tensorflow统计graph中的trainable_variables

原文地址https://blog.csdn.net/shwan_ma/article/details/78879620,版权归原作者所有。

原博主写的很好,将常用的方法记载下来供以后学习参考。

sess.run(tf.global_varibales_initializer())

variable_name = [v.name for v in tf.trainable_variables()]

print(variable_names)

variable_names = [v.name for v in tf.trainable_variables()]

values = sess.run(variable_names)

for k,v in zip(variable_names, values):

    print("Variable: ", k)

    print("Shape: ", v.shape)

    print(v)

for variable in tf.trainable_variables():

    shape = variable.get_shape()

    variable_parameters = 1

    for dim in shape:

        variable_parameters *= dim.value

    total_parameters += variable_parameters

猜你喜欢

转载自blog.csdn.net/u010454261/article/details/81121836