【TensorFlow】:获取默认计算图与创建新的计算图

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/bqw18744018044/article/details/83218105

获取默认计算图:tf.get_default_graph()

import tensorflow as tf
a = tf.constant([1.0,2.0],name='a')
b = tf.constant([2.0,3.0],name='b')
result = a+b
# a.graph获取张量a所属的计算图
# tf.get_default_graph()获取默认的计算图
print(a.graph is tf.get_default_graph()) # 判断张量a是否属于默认计算图
True

创建新的计算图:tf.Graph()

import tensorflow as tf
g1 = tf.Graph() # 生成新的计算图g1
with g1.as_default():
    # tf.get_variable()与tf.Variable()均可以用来创建变量,但是前者不保证唯一性
    v = tf.get_variable("v",initializer=tf.zeros(shape=[1]))
    
g2 = tf.Graph() # 生成新的计算图g2
with g2.as_default():
    v = tf.get_variable("v",initializer=tf.ones(shape=[1]))
    
# 在计算图g1中读取变量“v”的取值
with tf.Session(graph=g1) as sess:
    # 初始化所有的变量
    tf.global_variables_initializer().run()
    # tf.variable_scope()指定变量的作用域,字符串作为变量的前缀
    # reuse=True表示重用变量,此时tf.get_variable("v")不会产生新的变量而是使用先前定义的变量
    with tf.variable_scope("",reuse=True):
        # 输出的是图g1的变量,值为0
        print(sess.run(tf.get_variable("v")))
        
# 在计算图g2中读取变量“v”的取值
with tf.Session(graph=g2) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope("",reuse=True):
        # 输出是图g2的变量,值为1
        print(sess.run(tf.get_variable("v")))
[0.]
[1.]

猜你喜欢

转载自blog.csdn.net/bqw18744018044/article/details/83218105