Qwen2.5-VL occupe 9.8 GB de VRAM → plus de place pour YOLO. SomEngine passe en CPU (1.4s au lieu de 0.1s, acceptable car utilisé uniquement pendant le build_replay, pas le replay). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
316 lines
12 KiB
Python
316 lines
12 KiB
Python
"""
|
|
Set-of-Mark Engine — Détection et numérotation des éléments UI.
|
|
|
|
Pipeline : YOLO (détection icônes) + docTR (OCR) + numérotation visuelle.
|
|
Le VLM (qwen3-vl) est utilisé ensuite pour l'identification sémantique.
|
|
|
|
Usage :
|
|
from core.detection.som_engine import SomEngine
|
|
|
|
engine = SomEngine()
|
|
result = engine.analyze(screenshot_pil)
|
|
# result.elements : liste d'éléments avec coordonnées
|
|
# result.som_image : screenshot avec numéros
|
|
# result.som_image_b64 : en base64
|
|
|
|
Architecture :
|
|
- YOLO v8 (icon_detect) : détecte les éléments interactifs (~15ms GPU)
|
|
- docTR : OCR pour lire le texte visible (~100ms GPU)
|
|
- Annotation : numérote chaque élément sur le screenshot
|
|
- Le VLM n'est PAS appelé ici (séparation détection/identification)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import io
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Chemin vers les poids YOLO d'OmniParser (configurable via env)
|
|
_YOLO_WEIGHTS = Path(
|
|
os.environ.get("SOM_YOLO_WEIGHTS", "/home/dom/ai/OmniParser/weights/icon_detect/model.pt")
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SomElement:
|
|
"""Un élément UI détecté et numéroté."""
|
|
id: int # Numéro Set-of-Mark (0, 1, 2, ...)
|
|
bbox: Tuple[int, int, int, int] # Pixels (x1, y1, x2, y2)
|
|
bbox_norm: Tuple[float, float, float, float] # Normalisé 0-1
|
|
center: Tuple[int, int] # Centre en pixels
|
|
center_norm: Tuple[float, float] # Centre normalisé 0-1
|
|
source: str # 'yolo' ou 'ocr'
|
|
label: str = "" # Texte OCR ou description
|
|
confidence: float = 0.0
|
|
|
|
|
|
@dataclass
|
|
class SomResult:
|
|
"""Résultat de l'analyse Set-of-Mark."""
|
|
elements: List[SomElement] = field(default_factory=list)
|
|
som_image: Optional[Image.Image] = None # Screenshot annoté
|
|
som_image_b64: str = "" # En base64 (JPEG)
|
|
width: int = 0
|
|
height: int = 0
|
|
analysis_time_ms: float = 0.0
|
|
|
|
def find_element_at(self, x: int, y: int, margin: int = 20) -> Optional[SomElement]:
|
|
"""Trouver l'élément le plus proche d'un point (x, y)."""
|
|
best = None
|
|
best_dist = float("inf")
|
|
for elem in self.elements:
|
|
x1, y1, x2, y2 = elem.bbox
|
|
# Point dans la bbox ?
|
|
if x1 - margin <= x <= x2 + margin and y1 - margin <= y <= y2 + margin:
|
|
cx, cy = elem.center
|
|
dist = ((x - cx) ** 2 + (y - cy) ** 2) ** 0.5
|
|
if dist < best_dist:
|
|
best_dist = dist
|
|
best = elem
|
|
return best
|
|
|
|
def get_element_by_id(self, element_id: int) -> Optional[SomElement]:
|
|
"""Récupérer un élément par son numéro SoM."""
|
|
for elem in self.elements:
|
|
if elem.id == element_id:
|
|
return elem
|
|
return None
|
|
|
|
|
|
class SomEngine:
|
|
"""Moteur Set-of-Mark : YOLO + docTR + annotation."""
|
|
|
|
def __init__(self, device: str = "cuda"):
|
|
self._device = device
|
|
self._yolo = None
|
|
self._ocr = None
|
|
self._loaded = False
|
|
|
|
def _ensure_loaded(self):
|
|
"""Chargement paresseux des modèles."""
|
|
if self._loaded:
|
|
return
|
|
|
|
import time
|
|
t = time.time()
|
|
|
|
# YOLO
|
|
if _YOLO_WEIGHTS.is_file():
|
|
from ultralytics import YOLO
|
|
self._yolo = YOLO(str(_YOLO_WEIGHTS))
|
|
self._yolo.to(self._device)
|
|
logger.info("SoM: YOLO chargé sur %s", self._device)
|
|
else:
|
|
logger.warning("SoM: YOLO weights introuvable: %s", _YOLO_WEIGHTS)
|
|
|
|
# docTR
|
|
try:
|
|
from doctr.models import ocr_predictor
|
|
self._ocr = ocr_predictor(
|
|
det_arch="db_resnet50",
|
|
reco_arch="crnn_vgg16_bn",
|
|
pretrained=True,
|
|
)
|
|
if self._device == "cuda":
|
|
self._ocr = self._ocr.cuda()
|
|
logger.info("SoM: docTR chargé")
|
|
except Exception as e:
|
|
logger.warning("SoM: docTR non disponible: %s", e)
|
|
|
|
self._loaded = True
|
|
logger.info("SoM: modèles chargés en %.1fs", time.time() - t)
|
|
|
|
def analyze(self, screenshot: Image.Image) -> SomResult:
|
|
"""Analyser un screenshot : détecter tous les éléments et les numéroter.
|
|
|
|
Returns:
|
|
SomResult avec les éléments détectés et le screenshot annoté.
|
|
"""
|
|
import time
|
|
t_start = time.time()
|
|
|
|
self._ensure_loaded()
|
|
|
|
W, H = screenshot.size
|
|
elements: List[SomElement] = []
|
|
elem_id = 0
|
|
|
|
# ── 1. YOLO : détecter les éléments interactifs ──
|
|
if self._yolo is not None:
|
|
results = self._yolo.predict(
|
|
source=screenshot, conf=0.15, iou=0.5, verbose=False,
|
|
)
|
|
boxes = results[0].boxes if results else []
|
|
for box in boxes:
|
|
x1, y1, x2, y2 = [int(v) for v in box.xyxy[0].tolist()]
|
|
conf = float(box.conf[0])
|
|
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
|
elements.append(SomElement(
|
|
id=elem_id,
|
|
bbox=(x1, y1, x2, y2),
|
|
bbox_norm=(x1 / W, y1 / H, x2 / W, y2 / H),
|
|
center=(cx, cy),
|
|
center_norm=(cx / W, cy / H),
|
|
source="yolo",
|
|
confidence=conf,
|
|
))
|
|
elem_id += 1
|
|
|
|
# ── 2. docTR : OCR pour lire le texte ──
|
|
if self._ocr is not None:
|
|
try:
|
|
from doctr.io import DocumentFile
|
|
# Convertir PIL → fichier temporaire pour docTR
|
|
import tempfile
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
|
|
screenshot.save(tmp, format="JPEG", quality=90)
|
|
tmp_path = tmp.name
|
|
try:
|
|
doc = DocumentFile.from_images([tmp_path])
|
|
result_ocr = self._ocr(doc)
|
|
finally:
|
|
os.unlink(tmp_path)
|
|
|
|
for page in result_ocr.pages:
|
|
for block in page.blocks:
|
|
for line in block.lines:
|
|
for word in line.words:
|
|
text = word.value.strip()
|
|
if not text or len(text) < 2:
|
|
continue
|
|
# docTR retourne des coords normalisées (0-1)
|
|
(nx1, ny1), (nx2, ny2) = word.geometry
|
|
x1 = int(nx1 * W)
|
|
y1 = int(ny1 * H)
|
|
x2 = int(nx2 * W)
|
|
y2 = int(ny2 * H)
|
|
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
|
|
|
# Vérifier si ce texte chevauche un élément YOLO existant
|
|
overlaps = False
|
|
for existing in elements:
|
|
ex1, ey1, ex2, ey2 = existing.bbox
|
|
# IoU simple
|
|
ix1 = max(x1, ex1)
|
|
iy1 = max(y1, ey1)
|
|
ix2 = min(x2, ex2)
|
|
iy2 = min(y2, ey2)
|
|
if ix1 < ix2 and iy1 < iy2:
|
|
# Chevauchement → enrichir le label YOLO
|
|
if not existing.label:
|
|
existing.label = text
|
|
overlaps = True
|
|
break
|
|
|
|
if not overlaps:
|
|
elements.append(SomElement(
|
|
id=elem_id,
|
|
bbox=(x1, y1, x2, y2),
|
|
bbox_norm=(nx1, ny1, nx2, ny2),
|
|
center=(cx, cy),
|
|
center_norm=(cx / W, cy / H),
|
|
source="ocr",
|
|
label=text,
|
|
confidence=word.confidence,
|
|
))
|
|
elem_id += 1
|
|
except Exception as e:
|
|
logger.warning("SoM: Erreur OCR docTR: %s", e)
|
|
|
|
# ── 3. Annoter le screenshot avec les numéros ──
|
|
som_image = self._annotate(screenshot.copy(), elements)
|
|
som_b64 = self._image_to_b64(som_image)
|
|
|
|
elapsed_ms = (time.time() - t_start) * 1000
|
|
logger.info(
|
|
"SoM: %d éléments (%d yolo, %d ocr) en %.0fms",
|
|
len(elements),
|
|
sum(1 for e in elements if e.source == "yolo"),
|
|
sum(1 for e in elements if e.source == "ocr"),
|
|
elapsed_ms,
|
|
)
|
|
|
|
return SomResult(
|
|
elements=elements,
|
|
som_image=som_image,
|
|
som_image_b64=som_b64,
|
|
width=W,
|
|
height=H,
|
|
analysis_time_ms=elapsed_ms,
|
|
)
|
|
|
|
@staticmethod
|
|
def _annotate(image: Image.Image, elements: List[SomElement]) -> Image.Image:
|
|
"""Dessiner les numéros SoM sur le screenshot."""
|
|
draw = ImageDraw.Draw(image)
|
|
|
|
# Police
|
|
try:
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
|
|
font_small = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
|
|
except Exception:
|
|
font = ImageFont.load_default()
|
|
font_small = font
|
|
|
|
for elem in elements:
|
|
x1, y1, x2, y2 = elem.bbox
|
|
# Couleur selon la source
|
|
color = (255, 50, 50) if elem.source == "yolo" else (50, 50, 255)
|
|
|
|
# Boîte
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
|
|
|
|
# Numéro (badge rouge)
|
|
label = str(elem.id)
|
|
badge_w, badge_h = 24, 18
|
|
badge_x = max(0, x1 - 2)
|
|
badge_y = max(0, y1 - badge_h - 2)
|
|
draw.rectangle(
|
|
[badge_x, badge_y, badge_x + badge_w, badge_y + badge_h],
|
|
fill=(220, 30, 30),
|
|
)
|
|
draw.text(
|
|
(badge_x + 3, badge_y + 1), label,
|
|
fill="white", font=font_small,
|
|
)
|
|
|
|
return image
|
|
|
|
@staticmethod
|
|
def _image_to_b64(image: Image.Image, quality: int = 70) -> str:
|
|
"""Convertir une image PIL en base64 JPEG."""
|
|
buf = io.BytesIO()
|
|
image.save(buf, format="JPEG", quality=quality)
|
|
return base64.b64encode(buf.getvalue()).decode()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Singleton partagé (lazy-loaded, thread-safe)
|
|
# ---------------------------------------------------------------------------
|
|
_shared_engine: Optional[SomEngine] = None
|
|
_shared_lock = __import__("threading").Lock()
|
|
|
|
|
|
def get_shared_engine(device: str = "cpu") -> Optional[SomEngine]:
|
|
"""Singleton SomEngine partagé entre tous les modules."""
|
|
global _shared_engine
|
|
if _shared_engine is None:
|
|
with _shared_lock:
|
|
if _shared_engine is None:
|
|
try:
|
|
_shared_engine = SomEngine(device=device)
|
|
logger.info("SomEngine singleton partagé initialisé")
|
|
except Exception as e:
|
|
logger.warning("SomEngine non disponible : %s", e)
|
|
return None
|
|
return _shared_engine
|