| import os |
| import json |
| import torch |
| import argparse |
| import tempfile |
| import glob |
| from tqdm import tqdm |
| from transformers import AutoModel, AutoProcessor |
| from torch.nn.functional import cosine_similarity |
| import torch.multiprocessing as mp |
|
|
| |
| MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384" |
|
|
|
|
| def parse_arguments(): |
| """解析命令行参数""" |
| parser = argparse.ArgumentParser( |
| description="步骤 2: 从预计算的嵌入加载并计算问-帧相似度。" |
| ) |
| parser.add_argument( |
| "--data-file", |
| "-df", |
| type=str, |
| required=True, |
| help="包含评估数据集的JSON文件的绝对路径。", |
| ) |
| parser.add_argument( |
| "--embeddings-path", |
| "-ep", |
| type=str, |
| required=True, |
| help="包含预计算嵌入.pt文件的目录的绝对路径。", |
| ) |
| parser.add_argument( |
| "--output-file", |
| "-o", |
| type=str, |
| required=True, |
| help="用于保存最终相似度分数的JSON文件路径。", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def load_test_data(json_file): |
| """从JSON文件加载测试数据""" |
| try: |
| with open(json_file, "r", encoding="utf-8") as f: |
| return json.load(f) |
| except FileNotFoundError: |
| print(f"错误: 在 {json_file} 未找到数据文件") |
| exit(1) |
| except json.JSONDecodeError: |
| print(f"错误: 无法从 {json_file} 解码JSON") |
| exit(1) |
| return [] |
|
|
|
|
| def save_json_file(data, output_file): |
| """将数据保存到JSON文件""" |
| |
| with open(output_file, "w", encoding="utf-8") as f: |
| json.dump(data, f, indent=4) |
| print(f"\n成功将最终相似度结果保存到 {output_file}") |
|
|
|
|
| def process_question_chunk(args_tuple): |
| """ |
| 工作函数,用于处理一批问题并增量保存结果。 |
| """ |
| data_chunk, embeddings_base_path, gpu_id, temp_dir = args_tuple |
| device = f"cuda:{gpu_id}" |
|
|
| |
| temp_output_file = os.path.join(temp_dir, f"results_gpu_{gpu_id}.jsonl") |
|
|
| |
| model = AutoModel.from_pretrained(MODEL_ID).to(device).eval() |
| processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True) |
|
|
| progress_bar = tqdm(data_chunk, position=gpu_id, desc=f"GPU-{gpu_id}") |
|
|
| |
| embedding_cache = {} |
|
|
| with open(temp_output_file, "a", encoding="utf-8") as f_out: |
| for data_item in progress_bar: |
| item_key = data_item["key"] |
| question_key = data_item["uid"] |
| question = data_item["question"].split("\n(A)")[0] |
|
|
| embedding_file_path = os.path.join(embeddings_base_path, f"{item_key}.pt") |
| if not os.path.exists(embedding_file_path): |
| progress_bar.write( |
| f"Warning: Embedding file not found for '{item_key}', skipping." |
| ) |
| continue |
|
|
| try: |
| |
| if item_key not in embedding_cache: |
| loaded_data = torch.load(embedding_file_path, map_location="cpu") |
| embedding_cache[item_key] = { |
| "filenames": loaded_data["filenames"], |
| "embeddings": loaded_data["embeddings"], |
| } |
|
|
| frame_files = embedding_cache[item_key]["filenames"] |
| frame_embeddings = embedding_cache[item_key]["embeddings"].to(device) |
|
|
| with torch.no_grad(): |
| |
| text_inputs = processor( |
| text=[question], |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| ).to(device) |
| question_embedding = model.get_text_features(**text_inputs) |
|
|
| |
| similarities = cosine_similarity( |
| question_embedding, frame_embeddings |
| ) |
| scored_frames = sorted( |
| zip(frame_files, similarities.cpu().numpy()), |
| key=lambda x: x[1], |
| reverse=True, |
| ) |
| sorted_frame_filenames = [frame[0] for frame in scored_frames] |
|
|
| single_result = {question_key: sorted_frame_filenames} |
| f_out.write(json.dumps(single_result) + "\n") |
|
|
| except Exception as e: |
| progress_bar.write(f"Error on GPU-{gpu_id} for item '{item_key}': {e}") |
|
|
|
|
| def main(): |
| """主函数,用于协调多GPU处理""" |
| args = parse_arguments() |
|
|
| num_gpus = torch.cuda.device_count() |
| if num_gpus == 0: |
| print("错误: 未找到启用CUDA的GPU。正在退出。") |
| exit(1) |
|
|
| print(f"找到 {num_gpus} 个GPU。开始并行计算相似度...") |
|
|
| test_data = load_test_data(args.data_file) |
| if not test_data: |
| return |
|
|
| chunk_size = (len(test_data) + num_gpus - 1) // num_gpus |
| data_chunks = [ |
| test_data[i : i + chunk_size] for i in range(0, len(test_data), chunk_size) |
| ] |
|
|
| with tempfile.TemporaryDirectory() as temp_dir: |
| print(f"使用临时目录存储中间结果: {temp_dir}") |
|
|
| process_args = [ |
| (data_chunks[i], args.embeddings_path, i, temp_dir) |
| for i in range(len(data_chunks)) |
| ] |
|
|
| with mp.Pool(processes=num_gpus) as pool: |
| pool.map(process_question_chunk, process_args) |
|
|
| |
| print("\n\n所有GPU进程已完成。正在从临时文件合并结果...") |
| final_similarity_results = {} |
|
|
| temp_files = glob.glob(os.path.join(temp_dir, "*.jsonl")) |
|
|
| for temp_file in tqdm(temp_files, desc="合并文件"): |
| with open(temp_file, "r", encoding="utf-8") as f: |
| for line in f: |
| try: |
| data = json.loads(line) |
| final_similarity_results.update(data) |
| except json.JSONDecodeError: |
| print(f"警告: 跳过 {temp_file} 中的损坏行") |
|
|
| save_json_file(final_similarity_results, args.output_file) |
| print(f"总共处理的项目数: {len(final_similarity_results)}") |
|
|
|
|
| if __name__ == "__main__": |
| mp.set_start_method("spawn", force=True) |
| main() |
|
|