1、tf.control_dependencies
首先我们先介绍tf.control_dependencies
,该函数保证其辖域中的操作必须要在该函数所传递的参数中的操作完成后再进行。请看下面一个例子。
import tensorflow as tf
a_1 = tf.Variable(1)
b_1 = tf.Variable(2)
update_op = tf.assign(a_1, 10)
add = tf.add(a_1, b_1)
a_2 = tf.Variable(1)
b_2 = tf.Variable(2)
update_op = tf.assign(a_2, 10)
with tf.control_dependencies([update_op]):
add_with_dependencies = tf.add(a_2, b_2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ans_1, ans_2 = sess.run([add, add_with_dependencies])
print("Add: ", ans_1)
print("Add_with_dependency: ", ans_2)
输出:
Add: 3
Add_with_dependency: 12
可以看到两组加法进行的对比,正常的计算图在计算add时是不会经过update_op操作的,因此在加法时a的值为1,但是采用tf.control_dependencies函数,可以控制在进行add前先完成update_op的操作,因此在加法时a的值为10,因此最后两种加法的结果不同。
2、tf.GraphKeys.UPDATE_OPS
关于tf.GraphKeys.UPDATE_OPS,这是一个tensorflow的计算图中内置的一个集合,其中会保存一些需要在训练操作之前完成的操作,并配合tf.control_dependencies函数使用。
关于在batch_norm中,即为更新mean和variance的操作。通过下面一个例子可以看到tf.layers.batch_normalization中是如何实现的。
import tensorflow as tf
is_traing = tf.placeholder(dtype=tf.bool)
input = tf.ones([1, 2, 2, 3])
output = tf.layers.batch_normalization(input, training=is_traing)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(update_ops)
# with tf.control_dependencies(update_ops):
# train_op = optimizer.minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, "batch_norm_layer/Model")
输出:
[<tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(3,) dtype=float32_ref>, <tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(3,) dtype=float32_ref>]
可以看到输出的即为两个batch_normalization中更新mean和variance的操作,需要保证它们在train_op前完成。
这两个操作是在tensorflow的内部实现中自动被加入tf.GraphKeys.UPDATE_OPS这个集合的,在tf.contrib.layers.batch_norm的参数中可以看到有一项updates_collections的默认值即为tf.GraphKeys.UPDATE_OPS,而在tf.layers.batch_normalization中则是直接将两个更新操作放入了上
三、关于最初的错误使用的思考
最后我对于一开始的使用方法为什么会导致错误进行了思考,tensorflow中具体实现batch_normalization的代码在tensorflow\python\layers\normalization.py
中,下面展示一些关键代码。
if self.scale:
self.gamma = self.add_variable(
name='gamma',
shape=param_shape,
dtype=param_dtype,
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
trainable=True)
else:
self.gamma = None
if self.center:
self.beta = self.add_variable(
name='beta',
shape=param_shape,
dtype=param_dtype,
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
trainable=True)
else:
self.beta = None
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
self.moving_mean = self._add_tower_local_variable(
name='moving_mean',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_mean_initializer,
trainable=False)
self.moving_variance = self._add_tower_local_variable(
name='moving_variance',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_variance_initializer,
trainable=False)
def _assign_moving_average(self, variable, value, momentum):
with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope:
decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
if decay.dtype != variable.dtype.base_dtype:
decay = math_ops.cast(decay, variable.dtype.base_dtype)
update_delta = (variable - value) * decay
return state_ops.assign_sub(variable, update_delta, name=scope)
def _do_update(var, value):
return self._assign_moving_average(var, value, self.momentum)
# Determine a boolean value for `training`: could be True, False, or None.
training_value = utils.constant_value(training)
if training_value is not False:
mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
moving_mean = self.moving_mean
moving_variance = self.moving_variance
mean = utils.smart_cond(training,
lambda: mean,
lambda: moving_mean)
variance = utils.smart_cond(training,
lambda: variance,
lambda: moving_variance)
else:
new_mean, new_variance = mean, variance
mean_update = utils.smart_cond(
training,
lambda: _do_update(self.moving_mean, new_mean),
lambda: self.moving_mean)
variance_update = utils.smart_cond(
training,
lambda: _do_update(self.moving_variance, new_variance),
lambda: self.moving_variance)
if not context.executing_eagerly():
self.add_update(mean_update, inputs=inputs)
self.add_update(variance_update, inputs=inputs)
outputs = nn.batch_normalization(inputs,
_broadcast(mean),
_broadcast(variance),
offset,
scale,
self.epsilon)
可以看到其内部逻辑和我在介绍tf.nn.batch_normalization一节中展示的封装时所使用的方法类似。
如果不在使用时添加tf.control_dependencies函数,即在训练时(training=True)每批次时只会计算当批次的mean和var,并传递给tf.nn.batch_normalization进行归一化,由于mean_update和variance_update在计算图中并不在上述操作的依赖路径上,因为并不会主动完成,也就是说,在训练时mean_update和variance_update并不会被使用到,其值一直是初始值。因此在测试阶段(training=False)使用这两个作为mean和variance并进行归一化操作,这样就会出现错误。而如果使用tf.control_dependencies函数,会在训练阶段每次训练操作执行前被动地去执行mean_update和variance_update,因此moving_mean和moving_variance会被不断更新,在测试时使用该参数也就不会出现错误。