Initial commit
This commit is contained in:
194
geniusia2/tests/test_embeddings_manager.py
Normal file
194
geniusia2/tests/test_embeddings_manager.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Tests pour le gestionnaire d'embeddings.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from geniusia2.core.embeddings_manager import EmbeddingsManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir():
|
||||
"""Crée un répertoire temporaire pour les tests."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embeddings_manager(temp_index_dir):
|
||||
"""Crée une instance du gestionnaire d'embeddings pour les tests."""
|
||||
return EmbeddingsManager(
|
||||
model_name="ViT-B-32",
|
||||
pretrained="openai",
|
||||
index_path=temp_index_dir,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
|
||||
def test_initialization(embeddings_manager):
|
||||
"""Test l'initialisation du gestionnaire."""
|
||||
assert embeddings_manager.clip_model is not None
|
||||
assert embeddings_manager.faiss_index is not None
|
||||
assert embeddings_manager.embedding_dim > 0
|
||||
|
||||
|
||||
def test_encode_image(embeddings_manager):
|
||||
"""Test l'encodage d'une image."""
|
||||
# Créer une image de test
|
||||
test_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
||||
|
||||
# Encoder l'image
|
||||
embedding = embeddings_manager.encode_image(test_image)
|
||||
|
||||
# Vérifier la forme et le type
|
||||
assert isinstance(embedding, np.ndarray)
|
||||
assert embedding.shape == (embeddings_manager.embedding_dim,)
|
||||
assert embedding.dtype == np.float32 or embedding.dtype == np.float64
|
||||
|
||||
|
||||
def test_add_to_index(embeddings_manager):
|
||||
"""Test l'ajout d'embeddings à l'index."""
|
||||
# Créer un embedding de test
|
||||
embedding = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
metadata = {"task_id": "test_task", "label": "button"}
|
||||
|
||||
# Ajouter à l'index
|
||||
idx = embeddings_manager.add_to_index(embedding, metadata)
|
||||
|
||||
# Vérifier
|
||||
assert idx == 0
|
||||
assert embeddings_manager.faiss_index.ntotal == 1
|
||||
assert idx in embeddings_manager.metadata_store
|
||||
|
||||
|
||||
def test_search_similar(embeddings_manager):
|
||||
"""Test la recherche de similarité."""
|
||||
# Ajouter plusieurs embeddings
|
||||
embeddings = []
|
||||
for i in range(5):
|
||||
emb = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
embeddings.append(emb)
|
||||
embeddings_manager.add_to_index(emb, {"id": i})
|
||||
|
||||
# Rechercher des embeddings similaires
|
||||
query = embeddings[0]
|
||||
results = embeddings_manager.search_similar(query, k=3)
|
||||
|
||||
# Vérifier
|
||||
assert len(results) == 3
|
||||
assert results[0]["id"] == 0 # Le plus similaire devrait être lui-même
|
||||
assert "distance" in results[0]
|
||||
assert "similarity" in results[0]
|
||||
assert "metadata" in results[0]
|
||||
|
||||
|
||||
def test_embedding_similarity(embeddings_manager):
|
||||
"""Test le calcul de similarité cosinus."""
|
||||
# Créer deux embeddings identiques
|
||||
emb1 = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
emb2 = emb1.copy()
|
||||
|
||||
# Calculer la similarité
|
||||
similarity = embeddings_manager.get_embedding_similarity(emb1, emb2)
|
||||
|
||||
# Devrait être proche de 1.0
|
||||
assert 0.99 <= similarity <= 1.0
|
||||
|
||||
# Créer deux embeddings orthogonaux
|
||||
emb3 = np.zeros(embeddings_manager.embedding_dim, dtype=np.float32)
|
||||
emb3[0] = 1.0
|
||||
emb4 = np.zeros(embeddings_manager.embedding_dim, dtype=np.float32)
|
||||
emb4[1] = 1.0
|
||||
|
||||
similarity_ortho = embeddings_manager.get_embedding_similarity(emb3, emb4)
|
||||
|
||||
# Devrait être proche de 0.5 (orthogonal)
|
||||
assert 0.4 <= similarity_ortho <= 0.6
|
||||
|
||||
|
||||
def test_save_and_load_index(temp_index_dir):
|
||||
"""Test la sauvegarde et le chargement de l'index."""
|
||||
# Créer un gestionnaire et ajouter des données
|
||||
manager1 = EmbeddingsManager(
|
||||
model_name="ViT-B-32",
|
||||
index_path=temp_index_dir,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
embedding = np.random.randn(manager1.embedding_dim).astype(np.float32)
|
||||
metadata = {"test": "data"}
|
||||
manager1.add_to_index(embedding, metadata)
|
||||
manager1.save_index()
|
||||
|
||||
# Créer un nouveau gestionnaire qui charge l'index
|
||||
manager2 = EmbeddingsManager(
|
||||
model_name="ViT-B-32",
|
||||
index_path=temp_index_dir,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Vérifier que les données sont chargées
|
||||
assert manager2.faiss_index.ntotal == 1
|
||||
assert 0 in manager2.metadata_store
|
||||
assert manager2.metadata_store[0]["test"] == "data"
|
||||
|
||||
|
||||
def test_rebuild_index(embeddings_manager):
|
||||
"""Test la reconstruction de l'index."""
|
||||
# Ajouter des embeddings
|
||||
for i in range(3):
|
||||
emb = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
embeddings_manager.add_to_index(emb, {"id": i})
|
||||
|
||||
initial_count = embeddings_manager.faiss_index.ntotal
|
||||
|
||||
# Reconstruire l'index
|
||||
embeddings_manager.rebuild_index()
|
||||
|
||||
# Vérifier que le nombre d'embeddings est préservé
|
||||
assert embeddings_manager.faiss_index.ntotal == initial_count
|
||||
|
||||
|
||||
def test_get_stats(embeddings_manager):
|
||||
"""Test l'obtention des statistiques."""
|
||||
stats = embeddings_manager.get_stats()
|
||||
|
||||
assert "num_embeddings" in stats
|
||||
assert "embedding_dim" in stats
|
||||
assert "model_name" in stats
|
||||
assert "device" in stats
|
||||
assert stats["model_name"] == "ViT-B-32"
|
||||
|
||||
|
||||
def test_clear_index(embeddings_manager):
|
||||
"""Test l'effacement de l'index."""
|
||||
# Ajouter des embeddings
|
||||
for i in range(3):
|
||||
emb = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
embeddings_manager.add_to_index(emb, {"id": i})
|
||||
|
||||
assert embeddings_manager.faiss_index.ntotal == 3
|
||||
|
||||
# Effacer l'index
|
||||
embeddings_manager.clear_index()
|
||||
|
||||
# Vérifier
|
||||
assert embeddings_manager.faiss_index.ntotal == 0
|
||||
assert len(embeddings_manager.metadata_store) == 0
|
||||
|
||||
|
||||
def test_empty_search(embeddings_manager):
|
||||
"""Test la recherche sur un index vide."""
|
||||
query = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
results = embeddings_manager.search_similar(query, k=5)
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user