刚开始学习深度学习,使用 mnist 进行练习,刚开始自己定义了有3个卷积层(卷积+池化层),一个 fc 层,跑出的结果居然是50%,太低了,一脸懵逼,后来想到估计是自己 batch 的次数太少了(只有10 * 10次),想想也是,训练次数少(刚开始每个batch 是5000),每个 batch 会更新一次网络参数,100次肯定低到没朋友(没有达到拟合数据的数量级别),现在修改成了500。另外还有学习率的调参(现在使用的是0.01,尝试过0.02,0.05,其他参数不变的情况下,效果都不如0.01好),具体调参的话,后续打算用网格搜索法测试下。代码如下:
import tensorflow as tf
import argparse
from tensorflow.examples.tutorials.mnist import input_data
from sklearn import metrics
import numpy as np
class MnistCnnMethod:
def __init__(self):
self.x = tf.placeholder(tf.float32, [None, 28, 28, 1], name='x')
self.y = tf.placeholder(tf.float32, [None, 10], name='y')
self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
self.cnn()
def cnn(self):
with tf.name_scope('init_weights'):
w1 = tf.Variable(tf.random_normal([3, 3, 1, 32]), trainable=True, name='first_layer_w')
w2 = tf.Variable(tf.random_normal([3, 3, 32, 32]), trainable=True, name='second_layer_w')
w3 = tf.Variable(tf.random_normal([3, 3, 32, 64]), trainable=True, name='third_layer_w')
w4 = tf.Variable(tf.random_normal([64*4*4, 1024]), trainable=True, name='full_connect_w')
w_o = tf.Variable(tf.random_normal([1024, 10]), trainable=True, name='softmax_out_w')
with tf.name_scope('first_layer'):
lc1 = tf.nn.relu(tf.nn.conv2d(self.x, w1, strides=[1, 1, 1, 1], padding='SAME'))
lp1 = tf.nn.max_pool(lc1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
lp1 = tf.nn.dropout(lp1, self.keep_prob)
with tf.name_scope('second_layer'):
lc2 = tf.nn.relu(tf.nn.conv2d(lp1, w2, strides=[1, 1, 1, 1], padding='SAME'))
lp2 = tf.nn.max_pool(lc2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
lp2 = tf.nn.dropout(lp2, self.keep_prob)
with tf.name_scope('third_layer'):
lc3 = tf.nn.relu(tf.nn.conv2d(lp2, w3, strides=[1, 1, 1, 1], padding='SAME'))
lp3 = tf.nn.max_pool(lc3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
lp3 = tf.reshape(lp3, [-1, w4.get_shape().as_list()[0]])
lp3 = tf.nn.dropout(lp3, self.keep_prob)
with tf.name_scope('full_connect'):
lc4 = tf.nn.relu(tf.matmul(lp3, w4))
ld4 = tf.nn.dropout(lc4, self.keep_prob)
pyx = tf.matmul(ld4, w_o)
with tf.name_scope('predict'):
self.predict = tf.argmax(pyx, 1)
with tf.name_scope('loss'):
entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=pyx, labels=self.y)
cost = tf.reduce_mean(entropy, name='cost')
self.train_step = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost, var_list=tf.trainable_variables())
tf.summary.scalar('cost', cost)
with tf.name_scope('summary'):
self.merged = tf.summary.merge_all()
def train(mnistD, model):
batch_size = 256
keep_prob = 0.9
with tf.Session() as sess:
saver = tf.train.Saver()
writer = tf.summary.FileWriter('./log', sess.graph)
sess.run(tf.global_variables_initializer())
train_loops = 500
j = 0
for i in range(train_loops):
print(j)
j += 1
batch_x, batch_y = mnistD.train.next_batch(batch_size)
batch_x = np.reshape(batch_x, [-1, 28, 28, 1])
if i == train_loops - 1:
sess.run(model.train_step, feed_dict={model.x: batch_x, model.y: batch_y, model.keep_prob: keep_prob})
saver.save(sess, 'model/model.ckpt')
else:
summary, _ = sess.run([model.merged, model.train_step],
feed_dict={model.x: batch_x, model.y: batch_y, model.keep_prob: keep_prob})
writer.add_summary(summary, i)
def test(mnistD, model):
with tf.Session() as sess:
batch_x = mnistD.test.images
batch_y = mnistD.test.labels
batch_y = [np.where(r==1)[0][0] for r in batch_y]
tf.train.Saver().restore(sess=sess, save_path='model/model.ckpt')
pre = sess.run(model.predict, feed_dict={model.x: np.reshape(batch_x, [-1, 28, 28, 1]), model.keep_prob: 1.})
print(metrics.classification_report(batch_y, pre))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir',
type=str,
default='./mnist',
help='Directory for storing input data')
parser.add_argument(
'--type',
type=str,
default='train',
help='input train or test'
)
FLAGS, _ = parser.parse_known_args()
cnnModel = MnistCnnMethod()
mnistData = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
if FLAGS.type == 'train':
train(mnistData, cnnModel)
else:
test(mnistData, cnnModel)
cnn 结构:
这里面 numpy 居然也有 reshape 功能,汗,不知道。