Table of Contents
一、模型部分(成功)
1.保存的模型
import tensorflow as tf
import numpy as np
# To plot pretty figures
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
t_min, t_max = 0, 30
resolution = 0.1
def time_series(t):
return t * np.sin(t) / 3 + 2 * np.sin(t*5)
def next_batch(batch_size, n_steps):
t0 = np.random.rand(batch_size, 1) * (t_max - t_min - n_steps * resolution)
Ts = t0 + np.arange(0., n_steps + 1) * resolution
ys = time_series(Ts)
return ys[:, :-1].reshape(-1, n_steps, 1), ys[:, 1:].reshape(-1, n_steps, 1)
t = np.linspace(t_min, t_max, int((t_max - t_min) / resolution))
n_steps = 20
t_instance = np.linspace(12.2, 12.2 + resolution * (n_steps + 1), n_steps + 1)
n_steps = 20
n_inputs =1
n_neurons = 100
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs],name='X')
y = tf.placeholder(tf.float32, [None, n_steps, n_outputs],name='y')
cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons, activation=tf.nn.relu )
rnn_outputs, states = tf.nn.dynamic_rnn(cell=cell, inputs=X, dtype=tf.float32)
n_outputs = 1
learning_rate = 0.001
stacked_rnn_outputs = tf.reshape(tensor=rnn_outputs, shape=[-1, n_neurons])
stacked_outputs = tf.layers.dense(inputs=stacked_rnn_outputs, units=n_outputs)
outputs = tf.reshape(tensor=stacked_outputs, shape=[-1, n_steps, n_outputs],name='outputs')
loss = tf.reduce_mean(tf.square(outputs - y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
n_iterations = 1500
batch_size = 50
with tf.Session() as sess:
init.run()
for iteration in range(n_iterations):
X_batch, y_batch = next_batch(batch_size, n_steps)
sess.run(training_op, feed_dict={X:X_batch, y:y_batch})
if iteration %100 ==0:
mse = loss.eval(feed_dict={X:X_batch, y:y_batch})
print(iteration, "\tMSE:", mse)
X_new = time_series(np.array(t_instance[:-1].reshape(-1, n_steps, n_inputs)))
y_pred = sess.run(outputs, feed_dict={X: X_new})
saver.save(sess, "./my_time_series_model")
2.载入模型并用于预测
import tensorflow as tf
import numpy as np
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
reset_graph()
1.载入图结构和参数
sess = tf.Session()
saver = tf.train.import_meta_graph('./my_time_series_model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
INFO:tensorflow:Restoring parameters from ./my_time_series_model
2.获取图
graph = tf.get_default_graph()
3.获取tensor
X = graph.get_tensor_by_name('X:0')
y = graph.get_tensor_by_name("y:0")
outputs = graph.get_tensor_by_name("outputs:0")
4.新的input准备
X_new = time_series(np.array(t_instance[:-1].reshape(-1, 20, 1)))
y.shape
TensorShape([Dimension(None), Dimension(20), Dimension(1)])
5.应用与预测
y_pred = sess.run(outputs, feed_dict={X: X_new})
print(y_pred)
[[[-3.44828582]
[-2.48405623]
[-1.13649726]
[ 0.71962416]
[ 2.01745081]
[ 3.13937259]
[ 3.54828739]
[ 3.36234236]
[ 2.77184248]
[ 2.10781217]
[ 1.64527285]
[ 1.5579648 ]
[ 1.87219918]
[ 2.7233479 ]
[ 3.85228252]
[ 5.06193066]
[ 6.07513857]
[ 6.63054752]
[ 6.59069633]
[ 5.9993453 ]]]
plt.title("Testing the model", fontsize=14)
plt.plot(t_instance[:-1], time_series(t_instance[:-1]), "bo", markersize=10, label="instance")
plt.plot(t_instance[1:], time_series(t_instance[1:]), "w*", markersize=10, label="target")
plt.plot(t_instance[1:], y_pred[0,:,0], "r.", markersize=10, label="prediction")
plt.legend(loc="upper left")
plt.xlabel("Time")
# save_fig("time_series_pred_plot")
plt.show()
n_steps=20
sequence1 = [0.]*20
for iteration in range(len(t) - n_steps):
X_batch = np.array(sequence1[-n_steps:]).reshape(1, 20, 1)
y_red = sess.run(outputs, feed_dict={X:X_batch})
sequence1.append(y_pred[0,-1,0])
sequence2 = [time_series(i*resolution+t_min+(t_max-t_min/3)) for i in range(20)]
for iteration in range(len(t) - n_steps):
X_batch = np.array(sequence2[-n_steps:]).reshape(1, n_steps, 1)
y_pred = sess.run(outputs, feed_dict={X:X_batch})
sequence2.append(y_pred[0,-1,0])
sequence2
[-11.310069100186947,
-10.293466647143219,
-9.0351304228581899,
-7.7791526184875721,
-6.7463786174128,
-6.0810586339718817,
-5.8164786703312812,
-5.8679774217158975,
-6.0550484641549618,
-6.1471037973674534,
-5.9216796338876341,
-5.2208171081907668,
-3.991795656706794,
-2.3022262397260698,
-0.32578778096384303,
1.6979234284416602,
3.5196245383171041,
4.938914081326276,
5.8508622581105945,
6.2692415768875769,
6.2414274,
6.0063138,
5.7208457,
5.6004872,
5.7733955,
6.2796745,
7.0539942,
7.9423361,
8.7371979,
9.2194023,
9.2210941,
8.6577063,
7.55442,
6.0555,
4.37221,
2.7345319,
1.3436384,
0.32472184,
-0.3159388,
-0.6737811,
-0.92999107,
-1.3043951,
-1.9887228,
-3.10217,
-4.6800175,
-6.6327167,
-8.7743168,
-10.866323,
-12.663467,
-13.970881,
-14.69599,
-14.907141,
-14.78536,
-14.409622,
-14.090124,
-14.002066,
-14.257333,
-15.155252,
-16.459118,
-18.0779,
-19.749088,
-21.18313,
-22.134678,
-22.438381,
-22.068893,
-21.052273,
-19.576031,
-18.089518,
-16.549288,
-15.274453,
-14.365784,
-13.75396,
-13.394633,
-13.047565,
-12.466007,
-11.425198,
-9.8138237,
-7.6150599,
-4.9504156,
-2.0713785,
0.75190848,
3.2330799,
5.1569152,
6.4215879,
7.0548859,
7.2106261,
7.115716,
7.0375557,
7.2040081,
7.7478371,
8.6662283,
9.8395824,
11.034622,
11.977497,
12.413702,
12.166514,
11.191109,
9.5760918,
7.5540438,
5.3992167,
3.4150782,
1.812041,
0.70379823,
0.057163272,
-0.30561826,
-0.6333279,
-1.2131691,
-2.2672105,
-3.9158158,
-6.1540804,
-8.8191996,
-11.635813,
-14.286608,
-16.462812,
-17.943993,
-18.685251,
-18.793633,
-18.618328,
-18.118195,
-17.951168,
-18.098774,
-18.851744,
-20.435009,
-22.447792,
-24.747829,
-26.917685,
-28.524897,
-29.264297,
-29.117596,
-28.16255,
-26.49226,
-24.614424,
-23.007265,
-21.564766,
-20.551498,
-19.873325,
-19.364033,
-18.779448,
-17.821594,
-16.258081,
-13.970452,
-10.981444,
-7.4684205,
-3.7501016,
-0.15635261,
2.9772096,
5.4395175,
7.132719,
8.1294727,
8.6363735,
8.9281454,
9.2941914,
9.9697638,
11.079341,
12.562345,
14.237216,
15.809559,
16.950453,
17.381609,
16.938213,
15.613298,
13.565508,
11.076948,
8.5601444,
6.3773327,
4.7868357,
3.8481448,
3.5138502,
3.4309714,
3.2506099,
2.566304,
1.0682412,
-1.3450102,
-4.6130099,
-8.4637041,
-12.477572,
-16.190722,
-19.199072,
-21.232229,
-22.2255,
-22.393307,
-22.079338,
-21.793861,
-21.614162,
-22.374886,
-23.66847,
-25.927404,
-28.822844,
-31.882042,
-34.618858,
-36.46328,
-37.181564,
-36.713959,
-35.274097,
-32.990501,
-30.3675,
-28.14781,
-26.146204,
-24.562925,
-23.419697,
-22.426777,
-21.289103,
-19.713984,
-17.391224,
-14.248542,
-10.339687,
-5.8647647,
-1.2366034,
3.1689858,
6.9699464,
9.9272709,
11.941613,
13.105474,
13.669283,
13.953389,
14.285711,
14.963319,
16.072838,
17.565983,
19.225279,
20.71352,
21.667736,
21.790447,
20.899879,
19.012274,
16.317251,
13.150064,
9.8907776,
7.0072351,
4.6846075,
3.029146,
1.946041,
1.0919217,
0.098311812,
-1.4406949,
-3.7629437,
-7.0077753,
-11.050987,
-15.585717,
-20.164085,
-24.306299,
-27.590773,
-29.736839,
-30.848682,
-31.069016,
-31.061485,
-30.686995,
-31.03178,
-31.748112,
-33.299461,
-35.919754,
-38.907295,
-42.115463,
-44.881657,
-46.655895,
-47.255661,
-46.659004,
-44.970089,
-42.384876,
-39.614048,
-37.0881,
-34.543285,
-32.361893,
-30.367081,
-28.412891,
-26.267321,
-23.673107,
-20.30426,
-16.241322,
-11.44988,
-6.1824598,
-0.85328394,
4.2204714,
8.6730185,
12.300264,
15.019421,
16.92049,
18.221352,
19.190708,
20.100592,
21.194742,
22.507717,
23.986784,
25.428883,
26.544825,
27.042883,
26.692993,
25.375931,
23.142214,
20.169025,
16.782326,
13.413164,
10.326519,
7.692627,
5.611692,
3.8167367,
2.0629253,
-0.033325758,
-2.7715981,
-6.3110862]
6.其他内容
6.1 查看tensor、node等
方法一:pywrap_tensorflow
tf.train.get_checkpoint_state(checkpoint_dir='checkpoint路径') # checkpoint路径比如放在C:\Users\Administrator\Documents\checkpoint,这里填写r'C:\Users\Administrator\Documents\'即可
import os
logdir='./'
from tensorflow.python import pywrap_tensorflow
ckpt = tf.train.get_checkpoint_state(logdir)
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key) # tensor的名称
# print(reader.get_tensor(key))
# 参考:https://blog.csdn.net/wc781708249/article/details/78040735
tensor_name: dense/kernel
tensor_name: beta1_power
tensor_name: beta2_power
tensor_name: dense/bias
tensor_name: rnn/basic_rnn_cell/bias
tensor_name: dense/bias/Adam_1
tensor_name: dense/bias/Adam
tensor_name: dense/kernel/Adam_1
tensor_name: dense/kernel/Adam
tensor_name: rnn/basic_rnn_cell/bias/Adam
tensor_name: rnn/basic_rnn_cell/bias/Adam_1
tensor_name: rnn/basic_rnn_cell/kernel
tensor_name: rnn/basic_rnn_cell/kernel/Adam
tensor_name: rnn/basic_rnn_cell/kernel/Adam_1
方法二:inspect_checkpoint
inspect_checkpoint.print_tensors_in_checkpoint_file(file_name=,tensor_name=,all_tensors=) # file_name参数填写路径,比如checkpoint等四个模型文件存放在C:\Users\Administrator\Documents,其中meta文件C:\Users\Administrator\Documents\my_time_series_model.meta,所以file_name=r'C:\Users\Administrator\Documents\my_time_series_model'
help(chkp.print_tensors_in_checkpoint_file)
Help on function print_tensors_in_checkpoint_file in module tensorflow.python.tools.inspect_checkpoint:
print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors)
Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
in the checkpoint file.
If `tensor_name` is provided, prints the content of the tensor.
Args:
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
all_tensors: Boolean indicating whether to print all tensors.
#使用inspect_checkpoint来查看ckpt里的内容
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file(file_name="./my_time_series_model",
tensor_name=None, # 如果为None,则默认为ckpt里的所有变量
all_tensors=False, # bool 是否打印所有的tensor,这里打印出的是tensor的值,一般不推荐这里设置为False
) # bool 是否打印所有的tensor的name
#上print_tensors_in_checkpoint_file其实是用NewCheckpointReader实现的。
beta1_power (DT_FLOAT) []
beta2_power (DT_FLOAT) []
dense/bias (DT_FLOAT) [1]
dense/bias/Adam (DT_FLOAT) [1]
dense/bias/Adam_1 (DT_FLOAT) [1]
dense/kernel (DT_FLOAT) [100,1]
dense/kernel/Adam (DT_FLOAT) [100,1]
dense/kernel/Adam_1 (DT_FLOAT) [100,1]
rnn/basic_rnn_cell/bias (DT_FLOAT) [100]
rnn/basic_rnn_cell/bias/Adam (DT_FLOAT) [100]
rnn/basic_rnn_cell/bias/Adam_1 (DT_FLOAT) [100]
rnn/basic_rnn_cell/kernel (DT_FLOAT) [101,100]
rnn/basic_rnn_cell/kernel/Adam (DT_FLOAT) [101,100]
rnn/basic_rnn_cell/kernel/Adam_1 (DT_FLOAT) [101,100]
方法三:查看所node的名称
先载入模型,获取图结构,然后打印图结构中的node。
(便于获取变量为基于模型的应用服务)
tf.get_default_graph().as_graph_def().node
# 查看所有的tensor名称
[tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
['X',
'y',
'Rank',
'range/start',
'range/delta',
'range',
'concat/values_0',
'concat/axis',
'concat',
'transpose',
'rnn/Shape',
'rnn/strided_slice/stack',
'rnn/strided_slice/stack_1',
'rnn/strided_slice/stack_2',
'rnn/strided_slice',
'rnn/strided_slice_1/stack',
'rnn/strided_slice_1/stack_1',
'rnn/strided_slice_1/stack_2',
'rnn/strided_slice_1',
'rnn/BasicRNNCellZeroState/ExpandDims/dim',
'rnn/BasicRNNCellZeroState/ExpandDims',
'rnn/BasicRNNCellZeroState/Const',
'rnn/BasicRNNCellZeroState/concat/axis',
'rnn/BasicRNNCellZeroState/concat',
'rnn/BasicRNNCellZeroState/ExpandDims_1/dim',
'rnn/BasicRNNCellZeroState/ExpandDims_1',
'rnn/BasicRNNCellZeroState/Const_1',
'rnn/BasicRNNCellZeroState/zeros/Const',
'rnn/BasicRNNCellZeroState/zeros',
'rnn/Shape_1',
'rnn/strided_slice_2/stack',
'rnn/strided_slice_2/stack_1',
'rnn/strided_slice_2/stack_2',
'rnn/strided_slice_2',
'rnn/strided_slice_3/stack',
'rnn/strided_slice_3/stack_1',
'rnn/strided_slice_3/stack_2',
'rnn/strided_slice_3',
'rnn/ExpandDims/dim',
'rnn/ExpandDims',
'rnn/Const',
'rnn/concat/axis',
'rnn/concat',
'rnn/zeros/Const',
'rnn/zeros',
'rnn/time',
'rnn/TensorArray',
'rnn/TensorArray_1',
'rnn/TensorArrayUnstack/Shape',
'rnn/TensorArrayUnstack/strided_slice/stack',
'rnn/TensorArrayUnstack/strided_slice/stack_1',
'rnn/TensorArrayUnstack/strided_slice/stack_2',
'rnn/TensorArrayUnstack/strided_slice',
'rnn/TensorArrayUnstack/range/start',
'rnn/TensorArrayUnstack/range/delta',
'rnn/TensorArrayUnstack/range',
'rnn/TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3',
'rnn/while/Enter',
'rnn/while/Enter_1',
'rnn/while/Enter_2',
'rnn/while/Merge',
'rnn/while/Merge_1',
'rnn/while/Merge_2',
'rnn/while/Less/Enter',
'rnn/while/Less',
'rnn/while/LoopCond',
'rnn/while/Switch',
'rnn/while/Switch_1',
'rnn/while/Switch_2',
'rnn/while/Identity',
'rnn/while/Identity_1',
'rnn/while/Identity_2',
'rnn/while/TensorArrayReadV3/Enter',
'rnn/while/TensorArrayReadV3/Enter_1',
'rnn/while/TensorArrayReadV3',
'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/shape',
'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/min',
'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/max',
'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/RandomUniform',
'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/sub',
'rnn/basic_rnn_cell/kernel/Initializer/random_uniform/mul',
'rnn/basic_rnn_cell/kernel/Initializer/random_uniform',
'rnn/basic_rnn_cell/kernel',
'rnn/basic_rnn_cell/kernel/Assign',
'rnn/basic_rnn_cell/kernel/read',
'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat/axis',
'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat',
'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter',
'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul',
'rnn/basic_rnn_cell/bias/Initializer/Const',
'rnn/basic_rnn_cell/bias',
'rnn/basic_rnn_cell/bias/Assign',
'rnn/basic_rnn_cell/bias/read',
'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter',
'rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd',
'rnn/while/rnn/basic_rnn_cell/Relu',
'rnn/while/TensorArrayWrite/TensorArrayWriteV3/Enter',
'rnn/while/TensorArrayWrite/TensorArrayWriteV3',
'rnn/while/add/y',
'rnn/while/add',
'rnn/while/NextIteration',
'rnn/while/NextIteration_1',
'rnn/while/NextIteration_2',
'rnn/while/Exit',
'rnn/while/Exit_1',
'rnn/while/Exit_2',
'rnn/TensorArrayStack/TensorArraySizeV3',
'rnn/TensorArrayStack/range/start',
'rnn/TensorArrayStack/range/delta',
'rnn/TensorArrayStack/range',
'rnn/TensorArrayStack/TensorArrayGatherV3',
'rnn/Const_1',
'rnn/Rank',
'rnn/range/start',
'rnn/range/delta',
'rnn/range',
'rnn/concat_1/values_0',
'rnn/concat_1/axis',
'rnn/concat_1',
'rnn/transpose',
'Reshape/shape',
'Reshape',
'dense/kernel/Initializer/random_uniform/shape',
'dense/kernel/Initializer/random_uniform/min',
'dense/kernel/Initializer/random_uniform/max',
'dense/kernel/Initializer/random_uniform/RandomUniform',
'dense/kernel/Initializer/random_uniform/sub',
'dense/kernel/Initializer/random_uniform/mul',
'dense/kernel/Initializer/random_uniform',
'dense/kernel',
'dense/kernel/Assign',
'dense/kernel/read',
'dense/bias/Initializer/zeros',
'dense/bias',
'dense/bias/Assign',
'dense/bias/read',
'dense/MatMul',
'dense/BiasAdd',
'outputs/shape',
'outputs',
'sub',
'Square',
'Const',
'Mean',
'gradients/Shape',
'gradients/Const',
'gradients/Fill',
'gradients/f_count',
'gradients/f_count_1',
'gradients/Merge',
'gradients/Switch',
'gradients/Add/y',
'gradients/Add',
'gradients/NextIteration',
'gradients/f_count_2',
'gradients/b_count',
'gradients/b_count_1',
'gradients/Merge_1',
'gradients/GreaterEqual/Enter',
'gradients/GreaterEqual',
'gradients/b_count_2',
'gradients/Switch_1',
'gradients/Sub',
'gradients/NextIteration_1',
'gradients/b_count_3',
'gradients/Mean_grad/Reshape/shape',
'gradients/Mean_grad/Reshape',
'gradients/Mean_grad/Shape',
'gradients/Mean_grad/Tile',
'gradients/Mean_grad/Shape_1',
'gradients/Mean_grad/Shape_2',
'gradients/Mean_grad/Const',
'gradients/Mean_grad/Prod',
'gradients/Mean_grad/Const_1',
'gradients/Mean_grad/Prod_1',
'gradients/Mean_grad/Maximum/y',
'gradients/Mean_grad/Maximum',
'gradients/Mean_grad/floordiv',
'gradients/Mean_grad/Cast',
'gradients/Mean_grad/truediv',
'gradients/Square_grad/mul/x',
'gradients/Square_grad/mul',
'gradients/Square_grad/mul_1',
'gradients/sub_grad/Shape',
'gradients/sub_grad/Shape_1',
'gradients/sub_grad/BroadcastGradientArgs',
'gradients/sub_grad/Sum',
'gradients/sub_grad/Reshape',
'gradients/sub_grad/Sum_1',
'gradients/sub_grad/Neg',
'gradients/sub_grad/Reshape_1',
'gradients/sub_grad/tuple/group_deps',
'gradients/sub_grad/tuple/control_dependency',
'gradients/sub_grad/tuple/control_dependency_1',
'gradients/outputs_grad/Shape',
'gradients/outputs_grad/Reshape',
'gradients/dense/BiasAdd_grad/BiasAddGrad',
'gradients/dense/BiasAdd_grad/tuple/group_deps',
'gradients/dense/BiasAdd_grad/tuple/control_dependency',
'gradients/dense/BiasAdd_grad/tuple/control_dependency_1',
'gradients/dense/MatMul_grad/MatMul',
'gradients/dense/MatMul_grad/MatMul_1',
'gradients/dense/MatMul_grad/tuple/group_deps',
'gradients/dense/MatMul_grad/tuple/control_dependency',
'gradients/dense/MatMul_grad/tuple/control_dependency_1',
'gradients/Reshape_grad/Shape',
'gradients/Reshape_grad/Reshape',
'gradients/rnn/transpose_grad/InvertPermutation',
'gradients/rnn/transpose_grad/transpose',
'gradients/rnn/TensorArrayStack/TensorArrayGatherV3_grad/TensorArrayGrad/TensorArrayGradV3',
'gradients/rnn/TensorArrayStack/TensorArrayGatherV3_grad/TensorArrayGrad/gradient_flow',
'gradients/rnn/TensorArrayStack/TensorArrayGatherV3_grad/TensorArrayScatter/TensorArrayScatterV3',
'gradients/zeros_like',
'gradients/rnn/while/Exit_1_grad/b_exit',
'gradients/rnn/while/Exit_2_grad/b_exit',
'gradients/rnn/while/Switch_1_grad/b_switch',
'gradients/rnn/while/Switch_2_grad/b_switch',
'gradients/rnn/while/Merge_1_grad/Switch',
'gradients/rnn/while/Merge_1_grad/tuple/group_deps',
'gradients/rnn/while/Merge_1_grad/tuple/control_dependency',
'gradients/rnn/while/Merge_1_grad/tuple/control_dependency_1',
'gradients/rnn/while/Merge_2_grad/Switch',
'gradients/rnn/while/Merge_2_grad/tuple/group_deps',
'gradients/rnn/while/Merge_2_grad/tuple/control_dependency',
'gradients/rnn/while/Merge_2_grad/tuple/control_dependency_1',
'gradients/rnn/while/Enter_1_grad/Exit',
'gradients/rnn/while/Enter_2_grad/Exit',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayGrad/TensorArrayGradV3/Enter',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayGrad/TensorArrayGradV3',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayGrad/gradient_flow',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/f_acc',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/RefEnter',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/StackPush',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/StackPop/RefEnter',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/StackPop',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/b_sync',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/tuple/group_deps',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/tuple/control_dependency',
'gradients/rnn/while/TensorArrayWrite/TensorArrayWriteV3_grad/tuple/control_dependency_1',
'gradients/AddN',
'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad/f_acc',
'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad/RefEnter',
'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad/StackPush',
'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad/StackPop/RefEnter',
'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad/StackPop',
'gradients/rnn/while/rnn/basic_rnn_cell/Relu_grad/ReluGrad',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd_grad/BiasAddGrad',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd_grad/tuple/group_deps',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd_grad/tuple/control_dependency',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd_grad/tuple/control_dependency_1',
'gradients/rnn/while/Switch_1_grad_1/NextIteration',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul/Enter',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1/f_acc',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1/RefEnter',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1/StackPush',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1/StackPop/RefEnter',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1/StackPop',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/MatMul_1',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/tuple/group_deps',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/tuple/control_dependency',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul_grad/tuple/control_dependency_1',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/b_acc',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/b_acc_1',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/b_acc_2',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/Switch',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/Add',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/NextIteration',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/BiasAdd/Enter_grad/b_acc_3',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/Rank',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod/f_acc',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod/RefEnter',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod/StackPush',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod/StackPop/RefEnter',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod/StackPop',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/mod',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/Shape',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/f_acc',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/RefEnter',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPush',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPop/RefEnter',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPop',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/f_acc_1',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/RefEnter_1',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPush_1',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPop_1/RefEnter',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN/StackPop_1',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ShapeN',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/ConcatOffset',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/Slice',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/Slice_1',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/tuple/group_deps',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/tuple/control_dependency',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/concat_grad/tuple/control_dependency_1',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/b_acc',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/b_acc_1',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/b_acc_2',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/Switch',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/Add',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/NextIteration',
'gradients/rnn/while/rnn/basic_rnn_cell/basic_rnn_cell/MatMul/Enter_grad/b_acc_3',
'gradients/rnn/while/Switch_2_grad_1/NextIteration',
'beta1_power/initial_value',
'beta1_power',
'beta1_power/Assign',
'beta1_power/read',
'beta2_power/initial_value',
'beta2_power',
'beta2_power/Assign',
'beta2_power/read',
'rnn/basic_rnn_cell/kernel/Adam/Initializer/zeros',
'rnn/basic_rnn_cell/kernel/Adam',
'rnn/basic_rnn_cell/kernel/Adam/Assign',
'rnn/basic_rnn_cell/kernel/Adam/read',
'rnn/basic_rnn_cell/kernel/Adam_1/Initializer/zeros',
'rnn/basic_rnn_cell/kernel/Adam_1',
'rnn/basic_rnn_cell/kernel/Adam_1/Assign',
'rnn/basic_rnn_cell/kernel/Adam_1/read',
'rnn/basic_rnn_cell/bias/Adam/Initializer/zeros',
'rnn/basic_rnn_cell/bias/Adam',
'rnn/basic_rnn_cell/bias/Adam/Assign',
'rnn/basic_rnn_cell/bias/Adam/read',
'rnn/basic_rnn_cell/bias/Adam_1/Initializer/zeros',
'rnn/basic_rnn_cell/bias/Adam_1',
'rnn/basic_rnn_cell/bias/Adam_1/Assign',
'rnn/basic_rnn_cell/bias/Adam_1/read',
'dense/kernel/Adam/Initializer/zeros',
'dense/kernel/Adam',
'dense/kernel/Adam/Assign',
'dense/kernel/Adam/read',
'dense/kernel/Adam_1/Initializer/zeros',
'dense/kernel/Adam_1',
'dense/kernel/Adam_1/Assign',
'dense/kernel/Adam_1/read',
'dense/bias/Adam/Initializer/zeros',
'dense/bias/Adam',
'dense/bias/Adam/Assign',
'dense/bias/Adam/read',
'dense/bias/Adam_1/Initializer/zeros',
'dense/bias/Adam_1',
'dense/bias/Adam_1/Assign',
'dense/bias/Adam_1/read',
'Adam/learning_rate',
'Adam/beta1',
'Adam/beta2',
'Adam/epsilon',
'Adam/update_rnn/basic_rnn_cell/kernel/ApplyAdam',
'Adam/update_rnn/basic_rnn_cell/bias/ApplyAdam',
'Adam/update_dense/kernel/ApplyAdam',
'Adam/update_dense/bias/ApplyAdam',
'Adam/mul',
'Adam/Assign',
'Adam/mul_1',
'Adam/Assign_1',
'Adam',
'init',
'save/Const',
'save/SaveV2/tensor_names',
'save/SaveV2/shape_and_slices',
'save/SaveV2',
'save/control_dependency',
'save/RestoreV2/tensor_names',
'save/RestoreV2/shape_and_slices',
'save/RestoreV2',
'save/Assign',
'save/RestoreV2_1/tensor_names',
'save/RestoreV2_1/shape_and_slices',
'save/RestoreV2_1',
'save/Assign_1',
'save/RestoreV2_2/tensor_names',
'save/RestoreV2_2/shape_and_slices',
'save/RestoreV2_2',
'save/Assign_2',
'save/RestoreV2_3/tensor_names',
'save/RestoreV2_3/shape_and_slices',
'save/RestoreV2_3',
'save/Assign_3',
'save/RestoreV2_4/tensor_names',
'save/RestoreV2_4/shape_and_slices',
'save/RestoreV2_4',
'save/Assign_4',
'save/RestoreV2_5/tensor_names',
'save/RestoreV2_5/shape_and_slices',
'save/RestoreV2_5',
'save/Assign_5',
'save/RestoreV2_6/tensor_names',
'save/RestoreV2_6/shape_and_slices',
'save/RestoreV2_6',
'save/Assign_6',
'save/RestoreV2_7/tensor_names',
'save/RestoreV2_7/shape_and_slices',
'save/RestoreV2_7',
'save/Assign_7',
'save/RestoreV2_8/tensor_names',
'save/RestoreV2_8/shape_and_slices',
'save/RestoreV2_8',
'save/Assign_8',
'save/RestoreV2_9/tensor_names',
'save/RestoreV2_9/shape_and_slices',
'save/RestoreV2_9',
'save/Assign_9',
'save/RestoreV2_10/tensor_names',
'save/RestoreV2_10/shape_and_slices',
'save/RestoreV2_10',
'save/Assign_10',
'save/RestoreV2_11/tensor_names',
'save/RestoreV2_11/shape_and_slices',
'save/RestoreV2_11',
'save/Assign_11',
'save/RestoreV2_12/tensor_names',
'save/RestoreV2_12/shape_and_slices',
'save/RestoreV2_12',
'save/Assign_12',
'save/RestoreV2_13/tensor_names',
'save/RestoreV2_13/shape_and_slices',
'save/RestoreV2_13',
'save/Assign_13',
'save/restore_all']
6.2关于不同版本的checkpoint文件理解
- 对于tensorflow1.2版本及以上,直接书写完整的ckpt文件的路径中的model_name(比如,my_model.meta书写my_model即可)即可
参考:
(1) tensorflow的checkpoint文件的版本
(2) TensorFlow查看ckpt中变量的几种方法
二、学习其他简单的
1. 保存变量
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)
Model saved in path: /tmp/model.ckpt
2. 恢复变量
采用tf.train.Saver对象回复变量时,不必实现进行初始化,即tf.get_variable()中的initialization参数不需要设置。
tf.reset_default_graph()
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
Model restored.
v1 : [ 1. 1. 1.]
v2 : [-1. -1. -1. -1. -1.]
3. 选择想要保存的和恢复的变量(还不太明白)
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
# Add ops to save and restore only `v2` using the name "v2"
saver = tf.train.Saver({"v2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
# Initialize v1 since the saver will not.
v1.initializer.run()
saver.restore(sess, "/tmp/model.ckpt")
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
INFO:tensorflow:Restoring parameters from /tmp/model.ckpt
v1 : [ 0. 0. 0.]
v2 : [-1. -1. -1. -1. -1.]
三、保存和恢复模型
- savaedmodel 保存和加载模型(包括变量、图、图的元数据)。对应tensorflow有tf.save_model和tf.estimator.Estimator。
1. 构建和加载savedmodel
简单保存
采用:tf.saved_model.simple_save函数
simple_save(session,
export_dir,
inputs={"x": x, "y": y},
outputs={"z": z})
手动构建savedmodel
2. 加载savedmodel
需要的基本信息
- 图定义和变量的会话
- 用于标识要加载的 MetaGraphDef 的标签
- SavedModel 的位置(目录)
???
export_dir = ...
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)
参考:
四、a quick complete tutorial to save model
参考自:https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
1. 关于tensorflow的model文件基本介绍
(1) 在0.11版本之后(4个)
- checkpoint: 保存的最新checkpoint文件
- model.data-0000-of-0001: 文件中包含有training variables
- model.index:
- model.meta:
例如:
(2) 0.11版本之前(3个)
- checkpoint
- model.ckpt
- model.meta
2. 保存变量
- 注意:
- 需要在session内进行保存
- 文件不保存原有的placeholder中的值
(1) 简单保存
saver.save(sess, 'my-test-model')
(2) 指定迭代数即每iteration值之后再保存
saver.save(sess, 'my-test-model', global_step=1000) # 每1000次迭代保存一次
结果的model文件名称将追加’-1000’
(3) write_meta_graph参数:false表示不跟随global_step同步保存,True表示与global_step同步保存
saver.save(sess, 'my-test-model', global_step=1000, write_meta_graph=false) # 表示每1000步对的的meta文件并不保存,仅保存第一次的
(4) 希望每n个小时把保存最新的m个models
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2) # 表示每2h薄脆最新的4个models
(5) 关于tf.train.Saver()中参数
- 按照默认参数:保存全部的vairables
- 具体变量名称的list或dict: 保存部分variables
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)
3. 导入model
主要包括两步骤:
- create the network
通过meta文件来进行获取原有的关系图。
saver = tf.train.import_meta_graph('my_model-1000.meta')
- load the parameters
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_model-1000.meta') # 假设原有模型保存了w1和w2的tensor值
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # 模型已经载入,这里载入model中的w1和w2的值
print(sess.run('w1:0')) # 执行打印出之前模型里的w1值
4. 基于载入model的操作
常见的用已经训练好的模型进行prediction, fine-tuning, further training.
关于get_tensor_by_name中的的理解:
# w1:0
# <name>:0 (0 refers to endpoint which is somewhat redundant)
# 形如'w1'是节点名称,而'w1:0'是张量名称,表示节点的第一个输出张量
tensor = tf.get_default_graph().get_tensor_by_name("w1:0")
(1) 取得已保存的 variable、tensor、placeholders、operation
w1 = graph.get_tensor_by_name("w1:0")
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
(2) 基于原有网络图,喂入新数据
- 载入meta grapha和恢复weight——取得placeholder和为新数据创建feed-dict——取得operation——运行
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
#Now, let's access and create placeholders variables and
#create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated
#using new values of w1 and w2 and saved value of b1.
(3) 在原有图关系基础上添加更多的operation
#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)
(4) 在 fine-tuning
e.g. 对原有的vgg网络图,将最后输出层更改成2个,并用新数据微调。
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
# Now, you run this with fine-tuning data in sess.run()