| |
| """ |
| Perchance Image-Generation Server v2.0 |
| |
| This variant adds optional pyvirtualdisplay support so the server can be |
| hosted on headless environments (Hugging Face Spaces etc.) while keeping |
| all original behaviour unchanged. |
| |
| Behaviour: |
| - If ZD_HEADLESS is True, zendriver will run headless as before. |
| - If ZD_HEADLESS is False and USE_VIRTUAL_DISPLAY is True and a DISPLAY |
| is not present, we attempt to start a pyvirtualdisplay.Display (Xvfb) |
| automatically before launching browsers. If pyvirtualdisplay is not |
| installed or starting Xvfb fails, we log a warning and continue. |
| - If USE_VIRTUAL_DISPLAY is False we will NOT attempt to start a virtual |
| display β you must provide a DISPLAY yourself (or set ZD_HEADLESS=True) |
| if running on a headless host. |
| |
| To run on Hugging Face Spaces, add `pyvirtualdisplay` to requirements.txt |
| and ensure `xvfb` is available in the runtime (HF Spaces typically provide it). |
| |
| All original defaults and constants are preserved from the original file. |
| """ |
|
|
| import asyncio |
| import base64 |
| import json |
| import logging |
| import os |
| import random |
| import string |
| import time |
| import uuid |
| from concurrent.futures import ThreadPoolExecutor |
| from contextlib import asynccontextmanager |
| from datetime import datetime |
| from functools import partial |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
|
|
| import cloudscraper |
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse |
| from sse_starlette.sse import EventSourceResponse |
| import zendriver as zd |
| from zendriver import cdp |
|
|
| |
| try: |
| from pyvirtualdisplay import Display |
| _HAS_PYVIRTUALDISPLAY = True |
| except Exception: |
| Display = None |
| _HAS_PYVIRTUALDISPLAY = False |
|
|
| |
| |
| |
|
|
| |
| BASE_URL = "https://image-generation.perchance.org" |
| API_GENERATE = "/api/generate" |
| API_DOWNLOAD = "/api/downloadTemporaryImage" |
| API_AWAIT = "/api/awaitExistingGenerationRequest" |
| API_ACCESS_CODE = "/api/getAccessCodeForAdPoweredStuff" |
|
|
| |
| TARGET_URL = "https://perchance.org/ai-text-to-image-generator" |
| IMAGE_GEN_ORIGIN = "https://image-generation.perchance.org" |
| ZD_TIMEOUT = 90 |
| ZD_HEADLESS = False |
| CLICK_INTERVAL = 0.35 |
| CLICK_JITTER = 8.0 |
| KEY_PREFIX = "userKey" |
|
|
| |
| |
| |
| |
| |
| USE_VIRTUAL_DISPLAY = True |
|
|
| |
| HTTP_TIMEOUT = 30 |
| MAX_DOWNLOAD_WAIT = 180 |
| BACKOFF_INIT = 0.7 |
| MAX_GEN_RETRIES = 6 |
|
|
| |
| MAX_KEY_RETRIES = 3 |
| KEY_REFRESH_COOLDOWN = 30 |
| MAX_REFRESH_FAILURES = 5 |
|
|
| |
| WORKER_COUNT = 3 |
| MAX_QUEUE_SIZE = 1000 |
| EXECUTOR_THREADS = 16 |
| OUTPUT_DIR = Path("outputs") |
| OUTPUT_DIR.mkdir(exist_ok=True, parents=True) |
|
|
|
|
| |
| |
| |
|
|
| LOG_FMT = "%(asctime)s | %(levelname)-7s | %(message)s" |
| logging.basicConfig(level=logging.INFO, format=LOG_FMT) |
| log = logging.getLogger("perchance") |
|
|
|
|
| |
| |
| |
|
|
| USER_KEY: Optional[str] = None |
|
|
| |
| _key_lock: Optional[asyncio.Lock] = None |
| _key_valid: Optional[asyncio.Event] = None |
| _key_refresh_lock: Optional[asyncio.Lock] = None |
| _key_last_ts: float = 0.0 |
| _key_fail_count: int = 0 |
|
|
| JOB_QUEUE: Optional[asyncio.Queue] = None |
|
|
| TASKS: Dict[str, Dict[str, Any]] = {} |
| TASK_QUEUES: Dict[str, asyncio.Queue] = {} |
|
|
| EXECUTOR = ThreadPoolExecutor(max_workers=EXECUTOR_THREADS) |
| SCRAPER = cloudscraper.create_scraper() |
|
|
| |
| VDISPLAY: Optional[Display] = None |
|
|
|
|
| |
| |
| |
|
|
| def _safe(s: str) -> str: |
| """Sanitise string for filenames.""" |
| ok = set(string.ascii_letters + string.digits + "-_.()") |
| return "".join(c if c in ok else "_" for c in s)[:120] |
|
|
|
|
| def _sid() -> str: |
| return "".join(random.choices(string.ascii_lowercase + string.digits, k=8)) |
|
|
|
|
| def _now() -> str: |
| return datetime.utcnow().isoformat(timespec="milliseconds") + "Z" |
|
|
|
|
| def _reqid() -> str: |
| return f"{time.time():.6f}-{_sid()}" |
|
|
|
|
| def _stamp() -> str: |
| return datetime.utcnow().strftime("%Y%m%dT%H%M%S") |
|
|
|
|
| |
| |
| |
|
|
| def _start_virtual_display_if_needed(headless: bool): |
| """ |
| Start pyvirtualdisplay.Display() if we're running non-headless in an |
| environment without DISPLAY and USE_VIRTUAL_DISPLAY is True. |
| This function is synchronous and safe to be run in a thread executor. |
| """ |
| global VDISPLAY |
|
|
| if headless: |
| log.info("ZD_HEADLESS=True β not starting virtual display") |
| return |
|
|
| if not USE_VIRTUAL_DISPLAY: |
| log.info("USE_VIRTUAL_DISPLAY=False β not starting virtual display; expecting manual DISPLAY or headless mode.") |
| return |
|
|
| if os.environ.get("DISPLAY"): |
| log.info("DISPLAY already set: %s", os.environ.get("DISPLAY")) |
| return |
|
|
| if not _HAS_PYVIRTUALDISPLAY or Display is None: |
| log.warning( |
| "pyvirtualdisplay not installed β cannot create virtual DISPLAY. " |
| "Install pyvirtualdisplay in your environment to enable Xvfb.") |
| return |
|
|
| try: |
| VDISPLAY = Display(visible=0, size=(1280, 720)) |
| VDISPLAY.start() |
| |
| log.info("Started virtual display via pyvirtualdisplay (DISPLAY=%s)", os.environ.get("DISPLAY")) |
| except Exception as exc: |
| VDISPLAY = None |
| log.exception("Failed to start virtual display: %s", exc) |
|
|
|
|
| def _stop_virtual_display_if_needed(): |
| global VDISPLAY |
| if VDISPLAY is None: |
| return |
| try: |
| VDISPLAY.stop() |
| log.info("Stopped virtual display") |
| except Exception: |
| log.exception("Error while stopping virtual display") |
| finally: |
| VDISPLAY = None |
|
|
|
|
| |
| |
| |
|
|
| class PerchanceClient: |
| """All blocking HTTP work against the Perchance API.""" |
|
|
| def __init__(self): |
| self.base = BASE_URL.rstrip("/") |
| self.s = SCRAPER |
| self.h = { |
| "Accept": "*/*", |
| "Content-Type": "application/json;charset=UTF-8", |
| "Origin": IMAGE_GEN_ORIGIN, |
| "Referer": f"{IMAGE_GEN_ORIGIN}/embed", |
| "User-Agent": ( |
| "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " |
| "AppleWebKit/537.36 (KHTML, like Gecko) " |
| "Chrome/131.0.0.0 Safari/537.36" |
| ), |
| } |
|
|
| |
|
|
| def get_ad_code(self) -> str: |
| try: |
| r = self.s.get( |
| f"{self.base}{API_ACCESS_CODE}", |
| timeout=HTTP_TIMEOUT, headers=self.h, |
| ) |
| r.raise_for_status() |
| return r.text.strip() |
| except Exception: |
| return "" |
|
|
| def _post(self, body: dict, params: dict) -> dict: |
| try: |
| r = self.s.post( |
| f"{self.base}{API_GENERATE}", |
| json=body, params=params, |
| timeout=HTTP_TIMEOUT, headers=self.h, |
| ) |
| r.raise_for_status() |
| try: |
| return r.json() |
| except Exception: |
| return {"status": "invalid_json", "raw": r.text} |
| except Exception as exc: |
| return {"status": "fetch_failure", "error": str(exc)} |
|
|
| def _await_prev(self, key: str): |
| try: |
| self.s.get( |
| f"{self.base}{API_AWAIT}", |
| params={"userKey": key, "__cacheBust": random.random()}, |
| timeout=20, headers=self.h, |
| ) |
| except Exception: |
| pass |
|
|
| |
|
|
| def generate_one( |
| self, *, |
| prompt: str, |
| negative_prompt: str = "", |
| seed: int = -1, |
| resolution: str = "512x768", |
| guidance_scale: float = 7.0, |
| channel: str = "ai-text-to-image-generator", |
| sub_channel: str = "private", |
| user_key: str = "", |
| ad_access_code: str = "", |
| request_id: str = "", |
| ) -> dict: |
| """ |
| Returns ONE of: |
| {"imageId": ..., "seed": ...} |
| {"inline": ..., "seed": ...} |
| {"error": "invalid_key"} β caller must refresh key |
| {"error": "<other>", ...} |
| """ |
| request_id = request_id or _reqid() |
| params = { |
| "userKey": user_key, |
| "requestId": request_id, |
| "adAccessCode": ad_access_code, |
| "__cacheBust": random.random(), |
| } |
| body = { |
| "prompt": prompt, |
| "negativePrompt": negative_prompt, |
| "seed": seed, |
| "resolution": resolution, |
| "guidanceScale": guidance_scale, |
| "channel": channel, |
| "subChannel": sub_channel, |
| "userKey": user_key, |
| "adAccessCode": ad_access_code, |
| "requestId": request_id, |
| } |
|
|
| ad_refreshed = False |
|
|
| for att in range(1, MAX_GEN_RETRIES + 1): |
| res = self._post(body, params) |
| st = res.get("status") |
|
|
| |
| if st == "success": |
| iid = res.get("imageId") |
| urls = res.get("imageDataUrls") |
| if iid: |
| log.info("Got imageId: %s", iid) |
| return {"imageId": iid, "seed": res.get("seed")} |
| if urls: |
| return {"inline": urls[0], "seed": res.get("seed")} |
| log.error("success but empty payload: %s", str(res)[:300]) |
| return {"error": "empty_success", "raw": res} |
|
|
| |
| if st == "invalid_key": |
| log.warning("Server says invalid_key") |
| return {"error": "invalid_key"} |
|
|
| |
| if st == "waiting_for_prev_request_to_finish": |
| log.info("Waiting for prev request to finish β¦") |
| self._await_prev(user_key) |
| time.sleep(0.3 + random.random() * 0.3) |
| continue |
|
|
| |
| if st == "invalid_ad_access_code" and not ad_refreshed: |
| code = self.get_ad_code() |
| if code: |
| ad_access_code = code |
| params["adAccessCode"] = code |
| body["adAccessCode"] = code |
| ad_refreshed = True |
| log.info("Refreshed ad code β retry") |
| time.sleep(0.8) |
| continue |
| return {"error": "invalid_ad_access_code"} |
|
|
| |
| if st == "gen_failure" and res.get("type") == 1: |
| log.warning("gen_failure type 1 β retry after 2.5 s") |
| time.sleep(2.5) |
| continue |
|
|
| |
| if st in (None, "fetch_failure", "invalid_json", "stale_request"): |
| log.info("Transient error (status=%s) attempt %d/%d", st, att, MAX_GEN_RETRIES) |
| time.sleep(1.0) |
| continue |
|
|
| |
| log.error("Unhandled status '%s': %s", st, str(res)[:300]) |
| return {"error": f"unhandled_{st}", "raw": res} |
|
|
| return {"error": "max_retries_exceeded"} |
|
|
| |
|
|
| def download_image(self, image_id: str, prefix: str = "img") -> str: |
| """Poll until the image is ready, save to OUTPUT_DIR, return path.""" |
| url = f"{self.base}{API_DOWNLOAD}?imageId={image_id}" |
| t0 = time.time() |
| bk = BACKOFF_INIT |
|
|
| while True: |
| elapsed = time.time() - t0 |
| if elapsed >= MAX_DOWNLOAD_WAIT: |
| raise TimeoutError( |
| f"Download timed out ({elapsed:.0f}s) for {image_id}" |
| ) |
| try: |
| r = self.s.get(url, timeout=HTTP_TIMEOUT, |
| headers=self.h, stream=True) |
| if r.status_code == 200: |
| ct = r.headers.get("Content-Type", "") |
| ext = ( |
| ".png" if "png" in ct else |
| ".webp" if "webp" in ct else ".jpg" |
| ) |
| fn = _safe(f"{prefix}_{image_id[:12]}{ext}") |
| fp = str(OUTPUT_DIR / fn) |
| with open(fp, "wb") as f: |
| for chunk in r.iter_content(8192): |
| if chunk: |
| f.write(chunk) |
| log.info("Saved β %s", fp) |
| return fp |
| except Exception: |
| pass |
|
|
| time.sleep(bk) |
| bk = min(bk * 1.8, 8.0) |
|
|
|
|
| CLIENT = PerchanceClient() |
|
|
|
|
| |
| |
| |
|
|
| async def _cdp_mouse(tab, typ, x, y, **kw): |
| await tab.send( |
| cdp.input_.dispatch_mouse_event( |
| type_=typ, x=float(x), y=float(y), **kw, |
| ) |
| ) |
|
|
|
|
| async def _viewport_center(tab): |
| try: |
| v = await tab.evaluate( |
| "(()=>({w:innerWidth,h:innerHeight}))()", |
| await_promise=False, return_by_value=True, |
| ) |
| return (v["w"] / 2.0, v["h"] / 2.0) |
| except Exception: |
| return (600.0, 400.0) |
|
|
|
|
| async def _ls_get(tab, key): |
| try: |
| return await tab.evaluate( |
| f"localStorage&&localStorage.getItem({json.dumps(key)})", |
| await_promise=True, return_by_value=True, |
| ) |
| except Exception: |
| return None |
|
|
|
|
| async def _clicker_loop(tab, stop: asyncio.Event): |
| """Simulate steady centre-clicks on *tab* until *stop* is set.""" |
| try: |
| await tab.evaluate( |
| "window.focus&&window.focus()", |
| await_promise=False, return_by_value=False, |
| ) |
| except Exception: |
| pass |
|
|
| centre = await _viewport_center(tab) |
| centre_upd = time.time() |
|
|
| while not stop.is_set(): |
| if time.time() - centre_upd > 2.5: |
| centre = await _viewport_center(tab) |
| centre_upd = time.time() |
|
|
| jx = random.uniform(-CLICK_JITTER, CLICK_JITTER) |
| jy = random.uniform(-CLICK_JITTER, CLICK_JITTER) |
| cx, cy = centre[0] + jx, centre[1] + jy |
|
|
| try: |
| await _cdp_mouse(tab, "mouseMoved", cx, cy, pointer_type="mouse") |
| await asyncio.sleep(random.uniform(0.02, 0.08)) |
| await _cdp_mouse( |
| tab, "mousePressed", cx, cy, |
| button=cdp.input_.MouseButton.LEFT, |
| click_count=1, buttons=1, |
| ) |
| await asyncio.sleep(random.uniform(0.03, 0.12)) |
| await _cdp_mouse( |
| tab, "mouseReleased", cx, cy, |
| button=cdp.input_.MouseButton.LEFT, |
| click_count=1, buttons=0, |
| ) |
| except Exception: |
| pass |
|
|
| |
| try: |
| await asyncio.wait_for( |
| stop.wait(), |
| timeout=CLICK_INTERVAL * random.uniform(0.85, 1.15), |
| ) |
| break |
| except asyncio.TimeoutError: |
| pass |
|
|
|
|
| async def _poll_for_key(tab, stop: asyncio.Event, max_sec: int): |
| """Poll localStorage every 250 ms for a userKey entry.""" |
| t0 = time.time() |
| while not stop.is_set() and (time.time() - t0) < max_sec: |
| val = await _ls_get(tab, f"{KEY_PREFIX}-0") |
| if val: |
| return val |
| try: |
| keys = await tab.evaluate( |
| "Object.keys(localStorage||{}).filter(k=>k.includes('userKey'))", |
| await_promise=False, return_by_value=True, |
| ) |
| for k in (keys or []): |
| v = await _ls_get(tab, k) |
| if v: |
| return v |
| except Exception: |
| pass |
| await asyncio.sleep(0.25) |
| return None |
|
|
|
|
| async def fetch_key_via_browser( |
| timeout: int = ZD_TIMEOUT, |
| headless: bool = ZD_HEADLESS, |
| ) -> Optional[str]: |
| """ |
| Launch Chrome β navigate to Perchance β click to trigger |
| ad/verification β read userKey from localStorage β close browser. |
| Returns the key string or None. |
| """ |
| log.info( |
| "Launching browser for userKey (timeout=%ds, headless=%s)", |
| timeout, headless, |
| ) |
|
|
| |
| |
| |
| loop = asyncio.get_running_loop() |
| try: |
| await loop.run_in_executor(None, partial(_start_virtual_display_if_needed, headless)) |
| except Exception: |
| log.exception("Error while attempting to start virtual display") |
|
|
| try: |
| browser = await zd.start(headless=headless) |
| except Exception as exc: |
| log.exception("Browser start failed: %s", exc) |
| return None |
|
|
| stop = asyncio.Event() |
| result = None |
|
|
| try: |
| page_tab = await browser.get(TARGET_URL) |
| log.info("Opened %s", TARGET_URL) |
| await asyncio.sleep(2.0) |
|
|
| origin_tab = await browser.get(IMAGE_GEN_ORIGIN, new_tab=True) |
| log.info("Opened %s", IMAGE_GEN_ORIGIN) |
| await asyncio.sleep(1.0) |
|
|
| await page_tab.bring_to_front() |
| await asyncio.sleep(0.5) |
|
|
| clicker = asyncio.create_task(_clicker_loop(page_tab, stop)) |
| poller = asyncio.create_task(_poll_for_key(origin_tab, stop, timeout)) |
|
|
| try: |
| done, _ = await asyncio.wait({poller}, timeout=timeout) |
| if poller in done: |
| result = poller.result() |
| finally: |
| stop.set() |
| if not clicker.done(): |
| clicker.cancel() |
| try: |
| await clicker |
| except asyncio.CancelledError: |
| pass |
|
|
| for t in (origin_tab, page_tab): |
| try: |
| await t.close() |
| except Exception: |
| pass |
| finally: |
| try: |
| await browser.stop() |
| except Exception: |
| pass |
|
|
| if result: |
| log.info("Fetched userKey (len=%d)", len(result)) |
| else: |
| log.warning("Could not fetch userKey within %ds", timeout) |
| return result |
|
|
|
|
| |
| |
| |
|
|
| async def _broadcast(event: dict): |
| """Push an event into every active task's SSE queue.""" |
| for tid, q in TASK_QUEUES.items(): |
| task = TASKS.get(tid) |
| if task and task["status"] in ("queued", "running"): |
| try: |
| q.put_nowait(event) |
| except asyncio.QueueFull: |
| pass |
|
|
|
|
| async def refresh_user_key() -> Optional[str]: |
| """ |
| Coordinate a single key refresh. If another coroutine is already |
| refreshing, we simply wait for it to finish and return the new key. |
| |
| Returns the new key string, or None on failure. |
| """ |
| global USER_KEY, _key_last_ts, _key_fail_count |
|
|
| async with _key_refresh_lock: |
| |
| age = time.time() - _key_last_ts |
| if age < KEY_REFRESH_COOLDOWN and USER_KEY: |
| log.info( |
| "Key was refreshed %.1fs ago β reusing existing key", age, |
| ) |
| return USER_KEY |
|
|
| |
| if _key_fail_count >= MAX_REFRESH_FAILURES: |
| log.error( |
| "Key refresh disabled: %d consecutive failures. " |
| "Set key manually via POST /set_user_key", |
| _key_fail_count, |
| ) |
| await _broadcast({ |
| "type": "key_refresh_failed", |
| "time": _now(), |
| "message": ( |
| f"Auto-refresh disabled after {_key_fail_count} failures. " |
| "Please set userKey manually via /set_user_key" |
| ), |
| }) |
| return None |
|
|
| |
| _key_valid.clear() |
| log.info("Starting userKey refresh via browser β¦") |
|
|
| await _broadcast({ |
| "type": "key_refreshing", |
| "time": _now(), |
| "message": "UserKey expired β refreshing via browser automation β¦", |
| }) |
|
|
| try: |
| new_key = await fetch_key_via_browser( |
| timeout=ZD_TIMEOUT, headless=ZD_HEADLESS, |
| ) |
|
|
| if new_key: |
| async with _key_lock: |
| USER_KEY = new_key |
| _key_last_ts = time.time() |
| _key_fail_count = 0 |
|
|
| log.info("UserKey refreshed OK (len=%d)", len(new_key)) |
| await _broadcast({ |
| "type": "key_refreshed", |
| "time": _now(), |
| "message": "UserKey refreshed β resuming generation.", |
| }) |
| return new_key |
|
|
| |
| _key_fail_count += 1 |
| log.error( |
| "Key refresh returned nothing (failure #%d/%d)", |
| _key_fail_count, MAX_REFRESH_FAILURES, |
| ) |
| await _broadcast({ |
| "type": "key_refresh_failed", |
| "time": _now(), |
| "message": ( |
| f"Key refresh failed (attempt {_key_fail_count}" |
| f"/{MAX_REFRESH_FAILURES})" |
| ), |
| }) |
| return None |
|
|
| except Exception as exc: |
| _key_fail_count += 1 |
| log.exception( |
| "Key refresh error (failure #%d/%d): %s", |
| _key_fail_count, MAX_REFRESH_FAILURES, exc, |
| ) |
| await _broadcast({ |
| "type": "key_refresh_failed", |
| "time": _now(), |
| "message": f"Key refresh error: {exc}", |
| }) |
| return None |
|
|
| finally: |
| |
| _key_valid.set() |
|
|
|
|
| |
| |
| |
|
|
| def create_task( |
| prompts: List[str], |
| count: int, |
| resolution: str, |
| guidance: float, |
| negative: str, |
| sub_channel: str, |
| ) -> dict: |
| tid = str(uuid.uuid4()) |
| task = { |
| "id": tid, |
| "prompts": prompts, |
| "count": count, |
| "resolution": resolution, |
| "guidance": guidance, |
| "negative": negative, |
| "sub_channel": sub_channel, |
| "created_at": _now(), |
| "status": "queued", |
| "total_images": len(prompts) * count, |
| "completed": 0, |
| "results": [], |
| "error": None, |
| } |
| TASKS[tid] = task |
| TASK_QUEUES[tid] = asyncio.Queue() |
| return task |
|
|
|
|
| |
| |
| |
|
|
| async def _save_inline(data_url: str, prompt: str) -> str: |
| """Decode base-64 data URL β file. Returns path.""" |
| loop = asyncio.get_running_loop() |
| header, b64 = (data_url.split(",", 1) + [""])[:2] if "," in data_url else ("", data_url) |
| ext = ".png" if "png" in header else ".jpg" |
| fn = _safe(f"{prompt[:30]}_{_stamp()}_{_sid()}{ext}") |
| fp = OUTPUT_DIR / fn |
| raw = base64.b64decode(b64) |
| await loop.run_in_executor(EXECUTOR, fp.write_bytes, raw) |
| log.info("Saved inline β %s", fp) |
| return str(fp) |
|
|
|
|
| async def _download(image_id: str, prompt: str) -> str: |
| """Download via PerchanceClient (blocking, in executor).""" |
| loop = asyncio.get_running_loop() |
| prefix = f"{_safe(prompt[:30])}_{_stamp()}_{_sid()}" |
| return await loop.run_in_executor( |
| EXECUTOR, |
| partial(CLIENT.download_image, image_id, prefix), |
| ) |
|
|
|
|
| async def _generate_single( |
| prompt: str, |
| task: dict, |
| idx: int, |
| queue: asyncio.Queue, |
| ad_code: str, |
| ) -> Optional[str]: |
| """ |
| Generate + save one image. |
| |
| On 'invalid_key', triggers a coordinated key refresh and retries |
| up to MAX_KEY_RETRIES times. Returns the saved filepath or None. |
| """ |
| loop = asyncio.get_running_loop() |
| tid = task["id"] |
|
|
| for key_try in range(1, MAX_KEY_RETRIES + 1): |
|
|
| |
| await _key_valid.wait() |
|
|
| |
| async with _key_lock: |
| active_key = USER_KEY |
|
|
| if not active_key: |
| await queue.put({ |
| "type": "error", |
| "time": _now(), |
| "task_id": tid, |
| "message": "No userKey available. Set via /set_user_key", |
| }) |
| return None |
|
|
| |
| result = await loop.run_in_executor( |
| EXECUTOR, |
| partial( |
| CLIENT.generate_one, |
| prompt=prompt, |
| negative_prompt=task["negative"], |
| seed=-1, |
| resolution=task["resolution"], |
| guidance_scale=task["guidance"], |
| channel="ai-text-to-image-generator", |
| sub_channel=task["sub_channel"], |
| user_key=active_key, |
| ad_access_code=ad_code, |
| request_id=_reqid(), |
| ), |
| ) |
|
|
| |
| if result.get("error") == "invalid_key": |
| log.warning( |
| "invalid_key for task %s (key_try %d/%d) β refreshing", |
| tid, key_try, MAX_KEY_RETRIES, |
| ) |
| await queue.put({ |
| "type": "key_invalid", |
| "time": _now(), |
| "task_id": tid, |
| "attempt": key_try, |
| "max_attempts": MAX_KEY_RETRIES, |
| "message": "UserKey invalid β refreshing β¦", |
| }) |
|
|
| new_key = await refresh_user_key() |
| if new_key: |
| |
| ad_code = await loop.run_in_executor( |
| EXECUTOR, CLIENT.get_ad_code, |
| ) |
| continue |
| else: |
| await queue.put({ |
| "type": "error", |
| "time": _now(), |
| "task_id": tid, |
| "message": "Could not refresh userKey β aborting image", |
| }) |
| return None |
|
|
| |
| if result.get("error"): |
| log.warning( |
| "Gen error task=%s prompt='%.40s': %s", |
| tid, prompt, result, |
| ) |
| await queue.put({ |
| "type": "gen_error", |
| "time": _now(), |
| "task_id": tid, |
| "prompt": prompt, |
| "index": idx, |
| "error": result, |
| }) |
| return None |
|
|
| |
| try: |
| if result.get("inline"): |
| fp = await _save_inline(result["inline"], prompt) |
| elif result.get("imageId"): |
| fp = await _download(result["imageId"], prompt) |
| else: |
| log.error("Unexpected result: %s", result) |
| return None |
|
|
| seed = result.get("seed") |
| task["completed"] += 1 |
| task["results"].append({ |
| "prompt": prompt, |
| "index": idx, |
| "path": fp, |
| "seed": seed, |
| }) |
| await queue.put({ |
| "type": "image_ready", |
| "time": _now(), |
| "task_id": tid, |
| "prompt": prompt, |
| "index": idx, |
| "path": fp, |
| "seed": seed, |
| "completed": task["completed"], |
| "total": task["total_images"], |
| }) |
| return fp |
|
|
| except Exception as exc: |
| log.exception("Save/download error task=%s: %s", tid, exc) |
| await queue.put({ |
| "type": "download_error", |
| "time": _now(), |
| "task_id": tid, |
| "prompt": prompt, |
| "index": idx, |
| "error": str(exc), |
| }) |
| return None |
|
|
| |
| log.error("Exhausted key retries for task %s prompt='%.40s'", tid, prompt) |
| return None |
|
|
|
|
| async def worker_loop(worker_id: int, semaphore: asyncio.Semaphore): |
| """Long-running coroutine: pull jobs β generate images.""" |
| log.info("Worker %d started", worker_id) |
| loop = asyncio.get_running_loop() |
|
|
| while True: |
| job = await JOB_QUEUE.get() |
|
|
| |
| if job is None: |
| log.info("Worker %d shutting down", worker_id) |
| JOB_QUEUE.task_done() |
| break |
|
|
| task = job["task"] |
| tid = task["id"] |
| queue = TASK_QUEUES.get(tid) |
|
|
| log.info( |
| "Worker %d β task %s (%d images)", |
| worker_id, tid, task["total_images"], |
| ) |
| task["status"] = "running" |
| if queue: |
| await queue.put({ |
| "type": "started", |
| "time": _now(), |
| "task_id": tid, |
| "total_images": task["total_images"], |
| }) |
|
|
| |
| ad_code = await loop.run_in_executor(EXECUTOR, CLIENT.get_ad_code) |
|
|
| |
| async def _heartbeat(): |
| while task["status"] == "running": |
| await asyncio.sleep(5.0) |
| if queue and task["status"] == "running": |
| try: |
| queue.put_nowait({ |
| "type": "heartbeat", |
| "time": _now(), |
| "task_id": tid, |
| "completed": task["completed"], |
| "total": task["total_images"], |
| }) |
| except asyncio.QueueFull: |
| pass |
|
|
| hb = asyncio.create_task(_heartbeat()) |
|
|
| try: |
| for prompt in task["prompts"]: |
| for i in range(task["count"]): |
| async with semaphore: |
| await _generate_single( |
| prompt, task, i, queue, ad_code, |
| ) |
| if task["status"] == "failed": |
| break |
| if task["status"] == "failed": |
| break |
|
|
| |
| if task["status"] != "failed": |
| if task["completed"] == 0 and task["total_images"] > 0: |
| task["status"] = "failed" |
| task["error"] = "No images generated successfully" |
| else: |
| task["status"] = "done" |
|
|
| if queue: |
| await queue.put({ |
| "type": task["status"], |
| "time": _now(), |
| "task_id": tid, |
| "completed": task["completed"], |
| "total": task["total_images"], |
| "error": task.get("error"), |
| }) |
|
|
| except Exception as exc: |
| log.exception("Worker %d task %s crashed: %s", worker_id, tid, exc) |
| task["status"] = "failed" |
| task["error"] = str(exc) |
| if queue: |
| await queue.put({ |
| "type": "failed", |
| "time": _now(), |
| "task_id": tid, |
| "error": str(exc), |
| }) |
|
|
| finally: |
| hb.cancel() |
| try: |
| await hb |
| except asyncio.CancelledError: |
| pass |
|
|
| if queue: |
| await queue.put({"type": "eof", "time": _now(), "task_id": tid}) |
|
|
| JOB_QUEUE.task_done() |
| log.info( |
| "Worker %d task %s finished (%s, %d/%d)", |
| worker_id, tid, task["status"], |
| task["completed"], task["total_images"], |
| ) |
|
|
|
|
| |
| |
| |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global USER_KEY, _key_lock, _key_valid, _key_refresh_lock |
| global _key_last_ts, _key_fail_count, JOB_QUEUE |
|
|
| |
| _key_lock = asyncio.Lock() |
| _key_valid = asyncio.Event() |
| _key_valid.set() |
| _key_refresh_lock = asyncio.Lock() |
| JOB_QUEUE = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) |
|
|
| |
| loop = asyncio.get_running_loop() |
| try: |
| await loop.run_in_executor(None, partial(_start_virtual_display_if_needed, ZD_HEADLESS)) |
| except Exception: |
| log.exception("Failed to ensure virtual display at startup") |
|
|
| |
| if USER_KEY: |
| log.info("Using pre-fetched userKey (len=%d)", len(USER_KEY)) |
| _key_last_ts = time.time() |
| else: |
| skip = os.environ.get("NO_INITIAL_FETCH", "") in ("1", "true", "True") |
| if not skip: |
| try: |
| key = await fetch_key_via_browser( |
| timeout=ZD_TIMEOUT, headless=ZD_HEADLESS, |
| ) |
| if key: |
| USER_KEY = key |
| _key_last_ts = time.time() |
| log.info("Fetched userKey at startup (len=%d)", len(key)) |
| else: |
| log.warning( |
| "Startup key fetch failed. " |
| "Use /set_user_key or /fetch_user_key." |
| ) |
| except Exception as exc: |
| log.exception("Startup key fetch error: %s", exc) |
| else: |
| log.info("NO_INITIAL_FETCH=1 β skipping browser key fetch") |
|
|
| |
| sem = asyncio.Semaphore(WORKER_COUNT) |
| workers = [ |
| asyncio.create_task(worker_loop(i + 1, sem)) |
| for i in range(WORKER_COUNT) |
| ] |
| log.info("Launched %d workers", WORKER_COUNT) |
|
|
| |
| yield |
| |
|
|
| log.info("Shutdown: sending stop sentinels to workers β¦") |
| for _ in range(WORKER_COUNT): |
| await JOB_QUEUE.put(None) |
| await asyncio.gather(*workers, return_exceptions=True) |
|
|
| try: |
| SCRAPER.close() |
| except Exception: |
| pass |
| EXECUTOR.shutdown(wait=True) |
|
|
| |
| try: |
| await loop.run_in_executor(None, _stop_virtual_display_if_needed) |
| except Exception: |
| log.exception("Failed to stop virtual display cleanly") |
|
|
| log.info("Shutdown complete") |
|
|
|
|
| |
| app = FastAPI( |
| title="Perchance Image Generation Server v2 (pyvirtualdisplay)", |
| lifespan=lifespan, |
| ) |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
|
|
| @app.get("/health") |
| async def health(): |
| async with _key_lock: |
| has_key = USER_KEY is not None |
| return { |
| "status": "ok", |
| "has_user_key": has_key, |
| "queue_size": JOB_QUEUE.qsize() if JOB_QUEUE else 0, |
| "active_tasks": sum( |
| 1 for t in TASKS.values() if t["status"] in ("queued", "running") |
| ), |
| } |
|
|
|
|
| @app.get("/user_key") |
| async def user_key_info(): |
| async with _key_lock: |
| has = USER_KEY is not None |
| ln = len(USER_KEY) if has else 0 |
| return {"has_user_key": has, "key_length": ln} |
|
|
|
|
| @app.post("/set_user_key") |
| async def set_user_key(payload: Dict[str, str]): |
| global USER_KEY, _key_last_ts, _key_fail_count |
| key = payload.get("userKey", "").strip() |
| if not key: |
| raise HTTPException(400, "userKey required") |
| async with _key_lock: |
| USER_KEY = key |
| _key_last_ts = time.time() |
| _key_fail_count = 0 |
| _key_valid.set() |
| log.info("userKey set via API (len=%d)", len(key)) |
| return {"status": "ok", "key_length": len(key)} |
|
|
|
|
| @app.post("/fetch_user_key") |
| async def fetch_user_key_endpoint(): |
| """Trigger a background browser-based key fetch.""" |
| global _key_fail_count |
|
|
| async def _bg(): |
| global _key_fail_count |
| _key_fail_count = 0 |
| await refresh_user_key() |
|
|
| asyncio.create_task(_bg()) |
| return {"status": "started", "note": "Browser key fetch running in background"} |
|
|
|
|
| @app.post("/generate") |
| async def submit_job(payload: Dict[str, Any]): |
| """ |
| POST /generate |
| Body: |
| { |
| "prompts": ["a cat in space", "sunset over mountains"], |
| "count": 2, |
| "resolution": "512x768", |
| "guidance": 7.0, |
| "negative": "", |
| "subChannel": "private" |
| } |
| Returns: |
| { "task_id": "...", "stream_url": "/stream/...", "queue_position": N } |
| """ |
| prompts = payload.get("prompts") or payload.get("prompt") or [] |
| if isinstance(prompts, str): |
| prompts = [prompts] |
| if not isinstance(prompts, list) or not prompts: |
| raise HTTPException(400, "prompts must be a non-empty list") |
|
|
| count = max(1, int(payload.get("count", 1))) |
| resolution = payload.get("resolution", "512x768") |
| guidance = float(payload.get("guidance", 7.0)) |
| negative = payload.get("negative", "") or "" |
| sub_channel = payload.get("subChannel", "private") |
|
|
| task = create_task(prompts, count, resolution, guidance, negative, sub_channel) |
|
|
| try: |
| await JOB_QUEUE.put({"task": task}) |
| except asyncio.QueueFull: |
| raise HTTPException(503, "Server queue full β try again later") |
|
|
| position = JOB_QUEUE.qsize() |
| q = TASK_QUEUES.get(task["id"]) |
| if q: |
| await q.put({ |
| "type": "queued", |
| "time": _now(), |
| "task_id": task["id"], |
| "queue_position": position, |
| "total_images": task["total_images"], |
| }) |
|
|
| return { |
| "task_id": task["id"], |
| "stream_url": f"/stream/{task['id']}", |
| "queue_position": position, |
| } |
|
|
|
|
| @app.get("/stream/{task_id}") |
| async def stream_task(request: Request, task_id: str): |
| """ |
| SSE stream. Event types: |
| meta Β· queued Β· started Β· heartbeat Β· image_ready |
| key_invalid Β· key_refreshing Β· key_refreshed Β· key_refresh_failed |
| gen_error Β· download_error Β· done Β· failed Β· eof |
| """ |
| if task_id not in TASKS: |
| raise HTTPException(404, "unknown task id") |
|
|
| task = TASKS[task_id] |
| queue = TASK_QUEUES[task_id] |
|
|
| async def event_gen(): |
| |
| yield { |
| "event": "meta", |
| "data": json.dumps({ |
| "task_id": task_id, |
| "status": task["status"], |
| "total_images": task["total_images"], |
| "created_at": task["created_at"], |
| }), |
| } |
|
|
| |
| if task["status"] in ("done", "failed"): |
| for r in task["results"]: |
| yield { |
| "event": "image_ready", |
| "data": json.dumps({ |
| "task_id": task_id, |
| "prompt": r["prompt"], |
| "index": r["index"], |
| "path": r["path"], |
| "seed": r["seed"], |
| "completed": task["completed"], |
| "total": task["total_images"], |
| }), |
| } |
| yield { |
| "event": task["status"], |
| "data": json.dumps({ |
| "task_id": task_id, |
| "completed": task["completed"], |
| "total": task["total_images"], |
| "error": task.get("error"), |
| }), |
| } |
| yield { |
| "event": "eof", |
| "data": json.dumps({"task_id": task_id}), |
| } |
| return |
|
|
| |
| while True: |
| try: |
| ev = await asyncio.wait_for(queue.get(), timeout=30.0) |
| except asyncio.TimeoutError: |
| |
| yield { |
| "event": "ping", |
| "data": json.dumps({"time": _now()}), |
| } |
| if await request.is_disconnected(): |
| log.info("SSE client disconnected (task %s)", task_id) |
| break |
| continue |
|
|
| yield { |
| "event": ev.get("type", "event"), |
| "data": json.dumps(ev), |
| } |
| if ev.get("type") == "eof": |
| break |
|
|
| return EventSourceResponse(event_gen()) |
|
|
|
|
| @app.get("/status/{task_id}") |
| async def get_status(task_id: str): |
| task = TASKS.get(task_id) |
| if not task: |
| raise HTTPException(404, "unknown task id") |
| return {"task": task} |
|
|
|
|
| @app.get("/outputs/{filename}") |
| async def get_output(filename: str): |
| fp = OUTPUT_DIR / filename |
| if not fp.exists(): |
| raise HTTPException(404, "file not found") |
| return FileResponse(fp, media_type="application/octet-stream", filename=filename) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| |
| ZD_HEADLESS = os.environ.get("ZD_HEADLESS", str(ZD_HEADLESS)) in ("1", "true", "True") |
| USE_VIRTUAL_DISPLAY = os.environ.get("USE_VIRTUAL_DISPLAY", str(USE_VIRTUAL_DISPLAY)) in ("1", "true", "True") |
|
|
| |
| try: |
| _start_virtual_display_if_needed(ZD_HEADLESS) |
| except Exception: |
| log.exception("Failed to ensure virtual display in __main__") |
|
|
| uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) |
|
|