为了让自己更清楚模型的训练过程,在这里对训练过程进行一下记录,欢迎大佬补充与指正。
1、fit函数训练过程:
model.compile(loss="sparse_categorical_crossentropy",
optimizer = keras.optimizers.SGD(0.001),
metrics = ["accuracy"])
history = model.fit(x_train_scaled, y_train, epochs=10,
validation_data=(x_valid_scaled, y_valid),
callbacks = callbacks)
如上是tf2.0比较常见的模型训练代码,在这种情况下:因为epochs=10,所以会遍历10次训练集,在每遍历一次训练集后会在验证集上进行验证,因为metrics中指明“我们还要关注accuracy",所以在每一次epoch中,会输出测试集上的loss和accuracy,与验证集上的val_loss和val_accuracy。
Train on 55000 samples, validate on 5000 samples Epoch 1/10 55000/55000 [==============================] - 4s 78us/sample - loss: 0.9102 - accuracy: 0.7004 - val_loss: 0.6062 - val_accuracy: 0.7928 Epoch 2/10 55000/55000 [==============================] - 4s 64us/sample - loss: 0.5768 - accuracy: 0.8011 - val_loss: 0.5134 - val_accuracy: 0.8244 Epoch 3/10 55000/55000 [==============================] - 3s 62us/sample - loss: 0.5108 - accuracy: 0.8220 - val_loss: 0.4736 - val_accuracy: 0.8384 Epoch 4/10 55000/55000 [==============================] - 4s 64us/sample - loss: 0.4759 - accuracy: 0.8330 - val_loss: 0.4490 - val_accuracy: 0.8438 Epoch 5/10 55000/55000 [==============================] - 4s 64us/sample - loss: 0.4527 - accuracy: 0.8412 - val_loss: 0.4328 - val_accuracy: 0.8494 Epoch 6/10 55000/55000 [==============================] - 4s 68us/sample - loss: 0.4357 - accuracy: 0.8479 - val_loss: 0.4200 - val_accuracy: 0.8528 Epoch 7/10 55000/55000 [==============================] - 4s 66us/sample - loss: 0.4219 - accuracy: 0.8518 - val_loss: 0.4082 - val_accuracy: 0.8586 Epoch 8/10 55000/55000 [==============================] - 4s 66us/sample - loss: 0.4108 - accuracy: 0.8547 - val_loss: 0.3997 - val_accuracy: 0.8614 Epoch 9/10 55000/55000 [==============================] - 4s 68us/sample - loss: 0.4012 - accuracy: 0.8586 - val_loss: 0.3963 - val_accuracy: 0.8614 Epoch 10/10 55000/55000 [==============================] - 4s 66us/sample - loss: 0.3925 - accuracy: 0.8607 - val_loss: 0.3881 - val_accuracy: 0.8630
如果最后要在测试集上进行验证,可以运行下面代码:
model.evaluate(x_test_scaled, y_test, verbose=0)
[0.34827180202007296, 0.8757]
补充:在fit函数中也是以batch的形式遍历训练集的,默认batch_size=32,因此,在fit函数中做的事情可以总结为:1、以batch的形式遍历训练集, 然后统计训练集上的metric(其中包含自动求导,因此如果想要替换这部分,就需要实现这三部分)。2、一个epoch结束后,在验证集进行验证,统计验证集上的loss。
2、使用sklearn进行超参数搜索时的训练(cross_validation机制)
cross_validation机制: 训练集分成n份,n-1训练,最后一份验证.(默认n=4)
训练、验证、测试的关系是:每训练一次训练集(每经过一个epoch),就在验证集上做一次验证,全部训练完,最后在测试集上做测试。
在超参数搜索中,有了cross_validation机制,多了一份验证集,那它具体是怎样执行的呢?,先在前n-1份的数据上训练,每经过一个epoch,在第n份上做一个验证,当训练完100个epoch后(假设设置:epochs = 100),在x_valid_scaled(也就是验证集)在进行一次验证,在最后超参数搜索完之后再在全部训练集上用新的参数再训练一遍,每经过一个epoch,仍在验证集上进行验证。如果你感觉有点乱,看下面的训练日志,你就会清楚了。
下面是我截取的”有cross_validation机制”的训练过程的其中一个运行100次epoch的日志。它会在7740的训练数据上进行训练(这是11610的3/4),然后在剩下的1/4数据上进行验证,就是输出的val_loss,在运行100次epoch后(有early_stopping机制),在3870的验证集上进行验证。
Train on 7740 samples, validate on 3870 samples Epoch 1/100 7740/7740 [==============================] - 1s 112us/sample - loss: 5.2672 - val_loss: 4.9883 Epoch 2/100 7740/7740 [==============================] - 1s 80us/sample - loss: 4.3754 - val_loss: 4.1726 Epoch 3/100 7740/7740 [==============================] - 1s 95us/sample - loss: 3.6575 - val_loss: 3.4975 Epoch 4/100 7740/7740 [==============================] - 1s 87us/sample - loss: 3.0589 - val_loss: 2.9331 Epoch 5/100 7740/7740 [==============================] - 1s 102us/sample - loss: 2.5663 - val_loss: 2.4767 Epoch 6/100 7740/7740 [==============================] - 1s 87us/sample - loss: 2.1773 - val_loss: 2.1219 Epoch 7/100 7740/7740 [==============================] - 1s 80us/sample - loss: 1.8800 - val_loss: 1.8536 Epoch 8/100 7740/7740 [==============================] - 1s 79us/sample - loss: 1.6559 - val_loss: 1.6512 Epoch 9/100 7740/7740 [==============================] - 1s 81us/sample - loss: 1.4865 - val_loss: 1.4947 Epoch 10/100 7740/7740 [==============================] - 1s 80us/sample - loss: 1.3585 - val_loss: 1.3753 Epoch 11/100 7740/7740 [==============================] - 1s 81us/sample - loss: 1.2567 - val_loss: 1.2781 Epoch 12/100 7740/7740 [==============================] - 1s 80us/sample - loss: 1.1739 - val_loss: 1.1986 Epoch 13/100 7740/7740 [==============================] - 1s 95us/sample - loss: 1.1043 - val_loss: 1.1311 Epoch 14/100 7740/7740 [==============================] - 1s 81us/sample - loss: 1.0435 - val_loss: 1.0722 Epoch 15/100 7740/7740 [==============================] - 1s 80us/sample - loss: 0.9901 - val_loss: 1.0206 Epoch 16/100 7740/7740 [==============================] - 1s 82us/sample - loss: 0.9431 - val_loss: 0.9755 Epoch 17/100 7740/7740 [==============================] - 1s 80us/sample - loss: 0.9022 - val_loss: 0.9364 Epoch 18/100 7740/7740 [==============================] - 1s 81us/sample - loss: 0.8660 - val_loss: 0.9022 Epoch 19/100 7740/7740 [==============================] - 1s 81us/sample - loss: 0.8341 - val_loss: 0.8723 Epoch 20/100 7740/7740 [==============================] - 1s 89us/sample - loss: 0.8067 - val_loss: 0.8463 Epoch 21/100 7740/7740 [==============================] - 1s 80us/sample - loss: 0.7832 - val_loss: 0.8239 Epoch 22/100 7740/7740 [==============================] - 1s 81us/sample - loss: 0.7632 - val_loss: 0.8043 Epoch 23/100 7740/7740 [==============================] - 1s 81us/sample - loss: 0.7458 - val_loss: 0.7873 Epoch 24/100 7740/7740 [==============================] - 1s 81us/sample - loss: 0.7307 - val_loss: 0.7726 Epoch 25/100 7740/7740 [==============================] - 1s 81us/sample - loss: 0.7175 - val_loss: 0.7598 Epoch 26/100 7740/7740 [==============================] - 1s 92us/sample - loss: 0.7060 - val_loss: 0.7486 Epoch 27/100 7740/7740 [==============================] - 1s 84us/sample - loss: 0.6958 - val_loss: 0.7387 Epoch 28/100 7740/7740 [==============================] - 1s 83us/sample - loss: 0.6866 - val_loss: 0.7300 Epoch 29/100 7740/7740 [==============================] - 1s 88us/sample - loss: 0.6786 - val_loss: 0.7223 Epoch 30/100 7740/7740 [==============================] - 1s 105us/sample - loss: 0.6716 - val_loss: 0.7155 Epoch 31/100 7740/7740 [==============================] - 1s 88us/sample - loss: 0.6653 - val_loss: 0.7093 Epoch 32/100 7740/7740 [==============================] - 1s 92us/sample - loss: 0.6597 - val_loss: 0.7037 Epoch 33/100 7740/7740 [==============================] - 1s 105us/sample - loss: 0.6547 - val_loss: 0.6987 Epoch 34/100 7740/7740 [==============================] - 1s 92us/sample - loss: 0.6503 - val_loss: 0.6942 Epoch 35/100 7740/7740 [==============================] - 1s 91us/sample - loss: 0.6463 - val_loss: 0.6900 Epoch 36/100 7740/7740 [==============================] - 1s 84us/sample - loss: 0.6425 - val_loss: 0.6860 Epoch 37/100 7740/7740 [==============================] - 1s 81us/sample - loss: 0.6391 - val_loss: 0.6823 Epoch 38/100 7740/7740 [==============================] - 1s 83us/sample - loss: 0.6359 - val_loss: 0.6788 Epoch 39/100 7740/7740 [==============================] - 1s 80us/sample - loss: 0.6329 - val_loss: 0.6756 Epoch 40/100 7740/7740 [==============================] - 1s 79us/sample - loss: 0.6301 - val_loss: 0.6725 Epoch 41/100 7740/7740 [==============================] - 1s 76us/sample - loss: 0.6274 - val_loss: 0.6696 Epoch 42/100 7740/7740 [==============================] - 1s 80us/sample - loss: 0.6248 - val_loss: 0.6667 Epoch 43/100 7740/7740 [==============================] - 1s 77us/sample - loss: 0.6224 - val_loss: 0.6640 Epoch 44/100 7740/7740 [==============================] - 1s 77us/sample - loss: 0.6201 - val_loss: 0.6614 Epoch 45/100 7740/7740 [==============================] - 1s 79us/sample - loss: 0.6178 - val_loss: 0.6588 Epoch 46/100 7740/7740 [==============================] - 1s 78us/sample - loss: 0.6157 - val_loss: 0.6563 Epoch 47/100 7740/7740 [==============================] - 1s 78us/sample - loss: 0.6136 - val_loss: 0.6540 Epoch 48/100 7740/7740 [==============================] - 1s 77us/sample - loss: 0.6115 - val_loss: 0.6516 Epoch 49/100 7740/7740 [==============================] - 1s 76us/sample - loss: 0.6095 - val_loss: 0.6494 Epoch 50/100 7740/7740 [==============================] - 1s 78us/sample - loss: 0.6076 - val_loss: 0.6472 Epoch 51/100 7740/7740 [==============================] - 1s 77us/sample - loss: 0.6057 - val_loss: 0.6450 Epoch 52/100 7740/7740 [==============================] - 1s 78us/sample - loss: 0.6038 - val_loss: 0.6429 Epoch 53/100 7740/7740 [==============================] - 1s 77us/sample - loss: 0.6020 - val_loss: 0.6409 Epoch 54/100 7740/7740 [==============================] - 1s 82us/sample - loss: 0.6002 - val_loss: 0.6388 Epoch 55/100 7740/7740 [==============================] - 1s 87us/sample - loss: 0.5985 - val_loss: 0.6369 Epoch 56/100 7740/7740 [==============================] - 1s 78us/sample - loss: 0.5968 - val_loss: 0.6349 Epoch 57/100 7740/7740 [==============================] - 1s 76us/sample - loss: 0.5951 - val_loss: 0.6330 Epoch 58/100 7740/7740 [==============================] - 1s 76us/sample - loss: 0.5934 - val_loss: 0.6312 Epoch 59/100 7740/7740 [==============================] - 1s 75us/sample - loss: 0.5918 - val_loss: 0.6293 Epoch 60/100 7740/7740 [==============================] - 1s 77us/sample - loss: 0.5902 - val_loss: 0.6275 Epoch 61/100 7740/7740 [==============================] - 1s 77us/sample - loss: 0.5887 - val_loss: 0.6258 3870/1 [==================================================================] - 0s 40us/sample - loss: 0.5191
这是全部超参数搜索结束后,从sample出的10个超参组合中的最好的一组,它的选择就是按照它在训练集(11610样本)内切分出来的验证集(3870样本)上的val_loss,然后在全部训练集(11610)上进行训练。
Train on 11610 samples, validate on 3870 samples Epoch 1/100 11610/11610 [==============================] - 1s 89us/sample - loss: 0.6970 - val_loss: 0.5471 Epoch 2/100 11610/11610 [==============================] - 1s 70us/sample - loss: 0.4846 - val_loss: 0.4733 Epoch 3/100 11610/11610 [==============================] - 1s 71us/sample - loss: 0.4354 - val_loss: 0.4388 Epoch 4/100 11610/11610 [==============================] - 1s 72us/sample - loss: 0.4078 - val_loss: 0.4094 Epoch 5/100 11610/11610 [==============================] - 1s 71us/sample - loss: 0.3932 - val_loss: 0.4081 Epoch 6/100 11610/11610 [==============================] - 1s 71us/sample - loss: 0.3837 - val_loss: 0.3907 Epoch 7/100 11610/11610 [==============================] - 1s 71us/sample - loss: 0.3764 - val_loss: 0.3780 Epoch 8/100 11610/11610 [==============================] - 1s 70us/sample - loss: 0.3675 - val_loss: 0.3833 Epoch 9/100 11610/11610 [==============================] - 1s 74us/sample - loss: 0.3629 - val_loss: 0.3767 Epoch 10/100 11610/11610 [==============================] - 1s 73us/sample - loss: 0.3576 - val_loss: 0.3696 Epoch 11/100 11610/11610 [==============================] - 1s 71us/sample - loss: 0.3520 - val_loss: 0.3596 Epoch 12/100 11610/11610 [==============================] - 1s 78us/sample - loss: 0.3478 - val_loss: 0.3549 Epoch 13/100 11610/11610 [==============================] - 1s 81us/sample - loss: 0.3439 - val_loss: 0.3583 Epoch 14/100 11610/11610 [==============================] - 1s 80us/sample - loss: 0.3411 - val_loss: 0.3544 Epoch 15/100 11610/11610 [==============================] - 1s 81us/sample - loss: 0.3374 - val_loss: 0.3472 Epoch 16/100 11610/11610 [==============================] - 1s 76us/sample - loss: 0.3335 - val_loss: 0.3424 Epoch 17/100 11610/11610 [==============================] - 1s 76us/sample - loss: 0.3318 - val_loss: 0.3452 Epoch 18/100 11610/11610 [==============================] - 1s 71us/sample - loss: 0.3287 - val_loss: 0.3410 Epoch 19/100 11610/11610 [==============================] - 1s 69us/sample - loss: 0.3278 - val_loss: 0.3427 Epoch 20/100 11610/11610 [==============================] - 1s 71us/sample - loss: 0.3249 - val_loss: 0.3404