| """Smoke tests for the structured SQL environment loop.""" |
|
|
| import json |
| import sqlite3 |
|
|
| import pytest |
| import torch |
|
|
| from sql_env.client import SQLEnvClient |
| from sql_env.models import SQLAction, SQLObservation, SQLState |
| from sql_env.server.sql_environment import SQLEnvironment |
| from sql_env.server.test_sql_env import MockTokenizer |
|
|
|
|
| @pytest.fixture |
| def environment_paths(tmp_path): |
| db_id = "testdb" |
| db_root = tmp_path / "databases" |
| db_dir = db_root / db_id |
| db_dir.mkdir(parents=True) |
| db_path = db_dir / f"{db_id}.sqlite" |
|
|
| connection = sqlite3.connect(db_path) |
| cursor = connection.cursor() |
| cursor.execute( |
| "CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT)" |
| ) |
| cursor.execute( |
| "CREATE TABLE departments (id INTEGER PRIMARY KEY, name TEXT)" |
| ) |
| cursor.executemany( |
| "INSERT INTO departments (id, name) VALUES (?, ?)", |
| [(1, "engineering"), (2, "sales")], |
| ) |
| cursor.executemany( |
| "INSERT INTO employees (id, name, dept) VALUES (?, ?, ?)", |
| [(idx, f"emp-{idx}", "engineering") for idx in range(1, 26)], |
| ) |
| connection.commit() |
| connection.close() |
|
|
| questions_path = tmp_path / "questions.json" |
| questions = [ |
| { |
| "question": "How many employees are there?", |
| "db_id": db_id, |
| "query": "SELECT COUNT(*) FROM employees", |
| }, |
| { |
| "question": "How many departments are there?", |
| "db_id": db_id, |
| "query": "SELECT COUNT(*) FROM departments", |
| }, |
| ] |
| questions_path.write_text(json.dumps(questions), encoding="utf-8") |
|
|
| return str(questions_path), str(db_root) |
|
|
|
|
| @pytest.fixture |
| def env(environment_paths): |
| questions_path, db_dir = environment_paths |
| return SQLEnvironment( |
| questions_path=questions_path, |
| db_dir=db_dir, |
| tokenizer=MockTokenizer(), |
| ) |
|
|
|
|
| class TestModels: |
| def test_action_creation(self): |
| action = SQLAction(action_type="DESCRIBE", argument="employees") |
| assert action.action_type == "DESCRIBE" |
| assert action.argument == "employees" |
|
|
| def test_observation_creation(self): |
| observation = SQLObservation( |
| question="How many employees are there?", |
| schema_info="Available tables:\n- employees", |
| result="", |
| error="", |
| step_count=0, |
| budget_remaining=15, |
| action_history=[], |
| done=False, |
| reward=None, |
| ) |
| assert observation.done is False |
| assert observation.reward is None |
| assert observation.question.startswith("How many") |
|
|
| def test_state_defaults(self): |
| state = SQLState() |
| assert state.history_messages == [] |
| assert state.history_tokens == [] |
| assert state.current_action_type == "QUERY" |
|
|
|
|
| class TestEnvironment: |
| def test_init_loads_questions(self, env): |
| assert len(env.questions) == 2 |
| assert env.step_budget == 15 |
|
|
| def test_reset_returns_rich_observation(self, env): |
| observation = env.reset(seed=42) |
| assert isinstance(observation, SQLObservation) |
| assert observation.done is False |
| assert observation.reward is None |
| assert observation.step_count == 0 |
| assert observation.budget_remaining == 15 |
| assert observation.error == "" |
| assert observation.action_history == [] |
| assert "Available tables:" in observation.schema_info |
| assert "employees" in observation.schema_info |
| assert "name TEXT" not in observation.schema_info |
|
|
| def test_reset_seed_determinism(self, env): |
| first = env.reset(seed=123) |
| second = env.reset(seed=123) |
| assert first.question == second.question |
|
|
| def test_step_before_reset_is_graceful(self, env): |
| observation = env.step(SQLAction(action_type="QUERY", argument="SELECT 1")) |
| assert "No active episode" in observation.error |
| assert observation.done is False |
|
|
| def test_describe_reveals_columns_and_updates_schema(self, env): |
| env.reset(seed=42) |
| observation = env.step(SQLAction(action_type="DESCRIBE", argument="employees")) |
| assert "Table 'employees' columns:" in observation.result |
| assert "- name: TEXT" in observation.result |
| assert observation.error == "" |
| assert observation.step_count == 1 |
| assert observation.budget_remaining == 14 |
| assert observation.reward == pytest.approx(0.015) |
| assert "Described tables:" in observation.schema_info |
| assert "employees: id INTEGER" in observation.schema_info |
|
|
| def test_sample_and_query_success(self, env): |
| env.reset(seed=42) |
| sample_obs = env.step(SQLAction(action_type="SAMPLE", argument="employees")) |
| assert "Sample from 'employees':" in sample_obs.result |
| assert sample_obs.error == "" |
| assert sample_obs.reward == pytest.approx(0.015) |
|
|
| query_obs = env.step( |
| SQLAction(action_type="QUERY", argument="SELECT COUNT(*) FROM employees") |
| ) |
| assert "25" in query_obs.result |
| assert query_obs.error == "" |
| assert query_obs.reward is not None |
| assert query_obs.reward > 0 |
|
|
| def test_query_rejects_non_select(self, env): |
| env.reset(seed=42) |
| observation = env.step(SQLAction(action_type="QUERY", argument="DROP TABLE x")) |
| assert "Only SELECT queries are allowed" in observation.error |
| assert observation.step_count == 1 |
| assert observation.budget_remaining == 14 |
| assert observation.reward == pytest.approx(-0.005) |
|
|
| def test_invalid_action_type_consumes_budget(self, env): |
| env.reset(seed=42) |
| observation = env.step(SQLAction(action_type="HACK", argument="x")) |
| assert "Unknown action type" in observation.error |
| assert observation.step_count == 1 |
| assert observation.budget_remaining == 14 |
|
|
| def test_empty_argument_consumes_budget(self, env): |
| env.reset(seed=42) |
| observation = env.step(SQLAction(action_type="QUERY", argument=" ")) |
| assert "Argument cannot be empty" in observation.error |
| assert observation.step_count == 1 |
| assert observation.budget_remaining == 14 |
|
|
| def test_answer_ends_episode_without_budget_decrement(self, env): |
| env.reset(seed=42) |
| before_budget = env._episode.budget |
| observation = env.step(SQLAction(action_type="ANSWER", argument="25")) |
| assert observation.done is True |
| assert observation.reward == 1.0 |
| assert observation.budget_remaining == before_budget |
|
|
| def test_step_after_done_is_unchanged(self, env): |
| env.reset(seed=42) |
| terminal = env.step(SQLAction(action_type="ANSWER", argument="25")) |
| again = env.step(SQLAction(action_type="QUERY", argument="SELECT 1")) |
| assert again.done is True |
| assert again.step_count == terminal.step_count |
| assert again.budget_remaining == terminal.budget_remaining |
|
|
| def test_budget_exhaustion_sets_done_and_zero_reward(self, environment_paths): |
| questions_path, db_dir = environment_paths |
| budget_env = SQLEnvironment( |
| questions_path=questions_path, |
| db_dir=db_dir, |
| tokenizer=MockTokenizer(), |
| step_budget=2, |
| ) |
| budget_env.reset(seed=42) |
|
|
| first = budget_env.step(SQLAction(action_type="DESCRIBE", argument="employees")) |
| assert first.done is False |
| assert first.budget_remaining == 1 |
| assert first.reward == pytest.approx(0.015) |
|
|
| second = budget_env.step(SQLAction(action_type="QUERY", argument="SELECT 1")) |
| assert second.done is True |
| assert second.budget_remaining == 0 |
| assert second.reward == 0.0 |
|
|
| def test_query_truncates_to_20_rows(self, env): |
| env.reset(seed=42) |
| observation = env.step( |
| SQLAction(action_type="QUERY", argument="SELECT id FROM employees") |
| ) |
| assert "... (truncated to 20 rows)" in observation.result |
|
|
| def test_query_timeout_returns_error(self, env, monkeypatch): |
| env.reset(seed=42) |
|
|
| def _timeout(*args, **kwargs): |
| del args |
| del kwargs |
| raise sqlite3.OperationalError("Query timed out after 5.0 seconds") |
|
|
| monkeypatch.setattr(env, "_execute_sql", _timeout) |
|
|
| observation = env.step( |
| SQLAction( |
| action_type="QUERY", |
| argument=( |
| "SELECT e1.id " |
| "FROM employees e1 " |
| "JOIN employees e2 ON 1=1 " |
| "JOIN employees e3 ON 1=1" |
| ), |
| ) |
| ) |
| assert "timed out" in observation.error.lower() |
|
|
| def test_open_db_connection_is_read_only(self, env): |
| connection = env._open_db("testdb") |
| with pytest.raises(sqlite3.OperationalError): |
| connection.execute("INSERT INTO departments (id, name) VALUES (3, 'hr')") |
| connection.close() |
|
|
|
|
| class TestMessageToAction: |
| def test_parses_prefixed_message(self, env): |
| env.reset(seed=42) |
| action = env.message_to_action( |
| {"role": "user", "content": "DESCRIBE employees"} |
| ) |
| assert action.action_type == "DESCRIBE" |
| assert action.argument == "employees" |
|
|
| def test_defaults_to_query_for_unprefixed_message(self, env): |
| env.reset(seed=42) |
| action = env.message_to_action( |
| {"role": "user", "content": "SELECT COUNT(*) FROM employees"} |
| ) |
| assert action.action_type == "QUERY" |
| assert action.argument == "SELECT COUNT(*) FROM employees" |
|
|
| def test_validates_message_shape(self, env): |
| env.reset(seed=42) |
| with pytest.raises(ValueError): |
| env.message_to_action({"content": "missing role"}) |
| with pytest.raises(ValueError): |
| env.message_to_action({"role": "user"}) |
| with pytest.raises(ValueError): |
| env.message_to_action({"role": "user", "content": None}) |
|
|
|
|
| class TestClientSerialization: |
| def test_step_payload_serialization(self): |
| client = SQLEnvClient.__new__(SQLEnvClient) |
| action = SQLAction(action_type="QUERY", argument="SELECT 1") |
| payload = client._step_payload(action) |
| assert payload["action_type"] == "QUERY" |
| assert payload["argument"] == "SELECT 1" |
| assert "metadata" in payload |
|
|
| def test_parse_result_observation_payload(self): |
| client = SQLEnvClient.__new__(SQLEnvClient) |
| payload = { |
| "observation": { |
| "question": "How many employees are there?", |
| "schema_info": "Available tables:\n- employees", |
| "result": "1. 25", |
| "error": "", |
| "step_count": 1, |
| "budget_remaining": 14, |
| "action_history": ["QUERY -> 1. 25"], |
| "done": False, |
| "reward": None, |
| }, |
| "done": False, |
| "reward": None, |
| } |
| result = client._parse_result(payload) |
| assert result.observation.question == "How many employees are there?" |
| assert result.observation.step_count == 1 |
| assert result.done is False |
|
|
| def test_parse_state_deserializes_token_lists(self): |
| client = SQLEnvClient.__new__(SQLEnvClient) |
| state = client._parse_state( |
| { |
| "episode_id": "ep-1", |
| "step_count": 2, |
| "history_messages": [{"role": "user", "content": "hi"}], |
| "history_tokens": [[1, 2, 3]], |
| "current_action_type": "QUERY", |
| } |
| ) |
| assert state.episode_id == "ep-1" |
| assert state.step_count == 2 |
| assert len(state.history_tokens) == 1 |
| assert torch.equal(state.history_tokens[0], torch.tensor([1, 2, 3])) |
|
|
| def test_client_message_to_action_infers_action(self): |
| client = SQLEnvClient.__new__(SQLEnvClient) |
| action = client.message_to_action( |
| {"role": "user", "content": "show me sample rows from employees"}, |
| tokenizer=MockTokenizer(), |
| ) |
| assert action.action_type == "SAMPLE" |
| assert "sample" in action.argument.lower() |
|
|