"""Data/model loading helpers for the GRPO training notebook.""" from __future__ import annotations import json from pathlib import Path from typing import Any from transformers import AutoModelForCausalLM, AutoTokenizer def filter_questions_by_difficulty( questions: list[dict[str, Any]], allowed: list[str] ) -> list[dict[str, Any]]: """Filter question records by case-insensitive difficulty labels.""" allowed_set = {level.lower() for level in allowed} return [ question for question in questions if str(question.get("difficulty", "")).lower() in allowed_set ] def load_question_prompts( questions_path: str, allowed: list[str] ) -> list[dict[str, str]]: """Load question text prompts from JSON and apply difficulty filtering.""" path = Path(questions_path) if not path.exists(): raise FileNotFoundError(f"Questions file not found: {questions_path}") try: payload = json.loads(path.read_text(encoding="utf-8")) except json.JSONDecodeError as exc: raise ValueError(f"Invalid JSON in questions file: {questions_path}") from exc if not isinstance(payload, list) or not payload: raise ValueError(f"Questions file is empty or invalid: {questions_path}") filtered = filter_questions_by_difficulty(payload, allowed) if not filtered: raise ValueError( f"No questions match difficulty_filter={allowed} in {questions_path}" ) prompts = [ {"prompt": str(item["question_text"])} for item in filtered if item.get("question_text") ] if not prompts: raise ValueError(f"No usable question_text values found in {questions_path}") return prompts def load_model_and_tokenizer(model_name: str) -> tuple[Any, Any]: """Load HuggingFace tokenizer and model with fail-fast errors.""" try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) except Exception as exc: # pragma: no cover - covered by monkeypatched tests raise RuntimeError(f"Cannot load model '{model_name}': {exc}") from exc return model, tokenizer