""" Model Cache with LRU Eviction Intelligent caching system for Micro-SLMs to minimize loading time. Keeps the most recently used models in memory. """ import logging import asyncio from collections import OrderedDict from typing import Optional, Callable, Any, Dict from datetime import datetime try: import psutil HAS_PSUTIL = True except ImportError: HAS_PSUTIL = False logger = logging.getLogger(__name__) class ModelCache: """ LRU (Least Recently Used) Cache for Micro-SLM models. Features: - Automatic eviction of least recently used models - Memory usage tracking - Async model loading - Thread-safe operations """ def __init__( self, max_models: int = 3, max_memory_mb: int = 2000, enable_stats: bool = True ): """ Initialize the model cache. Args: max_models: Maximum number of models to keep in cache max_memory_mb: Maximum memory usage in MB (soft limit) enable_stats: Enable statistics tracking """ self.cache: OrderedDict[str, Any] = OrderedDict() self.max_models = max_models self.max_memory_mb = max_memory_mb self.enable_stats = enable_stats # Statistics self.stats = { "hits": 0, "misses": 0, "evictions": 0, "loads": 0 } # Locks for thread safety self._lock = asyncio.Lock() logger.info(f"ModelCache initialized: max_models={max_models}, max_memory={max_memory_mb}MB") async def get_or_load( self, model_name: str, loader_func: Callable, *args, **kwargs ) -> Any: """ Get model from cache or load it if not present. Args: model_name: Unique identifier for the model loader_func: Async function to load the model *args, **kwargs: Arguments to pass to loader_func Returns: The loaded model instance """ async with self._lock: # Check cache if model_name in self.cache: # Cache hit self.cache.move_to_end(model_name) self.stats["hits"] += 1 logger.info( f"✅ Cache HIT: {model_name} " f"(hit rate: {self.get_hit_rate():.1%})" ) return self.cache[model_name] # Cache miss - need to load self.stats["misses"] += 1 logger.info(f"❌ Cache MISS: {model_name}") # Check if we need to evict await self._evict_if_needed() # Load the model logger.info(f"📥 Loading model: {model_name}...") load_start = datetime.now() try: model = await loader_func(*args, **kwargs) load_duration = (datetime.now() - load_start).total_seconds() logger.info(f"✓ Loaded {model_name} in {load_duration:.2f}s") # Add to cache self.cache[model_name] = model self.cache.move_to_end(model_name) self.stats["loads"] += 1 return model except Exception as e: logger.error(f"Failed to load {model_name}: {e}") raise async def _evict_if_needed(self): """Evict least recently used model if cache is full""" # Check model count limit if len(self.cache) >= self.max_models: await self._evict_oldest() return # Check memory limit memory_usage = self._get_memory_usage_mb() if memory_usage > self.max_memory_mb: logger.warning( f"Memory usage ({memory_usage:.0f}MB) exceeds limit " f"({self.max_memory_mb}MB)" ) await self._evict_oldest() async def _evict_oldest(self): """Evict the least recently used model""" if not self.cache: return # Get oldest (first) item oldest_name = next(iter(self.cache)) oldest_model = self.cache.pop(oldest_name) self.stats["evictions"] += 1 logger.info(f"🗑️ Evicting: {oldest_name}") # Cleanup model resources try: if hasattr(oldest_model, 'shutdown'): await oldest_model.shutdown() elif hasattr(oldest_model, 'cleanup'): await oldest_model.cleanup() except Exception as e: logger.warning(f"Error during model cleanup: {e}") def _get_memory_usage_mb(self) -> float: """Get current process memory usage in MB""" if not HAS_PSUTIL: return 0.0 try: process = psutil.Process() return process.memory_info().rss / (1024 * 1024) except Exception: return 0.0 def get_hit_rate(self) -> float: """Calculate cache hit rate""" total = self.stats["hits"] + self.stats["misses"] if total == 0: return 0.0 return self.stats["hits"] / total def get_stats(self) -> Dict[str, Any]: """Get cache statistics""" return { **self.stats, "cached_models": len(self.cache), "model_names": list(self.cache.keys()), "hit_rate": self.get_hit_rate(), "memory_usage_mb": self._get_memory_usage_mb() } async def clear(self): """Clear all cached models""" async with self._lock: logger.info("Clearing model cache...") for name, model in self.cache.items(): try: if hasattr(model, 'shutdown'): await model.shutdown() except Exception as e: logger.warning(f"Error shutting down {name}: {e}") self.cache.clear() logger.info("Cache cleared") async def preload(self, model_name: str, loader_func: Callable, *args, **kwargs): """ Preload a model into cache (prefetching). Useful for anticipating which model will be needed next. """ logger.info(f"🔮 Prefetching: {model_name}") await self.get_or_load(model_name, loader_func, *args, **kwargs) def contains(self, model_name: str) -> bool: """Check if model is in cache""" return model_name in self.cache async def remove(self, model_name: str): """Manually remove a model from cache""" async with self._lock: if model_name in self.cache: model = self.cache.pop(model_name) logger.info(f"Removed {model_name} from cache") try: if hasattr(model, 'shutdown'): await model.shutdown() except Exception as e: logger.warning(f"Error shutting down {model_name}: {e}") # Global cache instance model_cache = ModelCache( max_models=3, # Keep 3 models in memory max_memory_mb=2000, # 2GB soft limit enable_stats=True )