"""Unit tests for training rollout helpers.""" from types import SimpleNamespace from sql_env.models import SQLAction from sql_env.models import SQLObservation from sql_env.training.config import GRPOConfig from sql_env.training import rollout as rollout_module from sql_env.training.rollout import parse_model_output, rollout_func class FakeTokenizer: def __init__(self) -> None: self.messages_seen: list[list[dict[str, str]]] = [] def apply_chat_template( self, messages: list[dict[str, str]], tokenize: bool = False, add_generation_prompt: bool = True, ) -> str: del tokenize del add_generation_prompt self.messages_seen.append(messages) return "\n".join(f"{msg['role']}::{msg['content']}" for msg in messages) class FakeModel: def __init__(self, outputs: list[str]) -> None: self._outputs = outputs def generate(self, prompt: str, max_new_tokens: int) -> str: del prompt del max_new_tokens if self._outputs: return self._outputs.pop(0) return "ANSWER: done" class FakeEnvironment: def __init__( self, *, step_budget: int, done_after: int | None = None, questions: list[SimpleNamespace] | None = None, answer_is_correct: bool = True, ) -> None: self._step_budget = step_budget self._done_after = done_after if done_after is not None else step_budget self._step = 0 self.actions: list[SQLAction] = [] self.state = SimpleNamespace(episode_id="ep-test") self.questions = questions if questions is not None else [] self.last_reset_question_text: str | None = None self.answer_is_correct = answer_is_correct def reset(self, *, seed: int | None = None) -> SQLObservation: del seed self._step = 0 self.actions = [] if self.questions: self.last_reset_question_text = self.questions[0].question_text return self._build_observation(done=False, error="", result="") def step(self, action: SQLAction) -> SQLObservation: self.actions.append(action) self._step += 1 error = "" result = "ok" reward = 0.0 if action.argument == "hello world random text": error = "unparseable action" if action.action_type == "QUERY" and not error: reward = 0.1 done = self._step >= self._done_after if action.action_type == "ANSWER": done = True if self.answer_is_correct: result = "Answer submitted: correct." reward = 1.0 else: result = "Answer submitted: incorrect." reward = 0.0 return self._build_observation( done=done, error=error, result=result, reward=reward ) def _build_observation( self, *, done: bool, error: str, result: str, reward: float | None = None, ) -> SQLObservation: return SQLObservation( question="How many students?", schema_info="Available tables:\n- students", result=result, error=error, step_count=self._step, budget_remaining=max(0, self._step_budget - self._step), action_history=[f"step-{idx}" for idx in range(self._step)], done=done, reward=reward, ) class HFTokenizer: def __init__(self) -> None: self.messages_seen: list[list[dict[str, str]]] = [] def apply_chat_template( self, messages: list[dict[str, str]], tokenize: bool = False, add_generation_prompt: bool = True, ) -> str: del tokenize del add_generation_prompt self.messages_seen.append(messages) return "prompt" def __call__( self, text: str, return_tensors: str = "pt" ) -> dict[str, list[list[int]]]: del text del return_tensors return {"input_ids": [[1, 2, 3]], "attention_mask": [[1, 1, 1]]} def decode(self, token_ids, skip_special_tokens: bool = True) -> str: del skip_special_tokens if token_ids == [4, 5, 6]: return "ANSWER: 42" return "QUERY: SELECT 1" class HFModel: def generate(self, **kwargs) -> list[list[int]]: del kwargs return [[1, 2, 3, 4, 5, 6]] class FakeTensor: def __init__(self, values: list[list[int]]) -> None: self._values = values def tolist(self) -> list[list[int]]: return self._values class HFTensorTokenizer(HFTokenizer): def __call__(self, text: str, return_tensors: str = "pt") -> dict[str, FakeTensor]: del text del return_tensors return { "input_ids": FakeTensor([[1, 2, 3]]), "attention_mask": FakeTensor([[1, 1, 1]]), } class HFTensorModel: def generate(self, **kwargs) -> FakeTensor: del kwargs return FakeTensor([[1, 2, 3, 4, 5, 6]]) def _build_config(step_budget: int = 5) -> GRPOConfig: return GRPOConfig( questions_path="data/questions/questions_train.json", db_dir="data/databases", output_dir="outputs/grpo_test", step_budget=step_budget, ) def test_parse_describe() -> None: action = parse_model_output("DESCRIBE employees") assert action == SQLAction(action_type="DESCRIBE", argument="employees") def test_parse_sample() -> None: action = parse_model_output("SAMPLE departments") assert action == SQLAction(action_type="SAMPLE", argument="departments") def test_parse_query() -> None: action = parse_model_output("QUERY SELECT COUNT(*) FROM employees") assert action == SQLAction( action_type="QUERY", argument="SELECT COUNT(*) FROM employees", ) def test_parse_answer() -> None: action = parse_model_output("ANSWER 42") assert action == SQLAction(action_type="ANSWER", argument="42") def test_parse_case_insensitive() -> None: action = parse_model_output("describe employees") assert action == SQLAction(action_type="DESCRIBE", argument="employees") def test_parse_with_colon_separator() -> None: action = parse_model_output("QUERY: SELECT 1") assert action == SQLAction(action_type="QUERY", argument="SELECT 1") def test_parse_garbage_fallback() -> None: raw = "hello world random text" action = parse_model_output(raw) assert action == SQLAction(action_type="QUERY", argument=raw) def test_parse_empty_string_fallback() -> None: action = parse_model_output("") assert action == SQLAction(action_type="QUERY", argument="") def test_parse_only_action_no_argument() -> None: raw = "DESCRIBE" action = parse_model_output(raw) assert action == SQLAction(action_type="QUERY", argument=raw) def test_parse_multiline_output() -> None: action = parse_model_output("Let me think...\nQUERY SELECT 1") assert action == SQLAction(action_type="QUERY", argument="SELECT 1") def test_parse_whitespace_padded() -> None: action = parse_model_output(" ANSWER 42 ") assert action == SQLAction(action_type="ANSWER", argument="42") def test_rollout_returns_completions(monkeypatch) -> None: config = _build_config(step_budget=5) tokenizer = FakeTokenizer() model = FakeModel(outputs=["ANSWER: 42"]) fake_env = FakeEnvironment(step_budget=5, done_after=5) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) results = rollout_func(["Count students"], model, tokenizer, config) assert len(results) == 1 result = results[0] assert "content" in result assert "metadata" in result assert "correct" in result assert "progress" in result assert "operational" in result def test_rollout_batch_size(monkeypatch) -> None: config = _build_config(step_budget=4) tokenizer = FakeTokenizer() model = FakeModel(outputs=["ANSWER: 1", "ANSWER: 2", "ANSWER: 3"]) fake_env = FakeEnvironment(step_budget=4) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) results = rollout_func(["q1", "q2", "q3"], model, tokenizer, config) assert len(results) == 3 def test_rollout_episode_terminates(monkeypatch) -> None: config = _build_config(step_budget=5) tokenizer = FakeTokenizer() model = FakeModel(outputs=["QUERY: SELECT 1"] * 20) fake_env = FakeEnvironment(step_budget=5, done_after=50) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) results = rollout_func(["q1"], model, tokenizer, config) assert results[0]["metadata"]["step_count"] <= 5 def test_rollout_metadata_present(monkeypatch) -> None: config = _build_config(step_budget=3) tokenizer = FakeTokenizer() model = FakeModel(outputs=["ANSWER: 42"]) fake_env = FakeEnvironment(step_budget=3) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) result = rollout_func(["q1"], model, tokenizer, config)[0] assert "correct" in result assert "progress" in result assert "operational" in result assert "episode_id" in result["metadata"] assert "step_count" in result["metadata"] assert "done" in result["metadata"] def test_rollout_unparseable_action(monkeypatch) -> None: config = _build_config(step_budget=3) tokenizer = FakeTokenizer() model = FakeModel(outputs=["hello world random text", "ANSWER: 42"]) fake_env = FakeEnvironment(step_budget=3) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) results = rollout_func(["q1"], model, tokenizer, config) assert len(results) == 1 assert fake_env.actions[0].action_type == "QUERY" assert fake_env.actions[0].argument == "hello world random text" def test_rollout_truncation(monkeypatch) -> None: config = _build_config(step_budget=20) tokenizer = FakeTokenizer() model = FakeModel(outputs=["QUERY: SELECT 1"] * 20) fake_env = FakeEnvironment(step_budget=20, done_after=20) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) rollout_func(["q1"], model, tokenizer, config) assert tokenizer.messages_seen assert any(len(messages) <= 8 for messages in tokenizer.messages_seen[6:]) def test_rollout_uses_hf_style_generate(monkeypatch) -> None: config = _build_config(step_budget=2) tokenizer = HFTokenizer() model = HFModel() fake_env = FakeEnvironment(step_budget=2) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) result = rollout_func(["q1"], model, tokenizer, config)[0] assert result["correct"] is True assert fake_env.actions[0].action_type == "ANSWER" def test_rollout_binds_environment_to_prompt_when_available(monkeypatch) -> None: config = _build_config(step_budget=1) tokenizer = FakeTokenizer() model = FakeModel(outputs=["ANSWER: 42"]) questions = [ SimpleNamespace(question_text="q1"), SimpleNamespace(question_text="q2"), ] fake_env = FakeEnvironment(step_budget=1, questions=questions) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) rollout_func(["q2"], model, tokenizer, config) assert fake_env.last_reset_question_text == "q2" def test_rollout_incorrect_answer_not_marked_correct(monkeypatch) -> None: config = _build_config(step_budget=1) tokenizer = FakeTokenizer() model = FakeModel(outputs=["ANSWER: 42"]) fake_env = FakeEnvironment(step_budget=1, answer_is_correct=False) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) result = rollout_func(["q1"], model, tokenizer, config)[0] assert result["correct"] is False def test_rollout_handles_tensor_like_generate_outputs(monkeypatch) -> None: config = _build_config(step_budget=2) tokenizer = HFTensorTokenizer() model = HFTensorModel() fake_env = FakeEnvironment(step_budget=2) monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env) result = rollout_func(["q1"], model, tokenizer, config)[0] assert result["correct"] is True assert fake_env.actions[0].action_type == "ANSWER"