| | import torch |
| | import torch.optim as optim |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| | from torchvision.utils import make_grid, save_image |
| | from tqdm import tqdm |
| | from ddt_model import LocalSongModel |
| | from transformers import get_cosine_schedule_with_warmup |
| | from datasets import load_from_disk |
| | from accelerate import Accelerator |
| | import os |
| | import argparse |
| | from torch.utils.tensorboard import SummaryWriter |
| | from datetime import datetime |
| | from collections import deque |
| | import torchaudio |
| | import re |
| | import sys |
| | import math |
| | from tag_embedder import TagEmbedder |
| |
|
| | |
| | from acestep.music_dcae.music_dcae_pipeline import MusicDCAE |
| |
|
| | |
| | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| | import timm.optim |
| |
|
| | import os |
| |
|
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | def save(model, optimizer, scheduler, global_step, accelerator): |
| | if accelerator.is_main_process: |
| | checkpoint_dir = "checkpoints" |
| | os.makedirs(checkpoint_dir, exist_ok=True) |
| |
|
| | unwrapped_model = accelerator.unwrap_model(model) |
| | |
| | checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{global_step}.pth") |
| | save_dict = { |
| | 'model_state_dict': unwrapped_model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'global_step': global_step |
| | } |
| | |
| | accelerator.save(save_dict, checkpoint_path) |
| | print(f"Checkpoint saved at step {global_step}: {checkpoint_path}") |
| |
|
| | checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")], |
| | key=lambda x: int(x.split("_")[1].split(".")[0]), reverse=True) |
| | |
| | for old_checkpoint in checkpoints[5:]: |
| | os.remove(os.path.join(checkpoint_dir, old_checkpoint)) |
| | print(f"Removed old checkpoint: {old_checkpoint}") |
| |
|
| |
|
| | def load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator): |
| | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) |
| | |
| | unwrapped_model = accelerator.unwrap_model(model) |
| | state_dict = {k.replace("_orig_mod.", ""): v for k, v in checkpoint['model_state_dict'].items()} |
| | missing, unexpected = unwrapped_model.load_state_dict(state_dict, strict=True) |
| | print("MISSING:", missing) |
| | print("UNEXPECTED:", unexpected) |
| |
|
| | if 'optimizer_state_dict' in checkpoint: |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | print("Optimizer loaded") |
| |
|
| | global_step = checkpoint['global_step'] |
| | print(f"Resumed from step {global_step}") |
| | return global_step |
| |
|
| | def resume(model, optimizer, scheduler, accelerator): |
| | checkpoint_dir = "checkpoints" |
| | if os.path.exists(checkpoint_dir): |
| | checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pth")] |
| | if checkpoints: |
| | latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("_")[1].split(".")[0])) |
| | checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint) |
| | if accelerator.is_main_process: |
| | print(f"Resuming from checkpoint: {checkpoint_path}") |
| |
|
| | return load_checkpoint(model, optimizer, scheduler, checkpoint_path, accelerator) |
| | else: |
| | if accelerator.is_main_process: |
| | print("No checkpoints found. Starting from scratch.") |
| | else: |
| | if accelerator.is_main_process: |
| | print("Checkpoint directory not found. Starting from scratch.") |
| | |
| | return 0 |
| |
|
| | class AudioVAE: |
| | def __init__(self, device): |
| | self.model = MusicDCAE().to(device) |
| | self.model.eval() |
| | self.device = device |
| |
|
| | self.latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], device=device).view(1, -1, 1, 1) |
| | self.latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], device=device).view(1, -1, 1, 1) |
| |
|
| | def encode(self, audio): |
| | """Encode audio to latents""" |
| | |
| | with torch.no_grad(): |
| | audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device) |
| | latents, _ = self.model.encode(audio, audio_lengths, sr=48000) |
| | |
| | latents = (latents - self.latent_mean) / self.latent_std |
| | return latents |
| |
|
| | def decode(self, latents): |
| | """Decode latents to audio""" |
| | with torch.no_grad(): |
| | |
| | latents = latents * self.latent_std + self.latent_mean |
| | sr, audio_list = self.model.decode(latents, sr=48000) |
| | |
| | audio_batch = torch.stack(audio_list).to(self.device) |
| | return audio_batch |
| |
|
| | class RF: |
| | def __init__(self, model, time_sampling="sigmoid"): |
| | self.model = model |
| | self.time_sampling = time_sampling |
| |
|
| | def sample_timesteps(self, batch, device): |
| | """Sample timesteps based on the configured strategy.""" |
| | if self.time_sampling == "sigmoid": |
| | return torch.sigmoid(torch.randn((batch,), device=device)) |
| | elif self.time_sampling == "warped": |
| | pm = 128 * 16 * 16 |
| | alpha = max(1.0, math.sqrt(pm / 4096.0)) |
| | u = torch.rand(batch, device=device) |
| | return alpha * u / (1.0 + (alpha - 1.0) * u) |
| | elif self.time_sampling == "uniform": |
| | return torch.rand(batch, device=device) |
| | else: |
| | raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}") |
| |
|
| | def forward(self, x, cond): |
| | b = x.size(0) |
| |
|
| | t = self.sample_timesteps(b, x.device) |
| |
|
| | texp = t.view([b, *([1] * len(x.shape[1:]))]) |
| | z1 = torch.randn_like(x) |
| | zt = (1 - texp) * x + texp * z1 |
| |
|
| | x_pred = self.model(zt, t, cond) |
| |
|
| | target = (zt - x) / (texp + 0.05) |
| | v_pred = (zt - x_pred) / (texp + 0.05) |
| | loss = F.mse_loss(target, v_pred) |
| |
|
| | return loss |
| |
|
| | def get_sampling_timesteps(self, steps, device): |
| | """Generate timesteps for sampling.""" |
| | if self.time_sampling == "uniform" or self.time_sampling == "sigmoid": |
| | return torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1] |
| | elif self.time_sampling == "warped": |
| | pm = 128 * 16 * 16 |
| | alpha = max(1.0, math.sqrt(pm / 4096.0)) |
| | u = torch.linspace(1.0, 0.0, steps + 1, device=device)[:-1] |
| | return alpha * u / (1.0 + (alpha - 1.0) * u) |
| | else: |
| | raise ValueError(f"Unknown time_sampling strategy: {self.time_sampling}") |
| |
|
| | def sample(self, z, cond, null_cond=None, sample_steps=100, cfg=3.0): |
| | b = z.size(0) |
| | device = z.device |
| | latent_shape = [b, *([1] * len(z.shape[1:]))] |
| |
|
| | timesteps = self.get_sampling_timesteps(sample_steps, device) |
| | images = [z] |
| |
|
| | for idx in range(sample_steps): |
| | t_curr = timesteps[idx] |
| | t_next = timesteps[idx + 1] if idx + 1 < sample_steps else torch.tensor(0.0, device=device) |
| | dt = t_curr - t_next |
| | t = t_curr.expand(b) |
| |
|
| | vc = self.model(z, t, cond) |
| | vc = (z - vc) / t_curr |
| | if null_cond is not None: |
| | vu = self.model(z, t, null_cond) |
| | vu = (z - vu) / t_curr |
| | vc = vu + cfg * (vc - vu) |
| |
|
| | z = z - dt * vc |
| | images.append(z) |
| | return images |
| |
|
| | def save_audio_samples(audio_batch, sample_rate, filename): |
| | """Save audio samples to file""" |
| | os.makedirs("audio_samples", exist_ok=True) |
| | |
| | |
| | audio = audio_batch[0].cpu() |
| | |
| | |
| | filepath = os.path.join("audio_samples", filename) |
| | torchaudio.save(filepath, audio, sample_rate) |
| | print(f"Saved audio sample: {filepath}") |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description='Audio training script with TensorBoard logging') |
| |
|
| | parser.add_argument('--channels', type=int, default=8, help='Number of input channels in the audio latents') |
| | parser.add_argument('--audio_height', type=int, default=16, help='Height of audio latents') |
| | parser.add_argument('--max_audio_width', type=int, default=4096, help='Max width of audio latents') |
| | parser.add_argument('--subsection_length', type=int, default=256, help='Length of random subsection to sample from each audio latent') |
| | parser.add_argument('--n_layers', type=int, default=36, help='Number of layers in the model') |
| | parser.add_argument('--n_encoder_layers', type=int, default=36, help='Number of encoder layers in the model') |
| | parser.add_argument('--n_heads', type=int, default=16, help='Number of heads in the model') |
| | parser.add_argument('--dim', type=int, default=768, help='Dimension of the encoder') |
| | parser.add_argument('--decoder_dim', type=int, default=1536, help='Dimension of the decoder (if None, uses --dim)') |
| | parser.add_argument('--dataset_name', type=str, default="cache", help='Audio dataset name') |
| | parser.add_argument('--num_workers', type=int, default=16, help='Number of workers for dataloader') |
| |
|
| | parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training') |
| | parser.add_argument('--epochs', type=int, default=1000, help='Number of epochs to train') |
| | parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') |
| | parser.add_argument('--warmup_steps', type=int, default=0, help='Number of warmup steps') |
| |
|
| | parser.add_argument('--sample_every', type=int, default=500, help='Audio sampling interval (batches)') |
| | parser.add_argument('--save_every', type=int, default=1000, help='Model saving interval (batches)') |
| | parser.add_argument('--num_samples', type=int, default=16, help='Number of samples to generate') |
| | parser.add_argument('--resume', type=bool, default=True, help='Resume training from checkpoint') |
| | parser.add_argument('--pad_to_length', action='store_true', help='Pad short samples to subsection_length instead of filtering them out') |
| | parser.add_argument('--time_sampling', type=str, default='warped', choices=['sigmoid', 'warped', 'uniform'], help='Timestep sampling strategy') |
| |
|
| | return parser.parse_args() |
| |
|
| | def main(): |
| | args = parse_args() |
| |
|
| | accelerator = Accelerator(mixed_precision="bf16" if torch.cuda.is_available() else "no") |
| |
|
| | is_main_process = accelerator.is_main_process |
| | |
| | writer = None |
| | if is_main_process: |
| | run_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| | writer = SummaryWriter(log_dir=f"runs/{run_datetime}") |
| | |
| | dataset = load_from_disk(args.dataset_name).with_format(type="torch") |
| |
|
| | |
| | if not args.pad_to_length: |
| | def filter_by_length(example): |
| | latent_width = example['latents'].shape[-1] |
| | return latent_width >= args.subsection_length * 2 |
| |
|
| | dataset = dataset.filter(filter_by_length) |
| |
|
| | if is_main_process: |
| | print(f"Dataset filtered to {len(dataset)} samples with width >= {args.subsection_length * 2}") |
| | else: |
| | if is_main_process: |
| | print(f"Padding enabled: short samples will be zero-padded to {args.subsection_length}") |
| |
|
| | |
| | latent_mean = torch.tensor([0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526]).view(1, -1, 1, 1) |
| | latent_std = torch.tensor([0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707]).view(1, -1, 1, 1) |
| |
|
| | |
| | num_classes = 2304 |
| | tag_embedder = TagEmbedder(num_classes=num_classes) |
| |
|
| | |
| | def collate_fn(batch): |
| | subsection_length = args.subsection_length |
| | pad_to_length = False |
| |
|
| | sampled_latents = [] |
| | album_names = [] |
| | song_names = [] |
| | ids = [] |
| | tags = [] |
| |
|
| | for item in batch: |
| | latent = item['latents'] |
| | if len(latent.shape) == 3: |
| | latent = latent.unsqueeze(0) |
| |
|
| | |
| | _, _, _, width = latent.shape |
| |
|
| | if width < subsection_length: |
| | if pad_to_length: |
| | |
| | pad_amount = subsection_length - width |
| | sampled_latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0) |
| |
|
| | else: |
| | |
| | max_start = width - subsection_length |
| | start_idx = torch.randint(0, max_start + 1, (1,)).item() |
| |
|
| | |
| | sampled_latent = latent[:, :, :, start_idx:start_idx + subsection_length] |
| |
|
| | sampled_latents.append(sampled_latent.squeeze(0)) |
| | album_name = item['album_name'] |
| | song_name = item['song_name'] |
| | album_names.append(album_name) |
| | song_names.append(song_name) |
| |
|
| | sample_tags = tag_embedder.get_tags(album_name, song_name) |
| | tags.append(sample_tags) |
| |
|
| | |
| | stacked_latents = torch.stack(sampled_latents) |
| | normalized_latents = (stacked_latents - latent_mean) / latent_std |
| |
|
| | return { |
| | 'latents': normalized_latents, |
| | 'tags': tags |
| | } |
| |
|
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=args.batch_size, |
| | shuffle=True, |
| | drop_last=True, |
| | persistent_workers=True, |
| | num_workers=args.num_workers if torch.cuda.is_available() else 0, |
| | pin_memory=True, |
| | collate_fn=collate_fn |
| | ) |
| | |
| | channels = args.channels |
| |
|
| | model = LocalSongModel( |
| | in_channels=channels, |
| | num_groups=args.n_heads, |
| | hidden_size=args.dim, |
| | decoder_hidden_size=args.decoder_dim, |
| | num_blocks=args.n_layers, |
| | patch_size=(16, 1), |
| | num_classes=num_classes, |
| | max_tags=8, |
| | ) |
| |
|
| | vae = AudioVAE(accelerator.device) |
| |
|
| | rf = RF(model, time_sampling=args.time_sampling) |
| |
|
| | optimizer = timm.optim.Muon(model.parameters(),lr=args.lr) |
| | scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.epochs * len(dataloader)) |
| |
|
| | global_step = 0 |
| | if args.resume: |
| | global_step = resume(model, optimizer, scheduler, accelerator) |
| |
|
| | if torch.cuda.is_available(): |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | model.forward_emb = torch.compile(model.forward_emb) |
| |
|
| | model, optimizer, scheduler, dataloader = accelerator.prepare( |
| | model, optimizer, scheduler, dataloader |
| | ) |
| | |
| | rf.model = model |
| |
|
| | if is_main_process: |
| | model_size = sum(p.numel() for p in accelerator.unwrap_model(model).parameters() if p.requires_grad) |
| | print(f"Number of parameters: {model_size}, {model_size / 1e6}M") |
| |
|
| | os.makedirs("audio_samples", exist_ok=True) |
| | num_samples = args.num_samples |
| | |
| | fixed_batch = None |
| | fixed_latents = None |
| | fixed_labels = None |
| | fixed_noise = None |
| | |
| | if is_main_process: |
| | data_iter = iter(dataloader) |
| | fixed_batch = next(data_iter) |
| | fixed_latents = fixed_batch["latents"][:num_samples] |
| |
|
| | print("Fixed ids:", fixed_batch["album_names"]) |
| |
|
| | |
| | fixed_tags = [] |
| |
|
| | |
| | idx_to_tag = {v: k for k, v in tag_embedder.tag_mapping.items()} |
| |
|
| | |
| | print("Fixed tag labels:") |
| | for i, tag_list in enumerate(fixed_tags): |
| | labels = [idx_to_tag.get(idx, f"<unknown:{idx}>") for idx in tag_list] |
| | print(f" Sample {i}: {labels}") |
| |
|
| | |
| | B, C, H, W = fixed_latents.shape |
| | fixed_noise = torch.randn(num_samples, C, H, args.subsection_length, device=accelerator.device) |
| |
|
| | fixed_latents = fixed_latents.to(accelerator.device) |
| |
|
| | if is_main_process: |
| | print("Begin training") |
| |
|
| | mse_loss_window = deque(maxlen=100) |
| | start_epoch = 0 |
| | for epoch in range(start_epoch, args.epochs): |
| | |
| | pbar = tqdm(dataloader) if is_main_process else dataloader |
| | for batch in pbar: |
| | x = batch["latents"] |
| |
|
| | |
| | tags = batch["tags"] |
| |
|
| | |
| | dropout_tags = [] |
| | for tag_list in tags: |
| | if torch.rand(1).item() < 0.1: |
| | |
| | dropout_tags.append([]) |
| | else: |
| | dropout_tags.append(tag_list) |
| |
|
| | |
| | c = dropout_tags |
| |
|
| | with accelerator.accumulate(model): |
| | optimizer.zero_grad() |
| | mse_loss = rf.forward(x, c) |
| | |
| | loss = mse_loss |
| | |
| | accelerator.backward(loss) |
| | accelerator.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | scheduler.step() |
| |
|
| | if is_main_process: |
| |
|
| | mse_loss_window.append(mse_loss.item()) |
| | |
| | avg_mse_loss = sum(mse_loss_window) / len(mse_loss_window) |
| |
|
| | if isinstance(pbar, tqdm): |
| | pbar.set_postfix({"mse_loss": avg_mse_loss, "lr": optimizer.param_groups[0]['lr']}) |
| | |
| | if writer is not None: |
| | writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], global_step) |
| | writer.add_scalar('MSE_Loss', avg_mse_loss, global_step) |
| |
|
| | global_step += 1 |
| | |
| | if is_main_process and global_step % args.save_every == 0: |
| | save(model, optimizer, scheduler, global_step, accelerator) |
| | |
| | if is_main_process and global_step % args.sample_every == 0: |
| | model.eval() |
| |
|
| | with torch.no_grad(): |
| | |
| | cond = fixed_tags |
| | |
| | null_cond = [[] for _ in range(len(cond))] |
| |
|
| | sampled_latents = rf.sample(fixed_noise, cond, null_cond)[-1] |
| | |
| | |
| | try: |
| | sampled_audio = vae.decode(sampled_latents) |
| | |
| | |
| | for i in range(min(8, sampled_audio.shape[0])): |
| | save_audio_samples( |
| | sampled_audio[i:i+1], |
| | 48000, |
| | f"sample_{global_step}_generated_{i}.wav" |
| | ) |
| | |
| | |
| | if global_step == args.sample_every: |
| | original_audio = vae.decode(fixed_latents) |
| | for i in range(min(8, original_audio.shape[0])): |
| | save_audio_samples( |
| | original_audio[i:i+1], |
| | 48000, |
| | f"sample_{global_step}_original_{i}.wav" |
| | ) |
| | |
| | except Exception as e: |
| | print(f"Error during audio generation: {e}") |
| |
|
| | model.train() |
| | |
| | print("Saving final model") |
| | save(model, optimizer, scheduler, global_step, accelerator) |
| | |
| | if writer is not None: |
| | writer.close() |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|