一. Dataset处理
1. 导入torch的库
import torchvision.datasets as dset
import torchvision.transforms as transforms
2. dest.xxx函数
例如 :
dataset = dset.CIFAR10(root='../data/', download=True, transform=none)
解释 :
将相对目录../data下的cifar-10-batches-py文件夹中的全部数据(50000张图片作为训练数据)加载到内存中,若download为True时,会自动从网上下载数据并解压
补充 :
(1) root,表示cifar10数据的加载的相对目录
(2) train,表示是否加载数据库的训练集,false的时候加载测试集
(3) download,表示是否自动下载cifar数据集
(4) transform,表示是否需要对数据进行预处理,none为不进行预处理
3. transform预处理
例如 :
transform=transforms.Compose([ transforms.Scale(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]))
解释 :
(1) transform.Scale(size) 定义图片大小
(2) transform.ToTenosr()
将 PIL.Image/numpy.ndarray 数据转化为torch.FloadTensor,并归一化到[0, 1.0]
(3) transform.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)
通过下面公式实现数据归一化
channel=(channel-mean)/std
4. loader
例如 :
loader = torch.utils.data.DataLoader( dataset = dataset,
batch_size = opt.batchSize,
shuffle = True)
解释 :
(1) 第一个参数transformed_dataset,即已经用了transform的Dataset实例。
(2) 第二个参数batch_size,表示每个batch包含多少个数据。
(3) 第三个参数shuffle,布尔型变量,表示是否打乱。
(4) 第四个参数num_workers表示使用几个线程来加载数据
5. 数据集补充
(1) CIFAR-10
CIFAR-10是多伦多大学提供的图片数据库,图片分辨率压缩至32x32,一共有10种图片分类,均进行了标注。适合监督式学习。
…
6. 完整代码
############### DATASET ##################
if(opt.dataset == 'CIFAR'):
dataset = dset.CIFAR10(root='../data/', download=True,
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
else:
dataset = dset.MNIST(root = '../data/',
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.ToTensor(),
]),
download = True)
loader = torch.utils.data.DataLoader(dataset = dataset,
batch_size = opt.batchSize,
shuffle = True)
二. Model设计
1. Generator
class Generator(nn.Module):
def __init__(self, nc, ngf, nz):
super(Generator,self).__init__()
self.layer1 = nn.Sequential(nn.ConvTranspose2d(nz,ngf*4,kernel_size=4),
nn.BatchNorm2d(ngf*4),
nn.ReLU())
# 4 x 4
self.layer2 = nn.Sequential(nn.ConvTranspose2d(ngf*4,ngf*2,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(ngf*2),
nn.ReLU())
# 8 x 8
self.layer3 = nn.Sequential(nn.ConvTranspose2d(ngf*2,ngf,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(ngf),
nn.ReLU())
# 16 x 16
self.layer4 = nn.Sequential(nn.ConvTranspose2d(ngf,nc,kernel_size=4,stride=2,padding=1),
nn.Tanh())
def forward(self,x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
return out
2. Discriminator
class Discriminator(nn.Module):
def __init__(self,nc,ndf):
super(Discriminator,self).__init__()
# 32 x 32
self.layer1 = nn.Sequential(nn.Conv2d(nc,ndf,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(ndf),
nn.LeakyReLU(0.2,inplace=True))
# 16 x 16
self.layer2 = nn.Sequential(nn.Conv2d(ndf,ndf*2,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2,inplace=True))
# 8 x 8
self.layer3 = nn.Sequential(nn.Conv2d(ndf*2,ndf*4,kernel_size=4,stride=2,padding=1),
nn.BatchNorm2d(ndf*4),
nn.LeakyReLU(0.2,inplace=True))
# 4 x 4
self.layer4 = nn.Sequential(nn.Conv2d(ndf*4,1,kernel_size=4,stride=1,padding=0),
nn.Sigmoid())
def forward(self,x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
return out
3. model调用
############### MODEL ####################
ndf = opt.ndf
ngf = opt.ngf
nc = 1
if(opt.dataset == 'CIFAR'):
nc = 3
netD = Discriminator(nc, ndf)
netG = Generator(nc, ngf, opt.nz)
if(opt.cuda):
netD.cuda()
netG.cuda()
源代码网址:https://github.com/sunshineatnoon/Paper-Implementations/tree/master/dcgan