Initial commit
This commit is contained in:
318
test_embedding_system.py
Executable file
318
test_embedding_system.py
Executable file
@@ -0,0 +1,318 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test du nouveau système d'embeddings.
|
||||
|
||||
Ce script teste:
|
||||
1. CLIPEmbedder - Génération d'embeddings
|
||||
2. FAISSIndex - Ajout, recherche, save/load
|
||||
3. EmbeddingManager - Cache, fallback
|
||||
4. Intégration complète
|
||||
"""
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
|
||||
# Add project to path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from geniusia2.core.embedders import CLIPEmbedder, FAISSIndex, EmbeddingManager
|
||||
|
||||
|
||||
def create_test_image(size=(224, 224), color=(255, 0, 0)):
|
||||
"""Crée une image de test."""
|
||||
return Image.new('RGB', size, color)
|
||||
|
||||
|
||||
def test_clip_embedder():
|
||||
"""Test 1: CLIPEmbedder"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 1: CLIPEmbedder")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Initialize
|
||||
print(" Initializing CLIPEmbedder...")
|
||||
embedder = CLIPEmbedder(device='cpu')
|
||||
print(f" ✓ Model loaded: {embedder.get_model_name()}")
|
||||
print(f" ✓ Dimension: {embedder.get_dimension()}")
|
||||
|
||||
# Test single embedding
|
||||
print("\n Testing single embedding...")
|
||||
img = create_test_image(color=(255, 0, 0))
|
||||
embedding = embedder.embed(img)
|
||||
print(f" ✓ Embedding shape: {embedding.shape}")
|
||||
print(f" ✓ Embedding norm: {np.linalg.norm(embedding):.4f} (should be ~1.0)")
|
||||
|
||||
# Test batch embedding
|
||||
print("\n Testing batch embedding...")
|
||||
images = [
|
||||
create_test_image(color=(255, 0, 0)),
|
||||
create_test_image(color=(0, 255, 0)),
|
||||
create_test_image(color=(0, 0, 255))
|
||||
]
|
||||
embeddings = embedder.embed_batch(images)
|
||||
print(f" ✓ Batch embeddings shape: {embeddings.shape}")
|
||||
print(f" ✓ Expected: (3, {embedder.get_dimension()})")
|
||||
|
||||
# Test consistency
|
||||
print("\n Testing consistency...")
|
||||
img1 = create_test_image(color=(128, 128, 128))
|
||||
emb1 = embedder.embed(img1)
|
||||
emb2 = embedder.embed(img1)
|
||||
diff = np.abs(emb1 - emb2).max()
|
||||
print(f" ✓ Max difference: {diff:.10f} (should be 0.0)")
|
||||
|
||||
print("\n ✅ CLIPEmbedder: ALL TESTS PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n ❌ CLIPEmbedder: FAILED - {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_faiss_index():
|
||||
"""Test 2: FAISSIndex"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 2: FAISSIndex")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Initialize
|
||||
print(" Initializing FAISSIndex...")
|
||||
dimension = 512
|
||||
index = FAISSIndex(dimension)
|
||||
print(f" ✓ Index created with dimension={dimension}")
|
||||
|
||||
# Test add
|
||||
print("\n Testing add...")
|
||||
embeddings = np.random.randn(5, dimension).astype('float32')
|
||||
# Normalize
|
||||
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
|
||||
metadata = [
|
||||
{"id": i, "label": f"item_{i}"}
|
||||
for i in range(5)
|
||||
]
|
||||
index.add(embeddings, metadata)
|
||||
print(f" ✓ Added 5 embeddings")
|
||||
print(f" ✓ Index size: {len(index)}")
|
||||
|
||||
# Test search
|
||||
print("\n Testing search...")
|
||||
query = embeddings[0] # Search for first embedding
|
||||
results = index.search(query, k=3)
|
||||
print(f" ✓ Found {len(results)} results")
|
||||
print(f" ✓ Best match distance: {results[0]['distance']:.6f}")
|
||||
print(f" ✓ Best match metadata: {results[0]['metadata']}")
|
||||
|
||||
# Test save/load
|
||||
print("\n Testing save/load...")
|
||||
test_path = "data/test_faiss_index"
|
||||
Path("data").mkdir(exist_ok=True)
|
||||
|
||||
index.save(test_path)
|
||||
print(f" ✓ Index saved to {test_path}")
|
||||
|
||||
# Load in new index
|
||||
index2 = FAISSIndex(dimension)
|
||||
index2.load(test_path)
|
||||
print(f" ✓ Index loaded")
|
||||
print(f" ✓ Loaded size: {len(index2)}")
|
||||
|
||||
# Verify search results are identical
|
||||
results2 = index2.search(query, k=3)
|
||||
assert len(results) == len(results2), "Result count mismatch"
|
||||
assert abs(results[0]['distance'] - results2[0]['distance']) < 1e-6, "Distance mismatch"
|
||||
print(f" ✓ Search results match after load")
|
||||
|
||||
# Test dimension mismatch
|
||||
print("\n Testing dimension validation...")
|
||||
try:
|
||||
wrong_embedding = np.random.randn(1, 256).astype('float32')
|
||||
index.add(wrong_embedding, [{"id": "wrong"}])
|
||||
print(f" ❌ Should have raised ValueError for dimension mismatch")
|
||||
return False
|
||||
except ValueError as e:
|
||||
print(f" ✓ Correctly rejected wrong dimension: {e}")
|
||||
|
||||
print("\n ✅ FAISSIndex: ALL TESTS PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n ❌ FAISSIndex: FAILED - {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_embedding_manager():
|
||||
"""Test 3: EmbeddingManager"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 3: EmbeddingManager")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Initialize
|
||||
print(" Initializing EmbeddingManager...")
|
||||
manager = EmbeddingManager(model_name="clip", cache_size=10)
|
||||
print(f" ✓ Manager initialized: {manager.get_model_name()}")
|
||||
|
||||
# Test single embedding
|
||||
print("\n Testing single embedding...")
|
||||
img = create_test_image(color=(255, 128, 0))
|
||||
emb1 = manager.embed(img)
|
||||
print(f" ✓ Embedding generated: shape={emb1.shape}")
|
||||
|
||||
# Test cache hit
|
||||
print("\n Testing cache...")
|
||||
emb2 = manager.embed(img) # Should hit cache
|
||||
assert np.array_equal(emb1, emb2), "Cache should return identical embedding"
|
||||
stats = manager.get_stats()
|
||||
print(f" ✓ Cache hits: {stats['cache_hits']}")
|
||||
print(f" ✓ Cache misses: {stats['cache_misses']}")
|
||||
print(f" ✓ Hit rate: {stats['cache_hit_rate']:.1%}")
|
||||
|
||||
# Test cache eviction
|
||||
print("\n Testing cache eviction...")
|
||||
for i in range(15): # More than cache_size=10
|
||||
img = create_test_image(color=(i*10, i*10, i*10))
|
||||
manager.embed(img)
|
||||
|
||||
stats = manager.get_stats()
|
||||
print(f" ✓ Cache size: {stats['cache_size']} (max: {stats['cache_capacity']})")
|
||||
assert stats['cache_size'] <= 10, "Cache should not exceed max size"
|
||||
|
||||
# Test batch embedding
|
||||
print("\n Testing batch embedding...")
|
||||
images = [create_test_image(color=(i, i, i)) for i in range(5)]
|
||||
embeddings = manager.embed_batch(images)
|
||||
print(f" ✓ Batch embeddings: shape={embeddings.shape}")
|
||||
|
||||
# Test fallback (simulate)
|
||||
print("\n Testing fallback...")
|
||||
manager_fallback = EmbeddingManager(
|
||||
model_name="nonexistent_model",
|
||||
fallback_enabled=True
|
||||
)
|
||||
print(f" ✓ Fallback to: {manager_fallback.get_model_name()}")
|
||||
assert "clip" in manager_fallback.get_model_name().lower(), "Should fallback to CLIP"
|
||||
|
||||
print("\n ✅ EmbeddingManager: ALL TESTS PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n ❌ EmbeddingManager: FAILED - {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_integration():
|
||||
"""Test 4: Intégration complète"""
|
||||
print("\n" + "="*60)
|
||||
print("TEST 4: Intégration Complète")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# Setup
|
||||
print(" Setting up integrated system...")
|
||||
manager = EmbeddingManager(model_name="clip")
|
||||
index = FAISSIndex(manager.get_dimension())
|
||||
print(f" ✓ Manager: {manager.get_model_name()}")
|
||||
print(f" ✓ Index dimension: {index.dimension}")
|
||||
|
||||
# Create and index some images
|
||||
print("\n Creating and indexing images...")
|
||||
images = [
|
||||
create_test_image(color=(255, 0, 0)), # Red
|
||||
create_test_image(color=(0, 255, 0)), # Green
|
||||
create_test_image(color=(0, 0, 255)), # Blue
|
||||
create_test_image(color=(255, 255, 0)), # Yellow
|
||||
create_test_image(color=(255, 0, 255)), # Magenta
|
||||
]
|
||||
|
||||
labels = ["red", "green", "blue", "yellow", "magenta"]
|
||||
|
||||
for img, label in zip(images, labels):
|
||||
emb = manager.embed(img)
|
||||
index.add(emb.reshape(1, -1), [{"label": label}])
|
||||
|
||||
print(f" ✓ Indexed {len(index)} images")
|
||||
|
||||
# Search for similar image
|
||||
print("\n Testing similarity search...")
|
||||
query_img = create_test_image(color=(255, 0, 0)) # Red
|
||||
query_emb = manager.embed(query_img)
|
||||
results = index.search(query_emb, k=3)
|
||||
|
||||
print(f" ✓ Found {len(results)} similar images:")
|
||||
for i, result in enumerate(results):
|
||||
print(f" {i+1}. {result['metadata']['label']}: "
|
||||
f"similarity={result['similarity']:.4f}")
|
||||
|
||||
# Verify red is most similar to red
|
||||
assert results[0]['metadata']['label'] == 'red', \
|
||||
"Most similar should be red"
|
||||
print(f" ✓ Correct match found")
|
||||
|
||||
# Test persistence
|
||||
print("\n Testing persistence...")
|
||||
test_path = "data/test_integrated_index"
|
||||
index.save(test_path)
|
||||
|
||||
index2 = FAISSIndex(manager.get_dimension())
|
||||
index2.load(test_path)
|
||||
|
||||
results2 = index2.search(query_emb, k=3)
|
||||
assert results[0]['metadata']['label'] == results2[0]['metadata']['label'], \
|
||||
"Results should match after load"
|
||||
print(f" ✓ Persistence verified")
|
||||
|
||||
print("\n ✅ Integration: ALL TESTS PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n ❌ Integration: FAILED - {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("\n" + "="*60)
|
||||
print("TESTING NEW EMBEDDING SYSTEM")
|
||||
print("="*60)
|
||||
|
||||
results = {
|
||||
"CLIPEmbedder": test_clip_embedder(),
|
||||
"FAISSIndex": test_faiss_index(),
|
||||
"EmbeddingManager": test_embedding_manager(),
|
||||
"Integration": test_integration()
|
||||
}
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*60)
|
||||
print("TEST SUMMARY")
|
||||
print("="*60)
|
||||
|
||||
for test_name, passed in results.items():
|
||||
status = "✅ PASSED" if passed else "❌ FAILED"
|
||||
print(f" {test_name}: {status}")
|
||||
|
||||
all_passed = all(results.values())
|
||||
|
||||
if all_passed:
|
||||
print("\n🎉 ALL TESTS PASSED! System is ready.")
|
||||
return 0
|
||||
else:
|
||||
print("\n❌ SOME TESTS FAILED. Please review errors above.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user