import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print(trainlabel[0]) #输出第一行label值
运行结果:
判别样本是否为对应类别,是则为1,否则为0,完成10分类任务,故上述样本的类别为9。
#初始化变量
x = tf.placeholder("float",[None,784]) #placeholder(先占位,不复制),[None,784]样本的个数(无限大),每个样本的特征(784个像素点)
y = tf.placeholder("float",[None,10])#样本的类别(10个)
W = tf.Variable(tf.zeros([784,10]))#每个特征(784个像素点)对应输出10个分类值
b = tf.Variable(tf.zeros([10]))
#逻辑回归模型(softmax完成多分类任务)
actv = tf.nn.softmax(tf.matmul(x,W)+b)#计算属于正确类别的概率值
#计算损失值(预测值与真实值间的均方差)
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
#采用梯度下降优化参数(W,b),最小化损失值
learning_rate=0.01
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) #学习率为0.01
#optimizer = tf.train.GradientDescentOptimizer(0.01)#学习率为0.01
#最小化损失值
#optm = optimizer.minimize(cost)
#预测值,equal返回的值是布尔类型
#argmax返回矩阵中最大元素的索引,0,代表列方向;1代表行方向
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#准确率
accr = tf.reduce_mean(tf.cast(pred,"float")) #cast进行类型转化 (true为1,false为0)
init_op = tf.global_variables_initializer()
#定义全局变量
training_epochs = 50 #将所有样本迭代50次
batch_size = 100 #每次迭代选择样本的个数
display_step =5 #每进行5个epoch进行一次展示
with tf.Session() as sess:
sess.run(init_op)
for epoch in range(training_epochs):
avg_cost =0.0 # 初始化损失值
num_batch = int(mnist.train.num_examples/batch_size)
for i in range(num_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size) #以batch为单位逐次进行
sess.run(optm,feed_dict={x: batch_xs,y: batch_ys}) #给x,y赋值
feeds={x: batch_xs,y: batch_ys}
avg_cost +=sess.run(cost,feed_dict= feeds)/num_batch
#显示
if epoch % display_step == 0:
feeds_train = {x: batch_xs,y: batch_ys}
feeds_test = {x:mnist.test.images,y: mnist.test.labels}
train_acc = sess.run(accr,feed_dict= feeds_train)
test_acc = sess.run(accr,feed_dict= feeds_test)
print("Epoch: %03d/%03d cost:%.9f trian_acc: %.3f test_acc: %.3f"
% (epoch,training_epochs,avg_cost,train_acc,test_acc))
print("Done")
运行结果:
由上图可知训练集的准确率为91%,测试集的准确率为91.8%。