一文梳理pytorch保存和重载模型参数攻略
查看当前模型结构与参数值
print(model.state_dict)
# 输出定义的网络结构
print(model.state_dict())
# 输出所有参数名和参数值
输出如下:
<bound method Module.state_dict of Digit(
(conv1): Conv1d(2, 10, kernel_size=(5,), stride=(1,))
(conv3): Conv1d(5, 20, kernel_size=(3,), stride=(1,))
(fc6): Linear(in_features=2480, out_features=500, bias=True)
(drop8): Dropout(p=0.5, inplace=False)
(fc9): Linear(in_features=500, out_features=1, bias=True)
)>
OrderedDict([('conv1.weight', tensor([[[-0.2759, 0.1526, 0.2299, -0.2617, -0.0128],
[ 0.2975, -0.1635, -0.1661, 0.1830, 0.1413]],
[[ 0.0064, -0.1616, -0.2967, -0.3151, 0.0642],
[-0.0369, 0.0338, 0.2795, 0.0888, -0.2408]],
[[ 0.2387, -0.1673, -0.2089, 0.2312, -0.2677],
[ 0.1646, -0.0508, -0.0151, 0.3200, -0.0355]],
[[-0.2255, 0.0793, -0.2272, -0.0198, -0.2901],
[-0.2260, 0.0601, -0.0991, 0.0732, -0.0444]],
保存模型
torch.save(obj = model.state_dict(), f = "./net.pth")
# 存储路径,上级目录同为,desktop
对新的网络加载参数值
首先定义一个新的空白参数值网络
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer = nn.Linear(1, 1)
self.layer.weight = nn.Parameter(torch.FloatTensor([[0]]))
self.layer.bias = nn.Parameter(torch.FloatTensor([0]))
def forward(self, x):
out = self.layer(x)
return out
# 该网络只有一个线性层
modeldemo = Model()
print(modeldemo.state_dict())
print(modeldemo.state_dict)
# 由于未经过训练 ,此时的权重和偏执都为0,
OrderedDict([('layer.weight', tensor([[0.]])), ('layer.bias', tensor([0.]))])
<bound method Module.state_dict of Model(
(layer): Linear(in_features=1, out_features=1, bias=True)
)>
可以看出未经过训练,w , b 都为0
print(model.state_dict())
# modeldemo加载,另一个网络model的权重,注意此时的两个网络应该是一样的结构
保存该模型的其他数值
创建一个字典,然后保存这个字典,字典的字段是需要的数值即可
net = Digit()
Adam = optim.Adam(params=net.parameters(), lr= 0.001, betas = (0.5,0.999))
Epo = 97
all_states = {
"net":net.state_dict(),"Adam" : Adam.state_dict(),"epoch" : Epo}
torch.save(obj= all_states,f= "./all_states.pth")
查看已保存的内容
{
'net': OrderedDict([('conv1.weight', tensor([[[-0.0992, 0.1028, 0.1915, 0.2423, -0.3130],
[ 0.0308, -0.0206, 0.0133, -0.2522, -0.2496]],
[[-0.2329, -0.1573, 0.3153, 0.1176, 0.0190],
[-0.2168, 0.1106, 0.1726, 0.0559, 0.2262]],
[[ 0.3109, -0.3043, -0.2859, -0.1401, -0.0489],
[-0.0905, -0.0871, -0.0425, -0.1573, -0.2254]],
[[-0.1303, -0.0006, 0.2278, -0.0243, 0.2638],
[-0.0177, -0.0474, -0.1561, 0.2652, -0.3036]],
。。。
-2.0413e-02, 2.7196e-02, 3.4429e-02, -3.7110e-02, 1.7414e-02,
-7.8588e-03, -4.4491e-02, -4.4779e-03, -4.4562e-02, -4.7492e-03,
-2.2696e-03, -1.9462e-02, -2.8391e-02, 2.0047e-02, -3.8300e-02,
7.0216e-03, -2.8285e-02, -1.8722e-03, 2.6953e-03, -6.4457e-03,
1.9489e-03, 2.9594e-02, 5.5762e-03, -1.9028e-02, 3.8116e-02]])),
('fc9.bias', tensor([0.0003]))]),
'Adam': {
'state': {
}, 'param_groups': [{
'lr': 0.001, 'betas': (0.5, 0.999),
'eps': 1e-08,
'weight_decay': 0,
'amsgrad': False,
'params': [0, 1, 2, 3, 4, 5, 6, 7]}]},
'epoch': 97}
总结:参数对象的存储是需要“名称”——“值”对应(即键值对),读取时也是通过键值对读取的,因此字典的概念很重要