导入mnist数据集
from pathlib import Path
import requests
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)
URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"
if not (PATH / FILENAME).exists():
content = requests.get(URL + FILENAME).content
(PATH / FILENAME).open("wb").write(content)
28281 拉成784个像素点
开始搭建TensorFlow
import tensorflow as tf
from tensorflow.keras import layers
model = tf.keras.Sequential()
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax')) #需要得到10个结果(数字0~9)
pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
# tf.losses.CategoricalCrossentropy
model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) #准确率
model.fit(x_train, y_train, epochs=5, batch_size=64,
validation_data=(x_valid, y_valid))
选择损失和评估函数时候需要选择合适的,Api参考:https://tensorflow.google.cn/api_docs/python/tf/keras/metrics/SparseCategoricalAccuracy?version=stable
一定选择合适的损失函数
不能用交叉熵:CategoricalCrossentropy,数据类型不对,交叉熵为one-hot
应该用SparseCategoricalCrossentropy代替
结果如下,效果一般
… 40256/50000 [=>…] - ETA: 0s -
loss: 0.1138 - sparse_categorical_accuracy: 0.9648 41792/50000
[>…] - ETA: 0s - loss: 0.1151 -
sparse_categorical_accuracy: 0.9645 43392/50000
[=>…] - ETA: 0s - loss: 0.1150 -
sparse_categorical_accuracy: 0.9645 44992/50000
[=>…] - ETA: 0s - loss: 0.1157 -
sparse_categorical_accuracy: 0.9643 46592/50000
[>…] - ETA: 0s - loss: 0.1152 -
sparse_categorical_accuracy: 0.9645 47872/50000
[=>…] - ETA: 0s - loss: 0.1141 -
sparse_categorical_accuracy: 0.9648 48960/50000
[==========================>.] - ETA: 0s - loss: 0.1140 -
sparse_categorical_accuracy: 0.9650