这种模式可以让梯度玩出更多花样,比如说梯度累加(gradient accumulation
)
传统的训练函数,一个batch
是这么训练的:
for i, (image, label) in enumerate(train_loader):
# 1. input output
pred = model(image)
loss = criterion(pred, label)
# 2. backward
optimizer.zero_grad() # reset gradient
loss.backward()
optimizer.step()
- 获取
loss
:输入图像和标签,通过infer
计算得到预测值,计算损失函数; optimizer.zero_grad()
清空过往梯度;loss.backward()
反向传播,计算当前梯度;optimizer.step()
根据梯度更新网络参数;
简单的说就是进来一个 batch
的数据,计算一次梯度,更新一次网络.
梯度累加
使用梯度累加是这么写的:
for i,(image, label) in enumerate(train_loader):
# 1. input output
pred = model(image)
loss = criterion(pred, label)
# 2.1 loss regularization
loss = loss / accumulation_steps
# 2.2 back propagation
loss.backward()
# 3. update parameters of net
if (i+1) % accumulation_steps == 0:
# optimizer the net
optimizer.step() # update parameters of net
optimizer.zero_grad() # reset gradient
【步骤1】:获取 loss
:输入图像和标签,通过infer
计算得到预测值,计算损失函数;
【步骤2】:loss.backward()
反向传播,计算当前梯度;
【步骤3】:多次循环步骤 1-2
,不清空梯度,使梯度累加在已有梯度上;
【步骤4】:梯度累加了一定次数后,先optimizer.step()
根据累计的梯度更新网络参数,然后optimizer.zero_grad()
清空过往梯度,为下一波梯度累加做准备;
总结来说:梯度累加就是,每次获取1
个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。一定条件下,batchsize
越大训练效果越好,梯度累加则实现了batchsize
的变相扩大,如果accumulation_steps
为 8
,则batchsize
变相 扩大了8
倍,是我们这种乞丐实验室解决显存受限的一个不错的trick
,使用时需要注意,学习率也要适当放大。
问题
更新1:
关于BN
是否有影响,之前有人是这么说的:
As far as I know, batch norm statistics get updated on each forward pass, so no problem if you don’t do .backward() every time.
BN的估算是在forward阶段就已经完成的,并不冲突,只是accumulation_steps=8和真实的batchsize放大八倍相比,效果自然是差一些,毕竟八倍Batchsize的BN估算出来的均值和方差肯定更精准一些。
更新2:
根据 @李韶华 的分享,可以适当调低BN自己的momentum参数
bn自己有个momentum参数:x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum越接近0,老的running stats记得越久,所以可以得到更长序列的统计信息