上一节(TensorFlow2.0之计算过程与名字来历)中我们反复提到了图,但是代码中并没有看到图的定义,也没看到任何跟图有关的代码。其实,在TensorFlow中,TensorFlow会定义默认图。用户可以自己显式定义图,并将自定义图作为默认图。
TensorFlow的图中包含tf.Operation对象集合,一个tf.Operation对象表示一个计算单元,如加、减、乘、除等计算都是tf.Operation对象。TensorFlow中的图还包含tf.Tensor对象,tf.Tensor对象表示参与运算的数据,这些数据在路径的各个tf.Operation节点中参与运算,并在图中的各个路径中传递。
代码:
import tensorflow as tf
# 定义Python中int类型二维矩阵
A = [[1, 2, 3],
[4, 5, 6]]
B = [[1, 1],
[1, 1],
[1, 1]]
my_graph = tf.compat.v1.Graph()
with my_graph.as_default():
# 将Python类型数据A和B传入图中
A_tf = tf.compat.v1.constant(A, dtype=tf.float32, name="A")
B_tf = tf.compat.v1.constant(B, dtype=tf.float32, name="B")
# 构建图中的计算节点
C_tf = tf.compat.v1.matmul(A_tf, B_tf)
print("C_tf is my_graph:", C_tf.graph is my_graph)
# 图构建完毕
with tf.compat.v1.Session(graph=my_graph) as sess:
C = sess.run(C_tf)
print(C)
输出:
C_tf is my_graph: True
[[ 6. 6.]
[15. 15.]]
上述代码中,第8行自定义了图 ,并在第9-14行往自定义的图中加入了数据节点和计算节点:第15行打印验证加入自定义图中的结点是否正确;第17行将自定义的图作为tf.Session的默认图。从输出结果可以看到,每个数据对象的计算节点对象可以在指定图中存放。有一点需要注意的是,计算节点的输出数据对象会被放置到输入数据对象所在的图中。下面通过一个例子说明,代码如下:
import tensorflow as tf
# 定义Python中int类型二维矩阵
A = [[1, 2, 3],
[4, 5, 6]]
B = [[1, 1],
[1, 1],
[1, 1]]
my_graph1 = tf.compat.v1.Graph()
my_graph2 = tf.compat.v1.Graph()
with my_graph1.as_default():
# 将Python类型数据A传入图中
A_tf = tf.compat.v1.constant(A, dtype=tf.float32, name="A")
# 将Python类型数据B传入图中
B_tf = tf.compat.v1.constant(B, dtype=tf.float32, name="B")
with my_graph2.as_default():
# 试图将C_tf放入图my_graph2中
C_tf = tf.compat.v1.matmul(A_tf, B_tf)
print("C_tf.graph is my_graph1:", C_tf.graph is my_graph1)
print("C_tf.graph is my_graph2:", C_tf.graph is my_graph2)
输出:
C_tf.graph is my_graph1: True
C_tf.graph is my_graph2: False
可以看到 ,即使指定将C_tf放在图my_graph2中,还是无法改变C_tf实际存放在my_graph1中的事实。此外,矩阵相乘计算节点也不会在my_graph2中,而会在my_graph1中。接下来我们自定义不同的图,看看不同的图中的数据和计算节点之间交叉引用会怎么样。
代码:
import tensorflow as tf
# 定义Python中int类型二维矩阵
A = [[1, 2, 3],
[4, 5, 6]]
B = [[1, 1],
[1, 1],
[1, 1]]
my_graph1 = tf.compat.v1.Graph()
my_graph2 = tf.compat.v1.Graph()
my_graph3 = tf.compat.v1.Graph()
with my_graph1.as_default():
# 将Python类型数据A传入图中:
A_tf = tf.compat.v1.constant(A, dtype=tf.float32, name="A")
with my_graph2.as_default():
# 将Python类型数据B传入图中:
B_tf = tf.compat.v1.constant(B, dtype=tf.float32, name="B")
with my_graph3.as_default():
# 构建图中的计算节点:
C_tf = tf.matmul(A_tf, B_tf)
# 图构建完毕
with tf.compat.v1.Session(graph=my_graph3) as sess:
C = sess.run(C_tf)
print(C)
此时会报错,输出报错结果如下:
Traceback (most recent call last):
File "E:/Pycharm专业版/Workspace/Data_Science/gensim_operation/word2vec_test/tensorflow_test/preparation_work/图中数据与计算节点交叉引用.py", line 19, in <module>
C_tf = tf.matmul(A_tf, B_tf)
File "E:\Anaconda\Anaconda_Package\lib\site-packages\tensorflow_core\python\util\dispatch.py", line 180, in wrapper
return target(*args, **kwargs)
File "E:\Anaconda\Anaconda_Package\lib\site-packages\tensorflow_core\python\ops\math_ops.py", line 2687, in matmul
with ops.name_scope(name, "MatMul", [a, b]) as name:
File "E:\Anaconda\Anaconda_Package\lib\site-packages\tensorflow_core\python\framework\ops.py", line 6337, in __enter__
g_from_inputs = _get_graph_from_inputs(self._values)
File "E:\Anaconda\Anaconda_Package\lib\site-packages\tensorflow_core\python\framework\ops.py", line 5982, in _get_graph_from_inputs
_assert_same_graph(original_graph_element, graph_element)
File "E:\Anaconda\Anaconda_Package\lib\site-packages\tensorflow_core\python\framework\ops.py", line 5917, in _assert_same_graph
(item, original_item))
ValueError: Tensor("B:0", shape=(3, 2), dtype=float32) must be from the same graph as Tensor("A:0", shape=(2, 3), dtype=float32).
从上述报错结果中可以看到,不同图中的数据和计算节点相互引用计算,会出现错误。ValueError提示很明显,即在第19行中计算矩阵运算时,名为“A:0”的数据对象(即Tensor对象)与名为“B:0”的数据对象在不同图中。在构建图时,各个数据对象和计算节点对象必须在当前图中,不同图之间的资源是不能交叉引用的。
注意:tf.Graph()构造函数是非线程安全的函数,在创建图时需要在单线程或外部保证线程安全。