ONNX(Open Neural Network Exchange)——开放神经网络交换格式,作为框架共用的一种模型交换格式,使用protobuf二进制格式来序列化模型(protobuf序列化可以参考Netty整合Protobuffer ),可以提供更好的传输性能。官方github:GitHub - onnx/onnx at f2daca5e9b9315a2034da61c662d2a7ac28a9488
ONNX将每一个网络的每一层或者说是每一个算子当作节点Node,再由这些Node去构建一个Graph,相当于是一个网络。最后将Graph和这个onnx模型的其他信息结合在一起,生成一个model,也就是最终的onnx模型。实例如下
创建ONNX模型
创建onnx模型有两种方法,一种是其他框架转换过来,如Pytorch、PaddlePaddle等,从Pytorch转换onnx可以参考模型部署篇 的Pytorch 权重 pth 转换 onnx;PaddlePaddle转换onnx可以参考PaddleOCR使用指南 中的Paddle2ONNX。
我们先来生成一个onnx文件
import torch import torch.nn as nn from torch.autograd import Variable class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv = nn.Conv2d(1, 1, 1) self.act = nn.ReLU() def forward(self, x): return self.act(self.conv(x)) if __name__ == '__main__': net = Network() input = Variable(torch.randn([1, 1, 1, 1])) torch.onnx.export(net, input, 'net.onnx', opset_version=10)
然后来打印这个onnx文件的结构
import torch import torch.nn as nn from torch.autograd import Variable import onnx class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv = nn.Conv2d(1, 1, 1) self.act = nn.ReLU() def forward(self, x): return self.act(self.conv(x)) if __name__ == '__main__': # net = Network() # input = Variable(torch.randn([1, 1, 1, 1])) # torch.onnx.export(net, input, 'net.onnx', opset_version=10) print(onnx.load("./net.onnx"))
运行结果
ir_version: 5
producer_name: "pytorch"
producer_version: "1.12.1"
graph {
node {
input: "input.1"
input: "conv.weight"
input: "conv.bias"
output: "input"
name: "Conv_0"
op_type: "Conv"
attribute {
name: "dilations"
ints: 1
ints: 1
type: INTS
}
attribute {
name: "group"
i: 1
type: INT
}
attribute {
name: "kernel_shape"
ints: 1
ints: 1
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "strides"
ints: 1
ints: 1
type: INTS
}
}
node {
input: "input"
output: "4"
name: "Relu_1"
op_type: "Relu"
}
name: "torch_jit"
initializer {
dims: 1
dims: 1
dims: 1
dims: 1
data_type: 1
name: "conv.weight"
raw_data: "\014\317B?"
}
initializer {
dims: 1
data_type: 1
name: "conv.bias"
raw_data: "\344\n\026\277"
}
input {
name: "input.1"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
output {
name: "4"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
}
}
}
}
}
opset_import {
version: 10
}
首先是onnx版本,我们这里为ir_version: 5,然后是从什么框架转换过来的,这里是从Pytorch转换过来的producer_name: "pytorch",版本号是producer_version: "1.12.1"。
然后是graph->node,第一个node是2D卷积核,第二个node是Relu激活函数。node中的op_type是节点类型,所有类型可以参考https://github.com/onnx/onnx/blob/f2daca5e9b9315a2034da61c662d2a7ac28a9488/docs/Operators.md。name是节点名称,它跟op_type是不同的。attribute是节点属性,在Conv_0中就是2D卷积的各种属性,比如"group"是分组卷积,"kernel_shape"是卷积核尺寸等等。initializer是初始化,包含了权重初始化和偏置初始化。input是输入,包含输入的形状,output是输出,包含输出的形状。opset_import为当前的模型文件所依赖的算子domain和版本。
最后我们来检查该模型,运行是没有问题的。
import torch import torch.nn as nn from torch.autograd import Variable import onnx class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv = nn.Conv2d(1, 1, 1) self.act = nn.ReLU() def forward(self, x): return self.act(self.conv(x)) if __name__ == '__main__': # net = Network() # input = Variable(torch.randn([1, 1, 1, 1])) # torch.onnx.export(net, input, 'net.onnx', opset_version=10) # print(onnx.load("./net.onnx")) model = onnx.load("./net.onnx") onnx.checker.check_model(model)
另外一种就是用onnx自己的方法创建onnx模型。
import onnx import onnx.helper as helper import numpy as np if __name__ == '__main__': input = helper.make_tensor_value_info(name='input', elem_type=onnx.TensorProto.FLOAT, shape=[1, 3, 244, 244]) output = helper.make_tensor_value_info(name='output', elem_type=onnx.TensorProto.FLOAT, shape=[1, 3, 244, 244]) weight = helper.make_tensor(name='weight', data_type=onnx.TensorProto.FLOAT, dims=[3, 3, 1, 1], vals=np.random.randn(3, 3, 1, 1)) bias = helper.make_tensor(name='bias', data_type=onnx.TensorProto.FLOAT, dims=[3], vals=np.random.randn(3)) node = helper.make_node(op_type='Conv', inputs=['input', 'weight', 'bias'], outputs=['output'], kernel_shape=[1, 1], strides=[1, 1], group=1, pads=[0, 0, 0, 0]) graph = helper.make_graph(name='graph', nodes=[node], inputs=[input], outputs=[output], initializer=[weight, bias]) model = helper.make_model(graph) onnx.checker.check_model(model) print(model) onnx.save_model(model, 'model.onnx')
运行结果
ir_version: 8
graph {
node {
input: "input"
input: "weight"
input: "bias"
output: "output"
op_type: "Conv"
attribute {
name: "group"
i: 1
type: INT
}
attribute {
name: "kernel_shape"
ints: 1
ints: 1
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "strides"
ints: 1
ints: 1
type: INTS
}
}
name: "graph"
initializer {
dims: 3
dims: 3
dims: 1
dims: 1
data_type: 1
float_data: 0.45837152004241943
float_data: 0.10209446400403976
float_data: 1.0382566452026367
float_data: -0.09292714297771454
float_data: 1.58871591091156
float_data: 0.3746287226676941
float_data: -0.35588690638542175
float_data: 0.7165427207946777
float_data: 0.10244251787662506
name: "weight"
}
initializer {
dims: 3
data_type: 1
float_data: -0.36782845854759216
float_data: 2.305680513381958
float_data: -0.13051341474056244
name: "bias"
}
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 244
}
dim {
dim_value: 244
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 244
}
dim {
dim_value: 244
}
}
}
}
}
}
opset_import {
version: 17
}
动态设置batch_size
在上面的结果中,我们可以看到input的维度都是固定值[1,3,244,244],现在我们要来改变这个固定值为可以动态输入的值。我们先将模型给运行起来。
import onnx import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load("./model.onnx") sess = onnxruntime.InferenceSession('./model.onnx') input = np.random.randn(1, 3, 244, 244).astype(np.float32) print(sess.run(['output'], {'input': input}))
运行结果
[array([[[[-7.4062514e-01, 2.5951520e-01, -3.5876265e-01, ...,
-2.0852795e+00, -1.0078001e-01, -4.9386564e-01],
[-6.0379845e-01, 9.2830718e-01, -4.2096943e-02, ...,
-1.9139317e-01, 1.6547061e+00, 1.4468774e+00],
[ 2.6494553e+00, -9.6209788e-01, 8.2099646e-02, ...,
-1.5899204e+00, -1.3295431e+00, 1.1512205e-01],
...,
[ 1.4135087e+00, 6.4077592e-01, -5.6514746e-01, ...,
2.1367333e+00, 2.6012421e+00, -1.3565271e+00],
[ 6.9879985e-01, 1.2454928e+00, 6.0045028e-01, ...,
-6.1302024e-01, -4.3026954e-02, -7.2975445e-01],
[-2.1020520e+00, -1.2499222e+00, -9.3896770e-01, ...,
-4.6129468e-01, 5.4580927e-01, -7.4599540e-01]],
[[ 5.6230574e+00, 2.6218858e+00, 7.1071947e-01, ...,
3.6510468e-02, 2.5771899e+00, 2.0060635e+00],
[ 4.2759910e+00, 2.5261867e+00, 1.0787441e+00, ...,
3.3373690e+00, 4.5090003e+00, 3.5535808e+00],
[ 1.6522924e+00, 1.5206050e+00, 3.6905313e+00, ...,
1.5963824e+00, 5.1875353e-02, 3.4248161e+00],
...,
[ 1.0295208e+00, 4.5397396e+00, 4.3366423e+00, ...,
1.2408195e+00, 3.1239326e+00, 1.7476916e+00],
[ 9.7080982e-01, 1.9692242e+00, 3.7690439e+00, ...,
-1.6770840e-01, 1.1871569e+00, 4.2690439e+00],
[ 4.4730301e+00, 1.5573008e+00, 7.2707558e+00, ...,
4.7898588e+00, 2.9080591e+00, 7.2294927e-01]],
[[ 1.3509388e+00, -1.9160898e-01, -1.3318433e+00, ...,
-1.0562456e+00, 1.0652192e-01, -4.4993240e-01],
[ 7.3106253e-01, -4.0714890e-03, -5.3625894e-01, ...,
-6.2385768e-02, 3.3464909e-01, 2.7667671e-01],
[-7.8517151e-01, -7.1918708e-01, 5.5366117e-01, ...,
-4.7982591e-01, -1.0322813e+00, 8.0901492e-01],
...,
[-1.0904443e+00, 4.7577775e-01, 9.5288980e-01, ...,
-9.8435390e-01, -5.1632053e-01, 2.4581529e-01],
[-6.4627886e-01, -9.8449951e-01, 1.6146483e-01, ...,
-1.2009792e+00, -7.3006052e-01, 7.0891309e-01],
[ 1.3855783e+00, -8.9338100e-01, 2.4704218e+00, ...,
6.8950468e-01, 1.7709453e-01, -7.6678610e-01]]]],
dtype=float32)]
现在我们来把输入的batch_size调整成2
import onnx import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load("./model.onnx") inputs = model.graph.input outputs = model.graph.output for i in inputs: i.type.tensor_type.shape.dim[0].dim_value = 2 for o in outputs: o.type.tensor_type.shape.dim[0].dim_value = 2 onnx.checker.check_model(model) onnx.save_model(model, 'dynamic_model.onnx') sess = onnxruntime.InferenceSession('./dynamic_model.onnx') input = np.random.randn(2, 3, 244, 244).astype(np.float32) print(sess.run(['output'], {'input': input}))
运行结果
[array([[[[-2.10871696e-02, -1.32871771e+00, -1.22335061e-01, ...,
4.77721721e-01, -4.10815179e-01, -1.37511027e+00],
[-1.09181249e+00, -2.02204657e+00, 1.54176390e+00, ...,
-1.88722742e+00, -2.00726366e+00, 4.24929589e-01],
[-7.14685619e-01, 3.82802397e-01, -2.30412316e+00, ...,
7.06834435e-01, -2.36892438e+00, -2.11947155e+00],
...,
[-9.51929450e-01, -1.22408187e+00, -1.35213524e-01, ...,
5.55669367e-02, -5.95110297e-01, -2.15206313e+00],
[ 8.90325904e-01, -1.89442956e+00, 8.34725618e-01, ...,
-2.34860206e+00, -1.09965193e+00, -4.96994108e-01],
[ 1.56639183e+00, 5.97145438e-01, -5.28750658e-01, ...,
5.77995658e-01, -1.46205699e+00, 2.80693078e+00]],
[[ 3.09728765e+00, -1.42589498e+00, 7.58970022e-01, ...,
3.48910093e+00, 2.95971513e+00, 1.96736765e+00],
[ 2.76622701e+00, 1.58350587e+00, 2.41761374e+00, ...,
3.68322372e-01, 3.05963039e-01, 2.99718475e+00],
[-1.75151324e+00, 2.79870439e+00, -3.03543806e-01, ...,
2.86027908e+00, 1.78771615e+00, 4.79569674e+00],
...,
[ 1.30739605e+00, 1.83714139e+00, 4.55001736e+00, ...,
1.44066858e+00, 4.87037659e+00, 2.10291076e+00],
[ 9.44083452e-01, -8.11131001e-02, 2.89160919e+00, ...,
2.34788847e+00, 1.95467031e+00, 3.87145948e+00],
[ 2.71238947e+00, 1.46723819e+00, 7.61192560e-01, ...,
2.69581342e+00, 2.11386037e+00, 4.08577728e+00]],
[[ 3.52043629e-01, -1.83945060e+00, -9.97831583e-01, ...,
-2.60245442e-01, 3.69277894e-01, 1.17505208e-01],
[ 2.62015522e-01, -6.50106370e-01, -7.36498535e-01, ...,
-3.72626394e-01, -9.92001474e-01, 1.87904552e-01],
[-2.00427341e+00, -2.67415404e-01, -1.00334084e+00, ...,
8.22970718e-02, 1.41485706e-01, 1.49001801e+00],
...,
[-9.85595703e-01, -9.74879414e-03, 1.27501774e+00, ...,
-7.10564435e-01, 1.17551017e+00, -4.15902734e-01],
[-9.80473995e-01, -1.07735765e+00, 2.39617974e-02, ...,
-1.93872005e-02, 2.48361230e-02, 7.19040394e-01],
[-6.61614537e-02, -4.85614896e-01, -7.31452227e-01, ...,
-9.65917259e-02, -2.94267178e-01, 1.87805906e-01]]],
[[[-1.05941308e+00, 8.10959578e-01, -9.29054856e-01, ...,
-1.33419132e+00, -5.62950134e-01, 3.15277368e-01],
[-2.45844007e+00, -5.31174302e-01, 8.06264520e-01, ...,
-1.37343729e+00, -1.26287377e+00, -1.79255664e+00],
[ 5.01155496e-01, 2.53203034e+00, -9.11398768e-01, ...,
-2.61194611e+00, -6.27550602e-01, -1.04612875e+00],
...,
[ 5.64767838e-01, 1.82380235e+00, -9.87865806e-01, ...,
-1.48546624e+00, 5.00284791e-01, -1.14099467e+00],
[-1.48488015e-01, -3.75306606e-03, 2.05217457e+00, ...,
-4.82964367e-01, 6.37757182e-01, 5.87742925e-01],
[-7.62285709e-01, 5.78535438e-01, -9.07517672e-01, ...,
-1.40203249e+00, 3.13063234e-01, 9.46564317e-01]],
[[ 2.21778965e+00, 1.17825162e+00, 1.17773283e+00, ...,
4.21785736e+00, 1.93207061e+00, 6.90674305e+00],
[ 5.16840172e+00, 4.03573513e-02, 3.72957373e+00, ...,
2.57324958e+00, 3.23857665e-01, 8.98278236e-01],
[ 1.18916261e+00, 4.03137350e+00, 1.54717636e+00, ...,
5.73142242e+00, 2.54209590e+00, 3.02691102e+00],
...,
[ 2.02949071e+00, 4.00444984e+00, 3.55739307e+00, ...,
5.54533482e-01, 3.57894540e+00, 7.03547835e-01],
[ 2.57975435e+00, 2.32062602e+00, 4.18669128e+00, ...,
2.15663671e+00, 2.39567637e+00, 7.93485880e-01],
[ 3.32399893e+00, 3.12817383e+00, 3.60134292e+00, ...,
1.70791423e+00, 7.71586776e-01, 3.58140349e+00]],
[[-6.90246701e-01, -8.55753422e-01, -1.35433823e-01, ...,
9.99482393e-01, -2.96287388e-01, 2.49611807e+00],
[ 1.56937921e+00, -9.95752215e-01, 5.38442284e-02, ...,
1.63274094e-01, -9.27845955e-01, -6.64922059e-01],
[-5.40241778e-01, 2.26666585e-01, -2.95405626e-01, ...,
1.90356636e+00, 4.94795978e-01, 1.35599896e-01],
...,
[-4.09579694e-01, 1.26961544e-01, 5.97525239e-01, ...,
-9.00853217e-01, 8.11160445e-01, -8.88532698e-01],
[-5.75763881e-01, -1.15364529e-01, 2.42510274e-01, ...,
1.83168098e-01, -3.83193374e-01, -1.10992551e+00],
[ 9.75027233e-02, 1.07848495e-02, 2.93477297e-01, ...,
-2.67393768e-01, -8.09366763e-01, 6.60410523e-03]]]],
dtype=float32)]
但是现在batch_size依然是一个固定值,如果我们修改input的第一个维度,是会报错的。则我们需要修改成以下的方式才能输入任意的batch_size。
import onnx import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load("./model.onnx") inputs = model.graph.input outputs = model.graph.output for i in inputs: i.type.tensor_type.shape.dim[0].dim_param = 'batchsize' for o in outputs: o.type.tensor_type.shape.dim[0].dim_param = 'batchsize' onnx.checker.check_model(model) onnx.save_model(model, 'dynamic_model.onnx') sess = onnxruntime.InferenceSession('./dynamic_model.onnx') input = np.random.randn(3, 3, 244, 244).astype(np.float32) print(sess.run(['output'], {'input': input}))
运行结果
[array([[[[-5.6308472e-01, -2.8269453e+00, -2.7103744e+00, ...,
4.2550400e-01, 6.5147376e-01, -4.7779888e-02],
[-2.5536952e+00, 1.1469245e-01, 3.4514198e-01, ...,
-1.8919052e+00, -5.7445437e-01, -1.5864235e+00],
[-1.7443299e-02, -8.9739335e-01, -2.9766396e-01, ...,
2.7872375e-01, -8.8234627e-01, -2.3681331e+00],
...,
[-1.3148707e+00, -5.4888296e-01, 4.1061863e-01, ...,
-1.0763314e+00, -9.6379507e-01, 1.3077673e+00],
[-6.3514382e-02, -5.1493609e-01, -1.5793841e+00, ...,
-2.2589236e-02, -2.2170777e+00, 1.2437304e+00],
[ 7.4394345e-01, 7.8581774e-01, 2.0062235e-01, ...,
-1.4014708e+00, 5.5377036e-02, 3.6608991e-01]],
[[ 5.4129419e+00, 1.7448205e+00, 3.4165416e+00, ...,
1.0320716e+00, 1.6988618e+00, 5.1501741e+00],
[-1.3918903e+00, 1.7199724e+00, 2.1343894e+00, ...,
8.0553353e-01, 4.7985373e+00, 2.5783958e+00],
[ 2.3555427e+00, 6.3222194e-01, 2.9314611e+00, ...,
4.3459427e-01, 1.3417060e+00, 1.6852837e+00],
...,
[ 5.7537341e-01, 3.0654173e+00, -5.7629395e-01, ...,
1.0968879e+00, 3.7861698e+00, 1.4928346e+00],
[ 3.1267416e+00, 2.0358701e+00, 2.2204084e+00, ...,
5.2084265e+00, 3.9166064e+00, 6.4575119e+00],
[-7.2486067e-01, 2.3311584e+00, 2.0912974e+00, ...,
1.8693907e+00, 3.2796674e+00, 3.8991761e+00]],
[[ 1.1898929e+00, 8.9648962e-03, 8.5148907e-01, ...,
-4.7057205e-01, -7.9108685e-01, 1.0573645e+00],
[-1.5732453e+00, -3.8554335e-01, -1.6086581e-01, ...,
-8.1125468e-01, 1.2085729e+00, 5.6812420e-02],
[-3.0767348e-01, -8.5083431e-01, 4.9003422e-02, ...,
-7.8210533e-01, -5.2408022e-01, -2.3199841e-02],
...,
[-7.3540843e-01, -2.9384446e-01, -1.6465921e+00, ...,
-3.8980949e-01, 7.1137357e-01, -8.0783540e-01],
[ 2.3953258e-01, -1.7050017e-01, -1.3933203e-01, ...,
1.6591790e+00, 1.0759927e+00, 1.7683787e+00],
[-1.6956003e+00, -4.6602386e-01, -3.4259117e-01, ...,
-1.0014131e-01, 2.6990986e-01, 8.6363363e-01]]],
[[[ 6.4478827e-01, -8.1067204e-01, -1.2237258e+00, ...,
-1.2951733e+00, -6.2070227e-01, -1.2906476e+00],
[-6.6038930e-01, -2.8674665e-01, -1.0612940e+00, ...,
4.6769258e-01, 4.8500946e-01, -5.6188315e-01],
[ 1.0600269e-02, -1.4934481e+00, 9.1430867e-01, ...,
-6.1285675e-01, -3.0706315e+00, -9.9033105e-01],
...,
[ 1.7771789e+00, -1.3830042e+00, -1.4351614e+00, ...,
-2.6786397e+00, 3.7956804e-02, 6.7189908e-01],
[-2.1517308e+00, -5.8123243e-01, -7.7163374e-01, ...,
1.6774191e+00, 7.2239363e-01, 1.3373801e+00],
[-8.6465418e-01, -1.3932706e+00, -2.2982714e+00, ...,
1.9587449e+00, -6.2718022e-01, -1.1754386e+00]],
[[ 5.2605295e+00, 6.8119764e-01, 1.6433215e+00, ...,
1.4899890e+00, 7.7494907e-01, 1.0885936e+00],
[ 1.7135508e+00, 1.7890544e+00, 1.5538380e+00, ...,
4.2714515e+00, 3.4532502e+00, 4.0540075e+00],
[ 3.2757509e-01, 2.8093519e+00, 4.4473543e+00, ...,
1.6302650e+00, 2.0791094e+00, -2.7314346e+00],
...,
[ 3.1872306e+00, 2.1063502e+00, 4.4839258e+00, ...,
8.6034179e-01, 3.7707591e+00, 3.9809742e+00],
[-2.0055294e-02, -4.3134212e-02, 2.1313593e+00, ...,
3.0318618e+00, 2.2852294e+00, 3.9968524e+00],
[ 2.1781492e+00, 3.6937137e+00, 1.5003638e+00, ...,
3.5955300e+00, 1.7056749e+00, 1.9585730e+00]],
[[ 1.1795213e+00, -7.1754062e-01, -1.3523299e-01, ...,
-5.6350648e-01, -1.3417213e+00, -5.0127864e-02],
[-5.1167816e-01, -7.7823803e-02, -3.1461412e-01, ...,
7.8631788e-01, 5.9256524e-01, 6.9275266e-01],
[-1.3142396e+00, 7.9331988e-01, 5.0062788e-01, ...,
-6.4525604e-03, -3.3234254e-02, -2.1546085e+00],
...,
[-2.0651843e-01, 1.1771068e-02, 1.3835690e+00, ...,
2.8883666e-03, 4.5511311e-01, 2.9804629e-01],
[-9.0822458e-01, -1.3634090e+00, -4.2348909e-01, ...,
2.9903316e-01, -5.9180021e-01, 5.1938176e-01],
[-3.7974668e-01, 6.5785772e-01, -4.8025602e-01, ...,
1.5578230e-01, -8.5666311e-01, -8.2990326e-02]]],
[[[ 2.7270940e-01, 1.6803369e-01, 6.4784336e-01, ...,
-8.6817765e-01, 2.4317000e+00, 9.9560642e-01],
[-1.0902294e+00, -1.5418210e+00, -6.4213789e-01, ...,
3.8346985e-01, -2.2009264e-01, -1.4083362e+00],
[-1.2999996e+00, -1.0029310e+00, -8.0927563e-01, ...,
-9.6844232e-01, 4.7647089e-02, -1.7528368e+00],
...,
[ 7.9181468e-01, -7.1245348e-01, -1.2355906e+00, ...,
-4.4910422e-01, 7.0296872e-01, -1.8157486e+00],
[ 8.5229218e-01, -3.9036795e-01, 3.7029549e-01, ...,
-2.0579123e+00, 9.2259049e-03, -1.2485095e+00],
[-1.0421257e+00, 9.6360290e-01, -1.9165359e+00, ...,
-1.5525728e+00, -2.7757692e+00, 5.9844279e-01]],
[[ 2.0120070e+00, 2.5763493e+00, 2.5311258e+00, ...,
2.0375581e+00, 1.6430848e+00, 4.5296006e+00],
[-6.4119029e-01, 3.2270002e-01, 2.7286339e+00, ...,
3.4792902e+00, 4.8433290e+00, 1.8760866e+00],
[ 5.2160606e+00, 5.8354855e-01, 1.9910555e+00, ...,
3.8761294e-01, 3.4568546e+00, 2.2840927e+00],
...,
[ 2.4697292e+00, 3.1099756e+00, 4.5984769e+00, ...,
3.1638999e+00, 1.7895203e+00, 5.1426482e-01],
[ 2.0174649e+00, 3.7343421e+00, 1.3838698e+00, ...,
6.8948352e-01, 1.9830887e+00, -1.2911747e+00],
[ 2.2970469e+00, 2.8243198e+00, 8.7906146e-01, ...,
3.2837601e+00, 1.0420291e+00, 4.1244802e+00]],
[[-4.5218289e-01, -1.1248827e-02, -3.9010030e-01, ...,
-2.1441557e-01, -8.8925439e-01, 1.0432711e+00],
[-1.5277631e+00, -6.0763943e-01, 8.2450414e-01, ...,
5.1565582e-01, 9.1227055e-01, -4.1257131e-01],
[ 1.1678007e+00, -8.4806198e-01, -4.1370481e-01, ...,
-9.4888353e-01, 2.4556525e-01, 2.7058780e-02],
...,
[-2.6444227e-01, 6.4803612e-01, 1.4935874e+00, ...,
1.9097075e-02, -6.0670388e-01, -3.2186458e-01],
[-5.2368152e-01, 6.9923353e-01, -6.0641676e-01, ...,
-3.2536793e-01, -3.0933461e-01, -1.7596698e+00],
[-5.7884902e-01, -9.0141267e-02, -4.4471401e-01, ...,
3.5021925e-01, -1.7998603e-01, 6.4285696e-01]]]],
dtype=float32)]
这里可以把input的第一个维度,也就是batch_size修改成任意数值,程序都可以运行。此时我们打印下model的信息。
import onnx import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load("./model.onnx") inputs = model.graph.input outputs = model.graph.output for i in inputs: i.type.tensor_type.shape.dim[0].dim_param = 'batchsize' for o in outputs: o.type.tensor_type.shape.dim[0].dim_param = 'batchsize' print(model) # onnx.checker.check_model(model) # onnx.save_model(model, 'dynamic_model.onnx') # sess = onnxruntime.InferenceSession('./dynamic_model.onnx') # input = np.random.randn(3, 3, 244, 244).astype(np.float32) # print(sess.run(['output'], {'input': input}))
运行结果
ir_version: 8
graph {
node {
input: "input"
input: "weight"
input: "bias"
output: "output"
op_type: "Conv"
attribute {
name: "group"
i: 1
type: INT
}
attribute {
name: "kernel_shape"
ints: 1
ints: 1
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "strides"
ints: 1
ints: 1
type: INTS
}
}
name: "graph"
initializer {
dims: 3
dims: 3
dims: 1
dims: 1
data_type: 1
float_data: 0.45837152004241943
float_data: 0.10209446400403976
float_data: 1.0382566452026367
float_data: -0.09292714297771454
float_data: 1.58871591091156
float_data: 0.3746287226676941
float_data: -0.35588690638542175
float_data: 0.7165427207946777
float_data: 0.10244251787662506
name: "weight"
}
initializer {
dims: 3
data_type: 1
float_data: -0.36782845854759216
float_data: 2.305680513381958
float_data: -0.13051341474056244
name: "bias"
}
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batchsize"
}
dim {
dim_value: 3
}
dim {
dim_value: 244
}
dim {
dim_value: 244
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batchsize"
}
dim {
dim_value: 3
}
dim {
dim_value: 244
}
dim {
dim_value: 244
}
}
}
}
}
}
opset_import {
version: 17
}
这里我们可以看到在input中的第一个dim中变成了dim_param: "batchsize"
节点的增加和删除
- 增加节点
import onnx import onnx.helper as helper import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load('./model.onnx') nodes = model.graph.node new_node = helper.make_node(op_type='Relu', name='relu1', inputs=['conv1'], outputs=['output']) nodes.append(new_node) nodes[0].output[0] = 'conv1' onnx.checker.check_model(model) onnx.save_model(model, 'add_model.onnx') input = np.random.randn(1, 3, 244, 244).astype(np.float32) sess = onnxruntime.InferenceSession('./add_model.onnx') print(sess.run(['output'], {'input': input}))
运行结果
[array([[[[1.5453527 , 0. , 0. , ..., 0.04255658,
0. , 0.40214583],
[0. , 0.5019511 , 0. , ..., 0.34235588,
0.36859825, 0. ],
[0. , 0. , 0.34334645, ..., 0. ,
0. , 0. ],
...,
[1.1857387 , 1.0710502 , 0. , ..., 0. ,
1.8497316 , 0. ],
[0.37889728, 0. , 0. , ..., 0. ,
0. , 0. ],
[0.73697627, 0. , 0.4978644 , ..., 0. ,
0. , 0.32394186]],
[[1.2723072 , 0. , 0.66669345, ..., 5.6399436 ,
1.4827138 , 2.7300682 ],
[4.5705633 , 2.9856906 , 2.9005556 , ..., 3.505543 ,
4.7502317 , 0. ],
[1.5251542 , 3.3182473 , 3.8036246 , ..., 0. ,
1.6024959 , 1.4051957 ],
...,
[1.7204559 , 4.551407 , 4.172427 , ..., 0.9121852 ,
3.3593512 , 4.6163626 ],
[0.2845726 , 0.13289118, 3.3601975 , ..., 3.9331636 ,
0.3700601 , 1.5711328 ],
[3.3283763 , 2.128338 , 2.1621299 , ..., 1.7635765 ,
0. , 2.1479769 ]],
[[0. , 0. , 0. , ..., 1.4292918 ,
0. , 0.46683455],
[1.0534286 , 0. , 0.02258705, ..., 0.4342987 ,
1.1339298 , 0. ],
[0. , 0.50237906, 0.20627443, ..., 0. ,
0. , 0. ],
...,
[0. , 0.78040606, 1.003104 , ..., 0. ,
0. , 1.0389903 ],
[0. , 0. , 0. , ..., 0.74816215,
0. , 0.02678718],
[0.26068228, 0. , 0. , ..., 0. ,
0. , 0. ]]]], dtype=float32)]
这里我们再来打印下model的信息
import onnx import onnx.helper as helper import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load('./model.onnx') nodes = model.graph.node new_node = helper.make_node(op_type='Relu', name='relu1', inputs=['conv1'], outputs=['output']) nodes.append(new_node) nodes[0].output[0] = 'conv1' print(model) # onnx.checker.check_model(model) # onnx.save_model(model, 'add_model.onnx') # # input = np.random.randn(1, 3, 244, 244).astype(np.float32) # sess = onnxruntime.InferenceSession('./add_model.onnx') # print(sess.run(['output'], {'input': input}))
运行结果
ir_version: 8
graph {
node {
input: "input"
input: "weight"
input: "bias"
output: "conv1"
op_type: "Conv"
attribute {
name: "group"
i: 1
type: INT
}
attribute {
name: "kernel_shape"
ints: 1
ints: 1
type: INTS
}
attribute {
name: "pads"
ints: 0
ints: 0
ints: 0
ints: 0
type: INTS
}
attribute {
name: "strides"
ints: 1
ints: 1
type: INTS
}
}
node {
input: "conv1"
output: "output"
name: "relu1"
op_type: "Relu"
}
name: "graph"
initializer {
dims: 3
dims: 3
dims: 1
dims: 1
data_type: 1
float_data: 0.45837152004241943
float_data: 0.10209446400403976
float_data: 1.0382566452026367
float_data: -0.09292714297771454
float_data: 1.58871591091156
float_data: 0.3746287226676941
float_data: -0.35588690638542175
float_data: 0.7165427207946777
float_data: 0.10244251787662506
name: "weight"
}
initializer {
dims: 3
data_type: 1
float_data: -0.36782845854759216
float_data: 2.305680513381958
float_data: -0.13051341474056244
name: "bias"
}
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 244
}
dim {
dim_value: 244
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 244
}
dim {
dim_value: 244
}
}
}
}
}
}
opset_import {
version: 17
}
这里我们可以看到增加了一个relu1的节点,并且第一个节点的output是conv1,第二个节点的input是conv1,output是output。
- 删除节点
import onnx import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load('./add_model.onnx') nodes = model.graph.node for node in nodes: if node.name == 'relu1': nodes.remove(node) nodes[0].output[0] = 'output' onnx.checker.check_model(model) onnx.save_model(model, 'del_model.onnx') input = np.random.randn(1, 3, 244, 244).astype(np.float32) sess = onnxruntime.InferenceSession('./del_model.onnx') print(sess.run(['output'], {'input': input}))
运行结果
[array([[[[-8.5923064e-01, -4.2249173e-01, 3.8687822e-01, ...,
-4.8348337e-02, 3.1652334e-01, -5.7166600e-01],
[ 3.1469372e-01, -9.4796360e-01, -2.4245100e+00, ...,
4.1007617e-01, -1.4098099e+00, 6.7472184e-01],
[-1.2910874e+00, 1.6070822e-01, -1.0217074e+00, ...,
7.1467435e-01, 1.5835044e-01, -6.4228356e-01],
...,
[-2.5442154e+00, -8.8969648e-01, 1.1389736e+00, ...,
1.7202379e+00, -1.1968368e+00, -3.3861694e-01],
[-9.0216339e-01, 4.8469666e-01, -9.5050204e-01, ...,
4.0511075e-01, -1.0113320e-01, 1.8743831e+00],
[ 3.2901958e-01, 4.3780953e-02, 1.4250931e+00, ...,
-1.4544667e+00, 9.0659869e-01, 1.7170597e+00]],
[[ 1.3439684e+00, 3.0856354e+00, 2.7811766e+00, ...,
4.1714394e-01, -3.3547878e-02, 1.1771207e+00],
[ 2.1574910e+00, 2.1122241e+00, -5.8333945e-01, ...,
1.9629711e+00, 3.4840956e+00, 6.1747317e+00],
[ 5.2136226e+00, 4.8688288e+00, 1.4613919e+00, ...,
4.1095753e+00, 1.4553337e+00, 3.5171165e+00],
...,
[ 7.3736429e-02, 8.4109855e-01, 5.7113109e+00, ...,
3.6336284e+00, 4.4551125e+00, 3.4602299e+00],
[ 1.1054695e+00, 2.7417006e+00, 4.9065466e+00, ...,
2.1775680e+00, 4.4132576e+00, 2.3781679e+00],
[-1.2788355e+00, 2.5300267e+00, 3.2560487e+00, ...,
2.2025514e+00, 4.2551570e+00, 3.5148311e+00]],
[[-8.5124874e-01, 3.1858414e-01, 3.3686757e-03, ...,
-1.1497847e+00, -1.1996644e+00, -9.6176589e-01],
[-4.2057925e-01, -1.8098265e-01, -7.4302059e-01, ...,
-3.5920531e-01, 7.0454830e-01, 1.8304255e+00],
[ 1.4177717e+00, 8.4456313e-01, -1.6396353e-01, ...,
4.2133337e-01, -4.6482396e-01, 6.6906375e-01],
...,
[-1.0060047e+00, -1.2088763e+00, 1.2608007e+00, ...,
5.1739502e-01, 8.9526463e-01, 7.2866821e-01],
[-3.5698372e-01, -5.9943002e-01, 1.0040566e+00, ...,
3.1322885e-01, 3.4513384e-01, -6.2404698e-01],
[-2.0622578e+00, 3.9633280e-01, 2.1701033e-01, ...,
2.6992482e-01, 4.4787437e-01, 2.1187775e-01]]]],
dtype=float32)]
替换节点
现在我们将add_model.onnx中的Conv节点替换成Squeeze节点(压缩维度)
import onnx import onnx.helper as helper import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load('./add_model.onnx') new_node = helper.make_node(op_type='Squeeze', inputs=['input'], outputs=['conv1'], name='squeeze1') nodes = model.graph.node nodes.append(new_node) for node in nodes: if node.op_type == 'Conv': nodes.remove(node) # onnx.checker.check_model(model) onnx.save_model(model, 'replace.onnx') input = np.random.randn(1, 3, 244, 244).astype(np.float32) sess = onnxruntime.InferenceSession('./replace.onnx') print(sess.run(['output'], {'input': input}))
运行结果
[array([[[0. , 1.1258854 , 0. , ..., 0.54984987,
0.19069785, 0. ],
[1.1481465 , 0. , 1.9025986 , ..., 0. ,
0.11273875, 0. ],
[1.57912 , 0. , 0. , ..., 0. ,
1.7471381 , 0. ],
...,
[0.42386332, 0.30908984, 0. , ..., 0. ,
0. , 1.8173866 ],
[0.07642962, 0.31224537, 0. , ..., 1.6805407 ,
2.0282576 , 0. ],
[0. , 0.2521538 , 0. , ..., 0. ,
0.6431213 , 0.5844705 ]],
[[0. , 0. , 0. , ..., 0.23725364,
0.22994171, 0.316093 ],
[0.85044146, 1.2757416 , 0.28854838, ..., 0. ,
0. , 0. ],
[0. , 0. , 1.1362596 , ..., 1.8543358 ,
1.1296074 , 0.5114057 ],
...,
[0. , 0.00810617, 0. , ..., 1.0819261 ,
1.707781 , 0. ],
[0. , 0.6385371 , 0. , ..., 0.6565783 ,
1.457183 , 0. ],
[0. , 0.8315589 , 1.4111192 , ..., 1.0682058 ,
0.17328343, 2.3547616 ]],
[[0.2426068 , 0. , 0. , ..., 0.89054537,
0.98760164, 0. ],
[1.1344411 , 0.8732987 , 0. , ..., 0. ,
0. , 0. ],
[0. , 0.3664789 , 1.4099371 , ..., 0. ,
0.0588427 , 0.5932818 ],
...,
[0. , 0.68438137, 0.8869638 , ..., 0. ,
0. , 1.4681839 ],
[0. , 0. , 0. , ..., 0.16630006,
1.9389246 , 0. ],
[0. , 0. , 0.03726726, ..., 0.86296386,
0. , 0. ]]], dtype=float32)]
这里需要注意的是,如果我们将# onnx.checker.check_model(model)的注释打开,运行是会报错的,因为我们添加的新节点squeeze1是在relu1之后的,虽然无法通过检查,但是是可以使用运行时来运行的。那如何才能即能运行又让检查也可以通过。
import onnx import onnx.helper as helper import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load('./add_model.onnx') new_node = helper.make_node(op_type='Squeeze', inputs=['input'], outputs=['conv1'], name='squeeze1') nodes = model.graph.node # nodes.append(new_node) for idx, node in enumerate(nodes): if node.op_type == 'Conv': nodes.remove(node) nodes.insert(idx, new_node) onnx.checker.check_model(model) onnx.save_model(model, 'replace.onnx') input = np.random.randn(1, 3, 244, 244).astype(np.float32) sess = onnxruntime.InferenceSession('./replace.onnx') print(sess.run(['output'], {'input': input}))
这里主要就是调换一下新节点的位置就好了。
ONNXRuntime介绍
ONNXRuntime是微软推出的一个推理框架,可以非常方便的运行ONNX模型,官方GitHub:https://github.com/microsoft/onnxruntime