sql_env / tests /unit /test_rollout.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Unit tests for training rollout helpers."""
from types import SimpleNamespace
from sql_env.models import SQLAction
from sql_env.models import SQLObservation
from sql_env.training.config import GRPOConfig
from sql_env.training import rollout as rollout_module
from sql_env.training.rollout import parse_model_output, rollout_func
class FakeTokenizer:
def __init__(self) -> None:
self.messages_seen: list[list[dict[str, str]]] = []
def apply_chat_template(
self,
messages: list[dict[str, str]],
tokenize: bool = False,
add_generation_prompt: bool = True,
) -> str:
del tokenize
del add_generation_prompt
self.messages_seen.append(messages)
return "\n".join(f"{msg['role']}::{msg['content']}" for msg in messages)
class FakeModel:
def __init__(self, outputs: list[str]) -> None:
self._outputs = outputs
def generate(self, prompt: str, max_new_tokens: int) -> str:
del prompt
del max_new_tokens
if self._outputs:
return self._outputs.pop(0)
return "ANSWER: done"
class FakeEnvironment:
def __init__(
self,
*,
step_budget: int,
done_after: int | None = None,
questions: list[SimpleNamespace] | None = None,
answer_is_correct: bool = True,
) -> None:
self._step_budget = step_budget
self._done_after = done_after if done_after is not None else step_budget
self._step = 0
self.actions: list[SQLAction] = []
self.state = SimpleNamespace(episode_id="ep-test")
self.questions = questions if questions is not None else []
self.last_reset_question_text: str | None = None
self.answer_is_correct = answer_is_correct
def reset(self, *, seed: int | None = None) -> SQLObservation:
del seed
self._step = 0
self.actions = []
if self.questions:
self.last_reset_question_text = self.questions[0].question_text
return self._build_observation(done=False, error="", result="")
def step(self, action: SQLAction) -> SQLObservation:
self.actions.append(action)
self._step += 1
error = ""
result = "ok"
reward = 0.0
if action.argument == "hello world random text":
error = "unparseable action"
if action.action_type == "QUERY" and not error:
reward = 0.1
done = self._step >= self._done_after
if action.action_type == "ANSWER":
done = True
if self.answer_is_correct:
result = "Answer submitted: correct."
reward = 1.0
else:
result = "Answer submitted: incorrect."
reward = 0.0
return self._build_observation(
done=done, error=error, result=result, reward=reward
)
def _build_observation(
self,
*,
done: bool,
error: str,
result: str,
reward: float | None = None,
) -> SQLObservation:
return SQLObservation(
question="How many students?",
schema_info="Available tables:\n- students",
result=result,
error=error,
step_count=self._step,
budget_remaining=max(0, self._step_budget - self._step),
action_history=[f"step-{idx}" for idx in range(self._step)],
done=done,
reward=reward,
)
class HFTokenizer:
def __init__(self) -> None:
self.messages_seen: list[list[dict[str, str]]] = []
def apply_chat_template(
self,
messages: list[dict[str, str]],
tokenize: bool = False,
add_generation_prompt: bool = True,
) -> str:
del tokenize
del add_generation_prompt
self.messages_seen.append(messages)
return "prompt"
def __call__(
self, text: str, return_tensors: str = "pt"
) -> dict[str, list[list[int]]]:
del text
del return_tensors
return {"input_ids": [[1, 2, 3]], "attention_mask": [[1, 1, 1]]}
def decode(self, token_ids, skip_special_tokens: bool = True) -> str:
del skip_special_tokens
if token_ids == [4, 5, 6]:
return "ANSWER: 42"
return "QUERY: SELECT 1"
class HFModel:
def generate(self, **kwargs) -> list[list[int]]:
del kwargs
return [[1, 2, 3, 4, 5, 6]]
class FakeTensor:
def __init__(self, values: list[list[int]]) -> None:
self._values = values
def tolist(self) -> list[list[int]]:
return self._values
class HFTensorTokenizer(HFTokenizer):
def __call__(self, text: str, return_tensors: str = "pt") -> dict[str, FakeTensor]:
del text
del return_tensors
return {
"input_ids": FakeTensor([[1, 2, 3]]),
"attention_mask": FakeTensor([[1, 1, 1]]),
}
class HFTensorModel:
def generate(self, **kwargs) -> FakeTensor:
del kwargs
return FakeTensor([[1, 2, 3, 4, 5, 6]])
def _build_config(step_budget: int = 5) -> GRPOConfig:
return GRPOConfig(
questions_path="data/questions/questions_train.json",
db_dir="data/databases",
output_dir="outputs/grpo_test",
step_budget=step_budget,
)
def test_parse_describe() -> None:
action = parse_model_output("DESCRIBE employees")
assert action == SQLAction(action_type="DESCRIBE", argument="employees")
def test_parse_sample() -> None:
action = parse_model_output("SAMPLE departments")
assert action == SQLAction(action_type="SAMPLE", argument="departments")
def test_parse_query() -> None:
action = parse_model_output("QUERY SELECT COUNT(*) FROM employees")
assert action == SQLAction(
action_type="QUERY",
argument="SELECT COUNT(*) FROM employees",
)
def test_parse_answer() -> None:
action = parse_model_output("ANSWER 42")
assert action == SQLAction(action_type="ANSWER", argument="42")
def test_parse_case_insensitive() -> None:
action = parse_model_output("describe employees")
assert action == SQLAction(action_type="DESCRIBE", argument="employees")
def test_parse_with_colon_separator() -> None:
action = parse_model_output("QUERY: SELECT 1")
assert action == SQLAction(action_type="QUERY", argument="SELECT 1")
def test_parse_garbage_fallback() -> None:
raw = "hello world random text"
action = parse_model_output(raw)
assert action == SQLAction(action_type="QUERY", argument=raw)
def test_parse_empty_string_fallback() -> None:
action = parse_model_output("")
assert action == SQLAction(action_type="QUERY", argument="")
def test_parse_only_action_no_argument() -> None:
raw = "DESCRIBE"
action = parse_model_output(raw)
assert action == SQLAction(action_type="QUERY", argument=raw)
def test_parse_multiline_output() -> None:
action = parse_model_output("Let me think...\nQUERY SELECT 1")
assert action == SQLAction(action_type="QUERY", argument="SELECT 1")
def test_parse_whitespace_padded() -> None:
action = parse_model_output(" ANSWER 42 ")
assert action == SQLAction(action_type="ANSWER", argument="42")
def test_rollout_returns_completions(monkeypatch) -> None:
config = _build_config(step_budget=5)
tokenizer = FakeTokenizer()
model = FakeModel(outputs=["ANSWER: 42"])
fake_env = FakeEnvironment(step_budget=5, done_after=5)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
results = rollout_func(["Count students"], model, tokenizer, config)
assert len(results) == 1
result = results[0]
assert "content" in result
assert "metadata" in result
assert "correct" in result
assert "progress" in result
assert "operational" in result
def test_rollout_batch_size(monkeypatch) -> None:
config = _build_config(step_budget=4)
tokenizer = FakeTokenizer()
model = FakeModel(outputs=["ANSWER: 1", "ANSWER: 2", "ANSWER: 3"])
fake_env = FakeEnvironment(step_budget=4)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
results = rollout_func(["q1", "q2", "q3"], model, tokenizer, config)
assert len(results) == 3
def test_rollout_episode_terminates(monkeypatch) -> None:
config = _build_config(step_budget=5)
tokenizer = FakeTokenizer()
model = FakeModel(outputs=["QUERY: SELECT 1"] * 20)
fake_env = FakeEnvironment(step_budget=5, done_after=50)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
results = rollout_func(["q1"], model, tokenizer, config)
assert results[0]["metadata"]["step_count"] <= 5
def test_rollout_metadata_present(monkeypatch) -> None:
config = _build_config(step_budget=3)
tokenizer = FakeTokenizer()
model = FakeModel(outputs=["ANSWER: 42"])
fake_env = FakeEnvironment(step_budget=3)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
result = rollout_func(["q1"], model, tokenizer, config)[0]
assert "correct" in result
assert "progress" in result
assert "operational" in result
assert "episode_id" in result["metadata"]
assert "step_count" in result["metadata"]
assert "done" in result["metadata"]
def test_rollout_unparseable_action(monkeypatch) -> None:
config = _build_config(step_budget=3)
tokenizer = FakeTokenizer()
model = FakeModel(outputs=["hello world random text", "ANSWER: 42"])
fake_env = FakeEnvironment(step_budget=3)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
results = rollout_func(["q1"], model, tokenizer, config)
assert len(results) == 1
assert fake_env.actions[0].action_type == "QUERY"
assert fake_env.actions[0].argument == "hello world random text"
def test_rollout_truncation(monkeypatch) -> None:
config = _build_config(step_budget=20)
tokenizer = FakeTokenizer()
model = FakeModel(outputs=["QUERY: SELECT 1"] * 20)
fake_env = FakeEnvironment(step_budget=20, done_after=20)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
rollout_func(["q1"], model, tokenizer, config)
assert tokenizer.messages_seen
assert any(len(messages) <= 8 for messages in tokenizer.messages_seen[6:])
def test_rollout_uses_hf_style_generate(monkeypatch) -> None:
config = _build_config(step_budget=2)
tokenizer = HFTokenizer()
model = HFModel()
fake_env = FakeEnvironment(step_budget=2)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
result = rollout_func(["q1"], model, tokenizer, config)[0]
assert result["correct"] is True
assert fake_env.actions[0].action_type == "ANSWER"
def test_rollout_binds_environment_to_prompt_when_available(monkeypatch) -> None:
config = _build_config(step_budget=1)
tokenizer = FakeTokenizer()
model = FakeModel(outputs=["ANSWER: 42"])
questions = [
SimpleNamespace(question_text="q1"),
SimpleNamespace(question_text="q2"),
]
fake_env = FakeEnvironment(step_budget=1, questions=questions)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
rollout_func(["q2"], model, tokenizer, config)
assert fake_env.last_reset_question_text == "q2"
def test_rollout_incorrect_answer_not_marked_correct(monkeypatch) -> None:
config = _build_config(step_budget=1)
tokenizer = FakeTokenizer()
model = FakeModel(outputs=["ANSWER: 42"])
fake_env = FakeEnvironment(step_budget=1, answer_is_correct=False)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
result = rollout_func(["q1"], model, tokenizer, config)[0]
assert result["correct"] is False
def test_rollout_handles_tensor_like_generate_outputs(monkeypatch) -> None:
config = _build_config(step_budget=2)
tokenizer = HFTensorTokenizer()
model = HFTensorModel()
fake_env = FakeEnvironment(step_budget=2)
monkeypatch.setattr(rollout_module, "_build_environment", lambda *_: fake_env)
result = rollout_func(["q1"], model, tokenizer, config)[0]
assert result["correct"] is True
assert fake_env.actions[0].action_type == "ANSWER"