| | import torch |
| | import math |
| | from tqdm import trange, tqdm |
| |
|
| | import k_diffusion as K |
| | |
| | |
| | def get_alphas_sigmas(t): |
| | """Returns the scaling factors for the clean image (alpha) and for the |
| | noise (sigma), given a timestep.""" |
| | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) |
| |
|
| | def alpha_sigma_to_t(alpha, sigma): |
| | """Returns a timestep, given the scaling factors for the clean image and for |
| | the noise.""" |
| | return torch.atan2(sigma, alpha) / math.pi * 2 |
| |
|
| | def t_to_alpha_sigma(t): |
| | """Returns the scaling factors for the clean image and for the noise, given |
| | a timestep.""" |
| | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) |
| |
|
| |
|
| | @torch.no_grad() |
| | def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args): |
| | """Draws samples from a model given starting noise. Euler method""" |
| |
|
| | |
| | ts = x.new_ones([x.shape[0]]) |
| |
|
| | |
| | t = torch.linspace(sigma_max, 0, steps + 1) |
| |
|
| | |
| |
|
| | for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])): |
| | |
| | t_curr_tensor = t_curr * torch.ones( |
| | (x.shape[0],), dtype=x.dtype, device=x.device |
| | ) |
| | dt = t_prev - t_curr |
| | x = x + dt * model(x, t_curr_tensor, **extra_args) |
| |
|
| | |
| | return x |
| |
|
| | @torch.no_grad() |
| | def sample(model, x, steps, eta, **extra_args): |
| | """Draws samples from a model given starting noise. v-diffusion""" |
| | ts = x.new_ones([x.shape[0]]) |
| |
|
| | |
| | t = torch.linspace(1, 0, steps + 1)[:-1] |
| |
|
| | alphas, sigmas = get_alphas_sigmas(t) |
| |
|
| | |
| | for i in trange(steps): |
| |
|
| | |
| | with torch.cuda.amp.autocast(): |
| | v = model(x, ts * t[i], **extra_args).float() |
| |
|
| | |
| | pred = x * alphas[i] - v * sigmas[i] |
| | eps = x * sigmas[i] + v * alphas[i] |
| |
|
| | |
| | |
| | if i < steps - 1: |
| | |
| | |
| | ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ |
| | (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() |
| | adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() |
| |
|
| | |
| | |
| | x = pred * alphas[i + 1] + eps * adjusted_sigma |
| |
|
| | |
| | if eta: |
| | x += torch.randn_like(x) * ddim_sigma |
| |
|
| | |
| | return pred |
| |
|
| | |
| | |
| | def get_bmask(i, steps, mask): |
| | strength = (i+1)/(steps) |
| | |
| | bmask = torch.where(mask<=strength,1,0) |
| | return bmask |
| |
|
| | def make_cond_model_fn(model, cond_fn): |
| | def cond_model_fn(x, sigma, **kwargs): |
| | with torch.enable_grad(): |
| | x = x.detach().requires_grad_() |
| | denoised = model(x, sigma, **kwargs) |
| | cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() |
| | cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) |
| | return cond_denoised |
| | return cond_model_fn |
| |
|
| | |
| | |
| | |
| | |
| | |
| | def sample_k( |
| | model_fn, |
| | noise, |
| | init_data=None, |
| | mask=None, |
| | steps=100, |
| | sampler_type="dpmpp-2m-sde", |
| | sigma_min=0.5, |
| | sigma_max=50, |
| | rho=1.0, device="cuda", |
| | callback=None, |
| | cond_fn=None, |
| | **extra_args |
| | ): |
| |
|
| | denoiser = K.external.VDenoiser(model_fn) |
| |
|
| | if cond_fn is not None: |
| | denoiser = make_cond_model_fn(denoiser, cond_fn) |
| |
|
| | |
| | sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) |
| | |
| | noise = noise * sigmas[0] |
| |
|
| | wrapped_callback = callback |
| |
|
| |
|
| | if mask is None and init_data is not None: |
| | |
| | |
| |
|
| | x = init_data + noise |
| | |
| | elif mask is not None and init_data is not None: |
| | |
| | bmask = get_bmask(0, steps, mask) |
| | |
| | input_noised = init_data + noise |
| | |
| | x = input_noised * bmask + noise * (1-bmask) |
| | |
| | |
| | |
| | |
| | def inpainting_callback(args): |
| | i = args["i"] |
| | x = args["x"] |
| | sigma = args["sigma"] |
| | |
| | |
| | input_noised = init_data + torch.randn_like(init_data) * sigma |
| | |
| | bmask = get_bmask(i, steps, mask) |
| | |
| | new_x = input_noised * bmask + x * (1-bmask) |
| | |
| | x[:,:,:] = new_x[:,:,:] |
| | |
| | if callback is None: |
| | wrapped_callback = inpainting_callback |
| | else: |
| | wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) |
| | else: |
| | |
| | |
| | x = noise |
| | |
| |
|
| | with torch.cuda.amp.autocast(): |
| | if sampler_type == "k-heun": |
| | return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| | elif sampler_type == "k-lms": |
| | return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| | elif sampler_type == "k-dpmpp-2s-ancestral": |
| | return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| | elif sampler_type == "k-dpm-2": |
| | return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| | elif sampler_type == "k-dpm-fast": |
| | return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| | elif sampler_type == "k-dpm-adaptive": |
| | return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| | elif sampler_type == "dpmpp-2m-sde": |
| | return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| | elif sampler_type == "dpmpp-3m-sde": |
| | return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | def sample_rf( |
| | model_fn, |
| | noise, |
| | init_data=None, |
| | steps=100, |
| | sigma_max=1, |
| | device="cuda", |
| | callback=None, |
| | cond_fn=None, |
| | **extra_args |
| | ): |
| |
|
| | if sigma_max > 1: |
| | sigma_max = 1 |
| |
|
| | if cond_fn is not None: |
| | denoiser = make_cond_model_fn(denoiser, cond_fn) |
| |
|
| | wrapped_callback = callback |
| |
|
| | if init_data is not None: |
| | |
| | |
| | x = init_data * (1 - sigma_max) + noise * sigma_max |
| | else: |
| | |
| | |
| | x = noise |
| |
|
| | with torch.cuda.amp.autocast(): |
| | |
| | |
| | return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) |