File size: 15,917 Bytes
87fbc7a
f1dacad
 
 
 
87fbc7a
 
f1dacad
 
 
 
 
 
 
87fbc7a
 
 
f1dacad
87fbc7a
 
f1dacad
 
 
 
 
 
482b6b9
f1dacad
482b6b9
f1dacad
 
 
 
 
87fbc7a
f1dacad
 
 
 
 
 
 
 
 
 
 
 
87fbc7a
 
f1dacad
 
87fbc7a
 
f1dacad
 
 
87fbc7a
 
f1dacad
 
 
 
 
87fbc7a
 
f1dacad
87fbc7a
482b6b9
 
f1dacad
87fbc7a
 
 
f1dacad
87fbc7a
 
 
 
 
f1dacad
 
 
c316ccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1dacad
 
c316ccb
f1dacad
c316ccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1dacad
 
 
 
 
 
c316ccb
 
 
f1dacad
 
 
 
 
 
482b6b9
f1dacad
 
 
482b6b9
 
f1dacad
f45912f
482b6b9
f1dacad
 
 
 
 
 
 
 
87fbc7a
482b6b9
 
 
 
 
 
 
 
 
 
f45912f
482b6b9
f1dacad
482b6b9
f1dacad
482b6b9
f1dacad
482b6b9
f1dacad
 
 
 
 
 
 
 
 
 
 
 
 
482b6b9
 
c316ccb
 
 
f1dacad
 
 
 
c316ccb
f1dacad
 
 
 
 
 
 
 
482b6b9
 
 
f1dacad
482b6b9
f1dacad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87fbc7a
f1dacad
 
 
 
 
 
c316ccb
 
 
 
f1dacad
 
 
c316ccb
 
f1dacad
 
 
 
 
 
 
 
 
 
87fbc7a
f1dacad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c316ccb
 
f1dacad
 
 
c316ccb
f1dacad
 
 
 
 
 
 
 
 
 
 
 
 
87fbc7a
 
 
f1dacad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c316ccb
f1dacad
 
 
 
 
 
 
 
 
87fbc7a
 
f1dacad
 
 
 
 
 
 
 
 
 
 
 
87fbc7a
 
 
f1dacad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87fbc7a
f1dacad
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
"""
server.py β€” FastAPI/OpenEnv server wrapper for CustomerSupportEnv.

Exposes the environment as REST endpoints compatible with OpenEnv specification.
Handles session management and action validation.

Endpoints:
  POST   /reset              β†’ Initialize new episode, return initial observation
  POST   /step               β†’ Apply action, return (obs, reward, done)
  GET    /state              β†’ Get current environment state
  GET    /tasks              β†’ List all tasks
  POST   /grade              β†’ Grade current episode
  GET    /health             β†’ Health check
  GET    /openenv.yaml       β†’ Spec file
"""
from __future__ import annotations

import json
import os
import sys
import traceback
from typing import Any, Dict, Optional
from pathlib import Path

# FastAPI imports
try:
    from fastapi import FastAPI, HTTPException, Request, Body
    from fastapi.responses import FileResponse, JSONResponse
    from pydantic import BaseModel, ConfigDict
    import uvicorn
except ImportError as e:
    print(f"[ERROR] Missing FastAPI dependency: {e}", flush=True)
    print("Run: pip install fastapi uvicorn pydantic", flush=True)
    sys.exit(1)

# Local env imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
try:
    from env.environment import CustomerSupportEnv, TASKS
    from env.models import Action, ActionType, Observation, Reward
    from graders.graders import grade
except ImportError as e:
    print(f"[ERROR] Missing local env module: {e}", flush=True)
    traceback.print_exc()
    sys.exit(1)

# ── FastAPI App ──────────────────────────────────────────────────────────────
app = FastAPI(
    title="CustomerSupportEnv",
    description="OpenEnv-compatible customer support RL environment",
    version="1.0.0"
)

# ── Session Storage (in-memory for single deployment) ───────────────────────
_sessions: Dict[str, Dict[str, Any]] = {}
_session_counter = 0


def new_session_id() -> str:
    """Generate a unique session ID."""
    global _session_counter
    _session_counter += 1
    return f"session_{_session_counter:06d}"


# ── Pydantic Models ──────────────────────────────────────────────────────────
class ResetRequest(BaseModel):
    model_config = ConfigDict(extra="allow")
    task_id: Optional[str] = None
    seed: Optional[int] = None


class StepRequest(BaseModel):
    session_id: str
    action_type: str
    payload: Optional[str] = None


class GradeRequest(BaseModel):
    session_id: str


