物理超好玩

  • 首页
  • Noip学习助手
  • 资源列表
  • 格式化代码
  • 习题答案
  • 创建块
  • 关于
物理超好玩
真诚面对自己
  1. 首页
  2. 程序设计
  3. 正文

利用JavaCV+OpenCV的ANN_MLP神经网络训练识别MNIST手写数字

2021年01月22日 2103点热度 1人点赞 0条评论

文本目录

  • 利用JavaCV+OpenCV的ANN_MLP神经网络训练识别MNIST手写数字
    • MNIST手写数字数据集
    • 读取MNIST数据
    • 从MNIST生成训练图片Mat和标签Mat数据
      • 图片Mat的格式
      • 标签Mat数据格式
    • 创建ANN_MLP神网络训练数据
    • 正确率测试
    • 源代码下载

利用JavaCV+OpenCV的ANN_MLP神经网络训练识别MNIST手写数字

JavaCV是可以在java中使用OpenCV的一个库。OpenCV是一个跨平台的开源计算机视觉和机器学习软件库。白话就是一个处理图片和进行人工智能识别图片的一个软件库。

MNIST手写数字数据集

MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集。

下载地址:http://yann.lecun.com/exdb/mnist/

做为许多神经网络学习的入门数据,一直没找到javaCV的相关例子。

用IDEA架设创建JAVACV的开发环境,请参考:在IDEA和Android Studio中用Gradle构建javacv开发环境

读取MNIST数据

MNIST数据有4个文件,分别为训练和测试的图片和标签。关于用java读取的方法可参考文章:

使用 Java 读取 MNIST 数据集

上面文章中介绍有MNIST文件的格式等信息,在这里就不再重复。

从MNIST生成训练图片Mat和标签Mat数据

JavaCV训练时的所有数据,都用Mat的形式提供。说白了就是一个float数组。注意神经网训练时,最好用float数据,MNIST数据集是一个byte数组,这里需要转换一下。

图片Mat的格式

x = new Mat(number, size, CvType.CV_32FC1);

number是样本数量,做为mat的行数

size是图片像素点数,即28*28。每个样本图片生成一个单行的数组放入Mat中

CV_32FC1是数据类型,为32位的float数据

完整的代码如下:

/**
 * 生成训练数据
 *
 * @param fileName the file of 'train' or 'test' about image
 * @return one row show a `picture`
 */
public static Mat getTrainData(String fileName) {
    Mat x = null;
    try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
        byte[] bytes = new byte[4];
        bin.read(bytes, 0, 4);
        if (!"00000803".equals(bytesToHex(bytes))) {                        
            // 读取魔数
            throw new RuntimeException("Please select the correct file!");
        } else {

            bin.read(bytes, 0, 4);
            // 读取样本总数
            int number = Integer.parseInt(bytesToHex(bytes), 16);       
            bin.read(bytes, 0, 4);
            // 读取每行所含像素点数
            int xPixel = Integer.parseInt(bytesToHex(bytes), 16);          
            bin.read(bytes, 0, 4);
            // 读取每列所含像素点数
            int yPixel = Integer.parseInt(bytesToHex(bytes), 16);           

            int l = xPixel*yPixel;
            x = new Mat(number, l, CvType.CV_32FC1);
            FloatIndexer indexer = x.createIndexer();
            for (int i = 0; i < number; i++) {
                for(int j=0; j<l; j++){
                    indexer.put(i, j, bin.read());
                }
            }
        }
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    return x;
}

标签Mat数据格式

Mat x = new Mat(data.length, 10, CvType.CV_32FC1);

data.length是样本数量,做为mat的行数

10是每个标签的数据量,即为float[10]。每个标签成一个单行的数组放入Mat中

CV_32FC1是数据类型,为32位的float数据

完整的代码如下:

/**
 * 获取训练的标签
 * 格式要求,每个标签为一个 float[10]数组,放在Mat的一行中
 * @param fileName
 * @return
 */
public static Mat getTrainLabels(String fileName) {
    byte[] data = getLabels(fileName);
    Mat x = new Mat(data.length, 10, CvType.CV_32FC1);
    FloatIndexer indexer = x.createIndexer();
    for(int i=0; i<data.length; i++){
        byte b = (byte) data[i];
        for(int j=0; j<10; j++){
            if(j==b)
                indexer.put(i, j, 1);
            else
                indexer.put(i, j, 0);
        }
    }
    return x;
}

/**
 * 获取所有标签的数值
 *
 * @param fileName the file of 'train' or 'test' about label
 * @return
 */
public static byte[] getLabels(String fileName) {
    byte[] y = null;
    try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
        byte[] bytes = new byte[4];
        bin.read(bytes, 0, 4);
        if (!"00000801".equals(bytesToHex(bytes))) {
            throw new RuntimeException("Please select the correct file!");
        } else {
            bin.read(bytes, 0, 4);
            int number = Integer.parseInt(bytesToHex(bytes), 16);
            y = new byte[number];

            byte c;
            for (int i = 0; i < number; i++) {
                c = (byte) bin.read();
                y[i] = c;
            }
        }
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    return y;
}

创建ANN_MLP神网络训练数据

