0. 前言
在使用TensorFlow要明确下面的概念:
- 用图表示计算任务;
- 在会话的上下文中执行图;
- 数据用Tensor表示;
- 变量用来维护状态;
- feed为操作赋值,fetch获取数据;
TensorFlow的程序一般有如下2个阶段:
- 构建阶段:构建图;
- 执行阶段:将构建好的图在会话中执行得到结果。
1. 构建图
下面将用代码来简单模拟一下构建图:
在默认的图中添加数据和操作,对2个数据a, b(矩阵)执行乘法操作得到结果c;再将a c作为输入,进行矩阵相加的操作。
import tensorflow as tf
tf.compat.v1.disable_v2_behavior()
# 1. 定义2个常量
a = tf.constant([[1, 2], [3, 4]], dtype=tf.int32, name='a')
b = tf.constant([5, 6, 7, 8], dtype=tf.int32, shape=[2, 2], name='b')
# 2. a,b作为输入,进行矩阵乘法操作
c = tf.matmul(a, b)
print("变量a是否在默认图中:{}".format(a.graph is tf.compat.v1.get_default_graph()))
print("变量c是否在默认图中:{}".format(c.graph is tf.compat.v1.get_default_graph()))
输出
变量a是否在默认图中:True
变量c是否在默认图中:True
因为在TensorFlow库中有一个默认的图,可以直接在其上添加节点,上边就是在默认图上添加节点,所以在查看变量是否在默认图时,返回的是true。
除此之外,也可以自己创建图:
# 创建另一张图
my_graph1 = tf.Graph()
此时,遇到的问题是,如何区分默认图和新创建的图呢,如下:
with my_graph1.as_default():
"""此代码块中使用的是创建的图my_graph1"""
m = tf.constant(7.0, name='const_m')
print(m)
print("变量m是否在默认图中:{}".format(m.graph is tf.compat.v1.get_default_graph()))
print("变量m是否在新图my_graph1中:{}".format(m.graph is my_graph1))
pass
输出
Tensor("const_m:0", shape=(), dtype=float32)
变量m是否在默认图中:True
变量m是否在新图my_graph1中:True
此处可能会有疑问,为什么都返回的是true;这是因为上边说明了在with下面的那个代码块中使用的是创建的新图,即这个代码块中默认的图就是我们创建的新图。
所以,把此句print("变量m是否在默认图中:{}".format(m.graph is tf.compat.v1.get_default_graph()))
放在with外部,再运行就会有如下输出
变量m是否在默认图中:False
注意:还需要明确的一点是,对于不同的图,不能对他们的变量进行操作。
为了说明上面这个问题,再创建一个新图my_graph2,然后对my_graph1中的变量m和my_graph2中的变量n执行add操作,看会出现什么情况:
my_graph2 = tf.Graph()
with my_graph2.as_default():
n = tf.constant(9.0, name='const_n')
print("变量n是否在新图my_graph2中:{}".format(n.graph is my_graph2))
pass
"""不能使用两个图中的变量进行操作"""
res = tf.add(m, n)
此时,程序会报错,有如下的错误提示:
Tensor("const_n:0", shape=(), dtype=float32) must be from the same graph
即必须是同一张图 !
# 3. 以a,c作为输入,对矩阵进行相加操作
add_ac = tf.add(a, c, name='matrix_add')
print(add_ac)
2. 执行图
执行阶段是在会话的上下文中进行的,所以就会涉及到创建会话、运行、最后关闭。
需要注意的是,跟上面图的创建一样,默认情况下,创建的session是属于默认图的。
# 4.执行阶段:在Session的上下文中进行
"""会话创建、运行、关闭"""
sess = tf.compat.v1.Session() # 属于默认图
result = sess.run(c) # 调用run,执行矩阵乘法,得到c的结果
print("type:{}, \nvalue:\n{}".format(type(result), result)) # 打印类型及值
sess.close() # 关闭会话
此句result = sess.run(c)
中run方法中的参数,就是想要得到结果的变量;比如想得到c的值,就传递c;想得到add_ac的值就将add_ac传进去。
输出结果:
type:<class 'numpy.ndarray'>,
value:
[[19 22]
[43 50]]
如果想一起得到2个结果的值:c和add_ac,可以有如下写法:
result = sess.run(fetches=[c, add_ac])
print("type:{}, \nvalue:\n{}\n".format(type(result), result)) # 打印类型及值
此时result的类型就是一个列表了
type:<class 'list'>,
value:
[array([[19, 22],
[43, 50]]), array([[20, 24],
[46, 54]])]
还需要注意的是,session是生命周期的,在关闭了之后就不能再对其进行操作了,不然就会出现此错误提示:RuntimeError: Attempted to use a closed Session.
上面需要在session用完了之后关闭,除了上面这种写法,还可以通过with来写,并且他就不需要人为close了(这种写法更加好):
with tf.compat.v1.Session() as sess2:
# 获得张量add_ac的结果
print("sess2:\n",sess2.run(add_ac))
pass
获得张量的结果有2中方法:c.eval()
或者session.run(c)