Spaces:
No application file
No application file
| from __future__ import annotations | |
| import logging | |
| from fastapi import Request | |
| from openenv.core.env_server.http_server import create_app | |
| from fastapi import HTTPException | |
| try: | |
| from ..models import ( | |
| BaselineRequest, | |
| BaselineScores, | |
| GraderRequest, | |
| GraderResponse, | |
| GridAction, | |
| GridObservation, | |
| PlanningContextRequest, | |
| PlanningContextResponse, | |
| SimulationRequest, | |
| SimulationResponse, | |
| TaskListResponse, | |
| ) | |
| from .graders import grade_episode | |
| from .grid_environment import GridEnvironment | |
| from .logging_utils import configure_logging | |
| from .tasks import task_list | |
| except ImportError: | |
| from models import ( | |
| BaselineRequest, | |
| BaselineScores, | |
| GraderRequest, | |
| GraderResponse, | |
| GridAction, | |
| GridObservation, | |
| PlanningContextRequest, | |
| PlanningContextResponse, | |
| SimulationRequest, | |
| SimulationResponse, | |
| TaskListResponse, | |
| ) | |
| from server.graders import grade_episode | |
| from server.grid_environment import GridEnvironment | |
| from server.logging_utils import configure_logging | |
| from server.tasks import task_list | |
| configure_logging() | |
| logger = logging.getLogger(__name__) | |
| app = create_app( | |
| GridEnvironment, | |
| GridAction, | |
| GridObservation, | |
| env_name="grid2op_env", | |
| max_concurrent_envs=2, | |
| ) | |
| def get_tasks() -> TaskListResponse: | |
| logger.info("Serving /tasks") | |
| return TaskListResponse( | |
| tasks=task_list(), | |
| action_schema=GridAction.model_json_schema(), | |
| ) | |
| def post_grader(payload: GraderRequest) -> GraderResponse: | |
| logger.info( | |
| "Serving /grader task_id=%s steps=%s", | |
| payload.task_id, | |
| len(payload.episode_log), | |
| ) | |
| return GraderResponse( | |
| task_id=payload.task_id, | |
| score=grade_episode(payload.task_id, payload.episode_log), | |
| ) | |
| def run_baseline_route(payload: BaselineRequest, request: Request) -> BaselineScores: | |
| from ..inference import run_baseline_suite | |
| base_url = str(request.base_url).rstrip("/") | |
| logger.info("Serving /baseline model=%s base_url=%s", payload.model, base_url) | |
| return run_baseline_suite(base_url=base_url, config=payload) | |
| def post_planning_context(payload: PlanningContextRequest) -> PlanningContextResponse: | |
| env = GridEnvironment.get_active_instance(payload.episode_id) | |
| if env is None: | |
| raise HTTPException(status_code=404, detail=f"Unknown episode_id: {payload.episode_id}") | |
| logger.info("Serving /planning_context episode_id=%s", payload.episode_id) | |
| return env.get_planning_context() | |
| def post_simulate(payload: SimulationRequest) -> SimulationResponse: | |
| env = GridEnvironment.get_active_instance(payload.episode_id) | |
| if env is None: | |
| raise HTTPException(status_code=404, detail=f"Unknown episode_id: {payload.episode_id}") | |
| logger.info( | |
| "Serving /simulate episode_id=%s candidate_count=%s", | |
| payload.episode_id, | |
| len(payload.actions), | |
| ) | |
| return SimulationResponse( | |
| episode_id=payload.episode_id, | |
| results=env.simulate_actions(payload.actions), | |
| ) | |
| def main(host: str = "0.0.0.0", port: int = 7860) -> None: | |
| import argparse | |
| import uvicorn | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", default=host) | |
| parser.add_argument("--port", type=int, default=port) | |
| args = parser.parse_args() | |
| logger.info("Starting Grid2Op FastAPI server host=%s port=%s", args.host, args.port) | |
| uvicorn.run(app, host=args.host, port=args.port) | |
| if __name__ == "__main__": | |
| main() | |