【TensorFlow】将tensorflow训练好的模型移植到Android

原博文:【尊重原创,转载请注明出处】https://blog.csdn.net/guyuealian/article/details/79672257

    本博客将以最简单的方式,利用TensorFlow实现了MNIST手写数字识别,并将Python TensoFlow训练好的模型移植到Android手机上运行。网上也有很多移植教程,大部分是在Ubuntu(Linux)系统,一般先利用Bazel工具把TensoFlow编译成.so库文件和jar包,再进行Android配置,实现模型移植。不会使用Bazel也没关系,实质上TensoFlow已经为开发者提供了最新的.so库文件和对应的jar包了(如libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar),我们只需要下载文件,并在本地Android Studio导入jar包和.so库文件,即可以在Android加载TensoFlow的模型了。 

  当然了,本博客的项目代码都上传到Github:https://github.com/PanJinquan/Mnist-tensorFlow-AndroidDemo

   先说一下,本人的开发环境:

Windows 7
Python3.5
TensoFlow 1.6.0(2018年3月23日—当前最新版)
Android Studio 3.0.1(2018年3月23日—当前最新版)

一、利用Python训练模型

     以MNIST手写数字识别为例,这里首先使用Python版的TensorFlow实现单隐含层的SoftMax Regression分类器,并将训练好的模型的网络拓扑结构和参数保存为pb文件。首先,需要定义模型的输入层和输出层节点的名字(通过形参 'name'指定,名字可以随意,后面加载模型时,都是通过该name来传递数据的):

[python]   view plain  copy
  1. x = tf.placeholder(tf.float32,[None,784],name='x_input')#输入节点:x_input  
  2. .  
  3. .  
  4. .  
  5. pre_num=tf.argmax(y,1,output_type='int32',name="output")#输出节点:output  
PS:说一下鄙人遇到坑:起初,我参照网上相关教程训练了一个模型,在Windows下测试没错,但把模型移植到Android后就出错了,但用别人的模型又正常运行;后来折腾了半天才发现,是类型转换出错啦!!!!
TensorFlow默认类型是float32,但我们希望返回的是一个int型,因此需要指定output_type='int32';但注意了,在Windows下测试使用int64和float64都是可以的,但在Android平台上只能使用int32和float32,并且对应Java的int和float类型。

   将训练好的模型保存为.pb文件,这就需要用到tf.graph_util.convert_variables_to_constants函数了。

[python]   view plain  copy
  1. # 保存训练好的模型  
  2. #形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),  
  3. output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['output'])  
  4. with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:#’wb’中w代表写文件,b代表将数据以二进制方式写入文件。  
  5.     f.write(output_graph_def.SerializeToString())  

 关于tensorflow保存模型和加载模型的方法,请参考本人另一篇博客:https://blog.csdn.net/guyuealian/article/details/79693741

这里给出Python训练模型完整的代码如下:

