tensorflow.js实现了几种RNN的接口,包括SimpleRNN、GRU和LSTM。这篇笔记介绍如何在浏览器环境下利用tensorflow.js训练RNN学习加法运算,即给出一个加法算式的字符串,算出数字结果,类似于自然语言处理。  
1、生成训练、测试数据
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 
 | function generateData(digits, trainingSize) {
 
 const digitArray = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'];
 const arraySize = digitArray.length;
 
 
 const output = [];
 const maxLen = digits + 1 + digits;
 
 
 const f = () => {
 let str = '';
 while (str.length < digits) {
 const index = Math.floor(Math.random() * arraySize);
 str += digitArray[index];
 }
 return Number.parseInt(str);
 };
 
 
 while (output.length < trainingSize) {
 const a = f();
 const b = f();
 
 const q = `${a}+${b}`;
 
 const query = q + ' '.repeat(maxLen - q.length);
 let ans = (a + b).toString();
 
 ans += ' '.repeat(digits + 1 - ans.length);
 output.push([query, ans]);
 }
 return output;
 }
 
 | 
digits代表输入数字的位数,比如567的位数是3。函数f从digitArray中随机挑选digits个数拼为一个输入。输入a、加号、输入b整体拼为一个query,a+b的真实结果拼为ans。为防止第一个数字为0改变数字位数,query和ans均后补空格,函数返回query、ans字符对。  
2、数据分组并转为tensor
| 12
 3
 4
 5
 6
 7
 8
 
 | const split = Math.floor(trainingSize * 0.9);
 this.trainData = data.slice(0, split);
 this.testData = data.slice(split);
 
 
 [this.trainXs, this.trainYs] = convertDataToTensors(this.trainData, this.charTable, digits);
 [this.testXs, this.testYs] = convertDataToTensors(this.testData, this.charTable, digits);
 
 | 
将generateData生成的数据分为训练组和验证组,并将字符串转为tensor。转换函数converDataToTensors如下。
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 
 | function convertDataToTensors(data, charTable, digits) {  const maxLen = digits + 1 + digits;
 
 const questions = data.map(datum => datum[0]);
 const answers = data.map(datum => datum[1]);
 return [
 charTable.encodeBatch(questions, maxLen),
 charTable.encodeBatch(answers, digits + 1),
 ];
 }
 
 | 
对query、ans编码,需要字符集类CharacterTable。
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 
 | class CharacterTable {  constructor(chars) {
 this.chars = chars;
 
 this.charIndices = {};
 
 this.indicesChar = {};
 this.size = this.chars.length;
 for (let i = 0; i < this.size; ++i) {
 const char = this.chars[i];
 this.charIndices[this.chars[i]] = i;
 this.indicesChar[i] = this.chars[i];
 }
 }
 
 
 encodeBatch(strings, maxLen) {
 const numExamples = strings.length;
 const buf = tf.buffer([numExamples, maxLen, this.size]);
 for (let i = 0; i < numExamples; ++i) {
 const str = strings[i];
 for (let j = 0; j < str.length; ++j) {
 const char = str[j];
 buf.set(1, i, j, this.charIndices[char]);
 }
 }
 return buf.toTensor().as3D(numExamples, maxLen, this.size);
 }
 }
 
 | 
3、构建RNN
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 
 | function createAndCompileModel(hiddenSize, rnnType, digits, vocabularySize) {  const maxLen = digits + 1 + digits;
 const model = tf.sequential();
 switch (rnnType) {
 case 'SimpleRNN':
 model.add(tf.layers.simpleRNN({
 units: hiddenSize,
 recurrentInitializer: 'glorotNormal',
 inputShape: [maxLen, vocabularySize]
 }));
 break;
 case 'GRU':
 model.add(tf.layers.gru({
 units: hiddenSize,
 recurrentInitializer: 'glorotNormal',
 inputShape: [maxLen, vocabularySize]
 }));
 break;
 case 'LSTM':
 model.add(tf.layers.lstm({
 units: hiddenSize,
 recurrentInitializer: 'glorotNormal',
 inputShape: [maxLen, vocabularySize]
 }));
 break;
 default:
 break;
 }
 model.add(tf.layers.repeatVector({n: digits + 1}));
 switch (rnnType) {
 case 'SimpleRNN':
 model.add(tf.layers.simpleRNN({
 units: hiddenSize,
 recurrentInitializer: 'glorotNormal',
 returnSequences: true
 }));
 break;
 case 'GRU':
 model.add(tf.layers.gru({
 units: hiddenSize,
 recurrentInitializer: 'glorotNormal',
 returnSequences: true
 }));
 break;
 case 'LSTM':
 model.add(tf.layers.lstm({
 units: hiddenSize,
 recurrentInitializer: 'glorotNormal',
 returnSequences: true
 }));
 break;
 default:
 break;
 }
 model.add(tf.layers.timeDistributed({
 layer: tf.layers.dense({units: vocabularySize})
 }));
 model.add(tf.layers.activation({
 activation: 'softmax'
 }));
 model.compile({
 loss: 'categoricalCrossentropy',
 optimizer: 'adam',
 metrics: ['accuracy']
 });
 return model;
 }
 
 | 
