网络架构如下:
参数为a1,a2,b1,b2,
网络输出:y=a1*a2*x+b1+b2
目标函数:y=x
一、网络局部参数restore
应用场景:网络架构修改,但是部分参数需要重新利用;
设置方法:将var_list参数传给tf.train.Saver即可只save/restore var_list里的参数
如何使用:
(1)保存save:a1,a2,b1,b2分别为10,20,30,40
import tensorflow as tf import random #目标函数y=x #也就是网络收敛时:a1*a2=1,b1+b2=0 x=tf.placeholder(tf.float32,[1]) with tf.variable_scope("AB1"): a1=tf.Variable(tf.constant([10],dtype=tf.float32),name="A1") b1=tf.Variable(tf.constant([30],dtype=tf.float32),name="B1") with tf.variable_scope("AB2"): a2=tf.Variable(tf.constant([20],dtype=tf.float32),name="A2") b2=tf.Variable(tf.constant([40],dtype=tf.float32),name="B2") y=a1*a2*x+b1+b2 # _y=tf.placeholder(tf.float32,[1]) loss=tf.square(y-_y) sess=tf.Session() sess.run(tf.global_variables_initializer()) var_list_ab1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1') var_list_ab2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB2') saver=tf.train.Saver() saver.save(sess,"./10203040")
(2)恢复restore:
import tensorflow as tf import random #目标函数y=x #也就是网络收敛时:a1*a2=1,b1+b2=0 x=tf.placeholder(tf.float32,[1]) with tf.variable_scope("AB1"): a1=tf.Variable(tf.constant([1],dtype=tf.float32),name="A1") b1=tf.Variable(tf.constant([3],dtype=tf.float32),name="B1") with tf.variable_scope("AB2"): a2=tf.Variable(tf.constant([2],dtype=tf.float32),name="A2") b2=tf.Variable(tf.constant([4],dtype=tf.float32),name="B2") y=a1*a2*x+b1+b2 # _y=tf.placeholder(tf.float32,[1]) loss=tf.square(y-_y) sess=tf.Session() sess.run(tf.global_variables_initializer()) var_list_ab1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1') var_list_ab2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB2') saver=tf.train.Saver(var_list_ab1) saver.restore(sess,"./10203040") print(sess.run([a1,a2,b1,b2]))
二、只训练局部参数,如只训练a1,b1;而a2,b2保持不变
设置方法:在optimizer传入要训练的参数列表,即var_list参数:
方法一(基于variable_scope获取var_list):
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1') train=tf.train.GradientDescentOptimizer(1e-1).minimize(loss,var_list=var_list)
方法二(比较笨拙的方式):
import tensorflow as tf import random #目标函数y=x #也就是网络收敛时:a1*a2=1,b1+b2=0 x=tf.placeholder(tf.float32,[1]) with tf.variable_scope("AB1"): a1=tf.Variable(tf.constant([1],dtype=tf.float32),name="A1") b1=tf.Variable(tf.constant([3],dtype=tf.float32),name="B1") with tf.variable_scope("AB2"): a2=tf.Variable(tf.constant([2],dtype=tf.float32),name="A2") b2=tf.Variable(tf.constant([4],dtype=tf.float32),name="B2") y=a1*a2*x+b1+b2 # _y=tf.placeholder(tf.float32,[1]) loss=tf.square(y-_y) sess=tf.Session() sess.run(tf.global_variables_initializer()) #var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='AB1') var_list = [a1,b1] # !!!!!!!!!在这里,手工添加var_list train=tf.train.GradientDescentOptimizer(1e-1).minimize(loss,var_list=var_list) while True: input=[random.randint(0,100)*0.01] #不乘以0.0001,则网络无法收敛 label=[input[0]] _,a1v,a2v,b1v,b2v,lossv=sess.run([train,a1,a2,b1,b2,loss],feed_dict={x:input,_y:label}) if (lossv<1e-10): break print("train data=%s %s" %(input,label)) print("a=%s %s\n b=%s %s\n loss=%s" %(a1v,a2v,b1v,b2v,lossv))