深度学习利器:TensorFlow在智能终端中的应用

作者: 武维 2017-09-21 12:29:58

深度学习利器:TensorFlow在智能终端中的应用

前言

深度学习在图像处理、语音识别、自然语言处理领域的应用取得了巨大成功,但是它通常在功能强大的服务器端进行运算。如果智能手机通过网络远程连接服务器,也可以利用深度学习技术,但这样可能会很慢,而且只有在设备处于良好的网络连接环境下才行,这就需要把深度学习模型迁移到智能终端。

由于智能终端CPU和内存资源有限,为了提高运算性能和内存利用率,需要对服务器端的模型进行量化处理并支持低精度算法。TensorFlow版本增加了对Android、iOS和Raspberry Pi硬件平台的支持,允许它在这些设备上执行图像分类等操作。这样就可以创建在智能手机上工作并且不需要云端每时每刻都支持的机器学习模型,带来了新的APP。

本文主要基于看花识名APP应用,讲解TensorFlow模型如何应用于Android系统;在服务器端训练TensorFlow模型,并把模型文件迁移到智能终端;TensorFlow Android开发环境构建以及应用开发API。

看花识名APP

使用AlexNet模型、Flowers数据以及Android平台构建了“看花识名”APP。TensorFlow模型对五种类型的花数据进行训练。如下图所示:

Daisy:雏菊

Dandelion:蒲公英

Roses:玫瑰

Sunflowers:向日葵

Tulips:郁金香

在服务器上把模型训练好后,把模型文件迁移到Android平台,在手机上安装APP。使用效果如下图所示,界面上端显示的是模型识别的置信度,界面中间是要识别的花:

TensorFlow模型如何应用于看花识名APP中,主要包括以下几个关键步骤:模型选择和应用、模型文件转换以及Android开发。如下图所示:

模型训练及模型文件

本章采用AlexNet模型对Flowers数据进行训练。AlexNet在2012取得了ImageNet***成绩,top 5准确率达到80.2%。这对于传统的机器学习分类算法而言,已经相当出色。模型结构如下:

