nnunet项目官方地址
准备工作
关于nnUnet代码包的安装和配置参考nn-UNet使用记录–代码配置
nnUnet最经典的部分在于数据处理,本文简单介绍nnUnet的数据读取和数据增强方法。
以nnunet/training/network_training/nnUNetTrainer.py
为例
数据读取
self.dl_tr, self.dl_val = self.get_basic_generators()
- self.dl_tr – dataloader_train
- self.dl_val – dataloader_valid
# 这里我只放了3D部分
def get_basic_generators(self):
self.load_dataset() # 加载数据集
self.do_split() # 划分数据集
dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
False, oversample_foreground_percent=self.oversample_foreground_percent,
pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False,
oversample_foreground_percent=self.oversample_foreground_percent,
return dl_tr, dl_val
- self.dataset_tr:训练集列表
- self.dataset_val:验证集列表
- self.basic_generator_patch_size:训练集的DataLoader中,做空间变换的数据集尺寸(后面会细讲)
# '/root/data/nnUNet_preprocessed/Task040_KiTS/nnUNetData_plans_v2.1_stage0' 数据集地址
def load_dataset(self):
self.dataset = load_dataset(self.folder_with_preprocessed_data)
def load_dataset(folder, num_cases_properties_loading_threshold=1000):
# we don't load the actual data but instead return the filename to the np file.
print('loading dataset')
case_identifiers = get_case_identifiers(folder) # 找到所有的文件名, ['case_00051', 'case_00007', ...]
case_identifiers.sort() # 对文件名进行排序
dataset = OrderedDict()
for c in case_identifiers:
dataset[c] = OrderedDict()
dataset[c]['data_file'] = join(folder, "%s.npz" % c)
# dataset[c]['properties'] = load_pickle(join(folder, "%s.pkl" % c))
dataset[c]['properties_file'] = join(folder, "%s.pkl" % c)
if dataset[c].get('seg_from_prev_stage_file') is not None:
dataset[c]['seg_from_prev_stage_file'] = join(folder, "%s_segs.npz" % c)
if len(case_identifiers) <= num_cases_properties_loading_threshold:
print('loading all case properties')
for i in dataset.keys():
dataset[i]['properties'] = load_pickle(dataset[i]['properties_file'])
return dataset
load_dataset
返回的是数据集的文件信息,并没有加载.npz文件
part_1:加载图像
加载预处理阶段得到的图像和信息
def generate_train_batch(self):
# self.list_of_keys是所有的病例列表(包括训练集和验证集)
selected_keys = np.random.choice(self.list_of_keys, self.batch_size, True, None)
# self.data_shape=self.seg_shape=(b,1,139,230,206),shape是setup_DA_params根据patch_size产生的
data = np.zeros(self.data_shape, dtype=np.float32)
seg = np.zeros(self.seg_shape, dtype=np.float32)
case_properties = []
for j, i in enumerate(selected_keys):
# 当 j < round(self.batch_size * (1 - self.oversample_foreground_percent)) 时,不做oversample
if self.get_do_oversample(j):
force_fg = True
else:
force_fg = False
# properties是当前病人信息,是数据预处理阶段产生的
if 'properties' in self._data[i].keys():
properties = self._data[i]['properties']
else:
properties = load_pickle(self._data[i]['properties_file'])
case_properties.append(properties)
加载当前病例的data_file,case_all_data.shape=(2, 193, 201, 201), 2个通道一个是图像,一个是标签
if isfile(self._data[i]['data_file'][:-4] + ".npy"):
case_all_data = np.load(self._data[i]['data_file'][:-4] + ".npy", self.memmap_mode)
else:
case_all_data = np.load(self._data[i]['data_file'])['data']
part_2:填充边界
# 没有做级联
# self.need_to_pad = basic_patch_size - final_patch_size
need_to_pad = self.need_to_pad.copy() # (139,230,206) - (80,160,160) = (59,70,46)
for d in range(3):
# 如果当前图像的尺寸加上pad的尺寸,还是比basic_patch_size要小,需要增加pad的尺寸
if need_to_pad[d] + case_all_data.shape[d + 1] < self.patch_size[d]:
need_to_pad[d] = self.patch_size[d] - case_all_data.shape[d + 1]
# we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we
# define what the upper and lower bound can be to then sample form them with np.random.randint
# 把当做一个三维的盒子去看,l(lower),u(upper)分别代表盒子的底部和顶部,所以一共有6个
# 最终得到的尺寸是 case_shape - basic_patch_size + need_to_pad
shape = case_all_data.shape[1:]
lb_x = - need_to_pad[0] // 2
ub_x = shape[0] + need_to_pad[0] // 2 + need_to_pad[0] % 2 - self.patch_size[0]
lb_y = - need_to_pad[1] // 2
ub_y = shape[1] + need_to_pad[1] // 2 + need_to_pad[1] % 2 - self.patch_size[1]
lb_z = - need_to_pad[2] // 2
ub_z = shape[2] + need_to_pad[2] // 2 + need_to_pad[2] % 2 - self.patch_size[2]
最终目标是得到basic_patch_size(139,230,206)大小的box
底部边界
# force_fg:可以理解为强制带有前景类(label除0以外的类)
# 如果不使用oversample,随机采样
if not force_fg:
bbox_x_lb = np.random.randint(lb_x, ub_x + 1) # 26
bbox_y_lb = np.random.randint(lb_y, ub_y + 1) # -31
bbox_z_lb = np.random.randint(lb_z, ub_z + 1) # 12
else:
# 从properties中找到前景类的位置
if 'class_locations' not in properties.keys():
raise RuntimeError("Please rerun the preprocessing with the newest version of nnU-Net!")
# this saves us a np.unique. Preprocessing already did that for all cases. Neat.
foreground_classes = np.array(
[i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) != 0])
foreground_classes = foreground_classes[foreground_classes > 0]
if len(foreground_classes) == 0:
# this only happens if some image does not contain foreground voxels at all
selected_class = None
voxels_of_that_class = None
print('case does not contain any foreground classes', i)
else:
selected_class = np.random.choice(foreground_classes)
voxels_of_that_class = properties['class_locations'][selected_class]
if voxels_of_that_class is not None:
selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
# selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.
# Make sure it is within the bounds of lb and ub
bbox_x_lb = max(lb_x, selected_voxel[0] - self.patch_size[0] // 2)
bbox_y_lb = max(lb_y, selected_voxel[1] - self.patch_size[1] // 2)
bbox_z_lb = max(lb_z, selected_voxel[2] - self.patch_size[2] // 2)
else:
# If the image does not contain any foreground classes, we fall back to random cropping
bbox_x_lb = np.random.randint(lb_x, ub_x + 1)
bbox_y_lb = np.random.randint(lb_y, ub_y + 1)
bbox_z_lb = np.random.randint(lb_z, ub_z + 1)
bbox_x_ub = bbox_x_lb + self.patch_size[0]
bbox_y_ub = bbox_y_lb + self.patch_size[1]
bbox_z_ub = bbox_z_lb + self.patch_size[2]
顶部边界
bbox_x_ub = bbox_x_lb + self.patch_size[0] # 165
bbox_y_ub = bbox_y_lb + self.patch_size[1] # 199
bbox_z_ub = bbox_z_lb + self.patch_size[2] # 218
axis | lower_box | upper_box | box_size |
---|---|---|---|
x | 26 | 165 | 139 |
y | -31 | 199 | 230 |
z | 12 | 218 | 206 |
box_size和basic_patch_size是一致的
有效边界(未填充)
valid_bbox_x_lb = max(0, bbox_x_lb) # 26
valid_bbox_x_ub = min(shape[0], bbox_x_ub) # 165
valid_bbox_y_lb = max(0, bbox_y_lb) # 0
valid_bbox_y_ub = min(shape[1], bbox_y_ub) # 199
valid_bbox_z_lb = max(0, bbox_z_lb) 12
valid_bbox_z_ub = min(shape[2], bbox_z_ub) # 201
复制+填充
case_all_data = np.copy(case_all_data[:, valid_bbox_x_lb:valid_bbox_x_ub, # shape:(139,199,201)
valid_bbox_y_lb:valid_bbox_y_ub, valid_bbox_z_lb:valid_bbox_z_ub])
# shape: (139,199,201) -> (139,230,206) 边界用0填充
data[j] = np.pad(case_all_data[:-1], ((0, 0),
(-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
(-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0)),
(-min(0, bbox_z_lb), max(bbox_z_ub - shape[2], 0))),
self.pad_mode, **self.pad_kwargs_data)
seg[j, 0] = np.pad(case_all_data[-1:], ((0, 0),
(-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
(-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0)),
(-min(0, bbox_z_lb), max(bbox_z_ub - shape[2], 0))),
'constant', **{
'constant_values': -1})
return {
'data': data, 'seg': seg, 'properties': case_properties, 'keys': selected_keys}
数据增强
part_1:默认参数
default_3D_augmentation_params
default_3D_augmentation_params = {
"selected_data_channels": None,
"selected_seg_channels": None,
"do_elastic": True,
"elastic_deform_alpha": (0., 900.),
"elastic_deform_sigma": (9., 13.),
"p_eldef": 0.2,
"do_scaling": True,
"scale_range": (0.85, 1.25), # (0.7,1.4)
"independent_scale_factor_for_each_axis": False,
"p_independent_scale_per_axis": 1,
"p_scale": 0.2,
"do_rotation": True,
"rotation_x": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
"rotation_y": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
"rotation_z": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
"rotation_p_per_axis": 1,
"p_rot": 0.2,
"random_crop": False,
"random_crop_dist_to_border": None,
"do_gamma": True,
"gamma_retain_stats": True,
"gamma_range": (0.7, 1.5),
"p_gamma": 0.3,
"do_mirror": True,
"mirror_axes": (0, 1, 2),
"dummy_2D": False,
"mask_was_used_for_normalization": None,
"border_mode_data": "constant",
"all_segmentation_labels": None, # used for cascade
"move_last_seg_chanel_to_data": False, # used for cascade
"cascade_do_cascade_augmentations": False, # used for cascade
"cascade_random_binary_transform_p": 0.4,
"cascade_random_binary_transform_p_per_label": 1,
"cascade_random_binary_transform_size": (1, 8),
"cascade_remove_conn_comp_p": 0.2,
"cascade_remove_conn_comp_max_size_percent_threshold": 0.15,
"cascade_remove_conn_comp_fill_with_other_class_p": 0.0,
"do_additive_brightness": False,
"additive_brightness_p_per_sample": 0.15,
"additive_brightness_p_per_channel": 0.5,
"additive_brightness_mu": 0.0,
"additive_brightness_sigma": 0.1,
"num_threads": 12 if 'nnUNet_n_proc_DA' not in os.environ else int(os.environ['nnUNet_n_proc_DA']),
"num_cached_per_thread": 1,
}
获取patch尺寸
- 输入到神经网络中的patch_size=(80,160,160)
- 数据增强时,对图像尺寸有影响的主要是旋转和缩放,所以要考虑这两个因素
- 可以把final_patch_size=(80,160,160),以及默认参数中的rot_x, rot_y, rot_z, scale_range带入,输出结果为(139,230,206)
def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range):
if isinstance(rot_x, (tuple, list)):
rot_x = max(np.abs(rot_x))
if isinstance(rot_y, (tuple, list)):
rot_y = max(np.abs(rot_y))
if isinstance(rot_z, (tuple, list)):
rot_z = max(np.abs(rot_z))
# 旋转限制在90度以内
rot_x = min(90 / 360 * 2. * np.pi, rot_x)
rot_y = min(90 / 360 * 2. * np.pi, rot_y)
rot_z = min(90 / 360 * 2. * np.pi, rot_z)
from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d
coords = np.array(final_patch_size)
final_shape = np.copy(coords)
if len(coords) == 3:
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0)
final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0)
elif len(coords) == 2:
final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0)
final_shape /= min(scale_range)
return final_shape.astype(int)
def setup_DA_params(self):
self.data_aug_params = default_3D_augmentation_params
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm # False
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'],
self.data_aug_params['scale_range'])
self.data_aug_params['selected_seg_channels'] = [0]
self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size
self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val,
self.data_aug_params['patch_size_for_spatialtransform'], self.data_aug_params)
- self.dl_tr 是训练集的dataloader, self.tr_gen增加了数据增强
- self.dl_val 是验证集的dataloader, self.val_gen增加了数据增强
- self.data_aug_params[‘patch_size_for_spatialtransform’] = patch_size = (80,160,160)
part_2: 变换方法
train_transform
- SegChannelSelectionTransform:标签如果有多个通道,可以选择一个通道。(我觉得用不到,标签一般都是单通道的)
- SpatialTransform:终极空间变换器,包括旋转、变形、缩放、裁剪。
- GammaTransform:Gamma变换,对输入图像灰度值进行非线性操作,使输出图像灰度值与输入图像灰度值呈指数关系
- MirrorTransform:镜像变换,沿着轴随机镜像翻转,每个轴默认的翻转概率是0.5
- MaskTransform:
data[mask < 0] = 0
,将mask之外(mask小于0)的部分置零 - RemoveLabelTransform:替换标签值,如
RemoveLabelTransform(-1, 0)
就是将标签为-1的替换为0 - RenameTransform:重命名data_dict,或者把data_dict中的seg部分丢掉,不算数据增强
- NumpyToTensor:顾名思义,numpy数组转为tensor
旋转、变形、缩放、裁剪,翻转,gamma变换,感觉像素值上的变换有点少,
nnUNetTrainerV2
可能会多一些。
valid_transform
没有做数据增强,只是对数据做了转换
RemoveLabelTransform,SegChannelSelectionTransform,RenameTransform,NumpyToTensor
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
border_val_seg=-1, pin_memory=True,
seeds_train=None, seeds_val=None, regions=None):
assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
tr_transforms = []
if params.get("selected_data_channels") is not None:
tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
patch_size_spatial = patch_size
tr_transforms.append(SpatialTransform(
patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"),
do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant",
border_cval_seg=border_val_seg,
order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
))
if params.get("do_gamma"):
tr_transforms.append(
GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
p_per_sample=params["p_gamma"]))
if params.get("do_mirror"):
tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) # (0,1,2)
if params.get("mask_was_used_for_normalization") is not None:
mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
tr_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms = Compose(tr_transforms)
# from batchgenerators.dataloading import SingleThreadedAugmenter
# batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
# import IPython;IPython.embed()
batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
params.get("num_cached_per_thread"), seeds=seeds_train,
pin_memory=pin_memory)
val_transforms = []
val_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("selected_data_channels") is not None:
val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
val_transforms.append(RenameTransform('seg', 'target', True))
if regions is not None:
val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
val_transforms = Compose(val_transforms)
# batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
params.get("num_cached_per_thread"), seeds=seeds_val,
pin_memory=pin_memory)
return batchgenerator_train, batchgenerator_val
part_3: 版本二
对比一下nnUNetTrainerV2的数据增强方法:
因为不想做深监督(deep_supervision)损失,深监督相关的代码我就省略了
1.增强旋转:将旋转的角度范围从 [-15, 15] 度增加到 [-30,30] 度
2.减小缩放:缩放范围从(0.7, 1.4) 修改为 (0.85,1.25),但是源码里面好像本来就是(0.85,1.25)
注意,旋转和缩放修改之后,
setup_DA_params
中得到的basic_patch_szie也会变化
做个实验,将旋转从15度增加到30度,输入同样的patch_size=(80,160,160),输出的basic_patch_size由原来的(139,230,206)变为(157,257,210)
def setup_DA_params(self):
self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]
self.data_aug_params = default_3D_augmentation_params
self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
self.do_dummy_2D_aug = False
if max(self.patch_size) / min(self.patch_size) > 1.5:
default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
self.data_aug_params = default_2D_augmentation_params
self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'], self.data_aug_params['rotation_y'],
self.data_aug_params['rotation_z'], self.data_aug_params['scale_range'])
3.丢弃弹性变换
4.增加高斯噪声(GaussianNoiseTransform), 高斯模糊(GaussianBlurTransform)
5.增加亮度变换,BrightnessMultiplicativeTransform,BrightnessTransform,
6.对比度变换,ContrastAugmentationTransform
7.分辨率变换,SimulateLowResolutionTransform
果然,相比
nnUNetTrainer
,nnUNetTrainerV2
重点增加了图像体素值方面的变换。
def get_moreDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
border_val_seg=-1,
seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None,
soft_ds=False,
classes=None, pin_memory=True, regions=None,
use_nondetMultiThreadedAugmenter: bool = False):
assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
tr_transforms = []
if params.get("selected_data_channels") is not None:
tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
patch_size_spatial = patch_size
ignore_axes = None
tr_transforms.append(SpatialTransform(
patch_size_spatial, patch_center_dist_from_border=None,
do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"),
sigma=params.get("elastic_deform_sigma"),
do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
angle_z=params.get("rotation_z"), p_rot_per_axis=params.get("rotation_p_per_axis"),
do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data,
border_mode_seg="constant", border_cval_seg=border_val_seg,
order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
))
# we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
# channel gets in the way
tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
p_per_channel=0.5))
tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))
if params.get("do_additive_brightness"):
tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"),
params.get("additive_brightness_sigma"),
True, p_per_sample=params.get("additive_brightness_p_per_sample"),
p_per_channel=params.get("additive_brightness_p_per_channel")))
tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
p_per_channel=0.5,
order_downsample=0, order_upsample=3, p_per_sample=0.25,
ignore_axes=ignore_axes))
tr_transforms.append(
GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"),
p_per_sample=0.1)) # inverted gamma
if params.get("do_gamma"):
tr_transforms.append(
GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
p_per_sample=params["p_gamma"]))
if params.get("do_mirror") or params.get("mirror"):
tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
if params.get("mask_was_used_for_normalization") is not None:
mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))
tr_transforms.append(RemoveLabelTransform(-1, 0))
tr_transforms.append(RenameTransform('seg', 'target', True))
tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
tr_transforms = Compose(tr_transforms)
if use_nondetMultiThreadedAugmenter:
if NonDetMultiThreadedAugmenter is None:
raise RuntimeError('NonDetMultiThreadedAugmenter is not yet available')
batchgenerator_train = NonDetMultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
params.get("num_cached_per_thread"), seeds=seeds_train,
pin_memory=pin_memory)
else:
batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
params.get("num_cached_per_thread"),
seeds=seeds_train, pin_memory=pin_memory)
val_transforms = []
val_transforms.append(RemoveLabelTransform(-1, 0))
if params.get("selected_data_channels") is not None:
val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
if params.get("selected_seg_channels") is not None:
val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
val_transforms.append(RenameTransform('seg', 'target', True))
val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
val_transforms = Compose(val_transforms)
if use_nondetMultiThreadedAugmenter:
if NonDetMultiThreadedAugmenter is None:
raise RuntimeError('NonDetMultiThreadedAugmenter is not yet available')
batchgenerator_val = NonDetMultiThreadedAugmenter(dataloader_val, val_transforms,
max(params.get('num_threads') // 2, 1),
params.get("num_cached_per_thread"),
seeds=seeds_val, pin_memory=pin_memory)
else:
batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms,
max(params.get('num_threads') // 2, 1),
params.get("num_cached_per_thread"),
seeds=seeds_val, pin_memory=pin_memory)
# batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
return batchgenerator_train, batchgenerator_val
MultiThreadedAugmenter的作用是将dataloader和transforms整合到一起
原train_loader的输出尺寸是basic_patch_size(139,230,206),而输入网络中的patch_size是(80,160,160),说明在patch的尺寸在数据增强部分被修改了,观察数据增强部分的源码:
# patch_size_spatial = (80,160,160)
# (139,230,206)大小的图像经过空间变换(旋转、变形、缩放)之后,最后裁剪到(80,160,160)
tr_transforms.append(SpatialTransform(
patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"),
do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant",
border_cval_seg=border_val_seg,
order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
))
SpatialTransform把数据增强的参数传递给augment_spatial,空间变换的源码我放在下面:
def augment_spatial(data, seg, patch_size, patch_center_dist_from_border=30,
do_elastic_deform=True, alpha=(0., 1000.), sigma=(10., 13.),
do_rotation=True, angle_x=(0, 2 * np.pi), angle_y=(0, 2 * np.pi), angle_z=(0, 2 * np.pi),
do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3,
border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1,
p_scale_per_sample=1, p_rot_per_sample=1, independent_scale_for_each_axis=False,
p_rot_per_axis: float = 1, p_independent_scale_per_axis: int = 1):
dim = len(patch_size)
seg_result = None
if seg is not None:
if dim == 2:
seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1]), dtype=np.float32)
else:
seg_result = np.zeros((seg.shape[0], seg.shape[1], patch_size[0], patch_size[1], patch_size[2]),
dtype=np.float32)
if dim == 2:
data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1]), dtype=np.float32)
else:
data_result = np.zeros((data.shape[0], data.shape[1], patch_size[0], patch_size[1], patch_size[2]),
dtype=np.float32)
if not isinstance(patch_center_dist_from_border, (list, tuple, np.ndarray)):
patch_center_dist_from_border = dim * [patch_center_dist_from_border]
for sample_id in range(data.shape[0]):
coords = create_zero_centered_coordinate_mesh(patch_size)
modified_coords = False
if do_elastic_deform and np.random.uniform() < p_el_per_sample:
a = np.random.uniform(alpha[0], alpha[1])
s = np.random.uniform(sigma[0], sigma[1])
coords = elastic_deform_coordinates(coords, a, s)
modified_coords = True
if do_rotation and np.random.uniform() < p_rot_per_sample:
if np.random.uniform() <= p_rot_per_axis:
a_x = np.random.uniform(angle_x[0], angle_x[1])
else:
a_x = 0
if dim == 3:
if np.random.uniform() <= p_rot_per_axis:
a_y = np.random.uniform(angle_y[0], angle_y[1])
else:
a_y = 0
if np.random.uniform() <= p_rot_per_axis:
a_z = np.random.uniform(angle_z[0], angle_z[1])
else:
a_z = 0
coords = rotate_coords_3d(coords, a_x, a_y, a_z)
else:
coords = rotate_coords_2d(coords, a_x)
modified_coords = True
if do_scale and np.random.uniform() < p_scale_per_sample:
if independent_scale_for_each_axis and np.random.uniform() < p_independent_scale_per_axis:
sc = []
for _ in range(dim):
if np.random.random() < 0.5 and scale[0] < 1:
sc.append(np.random.uniform(scale[0], 1))
else:
sc.append(np.random.uniform(max(scale[0], 1), scale[1]))
else:
if np.random.random() < 0.5 and scale[0] < 1:
sc = np.random.uniform(scale[0], 1)
else:
sc = np.random.uniform(max(scale[0], 1), scale[1])
coords = scale_coords(coords, sc)
modified_coords = True
# now find a nice center location
if modified_coords:
for d in range(dim):
if random_crop:
ctr = np.random.uniform(patch_center_dist_from_border[d],
data.shape[d + 2] - patch_center_dist_from_border[d])
else:
ctr = int(np.round(data.shape[d + 2] / 2.))
coords[d] += ctr
for channel_id in range(data.shape[1]):
data_result[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
border_mode_data, cval=border_cval_data)
if seg is not None:
for channel_id in range(seg.shape[1]):
seg_result[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
border_mode_seg, cval=border_cval_seg,
is_seg=True)
else:
if seg is None:
s = None
else:
s = seg[sample_id:sample_id + 1]
if random_crop:
margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)]
d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin)
else:
d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s)
data_result[sample_id] = d[0]
if seg is not None:
seg_result[sample_id] = s[0]
return data_result, seg_result
注意到一个参数–random_crop,控制随机裁剪或者中心裁剪
默认参数(default_3D_augmentation_params)中的random_crop=False,说明空间变换之后从图像中心裁剪到(80,160,160)
if random_crop:
margin = [patch_center_dist_from_border[d] - patch_size[d] // 2 for d in range(dim)]
d, s = random_crop_aug(data[sample_id:sample_id + 1], s, patch_size, margin)
else:
d, s = center_crop_aug(data[sample_id:sample_id + 1], patch_size, s)
nnUnet的数据增强方法主要是依赖batchgenerators包,啃了一遍之后发现也没那么复杂。
以后就可以借鉴nnUnet的数据增强方法,用在自己的图像分割项目中。