网络优化方法--正则化
正则化
1.1 正则化介绍
正则化也叫作规范化,通常用得比较多的方式是 L1 正则化和 L2 正则化。L1 和 L2 正则 化的使用实际上就是在普通的代价函数(例如均方差代价函数或交叉熵代价函数)后面加上一 个正则项,例如加上了 L1 正则项的交叉熵为:
E = − 1 N ∑ i = 1 N [ t i ln y i + ( 1 − t i ) ln ( 1 − y i ) ] + λ 2 N ∑ w ∣ w ∣ E=-\frac{1}{N} \sum_{i=1}^{N} \left [ t_i\ln_{}{y_i}+(1-t_i)\ln_{}{(1-y_i)} \right ] +\frac{\lambda }{2N} \sum_{w}^{}\left | w \right | E=−N1i=1∑N[tilnyi+(1−ti)ln(1−yi)]+2Nλw∑∣w∣
加上L2正则项的交叉熵为:
E = − 1 N ∑ i = 1 N [ t i ln y i + ( 1 − t i ) ln ( 1 − y i ) ] + λ 2 N ∑ w w 2 E=-\frac{1}{N} \sum_{i=1}^{N} \left [ t_i\ln_{}{y_i}+(1-t_i)\ln_{}{(1-y_i)} \right ] +\frac{\lambda }{2N} \sum_{w}^{}w^2 E=−N1i=1∑N[tilnyi+(1−ti)ln(1−yi)]+2Nλw∑w2
L2正则项的交叉熵也可写为:
E = E 0 + λ 2 N ∑ w w 2 E=E_0+\frac{\lambda }{2N} \sum_{w}^{w^2} E=E0+2Nλw∑w2
其中 E 0 E_0 E0是原始的代价函数, λ \lambda λ是正则项的系数, λ \lambda λ是一个大于 0 的数, λ \lambda λ的值越大那么正则 项的影响就越大, λ \lambda λ的值越小正则项的影响也就越小,当 λ \lambda λ为 0 时,相当于正则项不存在。N 表 示样本个数。w 代表所有的权值参数和偏置值。
我们训练模型的过程中实际上就是使用梯度下降法来最小化代价函数的过程,交叉熵代价 函数中的 t 和 y 的值越接近,那么代价函数的值就越接近于 0。观察带有正则项的代价函数表 达式我们可以知道,最小化代价函数的过程中不仅要使得 t 的值接近于 y,还要使得神经网络 的权值参数 w 的值趋近于 0。因为不管是对于 L1 正则项 λ 2 N ∑ w ∣ w ∣ \frac{\lambda }{2N} \sum_{w}^{}\left | w \right | 2Nλ∑w∣w∣还是对于 L2 正则项 λ 2 N ∑ w w 2 \frac{\lambda }{2N} \sum_{w}^{}w^2 2Nλ∑ww2, 正则项的值都是大于 0 的,所以最小化正则项的值,实际上就是让 w 的值接近于 0。
1.2 L1正则项与L2正则项的区别
L1 正则项会使得神经网络中的很多权值参数变为 0,如果神经网络中很多的权值都是 0 的 话那么可以认为网络的复杂度降低了,拟合能力也降低了,因此不容易出现过拟合的情况。
L2 正则项会使得神经网络的权值衰减,权值参数变为接近于 0 的值,注意这里的接近于 0 不是等于零,L2 正则化很少会使权值参数等于 0。L2 正则项之所以有效是因为权值参数 w 变 得很小之后 WX+b 的计算也是会变成一个接近于 0 的值。我们知道在使用 sigmoid(x)函数或 者 tanh(x)函数时,当 x 的取值在 0 附近时,函数的曲线是非常接近于一条直线的,如图 所示。
所以神经网络中增加了很多线性特征减少了很多非线性的特征,网络的复杂度降低了,因 此不容易出现过拟合
1.3 正则化程序
这里我们将正则化应用在MNIST数据集识别中。
代码使用Jupyter Notebook调试。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Dropout,Flatten
from tensorflow.keras.optimizers import SGD
import matplotlib.pyplot as plt
import numpy as np
# 使用l1或l2正则化
from tensorflow.keras.regularizers import l1,l2
# 载入数据集
mnist = tf.keras.datasets.mnist
# 载入训练集和测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对训练集和测试集的数据进行归一化处理,有助于提升模型训练速度
x_train, x_test = x_train / 255.0, x_test / 255.0
# 把训练集和测试集的标签转为独热编码
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)
# 模型定义,model1使用l2正则化
# l2(0.0003)表示使用l2正则化,正则化系数为0.0003
model1 = Sequential([
Flatten(input_shape=(28, 28)),
Dense(units=200,activation='tanh',kernel_regularizer=l2(0.0003)),
Dense(units=100,activation='tanh',kernel_regularizer=l2(0.0003)),
Dense(units=10,activation='softmax',kernel_regularizer=l2(0.0003))
])
# 在定义一个一模一样的模型用于对比测试,model2不使用正则化
model2 = Sequential([
Flatten(input_shape=(28, 28)),
Dense(units=200,activation='tanh'),
Dense(units=100,activation='tanh'),
Dense(units=10,activation='softmax')
])
# sgd定义随机梯度下降法优化器
# loss='categorical_crossentropy'定义交叉熵代价函数
# metrics=['accuracy']模型在训练的过程中同时计算准确率
sgd = SGD(0.2)
model1.compile(optimizer=sgd,
loss='categorical_crossentropy',
metrics=['accuracy'])
model2.compile(optimizer=sgd,
loss='categorical_crossentropy',
metrics=['accuracy'])
# 传入训练集数据和标签训练模型
# 周期大小为30(把所有训练集数据训练一次称为训练一个周期)
epochs = 30
# 批次大小为32(每次训练模型传入32个数据进行训练)
batch_size=32
# 先训练model1
history1 = model1.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test,y_test))
# 再训练model2
history2 = model2.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test,y_test))
训练过程:
Train on 60000 samples, validate on 10000 samples Epoch 1/30 60000/60000 [==============================] - 6s 101us/sample - loss: 0.4063 - accuracy: 0.9205 - val_loss: 0.2799 - val_accuracy: 0.9560 Epoch 2/30 60000/60000 [==============================] - 5s 78us/sample - loss: 0.2611 - accuracy: 0.9605 - val_loss: 0.2427 - val_accuracy: 0.9634 Epoch 3/30 60000/60000 [==============================] - 4s 73us/sample - loss: 0.2177 - accuracy: 0.9692 - val_loss: 0.2148 - val_accuracy: 0.9661 Epoch 4/30 60000/60000 [==============================] - 4s 74us/sample - loss: 0.1936 - accuracy: 0.9734 - val_loss: 0.1896 - val_accuracy: 0.9737 Epoch 5/30 60000/60000 [==============================] - 4s 73us/sample - loss: 0.1779 - accuracy: 0.9761 - val_loss: 0.1812 - val_accuracy: 0.9747 Epoch 6/30 60000/60000 [==============================] - 4s 73us/sample - loss: 0.1682 - accuracy: 0.9776 - val_loss: 0.1730 - val_accuracy: 0.9749 Epoch 7/30 60000/60000 [==============================] - 4s 73us/sample - loss: 0.1616 - accuracy: 0.9785 - val_loss: 0.1727 - val_accuracy: 0.9740 Epoch 8/30 60000/60000 [==============================] - 4s 73us/sample - loss: 0.1573 - accuracy: 0.9794 - val_loss: 0.1744 - val_accuracy: 0.9741 Epoch 9/30 60000/60000 [==============================] - 4s 74us/sample - loss: 0.1534 - accuracy: 0.9806 - val_loss: 0.1730 - val_accuracy: 0.9744 Epoch 10/30 60000/60000 [==============================] - 4s 73us/sample - loss: 0.1510 - accuracy: 0.9808 - val_loss: 0.1802 - val_accuracy: 0.9700 Epoch 11/30 60000/60000 [==============================] - 4s 73us/sample - loss: 0.1486 - accuracy: 0.9818 - val_loss: 0.1590 - val_accuracy: 0.9778 Epoch 12/30 60000/60000 [==============================] - 5s 77us/sample - loss: 0.1465 - accuracy: 0.9821 - val_loss: 0.1578 - val_accuracy: 0.9791 Epoch 13/30 60000/60000 [==============================] - 4s 74us/sample - loss: 0.1462 - accuracy: 0.9819 - val_loss: 0.1564 - val_accuracy: 0.9772 Epoch 14/30 60000/60000 [==============================] - 4s 74us/sample - loss: 0.1442 - accuracy: 0.9822 - val_loss: 0.1582 - val_accuracy: 0.9777 Epoch 15/30 60000/60000 [==============================] - 4s 74us/sample - loss: 0.1437 - accuracy: 0.9829 - val_loss: 0.1649 - val_accuracy: 0.9745 Epoch 16/30 60000/60000 [==============================] - 4s 75us/sample - loss: 0.1408 - accuracy: 0.9833 - val_loss: 0.1548 - val_accuracy: 0.9792 Epoch 17/30 60000/60000 [==============================] - 4s 74us/sample - loss: 0.1418 - accuracy: 0.9831 - val_loss: 0.1546 - val_accuracy: 0.9783 Epoch 18/30 60000/60000 [==============================] - 4s 75us/sample - loss: 0.1417 - accuracy: 0.9833 - val_loss: 0.1552 - val_accuracy: 0.9782 Epoch 19/30 60000/60000 [==============================] - 4s 75us/sample - loss: 0.1421 - accuracy: 0.9831 - val_loss: 0.1559 - val_accuracy: 0.9777 Epoch 20/30 60000/60000 [==============================] - 4s 75us/sample - loss: 0.1393 - accuracy: 0.9840 - val_loss: 0.1682 - val_accuracy: 0.9725 Epoch 21/30 60000/60000 [==============================] - 6s 92us/sample - loss: 0.1389 - accuracy: 0.9839 - val_loss: 0.1545 - val_accuracy: 0.9772 Epoch 22/30 60000/60000 [==============================] - 6s 96us/sample - loss: 0.1395 - accuracy: 0.9837 - val_loss: 0.1518 - val_accuracy: 0.9802 Epoch 23/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1396 - accuracy: 0.9834 - val_loss: 0.1484 - val_accuracy: 0.9792 Epoch 24/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1385 - accuracy: 0.9842 - val_loss: 0.1595 - val_accuracy: 0.9759 Epoch 25/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1379 - accuracy: 0.9842 - val_loss: 0.1694 - val_accuracy: 0.9737 Epoch 26/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1373 - accuracy: 0.9846 - val_loss: 0.1588 - val_accuracy: 0.9767 Epoch 27/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1377 - accuracy: 0.9839 - val_loss: 0.1512 - val_accuracy: 0.9797 Epoch 28/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1367 - accuracy: 0.9844 - val_loss: 0.1461 - val_accuracy: 0.9810 Epoch 29/30 60000/60000 [==============================] - 6s 98us/sample - loss: 0.1385 - accuracy: 0.9837 - val_loss: 0.1554 - val_accuracy: 0.9765 Epoch 30/30 60000/60000 [==============================] - 6s 95us/sample - loss: 0.1371 - accuracy: 0.9843 - val_loss: 0.1631 - val_accuracy: 0.9751 Train on 60000 samples, validate on 10000 samples Epoch 1/30 60000/60000 [==============================] - 6s 100us/sample - loss: 0.2524 - accuracy: 0.9245 - val_loss: 0.1453 - val_accuracy: 0.9544 Epoch 2/30 60000/60000 [==============================] - 6s 94us/sample - loss: 0.1168 - accuracy: 0.9643 - val_loss: 0.1153 - val_accuracy: 0.9634 Epoch 3/30 60000/60000 [==============================] - 6s 93us/sample - loss: 0.0800 - accuracy: 0.9753 - val_loss: 0.0893 - val_accuracy: 0.9705 Epoch 4/30 60000/60000 [==============================] - 6s 92us/sample - loss: 0.0599 - accuracy: 0.9812 - val_loss: 0.0928 - val_accuracy: 0.9694 Epoch 5/30 60000/60000 [==============================] - 6s 92us/sample - loss: 0.0450 - accuracy: 0.9858 - val_loss: 0.0725 - val_accuracy: 0.9774 Epoch 6/30 60000/60000 [==============================] - 5s 91us/sample - loss: 0.0337 - accuracy: 0.9897 - val_loss: 0.0804 - val_accuracy: 0.9744 Epoch 7/30 60000/60000 [==============================] - 5s 90us/sample - loss: 0.0253 - accuracy: 0.9925 - val_loss: 0.0749 - val_accuracy: 0.9784 Epoch 8/30 60000/60000 [==============================] - 5s 91us/sample - loss: 0.0185 - accuracy: 0.9945 - val_loss: 0.0756 - val_accuracy: 0.9773 Epoch 9/30 60000/60000 [==============================] - 5s 91us/sample - loss: 0.0129 - accuracy: 0.9968 - val_loss: 0.0712 - val_accuracy: 0.9788 Epoch 10/30 60000/60000 [==============================] - 5s 91us/sample - loss: 0.0090 - accuracy: 0.9981 - val_loss: 0.0686 - val_accuracy: 0.9789 Epoch 11/30 60000/60000 [==============================] - 5s 90us/sample - loss: 0.0057 - accuracy: 0.9993 - val_loss: 0.0692 - val_accuracy: 0.9798 Epoch 12/30 60000/60000 [==============================] - 5s 91us/sample - loss: 0.0040 - accuracy: 0.9996 - val_loss: 0.0688 - val_accuracy: 0.9789 Epoch 13/30 60000/60000 [==============================] - 6s 92us/sample - loss: 0.0032 - accuracy: 0.9997 - val_loss: 0.0690 - val_accuracy: 0.9793 Epoch 14/30 60000/60000 [==============================] - 5s 91us/sample - loss: 0.0024 - accuracy: 0.9999 - val_loss: 0.0681 - val_accuracy: 0.9798 Epoch 15/30 60000/60000 [==============================] - 5s 91us/sample - loss: 0.0018 - accuracy: 0.9999 - val_loss: 0.0690 - val_accuracy: 0.9800 Epoch 16/30 60000/60000 [==============================] - 5s 91us/sample - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0694 - val_accuracy: 0.9800 Epoch 17/30 60000/60000 [==============================] - 5s 91us/sample - loss: 0.0014 - accuracy: 1.0000 - val_loss: 0.0696 - val_accuracy: 0.9803 Epoch 18/30 60000/60000 [==============================] - 5s 91us/sample - loss: 0.0012 - accuracy: 1.0000 - val_loss: 0.0707 - val_accuracy: 0.9801 Epoch 19/30 60000/60000 [==============================] - 5s 91us/sample - loss: 0.0011 - accuracy: 1.0000 - val_loss: 0.0723 - val_accuracy: 0.9798 Epoch 20/30 60000/60000 [==============================] - 5s 90us/sample - loss: 9.8140e-04 - accuracy: 1.0000 - val_loss: 0.0718 - val_accuracy: 0.9801 Epoch 21/30 60000/60000 [==============================] - 5s 91us/sample - loss: 8.9510e-04 - accuracy: 1.0000 - val_loss: 0.0718 - val_accuracy: 0.9803 Epoch 22/30 60000/60000 [==============================] - 5s 91us/sample - loss: 8.2853e-04 - accuracy: 1.0000 - val_loss: 0.0733 - val_accuracy: 0.9797 Epoch 23/30 60000/60000 [==============================] - 5s 91us/sample - loss: 7.6028e-04 - accuracy: 1.0000 - val_loss: 0.0732 - val_accuracy: 0.9807 Epoch 24/30 60000/60000 [==============================] - 5s 92us/sample - loss: 7.1290e-04 - accuracy: 1.0000 - val_loss: 0.0738 - val_accuracy: 0.9803 Epoch 25/30 60000/60000 [==============================] - 6s 99us/sample - loss: 6.6676e-04 - accuracy: 1.0000 - val_loss: 0.0734 - val_accuracy: 0.9797 Epoch 26/30 60000/60000 [==============================] - 6s 98us/sample - loss: 6.2845e-04 - accuracy: 1.0000 - val_loss: 0.0738 - val_accuracy: 0.9803 Epoch 27/30 60000/60000 [==============================] - 6s 98us/sample - loss: 5.9281e-04 - accuracy: 1.0000 - val_loss: 0.0747 - val_accuracy: 0.9802 Epoch 28/30 60000/60000 [==============================] - 6s 98us/sample - loss: 5.6025e-04 - accuracy: 1.0000 - val_loss: 0.0745 - val_accuracy: 0.9797 Epoch 29/30 60000/60000 [==============================] - 6s 98us/sample - loss: 5.3286e-04 - accuracy: 1.0000 - val_loss: 0.0752 - val_accuracy: 0.9803 Epoch 30/30 60000/60000 [==============================] - 6s 98us/sample - loss: 5.0925e-04 - accuracy: 1.0000 - val_loss: 0.0746 - val_accuracy: 0.9801
# 画出model1验证集准确率曲线图
plt.plot(np.arange(epochs),history1.history['val_accuracy'],c='b',label='L2 Regularization')
# 画出model2验证集准确率曲线图
plt.plot(np.arange(epochs),history2.history['val_accuracy'],c='y',label='FC')
# 图例
plt.legend()
# x坐标描述
plt.xlabel('epochs')
# y坐标描述
plt.ylabel('accuracy')
# 显示图像
plt.show()
前 1-30 周期是使用 L2 正则化的 model1 的结果,后 1-30 周期是不使用正则化的 model2 的结果。从结果上看,使用正则化后 model1 的训练集准确率和验证集准确率相差不大,说明正则化确实是可以起到抵抗过拟合的作用。但是使用正则化之后验证集准确率的结果并不是非常理想,说明正则化并不是适用于所有场景。在神经网络结构比较复杂,训练数据量比较少的时候,使用正则化效果会比较好。如果网络不算太复杂的话,任务比较简单的时候,使用正则化可能准确率反而会下降。对于 Dropout 来说也有类似的情况。所以 Dropout 和正则化需要根据实际使用情况的好坏来决定是否使用。