Spaces:
Sleeping
Sleeping
| import tempfile | |
| from contextlib import asynccontextmanager, contextmanager | |
| from typing import Annotated, Generator | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| from fastapi import Depends, FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, RedirectResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from huggingface_hub import ( | |
| OAuthInfo, | |
| attach_huggingface_oauth, | |
| create_repo, | |
| parse_huggingface_oauth, | |
| snapshot_download, | |
| upload_folder, | |
| ) | |
| from sqlalchemy.engine import Engine | |
| from sqlmodel import Session, SQLModel, create_engine | |
| from . import constants | |
| from .parquet import export_to_parquet, import_from_parquet | |
| _ENGINE_SINGLETON: Engine | None = None | |
| # OAuth utilities | |
| async def _oauth_info_optional(request: Request) -> OAuthInfo | None: | |
| return parse_huggingface_oauth(request) | |
| async def _oauth_info_required(request: Request) -> OAuthInfo: | |
| oauth_info = parse_huggingface_oauth(request) | |
| if oauth_info is None: | |
| raise HTTPException( | |
| status_code=401, detail="Unauthorized. Please Sign in with Hugging Face." | |
| ) | |
| return oauth_info | |
| OptionalOAuth = Annotated[OAuthInfo | None, Depends(_oauth_info_optional)] | |
| RequiredOAuth = Annotated[OAuthInfo, Depends(_oauth_info_required)] | |
| def get_engine() -> Engine: | |
| """Get the engine.""" | |
| global _ENGINE_SINGLETON | |
| if _ENGINE_SINGLETON is None: | |
| _ENGINE_SINGLETON = create_engine(constants.DATABASE_URL) | |
| return _ENGINE_SINGLETON | |
| def get_session() -> Generator[Session, None, None]: | |
| """Get a session from the engine.""" | |
| engine = get_engine() | |
| with Session(engine) as session: | |
| yield session | |
| async def _database_lifespan(app: FastAPI): | |
| """Handle database lifespan. | |
| 1. If backup is enabled enabled, | |
| a. Try to load backup from remote dataset. If it fails, delete local database for a fresh start. | |
| b. Start back-up scheduler. | |
| 2. If disabled, create local database file or reuse existing one. | |
| 3. Initialize database. | |
| 4. Yield control to FastAPI app. | |
| 5. Close database + force push backup to remote dataset. | |
| """ | |
| scheduler = BackgroundScheduler() | |
| engine = get_engine() | |
| SQLModel.metadata.create_all(engine) | |
| if constants.BACKUP_DB: | |
| print("Back-up database is enabled") | |
| # Create remote dataset if it doesn't exist | |
| repo_url = create_repo( | |
| repo_id=constants.BACKUP_DATASET_ID, # type: ignore[arg-type] | |
| repo_type="dataset", | |
| token=constants.HF_TOKEN, | |
| private=True, | |
| exist_ok=True, | |
| ) | |
| print(f"Backup dataset: {repo_url}") | |
| repo_id = repo_url.repo_id | |
| # Try to load backup from remote dataset | |
| print("Trying to load backup from remote dataset...") | |
| try: | |
| backup_dir = snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| token=constants.HF_TOKEN, | |
| allow_patterns="*.parquet", | |
| ) | |
| except Exception: | |
| # If backup is enabled but no backup is found, delete local database to prevent confusion. | |
| print("Couldn't find backup in remote dataset.") | |
| print("Deleting local database for a fresh start.") | |
| engine = get_engine() | |
| SQLModel.metadata.drop_all(engine) | |
| SQLModel.metadata.create_all(engine) | |
| # Import parquet files to database | |
| import_from_parquet(get_engine(), backup_dir) | |
| # Start back-up scheduler | |
| scheduler.add_job(_backup_to_hub, args=[repo_id], trigger="interval", minutes=5) | |
| scheduler.start() | |
| yield | |
| print("Closing database...") | |
| global _ENGINE_SINGLETON | |
| if _ENGINE_SINGLETON is not None: | |
| _ENGINE_SINGLETON.dispose() | |
| _ENGINE_SINGLETON = None | |
| if constants.BACKUP_DB: | |
| print("Pushing backup to remote dataset...") | |
| _backup_to_hub(repo_id) | |
| def _backup_to_hub(repo_id: str) -> None: | |
| """Export backup to remote dataset as parquet files.""" | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| export_to_parquet(get_engine(), tmp_dir) | |
| upload_folder( | |
| repo_id=repo_id, | |
| folder_path=tmp_dir, | |
| token=constants.HF_TOKEN, | |
| repo_type="dataset", | |
| allow_patterns="*.parquet", | |
| commit_message="Backup database as parquet", | |
| delete_patterns=["*.parquet"], | |
| ) | |
| def create_app() -> FastAPI: | |
| # FastAPI app | |
| app = FastAPI(lifespan=_database_lifespan) | |
| # Set CORS headers | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| # Can't use "*" because frontend doesn't like it with "credentials: true" | |
| "http://localhost:5173", | |
| "http://0.0.0.0:9481", | |
| "http://localhost:9481", | |
| "http://127.0.0.1:9481", | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount frontend from dist directory (if configured) | |
| if constants.SERVE_FRONTEND: | |
| # Production => Serve frontend from dist directory | |
| app.mount( | |
| "/assets", | |
| StaticFiles(directory=constants.FRONTEND_ASSETS_PATH), # type: ignore[invalid-argument-type] | |
| name="assets", | |
| ) | |
| async def serve_frontend(): | |
| return FileResponse(constants.FRONTEND_INDEX_PATH) # type: ignore[invalid-argument-type] | |
| else: | |
| # Development => Redirect to dev frontend | |
| async def redirect_to_frontend(): | |
| return RedirectResponse("http://localhost:5173/") | |
| # Set up Hugging Face OAuth | |
| # To get OAuthInfo in an endpoint | |
| attach_huggingface_oauth(app) | |
| return app | |