import tensorflow as tf
import numpy as np
import os
#hyper parameters
batch_size = 32
save_path = 'model'
max_train_step = 10000
lr = 0.001
#network structure
class nn(object):
def __init__(self, name='nn', trainning=True, reuse=False):
with tf.variable_scope(name, reuse=reuse):
self.global_step = tf.Variable(0, trainable=False, dtype=tf.int32)
def summary(self):
pass
#placehoder
#model
train_model = nn()
#opt
train_up = tf.train.AdamOptimizer(lr).minimize(train_model.loss, train_model.global_step)
#save
saver = tf.train.Saver()
with tf.Session() as sess:
#recoder, summary
train_model.summary()
train_writer = tf.summary.FileWriter('log', graph=sess.graph)
merged = tf.summary.merge_all()
#restore or initail
ckpt = tf.train.get_checkpoint_state(save_path)
if ckpt:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(save_path, ckpt_name))
else:
sess.run(tf.global_variables_initializer())
#circulation start
global_step_val = sess.run(train_model.global_step)
while global_step_val < max_train_step:
sess.run(train_up, feed_dict={})
global_step_val += 1
if global_step_val % 100 == 0:
saver.save(sess, os.path.join(save_path, 'nn.ckpt'), global_step_val)
merged_summary = sess.run(merged, feed_dict={})
train_writer.add_summary(merged_summary, global_step_val)
tensorflow基础模板
猜你喜欢
转载自blog.csdn.net/hujiankun073/article/details/88802691
今日推荐
周排行