diff --git a/agent_v0/server_v1/api_stream.py b/agent_v0/server_v1/api_stream.py index 2f8347b29..8f51103e7 100644 --- a/agent_v0/server_v1/api_stream.py +++ b/agent_v0/server_v1/api_stream.py @@ -3322,6 +3322,197 @@ def _vlm_quick_find( return None +# --------------------------------------------------------------------------- +# Résolution Set-of-Mark : SomEngine (détection) + VLM (identification) +# --------------------------------------------------------------------------- + +_som_engine_api = None # Singleton + + +def _get_som_engine_api(): + """Singleton SomEngine pour la résolution visuelle (lazy-loaded, GPU).""" + global _som_engine_api + if _som_engine_api is None: + try: + from core.detection.som_engine import SomEngine + _som_engine_api = SomEngine(device="cuda") + logger.info("SomEngine API initialisé (lazy singleton)") + except Exception as e: + logger.warning("SomEngine API non disponible : %s", e) + _som_engine_api = False + return _som_engine_api if _som_engine_api is not False else None + + +def _resolve_by_som( + screenshot_path: str, + target_spec: Dict[str, Any], + screen_width: int, + screen_height: int, +) -> Optional[Dict[str, Any]]: + """Résoudre une cible UI via Set-of-Mark + VLM. + + Pipeline : + 1. SomEngine détecte tous les éléments et les numérote sur le screenshot + 2. VLM reçoit l'image annotée + description de la cible + 3. VLM identifie le numéro du mark → coordonnées précises + + Avantages vs VLM direct : + - Le VLM n'a qu'à identifier (son point fort), pas localiser + - Les coordonnées viennent de SomEngine (pixel-perfect) + - Question simple "quel numéro ?" → réponse simple + + Args: + screenshot_path: Chemin du screenshot actuel + target_spec: Spécification de la cible (vlm_description, som_element, etc.) + screen_width: Largeur écran en pixels + screen_height: Hauteur écran en pixels + + Returns: + Dict avec resolved=True et coordonnées, ou None si indisponible. + """ + engine = _get_som_engine_api() + if engine is None: + return None + + client = _get_vlm_client() + if client is None: + return None + + t0 = time.time() + + # ── 1. Lancer SomEngine sur le screenshot actuel ── + try: + from PIL import Image as PILImage + img = PILImage.open(screenshot_path).convert("RGB") + som_result = engine.analyze(img) + except Exception as e: + logger.warning("SoM resolve : erreur analyse — %s", e) + return None + + if not som_result.elements: + logger.info("SoM resolve : 0 éléments détectés") + return None + + # ── 2. Construire la description de la cible ── + som_element = target_spec.get("som_element", {}) + vlm_description = target_spec.get("vlm_description", "") + anchor_label = som_element.get("label", "") + + # Construire un prompt riche + target_parts = [] + if anchor_label: + target_parts.append(f"texte '{anchor_label}'") + if vlm_description: + target_parts.append(vlm_description) + if not target_parts: + # Sans description, SoM resolve ne peut pas fonctionner + logger.debug("SoM resolve : pas de description pour identifier l'élément") + return None + + target_desc = ", ".join(target_parts) + + # ── 3. Sauvegarder l'image annotée SoM temporairement ── + import tempfile + try: + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: + som_result.som_image.save(tmp, format="JPEG", quality=85) + som_img_path = tmp.name + except Exception as e: + logger.warning("SoM resolve : erreur sauvegarde image annotée — %s", e) + return None + + # ── 4. VLM : identifier le numéro du mark ── + # Lister les éléments avec leur numéro pour aider le VLM + elements_list = "\n".join( + f" #{e.id}: '{e.label}' ({e.source})" + for e in som_result.elements[:50] # Limiter à 50 éléments + if e.label + ) + + prompt = ( + "This screenshot has numbered marks (red badges) on each UI element.\n\n" + f"I'm looking for this element: {target_desc}\n\n" + ) + if elements_list: + prompt += f"Detected elements:\n{elements_list}\n\n" + prompt += ( + "Which mark number corresponds to this element?\n" + 'Return ONLY a JSON object: {"mark_id": N, "confidence": 0.XX}\n' + "If not found, return: {\"mark_id\": null, \"confidence\": 0.0}" + ) + + system_prompt = "You are a UI element identifier. Look at numbered marks on the screenshot. Output raw JSON only." + + try: + result = client.generate( + prompt=prompt, + image_path=som_img_path, + system_prompt=system_prompt, + temperature=0.1, + max_tokens=100, + force_json=False, + ) + except Exception as e: + logger.warning("SoM resolve : erreur VLM — %s", e) + return None + finally: + import os + try: + os.unlink(som_img_path) + except OSError: + pass + + elapsed = time.time() - t0 + + if not result.get("success"): + logger.info("SoM resolve : VLM échoué (%.1fs)", elapsed) + return None + + # ── 5. Parser la réponse et retourner les coordonnées ── + response_text = result.get("response", "").strip() + parsed = client._extract_json_from_response(response_text) + if parsed is None: + logger.info("SoM resolve : réponse non-JSON (%.1fs) — %.80s", elapsed, response_text) + return None + + mark_id = parsed.get("mark_id") + confidence = float(parsed.get("confidence", 0.0)) + + if mark_id is None or confidence < 0.3: + logger.info( + "SoM resolve : mark non trouvé ou confiance trop basse (mark=%s, conf=%.2f, %.1fs)", + mark_id, confidence, elapsed, + ) + return None + + mark_id = int(mark_id) + elem = som_result.get_element_by_id(mark_id) + if elem is None: + logger.warning("SoM resolve : mark #%d inexistant (%.1fs)", mark_id, elapsed) + return None + + cx_norm, cy_norm = elem.center_norm + logger.info( + "SoM resolve OK : mark #%d '%s' → (%.4f, %.4f) conf=%.2f en %.1fs (%d éléments)", + mark_id, elem.label, cx_norm, cy_norm, confidence, elapsed, len(som_result.elements), + ) + + return { + "resolved": True, + "method": "som_vlm", + "x_pct": round(cx_norm, 6), + "y_pct": round(cy_norm, 6), + "matched_element": { + "label": elem.label or f"mark #{mark_id}", + "type": elem.source, + "role": "som_identified", + "confidence": confidence, + "som_id": mark_id, + }, + "score": confidence, + } + + def _resolve_target_sync( screenshot_path: str, target_spec: Dict[str, Any], @@ -3336,6 +3527,7 @@ def _resolve_target_sync( Hiérarchie de résolution (strict_mode=True, replay sessions) — VLM-FIRST : 1. VLM Quick Find (~3-8s) — compréhension sémantique de l'écran, multi-image (screenshot + crop de référence + description riche) + 1.5. SoM + VLM (~5-15s) — SomEngine numérote les éléments, VLM identifie le bon 2. Template matching OpenCV (~100ms) — fallback pixel, seuil STRICT 0.90 3. resolved=False → STOP le replay @@ -3394,6 +3586,30 @@ def _resolve_target_sync( vlm_description[:60] if vlm_description else "(anchor)", ) + # --------------------------------------------------------------- + # Étape 1.5 : SoM + VLM (Set-of-Mark + identification) + # SomEngine numérote les éléments, VLM identifie le bon numéro. + # Plus fiable que le VLM direct car le VLM n'a qu'à identifier, + # pas localiser — et les coordonnées sont pixel-perfect. + # --------------------------------------------------------------- + som_element = target_spec.get("som_element", {}) + if som_element or vlm_description: + som_result = _resolve_by_som( + screenshot_path=screenshot_path, + target_spec=target_spec, + screen_width=screen_width, + screen_height=screen_height, + ) + if som_result and som_result.get("resolved"): + logger.info( + "Strict resolve SoM+VLM : OK (score=%.2f, mark=#%s)", + som_result.get("score", 0), + som_result.get("matched_element", {}).get("som_id", "?"), + ) + return som_result + else: + logger.info("Strict resolve SoM+VLM : échoué, passage template matching") + # --------------------------------------------------------------- # Étape 2 : Template matching (fallback pixel) — seuil STRICT 0.90 # --------------------------------------------------------------- diff --git a/agent_v0/server_v1/stream_processor.py b/agent_v0/server_v1/stream_processor.py index 379293391..6f1afd105 100644 --- a/agent_v0/server_v1/stream_processor.py +++ b/agent_v0/server_v1/stream_processor.py @@ -427,6 +427,111 @@ def _needs_post_wait(action: dict) -> int: return 0 +# --------------------------------------------------------------------------- +# SomEngine — enrichissement Set-of-Mark des clics pendant le build_replay +# --------------------------------------------------------------------------- + +_som_engine = None # Singleton, chargé à la demande + + +def _get_som_engine(): + """Singleton SomEngine (lazy-loaded, GPU).""" + global _som_engine + if _som_engine is None: + try: + from core.detection.som_engine import SomEngine + _som_engine = SomEngine(device="cuda") + logger.info("SomEngine initialisé (lazy singleton)") + except Exception as e: + logger.warning("SomEngine non disponible : %s", e) + _som_engine = False # Marqueur "indisponible" + return _som_engine if _som_engine is not False else None + + +def _som_identify_clicked_element( + event_data: dict, + session_dir: Optional[Path], + screen_w: int, + screen_h: int, +) -> Optional[dict]: + """Identifier l'élément UI cliqué via SomEngine (YOLO + docTR). + + Charge le full screenshot de l'événement, lance SomEngine pour détecter + tous les éléments, puis identifie celui qui se trouve sous le clic. + + Returns: + Dict avec id, label, source, bbox_norm, center_norm, confidence + ou None si SomEngine indisponible ou élément non trouvé. + """ + engine = _get_som_engine() + if engine is None: + return None + + if not session_dir: + return None + + shots_dir = session_dir / "shots" + if not shots_dir.is_dir(): + return None + + # Trouver le full screenshot + screenshot_id = event_data.get("screenshot_id", "") + if not screenshot_id: + return None + + full_path = shots_dir / f"{screenshot_id}_full.png" + if not full_path.is_file(): + # Fallback : essayer sans le suffixe _full + full_path = shots_dir / f"{screenshot_id}.png" + if not full_path.is_file(): + return None + + try: + from PIL import Image + img = Image.open(full_path).convert("RGB") + except Exception as e: + logger.debug("SoM: impossible de charger %s : %s", full_path, e) + return None + + # Lancer SomEngine + try: + result = engine.analyze(img) + except Exception as e: + logger.warning("SoM: erreur d'analyse : %s", e) + return None + + if not result.elements: + return None + + # Trouver l'élément cliqué + pos = event_data.get("pos", []) + if not pos or len(pos) < 2: + return None + + click_x, click_y = int(pos[0]), int(pos[1]) + elem = result.find_element_at(click_x, click_y, margin=30) + if elem is None: + logger.debug( + "SoM: aucun élément trouvé au clic (%d, %d) parmi %d éléments", + click_x, click_y, len(result.elements), + ) + return None + + logger.info( + "SoM: clic (%d,%d) → élément #%d '%s' (source=%s, conf=%.2f)", + click_x, click_y, elem.id, elem.label, elem.source, elem.confidence, + ) + return { + "id": elem.id, + "label": elem.label, + "source": elem.source, + "bbox_norm": list(elem.bbox_norm), + "center_norm": list(elem.center_norm), + "confidence": elem.confidence, + "element_count": len(result.elements), + } + + def _load_crop_for_event( event_data: dict, session_dir: Optional[Path], @@ -919,6 +1024,18 @@ def build_replay_from_raw_events( # Sinon le template matching texte cherche "13071967.txt – Bloc-notes" # sur l'écran et clique sur la barre de titre au lieu du bon élément. + # ── SomEngine : identifier l'élément cliqué ── + som_elem = _som_identify_clicked_element( + evt, session_dir_path, screen_w, screen_h, + ) + if som_elem: + action["target_spec"]["som_element"] = som_elem + # Enrichir la description VLM avec le label SoM + if som_elem.get("label") and not vision_info.get("text"): + action["target_spec"]["vlm_description"] += ( + f", le texte de l'élément est '{som_elem['label']}'" + ) + elif evt_type == "text_input": text = evt.get("text", "") if not text: diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..659a04127 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,26 @@ +"""Conftest pour les tests unitaires. + +Force le bon chemin agent_v0 (rpa_vision_v3) pour éviter les conflits +avec ~/ai/agent_v0 (standalone). +""" +import sys +from pathlib import Path + +ROOT = str(Path(__file__).resolve().parents[2]) + +if ROOT in sys.path: + sys.path.remove(ROOT) +sys.path.insert(0, ROOT) + +# Si agent_v0 est déjà chargé depuis le mauvais chemin, le remplacer +_agent_mod = sys.modules.get("agent_v0") +if _agent_mod and not getattr(_agent_mod, "__file__", "").startswith(ROOT): + to_remove = [k for k in sys.modules if k == "agent_v0" or k.startswith("agent_v0.")] + for k in to_remove: + del sys.modules[k] + +# Pré-importer le bon agent_v0.server_v1 +try: + import agent_v0.server_v1 # noqa: F401 +except ImportError: + pass diff --git a/tests/unit/test_som_integration.py b/tests/unit/test_som_integration.py new file mode 100644 index 000000000..ac35c4687 --- /dev/null +++ b/tests/unit/test_som_integration.py @@ -0,0 +1,301 @@ +"""Tests unitaires pour l'intégration SomEngine dans build_replay et resolve_target. + +Vérifie : +- Phase 1 : _som_identify_clicked_element enrichit target_spec avec som_element +- Phase 2 : _resolve_by_som utilise SomEngine + VLM pour résoudre une cible +- Fallbacks gracieux quand SomEngine ou VLM indisponible +""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + + +# ── Phase 1 : Enrichissement build_replay ── + + +class TestSomIdentifyClickedElement: + """Tests pour _som_identify_clicked_element (Phase 1).""" + + def test_returns_none_when_engine_unavailable(self): + """Si SomEngine n'est pas disponible, retourne None sans erreur.""" + from agent_v0.server_v1.stream_processor import _som_identify_clicked_element + + with patch( + "agent_v0.server_v1.stream_processor._get_som_engine", + return_value=None, + ): + result = _som_identify_clicked_element( + {"screenshot_id": "shot_0001", "pos": [500, 300]}, + Path("/fake/dir"), + 1920, 1080, + ) + assert result is None + + def test_returns_none_when_no_session_dir(self): + """Sans session_dir, retourne None.""" + from agent_v0.server_v1.stream_processor import _som_identify_clicked_element + + result = _som_identify_clicked_element( + {"screenshot_id": "shot_0001", "pos": [500, 300]}, + None, 1920, 1080, + ) + assert result is None + + def test_returns_none_when_no_screenshot_id(self): + """Sans screenshot_id, retourne None.""" + from agent_v0.server_v1.stream_processor import _som_identify_clicked_element + + result = _som_identify_clicked_element( + {"pos": [500, 300]}, + Path("/fake/dir"), + 1920, 1080, + ) + assert result is None + + def test_returns_element_when_found(self, tmp_path): + """Quand SomEngine trouve un élément sous le clic, retourne ses infos.""" + from core.detection.som_engine import SomElement, SomResult + from agent_v0.server_v1.stream_processor import _som_identify_clicked_element + + # Créer un faux screenshot + shots_dir = tmp_path / "shots" + shots_dir.mkdir() + from PIL import Image + img = Image.new("RGB", (1920, 1080), color="white") + img.save(shots_dir / "shot_0001_full.png") + + # Mock SomEngine + mock_elem = SomElement( + id=5, + bbox=(480, 280, 520, 320), + bbox_norm=(0.25, 0.259, 0.271, 0.296), + center=(500, 300), + center_norm=(0.2604, 0.2778), + source="yolo", + label="Enregistrer", + confidence=0.92, + ) + mock_result = SomResult( + elements=[mock_elem], + width=1920, + height=1080, + ) + mock_engine = MagicMock() + mock_engine.analyze.return_value = mock_result + + with patch( + "agent_v0.server_v1.stream_processor._get_som_engine", + return_value=mock_engine, + ): + result = _som_identify_clicked_element( + {"screenshot_id": "shot_0001", "pos": [500, 300]}, + tmp_path, 1920, 1080, + ) + + assert result is not None + assert result["id"] == 5 + assert result["label"] == "Enregistrer" + assert result["source"] == "yolo" + assert result["confidence"] == 0.92 + assert result["element_count"] == 1 + + def test_returns_none_when_no_element_at_click(self, tmp_path): + """Quand aucun élément n'est sous le clic, retourne None.""" + from core.detection.som_engine import SomResult + from agent_v0.server_v1.stream_processor import _som_identify_clicked_element + + shots_dir = tmp_path / "shots" + shots_dir.mkdir() + from PIL import Image + img = Image.new("RGB", (1920, 1080), color="white") + img.save(shots_dir / "shot_0001_full.png") + + # Résultat avec des éléments mais pas au point du clic + mock_result = SomResult(elements=[], width=1920, height=1080) + mock_engine = MagicMock() + mock_engine.analyze.return_value = mock_result + + with patch( + "agent_v0.server_v1.stream_processor._get_som_engine", + return_value=mock_engine, + ): + result = _som_identify_clicked_element( + {"screenshot_id": "shot_0001", "pos": [500, 300]}, + tmp_path, 1920, 1080, + ) + + assert result is None + + +# ── Phase 2 : Résolution SoM + VLM ── + + +class TestResolveBySom: + """Tests pour _resolve_by_som (Phase 2).""" + + def test_returns_none_when_engine_unavailable(self): + """Sans SomEngine, retourne None.""" + from agent_v0.server_v1.api_stream import _resolve_by_som + + with patch( + "agent_v0.server_v1.api_stream._get_som_engine_api", + return_value=None, + ): + result = _resolve_by_som( + "/fake/path.jpg", + {"vlm_description": "un bouton"}, + 1920, 1080, + ) + assert result is None + + def test_returns_none_when_vlm_unavailable(self): + """Sans VLM, retourne None.""" + from agent_v0.server_v1.api_stream import _resolve_by_som + + mock_engine = MagicMock() + + with patch( + "agent_v0.server_v1.api_stream._get_som_engine_api", + return_value=mock_engine, + ), patch( + "agent_v0.server_v1.api_stream._get_vlm_client", + return_value=None, + ): + result = _resolve_by_som( + "/fake/path.jpg", + {"vlm_description": "un bouton"}, + 1920, 1080, + ) + assert result is None + + def test_returns_none_without_description(self): + """Sans description ni som_element, retourne None.""" + from agent_v0.server_v1.api_stream import _resolve_by_som + + mock_engine = MagicMock() + mock_client = MagicMock() + + with patch( + "agent_v0.server_v1.api_stream._get_som_engine_api", + return_value=mock_engine, + ), patch( + "agent_v0.server_v1.api_stream._get_vlm_client", + return_value=mock_client, + ): + result = _resolve_by_som( + "/fake/path.jpg", + {}, # Pas de description + 1920, 1080, + ) + assert result is None + + def test_resolve_success(self, tmp_path): + """Résolution réussie : SomEngine détecte, VLM identifie le mark.""" + from core.detection.som_engine import SomElement, SomResult + from agent_v0.server_v1.api_stream import _resolve_by_som + + # Créer un faux screenshot + from PIL import Image + img = Image.new("RGB", (1920, 1080), color="white") + screenshot_path = str(tmp_path / "screen.jpg") + img.save(screenshot_path) + + # Mock SomEngine + mock_elem = SomElement( + id=9, + bbox=(960, 540, 1000, 570), + bbox_norm=(0.5, 0.5, 0.521, 0.528), + center=(980, 555), + center_norm=(0.5104, 0.5139), + source="ocr", + label="Ouvrir", + confidence=0.88, + ) + mock_result = SomResult( + elements=[mock_elem], + som_image=img.copy(), + som_image_b64="fake_b64", + width=1920, + height=1080, + ) + mock_engine = MagicMock() + mock_engine.analyze.return_value = mock_result + + # Mock VLM client + mock_client = MagicMock() + mock_client.generate.return_value = { + "success": True, + "response": '{"mark_id": 9, "confidence": 0.95}', + } + mock_client._extract_json_from_response.return_value = { + "mark_id": 9, + "confidence": 0.95, + } + + with patch( + "agent_v0.server_v1.api_stream._get_som_engine_api", + return_value=mock_engine, + ), patch( + "agent_v0.server_v1.api_stream._get_vlm_client", + return_value=mock_client, + ): + result = _resolve_by_som( + screenshot_path, + { + "vlm_description": "le bouton Ouvrir", + "som_element": {"id": 9, "label": "Ouvrir"}, + }, + 1920, 1080, + ) + + assert result is not None + assert result["resolved"] is True + assert result["method"] == "som_vlm" + assert abs(result["x_pct"] - 0.5104) < 0.001 + assert abs(result["y_pct"] - 0.5139) < 0.001 + assert result["matched_element"]["som_id"] == 9 + + def test_resolve_vlm_low_confidence(self, tmp_path): + """VLM retourne une confiance trop basse → None.""" + from core.detection.som_engine import SomResult + from agent_v0.server_v1.api_stream import _resolve_by_som + + from PIL import Image + img = Image.new("RGB", (1920, 1080), color="white") + screenshot_path = str(tmp_path / "screen.jpg") + img.save(screenshot_path) + + mock_result = SomResult( + elements=[MagicMock(id=1, label="test", source="ocr")], + som_image=img.copy(), + width=1920, height=1080, + ) + mock_engine = MagicMock() + mock_engine.analyze.return_value = mock_result + + mock_client = MagicMock() + mock_client.generate.return_value = { + "success": True, + "response": '{"mark_id": 1, "confidence": 0.1}', + } + mock_client._extract_json_from_response.return_value = { + "mark_id": 1, + "confidence": 0.1, + } + + with patch( + "agent_v0.server_v1.api_stream._get_som_engine_api", + return_value=mock_engine, + ), patch( + "agent_v0.server_v1.api_stream._get_vlm_client", + return_value=mock_client, + ): + result = _resolve_by_som( + screenshot_path, + {"vlm_description": "un bouton"}, + 1920, 1080, + ) + + assert result is None