import json import logging from pathlib import Path import random import re import sqlite3 import time import uuid from openenv.core.env_server.interfaces import Environment, Message, ModelTokenizer, Transform from .reward import compute_step_reward from .verifier import verify_answer try: from sql_env.models import EpisodeContext, QuestionRecord, SQLAction, SQLObservation, SQLState except ImportError: # Fallback for Docker where PYTHONPATH=/app/env from models import ( # type: ignore[no-redef] EpisodeContext, QuestionRecord, SQLAction, SQLObservation, SQLState, ) logger = logging.getLogger(__name__) _TABLE_FROM_JOIN_PATTERN = re.compile( r"\b(?:FROM|JOIN)\s+([A-Za-z_][A-Za-z0-9_]*)", re.IGNORECASE ) _FIRST_KEYWORD_PATTERN = re.compile(r"^[\s\n\r\t]*(\w+)") class SQLEnvironment(Environment[SQLAction, SQLObservation, SQLState]): """SQLEnv server implementation with a structured SQL action loop.""" def __init__( self, questions_path: str, db_dir: str, tokenizer: ModelTokenizer, step_budget: int = 15, transform: Transform | None = None, ): super().__init__(transform=transform) if not hasattr(tokenizer, "apply_chat_template"): raise ValueError("Tokenizer must have 'apply_chat_template' method") if step_budget <= 0: raise ValueError("step_budget must be a positive integer") questions_file = Path(questions_path) database_dir = Path(db_dir) if not questions_file.exists(): raise FileNotFoundError(f"Questions file not found: {questions_file}") if not database_dir.exists() or not database_dir.is_dir(): raise FileNotFoundError(f"Database directory not found: {database_dir}") self.tokenizer = tokenizer self.questions_path = questions_file self.db_dir = database_dir self.step_budget = step_budget self.questions = self._load_questions(str(questions_file)) if not self.questions: raise ValueError("Questions file contains no questions") self._episode: EpisodeContext | None = None self._last_result = "" self._last_error = "" self._last_reward: float | None = None self._last_query_truncated = False self._state = SQLState() def _extract_tables_from_sql(self, sql: str) -> list[str]: """Extract table names from basic FROM/JOIN clauses.""" tables: list[str] = [] for match in _TABLE_FROM_JOIN_PATTERN.findall(sql): if match not in tables: tables.append(match) return tables def _load_questions(self, path: str) -> list[QuestionRecord]: """Load Spider questions JSON into QuestionRecord instances.""" questions_path = Path(path) if not questions_path.exists(): raise FileNotFoundError(f"Questions file not found: {questions_path}") try: with questions_path.open("r", encoding="utf-8") as handle: payload = json.load(handle) except json.JSONDecodeError as exc: raise ValueError(f"Invalid questions JSON format: {questions_path}") from exc if not isinstance(payload, list): raise ValueError("Questions JSON must be an array of records") question_records: list[QuestionRecord] = [] for idx, item in enumerate(payload): if not isinstance(item, dict): raise ValueError(f"Question at index {idx} must be an object") question_text = item.get("question") db_name = item.get("db_id") gold_sql = item.get("query") if not isinstance(question_text, str) or not question_text.strip(): raise ValueError(f"Question at index {idx} missing non-empty 'question'") if not isinstance(db_name, str) or not db_name.strip(): raise ValueError(f"Question at index {idx} missing non-empty 'db_id'") if not isinstance(gold_sql, str) or not gold_sql.strip(): raise ValueError(f"Question at index {idx} missing non-empty 'query'") normalized_db_name = db_name.strip() if not re.fullmatch(r"[A-Za-z0-9_]+", normalized_db_name): raise ValueError( f"Question at index {idx} has invalid db_id '{normalized_db_name}'" ) question_records.append( QuestionRecord( question_id=f"q-{idx}", question_text=question_text, database_name=normalized_db_name, gold_sql=gold_sql, gold_answer="", answer_type="string", difficulty="medium", tables_involved=self._extract_tables_from_sql(gold_sql), ) ) return question_records def _open_db(self, db_name: str) -> sqlite3.Connection: """Open a read-only SQLite connection for the requested database.""" normalized_db_name = db_name.strip() if not re.fullmatch(r"[A-Za-z0-9_]+", normalized_db_name): raise ValueError(f"Invalid database name: '{db_name}'") candidates = [ (self.db_dir / normalized_db_name / f"{normalized_db_name}.sqlite").resolve(), (self.db_dir / f"{normalized_db_name}.sqlite").resolve(), ] db_root = self.db_dir.resolve() db_path = next( ( candidate for candidate in candidates if candidate.exists() and db_root in candidate.parents ), None, ) if db_path is None: raise FileNotFoundError( f"Database '{normalized_db_name}' not found in {self.db_dir}" ) uri = f"file:{db_path}?mode=ro" return sqlite3.connect(uri, uri=True) def _format_gold_answer(self, rows: list[tuple]) -> str: """Convert SQL rows into a stable string answer for episode comparison.""" if not rows: return "" if len(rows) == 1 and len(rows[0]) == 1: return str(rows[0][0]) return "\n".join(" | ".join(str(value) for value in row) for row in rows) def _execute_gold_sql( self, connection: sqlite3.Connection, sql: str, timeout_s: float = 5.0, ) -> list[tuple]: """Execute gold SQL with read-only/SELECT-only timeout protections.""" sql_stripped = sql.strip() if not sql_stripped: raise ValueError("SQL query cannot be empty") first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped) first_keyword = ( first_keyword_match.group(1).upper() if first_keyword_match else "" ) if first_keyword != "SELECT": raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}") deadline = time.monotonic() + timeout_s def _progress_callback() -> int: return 1 if time.monotonic() > deadline else 0 connection.set_progress_handler(_progress_callback, 1000) try: cursor = connection.cursor() cursor.execute(sql_stripped) return cursor.fetchall() except sqlite3.OperationalError as exc: if "interrupted" in str(exc).lower(): raise sqlite3.OperationalError( f"Query timed out after {timeout_s:.1f} seconds" ) from exc raise finally: connection.set_progress_handler(None, 0) def reset( self, *, seed: int | None = None, episode_id: str | None = None, **kwargs, ) -> SQLObservation: """Reset episode context and return the initial rich observation.""" del kwargs if self._episode is not None: self._episode.db_connection.close() chooser = random.Random(seed) if seed is not None else random question = chooser.choice(self.questions) connection = self._open_db(question.database_name) try: gold_rows = self._execute_gold_sql(connection, question.gold_sql) except sqlite3.Error: connection.close() raise gold_answer = self._format_gold_answer(gold_rows) question_for_episode = QuestionRecord( question_id=question.question_id, question_text=question.question_text, database_name=question.database_name, gold_sql=question.gold_sql, gold_answer=gold_answer, answer_type=question.answer_type, difficulty=question.difficulty, tables_involved=list(question.tables_involved), ) resolved_episode_id = episode_id or str(uuid.uuid4()) self._episode = EpisodeContext( episode_id=resolved_episode_id, db_connection=connection, question_record=question_for_episode, step_count=0, budget=self.step_budget, done=False, gold_answer=gold_answer, gold_rows=gold_rows, ) self._state.episode_id = resolved_episode_id self._state.step_count = 0 self._state.current_action_type = "QUERY" self._state.history_messages = [] self._state.history_tokens = [] self._last_result = "" self._last_error = "" self._last_reward = None self._last_query_truncated = False return self._build_observation() def _get_table_names(self, connection: sqlite3.Connection) -> list[str]: """Return user-visible table names for the active SQLite database.""" cursor = connection.cursor() cursor.execute( """ SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name """ ) return [str(row[0]) for row in cursor.fetchall()] def _resolve_table_name(self, table_name: str) -> tuple[str | None, list[str]]: """Resolve requested table name against active DB tables.""" if self._episode is None: return None, [] available_tables = self._get_table_names(self._episode.db_connection) lookup = {table.lower(): table for table in available_tables} resolved = lookup.get(table_name.strip().lower()) return resolved, available_tables def _format_rows(self, rows: list[tuple]) -> str: """Format SQL rows as readable text.""" if not rows: return "No rows returned." lines = [f"{idx}. {' | '.join(str(value) for value in row)}" for idx, row in enumerate(rows, start=1)] return "\n".join(lines) def _execute_sql(self, sql: str, timeout_s: float = 5.0) -> list[tuple]: """Execute SQL in sandbox: SELECT-only, single statement, timeout, truncation.""" if self._episode is None: raise RuntimeError("No active episode. Call reset() before step().") sql_stripped = sql.strip() if not sql_stripped: raise ValueError("SQL query cannot be empty") first_keyword_match = _FIRST_KEYWORD_PATTERN.match(sql_stripped) first_keyword = ( first_keyword_match.group(1).upper() if first_keyword_match else "" ) if first_keyword != "SELECT": raise ValueError(f"Only SELECT queries are allowed. Got: {first_keyword}") single_statement_sql = sql_stripped.rstrip(";").strip() if ";" in single_statement_sql: raise ValueError("Only a single SELECT statement is allowed") deadline = time.monotonic() + timeout_s def _progress_callback() -> int: return 1 if time.monotonic() > deadline else 0 connection = self._episode.db_connection connection.set_progress_handler(_progress_callback, 1000) self._last_query_truncated = False try: cursor = connection.cursor() cursor.execute(sql_stripped) rows = cursor.fetchmany(21) if len(rows) > 20: self._last_query_truncated = True rows = rows[:20] return rows except sqlite3.OperationalError as exc: if "interrupted" in str(exc).lower(): raise sqlite3.OperationalError( f"Query timed out after {timeout_s:.1f} seconds" ) from exc raise finally: connection.set_progress_handler(None, 0) def _handle_describe(self, table_name: str) -> str: """Return table schema and row count.""" if self._episode is None: raise RuntimeError("No active episode. Call reset() before step().") requested = table_name.strip() if not requested: raise ValueError("Argument cannot be empty for DESCRIBE") resolved_table, available_tables = self._resolve_table_name(requested) if resolved_table is None: available = ", ".join(available_tables) if available_tables else "none" raise ValueError( f"Table '{requested}' not found. Available tables: {available}" ) safe_identifier = resolved_table.replace('"', '""') cursor = self._episode.db_connection.cursor() cursor.execute(f'PRAGMA table_info("{safe_identifier}")') columns = cursor.fetchall() if not columns: raise ValueError(f"Table '{resolved_table}' has no visible columns") cursor.execute(f'SELECT COUNT(*) FROM "{safe_identifier}"') row_count = int(cursor.fetchone()[0]) self._episode.described_tables.add(resolved_table) lines = [f"Table '{resolved_table}' columns:"] for _, col_name, col_type, _, _, _ in columns: normalized_type = str(col_type).strip() or "UNKNOWN" lines.append(f"- {col_name}: {normalized_type}") lines.append(f"Row count: {row_count}") return "\n".join(lines) def _handle_sample(self, table_name: str, limit: int = 5) -> str: """Return sample rows from a table.""" if self._episode is None: raise RuntimeError("No active episode. Call reset() before step().") requested = table_name.strip() if not requested: raise ValueError("Argument cannot be empty for SAMPLE") resolved_table, available_tables = self._resolve_table_name(requested) if resolved_table is None: available = ", ".join(available_tables) if available_tables else "none" raise ValueError( f"Table '{requested}' not found. Available tables: {available}" ) safe_identifier = resolved_table.replace('"', '""') bounded_limit = max(1, min(limit, 20)) rows = self._execute_sql( f'SELECT * FROM "{safe_identifier}" LIMIT {bounded_limit}' ) return f"Sample from '{resolved_table}':\n{self._format_rows(rows)}" def _handle_query(self, sql: str) -> tuple[str, list[tuple]]: """Execute query and return formatted output with raw result rows.""" sql_text = sql.strip() if not sql_text: raise ValueError("Argument cannot be empty for QUERY") rows = self._execute_sql(sql_text, timeout_s=5.0) output = self._format_rows(rows) if self._last_query_truncated: output = f"{output}\n... (truncated to 20 rows)" return output, rows def _handle_answer(self, value: str) -> tuple[bool, float]: """Compare submitted answer against episode gold answer.""" if self._episode is None: raise RuntimeError("No active episode. Call reset() before step().") is_correct = verify_answer( predicted=value, gold=self._episode.gold_answer or "", answer_type=self._episode.question_record.answer_type, gold_rows=self._episode.gold_rows, ) self._episode.done = True return is_correct, 1.0 if is_correct else 0.0 def step( self, action: SQLAction, *, timeout_s: float = 30, **kwargs, ) -> SQLObservation: """Dispatch one structured action and return updated observation.""" del timeout_s del kwargs if self._episode is None: self._last_result = "" self._last_error = "No active episode. Call reset() before step()." self._last_reward = None return self._build_observation() if self._episode.done: return self._build_observation() action_type = str(action.action_type).strip().upper() argument = str(action.argument) self._state.current_action_type = action_type or "QUERY" self._last_result = "" self._last_error = "" self._last_reward = None reward_rows: list[tuple] | None = [] reward_sql = "" def _consume_invalid_step(error_text: str) -> SQLObservation: self._last_error = error_text self._episode.step_count += 1 self._episode.budget = max(0, self._episode.budget - 1) self._episode.action_log.append(f"{action_type} -> ERROR: {error_text}") if self._episode.budget == 0: self._episode.done = True self._last_reward = 0.0 self._state.step_count = self._episode.step_count return self._build_observation() valid_action_types = {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"} if action_type not in valid_action_types: return _consume_invalid_step( f"Unknown action type '{action.action_type}'. " "Valid types: DESCRIBE, SAMPLE, QUERY, ANSWER" ) argument_stripped = argument.strip() if not argument_stripped: return _consume_invalid_step( f"Argument cannot be empty for {action_type}" ) try: if action_type == "DESCRIBE": self._last_result = self._handle_describe(argument_stripped) elif action_type == "SAMPLE": self._last_result = self._handle_sample(argument_stripped) elif action_type == "QUERY": reward_sql = argument_stripped self._last_result, reward_rows = self._handle_query(argument_stripped) else: is_correct, reward = self._handle_answer(argument_stripped) verdict = "correct" if is_correct else "incorrect" self._last_result = f"Answer submitted: {verdict}." self._last_reward = reward self._episode.step_count += 1 self._episode.action_log.append( f"ANSWER {argument_stripped} -> {verdict}" ) self._state.step_count = self._episode.step_count return self._build_observation() except ValueError as exc: self._last_error = str(exc) except sqlite3.Error as exc: self._last_error = f"SQL error: {exc}" self._episode.step_count += 1 self._episode.budget = max(0, self._episode.budget - 1) self._state.step_count = self._episode.step_count if self._episode.budget > 0: self._last_reward = compute_step_reward( ctx=self._episode, action_type=action_type, sql=reward_sql, rows=reward_rows, error=self._last_error or None, ) if self._last_error: self._episode.action_log.append(f"{action_type} -> ERROR: {self._last_error}") else: preview = self._last_result.splitlines()[0] if self._last_result else "ok" self._episode.action_log.append(f"{action_type} -> {preview}") if self._episode.budget == 0: self._episode.done = True if self._last_reward is None: self._last_reward = 0.0 return self._build_observation() def _build_observation(self) -> SQLObservation: """Construct a rich observation from the current episode context.""" if self._episode is None: observation = SQLObservation( question="", schema_info="", result=self._last_result, error=self._last_error, step_count=0, budget_remaining=0, action_history=[], done=False, reward=self._last_reward, ) else: table_names = self._get_table_names(self._episode.db_connection) known_tables = set(table_names) schema_lines = ["Available tables:", *[f"- {name}" for name in table_names]] if self._episode.described_tables: schema_lines.append("") schema_lines.append("Described tables:") for table_name in sorted(self._episode.described_tables): if table_name not in known_tables: schema_lines.append( f"- {table_name}: unavailable (not in active schema)" ) continue safe_identifier = table_name.replace('"', '""') cursor = self._episode.db_connection.cursor() cursor.execute(f'PRAGMA table_info("{safe_identifier}")') columns = cursor.fetchall() if not columns: schema_lines.append(f"- {table_name}: no columns available") continue column_summary = ", ".join( f"{str(column[1])} {str(column[2]) or 'UNKNOWN'}" for column in columns ) schema_lines.append(f"- {table_name}: {column_summary}") observation = SQLObservation( question=self._episode.question_record.question_text, schema_info="\n".join(schema_lines), result=self._last_result, error=self._last_error, step_count=self._episode.step_count, budget_remaining=self._episode.budget, action_history=list(self._episode.action_log), done=self._episode.done, reward=self._last_reward, ) transformed = self._apply_transform(observation) if isinstance(transformed, SQLObservation): return transformed return SQLObservation( question=getattr(transformed, "question", ""), schema_info=getattr(transformed, "schema_info", ""), result=getattr(transformed, "result", ""), error=getattr(transformed, "error", ""), step_count=getattr(transformed, "step_count", 0), budget_remaining=getattr(transformed, "budget_remaining", 0), action_history=getattr(transformed, "action_history", []), done=transformed.done, reward=transformed.reward, ) @property def state(self) -> SQLState: """Get current exposed state metadata.""" return self._state def message_to_action(self, message: Message) -> SQLAction: """Convert free-form messages into structured SQLAction values.""" if "role" not in message: raise ValueError("Message must contain a 'role' key") if "content" not in message: raise ValueError("Message must contain a 'content' key") if message["content"] is None: raise ValueError("Message content cannot be None") content = str(message["content"]) parsed = content.strip() action_type = "QUERY" argument = content if message["role"].lower() == "user" and parsed: prefix, separator, remainder = parsed.partition(" ") normalized_prefix = prefix.upper() if normalized_prefix in {"DESCRIBE", "SAMPLE", "QUERY", "ANSWER"}: action_type = normalized_prefix if separator: argument = remainder else: argument = "" self._state.current_action_type = action_type self._state.history_messages.append(message) return SQLAction(action_type=action_type, argument=argument)