【tensorflow.js学习笔记(2)】CNN识别手写数字集MNIST
笔记(1)中利用tensorflow.js完成了机器学习中曲线拟合的任务,这篇笔记将实现一个经典的机器学习问题——CNN识别手写数字集MNIST。参考官方示例Training on Images: Recognizing Handwritten Digits with a Convolutional Neural Network,修改部分代码并用echarts改写vega。
1、定义mnist数据类
1 | import * as tf from '@tensorflow/tfjs'; |
mnist数据类在构造器内声明两个index,分别是训练过程的洗牌index和测试过程的洗牌index。引入洗牌index是为了防止模型训练受到传入图像顺序的影响。假设不洗牌,先将所有1的图像传入模型进行训练,那么此时训练的模型将学会预测1的手写体;之后传入所有2的图像,则模型将切换到仅预测2(这样会最小化损失函数);则这样永远无法完整的对全部数据集进行预测。
之后定义方法load(),该方法将图片mnist_images.png进行切割,并从mnist_labels_uint8中找到对应的label。之后的nextTrainBatch和nextTestBatch将分别返回训练样本和测试样本。
2、CNN的构建
1 | import * as tf from '@tensorflow/tfjs'; |
首先定义模型为tf.sequential(),该模型中张量将连续地从一层传递到下一层。之后分别加入卷积层、池化层、卷积层、池化层、flatten层(将输入展为向量)和dense层(完全连接层)。
其中卷积层是2维卷积;inputShape是传入数据的维数(第二个卷积层可以不指定inputShape,tf将从前一层的输出推断该值),三个值分别为行、列、深度,mnist图像格式为2828像素点,深度为1,因为只有一个颜色通道;kernelSize是应用到输入数据上的滑动滤波器窗口的大小,5代表55矩形卷积窗口;filters指滤波器窗口的数量,8代表有8个滤波器;strides指滑动窗口的步长,1代表每次将以1像素为单位滑动滤波器;activation为激活函数,relu代表线性整流函数,函数形状如下所示;kernelInitializer用于随机初始化模型权重,这里使用VarianceScaling方法初始化模型。
池化层使用的是二维最大池,通过计算该层每个滑动窗口的最大值来降低纬度。其中poolSize是滑动窗口的大小,[2, 2]代表2*2的矩形窗口;strides代表滑动窗口移动的步长,[2, 2]代表窗口将在水平和垂直方向是以2像素为单位进行移动。
flatten层将上一层的输出平铺到一个矢量上。dense层(完全连接层)将执行最终的分类任务。其中units是输出的激活数,10代表将有10种不同的输出,满足mnist的10种分类(数字0-9);kernelInitializer设为VarianceScaling初始化方法;分类任务的最后一层激活函数activation通常设为softmax,该函数将10维输出向量归一化为概率分布,以便我们知道该样本属于10个类中每个类的概率。
定义学习率为0.15,优化器为随机梯度下降法(SGD);损失函数为categoricalCrossentropy,即分类任务的交叉熵;评价指标为准确率accuracy,即所有预测中正确预测的百分比。之后编译模型。
3、模型训练
1 | const BATCH_SIZE = 64; |
训练时首先调用ui.isTraining()将training输出到文档,之后进行TEST_BATCH_SIZE = 1000轮迭代,每次记录损失函数的值及准确率,并在训练结束后进行可视化展示。之后调用showPredictions方法将测试样本的预测值与实际像素图输出到文档。
4、可视化部分
1 | const echarts = require('echarts'); |
BUG1:
1 | lossChart.setOption({ |
当绘图时仅传入series参数时,Echarts报错:
Uncaught (in promise) TypeError: Cannot read property ‘get’ of undefined
此时将xAxis、yAxis等添加进options中即可。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16lossChart.setOption({
title: {
text: 'Loss Values'
},
xAxis: {
type: 'value'
},
yAxis: {
type: 'value'
},
series: [{
name: 'loss',
type: 'line',
data: lossValues
}]
});
BUG2:
报错 Uncaught Error: Unsupported core optimizer type: t
原因是相关依赖未正确安装,cd到package.json同目录,运行yarn命令安装相关包。
完整程序见我的github,具体步骤为:
step1 新建文件夹,cmd输入git clone git@github.com:orangecsy/tfjs-exercise.git,cd 2进入文件夹2;
step2 cmd输入webpack,打包;
step3 cd dist进入dist文件夹,cmd中输入http-server(需先npm install http-server)或使用webpack配置开发服务器;
step4 浏览器中输入http://127.0.0.1:8080/,即为结果。