多GPU训练,加载模型时报错

一、问题描述

 多个GPU 训练,保存时没有加module , 导致加载模型时报错。正确写法应该如下:

       # save model
        if num_gpu ==  1:
            torch.save(model.module.state_dict(), os.path.join(opt.outf, 'net.pth'))
        else:
            torch.save(model.state_dict(), os.path.join(opt.outf, 'net.pth'))

二、解决方法

load 模型时,删除多余的module.  可以打印下面代码中的pth , 查看问题症结

具体代码如下:

        print("load pre_training weight. ")
        pth = torch.load(os.path.join(opt.outf, 'net.pth'))
        new_state_dict = OrderedDict()
        for k, v in pth.items():
            name =  k[7:] # remove  'module'
            new_state_dict[name]=v
        net.load_state_dict(new_state_dict)

参考链接:http://www.likecs.com/show-99262.html

猜你喜欢

转载自blog.csdn.net/ljh618625/article/details/115389817