现在有好多做图像增强的,直接拿来用就好了,这个是自己写的记录一下
import numpy as np
import cv2
# 剪切一张图片
def crop_2D(inputImg, label, crop_size=[256, 256]):
index_arr = np.where(label == 1)
# if len(index_arr[0]) == 0:
# return inputImg, label
choice_index = np.random.randint(len(index_arr[0]))
index_x = index_arr[0][choice_index]
index_y = index_arr[1][choice_index]
start_x = index_x - crop_size[0] // 2
start_y = index_y - crop_size[1] // 2
end_x = index_x + crop_size[0] // 2
end_y = index_y + crop_size[1] // 2
# 上下方向越界
if start_x < 0:
end_x = end_x - start_x
start_x = 0
elif end_x > crop_size[0] * 2:
start_x = start_x - (end_x - crop_size[0] * 2)
end_x = crop_size[0] * 2
# 左右方向越界
if start_y < 0:
end_y = end_y - start_y
start_y = 0
elif end_y > crop_size[1] * 2:
start_y = start_y - (end_y - crop_size[1] * 2)
end_y = crop_size[1] * 2
inputImg1 = inputImg[start_x: end_x, start_y: end_y]
label1 = label[start_x: end_x, start_y: end_y]
return inputImg1, label1
# 尺度放大
def zoom_2D(inputImg, label):
src_w, src_h = inputImg.shape
src_w_half = src_w // 2
src_h_half = src_h // 2
zoom = np.random.uniform(0.825, 1.125)
if zoom > 1:
interpolation = cv2.INTER_CUBIC
elif zoom < 1:
interpolation = cv2.INTER_AREA
inputImg = cv2.resize(src=inputImg, dsize=(0, 0), fx=zoom, fy=zoom, interpolation=interpolation)
label = cv2.resize(src=label, dsize=(0, 0), fx=zoom, fy=zoom, interpolation=interpolation)
if zoom > 1:
w, h = inputImg.shape
inputImg1 = inputImg[w // 2 - src_w_half: w // 2 + src_w_half, h // 2 - src_h_half: h // 2 + src_h_half]
label1 = label[w // 2 - src_w_half: w // 2 + src_w_half, h // 2 - src_h_half: h // 2 + src_h_half]
elif zoom < 1:
inputImg1 = np.zeros([src_w, src_h])
label1 = np.zeros([src_w, src_h])
w, h = inputImg.shape
if w % 2 != 0:
inputImg1[src_w_half - w // 2: src_w_half + 1 + w // 2, src_h_half - h // 2: src_h_half + 1 + h // 2] = inputImg[:, :]
label1[src_w_half - w // 2: src_w_half + 1 + w // 2, src_h_half - h // 2: src_h_half + 1 + h // 2] = label[:, :]
else:
label1[src_w_half - w // 2: src_w_half + w // 2, src_h_half - h // 2: src_h_half + h // 2] = label[:, :]
inputImg1[src_w_half - w // 2: src_w_half + w // 2, src_h_half - h // 2: src_h_half + h // 2] = inputImg[:, :]
return inputImg1, label1
# 角度
def rot_2D(inputImg, label, rng=np.random.RandomState(1)):
angle = np.random.randint(360)
rows, cols = inputImg.shape
M = cv2.getRotationMatrix2D(((cols - 1) / 2.0, (rows - 1) / 2.0), angle, 1)
# perform the actual rotation and return the image
return cv2.warpAffine(inputImg, M, (cols, rows)), cv2.warpAffine(label, M, (cols, rows))
# 翻转
def flip_2D(inputImg, label):
if np.random.randint(2):
# 水平翻转
inputImg = np.fliplr(inputImg)
label = np.fliplr(label)
else:
# 垂直翻转
inputImg = np.flipud(inputImg)
label = np.flipud(label)
return inputImg, label