import argparse import os os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" import sys import torch import logging import torch.distributed as dist from torch.utils.data import DataLoader, distributed from tqdm import tqdm import time import numpy as np sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils from data_utils.v2a_utils.thinksound_288_al import VGGSound from torch.utils.data.dataloader import default_collate def setup_distributed(rank, world_size): dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) def cleanup_distributed(): dist.destroy_process_group() def error_avoidance_collate(batch): batch = list(filter(lambda x: x is not None, batch)) if len(batch) == 0: return None # 或 return {} return default_collate(batch) def process_batch(data, model, rank, inference_mode=False): output = { 'caption_cot': data['caption_cot'], 'latent': [], 'global_video_features': [], 'video_features': [], 'global_text_features': [], 'text_features': [], 'sync_features': [], } #start_time = time.time() with torch.no_grad(): text_features = model.module.encode_t5_text(data['caption_cot']) output['text_features'] = text_features.detach().cpu().numpy() if not inference_mode: latent = model.module.encode_audio(data['audio'].cuda(rank, non_blocking=True)) output['latent'] = latent.detach().cpu().numpy() else: output['latent'] = [None] * len(text_features) video_feat,frame_embed,_,text_feat= model.module.encode_video_and_text_with_videoprism(data['clip_video'], data['caption_cot']) output['global_video_features'].append(np.array(video_feat)) output['video_features'].append(np.array(frame_embed)) output['global_text_features'].append(np.array(text_feat)) sync_video = data['sync_video'].cuda(rank, non_blocking=True) sync_features = model.module.encode_video_with_sync(sync_video) output['sync_features'] = sync_features.detach().cpu().numpy() return output def save_outputs(output, ids, save_dir, add_audio_path=None, add_video_path=None): for i, sample_id in enumerate(ids): np.savez( os.path.join(save_dir, f"{sample_id}.npz"), id=sample_id, audio_path=os.path.join(add_audio_path,f"{sample_id}.wav") if add_audio_path is not None else None, video_path=os.path.join(add_video_path,f"{sample_id}.mp4") if add_video_path is not None else None, caption_cot=output['caption_cot'][i], latent=output['latent'][i], global_video_features=output['global_video_features'][i], video_features=output['video_features'][i], global_text_features=output['global_text_features'][i], text_features=output['text_features'][i], sync_features=output['sync_features'][i], ) def main(args): rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) setup_distributed(rank, world_size) dataset = VGGSound( root=args.root, tsv_path=args.tsv_path, sample_rate=args.sample_rate, start_row=args.start_row, end_row=args.end_row, save_dir=args.save_dir, inference_mode = args.inference_mode ) os.makedirs(args.save_dir, exist_ok=True) sampler = distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=32, drop_last=False, collate_fn=error_avoidance_collate, pin_memory=True) model = FeaturesUtils( vae_ckpt=args.vae_ckpt if not args.inference_mode else None, vae_config=args.vae_config, enable_conditions=True, synchformer_ckpt=args.synchformer_ckpt ) model = model.eval().cuda(rank) torch.compile(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) for data in tqdm(dataloader, desc="Processing", unit="batch"): if data is None: continue ids = data['id'] try: output = process_batch(data, model, rank, args.inference_mode) save_outputs(output, ids, args.save_dir, args.add_audio_path, args.add_video_path) except Exception as e: logging.error(f"Error processing sample IDs {ids}: {e}") cleanup_distributed() if __name__ == '__main__': logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') parser = argparse.ArgumentParser(description='Extract Video Training Latents') parser.add_argument('--root', default='videos') parser.add_argument('--tsv_path', default='cot_coarse/cot.csv') parser.add_argument('--save-dir', default='results') parser.add_argument('--sample_rate', type=int, default=44100, help='Audio sample rate') parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint') parser.add_argument('--vae_config', type=str, default='PrismAudio/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file') parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint') parser.add_argument('--start-row', '-s', type=int, default=0, help='Start row index') parser.add_argument('--end-row', '-e', type=int, default=None, help='End row index') parser.add_argument('--add_audio_path', default=None, help='Provide the original audio file path required for ITD reward in GRPO') parser.add_argument('--add_video_path', default=None, help='Provide the video path file required for Synchformer reward in GRPO') parser.add_argument('--inference_mode', default=False, help='inference_mode') args = parser.parse_args() main(args)