| import argparse |
| import torch |
| import torchaudio |
| from pathlib import Path |
| from spectral_ops import STFT, iSTFT |
| from model import Renaissance |
|
|
| def load_and_preprocess_audio(input_path, device, dtype): |
| waveform, sr = torchaudio.load(input_path) |
| |
| if waveform.shape[0] > 1: |
| waveform = torch.mean(waveform, dim=0, keepdim=True) |
| print(f"Converted to mono from {waveform.shape[0]} channels") |
| |
| if sr != 48000: |
| print(f"Resampling from {sr} Hz to 48000 Hz") |
| resampler = torchaudio.transforms.Resample(sr, 48000) |
| waveform = resampler(waveform) |
|
|
| waveform = torchaudio.functional.highpass_biquad( |
| waveform, 48000, cutoff_freq=60.0 |
| ) |
| |
| waveform = waveform.to(device).to(dtype) |
| |
| return waveform |
|
|
| def normalize_audio(audio): |
| normalization_factor = torch.max(torch.abs(audio)) |
| if normalization_factor > 0: |
| normalized_audio = audio / normalization_factor |
| else: |
| normalized_audio = audio |
| return normalized_audio, normalization_factor |
|
|
|
|
| def process_audio(model, stft, istft, input_wav, device): |
| input_wav_norm, norm_factor = normalize_audio(input_wav) |
| |
| with torch.no_grad(): |
| input_stft = stft(input_wav_norm) |
| |
| with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()): |
| enhanced_stft = model(input_stft) |
| |
| enhanced_wav = istft(enhanced_stft) |
| |
| if norm_factor > 0: |
| enhanced_wav = enhanced_wav * norm_factor |
| |
| return enhanced_wav |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Smule Renaissance Vocal Restoration" |
| ) |
| parser.add_argument( |
| "input", |
| type=str, |
| help="Input audio file path" |
| ) |
| parser.add_argument( |
| "-o", "--output", |
| type=str, |
| default=None, |
| help="Output audio file path (default: input_enhanced.wav)" |
| ) |
| parser.add_argument( |
| "-c", "--checkpoint", |
| type=str, |
| required=True, |
| help="Model checkpoint path" |
| ) |
| |
| args = parser.parse_args() |
| |
| if args.output is None: |
| input_path = Path(args.input) |
| args.output = str(input_path.parent / f"{input_path.stem}_enhanced.wav") |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if torch.cuda.is_available(): |
| print("Using device: CUDA with FP16 precision") |
| dtype = torch.float16 |
| else: |
| print("Using device: CPU with FP32 precision") |
| dtype = torch.float32 |
| |
| print(f"Loading model from {args.checkpoint}...") |
| model = Renaissance().to(device).to(dtype) |
| model.load_state_dict(torch.load(args.checkpoint, map_location=device)) |
| model.eval() |
| |
| stft = STFT(n_fft=4096, hop_length=2048, win_length=4096) |
| istft = iSTFT(n_fft=4096, hop_length=2048, win_length=4096) |
| |
| print(f"Loading audio from {args.input}...") |
| input_wav = load_and_preprocess_audio(args.input, device, dtype) |
| print(f"Audio duration: {input_wav.shape[1] / 48000:.2f} seconds") |
| |
| print("Processing audio...") |
| enhanced_wav = process_audio(model, stft, istft, input_wav, device) |
| |
| print(f"Saving enhanced audio to {args.output}...") |
| enhanced_wav_cpu = enhanced_wav.cpu().to(torch.float32) |
| torchaudio.save(args.output, enhanced_wav_cpu, 48000) |
| |
| print("Done!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |