首先,给出代码地址https://github.com/jfzhang95/pytorch-deeplab-xception以及加载预训练模型的代码,做了些许修改
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
model_dict = {
}
state_dict = model.state_dict()
for k, v in pretrain_dict.items()
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
model.load_state_dict(state_dict)
其中,for循环寻找了model与加载的预训练模型中keys相同的部分,并将其值加载到model中,最后对model进行更新。
但是,在更换预训练模型的时候出现了错误:
RuntimeError: Error(s) in loading state_dict for XXX:
Missing key(s) in state_dict:
查看了一些资料,发现可能是keys不对齐造成的问题,于是,将两种模型的keys输出看一看:
pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
for k in pretrain_dict.keys():
print(k) #查看预训练模型的keys
model_dict = {
}
state_dict = model.state_dict()
for k in state_dict.keys():
print(k) #查看本地model的keys
for k, v in pretrain_dict.items()
if k in state_dict:
model_dict[k] = v
state_dict.update(model_dict)
model.load_state_dict(state_dict)
for k in model_dict.keys():
print(k) #查看model更新后的keys
查看结果如下:
…
module.backbone.conv1.weight
module.backbone.bn1.weight
…
…
conv1.weight
bn1.weight
…
果然,keys不对应,看了很多解决方法,发现都是需要一个一个加载,感觉太麻烦,所以按照自己的想法修改了一下,成功加载:
pretrain_dict = torch.load('/home/yu/Desktop/pytorch-deeplab-clone-1/111.pth')
for k in pretrain_dict.keys():
print(k)
model_dict = {
}
state_dict = model.state_dict()
for k in state_dict.keys():
print(k)
print("分界线")
for k, v in pretrain_dict.items():
for i, j in state_dict.items(): #加上前缀后寻找对应的keys
m = 'module.backbone.' + i
if k == m :
model_dict[i] = v
print(i)
state_dict.update(model_dict)
model.load_state_dict(state_dict)
for k in model_dict.keys():
print(k)
return model
其实,也就是把model的keys加了前缀。
新手上路,记录下解决问题的过程,有好的方法,欢迎交流。