tensorflow.js实现了几种RNN的接口,包括SimpleRNN、GRU和LSTM。这篇笔记介绍如何在浏览器环境下利用tensorflow.js训练RNN学习加法运算,即给出一个加法算式的字符串,算出数字结果,类似于自然语言处理。

1、生成训练、测试数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// digits-每个字符位数,trainingSize-训练集大小  
function generateData(digits, trainingSize) {
// 所有可选字符集
const digitArray = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'];
const arraySize = digitArray.length;

// 输出
const output = [];
const maxLen = digits + 1 + digits;

// 从digitArray挑选digits个数据拼为一个数字
const f = () => {
let str = '';
while (str.length < digits) {
const index = Math.floor(Math.random() * arraySize);
str += digitArray[index];
}
return Number.parseInt(str);
};

// 生成trainingSize组数据
while (output.length < trainingSize) {
const a = f();
const b = f();

const q = `${a}+${b}`;
// 补空格
const query = q + ' '.repeat(maxLen - q.length);
let ans = (a + b).toString();
// 补空格
ans += ' '.repeat(digits + 1 - ans.length);
output.push([query, ans]);
}
return output;
}

digits代表输入数字的位数,比如567的位数是3。函数f从digitArray中随机挑选digits个数拼为一个输入。输入a、加号、输入b整体拼为一个query,a+b的真实结果拼为ans。为防止第一个数字为0改变数字位数,query和ans均后补空格,函数返回query、ans字符对。

2、数据分组并转为tensor

1
2
3
4
5
6
7
8
// 90%训练集,10%测试集  
const split = Math.floor(trainingSize * 0.9);
this.trainData = data.slice(0, split);
this.testData = data.slice(split);

// 转为tensors,并分为训练组、测试组
[this.trainXs, this.trainYs] = convertDataToTensors(this.trainData, this.charTable, digits);
[this.testXs, this.testYs] = convertDataToTensors(this.testData, this.charTable, digits);

将generateData生成的数据分为训练组和验证组,并将字符串转为tensor。转换函数converDataToTensors如下。

1
2
3
4
5
6
7
8
9
10
function convertDataToTensors(data, charTable, digits) {  
const maxLen = digits + 1 + digits;
// data中每一项datum = [query, ans]
const questions = data.map(datum => datum[0]);
const answers = data.map(datum => datum[1]);
return [
charTable.encodeBatch(questions, maxLen),
charTable.encodeBatch(answers, digits + 1),
];
}

对query、ans编码,需要字符集类CharacterTable。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class CharacterTable {  
constructor(chars) {
this.chars = chars;
// 字符-位置index
this.charIndices = {};
// 位置index-字符
this.indicesChar = {};
this.size = this.chars.length;
for (let i = 0; i < this.size; ++i) {
const char = this.chars[i];
this.charIndices[this.chars[i]] = i;
this.indicesChar[i] = this.chars[i];
}
}

// 输入questions、answers数组,输出转化的tensor
encodeBatch(strings, maxLen) {
const numExamples = strings.length;
const buf = tf.buffer([numExamples, maxLen, this.size]);
for (let i = 0; i < numExamples; ++i) {
const str = strings[i];
for (let j = 0; j < str.length; ++j) {
const char = str[j];
buf.set(1, i, j, this.charIndices[char]);
}
}
return buf.toTensor().as3D(numExamples, maxLen, this.size);
}
}

3、构建RNN

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
function createAndCompileModel(hiddenSize, rnnType, digits, vocabularySize) {  
const maxLen = digits + 1 + digits;
const model = tf.sequential();
switch (rnnType) {
case 'SimpleRNN':
model.add(tf.layers.simpleRNN({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
inputShape: [maxLen, vocabularySize]
}));
break;
case 'GRU':
model.add(tf.layers.gru({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
inputShape: [maxLen, vocabularySize]
}));
break;
case 'LSTM':
model.add(tf.layers.lstm({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
inputShape: [maxLen, vocabularySize]
}));
break;
default:
break;
}
model.add(tf.layers.repeatVector({n: digits + 1}));
switch (rnnType) {
case 'SimpleRNN':
model.add(tf.layers.simpleRNN({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
returnSequences: true
}));
break;
case 'GRU':
model.add(tf.layers.gru({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
returnSequences: true
}));
break;
case 'LSTM':
model.add(tf.layers.lstm({
units: hiddenSize,
recurrentInitializer: 'glorotNormal',
returnSequences: true
}));
break;
default:
break;
}
model.add(tf.layers.timeDistributed({
layer: tf.layers.dense({units: vocabularySize})
}));
model.add(tf.layers.activation({
activation: 'softmax'
}));
model.compile({
loss: 'categoricalCrossentropy',
optimizer: 'adam',
metrics: ['accuracy']
});
return model;
}

