""" Script to download Spider schema and auto-generate SQLAlchemy models. The spider-schema dataset contains detailed database schemas including table names, column names, types, and relationships. This script downloads the schema and generates SQLAlchemy ORM models. Usage: # Generate models for student_assessment database python generate_models_from_schema.py --db-id student_assessment # Generate for multiple databases python generate_models_from_schema.py --db-id all --output-dir models/ # Load from validation split python generate_models_from_schema.py --db-id student_assessment --split validation """ import json import argparse from pathlib import Path from typing import Any, Dict, List, Optional from datasets import load_dataset # Type mapping from Spider schema to SQLAlchemy SQLALCHEMY_TYPE_MAP = { "number": "Integer", "int": "Integer", "float": "Float", "text": "String", "string": "String", "varchar": "String", "char": "String", "date": "Date", "datetime": "DateTime", "timestamp": "DateTime", "time": "DateTime", "boolean": "Boolean", "bool": "Boolean", } def get_sqlalchemy_type(col_type: str) -> str: """Convert Spider schema type to SQLAlchemy type.""" col_type_lower = col_type.lower().strip() # Exact match if col_type_lower in SQLALCHEMY_TYPE_MAP: return SQLALCHEMY_TYPE_MAP[col_type_lower] # Substring match (e.g., "varchar(255)" -> "String") for key, sa_type in SQLALCHEMY_TYPE_MAP.items(): if key in col_type_lower: return sa_type # Default to String return "String" def generate_model_code( db_id: str, tables: List[Dict[str, Any]], schema: Dict[str, Any], ) -> str: """Generate SQLAlchemy model code from schema. Args: db_id: Database ID tables: List of table schemas schema: Full schema dictionary with relationships Returns: Generated Python code as string """ lines = [ f'"""', f"SQLAlchemy ORM models for '{db_id}' database.", f'", f"Auto-generated from Spider schema dataset.", f'"""', f"", f"from datetime import datetime", f"from sqlalchemy import Column, Integer, String, Float, Date, DateTime, Boolean, ForeignKey", f"from sqlalchemy.ext.declarative import declarative_base", f"from sqlalchemy.orm import relationship", f"", f"Base = declarative_base()", f"", ] # Generate model for each table table_names = [t["name"] for t in tables] for table in tables: table_name = table["name"] class_name = "".join(word.capitalize() for word in table_name.split("_")) lines.append(f'class {class_name}(Base):') lines.append(f' """Model for {table_name} table."""') lines.append(f' __tablename__ = "{table_name}"') lines.append(f"") # Add columns columns = table.get("columns", []) for col in columns: col_name = col["name"] col_type = col.get("type", "text") sa_type = get_sqlalchemy_type(col_type) # Determine if primary key is_pk = col.get("is_primary_key", False) # Determine if foreign key fk_str = "" for fk in schema.get("foreign_keys", []): if fk[0] == (table_names.index(table_name), columns.index(col)): source_table_idx, target_table_idx = fk target_col_idx = fk[2] if len(fk) > 2 else 0 target_table = table_names[target_table_idx] target_col = tables[target_table_idx]["columns"][target_col_idx]["name"] fk_str = f', ForeignKey("{target_table}.{target_col}")' # Default nullable to False for primary keys nullable = "False" if is_pk else "True" pk_str = ", primary_key=True" if is_pk else "" lines.append( f' {col_name} = Column({sa_type}({col_type.split("(")[1].rstrip(")")} ' f'if "{sa_type}" == "String" else ""){pk_str}{fk_str}, nullable={nullable})' ) lines.append(f"") return "\n".join(lines) def download_schema_and_generate_models( db_id: str = "student_assessment", split: str = "train", output_dir: str = "data/models", ) -> None: """Download Spider schema and generate SQLAlchemy models. Args: db_id: Database ID to download schema for split: Dataset split ("train" or "validation") output_dir: Directory to save generated model files """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) print(f"Loading Spider schema dataset ({split} split)...") dataset = load_dataset("richardr1126/spider-schema", split=split) if db_id.lower() == "all": # Generate models for all databases processed = set() for item in dataset: current_db_id = item.get("db_id") if current_db_id in processed: continue processed.add(current_db_id) tables = item.get("table", []) schema = { "table_names": [t["name"] for t in tables], "column_names": [col for t in tables for col in t.get("columns", [])], "foreign_keys": item.get("foreign_keys", []), } # Generate code (simplified) code = generate_simplified_models(current_db_id, tables) filepath = output_path / f"{current_db_id}.py" with open(filepath, "w") as f: f.write(code) print(f" {current_db_id}: {len(tables)} tables → {filepath}") else: # Filter for specific db_id matching = [item for item in dataset if item.get("db_id") == db_id] if not matching: print(f"No schema found for db_id='{db_id}'") return item = matching[0] tables = item.get("table", []) # Generate simplified model code code = generate_simplified_models(db_id, tables) filepath = output_path / f"{db_id}.py" with open(filepath, "w") as f: f.write(code) print(f"Found schema for db_id='{db_id}' with {len(tables)} tables") print(f"Generated models → {filepath}") print(f"\nTables: {', '.join(t['name'] for t in tables)}") def generate_simplified_models(db_id: str, tables: List[Dict[str, Any]]) -> str: """Generate SQLAlchemy models from table schema (simplified version). Args: db_id: Database ID tables: List of table definitions from schema Returns: Generated Python code """ lines = [ f'"""', f"SQLAlchemy ORM models for '{db_id}' database.", f'", f"Auto-generated from Spider schema dataset.", f'"""', f"", f"from datetime import datetime", f"from sqlalchemy import Column, Integer, String, Float, Date, DateTime, Boolean, ForeignKey", f"from sqlalchemy.ext.declarative import declarative_base", f"from sqlalchemy.orm import relationship", f"", f"Base = declarative_base()", f"", ] for table in tables: table_name = table.get("name", "Unknown") class_name = "".join(word.capitalize() for word in table_name.split("_")) lines.append(f"") lines.append(f"class {class_name}(Base):") lines.append(f' """Model for {table_name} table."""') lines.append(f' __tablename__ = "{table_name}"') lines.append(f"") # Add columns columns = table.get("columns", []) if columns: for col in columns: col_name = col.get("name", "unknown") col_type = col.get("type", "text") sa_type = get_sqlalchemy_type(col_type) # Determine string length from type if specified length_spec = "" if sa_type == "String": if "(" in col_type and ")" in col_type: length = col_type.split("(")[1].split(")")[0] if length.isdigit(): length_spec = f"({length})" else: length_spec = "(255)" # default lines.append(f' {col_name} = Column({sa_type}{length_spec}, nullable=True)') else: lines.append(f" id = Column(Integer, primary_key=True)") lines.append(f"") return "\n".join(lines) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Download Spider schema and generate SQLAlchemy models", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( "--db-id", type=str, default="student_assessment", help="Database ID to generate models for (or 'all' for all databases)", ) parser.add_argument( "--split", type=str, default="train", choices=["train", "validation"], help="Schema dataset split to use", ) parser.add_argument( "--output-dir", type=str, default="data/models", help="Directory to save generated model files", ) args = parser.parse_args() download_schema_and_generate_models( db_id=args.db_id, split=args.split, output_dir=args.output_dir )