Initial commit
This commit is contained in:
210
test_complete_embedding_system.py
Executable file
210
test_complete_embedding_system.py
Executable file
@@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test complet du système d'embeddings avec fine-tuning.
|
||||
|
||||
Ce script teste le workflow complet:
|
||||
1. Génération d'embeddings
|
||||
2. Indexation FAISS
|
||||
3. Recherche de similarité
|
||||
4. Collection d'exemples
|
||||
5. Fine-tuning automatique
|
||||
6. Amélioration de la précision
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from geniusia2.core.embedders import (
|
||||
CLIPEmbedder,
|
||||
EmbeddingManager,
|
||||
FAISSIndex,
|
||||
LightweightFineTuner
|
||||
)
|
||||
|
||||
|
||||
def create_button_image(text, color, size=(200, 100)):
|
||||
"""Crée une image de bouton."""
|
||||
img = Image.new('RGB', size, color=(240, 240, 240))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Bouton
|
||||
button_rect = [20, 20, size[0]-20, size[1]-20]
|
||||
draw.rectangle(button_rect, fill=color, outline=(50, 50, 50), width=2)
|
||||
|
||||
# Texte
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
text_bbox = draw.textbbox((0, 0), text, font=font)
|
||||
text_width = text_bbox[2] - text_bbox[0]
|
||||
text_height = text_bbox[3] - text_bbox[1]
|
||||
text_pos = (size[0]//2 - text_width//2, size[1]//2 - text_height//2)
|
||||
draw.text(text_pos, text, fill=(255, 255, 255), font=font)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def main():
|
||||
print("\n" + "="*70)
|
||||
print("TEST COMPLET DU SYSTÈME D'EMBEDDINGS AVEC FINE-TUNING")
|
||||
print("="*70)
|
||||
|
||||
# 1. Setup
|
||||
print("\n1. Initialisation du système...")
|
||||
manager = EmbeddingManager(model_name="clip", cache_size=100)
|
||||
index = FAISSIndex(manager.get_dimension())
|
||||
fine_tuner = LightweightFineTuner(
|
||||
embedder=manager.embedder,
|
||||
trigger_threshold=5, # Trigger rapide pour le test
|
||||
max_examples=100
|
||||
)
|
||||
print(f" ✓ Manager: {manager.get_model_name()}")
|
||||
print(f" ✓ Index dimension: {index.dimension}")
|
||||
print(f" ✓ Fine-tuner: trigger={fine_tuner.trigger_threshold}")
|
||||
|
||||
# 2. Créer des workflows simulés
|
||||
print("\n2. Création de workflows simulés...")
|
||||
workflows = {
|
||||
'submit': [
|
||||
create_button_image("Submit", (100, 150, 255)),
|
||||
create_button_image("Send", (100, 150, 255)),
|
||||
create_button_image("OK", (100, 150, 255)),
|
||||
],
|
||||
'cancel': [
|
||||
create_button_image("Cancel", (255, 100, 100)),
|
||||
create_button_image("Close", (255, 100, 100)),
|
||||
create_button_image("Abort", (255, 100, 100)),
|
||||
]
|
||||
}
|
||||
print(f" ✓ Créé {len(workflows)} workflows avec 3 variantes chacun")
|
||||
|
||||
# 3. Indexer les workflows
|
||||
print("\n3. Indexation des workflows...")
|
||||
for workflow_id, images in workflows.items():
|
||||
for i, img in enumerate(images):
|
||||
emb = manager.embed(img)
|
||||
index.add(emb.reshape(1, -1), [{
|
||||
'workflow_id': workflow_id,
|
||||
'variant': i,
|
||||
'label': f"{workflow_id}_{i}"
|
||||
}])
|
||||
print(f" ✓ Indexé {len(index)} images")
|
||||
|
||||
# 4. Test de recherche AVANT fine-tuning
|
||||
print("\n4. Test de recherche AVANT fine-tuning...")
|
||||
test_submit = create_button_image("Submit", (100, 150, 255))
|
||||
test_cancel = create_button_image("Cancel", (255, 100, 100))
|
||||
|
||||
emb_submit = manager.embed(test_submit)
|
||||
results_submit = index.search(emb_submit, k=3)
|
||||
|
||||
print(" Recherche pour 'Submit':")
|
||||
for i, r in enumerate(results_submit):
|
||||
print(f" {i+1}. {r['metadata']['label']}: "
|
||||
f"similarity={r['similarity']:.4f}")
|
||||
|
||||
# Vérifier la précision
|
||||
correct_before = sum(1 for r in results_submit
|
||||
if r['metadata']['workflow_id'] == 'submit')
|
||||
accuracy_before = correct_before / len(results_submit)
|
||||
print(f" Précision AVANT: {accuracy_before:.1%} ({correct_before}/3)")
|
||||
|
||||
# 5. Simuler des interactions utilisateur
|
||||
print("\n5. Simulation d'interactions utilisateur...")
|
||||
print(" Scénario: L'utilisateur accepte 'submit', rejette 'cancel'")
|
||||
|
||||
# Accepter des workflows 'submit'
|
||||
for img in workflows['submit']:
|
||||
fine_tuner.add_positive_example(img, 'submit')
|
||||
print(f" ✓ Exemple positif ajouté (submit)")
|
||||
|
||||
# Rejeter des workflows 'cancel'
|
||||
for img in workflows['cancel']:
|
||||
fine_tuner.add_negative_example(img, 'cancel')
|
||||
print(f" ✓ Exemple négatif ajouté (cancel)")
|
||||
|
||||
stats = fine_tuner.get_stats()
|
||||
print(f" Total exemples: {stats['total_examples']} "
|
||||
f"(+{stats['positive_examples']}, -{stats['negative_examples']})")
|
||||
|
||||
# 6. Attendre le fine-tuning
|
||||
print("\n6. Fine-tuning en cours...")
|
||||
if stats['is_training']:
|
||||
print(" ⏳ Attente de la fin du fine-tuning...")
|
||||
fine_tuner.wait_for_training(timeout=120)
|
||||
print(" ✓ Fine-tuning terminé")
|
||||
|
||||
# Afficher les métriques
|
||||
if fine_tuner.metrics_history:
|
||||
metrics = fine_tuner.metrics_history[-1]
|
||||
print(f" Métriques:")
|
||||
print(f" - Loss: {metrics.get('loss', 'N/A'):.4f}")
|
||||
print(f" - Durée: {metrics.get('duration_seconds', 0):.1f}s")
|
||||
print(f" - Exemples: +{metrics.get('positive_count', 0)}, "
|
||||
f"-{metrics.get('negative_count', 0)}")
|
||||
else:
|
||||
print(" ⚠️ Fine-tuning pas encore déclenché")
|
||||
print(f" (Besoin de {fine_tuner.trigger_threshold} exemples, "
|
||||
f"actuellement {stats['total_examples']})")
|
||||
|
||||
# 7. Test de recherche APRÈS fine-tuning
|
||||
print("\n7. Test de recherche APRÈS fine-tuning...")
|
||||
|
||||
# Régénérer les embeddings avec le modèle fine-tuné
|
||||
manager._cache.clear() # Clear cache pour forcer régénération
|
||||
|
||||
emb_submit_after = manager.embed(test_submit)
|
||||
results_submit_after = index.search(emb_submit_after, k=3)
|
||||
|
||||
print(" Recherche pour 'Submit':")
|
||||
for i, r in enumerate(results_submit_after):
|
||||
print(f" {i+1}. {r['metadata']['label']}: "
|
||||
f"similarity={r['similarity']:.4f}")
|
||||
|
||||
# Vérifier l'amélioration
|
||||
correct_after = sum(1 for r in results_submit_after
|
||||
if r['metadata']['workflow_id'] == 'submit')
|
||||
accuracy_after = correct_after / len(results_submit_after)
|
||||
print(f" Précision APRÈS: {accuracy_after:.1%} ({correct_after}/3)")
|
||||
|
||||
# 8. Résumé
|
||||
print("\n" + "="*70)
|
||||
print("RÉSUMÉ")
|
||||
print("="*70)
|
||||
|
||||
print(f"\nPrécision:")
|
||||
print(f" Avant fine-tuning: {accuracy_before:.1%}")
|
||||
print(f" Après fine-tuning: {accuracy_after:.1%}")
|
||||
|
||||
if accuracy_after > accuracy_before:
|
||||
improvement = (accuracy_after - accuracy_before) * 100
|
||||
print(f" Amélioration: +{improvement:.1f}%")
|
||||
print("\n✅ Le fine-tuning a amélioré la précision!")
|
||||
elif accuracy_after == accuracy_before:
|
||||
print(f" Amélioration: 0%")
|
||||
print("\n⚠️ Pas d'amélioration (peut-être besoin de plus d'exemples)")
|
||||
else:
|
||||
print("\n⚠️ Précision dégradée (rare, peut nécessiter plus d'epochs)")
|
||||
|
||||
print(f"\nStatistiques du cache:")
|
||||
cache_stats = manager.get_stats()
|
||||
print(f" Hit rate: {cache_stats['cache_hit_rate']:.1%}")
|
||||
print(f" Taille: {cache_stats['cache_size']}/{cache_stats['cache_capacity']}")
|
||||
|
||||
print(f"\nStatistiques du fine-tuning:")
|
||||
ft_stats = fine_tuner.get_stats()
|
||||
print(f" Trainings: {ft_stats['training_count']}")
|
||||
print(f" Exemples collectés: {ft_stats['total_examples']}")
|
||||
|
||||
print("\n✅ Test complet terminé!")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user