feat: SomEngine — Set-of-Mark avec YOLO + docTR pour détection UI

- SomEngine : détecte et numérote tous les éléments UI d'un screenshot
- YOLO v8 (OmniParser) : détection icônes/boutons (~15ms GPU)
- docTR : OCR pour le texte visible
- Annotation visuelle : numéros rouges sur chaque élément
- find_element_at(x, y) : trouve l'élément cliqué par coordonnées
- Fix Florence-2 / transformers 4.57 incompatibilité (past_key_values)
- Testé : 107 éléments détectés sur screenshot Windows 2560x1600

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Dom
2026-03-31 08:26:07 +02:00
parent 3417f09598
commit 2ddccff108

View File

@@ -0,0 +1,290 @@
"""
Set-of-Mark Engine — Détection et numérotation des éléments UI.
Pipeline : YOLO (détection icônes) + docTR (OCR) + numérotation visuelle.
Le VLM (qwen3-vl) est utilisé ensuite pour l'identification sémantique.
Usage :
from core.detection.som_engine import SomEngine
engine = SomEngine()
result = engine.analyze(screenshot_pil)
# result.elements : liste d'éléments avec coordonnées
# result.som_image : screenshot avec numéros
# result.som_image_b64 : en base64
Architecture :
- YOLO v8 (icon_detect) : détecte les éléments interactifs (~15ms GPU)
- docTR : OCR pour lire le texte visible (~100ms GPU)
- Annotation : numérote chaque élément sur le screenshot
- Le VLM n'est PAS appelé ici (séparation détection/identification)
"""
from __future__ import annotations
import base64
import io
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
from PIL import Image, ImageDraw, ImageFont
logger = logging.getLogger(__name__)
# Chemin vers les poids YOLO d'OmniParser
_YOLO_WEIGHTS = Path("/home/dom/ai/OmniParser/weights/icon_detect/model.pt")
@dataclass
class SomElement:
"""Un élément UI détecté et numéroté."""
id: int # Numéro Set-of-Mark (0, 1, 2, ...)
bbox: Tuple[int, int, int, int] # Pixels (x1, y1, x2, y2)
bbox_norm: Tuple[float, float, float, float] # Normalisé 0-1
center: Tuple[int, int] # Centre en pixels
center_norm: Tuple[float, float] # Centre normalisé 0-1
source: str # 'yolo' ou 'ocr'
label: str = "" # Texte OCR ou description
confidence: float = 0.0
@dataclass
class SomResult:
"""Résultat de l'analyse Set-of-Mark."""
elements: List[SomElement] = field(default_factory=list)
som_image: Optional[Image.Image] = None # Screenshot annoté
som_image_b64: str = "" # En base64 (JPEG)
width: int = 0
height: int = 0
analysis_time_ms: float = 0.0
def find_element_at(self, x: int, y: int, margin: int = 20) -> Optional[SomElement]:
"""Trouver l'élément le plus proche d'un point (x, y)."""
best = None
best_dist = float("inf")
for elem in self.elements:
x1, y1, x2, y2 = elem.bbox
# Point dans la bbox ?
if x1 - margin <= x <= x2 + margin and y1 - margin <= y <= y2 + margin:
cx, cy = elem.center
dist = ((x - cx) ** 2 + (y - cy) ** 2) ** 0.5
if dist < best_dist:
best_dist = dist
best = elem
return best
def get_element_by_id(self, element_id: int) -> Optional[SomElement]:
"""Récupérer un élément par son numéro SoM."""
for elem in self.elements:
if elem.id == element_id:
return elem
return None
class SomEngine:
"""Moteur Set-of-Mark : YOLO + docTR + annotation."""
def __init__(self, device: str = "cuda"):
self._device = device
self._yolo = None
self._ocr = None
self._loaded = False
def _ensure_loaded(self):
"""Chargement paresseux des modèles."""
if self._loaded:
return
import time
t = time.time()
# YOLO
if _YOLO_WEIGHTS.is_file():
from ultralytics import YOLO
self._yolo = YOLO(str(_YOLO_WEIGHTS))
self._yolo.to(self._device)
logger.info("SoM: YOLO chargé sur %s", self._device)
else:
logger.warning("SoM: YOLO weights introuvable: %s", _YOLO_WEIGHTS)
# docTR
try:
from doctr.models import ocr_predictor
self._ocr = ocr_predictor(
det_arch="db_resnet50",
reco_arch="crnn_vgg16_bn",
pretrained=True,
)
if self._device == "cuda":
self._ocr = self._ocr.cuda()
logger.info("SoM: docTR chargé")
except Exception as e:
logger.warning("SoM: docTR non disponible: %s", e)
self._loaded = True
logger.info("SoM: modèles chargés en %.1fs", time.time() - t)
def analyze(self, screenshot: Image.Image) -> SomResult:
"""Analyser un screenshot : détecter tous les éléments et les numéroter.
Returns:
SomResult avec les éléments détectés et le screenshot annoté.
"""
import time
t_start = time.time()
self._ensure_loaded()
W, H = screenshot.size
elements: List[SomElement] = []
elem_id = 0
# ── 1. YOLO : détecter les éléments interactifs ──
if self._yolo is not None:
results = self._yolo.predict(
source=screenshot, conf=0.15, iou=0.5, verbose=False,
)
boxes = results[0].boxes if results else []
for box in boxes:
x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()]
conf = float(box.conf[0])
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
elements.append(SomElement(
id=elem_id,
bbox=(x1, y1, x2, y2),
bbox_norm=(x1 / W, y1 / H, x2 / W, y2 / H),
center=(cx, cy),
center_norm=(cx / W, cy / H),
source="yolo",
confidence=conf,
))
elem_id += 1
# ── 2. docTR : OCR pour lire le texte ──
if self._ocr is not None:
try:
import numpy as np
from doctr.io import DocumentFile
# Convertir PIL → fichier temporaire pour docTR
import tempfile
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
screenshot.save(tmp, format="JPEG", quality=90)
tmp_path = tmp.name
doc = DocumentFile.from_images([tmp_path])
import os
os.unlink(tmp_path)
result_ocr = self._ocr(doc)
for page in result_ocr.pages:
for block in page.blocks:
for line in block.lines:
for word in line.words:
text = word.value.strip()
if not text or len(text) < 2:
continue
# docTR retourne des coords normalisées (0-1)
(nx1, ny1), (nx2, ny2) = word.geometry
x1 = int(nx1 * W)
y1 = int(ny1 * H)
x2 = int(nx2 * W)
y2 = int(ny2 * H)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
# Vérifier si ce texte chevauche un élément YOLO existant
overlaps = False
for existing in elements:
ex1, ey1, ex2, ey2 = existing.bbox
# IoU simple
ix1 = max(x1, ex1)
iy1 = max(y1, ey1)
ix2 = min(x2, ex2)
iy2 = min(y2, ey2)
if ix1 < ix2 and iy1 < iy2:
# Chevauchement → enrichir le label YOLO
if not existing.label:
existing.label = text
overlaps = True
break
if not overlaps:
elements.append(SomElement(
id=elem_id,
bbox=(x1, y1, x2, y2),
bbox_norm=(nx1, ny1, nx2, ny2),
center=(cx, cy),
center_norm=(cx / W, cy / H),
source="ocr",
label=text,
confidence=word.confidence,
))
elem_id += 1
except Exception as e:
logger.warning("SoM: Erreur OCR docTR: %s", e)
# ── 3. Annoter le screenshot avec les numéros ──
som_image = self._annotate(screenshot.copy(), elements)
som_b64 = self._image_to_b64(som_image)
elapsed_ms = (time.time() - t_start) * 1000
logger.info(
"SoM: %d éléments (%d yolo, %d ocr) en %.0fms",
len(elements),
sum(1 for e in elements if e.source == "yolo"),
sum(1 for e in elements if e.source == "ocr"),
elapsed_ms,
)
return SomResult(
elements=elements,
som_image=som_image,
som_image_b64=som_b64,
width=W,
height=H,
analysis_time_ms=elapsed_ms,
)
@staticmethod
def _annotate(image: Image.Image, elements: List[SomElement]) -> Image.Image:
"""Dessiner les numéros SoM sur le screenshot."""
draw = ImageDraw.Draw(image)
# Police
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
font_small = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
except Exception:
font = ImageFont.load_default()
font_small = font
for elem in elements:
x1, y1, x2, y2 = elem.bbox
# Couleur selon la source
color = (255, 50, 50) if elem.source == "yolo" else (50, 50, 255)
# Boîte
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
# Numéro (badge rouge)
label = str(elem.id)
badge_w, badge_h = 24, 18
badge_x = max(0, x1 - 2)
badge_y = max(0, y1 - badge_h - 2)
draw.rectangle(
[badge_x, badge_y, badge_x + badge_w, badge_y + badge_h],
fill=(220, 30, 30),
)
draw.text(
(badge_x + 3, badge_y + 1), label,
fill="white", font=font_small,
)
return image
@staticmethod
def _image_to_b64(image: Image.Image, quality: int = 70) -> str:
"""Convertir une image PIL en base64 JPEG."""
buf = io.BytesIO()
image.save(buf, format="JPEG", quality=quality)
return base64.b64encode(buf.getvalue()).decode()