将deep_sort跟踪模型ckpt.t7转为pt格式,用于pc端的推断。
def cvt_model(): print("===> Loading model") model = Net() modelname = 'ckpt.t7' checkpoint = torch.load(savedataroot + modelname) model.load_state_dict(checkpoint['net_dict']) # 从字典中依次读取,具体值查看字典更改 print('===> Load last checkpoint data') # 模型转换,Torch Script model.eval() example = torch.rand(4,3,128,64) traced_script_module = torch.jit.trace(model, example) traced_script_module.save("deep_model.pt") print("Export of model.pt complete!")