| |
| """ |
| Memory Manager for BackgroundFX Pro |
| - Safe on CPU/CUDA/MPS (mostly CUDA/T4 on Spaces) |
| - Accepts `device` as str or torch.device |
| - Optional per-process VRAM cap (env or method) |
| - Detailed usage reporting (CPU/RAM + VRAM + torch allocator) |
| - Light and aggressive cleanup paths |
| - Background monitor (optional) |
| |
| Env switches: |
| BFX_DISABLE_LIMIT=1 -> do not set VRAM fraction automatically |
| BFX_CUDA_FRACTION=0.80 -> fraction to cap per-process VRAM (0.10..0.95) |
| """ |
|
|
| from __future__ import annotations |
| import gc |
| import os |
| import time |
| import logging |
| import threading |
| from typing import Dict, Any, Optional, Callable |
|
|
| |
| try: |
| import psutil |
| except Exception: |
| psutil = None |
|
|
| try: |
| import torch |
| except Exception: |
| torch = None |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| class MemoryManagerError(Exception): |
| pass |
|
|
|
|
| def _bytes_to_gb(x: int | float) -> float: |
| try: |
| return float(x) / (1024**3) |
| except Exception: |
| return 0.0 |
|
|
|
|
| def _normalize_device(dev) -> "torch.device": |
| if torch is None: |
| |
| class _Fake: |
| type = "cpu" |
| index = None |
| return _Fake() |
|
|
| if isinstance(dev, str): |
| return torch.device(dev) |
| if hasattr(dev, "type"): |
| return dev |
| |
| return torch.device("cpu") |
|
|
|
|
| def _cuda_index(device) -> Optional[int]: |
| if getattr(device, "type", "cpu") != "cuda": |
| return None |
| idx = getattr(device, "index", None) |
| if idx is None: |
| |
| return 0 |
| return int(idx) |
|
|
|
|
| class MemoryManager: |
| """ |
| Comprehensive memory management with VRAM cap + cleanup utilities. |
| """ |
|
|
| def __init__(self, device, memory_limit_gb: Optional[float] = None): |
| self.device = _normalize_device(device) |
| self.device_type = getattr(self.device, "type", "cpu") |
| self.cuda_idx = _cuda_index(self.device) |
|
|
| self.gpu_available = bool( |
| torch and self.device_type == "cuda" and torch.cuda.is_available() |
| ) |
| self.mps_available = bool( |
| torch and self.device_type == "mps" and getattr(torch.backends, "mps", None) |
| and torch.backends.mps.is_available() |
| ) |
|
|
| self.memory_limit_gb = memory_limit_gb |
| self.cleanup_callbacks: list[Callable] = [] |
| self.monitoring_active = False |
| self.monitoring_thread: Optional[threading.Thread] = None |
| self.stats = { |
| "cleanup_count": 0, |
| "peak_memory_usage": 0.0, |
| "total_allocated": 0.0, |
| "total_freed": 0.0, |
| } |
| self.applied_fraction: Optional[float] = None |
|
|
| self._initialize_memory_limits() |
| self._maybe_apply_vram_fraction() |
| logger.info(f"MemoryManager initialized (device={self.device}, cuda={self.gpu_available})") |
|
|
| |
| |
| |
| def _initialize_memory_limits(self): |
| try: |
| if self.gpu_available: |
| props = torch.cuda.get_device_properties(self.cuda_idx or 0) |
| total_gb = _bytes_to_gb(props.total_memory) |
| if self.memory_limit_gb is None: |
| self.memory_limit_gb = max(0.5, total_gb * 0.80) |
| logger.info( |
| f"CUDA memory limit baseline ~{self.memory_limit_gb:.1f}GB " |
| f"(device total {total_gb:.1f}GB)" |
| ) |
| elif self.mps_available: |
| vm = psutil.virtual_memory() if psutil else None |
| total_gb = _bytes_to_gb(vm.total) if vm else 0.0 |
| if self.memory_limit_gb is None: |
| self.memory_limit_gb = max(0.5, total_gb * 0.50) |
| logger.info(f"MPS memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)") |
| else: |
| vm = psutil.virtual_memory() if psutil else None |
| total_gb = _bytes_to_gb(vm.total) if vm else 0.0 |
| if self.memory_limit_gb is None: |
| self.memory_limit_gb = max(0.5, total_gb * 0.60) |
| logger.info(f"CPU memory baseline ~{self.memory_limit_gb:.1f}GB (system {total_gb:.1f}GB)") |
| except Exception as e: |
| logger.warning(f"Memory limit init failed: {e}") |
| if self.memory_limit_gb is None: |
| self.memory_limit_gb = 4.0 |
|
|
| def _maybe_apply_vram_fraction(self): |
| if not self.gpu_available or torch is None: |
| return |
| if os.environ.get("BFX_DISABLE_LIMIT", ""): |
| return |
| frac_env = os.environ.get("BFX_CUDA_FRACTION", "").strip() |
| try: |
| fraction = float(frac_env) if frac_env else 0.80 |
| except Exception: |
| fraction = 0.80 |
| applied = self.limit_cuda_memory(fraction=fraction) |
| if applied: |
| logger.info(f"Per-process CUDA memory fraction set to {applied:.2f} on device {self.cuda_idx or 0}") |
|
|
| |
| |
| |
| def get_memory_usage(self) -> Dict[str, Any]: |
| usage: Dict[str, Any] = { |
| "device_type": self.device_type, |
| "memory_limit_gb": self.memory_limit_gb, |
| "timestamp": time.time(), |
| } |
|
|
| |
| if psutil: |
| try: |
| vm = psutil.virtual_memory() |
| usage.update( |
| dict( |
| system_total_gb=round(_bytes_to_gb(vm.total), 3), |
| system_available_gb=round(_bytes_to_gb(vm.available), 3), |
| system_used_gb=round(_bytes_to_gb(vm.used), 3), |
| system_percent=float(vm.percent), |
| ) |
| ) |
| swap = psutil.swap_memory() |
| usage.update( |
| dict( |
| swap_total_gb=round(_bytes_to_gb(swap.total), 3), |
| swap_used_gb=round(_bytes_to_gb(swap.used), 3), |
| swap_percent=float(swap.percent), |
| ) |
| ) |
| proc = psutil.Process() |
| mi = proc.memory_info() |
| usage.update( |
| dict( |
| process_rss_gb=round(_bytes_to_gb(mi.rss), 3), |
| process_vms_gb=round(_bytes_to_gb(mi.vms), 3), |
| ) |
| ) |
| except Exception as e: |
| logger.debug(f"psutil stats error: {e}") |
|
|
| |
| if self.gpu_available and torch is not None: |
| try: |
| |
| free_b, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0) |
| used_b = total_b - free_b |
| usage.update( |
| dict( |
| vram_total_gb=round(_bytes_to_gb(total_b), 3), |
| vram_used_gb=round(_bytes_to_gb(used_b), 3), |
| vram_free_gb=round(_bytes_to_gb(free_b), 3), |
| vram_used_percent=float(used_b / total_b * 100.0) if total_b else 0.0, |
| ) |
| ) |
| except Exception as e: |
| logger.debug(f"mem_get_info failed: {e}") |
|
|
| |
| try: |
| idx = self.cuda_idx or 0 |
| allocated = torch.cuda.memory_allocated(idx) |
| reserved = torch.cuda.memory_reserved(idx) |
| usage["torch_allocated_gb"] = round(_bytes_to_gb(allocated), 3) |
| usage["torch_reserved_gb"] = round(_bytes_to_gb(reserved), 3) |
| |
| try: |
| inactive = torch.cuda.memory_stats(idx).get("inactive_split_bytes.all.current", 0) |
| usage["torch_inactive_split_gb"] = round(_bytes_to_gb(inactive), 3) |
| except Exception: |
| pass |
| except Exception as e: |
| logger.debug(f"allocator stats failed: {e}") |
|
|
| usage["applied_fraction"] = self.applied_fraction |
|
|
| |
| current = usage.get("vram_used_gb", usage.get("system_used_gb", 0.0)) |
| try: |
| if float(current) > float(self.stats["peak_memory_usage"]): |
| self.stats["peak_memory_usage"] = float(current) |
| except Exception: |
| pass |
|
|
| return usage |
|
|
| def limit_cuda_memory(self, fraction: Optional[float] = None, max_gb: Optional[float] = None) -> Optional[float]: |
| if not self.gpu_available or torch is None: |
| return None |
|
|
| |
| if max_gb is not None: |
| try: |
| _, total_b = torch.cuda.mem_get_info(self.cuda_idx or 0) |
| total_gb = _bytes_to_gb(total_b) |
| if total_gb <= 0: |
| return None |
| fraction = min(max(0.10, max_gb / total_gb), 0.95) |
| except Exception as e: |
| logger.debug(f"fraction from max_gb failed: {e}") |
| return None |
|
|
| if fraction is None: |
| fraction = 0.80 |
| fraction = float(max(0.10, min(0.95, fraction))) |
|
|
| try: |
| torch.cuda.set_per_process_memory_fraction(fraction, device=self.cuda_idx or 0) |
| self.applied_fraction = fraction |
| return fraction |
| except Exception as e: |
| logger.debug(f"set_per_process_memory_fraction failed: {e}") |
| return None |
|
|
| def cleanup(self) -> None: |
| """Light cleanup used frequently between steps.""" |
| try: |
| gc.collect() |
| except Exception: |
| pass |
| if self.gpu_available and torch is not None: |
| try: |
| torch.cuda.empty_cache() |
| except Exception: |
| pass |
| self.stats["cleanup_count"] += 1 |
|
|
| def cleanup_basic(self) -> None: |
| """Alias kept for compatibility.""" |
| self.cleanup() |
|
|
| def cleanup_aggressive(self) -> None: |
| """Aggressive cleanup for OOM recovery or big scene switches.""" |
| if self.gpu_available and torch is not None: |
| try: |
| torch.cuda.synchronize(self.cuda_idx or 0) |
| except Exception: |
| pass |
| try: |
| torch.cuda.empty_cache() |
| except Exception: |
| pass |
| try: |
| torch.cuda.reset_peak_memory_stats(self.cuda_idx or 0) |
| except Exception: |
| pass |
| try: |
| if hasattr(torch.cuda, "ipc_collect"): |
| torch.cuda.ipc_collect() |
| except Exception: |
| pass |
| try: |
| gc.collect(); gc.collect() |
| except Exception: |
| pass |
| self.stats["cleanup_count"] += 1 |
|
|
| def register_cleanup_callback(self, callback: Callable): |
| self.cleanup_callbacks.append(callback) |
|
|
| def start_monitoring(self, interval_seconds: float = 30.0, pressure_callback: Optional[Callable] = None): |
| if self.monitoring_active: |
| logger.warning("Memory monitoring already active") |
| return |
| self.monitoring_active = True |
|
|
| def loop(): |
| while self.monitoring_active: |
| try: |
| pressure = self.check_memory_pressure() |
| if pressure["under_pressure"]: |
| logger.warning( |
| f"Memory pressure: {pressure['pressure_level']} " |
| f"({pressure['usage_percent']:.1f}%)" |
| ) |
| if pressure_callback: |
| try: |
| pressure_callback(pressure) |
| except Exception as e: |
| logger.error(f"Pressure callback failed: {e}") |
| if pressure["pressure_level"] == "critical": |
| self.cleanup_aggressive() |
| except Exception as e: |
| logger.error(f"Memory monitoring error: {e}") |
| time.sleep(interval_seconds) |
|
|
| self.monitoring_thread = threading.Thread(target=loop, daemon=True) |
| self.monitoring_thread.start() |
| logger.info(f"Memory monitoring started (interval: {interval_seconds}s)") |
|
|
| def stop_monitoring(self): |
| if self.monitoring_active: |
| self.monitoring_active = False |
| if self.monitoring_thread and self.monitoring_thread.is_alive(): |
| self.monitoring_thread.join(timeout=5.0) |
| logger.info("Memory monitoring stopped") |
|
|
| def check_memory_pressure(self, threshold_percent: float = 85.0) -> Dict[str, Any]: |
| usage = self.get_memory_usage() |
| info = { |
| "under_pressure": False, |
| "pressure_level": "normal", |
| "usage_percent": 0.0, |
| "recommendations": [], |
| } |
|
|
| if self.gpu_available: |
| percent = usage.get("vram_used_percent", 0.0) |
| info["usage_percent"] = percent |
| if percent >= threshold_percent: |
| info["under_pressure"] = True |
| if percent >= 95: |
| info["pressure_level"] = "critical" |
| info["recommendations"] += [ |
| "Run aggressive memory cleanup", |
| "Reduce frame cache / chunk size", |
| "Lower resolution or disable previews", |
| ] |
| else: |
| info["pressure_level"] = "warning" |
| info["recommendations"] += [ |
| "Run cleanup", |
| "Monitor memory usage", |
| "Reduce keyframe interval", |
| ] |
| else: |
| percent = usage.get("system_percent", 0.0) |
| info["usage_percent"] = percent |
| if percent >= threshold_percent: |
| info["under_pressure"] = True |
| if percent >= 95: |
| info["pressure_level"] = "critical" |
| info["recommendations"] += [ |
| "Close other processes", |
| "Reduce resolution", |
| "Split video into chunks", |
| ] |
| else: |
| info["pressure_level"] = "warning" |
| info["recommendations"] += [ |
| "Run cleanup", |
| "Monitor usage", |
| "Reduce processing footprint", |
| ] |
| return info |
|
|
| def estimate_memory_requirement(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, float]: |
| bytes_per_frame = video_width * video_height * 3 |
| overhead_multiplier = 3.0 |
| frames_gb = _bytes_to_gb(bytes_per_frame * frames_in_memory * overhead_multiplier) |
| estimate = { |
| "frames_memory_gb": round(frames_gb, 3), |
| "model_memory_gb": 4.0, |
| "system_overhead_gb": 2.0, |
| } |
| estimate["total_estimated_gb"] = round( |
| estimate["frames_memory_gb"] + estimate["model_memory_gb"] + estimate["system_overhead_gb"], 3 |
| ) |
| return estimate |
|
|
| def can_process_video(self, video_width: int, video_height: int, frames_in_memory: int = 5) -> Dict[str, Any]: |
| estimate = self.estimate_memory_requirement(video_width, video_height, frames_in_memory) |
| usage = self.get_memory_usage() |
| if self.gpu_available: |
| available = usage.get("vram_free_gb", 0.0) |
| else: |
| available = usage.get("system_available_gb", 0.0) |
|
|
| can = estimate["total_estimated_gb"] <= available |
| return { |
| "can_process": can, |
| "estimated_memory_gb": estimate["total_estimated_gb"], |
| "available_memory_gb": available, |
| "memory_margin_gb": round(available - estimate["total_estimated_gb"], 3), |
| "recommendations": [] if can else [ |
| "Reduce resolution or duration", |
| "Process in smaller chunks", |
| "Run aggressive cleanup before start", |
| ], |
| } |
|
|
| def get_stats(self) -> Dict[str, Any]: |
| return { |
| "cleanup_count": self.stats["cleanup_count"], |
| "peak_memory_usage_gb": self.stats["peak_memory_usage"], |
| "device_type": self.device_type, |
| "memory_limit_gb": self.memory_limit_gb, |
| "applied_fraction": self.applied_fraction, |
| "monitoring_active": self.monitoring_active, |
| "callbacks_registered": len(self.cleanup_callbacks), |
| } |
|
|
| def __del__(self): |
| try: |
| self.stop_monitoring() |
| self.cleanup_aggressive() |
| except Exception: |
| pass |
|
|