tensorflow中的静态维度和动态维度

参考:

1. TensorFlow: Shapes and dynamic dimensions一文中,对张量的静态和动态维度做了描述。

  • 使用tf.get_shape()获取静态维度
  • 使用tf.shape获取动态维度
    如果你的placeholder输入的维度都是固定的情况下,使用get_shape()。但是很多情况下,我们希望想训练得到的网络可以用于任意大小的图像,这时你的placeholder就的输入维度都是[None,None,None,color_dim]这样的,在这种情况下,后续网络中如果需要得到tensor的维度,则需要使用tf.shape。

2 . https://blog.csdn.net/LoseInVain/article/details/78762739

3. tensor.shap.as_list()返回静态维度。tf.shape(tensor)返回动态维度

参考bert->modeling->get_shape_list()函数

def get_shape_list(tensor, expected_rank=None, name=None):
  """Returns a list of the shape of tensor, preferring static dimensions.

  Args:
    tensor: A tf.Tensor object to find the shape of.
    expected_rank: (optional) int. The expected rank of `tensor`. If this is
      specified and the `tensor` has a different rank, and exception will be
      thrown.
    name: Optional name of the tensor for the error message.

  Returns:
    A list of dimensions of the shape of tensor. All static dimensions will
    be returned as python integers, and dynamic dimensions will be returned
    as tf.Tensor scalars.
  """
  if name is None:
    name = tensor.name

  if expected_rank is not None:
    assert_rank(tensor, expected_rank, name)
  # tensor.shape.as_list()返回静态维度
  shape = tensor.shape.as_list()

  non_static_indexes = []
  for (index, dim) in enumerate(shape):
    if dim is None:
      non_static_indexes.append(index)

  if not non_static_indexes:
    return shape

  dyn_shape = tf.shape(tensor)
  for index in non_static_indexes:
    shape[index] = dyn_shape[index]
  return shape

猜你喜欢

转载自blog.csdn.net/biubiubiu888/article/details/86526653