tf.reshap() tf.shape(x)与x.get_shape()

 tf.rehspe()用法 : tf.reshape(tensor, shape, name=None)

>>> import tensorflow as tf
>>> import numpy as np

## 创建一个数组a

>>> a = np.arange(24)
>>> a
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23])

① reshape中shape默认用列表传入

>>> tf.reshape(a,[12,2])
<tf.Tensor 'Reshape:0' shape=(12, 2) dtype=int32>

>>> sess.run(tf.reshape(a,[4,6]))
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],

       [18, 19, 20, 21, 22, 23]])

② reshape中shape中的-1用法

## 行数固定4,列数默认计算

>>> sess.run(tf.reshape(a,[4,-1]))
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23]])

## 列数固定6,行数默认计算
>>> sess.run(tf.reshape(a,[-1,4]))
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15],
       [16, 17, 18, 19],
       [20, 21, 22, 23]])

③ reshape中shape中传入三个参数的意思

>>> sess.run(tf.reshape(a,[2,3,4]))
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])

 tf.shape(a)用法和a.get_shape()

① 两个函数都可以的到tensor的尺寸

② tf.shape(a)中的数据类型可以是tensor,array,list 但是a.get_shape 只能是tensor, 且返回值是元组

##创建一个数组a
>>> a = np.arange(24)
>>> a
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23])

## 从a的array中创建一个tensor b
>>> b = tf.reshape(a,[4,6])
>>> b
<tf.Tensor 'Reshape_11:0' shape=(4, 6) dtype=int32>
>>> sess.run(b)
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],

       [18, 19, 20, 21, 22, 23]])

## 创建一个列表c
>>> c = [1,2,3]
>>> c
[1, 2, 3]
我们来看tf.shape(x)函数作用a,b,c的输出:

>>> sess.run(tf.shape(a))
array([24])
>>> sess.run(tf.shape(b))
array([4, 6])
>>> sess.run(tf.shape(c))
array([3])
我们来看x.get_shape()函数作用a,b,c的输出:

>>> a.get_shape()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>

AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'

>>> b.get_shape()

TensorShape([Dimension(4), Dimension(6)])

>>> c.get_shape()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>

AttributeError: 'list' object has no attribute 'get_shape'

说明x.get_shape()中,x只能是tensor,否则报错,并且返回的是一个元组,可以分别取出行和列:

>>> b.get_shape()
TensorShape([Dimension(4), Dimension(6)])
>>> print(b.get_shape())
(4, 6)
>>> print( b.get_shape()[0])
4
>>> print( b.get_shape()[1])
6
>>> b.get_shape()[0].value
4
>>> b.get_shape()[1].value
6

参考: https://blog.csdn.net/fireflychh/article/details/73611021

猜你喜欢

转载自blog.csdn.net/li_haiyu/article/details/80063842
tf