tf.app.flags.FLAGS

下面介绍 tf.app.flags.FLAGS 的使用,主要是在用命令行执行程序时,需要传些参数,代码如下:

新建一个名为:app_flags.py 的文件。

[python]  view plain  copy
  1. #coding:utf-8  
  2.   
  3. # 学习使用 tf.app.flags 使用,全局变量  
  4. # 可以再命令行中运行也是比较方便,如果只写 python app_flags.py 则代码运行时默认程序里面设置的默认设置  
  5. # 若 python app_flags.py --train_data_path <绝对路径 train.txt> --max_sentence_len 100  
  6. #    --embedding_size 100 --learning_rate 0.05  代码再执行的时候将会按照上面的参数来运行程序  
  7.   
  8. import tensorflow as tf  
  9.   
  10. FLAGS = tf.app.flags.FLAGS  
  11.   
  12. # tf.app.flags.DEFINE_string("param_name", "default_val", "description")  
  13. tf.app.flags.DEFINE_string("train_data_path""/home/yongcai/chinese_fenci/train.txt""training data dir")  
  14. tf.app.flags.DEFINE_string("log_dir""./logs"" the log dir")  
  15. tf.app.flags.DEFINE_integer("max_sentence_len"80"max num of tokens per query")  
  16. tf.app.flags.DEFINE_integer("embedding_size"50"embedding size")  
  17.   
  18. tf.app.flags.DEFINE_float("learning_rate"0.001"learning rate")  
  19.   
  20.   
  21. def main(unused_argv):  
  22.     train_data_path = FLAGS.train_data_path  
  23.     print("train_data_path", train_data_path)  
  24.     max_sentence_len = FLAGS.max_sentence_len  
  25.     print("max_sentence_len", max_sentence_len)  
  26.     embdeeing_size = FLAGS.embedding_size  
  27.     print("embedding_size", embdeeing_size)  
  28.     abc = tf.add(max_sentence_len, embdeeing_size)  
  29.   
  30.     init = tf.global_variables_initializer()  
  31.   
  32.     #with tf.Session() as sess:  
  33.         #sess.run(init)  
  34.         #print("abc", sess.run(abc))  
  35.   
  36.     sv = tf.train.Supervisor(logdir=FLAGS.log_dir, init_op=init)  
  37.     with sv.managed_session() as sess:  
  38.         print("abc:", sess.run(abc))  
  39.   
  40.         # sv.saver.save(sess, "/home/yongcai/tmp/")  
  41.   
  42.   
  43. # 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数  
  44. if __name__ == '__main__':  
  45.     tf.app.run()   # 解析命令行参数,调用main 函数 main(sys.argv)  


调用方法:

其中参数可以根据需求进行修改。

[python]  view plain  copy
  1. python app_flags.py --train_data_path <绝对路径 train.txt> --max_sentence_len 100 --embedding_size 100 --learning_rate 0.05  

如果这样调用:

[python]  view plain  copy
  1. python app_flags.py  

则会执行程序时会自动调用程序中 default 中的参数。

猜你喜欢

转载自blog.csdn.net/zlrai5895/article/details/80462937