加载MobileNetV2部分权重

一、问题提出

        在部署的时候考虑到边缘设备的性能,网络模型最好小一点,因此想用简化版本的MobilenetV2,一开始因为更改了网络结构和层数就没有加载预训练模型,但是发现模型很难训练。后面想着既然我只是减少了网络的层数,是不是可以加载保留部分的网络权重?或许可以加快网络的训练。(后面训练确实有效果)

二、解决方法

        MobilenetV2的预训练模型结构如下所示,我的model是每一个block只堆叠一次,即所有n=1。

        思路:加载MobilenetV2的预训练模型pretrained_dict,实例化自己的模型得到model_dict,将两个state_dict进行比对,在pretrained_dict对key值进行删减。

def mobilenet_v2(pretrained=True):
    model = MobileNetV2(width_mult=1)
    if pretrained:
        # try:
        #    from torch.hub import load_state_dict_from_url
        # except ImportError:
        #    from torch.utils.model_zoo import load_url as load_state_dict_from_url
        # state_dict = load_state_dict_from_url(
        #     'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', progress=True)
        pretrained_dict = torch.load("mobilenetv2_1.0-f2a8633.pth.tar") # 预训练的MobilenetV2
        model_dict = model.state_dict()  # 读取自己的网络的结构参数
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if
                           k in model_dict and (v.shape == model_dict[k].shape)}
        model_dict.update(pretrained_dict)  # 将与 pretrained_dict 中 layer_name 相同的参数更新为 pretrained_dict 的参数
        model.load_state_dict(model_dict) # 加载更新后的参数
    return model

猜你喜欢

转载自blog.csdn.net/weixin_44855366/article/details/130553739
今日推荐