Early stoping和checkpoint在tensofrflow中的使用

在训练中使用early stoping终止模型和使用checkpoint保存模型

当我们训练模型时,如果epoch设置太长,常常希望可以在loss不再下降或者accuracy不再提高时终止训练,获得模型,避免模型浪费时间,这时可以使用tensorflow 中的early stoping终止模型和使用checkpoint保存模型:

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
import pandas as pd
import numpy as np
from scipy import sparse
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint #导入early stopping,checkpoint
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

batchsz = 256
....
....
.... # 用过pandas读取数据并构建你自己的x_train,y_train,x_val,y_val,x_test,y_test
db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db = db.shuffle(6000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.batch(batchsz)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz)

sample = next(iter(db))
print(sample[0].shape, sample[1].shape)

network = Sequential([layers.Conv1D(512, kernel_size=5, strides=3,padding='same',activation=tf.nn.relu),
                     layers.Dropout(0.4),
                     layers.Conv1D(256, kernel_size=5, strides=3,padding='same',activation=tf.nn.relu),
                     layers.Dropout(0.4),
                     layers.Conv1D(128, kernel_size=5, strides=3,padding='same',activation=tf.nn.relu),
                     layers.Dropout(0.4),
                     layers.Flatten(),
                     layers.Dense(128, activation=tf.nn.relu),
                     layers.Dropout(0.4),
                     layers.Dense(1,activation='sigmoid')]) # 构建你自己network

network.build(input_shape=(None, 2000,4))
network.summary()

# 构建early-stopping,监测目标是测试集的accuracy,8个epoch如果提高不到0。001即终止
early_stopping = EarlyStopping(monitor='val_acc',min_delta=0.001,patience=8) 

#构建checkpoint,保存model名字为‘conv.h5‘,监测指标为测试集准确率val_acc,model=‘max’保存val_acc最大的的,save_best_only=true,保存最好的模型
checkpoint=ModelCheckpoint('conv.h5', monitor='val_acc',model='max',verbose=1,save_best_only=True) 

network.compile(optimizer=optimizers.Adam(lr=0.001),
                loss='binary_crossentropy',
                metrics=['accuracy']
               )

# 在callback中使用early-stopping和checkpoint。
network.fit(db, epochs=100, validation_data=ds_val, validation_steps=2,callbacks=[early_stopping,checkpoint]) 
network.evaluate(db_test)

#测试
sample = next(iter(db_test))
x = sample[0]
y = sample[1] 
pred = network.predict(x)  

如果想进一步了解early stoppong,可以查看jason的文章Use Early Stopping to Halt the Training of Neural Networks At the Right Time

猜你喜欢

转载自blog.csdn.net/weixin_44022515/article/details/103867407