图(graph)
tf.Graph:
TensorFlow计算,表示为数据流图。一个图包含一组表示 tf.Operation计算单位的对象和tf.Tensor表示操作之间流动的数据单元的对象。默认Graph值始终注册,并可通过调用访问 tf.get_default_graph。
a = tf.constant(1.0)
assert c.graph is tf.get_default_graph()
我们可以发现这两个图是一样的。那么如何创建一个图呢,通过tf.Graph()
g1= tf.Graph()
g2= tf.Graph()
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(g1,g2,tf.get_default_graph())
默认图总是被registered,可以通过调用tf.get_default_graph
访问。要向默认图添加操作,只需调用被定义的函数即可:
import tensorflow as tf
a = tf.constant(4.0)
print(a.graph) # 得到a所在图对象的内存地址
print(tf.get_default_graph()) # 得到当前程序默认图对象的内存地址
assert a.graph == tf.get_default_graph() # 断言语句
print('-------------') # 该语句成功输出说明上面语句正确
输出:
<tensorflow.python.framework.ops.Graph object at 0x000001DE26E9CA58>
<tensorflow.python.framework.ops.Graph object at 0x000001DE26E9CA58>
-----
一个程序中多个Graph例子:
import tensorflow as tf
g1 = tf.Graph()
g2 = tf.Graph()
c0 = tf.constant(0.0)
with g1.as_default(): # 将g1设置为默认图然后再图中添加Operation
c1 = tf.constant(1.0)
with tf.Graph().as_default() as g2:
c2 = tf.constant(2.0)
with tf.Session() as sess: # c0被添加在默认图中无需指明图参数
assert c0.graph == tf.get_default_graph()
print(sess.run(c0)) # 0.0
with tf.Session(graph=g1) as sess1:
assert c1.graph is g1
print(sess1.run(c1)) # 1.0
with tf.Session(graph=g2) as sess2:
assert c2.graph is g2
print(sess2.run(c2)) # 2.0
会话(tf.Session())
tf.Session()
对象封装了执行Operation
对象和计算Tensor
对象的环境
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
sess = tf.Session()
print(sess.run(c))
在开启会话的时候指定图:
with tf.Session(graph=g) as sess:
资源释放
会话可能拥有很多资源,如 tf.Variable,tf.QueueBase和tf.ReaderBase。在不再需要这些资源时,重要的是释放这些资源。要做到这一点,既可以调用tf.Session.close会话中的方法,也可以使用会话作为上下文管理器。以下两个例子是等效的:
# 使用close手动关闭
sess = tf.Session()
sess.run(...)
sess.close()
# 使用上下文管理器
with tf.Session() as sess:
sess.run(...)
run方法介绍
run(fetches, feed_dict=None, options=None, run_metadata=None)
运行ops和计算tensor
- fetches 可以是单个图形元素,或任意嵌套列表,元组,namedtuple,dict或OrderedDict
- feed_dict 允许调用者覆盖图中指定张量的值
如果a,b是其它的类型,比如tensor,同样可以覆盖原先的值
a = tf.placeholder(tf.float32, shape=[])
b = tf.placeholder(tf.float32, shape=[])
c = tf.constant([1,2,3])
with tf.Session() as sess:
a,b,c = sess.run([a,b,c],feed_dict={a: 1, b: 2,c:[4,5,6]})
print(a,b,c)
常见错误:
- RuntimeError:如果它Session处于无效状态(例如已关闭)。
- TypeError:如果fetches或feed_dict键是不合适的类型。
- ValueError:如果fetches或feed_dict键无效或引用 Tensor不存在。