| | import asyncio |
| | import os |
| | import signal |
| | import sqlite3 |
| | from contextlib import asynccontextmanager |
| |
|
| | import psutil |
| | from fastapi import FastAPI |
| | from loguru import logger |
| |
|
| | from competitions.utils import run_evaluation |
| |
|
| |
|
| | def get_process_status(pid): |
| | try: |
| | process = psutil.Process(pid) |
| | proc_status = process.status() |
| | return proc_status |
| | except psutil.NoSuchProcess: |
| | logger.info(f"No process found with PID: {pid}") |
| | return "Completed" |
| |
|
| |
|
| | def kill_process_by_pid(pid): |
| | """Kill process by PID.""" |
| | os.kill(pid, signal.SIGTERM) |
| |
|
| |
|
| | class JobDB: |
| | def __init__(self, db_path): |
| | self.db_path = db_path |
| | self.conn = sqlite3.connect(db_path) |
| | self.c = self.conn.cursor() |
| | self.create_jobs_table() |
| |
|
| | def create_jobs_table(self): |
| | self.c.execute( |
| | """CREATE TABLE IF NOT EXISTS jobs |
| | (id INTEGER PRIMARY KEY, pid INTEGER)""" |
| | ) |
| | self.conn.commit() |
| |
|
| | def add_job(self, pid): |
| | sql = f"INSERT INTO jobs (pid) VALUES ({pid})" |
| | self.c.execute(sql) |
| | self.conn.commit() |
| |
|
| | def get_running_jobs(self): |
| | self.c.execute("""SELECT pid FROM jobs""") |
| | running_pids = self.c.fetchall() |
| | running_pids = [pid[0] for pid in running_pids] |
| | return running_pids |
| |
|
| | def delete_job(self, pid): |
| | sql = f"DELETE FROM jobs WHERE pid={pid}" |
| | self.c.execute(sql) |
| | self.conn.commit() |
| |
|
| |
|
| | PARAMS = os.environ.get("PARAMS") |
| | DB = JobDB("job.db") |
| |
|
| |
|
| | class BackgroundRunner: |
| | async def run_main(self): |
| | while True: |
| | running_jobs = DB.get_running_jobs() |
| | if running_jobs: |
| | for _pid in running_jobs: |
| | proc_status = get_process_status(_pid) |
| | proc_status = proc_status.strip().lower() |
| | if proc_status in ("completed", "error", "zombie"): |
| | logger.info(f"Process {_pid} is already completed. Skipping...") |
| | try: |
| | kill_process_by_pid(_pid) |
| | except Exception as e: |
| | logger.info(f"Error while killing process: {e}") |
| | DB.delete_job(_pid) |
| |
|
| | running_jobs = DB.get_running_jobs() |
| | if not running_jobs: |
| | logger.info("No running jobs found. Shutting down the server.") |
| | os.kill(os.getpid(), signal.SIGINT) |
| | await asyncio.sleep(30) |
| |
|
| |
|
| | runner = BackgroundRunner() |
| |
|
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | process_pid = run_evaluation(params=PARAMS) |
| | logger.info(f"Started training with PID {process_pid}") |
| | DB.add_job(process_pid) |
| | asyncio.create_task(runner.run_main()) |
| | yield |
| |
|
| |
|
| | api = FastAPI(lifespan=lifespan) |
| |
|
| |
|
| | @api.get("/") |
| | async def root(): |
| | return "Your model is being evaluated..." |
| |
|
| |
|
| | @api.get("/health") |
| | async def health(): |
| | return "OK" |
| |
|