[python]   view plain  copy
  1. #coding=utf-8  
  2. # 单隐层SoftMax Regression分类器:训练和保存模型模块  
  3. from tensorflow.examples.tutorials.mnist import input_data  
  4. import tensorflow as tf  
  5. from tensorflow.python.framework import graph_util  
  6. print('tensortflow:{0}'.format(tf.__version__))  
  7.   
  8. mnist = input_data.read_data_sets("Mnist_data/", one_hot=True)  
  9.   
  10. #create model  
  11. with tf.name_scope('input'):  
  12.     x = tf.placeholder(tf.float32,[None,784],name='x_input')#输入节点名:x_input  
  13.     y_ = tf.placeholder(tf.float32,[None,10],name='y_input')  
  14. with tf.name_scope('layer'):  
  15.     with tf.name_scope('W'):  
  16.         #tf.zeros([3, 4], tf.int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]  
  17.         W = tf.Variable(tf.zeros([784,10]),name='Weights')  
  18.     with tf.name_scope('b'):  
  19.         b = tf.Variable(tf.zeros([10]),name='biases')  
  20.     with tf.name_scope('W_p_b'):  
  21.         Wx_plus_b = tf.add(tf.matmul(x, W), b, name='Wx_plus_b')  
  22.   
  23.     y = tf.nn.softmax(Wx_plus_b, name='final_result')  
  24.   
  25. # 定义损失函数和优化方法  
  26. with tf.name_scope('loss'):  
  27.     loss = -tf.reduce_sum(y_ * tf.log(y))  
  28. with tf.name_scope('train_step'):  
  29.     train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)  
  30.     print(train_step)  
  31. # 初始化  
  32. sess = tf.InteractiveSession()  
  33. init = tf.global_variables_initializer()  
  34. sess.run(init)  
  35. # 训练  
  36. for step in range(100):  
  37.     batch_xs,batch_ys =mnist.train.next_batch(100)  
  38.     train_step.run({x:batch_xs,y_:batch_ys})  
  39.     # variables = tf.all_variables()  
  40.     # print(len(variables))  
  41.     # print(sess.run(b))  
  42.   
  43. # 测试模型准确率  
  44. pre_num=tf.argmax(y,1,output_type='int32',name="output")#输出节点名:output  
  45. correct_prediction = tf.equal(pre_num,tf.argmax(y_,1,output_type='int32'))  
  46. accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))  
  47. a = accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})  
  48. print('测试正确率:{0}'.format(a))  
  49.   
  50. # 保存训练好的模型  
  51. #形参output_node_names用于指定输出的节点名称,output_node_names=['output']对应pre_num=tf.argmax(y,1,name="output"),  
  52. output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['output'])  
  53. with tf.gfile.FastGFile('model/mnist.pb', mode='wb') as f:#’wb’中w代表写文件,b代表将数据以二进制方式写入文件。  
  54.     f.write(output_graph_def.SerializeToString())  
  55. sess.close()  


上面的代码已经将训练模型保存在model/mnist.pb,当然我们可以先在Python中使用该模型进行简单的预测,测试方法如下:

[python]   view plain  copy
  1. import tensorflow as tf  
  2. import numpy as np  
  3. from PIL import Image  
  4. import matplotlib.pyplot as plt  
  5.   
  6. #模型路径  
  7. model_path = 'model/mnist.pb'  
  8. #测试图片  
  9. testImage = Image.open("data/test_image.jpg");  
  10.   
  11. with tf.Graph().as_default():  
  12.     output_graph_def = tf.GraphDef()  
  13.     with open(model_path, "rb") as f:  
  14.         output_graph_def.ParseFromString(f.read())  
  15.         tf.import_graph_def(output_graph_def, name="")  
  16.   
  17.     with tf.Session() as sess:  
  18.         tf.global_variables_initializer().run()  
  19.         # x_test = x_test.reshape(1, 28 * 28)  
  20.         input_x = sess.graph.get_tensor_by_name("input/x_input:0")  
  21.         output = sess.graph.get_tensor_by_name("output:0")  
  22.   
  23.         #对图片进行测试  
  24.         testImage=testImage.convert('L')  
  25.         testImage = testImage.resize((2828))  
  26.         test_input=np.array(testImage)  
  27.         test_input = test_input.reshape(128 * 28)  
  28.         pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果  
  29.         print('模型预测结果为:',pre_num)  
  30.         #显示测试的图片  
  31.         # testImage = test_x.reshape(28, 28)  
  32.         fig = plt.figure(), plt.imshow(testImage,cmap='binary')  # 显示图片  
  33.         plt.title("prediction result:"+str(pre_num))  
  34.         plt.show()  


二、移植到Android

    相信大家看到很多大神的博客,都是要自己编译TensoFlow的so库和jar包,说实在的,这个过程真TM麻烦,反正我弄了半天都没成功过,然后放弃了……。本博客的移植方法不需要安装Bazel,也不需要构建TensoFlow的so库和jar包,因为Google在TensoFlow github中给我们提供了,为什么不用了!!!

