转自:https://cloud.tencent.com/developer/article/1092432
今天说一下tensorflow的变量共享机制,首先为什么会有变量共享机制? 这个还是要扯一下生成对抗网络GAN,我们知道GAN由两个网络组成,一个是生成器网络G,一个是判别器网络D。G的任务是由输入的隐变量z生成一张图像G(z)出来,D的任务是区分G(z)和训练数据中的真实的图像(real images)。所以这里D的输入就有2个,但是这两个输入是共享D网络的参数的,简单说,也就是权重和偏置。而TensorFlow的变量共享机制,正好可以解决这个问题。但是我现在不能确定,TF的这个机制是不是因为GAN的提出才有的,还是本身就存在。
所以变量共享的目的就是为了在对网络第二次使用的时候,可以使用同一套模型参数。TF中是由Variable_scope来实现的,下面我通过几个栗子,彻底弄明白到底该怎么使用,以及使用中会出现的错误。栗子来源于文档,然后我写了不同的情况,希望能帮到你。
# - * - coding:utf-8 - * - import tensorflow as tf import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' def fc_variable(): v1 = tf.Variable( initial_value=tf.random_normal( shape=[2, 3], mean=0., stddev=1.), dtype=tf.float32, name='variable_1') print v1 print "- v1 - * " * 5 return v1 """ <tf.Variable 'variable_1:0' shape=(2, 3) dtype=float32_ref> - v1 - * - v1 - * - v1 - * - v1 - * - v1 - * """ def variable_value(variables): with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 如果没有这句会报错,所以tf在调用变量之前主要 # 先初始化 """ tensorflow.python.framework.errors_impl. FailedPreconditionError: Attempting to use uninitialized value variable_1 """ print '- * - value: - * - ' * 3 print sess.run(variables) """ [[ 0.00556329 0.20311342 -0.79569227] [ 0.1700473 0.9499892 -0.46801034]] """ def fc_variable_scope(): with tf.variable_scope("foo"): v = tf.get_variable("v", [1]) print v.name w = tf.get_variable("w", [1]) print w.name with tf.variable_scope("foo", reuse=True): v1 = tf.get_variable("v") print v1.name """ foo/v:0 foo/w:0 foo/v:0 """ # 解释: # 这里说明v1和v的相同的,还有这里用的是 # get_variable定义的变量,这个和Variable # 定义变量的区别是,如果变量存在get_variable # 会获得他的值,如果不存在则创建变量 def fc_variable_scope_v2(): with tf.variable_scope("foo"): v = tf.get_variable("v", [1]) print v.name w = tf.get_variable("w", [1]) print w.name with tf.variable_scope("foo", reuse=False): v1 = tf.get_variable("v") print v1.name """ ValueError: Variable foo/v already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at: """ # 解释: # 当reuse为False的时候由于v1在'fool'这个scope里面, # 所以和v的name是一样的,而reuse为False,变量命名就起了冲突。 def fc_variable_scope_v3(): with tf.variable_scope("foo"): v = tf.get_variable("v", [1]) print v.name w = tf.get_variable("w", [1]) print w.name with tf.variable_scope("foo", reuse=True): v1 = tf.get_variable("u", [1]) print v1.name """ ValueError: Variable foo/u does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope? """ # 解释: # 当reuse为True时时候,而这里定义了新变量u, # 之前不存在,这样也无法reuse。 def fc_variable_scope_v4(): with tf.variable_scope("foo"): v = tf.get_variable("v", [1]) print v.name w = tf.get_variable("w", [1]) print w.name with tf.variable_scope("foo", reuse=False): v1 = tf.get_variable("u") print v1.name """ ValueError: Shape of a new variable (foo/u) must be fully defined, but instead was <unknown>. """ # 解释: # 这里reuse为Flase,但是定义新变量的时候, # 必须define fully变量,也就是要指定变量 # 的shape或者初始值等。 def fc_variable_scope_v5(): with tf.variable_scope("foo"): v = tf.get_variable("v", [1]) print dir(v) print v.name w = tf.get_variable("w", [1]) print w.name with tf.variable_scope("foo", reuse=False): v1 = tf.get_variable("u", [1]) print v1.name """ foo/v:0 foo/w:0 foo/u:0 """ # 这样就没错了 def fc_variable_scope_v6(): with tf.variable_scope("foo"): v1 = tf.Variable(tf.random_normal( shape=[2, 3], mean=0., stddev=1.), dtype=tf.float32, name='v1') print v1.name v2 = tf.get_variable("v2", [1]) print v2.name with tf.variable_scope("foo", reuse=True): v3 = tf.get_variable('v2') print v3.name v4 = tf.get_variable('v1') print v4.name """ foo/v1:0 foo/v2:0 foo/v2:0 ValueError: Variable foo/v1 does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope? """ # 解释: # 这里虽然reuse为True,但是v1是由Variable定义的, # 不能被get。 def compare_name_and_variable_scope(): with tf.name_scope("hello") as ns: arr1 = tf.get_variable( "arr1", shape=[2, 10], dtype=tf.float32) print (arr1.name) print " - * -" * 5 with tf.variable_scope("hello") as vs: arr1 = tf.get_variable( "arr1", shape=[2, 10], dtype=tf.float32) print (arr1.name) """ arr1:0 - * - - * - - * - - * - - * - hello/arr1:0 """ #解释: # 这里除了name_scope和variable_scope不同, # 其他都相同,但是从他们的name,也能看出来区别了。 if __name__ == "__main__": fc_variable_scope_v6() # # 需要测试那个函数,直接写在这里。
简单总结一下,今天的内容主要是变量定义的两种方法,Variable个get_variable,还有变量的范围以及reuse是什么鬼。通过几个栗子,应该明白了。