使用tensorflow.js的三种RNN接口即可。

4、训练RNN

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
async train(iterations, batchSize, numTestExamples) {  
// 损失、准确率、每秒训练样本数,用于可视化
const trainLossArr = [];
const valLossArr = [];
const trainAccuracyArr = [];
const valAccuracyArr = [];
const examplesPerSecArr = [];

for (let i = 0; i < iterations; ++i) {
// 开始计时
const beginMs = performance.now();
// 训练样本并测试
const history = await this.model.fit(this.trainXs, this.trainYs, {
epochs: 1,
batchSize,
validationData: [this.testXs, this.testYs],
});
// 时间差
const elapsedMs = performance.now() - beginMs;
// 每秒训练多少数据
const examplesPerSec = this.testXs.shape[0] / (elapsedMs / 1000);
// 训练集损失、准确率
const trainLoss = history.history['loss'][0];
const trainAccuracy = history.history['acc'][0];
// 验证集损失、准确率
const valLoss = history.history['val_loss'][0];
const valAccuracy = history.history['val_acc'][0];

// 输出迭代次数、训练损失、训练准确率、验证损失、验证准确滤
document.getElementById('trainStatus').textContent =
`Iteration ${i}: train loss = ${trainLoss.toFixed(6)}; ` +
`train accuracy = ${trainAccuracy.toFixed(6)}; ` +
`validation loss = ${valLoss.toFixed(6)}; ` +
`validation accuracy = ${valAccuracy.toFixed(6)} ` +
`(${examplesPerSec.toFixed(2)} examples/s)`;

// 存入数组便于可视化
trainLossArr.push([i, trainLoss]);
valLossArr.push([i, valLoss]);
trainAccuracyArr.push([i, trainAccuracy]);
valAccuracyArr.push([i, valAccuracy]);
examplesPerSecArr.push([i, examplesPerSec]);

// 迭代
await tf.nextFrame();
}
}

5、可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// 没有展示的数据或展示的数据量不够  
if (this.testXsForDisplay == null || this.testXsForDisplay.shape[0] !== numTestExamples) {
if (this.textXsForDisplay) {
this.textXsForDisplay.dispose();
}
// 用于展示的数据
this.testXsForDisplay = this.testXs.slice(
[0, 0, 0],
[numTestExamples, this.testXs.shape[1], this.testXs.shape[2]]);
}

const examples = [];
const isCorrect = [];
tf.tidy(() => {
const predictOut = this.model.predict(this.testXsForDisplay);
for (let j = 0; j < numTestExamples; ++j) {
const scores = predictOut
.slice([j, 0, 0], [1, predictOut.shape[1], predictOut.shape[2]])
.as2D(predictOut.shape[1], predictOut.shape[2]);
const decoded = this.charTable.decode(scores);
examples.push(this.testData[j][0] + ' = ' + decoded);
isCorrect.push(this.testData[j][1].trim() === decoded.trim());
}
});

// 修改testExamples视图
const examplesDiv = document.getElementById('testExamples');
// 清空
while (examplesDiv.firstChild) {
examplesDiv.removeChild(examplesDiv.firstChild);
}
// 向页面节点添加测试样本
for (let i = 0; i < examples.length; ++i) {
const exampleDiv = document.createElement('div');
exampleDiv.textContent = examples[i];
// 为正确和错误结果的节点加上不同的类
exampleDiv.className = isCorrect[i] ? 'answer-correct' : 'answer-wrong';
examplesDiv.appendChild(exampleDiv);
}

