在训练网络时,若采用数据增强扩充数据集,pytorch中常用
transforms.Compose([
transforms.RandomSizedCrop(max(resize)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
方法来进行变换。transform中有很多随机变换方法,如果需要确保input和label图像进行相同的变换就会出现问题。
所以将所有的变换方法拆开如下:
p1 = random.randint(0,1)
p2 = random.randint(0,1)
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomCrop(128),
transforms.RandomHorizontalFlip(p1),
transforms.RandomVerticalFlip(p2)
])
seed = np.random.randint(2147483647) # make a seed with numpy generator
random.seed(seed) # apply this seed to img tranfsorms
x = transform(x)
random.seed(seed) # apply this seed to img tranfsorms
t = transform(t)
随机给定p1 p2(0或1),分别应用于transforms中的两个概率随机变换,保证input (x) 和label(t)同时变换(概率为1)或同时不变(概率为0),而对于randomcrop这种不依赖与概率的随机变换,可以在每次应用函数前设置相同的种子来确保随机裁剪区域相同。