版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/shwan_ma/article/details/80488860
之前一直很困惑tf.Variable和tf.get_variable之间的区别,这几天算稍微明白一些。用简单的语言描述概括一下tf.Variable 和 tf.get_variable的一些特性
tf.Variable 和 tf.get_variable之间最主要的区别:
如果tf.Variable定义的时候,两个变量即使重名,那么是依然是两个独立的变量, tensorflow会自动增加变量后缀,以区分同名变量。
而 tf.get_variable定义的时候,tensorflow会自动去检查有没有命名一样的变量,如果出现一样,则会报错。但是如果设置了re-use,则不会报错,同时可以使得参数共享
首先,对于name_scope来说
如果在tf.name_scope下的话定义变量的话,
tf.Variable会在他的名字前面加上该name_scope的名字。
而 tf.get_varoable 则会无视掉name_scope
with tf.name_scope("a_name_scope"):
var1 = tf.get_variable(name="var1", shape=[1], dtype=tf.float32, initializer=tf.constant_initializer(1.0))
var2 = tf.Variable(name="var2", initial_value=[2], dtype=tf.float32)
var3 = tf.Variable(name="var2", initial_value=[2.1], dtype=tf.float32)
var4 = tf.Variable(name="var2", initial_value=[2.2], dtype=tf.float32)
sess.run(tf.global_variables_initializer())
print(var1.name)
print(sess.run(var1))
print(var2.name)
print(sess.run(var2))
print(var3.name)
print(sess.run(var3))
print(var4.name)
print(sess.run(var4))
输出结果:
var1:0
[ 1.]
a_name_scope/var2:0
[ 2.]
a_name_scope/var2_1:0
[ 2.0999999]
a_name_scope/var2_2:0
[ 2.20000005]
二, 对于variable_scope来说
则不管tf.Variable和tf.get_variable都会加上variable_scope的名字
with tf.variable_scope("a_variable_scope") as scope:
initializer = tf.constant_initializer(value=3)
var3 = tf.get_variable('var3', shape=[1], dtype=tf.float32, initializer=tf.constant_initializer(3.0))
var4 = tf.Variable(name="var4", initial_value=[4], dtype=tf.float32)
var4_reuse = tf.Variable(name="var4", initial_value=[4], dtype=tf.float32)
sess.run(tf.global_variables_initializer())
print(var3.name)
print(sess.run(var3))
print(var4.name)
print(sess.run(var4))
print(var4_reuse.name)
print(sess.run(var4_reuse))
输出结果:
a_variable_scope/var3:0
[ 3.]
a_variable_scope/var4:0
[ 4.]
a_variable_scope/var4_1:0
[ 4.]
三, 最重要的特性, tf_get_variable支持变量重用
with tf.variable_scope("a_variable_scope") as scope:
initializer = tf.constant_initializer(value=3)
var3 = tf.get_variable('var3', shape=[1], dtype=tf.float32, initializer=tf.constant_initializer(3.0))
scope.reuse_variables()
var3_reuse = tf.Variable(name="var3")
sess.run(tf.global_variables_initializer())
print(var3.name)
print(sess.run(var3))
print(var3_reuse.name)
print(sess.run(var3_reuse))
输出结果:(重用完成)
a_variable_scope/var3:0
[ 3.]
a_variable_scope/var3:0
[ 3.]