File size: 6,376 Bytes
5dd1bb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """Core types for policy evaluation."""
from __future__ import annotations
from dataclasses import dataclass
import random
import re
from typing import Callable, Protocol, runtime_checkable
try:
from ..models import SQLAction, SQLObservation
except ImportError:
try:
from models import SQLAction, SQLObservation # type: ignore[no-redef]
except ImportError:
from sql_env.models import SQLAction, SQLObservation # type: ignore[no-redef]
@runtime_checkable
class Policy(Protocol):
"""Interface for policies used by the evaluator."""
def select_action(self, observation: SQLObservation) -> SQLAction:
"""Choose one action for the current observation."""
@dataclass(frozen=True)
class EpisodeResult:
"""Per-episode metrics from one evaluation run."""
episode_index: int
correct: bool
total_reward: float
steps: int
error: str | None = None
@dataclass(frozen=True)
class EvaluationResult:
"""Aggregate evaluation metrics across all attempted episodes."""
success_rate: float
avg_reward: float
avg_steps: float
n_episodes: int
n_completed: int
episodes: list[EpisodeResult]
class RandomPolicy:
"""Built-in random baseline policy."""
_EXPLORATION_ACTIONS = ("DESCRIBE", "SAMPLE", "QUERY")
_ROW_PATTERN = re.compile(r"^\d+\.\s*(.+)$")
def __init__(self, seed: int | None = None) -> None:
self._rng = random.Random(seed)
def select_action(self, observation: SQLObservation) -> SQLAction:
if observation.budget_remaining <= 1:
return SQLAction(
action_type="ANSWER",
argument=self._random_answer(observation.result),
)
action_type = self._rng.choice(self._EXPLORATION_ACTIONS)
table_name = self._random_table(observation.schema_info)
if action_type == "QUERY":
safe_table_name = table_name.replace('"', '""')
argument = f'SELECT * FROM "{safe_table_name}" LIMIT 5'
else:
argument = table_name
return SQLAction(action_type=action_type, argument=argument)
def _random_table(self, schema_info: str) -> str:
table_names = self._extract_table_names(schema_info)
if not table_names:
return "unknown"
return self._rng.choice(table_names)
@classmethod
def _extract_table_names(cls, schema_info: str) -> list[str]:
table_names: list[str] = []
for line in schema_info.splitlines():
stripped = line.strip()
if not stripped.startswith("- "):
continue
candidate = stripped[2:]
if ":" in candidate:
candidate = candidate.split(":", maxsplit=1)[0]
candidate = candidate.strip()
if candidate:
table_names.append(candidate)
return table_names
def _random_answer(self, result_text: str) -> str:
candidates = self._extract_answer_candidates(result_text)
if not candidates:
return "unknown"
return self._rng.choice(candidates)
@classmethod
def _extract_answer_candidates(cls, result_text: str) -> list[str]:
candidates: list[str] = []
for line in result_text.splitlines():
match = cls._ROW_PATTERN.match(line.strip())
if not match:
continue
row_value = match.group(1).strip()
if not row_value:
continue
candidates.append(row_value)
split_values = [value.strip() for value in row_value.split("|")]
candidates.extend([value for value in split_values if value])
return candidates
def evaluate(
env: object,
policy: Policy,
n_episodes: int = 100,
*,
seed: int | None = None,
progress_callback: Callable[[int, int], None] | None = None,
) -> EvaluationResult:
"""Run policy evaluation over multiple episodes with error isolation."""
if n_episodes < 0:
raise ValueError("n_episodes must be >= 0")
if n_episodes == 0:
return EvaluationResult(
success_rate=0.0,
avg_reward=0.0,
avg_steps=0.0,
n_episodes=0,
n_completed=0,
episodes=[],
)
episodes: list[EpisodeResult] = []
for episode_index in range(n_episodes):
try:
episode_seed = seed + episode_index if seed is not None else None
observation = env.reset(seed=episode_seed)
total_reward = 0.0
steps = 0
while not observation.done:
action = policy.select_action(observation)
observation = env.step(action)
total_reward += observation.reward or 0.0
steps += 1
episodes.append(
EpisodeResult(
episode_index=episode_index,
correct=(observation.reward or 0.0) > 0.0,
total_reward=total_reward,
steps=steps,
)
)
except Exception as exc:
episodes.append(
EpisodeResult(
episode_index=episode_index,
correct=False,
total_reward=0.0,
steps=0,
error=str(exc),
)
)
if progress_callback is not None:
progress_callback(episode_index + 1, n_episodes)
completed_episodes = [episode for episode in episodes if episode.error is None]
n_completed = len(completed_episodes)
if n_completed == 0:
return EvaluationResult(
success_rate=0.0,
avg_reward=0.0,
avg_steps=0.0,
n_episodes=n_episodes,
n_completed=0,
episodes=episodes,
)
successful = sum(1 for episode in completed_episodes if episode.correct)
avg_reward = sum(episode.total_reward for episode in completed_episodes) / n_completed
avg_steps = sum(episode.steps for episode in completed_episodes) / n_completed
return EvaluationResult(
success_rate=successful / n_completed,
avg_reward=avg_reward,
avg_steps=avg_steps,
n_episodes=n_episodes,
n_completed=n_completed,
episodes=episodes,
)
|