| |
| from __future__ import annotations |
|
|
| import asyncio |
| import mimetypes |
| import os |
| import sys |
| import json |
| import re |
| import time |
| import uuid |
| import math |
| import logging |
| import shutil |
| from pathlib import Path |
| from dataclasses import dataclass, field |
| from typing import Any, Dict, List, Optional, Tuple, Set |
| from contextlib import asynccontextmanager |
| from starlette.websockets import WebSocketState, WebSocketDisconnect |
| try: |
| import tomllib |
| except ModuleNotFoundError: |
| import tomli as tomllib |
| import traceback |
|
|
| try: |
| from uvicorn.protocols.utils import ClientDisconnected |
| except Exception: |
| ClientDisconnected = None |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| import anyio |
| from fastapi import FastAPI, APIRouter, UploadFile, File, Form, HTTPException, WebSocket, WebSocketDisconnect, Request |
| from fastapi.responses import FileResponse, JSONResponse, Response |
| from fastapi.staticfiles import StaticFiles |
|
|
| from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage, AIMessage, ToolMessage |
|
|
| |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| SRC_DIR = os.path.join(ROOT_DIR, "src") |
| if SRC_DIR not in sys.path: |
| sys.path.insert(0, SRC_DIR) |
|
|
| from open_storyline.agent import build_agent, ClientContext |
| from open_storyline.utils.prompts import get_prompt |
| from open_storyline.utils.media_handler import scan_media_dir |
| from open_storyline.config import load_settings, default_config_path |
| from open_storyline.config import Settings |
| from open_storyline.storage.agent_memory import ArtifactStore |
| from open_storyline.mcp.hooks.node_interceptors import ToolInterceptor |
| from open_storyline.mcp.hooks.chat_middleware import set_mcp_log_sink, reset_mcp_log_sink |
|
|
| WEB_DIR = os.path.join(ROOT_DIR, "web") |
| STATIC_DIR = os.path.join(WEB_DIR, "static") |
| INDEX_HTML = os.path.join(WEB_DIR, "index.html") |
| NODE_MAP_HTML = os.path.join(WEB_DIR, "node_map/node_map.html") |
| NODE_MAP_DIR = os.path.join(WEB_DIR, "node_map") |
|
|
| SERVER_CACHE_DIR = os.path.join(ROOT_DIR, '.storyline' , ".server_cache") |
|
|
| CHUNK_SIZE = 1024 * 1024 |
|
|
| |
| USE_SESSION_SUBDIR = True |
|
|
| CUSTOM_MODEL_KEY = "__custom__" |
|
|
| |
| DEFAULT_LLM_API_KEY = os.getenv("DEEPSEEK_API_KEY") |
| DEFAULT_LLM_API_URL = os.getenv("DEEPSEEK_API_URL") |
| DEFAULT_LLM_API_NAME = os.getenv("DEEPSEEK_API_NAME", "deepseek-chat") |
| DEFAULT_VLM_API_KEY = os.getenv("GLM_V4_6_API_KEY") |
| DEFAULT_VLM_API_URL = os.getenv("GLM_V4_6_API_URL") |
| DEFAULT_VLM_API_NAME = os.getenv("GLM_V4_6_API_NAME", "qwen3-vl-8b-instruct") |
| print("DEEPSEEK_API_KEY exists:", bool(os.getenv("DEEPSEEK_API_KEY"))) |
| print("QWEN3_VL_8B_API_KEY exists:", bool(os.getenv("QWEN3_VL_8B_API_KEY"))) |
| print("DEEPSEEK_API_URL:", repr(os.getenv("DEEPSEEK_API_URL"))) |
| print("QWEN3_VL_8B_API_URL:", repr(os.getenv("QWEN3_VL_8B_API_URL"))) |
|
|
| def debug_traceback_print(cfg: Settings): |
| if cfg.developer.developer_mode: |
| traceback.print_exc() |
|
|
| def _s(x: Any) -> str: |
| return str(x or "").strip() |
|
|
| def _norm_url(u: Any) -> str: |
| u = _s(u) |
| return u.rstrip("/") if u else "" |
|
|
| def _env_fallback_for_model(model_name: str) -> Tuple[str, str]: |
| """ |
| - deepseek* -> DEEPSEEK_API_URL / DEEPSEEK_API_KEY |
| - qwen3* -> QWEN3_VL_8B_API_URL / QWEN3_VL_8B_API_KEY |
| """ |
| m = _s(model_name).lower() |
| if "deepseek" in m: |
| return (_s(os.getenv("DEEPSEEK_API_URL")), _s(os.getenv("DEEPSEEK_API_KEY"))) |
| if m.startswith("qwen3-vl-8b-instruct") or "qwen3-vl-8b-instruct" in m: |
| return (_s(os.getenv("QWEN3_VL_8B_API_URL")), _s(os.getenv("QWEN3_VL_8B_API_KEY"))) |
| return ("", "") |
|
|
| def _resolve_default_model_override(cfg: Settings, model_name: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: |
| """ |
| 1. get config from [developer.chat_models_config."<model_name>"] |
| 2. rollback to env |
| """ |
| model_name = _s(model_name) |
| if not model_name: |
| return None, "default model name is empty" |
|
|
| model_cfg: Dict[str, Any] = {} |
| try: |
| model_cfg = (cfg.developer.chat_models_config.get(model_name) or {}) if getattr(cfg, "developer", None) else {} |
| except Exception: |
| model_cfg = {} |
|
|
| if not isinstance(model_cfg, dict): |
| model_cfg = {} |
|
|
| base_url = _norm_url(model_cfg.get("base_url")) |
| api_key = _s(model_cfg.get("api_key")) |
|
|
| if not base_url or not api_key: |
| env_url, env_key = _env_fallback_for_model(model_name) |
| if not base_url: |
| base_url = _norm_url(env_url) |
| if not api_key: |
| api_key = _s(env_key) |
|
|
| override: Dict[str, Any] = {"model": model_name} |
| if base_url: |
| override["base_url"] = base_url |
| if api_key: |
| override["api_key"] = api_key |
|
|
| for k in ("timeout", "temperature", "max_retries", "top_p", "max_tokens"): |
| if k in model_cfg and model_cfg.get(k) not in (None, ""): |
| override[k] = model_cfg.get(k) |
|
|
| if not override.get("base_url") or not override.get("api_key"): |
| return None, ( |
| f"cannot find base_url/api_key of default model: {model_name}. " |
| f"please fill in base_url/api_key of [developer.chat_models_config.\"{model_name}\" in config.toml]" |
| f"or set environment variables(DEEPSEEK_API_URL/DEEPSEEK_API_KEY / QWEN3_VL_8B_API_URL/QWEN3_VL_8B_API_KEY)。" |
| ) |
|
|
| return override, None |
|
|
| def _stable_dict_key(d: Optional[Dict[str, Any]]) -> str: |
| try: |
| return json.dumps(d or {}, sort_keys=True, ensure_ascii=False) |
| except Exception: |
| return str(d or {}) |
|
|
| def _parse_service_config(service_cfg: Any) -> Tuple[ |
| Optional[Dict[str, Any]], |
| Optional[Dict[str, Any]], |
| Dict[str, Any], |
| Dict[str, Any], |
| Optional[str]]: |
| """ |
| 返回 (custom_llm, custom_vlm, tts_cfg, pexels, err) |
| - custom_llm/custom_vlm: {"model","base_url","api_key"} 或 None(允许只传 llm 或只传 vlm) |
| - tts_cfg: dict(可能为空) |
| """ |
| if not isinstance(service_cfg, dict): |
| return None, None, {}, {}, None |
|
|
| |
| custom_llm = None |
| custom_vlm = None |
| custom_models = service_cfg.get("custom_models") |
|
|
| if custom_models is not None: |
| if not isinstance(custom_models, dict): |
| return None, None, {}, {}, "service_config.custom_models 必须是对象" |
|
|
| def _pick(m: Any, label: str) -> Tuple[Optional[Dict[str, str]], Optional[str]]: |
| if m is None: |
| return None, None |
| if not isinstance(m, dict): |
| return None, f"service_config.custom_models.{label} 必须是对象" |
|
|
| model = _s(m.get("model")) |
| base_url = _norm_url(m.get("base_url")) |
| api_key = _s(m.get("api_key")) |
|
|
| if not (model and base_url and api_key): |
| return None, f"自定义 {label.upper()} 配置不完整:请填写 model/base_url/api_key" |
| if not (base_url.startswith("http://") or base_url.startswith("https://")): |
| return None, f"自定义 {label.upper()} 的 base_url 必须以 http(s) 开头" |
| return {"model": model, "base_url": base_url, "api_key": api_key}, None |
|
|
| custom_llm, err1 = _pick(custom_models.get("llm"), "llm") |
| if err1: |
| return None, None, {}, {}, err1 |
|
|
| custom_vlm, err2 = _pick(custom_models.get("vlm"), "vlm") |
| if err2: |
| return None, None, {}, {}, err2 |
|
|
| |
| tts_cfg: Dict[str, Any] = {} |
| tts = service_cfg.get("tts") |
| if isinstance(tts, dict): |
| provider = (tts.get("provider") or "").strip().lower() |
| if provider: |
| provider_block = tts.get(provider) |
| tts_cfg = {"provider": provider, provider: provider_block} |
| |
| |
| pexels_cfg: Dict[str, Any] = {} |
| search_media = service_cfg.get("search_media") |
| if isinstance(search_media, dict): |
| |
| |
| |
| p = search_media.get("pexels") or search_media.get("pexels") |
| if isinstance(p, dict): |
| mode = _s(p.get("mode")).lower() |
| if mode not in ("default", "custom"): |
| mode = "default" |
| api_key = _s(p.get("api_key") or p.get("pexels_api_key") or p.get("pexels_api_key")) |
| pexels_cfg = {"mode": mode, "api_key": api_key} |
| else: |
| mode = _s(search_media.get("mode") or search_media.get("pexels_mode") or search_media.get("pexels_mode")).lower() |
| if mode not in ("default", "custom"): |
| mode = "default" |
| api_key = _s(search_media.get("pexels_api_key") or search_media.get("pexels_api_key")) |
| pexels_cfg = {"mode": mode, "api_key": api_key} |
|
|
| return custom_llm, custom_vlm, tts_cfg, pexels_cfg, None |
|
|
| def is_developer_mode(cfg: Settings) -> bool: |
| try: |
| return bool(cfg.developer.developer_mode) |
| except Exception: |
| return False |
|
|
| def _abs(p: str) -> str: |
| return os.path.abspath(os.path.expanduser(p)) |
|
|
|
|
| def resolve_media_dir(cfg_media_dir: str, session_id: str) -> str: |
| root = _abs(cfg_media_dir).rstrip("/\\") |
| if not USE_SESSION_SUBDIR: |
| return root |
| project_dir = os.path.dirname(root) |
| leaf = os.path.basename(root) |
| return os.path.join(project_dir, session_id, leaf) |
|
|
|
|
| def sanitize_filename(name: str) -> str: |
| name = os.path.basename(name or "") |
| name = name.replace("\x00", "") |
| return name or "unnamed" |
|
|
|
|
| def detect_media_kind(filename: str) -> str: |
| ext = os.path.splitext(filename)[1].lower() |
| if ext in {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"}: |
| return "image" |
| if ext in {".mp4", ".mov", ".avi", ".mkv", ".webm"}: |
| return "video" |
| return "unknown" |
|
|
| _MEDIA_RE = re.compile(r"^media_(\d+)", re.IGNORECASE) |
|
|
| def make_media_store_filename(seq: int, ext: str) -> str: |
| ext = (ext or "").lower() |
| if ext and not ext.startswith("."): |
| ext = "." + ext |
| return f"{MEDIA_PREFIX}{seq:0{MEDIA_SEQ_WIDTH}d}{ext}" |
|
|
| def parse_media_seq(filename: str) -> Optional[int]: |
| m = _MEDIA_RE.match(os.path.basename(filename or "")) |
| if not m: |
| return None |
| try: |
| return int(m.group(1)) |
| except Exception: |
| return None |
|
|
| def safe_save_path_no_overwrite(media_dir: str, filename: str) -> str: |
| filename = sanitize_filename(filename) |
| stem, ext = os.path.splitext(filename) |
| path = os.path.join(media_dir, filename) |
| if not os.path.exists(path): |
| return path |
| i = 2 |
| while True: |
| p2 = os.path.join(media_dir, f"{stem} ({i}){ext}") |
| if not os.path.exists(p2): |
| return p2 |
| i += 1 |
|
|
|
|
| def ensure_thumbs_dir(media_dir: str) -> str: |
| d = os.path.join(media_dir, ".thumbs") |
| os.makedirs(d, exist_ok=True) |
| return d |
|
|
| def ensure_uploads_dir(media_dir: str) -> str: |
| d = os.path.join(media_dir, ".uploads") |
| os.makedirs(d, exist_ok=True) |
| return d |
|
|
| def guess_media_type(path: str) -> str: |
| mt, _ = mimetypes.guess_type(path) |
| return mt or "application/octet-stream" |
|
|
|
|
| def _is_under_dir(path: str, root: str) -> bool: |
| try: |
| path = os.path.abspath(path) |
| root = os.path.abspath(root) |
| return os.path.commonpath([path, root]) == root |
| except Exception: |
| return False |
|
|
|
|
| def video_placeholder_svg_bytes() -> bytes: |
| svg = """<svg xmlns="http://www.w3.org/2000/svg" width="320" height="320" viewBox="0 0 320 320"> |
| <defs> |
| <linearGradient id="g" x1="0" x2="1" y1="0" y2="1"> |
| <stop stop-color="#f2f2f2" offset="0"/> |
| <stop stop-color="#e6e6e6" offset="1"/> |
| </linearGradient> |
| </defs> |
| <rect x="0" y="0" width="320" height="320" fill="url(#g)"/> |
| <rect x="22" y="22" width="276" height="276" rx="22" fill="rgba(0,0,0,0.06)"/> |
| <polygon points="140,120 140,200 210,160" fill="rgba(0,0,0,0.55)"/> |
| </svg>""" |
| return svg.encode("utf-8") |
|
|
|
|
| def make_image_thumbnail_sync(src_path: str, dst_path: str, max_size: Tuple[int, int] = (320, 320)) -> bool: |
| try: |
| from PIL import Image |
| img = Image.open(src_path).convert("RGB") |
| img.thumbnail(max_size) |
| img.save(dst_path, format="JPEG", quality=85) |
| return True |
| except Exception: |
| return False |
|
|
| async def make_video_thumbnail_async( |
| src_video: str, |
| dst_path: str, |
| *, |
| max_size: Tuple[int, int] = (320, 320), |
| seek_sec: float = 0.5, |
| timeout_sec: float = 20.0, |
| ) -> bool: |
| ffmpeg = os.environ.get("FFMPEG_BIN") or shutil.which("ffmpeg") |
| if not ffmpeg: |
| logger.warning("ffmpeg not found (PATH/FFMPEG_BIN). skip video thumbnail. src=%s", src_video) |
| return False |
|
|
| src_video = os.path.abspath(src_video) |
| dst_path = os.path.abspath(dst_path) |
| os.makedirs(os.path.dirname(dst_path), exist_ok=True) |
|
|
| tmp_path = dst_path + ".tmp.jpg" |
|
|
| vf = ( |
| f"scale={max_size[0]}:{max_size[1]}:force_original_aspect_ratio=decrease" |
| f",pad={max_size[0]}:{max_size[1]}:(ow-iw)/2:(oh-ih)/2" |
| ) |
|
|
| async def _run(args: list[str]) -> tuple[bool, str]: |
| proc = await asyncio.create_subprocess_exec( |
| *args, |
| stdout=asyncio.subprocess.DEVNULL, |
| stderr=asyncio.subprocess.PIPE, |
| ) |
| try: |
| _, err = await asyncio.wait_for(proc.communicate(), timeout=timeout_sec) |
| except asyncio.TimeoutError: |
| try: |
| proc.kill() |
| except Exception: |
| pass |
| await proc.wait() |
| return False, f"timeout after {timeout_sec}s" |
| err_text = (err or b"").decode("utf-8", "ignore").strip() |
| return (proc.returncode == 0), err_text |
|
|
| |
| |
| common_tail = [ |
| "-an", |
| "-frames:v", "1", |
| "-vf", vf, |
| "-vcodec", "mjpeg", |
| "-q:v", "3", |
| "-f", "image2", |
| tmp_path, |
| ] |
|
|
| attempts = [ |
| |
| [ffmpeg, "-hide_banner", "-loglevel", "error", "-y", "-ss", f"{seek_sec}", "-i", src_video] + common_tail, |
| |
| [ffmpeg, "-hide_banner", "-loglevel", "error", "-y", "-i", src_video, "-ss", f"{seek_sec}"] + common_tail, |
| |
| [ffmpeg, "-hide_banner", "-loglevel", "error", "-y", "-ss", "1.0", "-i", src_video] + common_tail, |
| ] |
|
|
| last_err: Optional[str] = None |
| try: |
| for args in attempts: |
| ok, err = await _run(args) |
| if ok and os.path.exists(tmp_path) and os.path.getsize(tmp_path) > 0: |
| os.replace(tmp_path, dst_path) |
| return True |
| last_err = err or last_err |
| |
| try: |
| if os.path.exists(tmp_path): |
| os.remove(tmp_path) |
| except Exception: |
| pass |
|
|
| logger.warning("ffmpeg thumbnail failed. src=%s dst=%s err=%s", src_video, dst_path, last_err) |
| return False |
| finally: |
| try: |
| if os.path.exists(tmp_path): |
| os.remove(tmp_path) |
| except Exception: |
| pass |
|
|
| def _env_int(name: str, default: int) -> int: |
| try: |
| return int(os.environ.get(name, str(default))) |
| except Exception: |
| return default |
|
|
| def _env_float(name: str, default: float) -> float: |
| try: |
| return float(os.environ.get(name, str(default))) |
| except Exception: |
| return float(default) |
|
|
| def _rpm_to_rps(rpm: float) -> float: |
| return float(rpm) / 60.0 |
|
|
|
|
| |
| RATE_LIMIT_TRUST_PROXY_HEADERS = os.environ.get("RATE_LIMIT_TRUST_PROXY_HEADERS", "0") == "1" |
|
|
| @dataclass |
| class _RateBucket: |
| tokens: float |
| last_ts: float |
| last_seen: float |
|
|
| class TokenBucketRateLimiter: |
| """ |
| 内存令牌桶 + 防爆内存: |
| - max_buckets: 限制内部桶表最大条目数(防止海量 IP 导致字典膨胀) |
| - evict_batch: 超过上限后每次驱逐多少条(按插入顺序驱逐最早创建的桶) |
| """ |
| def __init__( |
| self, |
| ttl_sec: int = 900, |
| cleanup_interval_sec: int = 60, |
| *, |
| max_buckets: int = 100000, |
| evict_batch: int = 2000, |
| ): |
| self.ttl_sec = int(ttl_sec) |
| self.cleanup_interval_sec = int(cleanup_interval_sec) |
| self.max_buckets = int(max(1, max_buckets)) |
| self.evict_batch = int(max(1, evict_batch)) |
|
|
| self._buckets: Dict[str, _RateBucket] = {} |
| self._lock = asyncio.Lock() |
| self._last_cleanup = time.monotonic() |
|
|
| async def allow( |
| self, |
| key: str, |
| *, |
| capacity: float, |
| refill_rate: float, |
| cost: float = 1.0, |
| ) -> Tuple[bool, float, float]: |
| """ |
| 返回: (allowed, retry_after_sec, remaining_tokens) |
| """ |
| now = time.monotonic() |
| capacity = float(max(0.0, capacity)) |
| refill_rate = float(max(0.0, refill_rate)) |
| cost = float(max(0.0, cost)) |
|
|
| async with self._lock: |
| b = self._buckets.get(key) |
|
|
| if b is None: |
| |
| if now - self._last_cleanup > self.cleanup_interval_sec: |
| self._cleanup_locked(now) |
| self._last_cleanup = now |
|
|
| |
| if len(self._buckets) >= self.max_buckets: |
| self._cleanup_locked(now) |
|
|
| if len(self._buckets) >= self.max_buckets: |
| self._evict_locked() |
|
|
| if len(self._buckets) >= self.max_buckets: |
| |
| |
| return False, 1.0, 0.0 |
|
|
| b = _RateBucket(tokens=capacity, last_ts=now, last_seen=now) |
| self._buckets[key] = b |
| else: |
| b.last_seen = now |
|
|
| |
| elapsed = max(0.0, now - b.last_ts) |
| if refill_rate > 0: |
| b.tokens = min(capacity, b.tokens + elapsed * refill_rate) |
| else: |
| b.tokens = min(capacity, b.tokens) |
| b.last_ts = now |
|
|
| if b.tokens >= cost: |
| b.tokens -= cost |
| return True, 0.0, float(max(0.0, b.tokens)) |
|
|
| |
| if refill_rate <= 0: |
| retry_after = float(self.ttl_sec) |
| else: |
| need = cost - b.tokens |
| retry_after = need / refill_rate |
| return False, float(retry_after), float(max(0.0, b.tokens)) |
|
|
| def _cleanup_locked(self, now: float) -> None: |
| ttl = float(self.ttl_sec) |
| dead = [k for k, b in self._buckets.items() if (now - b.last_seen) > ttl] |
| for k in dead: |
| self._buckets.pop(k, None) |
|
|
| def _evict_locked(self) -> None: |
| |
| n = min(self.evict_batch, len(self._buckets)) |
| for _ in range(n): |
| try: |
| k = next(iter(self._buckets)) |
| except StopIteration: |
| break |
| self._buckets.pop(k, None) |
|
|
| def _headers_to_dict(scope_headers: List[Tuple[bytes, bytes]]) -> Dict[str, str]: |
| d: Dict[str, str] = {} |
| for k, v in scope_headers or []: |
| try: |
| dk = k.decode("latin1").lower() |
| dv = v.decode("latin1") |
| except Exception: |
| continue |
| d[dk] = dv |
| return d |
|
|
| def _client_ip_from_http_scope(scope: dict, trust_proxy_headers: bool) -> str: |
| headers = _headers_to_dict(scope.get("headers") or []) |
| if trust_proxy_headers: |
| xff = headers.get("x-forwarded-for") |
| if xff: |
| |
| return xff.split(",")[0].strip() or "unknown" |
| xri = headers.get("x-real-ip") |
| if xri: |
| return xri.strip() or "unknown" |
|
|
| client = scope.get("client") |
| if client and isinstance(client, (list, tuple)) and len(client) >= 1: |
| return str(client[0] or "unknown") |
| return "unknown" |
|
|
| def _client_ip_from_ws(ws: WebSocket, trust_proxy_headers: bool) -> str: |
| try: |
| if trust_proxy_headers: |
| xff = ws.headers.get("x-forwarded-for") |
| if xff: |
| return xff.split(",")[0].strip() or "unknown" |
| xri = ws.headers.get("x-real-ip") |
| if xri: |
| return xri.strip() or "unknown" |
| except Exception: |
| pass |
|
|
| try: |
| if ws.client: |
| return str(ws.client.host or "unknown") |
| except Exception: |
| pass |
|
|
| return "unknown" |
|
|
| |
| UPLOAD_RESUMABLE_CHUNK_BYTES = _env_int("UPLOAD_RESUMABLE_CHUNK_BYTES", 8 * 1024 * 1024) |
|
|
| |
| RESUMABLE_UPLOAD_TTL_SEC = _env_int("RESUMABLE_UPLOAD_TTL_SEC", 3600) |
|
|
| MEDIA_SEQ_WIDTH = 4 |
| MEDIA_PREFIX = "media_" |
|
|
|
|
| |
|
|
| |
| HTTP_GLOBAL_RPM = _env_int("RATE_LIMIT_HTTP_GLOBAL_RPM", 3000) |
| HTTP_GLOBAL_BURST = _env_int("RATE_LIMIT_HTTP_GLOBAL_BURST", 600) |
|
|
| |
| HTTP_CREATE_SESSION_RPM = _env_int("RATE_LIMIT_CREATE_SESSION_RPM", 3000) |
| HTTP_CREATE_SESSION_BURST = _env_int("RATE_LIMIT_CREATE_SESSION_BURST", 50) |
|
|
| |
| HTTP_UPLOAD_MEDIA_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_RPM", 12000) |
| HTTP_UPLOAD_MEDIA_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_BURST", 300) |
|
|
| |
| UPLOAD_COST_BYTES = _env_int("RATE_LIMIT_UPLOAD_COST_BYTES", 10 * 1024 * 1024) |
|
|
| |
| MAX_UPLOAD_FILES_PER_REQUEST = _env_int("MAX_UPLOAD_FILES_PER_REQUEST", 30) |
| MAX_MEDIA_PER_SESSION = _env_int("MAX_MEDIA_PER_SESSION", 30) |
| MAX_PENDING_MEDIA_PER_SESSION = _env_int("MAX_PENDING_MEDIA_PER_SESSION", 30) |
|
|
| HTTP_UPLOAD_MEDIA_COUNT_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_RPM", 50000) |
| HTTP_UPLOAD_MEDIA_COUNT_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_BURST", 1000) |
|
|
| |
| HTTP_MEDIA_GET_RPM = _env_int("RATE_LIMIT_MEDIA_GET_RPM", 2400) |
| HTTP_MEDIA_GET_BURST = _env_int("RATE_LIMIT_MEDIA_GET_BURST", 60) |
|
|
| |
| HTTP_CLEAR_RPM = _env_int("RATE_LIMIT_CLEAR_SESSION_RPM", 3000) |
| HTTP_CLEAR_BURST = _env_int("RATE_LIMIT_CLEAR_SESSION_BURST", 50) |
|
|
| |
| HTTP_API_RPM = _env_int("RATE_LIMIT_API_RPM", 2400) |
| HTTP_API_BURST = _env_int("RATE_LIMIT_API_BURST", 120) |
|
|
| |
| WS_CONNECT_RPM = _env_int("RATE_LIMIT_WS_CONNECT_RPM", 600) |
| WS_CONNECT_BURST = _env_int("RATE_LIMIT_WS_CONNECT_BURST", 50) |
|
|
| |
| WS_CHAT_SEND_RPM = _env_int("RATE_LIMIT_WS_CHAT_SEND_RPM", 300) |
| WS_CHAT_SEND_BURST = _env_int("RATE_LIMIT_WS_CHAT_SEND_BURST", 20) |
|
|
| |
| HTTP_ALL_RPM = _env_int("RATE_LIMIT_HTTP_ALL_RPM", 1200) |
| HTTP_ALL_BURST = _env_int("RATE_LIMIT_HTTP_ALL_BURST", 200) |
|
|
| CREATE_SESSION_ALL_RPM = _env_int("RATE_LIMIT_CREATE_SESSION_ALL_RPM", 120) |
| CREATE_SESSION_ALL_BURST = _env_int("RATE_LIMIT_CREATE_SESSION_ALL_BURST", 20) |
|
|
| UPLOAD_MEDIA_ALL_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_ALL_RPM", 6000) |
| UPLOAD_MEDIA_ALL_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_ALL_BURST", 2000) |
|
|
| |
| UPLOAD_MEDIA_COUNT_ALL_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_ALL_RPM", UPLOAD_MEDIA_ALL_RPM) |
| UPLOAD_MEDIA_COUNT_ALL_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_ALL_BURST", UPLOAD_MEDIA_ALL_BURST) |
|
|
| MEDIA_GET_ALL_RPM = _env_int("RATE_LIMIT_MEDIA_GET_ALL_RPM", 600) |
| MEDIA_GET_ALL_BURST = _env_int("RATE_LIMIT_MEDIA_GET_ALL_BURST", 120) |
|
|
| WS_CONNECT_ALL_RPM = _env_int("RATE_LIMIT_WS_CONNECT_ALL_RPM", 60000) |
| WS_CONNECT_ALL_BURST = _env_int("RATE_LIMIT_WS_CONNECT_ALL_BURST", 2000) |
|
|
| WS_CHAT_SEND_ALL_RPM = _env_int("RATE_LIMIT_WS_CHAT_SEND_ALL_RPM", 500) |
| WS_CHAT_SEND_ALL_BURST = _env_int("RATE_LIMIT_WS_CHAT_SEND_ALL_BURST", 30) |
|
|
| |
| WS_MAX_CONNECTIONS = _env_int("RATE_LIMIT_WS_MAX_CONNECTIONS", 500) |
| CHAT_MAX_CONCURRENCY = _env_int("RATE_LIMIT_CHAT_MAX_CONCURRENCY", 80) |
| UPLOAD_MAX_CONCURRENCY = _env_int("RATE_LIMIT_UPLOAD_MAX_CONCURRENCY", 100) |
|
|
| WS_CONN_SEM = asyncio.Semaphore(WS_MAX_CONNECTIONS) |
| CHAT_TURN_SEM = asyncio.Semaphore(CHAT_MAX_CONCURRENCY) |
| UPLOAD_SEM = asyncio.Semaphore(UPLOAD_MAX_CONCURRENCY) |
|
|
| def _global_http_rule_limit(rule_name: str) -> Optional[Tuple[int, int]]: |
| if rule_name == "create_session": |
| return CREATE_SESSION_ALL_BURST, CREATE_SESSION_ALL_RPM |
| if rule_name == "upload_media": |
| return UPLOAD_MEDIA_ALL_BURST, UPLOAD_MEDIA_ALL_RPM |
| if rule_name == "media_get": |
| return MEDIA_GET_ALL_BURST, MEDIA_GET_ALL_RPM |
| return None |
|
|
|
|
| def _get_content_length(scope: dict) -> Optional[int]: |
| try: |
| headers = _headers_to_dict(scope.get("headers") or []) |
| v = headers.get("content-length") |
| if v is None: |
| return None |
| n = int(v) |
| if n < 0: |
| return None |
| return n |
| except Exception: |
| return None |
|
|
| def _match_http_rule(method: str, path: str) -> Tuple[str, int, int, float]: |
| """ |
| 返回 (rule_name, burst, rpm, cost) |
| cost 默认为 1;上传接口会按 content-length 动态计算 cost(在 middleware 内处理)。 |
| """ |
| method = (method or "").upper() |
| path = path or "" |
|
|
| |
| if method == "POST" and path == "/api/sessions": |
| return ("create_session", HTTP_CREATE_SESSION_BURST, HTTP_CREATE_SESSION_RPM, 1.0) |
|
|
| |
| if method == "POST" and path.startswith("/api/sessions/"): |
| if path.endswith("/media") or path.endswith("/media/init"): |
| return ("upload_media", HTTP_UPLOAD_MEDIA_BURST, HTTP_UPLOAD_MEDIA_RPM, 1.0) |
| if "/media/" in path and (path.endswith("/chunk") or path.endswith("/complete") or path.endswith("/cancel")): |
| return ("upload_media", HTTP_UPLOAD_MEDIA_BURST, HTTP_UPLOAD_MEDIA_RPM, 1.0) |
|
|
| if method == "GET" and path.startswith("/api/sessions/") and (path.endswith("/thumb") or path.endswith("/file")): |
| return ("media_get", HTTP_MEDIA_GET_BURST, HTTP_MEDIA_GET_RPM, 1.0) |
|
|
| if method == "POST" and path.startswith("/api/sessions/") and path.endswith("/clear"): |
| return ("clear_session", HTTP_CLEAR_BURST, HTTP_CLEAR_RPM, 1.0) |
|
|
| |
| if path.startswith("/api/"): |
| return ("api_general", HTTP_API_BURST, HTTP_API_RPM, 1.0) |
|
|
| |
| return ("", 0, 0, 1.0) |
|
|
| class HttpRateLimitMiddleware: |
| """ |
| ASGI middleware:对 HTTP 请求做限流(WebSocket 不在这里处理)。 |
| """ |
| def __init__(self, app: Any, limiter: TokenBucketRateLimiter, trust_proxy_headers: bool = False): |
| self.app = app |
| self.limiter = limiter |
| self.trust_proxy_headers = bool(trust_proxy_headers) |
|
|
| async def __call__(self, scope: dict, receive: Any, send: Any): |
| if scope.get("type") != "http": |
| return await self.app(scope, receive, send) |
|
|
| method = scope.get("method", "GET") |
| path = scope.get("path", "/") |
| ip = _client_ip_from_http_scope(scope, self.trust_proxy_headers) |
|
|
| |
| ok, retry_after, _ = await self.limiter.allow( |
| key="http:all", |
| capacity=float(HTTP_ALL_BURST), |
| refill_rate=_rpm_to_rps(float(HTTP_ALL_RPM)), |
| cost=1.0, |
| ) |
| if not ok: |
| return await self._reject(send, retry_after) |
|
|
| |
| ok, retry_after, _ = await self.limiter.allow( |
| key=f"http:global:{ip}", |
| capacity=float(HTTP_GLOBAL_BURST), |
| refill_rate=_rpm_to_rps(float(HTTP_GLOBAL_RPM)), |
| cost=1.0, |
| ) |
| if not ok: |
| return await self._reject(send, retry_after) |
|
|
| |
| rule_name, burst, rpm, cost = _match_http_rule(method, path) |
|
|
| |
| if rule_name == "upload_media": |
| cl = _get_content_length(scope) |
| if cl and cl > 0 and UPLOAD_COST_BYTES > 0: |
| cost = max(1.0, float(math.ceil(cl / float(UPLOAD_COST_BYTES)))) |
|
|
| if rule_name: |
| |
| g = _global_http_rule_limit(rule_name) |
| if g: |
| g_burst, g_rpm = g |
| okg, rag, _ = await self.limiter.allow( |
| key=f"http:{rule_name}:all", |
| capacity=float(g_burst), |
| refill_rate=_rpm_to_rps(float(g_rpm)), |
| cost=float(cost), |
| ) |
| if not okg: |
| return await self._reject(send, rag) |
|
|
| |
| ok2, retry_after2, _ = await self.limiter.allow( |
| key=f"http:{rule_name}:{ip}", |
| capacity=float(burst), |
| refill_rate=_rpm_to_rps(float(rpm)), |
| cost=float(cost), |
| ) |
| if not ok2: |
| return await self._reject(send, retry_after2) |
|
|
| return await self.app(scope, receive, send) |
|
|
|
|
| async def _reject(self, send: Any, retry_after: float): |
| ra = int(math.ceil(float(retry_after or 0.0))) |
| body = json.dumps( |
| {"detail": "Too Many Requests", "retry_after": ra}, |
| ensure_ascii=False |
| ).encode("utf-8") |
|
|
| headers = [ |
| (b"content-type", b"application/json; charset=utf-8"), |
| (b"retry-after", str(ra).encode("ascii")), |
| ] |
|
|
| await send({"type": "http.response.start", "status": 429, "headers": headers}) |
| await send({"type": "http.response.body", "body": body, "more_body": False}) |
|
|
| RATE_LIMITER = TokenBucketRateLimiter( |
| ttl_sec=_env_int("RATE_LIMIT_TTL_SEC", 900), |
| cleanup_interval_sec=_env_int("RATE_LIMIT_CLEANUP_INTERVAL_SEC", 60), |
| max_buckets=_env_int("RATE_LIMIT_MAX_BUCKETS", 100000), |
| evict_batch=_env_int("RATE_LIMIT_EVICT_BATCH", 2000), |
| ) |
|
|
|
|
| @dataclass |
| class MediaMeta: |
| id: str |
| name: str |
| kind: str |
| path: str |
| thumb_path: Optional[str] |
| ts: float |
|
|
| @dataclass |
| class ResumableUpload: |
| upload_id: str |
| filename: str |
| store_filename: str |
| size: int |
| chunk_size: int |
| total_chunks: int |
| tmp_path: str |
| kind: str |
| created_ts: float |
| last_ts: float |
| received: Set[int] = field(default_factory=set) |
| closed: bool = False |
| lock: asyncio.Lock = field(default_factory=asyncio.Lock) |
|
|
| class MediaStore: |
| """ |
| 专注文件系统层: |
| - 保存上传文件(async chunk) |
| - 生成缩略图(图片:线程;视频:异步子进程) |
| - 删除文件(只删 media_dir 下的文件) |
| """ |
| def __init__(self, media_dir: str): |
| self.media_dir = os.path.abspath(media_dir) |
| os.makedirs(self.media_dir, exist_ok=True) |
| self.thumbs_dir = ensure_thumbs_dir(self.media_dir) |
|
|
| async def save_upload(self, uf: UploadFile, *, store_filename: str, display_name: str) -> MediaMeta: |
| media_id = uuid.uuid4().hex[:10] |
|
|
| display_name = sanitize_filename(display_name or uf.filename or "unnamed") |
| store_filename = sanitize_filename(store_filename) |
|
|
| kind = detect_media_kind(display_name) |
|
|
| save_path = os.path.join(self.media_dir, store_filename) |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
|
| if os.path.exists(save_path): |
| raise HTTPException(status_code=409, detail=f"media filename exists: {store_filename}") |
|
|
| |
| async with await anyio.open_file(save_path, "wb") as out: |
| while True: |
| chunk = await uf.read(CHUNK_SIZE) |
| if not chunk: |
| break |
| await out.write(chunk) |
|
|
| try: |
| await uf.close() |
| except Exception: |
| pass |
|
|
| thumb_path: Optional[str] = None |
| if kind in ("image", "video"): |
| thumb_path = os.path.join(self.thumbs_dir, f"{media_id}.jpg") |
|
|
| if kind == "image": |
| ok = await anyio.to_thread.run_sync(make_image_thumbnail_sync, save_path, thumb_path) |
| else: |
| ok = await make_video_thumbnail_async(save_path, thumb_path) |
|
|
| if not ok: |
| |
| thumb_path = save_path if kind == "image" else None |
|
|
| return MediaMeta( |
| id=media_id, |
| name=os.path.basename(display_name), |
| kind=kind, |
| path=os.path.abspath(save_path), |
| thumb_path=os.path.abspath(thumb_path) if thumb_path else None, |
| ts=time.time(), |
| ) |
| |
| async def save_from_path( |
| self, |
| src_path: str, |
| *, |
| store_filename: str, |
| display_name: str, |
| ) -> MediaMeta: |
| """ |
| 将分片上传产生的临时文件移动到 media_dir 下的最终文件。 |
| - display_name: UI 展示名(原始文件名) |
| - store_filename: 落盘名(media_0001.mp4),用于记录顺序 |
| """ |
| media_id = uuid.uuid4().hex[:10] |
|
|
| display_name = sanitize_filename(display_name or "unnamed") |
| store_filename = sanitize_filename(store_filename or "unnamed") |
|
|
| kind = detect_media_kind(display_name) |
|
|
| src_path = os.path.abspath(src_path) |
| if not os.path.exists(src_path): |
| raise HTTPException(status_code=400, detail="upload temp file missing") |
|
|
| save_path = os.path.abspath(os.path.join(self.media_dir, store_filename)) |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
|
| if os.path.exists(save_path): |
| raise HTTPException(status_code=409, detail=f"media already exists: {store_filename}") |
|
|
| |
| os.replace(src_path, save_path) |
|
|
| thumb_path: Optional[str] = None |
| if kind in ("image", "video"): |
| thumb_path = os.path.join(self.thumbs_dir, f"{media_id}.jpg") |
|
|
| if kind == "image": |
| ok = await anyio.to_thread.run_sync(make_image_thumbnail_sync, save_path, thumb_path) |
| else: |
| ok = await make_video_thumbnail_async(save_path, thumb_path) |
|
|
| if not ok: |
| thumb_path = save_path if kind == "image" else None |
|
|
| return MediaMeta( |
| id=media_id, |
| name=os.path.basename(display_name), |
| kind=kind, |
| path=os.path.abspath(save_path), |
| thumb_path=os.path.abspath(thumb_path) if thumb_path else None, |
| ts=time.time(), |
| ) |
|
|
| async def delete_files(self, meta: MediaMeta) -> None: |
| root = self.media_dir |
| for p in {meta.path, meta.thumb_path}: |
| if not p: |
| continue |
| ap = os.path.abspath(p) |
| if not _is_under_dir(ap, root): |
| continue |
| if os.path.isdir(ap): |
| continue |
| if os.path.exists(ap): |
| try: |
| os.remove(ap) |
| except Exception: |
| pass |
|
|
|
|
| class ChatSession: |
| """ |
| 一个 session 的全部状态: |
| - agent / lc_messages(LangChain上下文) |
| - history(给前端回放) |
| - load_media / pending_media(staging) |
| - tool trace 索引(支持 tool 事件“就地更新”) |
| """ |
| def __init__(self, session_id: str, cfg: Settings): |
| self.session_id = session_id |
| self.cfg = cfg |
| self.lang = "zh" |
|
|
| default_llm = _s(getattr(getattr(cfg, "developer", None), "default_llm", "")) or "deepseek-chat" |
| default_vlm = _s(getattr(getattr(cfg, "developer", None), "default_vlm", "")) or "qwen3-vl-8b-instruct" |
|
|
| self.chat_models = [default_llm, CUSTOM_MODEL_KEY] |
| self.chat_model_key = default_llm |
|
|
| self.vlm_models = [default_vlm, CUSTOM_MODEL_KEY] |
| self.vlm_model_key = default_vlm |
|
|
| self.developer_mode = is_developer_mode(cfg) |
|
|
| self.media_dir = resolve_media_dir(cfg.project.media_dir, session_id) |
| self.media_store = MediaStore(self.media_dir) |
| |
| self.uploads_dir = ensure_uploads_dir(self.media_dir) |
| self.resumable_uploads: Dict[str, ResumableUpload] = {} |
|
|
| |
| self._direct_upload_reservations = 0 |
|
|
| self.agent: Any = None |
| self.node_manager = None |
| self.client_context = None |
| |
| |
| self.chat_lock = asyncio.Lock() |
| self.media_lock = asyncio.Lock() |
|
|
| self.sent_media_total: int = 0 |
| self._attach_stats_msg_idx = 1 |
|
|
| self.lc_messages: List[BaseMessage] = [ |
| SystemMessage(content=get_prompt("instruction.system", lang=self.lang)), |
| SystemMessage(content="【User media upload status】{}"), |
| ] |
| self.history: List[Dict[str, Any]] = [] |
|
|
| self.load_media: Dict[str, MediaMeta] = {} |
| self.pending_media_ids: List[str] = [] |
|
|
| self._tool_history_index: Dict[str, int] = {} |
|
|
| self.cancel_event = asyncio.Event() |
|
|
| |
| self.custom_llm_config: Optional[Dict[str, Any]] = None |
| self.custom_vlm_config: Optional[Dict[str, Any]] = None |
| self.tts_config: Dict[str, Any] = {} |
| self._agent_build_key: Optional[Tuple[Any, ...]] = None |
|
|
| self.pexels_key_mode: str = "default" |
| self.pexels_custom_key: str = "" |
|
|
| self._media_seq_inited = False |
| self._media_seq_next = 1 |
|
|
| def _ensure_system_prompt(self) -> None: |
| sys = (get_prompt("instruction.system", lang=self.lang) or "").strip() |
| if not sys: |
| return |
|
|
| for m in self.lc_messages: |
| if isinstance(m, SystemMessage) and (getattr(m, "content", "") or "").strip() == sys: |
| return |
|
|
| self.lc_messages.insert(0, SystemMessage(content=sys)) |
|
|
| def _init_media_seq_locked(self) -> None: |
| """ |
| 初始化 self._media_seq_next: |
| - 允许 clear chat 后继续编号,不覆盖旧文件 |
| """ |
| if self._media_seq_inited: |
| return |
|
|
| max_seq = 0 |
|
|
| |
| try: |
| for fn in os.listdir(self.media_dir): |
| s = parse_media_seq(fn) |
| if s is not None: |
| max_seq = max(max_seq, s) |
| except Exception: |
| pass |
|
|
| |
| for meta in (self.load_media or {}).values(): |
| s = parse_media_seq(os.path.basename(meta.path or "")) |
| if s is not None: |
| max_seq = max(max_seq, s) |
|
|
| |
| for u in (self.resumable_uploads or {}).values(): |
| s = parse_media_seq(getattr(u, "store_filename", "") or "") |
| if s is not None: |
| max_seq = max(max_seq, s) |
|
|
| self._media_seq_next = max_seq + 1 |
| self._media_seq_inited = True |
|
|
|
|
| def _reserve_store_filenames_locked(self, display_filenames: List[str]) -> List[str]: |
| """ |
| 按传入顺序生成一组 store 文件名(media_0001.ext ...) |
| 注意:这里的“顺序”就是你要固化的上传顺序。 |
| """ |
| self._init_media_seq_locked() |
|
|
| out: List[str] = [] |
| seq = int(self._media_seq_next) |
|
|
| for disp in display_filenames: |
| disp = sanitize_filename(disp or "unnamed") |
| ext = os.path.splitext(disp)[1].lower() |
|
|
| |
| while True: |
| store = make_media_store_filename(seq, ext) |
| if not os.path.exists(os.path.join(self.media_dir, store)): |
| break |
| seq += 1 |
|
|
| out.append(store) |
| seq += 1 |
|
|
| self._media_seq_next = seq |
| return out |
|
|
|
|
| def apply_service_config(self, service_cfg: Any) -> Tuple[bool, Optional[str]]: |
| llm, vlm, tts, pexels, err = _parse_service_config(service_cfg) |
| if err: |
| return False, err |
|
|
| if llm is not None: |
| self.custom_llm_config = llm |
| if vlm is not None: |
| self.custom_vlm_config = vlm |
|
|
| |
| if isinstance(tts, dict) and tts: |
| self.tts_config = tts |
|
|
| |
| if isinstance(pexels, dict) and pexels: |
| mode = _s(pexels.get("mode")).lower() |
| if mode == "custom": |
| self.pexels_key_mode = "custom" |
| self.pexels_custom_key = _s(pexels.get("api_key")) |
| else: |
| self.pexels_key_mode = "default" |
| self.pexels_custom_key = "" |
|
|
| return True, None |
|
|
| async def ensure_agent(self) -> None: |
| |
| if self.chat_model_key == CUSTOM_MODEL_KEY: |
| if not isinstance(self.custom_llm_config, dict): |
| raise RuntimeError("please fill in model/base_url/api_key of custom LLM") |
| llm_override = self.custom_llm_config |
| else: |
| llm_override, err = _resolve_default_model_override(self.cfg, self.chat_model_key) |
| if err: |
| raise RuntimeError(err) |
|
|
| |
| if self.vlm_model_key == CUSTOM_MODEL_KEY: |
| if not isinstance(self.custom_vlm_config, dict): |
| raise RuntimeError("please fill in model/base_url/api_key of custom VLM") |
| vlm_override = self.custom_vlm_config |
| else: |
| vlm_override, err = _resolve_default_model_override(self.cfg, self.vlm_model_key) |
| if err: |
| raise RuntimeError(err) |
|
|
| agent_build_key: Tuple[Any, ...] = ( |
| "models", |
| _stable_dict_key(llm_override), |
| _stable_dict_key(vlm_override), |
| ) |
|
|
| if self.agent is None or self._agent_build_key != agent_build_key: |
| artifact_store = ArtifactStore(self.cfg.project.outputs_dir, session_id=self.session_id) |
| self.agent, self.node_manager = await build_agent( |
| cfg=self.cfg, |
| session_id=self.session_id, |
| store=artifact_store, |
| tool_interceptors=[ |
| ToolInterceptor.inject_media_content_before, |
| ToolInterceptor.save_media_content_after, |
| ToolInterceptor.inject_tts_config, |
| ToolInterceptor.inject_pexels_api_key, |
| ], |
| llm_override=llm_override, |
| vlm_override=vlm_override, |
| ) |
| self._agent_build_key = agent_build_key |
|
|
| if self.client_context is None: |
| self.client_context = ClientContext( |
| cfg=self.cfg, |
| session_id=self.session_id, |
| media_dir=self.media_dir, |
| bgm_dir=self.cfg.project.bgm_dir, |
| outputs_dir=self.cfg.project.outputs_dir, |
| node_manager=self.node_manager, |
| chat_model_key=self.chat_model_key, |
| vlm_model_key=self.vlm_model_key, |
| tts_config=(self.tts_config or None), |
| pexels_api_key=None, |
| lang=self.lang, |
| ) |
| else: |
| self.client_context.chat_model_key = self.chat_model_key |
| self.client_context.vlm_model_key = self.vlm_model_key |
| self.client_context.tts_config = (self.tts_config or None) |
| self.client_context.lang = self.lang |
|
|
| |
| pexels_api_key = "" |
| if (self.pexels_key_mode or "").lower() == "custom": |
| pexels_api_key = _s(self.pexels_custom_key) |
| else: |
| pexels_api_key = _get_default_pexels_api_key(self.cfg) |
|
|
| self.client_context.pexels_api_key = (pexels_api_key or None) |
|
|
| |
| def public_media(self, meta: MediaMeta) -> Dict[str, Any]: |
| return { |
| "id": meta.id, |
| "name": meta.name, |
| "kind": meta.kind, |
| "thumb_url": f"/api/sessions/{self.session_id}/media/{meta.id}/thumb", |
| "file_url": f"/api/sessions/{self.session_id}/media/{meta.id}/file", |
| } |
|
|
| def public_pending_media(self) -> List[Dict[str, Any]]: |
| out: List[Dict[str, Any]] = [] |
| for aid in self.pending_media_ids: |
| meta = self.load_media.get(aid) |
| if meta: |
| out.append(self.public_media(meta)) |
| return out |
|
|
| def snapshot(self) -> Dict[str, Any]: |
| return { |
| "session_id": self.session_id, |
| "developer_mode": self.developer_mode, |
| "pending_media": self.public_pending_media(), |
| "history": self.history, |
| "limits": { |
| "max_upload_files_per_request": MAX_UPLOAD_FILES_PER_REQUEST, |
| "max_media_per_session": MAX_MEDIA_PER_SESSION, |
| "max_pending_media_per_session": MAX_PENDING_MEDIA_PER_SESSION, |
| "upload_chunk_bytes": UPLOAD_RESUMABLE_CHUNK_BYTES, |
| }, |
| "stats": { |
| "media_count": len(self.load_media), |
| "pending_count": len(self.pending_media_ids), |
| "inflight_uploads": len(self.resumable_uploads), |
| }, |
| "chat_model_key": self.chat_model_key, |
| "chat_models": self.chat_models, |
| "llm_model_key": self.chat_model_key, |
| "llm_models": self.chat_models, |
| "vlm_model_key": self.vlm_model_key, |
| "vlm_models": self.vlm_models, |
| "lang": self.lang, |
| } |
|
|
| |
| def _cleanup_stale_uploads_locked(self, now: Optional[float] = None) -> None: |
| now = float(now or time.time()) |
| ttl = float(RESUMABLE_UPLOAD_TTL_SEC) |
| dead = [uid for uid, u in self.resumable_uploads.items() if (now - u.last_ts) > ttl] |
| for uid in dead: |
| u = self.resumable_uploads.pop(uid, None) |
| if not u: |
| continue |
| try: |
| if u.tmp_path and os.path.exists(u.tmp_path): |
| os.remove(u.tmp_path) |
| except Exception: |
| pass |
|
|
| def _check_media_caps_locked(self, add: int = 0) -> None: |
| add = int(max(0, add)) |
| total = len(self.load_media) + len(self.resumable_uploads) + int(self._direct_upload_reservations) |
| pending = len(self.pending_media_ids) + len(self.resumable_uploads) + int(self._direct_upload_reservations) |
|
|
| if MAX_MEDIA_PER_SESSION > 0 and (total + add) > MAX_MEDIA_PER_SESSION: |
| raise HTTPException( |
| status_code=400, |
| detail=f"会话素材总数已达上限:{total}/{MAX_MEDIA_PER_SESSION}", |
| ) |
|
|
| if MAX_PENDING_MEDIA_PER_SESSION > 0 and (pending + add) > MAX_PENDING_MEDIA_PER_SESSION: |
| raise HTTPException( |
| status_code=400, |
| detail=f"待发送素材数量已达上限:{pending}/{MAX_PENDING_MEDIA_PER_SESSION}", |
| ) |
|
|
| async def add_uploads(self, files: List[UploadFile], store_filenames: List[str]) -> List[MediaMeta]: |
| if len(store_filenames) != len(files): |
| raise HTTPException(status_code=500, detail="store_filenames mismatch") |
|
|
| metas: List[MediaMeta] = [] |
| for uf, store_fn in zip(files, store_filenames): |
| display_name = sanitize_filename(uf.filename or "unnamed") |
| metas.append(await self.media_store.save_upload( |
| uf, |
| store_filename=store_fn, |
| display_name=display_name, |
| )) |
|
|
| async with self.media_lock: |
| for m in metas: |
| self.load_media[m.id] = m |
| self.pending_media_ids.append(m.id) |
|
|
| self.pending_media_ids.sort( |
| key=lambda aid: os.path.basename(self.load_media[aid].path or "") |
| if aid in self.load_media else "" |
| ) |
|
|
| return metas |
|
|
| async def delete_pending_media(self, media_id: str) -> None: |
| async with self.media_lock: |
| if media_id not in self.pending_media_ids: |
| raise HTTPException(status_code=400, detail="media is not pending (refuse physical delete)") |
| self.pending_media_ids = [x for x in self.pending_media_ids if x != media_id] |
| meta = self.load_media.pop(media_id, None) |
|
|
| if meta: |
| await self.media_store.delete_files(meta) |
|
|
| async def take_pending_media_for_message(self, attachment_ids: Optional[List[str]]) -> List[MediaMeta]: |
| async with self.media_lock: |
| if attachment_ids: |
| pick = [aid for aid in attachment_ids if aid in self.pending_media_ids] |
| else: |
| pick = list(self.pending_media_ids) |
|
|
| pick_set = set(pick) |
| self.pending_media_ids = [aid for aid in self.pending_media_ids if aid not in pick_set] |
| metas = [self.load_media[aid] for aid in pick if aid in self.load_media] |
| return metas |
|
|
| |
| def _ensure_tool_record(self, tcid: str, server: str, name: str, args: Any) -> Dict[str, Any]: |
| idx = self._tool_history_index.get(tcid) |
| if idx is None: |
| rec = { |
| "id": f"tool_{tcid}", |
| "role": "tool", |
| "tool_call_id": tcid, |
| "server": server, |
| "name": name, |
| "args": args, |
| "state": "running", |
| "progress": 0.0, |
| "message": "", |
| "summary": None, |
| "ts": time.time(), |
| } |
| self.history.append(rec) |
| self._tool_history_index[tcid] = len(self.history) - 1 |
| return rec |
| return self.history[idx] |
|
|
| def apply_tool_event(self, raw: Dict[str, Any]) -> Optional[Dict[str, Any]]: |
| et = raw.get("type") |
| tcid = raw.get("tool_call_id") |
| if et not in ("tool_start", "tool_progress", "tool_end") or not tcid: |
| return None |
|
|
| server = raw.get("server") or "" |
| name = raw.get("name") or "" |
| args = raw.get("args") or {} |
|
|
| rec = self._ensure_tool_record(tcid, server, name, args) |
|
|
| if et == "tool_start": |
| rec.update({ |
| "server": server, |
| "name": name, |
| "args": args, |
| "state": "running", |
| "progress": 0.0, |
| "message": "Starting...", |
| "summary": None, |
| }) |
|
|
| elif et == "tool_progress": |
| progress = float(raw.get("progress", 0.0)) |
| total = raw.get("total") |
| if total and float(total) > 0: |
| p = progress / float(total) |
| else: |
| p = progress / 100.0 if progress > 1 else progress |
| p = max(0.0, min(1.0, p)) |
| rec.update({ |
| "state": "running", |
| "progress": p, |
| "message": raw.get("message") or "", |
| }) |
|
|
| elif et == "tool_end": |
| is_error = bool(raw.get("is_error")) |
|
|
| summary = raw.get("summary") |
| try: |
| json.dumps(summary, ensure_ascii=False) |
| except Exception: |
| summary = str(summary) if summary is not None else None |
| rec.update({ |
| "state": "error" if is_error else "complete", |
| "progress": 1.0, |
| "summary": summary, |
| "message": raw.get("message") or rec.get("message") or "", |
| }) |
|
|
| return rec |
|
|
|
|
| class SessionStore: |
| def __init__(self, cfg: Settings): |
| self.cfg = cfg |
| self._lock = asyncio.Lock() |
| self._sessions: Dict[str, ChatSession] = {} |
|
|
| async def create(self) -> ChatSession: |
| sid = uuid.uuid4().hex |
| sess = ChatSession(sid, self.cfg) |
| async with self._lock: |
| self._sessions[sid] = sess |
| return sess |
|
|
| async def get(self, sid: str) -> Optional[ChatSession]: |
| async with self._lock: |
| return self._sessions.get(sid) |
|
|
| async def get_or_404(self, sid: str) -> ChatSession: |
| sess = await self.get(sid) |
| if not sess: |
| raise HTTPException(status_code=404, detail="session not found") |
| return sess |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| cfg = load_settings(default_config_path()) |
| app.state.cfg = cfg |
| app.state.developer_mode = is_developer_mode(cfg) |
| app.state.sessions = SessionStore(cfg) |
| yield |
|
|
|
|
| app = FastAPI(title="OpenStoryline Web", version="1.0.0", lifespan=lifespan) |
|
|
| app.add_middleware( |
| HttpRateLimitMiddleware, |
| limiter=RATE_LIMITER, |
| trust_proxy_headers=RATE_LIMIT_TRUST_PROXY_HEADERS, |
| ) |
|
|
| if os.path.isdir(STATIC_DIR): |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") |
|
|
| if os.path.isdir(NODE_MAP_DIR): |
| app.mount("/node_map", StaticFiles(directory=NODE_MAP_DIR), name="node_map") |
|
|
| api = APIRouter(prefix="/api") |
|
|
| def _rate_limit_reject_json(retry_after: float) -> JSONResponse: |
| ra = int(math.ceil(float(retry_after or 0.0))) |
| return JSONResponse( |
| {"detail": "Too Many Requests", "retry_after": ra}, |
| status_code=429, |
| headers={"Retry-After": str(ra)}, |
| ) |
|
|
| async def _enforce_upload_media_count_limit(request: Request, cost: float) -> Optional[JSONResponse]: |
| ip = _client_ip_from_http_scope(request.scope, RATE_LIMIT_TRUST_PROXY_HEADERS) |
| cost = float(max(0.0, cost)) |
|
|
| ok, ra, _ = await RATE_LIMITER.allow( |
| key="http:upload_media_count:all", |
| capacity=float(UPLOAD_MEDIA_COUNT_ALL_BURST), |
| refill_rate=_rpm_to_rps(float(UPLOAD_MEDIA_COUNT_ALL_RPM)), |
| cost=cost, |
| ) |
| if not ok: |
| return _rate_limit_reject_json(ra) |
|
|
| ok2, ra2, _ = await RATE_LIMITER.allow( |
| key=f"http:upload_media_count:{ip}", |
| capacity=float(HTTP_UPLOAD_MEDIA_COUNT_BURST), |
| refill_rate=_rpm_to_rps(float(HTTP_UPLOAD_MEDIA_COUNT_RPM)), |
| cost=cost, |
| ) |
| if not ok2: |
| return _rate_limit_reject_json(ra2) |
|
|
| return None |
|
|
| _TTS_UI_SECRET_KEYS = { |
| "api_key", |
| "access_token", |
| "authorization", |
| "token", |
| "password", |
| "secret", |
| "x-api-key", |
| "apikey", |
| "access_key", |
| "accesskey", |
| } |
|
|
| def _is_secret_field_name(k: str) -> bool: |
| if str(k or "").strip().lower() in _TTS_UI_SECRET_KEYS: |
| return True |
| return False |
|
|
| def _read_config_toml(path: str) -> dict: |
| if tomllib is None: |
| return {} |
| try: |
| p = Path(path) |
| with p.open("rb") as f: |
| return tomllib.load(f) or {} |
| except Exception: |
| return {} |
|
|
| def _get_default_pexels_api_key(cfg: Settings) -> str: |
| |
| try: |
| search_media = getattr(cfg, "search_media", None) |
| pexels_api_key = _s(getattr(search_media, "pexels_api_key", None) if search_media else None) |
| if pexels_api_key: |
| return pexels_api_key |
| else: |
| return "" |
| except Exception: |
| return "" |
|
|
| def _normalize_field_item(item) -> dict | None: |
| """ |
| item 支持: |
| - "uid" |
| - { key="uid", label="UID", required=true, secret=false, placeholder="..." } |
| """ |
| if isinstance(item, str): |
| key = item.strip() |
| if not key: |
| return None |
| return { |
| "key": key, |
| "secret": _is_secret_field_name(key), |
| } |
| return None |
|
|
| def _build_provider_schema(provider: str, label: str | None, fields: list[dict]) -> dict: |
| seen = set() |
| out = [] |
| for f in fields: |
| k = str(f.get("key") or "").strip() |
| if not k or k in seen: |
| continue |
| seen.add(k) |
| out.append({ |
| "key": k, |
| "label": f.get("label") or k, |
| "placeholder": f.get("placeholder") or f.get("label") or k, |
| "required": bool(f.get("required", False)), |
| "secret": bool(f.get("secret", False)), |
| }) |
| return {"provider": provider, "label": label or provider, "fields": out} |
|
|
| def _build_tts_ui_schema_from_config(config_path: str) -> dict: |
| """ |
| 返回: |
| { |
| "providers": [ |
| {"provider":"bytedance","label":"字节跳动","fields":[{"key":"uid",...}, ...]}, |
| ... |
| ] |
| } |
| """ |
| cfg = _read_config_toml(config_path) |
| tts = cfg.get("generate_voiceover", {}) |
|
|
| providers_out: list[dict] = [] |
|
|
| |
| providers = tts.get("providers") |
| if isinstance(providers, dict): |
| for provider, provider_cfg in providers.items(): |
| fields: list[dict] = [] |
| label = str(provider_cfg.get("label") or provider_cfg.get("name") or provider) |
| for key in provider_cfg.keys(): |
| f = _normalize_field_item(str(key)) |
| if f: |
| fields.append(f) |
|
|
| providers_out.append(_build_provider_schema(provider, label, fields)) |
|
|
| return {"providers": providers_out} |
|
|
| @app.get("/") |
| async def index(): |
| if not os.path.exists(INDEX_HTML): |
| return Response("index.html not found. Put it under ./web/index.html", media_type="text/plain", status_code=404) |
| return FileResponse(INDEX_HTML, media_type="text/html") |
|
|
| @app.get("/node-map") |
| async def node_map(): |
| if not os.path.exists(NODE_MAP_HTML): |
| return Response( |
| "node_map.html not found. Put it under ./web/node_map/node_map.html", |
| media_type="text/plain", |
| status_code=404, |
| ) |
| return FileResponse(NODE_MAP_HTML, media_type="text/html") |
|
|
| @api.get("/meta/tts") |
| async def get_tts_ui_schema(): |
| schema = _build_tts_ui_schema_from_config(default_config_path()) |
| return JSONResponse(schema) |
|
|
| |
| |
| |
| @api.post("/sessions") |
| async def create_session(): |
| store: SessionStore = app.state.sessions |
| sess = await store.create() |
| return JSONResponse(sess.snapshot()) |
|
|
|
|
| @api.get("/sessions/{session_id}") |
| async def get_session(session_id: str): |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
| return JSONResponse(sess.snapshot()) |
|
|
|
|
| @api.post("/sessions/{session_id}/clear") |
| async def clear_session_chat(session_id: str): |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
| async with sess.chat_lock: |
| sess.sent_media_total = 0 |
| sess._attach_stats_msg_idx = 1 |
| sess.lc_messages = [ |
| SystemMessage(content=get_prompt("instruction.system", lang=sess.lang)), |
| SystemMessage(content="【User media upload status】{}"), |
| ] |
| sess._attach_stats_msg_idx = 1 |
|
|
| sess.history = [] |
| sess._tool_history_index = {} |
| return JSONResponse({"ok": True}) |
|
|
| @api.post("/sessions/{session_id}/cancel") |
| async def cancel_session_turn(session_id: str): |
| """ |
| 打断当前正在进行的 LLM turn(流式回复/工具调用)。 |
| - 不清空 history / lc_messages |
| - 仅设置 cancel_event,由 WS 侧在流式循环中感知并安全收尾 |
| """ |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
| sess.cancel_event.set() |
| return JSONResponse({"ok": True}) |
|
|
| |
| |
| |
| @api.post("/sessions/{session_id}/media") |
| async def upload_media(session_id: str, request: Request, files: List[UploadFile] = File(...)): |
| if not isinstance(files, list) or not files: |
| raise HTTPException(status_code=400, detail="no files") |
|
|
| if MAX_UPLOAD_FILES_PER_REQUEST > 0 and len(files) > MAX_UPLOAD_FILES_PER_REQUEST: |
| raise HTTPException(status_code=400, detail=f"单次上传最多 {MAX_UPLOAD_FILES_PER_REQUEST} 个文件") |
|
|
| |
| rej = await _enforce_upload_media_count_limit(request, cost=float(len(files))) |
| if rej: |
| return rej |
|
|
| if UPLOAD_SEM.locked(): |
| raise HTTPException(status_code=429, detail="上传并发过高,请稍后重试") |
| await UPLOAD_SEM.acquire() |
|
|
| n = len(files) |
| try: |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
|
|
| |
| async with sess.media_lock: |
| sess._cleanup_stale_uploads_locked() |
| sess._check_media_caps_locked(add=n) |
| sess._direct_upload_reservations += n |
|
|
| display_names = [sanitize_filename(uf.filename or "unnamed") for uf in files] |
| store_filenames = sess._reserve_store_filenames_locked(display_names) |
|
|
| try: |
| metas = await sess.add_uploads(files, store_filenames=store_filenames) |
|
|
| finally: |
| async with sess.media_lock: |
| sess._direct_upload_reservations = max(0, sess._direct_upload_reservations - n) |
|
|
| return JSONResponse({ |
| "media": [sess.public_media(m) for m in metas], |
| "pending_media": sess.public_pending_media(), |
| }) |
| finally: |
| try: |
| UPLOAD_SEM.release() |
| except Exception: |
| pass |
|
|
| @api.post("/sessions/{session_id}/media/init") |
| async def init_resumable_media_upload(session_id: str, request: Request): |
| try: |
| data = await request.json() |
| if not isinstance(data, dict): |
| data = {} |
| except Exception: |
| data = {} |
|
|
| filename = sanitize_filename((data.get("filename") or data.get("name") or "unnamed")) |
| size = int(data.get("size") or 0) |
| if size <= 0: |
| raise HTTPException(status_code=400, detail="invalid size") |
|
|
| |
| rej = await _enforce_upload_media_count_limit(request, cost=1.0) |
| if rej: |
| return rej |
|
|
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
|
|
| async with sess.media_lock: |
| sess._cleanup_stale_uploads_locked() |
| sess._check_media_caps_locked(add=1) |
|
|
| store_filename = sess._reserve_store_filenames_locked([filename])[0] |
|
|
| upload_id = uuid.uuid4().hex |
| chunk_size = int(max(1, UPLOAD_RESUMABLE_CHUNK_BYTES)) |
| total_chunks = int(math.ceil(size / float(chunk_size))) |
|
|
| tmp_path = os.path.join(sess.uploads_dir, f"{upload_id}.part") |
| os.makedirs(os.path.dirname(tmp_path), exist_ok=True) |
| try: |
| with open(tmp_path, "wb"): |
| pass |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"cannot create temp file: {e}") |
|
|
| u = ResumableUpload( |
| upload_id=upload_id, |
| filename=filename, |
| store_filename=store_filename, |
| size=size, |
| chunk_size=chunk_size, |
| total_chunks=total_chunks, |
| tmp_path=os.path.abspath(tmp_path), |
| kind=detect_media_kind(filename), |
| created_ts=time.time(), |
| last_ts=time.time(), |
| ) |
| sess.resumable_uploads[upload_id] = u |
|
|
| return JSONResponse({ |
| "upload_id": upload_id, |
| "chunk_size": chunk_size, |
| "total_chunks": total_chunks, |
| "filename": filename, |
| }) |
|
|
|
|
| @api.post("/sessions/{session_id}/media/{upload_id}/chunk") |
| async def upload_resumable_media_chunk( |
| session_id: str, |
| upload_id: str, |
| index: int = Form(...), |
| chunk: UploadFile = File(...), |
| ): |
| if UPLOAD_SEM.locked(): |
| raise HTTPException(status_code=429, detail="上传并发过高,请稍后重试") |
| await UPLOAD_SEM.acquire() |
| try: |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
|
|
| async with sess.media_lock: |
| sess._cleanup_stale_uploads_locked() |
| u = sess.resumable_uploads.get(upload_id) |
|
|
| if not u: |
| raise HTTPException(status_code=404, detail="upload_id not found or expired") |
|
|
| idx = int(index) |
| if idx < 0 or idx >= u.total_chunks: |
| raise HTTPException(status_code=400, detail="invalid chunk index") |
|
|
| |
| expected_len = u.size - idx * u.chunk_size |
| if expected_len <= 0: |
| raise HTTPException(status_code=400, detail="invalid chunk index") |
| expected_len = min(u.chunk_size, expected_len) |
|
|
| written = 0 |
| async with u.lock: |
| if u.closed: |
| raise HTTPException(status_code=400, detail="upload already closed") |
|
|
| async with await anyio.open_file(u.tmp_path, "r+b") as out: |
| await out.seek(idx * u.chunk_size) |
| while True: |
| buf = await chunk.read(CHUNK_SIZE) |
| if not buf: |
| break |
| written += len(buf) |
| if written > expected_len: |
| raise HTTPException(status_code=400, detail="chunk too large") |
| await out.write(buf) |
|
|
| try: |
| await chunk.close() |
| except Exception: |
| pass |
|
|
| if written != expected_len: |
| raise HTTPException(status_code=400, detail=f"chunk size mismatch: {written} != {expected_len}") |
|
|
| u.received.add(idx) |
| u.last_ts = time.time() |
|
|
| return JSONResponse({ |
| "ok": True, |
| "received_chunks": len(u.received), |
| "total_chunks": u.total_chunks, |
| }) |
| finally: |
| try: |
| UPLOAD_SEM.release() |
| except Exception: |
| pass |
|
|
|
|
| @api.post("/sessions/{session_id}/media/{upload_id}/complete") |
| async def complete_resumable_media_upload(session_id: str, upload_id: str): |
| if UPLOAD_SEM.locked(): |
| raise HTTPException(status_code=429, detail="上传并发过高,请稍后重试") |
| await UPLOAD_SEM.acquire() |
| try: |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
|
|
| async with sess.media_lock: |
| sess._cleanup_stale_uploads_locked() |
| u = sess.resumable_uploads.get(upload_id) |
|
|
| if not u: |
| raise HTTPException(status_code=404, detail="upload_id not found or expired") |
|
|
| |
| async with u.lock: |
| u.closed = True |
| if len(u.received) != u.total_chunks: |
| missing = u.total_chunks - len(u.received) |
| raise HTTPException(status_code=400, detail=f"chunks missing: {missing}") |
|
|
| |
| async with sess.media_lock: |
| u2 = sess.resumable_uploads.pop(upload_id, None) |
|
|
| if not u2: |
| raise HTTPException(status_code=404, detail="upload_id not found") |
|
|
| meta = await sess.media_store.save_from_path( |
| u2.tmp_path, |
| store_filename=u2.store_filename, |
| display_name=u2.filename, |
| ) |
|
|
| async with sess.media_lock: |
| sess.load_media[meta.id] = meta |
| sess.pending_media_ids.append(meta.id) |
|
|
| return JSONResponse({ |
| "media": sess.public_media(meta), |
| "pending_media": sess.public_pending_media(), |
| }) |
| finally: |
| try: |
| UPLOAD_SEM.release() |
| except Exception: |
| pass |
|
|
|
|
| @api.post("/sessions/{session_id}/media/{upload_id}/cancel") |
| async def cancel_resumable_media_upload(session_id: str, upload_id: str): |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
|
|
| async with sess.media_lock: |
| u = sess.resumable_uploads.pop(upload_id, None) |
|
|
| if not u: |
| return JSONResponse({"ok": True}) |
|
|
| async with u.lock: |
| u.closed = True |
| try: |
| if u.tmp_path and os.path.exists(u.tmp_path): |
| os.remove(u.tmp_path) |
| except Exception: |
| pass |
|
|
| return JSONResponse({"ok": True}) |
|
|
| @api.get("/sessions/{session_id}/media/pending") |
| async def get_pending_media(session_id: str): |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
| return JSONResponse({"pending_media": sess.public_pending_media()}) |
|
|
|
|
| @api.delete("/sessions/{session_id}/media/pending/{media_id}") |
| async def delete_pending_media(session_id: str, media_id: str): |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
| await sess.delete_pending_media(media_id) |
| return JSONResponse({"ok": True, "pending_media": sess.public_pending_media()}) |
|
|
|
|
| @api.get("/sessions/{session_id}/media/{media_id}/thumb") |
| async def get_media_thumb(session_id: str, media_id: str): |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
|
|
| meta = sess.load_media.get(media_id) |
| if not meta: |
| raise HTTPException(status_code=404, detail="media not found") |
|
|
| |
| if meta.thumb_path and os.path.exists(meta.thumb_path): |
| return FileResponse(meta.thumb_path, media_type="image/jpeg") |
|
|
| |
| if meta.kind == "video": |
| return Response(content=video_placeholder_svg_bytes(), media_type="image/svg+xml") |
|
|
| |
| if meta.path and os.path.exists(meta.path): |
| return FileResponse(meta.path, media_type=guess_media_type(meta.path)) |
|
|
| raise HTTPException(status_code=404, detail="thumb not available") |
|
|
|
|
| @api.get("/sessions/{session_id}/media/{media_id}/file") |
| async def get_media_file(session_id: str, media_id: str): |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
|
|
| meta = sess.load_media.get(media_id) |
| if not meta: |
| raise HTTPException(status_code=404, detail="media not found") |
| if not meta.path or (not os.path.exists(meta.path)): |
| raise HTTPException(status_code=404, detail="file not found") |
|
|
| |
| if not _is_under_dir(meta.path, sess.media_store.media_dir): |
| raise HTTPException(status_code=403, detail="forbidden") |
|
|
| return FileResponse( |
| meta.path, |
| media_type=guess_media_type(meta.path), |
| filename=meta.name, |
| ) |
|
|
| @api.get("/sessions/{session_id}/preview") |
| async def preview_local_file(session_id: str, path: str): |
| """ |
| 把 summary.preview_urls 里的“服务器本地路径”安全地转成可访问 URL。 |
| 只允许访问:media_dir / outputs_dir / outputs_dir / bgm_dir / .server_cache 这些根目录下的文件。 |
| """ |
| store: SessionStore = app.state.sessions |
| sess = await store.get_or_404(session_id) |
|
|
| p = (path or "").strip() |
| if not p: |
| raise HTTPException(status_code=400, detail="empty path") |
| if "\x00" in p: |
| raise HTTPException(status_code=400, detail="bad path") |
|
|
| |
| if p.startswith("file://"): |
| p = p[len("file://"):] |
|
|
| |
| if os.path.isabs(p): |
| ap = os.path.abspath(p) |
| else: |
| ap = os.path.abspath(os.path.join(ROOT_DIR, p)) |
|
|
| allowed_roots = [ |
| os.path.abspath(sess.media_dir), |
| os.path.abspath(app.state.cfg.project.outputs_dir), |
| os.path.abspath(app.state.cfg.project.outputs_dir), |
| os.path.abspath(app.state.cfg.project.bgm_dir), |
| os.path.abspath(SERVER_CACHE_DIR), |
| ] |
|
|
| if not any(_is_under_dir(ap, r) for r in allowed_roots): |
| raise HTTPException(status_code=403, detail="forbidden") |
|
|
| if (not os.path.exists(ap)) or os.path.isdir(ap): |
| raise HTTPException(status_code=404, detail="file not found") |
|
|
| |
| headers = {"Cache-Control": "public, max-age=31536000, immutable"} if _is_under_dir(ap, SERVER_CACHE_DIR) else None |
|
|
| return FileResponse( |
| ap, |
| media_type=guess_media_type(ap), |
| filename=os.path.basename(ap), |
| headers=headers, |
| ) |
|
|
| app.include_router(api) |
|
|
|
|
| |
| |
| |
| def extract_text_delta(msg_chunk: Any) -> str: |
| |
| blocks = getattr(msg_chunk, "content_blocks", None) or [] |
| if blocks: |
| out = "" |
| for b in blocks: |
| if isinstance(b, dict) and b.get("type") == "text": |
| out += b.get("text", "") |
| return out |
| c = getattr(msg_chunk, "content", "") |
| return c if isinstance(c, str) else "" |
|
|
|
|
| async def ws_send(ws: WebSocket, type_: str, data: Any = None): |
| if getattr(ws, "client_state", None) != WebSocketState.CONNECTED: |
| return False |
| try: |
| await ws.send_json({"type": type_, "data": data}) |
| return True |
| except WebSocketDisconnect: |
| return False |
| except RuntimeError: |
| return False |
| except Exception as e: |
| if ClientDisconnected is not None and isinstance(e, ClientDisconnected): |
| return False |
| logger.exception("ws_send failed: type=%s err=%r", type_, e) |
| return False |
|
|
| @asynccontextmanager |
| async def mcp_sink_context(sink_func): |
| token = set_mcp_log_sink(sink_func) |
| try: |
| yield |
| finally: |
| reset_mcp_log_sink(token) |
|
|
|
|
| @app.websocket("/ws/sessions/{session_id}/chat") |
| async def ws_chat(ws: WebSocket, session_id: str): |
| client_ip = _client_ip_from_ws(ws, RATE_LIMIT_TRUST_PROXY_HEADERS) |
|
|
| ok, retry_after, _ = await RATE_LIMITER.allow( |
| key=f"ws:connect:{client_ip}", |
| capacity=float(WS_CONNECT_BURST), |
| refill_rate=_rpm_to_rps(float(WS_CONNECT_RPM)), |
| cost=1.0, |
| ) |
| if not ok: |
| try: |
| await ws.close(code=1013, reason=f"rate_limited, retry after {int(math.ceil(retry_after))}s") |
| except Exception: |
| debug_traceback_print(app.state.cfg) |
| pass |
| return |
| |
| if WS_CONN_SEM.locked(): |
| try: |
| await ws.close(code=1013, reason="Server busy (websocket connections limit)") |
| except Exception: |
| debug_traceback_print(app.state.cfg) |
| pass |
| return |
| |
| await WS_CONN_SEM.acquire() |
|
|
| try: |
| await ws.accept() |
|
|
| store: SessionStore = app.state.sessions |
| sess = await store.get(session_id) |
| if not sess: |
| await ws.close(code=4404, reason="session not found") |
| return |
| sess = await store.get_or_404(session_id) |
|
|
| await ws_send(ws, "session.snapshot", sess.snapshot()) |
|
|
| try: |
| while True: |
| req = await ws.receive_json() |
| if not isinstance(req, dict): |
| continue |
|
|
| t = req.get("type") |
| if t == "ping": |
| await ws_send(ws, "pong", {"ts": time.time()}) |
| continue |
|
|
| if t == "session.set_lang": |
| data = (req.get("data") or {}) |
| lang = (data.get("lang") or "").strip().lower() |
| if lang not in ("zh", "en"): |
| lang = "zh" |
|
|
| sess.lang = lang |
| if sess.client_context: |
| sess.client_context.lang = lang |
|
|
| await ws_send(ws, "session.lang", {"lang": lang}) |
| continue |
|
|
| if t == "chat.clear": |
| async with sess.chat_lock: |
| sess.sent_media_total = 0 |
| sess._attach_stats_msg_idx = 1 |
| sess.lc_messages = [ |
| SystemMessage(content=get_prompt("instruction.system", lang=sess.lang)), |
| SystemMessage(content="【User media upload status】{}"), |
| ] |
| sess._attach_stats_msg_idx = 1 |
| sess.history = [] |
| sess._tool_history_index = {} |
| await ws_send(ws, "chat.cleared", {"ok": True}) |
| continue |
|
|
| if t != "chat.send": |
| await ws_send(ws, "error", {"message": f"unknown type: {t}"}) |
| continue |
|
|
| |
| if sess.chat_lock.locked(): |
| await ws_send(ws, "error", {"message": "上一条消息尚未完成,请稍后再发送"}) |
| continue |
| |
| ok, retry_after, _ = await RATE_LIMITER.allow( |
| key="ws:chat_send:all", |
| capacity=float(WS_CHAT_SEND_ALL_BURST), |
| refill_rate=_rpm_to_rps(float(WS_CHAT_SEND_ALL_RPM)), |
| cost=1.0, |
| ) |
| if not ok: |
| await ws_send(ws, "error", { |
| "message": f"触发全局限流:请 {int(math.ceil(retry_after))} 秒后再试", |
| "retry_after": int(math.ceil(retry_after)), |
| }) |
| continue |
|
|
| ok, retry_after, _ = await RATE_LIMITER.allow( |
| key=f"ws:chat_send:{client_ip}", |
| capacity=float(WS_CHAT_SEND_BURST), |
| refill_rate=_rpm_to_rps(float(WS_CHAT_SEND_RPM)), |
| cost=1.0, |
| ) |
| if not ok: |
| await ws_send(ws, "error", { |
| "message": f"触发限流:请 {int(math.ceil(retry_after))} 秒后再试", |
| "retry_after": int(math.ceil(retry_after)), |
| }) |
| continue |
|
|
| if CHAT_TURN_SEM.locked(): |
| await ws_send(ws, "error", {"message": "服务器繁忙(模型并发已满),请稍后再试"}) |
| continue |
|
|
| await CHAT_TURN_SEM.acquire() |
| try: |
| |
| if sess.chat_lock.locked(): |
| await ws_send(ws, "error", {"message": "上一条消息尚未完成,请稍后再发送"}) |
| continue |
|
|
| data = (req.get("data", {}) or {}) |
|
|
| prompt = data.get("text", "") |
| prompt = (prompt or "").strip() |
| if not prompt: |
| continue |
|
|
| requested_llm = data.get("llm_model") |
| requested_vlm = data.get("vlm_model") |
|
|
| attachment_ids = data.get("attachment_ids") |
| if not isinstance(attachment_ids, list): |
| attachment_ids = None |
|
|
| async with sess.chat_lock: |
| |
| sess.cancel_event.clear() |
| |
| ok_cfg, err_cfg = sess.apply_service_config(data.get("service_config")) |
| if not ok_cfg: |
| await ws_send(ws, "error", {"message": err_cfg or "service_config invalid"}) |
| continue |
|
|
| |
| if isinstance(requested_llm, str): |
| m = requested_llm.strip() |
| if m: |
| sess.chat_model_key = m |
| if sess.client_context: |
| sess.client_context.chat_model_key = m |
|
|
| if isinstance(requested_vlm, str): |
| m2 = requested_vlm.strip() |
| if m2: |
| sess.vlm_model_key = m2 |
| if sess.client_context: |
| sess.client_context.vlm_model_key = m2 |
|
|
| requested_lang = data.get("lang") |
| if isinstance(requested_lang, str): |
| lang = requested_lang.strip().lower() |
| if lang in ("zh", "en"): |
| sess.lang = lang |
| |
| try: |
| await sess.ensure_agent() |
| except Exception as e: |
| await ws_send(ws, "error", {"message": f"{type(e).__name__}: {e}"}) |
| continue |
|
|
| sess._ensure_system_prompt() |
|
|
| if sess.client_context: |
| sess.client_context.lang = sess.lang |
|
|
| |
| attachments = await sess.take_pending_media_for_message(attachment_ids) |
| attachments_public = [sess.public_media(m) for m in attachments] |
|
|
| |
| turn_attached_count = len(attachments) |
| sess.sent_media_total = int(getattr(sess, "sent_media_total", 0)) + turn_attached_count |
|
|
| stats = { |
| "Number of media carried in this message sent by the user": turn_attached_count, |
| "Total number of media sent by the user in all conversations": sess.sent_media_total, |
| "Total number of media in user's media library": scan_media_dir(resolve_media_dir(app.state.cfg.project.media_dir, session_id=session_id)), |
| } |
|
|
| idx = int(getattr(sess, "_attach_stats_msg_idx", 1)) |
| if len(sess.lc_messages) <= idx: |
| while len(sess.lc_messages) <= idx: |
| sess.lc_messages.append(SystemMessage(content="")) |
|
|
| sess.lc_messages[idx] = SystemMessage( |
| content="【User media upload status】The following fields are used to determine the nature of the media provided by the user: \n" |
| + json.dumps(stats, ensure_ascii=False) |
| ) |
|
|
|
|
| |
| user_msg = { |
| "id": uuid.uuid4().hex[:12], |
| "role": "user", |
| "content": prompt, |
| "attachments": attachments_public, |
| "ts": time.time(), |
| } |
| sess.history.append(user_msg) |
| sess.lc_messages.append(HumanMessage(content=prompt)) |
|
|
| |
| |
|
|
| |
| await ws_send(ws, "chat.user", { |
| "text": prompt, |
| "attachments": attachments_public, |
| "pending_media": sess.public_pending_media(), |
| "llm_model_key": sess.chat_model_key, |
| "vlm_model_key": sess.vlm_model_key, |
| }) |
|
|
| |
| loop = asyncio.get_running_loop() |
| out_q: asyncio.Queue[Tuple[str, Any]] = asyncio.Queue() |
|
|
| def sink(ev: Any): |
| |
| if isinstance(ev, dict): |
| loop.call_soon_threadsafe(out_q.put_nowait, ("mcp", ev)) |
|
|
| new_messages: List[BaseMessage] = [] |
|
|
| async def pump_agent(): |
| nonlocal new_messages |
| try: |
| stream = sess.agent.astream( |
| {"messages": sess.lc_messages}, |
| context=sess.client_context, |
| stream_mode=["messages", "updates"], |
| ) |
| async for mode, chunk in stream: |
| if mode == "messages": |
| msg_chunk, meta = chunk |
| if meta.get("langgraph_node") == "model": |
| delta = extract_text_delta(msg_chunk) |
| if delta: |
| await out_q.put(("assistant.delta", delta)) |
|
|
| elif mode == "updates": |
| if isinstance(chunk, dict): |
| for _step, data in chunk.items(): |
| msgs = (data or {}).get("messages") or [] |
| new_messages.extend(msgs) |
|
|
| await out_q.put(("agent.done", None)) |
| except asyncio.CancelledError: |
| |
| |
| try: |
| out_q.put_nowait(("agent.cancelled", None)) |
| except Exception: |
| debug_traceback_print(app.state.cfg) |
| pass |
| raise |
|
|
| except Exception as e: |
| |
| await out_q.put(("agent.error", f"{type(e).__name__}: {e}")) |
|
|
|
|
| async def safe_send(type_: str, data: Any = None) -> bool: |
| try: |
| await ws_send(ws, type_, data) |
| return True |
| except WebSocketDisconnect: |
| return False |
| except RuntimeError as e: |
| |
| if 'Cannot call "send" once a close message has been sent.' in str(e): |
| return False |
| raise |
| except Exception as e: |
| |
| if e.__class__.__name__ == "ClientDisconnected": |
| return False |
| raise |
| |
| if not await ws_send(ws, "assistant.start", {}): |
| return |
|
|
| |
| seg_text = "" |
| seg_ts: Optional[float] = None |
|
|
| async def flush_segment(send_flush_event: bool): |
| """ |
| - send_flush_event=True:告诉前端立刻结束当前 assistant 气泡(不结束整个 turn) |
| - 若 seg_text 有内容:写入 history(用于刷新/回放) |
| """ |
| nonlocal seg_text, seg_ts |
|
|
| if send_flush_event: |
| if not await ws_send(ws, "assistant.flush", {}): |
| return |
|
|
| text = (seg_text or "").strip() |
| if text: |
| sess.history.append({ |
| "id": uuid.uuid4().hex[:12], |
| "role": "assistant", |
| "content": text, |
| "ts": seg_ts or time.time(), |
| }) |
|
|
| seg_text = "" |
| seg_ts = None |
|
|
| pump_task: Optional[asyncio.Task] = None |
|
|
| |
| def _tool_call_ids_from_ai_message(m: BaseMessage) -> set[str]: |
| ids: set[str] = set() |
|
|
| tc = getattr(m, "tool_calls", None) or [] |
| for c in tc: |
| _id = None |
| if isinstance(c, dict): |
| _id = c.get("id") or c.get("tool_call_id") |
| else: |
| _id = getattr(c, "id", None) or getattr(c, "tool_call_id", None) |
| if _id: |
| ids.add(str(_id)) |
|
|
| ak = getattr(m, "additional_kwargs", None) or {} |
| tc2 = ak.get("tool_calls") or [] |
| for c in tc2: |
| if isinstance(c, dict): |
| _id = c.get("id") or c.get("tool_call_id") |
| if _id: |
| ids.add(str(_id)) |
|
|
| return ids |
|
|
| |
| def _tool_call_ids_in_msgs(msgs: List[BaseMessage]) -> set[str]: |
| ids: set[str] = set() |
| for m in msgs: |
| if isinstance(m, AIMessage): |
| ids |= _tool_call_ids_from_ai_message(m) |
| return ids |
|
|
| |
| def _tool_result_ids_in_msgs(msgs: List[BaseMessage]) -> set[str]: |
| ids: set[str] = set() |
| for m in msgs: |
| if isinstance(m, ToolMessage): |
| tcid = getattr(m, "tool_call_id", None) |
| if tcid: |
| ids.add(str(tcid)) |
| return ids |
|
|
| |
| def _force_cancelled_tool_results(msgs: List[BaseMessage], cancel_ids: set[str]) -> List[BaseMessage]: |
| if not cancel_ids: |
| return msgs |
| cancelled_content = json.dumps({"cancelled": True}, ensure_ascii=False) |
| out: List[BaseMessage] = [] |
| for m in msgs: |
| if isinstance(m, ToolMessage): |
| tcid = getattr(m, "tool_call_id", None) |
| if tcid and str(tcid) in cancel_ids: |
| out.append(ToolMessage(content=cancelled_content, tool_call_id=str(tcid))) |
| continue |
| out.append(m) |
| return out |
|
|
| def _inject_cancelled_tool_messages(msgs: List[BaseMessage], tool_call_ids: List[str]) -> List[BaseMessage]: |
| if not tool_call_ids: |
| return msgs |
|
|
| out = list(msgs) |
|
|
| existing = set() |
| for m in out: |
| if isinstance(m, ToolMessage): |
| tcid = getattr(m, "tool_call_id", None) |
| if tcid: |
| existing.add(str(tcid)) |
|
|
| cancelled_content = json.dumps({"cancelled": True}, ensure_ascii=False) |
|
|
| for tcid in tool_call_ids: |
| tcid = str(tcid) |
| if tcid in existing: |
| continue |
|
|
| insert_at = None |
| for i in range(len(out) - 1, -1, -1): |
| m = out[i] |
| if isinstance(m, AIMessage) and (tcid in _tool_call_ids_from_ai_message(m)): |
| insert_at = i + 1 |
| break |
|
|
| if insert_at is None: |
| continue |
|
|
| out.insert(insert_at, ToolMessage(content=cancelled_content, tool_call_id=tcid)) |
| existing.add(tcid) |
|
|
| return out |
|
|
| def _sanitize_new_messages_on_cancel( |
| new_messages: List[BaseMessage], |
| *, |
| interrupted_text: str, |
| cancelled_tool_ids_from_ui: List[str], |
| ) -> List[BaseMessage]: |
| """ |
| 返回:应该写回 sess.lc_messages 的消息序列(只包含“用户可见/认可”的那部分) |
| - 工具:对未返回的 tool_call 补 ToolMessage({"cancelled": true}) |
| - 回复:用 interrupted_text 替换末尾 final AIMessage,避免把完整回复泄漏进上下文 |
| """ |
| msgs = list(new_messages or []) |
| interrupted_text = (interrupted_text or "").strip() |
|
|
| |
| ai_tool_ids = _tool_call_ids_in_msgs(msgs) |
| tool_result_ids = _tool_result_ids_in_msgs(msgs) |
| pending_tool_ids = ai_tool_ids - tool_result_ids |
|
|
| |
| ui_cancel_ids = {str(x) for x in (cancelled_tool_ids_from_ui or [])} |
|
|
| |
| |
| |
| cancel_ids = set(ui_cancel_ids) | set(pending_tool_ids) |
|
|
| |
| |
| msgs = _force_cancelled_tool_results(msgs, cancel_ids) |
|
|
| |
| msgs = _inject_cancelled_tool_messages(msgs, list(cancel_ids)) |
|
|
| |
| |
| |
| def _is_toolcall_ai(m: BaseMessage) -> bool: |
| return isinstance(m, AIMessage) and bool(_tool_call_ids_from_ai_message(m)) |
|
|
| def _is_text_ai(m: BaseMessage) -> bool: |
| if not isinstance(m, AIMessage): |
| return False |
| if _tool_call_ids_from_ai_message(m): |
| return False |
| c = getattr(m, "content", None) |
| return isinstance(c, str) and bool(c.strip()) |
|
|
| |
| last_text_ai_idx = None |
| for i in range(len(msgs) - 1, -1, -1): |
| if _is_text_ai(msgs[i]): |
| last_text_ai_idx = i |
| break |
|
|
| if interrupted_text: |
| if last_text_ai_idx is None: |
| msgs.append(AIMessage(content=interrupted_text)) |
| else: |
| |
| msgs = msgs[:last_text_ai_idx] + [AIMessage(content=interrupted_text)] |
| return msgs |
|
|
| |
| |
| |
| if last_text_ai_idx is not None: |
| has_toolcall_after = any(_is_toolcall_ai(m) for m in msgs[last_text_ai_idx + 1 :]) |
| if not has_toolcall_after: |
| msgs = msgs[:last_text_ai_idx] |
|
|
| return msgs |
|
|
| pump_task: Optional[asyncio.Task] = None |
| cancel_wait_task: Optional[asyncio.Task] = None |
|
|
| was_interrupted = False |
|
|
| try: |
| async with mcp_sink_context(sink): |
| pump_task = asyncio.create_task(pump_agent()) |
| cancel_wait_task = asyncio.create_task(sess.cancel_event.wait()) |
|
|
| while True: |
| |
| get_task = asyncio.create_task(out_q.get()) |
| done, _ = await asyncio.wait( |
| {get_task, cancel_wait_task}, |
| return_when=asyncio.FIRST_COMPLETED, |
| ) |
|
|
| |
| if get_task in done: |
| kind, payload = get_task.result() |
| else: |
| |
| try: |
| get_task.cancel() |
| await get_task |
| except asyncio.CancelledError: |
| debug_traceback_print(app.state.cfg) |
| pass |
| except Exception: |
| debug_traceback_print(app.state.cfg) |
| pass |
|
|
| kind, payload = ("agent.cancelled", None) |
|
|
| |
| |
| |
| if kind == "agent.cancelled": |
| |
| if was_interrupted: |
| break |
| was_interrupted = True |
| |
| if pump_task and (not pump_task.done()): |
| pump_task.cancel() |
|
|
| |
| cancelled_tool_recs: List[Dict[str, Any]] = [] |
| for tcid, idx in list(sess._tool_history_index.items()): |
| rec = sess.history[idx] |
| if rec.get("role") == "tool" and rec.get("state") == "running": |
| rec.update({ |
| "state": "error", |
| "progress": 1.0, |
| "message": "Cancelled by user", |
| "summary": {"cancelled": True}, |
| }) |
| cancelled_tool_recs.append(rec) |
|
|
| |
| for rec in cancelled_tool_recs: |
| await ws_send(ws, "tool.end", { |
| "tool_call_id": rec["tool_call_id"], |
| "server": rec["server"], |
| "name": rec["name"], |
| "is_error": True, |
| "summary": rec.get("summary"), |
| }) |
| |
| interrupted_text = (seg_text or "").strip() |
| if interrupted_text: |
| sess.history.append({ |
| "id": uuid.uuid4().hex[:12], |
| "role": "assistant", |
| "content": interrupted_text, |
| "ts": seg_ts or time.time(), |
| }) |
|
|
| |
| cancelled_tool_ids = [rec["tool_call_id"] for rec in cancelled_tool_recs] |
|
|
| commit_msgs = _sanitize_new_messages_on_cancel( |
| new_messages, |
| interrupted_text=interrupted_text, |
| cancelled_tool_ids_from_ui=cancelled_tool_ids, |
| ) |
|
|
| if commit_msgs: |
| sess.lc_messages.extend(commit_msgs) |
| elif interrupted_text: |
| |
| sess.lc_messages.append(AIMessage(content=interrupted_text)) |
|
|
|
|
| |
| await ws_send(ws, "assistant.end", {"text": interrupted_text, "interrupted": True}) |
|
|
| sess.cancel_event.clear() |
| break |
|
|
| |
| |
| |
| if kind == "assistant.delta": |
| delta = payload or "" |
| if delta: |
| if seg_ts is None: |
| seg_ts = time.time() |
| seg_text += delta |
| if not await ws_send(ws, "assistant.delta", {"delta": delta}): |
| raise WebSocketDisconnect() |
| continue |
|
|
| if kind == "mcp": |
| raw = payload |
|
|
| if raw.get("type") == "tool_start": |
| await flush_segment(send_flush_event=True) |
|
|
| rec = sess.apply_tool_event(raw) |
| if rec: |
| if raw["type"] == "tool_start": |
| await ws_send(ws, "tool.start", { |
| "tool_call_id": rec["tool_call_id"], |
| "server": rec["server"], |
| "name": rec["name"], |
| "args": rec["args"], |
| }) |
| elif raw["type"] == "tool_progress": |
| await ws_send(ws, "tool.progress", { |
| "tool_call_id": rec["tool_call_id"], |
| "server": rec["server"], |
| "name": rec["name"], |
| "progress": rec["progress"], |
| "message": rec["message"], |
| }) |
| elif raw["type"] == "tool_end": |
| await ws_send(ws, "tool.end", { |
| "tool_call_id": rec["tool_call_id"], |
| "server": rec["server"], |
| "name": rec["name"], |
| "is_error": rec["state"] == "error", |
| "summary": rec["summary"], |
| }) |
| continue |
|
|
| if kind == "agent.done": |
| final_text = (seg_text or "").strip() |
|
|
| if final_text: |
| sess.history.append({ |
| "id": uuid.uuid4().hex[:12], |
| "role": "assistant", |
| "content": final_text, |
| "ts": seg_ts or time.time(), |
| }) |
|
|
| if new_messages: |
| sess.lc_messages.extend(new_messages) |
|
|
| if not await ws_send(ws, "assistant.end", {"text": final_text}): |
| return |
| break |
|
|
| if kind == "agent.error": |
| err_text = str(payload or "unknown error") |
| partial = (seg_text or "").strip() |
|
|
| |
| if partial: |
| sess.history.append({ |
| "id": uuid.uuid4().hex[:12], |
| "role": "assistant", |
| "content": partial, |
| "ts": seg_ts or time.time(), |
| }) |
| sess.lc_messages.append(AIMessage(content=partial)) |
|
|
| if new_messages: |
| sess.lc_messages.extend(new_messages) |
|
|
| |
| await ws_send(ws, "error", {"message": err_text, "partial_text": partial}) |
| break |
| |
| except WebSocketDisconnect: |
| return |
| except asyncio.CancelledError: |
| |
| return |
| except Exception as e: |
| |
| if was_interrupted: |
| return |
| await ws_send(ws, "error", {"message": f"{type(e).__name__}: {e}", "partial_text": (seg_text or "").strip()}) |
| return |
| finally: |
| |
| if cancel_wait_task and (not cancel_wait_task.done()): |
| cancel_wait_task.cancel() |
|
|
| |
| if pump_task and (not pump_task.done()): |
| pump_task.cancel() |
| if pump_task: |
| try: |
| await asyncio.wait_for(pump_task, timeout=2.0) |
| except asyncio.TimeoutError: |
| debug_traceback_print(app.state.cfg) |
| pass |
| except asyncio.CancelledError: |
| debug_traceback_print(app.state.cfg) |
| pass |
| except Exception: |
| debug_traceback_print(app.state.cfg) |
| pass |
| finally: |
| try: |
| CHAT_TURN_SEM.release() |
| except Exception: |
| debug_traceback_print(app.state.cfg) |
| pass |
|
|
| except WebSocketDisconnect: |
| return |
| finally: |
| try: |
| WS_CONN_SEM.release() |
| except: |
| pass |
|
|