| """Error-handling tests for GRPO notebook helpers.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import pytest |
|
|
| from sql_env.training.data_loading import ( |
| load_model_and_tokenizer, |
| load_question_prompts, |
| ) |
| from sql_env.training.notebook_pipeline import format_oom_guidance |
| from sql_env.training.rollout import parse_model_output |
|
|
|
|
| def test_model_load_error_bad_name(monkeypatch) -> None: |
| """Model loading failures include the configured model name.""" |
|
|
| class _Tokenizer: |
| @staticmethod |
| def from_pretrained(model_name: str): |
| del model_name |
| return object() |
|
|
| class _Model: |
| @staticmethod |
| def from_pretrained(model_name: str): |
| raise RuntimeError(f"missing model {model_name}") |
|
|
| import sql_env.training.data_loading as data_loading |
|
|
| monkeypatch.setattr(data_loading, "AutoTokenizer", _Tokenizer) |
| monkeypatch.setattr(data_loading, "AutoModelForCausalLM", _Model) |
|
|
| with pytest.raises(RuntimeError, match="nonexistent/model-xyz-999"): |
| load_model_and_tokenizer("nonexistent/model-xyz-999") |
|
|
|
|
| def test_question_load_missing_file() -> None: |
| with pytest.raises(FileNotFoundError, match="/nonexistent/questions.json"): |
| load_question_prompts("/nonexistent/questions.json", ["easy", "medium"]) |
|
|
|
|
| def test_question_load_empty_file(tmp_path: Path) -> None: |
| path = tmp_path / "questions.json" |
| path.write_text("[]", encoding="utf-8") |
|
|
| with pytest.raises(ValueError, match="empty or invalid"): |
| load_question_prompts(str(path), ["easy"]) |
|
|
|
|
| def test_question_load_invalid_json(tmp_path: Path) -> None: |
| path = tmp_path / "questions.json" |
| path.write_text("{broken", encoding="utf-8") |
|
|
| with pytest.raises(ValueError, match="Invalid JSON"): |
| load_question_prompts(str(path), ["easy"]) |
|
|
|
|
| def test_oom_guidance() -> None: |
| message = format_oom_guidance(RuntimeError("CUDA out of memory")) |
| assert "per_device_train_batch_size" in message |
| assert "num_generations" in message |
|
|
|
|
| def test_action_parse_fallback_logged(caplog) -> None: |
| caplog.set_level("WARNING") |
|
|
| action = parse_model_output("¯\\_(ツ)_/¯") |
|
|
| assert action.action_type == "QUERY" |
| assert action.argument == "¯\\_(ツ)_/¯" |
| assert "falling back to QUERY" in caplog.text |
|
|