Initial commit

This commit is contained in:
Dom
2026-03-05 00:20:25 +01:00
commit dcd4de9945
1954 changed files with 669380 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
"""
Embedding system for visual similarity matching.
This module provides an abstraction layer for different embedding models
(CLIP, Pix2Struct) used for workflow matching and visual analysis.
"""
from .base import EmbedderBase
from .clip_embedder import CLIPEmbedder
from .faiss_index import FAISSIndex
from .embedding_manager import EmbeddingManager
from .fine_tuner import LightweightFineTuner
# Pix2Struct is optional (requires transformers>=4.35.0)
try:
from .pix2struct_embedder import Pix2StructEmbedder
__all__ = ['EmbedderBase', 'CLIPEmbedder', 'Pix2StructEmbedder', 'FAISSIndex', 'EmbeddingManager', 'LightweightFineTuner']
except ImportError:
__all__ = ['EmbedderBase', 'CLIPEmbedder', 'FAISSIndex', 'EmbeddingManager', 'LightweightFineTuner']

View File

@@ -0,0 +1,100 @@
"""
Abstract base class for embedding models.
This module defines the interface that all embedding models must implement,
ensuring consistency across different model implementations (CLIP, Pix2Struct, etc.).
"""
from abc import ABC, abstractmethod
from typing import List
from PIL import Image
import numpy as np
class EmbedderBase(ABC):
"""
Abstract base class for image embedding models.
All embedding models must implement this interface to ensure
compatibility with the workflow matching system.
"""
@abstractmethod
def embed(self, image: Image.Image) -> np.ndarray:
"""
Generate an embedding vector for a single image.
Args:
image: PIL Image to embed
Returns:
np.ndarray: Normalized embedding vector of shape (dimension,)
The vector should be L2-normalized for cosine similarity
Raises:
ValueError: If image is invalid or cannot be processed
RuntimeError: If model inference fails
"""
pass
@abstractmethod
def get_dimension(self) -> int:
"""
Get the dimensionality of embeddings produced by this model.
Returns:
int: Embedding dimension (e.g., 512 for CLIP ViT-B/32, 768 for Pix2Struct)
"""
pass
@abstractmethod
def get_model_name(self) -> str:
"""
Get a unique identifier for this model.
Returns:
str: Model name (e.g., "clip-vit-b32", "pix2struct-base")
"""
pass
@abstractmethod
def supports_batch(self) -> bool:
"""
Check if this model supports batch processing.
Returns:
bool: True if embed_batch() is optimized, False otherwise
"""
pass
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
"""
Generate embeddings for multiple images.
Default implementation processes images one by one.
Subclasses can override this for optimized batch processing.
Args:
images: List of PIL Images to embed
Returns:
np.ndarray: Array of embeddings with shape (len(images), dimension)
Each row is a normalized embedding vector
Raises:
ValueError: If any image is invalid
RuntimeError: If model inference fails
"""
if not images:
return np.array([]).reshape(0, self.get_dimension())
embeddings = []
for img in images:
embedding = self.embed(img)
embeddings.append(embedding)
return np.array(embeddings)
def __repr__(self) -> str:
"""String representation of the embedder."""
return f"{self.__class__.__name__}(model={self.get_model_name()}, dim={self.get_dimension()})"

View File

