Files
Geniusia_v2/.kiro/specs/embedding-improvement/design.md
2026-03-05 00:20:25 +01:00

509 lines
20 KiB
Markdown

# Design - Amélioration du Système d'Embeddings et Fine-tuning
## Overview
Le système d'embeddings actuel utilise CLIP avec FAISS mais souffre de problèmes de dimensions et de performance. Cette amélioration corrige FAISS, crée une abstraction pour supporter plusieurs modèles (CLIP baseline, Pix2Struct pour UI), et ajoute un fine-tuning léger incrémental qui s'exécute toutes les 10 nouvelles interactions pour améliorer la précision.
## Architecture
```
┌─────────────────────────────────────────────────────────────┐
│ Embedding System │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ EmbedderBase │◄─────│ CLIPEmbedder │ │
│ │ (Abstract) │ └──────────────┘ │
│ └──────┬───────┘ │
│ │ │
│ │ ┌────────────────────┐ │
│ └──────────────│ Pix2StructEmbedder │ │
│ └────────────────────┘ │
│ │
│ ┌────────────────────────────────────────────────────┐ │
│ │ EmbeddingManager │ │
│ │ - Model selection & fallback │ │
│ │ - Caching (LRU 1000 entries) │ │
│ │ - GPU/CPU management │ │
│ └────────────────────────────────────────────────────┘ │
│ │
│ ┌────────────────────────────────────────────────────┐ │
│ │ FAISSIndex │ │
│ │ - Fixed dimension handling │ │
│ │ - Persistence (index + metadata) │ │
│ │ - Auto-rebuild on dimension change │ │
│ └────────────────────────────────────────────────────┘ │
│ │
│ ┌────────────────────────────────────────────────────┐ │
│ │ LightweightFineTuner │ │
│ │ - Collects positive/negative examples │ │
│ │ - Triggers every 10 examples │ │
│ │ - Runs in background thread │ │
│ │ - Updates last layer only │ │
│ └────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
│ │ │
▼ ▼ ▼
WorkflowMatcher VisionAnalysis SessionManager
```
## Components and Interfaces
### 1. EmbedderBase (Abstract Interface)
**Responsabilité**: Interface commune pour tous les modèles d'embedding
```python
from abc import ABC, abstractmethod
from PIL import Image
import numpy as np
class EmbedderBase(ABC):
@abstractmethod
def embed(self, image: Image.Image) -> np.ndarray:
"""Generate embedding for an image
Returns:
np.ndarray: Normalized embedding vector
"""
pass
@abstractmethod
def get_dimension(self) -> int:
"""Get embedding dimension"""
pass
@abstractmethod
def get_model_name(self) -> str:
"""Get model identifier"""
pass
@abstractmethod
def supports_batch(self) -> bool:
"""Check if model supports batch processing"""
pass
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
"""Generate embeddings for multiple images (optional optimization)"""
return np.array([self.embed(img) for img in images])
```
### 2. CLIPEmbedder (Concrete Implementation)
**Responsabilité**: Wrapper pour CLIP avec l'interface commune
```python
class CLIPEmbedder(EmbedderBase):
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.device = device
self.model, self.preprocess = clip.load("ViT-B/32", device=device)
self.model.eval()
def embed(self, image: Image.Image) -> np.ndarray:
with torch.no_grad():
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
embedding = self.model.encode_image(image_tensor)
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
return embedding.cpu().numpy().flatten()
def get_dimension(self) -> int:
return 512 # CLIP ViT-B/32 dimension
def get_model_name(self) -> str:
return "clip-vit-b32"
def supports_batch(self) -> bool:
return True
```
### 3. Pix2StructEmbedder (Concrete Implementation)
**Responsabilité**: Wrapper pour Pix2Struct spécialisé UI
```python
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
class Pix2StructEmbedder(EmbedderBase):
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.device = device
self.processor = Pix2StructProcessor.from_pretrained(
"google/pix2struct-base"
)
self.model = Pix2StructForConditionalGeneration.from_pretrained(
"google/pix2struct-base"
).to(device)
self.model.eval()
def embed(self, image: Image.Image) -> np.ndarray:
with torch.no_grad():
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# Extract encoder hidden states as embeddings
encoder_outputs = self.model.encoder(**inputs)
embedding = encoder_outputs.last_hidden_state.mean(dim=1)
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
return embedding.cpu().numpy().flatten()
def get_dimension(self) -> int:
return 768 # Pix2Struct base dimension
def get_model_name(self) -> str:
return "pix2struct-base"
def supports_batch(self) -> bool:
return True
```
### 4. EmbeddingManager
**Responsabilité**: Gestion des modèles, cache, et sélection
```python
from functools import lru_cache
from typing import Optional
class EmbeddingManager:
def __init__(self, model_name: str = "clip", fallback: bool = True):
self.model_name = model_name
self.fallback_enabled = fallback
self.embedder = self._load_embedder()
self._cache = {} # Manual LRU cache
self._cache_order = []
self.max_cache_size = 1000
def _load_embedder(self) -> EmbedderBase:
"""Load embedder with fallback"""
try:
if self.model_name == "pix2struct":
return Pix2StructEmbedder()
elif self.model_name == "clip":
return CLIPEmbedder()
else:
raise ValueError(f"Unknown model: {self.model_name}")
except Exception as e:
if self.fallback_enabled:
logger.warning(f"Failed to load {self.model_name}, falling back to CLIP: {e}")
return CLIPEmbedder()
raise
def embed(self, image: Image.Image) -> np.ndarray:
"""Generate embedding with caching"""
# Create cache key from image hash
img_hash = hashlib.md5(image.tobytes()).hexdigest()
if img_hash in self._cache:
return self._cache[img_hash]
# Generate embedding
embedding = self.embedder.embed(image)
# Update cache
self._cache[img_hash] = embedding
self._cache_order.append(img_hash)
# Evict if needed
if len(self._cache) > self.max_cache_size:
oldest = self._cache_order.pop(0)
del self._cache[oldest]
return embedding
def get_dimension(self) -> int:
return self.embedder.get_dimension()
def get_model_name(self) -> str:
return self.embedder.get_model_name()
```
### 5. FAISSIndex (Fixed)
**Responsabilité**: Gestion de l'index FAISS avec correction des bugs
```python
import faiss
import pickle
class FAISSIndex:
def __init__(self, dimension: int):
self.dimension = dimension
self.index = faiss.IndexFlatL2(dimension)
self.metadata = [] # Store workflow IDs and step info
def add(self, embeddings: np.ndarray, metadata: List[dict]):
"""Add embeddings to index
Args:
embeddings: Shape (N, dimension)
metadata: List of N metadata dicts
"""
if embeddings.shape[1] != self.dimension:
raise ValueError(
f"Embedding dimension {embeddings.shape[1]} doesn't match "
f"index dimension {self.dimension}"
)
self.index.add(embeddings.astype('float32'))
self.metadata.extend(metadata)
def search(self, query: np.ndarray, k: int = 5) -> List[dict]:
"""Search for similar embeddings
Args:
query: Shape (1, dimension) or (dimension,)
k: Number of results
Returns:
List of dicts with 'distance', 'metadata'
"""
if query.ndim == 1:
query = query.reshape(1, -1)
if query.shape[1] != self.dimension:
raise ValueError(
f"Query dimension {query.shape[1]} doesn't match "
f"index dimension {self.dimension}"
)
distances, indices = self.index.search(query.astype('float32'), k)
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx < len(self.metadata):
results.append({
'distance': float(dist),
'similarity': 1.0 / (1.0 + dist), # Convert to similarity
'metadata': self.metadata[idx]
})
return results
def save(self, path: str):
"""Save index and metadata"""
faiss.write_index(self.index, f"{path}.index")
with open(f"{path}.metadata", 'wb') as f:
pickle.dump({
'dimension': self.dimension,
'metadata': self.metadata
}, f)
def load(self, path: str):
"""Load index and metadata"""
self.index = faiss.read_index(f"{path}.index")
with open(f"{path}.metadata", 'rb') as f:
data = pickle.load(f)
self.dimension = data['dimension']
self.metadata = data['metadata']
def rebuild_if_needed(self, new_dimension: int):
"""Rebuild index if dimension changed"""
if new_dimension != self.dimension:
logger.info(f"Rebuilding FAISS index: {self.dimension} -> {new_dimension}")
old_metadata = self.metadata
self.dimension = new_dimension
self.index = faiss.IndexFlatL2(new_dimension)
self.metadata = []
# Note: Old embeddings are lost, need to regenerate
return True
return False
```
### 6. LightweightFineTuner
**Responsabilité**: Fine-tuning incrémental en arrière-plan
```python
import threading
from collections import deque
class LightweightFineTuner:
def __init__(self, embedder: EmbedderBase, trigger_threshold: int = 10):
self.embedder = embedder
self.trigger_threshold = trigger_threshold
self.positive_examples = deque(maxlen=1000)
self.negative_examples = deque(maxlen=1000)
self.is_training = False
self.training_thread = None
def add_positive_example(self, image: Image.Image, workflow_id: str):
"""Add successful workflow execution example"""
self.positive_examples.append({
'image': image,
'workflow_id': workflow_id,
'timestamp': time.time()
})
self._check_trigger()
def add_negative_example(self, image: Image.Image, workflow_id: str):
"""Add rejected workflow suggestion example"""
self.negative_examples.append({
'image': image,
'workflow_id': workflow_id,
'timestamp': time.time()
})
self._check_trigger()
def _check_trigger(self):
"""Check if we should trigger fine-tuning"""
total_new = len(self.positive_examples) + len(self.negative_examples)
if total_new >= self.trigger_threshold and not self.is_training:
logger.info(f"Triggering fine-tuning with {total_new} new examples")
self._start_training()
def _start_training(self):
"""Start training in background thread"""
self.training_thread = threading.Thread(target=self._train)
self.training_thread.daemon = True
self.training_thread.start()
def _train(self):
"""Fine-tune the last layer of the model"""
self.is_training = True
try:
# Only for models that support fine-tuning (not CLIP base)
if not hasattr(self.embedder, 'fine_tune'):
logger.info("Model doesn't support fine-tuning, skipping")
return
# Prepare training data
positive_images = [ex['image'] for ex in self.positive_examples]
negative_images = [ex['image'] for ex in self.negative_examples]
# Fine-tune (implementation depends on model)
metrics = self.embedder.fine_tune(
positive_images=positive_images,
negative_images=negative_images,
epochs=1,
learning_rate=1e-4
)
logger.info(f"Fine-tuning complete: {metrics}")
# Clear examples after training
self.positive_examples.clear()
self.negative_examples.clear()
except Exception as e:
logger.error(f"Fine-tuning failed: {e}")
finally:
self.is_training = False
def save_checkpoint(self, path: str):
"""Save training examples for recovery"""
with open(path, 'wb') as f:
pickle.dump({
'positive': list(self.positive_examples),
'negative': list(self.negative_examples)
}, f)
def load_checkpoint(self, path: str):
"""Load training examples"""
with open(path, 'rb') as f:
data = pickle.load(f)
self.positive_examples.extend(data['positive'])
self.negative_examples.extend(data['negative'])
```
## Data Models
### EmbeddingMetadata
```python
@dataclass
class EmbeddingMetadata:
workflow_id: str
step_index: int
action_type: str
timestamp: float
model_name: str
model_version: str
```
### FineTuningMetrics
```python
@dataclass
class FineTuningMetrics:
loss: float
accuracy: float
positive_examples: int
negative_examples: int
duration_seconds: float
timestamp: float
```
## Correctness Properties
*A property is a characteristic or behavior that should hold true across all valid executions of a system-essentially, a formal statement about what the system should do. Properties serve as the bridge between human-readable specifications and machine-verifiable correctness guarantees.*
### Property 1: Embedding Dimension Consistency
*For any* embedder instance, all generated embeddings SHALL have the same dimension as reported by `get_dimension()`
**Validates: Requirements 1.1, 2.2, 3.3**
### Property 2: FAISS Index Persistence Round-trip
*For any* FAISS index with embeddings, saving then loading SHALL restore the exact same search results
**Validates: Requirements 1.3, 1.4**
### Property 3: Embedder Interface Compatibility
*For any* embedder implementation (CLIP or Pix2Struct), calling `embed()` with a valid PIL image SHALL return a normalized numpy array
**Validates: Requirements 2.1, 2.2, 2.5**
### Property 4: Cache Hit Consistency
*For any* image, generating embeddings twice SHALL return identical results (via caching)
**Validates: Requirements 6.1**
### Property 5: Fallback Reliability
*For any* model loading failure, if fallback is enabled, the system SHALL successfully load CLIP without raising exceptions
**Validates: Requirements 2.4, 3.5**
### Property 6: Fine-tuning Non-blocking
*For any* fine-tuning operation, the main thread SHALL continue generating embeddings without blocking
**Validates: Requirements 5.1, 5.2**
### Property 7: Example Collection Bounds
*For any* sequence of positive/negative examples, the collection SHALL never exceed 1000 examples (deque maxlen)
**Validates: Requirements 4.1, 4.2**
## Error Handling
- **Model loading fails**: Fallback to CLIP if enabled, otherwise raise clear error
- **Dimension mismatch**: Rebuild FAISS index automatically, log warning
- **GPU unavailable**: Fall back to CPU seamlessly
- **Fine-tuning fails**: Log error, keep current model, don't crash
- **Cache overflow**: Evict LRU entries automatically
- **Corrupted index file**: Rebuild from scratch, log error
## Testing Strategy
### Unit Tests
- Test each embedder implementation independently
- Test FAISS index add/search/save/load operations
- Test cache eviction logic
- Test dimension mismatch handling
- Test fallback mechanism
### Property-Based Tests
- Property 1: Dimension consistency across random images
- Property 2: FAISS round-trip with random embeddings
- Property 3: Interface compatibility with random PIL images
- Property 4: Cache consistency with duplicate images
- Property 5: Fallback with simulated failures
- Property 6: Fine-tuning non-blocking with concurrent operations
- Property 7: Example collection bounds with random additions
### Integration Tests
- Test full workflow: image → embedding → FAISS → search
- Test model switching (CLIP ↔ Pix2Struct)
- Test fine-tuning trigger and execution
- Test system behavior under GPU/CPU switching
## Performance
- **Embedding generation**: < 200ms per image (GPU), < 500ms (CPU)
- **FAISS search**: < 10ms for k=5 in index of 10,000 embeddings
- **Cache hit**: < 1ms
- **Fine-tuning**: < 2 minutes for 100 examples
- **Model loading**: < 5 seconds (CLIP), < 10 seconds (Pix2Struct)
- **Memory usage**: ~2GB (CLIP), ~4GB (Pix2Struct), ~500MB (FAISS index for 10k embeddings)