一、数据增强
数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力
二、transforms——裁剪
2.1. transforms.CenterCrop
transforms.CenterCrop(size)
功能: 从图像中心裁剪图片
- size: 所需裁剪图片尺寸
2.2 transforms.RandomCrop
transforms.RandomCrop(size,
padding=None,
pad_if_needed=False,
fill=0,
padding_mode=' constant')
功能: 从图片中随机裁剪出尺寸为size的图片
- size: 所需裁剪图片尺寸
- padding: 设置填充大小
- 当为a时,上下左右均填充a个像素
- 当为(a, b)时,上下填充b个像素,左右填充a个像素
- 当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
- pad-if-need: 若图像小于设定size ,则填充
- padding_mode: 填充模式,有4种模式
- constant: 像素值由fill设定
- edge: 像素值由图像边缘像素决定
- reflect: 镜像填充,最后一个像素不镜像, eg: [1,2,3,4] → [3,2,1,2,3,4,3,2]
- symmetric: 镜像填充,最后一个像素镜像, eg: [1.2.3.4] → [2,1,1,2,3,4,4,3]
- fill: constant时,设置填充的像素值,即填充颜色
2.3. RandomResizedCrop
RandomResizedCrop (size,scale=(0.08, 1.0),ratio=(3/4, 4/3),interpolation)
功能: 随机大小、长宽比栽剪图片
- size: 所需裁剪图片尺寸
- scale: 随机裁剪面积比例,默认(0.08, 1)
- ratio: 随机长宽比,默认(3/4, 4/3)
- interpolation: 插值方法
- PIL.Image.NEAREST
- PIL.Image.BILINEAR
- PIL.Image.BICUBIC
2.4 FiveCrop
transforms.FiveCrop(size)
功能: 在图像的上下左右以及中心裁剪出尺寸为size的5张图片
2.5 TenCrop
transforms.TenCrop(size,vertical_flip=False)
功能:先获得FiveCrop处理的5张图片,对这5张图片进行水平或者垂直镜像获得10张图片
- size: 所需裁剪图片尺寸
- vertical_flip: 是否垂直翻转
三、transforms——翻转和旋转
3.1 RandomHorizontalFlip
RandomHorizontalFlip(p=0.5)
功能: 依概率垂直(上下)翻转图片
- p: 翻转概率
3.2 RandomVerticalFlip
RandomVerticalFlip(p=0.5)
功能: 依概率水平(左右)翻转图片
- p: 翻转概率
3.3 RandomRotation
RandomRotation(degrees,resample=False,expand=False,center=None)
功能: 随机旋转图片
- degrees: 旋转角度
- 当为a时,在(-a, a)之间选择旋转角度
- 当为(a, b)时,在(a, b)之间随机选择一个旋转角度
- resample: 重采样方法
- expand: 是否扩大图片, 以保持原图信息
- center: 旋转点设置, 默认中心旋转
四、代码实践
基于前面的人民币二分类模型的训练过程,数据增强部分
扫描二维码关注公众号,回复:
9082517 查看本文章
![](/qrcode.jpg)
# -*- coding:utf-8 -*-
import os
import numpy as np
import torch
import random
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tools.my_dataset import RMBDataset
from PIL import Image
from matplotlib import pyplot as plt
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed(1) # 设置随机种子
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 1
LR = 0.01
log_interval = 10
val_interval = 1
rmb_label = {"1": 0, "100": 1}
def transform_invert(img_, transform_train):
"""
将data 进行反transfrom操作
:param img_: tensor
:param transform_train: torchvision.transforms
:return: PIL image
"""
if 'Normalize' in str(transform_train):
norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
img_.mul_(std[:, None, None]).add_(mean[:, None, None])
img_ = img_.transpose(0, 2).transpose(0, 1) # C*H*W --> H*W*C
img_ = np.array(img_) * 255
if img_.shape[2] == 3:
img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
elif img_.shape[2] == 1:
img_ = Image.fromarray(img_.astype('uint8').squeeze())
else:
raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )
return img_
# ============================ step 1/5 数据 ============================
split_dir = os.path.join("..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# 1 CenterCrop
transforms.CenterCrop(512), # 512
# 2 RandomCrop
# transforms.RandomCrop(224, padding=16),
# transforms.RandomCrop(224, padding=(16, 64)),
# transforms.RandomCrop(224, padding=16, fill=(255, 0, 0)),
# transforms.RandomCrop(512, pad_if_needed=True), # pad_if_needed=True
# transforms.RandomCrop(224, padding=64, padding_mode='edge'),
# transforms.RandomCrop(224, padding=64, padding_mode='reflect'),
# transforms.RandomCrop(1024, padding=1024, padding_mode='symmetric'),
# 3 RandomResizedCrop
# transforms.RandomResizedCrop(size=224, scale=(0.5, 0.5)),
# 4 FiveCrop
# transforms.FiveCrop(112),
# transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),
# 5 TenCrop
# transforms.TenCrop(112, vertical_flip=False),
# transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),
# 1 Horizontal Flip
# transforms.RandomHorizontalFlip(p=1),
# 2 Vertical Flip
# transforms.RandomVerticalFlip(p=0.5),
# 3 RandomRotation
# transforms.RandomRotation(90),
# transforms.RandomRotation((90), expand=True),
# transforms.RandomRotation(30, center=(0, 0)),
# transforms.RandomRotation(30, center=(0, 0), expand=True), # expand only for center rotation
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 5/5 训练 ============================
for epoch in range(MAX_EPOCH):
for i, data in enumerate(train_loader):
inputs, labels = data # B C H W
img_tensor = inputs[0, ...] # C H W
img = transform_invert(img_tensor, train_transform)
plt.imshow(img)
plt.show()
plt.pause(0.5)
plt.close()
# bs, ncrops, c, h, w = inputs.shape
# for n in range(ncrops):
# img_tensor = inputs[0, n, ...] # C H W
# img = transform_invert(img_tensor, train_transform)
# plt.imshow(img)
# plt.show()
# plt.pause(1)