最近在做图像分类实验时,在4个gpu上使用pytorch的DataParallel 函数并行跑程序,批次为16时会报如下所示的错误:
RuntimeError: CUDA out of memory. Tried to allocate 858.00 MiB (GPU 3; 10.92 GiB total capacity; 10.10 GiB already allocated; 150.69 MiB free; 10.13 GiB reserved in total by PyTorch)
实验发现,每块gpu最多可以跑2条数据,但是我又想设置batch_size=16,参考https://zhuanlan.zhihu.com/p/86441879了解到transformer-XL官方写的BalancedDataParallel 函数,用来解决DataParallel 显存使用不平衡的问题(参考代码见最后)。
为了理解BalancedDataParallel 函数用法,我们先来弄清楚几个问题。
1,DataParallel 函数是如何工作的?
首先将模型加载到主 GPU 上,然后再将模型复制到各个指定的从 GPU 中,然后将输入数据按 batch 维度进行划分,具体来说就是每个 GPU 分配到的数据 batch 数量是总输入数据的 batch 除以指定 GPU 个数。每个 GPU 将针对各自的输入数据独立进行 forward计算,之后会把计算结果传到主GPU 上完成梯度计算和参数更新,最后将更新后的参数复制到从 GPU 中,这样就完成了一次迭代计算。参考https://blog.csdn.net/zhjm07054115/article/details/104799661当gpu=2,batch_size=30时,我们可以从下图清楚的看到首先会在两个gpu上分别分配15条数据,进行forward计算,之后汇总结果再进行梯度计算和参数更新。
我们可以看到反向传播计算和参数更新完全放在主gpu上进行的,这样会造成显存使用不平衡的问题。
2,梯度累加
参考https://blog.csdn.net/wuzhongqiang/article/details/102572324做的反向传播梯度累加实验,发现pytorch在反向传播的时候,默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。
了解了DataParallel 函数和梯度累加后,我们就可以来解决显存使用不平衡问题以及如何在显存固定的情况下加大训练批次。
首先,简单介绍BalancedDataParallel 用法【下图截取自https://github.com/Link-Li/Balanced-DataParallel】
简单解释一下:当我们需要在3个gpu并行跑程序,每个gpu最多一次可以处理3条数据,分配是[3,3,3],那么3个gpu最多可以同时处理9条数据,也就是batch_size最大可设为9,因为主gpu上还要进行反向传播,所以这里我们设置主gpu处理2条数据,分布就是[2,3,3],batch_size=8。
此时如果我们想加大批次,使得batch_size=16,那么分布应该是[4,6,6],但是我们知道每个gpu最多可以处理3条数据,这里就用到梯度累加的方法了,即上图中的acc_grad,acc_grad参数表示将batch_size分成多少份送入网络,当acc_grad=2,表示我们会先将16个数据分成2份,每份有8条数据,每次输入8条数据分给3个gpu做并行训练,forward计算结果放入主gpu上进行反向传播,由于梯度可以累加,循环两次后,再更新参数。这样做不仅可以缓解显存不平衡问题也可以解决显存不足的问题。
下面是我根据https://blog.csdn.net/zhjm07054115/article/details/104799661做了修改,加上BalancedDataParallel 完整代码:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from data_parallel_balance import BalancedDataParallel
# Dataset
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
self.target=np.random.randint(3,size=length)
def __getitem__(self, index):
label=torch.tensor(self.target[index])
return self.data[index],label
def __len__(self):
return self.len
# model
class Model(nn.Module):
def __init__(self, input_size, output_size):
super(Model, self).__init__()
self.fc = nn.Linear(input_size, output_size)
def forward(self, input):
output = self.fc(input)
print("\tIn Model: input size", input.size(),
"output size", output.size())
return output
# trian
def train(rand_loader,model,optimizer,criterion):
train_loss=0
# train
model.train()
optimizer.zero_grad()
for image,target in rand_loader:
print('image:',image.shape)
if batch_chunk > 0:
image_chunks = torch.chunk(image, batch_chunk, 0)
target_chunks = torch.chunk(target, batch_chunk, 0)
for i in range(len(image_chunks)):
print('image_chunks:',i)
img=image_chunks[i].to(device)
lab=target_chunks[i].to(device)
out=model(img)
print("Chunks_Outputs: input size", img.size(),
"output_size", out.size())
loss=criterion(out,lab)
# print('{} chunk,loss:{}.'.format(i,loss))
train_loss+=loss.item()
loss = loss.float().mean().type_as(loss) / len(image_chunks)
loss.backward()
else:
image = image.to(device)
target=target.to(device)
output = model(image)
loss=criterion(output,target)
train_loss=loss.item()
print("Outside: input size", image.size(),
"output_size", output.size())
optimizer.step()
return train_loss
if __name__=="__main__":
input_size = 5
output_size = 3
batch_size = 32
data_size = 70
batch_chunk=2
gpu0_bsz=8
epochs=2
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# data
rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
batch_size=batch_size, shuffle=True)
# model
model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
if gpu0_bsz >= 0:
model = BalancedDataParallel(gpu0_bsz // batch_chunk, model, dim=0)
else:
model = nn.DataParallel(model)
model.to(device)
# optimizer
optimizer= torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
# loss
criterion=nn.CrossEntropyLoss()
for epoch in range(epochs):
print('Epoch:',epoch)
train(rand_loader,model,optimizer,criterion)
参考:
pytorch多gpu并行训练
transformer-XL的官方代码
BalancedDataParallel 参考代码
PyTorch-4 nn.DataParallel 数据并行详解
Pytorch反向传播中的细节-计算梯度时的默认累加
欢迎大家留言批评指正!