nicopi commited on
Commit
5bd11be
·
verified ·
1 Parent(s): 931fffd

Upload 2 files

Browse files
Files changed (2) hide show
  1. models.py +54 -0
  2. server.py +221 -0
models.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Any, Optional, Dict
3
+ from enum import Enum
4
+ import uuid
5
+ import time
6
+
7
+
8
+ class JobStatus(str, Enum):
9
+ PENDING = "pending"
10
+ CLAIMED = "claimed"
11
+ COMPLETED = "completed"
12
+ FAILED = "failed"
13
+ TIMEOUT = "timeout"
14
+
15
+
16
+ class APIJob(BaseModel):
17
+ job_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
18
+ created_at: float = Field(default_factory=time.time)
19
+ claimed_at: Optional[float] = None
20
+ completed_at: Optional[float] = None
21
+
22
+ # Request fields
23
+ method: str # GET, POST, PUT, DELETE, etc.
24
+ endpoint: str # e.g. "/api/v1/inference"
25
+ headers: Dict[str, str] = {}
26
+ body: Optional[Any] = None
27
+ query_params: Dict[str, str] = {}
28
+
29
+ # Routing: which mirror should handle this (optional, None = any mirror that has it)
30
+ target_mirror: Optional[str] = None
31
+
32
+ # Response fields
33
+ status: JobStatus = JobStatus.PENDING
34
+ response_status_code: Optional[int] = None
35
+ response_headers: Dict[str, str] = {}
36
+ response_body: Optional[Any] = None
37
+ error: Optional[str] = None
38
+
39
+ # TTL: jobs older than this (seconds) are considered timed out
40
+ ttl: float = 30.0
41
+
42
+
43
+ class ClaimRequest(BaseModel):
44
+ mirror_id: str
45
+ available_endpoints: list[str] # list of endpoint prefixes this mirror can serve
46
+
47
+
48
+ class CompleteRequest(BaseModel):
49
+ mirror_id: str
50
+ job_id: str
51
+ response_status_code: int
52
+ response_headers: Dict[str, str] = {}
53
+ response_body: Optional[Any] = None
54
+ error: Optional[str] = None
server.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Queue Server - hosted on Hugging Face Spaces
3
+ ============================================
4
+ Acts as the neutral relay between the public GUI and local mirrors.
5
+
6
+ Run with:
7
+ pip install fastapi uvicorn
8
+ uvicorn server:app --host 0.0.0.0 --port 7860
9
+
10
+ Environment variables:
11
+ QUEUE_API_KEY - shared secret for authenticating mirrors and GUI clients
12
+ JOB_TTL - seconds before a pending job is considered timed out (default 30)
13
+ POLL_INTERVAL - seconds between mirror poll cycles, informational only (default 2)
14
+ """
15
+
16
+ import asyncio
17
+ import os
18
+ import time
19
+ import logging
20
+ from contextlib import asynccontextmanager
21
+ from typing import Optional
22
+
23
+ from fastapi import FastAPI, HTTPException, Header, BackgroundTasks
24
+ from fastapi.middleware.cors import CORSMiddleware
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Import shared models (copy shared/models.py next to this file on HF Space)
28
+ # ---------------------------------------------------------------------------
29
+ from models import APIJob, ClaimRequest, CompleteRequest, JobStatus
30
+
31
+ logging.basicConfig(level=logging.INFO)
32
+ log = logging.getLogger("queue-server")
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # In-memory store. Swap for Redis on Upstash for production persistence.
36
+ # ---------------------------------------------------------------------------
37
+ _jobs: dict[str, APIJob] = {}
38
+ _lock = asyncio.Lock()
39
+
40
+ API_KEY = os.environ.get("QUEUE_API_KEY", "changeme")
41
+ JOB_TTL = float(os.environ.get("JOB_TTL", 30))
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Background task: reap timed-out jobs
46
+ # ---------------------------------------------------------------------------
47
+ async def _reaper():
48
+ while True:
49
+ await asyncio.sleep(5)
50
+ now = time.time()
51
+ async with _lock:
52
+ for job in list(_jobs.values()):
53
+ if job.status in (JobStatus.PENDING, JobStatus.CLAIMED):
54
+ age = now - job.created_at
55
+ if age > job.ttl:
56
+ job.status = JobStatus.TIMEOUT
57
+ log.info(f"Job {job.job_id} timed out after {age:.1f}s")
58
+
59
+
60
+ @asynccontextmanager
61
+ async def lifespan(app: FastAPI):
62
+ task = asyncio.create_task(_reaper())
63
+ yield
64
+ task.cancel()
65
+
66
+
67
+ app = FastAPI(title="API Proxy Queue", lifespan=lifespan)
68
+ app.add_middleware(
69
+ CORSMiddleware,
70
+ allow_origins=["*"], # Tighten this in production
71
+ allow_methods=["*"],
72
+ allow_headers=["*"],
73
+ )
74
+
75
+
76
+ # ---------------------------------------------------------------------------
77
+ # Auth helper
78
+ # ---------------------------------------------------------------------------
79
+ def _check_auth(x_api_key: Optional[str]):
80
+ if x_api_key != API_KEY:
81
+ raise HTTPException(status_code=401, detail="Invalid API key")
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Routes: Client-facing (GUI / caller side)
86
+ # ---------------------------------------------------------------------------
87
+
88
+ @app.post("/jobs", response_model=APIJob, summary="Submit a new API job")
89
+ async def submit_job(
90
+ job_in: APIJob,
91
+ x_api_key: Optional[str] = Header(default=None),
92
+ ):
93
+ _check_auth(x_api_key)
94
+ job_in.ttl = JOB_TTL
95
+ job_in.status = JobStatus.PENDING
96
+ async with _lock:
97
+ _jobs[job_in.job_id] = job_in
98
+ log.info(f"Job {job_in.job_id} submitted: {job_in.method} {job_in.endpoint}")
99
+ return job_in
100
+
101
+
102
+ @app.get("/jobs/{job_id}", response_model=APIJob, summary="Poll for a job's result")
103
+ async def get_job(
104
+ job_id: str,
105
+ x_api_key: Optional[str] = Header(default=None),
106
+ ):
107
+ _check_auth(x_api_key)
108
+ async with _lock:
109
+ job = _jobs.get(job_id)
110
+ if not job:
111
+ raise HTTPException(status_code=404, detail="Job not found")
112
+ return job
113
+
114
+
115
+ @app.get("/jobs/{job_id}/wait", response_model=APIJob, summary="Long-poll until job completes or times out")
116
+ async def wait_for_job(
117
+ job_id: str,
118
+ timeout: float = 25.0,
119
+ x_api_key: Optional[str] = Header(default=None),
120
+ ):
121
+ """
122
+ Blocks up to `timeout` seconds waiting for the job to complete.
123
+ Much more efficient than client-side polling.
124
+ """
125
+ _check_auth(x_api_key)
126
+ deadline = time.time() + min(timeout, JOB_TTL)
127
+ while time.time() < deadline:
128
+ async with _lock:
129
+ job = _jobs.get(job_id)
130
+ if not job:
131
+ raise HTTPException(status_code=404, detail="Job not found")
132
+ if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.TIMEOUT):
133
+ return job
134
+ await asyncio.sleep(0.3)
135
+
136
+ # Return current state even if still pending
137
+ async with _lock:
138
+ return _jobs[job_id]
139
+
140
+
141
+ # ---------------------------------------------------------------------------
142
+ # Routes: Mirror-facing
143
+ # ---------------------------------------------------------------------------
144
+
145
+ @app.post("/mirror/claim", response_model=Optional[APIJob], summary="Mirror claims a pending job it can serve")
146
+ async def claim_job(
147
+ claim: ClaimRequest,
148
+ x_api_key: Optional[str] = Header(default=None),
149
+ ):
150
+ """
151
+ The mirror sends its ID and the list of endpoint prefixes it can serve.
152
+ The server atomically assigns the first matching pending job.
153
+ Returns null if nothing is available.
154
+ """
155
+ _check_auth(x_api_key)
156
+ now = time.time()
157
+ async with _lock:
158
+ for job in _jobs.values():
159
+ if job.status != JobStatus.PENDING:
160
+ continue
161
+ # Check TTL
162
+ if now - job.created_at > job.ttl:
163
+ continue
164
+ # Check target mirror constraint
165
+ if job.target_mirror and job.target_mirror != claim.mirror_id:
166
+ continue
167
+ # Check endpoint match
168
+ if not any(job.endpoint.startswith(ep) for ep in claim.available_endpoints):
169
+ continue
170
+ # Atomic claim
171
+ job.status = JobStatus.CLAIMED
172
+ job.claimed_at = now
173
+ log.info(f"Job {job.job_id} claimed by mirror '{claim.mirror_id}'")
174
+ return job
175
+ return None
176
+
177
+
178
+ @app.post("/mirror/complete", summary="Mirror posts the result of a completed job")
179
+ async def complete_job(
180
+ result: CompleteRequest,
181
+ x_api_key: Optional[str] = Header(default=None),
182
+ ):
183
+ _check_auth(x_api_key)
184
+ async with _lock:
185
+ job = _jobs.get(result.job_id)
186
+ if not job:
187
+ raise HTTPException(status_code=404, detail="Job not found")
188
+ if job.status != JobStatus.CLAIMED:
189
+ raise HTTPException(status_code=409, detail=f"Job is in state '{job.status}', cannot complete")
190
+
191
+ job.status = JobStatus.FAILED if result.error else JobStatus.COMPLETED
192
+ job.completed_at = time.time()
193
+ job.response_status_code = result.response_status_code
194
+ job.response_headers = result.response_headers
195
+ job.response_body = result.response_body
196
+ job.error = result.error
197
+
198
+ log.info(f"Job {result.job_id} completed by mirror '{result.mirror_id}' → {result.response_status_code}")
199
+ return {"ok": True}
200
+
201
+
202
+ # ---------------------------------------------------------------------------
203
+ # Debug / health
204
+ # ---------------------------------------------------------------------------
205
+
206
+ @app.get("/health")
207
+ async def health():
208
+ async with _lock:
209
+ counts = {s.value: 0 for s in JobStatus}
210
+ for job in _jobs.values():
211
+ counts[job.status.value] += 1
212
+ return {"status": "ok", "jobs": counts}
213
+
214
+
215
+ @app.get("/jobs", summary="List all jobs (debug)")
216
+ async def list_jobs(
217
+ x_api_key: Optional[str] = Header(default=None),
218
+ ):
219
+ _check_auth(x_api_key)
220
+ async with _lock:
221
+ return list(_jobs.values())