101 lines
2.9 KiB
Python
101 lines
2.9 KiB
Python
"""
|
|
Abstract base class for embedding models.
|
|
|
|
This module defines the interface that all embedding models must implement,
|
|
ensuring consistency across different model implementations (CLIP, Pix2Struct, etc.).
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import List
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
|
|
class EmbedderBase(ABC):
|
|
"""
|
|
Abstract base class for image embedding models.
|
|
|
|
All embedding models must implement this interface to ensure
|
|
compatibility with the workflow matching system.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def embed(self, image: Image.Image) -> np.ndarray:
|
|
"""
|
|
Generate an embedding vector for a single image.
|
|
|
|
Args:
|
|
image: PIL Image to embed
|
|
|
|
Returns:
|
|
np.ndarray: Normalized embedding vector of shape (dimension,)
|
|
The vector should be L2-normalized for cosine similarity
|
|
|
|
Raises:
|
|
ValueError: If image is invalid or cannot be processed
|
|
RuntimeError: If model inference fails
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_dimension(self) -> int:
|
|
"""
|
|
Get the dimensionality of embeddings produced by this model.
|
|
|
|
Returns:
|
|
int: Embedding dimension (e.g., 512 for CLIP ViT-B/32, 768 for Pix2Struct)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_model_name(self) -> str:
|
|
"""
|
|
Get a unique identifier for this model.
|
|
|
|
Returns:
|
|
str: Model name (e.g., "clip-vit-b32", "pix2struct-base")
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def supports_batch(self) -> bool:
|
|
"""
|
|
Check if this model supports batch processing.
|
|
|
|
Returns:
|
|
bool: True if embed_batch() is optimized, False otherwise
|
|
"""
|
|
pass
|
|
|
|
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
|
"""
|
|
Generate embeddings for multiple images.
|
|
|
|
Default implementation processes images one by one.
|
|
Subclasses can override this for optimized batch processing.
|
|
|
|
Args:
|
|
images: List of PIL Images to embed
|
|
|
|
Returns:
|
|
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
|
Each row is a normalized embedding vector
|
|
|
|
Raises:
|
|
ValueError: If any image is invalid
|
|
RuntimeError: If model inference fails
|
|
"""
|
|
if not images:
|
|
return np.array([]).reshape(0, self.get_dimension())
|
|
|
|
embeddings = []
|
|
for img in images:
|
|
embedding = self.embed(img)
|
|
embeddings.append(embedding)
|
|
|
|
return np.array(embeddings)
|
|
|
|
def __repr__(self) -> str:
|
|
"""String representation of the embedder."""
|
|
return f"{self.__class__.__name__}(model={self.get_model_name()}, dim={self.get_dimension()})"
|