414 lines
14 KiB
Python
414 lines
14 KiB
Python
"""
|
|
Gestionnaire d'embeddings visuels avec OpenCLIP et FAISS.
|
|
Gère l'encodage d'images, l'indexation et la recherche de similarité.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import pickle
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
import torch
|
|
from PIL import Image
|
|
|
|
try:
|
|
import open_clip
|
|
except ImportError:
|
|
open_clip = None
|
|
|
|
try:
|
|
import faiss
|
|
except ImportError:
|
|
faiss = None
|
|
|
|
from .logger import Logger
|
|
|
|
|
|
class EmbeddingsManager:
|
|
"""
|
|
Gestionnaire d'embeddings visuels utilisant OpenCLIP pour l'encodage
|
|
et FAISS pour l'indexation et la recherche de similarité.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "ViT-B-32",
|
|
pretrained: str = "openai",
|
|
index_path: str = "data/faiss_index",
|
|
device: Optional[str] = None,
|
|
logger: Optional[Logger] = None
|
|
):
|
|
"""
|
|
Initialise le gestionnaire d'embeddings.
|
|
|
|
Args:
|
|
model_name: Nom du modèle OpenCLIP
|
|
pretrained: Dataset de pré-entraînement
|
|
index_path: Chemin vers l'index FAISS
|
|
device: Device PyTorch (cuda/cpu)
|
|
logger: Instance du logger
|
|
"""
|
|
self.model_name = model_name
|
|
self.pretrained = pretrained
|
|
self.index_path = Path(index_path)
|
|
# Forcer CPU pour OpenCLIP pour économiser la mémoire GPU (Qwen3-VL prioritaire)
|
|
self.device = "cpu"
|
|
self.logger = logger
|
|
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "embeddings_manager_init",
|
|
"device": self.device,
|
|
"reason": "CPU forcé pour économiser GPU pour Qwen3-VL"
|
|
})
|
|
|
|
# Créer le répertoire d'index si nécessaire
|
|
self.index_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Chemins des fichiers
|
|
self.index_file = self.index_path / "embeddings.index"
|
|
self.metadata_file = self.index_path / "metadata.pkl"
|
|
|
|
# Initialiser le modèle et l'index
|
|
self.clip_model = None
|
|
self.preprocess = None
|
|
self.faiss_index = None
|
|
self.metadata_store: Dict[int, Dict[str, Any]] = {}
|
|
self.embedding_dim = 512 # Dimension par défaut pour ViT-B-32
|
|
|
|
self._load_model()
|
|
self._load_or_create_index()
|
|
|
|
def _load_model(self):
|
|
"""Charge le modèle OpenCLIP."""
|
|
if open_clip is None:
|
|
raise ImportError(
|
|
"OpenCLIP n'est pas installé. "
|
|
"Installez-le avec: pip install open-clip-torch"
|
|
)
|
|
|
|
try:
|
|
self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(
|
|
self.model_name,
|
|
pretrained=self.pretrained,
|
|
device=self.device
|
|
)
|
|
self.clip_model.eval()
|
|
|
|
# Obtenir la dimension d'embedding réelle
|
|
with torch.no_grad():
|
|
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
|
|
dummy_embedding = self.clip_model.encode_image(dummy_image)
|
|
self.embedding_dim = dummy_embedding.shape[-1]
|
|
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "model_loaded",
|
|
"model": self.model_name,
|
|
"device": self.device,
|
|
"embedding_dim": self.embedding_dim
|
|
})
|
|
|
|
except Exception as e:
|
|
error_msg = f"Erreur lors du chargement du modèle OpenCLIP: {e}"
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "model_load_error",
|
|
"error": str(e)
|
|
})
|
|
raise RuntimeError(error_msg)
|
|
|
|
def _load_or_create_index(self):
|
|
"""Charge l'index FAISS existant ou en crée un nouveau."""
|
|
if faiss is None:
|
|
raise ImportError(
|
|
"FAISS n'est pas installé. "
|
|
"Installez-le avec: pip install faiss-cpu ou faiss-gpu"
|
|
)
|
|
|
|
# Charger l'index existant
|
|
if self.index_file.exists() and self.metadata_file.exists():
|
|
try:
|
|
self.faiss_index = faiss.read_index(str(self.index_file))
|
|
with open(self.metadata_file, 'rb') as f:
|
|
self.metadata_store = pickle.load(f)
|
|
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "index_loaded",
|
|
"num_vectors": self.faiss_index.ntotal,
|
|
"path": str(self.index_file)
|
|
})
|
|
return
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "index_load_error",
|
|
"error": str(e)
|
|
})
|
|
# Continuer pour créer un nouvel index
|
|
|
|
# Créer un nouvel index
|
|
self.faiss_index = faiss.IndexFlatL2(self.embedding_dim)
|
|
self.metadata_store = {}
|
|
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "index_created",
|
|
"embedding_dim": self.embedding_dim
|
|
})
|
|
|
|
def encode_image(self, image: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Génère un embedding 512-d pour une image.
|
|
|
|
Args:
|
|
image: Image numpy array (H, W, C) en BGR ou RGB
|
|
|
|
Returns:
|
|
Embedding numpy array de forme (embedding_dim,)
|
|
"""
|
|
try:
|
|
# Convertir BGR vers RGB si nécessaire
|
|
if len(image.shape) == 3 and image.shape[2] == 3:
|
|
# Supposer BGR (format OpenCV)
|
|
image_rgb = image[:, :, ::-1]
|
|
else:
|
|
image_rgb = image
|
|
|
|
# Convertir en PIL Image
|
|
pil_image = Image.fromarray(image_rgb.astype(np.uint8))
|
|
|
|
# Prétraiter l'image
|
|
image_tensor = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
|
|
|
# Générer l'embedding
|
|
with torch.no_grad():
|
|
embedding = self.clip_model.encode_image(image_tensor)
|
|
embedding = embedding / embedding.norm(dim=-1, keepdim=True) # Normaliser
|
|
|
|
return embedding.cpu().numpy().flatten()
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "encoding_error",
|
|
"error": str(e)
|
|
})
|
|
raise RuntimeError(f"Erreur lors de l'encodage de l'image: {e}")
|
|
|
|
def add_to_index(self, embedding: np.ndarray, metadata: Dict[str, Any]) -> int:
|
|
"""
|
|
Ajoute un embedding à l'index FAISS avec ses métadonnées.
|
|
|
|
Args:
|
|
embedding: Embedding numpy array
|
|
metadata: Dictionnaire de métadonnées associées
|
|
|
|
Returns:
|
|
ID de l'embedding dans l'index
|
|
"""
|
|
try:
|
|
# Obtenir l'ID avant l'ajout
|
|
idx = self.faiss_index.ntotal
|
|
|
|
# Ajouter à l'index FAISS
|
|
embedding_2d = embedding.reshape(1, -1).astype(np.float32)
|
|
self.faiss_index.add(embedding_2d)
|
|
|
|
# Stocker les métadonnées
|
|
self.metadata_store[idx] = metadata
|
|
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "embedding_added",
|
|
"id": idx,
|
|
"metadata": metadata
|
|
})
|
|
|
|
return idx
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "add_to_index_error",
|
|
"error": str(e)
|
|
})
|
|
raise RuntimeError(f"Erreur lors de l'ajout à l'index: {e}")
|
|
|
|
def search_similar(
|
|
self,
|
|
query_embedding: np.ndarray,
|
|
k: int = 5
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Recherche les k embeddings les plus similaires.
|
|
|
|
Args:
|
|
query_embedding: Embedding de requête
|
|
k: Nombre de résultats à retourner
|
|
|
|
Returns:
|
|
Liste de dictionnaires avec id, distance et metadata
|
|
"""
|
|
try:
|
|
if self.faiss_index.ntotal == 0:
|
|
return []
|
|
|
|
# Limiter k au nombre d'embeddings disponibles
|
|
k = min(k, self.faiss_index.ntotal)
|
|
|
|
# Rechercher
|
|
query_2d = query_embedding.reshape(1, -1).astype(np.float32)
|
|
distances, indices = self.faiss_index.search(query_2d, k)
|
|
|
|
# Formater les résultats
|
|
results = []
|
|
for dist, idx in zip(distances[0], indices[0]):
|
|
if idx != -1: # FAISS retourne -1 si pas assez de résultats
|
|
results.append({
|
|
"id": int(idx),
|
|
"distance": float(dist),
|
|
"similarity": float(1.0 / (1.0 + dist)), # Convertir distance en similarité
|
|
"metadata": self.metadata_store.get(int(idx), {})
|
|
})
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "search_error",
|
|
"error": str(e)
|
|
})
|
|
raise RuntimeError(f"Erreur lors de la recherche: {e}")
|
|
|
|
def get_embedding_similarity(
|
|
self,
|
|
emb1: np.ndarray,
|
|
emb2: np.ndarray
|
|
) -> float:
|
|
"""
|
|
Calcule la similarité cosinus entre deux embeddings.
|
|
|
|
Args:
|
|
emb1: Premier embedding
|
|
emb2: Deuxième embedding
|
|
|
|
Returns:
|
|
Similarité cosinus (0-1)
|
|
"""
|
|
try:
|
|
# Normaliser les embeddings
|
|
emb1_norm = emb1 / np.linalg.norm(emb1)
|
|
emb2_norm = emb2 / np.linalg.norm(emb2)
|
|
|
|
# Calculer la similarité cosinus
|
|
similarity = np.dot(emb1_norm, emb2_norm)
|
|
|
|
# Convertir de [-1, 1] à [0, 1]
|
|
similarity = (similarity + 1.0) / 2.0
|
|
|
|
return float(similarity)
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "similarity_calculation_error",
|
|
"error": str(e)
|
|
})
|
|
return 0.0
|
|
|
|
def rebuild_index(self):
|
|
"""Reconstruit l'index FAISS à partir des embeddings stockés."""
|
|
try:
|
|
if self.faiss_index.ntotal == 0:
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "rebuild_skipped",
|
|
"reason": "index_empty"
|
|
})
|
|
return
|
|
|
|
# Extraire tous les embeddings
|
|
all_embeddings = []
|
|
for i in range(self.faiss_index.ntotal):
|
|
embedding = self.faiss_index.reconstruct(i)
|
|
all_embeddings.append(embedding)
|
|
|
|
# Créer un nouvel index
|
|
new_index = faiss.IndexFlatL2(self.embedding_dim)
|
|
|
|
# Ajouter tous les embeddings
|
|
embeddings_array = np.array(all_embeddings).astype(np.float32)
|
|
new_index.add(embeddings_array)
|
|
|
|
# Remplacer l'ancien index
|
|
self.faiss_index = new_index
|
|
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "index_rebuilt",
|
|
"num_vectors": self.faiss_index.ntotal
|
|
})
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "rebuild_error",
|
|
"error": str(e)
|
|
})
|
|
raise RuntimeError(f"Erreur lors de la reconstruction de l'index: {e}")
|
|
|
|
def save_index(self):
|
|
"""Sauvegarde l'index FAISS et les métadonnées sur disque."""
|
|
try:
|
|
# Sauvegarder l'index FAISS
|
|
faiss.write_index(self.faiss_index, str(self.index_file))
|
|
|
|
# Sauvegarder les métadonnées
|
|
with open(self.metadata_file, 'wb') as f:
|
|
pickle.dump(self.metadata_store, f)
|
|
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "index_saved",
|
|
"num_vectors": self.faiss_index.ntotal,
|
|
"path": str(self.index_file)
|
|
})
|
|
|
|
except Exception as e:
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "save_error",
|
|
"error": str(e)
|
|
})
|
|
raise RuntimeError(f"Erreur lors de la sauvegarde de l'index: {e}")
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
Retourne des statistiques sur l'index.
|
|
|
|
Returns:
|
|
Dictionnaire de statistiques
|
|
"""
|
|
return {
|
|
"num_embeddings": self.faiss_index.ntotal,
|
|
"embedding_dim": self.embedding_dim,
|
|
"model_name": self.model_name,
|
|
"device": self.device,
|
|
"index_path": str(self.index_file)
|
|
}
|
|
|
|
def clear_index(self):
|
|
"""Efface tous les embeddings de l'index."""
|
|
self.faiss_index = faiss.IndexFlatL2(self.embedding_dim)
|
|
self.metadata_store = {}
|
|
|
|
if self.logger:
|
|
self.logger.log_action({
|
|
"action": "index_cleared"
|
|
})
|