Spaces:
Running on Zero
Running on Zero
| from prefigure.prefigure import get_all_args, push_wandb_config | |
| import json | |
| import os | |
| import re | |
| import torch | |
| import torchaudio | |
| from lightning.pytorch import seed_everything | |
| import random | |
| from datetime import datetime | |
| import numpy as np | |
| from PrismAudio.data.datamodule import DataModule | |
| from PrismAudio.models import create_model_from_config | |
| from PrismAudio.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model | |
| from PrismAudio.inference.sampling import sample, sample_discrete_euler | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| def predict_step(diffusion, batch, diffusion_objective, device='cuda:0'): | |
| diffusion = diffusion.to(device) | |
| reals, metadata = batch | |
| ids = [item['id'] for item in metadata] | |
| batch_size, length = reals.shape[0], reals.shape[2] | |
| with torch.amp.autocast('cuda'): | |
| conditioning = diffusion.conditioner(metadata, device) | |
| video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) | |
| conditioning['metaclip_features'][~video_exist] = diffusion.model.model.empty_clip_feat | |
| conditioning['sync_features'][~video_exist] = diffusion.model.model.empty_sync_feat | |
| cond_inputs = diffusion.get_conditioning_inputs(conditioning) | |
| if batch_size > 1: | |
| noise_list = [] | |
| for _ in range(batch_size): | |
| noise_1 = torch.randn([1, diffusion.io_channels, length]).to(device) # 每次生成推进RNG状态 | |
| noise_list.append(noise_1) | |
| noise = torch.cat(noise_list, dim=0) | |
| else: | |
| noise = torch.randn([batch_size, diffusion.io_channels, length]).to(device) | |
| with torch.amp.autocast('cuda'): | |
| model = diffusion.model | |
| if diffusion_objective == "v": | |
| fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True) | |
| elif diffusion_objective == "rectified_flow": | |
| fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True) | |
| if diffusion.pretransform is not None: | |
| fakes = diffusion.pretransform.decode(fakes) | |
| audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() | |
| return audios | |
| def main(): | |
| args = get_all_args() | |
| if args.save_dir == '': | |
| args.save_dir = args.results_dir | |
| seed = args.seed | |
| if os.environ.get("SLURM_PROCID") is not None: | |
| seed += int(os.environ.get("SLURM_PROCID")) | |
| seed_everything(seed, workers=True) | |
| # Load config | |
| if args.model_config == '': | |
| args.model_config = "PrismAudio/configs/model_configs/thinksound.json" | |
| with open(args.model_config) as f: | |
| model_config = json.load(f) | |
| duration = float(args.duration_sec) | |
| sample_rate = model_config["sample_rate"] | |
| latent_length = round(44100 / 64 / 32 * duration) | |
| model_config["sample_size"] = duration * sample_rate | |
| model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24 * int(duration) | |
| model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8 * int(duration) | |
| model_config["model"]["diffusion"]["config"]["latent_seq_len"] = latent_length | |
| model = create_model_from_config(model_config) | |
| if args.compile: | |
| model = torch.compile(model) | |
| model.load_state_dict(torch.load(args.ckpt_dir)) | |
| vae_state = load_ckpt_state_dict(args.pretransform_ckpt_path, prefix='autoencoder.') | |
| model.pretransform.load_state_dict(vae_state) | |
| if args.dataset_config == '': | |
| args.dataset_config = "PrismAudio/configs/multimodal_dataset_demo.json" | |
| with open(args.dataset_config) as f: | |
| dataset_config = json.load(f) | |
| for td in dataset_config["test_datasets"]: | |
| td["path"] = args.results_dir | |
| dm = DataModule( | |
| dataset_config, | |
| batch_size=args.batch_size, | |
| test_batch_size=args.test_batch_size, | |
| num_workers=args.num_workers, | |
| sample_rate=model_config["sample_rate"], | |
| sample_size=(float)(args.duration_sec) * model_config["sample_rate"], | |
| audio_channels=model_config.get("audio_channels", 2), | |
| latent_length=round(44100/64/32*duration), | |
| ) | |
| dm.setup('predict') | |
| dl = dm.predict_dataloader() | |
| current_date = datetime.now() | |
| formatted_date = current_date.strftime('%m%d') | |
| audio_dir = os.path.join(args.save_dir,f'{formatted_date}_batch_size'+str(args.test_batch_size)) | |
| os.makedirs(audio_dir,exist_ok=True) | |
| for batch in tqdm(dl, desc="Predicting"): | |
| audio = predict_step( | |
| model, | |
| batch=batch, | |
| diffusion_objective=model_config["model"]["diffusion"]["diffusion_objective"], | |
| device='cuda:0' | |
| ) | |
| _, metadata = batch | |
| ids = [item['id'] for item in metadata] | |
| for i in range(audio.size(0)): | |
| id_str = ids[i] if i < len(ids) else f"unknown_{i}" | |
| torchaudio.save(os.path.join(audio_dir, f"{id_str}.wav"), audio[i], 44100) | |
| if __name__ == '__main__': | |
| main() |