笔记(2)中利用tensorflow.js实现了一个经典的机器学习问题——CNN识别手写数字集MNIST。这篇笔记将利用web摄像头识别图像并判断上、下、左、右来玩吃豆人游戏。参考官方示例Transfer learning - Train a neural network to predict from webcam data,修改了部分代码。

1、首先引入已训练好的模型,mobilenet

1
2
3
4
5
async function loadMobilenet() {  
const mobilenet = await tf.loadModel('./model.json');
const layer = mobilenet.getLayer('conv_pw_13_relu');
return tf.model({inputs: mobilenet.inputs, outputs: layer.output});
}

其中函数返回的tf.model中输入还是mobilenet的原始输入,输出为mobilenet的“conv_pw_13_relu”层。一般而言,因为越靠后所包含的训练信息越多,所以应选择已训练好的模型中越靠后的层。

2、定义摄像头的类webcam

webcam.js文件内容如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import * as tf from '@tensorflow/tfjs';  

export class Webcam {
constructor(webcamElement) {
this.webcamElement = webcamElement;
}

capture() {
return tf.tidy(() => {
const webcamImage = tf.fromPixels(this.webcamElement);
const croppedImage = this.cropImage(webcamImage);
const batchedImage = croppedImage.expandDims(0);
return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
});
}

cropImage(img) {
const size = Math.min(img.shape[0], img.shape[1]);
const centerHeight = img.shape[0] / 2;
const beginHeight = centerHeight - (size / 2);
const centerWidth = img.shape[1] / 2;
const beginWidth = centerWidth - (size / 2);
return img.slice([beginHeight, beginWidth, 0], [size, size, 3]);
}

adjustVideoSize(width, height) {
const aspectRatio = width / height;
if (width >= height) {
this.webcamElement.width = aspectRatio * this.webcamElement.height;
} else if (width < height) {
this.webcamElement.height = this.webcamElement.width / aspectRatio;
}
}

async setup() {
return new Promise((resolve, reject) => {
const navigatorAny = navigator;
navigator.getUserMedia = navigator.getUserMedia ||
navigatorAny.webkitGetUserMedia ||
navigatorAny.mozGetUserMedia ||
navigatorAny.msGetUserMedia;
if (navigator.getUserMedia) {
navigator.getUserMedia(
{video: true},
stream => {
this.webcamElement.src = window.URL.createObjectURL(stream);
this.webcamElement.addEventListener('loadeddata', async () => {
this.adjustVideoSize(
this.webcamElement.videoWidth,
this.webcamElement.videoHeight);
resolve();
}, false);
},
error => {
document.querySelector('#no-webcam').style.display = 'block';
});
} else {
reject();
}
});
}
}

其中构造器传入DOM中的\<video>元素,声明如下所示。

1
const webcam = new Webcam(document.getElementById('webcam'));

(未完待续)