sql_env / evaluation /green_agent.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Core types for policy evaluation."""
from __future__ import annotations
from dataclasses import dataclass
import random
import re
from typing import Callable, Protocol, runtime_checkable
try:
from ..models import SQLAction, SQLObservation
except ImportError:
try:
from models import SQLAction, SQLObservation # type: ignore[no-redef]
except ImportError:
from sql_env.models import SQLAction, SQLObservation # type: ignore[no-redef]
@runtime_checkable
class Policy(Protocol):
"""Interface for policies used by the evaluator."""
def select_action(self, observation: SQLObservation) -> SQLAction:
"""Choose one action for the current observation."""
@dataclass(frozen=True)
class EpisodeResult:
"""Per-episode metrics from one evaluation run."""
episode_index: int
correct: bool
total_reward: float
steps: int
error: str | None = None
@dataclass(frozen=True)
class EvaluationResult:
"""Aggregate evaluation metrics across all attempted episodes."""
success_rate: float
avg_reward: float
avg_steps: float
n_episodes: int
n_completed: int
episodes: list[EpisodeResult]
class RandomPolicy:
"""Built-in random baseline policy."""
_EXPLORATION_ACTIONS = ("DESCRIBE", "SAMPLE", "QUERY")
_ROW_PATTERN = re.compile(r"^\d+\.\s*(.+)$")
def __init__(self, seed: int | None = None) -> None:
self._rng = random.Random(seed)
def select_action(self, observation: SQLObservation) -> SQLAction:
if observation.budget_remaining <= 1:
return SQLAction(
action_type="ANSWER",
argument=self._random_answer(observation.result),
)
action_type = self._rng.choice(self._EXPLORATION_ACTIONS)
table_name = self._random_table(observation.schema_info)
if action_type == "QUERY":
safe_table_name = table_name.replace('"', '""')
argument = f'SELECT * FROM "{safe_table_name}" LIMIT 5'
else:
argument = table_name
return SQLAction(action_type=action_type, argument=argument)
def _random_table(self, schema_info: str) -> str:
table_names = self._extract_table_names(schema_info)
if not table_names:
return "unknown"
return self._rng.choice(table_names)
@classmethod
def _extract_table_names(cls, schema_info: str) -> list[str]:
table_names: list[str] = []
for line in schema_info.splitlines():
stripped = line.strip()
if not stripped.startswith("- "):
continue
candidate = stripped[2:]
if ":" in candidate:
candidate = candidate.split(":", maxsplit=1)[0]
candidate = candidate.strip()
if candidate:
table_names.append(candidate)
return table_names
def _random_answer(self, result_text: str) -> str:
candidates = self._extract_answer_candidates(result_text)
if not candidates:
return "unknown"
return self._rng.choice(candidates)
@classmethod
def _extract_answer_candidates(cls, result_text: str) -> list[str]:
candidates: list[str] = []
for line in result_text.splitlines():
match = cls._ROW_PATTERN.match(line.strip())
if not match:
continue
row_value = match.group(1).strip()
if not row_value:
continue
candidates.append(row_value)
split_values = [value.strip() for value in row_value.split("|")]
candidates.extend([value for value in split_values if value])
return candidates
def evaluate(
env: object,
policy: Policy,
n_episodes: int = 100,
*,
seed: int | None = None,
progress_callback: Callable[[int, int], None] | None = None,
) -> EvaluationResult:
"""Run policy evaluation over multiple episodes with error isolation."""
if n_episodes < 0:
raise ValueError("n_episodes must be >= 0")
if n_episodes == 0:
return EvaluationResult(
success_rate=0.0,
avg_reward=0.0,
avg_steps=0.0,
n_episodes=0,
n_completed=0,
episodes=[],
)
episodes: list[EpisodeResult] = []
for episode_index in range(n_episodes):
try:
episode_seed = seed + episode_index if seed is not None else None
observation = env.reset(seed=episode_seed)
total_reward = 0.0
steps = 0
while not observation.done:
action = policy.select_action(observation)
observation = env.step(action)
total_reward += observation.reward or 0.0
steps += 1
episodes.append(
EpisodeResult(
episode_index=episode_index,
correct=(observation.reward or 0.0) > 0.0,
total_reward=total_reward,
steps=steps,
)
)
except Exception as exc:
episodes.append(
EpisodeResult(
episode_index=episode_index,
correct=False,
total_reward=0.0,
steps=0,
error=str(exc),
)
)
if progress_callback is not None:
progress_callback(episode_index + 1, n_episodes)
completed_episodes = [episode for episode in episodes if episode.error is None]
n_completed = len(completed_episodes)
if n_completed == 0:
return EvaluationResult(
success_rate=0.0,
avg_reward=0.0,
avg_steps=0.0,
n_episodes=n_episodes,
n_completed=0,
episodes=episodes,
)
successful = sum(1 for episode in completed_episodes if episode.correct)
avg_reward = sum(episode.total_reward for episode in completed_episodes) / n_completed
avg_steps = sum(episode.steps for episode in completed_episodes) / n_completed
return EvaluationResult(
success_rate=successful / n_completed,
avg_reward=avg_reward,
avg_steps=avg_steps,
n_episodes=n_episodes,
n_completed=n_completed,
episodes=episodes,
)