1、下载TensoFlow的jar包和so库

    TensoFlow在Github已经存放了很多开发文件:https://github.com/PanJinquan/tensorflow


   我们需要做的是,下载Android: native libs ,打包下载全部文件,其中有我们需要的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar,有了这两个文件,剩下的就是在Android Studio配置的问题了


2、Android Studio配置

(1)新建一个Android项目

(2)把训练好的pb文件(mnist.pb)放入Android项目中app/src/main/assets下,若不存在assets目录,右键main->new->Directory,输入assets。

(3)将下载的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar如下结构放在libs文件夹下


(4)app\build.gradle配置

    在defaultConfig中添加

[css]   view plain  copy
  1. multiDexEnabled true  
  2.      ndk {  
  3.          abiFilters "armeabi-v7a"  
  4.      }  

    增加sourceSets

[css]   view plain  copy
  1. sourceSets {  
  2.     main {  
  3.         jniLibs.srcDirs = ['libs']  
  4.     }  
  5. }  

    在dependencies中增加TensoFlow编译的jar文件libandroid_tensorflow_inference_java.jar:

[css]   view plain  copy
  1. compile files('libs/libandroid_tensorflow_inference_java.jar')  

   OK了,build.gradle配置完成了,剩下的就是java编程的问题了。

3、模型调用

  在需要调用TensoFlow的地方,加载so库“System.loadLibrary("tensorflow_inference");并”import org.tensorflow.contrib.android.TensorFlowInferenceInterface;就可以使用了

     注意,旧版的TensoFlow,是如下方式进行,该方法可参考大神的博客:https://www.jianshu.com/p/1168384edc1e

[java]   view plain  copy
  1. TensorFlowInferenceInterface.fillNodeFloat(); //送入输入数据  
  2. TensorFlowInferenceInterface.runInference();  //进行模型的推理  
  3. TensorFlowInferenceInterface.readNodeFloat(); //获取输出数据  

     但在最新的libandroid_tensorflow_inference_java.jar中,已经没有这些方法了,换为

[java]   view plain  copy
  1. TensorFlowInferenceInterface.feed()  
  2. TensorFlowInferenceInterface.run()  
  3. TensorFlowInferenceInterface.fetch()  

     下面是以MNIST手写数字识别为例,其实现方法如下:

