"""Download Spider SQLite databases used by SQLEnv. Uses the same download logic as curate_questions.py: tries GitHub raw URLs first, then falls back to the official Google Drive Spider archive. Examples -------- Download the default database (student_assessment): uv run python scripts/download_spider_databases.py Download a specific database: uv run python scripts/download_spider_databases.py --db-id concert_singer Download all databases referenced in db_list.json: uv run python scripts/download_spider_databases.py --db-id all Force re-download: uv run python scripts/download_spider_databases.py --force """ from __future__ import annotations import argparse import io import json import re import time import zipfile from pathlib import Path from urllib.error import HTTPError, URLError from urllib.request import Request, urlopen SPIDER_RAW_SQLITE_URLS = ( "https://raw.githubusercontent.com/taoyds/spider/master/database/{db_id}/{db_id}.sqlite", "https://github.com/taoyds/spider/raw/master/database/{db_id}/{db_id}.sqlite", ) SPIDER_ARCHIVE_DRIVE_ID = "1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J" SQLITE_MAGIC = b"SQLite format 3\x00" DB_LIST_PATH = Path("data/questions/db_list.json") def _validate_db_id(db_id: str) -> str: normalized = db_id.strip() if not normalized: raise ValueError("db_id cannot be empty") if not re.fullmatch(r"[A-Za-z0-9_]+", normalized): raise ValueError( "Invalid db_id — only letters, numbers, and underscores allowed." ) return normalized def _is_valid_sqlite(path: Path) -> bool: if not path.exists() or path.stat().st_size < 16: return False with path.open("rb") as f: return f.read(16) == SQLITE_MAGIC def _safe_sqlite_path(output_dir: Path, db_id: str) -> Path: sqlite_path = output_dir / db_id / f"{db_id}.sqlite" output_root = output_dir.resolve() resolved = sqlite_path.resolve() if output_root not in resolved.parents: raise ValueError(f"Resolved path escapes output directory: {resolved}") return sqlite_path def _try_raw_download(db_id: str, destination: Path) -> bool: """Try downloading from GitHub raw URLs. Returns True on success.""" for url_template in SPIDER_RAW_SQLITE_URLS: url = url_template.format(db_id=db_id) try: req = Request(url, headers={"User-Agent": "sqlenv/1.0"}) with urlopen(req, timeout=30) as resp: data = resp.read() if not data.startswith(SQLITE_MAGIC): continue tmp = destination.with_suffix(".tmp") destination.parent.mkdir(parents=True, exist_ok=True) tmp.write_bytes(data) tmp.replace(destination) return True except (HTTPError, URLError, OSError): continue return False def _download_drive_archive() -> bytes: """Download official Spider archive from Google Drive.""" drive_url = ( f"https://drive.google.com/uc?export=download&id={SPIDER_ARCHIVE_DRIVE_ID}" ) req = Request(drive_url, headers={"User-Agent": "sqlenv/1.0"}) for attempt in range(2): try: with urlopen(req, timeout=120) as resp: payload = resp.read() if payload.startswith(b"PK"): return payload # Google Drive virus-scan warning page — parse confirm token text = payload.decode("utf-8", errors="replace") confirm_match = re.search(r'name="confirm" value="([^"]+)"', text) if confirm_match: confirm_url = ( "https://drive.usercontent.google.com/download" f"?id={SPIDER_ARCHIVE_DRIVE_ID}" f"&export=download&confirm={confirm_match.group(1)}" ) confirm_req = Request( confirm_url, headers={"User-Agent": "sqlenv/1.0"}, ) with urlopen(confirm_req, timeout=240) as resp2: payload = resp2.read() if payload.startswith(b"PK"): return payload raise RuntimeError("Drive response was not a zip file") except (HTTPError, URLError, OSError, RuntimeError): if attempt == 0: time.sleep(3) raise RuntimeError( "Failed to download Spider archive from Google Drive after retries" ) def _extract_from_archive(archive_bytes: bytes, db_id: str, destination: Path) -> None: """Extract a single database from the Spider zip archive.""" candidates = [ f"spider_data/database/{db_id}/{db_id}.sqlite", f"spider/database/{db_id}/{db_id}.sqlite", f"spider-master/database/{db_id}/{db_id}.sqlite", ] with zipfile.ZipFile(io.BytesIO(archive_bytes)) as zf: for member in candidates: try: data = zf.read(member) if data.startswith(SQLITE_MAGIC): destination.parent.mkdir(parents=True, exist_ok=True) tmp = destination.with_suffix(".tmp") tmp.write_bytes(data) tmp.replace(destination) return except KeyError: continue raise FileNotFoundError(f"Database '{db_id}' not found in Spider archive") def _extract_all_from_archive( archive_bytes: bytes, output_dir: Path, force: bool ) -> int: """Extract all databases from the Spider archive.""" count = 0 with zipfile.ZipFile(io.BytesIO(archive_bytes)) as zf: for member in zf.namelist(): if not member.endswith(".sqlite"): continue if "/database/" not in member: continue db_name = Path(member).stem target = output_dir / db_name / f"{db_name}.sqlite" if target.exists() and not force: continue data = zf.read(member) if not data.startswith(SQLITE_MAGIC): continue target.parent.mkdir(parents=True, exist_ok=True) tmp = target.with_suffix(".tmp") tmp.write_bytes(data) tmp.replace(target) count += 1 return count def download_database(db_id: str, output_dir: Path, force: bool = False) -> Path: """Download one Spider database, with Google Drive fallback.""" normalized = _validate_db_id(db_id) sqlite_path = _safe_sqlite_path(output_dir, normalized) if _is_valid_sqlite(sqlite_path) and not force: print(f"Already exists: {sqlite_path}") return sqlite_path print(f"Downloading {normalized}...") if _try_raw_download(normalized, sqlite_path): print(f" -> {sqlite_path} (from GitHub)") return sqlite_path print(" GitHub raw URLs failed, trying Google Drive archive...") archive_bytes = _download_drive_archive() _extract_from_archive(archive_bytes, normalized, sqlite_path) print(f" -> {sqlite_path} (from Drive archive)") return sqlite_path def download_all(output_dir: Path, force: bool = False) -> int: """Download all databases from Google Drive archive.""" output_dir.mkdir(parents=True, exist_ok=True) print("Downloading Spider archive from Google Drive...") archive_bytes = _download_drive_archive() count = _extract_all_from_archive(archive_bytes, output_dir, force) print(f"Extracted {count} database(s) to {output_dir}") return count def download_listed(output_dir: Path, force: bool = False) -> int: """Download databases listed in db_list.json.""" if not DB_LIST_PATH.exists(): raise FileNotFoundError( f"{DB_LIST_PATH} not found — run curate_questions.py first " "or use --db-id to download individual databases" ) db_ids = json.loads(DB_LIST_PATH.read_text()) print(f"Downloading {len(db_ids)} databases from db_list.json...") # Try GitHub raw first, batch fallback to archive for failures remaining = [] for db_id in db_ids: normalized = _validate_db_id(db_id) sqlite_path = _safe_sqlite_path(output_dir, normalized) if _is_valid_sqlite(sqlite_path) and not force: print(f" Already exists: {normalized}") continue if _try_raw_download(normalized, sqlite_path): print(f" Downloaded: {normalized} (GitHub)") else: remaining.append(normalized) if remaining: print( f" {len(remaining)} failed from GitHub, falling back to Drive archive..." ) archive_bytes = _download_drive_archive() for db_id in remaining: sqlite_path = _safe_sqlite_path(output_dir, db_id) try: _extract_from_archive(archive_bytes, db_id, sqlite_path) print(f" Downloaded: {db_id} (Drive archive)") except FileNotFoundError: print(f" FAILED: {db_id} not found in archive") downloaded = sum( 1 for db_id in db_ids if _is_valid_sqlite(output_dir / db_id / f"{db_id}.sqlite") ) print(f"Ready: {downloaded}/{len(db_ids)} databases in {output_dir}") return downloaded def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Download Spider SQLite databases for SQLEnv", ) parser.add_argument( "--db-id", type=str, default=None, help=( "Spider database ID to download. " "Use 'all' for every Spider DB, or omit to download " "databases listed in data/questions/db_list.json" ), ) parser.add_argument( "--output-dir", type=Path, default=Path("data/databases"), help="Directory to store databases (default: data/databases)", ) parser.add_argument( "--force", action="store_true", help="Overwrite existing files", ) return parser.parse_args() def main() -> None: args = parse_args() if args.db_id is None: download_listed(output_dir=args.output_dir, force=args.force) elif args.db_id.lower() == "all": download_all(output_dir=args.output_dir, force=args.force) else: download_database( db_id=args.db_id, output_dir=args.output_dir, force=args.force, ) if __name__ == "__main__": main()