最近在使用tensorflow的预训练模型,把自己的心得记录下来~
Tensorflow读取并输出已保存模型的权重数值
参考链接
https://blog.csdn.net/AManFromEarth/article/details/81057577
https://blog.csdn.net/aiseu001/article/details/79851176
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python import pywrap_tensorflow
#首先,使用tensorflow自带的python打包库读取模型
model_reader = pywrap_tensorflow.NewCheckpointReader(r"saver-test")
#然后,使reader变换成类似于dict形式的数据
var_dict = model_reader.get_variable_to_shape_map()
print(len(var_dict))#输出模型中的变量个数
print(var_dict) #输出模型中的变量名称
#提取模型中的某一个变量名称和值
w1 = model_reader.get_tensor("conv1/W") #提取模型中名为conv1/W的变量(conv1的权重参数)
print(type(w1)) #输出变量w1的类型 <class 'numpy.ndarray'>
print(w1.shape) #输出变量w1的形状 (11, 11, 3, 96)
# print(w1) #输出变量w1的值
#循环输出模型中所有参数的名称和值
for key in var_dict:
print("variable name: ", key)
print(model_reader.get_tensor(key))
#如果想要输出到文件,使用:
with open("output.txt","w+") as f:
#循环打印输出
for key in var_dict:
f.write(str(key))
f.write(str(model_reader.get_tensor(key)))
读取预训练模型,有选择性的加载参数
我是对原网络进行了一些修改,添加了一些新的参数,所以想要导入预训练模型来初始化原来就有的那部分参数,新添加的参数采用随机初始化。代码:
reader = tf.train.NewCheckpointReader("output/saver-test")
restore_dict = dict()
for v in tf.trainable_variables(): #只读取当前网络结构中待训练的变量
tensor_name = v.name.split(':')[0] # 把变量后面的:0去掉了(conv1/b:0->conv1/b)
# print(tensor_name) #输出训练变量的名称列表
if reader.has_tensor(tensor_name): #如果预训练模型中含有我们想要加载的参数,就把它添加到待restore的参数字典中
print('has tensor', tensor_name)
restore_dict[tensor_name] = v
saver = tf.train.Saver(restore_dict)#恢复指定的变量字典
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer()) #对所有参数进行随机初始化
sess.run(tf.local_variables_initializer())
saver.restore(sess, "output/saver-test") # restore指定的变量
# 获取当前网络结构中的conv1/b:0变量,查看其值,b的值就是预训练模型中保存的conv1/b的取值,说明restore成功
b=tf.get_default_graph().get_tensor_by_name("conv1/b:0")
print(sess.run(b))
从提取上一次训练结果继续训练
与第二种情况的区别就是,模型结构都是一样的,这个恢复起来更简单,直接这样操作即可,连参数初始化都不用~~~
saver = tf.train.Saver(max_to_keep=100) #max_to_keep参数是保存ckpt文件的个数,默认是5
with tf.Session(config=config) as sess:
saver.restore(sess, "output/saver")
分享一个链接:
https://blog.csdn.net/qq_25737169/article/details/78125061
https://stackoverflow.com/questions/52532150/how-to-restore-pretrained-model-to-initialize-parameters