"""Unit tests for evaluation package random policy and evaluate().""" import json import sqlite3 import pytest from sql_env.evaluation import RandomPolicy, evaluate from sql_env.models import SQLAction, SQLObservation from sql_env.server.sql_environment import SQLEnvironment from sql_env.server.test_sql_env import MockTokenizer def _build_sql_environment(tmp_path, *, db_id: str) -> SQLEnvironment: db_root = tmp_path / "databases" db_dir = db_root / db_id db_dir.mkdir(parents=True) db_path = db_dir / f"{db_id}.sqlite" connection = sqlite3.connect(db_path) cursor = connection.cursor() cursor.execute( "CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, dept TEXT)" ) cursor.executemany( "INSERT INTO employees (id, name, dept) VALUES (?, ?, ?)", [ (1, "Alice", "engineering"), (2, "Bob", "engineering"), (3, "Cara", "sales"), ], ) connection.commit() connection.close() questions_path = tmp_path / "questions.json" questions_path.write_text( json.dumps( [ { "question": "How many employees are there?", "db_id": db_id, "query": "SELECT COUNT(*) FROM employees", } ] ), encoding="utf-8", ) return SQLEnvironment( questions_path=str(questions_path), db_dir=str(db_root), tokenizer=MockTokenizer(), ) def _build_observation(*, budget_remaining: int, result: str = "") -> SQLObservation: return SQLObservation( question="How many rows?", schema_info="Available tables:\n- employees\n- departments", result=result, error="", step_count=0, budget_remaining=budget_remaining, action_history=[], done=False, reward=None, ) def _terminal_observation(*, reward: float) -> SQLObservation: return SQLObservation( question="How many rows?", schema_info="Available tables:\n- employees\n- departments", result="", error="", step_count=1, budget_remaining=0, action_history=[], done=True, reward=reward, ) class _FixedPolicy: def select_action(self, observation: SQLObservation) -> SQLAction: return SQLAction(action_type="QUERY", argument="SELECT 1") class _RaisingPolicy: def __init__(self, fail_on_episode: int) -> None: self._fail_on_episode = fail_on_episode self._episode_index = -1 def select_action(self, observation: SQLObservation) -> SQLAction: if observation.step_count == 0: self._episode_index += 1 if self._episode_index == self._fail_on_episode: raise RuntimeError("policy failed") return SQLAction(action_type="QUERY", argument="SELECT 1") class _SeedTrackingEnv: def __init__(self, rewards: list[float]) -> None: self._rewards = rewards self._episode_index = -1 self.reset_seeds: list[int | None] = [] def reset(self, *, seed: int | None = None) -> SQLObservation: self.reset_seeds.append(seed) self._episode_index += 1 return _build_observation(budget_remaining=2) def step(self, action: SQLAction) -> SQLObservation: del action reward = self._rewards[self._episode_index] return _terminal_observation(reward=reward) class _FlakyEnv(_SeedTrackingEnv): def __init__(self, rewards: list[float], fail_on_episode: int) -> None: super().__init__(rewards) self._fail_on_episode = fail_on_episode def step(self, action: SQLAction) -> SQLObservation: if self._episode_index == self._fail_on_episode: raise RuntimeError("step failed") return super().step(action) def test_random_policy_explores_when_budget_gt_one() -> None: policy = RandomPolicy(seed=42) observation = _build_observation(budget_remaining=10) action = policy.select_action(observation) assert action.action_type in {"DESCRIBE", "SAMPLE", "QUERY"} def test_random_policy_answers_when_budget_eq_one() -> None: policy = RandomPolicy(seed=42) observation = _build_observation(budget_remaining=1) action = policy.select_action(observation) assert action.action_type == "ANSWER" def test_random_policy_returns_sql_action() -> None: policy = RandomPolicy(seed=7) observation = _build_observation(budget_remaining=10) action = policy.select_action(observation) assert isinstance(action, SQLAction) def test_random_policy_deterministic_with_seed() -> None: observation = _build_observation(budget_remaining=10) first = RandomPolicy(seed=123) second = RandomPolicy(seed=123) first_actions = [first.select_action(observation) for _ in range(25)] second_actions = [second.select_action(observation) for _ in range(25)] assert first_actions == second_actions def test_random_policy_explores_all_action_types() -> None: policy = RandomPolicy(seed=1) observation = _build_observation(budget_remaining=10) action_types = {policy.select_action(observation).action_type for _ in range(200)} assert action_types == {"DESCRIBE", "SAMPLE", "QUERY"} def test_random_policy_uses_result_rows_for_answer_candidates() -> None: policy = RandomPolicy(seed=0) observation = _build_observation( budget_remaining=1, result="1. engineering | 25\n2. sales | 10", ) action = policy.select_action(observation) assert action.action_type == "ANSWER" assert action.argument in { "engineering", "25", "sales", "10", "engineering | 25", "sales | 10", } def test_evaluate_happy_path() -> None: env = _SeedTrackingEnv([1.0, 0.0, 1.0]) result = evaluate(env, _FixedPolicy(), n_episodes=3) assert result.n_episodes == 3 assert result.n_completed == 3 assert len(result.episodes) == 3 assert result.success_rate == 2 / 3 assert result.avg_reward == 2 / 3 assert result.avg_steps == 1.0 def test_evaluate_zero_episodes_returns_zero_values() -> None: env = _SeedTrackingEnv([]) result = evaluate(env, _FixedPolicy(), n_episodes=0) assert result == result.__class__( success_rate=0.0, avg_reward=0.0, avg_steps=0.0, n_episodes=0, n_completed=0, episodes=[], ) assert env.reset_seeds == [] def test_evaluate_negative_episodes_raises() -> None: env = _SeedTrackingEnv([]) try: evaluate(env, _FixedPolicy(), n_episodes=-1) except ValueError as exc: assert str(exc) == "n_episodes must be >= 0" else: raise AssertionError("Expected ValueError for negative n_episodes") def test_evaluate_uses_seed_plus_episode_index() -> None: env = _SeedTrackingEnv([1.0, 1.0, 1.0]) evaluate(env, _FixedPolicy(), n_episodes=3, seed=100) assert env.reset_seeds == [100, 101, 102] def test_evaluate_records_episode_errors_and_continues() -> None: env = _FlakyEnv([1.0, 1.0, 1.0], fail_on_episode=1) result = evaluate(env, _FixedPolicy(), n_episodes=3) assert result.n_episodes == 3 assert len(result.episodes) == 3 assert result.n_completed == 2 assert result.episodes[1].error == "step failed" assert result.episodes[2].error is None def test_evaluate_averages_exclude_failed_episodes() -> None: env = _FlakyEnv([1.0, 0.0, 0.0], fail_on_episode=1) result = evaluate(env, _FixedPolicy(), n_episodes=3) assert result.n_completed == 2 assert result.avg_reward == 0.5 assert result.avg_steps == 1.0 assert result.success_rate == 0.5 def test_evaluate_policy_exception_recorded() -> None: env = _SeedTrackingEnv([1.0, 1.0, 1.0]) result = evaluate(env, _RaisingPolicy(fail_on_episode=1), n_episodes=3) assert result.n_completed == 2 assert result.episodes[1].error == "policy failed" def test_evaluate_progress_callback_receives_episode_progress() -> None: env = _SeedTrackingEnv([1.0, 1.0, 1.0]) calls: list[tuple[int, int]] = [] evaluate( env, _FixedPolicy(), n_episodes=3, progress_callback=lambda current, total: calls.append((current, total)), ) assert calls == [(1, 3), (2, 3), (3, 3)] def test_evaluate_integration_with_sql_environment(tmp_path) -> None: env = _build_sql_environment(tmp_path, db_id="integration_eval") result = evaluate(env, RandomPolicy(seed=42), n_episodes=10, seed=0) assert result.n_episodes == 10 assert result.n_completed == 10 assert len(result.episodes) == 10 assert result.success_rate == sum(int(e.correct) for e in result.episodes) / 10 assert result.avg_reward == pytest.approx( sum(e.total_reward for e in result.episodes) / 10 ) def test_evaluate_integration_is_deterministic_with_seeds(tmp_path) -> None: env_a = _build_sql_environment(tmp_path / "run_a", db_id="integration_eval") env_b = _build_sql_environment(tmp_path / "run_b", db_id="integration_eval") result_a = evaluate(env_a, RandomPolicy(seed=42), n_episodes=10, seed=0) result_b = evaluate(env_b, RandomPolicy(seed=42), n_episodes=10, seed=0) assert result_a == result_b