创建了一个四层的神经网络,神经元个数分别为 { 28*28 , 512,  256,  10 } ,分别为:

输入层,对应着每个像素,所以是28*28

隐含层两个,神经元个数分别为 512 和 256

输出层,和训练的标签对应,神经元为10个,即数字 0123456789

具体代码如下:

/**
 * 训练数据
 * @param xml 要保存的数据文件名
 */
public static void train(String xml){
    opencv_core.Mat trainData = MnistRead.getTrainData(TRAIN_IMAGES_FILE);
    opencv_core.Mat lables = MnistRead.getTrainLabels(TRAIN_LABELS_FILE);

    opencv_ml.ANN_MLP mlp= opencv_ml.ANN_MLP.create();

    int image_cols = 28; //图片宽
    int image_rows = 28; //图片高
    int class_num = 10; //预测的结果,为 float[10] 数组
    /*
     * 神经网络层
     * */
    int[] layer={ image_cols*image_rows , 512, 256, class_num};
    opencv_core.Mat layerSizes=new opencv_core.Mat(1, layer.length, CV_32FC1);
    org.bytedeco.javacpp.indexer.FloatIndexer indexer = layerSizes.createIndexer();
    for(int i=0;i<layer.length;i++){
        indexer.put(i, layer[i]);
    }

    mlp.setLayerSizes(layerSizes);
    mlp.setActivationFunction(opencv_ml.ANN_MLP.SIGMOID_SYM);
    mlp.train(trainData, ROW_SAMPLE, lables);
    /*
     * 开始训练
     * */

    mlp.save(xml);
    mlp.clear();
    System.out.println("训练结束");

}

正确率测试

数据格式和训练时一样,就不做解释了,代码如下:

/**
 * 使用测试数据,测试识别率
 * @param xml 训练好的数据文件
 */
public static void test(String xml){
    opencv_ml.ANN_MLP ann = opencv_ml.ANN_MLP.load(xml);

    opencv_core.Mat predictData = MnistRead.getTrainData(TEST_IMAGES_FILE);
    byte[] predictLables = MnistRead.getLabels(TEST_LABELS_FILE);

    //正确计数
    int rc = 0;

    for(int i=0; i<predictData.rows(); i++){
        opencv_core.Mat sample = predictData.row(i);
        opencv_core.Mat predict = new opencv_core.Mat();
        ann.predict(sample, predict, UPDATE_MODEL);
        if(predictLables[i] == getMaxIndex(predict)){
            //预测正确
            rc++;
        }
    }

    //计算正确率
    double zql = rc*1.0/predictData.rows();
    System.out.println("正确率:" + zql);
}

源代码下载

本示例源代码已在GITEE上开源,大家可以免费下验证:

https://gitee.com/zizai/StudyJavaCV

 

相关文章:

  1. 在IDEA和Android Studio中用Gradle构建javacv开发环境
  2. 在JavaCV中合并两个Mat
  3. 用OpenCV的K-Means聚类对书法作品进行单字分割
  4. 用Python下载PHET互动仿真程序
  5. Idea中用gradle打包可执行的jar
  6. IDEA用Gradle打包GUI Form为可执行的jar

订阅号“物理超好玩”
标签: ANN_MLP javacv MNIST OpenCV 神经网络
最后更新:2021年01月22日

坚持

真诚的面对自己的内心。 确立志向;全力准备;清净无扰,最终成功。 尊重自我,做自己最擅长的事情,做自己最喜欢的事情。

点赞
< 上一篇
下一篇 >

坚持

真诚的面对自己的内心。 确立志向;全力准备;清净无扰,最终成功。 尊重自我,做自己最擅长的事情,做自己最喜欢的事情。

分类目录
  • NOIP (1)
  • 习题讲解 (9)
  • 未分类 (1)
  • 程序设计 (8)
  • 软件作品 (1)
标签聚合
初中物理 计算题 gradle 电学 javacv Idea 串联电路 OpenCV
最新 热点 随机
最新 热点 随机
26中信奥社团微信群二维码 U盘随身便携Git http服务器 IDEA用Gradle打包GUI Form为可执行的jar 运用浮力求密度解题思路 用OpenCV的K-Means聚类对书法作品进行单字分割 在JavaCV中合并两个Mat
26中信奥社团微信群二维码 在IDEA和Android Studio中用Gradle构建javacv开发环境 Idea中用gradle打包可执行的jar 利用JavaCV+OpenCV的ANN_MLP神经网络训练识别MNIST手写数字 已知两组数据列方程组解题补充习题 初中物理电学计算题第四讲:数据挖掘
  • 利用JavaCV+OpenCV的ANN_MLP神经网络训练识别MNIST手写数字
    • MNIST手写数字数据集
    • 读取MNIST数据
    • 从MNIST生成训练图片Mat和标签Mat数据
      • 图片Mat的格式
      • 标签Mat数据格式
    • 创建ANN_MLP神网络训练数据
    • 正确率测试
    • 源代码下载

COPYRIGHT © 2021 物理超好玩. ALL RIGHTS RESERVED.

THEME KRATOS MADE BY VTROIS

豫ICP备16037997号-2