File size: 3,902 Bytes
ccc1c93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations

import logging
from fastapi import Request
from openenv.core.env_server.http_server import create_app
from fastapi import HTTPException

try:
    from ..models import (
        BaselineRequest,
        BaselineScores,
        GraderRequest,
        GraderResponse,
        GridAction,
        GridObservation,
        PlanningContextRequest,
        PlanningContextResponse,
        SimulationRequest,
        SimulationResponse,
        TaskListResponse,
    )
    from .graders import grade_episode
    from .grid_environment import GridEnvironment
    from .logging_utils import configure_logging
    from .tasks import task_list
except ImportError:
    from models import (
        BaselineRequest,
        BaselineScores,
        GraderRequest,
        GraderResponse,
        GridAction,
        GridObservation,
        PlanningContextRequest,
        PlanningContextResponse,
        SimulationRequest,
        SimulationResponse,
        TaskListResponse,
    )
    from server.graders import grade_episode
    from server.grid_environment import GridEnvironment
    from server.logging_utils import configure_logging
    from server.tasks import task_list

configure_logging()
logger = logging.getLogger(__name__)

app = create_app(
    GridEnvironment,
    GridAction,
    GridObservation,
    env_name="grid2op_env",
    max_concurrent_envs=2,
)


@app.get("/tasks", response_model=TaskListResponse)
def get_tasks() -> TaskListResponse:
    logger.info("Serving /tasks")
    return TaskListResponse(
        tasks=task_list(),
        action_schema=GridAction.model_json_schema(),
    )


@app.post("/grader", response_model=GraderResponse)
def post_grader(payload: GraderRequest) -> GraderResponse:
    logger.info(
        "Serving /grader task_id=%s steps=%s",
        payload.task_id,
        len(payload.episode_log),
    )
    return GraderResponse(
        task_id=payload.task_id,
        score=grade_episode(payload.task_id, payload.episode_log),
    )


@app.post("/baseline", response_model=BaselineScores)
def run_baseline_route(payload: BaselineRequest, request: Request) -> BaselineScores:
    from ..inference import run_baseline_suite

    base_url = str(request.base_url).rstrip("/")
    logger.info("Serving /baseline model=%s base_url=%s", payload.model, base_url)
    return run_baseline_suite(base_url=base_url, config=payload)


@app.post("/planning_context", response_model=PlanningContextResponse)
def post_planning_context(payload: PlanningContextRequest) -> PlanningContextResponse:
    env = GridEnvironment.get_active_instance(payload.episode_id)
    if env is None:
        raise HTTPException(status_code=404, detail=f"Unknown episode_id: {payload.episode_id}")
    logger.info("Serving /planning_context episode_id=%s", payload.episode_id)
    return env.get_planning_context()


@app.post("/simulate", response_model=SimulationResponse)
def post_simulate(payload: SimulationRequest) -> SimulationResponse:
    env = GridEnvironment.get_active_instance(payload.episode_id)
    if env is None:
        raise HTTPException(status_code=404, detail=f"Unknown episode_id: {payload.episode_id}")
    logger.info(
        "Serving /simulate episode_id=%s candidate_count=%s",
        payload.episode_id,
        len(payload.actions),
    )
    return SimulationResponse(
        episode_id=payload.episode_id,
        results=env.simulate_actions(payload.actions),
    )


def main(host: str = "0.0.0.0", port: int = 7860) -> None:
    import argparse
    import uvicorn

    parser = argparse.ArgumentParser()
    parser.add_argument("--host", default=host)
    parser.add_argument("--port", type=int, default=port)
    args = parser.parse_args()
    logger.info("Starting Grid2Op FastAPI server host=%s port=%s", args.host, args.port)
    uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
    main()