首先获取checkpoint的状态以及每个参数的值:
ckpt_state = tf.train.get_checkpoint_state(model_dir)
ckpts = (ckpt_state.all_model_checkpoint_paths)
avg_model_dir = os.path.join(model_dir, "avg_ckpts")
tf.gfile.MakeDirs()
var_list = tf.contrib.framework.list_variables(ckpts[0])
然后对checkpoint里的每个参数求平均:
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
var_values[name] = np.zeros(shape)
for ckpt in ckpts:
reader = tf.contrib.framework.load_checkpoint(ckpt)
for name in var_values:
tensor = reader.get_tensor(name)
var_dtypes[name] = tensor.dtype
var_values[name] += tensor
for name in var_values:
var_values[name] /= len(ckpts)
接下来将平均后的参数保存在一个新的checkpoint里面:
tf_vars = [tf.get_variable(name, dtype=var_dtypes[name], initializer=var_values[name]) for name in var_values]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, os.path.join(avg_model_dir, "qe.ckpt"))