diffusion 之 mnist 数据集
代码出处:https://github.com/abarankab/DDPM
wandb的问题解决方法:
step1: 按照这个https://blog.csdn.net/weixin_43164054/article/details/124156206一步步走 step2: 修改project_name=“cifar”,然后执行python train_cifar.py 若出现报错"wandb: ERROR It appears that you do not have permission to access the requested resource.",参看这个https://blog.csdn.net/weixin_43835996/article/details/126955917
cifar10数据集
配置好wandb,按照github上的源代码
将DDPM/scripts/train_mnist.py中的entity='treaptofun'
去掉
run = wandb.init(
project=args.project_name,
config=vars(args),
name=args.run_name,
)
# entity='treaptofun',
然后就可以正常进行训练了
mnist数据集
对于mnist数据集需要修改如下两个文件
ddpm/script_utils.py
line 90:img_channel=1,因为cifar图片为3通道,而mnist图片为1通道
line 101: initial_pad=2, 是因为cifar数据集的图片大小为32,为2的指数倍,降采样过程中除以2的话一直能整除;而mnist的图片大小为28,所以要padding为32,即设置initial_pad=2
line 120:cifar10 的图片大小为3232, mnist的图片大小为2828,
import argparse
import torchvision
import torch.nn.functional as F
from .unet import UNet
from .diffusion import (
GaussianDiffusion,
generate_linear_schedule,
generate_cosine_schedule,
)
def cycle(dl):
"""
https://github.com/lucidrains/denoising-diffusion-pytorch/
"""
while True:
for data in dl:
yield data
def get_transform():
class RescaleChannels(object):
def __call__(self, sample):
return 2 * sample - 1
return torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
RescaleChannels(),
])
def str2bool(v):
"""
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("boolean value expected")
def add_dict_to_argparser(parser, default_dict):
"""
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py
"""
for k, v in default_dict.items():
v_type = type(v)
if v is None:
v_type = str
elif isinstance(v, bool):
v_type = str2bool
parser.add_argument(f"--{
k}", default=v, type=v_type)
def diffusion_defaults():
defaults = dict(
num_timesteps=1000,
schedule="linear",
loss_type="l2",
use_labels=False,
base_channels=128,
channel_mults=(1, 2, 2, 2),
num_res_blocks=2,
time_emb_dim=128 * 4,
norm="gn",
dropout=0.1,
activation="silu",
attention_resolutions=(1,),
ema_decay=0.9999,
ema_update_rate=1,
)
return defaults
def get_diffusion_from_args(args):
activations = {
"relu": F.relu,
"mish": F.mish,
"silu": F.silu,
}
# base_channels=128
model = UNet(
img_channels=1,
base_channels=args.base_channels,
channel_mults=args.channel_mults,
time_emb_dim=args.time_emb_dim,
norm=args.norm,
dropout=args.dropout,
activation=activations[args.activation],
attention_resolutions=args.attention_resolutions,
num_classes=None if not args.use_labels else 10,
initial_pad=2,
)
# line102 在cifar中为initial_pad=0,
if args.schedule == "cosine":
betas = generate_cosine_schedule(args.num_timesteps)
else:
betas = generate_linear_schedule(
args.num_timesteps,
args.schedule_low * 1000 / args.num_timesteps,
args.schedule_high * 1000 / args.num_timesteps,
)
# 本py文件共修改了3处:line 90 ; line 101 ;line 120.
# model, (32, 32), 3, 10,
# cifar10 的图片大小为32*32,3channel, mnist的图片大小为28*28,1channel
diffusion = GaussianDiffusion(
model, (28, 28), 1, 10,
betas,
ema_decay=args.ema_decay,
ema_update_rate=args.ema_update_rate,
ema_start=2000,
loss_type=args.loss_type,
)
return diffusion
scripts/train_mnist.py
把entity=‘treaptofun’,给去掉
import argparse
import datetime
import torch
import wandb
from torch.utils.data import DataLoader
from torchvision import datasets
from ddpm import script_utils
def main():
args = create_argparser().parse_args()
device = args.device
try:
diffusion = script_utils.get_diffusion_from_args(args).to(device)
optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)
# 接着上次中断保存的参数继续训练
if args.model_checkpoint is not None:
diffusion.load_state_dict(torch.load(args.model_checkpoint))
if args.optim_checkpoint is not None:
optimizer.load_state_dict(torch.load(args.optim_checkpoint))
if args.log_to_wandb:
if args.project_name is None:
raise ValueError("args.log_to_wandb set to True but args.project_name is None")
# wandb.init(project="ddpm_cifar")
run = wandb.init(
project=args.project_name,
config=vars(args),
name=args.run_name,
)
# entity='treaptofun',
wandb.watch(diffusion)
batch_size = args.batch_size
train_dataset = datasets.MNIST(
root='../dataset/mnist/mnist_train',
train=True,
download=True,
transform=script_utils.get_transform(),
)
test_dataset = datasets.MNIST(
root='../dataset/mnist/mnist_test',
train=False,
download=True,
transform=script_utils.get_transform(),
)
train_loader = script_utils.cycle(DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
num_workers=2,
))
test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2)
acc_train_loss = 0
for iteration in range(1, args.iterations + 1):
diffusion.train()
x, y = next(train_loader)
x = x.to(device)
y = y.to(device)
if args.use_labels:
loss = diffusion(x, y)
else:
loss = diffusion(x)
acc_train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
diffusion.update_ema()
if iteration % args.log_rate == 0:
test_loss = 0
with torch.no_grad():
diffusion.eval()
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
if args.use_labels:
loss = diffusion(x, y)
else:
loss = diffusion(x)
test_loss += loss.item()
if args.use_labels:
samples = diffusion.sample(10, device, y=torch.arange(10, device=device))
else:
samples = diffusion.sample(10, device)
samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()
test_loss /= len(test_loader)
acc_train_loss /= args.log_rate
wandb.log({
"test_loss": test_loss,
"train_loss": acc_train_loss,
"samples": [wandb.Image(sample) for sample in samples],
})
acc_train_loss = 0
if iteration % args.checkpoint_rate == 0:
model_filename = f"{
args.log_dir}/{
args.project_name}-{
args.run_name}-iteration-{
iteration}-model.pth"
optim_filename = f"{
args.log_dir}/{
args.project_name}-{
args.run_name}-iteration-{
iteration}-optim.pth"
torch.save(diffusion.state_dict(), model_filename)
torch.save(optimizer.state_dict(), optim_filename)
if args.log_to_wandb:
run.finish()
except KeyboardInterrupt:
if args.log_to_wandb:
run.finish()
print("Keyboard interrupt, run finished early")
def create_argparser():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")
defaults = dict(
learning_rate=2e-4,
batch_size=128,
iterations=80000,
log_to_wandb=True,
log_rate=1000,
checkpoint_rate=1000,
log_dir="./ddpm_logs_mnist",
project_name="mnist",
run_name=run_name,
model_checkpoint=None,
optim_checkpoint=None,
schedule_low=1e-4,
schedule_high=0.02,
device=device,
)
defaults.update(script_utils.diffusion_defaults())
parser = argparse.ArgumentParser()
script_utils.add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
命令行执行的训练命令:
python train.py
命令行执行的采样命令
python sample_images.py --model_path "your model path" --save_dir "your save img path" --schedule cosine
展示采样结果
import matplotlib.pyplot as plt
import numpy as np
import os
def show(num_imgs, dir_path):
'''
num_imgs: 要展示的图片的张数
dir_path:图片的路径
'''
img_names=os.listdir (dir_path)
img_names.sort(key=lambda x:int(x.split('.')[0]))
plt.figure(figsize=(20,5)) # 画布大小
N=2
M=10
#形成NxM大小的画布
for i in range(num_imgs):#有张图片
path = dir_path + img_names[i]
img = plt.imread(path)
plt.subplot(N,M,i+1)#表示第i张图片,下标只能从1开始,不能从0,
plt.imshow(img)
plt.title(img_names[i],color='black')
#下面两行是消除每张图片自己单独的横纵坐标,不然每张图片会有单独的横纵坐标,影响美观
plt.xticks([])
plt.yticks([])
plt.show()
print("mnist generation results:")
show(20, './scripts/save_dir_mnist/') # 模型训练出来的保存的结果
这里的名字只是预测出来的图片的序号,并不是预测的label!
无label的训练和采样过程
训练过程:
def get_losses(self, x, t, y):
noise = torch.randn_like(x)
perturbed_x = self.perturb_x(x, t, noise)
estimated_noise = self.model(perturbed_x, t, y) # 输入到Model的是加噪后的图片
# 这个model预测出来的噪声是每个像素点位置上的噪声!!!
# 因为这个model的output的形状和x是一样的,[batch, img_channel, h, w]
if self.loss_type == "l1":
loss = F.l1_loss(estimated_noise, noise)
elif self.loss_type == "l2":
loss = F.mse_loss(estimated_noise, noise)
return loss
- x: (batch_size, img_channel, h, w)
- t: (batch_size, )
在区间[0, num_timesteps]里面随机生成b个时间,扩散过程并不是逐步进行的,t 是一个大小为batch的张量
说明一下:这个t是怎么加入到图片x中的
t最开始为(batch_size,)形状的张量,它经过linear,变成了(batch_size, img_channel),然后经过扩维,变成(batch_size, img_channel, 1, 1),经过广播机制就可以加入 x: (batch_size, img_channel, h, w)里面,即x的同一个channel上的所有像素点加入的t的值是一样的。
- perturb_x: 根据公式 x t = α t ˉ . x 0 + 1 − α t ˉ . z x_t = \sqrt{\bar{\alpha_t}}.x_0 + \sqrt{1 - \bar{\alpha_t}}.z xt=αtˉ.x0+1−αtˉ.z对 x 0 x_0 x0进行加噪,perturbed_x的形状为(batch_size, img_channel, h, w)
- model(perturbed_x, t, y) : 输入加噪后的图片,和对应的时间t,model预测出来的是加入的噪声,通过对perturbed_x进行卷积,激活,降采样,上采样一通操作,最终model输出的形状仍为(batch_size, img_channel, h, w),model的output就是预测加入的噪声。那这里预测的噪声就是预测出来的是加入到每个像素点位置上的噪声!
- 用l1或l2损失函数来计算损失。
采样过程
@torch.no_grad()
def sample(self, batch_size, device, y=None, use_ema=True):
if y is not None and batch_size != len(y):
raise ValueError("sample batch size different from length of given y")
x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
for t in range(self.num_timesteps - 1, -1, -1): # 从T=[t-1]到T=[0]
t_batch = torch.tensor([t], device=device).repeat(batch_size)
x = self.remove_noise(x, t_batch, y, use_ema)
if t > 0:
x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
return x.cpu().detach()
x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z \mathbf{x}_{t-1}=\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \boldsymbol{\epsilon}_{\theta}\left(\mathbf{x}_{t}, t\right)\right)+\sigma_{t} \mathbf{z} xt−1=αt1(xt−1−αˉt1−αtϵθ(xt,t))+σtz
- x: 随机生成噪声作为初始值,batch_size就是你想生成的图片的张数,比如你想产生1k张图片
- t_batch: 就是说对x的去噪是批处理进行的,我们的目的是 x T , x T − 1 , x T − 2 . . . x 1 , x 0 x_T, x_{T-1},x_{T-2}...x_1,x_0 xT,xT−1,xT−2...x1,x0, 因为x是有batch_size个,t_batch就是让这batch_size张图片同时去噪
- remove_noise: 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) \frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \boldsymbol{\epsilon}_{\theta}\left(\mathbf{x}_{t}, t\right)\right) αt1(xt−1−αˉt1−αtϵθ(xt,t))
- if t>0: 就是加上一个随机噪声 σ t z \sigma_{t} \mathbf{z} σtz,为什么在采样的过程中还要加上一个随机噪声呢?为了模拟布朗运动的随机性,当t=0时,说明已经到了 x 0 x_0 x0了,即最后一步得到原图了,对于原图就不需要再加噪声了!
有条件的训练和采样
训练
有条件的训练过程就是把标签y 也加入到图片中去进行训练
self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None
out += self.class_bias(y)[:, :, None, None]
y:就是标签,通过nn.Embedding可以把y表示成[batch_size, out_channels],再通过[:, :, None, None]来进行扩维,将y变为[batch_size, out_channels, 1, 1], 然后加入到经过各种操作处理后的x中(也即加入到out中)
这里的对y的操作很像对时间t的操作
采样
if args.use_labels:
for label in range(10):
# 这个就是假设每一类的数量都是一样的,所以在生成标签的时候,每一类的标签y的数量是一样的
# 比如我们想生成1k个图片,label一共有10种,所以每一类有100张
y = torch.ones(args.num_images // 10, dtype=torch.long, device=device) * label
samples = diffusion.sample(args.num_images // 10, device, y=y)
for image_id in range(len(samples)):
image = ((samples[image_id] + 1) / 2).clip(0, 1)
torchvision.utils.save_image(image, f"{
args.save_dir}/{
label}-{
image_id}.png")
采样的过程就是去噪的过程,这个去除的噪声的大小就是用我们训练好的模型预测出来的噪声,因为对于有条件的生成,我们在训练的过程中是加入了label的,所以在生成的时候我们也可以加入label,来指定噪声图片一步步去噪得到 x 0 x_0 x0,那这个 x 0 x_0 x0就更有可能属于指定的label的类别。
有条件生成和无条件生的对比
假设:原始的训练数据集中有猫,狗,猪,三类,这三类的占比分别为0.2. 0.3 0.5
-
有条件生成:
我们可以指定生成哪一类,比如生成1k张图片,我们指定label=猫,那生成的1k张图片大约999+张都是猫 -
无条件生成:
不能指定生成哪一类,比如生成1k张图片,这1k张图片大约有200张是猫,300张是狗,500张是猪