# Design - Amélioration du Système d'Embeddings et Fine-tuning ## Overview Le système d'embeddings actuel utilise CLIP avec FAISS mais souffre de problèmes de dimensions et de performance. Cette amélioration corrige FAISS, crée une abstraction pour supporter plusieurs modèles (CLIP baseline, Pix2Struct pour UI), et ajoute un fine-tuning léger incrémental qui s'exécute toutes les 10 nouvelles interactions pour améliorer la précision. ## Architecture ``` ┌─────────────────────────────────────────────────────────────┐ │ Embedding System │ ├─────────────────────────────────────────────────────────────┤ │ │ │ ┌──────────────┐ ┌──────────────┐ │ │ │ EmbedderBase │◄─────│ CLIPEmbedder │ │ │ │ (Abstract) │ └──────────────┘ │ │ └──────┬───────┘ │ │ │ │ │ │ ┌────────────────────┐ │ │ └──────────────│ Pix2StructEmbedder │ │ │ └────────────────────┘ │ │ │ │ ┌────────────────────────────────────────────────────┐ │ │ │ EmbeddingManager │ │ │ │ - Model selection & fallback │ │ │ │ - Caching (LRU 1000 entries) │ │ │ │ - GPU/CPU management │ │ │ └────────────────────────────────────────────────────┘ │ │ │ │ ┌────────────────────────────────────────────────────┐ │ │ │ FAISSIndex │ │ │ │ - Fixed dimension handling │ │ │ │ - Persistence (index + metadata) │ │ │ │ - Auto-rebuild on dimension change │ │ │ └────────────────────────────────────────────────────┘ │ │ │ │ ┌────────────────────────────────────────────────────┐ │ │ │ LightweightFineTuner │ │ │ │ - Collects positive/negative examples │ │ │ │ - Triggers every 10 examples │ │ │ │ - Runs in background thread │ │ │ │ - Updates last layer only │ │ │ └────────────────────────────────────────────────────┘ │ │ │ └─────────────────────────────────────────────────────────────┘ │ │ │ ▼ ▼ ▼ WorkflowMatcher VisionAnalysis SessionManager ``` ## Components and Interfaces ### 1. EmbedderBase (Abstract Interface) **Responsabilité**: Interface commune pour tous les modèles d'embedding ```python from abc import ABC, abstractmethod from PIL import Image import numpy as np class EmbedderBase(ABC): @abstractmethod def embed(self, image: Image.Image) -> np.ndarray: """Generate embedding for an image Returns: np.ndarray: Normalized embedding vector """ pass @abstractmethod def get_dimension(self) -> int: """Get embedding dimension""" pass @abstractmethod def get_model_name(self) -> str: """Get model identifier""" pass @abstractmethod def supports_batch(self) -> bool: """Check if model supports batch processing""" pass def embed_batch(self, images: List[Image.Image]) -> np.ndarray: """Generate embeddings for multiple images (optional optimization)""" return np.array([self.embed(img) for img in images]) ``` ### 2. CLIPEmbedder (Concrete Implementation) **Responsabilité**: Wrapper pour CLIP avec l'interface commune ```python class CLIPEmbedder(EmbedderBase): def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'): self.device = device self.model, self.preprocess = clip.load("ViT-B/32", device=device) self.model.eval() def embed(self, image: Image.Image) -> np.ndarray: with torch.no_grad(): image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) embedding = self.model.encode_image(image_tensor) embedding = embedding / embedding.norm(dim=-1, keepdim=True) return embedding.cpu().numpy().flatten() def get_dimension(self) -> int: return 512 # CLIP ViT-B/32 dimension def get_model_name(self) -> str: return "clip-vit-b32" def supports_batch(self) -> bool: return True ``` ### 3. Pix2StructEmbedder (Concrete Implementation) **Responsabilité**: Wrapper pour Pix2Struct spécialisé UI ```python from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration class Pix2StructEmbedder(EmbedderBase): def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'): self.device = device self.processor = Pix2StructProcessor.from_pretrained( "google/pix2struct-base" ) self.model = Pix2StructForConditionalGeneration.from_pretrained( "google/pix2struct-base" ).to(device) self.model.eval() def embed(self, image: Image.Image) -> np.ndarray: with torch.no_grad(): inputs = self.processor(images=image, return_tensors="pt").to(self.device) # Extract encoder hidden states as embeddings encoder_outputs = self.model.encoder(**inputs) embedding = encoder_outputs.last_hidden_state.mean(dim=1) embedding = embedding / embedding.norm(dim=-1, keepdim=True) return embedding.cpu().numpy().flatten() def get_dimension(self) -> int: return 768 # Pix2Struct base dimension def get_model_name(self) -> str: return "pix2struct-base" def supports_batch(self) -> bool: return True ``` ### 4. EmbeddingManager **Responsabilité**: Gestion des modèles, cache, et sélection ```python from functools import lru_cache from typing import Optional class EmbeddingManager: def __init__(self, model_name: str = "clip", fallback: bool = True): self.model_name = model_name self.fallback_enabled = fallback self.embedder = self._load_embedder() self._cache = {} # Manual LRU cache self._cache_order = [] self.max_cache_size = 1000 def _load_embedder(self) -> EmbedderBase: """Load embedder with fallback""" try: if self.model_name == "pix2struct": return Pix2StructEmbedder() elif self.model_name == "clip": return CLIPEmbedder() else: raise ValueError(f"Unknown model: {self.model_name}") except Exception as e: if self.fallback_enabled: logger.warning(f"Failed to load {self.model_name}, falling back to CLIP: {e}") return CLIPEmbedder() raise def embed(self, image: Image.Image) -> np.ndarray: """Generate embedding with caching""" # Create cache key from image hash img_hash = hashlib.md5(image.tobytes()).hexdigest() if img_hash in self._cache: return self._cache[img_hash] # Generate embedding embedding = self.embedder.embed(image) # Update cache self._cache[img_hash] = embedding self._cache_order.append(img_hash) # Evict if needed if len(self._cache) > self.max_cache_size: oldest = self._cache_order.pop(0) del self._cache[oldest] return embedding def get_dimension(self) -> int: return self.embedder.get_dimension() def get_model_name(self) -> str: return self.embedder.get_model_name() ``` ### 5. FAISSIndex (Fixed) **Responsabilité**: Gestion de l'index FAISS avec correction des bugs ```python import faiss import pickle class FAISSIndex: def __init__(self, dimension: int): self.dimension = dimension self.index = faiss.IndexFlatL2(dimension) self.metadata = [] # Store workflow IDs and step info def add(self, embeddings: np.ndarray, metadata: List[dict]): """Add embeddings to index Args: embeddings: Shape (N, dimension) metadata: List of N metadata dicts """ if embeddings.shape[1] != self.dimension: raise ValueError( f"Embedding dimension {embeddings.shape[1]} doesn't match " f"index dimension {self.dimension}" ) self.index.add(embeddings.astype('float32')) self.metadata.extend(metadata) def search(self, query: np.ndarray, k: int = 5) -> List[dict]: """Search for similar embeddings Args: query: Shape (1, dimension) or (dimension,) k: Number of results Returns: List of dicts with 'distance', 'metadata' """ if query.ndim == 1: query = query.reshape(1, -1) if query.shape[1] != self.dimension: raise ValueError( f"Query dimension {query.shape[1]} doesn't match " f"index dimension {self.dimension}" ) distances, indices = self.index.search(query.astype('float32'), k) results = [] for dist, idx in zip(distances[0], indices[0]): if idx < len(self.metadata): results.append({ 'distance': float(dist), 'similarity': 1.0 / (1.0 + dist), # Convert to similarity 'metadata': self.metadata[idx] }) return results def save(self, path: str): """Save index and metadata""" faiss.write_index(self.index, f"{path}.index") with open(f"{path}.metadata", 'wb') as f: pickle.dump({ 'dimension': self.dimension, 'metadata': self.metadata }, f) def load(self, path: str): """Load index and metadata""" self.index = faiss.read_index(f"{path}.index") with open(f"{path}.metadata", 'rb') as f: data = pickle.load(f) self.dimension = data['dimension'] self.metadata = data['metadata'] def rebuild_if_needed(self, new_dimension: int): """Rebuild index if dimension changed""" if new_dimension != self.dimension: logger.info(f"Rebuilding FAISS index: {self.dimension} -> {new_dimension}") old_metadata = self.metadata self.dimension = new_dimension self.index = faiss.IndexFlatL2(new_dimension) self.metadata = [] # Note: Old embeddings are lost, need to regenerate return True return False ``` ### 6. LightweightFineTuner **Responsabilité**: Fine-tuning incrémental en arrière-plan ```python import threading from collections import deque class LightweightFineTuner: def __init__(self, embedder: EmbedderBase, trigger_threshold: int = 10): self.embedder = embedder self.trigger_threshold = trigger_threshold self.positive_examples = deque(maxlen=1000) self.negative_examples = deque(maxlen=1000) self.is_training = False self.training_thread = None def add_positive_example(self, image: Image.Image, workflow_id: str): """Add successful workflow execution example""" self.positive_examples.append({ 'image': image, 'workflow_id': workflow_id, 'timestamp': time.time() }) self._check_trigger() def add_negative_example(self, image: Image.Image, workflow_id: str): """Add rejected workflow suggestion example""" self.negative_examples.append({ 'image': image, 'workflow_id': workflow_id, 'timestamp': time.time() }) self._check_trigger() def _check_trigger(self): """Check if we should trigger fine-tuning""" total_new = len(self.positive_examples) + len(self.negative_examples) if total_new >= self.trigger_threshold and not self.is_training: logger.info(f"Triggering fine-tuning with {total_new} new examples") self._start_training() def _start_training(self): """Start training in background thread""" self.training_thread = threading.Thread(target=self._train) self.training_thread.daemon = True self.training_thread.start() def _train(self): """Fine-tune the last layer of the model""" self.is_training = True try: # Only for models that support fine-tuning (not CLIP base) if not hasattr(self.embedder, 'fine_tune'): logger.info("Model doesn't support fine-tuning, skipping") return # Prepare training data positive_images = [ex['image'] for ex in self.positive_examples] negative_images = [ex['image'] for ex in self.negative_examples] # Fine-tune (implementation depends on model) metrics = self.embedder.fine_tune( positive_images=positive_images, negative_images=negative_images, epochs=1, learning_rate=1e-4 ) logger.info(f"Fine-tuning complete: {metrics}") # Clear examples after training self.positive_examples.clear() self.negative_examples.clear() except Exception as e: logger.error(f"Fine-tuning failed: {e}") finally: self.is_training = False def save_checkpoint(self, path: str): """Save training examples for recovery""" with open(path, 'wb') as f: pickle.dump({ 'positive': list(self.positive_examples), 'negative': list(self.negative_examples) }, f) def load_checkpoint(self, path: str): """Load training examples""" with open(path, 'rb') as f: data = pickle.load(f) self.positive_examples.extend(data['positive']) self.negative_examples.extend(data['negative']) ``` ## Data Models ### EmbeddingMetadata ```python @dataclass class EmbeddingMetadata: workflow_id: str step_index: int action_type: str timestamp: float model_name: str model_version: str ``` ### FineTuningMetrics ```python @dataclass class FineTuningMetrics: loss: float accuracy: float positive_examples: int negative_examples: int duration_seconds: float timestamp: float ``` ## Correctness Properties *A property is a characteristic or behavior that should hold true across all valid executions of a system-essentially, a formal statement about what the system should do. Properties serve as the bridge between human-readable specifications and machine-verifiable correctness guarantees.* ### Property 1: Embedding Dimension Consistency *For any* embedder instance, all generated embeddings SHALL have the same dimension as reported by `get_dimension()` **Validates: Requirements 1.1, 2.2, 3.3** ### Property 2: FAISS Index Persistence Round-trip *For any* FAISS index with embeddings, saving then loading SHALL restore the exact same search results **Validates: Requirements 1.3, 1.4** ### Property 3: Embedder Interface Compatibility *For any* embedder implementation (CLIP or Pix2Struct), calling `embed()` with a valid PIL image SHALL return a normalized numpy array **Validates: Requirements 2.1, 2.2, 2.5** ### Property 4: Cache Hit Consistency *For any* image, generating embeddings twice SHALL return identical results (via caching) **Validates: Requirements 6.1** ### Property 5: Fallback Reliability *For any* model loading failure, if fallback is enabled, the system SHALL successfully load CLIP without raising exceptions **Validates: Requirements 2.4, 3.5** ### Property 6: Fine-tuning Non-blocking *For any* fine-tuning operation, the main thread SHALL continue generating embeddings without blocking **Validates: Requirements 5.1, 5.2** ### Property 7: Example Collection Bounds *For any* sequence of positive/negative examples, the collection SHALL never exceed 1000 examples (deque maxlen) **Validates: Requirements 4.1, 4.2** ## Error Handling - **Model loading fails**: Fallback to CLIP if enabled, otherwise raise clear error - **Dimension mismatch**: Rebuild FAISS index automatically, log warning - **GPU unavailable**: Fall back to CPU seamlessly - **Fine-tuning fails**: Log error, keep current model, don't crash - **Cache overflow**: Evict LRU entries automatically - **Corrupted index file**: Rebuild from scratch, log error ## Testing Strategy ### Unit Tests - Test each embedder implementation independently - Test FAISS index add/search/save/load operations - Test cache eviction logic - Test dimension mismatch handling - Test fallback mechanism ### Property-Based Tests - Property 1: Dimension consistency across random images - Property 2: FAISS round-trip with random embeddings - Property 3: Interface compatibility with random PIL images - Property 4: Cache consistency with duplicate images - Property 5: Fallback with simulated failures - Property 6: Fine-tuning non-blocking with concurrent operations - Property 7: Example collection bounds with random additions ### Integration Tests - Test full workflow: image → embedding → FAISS → search - Test model switching (CLIP ↔ Pix2Struct) - Test fine-tuning trigger and execution - Test system behavior under GPU/CPU switching ## Performance - **Embedding generation**: < 200ms per image (GPU), < 500ms (CPU) - **FAISS search**: < 10ms for k=5 in index of 10,000 embeddings - **Cache hit**: < 1ms - **Fine-tuning**: < 2 minutes for 100 examples - **Model loading**: < 5 seconds (CLIP), < 10 seconds (Pix2Struct) - **Memory usage**: ~2GB (CLIP), ~4GB (Pix2Struct), ~500MB (FAISS index for 10k embeddings)