File size: 3,391 Bytes
5dd1bb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | """
Script to download Spider dataset questions for specific databases.
Usage:
python download_spider_data.py --db-id student_assessment
python download_spider_data.py --db-id student_assessment --split validation
python download_spider_data.py --db-id all # downloads all db_ids
"""
import json
import argparse
from pathlib import Path
from datasets import load_dataset
def download_spider_questions(
db_id: str = "student_assessment",
split: str = "train",
output_dir: str = "data/questions",
) -> None:
"""Download Spider dataset questions for specified database(s).
Args:
db_id: Database ID to filter by, or "all" to get all databases
split: Dataset split ("train" or "validation")
output_dir: Directory to save JSON files
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
print(f"Loading Spider dataset ({split} split)...")
dataset = load_dataset("xlangai/spider", split=split)
if db_id.lower() == "all":
# Group by db_id
grouped = {}
for item in dataset:
current_db_id = item.get("db_id")
if current_db_id not in grouped:
grouped[current_db_id] = []
grouped[current_db_id].append(item)
total_questions = 0
for current_db_id, questions in grouped.items():
filepath = output_path / f"{current_db_id}.json"
with open(filepath, "w") as f:
json.dump(questions, f, indent=2)
print(f" {current_db_id}: {len(questions)} questions → {filepath}")
total_questions += len(questions)
print(f"\nTotal: {total_questions} questions across {len(grouped)} databases")
else:
# Filter for specific db_id
filtered_data = [item for item in dataset if item.get("db_id") == db_id]
if not filtered_data:
print(f"No questions found for db_id='{db_id}'")
return
filepath = output_path / f"{db_id}.json"
with open(filepath, "w") as f:
json.dump(filtered_data, f, indent=2)
print(f"Found {len(filtered_data)} questions for db_id='{db_id}'")
print(f"Saved to {filepath}")
# Print sample
if filtered_data:
sample = filtered_data[0]
print("\nFirst question sample:")
print(
json.dumps(
{k: v for k, v in sample.items() if k != "evidence"}, indent=2
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download Spider dataset questions for specific databases",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--db-id",
type=str,
default="student_assessment",
help="Database ID to filter by (or 'all' for all databases)",
)
parser.add_argument(
"--split",
type=str,
default="train",
choices=["train", "validation"],
help="Dataset split to download",
)
parser.add_argument(
"--output-dir",
type=str,
default="data/questions",
help="Directory to save JSON files",
)
args = parser.parse_args()
download_spider_questions(
db_id=args.db_id, split=args.split, output_dir=args.output_dir
)
|