转自:不知道哪里和yychenxie21
相同点:都可以得到tensor a的尺寸
不同点:tf.shape()中a 数据的类型可以是tensor, list, array
a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)
注意到tf.shape(a)返回的是一个OP需要再sess.run(),而a.get_shape()得到一个实际的元组。
-
import tensorflow
as tf
-
import numpy
as np
-
-
x=tf.constant([[
1,
2,
3],[
4,
5,
6]]
-
y=[[
1,
2,
3],[
4,
5,
6]]
-
z=np.arange(
24).reshape([
2,
3,
4]))
-
-
sess=tf.Session()
-
# tf.shape()
-
x_shape=tf.shape(x)
# x_shape 是一个tensor
-
y_shape=tf.shape(y)
# <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32>
-
z_shape=tf.shape(z)
# <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32>
-
print sess.run(x_shape)
# 结果:[2 3]
-
print sess.run(y_shape)
# 结果:[2 3]
-
print sess.run(z_shape)
# 结果:[2 3 4]
-
-
-
#a.get_shape()
-
x_shape=x.get_shape()
# 返回的是TensorShape([Dimension(2), Dimension(3)]),不能使用 sess.run() 因为返回的不是tensor 或string,而是元组
-
x_shape=x.get_shape().as_list()
# 可以使用 as_list()得到具体的尺寸,x_shape=[2 3]
-
y_shape=y.get_shape()
# AttributeError: 'list' object has no attribute 'get_shape'
-
z_shape=z.get_shape()
# AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'
-
-
-
import tensorflow as tf
input_tensor=tf.get_variable(name="input",shape=[1,5,5,1],initializer=tf.truncated_normal_initializer(stddev=0.1))
input_tensor_size=tf.shape(input_tensor)
print("直接运行tf.shape() :",input_tensor_size) #输出是tf.Tensor 'Shape:0' shape=(4,) dtype=int32
print("运行tensor.get_shape() :",input_tensor.get_shape()) #输出是TensorShape([Dimension(1), Dimension(5), Dimension(5), Dimension(1)])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print("在sess.run()中运行tf.shape():",sess.run(input_tensor_size)) #输出是array([1, 5, 5, 1], dtype=int32)
print("===========================")