版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/MOU_IT/article/details/82764699
Slim.data模块下包含很多关于数据集处理的模块,包括:dataset、data_decoder、prefetch_queue、dataset_data_provider、tfexample_decoder、data_provider、parallel_reader,下面我们一次介绍。
1、Slim.dataset
Dataset类的源码如下:
"""Dataset包含一个数据集的定义,一个数据集通常是几个组件组成
(1) 数据源的文件列表
(2) 一个读取器reader,它能读取数据源并返回编码后的数据样本
(3) 一个解码器decoder,它可以解码由reader提供的每个数据样本
(4) 数据集的样本总数
(5) 一个可选的字典,它是一个items的列表和items的描述的映射
通过使用一个dataset_data_provider,数据可以从一个指定的数据集加载,比如:
dataset = CreateMyDataset(...)
provider = dataset_data_provider.DatasetDataProvider(dataset, shuffle=False)
image, label = provider.get(['image', 'label'])
查看 slim.data.dataset_data_provider 获取更过例子。
"""
class Dataset(object):
"""Represents a Dataset specification."""
def __init__(self, data_sources, reader, decoder, num_samples,
items_to_descriptions, **kwargs):
"""初始化 dataset.
参数:
data_sources: 组成数据集的文件列表
reader: 一个reader类, 比如TextLineReader或者TFRecordReader.
decoder: 一个data_decoder类的实例
num_samples: 数据集中样本总数
items_to_descriptions: 一个items列表到items描述的字典映射
**kwargs: Any remaining dataset-specific fields.
"""
kwargs['data_sources'] = data_sources
kwargs['reader'] = reader
kwargs['decoder'] = decoder
kwargs['num_samples'] = num_samples
kwargs['items_to_descriptions'] = items_to_descriptions
self.__dict__.update(kwargs)
2、Slim.data_decoder
这是一个抽象类,它的源码如下:
"""
当data provider从磁盘读取数据时,data decoder负责对数据进行解码。一般情况下,需要提供给decoder一个序列化的或者
编码后的数据和一个items列表,decoder处理后返回一个tensors的集合,每个tensors对应数据中提取出来的items。比如,如果数据是
一个压缩字典,decoder的实现过程可能是:
def Decode(self, data, items):
decompressed_map = _Decompress(data)
outputs = []
for item in items:
outputs.append(decompressed_map[item])
return outputs.
"""
import abc
class DataDecoder(object):
"""一个抽象类,用来解码provider提供的数据"""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def decode(self, data, items):
"""对data进行解码,返回由items列表指定的Tensor
参数:
data: 编码的数据
items: 字符串列表,每个表示一个特殊的数据类型A
返回值:
Tensors的列表,它的长度和items的长度相同,每个tensor对应一个item
"""
pass
@abc.abstractmethod
def list_items(self):
"""列出decoder可以解码的items的名字L
返回值:
字符串名的列表
"""
pass
3、Slim.tfexample_decoder.TFExampleDecoder
TFExampleDecoder是一个数据解码器,它主要是用来解码TFrecord文件中的tf.train.Example,TFExampleDecoder类定义源码如下:
"""这个文件包含TFExampleDecode的解码器类和与这个类相关联的帮助类。TFExampleDecode是一个用于对TensorFlow样本进行解码
的解码器。解码过程中,每个请求的item必须与一个或多个样本配对,每个样本被解析并生成基于Tensor的item的表现形式。
"""
import abc
from tensorflow.contrib.slim.python.slim.data import data_decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import sparse_ops
class ItemHandler(object):
"""为tf.parse_example指定item-to-Features的映射,一个ItemHandler同时指定了用来解析样本的特征列表和一个用于对样本解析结果进行后处理的函数
"""
__metaclass__ = abc.ABCMeta
def __init__(self, keys):
"""用所使用的tf.feature关键字的名称来构建一个handler
参数:
keys: tf.train.Example的关键字
"""
if not isinstance(keys, (tuple, list)):
keys = [keys]
self._keys = keys
@property
def keys(self):
return self._keys
@abc.abstractmethod
def tensors_to_item(self, keys_to_tensors):
"""把所给的tensor字典映射为所请求的item
参数:
keys_to_tensors: 一个TF样本关键字到解析的tensors的一个字典
返回值:
最终的tensor,每个tensor是对应item被处理的表征。
"""
pass
class ItemHandlerCallback(ItemHandler):
""""一个ItemHandler通过一个给定的函数把对解析后的tensors进行转换,不同于其他ItemHandlers,ItemHandlerCallback不是通过预指定的函数对结果进行解析,而是通过一个回调函数对结果解析。An
"""
class BoundingBox(ItemHandler):
"""一个ItemHandler,它将一组已解析的Tensors连接到Bounding Boxes。
"""
class LookupTensor(Tensor):
"""An ItemHandler that returns a parsed Tensor, the result of a lookup."""
class BackupHandler(ItemHandler):
"""An ItemHandler that tries two ItemHandlers in order."""
class SparseTensor(ItemHandler):
"""An ItemHandler for SparseTensors."""
class Tensor(ItemHandler):
"""一个ItemHandler,它返回一个解析后的Tensor"""
def __init__(self, tensor_key, shape_keys=None, shape=None, default_value=0):
"""初始化Tensor handler.
默认情况下,返回的Tensors没有任何reshape,然而,这里有两种机制允许reshape发生。如果`shape_keys`不为None,则与‘tensor_key’相关的‘Tensor’和‘shape_keys’被加载,前一个tensor用后一个‘tensor_keys’来reshape。当‘shape’不为None时,则与‘tensor_key’相关的‘Tensor’被加载,然后被reshape。如果既没有提供`shape_keys`也没有`shape`,那么`Tensor`没有任何reshape就返回了。
参数:
tensor_key:tf.train.Example的的关键字名称。
shape_keys: 可选名称或存储张量形状的tf.train.Example的名称列表。 如果是列表,则每个列表对应于形状的一个维度。
shape: 可选的关于tensor输出形状。
default_value: 当 `tensor_key`在tf.train.Example没有存在时,该值被使用。
"""
if shape_keys and shape is not None:
raise ValueError('Cannot specify both shape_keys and shape parameters.')
if shape_keys and not isinstance(shape_keys, list):
shape_keys = [shape_keys]
self._tensor_key = tensor_key
self._shape_keys = shape_keys
self._shape = shape
self._default_value = default_value
keys = [tensor_key]
if shape_keys:
keys.extend(shape_keys)
super(Tensor, self).__init__(keys)
def tensors_to_item(self, keys_to_tensors):
tensor = keys_to_tensors[self._tensor_key]
shape = self._shape
if self._shape_keys:
shape_dims = []
for k in self._shape_keys:
shape_dim = keys_to_tensors[k]
if isinstance(shape_dim, sparse_tensor.SparseTensor):
shape_dim = sparse_ops.sparse_tensor_to_dense(shape_dim)
shape_dims.append(shape_dim)
shape = array_ops.reshape(array_ops.stack(shape_dims), [-1])
if isinstance(tensor, sparse_tensor.SparseTensor):
if shape is not None:
tensor = sparse_ops.sparse_reshape(tensor, shape)
tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value)
else:
if shape is not None:
tensor = array_ops.reshape(tensor, shape)
return tensor
class Image(ItemHandler):
"""一个ItemHandler,用来把一个解析的tensor解码为一张图像 """
def __init__(self,image_key=None, format_key=None,
shape=None,channels=3,dtype=dtypes.uint8,
repeated=False, dct_method=''):
"""初始化图像.
参数:
image_key: tf.train.Example中编码图像的关键字名称
format_key:tf.train.Example中图像格式的关键字名称
shape: 图像的输出形状[height, width, channels]. 如果不为空,则图像被reshape.如果为空,则不被reshape。当且仅当所有保存的图像有相同的图像时,shape才不为空。
channels:图像的通道数
dtype: 图像将在此位深度处被解码,不同的format支持不同的位深度。参考tf.image.decode_image, tf.decode_raw,
repeated: 如果为False,则解码单个图像。 如果为True,则从1D张量的字符串中解码可变数量的图像字符串。
dct_method: 一个可选的字符串,默认为空字符串。当图像时就jpeg时,它才起作用。它用于指定有关用于jpeg解压缩的算法的提示. 当前的有效值是['INTEGER_FAST', 'INTEGER_ACCURATE'].
"""
if not image_key:
image_key = 'image/encoded'
if not format_key:
format_key = 'image/format'
super(Image, self).__init__([image_key, format_key])
self._image_key = image_key
self._format_key = format_key
self._shape = shape
self._channels = channels
self._dtype = dtype
self._repeated = repeated
self._dct_method = dct_method
def tensors_to_item(self, keys_to_tensors):
image_buffer = keys_to_tensors[self._image_key]
image_format = keys_to_tensors[self._format_key]
if self._repeated:
return functional_ops.map_fn(lambda x: self._decode(x, image_format),
image_buffer, dtype=self._dtype)
else:
return self._decode(image_buffer, image_format)
def _decode(self, image_buffer, image_format):
"""解码图像缓存.
参数:
image_buffer: 表示编码图像张量。
image_format: `image_buffer`中图像的图像格式。 如果图像格式是“raw”,则所有图像都应采用此格式,否则此操作可以解码混合的“jpg”和“png”格式。
返回值:
代表解码后的图像的tensor,其形状为self._shape。如果self._shape没有指定的话,返回的tensor的形状为(?, ?, self._channels)
"""
def decode_image():
"""根据标题对图像进行解码."""
return math_ops.cast(
image_ops.decode_image(image_buffer, channels=self._channels),
self._dtype)
def decode_jpeg():
"""使用指定的'_dct_method'解码jpeg图像."""
return math_ops.cast(
image_ops.decode_jpeg(
image_buffer,
channels=self._channels,
dct_method=self._dct_method), self._dtype)
def check_jpeg():
"""检查图像是否为jpeg."""
# 对于jpeg图像,我们直接使用image_ops.decode_jpeg,而不解码图像。
return control_flow_ops.cond(
image_ops.is_jpeg(image_buffer),decode_jpeg,
decode_image,name='cond_jpeg')
def decode_raw():
"""解码原始图像."""
return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)
pred_fn_pairs = {
math_ops.logical_or(
math_ops.equal(image_format, 'raw'),
math_ops.equal(image_format, 'RAW')): decode_raw,
}
image = control_flow_ops.case(
pred_fn_pairs, default=check_jpeg, exclusive=True)
image.set_shape([None, None, self._channels])
if self._shape is not None:
image = array_ops.reshape(image, self._shape)
return image
class TFExampleDecoder(data_decoder.DataDecoder):
"""一个TensorFlow样本的数据解码器,解码样本时由以下两个阶段组成:
(1) 样本解析 (2)tensor操作
第一阶段,tf.parse_example()函数被调用,传入的参数为FixedLenFeatures和SparseLenFeaturesIn的列表。 这些实例告诉TF如何解析样本。第一阶段的输出是一个tensors的集合。
第二阶段,对第一阶段生成的tensors进行操作,然后生成所请求的‘items’的tensors。为了执行此解码操作,给ExampleDecoder一个ItemHandler列表。 每个ItemHandler指示第1阶段的功能集,并包含对第2阶段的张量进行post_processing的说明。
"""
def __init__(self, keys_to_features, items_to_handlers):
"""创建解码器
参数:
keys_to_features: 这是一个字典,它把tf.train.Example关键字映射为tf.VarLenFeature或者tf.FixedLenFeaturea实例。
items_to_handlers: 这也是一个字典,它把items(string)映射为ItemHandler实例。注意,需要提供keys_to_features中的关键字给ItemHandler,它用关键字来返回最终的Items tensor。
"""
self._keys_to_features = keys_to_features
self._items_to_handlers = items_to_handlers
def list_items(self):
return list(self._items_to_handlers.keys())
def decode(self, serialized_example, items=None):
"""对给定的序列化tf.train.Example进行解码.
参数:
serialized_example: 一个序列化的tf.train.Example张量.
items: 需要被解码的items列表。这些items必须是self._items_to_handlers的子集。如果items为空,那么所有的在self._items_to_handlers中的items都会被解码。
返回值:
解码的items,一个tensor的列表
"""
example = parsing_ops.parse_single_example(serialized_example,self._keys_to_features)
# Reshape non-sparse elements just once, adding the reshape ops in
# deterministic order.
for k in sorted(self._keys_to_features):
v = self._keys_to_features[k]
if isinstance(v, parsing_ops.FixedLenFeature):
example[k] = array_ops.reshape(example[k], v.shape)
if not items:
items = self._items_to_handlers.keys()
outputs = []
for item in items:
handler = self._items_to_handlers[item]
keys_to_tensors = {key: example[key] for key in handler.keys}
outputs.append(handler.tensors_to_item(keys_to_tensors))
return outputs
4、Slim.data_provider
"""
DataProvider类从一些数据源中读取数据,并且提供预定义的数据类型。一个DataProvider最基本的函数是一个Get操作,它是一个
请求一种或多种类型的数据或items的操作:provider.get(items=['image', 'sentence', 'class']) 具体而言,一个DataProvider
(BaseDataProvider的子类)为每个请求的item(数据类型)返回一个tensor,比如:
provider = MyDataProvider(...)
image, sentence, clazz = provider.get(['image', 'sentence', 'class'])
在这个例子中,MyDataProvider必须知道如何加载每个item。可以以这样的方式编写DataProvider:从每个item映射到tensor所需的逻辑完全封装在data_provider本身内。
"""
import abc
class DataProvider(object):
"""从一个数据源中,把请求的item列表映射为tensors,所有的DataProvider必须继承这个类,并且实现Get方法,没有假设数据来源和提供数据的机制。
"""
__metaclass__ = abc.ABCMeta
def __init__(self, items_to_tensors, num_samples):
"""构建Data Provider.
参数:
items_to_tensors: 一个name到tensor的字典
num_samples: 数据集所提供的样本总数
"""
self._items_to_tensors = items_to_tensors
self._num_samples = num_samples
def get(self, items):
"""返回一个由所给的items列表指定的tensors列表。items列表是任意的,不同的DataProvider满足不同的item列表。比如,对于Pascal Voc数据集而言,可能接受‘image’和‘semantics’的items,而对于NYUDepthV2数据集而言,可能接受‘image’,‘depths’和‘normals’的items。
参数:
items: 一个字符串列表,每个表示一个特定的数据类型
返回值:
一个tensors的列表,列表长度和items的长度相同,每个tensor对每个item一一对应。
"""
self._validate_items(items)
return [self._items_to_tensors[item] for item in items]
def list_items(self):
"""返回由DataProvider提供的item列表
Returns:
返回一个可以传给Get([items])的item列表。
"""
return self._items_to_tensors.keys()
def num_samples(self):
"""返回数据集的样本总数
Returns:
一个正整数
"""
return self._num_samples
def _validate_items(self, items):
"""验证所给的items是否是一个元组或者列表
参数:
items: 一个字符串元组或列表
"""
if not isinstance(items, (list, tuple)):
raise ValueError('items must be a list or tuple')
valid_items = self.list_items()
for item in items:
if item not in valid_items:
raise ValueError('Item [%s] is invalid. Valid entries include: %s' %(item, valid_items))
5、Slim.parallel_reader
并行读取器的源码如下:
from tensorflow.python.framework import dtypes as tf_dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import gfile
from tensorflow.python.summary import summary
from tensorflow.python.training import input as tf_input
from tensorflow.python.training import queue_runner
class ParallelReader(io_ops.ReaderBase):
"""Reader 类,它使用多个并行的读取器加快速度
"""
def __init__(self,reader_class, common_queue,num_readers=4,reader_kwargs=None):
"""
并行读取器创建了num_reader个读取器类的实例。每个实例通过调用‘reader_class’函数并传入reader_kwargs参数而创建,如下所示:reader_class(**read_kwargs)
当你使用并行读取器的read()方法读取数据时,只需要从‘common_queue’中dequeue一个样本即可。reader将并行读取不同的文件,异步将其输出enqueue到“common_queue”。 `common_queue.dtypes`必须是[tf.string,tf.string]。由于reader从不同的文件读取,因此‘common_queue’中的样本可能来自不同的文件。由于是异步读取,无法保证所有的reader将读取相同的样本个数。如果‘common_queue’是一个混乱队列,那么样本也同样是混乱的。使用方法如下:
common_queue = tf.RandomShuffleQueue(capacity=256,min_after_dequeue=128,dtypes=[tf.string, tf.string])
p_reader = ParallelReader(tf.TFRecordReader, common_queue)
common_queue = tf.FIFOQueue(capacity=256,dtypes=[tf.string, tf.string])
p_reader = ParallelReader(readers, common_queue, num_readers=2)
参数:
reader_class:一个io_ops.ReaderBase的子类,比如:TFRecordReader
common_queue: 一个队列,每个元素是一个(key,value)对,类型为[tf.string, tf.string]. 它必须是一个data_flow_ops.Queues的实例, 比如: `tf.FIFOQueue()`, `tf.RandomShuffleQueue()`,
num_readers: 一个整数值,创建的reader_class的实例个数
reader_kwargs: 一个用于创建reader的可选的参数字典
"""
if len(common_queue.dtypes) != 2:
raise TypeError('common_queue.dtypes must be [tf.string, tf.string]')
for dtype in common_queue.dtypes:
if not dtype.is_compatible_with(tf_dtypes.string):
raise TypeError('common_queue.dtypes must be [tf.string, tf.string]')
reader_kwargs = reader_kwargs or {}
self._readers = [reader_class(**reader_kwargs) for _ in range(num_readers)]
self._common_queue = common_queue
@property
def num_readers(self):
return len(self._readers)
@property
def common_queue(self):
return self._common_queue
def read(self, queue, name=None):
"""返回下一个记录(key,value).conmom_queue中,入队时一个queue runner自动添加到TF QueueRunners的集合中。
参数:
queue: 一个队列
name: (可选)操作名
Returns:
commom_queue中的下一个记录
"""
self._configure_readers_by(queue)
return self._common_queue.dequeue(name=name)
def read_up_to(self, queue, num_records, name=None):
"""返回多达num_records个记录
参数:
queue:一个队列
num_records: 需要读取的record数量
name: (可选)操作名
Returns:
一个Tensors (keys, values)的元组,keys: A 1-D string Tensor,values: A 1-D string Tensor.
"""
self._configure_readers_by(queue)
return self._common_queue.dequeue_up_to(num_records, name)
def _configure_readers_by(self, queue):
enqueue_ops = []
for reader in self._readers:
enqueue_ops.append(self._common_queue.enqueue(reader.read(queue)))
queue_runner.add_queue_runner(
queue_runner.QueueRunner(self._common_queue, enqueue_ops))
def num_records_produced(self, name=None):
"""返回这个reader生成的record的数量。
参数:
name: (可选)操作名
返回值:
一个 int64 Tensor.
"""
num_records = [r.num_records_produced() for r in self._readers]
return math_ops.add_n(num_records, name=name)
def num_work_units_completed(self, name=None):
"""返回这个reader已经完成处理的工作单元
参数:
name:(可选)操作名
Returns:
一个int64 Tensor.
"""
num_work_units = [r.num_work_units_completed() for r in self._readers]
return math_ops.add_n(num_work_units, name=name)
def parallel_read(data_sources, reader_class,num_epochs=None,num_readers=4,
reader_kwargs=None,shuffle=True,dtypes=None,
capacity=256, min_after_dequeue=128,
seed=None,scope=None):
"""使用n个reader从数据源并行读取多个records。用法:
data_sources = ['path_to/train*']
key, value = parallel_read(data_sources, tf.CSVReader, num_readers=4)
参数:
data_sources: 一个文件的列表/元组, 比如:/path/to/train@128, /path/to/train* 或/tmp/.../train*
reader_class:一个io_ops.ReaderBase的子类,比如: TFRecordReader
num_epochs: 每个数据集被读取的次数,如果为空,则无限次读取。
num_readers: reader的数量。
reader_kwargs: 一个reader的可选的参数列表。
shuffle: 是否使用RandomShuffleQueue打乱文件和记录的顺序
dtypes: 类型列表,它的长度应该和每个record中元素的个数相同。如果为空,则默认为 [tf.string, tf.string]
capacity: common_queue的容量
min_after_dequeue: 在dequue之后common_queue中最小的记录个数
seed: 随机数种子
scope: 操作的可选scope
返回值:
key, value: 一个来自数据源的键值对元组
"""
data_files = get_data_files(data_sources)
with ops.name_scope(scope, 'parallel_read'):
filename_queue = tf_input.string_input_producer(
data_files, num_epochs=num_epochs, shuffle=shuffle, seed=seed,
name='filenames')
dtypes = dtypes or [tf_dtypes.string, tf_dtypes.string]
if shuffle:
common_queue = data_flow_ops.RandomShuffleQueue(
capacity=capacity,
min_after_dequeue=min_after_dequeue,
dtypes=dtypes,
seed=seed,
name='common_queue')
else:
common_queue = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=dtypes, name='common_queue')
summary.scalar('fraction_of_%d_full' % capacity,
math_ops.to_float(common_queue.size()) * (1. / capacity))
return ParallelReader(reader_class,common_queue,num_readers=num_readers,
reader_kwargs=reader_kwargs).read(filename_queue)
def single_pass_read(data_sources, reader_class, reader_kwargs=None,scope=None):
"""使用reader顺序读取data_sources,执行单次传递。
Args:
data_sources: 文件列表
reader_class: 一个io_ops.ReaderBase 的子类,比如: TFRecordReader.
reader_kwargs: reader的可选的字典参宿
scope: 可选的操作的scope
Returns:
key, value: 来自数据源的(key,value)元组
"""
data_files = get_data_files(data_sources)
with ops.name_scope(scope, 'single_pass_read'):
filename_queue = tf_input.string_input_producer(
data_files, num_epochs=1, shuffle=False, capacity=1, name='filenames')
reader_kwargs = reader_kwargs or {}
return reader_class(**reader_kwargs).read(filename_queue)
def get_data_files(data_sources):
"""读取数据源获取数据文件列表
参数:
data_sources: 文件列表,比如/path/to/train@128, /path/to/train* 或 /tmp/.../train*
Returns:
数据文件的列表
"""
if isinstance(data_sources, (list, tuple)):
data_files = []
for source in data_sources:
data_files += get_data_files(source)
else:
if '*' in data_sources or '?' in data_sources or '[' in data_sources:
data_files = gfile.Glob(data_sources)
else:
data_files = [data_sources]
if not data_files:
raise ValueError('No data files found in %s' % (data_sources,))
return data_files
6、Slim.dataset_data_provider
"""
DatasetDataProviders从数据集中提供数据. 通过配置,可以同时使用多个readers或者使用单个reader提供数据。此外,被读取的数据
可以被打乱顺序。比如,使用一个单线程读取数据而不打乱顺序的例子如下:
pascal_voc_data_provider = DatasetDataProvider(
slim.datasets.pascal_voc.get_split('train'),shuffle=False)
images, labels = pascal_voc_data_provider.get(['images', 'labels'])
使用多个readers同时读取数据并且打乱顺序的例子如下:
pascal_voc_data_provider = DatasetDataProvider(
slim.datasets.pascal_voc.Dataset(),num_readers=10, shuffle=True)
images, labels = pascal_voc_data_provider.get(['images', 'labels'])
同样地,我们可以分开请求相同样本的不同属性,比如:
[images] = pascal_voc_data_provider.get(['images'])
[labels] = pascal_voc_data_provider.get(['labels'])
"""
from tensorflow.contrib.slim.python.slim.data import data_provider
from tensorflow.contrib.slim.python.slim.data import parallel_reader
class DatasetDataProvider(data_provider.DataProvider):
def __init__(self, dataset, num_readers=1, reader_kwargs=None,
shuffle=True, num_epochs=None,common_queue_capacity=256,
common_queue_min=128, record_key='record_key',
seed=None, scope=None):
"""创建一个DatasetDataProvider.
注意: 如果`num_epochs` 不为 `None`, 局部计数器 `epochs`将会被相关的函数创建. 使用`local_variables_initializer()` 来初始化局部变量。
参数:
dataset: 一个dataset类的实例
num_readers: 使用的并行读取器的数量
reader_kwargs: reader的一个可选的字典参数
shuffle:读取时是否打乱顺序
num_epochs: 每个数据源被读取的次数,如果为None,这个数据集将会被无限次读取。
common_queue_capacity: 公共队列的容量。
common_queue_min: 公共队列出队后的最小元素个数
record_key: 记录关键字
seed: 用于打乱顺序的seed
scope: 可选的该操作scope名
"""
key, data = parallel_reader.parallel_read(
dataset.data_sources, reader_class=dataset.reader,
num_epochs=num_epochs, num_readers=num_readers,
reader_kwargs=reader_kwargs,shuffle=shuffle,
capacity=common_queue_capacity,
min_after_dequeue=common_queue_min,
seed=seed, scope=scope)
items = dataset.decoder.list_items()
tensors = dataset.decoder.decode(data, items)
items_to_tensors = dict(zip(items, tensors))
if record_key in items_to_tensors:
raise ValueError('The item name used for `record_key` cannot also be '
'used for a dataset item: %s', record_key)
items_to_tensors[record_key] = key
super(DatasetDataProvider, self).__init__(
items_to_tensors=items_to_tensors,
num_samples=dataset.num_samples)
7、Slim.prefetch_queue
一个简单的预取队列prefetch_queue的实现。
from tensorflow.python.framework import ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.summary import summary
from tensorflow.python.training import queue_runner
def _which_queue(dynamic_pad):
return (data_flow_ops.PaddingFIFOQueue if dynamic_pad else data_flow_ops.FIFOQueue)
def prefetch_queue(tensors,capacity=8,num_threads=1,dynamic_pad=False,
shared_name=None,name=None):
"""创建一个队列,从传入的'tensors'里面预取出一些tensor。一个使tensor入对(prefetch_queue)的的queue runner将自动添加到TF QueueRunners集合中。
例如:
可以用`tf.train.batch()`预先组装输入batch,并将预先组装的batch入队。则从队列中出队时,将不会消耗组装batch的成本。
images, labels = tf.train.batch([image, label], batch_size=32, num_threads=4)
batch_queue = prefetch_queue([images, labels])
images, labels = batch_queue.dequeue()
logits = Net(images)
loss = Loss(logits, labels)
参数:
tensors: 一个tensors列表或者字典A list or dictionary of `Tensors` to enqueue in the buffer.
capacity: 一个整数,在队列中元素个数最大值
num_threads: 一个整数,入队操作的线程个数
dynamic_pad: 布尔值,是否允许输入形状可变
shared_name: (可选). 如果设置,这个队列将会被多个session共享
name: (可选) 操作名
返回值:
一个队列,你可以从这个队列中出队tensorsA
"""
if isinstance(tensors, dict):
# Need to wrap the keys and values in list() since Python3 returns views.
# We sort the keys so the order is consistent across runs.
names = list(sorted(tensors.keys()))
tensor_list = list([tensors[n] for n in names])
else:
names = None
tensor_list = tensors
with ops.name_scope(name, "prefetch_queue", tensor_list) as name:
dtypes = [t.dtype for t in tensor_list]
shapes = [t.get_shape() for t in tensor_list]
queue = _which_queue(dynamic_pad)(
capacity=capacity,
dtypes=dtypes,
shapes=shapes,
names=names,
shared_name=shared_name)
enqueue_op = queue.enqueue(tensors)
queue_runner.add_queue_runner(
queue_runner.QueueRunner(queue, [enqueue_op] * num_threads))
summary.scalar("fraction_of_%d_full" % capacity,
math_ops.to_float(queue.size()) * (1. / capacity))
return queue
参考:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim/python/slim/data