1. 保存模型:
# -*- coding: UTF-8 -*-
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
W1 = tf.Variable(tf.random_normal([2]), name='w1')
W2 = tf.Variable(tf.constant(1, shape=[2, 2]), name='w2')
saver = tf.train.Saver()
with tf.Session() as session:
session.run(tf.global_variables_initializer())
saver.save(session, './checkpoint_dir/mymodel') # './checkpoint_dir/',模型保存的目录;'mymodel',模型的名字
2. 加载模型:
# -*- coding: UTF-8 -*-
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
saver=tf.train.import_meta_graph("./checkpoint_dir/mymodel.meta")
with tf.Session() as session:
saver.restore(session,tf.train.latest_checkpoint('./checkpoint_dir/'))
g=tf.get_default_graph()
w1=g.get_tensor_by_name('w1:0')
w2=g.get_tensor_by_name('w2:0')
print(session.run(w1))
print(session.run(w2))
# [0.01617053 1.2160776 ]
# [[1 1]
# [1 1]]
3. practice-保存模型:
# -*- coding: UTF-8 -*-
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
w1 = tf.placeholder(dtype=np.float32, name='w1')
w2 = tf.placeholder(dtype=np.float32, name='w2')
b1 = tf.Variable(2.0, name='bias')
feed_dict = {w1: 4, w2: 8}
w3=tf.add(w1,w2)
w4=tf.multiply(w3,b1,name='multiply_op')
saver=tf.train.Saver()
with tf.Session() as session:
session.run(tf.global_variables_initializer())
print(session.run(w4,feed_dict=feed_dict))
saver.save(session,'./checkpoint_dir/mymodel')
# 24.0
4. practice-加载模型,最后一层添加自己的op
# -*- coding: UTF-8 -*-
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# w1 = tf.placeholder(dtype=np.float32, name='w1')
# w2 = tf.placeholder(dtype=np.float32, name='w2')
#
# b1 = tf.Variable(2.0, name='bias')
#
# feed_dict = {w1: 4, w2: 8}
#
# w3=tf.add(w1,w2)
# w4=tf.multiply(w3,b1,name='multiply_op')
# saver=tf.train.Saver()
#
# with tf.Session() as session:
# session.run(tf.global_variables_initializer())
# print(session.run(w4,feed_dict=feed_dict))
# saver.save(session,'./checkpoint_dir/mymodel')
# 24.0
# checkpoint_dir
checkpoint_dir = "./checkpoint_dir"
# 1. 加载网络图
saver = tf.train.import_meta_graph(checkpoint_dir + '/mymodel.meta')
with tf.Session() as session:
# 2. 加载值
saver.restore(session, save_path=tf.train.latest_checkpoint(checkpoint_dir))
# 3. Now, access the variables and op that you want to run.
g = tf.get_default_graph()
w1 = g.get_tensor_by_name('w1:0')
w2 = g.get_tensor_by_name('w2:0')
w4 = g.get_tensor_by_name('multiply_op:0')
# 4. 添加自己的op
w5 = tf.add(w4, 10)
print(session.run([w4, w5], feed_dict={w1: 4, w2: 8}))