Files
Geniusia_v2/geniusia2/core/embedders/embedding_manager.py
2026-03-05 00:20:25 +01:00

310 lines
9.9 KiB
Python

"""
Embedding manager with model selection, caching, and fallback.
This module provides a high-level interface for generating embeddings,
with automatic model selection, LRU caching, and fallback to CLIP if
the selected model fails to load.
"""
import hashlib
import logging
from typing import Optional, Dict, Any
from collections import OrderedDict
from PIL import Image
import numpy as np
from .base import EmbedderBase
from .clip_embedder import CLIPEmbedder
logger = logging.getLogger(__name__)
class EmbeddingManager:
"""
High-level manager for image embeddings.
Features:
- Model selection (CLIP, Pix2Struct, etc.)
- Automatic fallback to CLIP on errors
- LRU cache (1000 entries) for performance
- GPU/CPU management
- Logging and monitoring
"""
def __init__(
self,
model_name: str = "clip",
fallback_enabled: bool = True,
cache_size: int = 1000,
device: Optional[str] = None
):
"""
Initialize the embedding manager.
Args:
model_name: Model to use ("clip" or "pix2struct")
fallback_enabled: If True, fallback to CLIP on model load failure
cache_size: Maximum number of cached embeddings (LRU eviction)
device: Device to use ('cuda', 'cpu', or None for auto)
Raises:
RuntimeError: If model loading fails and fallback is disabled
"""
self.model_name = model_name.lower()
self.fallback_enabled = fallback_enabled
self.cache_size = cache_size
self.device = device
# Initialize embedder
self.embedder = self._load_embedder()
# Initialize LRU cache
self._cache: OrderedDict[str, np.ndarray] = OrderedDict()
# Statistics
self._cache_hits = 0
self._cache_misses = 0
logger.info(
f"EmbeddingManager initialized: model={self.embedder.get_model_name()}, "
f"dimension={self.embedder.get_dimension()}, "
f"cache_size={cache_size}"
)
def _load_embedder(self) -> EmbedderBase:
"""
Load the specified embedder with fallback support.
Returns:
EmbedderBase: Loaded embedder instance
Raises:
RuntimeError: If loading fails and fallback is disabled
"""
try:
if self.model_name == "clip":
return CLIPEmbedder(device=self.device)
elif self.model_name == "pix2struct":
# Import here to avoid dependency if not used
try:
from .pix2struct_embedder import Pix2StructEmbedder
return Pix2StructEmbedder(device=self.device)
except ImportError as e:
if self.fallback_enabled:
logger.warning(
f"Pix2Struct not available ({e}), falling back to CLIP"
)
return CLIPEmbedder(device=self.device)
raise
else:
raise ValueError(f"Unknown model: {self.model_name}")
except Exception as e:
if self.fallback_enabled:
logger.warning(
f"Failed to load {self.model_name} ({e}), falling back to CLIP"
)
return CLIPEmbedder(device=self.device)
raise RuntimeError(f"Failed to load embedder: {e}")
def embed(self, image: Image.Image, use_cache: bool = True) -> np.ndarray:
"""
Generate embedding for an image with caching.
Args:
image: PIL Image to embed
use_cache: If True, use cache for identical images
Returns:
np.ndarray: Normalized embedding vector
Raises:
ValueError: If image is invalid
RuntimeError: If embedding generation fails
"""
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image")
# Check cache if enabled
if use_cache:
cache_key = self._get_cache_key(image)
if cache_key in self._cache:
# Move to end (most recently used)
self._cache.move_to_end(cache_key)
self._cache_hits += 1
logger.debug(f"Cache hit (total: {self._cache_hits})")
return self._cache[cache_key]
self._cache_misses += 1
# Generate embedding
embedding = self.embedder.embed(image)
# Store in cache
if use_cache:
self._add_to_cache(cache_key, embedding)
return embedding
def embed_batch(
self,
images: list[Image.Image],
use_cache: bool = True
) -> np.ndarray:
"""
Generate embeddings for multiple images.
This method checks cache for each image individually and only
generates embeddings for cache misses.
Args:
images: List of PIL Images to embed
use_cache: If True, use cache for identical images
Returns:
np.ndarray: Array of embeddings (len(images), dimension)
Raises:
ValueError: If any image is invalid
RuntimeError: If embedding generation fails
"""
if not images:
return np.array([]).reshape(0, self.get_dimension())
embeddings = []
images_to_embed = []
indices_to_embed = []
# Check cache for each image
for i, img in enumerate(images):
if not isinstance(img, Image.Image):
raise ValueError(f"Image at index {i} is not a PIL Image")
if use_cache:
cache_key = self._get_cache_key(img)
if cache_key in self._cache:
self._cache.move_to_end(cache_key)
self._cache_hits += 1
embeddings.append((i, self._cache[cache_key]))
continue
self._cache_misses += 1
# Need to generate embedding
images_to_embed.append(img)
indices_to_embed.append(i)
# Generate embeddings for cache misses
if images_to_embed:
if self.embedder.supports_batch():
new_embeddings = self.embedder.embed_batch(images_to_embed)
else:
new_embeddings = np.array([
self.embedder.embed(img) for img in images_to_embed
])
# Add to cache and results
for img, idx, emb in zip(images_to_embed, indices_to_embed, new_embeddings):
if use_cache:
cache_key = self._get_cache_key(img)
self._add_to_cache(cache_key, emb)
embeddings.append((idx, emb))
# Sort by original index and extract embeddings
embeddings.sort(key=lambda x: x[0])
return np.array([emb for _, emb in embeddings])
def _get_cache_key(self, image: Image.Image) -> str:
"""
Generate cache key from image content.
Uses MD5 hash of image bytes for fast lookup.
Args:
image: PIL Image
Returns:
str: Cache key (MD5 hash)
"""
return hashlib.md5(image.tobytes()).hexdigest()
def _add_to_cache(self, key: str, embedding: np.ndarray):
"""
Add embedding to cache with LRU eviction.
Args:
key: Cache key
embedding: Embedding to cache
"""
# Add to cache
self._cache[key] = embedding
# Evict oldest if cache is full
if len(self._cache) > self.cache_size:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
logger.debug(f"Cache eviction: size={len(self._cache)}")
def clear_cache(self):
"""Clear all cached embeddings."""
self._cache.clear()
logger.info("Cache cleared")
def get_dimension(self) -> int:
"""
Get embedding dimension.
Returns:
int: Embedding dimension
"""
return self.embedder.get_dimension()
def get_model_name(self) -> str:
"""
Get current model name.
Returns:
str: Model identifier
"""
return self.embedder.get_model_name()
def get_stats(self) -> Dict[str, Any]:
"""
Get manager statistics.
Returns:
Dict with keys:
- model_name: Current model
- dimension: Embedding dimension
- cache_size: Current cache size
- cache_capacity: Maximum cache size
- cache_hits: Number of cache hits
- cache_misses: Number of cache misses
- cache_hit_rate: Hit rate (0-1)
"""
total_requests = self._cache_hits + self._cache_misses
hit_rate = self._cache_hits / total_requests if total_requests > 0 else 0.0
return {
'model_name': self.get_model_name(),
'dimension': self.get_dimension(),
'cache_size': len(self._cache),
'cache_capacity': self.cache_size,
'cache_hits': self._cache_hits,
'cache_misses': self._cache_misses,
'cache_hit_rate': hit_rate
}
def __repr__(self) -> str:
"""String representation."""
stats = self.get_stats()
return (
f"EmbeddingManager(model={stats['model_name']}, "
f"cache={stats['cache_size']}/{stats['cache_capacity']}, "
f"hit_rate={stats['cache_hit_rate']:.2%})"
)