""" Gestionnaire d'embeddings visuels avec OpenCLIP et FAISS. Gère l'encodage d'images, l'indexation et la recherche de similarité. """ import os import json import pickle import numpy as np from pathlib import Path from typing import Dict, List, Optional, Tuple, Any import torch from PIL import Image try: import open_clip except ImportError: open_clip = None try: import faiss except ImportError: faiss = None from .logger import Logger class EmbeddingsManager: """ Gestionnaire d'embeddings visuels utilisant OpenCLIP pour l'encodage et FAISS pour l'indexation et la recherche de similarité. """ def __init__( self, model_name: str = "ViT-B-32", pretrained: str = "openai", index_path: str = "data/faiss_index", device: Optional[str] = None, logger: Optional[Logger] = None ): """ Initialise le gestionnaire d'embeddings. Args: model_name: Nom du modèle OpenCLIP pretrained: Dataset de pré-entraînement index_path: Chemin vers l'index FAISS device: Device PyTorch (cuda/cpu) logger: Instance du logger """ self.model_name = model_name self.pretrained = pretrained self.index_path = Path(index_path) # Forcer CPU pour OpenCLIP pour économiser la mémoire GPU (Qwen3-VL prioritaire) self.device = "cpu" self.logger = logger if self.logger: self.logger.log_action({ "action": "embeddings_manager_init", "device": self.device, "reason": "CPU forcé pour économiser GPU pour Qwen3-VL" }) # Créer le répertoire d'index si nécessaire self.index_path.mkdir(parents=True, exist_ok=True) # Chemins des fichiers self.index_file = self.index_path / "embeddings.index" self.metadata_file = self.index_path / "metadata.pkl" # Initialiser le modèle et l'index self.clip_model = None self.preprocess = None self.faiss_index = None self.metadata_store: Dict[int, Dict[str, Any]] = {} self.embedding_dim = 512 # Dimension par défaut pour ViT-B-32 self._load_model() self._load_or_create_index() def _load_model(self): """Charge le modèle OpenCLIP.""" if open_clip is None: raise ImportError( "OpenCLIP n'est pas installé. " "Installez-le avec: pip install open-clip-torch" ) try: self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms( self.model_name, pretrained=self.pretrained, device=self.device ) self.clip_model.eval() # Obtenir la dimension d'embedding réelle with torch.no_grad(): dummy_image = torch.zeros(1, 3, 224, 224).to(self.device) dummy_embedding = self.clip_model.encode_image(dummy_image) self.embedding_dim = dummy_embedding.shape[-1] if self.logger: self.logger.log_action({ "action": "model_loaded", "model": self.model_name, "device": self.device, "embedding_dim": self.embedding_dim }) except Exception as e: error_msg = f"Erreur lors du chargement du modèle OpenCLIP: {e}" if self.logger: self.logger.log_action({ "action": "model_load_error", "error": str(e) }) raise RuntimeError(error_msg) def _load_or_create_index(self): """Charge l'index FAISS existant ou en crée un nouveau.""" if faiss is None: raise ImportError( "FAISS n'est pas installé. " "Installez-le avec: pip install faiss-cpu ou faiss-gpu" ) # Charger l'index existant if self.index_file.exists() and self.metadata_file.exists(): try: self.faiss_index = faiss.read_index(str(self.index_file)) with open(self.metadata_file, 'rb') as f: self.metadata_store = pickle.load(f) if self.logger: self.logger.log_action({ "action": "index_loaded", "num_vectors": self.faiss_index.ntotal, "path": str(self.index_file) }) return except Exception as e: if self.logger: self.logger.log_action({ "action": "index_load_error", "error": str(e) }) # Continuer pour créer un nouvel index # Créer un nouvel index self.faiss_index = faiss.IndexFlatL2(self.embedding_dim) self.metadata_store = {} if self.logger: self.logger.log_action({ "action": "index_created", "embedding_dim": self.embedding_dim }) def encode_image(self, image: np.ndarray) -> np.ndarray: """ Génère un embedding 512-d pour une image. Args: image: Image numpy array (H, W, C) en BGR ou RGB Returns: Embedding numpy array de forme (embedding_dim,) """ try: # Convertir BGR vers RGB si nécessaire if len(image.shape) == 3 and image.shape[2] == 3: # Supposer BGR (format OpenCV) image_rgb = image[:, :, ::-1] else: image_rgb = image # Convertir en PIL Image pil_image = Image.fromarray(image_rgb.astype(np.uint8)) # Prétraiter l'image image_tensor = self.preprocess(pil_image).unsqueeze(0).to(self.device) # Générer l'embedding with torch.no_grad(): embedding = self.clip_model.encode_image(image_tensor) embedding = embedding / embedding.norm(dim=-1, keepdim=True) # Normaliser return embedding.cpu().numpy().flatten() except Exception as e: if self.logger: self.logger.log_action({ "action": "encoding_error", "error": str(e) }) raise RuntimeError(f"Erreur lors de l'encodage de l'image: {e}") def add_to_index(self, embedding: np.ndarray, metadata: Dict[str, Any]) -> int: """ Ajoute un embedding à l'index FAISS avec ses métadonnées. Args: embedding: Embedding numpy array metadata: Dictionnaire de métadonnées associées Returns: ID de l'embedding dans l'index """ try: # Obtenir l'ID avant l'ajout idx = self.faiss_index.ntotal # Ajouter à l'index FAISS embedding_2d = embedding.reshape(1, -1).astype(np.float32) self.faiss_index.add(embedding_2d) # Stocker les métadonnées self.metadata_store[idx] = metadata if self.logger: self.logger.log_action({ "action": "embedding_added", "id": idx, "metadata": metadata }) return idx except Exception as e: if self.logger: self.logger.log_action({ "action": "add_to_index_error", "error": str(e) }) raise RuntimeError(f"Erreur lors de l'ajout à l'index: {e}") def search_similar( self, query_embedding: np.ndarray, k: int = 5 ) -> List[Dict[str, Any]]: """ Recherche les k embeddings les plus similaires. Args: query_embedding: Embedding de requête k: Nombre de résultats à retourner Returns: Liste de dictionnaires avec id, distance et metadata """ try: if self.faiss_index.ntotal == 0: return [] # Limiter k au nombre d'embeddings disponibles k = min(k, self.faiss_index.ntotal) # Rechercher query_2d = query_embedding.reshape(1, -1).astype(np.float32) distances, indices = self.faiss_index.search(query_2d, k) # Formater les résultats results = [] for dist, idx in zip(distances[0], indices[0]): if idx != -1: # FAISS retourne -1 si pas assez de résultats results.append({ "id": int(idx), "distance": float(dist), "similarity": float(1.0 / (1.0 + dist)), # Convertir distance en similarité "metadata": self.metadata_store.get(int(idx), {}) }) return results except Exception as e: if self.logger: self.logger.log_action({ "action": "search_error", "error": str(e) }) raise RuntimeError(f"Erreur lors de la recherche: {e}") def get_embedding_similarity( self, emb1: np.ndarray, emb2: np.ndarray ) -> float: """ Calcule la similarité cosinus entre deux embeddings. Args: emb1: Premier embedding emb2: Deuxième embedding Returns: Similarité cosinus (0-1) """ try: # Normaliser les embeddings emb1_norm = emb1 / np.linalg.norm(emb1) emb2_norm = emb2 / np.linalg.norm(emb2) # Calculer la similarité cosinus similarity = np.dot(emb1_norm, emb2_norm) # Convertir de [-1, 1] à [0, 1] similarity = (similarity + 1.0) / 2.0 return float(similarity) except Exception as e: if self.logger: self.logger.log_action({ "action": "similarity_calculation_error", "error": str(e) }) return 0.0 def rebuild_index(self): """Reconstruit l'index FAISS à partir des embeddings stockés.""" try: if self.faiss_index.ntotal == 0: if self.logger: self.logger.log_action({ "action": "rebuild_skipped", "reason": "index_empty" }) return # Extraire tous les embeddings all_embeddings = [] for i in range(self.faiss_index.ntotal): embedding = self.faiss_index.reconstruct(i) all_embeddings.append(embedding) # Créer un nouvel index new_index = faiss.IndexFlatL2(self.embedding_dim) # Ajouter tous les embeddings embeddings_array = np.array(all_embeddings).astype(np.float32) new_index.add(embeddings_array) # Remplacer l'ancien index self.faiss_index = new_index if self.logger: self.logger.log_action({ "action": "index_rebuilt", "num_vectors": self.faiss_index.ntotal }) except Exception as e: if self.logger: self.logger.log_action({ "action": "rebuild_error", "error": str(e) }) raise RuntimeError(f"Erreur lors de la reconstruction de l'index: {e}") def save_index(self): """Sauvegarde l'index FAISS et les métadonnées sur disque.""" try: # Sauvegarder l'index FAISS faiss.write_index(self.faiss_index, str(self.index_file)) # Sauvegarder les métadonnées with open(self.metadata_file, 'wb') as f: pickle.dump(self.metadata_store, f) if self.logger: self.logger.log_action({ "action": "index_saved", "num_vectors": self.faiss_index.ntotal, "path": str(self.index_file) }) except Exception as e: if self.logger: self.logger.log_action({ "action": "save_error", "error": str(e) }) raise RuntimeError(f"Erreur lors de la sauvegarde de l'index: {e}") def get_stats(self) -> Dict[str, Any]: """ Retourne des statistiques sur l'index. Returns: Dictionnaire de statistiques """ return { "num_embeddings": self.faiss_index.ntotal, "embedding_dim": self.embedding_dim, "model_name": self.model_name, "device": self.device, "index_path": str(self.index_file) } def clear_index(self): """Efface tous les embeddings de l'index.""" self.faiss_index = faiss.IndexFlatL2(self.embedding_dim) self.metadata_store = {} if self.logger: self.logger.log_action({ "action": "index_cleared" })