| import sys |
| sys.path.append('cclddg') |
| import wandb |
| from cclddg.data import get_paired_vqgan, tensor_to_image |
| from cclddg.core import UNet, Discriminator |
| from cclddg.ddg_context import DDG_Context |
| from PIL import Image |
| import torch |
| import torchvision.transforms as T |
| from torch_ema import ExponentialMovingAverage |
| from tqdm import tqdm |
| import torch.nn.functional as F |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
| |
| n_batches = 101000 |
| batch_size= 5 |
| lr = 5e-5 |
| img_size=128 |
| sr=1 |
| n_steps=200 |
| grad_accumulation_steps = 6 |
|
|
| wandb.init(project = 'dvq_diff', |
| config={ |
| 'n_batches':n_batches, |
| 'batch_size':batch_size, |
| 'lr':lr, |
| 'img_size':img_size, |
| 'sr':sr, |
| 'n_steps':n_steps, |
| }, |
| save_code=True) |
|
|
| |
| ddg_context = DDG_Context(n_steps=n_steps, beta_min=0.005, |
| beta_max=0.05, device=device) |
|
|
| |
| unet = UNet(image_channels=6, n_channels=128, ch_mults=(1, 1, 2, 2, 2), |
| is_attn=(False, False, False, True, True), |
| n_blocks=4, use_z=False, z_dim=8, n_z_channels=16, |
| use_cloob=False, n_cloob_channels=256, |
| n_time_channels=-1, denom_factor=1000).to(device) |
| unet.load_state_dict(torch.load('desert_dawn_ema_unet_020000.pt')) |
|
|
|
|
| if sr == 4: |
| |
| lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size//4), T.Resize(img_size)]) |
| hq_tfm = T.CenterCrop(img_size) |
| if sr == 2: |
| lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size)]) |
| hq_tfm = T.CenterCrop(img_size) |
| if sr == 1: |
| lq_tfm = T.Compose([T.Resize(img_size)]) |
| hq_tfm = T.Compose([T.Resize(img_size)]) |
|
|
|
|
| |
| data = get_paired_vqgan(batch_size=batch_size) |
| data_iter = iter(data) |
|
|
| |
| n_egs = 10 |
| eg_lq, eg_hq = next(data_iter) |
| eg_lq = lq_tfm(eg_lq[:n_egs]).to(device)*2-1 |
| eg_hq = hq_tfm(eg_hq[:n_egs]).to(device)*2-1 |
| def eg_im(eg_lq, eg_hq, ddg_context, start_t = 99): |
| batch_size = eg_lq.shape[0] |
| all_ims = [[] for _ in range(batch_size)] |
| |
| |
| cond_0 = eg_lq |
| start_t = min(start_t, ddg_context.n_steps-1) |
| t = torch.tensor(start_t, dtype=torch.long).cuda() |
| x, n = ddg_context.q_xt_x0(cond_0, t.unsqueeze(0)) |
| ims = [] |
| for i in range(start_t): |
| t = torch.tensor(start_t-i-1, dtype=torch.long).cuda() |
| with torch.no_grad(): |
| unet_input = torch.cat((x, cond_0), dim=1) |
| pred_noise = unet(unet_input, t.unsqueeze(0))[:,:3] |
| x = ddg_context.p_xt(x, pred_noise, t.unsqueeze(0)) |
| if i%(start_t//4 - 1) == 0: |
| for b in range(batch_size): |
| all_ims[b].append(tensor_to_image(x[b].cpu())) |
| |
| |
| for b in range(batch_size): |
| all_ims[b].append(tensor_to_image(eg_hq[b].cpu())) |
| |
| for b in range(batch_size): |
| all_ims[b].append(tensor_to_image(cond_0[b].cpu())) |
|
|
| image = Image.new('RGB', size=(img_size*7, batch_size*img_size)) |
| for i in range(7): |
| for b in range(batch_size): |
| image.paste(all_ims[b][i], (i*img_size, b*img_size)) |
| |
| return image |
|
|
| |
| losses = [] |
| optim = torch.optim.RMSprop(unet.parameters(), lr=lr) |
| ema = ExponentialMovingAverage(unet.parameters(), decay=0.995) |
| scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9) |
|
|
| for i in tqdm(range(0, n_batches)): |
|
|
| |
| try: |
| lq, hq = next(data_iter) |
| except: |
| pass |
| lq = lq_tfm(lq).to(device)*2-1 |
| hq = hq_tfm(hq).to(device)*2-1 |
| batch_size=lq.shape[0] |
| |
| |
| x0 = hq |
| cond_0 = lq |
| t = torch.randint(1, ddg_context.n_steps, (batch_size,), dtype=torch.long).to(device) |
| xt, noise = ddg_context.q_xt_x0(x0, t) |
| unet_input = torch.cat((xt, cond_0), dim=1) |
| pred_noise = unet(unet_input, t)[:,:3] |
| loss = F.mse_loss(noise.float(), pred_noise) |
| losses.append(loss.item()) |
| wandb.log({'Loss':loss.item()}) |
| loss.backward() |
| |
| if i % grad_accumulation_steps == 0: |
| optim.step() |
| optim.zero_grad() |
| ema.update() |
| |
| if i % 2000 == 0: |
| with torch.no_grad(): |
| wandb.log({'Examples @120':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 120))}) |
| wandb.log({'Examples @199':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 199))}) |
| wandb.log({'Random Examples @120':wandb.Image(eg_im(lq, hq, ddg_context, start_t = 120))}) |
| |
| if i % 20000 == 0: |
| torch.save(unet.state_dict(), f'unet_{i:06}.pt') |
| with ema.average_parameters(): |
| torch.save(unet.state_dict(), f'ema_unet_{i:06}.pt') |
| |
| if (i+1)%4000 == 0: |
| scheduler.step() |
| wandb.log({'lr':optim.param_groups[0]['lr']}) |
| |