只存储模型参数
import torch
import torch.nn as nn
# 定义一个简单的模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 创建一个模型实例
model = MyModel()
path = 'output/model'
# 保存模型
torch.save(model.state_dict(), path+'.pth')
# 创建一个新的模型对象(与保存模型的模型结构相同)
model_1 = MyModel()
# 加载模型参数
model_1.load_state_dict(torch.load(path+'.pth'))
# 判断两个模型参数是否相同
# 判断模型参数是否相同
params_equal = True
for (a_name, a_param), (b_name, b_param) in zip(model.state_dict().items(), model_1.state_dict().items()):
if a_name != b_name or not torch.equal(a_param, b_param):
params_equal = False
break
if params_equal:
print("模型参数相同")
else:
print("模型参数不相同")
运行结果:
保存整个模型
import torch
import torch.nn as nn
# 定义一个简单的模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 创建一个模型实例
model = MyModel()
path = 'output/model'
# 保存模型
torch.save(model, path+'.pth')
# 加载模型
model_1 = torch.load(path+'.pth')
# 判断两个模型参数是否相同
# 判断模型参数是否相同
params_equal = True
for (a_name, a_param), (b_name, b_param) in zip(model.state_dict().items(), model_1.state_dict().items()):
if a_name != b_name or not torch.equal(a_param, b_param):
params_equal = False
break
if params_equal:
print("模型参数相同")
else:
print("模型参数不相同")