| import traceback |
| from typing import Any, Dict, List |
|
|
| import uvicorn |
| from fastapi import FastAPI |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse |
| from pydantic import BaseModel |
| from rex.utils.initialization import set_seed_and_log_path |
| from rex.utils.logging import logger |
|
|
| from src.task import SchemaGuidedInstructBertTask |
|
|
| set_seed_and_log_path(log_path="debug.log") |
|
|
| app = FastAPI() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| class RequestData(BaseModel): |
| data: List[Dict[str, Any]] |
|
|
|
|
| task = SchemaGuidedInstructBertTask.from_taskdir( |
| "mirror_outputs/Mirror_Pretrain_AllExcluded_2", |
| load_best_model=True, |
| initialize=False, |
| dump_configfile=False, |
| update_config={ |
| "regenerate_cache": False, |
| }, |
| ) |
|
|
|
|
| @app.post("/process") |
| def process_data(data: RequestData): |
| input_data = data.data |
|
|
| ok = True |
| msg = "" |
| results = {} |
| try: |
| results = task.predict(input_data) |
| msg = "success" |
| except KeyboardInterrupt: |
| raise KeyboardInterrupt |
| except Exception: |
| ok = False |
| msg = traceback.format_exc() |
|
|
| |
| logger.info(f"Data: {input_data}, Prediction: {results}") |
| return {"ok": ok, "msg": msg, "results": results} |
|
|
|
|
| @app.get("/") |
| async def api(): |
| return FileResponse("./index.html", media_type="text/html") |
|
|
|
|
| if __name__ == "__main__": |
| log_config = uvicorn.config.LOGGING_CONFIG |
| log_config["formatters"]["access"]["fmt"] = ( |
| "%(asctime)s | " + log_config["formatters"]["access"]["fmt"] |
| ) |
| log_config["formatters"]["default"]["fmt"] = ( |
| "%(asctime)s | " + log_config["formatters"]["default"]["fmt"] |
| ) |
| uvicorn.run( |
| "src.app.api_backend:app", |
| host="0.0.0.0", |
| port=7860, |
| log_level="debug", |
| log_config=log_config, |
| reload=True, |
| ) |
|
|