tensorflow自己实现SGD功能

手动实现SGD和调用优化器结果比较

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#mnist已经作为官方的例子,做好了数据下载,分割,转浮点等一系列工作,源码在tensorflow源码中都可以找到
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 配置每个 GPU 上占用的内存的比例
# 没有GPU直接sess = tf.Session()
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

#每个批次的大小
batch_size = 20
#定义训练轮数据
train_epoch = 1
#定义每n轮输出一次
test_epoch_n = 1

#计算一共有多少批次
n_batch = mnist.train.num_examples // batch_size
print("batch_size="+str(batch_size)+"n_batch="+str(n_batch))

#占位符,定义了输入,输出
x = tf.placeholder(tf.float32,[None, 784]) 
y_ = tf.placeholder(tf.float32,[None, 10]) 
#权重和偏置,使用0初始化
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#权重和偏置,使用0初始化
W2 = tf.Variable(tf.zeros([784,10]))
b2 = tf.Variable(tf.zeros([10]))

#这里定义的网络结构
y = tf.matmul(x,W) + b
#损失函数是交叉熵
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_,logits=y))

lr = 0.01#学习率
gw,gb = tf.gradients(ys=cross_entropy, xs=[W,b])
wt = W - lr * gw
bt = b - lr * gb
updatew = tf.assign(W,wt)
updateb = tf.assign(b,bt)

#这里定义的网络结构
y2 = tf.matmul(x,W2) + b2
#损失函数是交叉熵
cross_entropy2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_,logits=y2))
#训练方法:
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy2)

#初始化sess中所有变量
init = tf.global_variables_initializer() 
sess.run(init) 
batch_xs, batch_ys = mnist.train.next_batch(batch_size)

#输出手动SGD后的b值
for _ in range(2):
    _,_,testsee1,testsee2 = sess.run([updatew,updateb,cross_entropy,b], feed_dict = {x: batch_xs, y_: batch_ys})
    print(testsee1)
    print(testsee2)

#输出优化器后的b值
for _ in range(2):
    _,testsee1,testsee2 = sess.run([train_step,cross_entropy2,b2], feed_dict = {x: batch_xs, y_: batch_ys})
    print(testsee1)
    print(testsee2)

输出结果:

2.30259
[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
2.26225
[ -5.00000024e-04   9.99999931e-04  -5.00000082e-04  -5.00000024e-04
  -1.02445483e-10   4.99999849e-04  -1.00000005e-03  -6.51925805e-11
  -2.79396766e-11   9.99999931e-04]

loss 和b值均一致,说明自己更新网络参数和优化器自动更新一致

此代码网络初始化均为0,mnist也是固定的数据,所以应该必定能复现上面的输出结果。

猜你喜欢

转载自blog.csdn.net/masbbx123/article/details/80771129
SGD