Spaces:
Sleeping
Sleeping
| import argparse | |
| import json | |
| import os | |
| from typing import Dict, List, Tuple | |
| import torch | |
| from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu | |
| from nltk.translate.meteor_score import single_meteor_score | |
| from rouge_score import rouge_scorer | |
| from .config import PathsConfig, TrainingConfig, get_device, set_seed | |
| from .dataset import create_dataloader, create_tokenizer | |
| from .model import ImageCaptioningModel | |
| def parse_args() -> argparse.Namespace: | |
| """ | |
| Parse command-line arguments for evaluation. | |
| """ | |
| parser = argparse.ArgumentParser(description="Evaluate image captioning model on test set.") | |
| parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset.") | |
| parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt).") | |
| parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation.") | |
| parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length during generation.") | |
| parser.add_argument("--num_beams", type=int, default=3, help="Number of beams for beam search.") | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed.") | |
| parser.add_argument("--output_samples", type=str, default="evaluation_samples.jsonl", help="File to save sample predictions.") | |
| return parser.parse_args() | |
| def compute_metrics( | |
| references: List[List[str]], | |
| hypotheses: List[str], | |
| ) -> Dict[str, float]: | |
| """ | |
| Compute BLEU (1-4), METEOR, and ROUGE-L metrics. | |
| """ | |
| if not references or not hypotheses: | |
| raise ValueError("References and hypotheses must be non-empty.") | |
| if len(references) != len(hypotheses): | |
| raise ValueError("Number of references and hypotheses must match.") | |
| smoothie = SmoothingFunction().method4 | |
| # BLEU scores | |
| bleu1 = corpus_bleu( | |
| references, | |
| hypotheses, | |
| weights=(1.0, 0.0, 0.0, 0.0), | |
| smoothing_function=smoothie, | |
| ) | |
| bleu2 = corpus_bleu( | |
| references, | |
| hypotheses, | |
| weights=(0.5, 0.5, 0.0, 0.0), | |
| smoothing_function=smoothie, | |
| ) | |
| bleu3 = corpus_bleu( | |
| references, | |
| hypotheses, | |
| weights=(1.0 / 3, 1.0 / 3, 1.0 / 3, 0.0), | |
| smoothing_function=smoothie, | |
| ) | |
| bleu4 = corpus_bleu( | |
| references, | |
| hypotheses, | |
| weights=(0.25, 0.25, 0.25, 0.25), | |
| smoothing_function=smoothie, | |
| ) | |
| # METEOR | |
| meteor_scores: List[float] = [] | |
| for ref_list, hyp in zip(references, hypotheses): | |
| # Use the first reference for METEOR; tokenize by simple whitespace. | |
| # If NLTK's WordNet data is missing, fall back to a simple unigram F1. | |
| ref_tokens = ref_list[0].split() | |
| hyp_tokens = hyp.split() | |
| try: | |
| meteor_scores.append(single_meteor_score(ref_tokens, hyp_tokens)) | |
| except LookupError: | |
| ref_set = set(ref_tokens) | |
| hyp_set = set(hyp_tokens) | |
| if not ref_set or not hyp_set: | |
| meteor_scores.append(0.0) | |
| else: | |
| overlap = len(ref_set & hyp_set) | |
| precision = overlap / len(hyp_set) | |
| recall = overlap / len(ref_set) | |
| if precision + recall == 0: | |
| meteor_scores.append(0.0) | |
| else: | |
| meteor_scores.append(2 * precision * recall / (precision + recall)) | |
| meteor = sum(meteor_scores) / max(1, len(meteor_scores)) | |
| # ROUGE-L | |
| rouge = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) | |
| rouge_l_scores: List[float] = [] | |
| for ref_list, hyp in zip(references, hypotheses): | |
| scores = rouge.score(ref_list[0], hyp) | |
| rouge_l_scores.append(scores["rougeL"].fmeasure) | |
| rouge_l = sum(rouge_l_scores) / max(1, len(rouge_l_scores)) | |
| return { | |
| "BLEU-1": bleu1, | |
| "BLEU-2": bleu2, | |
| "BLEU-3": bleu3, | |
| "BLEU-4": bleu4, | |
| "METEOR": meteor, | |
| "ROUGE-L": rouge_l, | |
| } | |
| def run_evaluation(args: argparse.Namespace) -> None: | |
| """ | |
| Run evaluation on the test set, compute metrics, and save sample predictions. | |
| """ | |
| paths_cfg = PathsConfig(data_root=args.data_root) | |
| training_cfg = TrainingConfig( | |
| batch_size=args.batch_size, | |
| max_caption_length=args.max_length, | |
| num_epochs=1, | |
| ) | |
| set_seed(args.seed) | |
| device = get_device() | |
| tokenizer = create_tokenizer() | |
| test_loader, tokenizer = create_dataloader( | |
| paths_cfg=paths_cfg, | |
| training_cfg=training_cfg, | |
| split="test", | |
| tokenizer=tokenizer, | |
| shuffle=False, | |
| ) | |
| model = ImageCaptioningModel(training_cfg=training_cfg) | |
| state_dict = torch.load(args.checkpoint, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| references: List[List[str]] = [] | |
| hypotheses: List[str] = [] | |
| num_samples_to_save = 50 | |
| saved_samples: List[Dict[str, str]] = [] | |
| with torch.no_grad(): | |
| for batch in test_loader: | |
| images = batch["image"].to(device) | |
| # Use the raw caption string from the dataset as reference | |
| captions = batch["caption"] | |
| # Generate predictions one image at a time to respect generate() constraints | |
| for idx in range(images.size(0)): | |
| single_image = images[idx : idx + 1] | |
| ref_caption = captions[idx] | |
| pred_text_list = model.generate( | |
| images=single_image, | |
| max_length=args.max_length, | |
| num_beams=args.num_beams, | |
| ) | |
| pred_text = pred_text_list[0] | |
| references.append([ref_caption]) | |
| hypotheses.append(pred_text) | |
| if len(saved_samples) < num_samples_to_save: | |
| saved_samples.append( | |
| { | |
| "image_id": batch["image_id"][idx], | |
| "reference": ref_caption, | |
| "prediction": pred_text, | |
| } | |
| ) | |
| metrics = compute_metrics(references, hypotheses) | |
| print("Evaluation metrics:") | |
| for name, value in metrics.items(): | |
| print(f" {name}: {value:.4f}") | |
| # Save sample predictions | |
| output_path = args.output_samples | |
| with open(output_path, "w", encoding="utf-8") as f: | |
| for sample in saved_samples: | |
| f.write(json.dumps(sample) + "\n") | |
| print(f"Saved {len(saved_samples)} sample predictions to {output_path}") | |
| def main() -> None: | |
| args = parse_args() | |
| if not os.path.exists(args.checkpoint): | |
| raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint}") | |
| run_evaluation(args) | |
| if __name__ == "__main__": | |
| main() | |