#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time :2022/8/3 16:19
# @Author :weiz
# @ProjectName :cbir
# @File :pth2onnx.py
# @Description :
from vgg import *
import torch
def pth2onnx_all(input_shape, model_path, onnx_path):
"""
有参数,有模型结构
"""
model = torch.load(model_path, map_location=lambda storage, loc: storage)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
input_shape = torch.randn(input_shape[0], input_shape[1], input_shape[2], input_shape[3], device=device)
torch.onnx.export(model, input_shape, onnx_path, opset_version=9, verbose=False,
input_names=["input"], output_names=["output"])
def pth2onnx(model_net, input_shape, model_path, onnx_path):
"""
只有模型参数转onnx需要网络结构(model_net),官方推荐这种
"""
model_statedict = torch.load(model_path, map_location=lambda storage, loc: storage)
model_net.load_state_dict(model_statedict)
model_net.eval() # 测试,看是否报错
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
input_shape = torch.randn(input_shape[0], input_shape[1], input_shape[2], input_shape[3], device=device)
torch.onnx.export(model_net, input_shape, onnx_path, opset_version=9, verbose=False,
input_names=["input"], output_names=["output"])
def pth2onnx_dynamic(model_net, input_shape, model_path, onnx_path):
"""
输入输出不固定,需要使用dynamic_axes参数
"""
model_statedict = torch.load(model_path, map_location=lambda storage, loc: storage)
model_net.load_state_dict(model_statedict)
model_net.eval() # 测试,看是否报错
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
input_shape = torch.randn(input_shape[0], input_shape[1], input_shape[2], input_shape[3], device=device)
# 动态输入输出:batch_size in_width int_height都可以动态
# dynameic_input = {"input": {0: 'batch_size', 2 : 'in_width', 3: 'int_height'},
# "output": {0: 'batch_size', 2: 'out_width', 3:'out_height'}}
dynameic_input = {"input": {2 : 'in_width', 3: 'int_height'}}
torch.onnx.export(model_net, input_shape, onnx_path, opset_version=9, verbose=False,
input_names=["input"], output_names=["output"], dynamic_axes=dynameic_input)
if __name__ == "__main__":
vgg = VGG16()
pth2onnx(vgg, [1, 3, 224, 224], "./vgg_dict_test.pth", "./vgg_no_full.onnx")
pth2onnx_dynamic(vgg, [1, 3, 224, 224], "./vgg_dict_test.pth", "./vgg_no_full_dynamic.onnx")
pth2onnx_all([1, 3, 224, 224], "./vgg_test.pth", "./vgg16_no_ful.onnx")
pth转onnx的三种情况
猜你喜欢
转载自blog.csdn.net/qq_31112205/article/details/126180272
今日推荐
周排行