File size: 2,317 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Error-handling tests for GRPO notebook helpers."""

from __future__ import annotations

from pathlib import Path

import pytest

from sql_env.training.data_loading import (
    load_model_and_tokenizer,
    load_question_prompts,
)
from sql_env.training.notebook_pipeline import format_oom_guidance
from sql_env.training.rollout import parse_model_output


def test_model_load_error_bad_name(monkeypatch) -> None:
    """Model loading failures include the configured model name."""

    class _Tokenizer:
        @staticmethod
        def from_pretrained(model_name: str):
            del model_name
            return object()

    class _Model:
        @staticmethod
        def from_pretrained(model_name: str):
            raise RuntimeError(f"missing model {model_name}")

    import sql_env.training.data_loading as data_loading

    monkeypatch.setattr(data_loading, "AutoTokenizer", _Tokenizer)
    monkeypatch.setattr(data_loading, "AutoModelForCausalLM", _Model)

    with pytest.raises(RuntimeError, match="nonexistent/model-xyz-999"):
        load_model_and_tokenizer("nonexistent/model-xyz-999")


def test_question_load_missing_file() -> None:
    with pytest.raises(FileNotFoundError, match="/nonexistent/questions.json"):
        load_question_prompts("/nonexistent/questions.json", ["easy", "medium"])


def test_question_load_empty_file(tmp_path: Path) -> None:
    path = tmp_path / "questions.json"
    path.write_text("[]", encoding="utf-8")

    with pytest.raises(ValueError, match="empty or invalid"):
        load_question_prompts(str(path), ["easy"])


def test_question_load_invalid_json(tmp_path: Path) -> None:
    path = tmp_path / "questions.json"
    path.write_text("{broken", encoding="utf-8")

    with pytest.raises(ValueError, match="Invalid JSON"):
        load_question_prompts(str(path), ["easy"])


def test_oom_guidance() -> None:
    message = format_oom_guidance(RuntimeError("CUDA out of memory"))
    assert "per_device_train_batch_size" in message
    assert "num_generations" in message


def test_action_parse_fallback_logged(caplog) -> None:
    caplog.set_level("WARNING")

    action = parse_model_output("¯\\_(ツ)_/¯")

    assert action.action_type == "QUERY"
    assert action.argument == "¯\\_(ツ)_/¯"
    assert "falling back to QUERY" in caplog.text