@@ -0,0 +1,358 @@
"""
CLIP-based embedder implementation.
This module provides a wrapper around OpenCLIP for generating image embeddings
using the CLIP (Contrastive Language-Image Pre-training) model.
"""
import torch
import numpy as np
from PIL import Image
from typing import List, Optional
import logging
try:
import open_clip
except ImportError:
open_clip = None
from .base import EmbedderBase
logger = logging.getLogger(__name__)
class CLIPEmbedder(EmbedderBase):
"""
CLIP-based image embedder using OpenCLIP.
This embedder uses the ViT-B/32 architecture by default, which produces
512-dimensional embeddings. It automatically handles GPU/CPU device selection.
"""
def __init__(
self,
model_name: str = "ViT-B-32",
pretrained: str = "openai",
device: Optional[str] = None
):
"""
Initialize the CLIP embedder.
Args:
model_name: CLIP model architecture (default: ViT-B-32)
pretrained: Pretrained weights to use (default: openai)
device: Device to use ('cuda', 'cpu', or None for auto-detect)
Note: Defaults to CPU to save GPU memory for other models
Raises:
ImportError: If open_clip is not installed
RuntimeError: If model loading fails
"""
if open_clip is None:
raise ImportError(
"OpenCLIP is not installed. "
"Install it with: pip install open-clip-torch"
)
# Default to CPU to save GPU for vision models (Qwen3-VL)
if device is None:
device = "cpu"
self.model_name = model_name
self.pretrained = pretrained
self.device = device
self._embedding_dim = None
# Load model
try:
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
model_name,
pretrained=pretrained,
device=device
)
self.model.eval()
# Determine embedding dimension
with torch.no_grad():
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
dummy_embedding = self.model.encode_image(dummy_image)
self._embedding_dim = dummy_embedding.shape[-1]
logger.info(
f"CLIPEmbedder loaded: {model_name} on {device}, "
f"dimension={self._embedding_dim}"
)
except Exception as e:
raise RuntimeError(f"Failed to load CLIP model: {e}")
def embed(self, image: Image.Image) -> np.ndarray:
"""
Generate embedding for a single image.
Args:
image: PIL Image to embed
Returns:
np.ndarray: Normalized embedding vector of shape (dimension,)
Raises:
ValueError: If image is invalid
RuntimeError: If embedding generation fails
"""
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image")
try:
# Preprocess image
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# Generate embedding
with torch.no_grad():
embedding = self.model.encode_image(image_tensor)
# L2 normalize for cosine similarity
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
return embedding.cpu().numpy().flatten()
except Exception as e:
raise RuntimeError(f"Failed to generate embedding: {e}")
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
"""
Generate embeddings for multiple images (optimized batch processing).
Args:
images: List of PIL Images to embed
Returns:
np.ndarray: Array of embeddings with shape (len(images), dimension)
Raises:
ValueError: If any image is invalid
RuntimeError: If embedding generation fails
"""
if not images:
return np.array([]).reshape(0, self.get_dimension())
# Validate all images
for i, img in enumerate(images):
if not isinstance(img, Image.Image):
raise ValueError(f"Image at index {i} is not a PIL Image")
try:
# Preprocess all images
image_tensors = torch.stack([
self.preprocess(img) for img in images
]).to(self.device)
# Generate embeddings in batch
with torch.no_grad():
embeddings = self.model.encode_image(image_tensors)
# L2 normalize for cosine similarity
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
return embeddings.cpu().numpy()
except Exception as e:
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
def get_dimension(self) -> int:
"""
Get the dimensionality of embeddings.
Returns:
int: Embedding dimension (512 for ViT-B/32)
"""
return self._embedding_dim
def get_model_name(self) -> str:
"""
Get model identifier.
Returns:
str: Model name (e.g., "clip-vit-b32")
"""
return f"clip-{self.model_name.lower().replace('/', '-')}"
def supports_batch(self) -> bool:
"""
Check if batch processing is supported.
Returns:
bool: True (CLIP supports efficient batch processing)
"""
return True
def fine_tune(
self,
positive_images: List[Image.Image],
negative_images: List[Image.Image],
epochs: int = 1,
learning_rate: float = 1e-4
) -> dict:
"""
Fine-tune the model using contrastive learning.
This method fine-tunes only the final projection layer to adapt
the model to user-specific workflows. It uses a simple contrastive
loss: positive examples should be similar, negative examples should
be dissimilar.
Args:
positive_images: Images from successful workflows
negative_images: Images from rejected workflows
epochs: Number of training epochs (default: 1 for speed)
learning_rate: Learning rate (default: 1e-4)
Returns:
dict: Training metrics (loss, accuracy, etc.)
"""
if not positive_images and not negative_images:
return {'loss': 0.0, 'note': 'No examples to train on'}
# Set model to training mode
self.model.train()
# Only train the visual projection layer (last layer)
# Freeze all other parameters
for param in self.model.parameters():
param.requires_grad = False
# Unfreeze visual projection
if hasattr(self.model, 'visual') and hasattr(self.model.visual, 'proj'):
if self.model.visual.proj is not None:
self.model.visual.proj.requires_grad = True
# Setup optimizer (only for trainable parameters)
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
if not trainable_params:
logger.warning("No trainable parameters found, skipping fine-tuning")
self.model.eval()
return {'loss': 0.0, 'note': 'No trainable parameters'}
optimizer = torch.optim.Adam(trainable_params, lr=learning_rate)
total_loss = 0.0
num_batches = 0
try:
for epoch in range(epochs):
# Process positive examples (should be similar to each other)
if len(positive_images) >= 2:
pos_loss = self._contrastive_loss_positive(positive_images)
total_loss += pos_loss
num_batches += 1
optimizer.zero_grad()
pos_loss.backward()
optimizer.step()
# Process negative examples (should be dissimilar from positives)
if positive_images and negative_images:
neg_loss = self._contrastive_loss_negative(
positive_images[:5], # Use subset for speed
negative_images[:5]
)
total_loss += neg_loss
num_batches += 1
optimizer.zero_grad()
neg_loss.backward()
optimizer.step()
finally:
# Always restore eval mode and freeze parameters
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
return {
'loss': float(avg_loss),
'epochs': epochs,
'learning_rate': learning_rate,
'positive_count': len(positive_images),
'negative_count': len(negative_images),
'num_batches': num_batches
}
def _contrastive_loss_positive(self, images: List[Image.Image]) -> torch.Tensor:
"""
Contrastive loss for positive examples (should be similar).
Args:
images: List of positive example images
Returns:
torch.Tensor: Loss value
"""
# Generate embeddings
embeddings = []
for img in images:
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
emb = self.model.encode_image(img_tensor)
emb = emb / emb.norm(dim=-1, keepdim=True)
embeddings.append(emb)
embeddings = torch.cat(embeddings, dim=0)
# Compute pairwise cosine similarities
similarities = torch.mm(embeddings, embeddings.t())
# Loss: maximize similarity (minimize negative similarity)
# Exclude diagonal (self-similarity)
mask = torch.eye(len(images), device=self.device).bool()
similarities = similarities.masked_fill(mask, 0)
# We want high similarity, so minimize (1 - similarity)
loss = (1 - similarities).mean()
return loss
def _contrastive_loss_negative(
self,
positive_images: List[Image.Image],
negative_images: List[Image.Image]
) -> torch.Tensor:
"""
Contrastive loss for negative examples (should be dissimilar).
Args:
positive_images: Positive example images
negative_images: Negative example images
Returns:
torch.Tensor: Loss value
"""
# Generate embeddings for positives
pos_embeddings = []
for img in positive_images:
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
emb = self.model.encode_image(img_tensor)
emb = emb / emb.norm(dim=-1, keepdim=True)
pos_embeddings.append(emb)
pos_embeddings = torch.cat(pos_embeddings, dim=0)
# Generate embeddings for negatives
neg_embeddings = []
for img in negative_images:
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
emb = self.model.encode_image(img_tensor)
emb = emb / emb.norm(dim=-1, keepdim=True)
neg_embeddings.append(emb)
neg_embeddings = torch.cat(neg_embeddings, dim=0)
# Compute cross-similarities (positive vs negative)
similarities = torch.mm(pos_embeddings, neg_embeddings.t())
# We want low similarity, so minimize similarity directly
# (or maximize dissimilarity)
loss = similarities.mean()
return loss

