""" Embedding Cache - Cache LRU pour embeddings Implémente un cache LRU (Least Recently Used) pour stocker les embeddings en mémoire et éviter les recalculs coûteux. """ import logging from typing import Optional, Dict, Any from collections import OrderedDict import numpy as np from datetime import datetime logger = logging.getLogger(__name__) class EmbeddingCache: """ Cache LRU pour embeddings. Stocke les embeddings les plus récemment utilisés en mémoire pour éviter les recalculs et chargements depuis disque. Features: - LRU eviction policy - Taille maximale configurable - Statistiques de cache (hits/misses) - Invalidation sélective """ def __init__(self, max_size: int = 1000, max_memory_mb: float = 500.0): """ Initialiser le cache. Args: max_size: Nombre maximum d'embeddings à garder en cache max_memory_mb: Mémoire maximale en MB (approximatif) """ self.max_size = max_size self.max_memory_mb = max_memory_mb self.cache: OrderedDict[str, np.ndarray] = OrderedDict() self.metadata: Dict[str, Dict[str, Any]] = {} # Statistiques self.hits = 0 self.misses = 0 self.evictions = 0 logger.info( f"EmbeddingCache initialized: max_size={max_size}, " f"max_memory_mb={max_memory_mb:.1f}" ) def get(self, key: str) -> Optional[np.ndarray]: """ Récupérer un embedding du cache. Args: key: Clé de l'embedding (embedding_id) Returns: Vecteur numpy si trouvé, None sinon """ if key in self.cache: # Déplacer à la fin (most recently used) self.cache.move_to_end(key) self.hits += 1 logger.debug(f"Cache HIT: {key}") return self.cache[key] self.misses += 1 logger.debug(f"Cache MISS: {key}") return None def put( self, key: str, vector: np.ndarray, metadata: Optional[Dict[str, Any]] = None ): """ Ajouter un embedding au cache. Args: key: Clé de l'embedding vector: Vecteur numpy metadata: Métadonnées optionnelles """ # Si déjà présent, mettre à jour et déplacer à la fin if key in self.cache: self.cache.move_to_end(key) self.cache[key] = vector if metadata: self.metadata[key] = metadata return # Vérifier si on doit évict if len(self.cache) >= self.max_size: self._evict_oldest() # Ajouter le nouvel embedding self.cache[key] = vector if metadata: self.metadata[key] = metadata logger.debug(f"Cache PUT: {key} (size: {len(self.cache)})") def _evict_oldest(self): """Évict l'embedding le moins récemment utilisé.""" if not self.cache: return # Retirer le premier élément (oldest) oldest_key, _ = self.cache.popitem(last=False) self.metadata.pop(oldest_key, None) self.evictions += 1 logger.debug(f"Cache EVICT: {oldest_key} (evictions: {self.evictions})") def invalidate(self, key: str): """ Invalider un embedding spécifique. Args: key: Clé de l'embedding à invalider """ if key in self.cache: del self.cache[key] self.metadata.pop(key, None) logger.debug(f"Cache INVALIDATE: {key}") def invalidate_pattern(self, pattern: str): """ Invalider tous les embeddings dont la clé contient le pattern. Args: pattern: Pattern à rechercher dans les clés """ keys_to_remove = [k for k in self.cache.keys() if pattern in k] for key in keys_to_remove: del self.cache[key] self.metadata.pop(key, None) if keys_to_remove: logger.info(f"Cache INVALIDATE PATTERN '{pattern}': {len(keys_to_remove)} entries") def clear(self): """Vider complètement le cache.""" size_before = len(self.cache) self.cache.clear() self.metadata.clear() logger.info(f"Cache CLEAR: {size_before} entries removed") def get_stats(self) -> Dict[str, Any]: """ Obtenir les statistiques du cache. Returns: Dict avec statistiques """ total_requests = self.hits + self.misses hit_rate = self.hits / total_requests if total_requests > 0 else 0.0 # Estimer la mémoire utilisée memory_mb = 0.0 for vector in self.cache.values(): # Taille en bytes = nombre d'éléments * taille d'un float32 memory_mb += vector.nbytes / (1024 * 1024) return { "size": len(self.cache), "max_size": self.max_size, "hits": self.hits, "misses": self.misses, "evictions": self.evictions, "hit_rate": hit_rate, "memory_mb": memory_mb, "max_memory_mb": self.max_memory_mb, "memory_usage_pct": (memory_mb / self.max_memory_mb * 100) if self.max_memory_mb > 0 else 0.0 } def __len__(self) -> int: """Retourne le nombre d'embeddings en cache.""" return len(self.cache) def __contains__(self, key: str) -> bool: """Vérifie si une clé est dans le cache.""" return key in self.cache class PrototypeCache: """ Cache spécialisé pour les prototypes de WorkflowNodes. Les prototypes sont utilisés fréquemment pour le matching, donc on les garde en cache avec une politique différente. """ def __init__(self, max_size: int = 100): """ Initialiser le cache de prototypes. Args: max_size: Nombre maximum de prototypes à garder """ self.max_size = max_size self.cache: Dict[str, np.ndarray] = {} self.access_count: Dict[str, int] = {} self.last_access: Dict[str, datetime] = {} logger.info(f"PrototypeCache initialized: max_size={max_size}") def get(self, node_id: str) -> Optional[np.ndarray]: """ Récupérer un prototype du cache. Args: node_id: ID du WorkflowNode Returns: Vecteur prototype si trouvé, None sinon """ if node_id in self.cache: self.access_count[node_id] = self.access_count.get(node_id, 0) + 1 self.last_access[node_id] = datetime.now() return self.cache[node_id] return None def put(self, node_id: str, prototype: np.ndarray): """ Ajouter un prototype au cache. Args: node_id: ID du WorkflowNode prototype: Vecteur prototype """ # Si cache plein, évict le moins utilisé if len(self.cache) >= self.max_size and node_id not in self.cache: self._evict_least_used() self.cache[node_id] = prototype self.access_count[node_id] = self.access_count.get(node_id, 0) + 1 self.last_access[node_id] = datetime.now() def _evict_least_used(self): """Évict le prototype le moins utilisé.""" if not self.cache: return # Trouver le moins utilisé least_used = min(self.access_count.items(), key=lambda x: x[1]) node_id = least_used[0] del self.cache[node_id] del self.access_count[node_id] del self.last_access[node_id] logger.debug(f"PrototypeCache EVICT: {node_id}") def invalidate(self, node_id: str): """Invalider un prototype spécifique.""" if node_id in self.cache: del self.cache[node_id] self.access_count.pop(node_id, None) self.last_access.pop(node_id, None) def clear(self): """Vider le cache.""" self.cache.clear() self.access_count.clear() self.last_access.clear() def get_stats(self) -> Dict[str, Any]: """Obtenir les statistiques du cache.""" total_accesses = sum(self.access_count.values()) avg_accesses = total_accesses / len(self.cache) if self.cache else 0.0 return { "size": len(self.cache), "max_size": self.max_size, "total_accesses": total_accesses, "avg_accesses_per_prototype": avg_accesses }