mnist_backward

import input_data
import tensorflow as tf
import numpy as np
import lenet_forward

BATCH_SIZ = 100
REG = 0.0001
STEPS = 20000


def backward(mnist):
    x  = tf.placeholder(tf.float32,shape=(BATCH_SIZ,28,28,1))
    y_ = tf.placeholder(tf.float32,shape = (None,10))
    y = lenet_forward.forward(x,REG,True)

    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y,labels=y_),0)  + tf.add_n(tf.get_collection('losses'))
    global_step = tf.Variable(0, trainable=False)

    train_step = tf.train.AdamOptimizer(1e-4).minimize(loss,global_step = global_step)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        for i in range(STEPS):
            xs,ys = mnist.train.next_batch(BATCH_SIZ)
            sxs = np.reshape(xs,(BATCH_SIZ,28,28,1))
            _,lossval,step  = sess.run([train_step,loss,global_step],feed_dict={x:sxs,y_:ys})
            if step% 1000 == 0:
                print('step=%d ,loss = %g' % (step, lossval))
                saver.save(sess,'./model/',global_step = global_step)

if __name__ == '__main__':
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    backward(mnist)


猜你喜欢

转载自blog.csdn.net/u010795146/article/details/81907343