有了之前搭建模型并训练的基础,就可以来个实际的例子了,新手必备MNIST登场。
实现一个手写数字识别非常简单,数据集是准备好的,一行代码下载即可,然后按照之前的方式搭建模型就能训练出一个识别手写数字效果还不错的模型了。
Step1:加载tensorflow并下载数据集
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
mnist = tf.keras.datasets.mnist # 下载
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 # 归一化处理
Step2:搭建模型并设置训练流程
model = tf.keras.Sequential()
model.add(layers.Flatten(input_shape=(28, 28)))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
Step3:启动训练并验证模型
model.fit(x_train, y_train, epochs=10)
model.evaluate(x_test, y_test, verbose=2)
可以看到使用全连接层的模型acc也能达到98%。
现在结合前面的知识,修改一下模型。
Step4:修改模型,用卷积层