#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on 2019/03/13
author: lujie
"""
import cv2
import numpy as np
import matplotlib.pyplot as plt
from IPython import embed
from skimage import exposure, img_as_float
# from utils.data_utils import load_CIFAR10
class DataArgmentation(object):
def __init__(self, dataset = None):
'''
dataset['X_train'] : N x H x W x C
'''
# self.data = load_CIFAR10()
self.data = dataset
del dataset
def _visual(self, img = None, normalize = True):
''' test the method '''
if normalize:
img_max, img_min = np.max(img), np.min(img)
img = 255.0 * (img - img_min) / (img_max - img_min)
plt.imshow(img.astype('uint8'))
plt.gca().axis('off')
def _flip(self, mode = 1):
'''
Flip the image according to mode :
- 0 : vertical flip
- 1 : horiontal flip
- -1 : diag flip
'''
flip_data = np.zeros_like(self.data['X_train'])
for index in range(self.data['X_train'].shape[0]):
flip_data[index] = cv2.flip(self.data['X_train'][index], mode, dst=None)
return flip_data
def _trans(self, ratio = 0.10):
'''
Translation according to ratio, with choice :
- 0 : trans up
- 1 : trans down
- 2 : trans left
- 3 : trans right
'''
N, H, W, C = self.data['X_train'].shape
shift_H, shift_W = int(H * ratio), int(W * ratio)
choice = np.random.randint(0, 4, N)
trans_data = np.zeros_like(self.data['X_train'])
for direction in range(4):
flag = choice == direction
if direction == 0:
trans_data[flag, :(H-shift_H), :, :] = self.data['X_train'][flag, shift_H:, :, :]
elif direction == 1:
trans_data[flag, shift_H:, :, :] = self.data['X_train'][flag, :(H-shift_H), :, :]
elif direction == 2:
trans_data[flag, :, :(W-shift_W), :] = self.data['X_train'][flag, :, shift_W:, :]
else:
trans_data[flag, :, shift_W:, :] = self.data['X_train'][flag, :, :(W-shift_W), :]
return trans_data
def _crop(self, ratio = 0.9, random_flag = True):
'''
Crop the image according to ratio
- ratio : the ratio to crop from image
- random_flag:
- False : crop the center of image
- True : random crop
'''
N, H, W, C = self.data['X_train'].shape
H_crop, W_crop = int(H * ratio), int(W * ratio)
H_range, W_range = H - H_crop, W - W_crop
center_y, center_x = H_range // 2, W_range // 2
crop_data = np.zeros_like(self.data['X_train'])
if not random_flag:
crop_data[:, center_y:(center_y+H_crop), center_x:(center_x+W_crop), :] = \
self.data['X_train'][:, center_y:(center_y+H_crop), center_x:(center_x+W_crop), :]
else:
for index in range(N):
y, x = np.random.randint(0, H_range), np.random.randint(0, W_range)
crop_data[index, center_y:(center_y+H_crop), center_x:(center_x+W_crop), :] = \
self.data['X_train'][index, y:(y+H_crop), x:(x+W_crop), :]
return crop_data
def _color_jitter(self, mode = 1, gamma = 0.5):
'''
Color jitter according to mode
- mode :
- -1 : rgb -> gray
- 0 : rgb -> hsv
- 1 : adjust brightness
- 2 : rescale intensity
'''
N, H, W, C = self.data['X_train'].shape
color_data = None
if mode == -1:
color_data = np.zeros(N, H, W)
for index in range(N):
color_data[index] = cv2.cvtColor(self.data['X_train'][index], cv2.COLOR_BGR2GRAY) # TODO
else:
color_data = np.zeros_like(self.data['X_train'])
if mode == 0:
for index in range(N):
color_data[index] = cv2.cvtColor(self.data['X_train'][index], cv2.COLOR_BGR2HSV)
elif mode == 1:
for index in range(N):
color_data[index] = exposure.adjust_gamma(self.data['X_train'][index], gamma) # gamma
# color_data[index] = exposure.adjust_log(self.X_data[index], gamma) # log
elif mode == 2:
for index in range(N):
flag = exposure.is_low_contrast(self.data['X_train'][index])
if flag:
color_data[index] = exposure.rescale_intensity(self.data['X_train'][index])
else:
color_data[index] = self.data['X_train'][index]
else:
raise ValueError('Unrecognized color jitter mode')
return color_data
深度学习中常见的数据增强策略及代码实现
猜你喜欢
转载自blog.csdn.net/On_theway10/article/details/88644721
今日推荐
周排行