View File

@@ -0,0 +1,309 @@
"""
Embedding manager with model selection, caching, and fallback.
This module provides a high-level interface for generating embeddings,
with automatic model selection, LRU caching, and fallback to CLIP if
the selected model fails to load.
"""
import hashlib
import logging
from typing import Optional, Dict, Any
from collections import OrderedDict
from PIL import Image
import numpy as np
from .base import EmbedderBase
from .clip_embedder import CLIPEmbedder
logger = logging.getLogger(__name__)
class EmbeddingManager:
"""
High-level manager for image embeddings.
Features:
- Model selection (CLIP, Pix2Struct, etc.)
- Automatic fallback to CLIP on errors
- LRU cache (1000 entries) for performance
- GPU/CPU management
- Logging and monitoring
"""
def __init__(
self,
model_name: str = "clip",
fallback_enabled: bool = True,
cache_size: int = 1000,
device: Optional[str] = None
):
"""
Initialize the embedding manager.
Args:
model_name: Model to use ("clip" or "pix2struct")
fallback_enabled: If True, fallback to CLIP on model load failure
cache_size: Maximum number of cached embeddings (LRU eviction)
device: Device to use ('cuda', 'cpu', or None for auto)
Raises:
RuntimeError: If model loading fails and fallback is disabled
"""
self.model_name = model_name.lower()
self.fallback_enabled = fallback_enabled
self.cache_size = cache_size
self.device = device
# Initialize embedder
self.embedder = self._load_embedder()
# Initialize LRU cache
self._cache: OrderedDict[str, np.ndarray] = OrderedDict()
# Statistics
self._cache_hits = 0
self._cache_misses = 0
logger.info(
f"EmbeddingManager initialized: model={self.embedder.get_model_name()}, "
f"dimension={self.embedder.get_dimension()}, "
f"cache_size={cache_size}"
)
def _load_embedder(self) -> EmbedderBase:
"""
Load the specified embedder with fallback support.
Returns:
EmbedderBase: Loaded embedder instance
Raises:
RuntimeError: If loading fails and fallback is disabled
"""
try:
if self.model_name == "clip":
return CLIPEmbedder(device=self.device)
elif self.model_name == "pix2struct":
# Import here to avoid dependency if not used
try:
from .pix2struct_embedder import Pix2StructEmbedder
return Pix2StructEmbedder(device=self.device)
except ImportError as e:
if self.fallback_enabled:
logger.warning(
f"Pix2Struct not available ({e}), falling back to CLIP"
)
return CLIPEmbedder(device=self.device)
raise
else:
raise ValueError(f"Unknown model: {self.model_name}")
except Exception as e:
if self.fallback_enabled:
logger.warning(
f"Failed to load {self.model_name} ({e}), falling back to CLIP"
)
return CLIPEmbedder(device=self.device)
raise RuntimeError(f"Failed to load embedder: {e}")
def embed(self, image: Image.Image, use_cache: bool = True) -> np.ndarray:
"""
Generate embedding for an image with caching.
Args:
image: PIL Image to embed
use_cache: If True, use cache for identical images
Returns:
np.ndarray: Normalized embedding vector
Raises:
ValueError: If image is invalid
RuntimeError: If embedding generation fails
"""
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image")
# Check cache if enabled
if use_cache:
cache_key = self._get_cache_key(image)
if cache_key in self._cache:
# Move to end (most recently used)
self._cache.move_to_end(cache_key)
self._cache_hits += 1
logger.debug(f"Cache hit (total: {self._cache_hits})")
return self._cache[cache_key]
self._cache_misses += 1
# Generate embedding
embedding = self.embedder.embed(image)
# Store in cache
if use_cache:
self._add_to_cache(cache_key, embedding)
return embedding
def embed_batch(
self,
images: list[Image.Image],
use_cache: bool = True
) -> np.ndarray:
"""
Generate embeddings for multiple images.
This method checks cache for each image individually and only
generates embeddings for cache misses.
Args:
images: List of PIL Images to embed
use_cache: If True, use cache for identical images
Returns:
np.ndarray: Array of embeddings (len(images), dimension)
Raises:
ValueError: If any image is invalid
RuntimeError: If embedding generation fails
"""
if not images:
return np.array([]).reshape(0, self.get_dimension())
embeddings = []
images_to_embed = []
indices_to_embed = []
# Check cache for each image
for i, img in enumerate(images):
if not isinstance(img, Image.Image):
raise ValueError(f"Image at index {i} is not a PIL Image")
if use_cache:
cache_key = self._get_cache_key(img)
if cache_key in self._cache:
self._cache.move_to_end(cache_key)
self._cache_hits += 1
embeddings.append((i, self._cache[cache_key]))
continue
self._cache_misses += 1
# Need to generate embedding
images_to_embed.append(img)
indices_to_embed.append(i)
# Generate embeddings for cache misses
if images_to_embed:
if self.embedder.supports_batch():
new_embeddings = self.embedder.embed_batch(images_to_embed)
else:
new_embeddings = np.array([
self.embedder.embed(img) for img in images_to_embed
])
# Add to cache and results
for img, idx, emb in zip(images_to_embed, indices_to_embed, new_embeddings):
if use_cache:
cache_key = self._get_cache_key(img)
self._add_to_cache(cache_key, emb)
embeddings.append((idx, emb))
# Sort by original index and extract embeddings
embeddings.sort(key=lambda x: x[0])
return np.array([emb for _, emb in embeddings])
def _get_cache_key(self, image: Image.Image) -> str:
"""
Generate cache key from image content.
Uses MD5 hash of image bytes for fast lookup.
Args:
image: PIL Image
Returns:
str: Cache key (MD5 hash)
"""
return hashlib.md5(image.tobytes()).hexdigest()
def _add_to_cache(self, key: str, embedding: np.ndarray):
"""
Add embedding to cache with LRU eviction.
Args:
key: Cache key
embedding: Embedding to cache
"""
# Add to cache
self._cache[key] = embedding
# Evict oldest if cache is full
if len(self._cache) > self.cache_size:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
logger.debug(f"Cache eviction: size={len(self._cache)}")
def clear_cache(self):
"""Clear all cached embeddings."""
self._cache.clear()
logger.info("Cache cleared")
def get_dimension(self) -> int:
"""
Get embedding dimension.
Returns:
int: Embedding dimension
"""
return self.embedder.get_dimension()
def get_model_name(self) -> str:
"""
Get current model name.
Returns:
str: Model identifier
"""
return self.embedder.get_model_name()
def get_stats(self) -> Dict[str, Any]:
"""
Get manager statistics.
Returns:
Dict with keys:
- model_name: Current model
- dimension: Embedding dimension
- cache_size: Current cache size
- cache_capacity: Maximum cache size
- cache_hits: Number of cache hits
- cache_misses: Number of cache misses
- cache_hit_rate: Hit rate (0-1)
"""
total_requests = self._cache_hits + self._cache_misses
hit_rate = self._cache_hits / total_requests if total_requests > 0 else 0.0
return {
'model_name': self.get_model_name(),
'dimension': self.get_dimension(),
'cache_size': len(self._cache),
'cache_capacity': self.cache_size,
'cache_hits': self._cache_hits,
'cache_misses': self._cache_misses,
'cache_hit_rate': hit_rate
}
def __repr__(self) -> str:
"""String representation."""
stats = self.get_stats()
return (
f"EmbeddingManager(model={stats['model_name']}, "
f"cache={stats['cache_size']}/{stats['cache_capacity']}, "
f"hit_rate={stats['cache_hit_rate']:.2%})"
)

