01_pytorch中的DataSet

在pytorch 中,
Dataset: 用于数据集的创建;
DataLoader: 用于在训练过程中,传递获取一个batch的数据;

这里先介绍 pytorch 中的 Dataset 这个类,
torch.utils.data. dataset.py 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。

数据集,其实就是一个负责处理索引(index)到样本(sample)映射的一个类(class)。

在这里插入图片描述
在torch.utils.data. dataset.py 中可知,
pytorch 提供两种数据集:

  • Map 式数据集, 上图中MapDataPipe()
  • Iterable 式数据集, IterDataPipe()

1. Map 式数据集

即上图中MapDataPipe(),

1.1 需要重写的方法

一个Map式的数据集必须要重写__getitem__(self, index),
len(self) 两个内建方法,用来表示从索引到样本的映射(Map).

这样一个数据集dataset的作用如下,

  • 当使用dataset[idx]命令时,可以在你的硬盘中读取你的数据集中第idx张图片以及其标签(如果有的话);
  • len(dataset)则会返回这个数据集的容量。

1.2 使用方法

例子-1: 自己实验中写的一个例子:这里我们的图片文件储存在“./data/faces/”文件夹下,图片的名字并不是从1开始,而是从final_train_tag_dict.txt这个文件保存的字典中读取,label信息也是用这个文件中读取。大家可以照着注释阅读这段代码。

from torch.utils import data
import numpy as np
from PIL import Image


class face_dataset(data.Dataset):
	def __init__(self):
		self.file_path = './data/faces/'
		f=open("final_train_tag_dict.txt","r")
		self.label_dict=eval(f.read())
		f.close()

	def __getitem__(self,index):
		label = list(self.label_dict.values())[index-1]
		img_id = list(self.label_dict.keys())[index-1]
		img_path = self.file_path+str(img_id)+".jpg"
		img = np.array(Image.open(img_path))
		return img,label

	def __len__(self):
		return len(self.label_dict)

2. Iterable 式数据集

一个Iterable(迭代)式数据集是抽象类data.IterableDataset的子类,并且覆写了__iter__方法成为一个迭代器。

这种数据集主要用于数据大小未知,或者以流的形式的输入,本地文件不固定的情况,需要以迭代的方式来获取样本索引。

一个 Iterable 式的数据集必须要重写__iter__,

class IterableDataset(Dataset[T_co]):
    r"""An iterable Dataset.

    All datasets that represent an iterable of data samples should subclass it.
    Such form of datasets is particularly useful when data come from a stream.

    All subclasses should overwrite :meth:`__iter__`, which would return an
    iterator of samples in this dataset.

    When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
    item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
    iterator. When :attr:`num_workers > 0`, each worker process will have a
    different copy of the dataset object, so it is often desired to configure
    each copy independently to avoid having duplicate data returned from the
    workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
    process, returns information about the worker. It can be used in either the
    dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
    :attr:`worker_init_fn` option to modify each copy's behavior.

    """
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

    def __add__(self, other: Dataset[T_co]):
        return ChainDataset([self, other])

所有表示数据样本可迭代的数据集都应该对其进行子类化。
这种形式的数据集在数据来自流的时候特别有用。

所有的子类都应该覆盖 :meth:__iter__,它将返回这个数据集中样本的迭代器。

  • 当子类与 :class:~torch.utils.data.DataLoader一起使用时,数据集中的每个项目将从 :class:~torch.utils.data.DataLoader迭代器中产生。
  • 当 :attr:num_workers > 0时,每个工作进程将有一个不同的数据集对象的副本,所以通常需要独立配置每个副本,以避免从工作进程返回重复的数据。
  • :func:~torch.utils.data.get_worker_info,当在工作进程中调用时,返回关于工作者的信息。它可以被用于数据集的 :meth:__iter__ 方法或 :class:~torch.utils.data.DataLoader 的 j-
  • :attr:worker_init_fn 选项来修改每个副本的行为

2.1  迭代器和生成器之间的关系

在这里插入图片描述

2.2 python 中的迭代器

顾名思义,迭代器就是用于迭代操作(for 循环)的对象,它像列表一样可以迭代获取其中的每一个元素,任何实现了 __next__ 方法 (python2 是 next)的对象都可以称为迭代器。

它与列表的区别在于,构建迭代器的时候,不像列表把所有元素一次性加载到内存,而是以一种延迟计算(lazy evaluation)方式返回元素,这正是它的优点。

比如列表含有中一千万个整数,需要占超过400M的内存,而迭代器只需要几十个字节的空间。
因为它并没有把所有元素装载到内存中,而是等到调用 next 方法时候才返回该元素
(按需调用 call by need 的方式,本质上 for 循环就是不断地调用迭代器的__next__方法)。

以斐波那契数列为例来实现一个迭代器:

class Fib:
    def __init__(self, n):
        self.prev = 0
        self.cur = 1
        self.n = n

    def __iter__(self):
        return self

    def __next__(self):
        if self.n > 0:
            value = self.cur
            self.cur = self.cur + self.prev
            self.prev = value
            self.n -= 1
            return value
        else:
            raise StopIteration()
    # 兼容python2
    def next(self):
        return self.__next__()

