小项目——训练画风数据集

承接上篇处理好数据之后开始定义网络

tf.placeholder定义形参

datas_placeholder = tf.placeholder(tf.float32, [None,32,32,3])
labels_placeholder = tf.placeholder(tf.int32, [None])
dropout_placeholder = tf.placeholder(tf.float32) 

tf.layers定义网络结构

conv0 = tf.layers.conv2d(datas_placeholder, 20, 5, activation=tf.nn.relu)
pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])
conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])
flatten = tf.layers.flatten(pool1)
fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
dropout_fc = tf.layers.dropout(fc, dropout_placeholder)
logits = tf.layers.dense(dropout_fc, num_classes)
predicted_labels = tf.argmax(logits, 1)

定义损失函数和优化器

losses = tf.nn.softmax_cross_entropy_with_logits_v2(
    labels=tf.one_hot(labels_placeholder, num_classes),
    logits=logits
)
mean_loss = tf.reduce_mean(losses)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-2).minimize(losses)

保存模型+(训练or测试)

saver = tf.train.Saver()
with tf.Session() as sess:
    if train:
        print("训练")
        sess.run(tf.global_variables_initializer())
        train_feed_dict = {
            datas_placeholder: datas,
            labels_placeholder: labels,
            dropout_placeholder: 0.1
        }
        for step in range(200):
            _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
            if step % 10 == 0:
                print("step = {}\tmean loss = {}".format(step, mean_loss_val))
        saver.save(sess, model_path)
        print("训练结束,保存模型到{}".format(model_path))
    else:
        print("测试")
        saver.restore(sess, model_path)
        print("从{}载入模型".format(model_path))
        label_name_dict = {
            0: "涂鸦",
            1: "油画",
            2: "素描"
        }
        test_feed_dict = {
            datas_placeholder: datas,
            labels_placeholder: labels,
            dropout_placeholder: 0
        }
        predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
        for fpath, real_label, predicted_label in zip(fpaths, labels, predicted_labels_val):
            real_label_name = label_name_dict[real_label]
            predicted_label_name = label_name_dict[predicted_label]
            print("{}\t{} => {}".format(fpath, real_label_name, predicted_label_name))

训练效果
在这里插入图片描述还不错,至此画风识别小项目基本完成(手动撒花)

持续完善

猜你喜欢

转载自blog.csdn.net/qq_38484259/article/details/84971185