sql_env / server /sql_environment.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
import json
import logging
from pathlib import Path
import random
import re
import sqlite3
import time
import uuid
from openenv.core.env_server.interfaces import Environment, Message, ModelTokenizer, Transform
from .reward import compute_step_reward
from .verifier import verify_answer
try:
from sql_env.models import EpisodeContext, QuestionRecord, SQLAction, SQLObservation, SQLState
except ImportError:
# Fallback for Docker where PYTHONPATH=/app/env
from models import ( # type: ignore[no-redef]
EpisodeContext,
QuestionRecord,
SQLAction,
SQLObservation,
SQLState,
)
logger = logging.getLogger(__name__)
_TABLE_FROM_JOIN_PATTERN = re.compile(
r"\b(?:FROM|JOIN)\s+([A-Za-z_][A-Za-z0-9_]*)", re.IGNORECASE
)
_FIRST_KEYWORD_PATTERN = re.compile(r"^[\s\n\r\t]*(\w+)")
class SQLEnvironment(Environment[SQLAction, SQLObservation, SQLState]):
"""SQLEnv server implementation with a structured SQL action loop."""
def __init__(
self,
questions_path: str,
db_dir: str,
tokenizer: ModelTokenizer,
step_budget: int = 15,
transform: Transform | None = None,
):
super().__init__(transform=transform)
if not hasattr(tokenizer, "apply_chat_template"):
raise ValueError("Tokenizer must have 'apply_chat_template' method")
if step_budget <= 0:
raise ValueError("step_budget must be a positive integer")
questions_file = Path(questions_path)
database_dir = Path(db_dir)
if not questions_file.exists():
raise FileNotFoundError(f"Questions file not found: {questions_file}")
if not database_dir.exists() or not database_dir.is_dir():
raise FileNotFoundError(f"Database directory not found: {database_dir}")
self.tokenizer = tokenizer
self.questions_path = questions_file
self.db_dir = database_dir
self.step_budget = step_budget
self.questions = self._load_questions(str(questions_file))
if not self.questions:
raise ValueError("Questions file contains no questions")
self._episode: EpisodeContext | None = None
self._last_result = ""
self._last_error = ""
self._last_reward: float | None = None
self._last_query_truncated = False
self._state = SQLState()
def _extract_tables_from_sql(self, sql: str) -> list[str]:
"""Extract table names from basic FROM/JOIN clauses."""
tables: list[str] = []
for match in _TABLE_FROM_JOIN_PATTERN.findall(sql):
if match not in tables:
tables.append(match)
return tables
def _load_questions(self, path: str) -> list[QuestionRecord]:
"""Load Spider questions JSON into QuestionRecord instances."""
questions_path = Path(path)
if not questions_path.exists():
raise FileNotFoundError(f"Questions file not found: {questions_path}")
try:
with questions_path.open("r", encoding="utf-8") as handle:
payload = json.load(handle)
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid questions JSON format: {questions_path}") from exc
if not isinstance(payload, list):
raise ValueError("Questions JSON must be an array of records")
question_records: list[QuestionRecord] = []
for idx, item in enumerate(payload):
if not isinstance(item, dict):
raise ValueError(f"Question at index {idx} must be an object")
question_text = item.get("question")
db_name = item.get("db_id")
gold_sql = item.get("query")
if not isinstance(question_text, str) or not question_text.strip():
raise ValueError(f"Question at index {idx} missing non-empty 'question'")
if not isinstance(db_name, str) or not db_name.strip():
raise ValueError(f"Question at index {idx} missing non-empty 'db_id'")
if not isinstance(gold_sql, str) or not gold_sql.strip():
raise ValueError(f"Question at index {idx} missing non-empty 'query'")
normalized_db_name = db_name.strip()
if not re.fullmatch(r"[A-Za-z0-9_]+", normalized_db_name):
raise ValueError(
f"Question at index {idx} has invalid db_id '{normalized_db_name}'"
)
question_records.append(
QuestionRecord(
question_id=f"q-{idx}",
question_text=question_text,
database_name=normalized_db_name,
gold_sql=gold_sql,
gold_answer="",
answer_type="string",
difficulty="medium",
tables_involved=self._extract_tables_from_sql(gold_sql),
)
)
return question_records
def _open_db(self, db_name: str) -> sqlite3.Connection:
"""Open a read-only SQLite connection for the requested database."""
normalized_db_name = db_name.strip()
if not re.fullmatch(r"[A-Za-z0-9_]+", normalized_db_name):
raise ValueError(f"Invalid database name: '{db_name}'")
candidates = [
(self.db_dir / normalized_db_name / f"{normalized_db_name}.sqlite").resolve(),
(self.db_dir / f"{normalized_db_name}.sqlite").resolve(),
]
db_root = self.db_dir.resolve()
db_path = next(
(
candidate
for candidate in candidates
if candidate.exists() and db_root in candidate.parents
),
None,
)
if db_path is None:
raise FileNotFoundError(
f"Database '{normalized_db_name}' not found in {self.db_dir}"
)
uri = f"file:{db_path}?mode=ro"
return sqlite3.connect(uri, uri=True)
def _format_gold_answer(self, rows: list[tuple]) -> str:
"""Convert SQL rows into a stable string answer for episode comparison."""
if not rows:
return ""
if len(rows) == 1 and len(rows[0]) == 1:
return str(rows[0][0])
return "\n".join(" | ".join(str(value) for value in row) for row in rows)
def _execute_gold_sql(
self,
connection: sqlite3.Connection,
sql: str,
timeout_s: float = 5.0,
) -> list[tuple]:
"""Execute gold SQL with read-only/SELECT-only timeout protections."""
sql_stripped = sql.strip()
if not sql_stripped:
raise ValueError("SQL query cannot be empty")
first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped)
first_keyword = (
first_keyword_match.group(1).upper() if first_keyword_match else ""
)
if first_keyword != "SELECT":
raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}")
deadline = time.monotonic() + timeout_s
def _progress_callback() -> int:
return 1 if time.monotonic() > deadline else 0
connection.set_progress_handler(_progress_callback, 1000)
try:
cursor = connection.cursor()
cursor.execute(sql_stripped)
return cursor.fetchall()
except sqlite3.OperationalError as exc:
if "interrupted" in str(exc).lower():
raise sqlite3.OperationalError(
f"Query timed out after {timeout_s:.1f} seconds"
) from exc
raise
finally:
connection.set_progress_handler(None, 0)
def reset(
self,
*,
seed: int | None = None,
episode_id: str | None = None,
**kwargs,
) -> SQLObservation:
"""Reset episode context and return the initial rich observation."""
del kwargs
if self._episode is not None:
self._episode.db_connection.close()
chooser = random.Random(seed) if seed is not None else random
question = chooser.choice(self.questions)
connection = self._open_db(question.database_name)
try:
gold_rows = self._execute_gold_sql(connection, question.gold_sql)
except sqlite3.Error:
connection.close()
raise
gold_answer = self._format_gold_answer(gold_rows)
question_for_episode = QuestionRecord(
question_id=question.question_id,
question_text=question.question_text,
database_name=question.database_name,
gold_sql=question.gold_sql,
gold_answer=gold_answer,
answer_type=question.answer_type,
difficulty=question.difficulty,
tables_involved=list(question.tables_involved),
)
resolved_episode_id = episode_id or str(uuid.uuid4())
self._episode = EpisodeContext(
episode_id=resolved_episode_id,
db_connection=connection,
question_record=question_for_episode,
step_count=0,
budget=self.step_budget,
done=False,
gold_answer=gold_answer,
gold_rows=gold_rows,
)
self._state.episode_id = resolved_episode_id
self._state.step_count = 0
self._state.current_action_type = "QUERY"
self._state.history_messages = []
self._state.history_tokens = []
self._last_result = ""
self._last_error = ""
self._last_reward = None
self._last_query_truncated = False
return self._build_observation()
def _get_table_names(self, connection: sqlite3.Connection) -> list[str]:
"""Return user-visible table names for the active SQLite database."""
cursor = connection.cursor()
cursor.execute(
"""
SELECT name
FROM sqlite_master
WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
ORDER BY name
"""
)
return [str(row[0]) for row in cursor.fetchall()]
def _resolve_table_name(self, table_name: str) -> tuple[str | None, list[str]]:
"""Resolve requested table name against active DB tables."""
if self._episode is None:
return None, []
available_tables = self._get_table_names(self._episode.db_connection)
lookup = {table.lower(): table for table in available_tables}
resolved = lookup.get(table_name.strip().lower())
return resolved, available_tables
def _format_rows(self, rows: list[tuple]) -> str:
"""Format SQL rows as readable text."""
if not rows:
return "No rows returned."
lines = [f"{idx}. {' | '.join(str(value) for value in row)}" for idx, row in enumerate(rows, start=1)]
return "\n".join(lines)
def _execute_sql(self, sql: str, timeout_s: float = 5.0) -> list[tuple]:
"""Execute SQL in sandbox: SELECT-only, single statement, timeout, truncation."""
if self._episode is None:
raise RuntimeError("No active episode. Call reset() before step().")
sql_stripped = sql.strip()
if not sql_stripped:
raise ValueError("SQL query cannot be empty")
first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped)
first_keyword = (
first_keyword_match.group(1).upper() if first_keyword_match else ""
)
if first_keyword != "SELECT":
raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}")
single_statement_sql = sql_stripped.rstrip(";").strip()
if ";" in single_statement_sql:
raise ValueError("Only a single SELECT statement is allowed")
deadline = time.monotonic() + timeout_s
def _progress_callback() -> int:
return 1 if time.monotonic() > deadline else 0
connection = self._episode.db_connection
connection.set_progress_handler(_progress_callback, 1000)
self._last_query_truncated = False
try:
cursor = connection.cursor()
cursor.execute(sql_stripped)
rows = cursor.fetchmany(21)
if len(rows) > 20:
self._last_query_truncated = True
rows = rows[:20]
return rows
except sqlite3.OperationalError as exc:
if "interrupted" in str(exc).lower():
raise sqlite3.OperationalError(
f"Query timed out after {timeout_s:.1f} seconds"
) from exc
raise
finally:
connection.set_progress_handler(None, 0)
def _handle_describe(self, table_name: str) -> str:
"""Return table schema and row count."""
if self._episode is None:
raise RuntimeError("No active episode. Call reset() before step().")
requested = table_name.strip()
if not requested:
raise ValueError("Argument cannot be empty for DESCRIBE")
resolved_table, available_tables = self._resolve_table_name(requested)
if resolved_table is None:
available = ", ".join(available_tables) if available_tables else "none"
raise ValueError(
f"Table '{requested}' not found. Available tables: {available}"
)
safe_identifier = resolved_table.replace('"', '""')
cursor = self._episode.db_connection.cursor()
cursor.execute(f'PRAGMA table_info("{safe_identifier}")')
columns = cursor.fetchall()
if not columns:
raise ValueError(f"Table '{resolved_table}' has no visible columns")
cursor.execute(f'SELECT COUNT(*) FROM "{safe_identifier}"')
row_count = int(cursor.fetchone()[0])
self._episode.described_tables.add(resolved_table)
lines = [f"Table '{resolved_table}' columns:"]
for _, col_name, col_type, _, _, _ in columns:
normalized_type = str(col_type).strip() or "UNKNOWN"
lines.append(f"- {col_name}: {normalized_type}")
lines.append(f"Row count: {row_count}")
return "\n".join(lines)
def _handle_sample(self, table_name: str, limit: int = 5) -> str:
"""Return sample rows from a table."""
if self._episode is None:
raise RuntimeError("No active episode. Call reset() before step().")
requested = table_name.strip()
if not requested:
raise ValueError("Argument cannot be empty for SAMPLE")
resolved_table, available_tables = self._resolve_table_name(requested)
if resolved_table is None:
available = ", ".join(available_tables) if available_tables else "none"
raise ValueError(
f"Table '{requested}' not found. Available tables: {available}"
)
safe_identifier = resolved_table.replace('"', '""')
bounded_limit = max(1, min(limit, 20))
rows = self._execute_sql(
f'SELECT * FROM "{safe_identifier}" LIMIT {bounded_limit}'
)
return f"Sample from '{resolved_table}':\n{self._format_rows(rows)}"
def _handle_query(self, sql: str) -> tuple[str, list[tuple]]:
"""Execute query and return formatted output with raw result rows."""
sql_text = sql.strip()
if not sql_text:
raise ValueError("Argument cannot be empty for QUERY")
rows = self._execute_sql(sql_text, timeout_s=5.0)
output = self._format_rows(rows)
if self._last_query_truncated:
output = f"{output}\n... (truncated to 20 rows)"
return output, rows
def _handle_answer(self, value: str) -> tuple[bool, float]:
"""Compare submitted answer against episode gold answer."""
if self._episode is None:
raise RuntimeError("No active episode. Call reset() before step().")
is_correct = verify_answer(
predicted=value,
gold=self._episode.gold_answer or "",
answer_type=self._episode.question_record.answer_type,
gold_rows=self._episode.gold_rows,
)
self._episode.done = True
return is_correct, 1.0 if is_correct else 0.0
def step(
self,
action: SQLAction,
*,
timeout_s: float = 30,
**kwargs,
) -> SQLObservation:
"""Dispatch one structured action and return updated observation."""
del timeout_s
del kwargs
if self._episode is None:
self._last_result = ""
self._last_error = "No active episode. Call reset() before step()."
self._last_reward = None
return self._build_observation()
if self._episode.done:
return self._build_observation()
action_type = str(action.action_type).strip().upper()
argument = str(action.argument)
self._state.current_action_type = action_type or "QUERY"
self._last_result = ""
self._last_error = ""
self._last_reward = None
reward_rows: list[tuple] | None = []
reward_sql = ""
def _consume_invalid_step(error_text: str) -> SQLObservation:
self._last_error = error_text
self._episode.step_count += 1
self._episode.budget = max(0, self._episode.budget - 1)
self._episode.action_log.append(f"{action_type} -> ERROR: {error_text}")
if self._episode.budget == 0:
self._episode.done = True
self._last_reward = 0.0
self._state.step_count = self._episode.step_count
return self._build_observation()
valid_action_types = {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}
if action_type not in valid_action_types:
return _consume_invalid_step(
f"Unknown action type '{action.action_type}'. "
"Valid types: DESCRIBE, SAMPLE, QUERY, ANSWER"
)
argument_stripped = argument.strip()
if not argument_stripped:
return _consume_invalid_step(
f"Argument cannot be empty for {action_type}"
)
try:
if action_type == "DESCRIBE":
self._last_result = self._handle_describe(argument_stripped)
elif action_type == "SAMPLE":
self._last_result = self._handle_sample(argument_stripped)
elif action_type == "QUERY":
reward_sql = argument_stripped
self._last_result, reward_rows = self._handle_query(argument_stripped)
else:
is_correct, reward = self._handle_answer(argument_stripped)
verdict = "correct" if is_correct else "incorrect"
self._last_result = f"Answer submitted: {verdict}."
self._last_reward = reward
self._episode.step_count += 1
self._episode.action_log.append(
f"ANSWER {argument_stripped} -> {verdict}"
)
self._state.step_count = self._episode.step_count
return self._build_observation()
except ValueError as exc:
self._last_error = str(exc)
except sqlite3.Error as exc:
self._last_error = f"SQL error: {exc}"
self._episode.step_count += 1
self._episode.budget = max(0, self._episode.budget - 1)
self._state.step_count = self._episode.step_count
if self._episode.budget > 0:
self._last_reward = compute_step_reward(
ctx=self._episode,
action_type=action_type,
sql=reward_sql,
rows=reward_rows,
error=self._last_error or None,
)
if self._last_error:
self._episode.action_log.append(f"{action_type} -> ERROR: {self._last_error}")
else:
preview = self._last_result.splitlines()[0] if self._last_result else "ok"
self._episode.action_log.append(f"{action_type} -> {preview}")
if self._episode.budget == 0:
self._episode.done = True
if self._last_reward is None:
self._last_reward = 0.0
return self._build_observation()
def _build_observation(self) -> SQLObservation:
"""Construct a rich observation from the current episode context."""
if self._episode is None:
observation = SQLObservation(
question="",
schema_info="",
result=self._last_result,
error=self._last_error,
step_count=0,
budget_remaining=0,
action_history=[],
done=False,
reward=self._last_reward,
)
else:
table_names = self._get_table_names(self._episode.db_connection)
known_tables = set(table_names)
schema_lines = ["Available tables:", *[f"- {name}" for name in table_names]]
if self._episode.described_tables:
schema_lines.append("")
schema_lines.append("Described tables:")
for table_name in sorted(self._episode.described_tables):
if table_name not in known_tables:
schema_lines.append(
f"- {table_name}: unavailable (not in active schema)"
)
continue
safe_identifier = table_name.replace('"', '""')
cursor = self._episode.db_connection.cursor()
cursor.execute(f'PRAGMA table_info("{safe_identifier}")')
columns = cursor.fetchall()
if not columns:
schema_lines.append(f"- {table_name}: no columns available")
continue
column_summary = ", ".join(
f"{str(column[1])} {str(column[2]) or 'UNKNOWN'}"
for column in columns
)
schema_lines.append(f"- {table_name}: {column_summary}")
observation = SQLObservation(
question=self._episode.question_record.question_text,
schema_info="\n".join(schema_lines),
result=self._last_result,
error=self._last_error,
step_count=self._episode.step_count,
budget_remaining=self._episode.budget,
action_history=list(self._episode.action_log),
done=self._episode.done,
reward=self._last_reward,
)
transformed = self._apply_transform(observation)
if isinstance(transformed, SQLObservation):
return transformed
return SQLObservation(
question=getattr(transformed, "question", ""),
schema_info=getattr(transformed, "schema_info", ""),
result=getattr(transformed, "result", ""),
error=getattr(transformed, "error", ""),
step_count=getattr(transformed, "step_count", 0),
budget_remaining=getattr(transformed, "budget_remaining", 0),
action_history=getattr(transformed, "action_history", []),
done=transformed.done,
reward=transformed.reward,
)
@property
def state(self) -> SQLState:
"""Get current exposed state metadata."""
return self._state
def message_to_action(self, message: Message) -> SQLAction:
"""Convert free-form messages into structured SQLAction values."""
if "role" not in message:
raise ValueError("Message must contain a 'role' key")
if "content" not in message:
raise ValueError("Message must contain a 'content' key")
if message["content"] is None:
raise ValueError("Message content cannot be None")
content = str(message["content"])
parsed = content.strip()
action_type = "QUERY"
argument = content
if message["role"].lower() == "user" and parsed:
prefix, separator, remainder = parsed.partition(" ")
normalized_prefix = prefix.upper()
if normalized_prefix in {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}:
action_type = normalized_prefix
if separator:
argument = remainder
else:
argument = ""
self._state.current_action_type = action_type
self._state.history_messages.append(message)
return SQLAction(action_type=action_type, argument=argument)