前言
很久没更新了,第一个原因是学校的课程任务,第二个原因是在kaggle实战去了,我参加的是泰坦尼克那个比赛,调了快一周的代码,收获也是不小,感受最大的就是:在机器学习的任务中,非常非常重要的就是特征工程,同样的模型,一个好的特征处理工程能让你的准确率提升百分之几,在kaggle上这能让你的排名上升非常多,这是一篇kaggle经验文章,也可以看csdn上的中译版:中译版点这里,我看完之后感觉受益良多。然后,步入正题:今天笔记的主题就是Pytorch神经网络参数的保存与读取了。
环境配置:
python版本:3.8.6
torch版本:1.11.0
导入的库:
import torch
from torch import nn
一、创建一个简单神经网络
多层感知机:
net = nn.Sequential(
nn.Linear(5, 2),
nn.ReLU(),
nn.Linear(2, 1)
)
这里的net其实就是nn.Sequential这个类的实例,
二、保存神经网络参数
torch.save(net.state_dict(), 'net.params')
!!! 注意:这里只是保存了神经网络的参数,而没有保存神经网络的结构,也就是说,如果我们后面要读取参数,必须要创建一个与之前相同结构的神经网络。
三、克隆之前的网络
net_clone = nn.Sequential(
nn.Linear(5, 2),
nn.ReLU(),
nn.Linear(2, 1)
)
四、克隆的网络加载本体网络保存的参数
net_clone.load_state_dict(torch.load('net.params'))
五、验证两个网络参数是否一致
print('net:', net.state_dict())
print('net_clone:', net_clone.state_dict())
输出结果如下:
参数一致,保存成功!
或许会有细心的小伙伴发现,为什么只有"0"层和 "2"层的参数,"1"层的呢?其实这里"1"层是ReLU激活层,没有参数,所以没有显示出来了。
附:总代码
我真的没有在凑字数
import torch
from torch import nn
net = nn.Sequential(
nn.Linear(5, 2),
nn.ReLU(),
nn.Linear(2, 1)
)
torch.save(net.state_dict(), 'net.params')
net_clone = nn.Sequential(
nn.Linear(5, 2),
nn.ReLU(),
nn.Linear(2, 1)
)
net_clone.load_state_dict(torch.load('net.params'))
print('net:', net.state_dict())
print('net_clone:', net_clone.state_dict())