# ── Helper: Make JSON serializable ──────────────────────────────────────────
def to_json_serializable(obj: Any) -> Any:
    """Convert any object to JSON-serializable format."""
    if obj is None:
        return None
    elif isinstance(obj, (str, int, float, bool)):
        return obj
    elif isinstance(obj, dict):
        return {k: to_json_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [to_json_serializable(item) for item in obj]
    elif hasattr(obj, 'dict') and callable(obj.dict):
        # Pydantic model
        return to_json_serializable(obj.dict())
    elif hasattr(obj, '__dict__'):
        # Regular object with attributes
        return to_json_serializable(obj.__dict__)
    else:
        # Fallback to string representation
        return str(obj)


def serialize_obs(obs: Observation) -> Dict[str, Any]:
    """Convert Observation dataclass to JSON-serializable dict."""
    # Convert all fields to JSON-serializable format
    return {
        "ticket_id": to_json_serializable(obs.ticket_id),
        "task_id": to_json_serializable(obs.task_id),
        "status": to_json_serializable(obs.status),
        "sentiment": to_json_serializable(obs.sentiment),
        "priority": to_json_serializable(obs.priority),
        "category": to_json_serializable(obs.category),
        "turn": to_json_serializable(obs.turn),
        "max_turns": to_json_serializable(obs.max_turns),
        "history": to_json_serializable(obs.history),
        "kb_results": to_json_serializable(obs.kb_results),
        "kb_searched": to_json_serializable(obs.kb_searched),
        "empathized": to_json_serializable(obs.empathized),
        "clarified": to_json_serializable(obs.clarified),
        "solution_offered": to_json_serializable(obs.solution_offered),
        "escalated": to_json_serializable(obs.escalated),
        "cumulative_reward": to_json_serializable(obs.cumulative_reward),
        "done": to_json_serializable(obs.done),
    }


def serialize_reward(reward: Reward) -> Dict[str, Any]:
    """Convert Reward dataclass to JSON-serializable dict."""
    return {
        "total": to_json_serializable(reward.total),
        "breakdown": to_json_serializable(reward.breakdown),
        "reason": to_json_serializable(reward.reason),
    }


# ── OpenEnv Endpoints ────────────────────────────────────────────────────────

@app.post("/reset")
async def reset(request: Optional[Dict[str, Any]] = Body(default=None)) -> JSONResponse:
    """
    Reset environment and start a new episode.
    
    Accepts both empty POST and JSON body with optional parameters.
    
    Args:
        task_id: One of task_1, task_2, task_3 (optional, defaults to task_1)
        seed: Optional random seed (defaults to 42)
    
    Returns:
        {
            "session_id": str,
            "observation": {...},
            "info": {...}
        }
    """
    try:
        # Default values
        task_id = "task_1"
        seed = 42
        
        # Override with request values if provided
        if request is not None and isinstance(request, dict):
            if "task_id" in request and request["task_id"]:
                task_id = request["task_id"]
            if "seed" in request and request["seed"] is not None:
                seed = request["seed"]
        
        print(f"[RESET] task_id={task_id}, seed={seed}", flush=True)
        
        # Validate task_id
        if task_id not in TASKS:
            raise ValueError(f"Invalid task_id '{task_id}'. Must be one of: {list(TASKS.keys())}")
        
        # Create and reset environment
        env = CustomerSupportEnv(task_id=task_id, seed=seed)
        obs = env.reset()
        
        # Store session
        session_id = new_session_id()
        _sessions[session_id] = {
            "env": env,
            "task_id": task_id,
            "observation": obs,
            "steps": 0,
            "done": False,
        }
        
        print(f"[RESET] Created session {session_id}", flush=True)
        
        # Serialize observation to ensure JSON compatibility
        obs_json = serialize_obs(obs)
        
        return JSONResponse(
            status_code=200,
            content={
                "session_id": session_id,
                "observation": obs_json,
                "info": {
                    "task_id": task_id,
                    "difficulty": TASKS[task_id].difficulty,
                    "description": TASKS[task_id].description,
                }
            }
        )
    
    except ValueError as e:
        print(f"[RESET ERROR] Validation error: {e}", flush=True)
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        print(f"[RESET ERROR] {type(e).__name__}: {e}", flush=True)
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Reset failed: {str(e)}")


@app.post("/step")
async def step(request: StepRequest) -> JSONResponse:
    """
    Apply an action and step the environment.
    
    Args:
        session_id: Session ID from /reset
        action_type: One of [search_kb, empathize, ask_clarify, offer_solution, escalate, resolve, send_message]
        payload: Optional action payload (required for some action types)
    
    Returns:
        {
            "observation": {...},
            "reward": {...},
            "done": bool,
            "info": {...}
        }
    """
    try:
        session_id = request.session_id
        action_type = request.action_type
        payload = request.payload
        
        if session_id not in _sessions:
            raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
        
        session = _sessions[session_id]
        env = session["env"]
        
        if session["done"]:
            raise HTTPException(status_code=400, detail="Episode already done. Call /reset to start new episode.")
        
        # Create action
        action = Action(action_type=action_type, payload=payload)
        
        # Step environment
        result = env.step(action)
        
        # Update session
        session["observation"] = result.observation
        session["steps"] += 1
        session["done"] = result.observation.done
        
        # Serialize for JSON compatibility
        obs_json = serialize_obs(result.observation)
        reward_json = serialize_reward(result.reward)
        
        return JSONResponse(
            status_code=200,
            content={
                "observation": obs_json,
                "reward": reward_json,
                "done": result.observation.done,
                "info": {
                    "step": session["steps"],
                    "action": action_type,
                }
            }
        )
    
    except HTTPException:
        raise
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Step failed: {str(e)}")


@app.get("/state")
async def state_endpoint(session_id: str) -> JSONResponse:
    """
    Get current environment state without stepping.
    
    Args:
        session_id: Session ID from /reset
    
    Returns:
        {
            "observation": {...},
            "info": {...}
        }
    """
    try:
        if session_id not in _sessions:
            raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
        
        session = _sessions[session_id]
        obs = session["observation"]
        
        obs_json = serialize_obs(obs)
        
        return JSONResponse(
            status_code=200,
            content={
                "observation": obs_json,
                "info": {
                    "task_id": session["task_id"],
                    "steps": session["steps"],
                    "done": session["done"],
                }
            }
        )
    
    except HTTPException:
        raise
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"State query failed: {str(e)}")


