一、环境
TensorFlow API r1.12
CUDA 9.2 V9.2.148
cudnn64_7.dll
Python 3.6.3
Windows 10
二、官方说明
给输入张量的形状增加1个维度
https://www.tensorflow.org/api_docs/python/tf/expand_dims
tf.expand_dims(
input,
axis=None,
name=None,
dim=None
)
输入:
(1)input:输入张量
(2)axis:标量,指定在哪个维度上给输入张量增加一个维度,范围必须在 [-输入张量的秩,+输入张量的秩]
(3)name:输出结果张量的名称
(4)dim:标量,等同于axis,被弃用
返回结果:
(1)比输入张量多1维但是包含相同数据的张量
三、实例
(1)tf.expand_dims(input_tensor,0)
>>> import tensorflow as tf
>>> input = tf.constant([1,2], shape=[2])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2])
>>> sess.run(tf.shape(tf.expand_dims(input,0)))
array([1, 2])
>>> sess.run(tf.expand_dims(input,0))
array([[1, 2]])
>>> sess.close()
(2)tf.expand_dims(input_tensor,1)
>>> import tensorflow as tf
input = tf.constant([1,2], shape=[2])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([1])
>>> sess.run(tf.shape(tf.expand_dims(input,1)))
array([2, 1])
>>> sess.run(tf.expand_dims(input,1))
array([[1],
[2]])
(3)tf.expand_dims(input_tensor,-1)
>>> import tensorflow as tf
>>> input = tf.constant([1,2], shape=[2])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([1])
>>> sess.run(tf.shape(tf.expand_dims(input,-1)))
array([2, 1])
>>> sess.run(tf.expand_dims(input,-1))
array([[1],
[2]])
(4)多维拓展0
>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,0)))
array([1, 2, 3, 5])
>>> sess.run(tf.expand_dims(input,0))
array([[[[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]],
[[16, 17, 18, 19, 20],
[21, 22, 23, 24, 25],
[26, 27, 28, 29, 30]]]])
>>> sess.close()
(5)多维拓展2
>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,2)))
array([2, 3, 1, 5])
>>> sess.run(tf.expand_dims(input,2))
array([[[[ 1, 2, 3, 4, 5]],
[[ 6, 7, 8, 9, 10]],
[[11, 12, 13, 14, 15]]],
[[[16, 17, 18, 19, 20]],
[[21, 22, 23, 24, 25]],
[[26, 27, 28, 29, 30]]]])
(6)多维拓展3
>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,3)))
array([2, 3, 5, 1])
>>> sess.run(tf.expand_dims(input,3))
array([[[[ 1],
[ 2],
[ 3],
[ 4],
[ 5]],
[[ 6],
[ 7],
[ 8],
[ 9],
[10]],
[[11],
[12],
[13],
[14],
[15]]],
[[[16],
[17],
[18],
[19],
[20]],
[[21],
[22],
[23],
[24],
[25]],
[[26],
[27],
[28],
[29],
[30]]]])