版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/zhangpengzp/article/details/85177780
softmax 是将分类函数`y = w * x`转化为概率分布的函数。下面是以对tensorflow mnist手写数字图像识别的0-9进行训练识别为例子。为了简单,直接上截图得了:
下面是tensorflow的代码。利用` with tf.name_scope`一个节点借点的写,有利于网络模型的可视化,更好的维护和管理。
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import argparse
import sys
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 在main函数中组织计算整个计算图 运行会话
def main():
print("******开始设计计算图*******")
with tf.Graph().as_default():
# input: 定义输入节点
with tf.name_scope('Input'):
# 计算图输入占位符
X = tf.placeholder(tf.float32, shape=[None, 784], name='X')
Y_true = tf.placeholder(tf.float32, shape=[None, 10], name='Y_true')
# interence 前向预测,创建一个线性模型: y = x*w + b
with tf.name_scope('Interence'):
W = tf.Variable(tf.zeros([784, 10]), name='Weight')
b = tf.Variable(tf.zeros([10]))
logits = tf.add(tf.matmul(X, W), b)
# softmax 把Y_pred变成概率分布
with tf.name_scope('SoftMax'):
Y_pred = tf.nn.softmax(logits=logits)
# 定义损失节点
with tf.name_scope('Loss'):
Loss = tf.reduce_mean(
-tf.reduce_sum(Y_true * tf.log(Y_pred), axis=1)
)
# Train 定义训练节点
with tf.name_scope('Train'):
# Optimizer :创建梯度下降优化器
Optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
# Train: 定义训练节点将梯度下降法应用于Loss
TrainStep = Optimizer.minimize(Loss)
# Evaluate : 评估节点
with tf.name_scope('Evaluate'):
correct_prediction = tf.equal(tf.argmax(Y_pred, 1), tf.argmax(Y_true, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# initial: 添加所有Variable类型的变量的初始化节点
initOp = tf.global_variables_initializer()
print('把计算图写入事件文件,在tensorBoard里面查看')
writer = tf.summary.FileWriter(logdir='logs/mnist_softmax',
graph=tf.get_default_graph())
writer.close()
print("******开始运行计算图*******")
# 加载数据
mnist = input_data.read_data_sets('/datasets', one_hot=True)
# 声明一个交互式会话
sess = tf.InteractiveSession()
# 初始化所有变量,w、b
sess.run(initOp)
# 开始按批次训练,总共训练1000个批次,每个批次100个样本
for step in range(2000):
batch_xs, batch_ys = mnist.train.next_batch(100)
# 将当前批次的样本feed给计算图中的输入占位符,启动训练节点开始训练
_, train_loss = sess.run([TrainStep, Loss], feed_dict={X: batch_xs,
Y_true: batch_ys})
print('train step:', step, ',train_loss s :', train_loss)
accuracy_score = sess.run(accuracy, feed_dict={X: mnist.test.images,
Y_true: mnist.test.labels})
print("准确率 : ", accuracy_score)
if __name__ == '__main__':
main()