pytorch 多卡并行载入部分网络模型
我们在做深度学习的时候经常会使用预训练的模型。但是一旦自己修改了网络架构,就无法load pretrained model。 因为模型文件保存的参数,有一部分是不需要的,或者有一部分参数是缺失的。
为了在这种情况下,成功导入模型,我们需要如下操作
操作的前提是我们存在已保存的模型参数
model = Net()
torch.save(model.state_dict(),'xxx.path')
接下来就好办了
device = torch.device("cuda:2" if args.cuda else "cpu")
#Try to load models
model = DGCNN(args).to(device)
print(str(model))
device_ids = [2,3]
model = nn.DataParallel(model,device_ids=device_ids) #使用2,3号显卡进行训练
save_model = torch.load('model.t7')
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys())
model_dict.update(state_dict)
model.load_state_dict(model_dict)
update之后,model_dict和state_dict中具有相同键的值已经同步了。
可以开始愉快的训练了!
参考
https://blog.csdn.net/qq_34914551/article/details/87871134