使用tensorflow.js的三种RNN接口即可。  
4、训练RNN
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 
 | async train(iterations, batchSize, numTestExamples) {  
 const trainLossArr = [];
 const valLossArr = [];
 const trainAccuracyArr = [];
 const valAccuracyArr = [];
 const examplesPerSecArr = [];
 
 for (let i = 0; i < iterations; ++i) {
 
 const beginMs = performance.now();
 
 const history = await this.model.fit(this.trainXs, this.trainYs, {
 epochs: 1,
 batchSize,
 validationData: [this.testXs, this.testYs],
 });
 
 const elapsedMs = performance.now() - beginMs;
 
 const examplesPerSec = this.testXs.shape[0] / (elapsedMs / 1000);
 
 const trainLoss = history.history['loss'][0];
 const trainAccuracy = history.history['acc'][0];
 
 const valLoss = history.history['val_loss'][0];
 const valAccuracy = history.history['val_acc'][0];
 
 
 document.getElementById('trainStatus').textContent =
 `Iteration ${i}: train loss = ${trainLoss.toFixed(6)}; ` +
 `train accuracy = ${trainAccuracy.toFixed(6)}; ` +
 `validation loss = ${valLoss.toFixed(6)}; ` +
 `validation accuracy = ${valAccuracy.toFixed(6)} ` +
 `(${examplesPerSec.toFixed(2)} examples/s)`;
 
 
 trainLossArr.push([i, trainLoss]);
 valLossArr.push([i, valLoss]);
 trainAccuracyArr.push([i, trainAccuracy]);
 valAccuracyArr.push([i, valAccuracy]);
 examplesPerSecArr.push([i, examplesPerSec]);
 
 
 await tf.nextFrame();
 }
 }
 
 | 
5、可视化
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 
 | if (this.testXsForDisplay == null || this.testXsForDisplay.shape[0] !== numTestExamples) {
 if (this.textXsForDisplay) {
 this.textXsForDisplay.dispose();
 }
 
 this.testXsForDisplay = this.testXs.slice(
 [0, 0, 0],
 [numTestExamples, this.testXs.shape[1], this.testXs.shape[2]]);
 }
 
 const examples = [];
 const isCorrect = [];
 tf.tidy(() => {
 const predictOut = this.model.predict(this.testXsForDisplay);
 for (let j = 0; j < numTestExamples; ++j) {
 const scores = predictOut
 .slice([j, 0, 0], [1, predictOut.shape[1], predictOut.shape[2]])
 .as2D(predictOut.shape[1], predictOut.shape[2]);
 const decoded = this.charTable.decode(scores);
 examples.push(this.testData[j][0] + ' = ' + decoded);
 isCorrect.push(this.testData[j][1].trim() === decoded.trim());
 }
 });
 
 
 const examplesDiv = document.getElementById('testExamples');
 
 while (examplesDiv.firstChild) {
 examplesDiv.removeChild(examplesDiv.firstChild);
 }
 
 for (let i = 0; i < examples.length; ++i) {
 const exampleDiv = document.createElement('div');
 exampleDiv.textContent = examples[i];
 
 exampleDiv.className = isCorrect[i] ? 'answer-correct' : 'answer-wrong';
 examplesDiv.appendChild(exampleDiv);
 }
 
 | 
