Spaces:
Running on Zero
Running on Zero
File size: 6,169 Bytes
ddb382a 8031e67 ddb382a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import argparse
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
import sys
import torch
import logging
import torch.distributed as dist
from torch.utils.data import DataLoader, distributed
from tqdm import tqdm
import time
import numpy as np
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
from data_utils.v2a_utils.thinksound_288_al import VGGSound
from torch.utils.data.dataloader import default_collate
def setup_distributed(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup_distributed():
dist.destroy_process_group()
def error_avoidance_collate(batch):
batch = list(filter(lambda x: x is not None, batch))
if len(batch) == 0:
return None # 或 return {}
return default_collate(batch)
def process_batch(data, model, rank, inference_mode=False):
output = {
'caption_cot': data['caption_cot'],
'latent': [],
'global_video_features': [],
'video_features': [],
'global_text_features': [],
'text_features': [],
'sync_features': [],
}
#start_time = time.time()
with torch.no_grad():
text_features = model.module.encode_t5_text(data['caption_cot'])
output['text_features'] = text_features.detach().cpu().numpy()
if not inference_mode:
latent = model.module.encode_audio(data['audio'].cuda(rank, non_blocking=True))
output['latent'] = latent.detach().cpu().numpy()
else:
output['latent'] = [None] * len(text_features)
video_feat,frame_embed,_,text_feat= model.module.encode_video_and_text_with_videoprism(data['clip_video'], data['caption_cot'])
output['global_video_features'].append(np.array(video_feat))
output['video_features'].append(np.array(frame_embed))
output['global_text_features'].append(np.array(text_feat))
sync_video = data['sync_video'].cuda(rank, non_blocking=True)
sync_features = model.module.encode_video_with_sync(sync_video)
output['sync_features'] = sync_features.detach().cpu().numpy()
return output
def save_outputs(output, ids, save_dir, add_audio_path=None, add_video_path=None):
for i, sample_id in enumerate(ids):
np.savez(
os.path.join(save_dir, f"{sample_id}.npz"),
id=sample_id,
audio_path=os.path.join(add_audio_path,f"{sample_id}.wav") if add_audio_path is not None else None,
video_path=os.path.join(add_video_path,f"{sample_id}.mp4") if add_video_path is not None else None,
caption_cot=output['caption_cot'][i],
latent=output['latent'][i],
global_video_features=output['global_video_features'][i],
video_features=output['video_features'][i],
global_text_features=output['global_text_features'][i],
text_features=output['text_features'][i],
sync_features=output['sync_features'][i],
)
def main(args):
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
setup_distributed(rank, world_size)
dataset = VGGSound(
root=args.root,
tsv_path=args.tsv_path,
sample_rate=args.sample_rate,
start_row=args.start_row,
end_row=args.end_row,
save_dir=args.save_dir,
inference_mode = args.inference_mode
)
os.makedirs(args.save_dir, exist_ok=True)
sampler = distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=32,
drop_last=False, collate_fn=error_avoidance_collate, pin_memory=True)
model = FeaturesUtils(
vae_ckpt=args.vae_ckpt if not args.inference_mode else None,
vae_config=args.vae_config,
enable_conditions=True,
synchformer_ckpt=args.synchformer_ckpt
)
model = model.eval().cuda(rank)
torch.compile(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
for data in tqdm(dataloader, desc="Processing", unit="batch"):
if data is None:
continue
ids = data['id']
try:
output = process_batch(data, model, rank, args.inference_mode)
save_outputs(output, ids, args.save_dir, args.add_audio_path, args.add_video_path)
except Exception as e:
logging.error(f"Error processing sample IDs {ids}: {e}")
cleanup_distributed()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
parser = argparse.ArgumentParser(description='Extract Video Training Latents')
parser.add_argument('--root', default='videos')
parser.add_argument('--tsv_path', default='cot_coarse/cot.csv')
parser.add_argument('--save-dir', default='results')
parser.add_argument('--sample_rate', type=int, default=44100, help='Audio sample rate')
parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint')
parser.add_argument('--vae_config', type=str, default='PrismAudio/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file')
parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint')
parser.add_argument('--start-row', '-s', type=int, default=0, help='Start row index')
parser.add_argument('--end-row', '-e', type=int, default=None, help='End row index')
parser.add_argument('--add_audio_path', default=None,
help='Provide the original audio file path required for ITD reward in GRPO')
parser.add_argument('--add_video_path', default=None,
help='Provide the video path file required for Synchformer reward in GRPO')
parser.add_argument('--inference_mode', default=False, help='inference_mode')
args = parser.parse_args()
main(args)
|