本文采用TensorFlow官方Slim(https://github.com/tensorflow/models/tree/master/slim)AlexNet模型进行训练。

  • 首先下载Flowers数据,并转换为TFRecord格式:
  1. DATA_DIR=/tmp/data/flowers 
  2. python download_and_convert_data.py --dataset_name=flowers 
  3.  --dataset_dir="${DATA_DIR}"  
  • 执行模型训练,经过36618次迭代后,模型精度达到85%
  1. TRAIN_DIR=/tmp/data/train 
  2. python train_image_classifier.py --train_dir=${TRAIN_DIR}  
  3. --dataset_dir=${DATASET_DIR} --dataset_name=flowers   
  4. --dataset_split_name=train  --model_name=alexnet_v2  
  5.  --preprocessing_name=vgg  
  • 生成Inference Graph的PB文件
  1. python export_inference_graph.py  --alsologtostderr   
  2. --model_name=alexnet_v2  --dataset_name=flowers --dataset_dir=${DATASET_DIR}  
  3.  --output_file=alexnet_v2_inf_graph.pb  
  • 结合CheckPoint文件和Inference GraphPB文件,生成Freeze Graph的PB文件 
  1. python freeze_graph.py  --input_graph=alexnet_v2_inf_graph.pb  
  2. --input_checkpoint= ${TRAIN_DIR}/model.ckpt-36618  --input_binary=true  
  3. --output_graph=frozen_alexnet_v2.pb --output_node_names=alexnet_v2/fc8/squeezed 
  • 对Freeze Graph的PB文件进行数据量化处理,减少模型文件的大小,生成的quantized_alexnet_v2_graph.pb为智能终端中应用的模型文件
  1. bazel-bin/tensorflow/tools/graph_transforms/transform_graph   
  2. --in_graph=frozen_alexnet_v2.pb  --outputs="alexnet_v2/fc8/squeezed"  
  3. --out_graph=quantized_alexnet_v2_graph.pb --transforms='add_default_attributes 
  4.  strip_unused_nodes(type=float, shape="1,224,224,3")  remove_nodes(op=Identity,  
  5. op=CheckNumerics) fold_constants(ignore_errors=true)  fold_batch_norms  
  6. fold_old_batch_norms quantize_weights quantize_nodes  
  7.  strip_unused_nodes sort_by_execution_order'  

为了减少智能终端上模型文件的大小,TensorFlow中常用的方法是对模型文件进行量化处理,本文对AlexNet CheckPoint文件进行Freeze和Quantized处理后的文件大小变化如下图所示:

量化操作的主要思想是在模型的Inference阶段采用等价的8位整数操作代替32位的浮点数操作,替换的操作包括:卷积操作、矩阵相乘、激活函数、池化操作等。量化节点的输入、输出为浮点数,但是内部运算会通过量化计算转换为8位整数(范围为0到255)的运算,浮点数和8位量化整数的对应关系示例如下图所示:

量化Relu操作的基本思想如下图所示:

TensorFlow Android应用开发环境构建

在Android系统上使用TensorFlow模型做Inference依赖于两个文件libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar。这两个文件可以通过下载TensorFlow源代码后,采用bazel编译出来,如下所示:

  1. android_sdk_repository(name = "androidsdk", api_level = 23, build_tools_version = "25.0.2", path = "/opt/android",) 
  2. android_ndk_repository(name="androidndk",  path="/opt/android/android-ndk-r12b",  api_level=14)  
  • 编译libtensorflow_inference.so
  1. bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so   
  2.   --crosstool_top=//external:android/crosstool --host_crosstool_top= 
  3. @bazel_tools//tools/cpp:toolchain --cpu=armeabi-v7a  
  • 编译libandroid_tensorflow_inference_java.jar
  1. bazel build //tensorflow/contrib/android:android_tensorflow_inference_java 

TensorFlow提供了Android开发的示例框架,下面基于AlexNet模型的看花识名APP做一些相应源码的修改,并编译生成Android的安装包:

  • 基于AlexNet模型,修改Inference的输入、输出的Tensor名称
  1. private static final String INPUT_NAME = "input"
  2.  
  3. private static final String OUTPUT_NAME = "alexnet_v2/fc8/squeezed" 
  • 放置quantized_alexnet_v2_graph.pb和对应的labels.txt文件到assets目录下,并修改Android文件路径
  1. private static final String MODEL_FILE = "file:///android_asset/quantized_alexnet_v2_graph.pb"
  2.  
  3. private static final String LABEL_FILE = "file:///android_asset/labels.txt" 
  • 编译生成安装包
  1. bazel build -c opt //tensorflow/examples/android:tensorflow_demo 
  • 拷贝tensorflow_demo.apk到手机上,并执行安装,太阳花识别效果如下图所示:(点击放大图像)

TensorFlow移动端应用开发API

在Android系统中执行TensorFlow Inference操作,需要调用libandroid_tensorflow_inference_java.jar中的JNI接口,主要接口如下:

  • 构建TensorFlow Inference对象,构建该对象时候会加载TensorFlow动态链接库libtensorflow_inference.so到系统中;参数assetManager为android asset管理器;参数modelFilename为TensorFlow模型文件在android_asset中的路径。
  1. TensorFlowInferenceInterface inferenceInterface = new 
  2.  
  3. TensorFlowInferenceInterface(assetManager, modelFilename);  
  • 向TensorFlow图中加载输入数据,本App中输入数据为摄像头截取到的图片;参数inputName为TensorFlow Inference中的输入数据Tensor的名称;参数floatValues为输入图片的像素数据,进行预处理后的浮点值;[1,inputSize,inputSize,3]为裁剪后图片的大小,比如1张224*224*3的RGB图片。
  1. inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); 
  • 执行模型推理; outputNames为TensorFlow Inference模型中要运算Tensor的名称,本APP中为分类的Logist值。
  1. inferenceInterface.run(outputNames); 
  • 获取模型Inference的运算结果,其中outputName为Tensor名称,参数outputs存储Tensor的运算结果。本APP中,outputs为计算得到的Logist浮点数组。
  1. inferenceInterface.fetch(outputName, outputs); 

