使用GPU训练模型,遇到显存不足的情况:开始报chunk xxx size 64000的错误。使用tensorflow框架来训练的。
仔细分析原因有两个:
- 数据集padding依据的是整个训练数据集的max_seq_length,这样在一个批内的数据会造成额外的padding,占用显存;
- 在训练时把整个训练数据先全部加载,造成显存占用多。
如果遇到第一种情况,即使使用CPU训练速度也非常慢。
对于第二种情况,要使用generator来解决。不要加载全部数据,要分批加载,根据一个批内的最大length来填充,同时也要限制最大length的长度。丢弃部分很长的数据。
而且,如果使用bert时,会对seq_length有限制。
tensorflow 1.12限制只使用CPU:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'