View File

@@ -0,0 +1,309 @@
"""
FAISS index wrapper with proper dimension handling and persistence.
This module provides a robust wrapper around FAISS for storing and searching
image embeddings, with proper error handling for dimension mismatches and
reliable save/load functionality.
"""
import pickle
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional
import numpy as np
try:
import faiss
except ImportError:
faiss = None
logger = logging.getLogger(__name__)
class FAISSIndex:
"""
Wrapper around FAISS index with metadata storage and dimension validation.
This class handles:
- Dimension validation on add/search operations
- Metadata storage alongside embeddings
- Reliable persistence (save/load)
- Automatic index rebuilding on dimension changes
"""
def __init__(self, dimension: int):
"""
Initialize a new FAISS index.
Args:
dimension: Embedding dimension (e.g., 512 for CLIP, 768 for Pix2Struct)
Raises:
ImportError: If FAISS is not installed
ValueError: If dimension is invalid
"""
if faiss is None:
raise ImportError(
"FAISS is not installed. "
"Install it with: pip install faiss-cpu or faiss-gpu"
)
if dimension <= 0:
raise ValueError(f"Dimension must be positive, got {dimension}")
self.dimension = dimension
self.index = faiss.IndexFlatL2(dimension)
self.metadata: List[Dict[str, Any]] = []
logger.info(f"FAISSIndex created with dimension={dimension}")
def add(self, embeddings: np.ndarray, metadata: List[Dict[str, Any]]):
"""
Add embeddings to the index with associated metadata.
Args:
embeddings: Array of shape (N, dimension) containing N embeddings
metadata: List of N metadata dictionaries
Raises:
ValueError: If dimensions don't match or array shapes are invalid
"""
# Validate input shape
if embeddings.ndim == 1:
# Single embedding, reshape to (1, dimension)
embeddings = embeddings.reshape(1, -1)
elif embeddings.ndim != 2:
raise ValueError(
f"Embeddings must be 1D or 2D array, got shape {embeddings.shape}"
)
# Validate dimension
if embeddings.shape[1] != self.dimension:
raise ValueError(
f"Embedding dimension {embeddings.shape[1]} doesn't match "
f"index dimension {self.dimension}"
)
# Validate metadata count
if len(metadata) != embeddings.shape[0]:
raise ValueError(
f"Number of metadata entries ({len(metadata)}) doesn't match "
f"number of embeddings ({embeddings.shape[0]})"
)
# Add to FAISS index
self.index.add(embeddings.astype('float32'))
# Store metadata
self.metadata.extend(metadata)
logger.debug(
f"Added {embeddings.shape[0]} embeddings to index "
f"(total: {self.index.ntotal})"
)
def search(
self,
query: np.ndarray,
k: int = 5
) -> List[Dict[str, Any]]:
"""
Search for the k most similar embeddings.
Args:
query: Query embedding of shape (dimension,) or (1, dimension)
k: Number of results to return
Returns:
List of dicts with keys:
- 'index': Index in the FAISS index
- 'distance': L2 distance
- 'similarity': Similarity score (1 / (1 + distance))
- 'metadata': Associated metadata dict
Raises:
ValueError: If query dimension doesn't match index dimension
"""
if self.index.ntotal == 0:
logger.warning("Search called on empty index")
return []
# Reshape query if needed
if query.ndim == 1:
query = query.reshape(1, -1)
elif query.ndim != 2:
raise ValueError(
f"Query must be 1D or 2D array, got shape {query.shape}"
)
# Validate dimension
if query.shape[1] != self.dimension:
raise ValueError(
f"Query dimension {query.shape[1]} doesn't match "
f"index dimension {self.dimension}"
)
# Limit k to available embeddings
k = min(k, self.index.ntotal)
# Search
distances, indices = self.index.search(query.astype('float32'), k)
# Format results
results = []
for dist, idx in zip(distances[0], indices[0]):
# FAISS returns -1 if not enough results
if idx >= 0 and idx < len(self.metadata):
results.append({
'index': int(idx),
'distance': float(dist),
'similarity': float(1.0 / (1.0 + dist)),
'metadata': self.metadata[idx]
})
return results
def save(self, path: str):
"""
Save index and metadata to disk.
Args:
path: Base path for saving (will create .index and .metadata files)
Raises:
RuntimeError: If save operation fails
"""
try:
path_obj = Path(path)
path_obj.parent.mkdir(parents=True, exist_ok=True)
# Save FAISS index
index_file = f"{path}.index"
faiss.write_index(self.index, index_file)
# Save metadata
metadata_file = f"{path}.metadata"
with open(metadata_file, 'wb') as f:
pickle.dump({
'dimension': self.dimension,
'metadata': self.metadata
}, f)
logger.info(
f"Saved index with {self.index.ntotal} embeddings to {path}"
)
except Exception as e:
raise RuntimeError(f"Failed to save index: {e}")
def load(self, path: str):
"""
Load index and metadata from disk.
Args:
path: Base path for loading (will read .index and .metadata files)
Raises:
FileNotFoundError: If files don't exist
RuntimeError: If load operation fails or dimension mismatch
"""
try:
index_file = f"{path}.index"
metadata_file = f"{path}.metadata"
# Check files exist
if not Path(index_file).exists():
raise FileNotFoundError(f"Index file not found: {index_file}")
if not Path(metadata_file).exists():
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
# Load FAISS index
loaded_index = faiss.read_index(index_file)
# Load metadata
with open(metadata_file, 'rb') as f:
data = pickle.load(f)
loaded_dimension = data['dimension']
loaded_metadata = data['metadata']
# Validate dimension
if loaded_dimension != self.dimension:
raise RuntimeError(
f"Loaded index dimension ({loaded_dimension}) doesn't match "
f"current dimension ({self.dimension}). "
f"Use rebuild_if_needed() to handle dimension changes."
)
# Update state
self.index = loaded_index
self.metadata = loaded_metadata
logger.info(
f"Loaded index with {self.index.ntotal} embeddings from {path}"
)
except Exception as e:
if isinstance(e, (FileNotFoundError, RuntimeError)):
raise
raise RuntimeError(f"Failed to load index: {e}")
def rebuild_if_needed(self, new_dimension: int) -> bool:
"""
Rebuild index if dimension has changed.
This creates a new empty index with the new dimension.
Old embeddings are lost and need to be regenerated.
Args:
new_dimension: New embedding dimension
Returns:
bool: True if index was rebuilt, False if dimension unchanged
"""
if new_dimension == self.dimension:
return False
logger.warning(
f"Rebuilding FAISS index: dimension changed from "
f"{self.dimension} to {new_dimension}. "
f"Old embeddings ({self.index.ntotal}) will be lost."
)
# Create new index
self.dimension = new_dimension
self.index = faiss.IndexFlatL2(new_dimension)
self.metadata = []
return True
def clear(self):
"""Clear all embeddings from the index."""
self.index = faiss.IndexFlatL2(self.dimension)
self.metadata = []
logger.info("Index cleared")
def get_stats(self) -> Dict[str, Any]:
"""
Get index statistics.
Returns:
Dict with keys: num_embeddings, dimension, is_trained
"""
return {
'num_embeddings': self.index.ntotal,
'dimension': self.dimension,
'is_trained': self.index.is_trained
}
def __len__(self) -> int:
"""Return number of embeddings in the index."""
return self.index.ntotal
def __repr__(self) -> str:
"""String representation of the index."""
return (
f"FAISSIndex(dimension={self.dimension}, "
f"num_embeddings={self.index.ntotal})"
)

