""" Set-of-Mark Engine — Détection et numérotation des éléments UI. Pipeline : YOLO (détection icônes) + docTR (OCR) + numérotation visuelle. Le VLM (qwen3-vl) est utilisé ensuite pour l'identification sémantique. Usage : from core.detection.som_engine import SomEngine engine = SomEngine() result = engine.analyze(screenshot_pil) # result.elements : liste d'éléments avec coordonnées # result.som_image : screenshot avec numéros # result.som_image_b64 : en base64 Architecture : - YOLO v8 (icon_detect) : détecte les éléments interactifs (~15ms GPU) - docTR : OCR pour lire le texte visible (~100ms GPU) - Annotation : numérote chaque élément sur le screenshot - Le VLM n'est PAS appelé ici (séparation détection/identification) """ from __future__ import annotations import base64 import io import logging import os from dataclasses import dataclass, field from pathlib import Path from typing import List, Optional, Tuple from PIL import Image, ImageDraw, ImageFont logger = logging.getLogger(__name__) # Chemin vers les poids YOLO d'OmniParser (configurable via env) _YOLO_WEIGHTS = Path( os.environ.get("SOM_YOLO_WEIGHTS", "/home/dom/ai/OmniParser/weights/icon_detect/model.pt") ) @dataclass class SomElement: """Un élément UI détecté et numéroté.""" id: int # Numéro Set-of-Mark (0, 1, 2, ...) bbox: Tuple[int, int, int, int] # Pixels (x1, y1, x2, y2) bbox_norm: Tuple[float, float, float, float] # Normalisé 0-1 center: Tuple[int, int] # Centre en pixels center_norm: Tuple[float, float] # Centre normalisé 0-1 source: str # 'yolo' ou 'ocr' label: str = "" # Texte OCR ou description confidence: float = 0.0 @dataclass class SomResult: """Résultat de l'analyse Set-of-Mark.""" elements: List[SomElement] = field(default_factory=list) som_image: Optional[Image.Image] = None # Screenshot annoté som_image_b64: str = "" # En base64 (JPEG) width: int = 0 height: int = 0 analysis_time_ms: float = 0.0 def find_element_at(self, x: int, y: int, margin: int = 20) -> Optional[SomElement]: """Trouver l'élément le plus proche d'un point (x, y).""" best = None best_dist = float("inf") for elem in self.elements: x1, y1, x2, y2 = elem.bbox # Point dans la bbox ? if x1 - margin <= x <= x2 + margin and y1 - margin <= y <= y2 + margin: cx, cy = elem.center dist = ((x - cx) ** 2 + (y - cy) ** 2) ** 0.5 if dist < best_dist: best_dist = dist best = elem return best def get_element_by_id(self, element_id: int) -> Optional[SomElement]: """Récupérer un élément par son numéro SoM.""" for elem in self.elements: if elem.id == element_id: return elem return None class SomEngine: """Moteur Set-of-Mark : YOLO + docTR + annotation.""" def __init__(self, device: str = "cuda"): self._device = device self._yolo = None self._ocr = None self._loaded = False def _ensure_loaded(self): """Chargement paresseux des modèles.""" if self._loaded: return import time t = time.time() # YOLO if _YOLO_WEIGHTS.is_file(): from ultralytics import YOLO self._yolo = YOLO(str(_YOLO_WEIGHTS)) self._yolo.to(self._device) logger.info("SoM: YOLO chargé sur %s", self._device) else: logger.warning("SoM: YOLO weights introuvable: %s", _YOLO_WEIGHTS) # docTR try: from doctr.models import ocr_predictor self._ocr = ocr_predictor( det_arch="db_resnet50", reco_arch="crnn_vgg16_bn", pretrained=True, ) if self._device == "cuda": self._ocr = self._ocr.cuda() logger.info("SoM: docTR chargé") except Exception as e: logger.warning("SoM: docTR non disponible: %s", e) self._loaded = True logger.info("SoM: modèles chargés en %.1fs", time.time() - t) def analyze(self, screenshot: Image.Image) -> SomResult: """Analyser un screenshot : détecter tous les éléments et les numéroter. Returns: SomResult avec les éléments détectés et le screenshot annoté. """ import time t_start = time.time() self._ensure_loaded() W, H = screenshot.size elements: List[SomElement] = [] elem_id = 0 # ── 1. YOLO : détecter les éléments interactifs ── if self._yolo is not None: results = self._yolo.predict( source=screenshot, conf=0.15, iou=0.5, verbose=False, ) boxes = results[0].boxes if results else [] for box in boxes: x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()] conf = float(box.conf[0]) cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 elements.append(SomElement( id=elem_id, bbox=(x1, y1, x2, y2), bbox_norm=(x1 / W, y1 / H, x2 / W, y2 / H), center=(cx, cy), center_norm=(cx / W, cy / H), source="yolo", confidence=conf, )) elem_id += 1 # ── 2. docTR : OCR pour lire le texte ── if self._ocr is not None: try: from doctr.io import DocumentFile # Convertir PIL → fichier temporaire pour docTR import tempfile with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: screenshot.save(tmp, format="JPEG", quality=90) tmp_path = tmp.name try: doc = DocumentFile.from_images([tmp_path]) result_ocr = self._ocr(doc) finally: os.unlink(tmp_path) for page in result_ocr.pages: for block in page.blocks: for line in block.lines: for word in line.words: text = word.value.strip() if not text or len(text) < 2: continue # docTR retourne des coords normalisées (0-1) (nx1, ny1), (nx2, ny2) = word.geometry x1 = int(nx1 * W) y1 = int(ny1 * H) x2 = int(nx2 * W) y2 = int(ny2 * H) cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 # Vérifier si ce texte chevauche un élément YOLO existant overlaps = False for existing in elements: ex1, ey1, ex2, ey2 = existing.bbox # IoU simple ix1 = max(x1, ex1) iy1 = max(y1, ey1) ix2 = min(x2, ex2) iy2 = min(y2, ey2) if ix1 < ix2 and iy1 < iy2: # Chevauchement → enrichir le label YOLO if not existing.label: existing.label = text overlaps = True break if not overlaps: elements.append(SomElement( id=elem_id, bbox=(x1, y1, x2, y2), bbox_norm=(nx1, ny1, nx2, ny2), center=(cx, cy), center_norm=(cx / W, cy / H), source="ocr", label=text, confidence=word.confidence, )) elem_id += 1 except Exception as e: logger.warning("SoM: Erreur OCR docTR: %s", e) # ── 3. Annoter le screenshot avec les numéros ── som_image = self._annotate(screenshot.copy(), elements) som_b64 = self._image_to_b64(som_image) elapsed_ms = (time.time() - t_start) * 1000 logger.info( "SoM: %d éléments (%d yolo, %d ocr) en %.0fms", len(elements), sum(1 for e in elements if e.source == "yolo"), sum(1 for e in elements if e.source == "ocr"), elapsed_ms, ) return SomResult( elements=elements, som_image=som_image, som_image_b64=som_b64, width=W, height=H, analysis_time_ms=elapsed_ms, ) @staticmethod def _annotate(image: Image.Image, elements: List[SomElement]) -> Image.Image: """Dessiner les numéros SoM sur le screenshot.""" draw = ImageDraw.Draw(image) # Police try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) font_small = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12) except Exception: font = ImageFont.load_default() font_small = font for elem in elements: x1, y1, x2, y2 = elem.bbox # Couleur selon la source color = (255, 50, 50) if elem.source == "yolo" else (50, 50, 255) # Boîte draw.rectangle([x1, y1, x2, y2], outline=color, width=2) # Numéro (badge rouge) label = str(elem.id) badge_w, badge_h = 24, 18 badge_x = max(0, x1 - 2) badge_y = max(0, y1 - badge_h - 2) draw.rectangle( [badge_x, badge_y, badge_x + badge_w, badge_y + badge_h], fill=(220, 30, 30), ) draw.text( (badge_x + 3, badge_y + 1), label, fill="white", font=font_small, ) return image @staticmethod def _image_to_b64(image: Image.Image, quality: int = 70) -> str: """Convertir une image PIL en base64 JPEG.""" buf = io.BytesIO() image.save(buf, format="JPEG", quality=quality) return base64.b64encode(buf.getvalue()).decode() # --------------------------------------------------------------------------- # Singleton partagé (lazy-loaded, thread-safe) # --------------------------------------------------------------------------- _shared_engine: Optional[SomEngine] = None _shared_lock = __import__("threading").Lock() def get_shared_engine(device: str = "cuda") -> Optional[SomEngine]: """Singleton SomEngine partagé entre tous les modules.""" global _shared_engine if _shared_engine is None: with _shared_lock: if _shared_engine is None: try: _shared_engine = SomEngine(device=device) logger.info("SomEngine singleton partagé initialisé") except Exception as e: logger.warning("SomEngine non disponible : %s", e) return None return _shared_engine