310 lines
9.9 KiB
Python
310 lines
9.9 KiB
Python
"""
|
|
Embedding manager with model selection, caching, and fallback.
|
|
|
|
This module provides a high-level interface for generating embeddings,
|
|
with automatic model selection, LRU caching, and fallback to CLIP if
|
|
the selected model fails to load.
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
from typing import Optional, Dict, Any
|
|
from collections import OrderedDict
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
from .base import EmbedderBase
|
|
from .clip_embedder import CLIPEmbedder
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingManager:
|
|
"""
|
|
High-level manager for image embeddings.
|
|
|
|
Features:
|
|
- Model selection (CLIP, Pix2Struct, etc.)
|
|
- Automatic fallback to CLIP on errors
|
|
- LRU cache (1000 entries) for performance
|
|
- GPU/CPU management
|
|
- Logging and monitoring
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "clip",
|
|
fallback_enabled: bool = True,
|
|
cache_size: int = 1000,
|
|
device: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize the embedding manager.
|
|
|
|
Args:
|
|
model_name: Model to use ("clip" or "pix2struct")
|
|
fallback_enabled: If True, fallback to CLIP on model load failure
|
|
cache_size: Maximum number of cached embeddings (LRU eviction)
|
|
device: Device to use ('cuda', 'cpu', or None for auto)
|
|
|
|
Raises:
|
|
RuntimeError: If model loading fails and fallback is disabled
|
|
"""
|
|
self.model_name = model_name.lower()
|
|
self.fallback_enabled = fallback_enabled
|
|
self.cache_size = cache_size
|
|
self.device = device
|
|
|
|
# Initialize embedder
|
|
self.embedder = self._load_embedder()
|
|
|
|
# Initialize LRU cache
|
|
self._cache: OrderedDict[str, np.ndarray] = OrderedDict()
|
|
|
|
# Statistics
|
|
self._cache_hits = 0
|
|
self._cache_misses = 0
|
|
|
|
logger.info(
|
|
f"EmbeddingManager initialized: model={self.embedder.get_model_name()}, "
|
|
f"dimension={self.embedder.get_dimension()}, "
|
|
f"cache_size={cache_size}"
|
|
)
|
|
|
|
def _load_embedder(self) -> EmbedderBase:
|
|
"""
|
|
Load the specified embedder with fallback support.
|
|
|
|
Returns:
|
|
EmbedderBase: Loaded embedder instance
|
|
|
|
Raises:
|
|
RuntimeError: If loading fails and fallback is disabled
|
|
"""
|
|
try:
|
|
if self.model_name == "clip":
|
|
return CLIPEmbedder(device=self.device)
|
|
|
|
elif self.model_name == "pix2struct":
|
|
# Import here to avoid dependency if not used
|
|
try:
|
|
from .pix2struct_embedder import Pix2StructEmbedder
|
|
return Pix2StructEmbedder(device=self.device)
|
|
except ImportError as e:
|
|
if self.fallback_enabled:
|
|
logger.warning(
|
|
f"Pix2Struct not available ({e}), falling back to CLIP"
|
|
)
|
|
return CLIPEmbedder(device=self.device)
|
|
raise
|
|
|
|
else:
|
|
raise ValueError(f"Unknown model: {self.model_name}")
|
|
|
|
except Exception as e:
|
|
if self.fallback_enabled:
|
|
logger.warning(
|
|
f"Failed to load {self.model_name} ({e}), falling back to CLIP"
|
|
)
|
|
return CLIPEmbedder(device=self.device)
|
|
raise RuntimeError(f"Failed to load embedder: {e}")
|
|
|
|
def embed(self, image: Image.Image, use_cache: bool = True) -> np.ndarray:
|
|
"""
|
|
Generate embedding for an image with caching.
|
|
|
|
Args:
|
|
image: PIL Image to embed
|
|
use_cache: If True, use cache for identical images
|
|
|
|
Returns:
|
|
np.ndarray: Normalized embedding vector
|
|
|
|
Raises:
|
|
ValueError: If image is invalid
|
|
RuntimeError: If embedding generation fails
|
|
"""
|
|
if not isinstance(image, Image.Image):
|
|
raise ValueError("Input must be a PIL Image")
|
|
|
|
# Check cache if enabled
|
|
if use_cache:
|
|
cache_key = self._get_cache_key(image)
|
|
|
|
if cache_key in self._cache:
|
|
# Move to end (most recently used)
|
|
self._cache.move_to_end(cache_key)
|
|
self._cache_hits += 1
|
|
logger.debug(f"Cache hit (total: {self._cache_hits})")
|
|
return self._cache[cache_key]
|
|
|
|
self._cache_misses += 1
|
|
|
|
# Generate embedding
|
|
embedding = self.embedder.embed(image)
|
|
|
|
# Store in cache
|
|
if use_cache:
|
|
self._add_to_cache(cache_key, embedding)
|
|
|
|
return embedding
|
|
|
|
def embed_batch(
|
|
self,
|
|
images: list[Image.Image],
|
|
use_cache: bool = True
|
|
) -> np.ndarray:
|
|
"""
|
|
Generate embeddings for multiple images.
|
|
|
|
This method checks cache for each image individually and only
|
|
generates embeddings for cache misses.
|
|
|
|
Args:
|
|
images: List of PIL Images to embed
|
|
use_cache: If True, use cache for identical images
|
|
|
|
Returns:
|
|
np.ndarray: Array of embeddings (len(images), dimension)
|
|
|
|
Raises:
|
|
ValueError: If any image is invalid
|
|
RuntimeError: If embedding generation fails
|
|
"""
|
|
if not images:
|
|
return np.array([]).reshape(0, self.get_dimension())
|
|
|
|
embeddings = []
|
|
images_to_embed = []
|
|
indices_to_embed = []
|
|
|
|
# Check cache for each image
|
|
for i, img in enumerate(images):
|
|
if not isinstance(img, Image.Image):
|
|
raise ValueError(f"Image at index {i} is not a PIL Image")
|
|
|
|
if use_cache:
|
|
cache_key = self._get_cache_key(img)
|
|
if cache_key in self._cache:
|
|
self._cache.move_to_end(cache_key)
|
|
self._cache_hits += 1
|
|
embeddings.append((i, self._cache[cache_key]))
|
|
continue
|
|
|
|
self._cache_misses += 1
|
|
|
|
# Need to generate embedding
|
|
images_to_embed.append(img)
|
|
indices_to_embed.append(i)
|
|
|
|
# Generate embeddings for cache misses
|
|
if images_to_embed:
|
|
if self.embedder.supports_batch():
|
|
new_embeddings = self.embedder.embed_batch(images_to_embed)
|
|
else:
|
|
new_embeddings = np.array([
|
|
self.embedder.embed(img) for img in images_to_embed
|
|
])
|
|
|
|
# Add to cache and results
|
|
for img, idx, emb in zip(images_to_embed, indices_to_embed, new_embeddings):
|
|
if use_cache:
|
|
cache_key = self._get_cache_key(img)
|
|
self._add_to_cache(cache_key, emb)
|
|
embeddings.append((idx, emb))
|
|
|
|
# Sort by original index and extract embeddings
|
|
embeddings.sort(key=lambda x: x[0])
|
|
return np.array([emb for _, emb in embeddings])
|
|
|
|
def _get_cache_key(self, image: Image.Image) -> str:
|
|
"""
|
|
Generate cache key from image content.
|
|
|
|
Uses MD5 hash of image bytes for fast lookup.
|
|
|
|
Args:
|
|
image: PIL Image
|
|
|
|
Returns:
|
|
str: Cache key (MD5 hash)
|
|
"""
|
|
return hashlib.md5(image.tobytes()).hexdigest()
|
|
|
|
def _add_to_cache(self, key: str, embedding: np.ndarray):
|
|
"""
|
|
Add embedding to cache with LRU eviction.
|
|
|
|
Args:
|
|
key: Cache key
|
|
embedding: Embedding to cache
|
|
"""
|
|
# Add to cache
|
|
self._cache[key] = embedding
|
|
|
|
# Evict oldest if cache is full
|
|
if len(self._cache) > self.cache_size:
|
|
oldest_key = next(iter(self._cache))
|
|
del self._cache[oldest_key]
|
|
logger.debug(f"Cache eviction: size={len(self._cache)}")
|
|
|
|
def clear_cache(self):
|
|
"""Clear all cached embeddings."""
|
|
self._cache.clear()
|
|
logger.info("Cache cleared")
|
|
|
|
def get_dimension(self) -> int:
|
|
"""
|
|
Get embedding dimension.
|
|
|
|
Returns:
|
|
int: Embedding dimension
|
|
"""
|
|
return self.embedder.get_dimension()
|
|
|
|
def get_model_name(self) -> str:
|
|
"""
|
|
Get current model name.
|
|
|
|
Returns:
|
|
str: Model identifier
|
|
"""
|
|
return self.embedder.get_model_name()
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
Get manager statistics.
|
|
|
|
Returns:
|
|
Dict with keys:
|
|
- model_name: Current model
|
|
- dimension: Embedding dimension
|
|
- cache_size: Current cache size
|
|
- cache_capacity: Maximum cache size
|
|
- cache_hits: Number of cache hits
|
|
- cache_misses: Number of cache misses
|
|
- cache_hit_rate: Hit rate (0-1)
|
|
"""
|
|
total_requests = self._cache_hits + self._cache_misses
|
|
hit_rate = self._cache_hits / total_requests if total_requests > 0 else 0.0
|
|
|
|
return {
|
|
'model_name': self.get_model_name(),
|
|
'dimension': self.get_dimension(),
|
|
'cache_size': len(self._cache),
|
|
'cache_capacity': self.cache_size,
|
|
'cache_hits': self._cache_hits,
|
|
'cache_misses': self._cache_misses,
|
|
'cache_hit_rate': hit_rate
|
|
}
|
|
|
|
def __repr__(self) -> str:
|
|
"""String representation."""
|
|
stats = self.get_stats()
|
|
return (
|
|
f"EmbeddingManager(model={stats['model_name']}, "
|
|
f"cache={stats['cache_size']}/{stats['cache_capacity']}, "
|
|
f"hit_rate={stats['cache_hit_rate']:.2%})"
|
|
)
|