版权声明:转载请标明出处,谢谢! https://blog.csdn.net/kdongyi/article/details/82910632
函数形式:
tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
参数:
- value :输入的tensor
- num_or_size_splits :每个分割后的张量的尺寸,如果是个整数n,就将输入的tensor分为n个子tensor。如果是个tensor T,就将输入的tensor分为len(T)个子tensor。
- axis :A 0-D int32 Tensor;表示分割的尺寸;必须在[-rank(value), rank(value))范围内;默认为0。
- num=None :可选的,用于指定无法从 size_splits 的形状推断出的输出数。
- name :操作的名称(可选)
用途:将张量分割成子张量。
-
如果num_or_size_splits是整数类型,num_split,则value沿维度 axis 分割成为num_split更小的张量。要求num_split均匀分配value.shape[axis]。
-
如果num_or_size_splits不是整数类型,则它被认为是一个张量size_splits,然后将value分割成len(size_splits)块。第i部分的形状与value的大小相同,除了沿维度axis之外的大小size_splits[i]。
代码实例:
import tensorflow as tf
value = [[1, 2, 3], [4, 5, 6]]
split1, split2 = tf.split(value, [0, 2], 0)
split3, split4, split5 = tf.split(value, [1, 0, 2], 1)
with tf.Session() as sess:
print("第一个变换结果:")
print(sess.run(split1))
print("**************")
print(sess.run(split2))
print("**************\r\n")
print("第二个变换结果:")
print(sess.run(split3))
print("**************")
print(sess.run(split4))
print("**************")
print(sess.run(split5))
运行结果:
第一个变换结果:
[]
**************
[[1 2 3]
[4 5 6]]
**************
第二个变换结果:
[[1]
[4]]
**************
[]
**************
[[2 3]
[5 6]]