"""E2E-style smoke coverage for the GRPO training notebook.""" from __future__ import annotations import json from pathlib import Path from sql_env.training import rollout as rollout_module from sql_env.training.config import GRPOConfig from sql_env.training.notebook_pipeline import ( build_trainer, run_training_with_metrics, sample_random_baseline, ) from sql_env.training.data_loading import filter_questions_by_difficulty from sql_env.training.rewards import ( reward_correctness, reward_operational, reward_progress, ) from sql_env.training.rollout import rollout_func NOTEBOOK_PATH = Path("notebooks/train_grpo.ipynb") def _read_notebook() -> dict: return json.loads(NOTEBOOK_PATH.read_text(encoding="utf-8")) def _code_sources(notebook: dict) -> list[str]: cells = notebook.get("cells", []) return [ "".join(cell.get("source", [])) for cell in cells if cell.get("cell_type") == "code" ] def test_training_notebook_smoke_structure() -> None: """Notebook includes the core GRPO training flow cells.""" assert NOTEBOOK_PATH.exists(), "notebooks/train_grpo.ipynb must exist" notebook = _read_notebook() sources = "\n".join(_code_sources(notebook)) assert "GRPOConfig(" in sources assert "load_model_and_tokenizer(config.model_name)" in sources assert "grpo_trainer_cls=GRPOTrainer" in sources assert "run_training_with_metrics" in sources assert "matplotlib.pyplot as plt" in sources before_index = sources.find("before_rollouts = sample_random_baseline") train_index = sources.find("run_training_with_metrics(trainer)") assert before_index != -1 assert train_index != -1 assert before_index < train_index def test_question_filtering_by_difficulty() -> None: """Difficulty filtering keeps only questions in the allowed set.""" questions = [ {"question_text": "q1", "difficulty": "easy"}, {"question_text": "q2", "difficulty": "medium"}, {"question_text": "q3", "difficulty": "hard"}, ] filtered = filter_questions_by_difficulty(questions, ["easy"]) assert [item["question_text"] for item in filtered] == ["q1"] class _FakeTokenizer: def apply_chat_template( self, messages: list[dict[str, str]], tokenize: bool = False, add_generation_prompt: bool = True, ) -> str: del messages del tokenize del add_generation_prompt return "prompt" class _FakeModel: def __init__(self) -> None: self._count = 0 def generate(self, prompt: str, max_new_tokens: int) -> str: del prompt del max_new_tokens self._count += 1 if self._count == 1: return "QUERY: SELECT 1" return "ANSWER: 42" class _FakeEnvironment: def __init__(self, step_budget: int) -> None: self.step_budget = step_budget self.step_count = 0 self.state = type("State", (), {"episode_id": "ep-e2e"})() def reset(self, *, seed: int | None = None): del seed self.step_count = 0 return self._observation(done=False, result="") def step(self, action): self.step_count += 1 if getattr(action, "action_type", "") == "ANSWER": return self._observation( done=True, result="Answer submitted: correct.", reward=1.0 ) return self._observation(done=False, result="ok", reward=0.1) def _observation(self, done: bool, result: str, reward: float | None = 0.0): from sql_env.models import SQLObservation return SQLObservation( question="How many rows?", schema_info="Available tables:\n- t", result=result, error="", step_count=self.step_count, budget_remaining=max(0, self.step_budget - self.step_count), action_history=[], done=done, reward=reward, ) def test_training_pipeline_smoke(monkeypatch) -> None: """Happy-path rollout + reward computation produces trainable signals.""" config = GRPOConfig( questions_path="data/questions/questions_train.json", db_dir="data/databases", output_dir="outputs/grpo_test", step_budget=2, ) tokenizer = _FakeTokenizer() model = _FakeModel() fake_env = _FakeEnvironment(step_budget=2) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) rollouts = rollout_func(["Count rows"], model, tokenizer, config) assert len(rollouts) == 1 metadata = [item["metadata"] for item in rollouts] completions = [ [{"role": "assistant", "content": item["content"]}] for item in rollouts ] correctness = reward_correctness(completions, metadata=metadata) progress = reward_progress(completions, metadata=metadata) operational = reward_operational(completions, metadata=metadata) assert correctness == [1.0] assert len(progress) == 1 assert 0.0 <= progress[0] <= 1.0 assert len(operational) == 1 class _FakeTRLConfig: def __init__(self, **kwargs): self.kwargs = kwargs class _FakeTrainer: def __init__( self, *, model, processing_class, args, train_dataset, reward_funcs, ) -> None: self.model = model self.processing_class = processing_class self.args = args self.train_dataset = train_dataset self.reward_funcs = reward_funcs self.state = type("State", (), {"log_history": []})() self.train_called = False def train(self) -> dict[str, str]: self.train_called = True self.state.log_history = [{"step": 1, "reward": 0.25}] return {"status": "ok"} def test_notebook_pipeline_executes_training_step(monkeypatch) -> None: """Notebook pipeline helper builds trainer and executes train().""" config = GRPOConfig( questions_path="data/questions/questions_train.json", db_dir="data/databases", output_dir="outputs/grpo_test", step_budget=2, ) tokenizer = _FakeTokenizer() model = _FakeModel() fake_env = _FakeEnvironment(step_budget=2) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) trainer = build_trainer( model=model, tokenizer=tokenizer, prompts=[{"prompt": "Count rows"}], config=config, trl_grpo_config_cls=_FakeTRLConfig, grpo_trainer_cls=_FakeTrainer, reward_funcs=[reward_correctness, reward_progress, reward_operational], ) output, steps, rewards = run_training_with_metrics(trainer) assert trainer.train_called is True assert output == {"status": "ok"} assert steps == [1] assert rewards == [0.25] def test_random_baseline_transcripts_are_generated() -> None: """Random baseline helper generates readable transcripts per prompt.""" baseline = sample_random_baseline(["q1", "q2"], step_budget=3, seed=7) assert len(baseline) == 2 assert all(item["metadata"]["policy"] == "random" for item in baseline) assert all(item["completion"] for item in baseline)