基于CIFAR10(小批量图片)数据集训练简单的深度卷积神经网络
名词解释
Cifar-10 是由 Hinton 的学生 Alex Krizhevsky、Ilya Sutskever 收集的一个用于普适物体识别的数据集。
官网:http://www.cs.toronto.edu/~kriz/cifar.html
Cifar 是加拿大政府牵头投资的一个先进科学项目研究所。Hinton、Bengio和他的学生在2004年拿到了 Cifar 投资的少量资金,建立了神经计算和自适应感知项目。这个项目结集了不少计算机科学家、生物学家、电气工程师、神经科学家、物理学家、心理学家,加速推动了 Deep Learning 的进程。从这个阵容来看,DL 已经和 ML 系的数据挖掘分的很远了。Deep Learning 强调的是自适应感知和人工智能,是计算机与神经科学交叉;Data Mining 强调的是高速、大数据、统计数学分析,是计算机和数学的交叉。
Cifar-10 由60000张32*32的 RGB 彩色图片构成,共10个分类。50000张训练,10000张
测试
(交叉验证)。这个数据集最大的特点在于将识别迁移到了普适物体,而且应用于多分类(姊妹数据集Cifar-100达到100类,ILSVRC比赛则是1000类)。
可以看到,同已经成熟的人脸识别相比,普适物体识别挑战巨大,数据中含有大量特征、噪声,识别物体比例不一。因而,Cifar-10 相对于传统图像识别数据集,是相当有挑战的。想了解更多信息请参考
CIFAR-10 page
,以及 Alex Krizhevsky 的
技术报告
。
PCA 与 ZCA详解
Keras ImageDataGenerator参数详解
代码注释
'''Train a simple deep CNN on the CIFAR10 small images dataset. 基于CIFAR10(小批量图片)数据集训练简单的深度卷积神经网络 It gets to 75% validation accuracy in 25 epochs, and 79% after 50 epochs. 25个周期后达到75%的精确度,50个周期后达到79%的精确度 (it's still underfitting at that point, though). (但是50个周期后,仍有欠拟合(训练集精度不高)) ''' from __future__ import print_function import keras from keras.datasets import cifar10 from keras.preprocessing.image import ImageDataGenerator from keras.models import Sequential from keras.layers import Dense, Dropout, Activation, Flatten from keras.layers import Conv2D, MaxPooling2D import os batch_size = 32 # 每个批次样本(数据记录)数 num_classes = 10 # 10分类 epochs = 100 # 100个周期 data_augmentation = True num_predictions = 20 save_dir = os.path.join(os.getcwd(), 'saved_models') # 训练的模型保存路径 model_name = 'keras_cifar10_trained_model.h5' # 训练的模型名称 # The data, shuffled and split between train and test sets: # 筛选(数据顺序打乱)、分割训练集和测试集 (x_train, y_train), (x_test, y_test) = cifar10.load_data() print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') # Convert class vectors to binary class matrices. # 类别向量转为多分类矩阵 y_train = keras.utils.to_categorical(y_train, num_classes) y_test = keras.utils.to_categorical(y_test, num_classes) # 建立基于keras的cnn模型 model = Sequential() model.add(Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:])) model.add(Activation('relu')) model.add(Conv2D(32, (3, 3))) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Conv2D(64, (3, 3), padding='same')) model.add(Activation('relu')) model.add(Conv2D(64, (3, 3))) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(512)) model.add(Activation('relu')) model.add(Dropout(0.5)) model.add(Dense(num_classes)) model.add(Activation('softmax')) # initiate RMSprop optimizer # 均方根反向传播(RMSprop,root mean square prop)优化 opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6) # Let's train the model using RMSprop # 使用均方根反向传播(RMSprop)训练模型 model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy']) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 if not data_augmentation: print('Not using data augmentation.') model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test), shuffle=True) else: print('Using real-time data augmentation.') # This will do preprocessing and realtime data augmentation: # 预处理和实时数据扩大(通过平移、翻转等图像变换增加图像样本数量)。 datagen = ImageDataGenerator( featurewise_center=False, # set input mean to 0 over the dataset # 基于数据集,使输入数据平均值为0 samplewise_center=False, # set each sample mean to 0 # 使样本平均值为0 featurewise_std_normalization=False, # divide inputs by std of the dataset # 通过数据标准化划分输入数据 samplewise_std_normalization=False, # divide each input by its std # 通过标准化划分输入数据 zca_whitening=False, # apply ZCA(Zero-phase Component Analysis) whitening # 对输入数据施加ZCA白化 rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) # 旋转图像0-180度 width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) # 水平平移图像(基于图像宽度比例) height_shift_range=0.1, # randomly shift images vertically (fraction of total height) # 垂直平移图像(基于图像高度比例) horizontal_flip=True, # randomly flip images # 水平翻转图像 vertical_flip=False) # randomly flip images # 垂直翻转图像 # Compute quantities required for feature-wise normalization # 特征归一化的计算量 # (std, mean, and principal components if ZCA whitening is applied). # (如果ZCA白化(一种降维方法)会使用标准化、均值和主成分方法) datagen.fit(x_train) # Fit the model on the batches generated by datagen.flow(). # 使用datagen.flow()生成的批次数据在模型训练 model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), epochs=epochs, validation_data=(x_test, y_test), workers=4) # Save model and weights # 保存模型和权重(数据) if not os.path.isdir(save_dir): os.makedirs(save_dir) model_path = os.path.join(save_dir, model_name) model.save(model_path) print('Saved trained model at %s ' % model_path) # Score trained model. # 评估训练的模型 scores = model.evaluate(x_test, y_test, verbose=1) print('Test loss:', scores[0]) print('Test accuracy:', scores[1])
执行过程
C:\ProgramData\Anaconda3\python.exe E:/keras-master/examples/cifar10_cnn.py Using TensorFlow backend. x_train shape: (50000, 32, 32, 3) 50000 train samples 10000 test samples Using real-time data augmentation. Epoch 1/100 1/1563 [..............................] - ETA: 1:03:06 - loss: 2.2554 - acc: 0.1250 4/1563 [..............................] - ETA: 16:11 - loss: 2.3237 - acc: 0.0781 7/1563 [..............................] - ETA: 9:27 - loss: 2.3296 - acc: 0.0759 9/1563 [..............................] - ETA: 7:31 - loss: 2.3343 - acc: 0.1007 11/1563 [..............................] - ETA: 6:16 - loss: 2.3295 - acc: 0.1080 14/1563 [..............................] - ETA: 5:02 - loss: 2.3278 - acc: 0.1004 17/1563 [..............................] - ETA: 4:15 - loss: 2.3202 - acc: 0.1048 19/1563 [..............................] - ETA: 3:52 - loss: 2.3207 - acc: 0.1069 21/1563 [..............................] - ETA: 3:33 - loss: 2.3159 - acc: 0.1131 23/1563 [..............................] - ETA: 3:18 - loss: 2.3157 - acc: 0.1114 26/1563 [..............................] - ETA: 2:58 - loss: 2.3108 - acc: 0.1142 29/1563 [..............................] - ETA: 2:43 - loss: 2.3086 - acc: 0.1207 32/1563 [..............................] - ETA: 2:31 - loss: 2.3099 - acc: 0.1191 35/1563 [..............................] - ETA: 2:20 - loss: 2.3080 - acc: 0.1187 38/1563 [..............................] - ETA: 2:11 - loss: 2.3085 - acc: 0.1143 41/1563 [..............................] - ETA: 2:04 - loss: 2.3073 - acc: 0.1151 44/1563 [..............................] - ETA: 1:57 - loss: 2.3085 - acc: 0.1122 47/1563 [..............................] - ETA: 1:52 - loss: 2.3086 - acc: 0.1117 50/1563 [..............................] - ETA: 1:46 - loss: 2.3073 - acc: 0.1106 53/1563 [>.............................] - ETA: 1:42 - loss: 2.3060 - acc: 0.1132 55/1563 [>.............................] - ETA: 1:40 - loss: 2.3048 - acc: 0.1131 58/1563 [>.............................] - ETA: 1:36 - loss: 2.3030 - acc: 0.1148 61/1563 [>.............................] - ETA: 1:33 - loss: 2.3008 - acc: 0.1158 64/1563 [>.............................] - ETA: 1:30 - loss: 2.2996 - acc: 0.1157 67/1563 [>.............................] - ETA: 1:27 - loss: 2.2991 - acc: 0.1161 70/1563 [>.............................] - ETA: 1:24 - loss: 2.2974 - acc: 0.1192 73/1563 [>.............................] - ETA: 1:22 - loss: 2.2962 - acc: 0.1194 76/1563 [>.............................] - ETA: 1:19 - loss: 2.2938 - acc: 0.1217 79/1563 [>.............................] - ETA: 1:17 - loss: 2.2939 - acc: 0.1222 82/1563 [>.............................] - ETA: 1:16 - loss: 2.2936 - acc: 0.1208 85/1563 [>.............................] - ETA: 1:14 - loss: 2.2913 - acc: 0.1235 88/1563 [>.............................] - ETA: 1:12 - loss: 2.2900 - acc: 0.1250 91/1563 [>.............................] - ETA: 1:11 - loss: 2.2896 - acc: 0.1257 94/1563 [>.............................] - ETA: 1:09 - loss: 2.2884 - acc: 0.1260 97/1563 [>.............................] - ETA: 1:08 - loss: 2.2879 - acc: 0.1263 100/1563 [>.............................] - ETA: 1:07 - loss: 2.2874 - acc: 0.1269 103/1563 [>.............................] - ETA: 1:06 - loss: 2.2865 - acc: 0.1271 106/1563 [=>............................] - ETA: 1:04 - loss: 2.2856 - acc: 0.1282 109/1563 [=>............................] - ETA: 1:03 - loss: 2.2845 - acc: 0.1313 112/1563 [=>............................] - ETA: 1:02 - loss: 2.2837 - acc: 0.1323 115/1563 [=>............................] - ETA: 1:01 - loss: 2.2831 - acc: 0.1318 118/1563 [=>............................] - ETA: 1:00 - loss: 2.2823 - acc: 0.1329 121/1563 [=>............................] - ETA: 1:00 - loss: 2.2808 - acc: 0.1343 124/1563 [=>............................] - ETA: 59s - loss: 2.2800 - acc: 0.1343 127/1563 [=>............................] - ETA: 58s - loss: 2.2800 - acc: 0.1341 130/1563 [=>............................] - ETA: 57s - loss: 2.2788 - acc: 0.1341 133/1563 [=>............................] - ETA: 56s - loss: 2.2770 - acc: 0.1346 136/1563 [=>............................] - ETA: 56s - loss: 2.2768 - acc: 0.1347 139/1563 [=>............................] - ETA: 55s - loss: 2.2760 - acc: 0.1356 141/1563 [=>............................] - ETA: 55s - loss: 2.2745 - acc: 0.1365 144/1563 [=>............................] - ETA: 54s - loss: 2.2740 - acc: 0.1365 147/1563 [=>............................] - ETA: 54s - loss: 2.2727 - acc: 0.1382 150/1563 [=>............................] - ETA: 53s - loss: 2.2702 - acc: 0.1402 153/1563 [=>............................] - ETA: 52s - loss: 2.2689 - acc: 0.1401 156/1563 [=>............................] - ETA: 52s - loss: 2.2657 - acc: 0.1416 159/1563 [==>...........................] - ETA: 51s - loss: 2.2636 - acc: 0.1425 8800/10000 [=========================>....] - ETA: 0s 9088/10000 [==========================>...] - ETA: 0s 9408/10000 [===========================>..] - ETA: 0s 9728/10000 [============================>.] - ETA: 0s 10000/10000 [==============================] - 2s 170us/step Test loss: 0.739190111351 Test accuracy: 0.7629 Process finished with exit code 0
Keras详细介绍
中文:http://keras-cn.readthedocs.io/en/latest/
实例下载
https://github.com/keras-team/keras
https://github.com/keras-team/keras/tree/master/examples