tensorflow.js实现了几种RNN的接口,包括SimpleRNN、GRU和LSTM。这篇笔记介绍如何在浏览器环境下利用tensorflow.js训练RNN学习加法运算,即给出一个加法算式的字符串,算出数字结果,类似于自然语言处理。
1、生成训练、测试数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| function generateData(digits, trainingSize) { const digitArray = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']; const arraySize = digitArray.length; const output = []; const maxLen = digits + 1 + digits; const f = () => { let str = ''; while (str.length < digits) { const index = Math.floor(Math.random() * arraySize); str += digitArray[index]; } return Number.parseInt(str); }; while (output.length < trainingSize) { const a = f(); const b = f(); const q = `${a}+${b}`; const query = q + ' '.repeat(maxLen - q.length); let ans = (a + b).toString(); ans += ' '.repeat(digits + 1 - ans.length); output.push([query, ans]); } return output; }
|
digits代表输入数字的位数,比如567的位数是3。函数f从digitArray中随机挑选digits个数拼为一个输入。输入a、加号、输入b整体拼为一个query,a+b的真实结果拼为ans。为防止第一个数字为0改变数字位数,query和ans均后补空格,函数返回query、ans字符对。
2、数据分组并转为tensor
1 2 3 4 5 6 7 8
| const split = Math.floor(trainingSize * 0.9); this.trainData = data.slice(0, split); this.testData = data.slice(split);
[this.trainXs, this.trainYs] = convertDataToTensors(this.trainData, this.charTable, digits); [this.testXs, this.testYs] = convertDataToTensors(this.testData, this.charTable, digits);
|
将generateData生成的数据分为训练组和验证组,并将字符串转为tensor。转换函数converDataToTensors如下。
1 2 3 4 5 6 7 8 9 10
| function convertDataToTensors(data, charTable, digits) { const maxLen = digits + 1 + digits; const questions = data.map(datum => datum[0]); const answers = data.map(datum => datum[1]); return [ charTable.encodeBatch(questions, maxLen), charTable.encodeBatch(answers, digits + 1), ]; }
|
对query、ans编码,需要字符集类CharacterTable。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| class CharacterTable { constructor(chars) { this.chars = chars; this.charIndices = {}; this.indicesChar = {}; this.size = this.chars.length; for (let i = 0; i < this.size; ++i) { const char = this.chars[i]; this.charIndices[this.chars[i]] = i; this.indicesChar[i] = this.chars[i]; } } encodeBatch(strings, maxLen) { const numExamples = strings.length; const buf = tf.buffer([numExamples, maxLen, this.size]); for (let i = 0; i < numExamples; ++i) { const str = strings[i]; for (let j = 0; j < str.length; ++j) { const char = str[j]; buf.set(1, i, j, this.charIndices[char]); } } return buf.toTensor().as3D(numExamples, maxLen, this.size); } }
|
3、构建RNN
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
| function createAndCompileModel(hiddenSize, rnnType, digits, vocabularySize) { const maxLen = digits + 1 + digits; const model = tf.sequential(); switch (rnnType) { case 'SimpleRNN': model.add(tf.layers.simpleRNN({ units: hiddenSize, recurrentInitializer: 'glorotNormal', inputShape: [maxLen, vocabularySize] })); break; case 'GRU': model.add(tf.layers.gru({ units: hiddenSize, recurrentInitializer: 'glorotNormal', inputShape: [maxLen, vocabularySize] })); break; case 'LSTM': model.add(tf.layers.lstm({ units: hiddenSize, recurrentInitializer: 'glorotNormal', inputShape: [maxLen, vocabularySize] })); break; default: break; } model.add(tf.layers.repeatVector({n: digits + 1})); switch (rnnType) { case 'SimpleRNN': model.add(tf.layers.simpleRNN({ units: hiddenSize, recurrentInitializer: 'glorotNormal', returnSequences: true })); break; case 'GRU': model.add(tf.layers.gru({ units: hiddenSize, recurrentInitializer: 'glorotNormal', returnSequences: true })); break; case 'LSTM': model.add(tf.layers.lstm({ units: hiddenSize, recurrentInitializer: 'glorotNormal', returnSequences: true })); break; default: break; } model.add(tf.layers.timeDistributed({ layer: tf.layers.dense({units: vocabularySize}) })); model.add(tf.layers.activation({ activation: 'softmax' })); model.compile({ loss: 'categoricalCrossentropy', optimizer: 'adam', metrics: ['accuracy'] }); return model; }
|
使用tensorflow.js的三种RNN接口即可。
4、训练RNN
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| async train(iterations, batchSize, numTestExamples) { const trainLossArr = []; const valLossArr = []; const trainAccuracyArr = []; const valAccuracyArr = []; const examplesPerSecArr = []; for (let i = 0; i < iterations; ++i) { const beginMs = performance.now(); const history = await this.model.fit(this.trainXs, this.trainYs, { epochs: 1, batchSize, validationData: [this.testXs, this.testYs], }); const elapsedMs = performance.now() - beginMs; const examplesPerSec = this.testXs.shape[0] / (elapsedMs / 1000); const trainLoss = history.history['loss'][0]; const trainAccuracy = history.history['acc'][0]; const valLoss = history.history['val_loss'][0]; const valAccuracy = history.history['val_acc'][0]; document.getElementById('trainStatus').textContent = `Iteration ${i}: train loss = ${trainLoss.toFixed(6)}; ` + `train accuracy = ${trainAccuracy.toFixed(6)}; ` + `validation loss = ${valLoss.toFixed(6)}; ` + `validation accuracy = ${valAccuracy.toFixed(6)} ` + `(${examplesPerSec.toFixed(2)} examples/s)`; trainLossArr.push([i, trainLoss]); valLossArr.push([i, valLoss]); trainAccuracyArr.push([i, trainAccuracy]); valAccuracyArr.push([i, valAccuracy]); examplesPerSecArr.push([i, examplesPerSec]); await tf.nextFrame(); } }
|
5、可视化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| if (this.testXsForDisplay == null || this.testXsForDisplay.shape[0] !== numTestExamples) { if (this.textXsForDisplay) { this.textXsForDisplay.dispose(); } this.testXsForDisplay = this.testXs.slice( [0, 0, 0], [numTestExamples, this.testXs.shape[1], this.testXs.shape[2]]); } const examples = []; const isCorrect = []; tf.tidy(() => { const predictOut = this.model.predict(this.testXsForDisplay); for (let j = 0; j < numTestExamples; ++j) { const scores = predictOut .slice([j, 0, 0], [1, predictOut.shape[1], predictOut.shape[2]]) .as2D(predictOut.shape[1], predictOut.shape[2]); const decoded = this.charTable.decode(scores); examples.push(this.testData[j][0] + ' = ' + decoded); isCorrect.push(this.testData[j][1].trim() === decoded.trim()); } });
const examplesDiv = document.getElementById('testExamples');
while (examplesDiv.firstChild) { examplesDiv.removeChild(examplesDiv.firstChild); }
for (let i = 0; i < examples.length; ++i) { const exampleDiv = document.createElement('div'); exampleDiv.textContent = examples[i]; exampleDiv.className = isCorrect[i] ? 'answer-correct' : 'answer-wrong'; examplesDiv.appendChild(exampleDiv); }
|
选择numTestExamples个测试数据展现在文档上,以展现训练过程。预测错误标为红色,预测正确标位绿色。将RNN的输出转换为数字需要字符集类CharacterTable的解码方法。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| class CharacterTable { ... decode(x, calcArgmax = true) { return tf.tidy(() => { if (calcArgmax) { x = x.argMax(1); } const xData = x.dataSync(); let output = ''; for (const index of Array.from(xData)) { output += this.indicesChar[index]; } return output; }); } }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
| const lossCanvas = echarts.init(document.getElementById('lossCanvas')); const accuracyCanvas = echarts.init(document.getElementById('accuracyCanvas')); const examplesPerSecCanvas = echarts.init(document.getElementById('examplesPerSecCanvas')); lossCanvas.setOption({ title: { text: 'Loss Values' }, xAxis: { type: 'value' }, yAxis: { type: 'value' }, series: [{ name: 'trainLoss', type: 'line', symbol: 'none', data: trainLossArr },{ name: 'valLoss', type: 'line', symbol: 'none', data: valLossArr }] }); accuracyCanvas.setOption({ title: { text: 'Accuracy Values' }, xAxis: { type: 'value' }, yAxis: { type: 'value' }, series: [{ name: 'trainAccuracy', type: 'line', symbol: 'none', data: trainAccuracyArr },{ name: 'valAccuracy', type: 'line', symbol: 'none', data: valAccuracyArr }] }); examplesPerSecCanvas.setOption({ title: { text: 'Examples Per Second' }, grid: { left: '12%' }, xAxis: { type: 'value' }, yAxis: { type: 'value' }, series: { name: 'examplesPerSec', type: 'line', symbol: 'none', data: examplesPerSecArr } });
|
展现训练损失、训练准确率、测试损失、测试准确率。index.html如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
| <!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title>Document</title> <style> body { margin-top: 50px; margin-left: 50px; } .setting { padding-bottom: 6px; } .setting-label { display: inline-block; width: 12em; } .answer-correct { color: green; } .answer-wrong { color: red; } .canvases { display: inline-block; width: 300px; height: 300px; } </style> </head> <body> <h1>RNN</h1> <div> <div> <div class="setting"> <span class="setting-label">Digits:</span> <input id="digits" value="2"></input> </div> <div class="setting"> <span class="setting-label">Training Size:</span> <input id="trainingSize" value="5000"></input> </div> <div class="setting"> <span class="setting-label">RNN Type:</span> <select id="rnnType"> <option value="SimpleRNN">SimpleRNN</option> <option value="GRU">GRU</option> <option value="LSTM">LSTM</option> </select> </div> <div class="setting"> <span class="setting-label">RNN Hidden Layer Size:</span> <input id="rnnLayerSize" value="128"></input> </div> <div class="setting"> <span class="setting-label">Batch Size:</span> <input id="batchSize" value="128"></input> </div> <div class="setting"> <span class="setting-label">Train Iterations:</span> <input id="trainIterations" value="100"></input> </div> <div class="setting"> <span class="setting-label"># of test examples:</span> <input id="numTestExamples" value="20"></input> </div> </div> <button id="trainModel">Train Model</button> <div id="trainStatus"></div> <div> <div class="canvases" id="lossCanvas"></div> <div class="canvases" id="accuracyCanvas"></div> <div class="canvases" id="examplesPerSecCanvas"></div> </div> <div id="testExamples"></div> </div> </body> <script src="bundle.js"></script> </html>
|
6、获取参数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| async function runRNN() { document.getElementById('trainModel').addEventListener('click', async () => { const digits = +(document.getElementById('digits')).value; const trainingSize = +(document.getElementById('trainingSize')).value; const rnnTypeSelect = document.getElementById('rnnType'); const rnnType = rnnTypeSelect.options[rnnTypeSelect.selectedIndex].getAttribute('value'); const hiddenSize = +(document.getElementById('rnnLayerSize')).value; const trainIterations = +(document.getElementById('trainIterations')).value; const batchSize = +(document.getElementById('batchSize')).value; const numTestExamples = +(document.getElementById('numTestExamples')).value; const demo = new RNN(digits, trainingSize, rnnType, hiddenSize); await demo.train(trainIterations, batchSize, numTestExamples); }); } runRNN();
|
结果如下,准确率97.9%。

完整程序见我的github。