L2-Regularization 实现的话,需要把所有的参数放在一个集合内,最后计算loss时,再减去加权值。
相比自己乱搞,代码一团糟,Tensorflow 提供了更优美的实现方法。
一、tf.GraphKeys : 多个包含Variables(Tensor)集合
(1)GLOBAL_VARIABLES:使用tf.get_variable()时,默认会将vairable放入这个集合。
我们熟悉的tf.global_variables_initializer()就是初始化这个集合内的Variables。
import tensorflow as tf sess=tf.Session() a=tf.get_variable("a",[3,3,32,64],initializer=tf.random_normal_initializer()) b=tf.get_variable("b",[64],initializer=tf.random_normal_initializer()) #collections=None等价于 collection=[tf.GraphKeys.GLOBAL_VARIABLES] gv= tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) #tf.get_collection(collection_name)返回某个collection的列表 for var in gv: print(var is a) print(var.get_shape())
Tips: tf.GraphKeys.GLOBAL_VARIABLES == "variable"。即其保存的是一个字符串。
(2)自定义集合
想个集合的名字,然后在tf.get_variable时,把集合名字传给 collection 就好了。
import tensorflow as tf
sess=tf.Session()
a=tf.get_variable("a",shape=[10],collections=["mycollection"]) #不把GLOBAL_VARIABLES加进去,那么就不在那个集合里了。
keys=tf.get_collection("mycollection")
for key in keys:
print(key.name)
二、L2正则化
先看看tf.contrib.layers.l2_regularizer(weight_decay)都执行了什么:
import tensorflow as tf sess=tf.Session() weight_decay=0.1 tmp=tf.constant([0,1,2,3],dtype=tf.float32) """ l2_reg=tf.contrib.layers.l2_regularizer(weight_decay) a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp) """ #**上面代码的等价代码 a=tf.get_variable("I_am_a",initializer=tmp) a2=tf.reduce_sum(a*a)*weight_decay/2; a3=tf.get_variable(a.name.split(":")[0]+"/Regularizer/l2_regularizer",initializer=a2) tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,a2) #** sess.run(tf.global_variables_initializer()) keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) for key in keys: print("%s : %s" %(key.name,sess.run(key)))
我们很容易可以模拟出tf.contrib.layers.l2_regularizer都做了什么,不过会让代码变丑。
以下比较完整实现L2 正则化。
import tensorflow as tf sess=tf.Session() weight_decay=0.1 #(1)定义weight_decay l2_reg=tf.contrib.layers.l2_regularizer(weight_decay) #(2)定义l2_regularizer() tmp=tf.constant([0,1,2,3],dtype=tf.float32) a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp) #(3)创建variable,l2_regularizer复制给regularizer参数。 #目测REXXX_LOSSES集合 #regularizer定义会将a加入REGULARIZATION_LOSSES集合 print("Global Set:") keys = tf.get_collection("variables") for key in keys: print(key.name) print("Regular Set:") keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) for key in keys: print(key.name) print("--------------------") sess.run(tf.global_variables_initializer()) print(sess.run(a)) reg_set=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) #(4)则REGULARIAZTION_LOSSES集合会包含所有被weight_decay后的参数和,将其相加 l2_loss=tf.add_n(reg_set) print("loss=%s" %(sess.run(l2_loss))) """ 此处输出0.7,即: weight_decay*sigmal(w*2)/2=0.1*(0*0+1*1+2*2+3*3)/2=0.7 其实代码自己写也很方便,用API看着比较正规。 在网络模型中,直接将l2_loss加入loss就好了。(loss变大,执行train自然会decay) """