Initial commit
This commit is contained in:
309
geniusia2/core/embedders/embedding_manager.py
Normal file
309
geniusia2/core/embedders/embedding_manager.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Embedding manager with model selection, caching, and fallback.
|
||||
|
||||
This module provides a high-level interface for generating embeddings,
|
||||
with automatic model selection, LRU caching, and fallback to CLIP if
|
||||
the selected model fails to load.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from collections import OrderedDict
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from .base import EmbedderBase
|
||||
from .clip_embedder import CLIPEmbedder
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
"""
|
||||
High-level manager for image embeddings.
|
||||
|
||||
Features:
|
||||
- Model selection (CLIP, Pix2Struct, etc.)
|
||||
- Automatic fallback to CLIP on errors
|
||||
- LRU cache (1000 entries) for performance
|
||||
- GPU/CPU management
|
||||
- Logging and monitoring
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "clip",
|
||||
fallback_enabled: bool = True,
|
||||
cache_size: int = 1000,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the embedding manager.
|
||||
|
||||
Args:
|
||||
model_name: Model to use ("clip" or "pix2struct")
|
||||
fallback_enabled: If True, fallback to CLIP on model load failure
|
||||
cache_size: Maximum number of cached embeddings (LRU eviction)
|
||||
device: Device to use ('cuda', 'cpu', or None for auto)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails and fallback is disabled
|
||||
"""
|
||||
self.model_name = model_name.lower()
|
||||
self.fallback_enabled = fallback_enabled
|
||||
self.cache_size = cache_size
|
||||
self.device = device
|
||||
|
||||
# Initialize embedder
|
||||
self.embedder = self._load_embedder()
|
||||
|
||||
# Initialize LRU cache
|
||||
self._cache: OrderedDict[str, np.ndarray] = OrderedDict()
|
||||
|
||||
# Statistics
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
logger.info(
|
||||
f"EmbeddingManager initialized: model={self.embedder.get_model_name()}, "
|
||||
f"dimension={self.embedder.get_dimension()}, "
|
||||
f"cache_size={cache_size}"
|
||||
)
|
||||
|
||||
def _load_embedder(self) -> EmbedderBase:
|
||||
"""
|
||||
Load the specified embedder with fallback support.
|
||||
|
||||
Returns:
|
||||
EmbedderBase: Loaded embedder instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If loading fails and fallback is disabled
|
||||
"""
|
||||
try:
|
||||
if self.model_name == "clip":
|
||||
return CLIPEmbedder(device=self.device)
|
||||
|
||||
elif self.model_name == "pix2struct":
|
||||
# Import here to avoid dependency if not used
|
||||
try:
|
||||
from .pix2struct_embedder import Pix2StructEmbedder
|
||||
return Pix2StructEmbedder(device=self.device)
|
||||
except ImportError as e:
|
||||
if self.fallback_enabled:
|
||||
logger.warning(
|
||||
f"Pix2Struct not available ({e}), falling back to CLIP"
|
||||
)
|
||||
return CLIPEmbedder(device=self.device)
|
||||
raise
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model_name}")
|
||||
|
||||
except Exception as e:
|
||||
if self.fallback_enabled:
|
||||
logger.warning(
|
||||
f"Failed to load {self.model_name} ({e}), falling back to CLIP"
|
||||
)
|
||||
return CLIPEmbedder(device=self.device)
|
||||
raise RuntimeError(f"Failed to load embedder: {e}")
|
||||
|
||||
def embed(self, image: Image.Image, use_cache: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for an image with caching.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
use_cache: If True, use cache for identical images
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError("Input must be a PIL Image")
|
||||
|
||||
# Check cache if enabled
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key(image)
|
||||
|
||||
if cache_key in self._cache:
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._cache_hits += 1
|
||||
logger.debug(f"Cache hit (total: {self._cache_hits})")
|
||||
return self._cache[cache_key]
|
||||
|
||||
self._cache_misses += 1
|
||||
|
||||
# Generate embedding
|
||||
embedding = self.embedder.embed(image)
|
||||
|
||||
# Store in cache
|
||||
if use_cache:
|
||||
self._add_to_cache(cache_key, embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
def embed_batch(
|
||||
self,
|
||||
images: list[Image.Image],
|
||||
use_cache: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images.
|
||||
|
||||
This method checks cache for each image individually and only
|
||||
generates embeddings for cache misses.
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
use_cache: If True, use cache for identical images
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings (len(images), dimension)
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
embeddings = []
|
||||
images_to_embed = []
|
||||
indices_to_embed = []
|
||||
|
||||
# Check cache for each image
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, Image.Image):
|
||||
raise ValueError(f"Image at index {i} is not a PIL Image")
|
||||
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key(img)
|
||||
if cache_key in self._cache:
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._cache_hits += 1
|
||||
embeddings.append((i, self._cache[cache_key]))
|
||||
continue
|
||||
|
||||
self._cache_misses += 1
|
||||
|
||||
# Need to generate embedding
|
||||
images_to_embed.append(img)
|
||||
indices_to_embed.append(i)
|
||||
|
||||
# Generate embeddings for cache misses
|
||||
if images_to_embed:
|
||||
if self.embedder.supports_batch():
|
||||
new_embeddings = self.embedder.embed_batch(images_to_embed)
|
||||
else:
|
||||
new_embeddings = np.array([
|
||||
self.embedder.embed(img) for img in images_to_embed
|
||||
])
|
||||
|
||||
# Add to cache and results
|
||||
for img, idx, emb in zip(images_to_embed, indices_to_embed, new_embeddings):
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key(img)
|
||||
self._add_to_cache(cache_key, emb)
|
||||
embeddings.append((idx, emb))
|
||||
|
||||
# Sort by original index and extract embeddings
|
||||
embeddings.sort(key=lambda x: x[0])
|
||||
return np.array([emb for _, emb in embeddings])
|
||||
|
||||
def _get_cache_key(self, image: Image.Image) -> str:
|
||||
"""
|
||||
Generate cache key from image content.
|
||||
|
||||
Uses MD5 hash of image bytes for fast lookup.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
|
||||
Returns:
|
||||
str: Cache key (MD5 hash)
|
||||
"""
|
||||
return hashlib.md5(image.tobytes()).hexdigest()
|
||||
|
||||
def _add_to_cache(self, key: str, embedding: np.ndarray):
|
||||
"""
|
||||
Add embedding to cache with LRU eviction.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
embedding: Embedding to cache
|
||||
"""
|
||||
# Add to cache
|
||||
self._cache[key] = embedding
|
||||
|
||||
# Evict oldest if cache is full
|
||||
if len(self._cache) > self.cache_size:
|
||||
oldest_key = next(iter(self._cache))
|
||||
del self._cache[oldest_key]
|
||||
logger.debug(f"Cache eviction: size={len(self._cache)}")
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear all cached embeddings."""
|
||||
self._cache.clear()
|
||||
logger.info("Cache cleared")
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get embedding dimension.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension
|
||||
"""
|
||||
return self.embedder.get_dimension()
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get current model name.
|
||||
|
||||
Returns:
|
||||
str: Model identifier
|
||||
"""
|
||||
return self.embedder.get_model_name()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get manager statistics.
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- model_name: Current model
|
||||
- dimension: Embedding dimension
|
||||
- cache_size: Current cache size
|
||||
- cache_capacity: Maximum cache size
|
||||
- cache_hits: Number of cache hits
|
||||
- cache_misses: Number of cache misses
|
||||
- cache_hit_rate: Hit rate (0-1)
|
||||
"""
|
||||
total_requests = self._cache_hits + self._cache_misses
|
||||
hit_rate = self._cache_hits / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
return {
|
||||
'model_name': self.get_model_name(),
|
||||
'dimension': self.get_dimension(),
|
||||
'cache_size': len(self._cache),
|
||||
'cache_capacity': self.cache_size,
|
||||
'cache_hits': self._cache_hits,
|
||||
'cache_misses': self._cache_misses,
|
||||
'cache_hit_rate': hit_rate
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation."""
|
||||
stats = self.get_stats()
|
||||
return (
|
||||
f"EmbeddingManager(model={stats['model_name']}, "
|
||||
f"cache={stats['cache_size']}/{stats['cache_capacity']}, "
|
||||
f"hit_rate={stats['cache_hit_rate']:.2%})"
|
||||
)
|
||||
Reference in New Issue
Block a user