1.os.path.join()函数表示把参数字符串按照路径命名规则拼接。
例如: import os
os.path.join('/hello/','good/boy/','doiido') 输出结果:'/hello/good/boy/doiido'
2.字符串.split( )函数表示按照指定“拆分符”对字符串拆分,返回拆分列表。 例如:
'./model/mnist_model-1001'.split('/')[-1].split('-')[-1] 在该例子中,共进行两次拆分。第一个拆分符为‘/’,返回拆分列表,并提取 列表中索引为-1 的元素即倒数第一个元素;第二个拆分符为‘-’,返回拆分列 表,并提取列表中索引为-1 的元素即倒数第一个元素,故函数返回值为 1001。
3.tf.Graph( ).as_default( )函数表示将当前图设置成为默认图,并返回一 个上下文管理器。该函数一般与 with 关键字搭配使用,应用于将已经定义好 的神经网络在计算图中复现。 例如: with tf.Graph().as_default() as g,表示将在 Graph()内定义的节点加入到 计算图 g 中。
4.tf.train.Saver()用来实例化 saver 对象。神经网络每循 环规定的轮数,将神经网络模型中所有的参数等信息保存到指定的路径中,并在 存放网络模型的文件夹名称中注明保存模型时的训练轮数。
5.存储网络模型
-
saver = tf.train.Saver() with tf.Session() as sess: for i in range(STEPS): if i % 轮数 == 0: saver.save(sess, os.path.join( MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
6.加载已有网络模型:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(存储路径)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
7.tf.argmax()函数取出对应向量中最大值元素对应的索引值