使用的fairseq version:1.0
根据dataset创建batch iterator的代码位于:tasks/fairseq_task.py:FairseqTask.get_batch_iterator, 代码逻辑和添加的代码注释如下
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 1).
data_buffer_size (int, optional): number of batches to
preload (default: 0).
disable_iterator_cache (bool, optional): don't cache the
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`)
(default: False).
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
can_reuse_epoch_itr = not disable_iterator_cache and self.can_reuse_epoch_itr(
dataset
)
if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter:
logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch))
return self.dataset_to_epoch_iter[dataset]
assert isinstance(dataset, FairseqDataset)
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# get indices ordered by example size
# 用于排序,比如将数据按照句子的长度装入不同的bucket
# 比如先根据target的长度进行排序,再根据source的长度进行排序。
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
# filter examples that are too large
if max_positions is not None:
indices = self.filter_indices_by_size(
indices, dataset, max_positions, ignore_invalid_inputs
)
# create mini-batches with given size constraints
# 调用dataset的batch_by_size,根据max_tokens, max_sentences等,创建batches
# 这里的batch sampler是按照顺序取batch对应的每个数据项,比如按照排好的顺序进行bucket,然后再装成batch
batch_sampler = dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
# return a reusable, sharded iterator
# 构建iterator,包含shuffle、buffer等操作。
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
buffer_size=data_buffer_size,
)
if can_reuse_epoch_itr:
self.dataset_to_epoch_iter[dataset] = epoch_iter
return epoch_ite