[ MOOC课程学习 ] 人工智能实践:Tensorflow笔记_CH4_3 滑动平均

滑动平均

  1. 滑动平均:记录了一段时间内模型中所有参数 w 和 b 各自的平均值。利用滑动平均值可以增强模型的泛化能力。
  2. 滑动平均值(影子)计算公式:
    影子 = 衰减率 * 影子 +(1 - 衰减率)* 参数
    影子初值 = 参数初值
    衰减率为:min(decay, (1 + num_updates) / (10 + num_updates))
  3. 用 Tesnsorflow 函数表示为:

    ema = tf.train.ExponentialMovingAverage(
        decay = EMA_DECAY, # 滑动平均衰减率,一般会赋接近 1 的值
        num_updates = global_step # 表示当前训练了多少轮
    )
    
    ema_op = ema.apply(tf.trainable_variables()) # 对括号内参数求滑动平均
    
    with tf.control_dependencies([train_step, ema_op]): #将滑动平均和训练过程同步运行
        train_op = tf.no_op(name='train')
    
    ema.average(var_name) # 查看模型中参数的平均值
  4. 例子:

    import tensorflow as tf
    
    w1 = tf.Variable(0, dtype=tf.float32)
    global_step = tf.Variable(0, trainable=False)
    
    EMA_DECAY = 0.99
    ema = tf.train.ExponentialMovingAverage(
        decay = EMA_DECAY,
        num_updates = global_step
    )
    
    ema_op = ema.apply(tf.trainable_variables())
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print('1.', sess.run([w1, ema.average(w1)]))
        # 1. [0.0, 0.0]
    
        sess.run(tf.assign(w1, 1))
        sess.run(ema_op)
        print('2.', sess.run([w1, ema.average(w1)]))
        # (0.1) * 0 + (1 - 0.1) * 1
        # 2. [1.0, 0.9]
    
    
        sess.run(tf.assign(global_step, 100))
        sess.run(tf.assign(w1, 10))
        sess.run(ema_op)
        print('3.', sess.run([w1, ema.average(w1)]))
        # (101/110) * 0.9 + (1 - 101/110) * 10
        # 3. [10.0, 1.6445453]
    
        sess.run(ema_op)
        print('4.', sess.run([w1, ema.average(w1)]))
        # (101/110) * 1.6445453 + (1 - 101/110) * 10
        # 4. [10.0, 2.3281732]
    
        sess.run(ema_op)
        print('5.', sess.run([w1, ema.average(w1)]))
        # (101/110) * 2.3281732 + (1 - 101/110) * 10
        # 5. [10.0, 2.955868]
    
        sess.run(ema_op)
        print('6.', sess.run([w1, ema.average(w1)]))
        # (101/110) * 2.955868 + (1 - 101/110) * 10
        # 6. [10.0, 3.532206]

猜你喜欢

转载自blog.csdn.net/ranmw1129/article/details/81088686