Sequential构建模型
- 任务:CIFAR100数据集进行模型训练
CIFAR100数据集有100个类别,每个类别包含600张图像。每个课程有500张训练图像和100张测试图像。CIFAR-100中的100个类别分为20个超类。每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类)。
代码如下:
import tensorflow as tf
import os
import numpy as np
class sequ_model():
""" CNN进行cifar100类别分类"""
def __init__(self):
(self.train, self.train_label), (self.test, self.test_label) = \
tf.keras.datasets.cifar100.load_data()
print(self.train.shape)
self.train = self.train.reshape(-1, 32, 32, 3) / 255.0
self.test = self.test.reshape(-1, 32, 32, 3) / 255.0
self.model = tf.keras.Sequential([tf.keras.layers.Conv2D(32, kernel_size=[5, 5], strides=1,
padding='same', activation=tf.nn.relu),
tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2),
tf.keras.layers.Conv2D(64, kernel_size=[5, 5], strides=1,
padding='same', activation=tf.nn.relu),
tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(1024, activation=tf.nn.relu),
tf.keras.layers.Dense(100,activation=tf.nn.softmax)])
def compile(self, learning_rate):
self.model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=tf.keras.losses.sparse_categorical_crossentropy,
Metrics=['accuracy'])
def fit(self):
check = tf.keras.callbacks.ModelCheckpoint('./ckpt/epochs_{epoch:02d}-{val_loss:.2f}.h5',
mornitor='val_cc',
save_best_only=True,
save_weights_only=True,
mode='auto',
period=1)
tensbord = tf.keras.callbacks.TensorBoard(log_dir='./graph',
histogram_freq=1, write_graph=True, write_images=True)
self.model.fit(self.train, self.train_label, epochs=2, batch_size=32, callbacks=[check, tensbord],
validation_data=(self.test, self.test_label))
def evaluate(self):
test_loss, test_acc = self.model.evaluate(self.test, self.test_label)
print("损失{},准确率{}".format(test_loss, test_acc))
def model_save(self):
self.model.save_weights("./model.h5")
#
def predict(self):
if os.path.exists("./model.h5"):
self.model.load_weights("./model.h5")
pre = self.model.predict(self.test)
if __name__ == '__main__':
ts = sequ_model()
# ts.compile(0.02)
# ts.fit()
# ts.evaluate()
# ts.model_save()
ts.predict()