目录
1、torch 的 save 和 load
我们可以直接使用 save 函数 和 load函数 进行存储和读取。
- save 使用 Python 的 pickle 实用程序将对象进行序列化,然后将序列化的对象保存到disk。save可以保存各种对象,包括模型、张量和字典等。
- load 使用 pickle unpickle 工具将 pickle 的对象文件反序列化为内存
import torch
x = [1, 2]
y = {
'name':'xiaoming', 'age':16}
z = (x, y)
torch.save(z, 'z.pt')
z_new = torch.load('z.pt')
print(z_new) # ([1, 2], {'name': 'xiaoming', 'age': 16})
2、state_dict
1)net.state_dict()
在PyTorch
中,Module
的可学习参数 (即权重和偏差),模块模型包含在参数中 (通过 model.parameters()
访问)。state_dict
是一个从参数名称隐射到参数 Tesnor
的有序字典对象。
注意,只有具有可学习参数的层(卷积层、线性层等) 才有 state_dict
中的条目。
import torch.nn as nn
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
print(net.state_dict())
# OrderedDict([('hidden.weight', tensor([[-0.2360, -0.3193, -0.2618],[ 0.1759, -0.0888, 0.2635]])),
# ('hidden.bias', tensor([ 0.2161, -0.3944])),
# ('output.weight', tensor([[-0.5358, -0.2140]])),
# ('output.bias', tensor([0.6262]))])
2)optimizer.state_dict()
优化器(optim
) 也有一个 state_dict
,其中包含关于优化器状态以及所使用的超参数的信息。
import torch
import torch.nn as nn
net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print(optimizer.state_dict())
# {'state': {},
# 'param_groups': [{'lr': 0.001,
# 'momentum': 0.9,
# 'dampening': 0,
# 'weight_decay': 0,
# 'nesterov': False,
# 'maximize': False,
# 'foreach': None,
# 'params': [0, 1, 2, 3]}]}
3、保存模型 和 加载模型
1)仅保存和加载模型参数(state_dict
)
import torch
import torch.nn as nn
net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))
# 保存模型参数
torch.save(net.state_dict(), 'model_weight.pt') # 推荐的文件后缀名是pt或pth
# 下载模型参数
model_weight = torch.load('model_weight.pt')
print(model_weight)
# OrderedDict([('0.weight', tensor([[-0.3865, -0.4623, 0.1212],[-0.2480, 0.3840, 0.1916]])),
# ('0.bias', tensor([ 0.0698, -0.0641])),
# ('2.weight', tensor([[-0.1499, -0.2895]])),
# ('2.bias', tensor([0.2585]))])
# 下载模型参数 并放到模型中
net.load_state_dict(torch.load('model_weight.pt'))
2)保存 和 加载 整个模型
import torch
import torch.nn as nn
net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))
# 保存整个模型
torch.save(net, 'model.pt')
# 下载模型参数 并放到模型中
net_new = torch.load('model.pt')
print(net_new)
# Sequential(
# (0): Linear(in_features=3, out_features=2, bias=True)
# (1): ReLU()
# (2): Linear(in_features=2, out_features=1, bias=True)
# )