customer_support / server.py
Sayed223's picture
Update server.py
c316ccb verified
"""
server.py β€” FastAPI/OpenEnv server wrapper for CustomerSupportEnv.
Exposes the environment as REST endpoints compatible with OpenEnv specification.
Handles session management and action validation.
Endpoints:
POST /reset β†’ Initialize new episode, return initial observation
POST /step β†’ Apply action, return (obs, reward, done)
GET /state β†’ Get current environment state
GET /tasks β†’ List all tasks
POST /grade β†’ Grade current episode
GET /health β†’ Health check
GET /openenv.yaml β†’ Spec file
"""
from __future__ import annotations
import json
import os
import sys
import traceback
from typing import Any, Dict, Optional
from pathlib import Path
# FastAPI imports
try:
from fastapi import FastAPI, HTTPException, Request, Body
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel, ConfigDict
import uvicorn
except ImportError as e:
print(f"[ERROR] Missing FastAPI dependency: {e}", flush=True)
print("Run: pip install fastapi uvicorn pydantic", flush=True)
sys.exit(1)
# Local env imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
try:
from env.environment import CustomerSupportEnv, TASKS
from env.models import Action, ActionType, Observation, Reward
from graders.graders import grade
except ImportError as e:
print(f"[ERROR] Missing local env module: {e}", flush=True)
traceback.print_exc()
sys.exit(1)
# ── FastAPI App ──────────────────────────────────────────────────────────────
app = FastAPI(
title="CustomerSupportEnv",
description="OpenEnv-compatible customer support RL environment",
version="1.0.0"
)
# ── Session Storage (in-memory for single deployment) ───────────────────────
_sessions: Dict[str, Dict[str, Any]] = {}
_session_counter = 0
def new_session_id() -> str:
"""Generate a unique session ID."""
global _session_counter
_session_counter += 1
return f"session_{_session_counter:06d}"
# ── Pydantic Models ──────────────────────────────────────────────────────────
class ResetRequest(BaseModel):
model_config = ConfigDict(extra="allow")
task_id: Optional[str] = None
seed: Optional[int] = None
class StepRequest(BaseModel):
session_id: str
action_type: str
payload: Optional[str] = None
class GradeRequest(BaseModel):
session_id: str
# ── Helper: Make JSON serializable ──────────────────────────────────────────
def to_json_serializable(obj: Any) -> Any:
"""Convert any object to JSON-serializable format."""
if obj is None:
return None
elif isinstance(obj, (str, int, float, bool)):
return obj
elif isinstance(obj, dict):
return {k: to_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [to_json_serializable(item) for item in obj]
elif hasattr(obj, 'dict') and callable(obj.dict):
# Pydantic model
return to_json_serializable(obj.dict())
elif hasattr(obj, '__dict__'):
# Regular object with attributes
return to_json_serializable(obj.__dict__)
else:
# Fallback to string representation
return str(obj)
def serialize_obs(obs: Observation) -> Dict[str, Any]:
"""Convert Observation dataclass to JSON-serializable dict."""
# Convert all fields to JSON-serializable format
return {
"ticket_id": to_json_serializable(obs.ticket_id),
"task_id": to_json_serializable(obs.task_id),
"status": to_json_serializable(obs.status),
"sentiment": to_json_serializable(obs.sentiment),
"priority": to_json_serializable(obs.priority),
"category": to_json_serializable(obs.category),
"turn": to_json_serializable(obs.turn),
"max_turns": to_json_serializable(obs.max_turns),
"history": to_json_serializable(obs.history),
"kb_results": to_json_serializable(obs.kb_results),
"kb_searched": to_json_serializable(obs.kb_searched),
"empathized": to_json_serializable(obs.empathized),
"clarified": to_json_serializable(obs.clarified),
"solution_offered": to_json_serializable(obs.solution_offered),
"escalated": to_json_serializable(obs.escalated),
"cumulative_reward": to_json_serializable(obs.cumulative_reward),
"done": to_json_serializable(obs.done),
}
def serialize_reward(reward: Reward) -> Dict[str, Any]:
"""Convert Reward dataclass to JSON-serializable dict."""
return {
"total": to_json_serializable(reward.total),
"breakdown": to_json_serializable(reward.breakdown),
"reason": to_json_serializable(reward.reason),
}
# ── OpenEnv Endpoints ────────────────────────────────────────────────────────
@app.post("/reset")
async def reset(request: Optional[Dict[str, Any]] = Body(default=None)) -> JSONResponse:
"""
Reset environment and start a new episode.
Accepts both empty POST and JSON body with optional parameters.
Args:
task_id: One of task_1, task_2, task_3 (optional, defaults to task_1)
seed: Optional random seed (defaults to 42)
Returns:
{
"session_id": str,
"observation": {...},
"info": {...}
}
"""
try:
# Default values
task_id = "task_1"
seed = 42
# Override with request values if provided
if request is not None and isinstance(request, dict):
if "task_id" in request and request["task_id"]:
task_id = request["task_id"]
if "seed" in request and request["seed"] is not None:
seed = request["seed"]
print(f"[RESET] task_id={task_id}, seed={seed}", flush=True)
# Validate task_id
if task_id not in TASKS:
raise ValueError(f"Invalid task_id '{task_id}'. Must be one of: {list(TASKS.keys())}")
# Create and reset environment
env = CustomerSupportEnv(task_id=task_id, seed=seed)
obs = env.reset()
# Store session
session_id = new_session_id()
_sessions[session_id] = {
"env": env,
"task_id": task_id,
"observation": obs,
"steps": 0,
"done": False,
}
print(f"[RESET] Created session {session_id}", flush=True)
# Serialize observation to ensure JSON compatibility
obs_json = serialize_obs(obs)
return JSONResponse(
status_code=200,
content={
"session_id": session_id,
"observation": obs_json,
"info": {
"task_id": task_id,
"difficulty": TASKS[task_id].difficulty,
"description": TASKS[task_id].description,
}
}
)
except ValueError as e:
print(f"[RESET ERROR] Validation error: {e}", flush=True)
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
print(f"[RESET ERROR] {type(e).__name__}: {e}", flush=True)
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Reset failed: {str(e)}")
@app.post("/step")
async def step(request: StepRequest) -> JSONResponse:
"""
Apply an action and step the environment.
Args:
session_id: Session ID from /reset
action_type: One of [search_kb, empathize, ask_clarify, offer_solution, escalate, resolve, send_message]
payload: Optional action payload (required for some action types)
Returns:
{
"observation": {...},
"reward": {...},
"done": bool,
"info": {...}
}
"""
try:
session_id = request.session_id
action_type = request.action_type
payload = request.payload
if session_id not in _sessions:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
session = _sessions[session_id]
env = session["env"]
if session["done"]:
raise HTTPException(status_code=400, detail="Episode already done. Call /reset to start new episode.")
# Create action
action = Action(action_type=action_type, payload=payload)
# Step environment
result = env.step(action)
# Update session
session["observation"] = result.observation
session["steps"] += 1
session["done"] = result.observation.done
# Serialize for JSON compatibility
obs_json = serialize_obs(result.observation)
reward_json = serialize_reward(result.reward)
return JSONResponse(
status_code=200,
content={
"observation": obs_json,
"reward": reward_json,
"done": result.observation.done,
"info": {
"step": session["steps"],
"action": action_type,
}
}
)
except HTTPException:
raise
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Step failed: {str(e)}")
@app.get("/state")
async def state_endpoint(session_id: str) -> JSONResponse:
"""
Get current environment state without stepping.
Args:
session_id: Session ID from /reset
Returns:
{
"observation": {...},
"info": {...}
}
"""
try:
if session_id not in _sessions:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
session = _sessions[session_id]
obs = session["observation"]
obs_json = serialize_obs(obs)
return JSONResponse(
status_code=200,
content={
"observation": obs_json,
"info": {
"task_id": session["task_id"],
"steps": session["steps"],
"done": session["done"],
}
}
)
except HTTPException:
raise
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"State query failed: {str(e)}")
@app.get("/tasks")
async def tasks_endpoint() -> JSONResponse:
"""
List all available tasks.
Returns:
{
"tasks": [
{
"id": "task_1",
"name": "...",
"difficulty": "easy|medium|hard",
"description": "...",
"max_turns": int
},
...
]
}
"""
try:
task_list = []
for task_id, task_obj in TASKS.items():
task_list.append({
"id": task_id,
"name": task_obj.name,
"difficulty": task_obj.difficulty,
"description": task_obj.description,
"max_turns": task_obj.max_turns,
})
return JSONResponse(
status_code=200,
content={"tasks": task_list}
)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Tasks query failed: {str(e)}")
@app.post("/grade")
async def grade_endpoint(request: GradeRequest) -> JSONResponse:
"""
Grade the current episode.
Args:
session_id: Session ID from /reset
Returns:
{
"score": float (0.0 to 1.0),
"passed": bool,
"breakdown": {...},
"reason": str
}
"""
try:
session_id = request.session_id
if session_id not in _sessions:
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
session = _sessions[session_id]
env = session["env"]
task_id = session["task_id"]
# Get final state
final_obs = env.state()
# Grade
grader_result = grade(task_id, final_obs)
return JSONResponse(
status_code=200,
content={
"score": grader_result.score,
"passed": grader_result.passed,
"breakdown": to_json_serializable(grader_result.breakdown),
"reason": grader_result.reason,
}
)
except HTTPException:
raise
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Grading failed: {str(e)}")
@app.get("/health")
async def health() -> JSONResponse:
"""Health check endpoint."""
return JSONResponse(
status_code=200,
content={
"status": "healthy",
"service": "CustomerSupportEnv",
"version": "1.0.0",
"sessions_active": len(_sessions),
}
)
@app.get("/openenv.yaml")
async def openenv_spec() -> FileResponse:
"""Serve OpenEnv specification."""
spec_path = Path(__file__).parent / "openenv.yaml"
if not spec_path.exists():
raise HTTPException(status_code=404, detail="openenv.yaml not found")
return FileResponse(spec_path, media_type="text/yaml")
# ── Root endpoint ────────────────────────────────────────────────────────────
@app.get("/")
async def root() -> JSONResponse:
"""Root endpoint."""
return JSONResponse(
status_code=200,
content={
"service": "CustomerSupportEnv OpenEnv Server",
"version": "1.0.0",
"endpoints": {
"POST /reset": "Initialize new episode",
"POST /step": "Apply action",
"GET /state": "Get current state",
"GET /tasks": "List tasks",
"POST /grade": "Grade episode",
"GET /health": "Health check",
"GET /openenv.yaml": "Specification",
}
}
)
# ── Startup/Shutdown ─────────────────────────────────────────────────────────
@app.on_event("startup")
async def startup_event():
"""Log startup."""
print("[INFO] CustomerSupportEnv server started", flush=True)
@app.on_event("shutdown")
async def shutdown_event():
"""Log shutdown."""
print("[INFO] CustomerSupportEnv server shutdown", flush=True)
# ── Main ─────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
host = os.environ.get("HOST", "0.0.0.0")
print(f"[INFO] Starting server on {host}:{port}", flush=True)
uvicorn.run(app, host=host, port=port, log_level="info")