| """Core types for policy evaluation.""" |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| import random |
| import re |
| from typing import Callable, Protocol, runtime_checkable |
|
|
| try: |
| from ..models import SQLAction, SQLObservation |
| except ImportError: |
| try: |
| from models import SQLAction, SQLObservation |
| except ImportError: |
| from sql_env.models import SQLAction, SQLObservation |
|
|
|
|
| @runtime_checkable |
| class Policy(Protocol): |
| """Interface for policies used by the evaluator.""" |
|
|
| def select_action(self, observation: SQLObservation) -> SQLAction: |
| """Choose one action for the current observation.""" |
|
|
|
|
| @dataclass(frozen=True) |
| class EpisodeResult: |
| """Per-episode metrics from one evaluation run.""" |
|
|
| episode_index: int |
| correct: bool |
| total_reward: float |
| steps: int |
| error: str | None = None |
|
|
|
|
| @dataclass(frozen=True) |
| class EvaluationResult: |
| """Aggregate evaluation metrics across all attempted episodes.""" |
|
|
| success_rate: float |
| avg_reward: float |
| avg_steps: float |
| n_episodes: int |
| n_completed: int |
| episodes: list[EpisodeResult] |
|
|
|
|
| class RandomPolicy: |
| """Built-in random baseline policy.""" |
|
|
| _EXPLORATION_ACTIONS = ("DESCRIBE", "SAMPLE", "QUERY") |
| _ROW_PATTERN = re.compile(r"^\d+\.\s*(.+)$") |
|
|
| def __init__(self, seed: int | None = None) -> None: |
| self._rng = random.Random(seed) |
|
|
| def select_action(self, observation: SQLObservation) -> SQLAction: |
| if observation.budget_remaining <= 1: |
| return SQLAction( |
| action_type="ANSWER", |
| argument=self._random_answer(observation.result), |
| ) |
|
|
| action_type = self._rng.choice(self._EXPLORATION_ACTIONS) |
| table_name = self._random_table(observation.schema_info) |
| if action_type == "QUERY": |
| safe_table_name = table_name.replace('"', '""') |
| argument = f'SELECT * FROM "{safe_table_name}" LIMIT 5' |
| else: |
| argument = table_name |
|
|
| return SQLAction(action_type=action_type, argument=argument) |
|
|
| def _random_table(self, schema_info: str) -> str: |
| table_names = self._extract_table_names(schema_info) |
| if not table_names: |
| return "unknown" |
| return self._rng.choice(table_names) |
|
|
| @classmethod |
| def _extract_table_names(cls, schema_info: str) -> list[str]: |
| table_names: list[str] = [] |
| for line in schema_info.splitlines(): |
| stripped = line.strip() |
| if not stripped.startswith("- "): |
| continue |
| candidate = stripped[2:] |
| if ":" in candidate: |
| candidate = candidate.split(":", maxsplit=1)[0] |
| candidate = candidate.strip() |
| if candidate: |
| table_names.append(candidate) |
| return table_names |
|
|
| def _random_answer(self, result_text: str) -> str: |
| candidates = self._extract_answer_candidates(result_text) |
| if not candidates: |
| return "unknown" |
| return self._rng.choice(candidates) |
|
|
| @classmethod |
| def _extract_answer_candidates(cls, result_text: str) -> list[str]: |
| candidates: list[str] = [] |
| for line in result_text.splitlines(): |
| match = cls._ROW_PATTERN.match(line.strip()) |
| if not match: |
| continue |
| row_value = match.group(1).strip() |
| if not row_value: |
| continue |
| candidates.append(row_value) |
| split_values = [value.strip() for value in row_value.split("|")] |
| candidates.extend([value for value in split_values if value]) |
| return candidates |
|
|
|
|
| def evaluate( |
| env: object, |
| policy: Policy, |
| n_episodes: int = 100, |
| *, |
| seed: int | None = None, |
| progress_callback: Callable[[int, int], None] | None = None, |
| ) -> EvaluationResult: |
| """Run policy evaluation over multiple episodes with error isolation.""" |
| if n_episodes < 0: |
| raise ValueError("n_episodes must be >= 0") |
|
|
| if n_episodes == 0: |
| return EvaluationResult( |
| success_rate=0.0, |
| avg_reward=0.0, |
| avg_steps=0.0, |
| n_episodes=0, |
| n_completed=0, |
| episodes=[], |
| ) |
|
|
| episodes: list[EpisodeResult] = [] |
| for episode_index in range(n_episodes): |
| try: |
| episode_seed = seed + episode_index if seed is not None else None |
| observation = env.reset(seed=episode_seed) |
| total_reward = 0.0 |
| steps = 0 |
|
|
| while not observation.done: |
| action = policy.select_action(observation) |
| observation = env.step(action) |
| total_reward += observation.reward or 0.0 |
| steps += 1 |
|
|
| episodes.append( |
| EpisodeResult( |
| episode_index=episode_index, |
| correct=(observation.reward or 0.0) > 0.0, |
| total_reward=total_reward, |
| steps=steps, |
| ) |
| ) |
| except Exception as exc: |
| episodes.append( |
| EpisodeResult( |
| episode_index=episode_index, |
| correct=False, |
| total_reward=0.0, |
| steps=0, |
| error=str(exc), |
| ) |
| ) |
|
|
| if progress_callback is not None: |
| progress_callback(episode_index + 1, n_episodes) |
|
|
| completed_episodes = [episode for episode in episodes if episode.error is None] |
| n_completed = len(completed_episodes) |
| if n_completed == 0: |
| return EvaluationResult( |
| success_rate=0.0, |
| avg_reward=0.0, |
| avg_steps=0.0, |
| n_episodes=n_episodes, |
| n_completed=0, |
| episodes=episodes, |
| ) |
|
|
| successful = sum(1 for episode in completed_episodes if episode.correct) |
| avg_reward = sum(episode.total_reward for episode in completed_episodes) / n_completed |
| avg_steps = sum(episode.steps for episode in completed_episodes) / n_completed |
| return EvaluationResult( |
| success_rate=successful / n_completed, |
| avg_reward=avg_reward, |
| avg_steps=avg_steps, |
| n_episodes=n_episodes, |
| n_completed=n_completed, |
| episodes=episodes, |
| ) |
|
|