tensorflow中获取shape的方法比较

tf.shape(xxx) 和 xxx.get_shape()比较

  1. 相同点:都可以得到tensor xxx 的尺寸
  2. 不同点:tf.shape(xxx)中xxx数据的类型可以是tensor,list,array;而xxx.get_shape()中的xxx的数据类型必须是tensor,且返回的是一个tuple.可以通过xxx.get_shape().as_list()得到一个list。

例如:

x= tf.truncated_normal([32, 32, 3], dtype=tf.float32)

print(tf.shape(x))
print(x.get_shape())
print(x.get_shape().as_list())
  • 1
  • 2
  • 3
  • 4
  • 5

输出:

Tensor("Shape:0", shape=(3,), dtype=int32)
(32, 32, 3)
[32, 32, 3]
  • 1
  • 2
  • 3

注意:dtype=int32是tf.shape()这个op的输出类型,默认为tf.int32。

猜你喜欢

转载自blog.csdn.net/yinxingtianxia/article/details/78121941
今日推荐