| import json |
| import openai |
| import os |
| from datetime import datetime |
| import base64 |
| import logging |
| from pathlib import Path |
| import time |
| from tqdm import tqdm |
| from typing import Dict, List, Optional, Union, Any |
|
|
| |
| DEBUG_MODE = False |
| OUTPUT_DIR = "results" |
| MODEL_NAME = "gpt-4o-2024-05-13" |
| TEMPERATURE = 0.2 |
| SUBSET = "Visual Question Answering" |
|
|
| |
| logging_level = logging.DEBUG if DEBUG_MODE else logging.INFO |
| logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def get_mime_type(file_path: str) -> str: |
| """ |
| Determine MIME type based on file extension. |
| |
| Args: |
| file_path (str): Path to the file |
| |
| Returns: |
| str: MIME type string for the file |
| """ |
| extension = os.path.splitext(file_path)[1].lower() |
| mime_types = { |
| ".png": "image/png", |
| ".jpg": "image/jpeg", |
| ".jpeg": "image/jpeg", |
| ".gif": "image/gif", |
| } |
| return mime_types.get(extension, "application/octet-stream") |
|
|
|
|
| def encode_image(image_path: str) -> str: |
| """ |
| Encode image to base64 with extensive error checking. |
| |
| Args: |
| image_path (str): Path to the image file |
| |
| Returns: |
| str: Base64 encoded image string |
| |
| Raises: |
| FileNotFoundError: If image file does not exist |
| ValueError: If image file is empty or too large |
| Exception: For other image processing errors |
| """ |
| logger.debug(f"Attempting to read image from: {image_path}") |
| if not os.path.exists(image_path): |
| raise FileNotFoundError(f"Image file not found: {image_path}") |
|
|
| |
| file_size = os.path.getsize(image_path) |
| if file_size > 20 * 1024 * 1024: |
| raise ValueError("Image file size exceeds 20MB limit") |
| if file_size == 0: |
| raise ValueError("Image file is empty") |
| logger.debug(f"Image file size: {file_size / 1024:.2f} KB") |
|
|
| try: |
| from PIL import Image |
|
|
| |
| with Image.open(image_path) as img: |
| |
| width, height = img.size |
| format = img.format |
| mode = img.mode |
| logger.debug( |
| f"Image verification - Format: {format}, Size: {width}x{height}, Mode: {mode}" |
| ) |
|
|
| if format not in ["PNG", "JPEG", "GIF"]: |
| raise ValueError(f"Unsupported image format: {format}") |
|
|
| with open(image_path, "rb") as image_file: |
| |
| header = image_file.read(8) |
| |
| |
|
|
| |
| image_file.seek(0) |
| encoded = base64.b64encode(image_file.read()).decode("utf-8") |
| encoded_length = len(encoded) |
| logger.debug(f"Base64 encoded length: {encoded_length} characters") |
|
|
| |
| if encoded_length == 0: |
| raise ValueError("Base64 encoding produced empty string") |
| if not encoded.startswith("/9j/") and not encoded.startswith("iVBOR"): |
| logger.warning("Base64 string doesn't start with expected JPEG or PNG header") |
|
|
| return encoded |
| except Exception as e: |
| logger.error(f"Error reading/encoding image: {str(e)}") |
| raise |
|
|
|
|
| def create_single_request( |
| image_path: str, question: str, options: Dict[str, str] |
| ) -> List[Dict[str, Any]]: |
| """ |
| Create a single API request with image and question. |
| |
| Args: |
| image_path (str): Path to the image file |
| question (str): Question text |
| options (Dict[str, str]): Dictionary containing options with keys 'option_0' and 'option_1' |
| |
| Returns: |
| List[Dict[str, Any]]: List of message dictionaries for the API request |
| |
| Raises: |
| Exception: For errors in request creation |
| """ |
| if DEBUG_MODE: |
| logger.debug("Creating API request...") |
|
|
| prompt = f"""Given the following medical examination question: |
| Please answer this multiple choice question: |
| |
| Question: {question} |
| |
| Options: |
| A) {options['option_0']} |
| B) {options['option_1']} |
| |
| Base your answer only on the provided image and select either A or B.""" |
|
|
| try: |
| encoded_image = encode_image(image_path) |
| mime_type = get_mime_type(image_path) |
|
|
| if DEBUG_MODE: |
| logger.debug(f"Image encoded with MIME type: {mime_type}") |
|
|
| messages = [ |
| { |
| "role": "system", |
| "content": "You are taking a medical exam. Answer ONLY with the letter (A/B) corresponding to your answer.", |
| }, |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": prompt}, |
| { |
| "type": "image_url", |
| "image_url": {"url": f"data:{mime_type};base64,{encoded_image}"}, |
| }, |
| ], |
| }, |
| ] |
|
|
| if DEBUG_MODE: |
| log_messages = json.loads(json.dumps(messages)) |
| log_messages[1]["content"][1]["image_url"][ |
| "url" |
| ] = f"data:{mime_type};base64,[BASE64_IMAGE_TRUNCATED]" |
| logger.debug(f"Complete API request payload:\n{json.dumps(log_messages, indent=2)}") |
|
|
| return messages |
|
|
| except Exception as e: |
| logger.error(f"Error creating request: {str(e)}") |
| raise |
|
|
|
|
| def check_answer(model_answer: str, correct_answer: int) -> bool: |
| """ |
| Check if the model's answer matches the correct answer. |
| |
| Args: |
| model_answer (str): The model's answer (A or B) |
| correct_answer (int): The correct answer index (0 for A, 1 for B) |
| |
| Returns: |
| bool: True if answer is correct, False otherwise |
| """ |
| if not isinstance(model_answer, str): |
| return False |
|
|
| |
| model_letter = model_answer.strip().upper() |
| if model_letter.startswith("A"): |
| model_index = 0 |
| elif model_letter.startswith("B"): |
| model_index = 1 |
| else: |
| return False |
|
|
| return model_index == correct_answer |
|
|
|
|
| def save_results_to_json(results: List[Dict[str, Any]], output_dir: str) -> str: |
| """ |
| Save results to a JSON file with timestamp. |
| |
| Args: |
| results (List[Dict[str, Any]]): List of result dictionaries |
| output_dir (str): Directory to save results |
| |
| Returns: |
| str: Path to the saved file |
| """ |
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| output_file = os.path.join(output_dir, f"batch_results_{timestamp}.json") |
|
|
| with open(output_file, "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| logger.info(f"Batch results saved to {output_file}") |
| return output_file |
|
|
|
|
| def calculate_accuracy(results: List[Dict[str, Any]]) -> tuple[float, int, int]: |
| """ |
| Calculate accuracy from results, handling error cases. |
| |
| Args: |
| results (List[Dict[str, Any]]): List of result dictionaries |
| |
| Returns: |
| tuple[float, int, int]: Tuple containing (accuracy percentage, number correct, total) |
| """ |
| if not results: |
| return 0.0, 0, 0 |
|
|
| total = len(results) |
| valid_results = [r for r in results if "output" in r] |
| correct = sum( |
| 1 for result in valid_results if result.get("output", {}).get("is_correct", False) |
| ) |
|
|
| accuracy = (correct / total * 100) if total > 0 else 0 |
| return accuracy, correct, total |
|
|
|
|
| def calculate_batch_accuracy(results: List[Dict[str, Any]]) -> float: |
| """ |
| Calculate accuracy for the current batch. |
| |
| Args: |
| results (List[Dict[str, Any]]): List of result dictionaries |
| |
| Returns: |
| float: Accuracy percentage for the batch |
| """ |
| valid_results = [r for r in results if "output" in r] |
| if not valid_results: |
| return 0.0 |
| return sum(1 for r in valid_results if r["output"]["is_correct"]) / len(valid_results) * 100 |
|
|
|
|
| def process_batch( |
| data: List[Dict[str, Any]], client: openai.OpenAI, start_idx: int = 0, batch_size: int = 50 |
| ) -> List[Dict[str, Any]]: |
| """ |
| Process a batch of examples and return results. |
| |
| Args: |
| data (List[Dict[str, Any]]): List of data items to process |
| client (openai.OpenAI): OpenAI client instance |
| start_idx (int, optional): Starting index for batch. Defaults to 0 |
| batch_size (int, optional): Size of batch to process. Defaults to 50 |
| |
| Returns: |
| List[Dict[str, Any]]: List of processed results |
| """ |
| batch_results = [] |
| end_idx = min(start_idx + batch_size, len(data)) |
|
|
| pbar = tqdm( |
| range(start_idx, end_idx), |
| desc=f"Processing batch {start_idx//batch_size + 1}", |
| unit="example", |
| ) |
|
|
| for index in pbar: |
| vqa_item = data[index] |
| options = {"option_0": vqa_item["option_0"], "option_1": vqa_item["option_1"]} |
|
|
| try: |
| messages = create_single_request( |
| image_path=vqa_item["image_path"], question=vqa_item["question"], options=options |
| ) |
|
|
| response = client.chat.completions.create( |
| model=MODEL_NAME, messages=messages, max_tokens=50, temperature=TEMPERATURE |
| ) |
|
|
| model_answer = response.choices[0].message.content.strip() |
| is_correct = check_answer(model_answer, vqa_item["answer"]) |
|
|
| result = { |
| "timestamp": datetime.now().isoformat(), |
| "example_index": index, |
| "input": { |
| "question": vqa_item["question"], |
| "options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]}, |
| "image_path": vqa_item["image_path"], |
| }, |
| "output": { |
| "model_answer": model_answer, |
| "correct_answer": "A" if vqa_item["answer"] == 0 else "B", |
| "is_correct": is_correct, |
| "usage": { |
| "prompt_tokens": response.usage.prompt_tokens, |
| "completion_tokens": response.usage.completion_tokens, |
| "total_tokens": response.usage.total_tokens, |
| }, |
| }, |
| } |
| batch_results.append(result) |
|
|
| |
| current_accuracy = calculate_batch_accuracy(batch_results) |
| pbar.set_description( |
| f"Batch {start_idx//batch_size + 1} - Accuracy: {current_accuracy:.2f}% " |
| f"({len(batch_results)}/{index-start_idx+1} examples)" |
| ) |
|
|
| except Exception as e: |
| error_result = { |
| "timestamp": datetime.now().isoformat(), |
| "example_index": index, |
| "error": str(e), |
| "input": { |
| "question": vqa_item["question"], |
| "options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]}, |
| "image_path": vqa_item["image_path"], |
| }, |
| } |
| batch_results.append(error_result) |
| if DEBUG_MODE: |
| pbar.write(f"Error processing example {index}: {str(e)}") |
|
|
| time.sleep(1) |
|
|
| return batch_results |
|
|
|
|
| def main() -> None: |
| """ |
| Main function to process the entire dataset. |
| |
| Raises: |
| ValueError: If OPENAI_API_KEY is not set |
| Exception: For other processing errors |
| """ |
| logger.info("Starting full dataset processing...") |
| json_path = "../data/chexbench_updated.json" |
|
|
| try: |
| api_key = os.getenv("OPENAI_API_KEY") |
| if not api_key: |
| raise ValueError("OPENAI_API_KEY environment variable is not set.") |
| client = openai.OpenAI(api_key=api_key) |
|
|
| with open(json_path, "r") as f: |
| data = json.load(f) |
|
|
| subset_data = data[SUBSET] |
| total_examples = len(subset_data) |
| logger.info(f"Found {total_examples} examples in {SUBSET} subset") |
|
|
| all_results = [] |
| batch_size = 50 |
|
|
| |
| for start_idx in range(0, total_examples, batch_size): |
| batch_results = process_batch(subset_data, client, start_idx, batch_size) |
| all_results.extend(batch_results) |
|
|
| |
| output_file = save_results_to_json(all_results, OUTPUT_DIR) |
|
|
| |
| overall_accuracy, correct, total = calculate_accuracy(all_results) |
| logger.info(f"Overall Progress: {len(all_results)}/{total_examples} examples processed") |
| logger.info(f"Current Accuracy: {overall_accuracy:.2f}% ({correct}/{total} correct)") |
|
|
| logger.info("Processing completed!") |
| logger.info(f"Final results saved to: {output_file}") |
|
|
| except Exception as e: |
| logger.error(f"Fatal error: {str(e)}") |
| raise |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|