sql_env / scripts /download_spider_databases.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Download Spider SQLite databases used by SQLEnv.
Uses the same download logic as curate_questions.py: tries GitHub raw URLs
first, then falls back to the official Google Drive Spider archive.
Examples
--------
Download the default database (student_assessment):
uv run python scripts/download_spider_databases.py
Download a specific database:
uv run python scripts/download_spider_databases.py --db-id concert_singer
Download all databases referenced in db_list.json:
uv run python scripts/download_spider_databases.py --db-id all
Force re-download:
uv run python scripts/download_spider_databases.py --force
"""
from __future__ import annotations
import argparse
import io
import json
import re
import time
import zipfile
from pathlib import Path
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
SPIDER_RAW_SQLITE_URLS = (
"https://raw.githubusercontent.com/taoyds/spider/master/database/{db_id}/{db_id}.sqlite",
"https://github.com/taoyds/spider/raw/master/database/{db_id}/{db_id}.sqlite",
)
SPIDER_ARCHIVE_DRIVE_ID = "1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J"
SQLITE_MAGIC = b"SQLite format 3\x00"
DB_LIST_PATH = Path("data/questions/db_list.json")
def _validate_db_id(db_id: str) -> str:
normalized = db_id.strip()
if not normalized:
raise ValueError("db_id cannot be empty")
if not re.fullmatch(r"[A-Za-z0-9_]+", normalized):
raise ValueError(
"Invalid db_id — only letters, numbers, and underscores allowed."
)
return normalized
def _is_valid_sqlite(path: Path) -> bool:
if not path.exists() or path.stat().st_size < 16:
return False
with path.open("rb") as f:
return f.read(16) == SQLITE_MAGIC
def _safe_sqlite_path(output_dir: Path, db_id: str) -> Path:
sqlite_path = output_dir / db_id / f"{db_id}.sqlite"
output_root = output_dir.resolve()
resolved = sqlite_path.resolve()
if output_root not in resolved.parents:
raise ValueError(f"Resolved path escapes output directory: {resolved}")
return sqlite_path
def _try_raw_download(db_id: str, destination: Path) -> bool:
"""Try downloading from GitHub raw URLs. Returns True on success."""
for url_template in SPIDER_RAW_SQLITE_URLS:
url = url_template.format(db_id=db_id)
try:
req = Request(url, headers={"User-Agent": "sqlenv/1.0"})
with urlopen(req, timeout=30) as resp:
data = resp.read()
if not data.startswith(SQLITE_MAGIC):
continue
tmp = destination.with_suffix(".tmp")
destination.parent.mkdir(parents=True, exist_ok=True)
tmp.write_bytes(data)
tmp.replace(destination)
return True
except (HTTPError, URLError, OSError):
continue
return False
def _download_drive_archive() -> bytes:
"""Download official Spider archive from Google Drive."""
drive_url = (
f"https://drive.google.com/uc?export=download&id={SPIDER_ARCHIVE_DRIVE_ID}"
)
req = Request(drive_url, headers={"User-Agent": "sqlenv/1.0"})
for attempt in range(2):
try:
with urlopen(req, timeout=120) as resp:
payload = resp.read()
if payload.startswith(b"PK"):
return payload
# Google Drive virus-scan warning page — parse confirm token
text = payload.decode("utf-8", errors="replace")
confirm_match = re.search(r'name="confirm" value="([^"]+)"', text)
if confirm_match:
confirm_url = (
"https://drive.usercontent.google.com/download"
f"?id={SPIDER_ARCHIVE_DRIVE_ID}"
f"&export=download&confirm={confirm_match.group(1)}"
)
confirm_req = Request(
confirm_url,
headers={"User-Agent": "sqlenv/1.0"},
)
with urlopen(confirm_req, timeout=240) as resp2:
payload = resp2.read()
if payload.startswith(b"PK"):
return payload
raise RuntimeError("Drive response was not a zip file")
except (HTTPError, URLError, OSError, RuntimeError):
if attempt == 0:
time.sleep(3)
raise RuntimeError(
"Failed to download Spider archive from Google Drive after retries"
)
def _extract_from_archive(archive_bytes: bytes, db_id: str, destination: Path) -> None:
"""Extract a single database from the Spider zip archive."""
candidates = [
f"spider_data/database/{db_id}/{db_id}.sqlite",
f"spider/database/{db_id}/{db_id}.sqlite",
f"spider-master/database/{db_id}/{db_id}.sqlite",
]
with zipfile.ZipFile(io.BytesIO(archive_bytes)) as zf:
for member in candidates:
try:
data = zf.read(member)
if data.startswith(SQLITE_MAGIC):
destination.parent.mkdir(parents=True, exist_ok=True)
tmp = destination.with_suffix(".tmp")
tmp.write_bytes(data)
tmp.replace(destination)
return
except KeyError:
continue
raise FileNotFoundError(f"Database '{db_id}' not found in Spider archive")
def _extract_all_from_archive(
archive_bytes: bytes, output_dir: Path, force: bool
) -> int:
"""Extract all databases from the Spider archive."""
count = 0
with zipfile.ZipFile(io.BytesIO(archive_bytes)) as zf:
for member in zf.namelist():
if not member.endswith(".sqlite"):
continue
if "/database/" not in member:
continue
db_name = Path(member).stem
target = output_dir / db_name / f"{db_name}.sqlite"
if target.exists() and not force:
continue
data = zf.read(member)
if not data.startswith(SQLITE_MAGIC):
continue
target.parent.mkdir(parents=True, exist_ok=True)
tmp = target.with_suffix(".tmp")
tmp.write_bytes(data)
tmp.replace(target)
count += 1
return count
def download_database(db_id: str, output_dir: Path, force: bool = False) -> Path:
"""Download one Spider database, with Google Drive fallback."""
normalized = _validate_db_id(db_id)
sqlite_path = _safe_sqlite_path(output_dir, normalized)
if _is_valid_sqlite(sqlite_path) and not force:
print(f"Already exists: {sqlite_path}")
return sqlite_path
print(f"Downloading {normalized}...")
if _try_raw_download(normalized, sqlite_path):
print(f" -> {sqlite_path} (from GitHub)")
return sqlite_path
print(" GitHub raw URLs failed, trying Google Drive archive...")
archive_bytes = _download_drive_archive()
_extract_from_archive(archive_bytes, normalized, sqlite_path)
print(f" -> {sqlite_path} (from Drive archive)")
return sqlite_path
def download_all(output_dir: Path, force: bool = False) -> int:
"""Download all databases from Google Drive archive."""
output_dir.mkdir(parents=True, exist_ok=True)
print("Downloading Spider archive from Google Drive...")
archive_bytes = _download_drive_archive()
count = _extract_all_from_archive(archive_bytes, output_dir, force)
print(f"Extracted {count} database(s) to {output_dir}")
return count
def download_listed(output_dir: Path, force: bool = False) -> int:
"""Download databases listed in db_list.json."""
if not DB_LIST_PATH.exists():
raise FileNotFoundError(
f"{DB_LIST_PATH} not found — run curate_questions.py first "
"or use --db-id <name> to download individual databases"
)
db_ids = json.loads(DB_LIST_PATH.read_text())
print(f"Downloading {len(db_ids)} databases from db_list.json...")
# Try GitHub raw first, batch fallback to archive for failures
remaining = []
for db_id in db_ids:
normalized = _validate_db_id(db_id)
sqlite_path = _safe_sqlite_path(output_dir, normalized)
if _is_valid_sqlite(sqlite_path) and not force:
print(f" Already exists: {normalized}")
continue
if _try_raw_download(normalized, sqlite_path):
print(f" Downloaded: {normalized} (GitHub)")
else:
remaining.append(normalized)
if remaining:
print(
f" {len(remaining)} failed from GitHub, falling back to Drive archive..."
)
archive_bytes = _download_drive_archive()
for db_id in remaining:
sqlite_path = _safe_sqlite_path(output_dir, db_id)
try:
_extract_from_archive(archive_bytes, db_id, sqlite_path)
print(f" Downloaded: {db_id} (Drive archive)")
except FileNotFoundError:
print(f" FAILED: {db_id} not found in archive")
downloaded = sum(
1
for db_id in db_ids
if _is_valid_sqlite(output_dir / db_id / f"{db_id}.sqlite")
)
print(f"Ready: {downloaded}/{len(db_ids)} databases in {output_dir}")
return downloaded
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Download Spider SQLite databases for SQLEnv",
)
parser.add_argument(
"--db-id",
type=str,
default=None,
help=(
"Spider database ID to download. "
"Use 'all' for every Spider DB, or omit to download "
"databases listed in data/questions/db_list.json"
),
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data/databases"),
help="Directory to store databases (default: data/databases)",
)
parser.add_argument(
"--force",
action="store_true",
help="Overwrite existing files",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
if args.db_id is None:
download_listed(output_dir=args.output_dir, force=args.force)
elif args.db_id.lower() == "all":
download_all(output_dir=args.output_dir, force=args.force)
else:
download_database(
db_id=args.db_id,
output_dir=args.output_dir,
force=args.force,
)
if __name__ == "__main__":
main()