一:
# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017
@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
n_input=784 #输入输出层神经元个数
learning_rate=0.001#学习率
def main(flag):
X = tf.placeholder(tf.float32, [None, n_input])
y_=tf.placeholder(tf.int32,[None,])
dense1 = tf.layers.dense(inputs=X,
units=1024,
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1,
units=512,
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2,
units=10,
activation=None,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss)
loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate).minimize(loss)
"""
equal--相等,tf.equal(A, B)是对比这两个矩阵或者向量的相等的元素,如果是相等的那就返回True,反正返回False,
返回的值的矩阵维度和A是一样的
tf.argmax(vector, 1):返回的是vector中的最大值的索引号
tf.cast(x,dtype,name=None) 将x的数据格式转化成dtype
"""
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) #返回是否相等
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) #将bool值转化为浮点数
sess=tf.InteractiveSession() #sess=tf.Session()
sess.run(tf.global_variables_initializer())
saver=tf.train.Saver(max_to_keep=3)
#训练阶段
if flag==1:
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
batch_xs, batch_ys = mnist.train.next_batch(256)
sess.run(train_op, feed_dict={X: batch_xs, y_: batch_ys}) #训练模型
# val验证阶段
val_loss,val_acc=sess.run([loss,acc], feed_dict={X: mnist.test.images, y_: mnist.test.labels})
print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
if val_acc>max_acc:
max_acc=val_acc
saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
#验证阶段
else:
model_file=tf.train.latest_checkpoint('ckpt/')
print("Find latest_checkpoint OK:",model_file)
saver.restore(sess,model_file)
print("Start evaluate val-data................")
val_loss,val_acc=sess.run([loss,acc], feed_dict={X: mnist.test.images, y_: mnist.test.labels})
print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()
if __name__=="__main__":
main(0)
二:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from __future__ import division, print_function, absolute_import
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#导入mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
learning_rate=0.01#学习率
n_hidden=256 #中间层神经元个数
n_input=784 #输入输出层神经元个数
X=tf.placeholder("float",[None,n_input])#待输入数据
#权重
weights={
'encoder_w':tf.Variable(tf.random_normal([n_input,n_hidden])),
'decoder_w':tf.Variable(tf.random_normal([n_hidden,n_input])),
}
#偏置
biases={
'encoder_b':tf.Variable(tf.random_normal([n_hidden])),
'decoder_b':tf.Variable(tf.random_normal([n_input])),
}
#编码过程
def encoder(x):
return tf.nn.sigmoid(tf.add(tf.matmul(x,weights['encoder_w']),biases['encoder_b']))
#解码过程
def decoder(x):
return tf.nn.sigmoid(tf.add(tf.matmul(x,weights['decoder_w']),biases['decoder_b']))
encoder_op=encoder(X)#编码开始
decoder_op=decoder(encoder_op)#解码开始
y_pred=decoder_op#预测值
y_ture=X#真实值(输入值)
loss=tf.reduce_mean(tf.pow(y_ture-y_pred,2))#代价函数(tensor均方误差),tf.pow(x,y)得到x的y次幂
train_op=tf.train.RMSPropOptimizer(learning_rate).minimize(loss)#优化过程
init=tf.initialize_all_variables()
saver=tf.train.Saver(max_to_keep=3)
with tf.Session() as sess:
sess.run(init)
for epoch in range(10):#训练十轮
for i in range(300):#每轮300次
batch_xs,batch_ys=mnist.train.next_batch(256)#每个步骤随机抓取256个批处理数据点
_,c=sess.run([train_op,loss],feed_dict={X:batch_xs})#运行
print("epoch",'%04d'%(epoch+1),"cost","{:.9f}".format(c))
saver.save(sess, 'ckpt-qxp/mnist.ckpt', global_step=epoch + 1)
print("Finished!")
#评估模型(可视化结果)
print("可视化结果")
encode_decode=sess.run(y_pred,feed_dict={X:mnist.test.images[:10]})
fig,ax=plt.subplots(2,10,figsize=(10,2))
for i in range(10):
ax[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
ax[1][i].imshow(np.reshape(encode_decode[i], (28, 28)))
#fig.show()
plt.subplot_tool()
plt.show()
#plt.draw()
print("all done")