JAYASREESS's picture
Create app.py
ccc1c93 verified
from __future__ import annotations
import logging
from fastapi import Request
from openenv.core.env_server.http_server import create_app
from fastapi import HTTPException
try:
from ..models import (
BaselineRequest,
BaselineScores,
GraderRequest,
GraderResponse,
GridAction,
GridObservation,
PlanningContextRequest,
PlanningContextResponse,
SimulationRequest,
SimulationResponse,
TaskListResponse,
)
from .graders import grade_episode
from .grid_environment import GridEnvironment
from .logging_utils import configure_logging
from .tasks import task_list
except ImportError:
from models import (
BaselineRequest,
BaselineScores,
GraderRequest,
GraderResponse,
GridAction,
GridObservation,
PlanningContextRequest,
PlanningContextResponse,
SimulationRequest,
SimulationResponse,
TaskListResponse,
)
from server.graders import grade_episode
from server.grid_environment import GridEnvironment
from server.logging_utils import configure_logging
from server.tasks import task_list
configure_logging()
logger = logging.getLogger(__name__)
app = create_app(
GridEnvironment,
GridAction,
GridObservation,
env_name="grid2op_env",
max_concurrent_envs=2,
)
@app.get("/tasks", response_model=TaskListResponse)
def get_tasks() -> TaskListResponse:
logger.info("Serving /tasks")
return TaskListResponse(
tasks=task_list(),
action_schema=GridAction.model_json_schema(),
)
@app.post("/grader", response_model=GraderResponse)
def post_grader(payload: GraderRequest) -> GraderResponse:
logger.info(
"Serving /grader task_id=%s steps=%s",
payload.task_id,
len(payload.episode_log),
)
return GraderResponse(
task_id=payload.task_id,
score=grade_episode(payload.task_id, payload.episode_log),
)
@app.post("/baseline", response_model=BaselineScores)
def run_baseline_route(payload: BaselineRequest, request: Request) -> BaselineScores:
from ..inference import run_baseline_suite
base_url = str(request.base_url).rstrip("/")
logger.info("Serving /baseline model=%s base_url=%s", payload.model, base_url)
return run_baseline_suite(base_url=base_url, config=payload)
@app.post("/planning_context", response_model=PlanningContextResponse)
def post_planning_context(payload: PlanningContextRequest) -> PlanningContextResponse:
env = GridEnvironment.get_active_instance(payload.episode_id)
if env is None:
raise HTTPException(status_code=404, detail=f"Unknown episode_id: {payload.episode_id}")
logger.info("Serving /planning_context episode_id=%s", payload.episode_id)
return env.get_planning_context()
@app.post("/simulate", response_model=SimulationResponse)
def post_simulate(payload: SimulationRequest) -> SimulationResponse:
env = GridEnvironment.get_active_instance(payload.episode_id)
if env is None:
raise HTTPException(status_code=404, detail=f"Unknown episode_id: {payload.episode_id}")
logger.info(
"Serving /simulate episode_id=%s candidate_count=%s",
payload.episode_id,
len(payload.actions),
)
return SimulationResponse(
episode_id=payload.episode_id,
results=env.simulate_actions(payload.actions),
)
def main(host: str = "0.0.0.0", port: int = 7860) -> None:
import argparse
import uvicorn
parser = argparse.ArgumentParser()
parser.add_argument("--host", default=host)
parser.add_argument("--port", type=int, default=port)
args = parser.parse_args()
logger.info("Starting Grid2Op FastAPI server host=%s port=%s", args.host, args.port)
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()