2.pytorch学习:自己导入数据集与利用自带的公开数据集

自己导入数据集

蚂蚁蜜蜂分类数据集下载链接:

https://download.pytorch.org/tutorial/hymenoptera_data.zip

from torch.utils.data import Dataset
import os
from PIL import Image


class MyData(Dataset):
    # root_dir数据集根目录文件夹,label_dir为标注过的图片文件夹
    # 初始化+读取数据集
    def __init__(self, root_dir, label_dir):
        # root_dir数据集根目录文件夹
        self.root_dir = root_dir
        # label_dir为标注过的图片文件夹
        self.label_dir = label_dir
        # 使用os.path.join()函数,拼接路径,因为win是\\拼接,linux是\拼接。
        self.path = os.path.join(self.root_dir, self.label_dir)
        # 将路径下的文件存成数组(array)的形式。数组的元素对应每个图片的名字(str字符串类型)。
        self.img_path = os.listdir(self.path)

    # 对于指定的idx(索引,因为img_path是一个由图片名(字符串)按照一定顺序组成的数组),获取数据并返回。
    def __getitem__(self, idx):
        # 通过idx索引图片名。
        img_name = self.img_path[idx]
        # 拼接数据集根目录、标注图片文件夹目录与图片名。
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        # 利用PIL(Python Imaging Library)中的Image中的open()函数打开图片。
        img = Image.open(img_item_path)
        # 因为这是分类任务,故标注图片的文件夹名就是种类的名字。
        label = self.label_dir
        # 返回图片与
        return img, label

    # 返回图片的个数(图片文件名列表的长度)。
    def __len__(self):
        return len(self.img_path)


# 数据集根目录。
dataset_root_dir = "hymenoptera_data/hymenoptera_data/train"
# 数据集的标注文件夹,因为是蚂蚁的蜜蜂的分类问题。
ants_label_dir = "ants"
bees_label_dir = "bees"
# def __init__(self, root_dir, label_dir): ,方法__init__有两个形参,根据这个类创建实例就必须要指定形参的值。
ants_dataset = MyData(dataset_root_dir, ants_label_dir)
bees_dataset = MyData(dataset_root_dir, ants_label_dir)
# 将两个数据集加载一块组成训练集
train_dataset = ants_dataset + bees_dataset

注意的点1:

方法__init__有两个形参,根据这个类创建实例就必须要指定形参的值。这个目的是将两个路径通过os.path.join()函数,拼成想要的路径,找到数据集图片。同时由于是分类问题,种类为文件名,故第二个形参就是种类名。

def __init__(self, root_dir, label_dir):

注意的点2:

由于有了这个方法,数据集中的图片可以通过列表索引的方式索引出来。为以后的操作提供基础。

def __getitem__(self, idx):

注意的点3:

由于有了这个方法,可以通过len()函数去查询数据集列表的长度,也就是图片的数量。

def __len__(self):

利用自带的公开数据集

import torchvision
from matplotlib import pyplot as plt
from torchvision import transforms
import torch
from torch.utils.data import DataLoader

# transforms模块,用到了ToTensor()与Normalize(),前者是为了将图片转为为tensor张量,转化后才能被神经网络接收,
# 后者是为了使每个信道(BGR)的灰度值平均值为0,标准差为1,数学上的标准差概念是各个元素的值距离平均值的距离的平均值为标准差。
# 故此操作为归一化,许多教程这里的标准差和平均值都是瞎给的,没有跟数据集对应.
# Compose能将一系列的预处理以列表的方式一一给出.
dataset_transforms = transforms.Compose([torchvision.transforms.ToTensor(),
                                         torchvision.transforms.Normalize((0.4915, 0.4823, 0.4468),
                                                                          (0.2470, 0.2435, 0.2616))
                                                     ])

# 加载已有的数据集
# root为下载的路径, ./的意思是当前路径下, ../意思是上一级路径
# train布尔值,True为训练集,False为验证集
# transform预处理模块.
train_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True, transform=dataset_transforms)
val_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True, transform=dataset_transforms)

