tf.transpose函数的用法讲解(多维情况,看似复杂,其实也简单)
- tf.transpose(a, perm=None, name='transpose')
- Transposes a. Permutes the dimensions according to perm.
- The returned tensor's dimension i will correspond to the input dimension perm[i]. If perm is not given, it is set to (n-1...0), where n is the rank of the input tensor. Hence by default, this operation performs a regular matrix transpose on 2-D input Tensors.
- For example:
- # 'x' is [[1 2 3]
- # [4 5 6]]
- tf.transpose(x) ==> [[1 4]
- [2 5]
- [3 6]]
- # Equivalently
- tf.transpose(x perm=[1, 0]) ==> [[1 4]
- [2 5]
- [3 6]]
- # 'perm' is more useful for n-dimensional tensors, for n > 2
- # 'x' is [[[1 2 3]
- # [4 5 6]]
- # [[7 8 9]
- # [10 11 12]]]
- # Take the transpose of the matrices in dimension-0
- tf.transpose(b, perm=[0, 2, 1]) ==> [[[1 4]
- [2 5]
- [3 6]]
- [[7 10]
- [8 11]
- [9 12]]]
- Args:
- •a: A Tensor.
- •perm: A permutation of the dimensions of a.
- •name: A name for the operation (optional).
- Returns:
- A transposed Tensor.
本文主要讨论高维度的情况:
为了形象理解高维情况,这里以矩阵组合举例:
先定义下: 2 x (3*4)表示2个3*4的矩阵,(其实,它是个3维张量)。
x = [[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[21,22,23,24],[25,26,27,28],[29,30,31,32]]]
输出:
---------------
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]]
[[21 22 23 24]
[25 26 27 28]
[29 30 31 32]]]
---------------
重点来了:
tf.transpose的第二个参数perm=[0,1,2],0代表三维数组的高(即为二维数组的个数),1代表二维数组的行,2代表二维数组的列。
tf.transpose(x, perm=[1,0,2])代表将三位数组的高和行进行转置。
我们写个测试程序如下:
扫描二维码关注公众号,回复:
1443810 查看本文章
- import tensorflow as tf
- #x = tf.constant([[1, 2 ,3],[4, 5, 6]])
- x = [[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[21,22,23,24],[25,26,27,28],[29,30,31,32]]]
- #a=tf.constant(x)
- a=tf.transpose(x, [0, 1, 2])
- b=tf.transpose(x, [0, 2, 1])
- c=tf.transpose(x, [1, 0, 2])
- d=tf.transpose(x, [1, 2, 0])
- e=tf.transpose(x, [2, 1, 0])
- f=tf.transpose(x, [2, 0, 1])
- # 'perm' is more useful for n-dimensional tensors, for n > 2
- # 'x' is [[[1 2 3]
- # [4 5 6]]
- # [[7 8 9]
- # [10 11 12]]]
- # Take the transpose of the matrices in dimension-0
- #tf.transpose(b, perm=[0, 2, 1])
- with tf.Session() as sess:
- print ('---------------')
- print (sess.run(a))
- print ('---------------')
- print (sess.run(b))
- print ('---------------')
- print (sess.run(c))
- print ('---------------')
- print (sess.run(d))
- print ('---------------')
- print (sess.run(e))
- print ('---------------')
- print (sess.run(f))
- print ('---------------')
我们期待的结果是得到如下矩阵:
a: 2 x 3*4
b: 2 x 4*3
c: 3 x 2*4
d: 3 x 4*2
e: 4 x 3*2
f: 4 x 2*2
运行脚本,结果一致,如下:
- ---------------
- [[[ 1 2 3 4]
- [ 5 6 7 8]
- [ 9 10 11 12]]
- [[21 22 23 24]
- [25 26 27 28]
- [29 30 31 32]]]
- ---------------
- [[[ 1 5 9]
- [ 2 6 10]
- [ 3 7 11]
- [ 4 8 12]]
- [[21 25 29]
- [22 26 30]
- [23 27 31]
- [24 28 32]]]
- ---------------
- [[[ 1 2 3 4]
- [21 22 23 24]]
- [[ 5 6 7 8]
- [25 26 27 28]]
- [[ 9 10 11 12]
- [29 30 31 32]]]
- ---------------
- [[[ 1 21]
- [ 2 22]
- [ 3 23]
- [ 4 24]]
- [[ 5 25]
- [ 6 26]
- [ 7 27]
- [ 8 28]]
- [[ 9 29]
- [10 30]
- [11 31]
- [12 32]]]
- ---------------
- [[[ 1 21]
- [ 5 25]
- [ 9 29]]
- [[ 2 22]
- [ 6 26]
- [10 30]]
- [[ 3 23]
- [ 7 27]
- [11 31]]
- [[ 4 24]
- [ 8 28]
- [12 32]]]
- ---------------
- [[[ 1 5 9]
- [21 25 29]]
- [[ 2 6 10]
- [22 26 30]]
- [[ 3 7 11]
- [23 27 31]]
- [[ 4 8 12]
- [24 28 32]]]
- ---------------