211 lines
7.5 KiB
Python
Executable File
211 lines
7.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Test complet du système d'embeddings avec fine-tuning.
|
|
|
|
Ce script teste le workflow complet:
|
|
1. Génération d'embeddings
|
|
2. Indexation FAISS
|
|
3. Recherche de similarité
|
|
4. Collection d'exemples
|
|
5. Fine-tuning automatique
|
|
6. Amélioration de la précision
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
from geniusia2.core.embedders import (
|
|
CLIPEmbedder,
|
|
EmbeddingManager,
|
|
FAISSIndex,
|
|
LightweightFineTuner
|
|
)
|
|
|
|
|
|
def create_button_image(text, color, size=(200, 100)):
|
|
"""Crée une image de bouton."""
|
|
img = Image.new('RGB', size, color=(240, 240, 240))
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
# Bouton
|
|
button_rect = [20, 20, size[0]-20, size[1]-20]
|
|
draw.rectangle(button_rect, fill=color, outline=(50, 50, 50), width=2)
|
|
|
|
# Texte
|
|
try:
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
|
|
except:
|
|
font = ImageFont.load_default()
|
|
|
|
text_bbox = draw.textbbox((0, 0), text, font=font)
|
|
text_width = text_bbox[2] - text_bbox[0]
|
|
text_height = text_bbox[3] - text_bbox[1]
|
|
text_pos = (size[0]//2 - text_width//2, size[1]//2 - text_height//2)
|
|
draw.text(text_pos, text, fill=(255, 255, 255), font=font)
|
|
|
|
return img
|
|
|
|
|
|
def main():
|
|
print("\n" + "="*70)
|
|
print("TEST COMPLET DU SYSTÈME D'EMBEDDINGS AVEC FINE-TUNING")
|
|
print("="*70)
|
|
|
|
# 1. Setup
|
|
print("\n1. Initialisation du système...")
|
|
manager = EmbeddingManager(model_name="clip", cache_size=100)
|
|
index = FAISSIndex(manager.get_dimension())
|
|
fine_tuner = LightweightFineTuner(
|
|
embedder=manager.embedder,
|
|
trigger_threshold=5, # Trigger rapide pour le test
|
|
max_examples=100
|
|
)
|
|
print(f" ✓ Manager: {manager.get_model_name()}")
|
|
print(f" ✓ Index dimension: {index.dimension}")
|
|
print(f" ✓ Fine-tuner: trigger={fine_tuner.trigger_threshold}")
|
|
|
|
# 2. Créer des workflows simulés
|
|
print("\n2. Création de workflows simulés...")
|
|
workflows = {
|
|
'submit': [
|
|
create_button_image("Submit", (100, 150, 255)),
|
|
create_button_image("Send", (100, 150, 255)),
|
|
create_button_image("OK", (100, 150, 255)),
|
|
],
|
|
'cancel': [
|
|
create_button_image("Cancel", (255, 100, 100)),
|
|
create_button_image("Close", (255, 100, 100)),
|
|
create_button_image("Abort", (255, 100, 100)),
|
|
]
|
|
}
|
|
print(f" ✓ Créé {len(workflows)} workflows avec 3 variantes chacun")
|
|
|
|
# 3. Indexer les workflows
|
|
print("\n3. Indexation des workflows...")
|
|
for workflow_id, images in workflows.items():
|
|
for i, img in enumerate(images):
|
|
emb = manager.embed(img)
|
|
index.add(emb.reshape(1, -1), [{
|
|
'workflow_id': workflow_id,
|
|
'variant': i,
|
|
'label': f"{workflow_id}_{i}"
|
|
}])
|
|
print(f" ✓ Indexé {len(index)} images")
|
|
|
|
# 4. Test de recherche AVANT fine-tuning
|
|
print("\n4. Test de recherche AVANT fine-tuning...")
|
|
test_submit = create_button_image("Submit", (100, 150, 255))
|
|
test_cancel = create_button_image("Cancel", (255, 100, 100))
|
|
|
|
emb_submit = manager.embed(test_submit)
|
|
results_submit = index.search(emb_submit, k=3)
|
|
|
|
print(" Recherche pour 'Submit':")
|
|
for i, r in enumerate(results_submit):
|
|
print(f" {i+1}. {r['metadata']['label']}: "
|
|
f"similarity={r['similarity']:.4f}")
|
|
|
|
# Vérifier la précision
|
|
correct_before = sum(1 for r in results_submit
|
|
if r['metadata']['workflow_id'] == 'submit')
|
|
accuracy_before = correct_before / len(results_submit)
|
|
print(f" Précision AVANT: {accuracy_before:.1%} ({correct_before}/3)")
|
|
|
|
# 5. Simuler des interactions utilisateur
|
|
print("\n5. Simulation d'interactions utilisateur...")
|
|
print(" Scénario: L'utilisateur accepte 'submit', rejette 'cancel'")
|
|
|
|
# Accepter des workflows 'submit'
|
|
for img in workflows['submit']:
|
|
fine_tuner.add_positive_example(img, 'submit')
|
|
print(f" ✓ Exemple positif ajouté (submit)")
|
|
|
|
# Rejeter des workflows 'cancel'
|
|
for img in workflows['cancel']:
|
|
fine_tuner.add_negative_example(img, 'cancel')
|
|
print(f" ✓ Exemple négatif ajouté (cancel)")
|
|
|
|
stats = fine_tuner.get_stats()
|
|
print(f" Total exemples: {stats['total_examples']} "
|
|
f"(+{stats['positive_examples']}, -{stats['negative_examples']})")
|
|
|
|
# 6. Attendre le fine-tuning
|
|
print("\n6. Fine-tuning en cours...")
|
|
if stats['is_training']:
|
|
print(" ⏳ Attente de la fin du fine-tuning...")
|
|
fine_tuner.wait_for_training(timeout=120)
|
|
print(" ✓ Fine-tuning terminé")
|
|
|
|
# Afficher les métriques
|
|
if fine_tuner.metrics_history:
|
|
metrics = fine_tuner.metrics_history[-1]
|
|
print(f" Métriques:")
|
|
print(f" - Loss: {metrics.get('loss', 'N/A'):.4f}")
|
|
print(f" - Durée: {metrics.get('duration_seconds', 0):.1f}s")
|
|
print(f" - Exemples: +{metrics.get('positive_count', 0)}, "
|
|
f"-{metrics.get('negative_count', 0)}")
|
|
else:
|
|
print(" ⚠️ Fine-tuning pas encore déclenché")
|
|
print(f" (Besoin de {fine_tuner.trigger_threshold} exemples, "
|
|
f"actuellement {stats['total_examples']})")
|
|
|
|
# 7. Test de recherche APRÈS fine-tuning
|
|
print("\n7. Test de recherche APRÈS fine-tuning...")
|
|
|
|
# Régénérer les embeddings avec le modèle fine-tuné
|
|
manager._cache.clear() # Clear cache pour forcer régénération
|
|
|
|
emb_submit_after = manager.embed(test_submit)
|
|
results_submit_after = index.search(emb_submit_after, k=3)
|
|
|
|
print(" Recherche pour 'Submit':")
|
|
for i, r in enumerate(results_submit_after):
|
|
print(f" {i+1}. {r['metadata']['label']}: "
|
|
f"similarity={r['similarity']:.4f}")
|
|
|
|
# Vérifier l'amélioration
|
|
correct_after = sum(1 for r in results_submit_after
|
|
if r['metadata']['workflow_id'] == 'submit')
|
|
accuracy_after = correct_after / len(results_submit_after)
|
|
print(f" Précision APRÈS: {accuracy_after:.1%} ({correct_after}/3)")
|
|
|
|
# 8. Résumé
|
|
print("\n" + "="*70)
|
|
print("RÉSUMÉ")
|
|
print("="*70)
|
|
|
|
print(f"\nPrécision:")
|
|
print(f" Avant fine-tuning: {accuracy_before:.1%}")
|
|
print(f" Après fine-tuning: {accuracy_after:.1%}")
|
|
|
|
if accuracy_after > accuracy_before:
|
|
improvement = (accuracy_after - accuracy_before) * 100
|
|
print(f" Amélioration: +{improvement:.1f}%")
|
|
print("\n✅ Le fine-tuning a amélioré la précision!")
|
|
elif accuracy_after == accuracy_before:
|
|
print(f" Amélioration: 0%")
|
|
print("\n⚠️ Pas d'amélioration (peut-être besoin de plus d'exemples)")
|
|
else:
|
|
print("\n⚠️ Précision dégradée (rare, peut nécessiter plus d'epochs)")
|
|
|
|
print(f"\nStatistiques du cache:")
|
|
cache_stats = manager.get_stats()
|
|
print(f" Hit rate: {cache_stats['cache_hit_rate']:.1%}")
|
|
print(f" Taille: {cache_stats['cache_size']}/{cache_stats['cache_capacity']}")
|
|
|
|
print(f"\nStatistiques du fine-tuning:")
|
|
ft_stats = fine_tuner.get_stats()
|
|
print(f" Trainings: {ft_stats['training_count']}")
|
|
print(f" Exemples collectés: {ft_stats['total_examples']}")
|
|
|
|
print("\n✅ Test complet terminé!")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|