这一次主要讲解pytorch读取数据的机制和流程,然后按照流程编写代码
Dataset基类
PyTorch 读取图片,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。Dataset
类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。
看一下源码:
这里有一个getitem函数,getitem函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。
list的制作方法通常是将图片的路径和标签信息存储在一个txt中,然后从txt中读取,所以总结一下基本流程:
- 制作存储了图片路径和标签信息的txt
- 将这些信息转化成list,list的每一个元素对应一个样本
- 通过getitem函数,读取数据和标签。
其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,其实触发去读取图片这些操作的是DataLoader里的__iter__(self)(后面再将)。
总而言之,要让PyTorch读取自己的数据集,只要两步:
- 制作图片数据的索引
- 构建Dataset子类
制作图片数据索引
非常简单,就是一些基本的操作,百度一下“”python如何保存txt文件“”就可以知道了。
然后一般来说,txt都是这样的格式
./Data/train/01.png 0
./Data/train/02.png 0
./Data/train/03.png 1
./Data/train/04.png 1
构建Dataset子类
下面我们构建一下Dataset的子类,叫他MyDataset类:
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Datset):
def __init__(self,txt_path,transform=None,target_transform=None):
fh = open(txt_path,'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0].int(words[1])))
self.imgs = imgs
self.transform = transform
def __getitem__(self,index):
fn,label = self.imgs[index]
img=Image.open(fn).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(sefl.imgs)
Init
-
初始化中,我们从已经准备好的txt中获取了图片的路径和表亲啊,并且春初在self.imgs这意味着self.imgs是一个list就像上面我们讲的那样
-
初始化中,初始化了transform,transform是一个Compose类型,transform中包含一个list,list中定义了各种对图像进行的操作,比如随机剪裁,旋转反转等。
-
一个图片都进来之后,会经过数据处理(数据增强),最终变成另外一张图片,也就是模型的输入数据。但是PyTorch的数据增强是将原始图片进行处理,是不会生成新的图片。因此假如我们使用randomcrop这样的随机操作的时候,每次epoch输入进来的图片不会是一摸一样的,达到样本多样性的功能
getitem
- self.imgs是一个list,每一个元素都是一个二元tuple,这很好理解(str1,str2)这样的
- 利用Image.open对图片进行读取,img类型为Image,mode=‘RGB’
- 用transform对图片进行处理,里面可能有什么 标准化(减均值除以标准差),随机剪裁什么的(后面会细说)
这样Mydataset就构建好了,剩下的操作就交给DataLoader,在DataLoader中,会触发Mydataset中的getitem函数读取一张图片的数据和标签,并将多个图片拼接成一个batch返回,每一个batch才是模型真正的输入。
下一章节会讲解DataLoader是如何获取一个batch的