"""Error-handling tests for GRPO notebook helpers.""" from __future__ import annotations from pathlib import Path import pytest from sql_env.training.data_loading import ( load_model_and_tokenizer, load_question_prompts, ) from sql_env.training.notebook_pipeline import format_oom_guidance from sql_env.training.rollout import parse_model_output def test_model_load_error_bad_name(monkeypatch) -> None: """Model loading failures include the configured model name.""" class _Tokenizer: @staticmethod def from_pretrained(model_name: str): del model_name return object() class _Model: @staticmethod def from_pretrained(model_name: str): raise RuntimeError(f"missing model {model_name}") import sql_env.training.data_loading as data_loading monkeypatch.setattr(data_loading, "AutoTokenizer", _Tokenizer) monkeypatch.setattr(data_loading, "AutoModelForCausalLM", _Model) with pytest.raises(RuntimeError, match="nonexistent/model-xyz-999"): load_model_and_tokenizer("nonexistent/model-xyz-999") def test_question_load_missing_file() -> None: with pytest.raises(FileNotFoundError, match="/nonexistent/questions.json"): load_question_prompts("/nonexistent/questions.json", ["easy", "medium"]) def test_question_load_empty_file(tmp_path: Path) -> None: path = tmp_path / "questions.json" path.write_text("[]", encoding="utf-8") with pytest.raises(ValueError, match="empty or invalid"): load_question_prompts(str(path), ["easy"]) def test_question_load_invalid_json(tmp_path: Path) -> None: path = tmp_path / "questions.json" path.write_text("{broken", encoding="utf-8") with pytest.raises(ValueError, match="Invalid JSON"): load_question_prompts(str(path), ["easy"]) def test_oom_guidance() -> None: message = format_oom_guidance(RuntimeError("CUDA out of memory")) assert "per_device_train_batch_size" in message assert "num_generations" in message def test_action_parse_fallback_logged(caplog) -> None: caplog.set_level("WARNING") action = parse_model_output("¯\\_(ツ)_/¯") assert action.action_type == "QUERY" assert action.argument == "¯\\_(ツ)_/¯" assert "falling back to QUERY" in caplog.text