""" Embedding manager with model selection, caching, and fallback. This module provides a high-level interface for generating embeddings, with automatic model selection, LRU caching, and fallback to CLIP if the selected model fails to load. """ import hashlib import logging from typing import Optional, Dict, Any from collections import OrderedDict from PIL import Image import numpy as np from .base import EmbedderBase from .clip_embedder import CLIPEmbedder logger = logging.getLogger(__name__) class EmbeddingManager: """ High-level manager for image embeddings. Features: - Model selection (CLIP, Pix2Struct, etc.) - Automatic fallback to CLIP on errors - LRU cache (1000 entries) for performance - GPU/CPU management - Logging and monitoring """ def __init__( self, model_name: str = "clip", fallback_enabled: bool = True, cache_size: int = 1000, device: Optional[str] = None ): """ Initialize the embedding manager. Args: model_name: Model to use ("clip" or "pix2struct") fallback_enabled: If True, fallback to CLIP on model load failure cache_size: Maximum number of cached embeddings (LRU eviction) device: Device to use ('cuda', 'cpu', or None for auto) Raises: RuntimeError: If model loading fails and fallback is disabled """ self.model_name = model_name.lower() self.fallback_enabled = fallback_enabled self.cache_size = cache_size self.device = device # Initialize embedder self.embedder = self._load_embedder() # Initialize LRU cache self._cache: OrderedDict[str, np.ndarray] = OrderedDict() # Statistics self._cache_hits = 0 self._cache_misses = 0 logger.info( f"EmbeddingManager initialized: model={self.embedder.get_model_name()}, " f"dimension={self.embedder.get_dimension()}, " f"cache_size={cache_size}" ) def _load_embedder(self) -> EmbedderBase: """ Load the specified embedder with fallback support. Returns: EmbedderBase: Loaded embedder instance Raises: RuntimeError: If loading fails and fallback is disabled """ try: if self.model_name == "clip": return CLIPEmbedder(device=self.device) elif self.model_name == "pix2struct": # Import here to avoid dependency if not used try: from .pix2struct_embedder import Pix2StructEmbedder return Pix2StructEmbedder(device=self.device) except ImportError as e: if self.fallback_enabled: logger.warning( f"Pix2Struct not available ({e}), falling back to CLIP" ) return CLIPEmbedder(device=self.device) raise else: raise ValueError(f"Unknown model: {self.model_name}") except Exception as e: if self.fallback_enabled: logger.warning( f"Failed to load {self.model_name} ({e}), falling back to CLIP" ) return CLIPEmbedder(device=self.device) raise RuntimeError(f"Failed to load embedder: {e}") def embed(self, image: Image.Image, use_cache: bool = True) -> np.ndarray: """ Generate embedding for an image with caching. Args: image: PIL Image to embed use_cache: If True, use cache for identical images Returns: np.ndarray: Normalized embedding vector Raises: ValueError: If image is invalid RuntimeError: If embedding generation fails """ if not isinstance(image, Image.Image): raise ValueError("Input must be a PIL Image") # Check cache if enabled if use_cache: cache_key = self._get_cache_key(image) if cache_key in self._cache: # Move to end (most recently used) self._cache.move_to_end(cache_key) self._cache_hits += 1 logger.debug(f"Cache hit (total: {self._cache_hits})") return self._cache[cache_key] self._cache_misses += 1 # Generate embedding embedding = self.embedder.embed(image) # Store in cache if use_cache: self._add_to_cache(cache_key, embedding) return embedding def embed_batch( self, images: list[Image.Image], use_cache: bool = True ) -> np.ndarray: """ Generate embeddings for multiple images. This method checks cache for each image individually and only generates embeddings for cache misses. Args: images: List of PIL Images to embed use_cache: If True, use cache for identical images Returns: np.ndarray: Array of embeddings (len(images), dimension) Raises: ValueError: If any image is invalid RuntimeError: If embedding generation fails """ if not images: return np.array([]).reshape(0, self.get_dimension()) embeddings = [] images_to_embed = [] indices_to_embed = [] # Check cache for each image for i, img in enumerate(images): if not isinstance(img, Image.Image): raise ValueError(f"Image at index {i} is not a PIL Image") if use_cache: cache_key = self._get_cache_key(img) if cache_key in self._cache: self._cache.move_to_end(cache_key) self._cache_hits += 1 embeddings.append((i, self._cache[cache_key])) continue self._cache_misses += 1 # Need to generate embedding images_to_embed.append(img) indices_to_embed.append(i) # Generate embeddings for cache misses if images_to_embed: if self.embedder.supports_batch(): new_embeddings = self.embedder.embed_batch(images_to_embed) else: new_embeddings = np.array([ self.embedder.embed(img) for img in images_to_embed ]) # Add to cache and results for img, idx, emb in zip(images_to_embed, indices_to_embed, new_embeddings): if use_cache: cache_key = self._get_cache_key(img) self._add_to_cache(cache_key, emb) embeddings.append((idx, emb)) # Sort by original index and extract embeddings embeddings.sort(key=lambda x: x[0]) return np.array([emb for _, emb in embeddings]) def _get_cache_key(self, image: Image.Image) -> str: """ Generate cache key from image content. Uses MD5 hash of image bytes for fast lookup. Args: image: PIL Image Returns: str: Cache key (MD5 hash) """ return hashlib.md5(image.tobytes()).hexdigest() def _add_to_cache(self, key: str, embedding: np.ndarray): """ Add embedding to cache with LRU eviction. Args: key: Cache key embedding: Embedding to cache """ # Add to cache self._cache[key] = embedding # Evict oldest if cache is full if len(self._cache) > self.cache_size: oldest_key = next(iter(self._cache)) del self._cache[oldest_key] logger.debug(f"Cache eviction: size={len(self._cache)}") def clear_cache(self): """Clear all cached embeddings.""" self._cache.clear() logger.info("Cache cleared") def get_dimension(self) -> int: """ Get embedding dimension. Returns: int: Embedding dimension """ return self.embedder.get_dimension() def get_model_name(self) -> str: """ Get current model name. Returns: str: Model identifier """ return self.embedder.get_model_name() def get_stats(self) -> Dict[str, Any]: """ Get manager statistics. Returns: Dict with keys: - model_name: Current model - dimension: Embedding dimension - cache_size: Current cache size - cache_capacity: Maximum cache size - cache_hits: Number of cache hits - cache_misses: Number of cache misses - cache_hit_rate: Hit rate (0-1) """ total_requests = self._cache_hits + self._cache_misses hit_rate = self._cache_hits / total_requests if total_requests > 0 else 0.0 return { 'model_name': self.get_model_name(), 'dimension': self.get_dimension(), 'cache_size': len(self._cache), 'cache_capacity': self.cache_size, 'cache_hits': self._cache_hits, 'cache_misses': self._cache_misses, 'cache_hit_rate': hit_rate } def __repr__(self) -> str: """String representation.""" stats = self.get_stats() return ( f"EmbeddingManager(model={stats['model_name']}, " f"cache={stats['cache_size']}/{stats['cache_capacity']}, " f"hit_rate={stats['cache_hit_rate']:.2%})" )