| """ |
| Script to download Spider schema and auto-generate SQLAlchemy models. |
| |
| The spider-schema dataset contains detailed database schemas including |
| table names, column names, types, and relationships. This script |
| downloads the schema and generates SQLAlchemy ORM models. |
| |
| Usage: |
| # Generate models for student_assessment database |
| python generate_models_from_schema.py --db-id student_assessment |
| |
| # Generate for multiple databases |
| python generate_models_from_schema.py --db-id all --output-dir models/ |
| |
| # Load from validation split |
| python generate_models_from_schema.py --db-id student_assessment --split validation |
| """ |
|
|
| import json |
| import argparse |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
| from datasets import load_dataset |
|
|
|
|
| |
| SQLALCHEMY_TYPE_MAP = { |
| "number": "Integer", |
| "int": "Integer", |
| "float": "Float", |
| "text": "String", |
| "string": "String", |
| "varchar": "String", |
| "char": "String", |
| "date": "Date", |
| "datetime": "DateTime", |
| "timestamp": "DateTime", |
| "time": "DateTime", |
| "boolean": "Boolean", |
| "bool": "Boolean", |
| } |
|
|
|
|
| def get_sqlalchemy_type(col_type: str) -> str: |
| """Convert Spider schema type to SQLAlchemy type.""" |
| col_type_lower = col_type.lower().strip() |
| |
| |
| if col_type_lower in SQLALCHEMY_TYPE_MAP: |
| return SQLALCHEMY_TYPE_MAP[col_type_lower] |
| |
| |
| for key, sa_type in SQLALCHEMY_TYPE_MAP.items(): |
| if key in col_type_lower: |
| return sa_type |
| |
| |
| return "String" |
|
|
|
|
| def generate_model_code( |
| db_id: str, |
| tables: List[Dict[str, Any]], |
| schema: Dict[str, Any], |
| ) -> str: |
| """Generate SQLAlchemy model code from schema. |
| |
| Args: |
| db_id: Database ID |
| tables: List of table schemas |
| schema: Full schema dictionary with relationships |
| |
| Returns: |
| Generated Python code as string |
| """ |
| lines = [ |
| f'"""', |
| f"SQLAlchemy ORM models for '{db_id}' database.", |
| f'", |
| f"Auto-generated from Spider schema dataset.", |
| f'"""', |
| f"", |
| f"from datetime import datetime", |
| f"from sqlalchemy import Column, Integer, String, Float, Date, DateTime, Boolean, ForeignKey", |
| f"from sqlalchemy.ext.declarative import declarative_base", |
| f"from sqlalchemy.orm import relationship", |
| f"", |
| f"Base = declarative_base()", |
| f"", |
| ] |
| |
| # Generate model for each table |
| table_names = [t["name"] for t in tables] |
| |
| for table in tables: |
| table_name = table["name"] |
| class_name = "".join(word.capitalize() for word in table_name.split("_")) |
| |
| lines.append(f'class {class_name}(Base):') |
| lines.append(f' """Model for {table_name} table."""') |
| lines.append(f' __tablename__ = "{table_name}"') |
| lines.append(f"") |
| |
| # Add columns |
| columns = table.get("columns", []) |
| for col in columns: |
| col_name = col["name"] |
| col_type = col.get("type", "text") |
| sa_type = get_sqlalchemy_type(col_type) |
| |
| # Determine if primary key |
| is_pk = col.get("is_primary_key", False) |
| |
| # Determine if foreign key |
| fk_str = "" |
| for fk in schema.get("foreign_keys", []): |
| if fk[0] == (table_names.index(table_name), columns.index(col)): |
| source_table_idx, target_table_idx = fk |
| target_col_idx = fk[2] if len(fk) > 2 else 0 |
| target_table = table_names[target_table_idx] |
| target_col = tables[target_table_idx]["columns"][target_col_idx]["name"] |
| fk_str = f', ForeignKey("{target_table}.{target_col}")' |
| |
| # Default nullable to False for primary keys |
| nullable = "False" if is_pk else "True" |
| pk_str = ", primary_key=True" if is_pk else "" |
| |
| lines.append( |
| f' {col_name} = Column({sa_type}({col_type.split("(")[1].rstrip(")")} ' |
| f'if "{sa_type}" == "String" else ""){pk_str}{fk_str}, nullable={nullable})' |
| ) |
| |
| lines.append(f"") |
| |
| return "\n".join(lines) |
| |
| |
| def download_schema_and_generate_models( |
| db_id: str = "student_assessment", |
| split: str = "train", |
| output_dir: str = "data/models", |
| ) -> None: |
| """Download Spider schema and generate SQLAlchemy models. |
| |
| Args: |
| db_id: Database ID to download schema for |
| split: Dataset split ("train" or "validation") |
| output_dir: Directory to save generated model files |
| """ |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
| |
| print(f"Loading Spider schema dataset ({split} split)...") |
| dataset = load_dataset("richardr1126/spider-schema", split=split) |
| |
| if db_id.lower() == "all": |
| # Generate models for all databases |
| processed = set() |
| for item in dataset: |
| current_db_id = item.get("db_id") |
| if current_db_id in processed: |
| continue |
| processed.add(current_db_id) |
| |
| tables = item.get("table", []) |
| schema = { |
| "table_names": [t["name"] for t in tables], |
| "column_names": [col for t in tables for col in t.get("columns", [])], |
| "foreign_keys": item.get("foreign_keys", []), |
| } |
| |
| # Generate code (simplified) |
| code = generate_simplified_models(current_db_id, tables) |
| |
| filepath = output_path / f"{current_db_id}.py" |
| with open(filepath, "w") as f: |
| f.write(code) |
| |
| print(f" {current_db_id}: {len(tables)} tables → {filepath}") |
| else: |
| # Filter for specific db_id |
| matching = [item for item in dataset if item.get("db_id") == db_id] |
| |
| if not matching: |
| print(f"No schema found for db_id='{db_id}'") |
| return |
| |
| item = matching[0] |
| tables = item.get("table", []) |
| |
| # Generate simplified model code |
| code = generate_simplified_models(db_id, tables) |
| |
| filepath = output_path / f"{db_id}.py" |
| with open(filepath, "w") as f: |
| f.write(code) |
| |
| print(f"Found schema for db_id='{db_id}' with {len(tables)} tables") |
| print(f"Generated models → {filepath}") |
| print(f"\nTables: {', '.join(t['name'] for t in tables)}") |
| |
| |
| def generate_simplified_models(db_id: str, tables: List[Dict[str, Any]]) -> str: |
| """Generate SQLAlchemy models from table schema (simplified version). |
| |
| Args: |
| db_id: Database ID |
| tables: List of table definitions from schema |
| |
| Returns: |
| Generated Python code |
| """ |
| lines = [ |
| f'"""', |
| f"SQLAlchemy ORM models for '{db_id}' database.", |
| f'", |
| f"Auto-generated from Spider schema dataset.", |
| f'"""', |
| f"", |
| f"from datetime import datetime", |
| f"from sqlalchemy import Column, Integer, String, Float, Date, DateTime, Boolean, ForeignKey", |
| f"from sqlalchemy.ext.declarative import declarative_base", |
| f"from sqlalchemy.orm import relationship", |
| f"", |
| f"Base = declarative_base()", |
| f"", |
| ] |
| |
| for table in tables: |
| table_name = table.get("name", "Unknown") |
| class_name = "".join(word.capitalize() for word in table_name.split("_")) |
| |
| lines.append(f"") |
| lines.append(f"class {class_name}(Base):") |
| lines.append(f' """Model for {table_name} table."""') |
| lines.append(f' __tablename__ = "{table_name}"') |
| lines.append(f"") |
| |
| # Add columns |
| columns = table.get("columns", []) |
| if columns: |
| for col in columns: |
| col_name = col.get("name", "unknown") |
| col_type = col.get("type", "text") |
| sa_type = get_sqlalchemy_type(col_type) |
| |
| # Determine string length from type if specified |
| length_spec = "" |
| if sa_type == "String": |
| if "(" in col_type and ")" in col_type: |
| length = col_type.split("(")[1].split(")")[0] |
| if length.isdigit(): |
| length_spec = f"({length})" |
| else: |
| length_spec = "(255)" # default |
| |
| lines.append(f' {col_name} = Column({sa_type}{length_spec}, nullable=True)') |
| else: |
| lines.append(f" id = Column(Integer, primary_key=True)") |
| |
| lines.append(f"") |
| |
| return "\n".join(lines) |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Download Spider schema and generate SQLAlchemy models", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| ) |
| parser.add_argument( |
| "--db-id", |
| type=str, |
| default="student_assessment", |
| help="Database ID to generate models for (or 'all' for all databases)", |
| ) |
| parser.add_argument( |
| "--split", |
| type=str, |
| default="train", |
| choices=["train", "validation"], |
| help="Schema dataset split to use", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default="data/models", |
| help="Directory to save generated model files", |
| ) |
| |
| args = parser.parse_args() |
| download_schema_and_generate_models( |
| db_id=args.db_id, split=args.split, output_dir=args.output_dir |
| ) |
| |