f = Fib(10)
print([i for i in f])
#[1, 1, 2, 3, 5, 8, 13, 21, 34, 5

2.3  python 中的生成器

知道迭代器之后,就可以正式进入生成器的话题了。

2.3.1 为什么需要生成器

通过列表生成式,我们可以直接创建一个列表,但是,受到内存限制,列表容量肯定是有限的,而且创建一个包含100万个元素的列表,不仅占用很大的存储空间,如果我们仅仅需要访问前面几个元素,那后面绝大多数元素占用的空间都白白浪费了。

所以,如果列表元素可以按照某种算法推算出来,那我们是否可以在循环的过程中不断推算出后续的元素呢?这样就不必创建完整的list,从而节省大量的空间,在Python中,这种一边循环一边计算的机制,称为生成器:generator

生成器是一个特殊的程序,可以被用作控制循环的迭代行为,python中生成器是迭代器的一种,使用yield返回值函数,每次调用yield会暂停,而可以使用next()函数和send()函数恢复生成器。

生成器类似于返回值为数组的一个函数,这个函数可以接受参数,可以被调用,但是,不同于一般的函数会一次性返回包括了所有数值的数组,生成器一次只能产生一个值,这样消耗的内存数量将大大减小,而且允许调用函数可以很快的处理前几个返回值,因此生成器看起来像是一个函数,但是表现得却像是迭代器

  • 生成器函数:也是用def定义的,利用关键字yield一次性返回一个结果,阻塞,重新开始

  • 生成器表达式:返回一个对象,这个对象只有在需要的时候才产生结果

普通函数用 return 返回一个值,和 Java 等其他语言是一样的,然而在 Python 中还有一种函数,用关键字 yield 来返回值,这种函数叫生成器函数,函数被调用时会返回一个生成器对象,生成器本质上还是一个迭代器,也是用在迭代操作中,因此它有和迭代器一样的特性,唯一的区别在于实现方式上不一样, 生成器更加简洁。

2.3.2 生成器函数

最简单的生成器函数:

>>> def func(n):
...     yield n*2
...
>>> func
<function func at 0x00000000029F6EB8>
>>> g = func(5)
>>> g
<generator object func at 0x0000000002908630>
>>>

func 就是一个生成器函数,调用该函数时返回对象就是生成器 g ,这个生成器对象的行为和迭代器是非常相似的,可以用在 for 循环等场景中。注意 yield 对应的值在函数被调用时不会立刻返回,而是调用next方法时(本质上 for 循环也是调用 next 方法)才返回。

>>> g = func(5)
>>> next(g)
10

>>> g = func(5)
>>> for i in g:
...     print(i)
...
10

那为什么要用生成器呢?用生成器它没有那么多冗长代码了,而且性能上一样的高效。

不足之处,便是需要多理解一下。

来看看用生成器实现斐波那契数列有多简单。

def fib(n):
    prev, curr = 0, 1
    while n > 0:
        n -= 1
        yield curr
        prev, curr = curr, curr + prev

print([i for i in fib(10)])
#[1, 1, 2, 3, 5, 8, 13, 21, 34, 55]

2.3.2 生成器表达式

  • 生成器表达式
    器表达式与列表推导式长的非常像,但是它俩返回的对象不一样,前者返回生成器对象,后者返回列表对象。
>>> g = (x*2 for x in range(10))
>>> type(g)
<type 'generator'>
>>> l = [x*2 for x in range(10)]
>>> type(l)
<type 'list'>

生成器的优势,就是迭代海量数据时,生成器会更加内存。

2.3  生成器在 DataLoader中应用

深度学习框架PyTorch中的DataLoader模块的实现就使用了生成器的机制来生成一次训练用的batch。

DataLoader的详解在博主的另一篇文章Pytorch之Dataloader中,这里只讲其中运用到生成器机制的Sampler类模块。

首先,是 RandomSampler, iter(randomSampler) 会返回一个可迭代对象,这个可迭代对象 每次 next 都会输出当前要采样的 index,SequentialSampler也是一样,只不过产生的 index 是顺序的

class RandomSampler(Sampler):

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(torch.randperm(len(self.data_source)).long())

    def __len__(self):
        return len(self.data_source)

BatchSampler 是一个普通 Sampler 的 wrapper, 普通Sampler 一次仅产生一个 index, 而 BatchSampler 一次产生一个 batch 的 indices。

class BatchSampler(Sampler):
    """Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """

    def __init__(self, sampler, batch_size, drop_last):
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integeral value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

reference:

https://chenllliang.github.io/2020/02/04/dataloader/
https://chenllliang.github.io/2020/02/06/PyIter/
https://www.zhihu.com/question/20829330/answer/213544776
https://www.cnblogs.com/wj-1314/p/8490822.html

猜你喜欢

转载自blog.csdn.net/chumingqian/article/details/131002167