我的需求是,由于我在不停的尝试各种模型,导致模型木块一直会变。如果每次重复重新开始训练要花费大把时间。
我之前运行的模型 ResNet -> 三个ResNet参数共享。
ResNet -> 中间模块 -> 结果
ResNet ->
现在我要改成 ResNet 1-> 三个ResNet不参数共享来重新训练,我想导入之前模型中间模块的参数,
ResNet 2-> 中间模块 -> 结果
ResNet 3->
并且冻结中间模块的参数使训练速度加快。
参考了两位大神的两篇博文:加载部分参数https://blog.csdn.net/weixin_41519463/article/details/101604662,冻结部分参数https://blog.csdn.net/jdzwanghao/article/details/83239111。
具体代码如下:
net = MY_Net( )
######导入部分参数
model_dict = net.state_dict()
for k, v in model_dict.items():
print(k)
pretrained_dict = torch.load(model_file1)#model_file1是之前模型的模型保存路径,这里只是加载参数而已
for k, v in pretrained_dict.items():
print(k)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict) # 用预训练模型参数更新new_model中的部分参数
net.load_state_dict(model_dict) # 将更新后的model_dict加载进new model中
##### 冻结部分参数
for param in net.parameters():
param.requires_grad = False#设置所有参数不可导,下面选择设置可导的参数
for param in net.ResNet1.parameters():
param.requires_grad = True
for param in net.ResNet2.parameters():
param.requires_grad = True
for param in net.ResNet3.parameters():
param.requires_grad = True
optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr = 0.0001, momentum=0.90,weight_decay=0.0005)#关键是优化器中通filter来过滤掉那些不可导的参数