Initial commit
This commit is contained in:
100
geniusia2/core/embedders/base.py
Normal file
100
geniusia2/core/embedders/base.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Abstract base class for embedding models.
|
||||
|
||||
This module defines the interface that all embedding models must implement,
|
||||
ensuring consistency across different model implementations (CLIP, Pix2Struct, etc.).
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EmbedderBase(ABC):
|
||||
"""
|
||||
Abstract base class for image embedding models.
|
||||
|
||||
All embedding models must implement this interface to ensure
|
||||
compatibility with the workflow matching system.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, image: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
Generate an embedding vector for a single image.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector of shape (dimension,)
|
||||
The vector should be L2-normalized for cosine similarity
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid or cannot be processed
|
||||
RuntimeError: If model inference fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get the dimensionality of embeddings produced by this model.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension (e.g., 512 for CLIP ViT-B/32, 768 for Pix2Struct)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get a unique identifier for this model.
|
||||
|
||||
Returns:
|
||||
str: Model name (e.g., "clip-vit-b32", "pix2struct-base")
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def supports_batch(self) -> bool:
|
||||
"""
|
||||
Check if this model supports batch processing.
|
||||
|
||||
Returns:
|
||||
bool: True if embed_batch() is optimized, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images.
|
||||
|
||||
Default implementation processes images one by one.
|
||||
Subclasses can override this for optimized batch processing.
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
||||
Each row is a normalized embedding vector
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If model inference fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
embeddings = []
|
||||
for img in images:
|
||||
embedding = self.embed(img)
|
||||
embeddings.append(embedding)
|
||||
|
||||
return np.array(embeddings)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the embedder."""
|
||||
return f"{self.__class__.__name__}(model={self.get_model_name()}, dim={self.get_dimension()})"
|
||||
Reference in New Issue
Block a user