TensorFlow函数:tf.transpose()及其参数perm的理解
tf.transpose函数
tf.transpose(
a,
perm=None,
name='transpose',
conjugate=False
)
函数参数:
- a:一个 Tensor.
- perm:a 的维数的排列.
- name:操作的名称(可选).
- conjugate:可选 bool,将其设置为 True 在数学上等同于 tf.conj(tf.transpose(input)).
返回:
- tf.transpose 函数返回一个转置 Tensor.
对于参数perm的理解:
tensorflow 里面的 tensor是先从高维向低维算起的 。
我们先设置一个张量(Tensor)a:
a = tf.constant([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
这里a的维数就是[2,2,3]
而perm的意思就是将tensor对应的不同的维数之间变换,
比如
- 若a的维度不变时,perm = [0,1,2]
- 当perm = [2,1,0],则原tensor = [3,2,2],全部倒过来
- 当perm = [0,2,1],则原tensor = [2,3,2],后两维置换
tf.transpose(x, perm=[0, 2, 1])
==> [[[1,4],
[2,5] ,
[3,6]] ,
[[7,10] ,
[8,11] ,
[9,12]]]
其他代码例子
x = tf.constant([[1, 2, 3], [4, 5, 6]])
tf.transpose(x) # [[1, 4]
# [2, 5]
# [3, 6]]
# Equivalently
tf.transpose(x, perm=[1, 0]) # [[1, 4]
# [2, 5]
# [3, 6]]
# If x is complex, setting conjugate=True gives the conjugate transpose
x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
[4 + 4j, 5 + 5j, 6 + 6j]])
tf.transpose(x, conjugate=True) # [[1 - 1j, 4 - 4j],
# [2 - 2j, 5 - 5j],
# [3 - 3j, 6 - 6j]]
# 'perm' is more useful for n-dimensional tensors, for n > 2
x = tf.constant([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
# Take the transpose of the matrices in dimension-0
# (this common operation has a shorthand `matrix_transpose`)
tf.transpose(x, perm=[0, 2, 1]) # [[[1, 4],
# [2, 5],
# [3, 6]],
# [[7, 10],
# [8, 11],
# [9, 12]]]
参考资料:(感谢!)
https://www.w3cschool.cn/tensorflow_python/tensorflow_python-z61d2no5.html
https://blog.csdn.net/appleml/article/details/71070767