参考博客
https://blog.csdn.net/learning_tortosie/article/details/85243310
在本教程中,您将了解Keras .fit
和.fit_generator
函数的工作原理,包括它们之间的差异。为了帮助您获得实践经验,我已经提供了一个完整的示例,向您展示如何从头开始实现Keras数据生成器。
Keras深度学习库包括三个独立的函数,可用于训练您自己的模型:
.fit
.fit_generator
.train_on_batch
-
这三个函数基本上可以完成相同的任务,但他们如何去做这件事是非常不同的。
让我们逐个探索这些函数,查看函数调用的示例,然后讨论它们彼此之间的差异。
调用.fit
:
model.fit(trainX, trainY, batch_size=32, epochs=50)
在这里可以看到提供的训练数据(trainX
)和训练标签(trainY
)。然后,我们指示Keras允许我们的模型训练50
个epoch,同时batch size为32
。
对.fit
的调用在这里做出两个主要假设:
- 我们的整个训练集可以放入RAM
- 没有数据增强(即不需要Keras生成器)
我们的网络将在原始数据上训练。原始数据本身适合内存,我们无需将旧批量数据从RAM中移出并将新批量数据移入RAM。此外,我们不会使用数据增强动态操纵训练数据。
对于小型,简单化的数据集,使用Keras的.fit
函数是完全可以接受的。
这些数据集通常不是很具有挑战性,不需要任何数据增强。
但是,真实世界的数据集很少这么简单:
- 真实世界的数据集通常太大而无法放入内存中
- 它们也往往具有挑战性,要求我们执行数据增强以避免过拟合并增加我们的模型的泛化能力
在这些情况下,我们需要利用Keras的.fit_generator
函数,函数原型为,
fit_generator(self, generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0)
优点:通过Python generator产生一批批的数据用于训练模型。generator可以和模型并行运行,例如,可以使用CPU生成批数据同时在GPU上训练模型。
参数:
- generator:一个generator或Sequence实例,为了避免在使用multiprocessing时直接复制数据。
- steps_per_epoch:从generator产生的步骤的总数(样本批次总数)。通常情况下,应该等于数据集的样本数量除以批量的大小。
- epochs:整数,在数据集上迭代的总数。
- works:在使用基于进程的线程时,最多需要启动的进程数量。
- use_multiprocessing:布尔值。当为True时,使用基于基于过程的线程。
# initialize the number of epochs and batch size
EPOCHS = 100
BS = 32
# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
horizontal_flip=True, fill_mode="nearest")
# train the network
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
epochs=EPOCHS)
我们首先初始化将要训练的网络的epoch和batch size。
然后我们初始化aug
,这是一个Keras ImageDataGenerator
对象,用于图像的数据增强,随机平移,旋转,调整大小等。
执行数据增强是正则化的一种形式,使我们的模型能够更好的被泛化。
但是,应用数据增强意味着我们的训练数据不再是“静态的” ——数据不断变化。
根据提供给ImageDataGenerator
的参数随机调整每批新数据。
因此,我们现在需要利用Keras的.fit_generator
函数来训练我们的模型。
该函数本身是一个Python生成器。
Keras在使用.fit_generator
训练模型时的过程:
- Keras调用提供给
.fit_generator
的生成器函数(在本例中为aug.flow
) - 生成器函数为
.fit_generator
函数生成一批大小为BS
的数据 .fit_generator
函数接受批量数据,执行反向传播,并更新模型中的权重- 重复该过程直到达到期望的epoch数量
-
您会注意到我们现在需要在调用
.fit_generator
时提供steps_per_epoch
参数(.fit
方法没有这样的参数)。为什么我们需要
steps_per_epoch
?请记住,Keras数据生成器意味着无限循环,它永远不会返回或退出。
由于该函数旨在无限循环,因此Keras无法确定一个epoch何时开始的,并且新的epoch何时开始。
因此,我们将训练数据的总数除以批量大小的结果作为
steps_per_epoch
的值。一旦Keras到达这一步,它就会知道这是一个新的epoch。 -
图像数据集作为CSV文件?
将在这里使用的数据集是Flowers-17数据集,它是17种不同花种的集合,每个类别有80个图像。我们的目标是培训Keras卷积神经网络,以正确分类每种花卉。
但是,这个项目有点不同:
- 不是使用存储在磁盘上的原始图像文件
- 而是将整个图像数据集序列化为两个CSV文件(一个用于训练,一个用于评估)
-
要构建每个CSV文件,我:
- 循环输入数据集中的所有图像
-
我们的目标是现在编写一个自定义Keras生成器来解析CSV文件,并为
.fit_generator
函数生成批量图像和标签。 - 将它们调整为 64×64 像素
- 将 64x64x3 = 12,288 个RGB像素的强度展平为单个列表
- 在CSV文件中写入12,288个像素值和类标签(每行一个)