文章目录
官方文档链接:
https://www.tensorflow.org/api_docs/python/tf/GraphKeys#GLOBAL_VARIABLES
公有以下标准键:
GLOBAL_VARIABLES
LOCAL_VARIABLES
MODEL_VARIABLES
TRAINABLE_VARIABLES
:tf.Optimizer
子类默认优化该类下的变量
SUMMARIES
QUEUE_RUNNERS
MOVING_AVERAGE_VARIABLES
REGULARIZATION_LOSSES
定义了以下标准键,但它们的集合不会像其他许多键那样自动填充:
WEIGHTS
BIASES
ACTIVATIONS
目前能够确定的键的关系如下
用tf.get_collection()
可以以list形式获取某个集合。
如
import tensorflow as tf
# 创建变量的3种方式
# 方式1
# tf.Variable(Tensor)
a = tf.Variable(tf.random_uniform(shape=[2,2], minval=0.0, maxval=1.0, dtype=tf.float32), name="a")
# 方式2
# tf.get_varizble(name=, shape=, initializer=)
b = tf.get_variable("b",
shape=[2,2],
initializer=tf.truncated_normal_initializer(mean=0.0, stddev=1.0, dtype=tf.float32))
# 方式3
# 与方式2类似,只是将initializer写到了variable_scope中
with tf.variable_scope("variable___scope", initializer=tf.truncated_normal_initializer(mean=10.0, stddev=1.0, dtype=tf.float32)):
c = tf.get_variable("c", shape=[2,2])
with tf.Session() as sess:
# writer = tf.summary.FileWriter("logs_test", sess.graph)
sess.run(tf.global_variables_initializer())
# aa=tf.global_variables() # 与下一句等价
aa=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
# bb=tf.trainable_variables() # 与下一句等价
bb=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print(type(aa))
print(aa)
print(type(bb))
print(bb)
print("a:\n", a.eval())
print("b:\n", b.eval())
print("c:\n", c.eval())
<class 'list'>
[<tf.Variable 'a:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'b:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'variable___scope/c:0' shape=(2, 2) dtype=float32_ref>]
<class 'list'>
[<tf.Variable 'a:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'b:0' shape=(2, 2) dtype=float32_ref>, <tf.Variable 'variable___scope/c:0' shape=(2, 2) dtype=float32_ref>]
a:
[[0.794688 0.50156784]
[0.7120378 0.08468437]]
b:
[[ 0.8215166 0.8389353 ]
[ 0.06688708 -0.4017645 ]]
c:
[[11.046042 9.249809 ]
[ 9.4180565 10.788679 ]]
总结
tensorflow的Graph中有很多collection,标准的有大概10个,分别是
公有以下标准键:
GLOBAL_VARIABLES
LOCAL_VARIABLES
MODEL_VARIABLES
TRAINABLE_VARIABLES
:tf.Optimizer
子类默认优化该类下的变量
SUMMARIES
QUEUE_RUNNERS
MOVING_AVERAGE_VARIABLES
REGULARIZATION_LOSSES
定义了以下标准键,但它们的集合不会像其他许多键那样自动填充:
WEIGHTS
BIASES
ACTIVATIONS
在定义变量的时候,这些变量会被自动分配到某些集合中。想要获取集合,可以用tf.get_collection
函数
如获取Graph中的GLOBAL_VARIABLES
集合,可用
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
或者tf.global_variables()
(二者等价)
。