tensorflow中tensor的static shape是在graph构建的时候就已经确定了。
import tensorflow as tf
a = tf.placeholder(tf.float32, [None, 128])
static_shape = a.shape.as_list()
dynamic_shape = tf.shape(a)
print(static_shape)
print(dynamic_shape)
None表示不确定,可以为任何尺寸,可以在Session.run()的时候动态确定。tf.shape()获取dynamic shape,返回的是一个tensor。
输出:
[None, 128]
Tensor("Shape:0", shape=(2,), dtype=int32)
可通过set_shape()设置tensor的shape,
a.set_shape([32, 128])
static_shape = a.shape.as_list()
dynamic_shape = tf.shape(a)
print(static_shape)
print(dynamic_shape)
输出:
[32, 128]
Tensor("Shape:0", shape=(2,), dtype=int32)
也可以通过tf.reshape()重设shape,
a = tf.placeholder(tf.float32, [None, 128])
a = tf.reshape(a, [32, 128])
static_shape = a.shape.as_list()
dynamic_shape = tf.shape(a)
print(static_shape)
print(dynamic_shape)
输出:
[32, 128]
Tensor("Shape:0", shape=(2,), dtype=int32)
可以定义一个函数,当static shape为None时返回dynamic shape,否则返回static shape。
def get_shape(tensor):
static_shape = tensor.shape.as_list()
dynamic_shape = tf.unstack(tf.shape(tensor))
dims = [s[1] if s[0] is None else s[0]
for s in zip(static_shape, dynamic_shape)]
return dims
b = tf.placeholder(tf.float32, [None, 10, 32])
shape = get_shape(b)
print(shape)
b = tf.reshape(b, [shape[0], shape[1] * shape[2]])
shape = get_shape(b)
print(shape)
输出:
[<tf.Tensor 'unstack:0' shape=() dtype=int32>, 10, 32]
[<tf.Tensor 'unstack_1:0' shape=() dtype=int32>, 320]