diffusers中的controlnet训练

train_controlnet.py

accelerate = Accelerator()->

tokenizer = AutoTokenizer.from_pretrained(,"tokenizer")->
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
noise_scheduler = DDPMScheduler.from_pretrained(,'scheduler')
text_encoder = text_encoder_cls.from_pretrained(,'text_encoder')
vae = AutoencoderKL.from_pretrained(,'vae')
unet = UNet2DConditionModel.from_pretrained(,'unet')

if controlnet_model_name_or_path:
    controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path)
else:
    controlnet = ControlNetModel.from_unet(unet)

vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
controlnet.train()

optimizer_class = torch.optim.AdamW
params_to_optimize = controlnet.parameters()
optimizer = optimizer_class(params_to_optimize,args.learning_rate,betas=(args.adam_beta1,args.adam_beta2),args.adam_weight_decay,args.adam_epsilon)

train_dataset = make_train_dataset(args,tokenizer,accelerator)
- dataset = load_dataset(args.train_data_dir,cache_dir)
-- builder_instance = load_dataset_builder(path,name,data_dir,data_files,...)
--- dataset_module = dataset_module_factory(path,download_config,...)
-- builder_instance.download_and_prepare()
-- ds = builder_instance.as_dataset()
- column_names = dataset['train'].colimn_names
- image_column = "image"
- caption_column = "text"
- conditioning_image_column = "conditioning_image"
- image_transforms = transforms.Compose([transforms.Resize(resolution,transforms.InterpolationMode.BILINEAR),transforms.CenterCrop(resolution),transforms.ToTensor(),transforms.Normalize([0.5,0.5])])
- conditioming_image_transforms = transforms.Compose([transforms.Resize(),transforms.CenterCrop(),transforms.ToTensor()])
- train_dataset = dataset['train'].with_transform(preprocess_train)
- images = [image.convert("RGB") for image in examples[image_column]]
- images = [image_transforms(image) for image in images]        
- conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
- conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
- examples["pixel_values"] = images
- examples["conditioning_pixel_values"] = conditioning_images
- examples['input_ids'] = tokenize_captions(examples)
- inputs = tokenizer(captions,tokenizer.model_max_length,padding='max_length')
train_dataloader = torch.utils.data.DataLoader(train_dataset,shuffle)

lr_scheduler = get_scheduler(lr_scheduler,optimizer,num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,num_training_steps=args.max_train_steps * accelerator.num_processes,num_cycles=args.lr_num_cycles,power=args.lr_power,)

controlnet,optimizer,train_dataloader,lr_scheduler = accelerateor.prepare(controlnet,optimizer,train_dataloader,lr_scheduler)

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

for epoch in range(first_epoch,args.num_train_epochs):
    for step,batch in enumerate(train_dataloader):
        latents = vae.encode(batch['pixel_values'].to()).latent_dist.sample()
        latents = latents*vae.config.scaling_factor
        
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        timesteps = torch.randint(0,noise_scheduler.config.num_train_timesteps,(bsz,))
        
        noisy_latents = noise_scheduler.add_noise(latents,noise,timesteps)
        encoder_hidden_states = text_encoder(['input_ids'])[0]
        controlnet_image = batch['conditioning_pixel_values']
        down_block_res_samples,mid_block_res_sample = controlnet(noisy_latnets,timesteps,encoder_hidden_states,controlnet_cond)
        model_pred = unet(noisy_latents,timesteps,encoder_hidden_states,down_block_additional_residuals=[sample for sample in down_block_res_samples],mid_block_additional_residual=mid_block_res_sample),sample
        
        target = noise
        loss = F.mse_loss(model_pred.float(),target.float())
        
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

diffusers/models/controlnet.py

forward->
0.check channel order and prepare attention mask

1.timesteps
t_emb = self.time_proj(timesteps)->
emb = self.time_embedding(t_emb,timestep_cond)
if self.class_embedding is not None:
    class_emb = self.class_embedding(class_labels)
    emb = emb+class_emb

2.pre-process
sample = self.conv_in(sample) # sample:noisy input,进行卷积操作
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) # controlnet_cond:条件输入
sample = sample+controlnet_cond # 将噪声和条件输入融合

3.down
down_block_res_sample = (sample,)
for downsample_block in self.down_block:
    sample,res_samples = downsample_block(hidden_states=sample,temb=emb)
    down_bock_res_samples += res_samples

4.mid
if self.mid_block is not None:
    sample = self.mid_block(sample,emb,encoder_hidden_states,attention_mask,cross_attention_kwargs)

5.control net blocks
for down_block_res_sample,controlnet_block in zip(down_block_res_samples,self.controlnet_down_blocks):
    down_block_res_sample = controlnet_block(down_block_res_sample)
    controlnet_down_block_res_sample = controlnet_down_block_res_samples+(down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_sample
mid_block_res_sample = self.controlnet_mid_block(sample)

6.scaling
down_block_res_samples = [sample*conditioning_scale for sample in down_block_res_sample]
mid_block_res_sample = mid_block_res_sample*conditioning_scale

这块和上面这张图可能有些区别,代码上,实际的controlnet训练之后的权重是加在encoder上,也就是下采样上,这点在ControlNet如何为扩散模型添加额外模态的引导信息 - 知乎 也说了。

在diffusers中的选择的数据集也是作者提供的fill50K,其中是圆圈以及相应的canny图。但是在本地加载时需要主要注意:

应该使用fill50k.py作为train_data_dir参数的值。

墨滴社区icon-default.png?t=N7T8https://mdnice.com/writing/f404f0a6ee6547f8b9df3a098b349619

猜你喜欢

转载自blog.csdn.net/u012193416/article/details/132986044