# 打印出训练集是一个类class.
print(type(train_dataset))
# 由于类中有__len__()函数,所以可以查询训练集列表的长度.
print("训练集长度:{}".format(len(train_dataset)))
# 由于类中有__getitem__()函数,所以可以使用标准索引对元组和列表进行索引,以访问单个数据项.
img, label = train_dataset[99]
print("图片:{}\n图片类型:{}\n图片形状:{}\n图片数据类型:{}\n值的范围:{}到{}\n标注:{}\n图片种类名:{}"
      .format(img, type(img), img.shape, img.dtype, img.min(), img.max(), label, train_dataset.classes[label]))

# 打开图片,如果用到了transform中的ToTensor,图片是打不开的
# AttributeError: 'Tensor' object has no attribute 'show'
# img.show()

# C×H×W改为H×W×C
plt.imshow(img.permute(1, 2, 0))
plt.show()

# 此处能知道每个信道的平均值与标准差,没进行预处理中的归一化操作,平均值和标准差是不稳定的.
imgs = torch.stack([img for img, _ in train_dataset], dim=3)
print("拼接后的形状:{}".format(imgs.shape))
print("每个信道平均值:{}".format(imgs.view(3, -1).mean(dim=1)))
print("每个信道标准差:{}".format(imgs.view(3, -1).std(dim=1)))

# Dataloader
# 数据集,批处理大小,是否随机抓取(不放回抓取),并行处理,最后剩下不满足批处理大小的数据是否扔掉.
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
# CIFAR10类中有__getitem__()函数.
for data in train_loader:
    imgss, labels = data
    print(imgss.shape)
    print(labels)

注意的点1:

ToTensor是必不可少的,归一化的参数不是乱给的,在利用别人的代码训练自己的自定义数据集,均值与标准差需要重新计算。

# transforms模块,用到了ToTensor()与Normalize(),前者是为了将图片转为为tensor张量,转化后才能被神经网络接收,
# 后者是为了使每个信道(BGR)的灰度值平均值为0,标准差为1,数学上的标准差概念是各个元素的值距离平均值的距离的平均值为标准差。
# 故此操作为归一化,许多教程这里的标准差和平均值都是瞎给的,没有跟数据集对应.
# Compose能将一系列的预处理以列表的方式一一给出.
dataset_transforms = transforms.Compose([torchvision.transforms.ToTensor(),
                                         torchvision.transforms.Normalize((0.4915, 0.4823, 0.4468),
                                                                          (0.2470, 0.2435, 0.2616))
                                                     ])

注意的点2:

要想通过matplotlib输出图片,必须给tensor张量换序。

# C×H×W改为H×W×C
plt.imshow(img.permute(1, 2, 0))
plt.show()

注意的点3:

在不知道有些维度的具体值,赋值成-1,系统自动计算。

# 此处能知道每个信道的平均值与标准差,没进行预处理中的归一化操作,平均值和标准差是不稳定的.
imgs = torch.stack([img for img, _ in train_dataset], dim=3)
print("拼接后的形状:{}".format(imgs.shape))
print("每个信道平均值:{}".format(imgs.view(3, -1).mean(dim=1)))
print("每个信道标准差:{}".format(imgs.view(3, -1).std(dim=1)))

运行结果

因为50000张图片,最后for循环结果只给部分:

