File size: 7,227 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""E2E-style smoke coverage for the GRPO training notebook."""

from __future__ import annotations

import json
from pathlib import Path

from sql_env.training import rollout as rollout_module
from sql_env.training.config import GRPOConfig
from sql_env.training.notebook_pipeline import (
    build_trainer,
    run_training_with_metrics,
    sample_random_baseline,
)
from sql_env.training.data_loading import filter_questions_by_difficulty
from sql_env.training.rewards import (
    reward_correctness,
    reward_operational,
    reward_progress,
)
from sql_env.training.rollout import rollout_func


NOTEBOOK_PATH = Path("notebooks/train_grpo.ipynb")


def _read_notebook() -> dict:
    return json.loads(NOTEBOOK_PATH.read_text(encoding="utf-8"))


def _code_sources(notebook: dict) -> list[str]:
    cells = notebook.get("cells", [])
    return [
        "".join(cell.get("source", []))
        for cell in cells
        if cell.get("cell_type") == "code"
    ]


def test_training_notebook_smoke_structure() -> None:
    """Notebook includes the core GRPO training flow cells."""

    assert NOTEBOOK_PATH.exists(), "notebooks/train_grpo.ipynb must exist"

    notebook = _read_notebook()
    sources = "\n".join(_code_sources(notebook))

    assert "GRPOConfig(" in sources
    assert "load_model_and_tokenizer(config.model_name)" in sources
    assert "grpo_trainer_cls=GRPOTrainer" in sources
    assert "run_training_with_metrics" in sources
    assert "matplotlib.pyplot as plt" in sources

    before_index = sources.find("before_rollouts = sample_random_baseline")
    train_index = sources.find("run_training_with_metrics(trainer)")
    assert before_index != -1
    assert train_index != -1
    assert before_index < train_index


def test_question_filtering_by_difficulty() -> None:
    """Difficulty filtering keeps only questions in the allowed set."""

    questions = [
        {"question_text": "q1", "difficulty": "easy"},
        {"question_text": "q2", "difficulty": "medium"},
        {"question_text": "q3", "difficulty": "hard"},
    ]

    filtered = filter_questions_by_difficulty(questions, ["easy"])
    assert [item["question_text"] for item in filtered] == ["q1"]


class _FakeTokenizer:
    def apply_chat_template(
        self,
        messages: list[dict[str, str]],
        tokenize: bool = False,
        add_generation_prompt: bool = True,
    ) -> str:
        del messages
        del tokenize
        del add_generation_prompt
        return "prompt"


class _FakeModel:
    def __init__(self) -> None:
        self._count = 0

    def generate(self, prompt: str, max_new_tokens: int) -> str:
        del prompt
        del max_new_tokens
        self._count += 1
        if self._count == 1:
            return "QUERY: SELECT 1"
        return "ANSWER: 42"


class _FakeEnvironment:
    def __init__(self, step_budget: int) -> None:
        self.step_budget = step_budget
        self.step_count = 0
        self.state = type("State", (), {"episode_id": "ep-e2e"})()

    def reset(self, *, seed: int | None = None):
        del seed
        self.step_count = 0
        return self._observation(done=False, result="")

    def step(self, action):
        self.step_count += 1
        if getattr(action, "action_type", "") == "ANSWER":
            return self._observation(
                done=True, result="Answer submitted: correct.", reward=1.0
            )
        return self._observation(done=False, result="ok", reward=0.1)

    def _observation(self, done: bool, result: str, reward: float | None = 0.0):
        from sql_env.models import SQLObservation

        return SQLObservation(
            question="How many rows?",
            schema_info="Available tables:\n- t",
            result=result,
            error="",
            step_count=self.step_count,
            budget_remaining=max(0, self.step_budget - self.step_count),
            action_history=[],
            done=done,
            reward=reward,
        )


def test_training_pipeline_smoke(monkeypatch) -> None:
    """Happy-path rollout + reward computation produces trainable signals."""

    config = GRPOConfig(
        questions_path="data/questions/questions_train.json",
        db_dir="data/databases",
        output_dir="outputs/grpo_test",
        step_budget=2,
    )
    tokenizer = _FakeTokenizer()
    model = _FakeModel()
    fake_env = _FakeEnvironment(step_budget=2)

    monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)

    rollouts = rollout_func(["Count rows"], model, tokenizer, config)
    assert len(rollouts) == 1

    metadata = [item["metadata"] for item in rollouts]
    completions = [
        [{"role": "assistant", "content": item["content"]}] for item in rollouts
    ]

    correctness = reward_correctness(completions, metadata=metadata)
    progress = reward_progress(completions, metadata=metadata)
    operational = reward_operational(completions, metadata=metadata)

    assert correctness == [1.0]
    assert len(progress) == 1
    assert 0.0 <= progress[0] <= 1.0
    assert len(operational) == 1


class _FakeTRLConfig:
    def __init__(self, **kwargs):
        self.kwargs = kwargs


class _FakeTrainer:
    def __init__(
        self,
        *,
        model,
        processing_class,
        args,
        train_dataset,
        reward_funcs,
    ) -> None:
        self.model = model
        self.processing_class = processing_class
        self.args = args
        self.train_dataset = train_dataset
        self.reward_funcs = reward_funcs
        self.state = type("State", (), {"log_history": []})()
        self.train_called = False

    def train(self) -> dict[str, str]:
        self.train_called = True
        self.state.log_history = [{"step": 1, "reward": 0.25}]
        return {"status": "ok"}


def test_notebook_pipeline_executes_training_step(monkeypatch) -> None:
    """Notebook pipeline helper builds trainer and executes train()."""

    config = GRPOConfig(
        questions_path="data/questions/questions_train.json",
        db_dir="data/databases",
        output_dir="outputs/grpo_test",
        step_budget=2,
    )
    tokenizer = _FakeTokenizer()
    model = _FakeModel()
    fake_env = _FakeEnvironment(step_budget=2)
    monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)

    trainer = build_trainer(
        model=model,
        tokenizer=tokenizer,
        prompts=[{"prompt": "Count rows"}],
        config=config,
        trl_grpo_config_cls=_FakeTRLConfig,
        grpo_trainer_cls=_FakeTrainer,
        reward_funcs=[reward_correctness, reward_progress, reward_operational],
    )

    output, steps, rewards = run_training_with_metrics(trainer)

    assert trainer.train_called is True
    assert output == {"status": "ok"}
    assert steps == [1]
    assert rewards == [0.25]


def test_random_baseline_transcripts_are_generated() -> None:
    """Random baseline helper generates readable transcripts per prompt."""

    baseline = sample_random_baseline(["q1", "q2"], step_budget=3, seed=7)
    assert len(baseline) == 2
    assert all(item["metadata"]["policy"] == "random" for item in baseline)
    assert all(item["completion"] for item in baseline)