Spaces:
Running
Running
File size: 1,778 Bytes
51c9156 a2f131f 51c9156 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
import os
from dotenv import load_dotenv
load_dotenv()
# Support both DATABASE_URL (Supabase/production) and individual env vars (local dev)
DATABASE_URL = os.getenv("DATABASE_URL")
if DATABASE_URL:
# Use DATABASE_URL if provided (Supabase/Vercel)
# Convert postgres:// to postgresql+asyncpg:// if needed
if DATABASE_URL.startswith("postgres://"):
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql+asyncpg://", 1)
elif DATABASE_URL.startswith("postgresql://"):
DATABASE_URL = DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1)
# --- ADD THIS FIX ---
# Strip sslmode if present (asyncpg doesn't support it in the connection string)
if "sslmode=" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.split("?")[0]
# --------------------
else:
# Fall back to individual environment variables (local development)
DB_USER = os.getenv("DB_USER", "postgres")
DB_PASSWORD = os.getenv("DB_PASSWORD", "password")
DB_HOST = os.getenv("DB_HOST", "localhost")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME", "lawbot_db")
DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
engine = create_async_engine(
DATABASE_URL,
echo=True,
pool_pre_ping=True, # Check connection alive before using
pool_recycle=300 # Refresh connection every 5 minutes
)
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
Base = declarative_base()
async def get_db():
async with AsyncSessionLocal() as session:
yield session |