sql_env / tests /test_smoke.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Smoke tests for the structured SQL environment loop."""
import json
import sqlite3
import pytest
import torch
from sql_env.client import SQLEnvClient
from sql_env.models import SQLAction, SQLObservation, SQLState
from sql_env.server.sql_environment import SQLEnvironment
from sql_env.server.test_sql_env import MockTokenizer
@pytest.fixture
def environment_paths(tmp_path):
db_id = "testdb"
db_root = tmp_path / "databases"
db_dir = db_root / db_id
db_dir.mkdir(parents=True)
db_path = db_dir / f"{db_id}.sqlite"
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
cursor.execute(
"CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT)"
)
cursor.execute(
"CREATE TABLE departments (id INTEGER PRIMARY KEY, name TEXT)"
)
cursor.executemany(
"INSERT INTO departments (id, name) VALUES (?, ?)",
[(1, "engineering"), (2, "sales")],
)
cursor.executemany(
"INSERT INTO employees (id, name, dept) VALUES (?, ?, ?)",
[(idx, f"emp-{idx}", "engineering") for idx in range(1, 26)],
)
connection.commit()
connection.close()
questions_path = tmp_path / "questions.json"
questions = [
{
"question": "How many employees are there?",
"db_id": db_id,
"query": "SELECT COUNT(*) FROM employees",
},
{
"question": "How many departments are there?",
"db_id": db_id,
"query": "SELECT COUNT(*) FROM departments",
},
]
questions_path.write_text(json.dumps(questions), encoding="utf-8")
return str(questions_path), str(db_root)
@pytest.fixture
def env(environment_paths):
questions_path, db_dir = environment_paths
return SQLEnvironment(
questions_path=questions_path,
db_dir=db_dir,
tokenizer=MockTokenizer(),
)
class TestModels:
def test_action_creation(self):
action = SQLAction(action_type="DESCRIBE", argument="employees")
assert action.action_type == "DESCRIBE"
assert action.argument == "employees"
def test_observation_creation(self):
observation = SQLObservation(
question="How many employees are there?",
schema_info="Available tables:\n- employees",
result="",
error="",
step_count=0,
budget_remaining=15,
action_history=[],
done=False,
reward=None,
)
assert observation.done is False
assert observation.reward is None
assert observation.question.startswith("How many")
def test_state_defaults(self):
state = SQLState()
assert state.history_messages == []
assert state.history_tokens == []
assert state.current_action_type == "QUERY"
class TestEnvironment:
def test_init_loads_questions(self, env):
assert len(env.questions) == 2
assert env.step_budget == 15
def test_reset_returns_rich_observation(self, env):
observation = env.reset(seed=42)
assert isinstance(observation, SQLObservation)
assert observation.done is False
assert observation.reward is None
assert observation.step_count == 0
assert observation.budget_remaining == 15
assert observation.error == ""
assert observation.action_history == []
assert "Available tables:" in observation.schema_info
assert "employees" in observation.schema_info
assert "name TEXT" not in observation.schema_info
def test_reset_seed_determinism(self, env):
first = env.reset(seed=123)
second = env.reset(seed=123)
assert first.question == second.question
def test_step_before_reset_is_graceful(self, env):
observation = env.step(SQLAction(action_type="QUERY", argument="SELECT 1"))
assert "No active episode" in observation.error
assert observation.done is False
def test_describe_reveals_columns_and_updates_schema(self, env):
env.reset(seed=42)
observation = env.step(SQLAction(action_type="DESCRIBE", argument="employees"))
assert "Table 'employees' columns:" in observation.result
assert "- name: TEXT" in observation.result
assert observation.error == ""
assert observation.step_count == 1
assert observation.budget_remaining == 14
assert observation.reward == pytest.approx(0.015)
assert "Described tables:" in observation.schema_info
assert "employees: id INTEGER" in observation.schema_info
def test_sample_and_query_success(self, env):
env.reset(seed=42)
sample_obs = env.step(SQLAction(action_type="SAMPLE", argument="employees"))
assert "Sample from 'employees':" in sample_obs.result
assert sample_obs.error == ""
assert sample_obs.reward == pytest.approx(0.015)
query_obs = env.step(
SQLAction(action_type="QUERY", argument="SELECT COUNT(*) FROM employees")
)
assert "25" in query_obs.result
assert query_obs.error == ""
assert query_obs.reward is not None
assert query_obs.reward > 0
def test_query_rejects_non_select(self, env):
env.reset(seed=42)
observation = env.step(SQLAction(action_type="QUERY", argument="DROP TABLE x"))
assert "Only SELECT queries are allowed" in observation.error
assert observation.step_count == 1
assert observation.budget_remaining == 14
assert observation.reward == pytest.approx(-0.005)
def test_invalid_action_type_consumes_budget(self, env):
env.reset(seed=42)
observation = env.step(SQLAction(action_type="HACK", argument="x"))
assert "Unknown action type" in observation.error
assert observation.step_count == 1
assert observation.budget_remaining == 14
def test_empty_argument_consumes_budget(self, env):
env.reset(seed=42)
observation = env.step(SQLAction(action_type="QUERY", argument=" "))
assert "Argument cannot be empty" in observation.error
assert observation.step_count == 1
assert observation.budget_remaining == 14
def test_answer_ends_episode_without_budget_decrement(self, env):
env.reset(seed=42)
before_budget = env._episode.budget
observation = env.step(SQLAction(action_type="ANSWER", argument="25"))
assert observation.done is True
assert observation.reward == 1.0
assert observation.budget_remaining == before_budget
def test_step_after_done_is_unchanged(self, env):
env.reset(seed=42)
terminal = env.step(SQLAction(action_type="ANSWER", argument="25"))
again = env.step(SQLAction(action_type="QUERY", argument="SELECT 1"))
assert again.done is True
assert again.step_count == terminal.step_count
assert again.budget_remaining == terminal.budget_remaining
def test_budget_exhaustion_sets_done_and_zero_reward(self, environment_paths):
questions_path, db_dir = environment_paths
budget_env = SQLEnvironment(
questions_path=questions_path,
db_dir=db_dir,
tokenizer=MockTokenizer(),
step_budget=2,
)
budget_env.reset(seed=42)
first = budget_env.step(SQLAction(action_type="DESCRIBE", argument="employees"))
assert first.done is False
assert first.budget_remaining == 1
assert first.reward == pytest.approx(0.015)
second = budget_env.step(SQLAction(action_type="QUERY", argument="SELECT 1"))
assert second.done is True
assert second.budget_remaining == 0
assert second.reward == 0.0
def test_query_truncates_to_20_rows(self, env):
env.reset(seed=42)
observation = env.step(
SQLAction(action_type="QUERY", argument="SELECT id FROM employees")
)
assert "... (truncated to 20 rows)" in observation.result
def test_query_timeout_returns_error(self, env, monkeypatch):
env.reset(seed=42)
def _timeout(*args, **kwargs):
del args
del kwargs
raise sqlite3.OperationalError("Query timed out after 5.0 seconds")
monkeypatch.setattr(env, "_execute_sql", _timeout)
observation = env.step(
SQLAction(
action_type="QUERY",
argument=(
"SELECT e1.id "
"FROM employees e1 "
"JOIN employees e2 ON 1=1 "
"JOIN employees e3 ON 1=1"
),
)
)
assert "timed out" in observation.error.lower()
def test_open_db_connection_is_read_only(self, env):
connection = env._open_db("testdb")
with pytest.raises(sqlite3.OperationalError):
connection.execute("INSERT INTO departments (id, name) VALUES (3, 'hr')")
connection.close()
class TestMessageToAction:
def test_parses_prefixed_message(self, env):
env.reset(seed=42)
action = env.message_to_action(
{"role": "user", "content": "DESCRIBE employees"}
)
assert action.action_type == "DESCRIBE"
assert action.argument == "employees"
def test_defaults_to_query_for_unprefixed_message(self, env):
env.reset(seed=42)
action = env.message_to_action(
{"role": "user", "content": "SELECT COUNT(*) FROM employees"}
)
assert action.action_type == "QUERY"
assert action.argument == "SELECT COUNT(*) FROM employees"
def test_validates_message_shape(self, env):
env.reset(seed=42)
with pytest.raises(ValueError):
env.message_to_action({"content": "missing role"})
with pytest.raises(ValueError):
env.message_to_action({"role": "user"})
with pytest.raises(ValueError):
env.message_to_action({"role": "user", "content": None})
class TestClientSerialization:
def test_step_payload_serialization(self):
client = SQLEnvClient.__new__(SQLEnvClient)
action = SQLAction(action_type="QUERY", argument="SELECT 1")
payload = client._step_payload(action)
assert payload["action_type"] == "QUERY"
assert payload["argument"] == "SELECT 1"
assert "metadata" in payload
def test_parse_result_observation_payload(self):
client = SQLEnvClient.__new__(SQLEnvClient)
payload = {
"observation": {
"question": "How many employees are there?",
"schema_info": "Available tables:\n- employees",
"result": "1. 25",
"error": "",
"step_count": 1,
"budget_remaining": 14,
"action_history": ["QUERY -> 1. 25"],
"done": False,
"reward": None,
},
"done": False,
"reward": None,
}
result = client._parse_result(payload)
assert result.observation.question == "How many employees are there?"
assert result.observation.step_count == 1
assert result.done is False
def test_parse_state_deserializes_token_lists(self):
client = SQLEnvClient.__new__(SQLEnvClient)
state = client._parse_state(
{
"episode_id": "ep-1",
"step_count": 2,
"history_messages": [{"role": "user", "content": "hi"}],
"history_tokens": [[1, 2, 3]],
"current_action_type": "QUERY",
}
)
assert state.episode_id == "ep-1"
assert state.step_count == 2
assert len(state.history_tokens) == 1
assert torch.equal(state.history_tokens[0], torch.tensor([1, 2, 3]))
def test_client_message_to_action_infers_action(self):
client = SQLEnvClient.__new__(SQLEnvClient)
action = client.message_to_action(
{"role": "user", "content": "show me sample rows from employees"},
tokenizer=MockTokenizer(),
)
assert action.action_type == "SAMPLE"
assert "sample" in action.argument.lower()