| """ |
| Script to download Spider dataset questions for specific databases. |
| |
| Usage: |
| python download_spider_data.py --db-id student_assessment |
| python download_spider_data.py --db-id student_assessment --split validation |
| python download_spider_data.py --db-id all # downloads all db_ids |
| """ |
|
|
| import json |
| import argparse |
| from pathlib import Path |
| from datasets import load_dataset |
|
|
|
|
| def download_spider_questions( |
| db_id: str = "student_assessment", |
| split: str = "train", |
| output_dir: str = "data/questions", |
| ) -> None: |
| """Download Spider dataset questions for specified database(s). |
| |
| Args: |
| db_id: Database ID to filter by, or "all" to get all databases |
| split: Dataset split ("train" or "validation") |
| output_dir: Directory to save JSON files |
| """ |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"Loading Spider dataset ({split} split)...") |
| dataset = load_dataset("xlangai/spider", split=split) |
|
|
| if db_id.lower() == "all": |
| |
| grouped = {} |
| for item in dataset: |
| current_db_id = item.get("db_id") |
| if current_db_id not in grouped: |
| grouped[current_db_id] = [] |
| grouped[current_db_id].append(item) |
|
|
| total_questions = 0 |
| for current_db_id, questions in grouped.items(): |
| filepath = output_path / f"{current_db_id}.json" |
| with open(filepath, "w") as f: |
| json.dump(questions, f, indent=2) |
| print(f" {current_db_id}: {len(questions)} questions → {filepath}") |
| total_questions += len(questions) |
|
|
| print(f"\nTotal: {total_questions} questions across {len(grouped)} databases") |
| else: |
| |
| filtered_data = [item for item in dataset if item.get("db_id") == db_id] |
|
|
| if not filtered_data: |
| print(f"No questions found for db_id='{db_id}'") |
| return |
|
|
| filepath = output_path / f"{db_id}.json" |
| with open(filepath, "w") as f: |
| json.dump(filtered_data, f, indent=2) |
|
|
| print(f"Found {len(filtered_data)} questions for db_id='{db_id}'") |
| print(f"Saved to {filepath}") |
|
|
| |
| if filtered_data: |
| sample = filtered_data[0] |
| print("\nFirst question sample:") |
| print( |
| json.dumps( |
| {k: v for k, v in sample.items() if k != "evidence"}, indent=2 |
| ) |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Download Spider dataset questions for specific databases", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| ) |
| parser.add_argument( |
| "--db-id", |
| type=str, |
| default="student_assessment", |
| help="Database ID to filter by (or 'all' for all databases)", |
| ) |
| parser.add_argument( |
| "--split", |
| type=str, |
| default="train", |
| choices=["train", "validation"], |
| help="Dataset split to download", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default="data/questions", |
| help="Directory to save JSON files", |
| ) |
|
|
| args = parser.parse_args() |
| download_spider_questions( |
| db_id=args.db_id, split=args.split, output_dir=args.output_dir |
| ) |
|
|