测试Dataset是否可已返回字典?
且DataLoader设置batch_szie>1时返回的数据是什么样子的?
接下来给出你答案。
首先看一个测试代码:
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
class TestDataset(Dataset):
def __init__(self):
# 随意定义一个数据集类,并且继承Dataset
# 定义一个list模拟数据集
self.lines = [1, 2, 3, 4, 5, 6, 7, 8, 9]
# 获取数据长度,保证类别可以是一个迭代类型
self.length = len(self.lines)
def __len__(self):
# 返回数据集的长度,可以通过len方法获取数据集的长度
return self.length
def __getitem__(self, index):
# 可以根据index获取数据集中的一个元素
index = index % self.length # 保证index安全,如果超过数据集长度,不会导致代码崩溃
line = self.lines[index] # 获取数据集的元素
# 返回一个字典(字典的内容是瞎写的没有任何意义,只不过是为了测试是否可以返回一个字典)
return {
"image_name": str(index + 1),
"image": line,
"shape": np.array([line ** 3 + 5, - line ** 2 + line, 3])}
class Model(nn.Module):
def __init__(self):
"""
随意定义一个模型,可以获取DataLoader的内容,如果直接打印DataLoader返回的元素值,是一个地址
eg:
dataset = TestDataset()
dataloader = DataLoader(dataset, batch_size=4, )
for data_dict in dataloader:
print(data_dict)
output:
<torch.utils.data.dataloader.DataLoader object at 0x000001485C9C6848>
<torch.utils.data.dataloader.DataLoader object at 0x000001485C9C6848>
<torch.utils.data.dataloader.DataLoader object at 0x000001485C9C6848>
所以需要一个模型来对数据进行解码,这个步骤是编写训练模型的一个标准流程
"""
super(Model, self).__init__()
pass
def forward(self, data_dict):
print(data_dict)
def dataset_collate(batch):
# dataloader加载时先是一个batch的数据保存到一个列表中,之后拼接到一起处理成一个整体的批次,但是由于不同图片可能存在不同多个目标边界框,所以需要自己编写拼接规则。
pass
if __name__ == '__main__':
model = Model()
dataset = TestDataset()
dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn(可选参数))
for data_dict in dataloader:
model(data_dict)
输出结果:
{
'image_name': ['1', '2', '3', '4'], 'image': tensor([1, 2, 3, 4]), 'shape': tensor([[ 6, 0, 3],
[ 13, -2, 3],
[ 32, -6, 3],
[ 69, -12, 3]], dtype=torch.int32)}
{
'image_name': ['5', '6', '7', '8'], 'image': tensor([5, 6, 7, 8]), 'shape': tensor([[130, -20, 3],
[221, -30, 3],
[348, -42, 3],
[517, -56, 3]], dtype=torch.int32)}
{
'image_name': ['9'], 'image': tensor([9]), 'shape': tensor([[734, -72, 3]], dtype=torch.int32)}
通过输出结果可以发现返回是一个字典的时候,DataLoader会将一个batch的数据按照key存放到一个list中,并且数据类型全部转成了Tensor