Spaces:
Sleeping
Sleeping
| """ | |
| Model Cache with LRU Eviction | |
| Intelligent caching system for Micro-SLMs to minimize loading time. | |
| Keeps the most recently used models in memory. | |
| """ | |
| import logging | |
| import asyncio | |
| from collections import OrderedDict | |
| from typing import Optional, Callable, Any, Dict | |
| from datetime import datetime | |
| try: | |
| import psutil | |
| HAS_PSUTIL = True | |
| except ImportError: | |
| HAS_PSUTIL = False | |
| logger = logging.getLogger(__name__) | |
| class ModelCache: | |
| """ | |
| LRU (Least Recently Used) Cache for Micro-SLM models. | |
| Features: | |
| - Automatic eviction of least recently used models | |
| - Memory usage tracking | |
| - Async model loading | |
| - Thread-safe operations | |
| """ | |
| def __init__( | |
| self, | |
| max_models: int = 3, | |
| max_memory_mb: int = 2000, | |
| enable_stats: bool = True | |
| ): | |
| """ | |
| Initialize the model cache. | |
| Args: | |
| max_models: Maximum number of models to keep in cache | |
| max_memory_mb: Maximum memory usage in MB (soft limit) | |
| enable_stats: Enable statistics tracking | |
| """ | |
| self.cache: OrderedDict[str, Any] = OrderedDict() | |
| self.max_models = max_models | |
| self.max_memory_mb = max_memory_mb | |
| self.enable_stats = enable_stats | |
| # Statistics | |
| self.stats = { | |
| "hits": 0, | |
| "misses": 0, | |
| "evictions": 0, | |
| "loads": 0 | |
| } | |
| # Locks for thread safety | |
| self._lock = asyncio.Lock() | |
| logger.info(f"ModelCache initialized: max_models={max_models}, max_memory={max_memory_mb}MB") | |
| async def get_or_load( | |
| self, | |
| model_name: str, | |
| loader_func: Callable, | |
| *args, | |
| **kwargs | |
| ) -> Any: | |
| """ | |
| Get model from cache or load it if not present. | |
| Args: | |
| model_name: Unique identifier for the model | |
| loader_func: Async function to load the model | |
| *args, **kwargs: Arguments to pass to loader_func | |
| Returns: | |
| The loaded model instance | |
| """ | |
| async with self._lock: | |
| # Check cache | |
| if model_name in self.cache: | |
| # Cache hit | |
| self.cache.move_to_end(model_name) | |
| self.stats["hits"] += 1 | |
| logger.info( | |
| f"✅ Cache HIT: {model_name} " | |
| f"(hit rate: {self.get_hit_rate():.1%})" | |
| ) | |
| return self.cache[model_name] | |
| # Cache miss - need to load | |
| self.stats["misses"] += 1 | |
| logger.info(f"❌ Cache MISS: {model_name}") | |
| # Check if we need to evict | |
| await self._evict_if_needed() | |
| # Load the model | |
| logger.info(f"📥 Loading model: {model_name}...") | |
| load_start = datetime.now() | |
| try: | |
| model = await loader_func(*args, **kwargs) | |
| load_duration = (datetime.now() - load_start).total_seconds() | |
| logger.info(f"✓ Loaded {model_name} in {load_duration:.2f}s") | |
| # Add to cache | |
| self.cache[model_name] = model | |
| self.cache.move_to_end(model_name) | |
| self.stats["loads"] += 1 | |
| return model | |
| except Exception as e: | |
| logger.error(f"Failed to load {model_name}: {e}") | |
| raise | |
| async def _evict_if_needed(self): | |
| """Evict least recently used model if cache is full""" | |
| # Check model count limit | |
| if len(self.cache) >= self.max_models: | |
| await self._evict_oldest() | |
| return | |
| # Check memory limit | |
| memory_usage = self._get_memory_usage_mb() | |
| if memory_usage > self.max_memory_mb: | |
| logger.warning( | |
| f"Memory usage ({memory_usage:.0f}MB) exceeds limit " | |
| f"({self.max_memory_mb}MB)" | |
| ) | |
| await self._evict_oldest() | |
| async def _evict_oldest(self): | |
| """Evict the least recently used model""" | |
| if not self.cache: | |
| return | |
| # Get oldest (first) item | |
| oldest_name = next(iter(self.cache)) | |
| oldest_model = self.cache.pop(oldest_name) | |
| self.stats["evictions"] += 1 | |
| logger.info(f"🗑️ Evicting: {oldest_name}") | |
| # Cleanup model resources | |
| try: | |
| if hasattr(oldest_model, 'shutdown'): | |
| await oldest_model.shutdown() | |
| elif hasattr(oldest_model, 'cleanup'): | |
| await oldest_model.cleanup() | |
| except Exception as e: | |
| logger.warning(f"Error during model cleanup: {e}") | |
| def _get_memory_usage_mb(self) -> float: | |
| """Get current process memory usage in MB""" | |
| if not HAS_PSUTIL: | |
| return 0.0 | |
| try: | |
| process = psutil.Process() | |
| return process.memory_info().rss / (1024 * 1024) | |
| except Exception: | |
| return 0.0 | |
| def get_hit_rate(self) -> float: | |
| """Calculate cache hit rate""" | |
| total = self.stats["hits"] + self.stats["misses"] | |
| if total == 0: | |
| return 0.0 | |
| return self.stats["hits"] / total | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get cache statistics""" | |
| return { | |
| **self.stats, | |
| "cached_models": len(self.cache), | |
| "model_names": list(self.cache.keys()), | |
| "hit_rate": self.get_hit_rate(), | |
| "memory_usage_mb": self._get_memory_usage_mb() | |
| } | |
| async def clear(self): | |
| """Clear all cached models""" | |
| async with self._lock: | |
| logger.info("Clearing model cache...") | |
| for name, model in self.cache.items(): | |
| try: | |
| if hasattr(model, 'shutdown'): | |
| await model.shutdown() | |
| except Exception as e: | |
| logger.warning(f"Error shutting down {name}: {e}") | |
| self.cache.clear() | |
| logger.info("Cache cleared") | |
| async def preload(self, model_name: str, loader_func: Callable, *args, **kwargs): | |
| """ | |
| Preload a model into cache (prefetching). | |
| Useful for anticipating which model will be needed next. | |
| """ | |
| logger.info(f"🔮 Prefetching: {model_name}") | |
| await self.get_or_load(model_name, loader_func, *args, **kwargs) | |
| def contains(self, model_name: str) -> bool: | |
| """Check if model is in cache""" | |
| return model_name in self.cache | |
| async def remove(self, model_name: str): | |
| """Manually remove a model from cache""" | |
| async with self._lock: | |
| if model_name in self.cache: | |
| model = self.cache.pop(model_name) | |
| logger.info(f"Removed {model_name} from cache") | |
| try: | |
| if hasattr(model, 'shutdown'): | |
| await model.shutdown() | |
| except Exception as e: | |
| logger.warning(f"Error shutting down {model_name}: {e}") | |
| # Global cache instance | |
| model_cache = ModelCache( | |
| max_models=3, # Keep 3 models in memory | |
| max_memory_mb=2000, # 2GB soft limit | |
| enable_stats=True | |
| ) | |