Fork me on GitHub

Android上运行手写数字识别模型

Github源码请移步本文底部。

模型导出pb文件

首先我们需要在我们的python代码中保存训练好的模型,save_path参数就传递**.pb,这里导出文件留给接下来使用

1
2
3
4
def save_model_and_params(session, save_path):
out_graph_def = tf.graph_util.convert_variables_to_constants(session, session.graph_def, ["output"])
with tf.gfile.FastGFile(save_path, 'wb') as file:
file.write(out_graph_def.SerializeToString())

Android中通过JNI调用

Tensorflow与Android整合

整合部分就直接按照Android端运行Tensorflow中的步骤来就行了。

封装输出数据解析逻辑

在手写数字识别模型中的输出是一个size为10的列表,列表元素的索引值对应输出的结果,列表元素对应输出的概率,例如输出是[0.2, 0.7, 0.01……],即表示有0.2的概率是0,0.7的概率是1,0.01的概率是2……

因此我们需要在输出中对数据按照概率进行降序排列,以便让结果一目了然。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
/**
* @author zijiao
* @version 17/8/2
*/
public class MnistData {

private final List<Item> items = new ArrayList<>(10);

public MnistData(float[] data) {
for (int i = 0; i < data.length; i++) {
items.add(new Item(data[i], i));
}
Collections.sort(items);
}

public String top(int topSize) {
StringBuilder builder = new StringBuilder();
for (int i = 0; i < topSize; i++) {
Item item = items.get(i);
builder.append(item.index)
.append(": ")
.append(String.format("%.1f%%", item.value * 100))
.append("\n");
}
return builder.toString();
}

public String output() {
return String.valueOf(items.get(0).index);
}

@Override
public String toString() {
return output();
}

@SuppressWarnings("NullableProblems")
private static class Item implements Comparable<Item> {
final float value;
final float index;

private Item(float value, float index) {
this.value = value;
this.index = index;
}

@Override
public int compareTo(Item o) {
return value < o.value ? 1 : -1;
}
}

}

这时我们就能通过MnistData类的top方法得到概率最大的几个结果分别是什么。

构建数字分类器

这里通过TensorFlowInferenceInterface来调用模型,注释写得很清楚,值得注意的一点是,inputoutput的名称要和模型中的变量名称保持一致。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/**
* @author zijiao
* @version 17/8/2
*/
public class MnistClassifier {

private final TensorFlowInferenceInterface inference;

public MnistClassifier(AssetManager assetManager) {
inference = new TensorFlowInferenceInterface();
// 加载模型图
inference.initializeTensorFlow(assetManager, TF.MODEL);
// 模型使用阶段, 不需要进行dropout处理, 所以keep_prob直接为1.0
inference.fillNodeFloat(TF.KEEP_PROB_NAME, new int[]{1}, new float[]{1.0f});
}

public MnistData inference(float[] input) {
if (input == null || input.length != 28 * 28) {
throw new RuntimeException("Input data is error.");
}
// 填入Input数据
inference.fillNodeFloat(TF.INPUT_NAME, TF.INPUT_TYPE, input);
// 运行结果, 类似Python中的sess.run([outputs])
inference.runInference(new String[]{TF.OUTPUT_NAME});
float[] output = new float[10];
// 取出结果集中我们需要的
inference.readNodeFloat(TF.OUTPUT_NAME, output);
// 将输出结果交给MnistData处理
return new MnistData(output);
}

}

添加画板

模型处理的逻辑已经写完了,接下来就是如何得到输入源了。由于是手写数字识别,所以接下来就要写画板类。这里只贴出关键代码部分(完整代码可以看本文底部的Github地址)。

手指滑动屏幕时画出手指滑动的轨迹

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@Override
protected void onDraw(Canvas canvas) {
super.onDraw(canvas);
canvas.drawPath(path, paint);
}

@Override
public boolean onTouchEvent(MotionEvent event) {
float x = event.getX();
float y = event.getY();
switch (event.getAction()) {
case MotionEvent.ACTION_DOWN:
path.moveTo(x, y);
break;
case MotionEvent.ACTION_MOVE:
path.lineTo(x, y);
break;
}
invalidate();
return true;
}

向外部提供读取画布数据的方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
public float[] fetchData(int width, int height) {
float[] data = new float[height * width];
try {
setDrawingCacheEnabled(true);
setDrawingCacheQuality(View.DRAWING_CACHE_QUALITY_LOW);
Bitmap cache = getDrawingCache();
fillInputData(cache, data, width, height);
} finally {
setDrawingCacheEnabled(false);
}
return data;
}

private void fillInputData(Bitmap bm, float[] data, int newWidth, int newHeight) {
// 获得图片的宽高
int width = bm.getWidth();
int height = bm.getHeight();
// 计算缩放比例
float scaleWidth = ((float) newWidth) / width;
float scaleHeight = ((float) newHeight) / height;
// 取得想要缩放的matrix参数
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
// 得到新的图片
Bitmap newbm = Bitmap.createBitmap(bm, 0, 0, width, height, matrix, true);
for (int y = 0; y < newHeight; y++) {
for (int x = 0; x < newWidth; x++) {
int pixel = newbm.getPixel(x, y);
data[newWidth * y + x] = pixel == 0xffffffff ? 0 : 1;
}
}
}

运行测试

布局代码就直接省略了,我们只需要在点击识别的时候,调用下面这段的识别逻辑即可。

1
2
3
4
5
6
7
8
9
// 识别
public void onInference(View view) {
if (canvasView.isEmpty()) {
resultPanel.setText("画板为空");
return;
}
MnistData result = classifier.inference(canvasView.fetchData(28, 28));
resultPanel.setText(result.top(3));
}

最后附上运行效果图

这里是该项目的Github源码


------------- The end -------------