File size: 2,198 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Data/model loading helpers for the GRPO training notebook."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any

from transformers import AutoModelForCausalLM, AutoTokenizer


def filter_questions_by_difficulty(
    questions: list[dict[str, Any]], allowed: list[str]
) -> list[dict[str, Any]]:
    """Filter question records by case-insensitive difficulty labels."""

    allowed_set = {level.lower() for level in allowed}
    return [
        question
        for question in questions
        if str(question.get("difficulty", "")).lower() in allowed_set
    ]


def load_question_prompts(
    questions_path: str, allowed: list[str]
) -> list[dict[str, str]]:
    """Load question text prompts from JSON and apply difficulty filtering."""

    path = Path(questions_path)
    if not path.exists():
        raise FileNotFoundError(f"Questions file not found: {questions_path}")

    try:
        payload = json.loads(path.read_text(encoding="utf-8"))
    except json.JSONDecodeError as exc:
        raise ValueError(f"Invalid JSON in questions file: {questions_path}") from exc

    if not isinstance(payload, list) or not payload:
        raise ValueError(f"Questions file is empty or invalid: {questions_path}")

    filtered = filter_questions_by_difficulty(payload, allowed)
    if not filtered:
        raise ValueError(
            f"No questions match difficulty_filter={allowed} in {questions_path}"
        )

    prompts = [
        {"prompt": str(item["question_text"])}
        for item in filtered
        if item.get("question_text")
    ]
    if not prompts:
        raise ValueError(f"No usable question_text values found in {questions_path}")

    return prompts


def load_model_and_tokenizer(model_name: str) -> tuple[Any, Any]:
    """Load HuggingFace tokenizer and model with fail-fast errors."""

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)
    except Exception as exc:  # pragma: no cover - covered by monkeypatched tests
        raise RuntimeError(f"Cannot load model '{model_name}': {exc}") from exc
    return model, tokenizer