Initial commit
This commit is contained in:
274
geniusia2/tests/test_vision_utils.py
Normal file
274
geniusia2/tests/test_vision_utils.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Tests unitaires pour vision_utils.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ajouter le répertoire parent au path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from geniusia2.core.utils.vision_utils import VisionUtils
|
||||
from geniusia2.core.models import Detection
|
||||
|
||||
|
||||
def test_initialization():
|
||||
"""Test l'initialisation de VisionUtils"""
|
||||
print("Test 1: Initialisation...")
|
||||
vision = VisionUtils()
|
||||
assert vision.primary_model == "owl-v2"
|
||||
assert len(vision.fallback_order) == 3
|
||||
assert "owl-v2" in vision.fallback_order
|
||||
assert "dino" in vision.fallback_order
|
||||
assert "yolo" in vision.fallback_order
|
||||
print("✓ Initialisation OK")
|
||||
|
||||
|
||||
def test_filter_detections():
|
||||
"""Test le filtrage des détections"""
|
||||
print("\nTest 2: Filtrage des détections...")
|
||||
vision = VisionUtils()
|
||||
|
||||
detections = [
|
||||
Detection(
|
||||
label="button1",
|
||||
confidence=0.95,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button2",
|
||||
confidence=0.25,
|
||||
bbox=(200, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button3",
|
||||
confidence=0.75,
|
||||
bbox=(300, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
]
|
||||
|
||||
# Filtrer avec seuil de 0.5
|
||||
filtered = vision.filter_detections(detections, min_confidence=0.5)
|
||||
assert len(filtered) == 2
|
||||
assert all(d.confidence >= 0.5 for d in filtered)
|
||||
|
||||
# Vérifier le tri par confiance
|
||||
assert filtered[0].confidence >= filtered[1].confidence
|
||||
|
||||
print("✓ Filtrage OK")
|
||||
|
||||
|
||||
def test_select_best_detection():
|
||||
"""Test la sélection de la meilleure détection"""
|
||||
print("\nTest 3: Sélection de la meilleure détection...")
|
||||
vision = VisionUtils()
|
||||
|
||||
detections = [
|
||||
Detection(
|
||||
label="button1",
|
||||
confidence=0.75,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
Detection(
|
||||
label="button2",
|
||||
confidence=0.85,
|
||||
bbox=(200, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2" # Modèle principal
|
||||
),
|
||||
]
|
||||
|
||||
# Sans contexte, devrait favoriser owl-v2 avec bonus
|
||||
best = vision.select_best_detection(detections)
|
||||
assert best is not None
|
||||
assert best.model_source == "owl-v2"
|
||||
|
||||
# Avec contexte de position précédente
|
||||
context = {"previous_bbox": (95, 95, 50, 30)} # Proche de button1
|
||||
best_with_context = vision.select_best_detection(detections, context)
|
||||
assert best_with_context is not None
|
||||
|
||||
print("✓ Sélection OK")
|
||||
|
||||
|
||||
def test_merge_overlapping_detections():
|
||||
"""Test la fusion des détections chevauchantes"""
|
||||
print("\nTest 4: Fusion des détections chevauchantes...")
|
||||
vision = VisionUtils()
|
||||
|
||||
# Créer des détections qui se chevauchent
|
||||
detections = [
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.95,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.85,
|
||||
bbox=(105, 102, 48, 28), # Légèrement décalé, fort chevauchement
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.70,
|
||||
bbox=(300, 100, 50, 30), # Pas de chevauchement
|
||||
embedding=np.random.rand(512),
|
||||
model_source="yolo"
|
||||
),
|
||||
]
|
||||
|
||||
merged = vision.merge_overlapping_detections(detections, iou_threshold=0.5)
|
||||
|
||||
# Devrait fusionner les 2 premières détections
|
||||
assert len(merged) == 2
|
||||
|
||||
# La meilleure confiance devrait être conservée
|
||||
assert merged[0].confidence == 0.95
|
||||
|
||||
print("✓ Fusion OK")
|
||||
|
||||
|
||||
def test_get_detection_statistics():
|
||||
"""Test le calcul des statistiques"""
|
||||
print("\nTest 5: Statistiques des détections...")
|
||||
vision = VisionUtils()
|
||||
|
||||
detections = [
|
||||
Detection(
|
||||
label="button1",
|
||||
confidence=0.95,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button2",
|
||||
confidence=0.85,
|
||||
bbox=(200, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button3",
|
||||
confidence=0.75,
|
||||
bbox=(300, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
]
|
||||
|
||||
stats = vision.get_detection_statistics(detections)
|
||||
|
||||
assert stats["count"] == 3
|
||||
assert stats["max_confidence"] == 0.95
|
||||
assert stats["min_confidence"] == 0.75
|
||||
assert abs(stats["avg_confidence"] - 0.85) < 0.01
|
||||
assert "owl-v2" in stats["models_used"]
|
||||
assert "dino" in stats["models_used"]
|
||||
assert stats["model_distribution"]["owl-v2"] == 2
|
||||
assert stats["model_distribution"]["dino"] == 1
|
||||
|
||||
print("✓ Statistiques OK")
|
||||
|
||||
|
||||
def test_empty_detections():
|
||||
"""Test le comportement avec des listes vides"""
|
||||
print("\nTest 6: Gestion des listes vides...")
|
||||
vision = VisionUtils()
|
||||
|
||||
# Liste vide
|
||||
empty = []
|
||||
|
||||
# Filtrage
|
||||
filtered = vision.filter_detections(empty)
|
||||
assert len(filtered) == 0
|
||||
|
||||
# Sélection
|
||||
best = vision.select_best_detection(empty)
|
||||
assert best is None
|
||||
|
||||
# Fusion
|
||||
merged = vision.merge_overlapping_detections(empty)
|
||||
assert len(merged) == 0
|
||||
|
||||
# Statistiques
|
||||
stats = vision.get_detection_statistics(empty)
|
||||
assert stats["count"] == 0
|
||||
assert stats["avg_confidence"] == 0.0
|
||||
|
||||
print("✓ Gestion des listes vides OK")
|
||||
|
||||
|
||||
def test_single_detection():
|
||||
"""Test le comportement avec une seule détection"""
|
||||
print("\nTest 7: Gestion d'une seule détection...")
|
||||
vision = VisionUtils()
|
||||
|
||||
single = [
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.85,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
)
|
||||
]
|
||||
|
||||
# Sélection
|
||||
best = vision.select_best_detection(single)
|
||||
assert best is not None
|
||||
assert best.label == "button"
|
||||
|
||||
# Fusion
|
||||
merged = vision.merge_overlapping_detections(single)
|
||||
assert len(merged) == 1
|
||||
|
||||
print("✓ Gestion d'une seule détection OK")
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Exécute tous les tests"""
|
||||
print("=" * 60)
|
||||
print("Tests unitaires de VisionUtils")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
test_initialization()
|
||||
test_filter_detections()
|
||||
test_select_best_detection()
|
||||
test_merge_overlapping_detections()
|
||||
test_get_detection_statistics()
|
||||
test_empty_detections()
|
||||
test_single_detection()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ TOUS LES TESTS RÉUSSIS!")
|
||||
print("=" * 60)
|
||||
return True
|
||||
|
||||
except AssertionError as e:
|
||||
print(f"\n✗ ÉCHEC DU TEST: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"\n✗ ERREUR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user