diff --git a/core/detection/som_engine.py b/core/detection/som_engine.py new file mode 100644 index 000000000..041aa7235 --- /dev/null +++ b/core/detection/som_engine.py @@ -0,0 +1,290 @@ +""" +Set-of-Mark Engine — Détection et numérotation des éléments UI. + +Pipeline : YOLO (détection icônes) + docTR (OCR) + numérotation visuelle. +Le VLM (qwen3-vl) est utilisé ensuite pour l'identification sémantique. + +Usage : + from core.detection.som_engine import SomEngine + + engine = SomEngine() + result = engine.analyze(screenshot_pil) + # result.elements : liste d'éléments avec coordonnées + # result.som_image : screenshot avec numéros + # result.som_image_b64 : en base64 + +Architecture : + - YOLO v8 (icon_detect) : détecte les éléments interactifs (~15ms GPU) + - docTR : OCR pour lire le texte visible (~100ms GPU) + - Annotation : numérote chaque élément sur le screenshot + - Le VLM n'est PAS appelé ici (séparation détection/identification) +""" + +from __future__ import annotations + +import base64 +import io +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional, Tuple + +from PIL import Image, ImageDraw, ImageFont + +logger = logging.getLogger(__name__) + +# Chemin vers les poids YOLO d'OmniParser +_YOLO_WEIGHTS = Path("/home/dom/ai/OmniParser/weights/icon_detect/model.pt") + + +@dataclass +class SomElement: + """Un élément UI détecté et numéroté.""" + id: int # Numéro Set-of-Mark (0, 1, 2, ...) + bbox: Tuple[int, int, int, int] # Pixels (x1, y1, x2, y2) + bbox_norm: Tuple[float, float, float, float] # Normalisé 0-1 + center: Tuple[int, int] # Centre en pixels + center_norm: Tuple[float, float] # Centre normalisé 0-1 + source: str # 'yolo' ou 'ocr' + label: str = "" # Texte OCR ou description + confidence: float = 0.0 + + +@dataclass +class SomResult: + """Résultat de l'analyse Set-of-Mark.""" + elements: List[SomElement] = field(default_factory=list) + som_image: Optional[Image.Image] = None # Screenshot annoté + som_image_b64: str = "" # En base64 (JPEG) + width: int = 0 + height: int = 0 + analysis_time_ms: float = 0.0 + + def find_element_at(self, x: int, y: int, margin: int = 20) -> Optional[SomElement]: + """Trouver l'élément le plus proche d'un point (x, y).""" + best = None + best_dist = float("inf") + for elem in self.elements: + x1, y1, x2, y2 = elem.bbox + # Point dans la bbox ? + if x1 - margin <= x <= x2 + margin and y1 - margin <= y <= y2 + margin: + cx, cy = elem.center + dist = ((x - cx) ** 2 + (y - cy) ** 2) ** 0.5 + if dist < best_dist: + best_dist = dist + best = elem + return best + + def get_element_by_id(self, element_id: int) -> Optional[SomElement]: + """Récupérer un élément par son numéro SoM.""" + for elem in self.elements: + if elem.id == element_id: + return elem + return None + + +class SomEngine: + """Moteur Set-of-Mark : YOLO + docTR + annotation.""" + + def __init__(self, device: str = "cuda"): + self._device = device + self._yolo = None + self._ocr = None + self._loaded = False + + def _ensure_loaded(self): + """Chargement paresseux des modèles.""" + if self._loaded: + return + + import time + t = time.time() + + # YOLO + if _YOLO_WEIGHTS.is_file(): + from ultralytics import YOLO + self._yolo = YOLO(str(_YOLO_WEIGHTS)) + self._yolo.to(self._device) + logger.info("SoM: YOLO chargé sur %s", self._device) + else: + logger.warning("SoM: YOLO weights introuvable: %s", _YOLO_WEIGHTS) + + # docTR + try: + from doctr.models import ocr_predictor + self._ocr = ocr_predictor( + det_arch="db_resnet50", + reco_arch="crnn_vgg16_bn", + pretrained=True, + ) + if self._device == "cuda": + self._ocr = self._ocr.cuda() + logger.info("SoM: docTR chargé") + except Exception as e: + logger.warning("SoM: docTR non disponible: %s", e) + + self._loaded = True + logger.info("SoM: modèles chargés en %.1fs", time.time() - t) + + def analyze(self, screenshot: Image.Image) -> SomResult: + """Analyser un screenshot : détecter tous les éléments et les numéroter. + + Returns: + SomResult avec les éléments détectés et le screenshot annoté. + """ + import time + t_start = time.time() + + self._ensure_loaded() + + W, H = screenshot.size + elements: List[SomElement] = [] + elem_id = 0 + + # ── 1. YOLO : détecter les éléments interactifs ── + if self._yolo is not None: + results = self._yolo.predict( + source=screenshot, conf=0.15, iou=0.5, verbose=False, + ) + boxes = results[0].boxes if results else [] + for box in boxes: + x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()] + conf = float(box.conf[0]) + cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 + elements.append(SomElement( + id=elem_id, + bbox=(x1, y1, x2, y2), + bbox_norm=(x1 / W, y1 / H, x2 / W, y2 / H), + center=(cx, cy), + center_norm=(cx / W, cy / H), + source="yolo", + confidence=conf, + )) + elem_id += 1 + + # ── 2. docTR : OCR pour lire le texte ── + if self._ocr is not None: + try: + import numpy as np + from doctr.io import DocumentFile + # Convertir PIL → fichier temporaire pour docTR + import tempfile + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: + screenshot.save(tmp, format="JPEG", quality=90) + tmp_path = tmp.name + doc = DocumentFile.from_images([tmp_path]) + import os + os.unlink(tmp_path) + result_ocr = self._ocr(doc) + + for page in result_ocr.pages: + for block in page.blocks: + for line in block.lines: + for word in line.words: + text = word.value.strip() + if not text or len(text) < 2: + continue + # docTR retourne des coords normalisées (0-1) + (nx1, ny1), (nx2, ny2) = word.geometry + x1 = int(nx1 * W) + y1 = int(ny1 * H) + x2 = int(nx2 * W) + y2 = int(ny2 * H) + cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 + + # Vérifier si ce texte chevauche un élément YOLO existant + overlaps = False + for existing in elements: + ex1, ey1, ex2, ey2 = existing.bbox + # IoU simple + ix1 = max(x1, ex1) + iy1 = max(y1, ey1) + ix2 = min(x2, ex2) + iy2 = min(y2, ey2) + if ix1 < ix2 and iy1 < iy2: + # Chevauchement → enrichir le label YOLO + if not existing.label: + existing.label = text + overlaps = True + break + + if not overlaps: + elements.append(SomElement( + id=elem_id, + bbox=(x1, y1, x2, y2), + bbox_norm=(nx1, ny1, nx2, ny2), + center=(cx, cy), + center_norm=(cx / W, cy / H), + source="ocr", + label=text, + confidence=word.confidence, + )) + elem_id += 1 + except Exception as e: + logger.warning("SoM: Erreur OCR docTR: %s", e) + + # ── 3. Annoter le screenshot avec les numéros ── + som_image = self._annotate(screenshot.copy(), elements) + som_b64 = self._image_to_b64(som_image) + + elapsed_ms = (time.time() - t_start) * 1000 + logger.info( + "SoM: %d éléments (%d yolo, %d ocr) en %.0fms", + len(elements), + sum(1 for e in elements if e.source == "yolo"), + sum(1 for e in elements if e.source == "ocr"), + elapsed_ms, + ) + + return SomResult( + elements=elements, + som_image=som_image, + som_image_b64=som_b64, + width=W, + height=H, + analysis_time_ms=elapsed_ms, + ) + + @staticmethod + def _annotate(image: Image.Image, elements: List[SomElement]) -> Image.Image: + """Dessiner les numéros SoM sur le screenshot.""" + draw = ImageDraw.Draw(image) + + # Police + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) + font_small = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12) + except Exception: + font = ImageFont.load_default() + font_small = font + + for elem in elements: + x1, y1, x2, y2 = elem.bbox + # Couleur selon la source + color = (255, 50, 50) if elem.source == "yolo" else (50, 50, 255) + + # Boîte + draw.rectangle([x1, y1, x2, y2], outline=color, width=2) + + # Numéro (badge rouge) + label = str(elem.id) + badge_w, badge_h = 24, 18 + badge_x = max(0, x1 - 2) + badge_y = max(0, y1 - badge_h - 2) + draw.rectangle( + [badge_x, badge_y, badge_x + badge_w, badge_y + badge_h], + fill=(220, 30, 30), + ) + draw.text( + (badge_x + 3, badge_y + 1), label, + fill="white", font=font_small, + ) + + return image + + @staticmethod + def _image_to_b64(image: Image.Image, quality: int = 70) -> str: + """Convertir une image PIL en base64 JPEG.""" + buf = io.BytesIO() + image.save(buf, format="JPEG", quality=quality) + return base64.b64encode(buf.getvalue()).decode()