sql_env / scripts /generate_models_from_schema.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""
Script to download Spider schema and auto-generate SQLAlchemy models.
The spider-schema dataset contains detailed database schemas including
table names, column names, types, and relationships. This script
downloads the schema and generates SQLAlchemy ORM models.
Usage:
# Generate models for student_assessment database
python generate_models_from_schema.py --db-id student_assessment
# Generate for multiple databases
python generate_models_from_schema.py --db-id all --output-dir models/
# Load from validation split
python generate_models_from_schema.py --db-id student_assessment --split validation
"""
import json
import argparse
from pathlib import Path
from typing import Any, Dict, List, Optional
from datasets import load_dataset
# Type mapping from Spider schema to SQLAlchemy
SQLALCHEMY_TYPE_MAP = {
"number": "Integer",
"int": "Integer",
"float": "Float",
"text": "String",
"string": "String",
"varchar": "String",
"char": "String",
"date": "Date",
"datetime": "DateTime",
"timestamp": "DateTime",
"time": "DateTime",
"boolean": "Boolean",
"bool": "Boolean",
}
def get_sqlalchemy_type(col_type: str) -> str:
"""Convert Spider schema type to SQLAlchemy type."""
col_type_lower = col_type.lower().strip()
# Exact match
if col_type_lower in SQLALCHEMY_TYPE_MAP:
return SQLALCHEMY_TYPE_MAP[col_type_lower]
# Substring match (e.g., "varchar(255)" -> "String")
for key, sa_type in SQLALCHEMY_TYPE_MAP.items():
if key in col_type_lower:
return sa_type
# Default to String
return "String"
def generate_model_code(
db_id: str,
tables: List[Dict[str, Any]],
schema: Dict[str, Any],
) -> str:
"""Generate SQLAlchemy model code from schema.
Args:
db_id: Database ID
tables: List of table schemas
schema: Full schema dictionary with relationships
Returns:
Generated Python code as string
"""
lines = [
f'"""',
f"SQLAlchemy ORM models for '{db_id}' database.",
f'",
f"Auto-generated from Spider schema dataset.",
f'"""',
f"",
f"from datetime import datetime",
f"from sqlalchemy import Column, Integer, String, Float, Date, DateTime, Boolean, ForeignKey",
f"from sqlalchemy.ext.declarative import declarative_base",
f"from sqlalchemy.orm import relationship",
f"",
f"Base = declarative_base()",
f"",
]
# Generate model for each table
table_names = [t["name"] for t in tables]
for table in tables:
table_name = table["name"]
class_name = "".join(word.capitalize() for word in table_name.split("_"))
lines.append(f'class {class_name}(Base):')
lines.append(f' """Model for {table_name} table."""')
lines.append(f' __tablename__ = "{table_name}"')
lines.append(f"")
# Add columns
columns = table.get("columns", [])
for col in columns:
col_name = col["name"]
col_type = col.get("type", "text")
sa_type = get_sqlalchemy_type(col_type)
# Determine if primary key
is_pk = col.get("is_primary_key", False)
# Determine if foreign key
fk_str = ""
for fk in schema.get("foreign_keys", []):
if fk[0] == (table_names.index(table_name), columns.index(col)):
source_table_idx, target_table_idx = fk
target_col_idx = fk[2] if len(fk) > 2 else 0
target_table = table_names[target_table_idx]
target_col = tables[target_table_idx]["columns"][target_col_idx]["name"]
fk_str = f', ForeignKey("{target_table}.{target_col}")'
# Default nullable to False for primary keys
nullable = "False" if is_pk else "True"
pk_str = ", primary_key=True" if is_pk else ""
lines.append(
f' {col_name} = Column({sa_type}({col_type.split("(")[1].rstrip(")")} '
f'if "{sa_type}" == "String" else ""){pk_str}{fk_str}, nullable={nullable})'
)
lines.append(f"")
return "\n".join(lines)
def download_schema_and_generate_models(
db_id: str = "student_assessment",
split: str = "train",
output_dir: str = "data/models",
) -> None:
"""Download Spider schema and generate SQLAlchemy models.
Args:
db_id: Database ID to download schema for
split: Dataset split ("train" or "validation")
output_dir: Directory to save generated model files
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
print(f"Loading Spider schema dataset ({split} split)...")
dataset = load_dataset("richardr1126/spider-schema", split=split)
if db_id.lower() == "all":
# Generate models for all databases
processed = set()
for item in dataset:
current_db_id = item.get("db_id")
if current_db_id in processed:
continue
processed.add(current_db_id)
tables = item.get("table", [])
schema = {
"table_names": [t["name"] for t in tables],
"column_names": [col for t in tables for col in t.get("columns", [])],
"foreign_keys": item.get("foreign_keys", []),
}
# Generate code (simplified)
code = generate_simplified_models(current_db_id, tables)
filepath = output_path / f"{current_db_id}.py"
with open(filepath, "w") as f:
f.write(code)
print(f" {current_db_id}: {len(tables)} tables → {filepath}")
else:
# Filter for specific db_id
matching = [item for item in dataset if item.get("db_id") == db_id]
if not matching:
print(f"No schema found for db_id='{db_id}'")
return
item = matching[0]
tables = item.get("table", [])
# Generate simplified model code
code = generate_simplified_models(db_id, tables)
filepath = output_path / f"{db_id}.py"
with open(filepath, "w") as f:
f.write(code)
print(f"Found schema for db_id='{db_id}' with {len(tables)} tables")
print(f"Generated models → {filepath}")
print(f"\nTables: {', '.join(t['name'] for t in tables)}")
def generate_simplified_models(db_id: str, tables: List[Dict[str, Any]]) -> str:
"""Generate SQLAlchemy models from table schema (simplified version).
Args:
db_id: Database ID
tables: List of table definitions from schema
Returns:
Generated Python code
"""
lines = [
f'"""',
f"SQLAlchemy ORM models for '{db_id}' database.",
f'",
f"Auto-generated from Spider schema dataset.",
f'"""',
f"",
f"from datetime import datetime",
f"from sqlalchemy import Column, Integer, String, Float, Date, DateTime, Boolean, ForeignKey",
f"from sqlalchemy.ext.declarative import declarative_base",
f"from sqlalchemy.orm import relationship",
f"",
f"Base = declarative_base()",
f"",
]
for table in tables:
table_name = table.get("name", "Unknown")
class_name = "".join(word.capitalize() for word in table_name.split("_"))
lines.append(f"")
lines.append(f"class {class_name}(Base):")
lines.append(f' """Model for {table_name} table."""')
lines.append(f' __tablename__ = "{table_name}"')
lines.append(f"")
# Add columns
columns = table.get("columns", [])
if columns:
for col in columns:
col_name = col.get("name", "unknown")
col_type = col.get("type", "text")
sa_type = get_sqlalchemy_type(col_type)
# Determine string length from type if specified
length_spec = ""
if sa_type == "String":
if "(" in col_type and ")" in col_type:
length = col_type.split("(")[1].split(")")[0]
if length.isdigit():
length_spec = f"({length})"
else:
length_spec = "(255)" # default
lines.append(f' {col_name} = Column({sa_type}{length_spec}, nullable=True)')
else:
lines.append(f" id = Column(Integer, primary_key=True)")
lines.append(f"")
return "\n".join(lines)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download Spider schema and generate SQLAlchemy models",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--db-id",
type=str,
default="student_assessment",
help="Database ID to generate models for (or 'all' for all databases)",
)
parser.add_argument(
"--split",
type=str,
default="train",
choices=["train", "validation"],
help="Schema dataset split to use",
)
parser.add_argument(
"--output-dir",
type=str,
default="data/models",
help="Directory to save generated model files",
)
args = parser.parse_args()
download_schema_and_generate_models(
db_id=args.db_id, split=args.split, output_dir=args.output_dir
)