LawBot / src /apps /database.py
Vishwanath77's picture
Update src/apps/database.py
a2f131f verified
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
import os
from dotenv import load_dotenv
load_dotenv()
# Support both DATABASE_URL (Supabase/production) and individual env vars (local dev)
DATABASE_URL = os.getenv("DATABASE_URL")
if DATABASE_URL:
# Use DATABASE_URL if provided (Supabase/Vercel)
# Convert postgres:// to postgresql+asyncpg:// if needed
if DATABASE_URL.startswith("postgres://"):
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql+asyncpg://", 1)
elif DATABASE_URL.startswith("postgresql://"):
DATABASE_URL = DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1)
# --- ADD THIS FIX ---
# Strip sslmode if present (asyncpg doesn't support it in the connection string)
if "sslmode=" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.split("?")[0]
# --------------------
else:
# Fall back to individual environment variables (local development)
DB_USER = os.getenv("DB_USER", "postgres")
DB_PASSWORD = os.getenv("DB_PASSWORD", "password")
DB_HOST = os.getenv("DB_HOST", "localhost")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME", "lawbot_db")
DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
engine = create_async_engine(
DATABASE_URL,
echo=True,
pool_pre_ping=True, # Check connection alive before using
pool_recycle=300 # Refresh connection every 5 minutes
)
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
Base = declarative_base()
async def get_db():
async with AsyncSessionLocal() as session:
yield session