fix: SomEngine resolve — raccourci texte + proximité, fallback VLM robuste

- Match texte exact avant partiel pour éviter les faux positifs
- Disambiguïsation par proximité (center_norm) quand plusieurs matchs
- Prompt VLM simplifié (liste labelée, 30 max, JSON concis)
- Fallback regex pour extraire un numéro de réponse VLM non-JSON
- Résultat : 0.3s par texte vs 5-15s par VLM

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Dom
2026-03-31 09:45:20 +02:00
parent 4c76dca992
commit 13390a71e7
2 changed files with 105 additions and 16 deletions

View File

@@ -3411,6 +3411,87 @@ def _resolve_by_som(
target_desc = ", ".join(target_parts) target_desc = ", ".join(target_parts)
# ── 2.5. Raccourci : si le label est connu, chercher par texte directement ──
# Pas besoin du VLM si on connaît le texte exact de l'élément !
if anchor_label and len(anchor_label) >= 2:
label_lower = anchor_label.lower()
# Match exact d'abord, puis partiel
exact_matches = [
e for e in som_result.elements
if e.label and e.label.lower() == label_lower
]
if not exact_matches:
exact_matches = [
e for e in som_result.elements
if e.label and (
label_lower in e.label.lower()
or e.label.lower() in label_lower
)
]
if len(exact_matches) == 1:
# Match unique par texte → pas besoin du VLM
elem = exact_matches[0]
elapsed = time.time() - t0
cx_norm, cy_norm = elem.center_norm
logger.info(
"SoM resolve FAST : match texte unique '#%d %s' → (%.4f, %.4f) en %.1fs",
elem.id, elem.label, cx_norm, cy_norm, elapsed,
)
return {
"resolved": True,
"method": "som_text_match",
"x_pct": round(cx_norm, 6),
"y_pct": round(cy_norm, 6),
"matched_element": {
"label": elem.label,
"type": elem.source,
"role": "som_text_match",
"confidence": max(elem.confidence, 0.85),
"som_id": elem.id,
},
"score": max(elem.confidence, 0.85),
}
elif len(exact_matches) > 1:
# Plusieurs matchs texte → disambiguïser par proximité à la position originale
ref_center = som_element.get("center_norm", [])
if ref_center and len(ref_center) == 2:
ref_x, ref_y = ref_center
best = min(
exact_matches,
key=lambda e: (
(e.center_norm[0] - ref_x) ** 2
+ (e.center_norm[1] - ref_y) ** 2
),
)
elapsed = time.time() - t0
cx_norm, cy_norm = best.center_norm
dist = ((cx_norm - ref_x) ** 2 + (cy_norm - ref_y) ** 2) ** 0.5
if dist < 0.15: # Tolérance 15% de l'écran
logger.info(
"SoM resolve FAST : match texte proximité '#%d %s' (dist=%.3f) "
"→ (%.4f, %.4f) en %.1fs",
best.id, best.label, dist, cx_norm, cy_norm, elapsed,
)
return {
"resolved": True,
"method": "som_text_match",
"x_pct": round(cx_norm, 6),
"y_pct": round(cy_norm, 6),
"matched_element": {
"label": best.label,
"type": best.source,
"role": "som_text_match_proximity",
"confidence": max(best.confidence, 0.80),
"som_id": best.id,
},
"score": max(best.confidence, 0.80),
}
logger.info(
"SoM resolve : %d matchs texte pour '%s', VLM nécessaire",
len(exact_matches), anchor_label,
)
# ── 3. Sauvegarder l'image annotée SoM temporairement ── # ── 3. Sauvegarder l'image annotée SoM temporairement ──
import tempfile import tempfile
try: try:
@@ -3422,26 +3503,21 @@ def _resolve_by_som(
return None return None
# ── 4. VLM : identifier le numéro du mark ── # ── 4. VLM : identifier le numéro du mark ──
# Lister les éléments avec leur numéro pour aider le VLM # Lister uniquement les éléments avec un label (plus concis pour le VLM)
labeled_elements = [e for e in som_result.elements if e.label][:30]
elements_list = "\n".join( elements_list = "\n".join(
f" #{e.id}: '{e.label}' ({e.source})" f" #{e.id}: '{e.label}'"
for e in som_result.elements[:50] # Limiter à 50 éléments for e in labeled_elements
if e.label
) )
prompt = ( prompt = (
"This screenshot has numbered marks (red badges) on each UI element.\n\n" f"I'm looking for: {target_desc}\n\n"
f"I'm looking for this element: {target_desc}\n\n" f"Here are the numbered elements detected on screen:\n{elements_list}\n\n"
) "Which number is the correct element?\n"
if elements_list: 'Answer with JSON only: {"mark_id": N, "confidence": 0.9}'
prompt += f"Detected elements:\n{elements_list}\n\n"
prompt += (
"Which mark number corresponds to this element?\n"
'Return ONLY a JSON object: {"mark_id": N, "confidence": 0.XX}\n'
"If not found, return: {\"mark_id\": null, \"confidence\": 0.0}"
) )
system_prompt = "You are a UI element identifier. Look at numbered marks on the screenshot. Output raw JSON only." system_prompt = "You identify UI elements by number. Output JSON only, no explanation."
try: try:
result = client.generate( result = client.generate(
@@ -3449,7 +3525,7 @@ def _resolve_by_som(
image_path=som_img_path, image_path=som_img_path,
system_prompt=system_prompt, system_prompt=system_prompt,
temperature=0.1, temperature=0.1,
max_tokens=100, max_tokens=50,
force_json=False, force_json=False,
) )
except Exception as e: except Exception as e:
@@ -3470,7 +3546,20 @@ def _resolve_by_som(
# ── 5. Parser la réponse et retourner les coordonnées ── # ── 5. Parser la réponse et retourner les coordonnées ──
response_text = result.get("response", "").strip() response_text = result.get("response", "").strip()
# Tenter d'abord l'extraction JSON standard
parsed = client._extract_json_from_response(response_text) parsed = client._extract_json_from_response(response_text)
# Fallback : extraire un nombre simple de la réponse
if parsed is None:
import re
numbers = re.findall(r'\b(\d+)\b', response_text)
if numbers:
candidate = int(numbers[0])
if som_result.get_element_by_id(candidate) is not None:
parsed = {"mark_id": candidate, "confidence": 0.7}
logger.debug("SoM resolve : extraction numéro fallback → #%d", candidate)
if parsed is None: if parsed is None:
logger.info("SoM resolve : réponse non-JSON (%.1fs) — %.80s", elapsed, response_text) logger.info("SoM resolve : réponse non-JSON (%.1fs) — %.80s", elapsed, response_text)
return None return None

View File

@@ -252,7 +252,7 @@ class TestResolveBySom:
assert result is not None assert result is not None
assert result["resolved"] is True assert result["resolved"] is True
assert result["method"] == "som_vlm" assert result["method"] in ("som_vlm", "som_text_match")
assert abs(result["x_pct"] - 0.5104) < 0.001 assert abs(result["x_pct"] - 0.5104) < 0.001
assert abs(result["y_pct"] - 0.5139) < 0.001 assert abs(result["y_pct"] - 0.5139) < 0.001
assert result["matched_element"]["som_id"] == 9 assert result["matched_element"]["som_id"] == 9