File size: 10,508 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""Download Spider SQLite databases used by SQLEnv.

Uses the same download logic as curate_questions.py: tries GitHub raw URLs
first, then falls back to the official Google Drive Spider archive.

Examples
--------
Download the default database (student_assessment):
    uv run python scripts/download_spider_databases.py

Download a specific database:
    uv run python scripts/download_spider_databases.py --db-id concert_singer

Download all databases referenced in db_list.json:
    uv run python scripts/download_spider_databases.py --db-id all

Force re-download:
    uv run python scripts/download_spider_databases.py --force
"""

from __future__ import annotations

import argparse
import io
import json
import re
import time
import zipfile
from pathlib import Path
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen

SPIDER_RAW_SQLITE_URLS = (
    "https://raw.githubusercontent.com/taoyds/spider/master/database/{db_id}/{db_id}.sqlite",
    "https://github.com/taoyds/spider/raw/master/database/{db_id}/{db_id}.sqlite",
)
SPIDER_ARCHIVE_DRIVE_ID = "1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J"
SQLITE_MAGIC = b"SQLite format 3\x00"
DB_LIST_PATH = Path("data/questions/db_list.json")


def _validate_db_id(db_id: str) -> str:
    normalized = db_id.strip()
    if not normalized:
        raise ValueError("db_id cannot be empty")
    if not re.fullmatch(r"[A-Za-z0-9_]+", normalized):
        raise ValueError(
            "Invalid db_id — only letters, numbers, and underscores allowed."
        )
    return normalized


def _is_valid_sqlite(path: Path) -> bool:
    if not path.exists() or path.stat().st_size < 16:
        return False
    with path.open("rb") as f:
        return f.read(16) == SQLITE_MAGIC


def _safe_sqlite_path(output_dir: Path, db_id: str) -> Path:
    sqlite_path = output_dir / db_id / f"{db_id}.sqlite"
    output_root = output_dir.resolve()
    resolved = sqlite_path.resolve()
    if output_root not in resolved.parents:
        raise ValueError(f"Resolved path escapes output directory: {resolved}")
    return sqlite_path


def _try_raw_download(db_id: str, destination: Path) -> bool:
    """Try downloading from GitHub raw URLs. Returns True on success."""
    for url_template in SPIDER_RAW_SQLITE_URLS:
        url = url_template.format(db_id=db_id)
        try:
            req = Request(url, headers={"User-Agent": "sqlenv/1.0"})
            with urlopen(req, timeout=30) as resp:
                data = resp.read()
            if not data.startswith(SQLITE_MAGIC):
                continue
            tmp = destination.with_suffix(".tmp")
            destination.parent.mkdir(parents=True, exist_ok=True)
            tmp.write_bytes(data)
            tmp.replace(destination)
            return True
        except (HTTPError, URLError, OSError):
            continue
    return False


def _download_drive_archive() -> bytes:
    """Download official Spider archive from Google Drive."""
    drive_url = (
        f"https://drive.google.com/uc?export=download&id={SPIDER_ARCHIVE_DRIVE_ID}"
    )
    req = Request(drive_url, headers={"User-Agent": "sqlenv/1.0"})

    for attempt in range(2):
        try:
            with urlopen(req, timeout=120) as resp:
                payload = resp.read()

            if payload.startswith(b"PK"):
                return payload

            # Google Drive virus-scan warning page — parse confirm token
            text = payload.decode("utf-8", errors="replace")
            confirm_match = re.search(r'name="confirm" value="([^"]+)"', text)
            if confirm_match:
                confirm_url = (
                    "https://drive.usercontent.google.com/download"
                    f"?id={SPIDER_ARCHIVE_DRIVE_ID}"
                    f"&export=download&confirm={confirm_match.group(1)}"
                )
                confirm_req = Request(
                    confirm_url,
                    headers={"User-Agent": "sqlenv/1.0"},
                )
                with urlopen(confirm_req, timeout=240) as resp2:
                    payload = resp2.read()
                if payload.startswith(b"PK"):
                    return payload

            raise RuntimeError("Drive response was not a zip file")
        except (HTTPError, URLError, OSError, RuntimeError):
            if attempt == 0:
                time.sleep(3)

    raise RuntimeError(
        "Failed to download Spider archive from Google Drive after retries"
    )


def _extract_from_archive(archive_bytes: bytes, db_id: str, destination: Path) -> None:
    """Extract a single database from the Spider zip archive."""
    candidates = [
        f"spider_data/database/{db_id}/{db_id}.sqlite",
        f"spider/database/{db_id}/{db_id}.sqlite",
        f"spider-master/database/{db_id}/{db_id}.sqlite",
    ]
    with zipfile.ZipFile(io.BytesIO(archive_bytes)) as zf:
        for member in candidates:
            try:
                data = zf.read(member)
                if data.startswith(SQLITE_MAGIC):
                    destination.parent.mkdir(parents=True, exist_ok=True)
                    tmp = destination.with_suffix(".tmp")
                    tmp.write_bytes(data)
                    tmp.replace(destination)
                    return
            except KeyError:
                continue
    raise FileNotFoundError(f"Database '{db_id}' not found in Spider archive")


def _extract_all_from_archive(
    archive_bytes: bytes, output_dir: Path, force: bool
) -> int:
    """Extract all databases from the Spider archive."""
    count = 0
    with zipfile.ZipFile(io.BytesIO(archive_bytes)) as zf:
        for member in zf.namelist():
            if not member.endswith(".sqlite"):
                continue
            if "/database/" not in member:
                continue
            db_name = Path(member).stem
            target = output_dir / db_name / f"{db_name}.sqlite"
            if target.exists() and not force:
                continue
            data = zf.read(member)
            if not data.startswith(SQLITE_MAGIC):
                continue
            target.parent.mkdir(parents=True, exist_ok=True)
            tmp = target.with_suffix(".tmp")
            tmp.write_bytes(data)
            tmp.replace(target)
            count += 1
    return count


def download_database(db_id: str, output_dir: Path, force: bool = False) -> Path:
    """Download one Spider database, with Google Drive fallback."""
    normalized = _validate_db_id(db_id)
    sqlite_path = _safe_sqlite_path(output_dir, normalized)

    if _is_valid_sqlite(sqlite_path) and not force:
        print(f"Already exists: {sqlite_path}")
        return sqlite_path

    print(f"Downloading {normalized}...")

    if _try_raw_download(normalized, sqlite_path):
        print(f"  -> {sqlite_path} (from GitHub)")
        return sqlite_path

    print("  GitHub raw URLs failed, trying Google Drive archive...")
    archive_bytes = _download_drive_archive()
    _extract_from_archive(archive_bytes, normalized, sqlite_path)
    print(f"  -> {sqlite_path} (from Drive archive)")
    return sqlite_path


def download_all(output_dir: Path, force: bool = False) -> int:
    """Download all databases from Google Drive archive."""
    output_dir.mkdir(parents=True, exist_ok=True)
    print("Downloading Spider archive from Google Drive...")
    archive_bytes = _download_drive_archive()
    count = _extract_all_from_archive(archive_bytes, output_dir, force)
    print(f"Extracted {count} database(s) to {output_dir}")
    return count


def download_listed(output_dir: Path, force: bool = False) -> int:
    """Download databases listed in db_list.json."""
    if not DB_LIST_PATH.exists():
        raise FileNotFoundError(
            f"{DB_LIST_PATH} not found — run curate_questions.py first "
            "or use --db-id <name> to download individual databases"
        )
    db_ids = json.loads(DB_LIST_PATH.read_text())
    print(f"Downloading {len(db_ids)} databases from db_list.json...")

    # Try GitHub raw first, batch fallback to archive for failures
    remaining = []
    for db_id in db_ids:
        normalized = _validate_db_id(db_id)
        sqlite_path = _safe_sqlite_path(output_dir, normalized)
        if _is_valid_sqlite(sqlite_path) and not force:
            print(f"  Already exists: {normalized}")
            continue
        if _try_raw_download(normalized, sqlite_path):
            print(f"  Downloaded: {normalized} (GitHub)")
        else:
            remaining.append(normalized)

    if remaining:
        print(
            f"  {len(remaining)} failed from GitHub, falling back to Drive archive..."
        )
        archive_bytes = _download_drive_archive()
        for db_id in remaining:
            sqlite_path = _safe_sqlite_path(output_dir, db_id)
            try:
                _extract_from_archive(archive_bytes, db_id, sqlite_path)
                print(f"  Downloaded: {db_id} (Drive archive)")
            except FileNotFoundError:
                print(f"  FAILED: {db_id} not found in archive")

    downloaded = sum(
        1
        for db_id in db_ids
        if _is_valid_sqlite(output_dir / db_id / f"{db_id}.sqlite")
    )
    print(f"Ready: {downloaded}/{len(db_ids)} databases in {output_dir}")
    return downloaded


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Download Spider SQLite databases for SQLEnv",
    )
    parser.add_argument(
        "--db-id",
        type=str,
        default=None,
        help=(
            "Spider database ID to download. "
            "Use 'all' for every Spider DB, or omit to download "
            "databases listed in data/questions/db_list.json"
        ),
    )
    parser.add_argument(
        "--output-dir",
        type=Path,
        default=Path("data/databases"),
        help="Directory to store databases (default: data/databases)",
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="Overwrite existing files",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()

    if args.db_id is None:
        download_listed(output_dir=args.output_dir, force=args.force)
    elif args.db_id.lower() == "all":
        download_all(output_dir=args.output_dir, force=args.force)
    else:
        download_database(
            db_id=args.db_id,
            output_dir=args.output_dir,
            force=args.force,
        )


if __name__ == "__main__":
    main()