# Pytorch 0.4.0 cifar10数据集显示
# @Time: 2018/6/15
# @Author: xfLi
import torchvision as tv
import torchvision.transforms as transforms
import torch as t
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
def getData():
#数据预处理
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#训练集
train_set = tv.datasets.CIFAR10(root='/data/', train=True, transform=transform, download=True)
train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
#测试集
test_set = tv.datasets.CIFAR10(root='/data/', train=False, transform=transform, download=True)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
return train_loader, test_set, classes
if __name__ == '__main__':
_, testset, classes = getData()
for img, label in testset:
print(classes[label])
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
【PyTorch】:数据预处理
猜你喜欢
转载自blog.csdn.net/qq_30159015/article/details/80756470
今日推荐
周排行