File size: 3,391 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Script to download Spider dataset questions for specific databases.

Usage:
    python download_spider_data.py --db-id student_assessment
    python download_spider_data.py --db-id student_assessment --split validation
    python download_spider_data.py --db-id all  # downloads all db_ids
"""

import json
import argparse
from pathlib import Path
from datasets import load_dataset


def download_spider_questions(
    db_id: str = "student_assessment",
    split: str = "train",
    output_dir: str = "data/questions",
) -> None:
    """Download Spider dataset questions for specified database(s).

    Args:
        db_id: Database ID to filter by, or "all" to get all databases
        split: Dataset split ("train" or "validation")
        output_dir: Directory to save JSON files
    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    print(f"Loading Spider dataset ({split} split)...")
    dataset = load_dataset("xlangai/spider", split=split)

    if db_id.lower() == "all":
        # Group by db_id
        grouped = {}
        for item in dataset:
            current_db_id = item.get("db_id")
            if current_db_id not in grouped:
                grouped[current_db_id] = []
            grouped[current_db_id].append(item)

        total_questions = 0
        for current_db_id, questions in grouped.items():
            filepath = output_path / f"{current_db_id}.json"
            with open(filepath, "w") as f:
                json.dump(questions, f, indent=2)
            print(f"  {current_db_id}: {len(questions)} questions → {filepath}")
            total_questions += len(questions)

        print(f"\nTotal: {total_questions} questions across {len(grouped)} databases")
    else:
        # Filter for specific db_id
        filtered_data = [item for item in dataset if item.get("db_id") == db_id]

        if not filtered_data:
            print(f"No questions found for db_id='{db_id}'")
            return

        filepath = output_path / f"{db_id}.json"
        with open(filepath, "w") as f:
            json.dump(filtered_data, f, indent=2)

        print(f"Found {len(filtered_data)} questions for db_id='{db_id}'")
        print(f"Saved to {filepath}")

        # Print sample
        if filtered_data:
            sample = filtered_data[0]
            print("\nFirst question sample:")
            print(
                json.dumps(
                    {k: v for k, v in sample.items() if k != "evidence"}, indent=2
                )
            )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Download Spider dataset questions for specific databases",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "--db-id",
        type=str,
        default="student_assessment",
        help="Database ID to filter by (or 'all' for all databases)",
    )
    parser.add_argument(
        "--split",
        type=str,
        default="train",
        choices=["train", "validation"],
        help="Dataset split to download",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="data/questions",
        help="Directory to save JSON files",
    )

    args = parser.parse_args()
    download_spider_questions(
        db_id=args.db_id, split=args.split, output_dir=args.output_dir
    )