| """Unit tests for training rollout helpers.""" |
|
|
| from types import SimpleNamespace |
|
|
| from sql_env.models import SQLAction |
| from sql_env.models import SQLObservation |
| from sql_env.training.config import GRPOConfig |
| from sql_env.training import rollout as rollout_module |
| from sql_env.training.rollout import parse_model_output, rollout_func |
|
|
|
|
| class FakeTokenizer: |
| def __init__(self) -> None: |
| self.messages_seen: list[list[dict[str, str]]] = [] |
|
|
| def apply_chat_template( |
| self, |
| messages: list[dict[str, str]], |
| tokenize: bool = False, |
| add_generation_prompt: bool = True, |
| ) -> str: |
| del tokenize |
| del add_generation_prompt |
| self.messages_seen.append(messages) |
| return "\n".join(f"{msg['role']}::{msg['content']}" for msg in messages) |
|
|
|
|
| class FakeModel: |
| def __init__(self, outputs: list[str]) -> None: |
| self._outputs = outputs |
|
|
| def generate(self, prompt: str, max_new_tokens: int) -> str: |
| del prompt |
| del max_new_tokens |
| if self._outputs: |
| return self._outputs.pop(0) |
| return "ANSWER: done" |
|
|
|
|
| class FakeEnvironment: |
| def __init__( |
| self, |
| *, |
| step_budget: int, |
| done_after: int | None = None, |
| questions: list[SimpleNamespace] | None = None, |
| answer_is_correct: bool = True, |
| ) -> None: |
| self._step_budget = step_budget |
| self._done_after = done_after if done_after is not None else step_budget |
| self._step = 0 |
| self.actions: list[SQLAction] = [] |
| self.state = SimpleNamespace(episode_id="ep-test") |
| self.questions = questions if questions is not None else [] |
| self.last_reset_question_text: str | None = None |
| self.answer_is_correct = answer_is_correct |
|
|
| def reset(self, *, seed: int | None = None) -> SQLObservation: |
| del seed |
| self._step = 0 |
| self.actions = [] |
| if self.questions: |
| self.last_reset_question_text = self.questions[0].question_text |
| return self._build_observation(done=False, error="", result="") |
|
|
| def step(self, action: SQLAction) -> SQLObservation: |
| self.actions.append(action) |
| self._step += 1 |
|
|
| error = "" |
| result = "ok" |
| reward = 0.0 |
|
|
| if action.argument == "hello world random text": |
| error = "unparseable action" |
|
|
| if action.action_type == "QUERY" and not error: |
| reward = 0.1 |
|
|
| done = self._step >= self._done_after |
| if action.action_type == "ANSWER": |
| done = True |
| if self.answer_is_correct: |
| result = "Answer submitted: correct." |
| reward = 1.0 |
| else: |
| result = "Answer submitted: incorrect." |
| reward = 0.0 |
|
|
| return self._build_observation( |
| done=done, error=error, result=result, reward=reward |
| ) |
|
|
| def _build_observation( |
| self, |
| *, |
| done: bool, |
| error: str, |
| result: str, |
| reward: float | None = None, |
| ) -> SQLObservation: |
| return SQLObservation( |
| question="How many students?", |
| schema_info="Available tables:\n- students", |
| result=result, |
| error=error, |
| step_count=self._step, |
| budget_remaining=max(0, self._step_budget - self._step), |
| action_history=[f"step-{idx}" for idx in range(self._step)], |
| done=done, |
| reward=reward, |
| ) |
|
|
|
|
| class HFTokenizer: |
| def __init__(self) -> None: |
| self.messages_seen: list[list[dict[str, str]]] = [] |
|
|
| def apply_chat_template( |
| self, |
| messages: list[dict[str, str]], |
| tokenize: bool = False, |
| add_generation_prompt: bool = True, |
| ) -> str: |
| del tokenize |
| del add_generation_prompt |
| self.messages_seen.append(messages) |
| return "prompt" |
|
|
| def __call__( |
| self, text: str, return_tensors: str = "pt" |
| ) -> dict[str, list[list[int]]]: |
| del text |
| del return_tensors |
| return {"input_ids": [[1, 2, 3]], "attention_mask": [[1, 1, 1]]} |
|
|
| def decode(self, token_ids, skip_special_tokens: bool = True) -> str: |
| del skip_special_tokens |
| if token_ids == [4, 5, 6]: |
| return "ANSWER: 42" |
| return "QUERY: SELECT 1" |
|
|
|
|
| class HFModel: |
| def generate(self, **kwargs) -> list[list[int]]: |
| del kwargs |
| return [[1, 2, 3, 4, 5, 6]] |
|
|
|
|
| class FakeTensor: |
| def __init__(self, values: list[list[int]]) -> None: |
| self._values = values |
|
|
| def tolist(self) -> list[list[int]]: |
| return self._values |
|
|
|
|
| class HFTensorTokenizer(HFTokenizer): |
| def __call__(self, text: str, return_tensors: str = "pt") -> dict[str, FakeTensor]: |
| del text |
| del return_tensors |
| return { |
| "input_ids": FakeTensor([[1, 2, 3]]), |
| "attention_mask": FakeTensor([[1, 1, 1]]), |
| } |
|
|
|
|
| class HFTensorModel: |
| def generate(self, **kwargs) -> FakeTensor: |
| del kwargs |
| return FakeTensor([[1, 2, 3, 4, 5, 6]]) |
|
|
|
|
| def _build_config(step_budget: int = 5) -> GRPOConfig: |
| return GRPOConfig( |
| questions_path="data/questions/questions_train.json", |
| db_dir="data/databases", |
| output_dir="outputs/grpo_test", |
| step_budget=step_budget, |
| ) |
|
|
|
|
| def test_parse_describe() -> None: |
| action = parse_model_output("DESCRIBE employees") |
|
|
| assert action == SQLAction(action_type="DESCRIBE", argument="employees") |
|
|
|
|
| def test_parse_sample() -> None: |
| action = parse_model_output("SAMPLE departments") |
|
|
| assert action == SQLAction(action_type="SAMPLE", argument="departments") |
|
|
|
|
| def test_parse_query() -> None: |
| action = parse_model_output("QUERY SELECT COUNT(*) FROM employees") |
|
|
| assert action == SQLAction( |
| action_type="QUERY", |
| argument="SELECT COUNT(*) FROM employees", |
| ) |
|
|
|
|
| def test_parse_answer() -> None: |
| action = parse_model_output("ANSWER 42") |
|
|
| assert action == SQLAction(action_type="ANSWER", argument="42") |
|
|
|
|
| def test_parse_case_insensitive() -> None: |
| action = parse_model_output("describe employees") |
|
|
| assert action == SQLAction(action_type="DESCRIBE", argument="employees") |
|
|
|
|
| def test_parse_with_colon_separator() -> None: |
| action = parse_model_output("QUERY: SELECT 1") |
|
|
| assert action == SQLAction(action_type="QUERY", argument="SELECT 1") |
|
|
|
|
| def test_parse_garbage_fallback() -> None: |
| raw = "hello world random text" |
| action = parse_model_output(raw) |
|
|
| assert action == SQLAction(action_type="QUERY", argument=raw) |
|
|
|
|
| def test_parse_empty_string_fallback() -> None: |
| action = parse_model_output("") |
|
|
| assert action == SQLAction(action_type="QUERY", argument="") |
|
|
|
|
| def test_parse_only_action_no_argument() -> None: |
| raw = "DESCRIBE" |
| action = parse_model_output(raw) |
|
|
| assert action == SQLAction(action_type="QUERY", argument=raw) |
|
|
|
|
| def test_parse_multiline_output() -> None: |
| action = parse_model_output("Let me think...\nQUERY SELECT 1") |
|
|
| assert action == SQLAction(action_type="QUERY", argument="SELECT 1") |
|
|
|
|
| def test_parse_whitespace_padded() -> None: |
| action = parse_model_output(" ANSWER 42 ") |
|
|
| assert action == SQLAction(action_type="ANSWER", argument="42") |
|
|
|
|
| def test_rollout_returns_completions(monkeypatch) -> None: |
| config = _build_config(step_budget=5) |
| tokenizer = FakeTokenizer() |
| model = FakeModel(outputs=["ANSWER: 42"]) |
| fake_env = FakeEnvironment(step_budget=5, done_after=5) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| results = rollout_func(["Count students"], model, tokenizer, config) |
|
|
| assert len(results) == 1 |
| result = results[0] |
| assert "content" in result |
| assert "metadata" in result |
| assert "correct" in result |
| assert "progress" in result |
| assert "operational" in result |
|
|
|
|
| def test_rollout_batch_size(monkeypatch) -> None: |
| config = _build_config(step_budget=4) |
| tokenizer = FakeTokenizer() |
| model = FakeModel(outputs=["ANSWER: 1", "ANSWER: 2", "ANSWER: 3"]) |
| fake_env = FakeEnvironment(step_budget=4) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| results = rollout_func(["q1", "q2", "q3"], model, tokenizer, config) |
|
|
| assert len(results) == 3 |
|
|
|
|
| def test_rollout_episode_terminates(monkeypatch) -> None: |
| config = _build_config(step_budget=5) |
| tokenizer = FakeTokenizer() |
| model = FakeModel(outputs=["QUERY: SELECT 1"] * 20) |
| fake_env = FakeEnvironment(step_budget=5, done_after=50) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| results = rollout_func(["q1"], model, tokenizer, config) |
|
|
| assert results[0]["metadata"]["step_count"] <= 5 |
|
|
|
|
| def test_rollout_metadata_present(monkeypatch) -> None: |
| config = _build_config(step_budget=3) |
| tokenizer = FakeTokenizer() |
| model = FakeModel(outputs=["ANSWER: 42"]) |
| fake_env = FakeEnvironment(step_budget=3) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| result = rollout_func(["q1"], model, tokenizer, config)[0] |
|
|
| assert "correct" in result |
| assert "progress" in result |
| assert "operational" in result |
| assert "episode_id" in result["metadata"] |
| assert "step_count" in result["metadata"] |
| assert "done" in result["metadata"] |
|
|
|
|
| def test_rollout_unparseable_action(monkeypatch) -> None: |
| config = _build_config(step_budget=3) |
| tokenizer = FakeTokenizer() |
| model = FakeModel(outputs=["hello world random text", "ANSWER: 42"]) |
| fake_env = FakeEnvironment(step_budget=3) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| results = rollout_func(["q1"], model, tokenizer, config) |
|
|
| assert len(results) == 1 |
| assert fake_env.actions[0].action_type == "QUERY" |
| assert fake_env.actions[0].argument == "hello world random text" |
|
|
|
|
| def test_rollout_truncation(monkeypatch) -> None: |
| config = _build_config(step_budget=20) |
| tokenizer = FakeTokenizer() |
| model = FakeModel(outputs=["QUERY: SELECT 1"] * 20) |
| fake_env = FakeEnvironment(step_budget=20, done_after=20) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| rollout_func(["q1"], model, tokenizer, config) |
|
|
| assert tokenizer.messages_seen |
| assert any(len(messages) <= 8 for messages in tokenizer.messages_seen[6:]) |
|
|
|
|
| def test_rollout_uses_hf_style_generate(monkeypatch) -> None: |
| config = _build_config(step_budget=2) |
| tokenizer = HFTokenizer() |
| model = HFModel() |
| fake_env = FakeEnvironment(step_budget=2) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| result = rollout_func(["q1"], model, tokenizer, config)[0] |
|
|
| assert result["correct"] is True |
| assert fake_env.actions[0].action_type == "ANSWER" |
|
|
|
|
| def test_rollout_binds_environment_to_prompt_when_available(monkeypatch) -> None: |
| config = _build_config(step_budget=1) |
| tokenizer = FakeTokenizer() |
| model = FakeModel(outputs=["ANSWER: 42"]) |
| questions = [ |
| SimpleNamespace(question_text="q1"), |
| SimpleNamespace(question_text="q2"), |
| ] |
| fake_env = FakeEnvironment(step_budget=1, questions=questions) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| rollout_func(["q2"], model, tokenizer, config) |
|
|
| assert fake_env.last_reset_question_text == "q2" |
|
|
|
|
| def test_rollout_incorrect_answer_not_marked_correct(monkeypatch) -> None: |
| config = _build_config(step_budget=1) |
| tokenizer = FakeTokenizer() |
| model = FakeModel(outputs=["ANSWER: 42"]) |
| fake_env = FakeEnvironment(step_budget=1, answer_is_correct=False) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| result = rollout_func(["q1"], model, tokenizer, config)[0] |
|
|
| assert result["correct"] is False |
|
|
|
|
| def test_rollout_handles_tensor_like_generate_outputs(monkeypatch) -> None: |
| config = _build_config(step_budget=2) |
| tokenizer = HFTensorTokenizer() |
| model = HFTensorModel() |
| fake_env = FakeEnvironment(step_budget=2) |
|
|
| monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) |
|
|
| result = rollout_func(["q1"], model, tokenizer, config)[0] |
|
|
| assert result["correct"] is True |
| assert fake_env.actions[0].action_type == "ANSWER" |
|
|