""" ╔═══════════════════════════════════════════════════════════════╗ ║ server.py — Cloudflare AI REST API ║ ║ ║ ║ OpenAI-compatible endpoints: ║ ║ POST /v1/chat/completions (streaming + non-streaming) ║ ║ GET /v1/models ║ ║ GET /health ║ ║ GET / ║ ║ ║ ║ Pool startup: up to 3 retries per slot, logs exact errors. ║ ║ Health monitor: heals dead idle slots every 60s. ║ ║ SSE: thread→asyncio bridge with backpressure. ║ ╚═══════════════════════════════════════════════════════════════╝ """ import asyncio import json import logging import os import sys import threading import time import traceback import uuid from contextlib import asynccontextmanager from typing import AsyncGenerator, List, Optional import uvicorn from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from cloudflare_provider import CloudflareProvider # ═══════════════════════════════════════════════════════════ # LOGGING # ═══════════════════════════════════════════════════════════ logging.basicConfig( level = logging.INFO, format = "%(asctime)s %(levelname)-8s %(message)s", stream = sys.stdout, datefmt = "%H:%M:%S", ) log = logging.getLogger("cf-api") # ═══════════════════════════════════════════════════════════ # CONFIG # ═══════════════════════════════════════════════════════════ POOL_SIZE = int(os.getenv("POOL_SIZE", "2")) PORT = int(os.getenv("PORT", "7860")) HOST = os.getenv("HOST", "0.0.0.0") HEALTH_INTERVAL = int(os.getenv("HEALTH_INTERVAL", "60")) ACQUIRE_TIMEOUT = int(os.getenv("ACQUIRE_TIMEOUT", "90")) STREAM_TIMEOUT = int(os.getenv("STREAM_TIMEOUT", "120")) DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "@cf/moonshotai/kimi-k2.5") DEFAULT_SYSTEM = os.getenv("DEFAULT_SYSTEM", "You are a helpful assistant.") SLOT_RETRIES = int(os.getenv("SLOT_RETRIES", "3")) SLOT_RETRY_WAIT = int(os.getenv("SLOT_RETRY_WAIT", "10")) # seconds between retries # ═══════════════════════════════════════════════════════════ # PYDANTIC SCHEMAS # ═══════════════════════════════════════════════════════════ class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): model: str = DEFAULT_MODEL messages: List[Message] temperature: float = Field(default=1.0, ge=0.0, le=2.0) max_tokens: Optional[int] = None stream: bool = True system: Optional[str] = None # ═══════════════════════════════════════════════════════════ # MANAGED PROVIDER SLOT # ═══════════════════════════════════════════════════════════ class ManagedProvider: def __init__(self, slot_id: int): self.slot_id = slot_id self.provider: Optional[CloudflareProvider] = None self.busy = False self.born_at = 0.0 self.error_count = 0 self.request_count = 0 self.last_error = "" def is_healthy(self) -> bool: if self.provider is None: return False try: return ( self.provider._on and self.provider._transport is not None and self.provider._transport.alive ) except Exception: return False def close(self): p = self.provider self.provider = None if p: try: p.close() except Exception: pass def __repr__(self): state = "busy" if self.busy else ("ok" if self.is_healthy() else "dead") mode = self.provider._mode if self.provider else "none" return f"" # ═══════════════════════════════════════════════════════════ # PROVIDER POOL # ═══════════════════════════════════════════════════════════ class ProviderPool: def __init__(self, size: int = 2): self.size = size self._slots: List[ManagedProvider] = [] self._queue: asyncio.Queue = None self._loop: asyncio.AbstractEventLoop = None # ─── Startup ────────────────────────────────────────── async def initialize(self): self._loop = asyncio.get_event_loop() self._queue = asyncio.Queue(maxsize=self.size) log.info(f"🚀 Initializing provider pool (slots={self.size})") log.info(f" DISPLAY={os.environ.get('DISPLAY', 'NOT SET')}") log.info(f" XVFB_EXTERNAL={os.environ.get('XVFB_EXTERNAL', '0')}") log.info(f" VR_DISPLAY={os.environ.get('VR_DISPLAY', '0')}") results = await asyncio.gather( *[self._spawn_slot_with_retry(i) for i in range(self.size)], return_exceptions=True, ) ok = sum(1 for r in results if not isinstance(r, Exception)) fail = sum(1 for r in results if isinstance(r, Exception)) if fail: for i, r in enumerate(results): if isinstance(r, Exception): log.error(f" [S{i}] FAILED: {r}") log.info(f" Pool ready — {ok}/{self.size} slots healthy") if ok == 0: raise RuntimeError( f"All {self.size} provider slots failed to connect.\n" f" → Check DISPLAY / XVFB_EXTERNAL environment variables.\n" f" → Ensure entrypoint.sh started Xvfb before the server.\n" f" → Check network connectivity to playground.ai.cloudflare.com." ) async def _spawn_slot_with_retry(self, slot_id: int) -> "ManagedProvider": """Try to create a slot, retrying up to SLOT_RETRIES times.""" managed = ManagedProvider(slot_id) for attempt in range(1, SLOT_RETRIES + 1): try: log.info(f" [S{slot_id}] Connecting... (attempt {attempt}/{SLOT_RETRIES})") def _create(): return CloudflareProvider( model = DEFAULT_MODEL, system = DEFAULT_SYSTEM, debug = True, # verbose during init so we can see failures use_cache = True, ) managed.provider = await asyncio.wait_for( self._loop.run_in_executor(None, _create), timeout=180, ) managed.provider.debug = False # quiet after successful boot managed.born_at = time.time() self._slots.append(managed) await self._queue.put(managed) mode = managed.provider._mode log.info(f" [S{slot_id}] ✓ Ready mode={mode!r}") return managed except asyncio.TimeoutError: err = f"Slot {slot_id} timed out (attempt {attempt})" log.warning(f" [S{slot_id}] ⚠ {err}") managed.last_error = err managed.close() except Exception as exc: err = str(exc) # Print full traceback for debugging log.warning( f" [S{slot_id}] ⚠ Attempt {attempt} failed:\n" + traceback.format_exc() ) managed.last_error = err managed.close() if attempt < SLOT_RETRIES: log.info(f" [S{slot_id}] Retrying in {SLOT_RETRY_WAIT}s...") await asyncio.sleep(SLOT_RETRY_WAIT) raise RuntimeError( f"Slot {slot_id} failed after {SLOT_RETRIES} attempts. " f"Last error: {managed.last_error}" ) # ─── Acquire ────────────────────────────────────────── @asynccontextmanager async def acquire(self): managed: ManagedProvider = await asyncio.wait_for( self._queue.get(), timeout=ACQUIRE_TIMEOUT, ) managed.busy = True try: if not managed.is_healthy(): log.warning(f"[S{managed.slot_id}] Unhealthy at checkout — healing now") await self._heal(managed) managed.request_count += 1 yield managed.provider except Exception: managed.error_count += 1 raise finally: managed.busy = False if managed.is_healthy(): await self._queue.put(managed) else: log.warning(f"[S{managed.slot_id}] Dead after use — background heal") asyncio.create_task(self._heal_then_return(managed)) # ─── Healing ────────────────────────────────────────── async def _heal(self, managed: ManagedProvider): sid = managed.slot_id log.info(f"[S{sid}] Healing slot...") def _recreate(): managed.close() return CloudflareProvider( model = DEFAULT_MODEL, system = DEFAULT_SYSTEM, debug = True, use_cache = True, ) try: managed.provider = await asyncio.wait_for( self._loop.run_in_executor(None, _recreate), timeout=180, ) managed.provider.debug = False managed.born_at = time.time() managed.error_count = 0 managed.last_error = "" log.info(f"[S{sid}] ✓ Healed mode={managed.provider._mode!r}") except Exception as e: managed.last_error = str(e) log.error(f"[S{sid}] ✗ Heal failed: {e}\n{traceback.format_exc()}") raise async def _heal_then_return(self, managed: ManagedProvider): sid = managed.slot_id for attempt in range(1, SLOT_RETRIES + 1): try: await self._heal(managed) await self._queue.put(managed) return except Exception as e: log.warning(f"[S{sid}] Heal attempt {attempt}/{SLOT_RETRIES} failed: {e}") if attempt < SLOT_RETRIES: await asyncio.sleep(SLOT_RETRY_WAIT) # Last resort: put it back anyway so queue doesn't shrink permanently log.error(f"[S{sid}] All heal attempts failed — slot may be non-functional") await self._queue.put(managed) # ─── Health monitor ─────────────────────────────────── async def health_monitor(self): while True: await asyncio.sleep(HEALTH_INTERVAL) healthy = sum(1 for m in self._slots if m.is_healthy()) busy = sum(1 for m in self._slots if m.busy) log.info( f"♥ Pool — {healthy}/{self.size} healthy " f"{busy} busy queue={self._queue.qsize()}" ) for managed in list(self._slots): if not managed.busy and not managed.is_healthy(): log.warning(f"[S{managed.slot_id}] Idle+dead — healing in background") asyncio.create_task(self._heal_then_return(managed)) # ─── Status ─────────────────────────────────────────── @property def status(self) -> dict: return { "pool_size": self.size, "queue_free": self._queue.qsize() if self._queue else 0, "slots": [ { "id": m.slot_id, "healthy": m.is_healthy(), "busy": m.busy, "mode": m.provider._mode if m.provider else "none", "errors": m.error_count, "requests": m.request_count, "age_s": round(time.time() - m.born_at, 1) if m.born_at else 0, "last_error": m.last_error or None, } for m in self._slots ], } # ─── Shutdown ───────────────────────────────────────── async def shutdown(self): log.info("Shutting down provider pool...") for m in self._slots: m.close() log.info("Pool shut down.") # ═══════════════════════════════════════════════════════════ # GLOBAL POOL # ═══════════════════════════════════════════════════════════ pool: ProviderPool = None # ═══════════════════════════════════════════════════════════ # LIFESPAN # ═══════════════════════════════════════════════════════════ @asynccontextmanager async def lifespan(app: FastAPI): global pool pool = ProviderPool(size=POOL_SIZE) await pool.initialize() monitor = asyncio.create_task(pool.health_monitor()) log.info(f"✅ Server ready {HOST}:{PORT}") yield monitor.cancel() try: await monitor except asyncio.CancelledError: pass await pool.shutdown() # ═══════════════════════════════════════════════════════════ # APP # ═══════════════════════════════════════════════════════════ app = FastAPI( title = "Cloudflare AI API", description = "OpenAI-compatible API via Cloudflare AI Playground", version = "1.1.0", lifespan = lifespan, docs_url = "/docs", redoc_url = "/redoc", ) app.add_middleware( CORSMiddleware, allow_origins = ["*"], allow_methods = ["*"], allow_headers = ["*"], ) # ═══════════════════════════════════════════════════════════ # SSE HELPERS # ═══════════════════════════════════════════════════════════ def _sse_chunk(content: str, model: str, cid: str) -> str: return "data: " + json.dumps({ "id": cid, "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}], }, ensure_ascii=False) + "\n\n" def _sse_done(model: str, cid: str) -> str: return "data: " + json.dumps({ "id": cid, "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], }) + "\n\ndata: [DONE]\n\n" def _sse_error(msg: str) -> str: return f'data: {{"error": {json.dumps(msg)}}}\n\ndata: [DONE]\n\n' async def _stream_generator( provider: CloudflareProvider, req: ChatRequest, ) -> AsyncGenerator[str, None]: loop = asyncio.get_event_loop() q: asyncio.Queue = asyncio.Queue(maxsize=512) cid = f"chatcmpl-{uuid.uuid4().hex[:20]}" cancel = threading.Event() messages = [{"role": m.role, "content": m.content} for m in req.messages] kwargs = { "messages": messages, "temperature": req.temperature, "model": req.model, } if req.max_tokens: kwargs["max_tokens"] = req.max_tokens if req.system: kwargs["system"] = req.system def _worker(): try: for chunk in provider.chat(**kwargs): if cancel.is_set(): break fut = asyncio.run_coroutine_threadsafe(q.put(chunk), loop) fut.result(timeout=10) except Exception as exc: err = RuntimeError(str(exc)) asyncio.run_coroutine_threadsafe(q.put(err), loop).result(timeout=5) finally: asyncio.run_coroutine_threadsafe(q.put(None), loop).result(timeout=5) t = threading.Thread(target=_worker, daemon=True) t.start() try: while True: item = await asyncio.wait_for(q.get(), timeout=STREAM_TIMEOUT) if item is None: yield _sse_done(req.model, cid) break if isinstance(item, Exception): yield _sse_error(str(item)) break if item: yield _sse_chunk(item, req.model, cid) except asyncio.TimeoutError: cancel.set() yield _sse_error("Stream timed out") finally: cancel.set() t.join(timeout=5) # ═══════════════════════════════════════════════════════════ # ENDPOINTS # ═══════════════════════════════════════════════════════════ @app.get("/", tags=["Info"]) async def root(): return { "service": "Cloudflare AI API", "version": "1.1.0", "status": "running", "display": os.environ.get("DISPLAY", "not set"), "endpoints": { "chat": "POST /v1/chat/completions", "models": "GET /v1/models", "health": "GET /health", "docs": "GET /docs", }, } @app.get("/health", tags=["Info"]) async def health(): if pool is None: raise HTTPException(503, detail="Pool not initialized") healthy = sum(1 for m in pool._slots if m.is_healthy()) status = "ok" if healthy > 0 else "degraded" return JSONResponse( content={"status": status, "pool": pool.status}, status_code=200 if status == "ok" else 206, ) @app.get("/v1/models", tags=["Models"]) async def list_models(): if pool is None: raise HTTPException(503, detail="Pool not initialized") async with pool.acquire() as provider: models = await asyncio.get_event_loop().run_in_executor( None, provider.list_models ) return { "object": "list", "data": [ { "id": m["name"], "object": "model", "created": 0, "owned_by": "cloudflare", "context_window": m.get("context", 4096), } for m in models ], } @app.post("/v1/chat/completions", tags=["Chat"]) async def chat_completions(req: ChatRequest, request: Request): if pool is None: raise HTTPException(503, detail="Pool not initialized") if not req.messages: raise HTTPException(400, detail="`messages` must not be empty") # ── Streaming ────────────────────────────────────────── if req.stream: async def _gen(): async with pool.acquire() as provider: async for chunk in _stream_generator(provider, req): if await request.is_disconnected(): break yield chunk return StreamingResponse( _gen(), media_type = "text/event-stream", headers = { "Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive", }, ) # ── Non-streaming ────────────────────────────────────── messages = [{"role": m.role, "content": m.content} for m in req.messages] kwargs = { "messages": messages, "temperature": req.temperature, "model": req.model, } if req.max_tokens: kwargs["max_tokens"] = req.max_tokens if req.system: kwargs["system"] = req.system loop = asyncio.get_event_loop() full_parts: list[str] = [] async with pool.acquire() as provider: def _collect(): for chunk in provider.chat(**kwargs): full_parts.append(chunk) await asyncio.wait_for( loop.run_in_executor(None, _collect), timeout=STREAM_TIMEOUT, ) return { "id": f"chatcmpl-{uuid.uuid4().hex[:20]}", "object": "chat.completion", "created": int(time.time()), "model": req.model, "choices": [{ "index": 0, "message": {"role": "assistant", "content": "".join(full_parts)}, "finish_reason": "stop", }], "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, } # ═══════════════════════════════════════════════════════════ # ENTRY POINT # ═══════════════════════════════════════════════════════════ if __name__ == "__main__": uvicorn.run( "server:app", host = HOST, port = PORT, log_level = "info", workers = 1, loop = "asyncio", timeout_keep_alive = 30, )