""" Script to download Spider dataset questions for specific databases. Usage: python download_spider_data.py --db-id student_assessment python download_spider_data.py --db-id student_assessment --split validation python download_spider_data.py --db-id all # downloads all db_ids """ import json import argparse from pathlib import Path from datasets import load_dataset def download_spider_questions( db_id: str = "student_assessment", split: str = "train", output_dir: str = "data/questions", ) -> None: """Download Spider dataset questions for specified database(s). Args: db_id: Database ID to filter by, or "all" to get all databases split: Dataset split ("train" or "validation") output_dir: Directory to save JSON files """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) print(f"Loading Spider dataset ({split} split)...") dataset = load_dataset("xlangai/spider", split=split) if db_id.lower() == "all": # Group by db_id grouped = {} for item in dataset: current_db_id = item.get("db_id") if current_db_id not in grouped: grouped[current_db_id] = [] grouped[current_db_id].append(item) total_questions = 0 for current_db_id, questions in grouped.items(): filepath = output_path / f"{current_db_id}.json" with open(filepath, "w") as f: json.dump(questions, f, indent=2) print(f" {current_db_id}: {len(questions)} questions → {filepath}") total_questions += len(questions) print(f"\nTotal: {total_questions} questions across {len(grouped)} databases") else: # Filter for specific db_id filtered_data = [item for item in dataset if item.get("db_id") == db_id] if not filtered_data: print(f"No questions found for db_id='{db_id}'") return filepath = output_path / f"{db_id}.json" with open(filepath, "w") as f: json.dump(filtered_data, f, indent=2) print(f"Found {len(filtered_data)} questions for db_id='{db_id}'") print(f"Saved to {filepath}") # Print sample if filtered_data: sample = filtered_data[0] print("\nFirst question sample:") print( json.dumps( {k: v for k, v in sample.items() if k != "evidence"}, indent=2 ) ) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Download Spider dataset questions for specific databases", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( "--db-id", type=str, default="student_assessment", help="Database ID to filter by (or 'all' for all databases)", ) parser.add_argument( "--split", type=str, default="train", choices=["train", "validation"], help="Dataset split to download", ) parser.add_argument( "--output-dir", type=str, default="data/questions", help="Directory to save JSON files", ) args = parser.parse_args() download_spider_questions( db_id=args.db_id, split=args.split, output_dir=args.output_dir )