m = weight_norm(nn.Linear(20, 40), name='weight').cuda()
print(m.weight.device) # on cpu
inputs # a tensor on cuda
outputs = m(inputs)
不会报错,按常识模型和数据要放在同一个设备上才行,其实是weight_norm运算的时候用的不是m.weight,而是m.weight_g和m.weight_v.
print(m.weight_g.device) # on cuda
print(m.weight_v.device) # on cuda