最近想要把训练好的Pytorch模型在Android端上部署,发现如果将直接训练好的模型直接运用到Android上会出现闪退的情况,所以需要将转模型进行转换。
查找了很多博客,都不能直接解决我的问题,所以经过一天的试错,终于把模型转换搞好了。
参考:
将Pytorch模型部署到Android端
pytorch官方教程
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
import os
from model import ResNet18
MODEL_PATH=''
model_pth = os.path.join(MODEL_PATH, 'test1_dict.pth') #拼接原模型的路径
#搭建网络,可以自己的网络模型,也可以使用torchvision.model提供的模型
model=ResNet18.RestNet18_Net()
#加载参数
model.load_state_dict(torch.load(model_pth))
#模型设置为评测模式
model.eval()
example=torch.rand(1,3,384,384)
#模型转化
traced_script_module = torch.jit.trace(model, example)
#移动端优化
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
#保存模型
traced_script_module_optimized._save_for_lite_interpreter("model4.pt")
最后在官方案例测试一下