import torch
def cuda(*pargs):
if torch.cuda.is_available():
return (data.cuda() for data in pargs)
if __name__ == '__main__':
x = torch.randn(1,3,256,256)
y = torch.randn(1,3,512,512)
a,b = cuda(x,y)
print(a.cpu().data.numpy().shape)
print(b.cpu().data.numpy().shape)
pytorch任意批量将数据传入cuda
猜你喜欢
转载自blog.csdn.net/luolinll1212/article/details/82762674
今日推荐
周排行