上文介绍了将pytorch模型转为onnx模型的方法,本文将介绍得到.onnx文件后通过onnx_tf将其转为tensorflow文件的方法。
注意事项
onnx_tf中明确说明了onnx只支持tensorflow>=1.15.0,但是由于安卓端的tensorflow更新的比较慢,一般只支持版本较低的tensorflow,如demo中只支持tensorflow 1.12.0(或1.13.0),因此由onnx转换成tensorflow的过程中要转成tensorflow 1.12.0可使用的pb文件。
- 若使用tensorflow 1.15.0环境去输出由tensorflow 1.12.0环境下生成的pb文件是可以正常输出的。
- 若使用tensorflow 1.12.0环境去输出由tensorflow 1.15.0环境下生成的pb文件,则会产生以下报错:
- 若直接使用官方onnx_tf和tensorflow 1.12.0生成pb文件则会产生以下两个报错:
因此需要将onnx_tf的代码进行一些修改。
如果安装onnx_tf时使用pip install onnx_tf
命令,则将anaconda3\Lib\site-packages\onnx_tf\handlers\backend\is_inf.py
文件下的 @tf_func(tf.math.is_inf)
改为 @tf_func(tf.is_inf)
,将anaconda3\Lib\site-packages\onnx_tf\handlers\backend\scatter_nd.py
文件下的@tf_func(tf.tensor_scatter_nd_update)
改为@tf_func(tf.scatter_nd_update)
。
如果去官网下载安装onnx_tf,则将onnx-tensorflow\onnx_tf\handlers\backend\is_inf.py
文件下的@tf_func(tf.math.is_inf)
改为 @tf_func(tf.is_inf)
,将onnx-tensorflow\onnx_tf\handlers\backend\scatter_nd.py
文件下的@tf_func(tf.tensor_scatter_nd_update)
改为@tf_func(tf.scatter_nd_update)
。
代码
onnx2pb
# venv-tf1.12
import onnx
import numpy as np
from onnx_tf.backend import prepare
# 给定输入图片或者随机输入,尺寸要跟.onnx模型生成时dummy_input一样
img = np.load('random.npy')
# img = img.reshape([1, 3, 300, 300])
# 导入onnx到tensorflow中,并获得输出
model = onnx.load('model_005000.onnx')
# 这里必须strict=False,不然生成的pb文件输出会报错
tf_rep = prepare(model, strict=False)
onnx_output = tf_rep.run(img)
print("onnx-tensorflow output: \n",onnx_output)
# 将onnx-tensorflow模型导出成pb格式
name = "model_005000.pb"
tf_rep.export_graph(name)
save_npy
输入图片random.npy
可以通过运行以下代码得到:
import numpy as np
img = np.random.rand(1, 3, 300, 300)
np.save(r'E:\Pythonworkspace\pth2pb\random', img) #保存的路径,random表示文件名
test_pb
验证pb文件:
import tensorflow as tf
import numpy as np
# from PIL import Image
model_path = 'model_005000.pb'
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(model_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
tf.global_variables_initializer().run()
inp = sess.graph.get_tensor_by_name('actual_input_1:0') #以下节点通过运行print_tensorname得到
out0 = sess.graph.get_tensor_by_name('output1:0')
out1 = sess.graph.get_tensor_by_name('73:0')
out2 = sess.graph.get_tensor_by_name('74:0')
img = np.load('random.npy')
# img = img.reshape([1, 3, 300, 300])
pre_num = sess.run([out0, out1, out2], feed_dict={inp: img})
print(pre_num)
print_tensorname
只有能打印出来的节点才能作为输入输出!注意tensor有两种,一种是保存固定值的Const节点,是各个层训练完得到的固定权重偏置等值;一种是会因输入不同而得到不同数值的变量节点,也是我们所需要的tensor。这里输出的tensor名字只是一半,一般来说后一半都是0或者1,如“481:0”、“Add_51:0”等。
打印pb图中的节点:
import tensorflow as tf
model_name = 'model_005000.pb'
with tf.gfile.GFile(model_name , 'rb') as f:
# 使用tf.GraphDef()定义一个空的Graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 把当前流图读入graph_def中
tf.import_graph_def(graph_def, name='')
# 打印所有tensor名称
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name,'\n')
结果如下图所示:
有问题欢迎在评论区留言,本人水平有限,有错误希望大家指正。