选择numTestExamples个测试数据展现在文档上,以展现训练过程。预测错误标为红色,预测正确标位绿色。将RNN的输出转换为数字需要字符集类CharacterTable的解码方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class CharacterTable {  
...
decode(x, calcArgmax = true) {
return tf.tidy(() => {
if (calcArgmax) {
x = x.argMax(1);
}
const xData = x.dataSync();
let output = '';
for (const index of Array.from(xData)) {
output += this.indicesChar[index];
}
return output;
});
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
// 可视化  
const lossCanvas = echarts.init(document.getElementById('lossCanvas'));
const accuracyCanvas = echarts.init(document.getElementById('accuracyCanvas'));
const examplesPerSecCanvas = echarts.init(document.getElementById('examplesPerSecCanvas'));
lossCanvas.setOption({
title: {
text: 'Loss Values'
},
xAxis: {
type: 'value'
},
yAxis: {
type: 'value'
},
series: [{
name: 'trainLoss',
type: 'line',
symbol: 'none',
data: trainLossArr
},{
name: 'valLoss',
type: 'line',
symbol: 'none',
data: valLossArr
}]
});
accuracyCanvas.setOption({
title: {
text: 'Accuracy Values'
},
xAxis: {
type: 'value'
},
yAxis: {
type: 'value'
},
series: [{
name: 'trainAccuracy',
type: 'line',
symbol: 'none',
data: trainAccuracyArr
},{
name: 'valAccuracy',
type: 'line',
symbol: 'none',
data: valAccuracyArr
}]
});
examplesPerSecCanvas.setOption({
title: {
text: 'Examples Per Second'
},
// y轴向右12%,避免遮挡
grid: {
left: '12%'
},
xAxis: {
type: 'value'
},
yAxis: {
type: 'value'
},
series: {
name: 'examplesPerSec',
type: 'line',
symbol: 'none',
data: examplesPerSecArr
}
});

展现训练损失、训练准确率、测试损失、测试准确率。index.html如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
<!DOCTYPE html>  
<html>
<head>
<meta charset="UTF-8">
<title>Document</title>
<style>
body {
margin-top: 50px;
margin-left: 50px;
}
.setting {
padding-bottom: 6px;
}
.setting-label {
display: inline-block;
width: 12em;
}
.answer-correct {
color: green;
}
.answer-wrong {
color: red;
}
.canvases {
display: inline-block;
width: 300px;
height: 300px;
}
</style>
</head>
<body>
<h1>RNN</h1>
<div>
<div>
<div class="setting">
<span class="setting-label">Digits:</span>
<input id="digits" value="2"></input>
</div>
<div class="setting">
<span class="setting-label">Training Size:</span>
<input id="trainingSize" value="5000"></input>
</div>
<div class="setting">
<span class="setting-label">RNN Type:</span>
<select id="rnnType">
<option value="SimpleRNN">SimpleRNN</option>
<option value="GRU">GRU</option>
<option value="LSTM">LSTM</option>
</select>
</div>
<div class="setting">
<span class="setting-label">RNN Hidden Layer Size:</span>
<input id="rnnLayerSize" value="128"></input>
</div>
<div class="setting">
<span class="setting-label">Batch Size:</span>
<input id="batchSize" value="128"></input>
</div>
<div class="setting">
<span class="setting-label">Train Iterations:</span>
<input id="trainIterations" value="100"></input>
</div>
<div class="setting">
<span class="setting-label"># of test examples:</span>
<input id="numTestExamples" value="20"></input>
</div>
</div>
<button id="trainModel">Train Model</button>
<div id="trainStatus"></div>
<div>
<div class="canvases" id="lossCanvas"></div>
<div class="canvases" id="accuracyCanvas"></div>
<div class="canvases" id="examplesPerSecCanvas"></div>
</div>
<div id="testExamples"></div>
</div>
</body>
<script src="bundle.js"></script>
</html>

6、获取参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
async function runRNN() {  
// 训练按钮
document.getElementById('trainModel').addEventListener('click', async () => {
// 获取参数
// 每个字符位数
const digits = +(document.getElementById('digits')).value;
// 训练集大小
const trainingSize = +(document.getElementById('trainingSize')).value;
// SimpleRNN、GRU、LSTM
const rnnTypeSelect = document.getElementById('rnnType');
const rnnType = rnnTypeSelect.options[rnnTypeSelect.selectedIndex].getAttribute('value');
// 隐含层大小
const hiddenSize = +(document.getElementById('rnnLayerSize')).value;

// 迭代次数
const trainIterations = +(document.getElementById('trainIterations')).value;
// 每次训练大小
const batchSize = +(document.getElementById('batchSize')).value;
// 测试几组数据
const numTestExamples = +(document.getElementById('numTestExamples')).value;

// 生成RNN
const demo = new RNN(digits, trainingSize, rnnType, hiddenSize);
// 训练RNN
await demo.train(trainIterations, batchSize, numTestExamples);
});
}
runRNN();

结果如下,准确率97.9%。
res
完整程序见我的github