Initial commit
This commit is contained in:
19
geniusia2/core/embedders/__init__.py
Normal file
19
geniusia2/core/embedders/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Embedding system for visual similarity matching.
|
||||
|
||||
This module provides an abstraction layer for different embedding models
|
||||
(CLIP, Pix2Struct) used for workflow matching and visual analysis.
|
||||
"""
|
||||
|
||||
from .base import EmbedderBase
|
||||
from .clip_embedder import CLIPEmbedder
|
||||
from .faiss_index import FAISSIndex
|
||||
from .embedding_manager import EmbeddingManager
|
||||
from .fine_tuner import LightweightFineTuner
|
||||
|
||||
# Pix2Struct is optional (requires transformers>=4.35.0)
|
||||
try:
|
||||
from .pix2struct_embedder import Pix2StructEmbedder
|
||||
__all__ = ['EmbedderBase', 'CLIPEmbedder', 'Pix2StructEmbedder', 'FAISSIndex', 'EmbeddingManager', 'LightweightFineTuner']
|
||||
except ImportError:
|
||||
__all__ = ['EmbedderBase', 'CLIPEmbedder', 'FAISSIndex', 'EmbeddingManager', 'LightweightFineTuner']
|
||||
BIN
geniusia2/core/embedders/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
geniusia2/core/embedders/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/embedders/__pycache__/base.cpython-312.pyc
Normal file
BIN
geniusia2/core/embedders/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
geniusia2/core/embedders/__pycache__/faiss_index.cpython-312.pyc
Normal file
BIN
geniusia2/core/embedders/__pycache__/faiss_index.cpython-312.pyc
Normal file
Binary file not shown.
BIN
geniusia2/core/embedders/__pycache__/fine_tuner.cpython-312.pyc
Normal file
BIN
geniusia2/core/embedders/__pycache__/fine_tuner.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
100
geniusia2/core/embedders/base.py
Normal file
100
geniusia2/core/embedders/base.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Abstract base class for embedding models.
|
||||
|
||||
This module defines the interface that all embedding models must implement,
|
||||
ensuring consistency across different model implementations (CLIP, Pix2Struct, etc.).
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EmbedderBase(ABC):
|
||||
"""
|
||||
Abstract base class for image embedding models.
|
||||
|
||||
All embedding models must implement this interface to ensure
|
||||
compatibility with the workflow matching system.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, image: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
Generate an embedding vector for a single image.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector of shape (dimension,)
|
||||
The vector should be L2-normalized for cosine similarity
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid or cannot be processed
|
||||
RuntimeError: If model inference fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get the dimensionality of embeddings produced by this model.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension (e.g., 512 for CLIP ViT-B/32, 768 for Pix2Struct)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get a unique identifier for this model.
|
||||
|
||||
Returns:
|
||||
str: Model name (e.g., "clip-vit-b32", "pix2struct-base")
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def supports_batch(self) -> bool:
|
||||
"""
|
||||
Check if this model supports batch processing.
|
||||
|
||||
Returns:
|
||||
bool: True if embed_batch() is optimized, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images.
|
||||
|
||||
Default implementation processes images one by one.
|
||||
Subclasses can override this for optimized batch processing.
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
||||
Each row is a normalized embedding vector
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If model inference fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
embeddings = []
|
||||
for img in images:
|
||||
embedding = self.embed(img)
|
||||
embeddings.append(embedding)
|
||||
|
||||
return np.array(embeddings)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the embedder."""
|
||||
return f"{self.__class__.__name__}(model={self.get_model_name()}, dim={self.get_dimension()})"
|
||||
358
geniusia2/core/embedders/clip_embedder.py
Normal file
358
geniusia2/core/embedders/clip_embedder.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
CLIP-based embedder implementation.
|
||||
|
||||
This module provides a wrapper around OpenCLIP for generating image embeddings
|
||||
using the CLIP (Contrastive Language-Image Pre-training) model.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
try:
|
||||
import open_clip
|
||||
except ImportError:
|
||||
open_clip = None
|
||||
|
||||
from .base import EmbedderBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLIPEmbedder(EmbedderBase):
|
||||
"""
|
||||
CLIP-based image embedder using OpenCLIP.
|
||||
|
||||
This embedder uses the ViT-B/32 architecture by default, which produces
|
||||
512-dimensional embeddings. It automatically handles GPU/CPU device selection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "ViT-B-32",
|
||||
pretrained: str = "openai",
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the CLIP embedder.
|
||||
|
||||
Args:
|
||||
model_name: CLIP model architecture (default: ViT-B-32)
|
||||
pretrained: Pretrained weights to use (default: openai)
|
||||
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
||||
Note: Defaults to CPU to save GPU memory for other models
|
||||
|
||||
Raises:
|
||||
ImportError: If open_clip is not installed
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
if open_clip is None:
|
||||
raise ImportError(
|
||||
"OpenCLIP is not installed. "
|
||||
"Install it with: pip install open-clip-torch"
|
||||
)
|
||||
|
||||
# Default to CPU to save GPU for vision models (Qwen3-VL)
|
||||
if device is None:
|
||||
device = "cpu"
|
||||
|
||||
self.model_name = model_name
|
||||
self.pretrained = pretrained
|
||||
self.device = device
|
||||
self._embedding_dim = None
|
||||
|
||||
# Load model
|
||||
try:
|
||||
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
||||
model_name,
|
||||
pretrained=pretrained,
|
||||
device=device
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
# Determine embedding dimension
|
||||
with torch.no_grad():
|
||||
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
|
||||
dummy_embedding = self.model.encode_image(dummy_image)
|
||||
self._embedding_dim = dummy_embedding.shape[-1]
|
||||
|
||||
logger.info(
|
||||
f"CLIPEmbedder loaded: {model_name} on {device}, "
|
||||
f"dimension={self._embedding_dim}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load CLIP model: {e}")
|
||||
|
||||
def embed(self, image: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single image.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector of shape (dimension,)
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError("Input must be a PIL Image")
|
||||
|
||||
try:
|
||||
# Preprocess image
|
||||
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
|
||||
|
||||
# Generate embedding
|
||||
with torch.no_grad():
|
||||
embedding = self.model.encode_image(image_tensor)
|
||||
# L2 normalize for cosine similarity
|
||||
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embedding.cpu().numpy().flatten()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate embedding: {e}")
|
||||
|
||||
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images (optimized batch processing).
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
# Validate all images
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, Image.Image):
|
||||
raise ValueError(f"Image at index {i} is not a PIL Image")
|
||||
|
||||
try:
|
||||
# Preprocess all images
|
||||
image_tensors = torch.stack([
|
||||
self.preprocess(img) for img in images
|
||||
]).to(self.device)
|
||||
|
||||
# Generate embeddings in batch
|
||||
with torch.no_grad():
|
||||
embeddings = self.model.encode_image(image_tensors)
|
||||
# L2 normalize for cosine similarity
|
||||
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embeddings.cpu().numpy()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get the dimensionality of embeddings.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension (512 for ViT-B/32)
|
||||
"""
|
||||
return self._embedding_dim
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get model identifier.
|
||||
|
||||
Returns:
|
||||
str: Model name (e.g., "clip-vit-b32")
|
||||
"""
|
||||
return f"clip-{self.model_name.lower().replace('/', '-')}"
|
||||
|
||||
def supports_batch(self) -> bool:
|
||||
"""
|
||||
Check if batch processing is supported.
|
||||
|
||||
Returns:
|
||||
bool: True (CLIP supports efficient batch processing)
|
||||
"""
|
||||
return True
|
||||
|
||||
def fine_tune(
|
||||
self,
|
||||
positive_images: List[Image.Image],
|
||||
negative_images: List[Image.Image],
|
||||
epochs: int = 1,
|
||||
learning_rate: float = 1e-4
|
||||
) -> dict:
|
||||
"""
|
||||
Fine-tune the model using contrastive learning.
|
||||
|
||||
This method fine-tunes only the final projection layer to adapt
|
||||
the model to user-specific workflows. It uses a simple contrastive
|
||||
loss: positive examples should be similar, negative examples should
|
||||
be dissimilar.
|
||||
|
||||
Args:
|
||||
positive_images: Images from successful workflows
|
||||
negative_images: Images from rejected workflows
|
||||
epochs: Number of training epochs (default: 1 for speed)
|
||||
learning_rate: Learning rate (default: 1e-4)
|
||||
|
||||
Returns:
|
||||
dict: Training metrics (loss, accuracy, etc.)
|
||||
"""
|
||||
if not positive_images and not negative_images:
|
||||
return {'loss': 0.0, 'note': 'No examples to train on'}
|
||||
|
||||
# Set model to training mode
|
||||
self.model.train()
|
||||
|
||||
# Only train the visual projection layer (last layer)
|
||||
# Freeze all other parameters
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Unfreeze visual projection
|
||||
if hasattr(self.model, 'visual') and hasattr(self.model.visual, 'proj'):
|
||||
if self.model.visual.proj is not None:
|
||||
self.model.visual.proj.requires_grad = True
|
||||
|
||||
# Setup optimizer (only for trainable parameters)
|
||||
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
||||
|
||||
if not trainable_params:
|
||||
logger.warning("No trainable parameters found, skipping fine-tuning")
|
||||
self.model.eval()
|
||||
return {'loss': 0.0, 'note': 'No trainable parameters'}
|
||||
|
||||
optimizer = torch.optim.Adam(trainable_params, lr=learning_rate)
|
||||
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
try:
|
||||
for epoch in range(epochs):
|
||||
# Process positive examples (should be similar to each other)
|
||||
if len(positive_images) >= 2:
|
||||
pos_loss = self._contrastive_loss_positive(positive_images)
|
||||
total_loss += pos_loss
|
||||
num_batches += 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
pos_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Process negative examples (should be dissimilar from positives)
|
||||
if positive_images and negative_images:
|
||||
neg_loss = self._contrastive_loss_negative(
|
||||
positive_images[:5], # Use subset for speed
|
||||
negative_images[:5]
|
||||
)
|
||||
total_loss += neg_loss
|
||||
num_batches += 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
neg_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
finally:
|
||||
# Always restore eval mode and freeze parameters
|
||||
self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
|
||||
return {
|
||||
'loss': float(avg_loss),
|
||||
'epochs': epochs,
|
||||
'learning_rate': learning_rate,
|
||||
'positive_count': len(positive_images),
|
||||
'negative_count': len(negative_images),
|
||||
'num_batches': num_batches
|
||||
}
|
||||
|
||||
def _contrastive_loss_positive(self, images: List[Image.Image]) -> torch.Tensor:
|
||||
"""
|
||||
Contrastive loss for positive examples (should be similar).
|
||||
|
||||
Args:
|
||||
images: List of positive example images
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loss value
|
||||
"""
|
||||
# Generate embeddings
|
||||
embeddings = []
|
||||
for img in images:
|
||||
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
||||
emb = self.model.encode_image(img_tensor)
|
||||
emb = emb / emb.norm(dim=-1, keepdim=True)
|
||||
embeddings.append(emb)
|
||||
|
||||
embeddings = torch.cat(embeddings, dim=0)
|
||||
|
||||
# Compute pairwise cosine similarities
|
||||
similarities = torch.mm(embeddings, embeddings.t())
|
||||
|
||||
# Loss: maximize similarity (minimize negative similarity)
|
||||
# Exclude diagonal (self-similarity)
|
||||
mask = torch.eye(len(images), device=self.device).bool()
|
||||
similarities = similarities.masked_fill(mask, 0)
|
||||
|
||||
# We want high similarity, so minimize (1 - similarity)
|
||||
loss = (1 - similarities).mean()
|
||||
|
||||
return loss
|
||||
|
||||
def _contrastive_loss_negative(
|
||||
self,
|
||||
positive_images: List[Image.Image],
|
||||
negative_images: List[Image.Image]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Contrastive loss for negative examples (should be dissimilar).
|
||||
|
||||
Args:
|
||||
positive_images: Positive example images
|
||||
negative_images: Negative example images
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loss value
|
||||
"""
|
||||
# Generate embeddings for positives
|
||||
pos_embeddings = []
|
||||
for img in positive_images:
|
||||
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
||||
emb = self.model.encode_image(img_tensor)
|
||||
emb = emb / emb.norm(dim=-1, keepdim=True)
|
||||
pos_embeddings.append(emb)
|
||||
|
||||
pos_embeddings = torch.cat(pos_embeddings, dim=0)
|
||||
|
||||
# Generate embeddings for negatives
|
||||
neg_embeddings = []
|
||||
for img in negative_images:
|
||||
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
||||
emb = self.model.encode_image(img_tensor)
|
||||
emb = emb / emb.norm(dim=-1, keepdim=True)
|
||||
neg_embeddings.append(emb)
|
||||
|
||||
neg_embeddings = torch.cat(neg_embeddings, dim=0)
|
||||
|
||||
# Compute cross-similarities (positive vs negative)
|
||||
similarities = torch.mm(pos_embeddings, neg_embeddings.t())
|
||||
|
||||
# We want low similarity, so minimize similarity directly
|
||||
# (or maximize dissimilarity)
|
||||
loss = similarities.mean()
|
||||
|
||||
return loss
|
||||
309
geniusia2/core/embedders/embedding_manager.py
Normal file
309
geniusia2/core/embedders/embedding_manager.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Embedding manager with model selection, caching, and fallback.
|
||||
|
||||
This module provides a high-level interface for generating embeddings,
|
||||
with automatic model selection, LRU caching, and fallback to CLIP if
|
||||
the selected model fails to load.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from collections import OrderedDict
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from .base import EmbedderBase
|
||||
from .clip_embedder import CLIPEmbedder
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
"""
|
||||
High-level manager for image embeddings.
|
||||
|
||||
Features:
|
||||
- Model selection (CLIP, Pix2Struct, etc.)
|
||||
- Automatic fallback to CLIP on errors
|
||||
- LRU cache (1000 entries) for performance
|
||||
- GPU/CPU management
|
||||
- Logging and monitoring
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "clip",
|
||||
fallback_enabled: bool = True,
|
||||
cache_size: int = 1000,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the embedding manager.
|
||||
|
||||
Args:
|
||||
model_name: Model to use ("clip" or "pix2struct")
|
||||
fallback_enabled: If True, fallback to CLIP on model load failure
|
||||
cache_size: Maximum number of cached embeddings (LRU eviction)
|
||||
device: Device to use ('cuda', 'cpu', or None for auto)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails and fallback is disabled
|
||||
"""
|
||||
self.model_name = model_name.lower()
|
||||
self.fallback_enabled = fallback_enabled
|
||||
self.cache_size = cache_size
|
||||
self.device = device
|
||||
|
||||
# Initialize embedder
|
||||
self.embedder = self._load_embedder()
|
||||
|
||||
# Initialize LRU cache
|
||||
self._cache: OrderedDict[str, np.ndarray] = OrderedDict()
|
||||
|
||||
# Statistics
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
|
||||
logger.info(
|
||||
f"EmbeddingManager initialized: model={self.embedder.get_model_name()}, "
|
||||
f"dimension={self.embedder.get_dimension()}, "
|
||||
f"cache_size={cache_size}"
|
||||
)
|
||||
|
||||
def _load_embedder(self) -> EmbedderBase:
|
||||
"""
|
||||
Load the specified embedder with fallback support.
|
||||
|
||||
Returns:
|
||||
EmbedderBase: Loaded embedder instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If loading fails and fallback is disabled
|
||||
"""
|
||||
try:
|
||||
if self.model_name == "clip":
|
||||
return CLIPEmbedder(device=self.device)
|
||||
|
||||
elif self.model_name == "pix2struct":
|
||||
# Import here to avoid dependency if not used
|
||||
try:
|
||||
from .pix2struct_embedder import Pix2StructEmbedder
|
||||
return Pix2StructEmbedder(device=self.device)
|
||||
except ImportError as e:
|
||||
if self.fallback_enabled:
|
||||
logger.warning(
|
||||
f"Pix2Struct not available ({e}), falling back to CLIP"
|
||||
)
|
||||
return CLIPEmbedder(device=self.device)
|
||||
raise
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown model: {self.model_name}")
|
||||
|
||||
except Exception as e:
|
||||
if self.fallback_enabled:
|
||||
logger.warning(
|
||||
f"Failed to load {self.model_name} ({e}), falling back to CLIP"
|
||||
)
|
||||
return CLIPEmbedder(device=self.device)
|
||||
raise RuntimeError(f"Failed to load embedder: {e}")
|
||||
|
||||
def embed(self, image: Image.Image, use_cache: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for an image with caching.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
use_cache: If True, use cache for identical images
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError("Input must be a PIL Image")
|
||||
|
||||
# Check cache if enabled
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key(image)
|
||||
|
||||
if cache_key in self._cache:
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._cache_hits += 1
|
||||
logger.debug(f"Cache hit (total: {self._cache_hits})")
|
||||
return self._cache[cache_key]
|
||||
|
||||
self._cache_misses += 1
|
||||
|
||||
# Generate embedding
|
||||
embedding = self.embedder.embed(image)
|
||||
|
||||
# Store in cache
|
||||
if use_cache:
|
||||
self._add_to_cache(cache_key, embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
def embed_batch(
|
||||
self,
|
||||
images: list[Image.Image],
|
||||
use_cache: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images.
|
||||
|
||||
This method checks cache for each image individually and only
|
||||
generates embeddings for cache misses.
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
use_cache: If True, use cache for identical images
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings (len(images), dimension)
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
embeddings = []
|
||||
images_to_embed = []
|
||||
indices_to_embed = []
|
||||
|
||||
# Check cache for each image
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, Image.Image):
|
||||
raise ValueError(f"Image at index {i} is not a PIL Image")
|
||||
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key(img)
|
||||
if cache_key in self._cache:
|
||||
self._cache.move_to_end(cache_key)
|
||||
self._cache_hits += 1
|
||||
embeddings.append((i, self._cache[cache_key]))
|
||||
continue
|
||||
|
||||
self._cache_misses += 1
|
||||
|
||||
# Need to generate embedding
|
||||
images_to_embed.append(img)
|
||||
indices_to_embed.append(i)
|
||||
|
||||
# Generate embeddings for cache misses
|
||||
if images_to_embed:
|
||||
if self.embedder.supports_batch():
|
||||
new_embeddings = self.embedder.embed_batch(images_to_embed)
|
||||
else:
|
||||
new_embeddings = np.array([
|
||||
self.embedder.embed(img) for img in images_to_embed
|
||||
])
|
||||
|
||||
# Add to cache and results
|
||||
for img, idx, emb in zip(images_to_embed, indices_to_embed, new_embeddings):
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key(img)
|
||||
self._add_to_cache(cache_key, emb)
|
||||
embeddings.append((idx, emb))
|
||||
|
||||
# Sort by original index and extract embeddings
|
||||
embeddings.sort(key=lambda x: x[0])
|
||||
return np.array([emb for _, emb in embeddings])
|
||||
|
||||
def _get_cache_key(self, image: Image.Image) -> str:
|
||||
"""
|
||||
Generate cache key from image content.
|
||||
|
||||
Uses MD5 hash of image bytes for fast lookup.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
|
||||
Returns:
|
||||
str: Cache key (MD5 hash)
|
||||
"""
|
||||
return hashlib.md5(image.tobytes()).hexdigest()
|
||||
|
||||
def _add_to_cache(self, key: str, embedding: np.ndarray):
|
||||
"""
|
||||
Add embedding to cache with LRU eviction.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
embedding: Embedding to cache
|
||||
"""
|
||||
# Add to cache
|
||||
self._cache[key] = embedding
|
||||
|
||||
# Evict oldest if cache is full
|
||||
if len(self._cache) > self.cache_size:
|
||||
oldest_key = next(iter(self._cache))
|
||||
del self._cache[oldest_key]
|
||||
logger.debug(f"Cache eviction: size={len(self._cache)}")
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear all cached embeddings."""
|
||||
self._cache.clear()
|
||||
logger.info("Cache cleared")
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get embedding dimension.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension
|
||||
"""
|
||||
return self.embedder.get_dimension()
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get current model name.
|
||||
|
||||
Returns:
|
||||
str: Model identifier
|
||||
"""
|
||||
return self.embedder.get_model_name()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get manager statistics.
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- model_name: Current model
|
||||
- dimension: Embedding dimension
|
||||
- cache_size: Current cache size
|
||||
- cache_capacity: Maximum cache size
|
||||
- cache_hits: Number of cache hits
|
||||
- cache_misses: Number of cache misses
|
||||
- cache_hit_rate: Hit rate (0-1)
|
||||
"""
|
||||
total_requests = self._cache_hits + self._cache_misses
|
||||
hit_rate = self._cache_hits / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
return {
|
||||
'model_name': self.get_model_name(),
|
||||
'dimension': self.get_dimension(),
|
||||
'cache_size': len(self._cache),
|
||||
'cache_capacity': self.cache_size,
|
||||
'cache_hits': self._cache_hits,
|
||||
'cache_misses': self._cache_misses,
|
||||
'cache_hit_rate': hit_rate
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation."""
|
||||
stats = self.get_stats()
|
||||
return (
|
||||
f"EmbeddingManager(model={stats['model_name']}, "
|
||||
f"cache={stats['cache_size']}/{stats['cache_capacity']}, "
|
||||
f"hit_rate={stats['cache_hit_rate']:.2%})"
|
||||
)
|
||||
309
geniusia2/core/embedders/faiss_index.py
Normal file
309
geniusia2/core/embedders/faiss_index.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
FAISS index wrapper with proper dimension handling and persistence.
|
||||
|
||||
This module provides a robust wrapper around FAISS for storing and searching
|
||||
image embeddings, with proper error handling for dimension mismatches and
|
||||
reliable save/load functionality.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import faiss
|
||||
except ImportError:
|
||||
faiss = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FAISSIndex:
|
||||
"""
|
||||
Wrapper around FAISS index with metadata storage and dimension validation.
|
||||
|
||||
This class handles:
|
||||
- Dimension validation on add/search operations
|
||||
- Metadata storage alongside embeddings
|
||||
- Reliable persistence (save/load)
|
||||
- Automatic index rebuilding on dimension changes
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int):
|
||||
"""
|
||||
Initialize a new FAISS index.
|
||||
|
||||
Args:
|
||||
dimension: Embedding dimension (e.g., 512 for CLIP, 768 for Pix2Struct)
|
||||
|
||||
Raises:
|
||||
ImportError: If FAISS is not installed
|
||||
ValueError: If dimension is invalid
|
||||
"""
|
||||
if faiss is None:
|
||||
raise ImportError(
|
||||
"FAISS is not installed. "
|
||||
"Install it with: pip install faiss-cpu or faiss-gpu"
|
||||
)
|
||||
|
||||
if dimension <= 0:
|
||||
raise ValueError(f"Dimension must be positive, got {dimension}")
|
||||
|
||||
self.dimension = dimension
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
self.metadata: List[Dict[str, Any]] = []
|
||||
|
||||
logger.info(f"FAISSIndex created with dimension={dimension}")
|
||||
|
||||
def add(self, embeddings: np.ndarray, metadata: List[Dict[str, Any]]):
|
||||
"""
|
||||
Add embeddings to the index with associated metadata.
|
||||
|
||||
Args:
|
||||
embeddings: Array of shape (N, dimension) containing N embeddings
|
||||
metadata: List of N metadata dictionaries
|
||||
|
||||
Raises:
|
||||
ValueError: If dimensions don't match or array shapes are invalid
|
||||
"""
|
||||
# Validate input shape
|
||||
if embeddings.ndim == 1:
|
||||
# Single embedding, reshape to (1, dimension)
|
||||
embeddings = embeddings.reshape(1, -1)
|
||||
elif embeddings.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Embeddings must be 1D or 2D array, got shape {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Validate dimension
|
||||
if embeddings.shape[1] != self.dimension:
|
||||
raise ValueError(
|
||||
f"Embedding dimension {embeddings.shape[1]} doesn't match "
|
||||
f"index dimension {self.dimension}"
|
||||
)
|
||||
|
||||
# Validate metadata count
|
||||
if len(metadata) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"Number of metadata entries ({len(metadata)}) doesn't match "
|
||||
f"number of embeddings ({embeddings.shape[0]})"
|
||||
)
|
||||
|
||||
# Add to FAISS index
|
||||
self.index.add(embeddings.astype('float32'))
|
||||
|
||||
# Store metadata
|
||||
self.metadata.extend(metadata)
|
||||
|
||||
logger.debug(
|
||||
f"Added {embeddings.shape[0]} embeddings to index "
|
||||
f"(total: {self.index.ntotal})"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
k: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for the k most similar embeddings.
|
||||
|
||||
Args:
|
||||
query: Query embedding of shape (dimension,) or (1, dimension)
|
||||
k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of dicts with keys:
|
||||
- 'index': Index in the FAISS index
|
||||
- 'distance': L2 distance
|
||||
- 'similarity': Similarity score (1 / (1 + distance))
|
||||
- 'metadata': Associated metadata dict
|
||||
|
||||
Raises:
|
||||
ValueError: If query dimension doesn't match index dimension
|
||||
"""
|
||||
if self.index.ntotal == 0:
|
||||
logger.warning("Search called on empty index")
|
||||
return []
|
||||
|
||||
# Reshape query if needed
|
||||
if query.ndim == 1:
|
||||
query = query.reshape(1, -1)
|
||||
elif query.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Query must be 1D or 2D array, got shape {query.shape}"
|
||||
)
|
||||
|
||||
# Validate dimension
|
||||
if query.shape[1] != self.dimension:
|
||||
raise ValueError(
|
||||
f"Query dimension {query.shape[1]} doesn't match "
|
||||
f"index dimension {self.dimension}"
|
||||
)
|
||||
|
||||
# Limit k to available embeddings
|
||||
k = min(k, self.index.ntotal)
|
||||
|
||||
# Search
|
||||
distances, indices = self.index.search(query.astype('float32'), k)
|
||||
|
||||
# Format results
|
||||
results = []
|
||||
for dist, idx in zip(distances[0], indices[0]):
|
||||
# FAISS returns -1 if not enough results
|
||||
if idx >= 0 and idx < len(self.metadata):
|
||||
results.append({
|
||||
'index': int(idx),
|
||||
'distance': float(dist),
|
||||
'similarity': float(1.0 / (1.0 + dist)),
|
||||
'metadata': self.metadata[idx]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def save(self, path: str):
|
||||
"""
|
||||
Save index and metadata to disk.
|
||||
|
||||
Args:
|
||||
path: Base path for saving (will create .index and .metadata files)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If save operation fails
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path)
|
||||
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save FAISS index
|
||||
index_file = f"{path}.index"
|
||||
faiss.write_index(self.index, index_file)
|
||||
|
||||
# Save metadata
|
||||
metadata_file = f"{path}.metadata"
|
||||
with open(metadata_file, 'wb') as f:
|
||||
pickle.dump({
|
||||
'dimension': self.dimension,
|
||||
'metadata': self.metadata
|
||||
}, f)
|
||||
|
||||
logger.info(
|
||||
f"Saved index with {self.index.ntotal} embeddings to {path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save index: {e}")
|
||||
|
||||
def load(self, path: str):
|
||||
"""
|
||||
Load index and metadata from disk.
|
||||
|
||||
Args:
|
||||
path: Base path for loading (will read .index and .metadata files)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If files don't exist
|
||||
RuntimeError: If load operation fails or dimension mismatch
|
||||
"""
|
||||
try:
|
||||
index_file = f"{path}.index"
|
||||
metadata_file = f"{path}.metadata"
|
||||
|
||||
# Check files exist
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(f"Index file not found: {index_file}")
|
||||
if not Path(metadata_file).exists():
|
||||
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
|
||||
|
||||
# Load FAISS index
|
||||
loaded_index = faiss.read_index(index_file)
|
||||
|
||||
# Load metadata
|
||||
with open(metadata_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
loaded_dimension = data['dimension']
|
||||
loaded_metadata = data['metadata']
|
||||
|
||||
# Validate dimension
|
||||
if loaded_dimension != self.dimension:
|
||||
raise RuntimeError(
|
||||
f"Loaded index dimension ({loaded_dimension}) doesn't match "
|
||||
f"current dimension ({self.dimension}). "
|
||||
f"Use rebuild_if_needed() to handle dimension changes."
|
||||
)
|
||||
|
||||
# Update state
|
||||
self.index = loaded_index
|
||||
self.metadata = loaded_metadata
|
||||
|
||||
logger.info(
|
||||
f"Loaded index with {self.index.ntotal} embeddings from {path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, (FileNotFoundError, RuntimeError)):
|
||||
raise
|
||||
raise RuntimeError(f"Failed to load index: {e}")
|
||||
|
||||
def rebuild_if_needed(self, new_dimension: int) -> bool:
|
||||
"""
|
||||
Rebuild index if dimension has changed.
|
||||
|
||||
This creates a new empty index with the new dimension.
|
||||
Old embeddings are lost and need to be regenerated.
|
||||
|
||||
Args:
|
||||
new_dimension: New embedding dimension
|
||||
|
||||
Returns:
|
||||
bool: True if index was rebuilt, False if dimension unchanged
|
||||
"""
|
||||
if new_dimension == self.dimension:
|
||||
return False
|
||||
|
||||
logger.warning(
|
||||
f"Rebuilding FAISS index: dimension changed from "
|
||||
f"{self.dimension} to {new_dimension}. "
|
||||
f"Old embeddings ({self.index.ntotal}) will be lost."
|
||||
)
|
||||
|
||||
# Create new index
|
||||
self.dimension = new_dimension
|
||||
self.index = faiss.IndexFlatL2(new_dimension)
|
||||
self.metadata = []
|
||||
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
"""Clear all embeddings from the index."""
|
||||
self.index = faiss.IndexFlatL2(self.dimension)
|
||||
self.metadata = []
|
||||
logger.info("Index cleared")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get index statistics.
|
||||
|
||||
Returns:
|
||||
Dict with keys: num_embeddings, dimension, is_trained
|
||||
"""
|
||||
return {
|
||||
'num_embeddings': self.index.ntotal,
|
||||
'dimension': self.dimension,
|
||||
'is_trained': self.index.is_trained
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return number of embeddings in the index."""
|
||||
return self.index.ntotal
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the index."""
|
||||
return (
|
||||
f"FAISSIndex(dimension={self.dimension}, "
|
||||
f"num_embeddings={self.index.ntotal})"
|
||||
)
|
||||
321
geniusia2/core/embedders/fine_tuner.py
Normal file
321
geniusia2/core/embedders/fine_tuner.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Lightweight fine-tuner for embedding models.
|
||||
|
||||
This module provides incremental fine-tuning capabilities that run in the
|
||||
background, adapting the embedding model to user-specific workflows over time.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import pickle
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LightweightFineTuner:
|
||||
"""
|
||||
Lightweight fine-tuner for incremental model adaptation.
|
||||
|
||||
This class collects positive and negative examples from user interactions
|
||||
and periodically fine-tunes the embedding model to improve accuracy on
|
||||
user-specific workflows.
|
||||
|
||||
Features:
|
||||
- Automatic triggering after N examples
|
||||
- Background training (non-blocking)
|
||||
- Checkpoint save/load for recovery
|
||||
- Metrics tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder,
|
||||
trigger_threshold: int = 10,
|
||||
max_examples: int = 1000,
|
||||
checkpoint_dir: str = "data/fine_tuning"
|
||||
):
|
||||
"""
|
||||
Initialize the fine-tuner.
|
||||
|
||||
Args:
|
||||
embedder: Embedder instance to fine-tune (must support fine_tune method)
|
||||
trigger_threshold: Number of new examples before triggering fine-tuning
|
||||
max_examples: Maximum examples to keep (LRU eviction)
|
||||
checkpoint_dir: Directory for saving checkpoints
|
||||
"""
|
||||
self.embedder = embedder
|
||||
self.trigger_threshold = trigger_threshold
|
||||
self.max_examples = max_examples
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Example storage (deque for automatic LRU)
|
||||
self.positive_examples = deque(maxlen=max_examples)
|
||||
self.negative_examples = deque(maxlen=max_examples)
|
||||
|
||||
# Training state
|
||||
self.is_training = False
|
||||
self.training_thread: Optional[threading.Thread] = None
|
||||
self.last_training_time = 0
|
||||
self.training_count = 0
|
||||
|
||||
# Metrics
|
||||
self.metrics_history: List[Dict[str, Any]] = []
|
||||
|
||||
logger.info(
|
||||
f"LightweightFineTuner initialized: "
|
||||
f"trigger={trigger_threshold}, max_examples={max_examples}"
|
||||
)
|
||||
|
||||
def add_positive_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
|
||||
"""
|
||||
Add a positive example (successful workflow execution).
|
||||
|
||||
Args:
|
||||
image: Screenshot where workflow succeeded
|
||||
workflow_id: ID of the successful workflow
|
||||
metadata: Optional additional metadata
|
||||
"""
|
||||
example = {
|
||||
'image': image,
|
||||
'workflow_id': workflow_id,
|
||||
'metadata': metadata or {},
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
self.positive_examples.append(example)
|
||||
|
||||
logger.debug(
|
||||
f"Added positive example: workflow={workflow_id}, "
|
||||
f"total_positive={len(self.positive_examples)}"
|
||||
)
|
||||
|
||||
self._check_trigger()
|
||||
|
||||
def add_negative_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
|
||||
"""
|
||||
Add a negative example (rejected workflow suggestion).
|
||||
|
||||
Args:
|
||||
image: Screenshot where workflow was rejected
|
||||
workflow_id: ID of the rejected workflow
|
||||
metadata: Optional additional metadata
|
||||
"""
|
||||
example = {
|
||||
'image': image,
|
||||
'workflow_id': workflow_id,
|
||||
'metadata': metadata or {},
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
self.negative_examples.append(example)
|
||||
|
||||
logger.debug(
|
||||
f"Added negative example: workflow={workflow_id}, "
|
||||
f"total_negative={len(self.negative_examples)}"
|
||||
)
|
||||
|
||||
self._check_trigger()
|
||||
|
||||
def _check_trigger(self):
|
||||
"""Check if we should trigger fine-tuning."""
|
||||
total_new = len(self.positive_examples) + len(self.negative_examples)
|
||||
|
||||
# Don't trigger if already training
|
||||
if self.is_training:
|
||||
logger.debug("Fine-tuning already in progress, skipping trigger check")
|
||||
return
|
||||
|
||||
# Check if we have enough examples
|
||||
if total_new >= self.trigger_threshold:
|
||||
logger.info(
|
||||
f"Fine-tuning triggered: {total_new} examples "
|
||||
f"({len(self.positive_examples)} positive, "
|
||||
f"{len(self.negative_examples)} negative)"
|
||||
)
|
||||
self._start_training()
|
||||
|
||||
def _start_training(self):
|
||||
"""Start training in background thread."""
|
||||
if self.is_training:
|
||||
logger.warning("Training already in progress")
|
||||
return
|
||||
|
||||
self.training_thread = threading.Thread(
|
||||
target=self._train,
|
||||
name="FineTuningThread",
|
||||
daemon=True
|
||||
)
|
||||
self.training_thread.start()
|
||||
logger.info("Fine-tuning thread started")
|
||||
|
||||
def _train(self):
|
||||
"""Fine-tune the model (runs in background thread)."""
|
||||
self.is_training = True
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Check if embedder supports fine-tuning
|
||||
if not hasattr(self.embedder, 'fine_tune'):
|
||||
logger.info(
|
||||
f"Embedder {self.embedder.get_model_name()} doesn't support "
|
||||
f"fine-tuning, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
# Prepare training data
|
||||
positive_images = [ex['image'] for ex in self.positive_examples]
|
||||
negative_images = [ex['image'] for ex in self.negative_examples]
|
||||
|
||||
if not positive_images and not negative_images:
|
||||
logger.warning("No examples to train on")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Starting fine-tuning: {len(positive_images)} positive, "
|
||||
f"{len(negative_images)} negative examples"
|
||||
)
|
||||
|
||||
# Fine-tune (implementation depends on embedder)
|
||||
metrics = self.embedder.fine_tune(
|
||||
positive_images=positive_images,
|
||||
negative_images=negative_images,
|
||||
epochs=1,
|
||||
learning_rate=1e-4
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
duration = time.time() - start_time
|
||||
metrics['duration_seconds'] = duration
|
||||
metrics['timestamp'] = time.time()
|
||||
metrics['positive_count'] = len(positive_images)
|
||||
metrics['negative_count'] = len(negative_images)
|
||||
metrics['training_number'] = self.training_count
|
||||
|
||||
self.metrics_history.append(metrics)
|
||||
self.last_training_time = time.time()
|
||||
self.training_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Fine-tuning complete #{self.training_count}: "
|
||||
f"loss={metrics.get('loss', 'N/A'):.4f}, "
|
||||
f"duration={duration:.1f}s"
|
||||
)
|
||||
|
||||
# Clear examples after successful training
|
||||
self.positive_examples.clear()
|
||||
self.negative_examples.clear()
|
||||
logger.debug("Training examples cleared")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fine-tuning failed: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
self.is_training = False
|
||||
|
||||
def save_checkpoint(self, name: str = "checkpoint"):
|
||||
"""
|
||||
Save training examples and metrics for recovery.
|
||||
|
||||
Args:
|
||||
name: Checkpoint name
|
||||
"""
|
||||
try:
|
||||
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
|
||||
|
||||
data = {
|
||||
'positive_examples': list(self.positive_examples),
|
||||
'negative_examples': list(self.negative_examples),
|
||||
'metrics_history': self.metrics_history,
|
||||
'training_count': self.training_count,
|
||||
'last_training_time': self.last_training_time
|
||||
}
|
||||
|
||||
with open(checkpoint_path, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
logger.info(f"Checkpoint saved: {checkpoint_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint: {e}")
|
||||
|
||||
def load_checkpoint(self, name: str = "checkpoint"):
|
||||
"""
|
||||
Load training examples and metrics from checkpoint.
|
||||
|
||||
Args:
|
||||
name: Checkpoint name
|
||||
|
||||
Returns:
|
||||
bool: True if loaded successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
||||
return False
|
||||
|
||||
with open(checkpoint_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
self.positive_examples.extend(data.get('positive_examples', []))
|
||||
self.negative_examples.extend(data.get('negative_examples', []))
|
||||
self.metrics_history = data.get('metrics_history', [])
|
||||
self.training_count = data.get('training_count', 0)
|
||||
self.last_training_time = data.get('last_training_time', 0)
|
||||
|
||||
logger.info(
|
||||
f"Checkpoint loaded: {len(self.positive_examples)} positive, "
|
||||
f"{len(self.negative_examples)} negative examples"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get fine-tuning statistics.
|
||||
|
||||
Returns:
|
||||
Dict with statistics
|
||||
"""
|
||||
return {
|
||||
'positive_examples': len(self.positive_examples),
|
||||
'negative_examples': len(self.negative_examples),
|
||||
'total_examples': len(self.positive_examples) + len(self.negative_examples),
|
||||
'is_training': self.is_training,
|
||||
'training_count': self.training_count,
|
||||
'last_training_time': self.last_training_time,
|
||||
'metrics_history': self.metrics_history,
|
||||
'trigger_threshold': self.trigger_threshold
|
||||
}
|
||||
|
||||
def wait_for_training(self, timeout: Optional[float] = None):
|
||||
"""
|
||||
Wait for current training to complete.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds (None = wait forever)
|
||||
"""
|
||||
if self.training_thread and self.training_thread.is_alive():
|
||||
self.training_thread.join(timeout=timeout)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation."""
|
||||
stats = self.get_stats()
|
||||
return (
|
||||
f"LightweightFineTuner("
|
||||
f"examples={stats['total_examples']}, "
|
||||
f"trainings={stats['training_count']}, "
|
||||
f"is_training={stats['is_training']})"
|
||||
)
|
||||
193
geniusia2/core/embedders/pix2struct_embedder.py
Normal file
193
geniusia2/core/embedders/pix2struct_embedder.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Pix2Struct-based embedder implementation.
|
||||
|
||||
This module provides a wrapper around Google's Pix2Struct model for generating
|
||||
image embeddings specialized for UI understanding and document screenshots.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
try:
|
||||
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
||||
except ImportError:
|
||||
Pix2StructProcessor = None
|
||||
Pix2StructForConditionalGeneration = None
|
||||
|
||||
from .base import EmbedderBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Pix2StructEmbedder(EmbedderBase):
|
||||
"""
|
||||
Pix2Struct-based image embedder specialized for UI understanding.
|
||||
|
||||
Pix2Struct is a vision-language model trained on screenshots and structured
|
||||
documents, making it particularly well-suited for RPA and UI automation tasks.
|
||||
|
||||
This embedder uses the encoder's hidden states as embeddings, which capture
|
||||
visual features optimized for understanding UI elements and layouts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "google/pix2struct-base",
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the Pix2Struct embedder.
|
||||
|
||||
Args:
|
||||
model_name: Pix2Struct model to use (default: google/pix2struct-base)
|
||||
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
||||
|
||||
Raises:
|
||||
ImportError: If transformers is not installed
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
if Pix2StructProcessor is None or Pix2StructForConditionalGeneration is None:
|
||||
raise ImportError(
|
||||
"Transformers is not installed or version is too old. "
|
||||
"Install it with: pip install transformers>=4.35.0"
|
||||
)
|
||||
|
||||
# Auto-detect device
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self._embedding_dim = None
|
||||
|
||||
# Load model and processor
|
||||
try:
|
||||
logger.info(f"Loading Pix2Struct model: {model_name}")
|
||||
|
||||
self.processor = Pix2StructProcessor.from_pretrained(model_name)
|
||||
self.model = Pix2StructForConditionalGeneration.from_pretrained(
|
||||
model_name
|
||||
).to(device)
|
||||
self.model.eval()
|
||||
|
||||
# Determine embedding dimension from encoder
|
||||
with torch.no_grad():
|
||||
dummy_image = Image.new('RGB', (224, 224), color=(128, 128, 128))
|
||||
inputs = self.processor(images=dummy_image, return_tensors="pt").to(device)
|
||||
encoder_outputs = self.model.encoder(**inputs)
|
||||
# Use mean pooling of last hidden state
|
||||
self._embedding_dim = encoder_outputs.last_hidden_state.shape[-1]
|
||||
|
||||
logger.info(
|
||||
f"Pix2StructEmbedder loaded: {model_name} on {device}, "
|
||||
f"dimension={self._embedding_dim}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load Pix2Struct model: {e}")
|
||||
|
||||
def embed(self, image: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single image.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector of shape (dimension,)
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError("Input must be a PIL Image")
|
||||
|
||||
try:
|
||||
# Process image
|
||||
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
||||
|
||||
# Generate embedding from encoder
|
||||
with torch.no_grad():
|
||||
encoder_outputs = self.model.encoder(**inputs)
|
||||
# Mean pooling over sequence dimension
|
||||
embedding = encoder_outputs.last_hidden_state.mean(dim=1)
|
||||
# L2 normalize for cosine similarity
|
||||
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embedding.cpu().numpy().flatten()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate embedding: {e}")
|
||||
|
||||
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images (optimized batch processing).
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
# Validate all images
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, Image.Image):
|
||||
raise ValueError(f"Image at index {i} is not a PIL Image")
|
||||
|
||||
try:
|
||||
# Process all images in batch
|
||||
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
|
||||
|
||||
# Generate embeddings in batch
|
||||
with torch.no_grad():
|
||||
encoder_outputs = self.model.encoder(**inputs)
|
||||
# Mean pooling over sequence dimension
|
||||
embeddings = encoder_outputs.last_hidden_state.mean(dim=1)
|
||||
# L2 normalize for cosine similarity
|
||||
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embeddings.cpu().numpy()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get the dimensionality of embeddings.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension (768 for pix2struct-base)
|
||||
"""
|
||||
return self._embedding_dim
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get model identifier.
|
||||
|
||||
Returns:
|
||||
str: Model name (e.g., "pix2struct-base")
|
||||
"""
|
||||
# Extract model name from full path
|
||||
model_id = self.model_name.split('/')[-1]
|
||||
return f"pix2struct-{model_id}" if not model_id.startswith('pix2struct') else model_id
|
||||
|
||||
def supports_batch(self) -> bool:
|
||||
"""
|
||||
Check if batch processing is supported.
|
||||
|
||||
Returns:
|
||||
bool: True (Pix2Struct supports efficient batch processing)
|
||||
"""
|
||||
return True
|
||||
Reference in New Issue
Block a user