from tensorflow.examples.tutorials.mnist.input_data import read_data_sets
import numpy as np
import cv2
import tensorflow as tf
class Config:
def __init__(self):
self.sample_path = '../deeplearning_ai12/p07_mnist/MNIST_data'
self.lr = 0.001
self.epoches = 200
self.batch_size = 50
self.eps = 1e-10
self.base_filters = 16 # should be 32 at least
self.name = 'mnist07'
self.save_path = '../models/{name}/{name}'.format(name=self.name)
self.logdir = '../logs/{name}'.format(name=self.name)
class Tensors:
def __init__(self, config: Config):
self.config = config
self.x = tf.placeholder(tf.float32, [None, 784], 'x')
x = tf.reshape(self.x, [-1, 28, 28, 1]) # [-1, 28, 28, 1]
logits = self.get_logits(x) # [-1, 10]
self.y_predict = tf.argmax(logits, axis=1, output_type=tf.int32) # [-1]
p = tf.nn.softmax(logits) # [-1, 10]
self.y = tf.placeholder(tf.int32, [None], 'y')
y = tf.one_hot(self.y, 10) # [-1, 10]
p = tf.maximum(p, config.eps)
self.loss = -tf.reduce_mean(tf.reduce_sum(y * tf.log(p), axis=1))
opt = tf.train.AdamOptimizer(config.lr)
self.train_op = opt.minimize(self.loss)
self.precise = tf.reduce_mean(tf.cast(tf.equal(self.y, self.y_predict), tf.float32))
params = 0
for var in tf.trainable_variables():
ps = _params(var.shape)
print(var.name, var.shape, ps)
params += ps
print('-' * 200)
print('Total:', params)
tf.summary.scalar('loss', self.loss)
tf.summary.scalar('precise', self.precise)
self.summary_op = tf.summary.merge_all()
def get_logits(self, x):
"""
:param x: [-1, 28, 28, 1]
:return: [-1, 10]
"""
config = self.config
filters = config.base_filters
x = tf.layers.conv2d(x, filters, (3, 3), (1, 1), 'same',
activation=tf.nn.relu, name='conv1') # [-1, 28, 28, 32]
for i in range(2):
filters *= 2
x = tf.layers.conv2d(x, filters, 3, 1, 'same',
name='conv2_%d' % i) # [-1, 28, 28, 64]
x = tf.layers.max_pooling2d(x, (2, 2), (2, 2), 'same') # [-1, 14, 14, 64]
x = tf.nn.relu(x)
# x: [-1, 7, 7, 128]
x = tf.layers.flatten(x) # [-1, 7*7*128]
x = tf.layers.dense(x, 1000, activation=tf.nn.relu, name='dense1')
x = tf.layers.dense(x, 10, name='dense2') # [-1, 10]
return x
def _params(shape):
result = 1
for sh in shape:
result *= sh.value
return result
class Samples:
def __init__(self, config):
ds = read_data_sets(config.sample_path)
self.train = SubSamples(ds.train)
self.validation = SubSamples(ds.validation)
self.test = SubSamples(ds.test)
class SubSamples:
def __init__(self, data):
self.data = data
def num_examples(self):
return self.data.num_examples
def next_batch(self, batch_size):
return self.data.next_batch(batch_size) # xs: [batch_size, 784], ys: [batch_size]
def show_imgs(xs, ys):
print(ys)
xs = np.reshape(xs, [-1, 28, 28])
xs = np.transpose(xs, [1, 0, 2]) # [28, -1, 28]
xs = np.reshape(xs, [28, -1, 28 * 20]) # [28, -1, 560],
xs = np.transpose(xs, [1, 0, 2]) # [-1, 28, 560]
xs = np.reshape(xs, [-1, 28 * 20])
cv2.imshow('My digits', xs)
cv2.waitKey()
class App:
def __init__(self, config: Config):
self.config = config
self.samples = Samples(config)
g = tf.Graph()
with g.as_default():
self.tensors = Tensors(config)
self.session = tf.Session(graph=g)
self.saver = tf.train.Saver()
try:
self.saver.restore(self.session, config.save_path)
print('Restore the model from %s successfully' % config.save_path)
except:
print('Fail to restore the model from %s, use a new model instead' % config.save_path)
self.session.run(tf.global_variables_initializer())
def close(self):
self.session.close()
def train(self):
train_samples = self.samples.train
config = self.config
ts = self.tensors
fw = tf.summary.FileWriter(config.logdir, self.session.graph)
step = 0
for epoch in range(config.epoches):
batches = train_samples.num_examples() // config.batch_size
for batch in range(batches):
xs, ys = train_samples.next_batch(config.batch_size)
_, summary = self.session.run([ts.train_op, ts.summary_op], {ts.x: xs, ts.y: ys})
fw.add_summary(summary, step)
step += 1
xs, ys = self.samples.validation.next_batch(config.batch_size)
precise_v = self.session.run(ts.precise, {ts.x: xs, ts.y: ys})
print('Epoch: %d, batch %d: precise=%.6f' % (epoch, batch, precise_v))
self.saver.save(self.session, config.save_path)
print('Model saved into', config.save_path)
print('Training is finished!')
def predict(self):
xs, ys = self.samples.test.next_batch(self.config.batch_size)
print(ys)
ts = self.tensors
ys_predict = self.session.run(ts.y_predict, {ts.x: xs})
print('predict:')
print(ys_predict)
if __name__ == '__main__':
config = Config()
app = App(config)
app.train()
# app.predict()
app.close()
25-mnist07_tb
猜你喜欢
转载自blog.csdn.net/HJZ11/article/details/104793393
今日推荐
周排行