319 lines
11 KiB
Python
Executable File
319 lines
11 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Test du nouveau système d'embeddings.
|
|
|
|
Ce script teste:
|
|
1. CLIPEmbedder - Génération d'embeddings
|
|
2. FAISSIndex - Ajout, recherche, save/load
|
|
3. EmbeddingManager - Cache, fallback
|
|
4. Intégration complète
|
|
"""
|
|
|
|
import sys
|
|
import numpy as np
|
|
from PIL import Image
|
|
from pathlib import Path
|
|
|
|
# Add project to path
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
from geniusia2.core.embedders import CLIPEmbedder, FAISSIndex, EmbeddingManager
|
|
|
|
|
|
def create_test_image(size=(224, 224), color=(255, 0, 0)):
|
|
"""Crée une image de test."""
|
|
return Image.new('RGB', size, color)
|
|
|
|
|
|
def test_clip_embedder():
|
|
"""Test 1: CLIPEmbedder"""
|
|
print("\n" + "="*60)
|
|
print("TEST 1: CLIPEmbedder")
|
|
print("="*60)
|
|
|
|
try:
|
|
# Initialize
|
|
print(" Initializing CLIPEmbedder...")
|
|
embedder = CLIPEmbedder(device='cpu')
|
|
print(f" ✓ Model loaded: {embedder.get_model_name()}")
|
|
print(f" ✓ Dimension: {embedder.get_dimension()}")
|
|
|
|
# Test single embedding
|
|
print("\n Testing single embedding...")
|
|
img = create_test_image(color=(255, 0, 0))
|
|
embedding = embedder.embed(img)
|
|
print(f" ✓ Embedding shape: {embedding.shape}")
|
|
print(f" ✓ Embedding norm: {np.linalg.norm(embedding):.4f} (should be ~1.0)")
|
|
|
|
# Test batch embedding
|
|
print("\n Testing batch embedding...")
|
|
images = [
|
|
create_test_image(color=(255, 0, 0)),
|
|
create_test_image(color=(0, 255, 0)),
|
|
create_test_image(color=(0, 0, 255))
|
|
]
|
|
embeddings = embedder.embed_batch(images)
|
|
print(f" ✓ Batch embeddings shape: {embeddings.shape}")
|
|
print(f" ✓ Expected: (3, {embedder.get_dimension()})")
|
|
|
|
# Test consistency
|
|
print("\n Testing consistency...")
|
|
img1 = create_test_image(color=(128, 128, 128))
|
|
emb1 = embedder.embed(img1)
|
|
emb2 = embedder.embed(img1)
|
|
diff = np.abs(emb1 - emb2).max()
|
|
print(f" ✓ Max difference: {diff:.10f} (should be 0.0)")
|
|
|
|
print("\n ✅ CLIPEmbedder: ALL TESTS PASSED")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"\n ❌ CLIPEmbedder: FAILED - {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def test_faiss_index():
|
|
"""Test 2: FAISSIndex"""
|
|
print("\n" + "="*60)
|
|
print("TEST 2: FAISSIndex")
|
|
print("="*60)
|
|
|
|
try:
|
|
# Initialize
|
|
print(" Initializing FAISSIndex...")
|
|
dimension = 512
|
|
index = FAISSIndex(dimension)
|
|
print(f" ✓ Index created with dimension={dimension}")
|
|
|
|
# Test add
|
|
print("\n Testing add...")
|
|
embeddings = np.random.randn(5, dimension).astype('float32')
|
|
# Normalize
|
|
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
|
|
metadata = [
|
|
{"id": i, "label": f"item_{i}"}
|
|
for i in range(5)
|
|
]
|
|
index.add(embeddings, metadata)
|
|
print(f" ✓ Added 5 embeddings")
|
|
print(f" ✓ Index size: {len(index)}")
|
|
|
|
# Test search
|
|
print("\n Testing search...")
|
|
query = embeddings[0] # Search for first embedding
|
|
results = index.search(query, k=3)
|
|
print(f" ✓ Found {len(results)} results")
|
|
print(f" ✓ Best match distance: {results[0]['distance']:.6f}")
|
|
print(f" ✓ Best match metadata: {results[0]['metadata']}")
|
|
|
|
# Test save/load
|
|
print("\n Testing save/load...")
|
|
test_path = "data/test_faiss_index"
|
|
Path("data").mkdir(exist_ok=True)
|
|
|
|
index.save(test_path)
|
|
print(f" ✓ Index saved to {test_path}")
|
|
|
|
# Load in new index
|
|
index2 = FAISSIndex(dimension)
|
|
index2.load(test_path)
|
|
print(f" ✓ Index loaded")
|
|
print(f" ✓ Loaded size: {len(index2)}")
|
|
|
|
# Verify search results are identical
|
|
results2 = index2.search(query, k=3)
|
|
assert len(results) == len(results2), "Result count mismatch"
|
|
assert abs(results[0]['distance'] - results2[0]['distance']) < 1e-6, "Distance mismatch"
|
|
print(f" ✓ Search results match after load")
|
|
|
|
# Test dimension mismatch
|
|
print("\n Testing dimension validation...")
|
|
try:
|
|
wrong_embedding = np.random.randn(1, 256).astype('float32')
|
|
index.add(wrong_embedding, [{"id": "wrong"}])
|
|
print(f" ❌ Should have raised ValueError for dimension mismatch")
|
|
return False
|
|
except ValueError as e:
|
|
print(f" ✓ Correctly rejected wrong dimension: {e}")
|
|
|
|
print("\n ✅ FAISSIndex: ALL TESTS PASSED")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"\n ❌ FAISSIndex: FAILED - {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def test_embedding_manager():
|
|
"""Test 3: EmbeddingManager"""
|
|
print("\n" + "="*60)
|
|
print("TEST 3: EmbeddingManager")
|
|
print("="*60)
|
|
|
|
try:
|
|
# Initialize
|
|
print(" Initializing EmbeddingManager...")
|
|
manager = EmbeddingManager(model_name="clip", cache_size=10)
|
|
print(f" ✓ Manager initialized: {manager.get_model_name()}")
|
|
|
|
# Test single embedding
|
|
print("\n Testing single embedding...")
|
|
img = create_test_image(color=(255, 128, 0))
|
|
emb1 = manager.embed(img)
|
|
print(f" ✓ Embedding generated: shape={emb1.shape}")
|
|
|
|
# Test cache hit
|
|
print("\n Testing cache...")
|
|
emb2 = manager.embed(img) # Should hit cache
|
|
assert np.array_equal(emb1, emb2), "Cache should return identical embedding"
|
|
stats = manager.get_stats()
|
|
print(f" ✓ Cache hits: {stats['cache_hits']}")
|
|
print(f" ✓ Cache misses: {stats['cache_misses']}")
|
|
print(f" ✓ Hit rate: {stats['cache_hit_rate']:.1%}")
|
|
|
|
# Test cache eviction
|
|
print("\n Testing cache eviction...")
|
|
for i in range(15): # More than cache_size=10
|
|
img = create_test_image(color=(i*10, i*10, i*10))
|
|
manager.embed(img)
|
|
|
|
stats = manager.get_stats()
|
|
print(f" ✓ Cache size: {stats['cache_size']} (max: {stats['cache_capacity']})")
|
|
assert stats['cache_size'] <= 10, "Cache should not exceed max size"
|
|
|
|
# Test batch embedding
|
|
print("\n Testing batch embedding...")
|
|
images = [create_test_image(color=(i, i, i)) for i in range(5)]
|
|
embeddings = manager.embed_batch(images)
|
|
print(f" ✓ Batch embeddings: shape={embeddings.shape}")
|
|
|
|
# Test fallback (simulate)
|
|
print("\n Testing fallback...")
|
|
manager_fallback = EmbeddingManager(
|
|
model_name="nonexistent_model",
|
|
fallback_enabled=True
|
|
)
|
|
print(f" ✓ Fallback to: {manager_fallback.get_model_name()}")
|
|
assert "clip" in manager_fallback.get_model_name().lower(), "Should fallback to CLIP"
|
|
|
|
print("\n ✅ EmbeddingManager: ALL TESTS PASSED")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"\n ❌ EmbeddingManager: FAILED - {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def test_integration():
|
|
"""Test 4: Intégration complète"""
|
|
print("\n" + "="*60)
|
|
print("TEST 4: Intégration Complète")
|
|
print("="*60)
|
|
|
|
try:
|
|
# Setup
|
|
print(" Setting up integrated system...")
|
|
manager = EmbeddingManager(model_name="clip")
|
|
index = FAISSIndex(manager.get_dimension())
|
|
print(f" ✓ Manager: {manager.get_model_name()}")
|
|
print(f" ✓ Index dimension: {index.dimension}")
|
|
|
|
# Create and index some images
|
|
print("\n Creating and indexing images...")
|
|
images = [
|
|
create_test_image(color=(255, 0, 0)), # Red
|
|
create_test_image(color=(0, 255, 0)), # Green
|
|
create_test_image(color=(0, 0, 255)), # Blue
|
|
create_test_image(color=(255, 255, 0)), # Yellow
|
|
create_test_image(color=(255, 0, 255)), # Magenta
|
|
]
|
|
|
|
labels = ["red", "green", "blue", "yellow", "magenta"]
|
|
|
|
for img, label in zip(images, labels):
|
|
emb = manager.embed(img)
|
|
index.add(emb.reshape(1, -1), [{"label": label}])
|
|
|
|
print(f" ✓ Indexed {len(index)} images")
|
|
|
|
# Search for similar image
|
|
print("\n Testing similarity search...")
|
|
query_img = create_test_image(color=(255, 0, 0)) # Red
|
|
query_emb = manager.embed(query_img)
|
|
results = index.search(query_emb, k=3)
|
|
|
|
print(f" ✓ Found {len(results)} similar images:")
|
|
for i, result in enumerate(results):
|
|
print(f" {i+1}. {result['metadata']['label']}: "
|
|
f"similarity={result['similarity']:.4f}")
|
|
|
|
# Verify red is most similar to red
|
|
assert results[0]['metadata']['label'] == 'red', \
|
|
"Most similar should be red"
|
|
print(f" ✓ Correct match found")
|
|
|
|
# Test persistence
|
|
print("\n Testing persistence...")
|
|
test_path = "data/test_integrated_index"
|
|
index.save(test_path)
|
|
|
|
index2 = FAISSIndex(manager.get_dimension())
|
|
index2.load(test_path)
|
|
|
|
results2 = index2.search(query_emb, k=3)
|
|
assert results[0]['metadata']['label'] == results2[0]['metadata']['label'], \
|
|
"Results should match after load"
|
|
print(f" ✓ Persistence verified")
|
|
|
|
print("\n ✅ Integration: ALL TESTS PASSED")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"\n ❌ Integration: FAILED - {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
|
|
def main():
|
|
"""Run all tests."""
|
|
print("\n" + "="*60)
|
|
print("TESTING NEW EMBEDDING SYSTEM")
|
|
print("="*60)
|
|
|
|
results = {
|
|
"CLIPEmbedder": test_clip_embedder(),
|
|
"FAISSIndex": test_faiss_index(),
|
|
"EmbeddingManager": test_embedding_manager(),
|
|
"Integration": test_integration()
|
|
}
|
|
|
|
# Summary
|
|
print("\n" + "="*60)
|
|
print("TEST SUMMARY")
|
|
print("="*60)
|
|
|
|
for test_name, passed in results.items():
|
|
status = "✅ PASSED" if passed else "❌ FAILED"
|
|
print(f" {test_name}: {status}")
|
|
|
|
all_passed = all(results.values())
|
|
|
|
if all_passed:
|
|
print("\n🎉 ALL TESTS PASSED! System is ready.")
|
|
return 0
|
|
else:
|
|
print("\n❌ SOME TESTS FAILED. Please review errors above.")
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|