1、序列化与反序列化
2、 PyTorch中的序列化(save)与反序列化(load)
- torch.save:主要参数为obj:对象;f:输出路径
- torch.load:主要参数为f:文件路径;map_location:指定存放位置, cpu or gpu
3、保存补充
- 保存整个Module:torch.save(net, path)
- 保存模型参数:state_dict = net.state_dict();torch.save(state_dict , path)
4、断点续训练
checkpoint = {"model_state_dict": net.state_dict(),"optimizer_state_dict":optimizer.state_dict(),"epoch": epoch}