Tensorflow中的dynamic shape、static shape及reshape、set_shape

转自:https://blog.csdn.net/qq_21949357/article/details/77987928

这个问题是在学习Tensorflow当中,reshape与set_shape的区别时引出的。

在学习Tensorflow的cifar10代码的时候,发现在处理cifar数据集的读入数据时,demo使用的是如下代码来处理读入数据而不是直接使用熟悉的reshape。

float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])
  • 1
  • 2

官方文档对此的解释是:

The tf.Tensor.set_shape method updates the static shape of a Tensor object, and it is typically used to provide additional shape information when this cannot be inferred directly. It does not change the dynamic shape of the tensor.

The tf.reshape operation creates a new tensor with a different dynamic shape.

由此引出了dynamic shape和static shape的概念。

Tensorflow在构建图的时候,tensor的shape被称为static(inferred);而在实际运行中,常常出现图中tensor的具体维数不确定而用placeholder代替的情况,因此static shape未必是已知的。tensor在训练过程中的实际维数被称为dynamic shape,而dynamic shape是一定的。

看如下例子:

import tensorflow as tf
x1 = tf.placeholder(tf.int32)
print(x1.get_shape())

sess = tf.Session()
print(sess.run(tf.shape(x1), feed_dict={x1:[0,1,2,3]}))
print(sess.run(tf.shape(x1), feed_dict={x1:[[0,1],[2,3]]}))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在实际运行的时候,显示的结果是

unknown 
[4] 
[2,2]

第一个结果代表构建图时x1的shape,此时它是未知的,尽管在Session中运行了以后给出了[0,1,2,3]的传入值,但无法改变图中的x1的static shape。而后面的print也可以看到,随着传入值的不同,x1的dynamic shape是会变化的。

顺带一说get_shape()方法和tf.shape()的区别,get_shape()是tensor的方法,返回一个tuple,而tf.shape()则返回一个tensor。第三行的unknown直接代表了x1的shape数组,假如我们这里用:

print(tf.shape(x1))
  • 1

就会看到显示的是一个Tensor,不过它的shape里有unknown。因此,我们在获取x1的shape的时候,要用tf.shape方法,让指针指向一个tensor,不然使用get_shape()指针就会指向一个tuple从而报错。

下面说下set_shape()和reshape()的区别。其实从官方说明中可以看出,这两个主要是适用场合的区别,前者用于更新图中某个tensor的shape,而后者则往往用于动态地创建一个新的tensor。

一个set_shape的典型用法如下:

import tensorflow as tf
x1 = tf.placeholder(tf.int32)
x1.set_shape([22])
print(x1.get_shape())

sess = tf.Session()
#print(sess.run(tf.shape(x1), feed_dict={x1:[0,1,2,3]}))
print(sess.run(tf.shape(x1), feed_dict={x1:[[0,1],[2,3]]}))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

此时,运行结果为:

(2,2) 
[2,2]

这代表了图中最开始没有shape的x1在使用了set_shape后,它的图中的信息已经改变了,如果取消掉注释就会报错,因为我们传入了和图不符合的参数。

reshape的典型用法则是这样:

import tensorflow as tf
x1 = tf.placeholder(tf.int32)
x2 = tf.reshape(x1, [2,2])
print(x1.get_shape())

sess = tf.Session()
print(sess.run(tf.shape(x2), feed_dict={x1:[0,1,2,3]}))
print(sess.run(tf.shape(x2), feed_dict={x1:[[0,1],[2,3]]}))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

此时运算结果为:

(2,2) 
[2,2] 
[2,2]

即它并不是想改变图,而只是想创造一个新的tensor以供我们使用。

但是reshape能否和set_shape有着相同的用法,即用来改变图?我们试着修改上面的代码:

import tensorflow as tf
x1 = tf.placeholder(tf.int32)
x1 = tf.reshape(x1, [2,2]) # use tf.reshape()
print(tf.shape(x1))

sess = tf.Session()
#print(sess.run(tf.shape(x1), feed_dict={x1:[0,1,2,3]}))
print(sess.run(tf.shape(x1), feed_dict={x1:[[0,1],[2,3]]}))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

经测试,reshape后x1的shape也发生了变化,注释不取消仍然会有报错现象。

那么set_shape和reshape的用法是否完全一样呢?还是有一定差别的。reshape可以改变原有tensor的shape,而set_shape只能更新信息没办法直接改变值,可以参考下面的程序:

import tensorflow as tf
x1 = tf.Variable([[0, 1], [2, 3]])
print(x1.get_shape())

x1 = tf.reshape(x1, [4, 1]) # if we use x1.set_shape([4, 1]),the program cannot run
print(x1.get_shape())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

最后总结一下吧,reshape应用场合比较广泛,当我们需要创建新的tensor或者动态地改变原有tensor的shape的时候可以使用;而当我们只是想更新图中某个tensor的shape或者补充某个tensor的shape信息可以使用set_shape来进行更新。

猜你喜欢

转载自blog.csdn.net/class_brick/article/details/80594283