模型保存与加载常用有两种方法,第一种是保存整个模型,包括模型的结构和参数;第二种是保存模型的参数。推荐使用第二种,因为模型一旦很大,第一种加载耗时长,其次第二种加载方式更加灵活,可以加载其他模型的预训练参数,从而使用迁移学习的方法减小训练时长。
一、保存/加载整个模型
- 保存模型:
torch.save(net, 'model_net1.pkl')
- 加载模型
net_parm = 'model_net1.pkl'
net = torch.load(net_parm)
二、保存/加载模型参数
- 保存参数:
torch.save({
'epoch': nums_epoch,
'state_dict': net.state_dict(),
}, 'model_net.pkl')
- 加载参数:
cuda_gpu = torch.cuda.is_available()
if(cuda_gpu):
net = torch.nn.DataParallel(net, device_ids=gpus).cuda()
if os.path.isfile(net_parm):
print("=>loading model '{}'".format(net_parm))
checkpoint = torch.load(net_parm)
epoch = checkpoint['epoch']
print(epoch)
net.load_state_dict(checkpoint['state_dict'])
print("=>load model success, start epoch: '{}'".format(epoch))
torch.load 返回的是一个 OrderedDict。OrderedDict 是 collections 提供的一种数据结构, 它提供了有序的dict结构。可以将加载的模型打印出来:
checkpoint = torch.load(net_parm)
print (checkpoint)
因此我们知道参数加载的原理就是将相同的Key进行赋值操作,load一般是依据key来加载的,一旦有key不匹配则出错。
3. 部分加载参数
这时候就出现一个问题,如果我们修改了网络结构,但是相同的部分想加载预训练模型的参数,应该如何加载呢?
net.load_state_dict(pretrained_net, strict=False)
设置strict=False,则直接忽略不匹配的key,对于匹配的key则进行正常的赋值。
或者自己设置一个过滤器,过滤不需要的网络层。参考apaszke推荐的做法,即删除与当前model不匹配的key。
pretrained_dict = ...
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(pretrained_dict)