| """Data/model loading helpers for the GRPO training notebook.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import Any |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| def filter_questions_by_difficulty( |
| questions: list[dict[str, Any]], allowed: list[str] |
| ) -> list[dict[str, Any]]: |
| """Filter question records by case-insensitive difficulty labels.""" |
|
|
| allowed_set = {level.lower() for level in allowed} |
| return [ |
| question |
| for question in questions |
| if str(question.get("difficulty", "")).lower() in allowed_set |
| ] |
|
|
|
|
| def load_question_prompts( |
| questions_path: str, allowed: list[str] |
| ) -> list[dict[str, str]]: |
| """Load question text prompts from JSON and apply difficulty filtering.""" |
|
|
| path = Path(questions_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Questions file not found: {questions_path}") |
|
|
| try: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| except json.JSONDecodeError as exc: |
| raise ValueError(f"Invalid JSON in questions file: {questions_path}") from exc |
|
|
| if not isinstance(payload, list) or not payload: |
| raise ValueError(f"Questions file is empty or invalid: {questions_path}") |
|
|
| filtered = filter_questions_by_difficulty(payload, allowed) |
| if not filtered: |
| raise ValueError( |
| f"No questions match difficulty_filter={allowed} in {questions_path}" |
| ) |
|
|
| prompts = [ |
| {"prompt": str(item["question_text"])} |
| for item in filtered |
| if item.get("question_text") |
| ] |
| if not prompts: |
| raise ValueError(f"No usable question_text values found in {questions_path}") |
|
|
| return prompts |
|
|
|
|
| def load_model_and_tokenizer(model_name: str) -> tuple[Any, Any]: |
| """Load HuggingFace tokenizer and model with fail-fast errors.""" |
|
|
| try: |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForCausalLM.from_pretrained(model_name) |
| except Exception as exc: |
| raise RuntimeError(f"Cannot load model '{model_name}': {exc}") from exc |
| return model, tokenizer |
|
|