在深度学习中,迁移学习是一个很普遍的操作,即将一个训练好的网络的一部分迁移到另一个网络,作为另一个网络结果的一部分.但是,我们要怎么保存和迁移呢?今天将以tensorflow的代码为例,给大家一个简单的介绍.
采用的函数是: tf.train.Saver()
1.存储和读取的步骤
(1)存储saver.save(sess, save_dir)
saver = tf.train.Saver()#声明ta.train.Saver()类用于保存
save_path = saver.save(sess,'save/filename.ckpt')#保存路径为相对路径的save文件夹,保存名为filename.ckpt
存储之后总共有几个后缀的文件:
filename.ckpt.meta:保存tensorflow的网络(计算图)结构
filename.ckpt:保存tensorflow中每一个变量的值
ckptpoint:保存一个目录下所有的模型文件列表
(2)读取saver.restore()
save.restore(sess, 'save/filename.ckpt')#从保存路径读取
在读取之前,先定义号和原来模型中相同的变量.读取出的结果直接赋值给变量使用
(3)直接测试已经训练好的模型
可以通过meta graph构建网络、载入训练时得到的参数,并使用默认的session:
saver = tf.train.import_meta_graph(‘save/filename.meta’)
saver.restore(tf.get_default_session(),’ save/filename.ckpt-16000’)
2.代码实现
代码实现我懒得写了,引用一个作者(Traphix)写好的,比较清晰明了: https://www.jianshu.com/p/83fa3aa2d0e9
(1)训练网络的
import tensorflow as tf
import sys
# load MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data', one_hot=True)
# 一些 hyper parameters
activation = tf.nn.relu
batch_size = 100
iteration = 20000
hidden1_units = 30
# 注意!这里是存储路径!
model_path = sys.path[0] + '/simple_mnist.ckpt'
X = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2))
b_fc1 = tf.Variable(tf.zeros([hidden1_units]))
W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2))
b_fc2 = tf.Variable(tf.zeros([10]))
def inference(img):
fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1))
logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)
return logits
def loss(logits, labels):
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels)
loss = tf.reduce_mean(cross_entropy)
return loss
def evaluation(logits, labels):
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
return accuracy
logits = inference(X)
loss = loss(logits, y_)
train_op = tf.train.AdamOptimizer(1e-4).minimize(loss)
accuracy = evaluation(logits, y_)
# 先实例化一个Saver()类
saver = tf.train.Saver()
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
for i in xrange(iteration):
batch = mnist.train.next_batch(batch_size)
if i%1000 == 0 and i:
train_accuracy = sess.run(accuracy, feed_dict={X: batch[0], y_: batch[1]})
print "step %d, train accuracy %g" %(i, train_accuracy)
sess.run(train_op, feed_dict={X: batch[0], y_: batch[1]})
print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.test.labels})
# 存储训练好的variables
save_path = saver.save(sess, model_path)
print "[+] Model saved in file: %s" % save_path
(2)测试
import tensorflow as tf
import sys
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data', one_hot=True)
activation = tf.nn.relu
hidden1_units = 30
model_path = sys.path[0] + '/simple_mnist.ckpt'
X = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2))
b_fc1 = tf.Variable(tf.zeros([hidden1_units]))
W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2))
b_fc2 = tf.Variable(tf.zeros([10]))
def inference(img):
fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1))
logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)
return logits
def evaluation(logits, labels):
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
return accuracy
logits = inference(X)
accuracy = evaluation(logits, y_)
saver = tf.train.Saver()
with tf.Session() as sess:
# 读取之前训练好的数据
load_path = saver.restore(sess, model_path)
print "[+] Model restored from %s" % load_path
print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.test.labels})