网络训练细节

https://www.cnblogs.com/wmlj/p/9917827.html

经典网络的加载和初始化:pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用网络结构,并且提供了预训练模型,可通过调用来读取网络结构和预训练模型(模型参数)。往往为了加快学习进度,训练的初期直接加载pretrain模型中预先训练好的参数。加载model如下所示:

import torchvision.models as models
    #加载网络结构和预训练参数
    #参数pretrained在默认情况下是False,表示只加载网络结构而不加载预训练参数来初始化
    resnet34 = models.resnet34(pretrained=True)

    #打印网络结构    
    print(resnet34) 

resnet18.load_state_dict(torch.load(path_params.pkl))#其中,path_params.pkl为预训练模型参数的保存路径。加载预先下载好的预训练参数到resnet18,用预训练模型的参数初始化resnet18的层,此时resnet18发生了改变。调用model的load_state_dict方法用预训练的模型参数来初始化自己定义的新网络结构,这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。load_state_dict方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度)。故,当新定义的网络(model_dict)和预训练网络(pretrained_dict)的层名不严格相等时,需要先将pretrained_dict里不属于model_dict的键剔除掉 :
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} ,再用预训练模型参数更新model_dict,最后用load_state_dict方法初始化自己定义的新网络结构。

print resnet18 #打印的还是网络结构

注意: cnn = resnet18.load_state_dict(torch.load( path_params.pkl )) #是错误的,这样cnn将是nonetype

pre_dict = resnet18.state_dict() #按键值对将模型参数加载到pre_dict

print for k, v in pre_dict.items(): # 打印模型参数

for k, v in pre_dict.items():

  print k  #打印模型每层命名

#model是自己定义好的新网络模型,将pretrained_dict和model_dict中命名一致的层加入pretrained_dict(包括参数)。

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

猜你喜欢

转载自www.cnblogs.com/czz0508/p/10499700.html