""" ComputationCache - Cache intelligent pour calculs redondants Tâche 5.4: Optimiser les calculs redondants dans TargetResolver. Cache les calculs de distance, alignement et relations spatiales. Auteur : Dom, Alice Kiro - 20 décembre 2024 """ import logging import time from typing import Dict, Tuple, Any, Optional, Callable from dataclasses import dataclass from functools import lru_cache import hashlib logger = logging.getLogger(__name__) @dataclass class ComputationCacheStats: """Statistiques du cache de calculs""" hits: int = 0 misses: int = 0 total_time_saved_ms: float = 0.0 cache_size: int = 0 class ComputationCache: """ Cache intelligent pour calculs redondants. Tâche 5.4: Évite les recalculs coûteux de distance, alignement, etc. Réutilise les résultats entre les résolutions d'ancres multiples. """ def __init__(self, max_size: int = 1000): """ Initialiser le cache de calculs. Args: max_size: Taille maximale du cache """ self.max_size = max_size # Caches spécialisés self._distance_cache: Dict[str, float] = {} self._alignment_cache: Dict[str, float] = {} self._spatial_relation_cache: Dict[str, bool] = {} self._bbox_operation_cache: Dict[str, Any] = {} # Stats self._stats = ComputationCacheStats() logger.debug(f"ComputationCache initialized (max_size={max_size})") def _make_key(self, *args) -> str: """ Créer une clé de cache depuis des arguments. Args: *args: Arguments à hasher Returns: Clé de cache unique """ # Convertir les arguments en string hashable key_parts = [] for arg in args: if hasattr(arg, 'element_id'): key_parts.append(arg.element_id) elif isinstance(arg, (tuple, list)): key_parts.append(str(tuple(arg))) else: key_parts.append(str(arg)) key_str = '|'.join(key_parts) # Hasher pour clés longues if len(key_str) > 100: return hashlib.md5(key_str.encode()).hexdigest() return key_str def get_distance(self, elem1_id: str, elem2_id: str, compute_func: Callable[[], float]) -> float: """ Obtenir la distance entre deux éléments avec cache. Args: elem1_id: ID du premier élément elem2_id: ID du deuxième élément compute_func: Fonction pour calculer la distance si absent du cache Returns: Distance calculée ou depuis le cache """ # Clé symétrique (distance(A,B) = distance(B,A)) key = self._make_key(min(elem1_id, elem2_id), max(elem1_id, elem2_id), 'dist') if key in self._distance_cache: self._stats.hits += 1 return self._distance_cache[key] # Cache miss - calculer self._stats.misses += 1 start_time = time.perf_counter() distance = compute_func() compute_time = (time.perf_counter() - start_time) * 1000 self._stats.total_time_saved_ms += compute_time # Ajouter au cache avec éviction si nécessaire self._distance_cache[key] = distance self._ensure_cache_size(self._distance_cache) return distance def get_alignment_score(self, elem_id: str, anchor_id: str, hint_type: str, compute_func: Callable[[], float]) -> float: """ Obtenir le score d'alignement avec cache. Args: elem_id: ID de l'élément anchor_id: ID de l'ancre hint_type: Type de hint (below, right_of, etc.) compute_func: Fonction pour calculer l'alignement Returns: Score d'alignement """ key = self._make_key(elem_id, anchor_id, hint_type, 'align') if key in self._alignment_cache: self._stats.hits += 1 return self._alignment_cache[key] # Cache miss self._stats.misses += 1 start_time = time.perf_counter() score = compute_func() compute_time = (time.perf_counter() - start_time) * 1000 self._stats.total_time_saved_ms += compute_time self._alignment_cache[key] = score self._ensure_cache_size(self._alignment_cache) return score def get_spatial_relation(self, elem_id: str, anchor_id: str, relation_type: str, compute_func: Callable[[], bool]) -> bool: """ Obtenir une relation spatiale avec cache. Args: elem_id: ID de l'élément anchor_id: ID de l'ancre relation_type: Type de relation (below, above, etc.) compute_func: Fonction pour calculer la relation Returns: True si la relation est vérifiée """ key = self._make_key(elem_id, anchor_id, relation_type, 'spatial') if key in self._spatial_relation_cache: self._stats.hits += 1 return self._spatial_relation_cache[key] # Cache miss self._stats.misses += 1 result = compute_func() self._spatial_relation_cache[key] = result self._ensure_cache_size(self._spatial_relation_cache) return result def get_bbox_operation(self, operation: str, *bbox_ids, compute_func: Callable[[], Any]) -> Any: """ Obtenir le résultat d'une opération bbox avec cache. Args: operation: Type d'opération (intersection, union, contains, etc.) *bbox_ids: IDs des bboxes impliquées compute_func: Fonction pour calculer l'opération Returns: Résultat de l'opération """ key = self._make_key(operation, *bbox_ids, 'bbox_op') if key in self._bbox_operation_cache: self._stats.hits += 1 return self._bbox_operation_cache[key] # Cache miss self._stats.misses += 1 start_time = time.perf_counter() result = compute_func() compute_time = (time.perf_counter() - start_time) * 1000 self._stats.total_time_saved_ms += compute_time self._bbox_operation_cache[key] = result self._ensure_cache_size(self._bbox_operation_cache) return result def _ensure_cache_size(self, cache: Dict) -> None: """ S'assurer que le cache ne dépasse pas la taille max. Args: cache: Cache à vérifier """ if len(cache) > self.max_size: # Éviction FIFO simple (supprimer les plus anciennes entrées) keys_to_remove = list(cache.keys())[:len(cache) - self.max_size] for key in keys_to_remove: del cache[key] def clear(self) -> None: """Vider tous les caches""" self._distance_cache.clear() self._alignment_cache.clear() self._spatial_relation_cache.clear() self._bbox_operation_cache.clear() logger.debug("ComputationCache cleared") def get_stats(self) -> Dict[str, Any]: """Obtenir les statistiques du cache""" total_cache_size = ( len(self._distance_cache) + len(self._alignment_cache) + len(self._spatial_relation_cache) + len(self._bbox_operation_cache) ) total_requests = self._stats.hits + self._stats.misses hit_rate = (self._stats.hits / total_requests * 100) if total_requests > 0 else 0.0 return { 'hits': self._stats.hits, 'misses': self._stats.misses, 'hit_rate_percent': round(hit_rate, 2), 'total_time_saved_ms': round(self._stats.total_time_saved_ms, 2), 'cache_sizes': { 'distance': len(self._distance_cache), 'alignment': len(self._alignment_cache), 'spatial_relation': len(self._spatial_relation_cache), 'bbox_operation': len(self._bbox_operation_cache), 'total': total_cache_size }, 'max_size': self.max_size } # Fonctions utilitaires avec cache LRU intégré @lru_cache(maxsize=512) def cached_bbox_center(bbox_tuple: Tuple[int, int, int, int]) -> Tuple[float, float]: """ Calculer le centre d'une bbox avec cache LRU. Args: bbox_tuple: (x, y, w, h) Returns: (center_x, center_y) """ x, y, w, h = bbox_tuple return (float(x + w / 2), float(y + h / 2)) @lru_cache(maxsize=512) def cached_bbox_area(bbox_tuple: Tuple[int, int, int, int]) -> float: """ Calculer l'aire d'une bbox avec cache LRU. Args: bbox_tuple: (x, y, w, h) Returns: Aire en pixels """ x, y, w, h = bbox_tuple return float(w * h) @lru_cache(maxsize=512) def cached_bbox_iou(bbox1: Tuple[int, int, int, int], bbox2: Tuple[int, int, int, int]) -> float: """ Calculer l'IoU entre deux bboxes avec cache LRU. Args: bbox1: (x, y, w, h) bbox2: (x, y, w, h) Returns: IoU dans [0, 1] """ x1, y1, w1, h1 = bbox1 x2, y2, w2, h2 = bbox2 # Intersection x_left = max(x1, x2) y_top = max(y1, y2) x_right = min(x1 + w1, x2 + w2) y_bottom = min(y1 + h1, y2 + h2) if x_right < x_left or y_bottom < y_top: return 0.0 intersection = (x_right - x_left) * (y_bottom - y_top) # Union area1 = w1 * h1 area2 = w2 * h2 union = area1 + area2 - intersection return float(intersection / union) if union > 0 else 0.0 @lru_cache(maxsize=512) def cached_euclidean_distance(point1: Tuple[float, float], point2: Tuple[float, float]) -> float: """ Calculer la distance euclidienne avec cache LRU. Args: point1: (x1, y1) point2: (x2, y2) Returns: Distance euclidienne """ x1, y1 = point1 x2, y2 = point2 return float(((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5) def clear_all_lru_caches() -> None: """Vider tous les caches LRU des fonctions utilitaires""" cached_bbox_center.cache_clear() cached_bbox_area.cache_clear() cached_bbox_iou.cache_clear() cached_euclidean_distance.cache_clear() logger.debug("All LRU caches cleared")