pytorch cuda.FloatTensor->FloatTensor

错误类型:

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor)

        定义残差块时定义在model的外面,在使用gpu进行训练的时候,残差块的参数是torch.FloatTensor类型,

虽然使用了model.cuda(),但是只对model里面的参数在gpu部分,所以把残差块对应的操作都在model的__init__(),

重新定义,即可解决问题

参考:https://www.jianshu.com/p/1fa86e060e5a

猜你喜欢

转载自blog.csdn.net/nlite827109223/article/details/84823518