pytorch使用——(十六)模型保存与加载

1、序列化与反序列化

2、 PyTorch中的序列化(save)与反序列化(load)

  • torch.save:主要参数为obj:对象;f:输出路径
  • torch.load:主要参数为f:文件路径;map_location:指定存放位置, cpu or gpu

3、保存补充

  • 保存整个Module:torch.save(net, path)
  • 保存模型参数:state_dict = net.state_dict();torch.save(state_dict , path)

4、断点续训练

checkpoint = {"model_state_dict": net.state_dict(),"optimizer_state_dict":optimizer.state_dict(),"epoch": epoch}

猜你喜欢

转载自blog.csdn.net/weixin_37799689/article/details/106486477