pytorch——保存,加载模型

一文梳理pytorch保存和重载模型参数攻略

查看当前模型结构与参数值

print(model.state_dict)
# 输出定义的网络结构
print(model.state_dict())
# 输出所有参数名和参数值

输出如下:

<bound method Module.state_dict of Digit(
  (conv1): Conv1d(2, 10, kernel_size=(5,), stride=(1,))
  (conv3): Conv1d(5, 20, kernel_size=(3,), stride=(1,))
  (fc6): Linear(in_features=2480, out_features=500, bias=True)
  (drop8): Dropout(p=0.5, inplace=False)
  (fc9): Linear(in_features=500, out_features=1, bias=True)
)>
OrderedDict([('conv1.weight', tensor([[[-0.2759,  0.1526,  0.2299, -0.2617, -0.0128],
         [ 0.2975, -0.1635, -0.1661,  0.1830,  0.1413]],

        [[ 0.0064, -0.1616, -0.2967, -0.3151,  0.0642],
         [-0.0369,  0.0338,  0.2795,  0.0888, -0.2408]],

        [[ 0.2387, -0.1673, -0.2089,  0.2312, -0.2677],
         [ 0.1646, -0.0508, -0.0151,  0.3200, -0.0355]],

        [[-0.2255,  0.0793, -0.2272, -0.0198, -0.2901],
         [-0.2260,  0.0601, -0.0991,  0.0732, -0.0444]],

保存模型

torch.save(obj = model.state_dict(), f = "./net.pth")
# 存储路径,上级目录同为,desktop

对新的网络加载参数值

首先定义一个新的空白参数值网络

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer = nn.Linear(1, 1)
        self.layer.weight = nn.Parameter(torch.FloatTensor([[0]]))
        self.layer.bias = nn.Parameter(torch.FloatTensor([0]))

    def forward(self, x):
        out = self.layer(x)
        return out
# 该网络只有一个线性层
modeldemo = Model()
print(modeldemo.state_dict())
print(modeldemo.state_dict)
# 由于未经过训练 ,此时的权重和偏执都为0,
OrderedDict([('layer.weight', tensor([[0.]])), ('layer.bias', tensor([0.]))])
<bound method Module.state_dict of Model(
  (layer): Linear(in_features=1, out_features=1, bias=True)
)>

可以看出未经过训练,w , b 都为0

print(model.state_dict())
# modeldemo加载,另一个网络model的权重,注意此时的两个网络应该是一样的结构

保存该模型的其他数值

创建一个字典,然后保存这个字典,字典的字段是需要的数值即可

net = Digit()
Adam = optim.Adam(params=net.parameters(), lr= 0.001, betas = (0.5,0.999))
Epo = 97
all_states = {
    
    "net":net.state_dict(),"Adam" : Adam.state_dict(),"epoch" : Epo}
torch.save(obj= all_states,f= "./all_states.pth")

查看已保存的内容

{
    
    'net': OrderedDict([('conv1.weight', tensor([[[-0.0992,  0.1028,  0.1915,  0.2423, -0.3130],
         [ 0.0308, -0.0206,  0.0133, -0.2522, -0.2496]],

        [[-0.2329, -0.1573,  0.3153,  0.1176,  0.0190],
         [-0.2168,  0.1106,  0.1726,  0.0559,  0.2262]],

        [[ 0.3109, -0.3043, -0.2859, -0.1401, -0.0489],
         [-0.0905, -0.0871, -0.0425, -0.1573, -0.2254]],

        [[-0.1303, -0.0006,  0.2278, -0.0243,  0.2638],
         [-0.0177, -0.0474, -0.1561,  0.2652, -0.3036]],


。。。

		 -2.0413e-02,  2.7196e-02,  3.4429e-02, -3.7110e-02,  1.7414e-02,
         -7.8588e-03, -4.4491e-02, -4.4779e-03, -4.4562e-02, -4.7492e-03,
         -2.2696e-03, -1.9462e-02, -2.8391e-02,  2.0047e-02, -3.8300e-02,
          7.0216e-03, -2.8285e-02, -1.8722e-03,  2.6953e-03, -6.4457e-03,
          1.9489e-03,  2.9594e-02,  5.5762e-03, -1.9028e-02,  3.8116e-02]])), 
          ('fc9.bias', tensor([0.0003]))]),
           'Adam': {
    
    'state': {
    
    }, 'param_groups': [{
    
    'lr': 0.001, 'betas': (0.5, 0.999), 
           'eps': 1e-08, 
           'weight_decay': 0, 
           'amsgrad': False, 
           'params': [0, 1, 2, 3, 4, 5, 6, 7]}]}, 
           'epoch': 97}

总结:参数对象的存储是需要“名称”——“值”对应(即键值对),读取时也是通过键值对读取的,因此字典的概念很重要

原文参考:https://zhuanlan.zhihu.com/p/94971100

猜你喜欢

转载自blog.csdn.net/HJ33_/article/details/120486391