File size: 1,832 Bytes
734a827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Database initialization and session management
"""
import os
from urllib.parse import urlparse, parse_qs
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import StaticPool
from models import Base

# Database URL from environment or use SQLite
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./hf_uploader.db")

# Parse and fix the database URL for TiDB Cloud
if DATABASE_URL and DATABASE_URL.startswith("mysql://"):
    # Parse the URL
    parsed = urlparse(DATABASE_URL)
    
    # Build a clean URL without SSL parameters
    clean_url = f"mysql+pymysql://{parsed.username}:{parsed.password}@{parsed.hostname}:{parsed.port or 3306}{parsed.path}"
    
    # Add SSL parameters properly formatted for PyMySQL
    clean_url += "?ssl_verify_cert=false&ssl_verify_identity=false"
    
    DATABASE_URL = clean_url

# Create engine with proper configuration
if "sqlite" in DATABASE_URL:
    engine = create_engine(
        DATABASE_URL,
        connect_args={"check_same_thread": False},
        poolclass=StaticPool,
        echo=False
    )
    # Enable foreign keys for SQLite
    @event.listens_for(engine, "connect")
    def set_sqlite_pragma(dbapi_conn, connection_record):
        cursor = dbapi_conn.cursor()
        cursor.execute("PRAGMA foreign_keys=ON")
        cursor.close()
else:
    engine = create_engine(
        DATABASE_URL,
        echo=False,
        pool_pre_ping=True,
        pool_recycle=3600
    )

# Create session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def init_db():
    """Initialize database tables"""
    Base.metadata.create_all(bind=engine)


def get_db() -> Session:
    """Get database session"""
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()