Files
Geniusia_v2/geniusia2/core/embedders/base.py
2026-03-05 00:20:25 +01:00

101 lines
2.9 KiB
Python

"""
Abstract base class for embedding models.
This module defines the interface that all embedding models must implement,
ensuring consistency across different model implementations (CLIP, Pix2Struct, etc.).
"""
from abc import ABC, abstractmethod
from typing import List
from PIL import Image
import numpy as np
class EmbedderBase(ABC):
"""
Abstract base class for image embedding models.
All embedding models must implement this interface to ensure
compatibility with the workflow matching system.
"""
@abstractmethod
def embed(self, image: Image.Image) -> np.ndarray:
"""
Generate an embedding vector for a single image.
Args:
image: PIL Image to embed
Returns:
np.ndarray: Normalized embedding vector of shape (dimension,)
The vector should be L2-normalized for cosine similarity
Raises:
ValueError: If image is invalid or cannot be processed
RuntimeError: If model inference fails
"""
pass
@abstractmethod
def get_dimension(self) -> int:
"""
Get the dimensionality of embeddings produced by this model.
Returns:
int: Embedding dimension (e.g., 512 for CLIP ViT-B/32, 768 for Pix2Struct)
"""
pass
@abstractmethod
def get_model_name(self) -> str:
"""
Get a unique identifier for this model.
Returns:
str: Model name (e.g., "clip-vit-b32", "pix2struct-base")
"""
pass
@abstractmethod
def supports_batch(self) -> bool:
"""
Check if this model supports batch processing.
Returns:
bool: True if embed_batch() is optimized, False otherwise
"""
pass
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
"""
Generate embeddings for multiple images.
Default implementation processes images one by one.
Subclasses can override this for optimized batch processing.
Args:
images: List of PIL Images to embed
Returns:
np.ndarray: Array of embeddings with shape (len(images), dimension)
Each row is a normalized embedding vector
Raises:
ValueError: If any image is invalid
RuntimeError: If model inference fails
"""
if not images:
return np.array([]).reshape(0, self.get_dimension())
embeddings = []
for img in images:
embedding = self.embed(img)
embeddings.append(embedding)
return np.array(embeddings)
def __repr__(self) -> str:
"""String representation of the embedder."""
return f"{self.__class__.__name__}(model={self.get_model_name()}, dim={self.get_dimension()})"