509 lines
20 KiB
Markdown
509 lines
20 KiB
Markdown
# Design - Amélioration du Système d'Embeddings et Fine-tuning
|
|
|
|
## Overview
|
|
|
|
Le système d'embeddings actuel utilise CLIP avec FAISS mais souffre de problèmes de dimensions et de performance. Cette amélioration corrige FAISS, crée une abstraction pour supporter plusieurs modèles (CLIP baseline, Pix2Struct pour UI), et ajoute un fine-tuning léger incrémental qui s'exécute toutes les 10 nouvelles interactions pour améliorer la précision.
|
|
|
|
## Architecture
|
|
|
|
```
|
|
┌─────────────────────────────────────────────────────────────┐
|
|
│ Embedding System │
|
|
├─────────────────────────────────────────────────────────────┤
|
|
│ │
|
|
│ ┌──────────────┐ ┌──────────────┐ │
|
|
│ │ EmbedderBase │◄─────│ CLIPEmbedder │ │
|
|
│ │ (Abstract) │ └──────────────┘ │
|
|
│ └──────┬───────┘ │
|
|
│ │ │
|
|
│ │ ┌────────────────────┐ │
|
|
│ └──────────────│ Pix2StructEmbedder │ │
|
|
│ └────────────────────┘ │
|
|
│ │
|
|
│ ┌────────────────────────────────────────────────────┐ │
|
|
│ │ EmbeddingManager │ │
|
|
│ │ - Model selection & fallback │ │
|
|
│ │ - Caching (LRU 1000 entries) │ │
|
|
│ │ - GPU/CPU management │ │
|
|
│ └────────────────────────────────────────────────────┘ │
|
|
│ │
|
|
│ ┌────────────────────────────────────────────────────┐ │
|
|
│ │ FAISSIndex │ │
|
|
│ │ - Fixed dimension handling │ │
|
|
│ │ - Persistence (index + metadata) │ │
|
|
│ │ - Auto-rebuild on dimension change │ │
|
|
│ └────────────────────────────────────────────────────┘ │
|
|
│ │
|
|
│ ┌────────────────────────────────────────────────────┐ │
|
|
│ │ LightweightFineTuner │ │
|
|
│ │ - Collects positive/negative examples │ │
|
|
│ │ - Triggers every 10 examples │ │
|
|
│ │ - Runs in background thread │ │
|
|
│ │ - Updates last layer only │ │
|
|
│ └────────────────────────────────────────────────────┘ │
|
|
│ │
|
|
└─────────────────────────────────────────────────────────────┘
|
|
│ │ │
|
|
▼ ▼ ▼
|
|
WorkflowMatcher VisionAnalysis SessionManager
|
|
```
|
|
|
|
## Components and Interfaces
|
|
|
|
### 1. EmbedderBase (Abstract Interface)
|
|
|
|
**Responsabilité**: Interface commune pour tous les modèles d'embedding
|
|
|
|
```python
|
|
from abc import ABC, abstractmethod
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
class EmbedderBase(ABC):
|
|
@abstractmethod
|
|
def embed(self, image: Image.Image) -> np.ndarray:
|
|
"""Generate embedding for an image
|
|
|
|
Returns:
|
|
np.ndarray: Normalized embedding vector
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_dimension(self) -> int:
|
|
"""Get embedding dimension"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_model_name(self) -> str:
|
|
"""Get model identifier"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def supports_batch(self) -> bool:
|
|
"""Check if model supports batch processing"""
|
|
pass
|
|
|
|
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
|
"""Generate embeddings for multiple images (optional optimization)"""
|
|
return np.array([self.embed(img) for img in images])
|
|
```
|
|
|
|
### 2. CLIPEmbedder (Concrete Implementation)
|
|
|
|
**Responsabilité**: Wrapper pour CLIP avec l'interface commune
|
|
|
|
```python
|
|
class CLIPEmbedder(EmbedderBase):
|
|
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
|
self.device = device
|
|
self.model, self.preprocess = clip.load("ViT-B/32", device=device)
|
|
self.model.eval()
|
|
|
|
def embed(self, image: Image.Image) -> np.ndarray:
|
|
with torch.no_grad():
|
|
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
|
|
embedding = self.model.encode_image(image_tensor)
|
|
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
|
|
return embedding.cpu().numpy().flatten()
|
|
|
|
def get_dimension(self) -> int:
|
|
return 512 # CLIP ViT-B/32 dimension
|
|
|
|
def get_model_name(self) -> str:
|
|
return "clip-vit-b32"
|
|
|
|
def supports_batch(self) -> bool:
|
|
return True
|
|
```
|
|
|
|
### 3. Pix2StructEmbedder (Concrete Implementation)
|
|
|
|
**Responsabilité**: Wrapper pour Pix2Struct spécialisé UI
|
|
|
|
```python
|
|
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
|
|
|
class Pix2StructEmbedder(EmbedderBase):
|
|
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
|
self.device = device
|
|
self.processor = Pix2StructProcessor.from_pretrained(
|
|
"google/pix2struct-base"
|
|
)
|
|
self.model = Pix2StructForConditionalGeneration.from_pretrained(
|
|
"google/pix2struct-base"
|
|
).to(device)
|
|
self.model.eval()
|
|
|
|
def embed(self, image: Image.Image) -> np.ndarray:
|
|
with torch.no_grad():
|
|
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
|
# Extract encoder hidden states as embeddings
|
|
encoder_outputs = self.model.encoder(**inputs)
|
|
embedding = encoder_outputs.last_hidden_state.mean(dim=1)
|
|
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
|
|
return embedding.cpu().numpy().flatten()
|
|
|
|
def get_dimension(self) -> int:
|
|
return 768 # Pix2Struct base dimension
|
|
|
|
def get_model_name(self) -> str:
|
|
return "pix2struct-base"
|
|
|
|
def supports_batch(self) -> bool:
|
|
return True
|
|
```
|
|
|
|
### 4. EmbeddingManager
|
|
|
|
**Responsabilité**: Gestion des modèles, cache, et sélection
|
|
|
|
```python
|
|
from functools import lru_cache
|
|
from typing import Optional
|
|
|
|
class EmbeddingManager:
|
|
def __init__(self, model_name: str = "clip", fallback: bool = True):
|
|
self.model_name = model_name
|
|
self.fallback_enabled = fallback
|
|
self.embedder = self._load_embedder()
|
|
self._cache = {} # Manual LRU cache
|
|
self._cache_order = []
|
|
self.max_cache_size = 1000
|
|
|
|
def _load_embedder(self) -> EmbedderBase:
|
|
"""Load embedder with fallback"""
|
|
try:
|
|
if self.model_name == "pix2struct":
|
|
return Pix2StructEmbedder()
|
|
elif self.model_name == "clip":
|
|
return CLIPEmbedder()
|
|
else:
|
|
raise ValueError(f"Unknown model: {self.model_name}")
|
|
except Exception as e:
|
|
if self.fallback_enabled:
|
|
logger.warning(f"Failed to load {self.model_name}, falling back to CLIP: {e}")
|
|
return CLIPEmbedder()
|
|
raise
|
|
|
|
def embed(self, image: Image.Image) -> np.ndarray:
|
|
"""Generate embedding with caching"""
|
|
# Create cache key from image hash
|
|
img_hash = hashlib.md5(image.tobytes()).hexdigest()
|
|
|
|
if img_hash in self._cache:
|
|
return self._cache[img_hash]
|
|
|
|
# Generate embedding
|
|
embedding = self.embedder.embed(image)
|
|
|
|
# Update cache
|
|
self._cache[img_hash] = embedding
|
|
self._cache_order.append(img_hash)
|
|
|
|
# Evict if needed
|
|
if len(self._cache) > self.max_cache_size:
|
|
oldest = self._cache_order.pop(0)
|
|
del self._cache[oldest]
|
|
|
|
return embedding
|
|
|
|
def get_dimension(self) -> int:
|
|
return self.embedder.get_dimension()
|
|
|
|
def get_model_name(self) -> str:
|
|
return self.embedder.get_model_name()
|
|
```
|
|
|
|
### 5. FAISSIndex (Fixed)
|
|
|
|
**Responsabilité**: Gestion de l'index FAISS avec correction des bugs
|
|
|
|
```python
|
|
import faiss
|
|
import pickle
|
|
|
|
class FAISSIndex:
|
|
def __init__(self, dimension: int):
|
|
self.dimension = dimension
|
|
self.index = faiss.IndexFlatL2(dimension)
|
|
self.metadata = [] # Store workflow IDs and step info
|
|
|
|
def add(self, embeddings: np.ndarray, metadata: List[dict]):
|
|
"""Add embeddings to index
|
|
|
|
Args:
|
|
embeddings: Shape (N, dimension)
|
|
metadata: List of N metadata dicts
|
|
"""
|
|
if embeddings.shape[1] != self.dimension:
|
|
raise ValueError(
|
|
f"Embedding dimension {embeddings.shape[1]} doesn't match "
|
|
f"index dimension {self.dimension}"
|
|
)
|
|
|
|
self.index.add(embeddings.astype('float32'))
|
|
self.metadata.extend(metadata)
|
|
|
|
def search(self, query: np.ndarray, k: int = 5) -> List[dict]:
|
|
"""Search for similar embeddings
|
|
|
|
Args:
|
|
query: Shape (1, dimension) or (dimension,)
|
|
k: Number of results
|
|
|
|
Returns:
|
|
List of dicts with 'distance', 'metadata'
|
|
"""
|
|
if query.ndim == 1:
|
|
query = query.reshape(1, -1)
|
|
|
|
if query.shape[1] != self.dimension:
|
|
raise ValueError(
|
|
f"Query dimension {query.shape[1]} doesn't match "
|
|
f"index dimension {self.dimension}"
|
|
)
|
|
|
|
distances, indices = self.index.search(query.astype('float32'), k)
|
|
|
|
results = []
|
|
for dist, idx in zip(distances[0], indices[0]):
|
|
if idx < len(self.metadata):
|
|
results.append({
|
|
'distance': float(dist),
|
|
'similarity': 1.0 / (1.0 + dist), # Convert to similarity
|
|
'metadata': self.metadata[idx]
|
|
})
|
|
|
|
return results
|
|
|
|
def save(self, path: str):
|
|
"""Save index and metadata"""
|
|
faiss.write_index(self.index, f"{path}.index")
|
|
with open(f"{path}.metadata", 'wb') as f:
|
|
pickle.dump({
|
|
'dimension': self.dimension,
|
|
'metadata': self.metadata
|
|
}, f)
|
|
|
|
def load(self, path: str):
|
|
"""Load index and metadata"""
|
|
self.index = faiss.read_index(f"{path}.index")
|
|
with open(f"{path}.metadata", 'rb') as f:
|
|
data = pickle.load(f)
|
|
self.dimension = data['dimension']
|
|
self.metadata = data['metadata']
|
|
|
|
def rebuild_if_needed(self, new_dimension: int):
|
|
"""Rebuild index if dimension changed"""
|
|
if new_dimension != self.dimension:
|
|
logger.info(f"Rebuilding FAISS index: {self.dimension} -> {new_dimension}")
|
|
old_metadata = self.metadata
|
|
self.dimension = new_dimension
|
|
self.index = faiss.IndexFlatL2(new_dimension)
|
|
self.metadata = []
|
|
# Note: Old embeddings are lost, need to regenerate
|
|
return True
|
|
return False
|
|
```
|
|
|
|
### 6. LightweightFineTuner
|
|
|
|
**Responsabilité**: Fine-tuning incrémental en arrière-plan
|
|
|
|
```python
|
|
import threading
|
|
from collections import deque
|
|
|
|
class LightweightFineTuner:
|
|
def __init__(self, embedder: EmbedderBase, trigger_threshold: int = 10):
|
|
self.embedder = embedder
|
|
self.trigger_threshold = trigger_threshold
|
|
self.positive_examples = deque(maxlen=1000)
|
|
self.negative_examples = deque(maxlen=1000)
|
|
self.is_training = False
|
|
self.training_thread = None
|
|
|
|
def add_positive_example(self, image: Image.Image, workflow_id: str):
|
|
"""Add successful workflow execution example"""
|
|
self.positive_examples.append({
|
|
'image': image,
|
|
'workflow_id': workflow_id,
|
|
'timestamp': time.time()
|
|
})
|
|
self._check_trigger()
|
|
|
|
def add_negative_example(self, image: Image.Image, workflow_id: str):
|
|
"""Add rejected workflow suggestion example"""
|
|
self.negative_examples.append({
|
|
'image': image,
|
|
'workflow_id': workflow_id,
|
|
'timestamp': time.time()
|
|
})
|
|
self._check_trigger()
|
|
|
|
def _check_trigger(self):
|
|
"""Check if we should trigger fine-tuning"""
|
|
total_new = len(self.positive_examples) + len(self.negative_examples)
|
|
|
|
if total_new >= self.trigger_threshold and not self.is_training:
|
|
logger.info(f"Triggering fine-tuning with {total_new} new examples")
|
|
self._start_training()
|
|
|
|
def _start_training(self):
|
|
"""Start training in background thread"""
|
|
self.training_thread = threading.Thread(target=self._train)
|
|
self.training_thread.daemon = True
|
|
self.training_thread.start()
|
|
|
|
def _train(self):
|
|
"""Fine-tune the last layer of the model"""
|
|
self.is_training = True
|
|
|
|
try:
|
|
# Only for models that support fine-tuning (not CLIP base)
|
|
if not hasattr(self.embedder, 'fine_tune'):
|
|
logger.info("Model doesn't support fine-tuning, skipping")
|
|
return
|
|
|
|
# Prepare training data
|
|
positive_images = [ex['image'] for ex in self.positive_examples]
|
|
negative_images = [ex['image'] for ex in self.negative_examples]
|
|
|
|
# Fine-tune (implementation depends on model)
|
|
metrics = self.embedder.fine_tune(
|
|
positive_images=positive_images,
|
|
negative_images=negative_images,
|
|
epochs=1,
|
|
learning_rate=1e-4
|
|
)
|
|
|
|
logger.info(f"Fine-tuning complete: {metrics}")
|
|
|
|
# Clear examples after training
|
|
self.positive_examples.clear()
|
|
self.negative_examples.clear()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Fine-tuning failed: {e}")
|
|
|
|
finally:
|
|
self.is_training = False
|
|
|
|
def save_checkpoint(self, path: str):
|
|
"""Save training examples for recovery"""
|
|
with open(path, 'wb') as f:
|
|
pickle.dump({
|
|
'positive': list(self.positive_examples),
|
|
'negative': list(self.negative_examples)
|
|
}, f)
|
|
|
|
def load_checkpoint(self, path: str):
|
|
"""Load training examples"""
|
|
with open(path, 'rb') as f:
|
|
data = pickle.load(f)
|
|
self.positive_examples.extend(data['positive'])
|
|
self.negative_examples.extend(data['negative'])
|
|
```
|
|
|
|
## Data Models
|
|
|
|
### EmbeddingMetadata
|
|
```python
|
|
@dataclass
|
|
class EmbeddingMetadata:
|
|
workflow_id: str
|
|
step_index: int
|
|
action_type: str
|
|
timestamp: float
|
|
model_name: str
|
|
model_version: str
|
|
```
|
|
|
|
### FineTuningMetrics
|
|
```python
|
|
@dataclass
|
|
class FineTuningMetrics:
|
|
loss: float
|
|
accuracy: float
|
|
positive_examples: int
|
|
negative_examples: int
|
|
duration_seconds: float
|
|
timestamp: float
|
|
```
|
|
|
|
## Correctness Properties
|
|
|
|
*A property is a characteristic or behavior that should hold true across all valid executions of a system-essentially, a formal statement about what the system should do. Properties serve as the bridge between human-readable specifications and machine-verifiable correctness guarantees.*
|
|
|
|
### Property 1: Embedding Dimension Consistency
|
|
*For any* embedder instance, all generated embeddings SHALL have the same dimension as reported by `get_dimension()`
|
|
**Validates: Requirements 1.1, 2.2, 3.3**
|
|
|
|
### Property 2: FAISS Index Persistence Round-trip
|
|
*For any* FAISS index with embeddings, saving then loading SHALL restore the exact same search results
|
|
**Validates: Requirements 1.3, 1.4**
|
|
|
|
### Property 3: Embedder Interface Compatibility
|
|
*For any* embedder implementation (CLIP or Pix2Struct), calling `embed()` with a valid PIL image SHALL return a normalized numpy array
|
|
**Validates: Requirements 2.1, 2.2, 2.5**
|
|
|
|
### Property 4: Cache Hit Consistency
|
|
*For any* image, generating embeddings twice SHALL return identical results (via caching)
|
|
**Validates: Requirements 6.1**
|
|
|
|
### Property 5: Fallback Reliability
|
|
*For any* model loading failure, if fallback is enabled, the system SHALL successfully load CLIP without raising exceptions
|
|
**Validates: Requirements 2.4, 3.5**
|
|
|
|
### Property 6: Fine-tuning Non-blocking
|
|
*For any* fine-tuning operation, the main thread SHALL continue generating embeddings without blocking
|
|
**Validates: Requirements 5.1, 5.2**
|
|
|
|
### Property 7: Example Collection Bounds
|
|
*For any* sequence of positive/negative examples, the collection SHALL never exceed 1000 examples (deque maxlen)
|
|
**Validates: Requirements 4.1, 4.2**
|
|
|
|
## Error Handling
|
|
|
|
- **Model loading fails**: Fallback to CLIP if enabled, otherwise raise clear error
|
|
- **Dimension mismatch**: Rebuild FAISS index automatically, log warning
|
|
- **GPU unavailable**: Fall back to CPU seamlessly
|
|
- **Fine-tuning fails**: Log error, keep current model, don't crash
|
|
- **Cache overflow**: Evict LRU entries automatically
|
|
- **Corrupted index file**: Rebuild from scratch, log error
|
|
|
|
## Testing Strategy
|
|
|
|
### Unit Tests
|
|
- Test each embedder implementation independently
|
|
- Test FAISS index add/search/save/load operations
|
|
- Test cache eviction logic
|
|
- Test dimension mismatch handling
|
|
- Test fallback mechanism
|
|
|
|
### Property-Based Tests
|
|
- Property 1: Dimension consistency across random images
|
|
- Property 2: FAISS round-trip with random embeddings
|
|
- Property 3: Interface compatibility with random PIL images
|
|
- Property 4: Cache consistency with duplicate images
|
|
- Property 5: Fallback with simulated failures
|
|
- Property 6: Fine-tuning non-blocking with concurrent operations
|
|
- Property 7: Example collection bounds with random additions
|
|
|
|
### Integration Tests
|
|
- Test full workflow: image → embedding → FAISS → search
|
|
- Test model switching (CLIP ↔ Pix2Struct)
|
|
- Test fine-tuning trigger and execution
|
|
- Test system behavior under GPU/CPU switching
|
|
|
|
## Performance
|
|
|
|
- **Embedding generation**: < 200ms per image (GPU), < 500ms (CPU)
|
|
- **FAISS search**: < 10ms for k=5 in index of 10,000 embeddings
|
|
- **Cache hit**: < 1ms
|
|
- **Fine-tuning**: < 2 minutes for 100 examples
|
|
- **Model loading**: < 5 seconds (CLIP), < 10 seconds (Pix2Struct)
|
|
- **Memory usage**: ~2GB (CLIP), ~4GB (Pix2Struct), ~500MB (FAISS index for 10k embeddings)
|
|
|