View File

@@ -0,0 +1,321 @@
"""
Lightweight fine-tuner for embedding models.
This module provides incremental fine-tuning capabilities that run in the
background, adapting the embedding model to user-specific workflows over time.
"""
import threading
import time
import pickle
import logging
from collections import deque
from pathlib import Path
from typing import List, Dict, Any, Optional
from PIL import Image
import numpy as np
logger = logging.getLogger(__name__)
class LightweightFineTuner:
"""
Lightweight fine-tuner for incremental model adaptation.
This class collects positive and negative examples from user interactions
and periodically fine-tunes the embedding model to improve accuracy on
user-specific workflows.
Features:
- Automatic triggering after N examples
- Background training (non-blocking)
- Checkpoint save/load for recovery
- Metrics tracking
"""
def __init__(
self,
embedder,
trigger_threshold: int = 10,
max_examples: int = 1000,
checkpoint_dir: str = "data/fine_tuning"
):
"""
Initialize the fine-tuner.
Args:
embedder: Embedder instance to fine-tune (must support fine_tune method)
trigger_threshold: Number of new examples before triggering fine-tuning
max_examples: Maximum examples to keep (LRU eviction)
checkpoint_dir: Directory for saving checkpoints
"""
self.embedder = embedder
self.trigger_threshold = trigger_threshold
self.max_examples = max_examples
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Example storage (deque for automatic LRU)
self.positive_examples = deque(maxlen=max_examples)
self.negative_examples = deque(maxlen=max_examples)
# Training state
self.is_training = False
self.training_thread: Optional[threading.Thread] = None
self.last_training_time = 0
self.training_count = 0
# Metrics
self.metrics_history: List[Dict[str, Any]] = []
logger.info(
f"LightweightFineTuner initialized: "
f"trigger={trigger_threshold}, max_examples={max_examples}"
)
def add_positive_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
"""
Add a positive example (successful workflow execution).
Args:
image: Screenshot where workflow succeeded
workflow_id: ID of the successful workflow
metadata: Optional additional metadata
"""
example = {
'image': image,
'workflow_id': workflow_id,
'metadata': metadata or {},
'timestamp': time.time()
}
self.positive_examples.append(example)
logger.debug(
f"Added positive example: workflow={workflow_id}, "
f"total_positive={len(self.positive_examples)}"
)
self._check_trigger()
def add_negative_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
"""
Add a negative example (rejected workflow suggestion).
Args:
image: Screenshot where workflow was rejected
workflow_id: ID of the rejected workflow
metadata: Optional additional metadata
"""
example = {
'image': image,
'workflow_id': workflow_id,
'metadata': metadata or {},
'timestamp': time.time()
}
self.negative_examples.append(example)
logger.debug(
f"Added negative example: workflow={workflow_id}, "
f"total_negative={len(self.negative_examples)}"
)
self._check_trigger()
def _check_trigger(self):
"""Check if we should trigger fine-tuning."""
total_new = len(self.positive_examples) + len(self.negative_examples)
# Don't trigger if already training
if self.is_training:
logger.debug("Fine-tuning already in progress, skipping trigger check")
return
# Check if we have enough examples
if total_new >= self.trigger_threshold:
logger.info(
f"Fine-tuning triggered: {total_new} examples "
f"({len(self.positive_examples)} positive, "
f"{len(self.negative_examples)} negative)"
)
self._start_training()
def _start_training(self):
"""Start training in background thread."""
if self.is_training:
logger.warning("Training already in progress")
return
self.training_thread = threading.Thread(
target=self._train,
name="FineTuningThread",
daemon=True
)
self.training_thread.start()
logger.info("Fine-tuning thread started")
def _train(self):
"""Fine-tune the model (runs in background thread)."""
self.is_training = True
start_time = time.time()
try:
# Check if embedder supports fine-tuning
if not hasattr(self.embedder, 'fine_tune'):
logger.info(
f"Embedder {self.embedder.get_model_name()} doesn't support "
f"fine-tuning, skipping"
)
return
# Prepare training data
positive_images = [ex['image'] for ex in self.positive_examples]
negative_images = [ex['image'] for ex in self.negative_examples]
if not positive_images and not negative_images:
logger.warning("No examples to train on")
return
logger.info(
f"Starting fine-tuning: {len(positive_images)} positive, "
f"{len(negative_images)} negative examples"
)
# Fine-tune (implementation depends on embedder)
metrics = self.embedder.fine_tune(
positive_images=positive_images,
negative_images=negative_images,
epochs=1,
learning_rate=1e-4
)
# Record metrics
duration = time.time() - start_time
metrics['duration_seconds'] = duration
metrics['timestamp'] = time.time()
metrics['positive_count'] = len(positive_images)
metrics['negative_count'] = len(negative_images)
metrics['training_number'] = self.training_count
self.metrics_history.append(metrics)
self.last_training_time = time.time()
self.training_count += 1
logger.info(
f"Fine-tuning complete #{self.training_count}: "
f"loss={metrics.get('loss', 'N/A'):.4f}, "
f"duration={duration:.1f}s"
)
# Clear examples after successful training
self.positive_examples.clear()
self.negative_examples.clear()
logger.debug("Training examples cleared")
except Exception as e:
logger.error(f"Fine-tuning failed: {e}", exc_info=True)
finally:
self.is_training = False
def save_checkpoint(self, name: str = "checkpoint"):
"""
Save training examples and metrics for recovery.
Args:
name: Checkpoint name
"""
try:
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
data = {
'positive_examples': list(self.positive_examples),
'negative_examples': list(self.negative_examples),
'metrics_history': self.metrics_history,
'training_count': self.training_count,
'last_training_time': self.last_training_time
}
with open(checkpoint_path, 'wb') as f:
pickle.dump(data, f)
logger.info(f"Checkpoint saved: {checkpoint_path}")
except Exception as e:
logger.error(f"Failed to save checkpoint: {e}")
def load_checkpoint(self, name: str = "checkpoint"):
"""
Load training examples and metrics from checkpoint.
Args:
name: Checkpoint name
Returns:
bool: True if loaded successfully, False otherwise
"""
try:
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
if not checkpoint_path.exists():
logger.warning(f"Checkpoint not found: {checkpoint_path}")
return False
with open(checkpoint_path, 'rb') as f:
data = pickle.load(f)
self.positive_examples.extend(data.get('positive_examples', []))
self.negative_examples.extend(data.get('negative_examples', []))
self.metrics_history = data.get('metrics_history', [])
self.training_count = data.get('training_count', 0)
self.last_training_time = data.get('last_training_time', 0)
logger.info(
f"Checkpoint loaded: {len(self.positive_examples)} positive, "
f"{len(self.negative_examples)} negative examples"
)
return True
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
return False
def get_stats(self) -> Dict[str, Any]:
"""
Get fine-tuning statistics.
Returns:
Dict with statistics
"""
return {
'positive_examples': len(self.positive_examples),
'negative_examples': len(self.negative_examples),
'total_examples': len(self.positive_examples) + len(self.negative_examples),
'is_training': self.is_training,
'training_count': self.training_count,
'last_training_time': self.last_training_time,
'metrics_history': self.metrics_history,
'trigger_threshold': self.trigger_threshold
}
def wait_for_training(self, timeout: Optional[float] = None):
"""
Wait for current training to complete.
Args:
timeout: Maximum time to wait in seconds (None = wait forever)
"""
if self.training_thread and self.training_thread.is_alive():
self.training_thread.join(timeout=timeout)
def __repr__(self) -> str:
"""String representation."""
stats = self.get_stats()
return (
f"LightweightFineTuner("
f"examples={stats['total_examples']}, "
f"trainings={stats['training_count']}, "
f"is_training={stats['is_training']})"
)

