| import torch |
| import torchaudio |
| from pathlib import Path |
| import argparse |
| from tqdm import tqdm |
| from acestep.music_dcae.music_dcae_pipeline import MusicDCAE |
|
|
| class AudioVAE: |
| def __init__(self, device: torch.device): |
| self.model = MusicDCAE().to(device) |
| self.model.eval() |
| self.device = device |
| self.latent_mean = torch.tensor( |
| [0.1207, -0.0186, -0.0947, -0.3779, 0.5956, 0.3422, 0.1796, -0.0526], |
| device=device, |
| ).view(1, -1, 1, 1) |
| self.latent_std = torch.tensor( |
| [0.4638, 0.3154, 0.6244, 1.5078, 0.4696, 0.4633, 0.5614, 0.2707], |
| device=device, |
| ).view(1, -1, 1, 1) |
|
|
| def encode(self, audio): |
|
|
| with torch.no_grad(): |
| audio_lengths = torch.tensor([audio.shape[2]] * audio.shape[0]).to(self.device) |
| latents, _ = self.model.encode(audio, audio_lengths, sr=48000) |
| latents = (latents - self.latent_mean) / self.latent_std |
| return latents |
|
|
| def decode(self, latents: torch.Tensor) -> torch.Tensor: |
| with torch.no_grad(): |
| latents = latents * self.latent_std + self.latent_mean |
| _, audio_list = self.model.decode(latents, sr=48000) |
| audio_batch = torch.stack(audio_list).to(self.device) |
| return audio_batch |
|
|
| def load_audio(audio_path, target_sr=48000): |
| """Load and preprocess audio file.""" |
| audio, sr = torchaudio.load(audio_path) |
|
|
| if audio.shape[0] == 1: |
| audio = audio.repeat(2, 1) |
| elif audio.shape[0] > 2: |
| audio = audio[:2] |
|
|
| if sr != target_sr: |
| resampler = torchaudio.transforms.Resample(sr, target_sr) |
| audio = resampler(audio) |
|
|
| return audio |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Encode audio files to VAE latents') |
|
|
| parser.add_argument('--audio-dir', type=str, required=True, |
| help='Directory containing audio files') |
| parser.add_argument('--output-dir', type=str, default="latents", |
| help='Directory to save encoded latents') |
|
|
| args = parser.parse_args() |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| audio_dir = Path(args.audio_dir) |
| audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.ogg', '*.m4a'] |
| audio_files = [] |
| for ext in audio_extensions: |
| audio_files.extend(list(audio_dir.glob(ext))) |
| audio_files = sorted(audio_files) |
|
|
| if len(audio_files) == 0: |
| raise ValueError(f"No audio files found in {args.audio_dir}") |
|
|
| print(f"Found {len(audio_files)} audio files") |
|
|
| vae = AudioVAE(device) |
| print("VAE loaded") |
|
|
| |
| print("\nEncoding audio files...") |
| for audio_path in tqdm(audio_files, desc="Encoding"): |
| try: |
| audio = load_audio(audio_path) |
| audio = audio.unsqueeze(0).to(device) |
| latents = vae.encode(audio) |
| latents = latents.squeeze(0) |
|
|
| output_path = output_dir / f"{audio_path.stem}.pt" |
| torch.save(latents.cpu(), output_path) |
|
|
| except Exception as e: |
| print(f"\nError encoding {audio_path.name}: {e}") |
| continue |
|
|
| print(f"\nEncoding complete! Saved {len(list(output_dir.glob('*.pt')))} latent files to {output_dir}") |
|
|
| if __name__ == '__main__': |
| main() |
|
|