tensorflow基础模板

import tensorflow as tf
import numpy as np
import os

#hyper parameters
batch_size = 32
save_path = 'model'
max_train_step = 10000
lr = 0.001

#network structure
class nn(object):
    def __init__(self, name='nn', trainning=True, reuse=False):
        with tf.variable_scope(name, reuse=reuse):
            self.global_step = tf.Variable(0, trainable=False, dtype=tf.int32)
    def summary(self):
        pass

#placehoder

#model
train_model = nn()

#opt
train_up = tf.train.AdamOptimizer(lr).minimize(train_model.loss, train_model.global_step)

#save
saver = tf.train.Saver()

with tf.Session() as sess:
    #recoder, summary
    train_model.summary()
    train_writer = tf.summary.FileWriter('log', graph=sess.graph)
    merged = tf.summary.merge_all()

    #restore or initail
    ckpt = tf.train.get_checkpoint_state(save_path)
    if ckpt:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(save_path, ckpt_name))
    else:
        sess.run(tf.global_variables_initializer())

    #circulation start
    global_step_val = sess.run(train_model.global_step)
    while global_step_val < max_train_step:
        sess.run(train_up, feed_dict={})
        global_step_val += 1
        if global_step_val % 100 == 0:
            saver.save(sess, os.path.join(save_path, 'nn.ckpt'), global_step_val)
            merged_summary = sess.run(merged, feed_dict={})
            train_writer.add_summary(merged_summary, global_step_val)


猜你喜欢

转载自blog.csdn.net/hujiankun073/article/details/88802691