选择numTestExamples个测试数据展现在文档上,以展现训练过程。预测错误标为红色,预测正确标位绿色。将RNN的输出转换为数字需要字符集类CharacterTable的解码方法。
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 
 | class CharacterTable {  ...
 decode(x, calcArgmax = true) {
 return tf.tidy(() => {
 if (calcArgmax) {
 x = x.argMax(1);
 }
 const xData = x.dataSync();
 let output = '';
 for (const index of Array.from(xData)) {
 output += this.indicesChar[index];
 }
 return output;
 });
 }
 }
 
 | 
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 
 | const lossCanvas = echarts.init(document.getElementById('lossCanvas'));
 const accuracyCanvas = echarts.init(document.getElementById('accuracyCanvas'));
 const examplesPerSecCanvas = echarts.init(document.getElementById('examplesPerSecCanvas'));
 lossCanvas.setOption({
 title: {
 text: 'Loss Values'
 },
 xAxis: {
 type: 'value'
 },
 yAxis: {
 type: 'value'
 },
 series: [{
 name: 'trainLoss',
 type: 'line',
 symbol: 'none',
 data: trainLossArr
 },{
 name: 'valLoss',
 type: 'line',
 symbol: 'none',
 data: valLossArr
 }]
 });
 accuracyCanvas.setOption({
 title: {
 text: 'Accuracy Values'
 },
 xAxis: {
 type: 'value'
 },
 yAxis: {
 type: 'value'
 },
 series: [{
 name: 'trainAccuracy',
 type: 'line',
 symbol: 'none',
 data: trainAccuracyArr
 },{
 name: 'valAccuracy',
 type: 'line',
 symbol: 'none',
 data: valAccuracyArr
 }]
 });
 examplesPerSecCanvas.setOption({
 title: {
 text: 'Examples Per Second'
 },
 
 grid: {
 left: '12%'
 },
 xAxis: {
 type: 'value'
 },
 yAxis: {
 type: 'value'
 },
 series: {
 name: 'examplesPerSec',
 type: 'line',
 symbol: 'none',
 data: examplesPerSecArr
 }
 });
 
 | 
展现训练损失、训练准确率、测试损失、测试准确率。index.html如下。
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 
 | <!DOCTYPE html>  <html>
 <head>
 <meta charset="UTF-8">
 <title>Document</title>
 <style>
 body {
 margin-top: 50px;
 margin-left: 50px;
 }
 .setting {
 padding-bottom: 6px;
 }
 .setting-label {
 display: inline-block;
 width: 12em;
 }
 .answer-correct {
 color: green;
 }
 .answer-wrong {
 color: red;
 }
 .canvases {
 display: inline-block;
 width: 300px;
 height: 300px;
 }
 </style>
 </head>
 <body>
 <h1>RNN</h1>
 <div>
 <div>
 <div class="setting">
 <span class="setting-label">Digits:</span>
 <input id="digits" value="2"></input>
 </div>
 <div class="setting">
 <span class="setting-label">Training Size:</span>
 <input id="trainingSize" value="5000"></input>
 </div>
 <div class="setting">
 <span class="setting-label">RNN Type:</span>
 <select id="rnnType">
 <option value="SimpleRNN">SimpleRNN</option>
 <option value="GRU">GRU</option>
 <option value="LSTM">LSTM</option>
 </select>
 </div>
 <div class="setting">
 <span class="setting-label">RNN Hidden Layer Size:</span>
 <input id="rnnLayerSize" value="128"></input>
 </div>
 <div class="setting">
 <span class="setting-label">Batch Size:</span>
 <input id="batchSize" value="128"></input>
 </div>
 <div class="setting">
 <span class="setting-label">Train Iterations:</span>
 <input id="trainIterations" value="100"></input>
 </div>
 <div class="setting">
 <span class="setting-label"># of test examples:</span>
 <input id="numTestExamples" value="20"></input>
 </div>
 </div>
 <button id="trainModel">Train Model</button>
 <div id="trainStatus"></div>
 <div>
 <div class="canvases" id="lossCanvas"></div>
 <div class="canvases" id="accuracyCanvas"></div>
 <div class="canvases" id="examplesPerSecCanvas"></div>
 </div>
 <div id="testExamples"></div>
 </div>
 </body>
 <script src="bundle.js"></script>
 </html>
 
 | 
6、获取参数
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 
 | async function runRNN() {  
 document.getElementById('trainModel').addEventListener('click', async () => {
 
 
 const digits = +(document.getElementById('digits')).value;
 
 const trainingSize = +(document.getElementById('trainingSize')).value;
 
 const rnnTypeSelect = document.getElementById('rnnType');
 const rnnType = rnnTypeSelect.options[rnnTypeSelect.selectedIndex].getAttribute('value');
 
 const hiddenSize = +(document.getElementById('rnnLayerSize')).value;
 
 
 const trainIterations = +(document.getElementById('trainIterations')).value;
 
 const batchSize = +(document.getElementById('batchSize')).value;
 
 const numTestExamples = +(document.getElementById('numTestExamples')).value;
 
 
 const demo = new RNN(digits, trainingSize, rnnType, hiddenSize);
 
 await demo.train(trainIterations, batchSize, numTestExamples);
 });
 }
 runRNN();
 
 | 
结果如下,准确率97.9%。

完整程序见我的github。