"""Smoke tests for the structured SQL environment loop.""" import json import sqlite3 import pytest import torch from sql_env.client import SQLEnvClient from sql_env.models import SQLAction, SQLObservation, SQLState from sql_env.server.sql_environment import SQLEnvironment from sql_env.server.test_sql_env import MockTokenizer @pytest.fixture def environment_paths(tmp_path): db_id = "testdb" db_root = tmp_path / "databases" db_dir = db_root / db_id db_dir.mkdir(parents=True) db_path = db_dir / f"{db_id}.sqlite" connection = sqlite3.connect(db_path) cursor = connection.cursor() cursor.execute( "CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT)" ) cursor.execute( "CREATE TABLE departments (id INTEGER PRIMARY KEY, name TEXT)" ) cursor.executemany( "INSERT INTO departments (id, name) VALUES (?, ?)", [(1, "engineering"), (2, "sales")], ) cursor.executemany( "INSERT INTO employees (id, name, dept) VALUES (?, ?, ?)", [(idx, f"emp-{idx}", "engineering") for idx in range(1, 26)], ) connection.commit() connection.close() questions_path = tmp_path / "questions.json" questions = [ { "question": "How many employees are there?", "db_id": db_id, "query": "SELECT COUNT(*) FROM employees", }, { "question": "How many departments are there?", "db_id": db_id, "query": "SELECT COUNT(*) FROM departments", }, ] questions_path.write_text(json.dumps(questions), encoding="utf-8") return str(questions_path), str(db_root) @pytest.fixture def env(environment_paths): questions_path, db_dir = environment_paths return SQLEnvironment( questions_path=questions_path, db_dir=db_dir, tokenizer=MockTokenizer(), ) class TestModels: def test_action_creation(self): action = SQLAction(action_type="DESCRIBE", argument="employees") assert action.action_type == "DESCRIBE" assert action.argument == "employees" def test_observation_creation(self): observation = SQLObservation( question="How many employees are there?", schema_info="Available tables:\n- employees", result="", error="", step_count=0, budget_remaining=15, action_history=[], done=False, reward=None, ) assert observation.done is False assert observation.reward is None assert observation.question.startswith("How many") def test_state_defaults(self): state = SQLState() assert state.history_messages == [] assert state.history_tokens == [] assert state.current_action_type == "QUERY" class TestEnvironment: def test_init_loads_questions(self, env): assert len(env.questions) == 2 assert env.step_budget == 15 def test_reset_returns_rich_observation(self, env): observation = env.reset(seed=42) assert isinstance(observation, SQLObservation) assert observation.done is False assert observation.reward is None assert observation.step_count == 0 assert observation.budget_remaining == 15 assert observation.error == "" assert observation.action_history == [] assert "Available tables:" in observation.schema_info assert "employees" in observation.schema_info assert "name TEXT" not in observation.schema_info def test_reset_seed_determinism(self, env): first = env.reset(seed=123) second = env.reset(seed=123) assert first.question == second.question def test_step_before_reset_is_graceful(self, env): observation = env.step(SQLAction(action_type="QUERY", argument="SELECT 1")) assert "No active episode" in observation.error assert observation.done is False def test_describe_reveals_columns_and_updates_schema(self, env): env.reset(seed=42) observation = env.step(SQLAction(action_type="DESCRIBE", argument="employees")) assert "Table 'employees' columns:" in observation.result assert "- name: TEXT" in observation.result assert observation.error == "" assert observation.step_count == 1 assert observation.budget_remaining == 14 assert observation.reward == pytest.approx(0.015) assert "Described tables:" in observation.schema_info assert "employees: id INTEGER" in observation.schema_info def test_sample_and_query_success(self, env): env.reset(seed=42) sample_obs = env.step(SQLAction(action_type="SAMPLE", argument="employees")) assert "Sample from 'employees':" in sample_obs.result assert sample_obs.error == "" assert sample_obs.reward == pytest.approx(0.015) query_obs = env.step( SQLAction(action_type="QUERY", argument="SELECT COUNT(*) FROM employees") ) assert "25" in query_obs.result assert query_obs.error == "" assert query_obs.reward is not None assert query_obs.reward > 0 def test_query_rejects_non_select(self, env): env.reset(seed=42) observation = env.step(SQLAction(action_type="QUERY", argument="DROP TABLE x")) assert "Only SELECT queries are allowed" in observation.error assert observation.step_count == 1 assert observation.budget_remaining == 14 assert observation.reward == pytest.approx(-0.005) def test_invalid_action_type_consumes_budget(self, env): env.reset(seed=42) observation = env.step(SQLAction(action_type="HACK", argument="x")) assert "Unknown action type" in observation.error assert observation.step_count == 1 assert observation.budget_remaining == 14 def test_empty_argument_consumes_budget(self, env): env.reset(seed=42) observation = env.step(SQLAction(action_type="QUERY", argument=" ")) assert "Argument cannot be empty" in observation.error assert observation.step_count == 1 assert observation.budget_remaining == 14 def test_answer_ends_episode_without_budget_decrement(self, env): env.reset(seed=42) before_budget = env._episode.budget observation = env.step(SQLAction(action_type="ANSWER", argument="25")) assert observation.done is True assert observation.reward == 1.0 assert observation.budget_remaining == before_budget def test_step_after_done_is_unchanged(self, env): env.reset(seed=42) terminal = env.step(SQLAction(action_type="ANSWER", argument="25")) again = env.step(SQLAction(action_type="QUERY", argument="SELECT 1")) assert again.done is True assert again.step_count == terminal.step_count assert again.budget_remaining == terminal.budget_remaining def test_budget_exhaustion_sets_done_and_zero_reward(self, environment_paths): questions_path, db_dir = environment_paths budget_env = SQLEnvironment( questions_path=questions_path, db_dir=db_dir, tokenizer=MockTokenizer(), step_budget=2, ) budget_env.reset(seed=42) first = budget_env.step(SQLAction(action_type="DESCRIBE", argument="employees")) assert first.done is False assert first.budget_remaining == 1 assert first.reward == pytest.approx(0.015) second = budget_env.step(SQLAction(action_type="QUERY", argument="SELECT 1")) assert second.done is True assert second.budget_remaining == 0 assert second.reward == 0.0 def test_query_truncates_to_20_rows(self, env): env.reset(seed=42) observation = env.step( SQLAction(action_type="QUERY", argument="SELECT id FROM employees") ) assert "... (truncated to 20 rows)" in observation.result def test_query_timeout_returns_error(self, env, monkeypatch): env.reset(seed=42) def _timeout(*args, **kwargs): del args del kwargs raise sqlite3.OperationalError("Query timed out after 5.0 seconds") monkeypatch.setattr(env, "_execute_sql", _timeout) observation = env.step( SQLAction( action_type="QUERY", argument=( "SELECT e1.id " "FROM employees e1 " "JOIN employees e2 ON 1=1 " "JOIN employees e3 ON 1=1" ), ) ) assert "timed out" in observation.error.lower() def test_open_db_connection_is_read_only(self, env): connection = env._open_db("testdb") with pytest.raises(sqlite3.OperationalError): connection.execute("INSERT INTO departments (id, name) VALUES (3, 'hr')") connection.close() class TestMessageToAction: def test_parses_prefixed_message(self, env): env.reset(seed=42) action = env.message_to_action( {"role": "user", "content": "DESCRIBE employees"} ) assert action.action_type == "DESCRIBE" assert action.argument == "employees" def test_defaults_to_query_for_unprefixed_message(self, env): env.reset(seed=42) action = env.message_to_action( {"role": "user", "content": "SELECT COUNT(*) FROM employees"} ) assert action.action_type == "QUERY" assert action.argument == "SELECT COUNT(*) FROM employees" def test_validates_message_shape(self, env): env.reset(seed=42) with pytest.raises(ValueError): env.message_to_action({"content": "missing role"}) with pytest.raises(ValueError): env.message_to_action({"role": "user"}) with pytest.raises(ValueError): env.message_to_action({"role": "user", "content": None}) class TestClientSerialization: def test_step_payload_serialization(self): client = SQLEnvClient.__new__(SQLEnvClient) action = SQLAction(action_type="QUERY", argument="SELECT 1") payload = client._step_payload(action) assert payload["action_type"] == "QUERY" assert payload["argument"] == "SELECT 1" assert "metadata" in payload def test_parse_result_observation_payload(self): client = SQLEnvClient.__new__(SQLEnvClient) payload = { "observation": { "question": "How many employees are there?", "schema_info": "Available tables:\n- employees", "result": "1. 25", "error": "", "step_count": 1, "budget_remaining": 14, "action_history": ["QUERY -> 1. 25"], "done": False, "reward": None, }, "done": False, "reward": None, } result = client._parse_result(payload) assert result.observation.question == "How many employees are there?" assert result.observation.step_count == 1 assert result.done is False def test_parse_state_deserializes_token_lists(self): client = SQLEnvClient.__new__(SQLEnvClient) state = client._parse_state( { "episode_id": "ep-1", "step_count": 2, "history_messages": [{"role": "user", "content": "hi"}], "history_tokens": [[1, 2, 3]], "current_action_type": "QUERY", } ) assert state.episode_id == "ep-1" assert state.step_count == 2 assert len(state.history_tokens) == 1 assert torch.equal(state.history_tokens[0], torch.tensor([1, 2, 3])) def test_client_message_to_action_infers_action(self): client = SQLEnvClient.__new__(SQLEnvClient) action = client.message_to_action( {"role": "user", "content": "show me sample rows from employees"}, tokenizer=MockTokenizer(), ) assert action.action_type == "SAMPLE" assert "sample" in action.argument.lower()