@app.get("/tasks")
async def tasks_endpoint() -> JSONResponse:
    """
    List all available tasks.
    
    Returns:
        {
            "tasks": [
                {
                    "id": "task_1",
                    "name": "...",
                    "difficulty": "easy|medium|hard",
                    "description": "...",
                    "max_turns": int
                },
                ...
            ]
        }
    """
    try:
        task_list = []
        for task_id, task_obj in TASKS.items():
            task_list.append({
                "id": task_id,
                "name": task_obj.name,
                "difficulty": task_obj.difficulty,
                "description": task_obj.description,
                "max_turns": task_obj.max_turns,
            })
        
        return JSONResponse(
            status_code=200,
            content={"tasks": task_list}
        )
    
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Tasks query failed: {str(e)}")


@app.post("/grade")
async def grade_endpoint(request: GradeRequest) -> JSONResponse:
    """
    Grade the current episode.
    
    Args:
        session_id: Session ID from /reset
    
    Returns:
        {
            "score": float (0.0 to 1.0),
            "passed": bool,
            "breakdown": {...},
            "reason": str
        }
    """
    try:
        session_id = request.session_id
        
        if session_id not in _sessions:
            raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
        
        session = _sessions[session_id]
        env = session["env"]
        task_id = session["task_id"]
        
        # Get final state
        final_obs = env.state()
        
        # Grade
        grader_result = grade(task_id, final_obs)
        
        return JSONResponse(
            status_code=200,
            content={
                "score": grader_result.score,
                "passed": grader_result.passed,
                "breakdown": to_json_serializable(grader_result.breakdown),
                "reason": grader_result.reason,
            }
        )
    
    except HTTPException:
        raise
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Grading failed: {str(e)}")


@app.get("/health")
async def health() -> JSONResponse:
    """Health check endpoint."""
    return JSONResponse(
        status_code=200,
        content={
            "status": "healthy",
            "service": "CustomerSupportEnv",
            "version": "1.0.0",
            "sessions_active": len(_sessions),
        }
    )


@app.get("/openenv.yaml")
async def openenv_spec() -> FileResponse:
    """Serve OpenEnv specification."""
    spec_path = Path(__file__).parent / "openenv.yaml"
    if not spec_path.exists():
        raise HTTPException(status_code=404, detail="openenv.yaml not found")
    return FileResponse(spec_path, media_type="text/yaml")


# ── Root endpoint ────────────────────────────────────────────────────────────
@app.get("/")
async def root() -> JSONResponse:
    """Root endpoint."""
    return JSONResponse(
        status_code=200,
        content={
            "service": "CustomerSupportEnv OpenEnv Server",
            "version": "1.0.0",
            "endpoints": {
                "POST /reset": "Initialize new episode",
                "POST /step": "Apply action",
                "GET /state": "Get current state",
                "GET /tasks": "List tasks",
                "POST /grade": "Grade episode",
                "GET /health": "Health check",
                "GET /openenv.yaml": "Specification",
            }
        }
    )


# ── Startup/Shutdown ─────────────────────────────────────────────────────────
@app.on_event("startup")
async def startup_event():
    """Log startup."""
    print("[INFO] CustomerSupportEnv server started", flush=True)


@app.on_event("shutdown")
async def shutdown_event():
    """Log shutdown."""
    print("[INFO] CustomerSupportEnv server shutdown", flush=True)


# ── Main ─────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    host = os.environ.get("HOST", "0.0.0.0")
    
    print(f"[INFO] Starting server on {host}:{port}", flush=True)
    uvicorn.run(app, host=host, port=port, log_level="info")