tf.rehspe()用法 : tf.reshape(tensor, shape, name=None)
>>> import tensorflow as tf>>> import numpy as np
## 创建一个数组a
>>> a = np.arange(24)
>>> a
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23])
① reshape中shape默认用列表传入
>>> tf.reshape(a,[12,2])
<tf.Tensor 'Reshape:0' shape=(12, 2) dtype=int32>
array([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]])
② reshape中shape中的-1用法
## 行数固定4,列数默认计算
>>> sess.run(tf.reshape(a,[4,-1]))
array([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]])
## 列数固定6,行数默认计算
>>> sess.run(tf.reshape(a,[-1,4]))
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]])
③ reshape中shape中传入三个参数的意思
>>> sess.run(tf.reshape(a,[2,3,4]))array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tf.shape(a)用法和a.get_shape()
① 两个函数都可以的到tensor的尺寸
② tf.shape(a)中的数据类型可以是tensor,array,list 但是a.get_shape 只能是tensor, 且返回值是元组
##创建一个数组a
>>> a = np.arange(24)
>>> a
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23])
>>> b = tf.reshape(a,[4,6])
>>> b
<tf.Tensor 'Reshape_11:0' shape=(4, 6) dtype=int32>
>>> sess.run(b)
array([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]])
## 创建一个列表c
>>> c = [1,2,3]
>>> c
[1, 2, 3]
我们来看tf.shape(x)函数作用a,b,c的输出:
>>> sess.run(tf.shape(a))
array([24])
>>> sess.run(tf.shape(b))
array([4, 6])
>>> sess.run(tf.shape(c))
array([3])
我们来看x.get_shape()函数作用a,b,c的输出:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'
>>> b.get_shape()TensorShape([Dimension(4), Dimension(6)])
>>> c.get_shape()Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'list' object has no attribute 'get_shape'
说明x.get_shape()中,x只能是tensor,否则报错,并且返回的是一个元组,可以分别取出行和列:
>>> b.get_shape()
TensorShape([Dimension(4), Dimension(6)])
>>> print(b.get_shape())
(4, 6)
>>> print( b.get_shape()[0])
4
>>> print( b.get_shape()[1])
6
>>> b.get_shape()[0].value
4
>>> b.get_shape()[1].value
6
参考: https://blog.csdn.net/fireflychh/article/details/73611021