195 lines
5.9 KiB
Python
195 lines
5.9 KiB
Python
"""
|
|
Tests pour le gestionnaire d'embeddings.
|
|
"""
|
|
|
|
import pytest
|
|
import numpy as np
|
|
import tempfile
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
from geniusia2.core.embeddings_manager import EmbeddingsManager
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_index_dir():
|
|
"""Crée un répertoire temporaire pour les tests."""
|
|
temp_dir = tempfile.mkdtemp()
|
|
yield temp_dir
|
|
shutil.rmtree(temp_dir)
|
|
|
|
|
|
@pytest.fixture
|
|
def embeddings_manager(temp_index_dir):
|
|
"""Crée une instance du gestionnaire d'embeddings pour les tests."""
|
|
return EmbeddingsManager(
|
|
model_name="ViT-B-32",
|
|
pretrained="openai",
|
|
index_path=temp_index_dir,
|
|
device="cpu"
|
|
)
|
|
|
|
|
|
def test_initialization(embeddings_manager):
|
|
"""Test l'initialisation du gestionnaire."""
|
|
assert embeddings_manager.clip_model is not None
|
|
assert embeddings_manager.faiss_index is not None
|
|
assert embeddings_manager.embedding_dim > 0
|
|
|
|
|
|
def test_encode_image(embeddings_manager):
|
|
"""Test l'encodage d'une image."""
|
|
# Créer une image de test
|
|
test_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
|
|
|
# Encoder l'image
|
|
embedding = embeddings_manager.encode_image(test_image)
|
|
|
|
# Vérifier la forme et le type
|
|
assert isinstance(embedding, np.ndarray)
|
|
assert embedding.shape == (embeddings_manager.embedding_dim,)
|
|
assert embedding.dtype == np.float32 or embedding.dtype == np.float64
|
|
|
|
|
|
def test_add_to_index(embeddings_manager):
|
|
"""Test l'ajout d'embeddings à l'index."""
|
|
# Créer un embedding de test
|
|
embedding = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
|
metadata = {"task_id": "test_task", "label": "button"}
|
|
|
|
# Ajouter à l'index
|
|
idx = embeddings_manager.add_to_index(embedding, metadata)
|
|
|
|
# Vérifier
|
|
assert idx == 0
|
|
assert embeddings_manager.faiss_index.ntotal == 1
|
|
assert idx in embeddings_manager.metadata_store
|
|
|
|
|
|
def test_search_similar(embeddings_manager):
|
|
"""Test la recherche de similarité."""
|
|
# Ajouter plusieurs embeddings
|
|
embeddings = []
|
|
for i in range(5):
|
|
emb = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
|
embeddings.append(emb)
|
|
embeddings_manager.add_to_index(emb, {"id": i})
|
|
|
|
# Rechercher des embeddings similaires
|
|
query = embeddings[0]
|
|
results = embeddings_manager.search_similar(query, k=3)
|
|
|
|
# Vérifier
|
|
assert len(results) == 3
|
|
assert results[0]["id"] == 0 # Le plus similaire devrait être lui-même
|
|
assert "distance" in results[0]
|
|
assert "similarity" in results[0]
|
|
assert "metadata" in results[0]
|
|
|
|
|
|
def test_embedding_similarity(embeddings_manager):
|
|
"""Test le calcul de similarité cosinus."""
|
|
# Créer deux embeddings identiques
|
|
emb1 = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
|
emb2 = emb1.copy()
|
|
|
|
# Calculer la similarité
|
|
similarity = embeddings_manager.get_embedding_similarity(emb1, emb2)
|
|
|
|
# Devrait être proche de 1.0
|
|
assert 0.99 <= similarity <= 1.0
|
|
|
|
# Créer deux embeddings orthogonaux
|
|
emb3 = np.zeros(embeddings_manager.embedding_dim, dtype=np.float32)
|
|
emb3[0] = 1.0
|
|
emb4 = np.zeros(embeddings_manager.embedding_dim, dtype=np.float32)
|
|
emb4[1] = 1.0
|
|
|
|
similarity_ortho = embeddings_manager.get_embedding_similarity(emb3, emb4)
|
|
|
|
# Devrait être proche de 0.5 (orthogonal)
|
|
assert 0.4 <= similarity_ortho <= 0.6
|
|
|
|
|
|
def test_save_and_load_index(temp_index_dir):
|
|
"""Test la sauvegarde et le chargement de l'index."""
|
|
# Créer un gestionnaire et ajouter des données
|
|
manager1 = EmbeddingsManager(
|
|
model_name="ViT-B-32",
|
|
index_path=temp_index_dir,
|
|
device="cpu"
|
|
)
|
|
|
|
embedding = np.random.randn(manager1.embedding_dim).astype(np.float32)
|
|
metadata = {"test": "data"}
|
|
manager1.add_to_index(embedding, metadata)
|
|
manager1.save_index()
|
|
|
|
# Créer un nouveau gestionnaire qui charge l'index
|
|
manager2 = EmbeddingsManager(
|
|
model_name="ViT-B-32",
|
|
index_path=temp_index_dir,
|
|
device="cpu"
|
|
)
|
|
|
|
# Vérifier que les données sont chargées
|
|
assert manager2.faiss_index.ntotal == 1
|
|
assert 0 in manager2.metadata_store
|
|
assert manager2.metadata_store[0]["test"] == "data"
|
|
|
|
|
|
def test_rebuild_index(embeddings_manager):
|
|
"""Test la reconstruction de l'index."""
|
|
# Ajouter des embeddings
|
|
for i in range(3):
|
|
emb = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
|
embeddings_manager.add_to_index(emb, {"id": i})
|
|
|
|
initial_count = embeddings_manager.faiss_index.ntotal
|
|
|
|
# Reconstruire l'index
|
|
embeddings_manager.rebuild_index()
|
|
|
|
# Vérifier que le nombre d'embeddings est préservé
|
|
assert embeddings_manager.faiss_index.ntotal == initial_count
|
|
|
|
|
|
def test_get_stats(embeddings_manager):
|
|
"""Test l'obtention des statistiques."""
|
|
stats = embeddings_manager.get_stats()
|
|
|
|
assert "num_embeddings" in stats
|
|
assert "embedding_dim" in stats
|
|
assert "model_name" in stats
|
|
assert "device" in stats
|
|
assert stats["model_name"] == "ViT-B-32"
|
|
|
|
|
|
def test_clear_index(embeddings_manager):
|
|
"""Test l'effacement de l'index."""
|
|
# Ajouter des embeddings
|
|
for i in range(3):
|
|
emb = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
|
embeddings_manager.add_to_index(emb, {"id": i})
|
|
|
|
assert embeddings_manager.faiss_index.ntotal == 3
|
|
|
|
# Effacer l'index
|
|
embeddings_manager.clear_index()
|
|
|
|
# Vérifier
|
|
assert embeddings_manager.faiss_index.ntotal == 0
|
|
assert len(embeddings_manager.metadata_store) == 0
|
|
|
|
|
|
def test_empty_search(embeddings_manager):
|
|
"""Test la recherche sur un index vide."""
|
|
query = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
|
results = embeddings_manager.search_similar(query, k=5)
|
|
|
|
assert results == []
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|