总结

本文基于看花识名APP,讲解了TensorFlow在Android智能终端中的应用技术。首先回顾了AlexNet模型结构,基于AlexNet的slim模型对Flowers数据进行训练;对训练后的CheckPoint数据,进行Freeze和Quantized处理,生成智能终端要用的Inference模型。然后介绍了TensorFlow Android应用开发环境的构建,编译生成TensorFlow在Android上的动态链接库以及java开发包;文章***介绍了Inference API的使用方式。

参考文献

深度学习 TensorFlow 智能终端
上一篇:学习机器学习前,你首先要掌握这些概率论基础知识 下一篇:从贝叶斯定理到概率分布:综述概率论基本定义
评论
取消
暂无评论,快去成为第一个评论的人吧

更多资讯推荐

深度学习算法

深度学习算法在机器视觉中就如一个巧妙的接收转换器般的存在,它灵活、敏捷、“深度”与广度兼具,强悍的计算与预测能力可以称为其魅力之处。深度计算——可以集数亿个神经网络的自拟,对于数据、语音、图像等多种形式的资源进行分析、解释。

三姆森科技 ·  19h前
中美欧人工智能发展现状比较分析

从投资、人才、研究、硬件、应用、数据多个维度,系统对比中、美、欧人工智能发展现状,最终得出结论称,美国当前依然保持着世界人工智能发展总体领先地位,中国在一些重要领域与美国的差距缩小,欧盟在三者中相对落后。

王璐菲 ·  20h前
解锁人工智能、机器学习和深度学习

深度学习是机器学习的子集,而机器学习又是人工智能的子集,但是这些名称的起源来自一个有趣的历史。此外,还有一些引人入胜的技术特征,可将深度学习与其他类型的机器学习区分开来……对于技能水平较高的ML、DL或AI的任何人来说,这都是必不可少的工作知识。

佚名 ·  21h前
谈谈基于深度学习的目标检测网络为什么会误检,以及如何优化目标检测的误检问题

在训练人脸检测网络时,一般都会做数据增强,为图像模拟不同姿态、不同光照等复杂情况,这就有可能产生过亮的人脸图像,“过亮”的人脸看起来就像发光的灯泡一样。

刘冲 ·  1天前
报告指出:中国人工智能专利申请数量居全球首位

中国在自然语言处理、芯片技术、机器学习等10多个人工智能子领域的科研产出水平居于世界前列。而在人机交互、知识工程、机器人、计算机图形、计算理论领域,中国还需努力追赶。

Yu ·  3天前
深度学习(Deep learning)入门导读

2016年Google人工智能程序阿尔法围棋(AlphaGo)对战世界围棋选手李世石,最终以4:1的成绩获得胜利,这惊人的一幕将国内外研究和学习人工智能的热题推向了新的高潮。然而,何为深度学习?本文将揭开深度学习的面纱。

洛辰不才 ·  3天前
人工智能时代到来后,有哪些工作难以代替?

我们到底应该如何面对人工智能时代?尤其是哪些工作在这个时代难以代替?这是值得人们认真研究和解决的问题。

江东 ·  4天前
启动机器学习/深度学习项目的八种方法

从探索性的数据分析到自动机器学习(AutoML),组织需要使用这些技术来推动其数据科学项目发展,并建立更好的模型。

李睿 ·  4天前
Copyright©2005-2021 51CTO.COM 版权所有 未经许可 请勿转载