在模型训练完后再进行测试加载模型后出现bug,显示如下错误
据了解是由于pytorch版本导致的错误,可能与自己训练阶段保持的模型方式有关,训练阶段保存方式如下:
解决方案如下:
方法一:
generator.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(generator_1_10.pth).items()})
实际上就是将load进行的权重的有序字典里面的键值前面的的7个字符’module.'去掉。加载进行的权重有序字典如下图所示:
键就是每层的权重或者 bias 的名称,value就是其具体的张量值。
方法二:重新新建个有序字典:
from collections import OrderedDict
# new_state_dict = OrderedDict()
# for k, v in a.items():
# name=k[7:] # reduce `module.`
# new_state_dict[name] = v
# # load params
# # model.load_state_dict(new_state_dict)
# model.load_state_dict(new_state_dict)
显然方法一更简洁明了。