| | import pytorch_lightning as pl |
| | import sys, gc |
| | import random |
| | import torch |
| | import torchaudio |
| | import typing as tp |
| | import wandb |
| |
|
| | from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image |
| | import auraloss |
| | from ema_pytorch import EMA |
| | from einops import rearrange |
| | from safetensors.torch import save_file |
| | from torch import optim |
| | from torch.nn import functional as F |
| | from pytorch_lightning.utilities.rank_zero import rank_zero_only |
| |
|
| | from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler |
| | from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper |
| | from ..models.autoencoders import DiffusionAutoencoder |
| | from ..models.diffusion_prior import PriorType |
| | from .autoencoders import create_loss_modules_from_bottleneck |
| | from .losses import AuralossLoss, MSELoss, MultiLoss |
| | from .utils import create_optimizer_from_config, create_scheduler_from_config |
| |
|
| | from time import time |
| |
|
| |
|
| | class Profiler: |
| |
|
| | def __init__(self): |
| | self.ticks = [[time(), None]] |
| |
|
| | def tick(self, msg): |
| | self.ticks.append([time(), msg]) |
| |
|
| | def __repr__(self): |
| | rep = 80 * "=" + "\n" |
| | for i in range(1, len(self.ticks)): |
| | msg = self.ticks[i][1] |
| | ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] |
| | rep += msg + f": {ellapsed*1000:.2f}ms\n" |
| | rep += 80 * "=" + "\n\n\n" |
| | return rep |
| |
|
| | class DiffusionUncondTrainingWrapper(pl.LightningModule): |
| | ''' |
| | Wrapper for training an unconditional audio diffusion model (like Dance Diffusion). |
| | ''' |
| | def __init__( |
| | self, |
| | model: DiffusionModelWrapper, |
| | lr: float = 1e-4, |
| | pre_encoded: bool = False |
| | ): |
| | super().__init__() |
| |
|
| | self.diffusion = model |
| | |
| | self.diffusion_ema = EMA( |
| | self.diffusion.model, |
| | beta=0.9999, |
| | power=3/4, |
| | update_every=1, |
| | update_after_step=1 |
| | ) |
| |
|
| | self.lr = lr |
| |
|
| | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) |
| |
|
| | loss_modules = [ |
| | MSELoss("v", |
| | "targets", |
| | weight=1.0, |
| | name="mse_loss" |
| | ) |
| | ] |
| |
|
| | self.losses = MultiLoss(loss_modules) |
| |
|
| | self.pre_encoded = pre_encoded |
| |
|
| | def configure_optimizers(self): |
| | return optim.Adam([*self.diffusion.parameters()], lr=self.lr) |
| |
|
| | def training_step(self, batch, batch_idx): |
| | reals = batch[0] |
| |
|
| | if reals.ndim == 4 and reals.shape[0] == 1: |
| | reals = reals[0] |
| | |
| | diffusion_input = reals |
| |
|
| | loss_info = {} |
| |
|
| | if not self.pre_encoded: |
| | loss_info["audio_reals"] = diffusion_input |
| |
|
| | if self.diffusion.pretransform is not None: |
| | if not self.pre_encoded: |
| | with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): |
| | diffusion_input = self.diffusion.pretransform.encode(diffusion_input) |
| | else: |
| | |
| | if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: |
| | diffusion_input = diffusion_input / self.diffusion.pretransform.scale |
| |
|
| | loss_info["reals"] = diffusion_input |
| |
|
| | |
| | t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) |
| |
|
| | |
| | alphas, sigmas = get_alphas_sigmas(t) |
| |
|
| | |
| | alphas = alphas[:, None, None] |
| | sigmas = sigmas[:, None, None] |
| | noise = torch.randn_like(diffusion_input) |
| | noised_inputs = diffusion_input * alphas + noise * sigmas |
| | targets = noise * alphas - diffusion_input * sigmas |
| |
|
| | with torch.cuda.amp.autocast(): |
| | v = self.diffusion(noised_inputs, t) |
| |
|
| | loss_info.update({ |
| | "v": v, |
| | "targets": targets |
| | }) |
| |
|
| | loss, losses = self.losses(loss_info) |
| |
|
| | log_dict = { |
| | 'train/loss': loss.detach(), |
| | 'train/std_data': diffusion_input.std(), |
| | } |
| |
|
| | for loss_name, loss_value in losses.items(): |
| | log_dict[f"train/{loss_name}"] = loss_value.detach() |
| |
|
| | self.log_dict(log_dict, prog_bar=True, on_step=True) |
| | return loss |
| | |
| | def on_before_zero_grad(self, *args, **kwargs): |
| | self.diffusion_ema.update() |
| |
|
| | def export_model(self, path, use_safetensors=False): |
| |
|
| | self.diffusion.model = self.diffusion_ema.ema_model |
| | |
| | if use_safetensors: |
| | save_file(self.diffusion.state_dict(), path) |
| | else: |
| | torch.save({"state_dict": self.diffusion.state_dict()}, path) |
| |
|
| | class DiffusionUncondDemoCallback(pl.Callback): |
| | def __init__(self, |
| | demo_every=2000, |
| | num_demos=8, |
| | demo_steps=250, |
| | sample_rate=48000 |
| | ): |
| | super().__init__() |
| |
|
| | self.demo_every = demo_every |
| | self.num_demos = num_demos |
| | self.demo_steps = demo_steps |
| | self.sample_rate = sample_rate |
| | self.last_demo_step = -1 |
| | |
| | @rank_zero_only |
| | @torch.no_grad() |
| | def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): |
| |
|
| | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: |
| | return |
| | |
| | self.last_demo_step = trainer.global_step |
| |
|
| | demo_samples = module.diffusion.sample_size |
| |
|
| | if module.diffusion.pretransform is not None: |
| | demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio |
| |
|
| | noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) |
| |
|
| | try: |
| | with torch.cuda.amp.autocast(): |
| | fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0) |
| |
|
| | if module.diffusion.pretransform is not None: |
| | fakes = module.diffusion.pretransform.decode(fakes) |
| |
|
| | |
| | fakes = rearrange(fakes, 'b d n -> d (b n)') |
| |
|
| | log_dict = {} |
| | |
| | filename = f'demo_{trainer.global_step:08}.wav' |
| | fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() |
| | torchaudio.save(filename, fakes, self.sample_rate) |
| |
|
| | log_dict[f'demo'] = wandb.Audio(filename, |
| | sample_rate=self.sample_rate, |
| | caption=f'Reconstructed') |
| | |
| | log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes)) |
| |
|
| | trainer.logger.experiment.log(log_dict) |
| |
|
| | del fakes |
| | |
| | except Exception as e: |
| | print(f'{type(e).__name__}: {e}') |
| | finally: |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| | class DiffusionCondTrainingWrapper(pl.LightningModule): |
| | ''' |
| | Wrapper for training a conditional audio diffusion model. |
| | ''' |
| | def __init__( |
| | self, |
| | model: ConditionedDiffusionModelWrapper, |
| | lr: float = None, |
| | mask_padding: bool = False, |
| | mask_padding_dropout: float = 0.0, |
| | use_ema: bool = True, |
| | log_loss_info: bool = True, |
| | optimizer_configs: dict = None, |
| | pre_encoded: bool = False, |
| | cfg_dropout_prob = 0.1, |
| | timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", |
| | ): |
| | super().__init__() |
| |
|
| | self.diffusion = model |
| |
|
| | if use_ema: |
| | self.diffusion_ema = EMA( |
| | self.diffusion.model, |
| | beta=0.9999, |
| | power=3/4, |
| | update_every=1, |
| | update_after_step=1, |
| | include_online_model=False |
| | ) |
| | else: |
| | self.diffusion_ema = None |
| |
|
| | self.mask_padding = mask_padding |
| | self.mask_padding_dropout = mask_padding_dropout |
| |
|
| | self.cfg_dropout_prob = cfg_dropout_prob |
| |
|
| | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) |
| |
|
| | self.timestep_sampler = timestep_sampler |
| |
|
| | self.diffusion_objective = model.diffusion_objective |
| |
|
| | if 'av_loss' in optimizer_configs and optimizer_configs['av_loss']['if_add_av_loss']: |
| | av_align_weight = optimizer_configs['av_loss']['config']['weight'] |
| | self.loss_modules = [ |
| | MSELoss("output", |
| | "targets", |
| | weight=1.0 - av_align_weight, |
| | mask_key="padding_mask" if self.mask_padding else None, |
| | name="mse_loss" |
| | ) |
| | ] |
| | else: |
| | self.loss_modules = [ |
| | MSELoss("output", |
| | "targets", |
| | weight=1.0, |
| | mask_key="padding_mask" if self.mask_padding else None, |
| | name="mse_loss" |
| | ) |
| | ] |
| |
|
| |
|
| | self.losses = MultiLoss(self.loss_modules) |
| |
|
| | self.log_loss_info = log_loss_info |
| |
|
| | assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" |
| |
|
| | if optimizer_configs is None: |
| | optimizer_configs = { |
| | "diffusion": { |
| | "optimizer": { |
| | "type": "Adam", |
| | "config": { |
| | "lr": lr |
| | } |
| | } |
| | } |
| | } |
| | else: |
| | if lr is not None: |
| | print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") |
| |
|
| | self.optimizer_configs = optimizer_configs |
| |
|
| | self.pre_encoded = pre_encoded |
| |
|
| | def configure_optimizers(self): |
| | diffusion_opt_config = self.optimizer_configs['diffusion'] |
| | opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) |
| |
|
| | if "scheduler" in diffusion_opt_config: |
| | sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) |
| | sched_diff_config = { |
| | "scheduler": sched_diff, |
| | "interval": "step" |
| | } |
| | return [opt_diff], [sched_diff_config] |
| |
|
| | return [opt_diff] |
| |
|
| | def training_step(self, batch, batch_idx): |
| |
|
| |
|
| | reals, metadata = batch |
| |
|
| | p = Profiler() |
| |
|
| | if reals.ndim == 4 and reals.shape[0] == 1: |
| | reals = reals[0] |
| |
|
| | loss_info = {} |
| |
|
| | diffusion_input = reals |
| | if not self.pre_encoded: |
| | loss_info["audio_reals"] = diffusion_input |
| |
|
| | p.tick("setup") |
| |
|
| | with torch.cuda.amp.autocast(): |
| | conditioning = self.diffusion.conditioner(metadata, self.device) |
| |
|
| | use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout |
| |
|
| | |
| | if use_padding_mask: |
| | padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) |
| |
|
| | p.tick("conditioning") |
| |
|
| | if self.diffusion.pretransform is not None: |
| | self.diffusion.pretransform.to(self.device) |
| |
|
| | if not self.pre_encoded: |
| | with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): |
| | self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) |
| |
|
| | diffusion_input = self.diffusion.pretransform.encode(diffusion_input) |
| | p.tick("pretransform") |
| |
|
| | |
| | if use_padding_mask: |
| | padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() |
| | else: |
| | |
| | if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: |
| | diffusion_input = diffusion_input / self.diffusion.pretransform.scale |
| |
|
| | if self.timestep_sampler == "uniform": |
| | |
| | t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) |
| | elif self.timestep_sampler == "logit_normal": |
| | t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) |
| |
|
| | |
| | if self.diffusion_objective == "v": |
| | alphas, sigmas = get_alphas_sigmas(t) |
| | elif self.diffusion_objective == "rectified_flow": |
| | alphas, sigmas = 1-t, t |
| |
|
| | |
| | alphas = alphas[:, None, None] |
| | sigmas = sigmas[:, None, None] |
| | noise = torch.randn_like(diffusion_input) |
| | noised_inputs = diffusion_input * alphas + noise * sigmas |
| |
|
| | if self.diffusion_objective == "v": |
| | targets = noise * alphas - diffusion_input * sigmas |
| | elif self.diffusion_objective == "rectified_flow": |
| | targets = noise - diffusion_input |
| |
|
| | p.tick("noise") |
| |
|
| | extra_args = {} |
| |
|
| | if use_padding_mask: |
| | extra_args["mask"] = padding_masks |
| |
|
| | with torch.cuda.amp.autocast(): |
| | p.tick("amp") |
| | output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) |
| | p.tick("diffusion") |
| |
|
| | loss_info.update({ |
| | "output": output, |
| | "targets": targets, |
| | "padding_mask": padding_masks if use_padding_mask else None, |
| | }) |
| |
|
| | loss, losses = self.losses(loss_info) |
| |
|
| | p.tick("loss") |
| |
|
| | if self.log_loss_info: |
| | |
| | num_loss_buckets = 10 |
| | bucket_size = 1 / num_loss_buckets |
| | loss_all = F.mse_loss(output, targets, reduction="none") |
| |
|
| | sigmas = rearrange(self.all_gather(sigmas), "b c n -> (b) c n").squeeze() |
| |
|
| | |
| | loss_all = rearrange(self.all_gather(loss_all), "b c n -> (b) c n") |
| |
|
| | |
| | loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) |
| |
|
| | |
| | debug_log_dict = { |
| | f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) |
| | } |
| |
|
| | self.log_dict(debug_log_dict) |
| |
|
| |
|
| | log_dict = { |
| | 'train/loss': loss.detach(), |
| | 'train/std_data': diffusion_input.std(), |
| | 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] |
| | } |
| |
|
| | for loss_name, loss_value in losses.items(): |
| | log_dict[f"train/{loss_name}"] = loss_value.detach() |
| |
|
| | self.log_dict(log_dict, prog_bar=True, on_step=True) |
| | p.tick("log") |
| | |
| | return loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | reals, metadata = batch |
| |
|
| | p = Profiler() |
| |
|
| | if reals.ndim == 4 and reals.shape[0] == 1: |
| | reals = reals[0] |
| |
|
| | loss_info = {} |
| |
|
| | diffusion_input = reals |
| |
|
| | if not self.pre_encoded: |
| | loss_info["audio_reals"] = diffusion_input |
| |
|
| | p.tick("setup") |
| | with torch.cuda.amp.autocast(): |
| | conditioning = self.diffusion.conditioner(metadata, self.device) |
| | |
| | |
| | use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout |
| |
|
| | |
| | if use_padding_mask: |
| | padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) |
| |
|
| | p.tick("conditioning") |
| |
|
| | if self.diffusion.pretransform is not None: |
| | self.diffusion.pretransform.to(self.device) |
| |
|
| | if not self.pre_encoded: |
| | with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): |
| | self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) |
| | |
| | diffusion_input = self.diffusion.pretransform.encode(diffusion_input) |
| | p.tick("pretransform") |
| |
|
| | |
| | if use_padding_mask: |
| | padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() |
| | else: |
| | |
| | if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: |
| | diffusion_input = diffusion_input / self.diffusion.pretransform.scale |
| |
|
| | if self.timestep_sampler == "uniform": |
| | |
| | t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) |
| | elif self.timestep_sampler == "logit_normal": |
| | t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) |
| | |
| | |
| | if self.diffusion_objective == "v": |
| | alphas, sigmas = get_alphas_sigmas(t) |
| | elif self.diffusion_objective == "rectified_flow": |
| | alphas, sigmas = 1-t, t |
| |
|
| | |
| | alphas = alphas[:, None, None] |
| | sigmas = sigmas[:, None, None] |
| | noise = torch.randn_like(diffusion_input) |
| | noised_inputs = diffusion_input * alphas + noise * sigmas |
| |
|
| | if self.diffusion_objective == "v": |
| | targets = noise * alphas - diffusion_input * sigmas |
| | elif self.diffusion_objective == "rectified_flow": |
| | targets = noise - diffusion_input |
| |
|
| | p.tick("noise") |
| |
|
| | extra_args = {} |
| |
|
| | if use_padding_mask: |
| | extra_args["mask"] = padding_masks |
| |
|
| | with torch.cuda.amp.autocast(): |
| | p.tick("amp") |
| | |
| | output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) |
| | p.tick("diffusion") |
| |
|
| | loss_info.update({ |
| | "output": output, |
| | "targets": targets, |
| | "padding_mask": padding_masks if use_padding_mask else None, |
| | }) |
| |
|
| | loss, losses = self.losses(loss_info) |
| |
|
| | p.tick("loss") |
| |
|
| | if self.log_loss_info: |
| | |
| | num_loss_buckets = 10 |
| | bucket_size = 1 / num_loss_buckets |
| | loss_all = F.mse_loss(output, targets, reduction="none") |
| | |
| | |
| |
|
| | sigmas = rearrange(self.all_gather(sigmas), "b c n -> (b) c n").squeeze() |
| | |
| |
|
| | |
| | loss_all = rearrange(self.all_gather(loss_all), "b c n -> (b) c n") |
| | |
| |
|
| | |
| | loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) |
| |
|
| | |
| | debug_log_dict = { |
| | f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) |
| | } |
| |
|
| | self.log_dict(debug_log_dict) |
| |
|
| |
|
| | log_dict = { |
| | 'valid/loss': loss.detach(), |
| | 'valid/std_data': diffusion_input.std(), |
| | 'valid/lr': self.trainer.optimizers[0].param_groups[0]['lr'] |
| | } |
| |
|
| |
|
| | for loss_name, loss_value in losses.items(): |
| | log_dict[f"valid/{loss_name}"] = loss_value.detach() |
| |
|
| | self.log_dict(log_dict, prog_bar=True, on_step=True) |
| | |
| |
|
| | p.tick("log") |
| | |
| | return loss |
| | |
| | def on_before_zero_grad(self, *args, **kwargs): |
| | if self.diffusion_ema is not None: |
| | self.diffusion_ema.update() |
| |
|
| | def export_model(self, path, use_safetensors=False): |
| | if self.diffusion_ema is not None: |
| | self.diffusion.model = self.diffusion_ema.ema_model |
| | |
| | if use_safetensors: |
| | save_file(self.diffusion.state_dict(), path) |
| | else: |
| | torch.save({"state_dict": self.diffusion.state_dict()}, path) |
| |
|
| | class DiffusionCondDemoCallback(pl.Callback): |
| | def __init__(self, |
| | demo_every=2000, |
| | num_demos=8, |
| | sample_size=65536, |
| | demo_steps=250, |
| | sample_rate=48000, |
| | demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = {}, |
| | demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], |
| | demo_cond_from_batch: bool = False, |
| | display_audio_cond: bool = False |
| | ): |
| | super().__init__() |
| |
|
| | self.demo_every = demo_every |
| | self.num_demos = num_demos |
| | self.demo_samples = sample_size |
| | self.demo_steps = demo_steps |
| | self.sample_rate = sample_rate |
| | self.last_demo_step = -1 |
| | self.demo_conditioning = demo_conditioning |
| | self.demo_cfg_scales = demo_cfg_scales |
| |
|
| | |
| | self.demo_cond_from_batch = demo_cond_from_batch |
| |
|
| | |
| | self.display_audio_cond = display_audio_cond |
| |
|
| | @rank_zero_only |
| | @torch.no_grad() |
| | def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): |
| |
|
| | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: |
| | return |
| |
|
| | module.eval() |
| |
|
| | print(f"Generating demo") |
| | self.last_demo_step = trainer.global_step |
| |
|
| | demo_samples = self.demo_samples |
| |
|
| | demo_cond = self.demo_conditioning |
| |
|
| | if self.demo_cond_from_batch: |
| | |
| | demo_cond = batch[1][:self.num_demos] |
| |
|
| | if module.diffusion.pretransform is not None: |
| | demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio |
| |
|
| | noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) |
| |
|
| | try: |
| | print("Getting conditioning") |
| | with torch.cuda.amp.autocast(): |
| | conditioning = module.diffusion.conditioner(demo_cond, module.device) |
| |
|
| |
|
| | cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) |
| |
|
| | log_dict = {} |
| |
|
| | if self.display_audio_cond: |
| | audio_inputs = torch.cat([cond["audio"] for cond in demo_cond], dim=0) |
| | audio_inputs = rearrange(audio_inputs, 'b d n -> d (b n)') |
| |
|
| | filename = f'demo_audio_cond_{trainer.global_step:08}.wav' |
| | audio_inputs = audio_inputs.to(torch.float32).mul(32767).to(torch.int16).cpu() |
| | torchaudio.save(filename, audio_inputs, self.sample_rate) |
| | log_dict[f'demo_audio_cond'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption="Audio conditioning") |
| | log_dict[f"demo_audio_cond_melspec_left"] = wandb.Image(audio_spectrogram_image(audio_inputs)) |
| | trainer.logger.experiment.log(log_dict) |
| |
|
| | for cfg_scale in self.demo_cfg_scales: |
| |
|
| | print(f"Generating demo for cfg scale {cfg_scale}") |
| | |
| | with torch.cuda.amp.autocast(): |
| | model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model |
| |
|
| | if module.diffusion_objective == "v": |
| | fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) |
| | elif module.diffusion_objective == "rectified_flow": |
| | fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) |
| | |
| | if module.diffusion.pretransform is not None: |
| | fakes = module.diffusion.pretransform.decode(fakes) |
| |
|
| | |
| | fakes = rearrange(fakes, 'b d n -> d (b n)') |
| |
|
| | log_dict = {} |
| | |
| | filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' |
| | fakes = fakes.div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() |
| | torchaudio.save(filename, fakes, self.sample_rate) |
| |
|
| | log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, |
| | sample_rate=self.sample_rate, |
| | caption=f'Reconstructed') |
| | |
| | log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) |
| |
|
| | trainer.logger.experiment.log(log_dict) |
| | |
| | del fakes |
| |
|
| | except Exception as e: |
| | raise e |
| | finally: |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | module.train() |
| |
|
| | class DiffusionCondInpaintTrainingWrapper(pl.LightningModule): |
| | ''' |
| | Wrapper for training a conditional audio diffusion model. |
| | ''' |
| | def __init__( |
| | self, |
| | model: ConditionedDiffusionModelWrapper, |
| | lr: float = 1e-4, |
| | max_mask_segments = 10, |
| | log_loss_info: bool = False, |
| | optimizer_configs: dict = None, |
| | use_ema: bool = True, |
| | pre_encoded: bool = False, |
| | cfg_dropout_prob = 0.1, |
| | timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", |
| | ): |
| | super().__init__() |
| |
|
| | self.diffusion = model |
| | |
| | self.use_ema = use_ema |
| |
|
| | if self.use_ema: |
| | self.diffusion_ema = EMA( |
| | self.diffusion.model, |
| | beta=0.9999, |
| | power=3/4, |
| | update_every=1, |
| | update_after_step=1, |
| | include_online_model=False |
| | ) |
| | else: |
| | self.diffusion_ema = None |
| |
|
| | self.cfg_dropout_prob = cfg_dropout_prob |
| |
|
| | self.lr = lr |
| | self.max_mask_segments = max_mask_segments |
| |
|
| | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) |
| | |
| | self.timestep_sampler = timestep_sampler |
| |
|
| | self.diffusion_objective = model.diffusion_objective |
| |
|
| | self.loss_modules = [ |
| | MSELoss("output", |
| | "targets", |
| | weight=1.0, |
| | name="mse_loss" |
| | ) |
| | ] |
| |
|
| | self.losses = MultiLoss(self.loss_modules) |
| |
|
| | self.log_loss_info = log_loss_info |
| |
|
| | assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" |
| |
|
| | if optimizer_configs is None: |
| | optimizer_configs = { |
| | "diffusion": { |
| | "optimizer": { |
| | "type": "Adam", |
| | "config": { |
| | "lr": lr |
| | } |
| | } |
| | } |
| | } |
| | else: |
| | if lr is not None: |
| | print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") |
| |
|
| | self.optimizer_configs = optimizer_configs |
| |
|
| | self.pre_encoded = pre_encoded |
| |
|
| | def configure_optimizers(self): |
| | diffusion_opt_config = self.optimizer_configs['diffusion'] |
| | opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) |
| |
|
| | if "scheduler" in diffusion_opt_config: |
| | sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) |
| | sched_diff_config = { |
| | "scheduler": sched_diff, |
| | "interval": "step" |
| | } |
| | return [opt_diff], [sched_diff_config] |
| |
|
| | return [opt_diff] |
| |
|
| | def random_mask(self, sequence, max_mask_length): |
| | b, _, sequence_length = sequence.size() |
| |
|
| | |
| | masks = [] |
| |
|
| | for i in range(b): |
| | mask_type = random.randint(0, 2) |
| |
|
| | if mask_type == 0: |
| | num_segments = random.randint(1, self.max_mask_segments) |
| | max_segment_length = max_mask_length // num_segments |
| |
|
| | segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments) |
| | |
| | mask = torch.ones((1, 1, sequence_length)) |
| | for length in segment_lengths: |
| | mask_start = random.randint(0, sequence_length - length) |
| | mask[:, :, mask_start:mask_start + length] = 0 |
| |
|
| | elif mask_type == 1: |
| | mask = torch.zeros((1, 1, sequence_length)) |
| |
|
| | elif mask_type == 2: |
| | mask = torch.ones((1, 1, sequence_length)) |
| | mask_length = random.randint(1, max_mask_length) |
| | mask[:, :, -mask_length:] = 0 |
| |
|
| | mask = mask.to(sequence.device) |
| | masks.append(mask) |
| |
|
| | |
| | mask = torch.cat(masks, dim=0).to(sequence.device) |
| |
|
| | |
| | masked_sequence = sequence * mask |
| |
|
| | return masked_sequence, mask |
| |
|
| | def training_step(self, batch, batch_idx): |
| | reals, metadata = batch |
| |
|
| | p = Profiler() |
| |
|
| | if reals.ndim == 4 and reals.shape[0] == 1: |
| | reals = reals[0] |
| |
|
| | loss_info = {} |
| |
|
| | diffusion_input = reals |
| |
|
| | if not self.pre_encoded: |
| | loss_info["audio_reals"] = diffusion_input |
| |
|
| | p.tick("setup") |
| |
|
| | with torch.cuda.amp.autocast(): |
| | conditioning = self.diffusion.conditioner(metadata, self.device) |
| |
|
| | p.tick("conditioning") |
| |
|
| | if self.diffusion.pretransform is not None: |
| | self.diffusion.pretransform.to(self.device) |
| |
|
| | if not self.pre_encoded: |
| | with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): |
| | diffusion_input = self.diffusion.pretransform.encode(diffusion_input) |
| | p.tick("pretransform") |
| |
|
| | |
| | |
| | |
| | else: |
| | |
| | if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: |
| | diffusion_input = diffusion_input / self.diffusion.pretransform.scale |
| |
|
| | |
| | max_mask_length = diffusion_input.shape[2] |
| |
|
| | |
| | masked_input, mask = self.random_mask(diffusion_input, max_mask_length) |
| |
|
| | conditioning['inpaint_mask'] = [mask] |
| | conditioning['inpaint_masked_input'] = [masked_input] |
| |
|
| | if self.timestep_sampler == "uniform": |
| | |
| | t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) |
| | elif self.timestep_sampler == "logit_normal": |
| | t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) |
| | |
| | |
| | if self.diffusion_objective == "v": |
| | alphas, sigmas = get_alphas_sigmas(t) |
| | elif self.diffusion_objective == "rectified_flow": |
| | alphas, sigmas = 1-t, t |
| |
|
| | |
| | alphas = alphas[:, None, None] |
| | sigmas = sigmas[:, None, None] |
| | noise = torch.randn_like(diffusion_input) |
| | noised_inputs = diffusion_input * alphas + noise * sigmas |
| |
|
| | if self.diffusion_objective == "v": |
| | targets = noise * alphas - diffusion_input * sigmas |
| | elif self.diffusion_objective == "rectified_flow": |
| | targets = noise - diffusion_input |
| |
|
| | p.tick("noise") |
| |
|
| | extra_args = {} |
| |
|
| | with torch.cuda.amp.autocast(): |
| | p.tick("amp") |
| | output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) |
| | p.tick("diffusion") |
| |
|
| | loss_info.update({ |
| | "output": output, |
| | "targets": targets, |
| | }) |
| |
|
| | loss, losses = self.losses(loss_info) |
| |
|
| | if self.log_loss_info: |
| | |
| | num_loss_buckets = 10 |
| | bucket_size = 1 / num_loss_buckets |
| | loss_all = F.mse_loss(output, targets, reduction="none") |
| |
|
| | sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() |
| |
|
| | |
| | loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") |
| |
|
| | |
| | loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) |
| |
|
| | |
| | debug_log_dict = { |
| | f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) |
| | } |
| |
|
| | self.log_dict(debug_log_dict) |
| |
|
| | log_dict = { |
| | 'train/loss': loss.detach(), |
| | 'train/std_data': diffusion_input.std(), |
| | 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] |
| | } |
| |
|
| | for loss_name, loss_value in losses.items(): |
| | log_dict[f"train/{loss_name}"] = loss_value.detach() |
| |
|
| | self.log_dict(log_dict, prog_bar=True, on_step=True) |
| | p.tick("log") |
| | |
| | return loss |
| | |
| | def on_before_zero_grad(self, *args, **kwargs): |
| | if self.diffusion_ema is not None: |
| | self.diffusion_ema.update() |
| |
|
| | def export_model(self, path, use_safetensors=False): |
| | if self.diffusion_ema is not None: |
| | self.diffusion.model = self.diffusion_ema.ema_model |
| | |
| | if use_safetensors: |
| | save_file(self.diffusion.state_dict(), path) |
| | else: |
| | torch.save({"state_dict": self.diffusion.state_dict()}, path) |
| |
|
| | class DiffusionCondInpaintDemoCallback(pl.Callback): |
| | def __init__( |
| | self, |
| | demo_dl, |
| | demo_every=2000, |
| | demo_steps=250, |
| | sample_size=65536, |
| | sample_rate=48000, |
| | demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7] |
| | ): |
| | super().__init__() |
| | self.demo_every = demo_every |
| | self.demo_steps = demo_steps |
| | self.demo_samples = sample_size |
| | self.demo_dl = iter(demo_dl) |
| | self.sample_rate = sample_rate |
| | self.demo_cfg_scales = demo_cfg_scales |
| | self.last_demo_step = -1 |
| |
|
| | @rank_zero_only |
| | @torch.no_grad() |
| | def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): |
| | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: |
| | return |
| | |
| | self.last_demo_step = trainer.global_step |
| |
|
| | try: |
| | log_dict = {} |
| |
|
| | demo_reals, metadata = next(self.demo_dl) |
| |
|
| | |
| | if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: |
| | demo_reals = demo_reals[0] |
| |
|
| | demo_reals = demo_reals.to(module.device) |
| |
|
| | if not module.pre_encoded: |
| | |
| | log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu())) |
| | |
| |
|
| | if module.diffusion.pretransform is not None: |
| | module.diffusion.pretransform.to(module.device) |
| | with torch.cuda.amp.autocast(): |
| | demo_reals = module.diffusion.pretransform.encode(demo_reals) |
| |
|
| | demo_samples = demo_reals.shape[2] |
| |
|
| | |
| | conditioning = module.diffusion.conditioner(metadata, module.device) |
| |
|
| | masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2]) |
| |
|
| | conditioning['inpaint_mask'] = [mask] |
| | conditioning['inpaint_masked_input'] = [masked_input] |
| |
|
| | if module.diffusion.pretransform is not None: |
| | log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu())) |
| | else: |
| | log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu())) |
| |
|
| | cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) |
| |
|
| | noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device) |
| |
|
| | trainer.logger.experiment.log(log_dict) |
| |
|
| | for cfg_scale in self.demo_cfg_scales: |
| | model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model |
| | print(f"Generating demo for cfg scale {cfg_scale}") |
| |
|
| | if module.diffusion_objective == "v": |
| | fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) |
| | elif module.diffusion_objective == "rectified_flow": |
| | fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) |
| |
|
| | if module.diffusion.pretransform is not None: |
| | with torch.cuda.amp.autocast(): |
| | fakes = module.diffusion.pretransform.decode(fakes) |
| |
|
| | |
| | fakes = rearrange(fakes, 'b d n -> d (b n)') |
| |
|
| | log_dict = {} |
| | |
| | filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' |
| | fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() |
| | torchaudio.save(filename, fakes, self.sample_rate) |
| |
|
| | log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, |
| | sample_rate=self.sample_rate, |
| | caption=f'Reconstructed') |
| | |
| | log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) |
| |
|
| | trainer.logger.experiment.log(log_dict) |
| | except Exception as e: |
| | print(f'{type(e).__name__}: {e}') |
| | raise e |
| |
|
| | class DiffusionAutoencoderTrainingWrapper(pl.LightningModule): |
| | ''' |
| | Wrapper for training a diffusion autoencoder |
| | ''' |
| | def __init__( |
| | self, |
| | model: DiffusionAutoencoder, |
| | lr: float = 1e-4, |
| | ema_copy = None, |
| | use_reconstruction_loss: bool = False |
| | ): |
| | super().__init__() |
| |
|
| | self.diffae = model |
| | |
| | self.diffae_ema = EMA( |
| | self.diffae, |
| | ema_model=ema_copy, |
| | beta=0.9999, |
| | power=3/4, |
| | update_every=1, |
| | update_after_step=1, |
| | include_online_model=False |
| | ) |
| |
|
| | self.lr = lr |
| |
|
| | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) |
| |
|
| | loss_modules = [ |
| | MSELoss("v", |
| | "targets", |
| | weight=1.0, |
| | name="mse_loss" |
| | ) |
| | ] |
| |
|
| | if model.bottleneck is not None: |
| | |
| | loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {}) |
| |
|
| | self.use_reconstruction_loss = use_reconstruction_loss |
| |
|
| | if use_reconstruction_loss: |
| | scales = [2048, 1024, 512, 256, 128, 64, 32] |
| | hop_sizes = [] |
| | win_lengths = [] |
| | overlap = 0.75 |
| | for s in scales: |
| | hop_sizes.append(int(s * (1 - overlap))) |
| | win_lengths.append(s) |
| |
|
| | sample_rate = model.sample_rate |
| |
|
| | stft_loss_args = { |
| | "fft_sizes": scales, |
| | "hop_sizes": hop_sizes, |
| | "win_lengths": win_lengths, |
| | "perceptual_weighting": True |
| | } |
| |
|
| | out_channels = model.out_channels |
| |
|
| | if model.pretransform is not None: |
| | out_channels = model.pretransform.io_channels |
| |
|
| | if out_channels == 2: |
| | self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) |
| | else: |
| | self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) |
| |
|
| | loss_modules.append( |
| | AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), |
| | ) |
| |
|
| | self.losses = MultiLoss(loss_modules) |
| |
|
| | def configure_optimizers(self): |
| | return optim.Adam([*self.diffae.parameters()], lr=self.lr) |
| |
|
| | def training_step(self, batch, batch_idx): |
| | reals = batch[0] |
| |
|
| | if reals.ndim == 4 and reals.shape[0] == 1: |
| | reals = reals[0] |
| |
|
| | loss_info = {} |
| |
|
| | loss_info["audio_reals"] = reals |
| | |
| | if self.diffae.pretransform is not None: |
| | with torch.no_grad(): |
| | reals = self.diffae.pretransform.encode(reals) |
| |
|
| | loss_info["reals"] = reals |
| |
|
| | |
| | latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True) |
| |
|
| | loss_info["latents"] = latents |
| | loss_info.update(encoder_info) |
| |
|
| | if self.diffae.decoder is not None: |
| | latents = self.diffae.decoder(latents) |
| | |
| | |
| | if latents.shape[2] != reals.shape[2]: |
| | latents = F.interpolate(latents, size=reals.shape[2], mode='nearest') |
| |
|
| | loss_info["latents_upsampled"] = latents |
| |
|
| | |
| | t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) |
| |
|
| | |
| | alphas, sigmas = get_alphas_sigmas(t) |
| |
|
| | |
| | alphas = alphas[:, None, None] |
| | sigmas = sigmas[:, None, None] |
| | noise = torch.randn_like(reals) |
| | noised_reals = reals * alphas + noise * sigmas |
| | targets = noise * alphas - reals * sigmas |
| |
|
| | with torch.cuda.amp.autocast(): |
| | v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents) |
| | |
| | loss_info.update({ |
| | "v": v, |
| | "targets": targets |
| | }) |
| |
|
| | if self.use_reconstruction_loss: |
| | pred = noised_reals * alphas - v * sigmas |
| |
|
| | loss_info["pred"] = pred |
| |
|
| | if self.diffae.pretransform is not None: |
| | pred = self.diffae.pretransform.decode(pred) |
| | loss_info["audio_pred"] = pred |
| |
|
| | loss, losses = self.losses(loss_info) |
| |
|
| | log_dict = { |
| | 'train/loss': loss.detach(), |
| | 'train/std_data': reals.std(), |
| | 'train/latent_std': latents.std(), |
| | } |
| |
|
| | for loss_name, loss_value in losses.items(): |
| | log_dict[f"train/{loss_name}"] = loss_value.detach() |
| |
|
| | self.log_dict(log_dict, prog_bar=True, on_step=True) |
| | return loss |
| | |
| | def on_before_zero_grad(self, *args, **kwargs): |
| | self.diffae_ema.update() |
| |
|
| | def export_model(self, path, use_safetensors=False): |
| |
|
| | model = self.diffae_ema.ema_model |
| | |
| | if use_safetensors: |
| | save_file(model.state_dict(), path) |
| | else: |
| | torch.save({"state_dict": model.state_dict()}, path) |
| |
|
| | class DiffusionAutoencoderDemoCallback(pl.Callback): |
| | def __init__( |
| | self, |
| | demo_dl, |
| | demo_every=2000, |
| | demo_steps=250, |
| | sample_size=65536, |
| | sample_rate=48000 |
| | ): |
| | super().__init__() |
| | self.demo_every = demo_every |
| | self.demo_steps = demo_steps |
| | self.demo_samples = sample_size |
| | self.demo_dl = iter(demo_dl) |
| | self.sample_rate = sample_rate |
| | self.last_demo_step = -1 |
| |
|
| | @rank_zero_only |
| | @torch.no_grad() |
| | def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): |
| | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: |
| | return |
| | |
| | self.last_demo_step = trainer.global_step |
| |
|
| | demo_reals, _ = next(self.demo_dl) |
| |
|
| | |
| | if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: |
| | demo_reals = demo_reals[0] |
| |
|
| | encoder_input = demo_reals |
| | |
| | encoder_input = encoder_input.to(module.device) |
| |
|
| | demo_reals = demo_reals.to(module.device) |
| |
|
| | with torch.no_grad() and torch.cuda.amp.autocast(): |
| | latents = module.diffae_ema.ema_model.encode(encoder_input).float() |
| | fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps) |
| |
|
| | |
| | reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') |
| |
|
| | |
| | reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') |
| |
|
| | log_dict = {} |
| | |
| | filename = f'recon_{trainer.global_step:08}.wav' |
| | reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() |
| | torchaudio.save(filename, reals_fakes, self.sample_rate) |
| |
|
| | log_dict[f'recon'] = wandb.Audio(filename, |
| | sample_rate=self.sample_rate, |
| | caption=f'Reconstructed') |
| |
|
| | log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) |
| | log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) |
| |
|
| | log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) |
| |
|
| | if module.diffae_ema.ema_model.pretransform is not None: |
| | with torch.no_grad() and torch.cuda.amp.autocast(): |
| | initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input) |
| | first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents) |
| | first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)') |
| | first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu() |
| | first_stage_filename = f'first_stage_{trainer.global_step:08}.wav' |
| | torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate) |
| |
|
| | log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents)) |
| |
|
| | log_dict[f'first_stage'] = wandb.Audio(first_stage_filename, |
| | sample_rate=self.sample_rate, |
| | caption=f'First Stage Reconstructed') |
| | |
| | log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes)) |
| | |
| |
|
| | trainer.logger.experiment.log(log_dict) |
| |
|
| | def create_source_mixture(reals, num_sources=2): |
| | |
| | source = torch.zeros_like(reals) |
| | for i in range(reals.shape[0]): |
| | sources_added = 0 |
| | |
| | js = list(range(reals.shape[0])) |
| | random.shuffle(js) |
| | for j in js: |
| | if i == j or (i != j and sources_added < num_sources): |
| | |
| | seq_len = reals.shape[2] |
| | offset = random.randint(0, seq_len-1) |
| | source[i, :, offset:] += reals[j, :, :-offset] |
| | if i == j: |
| | |
| | new_reals = torch.zeros_like(reals[i]) |
| | new_reals[:, offset:] = reals[i, :, :-offset] |
| | reals[i] = new_reals |
| | sources_added += 1 |
| |
|
| | return source |
| |
|
| | class DiffusionPriorTrainingWrapper(pl.LightningModule): |
| | ''' |
| | Wrapper for training a diffusion prior for inverse problems |
| | Prior types: |
| | mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version |
| | ''' |
| | def __init__( |
| | self, |
| | model: ConditionedDiffusionModelWrapper, |
| | lr: float = 1e-4, |
| | ema_copy = None, |
| | prior_type: PriorType = PriorType.MonoToStereo, |
| | use_reconstruction_loss: bool = False, |
| | log_loss_info: bool = False, |
| | ): |
| | super().__init__() |
| |
|
| | self.diffusion = model |
| | |
| | self.diffusion_ema = EMA( |
| | self.diffusion, |
| | ema_model=ema_copy, |
| | beta=0.9999, |
| | power=3/4, |
| | update_every=1, |
| | update_after_step=1, |
| | include_online_model=False |
| | ) |
| |
|
| | self.lr = lr |
| |
|
| | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) |
| |
|
| | self.log_loss_info = log_loss_info |
| |
|
| | loss_modules = [ |
| | MSELoss("v", |
| | "targets", |
| | weight=1.0, |
| | name="mse_loss" |
| | ) |
| | ] |
| |
|
| | self.use_reconstruction_loss = use_reconstruction_loss |
| |
|
| | if use_reconstruction_loss: |
| | scales = [2048, 1024, 512, 256, 128, 64, 32] |
| | hop_sizes = [] |
| | win_lengths = [] |
| | overlap = 0.75 |
| | for s in scales: |
| | hop_sizes.append(int(s * (1 - overlap))) |
| | win_lengths.append(s) |
| |
|
| | sample_rate = model.sample_rate |
| |
|
| | stft_loss_args = { |
| | "fft_sizes": scales, |
| | "hop_sizes": hop_sizes, |
| | "win_lengths": win_lengths, |
| | "perceptual_weighting": True |
| | } |
| |
|
| | out_channels = model.io_channels |
| |
|
| | self.audio_out_channels = out_channels |
| |
|
| | if model.pretransform is not None: |
| | out_channels = model.pretransform.io_channels |
| |
|
| | if self.audio_out_channels == 2: |
| | self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) |
| | self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) |
| |
|
| | |
| | self.loss_modules += [ |
| | AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05), |
| | AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05), |
| | ] |
| |
|
| | else: |
| | self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) |
| |
|
| | self.loss_modules.append( |
| | AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), |
| | ) |
| |
|
| | self.losses = MultiLoss(loss_modules) |
| |
|
| | self.prior_type = prior_type |
| |
|
| | def configure_optimizers(self): |
| | return optim.Adam([*self.diffusion.parameters()], lr=self.lr) |
| |
|
| | def training_step(self, batch, batch_idx): |
| | reals, metadata = batch |
| |
|
| | if reals.ndim == 4 and reals.shape[0] == 1: |
| | reals = reals[0] |
| |
|
| | loss_info = {} |
| |
|
| | loss_info["audio_reals"] = reals |
| |
|
| | if self.prior_type == PriorType.MonoToStereo: |
| | source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device) |
| | loss_info["audio_reals_mono"] = source |
| | else: |
| | raise ValueError(f"Unknown prior type {self.prior_type}") |
| | |
| | if self.diffusion.pretransform is not None: |
| | with torch.no_grad(): |
| | reals = self.diffusion.pretransform.encode(reals) |
| |
|
| | if self.prior_type in [PriorType.MonoToStereo]: |
| | source = self.diffusion.pretransform.encode(source) |
| |
|
| | if self.diffusion.conditioner is not None: |
| | with torch.cuda.amp.autocast(): |
| | conditioning = self.diffusion.conditioner(metadata, self.device) |
| | else: |
| | conditioning = {} |
| |
|
| | loss_info["reals"] = reals |
| |
|
| | |
| | t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) |
| |
|
| | |
| | alphas, sigmas = get_alphas_sigmas(t) |
| |
|
| | |
| | alphas = alphas[:, None, None] |
| | sigmas = sigmas[:, None, None] |
| | noise = torch.randn_like(reals) |
| | noised_reals = reals * alphas + noise * sigmas |
| | targets = noise * alphas - reals * sigmas |
| |
|
| | with torch.cuda.amp.autocast(): |
| | |
| | conditioning['source'] = [source] |
| |
|
| | v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1) |
| | |
| | loss_info.update({ |
| | "v": v, |
| | "targets": targets |
| | }) |
| |
|
| | if self.use_reconstruction_loss: |
| | pred = noised_reals * alphas - v * sigmas |
| |
|
| | loss_info["pred"] = pred |
| |
|
| | if self.diffusion.pretransform is not None: |
| | pred = self.diffusion.pretransform.decode(pred) |
| | loss_info["audio_pred"] = pred |
| |
|
| | if self.audio_out_channels == 2: |
| | loss_info["pred_left"] = pred[:, 0:1, :] |
| | loss_info["pred_right"] = pred[:, 1:2, :] |
| | loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :] |
| | loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :] |
| |
|
| | loss, losses = self.losses(loss_info) |
| |
|
| | if self.log_loss_info: |
| | |
| | num_loss_buckets = 10 |
| | bucket_size = 1 / num_loss_buckets |
| | loss_all = F.mse_loss(v, targets, reduction="none") |
| |
|
| | sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() |
| |
|
| | |
| | loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") |
| |
|
| | |
| | loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) |
| |
|
| | |
| | debug_log_dict = { |
| | f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) |
| | } |
| |
|
| | self.log_dict(debug_log_dict) |
| |
|
| | log_dict = { |
| | 'train/loss': loss.detach(), |
| | 'train/std_data': reals.std() |
| | } |
| |
|
| | for loss_name, loss_value in losses.items(): |
| | log_dict[f"train/{loss_name}"] = loss_value.detach() |
| |
|
| | self.log_dict(log_dict, prog_bar=True, on_step=True) |
| | return loss |
| | |
| | def on_before_zero_grad(self, *args, **kwargs): |
| | self.diffusion_ema.update() |
| |
|
| | def export_model(self, path, use_safetensors=False): |
| |
|
| | |
| | model = self.diffusion |
| | |
| | if use_safetensors: |
| | save_file(model.state_dict(), path) |
| | else: |
| | torch.save({"state_dict": model.state_dict()}, path) |
| |
|
| | class DiffusionPriorDemoCallback(pl.Callback): |
| | def __init__( |
| | self, |
| | demo_dl, |
| | demo_every=2000, |
| | demo_steps=250, |
| | sample_size=65536, |
| | sample_rate=48000 |
| | ): |
| | super().__init__() |
| | self.demo_every = demo_every |
| | self.demo_steps = demo_steps |
| | self.demo_samples = sample_size |
| | self.demo_dl = iter(demo_dl) |
| | self.sample_rate = sample_rate |
| | self.last_demo_step = -1 |
| |
|
| | @rank_zero_only |
| | @torch.no_grad() |
| | def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): |
| | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: |
| | return |
| | |
| | self.last_demo_step = trainer.global_step |
| |
|
| | demo_reals, metadata = next(self.demo_dl) |
| |
|
| | |
| | if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: |
| | demo_reals = demo_reals[0] |
| |
|
| | demo_reals = demo_reals.to(module.device) |
| |
|
| | encoder_input = demo_reals |
| |
|
| | if module.diffusion.conditioner is not None: |
| | with torch.cuda.amp.autocast(): |
| | conditioning_tensors = module.diffusion.conditioner(metadata, module.device) |
| |
|
| | else: |
| | conditioning_tensors = {} |
| |
|
| | |
| | with torch.no_grad() and torch.cuda.amp.autocast(): |
| | if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1: |
| | source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device) |
| |
|
| | if module.diffusion.pretransform is not None: |
| | encoder_input = module.diffusion.pretransform.encode(encoder_input) |
| | source_input = module.diffusion.pretransform.encode(source) |
| | else: |
| | source_input = source |
| |
|
| | conditioning_tensors['source'] = [source_input] |
| |
|
| | fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors) |
| |
|
| | if module.diffusion.pretransform is not None: |
| | fakes = module.diffusion.pretransform.decode(fakes) |
| |
|
| | |
| | reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') |
| |
|
| | |
| | reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') |
| |
|
| | log_dict = {} |
| | |
| | filename = f'recon_{trainer.global_step:08}.wav' |
| | reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() |
| | torchaudio.save(filename, reals_fakes, self.sample_rate) |
| |
|
| | log_dict[f'recon'] = wandb.Audio(filename, |
| | sample_rate=self.sample_rate, |
| | caption=f'Reconstructed') |
| |
|
| | log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) |
| |
|
| | |
| | filename = f'source_{trainer.global_step:08}.wav' |
| | source = rearrange(source, 'b d n -> d (b n)') |
| | source = source.to(torch.float32).mul(32767).to(torch.int16).cpu() |
| | torchaudio.save(filename, source, self.sample_rate) |
| |
|
| | log_dict[f'source'] = wandb.Audio(filename, |
| | sample_rate=self.sample_rate, |
| | caption=f'Source') |
| |
|
| | log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source)) |
| |
|
| | trainer.logger.experiment.log(log_dict) |