| """Download Spider SQLite databases used by SQLEnv. |
| |
| Uses the same download logic as curate_questions.py: tries GitHub raw URLs |
| first, then falls back to the official Google Drive Spider archive. |
| |
| Examples |
| -------- |
| Download the default database (student_assessment): |
| uv run python scripts/download_spider_databases.py |
| |
| Download a specific database: |
| uv run python scripts/download_spider_databases.py --db-id concert_singer |
| |
| Download all databases referenced in db_list.json: |
| uv run python scripts/download_spider_databases.py --db-id all |
| |
| Force re-download: |
| uv run python scripts/download_spider_databases.py --force |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import io |
| import json |
| import re |
| import time |
| import zipfile |
| from pathlib import Path |
| from urllib.error import HTTPError, URLError |
| from urllib.request import Request, urlopen |
|
|
| SPIDER_RAW_SQLITE_URLS = ( |
| "https://raw.githubusercontent.com/taoyds/spider/master/database/{db_id}/{db_id}.sqlite", |
| "https://github.com/taoyds/spider/raw/master/database/{db_id}/{db_id}.sqlite", |
| ) |
| SPIDER_ARCHIVE_DRIVE_ID = "1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J" |
| SQLITE_MAGIC = b"SQLite format 3\x00" |
| DB_LIST_PATH = Path("data/questions/db_list.json") |
|
|
|
|
| def _validate_db_id(db_id: str) -> str: |
| normalized = db_id.strip() |
| if not normalized: |
| raise ValueError("db_id cannot be empty") |
| if not re.fullmatch(r"[A-Za-z0-9_]+", normalized): |
| raise ValueError( |
| "Invalid db_id — only letters, numbers, and underscores allowed." |
| ) |
| return normalized |
|
|
|
|
| def _is_valid_sqlite(path: Path) -> bool: |
| if not path.exists() or path.stat().st_size < 16: |
| return False |
| with path.open("rb") as f: |
| return f.read(16) == SQLITE_MAGIC |
|
|
|
|
| def _safe_sqlite_path(output_dir: Path, db_id: str) -> Path: |
| sqlite_path = output_dir / db_id / f"{db_id}.sqlite" |
| output_root = output_dir.resolve() |
| resolved = sqlite_path.resolve() |
| if output_root not in resolved.parents: |
| raise ValueError(f"Resolved path escapes output directory: {resolved}") |
| return sqlite_path |
|
|
|
|
| def _try_raw_download(db_id: str, destination: Path) -> bool: |
| """Try downloading from GitHub raw URLs. Returns True on success.""" |
| for url_template in SPIDER_RAW_SQLITE_URLS: |
| url = url_template.format(db_id=db_id) |
| try: |
| req = Request(url, headers={"User-Agent": "sqlenv/1.0"}) |
| with urlopen(req, timeout=30) as resp: |
| data = resp.read() |
| if not data.startswith(SQLITE_MAGIC): |
| continue |
| tmp = destination.with_suffix(".tmp") |
| destination.parent.mkdir(parents=True, exist_ok=True) |
| tmp.write_bytes(data) |
| tmp.replace(destination) |
| return True |
| except (HTTPError, URLError, OSError): |
| continue |
| return False |
|
|
|
|
| def _download_drive_archive() -> bytes: |
| """Download official Spider archive from Google Drive.""" |
| drive_url = ( |
| f"https://drive.google.com/uc?export=download&id={SPIDER_ARCHIVE_DRIVE_ID}" |
| ) |
| req = Request(drive_url, headers={"User-Agent": "sqlenv/1.0"}) |
|
|
| for attempt in range(2): |
| try: |
| with urlopen(req, timeout=120) as resp: |
| payload = resp.read() |
|
|
| if payload.startswith(b"PK"): |
| return payload |
|
|
| |
| text = payload.decode("utf-8", errors="replace") |
| confirm_match = re.search(r'name="confirm" value="([^"]+)"', text) |
| if confirm_match: |
| confirm_url = ( |
| "https://drive.usercontent.google.com/download" |
| f"?id={SPIDER_ARCHIVE_DRIVE_ID}" |
| f"&export=download&confirm={confirm_match.group(1)}" |
| ) |
| confirm_req = Request( |
| confirm_url, |
| headers={"User-Agent": "sqlenv/1.0"}, |
| ) |
| with urlopen(confirm_req, timeout=240) as resp2: |
| payload = resp2.read() |
| if payload.startswith(b"PK"): |
| return payload |
|
|
| raise RuntimeError("Drive response was not a zip file") |
| except (HTTPError, URLError, OSError, RuntimeError): |
| if attempt == 0: |
| time.sleep(3) |
|
|
| raise RuntimeError( |
| "Failed to download Spider archive from Google Drive after retries" |
| ) |
|
|
|
|
| def _extract_from_archive(archive_bytes: bytes, db_id: str, destination: Path) -> None: |
| """Extract a single database from the Spider zip archive.""" |
| candidates = [ |
| f"spider_data/database/{db_id}/{db_id}.sqlite", |
| f"spider/database/{db_id}/{db_id}.sqlite", |
| f"spider-master/database/{db_id}/{db_id}.sqlite", |
| ] |
| with zipfile.ZipFile(io.BytesIO(archive_bytes)) as zf: |
| for member in candidates: |
| try: |
| data = zf.read(member) |
| if data.startswith(SQLITE_MAGIC): |
| destination.parent.mkdir(parents=True, exist_ok=True) |
| tmp = destination.with_suffix(".tmp") |
| tmp.write_bytes(data) |
| tmp.replace(destination) |
| return |
| except KeyError: |
| continue |
| raise FileNotFoundError(f"Database '{db_id}' not found in Spider archive") |
|
|
|
|
| def _extract_all_from_archive( |
| archive_bytes: bytes, output_dir: Path, force: bool |
| ) -> int: |
| """Extract all databases from the Spider archive.""" |
| count = 0 |
| with zipfile.ZipFile(io.BytesIO(archive_bytes)) as zf: |
| for member in zf.namelist(): |
| if not member.endswith(".sqlite"): |
| continue |
| if "/database/" not in member: |
| continue |
| db_name = Path(member).stem |
| target = output_dir / db_name / f"{db_name}.sqlite" |
| if target.exists() and not force: |
| continue |
| data = zf.read(member) |
| if not data.startswith(SQLITE_MAGIC): |
| continue |
| target.parent.mkdir(parents=True, exist_ok=True) |
| tmp = target.with_suffix(".tmp") |
| tmp.write_bytes(data) |
| tmp.replace(target) |
| count += 1 |
| return count |
|
|
|
|
| def download_database(db_id: str, output_dir: Path, force: bool = False) -> Path: |
| """Download one Spider database, with Google Drive fallback.""" |
| normalized = _validate_db_id(db_id) |
| sqlite_path = _safe_sqlite_path(output_dir, normalized) |
|
|
| if _is_valid_sqlite(sqlite_path) and not force: |
| print(f"Already exists: {sqlite_path}") |
| return sqlite_path |
|
|
| print(f"Downloading {normalized}...") |
|
|
| if _try_raw_download(normalized, sqlite_path): |
| print(f" -> {sqlite_path} (from GitHub)") |
| return sqlite_path |
|
|
| print(" GitHub raw URLs failed, trying Google Drive archive...") |
| archive_bytes = _download_drive_archive() |
| _extract_from_archive(archive_bytes, normalized, sqlite_path) |
| print(f" -> {sqlite_path} (from Drive archive)") |
| return sqlite_path |
|
|
|
|
| def download_all(output_dir: Path, force: bool = False) -> int: |
| """Download all databases from Google Drive archive.""" |
| output_dir.mkdir(parents=True, exist_ok=True) |
| print("Downloading Spider archive from Google Drive...") |
| archive_bytes = _download_drive_archive() |
| count = _extract_all_from_archive(archive_bytes, output_dir, force) |
| print(f"Extracted {count} database(s) to {output_dir}") |
| return count |
|
|
|
|
| def download_listed(output_dir: Path, force: bool = False) -> int: |
| """Download databases listed in db_list.json.""" |
| if not DB_LIST_PATH.exists(): |
| raise FileNotFoundError( |
| f"{DB_LIST_PATH} not found — run curate_questions.py first " |
| "or use --db-id <name> to download individual databases" |
| ) |
| db_ids = json.loads(DB_LIST_PATH.read_text()) |
| print(f"Downloading {len(db_ids)} databases from db_list.json...") |
|
|
| |
| remaining = [] |
| for db_id in db_ids: |
| normalized = _validate_db_id(db_id) |
| sqlite_path = _safe_sqlite_path(output_dir, normalized) |
| if _is_valid_sqlite(sqlite_path) and not force: |
| print(f" Already exists: {normalized}") |
| continue |
| if _try_raw_download(normalized, sqlite_path): |
| print(f" Downloaded: {normalized} (GitHub)") |
| else: |
| remaining.append(normalized) |
|
|
| if remaining: |
| print( |
| f" {len(remaining)} failed from GitHub, falling back to Drive archive..." |
| ) |
| archive_bytes = _download_drive_archive() |
| for db_id in remaining: |
| sqlite_path = _safe_sqlite_path(output_dir, db_id) |
| try: |
| _extract_from_archive(archive_bytes, db_id, sqlite_path) |
| print(f" Downloaded: {db_id} (Drive archive)") |
| except FileNotFoundError: |
| print(f" FAILED: {db_id} not found in archive") |
|
|
| downloaded = sum( |
| 1 |
| for db_id in db_ids |
| if _is_valid_sqlite(output_dir / db_id / f"{db_id}.sqlite") |
| ) |
| print(f"Ready: {downloaded}/{len(db_ids)} databases in {output_dir}") |
| return downloaded |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Download Spider SQLite databases for SQLEnv", |
| ) |
| parser.add_argument( |
| "--db-id", |
| type=str, |
| default=None, |
| help=( |
| "Spider database ID to download. " |
| "Use 'all' for every Spider DB, or omit to download " |
| "databases listed in data/questions/db_list.json" |
| ), |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=Path, |
| default=Path("data/databases"), |
| help="Directory to store databases (default: data/databases)", |
| ) |
| parser.add_argument( |
| "--force", |
| action="store_true", |
| help="Overwrite existing files", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| if args.db_id is None: |
| download_listed(output_dir=args.output_dir, force=args.force) |
| elif args.db_id.lower() == "all": |
| download_all(output_dir=args.output_dir, force=args.force) |
| else: |
| download_database( |
| db_id=args.db_id, |
| output_dir=args.output_dir, |
| force=args.force, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|