""" ModelCache - Cache persistant des modèles ML Tâche 5.3: Cache des modèles ML pour éviter les rechargements multiples. Gère le chargement, la mise en cache et l'éviction des modèles ML. Auteur : Dom, Alice Kiro - 20 décembre 2024 """ import logging import time import threading from typing import Dict, Any, Optional, Callable, Tuple from dataclasses import dataclass, field from pathlib import Path import weakref import gc logger = logging.getLogger(__name__) @dataclass class ModelCacheEntry: """Entrée du cache de modèles""" model: Any load_time: float last_access: float access_count: int = 0 memory_size_mb: float = 0.0 model_type: str = "unknown" def update_access(self): """Mettre à jour les stats d'accès""" self.last_access = time.time() self.access_count += 1 @dataclass class ModelCacheConfig: """Configuration du cache de modèles""" max_models: int = 5 # Nombre max de modèles en cache max_memory_mb: float = 2048.0 # Mémoire max en MB ttl_seconds: float = 3600.0 # TTL par défaut (1h) enable_weak_refs: bool = True # Utiliser WeakValueDictionary auto_cleanup: bool = True # Nettoyage automatique cleanup_interval: float = 300.0 # Intervalle de nettoyage (5min) class ModelCache: """ Cache persistant des modèles ML avec gestion mémoire intelligente. Tâche 5.3: Évite les rechargements multiples des modèles coûteux. Fonctionnalités: - Cache LRU avec limite de mémoire - TTL configurable par modèle - Nettoyage automatique - Support WeakValueDictionary - Thread-safe """ def __init__(self, config: Optional[ModelCacheConfig] = None): """ Initialiser le cache de modèles. Args: config: Configuration du cache """ self.config = config or ModelCacheConfig() # Cache principal avec ou sans weak references if self.config.enable_weak_refs: self._cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary() else: self._cache: Dict[str, ModelCacheEntry] = {} # Métadonnées du cache (toujours dict normal) self._metadata: Dict[str, Dict[str, Any]] = {} # Thread safety self._lock = threading.RLock() # Stats self._stats = { 'hits': 0, 'misses': 0, 'loads': 0, 'evictions': 0, 'cleanups': 0, 'memory_freed_mb': 0.0 } # Nettoyage automatique self._cleanup_timer: Optional[threading.Timer] = None if self.config.auto_cleanup: self._start_cleanup_timer() logger.info(f"ModelCache initialized (max_models={self.config.max_models}, " f"max_memory={self.config.max_memory_mb}MB)") def get_model(self, model_key: str, loader_func: Callable[[], Any], model_type: str = "unknown", ttl_seconds: Optional[float] = None) -> Any: """ Obtenir un modèle depuis le cache ou le charger. Args: model_key: Clé unique du modèle loader_func: Fonction pour charger le modèle si absent du cache model_type: Type de modèle (pour logging/stats) ttl_seconds: TTL spécifique (utilise config par défaut si None) Returns: Modèle chargé """ with self._lock: # Vérifier le cache if model_key in self._cache: entry = self._cache[model_key] # Vérifier TTL ttl = ttl_seconds or self.config.ttl_seconds if time.time() - entry.load_time < ttl: entry.update_access() self._stats['hits'] += 1 logger.debug(f"Model cache hit: {model_key} ({model_type})") return entry.model else: # TTL expiré logger.debug(f"Model TTL expired: {model_key}") self._remove_model(model_key) # Cache miss - charger le modèle self._stats['misses'] += 1 logger.info(f"Loading model: {model_key} ({model_type})") start_time = time.time() try: model = loader_func() load_time = time.time() - start_time # Estimer la taille mémoire (approximation) memory_size = self._estimate_model_size(model) # Créer l'entrée de cache entry = ModelCacheEntry( model=model, load_time=time.time(), last_access=time.time(), access_count=1, memory_size_mb=memory_size, model_type=model_type ) # Vérifier les limites avant d'ajouter self._ensure_cache_limits(memory_size) # Ajouter au cache self._cache[model_key] = entry self._metadata[model_key] = { 'ttl_seconds': ttl_seconds or self.config.ttl_seconds, 'model_type': model_type, 'load_time_seconds': load_time } self._stats['loads'] += 1 logger.info(f"Model loaded and cached: {model_key} " f"({memory_size:.1f}MB, {load_time:.2f}s)") return model except Exception as e: logger.error(f"Failed to load model {model_key}: {e}") raise def remove_model(self, model_key: str) -> bool: """ Supprimer un modèle du cache. Args: model_key: Clé du modèle à supprimer Returns: True si supprimé, False si non trouvé """ with self._lock: return self._remove_model(model_key) def _remove_model(self, model_key: str) -> bool: """Version interne de remove_model (sans lock)""" if model_key in self._cache: entry = self._cache[model_key] memory_freed = entry.memory_size_mb del self._cache[model_key] self._metadata.pop(model_key, None) self._stats['evictions'] += 1 self._stats['memory_freed_mb'] += memory_freed logger.debug(f"Model evicted: {model_key} ({memory_freed:.1f}MB freed)") return True return False def _ensure_cache_limits(self, new_model_size_mb: float) -> None: """ S'assurer que les limites du cache sont respectées. Args: new_model_size_mb: Taille du nouveau modèle à ajouter """ current_memory = self.get_memory_usage() target_memory = current_memory + new_model_size_mb # Éviction par mémoire if target_memory > self.config.max_memory_mb: logger.info(f"Memory limit would be exceeded ({target_memory:.1f}MB > " f"{self.config.max_memory_mb}MB), evicting models...") self._evict_lru_models(target_memory - self.config.max_memory_mb) # Éviction par nombre de modèles if len(self._cache) >= self.config.max_models: logger.info(f"Model count limit reached ({len(self._cache)} >= " f"{self.config.max_models}), evicting oldest...") self._evict_oldest_model() def _evict_lru_models(self, memory_to_free_mb: float) -> None: """Éviction LRU pour libérer de la mémoire""" if not self._cache: return # Trier par dernier accès (LRU) models_by_access = sorted( self._cache.items(), key=lambda x: x[1].last_access ) freed_memory = 0.0 for model_key, entry in models_by_access: if freed_memory >= memory_to_free_mb: break freed_memory += entry.memory_size_mb self._remove_model(model_key) logger.info(f"LRU eviction freed {freed_memory:.1f}MB") def _evict_oldest_model(self) -> None: """Éviction du modèle le plus ancien""" if not self._cache: return oldest_key = min(self._cache.keys(), key=lambda k: self._cache[k].load_time) self._remove_model(oldest_key) def _estimate_model_size(self, model: Any) -> float: """ Estimer la taille mémoire d'un modèle (approximation). Args: model: Modèle à analyser Returns: Taille estimée en MB """ try: # Pour les modèles PyTorch if hasattr(model, 'parameters'): total_params = sum(p.numel() for p in model.parameters()) # Approximation: 4 bytes par paramètre (float32) return (total_params * 4) / (1024 * 1024) # Pour les modèles scikit-learn if hasattr(model, '__sizeof__'): return model.__sizeof__() / (1024 * 1024) # Fallback générique import sys return sys.getsizeof(model) / (1024 * 1024) except Exception: # Estimation par défaut si échec return 50.0 # 50MB par défaut def cleanup_expired(self) -> int: """ Nettoyer les modèles expirés. Returns: Nombre de modèles supprimés """ with self._lock: current_time = time.time() expired_keys = [] for model_key, entry in self._cache.items(): metadata = self._metadata.get(model_key, {}) ttl = metadata.get('ttl_seconds', self.config.ttl_seconds) if current_time - entry.load_time > ttl: expired_keys.append(model_key) for key in expired_keys: self._remove_model(key) if expired_keys: self._stats['cleanups'] += 1 logger.info(f"Cleanup removed {len(expired_keys)} expired models") return len(expired_keys) def _start_cleanup_timer(self) -> None: """Démarrer le timer de nettoyage automatique""" def cleanup_task(): try: self.cleanup_expired() # Force garbage collection après nettoyage gc.collect() except Exception as e: logger.error(f"Error in cleanup task: {e}") finally: # Reprogrammer le prochain nettoyage if self.config.auto_cleanup: self._cleanup_timer = threading.Timer( self.config.cleanup_interval, cleanup_task ) self._cleanup_timer.daemon = True self._cleanup_timer.start() self._cleanup_timer = threading.Timer( self.config.cleanup_interval, cleanup_task ) self._cleanup_timer.daemon = True self._cleanup_timer.start() def get_memory_usage(self) -> float: """ Obtenir l'utilisation mémoire actuelle du cache. Returns: Mémoire utilisée en MB """ with self._lock: return sum(entry.memory_size_mb for entry in self._cache.values()) def get_stats(self) -> Dict[str, Any]: """Obtenir les statistiques du cache""" with self._lock: return { **self._stats, 'cache_size': len(self._cache), 'memory_usage_mb': self.get_memory_usage(), 'memory_limit_mb': self.config.max_memory_mb, 'model_limit': self.config.max_models } def clear(self) -> None: """Vider complètement le cache""" with self._lock: cache_size = len(self._cache) memory_freed = self.get_memory_usage() self._cache.clear() self._metadata.clear() self._stats['evictions'] += cache_size self._stats['memory_freed_mb'] += memory_freed logger.info(f"Cache cleared: {cache_size} models, {memory_freed:.1f}MB freed") def shutdown(self) -> None: """Arrêter le cache et nettoyer les ressources""" if self._cleanup_timer: self._cleanup_timer.cancel() self._cleanup_timer = None self.clear() logger.info("ModelCache shutdown complete") def __del__(self): """Nettoyage automatique à la destruction""" try: self.shutdown() except Exception: pass # Instance globale du cache de modèles _global_model_cache: Optional[ModelCache] = None def get_global_model_cache() -> ModelCache: """ Obtenir l'instance globale du cache de modèles. Returns: Instance globale du ModelCache """ global _global_model_cache if _global_model_cache is None: _global_model_cache = ModelCache() return _global_model_cache def set_global_model_cache(cache: ModelCache) -> None: """ Définir l'instance globale du cache de modèles. Args: cache: Nouvelle instance de ModelCache """ global _global_model_cache if _global_model_cache: _global_model_cache.shutdown() _global_model_cache = cache