| | import os |
| | from argparse import ArgumentParser |
| |
|
| | from omegaconf import OmegaConf |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from torchvision.utils import make_grid |
| | from accelerate import Accelerator |
| | from accelerate.utils import set_seed |
| | from einops import rearrange |
| | from tqdm import tqdm |
| | from torch.utils.tensorboard import SummaryWriter |
| | from PIL import Image, ImageDraw, ImageFont |
| | import numpy as np |
| |
|
| | from model import ControlLDM, SwinIR, Diffusion |
| | from utils.common import instantiate_from_config |
| | from utils.sampler import SpacedSampler |
| |
|
| |
|
| | def log_txt_as_img(wh, xc): |
| | |
| | |
| | b = len(xc) |
| | txts = list() |
| | for bi in range(b): |
| | txt = Image.new("RGB", wh, color="white") |
| | draw = ImageDraw.Draw(txt) |
| | |
| | font = ImageFont.load_default() |
| | nc = int(40 * (wh[0] / 256)) |
| | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) |
| |
|
| | try: |
| | draw.text((0, 0), lines, fill="black", font=font) |
| | except UnicodeEncodeError: |
| | print("Cant encode string for logging. Skipping.") |
| |
|
| | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 |
| | txts.append(txt) |
| | txts = np.stack(txts) |
| | txts = torch.tensor(txts) |
| | return txts |
| |
|
| |
|
| | def main(args) -> None: |
| | |
| | accelerator = Accelerator(split_batches=True) |
| | set_seed(231) |
| | device = accelerator.device |
| | cfg = OmegaConf.load(args.config) |
| |
|
| | |
| | if accelerator.is_local_main_process: |
| | exp_dir = cfg.train.exp_dir |
| | os.makedirs(exp_dir, exist_ok=True) |
| | ckpt_dir = os.path.join(exp_dir, "checkpoints") |
| | os.makedirs(ckpt_dir, exist_ok=True) |
| | print(f"Experiment directory created at {exp_dir}") |
| |
|
| | |
| | cldm: ControlLDM = instantiate_from_config(cfg.model.cldm) |
| | sd = torch.load(cfg.train.sd_path, map_location="cpu")["state_dict"] |
| | unused = cldm.load_pretrained_sd(sd) |
| | if accelerator.is_local_main_process: |
| | print(f"strictly load pretrained SD weight from {cfg.train.sd_path}\n" |
| | f"unused weights: {unused}") |
| | |
| | if cfg.train.resume: |
| | cldm.load_controlnet_from_ckpt(torch.load(cfg.train.resume, map_location="cpu")) |
| | if accelerator.is_local_main_process: |
| | print(f"strictly load controlnet weight from checkpoint: {cfg.train.resume}") |
| | else: |
| | init_with_new_zero, init_with_scratch = cldm.load_controlnet_from_unet() |
| | if accelerator.is_local_main_process: |
| | print(f"strictly load controlnet weight from pretrained SD\n" |
| | f"weights initialized with newly added zeros: {init_with_new_zero}\n" |
| | f"weights initialized from scratch: {init_with_scratch}") |
| | |
| | swinir: SwinIR = instantiate_from_config(cfg.model.swinir) |
| | sd = { |
| | (k[len("module."):] if k.startswith("module.") else k): v |
| | for k, v in torch.load(cfg.train.swinir_path, map_location="cpu").items() |
| | } |
| | swinir.load_state_dict(sd, strict=True) |
| | for p in swinir.parameters(): |
| | p.requires_grad = False |
| | if accelerator.is_local_main_process: |
| | print(f"load SwinIR from {cfg.train.swinir_path}") |
| | |
| | diffusion: Diffusion = instantiate_from_config(cfg.model.diffusion) |
| | |
| | |
| | opt = torch.optim.AdamW(cldm.controlnet.parameters(), lr=cfg.train.learning_rate) |
| | |
| | |
| | dataset = instantiate_from_config(cfg.dataset.train) |
| | loader = DataLoader( |
| | dataset=dataset, batch_size=cfg.train.batch_size, |
| | num_workers=cfg.train.num_workers, |
| | shuffle=True, drop_last=True |
| | ) |
| | if accelerator.is_local_main_process: |
| | print(f"Dataset contains {len(dataset):,} images from {dataset.file_list}") |
| |
|
| | |
| | cldm.train().to(device) |
| | swinir.eval().to(device) |
| | diffusion.to(device) |
| | cldm, opt, loader = accelerator.prepare(cldm, opt, loader) |
| | pure_cldm: ControlLDM = accelerator.unwrap_model(cldm) |
| | |
| | |
| | global_step = 0 |
| | max_steps = cfg.train.train_steps |
| | step_loss = [] |
| | epoch = 0 |
| | epoch_loss = [] |
| | sampler = SpacedSampler(diffusion.betas) |
| | if accelerator.is_local_main_process: |
| | writer = SummaryWriter(exp_dir) |
| | print(f"Training for {max_steps} steps...") |
| | |
| | while global_step < max_steps: |
| | pbar = tqdm(iterable=None, disable=not accelerator.is_local_main_process, unit="batch", total=len(loader)) |
| | for gt, lq, prompt in loader: |
| | gt = rearrange(gt, "b h w c -> b c h w").contiguous().float().to(device) |
| | lq = rearrange(lq, "b h w c -> b c h w").contiguous().float().to(device) |
| | with torch.no_grad(): |
| | z_0 = pure_cldm.vae_encode(gt) |
| | clean = swinir(lq) |
| | cond = pure_cldm.prepare_condition(clean, prompt) |
| | t = torch.randint(0, diffusion.num_timesteps, (z_0.shape[0],), device=device) |
| | |
| | loss = diffusion.p_losses(cldm, z_0, t, cond) |
| | opt.zero_grad() |
| | accelerator.backward(loss) |
| | opt.step() |
| |
|
| | accelerator.wait_for_everyone() |
| |
|
| | global_step += 1 |
| | step_loss.append(loss.item()) |
| | epoch_loss.append(loss.item()) |
| | pbar.update(1) |
| | pbar.set_description(f"Epoch: {epoch:04d}, Global Step: {global_step:07d}, Loss: {loss.item():.6f}") |
| |
|
| | |
| | if global_step % cfg.train.log_every == 0 and global_step > 0: |
| | |
| | avg_loss = accelerator.gather(torch.tensor(step_loss, device=device).unsqueeze(0)).mean().item() |
| | step_loss.clear() |
| | if accelerator.is_local_main_process: |
| | writer.add_scalar("loss/loss_simple_step", avg_loss, global_step) |
| |
|
| | |
| | if global_step % cfg.train.ckpt_every == 0 and global_step > 0: |
| | if accelerator.is_local_main_process: |
| | checkpoint = pure_cldm.controlnet.state_dict() |
| | ckpt_path = f"{ckpt_dir}/{global_step:07d}.pt" |
| | torch.save(checkpoint, ckpt_path) |
| |
|
| | if global_step % cfg.train.image_every == 0 or global_step == 1: |
| | N = 12 |
| | log_clean = clean[:N] |
| | log_cond = {k:v[:N] for k, v in cond.items()} |
| | log_gt, log_lq = gt[:N], lq[:N] |
| | log_prompt = prompt[:N] |
| | cldm.eval() |
| | with torch.no_grad(): |
| | z = sampler.sample( |
| | model=cldm, device=device, steps=50, batch_size=len(log_gt), x_size=z_0.shape[1:], |
| | cond=log_cond, uncond=None, cfg_scale=1.0, x_T=None, |
| | progress=accelerator.is_local_main_process, progress_leave=False |
| | ) |
| | if accelerator.is_local_main_process: |
| | for tag, image in [ |
| | ("image/samples", (pure_cldm.vae_decode(z) + 1) / 2), |
| | ("image/gt", (log_gt + 1) / 2), |
| | ("image/lq", log_lq), |
| | ("image/condition", log_clean), |
| | ("image/condition_decoded", (pure_cldm.vae_decode(log_cond["c_img"]) + 1) / 2), |
| | ("image/prompt", (log_txt_as_img((512, 512), log_prompt) + 1) / 2) |
| | ]: |
| | writer.add_image(tag, make_grid(image, nrow=4), global_step) |
| | cldm.train() |
| | accelerator.wait_for_everyone() |
| | if global_step == max_steps: |
| | break |
| | |
| | pbar.close() |
| | epoch += 1 |
| | avg_epoch_loss = accelerator.gather(torch.tensor(epoch_loss, device=device).unsqueeze(0)).mean().item() |
| | epoch_loss.clear() |
| | if accelerator.is_local_main_process: |
| | writer.add_scalar("loss/loss_simple_epoch", avg_epoch_loss, global_step) |
| |
|
| | if accelerator.is_local_main_process: |
| | print("done!") |
| | writer.close() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = ArgumentParser() |
| | parser.add_argument("--config", type=str, required=True) |
| | args = parser.parse_args() |
| | main(args) |
| |
|