物理超好玩

  • 首页
  • Noip学习助手
  • 书法字典APP下载
  • 资源列表
  • 格式化代码
  • 习题答案
  • 关于
物理超好玩
真诚面对自己
  1. 首页
  2. 程序设计
  3. 正文

利用JavaCV+OpenCV的ANN_MLP神经网络训练识别MNIST手写数字

2021年01月22日 1478点热度 1人点赞 0条评论

文本目录

  • 利用JavaCV+OpenCV的ANN_MLP神经网络训练识别MNIST手写数字
    • MNIST手写数字数据集
    • 读取MNIST数据
    • 从MNIST生成训练图片Mat和标签Mat数据
      • 图片Mat的格式
      • 标签Mat数据格式
    • 创建ANN_MLP神网络训练数据
    • 正确率测试
    • 源代码下载

利用JavaCV+OpenCV的ANN_MLP神经网络训练识别MNIST手写数字

JavaCV是可以在java中使用OpenCV的一个库。OpenCV是一个跨平台的开源计算机视觉和机器学习软件库。白话就是一个处理图片和进行人工智能识别图片的一个软件库。

MNIST手写数字数据集

MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集。

下载地址:http://yann.lecun.com/exdb/mnist/

做为许多神经网络学习的入门数据,一直没找到javaCV的相关例子。

用IDEA架设创建JAVACV的开发环境,请参考:在IDEA和Android Studio中用Gradle构建javacv开发环境

读取MNIST数据

MNIST数据有4个文件,分别为训练和测试的图片和标签。关于用java读取的方法可参考文章:

使用 Java 读取 MNIST 数据集

上面文章中介绍有MNIST文件的格式等信息,在这里就不再重复。

从MNIST生成训练图片Mat和标签Mat数据

JavaCV训练时的所有数据,都用Mat的形式提供。说白了就是一个float数组。注意神经网训练时,最好用float数据,MNIST数据集是一个byte数组,这里需要转换一下。

图片Mat的格式

x = new Mat(number, size, CvType.CV_32FC1);

number是样本数量,做为mat的行数

size是图片像素点数,即28*28。每个样本图片生成一个单行的数组放入Mat中

CV_32FC1是数据类型,为32位的float数据

完整的代码如下:

/**
 * 生成训练数据
 *
 * @param fileName the file of 'train' or 'test' about image
 * @return one row show a `picture`
 */
public static Mat getTrainData(String fileName) {
    Mat x = null;
    try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
        byte[] bytes = new byte[4];
        bin.read(bytes, 0, 4);
        if (!"00000803".equals(bytesToHex(bytes))) {                        
            // 读取魔数
            throw new RuntimeException("Please select the correct file!");
        } else {

            bin.read(bytes, 0, 4);
            // 读取样本总数
            int number = Integer.parseInt(bytesToHex(bytes), 16);       
            bin.read(bytes, 0, 4);
            // 读取每行所含像素点数
            int xPixel = Integer.parseInt(bytesToHex(bytes), 16);          
            bin.read(bytes, 0, 4);
            // 读取每列所含像素点数
            int yPixel = Integer.parseInt(bytesToHex(bytes), 16);           

            int l = xPixel*yPixel;
            x = new Mat(number, l, CvType.CV_32FC1);
            FloatIndexer indexer = x.createIndexer();
            for (int i = 0; i < number; i++) {
                for(int j=0; j<l; j++){
                    indexer.put(i, j, bin.read());
                }
            }
        }
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    return x;
}

标签Mat数据格式

Mat x = new Mat(data.length, 10, CvType.CV_32FC1);

data.length是样本数量,做为mat的行数

10是每个标签的数据量,即为float[10]。每个标签成一个单行的数组放入Mat中

CV_32FC1是数据类型,为32位的float数据

完整的代码如下:

/**
 * 获取训练的标签
 * 格式要求,每个标签为一个 float[10]数组,放在Mat的一行中
 * @param fileName
 * @return
 */
public static Mat getTrainLabels(String fileName) {
    byte[] data = getLabels(fileName);
    Mat x = new Mat(data.length, 10, CvType.CV_32FC1);
    FloatIndexer indexer = x.createIndexer();
    for(int i=0; i<data.length; i++){
        byte b = (byte) data[i];
        for(int j=0; j<10; j++){
            if(j==b)
                indexer.put(i, j, 1);
            else
                indexer.put(i, j, 0);
        }
    }
    return x;
}

/**
 * 获取所有标签的数值
 *
 * @param fileName the file of 'train' or 'test' about label
 * @return
 */
public static byte[] getLabels(String fileName) {
    byte[] y = null;
    try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
        byte[] bytes = new byte[4];
        bin.read(bytes, 0, 4);
        if (!"00000801".equals(bytesToHex(bytes))) {
            throw new RuntimeException("Please select the correct file!");
        } else {
            bin.read(bytes, 0, 4);
            int number = Integer.parseInt(bytesToHex(bytes), 16);
            y = new byte[number];

            byte c;
            for (int i = 0; i < number; i++) {
                c = (byte) bin.read();
                y[i] = c;
            }
        }
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    return y;
}

