import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from PIL import ImageFont from PIL import Image from PIL import ImageDraw import numpy as np mnist = input_data.read_data_sets('./MNIST_data',one_hot=True) class Net: def __init__(self): self.x = tf.placeholder(shape=[None,784],dtype=tf.float32) self.y = tf.placeholder(shape=[None,10],dtype=tf.float32) self.w1 = tf.Variable(tf.truncated_normal(shape=[784,512],dtype=tf.float32,stddev=0.1)) self.b1 = tf.Variable(tf.zeros([512])) self.w2 = tf.Variable(tf.truncated_normal(shape=[512,256], dtype=tf.float32, stddev=0.1)) self.b2 = tf.Variable(tf.zeros([256])) self.w3 = tf.Variable(tf.truncated_normal(shape=[256, 10], dtype=tf.float32, stddev=0.1)) self.b3 = tf.Variable(tf.zeros([10])) def forward(self): self.y1 = tf.nn.dropout(tf.nn.relu(tf.layers.batch_normalization(tf.matmul(self.x,self.w1)+self.b1)),keep_prob=0.5) self.y2 = tf.nn.dropout(tf.nn.relu(tf.layers.batch_normalization(tf.matmul(self.y1, self.w2) + self.b2)),keep_prob=0.5) self.y3 = tf.nn.softmax(tf.layers.batch_normalization(tf.matmul(self.y2, self.w3) + self.b3)) def backward(self): self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.y3,labels=self.y)) self.opt = tf.train.AdamOptimizer(0.0001).minimize(self.loss) self.acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.y3,axis=1),tf.argmax(self.y,axis=1)),dtype=tf.float32)) if __name__ == '__main__': net = Net() net.forward() net.backward() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) for epoch in range(1000): x,y = mnist.train.next_batch(100) loss,_ = sess.run([net.loss,net.opt],feed_dict={net.x:x,net.y:y}) if epoch % 100 == 0: xs, ys = mnist.validation.next_batch(100) error, _,acc,out = sess.run([net.loss, net.opt,net.acc,net.y3], feed_dict={net.x: x, net.y: y}) imgarray = np.reshape(xs[0],[28,28])*255 img = Image.fromarray(imgarray)#从数组里面转换为图片 imgdraw = ImageDraw.ImageDraw(img) font = ImageFont.truetype(font='msyh.ttf',size=10) lable = np.argmax(out[0]) imgdraw.text(xy=(0,0),text=str(lable),font=font,fill=255) img.show()
tensorflow识别mnist并用PIL显示
猜你喜欢
转载自blog.csdn.net/weixin_38241876/article/details/89533082
今日推荐
周排行