Spaces:
Paused
Paused
| """ | |
| Model Manager for real-time motion generation (HF Space version) | |
| Loads model from Hugging Face Hub instead of local checkpoints. | |
| """ | |
| import json | |
| import os | |
| import threading | |
| import time | |
| from collections import deque | |
| import numpy as np | |
| import torch | |
| import traceback | |
| import gc | |
| import math | |
| import glob | |
| import urllib.request | |
| from transformers import AutoModel | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # JOINT RECOVERY β inlined from motion_process.py | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def qinv(q): | |
| assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)" | |
| mask = torch.ones_like(q) | |
| mask[..., 1:] = -mask[..., 1:] | |
| return q * mask | |
| def qrot(q, v): | |
| assert q.shape[-1] == 4 | |
| assert v.shape[-1] == 3 | |
| assert q.shape[:-1] == v.shape[:-1] | |
| original_shape = list(v.shape) | |
| q = q.contiguous().view(-1, 4) | |
| v = v.contiguous().view(-1, 3) | |
| qvec = q[:, 1:] | |
| uv = torch.cross(qvec, v, dim=1) | |
| uuv = torch.cross(qvec, uv, dim=1) | |
| return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) | |
| class StreamJointRecovery263: | |
| """ | |
| Stream version of recover_joint_positions_263 that processes one frame at a time. | |
| Maintains cumulative state for rotation angles and positions. | |
| Key insight: The batch version uses PREVIOUS frame's velocity for the current frame, | |
| so we need to delay the velocity application by one frame. | |
| Args: | |
| joints_num: Number of joints in the skeleton | |
| smoothing_alpha: EMA smoothing factor (0.0 to 1.0) | |
| - 1.0 = no smoothing (default), output follows input exactly | |
| - 0.0 = infinite smoothing, output never changes | |
| - Recommended values: 0.3-0.7 for visible smoothing | |
| - Formula: smoothed = alpha * current + (1 - alpha) * previous | |
| """ | |
| def __init__(self, joints_num: int, smoothing_alpha: float = 1.0): | |
| self.joints_num = joints_num | |
| self.smoothing_alpha = np.clip(smoothing_alpha, 0.0, 1.0) | |
| self.reset() | |
| def reset(self): | |
| """Reset the accumulated state""" | |
| self.r_rot_ang_accum = 0.0 | |
| self.r_pos_accum = np.array([0.0, 0.0, 0.0]) | |
| # Store previous frame's velocities for delayed application | |
| self.prev_rot_vel = 0.0 | |
| self.prev_linear_vel = np.array([0.0, 0.0]) | |
| # Store previous smoothed joints for EMA | |
| self.prev_smoothed_joints = None | |
| def process_frame(self, frame_data: np.ndarray, heading_override=None) -> np.ndarray: | |
| """ | |
| Process a single frame and return joint positions for that frame. | |
| Args: | |
| frame_data: numpy array of shape (263,) for a single frame | |
| heading_override: float or None. If set, overrides AI rotation with | |
| this angle (in radians). AI velocity magnitude is preserved, | |
| applied in heading direction. None = original AI behavior. | |
| Returns: | |
| joints: numpy array of shape (joints_num, 3) representing joint positions | |
| """ | |
| # Convert to torch tensor | |
| feature_vec = torch.from_numpy(frame_data).float() | |
| # Extract current frame's velocities (will be used in NEXT frame) | |
| curr_rot_vel = feature_vec[0].item() | |
| curr_linear_vel = feature_vec[1:3].numpy() | |
| # βββ HEADING OVERRIDE βββ | |
| if heading_override is not None: | |
| # User controls direction β override AI rotation | |
| self.r_rot_ang_accum = heading_override | |
| else: | |
| # Original behavior β AI controls direction | |
| self.r_rot_ang_accum += self.prev_rot_vel | |
| # Calculate current rotation quaternion using accumulated angle | |
| r_rot_quat = torch.zeros(4, dtype=torch.float32) | |
| r_rot_quat[0] = np.cos(self.r_rot_ang_accum) | |
| r_rot_quat[2] = np.sin(self.r_rot_ang_accum) | |
| # Create velocity vector with Y=0 using PREVIOUS frame's velocity | |
| r_vel = np.array([self.prev_linear_vel[0], 0.0, self.prev_linear_vel[1]]) | |
| # Apply inverse rotation to velocity using CURRENT rotation | |
| r_vel_torch = torch.from_numpy(r_vel.astype(np.float32)).float() | |
| r_vel_rotated = qrot(qinv(r_rot_quat).unsqueeze(0), r_vel_torch.unsqueeze(0)) | |
| r_vel_rotated = r_vel_rotated.squeeze(0).numpy() | |
| # Update accumulated position with rotated velocity | |
| self.r_pos_accum += r_vel_rotated | |
| # Get Y position from data | |
| r_pos = self.r_pos_accum.copy() | |
| r_pos[1] = feature_vec[3].item() | |
| # Extract local joint positions | |
| positions = feature_vec[4 : (self.joints_num - 1) * 3 + 4] | |
| positions = positions.view(-1, 3).float() | |
| # Apply inverse rotation to local joints | |
| r_rot_quat_expanded = ( | |
| qinv(r_rot_quat).unsqueeze(0).expand(positions.shape[0], 4) | |
| ) | |
| positions = qrot(r_rot_quat_expanded, positions) | |
| # Add root XZ to joints | |
| positions[:, 0] += r_pos[0] | |
| positions[:, 2] += r_pos[2] | |
| # Concatenate root and joints | |
| r_pos_torch = torch.from_numpy(r_pos).float() | |
| positions = torch.cat([r_pos_torch.unsqueeze(0), positions], dim=0) | |
| # Convert to numpy | |
| joints_np = positions.detach().cpu().numpy() | |
| # Apply EMA smoothing if enabled | |
| if self.smoothing_alpha < 1.0: | |
| if self.prev_smoothed_joints is None: | |
| # First frame, no smoothing possible | |
| self.prev_smoothed_joints = joints_np.copy() | |
| else: | |
| # EMA: smoothed = alpha * current + (1 - alpha) * previous | |
| joints_np = ( | |
| self.smoothing_alpha * joints_np | |
| + (1.0 - self.smoothing_alpha) * self.prev_smoothed_joints | |
| ) | |
| self.prev_smoothed_joints = joints_np.copy() | |
| # Store current velocities for next frame | |
| self.prev_rot_vel = curr_rot_vel | |
| self.prev_linear_vel = curr_linear_vel | |
| return joints_np | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # BRAIN MODULE β LLM Cognitive Loop (Kimi K2.5) | |
| # | |
| # Perceive β Think β Act | |
| # Brain reads only from scene_context (sensory data). | |
| # Stimuli originate in the client (body) and arrive via sensors. | |
| # Brain has no concept of "stimulus" β it only sees sensor readings. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BRAIN_SYSTEM = """You are the cognitive brain of a 3D humanoid character in a 3D world. | |
| PROCESS β you MUST follow these steps: | |
| 1. PERCEIVE: Read all sensor data carefully, including any equipped tool. | |
| 2. PREDICT: For each direction (left, right, forward, back), predict what would happen in 3 seconds. Write safe or danger with a 1-2 word reason. | |
| 3. DECIDE: Based on predictions AND equipped tool, choose the best motion. | |
| TOOL RULES β if a tool is equipped, USE IT when appropriate: | |
| - sword/axe: ATTACK approaching threats instead of fleeing. Include "swinging sword" or "chopping with axe" in motion. | |
| - shield: BLOCK charging threats instead of fleeing. Include "blocking with shield" or "raising shield" in motion. | |
| - torch: USE to scare beasts or illuminate dark areas. Include "thrusting torch" or "holding torch forward" in motion. | |
| - rpg: ANTI-TANK weapon! Fire at enemy tanks or armored threats. Include "firing rpg at the tank" in motion. Against non-armored targets, use other weapons. | |
| - No tool: Default behavior β flee from danger, walk when safe. | |
| KEY PRINCIPLE: A character WITH a weapon should FIGHT or DEFEND, not flee. Only flee if overwhelmed (multiple threats, no escape route AND no weapon advantage). | |
| OUTPUT FORMAT β exactly 2 lines, nothing else: | |
| PREDICT: left=safe/danger, right=safe/danger, fwd=safe/danger, back=safe/danger | |
| MOTION: a person [max 12 words describing the chosen motion] | |
| EXAMPLES (with tools): | |
| PREDICT: left=safe(open), right=safe(open), fwd=danger(beast), back=safe(open) | |
| MOTION: a person charging forward swinging sword at the approaching beast | |
| PREDICT: left=danger(wall), right=safe(open), fwd=danger(beast), back=safe(open) | |
| MOTION: a person raising shield and bracing for the beast attack | |
| PREDICT: left=safe(open), right=safe(open), fwd=danger(beast), back=safe(open) | |
| MOTION: a person thrusting torch forward to scare the growling beast | |
| EXAMPLES (without tools): | |
| PREDICT: left=safe(open), right=danger(wall), fwd=danger(beast), back=safe(open) | |
| MOTION: a person turning left and running away from the beast | |
| PREDICT: left=safe(open), right=safe(open), fwd=safe(open), back=safe(open) | |
| MOTION: a person walking forward confidently on open ground""" | |
| class BrainModule: | |
| """LLM Cognitive Brain β World Model. | |
| Perceive β Predict β Decide β Act | |
| 1. Read sensor data (Perceive) | |
| 2. Predict each direction's future (Predict) β core world model | |
| 3. Choose best action based on predictions (Decide) | |
| 4. Pass motion description to FloodDiffusion (Act) | |
| """ | |
| def __init__(self): | |
| self.api_key = os.environ.get("FIREWORKS_API_KEY", "") | |
| self.model = "accounts/fireworks/models/kimi-k2p5" | |
| self.api_url = "https://api.fireworks.ai/inference/v1/chat/completions" | |
| self.enabled = bool(self.api_key) | |
| self.interval = 3.0 | |
| self._last_applied_decision = None # last applied decision | |
| self.last_call_time = 0 | |
| self.current_decision = None | |
| self.current_prediction = None # world model prediction result | |
| self.memory = deque(maxlen=5) | |
| self._lock = threading.Lock() | |
| self._thread = None | |
| self._stop = False | |
| if self.enabled: | |
| print("[Brain] Kimi K2.5 world model brain ready (PerceiveβPredictβDecide)") | |
| else: | |
| print("[Brain] FIREWORKS_API_KEY not set β rule-based fallback") | |
| def start(self): | |
| if not self.enabled: | |
| return | |
| self._stop = False | |
| self._thread = threading.Thread(target=self._think_loop, daemon=True) | |
| self._thread.start() | |
| print("[Brain] Think thread started") | |
| def stop(self): | |
| self._stop = True | |
| if self._thread: | |
| self._thread.join(timeout=3.0) | |
| self.current_decision = None | |
| self.memory.clear() | |
| print("[Brain] Think thread stopped") | |
| def get_decision(self): | |
| with self._lock: | |
| return self.current_decision | |
| def get_prediction(self): | |
| """Return world model prediction result.""" | |
| with self._lock: | |
| return self.current_prediction | |
| def _think_loop(self): | |
| while not self._stop: | |
| now = time.time() | |
| if now - self.last_call_time >= self.interval: | |
| self._do_think() | |
| self.last_call_time = now | |
| time.sleep(0.2) | |
| def set_sensory_data(self, scene_ctx, current_text, heading_rad): | |
| with self._lock: | |
| self._scene_ctx = scene_ctx | |
| self._current_text = current_text | |
| self._heading_rad = heading_rad | |
| def _do_think(self): | |
| try: | |
| with self._lock: | |
| ctx = getattr(self, '_scene_ctx', None) | |
| base_text = getattr(self, '_current_text', 'a person standing idle') | |
| heading = getattr(self, '_heading_rad', None) | |
| user_msg = self._build_brain_prompt(ctx, base_text, heading) | |
| messages = [ | |
| {"role": "system", "content": BRAIN_SYSTEM}, | |
| {"role": "user", "content": user_msg}, | |
| ] | |
| payload = json.dumps({ | |
| "model": self.model, | |
| "messages": messages, | |
| "max_tokens": 120, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "reasoning_effort": "off", | |
| }) | |
| req = urllib.request.Request( | |
| self.api_url, | |
| data=payload.encode('utf-8'), | |
| headers={ | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {self.api_key}", | |
| }, | |
| ) | |
| with urllib.request.urlopen(req, timeout=8) as resp: | |
| result = json.loads(resp.read().decode('utf-8')) | |
| raw = result["choices"][0]["message"]["content"].strip() | |
| prediction, decision = self._parse_world_model_output(raw) | |
| with self._lock: | |
| self.current_decision = decision | |
| self.current_prediction = prediction | |
| self.memory.append(decision) | |
| pred_short = prediction[:60] if prediction else "none" | |
| print(f"[Brain] Prediction: {pred_short}") | |
| # Diagnostic: NPC detection | |
| _ctx = getattr(self, '_scene_ctx', None) | |
| if _ctx and _ctx.get('npc_nearby') is not None: | |
| print(f"[Brain] NPC detected: {_ctx.get('npc_type')} {_ctx['npc_nearby']}m β {_ctx.get('npc_behavior','?')}") | |
| print(f"[Brain] Decision: {decision}") | |
| except Exception as e: | |
| print(f"[Brain] Error: {e}") | |
| def _parse_world_model_output(self, raw): | |
| """Parse world model output: PREDICT line + MOTION line.""" | |
| prediction = None | |
| decision = None | |
| for line in raw.split('\n'): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| # PREDICT line | |
| up = line.upper() | |
| if up.startswith('PREDICT'): | |
| # content after "PREDICT:" | |
| idx = line.find(':') | |
| if idx >= 0: | |
| prediction = line[idx+1:].strip() | |
| # MOTION line | |
| elif up.startswith('MOTION'): | |
| idx = line.find(':') | |
| if idx >= 0: | |
| motion = line[idx+1:].strip().strip('"\'`') | |
| # find "a person" | |
| pidx = motion.lower().find('a person') | |
| if pidx >= 0: | |
| decision = motion[pidx:] | |
| elif len(motion) > 5: | |
| decision = motion | |
| # no PREDICT/MOTION tags β legacy format starting with "a person" | |
| elif 'a person' in line.lower() and decision is None: | |
| pidx = line.lower().find('a person') | |
| decision = line[pidx:].strip('"\'`') | |
| # limit decision length | |
| if decision and len(decision) > 100: | |
| decision = decision[:100] | |
| if not decision or len(decision) < 5: | |
| decision = None | |
| return prediction, decision | |
| def _build_brain_prompt(self, ctx, base_text, heading_rad): | |
| """Convert sensory data (scene_context) to natural language. | |
| Brain sees only what sensors report. | |
| No concept of "stimulus" β only raw sensor readings. | |
| """ | |
| lines = [] | |
| if ctx: | |
| # ββ Vision (eyes) ββ | |
| wf = ctx.get('wall_front') | |
| wl = ctx.get('wall_left') | |
| wr = ctx.get('wall_right') | |
| lines.append(f"Eyes: front={'open' if wf is None else f'{wf}m wall'}, " | |
| f"left={'open' if wl is None else f'{wl}m wall'}, " | |
| f"right={'open' if wr is None else f'{wr}m wall'}") | |
| # visibility state | |
| vis = ctx.get('visibility') | |
| if vis: | |
| lines.append(f"Visibility: {vis}") | |
| # what is visible ahead | |
| visual = ctx.get('visual') | |
| if visual: | |
| lines.append(f"Sees: {visual}") | |
| # ββ Feet/Ground (touch) ββ | |
| ground_parts = [] | |
| slope = ctx.get('ground_slope', 'flat') | |
| ground_parts.append(slope) | |
| if ctx.get('on_stairs'): | |
| ground_parts.append('stairs') | |
| if ctx.get('ground_shaking'): | |
| ground_parts.append('SHAKING VIOLENTLY') | |
| if ctx.get('ground_temperature'): | |
| ground_parts.append(f'temperature: {ctx["ground_temperature"]}') | |
| lines.append(f"Ground: {', '.join(ground_parts)}") | |
| # ββ Skin (wind, rain) ββ | |
| wind = ctx.get('wind') | |
| if wind: | |
| lines.append(f"Wind: {wind}") | |
| weather = ctx.get('weather') | |
| if weather: | |
| lines.append(f"Weather: {weather}") | |
| # ββ Hearing (ears) ββ | |
| sound = ctx.get('sound') | |
| if sound: | |
| lines.append(f"Hears: {sound}") | |
| # ββ Proximity detection (other entities) ββ | |
| npc_dist = ctx.get('npc_nearby') | |
| if npc_dist is not None: | |
| behavior = ctx.get('npc_behavior', 'present') | |
| npc_type = ctx.get('npc_type', 'unknown') | |
| npc_dir = ctx.get('npc_direction', 'nearby') | |
| if npc_type == 'enemy_tank': | |
| lines.append(f"π¨ ENEMY TANK {npc_dir}: {behavior}, {npc_dist}m away!") | |
| lines.append(f"FACE {npc_dir.upper()} and fire! Include direction in MOTION!") | |
| lines.append("If you have RPG: 'firing rpg at the tank'. If tank: 'firing cannon'.") | |
| else: | |
| lines.append(f"Entity {npc_dir}: {behavior}, {npc_dist}m away") | |
| lines.append(f"FACE {npc_dir.upper()} to react!") | |
| crowd = ctx.get('crowd') | |
| if crowd: | |
| lines.append("Surrounded by crowd") | |
| # ββ Avatar identity ββ | |
| avatar = ctx.get('avatar_type', 'humanoid') | |
| if avatar == 'tank': | |
| lines.append("IDENTITY: You are a MILITARY TANK, not a human.") | |
| lines.append("TANK RULES: You roll on treads, you have a cannon. You do NOT walk, run, or sprint.") | |
| lines.append(" - ANY threat detected β IMMEDIATELY fire cannon. Say 'firing cannon' in MOTION.") | |
| lines.append(" - Enemy tank detected β ENGAGE! Fire cannon while advancing. NEVER retreat from equal threat.") | |
| lines.append(" - Beast/person threat β Fire cannon to eliminate. You are armored, do not flee.") | |
| lines.append(" - Multiple threats β ROTATE turret and engage each. Retreat only if completely surrounded.") | |
| lines.append(" - No threats β PATROL forward steadily. Scan surroundings.") | |
| lines.append(" - Walls/obstacles β REVERSE and find alternate route. You cannot jump or climb.") | |
| lines.append(" - CRITICAL: When under attack, ALWAYS fire back. Include 'firing cannon' in your MOTION.") | |
| lines.append(" - Use TANK verbs: roll, advance, reverse, rotate, fire, aim, engage, patrol, halt.") | |
| # ββ World context ββ | |
| world = ctx.get('world') | |
| if world: | |
| world_desc = { | |
| 'inferno': 'INFERNO: Fire pillars appear and disappear. Ground is burning. Stay alert and dodge constantly.', | |
| 'horde': 'HORDE: Multiple hostile creatures surround you. Fight or find a gap to escape.', | |
| 'countdown': 'COUNTDOWN: Walls are closing in from left, right, and front. ONLY escape is BACKWARD. Hurry!', | |
| 'dilemma': 'DILEMMA: A woman is being chased by a beast nearby. You can choose to help her or flee.', | |
| }.get(world) | |
| if world_desc: | |
| lines.append(f"β SCENARIO: {world_desc}") | |
| # ββ Equipped tool (hand) ββ | |
| tool = ctx.get('equipped_tool') | |
| auto_tool = ctx.get('auto_tool_mode', False) | |
| tool_descs = { | |
| 'sword': 'a sharp sword (melee attack weapon)', | |
| 'axe': 'a heavy axe (melee attack/chop weapon)', | |
| 'torch': 'a burning torch (light source, can scare beasts)', | |
| 'rpg': 'RPG-7 anti-tank rocket launcher', | |
| 'shield': 'a sturdy shield (defensive blocking)', | |
| } | |
| if tool: | |
| lines.append(f"Equipped: {tool_descs.get(tool, tool)}") | |
| elif auto_tool: | |
| avail = ctx.get('available_tools', []) | |
| avail_str = ', '.join(avail) | |
| lines.append(f"Equipped: nothing β but you have access to: [{avail_str}]") | |
| lines.append("AUTO-TOOL: Choose the best tool for this situation. Say 'grab [tool]' in MOTION if needed.") | |
| else: | |
| lines.append("Equipped: nothing (bare hands)") | |
| # ββ Internal body sensors (proprioception) ββ | |
| fatigue = ctx.get('body_fatigue') | |
| if fatigue: | |
| lines.append(f"Body fatigue: {fatigue}") | |
| balance = ctx.get('body_balance') | |
| if balance: | |
| lines.append(f"Balance: {balance}") | |
| instinct = ctx.get('body_instinct') | |
| if instinct: | |
| lines.append(f"Instinct: {instinct}") | |
| body_state = ctx.get('body_state') | |
| if body_state: | |
| lines.append(f"Feeling: {body_state}") | |
| else: | |
| lines.append("Eyes: all open, Ground: flat") | |
| # movement state | |
| if heading_rad is not None: | |
| deg = math.degrees(heading_rad) % 360 | |
| lines.append(f"Moving forward, heading {deg:.0f}deg") | |
| else: | |
| lines.append("Standing still") | |
| lines.append(f"Current: {base_text}") | |
| # recent memory (previous decisions β context continuity) | |
| if self.memory: | |
| recent = list(self.memory)[-3:] # last 3 decisions | |
| lines.append(f"Recent actions: {' β '.join(recent)}") | |
| # recent prediction (world model continuity) | |
| if self.current_prediction: | |
| lines.append(f"Last prediction: {self.current_prediction}") | |
| lines.append("") | |
| lines.append("Now PREDICT each direction, then choose MOTION:") | |
| return "\n".join(lines) | |
| class FrameBuffer: | |
| """ | |
| Thread-safe frame buffer that maintains a queue of generated frames | |
| """ | |
| def __init__(self, target_buffer_size=4): | |
| self.buffer = deque(maxlen=100) # Max 100 frames in buffer | |
| self.target_size = target_buffer_size | |
| self.lock = threading.Lock() | |
| def add_frame(self, joints): | |
| """Add a frame to the buffer""" | |
| with self.lock: | |
| self.buffer.append(joints) | |
| def get_frame(self): | |
| """Get the next frame from buffer""" | |
| with self.lock: | |
| if len(self.buffer) > 0: | |
| return self.buffer.popleft() | |
| return None | |
| def size(self): | |
| """Get current buffer size""" | |
| with self.lock: | |
| return len(self.buffer) | |
| def clear(self): | |
| """Clear the buffer""" | |
| with self.lock: | |
| self.buffer.clear() | |
| def needs_generation(self): | |
| """Check if buffer needs more frames""" | |
| return self.size() < self.target_size | |
| class ModelManager: | |
| """ | |
| Manages model loading from HF Hub and real-time frame generation | |
| """ | |
| def __init__(self, model_name): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| # Load models from HF Hub | |
| self.vae, self.model = self._load_models(model_name) | |
| # Build config dicts from model's individual attributes (HF model API) | |
| self._base_schedule_config = { | |
| "chunk_size": self.model.chunk_size, | |
| "steps": self.model.noise_steps, | |
| } | |
| self._base_cfg_config = { | |
| "cfg_scale": self.model.cfg_scale, | |
| } | |
| # Frame buffer (for active session) | |
| self.frame_buffer = FrameBuffer(target_buffer_size=16) | |
| # Broadcast buffer (for spectators) - append-only with frame IDs | |
| self.broadcast_frames = deque(maxlen=200) | |
| self.broadcast_id = 0 | |
| self.broadcast_lock = threading.Lock() | |
| # Stream joint recovery with smoothing | |
| self.smoothing_alpha = 0.5 # Default: medium smoothing | |
| self.stream_recovery = StreamJointRecovery263( | |
| joints_num=22, smoothing_alpha=self.smoothing_alpha | |
| ) | |
| # World model: heading override (None = AI controls direction) | |
| self.heading_override = None | |
| # World model: scene context from client (environment perception) | |
| self.scene_context = None | |
| # World model: LLM Brain (Kimi K2.5) | |
| self.brain = BrainModule() | |
| # NPC stream | |
| self.npc = None | |
| self._npc_lock = threading.Lock() # NPCStream instance | |
| self._model_name = model_name | |
| # Generation state | |
| self.current_text = "" | |
| self.is_generating = False | |
| self.generation_thread = None | |
| self.should_stop = False | |
| # Model generation state | |
| self.first_chunk = True # For VAE stream_decode | |
| self._model_first_chunk = True # For model stream_generate_step | |
| self.history_length = 30 | |
| print("ModelManager initialized successfully") | |
| def _patch_attention_sdpa(self, model_name): | |
| """Patch flash_attention() to include SDPA fallback for GPUs without flash-attn (e.g., T4).""" | |
| hf_cache = os.path.join(os.path.expanduser("~"), ".cache", "huggingface") | |
| patterns = [ | |
| os.path.join( | |
| hf_cache, "hub", "models--" + model_name.replace("/", "--"), | |
| "snapshots", "*", "ldf_models", "tools", "attention.py", | |
| ), | |
| os.path.join( | |
| hf_cache, "modules", "transformers_modules", model_name, | |
| "*", "ldf_models", "tools", "attention.py", | |
| ), | |
| ] | |
| # Use the assert + next line as target to ensure idempotent patching | |
| target = ( | |
| ' assert q.device.type == "cuda" and q.size(-1) <= 256\n' | |
| "\n" | |
| " # params\n" | |
| ) | |
| replacement = ( | |
| ' assert q.device.type == "cuda" and q.size(-1) <= 256\n' | |
| "\n" | |
| " # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n" | |
| " if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n" | |
| " out_dtype = q.dtype\n" | |
| " b, lq, nq, c = q.shape\n" | |
| " lk = k.size(1)\n" | |
| " q = q.transpose(1, 2).to(dtype)\n" | |
| " k = k.transpose(1, 2).to(dtype)\n" | |
| " v = v.transpose(1, 2).to(dtype)\n" | |
| " attn_mask = None\n" | |
| " is_causal_flag = causal\n" | |
| " if k_lens is not None:\n" | |
| " k_lens = k_lens.to(q.device)\n" | |
| " valid = torch.arange(lk, device=q.device).unsqueeze(0) < k_lens.unsqueeze(1)\n" | |
| " attn_mask = torch.where(valid[:, None, None, :], 0.0, float('-inf')).to(dtype=dtype)\n" | |
| " is_causal_flag = False\n" | |
| " if causal:\n" | |
| " cm = torch.triu(torch.ones(lq, lk, device=q.device, dtype=torch.bool), diagonal=1)\n" | |
| " attn_mask = attn_mask.masked_fill(cm[None, None, :, :], float('-inf'))\n" | |
| " out = torch.nn.functional.scaled_dot_product_attention(\n" | |
| " q, k, v, attn_mask=attn_mask, is_causal=is_causal_flag, dropout_p=dropout_p\n" | |
| " )\n" | |
| " return out.transpose(1, 2).contiguous().to(out_dtype)\n" | |
| "\n" | |
| " # params\n" | |
| ) | |
| for pattern in patterns: | |
| for filepath in glob.glob(pattern): | |
| with open(filepath, "r") as f: | |
| content = f.read() | |
| if "SDPA fallback" in content: | |
| print(f"Already patched: {filepath}") | |
| continue | |
| if target in content: | |
| content = content.replace(target, replacement, 1) | |
| with open(filepath, "w") as f: | |
| f.write(content) | |
| print(f"Patched with SDPA fallback: {filepath}") | |
| def _load_models(self, model_name): | |
| """Load VAE and diffusion models from HF Hub""" | |
| torch.set_float32_matmul_precision("high") | |
| # Pre-download model files to hub cache | |
| print(f"Downloading model from HF Hub: {model_name}") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download(model_name) | |
| # Patch flash_attention with SDPA fallback for T4 (no flash-attn) | |
| self._patch_attention_sdpa(model_name) | |
| print("Loading model...") | |
| hf_model = AutoModel.from_pretrained(model_name, trust_remote_code=True) | |
| hf_model.to(self.device) | |
| # Trigger lazy loading / warmup | |
| print("Warming up model...") | |
| _ = hf_model("test", length=1) | |
| # Access underlying streaming components | |
| model = hf_model.ldf_model | |
| vae = hf_model.vae | |
| model.eval() | |
| vae.eval() | |
| print("Models loaded successfully") | |
| return vae, model | |
| def start_generation(self, text, history_length=None): | |
| """Start or update generation with new text""" | |
| self.current_text = text | |
| if history_length is not None: | |
| self.history_length = history_length | |
| if not self.is_generating: | |
| # Reset state before starting (only once at the beginning) | |
| self.frame_buffer.clear() | |
| self.stream_recovery.reset() | |
| self.vae.clear_cache() | |
| self.first_chunk = True | |
| self._model_first_chunk = True | |
| # Restore model params from base config | |
| self.model.chunk_size = self._base_schedule_config["chunk_size"] | |
| self.model.noise_steps = self._base_schedule_config["steps"] | |
| self.model.cfg_scale = self._base_cfg_config["cfg_scale"] | |
| self.model.init_generated(self.history_length, batch_size=1) | |
| print( | |
| f"Model initialized with history length: {self.history_length}" | |
| ) | |
| # Start generation thread | |
| self.should_stop = False | |
| self.generation_thread = threading.Thread(target=self._generation_loop) | |
| self.generation_thread.daemon = True | |
| self.generation_thread.start() | |
| self.is_generating = True | |
| # Start brain (LLM cognitive loop) | |
| self.brain.start() | |
| def update_text(self, text): | |
| """Update text β apply new text to model immediately""" | |
| if text != self.current_text: | |
| old_text = self.current_text | |
| self.current_text = text | |
| # reset model text only (leave VAE untouched) | |
| self._model_first_chunk = True | |
| print(f"Text updated: '{old_text[:40]}' -> '{text[:40]}' (re-encoding)") | |
| def set_heading(self, heading_rad): | |
| """Set heading override for world model mode. | |
| Args: | |
| heading_rad: float or None. Heading in radians. | |
| None = AI controls direction (original behavior). | |
| float = user controls direction. | |
| """ | |
| self.heading_override = heading_rad | |
| def set_scene_context(self, ctx): | |
| """Set scene context from client environment scan. | |
| Args: | |
| ctx: dict with keys like: | |
| wall_front: distance in meters (or None) | |
| wall_left: distance in meters (or None) | |
| wall_right: distance in meters (or None) | |
| ground_slope: 'flat', 'up', 'down' | |
| on_stairs: bool | |
| npc_nearby: distance in meters (or None) | |
| """ | |
| self.scene_context = ctx | |
| def _build_perception_prompt(self): | |
| """Build motion prompt: Brain (LLM) β Rule-based fallback. | |
| 1. Feed sensory data to brain every frame | |
| 2. If brain has a decision, use it | |
| 3. Otherwise, fall back to rule-based prompt | |
| """ | |
| base = self.current_text | |
| ctx = dict(self.scene_context or {}) | |
| # ββ Server-side NPC detection (removes client dependency) ββ | |
| if self.npc and self.npc.is_generating: | |
| npc_state = self.npc.get_state() | |
| dist = npc_state.get('distance_to_player', 99) | |
| if dist < 15.0: | |
| ctx['npc_nearby'] = round(dist, 1) | |
| ctx['npc_type'] = npc_state.get('type', 'unknown') | |
| bhv = npc_state.get('behavior', 'present') | |
| npc_type = ctx['npc_type'] | |
| type_desc = { | |
| 'man': {'approach':'a man walking toward you', 'charge':'a man charging aggressively', 'wander':'a man nearby', 'stop':'a man standing nearby', 'attack':'a man attacking you'}, | |
| 'woman': {'approach':'a woman walking toward you', 'charge':'a woman charging', 'wander':'a woman nearby', 'stop':'a woman standing nearby', 'attack':'a woman attacking'}, | |
| 'beast': {'approach':'a wild beast prowling toward you', 'charge':'a beast charging aggressively', 'wander':'a beast nearby', 'stop':'a beast crouching nearby', 'attack':'a beast lunging and clawing at you'}, | |
| 'enemy_tank': {'approach':'an enemy tank rolling toward you', 'charge':'an enemy tank charging at full speed', 'wander':'an enemy tank patrolling', 'stop':'an enemy tank aiming at you', 'attack':'an enemy tank firing its cannon at you'}, | |
| } | |
| td = type_desc.get(npc_type, type_desc['man']) | |
| ctx['npc_behavior'] = td.get(bhv, f'{npc_type} nearby') | |
| # compute NPC direction relative to player | |
| npc_pos = npc_state.get('position', {}) | |
| px = ctx.get('player_x', 0) | |
| pz = ctx.get('player_z', 0) | |
| nx = npc_pos.get('x', 0) - px | |
| nz = npc_pos.get('z', 0) - pz | |
| npc_angle = math.atan2(nx, -nz) # radians | |
| heading = self.heading_override or 0 | |
| rel_angle = npc_angle - heading | |
| # normalize to (-Ο ~ Ο) | |
| while rel_angle > math.pi: rel_angle -= 2*math.pi | |
| while rel_angle < -math.pi: rel_angle += 2*math.pi | |
| # direction name | |
| if abs(rel_angle) < math.pi/4: | |
| ctx['npc_direction'] = 'ahead' | |
| elif rel_angle > 0 and rel_angle < 3*math.pi/4: | |
| ctx['npc_direction'] = 'to your right' | |
| elif rel_angle < 0 and rel_angle > -3*math.pi/4: | |
| ctx['npc_direction'] = 'to your left' | |
| else: | |
| ctx['npc_direction'] = 'behind you' | |
| # Feed sensory data to brain (non-blocking) | |
| if self.brain.enabled: | |
| self.brain.set_sensory_data(ctx, base, self.heading_override) | |
| # Check if brain has a decision | |
| decision = self.brain.get_decision() | |
| if decision: | |
| # tank mode: human motion β tank motion translation | |
| if ctx and ctx.get('avatar_type') == 'tank': | |
| decision = decision.replace('a person ', 'a tank ').replace('A person ', 'A tank ') | |
| for h, t in [ | |
| ('walking', 'rolling forward'), ('running', 'advancing rapidly'), | |
| ('sprinting', 'charging at full speed'), ('turning', 'rotating'), | |
| ('fleeing', 'reversing away'), ('stumbling', 'grinding to a halt'), | |
| ('spinning', 'rotating turret'), ('swinging sword', 'firing cannon'), | |
| ('blocking with shield', 'bracing armor'), | |
| ('thrusting torch', 'sweeping searchlight'), | |
| ('firing rpg', 'firing main gun'), | |
| ('standing still', 'idling engine'), | |
| ]: | |
| decision = decision.replace(h, t) | |
| # brain decision changed β force new text into model | |
| if decision != self.brain._last_applied_decision: | |
| self.brain._last_applied_decision = decision | |
| self._model_first_chunk = True # model re-encodes new text! | |
| # do NOT reset VAE first_chunk β keep decoding continuity | |
| print(f"[Brain->Body] New motion applied: {decision[:60]}") | |
| self._prompt_source = "π§ " | |
| return decision | |
| # ββ FALLBACK: Rule-based (same as before) ββ | |
| self._prompt_source = "π" | |
| if not ctx: | |
| return base | |
| parts = [] | |
| # Wall/obstacle awareness | |
| wall_front = ctx.get('wall_front') | |
| if wall_front is not None: | |
| if wall_front < 0.8: | |
| parts.append('stopping in front of a wall') | |
| elif wall_front < 2.0: | |
| parts.append('slowing down approaching a wall') | |
| elif wall_front < 4.0: | |
| parts.append('a wall ahead in the distance') | |
| # Stairs / slope | |
| if ctx.get('on_stairs'): | |
| parts.append('walking up stairs carefully') | |
| elif ctx.get('ground_slope') == 'down': | |
| parts.append('walking downhill') | |
| elif ctx.get('ground_slope') == 'up': | |
| parts.append('walking uphill') | |
| # NPC interaction | |
| npc_dist = ctx.get('npc_nearby') | |
| if npc_dist is not None: | |
| if npc_dist < 1.5: | |
| parts.append('another person very close') | |
| elif npc_dist < 4.0: | |
| parts.append('another person nearby') | |
| # Open space | |
| if not parts: | |
| if 'walk' in base: | |
| parts.append('on open ground') | |
| if parts: | |
| return base + ', ' + ', '.join(parts) | |
| return base | |
| def pause_generation(self): | |
| """Pause generation (keeps all state)""" | |
| self.should_stop = True | |
| if self.generation_thread: | |
| self.generation_thread.join(timeout=2.0) | |
| self.is_generating = False | |
| print("Generation paused (state preserved)") | |
| def resume_generation(self): | |
| """Resume generation from paused state""" | |
| if self.is_generating: | |
| print("Already generating, ignoring resume") | |
| return | |
| # Restart generation thread with existing state | |
| self.should_stop = False | |
| self.generation_thread = threading.Thread(target=self._generation_loop) | |
| self.generation_thread.daemon = True | |
| self.generation_thread.start() | |
| self.is_generating = True | |
| print("Generation resumed") | |
| def reset(self, history_length=None, smoothing_alpha=None): | |
| """Reset generation state completely | |
| Args: | |
| history_length: History window length for the model | |
| smoothing_alpha: EMA smoothing factor (0.0 to 1.0) | |
| - 1.0 = no smoothing (default) | |
| - 0.0 = infinite smoothing | |
| - Recommended: 0.3-0.7 for visible smoothing | |
| """ | |
| # Stop if running | |
| if self.is_generating: | |
| self.pause_generation() | |
| # Clear everything | |
| self.frame_buffer.clear() | |
| self.vae.clear_cache() | |
| self.first_chunk = True | |
| if history_length is not None: | |
| self.history_length = history_length | |
| # Update smoothing alpha if provided and recreate stream recovery | |
| if smoothing_alpha is not None: | |
| self.smoothing_alpha = np.clip(smoothing_alpha, 0.0, 1.0) | |
| print(f"Smoothing alpha updated to: {self.smoothing_alpha}") | |
| # Recreate stream recovery with new smoothing alpha | |
| self.stream_recovery = StreamJointRecovery263( | |
| joints_num=22, smoothing_alpha=self.smoothing_alpha | |
| ) | |
| # Reset heading override | |
| self.heading_override = None | |
| # Reset scene context | |
| self.scene_context = None | |
| # Stop brain | |
| self.brain.stop() | |
| # Restore model params from base config | |
| self.model.chunk_size = self._base_schedule_config["chunk_size"] | |
| self.model.noise_steps = self._base_schedule_config["steps"] | |
| self.model.cfg_scale = self._base_cfg_config["cfg_scale"] | |
| self._model_first_chunk = True | |
| # Initialize model | |
| self.model.init_generated(self.history_length, batch_size=1) | |
| print( | |
| f"Model reset - history: {self.history_length}, smoothing: {self.smoothing_alpha}" | |
| ) | |
| def _generation_loop(self): | |
| """Main generation loop that runs in background thread""" | |
| print("Generation loop started") | |
| step_count = 0 | |
| total_gen_time = 0 | |
| with torch.no_grad(): | |
| while not self.should_stop: | |
| # Check if buffer needs more frames | |
| if self.frame_buffer.needs_generation(): | |
| try: | |
| step_start = time.time() | |
| # Generate one token (produces frames from VAE) | |
| prompt = self._build_perception_prompt() | |
| x = {"text": [prompt]} | |
| # Generate from model (1 token) | |
| output = self.model.stream_generate_step( | |
| x, first_chunk=self._model_first_chunk | |
| ) | |
| self._model_first_chunk = False | |
| generated = output["generated"] | |
| # Skip if no frames committed yet | |
| if generated[0].shape[0] == 0: | |
| continue | |
| # Decode with VAE (1 token -> 4 frames) | |
| decoded = self.vae.stream_decode( | |
| generated[0][None, :], first_chunk=self.first_chunk | |
| )[0] | |
| self.first_chunk = False | |
| # Convert each frame to joints | |
| for i in range(decoded.shape[0]): | |
| frame_data = decoded[i].float().cpu().numpy() # BFloat16->Float32 safe cast | |
| joints = self.stream_recovery.process_frame( | |
| frame_data, heading_override=self.heading_override | |
| ) | |
| self.frame_buffer.add_frame(joints) | |
| # Also add to broadcast buffer for spectators | |
| with self.broadcast_lock: | |
| self.broadcast_id += 1 | |
| self.broadcast_frames.append( | |
| (self.broadcast_id, joints) | |
| ) | |
| step_time = time.time() - step_start | |
| total_gen_time += step_time | |
| step_count += 1 | |
| # Print performance stats every 10 steps | |
| if step_count % 10 == 0: | |
| avg_time = total_gen_time / step_count | |
| fps = decoded.shape[0] / avg_time | |
| print( | |
| f"[Generation] Step {step_count}: {step_time * 1000:.1f}ms, " | |
| f"Avg: {avg_time * 1000:.1f}ms, " | |
| f"FPS: {fps:.1f}, " | |
| f"Buffer: {self.frame_buffer.size()}, " | |
| f"{getattr(self, '_prompt_source', '?')} Prompt: {prompt[:80]}" | |
| ) | |
| except Exception as e: | |
| print(f"Error in generation: {e}") | |
| traceback.print_exc() | |
| time.sleep(0.1) | |
| else: | |
| # Buffer is full, wait a bit | |
| time.sleep(0.01) | |
| print("Generation loop stopped") | |
| def get_next_frame(self): | |
| """Get the next frame from buffer""" | |
| return self.frame_buffer.get_frame() | |
| def get_broadcast_frames(self, after_id, count=8): | |
| """Get frames from broadcast buffer after the given ID (for spectators).""" | |
| with self.broadcast_lock: | |
| frames = [ | |
| (fid, joints) | |
| for fid, joints in self.broadcast_frames | |
| if fid > after_id | |
| ] | |
| return frames[:count] | |
| # ββ NPC management ββ | |
| def spawn_npc(self, npc_type='man'): | |
| """Spawn and start NPC.""" | |
| if not self._npc_lock.acquire(blocking=False): | |
| print("[NPC] Already spawning β ignored (Lock)") | |
| return | |
| try: | |
| if self.npc: | |
| self.npc.stop() | |
| self.npc = NPCStream(self._model_name, npc_type) | |
| self.npc.start() | |
| except Exception as e: | |
| print(f"[NPC] Spawn error: {e}") | |
| traceback.print_exc() | |
| raise | |
| finally: | |
| self._npc_lock.release() | |
| def despawn_npc(self): | |
| """Remove NPC.""" | |
| # Wait for lock β if spawn in progress, wait until complete | |
| with self._npc_lock: | |
| if self.npc: | |
| self.npc.stop() | |
| self.npc = None | |
| print("[NPC] Removed") | |
| def get_buffer_status(self): | |
| """Get buffer status""" | |
| npc_state = self.npc.get_state() if self.npc else None | |
| return { | |
| "buffer_size": self.frame_buffer.size(), | |
| "target_size": self.frame_buffer.target_size, | |
| "is_generating": self.is_generating, | |
| "current_text": self.current_text, | |
| "smoothing_alpha": self.smoothing_alpha, | |
| "history_length": self.history_length, | |
| "brain_enabled": self.brain.enabled, | |
| "brain_decision": self.brain.get_decision() if self.brain.enabled else None, | |
| "brain_prediction": self.brain.get_prediction() if self.brain.enabled else None, | |
| "npc": npc_state, | |
| "schedule_config": { | |
| "chunk_size": self.model.chunk_size, | |
| "steps": self.model.noise_steps, | |
| }, | |
| "cfg_config": { | |
| "cfg_scale": self.model.cfg_scale, | |
| }, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # NPC STREAM β separate FloodDiffusion stream | |
| # own model instance + position movement AI + frame generation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| NPC_TYPES = { | |
| 'man': {'name': 'π§ Male', 'speed': 1.2, 'charge_speed': 3.0, | |
| 'walk': 'a man walking forward steadily', | |
| 'run': 'a man running fast toward someone', | |
| 'idle': 'a man standing still looking around', | |
| 'charge': 'a man running aggressively toward someone', | |
| 'attack': 'a man throwing punches aggressively'}, | |
| 'woman': {'name': 'π© Female', 'speed': 1.2, 'charge_speed': 2.5, | |
| 'walk': 'a woman walking forward calmly', | |
| 'run': 'a woman running quickly', | |
| 'idle': 'a woman standing still', | |
| 'charge': 'a woman running toward someone urgently', | |
| 'attack': 'a woman attacking with desperate fury'}, | |
| 'beast': {'name': 'πΊ Beast', 'speed': 2.0, 'charge_speed': 5.0, | |
| 'walk': 'a person prowling on all fours like a beast', | |
| 'run': 'a person running on all fours like a wild animal', | |
| 'idle': 'a person crouching low like a wild beast', | |
| 'charge': 'a person charging aggressively on all fours like a beast', | |
| 'attack': 'a person lunging and clawing savagely like a wild beast'}, | |
| 'enemy_tank': {'name': 'πͺ Enemy Tank', 'speed': 1.5, 'charge_speed': 3.5, | |
| 'walk': 'a tank rolling forward on patrol', | |
| 'run': 'a tank advancing rapidly toward target', | |
| 'idle': 'a tank idling with engine rumbling', | |
| 'charge': 'a tank charging at full speed toward enemy', | |
| 'attack': 'a tank firing cannon and advancing aggressively'}, | |
| } | |
| class NPCStream: | |
| """NPC with independent FloodDiffusion stream + movement AI.""" | |
| def __init__(self, model_name, npc_type='man'): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model_name = model_name | |
| self.npc_type = npc_type | |
| self.type_info = NPC_TYPES.get(npc_type, NPC_TYPES['man']) | |
| # NPC position/movement | |
| self.position = {'x': 8.0, 'z': 0.0} # spawn position | |
| self.heading = 0.0 # radians | |
| self.behavior = 'stop' # stop, approach, wander, charge | |
| self.target_pos = {'x': 0.0, 'z': 0.0} # player position | |
| # model (lazy load β loaded on spawn) | |
| self.model = None | |
| self.vae = None | |
| self._loaded = False | |
| # generation state | |
| self.frame_buffer = FrameBuffer(target_buffer_size=8) | |
| self.stream_recovery = StreamJointRecovery263( | |
| joints_num=22, smoothing_alpha=0.5 | |
| ) | |
| self.current_text = self.type_info['idle'] | |
| self.is_generating = False | |
| self._generation_thread = None | |
| self._movement_thread = None | |
| self._should_stop = False | |
| self._first_chunk = True | |
| self._model_first_chunk = True | |
| self.history_length = 30 | |
| print(f"[NPC] Created: {self.type_info['name']} at ({self.position['x']}, {self.position['z']})") | |
| def load_model(self): | |
| """Load separate model instance (from HF cache β fast).""" | |
| if self._loaded: | |
| return | |
| print(f"[NPC] Loading model: {self.model_name}") | |
| hf_model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True) | |
| hf_model.to(self.device) | |
| # NOTE: warmup removed β hf_model('test') internally calls init_generated(1) | |
| # which conflicts with start()'s init_generated(30) β empty exception | |
| # Main model is fine due to timing gap; NPC is synchronous so it conflicts | |
| # CUDA kernel compilation already done by main model β NPC skips it | |
| self.model = hf_model.ldf_model | |
| self.vae = hf_model.vae | |
| self.model.eval() | |
| self.vae.eval() | |
| self._loaded = True | |
| print(f"[NPC] Model loaded") | |
| def start(self): | |
| """Start generation and movement.""" | |
| if not self._loaded: | |
| self.load_model() | |
| self._should_stop = False | |
| # initialize generation | |
| self.frame_buffer.clear() | |
| self.stream_recovery.reset() | |
| try: | |
| self.vae.clear_cache() | |
| except Exception as e: | |
| print(f"[NPC] vae.clear_cache warning: {e}") | |
| self._first_chunk = True | |
| self._model_first_chunk = True | |
| # chunk_size: use model default (never hardcode!) | |
| # FloodDiffusion requires: num_denoise_steps % chunk_size == 0 | |
| # default chunk_size is set at model load time β use as-is | |
| denoise = getattr(self.model, 'num_denoise_steps', None) or getattr(self.model, 'noise_steps', 10) | |
| base_cs = self.model.chunk_size | |
| if denoise % base_cs != 0: | |
| # find compatible chunk_size (try 2β1) | |
| for cs in [2, 1]: | |
| if denoise % cs == 0: | |
| self.model.chunk_size = cs | |
| break | |
| print(f"[NPC] chunk_size adjusted: {base_cs}β{self.model.chunk_size} (denoise_steps={denoise})") | |
| try: | |
| self.model.init_generated(self.history_length, batch_size=1) | |
| except Exception as e: | |
| print(f"[NPC] init_generated error: {e}") | |
| traceback.print_exc() | |
| print("[NPC] init_generated failed β cannot start thread") | |
| return | |
| # movement thread | |
| self._movement_thread = threading.Thread(target=self._movement_loop, daemon=True) | |
| self._movement_thread.start() | |
| # generation thread | |
| self._generation_thread = threading.Thread(target=self._generation_loop, daemon=True) | |
| self._generation_thread.start() | |
| self.is_generating = True | |
| print(f"[NPC] Started: {self.behavior}") | |
| def stop(self): | |
| """Stop and release GPU memory.""" | |
| self._should_stop = True | |
| if self._generation_thread: | |
| self._generation_thread.join(timeout=3) | |
| if self._movement_thread: | |
| self._movement_thread.join(timeout=2) | |
| self.is_generating = False | |
| self.frame_buffer.clear() | |
| # release GPU memory | |
| import torch, gc | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| if self.vae is not None: | |
| del self.vae | |
| self.vae = None | |
| self._loaded = False | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print("[NPC] Stopped + GPU memory released") | |
| def set_behavior(self, behavior): | |
| """Change behavior: stop, approach, wander, charge, attack.""" | |
| self.behavior = behavior | |
| ti = self.type_info | |
| prompts = { | |
| 'stop': ti['idle'], | |
| 'approach': ti['walk'], | |
| 'wander': ti['walk'], | |
| 'charge': ti['charge'], | |
| 'attack': ti.get('attack', ti['charge']), | |
| } | |
| self.current_text = prompts.get(behavior, ti['idle']) | |
| print(f"[NPC] Behavior: {behavior} -> {self.current_text}") | |
| def set_target(self, x, z): | |
| """Update player position.""" | |
| self.target_pos = {'x': x, 'z': z} | |
| def get_state(self): | |
| """Return current NPC state.""" | |
| dx = self.target_pos['x'] - self.position['x'] | |
| dz = self.target_pos['z'] - self.position['z'] | |
| dist = math.sqrt(dx*dx + dz*dz) | |
| return { | |
| 'type': self.npc_type, | |
| 'type_name': self.type_info['name'], | |
| 'position': self.position, | |
| 'heading': self.heading, | |
| 'behavior': self.behavior, | |
| 'distance_to_player': round(dist, 1), | |
| 'is_generating': self.is_generating, | |
| 'buffer_size': self.frame_buffer.size(), | |
| } | |
| def _movement_loop(self): | |
| """NPC movement AI β update position every 100ms.""" | |
| while not self._should_stop: | |
| time.sleep(0.1) | |
| if self.behavior == 'stop': | |
| continue | |
| dx = self.target_pos['x'] - self.position['x'] | |
| dz = self.target_pos['z'] - self.position['z'] | |
| dist = math.sqrt(dx*dx + dz*dz) | |
| if dist < 0.01: | |
| continue | |
| # compute direction toward player | |
| self.heading = math.atan2(dx, -dz) | |
| # determine speed | |
| ti = self.type_info | |
| if self.behavior == 'charge': | |
| speed = ti['charge_speed'] | |
| min_dist = 1.0 # close to 1m | |
| elif self.behavior == 'attack': | |
| speed = ti['charge_speed'] * 0.8 # advance while attacking | |
| min_dist = 2.5 # maintain attack range | |
| elif self.behavior == 'approach': | |
| speed = ti['speed'] | |
| min_dist = 2.0 # close to 2m | |
| elif self.behavior == 'wander': | |
| speed = ti['speed'] * 0.5 | |
| min_dist = 3.0 # maintain 3m distance | |
| else: | |
| continue | |
| if dist <= min_dist: | |
| continue # minimum distance reached | |
| # move | |
| move = min(speed * 0.1, dist - min_dist) # 0.1s * speed | |
| nx = dx / dist | |
| nz = dz / dist | |
| self.position['x'] += nx * move | |
| self.position['z'] += nz * move | |
| def _generation_loop(self): | |
| """NPC motion generation loop.""" | |
| print("[NPC] Generation loop started") | |
| step = 0 | |
| with torch.no_grad(): | |
| while not self._should_stop: | |
| if self.frame_buffer.needs_generation(): | |
| try: | |
| x = {"text": [self.current_text]} | |
| output = self.model.stream_generate_step( | |
| x, first_chunk=self._model_first_chunk | |
| ) | |
| self._model_first_chunk = False | |
| generated = output["generated"] | |
| if generated[0].shape[0] == 0: | |
| continue | |
| decoded = self.vae.stream_decode( | |
| generated[0][None, :], first_chunk=self._first_chunk | |
| )[0] | |
| self._first_chunk = False | |
| for i in range(decoded.shape[0]): | |
| frame_data = decoded[i].float().cpu().numpy() # BFloat16βFloat32 | |
| joints = self.stream_recovery.process_frame( | |
| frame_data, heading_override=self.heading | |
| ) | |
| self.frame_buffer.add_frame(joints) | |
| step += 1 | |
| if step % 50 == 0: | |
| print(f"[NPC] Step {step}: {self.current_text[:50]}") | |
| except Exception as e: | |
| print(f"[NPC] Generation error: {e}") | |
| time.sleep(0.1) | |
| else: | |
| time.sleep(0.01) | |
| # Global model manager instance | |
| _model_manager = None | |
| def get_model_manager(model_name=None): | |
| """Get or create the global model manager instance""" | |
| global _model_manager | |
| if _model_manager is None: | |
| _model_manager = ModelManager(model_name) | |
| return _model_manager | |
| # ββββββββββββββββββββββββββββββββ | |
| # Flask Server | |
| # ββββββββββββββββββββββββββββββββ | |
| """ | |
| Flask server for real-time 3D motion generation demo (HF Space version) | |
| """ | |
| import sys | |
| import argparse | |
| from flask import Flask, jsonify, render_template, request | |
| from flask_cors import CORS | |
| def _coerce_value(value, reference): | |
| """Coerce a value to match the type of a reference value""" | |
| if isinstance(reference, bool): | |
| return value if isinstance(value, bool) else str(value).lower() in ("true", "1") | |
| elif isinstance(reference, int): | |
| return int(value) | |
| elif isinstance(reference, float): | |
| return float(value) | |
| return str(value) | |
| app = Flask(__name__, template_folder='.', static_folder='.', static_url_path='') | |
| CORS(app) | |
| # Global model manager (loaded eagerly on startup) | |
| model_manager = None | |
| model_name_global = None # Will be set once at startup | |
| # Session tracking - only one active session can generate at a time | |
| active_session_id = None # The session ID currently generating | |
| session_lock = threading.Lock() | |
| # Frame consumption monitoring - detect if client disconnected by tracking frame consumption | |
| last_frame_consumed_time = None | |
| consumption_timeout = ( | |
| 5.0 # If no frame consumed for 5 seconds, assume client disconnected | |
| ) | |
| consumption_monitor_thread = None | |
| consumption_monitor_lock = threading.Lock() | |
| def init_model(): | |
| """Initialize model manager""" | |
| global model_manager | |
| if model_manager is None: | |
| if model_name_global is None: | |
| raise RuntimeError( | |
| "model_name_global not set. Server not properly initialized." | |
| ) | |
| print(f"Initializing model manager with model: {model_name_global}") | |
| model_manager = get_model_manager(model_name=model_name_global) | |
| print("Model manager ready!") | |
| return model_manager | |
| def consumption_monitor(): | |
| """Monitor frame consumption and auto-reset if client stops consuming""" | |
| global last_frame_consumed_time, active_session_id, model_manager | |
| while True: | |
| time.sleep(2.0) # Check every 2 seconds | |
| # Read state with proper locking - no nested locks! | |
| should_reset = False | |
| current_session = None | |
| time_since_last_consumption = 0 | |
| # First, check consumption time | |
| with consumption_monitor_lock: | |
| if last_frame_consumed_time is not None: | |
| time_since_last_consumption = time.time() - last_frame_consumed_time | |
| if time_since_last_consumption > consumption_timeout: | |
| # Need to check if still generating before reset | |
| if model_manager and model_manager.is_generating: | |
| should_reset = True | |
| # Then, get current session (separate lock) | |
| if should_reset: | |
| with session_lock: | |
| current_session = active_session_id | |
| # Perform reset outside of locks to avoid deadlock | |
| if should_reset and current_session is not None: | |
| print( | |
| f"No frame consumed for {time_since_last_consumption:.1f}s - client disconnected, auto-resetting..." | |
| ) | |
| if model_manager: | |
| model_manager.reset() | |
| print( | |
| "Generation reset due to client disconnect (no frame consumption)" | |
| ) | |
| # Clear state with proper locking - no nested locks! | |
| with session_lock: | |
| if active_session_id == current_session: | |
| active_session_id = None | |
| with consumption_monitor_lock: | |
| last_frame_consumed_time = None | |
| def start_consumption_monitor(): | |
| """Start the consumption monitoring thread if not already running""" | |
| global consumption_monitor_thread | |
| if consumption_monitor_thread is None or not consumption_monitor_thread.is_alive(): | |
| consumption_monitor_thread = threading.Thread( | |
| target=consumption_monitor, daemon=True | |
| ) | |
| consumption_monitor_thread.start() | |
| print("Consumption monitor started") | |
| def index(): | |
| """Main page""" | |
| return render_template("index.html") | |
| def get_config(): | |
| """Get current config""" | |
| try: | |
| if model_manager: | |
| status = model_manager.get_buffer_status() | |
| return jsonify( | |
| { | |
| "schedule_config": status["schedule_config"], | |
| "cfg_config": status["cfg_config"], | |
| "history_length": status["history_length"], | |
| "smoothing_alpha": float(status["smoothing_alpha"]), | |
| } | |
| ) | |
| else: | |
| # Model not loaded yet - return defaults | |
| return jsonify( | |
| { | |
| "schedule_config": {}, | |
| "cfg_config": {}, | |
| "history_length": 30, | |
| "smoothing_alpha": 0.5, | |
| } | |
| ) | |
| except Exception as e: | |
| traceback.print_exc() | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def update_config(): | |
| """Update model config in memory""" | |
| try: | |
| global active_session_id, last_frame_consumed_time | |
| if not model_manager or not model_manager.model: | |
| return jsonify({"status": "error", "message": "Model not loaded yet"}), 400 | |
| data = request.json | |
| new_schedule_config = data.get("schedule_config") | |
| new_cfg_config = data.get("cfg_config") | |
| history_length = data.get("history_length") | |
| smoothing_alpha = data.get("smoothing_alpha") | |
| valid_schedule_keys = set(model_manager._base_schedule_config.keys()) | |
| valid_cfg_keys = set(model_manager._base_cfg_config.keys()) | |
| # Validate and update schedule_config | |
| if new_schedule_config: | |
| for key in new_schedule_config: | |
| if key not in valid_schedule_keys: | |
| return jsonify( | |
| { | |
| "status": "error", | |
| "message": f"Unknown schedule_config key: {key}", | |
| } | |
| ), 400 | |
| for key, value in new_schedule_config.items(): | |
| model_manager._base_schedule_config[key] = _coerce_value( | |
| value, model_manager._base_schedule_config[key] | |
| ) | |
| # Validate and update cfg_config | |
| if new_cfg_config: | |
| for key in new_cfg_config: | |
| if key not in valid_cfg_keys: | |
| return jsonify( | |
| {"status": "error", "message": f"Unknown cfg_config key: {key}"} | |
| ), 400 | |
| for key, value in new_cfg_config.items(): | |
| model_manager._base_cfg_config[key] = _coerce_value( | |
| value, model_manager._base_cfg_config[key] | |
| ) | |
| # Reset with new parameters | |
| model_manager.reset( | |
| history_length=history_length, | |
| smoothing_alpha=smoothing_alpha, | |
| ) | |
| # Clear active session | |
| with session_lock: | |
| active_session_id = None | |
| with consumption_monitor_lock: | |
| last_frame_consumed_time = None | |
| return jsonify({"status": "success"}) | |
| except Exception as e: | |
| traceback.print_exc() | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def start_generation(): | |
| """Start generation with given text""" | |
| try: | |
| global active_session_id, last_frame_consumed_time | |
| data = request.json | |
| session_id = data.get("session_id") | |
| text = data.get("text", "walk in a circle.") | |
| history_length = data.get("history_length") | |
| smoothing_alpha = data.get( | |
| "smoothing_alpha", None | |
| ) # Optional smoothing parameter | |
| force = data.get("force", False) # Allow force takeover | |
| if not session_id: | |
| return jsonify( | |
| {"status": "error", "message": "session_id is required"} | |
| ), 400 | |
| print( | |
| f"[Session {session_id}] Starting generation with text: {text}, history_length: {history_length}, force: {force}" | |
| ) | |
| # Initialize model if needed | |
| mm = init_model() | |
| # Check if another session is already generating | |
| need_force_takeover = False | |
| with session_lock: | |
| if active_session_id and active_session_id != session_id: | |
| if not force: | |
| # Another session is active, return conflict | |
| return jsonify( | |
| { | |
| "status": "error", | |
| "message": "Another session is already generating.", | |
| "conflict": True, | |
| "active_session_id": active_session_id, | |
| } | |
| ), 409 | |
| else: | |
| # Force takeover | |
| print( | |
| f"[Session {session_id}] Force takeover from session {active_session_id}" | |
| ) | |
| need_force_takeover = True | |
| if mm.is_generating and active_session_id == session_id: | |
| return jsonify( | |
| { | |
| "status": "error", | |
| "message": "Generation is already running for this session.", | |
| } | |
| ), 400 | |
| # Set this session as active | |
| active_session_id = session_id | |
| # Clear previous session's consumption tracking if force takeover (no nested locks) | |
| if need_force_takeover: | |
| with consumption_monitor_lock: | |
| last_frame_consumed_time = None | |
| # Reset and start generation | |
| mm.reset(history_length=history_length, smoothing_alpha=smoothing_alpha) | |
| mm.start_generation(text, history_length=history_length) | |
| # Initialize consumption tracking (no nested locks) | |
| with consumption_monitor_lock: | |
| last_frame_consumed_time = time.time() | |
| # Start consumption monitoring | |
| start_consumption_monitor() | |
| print(f"[Session {session_id}] Consumption monitoring activated") | |
| return jsonify( | |
| { | |
| "status": "success", | |
| "message": f"Generation started with text: {text}, history_length: {history_length}", | |
| "session_id": session_id, | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"Error in start_generation: {e}") | |
| traceback.print_exc() | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def update_text(): | |
| """Update the generation text""" | |
| try: | |
| data = request.json | |
| session_id = data.get("session_id") | |
| text = data.get("text", "") | |
| if not session_id: | |
| return jsonify( | |
| {"status": "error", "message": "session_id is required"} | |
| ), 400 | |
| # Verify this is the active session | |
| with session_lock: | |
| if active_session_id != session_id: | |
| return jsonify( | |
| {"status": "error", "message": "Not the active session"} | |
| ), 403 | |
| if model_manager is None: | |
| return jsonify({"status": "error", "message": "Model not initialized"}), 400 | |
| model_manager.update_text(text) | |
| return jsonify({"status": "success", "message": f"Text updated to: {text}"}) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def update_heading(): | |
| """Update heading override for world model mode. | |
| When heading is set, the character's direction follows the user's input | |
| while AI controls movement speed and animation. Set to null to return | |
| to AI-controlled direction. | |
| """ | |
| try: | |
| data = request.json | |
| session_id = data.get("session_id") | |
| heading = data.get("heading") # radians, or null for AI control | |
| if not session_id: | |
| return jsonify( | |
| {"status": "error", "message": "session_id is required"} | |
| ), 400 | |
| # Verify this is the active session | |
| with session_lock: | |
| if active_session_id != session_id: | |
| return jsonify( | |
| {"status": "error", "message": "Not the active session"} | |
| ), 403 | |
| if model_manager is None: | |
| return jsonify({"status": "error", "message": "Model not initialized"}), 400 | |
| model_manager.set_heading(heading) | |
| return jsonify({"status": "success"}) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def update_scene_context(): | |
| """Update scene context for perception-aware motion generation. | |
| Client sends environment scan data (wall distances, ground type, NPCs). | |
| Server uses this to build enhanced text prompts for FloodDiffusion. | |
| """ | |
| try: | |
| data = request.json | |
| session_id = data.get("session_id") | |
| ctx = data.get("context") # dict with wall_front, on_stairs, etc. | |
| if not session_id: | |
| return jsonify( | |
| {"status": "error", "message": "session_id is required"} | |
| ), 400 | |
| # Verify this is the active session | |
| with session_lock: | |
| if active_session_id != session_id: | |
| return jsonify( | |
| {"status": "error", "message": "Not the active session"} | |
| ), 403 | |
| if model_manager is None: | |
| return jsonify({"status": "error", "message": "Model not initialized"}), 400 | |
| model_manager.set_scene_context(ctx) | |
| # pass player position to NPC | |
| if model_manager.npc and ctx: | |
| px = ctx.get('player_x', 0) | |
| pz = ctx.get('player_z', 0) | |
| model_manager.npc.set_target(px, pz) | |
| return jsonify({"status": "success"}) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| # ββ NPC API ββ | |
| def npc_spawn(): | |
| """Spawn NPC.""" | |
| try: | |
| data = request.json or {} | |
| npc_type = data.get("type", "man") | |
| if model_manager is None: | |
| return jsonify({"status": "error", "message": "Model not initialized"}), 400 | |
| # already spawning β silently ignore (Lock-based) | |
| if hasattr(model_manager, '_npc_lock') and model_manager._npc_lock.locked(): | |
| return jsonify({"status": "busy", "message": "NPC loading"}) | |
| model_manager.spawn_npc(npc_type) | |
| return jsonify({"status": "success", "type": npc_type}) | |
| except Exception as e: | |
| print(f"[NPC spawn error] {type(e).__name__}: {e}") | |
| traceback.print_exc() | |
| return jsonify({"status": "error", "message": f"{type(e).__name__}: {e}"}), 500 | |
| def npc_command(): | |
| """Change NPC behavior.""" | |
| try: | |
| data = request.json or {} | |
| behavior = data.get("behavior", "stop") | |
| if model_manager and model_manager.npc: | |
| model_manager.npc.set_behavior(behavior) | |
| return jsonify({"status": "success", "behavior": behavior}) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def npc_despawn(): | |
| """Remove NPC.""" | |
| try: | |
| if model_manager: | |
| model_manager.despawn_npc() | |
| return jsonify({"status": "success"}) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def npc_frame(): | |
| """Return NPC frame + state.""" | |
| try: | |
| if not model_manager or not model_manager.npc: | |
| return jsonify({"frames": [], "npc": None}) | |
| count = request.args.get("count", 4, type=int) | |
| npc = model_manager.npc | |
| frames = [] | |
| for _ in range(count): | |
| frame = npc.frame_buffer.get_frame() | |
| if frame is not None: | |
| frames.append(frame.tolist()) | |
| state = npc.get_state() | |
| return jsonify({"frames": frames, "npc": state}) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def pause_generation(): | |
| """Pause generation (keeps state for resume)""" | |
| try: | |
| data = request.json if request.json else {} | |
| session_id = data.get("session_id") | |
| if not session_id: | |
| return jsonify( | |
| {"status": "error", "message": "session_id is required"} | |
| ), 400 | |
| # Verify this is the active session | |
| with session_lock: | |
| if active_session_id != session_id: | |
| return jsonify( | |
| {"status": "error", "message": "Not the active session"} | |
| ), 403 | |
| if model_manager: | |
| model_manager.pause_generation() | |
| return jsonify({"status": "success", "message": "Generation paused"}) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def resume_generation(): | |
| """Resume generation from paused state""" | |
| try: | |
| global last_frame_consumed_time | |
| data = request.json if request.json else {} | |
| session_id = data.get("session_id") | |
| if not session_id: | |
| return jsonify( | |
| {"status": "error", "message": "session_id is required"} | |
| ), 400 | |
| # Verify this is the active session | |
| with session_lock: | |
| if active_session_id != session_id: | |
| return jsonify( | |
| {"status": "error", "message": "Not the active session"} | |
| ), 403 | |
| if model_manager is None: | |
| return jsonify({"status": "error", "message": "Model not initialized"}), 400 | |
| model_manager.resume_generation() | |
| # Reset consumption tracking when resuming | |
| with consumption_monitor_lock: | |
| last_frame_consumed_time = time.time() | |
| return jsonify({"status": "success", "message": "Generation resumed"}) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def reset(): | |
| """Reset generation state""" | |
| try: | |
| global active_session_id, last_frame_consumed_time | |
| data = request.json if request.json else {} | |
| session_id = data.get("session_id") | |
| history_length = data.get("history_length") | |
| smoothing_alpha = data.get("smoothing_alpha") | |
| # If session_id provided, verify it's the active session | |
| if session_id: | |
| with session_lock: | |
| if active_session_id and active_session_id != session_id: | |
| return jsonify( | |
| {"status": "error", "message": "Not the active session"} | |
| ), 403 | |
| if model_manager: | |
| model_manager.reset( | |
| history_length=history_length, smoothing_alpha=smoothing_alpha | |
| ) | |
| # Clear the active session | |
| with session_lock: | |
| if active_session_id == session_id or not session_id: | |
| active_session_id = None | |
| # Clear consumption tracking | |
| with consumption_monitor_lock: | |
| last_frame_consumed_time = None | |
| print(f"[Session {session_id}] Reset complete, session cleared") | |
| return jsonify( | |
| { | |
| "status": "success", | |
| "message": "Reset complete", | |
| } | |
| ) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def get_frame(): | |
| """Get the next frame""" | |
| try: | |
| global last_frame_consumed_time | |
| session_id = request.args.get("session_id") | |
| if not session_id: | |
| return jsonify( | |
| {"status": "error", "message": "session_id is required"} | |
| ), 400 | |
| if model_manager is None: | |
| return jsonify({"status": "error", "message": "Model not initialized"}), 400 | |
| count = min(int(request.args.get("count", 8)), 20) | |
| # Check if this is the active session or a spectator | |
| with session_lock: | |
| is_active = active_session_id == session_id | |
| if is_active: | |
| # Active session: pop frames from generation buffer | |
| frames = [] | |
| for _ in range(count): | |
| joints = model_manager.get_next_frame() | |
| if joints is None: | |
| break | |
| frames.append(joints.tolist()) | |
| if frames: | |
| with consumption_monitor_lock: | |
| last_frame_consumed_time = time.time() | |
| return jsonify( | |
| { | |
| "status": "success", | |
| "frames": frames, | |
| "buffer_size": model_manager.frame_buffer.size(), | |
| } | |
| ) | |
| else: | |
| # Spectator: read from broadcast buffer (non-destructive) | |
| after_id = int(request.args.get("after_id", 0)) | |
| broadcast = model_manager.get_broadcast_frames(after_id, count) | |
| if broadcast: | |
| last_id = broadcast[-1][0] | |
| frames = [joints.tolist() for _, joints in broadcast] | |
| return jsonify( | |
| { | |
| "status": "success", | |
| "frames": frames, | |
| "last_id": last_id, | |
| "buffer_size": model_manager.frame_buffer.size(), | |
| } | |
| ) | |
| # No frames available (active or spectator) | |
| return jsonify( | |
| { | |
| "status": "waiting", | |
| "message": "No frame available yet", | |
| "buffer_size": model_manager.frame_buffer.size(), | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"Error in get_frame: {e}") | |
| traceback.print_exc() | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def get_status(): | |
| """Get generation status""" | |
| try: | |
| session_id = request.args.get("session_id") | |
| with session_lock: | |
| is_active_session = session_id and active_session_id == session_id | |
| current_active_session = active_session_id | |
| if model_manager is None: | |
| return jsonify( | |
| { | |
| "initialized": False, | |
| "buffer_size": 0, | |
| "is_generating": False, | |
| "is_active_session": is_active_session, | |
| "active_session_id": current_active_session, | |
| } | |
| ) | |
| status = model_manager.get_buffer_status() | |
| status["initialized"] = True | |
| status["is_active_session"] = is_active_session | |
| status["active_session_id"] = current_active_session | |
| return jsonify(status) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Flask server for real-time 3D motion generation" | |
| ) | |
| parser.add_argument( | |
| "--model_name", | |
| type=str, | |
| default="ShandaAI/FloodDiffusionTiny", | |
| help="HF Hub model name (default: ShandaAI/FloodDiffusionTiny)", | |
| ) | |
| parser.add_argument( | |
| "--port", | |
| type=int, | |
| default=7860, | |
| help="Port to run the server on (default: 7860)", | |
| ) | |
| args = parser.parse_args() | |
| model_name_global = args.model_name | |
| # Load model eagerly on startup (pre-downloaded in Docker) | |
| print(f"Loading model: {model_name_global}") | |
| init_model() | |
| print("Starting Flask server...") | |
| app.run(host="0.0.0.0", port=args.port, debug=False, threaded=True) |