pytorch模型转换,将.t7转为.pt

将deep_sort跟踪模型ckpt.t7转为pt格式,用于pc端的推断。

def cvt_model():
    print("===> Loading model")
    model = Net()
    modelname = 'ckpt.t7'
    checkpoint = torch.load(savedataroot + modelname)
    model.load_state_dict(checkpoint['net_dict'])  # 从字典中依次读取,具体值查看字典更改
    print('===> Load last checkpoint data')

    # 模型转换,Torch Script
    model.eval()
    example = torch.rand(4,3,128,64)
    traced_script_module = torch.jit.trace(model, example)
    traced_script_module.save("deep_model.pt")
    print("Export of model.pt complete!")
发布了191 篇原创文章 · 获赞 104 · 访问量 34万+

猜你喜欢

转载自blog.csdn.net/u013925378/article/details/103281368
T7