| """Unit tests for evaluation package random policy and evaluate().""" |
|
|
| import json |
| import sqlite3 |
|
|
| import pytest |
|
|
| from sql_env.evaluation import RandomPolicy, evaluate |
| from sql_env.models import SQLAction, SQLObservation |
| from sql_env.server.sql_environment import SQLEnvironment |
| from sql_env.server.test_sql_env import MockTokenizer |
|
|
|
|
| def _build_sql_environment(tmp_path, *, db_id: str) -> SQLEnvironment: |
| db_root = tmp_path / "databases" |
| db_dir = db_root / db_id |
| db_dir.mkdir(parents=True) |
| db_path = db_dir / f"{db_id}.sqlite" |
|
|
| connection = sqlite3.connect(db_path) |
| cursor = connection.cursor() |
| cursor.execute( |
| "CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT)" |
| ) |
| cursor.executemany( |
| "INSERT INTO employees (id, name, dept) VALUES (?, ?, ?)", |
| [ |
| (1, "Alice", "engineering"), |
| (2, "Bob", "engineering"), |
| (3, "Cara", "sales"), |
| ], |
| ) |
| connection.commit() |
| connection.close() |
|
|
| questions_path = tmp_path / "questions.json" |
| questions_path.write_text( |
| json.dumps( |
| [ |
| { |
| "question": "How many employees are there?", |
| "db_id": db_id, |
| "query": "SELECT COUNT(*) FROM employees", |
| } |
| ] |
| ), |
| encoding="utf-8", |
| ) |
|
|
| return SQLEnvironment( |
| questions_path=str(questions_path), |
| db_dir=str(db_root), |
| tokenizer=MockTokenizer(), |
| ) |
|
|
|
|
| def _build_observation(*, budget_remaining: int, result: str = "") -> SQLObservation: |
| return SQLObservation( |
| question="How many rows?", |
| schema_info="Available tables:\n- employees\n- departments", |
| result=result, |
| error="", |
| step_count=0, |
| budget_remaining=budget_remaining, |
| action_history=[], |
| done=False, |
| reward=None, |
| ) |
|
|
|
|
| def _terminal_observation(*, reward: float) -> SQLObservation: |
| return SQLObservation( |
| question="How many rows?", |
| schema_info="Available tables:\n- employees\n- departments", |
| result="", |
| error="", |
| step_count=1, |
| budget_remaining=0, |
| action_history=[], |
| done=True, |
| reward=reward, |
| ) |
|
|
|
|
| class _FixedPolicy: |
| def select_action(self, observation: SQLObservation) -> SQLAction: |
| return SQLAction(action_type="QUERY", argument="SELECT 1") |
|
|
|
|
| class _RaisingPolicy: |
| def __init__(self, fail_on_episode: int) -> None: |
| self._fail_on_episode = fail_on_episode |
| self._episode_index = -1 |
|
|
| def select_action(self, observation: SQLObservation) -> SQLAction: |
| if observation.step_count == 0: |
| self._episode_index += 1 |
| if self._episode_index == self._fail_on_episode: |
| raise RuntimeError("policy failed") |
| return SQLAction(action_type="QUERY", argument="SELECT 1") |
|
|
|
|
| class _SeedTrackingEnv: |
| def __init__(self, rewards: list[float]) -> None: |
| self._rewards = rewards |
| self._episode_index = -1 |
| self.reset_seeds: list[int | None] = [] |
|
|
| def reset(self, *, seed: int | None = None) -> SQLObservation: |
| self.reset_seeds.append(seed) |
| self._episode_index += 1 |
| return _build_observation(budget_remaining=2) |
|
|
| def step(self, action: SQLAction) -> SQLObservation: |
| del action |
| reward = self._rewards[self._episode_index] |
| return _terminal_observation(reward=reward) |
|
|
|
|
| class _FlakyEnv(_SeedTrackingEnv): |
| def __init__(self, rewards: list[float], fail_on_episode: int) -> None: |
| super().__init__(rewards) |
| self._fail_on_episode = fail_on_episode |
|
|
| def step(self, action: SQLAction) -> SQLObservation: |
| if self._episode_index == self._fail_on_episode: |
| raise RuntimeError("step failed") |
| return super().step(action) |
|
|
|
|
| def test_random_policy_explores_when_budget_gt_one() -> None: |
| policy = RandomPolicy(seed=42) |
| observation = _build_observation(budget_remaining=10) |
|
|
| action = policy.select_action(observation) |
|
|
| assert action.action_type in {"DESCRIBE", "SAMPLE", "QUERY"} |
|
|
|
|
| def test_random_policy_answers_when_budget_eq_one() -> None: |
| policy = RandomPolicy(seed=42) |
| observation = _build_observation(budget_remaining=1) |
|
|
| action = policy.select_action(observation) |
|
|
| assert action.action_type == "ANSWER" |
|
|
|
|
| def test_random_policy_returns_sql_action() -> None: |
| policy = RandomPolicy(seed=7) |
| observation = _build_observation(budget_remaining=10) |
|
|
| action = policy.select_action(observation) |
|
|
| assert isinstance(action, SQLAction) |
|
|
|
|
| def test_random_policy_deterministic_with_seed() -> None: |
| observation = _build_observation(budget_remaining=10) |
| first = RandomPolicy(seed=123) |
| second = RandomPolicy(seed=123) |
|
|
| first_actions = [first.select_action(observation) for _ in range(25)] |
| second_actions = [second.select_action(observation) for _ in range(25)] |
|
|
| assert first_actions == second_actions |
|
|
|
|
| def test_random_policy_explores_all_action_types() -> None: |
| policy = RandomPolicy(seed=1) |
| observation = _build_observation(budget_remaining=10) |
|
|
| action_types = {policy.select_action(observation).action_type for _ in range(200)} |
|
|
| assert action_types == {"DESCRIBE", "SAMPLE", "QUERY"} |
|
|
|
|
| def test_random_policy_uses_result_rows_for_answer_candidates() -> None: |
| policy = RandomPolicy(seed=0) |
| observation = _build_observation( |
| budget_remaining=1, |
| result="1. engineering | 25\n2. sales | 10", |
| ) |
|
|
| action = policy.select_action(observation) |
|
|
| assert action.action_type == "ANSWER" |
| assert action.argument in { |
| "engineering", |
| "25", |
| "sales", |
| "10", |
| "engineering | 25", |
| "sales | 10", |
| } |
|
|
|
|
| def test_evaluate_happy_path() -> None: |
| env = _SeedTrackingEnv([1.0, 0.0, 1.0]) |
| result = evaluate(env, _FixedPolicy(), n_episodes=3) |
|
|
| assert result.n_episodes == 3 |
| assert result.n_completed == 3 |
| assert len(result.episodes) == 3 |
| assert result.success_rate == 2 / 3 |
| assert result.avg_reward == 2 / 3 |
| assert result.avg_steps == 1.0 |
|
|
|
|
| def test_evaluate_zero_episodes_returns_zero_values() -> None: |
| env = _SeedTrackingEnv([]) |
| result = evaluate(env, _FixedPolicy(), n_episodes=0) |
|
|
| assert result == result.__class__( |
| success_rate=0.0, |
| avg_reward=0.0, |
| avg_steps=0.0, |
| n_episodes=0, |
| n_completed=0, |
| episodes=[], |
| ) |
| assert env.reset_seeds == [] |
|
|
|
|
| def test_evaluate_negative_episodes_raises() -> None: |
| env = _SeedTrackingEnv([]) |
|
|
| try: |
| evaluate(env, _FixedPolicy(), n_episodes=-1) |
| except ValueError as exc: |
| assert str(exc) == "n_episodes must be >= 0" |
| else: |
| raise AssertionError("Expected ValueError for negative n_episodes") |
|
|
|
|
| def test_evaluate_uses_seed_plus_episode_index() -> None: |
| env = _SeedTrackingEnv([1.0, 1.0, 1.0]) |
| evaluate(env, _FixedPolicy(), n_episodes=3, seed=100) |
|
|
| assert env.reset_seeds == [100, 101, 102] |
|
|
|
|
| def test_evaluate_records_episode_errors_and_continues() -> None: |
| env = _FlakyEnv([1.0, 1.0, 1.0], fail_on_episode=1) |
| result = evaluate(env, _FixedPolicy(), n_episodes=3) |
|
|
| assert result.n_episodes == 3 |
| assert len(result.episodes) == 3 |
| assert result.n_completed == 2 |
| assert result.episodes[1].error == "step failed" |
| assert result.episodes[2].error is None |
|
|
|
|
| def test_evaluate_averages_exclude_failed_episodes() -> None: |
| env = _FlakyEnv([1.0, 0.0, 0.0], fail_on_episode=1) |
| result = evaluate(env, _FixedPolicy(), n_episodes=3) |
|
|
| assert result.n_completed == 2 |
| assert result.avg_reward == 0.5 |
| assert result.avg_steps == 1.0 |
| assert result.success_rate == 0.5 |
|
|
|
|
| def test_evaluate_policy_exception_recorded() -> None: |
| env = _SeedTrackingEnv([1.0, 1.0, 1.0]) |
| result = evaluate(env, _RaisingPolicy(fail_on_episode=1), n_episodes=3) |
|
|
| assert result.n_completed == 2 |
| assert result.episodes[1].error == "policy failed" |
|
|
|
|
| def test_evaluate_progress_callback_receives_episode_progress() -> None: |
| env = _SeedTrackingEnv([1.0, 1.0, 1.0]) |
| calls: list[tuple[int, int]] = [] |
|
|
| evaluate( |
| env, |
| _FixedPolicy(), |
| n_episodes=3, |
| progress_callback=lambda current, total: calls.append((current, total)), |
| ) |
|
|
| assert calls == [(1, 3), (2, 3), (3, 3)] |
|
|
|
|
| def test_evaluate_integration_with_sql_environment(tmp_path) -> None: |
| env = _build_sql_environment(tmp_path, db_id="integration_eval") |
|
|
| result = evaluate(env, RandomPolicy(seed=42), n_episodes=10, seed=0) |
|
|
| assert result.n_episodes == 10 |
| assert result.n_completed == 10 |
| assert len(result.episodes) == 10 |
| assert result.success_rate == sum(int(e.correct) for e in result.episodes) / 10 |
| assert result.avg_reward == pytest.approx( |
| sum(e.total_reward for e in result.episodes) / 10 |
| ) |
|
|
|
|
| def test_evaluate_integration_is_deterministic_with_seeds(tmp_path) -> None: |
| env_a = _build_sql_environment(tmp_path / "run_a", db_id="integration_eval") |
| env_b = _build_sql_environment(tmp_path / "run_b", db_id="integration_eval") |
|
|
| result_a = evaluate(env_a, RandomPolicy(seed=42), n_episodes=10, seed=0) |
| result_b = evaluate(env_b, RandomPolicy(seed=42), n_episodes=10, seed=0) |
|
|
| assert result_a == result_b |
|
|