v1.0 - Version stable: multi-PC, détection UI-DETR-1, 3 modes exécution
- Frontend v4 accessible sur réseau local (192.168.1.40) - Ports ouverts: 3002 (frontend), 5001 (backend), 5004 (dashboard) - Ollama GPU fonctionnel - Self-healing interactif - Dashboard confiance Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
471
core/detection/ollama_client.py
Normal file
471
core/detection/ollama_client.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""
|
||||
OllamaClient - Client pour Vision-Language Models via Ollama
|
||||
|
||||
Interface pour communiquer avec des VLM (Qwen, LLaVA, etc.) via Ollama.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
import requests
|
||||
import json
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""
|
||||
Client Ollama pour VLM
|
||||
|
||||
Permet d'envoyer des images et prompts à un VLM via l'API Ollama.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
endpoint: str = "http://localhost:11434",
|
||||
model: str = "qwen3-vl:8b",
|
||||
timeout: int = 60):
|
||||
"""
|
||||
Initialiser le client Ollama
|
||||
|
||||
Args:
|
||||
endpoint: URL de l'API Ollama
|
||||
model: Nom du modèle VLM à utiliser
|
||||
timeout: Timeout en secondes
|
||||
"""
|
||||
self.endpoint = endpoint.rstrip('/')
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self._check_connection()
|
||||
|
||||
def _check_connection(self) -> bool:
|
||||
"""Vérifier la connexion à Ollama"""
|
||||
try:
|
||||
response = requests.get(f"{self.endpoint}/api/tags", timeout=5)
|
||||
if response.status_code == 200:
|
||||
models = response.json().get('models', [])
|
||||
model_names = [m['name'] for m in models]
|
||||
if self.model not in model_names:
|
||||
logger.warning(f" Model '{self.model}' not found in Ollama")
|
||||
logger.info(f"Available models: {model_names}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f" Cannot connect to Ollama at {self.endpoint}: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
def generate(self,
|
||||
prompt: str,
|
||||
image_path: Optional[str] = None,
|
||||
image: Optional[Image.Image] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 500,
|
||||
force_json: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Générer une réponse du VLM
|
||||
|
||||
Args:
|
||||
prompt: Prompt textuel
|
||||
image_path: Chemin vers une image (optionnel)
|
||||
image: Image PIL (optionnel)
|
||||
system_prompt: Prompt système (optionnel)
|
||||
temperature: Température de génération
|
||||
max_tokens: Nombre max de tokens
|
||||
|
||||
Returns:
|
||||
Dict avec 'response', 'success', 'error'
|
||||
"""
|
||||
try:
|
||||
# Préparer l'image si fournie
|
||||
image_data = None
|
||||
if image_path:
|
||||
image_data = self._encode_image_from_path(image_path)
|
||||
elif image:
|
||||
image_data = self._encode_image_from_pil(image)
|
||||
|
||||
# Construire la requête avec thinking mode désactivé
|
||||
# Pour Qwen3, utiliser /nothink au début du prompt
|
||||
effective_prompt = prompt
|
||||
if "qwen" in self.model.lower():
|
||||
effective_prompt = f"/nothink {prompt}"
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": effective_prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens,
|
||||
"num_ctx": 2048, # Contexte réduit pour plus de vitesse
|
||||
"top_k": 1 # Plus rapide pour les tâches de classification
|
||||
}
|
||||
}
|
||||
|
||||
# Forcer la sortie JSON si demandé (réduit drastiquement les erreurs de parsing)
|
||||
if force_json:
|
||||
payload["format"] = "json"
|
||||
|
||||
if system_prompt:
|
||||
payload["system"] = system_prompt
|
||||
|
||||
if image_data:
|
||||
payload["images"] = [image_data]
|
||||
|
||||
# Envoyer la requête
|
||||
response = requests.post(
|
||||
f"{self.endpoint}/api/generate",
|
||||
json=payload,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return {
|
||||
"response": result.get("response", ""),
|
||||
"success": True,
|
||||
"error": None
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"response": "",
|
||||
"success": False,
|
||||
"error": f"HTTP {response.status_code}: {response.text}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"response": "",
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def detect_ui_elements(self, image_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Détecter les éléments UI dans une image
|
||||
|
||||
Args:
|
||||
image_path: Chemin vers le screenshot
|
||||
|
||||
Returns:
|
||||
Dict avec liste d'éléments détectés
|
||||
"""
|
||||
prompt = """Analyze this screenshot and list all interactive UI elements you can see.
|
||||
For each element, provide:
|
||||
- Type (button, text_input, checkbox, radio, dropdown, tab, link, icon, table_row, menu_item)
|
||||
- Position (approximate x, y coordinates)
|
||||
- Label or text content
|
||||
- Semantic role (primary_action, cancel, submit, form_input, search_field, navigation, settings, close)
|
||||
|
||||
Format your response as JSON."""
|
||||
|
||||
result = self.generate(prompt, image_path=image_path, temperature=0.1)
|
||||
|
||||
if result["success"]:
|
||||
try:
|
||||
# Parser la réponse JSON
|
||||
elements = json.loads(result["response"])
|
||||
return {"elements": elements, "success": True}
|
||||
except json.JSONDecodeError:
|
||||
# Si pas JSON valide, retourner texte brut
|
||||
return {"elements": [], "success": False, "raw_response": result["response"]}
|
||||
|
||||
return {"elements": [], "success": False, "error": result["error"]}
|
||||
|
||||
def classify_element_type(self,
|
||||
element_image: Image.Image,
|
||||
context: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Classifier le type d'un élément UI
|
||||
|
||||
Args:
|
||||
element_image: Image de l'élément
|
||||
context: Contexte additionnel
|
||||
|
||||
Returns:
|
||||
Dict avec 'type' et 'confidence'
|
||||
"""
|
||||
types_list = "button, text_input, checkbox, radio, dropdown, tab, link, icon, table_row, menu_item"
|
||||
|
||||
prompt = f"""What type of UI element is this?
|
||||
Choose ONLY ONE from: {types_list}
|
||||
|
||||
Respond with just the type name, nothing else."""
|
||||
|
||||
if context:
|
||||
prompt += f"\n\nContext: {context}"
|
||||
|
||||
result = self.generate(prompt, image=element_image, temperature=0.0)
|
||||
|
||||
if result["success"]:
|
||||
element_type = result["response"].strip().lower()
|
||||
# Valider que c'est un type connu
|
||||
valid_types = types_list.split(", ")
|
||||
if element_type in valid_types:
|
||||
return {"type": element_type, "confidence": 0.9, "success": True}
|
||||
else:
|
||||
# Essayer de trouver le type le plus proche
|
||||
for vtype in valid_types:
|
||||
if vtype in element_type:
|
||||
return {"type": vtype, "confidence": 0.7, "success": True}
|
||||
|
||||
return {"type": "unknown", "confidence": 0.0, "success": False}
|
||||
|
||||
def classify_element_role(self,
|
||||
element_image: Image.Image,
|
||||
element_type: str,
|
||||
context: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Classifier le rôle sémantique d'un élément
|
||||
|
||||
Args:
|
||||
element_image: Image de l'élément
|
||||
element_type: Type de l'élément
|
||||
context: Contexte additionnel
|
||||
|
||||
Returns:
|
||||
Dict avec 'role' et 'confidence'
|
||||
"""
|
||||
roles_list = "primary_action, cancel, submit, form_input, search_field, navigation, settings, close, delete, edit, save"
|
||||
|
||||
prompt = f"""This is a {element_type}. What is its semantic role or purpose?
|
||||
Choose ONLY ONE from: {roles_list}
|
||||
|
||||
Respond with just the role name, nothing else."""
|
||||
|
||||
if context:
|
||||
prompt += f"\n\nContext: {context}"
|
||||
|
||||
result = self.generate(prompt, image=element_image, temperature=0.0)
|
||||
|
||||
if result["success"]:
|
||||
role = result["response"].strip().lower()
|
||||
# Valider que c'est un rôle connu
|
||||
valid_roles = roles_list.split(", ")
|
||||
if role in valid_roles:
|
||||
return {"role": role, "confidence": 0.9, "success": True}
|
||||
else:
|
||||
# Essayer de trouver le rôle le plus proche
|
||||
for vrole in valid_roles:
|
||||
if vrole in role:
|
||||
return {"role": vrole, "confidence": 0.7, "success": True}
|
||||
|
||||
return {"role": "unknown", "confidence": 0.0, "success": False}
|
||||
|
||||
def extract_text(self, image: Image.Image) -> Dict[str, Any]:
|
||||
"""
|
||||
Extraire le texte d'une image
|
||||
|
||||
Args:
|
||||
image: Image PIL
|
||||
|
||||
Returns:
|
||||
Dict avec 'text' extrait
|
||||
"""
|
||||
prompt = "Extract all visible text from this image. Return only the text, nothing else."
|
||||
|
||||
result = self.generate(prompt, image=image, temperature=0.0)
|
||||
|
||||
if result["success"]:
|
||||
return {"text": result["response"].strip(), "success": True}
|
||||
|
||||
return {"text": "", "success": False, "error": result["error"]}
|
||||
|
||||
def classify_element_complete(self, element_image: Image.Image) -> Dict[str, Any]:
|
||||
"""
|
||||
Classifier complètement un élément UI en UN SEUL appel VLM (optimisé)
|
||||
|
||||
Au lieu de 3 appels séparés (type, role, text), cette méthode
|
||||
fait UN SEUL appel pour obtenir toutes les informations.
|
||||
|
||||
Réduction de performance: 3 appels → 1 appel = 66% plus rapide
|
||||
|
||||
Args:
|
||||
element_image: Image PIL de l'élément
|
||||
|
||||
Returns:
|
||||
Dict avec 'type', 'role', 'text', 'confidence', 'success'
|
||||
"""
|
||||
# System prompt "zéro tolérance" - Force le VLM à NE produire QUE du JSON
|
||||
system_prompt = """You are a UI element classifier.
|
||||
Your ONLY task is to output valid JSON. Never explain. Never comment. Never discuss.
|
||||
Expected format:
|
||||
{"type": "...", "role": "...", "text": "..."}"""
|
||||
|
||||
# User prompt simplifié et direct
|
||||
prompt = """Classify this UI element:
|
||||
- Type: Choose ONE from [button, text_input, checkbox, radio, dropdown, tab, link, icon, table_row, menu_item]
|
||||
- Role: Choose ONE from [primary_action, cancel, submit, form_input, search_field, navigation, settings, close, delete, edit, save]
|
||||
- Text: Any visible text (empty string if none)
|
||||
|
||||
Output JSON only."""
|
||||
|
||||
result = self.generate(
|
||||
prompt,
|
||||
image=element_image,
|
||||
system_prompt=system_prompt,
|
||||
temperature=0.0,
|
||||
max_tokens=150,
|
||||
force_json=True
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
try:
|
||||
# Parser la réponse JSON
|
||||
response_text = result["response"].strip()
|
||||
|
||||
# Nettoyer la réponse si elle contient du markdown
|
||||
if response_text.startswith("```"):
|
||||
lines = response_text.split("\n")
|
||||
response_text = "\n".join([l for l in lines if not l.startswith("```")])
|
||||
response_text = response_text.strip()
|
||||
|
||||
data = json.loads(response_text)
|
||||
|
||||
# Valider les valeurs
|
||||
valid_types = ["button", "text_input", "checkbox", "radio", "dropdown",
|
||||
"tab", "link", "icon", "table_row", "menu_item"]
|
||||
valid_roles = ["primary_action", "cancel", "submit", "form_input",
|
||||
"search_field", "navigation", "settings", "close",
|
||||
"delete", "edit", "save"]
|
||||
|
||||
elem_type = data.get("type", "unknown").lower()
|
||||
elem_role = data.get("role", "unknown").lower()
|
||||
elem_text = data.get("text", "")
|
||||
|
||||
# Fallback si type/role invalides
|
||||
if elem_type not in valid_types:
|
||||
elem_type = "unknown"
|
||||
if elem_role not in valid_roles:
|
||||
elem_role = "unknown"
|
||||
|
||||
return {
|
||||
"type": elem_type,
|
||||
"role": elem_role,
|
||||
"text": elem_text,
|
||||
"confidence": 0.85,
|
||||
"success": True
|
||||
}
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"JSON parse error in classify_element_complete: {e}")
|
||||
logger.debug(f"Raw response: {result['response'][:200]}")
|
||||
return {
|
||||
"type": "unknown",
|
||||
"role": "unknown",
|
||||
"text": "",
|
||||
"confidence": 0.0,
|
||||
"success": False,
|
||||
"error": f"JSON parse error: {e}"
|
||||
}
|
||||
|
||||
return {
|
||||
"type": "unknown",
|
||||
"role": "unknown",
|
||||
"text": "",
|
||||
"confidence": 0.0,
|
||||
"success": False,
|
||||
"error": result.get("error", "VLM call failed")
|
||||
}
|
||||
|
||||
def _encode_image_from_path(self, image_path: str) -> str:
|
||||
"""Encoder une image depuis un fichier en base64"""
|
||||
with open(image_path, 'rb') as f:
|
||||
return base64.b64encode(f.read()).decode('utf-8')
|
||||
|
||||
def _encode_image_from_pil(self, image: Image.Image) -> str:
|
||||
"""Encoder une image PIL en base64 avec prétraitement optimisé"""
|
||||
# 1. Convertir en RGB si nécessaire (évite erreurs PNG transparent)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
# 2. Redimensionnement intelligent : max 1280px sur le côté long
|
||||
max_size = 1280
|
||||
if max(image.size) > max_size:
|
||||
ratio = max_size / max(image.size)
|
||||
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
||||
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 3. Sauvegarder en JPEG qualité 90 (plus léger, meilleur pour VLM)
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format='JPEG', quality=90)
|
||||
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
def list_models(self) -> List[str]:
|
||||
"""Lister les modèles disponibles dans Ollama"""
|
||||
try:
|
||||
response = requests.get(f"{self.endpoint}/api/tags", timeout=5)
|
||||
if response.status_code == 200:
|
||||
models = response.json().get('models', [])
|
||||
return [m['name'] for m in models]
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing models: {e}")
|
||||
return []
|
||||
|
||||
def pull_model(self, model_name: str) -> bool:
|
||||
"""
|
||||
Télécharger un modèle dans Ollama
|
||||
|
||||
Args:
|
||||
model_name: Nom du modèle à télécharger
|
||||
|
||||
Returns:
|
||||
True si succès
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Pulling model {model_name}...")
|
||||
response = requests.post(
|
||||
f"{self.endpoint}/api/pull",
|
||||
json={"name": model_name},
|
||||
stream=True,
|
||||
timeout=600
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
data = json.loads(line)
|
||||
if 'status' in data:
|
||||
logger.info(f" {data['status']}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error pulling model: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fonctions utilitaires
|
||||
# ============================================================================
|
||||
|
||||
def create_ollama_client(model: str = "qwen3-vl:8b",
|
||||
endpoint: str = "http://localhost:11434") -> OllamaClient:
|
||||
"""
|
||||
Créer un client Ollama
|
||||
|
||||
Args:
|
||||
model: Nom du modèle VLM
|
||||
endpoint: URL de l'API Ollama
|
||||
|
||||
Returns:
|
||||
OllamaClient configuré
|
||||
"""
|
||||
return OllamaClient(endpoint=endpoint, model=model)
|
||||
|
||||
|
||||
def check_ollama_available(endpoint: str = "http://localhost:11434") -> bool:
|
||||
"""
|
||||
Vérifier si Ollama est disponible
|
||||
|
||||
Args:
|
||||
endpoint: URL de l'API Ollama
|
||||
|
||||
Returns:
|
||||
True si disponible
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"{endpoint}/api/tags", timeout=5)
|
||||
return response.status_code == 200
|
||||
except (requests.RequestException, ConnectionError, TimeoutError):
|
||||
return False
|
||||
429
core/detection/omniparser_adapter.py
Normal file
429
core/detection/omniparser_adapter.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
OmniParser Adapter pour RPA Vision V3
|
||||
|
||||
Intègre Microsoft OmniParser v2 pour la détection d'éléments UI.
|
||||
OmniParser combine détection d'icônes (YOLO) + OCR + captioning en un seul pipeline.
|
||||
|
||||
Avantages:
|
||||
- Détection précise des petits éléments (icônes, boutons)
|
||||
- OCR intégré
|
||||
- Description sémantique des éléments
|
||||
- 60% plus rapide que le pipeline OWL+OpenCV+VLM
|
||||
|
||||
Usage:
|
||||
adapter = OmniParserAdapter()
|
||||
elements = adapter.detect(screenshot_pil)
|
||||
# elements est une liste de dicts avec bbox, label, type, etc.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import base64
|
||||
import io
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# Ajouter OmniParser au path
|
||||
OMNIPARSER_PATH = "/home/dom/ai/OmniParser"
|
||||
if OMNIPARSER_PATH not in sys.path:
|
||||
sys.path.insert(0, OMNIPARSER_PATH)
|
||||
|
||||
# Configuration des modèles OmniParser
|
||||
OMNIPARSER_CONFIG = {
|
||||
'som_model_path': os.path.join(OMNIPARSER_PATH, 'weights/icon_detect/model.pt'),
|
||||
'caption_model_name': 'florence2',
|
||||
'caption_model_path': os.path.join(OMNIPARSER_PATH, 'weights/icon_caption_florence'),
|
||||
'BOX_TRESHOLD': 0.05, # Seuil bas pour détecter plus d'éléments
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectedElement:
|
||||
"""Élément UI détecté par OmniParser"""
|
||||
bbox: Tuple[int, int, int, int] # (x1, y1, x2, y2) en pixels
|
||||
bbox_normalized: Tuple[float, float, float, float] # (x1, y1, x2, y2) normalisé 0-1
|
||||
label: str # Description de l'élément
|
||||
element_type: str # 'icon', 'text', 'button', etc.
|
||||
confidence: float
|
||||
center: Tuple[int, int] # Centre en pixels
|
||||
is_interactable: bool
|
||||
|
||||
|
||||
class OmniParserAdapter:
|
||||
"""
|
||||
Adapter pour utiliser OmniParser dans RPA Vision V3.
|
||||
|
||||
OmniParser détecte tous les éléments UI d'un screenshot et retourne
|
||||
leurs positions, descriptions et types.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pour éviter de charger les modèles plusieurs fois"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise OmniParser (lazy loading)"""
|
||||
if OmniParserAdapter._initialized:
|
||||
return
|
||||
|
||||
self.omniparser = None
|
||||
self.available = False
|
||||
self._check_availability()
|
||||
|
||||
def _check_availability(self):
|
||||
"""Vérifie si OmniParser est disponible"""
|
||||
try:
|
||||
# Vérifier que les fichiers de modèles existent
|
||||
if not os.path.exists(OMNIPARSER_CONFIG['som_model_path']):
|
||||
print(f"⚠️ [OmniParser] Modèle de détection non trouvé: {OMNIPARSER_CONFIG['som_model_path']}")
|
||||
return
|
||||
|
||||
if not os.path.exists(OMNIPARSER_CONFIG['caption_model_path']):
|
||||
print(f"⚠️ [OmniParser] Modèle de caption non trouvé: {OMNIPARSER_CONFIG['caption_model_path']}")
|
||||
return
|
||||
|
||||
self.available = True
|
||||
print("✅ [OmniParser] Modèles disponibles, chargement différé")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ [OmniParser] Erreur vérification: {e}")
|
||||
self.available = False
|
||||
|
||||
def _load_models(self):
|
||||
"""Charge les modèles OmniParser (lazy loading) avec GPU"""
|
||||
if self.omniparser is not None:
|
||||
return True
|
||||
|
||||
if not self.available:
|
||||
return False
|
||||
|
||||
try:
|
||||
import torch
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f"🔄 [OmniParser] Chargement des modèles sur {device}...")
|
||||
|
||||
from util.omniparser import Omniparser
|
||||
self.omniparser = Omniparser(OMNIPARSER_CONFIG)
|
||||
|
||||
# Forcer YOLO sur GPU si disponible
|
||||
if device == 'cuda' and hasattr(self.omniparser, 'som_model'):
|
||||
self.omniparser.som_model.to(device)
|
||||
print(f"✅ [OmniParser] YOLO déplacé sur {device}")
|
||||
|
||||
OmniParserAdapter._initialized = True
|
||||
print(f"✅ [OmniParser] Modèles chargés avec succès sur {device}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ [OmniParser] Erreur chargement modèles: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
self.available = False
|
||||
return False
|
||||
|
||||
def detect(self, image: Image.Image) -> List[DetectedElement]:
|
||||
"""
|
||||
Détecte tous les éléments UI dans une image.
|
||||
|
||||
Args:
|
||||
image: Image PIL du screenshot
|
||||
|
||||
Returns:
|
||||
Liste de DetectedElement avec bbox, label, type, etc.
|
||||
"""
|
||||
if not self._load_models():
|
||||
print("⚠️ [OmniParser] Non disponible, retourne liste vide")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Convertir PIL en base64
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
W, H = image.size
|
||||
print(f"📸 [OmniParser] Analyse image {W}x{H}...")
|
||||
|
||||
# Appel OmniParser
|
||||
labeled_img, parsed_content = self.omniparser.parse(image_base64)
|
||||
|
||||
print(f"🎯 [OmniParser] {len(parsed_content)} éléments détectés")
|
||||
|
||||
# Convertir en DetectedElement
|
||||
elements = []
|
||||
for item in parsed_content:
|
||||
elem = self._parse_item(item, W, H)
|
||||
if elem:
|
||||
elements.append(elem)
|
||||
|
||||
return elements
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ [OmniParser] Erreur détection: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
||||
def _parse_item(self, item: Any, width: int, height: int) -> Optional[DetectedElement]:
|
||||
"""Parse un élément OmniParser en DetectedElement"""
|
||||
try:
|
||||
# Format OmniParser: {'bbox': [x1, y1, x2, y2], 'label': 'description', ...}
|
||||
# Les bbox sont normalisées (0-1)
|
||||
|
||||
if isinstance(item, dict):
|
||||
bbox_norm = item.get('bbox', item.get('box', []))
|
||||
label = item.get('label', item.get('content', item.get('text', 'unknown')))
|
||||
elif isinstance(item, (list, tuple)) and len(item) >= 2:
|
||||
# Format alternatif: (bbox, label)
|
||||
bbox_norm = item[0] if isinstance(item[0], (list, tuple)) else []
|
||||
label = item[1] if len(item) > 1 else 'unknown'
|
||||
else:
|
||||
return None
|
||||
|
||||
if not bbox_norm or len(bbox_norm) < 4:
|
||||
return None
|
||||
|
||||
x1_n, y1_n, x2_n, y2_n = bbox_norm[:4]
|
||||
|
||||
# Convertir en pixels
|
||||
x1 = int(x1_n * width)
|
||||
y1 = int(y1_n * height)
|
||||
x2 = int(x2_n * width)
|
||||
y2 = int(y2_n * height)
|
||||
|
||||
# Calculer le centre
|
||||
cx = (x1 + x2) // 2
|
||||
cy = (y1 + y2) // 2
|
||||
|
||||
# Déterminer le type d'élément
|
||||
element_type = self._classify_element(label, x2-x1, y2-y1)
|
||||
|
||||
# Confiance (OmniParser ne fournit pas toujours)
|
||||
confidence = item.get('confidence', item.get('score', 0.8))
|
||||
|
||||
return DetectedElement(
|
||||
bbox=(x1, y1, x2, y2),
|
||||
bbox_normalized=(x1_n, y1_n, x2_n, y2_n),
|
||||
label=str(label),
|
||||
element_type=element_type,
|
||||
confidence=float(confidence),
|
||||
center=(cx, cy),
|
||||
is_interactable=self._is_interactable(label, element_type)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [OmniParser] Erreur parsing item: {e}")
|
||||
return None
|
||||
|
||||
def _classify_element(self, label: str, width: int, height: int) -> str:
|
||||
"""Classifie le type d'élément basé sur le label et la taille"""
|
||||
label_lower = label.lower() if label else ""
|
||||
|
||||
# Mots-clés pour classification
|
||||
icon_keywords = ['icon', 'logo', 'image', 'picture', 'symbol']
|
||||
button_keywords = ['button', 'btn', 'click', 'submit', 'ok', 'cancel', 'close']
|
||||
input_keywords = ['input', 'text field', 'search', 'textbox', 'entry']
|
||||
menu_keywords = ['menu', 'dropdown', 'select', 'option']
|
||||
|
||||
for kw in icon_keywords:
|
||||
if kw in label_lower:
|
||||
return 'icon'
|
||||
|
||||
for kw in button_keywords:
|
||||
if kw in label_lower:
|
||||
return 'button'
|
||||
|
||||
for kw in input_keywords:
|
||||
if kw in label_lower:
|
||||
return 'input'
|
||||
|
||||
for kw in menu_keywords:
|
||||
if kw in label_lower:
|
||||
return 'menu'
|
||||
|
||||
# Classification par taille
|
||||
if width < 50 and height < 50:
|
||||
return 'icon'
|
||||
elif width > 100 and height < 40:
|
||||
return 'input'
|
||||
elif width < 150 and height < 50:
|
||||
return 'button'
|
||||
|
||||
return 'element'
|
||||
|
||||
def _is_interactable(self, label: str, element_type: str) -> bool:
|
||||
"""Détermine si l'élément est interactable"""
|
||||
interactable_types = {'button', 'input', 'icon', 'menu', 'link', 'checkbox'}
|
||||
return element_type in interactable_types
|
||||
|
||||
def find_element(
|
||||
self,
|
||||
screenshot: Image.Image,
|
||||
anchor: Image.Image,
|
||||
threshold: float = 0.5
|
||||
) -> Optional[Tuple[int, int, str]]:
|
||||
"""
|
||||
Trouve un élément spécifique dans le screenshot en comparant avec une ancre.
|
||||
|
||||
Stratégie:
|
||||
1. Détecte tous les éléments avec OmniParser
|
||||
2. Pour chaque élément, compare avec l'ancre via template matching
|
||||
3. Retourne le meilleur match
|
||||
|
||||
Args:
|
||||
screenshot: Screenshot complet
|
||||
anchor: Image de l'élément à trouver
|
||||
threshold: Seuil de similarité (0-1)
|
||||
|
||||
Returns:
|
||||
(x, y, method) si trouvé, None sinon
|
||||
"""
|
||||
import cv2
|
||||
|
||||
elements = self.detect(screenshot)
|
||||
if not elements:
|
||||
print("⚠️ [OmniParser] Aucun élément détecté")
|
||||
return None
|
||||
|
||||
print(f"🔍 [OmniParser] Recherche parmi {len(elements)} éléments...")
|
||||
|
||||
# Convertir images en arrays
|
||||
screenshot_np = np.array(screenshot)
|
||||
anchor_np = np.array(anchor)
|
||||
|
||||
if len(screenshot_np.shape) == 3:
|
||||
screenshot_gray = cv2.cvtColor(screenshot_np, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
screenshot_gray = screenshot_np
|
||||
|
||||
if len(anchor_np.shape) == 3:
|
||||
anchor_gray = cv2.cvtColor(anchor_np, cv2.COLOR_RGB2GRAY)
|
||||
else:
|
||||
anchor_gray = anchor_np
|
||||
|
||||
best_match = None
|
||||
best_score = -1
|
||||
|
||||
anchor_h, anchor_w = anchor_gray.shape[:2]
|
||||
|
||||
for elem in elements:
|
||||
x1, y1, x2, y2 = elem.bbox
|
||||
|
||||
# Extraire la région
|
||||
region = screenshot_gray[y1:y2, x1:x2]
|
||||
|
||||
if region.size == 0:
|
||||
continue
|
||||
|
||||
# Resize pour matcher la taille de l'ancre
|
||||
try:
|
||||
region_resized = cv2.resize(region, (anchor_w, anchor_h))
|
||||
|
||||
# Template matching
|
||||
result = cv2.matchTemplate(
|
||||
region_resized,
|
||||
anchor_gray,
|
||||
cv2.TM_CCOEFF_NORMED
|
||||
)
|
||||
_, max_val, _, _ = cv2.minMaxLoc(result)
|
||||
|
||||
if max_val > best_score:
|
||||
best_score = max_val
|
||||
best_match = elem
|
||||
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if best_match and best_score >= threshold:
|
||||
cx, cy = best_match.center
|
||||
print(f"✅ [OmniParser] Trouvé: '{best_match.label}' à ({cx}, {cy}) score={best_score:.2f}")
|
||||
return (cx, cy, f"omniparser_{best_match.element_type}")
|
||||
|
||||
print(f"⚠️ [OmniParser] Aucun match >= {threshold} (best={best_score:.2f})")
|
||||
return None
|
||||
|
||||
def find_by_description(
|
||||
self,
|
||||
screenshot: Image.Image,
|
||||
description: str,
|
||||
threshold: float = 0.3
|
||||
) -> Optional[Tuple[int, int, str]]:
|
||||
"""
|
||||
Trouve un élément par sa description textuelle.
|
||||
|
||||
Args:
|
||||
screenshot: Screenshot complet
|
||||
description: Description de l'élément ("bouton Document", "icône Excel", etc.)
|
||||
threshold: Seuil de similarité textuelle
|
||||
|
||||
Returns:
|
||||
(x, y, method) si trouvé, None sinon
|
||||
"""
|
||||
elements = self.detect(screenshot)
|
||||
if not elements:
|
||||
return None
|
||||
|
||||
description_lower = description.lower()
|
||||
description_words = set(description_lower.split())
|
||||
|
||||
best_match = None
|
||||
best_score = 0
|
||||
|
||||
for elem in elements:
|
||||
label_lower = elem.label.lower()
|
||||
label_words = set(label_lower.split())
|
||||
|
||||
# Score basé sur les mots communs
|
||||
common_words = description_words & label_words
|
||||
if description_words:
|
||||
score = len(common_words) / len(description_words)
|
||||
else:
|
||||
score = 0
|
||||
|
||||
# Bonus si le type correspond
|
||||
if elem.element_type in description_lower:
|
||||
score += 0.2
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_match = elem
|
||||
|
||||
if best_match and best_score >= threshold:
|
||||
cx, cy = best_match.center
|
||||
print(f"✅ [OmniParser] Match description: '{best_match.label}' à ({cx}, {cy}) score={best_score:.2f}")
|
||||
return (cx, cy, "omniparser_description")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Instance globale (singleton)
|
||||
_omniparser_instance: Optional[OmniParserAdapter] = None
|
||||
|
||||
|
||||
def get_omniparser() -> OmniParserAdapter:
|
||||
"""Retourne l'instance singleton d'OmniParser"""
|
||||
global _omniparser_instance
|
||||
if _omniparser_instance is None:
|
||||
_omniparser_instance = OmniParserAdapter()
|
||||
return _omniparser_instance
|
||||
|
||||
|
||||
def detect_elements(image: Image.Image) -> List[DetectedElement]:
|
||||
"""Fonction utilitaire pour détecter les éléments"""
|
||||
return get_omniparser().detect(image)
|
||||
|
||||
|
||||
def find_element(
|
||||
screenshot: Image.Image,
|
||||
anchor: Image.Image,
|
||||
threshold: float = 0.5
|
||||
) -> Optional[Tuple[int, int, str]]:
|
||||
"""Fonction utilitaire pour trouver un élément"""
|
||||
return get_omniparser().find_element(screenshot, anchor, threshold)
|
||||
309
core/detection/owl_detector.py
Normal file
309
core/detection/owl_detector.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
OWL-v2 Detector - Détection d'éléments UI avec OWL-v2
|
||||
|
||||
Utilise le modèle OWL-v2 (Open-World Localization) de Google pour détecter
|
||||
des éléments UI dans les screenshots avec des prompts textuels.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from transformers import Owlv2Processor, Owlv2ForObjectDetection
|
||||
OWL_AVAILABLE = True
|
||||
except ImportError:
|
||||
OWL_AVAILABLE = False
|
||||
|
||||
|
||||
class OwlDetector:
|
||||
"""
|
||||
Détecteur d'éléments UI basé sur OWL-v2
|
||||
|
||||
OWL-v2 permet de détecter des objets avec des prompts textuels,
|
||||
idéal pour trouver des boutons, champs de texte, etc.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name: str = "google/owlv2-base-patch16-ensemble",
|
||||
device: Optional[str] = None,
|
||||
confidence_threshold: float = 0.1):
|
||||
"""
|
||||
Initialiser le détecteur OWL-v2
|
||||
|
||||
Args:
|
||||
model_name: Nom du modèle HuggingFace
|
||||
device: Device ('cuda' ou 'cpu', auto-détecté si None)
|
||||
confidence_threshold: Seuil de confiance minimum
|
||||
"""
|
||||
if not OWL_AVAILABLE:
|
||||
raise ImportError(
|
||||
"transformers n'est pas installé ou version trop ancienne. "
|
||||
"Installer avec: pip install transformers>=4.35.0"
|
||||
)
|
||||
|
||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.confidence_threshold = confidence_threshold
|
||||
|
||||
logger.info(f"Chargement OWL-v2 sur {self.device}...")
|
||||
self.processor = Owlv2Processor.from_pretrained(model_name)
|
||||
self.model = Owlv2ForObjectDetection.from_pretrained(model_name)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
logger.info("OWL-v2 chargé")
|
||||
|
||||
def detect(self,
|
||||
image: Image.Image,
|
||||
text_queries: List[str],
|
||||
confidence_threshold: Optional[float] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Détecter des éléments UI avec des prompts textuels
|
||||
|
||||
Args:
|
||||
image: Image PIL à analyser
|
||||
text_queries: Liste de prompts (ex: ["button", "text field", "icon"])
|
||||
confidence_threshold: Seuil de confiance (utilise self.confidence_threshold si None)
|
||||
|
||||
Returns:
|
||||
Liste de détections avec format:
|
||||
{
|
||||
'label': str, # Prompt qui a matché
|
||||
'confidence': float, # Score de confiance
|
||||
'bbox': [x1, y1, x2, y2], # Coordonnées
|
||||
'center': (x, y) # Centre de la bbox
|
||||
}
|
||||
"""
|
||||
threshold = confidence_threshold or self.confidence_threshold
|
||||
|
||||
# Préparer les inputs
|
||||
inputs = self.processor(
|
||||
text=text_queries,
|
||||
images=image,
|
||||
return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
# Inférence
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
# Post-traitement
|
||||
target_sizes = torch.tensor([image.size[::-1]]).to(self.device)
|
||||
results = self.processor.post_process_object_detection(
|
||||
outputs=outputs,
|
||||
target_sizes=target_sizes,
|
||||
threshold=threshold
|
||||
)[0]
|
||||
|
||||
# Formater les résultats
|
||||
detections = []
|
||||
boxes = results["boxes"].cpu().numpy()
|
||||
scores = results["scores"].cpu().numpy()
|
||||
labels = results["labels"].cpu().numpy()
|
||||
|
||||
for box, score, label_idx in zip(boxes, scores, labels):
|
||||
x1, y1, x2, y2 = box
|
||||
center_x = (x1 + x2) / 2
|
||||
center_y = (y1 + y2) / 2
|
||||
|
||||
detections.append({
|
||||
'label': text_queries[label_idx],
|
||||
'confidence': float(score),
|
||||
'bbox': [float(x1), float(y1), float(x2), float(y2)],
|
||||
'center': (float(center_x), float(center_y)),
|
||||
'width': float(x2 - x1),
|
||||
'height': float(y2 - y1)
|
||||
})
|
||||
|
||||
return detections
|
||||
|
||||
def detect_ui_elements(self, image: Image.Image) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Détecter les éléments UI courants
|
||||
|
||||
Args:
|
||||
image: Image PIL à analyser
|
||||
|
||||
Returns:
|
||||
Liste de détections d'éléments UI
|
||||
"""
|
||||
# Prompts pour éléments UI courants
|
||||
ui_queries = [
|
||||
"button",
|
||||
"text field",
|
||||
"input box",
|
||||
"checkbox",
|
||||
"radio button",
|
||||
"dropdown menu",
|
||||
"icon",
|
||||
"label",
|
||||
"link",
|
||||
"tab"
|
||||
]
|
||||
|
||||
return self.detect(image, ui_queries)
|
||||
|
||||
def detect_specific(self,
|
||||
image: Image.Image,
|
||||
element_type: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Détecter un type spécifique d'élément
|
||||
|
||||
Args:
|
||||
image: Image PIL
|
||||
element_type: Type d'élément (ex: "submit button", "cancel button")
|
||||
|
||||
Returns:
|
||||
Liste de détections
|
||||
"""
|
||||
return self.detect(image, [element_type])
|
||||
|
||||
def find_element_by_text(self,
|
||||
image: Image.Image,
|
||||
text: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Trouver un élément par son texte
|
||||
|
||||
Args:
|
||||
image: Image PIL
|
||||
text: Texte à chercher (ex: "Submit", "Cancel")
|
||||
|
||||
Returns:
|
||||
Première détection ou None
|
||||
"""
|
||||
# Essayer plusieurs formulations
|
||||
queries = [
|
||||
f"{text} button",
|
||||
f"{text} text",
|
||||
f"{text} label",
|
||||
text
|
||||
]
|
||||
|
||||
detections = self.detect(image, queries)
|
||||
|
||||
if detections:
|
||||
# Retourner la détection avec le meilleur score
|
||||
return max(detections, key=lambda d: d['confidence'])
|
||||
|
||||
return None
|
||||
|
||||
def get_clickable_elements(self, image: Image.Image) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Détecter tous les éléments cliquables
|
||||
|
||||
Args:
|
||||
image: Image PIL
|
||||
|
||||
Returns:
|
||||
Liste d'éléments cliquables
|
||||
"""
|
||||
clickable_queries = [
|
||||
"button",
|
||||
"link",
|
||||
"checkbox",
|
||||
"radio button",
|
||||
"dropdown menu",
|
||||
"tab",
|
||||
"icon button"
|
||||
]
|
||||
|
||||
return self.detect(image, clickable_queries)
|
||||
|
||||
def visualize_detections(self,
|
||||
image: Image.Image,
|
||||
detections: List[Dict[str, Any]],
|
||||
output_path: Optional[Path] = None) -> Image.Image:
|
||||
"""
|
||||
Visualiser les détections sur l'image
|
||||
|
||||
Args:
|
||||
image: Image PIL originale
|
||||
detections: Liste de détections
|
||||
output_path: Chemin de sauvegarde (optionnel)
|
||||
|
||||
Returns:
|
||||
Image avec détections dessinées
|
||||
"""
|
||||
from PIL import ImageDraw, ImageFont
|
||||
|
||||
# Copier l'image
|
||||
img_with_boxes = image.copy()
|
||||
draw = ImageDraw.Draw(img_with_boxes)
|
||||
|
||||
# Dessiner chaque détection
|
||||
for det in detections:
|
||||
bbox = det['bbox']
|
||||
label = det['label']
|
||||
confidence = det['confidence']
|
||||
|
||||
# Dessiner la bbox
|
||||
draw.rectangle(bbox, outline="red", width=2)
|
||||
|
||||
# Dessiner le label
|
||||
text = f"{label}: {confidence:.2f}"
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
|
||||
except (OSError, IOError):
|
||||
font = ImageFont.load_default()
|
||||
|
||||
draw.text((bbox[0], bbox[1] - 15), text, fill="red", font=font)
|
||||
|
||||
# Sauvegarder si demandé
|
||||
if output_path:
|
||||
img_with_boxes.save(output_path)
|
||||
|
||||
return img_with_boxes
|
||||
|
||||
|
||||
def create_owl_detector(device: Optional[str] = None,
|
||||
confidence_threshold: float = 0.1) -> OwlDetector:
|
||||
"""
|
||||
Créer un détecteur OWL-v2 avec configuration par défaut
|
||||
|
||||
Args:
|
||||
device: Device à utiliser
|
||||
confidence_threshold: Seuil de confiance
|
||||
|
||||
Returns:
|
||||
OwlDetector configuré
|
||||
"""
|
||||
return OwlDetector(
|
||||
device=device,
|
||||
confidence_threshold=confidence_threshold
|
||||
)
|
||||
|
||||
|
||||
# Test rapide
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python owl_detector.py <image_path>")
|
||||
sys.exit(1)
|
||||
|
||||
image_path = sys.argv[1]
|
||||
|
||||
print(f"Test OWL-v2 sur {image_path}")
|
||||
|
||||
# Charger image
|
||||
image = Image.open(image_path)
|
||||
|
||||
# Créer détecteur
|
||||
detector = create_owl_detector()
|
||||
|
||||
# Détecter éléments UI
|
||||
print("\nDétection d'éléments UI...")
|
||||
detections = detector.detect_ui_elements(image)
|
||||
|
||||
print(f"\n✓ Trouvé {len(detections)} éléments:")
|
||||
for i, det in enumerate(detections, 1):
|
||||
print(f" {i}. {det['label']}: {det['confidence']:.3f} @ {det['bbox']}")
|
||||
|
||||
# Visualiser
|
||||
output_path = Path(image_path).parent / f"{Path(image_path).stem}_owl_detections.png"
|
||||
detector.visualize_detections(image, detections, output_path)
|
||||
print(f"\n✓ Visualisation sauvegardée: {output_path}")
|
||||
493
core/detection/roi_optimizer.py
Normal file
493
core/detection/roi_optimizer.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""
|
||||
ROI Optimizer - Optimisation de la détection UI par régions d'intérêt
|
||||
|
||||
Optimisations:
|
||||
1. Redimensionnement intelligent des screenshots (max 1920x1080)
|
||||
2. Détection rapide des régions d'intérêt (ROI)
|
||||
3. Cache des résultats pour frames similaires
|
||||
4. Traitement sélectif des zones actives
|
||||
|
||||
Gains de performance attendus:
|
||||
- Réduction de 50-70% du temps de traitement
|
||||
- Réduction de 60-80% de l'utilisation mémoire
|
||||
- Cache hit rate de 30-50% sur workflows répétitifs
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import hashlib
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROI:
|
||||
"""Région d'intérêt détectée"""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float
|
||||
roi_type: str # "active", "changed", "interactive"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertir en dictionnaire"""
|
||||
return {
|
||||
"x": self.x,
|
||||
"y": self.y,
|
||||
"w": self.w,
|
||||
"h": self.h,
|
||||
"confidence": self.confidence,
|
||||
"roi_type": self.roi_type
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizedFrame:
|
||||
"""Frame optimisé avec ROIs"""
|
||||
image: np.ndarray
|
||||
original_size: Tuple[int, int]
|
||||
resized_size: Tuple[int, int]
|
||||
scale_factor: float
|
||||
rois: List[ROI]
|
||||
frame_hash: str
|
||||
|
||||
|
||||
class ROICache:
|
||||
"""
|
||||
Cache pour résultats de détection ROI
|
||||
|
||||
Stocke les résultats de détection pour frames similaires
|
||||
pour éviter les recalculs coûteux.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 100, similarity_threshold: float = 0.95):
|
||||
"""
|
||||
Initialiser le cache ROI
|
||||
|
||||
Args:
|
||||
max_size: Nombre maximum de frames en cache
|
||||
similarity_threshold: Seuil de similarité pour considérer 2 frames identiques
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
|
||||
# Statistiques
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
self.total_time_saved = 0.0
|
||||
|
||||
def _compute_frame_hash(self, image: np.ndarray, quick: bool = True) -> str:
|
||||
"""
|
||||
Calculer un hash rapide de l'image
|
||||
|
||||
Args:
|
||||
image: Image numpy
|
||||
quick: Si True, utilise un hash rapide (downsampled)
|
||||
|
||||
Returns:
|
||||
Hash hexadécimal
|
||||
"""
|
||||
if quick:
|
||||
# Downsample pour hash rapide
|
||||
small = cv2.resize(image, (64, 64))
|
||||
gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY) if len(small.shape) == 3 else small
|
||||
return hashlib.md5(gray.tobytes()).hexdigest()
|
||||
else:
|
||||
# Hash complet (plus lent)
|
||||
return hashlib.md5(image.tobytes()).hexdigest()
|
||||
|
||||
def get(self, image: np.ndarray) -> Optional[List[ROI]]:
|
||||
"""
|
||||
Récupérer les ROIs depuis le cache
|
||||
|
||||
Args:
|
||||
image: Image à rechercher
|
||||
|
||||
Returns:
|
||||
Liste de ROIs si trouvé, None sinon
|
||||
"""
|
||||
frame_hash = self._compute_frame_hash(image)
|
||||
|
||||
if frame_hash in self.cache:
|
||||
# Déplacer à la fin (LRU)
|
||||
self.cache.move_to_end(frame_hash)
|
||||
self.hits += 1
|
||||
|
||||
cached_data = self.cache[frame_hash]
|
||||
self.total_time_saved += cached_data.get("processing_time", 0.0)
|
||||
|
||||
return cached_data["rois"]
|
||||
|
||||
self.misses += 1
|
||||
return None
|
||||
|
||||
def put(self, image: np.ndarray, rois: List[ROI], processing_time: float = 0.0):
|
||||
"""
|
||||
Ajouter des ROIs au cache
|
||||
|
||||
Args:
|
||||
image: Image source
|
||||
rois: ROIs détectés
|
||||
processing_time: Temps de traitement (pour stats)
|
||||
"""
|
||||
frame_hash = self._compute_frame_hash(image)
|
||||
|
||||
# Évict si cache plein
|
||||
if len(self.cache) >= self.max_size and frame_hash not in self.cache:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
self.cache[frame_hash] = {
|
||||
"rois": rois,
|
||||
"processing_time": processing_time,
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
|
||||
def clear(self):
|
||||
"""Vider le cache"""
|
||||
self.cache.clear()
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
self.total_time_saved = 0.0
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Obtenir les statistiques du cache"""
|
||||
total_requests = self.hits + self.misses
|
||||
hit_rate = self.hits / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
return {
|
||||
"size": len(self.cache),
|
||||
"max_size": self.max_size,
|
||||
"hits": self.hits,
|
||||
"misses": self.misses,
|
||||
"hit_rate": hit_rate,
|
||||
"total_time_saved_ms": self.total_time_saved * 1000
|
||||
}
|
||||
|
||||
|
||||
class ROIOptimizer:
|
||||
"""
|
||||
Optimiseur de détection UI par régions d'intérêt
|
||||
|
||||
Optimise la détection UI en:
|
||||
1. Redimensionnant intelligemment les screenshots
|
||||
2. Détectant rapidement les zones actives
|
||||
3. Cachant les résultats pour frames similaires
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_width: int = 1920,
|
||||
max_height: int = 1080,
|
||||
enable_cache: bool = True,
|
||||
cache_size: int = 100):
|
||||
"""
|
||||
Initialiser l'optimiseur ROI
|
||||
|
||||
Args:
|
||||
max_width: Largeur maximale des screenshots
|
||||
max_height: Hauteur maximale des screenshots
|
||||
enable_cache: Activer le cache de ROIs
|
||||
cache_size: Taille du cache
|
||||
"""
|
||||
self.max_width = max_width
|
||||
self.max_height = max_height
|
||||
self.enable_cache = enable_cache
|
||||
|
||||
# Cache
|
||||
self.cache = ROICache(max_size=cache_size) if enable_cache else None
|
||||
|
||||
# Statistiques
|
||||
self.total_frames_processed = 0
|
||||
self.total_frames_resized = 0
|
||||
self.total_processing_time = 0.0
|
||||
|
||||
def optimize_frame(self, image_path: str) -> OptimizedFrame:
|
||||
"""
|
||||
Optimiser un frame pour la détection
|
||||
|
||||
Args:
|
||||
image_path: Chemin vers l'image
|
||||
|
||||
Returns:
|
||||
OptimizedFrame avec image redimensionnée et ROIs
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Charger l'image
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise ValueError(f"Failed to load image: {image_path}")
|
||||
|
||||
original_h, original_w = image.shape[:2]
|
||||
|
||||
# Vérifier le cache d'abord
|
||||
if self.cache:
|
||||
cached_rois = self.cache.get(image)
|
||||
if cached_rois is not None:
|
||||
# Cache hit - retourner directement
|
||||
return OptimizedFrame(
|
||||
image=image,
|
||||
original_size=(original_w, original_h),
|
||||
resized_size=(original_w, original_h),
|
||||
scale_factor=1.0,
|
||||
rois=cached_rois,
|
||||
frame_hash=self.cache._compute_frame_hash(image)
|
||||
)
|
||||
|
||||
# Redimensionner si nécessaire
|
||||
resized_image, scale_factor = self._resize_if_needed(image)
|
||||
resized_h, resized_w = resized_image.shape[:2]
|
||||
|
||||
if scale_factor < 1.0:
|
||||
self.total_frames_resized += 1
|
||||
|
||||
# Détecter les ROIs
|
||||
rois = self._detect_rois(resized_image)
|
||||
|
||||
# Mettre en cache
|
||||
processing_time = time.time() - start_time
|
||||
if self.cache:
|
||||
self.cache.put(image, rois, processing_time)
|
||||
|
||||
self.total_frames_processed += 1
|
||||
self.total_processing_time += processing_time
|
||||
|
||||
return OptimizedFrame(
|
||||
image=resized_image,
|
||||
original_size=(original_w, original_h),
|
||||
resized_size=(resized_w, resized_h),
|
||||
scale_factor=scale_factor,
|
||||
rois=rois,
|
||||
frame_hash=self.cache._compute_frame_hash(image) if self.cache else ""
|
||||
)
|
||||
|
||||
def _resize_if_needed(self, image: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
"""
|
||||
Redimensionner l'image si elle dépasse les limites
|
||||
|
||||
Args:
|
||||
image: Image OpenCV
|
||||
|
||||
Returns:
|
||||
(image_redimensionnée, facteur_d'échelle)
|
||||
"""
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# Calculer le facteur d'échelle nécessaire
|
||||
scale_w = self.max_width / w if w > self.max_width else 1.0
|
||||
scale_h = self.max_height / h if h > self.max_height else 1.0
|
||||
scale_factor = min(scale_w, scale_h)
|
||||
|
||||
# Redimensionner si nécessaire
|
||||
if scale_factor < 1.0:
|
||||
new_w = int(w * scale_factor)
|
||||
new_h = int(h * scale_factor)
|
||||
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
return resized, scale_factor
|
||||
|
||||
return image, 1.0
|
||||
|
||||
def _detect_rois(self, image: np.ndarray) -> List[ROI]:
|
||||
"""
|
||||
Détecter rapidement les régions d'intérêt
|
||||
|
||||
Utilise des techniques rapides pour identifier les zones actives:
|
||||
- Détection de changements (si frame précédent disponible)
|
||||
- Détection de contours
|
||||
- Détection de zones de texte
|
||||
|
||||
Args:
|
||||
image: Image OpenCV
|
||||
|
||||
Returns:
|
||||
Liste de ROIs détectés
|
||||
"""
|
||||
rois = []
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# Convertir en niveaux de gris
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Méthode 1: Détection de contours (rapide)
|
||||
# Appliquer un flou pour réduire le bruit
|
||||
blurred = cv2.GaussianBlur(gray, (5, 5), 0)
|
||||
|
||||
# Détection de contours avec Canny
|
||||
edges = cv2.Canny(blurred, 50, 150)
|
||||
|
||||
# Trouver les contours
|
||||
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Filtrer et créer des ROIs
|
||||
for contour in contours:
|
||||
x, y, cw, ch = cv2.boundingRect(contour)
|
||||
|
||||
# Filtrer les régions trop petites ou trop grandes
|
||||
area = cw * ch
|
||||
if area < 100 or area > (w * h * 0.5): # Min 100px², max 50% de l'image
|
||||
continue
|
||||
|
||||
# Ajouter une marge
|
||||
margin = 5
|
||||
x = max(0, x - margin)
|
||||
y = max(0, y - margin)
|
||||
cw = min(w - x, cw + 2 * margin)
|
||||
ch = min(h - y, ch + 2 * margin)
|
||||
|
||||
rois.append(ROI(
|
||||
x=x,
|
||||
y=y,
|
||||
w=cw,
|
||||
h=ch,
|
||||
confidence=0.8,
|
||||
roi_type="contour"
|
||||
))
|
||||
|
||||
# Méthode 2: Zones de texte (rapide avec EAST ou MSER)
|
||||
# Pour l'instant, on utilise MSER (Maximally Stable Extremal Regions)
|
||||
mser = cv2.MSER_create()
|
||||
regions, _ = mser.detectRegions(gray)
|
||||
|
||||
for region in regions:
|
||||
x, y, rw, rh = cv2.boundingRect(region)
|
||||
|
||||
# Filtrer
|
||||
area = rw * rh
|
||||
if area < 50 or area > (w * h * 0.3):
|
||||
continue
|
||||
|
||||
rois.append(ROI(
|
||||
x=x,
|
||||
y=y,
|
||||
w=rw,
|
||||
h=rh,
|
||||
confidence=0.7,
|
||||
roi_type="text"
|
||||
))
|
||||
|
||||
# Fusionner les ROIs qui se chevauchent
|
||||
rois = self._merge_overlapping_rois(rois)
|
||||
|
||||
# Si aucun ROI détecté, utiliser l'image entière
|
||||
if not rois:
|
||||
rois.append(ROI(
|
||||
x=0,
|
||||
y=0,
|
||||
w=w,
|
||||
h=h,
|
||||
confidence=1.0,
|
||||
roi_type="full_frame"
|
||||
))
|
||||
|
||||
return rois
|
||||
|
||||
def _merge_overlapping_rois(self, rois: List[ROI], iou_threshold: float = 0.5) -> List[ROI]:
|
||||
"""
|
||||
Fusionner les ROIs qui se chevauchent
|
||||
|
||||
Args:
|
||||
rois: Liste de ROIs
|
||||
iou_threshold: Seuil IoU pour fusion
|
||||
|
||||
Returns:
|
||||
Liste de ROIs fusionnés
|
||||
"""
|
||||
if len(rois) <= 1:
|
||||
return rois
|
||||
|
||||
# Trier par aire décroissante
|
||||
rois = sorted(rois, key=lambda r: r.w * r.h, reverse=True)
|
||||
|
||||
merged = []
|
||||
used = set()
|
||||
|
||||
for i, roi1 in enumerate(rois):
|
||||
if i in used:
|
||||
continue
|
||||
|
||||
# Trouver tous les ROIs qui se chevauchent
|
||||
group = [roi1]
|
||||
for j, roi2 in enumerate(rois[i+1:], start=i+1):
|
||||
if j in used:
|
||||
continue
|
||||
|
||||
# Calculer IoU
|
||||
iou = self._calculate_iou(roi1, roi2)
|
||||
if iou > iou_threshold:
|
||||
group.append(roi2)
|
||||
used.add(j)
|
||||
|
||||
# Fusionner le groupe
|
||||
if len(group) == 1:
|
||||
merged.append(roi1)
|
||||
else:
|
||||
merged_roi = self._merge_roi_group(group)
|
||||
merged.append(merged_roi)
|
||||
|
||||
return merged
|
||||
|
||||
def _calculate_iou(self, roi1: ROI, roi2: ROI) -> float:
|
||||
"""Calculer l'IoU entre deux ROIs"""
|
||||
x1_inter = max(roi1.x, roi2.x)
|
||||
y1_inter = max(roi1.y, roi2.y)
|
||||
x2_inter = min(roi1.x + roi1.w, roi2.x + roi2.w)
|
||||
y2_inter = min(roi1.y + roi1.h, roi2.y + roi2.h)
|
||||
|
||||
if x2_inter < x1_inter or y2_inter < y1_inter:
|
||||
return 0.0
|
||||
|
||||
inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter)
|
||||
union_area = (roi1.w * roi1.h) + (roi2.w * roi2.h) - inter_area
|
||||
|
||||
return inter_area / union_area if union_area > 0 else 0.0
|
||||
|
||||
def _merge_roi_group(self, rois: List[ROI]) -> ROI:
|
||||
"""Fusionner un groupe de ROIs en un seul"""
|
||||
min_x = min(r.x for r in rois)
|
||||
min_y = min(r.y for r in rois)
|
||||
max_x = max(r.x + r.w for r in rois)
|
||||
max_y = max(r.y + r.h for r in rois)
|
||||
|
||||
avg_confidence = sum(r.confidence for r in rois) / len(rois)
|
||||
|
||||
return ROI(
|
||||
x=min_x,
|
||||
y=min_y,
|
||||
w=max_x - min_x,
|
||||
h=max_y - min_y,
|
||||
confidence=avg_confidence,
|
||||
roi_type="merged"
|
||||
)
|
||||
|
||||
def scale_coordinates(self, x: int, y: int, scale_factor: float) -> Tuple[int, int]:
|
||||
"""
|
||||
Convertir des coordonnées de l'image redimensionnée vers l'originale
|
||||
|
||||
Args:
|
||||
x, y: Coordonnées dans l'image redimensionnée
|
||||
scale_factor: Facteur d'échelle utilisé
|
||||
|
||||
Returns:
|
||||
(x_original, y_original)
|
||||
"""
|
||||
return (int(x / scale_factor), int(y / scale_factor))
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Obtenir les statistiques de l'optimiseur"""
|
||||
stats = {
|
||||
"total_frames_processed": self.total_frames_processed,
|
||||
"total_frames_resized": self.total_frames_resized,
|
||||
"resize_rate": self.total_frames_resized / self.total_frames_processed if self.total_frames_processed > 0 else 0.0,
|
||||
"avg_processing_time_ms": (self.total_processing_time / self.total_frames_processed * 1000) if self.total_frames_processed > 0 else 0.0
|
||||
}
|
||||
|
||||
if self.cache:
|
||||
stats["cache"] = self.cache.get_stats()
|
||||
|
||||
return stats
|
||||
595
core/detection/spatial_analyzer.py
Normal file
595
core/detection/spatial_analyzer.py
Normal file
@@ -0,0 +1,595 @@
|
||||
"""
|
||||
SpatialAnalyzer - Analyse des relations spatiales entre éléments UI
|
||||
|
||||
Ce module analyse:
|
||||
- Relations spatiales (above, below, left_of, right_of, inside)
|
||||
- Conteneurs sémantiques (forms, menus, toolbars, dialogs)
|
||||
- Groupement d'éléments liés
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Any, Tuple, Set
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Enums et Dataclasses
|
||||
# =============================================================================
|
||||
|
||||
class RelationType(Enum):
|
||||
"""Types de relations spatiales"""
|
||||
ABOVE = "above"
|
||||
BELOW = "below"
|
||||
LEFT_OF = "left_of"
|
||||
RIGHT_OF = "right_of"
|
||||
INSIDE = "inside"
|
||||
CONTAINS = "contains"
|
||||
OVERLAPS = "overlaps"
|
||||
ADJACENT = "adjacent"
|
||||
|
||||
|
||||
class ContainerType(Enum):
|
||||
"""Types de conteneurs sémantiques"""
|
||||
FORM = "form"
|
||||
MENU = "menu"
|
||||
TOOLBAR = "toolbar"
|
||||
DIALOG = "dialog"
|
||||
LIST = "list"
|
||||
TABLE = "table"
|
||||
PANEL = "panel"
|
||||
TAB_GROUP = "tab_group"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpatialRelation:
|
||||
"""Relation spatiale entre deux éléments"""
|
||||
source_element_id: str
|
||||
target_element_id: str
|
||||
relation_type: RelationType
|
||||
distance: float # Distance en pixels
|
||||
confidence: float # Confiance de la relation (0-1)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"source": self.source_element_id,
|
||||
"target": self.target_element_id,
|
||||
"relation": self.relation_type.value,
|
||||
"distance": self.distance,
|
||||
"confidence": self.confidence
|
||||
}
|
||||
|
||||
@property
|
||||
def inverse(self) -> 'SpatialRelation':
|
||||
"""Retourner la relation inverse"""
|
||||
inverse_map = {
|
||||
RelationType.ABOVE: RelationType.BELOW,
|
||||
RelationType.BELOW: RelationType.ABOVE,
|
||||
RelationType.LEFT_OF: RelationType.RIGHT_OF,
|
||||
RelationType.RIGHT_OF: RelationType.LEFT_OF,
|
||||
RelationType.INSIDE: RelationType.CONTAINS,
|
||||
RelationType.CONTAINS: RelationType.INSIDE,
|
||||
RelationType.OVERLAPS: RelationType.OVERLAPS,
|
||||
RelationType.ADJACENT: RelationType.ADJACENT,
|
||||
}
|
||||
return SpatialRelation(
|
||||
source_element_id=self.target_element_id,
|
||||
target_element_id=self.source_element_id,
|
||||
relation_type=inverse_map[self.relation_type],
|
||||
distance=self.distance,
|
||||
confidence=self.confidence
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SemanticContainer:
|
||||
"""Conteneur sémantique groupant des éléments"""
|
||||
container_id: str
|
||||
container_type: ContainerType
|
||||
element_ids: List[str]
|
||||
bounds: Tuple[int, int, int, int] # (x, y, width, height)
|
||||
confidence: float
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"container_id": self.container_id,
|
||||
"container_type": self.container_type.value,
|
||||
"element_ids": self.element_ids,
|
||||
"bounds": self.bounds,
|
||||
"confidence": self.confidence,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpatialAnalyzerConfig:
|
||||
"""Configuration de l'analyseur spatial"""
|
||||
# Seuils de distance
|
||||
adjacent_threshold: float = 20.0 # Distance max pour "adjacent"
|
||||
inside_margin: float = 5.0 # Marge pour "inside"
|
||||
|
||||
# Seuils de confiance
|
||||
min_relation_confidence: float = 0.5
|
||||
min_container_confidence: float = 0.6
|
||||
|
||||
# Groupement
|
||||
max_group_distance: float = 50.0 # Distance max pour grouper
|
||||
min_group_size: int = 2 # Taille min d'un groupe
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Analyseur Spatial
|
||||
# =============================================================================
|
||||
|
||||
class SpatialAnalyzer:
|
||||
"""
|
||||
Analyseur de relations spatiales entre éléments UI.
|
||||
|
||||
Fonctionnalités:
|
||||
- Calcul des relations spatiales (above, below, etc.)
|
||||
- Détection de conteneurs sémantiques
|
||||
- Groupement d'éléments liés
|
||||
|
||||
Example:
|
||||
>>> analyzer = SpatialAnalyzer()
|
||||
>>> relations = analyzer.compute_relations(elements)
|
||||
>>> containers = analyzer.detect_containers(elements)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SpatialAnalyzerConfig] = None):
|
||||
"""
|
||||
Initialiser l'analyseur.
|
||||
|
||||
Args:
|
||||
config: Configuration (utilise défaut si None)
|
||||
"""
|
||||
self.config = config or SpatialAnalyzerConfig()
|
||||
logger.info("SpatialAnalyzer initialisé")
|
||||
|
||||
def compute_relations(
|
||||
self,
|
||||
elements: List[Any]
|
||||
) -> List[SpatialRelation]:
|
||||
"""
|
||||
Calculer les relations spatiales entre tous les éléments.
|
||||
|
||||
Args:
|
||||
elements: Liste d'éléments UI avec bounds
|
||||
|
||||
Returns:
|
||||
Liste de SpatialRelation
|
||||
"""
|
||||
relations = []
|
||||
|
||||
for i, elem_a in enumerate(elements):
|
||||
for j, elem_b in enumerate(elements):
|
||||
if i >= j: # Éviter doublons et auto-relations
|
||||
continue
|
||||
|
||||
# Calculer relation
|
||||
relation = self._compute_relation(elem_a, elem_b)
|
||||
if relation and relation.confidence >= self.config.min_relation_confidence:
|
||||
relations.append(relation)
|
||||
# Ajouter relation inverse pour symétrie
|
||||
relations.append(relation.inverse)
|
||||
|
||||
logger.debug(f"Calculé {len(relations)} relations spatiales")
|
||||
return relations
|
||||
|
||||
def _compute_relation(
|
||||
self,
|
||||
elem_a: Any,
|
||||
elem_b: Any
|
||||
) -> Optional[SpatialRelation]:
|
||||
"""Calculer la relation entre deux éléments."""
|
||||
# Extraire bounds
|
||||
bounds_a = self._get_bounds(elem_a)
|
||||
bounds_b = self._get_bounds(elem_b)
|
||||
|
||||
if bounds_a is None or bounds_b is None:
|
||||
return None
|
||||
|
||||
# Calculer centres
|
||||
center_a = self._get_center(bounds_a)
|
||||
center_b = self._get_center(bounds_b)
|
||||
|
||||
# Calculer distance
|
||||
distance = np.sqrt(
|
||||
(center_a[0] - center_b[0])**2 +
|
||||
(center_a[1] - center_b[1])**2
|
||||
)
|
||||
|
||||
# Déterminer type de relation
|
||||
relation_type, confidence = self._determine_relation_type(
|
||||
bounds_a, bounds_b, center_a, center_b
|
||||
)
|
||||
|
||||
if relation_type is None:
|
||||
return None
|
||||
|
||||
elem_id_a = self._get_element_id(elem_a)
|
||||
elem_id_b = self._get_element_id(elem_b)
|
||||
|
||||
return SpatialRelation(
|
||||
source_element_id=elem_id_a,
|
||||
target_element_id=elem_id_b,
|
||||
relation_type=relation_type,
|
||||
distance=distance,
|
||||
confidence=confidence
|
||||
)
|
||||
|
||||
def _determine_relation_type(
|
||||
self,
|
||||
bounds_a: Tuple[int, int, int, int],
|
||||
bounds_b: Tuple[int, int, int, int],
|
||||
center_a: Tuple[float, float],
|
||||
center_b: Tuple[float, float]
|
||||
) -> Tuple[Optional[RelationType], float]:
|
||||
"""Déterminer le type de relation et sa confiance."""
|
||||
x_a, y_a, w_a, h_a = bounds_a
|
||||
x_b, y_b, w_b, h_b = bounds_b
|
||||
|
||||
# Vérifier INSIDE/CONTAINS
|
||||
if self._is_inside(bounds_a, bounds_b):
|
||||
return RelationType.INSIDE, 0.9
|
||||
if self._is_inside(bounds_b, bounds_a):
|
||||
return RelationType.CONTAINS, 0.9
|
||||
|
||||
# Vérifier OVERLAPS
|
||||
if self._overlaps(bounds_a, bounds_b):
|
||||
return RelationType.OVERLAPS, 0.7
|
||||
|
||||
# Calculer différences de position
|
||||
dx = center_b[0] - center_a[0]
|
||||
dy = center_b[1] - center_a[1]
|
||||
|
||||
# Déterminer direction principale
|
||||
if abs(dx) > abs(dy):
|
||||
# Relation horizontale
|
||||
if dx > 0:
|
||||
relation = RelationType.LEFT_OF # A est à gauche de B
|
||||
else:
|
||||
relation = RelationType.RIGHT_OF # A est à droite de B
|
||||
confidence = min(1.0, abs(dx) / (abs(dy) + 1))
|
||||
else:
|
||||
# Relation verticale
|
||||
if dy > 0:
|
||||
relation = RelationType.ABOVE # A est au-dessus de B
|
||||
else:
|
||||
relation = RelationType.BELOW # A est en-dessous de B
|
||||
confidence = min(1.0, abs(dy) / (abs(dx) + 1))
|
||||
|
||||
# Vérifier adjacence
|
||||
gap = self._compute_gap(bounds_a, bounds_b)
|
||||
if gap <= self.config.adjacent_threshold:
|
||||
confidence = min(confidence + 0.2, 1.0)
|
||||
|
||||
return relation, confidence
|
||||
|
||||
def _is_inside(
|
||||
self,
|
||||
inner: Tuple[int, int, int, int],
|
||||
outer: Tuple[int, int, int, int]
|
||||
) -> bool:
|
||||
"""Vérifier si inner est à l'intérieur de outer."""
|
||||
x_i, y_i, w_i, h_i = inner
|
||||
x_o, y_o, w_o, h_o = outer
|
||||
margin = self.config.inside_margin
|
||||
|
||||
return (
|
||||
x_i >= x_o - margin and
|
||||
y_i >= y_o - margin and
|
||||
x_i + w_i <= x_o + w_o + margin and
|
||||
y_i + h_i <= y_o + h_o + margin
|
||||
)
|
||||
|
||||
def _overlaps(
|
||||
self,
|
||||
bounds_a: Tuple[int, int, int, int],
|
||||
bounds_b: Tuple[int, int, int, int]
|
||||
) -> bool:
|
||||
"""Vérifier si deux bounds se chevauchent."""
|
||||
x_a, y_a, w_a, h_a = bounds_a
|
||||
x_b, y_b, w_b, h_b = bounds_b
|
||||
|
||||
return not (
|
||||
x_a + w_a < x_b or
|
||||
x_b + w_b < x_a or
|
||||
y_a + h_a < y_b or
|
||||
y_b + h_b < y_a
|
||||
)
|
||||
|
||||
def _compute_gap(
|
||||
self,
|
||||
bounds_a: Tuple[int, int, int, int],
|
||||
bounds_b: Tuple[int, int, int, int]
|
||||
) -> float:
|
||||
"""Calculer l'écart entre deux bounds."""
|
||||
x_a, y_a, w_a, h_a = bounds_a
|
||||
x_b, y_b, w_b, h_b = bounds_b
|
||||
|
||||
# Écart horizontal
|
||||
if x_a + w_a < x_b:
|
||||
gap_x = x_b - (x_a + w_a)
|
||||
elif x_b + w_b < x_a:
|
||||
gap_x = x_a - (x_b + w_b)
|
||||
else:
|
||||
gap_x = 0
|
||||
|
||||
# Écart vertical
|
||||
if y_a + h_a < y_b:
|
||||
gap_y = y_b - (y_a + h_a)
|
||||
elif y_b + h_b < y_a:
|
||||
gap_y = y_a - (y_b + h_b)
|
||||
else:
|
||||
gap_y = 0
|
||||
|
||||
return np.sqrt(gap_x**2 + gap_y**2)
|
||||
|
||||
def detect_containers(
|
||||
self,
|
||||
elements: List[Any]
|
||||
) -> List[SemanticContainer]:
|
||||
"""
|
||||
Détecter les conteneurs sémantiques.
|
||||
|
||||
Identifie les groupes d'éléments formant:
|
||||
- Formulaires (labels + inputs)
|
||||
- Menus (items alignés)
|
||||
- Barres d'outils (boutons alignés)
|
||||
- Dialogues (titre + contenu + boutons)
|
||||
|
||||
Args:
|
||||
elements: Liste d'éléments UI
|
||||
|
||||
Returns:
|
||||
Liste de SemanticContainer
|
||||
"""
|
||||
containers = []
|
||||
|
||||
# Grouper éléments par proximité
|
||||
groups = self._group_by_proximity(elements)
|
||||
|
||||
for group_id, group_elements in enumerate(groups):
|
||||
if len(group_elements) < self.config.min_group_size:
|
||||
continue
|
||||
|
||||
# Analyser le groupe pour déterminer le type
|
||||
container_type, confidence = self._classify_container(group_elements)
|
||||
|
||||
if confidence < self.config.min_container_confidence:
|
||||
continue
|
||||
|
||||
# Calculer bounds du conteneur
|
||||
bounds = self._compute_group_bounds(group_elements)
|
||||
|
||||
container = SemanticContainer(
|
||||
container_id=f"container_{group_id:03d}",
|
||||
container_type=container_type,
|
||||
element_ids=[self._get_element_id(e) for e in group_elements],
|
||||
bounds=bounds,
|
||||
confidence=confidence
|
||||
)
|
||||
containers.append(container)
|
||||
|
||||
logger.info(f"Détecté {len(containers)} conteneurs sémantiques")
|
||||
return containers
|
||||
|
||||
def _group_by_proximity(
|
||||
self,
|
||||
elements: List[Any]
|
||||
) -> List[List[Any]]:
|
||||
"""Grouper les éléments par proximité spatiale."""
|
||||
if not elements:
|
||||
return []
|
||||
|
||||
# Union-Find pour groupement
|
||||
n = len(elements)
|
||||
parent = list(range(n))
|
||||
|
||||
def find(x):
|
||||
if parent[x] != x:
|
||||
parent[x] = find(parent[x])
|
||||
return parent[x]
|
||||
|
||||
def union(x, y):
|
||||
px, py = find(x), find(y)
|
||||
if px != py:
|
||||
parent[px] = py
|
||||
|
||||
# Grouper éléments proches
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
bounds_i = self._get_bounds(elements[i])
|
||||
bounds_j = self._get_bounds(elements[j])
|
||||
|
||||
if bounds_i and bounds_j:
|
||||
gap = self._compute_gap(bounds_i, bounds_j)
|
||||
if gap <= self.config.max_group_distance:
|
||||
union(i, j)
|
||||
|
||||
# Construire groupes
|
||||
groups_dict: Dict[int, List[Any]] = {}
|
||||
for i, elem in enumerate(elements):
|
||||
root = find(i)
|
||||
if root not in groups_dict:
|
||||
groups_dict[root] = []
|
||||
groups_dict[root].append(elem)
|
||||
|
||||
return list(groups_dict.values())
|
||||
|
||||
def _classify_container(
|
||||
self,
|
||||
elements: List[Any]
|
||||
) -> Tuple[ContainerType, float]:
|
||||
"""Classifier le type de conteneur."""
|
||||
# Analyser les types d'éléments
|
||||
roles = [self._get_role(e) for e in elements]
|
||||
|
||||
# Compter types
|
||||
has_input = any(r in ['textbox', 'input', 'textarea', 'combobox'] for r in roles)
|
||||
has_label = any(r in ['label', 'text'] for r in roles)
|
||||
has_button = any(r in ['button', 'link'] for r in roles)
|
||||
has_menuitem = any(r in ['menuitem', 'option'] for r in roles)
|
||||
|
||||
# Analyser alignement
|
||||
bounds_list = [self._get_bounds(e) for e in elements if self._get_bounds(e)]
|
||||
is_vertical = self._is_vertically_aligned(bounds_list)
|
||||
is_horizontal = self._is_horizontally_aligned(bounds_list)
|
||||
|
||||
# Classifier
|
||||
if has_input and has_label:
|
||||
return ContainerType.FORM, 0.8
|
||||
|
||||
if has_menuitem or (is_vertical and has_button):
|
||||
return ContainerType.MENU, 0.7
|
||||
|
||||
if is_horizontal and has_button:
|
||||
return ContainerType.TOOLBAR, 0.7
|
||||
|
||||
if has_button and len(elements) <= 5:
|
||||
return ContainerType.DIALOG, 0.6
|
||||
|
||||
if is_vertical and len(elements) > 3:
|
||||
return ContainerType.LIST, 0.6
|
||||
|
||||
return ContainerType.PANEL, 0.5
|
||||
|
||||
def _is_vertically_aligned(
|
||||
self,
|
||||
bounds_list: List[Tuple[int, int, int, int]]
|
||||
) -> bool:
|
||||
"""Vérifier si les éléments sont alignés verticalement."""
|
||||
if len(bounds_list) < 2:
|
||||
return False
|
||||
|
||||
x_centers = [(b[0] + b[2]/2) for b in bounds_list]
|
||||
x_std = np.std(x_centers)
|
||||
|
||||
return x_std < 30 # Tolérance de 30 pixels
|
||||
|
||||
def _is_horizontally_aligned(
|
||||
self,
|
||||
bounds_list: List[Tuple[int, int, int, int]]
|
||||
) -> bool:
|
||||
"""Vérifier si les éléments sont alignés horizontalement."""
|
||||
if len(bounds_list) < 2:
|
||||
return False
|
||||
|
||||
y_centers = [(b[1] + b[3]/2) for b in bounds_list]
|
||||
y_std = np.std(y_centers)
|
||||
|
||||
return y_std < 30 # Tolérance de 30 pixels
|
||||
|
||||
def _compute_group_bounds(
|
||||
self,
|
||||
elements: List[Any]
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""Calculer les bounds englobant un groupe."""
|
||||
bounds_list = [self._get_bounds(e) for e in elements if self._get_bounds(e)]
|
||||
|
||||
if not bounds_list:
|
||||
return (0, 0, 0, 0)
|
||||
|
||||
min_x = min(b[0] for b in bounds_list)
|
||||
min_y = min(b[1] for b in bounds_list)
|
||||
max_x = max(b[0] + b[2] for b in bounds_list)
|
||||
max_y = max(b[1] + b[3] for b in bounds_list)
|
||||
|
||||
return (min_x, min_y, max_x - min_x, max_y - min_y)
|
||||
|
||||
def find_by_relation(
|
||||
self,
|
||||
anchor_id: str,
|
||||
relation: RelationType,
|
||||
relations: List[SpatialRelation]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Trouver les éléments ayant une relation spécifique avec un ancre.
|
||||
|
||||
Args:
|
||||
anchor_id: ID de l'élément ancre
|
||||
relation: Type de relation recherchée
|
||||
relations: Liste des relations calculées
|
||||
|
||||
Returns:
|
||||
Liste des IDs d'éléments correspondants
|
||||
"""
|
||||
results = []
|
||||
|
||||
for rel in relations:
|
||||
if rel.source_element_id == anchor_id and rel.relation_type == relation:
|
||||
results.append(rel.target_element_id)
|
||||
|
||||
return results
|
||||
|
||||
def _get_bounds(self, element: Any) -> Optional[Tuple[int, int, int, int]]:
|
||||
"""Extraire les bounds d'un élément."""
|
||||
if hasattr(element, 'bounds'):
|
||||
return element.bounds
|
||||
if hasattr(element, 'bbox'):
|
||||
return element.bbox
|
||||
if isinstance(element, dict):
|
||||
if 'bounds' in element:
|
||||
return tuple(element['bounds'])
|
||||
if 'bbox' in element:
|
||||
return tuple(element['bbox'])
|
||||
if all(k in element for k in ['x', 'y', 'width', 'height']):
|
||||
return (element['x'], element['y'], element['width'], element['height'])
|
||||
return None
|
||||
|
||||
def _get_center(self, bounds: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
||||
"""Calculer le centre d'un bounds."""
|
||||
x, y, w, h = bounds
|
||||
return (x + w/2, y + h/2)
|
||||
|
||||
def _get_element_id(self, element: Any) -> str:
|
||||
"""Extraire l'ID d'un élément."""
|
||||
if hasattr(element, 'element_id'):
|
||||
return element.element_id
|
||||
if hasattr(element, 'id'):
|
||||
return element.id
|
||||
if isinstance(element, dict):
|
||||
return element.get('id', element.get('element_id', str(id(element))))
|
||||
return str(id(element))
|
||||
|
||||
def _get_role(self, element: Any) -> str:
|
||||
"""Extraire le rôle d'un élément."""
|
||||
if hasattr(element, 'role'):
|
||||
return element.role.lower()
|
||||
if isinstance(element, dict):
|
||||
return element.get('role', 'unknown').lower()
|
||||
return 'unknown'
|
||||
|
||||
def get_config(self) -> SpatialAnalyzerConfig:
|
||||
"""Récupérer la configuration."""
|
||||
return self.config
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fonctions utilitaires
|
||||
# =============================================================================
|
||||
|
||||
def create_spatial_analyzer(
|
||||
adjacent_threshold: float = 20.0,
|
||||
max_group_distance: float = 50.0
|
||||
) -> SpatialAnalyzer:
|
||||
"""
|
||||
Créer un analyseur avec configuration personnalisée.
|
||||
|
||||
Args:
|
||||
adjacent_threshold: Distance max pour "adjacent"
|
||||
max_group_distance: Distance max pour grouper
|
||||
|
||||
Returns:
|
||||
SpatialAnalyzer configuré
|
||||
"""
|
||||
config = SpatialAnalyzerConfig(
|
||||
adjacent_threshold=adjacent_threshold,
|
||||
max_group_distance=max_group_distance
|
||||
)
|
||||
return SpatialAnalyzer(config)
|
||||
617
core/detection/ui_detector.py
Normal file
617
core/detection/ui_detector.py
Normal file
@@ -0,0 +1,617 @@
|
||||
"""
|
||||
UIDetector - Détection Hybride OpenCV + VLM
|
||||
|
||||
Approche hybride qui combine:
|
||||
1. OpenCV pour détecter rapidement les régions candidates (~10ms)
|
||||
2. VLM pour classifier intelligemment chaque région (~100-200ms par élément)
|
||||
|
||||
Cette approche est plus rapide et plus fiable que le VLM seul.
|
||||
Basée sur l'architecture éprouvée de la V2.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ..models.ui_element import UIElement, UIElementEmbeddings, VisualFeatures
|
||||
from .ollama_client import OllamaClient, check_ollama_available
|
||||
|
||||
# Import OWL-v2 (optionnel)
|
||||
try:
|
||||
from .owl_detector import OwlDetector
|
||||
OWL_AVAILABLE = True
|
||||
except ImportError:
|
||||
OWL_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundingBox:
|
||||
"""Représente une bounding box détectée"""
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
confidence: float = 1.0
|
||||
source: str = "unknown" # "text_detection", "rectangle_detection", etc.
|
||||
|
||||
def area(self) -> int:
|
||||
"""Calcule l'aire de la bbox"""
|
||||
return self.w * self.h
|
||||
|
||||
def center(self) -> Tuple[int, int]:
|
||||
"""Calcule le centre de la bbox"""
|
||||
return (self.x + self.w // 2, self.y + self.h // 2)
|
||||
|
||||
def iou(self, other: 'BoundingBox') -> float:
|
||||
"""Calcule l'Intersection over Union avec une autre bbox"""
|
||||
x1_inter = max(self.x, other.x)
|
||||
y1_inter = max(self.y, other.y)
|
||||
x2_inter = min(self.x + self.w, other.x + other.w)
|
||||
y2_inter = min(self.y + self.h, other.y + other.h)
|
||||
|
||||
if x2_inter < x1_inter or y2_inter < y1_inter:
|
||||
return 0.0
|
||||
|
||||
inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter)
|
||||
union_area = self.area() + other.area() - inter_area
|
||||
|
||||
return inter_area / union_area if union_area > 0 else 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionConfig:
|
||||
"""Configuration de la détection UI hybride"""
|
||||
# VLM
|
||||
# Modèles recommandés:
|
||||
# - "qwen2.5vl:7b" (plus rapide, meilleur avec format='json', recommandé)
|
||||
# - "qwen3-vl:8b" (plus gros, supporté mais plus d'erreurs JSON)
|
||||
vlm_model: str = "qwen2.5vl:7b"
|
||||
vlm_endpoint: str = "http://localhost:11434"
|
||||
use_vlm_classification: bool = True # Utiliser VLM pour classifier
|
||||
|
||||
# OWL-v2 (détection zero-shot)
|
||||
use_owl_detection: bool = True # Utiliser OWL-v2 pour détection
|
||||
owl_confidence_threshold: float = 0.1 # Seuil de confiance OWL-v2
|
||||
|
||||
# OpenCV
|
||||
use_text_detection: bool = True # Détecter zones de texte
|
||||
use_rectangle_detection: bool = True # Détecter rectangles
|
||||
min_region_size: int = 10 # Taille minimale d'une région (réduit pour petits éléments comme checkboxes)
|
||||
max_region_size: int = 600 # Taille maximale d'une région (augmenté pour grands champs)
|
||||
|
||||
# Général
|
||||
confidence_threshold: float = 0.7
|
||||
max_elements: int = 50
|
||||
merge_overlapping: bool = True # Fusionner régions qui se chevauchent
|
||||
iou_threshold: float = 0.5 # Seuil IoU pour fusion
|
||||
|
||||
|
||||
class UIDetector:
|
||||
"""
|
||||
Détecteur UI Hybride : OWL-v2 + OpenCV + VLM
|
||||
|
||||
Pipeline:
|
||||
1. OWL-v2 détecte les éléments UI avec zero-shot (rapide et précis)
|
||||
2. OpenCV détecte les régions candidates supplémentaires (fallback)
|
||||
3. VLM classifie chaque région (précis)
|
||||
4. Création des UIElements avec toutes les infos
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[DetectionConfig] = None):
|
||||
"""Initialiser le détecteur hybride"""
|
||||
self.config = config or DetectionConfig()
|
||||
self.vlm_client = None
|
||||
self.owl_detector = None
|
||||
|
||||
# Initialiser OWL-v2 si demandé
|
||||
if self.config.use_owl_detection and OWL_AVAILABLE:
|
||||
self._initialize_owl()
|
||||
|
||||
# Initialiser VLM si demandé
|
||||
if self.config.use_vlm_classification:
|
||||
self._initialize_vlm()
|
||||
|
||||
def _initialize_owl(self) -> None:
|
||||
"""Initialiser le détecteur OWL-v2"""
|
||||
try:
|
||||
self.owl_detector = OwlDetector(
|
||||
confidence_threshold=self.config.owl_confidence_threshold
|
||||
)
|
||||
logger.info("✓ OWL-v2 initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize OWL-v2: {e}")
|
||||
logger.info("Falling back to OpenCV detection only")
|
||||
self.owl_detector = None
|
||||
|
||||
def _initialize_vlm(self) -> None:
|
||||
"""Initialiser le client VLM"""
|
||||
try:
|
||||
if check_ollama_available(self.config.vlm_endpoint):
|
||||
self.vlm_client = OllamaClient(
|
||||
endpoint=self.config.vlm_endpoint,
|
||||
model=self.config.vlm_model
|
||||
)
|
||||
logger.info(f"✓ VLM initialized: {self.config.vlm_model}")
|
||||
else:
|
||||
logger.warning("Ollama not available, VLM classification disabled")
|
||||
self.vlm_client = None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize VLM: {e}")
|
||||
self.vlm_client = None
|
||||
|
||||
def detect(self,
|
||||
screenshot_path: str,
|
||||
window_context: Optional[Dict[str, Any]] = None) -> List[UIElement]:
|
||||
"""
|
||||
Détecter tous les éléments UI dans un screenshot
|
||||
|
||||
Args:
|
||||
screenshot_path: Chemin vers le screenshot
|
||||
window_context: Contexte de la fenêtre
|
||||
|
||||
Returns:
|
||||
Liste d'UIElements détectés
|
||||
"""
|
||||
# Charger l'image
|
||||
pil_image = Image.open(screenshot_path)
|
||||
cv_image = cv2.imread(screenshot_path)
|
||||
|
||||
if cv_image is None:
|
||||
logger.error(f"Failed to load image: {screenshot_path}")
|
||||
return []
|
||||
|
||||
logger.info(f"Analyzing screenshot: {cv_image.shape[1]}x{cv_image.shape[0]}")
|
||||
|
||||
# Étape 1: Détecter avec OWL-v2 si disponible
|
||||
regions = []
|
||||
if self.owl_detector:
|
||||
logger.debug("Step 1: Detecting UI elements with OWL-v2...")
|
||||
owl_detections = self.owl_detector.detect_ui_elements(pil_image)
|
||||
logger.debug(f"Found {len(owl_detections)} elements with OWL-v2")
|
||||
|
||||
# Convertir détections OWL en BoundingBox avec validation
|
||||
img_width, img_height = pil_image.size
|
||||
|
||||
for det in owl_detections:
|
||||
bbox = det['bbox']
|
||||
|
||||
# Clipper les coordonnées dans les limites de l'image
|
||||
x1 = max(0, int(bbox[0]))
|
||||
y1 = max(0, int(bbox[1]))
|
||||
x2 = min(img_width, int(bbox[2]))
|
||||
y2 = min(img_height, int(bbox[3]))
|
||||
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
|
||||
# Ignorer les bounding boxes invalides (négatives ou taille nulle)
|
||||
if w <= 0 or h <= 0:
|
||||
logger.debug(f"Skipping invalid OWL bbox: x1={bbox[0]}, y1={bbox[1]}, x2={bbox[2]}, y2={bbox[3]}")
|
||||
continue
|
||||
|
||||
regions.append(BoundingBox(
|
||||
x=x1,
|
||||
y=y1,
|
||||
w=w,
|
||||
h=h,
|
||||
confidence=det['confidence'],
|
||||
source=f"owl_{det['label']}"
|
||||
))
|
||||
|
||||
# Étape 1bis: Compléter avec OpenCV si nécessaire
|
||||
if not regions or len(regions) < 5: # Si OWL trouve peu d'éléments
|
||||
logger.debug("Step 1bis: Detecting additional regions with OpenCV...")
|
||||
opencv_regions = self._detect_candidate_regions(cv_image)
|
||||
logger.debug(f"Found {len(opencv_regions)} additional regions")
|
||||
regions.extend(opencv_regions)
|
||||
|
||||
logger.debug(f"Total: {len(regions)} candidate regions")
|
||||
|
||||
# Étape 2: Classifier chaque région avec le VLM
|
||||
logger.debug("Step 2: Classifying regions with VLM...")
|
||||
ui_elements = []
|
||||
|
||||
for i, region in enumerate(regions):
|
||||
# Extraire le crop de la région
|
||||
crop = pil_image.crop((
|
||||
region.x,
|
||||
region.y,
|
||||
region.x + region.w,
|
||||
region.y + region.h
|
||||
))
|
||||
|
||||
# Classifier avec VLM
|
||||
element = self._classify_region(
|
||||
crop,
|
||||
region,
|
||||
screenshot_path,
|
||||
window_context
|
||||
)
|
||||
|
||||
if element and element.confidence >= self.config.confidence_threshold:
|
||||
ui_elements.append(element)
|
||||
|
||||
logger.info(f"Detected {len(ui_elements)} UI elements")
|
||||
|
||||
# Limiter le nombre d'éléments
|
||||
if len(ui_elements) > self.config.max_elements:
|
||||
ui_elements.sort(key=lambda x: x.confidence, reverse=True)
|
||||
ui_elements = ui_elements[:self.config.max_elements]
|
||||
|
||||
return ui_elements
|
||||
|
||||
def _detect_candidate_regions(self, image: np.ndarray) -> List[BoundingBox]:
|
||||
"""
|
||||
Détecter les régions candidates avec OpenCV
|
||||
|
||||
Args:
|
||||
image: Image OpenCV (numpy array)
|
||||
|
||||
Returns:
|
||||
Liste de BoundingBox candidates
|
||||
"""
|
||||
regions = []
|
||||
|
||||
# Méthode 1: Détection de texte
|
||||
if self.config.use_text_detection:
|
||||
text_regions = self._detect_text_regions(image)
|
||||
regions.extend(text_regions)
|
||||
logger.debug(f"Text regions: {len(text_regions)}")
|
||||
|
||||
# Méthode 2: Détection de rectangles
|
||||
if self.config.use_rectangle_detection:
|
||||
rect_regions = self._detect_rectangles(image)
|
||||
regions.extend(rect_regions)
|
||||
logger.debug(f"Rectangle regions: {len(rect_regions)}")
|
||||
|
||||
# Fusionner les régions qui se chevauchent
|
||||
if self.config.merge_overlapping and len(regions) > 0:
|
||||
regions = self._merge_overlapping_regions(regions)
|
||||
logger.debug(f"After merging: {len(regions)}")
|
||||
|
||||
# Filtrer les régions invalides
|
||||
regions = self._filter_invalid_regions(regions, image.shape)
|
||||
|
||||
return regions
|
||||
|
||||
def _detect_text_regions(self, image: np.ndarray) -> List[BoundingBox]:
|
||||
"""Détecter les zones de texte avec OpenCV"""
|
||||
regions = []
|
||||
|
||||
try:
|
||||
# Convertir en niveaux de gris
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Seuillage adaptatif
|
||||
thresh = cv2.adaptiveThreshold(
|
||||
gray, 255,
|
||||
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||
cv2.THRESH_BINARY_INV,
|
||||
11, 2
|
||||
)
|
||||
|
||||
# Trouver les contours
|
||||
contours, _ = cv2.findContours(
|
||||
thresh,
|
||||
cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
# Créer des bboxes
|
||||
for contour in contours:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
|
||||
# Filtrer par taille
|
||||
if w < self.config.min_region_size or h < self.config.min_region_size:
|
||||
continue
|
||||
if w > self.config.max_region_size or h > self.config.max_region_size:
|
||||
continue
|
||||
|
||||
# Filtrer par ratio (texte généralement horizontal, mais accepter carrés pour checkboxes)
|
||||
ratio = w / h if h > 0 else 0
|
||||
if ratio < 0.3 or ratio > 25: # Plus permissif
|
||||
continue
|
||||
|
||||
regions.append(BoundingBox(
|
||||
x=x, y=y, w=w, h=h,
|
||||
confidence=0.7,
|
||||
source="text_detection"
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Text detection error: {e}")
|
||||
|
||||
return regions
|
||||
|
||||
def _detect_rectangles(self, image: np.ndarray) -> List[BoundingBox]:
|
||||
"""Détecter les rectangles propres (boutons, champs, etc.)"""
|
||||
regions = []
|
||||
|
||||
try:
|
||||
# Convertir en niveaux de gris
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Détection de contours avec Canny
|
||||
edges = cv2.Canny(gray, 50, 150)
|
||||
|
||||
# Dilatation pour connecter les contours
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
dilated = cv2.dilate(edges, kernel, iterations=1)
|
||||
|
||||
# Trouver les contours
|
||||
contours, _ = cv2.findContours(
|
||||
dilated,
|
||||
cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
# Créer des bboxes
|
||||
for contour in contours:
|
||||
# Approximer le contour
|
||||
epsilon = 0.02 * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
|
||||
# Garder les formes rectangulaires (4+ coins)
|
||||
if len(approx) >= 4:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
|
||||
# Filtrer par taille
|
||||
if w < self.config.min_region_size or h < self.config.min_region_size:
|
||||
continue
|
||||
if w > self.config.max_region_size or h > self.config.max_region_size:
|
||||
continue
|
||||
|
||||
regions.append(BoundingBox(
|
||||
x=x, y=y, w=w, h=h,
|
||||
confidence=0.8,
|
||||
source="rectangle_detection"
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Rectangle detection error: {e}")
|
||||
|
||||
return regions
|
||||
|
||||
def _merge_overlapping_regions(self, regions: List[BoundingBox]) -> List[BoundingBox]:
|
||||
"""Fusionner les régions qui se chevauchent"""
|
||||
if not regions:
|
||||
return []
|
||||
|
||||
# Trier par confiance décroissante
|
||||
regions = sorted(regions, key=lambda r: r.confidence, reverse=True)
|
||||
|
||||
merged = []
|
||||
used = set()
|
||||
|
||||
for i, region in enumerate(regions):
|
||||
if i in used:
|
||||
continue
|
||||
|
||||
# Chercher les régions qui se chevauchent
|
||||
overlapping = [region]
|
||||
for j, other in enumerate(regions[i+1:], start=i+1):
|
||||
if j in used:
|
||||
continue
|
||||
|
||||
if region.iou(other) > self.config.iou_threshold:
|
||||
overlapping.append(other)
|
||||
used.add(j)
|
||||
|
||||
# Fusionner en prenant l'union
|
||||
if len(overlapping) > 1:
|
||||
x_min = min(r.x for r in overlapping)
|
||||
y_min = min(r.y for r in overlapping)
|
||||
x_max = max(r.x + r.w for r in overlapping)
|
||||
y_max = max(r.y + r.h for r in overlapping)
|
||||
conf = max(r.confidence for r in overlapping)
|
||||
|
||||
merged.append(BoundingBox(
|
||||
x=x_min, y=y_min,
|
||||
w=x_max - x_min, h=y_max - y_min,
|
||||
confidence=conf,
|
||||
source="merged"
|
||||
))
|
||||
else:
|
||||
merged.append(region)
|
||||
|
||||
return merged
|
||||
|
||||
def _filter_invalid_regions(self,
|
||||
regions: List[BoundingBox],
|
||||
image_shape: Tuple[int, ...]) -> List[BoundingBox]:
|
||||
"""Filtrer les régions invalides"""
|
||||
height, width = image_shape[:2]
|
||||
|
||||
valid = []
|
||||
for region in regions:
|
||||
# Vérifier que la région est dans l'image
|
||||
if region.x < 0 or region.y < 0:
|
||||
continue
|
||||
if region.x + region.w > width or region.y + region.h > height:
|
||||
continue
|
||||
|
||||
# Vérifier la taille
|
||||
if region.w < self.config.min_region_size or region.h < self.config.min_region_size:
|
||||
continue
|
||||
if region.w > self.config.max_region_size or region.h > self.config.max_region_size:
|
||||
continue
|
||||
|
||||
valid.append(region)
|
||||
|
||||
return valid
|
||||
|
||||
def _classify_region(self,
|
||||
crop: Image.Image,
|
||||
region: BoundingBox,
|
||||
screenshot_path: str,
|
||||
window_context: Optional[Dict] = None) -> Optional[UIElement]:
|
||||
"""
|
||||
Classifier une région avec le VLM
|
||||
|
||||
Args:
|
||||
crop: Image PIL de la région
|
||||
region: BoundingBox de la région
|
||||
screenshot_path: Chemin du screenshot
|
||||
window_context: Contexte de la fenêtre
|
||||
|
||||
Returns:
|
||||
UIElement ou None
|
||||
"""
|
||||
if self.vlm_client is None:
|
||||
# Fallback: classification basique sans VLM
|
||||
return self._classify_region_fallback(crop, region, screenshot_path)
|
||||
|
||||
try:
|
||||
# OPTIMISATION: Un seul appel VLM au lieu de 3
|
||||
# Avant: classify_element_type() + classify_element_role() + extract_text()
|
||||
# Après: classify_element_complete() → réduction de 66% du temps
|
||||
classification = self.vlm_client.classify_element_complete(crop)
|
||||
|
||||
if classification["success"]:
|
||||
elem_type = classification.get("type", "unknown")
|
||||
elem_role = classification.get("role", "unknown")
|
||||
elem_label = classification.get("text", "")
|
||||
confidence = classification.get("confidence", 0.85)
|
||||
else:
|
||||
# Fallback si échec
|
||||
elem_type = "unknown"
|
||||
elem_role = "unknown"
|
||||
elem_label = ""
|
||||
confidence = 0.5
|
||||
|
||||
# Créer l'UIElement
|
||||
element = UIElement(
|
||||
element_id=f"hybrid_{region.x}_{region.y}",
|
||||
type=elem_type,
|
||||
role=elem_role,
|
||||
bbox=(region.x, region.y, region.w, region.h),
|
||||
center=region.center(),
|
||||
label=elem_label,
|
||||
label_confidence=0.8,
|
||||
embeddings=UIElementEmbeddings(),
|
||||
visual_features=self._extract_visual_features(crop),
|
||||
confidence=confidence,
|
||||
metadata={
|
||||
"detected_by": "hybrid",
|
||||
"detection_method": region.source,
|
||||
"vlm_model": self.config.vlm_model,
|
||||
"screenshot_path": screenshot_path
|
||||
}
|
||||
)
|
||||
|
||||
return element
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Classification error: {e}")
|
||||
return None
|
||||
|
||||
def _classify_region_fallback(self,
|
||||
crop: Image.Image,
|
||||
region: BoundingBox,
|
||||
screenshot_path: str) -> UIElement:
|
||||
"""Classification basique sans VLM (fallback)"""
|
||||
# Heuristiques simples basées sur la taille et la forme
|
||||
aspect_ratio = region.w / region.h if region.h > 0 else 1.0
|
||||
|
||||
if aspect_ratio > 3:
|
||||
elem_type = "text_input"
|
||||
elem_role = "form_input"
|
||||
elif 0.8 <= aspect_ratio <= 1.2 and region.w < 50:
|
||||
elem_type = "checkbox"
|
||||
elem_role = "form_input"
|
||||
else:
|
||||
elem_type = "button"
|
||||
elem_role = "unknown"
|
||||
|
||||
return UIElement(
|
||||
element_id=f"fallback_{region.x}_{region.y}",
|
||||
type=elem_type,
|
||||
role=elem_role,
|
||||
bbox=(region.x, region.y, region.w, region.h),
|
||||
center=region.center(),
|
||||
label="",
|
||||
label_confidence=0.5,
|
||||
embeddings=UIElementEmbeddings(),
|
||||
visual_features=self._extract_visual_features(crop),
|
||||
confidence=0.6,
|
||||
metadata={
|
||||
"detected_by": "hybrid_fallback",
|
||||
"detection_method": region.source,
|
||||
"screenshot_path": screenshot_path
|
||||
}
|
||||
)
|
||||
|
||||
def _extract_visual_features(self, image: Image.Image) -> VisualFeatures:
|
||||
"""Extraire les features visuelles d'une image"""
|
||||
# Calculer couleur dominante
|
||||
img_array = np.array(image)
|
||||
if len(img_array.shape) == 3:
|
||||
dominant_color = tuple(img_array.mean(axis=(0, 1)).astype(int).tolist())
|
||||
else:
|
||||
dominant_color = (128, 128, 128)
|
||||
|
||||
# Déterminer forme
|
||||
width, height = image.size
|
||||
aspect_ratio = width / height if height > 0 else 1.0
|
||||
|
||||
if aspect_ratio > 3:
|
||||
shape = "horizontal_bar"
|
||||
elif aspect_ratio < 0.33:
|
||||
shape = "vertical_bar"
|
||||
elif 0.8 <= aspect_ratio <= 1.2:
|
||||
shape = "square"
|
||||
else:
|
||||
shape = "rectangle"
|
||||
|
||||
# Catégorie de taille
|
||||
area = width * height
|
||||
if area < 1000:
|
||||
size_category = "small"
|
||||
elif area < 10000:
|
||||
size_category = "medium"
|
||||
else:
|
||||
size_category = "large"
|
||||
|
||||
# Détection d'icône
|
||||
has_icon = width < 100 and height < 100 and 0.8 <= aspect_ratio <= 1.2
|
||||
|
||||
return VisualFeatures(
|
||||
dominant_color=dominant_color,
|
||||
has_icon=has_icon,
|
||||
shape=shape,
|
||||
size_category=size_category
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fonctions utilitaires
|
||||
# ============================================================================
|
||||
|
||||
def create_detector(
|
||||
vlm_model: str = "qwen3-vl:8b",
|
||||
confidence_threshold: float = 0.7,
|
||||
use_vlm: bool = True
|
||||
) -> UIDetector:
|
||||
"""
|
||||
Créer un détecteur avec configuration personnalisée
|
||||
|
||||
Args:
|
||||
vlm_model: Modèle VLM à utiliser
|
||||
confidence_threshold: Seuil de confiance
|
||||
use_vlm: Utiliser le VLM pour la classification
|
||||
|
||||
Returns:
|
||||
UIDetector configuré
|
||||
"""
|
||||
config = DetectionConfig(
|
||||
vlm_model=vlm_model,
|
||||
confidence_threshold=confidence_threshold,
|
||||
use_vlm_classification=use_vlm
|
||||
)
|
||||
return UIDetector(config)
|
||||
Reference in New Issue
Block a user