feat(grounding): pipeline centralisé + serveur UI-TARS transformers + nettoyage code mort
Architecture grounding complète :
- core/grounding/server.py : serveur FastAPI (port 8200) avec UI-TARS-1.5-7B en 4-bit NF4
Process séparé avec son propre contexte CUDA (résout le crash Flask/CUDA)
- core/grounding/pipeline.py : orchestrateur cascade template→OCR→UI-TARS→static
- core/grounding/template_matcher.py : TemplateMatcher centralisé (remplace 5 copies)
- core/grounding/ui_tars_grounder.py : client HTTP vers le serveur de grounding
- core/grounding/target.py : GroundingTarget + GroundingResult
ORA modifié :
- _act_click() : capture unique de l'écran envoyée au serveur de grounding
- Pre-check VLM skippé pour ui_tars (redondant, et Ollama n'a plus de VRAM)
- verify_level='none' par défaut (vérification titre OCR prévue en Phase 2)
- Détection réponses négatives UI-TARS ("I don't see it" → fallback OCR)
Nettoyage :
- 9 fichiers morts archivés dans _archive/ (~6300 lignes supprimées)
- 21 tests ajoutés pour TemplateMatcher
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
20
core/grounding/__init__.py
Normal file
20
core/grounding/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# core/grounding — Module de localisation d'éléments UI
|
||||
#
|
||||
# Centralise les méthodes de grounding visuel : template matching,
|
||||
# OCR, VLM, etc. Chaque méthode produit un GroundingResult uniforme.
|
||||
#
|
||||
# Le serveur de grounding (server.py) tourne dans un process séparé
|
||||
# sur le port 8200. Le client HTTP (UITarsGrounder) l'appelle via HTTP.
|
||||
# Le pipeline (GroundingPipeline) orchestre template → OCR → UI-TARS → static.
|
||||
|
||||
from core.grounding.template_matcher import TemplateMatcher, MatchResult
|
||||
from core.grounding.target import GroundingTarget, GroundingResult
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
from core.grounding.pipeline import GroundingPipeline
|
||||
|
||||
__all__ = [
|
||||
'TemplateMatcher', 'MatchResult',
|
||||
'GroundingTarget', 'GroundingResult',
|
||||
'UITarsGrounder',
|
||||
'GroundingPipeline',
|
||||
]
|
||||
190
core/grounding/pipeline.py
Normal file
190
core/grounding/pipeline.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
core/grounding/pipeline.py — Pipeline de grounding en cascade
|
||||
|
||||
Orchestre les methodes de localisation dans l'ordre :
|
||||
1. Template matching (TemplateMatcher, local, ~80ms)
|
||||
2. OCR (docTR via input_handler, local, ~1s)
|
||||
3. UI-TARS (HTTP vers serveur grounding, ~3s)
|
||||
4. Static fallback (coordonnees d'origine du workflow)
|
||||
|
||||
Chaque methode est essayee dans l'ordre. Des qu'une reussit, on retourne
|
||||
le resultat. Cela permet un equilibre entre vitesse (template) et robustesse
|
||||
(UI-TARS pour les elements qui ont change de position/apparence).
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.pipeline import GroundingPipeline
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
pipeline = GroundingPipeline()
|
||||
result = pipeline.locate(GroundingTarget(
|
||||
text="Valider",
|
||||
description="bouton vert en bas",
|
||||
template_b64=screenshot_b64,
|
||||
original_bbox={"x": 100, "y": 200, "width": 80, "height": 30},
|
||||
))
|
||||
if result:
|
||||
print(f"Trouve a ({result.x}, {result.y}) via {result.method}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.grounding.target import GroundingTarget, GroundingResult
|
||||
|
||||
|
||||
class GroundingPipeline:
|
||||
"""Pipeline de localisation en cascade : template -> OCR -> UI-TARS -> static."""
|
||||
|
||||
def __init__(self, template_threshold: float = 0.75, enable_uitars: bool = True):
|
||||
self.template_threshold = template_threshold
|
||||
self.enable_uitars = enable_uitars
|
||||
|
||||
def locate(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Localise un element UI en essayant les methodes en cascade.
|
||||
|
||||
Args:
|
||||
target: description de l'element a localiser
|
||||
|
||||
Returns:
|
||||
GroundingResult ou None si aucune methode ne trouve l'element
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# --- Methode 1 : Template matching (~80ms) ---
|
||||
result = self._try_template(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 2 : OCR texte (~1s) ---
|
||||
result = self._try_ocr(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 3 : UI-TARS via serveur HTTP (~3s) ---
|
||||
if self.enable_uitars:
|
||||
result = self._try_uitars(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 4 : Fallback statique ---
|
||||
result = self._try_static(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
print(f"[GroundingPipeline] ECHEC: '{target.text}' introuvable "
|
||||
f"(toutes methodes epuisees, {(time.time() - t0) * 1000:.0f}ms)")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Methodes individuelles
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _try_template(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Template matching — rapide, exact, mais sensible aux changements visuels."""
|
||||
if not target.template_b64:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.grounding.template_matcher import TemplateMatcher
|
||||
matcher = TemplateMatcher(threshold=self.template_threshold)
|
||||
match = matcher.match_screen(anchor_b64=target.template_b64)
|
||||
if match:
|
||||
print(f"[GroundingPipeline/template] score={match.score:.3f} "
|
||||
f"pos=({match.x},{match.y}) ({match.time_ms:.0f}ms)")
|
||||
return GroundingResult(
|
||||
x=match.x,
|
||||
y=match.y,
|
||||
method='template',
|
||||
confidence=match.score,
|
||||
time_ms=match.time_ms,
|
||||
)
|
||||
else:
|
||||
diag = matcher.match_screen_diagnostic(anchor_b64=target.template_b64)
|
||||
print(f"[GroundingPipeline/template] pas de match — best={diag}")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/template] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_ocr(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""OCR : cherche le texte cible sur l'ecran via docTR."""
|
||||
if not target.text:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.execution.input_handler import _grounding_ocr
|
||||
bbox = target.original_bbox if target.original_bbox else None
|
||||
result = _grounding_ocr(target.text, anchor_bbox=bbox)
|
||||
if result:
|
||||
print(f"[GroundingPipeline/OCR] '{target.text}' -> ({result['x']}, {result['y']})")
|
||||
return GroundingResult(
|
||||
x=result['x'],
|
||||
y=result['y'],
|
||||
method='ocr',
|
||||
confidence=result.get('confidence', 0.80),
|
||||
time_ms=result.get('time_ms', 0),
|
||||
)
|
||||
else:
|
||||
print(f"[GroundingPipeline/OCR] '{target.text}' non trouve")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/OCR] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_uitars(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""UI-TARS via serveur HTTP — robust, gere les changements de layout."""
|
||||
if not target.text and not target.description:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
grounder = UITarsGrounder.get_instance()
|
||||
result = grounder.ground(
|
||||
target_text=target.text,
|
||||
target_description=target.description,
|
||||
)
|
||||
if result:
|
||||
print(f"[GroundingPipeline/UI-TARS] ({result.x}, {result.y}) "
|
||||
f"conf={result.confidence:.2f} ({result.time_ms:.0f}ms)")
|
||||
return result
|
||||
else:
|
||||
print(f"[GroundingPipeline/UI-TARS] pas de resultat")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/UI-TARS] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_static(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Fallback : coordonnees d'origine du workflow (centre du bounding box)."""
|
||||
bbox = target.original_bbox
|
||||
if not bbox:
|
||||
return None
|
||||
|
||||
w = bbox.get('width', 0)
|
||||
h = bbox.get('height', 0)
|
||||
if not w or not h:
|
||||
return None
|
||||
|
||||
x = int(bbox.get('x', 0) + w / 2)
|
||||
y = int(bbox.get('y', 0) + h / 2)
|
||||
|
||||
print(f"[GroundingPipeline/static] fallback ({x}, {y}) "
|
||||
f"depuis bbox {bbox}")
|
||||
|
||||
return GroundingResult(
|
||||
x=x,
|
||||
y=y,
|
||||
method='static_fallback',
|
||||
confidence=0.30,
|
||||
time_ms=0.0,
|
||||
)
|
||||
433
core/grounding/server.py
Normal file
433
core/grounding/server.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
core/grounding/server.py — Serveur FastAPI de grounding visuel (port 8200)
|
||||
|
||||
Charge UI-TARS-1.5-7B en 4-bit NF4 dans son propre process Python avec son
|
||||
propre contexte CUDA. Le backend Flask VWB (port 5002) et la boucle ORA
|
||||
appellent ce serveur en HTTP au lieu de charger le modele in-process.
|
||||
|
||||
Lancement :
|
||||
.venv/bin/python3 -m core.grounding.server
|
||||
|
||||
Endpoints :
|
||||
GET /health — verifie que le modele est charge
|
||||
POST /ground — localise un element UI sur un screenshot
|
||||
"""
|
||||
|
||||
import base64
|
||||
import gc
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PORT = int(os.environ.get("GROUNDING_PORT", 8200))
|
||||
MODEL_ID = "ByteDance-Seed/UI-TARS-1.5-7B"
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Smart resize — identique a /tmp/test_uitars.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _smart_resize(height: int, width: int, factor: int = 28,
|
||||
min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS):
|
||||
"""UI-TARS smart resize (memes defaults que le test valide)."""
|
||||
h_bar = max(factor, round(height / factor) * factor)
|
||||
w_bar = max(factor, round(width / factor) * factor)
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = math.floor(height / beta / factor) * factor
|
||||
w_bar = math.floor(width / beta / factor) * factor
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt officiel UI-TARS — identique a /tmp/test_uitars.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_GROUNDING_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
|
||||
Thought: ...
|
||||
Action: ...
|
||||
|
||||
|
||||
## Action Space
|
||||
click(start_box='(x1, y1)')
|
||||
|
||||
|
||||
## User Instruction
|
||||
{instruction}"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Modele singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_model = None
|
||||
_processor = None
|
||||
_model_loaded = False
|
||||
|
||||
|
||||
def _evict_ollama_models():
|
||||
"""Libere les modeles Ollama de la VRAM avant de charger UI-TARS."""
|
||||
try:
|
||||
import requests
|
||||
try:
|
||||
ps_resp = requests.get('http://localhost:11434/api/ps', timeout=3)
|
||||
if ps_resp.status_code == 200:
|
||||
loaded = ps_resp.json().get('models', [])
|
||||
model_names = [m.get('name', '') for m in loaded if m.get('name')]
|
||||
else:
|
||||
model_names = []
|
||||
except Exception:
|
||||
model_names = []
|
||||
|
||||
if not model_names:
|
||||
print("[grounding-server] Aucun modele Ollama en VRAM")
|
||||
return
|
||||
|
||||
for model_name in model_names:
|
||||
try:
|
||||
requests.post(
|
||||
'http://localhost:11434/api/generate',
|
||||
json={'model': model_name, 'keep_alive': '0'},
|
||||
timeout=5,
|
||||
)
|
||||
print(f"[grounding-server] Ollama: eviction de '{model_name}'")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(1.0)
|
||||
print("[grounding-server] Modeles Ollama liberes")
|
||||
except ImportError:
|
||||
print("[grounding-server] requests non dispo, skip eviction Ollama")
|
||||
|
||||
|
||||
def _load_model():
|
||||
"""Charge UI-TARS-1.5-7B en 4-bit NF4 — code identique a /tmp/test_uitars.py."""
|
||||
global _model, _processor, _model_loaded
|
||||
|
||||
if _model_loaded:
|
||||
return
|
||||
|
||||
print("=" * 60)
|
||||
print(f"[grounding-server] Chargement de {MODEL_ID}")
|
||||
print("=" * 60)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA non disponible — le serveur de grounding necessite un GPU")
|
||||
|
||||
# Liberer la VRAM Ollama
|
||||
_evict_ollama_models()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
t0 = time.time()
|
||||
_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_ID,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
)
|
||||
_model.eval()
|
||||
|
||||
_processor = AutoProcessor.from_pretrained(
|
||||
MODEL_ID,
|
||||
min_pixels=MIN_PIXELS,
|
||||
max_pixels=MAX_PIXELS,
|
||||
)
|
||||
|
||||
_model_loaded = True
|
||||
load_time = time.time() - t0
|
||||
alloc = torch.cuda.memory_allocated() / 1024**3
|
||||
peak = torch.cuda.max_memory_allocated() / 1024**3
|
||||
print(f"[grounding-server] Modele charge en {load_time:.1f}s | "
|
||||
f"VRAM: {alloc:.2f} GB (peak: {peak:.2f} GB)")
|
||||
|
||||
|
||||
def _capture_screen():
|
||||
"""Capture l'ecran complet via mss. Retourne PIL Image ou None."""
|
||||
try:
|
||||
import mss as mss_lib
|
||||
from PIL import Image
|
||||
with mss_lib.mss() as sct:
|
||||
mon = sct.monitors[0]
|
||||
grab = sct.grab(mon)
|
||||
return Image.frombytes('RGB', grab.size, grab.bgra, 'raw', 'BGRX')
|
||||
except Exception as e:
|
||||
print(f"[grounding-server] Erreur capture ecran: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _parse_coordinates(raw: str, orig_w: int, orig_h: int,
|
||||
resized_w: int, resized_h: int):
|
||||
"""Parse les coordonnees du modele — identique a /tmp/test_uitars.py.
|
||||
|
||||
Retourne (px, py, method_detail, confidence) ou None.
|
||||
"""
|
||||
cx, cy = None, None
|
||||
|
||||
# Format 1: <point>x y</point>
|
||||
pm = re.search(r'<point>\s*(\d+)\s+(\d+)\s*</point>', raw)
|
||||
if pm:
|
||||
cx, cy = int(pm.group(1)), int(pm.group(2))
|
||||
|
||||
# Format 2: start_box='(x, y)'
|
||||
if cx is None:
|
||||
bm = re.search(r"start_box=\s*['\"]?\((\d+)\s*,\s*(\d+)\)", raw)
|
||||
if bm:
|
||||
cx, cy = int(bm.group(1)), int(bm.group(2))
|
||||
|
||||
# Format 3: fallback x, y
|
||||
if cx is None:
|
||||
fm = re.search(r'(\d+)\s*,\s*(\d+)', raw)
|
||||
if fm:
|
||||
cx, cy = int(fm.group(1)), int(fm.group(2))
|
||||
|
||||
if cx is None or cy is None:
|
||||
return None
|
||||
|
||||
# Conversion : tester les 2 interpretations, garder la meilleure
|
||||
# Methode A : coordonnees dans l'espace de l'image resizee
|
||||
px_r = int(cx / resized_w * orig_w)
|
||||
py_r = int(cy / resized_h * orig_h)
|
||||
delta_r = ((px_r - orig_w / 2) ** 2 + (py_r - orig_h / 2) ** 2) ** 0.5
|
||||
|
||||
# Methode B : coordonnees 0-1000
|
||||
px_1k = int(cx / 1000 * orig_w)
|
||||
py_1k = int(cy / 1000 * orig_h)
|
||||
delta_1k = ((px_1k - orig_w / 2) ** 2 + (py_1k - orig_h / 2) ** 2) ** 0.5
|
||||
|
||||
# Heuristique du script valide : si coords dans les limites du resize,
|
||||
# les deux sont possibles. UI-TARS utilise l'espace resize en natif.
|
||||
if cx <= resized_w and cy <= resized_h:
|
||||
in_screen_r = (0 <= px_r <= orig_w and 0 <= py_r <= orig_h)
|
||||
in_screen_1k = (0 <= px_1k <= orig_w and 0 <= py_1k <= orig_h)
|
||||
|
||||
if in_screen_r and in_screen_1k:
|
||||
px, py = px_r, py_r
|
||||
method_detail = "resized"
|
||||
elif in_screen_r:
|
||||
px, py = px_r, py_r
|
||||
method_detail = "resized"
|
||||
else:
|
||||
px, py = px_1k, py_1k
|
||||
method_detail = "0-1000"
|
||||
else:
|
||||
px, py = px_1k, py_1k
|
||||
method_detail = "0-1000"
|
||||
|
||||
confidence = 0.85 if ("start_box" in raw or "<point>" in raw) else 0.70
|
||||
|
||||
print(f"[grounding-server] model=({cx},{cy}) -> pixel=({px},{py}) "
|
||||
f"[{method_detail}] resized={resized_w}x{resized_h} orig={orig_w}x{orig_h}")
|
||||
|
||||
return px, py, method_detail, confidence
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI app
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
app = FastAPI(title="RPA Vision Grounding Server", version="1.0.0")
|
||||
|
||||
|
||||
class GroundRequest(BaseModel):
|
||||
target_text: str = ""
|
||||
target_description: str = ""
|
||||
image_b64: str = ""
|
||||
|
||||
|
||||
class GroundResponse(BaseModel):
|
||||
x: Optional[int] = None
|
||||
y: Optional[int] = None
|
||||
method: str = "ui_tars"
|
||||
confidence: float = 0.85
|
||||
time_ms: float = 0.0
|
||||
raw_output: str = ""
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {
|
||||
"status": "ok" if _model_loaded else "loading",
|
||||
"model": MODEL_ID,
|
||||
"model_loaded": _model_loaded,
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"vram_allocated_gb": round(torch.cuda.memory_allocated() / 1024**3, 2) if torch.cuda.is_available() else 0,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/ground", response_model=GroundResponse)
|
||||
def ground(req: GroundRequest):
|
||||
if not _model_loaded:
|
||||
raise HTTPException(status_code=503, detail="Modele pas encore charge")
|
||||
|
||||
from PIL import Image
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
# Construire l'instruction
|
||||
parts = []
|
||||
if req.target_text:
|
||||
parts.append(req.target_text)
|
||||
if req.target_description:
|
||||
parts.append(req.target_description)
|
||||
if not parts:
|
||||
raise HTTPException(status_code=400, detail="target_text ou target_description requis")
|
||||
|
||||
instruction = f"Click on the {' — '.join(parts)}"
|
||||
|
||||
# Obtenir l'image (fournie en b64 ou capture ecran)
|
||||
if req.image_b64:
|
||||
try:
|
||||
raw_b64 = req.image_b64.split(',')[1] if ',' in req.image_b64 else req.image_b64
|
||||
img_data = base64.b64decode(raw_b64)
|
||||
screen_pil = Image.open(io.BytesIO(img_data)).convert('RGB')
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Erreur decodage image: {e}")
|
||||
else:
|
||||
screen_pil = _capture_screen()
|
||||
if screen_pil is None:
|
||||
raise HTTPException(status_code=500, detail="Capture ecran echouee")
|
||||
|
||||
W, H = screen_pil.size
|
||||
rH, rW = _smart_resize(H, W, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS)
|
||||
|
||||
# Sauver temporairement l'image pour qwen_vl_utils
|
||||
import tempfile
|
||||
tmp_path = os.path.join(tempfile.gettempdir(), f"grounding_screen_{os.getpid()}.png")
|
||||
screen_pil.save(tmp_path)
|
||||
|
||||
try:
|
||||
system_prompt = _GROUNDING_PROMPT.format(instruction=instruction)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": f"file://{tmp_path}",
|
||||
"min_pixels": MIN_PIXELS,
|
||||
"max_pixels": MAX_PIXELS,
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": system_prompt,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
text = _processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = _processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(_model.device)
|
||||
|
||||
# Inference
|
||||
t0 = time.time()
|
||||
with torch.no_grad():
|
||||
gen = _model.generate(**inputs, max_new_tokens=256)
|
||||
infer_ms = (time.time() - t0) * 1000
|
||||
|
||||
# Decoder
|
||||
trimmed = [o[len(i):] for i, o in zip(inputs.input_ids, gen)]
|
||||
raw = _processor.batch_decode(
|
||||
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0].strip()
|
||||
|
||||
print(f"[grounding-server] '{instruction}' -> raw='{raw[:150]}' ({infer_ms:.0f}ms)")
|
||||
|
||||
# Détecter les réponses négatives (le modèle dit qu'il ne voit pas l'élément)
|
||||
_raw_lower = raw.lower()
|
||||
_negative_markers = ["don't see", "do not see", "cannot find", "can't find",
|
||||
"not visible", "not found", "doesn't appear", "does not appear",
|
||||
"i don't", "unable to find", "unable to locate"]
|
||||
for _neg in _negative_markers:
|
||||
if _neg in _raw_lower:
|
||||
print(f"[grounding-server] NÉGATIF détecté: '{_neg}' → élément non trouvé")
|
||||
return GroundResponse(x=None, y=None, method="ui_tars", confidence=0.0,
|
||||
time_ms=round(infer_ms, 1), raw_output=raw[:300])
|
||||
|
||||
# Parser les coordonnees
|
||||
parsed = _parse_coordinates(raw, W, H, rW, rH)
|
||||
if parsed is None:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Coordonnees non parsees dans la reponse: {raw[:200]}"
|
||||
)
|
||||
|
||||
px, py, method_detail, confidence = parsed
|
||||
|
||||
print(f"[grounding-server] Resultat: ({px}, {py}) conf={confidence:.2f} "
|
||||
f"[{method_detail}] ({infer_ms:.0f}ms)")
|
||||
|
||||
return GroundResponse(
|
||||
x=px,
|
||||
y=py,
|
||||
method="ui_tars",
|
||||
confidence=confidence,
|
||||
time_ms=round(infer_ms, 1),
|
||||
raw_output=raw[:300],
|
||||
)
|
||||
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entrypoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Charge le modele au demarrage du serveur."""
|
||||
print(f"[grounding-server] Demarrage sur port {PORT}...")
|
||||
_load_model()
|
||||
print(f"[grounding-server] Pret a recevoir des requetes sur http://localhost:{PORT}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"core.grounding.server:app",
|
||||
host="0.0.0.0",
|
||||
port=PORT,
|
||||
log_level="info",
|
||||
workers=1, # 1 seul worker (1 seul GPU)
|
||||
)
|
||||
48
core/grounding/target.py
Normal file
48
core/grounding/target.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
core/grounding/target.py — Types partagés pour le grounding visuel
|
||||
|
||||
Dataclasses décrivant une cible à localiser (GroundingTarget) et
|
||||
le résultat d'une localisation (GroundingResult).
|
||||
|
||||
Ces types sont la brique commune pour tous les modules de grounding :
|
||||
template matching, OCR, VLM, CLIP, etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroundingTarget:
|
||||
"""Description d'un élément UI à localiser sur l'écran.
|
||||
|
||||
Attributs :
|
||||
text : texte visible de l'élément (bouton, label, etc.)
|
||||
description : description sémantique libre (ex: "le bouton Valider en bas à droite")
|
||||
template_b64 : capture visuelle de l'élément, encodée en base64 PNG/JPEG
|
||||
original_bbox : position d'origine lors de la capture {x, y, width, height}
|
||||
"""
|
||||
text: str = ""
|
||||
description: str = ""
|
||||
template_b64: str = ""
|
||||
original_bbox: Optional[Dict[str, int]] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroundingResult:
|
||||
"""Résultat d'une localisation d'élément UI.
|
||||
|
||||
Attributs :
|
||||
x : coordonnée X du centre de l'élément trouvé (pixels écran)
|
||||
y : coordonnée Y du centre de l'élément trouvé (pixels écran)
|
||||
method : méthode ayant produit le résultat ('template', 'ocr', 'vlm', 'clip', etc.)
|
||||
confidence : score de confiance [0.0 – 1.0]
|
||||
time_ms : temps de recherche en millisecondes
|
||||
"""
|
||||
x: int
|
||||
y: int
|
||||
method: str
|
||||
confidence: float
|
||||
time_ms: float
|
||||
350
core/grounding/template_matcher.py
Normal file
350
core/grounding/template_matcher.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
core/grounding/template_matcher.py — Template matching centralisé
|
||||
|
||||
Fournit une classe TemplateMatcher qui localise une ancre visuelle (image template)
|
||||
dans un screenshot via cv2.matchTemplate. Supporte single-scale et multi-scale.
|
||||
|
||||
Remplace les implémentations dupliquées dans :
|
||||
- core/execution/observe_reason_act.py (~1348-1375)
|
||||
- visual_workflow_builder/backend/api_v3/execute.py (~930-963)
|
||||
- visual_workflow_builder/backend/catalog_routes_v2_vlm.py (~339-381)
|
||||
- visual_workflow_builder/backend/services/intelligent_executor.py (~131-210)
|
||||
- core/detection/omniparser_adapter.py (~330)
|
||||
|
||||
Utilisation :
|
||||
from core.grounding import TemplateMatcher, MatchResult
|
||||
|
||||
matcher = TemplateMatcher(threshold=0.75)
|
||||
result = matcher.match_screen(anchor_b64="...")
|
||||
if result:
|
||||
print(f"Trouvé à ({result.x}, {result.y}) score={result.score:.3f}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Imports optionnels — le module se charge même sans cv2/PIL/mss
|
||||
try:
|
||||
import cv2
|
||||
_CV2 = True
|
||||
except ImportError:
|
||||
_CV2 = False
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
_NP = True
|
||||
except ImportError:
|
||||
_NP = False
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
_PIL = True
|
||||
except ImportError:
|
||||
_PIL = False
|
||||
|
||||
try:
|
||||
import mss as mss_lib
|
||||
_MSS = True
|
||||
except ImportError:
|
||||
_MSS = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Résultat d'un match
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MatchResult:
|
||||
"""Résultat d'un template matching."""
|
||||
x: int
|
||||
y: int
|
||||
score: float
|
||||
method: str # 'template' | 'template_multiscale'
|
||||
time_ms: float
|
||||
scale: float = 1.0 # Échelle à laquelle le meilleur match a été trouvé
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TemplateMatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TemplateMatcher:
|
||||
"""Localise une ancre visuelle dans un screenshot via template matching.
|
||||
|
||||
Paramètres :
|
||||
threshold : score minimum pour accepter un match (défaut 0.75)
|
||||
multiscale : active le matching multi-échelle (défaut False)
|
||||
scales : liste d'échelles à tester en mode multi-scale
|
||||
method : méthode cv2 (défaut cv2.TM_CCOEFF_NORMED)
|
||||
grayscale : convertir en niveaux de gris avant matching (défaut False)
|
||||
"""
|
||||
|
||||
# Échelles par défaut pour le mode multi-scale, ordonnées par
|
||||
# probabilité décroissante (1.0 en premier = rapide si ça matche)
|
||||
DEFAULT_SCALES: List[float] = [1.0, 0.95, 1.05, 0.9, 1.1, 0.85, 1.15, 0.8, 1.2]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.75,
|
||||
multiscale: bool = False,
|
||||
scales: Optional[List[float]] = None,
|
||||
grayscale: bool = False,
|
||||
):
|
||||
self.threshold = threshold
|
||||
self.multiscale = multiscale
|
||||
self.scales = scales or self.DEFAULT_SCALES
|
||||
self.grayscale = grayscale
|
||||
# cv2.TM_CCOEFF_NORMED est la méthode utilisée partout dans le projet
|
||||
self._cv2_method = cv2.TM_CCOEFF_NORMED if _CV2 else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API publique
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def match_screen(
|
||||
self,
|
||||
anchor_b64: Optional[str] = None,
|
||||
anchor_pil: Optional["Image.Image"] = None,
|
||||
screen_pil: Optional["Image.Image"] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Cherche l'ancre dans le screenshot courant (ou fourni).
|
||||
|
||||
L'ancre peut être passée en base64 ou en PIL Image.
|
||||
Le screenshot est capturé via mss si non fourni.
|
||||
|
||||
Retourne un MatchResult ou None si aucun match >= seuil.
|
||||
"""
|
||||
if not (_CV2 and _NP and _PIL):
|
||||
logger.debug("[TemplateMatcher] cv2/numpy/PIL non disponible")
|
||||
return None
|
||||
|
||||
# --- Préparer l'ancre ---
|
||||
anchor_img = self._decode_anchor(anchor_b64, anchor_pil)
|
||||
if anchor_img is None:
|
||||
return None
|
||||
|
||||
# --- Préparer le screenshot ---
|
||||
if screen_pil is None:
|
||||
screen_pil = self._capture_screen()
|
||||
if screen_pil is None:
|
||||
return None
|
||||
|
||||
# --- Convertir en arrays cv2 ---
|
||||
screen_cv = cv2.cvtColor(np.array(screen_pil), cv2.COLOR_RGB2BGR)
|
||||
anchor_cv = cv2.cvtColor(np.array(anchor_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# --- Matching ---
|
||||
if self.multiscale:
|
||||
return self._match_multiscale(screen_cv, anchor_cv)
|
||||
else:
|
||||
return self._match_single(screen_cv, anchor_cv)
|
||||
|
||||
def match_in_region(
|
||||
self,
|
||||
region_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Match dans une région déjà découpée (arrays BGR).
|
||||
|
||||
Utilisé par les pipelines qui font leur propre capture/découpe.
|
||||
"""
|
||||
if not (_CV2 and _NP):
|
||||
return None
|
||||
|
||||
thr = threshold if threshold is not None else self.threshold
|
||||
|
||||
if self.multiscale:
|
||||
return self._match_multiscale(region_cv, anchor_cv, threshold_override=thr)
|
||||
else:
|
||||
return self._match_single(region_cv, anchor_cv, threshold_override=thr)
|
||||
|
||||
def match_screen_diagnostic(
|
||||
self,
|
||||
anchor_b64: Optional[str] = None,
|
||||
anchor_pil: Optional["Image.Image"] = None,
|
||||
screen_pil: Optional["Image.Image"] = None,
|
||||
) -> str:
|
||||
"""Retourne un diagnostic textuel (score + position) même sans match."""
|
||||
if not (_CV2 and _NP and _PIL):
|
||||
return "cv2/numpy/PIL non dispo"
|
||||
|
||||
anchor_img = self._decode_anchor(anchor_b64, anchor_pil)
|
||||
if anchor_img is None:
|
||||
return "ancre non décodable"
|
||||
|
||||
if screen_pil is None:
|
||||
screen_pil = self._capture_screen()
|
||||
if screen_pil is None:
|
||||
return "capture écran échouée"
|
||||
|
||||
screen_cv = cv2.cvtColor(np.array(screen_pil), cv2.COLOR_RGB2BGR)
|
||||
anchor_cv = cv2.cvtColor(np.array(anchor_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
if anchor_cv.shape[0] >= screen_cv.shape[0] or anchor_cv.shape[1] >= screen_cv.shape[1]:
|
||||
return f"ancre {anchor_cv.shape[:2]} >= écran {screen_cv.shape[:2]}"
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, anchor_cv)
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
return f"{max_val:.3f} pos={max_loc}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Méthodes internes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _match_single(
|
||||
self,
|
||||
screen_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold_override: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Template matching single-scale."""
|
||||
threshold = threshold_override if threshold_override is not None else self.threshold
|
||||
|
||||
if anchor_cv.shape[0] >= screen_cv.shape[0] or anchor_cv.shape[1] >= screen_cv.shape[1]:
|
||||
logger.debug("[TemplateMatcher] Ancre plus grande que le screen")
|
||||
return None
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, anchor_cv)
|
||||
|
||||
t0 = time.time()
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
|
||||
logger.debug(
|
||||
"[TemplateMatcher] score=%.3f pos=%s (%.0fms)",
|
||||
max_val, max_loc, elapsed_ms,
|
||||
)
|
||||
|
||||
if max_val >= threshold:
|
||||
cx = max_loc[0] + anchor_cv.shape[1] // 2
|
||||
cy = max_loc[1] + anchor_cv.shape[0] // 2
|
||||
return MatchResult(
|
||||
x=cx,
|
||||
y=cy,
|
||||
score=float(max_val),
|
||||
method='template',
|
||||
time_ms=elapsed_ms,
|
||||
scale=1.0,
|
||||
)
|
||||
return None
|
||||
|
||||
def _match_multiscale(
|
||||
self,
|
||||
screen_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold_override: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Template matching multi-scale."""
|
||||
threshold = threshold_override if threshold_override is not None else self.threshold
|
||||
|
||||
best_score = -1.0
|
||||
best_loc = None
|
||||
best_scale = 1.0
|
||||
best_anchor_shape = anchor_cv.shape
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
for scale in self.scales:
|
||||
if scale == 1.0:
|
||||
scaled = anchor_cv
|
||||
else:
|
||||
new_w = int(anchor_cv.shape[1] * scale)
|
||||
new_h = int(anchor_cv.shape[0] * scale)
|
||||
if new_w < 8 or new_h < 8:
|
||||
continue
|
||||
if new_h >= screen_cv.shape[0] or new_w >= screen_cv.shape[1]:
|
||||
continue
|
||||
scaled = cv2.resize(anchor_cv, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
if scaled.shape[0] >= screen_cv.shape[0] or scaled.shape[1] >= screen_cv.shape[1]:
|
||||
continue
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, scaled)
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
|
||||
if max_val > best_score:
|
||||
best_score = max_val
|
||||
best_loc = max_loc
|
||||
best_scale = scale
|
||||
best_anchor_shape = scaled.shape
|
||||
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
|
||||
logger.debug(
|
||||
"[TemplateMatcher/multiscale] best_score=%.3f scale=%.2f (%.0fms)",
|
||||
best_score, best_scale, elapsed_ms,
|
||||
)
|
||||
|
||||
if best_score >= threshold and best_loc is not None:
|
||||
cx = best_loc[0] + best_anchor_shape[1] // 2
|
||||
cy = best_loc[1] + best_anchor_shape[0] // 2
|
||||
return MatchResult(
|
||||
x=cx,
|
||||
y=cy,
|
||||
score=float(best_score),
|
||||
method='template_multiscale',
|
||||
time_ms=elapsed_ms,
|
||||
scale=best_scale,
|
||||
)
|
||||
return None
|
||||
|
||||
def _maybe_grayscale(
|
||||
self,
|
||||
screen: "np.ndarray",
|
||||
anchor: "np.ndarray",
|
||||
) -> Tuple["np.ndarray", "np.ndarray"]:
|
||||
"""Convertit en niveaux de gris si self.grayscale est True."""
|
||||
if not self.grayscale:
|
||||
return screen, anchor
|
||||
s = cv2.cvtColor(screen, cv2.COLOR_BGR2GRAY) if len(screen.shape) == 3 else screen
|
||||
a = cv2.cvtColor(anchor, cv2.COLOR_BGR2GRAY) if len(anchor.shape) == 3 else anchor
|
||||
return s, a
|
||||
|
||||
@staticmethod
|
||||
def _decode_anchor(
|
||||
anchor_b64: Optional[str],
|
||||
anchor_pil: Optional["Image.Image"],
|
||||
) -> Optional["Image.Image"]:
|
||||
"""Décode l'ancre depuis base64 ou retourne le PIL directement."""
|
||||
if anchor_pil is not None:
|
||||
return anchor_pil
|
||||
|
||||
if anchor_b64 is None:
|
||||
logger.debug("[TemplateMatcher] Ni anchor_b64 ni anchor_pil fourni")
|
||||
return None
|
||||
|
||||
try:
|
||||
raw = anchor_b64.split(',')[1] if ',' in anchor_b64 else anchor_b64
|
||||
data = base64.b64decode(raw)
|
||||
return Image.open(io.BytesIO(data))
|
||||
except Exception as e:
|
||||
logger.debug("[TemplateMatcher] Erreur décodage ancre: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _capture_screen() -> Optional["Image.Image"]:
|
||||
"""Capture l'écran complet via mss (moniteur 0 = tous les écrans)."""
|
||||
if not _MSS:
|
||||
logger.debug("[TemplateMatcher] mss non disponible")
|
||||
return None
|
||||
|
||||
try:
|
||||
with mss_lib.mss() as sct:
|
||||
mon = sct.monitors[0]
|
||||
grab = sct.grab(mon)
|
||||
return Image.frombytes('RGB', grab.size, grab.bgra, 'raw', 'BGRX')
|
||||
except Exception as e:
|
||||
logger.debug("[TemplateMatcher] Erreur capture écran: %s", e)
|
||||
return None
|
||||
204
core/grounding/ui_tars_grounder.py
Normal file
204
core/grounding/ui_tars_grounder.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
core/grounding/ui_tars_grounder.py — Client HTTP pour le serveur de grounding
|
||||
|
||||
Remplace le chargement in-process du modele UI-TARS (qui crashe dans Flask
|
||||
a cause de conflits CUDA) par un CLIENT HTTP qui appelle le serveur de
|
||||
grounding separe sur le port 8200.
|
||||
|
||||
Le serveur est lance separement via :
|
||||
.venv/bin/python3 -m core.grounding.server
|
||||
|
||||
Utilisation (inchangee) :
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
|
||||
grounder = UITarsGrounder.get_instance()
|
||||
result = grounder.ground("Bouton Valider", "le bouton vert en bas a droite")
|
||||
if result:
|
||||
print(f"Trouve a ({result.x}, {result.y})")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.grounding.target import GroundingResult
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_instance: Optional[UITarsGrounder] = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
|
||||
class UITarsGrounder:
|
||||
"""Client HTTP pour le serveur de grounding UI-TARS (port 8200).
|
||||
|
||||
Singleton : utiliser get_instance() pour obtenir l'instance unique.
|
||||
Le serveur doit etre lance separement (.venv/bin/python3 -m core.grounding.server).
|
||||
"""
|
||||
|
||||
SERVER_URL = os.environ.get("GROUNDING_SERVER_URL", "http://localhost:8200")
|
||||
|
||||
def __init__(self):
|
||||
self._server_available: Optional[bool] = None
|
||||
self._last_check = 0.0
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> UITarsGrounder:
|
||||
"""Retourne l'instance singleton du grounder."""
|
||||
global _instance
|
||||
if _instance is None:
|
||||
with _instance_lock:
|
||||
if _instance is None:
|
||||
_instance = cls()
|
||||
return _instance
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Verification du serveur
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _check_server(self, force: bool = False) -> bool:
|
||||
"""Verifie si le serveur de grounding est disponible.
|
||||
|
||||
Cache le resultat pendant 30 secondes pour eviter le spam.
|
||||
"""
|
||||
now = time.time()
|
||||
if not force and self._server_available is not None and (now - self._last_check) < 30:
|
||||
return self._server_available
|
||||
|
||||
try:
|
||||
import requests
|
||||
resp = requests.get(f"{self.SERVER_URL}/health", timeout=3)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
self._server_available = data.get("model_loaded", False)
|
||||
if not self._server_available:
|
||||
print(f"[UI-TARS/client] Serveur en cours de chargement...")
|
||||
else:
|
||||
self._server_available = False
|
||||
except Exception:
|
||||
self._server_available = False
|
||||
|
||||
self._last_check = now
|
||||
|
||||
if not self._server_available:
|
||||
print(f"[UI-TARS/client] Serveur non disponible sur {self.SERVER_URL} "
|
||||
f"— lancer: .venv/bin/python3 -m core.grounding.server")
|
||||
|
||||
return self._server_available
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Compatibilite : verifie si le serveur est pret."""
|
||||
return self._check_server()
|
||||
|
||||
def load(self) -> None:
|
||||
"""Compatibilite : ne fait rien (le serveur charge le modele au demarrage)."""
|
||||
if not self._check_server(force=True):
|
||||
print(f"[UI-TARS/client] ATTENTION: serveur non disponible sur {self.SERVER_URL}")
|
||||
print(f"[UI-TARS/client] Lancer le serveur: .venv/bin/python3 -m core.grounding.server")
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Compatibilite : ne fait rien (le modele vit dans le process serveur)."""
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Grounding via HTTP
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def ground(
|
||||
self,
|
||||
target_text: str = "",
|
||||
target_description: str = "",
|
||||
screen_pil: Optional["PIL.Image.Image"] = None,
|
||||
) -> Optional[GroundingResult]:
|
||||
"""Localise un element UI en appelant le serveur de grounding.
|
||||
|
||||
Args:
|
||||
target_text: texte visible de l'element (ex: "Valider", "Rechercher")
|
||||
target_description: description semantique (ex: "le bouton vert en bas")
|
||||
screen_pil: screenshot PIL, le serveur capture si None
|
||||
|
||||
Returns:
|
||||
GroundingResult avec coordonnees en pixels ecran, ou None si echec
|
||||
"""
|
||||
if not target_text and not target_description:
|
||||
print("[UI-TARS/client] Pas de target_text ni target_description")
|
||||
return None
|
||||
|
||||
# Verifier que le serveur est disponible
|
||||
if not self._check_server():
|
||||
return None
|
||||
|
||||
import requests
|
||||
|
||||
# Encoder l'image en base64 si fournie
|
||||
image_b64 = ""
|
||||
if screen_pil is not None:
|
||||
try:
|
||||
buffer = io.BytesIO()
|
||||
screen_pil.save(buffer, format='PNG')
|
||||
image_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
except Exception as e:
|
||||
print(f"[UI-TARS/client] Erreur encodage image: {e}")
|
||||
# Continuer sans image — le serveur capturera l'ecran
|
||||
|
||||
payload = {
|
||||
"target_text": target_text,
|
||||
"target_description": target_description,
|
||||
"image_b64": image_b64,
|
||||
}
|
||||
|
||||
try:
|
||||
t0 = time.time()
|
||||
resp = requests.post(
|
||||
f"{self.SERVER_URL}/ground",
|
||||
json=payload,
|
||||
timeout=30, # UI-TARS peut prendre 3-5s + overhead reseau
|
||||
)
|
||||
total_ms = (time.time() - t0) * 1000
|
||||
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
result = GroundingResult(
|
||||
x=data["x"],
|
||||
y=data["y"],
|
||||
method=data.get("method", "ui_tars"),
|
||||
confidence=data.get("confidence", 0.85),
|
||||
time_ms=data.get("time_ms", total_ms),
|
||||
)
|
||||
print(f"[UI-TARS/client] '{target_text or target_description}' -> "
|
||||
f"({result.x}, {result.y}) conf={result.confidence:.2f} "
|
||||
f"({result.time_ms:.0f}ms)")
|
||||
return result
|
||||
|
||||
elif resp.status_code == 422:
|
||||
# Coordonnees non parsees
|
||||
detail = resp.json().get("detail", "")
|
||||
print(f"[UI-TARS/client] Pas de coordonnees parsees: {detail[:150]}")
|
||||
return None
|
||||
|
||||
elif resp.status_code == 503:
|
||||
print(f"[UI-TARS/client] Serveur pas encore pret (modele en chargement)")
|
||||
return None
|
||||
|
||||
else:
|
||||
print(f"[UI-TARS/client] Erreur HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
return None
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
self._server_available = False
|
||||
print(f"[UI-TARS/client] Serveur non joignable sur {self.SERVER_URL}")
|
||||
return None
|
||||
except requests.exceptions.Timeout:
|
||||
print(f"[UI-TARS/client] Timeout (>30s) pour '{target_text}'")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"[UI-TARS/client] Erreur inattendue: {e}")
|
||||
return None
|
||||
Reference in New Issue
Block a user