一、问题提出
在部署的时候考虑到边缘设备的性能,网络模型最好小一点,因此想用简化版本的MobilenetV2,一开始因为更改了网络结构和层数就没有加载预训练模型,但是发现模型很难训练。后面想着既然我只是减少了网络的层数,是不是可以加载保留部分的网络权重?或许可以加快网络的训练。(后面训练确实有效果)
二、解决方法
MobilenetV2的预训练模型结构如下所示,我的model是每一个block只堆叠一次,即所有n=1。
思路:加载MobilenetV2的预训练模型pretrained_dict,实例化自己的模型得到model_dict,将两个state_dict进行比对,在pretrained_dict对key值进行删减。
def mobilenet_v2(pretrained=True):
model = MobileNetV2(width_mult=1)
if pretrained:
# try:
# from torch.hub import load_state_dict_from_url
# except ImportError:
# from torch.utils.model_zoo import load_url as load_state_dict_from_url
# state_dict = load_state_dict_from_url(
# 'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', progress=True)
pretrained_dict = torch.load("mobilenetv2_1.0-f2a8633.pth.tar") # 预训练的MobilenetV2
model_dict = model.state_dict() # 读取自己的网络的结构参数
pretrained_dict = {k: v for k, v in pretrained_dict.items() if
k in model_dict and (v.shape == model_dict[k].shape)}
model_dict.update(pretrained_dict) # 将与 pretrained_dict 中 layer_name 相同的参数更新为 pretrained_dict 的参数
model.load_state_dict(model_dict) # 加载更新后的参数
return model