如果我们直接用keras的fit函数来训练模型的话,是需要传入全部训练数据,但是好在提供了fit_generator
,可以分批次的读取数据,节省了我们的内存,我们唯一要做的就是实现一个生成器(generator)。
直接给出关键代码,数据处理和组网部分略过
处理后的数据格式如下:
- 第一个列表是特征,[39, 500000]
- 第二个列表是label
fit_generator
batch_size = 2048
model.fit_generator(
GeneratorRandomPatchs(train_x, train_y, batch_size),
validation_data=(val_x, val_y),
steps_per_epoch=len(train_data) // batch_size,
epochs=100,
verbose=1
)
def GeneratorRandomPatchs(train_x, train_y, batch_size):
totl, col = np.array(train_x).shape # (39, 500000) 特征数、样本数
# 保证 steps_per_epoch * epoch 批次的数据够
while True:
for index in range(0, col, batch_size):
xs, ys = [], []
for t in range(totl):
xs.append(train_x[t][index: index + batch_size])
ys.append(train_y[0][index: index + batch_size])
# print(np.array(xs).shape, np.array(ys).shape)
yield (xs, ys)
fit_generator
需要传递一个迭代器,如上述例子:GeneratorRandomPatchs,通过yield
返回训练数据
batch_size
:批处理大小,就是每次入模的样本数。steps_per_epoch
:每个epoch要处理的批数。比如训练数据50W,batch_size是2048,那么一个epoch的批数就是244。
其它参数解释
参考自:
https://blog.csdn.net/zhangpeterx/article/details/90900118
https://blog.csdn.net/qq_39783265/article/details/106752903