Spaces:
Sleeping
Sleeping
File size: 7,772 Bytes
f9adcbf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | """
Model Cache with LRU Eviction
Intelligent caching system for Micro-SLMs to minimize loading time.
Keeps the most recently used models in memory.
"""
import logging
import asyncio
from collections import OrderedDict
from typing import Optional, Callable, Any, Dict
from datetime import datetime
try:
import psutil
HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False
logger = logging.getLogger(__name__)
class ModelCache:
"""
LRU (Least Recently Used) Cache for Micro-SLM models.
Features:
- Automatic eviction of least recently used models
- Memory usage tracking
- Async model loading
- Thread-safe operations
"""
def __init__(
self,
max_models: int = 3,
max_memory_mb: int = 2000,
enable_stats: bool = True
):
"""
Initialize the model cache.
Args:
max_models: Maximum number of models to keep in cache
max_memory_mb: Maximum memory usage in MB (soft limit)
enable_stats: Enable statistics tracking
"""
self.cache: OrderedDict[str, Any] = OrderedDict()
self.max_models = max_models
self.max_memory_mb = max_memory_mb
self.enable_stats = enable_stats
# Statistics
self.stats = {
"hits": 0,
"misses": 0,
"evictions": 0,
"loads": 0
}
# Locks for thread safety
self._lock = asyncio.Lock()
logger.info(f"ModelCache initialized: max_models={max_models}, max_memory={max_memory_mb}MB")
async def get_or_load(
self,
model_name: str,
loader_func: Callable,
*args,
**kwargs
) -> Any:
"""
Get model from cache or load it if not present.
Args:
model_name: Unique identifier for the model
loader_func: Async function to load the model
*args, **kwargs: Arguments to pass to loader_func
Returns:
The loaded model instance
"""
async with self._lock:
# Check cache
if model_name in self.cache:
# Cache hit
self.cache.move_to_end(model_name)
self.stats["hits"] += 1
logger.info(
f"✅ Cache HIT: {model_name} "
f"(hit rate: {self.get_hit_rate():.1%})"
)
return self.cache[model_name]
# Cache miss - need to load
self.stats["misses"] += 1
logger.info(f"❌ Cache MISS: {model_name}")
# Check if we need to evict
await self._evict_if_needed()
# Load the model
logger.info(f"📥 Loading model: {model_name}...")
load_start = datetime.now()
try:
model = await loader_func(*args, **kwargs)
load_duration = (datetime.now() - load_start).total_seconds()
logger.info(f"✓ Loaded {model_name} in {load_duration:.2f}s")
# Add to cache
self.cache[model_name] = model
self.cache.move_to_end(model_name)
self.stats["loads"] += 1
return model
except Exception as e:
logger.error(f"Failed to load {model_name}: {e}")
raise
async def _evict_if_needed(self):
"""Evict least recently used model if cache is full"""
# Check model count limit
if len(self.cache) >= self.max_models:
await self._evict_oldest()
return
# Check memory limit
memory_usage = self._get_memory_usage_mb()
if memory_usage > self.max_memory_mb:
logger.warning(
f"Memory usage ({memory_usage:.0f}MB) exceeds limit "
f"({self.max_memory_mb}MB)"
)
await self._evict_oldest()
async def _evict_oldest(self):
"""Evict the least recently used model"""
if not self.cache:
return
# Get oldest (first) item
oldest_name = next(iter(self.cache))
oldest_model = self.cache.pop(oldest_name)
self.stats["evictions"] += 1
logger.info(f"🗑️ Evicting: {oldest_name}")
# Cleanup model resources
try:
if hasattr(oldest_model, 'shutdown'):
await oldest_model.shutdown()
elif hasattr(oldest_model, 'cleanup'):
await oldest_model.cleanup()
except Exception as e:
logger.warning(f"Error during model cleanup: {e}")
def _get_memory_usage_mb(self) -> float:
"""Get current process memory usage in MB"""
if not HAS_PSUTIL:
return 0.0
try:
process = psutil.Process()
return process.memory_info().rss / (1024 * 1024)
except Exception:
return 0.0
def get_hit_rate(self) -> float:
"""Calculate cache hit rate"""
total = self.stats["hits"] + self.stats["misses"]
if total == 0:
return 0.0
return self.stats["hits"] / total
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
return {
**self.stats,
"cached_models": len(self.cache),
"model_names": list(self.cache.keys()),
"hit_rate": self.get_hit_rate(),
"memory_usage_mb": self._get_memory_usage_mb()
}
async def clear(self):
"""Clear all cached models"""
async with self._lock:
logger.info("Clearing model cache...")
for name, model in self.cache.items():
try:
if hasattr(model, 'shutdown'):
await model.shutdown()
except Exception as e:
logger.warning(f"Error shutting down {name}: {e}")
self.cache.clear()
logger.info("Cache cleared")
async def preload(self, model_name: str, loader_func: Callable, *args, **kwargs):
"""
Preload a model into cache (prefetching).
Useful for anticipating which model will be needed next.
"""
logger.info(f"🔮 Prefetching: {model_name}")
await self.get_or_load(model_name, loader_func, *args, **kwargs)
def contains(self, model_name: str) -> bool:
"""Check if model is in cache"""
return model_name in self.cache
async def remove(self, model_name: str):
"""Manually remove a model from cache"""
async with self._lock:
if model_name in self.cache:
model = self.cache.pop(model_name)
logger.info(f"Removed {model_name} from cache")
try:
if hasattr(model, 'shutdown'):
await model.shutdown()
except Exception as e:
logger.warning(f"Error shutting down {model_name}: {e}")
# Global cache instance
model_cache = ModelCache(
max_models=3, # Keep 3 models in memory
max_memory_mb=2000, # 2GB soft limit
enable_stats=True
)
|