目录
1.变量管理
import tensorflow as tf
v=tf.get_variable("v",shape=[1],initializer=tf.constant_initializer(1.0))
print(v)
v=tf.Variable(tf.constant(1.0,shape=[1]),name="v")
print(v)
tf.get_variable()必须指定变量名称,从名称的位置也可以看出,变量名称的重要性。该函数通过该名称去创建或获取变量。
通过tf.variable_scope()创建上下文管理器,这时可以通过tf.get_variable()获取已有变量
with tf.variable_scope("foo"):
v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
with tf.variable_scope("foo",reuse=True): #设置reuse=True,get_variable可以直接获取已经生命的变量
v1=tf.get_variable("v",[1])
print(v==v1)
嵌套上下文管理器:
a.不指定reuse为True,取值与外层一样
b.推出reuse设置为True的上下文后,reuse又回到False
使用上下文管理器后,就不需要重复传入变量
使用了上下文管理器
def inference(input_tensor,reuse=False):
with tf.variable_scope("layer1",reuse=reuse):
#根据传进来的reuse来判断是使用新的变量还是已有的,
#第一次使用创建新的变量,以后都直接使用reuse=True就不需要每次都将变量传进来了
weights=tf.get_variable("weights",[INPUT_NODE,LAYER1_NODE],
initializer=tf.truncated_normal_initializer(stddev=0.1))
biases=tf.get_variable("biaes",[LAYER1_NODE],
initializer=tf.constant_initializer(0.0))
layer1=tf.nn.relu(tf.matmul(input_tensor,weights)+biaes)
with tf.variable_scope("layer2",reuse=reuse):
#根据传进来的reuse来判断是使用新的变量还是已有的,
#第一次使用创建新的变量,以后都直接使用reuse=True就不需要每次都将变量传进来了
weights=tf.get_variable("weights",[LAYER1_NODE,OUTPUT_NODE],
initializer=tf.truncated_normal_initializer(stddev=0.1))
biases=tf.get_variable("biaes",[OUTPUT_NODE],
initializer=tf.constant_initializer(0.0))
layer2=tf.nn.relu(tf.matmul(input_tensor,weights)+biaes)
return layer2
没有使用上下文管理器
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
# 不使用滑动平均类
if avg_class == None:
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
return tf.matmul(layer1, weights2) + biases2
else:
# 使用滑动平均类
layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1))
return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2)
2.模型持久化
tensorflow通过tf.train.Saver类让训练结果可以重复使用
#定义
saver=tf.train.Saver()
saver.save(sess,"/path/to/model/model.ckpt")#此时保存是将计算图和图上的参数取值分开保存的,会有三个文件
#加载,带重复定义
saver=tf.train.Saver()
saver.resore(sess,"/path/to/model/model.ckpt")
#直接加载已经持久化的图
import tensorflow as tf
saver=tf.train.import_meta_graph(
"/path/to/model/model.ckpt/model.ckpt.meta")
with tf.Session() as sess:
saver.resore(sess,"/path/to/model/model/model.ckpt")
print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))
#通过字典重命名变量
saver=tf.train.Saver({"v1":v1,"v2":v2})
#加载重命名滑动平均变量,tf.train.ExponentialMovingAverage中的variables_to_restore
具体的持久化原理和数据格式参考chapter 5.4.2