tf.shape(xxx) 和 xxx.get_shape()比较
- 相同点:都可以得到tensor xxx 的尺寸
- 不同点:tf.shape(xxx)中xxx数据的类型可以是tensor,list,array;而xxx.get_shape()中的xxx的数据类型必须是tensor,且返回的是一个tuple.可以通过xxx.get_shape().as_list()得到一个list。
例如:
x= tf.truncated_normal([32, 32, 3], dtype=tf.float32)
print(tf.shape(x))
print(x.get_shape())
print(x.get_shape().as_list())
- 1
- 2
- 3
- 4
- 5
输出:
Tensor("Shape:0", shape=(3,), dtype=int32)
(32, 32, 3)
[32, 32, 3]
- 1
- 2
- 3
注意:dtype=int32是tf.shape()这个op的输出类型,默认为tf.int32。