Initial commit
This commit is contained in:
194
geniusia2/tests/test_embeddings_manager.py
Normal file
194
geniusia2/tests/test_embeddings_manager.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Tests pour le gestionnaire d'embeddings.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from geniusia2.core.embeddings_manager import EmbeddingsManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_index_dir():
|
||||
"""Crée un répertoire temporaire pour les tests."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield temp_dir
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embeddings_manager(temp_index_dir):
|
||||
"""Crée une instance du gestionnaire d'embeddings pour les tests."""
|
||||
return EmbeddingsManager(
|
||||
model_name="ViT-B-32",
|
||||
pretrained="openai",
|
||||
index_path=temp_index_dir,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
|
||||
def test_initialization(embeddings_manager):
|
||||
"""Test l'initialisation du gestionnaire."""
|
||||
assert embeddings_manager.clip_model is not None
|
||||
assert embeddings_manager.faiss_index is not None
|
||||
assert embeddings_manager.embedding_dim > 0
|
||||
|
||||
|
||||
def test_encode_image(embeddings_manager):
|
||||
"""Test l'encodage d'une image."""
|
||||
# Créer une image de test
|
||||
test_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
||||
|
||||
# Encoder l'image
|
||||
embedding = embeddings_manager.encode_image(test_image)
|
||||
|
||||
# Vérifier la forme et le type
|
||||
assert isinstance(embedding, np.ndarray)
|
||||
assert embedding.shape == (embeddings_manager.embedding_dim,)
|
||||
assert embedding.dtype == np.float32 or embedding.dtype == np.float64
|
||||
|
||||
|
||||
def test_add_to_index(embeddings_manager):
|
||||
"""Test l'ajout d'embeddings à l'index."""
|
||||
# Créer un embedding de test
|
||||
embedding = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
metadata = {"task_id": "test_task", "label": "button"}
|
||||
|
||||
# Ajouter à l'index
|
||||
idx = embeddings_manager.add_to_index(embedding, metadata)
|
||||
|
||||
# Vérifier
|
||||
assert idx == 0
|
||||
assert embeddings_manager.faiss_index.ntotal == 1
|
||||
assert idx in embeddings_manager.metadata_store
|
||||
|
||||
|
||||
def test_search_similar(embeddings_manager):
|
||||
"""Test la recherche de similarité."""
|
||||
# Ajouter plusieurs embeddings
|
||||
embeddings = []
|
||||
for i in range(5):
|
||||
emb = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
embeddings.append(emb)
|
||||
embeddings_manager.add_to_index(emb, {"id": i})
|
||||
|
||||
# Rechercher des embeddings similaires
|
||||
query = embeddings[0]
|
||||
results = embeddings_manager.search_similar(query, k=3)
|
||||
|
||||
# Vérifier
|
||||
assert len(results) == 3
|
||||
assert results[0]["id"] == 0 # Le plus similaire devrait être lui-même
|
||||
assert "distance" in results[0]
|
||||
assert "similarity" in results[0]
|
||||
assert "metadata" in results[0]
|
||||
|
||||
|
||||
def test_embedding_similarity(embeddings_manager):
|
||||
"""Test le calcul de similarité cosinus."""
|
||||
# Créer deux embeddings identiques
|
||||
emb1 = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
emb2 = emb1.copy()
|
||||
|
||||
# Calculer la similarité
|
||||
similarity = embeddings_manager.get_embedding_similarity(emb1, emb2)
|
||||
|
||||
# Devrait être proche de 1.0
|
||||
assert 0.99 <= similarity <= 1.0
|
||||
|
||||
# Créer deux embeddings orthogonaux
|
||||
emb3 = np.zeros(embeddings_manager.embedding_dim, dtype=np.float32)
|
||||
emb3[0] = 1.0
|
||||
emb4 = np.zeros(embeddings_manager.embedding_dim, dtype=np.float32)
|
||||
emb4[1] = 1.0
|
||||
|
||||
similarity_ortho = embeddings_manager.get_embedding_similarity(emb3, emb4)
|
||||
|
||||
# Devrait être proche de 0.5 (orthogonal)
|
||||
assert 0.4 <= similarity_ortho <= 0.6
|
||||
|
||||
|
||||
def test_save_and_load_index(temp_index_dir):
|
||||
"""Test la sauvegarde et le chargement de l'index."""
|
||||
# Créer un gestionnaire et ajouter des données
|
||||
manager1 = EmbeddingsManager(
|
||||
model_name="ViT-B-32",
|
||||
index_path=temp_index_dir,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
embedding = np.random.randn(manager1.embedding_dim).astype(np.float32)
|
||||
metadata = {"test": "data"}
|
||||
manager1.add_to_index(embedding, metadata)
|
||||
manager1.save_index()
|
||||
|
||||
# Créer un nouveau gestionnaire qui charge l'index
|
||||
manager2 = EmbeddingsManager(
|
||||
model_name="ViT-B-32",
|
||||
index_path=temp_index_dir,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# Vérifier que les données sont chargées
|
||||
assert manager2.faiss_index.ntotal == 1
|
||||
assert 0 in manager2.metadata_store
|
||||
assert manager2.metadata_store[0]["test"] == "data"
|
||||
|
||||
|
||||
def test_rebuild_index(embeddings_manager):
|
||||
"""Test la reconstruction de l'index."""
|
||||
# Ajouter des embeddings
|
||||
for i in range(3):
|
||||
emb = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
embeddings_manager.add_to_index(emb, {"id": i})
|
||||
|
||||
initial_count = embeddings_manager.faiss_index.ntotal
|
||||
|
||||
# Reconstruire l'index
|
||||
embeddings_manager.rebuild_index()
|
||||
|
||||
# Vérifier que le nombre d'embeddings est préservé
|
||||
assert embeddings_manager.faiss_index.ntotal == initial_count
|
||||
|
||||
|
||||
def test_get_stats(embeddings_manager):
|
||||
"""Test l'obtention des statistiques."""
|
||||
stats = embeddings_manager.get_stats()
|
||||
|
||||
assert "num_embeddings" in stats
|
||||
assert "embedding_dim" in stats
|
||||
assert "model_name" in stats
|
||||
assert "device" in stats
|
||||
assert stats["model_name"] == "ViT-B-32"
|
||||
|
||||
|
||||
def test_clear_index(embeddings_manager):
|
||||
"""Test l'effacement de l'index."""
|
||||
# Ajouter des embeddings
|
||||
for i in range(3):
|
||||
emb = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
embeddings_manager.add_to_index(emb, {"id": i})
|
||||
|
||||
assert embeddings_manager.faiss_index.ntotal == 3
|
||||
|
||||
# Effacer l'index
|
||||
embeddings_manager.clear_index()
|
||||
|
||||
# Vérifier
|
||||
assert embeddings_manager.faiss_index.ntotal == 0
|
||||
assert len(embeddings_manager.metadata_store) == 0
|
||||
|
||||
|
||||
def test_empty_search(embeddings_manager):
|
||||
"""Test la recherche sur un index vide."""
|
||||
query = np.random.randn(embeddings_manager.embedding_dim).astype(np.float32)
|
||||
results = embeddings_manager.search_similar(query, k=5)
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
528
geniusia2/tests/test_human_logger.py
Normal file
528
geniusia2/tests/test_human_logger.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
Tests unitaires pour HumanLogger
|
||||
Teste le formatage des messages, les emojis et les messages contextuels.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ajouter le répertoire parent au path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from geniusia2.gui.human_logger import HumanLogger
|
||||
|
||||
|
||||
def test_initialization():
|
||||
"""Test l'initialisation de HumanLogger"""
|
||||
print("Test 1: Initialisation...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Vérifier que les événements "première fois" sont initialisés
|
||||
assert "workflow_detected" in logger.first_time_events
|
||||
assert "mode_change" in logger.first_time_events
|
||||
assert "pattern_detected" in logger.first_time_events
|
||||
assert "finetuning_started" in logger.first_time_events
|
||||
|
||||
# Tous doivent être False au départ
|
||||
assert logger.first_time_events["workflow_detected"] is False
|
||||
assert logger.first_time_events["mode_change"] is False
|
||||
assert logger.first_time_events["pattern_detected"] is False
|
||||
assert logger.first_time_events["finetuning_started"] is False
|
||||
|
||||
print("✓ Initialisation OK")
|
||||
|
||||
|
||||
def test_log_observation_format():
|
||||
"""Test le formatage des messages d'observation"""
|
||||
print("\nTest 2: Formatage log_observation...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Test avec différentes fenêtres
|
||||
msg1 = logger.log_observation("click", "Calculator")
|
||||
assert "👀" in msg1, "L'emoji d'observation doit être présent"
|
||||
assert "Calculator" in msg1, "Le nom de la fenêtre doit être présent"
|
||||
assert "J'observe" in msg1, "Le message doit contenir 'J'observe'"
|
||||
|
||||
msg2 = logger.log_observation("type", "Firefox")
|
||||
assert "👀" in msg2
|
||||
assert "Firefox" in msg2
|
||||
|
||||
msg3 = logger.log_observation("scroll", "LibreOffice Writer")
|
||||
assert "👀" in msg3
|
||||
assert "LibreOffice Writer" in msg3
|
||||
|
||||
print("✓ Formatage log_observation OK")
|
||||
|
||||
|
||||
def test_log_pattern_detected_emoji():
|
||||
"""Test l'emoji correct pour pattern détecté"""
|
||||
print("\nTest 3: Emoji log_pattern_detected...")
|
||||
logger = HumanLogger()
|
||||
|
||||
msg = logger.log_pattern_detected(3, "Calculer 9/9")
|
||||
assert "🎯" in msg, "L'emoji de pattern doit être présent"
|
||||
assert "3 fois" in msg, "Le nombre de répétitions doit être présent"
|
||||
|
||||
print("✓ Emoji log_pattern_detected OK")
|
||||
|
||||
|
||||
def test_log_pattern_detected_contextual():
|
||||
"""Test les messages contextuels pour pattern détecté"""
|
||||
print("\nTest 4: Messages contextuels log_pattern_detected...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Première fois - doit inclure le contexte
|
||||
msg1 = logger.log_pattern_detected(3, "Calculer 9/9")
|
||||
assert "Je commence à apprendre" in msg1, "Le message contextuel doit être présent la première fois"
|
||||
assert logger.first_time_events["pattern_detected"] is True, "Le flag doit être mis à True"
|
||||
|
||||
# Deuxième fois - ne doit pas inclure le contexte
|
||||
msg2 = logger.log_pattern_detected(4, "Ouvrir fichier")
|
||||
assert "Je commence à apprendre" not in msg2, "Le message contextuel ne doit pas être répété"
|
||||
|
||||
print("✓ Messages contextuels log_pattern_detected OK")
|
||||
|
||||
|
||||
def test_log_workflow_learned_format():
|
||||
"""Test le formatage des messages de workflow appris"""
|
||||
print("\nTest 5: Formatage log_workflow_learned...")
|
||||
logger = HumanLogger()
|
||||
|
||||
msg = logger.log_workflow_learned("Ouvrir facture", 5)
|
||||
assert "📚" in msg, "L'emoji de workflow doit être présent"
|
||||
assert "Ouvrir facture" in msg, "Le nom du workflow doit être présent"
|
||||
assert "5 observations" in msg, "Le nombre d'observations doit être présent"
|
||||
assert "J'apprends" in msg, "Le message doit contenir 'J'apprends'"
|
||||
|
||||
print("✓ Formatage log_workflow_learned OK")
|
||||
|
||||
|
||||
def test_log_workflow_learned_contextual():
|
||||
"""Test les messages contextuels pour workflow appris"""
|
||||
print("\nTest 6: Messages contextuels log_workflow_learned...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Première fois - doit inclure l'explication
|
||||
msg1 = logger.log_workflow_learned("Ouvrir facture", 5)
|
||||
assert "Un workflow est une séquence" in msg1, "L'explication doit être présente la première fois"
|
||||
assert logger.first_time_events["workflow_detected"] is True
|
||||
|
||||
# Deuxième fois - ne doit pas inclure l'explication
|
||||
msg2 = logger.log_workflow_learned("Calculer total", 3)
|
||||
assert "Un workflow est une séquence" not in msg2, "L'explication ne doit pas être répétée"
|
||||
|
||||
print("✓ Messages contextuels log_workflow_learned OK")
|
||||
|
||||
|
||||
def test_log_mode_change_format():
|
||||
"""Test le formatage des messages de changement de mode"""
|
||||
print("\nTest 7: Formatage log_mode_change...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Test tous les modes
|
||||
msg_shadow = logger.log_mode_change("assist", "shadow")
|
||||
assert "✅" in msg_shadow, "L'emoji de succès doit être présent"
|
||||
assert "Observation" in msg_shadow, "Le nom du mode doit être traduit"
|
||||
|
||||
msg_assist = logger.log_mode_change("shadow", "assist")
|
||||
assert "✅" in msg_assist
|
||||
assert "Suggestions" in msg_assist
|
||||
|
||||
msg_copilot = logger.log_mode_change("assist", "copilot")
|
||||
assert "✅" in msg_copilot
|
||||
assert "Copilote" in msg_copilot
|
||||
|
||||
msg_auto = logger.log_mode_change("copilot", "auto")
|
||||
assert "✅" in msg_auto
|
||||
assert "Autonome" in msg_auto
|
||||
|
||||
print("✓ Formatage log_mode_change OK")
|
||||
|
||||
|
||||
def test_log_mode_change_contextual():
|
||||
"""Test les messages contextuels pour changement de mode"""
|
||||
print("\nTest 8: Messages contextuels log_mode_change...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Première fois vers assist - doit inclure l'explication
|
||||
msg1 = logger.log_mode_change("shadow", "assist")
|
||||
assert "Je vais maintenant vous suggérer" in msg1, "L'explication du mode assist doit être présente"
|
||||
assert logger.first_time_events["mode_change"] is True
|
||||
|
||||
# Deuxième fois - ne doit pas inclure l'explication
|
||||
msg2 = logger.log_mode_change("assist", "auto")
|
||||
assert "Je vais maintenant vous suggérer" not in msg2, "L'explication ne doit pas être répétée"
|
||||
|
||||
print("✓ Messages contextuels log_mode_change OK")
|
||||
|
||||
|
||||
def test_log_mode_change_explanations():
|
||||
"""Test les explications spécifiques à chaque mode"""
|
||||
print("\nTest 9: Explications spécifiques des modes...")
|
||||
|
||||
# Test mode assist
|
||||
logger_assist = HumanLogger()
|
||||
msg_assist = logger_assist.log_mode_change("shadow", "assist")
|
||||
assert "suggérer des actions" in msg_assist
|
||||
|
||||
# Test mode copilot
|
||||
logger_copilot = HumanLogger()
|
||||
msg_copilot = logger_copilot.log_mode_change("assist", "copilot")
|
||||
assert "exécuter avec votre validation" in msg_copilot
|
||||
|
||||
# Test mode auto
|
||||
logger_auto = HumanLogger()
|
||||
msg_auto = logger_auto.log_mode_change("copilot", "auto")
|
||||
assert "agir de manière autonome" in msg_auto
|
||||
|
||||
# Test mode shadow (pas d'explication spéciale)
|
||||
logger_shadow = HumanLogger()
|
||||
msg_shadow = logger_shadow.log_mode_change("assist", "shadow")
|
||||
assert "💡" not in msg_shadow or "Je vais maintenant" not in msg_shadow
|
||||
|
||||
print("✓ Explications spécifiques des modes OK")
|
||||
|
||||
|
||||
def test_log_finetuning_started_format():
|
||||
"""Test le formatage des messages de début de fine-tuning"""
|
||||
print("\nTest 10: Formatage log_finetuning_started...")
|
||||
logger = HumanLogger()
|
||||
|
||||
msg = logger.log_finetuning_started(10)
|
||||
assert "🧠" in msg, "L'emoji de cerveau doit être présent"
|
||||
assert "10 exemples" in msg, "Le nombre d'exemples doit être présent"
|
||||
assert "Amélioration du modèle" in msg, "Le message doit contenir 'Amélioration du modèle'"
|
||||
|
||||
print("✓ Formatage log_finetuning_started OK")
|
||||
|
||||
|
||||
def test_log_finetuning_started_contextual():
|
||||
"""Test les messages contextuels pour début de fine-tuning"""
|
||||
print("\nTest 11: Messages contextuels log_finetuning_started...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Première fois - doit inclure l'explication
|
||||
msg1 = logger.log_finetuning_started(10)
|
||||
assert "J'améliore ma compréhension" in msg1, "L'explication doit être présente la première fois"
|
||||
assert logger.first_time_events["finetuning_started"] is True
|
||||
|
||||
# Deuxième fois - ne doit pas inclure l'explication
|
||||
msg2 = logger.log_finetuning_started(15)
|
||||
assert "J'améliore ma compréhension" not in msg2, "L'explication ne doit pas être répétée"
|
||||
|
||||
print("✓ Messages contextuels log_finetuning_started OK")
|
||||
|
||||
|
||||
def test_log_finetuning_completed_format():
|
||||
"""Test le formatage des messages de fin de fine-tuning"""
|
||||
print("\nTest 12: Formatage log_finetuning_completed...")
|
||||
logger = HumanLogger()
|
||||
|
||||
msg1 = logger.log_finetuning_completed(2.5)
|
||||
assert "✅" in msg1, "L'emoji de succès doit être présent"
|
||||
assert "2.5s" in msg1, "La durée doit être présente avec 1 décimale"
|
||||
assert "Modèle amélioré" in msg1
|
||||
|
||||
msg2 = logger.log_finetuning_completed(10.123)
|
||||
assert "10.1s" in msg2, "La durée doit être arrondie à 1 décimale"
|
||||
|
||||
print("✓ Formatage log_finetuning_completed OK")
|
||||
|
||||
|
||||
def test_log_error_types():
|
||||
"""Test les différents types d'erreurs"""
|
||||
print("\nTest 13: Types d'erreurs...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Test tous les types d'erreurs
|
||||
error_types = ["connection", "permission", "not_found", "timeout", "whitelist", "unknown"]
|
||||
|
||||
for error_type in error_types:
|
||||
msg = logger.log_error(error_type)
|
||||
assert "⚠️" in msg, f"L'emoji d'avertissement doit être présent pour {error_type}"
|
||||
|
||||
# Vérifier les messages spécifiques
|
||||
msg_connection = logger.log_error("connection")
|
||||
assert "Impossible de se connecter" in msg_connection
|
||||
|
||||
msg_permission = logger.log_error("permission")
|
||||
assert "Permission refusée" in msg_permission
|
||||
|
||||
msg_not_found = logger.log_error("not_found")
|
||||
assert "Élément introuvable" in msg_not_found
|
||||
|
||||
msg_timeout = logger.log_error("timeout")
|
||||
assert "Délai d'attente dépassé" in msg_timeout
|
||||
|
||||
msg_whitelist = logger.log_error("whitelist")
|
||||
assert "Application non autorisée" in msg_whitelist
|
||||
|
||||
msg_unknown = logger.log_error("unknown_type")
|
||||
assert "Une erreur est survenue" in msg_unknown
|
||||
|
||||
print("✓ Types d'erreurs OK")
|
||||
|
||||
|
||||
def test_log_error_with_context():
|
||||
"""Test les erreurs avec contexte"""
|
||||
print("\nTest 14: Erreurs avec contexte...")
|
||||
logger = HumanLogger()
|
||||
|
||||
msg = logger.log_error("connection", "Calculator")
|
||||
assert "Calculator" in msg, "Le contexte doit être présent"
|
||||
|
||||
msg2 = logger.log_error("permission", "Firefox")
|
||||
assert "Firefox" in msg2
|
||||
|
||||
print("✓ Erreurs avec contexte OK")
|
||||
|
||||
|
||||
def test_log_error_suggestions():
|
||||
"""Test les suggestions correctives pour les erreurs"""
|
||||
print("\nTest 15: Suggestions correctives...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Vérifier que chaque type d'erreur a une suggestion
|
||||
msg_connection = logger.log_error("connection")
|
||||
assert "💡" in msg_connection, "Une suggestion doit être présente"
|
||||
assert "Vérifiez que l'application est ouverte" in msg_connection
|
||||
|
||||
msg_permission = logger.log_error("permission")
|
||||
assert "💡" in msg_permission
|
||||
assert "Vérifiez les permissions" in msg_permission
|
||||
|
||||
msg_not_found = logger.log_error("not_found")
|
||||
assert "💡" in msg_not_found
|
||||
assert "interface a peut-être changé" in msg_not_found
|
||||
|
||||
msg_timeout = logger.log_error("timeout")
|
||||
assert "💡" in msg_timeout
|
||||
assert "application est peut-être trop lente" in msg_timeout
|
||||
|
||||
msg_whitelist = logger.log_error("whitelist")
|
||||
assert "💡" in msg_whitelist
|
||||
assert "Ajoutez l'application à la liste autorisée" in msg_whitelist
|
||||
|
||||
print("✓ Suggestions correctives OK")
|
||||
|
||||
|
||||
def test_log_idle():
|
||||
"""Test le message d'inactivité"""
|
||||
print("\nTest 16: Message d'inactivité...")
|
||||
logger = HumanLogger()
|
||||
|
||||
msg = logger.log_idle()
|
||||
assert "💤" in msg, "L'emoji de sommeil doit être présent"
|
||||
assert "En attente" in msg, "Le message doit indiquer l'attente"
|
||||
|
||||
print("✓ Message d'inactivité OK")
|
||||
|
||||
|
||||
def test_log_stats_update():
|
||||
"""Test le formatage des statistiques"""
|
||||
print("\nTest 17: Formatage des statistiques...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Test avec 1 workflow (singulier)
|
||||
msg1 = logger.log_stats_update(12, 2, 1)
|
||||
assert "📊" in msg1, "L'emoji de statistiques doit être présent"
|
||||
assert "12 actions" in msg1
|
||||
assert "2 patterns" in msg1
|
||||
assert "1 workflow" in msg1
|
||||
assert "workflows" not in msg1, "Doit être au singulier pour 1 workflow"
|
||||
|
||||
# Test avec plusieurs workflows (pluriel)
|
||||
msg2 = logger.log_stats_update(25, 5, 3)
|
||||
assert "25 actions" in msg2
|
||||
assert "5 patterns" in msg2
|
||||
assert "3 workflows" in msg2, "Doit être au pluriel pour plusieurs workflows"
|
||||
|
||||
# Test avec 0 workflow
|
||||
msg3 = logger.log_stats_update(5, 0, 0)
|
||||
assert "0 workflow" in msg3, "Doit être au singulier pour 0 workflow"
|
||||
|
||||
print("✓ Formatage des statistiques OK")
|
||||
|
||||
|
||||
def test_log_suggestion_ready():
|
||||
"""Test le message de suggestion prête"""
|
||||
print("\nTest 18: Message de suggestion prête...")
|
||||
logger = HumanLogger()
|
||||
|
||||
msg = logger.log_suggestion_ready("Ouvrir facture")
|
||||
assert "💡" in msg, "L'emoji d'idée doit être présent"
|
||||
assert "Prêt à suggérer" in msg
|
||||
assert "Ouvrir facture" in msg, "Le nom de la tâche doit être présent"
|
||||
|
||||
print("✓ Message de suggestion prête OK")
|
||||
|
||||
|
||||
def test_log_collecting_examples():
|
||||
"""Test le message de collecte d'exemples"""
|
||||
print("\nTest 19: Message de collecte d'exemples...")
|
||||
logger = HumanLogger()
|
||||
|
||||
msg = logger.log_collecting_examples(8, 10)
|
||||
assert "🧠" in msg, "L'emoji de cerveau doit être présent"
|
||||
assert "8/10" in msg, "La progression doit être présente"
|
||||
assert "Collecte d'exemples" in msg
|
||||
|
||||
# Test avec différentes valeurs
|
||||
msg2 = logger.log_collecting_examples(1, 5)
|
||||
assert "1/5" in msg2
|
||||
|
||||
msg3 = logger.log_collecting_examples(10, 10)
|
||||
assert "10/10" in msg3
|
||||
|
||||
print("✓ Message de collecte d'exemples OK")
|
||||
|
||||
|
||||
def test_emoji_consistency():
|
||||
"""Test la cohérence des emojis utilisés"""
|
||||
print("\nTest 20: Cohérence des emojis...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Vérifier que les mêmes emojis sont utilisés de manière cohérente
|
||||
emojis = {
|
||||
"observation": "👀",
|
||||
"pattern": "🎯",
|
||||
"workflow": "📚",
|
||||
"success": "✅",
|
||||
"brain": "🧠",
|
||||
"warning": "⚠️",
|
||||
"idea": "💡",
|
||||
"stats": "📊",
|
||||
"sleep": "💤"
|
||||
}
|
||||
|
||||
# Observation
|
||||
assert emojis["observation"] in logger.log_observation("click", "Test")
|
||||
|
||||
# Pattern
|
||||
assert emojis["pattern"] in logger.log_pattern_detected(3, "Test")
|
||||
|
||||
# Workflow
|
||||
assert emojis["workflow"] in logger.log_workflow_learned("Test", 5)
|
||||
|
||||
# Success
|
||||
assert emojis["success"] in logger.log_mode_change("shadow", "assist")
|
||||
assert emojis["success"] in logger.log_finetuning_completed(2.5)
|
||||
|
||||
# Brain
|
||||
assert emojis["brain"] in logger.log_finetuning_started(10)
|
||||
assert emojis["brain"] in logger.log_collecting_examples(8, 10)
|
||||
|
||||
# Warning
|
||||
assert emojis["warning"] in logger.log_error("connection")
|
||||
|
||||
# Idea
|
||||
assert emojis["idea"] in logger.log_suggestion_ready("Test")
|
||||
|
||||
# Stats
|
||||
assert emojis["stats"] in logger.log_stats_update(12, 2, 1)
|
||||
|
||||
# Sleep
|
||||
assert emojis["sleep"] in logger.log_idle()
|
||||
|
||||
print("✓ Cohérence des emojis OK")
|
||||
|
||||
|
||||
def test_message_length():
|
||||
"""Test que les messages ne sont pas trop longs"""
|
||||
print("\nTest 21: Longueur des messages...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Les messages de base (sans contexte) ne doivent pas dépasser 80 caractères
|
||||
msg1 = logger.log_observation("click", "Calculator")
|
||||
base_msg1 = msg1.split("\n")[0] # Première ligne seulement
|
||||
assert len(base_msg1) <= 80, f"Message trop long: {len(base_msg1)} caractères"
|
||||
|
||||
# Créer un nouveau logger pour éviter les messages contextuels
|
||||
logger2 = HumanLogger()
|
||||
logger2.first_time_events = {k: True for k in logger2.first_time_events}
|
||||
|
||||
msg2 = logger2.log_pattern_detected(3, "Calculer 9/9")
|
||||
base_msg2 = msg2.split("\n")[0]
|
||||
assert len(base_msg2) <= 80, f"Message trop long: {len(base_msg2)} caractères"
|
||||
|
||||
msg3 = logger2.log_workflow_learned("Ouvrir facture", 5)
|
||||
base_msg3 = msg3.split("\n")[0]
|
||||
assert len(base_msg3) <= 80, f"Message trop long: {len(base_msg3)} caractères"
|
||||
|
||||
print("✓ Longueur des messages OK")
|
||||
|
||||
|
||||
def test_first_time_tracking():
|
||||
"""Test le tracking des événements 'première fois'"""
|
||||
print("\nTest 22: Tracking des premières fois...")
|
||||
logger = HumanLogger()
|
||||
|
||||
# Vérifier l'état initial
|
||||
assert all(not v for v in logger.first_time_events.values()), "Tous les flags doivent être False au départ"
|
||||
|
||||
# Déclencher chaque événement et vérifier le flag
|
||||
logger.log_pattern_detected(3, "Test")
|
||||
assert logger.first_time_events["pattern_detected"] is True
|
||||
|
||||
logger.log_workflow_learned("Test", 5)
|
||||
assert logger.first_time_events["workflow_detected"] is True
|
||||
|
||||
logger.log_mode_change("shadow", "assist")
|
||||
assert logger.first_time_events["mode_change"] is True
|
||||
|
||||
logger.log_finetuning_started(10)
|
||||
assert logger.first_time_events["finetuning_started"] is True
|
||||
|
||||
print("✓ Tracking des premières fois OK")
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Exécute tous les tests"""
|
||||
print("=" * 60)
|
||||
print("Tests unitaires de HumanLogger")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
test_initialization()
|
||||
test_log_observation_format()
|
||||
test_log_pattern_detected_emoji()
|
||||
test_log_pattern_detected_contextual()
|
||||
test_log_workflow_learned_format()
|
||||
test_log_workflow_learned_contextual()
|
||||
test_log_mode_change_format()
|
||||
test_log_mode_change_contextual()
|
||||
test_log_mode_change_explanations()
|
||||
test_log_finetuning_started_format()
|
||||
test_log_finetuning_started_contextual()
|
||||
test_log_finetuning_completed_format()
|
||||
test_log_error_types()
|
||||
test_log_error_with_context()
|
||||
test_log_error_suggestions()
|
||||
test_log_idle()
|
||||
test_log_stats_update()
|
||||
test_log_suggestion_ready()
|
||||
test_log_collecting_examples()
|
||||
test_emoji_consistency()
|
||||
test_message_length()
|
||||
test_first_time_tracking()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ TOUS LES TESTS RÉUSSIS!")
|
||||
print("=" * 60)
|
||||
return True
|
||||
|
||||
except AssertionError as e:
|
||||
print(f"\n✗ ÉCHEC DU TEST: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n✗ ERREUR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
405
geniusia2/tests/test_learning_manager.py
Normal file
405
geniusia2/tests/test_learning_manager.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
Tests pour le gestionnaire d'apprentissage.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, MagicMock
|
||||
|
||||
from geniusia2.core.learning_manager import LearningManager
|
||||
from geniusia2.core.models import Action, TaskProfile
|
||||
from geniusia2.core.embeddings_manager import EmbeddingsManager
|
||||
from geniusia2.core.logger import Logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dirs():
|
||||
"""Crée des répertoires temporaires pour les tests."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
profiles_dir = Path(temp_dir) / "profiles"
|
||||
logs_dir = Path(temp_dir) / "logs"
|
||||
index_dir = Path(temp_dir) / "index"
|
||||
keys_dir = Path(temp_dir) / "keys"
|
||||
|
||||
profiles_dir.mkdir(parents=True, exist_ok=True)
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
keys_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
yield {
|
||||
"profiles": str(profiles_dir),
|
||||
"logs": str(logs_dir),
|
||||
"index": str(index_dir),
|
||||
"keys": str(keys_dir)
|
||||
}
|
||||
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embeddings_manager(temp_dirs):
|
||||
"""Crée un gestionnaire d'embeddings mocké."""
|
||||
return EmbeddingsManager(
|
||||
model_name="ViT-B-32",
|
||||
index_path=temp_dirs["index"],
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(temp_dirs):
|
||||
"""Crée un logger mocké."""
|
||||
return Logger(
|
||||
log_dir=temp_dirs["logs"],
|
||||
key_path=temp_dirs["keys"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def learning_manager(mock_embeddings_manager, mock_logger, temp_dirs):
|
||||
"""Crée une instance du gestionnaire d'apprentissage pour les tests."""
|
||||
config = {
|
||||
"thresholds": {
|
||||
"autopilot_observations": 20,
|
||||
"autopilot_concordance": 0.95,
|
||||
"confidence_min": 0.90,
|
||||
"rollback_confidence": 0.85
|
||||
}
|
||||
}
|
||||
|
||||
return LearningManager(
|
||||
embeddings_manager=mock_embeddings_manager,
|
||||
logger=mock_logger,
|
||||
config=config,
|
||||
profiles_path=temp_dirs["profiles"]
|
||||
)
|
||||
|
||||
|
||||
def test_initialization(learning_manager):
|
||||
"""Test l'initialisation du gestionnaire."""
|
||||
assert learning_manager.mode == "shadow"
|
||||
assert len(learning_manager.tasks) == 0
|
||||
assert learning_manager.current_task_id is None
|
||||
|
||||
|
||||
def test_observe_action(learning_manager):
|
||||
"""Test l'enregistrement d'une observation."""
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="valider_button",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Test Window"
|
||||
)
|
||||
|
||||
learning_manager.observe(action)
|
||||
|
||||
# Vérifier qu'une tâche a été créée
|
||||
assert len(learning_manager.tasks) == 1
|
||||
|
||||
# Vérifier les propriétés de la tâche
|
||||
task_id = list(learning_manager.tasks.keys())[0]
|
||||
task = learning_manager.tasks[task_id]
|
||||
|
||||
assert task.observation_count == 1
|
||||
assert task.mode == "shadow"
|
||||
assert len(task.action_sequence) == 1
|
||||
assert task.window_whitelist == ["Test Window"]
|
||||
|
||||
|
||||
def test_mode_transition_shadow_to_assist(learning_manager):
|
||||
"""Test la transition de Shadow vers Assisté."""
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="button",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Test"
|
||||
)
|
||||
|
||||
# Observer 5 fois pour déclencher la transition
|
||||
for _ in range(5):
|
||||
learning_manager.observe(action)
|
||||
|
||||
task_id = list(learning_manager.tasks.keys())[0]
|
||||
task = learning_manager.tasks[task_id]
|
||||
|
||||
# Devrait être passé en mode Assisté
|
||||
assert task.mode == "assist"
|
||||
assert task.observation_count == 5
|
||||
|
||||
|
||||
def test_calculate_confidence(learning_manager):
|
||||
"""Test le calcul du score de confiance."""
|
||||
# Créer une tâche de test
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="button",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Test"
|
||||
)
|
||||
|
||||
learning_manager.observe(action)
|
||||
task_id = list(learning_manager.tasks.keys())[0]
|
||||
|
||||
# Calculer la confiance
|
||||
confidence = learning_manager.calculate_confidence(
|
||||
vision_conf=0.9,
|
||||
llm_score=0.8,
|
||||
task_id=task_id
|
||||
)
|
||||
|
||||
# Vérifier la formule : 0.6 * 0.9 + 0.3 * 0.8 + 0.1 * 0.0 = 0.78
|
||||
assert 0.77 <= confidence <= 0.79
|
||||
|
||||
|
||||
def test_confirm_action_accept(learning_manager):
|
||||
"""Test la confirmation d'une action acceptée."""
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="button",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Test"
|
||||
)
|
||||
|
||||
learning_manager.observe(action)
|
||||
task_id = list(learning_manager.tasks.keys())[0]
|
||||
|
||||
# Accepter l'action
|
||||
learning_manager.confirm_action({
|
||||
"type": "accept",
|
||||
"task_id": task_id
|
||||
})
|
||||
|
||||
task = learning_manager.tasks[task_id]
|
||||
assert task.concordance_rate == 1.0
|
||||
|
||||
|
||||
def test_confirm_action_reject(learning_manager):
|
||||
"""Test la confirmation d'une action rejetée."""
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="button",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Test"
|
||||
)
|
||||
|
||||
learning_manager.observe(action)
|
||||
task_id = list(learning_manager.tasks.keys())[0]
|
||||
|
||||
# Rejeter l'action
|
||||
learning_manager.confirm_action({
|
||||
"type": "reject",
|
||||
"task_id": task_id
|
||||
})
|
||||
|
||||
task = learning_manager.tasks[task_id]
|
||||
assert task.concordance_rate == 0.0
|
||||
|
||||
|
||||
def test_confirm_action_correct(learning_manager):
|
||||
"""Test la correction d'une action."""
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="wrong_button",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Test"
|
||||
)
|
||||
|
||||
learning_manager.observe(action)
|
||||
task_id = list(learning_manager.tasks.keys())[0]
|
||||
|
||||
# Corriger l'action
|
||||
corrected_action = Action(
|
||||
action_type="click",
|
||||
target_element="correct_button",
|
||||
bbox=(200, 100, 50, 30),
|
||||
confidence=0.95,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Test"
|
||||
)
|
||||
|
||||
learning_manager.confirm_action({
|
||||
"type": "correct",
|
||||
"task_id": task_id,
|
||||
"corrected_action": corrected_action
|
||||
})
|
||||
|
||||
task = learning_manager.tasks[task_id]
|
||||
assert task.correction_count == 1
|
||||
assert len(task.action_sequence) == 2
|
||||
|
||||
|
||||
def test_should_transition_to_auto(learning_manager):
|
||||
"""Test la vérification des critères pour passer en Autopilot."""
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="button",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Test"
|
||||
)
|
||||
|
||||
# Observer 20 fois
|
||||
for _ in range(20):
|
||||
learning_manager.observe(action)
|
||||
|
||||
task_id = list(learning_manager.tasks.keys())[0]
|
||||
task = learning_manager.tasks[task_id]
|
||||
|
||||
# Définir une concordance élevée
|
||||
task.concordance_rate = 0.96
|
||||
|
||||
# Devrait être éligible pour Autopilot
|
||||
assert learning_manager.should_transition_to_auto(task_id)
|
||||
|
||||
|
||||
def test_rollback_if_low_confidence(learning_manager):
|
||||
"""Test la rétrogradation en cas de confiance faible."""
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="button",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Test"
|
||||
)
|
||||
|
||||
learning_manager.observe(action)
|
||||
task_id = list(learning_manager.tasks.keys())[0]
|
||||
task = learning_manager.tasks[task_id]
|
||||
|
||||
# Forcer le mode Auto avec confiance faible
|
||||
task.mode = "auto"
|
||||
task.confidence_score = 0.85 # En dessous du seuil de 0.90
|
||||
|
||||
learning_manager.rollback_if_low_confidence(task_id)
|
||||
|
||||
# Devrait être rétrogradé en Assisté
|
||||
assert task.mode == "assist"
|
||||
|
||||
|
||||
def test_evaluate_task(learning_manager):
|
||||
"""Test l'évaluation d'une tâche."""
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="button",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Test"
|
||||
)
|
||||
|
||||
learning_manager.observe(action)
|
||||
task_id = list(learning_manager.tasks.keys())[0]
|
||||
|
||||
metrics = learning_manager.evaluate_task(task_id)
|
||||
|
||||
assert "task_id" in metrics
|
||||
assert "task_name" in metrics
|
||||
assert "mode" in metrics
|
||||
assert "observation_count" in metrics
|
||||
assert "concordance_rate" in metrics
|
||||
assert "confidence_score" in metrics
|
||||
assert "correction_count" in metrics
|
||||
assert metrics["observation_count"] == 1
|
||||
|
||||
|
||||
def test_get_all_tasks(learning_manager):
|
||||
"""Test la récupération de toutes les tâches."""
|
||||
# Créer plusieurs tâches
|
||||
for i in range(3):
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element=f"button_{i}",
|
||||
bbox=(100 + i*50, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title=f"Test_{i}"
|
||||
)
|
||||
learning_manager.observe(action)
|
||||
|
||||
all_tasks = learning_manager.get_all_tasks()
|
||||
|
||||
assert len(all_tasks) == 3
|
||||
assert all(isinstance(task, dict) for task in all_tasks)
|
||||
|
||||
|
||||
def test_get_task_stats(learning_manager):
|
||||
"""Test les statistiques globales."""
|
||||
# Créer des tâches dans différents modes
|
||||
for i in range(3):
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element=f"button_{i}",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title=f"Test_{i}"
|
||||
)
|
||||
learning_manager.observe(action)
|
||||
|
||||
stats = learning_manager.get_task_stats()
|
||||
|
||||
assert "total_tasks" in stats
|
||||
assert "shadow_tasks" in stats
|
||||
assert "assist_tasks" in stats
|
||||
assert "auto_tasks" in stats
|
||||
assert stats["total_tasks"] == 3
|
||||
|
||||
|
||||
def test_record_execution(learning_manager):
|
||||
"""Test l'enregistrement d'une exécution."""
|
||||
action = Action(
|
||||
action_type="click",
|
||||
target_element="button",
|
||||
bbox=(100, 100, 50, 30),
|
||||
confidence=0.9,
|
||||
embedding=np.random.rand(512).astype(np.float32),
|
||||
timestamp=datetime.now(),
|
||||
window_title="Test"
|
||||
)
|
||||
|
||||
learning_manager.observe(action)
|
||||
task_id = list(learning_manager.tasks.keys())[0]
|
||||
|
||||
# Enregistrer une exécution
|
||||
learning_manager.record_execution({
|
||||
"task_id": task_id,
|
||||
"confidence": 0.92
|
||||
})
|
||||
|
||||
task = learning_manager.tasks[task_id]
|
||||
assert task.confidence_score == 0.92
|
||||
assert task.last_execution is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
230
geniusia2/tests/test_llm_manager.py
Normal file
230
geniusia2/tests/test_llm_manager.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Tests pour le gestionnaire LLM.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
from geniusia2.core.llm_manager import LLMManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ollama_client():
|
||||
"""Crée un client Ollama mocké."""
|
||||
with patch('geniusia2.core.llm_manager.ollama') as mock_ollama:
|
||||
mock_client = MagicMock()
|
||||
mock_client.list.return_value = {
|
||||
'models': [{'name': 'qwen2.5-vl:3b'}]
|
||||
}
|
||||
mock_ollama.Client.return_value = mock_client
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_manager(mock_ollama_client):
|
||||
"""Crée une instance du gestionnaire LLM pour les tests."""
|
||||
return LLMManager(
|
||||
model_name="qwen2.5-vl:3b",
|
||||
ollama_host="localhost:11434",
|
||||
fallback_to_vision=True
|
||||
)
|
||||
|
||||
|
||||
def test_initialization(llm_manager):
|
||||
"""Test l'initialisation du gestionnaire."""
|
||||
assert llm_manager.model_name == "qwen2.5-vl:3b"
|
||||
assert llm_manager.ollama_host == "localhost:11434"
|
||||
assert llm_manager.fallback_to_vision is True
|
||||
|
||||
|
||||
def test_image_to_base64(llm_manager):
|
||||
"""Test la conversion d'image en base64."""
|
||||
# Créer une image de test
|
||||
test_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
|
||||
# Convertir en base64
|
||||
b64_str = llm_manager._image_to_base64(test_image)
|
||||
|
||||
# Vérifier
|
||||
assert isinstance(b64_str, str)
|
||||
assert len(b64_str) > 0
|
||||
|
||||
|
||||
def test_fallback_to_vision_only(llm_manager):
|
||||
"""Test le fallback vers la vision pure."""
|
||||
detections = [
|
||||
{"label": "button1", "confidence": 0.7, "bbox": (10, 10, 50, 30)},
|
||||
{"label": "button2", "confidence": 0.9, "bbox": (100, 10, 50, 30)},
|
||||
{"label": "button3", "confidence": 0.6, "bbox": (200, 10, 50, 30)}
|
||||
]
|
||||
|
||||
result = llm_manager._fallback_to_vision_only(detections)
|
||||
|
||||
# Devrait sélectionner button2 (confiance la plus élevée)
|
||||
assert result["element_index"] == 1
|
||||
assert result["selected_element"]["label"] == "button2"
|
||||
assert result["confidence"] == 0.9
|
||||
|
||||
|
||||
def test_fallback_empty_detections(llm_manager):
|
||||
"""Test le fallback avec des détections vides."""
|
||||
result = llm_manager._fallback_to_vision_only([])
|
||||
|
||||
assert result["selected_element"] is None
|
||||
assert result["confidence"] == 0.0
|
||||
|
||||
|
||||
def test_parse_llm_response_valid_json(llm_manager):
|
||||
"""Test le parsing d'une réponse LLM valide."""
|
||||
detections = [
|
||||
{"label": "button1", "confidence": 0.7},
|
||||
{"label": "button2", "confidence": 0.9}
|
||||
]
|
||||
|
||||
response = """
|
||||
Voici mon analyse:
|
||||
{
|
||||
"element_index": 1,
|
||||
"confidence": 0.85,
|
||||
"reasoning": "Ce bouton correspond le mieux"
|
||||
}
|
||||
"""
|
||||
|
||||
result = llm_manager._parse_llm_response(response, detections)
|
||||
|
||||
assert result["element_index"] == 1
|
||||
assert result["confidence"] == 0.85
|
||||
assert "reasoning" in result
|
||||
|
||||
|
||||
def test_parse_llm_response_invalid_json(llm_manager):
|
||||
"""Test le parsing d'une réponse LLM invalide."""
|
||||
detections = [
|
||||
{"label": "button1", "confidence": 0.7},
|
||||
{"label": "button2", "confidence": 0.9}
|
||||
]
|
||||
|
||||
response = "Ceci n'est pas du JSON valide"
|
||||
|
||||
result = llm_manager._parse_llm_response(response, detections)
|
||||
|
||||
# Devrait fallback vers la vision pure
|
||||
assert result["element_index"] == 1 # button2 a la confiance la plus élevée
|
||||
assert "fallback" in result["reasoning"].lower()
|
||||
|
||||
|
||||
def test_reason_about_detections_empty(llm_manager):
|
||||
"""Test le raisonnement avec des détections vides."""
|
||||
result = llm_manager.reason_about_detections(
|
||||
detections=[],
|
||||
context={"window": "Test"},
|
||||
intent="cliquer sur valider"
|
||||
)
|
||||
|
||||
assert result["selected_element"] is None
|
||||
assert result["confidence"] == 0.0
|
||||
|
||||
|
||||
def test_reason_about_detections_with_mock(llm_manager, mock_ollama_client):
|
||||
"""Test le raisonnement avec un client mocké."""
|
||||
detections = [
|
||||
{
|
||||
"label": "valider",
|
||||
"confidence": 0.9,
|
||||
"bbox": (100, 100, 50, 30),
|
||||
"roi_image": np.random.randint(0, 255, (30, 50, 3), dtype=np.uint8)
|
||||
}
|
||||
]
|
||||
|
||||
# Configurer le mock pour retourner une réponse JSON valide
|
||||
mock_ollama_client.generate.return_value = {
|
||||
'response': '{"element_index": 0, "confidence": 0.95, "reasoning": "Bouton valider"}'
|
||||
}
|
||||
|
||||
result = llm_manager.reason_about_detections(
|
||||
detections=detections,
|
||||
context={"window": "Test"},
|
||||
intent="cliquer sur valider"
|
||||
)
|
||||
|
||||
assert result["element_index"] == 0
|
||||
assert result["confidence"] == 0.95
|
||||
|
||||
|
||||
def test_generate_with_vision(llm_manager, mock_ollama_client):
|
||||
"""Test la génération avec vision."""
|
||||
mock_ollama_client.generate.return_value = {
|
||||
'response': 'Ceci est une réponse de test'
|
||||
}
|
||||
|
||||
test_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
|
||||
response = llm_manager.generate_with_vision(
|
||||
prompt="Décris cette image",
|
||||
images=[test_image]
|
||||
)
|
||||
|
||||
assert response == 'Ceci est une réponse de test'
|
||||
assert mock_ollama_client.generate.called
|
||||
|
||||
|
||||
def test_score_action_relevance(llm_manager, mock_ollama_client):
|
||||
"""Test le scoring de pertinence d'action."""
|
||||
mock_ollama_client.generate.return_value = {
|
||||
'response': '0.85'
|
||||
}
|
||||
|
||||
action = {
|
||||
"action_type": "click",
|
||||
"target_element": "valider_button"
|
||||
}
|
||||
|
||||
score = llm_manager.score_action_relevance(
|
||||
action=action,
|
||||
intent="valider le formulaire"
|
||||
)
|
||||
|
||||
assert 0.0 <= score <= 1.0
|
||||
assert score == 0.85
|
||||
|
||||
|
||||
def test_is_available(llm_manager, mock_ollama_client):
|
||||
"""Test la vérification de disponibilité."""
|
||||
mock_ollama_client.list.return_value = {'models': []}
|
||||
|
||||
available = llm_manager.is_available()
|
||||
|
||||
assert isinstance(available, bool)
|
||||
|
||||
|
||||
def test_get_model_info(llm_manager):
|
||||
"""Test l'obtention des informations du modèle."""
|
||||
info = llm_manager.get_model_info()
|
||||
|
||||
assert "model_name" in info
|
||||
assert "host" in info
|
||||
assert "available" in info
|
||||
assert "fallback_enabled" in info
|
||||
assert info["model_name"] == "qwen2.5-vl:3b"
|
||||
|
||||
|
||||
def test_llm_manager_without_ollama():
|
||||
"""Test l'initialisation sans Ollama disponible."""
|
||||
with patch('geniusia2.core.llm_manager.ollama', None):
|
||||
with pytest.raises(ImportError):
|
||||
LLMManager()
|
||||
|
||||
|
||||
def test_client_initialization_error():
|
||||
"""Test la gestion d'erreur lors de l'initialisation."""
|
||||
with patch('geniusia2.core.llm_manager.ollama') as mock_ollama:
|
||||
mock_ollama.Client.side_effect = Exception("Connection error")
|
||||
|
||||
# Devrait créer le manager avec fallback
|
||||
manager = LLMManager(fallback_to_vision=True)
|
||||
assert manager.client is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
274
geniusia2/tests/test_vision_utils.py
Normal file
274
geniusia2/tests/test_vision_utils.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Tests unitaires pour vision_utils.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ajouter le répertoire parent au path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from geniusia2.core.utils.vision_utils import VisionUtils
|
||||
from geniusia2.core.models import Detection
|
||||
|
||||
|
||||
def test_initialization():
|
||||
"""Test l'initialisation de VisionUtils"""
|
||||
print("Test 1: Initialisation...")
|
||||
vision = VisionUtils()
|
||||
assert vision.primary_model == "owl-v2"
|
||||
assert len(vision.fallback_order) == 3
|
||||
assert "owl-v2" in vision.fallback_order
|
||||
assert "dino" in vision.fallback_order
|
||||
assert "yolo" in vision.fallback_order
|
||||
print("✓ Initialisation OK")
|
||||
|
||||
|
||||
def test_filter_detections():
|
||||
"""Test le filtrage des détections"""
|
||||
print("\nTest 2: Filtrage des détections...")
|
||||
vision = VisionUtils()
|
||||
|
||||
detections = [
|
||||
Detection(
|
||||
label="button1",
|
||||
confidence=0.95,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button2",
|
||||
confidence=0.25,
|
||||
bbox=(200, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button3",
|
||||
confidence=0.75,
|
||||
bbox=(300, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
]
|
||||
|
||||
# Filtrer avec seuil de 0.5
|
||||
filtered = vision.filter_detections(detections, min_confidence=0.5)
|
||||
assert len(filtered) == 2
|
||||
assert all(d.confidence >= 0.5 for d in filtered)
|
||||
|
||||
# Vérifier le tri par confiance
|
||||
assert filtered[0].confidence >= filtered[1].confidence
|
||||
|
||||
print("✓ Filtrage OK")
|
||||
|
||||
|
||||
def test_select_best_detection():
|
||||
"""Test la sélection de la meilleure détection"""
|
||||
print("\nTest 3: Sélection de la meilleure détection...")
|
||||
vision = VisionUtils()
|
||||
|
||||
detections = [
|
||||
Detection(
|
||||
label="button1",
|
||||
confidence=0.75,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
Detection(
|
||||
label="button2",
|
||||
confidence=0.85,
|
||||
bbox=(200, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2" # Modèle principal
|
||||
),
|
||||
]
|
||||
|
||||
# Sans contexte, devrait favoriser owl-v2 avec bonus
|
||||
best = vision.select_best_detection(detections)
|
||||
assert best is not None
|
||||
assert best.model_source == "owl-v2"
|
||||
|
||||
# Avec contexte de position précédente
|
||||
context = {"previous_bbox": (95, 95, 50, 30)} # Proche de button1
|
||||
best_with_context = vision.select_best_detection(detections, context)
|
||||
assert best_with_context is not None
|
||||
|
||||
print("✓ Sélection OK")
|
||||
|
||||
|
||||
def test_merge_overlapping_detections():
|
||||
"""Test la fusion des détections chevauchantes"""
|
||||
print("\nTest 4: Fusion des détections chevauchantes...")
|
||||
vision = VisionUtils()
|
||||
|
||||
# Créer des détections qui se chevauchent
|
||||
detections = [
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.95,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.85,
|
||||
bbox=(105, 102, 48, 28), # Légèrement décalé, fort chevauchement
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.70,
|
||||
bbox=(300, 100, 50, 30), # Pas de chevauchement
|
||||
embedding=np.random.rand(512),
|
||||
model_source="yolo"
|
||||
),
|
||||
]
|
||||
|
||||
merged = vision.merge_overlapping_detections(detections, iou_threshold=0.5)
|
||||
|
||||
# Devrait fusionner les 2 premières détections
|
||||
assert len(merged) == 2
|
||||
|
||||
# La meilleure confiance devrait être conservée
|
||||
assert merged[0].confidence == 0.95
|
||||
|
||||
print("✓ Fusion OK")
|
||||
|
||||
|
||||
def test_get_detection_statistics():
|
||||
"""Test le calcul des statistiques"""
|
||||
print("\nTest 5: Statistiques des détections...")
|
||||
vision = VisionUtils()
|
||||
|
||||
detections = [
|
||||
Detection(
|
||||
label="button1",
|
||||
confidence=0.95,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button2",
|
||||
confidence=0.85,
|
||||
bbox=(200, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button3",
|
||||
confidence=0.75,
|
||||
bbox=(300, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
]
|
||||
|
||||
stats = vision.get_detection_statistics(detections)
|
||||
|
||||
assert stats["count"] == 3
|
||||
assert stats["max_confidence"] == 0.95
|
||||
assert stats["min_confidence"] == 0.75
|
||||
assert abs(stats["avg_confidence"] - 0.85) < 0.01
|
||||
assert "owl-v2" in stats["models_used"]
|
||||
assert "dino" in stats["models_used"]
|
||||
assert stats["model_distribution"]["owl-v2"] == 2
|
||||
assert stats["model_distribution"]["dino"] == 1
|
||||
|
||||
print("✓ Statistiques OK")
|
||||
|
||||
|
||||
def test_empty_detections():
|
||||
"""Test le comportement avec des listes vides"""
|
||||
print("\nTest 6: Gestion des listes vides...")
|
||||
vision = VisionUtils()
|
||||
|
||||
# Liste vide
|
||||
empty = []
|
||||
|
||||
# Filtrage
|
||||
filtered = vision.filter_detections(empty)
|
||||
assert len(filtered) == 0
|
||||
|
||||
# Sélection
|
||||
best = vision.select_best_detection(empty)
|
||||
assert best is None
|
||||
|
||||
# Fusion
|
||||
merged = vision.merge_overlapping_detections(empty)
|
||||
assert len(merged) == 0
|
||||
|
||||
# Statistiques
|
||||
stats = vision.get_detection_statistics(empty)
|
||||
assert stats["count"] == 0
|
||||
assert stats["avg_confidence"] == 0.0
|
||||
|
||||
print("✓ Gestion des listes vides OK")
|
||||
|
||||
|
||||
def test_single_detection():
|
||||
"""Test le comportement avec une seule détection"""
|
||||
print("\nTest 7: Gestion d'une seule détection...")
|
||||
vision = VisionUtils()
|
||||
|
||||
single = [
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.85,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
)
|
||||
]
|
||||
|
||||
# Sélection
|
||||
best = vision.select_best_detection(single)
|
||||
assert best is not None
|
||||
assert best.label == "button"
|
||||
|
||||
# Fusion
|
||||
merged = vision.merge_overlapping_detections(single)
|
||||
assert len(merged) == 1
|
||||
|
||||
print("✓ Gestion d'une seule détection OK")
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Exécute tous les tests"""
|
||||
print("=" * 60)
|
||||
print("Tests unitaires de VisionUtils")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
test_initialization()
|
||||
test_filter_detections()
|
||||
test_select_best_detection()
|
||||
test_merge_overlapping_detections()
|
||||
test_get_detection_statistics()
|
||||
test_empty_detections()
|
||||
test_single_detection()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ TOUS LES TESTS RÉUSSIS!")
|
||||
print("=" * 60)
|
||||
return True
|
||||
|
||||
except AssertionError as e:
|
||||
print(f"\n✗ ÉCHEC DU TEST: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n✗ ERREUR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user