| """ |
| Cache Management and SAM2 Loading Utilities |
| Comprehensive cache cleaning system to resolve model loading issues on HF Spaces |
| """ |
|
|
| import os |
| import gc |
| import sys |
| import shutil |
| import tempfile |
| import logging |
| import traceback |
| from pathlib import Path |
| from typing import Optional, Dict, Any, Tuple |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class HardCacheCleaner: |
| """ |
| Comprehensive cache cleaning system to resolve SAM2 loading issues |
| Clears Python module cache, HuggingFace cache, and temp files |
| """ |
| |
| @staticmethod |
| def clean_all_caches(verbose: bool = True): |
| """Clean all caches that might interfere with SAM2 loading""" |
| |
| if verbose: |
| logger.info("Starting comprehensive cache cleanup...") |
| |
| |
| HardCacheCleaner._clean_python_cache(verbose) |
| |
| |
| HardCacheCleaner._clean_huggingface_cache(verbose) |
| |
| |
| HardCacheCleaner._clean_pytorch_cache(verbose) |
| |
| |
| HardCacheCleaner._clean_temp_directories(verbose) |
| |
| |
| HardCacheCleaner._clear_import_cache(verbose) |
| |
| |
| HardCacheCleaner._force_gc_cleanup(verbose) |
| |
| if verbose: |
| logger.info("Cache cleanup completed") |
| |
| @staticmethod |
| def _clean_python_cache(verbose: bool = True): |
| """Clean Python bytecode cache""" |
| try: |
| |
| sam2_modules = [key for key in sys.modules.keys() if 'sam2' in key.lower()] |
| for module in sam2_modules: |
| if verbose: |
| logger.info(f"Removing cached module: {module}") |
| del sys.modules[module] |
| |
| |
| for root, dirs, files in os.walk("."): |
| for dir_name in dirs[:]: |
| if dir_name == "__pycache__": |
| cache_path = os.path.join(root, dir_name) |
| if verbose: |
| logger.info(f"Removing __pycache__: {cache_path}") |
| shutil.rmtree(cache_path, ignore_errors=True) |
| dirs.remove(dir_name) |
| |
| except Exception as e: |
| logger.warning(f"Python cache cleanup failed: {e}") |
| |
| @staticmethod |
| def _clean_huggingface_cache(verbose: bool = True): |
| """Clean HuggingFace model cache""" |
| try: |
| |
| from config.app_config import get_config |
| config = get_config() |
| |
| cache_paths = [ |
| os.path.expanduser("~/.cache/huggingface/"), |
| os.path.expanduser("~/.cache/torch/"), |
| config.model_cache_dir, |
| "./checkpoints/", |
| "./.cache/", |
| ] |
| |
| for cache_path in cache_paths: |
| if os.path.exists(cache_path): |
| if verbose: |
| logger.info(f"Cleaning cache directory: {cache_path}") |
| |
| |
| for root, dirs, files in os.walk(cache_path): |
| for file in files: |
| if any(pattern in file.lower() for pattern in ['sam2', 'segment-anything-2']): |
| file_path = os.path.join(root, file) |
| try: |
| os.remove(file_path) |
| if verbose: |
| logger.info(f"Removed cached file: {file_path}") |
| except: |
| pass |
| |
| for dir_name in dirs[:]: |
| if any(pattern in dir_name.lower() for pattern in ['sam2', 'segment-anything-2']): |
| dir_path = os.path.join(root, dir_name) |
| try: |
| shutil.rmtree(dir_path, ignore_errors=True) |
| if verbose: |
| logger.info(f"Removed cached directory: {dir_path}") |
| dirs.remove(dir_name) |
| except: |
| pass |
| |
| except Exception as e: |
| logger.warning(f"HuggingFace cache cleanup failed: {e}") |
| |
| @staticmethod |
| def _clean_pytorch_cache(verbose: bool = True): |
| """Clean PyTorch cache""" |
| try: |
| import torch |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| if verbose: |
| logger.info("Cleared PyTorch CUDA cache") |
| except Exception as e: |
| logger.warning(f"PyTorch cache cleanup failed: {e}") |
| |
| @staticmethod |
| def _clean_temp_directories(verbose: bool = True): |
| """Clean temporary directories""" |
| try: |
| from config.app_config import get_config |
| config = get_config() |
| |
| temp_dirs = [ |
| config.temp_dir, |
| tempfile.gettempdir(), |
| "/tmp", |
| "./tmp", |
| "./temp" |
| ] |
| |
| for temp_dir in temp_dirs: |
| if os.path.exists(temp_dir): |
| for item in os.listdir(temp_dir): |
| if 'sam2' in item.lower() or 'segment' in item.lower(): |
| item_path = os.path.join(temp_dir, item) |
| try: |
| if os.path.isfile(item_path): |
| os.remove(item_path) |
| elif os.path.isdir(item_path): |
| shutil.rmtree(item_path, ignore_errors=True) |
| if verbose: |
| logger.info(f"Removed temp item: {item_path}") |
| except: |
| pass |
| |
| except Exception as e: |
| logger.warning(f"Temp directory cleanup failed: {e}") |
| |
| @staticmethod |
| def _clear_import_cache(verbose: bool = True): |
| """Clear Python import cache""" |
| try: |
| import importlib |
| |
| |
| importlib.invalidate_caches() |
| |
| if verbose: |
| logger.info("Cleared Python import cache") |
| |
| except Exception as e: |
| logger.warning(f"Import cache cleanup failed: {e}") |
| |
| @staticmethod |
| def _force_gc_cleanup(verbose: bool = True): |
| """Force garbage collection""" |
| try: |
| collected = gc.collect() |
| if verbose: |
| logger.info(f"Garbage collection freed {collected} objects") |
| except Exception as e: |
| logger.warning(f"Garbage collection failed: {e}") |
|
|
|
|
| class WorkingSAM2Loader: |
| """ |
| SAM2 loader using HuggingFace Transformers integration - proven to work on HF Spaces |
| This avoids all the config file and CUDA compilation issues |
| """ |
| |
| @staticmethod |
| def load_sam2_transformers_approach(device: str = "cuda", model_size: str = "large") -> Optional[Any]: |
| """ |
| Load SAM2 using HuggingFace Transformers integration |
| This method works reliably on HuggingFace Spaces |
| """ |
| try: |
| logger.info("Loading SAM2 via HuggingFace Transformers...") |
| |
| |
| model_map = { |
| "tiny": "facebook/sam2.1-hiera-tiny", |
| "small": "facebook/sam2.1-hiera-small", |
| "base": "facebook/sam2.1-hiera-base-plus", |
| "large": "facebook/sam2.1-hiera-large" |
| } |
| |
| model_id = model_map.get(model_size, model_map["large"]) |
| logger.info(f"Using model: {model_id}") |
| |
| |
| try: |
| from transformers import pipeline |
| |
| sam2_pipeline = pipeline( |
| "mask-generation", |
| model=model_id, |
| device=0 if device == "cuda" else -1 |
| ) |
| |
| logger.info("SAM2 loaded successfully via Transformers pipeline") |
| return sam2_pipeline |
| |
| except Exception as e: |
| logger.warning(f"Pipeline approach failed: {e}") |
| |
| |
| try: |
| from transformers import Sam2Processor, Sam2Model |
| |
| processor = Sam2Processor.from_pretrained(model_id) |
| model = Sam2Model.from_pretrained(model_id).to(device) |
| |
| logger.info("SAM2 loaded successfully via Transformers classes") |
| return {"model": model, "processor": processor} |
| |
| except Exception as e: |
| logger.warning(f"Direct class approach failed: {e}") |
| |
| |
| try: |
| from sam2.sam2_image_predictor import SAM2ImagePredictor |
| |
| predictor = SAM2ImagePredictor.from_pretrained(model_id) |
| |
| logger.info("SAM2 loaded successfully via official from_pretrained") |
| return predictor |
| |
| except Exception as e: |
| logger.warning(f"Official from_pretrained approach failed: {e}") |
| |
| return None |
| |
| except Exception as e: |
| logger.error(f"All SAM2 loading methods failed: {e}") |
| return None |
| |
| @staticmethod |
| def load_sam2_fallback_approach(device: str = "cuda") -> Optional[Any]: |
| """ |
| Fallback approach using direct model loading |
| """ |
| try: |
| logger.info("Trying fallback SAM2 loading approach...") |
| |
| |
| from huggingface_hub import hf_hub_download |
| import torch |
| |
| |
| checkpoint_path = hf_hub_download( |
| repo_id="facebook/sam2.1-hiera-large", |
| filename="sam2_hiera_large.pt" |
| ) |
| |
| logger.info(f"Downloaded checkpoint to: {checkpoint_path}") |
| |
| |
| try: |
| |
| from transformers import Sam2Model |
| model = Sam2Model.from_pretrained("facebook/sam2.1-hiera-large") |
| return model.to(device) |
| |
| except Exception as e: |
| logger.warning(f"Transformers fallback failed: {e}") |
| |
| return None |
| |
| except Exception as e: |
| logger.error(f"Fallback loading failed: {e}") |
| return None |
|
|
|
|
| def load_sam2_with_cache_cleanup( |
| device: str = "cuda", |
| model_size: str = "large", |
| force_cache_clean: bool = True, |
| verbose: bool = True |
| ) -> Tuple[Optional[Any], str]: |
| """ |
| Load SAM2 with comprehensive cache cleanup |
| |
| Returns: |
| Tuple of (model, status_message) |
| """ |
| |
| status_messages = [] |
| |
| try: |
| |
| if force_cache_clean: |
| status_messages.append("Cleaning caches...") |
| HardCacheCleaner.clean_all_caches(verbose=verbose) |
| status_messages.append("Cache cleanup completed") |
| |
| |
| status_messages.append("Loading SAM2 (primary method)...") |
| model = WorkingSAM2Loader.load_sam2_transformers_approach(device, model_size) |
| |
| if model is not None: |
| status_messages.append("SAM2 loaded successfully!") |
| return model, "\n".join(status_messages) |
| |
| |
| status_messages.append("Trying fallback loading method...") |
| model = WorkingSAM2Loader.load_sam2_fallback_approach(device) |
| |
| if model is not None: |
| status_messages.append("SAM2 loaded successfully (fallback)!") |
| return model, "\n".join(status_messages) |
| |
| |
| status_messages.append("All SAM2 loading methods failed") |
| return None, "\n".join(status_messages) |
| |
| except Exception as e: |
| error_msg = f"Critical error in SAM2 loading: {e}" |
| logger.error(f"{error_msg}\n{traceback.format_exc()}") |
| status_messages.append(error_msg) |
| return None, "\n".join(status_messages) |