基本思路
K-折交叉验证的方法是将数据集分成k个互斥的子集(一般是均分),然后将每个子集分别做一次验证集,其余K-1组子集作为训练集。在每次训练完的模型后进行验证来对模型性能进行估计。
应用的函数
- tf.range()
tf.range(start, limit, delta=1, dtype=None, name='range'
功能是创建一个开始于 start 并且将以 delta 为增量扩展到limit-1 的数字序列。
例如:
a = tf.range(0, 10)
tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32)
- tf.random.shuffle()
tf.random_shuffle(
value,
seed=None,
name=None
)
功能:对张量value的第一维度进行打乱。
例如:
a = tf.random.shuffle(a)
tf.Tensor([4 3 7 5 9 8 6 1 0 2], shape=(10,), dtype=int32)
- tf.gather()
功能:用一个索引数组将张量中对应索引的向量提取出来。
例如:
index = tf.range(0, 2) # [0, 1]
x = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
print(tf.gather(x, index))
输出为:
[[1 2 3]
[4 5 6]]
使用上面的函数就可以实现对数据集元素的随机打乱并划分。
最终代码
下面以10-折交叉验证为例:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow.keras import datasets, layers, Sequential, optimizers
# load data
(x, y), (x_test, y_test) = datasets.mnist.load_data()
# build datasets
def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [-1, 28 * 28])
y = tf.cast(y, dtype=tf.int64)
y = tf.one_hot(y, depth=10)
return x, y
print('datasets:', x.shape, y.shape, x_test.shape, y_test.shape)
index = tf.range(0, 60000)
index = tf.random.shuffle(index)
x_train, y_train = tf.gather(x, index[:54000]), tf.gather(y, index[:54000]) # 60000 * 9/10
x_val, y_val = tf.gather(x, index[-6000:]), tf.gather(y, index[-6000:])
# print the shapes of training dataset and validation dataset
print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)
batchsz = 128
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.batch(batchsz).shuffle(54000).map(preprocess)
db_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
db_val = db_val.batch(batchsz).shuffle(6000).map(preprocess)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz).map(preprocess)
# print a sample in training dataset
sample = next(iter(db_train))
print('sample shape:', sample[0].shape, sample[1].shape)
# build network
network = Sequential([
layers.Dense(256, activation='relu'), # [b, 784] => [b, 256]
layers.Dense(128, activation='relu'), # [b, 256] => [b, 128]
layers.Dense(64, activation='relu'), # [b, 128] => [b, 64]
layers.Dense(32, activation='relu'), # [b, 64] => [b, 32]
layers.Dense(10, ) # [b, 32] => [b, 10]
])
network.build(input_shape=[None, 28 * 28])
network.summary()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
network.fit(db_train, epochs=10, validation_data=db_val, validation_freq=1)
# print text accuracy
print('test accuracy:')
network.evaluate(db_test)