版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_29957455/article/details/86100641
tf.concat函数
:函数功能比较简单,主要用于连接两个数组
参数
:
values
:需要连接的数组axis
:从哪个维度来连接数组
例子
:
- 一维数组
import tensorflow as tf
if __name__ == "__main__":
a = [1,2,3]
b = [4,5,6]
c = tf.concat([a,b],0)
sess = tf.InteractiveSession()
print(sess.run(c)) #[1 2 3 4 5 6]
注意
:axis参数不能超过数组的维度。如果超过数组的维度,如下:
c = tf.concat([a,b],1)
则会报,ValueError: Shape must be at least rank 2 but is rank 1 for 'concat'
,意思是数组至少是二维,axis才能为1。
- 二维数组
a = [[1,1],[2,2],[3,3]]
b = [[4,4],[5,5],[6,6]]
c = tf.concat([a,b],0)
print(sess.run(c))
"""
[[1 1]
[2 2]
[3 3]
[4 4]
[5 5]
[6 6]]
"""
c = tf.concat([a,b],1) #等价于tf.concat([a,b],-1)
print(sess.run(c))
"""
[[1 1 4 4]
[2 2 5 5]
[3 3 6 6]]
"""
- 三维数组
a = [[[1,1],[2,2]],[[3,3],[4,4]]]
b = [[[5,5]],[[6,6]]]
c = tf.concat([a,b],1)
print(sess.run(c))
"""
[[[1 1]
[2 2]
[5 5]]
[[3 3]
[4 4]
[6 6]]]
"""
注意
:在使用tf.concat
函数连接两个数组的时候,数组该维度必须是一致的,否则会报错,如下:
c = tf.concat([a,b],0)
错误提示ValueError: Dimension 0 in both shapes must be equal, but are 2 and 1
,意思是a在第1个维度上shape是2,而b在第一个维度上shape是1。
总结
:如何来判断数组是否在该个维度上的shape是相同的呢?其实很简单,我们根据tf.concat的axis参数来去数组的[],0表示去掉最外面的一层,1去掉两层,以此类推,下面举例说明一下。
如:最后一个例子中的c = tf.concat([a,b],1)
,我们先将a去掉最外面两层[]
,变成了[1,1],[2,2]和[3,3],[4,4]]
,然后再将b去掉最外面两层[]
,变成了[5,5]和[6,6]
,此时再进行concat
,可以发现此时的shape是相等的。