799 lines
29 KiB
Python
799 lines
29 KiB
Python
"""
|
|
Utilitaires de vision pour détection d'éléments UI
|
|
Fournit des interfaces vers les modèles de vision (OWL-v2, Grounding DINO, YOLO-World)
|
|
"""
|
|
|
|
import logging
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
from ..models import Detection
|
|
from ..config import get_config, get_model_config
|
|
from .image_utils import extract_roi
|
|
|
|
# Configuration du logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class VisionUtils:
|
|
"""
|
|
Classe utilitaire pour la détection d'éléments UI avec plusieurs modèles de vision
|
|
Supporte OWL-v2, Grounding DINO et YOLO-World avec fallback automatique
|
|
"""
|
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
"""
|
|
Initialise VisionUtils avec les modèles de vision
|
|
|
|
Args:
|
|
config: Configuration optionnelle (utilise CONFIG global si None)
|
|
"""
|
|
self.config = config or get_config()
|
|
self.model_config = get_model_config()
|
|
|
|
# Modèle principal configuré
|
|
self.primary_model = self.model_config.get("vision", "owl-v2")
|
|
|
|
# Ordre de fallback des modèles
|
|
self.fallback_order = ["owl-v2", "dino", "yolo"]
|
|
|
|
# Modèles chargés (lazy loading)
|
|
self._models = {}
|
|
self._models_loaded = {
|
|
"owl-v2": False,
|
|
"dino": False,
|
|
"yolo": False,
|
|
}
|
|
|
|
logger.info(f"VisionUtils initialisé avec modèle principal: {self.primary_model}")
|
|
|
|
def _load_owlv2(self) -> Any:
|
|
"""
|
|
Charge le modèle OWL-v2 (OWLv2 pour détection open-vocabulary)
|
|
|
|
Returns:
|
|
Modèle OWL-v2 chargé
|
|
"""
|
|
try:
|
|
logger.info("Chargement du modèle OWL-v2...")
|
|
|
|
# Import dynamique pour éviter les dépendances si non utilisé
|
|
from transformers import Owlv2Processor, Owlv2ForObjectDetection
|
|
import torch
|
|
|
|
model_path = self.model_config["paths"].get("owl_v2")
|
|
|
|
# Charger le modèle pré-entraîné
|
|
processor = Owlv2Processor.from_pretrained(
|
|
"google/owlv2-base-patch16-ensemble",
|
|
cache_dir=model_path
|
|
)
|
|
model = Owlv2ForObjectDetection.from_pretrained(
|
|
"google/owlv2-base-patch16-ensemble",
|
|
cache_dir=model_path
|
|
)
|
|
|
|
# Déplacer vers GPU si disponible
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = model.to(device)
|
|
model.eval()
|
|
|
|
self._models["owl-v2"] = {
|
|
"processor": processor,
|
|
"model": model,
|
|
"device": device
|
|
}
|
|
self._models_loaded["owl-v2"] = True
|
|
|
|
logger.info(f"OWL-v2 chargé avec succès sur {device}")
|
|
return self._models["owl-v2"]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Erreur lors du chargement d'OWL-v2: {e}")
|
|
self._models_loaded["owl-v2"] = False
|
|
raise
|
|
|
|
def _load_dino(self) -> Any:
|
|
"""
|
|
Charge le modèle Grounding DINO
|
|
|
|
Returns:
|
|
Modèle Grounding DINO chargé
|
|
"""
|
|
try:
|
|
logger.info("Chargement du modèle Grounding DINO...")
|
|
|
|
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
|
import torch
|
|
|
|
# Charger le modèle Grounding DINO depuis HuggingFace
|
|
model_id = "IDEA-Research/grounding-dino-tiny"
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id)
|
|
|
|
# Déplacer vers GPU si disponible
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = model.to(device)
|
|
model.eval()
|
|
|
|
self._models["dino"] = {
|
|
"processor": processor,
|
|
"model": model,
|
|
"device": device
|
|
}
|
|
self._models_loaded["dino"] = True
|
|
|
|
logger.info(f"Grounding DINO chargé avec succès sur {device}")
|
|
return self._models["dino"]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Erreur lors du chargement de Grounding DINO: {e}")
|
|
self._models_loaded["dino"] = False
|
|
self._models["dino"] = {"model": None, "loaded": False}
|
|
return self._models["dino"]
|
|
|
|
def _load_yolo(self) -> Any:
|
|
"""
|
|
Charge le modèle YOLO-World
|
|
|
|
Returns:
|
|
Modèle YOLO-World chargé
|
|
"""
|
|
try:
|
|
logger.info("Chargement du modèle YOLO-World...")
|
|
|
|
from ultralytics import YOLOWorld
|
|
|
|
# Charger YOLO-World (modèle pré-entraîné)
|
|
model = YOLOWorld("yolov8s-worldv2.pt")
|
|
|
|
self._models["yolo"] = {
|
|
"model": model
|
|
}
|
|
self._models_loaded["yolo"] = True
|
|
|
|
logger.info("YOLO-World chargé avec succès")
|
|
return self._models["yolo"]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Erreur lors du chargement de YOLO-World: {e}")
|
|
self._models_loaded["yolo"] = False
|
|
self._models["yolo"] = {"model": None, "loaded": False}
|
|
return self._models["yolo"]
|
|
|
|
def _ensure_model_loaded(self, model_name: str) -> bool:
|
|
"""
|
|
S'assure qu'un modèle est chargé
|
|
|
|
Args:
|
|
model_name: Nom du modèle ("owl-v2", "dino", "yolo")
|
|
|
|
Returns:
|
|
True si le modèle est chargé avec succès
|
|
"""
|
|
if self._models_loaded.get(model_name, False):
|
|
return True
|
|
|
|
try:
|
|
if model_name == "owl-v2":
|
|
self._load_owlv2()
|
|
elif model_name == "dino":
|
|
self._load_dino()
|
|
elif model_name == "yolo":
|
|
self._load_yolo()
|
|
else:
|
|
logger.error(f"Modèle inconnu: {model_name}")
|
|
return False
|
|
|
|
return self._models_loaded.get(model_name, False)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Impossible de charger le modèle {model_name}: {e}")
|
|
return False
|
|
|
|
def detect_with_owlv2(self, prompt: str, frame: np.ndarray) -> List[Detection]:
|
|
"""
|
|
Détection d'éléments UI avec OWL-v2
|
|
|
|
Args:
|
|
prompt: Description textuelle de l'élément à détecter
|
|
frame: Image de l'écran (numpy array RGB)
|
|
|
|
Returns:
|
|
Liste de détections trouvées
|
|
"""
|
|
try:
|
|
# S'assurer que le modèle est chargé
|
|
if not self._ensure_model_loaded("owl-v2"):
|
|
logger.error("OWL-v2 n'est pas disponible")
|
|
return []
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
model_data = self._models["owl-v2"]
|
|
processor = model_data["processor"]
|
|
model = model_data["model"]
|
|
device = model_data["device"]
|
|
|
|
# Convertir frame numpy en PIL Image
|
|
if frame.dtype != np.uint8:
|
|
frame = (frame * 255).astype(np.uint8)
|
|
image = Image.fromarray(frame)
|
|
|
|
# Préparer les prompts (OWL-v2 accepte plusieurs prompts)
|
|
texts = [[prompt]]
|
|
|
|
# Traiter l'image et le texte
|
|
inputs = processor(text=texts, images=image, return_tensors="pt")
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
# Inférence
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
# Post-traitement des résultats
|
|
target_sizes = torch.tensor([image.size[::-1]]).to(device)
|
|
results = processor.post_process_object_detection(
|
|
outputs=outputs,
|
|
threshold=0.1, # Seuil bas pour capturer plus de détections
|
|
target_sizes=target_sizes
|
|
)[0]
|
|
|
|
# Convertir en objets Detection
|
|
detections = []
|
|
boxes = results["boxes"].cpu().numpy()
|
|
scores = results["scores"].cpu().numpy()
|
|
labels = results["labels"].cpu().numpy()
|
|
|
|
for box, score, label in zip(boxes, scores, labels):
|
|
# Convertir bbox de [x1, y1, x2, y2] vers [x, y, w, h]
|
|
x1, y1, x2, y2 = box
|
|
x, y = int(x1), int(y1)
|
|
w, h = int(x2 - x1), int(y2 - y1)
|
|
|
|
# Extraire ROI pour embedding
|
|
roi = extract_roi(frame, (x, y, w, h))
|
|
|
|
# Créer embedding simple (sera remplacé par OpenCLIP plus tard)
|
|
embedding = np.random.rand(512) # Placeholder
|
|
|
|
detection = Detection(
|
|
label=prompt,
|
|
confidence=float(score),
|
|
bbox=(x, y, w, h),
|
|
embedding=embedding,
|
|
model_source="owl-v2",
|
|
roi_image=roi,
|
|
metadata={
|
|
"label_id": int(label),
|
|
"raw_box": box.tolist()
|
|
}
|
|
)
|
|
detections.append(detection)
|
|
|
|
logger.info(f"OWL-v2: {len(detections)} détections pour '{prompt}'")
|
|
return detections
|
|
|
|
except Exception as e:
|
|
logger.error(f"Erreur lors de la détection OWL-v2: {e}")
|
|
return []
|
|
|
|
def detect_with_dino(self, prompt: str, frame: np.ndarray) -> List[Detection]:
|
|
"""
|
|
Détection d'éléments UI avec Grounding DINO
|
|
|
|
Args:
|
|
prompt: Description textuelle de l'élément à détecter
|
|
frame: Image de l'écran (numpy array RGB)
|
|
|
|
Returns:
|
|
Liste de détections trouvées
|
|
"""
|
|
try:
|
|
# S'assurer que le modèle est chargé
|
|
if not self._ensure_model_loaded("dino"):
|
|
logger.warning("Grounding DINO n'est pas disponible")
|
|
return []
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
model_data = self._models["dino"]
|
|
if not model_data.get("model"):
|
|
return []
|
|
|
|
processor = model_data["processor"]
|
|
model = model_data["model"]
|
|
device = model_data["device"]
|
|
|
|
# Convertir frame numpy en PIL Image
|
|
if frame.dtype != np.uint8:
|
|
frame = (frame * 255).astype(np.uint8)
|
|
image = Image.fromarray(frame)
|
|
|
|
# Préparer les inputs
|
|
inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
# Inférence
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
# Post-traitement
|
|
target_sizes = torch.tensor([image.size[::-1]]).to(device)
|
|
results = processor.post_process_grounded_object_detection(
|
|
outputs=outputs,
|
|
input_ids=inputs["input_ids"],
|
|
threshold=0.3,
|
|
target_sizes=target_sizes
|
|
)[0]
|
|
|
|
# Convertir en objets Detection
|
|
detections = []
|
|
boxes = results["boxes"].cpu().numpy()
|
|
scores = results["scores"].cpu().numpy()
|
|
labels = results["labels"]
|
|
|
|
for box, score, label in zip(boxes, scores, labels):
|
|
x1, y1, x2, y2 = box
|
|
x, y = int(x1), int(y1)
|
|
w, h = int(x2 - x1), int(y2 - y1)
|
|
|
|
roi = extract_roi(frame, (x, y, w, h))
|
|
embedding = np.random.rand(512) # Placeholder
|
|
|
|
detection = Detection(
|
|
label=label,
|
|
confidence=float(score),
|
|
bbox=(x, y, w, h),
|
|
embedding=embedding,
|
|
model_source="dino",
|
|
roi_image=roi,
|
|
metadata={"raw_box": box.tolist()}
|
|
)
|
|
detections.append(detection)
|
|
|
|
logger.info(f"Grounding DINO: {len(detections)} détections pour '{prompt}'")
|
|
return detections
|
|
|
|
except Exception as e:
|
|
logger.error(f"Erreur lors de la détection Grounding DINO: {e}")
|
|
return []
|
|
|
|
def detect_with_yolo(self, prompt: str, frame: np.ndarray) -> List[Detection]:
|
|
"""
|
|
Détection d'éléments UI avec YOLO-World
|
|
|
|
Args:
|
|
prompt: Description textuelle de l'élément à détecter
|
|
frame: Image de l'écran (numpy array RGB)
|
|
|
|
Returns:
|
|
Liste de détections trouvées
|
|
"""
|
|
try:
|
|
# S'assurer que le modèle est chargé
|
|
if not self._ensure_model_loaded("yolo"):
|
|
logger.warning("YOLO-World n'est pas disponible")
|
|
return []
|
|
|
|
model_data = self._models["yolo"]
|
|
if not model_data.get("model"):
|
|
return []
|
|
|
|
model = model_data["model"]
|
|
|
|
# Définir les classes à détecter (YOLO-World accepte des prompts textuels)
|
|
model.set_classes([prompt])
|
|
|
|
# Convertir BGR vers RGB si nécessaire
|
|
if frame.dtype != np.uint8:
|
|
frame = (frame * 255).astype(np.uint8)
|
|
|
|
# Inférence
|
|
results = model.predict(frame, conf=0.1, verbose=False)
|
|
|
|
# Convertir en objets Detection
|
|
detections = []
|
|
for result in results:
|
|
boxes = result.boxes
|
|
for box in boxes:
|
|
# Extraire les coordonnées
|
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
|
x, y = int(x1), int(y1)
|
|
w, h = int(x2 - x1), int(y2 - y1)
|
|
|
|
# Score de confiance
|
|
confidence = float(box.conf[0])
|
|
|
|
# Classe détectée
|
|
cls_id = int(box.cls[0])
|
|
label = model.names[cls_id] if cls_id < len(model.names) else prompt
|
|
|
|
roi = extract_roi(frame, (x, y, w, h))
|
|
embedding = np.random.rand(512) # Placeholder
|
|
|
|
detection = Detection(
|
|
label=label,
|
|
confidence=confidence,
|
|
bbox=(x, y, w, h),
|
|
embedding=embedding,
|
|
model_source="yolo",
|
|
roi_image=roi,
|
|
metadata={"class_id": cls_id}
|
|
)
|
|
detections.append(detection)
|
|
|
|
logger.info(f"YOLO-World: {len(detections)} détections pour '{prompt}'")
|
|
return detections
|
|
|
|
except Exception as e:
|
|
logger.error(f"Erreur lors de la détection YOLO-World: {e}")
|
|
return []
|
|
|
|
|
|
def detect(self, prompt: str, frame: np.ndarray,
|
|
model: Optional[str] = None) -> List[Detection]:
|
|
"""
|
|
Détection d'éléments UI avec fallback automatique entre modèles
|
|
|
|
Args:
|
|
prompt: Description textuelle de l'élément à détecter
|
|
frame: Image de l'écran (numpy array RGB)
|
|
model: Modèle spécifique à utiliser (None = utiliser le modèle principal)
|
|
|
|
Returns:
|
|
Liste de détections trouvées
|
|
"""
|
|
# Déterminer l'ordre des modèles à essayer
|
|
if model:
|
|
models_to_try = [model] + [m for m in self.fallback_order if m != model]
|
|
else:
|
|
models_to_try = [self.primary_model] + [m for m in self.fallback_order if m != self.primary_model]
|
|
|
|
# Essayer chaque modèle jusqu'à obtenir des détections
|
|
for model_name in models_to_try:
|
|
try:
|
|
logger.info(f"Tentative de détection avec {model_name}...")
|
|
|
|
if model_name == "owl-v2":
|
|
detections = self.detect_with_owlv2(prompt, frame)
|
|
elif model_name == "dino":
|
|
detections = self.detect_with_dino(prompt, frame)
|
|
elif model_name == "yolo":
|
|
detections = self.detect_with_yolo(prompt, frame)
|
|
else:
|
|
logger.warning(f"Modèle inconnu: {model_name}")
|
|
continue
|
|
|
|
# Si des détections sont trouvées, retourner
|
|
if detections:
|
|
logger.info(f"Détection réussie avec {model_name}: {len(detections)} éléments")
|
|
return detections
|
|
else:
|
|
logger.warning(f"Aucune détection avec {model_name}, essai du modèle suivant...")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Erreur avec {model_name}: {e}, essai du modèle suivant...")
|
|
continue
|
|
|
|
# Aucun modèle n'a réussi
|
|
logger.error(f"Aucun modèle n'a pu détecter '{prompt}'")
|
|
return []
|
|
|
|
def select_best_detection(self, detections: List[Detection],
|
|
context: Optional[Dict[str, Any]] = None) -> Optional[Detection]:
|
|
"""
|
|
Sélectionne la meilleure détection parmi une liste
|
|
|
|
Args:
|
|
detections: Liste de détections à évaluer
|
|
context: Contexte additionnel pour la sélection (position précédente, etc.)
|
|
|
|
Returns:
|
|
La meilleure détection ou None si la liste est vide
|
|
"""
|
|
if not detections:
|
|
return None
|
|
|
|
# Si une seule détection, la retourner
|
|
if len(detections) == 1:
|
|
return detections[0]
|
|
|
|
# Stratégie de sélection basée sur plusieurs critères
|
|
best_detection = None
|
|
best_score = -1
|
|
|
|
for detection in detections:
|
|
score = detection.confidence
|
|
|
|
# Bonus pour les détections du modèle principal
|
|
if detection.model_source == self.primary_model:
|
|
score *= 1.1
|
|
|
|
# Si contexte fourni avec position précédente, favoriser les détections proches
|
|
if context and "previous_bbox" in context:
|
|
prev_x, prev_y, prev_w, prev_h = context["previous_bbox"]
|
|
curr_x, curr_y, curr_w, curr_h = detection.bbox
|
|
|
|
# Calculer la distance entre les centres
|
|
prev_center = (prev_x + prev_w / 2, prev_y + prev_h / 2)
|
|
curr_center = (curr_x + curr_w / 2, curr_y + curr_h / 2)
|
|
distance = np.sqrt(
|
|
(prev_center[0] - curr_center[0]) ** 2 +
|
|
(prev_center[1] - curr_center[1]) ** 2
|
|
)
|
|
|
|
# Bonus inversement proportionnel à la distance (max 20% bonus)
|
|
proximity_bonus = max(0, 1 - distance / 500) * 0.2
|
|
score *= (1 + proximity_bonus)
|
|
|
|
# Favoriser les détections avec des bounding boxes de taille raisonnable
|
|
x, y, w, h = detection.bbox
|
|
area = w * h
|
|
if 100 < area < 100000: # Taille raisonnable pour un élément UI
|
|
score *= 1.05
|
|
|
|
if score > best_score:
|
|
best_score = score
|
|
best_detection = detection
|
|
|
|
logger.info(f"Meilleure détection sélectionnée: {best_detection.label} "
|
|
f"(confiance: {best_detection.confidence:.2f}, "
|
|
f"modèle: {best_detection.model_source})")
|
|
|
|
return best_detection
|
|
|
|
def filter_detections(self, detections: List[Detection],
|
|
min_confidence: float = 0.3,
|
|
max_detections: int = 10) -> List[Detection]:
|
|
"""
|
|
Filtre les détections selon des critères de qualité
|
|
|
|
Args:
|
|
detections: Liste de détections à filtrer
|
|
min_confidence: Confiance minimale requise
|
|
max_detections: Nombre maximum de détections à retourner
|
|
|
|
Returns:
|
|
Liste filtrée et triée de détections
|
|
"""
|
|
# Filtrer par confiance minimale
|
|
filtered = [d for d in detections if d.confidence >= min_confidence]
|
|
|
|
# Trier par confiance décroissante
|
|
filtered.sort(key=lambda d: d.confidence, reverse=True)
|
|
|
|
# Limiter le nombre de détections
|
|
filtered = filtered[:max_detections]
|
|
|
|
logger.info(f"Filtrage: {len(detections)} -> {len(filtered)} détections "
|
|
f"(seuil: {min_confidence})")
|
|
|
|
return filtered
|
|
|
|
def merge_overlapping_detections(self, detections: List[Detection],
|
|
iou_threshold: float = 0.5) -> List[Detection]:
|
|
"""
|
|
Fusionne les détections qui se chevauchent (même élément détecté plusieurs fois)
|
|
|
|
Args:
|
|
detections: Liste de détections
|
|
iou_threshold: Seuil d'IoU pour considérer deux détections comme identiques
|
|
|
|
Returns:
|
|
Liste de détections fusionnées
|
|
"""
|
|
if len(detections) <= 1:
|
|
return detections
|
|
|
|
def calculate_iou(box1: Tuple[int, int, int, int],
|
|
box2: Tuple[int, int, int, int]) -> float:
|
|
"""Calcule l'Intersection over Union entre deux bounding boxes"""
|
|
x1, y1, w1, h1 = box1
|
|
x2, y2, w2, h2 = box2
|
|
|
|
# Coordonnées de l'intersection
|
|
xi1 = max(x1, x2)
|
|
yi1 = max(y1, y2)
|
|
xi2 = min(x1 + w1, x2 + w2)
|
|
yi2 = min(y1 + h1, y2 + h2)
|
|
|
|
# Aire de l'intersection
|
|
inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
|
|
|
|
# Aires des deux boxes
|
|
box1_area = w1 * h1
|
|
box2_area = w2 * h2
|
|
|
|
# Union
|
|
union_area = box1_area + box2_area - inter_area
|
|
|
|
# IoU
|
|
return inter_area / union_area if union_area > 0 else 0
|
|
|
|
# Trier par confiance décroissante
|
|
sorted_detections = sorted(detections, key=lambda d: d.confidence, reverse=True)
|
|
|
|
merged = []
|
|
used = set()
|
|
|
|
for i, det1 in enumerate(sorted_detections):
|
|
if i in used:
|
|
continue
|
|
|
|
# Trouver toutes les détections qui se chevauchent avec det1
|
|
overlapping = [det1]
|
|
for j, det2 in enumerate(sorted_detections[i+1:], start=i+1):
|
|
if j in used:
|
|
continue
|
|
|
|
iou = calculate_iou(det1.bbox, det2.bbox)
|
|
if iou >= iou_threshold:
|
|
overlapping.append(det2)
|
|
used.add(j)
|
|
|
|
# Si plusieurs détections se chevauchent, garder celle avec la meilleure confiance
|
|
# (det1 est déjà la meilleure car la liste est triée)
|
|
merged.append(det1)
|
|
used.add(i)
|
|
|
|
logger.info(f"Fusion: {len(detections)} -> {len(merged)} détections "
|
|
f"(seuil IoU: {iou_threshold})")
|
|
|
|
return merged
|
|
|
|
def get_detection_statistics(self, detections: List[Detection]) -> Dict[str, Any]:
|
|
"""
|
|
Calcule des statistiques sur une liste de détections
|
|
|
|
Args:
|
|
detections: Liste de détections
|
|
|
|
Returns:
|
|
Dictionnaire de statistiques
|
|
"""
|
|
if not detections:
|
|
return {
|
|
"count": 0,
|
|
"avg_confidence": 0.0,
|
|
"max_confidence": 0.0,
|
|
"min_confidence": 0.0,
|
|
"models_used": []
|
|
}
|
|
|
|
confidences = [d.confidence for d in detections]
|
|
models = [d.model_source for d in detections]
|
|
|
|
stats = {
|
|
"count": len(detections),
|
|
"avg_confidence": float(np.mean(confidences)),
|
|
"max_confidence": float(np.max(confidences)),
|
|
"min_confidence": float(np.min(confidences)),
|
|
"std_confidence": float(np.std(confidences)),
|
|
"models_used": list(set(models)),
|
|
"model_distribution": {model: models.count(model) for model in set(models)}
|
|
}
|
|
|
|
return stats
|
|
|
|
def unload_models(self):
|
|
"""Décharge tous les modèles de la mémoire"""
|
|
logger.info("Déchargement des modèles de vision...")
|
|
self._models.clear()
|
|
self._models_loaded = {k: False for k in self._models_loaded}
|
|
|
|
# Forcer le garbage collection
|
|
import gc
|
|
gc.collect()
|
|
|
|
# Si CUDA disponible, vider le cache
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
except ImportError:
|
|
pass
|
|
|
|
logger.info("Modèles déchargés")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
"""Tests basiques de VisionUtils"""
|
|
import sys
|
|
|
|
print("Test de VisionUtils")
|
|
print("=" * 50)
|
|
|
|
# Initialiser VisionUtils
|
|
print("\n1. Initialisation de VisionUtils...")
|
|
vision = VisionUtils()
|
|
print(f" Modèle principal: {vision.primary_model}")
|
|
print(f" Ordre de fallback: {vision.fallback_order}")
|
|
|
|
# Créer une image de test
|
|
print("\n2. Création d'une image de test...")
|
|
test_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
|
print(f" Taille de l'image: {test_frame.shape}")
|
|
|
|
# Test de détection (nécessite les modèles installés)
|
|
print("\n3. Test de détection...")
|
|
try:
|
|
detections = vision.detect("button", test_frame)
|
|
print(f" Détections trouvées: {len(detections)}")
|
|
|
|
if detections:
|
|
print("\n4. Statistiques des détections:")
|
|
stats = vision.get_detection_statistics(detections)
|
|
for key, value in stats.items():
|
|
print(f" {key}: {value}")
|
|
|
|
print("\n5. Sélection de la meilleure détection:")
|
|
best = vision.select_best_detection(detections)
|
|
if best:
|
|
print(f" Label: {best.label}")
|
|
print(f" Confiance: {best.confidence:.2f}")
|
|
print(f" BBox: {best.bbox}")
|
|
print(f" Modèle: {best.model_source}")
|
|
except Exception as e:
|
|
print(f" Erreur lors de la détection: {e}")
|
|
print(" (Normal si les modèles ne sont pas installés)")
|
|
|
|
# Test de filtrage
|
|
print("\n6. Test de filtrage de détections...")
|
|
mock_detections = [
|
|
Detection(
|
|
label="button1",
|
|
confidence=0.95,
|
|
bbox=(100, 100, 50, 30),
|
|
embedding=np.random.rand(512),
|
|
model_source="owl-v2"
|
|
),
|
|
Detection(
|
|
label="button2",
|
|
confidence=0.25,
|
|
bbox=(200, 100, 50, 30),
|
|
embedding=np.random.rand(512),
|
|
model_source="owl-v2"
|
|
),
|
|
Detection(
|
|
label="button3",
|
|
confidence=0.75,
|
|
bbox=(300, 100, 50, 30),
|
|
embedding=np.random.rand(512),
|
|
model_source="dino"
|
|
),
|
|
]
|
|
|
|
filtered = vision.filter_detections(mock_detections, min_confidence=0.5)
|
|
print(f" Détections avant filtrage: {len(mock_detections)}")
|
|
print(f" Détections après filtrage: {len(filtered)}")
|
|
|
|
# Test de fusion
|
|
print("\n7. Test de fusion de détections chevauchantes...")
|
|
overlapping_detections = [
|
|
Detection(
|
|
label="button",
|
|
confidence=0.95,
|
|
bbox=(100, 100, 50, 30),
|
|
embedding=np.random.rand(512),
|
|
model_source="owl-v2"
|
|
),
|
|
Detection(
|
|
label="button",
|
|
confidence=0.85,
|
|
bbox=(105, 102, 48, 28), # Légèrement décalé
|
|
embedding=np.random.rand(512),
|
|
model_source="dino"
|
|
),
|
|
]
|
|
|
|
merged = vision.merge_overlapping_detections(overlapping_detections, iou_threshold=0.5)
|
|
print(f" Détections avant fusion: {len(overlapping_detections)}")
|
|
print(f" Détections après fusion: {len(merged)}")
|
|
|
|
print("\n✓ Tests basiques terminés!")
|