pytorch任意批量将数据传入cuda

import torch

def cuda(*pargs):
    if torch.cuda.is_available():
        return (data.cuda() for data in pargs)

if __name__ == '__main__':

    x = torch.randn(1,3,256,256)
    y = torch.randn(1,3,512,512)

    a,b = cuda(x,y)
    print(a.cpu().data.numpy().shape)
    print(b.cpu().data.numpy().shape)

猜你喜欢

转载自blog.csdn.net/luolinll1212/article/details/82762674