feat: SomEngine — Set-of-Mark avec YOLO + docTR pour détection UI
- SomEngine : détecte et numérote tous les éléments UI d'un screenshot - YOLO v8 (OmniParser) : détection icônes/boutons (~15ms GPU) - docTR : OCR pour le texte visible - Annotation visuelle : numéros rouges sur chaque élément - find_element_at(x, y) : trouve l'élément cliqué par coordonnées - Fix Florence-2 / transformers 4.57 incompatibilité (past_key_values) - Testé : 107 éléments détectés sur screenshot Windows 2560x1600 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
290
core/detection/som_engine.py
Normal file
290
core/detection/som_engine.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Set-of-Mark Engine — Détection et numérotation des éléments UI.
|
||||
|
||||
Pipeline : YOLO (détection icônes) + docTR (OCR) + numérotation visuelle.
|
||||
Le VLM (qwen3-vl) est utilisé ensuite pour l'identification sémantique.
|
||||
|
||||
Usage :
|
||||
from core.detection.som_engine import SomEngine
|
||||
|
||||
engine = SomEngine()
|
||||
result = engine.analyze(screenshot_pil)
|
||||
# result.elements : liste d'éléments avec coordonnées
|
||||
# result.som_image : screenshot avec numéros
|
||||
# result.som_image_b64 : en base64
|
||||
|
||||
Architecture :
|
||||
- YOLO v8 (icon_detect) : détecte les éléments interactifs (~15ms GPU)
|
||||
- docTR : OCR pour lire le texte visible (~100ms GPU)
|
||||
- Annotation : numérote chaque élément sur le screenshot
|
||||
- Le VLM n'est PAS appelé ici (séparation détection/identification)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Chemin vers les poids YOLO d'OmniParser
|
||||
_YOLO_WEIGHTS = Path("/home/dom/ai/OmniParser/weights/icon_detect/model.pt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SomElement:
|
||||
"""Un élément UI détecté et numéroté."""
|
||||
id: int # Numéro Set-of-Mark (0, 1, 2, ...)
|
||||
bbox: Tuple[int, int, int, int] # Pixels (x1, y1, x2, y2)
|
||||
bbox_norm: Tuple[float, float, float, float] # Normalisé 0-1
|
||||
center: Tuple[int, int] # Centre en pixels
|
||||
center_norm: Tuple[float, float] # Centre normalisé 0-1
|
||||
source: str # 'yolo' ou 'ocr'
|
||||
label: str = "" # Texte OCR ou description
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SomResult:
|
||||
"""Résultat de l'analyse Set-of-Mark."""
|
||||
elements: List[SomElement] = field(default_factory=list)
|
||||
som_image: Optional[Image.Image] = None # Screenshot annoté
|
||||
som_image_b64: str = "" # En base64 (JPEG)
|
||||
width: int = 0
|
||||
height: int = 0
|
||||
analysis_time_ms: float = 0.0
|
||||
|
||||
def find_element_at(self, x: int, y: int, margin: int = 20) -> Optional[SomElement]:
|
||||
"""Trouver l'élément le plus proche d'un point (x, y)."""
|
||||
best = None
|
||||
best_dist = float("inf")
|
||||
for elem in self.elements:
|
||||
x1, y1, x2, y2 = elem.bbox
|
||||
# Point dans la bbox ?
|
||||
if x1 - margin <= x <= x2 + margin and y1 - margin <= y <= y2 + margin:
|
||||
cx, cy = elem.center
|
||||
dist = ((x - cx) ** 2 + (y - cy) ** 2) ** 0.5
|
||||
if dist < best_dist:
|
||||
best_dist = dist
|
||||
best = elem
|
||||
return best
|
||||
|
||||
def get_element_by_id(self, element_id: int) -> Optional[SomElement]:
|
||||
"""Récupérer un élément par son numéro SoM."""
|
||||
for elem in self.elements:
|
||||
if elem.id == element_id:
|
||||
return elem
|
||||
return None
|
||||
|
||||
|
||||
class SomEngine:
|
||||
"""Moteur Set-of-Mark : YOLO + docTR + annotation."""
|
||||
|
||||
def __init__(self, device: str = "cuda"):
|
||||
self._device = device
|
||||
self._yolo = None
|
||||
self._ocr = None
|
||||
self._loaded = False
|
||||
|
||||
def _ensure_loaded(self):
|
||||
"""Chargement paresseux des modèles."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
import time
|
||||
t = time.time()
|
||||
|
||||
# YOLO
|
||||
if _YOLO_WEIGHTS.is_file():
|
||||
from ultralytics import YOLO
|
||||
self._yolo = YOLO(str(_YOLO_WEIGHTS))
|
||||
self._yolo.to(self._device)
|
||||
logger.info("SoM: YOLO chargé sur %s", self._device)
|
||||
else:
|
||||
logger.warning("SoM: YOLO weights introuvable: %s", _YOLO_WEIGHTS)
|
||||
|
||||
# docTR
|
||||
try:
|
||||
from doctr.models import ocr_predictor
|
||||
self._ocr = ocr_predictor(
|
||||
det_arch="db_resnet50",
|
||||
reco_arch="crnn_vgg16_bn",
|
||||
pretrained=True,
|
||||
)
|
||||
if self._device == "cuda":
|
||||
self._ocr = self._ocr.cuda()
|
||||
logger.info("SoM: docTR chargé")
|
||||
except Exception as e:
|
||||
logger.warning("SoM: docTR non disponible: %s", e)
|
||||
|
||||
self._loaded = True
|
||||
logger.info("SoM: modèles chargés en %.1fs", time.time() - t)
|
||||
|
||||
def analyze(self, screenshot: Image.Image) -> SomResult:
|
||||
"""Analyser un screenshot : détecter tous les éléments et les numéroter.
|
||||
|
||||
Returns:
|
||||
SomResult avec les éléments détectés et le screenshot annoté.
|
||||
"""
|
||||
import time
|
||||
t_start = time.time()
|
||||
|
||||
self._ensure_loaded()
|
||||
|
||||
W, H = screenshot.size
|
||||
elements: List[SomElement] = []
|
||||
elem_id = 0
|
||||
|
||||
# ── 1. YOLO : détecter les éléments interactifs ──
|
||||
if self._yolo is not None:
|
||||
results = self._yolo.predict(
|
||||
source=screenshot, conf=0.15, iou=0.5, verbose=False,
|
||||
)
|
||||
boxes = results[0].boxes if results else []
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()]
|
||||
conf = float(box.conf[0])
|
||||
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
||||
elements.append(SomElement(
|
||||
id=elem_id,
|
||||
bbox=(x1, y1, x2, y2),
|
||||
bbox_norm=(x1 / W, y1 / H, x2 / W, y2 / H),
|
||||
center=(cx, cy),
|
||||
center_norm=(cx / W, cy / H),
|
||||
source="yolo",
|
||||
confidence=conf,
|
||||
))
|
||||
elem_id += 1
|
||||
|
||||
# ── 2. docTR : OCR pour lire le texte ──
|
||||
if self._ocr is not None:
|
||||
try:
|
||||
import numpy as np
|
||||
from doctr.io import DocumentFile
|
||||
# Convertir PIL → fichier temporaire pour docTR
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
||||
screenshot.save(tmp, format="JPEG", quality=90)
|
||||
tmp_path = tmp.name
|
||||
doc = DocumentFile.from_images([tmp_path])
|
||||
import os
|
||||
os.unlink(tmp_path)
|
||||
result_ocr = self._ocr(doc)
|
||||
|
||||
for page in result_ocr.pages:
|
||||
for block in page.blocks:
|
||||
for line in block.lines:
|
||||
for word in line.words:
|
||||
text = word.value.strip()
|
||||
if not text or len(text) < 2:
|
||||
continue
|
||||
# docTR retourne des coords normalisées (0-1)
|
||||
(nx1, ny1), (nx2, ny2) = word.geometry
|
||||
x1 = int(nx1 * W)
|
||||
y1 = int(ny1 * H)
|
||||
x2 = int(nx2 * W)
|
||||
y2 = int(ny2 * H)
|
||||
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
||||
|
||||
# Vérifier si ce texte chevauche un élément YOLO existant
|
||||
overlaps = False
|
||||
for existing in elements:
|
||||
ex1, ey1, ex2, ey2 = existing.bbox
|
||||
# IoU simple
|
||||
ix1 = max(x1, ex1)
|
||||
iy1 = max(y1, ey1)
|
||||
ix2 = min(x2, ex2)
|
||||
iy2 = min(y2, ey2)
|
||||
if ix1 < ix2 and iy1 < iy2:
|
||||
# Chevauchement → enrichir le label YOLO
|
||||
if not existing.label:
|
||||
existing.label = text
|
||||
overlaps = True
|
||||
break
|
||||
|
||||
if not overlaps:
|
||||
elements.append(SomElement(
|
||||
id=elem_id,
|
||||
bbox=(x1, y1, x2, y2),
|
||||
bbox_norm=(nx1, ny1, nx2, ny2),
|
||||
center=(cx, cy),
|
||||
center_norm=(cx / W, cy / H),
|
||||
source="ocr",
|
||||
label=text,
|
||||
confidence=word.confidence,
|
||||
))
|
||||
elem_id += 1
|
||||
except Exception as e:
|
||||
logger.warning("SoM: Erreur OCR docTR: %s", e)
|
||||
|
||||
# ── 3. Annoter le screenshot avec les numéros ──
|
||||
som_image = self._annotate(screenshot.copy(), elements)
|
||||
som_b64 = self._image_to_b64(som_image)
|
||||
|
||||
elapsed_ms = (time.time() - t_start) * 1000
|
||||
logger.info(
|
||||
"SoM: %d éléments (%d yolo, %d ocr) en %.0fms",
|
||||
len(elements),
|
||||
sum(1 for e in elements if e.source == "yolo"),
|
||||
sum(1 for e in elements if e.source == "ocr"),
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
return SomResult(
|
||||
elements=elements,
|
||||
som_image=som_image,
|
||||
som_image_b64=som_b64,
|
||||
width=W,
|
||||
height=H,
|
||||
analysis_time_ms=elapsed_ms,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _annotate(image: Image.Image, elements: List[SomElement]) -> Image.Image:
|
||||
"""Dessiner les numéros SoM sur le screenshot."""
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
# Police
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
|
||||
font_small = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
|
||||
except Exception:
|
||||
font = ImageFont.load_default()
|
||||
font_small = font
|
||||
|
||||
for elem in elements:
|
||||
x1, y1, x2, y2 = elem.bbox
|
||||
# Couleur selon la source
|
||||
color = (255, 50, 50) if elem.source == "yolo" else (50, 50, 255)
|
||||
|
||||
# Boîte
|
||||
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
||||
|
||||
# Numéro (badge rouge)
|
||||
label = str(elem.id)
|
||||
badge_w, badge_h = 24, 18
|
||||
badge_x = max(0, x1 - 2)
|
||||
badge_y = max(0, y1 - badge_h - 2)
|
||||
draw.rectangle(
|
||||
[badge_x, badge_y, badge_x + badge_w, badge_y + badge_h],
|
||||
fill=(220, 30, 30),
|
||||
)
|
||||
draw.text(
|
||||
(badge_x + 3, badge_y + 1), label,
|
||||
fill="white", font=font_small,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def _image_to_b64(image: Image.Image, quality: int = 70) -> str:
|
||||
"""Convertir une image PIL en base64 JPEG."""
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="JPEG", quality=quality)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
Reference in New Issue
Block a user