由于中文文档里面没有写这个类但是我们经常用它,所以这里进行一下分析
类定义
参数
额外信息
使用方法以及要点
不用sampler
# 训练数据集的加载器,自动将数据分割成batch,顺序随机打乱
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
drop_last = True ,
shuffle=True)
使用sampler
首先,我们定义下标数组indices,它相当于对所有test_dataset中数据的编码
# 然后定义下标indices_val来表示校验集数据的那些下标,indices_test表示测试集的下标
indices = range(len(test_dataset))
indices_val = indices[:5000]
indices_test = indices[5000:]
# 根据这些下标,构造两个数据集的SubsetRandomSampler采样器,它会对下标进行采样
sampler_val = torch.utils.data.sampler.SubsetRandomSampler(indices_val)
sampler_test = torch.utils.data.sampler.SubsetRandomSampler(indices_test)
# 根据两个采样器来定义加载器,注意将sampler_val和sampler_test分别赋值给了validation_loader和test_loader
validation_loader = torch.utils.data.DataLoader(dataset =test_dataset,
batch_size = batch_size,
sampler = sampler_val
)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
sampler = sampler_test
)
特别注意
可能出现batch_size小于预期的情况,请指定drop_last = True解决