1.tf.concat()
tensorflow中用来拼接张量的函数tf.concat(),用法:
tf.concat([tensor1, tensor2, tensor3,...], axis)
先给出tf源代码中的解释:
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) # [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1) # [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
# tensor t3 with shape [2, 3]
# tensor t4 with shape [2, 3]
tf.shape(tf.concat([t3, t4], 0)) # [4, 3]
tf.shape(tf.concat([t3, t4], 1)) # [2, 6]
这里解释了当axis=0和axis=1的情况,怎么理解这个axis呢?其实这和numpy中的np.concatenate()用法是一样的。
axis=0 代表在第0个维度拼接
axis=1 代表在第1个维度拼接
对于一个二维矩阵,第0个维度代表最外层方括号所框下的子集,第1个维度代表内部方括号所框下的子集。维度越高,括号越小。
对于这种情况,我可以再解释清楚一点:
对于[ [ ], [ ]]和[[ ], [ ]],低维拼接等于拿掉最外面括号,高维拼接是拿掉里面的括号(保证其他维度不变)。注意:tf.concat()拼接的张量只会改变一个维度,其他维度是保存不变的。比如两个shape为[2,3]的矩阵拼接,要么通过axis=0变成[4,3],要么通过axis=1变成[2,6]。改变的维度索引对应axis的值。
这样就可以理解多维矩阵的拼接了,可以用axis的设置来从不同维度进行拼接。
对于三维矩阵的拼接,自然axis取值范围是[0, 1, 2]。
对于axis等于负数的情况
负数在数组索引里面表示倒数(countdown)。比如,对于列表ls = [1,2,3]而言,ls[-1] = 3,表示读取倒数第一个索引对应值。
axis=-1表示倒数第一个维度,对于三维矩阵拼接来说,axis=-1等价于axis=2。同理,axis=-2代表倒数第二个维度,对于三维矩阵拼接来说,axis=-2等价于axis=1。
一般在维度非常高的情况下,我们想在最'高'的维度进行拼接,一般就直接用countdown机制,直接axis=-1就搞定了。
2.tf.reduce_mean()
tf.reduce_mean 函数用于计算张量 tensor 沿着指定的数轴( tensor 的某一维度)上的的平均值,主要用作降维或者计算 tensor(图像)的平均值。
reduce_mean(input_tensor,
axis=None,
keepdims=False,
name=None,
reduction_indices=None)
第一个参数 input_tensor: 输入的待降维的 tensor;
第二个参数 axis: 指定的轴,如果不指定,则计算所有元素的均值;
第三个参数 keepdims:是否降维度,设置为 True,输出的结果保持输入 tensor 的形状,设置为 False,输出结果会降低维度;
第四个参数 name: 操作的名称;
第五个参数 reduction_indices:在以前版本中用来指定轴,已弃用;
举例:
import tensorflow as tf
x = [[1,2,3],[1,2,3]]
xx = tf.cast(x,tf.float32)
mean_all = tf.reduce_mean(xx, keep_dims=False)
mean_0 = tf.reduce_mean(xx, axis=0, keep_dims=False)
mean_1 = tf.reduce_mean(xx, axis=1, keep_dims=False)
with tf.Session() as sess:
m_a,m_0,m_1 = sess.run([mean_all, mean_0, mean_1])
print m_a # output: 2.0
print m_0 # output: [ 1. 2. 3.]
print m_1 #output: [ 2. 2.]
相似的还有:
tf.reduce_sum
:计算 tensor 指定轴方向上的所有元素的累加和;tf.reduce_max
: 计算 tensor 指定轴方向上的各个元素的最大值;tf.reduce_all
: 计算 tensor 指定轴方向上的各个元素的逻辑和(and 运算);tf.reduce_any
: 计算 tensor 指定轴方向上的各个元素的逻辑或(or 运算);
3.tf.cond()
类似于c语言中的if...else...,用来控制数据流向
例子:
x = tf.constant(2 )
y = tf.constant(5 )
def f1 (): return tf .multiply( x , 17 )
def f2 (): return tf .add ( y , 23 )
r = tf .cond( tf.less( X ,y ), f1 , f2 )
#r 设置为f1().
#f2 中的操作(例如,tf.add)不执行.
4.tf.argmax(input,axis)
根据axis取值的不同返回每行或者每列最大值的索引。 这个很好理解,只是tf.argmax()的参数让人有些迷惑,比如,tf.argmax(array, 1)和tf.argmax(array, 0)有啥区别呢? 这里面就涉及到一个概念:axis。上面例子中的1和0就是axis。我先笼统的解释这个问题,设置axis的主要原因是方便我们进行多个维度的计算。
比如:
test = np.array([[1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]])
np.argmax(test, 0) #输出:array([3, 3, 1]
np.argmax(test, 1) #输出:array([2, 2, 0, 0]123
axis = 0:
axis=0时比较每一列的元素,将每一列最大元素所在的索引记录下来,最后输出每一列最大元素所在的索引数组。
test[0] = array([1, 2, 3])
test[1] = array([2, 3, 4])
test[2] = array([5, 4, 3])
test[3] = array([8, 7, 2])
# output : [3, 3, 1]
axis = 1:
axis=1的时候,将每一行最大元素所在的索引记录下来,最后返回每一行最大元素所在的索引数组。
test[0] = array([1, 2, 3]) #2
test[1] = array([2, 3, 4]) #2
test[2] = array([5, 4, 3]) #0
test[3] = array([8, 7, 2]) #0
这是里面都是数组长度一致的情况,如果不一致,axis最大值为最小的数组长度-1,超过则报错。
当不一致的时候,axis=0的比较也就变成了每个数组的和的比较。
5.tf.transpose()
x = [[1,3,5], [2,4,6]] 二维数组为2行3列的矩阵
tf.transpose(x, perm=[1,
0]),perm[1,0]代表将数组的行和列进行交换,代表矩阵的转置,转置之后为3行2列
结果为:
[
[1,2],
[3,4],
[5,6]
]
input_data = tf.constant([[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]]) 此3维数组为1x4x3,可以看成是一个4x3的二维数组
output_data = tf.transpose(input_data, perm=[1, 2, 0])
print(sess.run(output_data))
# [[[ 1]
# [ 2]
# [ 3]]
# [[ 4]
# [ 5]
# [ 6]]
#
# [[ 7]
# [ 8]
# [ 9]]
#
# [[10]
# [11]
# [12]]]
6.tf.reshape()
-1相当于一个未知数,可以把它当做x,即tf.reshape(a,(3,x))
程序内部会自动计算x的值,并自动变
7.tf.sequence_mask()
sequence_mask(
lengths,
maxlen=None,
dtype=tf.bool,
name=None
)
返回一个表示每个单元的前N个位置的mask张量
函数参数
- lengths:整数张量,其所有值小于等于maxlen。
- maxlen:标量整数张量,返回张量的最后维度的大小;默认值是lengths中的最大值。
- dtype:结果张量的输出类型。
- name:操作的名字。
示例:
8. tf.shape()
将矩阵的维度输出为一个维度矩阵
import tensorflow as tf
import numpy as np
A = np.array([[[1, 1, 1],
[2, 2, 2]],
[[3, 3, 3],
[4, 4, 4]],
[[5, 5, 5],
[6, 6, 6]]])
t = tf.shape(A)
with tf.Session() as sess:
print(sess.run(t))
# 输出
[3 2 3]