注意事项
1.fit_generator中用的都是整型数字
2.构建checkpoing
3.模型保存的时候注意,如果有自定义层,容易出错
代码
#定义模型检查点
checkpoint = keras.callbacks.ModelCheckpoint(self.save_path, monitor='val_metric_precision', verbose=1,
save_best_only=True, mode='max')
callbacks_list = [checkpoint]
#模型保存
model.fit_generator(G_train,steps_per_epoch=int(self.total_number/self.batch_size),validation_data=G_eval,#不设置steps_per_epoch=
validation_steps=40,epochs=self.epochs,callbacks=callbacks_list)
model.save_weights(self.save_path)#保存模型权重 保存模型是model.save 注意后缀是h5