| import tempfile |
| from contextlib import asynccontextmanager, contextmanager |
| from typing import Annotated, Generator |
|
|
| from apscheduler.schedulers.background import BackgroundScheduler |
| from fastapi import Depends, FastAPI, HTTPException, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse, RedirectResponse |
| from fastapi.staticfiles import StaticFiles |
| from huggingface_hub import ( |
| OAuthInfo, |
| attach_huggingface_oauth, |
| create_repo, |
| parse_huggingface_oauth, |
| snapshot_download, |
| upload_folder, |
| ) |
| from sqlalchemy.engine import Engine |
| from sqlmodel import Session, SQLModel, create_engine |
|
|
| from . import constants |
| from .parquet import export_to_parquet, import_from_parquet |
|
|
| _ENGINE_SINGLETON: Engine | None = None |
|
|
| |
|
|
|
|
| async def _oauth_info_optional(request: Request) -> OAuthInfo | None: |
| return parse_huggingface_oauth(request) |
|
|
|
|
| async def _oauth_info_required(request: Request) -> OAuthInfo: |
| oauth_info = parse_huggingface_oauth(request) |
| if oauth_info is None: |
| raise HTTPException( |
| status_code=401, detail="Unauthorized. Please Sign in with Hugging Face." |
| ) |
| return oauth_info |
|
|
|
|
| OptionalOAuth = Annotated[OAuthInfo | None, Depends(_oauth_info_optional)] |
| RequiredOAuth = Annotated[OAuthInfo, Depends(_oauth_info_required)] |
|
|
|
|
| def get_engine() -> Engine: |
| """Get the engine.""" |
| global _ENGINE_SINGLETON |
| if _ENGINE_SINGLETON is None: |
| _ENGINE_SINGLETON = create_engine(constants.DATABASE_URL) |
| return _ENGINE_SINGLETON |
|
|
|
|
| @contextmanager |
| def get_session() -> Generator[Session, None, None]: |
| """Get a session from the engine.""" |
| engine = get_engine() |
| with Session(engine) as session: |
| yield session |
|
|
|
|
| @asynccontextmanager |
| async def _database_lifespan(app: FastAPI): |
| """Handle database lifespan. |
| |
| 1. If backup is enabled enabled, |
| a. Try to load backup from remote dataset. If it fails, delete local database for a fresh start. |
| b. Start back-up scheduler. |
| 2. If disabled, create local database file or reuse existing one. |
| 3. Initialize database. |
| 4. Yield control to FastAPI app. |
| 5. Close database + force push backup to remote dataset. |
| """ |
| scheduler = BackgroundScheduler() |
| engine = get_engine() |
| SQLModel.metadata.create_all(engine) |
|
|
| if constants.BACKUP_DB: |
| print("Back-up database is enabled") |
|
|
| |
| repo_url = create_repo( |
| repo_id=constants.BACKUP_DATASET_ID, |
| repo_type="dataset", |
| token=constants.HF_TOKEN, |
| private=True, |
| exist_ok=True, |
| ) |
| print(f"Backup dataset: {repo_url}") |
| repo_id = repo_url.repo_id |
|
|
| |
| print("Trying to load backup from remote dataset...") |
| try: |
| backup_dir = snapshot_download( |
| repo_id=repo_id, |
| repo_type="dataset", |
| token=constants.HF_TOKEN, |
| allow_patterns="*.parquet", |
| ) |
| except Exception: |
| |
| print("Couldn't find backup in remote dataset.") |
| print("Deleting local database for a fresh start.") |
|
|
| engine = get_engine() |
| SQLModel.metadata.drop_all(engine) |
| SQLModel.metadata.create_all(engine) |
|
|
| |
| import_from_parquet(get_engine(), backup_dir) |
|
|
| |
| scheduler.add_job(_backup_to_hub, args=[repo_id], trigger="interval", minutes=5) |
| scheduler.start() |
|
|
| yield |
|
|
| print("Closing database...") |
| global _ENGINE_SINGLETON |
| if _ENGINE_SINGLETON is not None: |
| _ENGINE_SINGLETON.dispose() |
| _ENGINE_SINGLETON = None |
|
|
| if constants.BACKUP_DB: |
| print("Pushing backup to remote dataset...") |
| _backup_to_hub(repo_id) |
|
|
|
|
| def _backup_to_hub(repo_id: str) -> None: |
| """Export backup to remote dataset as parquet files.""" |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| export_to_parquet(get_engine(), tmp_dir) |
|
|
| upload_folder( |
| repo_id=repo_id, |
| folder_path=tmp_dir, |
| token=constants.HF_TOKEN, |
| repo_type="dataset", |
| allow_patterns="*.parquet", |
| commit_message="Backup database as parquet", |
| delete_patterns=["*.parquet"], |
| ) |
|
|
|
|
| def create_app() -> FastAPI: |
| |
| app = FastAPI(lifespan=_database_lifespan) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=[ |
| |
| "http://localhost:5173", |
| "http://0.0.0.0:9481", |
| "http://localhost:9481", |
| "http://127.0.0.1:9481", |
| ], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| if constants.SERVE_FRONTEND: |
| |
| app.mount( |
| "/assets", |
| StaticFiles(directory=constants.FRONTEND_ASSETS_PATH), |
| name="assets", |
| ) |
|
|
| @app.get("/") |
| async def serve_frontend(): |
| return FileResponse(constants.FRONTEND_INDEX_PATH) |
|
|
| else: |
| |
| @app.get("/") |
| async def redirect_to_frontend(): |
| return RedirectResponse("http://localhost:5173/") |
|
|
| |
| |
| attach_huggingface_oauth(app) |
|
|
| return app |
|
|