#!/usr/bin/env python3 """ Test du nouveau système d'embeddings. Ce script teste: 1. CLIPEmbedder - Génération d'embeddings 2. FAISSIndex - Ajout, recherche, save/load 3. EmbeddingManager - Cache, fallback 4. Intégration complète """ import sys import numpy as np from PIL import Image from pathlib import Path # Add project to path sys.path.insert(0, str(Path(__file__).parent)) from geniusia2.core.embedders import CLIPEmbedder, FAISSIndex, EmbeddingManager def create_test_image(size=(224, 224), color=(255, 0, 0)): """Crée une image de test.""" return Image.new('RGB', size, color) def test_clip_embedder(): """Test 1: CLIPEmbedder""" print("\n" + "="*60) print("TEST 1: CLIPEmbedder") print("="*60) try: # Initialize print(" Initializing CLIPEmbedder...") embedder = CLIPEmbedder(device='cpu') print(f" ✓ Model loaded: {embedder.get_model_name()}") print(f" ✓ Dimension: {embedder.get_dimension()}") # Test single embedding print("\n Testing single embedding...") img = create_test_image(color=(255, 0, 0)) embedding = embedder.embed(img) print(f" ✓ Embedding shape: {embedding.shape}") print(f" ✓ Embedding norm: {np.linalg.norm(embedding):.4f} (should be ~1.0)") # Test batch embedding print("\n Testing batch embedding...") images = [ create_test_image(color=(255, 0, 0)), create_test_image(color=(0, 255, 0)), create_test_image(color=(0, 0, 255)) ] embeddings = embedder.embed_batch(images) print(f" ✓ Batch embeddings shape: {embeddings.shape}") print(f" ✓ Expected: (3, {embedder.get_dimension()})") # Test consistency print("\n Testing consistency...") img1 = create_test_image(color=(128, 128, 128)) emb1 = embedder.embed(img1) emb2 = embedder.embed(img1) diff = np.abs(emb1 - emb2).max() print(f" ✓ Max difference: {diff:.10f} (should be 0.0)") print("\n ✅ CLIPEmbedder: ALL TESTS PASSED") return True except Exception as e: print(f"\n ❌ CLIPEmbedder: FAILED - {e}") import traceback traceback.print_exc() return False def test_faiss_index(): """Test 2: FAISSIndex""" print("\n" + "="*60) print("TEST 2: FAISSIndex") print("="*60) try: # Initialize print(" Initializing FAISSIndex...") dimension = 512 index = FAISSIndex(dimension) print(f" ✓ Index created with dimension={dimension}") # Test add print("\n Testing add...") embeddings = np.random.randn(5, dimension).astype('float32') # Normalize embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) metadata = [ {"id": i, "label": f"item_{i}"} for i in range(5) ] index.add(embeddings, metadata) print(f" ✓ Added 5 embeddings") print(f" ✓ Index size: {len(index)}") # Test search print("\n Testing search...") query = embeddings[0] # Search for first embedding results = index.search(query, k=3) print(f" ✓ Found {len(results)} results") print(f" ✓ Best match distance: {results[0]['distance']:.6f}") print(f" ✓ Best match metadata: {results[0]['metadata']}") # Test save/load print("\n Testing save/load...") test_path = "data/test_faiss_index" Path("data").mkdir(exist_ok=True) index.save(test_path) print(f" ✓ Index saved to {test_path}") # Load in new index index2 = FAISSIndex(dimension) index2.load(test_path) print(f" ✓ Index loaded") print(f" ✓ Loaded size: {len(index2)}") # Verify search results are identical results2 = index2.search(query, k=3) assert len(results) == len(results2), "Result count mismatch" assert abs(results[0]['distance'] - results2[0]['distance']) < 1e-6, "Distance mismatch" print(f" ✓ Search results match after load") # Test dimension mismatch print("\n Testing dimension validation...") try: wrong_embedding = np.random.randn(1, 256).astype('float32') index.add(wrong_embedding, [{"id": "wrong"}]) print(f" ❌ Should have raised ValueError for dimension mismatch") return False except ValueError as e: print(f" ✓ Correctly rejected wrong dimension: {e}") print("\n ✅ FAISSIndex: ALL TESTS PASSED") return True except Exception as e: print(f"\n ❌ FAISSIndex: FAILED - {e}") import traceback traceback.print_exc() return False def test_embedding_manager(): """Test 3: EmbeddingManager""" print("\n" + "="*60) print("TEST 3: EmbeddingManager") print("="*60) try: # Initialize print(" Initializing EmbeddingManager...") manager = EmbeddingManager(model_name="clip", cache_size=10) print(f" ✓ Manager initialized: {manager.get_model_name()}") # Test single embedding print("\n Testing single embedding...") img = create_test_image(color=(255, 128, 0)) emb1 = manager.embed(img) print(f" ✓ Embedding generated: shape={emb1.shape}") # Test cache hit print("\n Testing cache...") emb2 = manager.embed(img) # Should hit cache assert np.array_equal(emb1, emb2), "Cache should return identical embedding" stats = manager.get_stats() print(f" ✓ Cache hits: {stats['cache_hits']}") print(f" ✓ Cache misses: {stats['cache_misses']}") print(f" ✓ Hit rate: {stats['cache_hit_rate']:.1%}") # Test cache eviction print("\n Testing cache eviction...") for i in range(15): # More than cache_size=10 img = create_test_image(color=(i*10, i*10, i*10)) manager.embed(img) stats = manager.get_stats() print(f" ✓ Cache size: {stats['cache_size']} (max: {stats['cache_capacity']})") assert stats['cache_size'] <= 10, "Cache should not exceed max size" # Test batch embedding print("\n Testing batch embedding...") images = [create_test_image(color=(i, i, i)) for i in range(5)] embeddings = manager.embed_batch(images) print(f" ✓ Batch embeddings: shape={embeddings.shape}") # Test fallback (simulate) print("\n Testing fallback...") manager_fallback = EmbeddingManager( model_name="nonexistent_model", fallback_enabled=True ) print(f" ✓ Fallback to: {manager_fallback.get_model_name()}") assert "clip" in manager_fallback.get_model_name().lower(), "Should fallback to CLIP" print("\n ✅ EmbeddingManager: ALL TESTS PASSED") return True except Exception as e: print(f"\n ❌ EmbeddingManager: FAILED - {e}") import traceback traceback.print_exc() return False def test_integration(): """Test 4: Intégration complète""" print("\n" + "="*60) print("TEST 4: Intégration Complète") print("="*60) try: # Setup print(" Setting up integrated system...") manager = EmbeddingManager(model_name="clip") index = FAISSIndex(manager.get_dimension()) print(f" ✓ Manager: {manager.get_model_name()}") print(f" ✓ Index dimension: {index.dimension}") # Create and index some images print("\n Creating and indexing images...") images = [ create_test_image(color=(255, 0, 0)), # Red create_test_image(color=(0, 255, 0)), # Green create_test_image(color=(0, 0, 255)), # Blue create_test_image(color=(255, 255, 0)), # Yellow create_test_image(color=(255, 0, 255)), # Magenta ] labels = ["red", "green", "blue", "yellow", "magenta"] for img, label in zip(images, labels): emb = manager.embed(img) index.add(emb.reshape(1, -1), [{"label": label}]) print(f" ✓ Indexed {len(index)} images") # Search for similar image print("\n Testing similarity search...") query_img = create_test_image(color=(255, 0, 0)) # Red query_emb = manager.embed(query_img) results = index.search(query_emb, k=3) print(f" ✓ Found {len(results)} similar images:") for i, result in enumerate(results): print(f" {i+1}. {result['metadata']['label']}: " f"similarity={result['similarity']:.4f}") # Verify red is most similar to red assert results[0]['metadata']['label'] == 'red', \ "Most similar should be red" print(f" ✓ Correct match found") # Test persistence print("\n Testing persistence...") test_path = "data/test_integrated_index" index.save(test_path) index2 = FAISSIndex(manager.get_dimension()) index2.load(test_path) results2 = index2.search(query_emb, k=3) assert results[0]['metadata']['label'] == results2[0]['metadata']['label'], \ "Results should match after load" print(f" ✓ Persistence verified") print("\n ✅ Integration: ALL TESTS PASSED") return True except Exception as e: print(f"\n ❌ Integration: FAILED - {e}") import traceback traceback.print_exc() return False def main(): """Run all tests.""" print("\n" + "="*60) print("TESTING NEW EMBEDDING SYSTEM") print("="*60) results = { "CLIPEmbedder": test_clip_embedder(), "FAISSIndex": test_faiss_index(), "EmbeddingManager": test_embedding_manager(), "Integration": test_integration() } # Summary print("\n" + "="*60) print("TEST SUMMARY") print("="*60) for test_name, passed in results.items(): status = "✅ PASSED" if passed else "❌ FAILED" print(f" {test_name}: {status}") all_passed = all(results.values()) if all_passed: print("\n🎉 ALL TESTS PASSED! System is ready.") return 0 else: print("\n❌ SOME TESTS FAILED. Please review errors above.") return 1 if __name__ == "__main__": sys.exit(main())