转自莫烦大神,转载原因是想把所有相关内容收集到自己的博客中,方便系统的学习。
两种保存方法,1是保存整个神经网络;2是只保存神经网络的所有参数。
一、保存神经网络
1保存整个神经网络。
torch.save(net1,"net1.pkl")
net1为我想要保存的网络,net1.pkl为文件名,保存的格式只能是.pkl
2,保存神经网络参数
torch.save(net1.state_dict(),"net1_parmaer.pkl")
二、恢复神经网络
1恢复完整神经网络(直接load())
net2=torch.load("net1.pkl")
2.从参数中恢复神经网络
需先构建与所要恢复的神经网络相同结构,再load参数。
3,完整程序如下
import torch import matplotlib.pyplot as plt x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1) y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) def save(): net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), # 一层神经层 torch.nn.ReLU(), # 加激励函数,relu相当于类 torch.nn.Linear(10, 1), ) optimizer=torch.optim.SGD(net1.parameters(),lr=0.5) loss_func=torch.nn.MSELoss() for t in range(100): prediction=net1(x) loss=loss_func(prediction,y) optimizer.zero_grad() loss.backward() optimizer.step() #画图 plt.figure(1, figsize=(10, 3)) plt.subplot(131) plt.title('Net1') plt.scatter(x.data.numpy(), y.data.numpy()) #实际数据 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) #回归曲线 torch.save(net1,"net1.pkl") #保存整个神经网络 torch.save(net1.state_dict(),"net1_parmaer.pkl") #保存神经网络中的所有参数 def restore_net(): net2=torch.load("net1.pkl") prediction2=net2(x) plt.subplot(132) plt.title('Net2') plt.scatter(x.data.numpy(), y.data.numpy()) #实际数据 plt.plot(x.data.numpy(), prediction2.data.numpy(), 'r-', lw=5) #回归曲线 def restore_paramers(): net3=torch.nn.Sequential( torch.nn.Linear(1, 10), # 一层神经层 torch.nn.ReLU(), # 加激励函数,relu相当于类 torch.nn.Linear(10, 1), ) net3.load_state_dict(torch.load("net1_parmaer.pkl")) #先构建网络在,再加载参数 prediction3 = net3(x) plt.subplot(133) plt.title('Net3') plt.scatter(x.data.numpy(), y.data.numpy()) # 实际数据 plt.plot(x.data.numpy(), prediction3.data.numpy(), 'r-', lw=5) # 回归曲线 plt.show() save() restore_net()
运行结果: