Files
Geniusia_v2/geniusia2/core/embeddings_manager.py
2026-03-05 00:20:25 +01:00

414 lines
14 KiB
Python

"""
Gestionnaire d'embeddings visuels avec OpenCLIP et FAISS.
Gère l'encodage d'images, l'indexation et la recherche de similarité.
"""
import os
import json
import pickle
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
import torch
from PIL import Image
try:
import open_clip
except ImportError:
open_clip = None
try:
import faiss
except ImportError:
faiss = None
from .logger import Logger
class EmbeddingsManager:
"""
Gestionnaire d'embeddings visuels utilisant OpenCLIP pour l'encodage
et FAISS pour l'indexation et la recherche de similarité.
"""
def __init__(
self,
model_name: str = "ViT-B-32",
pretrained: str = "openai",
index_path: str = "data/faiss_index",
device: Optional[str] = None,
logger: Optional[Logger] = None
):
"""
Initialise le gestionnaire d'embeddings.
Args:
model_name: Nom du modèle OpenCLIP
pretrained: Dataset de pré-entraînement
index_path: Chemin vers l'index FAISS
device: Device PyTorch (cuda/cpu)
logger: Instance du logger
"""
self.model_name = model_name
self.pretrained = pretrained
self.index_path = Path(index_path)
# Forcer CPU pour OpenCLIP pour économiser la mémoire GPU (Qwen3-VL prioritaire)
self.device = "cpu"
self.logger = logger
if self.logger:
self.logger.log_action({
"action": "embeddings_manager_init",
"device": self.device,
"reason": "CPU forcé pour économiser GPU pour Qwen3-VL"
})
# Créer le répertoire d'index si nécessaire
self.index_path.mkdir(parents=True, exist_ok=True)
# Chemins des fichiers
self.index_file = self.index_path / "embeddings.index"
self.metadata_file = self.index_path / "metadata.pkl"
# Initialiser le modèle et l'index
self.clip_model = None
self.preprocess = None
self.faiss_index = None
self.metadata_store: Dict[int, Dict[str, Any]] = {}
self.embedding_dim = 512 # Dimension par défaut pour ViT-B-32
self._load_model()
self._load_or_create_index()
def _load_model(self):
"""Charge le modèle OpenCLIP."""
if open_clip is None:
raise ImportError(
"OpenCLIP n'est pas installé. "
"Installez-le avec: pip install open-clip-torch"
)
try:
self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(
self.model_name,
pretrained=self.pretrained,
device=self.device
)
self.clip_model.eval()
# Obtenir la dimension d'embedding réelle
with torch.no_grad():
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
dummy_embedding = self.clip_model.encode_image(dummy_image)
self.embedding_dim = dummy_embedding.shape[-1]
if self.logger:
self.logger.log_action({
"action": "model_loaded",
"model": self.model_name,
"device": self.device,
"embedding_dim": self.embedding_dim
})
except Exception as e:
error_msg = f"Erreur lors du chargement du modèle OpenCLIP: {e}"
if self.logger:
self.logger.log_action({
"action": "model_load_error",
"error": str(e)
})
raise RuntimeError(error_msg)
def _load_or_create_index(self):
"""Charge l'index FAISS existant ou en crée un nouveau."""
if faiss is None:
raise ImportError(
"FAISS n'est pas installé. "
"Installez-le avec: pip install faiss-cpu ou faiss-gpu"
)
# Charger l'index existant
if self.index_file.exists() and self.metadata_file.exists():
try:
self.faiss_index = faiss.read_index(str(self.index_file))
with open(self.metadata_file, 'rb') as f:
self.metadata_store = pickle.load(f)
if self.logger:
self.logger.log_action({
"action": "index_loaded",
"num_vectors": self.faiss_index.ntotal,
"path": str(self.index_file)
})
return
except Exception as e:
if self.logger:
self.logger.log_action({
"action": "index_load_error",
"error": str(e)
})
# Continuer pour créer un nouvel index
# Créer un nouvel index
self.faiss_index = faiss.IndexFlatL2(self.embedding_dim)
self.metadata_store = {}
if self.logger:
self.logger.log_action({
"action": "index_created",
"embedding_dim": self.embedding_dim
})
def encode_image(self, image: np.ndarray) -> np.ndarray:
"""
Génère un embedding 512-d pour une image.
Args:
image: Image numpy array (H, W, C) en BGR ou RGB
Returns:
Embedding numpy array de forme (embedding_dim,)
"""
try:
# Convertir BGR vers RGB si nécessaire
if len(image.shape) == 3 and image.shape[2] == 3:
# Supposer BGR (format OpenCV)
image_rgb = image[:, :, ::-1]
else:
image_rgb = image
# Convertir en PIL Image
pil_image = Image.fromarray(image_rgb.astype(np.uint8))
# Prétraiter l'image
image_tensor = self.preprocess(pil_image).unsqueeze(0).to(self.device)
# Générer l'embedding
with torch.no_grad():
embedding = self.clip_model.encode_image(image_tensor)
embedding = embedding / embedding.norm(dim=-1, keepdim=True) # Normaliser
return embedding.cpu().numpy().flatten()
except Exception as e:
if self.logger:
self.logger.log_action({
"action": "encoding_error",
"error": str(e)
})
raise RuntimeError(f"Erreur lors de l'encodage de l'image: {e}")
def add_to_index(self, embedding: np.ndarray, metadata: Dict[str, Any]) -> int:
"""
Ajoute un embedding à l'index FAISS avec ses métadonnées.
Args:
embedding: Embedding numpy array
metadata: Dictionnaire de métadonnées associées
Returns:
ID de l'embedding dans l'index
"""
try:
# Obtenir l'ID avant l'ajout
idx = self.faiss_index.ntotal
# Ajouter à l'index FAISS
embedding_2d = embedding.reshape(1, -1).astype(np.float32)
self.faiss_index.add(embedding_2d)
# Stocker les métadonnées
self.metadata_store[idx] = metadata
if self.logger:
self.logger.log_action({
"action": "embedding_added",
"id": idx,
"metadata": metadata
})
return idx
except Exception as e:
if self.logger:
self.logger.log_action({
"action": "add_to_index_error",
"error": str(e)
})
raise RuntimeError(f"Erreur lors de l'ajout à l'index: {e}")
def search_similar(
self,
query_embedding: np.ndarray,
k: int = 5
) -> List[Dict[str, Any]]:
"""
Recherche les k embeddings les plus similaires.
Args:
query_embedding: Embedding de requête
k: Nombre de résultats à retourner
Returns:
Liste de dictionnaires avec id, distance et metadata
"""
try:
if self.faiss_index.ntotal == 0:
return []
# Limiter k au nombre d'embeddings disponibles
k = min(k, self.faiss_index.ntotal)
# Rechercher
query_2d = query_embedding.reshape(1, -1).astype(np.float32)
distances, indices = self.faiss_index.search(query_2d, k)
# Formater les résultats
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx != -1: # FAISS retourne -1 si pas assez de résultats
results.append({
"id": int(idx),
"distance": float(dist),
"similarity": float(1.0 / (1.0 + dist)), # Convertir distance en similarité
"metadata": self.metadata_store.get(int(idx), {})
})
return results
except Exception as e:
if self.logger:
self.logger.log_action({
"action": "search_error",
"error": str(e)
})
raise RuntimeError(f"Erreur lors de la recherche: {e}")
def get_embedding_similarity(
self,
emb1: np.ndarray,
emb2: np.ndarray
) -> float:
"""
Calcule la similarité cosinus entre deux embeddings.
Args:
emb1: Premier embedding
emb2: Deuxième embedding
Returns:
Similarité cosinus (0-1)
"""
try:
# Normaliser les embeddings
emb1_norm = emb1 / np.linalg.norm(emb1)
emb2_norm = emb2 / np.linalg.norm(emb2)
# Calculer la similarité cosinus
similarity = np.dot(emb1_norm, emb2_norm)
# Convertir de [-1, 1] à [0, 1]
similarity = (similarity + 1.0) / 2.0
return float(similarity)
except Exception as e:
if self.logger:
self.logger.log_action({
"action": "similarity_calculation_error",
"error": str(e)
})
return 0.0
def rebuild_index(self):
"""Reconstruit l'index FAISS à partir des embeddings stockés."""
try:
if self.faiss_index.ntotal == 0:
if self.logger:
self.logger.log_action({
"action": "rebuild_skipped",
"reason": "index_empty"
})
return
# Extraire tous les embeddings
all_embeddings = []
for i in range(self.faiss_index.ntotal):
embedding = self.faiss_index.reconstruct(i)
all_embeddings.append(embedding)
# Créer un nouvel index
new_index = faiss.IndexFlatL2(self.embedding_dim)
# Ajouter tous les embeddings
embeddings_array = np.array(all_embeddings).astype(np.float32)
new_index.add(embeddings_array)
# Remplacer l'ancien index
self.faiss_index = new_index
if self.logger:
self.logger.log_action({
"action": "index_rebuilt",
"num_vectors": self.faiss_index.ntotal
})
except Exception as e:
if self.logger:
self.logger.log_action({
"action": "rebuild_error",
"error": str(e)
})
raise RuntimeError(f"Erreur lors de la reconstruction de l'index: {e}")
def save_index(self):
"""Sauvegarde l'index FAISS et les métadonnées sur disque."""
try:
# Sauvegarder l'index FAISS
faiss.write_index(self.faiss_index, str(self.index_file))
# Sauvegarder les métadonnées
with open(self.metadata_file, 'wb') as f:
pickle.dump(self.metadata_store, f)
if self.logger:
self.logger.log_action({
"action": "index_saved",
"num_vectors": self.faiss_index.ntotal,
"path": str(self.index_file)
})
except Exception as e:
if self.logger:
self.logger.log_action({
"action": "save_error",
"error": str(e)
})
raise RuntimeError(f"Erreur lors de la sauvegarde de l'index: {e}")
def get_stats(self) -> Dict[str, Any]:
"""
Retourne des statistiques sur l'index.
Returns:
Dictionnaire de statistiques
"""
return {
"num_embeddings": self.faiss_index.ntotal,
"embedding_dim": self.embedding_dim,
"model_name": self.model_name,
"device": self.device,
"index_path": str(self.index_file)
}
def clear_index(self):
"""Efface tous les embeddings de l'index."""
self.faiss_index = faiss.IndexFlatL2(self.embedding_dim)
self.metadata_store = {}
if self.logger:
self.logger.log_action({
"action": "index_cleared"
})