在模型落地应用时,我们往往需要先对模型进行格式转换,本篇博客记录从.pth文件转换为.onnx文件的具体流程。
一、模型结构定义
在模型训练好之后我们有两种保存方式,一种就是将模型的结构与模型的参数一起进行保存,但是由于后续工程的持续改进,这种方法往往并不是很实用,大部分工程采用了第二种方法,也就是只保存模型中的参数部分,由于只有模型的参数,我们在转换模型之前我们需要将模型的结构进行实例化。这也是模型转换的第一步
我们可以在原工程中找到get_model这一接口,但是如果对工程并不熟悉的话这一操作往往会比较困难,所以我们采用了比较简单的方式,也就是将模型的结构文件直接进行实例化。举个例子,有如下模型结构:
import torch
import torch.nn as nn
import torch.nn.functional as F
class cde(nn.Module):
def __init__(self):
super(cde, self).__init__()
self.relu = nn.ReLU(inplace=True)
number_f = 32
self.e_conv1 = nn.Conv2d(3, number_f, 3, 1, 1, bias=True)
self.e_conv2 = nn.Conv2d(number_f, number_f, 3, 1, 1, bias=True)
self.e_conv3 = nn.Conv2d(number_f, number_f, 3, 1, 1, bias=True)
self.e_conv4 = nn.Conv2d(number_f, number_f, 3, 1, 1, bias=True)
self.e_conv5 = nn.Conv2d(number_f * 2, number_f, 3, 1, 1, bias=True)
self.e_conv6 = nn.Conv2d(number_f * 2, number_f, 3, 1, 1, bias=True)
self.e_conv7 = nn.Conv2d(number_f * 2, 24, 3, 1, 1, bias=True)
self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
def forward(self, x):
x1 = self.relu(self.e_conv1(x))
# p1 = self.maxpool(x1)
x2 = self.relu(self.e_conv2(x1))
# p2 = self.maxpool(x2)
x3 = self.relu(self.e_conv3(x2))
# p3 = self.maxpool(x3)
x4 = self.relu(self.e_conv4(x3))
x5 = self.relu(self.e_conv5(torch.cat([x3, x4], 1)))
# x5 = self.upsample(x5)
x6 = self.relu(self.e_conv6(torch.cat([x2, x5], 1)))
x_r = F.tanh(self.e_conv7(torch.cat([x1, x6], 1)))
r1, r2, r3, r4, r5, r6, r7, r8 = torch.split(x_r, 3, dim=1)
x = x + r1 * (torch.pow(x, 2) - x)
x = x + r2 * (torch.pow(x, 2) - x)
x = x + r3 * (torch.pow(x, 2) - x)
enhance_image_1 = x + r4 * (torch.pow(x, 2) - x)
x = enhance_image_1 + r5 * (torch.pow(enhance_image_1, 2) - enhance_image_1)
x = x + r6 * (torch.pow(x, 2) - x)
x = x + r7 * (torch.pow(x, 2) - x)
enhance_image = x + r8 * (torch.pow(x, 2) - x)
r = torch.cat([r1, r2, r3, r4, r5, r6, r7, r8], 1)
return enhance_image_1, enhance_image, r
我们只需要将本结构保存为一个model.py文件,便于后续的实用。接下来我们找到需要转换模型的权重就可以开始转换了
转换代码如下:
import torch
import model
# 实例化模型的结构部分
DCE_net = model.cde()
# 训练之后保存的模型权重
a = "./1.pth"
# 加载模型权重文件
DCE_net.load_state_dict(torch.load(a))
# 定义导出文件的路径以及名称
onnx_path = './2.onnx'
# 定义模型的静态输入,也可以指定动态输入,详细操作见官网
input = torch.randn(1, 3, 640, 640)
torch.onnx.export(DCE_net, args=input, f=onnx_path, export_params=False, verbose=True, opset_version=11) # 指定模型的输入,以及onnx的输出路径
运行之后就可以将模型进行转换完毕了!