View File

@@ -0,0 +1,193 @@
"""
Pix2Struct-based embedder implementation.
This module provides a wrapper around Google's Pix2Struct model for generating
image embeddings specialized for UI understanding and document screenshots.
"""
import torch
import numpy as np
from PIL import Image
from typing import List, Optional
import logging
try:
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
except ImportError:
Pix2StructProcessor = None
Pix2StructForConditionalGeneration = None
from .base import EmbedderBase
logger = logging.getLogger(__name__)
class Pix2StructEmbedder(EmbedderBase):
"""
Pix2Struct-based image embedder specialized for UI understanding.
Pix2Struct is a vision-language model trained on screenshots and structured
documents, making it particularly well-suited for RPA and UI automation tasks.
This embedder uses the encoder's hidden states as embeddings, which capture
visual features optimized for understanding UI elements and layouts.
"""
def __init__(
self,
model_name: str = "google/pix2struct-base",
device: Optional[str] = None
):
"""
Initialize the Pix2Struct embedder.
Args:
model_name: Pix2Struct model to use (default: google/pix2struct-base)
device: Device to use ('cuda', 'cpu', or None for auto-detect)
Raises:
ImportError: If transformers is not installed
RuntimeError: If model loading fails
"""
if Pix2StructProcessor is None or Pix2StructForConditionalGeneration is None:
raise ImportError(
"Transformers is not installed or version is too old. "
"Install it with: pip install transformers>=4.35.0"
)
# Auto-detect device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_name = model_name
self.device = device
self._embedding_dim = None
# Load model and processor
try:
logger.info(f"Loading Pix2Struct model: {model_name}")
self.processor = Pix2StructProcessor.from_pretrained(model_name)
self.model = Pix2StructForConditionalGeneration.from_pretrained(
model_name
).to(device)
self.model.eval()
# Determine embedding dimension from encoder
with torch.no_grad():
dummy_image = Image.new('RGB', (224, 224), color=(128, 128, 128))
inputs = self.processor(images=dummy_image, return_tensors="pt").to(device)
encoder_outputs = self.model.encoder(**inputs)
# Use mean pooling of last hidden state
self._embedding_dim = encoder_outputs.last_hidden_state.shape[-1]
logger.info(
f"Pix2StructEmbedder loaded: {model_name} on {device}, "
f"dimension={self._embedding_dim}"
)
except Exception as e:
raise RuntimeError(f"Failed to load Pix2Struct model: {e}")
def embed(self, image: Image.Image) -> np.ndarray:
"""
Generate embedding for a single image.
Args:
image: PIL Image to embed
Returns:
np.ndarray: Normalized embedding vector of shape (dimension,)
Raises:
ValueError: If image is invalid
RuntimeError: If embedding generation fails
"""
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image")
try:
# Process image
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# Generate embedding from encoder
with torch.no_grad():
encoder_outputs = self.model.encoder(**inputs)
# Mean pooling over sequence dimension
embedding = encoder_outputs.last_hidden_state.mean(dim=1)
# L2 normalize for cosine similarity
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
return embedding.cpu().numpy().flatten()
except Exception as e:
raise RuntimeError(f"Failed to generate embedding: {e}")
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
"""
Generate embeddings for multiple images (optimized batch processing).
Args:
images: List of PIL Images to embed
Returns:
np.ndarray: Array of embeddings with shape (len(images), dimension)
Raises:
ValueError: If any image is invalid
RuntimeError: If embedding generation fails
"""
if not images:
return np.array([]).reshape(0, self.get_dimension())
# Validate all images
for i, img in enumerate(images):
if not isinstance(img, Image.Image):
raise ValueError(f"Image at index {i} is not a PIL Image")
try:
# Process all images in batch
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
# Generate embeddings in batch
with torch.no_grad():
encoder_outputs = self.model.encoder(**inputs)
# Mean pooling over sequence dimension
embeddings = encoder_outputs.last_hidden_state.mean(dim=1)
# L2 normalize for cosine similarity
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
return embeddings.cpu().numpy()
except Exception as e:
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
def get_dimension(self) -> int:
"""
Get the dimensionality of embeddings.
Returns:
int: Embedding dimension (768 for pix2struct-base)
"""
return self._embedding_dim
def get_model_name(self) -> str:
"""
Get model identifier.
Returns:
str: Model name (e.g., "pix2struct-base")
"""
# Extract model name from full path
model_id = self.model_name.split('/')[-1]
return f"pix2struct-{model_id}" if not model_id.startswith('pix2struct') else model_id
def supports_batch(self) -> bool:
"""
Check if batch processing is supported.
Returns:
bool: True (Pix2Struct supports efficient batch processing)
"""
return True