Files already downloaded and verified
Files already downloaded and verified
<class 'torchvision.datasets.cifar.CIFAR10'>
训练集长度:50000
图片:tensor([[[-1.0055, -1.1960, -1.2595,  ...,  0.6615,  0.9156,  0.1852],
         [-0.9896, -1.1167, -1.1643,  ...,  0.5980,  0.7251,  0.3123],
         [-1.0690, -0.9738, -1.1008,  ...,  0.4393,  0.3916, -0.0370],
         ...,
         [ 0.7409,  0.2805,  0.0741,  ..., -0.4975,  0.2487,  0.2170],
         [ 0.9156,  0.3916, -0.7197,  ..., -0.7039,  0.1535,  0.2805],
         [ 1.3284,  0.8997,  0.2170,  ..., -1.0531,  0.0741,  0.6933]],

        [[-0.9500, -1.1754, -1.2721,  ...,  0.7894,  0.9826,  0.2096],
         [-0.9339, -1.1271, -1.1754,  ...,  0.7410,  0.8216,  0.3706],
         [-0.9822, -0.9178, -1.0144,  ...,  0.5156,  0.4995,  0.0807],
         ...,
         [ 0.1935, -0.2091, -1.0788,  ..., -0.7728, -0.2414, -0.2897],
         [ 0.3706, -0.0803, -0.9500,  ..., -0.8211, -0.0803,  0.0324],
         [ 0.8216,  0.4512, -0.2253,  ..., -1.1110, -0.0642,  0.5317]],

        [[-1.0484, -1.3182, -1.4231,  ..., -0.6736, -0.5687, -0.6286],
         [-1.1533, -1.3182, -1.3032,  ..., -0.7935, -0.5836, -0.5537],
         [-1.1683, -1.1533, -1.1533,  ..., -0.7785, -0.7485, -0.8535],
         ...,
         [-0.2239, -0.4487, -1.0783,  ..., -0.8685, -0.4188, -0.4937],
         [ 0.0460, -0.2838, -1.0484,  ..., -0.8085, -0.2389, -0.0590],
         [ 0.4507,  0.1359, -0.4637,  ..., -1.0034, -0.0440,  0.6906]]])
图片类型:<class 'torch.Tensor'>
图片形状:torch.Size([3, 32, 32])
图片数据类型:torch.float32
值的范围:-1.9806982278823853到2.126077890396118
标注:1
图片种类名:automobile
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
拼接后的形状:torch.Size([3, 32, 32, 50000])
每个信道平均值:tensor([-0.0004, -0.0006, -0.0010])
每个信道标准差:tensor([1.0001, 0.9999, 1.0000])
torch.Size([64, 3, 32, 32])
tensor([5, 7, 6, 0, 5, 2, 5, 8, 0, 3, 2, 8, 1, 0, 4, 3, 1, 5, 3, 2, 4, 2, 5, 4,
        8, 9, 2, 6, 9, 4, 6, 9, 7, 6, 2, 6, 3, 4, 7, 0, 8, 7, 0, 1, 4, 7, 0, 5,
        9, 3, 7, 7, 8, 5, 3, 3, 0, 4, 7, 7, 4, 5, 5, 1])
torch.Size([64, 3, 32, 32])
tensor([4, 6, 6, 9, 2, 8, 8, 5, 3, 0, 6, 5, 0, 4, 6, 1, 2, 7, 3, 2, 0, 8, 3, 1,
        5, 4, 3, 9, 6, 6, 7, 3, 7, 3, 3, 2, 4, 6, 3, 8, 6, 8, 2, 4, 0, 9, 4, 8,
        8, 8, 8, 8, 4, 8, 6, 3, 2, 0, 5, 3, 9, 3, 4, 4])
torch.Size([64, 3, 32, 32])
tensor([3, 4, 1, 8, 6, 0, 1, 3, 5, 0, 3, 9, 3, 5, 0, 6, 0, 7, 2, 8, 6, 9, 0, 4,
        5, 3, 7, 1, 9, 2, 4, 2, 9, 3, 5, 6, 0, 6, 0, 9, 1, 1, 8, 4, 1, 2, 9, 3,
        3, 9, 1, 1, 3, 8, 6, 0, 4, 0, 8, 0, 3, 7, 2, 5])
torch.Size([64, 3, 32, 32])

————————————————————————————————————此处省略————————————————————————————————————————————

torch.Size([16, 3, 32, 32])
tensor([3, 8, 1, 4, 5, 1, 5, 2, 4, 1, 3, 2, 3, 3, 9, 6])

Process finished with exit code 0

注意的点:最后留下不足64个图片的尾巴是因为DataLoader的drop_last为False。

torch.Size([16, 3, 32, 32])
tensor([3, 8, 1, 4, 5, 1, 5, 2, 4, 1, 3, 2, 3, 3, 9, 6])

Process finished with exit code 0

猜你喜欢

转载自blog.csdn.net/wzfafabga/article/details/127694889