pytorch的并行:nn.DataParallel 方法

```
# 1.当前版本信息
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 2. 设置device信息 和 创建model

model = UNetSeeInDark()
model._initialize_weights()

gpus = [0,1,2,3]
model = nn.DataParallel(model, device_ids=gpus)
device = torch.device('cuda:0')
model = model.cuda(device=gpus[0])

如果不使用并行,只需要
注释掉 model = nn.DataParallel(model, device_ids=gpus)

猜你喜欢

转载自blog.csdn.net/tywwwww/article/details/131654845