vienoux's picture
Upload folder using huggingface_hub
f9adcbf verified
"""
Model Cache with LRU Eviction
Intelligent caching system for Micro-SLMs to minimize loading time.
Keeps the most recently used models in memory.
"""
import logging
import asyncio
from collections import OrderedDict
from typing import Optional, Callable, Any, Dict
from datetime import datetime
try:
import psutil
HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False
logger = logging.getLogger(__name__)
class ModelCache:
"""
LRU (Least Recently Used) Cache for Micro-SLM models.
Features:
- Automatic eviction of least recently used models
- Memory usage tracking
- Async model loading
- Thread-safe operations
"""
def __init__(
self,
max_models: int = 3,
max_memory_mb: int = 2000,
enable_stats: bool = True
):
"""
Initialize the model cache.
Args:
max_models: Maximum number of models to keep in cache
max_memory_mb: Maximum memory usage in MB (soft limit)
enable_stats: Enable statistics tracking
"""
self.cache: OrderedDict[str, Any] = OrderedDict()
self.max_models = max_models
self.max_memory_mb = max_memory_mb
self.enable_stats = enable_stats
# Statistics
self.stats = {
"hits": 0,
"misses": 0,
"evictions": 0,
"loads": 0
}
# Locks for thread safety
self._lock = asyncio.Lock()
logger.info(f"ModelCache initialized: max_models={max_models}, max_memory={max_memory_mb}MB")
async def get_or_load(
self,
model_name: str,
loader_func: Callable,
*args,
**kwargs
) -> Any:
"""
Get model from cache or load it if not present.
Args:
model_name: Unique identifier for the model
loader_func: Async function to load the model
*args, **kwargs: Arguments to pass to loader_func
Returns:
The loaded model instance
"""
async with self._lock:
# Check cache
if model_name in self.cache:
# Cache hit
self.cache.move_to_end(model_name)
self.stats["hits"] += 1
logger.info(
f"✅ Cache HIT: {model_name} "
f"(hit rate: {self.get_hit_rate():.1%})"
)
return self.cache[model_name]
# Cache miss - need to load
self.stats["misses"] += 1
logger.info(f"❌ Cache MISS: {model_name}")
# Check if we need to evict
await self._evict_if_needed()
# Load the model
logger.info(f"📥 Loading model: {model_name}...")
load_start = datetime.now()
try:
model = await loader_func(*args, **kwargs)
load_duration = (datetime.now() - load_start).total_seconds()
logger.info(f"✓ Loaded {model_name} in {load_duration:.2f}s")
# Add to cache
self.cache[model_name] = model
self.cache.move_to_end(model_name)
self.stats["loads"] += 1
return model
except Exception as e:
logger.error(f"Failed to load {model_name}: {e}")
raise
async def _evict_if_needed(self):
"""Evict least recently used model if cache is full"""
# Check model count limit
if len(self.cache) >= self.max_models:
await self._evict_oldest()
return
# Check memory limit
memory_usage = self._get_memory_usage_mb()
if memory_usage > self.max_memory_mb:
logger.warning(
f"Memory usage ({memory_usage:.0f}MB) exceeds limit "
f"({self.max_memory_mb}MB)"
)
await self._evict_oldest()
async def _evict_oldest(self):
"""Evict the least recently used model"""
if not self.cache:
return
# Get oldest (first) item
oldest_name = next(iter(self.cache))
oldest_model = self.cache.pop(oldest_name)
self.stats["evictions"] += 1
logger.info(f"🗑️ Evicting: {oldest_name}")
# Cleanup model resources
try:
if hasattr(oldest_model, 'shutdown'):
await oldest_model.shutdown()
elif hasattr(oldest_model, 'cleanup'):
await oldest_model.cleanup()
except Exception as e:
logger.warning(f"Error during model cleanup: {e}")
def _get_memory_usage_mb(self) -> float:
"""Get current process memory usage in MB"""
if not HAS_PSUTIL:
return 0.0
try:
process = psutil.Process()
return process.memory_info().rss / (1024 * 1024)
except Exception:
return 0.0
def get_hit_rate(self) -> float:
"""Calculate cache hit rate"""
total = self.stats["hits"] + self.stats["misses"]
if total == 0:
return 0.0
return self.stats["hits"] / total
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
return {
**self.stats,
"cached_models": len(self.cache),
"model_names": list(self.cache.keys()),
"hit_rate": self.get_hit_rate(),
"memory_usage_mb": self._get_memory_usage_mb()
}
async def clear(self):
"""Clear all cached models"""
async with self._lock:
logger.info("Clearing model cache...")
for name, model in self.cache.items():
try:
if hasattr(model, 'shutdown'):
await model.shutdown()
except Exception as e:
logger.warning(f"Error shutting down {name}: {e}")
self.cache.clear()
logger.info("Cache cleared")
async def preload(self, model_name: str, loader_func: Callable, *args, **kwargs):
"""
Preload a model into cache (prefetching).
Useful for anticipating which model will be needed next.
"""
logger.info(f"🔮 Prefetching: {model_name}")
await self.get_or_load(model_name, loader_func, *args, **kwargs)
def contains(self, model_name: str) -> bool:
"""Check if model is in cache"""
return model_name in self.cache
async def remove(self, model_name: str):
"""Manually remove a model from cache"""
async with self._lock:
if model_name in self.cache:
model = self.cache.pop(model_name)
logger.info(f"Removed {model_name} from cache")
try:
if hasattr(model, 'shutdown'):
await model.shutdown()
except Exception as e:
logger.warning(f"Error shutting down {model_name}: {e}")
# Global cache instance
model_cache = ModelCache(
max_models=3, # Keep 3 models in memory
max_memory_mb=2000, # 2GB soft limit
enable_stats=True
)