File size: 7,772 Bytes
f9adcbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""

Model Cache with LRU Eviction



Intelligent caching system for Micro-SLMs to minimize loading time.

Keeps the most recently used models in memory.

"""
import logging
import asyncio
from collections import OrderedDict
from typing import Optional, Callable, Any, Dict
from datetime import datetime

try:
    import psutil
    HAS_PSUTIL = True
except ImportError:
    HAS_PSUTIL = False

logger = logging.getLogger(__name__)


class ModelCache:
    """

    LRU (Least Recently Used) Cache for Micro-SLM models.

    

    Features:

    - Automatic eviction of least recently used models

    - Memory usage tracking

    - Async model loading

    - Thread-safe operations

    """
    
    def __init__(

        self,

        max_models: int = 3,

        max_memory_mb: int = 2000,

        enable_stats: bool = True

    ):
        """

        Initialize the model cache.

        

        Args:

            max_models: Maximum number of models to keep in cache

            max_memory_mb: Maximum memory usage in MB (soft limit)

            enable_stats: Enable statistics tracking

        """
        self.cache: OrderedDict[str, Any] = OrderedDict()
        self.max_models = max_models
        self.max_memory_mb = max_memory_mb
        self.enable_stats = enable_stats
        
        # Statistics
        self.stats = {
            "hits": 0,
            "misses": 0,
            "evictions": 0,
            "loads": 0
        }
        
        # Locks for thread safety
        self._lock = asyncio.Lock()
        
        logger.info(f"ModelCache initialized: max_models={max_models}, max_memory={max_memory_mb}MB")
    
    async def get_or_load(

        self,

        model_name: str,

        loader_func: Callable,

        *args,

        **kwargs

    ) -> Any:
        """

        Get model from cache or load it if not present.

        

        Args:

            model_name: Unique identifier for the model

            loader_func: Async function to load the model

            *args, **kwargs: Arguments to pass to loader_func

        

        Returns:

            The loaded model instance

        """
        async with self._lock:
            # Check cache
            if model_name in self.cache:
                # Cache hit
                self.cache.move_to_end(model_name)
                self.stats["hits"] += 1
                
                logger.info(
                    f"✅ Cache HIT: {model_name} "
                    f"(hit rate: {self.get_hit_rate():.1%})"
                )
                
                return self.cache[model_name]
            
            # Cache miss - need to load
            self.stats["misses"] += 1
            logger.info(f"❌ Cache MISS: {model_name}")
            
            # Check if we need to evict
            await self._evict_if_needed()
            
            # Load the model
            logger.info(f"📥 Loading model: {model_name}...")
            load_start = datetime.now()
            
            try:
                model = await loader_func(*args, **kwargs)
                load_duration = (datetime.now() - load_start).total_seconds()
                
                logger.info(f"✓ Loaded {model_name} in {load_duration:.2f}s")
                
                # Add to cache
                self.cache[model_name] = model
                self.cache.move_to_end(model_name)
                self.stats["loads"] += 1
                
                return model
            
            except Exception as e:
                logger.error(f"Failed to load {model_name}: {e}")
                raise
    
    async def _evict_if_needed(self):
        """Evict least recently used model if cache is full"""
        
        # Check model count limit
        if len(self.cache) >= self.max_models:
            await self._evict_oldest()
            return
        
        # Check memory limit
        memory_usage = self._get_memory_usage_mb()
        if memory_usage > self.max_memory_mb:
            logger.warning(
                f"Memory usage ({memory_usage:.0f}MB) exceeds limit "
                f"({self.max_memory_mb}MB)"
            )
            await self._evict_oldest()
    
    async def _evict_oldest(self):
        """Evict the least recently used model"""
        
        if not self.cache:
            return
        
        # Get oldest (first) item
        oldest_name = next(iter(self.cache))
        oldest_model = self.cache.pop(oldest_name)
        
        self.stats["evictions"] += 1
        
        logger.info(f"🗑️ Evicting: {oldest_name}")
        
        # Cleanup model resources
        try:
            if hasattr(oldest_model, 'shutdown'):
                await oldest_model.shutdown()
            elif hasattr(oldest_model, 'cleanup'):
                await oldest_model.cleanup()
        except Exception as e:
            logger.warning(f"Error during model cleanup: {e}")
    
    def _get_memory_usage_mb(self) -> float:
        """Get current process memory usage in MB"""
        if not HAS_PSUTIL:
            return 0.0
        
        try:
            process = psutil.Process()
            return process.memory_info().rss / (1024 * 1024)
        except Exception:
            return 0.0
    
    def get_hit_rate(self) -> float:
        """Calculate cache hit rate"""
        total = self.stats["hits"] + self.stats["misses"]
        if total == 0:
            return 0.0
        return self.stats["hits"] / total
    
    def get_stats(self) -> Dict[str, Any]:
        """Get cache statistics"""
        return {
            **self.stats,
            "cached_models": len(self.cache),
            "model_names": list(self.cache.keys()),
            "hit_rate": self.get_hit_rate(),
            "memory_usage_mb": self._get_memory_usage_mb()
        }
    
    async def clear(self):
        """Clear all cached models"""
        async with self._lock:
            logger.info("Clearing model cache...")
            
            for name, model in self.cache.items():
                try:
                    if hasattr(model, 'shutdown'):
                        await model.shutdown()
                except Exception as e:
                    logger.warning(f"Error shutting down {name}: {e}")
            
            self.cache.clear()
            logger.info("Cache cleared")
    
    async def preload(self, model_name: str, loader_func: Callable, *args, **kwargs):
        """

        Preload a model into cache (prefetching).

        

        Useful for anticipating which model will be needed next.

        """
        logger.info(f"🔮 Prefetching: {model_name}")
        await self.get_or_load(model_name, loader_func, *args, **kwargs)
    
    def contains(self, model_name: str) -> bool:
        """Check if model is in cache"""
        return model_name in self.cache
    
    async def remove(self, model_name: str):
        """Manually remove a model from cache"""
        async with self._lock:
            if model_name in self.cache:
                model = self.cache.pop(model_name)
                logger.info(f"Removed {model_name} from cache")
                
                try:
                    if hasattr(model, 'shutdown'):
                        await model.shutdown()
                except Exception as e:
                    logger.warning(f"Error shutting down {model_name}: {e}")


# Global cache instance
model_cache = ModelCache(
    max_models=3,  # Keep 3 models in memory
    max_memory_mb=2000,  # 2GB soft limit
    enable_stats=True
)