tensorflow中的tf.app.run()函数
转载需要注明出处!
# 省略import
if __name__ = '__main__':
tf.app.run()
tf.app.run() 是函数入口,类似于c++中的main()
根据run()函数的源码描述,一般来说,运行一个程序是需要main函数作为入口的,同时都附带了参数argv是用来接收用户输入的,而在tensorflow这里,tf.app.run() 就相当于这个main() 函数,而且还能解析参数。附上run() 函数的源码
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or _sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
_allowed_symbols = [
'run',
# Allowed submodule.
'flags',
]
remove_undocumented(__name__, _allowed_symbols)
运作机制
run() 函数的运作计制是:先加载flags的参数项,然后执行main() 函数,其中参数使用tf.app.flags.FLAGS定义。
那么,flags如何定义参数?
附上tensorflow的flags参数定义方式
import tensorflow as tf
# 定义参数
# 第一个是参数名称, 第二个参数是默认值, 第三个是参数描述
tf.app.flags.DEFINE_string('string', 'myname', 'The type of myname is string')
tf.app.flags.DEFINE_integer('image_size', 32, 'The size of image')
FLAGS = tf.app.flags.FlAGS
# 定义主函数
# 需要传递参数
def main(argv):
print('string: ', FLAGS.string)
print('image_size: ' , FLAGS.image_size)
if __name__ = '__main__':
tf.app.run()
主函数中run会调用main,并传递刚才定义的参数,所以main函数中需要设置参数位置。
有两种情况需要说明:
- 如果你的代码中的入口函数不叫main(),而是一个其他名字的函数,如test(),则你应该这样写入口tf.app.run(test)
- 如果你的代码中的入口函数叫main(),则你就可以把入口写成tf.app.run()
文章引用
参考文章1