[关闭]
@nataliecai1988 2017-09-08T03:16:32.000000Z 字数 5525 阅读 778

深度学习利器:TensorFlow在智能终端中的应用

投稿


作者:武维

前言

深度学习在图像处理、语音识别、自然语言处理领域的应用取得了巨大成功,但是它通常在功能强大的服务器端进行运算。如果智能手机通过网络远程连接服务器,也可以利用深度学习技术,但这样可能会很慢,而且只有在设备处于良好的网络连接环境下才行,这就需要把深度学习模型迁移到智能终端。

由于智能终端CPU和内存资源有限,为了提高运算性能和内存利用率,需要对服务器端的模型进行量化处理并支持低精度算法。TensorFlow版本增加了对Android、iOS和Raspberry Pi硬件平台的支持,允许它在这些设备上执行图像分类等操作。这样就可以创建在智能手机上工作并且不需要云端每时每刻都支持的机器学习模型,带来了新的APP。

本文主要基于看花识名APP应用,讲解TensorFlow模型如何应用于Android系统;在服务器端训练TensorFlow模型,并把模型文件迁移到智能终端;TensorFlow Android开发环境构建以及应用开发API。

看花识名APP

使用AlexNet模型、Flowers数据以及Android平台构建了“看花识名”APP。TensorFlow模型对五种类型的花数据进行训练。如下图所示:

Daisy:雏菊
image

Dandelion:蒲公英
image

Roses:玫瑰
image

Sunflowers:向日葵
image

Tulips:郁金香
image

在服务器上把模型训练好后,把模型文件迁移到Android平台,在手机上安装APP。使用效果如下图所示,界面上端显示的是模型识别的置信度,界面中间是要识别的花:

image

TensorFlow模型如何应用于看花识名APP中,主要包括以下几个关键步骤:模型选择和应用、模型文件转换以及Android开发。如下图所示:

image

image

模型训练及模型文件

本章采用AlexNet模型对Flowers数据进行训练。AlexNet在2012取得了ImageNet最好成绩,top 5准确率达到80.2%。这对于传统的机器学习分类算法而言,已经相当出色。模型结构如下:

image

本文采用TensorFlow官方Slim(https://github.com/tensorflow/models/tree/master/slim)AlexNet模型进行训练。

  1. DATA_DIR=/tmp/data/flowers
  2. python download_and_convert_data.py --dataset_name=flowers --dataset_dir="${DATA_DIR}"
  1. TRAIN_DIR=/tmp/data/train
  2. python train_image_classifier.py --train_dir=${TRAIN_DIR} --dataset_dir=${DATASET_DIR} --dataset_name=flowers --dataset_split_name=train --model_name=alexnet_v2 --preprocessing_name=vgg
  1. python export_inference_graph.py --alsologtostderr --model_name=alexnet_v2 --dataset_name=flowers --dataset_dir=${DATASET_DIR} --output_file=alexnet_v2_inf_graph.pb
  1. python freeze_graph.py --input_graph=alexnet_v2_inf_graph.pb --input_checkpoint= ${TRAIN_DIR}/model.ckpt-36618 --input_binary=true --output_graph=frozen_alexnet_v2.pb --output_node_names=alexnet_v2/fc8/squeezed
  1. bazel-bin/tensorflow/tools/graph_transforms/transform_graph --in_graph=frozen_alexnet_v2.pb --outputs="alexnet_v2/fc8/squeezed" --out_graph=quantized_alexnet_v2_graph.pb --transforms='add_default_attributes strip_unused_nodes(type=float, shape="1,224,224,3") remove_nodes(op=Identity, op=CheckNumerics) fold_constants(ignore_errors=true) fold_batch_norms fold_old_batch_norms quantize_weights quantize_nodes strip_unused_nodes sort_by_execution_order'

为了减少智能终端上模型文件的大小,TensorFlow中常用的方法是对模型文件进行量化处理,本文对AlexNet CheckPoint文件进行Freeze和Quantized处理后的文件大小变化如下图所示:

image

量化操作的主要思想是在模型的Inference阶段采用等价的8位整数操作代替32位的浮点数操作,替换的操作包括:卷积操作、矩阵相乘、激活函数、池化操作等。量化节点的输入、输出为浮点数,但是内部运算会通过量化计算转换为8位整数(范围为0到255)的运算,浮点数和8位量化整数的对应关系示例如下图所示:

image

量化Relu操作的基本思想如下图所示:

image

TensorFlow Android应用开发环境构建

在Android系统上使用TensorFlow模型做Inference依赖于两个文件libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar。这两个文件可以通过下载TensorFlow源代码后,采用bazel编译出来,如下所示:

git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git

https://developer.android.com/ndk/downloads/older_releases.html#ndk-12b-downloads

https://developer.android.com/studio/command-line/sdkmanager.html

  1. android_sdk_repository(name = "androidsdk", api_level = 23, build_tools_version = "25.0.2", path = "/opt/android",)
  2. android_ndk_repository(name="androidndk", path="/opt/android/android-ndk-r12b", api_level=14)
  1. bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so --crosstool_top=//external:android/crosstool --host_crosstool_top=@bazel_tools//tools/cpp:toolchain --cpu=armeabi-v7a
  1. bazel build //tensorflow/contrib/android:android_tensorflow_inference_java

TensorFlow在https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android下提供了Android开发的示例框架,下面基于AlexNet模型的看花识名APP做一些相应源码的修改,并编译生成Android的安装包:

  1. private static final String INPUT_NAME = "input";
  2. private static final String OUTPUT_NAME = "alexnet_v2/fc8/squeezed";
  1. private static final String MODEL_FILE = "file:///android_asset/quantized_alexnet_v2_graph.pb";
  2. private static final String LABEL_FILE = "file:///android_asset/labels.txt";
  1. bazel build -c opt //tensorflow/examples/android:tensorflow_demo

image

TensorFlow移动端应用开发API

在Android系统中执行TensorFlow Inference操作,需要调用libandroid_tensorflow_inference_java.jar中的JNI接口,主要接口如下:

  1. TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
  1. inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
  1. inferenceInterface.run(outputNames);
  1. inferenceInterface.fetch(outputName, outputs);

总结

本文基于看花识名APP,讲解了TensorFlow在Android智能终端中的应用技术。首先回顾了AlexNet模型结构,基于AlexNet的slim模型对Flowers数据进行训练;对训练后的CheckPoint数据,进行Freeze和Quantized处理,生成智能终端要用的Inference模型。然后介绍了TensorFlow Android应用开发环境的构建,编译生成TensorFlow在Android上的动态链接库以及java开发包;文章最后介绍了Inference API的使用方式。

参考文献

[1] http://www.tensorflow.org

[2] 深度学习利器:分布式TensorFlow及实例分析

[3] 深度学习利器:TensorFlow使用实战

[4] 深度学习利器:TensorFlow系统架构与高性能程序设计

[5] 深度学习利器:TensorFlow与深度卷积神经网络

[6] 深度学习利器:TensorFlow与NLP模型

作者简介

武维(微信:3381209@qq.com):博士,系统架构师,主要从事大数据,深度学习,云计算等领域的研发工作。

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注