Initial commit
This commit is contained in:
798
geniusia2/core/utils/vision_utils.py
Normal file
798
geniusia2/core/utils/vision_utils.py
Normal file
@@ -0,0 +1,798 @@
|
||||
"""
|
||||
Utilitaires de vision pour détection d'éléments UI
|
||||
Fournit des interfaces vers les modèles de vision (OWL-v2, Grounding DINO, YOLO-World)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from ..models import Detection
|
||||
from ..config import get_config, get_model_config
|
||||
from .image_utils import extract_roi
|
||||
|
||||
# Configuration du logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VisionUtils:
|
||||
"""
|
||||
Classe utilitaire pour la détection d'éléments UI avec plusieurs modèles de vision
|
||||
Supporte OWL-v2, Grounding DINO et YOLO-World avec fallback automatique
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Initialise VisionUtils avec les modèles de vision
|
||||
|
||||
Args:
|
||||
config: Configuration optionnelle (utilise CONFIG global si None)
|
||||
"""
|
||||
self.config = config or get_config()
|
||||
self.model_config = get_model_config()
|
||||
|
||||
# Modèle principal configuré
|
||||
self.primary_model = self.model_config.get("vision", "owl-v2")
|
||||
|
||||
# Ordre de fallback des modèles
|
||||
self.fallback_order = ["owl-v2", "dino", "yolo"]
|
||||
|
||||
# Modèles chargés (lazy loading)
|
||||
self._models = {}
|
||||
self._models_loaded = {
|
||||
"owl-v2": False,
|
||||
"dino": False,
|
||||
"yolo": False,
|
||||
}
|
||||
|
||||
logger.info(f"VisionUtils initialisé avec modèle principal: {self.primary_model}")
|
||||
|
||||
def _load_owlv2(self) -> Any:
|
||||
"""
|
||||
Charge le modèle OWL-v2 (OWLv2 pour détection open-vocabulary)
|
||||
|
||||
Returns:
|
||||
Modèle OWL-v2 chargé
|
||||
"""
|
||||
try:
|
||||
logger.info("Chargement du modèle OWL-v2...")
|
||||
|
||||
# Import dynamique pour éviter les dépendances si non utilisé
|
||||
from transformers import Owlv2Processor, Owlv2ForObjectDetection
|
||||
import torch
|
||||
|
||||
model_path = self.model_config["paths"].get("owl_v2")
|
||||
|
||||
# Charger le modèle pré-entraîné
|
||||
processor = Owlv2Processor.from_pretrained(
|
||||
"google/owlv2-base-patch16-ensemble",
|
||||
cache_dir=model_path
|
||||
)
|
||||
model = Owlv2ForObjectDetection.from_pretrained(
|
||||
"google/owlv2-base-patch16-ensemble",
|
||||
cache_dir=model_path
|
||||
)
|
||||
|
||||
# Déplacer vers GPU si disponible
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
self._models["owl-v2"] = {
|
||||
"processor": processor,
|
||||
"model": model,
|
||||
"device": device
|
||||
}
|
||||
self._models_loaded["owl-v2"] = True
|
||||
|
||||
logger.info(f"OWL-v2 chargé avec succès sur {device}")
|
||||
return self._models["owl-v2"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors du chargement d'OWL-v2: {e}")
|
||||
self._models_loaded["owl-v2"] = False
|
||||
raise
|
||||
|
||||
def _load_dino(self) -> Any:
|
||||
"""
|
||||
Charge le modèle Grounding DINO
|
||||
|
||||
Returns:
|
||||
Modèle Grounding DINO chargé
|
||||
"""
|
||||
try:
|
||||
logger.info("Chargement du modèle Grounding DINO...")
|
||||
|
||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
import torch
|
||||
|
||||
# Charger le modèle Grounding DINO depuis HuggingFace
|
||||
model_id = "IDEA-Research/grounding-dino-tiny"
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)
|
||||
|
||||
# Déplacer vers GPU si disponible
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
self._models["dino"] = {
|
||||
"processor": processor,
|
||||
"model": model,
|
||||
"device": device
|
||||
}
|
||||
self._models_loaded["dino"] = True
|
||||
|
||||
logger.info(f"Grounding DINO chargé avec succès sur {device}")
|
||||
return self._models["dino"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors du chargement de Grounding DINO: {e}")
|
||||
self._models_loaded["dino"] = False
|
||||
self._models["dino"] = {"model": None, "loaded": False}
|
||||
return self._models["dino"]
|
||||
|
||||
def _load_yolo(self) -> Any:
|
||||
"""
|
||||
Charge le modèle YOLO-World
|
||||
|
||||
Returns:
|
||||
Modèle YOLO-World chargé
|
||||
"""
|
||||
try:
|
||||
logger.info("Chargement du modèle YOLO-World...")
|
||||
|
||||
from ultralytics import YOLOWorld
|
||||
|
||||
# Charger YOLO-World (modèle pré-entraîné)
|
||||
model = YOLOWorld("yolov8s-worldv2.pt")
|
||||
|
||||
self._models["yolo"] = {
|
||||
"model": model
|
||||
}
|
||||
self._models_loaded["yolo"] = True
|
||||
|
||||
logger.info("YOLO-World chargé avec succès")
|
||||
return self._models["yolo"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors du chargement de YOLO-World: {e}")
|
||||
self._models_loaded["yolo"] = False
|
||||
self._models["yolo"] = {"model": None, "loaded": False}
|
||||
return self._models["yolo"]
|
||||
|
||||
def _ensure_model_loaded(self, model_name: str) -> bool:
|
||||
"""
|
||||
S'assure qu'un modèle est chargé
|
||||
|
||||
Args:
|
||||
model_name: Nom du modèle ("owl-v2", "dino", "yolo")
|
||||
|
||||
Returns:
|
||||
True si le modèle est chargé avec succès
|
||||
"""
|
||||
if self._models_loaded.get(model_name, False):
|
||||
return True
|
||||
|
||||
try:
|
||||
if model_name == "owl-v2":
|
||||
self._load_owlv2()
|
||||
elif model_name == "dino":
|
||||
self._load_dino()
|
||||
elif model_name == "yolo":
|
||||
self._load_yolo()
|
||||
else:
|
||||
logger.error(f"Modèle inconnu: {model_name}")
|
||||
return False
|
||||
|
||||
return self._models_loaded.get(model_name, False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Impossible de charger le modèle {model_name}: {e}")
|
||||
return False
|
||||
|
||||
def detect_with_owlv2(self, prompt: str, frame: np.ndarray) -> List[Detection]:
|
||||
"""
|
||||
Détection d'éléments UI avec OWL-v2
|
||||
|
||||
Args:
|
||||
prompt: Description textuelle de l'élément à détecter
|
||||
frame: Image de l'écran (numpy array RGB)
|
||||
|
||||
Returns:
|
||||
Liste de détections trouvées
|
||||
"""
|
||||
try:
|
||||
# S'assurer que le modèle est chargé
|
||||
if not self._ensure_model_loaded("owl-v2"):
|
||||
logger.error("OWL-v2 n'est pas disponible")
|
||||
return []
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
model_data = self._models["owl-v2"]
|
||||
processor = model_data["processor"]
|
||||
model = model_data["model"]
|
||||
device = model_data["device"]
|
||||
|
||||
# Convertir frame numpy en PIL Image
|
||||
if frame.dtype != np.uint8:
|
||||
frame = (frame * 255).astype(np.uint8)
|
||||
image = Image.fromarray(frame)
|
||||
|
||||
# Préparer les prompts (OWL-v2 accepte plusieurs prompts)
|
||||
texts = [[prompt]]
|
||||
|
||||
# Traiter l'image et le texte
|
||||
inputs = processor(text=texts, images=image, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Inférence
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Post-traitement des résultats
|
||||
target_sizes = torch.tensor([image.size[::-1]]).to(device)
|
||||
results = processor.post_process_object_detection(
|
||||
outputs=outputs,
|
||||
threshold=0.1, # Seuil bas pour capturer plus de détections
|
||||
target_sizes=target_sizes
|
||||
)[0]
|
||||
|
||||
# Convertir en objets Detection
|
||||
detections = []
|
||||
boxes = results["boxes"].cpu().numpy()
|
||||
scores = results["scores"].cpu().numpy()
|
||||
labels = results["labels"].cpu().numpy()
|
||||
|
||||
for box, score, label in zip(boxes, scores, labels):
|
||||
# Convertir bbox de [x1, y1, x2, y2] vers [x, y, w, h]
|
||||
x1, y1, x2, y2 = box
|
||||
x, y = int(x1), int(y1)
|
||||
w, h = int(x2 - x1), int(y2 - y1)
|
||||
|
||||
# Extraire ROI pour embedding
|
||||
roi = extract_roi(frame, (x, y, w, h))
|
||||
|
||||
# Créer embedding simple (sera remplacé par OpenCLIP plus tard)
|
||||
embedding = np.random.rand(512) # Placeholder
|
||||
|
||||
detection = Detection(
|
||||
label=prompt,
|
||||
confidence=float(score),
|
||||
bbox=(x, y, w, h),
|
||||
embedding=embedding,
|
||||
model_source="owl-v2",
|
||||
roi_image=roi,
|
||||
metadata={
|
||||
"label_id": int(label),
|
||||
"raw_box": box.tolist()
|
||||
}
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
logger.info(f"OWL-v2: {len(detections)} détections pour '{prompt}'")
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la détection OWL-v2: {e}")
|
||||
return []
|
||||
|
||||
def detect_with_dino(self, prompt: str, frame: np.ndarray) -> List[Detection]:
|
||||
"""
|
||||
Détection d'éléments UI avec Grounding DINO
|
||||
|
||||
Args:
|
||||
prompt: Description textuelle de l'élément à détecter
|
||||
frame: Image de l'écran (numpy array RGB)
|
||||
|
||||
Returns:
|
||||
Liste de détections trouvées
|
||||
"""
|
||||
try:
|
||||
# S'assurer que le modèle est chargé
|
||||
if not self._ensure_model_loaded("dino"):
|
||||
logger.warning("Grounding DINO n'est pas disponible")
|
||||
return []
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
model_data = self._models["dino"]
|
||||
if not model_data.get("model"):
|
||||
return []
|
||||
|
||||
processor = model_data["processor"]
|
||||
model = model_data["model"]
|
||||
device = model_data["device"]
|
||||
|
||||
# Convertir frame numpy en PIL Image
|
||||
if frame.dtype != np.uint8:
|
||||
frame = (frame * 255).astype(np.uint8)
|
||||
image = Image.fromarray(frame)
|
||||
|
||||
# Préparer les inputs
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Inférence
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Post-traitement
|
||||
target_sizes = torch.tensor([image.size[::-1]]).to(device)
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs=outputs,
|
||||
input_ids=inputs["input_ids"],
|
||||
threshold=0.3,
|
||||
target_sizes=target_sizes
|
||||
)[0]
|
||||
|
||||
# Convertir en objets Detection
|
||||
detections = []
|
||||
boxes = results["boxes"].cpu().numpy()
|
||||
scores = results["scores"].cpu().numpy()
|
||||
labels = results["labels"]
|
||||
|
||||
for box, score, label in zip(boxes, scores, labels):
|
||||
x1, y1, x2, y2 = box
|
||||
x, y = int(x1), int(y1)
|
||||
w, h = int(x2 - x1), int(y2 - y1)
|
||||
|
||||
roi = extract_roi(frame, (x, y, w, h))
|
||||
embedding = np.random.rand(512) # Placeholder
|
||||
|
||||
detection = Detection(
|
||||
label=label,
|
||||
confidence=float(score),
|
||||
bbox=(x, y, w, h),
|
||||
embedding=embedding,
|
||||
model_source="dino",
|
||||
roi_image=roi,
|
||||
metadata={"raw_box": box.tolist()}
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
logger.info(f"Grounding DINO: {len(detections)} détections pour '{prompt}'")
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la détection Grounding DINO: {e}")
|
||||
return []
|
||||
|
||||
def detect_with_yolo(self, prompt: str, frame: np.ndarray) -> List[Detection]:
|
||||
"""
|
||||
Détection d'éléments UI avec YOLO-World
|
||||
|
||||
Args:
|
||||
prompt: Description textuelle de l'élément à détecter
|
||||
frame: Image de l'écran (numpy array RGB)
|
||||
|
||||
Returns:
|
||||
Liste de détections trouvées
|
||||
"""
|
||||
try:
|
||||
# S'assurer que le modèle est chargé
|
||||
if not self._ensure_model_loaded("yolo"):
|
||||
logger.warning("YOLO-World n'est pas disponible")
|
||||
return []
|
||||
|
||||
model_data = self._models["yolo"]
|
||||
if not model_data.get("model"):
|
||||
return []
|
||||
|
||||
model = model_data["model"]
|
||||
|
||||
# Définir les classes à détecter (YOLO-World accepte des prompts textuels)
|
||||
model.set_classes([prompt])
|
||||
|
||||
# Convertir BGR vers RGB si nécessaire
|
||||
if frame.dtype != np.uint8:
|
||||
frame = (frame * 255).astype(np.uint8)
|
||||
|
||||
# Inférence
|
||||
results = model.predict(frame, conf=0.1, verbose=False)
|
||||
|
||||
# Convertir en objets Detection
|
||||
detections = []
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
for box in boxes:
|
||||
# Extraire les coordonnées
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
x, y = int(x1), int(y1)
|
||||
w, h = int(x2 - x1), int(y2 - y1)
|
||||
|
||||
# Score de confiance
|
||||
confidence = float(box.conf[0])
|
||||
|
||||
# Classe détectée
|
||||
cls_id = int(box.cls[0])
|
||||
label = model.names[cls_id] if cls_id < len(model.names) else prompt
|
||||
|
||||
roi = extract_roi(frame, (x, y, w, h))
|
||||
embedding = np.random.rand(512) # Placeholder
|
||||
|
||||
detection = Detection(
|
||||
label=label,
|
||||
confidence=confidence,
|
||||
bbox=(x, y, w, h),
|
||||
embedding=embedding,
|
||||
model_source="yolo",
|
||||
roi_image=roi,
|
||||
metadata={"class_id": cls_id}
|
||||
)
|
||||
detections.append(detection)
|
||||
|
||||
logger.info(f"YOLO-World: {len(detections)} détections pour '{prompt}'")
|
||||
return detections
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur lors de la détection YOLO-World: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def detect(self, prompt: str, frame: np.ndarray,
|
||||
model: Optional[str] = None) -> List[Detection]:
|
||||
"""
|
||||
Détection d'éléments UI avec fallback automatique entre modèles
|
||||
|
||||
Args:
|
||||
prompt: Description textuelle de l'élément à détecter
|
||||
frame: Image de l'écran (numpy array RGB)
|
||||
model: Modèle spécifique à utiliser (None = utiliser le modèle principal)
|
||||
|
||||
Returns:
|
||||
Liste de détections trouvées
|
||||
"""
|
||||
# Déterminer l'ordre des modèles à essayer
|
||||
if model:
|
||||
models_to_try = [model] + [m for m in self.fallback_order if m != model]
|
||||
else:
|
||||
models_to_try = [self.primary_model] + [m for m in self.fallback_order if m != self.primary_model]
|
||||
|
||||
# Essayer chaque modèle jusqu'à obtenir des détections
|
||||
for model_name in models_to_try:
|
||||
try:
|
||||
logger.info(f"Tentative de détection avec {model_name}...")
|
||||
|
||||
if model_name == "owl-v2":
|
||||
detections = self.detect_with_owlv2(prompt, frame)
|
||||
elif model_name == "dino":
|
||||
detections = self.detect_with_dino(prompt, frame)
|
||||
elif model_name == "yolo":
|
||||
detections = self.detect_with_yolo(prompt, frame)
|
||||
else:
|
||||
logger.warning(f"Modèle inconnu: {model_name}")
|
||||
continue
|
||||
|
||||
# Si des détections sont trouvées, retourner
|
||||
if detections:
|
||||
logger.info(f"Détection réussie avec {model_name}: {len(detections)} éléments")
|
||||
return detections
|
||||
else:
|
||||
logger.warning(f"Aucune détection avec {model_name}, essai du modèle suivant...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur avec {model_name}: {e}, essai du modèle suivant...")
|
||||
continue
|
||||
|
||||
# Aucun modèle n'a réussi
|
||||
logger.error(f"Aucun modèle n'a pu détecter '{prompt}'")
|
||||
return []
|
||||
|
||||
def select_best_detection(self, detections: List[Detection],
|
||||
context: Optional[Dict[str, Any]] = None) -> Optional[Detection]:
|
||||
"""
|
||||
Sélectionne la meilleure détection parmi une liste
|
||||
|
||||
Args:
|
||||
detections: Liste de détections à évaluer
|
||||
context: Contexte additionnel pour la sélection (position précédente, etc.)
|
||||
|
||||
Returns:
|
||||
La meilleure détection ou None si la liste est vide
|
||||
"""
|
||||
if not detections:
|
||||
return None
|
||||
|
||||
# Si une seule détection, la retourner
|
||||
if len(detections) == 1:
|
||||
return detections[0]
|
||||
|
||||
# Stratégie de sélection basée sur plusieurs critères
|
||||
best_detection = None
|
||||
best_score = -1
|
||||
|
||||
for detection in detections:
|
||||
score = detection.confidence
|
||||
|
||||
# Bonus pour les détections du modèle principal
|
||||
if detection.model_source == self.primary_model:
|
||||
score *= 1.1
|
||||
|
||||
# Si contexte fourni avec position précédente, favoriser les détections proches
|
||||
if context and "previous_bbox" in context:
|
||||
prev_x, prev_y, prev_w, prev_h = context["previous_bbox"]
|
||||
curr_x, curr_y, curr_w, curr_h = detection.bbox
|
||||
|
||||
# Calculer la distance entre les centres
|
||||
prev_center = (prev_x + prev_w / 2, prev_y + prev_h / 2)
|
||||
curr_center = (curr_x + curr_w / 2, curr_y + curr_h / 2)
|
||||
distance = np.sqrt(
|
||||
(prev_center[0] - curr_center[0]) ** 2 +
|
||||
(prev_center[1] - curr_center[1]) ** 2
|
||||
)
|
||||
|
||||
# Bonus inversement proportionnel à la distance (max 20% bonus)
|
||||
proximity_bonus = max(0, 1 - distance / 500) * 0.2
|
||||
score *= (1 + proximity_bonus)
|
||||
|
||||
# Favoriser les détections avec des bounding boxes de taille raisonnable
|
||||
x, y, w, h = detection.bbox
|
||||
area = w * h
|
||||
if 100 < area < 100000: # Taille raisonnable pour un élément UI
|
||||
score *= 1.05
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_detection = detection
|
||||
|
||||
logger.info(f"Meilleure détection sélectionnée: {best_detection.label} "
|
||||
f"(confiance: {best_detection.confidence:.2f}, "
|
||||
f"modèle: {best_detection.model_source})")
|
||||
|
||||
return best_detection
|
||||
|
||||
def filter_detections(self, detections: List[Detection],
|
||||
min_confidence: float = 0.3,
|
||||
max_detections: int = 10) -> List[Detection]:
|
||||
"""
|
||||
Filtre les détections selon des critères de qualité
|
||||
|
||||
Args:
|
||||
detections: Liste de détections à filtrer
|
||||
min_confidence: Confiance minimale requise
|
||||
max_detections: Nombre maximum de détections à retourner
|
||||
|
||||
Returns:
|
||||
Liste filtrée et triée de détections
|
||||
"""
|
||||
# Filtrer par confiance minimale
|
||||
filtered = [d for d in detections if d.confidence >= min_confidence]
|
||||
|
||||
# Trier par confiance décroissante
|
||||
filtered.sort(key=lambda d: d.confidence, reverse=True)
|
||||
|
||||
# Limiter le nombre de détections
|
||||
filtered = filtered[:max_detections]
|
||||
|
||||
logger.info(f"Filtrage: {len(detections)} -> {len(filtered)} détections "
|
||||
f"(seuil: {min_confidence})")
|
||||
|
||||
return filtered
|
||||
|
||||
def merge_overlapping_detections(self, detections: List[Detection],
|
||||
iou_threshold: float = 0.5) -> List[Detection]:
|
||||
"""
|
||||
Fusionne les détections qui se chevauchent (même élément détecté plusieurs fois)
|
||||
|
||||
Args:
|
||||
detections: Liste de détections
|
||||
iou_threshold: Seuil d'IoU pour considérer deux détections comme identiques
|
||||
|
||||
Returns:
|
||||
Liste de détections fusionnées
|
||||
"""
|
||||
if len(detections) <= 1:
|
||||
return detections
|
||||
|
||||
def calculate_iou(box1: Tuple[int, int, int, int],
|
||||
box2: Tuple[int, int, int, int]) -> float:
|
||||
"""Calcule l'Intersection over Union entre deux bounding boxes"""
|
||||
x1, y1, w1, h1 = box1
|
||||
x2, y2, w2, h2 = box2
|
||||
|
||||
# Coordonnées de l'intersection
|
||||
xi1 = max(x1, x2)
|
||||
yi1 = max(y1, y2)
|
||||
xi2 = min(x1 + w1, x2 + w2)
|
||||
yi2 = min(y1 + h1, y2 + h2)
|
||||
|
||||
# Aire de l'intersection
|
||||
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
|
||||
|
||||
# Aires des deux boxes
|
||||
box1_area = w1 * h1
|
||||
box2_area = w2 * h2
|
||||
|
||||
# Union
|
||||
union_area = box1_area + box2_area - inter_area
|
||||
|
||||
# IoU
|
||||
return inter_area / union_area if union_area > 0 else 0
|
||||
|
||||
# Trier par confiance décroissante
|
||||
sorted_detections = sorted(detections, key=lambda d: d.confidence, reverse=True)
|
||||
|
||||
merged = []
|
||||
used = set()
|
||||
|
||||
for i, det1 in enumerate(sorted_detections):
|
||||
if i in used:
|
||||
continue
|
||||
|
||||
# Trouver toutes les détections qui se chevauchent avec det1
|
||||
overlapping = [det1]
|
||||
for j, det2 in enumerate(sorted_detections[i+1:], start=i+1):
|
||||
if j in used:
|
||||
continue
|
||||
|
||||
iou = calculate_iou(det1.bbox, det2.bbox)
|
||||
if iou >= iou_threshold:
|
||||
overlapping.append(det2)
|
||||
used.add(j)
|
||||
|
||||
# Si plusieurs détections se chevauchent, garder celle avec la meilleure confiance
|
||||
# (det1 est déjà la meilleure car la liste est triée)
|
||||
merged.append(det1)
|
||||
used.add(i)
|
||||
|
||||
logger.info(f"Fusion: {len(detections)} -> {len(merged)} détections "
|
||||
f"(seuil IoU: {iou_threshold})")
|
||||
|
||||
return merged
|
||||
|
||||
def get_detection_statistics(self, detections: List[Detection]) -> Dict[str, Any]:
|
||||
"""
|
||||
Calcule des statistiques sur une liste de détections
|
||||
|
||||
Args:
|
||||
detections: Liste de détections
|
||||
|
||||
Returns:
|
||||
Dictionnaire de statistiques
|
||||
"""
|
||||
if not detections:
|
||||
return {
|
||||
"count": 0,
|
||||
"avg_confidence": 0.0,
|
||||
"max_confidence": 0.0,
|
||||
"min_confidence": 0.0,
|
||||
"models_used": []
|
||||
}
|
||||
|
||||
confidences = [d.confidence for d in detections]
|
||||
models = [d.model_source for d in detections]
|
||||
|
||||
stats = {
|
||||
"count": len(detections),
|
||||
"avg_confidence": float(np.mean(confidences)),
|
||||
"max_confidence": float(np.max(confidences)),
|
||||
"min_confidence": float(np.min(confidences)),
|
||||
"std_confidence": float(np.std(confidences)),
|
||||
"models_used": list(set(models)),
|
||||
"model_distribution": {model: models.count(model) for model in set(models)}
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def unload_models(self):
|
||||
"""Décharge tous les modèles de la mémoire"""
|
||||
logger.info("Déchargement des modèles de vision...")
|
||||
self._models.clear()
|
||||
self._models_loaded = {k: False for k in self._models_loaded}
|
||||
|
||||
# Forcer le garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Si CUDA disponible, vider le cache
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
logger.info("Modèles déchargés")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Tests basiques de VisionUtils"""
|
||||
import sys
|
||||
|
||||
print("Test de VisionUtils")
|
||||
print("=" * 50)
|
||||
|
||||
# Initialiser VisionUtils
|
||||
print("\n1. Initialisation de VisionUtils...")
|
||||
vision = VisionUtils()
|
||||
print(f" Modèle principal: {vision.primary_model}")
|
||||
print(f" Ordre de fallback: {vision.fallback_order}")
|
||||
|
||||
# Créer une image de test
|
||||
print("\n2. Création d'une image de test...")
|
||||
test_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
||||
print(f" Taille de l'image: {test_frame.shape}")
|
||||
|
||||
# Test de détection (nécessite les modèles installés)
|
||||
print("\n3. Test de détection...")
|
||||
try:
|
||||
detections = vision.detect("button", test_frame)
|
||||
print(f" Détections trouvées: {len(detections)}")
|
||||
|
||||
if detections:
|
||||
print("\n4. Statistiques des détections:")
|
||||
stats = vision.get_detection_statistics(detections)
|
||||
for key, value in stats.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\n5. Sélection de la meilleure détection:")
|
||||
best = vision.select_best_detection(detections)
|
||||
if best:
|
||||
print(f" Label: {best.label}")
|
||||
print(f" Confiance: {best.confidence:.2f}")
|
||||
print(f" BBox: {best.bbox}")
|
||||
print(f" Modèle: {best.model_source}")
|
||||
except Exception as e:
|
||||
print(f" Erreur lors de la détection: {e}")
|
||||
print(" (Normal si les modèles ne sont pas installés)")
|
||||
|
||||
# Test de filtrage
|
||||
print("\n6. Test de filtrage de détections...")
|
||||
mock_detections = [
|
||||
Detection(
|
||||
label="button1",
|
||||
confidence=0.95,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button2",
|
||||
confidence=0.25,
|
||||
bbox=(200, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button3",
|
||||
confidence=0.75,
|
||||
bbox=(300, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
]
|
||||
|
||||
filtered = vision.filter_detections(mock_detections, min_confidence=0.5)
|
||||
print(f" Détections avant filtrage: {len(mock_detections)}")
|
||||
print(f" Détections après filtrage: {len(filtered)}")
|
||||
|
||||
# Test de fusion
|
||||
print("\n7. Test de fusion de détections chevauchantes...")
|
||||
overlapping_detections = [
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.95,
|
||||
bbox=(100, 100, 50, 30),
|
||||
embedding=np.random.rand(512),
|
||||
model_source="owl-v2"
|
||||
),
|
||||
Detection(
|
||||
label="button",
|
||||
confidence=0.85,
|
||||
bbox=(105, 102, 48, 28), # Légèrement décalé
|
||||
embedding=np.random.rand(512),
|
||||
model_source="dino"
|
||||
),
|
||||
]
|
||||
|
||||
merged = vision.merge_overlapping_detections(overlapping_detections, iou_threshold=0.5)
|
||||
print(f" Détections avant fusion: {len(overlapping_detections)}")
|
||||
print(f" Détections après fusion: {len(merged)}")
|
||||
|
||||
print("\n✓ Tests basiques terminés!")
|
||||
Reference in New Issue
Block a user