Initial commit
This commit is contained in:
413
geniusia2/core/embeddings_manager.py
Normal file
413
geniusia2/core/embeddings_manager.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
Gestionnaire d'embeddings visuels avec OpenCLIP et FAISS.
|
||||
Gère l'encodage d'images, l'indexation et la recherche de similarité.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import open_clip
|
||||
except ImportError:
|
||||
open_clip = None
|
||||
|
||||
try:
|
||||
import faiss
|
||||
except ImportError:
|
||||
faiss = None
|
||||
|
||||
from .logger import Logger
|
||||
|
||||
|
||||
class EmbeddingsManager:
|
||||
"""
|
||||
Gestionnaire d'embeddings visuels utilisant OpenCLIP pour l'encodage
|
||||
et FAISS pour l'indexation et la recherche de similarité.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "ViT-B-32",
|
||||
pretrained: str = "openai",
|
||||
index_path: str = "data/faiss_index",
|
||||
device: Optional[str] = None,
|
||||
logger: Optional[Logger] = None
|
||||
):
|
||||
"""
|
||||
Initialise le gestionnaire d'embeddings.
|
||||
|
||||
Args:
|
||||
model_name: Nom du modèle OpenCLIP
|
||||
pretrained: Dataset de pré-entraînement
|
||||
index_path: Chemin vers l'index FAISS
|
||||
device: Device PyTorch (cuda/cpu)
|
||||
logger: Instance du logger
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.pretrained = pretrained
|
||||
self.index_path = Path(index_path)
|
||||
# Forcer CPU pour OpenCLIP pour économiser la mémoire GPU (Qwen3-VL prioritaire)
|
||||
self.device = "cpu"
|
||||
self.logger = logger
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "embeddings_manager_init",
|
||||
"device": self.device,
|
||||
"reason": "CPU forcé pour économiser GPU pour Qwen3-VL"
|
||||
})
|
||||
|
||||
# Créer le répertoire d'index si nécessaire
|
||||
self.index_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Chemins des fichiers
|
||||
self.index_file = self.index_path / "embeddings.index"
|
||||
self.metadata_file = self.index_path / "metadata.pkl"
|
||||
|
||||
# Initialiser le modèle et l'index
|
||||
self.clip_model = None
|
||||
self.preprocess = None
|
||||
self.faiss_index = None
|
||||
self.metadata_store: Dict[int, Dict[str, Any]] = {}
|
||||
self.embedding_dim = 512 # Dimension par défaut pour ViT-B-32
|
||||
|
||||
self._load_model()
|
||||
self._load_or_create_index()
|
||||
|
||||
def _load_model(self):
|
||||
"""Charge le modèle OpenCLIP."""
|
||||
if open_clip is None:
|
||||
raise ImportError(
|
||||
"OpenCLIP n'est pas installé. "
|
||||
"Installez-le avec: pip install open-clip-torch"
|
||||
)
|
||||
|
||||
try:
|
||||
self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(
|
||||
self.model_name,
|
||||
pretrained=self.pretrained,
|
||||
device=self.device
|
||||
)
|
||||
self.clip_model.eval()
|
||||
|
||||
# Obtenir la dimension d'embedding réelle
|
||||
with torch.no_grad():
|
||||
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
|
||||
dummy_embedding = self.clip_model.encode_image(dummy_image)
|
||||
self.embedding_dim = dummy_embedding.shape[-1]
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "model_loaded",
|
||||
"model": self.model_name,
|
||||
"device": self.device,
|
||||
"embedding_dim": self.embedding_dim
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Erreur lors du chargement du modèle OpenCLIP: {e}"
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "model_load_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
def _load_or_create_index(self):
|
||||
"""Charge l'index FAISS existant ou en crée un nouveau."""
|
||||
if faiss is None:
|
||||
raise ImportError(
|
||||
"FAISS n'est pas installé. "
|
||||
"Installez-le avec: pip install faiss-cpu ou faiss-gpu"
|
||||
)
|
||||
|
||||
# Charger l'index existant
|
||||
if self.index_file.exists() and self.metadata_file.exists():
|
||||
try:
|
||||
self.faiss_index = faiss.read_index(str(self.index_file))
|
||||
with open(self.metadata_file, 'rb') as f:
|
||||
self.metadata_store = pickle.load(f)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_loaded",
|
||||
"num_vectors": self.faiss_index.ntotal,
|
||||
"path": str(self.index_file)
|
||||
})
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_load_error",
|
||||
"error": str(e)
|
||||
})
|
||||
# Continuer pour créer un nouvel index
|
||||
|
||||
# Créer un nouvel index
|
||||
self.faiss_index = faiss.IndexFlatL2(self.embedding_dim)
|
||||
self.metadata_store = {}
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_created",
|
||||
"embedding_dim": self.embedding_dim
|
||||
})
|
||||
|
||||
def encode_image(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Génère un embedding 512-d pour une image.
|
||||
|
||||
Args:
|
||||
image: Image numpy array (H, W, C) en BGR ou RGB
|
||||
|
||||
Returns:
|
||||
Embedding numpy array de forme (embedding_dim,)
|
||||
"""
|
||||
try:
|
||||
# Convertir BGR vers RGB si nécessaire
|
||||
if len(image.shape) == 3 and image.shape[2] == 3:
|
||||
# Supposer BGR (format OpenCV)
|
||||
image_rgb = image[:, :, ::-1]
|
||||
else:
|
||||
image_rgb = image
|
||||
|
||||
# Convertir en PIL Image
|
||||
pil_image = Image.fromarray(image_rgb.astype(np.uint8))
|
||||
|
||||
# Prétraiter l'image
|
||||
image_tensor = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
||||
|
||||
# Générer l'embedding
|
||||
with torch.no_grad():
|
||||
embedding = self.clip_model.encode_image(image_tensor)
|
||||
embedding = embedding / embedding.norm(dim=-1, keepdim=True) # Normaliser
|
||||
|
||||
return embedding.cpu().numpy().flatten()
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "encoding_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(f"Erreur lors de l'encodage de l'image: {e}")
|
||||
|
||||
def add_to_index(self, embedding: np.ndarray, metadata: Dict[str, Any]) -> int:
|
||||
"""
|
||||
Ajoute un embedding à l'index FAISS avec ses métadonnées.
|
||||
|
||||
Args:
|
||||
embedding: Embedding numpy array
|
||||
metadata: Dictionnaire de métadonnées associées
|
||||
|
||||
Returns:
|
||||
ID de l'embedding dans l'index
|
||||
"""
|
||||
try:
|
||||
# Obtenir l'ID avant l'ajout
|
||||
idx = self.faiss_index.ntotal
|
||||
|
||||
# Ajouter à l'index FAISS
|
||||
embedding_2d = embedding.reshape(1, -1).astype(np.float32)
|
||||
self.faiss_index.add(embedding_2d)
|
||||
|
||||
# Stocker les métadonnées
|
||||
self.metadata_store[idx] = metadata
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "embedding_added",
|
||||
"id": idx,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
return idx
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "add_to_index_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(f"Erreur lors de l'ajout à l'index: {e}")
|
||||
|
||||
def search_similar(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
k: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Recherche les k embeddings les plus similaires.
|
||||
|
||||
Args:
|
||||
query_embedding: Embedding de requête
|
||||
k: Nombre de résultats à retourner
|
||||
|
||||
Returns:
|
||||
Liste de dictionnaires avec id, distance et metadata
|
||||
"""
|
||||
try:
|
||||
if self.faiss_index.ntotal == 0:
|
||||
return []
|
||||
|
||||
# Limiter k au nombre d'embeddings disponibles
|
||||
k = min(k, self.faiss_index.ntotal)
|
||||
|
||||
# Rechercher
|
||||
query_2d = query_embedding.reshape(1, -1).astype(np.float32)
|
||||
distances, indices = self.faiss_index.search(query_2d, k)
|
||||
|
||||
# Formater les résultats
|
||||
results = []
|
||||
for dist, idx in zip(distances[0], indices[0]):
|
||||
if idx != -1: # FAISS retourne -1 si pas assez de résultats
|
||||
results.append({
|
||||
"id": int(idx),
|
||||
"distance": float(dist),
|
||||
"similarity": float(1.0 / (1.0 + dist)), # Convertir distance en similarité
|
||||
"metadata": self.metadata_store.get(int(idx), {})
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "search_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(f"Erreur lors de la recherche: {e}")
|
||||
|
||||
def get_embedding_similarity(
|
||||
self,
|
||||
emb1: np.ndarray,
|
||||
emb2: np.ndarray
|
||||
) -> float:
|
||||
"""
|
||||
Calcule la similarité cosinus entre deux embeddings.
|
||||
|
||||
Args:
|
||||
emb1: Premier embedding
|
||||
emb2: Deuxième embedding
|
||||
|
||||
Returns:
|
||||
Similarité cosinus (0-1)
|
||||
"""
|
||||
try:
|
||||
# Normaliser les embeddings
|
||||
emb1_norm = emb1 / np.linalg.norm(emb1)
|
||||
emb2_norm = emb2 / np.linalg.norm(emb2)
|
||||
|
||||
# Calculer la similarité cosinus
|
||||
similarity = np.dot(emb1_norm, emb2_norm)
|
||||
|
||||
# Convertir de [-1, 1] à [0, 1]
|
||||
similarity = (similarity + 1.0) / 2.0
|
||||
|
||||
return float(similarity)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "similarity_calculation_error",
|
||||
"error": str(e)
|
||||
})
|
||||
return 0.0
|
||||
|
||||
def rebuild_index(self):
|
||||
"""Reconstruit l'index FAISS à partir des embeddings stockés."""
|
||||
try:
|
||||
if self.faiss_index.ntotal == 0:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_skipped",
|
||||
"reason": "index_empty"
|
||||
})
|
||||
return
|
||||
|
||||
# Extraire tous les embeddings
|
||||
all_embeddings = []
|
||||
for i in range(self.faiss_index.ntotal):
|
||||
embedding = self.faiss_index.reconstruct(i)
|
||||
all_embeddings.append(embedding)
|
||||
|
||||
# Créer un nouvel index
|
||||
new_index = faiss.IndexFlatL2(self.embedding_dim)
|
||||
|
||||
# Ajouter tous les embeddings
|
||||
embeddings_array = np.array(all_embeddings).astype(np.float32)
|
||||
new_index.add(embeddings_array)
|
||||
|
||||
# Remplacer l'ancien index
|
||||
self.faiss_index = new_index
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_rebuilt",
|
||||
"num_vectors": self.faiss_index.ntotal
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "rebuild_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(f"Erreur lors de la reconstruction de l'index: {e}")
|
||||
|
||||
def save_index(self):
|
||||
"""Sauvegarde l'index FAISS et les métadonnées sur disque."""
|
||||
try:
|
||||
# Sauvegarder l'index FAISS
|
||||
faiss.write_index(self.faiss_index, str(self.index_file))
|
||||
|
||||
# Sauvegarder les métadonnées
|
||||
with open(self.metadata_file, 'wb') as f:
|
||||
pickle.dump(self.metadata_store, f)
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_saved",
|
||||
"num_vectors": self.faiss_index.ntotal,
|
||||
"path": str(self.index_file)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "save_error",
|
||||
"error": str(e)
|
||||
})
|
||||
raise RuntimeError(f"Erreur lors de la sauvegarde de l'index: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Retourne des statistiques sur l'index.
|
||||
|
||||
Returns:
|
||||
Dictionnaire de statistiques
|
||||
"""
|
||||
return {
|
||||
"num_embeddings": self.faiss_index.ntotal,
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"model_name": self.model_name,
|
||||
"device": self.device,
|
||||
"index_path": str(self.index_file)
|
||||
}
|
||||
|
||||
def clear_index(self):
|
||||
"""Efface tous les embeddings de l'index."""
|
||||
self.faiss_index = faiss.IndexFlatL2(self.embedding_dim)
|
||||
self.metadata_store = {}
|
||||
|
||||
if self.logger:
|
||||
self.logger.log_action({
|
||||
"action": "index_cleared"
|
||||
})
|
||||
Reference in New Issue
Block a user