| | import os |
| | os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO' |
| |
|
| | import gc |
| | import lpips |
| | import clip |
| | import random |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint |
| | import transformers |
| |
|
| | from omegaconf import OmegaConf |
| | from accelerate import Accelerator |
| | from accelerate.utils import set_seed |
| | from PIL import Image |
| | from torchvision import transforms |
| | from tqdm.auto import tqdm |
| |
|
| | import diffusers |
| | from diffusers.utils.import_utils import is_xformers_available |
| | from diffusers.optimization import get_scheduler |
| |
|
| | from de_net import DEResNet |
| | from s3diff import S3Diff |
| | from my_utils.training_utils import parse_args_paired_training, PairedDataset, degradation_proc |
| |
|
| | def main(args): |
| |
|
| | |
| | config = OmegaConf.load(args.base_config) |
| |
|
| | accelerator = Accelerator( |
| | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| | mixed_precision=args.mixed_precision, |
| | log_with=args.report_to, |
| | ) |
| |
|
| | if accelerator.is_local_main_process: |
| | transformers.utils.logging.set_verbosity_warning() |
| | diffusers.utils.logging.set_verbosity_info() |
| | else: |
| | transformers.utils.logging.set_verbosity_error() |
| | diffusers.utils.logging.set_verbosity_error() |
| |
|
| | if args.seed is not None: |
| | set_seed(args.seed) |
| |
|
| | if accelerator.is_main_process: |
| | os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True) |
| | os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True) |
| |
|
| | |
| | net_de = DEResNet(num_in_ch=3, num_degradation=2) |
| | net_de.load_model(args.de_net_path) |
| | net_de = net_de.cuda() |
| | net_de.eval() |
| |
|
| | |
| | net_sr = S3Diff(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae, sd_path=args.sd_path, pretrained_path=args.pretrained_path) |
| | net_sr.set_train() |
| |
|
| | if args.enable_xformers_memory_efficient_attention: |
| | if is_xformers_available(): |
| | net_sr.unet.enable_xformers_memory_efficient_attention() |
| | else: |
| | raise ValueError("xformers is not available, please install it by running `pip install xformers`") |
| |
|
| | if args.gradient_checkpointing: |
| | net_sr.unet.enable_gradient_checkpointing() |
| |
|
| | if args.allow_tf32: |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| |
|
| | if args.gan_disc_type == "vagan": |
| | import vision_aided_loss |
| | net_disc = vision_aided_loss.Discriminator(cv_type='dino', output_type='conv_multi_level', loss_type=args.gan_loss_type, device="cuda") |
| | else: |
| | raise NotImplementedError(f"Discriminator type {args.gan_disc_type} not implemented") |
| |
|
| | net_disc = net_disc.cuda() |
| | net_disc.requires_grad_(True) |
| | net_disc.cv_ensemble.requires_grad_(False) |
| | net_disc.train() |
| |
|
| | net_lpips = lpips.LPIPS(net='vgg').cuda() |
| | net_lpips.requires_grad_(False) |
| |
|
| | |
| | layers_to_opt = [] |
| | layers_to_opt = layers_to_opt + list(net_sr.vae_block_embeddings.parameters()) + list(net_sr.unet_block_embeddings.parameters()) |
| | layers_to_opt = layers_to_opt + list(net_sr.vae_de_mlp.parameters()) + list(net_sr.unet_de_mlp.parameters()) + \ |
| | list(net_sr.vae_block_mlp.parameters()) + list(net_sr.unet_block_mlp.parameters()) + \ |
| | list(net_sr.vae_fuse_mlp.parameters()) + list(net_sr.unet_fuse_mlp.parameters()) |
| |
|
| | for n, _p in net_sr.unet.named_parameters(): |
| | if "lora" in n: |
| | assert _p.requires_grad |
| | layers_to_opt.append(_p) |
| | layers_to_opt += list(net_sr.unet.conv_in.parameters()) |
| |
|
| | for n, _p in net_sr.vae.named_parameters(): |
| | if "lora" in n: |
| | assert _p.requires_grad |
| | layers_to_opt.append(_p) |
| |
|
| | dataset_train = PairedDataset(config.train) |
| | dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers) |
| | dataset_val = PairedDataset(config.validation) |
| | dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0) |
| |
|
| |
|
| | optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate, |
| | betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, |
| | eps=args.adam_epsilon,) |
| | lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer, |
| | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, |
| | num_training_steps=args.max_train_steps * accelerator.num_processes, |
| | num_cycles=args.lr_num_cycles, power=args.lr_power,) |
| |
|
| | optimizer_disc = torch.optim.AdamW(net_disc.parameters(), lr=args.learning_rate, |
| | betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, |
| | eps=args.adam_epsilon,) |
| | lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc, |
| | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, |
| | num_training_steps=args.max_train_steps * accelerator.num_processes, |
| | num_cycles=args.lr_num_cycles, power=args.lr_power) |
| |
|
| | |
| | net_sr, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc = accelerator.prepare( |
| | net_sr, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc |
| | ) |
| | net_de, net_lpips = accelerator.prepare(net_de, net_lpips) |
| | |
| | weight_dtype = torch.float32 |
| | if accelerator.mixed_precision == "fp16": |
| | weight_dtype = torch.float16 |
| | elif accelerator.mixed_precision == "bf16": |
| | weight_dtype = torch.bfloat16 |
| |
|
| | |
| | net_sr.to(accelerator.device, dtype=weight_dtype) |
| | net_de.to(accelerator.device, dtype=weight_dtype) |
| | net_disc.to(accelerator.device, dtype=weight_dtype) |
| | net_lpips.to(accelerator.device, dtype=weight_dtype) |
| |
|
| | progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps", |
| | disable=not accelerator.is_local_main_process,) |
| |
|
| | for name, module in net_disc.named_modules(): |
| | if "attn" in name: |
| | module.fused_attn = False |
| |
|
| | |
| | global_step = 0 |
| | for epoch in range(0, args.num_training_epochs): |
| | for step, batch in enumerate(dl_train): |
| | l_acc = [net_sr, net_disc] |
| | with accelerator.accumulate(*l_acc): |
| | x_src, x_tgt, x_ori_size_src = degradation_proc(config, batch, accelerator.device) |
| | B, C, H, W = x_src.shape |
| | with torch.no_grad(): |
| | deg_score = net_de(x_ori_size_src.detach()).detach() |
| |
|
| | pos_tag_prompt = [args.pos_prompt for _ in range(B)] |
| | neg_tag_prompt = [args.neg_prompt for _ in range(B)] |
| |
|
| | neg_probs = torch.rand(B).to(accelerator.device) |
| | |
| | |
| | mixed_tag_prompt = [_neg_tag if p_i < args.neg_prob else _pos_tag for _neg_tag, _pos_tag, p_i in zip(neg_tag_prompt, pos_tag_prompt, neg_probs)] |
| | neg_probs = neg_probs.reshape(B, 1, 1, 1) |
| | mixed_tgt = torch.where(neg_probs < args.neg_prob, x_src, x_tgt) |
| |
|
| | x_tgt_pred = net_sr(x_src.detach(), deg_score, mixed_tag_prompt) |
| | loss_l2 = F.mse_loss(x_tgt_pred.float(), mixed_tgt.detach().float(), reduction="mean") * args.lambda_l2 |
| | loss_lpips = net_lpips(x_tgt_pred.float(), mixed_tgt.detach().float()).mean() * args.lambda_lpips |
| |
|
| | loss = loss_l2 + loss_lpips |
| |
|
| | accelerator.backward(loss, retain_graph=False) |
| | if accelerator.sync_gradients: |
| | accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm) |
| | optimizer.step() |
| | lr_scheduler.step() |
| | optimizer.zero_grad(set_to_none=args.set_grads_to_none) |
| |
|
| | """ |
| | Generator loss: fool the discriminator |
| | """ |
| | x_tgt_pred = net_sr(x_src.detach(), deg_score, pos_tag_prompt) |
| | lossG = net_disc(x_tgt_pred, for_G=True).mean() * args.lambda_gan |
| | accelerator.backward(lossG) |
| | if accelerator.sync_gradients: |
| | accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm) |
| | optimizer.step() |
| | lr_scheduler.step() |
| | optimizer.zero_grad(set_to_none=args.set_grads_to_none) |
| |
|
| | """ |
| | Discriminator loss: fake image vs real image |
| | """ |
| | |
| | lossD_real = net_disc(x_tgt.detach(), for_real=True).mean() * args.lambda_gan |
| | accelerator.backward(lossD_real.mean()) |
| | if accelerator.sync_gradients: |
| | accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm) |
| | optimizer_disc.step() |
| | lr_scheduler_disc.step() |
| | optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none) |
| | |
| | lossD_fake = net_disc(x_tgt_pred.detach(), for_real=False).mean() * args.lambda_gan |
| | accelerator.backward(lossD_fake.mean()) |
| | if accelerator.sync_gradients: |
| | accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm) |
| | optimizer_disc.step() |
| | optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none) |
| | lossD = lossD_real + lossD_fake |
| |
|
| | |
| | if accelerator.sync_gradients: |
| | progress_bar.update(1) |
| | global_step += 1 |
| |
|
| | if accelerator.is_main_process: |
| | logs = {} |
| | logs["lossG"] = lossG.detach().item() |
| | logs["lossD"] = lossD.detach().item() |
| | logs["loss_l2"] = loss_l2.detach().item() |
| | logs["loss_lpips"] = loss_lpips.detach().item() |
| | progress_bar.set_postfix(**logs) |
| |
|
| | |
| | if global_step % args.checkpointing_steps == 1: |
| | outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl") |
| | accelerator.unwrap_model(net_sr).save_model(outf) |
| |
|
| | |
| | if global_step % args.eval_freq == 1: |
| | l_l2, l_lpips = [], [] |
| | |
| | val_count = 0 |
| | for step, batch_val in enumerate(dl_val): |
| | if step >= args.num_samples_eval: |
| | break |
| | x_src, x_tgt, x_ori_size_src = degradation_proc(config, batch_val, accelerator.device) |
| | B, C, H, W = x_src.shape |
| | assert B == 1, "Use batch size 1 for eval." |
| | with torch.no_grad(): |
| | |
| | with torch.no_grad(): |
| | deg_score = net_de(x_ori_size_src.detach()) |
| |
|
| | pos_tag_prompt = [args.pos_prompt for _ in range(B)] |
| | x_tgt_pred = accelerator.unwrap_model(net_sr)(x_src.detach(), deg_score, pos_tag_prompt) |
| | |
| | loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.detach().float(), reduction="mean") |
| | loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.detach().float()).mean() |
| |
|
| | l_l2.append(loss_l2.item()) |
| | l_lpips.append(loss_lpips.item()) |
| |
|
| | if args.save_val and val_count < 5: |
| | x_src = x_src.cpu().detach() * 0.5 + 0.5 |
| | x_tgt = x_tgt.cpu().detach() * 0.5 + 0.5 |
| | x_tgt_pred = x_tgt_pred.cpu().detach() * 0.5 + 0.5 |
| |
|
| | combined = torch.cat([x_src, x_tgt_pred, x_tgt], dim=3) |
| | output_pil = transforms.ToPILImage()(combined[0]) |
| | outf = os.path.join(args.output_dir, f"val_{step}.png") |
| | output_pil.save(outf) |
| | val_count += 1 |
| |
|
| | logs["val/l2"] = np.mean(l_l2) |
| | logs["val/lpips"] = np.mean(l_lpips) |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | accelerator.log(logs, step=global_step) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args_paired_training() |
| | main(args) |
| |
|