sql_env / tests /e2e /test_training_e2e.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""E2E-style smoke coverage for the GRPO training notebook."""
from __future__ import annotations
import json
from pathlib import Path
from sql_env.training import rollout as rollout_module
from sql_env.training.config import GRPOConfig
from sql_env.training.notebook_pipeline import (
build_trainer,
run_training_with_metrics,
sample_random_baseline,
)
from sql_env.training.data_loading import filter_questions_by_difficulty
from sql_env.training.rewards import (
reward_correctness,
reward_operational,
reward_progress,
)
from sql_env.training.rollout import rollout_func
NOTEBOOK_PATH = Path("notebooks/train_grpo.ipynb")
def _read_notebook() -> dict:
return json.loads(NOTEBOOK_PATH.read_text(encoding="utf-8"))
def _code_sources(notebook: dict) -> list[str]:
cells = notebook.get("cells", [])
return [
"".join(cell.get("source", []))
for cell in cells
if cell.get("cell_type") == "code"
]
def test_training_notebook_smoke_structure() -> None:
"""Notebook includes the core GRPO training flow cells."""
assert NOTEBOOK_PATH.exists(), "notebooks/train_grpo.ipynb must exist"
notebook = _read_notebook()
sources = "\n".join(_code_sources(notebook))
assert "GRPOConfig(" in sources
assert "load_model_and_tokenizer(config.model_name)" in sources
assert "grpo_trainer_cls=GRPOTrainer" in sources
assert "run_training_with_metrics" in sources
assert "matplotlib.pyplot as plt" in sources
before_index = sources.find("before_rollouts = sample_random_baseline")
train_index = sources.find("run_training_with_metrics(trainer)")
assert before_index != -1
assert train_index != -1
assert before_index < train_index
def test_question_filtering_by_difficulty() -> None:
"""Difficulty filtering keeps only questions in the allowed set."""
questions = [
{"question_text": "q1", "difficulty": "easy"},
{"question_text": "q2", "difficulty": "medium"},
{"question_text": "q3", "difficulty": "hard"},
]
filtered = filter_questions_by_difficulty(questions, ["easy"])
assert [item["question_text"] for item in filtered] == ["q1"]
class _FakeTokenizer:
def apply_chat_template(
self,
messages: list[dict[str, str]],
tokenize: bool = False,
add_generation_prompt: bool = True,
) -> str:
del messages
del tokenize
del add_generation_prompt
return "prompt"
class _FakeModel:
def __init__(self) -> None:
self._count = 0
def generate(self, prompt: str, max_new_tokens: int) -> str:
del prompt
del max_new_tokens
self._count += 1
if self._count == 1:
return "QUERY: SELECT 1"
return "ANSWER: 42"
class _FakeEnvironment:
def __init__(self, step_budget: int) -> None:
self.step_budget = step_budget
self.step_count = 0
self.state = type("State", (), {"episode_id": "ep-e2e"})()
def reset(self, *, seed: int | None = None):
del seed
self.step_count = 0
return self._observation(done=False, result="")
def step(self, action):
self.step_count += 1
if getattr(action, "action_type", "") == "ANSWER":
return self._observation(
done=True, result="Answer submitted: correct.", reward=1.0
)
return self._observation(done=False, result="ok", reward=0.1)
def _observation(self, done: bool, result: str, reward: float | None = 0.0):
from sql_env.models import SQLObservation
return SQLObservation(
question="How many rows?",
schema_info="Available tables:\n- t",
result=result,
error="",
step_count=self.step_count,
budget_remaining=max(0, self.step_budget - self.step_count),
action_history=[],
done=done,
reward=reward,
)
def test_training_pipeline_smoke(monkeypatch) -> None:
"""Happy-path rollout + reward computation produces trainable signals."""
config = GRPOConfig(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
output_dir="outputs/grpo_test",
step_budget=2,
)
tokenizer = _FakeTokenizer()
model = _FakeModel()
fake_env = _FakeEnvironment(step_budget=2)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
rollouts = rollout_func(["Count rows"], model, tokenizer, config)
assert len(rollouts) == 1
metadata = [item["metadata"] for item in rollouts]
completions = [
[{"role": "assistant", "content": item["content"]}] for item in rollouts
]
correctness = reward_correctness(completions, metadata=metadata)
progress = reward_progress(completions, metadata=metadata)
operational = reward_operational(completions, metadata=metadata)
assert correctness == [1.0]
assert len(progress) == 1
assert 0.0 <= progress[0] <= 1.0
assert len(operational) == 1
class _FakeTRLConfig:
def __init__(self, **kwargs):
self.kwargs = kwargs
class _FakeTrainer:
def __init__(
self,
*,
model,
processing_class,
args,
train_dataset,
reward_funcs,
) -> None:
self.model = model
self.processing_class = processing_class
self.args = args
self.train_dataset = train_dataset
self.reward_funcs = reward_funcs
self.state = type("State", (), {"log_history": []})()
self.train_called = False
def train(self) -> dict[str, str]:
self.train_called = True
self.state.log_history = [{"step": 1, "reward": 0.25}]
return {"status": "ok"}
def test_notebook_pipeline_executes_training_step(monkeypatch) -> None:
"""Notebook pipeline helper builds trainer and executes train()."""
config = GRPOConfig(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
output_dir="outputs/grpo_test",
step_budget=2,
)
tokenizer = _FakeTokenizer()
model = _FakeModel()
fake_env = _FakeEnvironment(step_budget=2)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
trainer = build_trainer(
model=model,
tokenizer=tokenizer,
prompts=[{"prompt": "Count rows"}],
config=config,
trl_grpo_config_cls=_FakeTRLConfig,
grpo_trainer_cls=_FakeTrainer,
reward_funcs=[reward_correctness, reward_progress, reward_operational],
)
output, steps, rewards = run_training_with_metrics(trainer)
assert trainer.train_called is True
assert output == {"status": "ok"}
assert steps == [1]
assert rewards == [0.25]
def test_random_baseline_transcripts_are_generated() -> None:
"""Random baseline helper generates readable transcripts per prompt."""
baseline = sample_random_baseline(["q1", "q2"], step_budget=3, seed=7)
assert len(baseline) == 2
assert all(item["metadata"]["policy"] == "random" for item in baseline)
assert all(item["completion"] for item in baseline)