版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_16949707/article/details/72571509
1 DataParallel
from torch.nn import DataParallel
net = DataParallel(net)
可以实现模块级别(?好处具体是啥不大懂)的并行计算,可以将一个模块forward部分分到各个gpu去计算,然后backwards时,合并gradients 到original module。
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)
2 DataLoader
其实这里trainset已经包含数据集了,dataloader只是定义输入网络的一些参数,入batch_size等等。
3 Transform
对数据集进行的操作
compose函数会将多个transforms包在一起。