torchvision 模块
torchvision是独立于pytorch的关于图像操作的工具库,主要包含了如下4个子模块或包:
- datasets
- utils
- transforms
- models
1、datasets
torchvision.datasets包含如下数据集,可以下载和加载
- MNIST
- COCO(用于图像标注和目标检测)(Captioning and Detection)
- LSUN Classification
- ImageFolder
- Imagenet-12
- CIFAR10 and CIFAR100
- STL10
- SVHN
- PhotoTour
from torchvision import datasets
train_dataset = datasets.MNIST(root='./data', train=True,
transform=transforms.ToTensor(),
download=True)
此操作便可下载MNIST的训练数据集,
数据集有 API: - __getitem__ - __len__ 他们都是 torch.utils.data.Dataset的子类。因此, 他们可以使用torch.utils.data.DataLoader里的多线程 (python multithreading) 。
例如:
torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)
2、utils
utils主要提供了两个方法:
- make_grid
- save_image
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False)
将输入的minbatch_size图片转换成一张大的网格图片
torchvision.utils.save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False)
将输入的图片保存,如果输入的是minbatch_size图片,先make_grid转换成大的网格图再保存
3、transforms
了方便进行数据的操作,pytorch团队提供了一个torchvision.transforms包,我们可以用transforms进行以下操作:
PIL.Image/numpy.ndarray与Tensor的相互转化;
归一化;
对PIL.Image进行裁剪、缩放等操作。
通常,在使用torchvision.transforms,我们通常使用transforms.Compose将transforms组合在一起。
transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
std = [ 0.229, 0.224, 0.225 ]),
])
- transforms.ToTensor() :把shape=(H x W x C)的像素值范围为[0, 255]的PIL.Image或者numpy.ndarray转换成shape=(C x H x W)的像素值范围为[0.0, 1.0]的torch.FloatTensor。
- transforms.Normalize(mean,std) : 此转换类作用于torch.tensor,给定均值(R, G, B)和标准差(R, G, B),用公式channel = (channel - mean) / std进行规范化。
4、models
torchvision.models包含下列常用网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型
- AlexNet: AlexNet variant from the “One weird trick” paper.
- VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
- ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
- SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1
import torchvision
model = torchvision.models.resnet50(pretrained=True)
这样就导入了resnet50的预训练模型了,
如果只需要网络结构,不需要用预训练模型的参数来初始化
model = torchvision.models.resnet50(pretrained=False)
如果要导入densenet模型也是同样的道理,比如导入densenet169,且不需要是预训练的模型,
model = torchvision.models.densenet169(pretrained=False)
由于预训练参数默认是假,所以等价于
model = torchvision.models.densenet169()