| |
| """ |
| Model Loading and Memory Management |
| Handles lazy loading of SAM2 and MatAnyone models with caching. |
| (Enhanced logging, error handling, and memory safety) |
| """ |
|
|
| import os |
| import gc |
| import logging |
| import streamlit as st |
| import torch |
| import psutil |
| from contextlib import contextmanager |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| @contextmanager |
| def torch_memory_manager(): |
| try: |
| logger.info("[torch_memory_manager] Enter") |
| yield |
| finally: |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| gc.collect() |
| logger.info("[torch_memory_manager] Exit, cleaned up") |
|
|
| def get_memory_usage(): |
| memory_info = {} |
| if torch.cuda.is_available(): |
| memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9 |
| memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9 |
| memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory - |
| torch.cuda.memory_allocated()) / 1e9 |
| memory_info['ram_used'] = psutil.virtual_memory().used / 1e9 |
| memory_info['ram_available'] = psutil.virtual_memory().available / 1e9 |
| logger.info(f"[get_memory_usage] {memory_info}") |
| return memory_info |
|
|
| def clear_model_cache(): |
| """Manual/debug only: Clear Streamlit resource cache and free memory.""" |
| logger.info("[clear_model_cache] Clearing all model caches...") |
| if hasattr(st, 'cache_resource'): |
| st.cache_resource.clear() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| gc.collect() |
| logger.info("[clear_model_cache] Model cache cleared") |
|
|
| @st.cache_resource(show_spinner=False) |
| def load_sam2_predictor(): |
| """Load SAM2 image predictor, choosing model size based on available GPU memory.""" |
| try: |
| logger.info("[load_sam2_predictor] Loading SAM2 image predictor...") |
| from sam2.build_sam import build_sam2 |
| from sam2.sam2_image_predictor import SAM2ImagePredictor |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"[load_sam2_predictor] Using device: {device}") |
| checkpoint_path = "/home/user/app/checkpoints/sam2.1_hiera_large.pt" |
| model_cfg = "/home/user/app/configs/sam2.1/sam2.1_hiera_l.yaml" |
| if not os.path.exists(checkpoint_path) or not os.path.exists(model_cfg): |
| logger.warning("[load_sam2_predictor] Local checkpoints not found, using Hugging Face.") |
| predictor = SAM2ImagePredictor.from_pretrained( |
| "facebook/sam2-hiera-large", |
| device=device |
| ) |
| else: |
| memory_info = get_memory_usage() |
| gpu_free = memory_info.get('gpu_free', 0) |
| if device == "cuda" and gpu_free < 4.0: |
| logger.warning(f"[load_sam2_predictor] Limited GPU memory ({gpu_free:.1f}GB), using smaller SAM2 model.") |
| try: |
| predictor = SAM2ImagePredictor.from_pretrained( |
| "facebook/sam2-hiera-tiny", |
| device=device |
| ) |
| except Exception as e: |
| logger.warning(f"[load_sam2_predictor] Tiny model failed, trying small. {e}") |
| predictor = SAM2ImagePredictor.from_pretrained( |
| "facebook/sam2-hiera-small", |
| device=device |
| ) |
| else: |
| logger.info("[load_sam2_predictor] Using local large model") |
| sam2_model = build_sam2(model_cfg, checkpoint_path, device=device) |
| predictor = SAM2ImagePredictor(sam2_model) |
| if hasattr(predictor, 'model'): |
| predictor.model.to(device) |
| predictor.model.eval() |
| logger.info(f"[load_sam2_predictor] SAM2 model moved to {device} and set to eval mode") |
| logger.info(f"✅ SAM2 loaded successfully on {device}!") |
| return predictor |
| except Exception as e: |
| logger.error(f"❌ Failed to load SAM2 predictor: {e}", exc_info=True) |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
| def load_sam2(): |
| """Convenience alias for legacy code: returns only the predictor object.""" |
| predictor = load_sam2_predictor() |
| return predictor |
|
|
| @st.cache_resource(show_spinner=False) |
| def load_matanyone_processor(): |
| """Load MatAnyone processor (inference core) on the best available device.""" |
| try: |
| logger.info("[load_matanyone_processor] Loading MatAnyone processor...") |
| from matanyone import InferenceCore |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"[load_matanyone_processor] MatAnyone using device: {device}") |
| try: |
| processor = InferenceCore("PeiqingYang/MatAnyone", device=device) |
| except Exception as e: |
| logger.warning(f"[load_matanyone_processor] Path warning caught: {e}") |
| processor = InferenceCore("PeiqingYang/MatAnyone", device=device) |
| if hasattr(processor, 'model'): |
| processor.model.to(device) |
| processor.model.eval() |
| logger.info(f"[load_matanyone_processor] MatAnyone model explicitly moved to {device}") |
| if not hasattr(processor, 'device'): |
| processor.device = device |
| logger.info(f"[load_matanyone_processor] Set processor.device to {device}") |
| logger.info(f"✅ MatAnyone loaded successfully on {device}!") |
| return processor |
| except Exception as e: |
| logger.error(f"❌ Failed to load MatAnyone: {e}", exc_info=True) |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
| def load_matanyone(): |
| """Convenience alias for legacy code: returns only the processor object.""" |
| processor = load_matanyone_processor() |
| return processor |
|
|
| def test_models(): |
| """For admin/diagnosis: attempts to load both models and returns status.""" |
| results = { |
| 'sam2': {'loaded': False, 'error': None}, |
| 'matanyone': {'loaded': False, 'error': None} |
| } |
| try: |
| sam2_predictor = load_sam2_predictor() |
| if sam2_predictor is not None: |
| results['sam2']['loaded'] = True |
| else: |
| results['sam2']['error'] = "Predictor returned None" |
| except Exception as e: |
| results['sam2']['error'] = str(e) |
| logger.error(f"[test_models] SAM2 error: {e}", exc_info=True) |
| try: |
| matanyone_processor = load_matanyone_processor() |
| if matanyone_processor is not None: |
| results['matanyone']['loaded'] = True |
| else: |
| results['matanyone']['error'] = "Processor returned None" |
| except Exception as e: |
| results['matanyone']['error'] = str(e) |
| logger.error(f"[test_models] MatAnyone error: {e}", exc_info=True) |
| logger.info(f"[test_models] Results: {results}") |
| return results |
|
|
| def log_memory_usage(stage=""): |
| memory_info = get_memory_usage() |
| log_msg = f"Memory usage" |
| if stage: |
| log_msg += f" ({stage})" |
| log_msg += ":" |
| if 'gpu_allocated' in memory_info: |
| log_msg += f" GPU {memory_info['gpu_allocated']:.1f}GB allocated, {memory_info['gpu_free']:.1f}GB free" |
| log_msg += f" | RAM {memory_info['ram_used']:.1f}GB used" |
| print(log_msg, flush=True) |
| logger.info(log_msg) |
| return memory_info |
|
|
| def check_memory_available(required_gb=2.0): |
| if not torch.cuda.is_available(): |
| return False, 0.0 |
| memory_info = get_memory_usage() |
| free_gb = memory_info.get('gpu_free', 0) |
| logger.info(f"[check_memory_available] free_gb={free_gb}, required={required_gb}") |
| return free_gb >= required_gb, free_gb |
|
|
| def free_memory_aggressive(): |
| """For emergency/manual use only! Do NOT call after every video or from UI!""" |
| logger.info("[free_memory_aggressive] Performing aggressive memory cleanup...") |
| print("Performing aggressive memory cleanup...", flush=True) |
| clear_model_cache() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| torch.cuda.synchronize() |
| try: |
| torch.cuda.ipc_collect() |
| except Exception: |
| pass |
| gc.collect() |
| print("Memory cleanup complete", flush=True) |
| logger.info("Memory cleanup complete") |
| log_memory_usage("after cleanup") |
|
|