在tensorflow中session.run()用来将数据传入计算图,计算并返回出给定变量/placeholder的结果。
在看论文代码的时候遇到一段复杂的feed_dict, 本文记录了对sess.run()的复习。
1.tensorflow Session.run()
session.run()
的函数定义如下,可以在交互式python中sess = tf.Session; ?sess.run
,也可以在源码 line846中查看到。首先来看函数的参数定义:
run(self, fetches, feed_dict=None, options=None, run_metadata=None)
其中常用的fetches
和feed_dict
就是常用的传入参数。fetches主要指从计算图中取回计算结果进行放回的那些placeholder和变量,而feed_dict
则是将对应的数据传入计算图中占位符,它是字典数据结构只在调用方法内有效。
参考这个例子额的解释,最下面的fetch和feed,原始定义 在make_callable
下面让我们来看看官方代码中对run()
函数的解释:
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
"""Runs operations and evaluates tensors in `fetches`.
运行操作和对fetches中的张量进行计算
This method runs one "step" of TensorFlow computation, by
running the necessary graph fragment to execute every `Operation`
and evaluate every `Tensor` in `fetches`, substituting the values in
`feed_dict` for the corresponding input values.
这一方法将在tensorflow中运行一次计算,通过将feed_dict中的数据馈入计算图中,
运行计算图定义的操作并最终得到fectch中tensor的评测结果
The `fetches` argument may be a single graph element, or an arbitrarily
nested list, tuple, namedtuple, dict, or OrderedDict containing graph
elements at its leaves. A graph element can be one of the following types:
fecches是从计算图中取出对应变量的参数,可以是单个图元素、任意的列表、元组、字典等等形式的图元素。
图元素包括操作、张量、稀疏张量、句柄、字符串等等。
* A `tf.Operation`.
The corresponding fetched value will be `None`.
* A `tf.Tensor`.
The corresponding fetched value will be a numpy ndarray containing the
value of that tensor.
* A `tf.SparseTensor`.
The corresponding fetched value will be a
`tf.compat.v1.SparseTensorValue`
containing the value of that sparse tensor.
* A `get_tensor_handle` op. The corresponding fetched value will be a
numpy ndarray containing the handle of that tensor.
* A `string` which is the name of a tensor or operation in the graph.
The value returned by `run()` has the same shape as the `fetches` argument,
where the leaves are replaced by the corresponding values returned by
TensorFlow.
run的返回值与fetches的形状一致
Example:
```python
a = tf.constant([10, 20])
b = tf.constant([1.0, 2.0])
# 'fetches' can be a singleton
v = session.run(a)
# v is the numpy array [10, 20] # 这里就是单个元素作为fetch数值
# 'fetches' can be a list.
v = session.run([a, b]) # 这里作为list取回值
# v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
# 1-D array [1.0, 2.0]
# 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
MyData = collections.namedtuple('MyData', ['a', 'b'])
v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
# v is a dict with
# v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
# 'b' (the numpy array [1.0, 2.0])
# v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
# [10, 20].
```
feed_dict可以使得输入的值覆盖图中定义的tensorflow,并在保持dtype一致的情况下在调用函数内起作用
The optional `feed_dict` argument allows the caller to override
the value of tensors in the graph. Each key in `feed_dict` can be
one of the following types:
* If the key is a `tf.Tensor`, the
value may be a Python scalar, string, list, or numpy ndarray
that can be converted to the same `dtype` as that
tensor. Additionally, if the key is a
`tf.compat.v1.placeholder`, the shape of
the value will be checked for compatibility with the placeholder.
* If the key is a
`tf.SparseTensor`,
the value should be a
`tf.compat.v1.SparseTensorValue`.
* If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value
should be a nested tuple with the same structure that maps to their
corresponding values as above.
Each value in `feed_dict` must be convertible to a numpy array of the dtype
of the corresponding key.
The optional `options` argument expects a [`RunOptions`] proto. The options
allow controlling the behavior of this particular step (e.g. turning tracing
on).
The optional `run_metadata` argument expects a [`RunMetadata`] proto. When
appropriate, the non-Tensor output of this step will be collected there. For
example, when users turn on tracing in `options`, the profiled info will be
collected into this argument and passed back.
#----------------------常用输入变量--------------------#
fetches:图元素,需要从中取出对应运行结果
feed_dict:字典映射图元素对应的值
Args:
fetches: A single graph element, a list of graph elements, or a dictionary
whose values are graph elements or lists of graph elements (described
above).
feed_dict: A dictionary that maps graph elements to values (described
above).
options: A [`RunOptions`] protocol buffer
run_metadata: A [`RunMetadata`] protocol buffer
Returns:
Either a single value if `fetches` is a single graph element, or
a list of values if `fetches` is a list, or a dictionary with the
same keys as `fetches` if that is a dictionary (described above).
Order in which `fetches` operations are evaluated inside the call
is undefined.
Raises:
RuntimeError: If this `Session` is in an invalid state (e.g. has been
closed).
TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
`Tensor` that doesn't exist.
"""
2.代码实例
下面这段代码来源于PU-Net的主函数源码,作者自定义了很多ops。我们直接聚焦在代码的第九行feed_dict来看:
# copy from:https://github.com/yulequan/PU-Net/blob/master/code/main.py
def train_one_epoch(sess, ops, fetchworker, train_writer):
loss_sum = []
fetch_time = 0
for batch_idx in range(fetchworker.num_batches):
start = time.time()
batch_input_data, batch_data_gt, radius =fetchworker.fetch()
end = time.time()
fetch_time+= end-start
feed_dict = {ops['pointclouds_pl']: batch_input_data, #<<<<<<<<<看这里--------------------------
ops['pointclouds_gt']: batch_data_gt[:,:,0:3],
ops['pointclouds_gt_normal']:batch_data_gt[:,:,0:3],
ops['pointclouds_radius']: radius}
summary,step, _, pred_val,gen_loss_emd = sess.run( [ops['pretrain_merged'],ops['step'],ops['pre_gen_train'],
ops['pred'], ops['gen_loss_emd']], feed_dict=feed_dict)
# 这里定义了一个feed_dict
# 其中键为ops['xxx'] 值为各种对应的输入数据
# 在运行的时候 fetches构成一整个list放到run中计算出summary step 预训出的结果pred 计算出的损失
# 为了更直观的看到这里馈入和抓取的张量/变量,我们来看看他们的定义:
"""
ops = {'pointclouds_pl': pointclouds_pl,
'pointclouds_gt': pointclouds_gt,
'pointclouds_gt_normal':pointclouds_gt_normal,
'pointclouds_radius': pointclouds_radius,
'pointclouds_image_input':pointclouds_image_input,
'pointclouds_image_pred': pointclouds_image_pred,
'pointclouds_image_gt': pointclouds_image_gt,
'pretrain_merged':pretrain_merged,
'image_merged': image_merged,
'gen_loss_emd': gen_loss_emd,
'pre_gen_train':pre_gen_train,
'pred': pred,
'step': step,
}
"""
# 其中包含了输入输出的占位符以及对应的计算图元素,还有记录网络运行过程的变量如step.
train_writer.add_summary(summary, step)
loss_sum.append(gen_loss_emd)
if step%30 == 0:
pointclouds_image_input = pc_util.point_cloud_three_views(batch_input_data[0,:,0:3])
pointclouds_image_input = np.expand_dims(np.expand_dims(pointclouds_image_input,axis=-1),axis=0)
pointclouds_image_pred = pc_util.point_cloud_three_views(pred_val[0, :, :])
pointclouds_image_pred = np.expand_dims(np.expand_dims(pointclouds_image_pred, axis=-1), axis=0)
pointclouds_image_gt = pc_util.point_cloud_three_views(batch_data_gt[0, :, 0:3])
pointclouds_image_gt = np.expand_dims(np.expand_dims(pointclouds_image_gt, axis=-1), axis=0)
# 下面两句则定义了需要从计算图中拿到的一个merged可视化结果,并馈入三个对应数据来获取
feed_dict ={ops['pointclouds_image_input']:pointclouds_image_input,
ops['pointclouds_image_pred']: pointclouds_image_pred,
ops['pointclouds_image_gt']: pointclouds_image_gt,
}
summary = sess.run(ops['image_merged'],feed_dict)
train_writer.add_summary(summary,step)
loss_sum = np.asarray(loss_sum)
log_string('step: %d mean gen_loss_emd: %f\n' % (step, round(loss_sum.mean(),4)))
print 'read data time: %s mean gen_loss_emd: %f' % (round(fetch_time,4), round(loss_sum.mean(),4))
在上面的例子中可以看到run()函数可以根据馈入的feed_dict
字典来依据计算图进行计算,而后更具fetches的元素返回出对应元素的计算结果,完成一次运行过程。
ref:
https://www.w3cschool.cn/tensorflow_python/tensorflow_python-slp52jz8.html
https://www.w3cschool.cn/tensorflow_python/tensorflow_python-fibz28ss.html
博客简介:https://www.cnblogs.com/gengyi/p/9865915.html
博客learner_ctr讲解:https://blog.csdn.net/a1066196847/article/details/84104655
一个视频讲解:https://www.aiworkbox.com/lessons/use-feed_dict-to-feed-values-to-tensorflow-placeholders
教程:https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/Session.Run.html#Session.Run()
https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/Session.Runner.html#runAndFetchMetadata()
https://www.cnblogs.com/yao62995/p/5773043.html