Spaces:
Running
Running
| from __future__ import annotations | |
| import asyncio | |
| import base64 | |
| import concurrent.futures | |
| import json | |
| import logging | |
| import os | |
| import threading | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| from typing import Any, Callable, List, Optional | |
| import aiosqlite | |
| import cv2 | |
| import numpy as np | |
| from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack | |
| from av import VideoFrame | |
| from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from api.drawing import draw_face_mesh, draw_hud, get_tesselation_connections | |
| from api.db import ( | |
| EventBuffer, | |
| create_session, | |
| end_session, | |
| get_db_path, | |
| init_database, | |
| store_focus_event, | |
| ) | |
| from config import get | |
| from models.face_mesh import FaceMeshDetector | |
| from ui.pipeline import ( | |
| FaceMeshPipeline, | |
| HybridFocusPipeline, | |
| L2CSPipeline, | |
| MLPPipeline, | |
| XGBoostPipeline, | |
| is_l2cs_weights_available, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| db_path = get_db_path() | |
| _inference_size = get("app.inference_size") or [640, 480] | |
| _inference_workers = get("app.inference_workers") or 4 | |
| _fused_threshold = get("l2cs_boost.fused_threshold") or 0.52 | |
| _no_face_cap = get("app.no_face_confidence_cap") or 0.1 | |
| _BOOST_BASE_W = get("l2cs_boost.base_weight") or 0.35 | |
| _BOOST_L2CS_W = get("l2cs_boost.l2cs_weight") or 0.65 | |
| _BOOST_VETO = get("l2cs_boost.veto_threshold") or 0.38 | |
| _FONT = cv2.FONT_HERSHEY_SIMPLEX | |
| _RED = (0, 0, 255) | |
| async def lifespan(app): | |
| global _cached_model_name | |
| print("Starting Focus Guard API") | |
| await init_database(db_path) | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| if row: | |
| _cached_model_name = row[0] | |
| print("[OK] Database initialized") | |
| try: | |
| pipelines["geometric"] = FaceMeshPipeline() | |
| print("[OK] FaceMeshPipeline (geometric) loaded") | |
| except Exception as e: | |
| print(f"[WARN] FaceMeshPipeline unavailable: {e}") | |
| try: | |
| pipelines["mlp"] = MLPPipeline() | |
| print("[OK] MLPPipeline loaded") | |
| except Exception as e: | |
| print(f"[ERR] Failed to load MLPPipeline: {e}") | |
| try: | |
| pipelines["hybrid"] = HybridFocusPipeline() | |
| print("[OK] HybridFocusPipeline loaded") | |
| except Exception as e: | |
| print(f"[WARN] HybridFocusPipeline unavailable: {e}") | |
| try: | |
| pipelines["xgboost"] = XGBoostPipeline() | |
| print("[OK] XGBoostPipeline loaded") | |
| except Exception as e: | |
| print(f"[ERR] Failed to load XGBoostPipeline: {e}") | |
| resolved_model = _first_available_pipeline_name(_cached_model_name) | |
| if resolved_model is not None and resolved_model != _cached_model_name: | |
| _cached_model_name = resolved_model | |
| async with aiosqlite.connect(db_path) as db: | |
| await db.execute( | |
| "UPDATE user_settings SET model_name = ? WHERE id = 1", | |
| (_cached_model_name,), | |
| ) | |
| await db.commit() | |
| if resolved_model is not None: | |
| print(f"[OK] Active model set to {resolved_model}") | |
| if is_l2cs_weights_available(): | |
| print("[OK] L2CS weights found (lazy-loaded on first use)") | |
| else: | |
| print("[WARN] L2CS weights not found") | |
| yield | |
| _inference_executor.shutdown(wait=False) | |
| print("Shutting down Focus Guard API") | |
| app = FastAPI(title="Focus Guard API", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| pcs = set() | |
| _cached_model_name = get("app.default_model") or "mlp" | |
| _l2cs_boost_enabled = False | |
| async def _wait_for_ice_gathering(pc: RTCPeerConnection): | |
| if pc.iceGatheringState == "complete": | |
| return | |
| done = asyncio.Event() | |
| def _on_state_change(): | |
| if pc.iceGatheringState == "complete": | |
| done.set() | |
| await done.wait() | |
| # ================ PYDANTIC MODELS ================ | |
| class SessionCreate(BaseModel): | |
| pass | |
| class SessionEnd(BaseModel): | |
| session_id: int | |
| class SettingsUpdate(BaseModel): | |
| model_name: Optional[str] = None | |
| l2cs_boost: Optional[bool] = None | |
| class VideoTransformTrack(VideoStreamTrack): | |
| def __init__(self, track, session_id: int, get_channel: Callable[[], Any]): | |
| super().__init__() | |
| self.track = track | |
| self.session_id = session_id | |
| self.get_channel = get_channel | |
| self.last_inference_time = 0 | |
| self.min_inference_interval = 1 / 60 | |
| self.last_frame = None | |
| async def recv(self): | |
| frame = await self.track.recv() | |
| img = frame.to_ndarray(format="bgr24") | |
| if img is None: | |
| return frame | |
| w_sz, h_sz = _inference_size[0], _inference_size[1] | |
| img = cv2.resize(img, (w_sz, h_sz)) | |
| now = datetime.now().timestamp() | |
| do_infer = (now - self.last_inference_time) >= self.min_inference_interval | |
| if do_infer: | |
| self.last_inference_time = now | |
| model_name = _cached_model_name | |
| if model_name == "l2cs" and pipelines.get("l2cs") is None: | |
| _ensure_l2cs() | |
| if model_name not in pipelines or pipelines.get(model_name) is None: | |
| model_name = 'mlp' | |
| active_pipeline = pipelines.get(model_name) | |
| if active_pipeline is not None: | |
| loop = asyncio.get_event_loop() | |
| out = await loop.run_in_executor( | |
| _inference_executor, | |
| _process_frame_safe, | |
| active_pipeline, | |
| img, | |
| model_name, | |
| ) | |
| is_focused = out["is_focused"] | |
| confidence = out.get("mlp_prob", out.get("raw_score", 0.0)) | |
| metadata = { | |
| "s_face": out.get("s_face", 0.0), | |
| "s_eye": out.get("s_eye", 0.0), | |
| "mar": out.get("mar", 0.0), | |
| "model": model_name, | |
| } | |
| h_f, w_f = img.shape[:2] | |
| lm = out.get("landmarks") | |
| eye_gaze_enabled = _l2cs_boost_enabled or model_name == "l2cs" | |
| if lm is not None: | |
| draw_face_mesh(img, lm, w_f, h_f) | |
| draw_hud(img, out, model_name) | |
| else: | |
| is_focused = False | |
| confidence = 0.0 | |
| metadata = {"model": model_name} | |
| cv2.rectangle(img, (0, 0), (img.shape[1], 55), (0, 0, 0), -1) | |
| cv2.putText(img, "NO MODEL", (10, 28), _FONT, 0.8, _RED, 2, cv2.LINE_AA) | |
| if self.session_id: | |
| await store_focus_event(self.session_id, is_focused, confidence, metadata) | |
| channel = self.get_channel() | |
| if channel and channel.readyState == "open": | |
| try: | |
| channel.send(json.dumps({ | |
| "type": "detection", | |
| "focused": is_focused, | |
| "confidence": round(confidence, 3), | |
| "detections": [], | |
| "model": model_name, | |
| })) | |
| except Exception: | |
| pass | |
| self.last_frame = img | |
| elif self.last_frame is not None: | |
| img = self.last_frame | |
| new_frame = VideoFrame.from_ndarray(img, format="bgr24") | |
| new_frame.pts = frame.pts | |
| new_frame.time_base = frame.time_base | |
| return new_frame | |
| # ================ STARTUP/SHUTDOWN ================ | |
| pipelines = { | |
| "geometric": None, | |
| "mlp": None, | |
| "hybrid": None, | |
| "xgboost": None, | |
| "l2cs": None, | |
| } | |
| # Thread pool for CPU-bound inference so the event loop stays responsive. | |
| _inference_executor = concurrent.futures.ThreadPoolExecutor( | |
| max_workers=_inference_workers, | |
| thread_name_prefix="inference", | |
| ) | |
| # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when | |
| # multiple frames are processed in parallel by the thread pool. | |
| _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost", "l2cs")} | |
| _l2cs_load_lock = threading.Lock() | |
| _l2cs_error: str | None = None | |
| def _ensure_l2cs(): | |
| # lazy-load L2CS on first use, double-checked locking | |
| global _l2cs_error | |
| if pipelines["l2cs"] is not None: | |
| return True | |
| with _l2cs_load_lock: | |
| if pipelines["l2cs"] is not None: | |
| return True | |
| if not is_l2cs_weights_available(): | |
| _l2cs_error = "Weights not found" | |
| return False | |
| try: | |
| pipelines["l2cs"] = L2CSPipeline() | |
| _l2cs_error = None | |
| print("[OK] L2CSPipeline lazy-loaded") | |
| return True | |
| except Exception as e: | |
| _l2cs_error = str(e) | |
| print(f"[ERR] L2CS lazy-load failed: {e}") | |
| return False | |
| def _process_frame_safe(pipeline, frame, model_name): | |
| with _pipeline_locks[model_name]: | |
| return pipeline.process_frame(frame) | |
| def _first_available_pipeline_name(preferred: str | None = None) -> str | None: | |
| if preferred and preferred in pipelines and pipelines.get(preferred) is not None: | |
| return preferred | |
| for name, pipeline in pipelines.items(): | |
| if pipeline is not None: | |
| return name | |
| return None | |
| _BOOST_BASE_W = 0.35 | |
| _BOOST_L2CS_W = 0.65 | |
| _BOOST_VETO = 0.38 # L2CS below this -> forced not-focused | |
| def _process_frame_with_l2cs_boost(base_pipeline, frame, base_model_name): | |
| with _pipeline_locks[base_model_name]: | |
| base_out = base_pipeline.process_frame(frame) | |
| l2cs_pipe = pipelines.get("l2cs") | |
| if l2cs_pipe is None: | |
| base_out["boost_active"] = False | |
| return base_out | |
| with _pipeline_locks["l2cs"]: | |
| l2cs_out = l2cs_pipe.process_frame(frame) | |
| base_score = base_out.get("mlp_prob", base_out.get("raw_score", 0.0)) | |
| l2cs_score = l2cs_out.get("raw_score", 0.0) | |
| fused_score = _BOOST_BASE_W * base_score + _BOOST_L2CS_W * l2cs_score | |
| is_focused = fused_score >= _fused_threshold | |
| base_out["raw_score"] = fused_score | |
| base_out["is_focused"] = is_focused | |
| base_out["boost_active"] = True | |
| base_out["base_score"] = round(base_score, 3) | |
| base_out["l2cs_score"] = round(l2cs_score, 3) | |
| if l2cs_out.get("gaze_yaw") is not None: | |
| base_out["gaze_yaw"] = l2cs_out["gaze_yaw"] | |
| base_out["gaze_pitch"] = l2cs_out["gaze_pitch"] | |
| return base_out | |
| # ================ WEBRTC SIGNALING ================ | |
| async def webrtc_offer(offer: dict): | |
| try: | |
| pc = RTCPeerConnection() | |
| pcs.add(pc) | |
| session_id = await create_session() | |
| channel_ref = {"channel": None} | |
| def on_datachannel(channel): | |
| channel_ref["channel"] = channel | |
| def on_track(track): | |
| if track.kind == "video": | |
| local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"]) | |
| pc.addTrack(local_track) | |
| async def on_ended(): | |
| pass | |
| async def on_connectionstatechange(): | |
| if pc.connectionState in ("failed", "closed", "disconnected"): | |
| try: | |
| await end_session(session_id) | |
| except Exception as e: | |
| logger.warning("WebRTC session end failed: %s", e) | |
| pcs.discard(pc) | |
| await pc.close() | |
| await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"])) | |
| answer = await pc.createAnswer() | |
| await pc.setLocalDescription(answer) | |
| await _wait_for_ice_gathering(pc) | |
| return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id} | |
| except Exception as e: | |
| logger.exception("WebRTC offer failed") | |
| raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}") | |
| # ================ WEBSOCKET ================ | |
| async def websocket_endpoint(websocket: WebSocket): | |
| from models.gaze_calibration import GazeCalibration | |
| from models.gaze_eye_fusion import GazeEyeFusion | |
| await websocket.accept() | |
| session_id = None | |
| frame_count = 0 | |
| running = True | |
| event_buffer = EventBuffer(db_path=db_path, flush_interval=2.0) | |
| # Calibration state (per-connection) | |
| # verifying: after fit, show a verification target and check gaze accuracy | |
| _cal: dict = {"cal": None, "collecting": False, "fusion": None, | |
| "verifying": False, "verify_target": None, "verify_samples": []} | |
| # Latest frame slot — only the most recent frame is kept, older ones are dropped. | |
| _slot = {"frame": None} | |
| _frame_ready = asyncio.Event() | |
| async def _receive_loop(): | |
| """Receive messages as fast as possible. Binary = frame, text = control.""" | |
| nonlocal session_id, running | |
| try: | |
| while running: | |
| msg = await websocket.receive() | |
| msg_type = msg.get("type", "") | |
| if msg_type == "websocket.disconnect": | |
| running = False | |
| _frame_ready.set() | |
| return | |
| # Binary message → JPEG frame (fast path, no base64) | |
| raw_bytes = msg.get("bytes") | |
| if raw_bytes is not None and len(raw_bytes) > 0: | |
| _slot["frame"] = raw_bytes | |
| _frame_ready.set() | |
| continue | |
| # Text message → JSON control command (or legacy base64 frame) | |
| text = msg.get("text") | |
| if not text: | |
| continue | |
| data = json.loads(text) | |
| if data["type"] == "frame": | |
| _slot["frame"] = base64.b64decode(data["image"]) | |
| _frame_ready.set() | |
| elif data["type"] == "start_session": | |
| session_id = await create_session() | |
| event_buffer.start() | |
| for p in pipelines.values(): | |
| if p is not None and hasattr(p, "reset_session"): | |
| p.reset_session() | |
| await websocket.send_json({"type": "session_started", "session_id": session_id}) | |
| elif data["type"] == "end_session": | |
| if session_id: | |
| await event_buffer.stop() | |
| summary = await end_session(session_id) | |
| if summary: | |
| await websocket.send_json({"type": "session_ended", "summary": summary}) | |
| session_id = None | |
| # ---- Calibration commands ---- | |
| elif data["type"] == "calibration_start": | |
| loop = asyncio.get_event_loop() | |
| await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| _cal["cal"] = GazeCalibration() | |
| _cal["collecting"] = True | |
| _cal["fusion"] = None | |
| # Tell L2CS pipeline to run every frame during calibration | |
| l2cs_pipe = pipelines.get("l2cs") | |
| if l2cs_pipe is not None and hasattr(l2cs_pipe, '_calibrating'): | |
| l2cs_pipe._calibrating = True | |
| cal = _cal["cal"] | |
| await websocket.send_json({ | |
| "type": "calibration_started", | |
| "num_points": cal.num_points, | |
| "target": list(cal.current_target), | |
| "index": cal.current_index, | |
| }) | |
| elif data["type"] == "calibration_next": | |
| cal = _cal.get("cal") | |
| if _cal.get("verifying"): | |
| # Verification phase complete — user clicked next | |
| _cal["verifying"] = False | |
| _cal["collecting"] = False | |
| # Re-enable frame skipping | |
| l2cs_pipe = pipelines.get("l2cs") | |
| if l2cs_pipe is not None and hasattr(l2cs_pipe, '_calibrating'): | |
| l2cs_pipe._calibrating = False | |
| # Check verification samples | |
| v_samples = _cal.get("verify_samples", []) | |
| vt = _cal.get("verify_target", [0.5, 0.5]) | |
| if len(v_samples) >= 3: | |
| med_yaw = float(np.median([s[0] for s in v_samples])) | |
| med_pitch = float(np.median([s[1] for s in v_samples])) | |
| px, py, err, passed = cal.verify(med_yaw, med_pitch, vt[0], vt[1]) | |
| print(f"[CAL] Verification: target=({vt[0]:.2f},{vt[1]:.2f}) " | |
| f"predicted=({px:.3f},{py:.3f}) error={err:.3f} passed={passed}") | |
| else: | |
| passed = True # not enough samples, trust the fit | |
| _cal["fusion"] = GazeEyeFusion(cal) | |
| await websocket.send_json({ | |
| "type": "calibration_done", | |
| "success": True, | |
| "verified": passed, | |
| }) | |
| elif cal is not None: | |
| more = cal.advance() | |
| if more: | |
| await websocket.send_json({ | |
| "type": "calibration_point", | |
| "target": list(cal.current_target), | |
| "index": cal.current_index, | |
| }) | |
| else: | |
| # All 9 points collected — try to fit | |
| _cal["collecting"] = False | |
| ok = cal.fit() | |
| if ok: | |
| # Enter verification phase: show center target | |
| _cal["verifying"] = True | |
| _cal["verify_target"] = [0.5, 0.5] | |
| _cal["verify_samples"] = [] | |
| await websocket.send_json({ | |
| "type": "calibration_verify", | |
| "target": [0.5, 0.5], | |
| "message": "Look at the dot to verify calibration", | |
| }) | |
| else: | |
| # Re-enable frame skipping | |
| l2cs_pipe = pipelines.get("l2cs") | |
| if l2cs_pipe is not None and hasattr(l2cs_pipe, '_calibrating'): | |
| l2cs_pipe._calibrating = False | |
| await websocket.send_json( | |
| { | |
| "type": "calibration_done", | |
| "success": False, | |
| "error": "Not enough samples", | |
| } | |
| ) | |
| elif data["type"] == "calibration_cancel": | |
| _cal["cal"] = None | |
| _cal["collecting"] = False | |
| _cal["fusion"] = None | |
| l2cs_pipe = pipelines.get("l2cs") | |
| if l2cs_pipe is not None and hasattr(l2cs_pipe, '_calibrating'): | |
| l2cs_pipe._calibrating = False | |
| await websocket.send_json({"type": "calibration_cancelled"}) | |
| except WebSocketDisconnect: | |
| running = False | |
| _frame_ready.set() | |
| except Exception as e: | |
| print(f"[WS] receive error: {e}") | |
| running = False | |
| _frame_ready.set() | |
| async def _process_loop(): | |
| """Process only the latest frame, dropping stale ones.""" | |
| nonlocal frame_count, running | |
| loop = asyncio.get_event_loop() | |
| while running: | |
| await _frame_ready.wait() | |
| _frame_ready.clear() | |
| if not running: | |
| return | |
| raw = _slot["frame"] | |
| _slot["frame"] = None | |
| if raw is None: | |
| continue | |
| try: | |
| nparr = np.frombuffer(raw, np.uint8) | |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if frame is None: | |
| continue | |
| frame = cv2.resize(frame, (_inference_size[0], _inference_size[1])) | |
| # During calibration collection, always use L2CS | |
| collecting = _cal.get("collecting", False) | |
| if collecting: | |
| if pipelines.get("l2cs") is None: | |
| await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| use_model = "l2cs" if pipelines.get("l2cs") is not None else _cached_model_name | |
| else: | |
| use_model = _cached_model_name | |
| model_name = use_model | |
| if model_name == "l2cs" and pipelines.get("l2cs") is None: | |
| await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| if model_name not in pipelines or pipelines.get(model_name) is None: | |
| model_name = "mlp" | |
| active_pipeline = pipelines.get(model_name) | |
| # L2CS boost: run L2CS alongside base model | |
| use_boost = ( | |
| _l2cs_boost_enabled | |
| and model_name != "l2cs" | |
| and pipelines.get("l2cs") is not None | |
| and not collecting | |
| ) | |
| landmarks_list = None | |
| out = None | |
| if active_pipeline is not None: | |
| if use_boost: | |
| out = await loop.run_in_executor( | |
| _inference_executor, | |
| _process_frame_with_l2cs_boost, | |
| active_pipeline, | |
| frame, | |
| model_name, | |
| ) | |
| else: | |
| out = await loop.run_in_executor( | |
| _inference_executor, | |
| _process_frame_safe, | |
| active_pipeline, | |
| frame, | |
| model_name, | |
| ) | |
| is_focused = out["is_focused"] | |
| confidence = out.get("mlp_prob", out.get("raw_score", 0.0)) | |
| lm = out.get("landmarks") | |
| if lm is not None: | |
| landmarks_list = [ | |
| [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)] | |
| for i in range(lm.shape[0]) | |
| ] | |
| # Calibration sample collection (L2CS gaze angles) | |
| if collecting and _cal.get("cal") is not None: | |
| pipe_yaw = out.get("gaze_yaw") | |
| pipe_pitch = out.get("gaze_pitch") | |
| if pipe_yaw is not None and pipe_pitch is not None: | |
| _cal["cal"].collect_sample(pipe_yaw, pipe_pitch) | |
| # Verification sample collection | |
| if _cal.get("verifying") and out.get("gaze_yaw") is not None: | |
| _cal["verify_samples"].append( | |
| (out["gaze_yaw"], out["gaze_pitch"]) | |
| ) | |
| # Gaze fusion (single call — applied before event logging | |
| # and response to avoid double-EMA smoothing) | |
| fusion = _cal.get("fusion") | |
| has_gaze = out.get("gaze_yaw") is not None | |
| fuse = None | |
| if fusion is not None and has_gaze and (model_name == "l2cs" or use_boost): | |
| fuse = fusion.update(out["gaze_yaw"], out["gaze_pitch"], lm) | |
| if model_name == "l2cs": | |
| # L2CS standalone: fusion fully controls focus decision | |
| is_focused = fuse["focused"] | |
| confidence = fuse["focus_score"] | |
| elif use_boost and fuse is not None: | |
| # Boost mode: blend base confidence with continuous gaze score | |
| gaze_focus = fuse["focus_score"] | |
| confidence = 0.6 * confidence + 0.4 * gaze_focus | |
| is_focused = confidence >= _fused_threshold | |
| if session_id: | |
| metadata = { | |
| "s_face": out.get("s_face", 0.0), | |
| "s_eye": out.get("s_eye", 0.0), | |
| "mar": out.get("mar", 0.0), | |
| "model": model_name, | |
| } | |
| event_buffer.add(session_id, is_focused, confidence, metadata) | |
| else: | |
| is_focused = False | |
| confidence = 0.0 | |
| resp = { | |
| "type": "detection", | |
| "focused": is_focused, | |
| "confidence": round(confidence, 3), | |
| "detections": [], | |
| "model": model_name, | |
| "fc": frame_count, | |
| "frame_count": frame_count, | |
| "eye_gaze_enabled": _l2cs_boost_enabled or model_name == "l2cs", | |
| } | |
| if out is not None: | |
| if out.get("yaw") is not None: | |
| resp["yaw"] = round(out["yaw"], 1) | |
| resp["pitch"] = round(out["pitch"], 1) | |
| resp["roll"] = round(out["roll"], 1) | |
| if out.get("mar") is not None: | |
| resp["mar"] = round(out["mar"], 3) | |
| resp["sf"] = round(out.get("s_face", 0), 3) | |
| resp["se"] = round(out.get("s_eye", 0), 3) | |
| # Attach gaze fusion fields + raw gaze angles for visualization | |
| if fuse is not None: | |
| resp["gaze_x"] = fuse["gaze_x"] | |
| resp["gaze_y"] = fuse["gaze_y"] | |
| resp["on_screen"] = fuse["on_screen"] | |
| if model_name == "l2cs": | |
| resp["focused"] = fuse["focused"] | |
| resp["confidence"] = round(fuse["focus_score"], 3) | |
| if has_gaze: | |
| resp["gaze_yaw"] = round(out["gaze_yaw"], 4) | |
| resp["gaze_pitch"] = round(out["gaze_pitch"], 4) | |
| if out.get("boost_active"): | |
| resp["boost"] = True | |
| resp["base_score"] = out.get("base_score", 0) | |
| resp["l2cs_score"] = out.get("l2cs_score", 0) | |
| if landmarks_list is not None: | |
| resp["lm"] = landmarks_list | |
| await websocket.send_json(resp) | |
| frame_count += 1 | |
| except Exception as e: | |
| print(f"[WS] process error: {e}") | |
| try: | |
| await asyncio.gather(_receive_loop(), _process_loop()) | |
| except Exception: | |
| pass | |
| finally: | |
| running = False | |
| if session_id: | |
| await event_buffer.stop() | |
| await end_session(session_id) | |
| # ================ API ENDPOINTS ================ | |
| async def api_start_session(): | |
| session_id = await create_session() | |
| return {"session_id": session_id} | |
| async def api_end_session(data: SessionEnd): | |
| summary = await end_session(data.session_id) | |
| if not summary: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return summary | |
| async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| # limit=-1 returns all rows (export); otherwise paginate | |
| limit_clause = "LIMIT ? OFFSET ?" | |
| params = [] | |
| base_query = "SELECT * FROM focus_sessions" | |
| where_clause = "" | |
| if filter == "today": | |
| date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "week": | |
| date_filter = datetime.now() - timedelta(days=7) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "month": | |
| date_filter = datetime.now() - timedelta(days=30) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "all": | |
| where_clause = " WHERE end_time IS NOT NULL" | |
| query = f"{base_query}{where_clause} ORDER BY start_time DESC" | |
| if limit == -1: | |
| pass | |
| else: | |
| query += f" {limit_clause}" | |
| params.extend([limit, offset]) | |
| cursor = await db.execute(query, tuple(params)) | |
| rows = await cursor.fetchall() | |
| return [dict(row) for row in rows] | |
| async def import_sessions(sessions: List[dict]): | |
| count = 0 | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| for session in sessions: | |
| # Use .get() to handle potential missing fields from older versions or edits | |
| await db.execute(""" | |
| INSERT INTO focus_sessions ( | |
| start_time, end_time, duration_seconds, focus_score, | |
| total_frames, focused_frames, created_at | |
| ) | |
| VALUES (?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| session.get('start_time'), | |
| session.get('end_time'), | |
| session.get('duration_seconds', 0), | |
| session.get('focus_score', 0.0), | |
| session.get('total_frames', 0), | |
| session.get('focused_frames', 0), | |
| session.get('created_at', session.get('start_time')) | |
| )) | |
| count += 1 | |
| await db.commit() | |
| return {"status": "success", "count": count} | |
| except Exception as e: | |
| print(f"Import Error: {e}") | |
| return {"status": "error", "message": str(e)} | |
| async def clear_history(): | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| # events reference sessions via FK | |
| await db.execute("DELETE FROM focus_events") | |
| await db.execute("DELETE FROM focus_sessions") | |
| await db.commit() | |
| return {"status": "success", "message": "History cleared"} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_session(session_id: int): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,)) | |
| row = await cursor.fetchone() | |
| if not row: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| session = dict(row) | |
| cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,)) | |
| events = [dict(r) for r in await cursor.fetchall()] | |
| session['events'] = events | |
| return session | |
| async def get_settings(): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| result = dict(row) if row else { | |
| "model_name": "mlp", | |
| } | |
| result['l2cs_boost'] = _l2cs_boost_enabled | |
| return result | |
| async def update_settings(settings: SettingsUpdate): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1") | |
| exists = await cursor.fetchone() | |
| if not exists: | |
| await db.execute("INSERT INTO user_settings (id, model_name) VALUES (1, 'mlp')") | |
| await db.commit() | |
| updates = [] | |
| params = [] | |
| if settings.model_name is not None and settings.model_name in pipelines: | |
| if settings.model_name == "l2cs": | |
| loop = asyncio.get_event_loop() | |
| loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| if not loaded: | |
| raise HTTPException(status_code=400, detail=f"L2CS model unavailable: {_l2cs_error}") | |
| elif pipelines[settings.model_name] is None: | |
| raise HTTPException(status_code=400, detail=f"Model '{settings.model_name}' not loaded") | |
| updates.append("model_name = ?") | |
| params.append(settings.model_name) | |
| global _cached_model_name | |
| _cached_model_name = settings.model_name | |
| if settings.l2cs_boost is not None: | |
| global _l2cs_boost_enabled | |
| if settings.l2cs_boost: | |
| loop = asyncio.get_event_loop() | |
| loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| if not loaded: | |
| raise HTTPException(status_code=400, detail=f"L2CS boost unavailable: {_l2cs_error}") | |
| _l2cs_boost_enabled = settings.l2cs_boost | |
| if updates: | |
| query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1" | |
| await db.execute(query, tuple(params)) | |
| await db.commit() | |
| return {"status": "success", "updated": len(updates) > 0} | |
| async def get_system_stats(): | |
| """Return server CPU and memory usage for UI display.""" | |
| try: | |
| import psutil | |
| cpu = psutil.cpu_percent(interval=0.1) | |
| mem = psutil.virtual_memory() | |
| return { | |
| "cpu_percent": round(cpu, 1), | |
| "memory_percent": round(mem.percent, 1), | |
| "memory_used_mb": round(mem.used / (1024 * 1024), 0), | |
| "memory_total_mb": round(mem.total / (1024 * 1024), 0), | |
| } | |
| except ImportError: | |
| return { | |
| "cpu_percent": None, | |
| "memory_percent": None, | |
| "memory_used_mb": None, | |
| "memory_total_mb": None, | |
| } | |
| async def get_stats_summary(): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_sessions = (await cursor.fetchone())[0] | |
| cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_focus_time = (await cursor.fetchone())[0] or 0 | |
| cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| avg_focus_score = (await cursor.fetchone())[0] or 0.0 | |
| cursor = await db.execute( | |
| """ | |
| SELECT DISTINCT DATE(start_time) as session_date | |
| FROM focus_sessions | |
| WHERE end_time IS NOT NULL | |
| ORDER BY session_date DESC | |
| """ | |
| ) | |
| dates = [row[0] for row in await cursor.fetchall()] | |
| streak_days = 0 | |
| if dates: | |
| current_date = datetime.now().date() | |
| for i, date_str in enumerate(dates): | |
| session_date = datetime.fromisoformat(date_str).date() | |
| expected_date = current_date - timedelta(days=i) | |
| if session_date == expected_date: | |
| streak_days += 1 | |
| else: | |
| break | |
| return { | |
| 'total_sessions': total_sessions, | |
| 'total_focus_time': int(total_focus_time), | |
| 'avg_focus_score': round(avg_focus_score, 3), | |
| 'streak_days': streak_days | |
| } | |
| async def get_available_models(): | |
| """Return model names, statuses, and which is currently active.""" | |
| statuses = {} | |
| errors = {} | |
| available = [] | |
| for name, p in pipelines.items(): | |
| if name == "l2cs": | |
| if p is not None: | |
| statuses[name] = "ready" | |
| available.append(name) | |
| elif is_l2cs_weights_available(): | |
| statuses[name] = "lazy" | |
| available.append(name) | |
| elif _l2cs_error: | |
| statuses[name] = "error" | |
| errors[name] = _l2cs_error | |
| else: | |
| statuses[name] = "unavailable" | |
| elif p is not None: | |
| statuses[name] = "ready" | |
| available.append(name) | |
| else: | |
| statuses[name] = "unavailable" | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| current = row[0] if row else "mlp" | |
| if current not in available and available: | |
| current = available[0] | |
| l2cs_boost_available = ( | |
| statuses.get("l2cs") in ("ready", "lazy") and current != "l2cs" | |
| ) | |
| return { | |
| "available": available, | |
| "current": current, | |
| "statuses": statuses, | |
| "errors": errors, | |
| "l2cs_boost": _l2cs_boost_enabled, | |
| "l2cs_boost_available": l2cs_boost_available, | |
| } | |
| async def l2cs_status(): | |
| """L2CS-specific status: weights available, loaded, and calibration info.""" | |
| loaded = pipelines.get("l2cs") is not None | |
| return { | |
| "weights_available": is_l2cs_weights_available(), | |
| "loaded": loaded, | |
| "error": _l2cs_error, | |
| } | |
| async def get_mesh_topology(): | |
| """Return tessellation edge pairs for client-side face mesh drawing (cached by client).""" | |
| return {"tessellation": get_tesselation_connections()} | |
| async def health_check(): | |
| available = [name for name, p in pipelines.items() if p is not None] | |
| return {"status": "healthy", "models_loaded": available, "database": os.path.exists(db_path)} | |
| # ================ STATIC FILES (SPA SUPPORT) ================ | |
| # Resolve frontend dir from this file so it works regardless of cwd. | |
| # Prefer a built `dist/` app when present, otherwise fall back to `static/`. | |
| _BASE_DIR = Path(__file__).resolve().parent | |
| _DIST_DIR = _BASE_DIR / "dist" | |
| _STATIC_DIR = _BASE_DIR / "static" | |
| _FRONTEND_DIR = _DIST_DIR if (_DIST_DIR / "index.html").is_file() else _STATIC_DIR | |
| _ASSETS_DIR = _FRONTEND_DIR / "assets" | |
| # 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all | |
| if _ASSETS_DIR.is_dir(): | |
| app.mount("/assets", StaticFiles(directory=str(_ASSETS_DIR)), name="assets") | |
| # 2. Catch-all for SPA: serve index.html for app routes, never for /assets (would break JS MIME type) | |
| async def serve_react_app(full_path: str, request: Request): | |
| if full_path.startswith("api") or full_path.startswith("ws"): | |
| raise HTTPException(status_code=404, detail="Not Found") | |
| # Don't serve HTML for asset paths; let them 404 so we don't break module script loading | |
| if full_path.startswith("assets") or full_path.startswith("assets/"): | |
| raise HTTPException(status_code=404, detail="Not Found") | |
| file_path = _FRONTEND_DIR / full_path | |
| if full_path and file_path.is_file(): | |
| return FileResponse(str(file_path)) | |
| index_path = _FRONTEND_DIR / "index.html" | |
| if index_path.is_file(): | |
| return FileResponse(str(index_path)) | |
| return {"message": "React app not found. Please run 'npm run build' and copy dist to static if needed."} | |