sql_env / tests /test_evaluation.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Unit tests for evaluation package random policy and evaluate()."""
import json
import sqlite3
import pytest
from sql_env.evaluation import RandomPolicy, evaluate
from sql_env.models import SQLAction, SQLObservation
from sql_env.server.sql_environment import SQLEnvironment
from sql_env.server.test_sql_env import MockTokenizer
def _build_sql_environment(tmp_path, *, db_id: str) -> SQLEnvironment:
db_root = tmp_path / "databases"
db_dir = db_root / db_id
db_dir.mkdir(parents=True)
db_path = db_dir / f"{db_id}.sqlite"
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
cursor.execute(
"CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT)"
)
cursor.executemany(
"INSERT INTO employees (id, name, dept) VALUES (?, ?, ?)",
[
(1, "Alice", "engineering"),
(2, "Bob", "engineering"),
(3, "Cara", "sales"),
],
)
connection.commit()
connection.close()
questions_path = tmp_path / "questions.json"
questions_path.write_text(
json.dumps(
[
{
"question": "How many employees are there?",
"db_id": db_id,
"query": "SELECT COUNT(*) FROM employees",
}
]
),
encoding="utf-8",
)
return SQLEnvironment(
questions_path=str(questions_path),
db_dir=str(db_root),
tokenizer=MockTokenizer(),
)
def _build_observation(*, budget_remaining: int, result: str = "") -> SQLObservation:
return SQLObservation(
question="How many rows?",
schema_info="Available tables:\n- employees\n- departments",
result=result,
error="",
step_count=0,
budget_remaining=budget_remaining,
action_history=[],
done=False,
reward=None,
)
def _terminal_observation(*, reward: float) -> SQLObservation:
return SQLObservation(
question="How many rows?",
schema_info="Available tables:\n- employees\n- departments",
result="",
error="",
step_count=1,
budget_remaining=0,
action_history=[],
done=True,
reward=reward,
)
class _FixedPolicy:
def select_action(self, observation: SQLObservation) -> SQLAction:
return SQLAction(action_type="QUERY", argument="SELECT 1")
class _RaisingPolicy:
def __init__(self, fail_on_episode: int) -> None:
self._fail_on_episode = fail_on_episode
self._episode_index = -1
def select_action(self, observation: SQLObservation) -> SQLAction:
if observation.step_count == 0:
self._episode_index += 1
if self._episode_index == self._fail_on_episode:
raise RuntimeError("policy failed")
return SQLAction(action_type="QUERY", argument="SELECT 1")
class _SeedTrackingEnv:
def __init__(self, rewards: list[float]) -> None:
self._rewards = rewards
self._episode_index = -1
self.reset_seeds: list[int | None] = []
def reset(self, *, seed: int | None = None) -> SQLObservation:
self.reset_seeds.append(seed)
self._episode_index += 1
return _build_observation(budget_remaining=2)
def step(self, action: SQLAction) -> SQLObservation:
del action
reward = self._rewards[self._episode_index]
return _terminal_observation(reward=reward)
class _FlakyEnv(_SeedTrackingEnv):
def __init__(self, rewards: list[float], fail_on_episode: int) -> None:
super().__init__(rewards)
self._fail_on_episode = fail_on_episode
def step(self, action: SQLAction) -> SQLObservation:
if self._episode_index == self._fail_on_episode:
raise RuntimeError("step failed")
return super().step(action)
def test_random_policy_explores_when_budget_gt_one() -> None:
policy = RandomPolicy(seed=42)
observation = _build_observation(budget_remaining=10)
action = policy.select_action(observation)
assert action.action_type in {"DESCRIBE", "SAMPLE", "QUERY"}
def test_random_policy_answers_when_budget_eq_one() -> None:
policy = RandomPolicy(seed=42)
observation = _build_observation(budget_remaining=1)
action = policy.select_action(observation)
assert action.action_type == "ANSWER"
def test_random_policy_returns_sql_action() -> None:
policy = RandomPolicy(seed=7)
observation = _build_observation(budget_remaining=10)
action = policy.select_action(observation)
assert isinstance(action, SQLAction)
def test_random_policy_deterministic_with_seed() -> None:
observation = _build_observation(budget_remaining=10)
first = RandomPolicy(seed=123)
second = RandomPolicy(seed=123)
first_actions = [first.select_action(observation) for _ in range(25)]
second_actions = [second.select_action(observation) for _ in range(25)]
assert first_actions == second_actions
def test_random_policy_explores_all_action_types() -> None:
policy = RandomPolicy(seed=1)
observation = _build_observation(budget_remaining=10)
action_types = {policy.select_action(observation).action_type for _ in range(200)}
assert action_types == {"DESCRIBE", "SAMPLE", "QUERY"}
def test_random_policy_uses_result_rows_for_answer_candidates() -> None:
policy = RandomPolicy(seed=0)
observation = _build_observation(
budget_remaining=1,
result="1. engineering | 25\n2. sales | 10",
)
action = policy.select_action(observation)
assert action.action_type == "ANSWER"
assert action.argument in {
"engineering",
"25",
"sales",
"10",
"engineering | 25",
"sales | 10",
}
def test_evaluate_happy_path() -> None:
env = _SeedTrackingEnv([1.0, 0.0, 1.0])
result = evaluate(env, _FixedPolicy(), n_episodes=3)
assert result.n_episodes == 3
assert result.n_completed == 3
assert len(result.episodes) == 3
assert result.success_rate == 2 / 3
assert result.avg_reward == 2 / 3
assert result.avg_steps == 1.0
def test_evaluate_zero_episodes_returns_zero_values() -> None:
env = _SeedTrackingEnv([])
result = evaluate(env, _FixedPolicy(), n_episodes=0)
assert result == result.__class__(
success_rate=0.0,
avg_reward=0.0,
avg_steps=0.0,
n_episodes=0,
n_completed=0,
episodes=[],
)
assert env.reset_seeds == []
def test_evaluate_negative_episodes_raises() -> None:
env = _SeedTrackingEnv([])
try:
evaluate(env, _FixedPolicy(), n_episodes=-1)
except ValueError as exc:
assert str(exc) == "n_episodes must be >= 0"
else:
raise AssertionError("Expected ValueError for negative n_episodes")
def test_evaluate_uses_seed_plus_episode_index() -> None:
env = _SeedTrackingEnv([1.0, 1.0, 1.0])
evaluate(env, _FixedPolicy(), n_episodes=3, seed=100)
assert env.reset_seeds == [100, 101, 102]
def test_evaluate_records_episode_errors_and_continues() -> None:
env = _FlakyEnv([1.0, 1.0, 1.0], fail_on_episode=1)
result = evaluate(env, _FixedPolicy(), n_episodes=3)
assert result.n_episodes == 3
assert len(result.episodes) == 3
assert result.n_completed == 2
assert result.episodes[1].error == "step failed"
assert result.episodes[2].error is None
def test_evaluate_averages_exclude_failed_episodes() -> None:
env = _FlakyEnv([1.0, 0.0, 0.0], fail_on_episode=1)
result = evaluate(env, _FixedPolicy(), n_episodes=3)
assert result.n_completed == 2
assert result.avg_reward == 0.5
assert result.avg_steps == 1.0
assert result.success_rate == 0.5
def test_evaluate_policy_exception_recorded() -> None:
env = _SeedTrackingEnv([1.0, 1.0, 1.0])
result = evaluate(env, _RaisingPolicy(fail_on_episode=1), n_episodes=3)
assert result.n_completed == 2
assert result.episodes[1].error == "policy failed"
def test_evaluate_progress_callback_receives_episode_progress() -> None:
env = _SeedTrackingEnv([1.0, 1.0, 1.0])
calls: list[tuple[int, int]] = []
evaluate(
env,
_FixedPolicy(),
n_episodes=3,
progress_callback=lambda current, total: calls.append((current, total)),
)
assert calls == [(1, 3), (2, 3), (3, 3)]
def test_evaluate_integration_with_sql_environment(tmp_path) -> None:
env = _build_sql_environment(tmp_path, db_id="integration_eval")
result = evaluate(env, RandomPolicy(seed=42), n_episodes=10, seed=0)
assert result.n_episodes == 10
assert result.n_completed == 10
assert len(result.episodes) == 10
assert result.success_rate == sum(int(e.correct) for e in result.episodes) / 10
assert result.avg_reward == pytest.approx(
sum(e.total_reward for e in result.episodes) / 10
)
def test_evaluate_integration_is_deterministic_with_seeds(tmp_path) -> None:
env_a = _build_sql_environment(tmp_path / "run_a", db_id="integration_eval")
env_b = _build_sql_environment(tmp_path / "run_b", db_id="integration_eval")
result_a = evaluate(env_a, RandomPolicy(seed=42), n_episodes=10, seed=0)
result_b = evaluate(env_b, RandomPolicy(seed=42), n_episodes=10, seed=0)
assert result_a == result_b