"""Core types for policy evaluation.""" from __future__ import annotations from dataclasses import dataclass import random import re from typing import Callable, Protocol, runtime_checkable try: from ..models import SQLAction, SQLObservation except ImportError: try: from models import SQLAction, SQLObservation # type: ignore[no-redef] except ImportError: from sql_env.models import SQLAction, SQLObservation # type: ignore[no-redef] @runtime_checkable class Policy(Protocol): """Interface for policies used by the evaluator.""" def select_action(self, observation: SQLObservation) -> SQLAction: """Choose one action for the current observation.""" @dataclass(frozen=True) class EpisodeResult: """Per-episode metrics from one evaluation run.""" episode_index: int correct: bool total_reward: float steps: int error: str | None = None @dataclass(frozen=True) class EvaluationResult: """Aggregate evaluation metrics across all attempted episodes.""" success_rate: float avg_reward: float avg_steps: float n_episodes: int n_completed: int episodes: list[EpisodeResult] class RandomPolicy: """Built-in random baseline policy.""" _EXPLORATION_ACTIONS = ("DESCRIBE", "SAMPLE", "QUERY") _ROW_PATTERN = re.compile(r"^\d+\.\s*(.+)$") def __init__(self, seed: int | None = None) -> None: self._rng = random.Random(seed) def select_action(self, observation: SQLObservation) -> SQLAction: if observation.budget_remaining <= 1: return SQLAction( action_type="ANSWER", argument=self._random_answer(observation.result), ) action_type = self._rng.choice(self._EXPLORATION_ACTIONS) table_name = self._random_table(observation.schema_info) if action_type == "QUERY": safe_table_name = table_name.replace('"', '""') argument = f'SELECT * FROM "{safe_table_name}" LIMIT 5' else: argument = table_name return SQLAction(action_type=action_type, argument=argument) def _random_table(self, schema_info: str) -> str: table_names = self._extract_table_names(schema_info) if not table_names: return "unknown" return self._rng.choice(table_names) @classmethod def _extract_table_names(cls, schema_info: str) -> list[str]: table_names: list[str] = [] for line in schema_info.splitlines(): stripped = line.strip() if not stripped.startswith("- "): continue candidate = stripped[2:] if ":" in candidate: candidate = candidate.split(":", maxsplit=1)[0] candidate = candidate.strip() if candidate: table_names.append(candidate) return table_names def _random_answer(self, result_text: str) -> str: candidates = self._extract_answer_candidates(result_text) if not candidates: return "unknown" return self._rng.choice(candidates) @classmethod def _extract_answer_candidates(cls, result_text: str) -> list[str]: candidates: list[str] = [] for line in result_text.splitlines(): match = cls._ROW_PATTERN.match(line.strip()) if not match: continue row_value = match.group(1).strip() if not row_value: continue candidates.append(row_value) split_values = [value.strip() for value in row_value.split("|")] candidates.extend([value for value in split_values if value]) return candidates def evaluate( env: object, policy: Policy, n_episodes: int = 100, *, seed: int | None = None, progress_callback: Callable[[int, int], None] | None = None, ) -> EvaluationResult: """Run policy evaluation over multiple episodes with error isolation.""" if n_episodes < 0: raise ValueError("n_episodes must be >= 0") if n_episodes == 0: return EvaluationResult( success_rate=0.0, avg_reward=0.0, avg_steps=0.0, n_episodes=0, n_completed=0, episodes=[], ) episodes: list[EpisodeResult] = [] for episode_index in range(n_episodes): try: episode_seed = seed + episode_index if seed is not None else None observation = env.reset(seed=episode_seed) total_reward = 0.0 steps = 0 while not observation.done: action = policy.select_action(observation) observation = env.step(action) total_reward += observation.reward or 0.0 steps += 1 episodes.append( EpisodeResult( episode_index=episode_index, correct=(observation.reward or 0.0) > 0.0, total_reward=total_reward, steps=steps, ) ) except Exception as exc: episodes.append( EpisodeResult( episode_index=episode_index, correct=False, total_reward=0.0, steps=0, error=str(exc), ) ) if progress_callback is not None: progress_callback(episode_index + 1, n_episodes) completed_episodes = [episode for episode in episodes if episode.error is None] n_completed = len(completed_episodes) if n_completed == 0: return EvaluationResult( success_rate=0.0, avg_reward=0.0, avg_steps=0.0, n_episodes=n_episodes, n_completed=0, episodes=episodes, ) successful = sum(1 for episode in completed_episodes if episode.correct) avg_reward = sum(episode.total_reward for episode in completed_episodes) / n_completed avg_steps = sum(episode.steps for episode in completed_episodes) / n_completed return EvaluationResult( success_rate=successful / n_completed, avg_reward=avg_reward, avg_steps=avg_steps, n_episodes=n_episodes, n_completed=n_completed, episodes=episodes, )