创建ANN_MLP神网络训练数据

创建了一个四层的神经网络,神经元个数分别为 { 28*28 , 512,  256,  10 } ,分别为:

输入层,对应着每个像素,所以是28*28

隐含层两个,神经元个数分别为 512 和 256

输出层,和训练的标签对应,神经元为10个,即数字 0123456789

具体代码如下:

/**
 * 训练数据
 * @param xml 要保存的数据文件名
 */
public static void train(String xml){
    opencv_core.Mat trainData = MnistRead.getTrainData(TRAIN_IMAGES_FILE);
    opencv_core.Mat lables = MnistRead.getTrainLabels(TRAIN_LABELS_FILE);

    opencv_ml.ANN_MLP mlp= opencv_ml.ANN_MLP.create();

    int image_cols = 28; //图片宽
    int image_rows = 28; //图片高
    int class_num = 10; //预测的结果,为 float[10] 数组
    /*
     * 神经网络层
     * */
    int[] layer={ image_cols*image_rows , 512, 256, class_num};
    opencv_core.Mat layerSizes=new opencv_core.Mat(1, layer.length, CV_32FC1);
    org.bytedeco.javacpp.indexer.FloatIndexer indexer = layerSizes.createIndexer();
    for(int i=0;i<layer.length;i++){
        indexer.put(i, layer[i]);
    }

    mlp.setLayerSizes(layerSizes);
    mlp.setActivationFunction(opencv_ml.ANN_MLP.SIGMOID_SYM);
    mlp.train(trainData, ROW_SAMPLE, lables);
    /*
     * 开始训练
     * */

    mlp.save(xml);
    mlp.clear();
    System.out.println("训练结束");

}

正确率测试

数据格式和训练时一样,就不做解释了,代码如下:

/**
 * 使用测试数据,测试识别率
 * @param xml 训练好的数据文件
 */
public static void test(String xml){
    opencv_ml.ANN_MLP ann = opencv_ml.ANN_MLP.load(xml);

    opencv_core.Mat predictData = MnistRead.getTrainData(TEST_IMAGES_FILE);
    byte[] predictLables = MnistRead.getLabels(TEST_LABELS_FILE);

    //正确计数
    int rc = 0;

    for(int i=0; i<predictData.rows(); i++){
        opencv_core.Mat sample = predictData.row(i);
        opencv_core.Mat predict = new opencv_core.Mat();
        ann.predict(sample, predict, UPDATE_MODEL);
        if(predictLables[i] == getMaxIndex(predict)){
            //预测正确
            rc++;
        }
    }

    //计算正确率
    double zql = rc*1.0/predictData.rows();
    System.out.println("正确率:" + zql);
}

源代码下载

本示例源代码已在GITEE上开源,大家可以免费下验证:

https://gitee.com/zizai/StudyJavaCV

 

相关文章:

  1. 在IDEA和Android Studio中用Gradle构建javacv开发环境
  2. 在JavaCV中合并两个Mat
  3. 用OpenCV的K-Means聚类对书法作品进行单字分割
  4. 格式混乱的百度文库复制文本格式化工具
  5. 用QT开发百度文库文本下载工具
  6. Idea中用gradle打包可执行的jar

订阅号“物理超好玩”
标签: ANN_MLP javacv MNIST OpenCV 神经网络
最后更新:2021年01月22日

坚持

真诚的面对自己的内心。 确立志向;全力准备;清净无扰,最终成功。 尊重自我,做自己最擅长的事情,做自己最喜欢的事情。

点赞
< 上一篇
下一篇 >

坚持

真诚的面对自己的内心。 确立志向;全力准备;清净无扰,最终成功。 尊重自我,做自己最擅长的事情,做自己最喜欢的事情。

分类
  • NOIP (1)
  • 习题讲解 (9)
  • 克服沉迷 (2)
  • 游戏危害 (1)
  • 程序设计 (10)
  • 软件作品 (2)
标签聚合
初中物理 电学 javacv OpenCV Idea 计算题 gradle 串联电路
最新 热点 随机
最新 热点 随机
中国游戏防沉迷简史 转移注意力 认识游戏的危害 U盘随身便携Git http服务器 IDEA用Gradle打包GUI Form为可执行的jar 运用浮力求密度解题思路
在JavaCV中合并两个Mat 已知两组数据列方程组解题补充习题 转移注意力 IDEA用Gradle打包GUI Form为可执行的jar 用QT开发百度文库文本下载工具 NOIP信息学奥赛视频教程
  • 利用JavaCV+OpenCV的ANN_MLP神经网络训练识别MNIST手写数字
    • MNIST手写数字数据集
    • 读取MNIST数据
    • 从MNIST生成训练图片Mat和标签Mat数据
      • 图片Mat的格式
      • 标签Mat数据格式
    • 创建ANN_MLP神网络训练数据
    • 正确率测试
    • 源代码下载

COPYRIGHT © 2021 物理超好玩. ALL RIGHTS RESERVED.

THEME KRATOS MADE BY VTROIS

豫ICP备16037997号-2