| """ |
| PyTorch Lightning DataModule for LoRA Training |
| |
| Handles data loading and preprocessing for training ACE-Step LoRA adapters. |
| Supports both raw audio loading and preprocessed tensor loading. |
| """ |
|
|
| import os |
| import json |
| import random |
| from typing import Optional, List, Dict, Any, Tuple |
| from loguru import logger |
|
|
| import torch |
| import torchaudio |
| from torch.utils.data import Dataset, DataLoader |
|
|
| try: |
| from lightning.pytorch import LightningDataModule |
| LIGHTNING_AVAILABLE = True |
| except ImportError: |
| LIGHTNING_AVAILABLE = False |
| logger.warning("Lightning not installed. Training module will not be available.") |
| |
| class LightningDataModule: |
| pass |
|
|
|
|
| |
| |
| |
|
|
| class PreprocessedTensorDataset(Dataset): |
| """Dataset that loads preprocessed tensor files. |
| |
| This is the recommended dataset for training as all tensors are pre-computed: |
| - target_latents: VAE-encoded audio [T, 64] |
| - encoder_hidden_states: Condition encoder output [L, D] |
| - encoder_attention_mask: Condition mask [L] |
| - context_latents: Source context [T, 65] |
| - attention_mask: Audio latent mask [T] |
| |
| No VAE/text encoder needed during training - just load tensors directly! |
| """ |
| |
| def __init__(self, tensor_dir: str): |
| """Initialize from a directory of preprocessed .pt files. |
| |
| Args: |
| tensor_dir: Directory containing preprocessed .pt files and manifest.json |
| """ |
| self.tensor_dir = tensor_dir |
| self.sample_paths = [] |
| |
| |
| manifest_path = os.path.join(tensor_dir, "manifest.json") |
| if os.path.exists(manifest_path): |
| with open(manifest_path, 'r') as f: |
| manifest = json.load(f) |
| self.sample_paths = manifest.get("samples", []) |
| else: |
| |
| for f in os.listdir(tensor_dir): |
| if f.endswith('.pt') and f != "manifest.json": |
| self.sample_paths.append(os.path.join(tensor_dir, f)) |
| |
| |
| self.valid_paths = [p for p in self.sample_paths if os.path.exists(p)] |
| |
| if len(self.valid_paths) != len(self.sample_paths): |
| logger.warning(f"Some tensor files not found: {len(self.sample_paths) - len(self.valid_paths)} missing") |
| |
| logger.info(f"PreprocessedTensorDataset: {len(self.valid_paths)} samples from {tensor_dir}") |
| |
| def __len__(self) -> int: |
| return len(self.valid_paths) |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| """Load a preprocessed tensor file. |
| |
| Returns: |
| Dictionary containing all pre-computed tensors for training |
| """ |
| tensor_path = self.valid_paths[idx] |
| data = torch.load(tensor_path, map_location='cpu') |
| |
| return { |
| "target_latents": data["target_latents"], |
| "attention_mask": data["attention_mask"], |
| "encoder_hidden_states": data["encoder_hidden_states"], |
| "encoder_attention_mask": data["encoder_attention_mask"], |
| "context_latents": data["context_latents"], |
| "metadata": data.get("metadata", {}), |
| } |
|
|
|
|
| def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]: |
| """Collate function for preprocessed tensor batches. |
| |
| Handles variable-length tensors by padding to the longest in the batch. |
| |
| Args: |
| batch: List of sample dictionaries with pre-computed tensors |
| |
| Returns: |
| Batched dictionary with all tensors stacked |
| """ |
| |
| max_latent_len = max(s["target_latents"].shape[0] for s in batch) |
| max_encoder_len = max(s["encoder_hidden_states"].shape[0] for s in batch) |
| |
| |
| target_latents = [] |
| attention_masks = [] |
| encoder_hidden_states = [] |
| encoder_attention_masks = [] |
| context_latents = [] |
| |
| for sample in batch: |
| |
| tl = sample["target_latents"] |
| if tl.shape[0] < max_latent_len: |
| pad = torch.zeros(max_latent_len - tl.shape[0], tl.shape[1]) |
| tl = torch.cat([tl, pad], dim=0) |
| target_latents.append(tl) |
| |
| |
| am = sample["attention_mask"] |
| if am.shape[0] < max_latent_len: |
| pad = torch.zeros(max_latent_len - am.shape[0]) |
| am = torch.cat([am, pad], dim=0) |
| attention_masks.append(am) |
| |
| |
| cl = sample["context_latents"] |
| if cl.shape[0] < max_latent_len: |
| pad = torch.zeros(max_latent_len - cl.shape[0], cl.shape[1]) |
| cl = torch.cat([cl, pad], dim=0) |
| context_latents.append(cl) |
| |
| |
| ehs = sample["encoder_hidden_states"] |
| if ehs.shape[0] < max_encoder_len: |
| pad = torch.zeros(max_encoder_len - ehs.shape[0], ehs.shape[1]) |
| ehs = torch.cat([ehs, pad], dim=0) |
| encoder_hidden_states.append(ehs) |
| |
| |
| eam = sample["encoder_attention_mask"] |
| if eam.shape[0] < max_encoder_len: |
| pad = torch.zeros(max_encoder_len - eam.shape[0]) |
| eam = torch.cat([eam, pad], dim=0) |
| encoder_attention_masks.append(eam) |
| |
| return { |
| "target_latents": torch.stack(target_latents), |
| "attention_mask": torch.stack(attention_masks), |
| "encoder_hidden_states": torch.stack(encoder_hidden_states), |
| "encoder_attention_mask": torch.stack(encoder_attention_masks), |
| "context_latents": torch.stack(context_latents), |
| "metadata": [s["metadata"] for s in batch], |
| } |
|
|
|
|
| class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else object): |
| """DataModule for preprocessed tensor files. |
| |
| This is the recommended DataModule for training. It loads pre-computed tensors |
| directly without needing VAE, text encoder, or condition encoder at training time. |
| """ |
| |
| def __init__( |
| self, |
| tensor_dir: str, |
| batch_size: int = 1, |
| num_workers: int = 4, |
| pin_memory: bool = True, |
| val_split: float = 0.0, |
| ): |
| """Initialize the data module. |
| |
| Args: |
| tensor_dir: Directory containing preprocessed .pt files |
| batch_size: Training batch size |
| num_workers: Number of data loading workers |
| pin_memory: Whether to pin memory for faster GPU transfer |
| val_split: Fraction of data for validation (0 = no validation) |
| """ |
| if LIGHTNING_AVAILABLE: |
| super().__init__() |
| |
| self.tensor_dir = tensor_dir |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| self.pin_memory = pin_memory |
| self.val_split = val_split |
| |
| self.train_dataset = None |
| self.val_dataset = None |
| |
| def setup(self, stage: Optional[str] = None): |
| """Setup datasets.""" |
| if stage == 'fit' or stage is None: |
| |
| full_dataset = PreprocessedTensorDataset(self.tensor_dir) |
| |
| |
| if self.val_split > 0 and len(full_dataset) > 1: |
| n_val = max(1, int(len(full_dataset) * self.val_split)) |
| n_train = len(full_dataset) - n_val |
| |
| self.train_dataset, self.val_dataset = torch.utils.data.random_split( |
| full_dataset, [n_train, n_val] |
| ) |
| else: |
| self.train_dataset = full_dataset |
| self.val_dataset = None |
| |
| def train_dataloader(self) -> DataLoader: |
| """Create training dataloader.""" |
| return DataLoader( |
| self.train_dataset, |
| batch_size=self.batch_size, |
| shuffle=True, |
| num_workers=self.num_workers, |
| pin_memory=self.pin_memory, |
| collate_fn=collate_preprocessed_batch, |
| drop_last=True, |
| ) |
| |
| def val_dataloader(self) -> Optional[DataLoader]: |
| """Create validation dataloader.""" |
| if self.val_dataset is None: |
| return None |
| |
| return DataLoader( |
| self.val_dataset, |
| batch_size=self.batch_size, |
| shuffle=False, |
| num_workers=self.num_workers, |
| pin_memory=self.pin_memory, |
| collate_fn=collate_preprocessed_batch, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class AceStepTrainingDataset(Dataset): |
| """Dataset for ACE-Step LoRA training from raw audio. |
| |
| DEPRECATED: Use PreprocessedTensorDataset instead for better performance. |
| |
| Audio Format Requirements (handled automatically): |
| - Sample rate: 48kHz (resampled if different) |
| - Channels: Stereo (2 channels, mono is duplicated) |
| - Max duration: 240 seconds (4 minutes) |
| - Min duration: 5 seconds (padded if shorter) |
| """ |
| |
| def __init__( |
| self, |
| samples: List[Dict[str, Any]], |
| dit_handler, |
| max_duration: float = 240.0, |
| target_sample_rate: int = 48000, |
| ): |
| """Initialize the dataset.""" |
| self.samples = samples |
| self.dit_handler = dit_handler |
| self.max_duration = max_duration |
| self.target_sample_rate = target_sample_rate |
| |
| self.valid_samples = self._validate_samples() |
| logger.info(f"Dataset initialized with {len(self.valid_samples)} valid samples") |
| |
| def _validate_samples(self) -> List[Dict[str, Any]]: |
| """Validate and filter samples.""" |
| valid = [] |
| for i, sample in enumerate(self.samples): |
| audio_path = sample.get("audio_path", "") |
| if not audio_path or not os.path.exists(audio_path): |
| logger.warning(f"Sample {i}: Audio file not found: {audio_path}") |
| continue |
| |
| if not sample.get("caption"): |
| logger.warning(f"Sample {i}: Missing caption") |
| continue |
| |
| valid.append(sample) |
| |
| return valid |
| |
| def __len__(self) -> int: |
| return len(self.valid_samples) |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| """Get a single training sample.""" |
| sample = self.valid_samples[idx] |
| |
| audio_path = sample["audio_path"] |
| audio, sr = torchaudio.load(audio_path) |
| |
| |
| if sr != self.target_sample_rate: |
| resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) |
| audio = resampler(audio) |
| |
| |
| if audio.shape[0] == 1: |
| audio = audio.repeat(2, 1) |
| elif audio.shape[0] > 2: |
| audio = audio[:2, :] |
| |
| |
| max_samples = int(self.max_duration * self.target_sample_rate) |
| if audio.shape[1] > max_samples: |
| audio = audio[:, :max_samples] |
| |
| min_samples = int(5.0 * self.target_sample_rate) |
| if audio.shape[1] < min_samples: |
| padding = min_samples - audio.shape[1] |
| audio = torch.nn.functional.pad(audio, (0, padding)) |
| |
| return { |
| "audio": audio, |
| "caption": sample.get("caption", ""), |
| "lyrics": sample.get("lyrics", "[Instrumental]"), |
| "metadata": { |
| "caption": sample.get("caption", ""), |
| "lyrics": sample.get("lyrics", "[Instrumental]"), |
| "bpm": sample.get("bpm"), |
| "keyscale": sample.get("keyscale", ""), |
| "timesignature": sample.get("timesignature", ""), |
| "duration": sample.get("duration", audio.shape[1] / self.target_sample_rate), |
| "language": sample.get("language", "instrumental"), |
| "is_instrumental": sample.get("is_instrumental", True), |
| }, |
| "audio_path": audio_path, |
| } |
|
|
|
|
| def collate_training_batch(batch: List[Dict]) -> Dict[str, Any]: |
| """Collate function for raw audio batches (legacy).""" |
| max_len = max(sample["audio"].shape[1] for sample in batch) |
| |
| padded_audio = [] |
| attention_masks = [] |
| |
| for sample in batch: |
| audio = sample["audio"] |
| audio_len = audio.shape[1] |
| |
| if audio_len < max_len: |
| padding = max_len - audio_len |
| audio = torch.nn.functional.pad(audio, (0, padding)) |
| |
| padded_audio.append(audio) |
| |
| mask = torch.ones(max_len) |
| if audio_len < max_len: |
| mask[audio_len:] = 0 |
| attention_masks.append(mask) |
| |
| return { |
| "audio": torch.stack(padded_audio), |
| "attention_mask": torch.stack(attention_masks), |
| "captions": [s["caption"] for s in batch], |
| "lyrics": [s["lyrics"] for s in batch], |
| "metadata": [s["metadata"] for s in batch], |
| "audio_paths": [s["audio_path"] for s in batch], |
| } |
|
|
|
|
| class AceStepDataModule(LightningDataModule if LIGHTNING_AVAILABLE else object): |
| """DataModule for raw audio loading (legacy). |
| |
| DEPRECATED: Use PreprocessedDataModule for better training performance. |
| """ |
| |
| def __init__( |
| self, |
| samples: List[Dict[str, Any]], |
| dit_handler, |
| batch_size: int = 1, |
| num_workers: int = 4, |
| pin_memory: bool = True, |
| max_duration: float = 240.0, |
| val_split: float = 0.0, |
| ): |
| if LIGHTNING_AVAILABLE: |
| super().__init__() |
| |
| self.samples = samples |
| self.dit_handler = dit_handler |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| self.pin_memory = pin_memory |
| self.max_duration = max_duration |
| self.val_split = val_split |
| |
| self.train_dataset = None |
| self.val_dataset = None |
| |
| def setup(self, stage: Optional[str] = None): |
| if stage == 'fit' or stage is None: |
| if self.val_split > 0 and len(self.samples) > 1: |
| n_val = max(1, int(len(self.samples) * self.val_split)) |
| |
| indices = list(range(len(self.samples))) |
| random.shuffle(indices) |
| |
| val_indices = indices[:n_val] |
| train_indices = indices[n_val:] |
| |
| train_samples = [self.samples[i] for i in train_indices] |
| val_samples = [self.samples[i] for i in val_indices] |
| |
| self.train_dataset = AceStepTrainingDataset( |
| train_samples, self.dit_handler, self.max_duration |
| ) |
| self.val_dataset = AceStepTrainingDataset( |
| val_samples, self.dit_handler, self.max_duration |
| ) |
| else: |
| self.train_dataset = AceStepTrainingDataset( |
| self.samples, self.dit_handler, self.max_duration |
| ) |
| self.val_dataset = None |
| |
| def train_dataloader(self) -> DataLoader: |
| return DataLoader( |
| self.train_dataset, |
| batch_size=self.batch_size, |
| shuffle=True, |
| num_workers=self.num_workers, |
| pin_memory=self.pin_memory, |
| collate_fn=collate_training_batch, |
| drop_last=True, |
| ) |
| |
| def val_dataloader(self) -> Optional[DataLoader]: |
| if self.val_dataset is None: |
| return None |
| |
| return DataLoader( |
| self.val_dataset, |
| batch_size=self.batch_size, |
| shuffle=False, |
| num_workers=self.num_workers, |
| pin_memory=self.pin_memory, |
| collate_fn=collate_training_batch, |
| ) |
|
|
|
|
| def load_dataset_from_json(json_path: str) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: |
| """Load a dataset from JSON file.""" |
| with open(json_path, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| |
| metadata = data.get("metadata", {}) |
| samples = data.get("samples", []) |
| |
| return samples, metadata |
|
|