[java]   view plain  copy
  1. package com.example.jinquan.pan.mnist_ensorflow_androiddemo;  
  2.   
  3. import android.content.res.AssetManager;  
  4. import android.graphics.Bitmap;  
  5. import android.graphics.Color;  
  6. import android.graphics.Matrix;  
  7. import android.util.Log;  
  8.   
  9. import org.tensorflow.contrib.android.TensorFlowInferenceInterface;  
  10.   
  11.   
  12. public class PredictionTF {  
  13.     private static final String TAG = "PredictionTF";  
  14.     //设置模型输入/输出节点的数据维度  
  15.     private static final int IN_COL = 1;  
  16.     private static final int IN_ROW = 28*28;  
  17.     private static final int OUT_COL = 1;  
  18.     private static final int OUT_ROW = 1;  
  19.     //模型中输入变量的名称  
  20.     private static final String inputName = "input/x_input";  
  21.     //模型中输出变量的名称  
  22.     private static final String outputName = "output";  
  23.   
  24.     TensorFlowInferenceInterface inferenceInterface;  
  25.     static {  
  26.         //加载libtensorflow_inference.so库文件  
  27.         System.loadLibrary("tensorflow_inference");  
  28.         Log.e(TAG,"libtensorflow_inference.so库加载成功");  
  29.     }  
  30.   
  31.     PredictionTF(AssetManager assetManager, String modePath) {  
  32.         //初始化TensorFlowInferenceInterface对象  
  33.         inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);  
  34.         Log.e(TAG,"TensoFlow模型文件加载成功");  
  35.     }  
  36.   
  37.     /** 
  38.      *  利用训练好的TensoFlow模型预测结果 
  39.      * @param bitmap 输入被测试的bitmap图 
  40.      * @return 返回预测结果,int数组 
  41.      */  
  42.     public int[] getPredict(Bitmap bitmap) {  
  43.         float[] inputdata = bitmapToFloatArray(bitmap,2828);//需要将图片缩放带28*28  
  44.         //将数据feed给tensorflow的输入节点  
  45.         inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);  
  46.         //运行tensorflow  
  47.         String[] outputNames = new String[] {outputName};  
  48.         inferenceInterface.run(outputNames);  
  49.         ///获取输出节点的输出信息  
  50.         int[] outputs = new int[OUT_COL*OUT_ROW]; //用于存储模型的输出数据  
  51.         inferenceInterface.fetch(outputName, outputs);  
  52.         return outputs;  
  53.     }  
  54.   
  55.     /** 
  56.      * 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。 
  57.      * @param bitmap 输入被测试的bitmap图片 
  58.      * @param rx 将图片缩放到指定的大小(列)->28 
  59.      * @param ry 将图片缩放到指定的大小(行)->28 
  60.      * @return   返回归一化后的一维float数组 ->28*28 
  61.      */  
  62.     public static float[] bitmapToFloatArray(Bitmap bitmap, int rx, int ry){  
  63.         int height = bitmap.getHeight();  
  64.         int width = bitmap.getWidth();  
  65.         // 计算缩放比例  
  66.         float scaleWidth = ((float) rx) / width;  
  67.         float scaleHeight = ((float) ry) / height;  
  68.         Matrix matrix = new Matrix();  
  69.         matrix.postScale(scaleWidth, scaleHeight);  
  70.         bitmap = Bitmap.createBitmap(bitmap, 00, width, height, matrix, true);  
  71.         Log.i(TAG,"bitmap width:"+bitmap.getWidth()+",height:"+bitmap.getHeight());  
  72.         Log.i(TAG,"bitmap.getConfig():"+bitmap.getConfig());  
  73.         height = bitmap.getHeight();  
  74.         width = bitmap.getWidth();  
  75.         float[] result = new float[height*width];  
  76.         int k = 0;  
  77.         //行优先  
  78.         for(int j = 0;j < height;j++){  
  79.             for (int i = 0;i < width;i++){  
  80.                 int argb = bitmap.getPixel(i,j);  
  81.                 int r = Color.red(argb);  
  82.                 int g = Color.green(argb);  
  83.                 int b = Color.blue(argb);  
  84.                 int a = Color.alpha(argb);  
  85.                 //由于是灰度图,所以r,g,b分量是相等的。  
  86.                 assert(r==g && g==b);  
  87. //                Log.i(TAG,i+","+j+" : argb = "+argb+", a="+a+", r="+r+", g="+g+", b="+b);  
  88.                 result[k++] = r / 255.0f;  
  89.             }  
  90.         }  
  91.         return result;  
  92.     }  
  93. }  

  简单说明一下:项目新建了一个PredictionTF类,该类会先加载libtensorflow_inference.so库文件;PredictionTF(AssetManager assetManager, String modePath) 构造方法需要传入AssetManager对象和pb文件的路径;

    从资源文件中获取BitMap图片,并传入 getPredict(Bitmap bitmap)方法,该方法首先将BitMap图像缩放到28*28的大小,由于原图是灰度图,我们需要获取灰度图的像素值,并将28*28的像素转存为行向量的一个float数组,并且每个像素点都归一化到0~1之间,这个就是bitmapToFloatArray(Bitmap bitmap, int rx, int ry)方法的作用;

    然后将数据feed给tensorflow的输入节点,并运行(run)tensorflow,最后获取(fetch)输出节点的输出信息。

   MainActivity很简单,一个单击事件获取预测结果:

