| """Curate multi-database Spider questions for SQLEnv.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import io |
| import json |
| import logging |
| import re |
| import sqlite3 |
| import time |
| import zipfile |
| from collections.abc import Iterable |
| from pathlib import Path |
| from typing import Any, Callable |
| from urllib.parse import quote |
|
|
| import requests |
|
|
|
|
| SPIDER_SQLITE_URLS = ( |
| "https://raw.githubusercontent.com/taoyds/spider/master/database/{db_id}/{db_id}.sqlite", |
| "https://github.com/taoyds/spider/raw/master/database/{db_id}/{db_id}.sqlite", |
| ) |
| SPIDER_DATASET_FILE_ID = "1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J" |
| SPIDER_DATASET_DOWNLOAD_URL = "https://drive.usercontent.google.com/download" |
|
|
| SQLITE_MAGIC_HEADER = b"SQLite format 3\x00" |
| DB_ID_PATTERN = re.compile(r"^[A-Za-z0-9_]+$") |
| TABLE_TOKEN_PATTERN = re.compile( |
| r"\b(?:FROM|JOIN)\s+([`\"\[]?[A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?[`\"\]]?)", |
| flags=re.IGNORECASE, |
| ) |
| CTE_ALIAS_PATTERN = re.compile( |
| r"(?:\bWITH\b|,)\s*([A-Za-z_][A-Za-z0-9_]*)\s+AS\s*\(", |
| flags=re.IGNORECASE, |
| ) |
|
|
| TRAIN_SPLIT = "train" |
| EVAL_SPLIT = "eval" |
| VALID_SPLITS = {TRAIN_SPLIT, EVAL_SPLIT} |
| VALID_ANSWER_TYPES = {"integer", "float", "string", "list", "table"} |
| VALID_DIFFICULTIES = {"easy", "medium", "hard"} |
| REQUIRED_FIELDS = ( |
| "question_id", |
| "question_text", |
| "database_name", |
| "gold_sql", |
| "gold_answer", |
| "answer_type", |
| "difficulty", |
| "tables_involved", |
| "split", |
| ) |
|
|
| LOGGER = logging.getLogger(__name__) |
| _SPIDER_ARCHIVE_BYTES: bytes | None = None |
|
|
|
|
| def _normalize_table_name(raw_table: str) -> str: |
| """Normalize a table token extracted from SQL text.""" |
| token = raw_table.strip().strip('`"[]') |
| if "." in token: |
| token = token.split(".", maxsplit=1)[1] |
| return token |
|
|
|
|
| def _validate_db_id(db_id: str) -> None: |
| """Validate that ``db_id`` is safe for filesystem usage.""" |
| if not DB_ID_PATTERN.fullmatch(db_id): |
| raise ValueError(f"Invalid db_id '{db_id}'. Expected [A-Za-z0-9_]+") |
|
|
|
|
| def _is_valid_sqlite_file(path: Path) -> bool: |
| """Return True when the file looks like a SQLite database.""" |
| if not path.exists() or path.stat().st_size < len(SQLITE_MAGIC_HEADER): |
| return False |
| with path.open("rb") as handle: |
| return handle.read(len(SQLITE_MAGIC_HEADER)) == SQLITE_MAGIC_HEADER |
|
|
|
|
| def _download_sqlite_file(db_id: str, destination: Path) -> None: |
| """Download one Spider SQLite file into destination. |
| |
| Args: |
| db_id: Spider database identifier. |
| destination: Path to write ``{db_id}.sqlite``. |
| |
| Raises: |
| FileNotFoundError: If all sources fail for this ``db_id``. |
| """ |
| _validate_db_id(db_id) |
| destination.parent.mkdir(parents=True, exist_ok=True) |
|
|
| last_error: str | None = None |
| for url_template in SPIDER_SQLITE_URLS: |
| url = url_template.format(db_id=db_id) |
| for attempt in range(2): |
| try: |
| response = requests.get(url, timeout=30) |
| response.raise_for_status() |
| tmp_path = destination.with_suffix(".sqlite.tmp") |
| tmp_path.write_bytes(response.content) |
| if not _is_valid_sqlite_file(tmp_path): |
| tmp_path.unlink(missing_ok=True) |
| raise FileNotFoundError( |
| f"Downloaded payload for '{db_id}' was not a valid SQLite file" |
| ) |
| tmp_path.replace(destination) |
| return |
| except (requests.RequestException, OSError, FileNotFoundError) as exc: |
| last_error = str(exc) |
| if attempt == 0: |
| time.sleep(5) |
|
|
| try: |
| archive_bytes = _download_spider_archive() |
| _extract_sqlite_from_archive( |
| archive_bytes=archive_bytes, |
| db_id=db_id, |
| destination=destination, |
| ) |
| return |
| except ( |
| requests.RequestException, |
| OSError, |
| FileNotFoundError, |
| zipfile.BadZipFile, |
| ) as exc: |
| last_error = str(exc) |
|
|
| raise FileNotFoundError( |
| f"Unable to download Spider SQLite for '{db_id}'. Last error: {last_error}" |
| ) |
|
|
|
|
| def _download_spider_archive() -> bytes: |
| """Download and cache official Spider dataset archive bytes.""" |
| global _SPIDER_ARCHIVE_BYTES |
| if _SPIDER_ARCHIVE_BYTES is not None: |
| return _SPIDER_ARCHIVE_BYTES |
|
|
| last_error: str | None = None |
| for attempt in range(2): |
| try: |
| session = requests.Session() |
| warning_page = session.get( |
| f"https://drive.google.com/uc?export=download&id={SPIDER_DATASET_FILE_ID}", |
| timeout=60, |
| ) |
| warning_page.raise_for_status() |
|
|
| payload = warning_page.content |
| content_type = warning_page.headers.get("content-type", "") |
| if "text/html" in content_type.lower(): |
| page_text = warning_page.text |
| params: dict[str, str] = { |
| "id": SPIDER_DATASET_FILE_ID, |
| "export": "download", |
| } |
| for field in ("confirm", "uuid"): |
| match = re.search( |
| rf'name="{field}" value="([^"]+)"', |
| page_text, |
| ) |
| if match: |
| params[field] = match.group(1) |
|
|
| download_response = session.get( |
| SPIDER_DATASET_DOWNLOAD_URL, |
| params=params, |
| timeout=240, |
| ) |
| download_response.raise_for_status() |
| payload = download_response.content |
|
|
| if not payload.startswith(b"PK"): |
| raise FileNotFoundError( |
| "Spider dataset download did not return a zip file" |
| ) |
|
|
| _SPIDER_ARCHIVE_BYTES = payload |
| return _SPIDER_ARCHIVE_BYTES |
| except (requests.RequestException, FileNotFoundError) as exc: |
| last_error = str(exc) |
| if attempt == 0: |
| time.sleep(5) |
|
|
| raise FileNotFoundError( |
| f"Unable to download Spider dataset zip. Last error: {last_error}" |
| ) |
|
|
|
|
| def _extract_sqlite_from_archive( |
| archive_bytes: bytes, db_id: str, destination: Path |
| ) -> None: |
| """Extract one SQLite file from the Spider zip archive.""" |
| candidate_members = ( |
| f"spider_data/database/{db_id}/{db_id}.sqlite", |
| f"spider/database/{db_id}/{db_id}.sqlite", |
| f"spider-master/database/{db_id}/{db_id}.sqlite", |
| ) |
|
|
| payload: bytes | None = None |
| with zipfile.ZipFile(io.BytesIO(archive_bytes)) as archive: |
| for member_name in candidate_members: |
| try: |
| payload = archive.read(member_name) |
| break |
| except KeyError: |
| continue |
|
|
| if payload is None: |
| raise FileNotFoundError(f"Database '{db_id}' not found in Spider archive") |
|
|
| tmp_path = destination.with_suffix(".sqlite.tmp") |
| tmp_path.write_bytes(payload) |
| if not _is_valid_sqlite_file(tmp_path): |
| tmp_path.unlink(missing_ok=True) |
| raise FileNotFoundError( |
| f"Archive payload for '{db_id}' was not a valid SQLite file" |
| ) |
| tmp_path.replace(destination) |
|
|
|
|
| def download_spider_databases(db_ids: list[str], output_dir: Path) -> dict[str, Path]: |
| """Download Spider SQLite database files for selected ``db_ids``. |
| |
| Existing files are reused and not downloaded again. |
| |
| Args: |
| db_ids: Spider database IDs. |
| output_dir: Base output directory (e.g. ``data/databases``). |
| |
| Returns: |
| Mapping of ``db_id`` to local SQLite path. |
| |
| Raises: |
| FileNotFoundError: If no requested database can be prepared. |
| """ |
| db_paths: dict[str, Path] = {} |
| output_root = output_dir.resolve() |
|
|
| for db_id in db_ids: |
| _validate_db_id(db_id) |
| sqlite_path = output_dir / db_id / f"{db_id}.sqlite" |
| resolved_path = sqlite_path.resolve() |
| if output_root not in resolved_path.parents: |
| raise ValueError( |
| "Resolved path " |
| f"'{resolved_path}' escapes output directory '{output_root}'" |
| ) |
|
|
| if _is_valid_sqlite_file(sqlite_path): |
| db_paths[db_id] = sqlite_path |
| continue |
|
|
| try: |
| _download_sqlite_file(db_id=db_id, destination=sqlite_path) |
| except FileNotFoundError as exc: |
| LOGGER.warning("Skipping database '%s': %s", db_id, exc) |
| continue |
| db_paths[db_id] = sqlite_path |
|
|
| if not db_paths: |
| raise FileNotFoundError("No Spider SQLite databases could be prepared") |
|
|
| return db_paths |
|
|
|
|
| def _load_questions_from_hf_datasets(db_ids: set[str]) -> list[dict[str, Any]]: |
| """Load questions through the `datasets` package when available.""" |
| try: |
| from datasets import load_dataset |
| except ImportError as exc: |
| raise ConnectionError("`datasets` package is not installed") from exc |
|
|
| records: list[dict[str, Any]] = [] |
| for spider_split in ("train", "validation"): |
| for row in load_dataset("xlangai/spider", split=spider_split): |
| db_id = row.get("db_id") |
| if db_id not in db_ids: |
| continue |
| records.append( |
| { |
| "db_id": db_id, |
| "query": row.get("query", ""), |
| "question": row.get("question", ""), |
| "spider_split": spider_split, |
| } |
| ) |
| return records |
|
|
|
|
| def _load_questions_from_spider_archive(db_ids: set[str]) -> list[dict[str, Any]]: |
| """Load Spider questions from the official dataset zip archive.""" |
| archive_bytes = _download_spider_archive() |
| records: list[dict[str, Any]] = [] |
|
|
| split_files = ( |
| ("spider_data/train_spider.json", "train"), |
| ("spider_data/dev.json", "validation"), |
| ) |
|
|
| with zipfile.ZipFile(io.BytesIO(archive_bytes)) as archive: |
| for member_name, spider_split in split_files: |
| try: |
| payload = archive.read(member_name) |
| except KeyError: |
| continue |
|
|
| rows = json.loads(payload.decode("utf-8")) |
| if not isinstance(rows, list): |
| continue |
|
|
| for row in rows: |
| if not isinstance(row, dict): |
| continue |
| db_id = row.get("db_id") |
| if db_id not in db_ids: |
| continue |
| records.append( |
| { |
| "db_id": db_id, |
| "query": row.get("query", ""), |
| "question": row.get("question", ""), |
| "spider_split": spider_split, |
| } |
| ) |
|
|
| if not records: |
| raise ConnectionError( |
| "No Spider questions found in archive for selected db_ids" |
| ) |
|
|
| return records |
|
|
|
|
| def _load_questions_from_hf_rows_api(db_ids: set[str]) -> list[dict[str, Any]]: |
| """Load Spider questions from the HuggingFace datasets rows API.""" |
| endpoint = "https://datasets-server.huggingface.co/rows" |
| records: list[dict[str, Any]] = [] |
|
|
| for spider_split in ("train", "validation"): |
| offset = 0 |
| length = 100 |
| while True: |
| params = { |
| "dataset": "xlangai/spider", |
| "config": "spider", |
| "split": spider_split, |
| "offset": offset, |
| "length": length, |
| } |
| response = requests.get(endpoint, params=params, timeout=30) |
| response.raise_for_status() |
| payload = response.json() |
| rows = payload.get("rows", []) |
| if not rows: |
| break |
|
|
| for row_payload in rows: |
| row = row_payload.get("row", {}) |
| db_id = row.get("db_id") |
| if db_id not in db_ids: |
| continue |
| records.append( |
| { |
| "db_id": db_id, |
| "query": row.get("query", ""), |
| "question": row.get("question", ""), |
| "spider_split": spider_split, |
| } |
| ) |
| offset += len(rows) |
|
|
| return records |
|
|
|
|
| def load_spider_questions(db_ids: list[str]) -> list[dict[str, Any]]: |
| """Load raw Spider questions for selected databases. |
| |
| Args: |
| db_ids: Spider database IDs. |
| |
| Returns: |
| Filtered list of question records including ``spider_split`` metadata. |
| |
| Raises: |
| ConnectionError: If all loading strategies fail. |
| """ |
| if not db_ids: |
| return [] |
|
|
| db_set = set(db_ids) |
| for db_id in db_set: |
| _validate_db_id(db_id) |
|
|
| loaders: tuple[Callable[[set[str]], list[dict[str, Any]]], ...] = ( |
| _load_questions_from_spider_archive, |
| _load_questions_from_hf_datasets, |
| _load_questions_from_hf_rows_api, |
| ) |
|
|
| last_error: str | None = None |
| for loader in loaders: |
| for attempt in range(2): |
| try: |
| return loader(db_set) |
| except (ConnectionError, OSError, requests.RequestException) as exc: |
| last_error = f"{loader.__name__}: {exc}" |
| if attempt == 0: |
| time.sleep(5) |
|
|
| raise ConnectionError( |
| f"Unable to load Spider questions from HuggingFace. Last error: {last_error}" |
| ) |
|
|
|
|
| def _shape_rows(rows: list[tuple[Any, ...]]) -> Any: |
| """Shape SQL rows into scalar/list/table forms used by the dataset.""" |
| if not rows: |
| return [] |
|
|
| column_count = len(rows[0]) |
| if column_count == 1: |
| values = [row[0] for row in rows] |
| if len(values) == 1: |
| return values[0] |
| return values |
|
|
| return [list(row) for row in rows] |
|
|
|
|
| def compute_gold_answer(gold_sql: str, db_path: Path) -> Any: |
| """Execute gold SQL against SQLite and return a normalized result.""" |
| if not db_path.exists(): |
| raise FileNotFoundError(f"Database not found: {db_path}") |
| if not _is_valid_sqlite_file(db_path): |
| raise sqlite3.Error(f"Invalid SQLite database file: {db_path}") |
|
|
| db_uri = f"file:{quote(str(db_path.resolve()))}?mode=ro" |
| with sqlite3.connect(db_uri, uri=True) as conn: |
| cursor = conn.execute(gold_sql) |
| rows = cursor.fetchall() |
| return _shape_rows(rows) |
|
|
|
|
| def classify_answer_type(gold_answer: Any) -> str: |
| """Classify the answer type for a computed gold answer.""" |
| if isinstance(gold_answer, bool): |
| return "integer" |
| if isinstance(gold_answer, int): |
| return "integer" |
| if isinstance(gold_answer, float): |
| return "float" |
| if isinstance(gold_answer, str): |
| return "string" |
|
|
| if isinstance(gold_answer, tuple): |
| if len(gold_answer) == 1: |
| return classify_answer_type(gold_answer[0]) |
| return "table" |
|
|
| if isinstance(gold_answer, list): |
| if not gold_answer: |
| return "list" |
| first = gold_answer[0] |
| if isinstance(first, (list, tuple)): |
| return "table" |
| return "list" |
|
|
| if gold_answer is None: |
| return "list" |
|
|
| raise ValueError(f"Unsupported gold_answer type: {type(gold_answer).__name__}") |
|
|
|
|
| def extract_tables_involved(gold_sql: str) -> list[str]: |
| """Extract table names referenced after FROM/JOIN tokens.""" |
| if not gold_sql.strip(): |
| return [] |
|
|
| cte_aliases = { |
| match.group(1).lower() for match in CTE_ALIAS_PATTERN.finditer(gold_sql) |
| } |
|
|
| tables: set[str] = set() |
| for match in TABLE_TOKEN_PATTERN.finditer(gold_sql): |
| normalized = _normalize_table_name(match.group(1)) |
| if normalized and normalized.lower() not in cte_aliases: |
| tables.add(normalized) |
| return sorted(tables) |
|
|
|
|
| def classify_difficulty(tables_involved: Iterable[str]) -> str: |
| """Assign difficulty from the number of tables involved.""" |
| table_count = len({name for name in tables_involved if name}) |
| if table_count <= 2: |
| return "easy" |
| if table_count == 3: |
| return "medium" |
| return "hard" |
|
|
|
|
| def _load_db_list(db_list_path: Path) -> list[str]: |
| """Load database IDs from a JSON array file.""" |
| payload = json.loads(db_list_path.read_text(encoding="utf-8")) |
| if not isinstance(payload, list) or not all( |
| isinstance(item, str) for item in payload |
| ): |
| raise ValueError(f"Expected JSON list[str] in {db_list_path}") |
| return payload |
|
|
|
|
| def assign_splits(questions: list[dict[str, Any]]) -> list[dict[str, Any]]: |
| """Assign SQLEnv train/eval splits from Spider split metadata.""" |
| split_questions: list[dict[str, Any]] = [] |
| for question in questions: |
| spider_split = str(question.get("spider_split", "")).lower() |
| if spider_split in {"validation", EVAL_SPLIT}: |
| split = EVAL_SPLIT |
| elif spider_split in {"train", TRAIN_SPLIT}: |
| split = TRAIN_SPLIT |
| else: |
| LOGGER.warning( |
| "Unknown spider_split '%s' for database '%s'; defaulting to train", |
| spider_split, |
| question.get("database_name", "unknown"), |
| ) |
| split = TRAIN_SPLIT |
| updated = dict(question) |
| updated["split"] = split |
| split_questions.append(updated) |
|
|
| total = len(split_questions) |
| if total <= 1: |
| return split_questions |
|
|
| train_records = [q for q in split_questions if q["split"] == TRAIN_SPLIT] |
| eval_records = [q for q in split_questions if q["split"] == EVAL_SPLIT] |
| if not train_records or not eval_records: |
| return split_questions |
|
|
| target_eval_count = max(1, round(total * 0.3)) |
| current_eval_count = len(eval_records) |
|
|
| if current_eval_count >= target_eval_count: |
| if current_eval_count == target_eval_count: |
| return split_questions |
|
|
| excess = min(current_eval_count - target_eval_count, len(eval_records)) |
| for index in range(excess): |
| eval_records[index]["split"] = TRAIN_SPLIT |
| return split_questions |
|
|
| needed = min(target_eval_count - current_eval_count, len(train_records)) |
| for index in range(needed): |
| train_records[index]["split"] = EVAL_SPLIT |
|
|
| return split_questions |
|
|
|
|
| def _sort_enriched_questions( |
| questions: list[dict[str, Any]], |
| ) -> list[dict[str, Any]]: |
| """Return deterministically ordered records for stable output files.""" |
| return sorted( |
| questions, |
| key=lambda item: ( |
| str(item.get("database_name", "")), |
| str(item.get("spider_split", "")), |
| str(item.get("gold_sql", "")), |
| str(item.get("question_text", "")), |
| ), |
| ) |
|
|
|
|
| def _assign_question_ids(questions: list[dict[str, Any]]) -> list[dict[str, Any]]: |
| """Assign IDs with format ``{db_id}_{split}_{index:03d}`` per db/split.""" |
| counters: dict[tuple[str, str], int] = {} |
| with_ids: list[dict[str, Any]] = [] |
|
|
| for question in questions: |
| db_id = str(question["database_name"]) |
| split = str(question["split"]) |
| key = (db_id, split) |
| index = counters.get(key, 0) |
| counters[key] = index + 1 |
|
|
| updated = dict(question) |
| updated["question_id"] = f"{db_id}_{split}_{index:03d}" |
| with_ids.append(updated) |
|
|
| return with_ids |
|
|
|
|
| def _write_output(path: Path, records: list[dict[str, Any]]) -> None: |
| """Write JSON records to disk.""" |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(json.dumps(records, indent=2, ensure_ascii=False), encoding="utf-8") |
|
|
|
|
| def _load_output_questions(path: Path) -> list[dict[str, Any]]: |
| """Load curated output records from a JSON file.""" |
| try: |
| payload = json.loads(path.read_text(encoding="utf-8")) |
| except FileNotFoundError as exc: |
| raise ValueError(f"Output dataset file not found: {path}") from exc |
| except json.JSONDecodeError as exc: |
| raise ValueError(f"Output dataset file is invalid JSON: {path}") from exc |
|
|
| if not isinstance(payload, list): |
| raise ValueError(f"Expected JSON list in {path}") |
| records: list[dict[str, Any]] = [] |
| for index, item in enumerate(payload): |
| if not isinstance(item, dict): |
| raise ValueError(f"Expected record object at index {index} in {path}") |
| records.append(item) |
| return records |
|
|
|
|
| def _question_fingerprint(record: dict[str, Any]) -> tuple[str, str, str]: |
| """Build a stable identity tuple for split leakage checks.""" |
| return ( |
| str(record.get("database_name", "")), |
| str(record.get("question_text", "")), |
| str(record.get("gold_sql", "")), |
| ) |
|
|
|
|
| def validate_dataset( |
| questions: list[dict[str, Any]], |
| db_paths: dict[str, Path], |
| ) -> list[str]: |
| """Validate curated records and return all detected issues.""" |
| errors: list[str] = [] |
| question_ids: set[str] = set() |
| train_fingerprints: set[tuple[str, str, str]] = set() |
| eval_fingerprints: set[tuple[str, str, str]] = set() |
| difficulty_counts: dict[str, int] = {key: 0 for key in VALID_DIFFICULTIES} |
|
|
| for index, question in enumerate(questions): |
| context = f"record[{index}]" |
| missing = [field for field in REQUIRED_FIELDS if field not in question] |
| if missing: |
| errors.append(f"{context}: missing required fields: {', '.join(missing)}") |
| continue |
|
|
| question_id = str(question["question_id"]).strip() |
| if not question_id: |
| errors.append(f"{context}: question_id must be non-empty") |
| elif question_id in question_ids: |
| errors.append(f"{context}: duplicate question_id '{question_id}'") |
| else: |
| question_ids.add(question_id) |
|
|
| question_text = str(question["question_text"]).strip() |
| if not question_text: |
| errors.append(f"{context}: question_text must be non-empty") |
|
|
| db_id = str(question["database_name"]).strip() |
| if not db_id: |
| errors.append(f"{context}: database_name must be non-empty") |
| continue |
|
|
| gold_sql = str(question["gold_sql"]).strip() |
| if not gold_sql: |
| errors.append(f"{context}: gold_sql must be non-empty") |
|
|
| answer_type = str(question["answer_type"]).strip() |
| if answer_type not in VALID_ANSWER_TYPES: |
| errors.append( |
| f"{context}: answer_type '{answer_type}' is invalid " |
| f"(expected one of {sorted(VALID_ANSWER_TYPES)})" |
| ) |
|
|
| difficulty = str(question["difficulty"]).strip() |
| if difficulty not in VALID_DIFFICULTIES: |
| errors.append( |
| f"{context}: difficulty '{difficulty}' is invalid " |
| f"(expected one of {sorted(VALID_DIFFICULTIES)})" |
| ) |
| else: |
| difficulty_counts[difficulty] += 1 |
|
|
| tables = question["tables_involved"] |
| if not isinstance(tables, list) or not tables: |
| errors.append(f"{context}: tables_involved must be a non-empty list") |
| elif not all( |
| isinstance(table_name, str) and table_name.strip() for table_name in tables |
| ): |
| errors.append( |
| f"{context}: tables_involved must contain non-empty table name strings" |
| ) |
|
|
| split = str(question["split"]).strip() |
| if split not in VALID_SPLITS: |
| errors.append( |
| f"{context}: split '{split}' is invalid " |
| f"(expected one of {sorted(VALID_SPLITS)})" |
| ) |
| else: |
| fingerprint = _question_fingerprint(question) |
| if split == TRAIN_SPLIT: |
| train_fingerprints.add(fingerprint) |
| else: |
| eval_fingerprints.add(fingerprint) |
|
|
| if gold_sql and db_id in db_paths: |
| try: |
| recomputed = compute_gold_answer( |
| gold_sql=gold_sql, db_path=db_paths[db_id] |
| ) |
| if recomputed != question["gold_answer"]: |
| errors.append( |
| f"{context}: gold_answer mismatch" |
| f" for question_id '{question_id}'" |
| ) |
| except (sqlite3.Error, FileNotFoundError) as exc: |
| errors.append( |
| f"{context}: gold_sql execution failed" |
| f" for database '{db_id}': {exc}" |
| ) |
| elif db_id not in db_paths: |
| errors.append( |
| f"{context}: missing database path" |
| f" for '{db_id}' (expected in data/databases)" |
| ) |
|
|
| leaked = sorted(train_fingerprints.intersection(eval_fingerprints)) |
| if leaked: |
| errors.append( |
| f"train/eval split leak detected:" |
| f" {len(leaked)} question(s) appear in both splits" |
| ) |
|
|
| total = len(questions) |
| if total > 0: |
| easy_ratio = difficulty_counts["easy"] / total |
| medium_ratio = difficulty_counts["medium"] / total |
| hard_ratio = difficulty_counts["hard"] / total |
| if abs(easy_ratio - 0.40) > 0.20: |
| LOGGER.warning( |
| "Difficulty distribution off target: easy=%s (target 40%%)", |
| f"{easy_ratio:.2%}", |
| ) |
| if abs(medium_ratio - 0.40) > 0.20: |
| LOGGER.warning( |
| "Difficulty distribution off target: medium=%s (target 40%%)", |
| f"{medium_ratio:.2%}", |
| ) |
| if abs(hard_ratio - 0.20) > 0.15: |
| LOGGER.warning( |
| "Difficulty distribution off target: hard=%s (target 20%%)", |
| f"{hard_ratio:.2%}", |
| ) |
|
|
| return errors |
|
|
|
|
| def main() -> None: |
| """CLI entry point for the dataset curation pipeline.""" |
| logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") |
|
|
| parser = argparse.ArgumentParser( |
| description="Curate Spider questions into enriched train/eval JSON files." |
| ) |
| parser.add_argument( |
| "--db-list", |
| type=Path, |
| default=Path("data/questions/db_list.json"), |
| help="Path to JSON list of Spider database IDs.", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=Path, |
| default=Path("data/databases"), |
| help="Directory where SQLite files will be stored.", |
| ) |
| parser.add_argument( |
| "--validate", |
| action="store_true", |
| help="Validate existing output files instead of running full curation.", |
| ) |
| parser.add_argument( |
| "--train-output", |
| type=Path, |
| default=Path("data/questions/questions_train.json"), |
| help="Output path for curated train questions.", |
| ) |
| parser.add_argument( |
| "--eval-output", |
| type=Path, |
| default=Path("data/questions/questions_eval.json"), |
| help="Output path for curated eval questions.", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| if args.validate: |
| try: |
| train_questions = _load_output_questions(args.train_output) |
| eval_questions = _load_output_questions(args.eval_output) |
| except ValueError as exc: |
| print(f"ERROR: {exc}") |
| raise SystemExit(1) from exc |
|
|
| questions = train_questions + eval_questions |
|
|
| db_ids = sorted( |
| {str(record.get("database_name", "")).strip() for record in questions} |
| ) |
| try: |
| for db_id in db_ids: |
| _validate_db_id(db_id) |
| except ValueError as exc: |
| print(f"ERROR: {exc}") |
| raise SystemExit(1) from exc |
|
|
| db_paths = { |
| db_id: args.output_dir / db_id / f"{db_id}.sqlite" |
| for db_id in db_ids |
| if db_id |
| } |
| errors = validate_dataset(questions=questions, db_paths=db_paths) |
| if errors: |
| for error in errors: |
| print(f"ERROR: {error}") |
| raise SystemExit(1) |
|
|
| print(f"Validation passed for {len(questions)} curated records") |
| raise SystemExit(0) |
|
|
| db_ids = _load_db_list(args.db_list) |
| db_paths = download_spider_databases(db_ids=db_ids, output_dir=args.output_dir) |
| raw_questions = load_spider_questions(db_ids) |
|
|
| enriched_questions: list[dict[str, Any]] = [] |
| skipped_count = 0 |
| for raw_question in raw_questions: |
| db_id = str(raw_question.get("db_id", "")).strip() |
| if db_id not in db_paths: |
| skipped_count += 1 |
| continue |
|
|
| gold_sql = str(raw_question.get("query", "")).strip() |
| question_text = str(raw_question.get("question", "")).strip() |
| if not gold_sql or not question_text: |
| skipped_count += 1 |
| continue |
|
|
| try: |
| gold_answer = compute_gold_answer( |
| gold_sql=gold_sql, |
| db_path=db_paths[db_id], |
| ) |
| except sqlite3.Error as exc: |
| LOGGER.warning( |
| "Skipping question for database '%s' due to SQL execution failure: %s", |
| db_id, |
| exc, |
| ) |
| skipped_count += 1 |
| continue |
|
|
| tables_involved = extract_tables_involved(gold_sql) |
| if not tables_involved: |
| LOGGER.warning( |
| "Skipping question for database '%s' because no tables were extracted", |
| db_id, |
| ) |
| skipped_count += 1 |
| continue |
|
|
| enriched_questions.append( |
| { |
| "question_text": question_text, |
| "database_name": db_id, |
| "gold_sql": gold_sql, |
| "gold_answer": gold_answer, |
| "answer_type": classify_answer_type(gold_answer), |
| "difficulty": classify_difficulty(tables_involved), |
| "tables_involved": tables_involved, |
| "spider_split": raw_question.get("spider_split", "train"), |
| } |
| ) |
|
|
| split_questions = assign_splits(_sort_enriched_questions(enriched_questions)) |
| final_questions = _assign_question_ids(split_questions) |
|
|
| validation_errors = validate_dataset(questions=final_questions, db_paths=db_paths) |
| if validation_errors: |
| for error in validation_errors: |
| print(f"ERROR: {error}") |
| raise SystemExit(1) |
|
|
| train_questions: list[dict[str, Any]] = [] |
| eval_questions: list[dict[str, Any]] = [] |
| for record in final_questions: |
| output_record = { |
| key: value for key, value in record.items() if key != "spider_split" |
| } |
| if output_record["split"] == TRAIN_SPLIT: |
| train_questions.append(output_record) |
| else: |
| eval_questions.append(output_record) |
|
|
| _write_output(args.train_output, train_questions) |
| _write_output(args.eval_output, eval_questions) |
|
|
| print(f"Prepared {len(db_paths)} databases in {args.output_dir}") |
| print(f"Loaded {len(raw_questions)} Spider questions") |
| print(f"Curated {len(final_questions)} questions (skipped {skipped_count})") |
| print("Validation passed") |
| print(f"Wrote {len(train_questions)} train records to {args.train_output}") |
| print(f"Wrote {len(eval_questions)} eval records to {args.eval_output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|