| import os |
| import json |
| import base64 |
| import argparse |
| import time |
| import re |
| from datetime import datetime |
| from functools import partial |
| from openai import AzureOpenAI, OpenAI |
| from volcenginesdkarkruntime import Ark |
| from multiprocessing import Pool, Manager, Lock |
|
|
| |
| REASONING_MULTIPLE_CHOICE_TEMPLATE = """ |
| You are an AI assistant evaluating video frames to answer a multiple-choice question. |
| The user will provide you with a set of video frames and a question with several options (e.g., A, B, C, D). |
| |
| First, provide a step-by-step reasoning process that analyzes the video frames and leads to your conclusion. |
| After your reasoning, provide the final answer in a JSON block. The JSON object must contain a single key "answer" with the value being one of 'A', 'B', 'C', or 'D'. |
| |
| Your output should follow this format exactly: |
| <Your step-by-step reasoning here> |
| ```json |
| {"answer": "A"} |
| ``` |
| Do not include any other text after the JSON block. |
| """ |
|
|
|
|
| def parse_arguments(): |
| """ |
| Parse command line arguments for evaluation configuration. |
| |
| Returns: |
| argparse.Namespace: Parsed command line arguments |
| """ |
| parser = argparse.ArgumentParser( |
| description="Video QA Evaluation with Pre-computed Similarity Frame Selection" |
| ) |
|
|
| |
| parser.add_argument( |
| "--target-model", |
| "-tm", |
| type=str, |
| required=True, |
| help="Model to be evaluated (e.g., gpt-4o, gpt-4-vision-preview)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--frame-num", |
| "-fn", |
| type=int, |
| default=32, |
| help="Number of most similar frames to select for each video (default: 32)", |
| ) |
| parser.add_argument( |
| "--frames-path", |
| "-fp", |
| type=str, |
| required=True, |
| help="Absolute path to the base directory containing video frame folders.", |
| ) |
| parser.add_argument( |
| "--data-file", |
| "-df", |
| type=str, |
| required=True, |
| help="Absolute path to the JSON file containing the evaluation dataset.", |
| ) |
| |
| parser.add_argument( |
| "--similarity-file", |
| "-sf", |
| type=str, |
| required=True, |
| help="Absolute path to the pre-computed similarity JSON file (e.g., lv_bench_similarity.json).", |
| ) |
|
|
| |
| parser.add_argument( |
| "--max-retry-times", |
| "-mr", |
| type=int, |
| default=10, |
| help="Maximum number of retries for API calls (default: 10)", |
| ) |
| parser.add_argument( |
| "--pool-processes", |
| "-pp", |
| type=int, |
| default=20, |
| help="Number of parallel processes for evaluation (default: 20)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--base_url", type=str, required=True, help="Azure OpenAI endpoint URL." |
| ) |
| parser.add_argument( |
| "--api_key", type=str, required=True, help="Azure OpenAI API key." |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def save_json_file(data, output_file): |
| """ |
| Save data to a JSON file. |
| """ |
| with open(output_file, "w", encoding="utf-8") as f: |
| json.dump(data, f, indent=4) |
|
|
|
|
| def extract_json_from_response(response): |
| """ |
| Extracts a JSON object from a string that contains reasoning followed by a tagged JSON block. |
| """ |
| if not response: |
| return None |
| try: |
| match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL) |
| if match: |
| json_str = match.group(1) |
| return json.loads(json_str) |
| return None |
| except (json.JSONDecodeError, IndexError): |
| return None |
|
|
|
|
| def calculate_metrics(results): |
| """ |
| Calculate evaluation metrics from the results. |
| """ |
| total_samples = len(results) |
| if total_samples == 0: |
| return { |
| "total_samples": 0, |
| "answered_samples": 0, |
| "correct_answers": 0, |
| "accuracy": 0.0, |
| } |
|
|
| answered_samples = sum(1 for x in results if x.get("model_answer") is not None) |
| correct_answers = sum(1 for x in results if x.get("is_correct")) |
|
|
| accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0 |
|
|
| return { |
| "total_samples": total_samples, |
| "answered_samples": answered_samples, |
| "correct_answers": correct_answers, |
| "accuracy": accuracy, |
| } |
|
|
|
|
| def call_single_model(client, messages, model, item_id, max_retry_times): |
| """ |
| Make a single API call to the specified model with retry logic. |
| """ |
| if "doubao" in model: |
| max_tokens = 32768 |
| else: |
| max_tokens = 65535 |
| retry_times = 0 |
| while retry_times < max_retry_times: |
| try: |
| completion = client.chat.completions.create( |
| model=model, messages=messages, max_tokens=max_tokens |
| ) |
| return completion.choices[0].message.content |
| except Exception as e: |
| retry_times += 1 |
| print( |
| f"Error processing item {item_id} with model {model}: {str(e)}. Retrying ({retry_times}/{max_retry_times})..." |
| ) |
| if retry_times == max_retry_times: |
| error_log_file = f"error_log_{model.replace('/', '_')}.txt" |
| with open(error_log_file, "a") as f: |
| f.write( |
| f"Error processing item {item_id} with model {model} after {max_retry_times} retries: {str(e)}\n" |
| ) |
| return None |
| time.sleep(5) |
|
|
|
|
| def evaluate_single_item( |
| data_item, frames, target_model, api_key, base_url, max_retry_times |
| ): |
| """ |
| Evaluate a single data item using the target model. |
| """ |
| if "ark" in base_url: |
| client = Ark(base_url=base_url, api_key=api_key) |
| elif "aliyun" in base_url or "127.0.0.1" in base_url: |
| client = OpenAI(api_key=api_key, base_url=base_url) |
| else: |
| client = AzureOpenAI( |
| api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url |
| ) |
|
|
| messages = [ |
| {"role": "system", "content": REASONING_MULTIPLE_CHOICE_TEMPLATE}, |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "Here are the video frames:"}, |
| *frames, |
| {"type": "text", "text": f"Question: {data_item['question']}"}, |
| ], |
| }, |
| ] |
|
|
| response = call_single_model( |
| client, messages, target_model, data_item["key"], max_retry_times |
| ) |
|
|
| is_correct = False |
| model_answer_cleaned = None |
| parsed_json = None |
|
|
| if response: |
| parsed_json = extract_json_from_response(response) |
| if parsed_json and "answer" in parsed_json: |
| model_answer_cleaned = str(parsed_json["answer"]).strip().upper() |
| gold_answer = data_item["answer"].strip().upper() |
| if model_answer_cleaned == gold_answer: |
| is_correct = True |
|
|
| return { |
| **data_item, |
| "model_reasoning_and_answer": response, |
| "model_answer_raw": parsed_json.get("answer") if parsed_json else None, |
| "model_answer": model_answer_cleaned, |
| "is_correct": is_correct, |
| } |
|
|
|
|
| def encode_image(image_path): |
| """ |
| Encode an image file to base64 string. |
| """ |
| with open(image_path, "rb") as image_file: |
| return base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
|
| |
| def process_frames_from_similarity_file( |
| frames_base_path, frame_num, data_item, similarity_data |
| ): |
| """ |
| Select and encode the top N frames using a pre-computed similarity file. |
| """ |
| item_key = data_item["key"] |
| question_uid = str(data_item["uid"]) |
|
|
| |
| sorted_filenames = similarity_data.get(question_uid) |
|
|
| if not sorted_filenames: |
| print( |
| f"Warning: No similarity data found for question UID '{question_uid}', skipping." |
| ) |
| return [] |
|
|
| try: |
| |
| num_frames_to_select = min(frame_num, len(sorted_filenames)) |
| selected_filenames = sorted_filenames[:num_frames_to_select] |
| selected_ids = [int(f.split(".")[0].split("_")[-1]) for f in selected_filenames] |
| selected_ids = sorted(selected_ids) |
| selected_filenames = [f"frame_{i:06d}.jpg" for i in selected_ids] |
|
|
| |
| video_frames_path = os.path.join(frames_base_path, item_key) |
| sampled_paths = [os.path.join(video_frames_path, f) for f in selected_filenames] |
|
|
| |
| base64_images = [encode_image(path) for path in sampled_paths] |
|
|
| return [ |
| { |
| "type": "image_url", |
| "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}, |
| } |
| for b64_img in base64_images |
| ] |
| except Exception as e: |
| print(f"Error during frame processing for key '{item_key}': {e}") |
| return [] |
|
|
|
|
| def process_single_data( |
| data_item, |
| args, |
| shared_results, |
| progress_counter, |
| total_items, |
| locks, |
| similarity_data, |
| ): |
| """ |
| Process a single data item in a multiprocessing context. |
| """ |
| item_key = data_item["key"] |
| try: |
| |
| frames = process_frames_from_similarity_file( |
| args.frames_path, args.frame_num, data_item, similarity_data |
| ) |
|
|
| if not frames: |
| raise ValueError( |
| f"No frames were processed from similarity file for key '{item_key}'" |
| ) |
|
|
| result = evaluate_single_item( |
| data_item, |
| frames, |
| args.target_model, |
| args.api_key, |
| args.base_url, |
| args.max_retry_times, |
| ) |
|
|
| if result is not None: |
| with locks["results"]: |
| shared_results.append(result) |
| data_filename_base = os.path.splitext(os.path.basename(args.data_file))[ |
| 0 |
| ] |
| model_name_safe = args.target_model.replace("/", "_") |
| output_prefix = f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames_precomputed_similar" |
| results_output_file = f"{output_prefix}_results.json" |
| save_json_file(list(shared_results), results_output_file) |
|
|
| except Exception as e: |
| print(f"Error processing video key {item_key}: {str(e)}") |
| with locks["file"]: |
| error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt" |
| with open(error_log_file, "a") as f: |
| f.write(f"Critical error processing video key {item_key}: {str(e)}\n") |
| finally: |
| with locks["counter"]: |
| progress_counter.value += 1 |
| print( |
| f"\rProcessed: {progress_counter.value}/{total_items} videos...", |
| end="", |
| flush=True, |
| ) |
|
|
|
|
| def load_test_data(json_file): |
| """ |
| Load test data from a JSON file. |
| """ |
| try: |
| with open(json_file, "r", encoding="utf-8") as f: |
| return json.load(f) |
| except FileNotFoundError: |
| print(f"Error: Data file not found at {json_file}") |
| exit(1) |
| except json.JSONDecodeError: |
| print(f"Error: Could not decode JSON from {json_file}") |
| exit(1) |
|
|
|
|
| def main(): |
| """ |
| Main function to run the video QA evaluation framework. |
| """ |
| args = parse_arguments() |
|
|
| print("--- Evaluation Configuration ---") |
| print(f"Target Model: {args.target_model}") |
| print(f"Frames to Sample (by pre-computed similarity): {args.frame_num}") |
| print(f"Frames Base Path: {args.frames_path}") |
| print(f"Similarity File: {args.similarity_file}") |
| print(f"Data File: {args.data_file}") |
| print(f"Parallel Processes: {args.pool_processes}") |
| print("---------------------------------") |
|
|
| error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt" |
| with open(error_log_file, "w") as f: |
| f.write( |
| f"=== Error Log Started at {datetime.now()} for model {args.target_model} ===\n" |
| ) |
|
|
| data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0] |
| model_name_safe = args.target_model.replace("/", "_") |
| output_prefix = f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames_precomputed_similar" |
|
|
| results_output_file = f"{output_prefix}_results.json" |
| metrics_output_file = f"{output_prefix}_metrics.json" |
|
|
| |
| test_data = load_test_data(args.data_file) |
| try: |
| with open(args.similarity_file, "r", encoding="utf-8") as f: |
| similarity_data = json.load(f) |
| except FileNotFoundError: |
| print(f"Error: Similarity file not found at {args.similarity_file}") |
| exit(1) |
|
|
| total_videos = len(test_data) |
| print(f"\nLoaded {total_videos} videos to process.") |
|
|
| with Manager() as manager: |
| shared_results = manager.list() |
| progress_counter = manager.Value("i", 0) |
| locks = { |
| "results": manager.Lock(), |
| "file": manager.Lock(), |
| "counter": manager.Lock(), |
| } |
|
|
| |
| process_func = partial( |
| process_single_data, |
| args=args, |
| shared_results=shared_results, |
| progress_counter=progress_counter, |
| total_items=total_videos, |
| locks=locks, |
| similarity_data=similarity_data, |
| ) |
|
|
| |
| with Pool(processes=args.pool_processes) as pool: |
| pool.map(process_func, test_data) |
|
|
| all_results = list(shared_results) |
|
|
| print(f"\n\nProcessing complete for model: {args.target_model}") |
|
|
| final_metrics = calculate_metrics(all_results) |
| save_json_file(final_metrics, metrics_output_file) |
| print(f"\nMetrics saved to: {metrics_output_file}") |
| print(json.dumps(final_metrics, indent=4)) |
|
|
| save_json_file(all_results, results_output_file) |
| print(f"Detailed results saved to: {results_output_file}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|