File size: 2,317 Bytes
5dd1bb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | """Error-handling tests for GRPO notebook helpers."""
from __future__ import annotations
from pathlib import Path
import pytest
from sql_env.training.data_loading import (
load_model_and_tokenizer,
load_question_prompts,
)
from sql_env.training.notebook_pipeline import format_oom_guidance
from sql_env.training.rollout import parse_model_output
def test_model_load_error_bad_name(monkeypatch) -> None:
"""Model loading failures include the configured model name."""
class _Tokenizer:
@staticmethod
def from_pretrained(model_name: str):
del model_name
return object()
class _Model:
@staticmethod
def from_pretrained(model_name: str):
raise RuntimeError(f"missing model {model_name}")
import sql_env.training.data_loading as data_loading
monkeypatch.setattr(data_loading, "AutoTokenizer", _Tokenizer)
monkeypatch.setattr(data_loading, "AutoModelForCausalLM", _Model)
with pytest.raises(RuntimeError, match="nonexistent/model-xyz-999"):
load_model_and_tokenizer("nonexistent/model-xyz-999")
def test_question_load_missing_file() -> None:
with pytest.raises(FileNotFoundError, match="/nonexistent/questions.json"):
load_question_prompts("/nonexistent/questions.json", ["easy", "medium"])
def test_question_load_empty_file(tmp_path: Path) -> None:
path = tmp_path / "questions.json"
path.write_text("[]", encoding="utf-8")
with pytest.raises(ValueError, match="empty or invalid"):
load_question_prompts(str(path), ["easy"])
def test_question_load_invalid_json(tmp_path: Path) -> None:
path = tmp_path / "questions.json"
path.write_text("{broken", encoding="utf-8")
with pytest.raises(ValueError, match="Invalid JSON"):
load_question_prompts(str(path), ["easy"])
def test_oom_guidance() -> None:
message = format_oom_guidance(RuntimeError("CUDA out of memory"))
assert "per_device_train_batch_size" in message
assert "num_generations" in message
def test_action_parse_fallback_logged(caplog) -> None:
caplog.set_level("WARNING")
action = parse_model_output("¯\\_(ツ)_/¯")
assert action.action_type == "QUERY"
assert action.argument == "¯\\_(ツ)_/¯"
assert "falling back to QUERY" in caplog.text
|