官方的输入定义如下:
def moments(x, axes, name=None, keep_dims=False)
解释如下:
x 可以理解为我们输出的数据,形如 [batchsize, height, width, kernels]
axes 表示在哪个维度上求解,是个list,例如 [0, 1, 2]
name 就是个名字,不多解释
keep_dims 是否保持维度,不多解释
这个函数的输出有两个,用官方的话说就是:
Two Tensor objects: mean andvariance.
解释如下:
mean 就是均值啦
variance 就是方差啦
个人试验后总结理解就是将x上除去axes所指定的纬度的剩余纬度组成的各个子元素看做个体,个体中的每个位置的值看做个体的不同位置属性,然后求所有个体在每种位置属性上的均值和方差
例子如下
eg1:
代码:
# coding: utf-8 import tensorflow as tf img = tf.Variable(tf.random_normal([128, 4, 2, 3])) axis = [0,1,2]#所以剩余的是四纬看做一个整体shape为[3] mean, variance = tf.nn.moments(img, axis) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) mean_, variance_ = sess.run([mean, variance]) print("均值:",mean_) print("方差:",variance_)
结果:
均值: [-0.03587267 0.06021447 0.02401767]
方差: [0.99473494 0.93040663 0.98113006]
eg2:
代码:
# coding: utf-8 import tensorflow as tf img = tf.Variable(tf.random_normal([128, 4, 2, 3])) axis = [0,1]#所以剩余的是第三 四纬看做一个整体shape为[2, 3] mean, variance = tf.nn.moments(img, axis) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) mean_, variance_ = sess.run([mean, variance]) print("均值:",mean_) print("方差:",variance_)
结果:
均值: [[-0.04313184 0.01417894 0.06847101]
[ 0.04183875 -0.01508999 -0.11406976]]
方差: [[0.976376 0.91841435 1.0207324 ]
[1.0403597 0.9773739 1.0360421 ]]