Spaces:
Running
Running
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | |
| from sqlalchemy.orm import sessionmaker, declarative_base | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Support both DATABASE_URL (Supabase/production) and individual env vars (local dev) | |
| DATABASE_URL = os.getenv("DATABASE_URL") | |
| if DATABASE_URL: | |
| # Use DATABASE_URL if provided (Supabase/Vercel) | |
| # Convert postgres:// to postgresql+asyncpg:// if needed | |
| if DATABASE_URL.startswith("postgres://"): | |
| DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql+asyncpg://", 1) | |
| elif DATABASE_URL.startswith("postgresql://"): | |
| DATABASE_URL = DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1) | |
| # --- ADD THIS FIX --- | |
| # Strip sslmode if present (asyncpg doesn't support it in the connection string) | |
| if "sslmode=" in DATABASE_URL: | |
| DATABASE_URL = DATABASE_URL.split("?")[0] | |
| # -------------------- | |
| else: | |
| # Fall back to individual environment variables (local development) | |
| DB_USER = os.getenv("DB_USER", "postgres") | |
| DB_PASSWORD = os.getenv("DB_PASSWORD", "password") | |
| DB_HOST = os.getenv("DB_HOST", "localhost") | |
| DB_PORT = os.getenv("DB_PORT", "5432") | |
| DB_NAME = os.getenv("DB_NAME", "lawbot_db") | |
| DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" | |
| engine = create_async_engine( | |
| DATABASE_URL, | |
| echo=True, | |
| pool_pre_ping=True, # Check connection alive before using | |
| pool_recycle=300 # Refresh connection every 5 minutes | |
| ) | |
| AsyncSessionLocal = sessionmaker( | |
| bind=engine, | |
| class_=AsyncSession, | |
| expire_on_commit=False, | |
| ) | |
| Base = declarative_base() | |
| async def get_db(): | |
| async with AsyncSessionLocal() as session: | |
| yield session |