[java]   view plain  copy
  1. package com.example.jinquan.pan.mnist_ensorflow_androiddemo;  
  2.   
  3. import android.graphics.Bitmap;  
  4. import android.graphics.BitmapFactory;  
  5. import android.support.v7.app.AppCompatActivity;  
  6. import android.os.Bundle;  
  7. import android.util.Log;  
  8. import android.view.View;  
  9. import android.widget.ImageView;  
  10. import android.widget.TextView;  
  11.   
  12. public class MainActivity extends AppCompatActivity {  
  13.   
  14.     // Used to load the 'native-lib' library on application startup.  
  15.     static {  
  16.         System.loadLibrary("native-lib");//可以去掉  
  17.     }  
  18.   
  19.     private static final String TAG = "MainActivity";  
  20.     private static final String MODEL_FILE = "file:///android_asset/mnist.pb"; //模型存放路径  
  21.     TextView txt;  
  22.     TextView tv;  
  23.     ImageView imageView;  
  24.     Bitmap bitmap;  
  25.     PredictionTF preTF;  
  26.     @Override  
  27.     protected void onCreate(Bundle savedInstanceState) {  
  28.         super.onCreate(savedInstanceState);  
  29.         setContentView(R.layout.activity_main);  
  30.   
  31.         // Example of a call to a native method  
  32.         tv = (TextView) findViewById(R.id.sample_text);  
  33.         txt=(TextView)findViewById(R.id.txt_id);  
  34.         imageView =(ImageView)findViewById(R.id.imageView1);  
  35.         bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.test_image);  
  36.         imageView.setImageBitmap(bitmap);  
  37.         preTF =new PredictionTF(getAssets(),MODEL_FILE);//输入模型存放路径,并加载TensoFlow模型  
  38.     }  
  39.   
  40.     public void click01(View v){  
  41.         String res="预测结果为:";  
  42.         int[] result= preTF.getPredict(bitmap);  
  43.         for (int i=0;i<result.length;i++){  
  44.             Log.i(TAG, res+result[i] );  
  45.             res=res+String.valueOf(result[i])+" ";  
  46.         }  
  47.         txt.setText(res);  
  48.         tv.setText(stringFromJNI());  
  49.     }  
  50.     /** 
  51.      * A native method that is implemented by the 'native-lib' native library, 
  52.      * which is packaged with this application. 
  53.      */  
  54.     public native String stringFromJNI();//可以去掉  
  55. }  

   activity_main布局文件:

[html]   view plain  copy
  1. <?xml version="1.0" encoding="utf-8"?>  
  2. <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"  
  3.     android:layout_width="match_parent"  
  4.     android:layout_height="match_parent"  
  5.     android:orientation="vertical"  
  6.     android:paddingBottom="16dp"  
  7.     android:paddingLeft="16dp"  
  8.     android:paddingRight="16dp"  
  9.     android:paddingTop="16dp">  
  10.     <TextView  
  11.         android:id="@+id/sample_text"  
  12.         android:layout_width="wrap_content"  
  13.         android:layout_height="wrap_content"  
  14.         android:text="https://blog.csdn.net/guyuealian"  
  15.         android:layout_gravity="center"/>  
  16.     <Button  
  17.         android:onClick="click01"  
  18.         android:layout_width="match_parent"  
  19.         android:layout_height="wrap_content"  
  20.         android:text="click" />  
  21.     <TextView  
  22.         android:id="@+id/txt_id"  
  23.         android:layout_width="match_parent"  
  24.         android:layout_height="wrap_content"  
  25.         android:gravity="center"  
  26.         android:text="结果为:"/>  
  27.     <ImageView  
  28.         android:id="@+id/imageView1"  
  29.         android:layout_width="wrap_content"  
  30.         android:layout_height="wrap_content"  
  31.         android:layout_gravity="center"/>  
  32. </LinearLayout>  

参考资料:https://blog.csdn.net/gzhermit/article/details/73924515


如果你觉得该帖子帮到你,还望贵人多多支持,鄙人会再接再厉,继续努力的~

猜你喜欢

转载自blog.csdn.net/u011511601/article/details/80426375