import argparse import json import os from typing import Dict, List, Tuple import torch from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu from nltk.translate.meteor_score import single_meteor_score from rouge_score import rouge_scorer from .config import PathsConfig, TrainingConfig, get_device, set_seed from .dataset import create_dataloader, create_tokenizer from .model import ImageCaptioningModel def parse_args() -> argparse.Namespace: """ Parse command-line arguments for evaluation. """ parser = argparse.ArgumentParser(description="Evaluate image captioning model on test set.") parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset.") parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt).") parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation.") parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length during generation.") parser.add_argument("--num_beams", type=int, default=3, help="Number of beams for beam search.") parser.add_argument("--seed", type=int, default=42, help="Random seed.") parser.add_argument("--output_samples", type=str, default="evaluation_samples.jsonl", help="File to save sample predictions.") return parser.parse_args() def compute_metrics( references: List[List[str]], hypotheses: List[str], ) -> Dict[str, float]: """ Compute BLEU (1-4), METEOR, and ROUGE-L metrics. """ if not references or not hypotheses: raise ValueError("References and hypotheses must be non-empty.") if len(references) != len(hypotheses): raise ValueError("Number of references and hypotheses must match.") smoothie = SmoothingFunction().method4 # BLEU scores bleu1 = corpus_bleu( references, hypotheses, weights=(1.0, 0.0, 0.0, 0.0), smoothing_function=smoothie, ) bleu2 = corpus_bleu( references, hypotheses, weights=(0.5, 0.5, 0.0, 0.0), smoothing_function=smoothie, ) bleu3 = corpus_bleu( references, hypotheses, weights=(1.0 / 3, 1.0 / 3, 1.0 / 3, 0.0), smoothing_function=smoothie, ) bleu4 = corpus_bleu( references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothie, ) # METEOR meteor_scores: List[float] = [] for ref_list, hyp in zip(references, hypotheses): # Use the first reference for METEOR; tokenize by simple whitespace. # If NLTK's WordNet data is missing, fall back to a simple unigram F1. ref_tokens = ref_list[0].split() hyp_tokens = hyp.split() try: meteor_scores.append(single_meteor_score(ref_tokens, hyp_tokens)) except LookupError: ref_set = set(ref_tokens) hyp_set = set(hyp_tokens) if not ref_set or not hyp_set: meteor_scores.append(0.0) else: overlap = len(ref_set & hyp_set) precision = overlap / len(hyp_set) recall = overlap / len(ref_set) if precision + recall == 0: meteor_scores.append(0.0) else: meteor_scores.append(2 * precision * recall / (precision + recall)) meteor = sum(meteor_scores) / max(1, len(meteor_scores)) # ROUGE-L rouge = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) rouge_l_scores: List[float] = [] for ref_list, hyp in zip(references, hypotheses): scores = rouge.score(ref_list[0], hyp) rouge_l_scores.append(scores["rougeL"].fmeasure) rouge_l = sum(rouge_l_scores) / max(1, len(rouge_l_scores)) return { "BLEU-1": bleu1, "BLEU-2": bleu2, "BLEU-3": bleu3, "BLEU-4": bleu4, "METEOR": meteor, "ROUGE-L": rouge_l, } def run_evaluation(args: argparse.Namespace) -> None: """ Run evaluation on the test set, compute metrics, and save sample predictions. """ paths_cfg = PathsConfig(data_root=args.data_root) training_cfg = TrainingConfig( batch_size=args.batch_size, max_caption_length=args.max_length, num_epochs=1, ) set_seed(args.seed) device = get_device() tokenizer = create_tokenizer() test_loader, tokenizer = create_dataloader( paths_cfg=paths_cfg, training_cfg=training_cfg, split="test", tokenizer=tokenizer, shuffle=False, ) model = ImageCaptioningModel(training_cfg=training_cfg) state_dict = torch.load(args.checkpoint, map_location=device) model.load_state_dict(state_dict) model.to(device) model.eval() references: List[List[str]] = [] hypotheses: List[str] = [] num_samples_to_save = 50 saved_samples: List[Dict[str, str]] = [] with torch.no_grad(): for batch in test_loader: images = batch["image"].to(device) # Use the raw caption string from the dataset as reference captions = batch["caption"] # Generate predictions one image at a time to respect generate() constraints for idx in range(images.size(0)): single_image = images[idx : idx + 1] ref_caption = captions[idx] pred_text_list = model.generate( images=single_image, max_length=args.max_length, num_beams=args.num_beams, ) pred_text = pred_text_list[0] references.append([ref_caption]) hypotheses.append(pred_text) if len(saved_samples) < num_samples_to_save: saved_samples.append( { "image_id": batch["image_id"][idx], "reference": ref_caption, "prediction": pred_text, } ) metrics = compute_metrics(references, hypotheses) print("Evaluation metrics:") for name, value in metrics.items(): print(f" {name}: {value:.4f}") # Save sample predictions output_path = args.output_samples with open(output_path, "w", encoding="utf-8") as f: for sample in saved_samples: f.write(json.dumps(sample) + "\n") print(f"Saved {len(saved_samples)} sample predictions to {output_path}") def main() -> None: args = parse_args() if not os.path.exists(args.checkpoint): raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}") run_evaluation(args) if __name__ == "__main__": main()