为了加强对Pytroch框架的掌握,本文梳理了一下Pytroch自定义数据集常用的知识。首先便是如何从文件夹中读取图片名称或者从txt文档里读取图片名称方便导入图片。接着便是常见的torchvision.transforms的使用以及如何自定义transform。最后便是基于torch.utils.data.Dataset构建自定义数据集,并使用dataloader导入。这里参考了许多网上博客,属于汇总。
文章目录
-
- 一、常见读取文件操作
-
- 1.[python读取文件夹中的所有图片并将图片名逐行写入txt中](https://wang11.blog.csdn.net/article/details/125666776?spm=1001.2101.3001.6650.3&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-3-125666776-blog-88117077.235%5Ev27%5Epc_relevant_default&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7ERate-3-125666776-blog-88117077.235%5Ev27%5Epc_relevant_default&utm_relevant_index=4)
- 2.[python读取txt文本内容](https://zhuanlan.zhihu.com/p/42784651)
- 二、transforms
- 三、自定义数据集
- 四、实战
一、常见读取文件操作
1.python读取文件夹中的所有图片并将图片名逐行写入txt中
import os
img_path = r'' # 这里需要写入要读取的文件夹路径,这里的r是为了防止路径转义,\\同样也是为了防止\转义
save_txt_path = r'' # 这里需要写入读取后的txt文件要保存的路径和名称
imgs = os.listdir(img_path) # 这里读取文件中图片名称并且将其用列表存储。os.listdir的返回值是一个列表,列表里面存储该path下面的子目录的名称
txt = open(save_txt_path, 'w')
for img in imgs:
txt.write(img + '\n') # # 逐行写入图片名,'\n'表示换行
txt.close()
这里由于大部分东西都是字符串,故需要用到一些常见的字符串函数。
2.python读取txt文本内容
txt_path = r''
f = open(txt_path, 't') # 这里的t是一个参数,具体参考下面的图片
# 第一种read方法,表示一次读取文件全部内容,该方法返回字符串
lines = f.read()
print(lines)
print(type(lines))
f.close()
# 第二种readline方法,该方法每次读出一行内容,该方法返回一个字符串对象
line = f.readline()
while line:
print(line)
print(type(line))
line = f.readline() # 继续读取下一行
f.close()
# 第三种readlines方法,该方法读取整个文件所有行,保存在一个列表(list)变量中,每次读取一行
lines = f.readline()
for line in lines:
print(line)
print(type(line))
f.close()
二、transforms
1.torchvision.transforms
主要用来进行data augmentation操作,是Pytroch中的图像预处理包。
torchvision是pytroch的一个图形库,主要用来构建计算机视觉模型。
- torchvision.datasets:一些加载数据的函数及常用的数据接口
- torchvision.models:包含常用的模型结构(含预训练模型),例如ResNet等
- torchvision.utils:一些常见的辅助工具代码
- torchvision.transforms:常用的图像处理
transform.Compose() # 把几个常用的变化放到一起
transforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),])
transform.Compose这个类会将列表里面的transform操作进行遍历。
其类实现比较简单
class Compose:
def __init__(self, transforms):
self.transforms = transforms
# 类实例化之后直接调用
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
好了,有一些常见的图像变化,如resize和标准化(transforms.Resize、transforms.Normalize)等,这里就不一一列举了。需要的可以查阅这篇博客。
2.自定义transform
首先,如果自定义transform就需要遵循一定的规则。通过查看Compose这个类的实现,我们可以发现自定义transform需要两个约束:
- 仅接受一个参数,并返回一个参数。如果是多个图片需要同时处理,可以用字典传输
- 实现需要在
__call__
中进行
如下为参考代码:
from PIL import Image
from torchvision import transforms
from utils import transform_invert
import random
import numpy as np
class Enhance(object):
"""增加椒盐噪声
Args:
x():乘
y (): 加
"""
def __init__(self, x=1, y=0):
self.x = x
self.y = y
def __call__(self, img):
"""
Args:
img (PIL Image): PIL Image
Returns:
PIL Image: PIL image.
"""
img_ = np.array(img).copy()
img_ = img_*self.x + self.y
return Image.fromarray(img_.astype('uint8')).convert('RGB')
if __name__ == '__main__':
# 1.读取图像
img = Image.open(r"./cat.png").convert('RGB')
# 2.确定预处理方式
img_transform = transforms.Compose([
transforms.Grayscale(),
Enhance(x=2, y=22),
transforms.ToTensor() # 转Tensor型变量
])
img_tensor = img_transform(img)
# 4.逆Transform变换
img = transform_invert(img_tensor, img_transform) # input: shape=[c h w]
# 5.进行预处理效果展示
img.show()
多个图片需要处理参考代码如下:
from PIL import Image
import random
class RandomFlipOrRotate(object):
def __call__(self, sample):
img1, img2, mask1, mask2, mask_bin = \
sample['img1'], sample['img2'], sample['mask1'], sample['mask2'], sample['mask_bin']
rand = random.random()
if rand < 1 / 6:
img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)
mask1 = mask1.transpose(Image.FLIP_LEFT_RIGHT)
img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)
mask2 = mask2.transpose(Image.FLIP_LEFT_RIGHT)
mask_bin = mask_bin.transpose(Image.FLIP_LEFT_RIGHT)
return {
'img1': img1, 'img2': img2, 'mask1': mask1, 'mask2': mask2, 'mask_bin': mask_bin}
# 这里只展示了部分代码
transform = transforms.Compose([
tr.RandomFlipOrRotate()])
sample = transform({
'img1': img1, 'img2': img2, 'mask1': mask1, 'mask2': mask2, 'mask_bin': mask_bin})
三、自定义数据集
1.torch.utils.data.Dataset
同样,如果自定义数据集也必须满足一定的条件:
-
需要继承data.Dataset
-
__getitem__()
和__len__()
两个方法是必须重写的。__getitem__()
输入索引,根据索引返回训练数据,如图片和label,而__len__()
返回数据长度。class CustomDataset(data.Dataset):#需要继承data.Dataset def __init__(self): # TODO # 1. Initialize file path or list of file names. pass def __getitem__(self, index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). #这里需要注意的是,第一步:read one data,是一个data pass def __len__(self): # You should change 0 to the total size of your dataset. return 0
2.实例
这是一个实现语义变化检测数据集处理的代码,如果有需要可以私信或评论。因只需要看看如何使用即可,就没有详细介绍。
import datasets.transform as tr
import numpy as np
import os
from PIL import Image
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms
class ChangeDetection(Dataset):
def __init__(self, root, mode, use_pseudo_label=False):
super(ChangeDetection, self).__init__()
self.root = root
self.mode = mode
self.use_pseudo_label = use_pseudo_label
if mode in ['train', 'val', 'pseudo_labeling']:
self.root = os.path.join(self.root, 'train')
self.ids = os.listdir(os.path.join(self.root, "im1"))
self.ids.sort()
if mode == 'val':
self.ids = self.ids[::10]
else:
self.ids = list(set(self.ids) - set(self.ids[::10]))
else:
self.root = os.path.join(self.root, 'val')
self.ids = os.listdir(os.path.join(self.root, 'im1'))
self.ids.sort()
self.transform = transforms.Compose([
tr.RandomFlipOrRotate()
])
self.normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def __getitem__(self, index):
id = self.ids[index]
img1 = Image.open(os.path.join(self.root, 'im1', id))
img2 = Image.open(os.path.join(self.root, 'im2', id))
if self.mode == "test":
img1 = self.normalize(img1)
img2 = self.normalize(img2)
return img1, img2, id
if self.mode == "val":
mask1 = Image.open(os.path.join(self.root, 'label1', id))
mask2 = Image.open(os.path.join(self.root, 'label2', id))
else:
if self.mode == 'pseudo_labeling' or (self.mode == 'train' and not self.use_pseudo_label):
mask1 = Image.open(os.path.join(self.root, 'label1', id))
mask2 = Image.open(os.path.join(self.root, 'label2', id))
else:
mask1 = Image.open(os.path.join('outdir/masks/train/im1', id))
mask2 = Image.open(os.path.join('outdir/masks/train/im2', id))
if self.mode == 'train':
gt_mask1 = np.array(Image.open(os.path.join(self.root, 'label1', id)))
mask_bin = np.zeros_like(gt_mask1)
mask_bin[gt_mask1 == 0] = 1
mask_bin = Image.fromarray(mask_bin)
sample = self.transform({
'img1': img1, 'img2': img2, 'mask1': mask1, 'mask2': mask2,
'mask_bin': mask_bin})
img1, img2, mask1, mask2, mask_bin = sample['img1'], sample['img2'], sample['mask1'], \
sample['mask2'], sample['mask_bin']
img1 = self.normalize(img1)
img2 = self.normalize(img2)
mask1 = torch.from_numpy(np.array(mask1)).long()
mask2 = torch.from_numpy(np.array(mask2)).long()
if self.mode == 'train':
mask_bin = torch.from_numpy(np.array(mask_bin)).float()
return img1, img2, mask1, mask2, mask_bin
return img1, img2, mask1, mask2, id
def __len__(self):
return len(self.ids)
3.torch.utils.data.DataLoader
下面显示了 PyTorch 库中DataLoader函数的语法及其参数信息。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
4.python类常见的魔法方法
这个部分有待补充
四、实战
有了上述基础,我们便可以用Pytroch自定义自己的数据集并使用DataLoader载入了。