| | from config import * |
| |
|
| | from torch.cuda import amp |
| |
|
| |
|
| | def render_uncondition(conf: TrainConfig, |
| | model: BeatGANsAutoencModel, |
| | x_T, |
| | sampler: Sampler, |
| | latent_sampler: Sampler, |
| | conds_mean=None, |
| | conds_std=None, |
| | clip_latent_noise: bool = False): |
| | device = x_T.device |
| | if conf.train_mode == TrainMode.diffusion: |
| | assert conf.model_type.can_sample() |
| | return sampler.sample(model=model, noise=x_T) |
| | elif conf.train_mode.is_latent_diffusion(): |
| | model: BeatGANsAutoencModel |
| | if conf.train_mode == TrainMode.latent_diffusion: |
| | latent_noise = torch.randn(len(x_T), conf.style_ch, device=device) |
| | else: |
| | raise NotImplementedError() |
| |
|
| | if clip_latent_noise: |
| | latent_noise = latent_noise.clip(-1, 1) |
| |
|
| | cond = latent_sampler.sample( |
| | model=model.latent_net, |
| | noise=latent_noise, |
| | clip_denoised=conf.latent_clip_sample, |
| | ) |
| |
|
| | if conf.latent_znormalize: |
| | cond = cond * conds_std.to(device) + conds_mean.to(device) |
| |
|
| | |
| | return sampler.sample(model=model, noise=x_T, cond=cond) |
| | else: |
| | raise NotImplementedError() |
| |
|
| |
|
| | def render_condition( |
| | conf: TrainConfig, |
| | model: BeatGANsAutoencModel, |
| | x_T, |
| | sampler: Sampler, |
| | x_start=None, |
| | cond=None, |
| | ): |
| | if conf.train_mode == TrainMode.diffusion: |
| | assert conf.model_type.has_autoenc() |
| | |
| | if cond is None: |
| | cond = model.encode(x_start) |
| | return sampler.sample(model=model, |
| | noise=x_T, |
| | model_kwargs={'cond': cond}) |
| | else: |
| | raise NotImplementedError() |
| |
|