Compare commits
19 Commits
main
...
41c1250c99
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41c1250c99 | ||
|
|
2af3bc3b93 | ||
|
|
6154423a91 | ||
|
|
41eba898c0 | ||
|
|
9452e86fd1 | ||
|
|
5e31cdf666 | ||
|
|
487bcb8618 | ||
|
|
3d6868f029 | ||
|
|
f73a2a59a9 | ||
|
|
77faa03ec9 | ||
|
|
343d6fbe95 | ||
|
|
cc64439738 | ||
|
|
90007cc7c1 | ||
|
|
73cea2385e | ||
|
|
e2046837cf | ||
|
|
b30d4b6656 | ||
|
|
e4a48e78bf | ||
|
|
ea36bba5cc | ||
|
|
9da589c8c2 |
@@ -46,6 +46,14 @@ LOGS_PATH=logs
|
||||
UPLOADS_PATH=data/training/uploads
|
||||
SESSIONS_PATH=data/training/sessions
|
||||
|
||||
# ============================================================================
|
||||
# Feedback Bus (Léa parle pendant exécution)
|
||||
# ============================================================================
|
||||
# Bus SocketIO unifié 'lea:*' (action_started, action_done, need_confirm, paused).
|
||||
# Désactivé par défaut. Mettre à 1 pour activer les bulles temps réel dans ChatWindow.
|
||||
# Si la connexion bus échoue, l'exécution continue normalement (fail-safe).
|
||||
LEA_FEEDBACK_BUS=0
|
||||
|
||||
# ============================================================================
|
||||
# FAISS
|
||||
# ============================================================================
|
||||
|
||||
@@ -133,6 +133,28 @@ def _streaming_headers() -> dict:
|
||||
headers["Authorization"] = f"Bearer {_STREAMING_API_TOKEN}"
|
||||
return headers
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Feedback Bus — events 'lea:*' temps réel vers ChatWindow
|
||||
# ============================================================
|
||||
LEA_FEEDBACK_BUS = os.environ.get("LEA_FEEDBACK_BUS", "0").lower() in ("1", "true", "yes", "on")
|
||||
|
||||
|
||||
def _emit_lea(event: str, payload: Dict[str, Any]) -> None:
|
||||
"""Émet 'lea:{event}' sur le bus SocketIO. No-op silencieux si flag off ou erreur."""
|
||||
if not LEA_FEEDBACK_BUS:
|
||||
return
|
||||
try:
|
||||
socketio.emit(f"lea:{event}", payload)
|
||||
except Exception:
|
||||
logger.debug("_emit_lea silenced", exc_info=True)
|
||||
|
||||
|
||||
def _emit_dual(legacy_event: str, lea_event: str, payload: Dict[str, Any], **kwargs) -> None:
|
||||
"""Émet l'event legacy (compat dashboard) ET l'alias lea:* (ChatWindow tkinter)."""
|
||||
socketio.emit(legacy_event, payload, **kwargs)
|
||||
_emit_lea(lea_event, payload)
|
||||
|
||||
execution_status = {
|
||||
"running": False,
|
||||
"workflow": None,
|
||||
@@ -623,7 +645,7 @@ def api_execute():
|
||||
}
|
||||
|
||||
# Notifier via WebSocket
|
||||
socketio.emit('execution_started', {
|
||||
_emit_dual('execution_started', 'action_started', {
|
||||
"workflow": match.workflow_name,
|
||||
"params": all_params
|
||||
})
|
||||
@@ -1181,28 +1203,28 @@ def _execute_gesture(gesture):
|
||||
)
|
||||
|
||||
if resp.status_code == 200:
|
||||
socketio.emit('execution_completed', {
|
||||
_emit_dual('execution_completed', 'done', {
|
||||
"workflow": gesture.name,
|
||||
"success": True,
|
||||
"message": f"Geste '{gesture.name}' ({'+'.join(gesture.keys)}) envoyé",
|
||||
})
|
||||
else:
|
||||
error = resp.text[:200]
|
||||
socketio.emit('execution_completed', {
|
||||
_emit_dual('execution_completed', 'done', {
|
||||
"workflow": gesture.name,
|
||||
"success": False,
|
||||
"message": f"Erreur: {error}",
|
||||
})
|
||||
|
||||
except http_requests.ConnectionError:
|
||||
socketio.emit('execution_completed', {
|
||||
_emit_dual('execution_completed', 'done', {
|
||||
"workflow": gesture.name,
|
||||
"success": False,
|
||||
"message": "Serveur de streaming non disponible (port 5005).",
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Gesture execution error: {e}")
|
||||
socketio.emit('execution_completed', {
|
||||
_emit_dual('execution_completed', 'done', {
|
||||
"workflow": gesture.name,
|
||||
"success": False,
|
||||
"message": f"Erreur: {str(e)}",
|
||||
@@ -1661,6 +1683,52 @@ def handle_copilot_abort():
|
||||
})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Bulle paused_need_help — handlers SocketIO depuis ChatWindow (J3.5)
|
||||
# =============================================================================
|
||||
|
||||
@socketio.on('lea:replay_resume')
|
||||
def handle_lea_replay_resume(data):
|
||||
"""Bouton Continuer : relayer le resume vers le streaming server."""
|
||||
replay_id = (data or {}).get("replay_id")
|
||||
if not replay_id:
|
||||
_emit_lea("resume_acked", {"status": "error", "detail": "replay_id manquant"})
|
||||
return
|
||||
try:
|
||||
resp = http_requests.post(
|
||||
f"{STREAMING_SERVER_URL}/api/v1/traces/stream/replay/{replay_id}/resume",
|
||||
headers=_streaming_headers(),
|
||||
timeout=5,
|
||||
)
|
||||
if resp.ok:
|
||||
logger.info(f"Replay {replay_id} resume relayé OK")
|
||||
_emit_lea("resume_acked", {"replay_id": replay_id, "status": "ok"})
|
||||
else:
|
||||
detail = resp.text[:200]
|
||||
logger.warning(f"Resume échoué (HTTP {resp.status_code}): {detail}")
|
||||
_emit_lea("resume_acked", {
|
||||
"replay_id": replay_id, "status": "error",
|
||||
"http_status": resp.status_code, "detail": detail,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Resume relay error: {e}")
|
||||
_emit_lea("resume_acked", {
|
||||
"replay_id": replay_id, "status": "error", "detail": str(e),
|
||||
})
|
||||
|
||||
|
||||
@socketio.on('lea:replay_abort')
|
||||
def handle_lea_replay_abort(data):
|
||||
"""Bouton Annuler : arrêter le polling local. Le replay côté streaming sera
|
||||
cleaned up naturellement au prochain replay (cf api_stream._replay_states stale)."""
|
||||
global execution_status
|
||||
replay_id = (data or {}).get("replay_id")
|
||||
execution_status["running"] = False
|
||||
execution_status["message"] = "Annulé par l'utilisateur"
|
||||
logger.info(f"Replay {replay_id or '?'} abort par l'utilisateur (paused bubble)")
|
||||
_emit_lea("abort_acked", {"replay_id": replay_id, "status": "ok"})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exécution de workflow
|
||||
# =============================================================================
|
||||
@@ -1730,14 +1798,20 @@ def _poll_replay_progress(replay_id: str, workflow_name: str, total_actions: int
|
||||
"""Suivre la progression d'un replay distant via polling."""
|
||||
import time
|
||||
|
||||
max_wait = 120 # 2 minutes max
|
||||
max_wait_running = 120 # 2 min en exécution active
|
||||
max_wait_paused = 600 # 10 min en pause supervisée (humain peut prendre son temps)
|
||||
poll_interval = 2.0
|
||||
elapsed = 0
|
||||
was_paused = False
|
||||
|
||||
while elapsed < max_wait and execution_status.get("running"):
|
||||
while execution_status.get("running"):
|
||||
time.sleep(poll_interval)
|
||||
elapsed += poll_interval
|
||||
|
||||
cap = max_wait_paused if was_paused else max_wait_running
|
||||
if elapsed >= cap:
|
||||
break
|
||||
|
||||
try:
|
||||
resp = http_requests.get(
|
||||
f"{STREAMING_SERVER_URL}/api/v1/traces/stream/replay/{replay_id}",
|
||||
@@ -1753,7 +1827,26 @@ def _poll_replay_progress(replay_id: str, workflow_name: str, total_actions: int
|
||||
failed = data.get("failed_actions", 0)
|
||||
progress = int(10 + (completed / max(total_actions, 1)) * 80)
|
||||
|
||||
socketio.emit('execution_progress', {
|
||||
if status == "paused_need_help" and not was_paused:
|
||||
_emit_lea("paused", {
|
||||
"workflow": workflow_name,
|
||||
"replay_id": replay_id,
|
||||
"completed": completed,
|
||||
"total": total_actions,
|
||||
"failed_action": data.get("failed_action"),
|
||||
"reason": data.get("error") or "Action incertaine",
|
||||
})
|
||||
was_paused = True
|
||||
elapsed = 0
|
||||
elif was_paused and status != "paused_need_help":
|
||||
_emit_lea("resumed", {
|
||||
"workflow": workflow_name,
|
||||
"replay_id": replay_id,
|
||||
"status_after": status,
|
||||
})
|
||||
was_paused = False
|
||||
|
||||
_emit_dual('execution_progress', 'action_progress', {
|
||||
"progress": progress,
|
||||
"step": f"Action {completed}/{total_actions} exécutée",
|
||||
"current": completed,
|
||||
@@ -1922,7 +2015,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
|
||||
actions = _build_actions_from_workflow(match, params)
|
||||
if not actions:
|
||||
socketio.emit('copilot_complete', {
|
||||
_emit_dual('copilot_complete', 'done', {
|
||||
"workflow": workflow_name,
|
||||
"status": "error",
|
||||
"message": "Aucune action exécutable dans ce workflow.",
|
||||
@@ -1959,7 +2052,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
break
|
||||
|
||||
copilot_state["status"] = "waiting_approval"
|
||||
socketio.emit('copilot_step', {
|
||||
_emit_dual('copilot_step', 'need_confirm', {
|
||||
"workflow": workflow_name,
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
@@ -1982,7 +2075,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
|
||||
if waited >= max_wait:
|
||||
copilot_state["status"] = "aborted"
|
||||
socketio.emit('copilot_complete', {
|
||||
_emit_dual('copilot_complete', 'done', {
|
||||
"workflow": workflow_name,
|
||||
"status": "timeout",
|
||||
"message": f"Timeout : pas de réponse après {max_wait}s.",
|
||||
@@ -1999,7 +2092,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
elif decision == "skipped":
|
||||
copilot_state["skipped"] += 1
|
||||
logger.info(f"Copilot skip étape {idx + 1}/{total}")
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "skipped",
|
||||
@@ -2034,7 +2127,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
|
||||
if action_success:
|
||||
copilot_state["completed"] += 1
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "completed",
|
||||
@@ -2042,7 +2135,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
})
|
||||
else:
|
||||
copilot_state["failed"] += 1
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "failed",
|
||||
@@ -2051,7 +2144,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
else:
|
||||
error = resp.text[:200]
|
||||
copilot_state["failed"] += 1
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "failed",
|
||||
@@ -2060,7 +2153,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
|
||||
except http_requests.ConnectionError:
|
||||
copilot_state["failed"] += 1
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "failed",
|
||||
@@ -2070,7 +2163,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
except Exception as e:
|
||||
copilot_state["failed"] += 1
|
||||
logger.error(f"Copilot action error: {e}")
|
||||
socketio.emit('copilot_step_result', {
|
||||
_emit_dual('copilot_step_result', 'step_result', {
|
||||
"step_index": idx,
|
||||
"total": total,
|
||||
"status": "failed",
|
||||
@@ -2098,7 +2191,7 @@ def execute_workflow_copilot(match, params: Dict[str, Any]):
|
||||
f"Copilot terminé : {completed} réussies, "
|
||||
f"{skipped} passées, {failed} échouées sur {total} étapes."
|
||||
)
|
||||
socketio.emit('copilot_complete', {
|
||||
_emit_dual('copilot_complete', 'done', {
|
||||
"workflow": workflow_name,
|
||||
"status": "completed" if success else "partial",
|
||||
"message": message,
|
||||
@@ -2175,7 +2268,7 @@ def execute_workflow(match, params):
|
||||
execution_status["progress"] = 10
|
||||
execution_status["message"] = f"Envoyé à l'Agent V1 ({target_session})"
|
||||
|
||||
socketio.emit('execution_progress', {
|
||||
_emit_dual('execution_progress', 'action_progress', {
|
||||
"progress": 10,
|
||||
"step": f"Replay envoyé à l'Agent V1 — {total_actions} actions en attente",
|
||||
"current": 0,
|
||||
@@ -2523,7 +2616,7 @@ def update_progress(progress: int, message: str, current: int, total: int):
|
||||
execution_status["progress"] = progress
|
||||
execution_status["message"] = message
|
||||
|
||||
socketio.emit('execution_progress', {
|
||||
_emit_dual('execution_progress', 'action_progress', {
|
||||
"progress": progress,
|
||||
"step": message,
|
||||
"current": current,
|
||||
@@ -2543,7 +2636,7 @@ def finish_execution(workflow_name: str, success: bool, message: str):
|
||||
if command_history:
|
||||
command_history[-1]["status"] = "completed" if success else "failed"
|
||||
|
||||
socketio.emit('execution_completed', {
|
||||
_emit_dual('execution_completed', 'done', {
|
||||
"workflow": workflow_name,
|
||||
"success": success,
|
||||
"message": message
|
||||
|
||||
149
agent_v0/agent_v1/network/feedback_bus.py
Normal file
149
agent_v0/agent_v1/network/feedback_bus.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# agent_v1/network/feedback_bus.py
|
||||
"""Client SocketIO pour le bus feedback Léa.
|
||||
|
||||
Consomme les events 'lea:*' émis par agent_chat (port 5004) et les dispatche
|
||||
vers ChatWindow pour affichage en bulles temps réel.
|
||||
|
||||
Events écoutés :
|
||||
lea:action_started — début d'un workflow ou d'une action
|
||||
lea:action_progress — progression dans le workflow
|
||||
lea:done — fin d'un workflow ou d'un copilot
|
||||
lea:need_confirm — étape copilot en attente de validation
|
||||
lea:step_result — résultat d'une étape copilot
|
||||
lea:paused — basculement en paused_need_help (asset démo)
|
||||
lea:resumed — sortie de pause supervisée
|
||||
|
||||
Fail-safe : toute erreur de connexion ou de dispatch est silencieusement
|
||||
loggée. Le ChatWindow continue de fonctionner même si le bus est mort
|
||||
(comportement strictement identique au pré-J3).
|
||||
|
||||
Usage :
|
||||
bus = FeedbackBusClient(
|
||||
server_url="http://localhost:5004",
|
||||
token=os.environ.get("RPA_API_TOKEN", ""),
|
||||
on_event=lambda event, payload: print(event, payload),
|
||||
)
|
||||
bus.start() # connexion en arrière-plan, non-bloquant
|
||||
# ... ChatWindow tourne ...
|
||||
bus.stop()
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Callable, Optional
|
||||
|
||||
import socketio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LEA_EVENTS = (
|
||||
'lea:action_started',
|
||||
'lea:action_progress',
|
||||
'lea:done',
|
||||
'lea:need_confirm',
|
||||
'lea:step_result',
|
||||
'lea:paused',
|
||||
'lea:resumed',
|
||||
)
|
||||
|
||||
EventCallback = Callable[[str, dict], None]
|
||||
|
||||
|
||||
class FeedbackBusClient:
|
||||
"""Client SocketIO non-bloquant pour le bus 'lea:*'."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
token: Optional[str] = None,
|
||||
on_event: Optional[EventCallback] = None,
|
||||
):
|
||||
self._url = server_url.rstrip('/')
|
||||
self._token = token or None
|
||||
self._on_event: EventCallback = on_event or (lambda e, p: None)
|
||||
self._sio = socketio.Client(
|
||||
reconnection=True,
|
||||
reconnection_attempts=0, # 0 = illimité
|
||||
reconnection_delay=2,
|
||||
reconnection_delay_max=30,
|
||||
logger=False,
|
||||
engineio_logger=False,
|
||||
)
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._register_handlers()
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
@self._sio.event
|
||||
def connect():
|
||||
logger.info("FeedbackBus connecté à %s", self._url)
|
||||
|
||||
@self._sio.event
|
||||
def disconnect():
|
||||
logger.info("FeedbackBus déconnecté")
|
||||
|
||||
for ev in LEA_EVENTS:
|
||||
self._sio.on(ev, lambda data, e=ev: self._dispatch(e, data))
|
||||
|
||||
def _dispatch(self, event: str, payload: Optional[dict]) -> None:
|
||||
try:
|
||||
self._on_event(event, payload or {})
|
||||
except Exception:
|
||||
logger.debug("FeedbackBus dispatch silenced", exc_info=True)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Démarrer la connexion en arrière-plan (idempotent, non-bloquant)."""
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
return
|
||||
self._thread = threading.Thread(
|
||||
target=self._run, daemon=True, name="LeaFeedbackBus",
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _run(self) -> None:
|
||||
headers = {}
|
||||
if self._token:
|
||||
headers['Authorization'] = f'Bearer {self._token}'
|
||||
try:
|
||||
self._sio.connect(self._url, headers=headers, wait=True)
|
||||
self._sio.wait()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"FeedbackBus connect échoué (%s) — ChatWindow continue normalement", e,
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Arrêter proprement la connexion (idempotent, fail-safe)."""
|
||||
try:
|
||||
if self._sio.connected:
|
||||
self._sio.disconnect()
|
||||
except Exception:
|
||||
logger.debug("FeedbackBus stop silenced", exc_info=True)
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
return bool(self._sio.connected)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Actions utilisateur depuis la bulle paused_need_help (J3.5)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def resume_replay(self, replay_id: str) -> bool:
|
||||
"""Bouton Continuer : émet 'lea:replay_resume' vers agent_chat.
|
||||
|
||||
Retourne True si l'event a pu être émis, False sinon (déconnecté/erreur).
|
||||
"""
|
||||
return self._safe_emit("lea:replay_resume", {"replay_id": replay_id})
|
||||
|
||||
def abort_replay(self, replay_id: str) -> bool:
|
||||
"""Bouton Annuler : émet 'lea:replay_abort' vers agent_chat."""
|
||||
return self._safe_emit("lea:replay_abort", {"replay_id": replay_id})
|
||||
|
||||
def _safe_emit(self, event: str, payload: dict) -> bool:
|
||||
try:
|
||||
if not self._sio.connected:
|
||||
return False
|
||||
self._sio.emit(event, payload)
|
||||
return True
|
||||
except Exception:
|
||||
logger.debug("FeedbackBus _safe_emit silenced", exc_info=True)
|
||||
return False
|
||||
@@ -3,6 +3,7 @@ mss>=9.0.1 # Capture d'écran haute performance
|
||||
pynput>=1.7.7 # Clavier/Souris Cross-plateforme
|
||||
Pillow>=10.0.0 # Crops et processing image
|
||||
requests>=2.31.0 # Streaming réseau
|
||||
python-socketio[client]>=5.10,<6.0 # Bus feedback Léa 'lea:*' (compat Flask-SocketIO 5.3.x serveur)
|
||||
psutil>=5.9.0 # Monitoring CPU/RAM
|
||||
pystray>=0.19.5 # Icône Tray UI
|
||||
plyer>=2.1.0 # Notifications toast natives (remplace PyQt5)
|
||||
|
||||
@@ -16,6 +16,15 @@ from typing import Any, Callable, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FeedbackBus : import fail-safe (le ChatWindow doit tourner même si python-socketio
|
||||
# n'est pas installé sur le poste client, par exemple ancienne installation Pauline)
|
||||
try:
|
||||
from ..network.feedback_bus import FeedbackBusClient
|
||||
_HAS_FEEDBACK_BUS = True
|
||||
except Exception:
|
||||
FeedbackBusClient = None # type: ignore
|
||||
_HAS_FEEDBACK_BUS = False
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Theme — palette professionnelle claire
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -42,6 +51,25 @@ SCROLLBAR_BG = "#E5E7EB" # Fond scrollbar
|
||||
SCROLLBAR_FG = "#9CA3AF" # Curseur scrollbar
|
||||
MSG_BORDER_COLOR = "#D1D5DB" # Bordure subtile des bulles de messages
|
||||
|
||||
# Bulle paused_need_help (J3.5) — alerte non bloquante, asset démo majeur
|
||||
PAUSED_BG = "#FEF3C7" # Jaune pâle
|
||||
PAUSED_BORDER = "#F59E0B" # Orange ambré
|
||||
PAUSED_FG = "#92400E" # Brun foncé (lisible sur fond jaune)
|
||||
PAUSED_BTN_RESUME_BG = "#22C55E" # Vert
|
||||
PAUSED_BTN_RESUME_HOVER = "#16A34A"
|
||||
PAUSED_BTN_ABORT_BG = "#9CA3AF" # Gris neutre (pas dramatique)
|
||||
PAUSED_BTN_ABORT_HOVER = "#6B7280"
|
||||
|
||||
# Bulle "Léa exécute" (J3.4) — distincte des bulles chat normales
|
||||
ACTION_BG = "#F1F5F9" # Gris très clair (différencie d'une réponse chat)
|
||||
ACTION_BORDER = "#CBD5E1" # Gris pâle
|
||||
ACTION_FG = "#1E293B" # Gris foncé
|
||||
ACTION_META_FG = "#94A3B8" # Métadonnées en gris discret
|
||||
ACTION_ICON_RUN = "#3B82F6" # Bleu (en cours)
|
||||
ACTION_ICON_OK = "#22C55E" # Vert (succès)
|
||||
ACTION_ICON_ERR = "#EF4444" # Rouge (échec)
|
||||
ACTION_ICON_INFO = "#64748B" # Gris (neutre)
|
||||
|
||||
# Dimensions — confortables
|
||||
WIN_WIDTH = 600
|
||||
WIN_HEIGHT = 800
|
||||
@@ -62,6 +90,80 @@ FONT_SEND_BTN = ("Segoe UI", 13)
|
||||
FONT_RESIZE_GRIP = ("Segoe UI", 10)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Templates de bulles "Léa exécute" (J3.4)
|
||||
# Chaque template prend un payload et retourne (icon, icon_color, title).
|
||||
# Les libellés sont volontairement neutres : le contexte métier vient du
|
||||
# payload (workflow, action, message), pas de hardcoding.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _tpl_action_started(payload: Dict[str, Any]) -> tuple:
|
||||
wf = payload.get("workflow") or "?"
|
||||
return ("▶", ACTION_ICON_RUN, f"Démarrage : {wf}")
|
||||
|
||||
|
||||
def _tpl_action_progress(payload: Dict[str, Any]) -> tuple:
|
||||
cur = payload.get("current", "?")
|
||||
tot = payload.get("total", "?")
|
||||
step = payload.get("step")
|
||||
title = step if step else f"Étape {cur}/{tot}"
|
||||
return ("⋯", ACTION_ICON_RUN, str(title))
|
||||
|
||||
|
||||
def _tpl_done(payload: Dict[str, Any]) -> tuple:
|
||||
success = bool(payload.get("success", True))
|
||||
msg = payload.get("message") or ("Terminé" if success else "Échec")
|
||||
if success:
|
||||
return ("✓", ACTION_ICON_OK, str(msg))
|
||||
return ("✗", ACTION_ICON_ERR, str(msg))
|
||||
|
||||
|
||||
def _tpl_need_confirm(payload: Dict[str, Any]) -> tuple:
|
||||
action = payload.get("action") or {}
|
||||
desc = action.get("description") if isinstance(action, dict) else None
|
||||
title = desc or "Validation requise"
|
||||
return ("?", ACTION_ICON_RUN, str(title))
|
||||
|
||||
|
||||
def _tpl_step_result(payload: Dict[str, Any]) -> tuple:
|
||||
status = (payload.get("status") or "").lower()
|
||||
msg = payload.get("message") or status or "Étape terminée"
|
||||
if status in ("ok", "success", "approved"):
|
||||
return ("✓", ACTION_ICON_OK, str(msg))
|
||||
if status in ("error", "failed"):
|
||||
return ("✗", ACTION_ICON_ERR, str(msg))
|
||||
return ("·", ACTION_ICON_INFO, str(msg))
|
||||
|
||||
|
||||
def _tpl_resumed(payload: Dict[str, Any]) -> tuple:
|
||||
return ("→", ACTION_ICON_OK, "Reprise")
|
||||
|
||||
|
||||
_ACTION_TEMPLATES = {
|
||||
"lea:action_started": _tpl_action_started,
|
||||
"lea:action_progress": _tpl_action_progress,
|
||||
"lea:done": _tpl_done,
|
||||
"lea:need_confirm": _tpl_need_confirm,
|
||||
"lea:step_result": _tpl_step_result,
|
||||
"lea:resumed": _tpl_resumed,
|
||||
}
|
||||
|
||||
|
||||
def _extract_meta(payload: Dict[str, Any]) -> str:
|
||||
"""Métadonnées techniques en pied de bulle (workflow, étape, replay_id court)."""
|
||||
parts = []
|
||||
wf = payload.get("workflow")
|
||||
if wf:
|
||||
parts.append(str(wf))
|
||||
cur, tot = payload.get("current"), payload.get("total")
|
||||
if cur is not None and tot is not None:
|
||||
parts.append(f"étape {cur}/{tot}")
|
||||
rid = payload.get("replay_id")
|
||||
if rid:
|
||||
parts.append(f"#{str(rid)[-6:]}")
|
||||
return " • ".join(parts)
|
||||
|
||||
|
||||
class ChatWindow:
|
||||
"""Fenetre de chat Lea en tkinter natif.
|
||||
|
||||
@@ -91,6 +193,8 @@ class ChatWindow:
|
||||
self._root = None
|
||||
self._ready = threading.Event()
|
||||
self._messages = [] # historique local
|
||||
self._bus: Optional[Any] = None # FeedbackBusClient (J3.3, peut rester None)
|
||||
self._active_paused_bubble: Optional[Dict[str, Any]] = None # bulle paused active (J3.5)
|
||||
|
||||
# S'abonner aux changements de l'etat partage
|
||||
if self._shared_state is not None:
|
||||
@@ -266,6 +370,9 @@ class ChatWindow:
|
||||
# Signaler que la fenetre est prete
|
||||
self._ready.set()
|
||||
|
||||
# Demarrer le bus feedback Lea (events 'lea:*' temps reel)
|
||||
self._start_feedback_bus()
|
||||
|
||||
# Boucle tkinter
|
||||
root.mainloop()
|
||||
|
||||
@@ -608,6 +715,12 @@ class ChatWindow:
|
||||
|
||||
def _do_destroy(self) -> None:
|
||||
"""Detruit la fenetre (appele dans le thread tkinter)."""
|
||||
if self._bus is not None:
|
||||
try:
|
||||
self._bus.stop()
|
||||
except Exception:
|
||||
pass
|
||||
self._bus = None
|
||||
if self._root is not None:
|
||||
try:
|
||||
self._root.quit()
|
||||
@@ -617,6 +730,232 @@ class ChatWindow:
|
||||
self._root = None
|
||||
self._visible = False
|
||||
|
||||
# ======================================================================
|
||||
# FeedbackBus — bulles temps reel pendant l'execution (J3.3)
|
||||
# ======================================================================
|
||||
|
||||
def _start_feedback_bus(self) -> None:
|
||||
"""Demarrer la connexion au bus 'lea:*' si flag actif et lib disponible."""
|
||||
if not _HAS_FEEDBACK_BUS:
|
||||
logger.debug("FeedbackBus non disponible (python-socketio manquant)")
|
||||
return
|
||||
flag = os.environ.get("LEA_FEEDBACK_BUS", "0").lower()
|
||||
if flag not in ("1", "true", "yes", "on"):
|
||||
return
|
||||
try:
|
||||
url = f"http://{self._server_host}:{self._chat_port}"
|
||||
token = os.environ.get("RPA_API_TOKEN", "") or None
|
||||
self._bus = FeedbackBusClient(url, token=token, on_event=self._on_lea_event)
|
||||
self._bus.start()
|
||||
logger.info("FeedbackBus demarre : %s", url)
|
||||
except Exception:
|
||||
logger.debug("FeedbackBus init silenced", exc_info=True)
|
||||
self._bus = None
|
||||
|
||||
def _on_lea_event(self, event: str, payload: Dict[str, Any]) -> None:
|
||||
"""Callback bus → bulle Lea. Thread-safe : helpers utilisent root.after."""
|
||||
payload = payload or {}
|
||||
|
||||
# J3.5 : la pause supervisée a sa propre bulle interactive
|
||||
if event == "lea:paused":
|
||||
self._add_paused_bubble(payload)
|
||||
return
|
||||
if event in ("lea:resumed", "lea:done"):
|
||||
self._close_active_paused_bubble(reason=event)
|
||||
# on continue pour afficher la bulle d'action (cf. dispatch ci-dessous)
|
||||
|
||||
# Acks bus (resume_acked, abort_acked) : silencieux côté UI
|
||||
if event in ("lea:resume_acked", "lea:abort_acked"):
|
||||
return
|
||||
|
||||
# J3.4 : bulle "Léa exécute" stylisée (séparée des bulles chat normales)
|
||||
rendered = _ACTION_TEMPLATES.get(event)
|
||||
if rendered is None:
|
||||
# Event inconnu : on affiche en bulle d'action neutre
|
||||
self._add_action_bubble(
|
||||
icon="·", icon_color=ACTION_ICON_INFO,
|
||||
title=event.removeprefix("lea:"),
|
||||
meta=_extract_meta(payload),
|
||||
)
|
||||
return
|
||||
icon, icon_color, title = rendered(payload)
|
||||
self._add_action_bubble(
|
||||
icon=icon, icon_color=icon_color, title=title,
|
||||
meta=_extract_meta(payload),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bulle "Léa exécute" stylisée (J3.4)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _add_action_bubble(
|
||||
self, icon: str, icon_color: str, title: str, meta: str = "",
|
||||
) -> None:
|
||||
if self._root is None:
|
||||
return
|
||||
self._root.after(0, lambda: self._render_action_bubble(icon, icon_color, title, meta))
|
||||
|
||||
def _render_action_bubble(
|
||||
self, icon: str, icon_color: str, title: str, meta: str,
|
||||
) -> None:
|
||||
tk = self._tk
|
||||
if getattr(self, "_msg_frame", None) is None:
|
||||
return
|
||||
now = datetime.now().strftime("%H:%M")
|
||||
|
||||
container = tk.Frame(self._msg_frame, bg=BG_COLOR)
|
||||
container.pack(fill=tk.X, padx=MARGIN, pady=3)
|
||||
|
||||
inner = tk.Frame(
|
||||
container, bg=ACTION_BG, padx=10, pady=6,
|
||||
highlightbackground=ACTION_BORDER, highlightthickness=1,
|
||||
)
|
||||
inner.pack(anchor=tk.W, padx=(0, 70), fill=tk.X)
|
||||
|
||||
row = tk.Frame(inner, bg=ACTION_BG)
|
||||
row.pack(fill=tk.X, anchor=tk.W)
|
||||
|
||||
tk.Label(
|
||||
row, text=icon, bg=ACTION_BG, fg=icon_color,
|
||||
font=("Segoe UI", 13, "bold"), padx=4,
|
||||
).pack(side=tk.LEFT)
|
||||
|
||||
tk.Label(
|
||||
row, text=title, bg=ACTION_BG, fg=ACTION_FG,
|
||||
font=FONT_MSG, anchor="w", justify=tk.LEFT,
|
||||
wraplength=MSG_WRAP_WIDTH - 60,
|
||||
).pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(2, 0))
|
||||
|
||||
if meta:
|
||||
tk.Label(
|
||||
inner, text=f"{meta} • {now}",
|
||||
bg=ACTION_BG, fg=ACTION_META_FG,
|
||||
font=FONT_TIMESTAMP, anchor="w",
|
||||
).pack(fill=tk.X, anchor=tk.W, pady=(2, 0))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bulle paused_need_help interactive (J3.5)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _add_paused_bubble(self, payload: Dict[str, Any]) -> None:
|
||||
"""Ajouter une bulle paused interactive (asset démo : Léa demande de l'aide)."""
|
||||
if self._root is None:
|
||||
return
|
||||
self._root.after(0, lambda: self._render_paused_bubble(payload))
|
||||
|
||||
def _render_paused_bubble(self, payload: Dict[str, Any]) -> None:
|
||||
tk = self._tk
|
||||
if getattr(self, "_msg_frame", None) is None:
|
||||
return
|
||||
|
||||
replay_id = str(payload.get("replay_id", "") or "")
|
||||
workflow = payload.get("workflow", "?")
|
||||
reason = payload.get("reason") or "Action incertaine — j'ai besoin de votre validation."
|
||||
completed = payload.get("completed", 0)
|
||||
total = payload.get("total", "?")
|
||||
now = datetime.now().strftime("%H:%M")
|
||||
|
||||
container = tk.Frame(self._msg_frame, bg=BG_COLOR)
|
||||
container.pack(fill=tk.X, padx=MARGIN, pady=6)
|
||||
|
||||
inner = tk.Frame(
|
||||
container, bg=PAUSED_BG, padx=14, pady=12,
|
||||
highlightbackground=PAUSED_BORDER, highlightthickness=2,
|
||||
)
|
||||
inner.pack(anchor=tk.W, padx=(0, 50), fill=tk.X)
|
||||
|
||||
tk.Label(
|
||||
inner, text=f"⏸ Pause supervisée • {now}",
|
||||
bg=PAUSED_BG, fg=PAUSED_FG,
|
||||
font=("Segoe UI", 12, "bold"), anchor="w",
|
||||
).pack(fill=tk.X, anchor=tk.W)
|
||||
|
||||
tk.Label(
|
||||
inner, text=reason, bg=PAUSED_BG, fg=PAUSED_FG,
|
||||
font=FONT_MSG, wraplength=MSG_WRAP_WIDTH - 30,
|
||||
anchor="w", justify=tk.LEFT,
|
||||
).pack(fill=tk.X, anchor=tk.W, pady=(6, 0))
|
||||
|
||||
tk.Label(
|
||||
inner, text=f"{workflow} — étape {completed}/{total}",
|
||||
bg=PAUSED_BG, fg=TIMESTAMP_FG, font=FONT_TIMESTAMP, anchor="w",
|
||||
).pack(fill=tk.X, anchor=tk.W, pady=(4, 8))
|
||||
|
||||
btn_frame = tk.Frame(inner, bg=PAUSED_BG)
|
||||
btn_frame.pack(fill=tk.X, anchor=tk.W)
|
||||
|
||||
btn_resume = tk.Button(
|
||||
btn_frame, text="Continuer",
|
||||
bg=PAUSED_BTN_RESUME_BG, fg="white", font=FONT_QUICK_BTN,
|
||||
padx=14, pady=4, bd=0, cursor="hand2",
|
||||
activebackground=PAUSED_BTN_RESUME_HOVER, activeforeground="white",
|
||||
command=lambda: self._on_paused_resume(replay_id),
|
||||
)
|
||||
btn_resume.pack(side=tk.LEFT, padx=(0, 8))
|
||||
|
||||
btn_abort = tk.Button(
|
||||
btn_frame, text="Annuler",
|
||||
bg=PAUSED_BTN_ABORT_BG, fg="white", font=FONT_QUICK_BTN,
|
||||
padx=14, pady=4, bd=0, cursor="hand2",
|
||||
activebackground=PAUSED_BTN_ABORT_HOVER, activeforeground="white",
|
||||
command=lambda: self._on_paused_abort(replay_id),
|
||||
)
|
||||
btn_abort.pack(side=tk.LEFT)
|
||||
|
||||
self._active_paused_bubble = {
|
||||
"container": container, "inner": inner,
|
||||
"btn_resume": btn_resume, "btn_abort": btn_abort,
|
||||
"replay_id": replay_id,
|
||||
}
|
||||
|
||||
def _close_active_paused_bubble(self, reason: str) -> None:
|
||||
if self._active_paused_bubble is None or self._root is None:
|
||||
return
|
||||
self._root.after(0, lambda: self._do_close_paused_bubble(reason))
|
||||
|
||||
def _do_close_paused_bubble(self, reason: str) -> None:
|
||||
bubble = self._active_paused_bubble
|
||||
if bubble is None:
|
||||
return
|
||||
try:
|
||||
bubble["btn_resume"].config(state="disabled")
|
||||
bubble["btn_abort"].config(state="disabled")
|
||||
label_text = {
|
||||
"lea:resumed": "→ Reprise",
|
||||
"lea:done": "→ Terminé",
|
||||
}.get(reason, f"→ {reason}")
|
||||
self._tk.Label(
|
||||
bubble["inner"], text=label_text,
|
||||
bg=PAUSED_BG, fg=PAUSED_FG, font=FONT_TIMESTAMP, anchor="w",
|
||||
).pack(fill="x", anchor="w", pady=(6, 0))
|
||||
except Exception:
|
||||
logger.debug("close paused bubble silenced", exc_info=True)
|
||||
self._active_paused_bubble = None
|
||||
|
||||
def _on_paused_resume(self, replay_id: str) -> None:
|
||||
if not replay_id or self._bus is None or not self._bus.connected:
|
||||
self._add_lea_message("⚠ Bus indisponible — impossible de relancer")
|
||||
return
|
||||
self._bus.resume_replay(replay_id)
|
||||
if self._active_paused_bubble:
|
||||
try:
|
||||
self._active_paused_bubble["btn_resume"].config(state="disabled")
|
||||
self._active_paused_bubble["btn_abort"].config(state="disabled")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _on_paused_abort(self, replay_id: str) -> None:
|
||||
if self._bus is None or not self._bus.connected:
|
||||
self._add_lea_message("⚠ Bus indisponible — impossible d'annuler")
|
||||
return
|
||||
self._bus.abort_replay(replay_id)
|
||||
if self._active_paused_bubble:
|
||||
try:
|
||||
self._active_paused_bubble["btn_resume"].config(state="disabled")
|
||||
self._active_paused_bubble["btn_abort"].config(state="disabled")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ======================================================================
|
||||
# Ajout de messages dans la zone de chat
|
||||
# ======================================================================
|
||||
|
||||
@@ -3,6 +3,7 @@ mss>=9.0.1 # Capture d'écran haute performance
|
||||
pynput>=1.7.7 # Clavier/Souris Cross-plateforme
|
||||
Pillow>=10.0.0 # Crops et processing image
|
||||
requests>=2.31.0 # Streaming réseau
|
||||
python-socketio[client]>=5.10,<6.0 # Bus feedback Léa 'lea:*' (compat Flask-SocketIO 5.3.x serveur)
|
||||
psutil>=5.9.0 # Monitoring CPU/RAM
|
||||
pystray>=0.19.5 # Icône Tray UI
|
||||
plyer>=2.1.0 # Notifications toast natives (remplace PyQt5)
|
||||
|
||||
@@ -116,13 +116,13 @@ def check_screen_for_patterns() -> Optional[Dict[str, Any]]:
|
||||
|
||||
pattern = lib.find_pattern(ocr_text)
|
||||
if pattern and pattern['category'] in ('dialog', 'popup'):
|
||||
logger.info(f"Pattern UI détecté: {pattern['pattern']} → {pattern['action']} '{pattern['target']}'")
|
||||
print(f"🧠 [PatternCheck] Détecté: '{pattern['pattern']}' → {pattern['action']} '{pattern['target']}'")
|
||||
return pattern
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Pattern check échoué: {e}")
|
||||
print(f"⚠️ [PatternCheck] Erreur: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -145,26 +145,40 @@ def handle_detected_pattern(pattern: Dict[str, Any]) -> bool:
|
||||
|
||||
if action == 'click':
|
||||
candidates_labels = [target] + alternatives
|
||||
print(f"🔧 [Réflexe/handle] Recherche bouton parmi: {candidates_labels}")
|
||||
|
||||
try:
|
||||
import mss
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
# Importer OCR (essayer les deux chemins)
|
||||
try:
|
||||
from services.ocr_service import ocr_extract_words
|
||||
except ImportError:
|
||||
from core.extraction.field_extractor import FieldExtractor
|
||||
extractor = FieldExtractor()
|
||||
def ocr_extract_words(img):
|
||||
return extractor.extract_words_from_image(img)
|
||||
|
||||
with mss.mss() as sct:
|
||||
monitor = sct.monitors[0]
|
||||
screenshot = sct.grab(monitor)
|
||||
screen = Image.frombytes('RGB', screenshot.size, screenshot.bgra, 'raw', 'BGRX')
|
||||
|
||||
words = ocr_extract_words(screen)
|
||||
# EasyOCR (rapide, bonne qualité GUI) avec fallback docTR
|
||||
words = []
|
||||
try:
|
||||
import easyocr
|
||||
_reader = easyocr.Reader(['fr', 'en'], gpu=False, verbose=False)
|
||||
results = _reader.readtext(np.array(screen))
|
||||
for (bbox_pts, text, conf) in results:
|
||||
if not text or len(text.strip()) < 1:
|
||||
continue
|
||||
x1 = int(min(p[0] for p in bbox_pts))
|
||||
y1 = int(min(p[1] for p in bbox_pts))
|
||||
x2 = int(max(p[0] for p in bbox_pts))
|
||||
y2 = int(max(p[1] for p in bbox_pts))
|
||||
words.append({'text': text.strip(), 'bbox': [x1, y1, x2, y2]})
|
||||
except ImportError:
|
||||
try:
|
||||
from services.ocr_service import ocr_extract_words
|
||||
words = ocr_extract_words(screen) or []
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
print(f"🔧 [Réflexe/handle] {len(words)} mots OCR détectés")
|
||||
|
||||
# Collecter tous les matchs, prendre le plus bas (bouton = bas du dialogue)
|
||||
all_matches = []
|
||||
@@ -175,58 +189,28 @@ def handle_detected_pattern(pattern: Dict[str, Any]) -> bool:
|
||||
word_text = word['text'].lower()
|
||||
if len(word_text) < 2 or len(candidate_lower) < 2:
|
||||
continue
|
||||
if word_text == candidate_lower:
|
||||
# Match exact ou inclusion
|
||||
if word_text == candidate_lower or candidate_lower in word_text or word_text in candidate_lower:
|
||||
x1, y1, x2, y2 = word['bbox']
|
||||
all_matches.append({
|
||||
'text': word['text'],
|
||||
'x': int((x1 + x2) / 2),
|
||||
'y': int((y1 + y2) / 2),
|
||||
'match_type': 'exact',
|
||||
'candidate': candidate,
|
||||
})
|
||||
|
||||
# Recherche partielle (lettre soulignée manquante)
|
||||
if not all_matches:
|
||||
for candidate in candidates_labels:
|
||||
if len(candidate) > 3:
|
||||
partial = candidate[1:].lower()
|
||||
for word in words:
|
||||
if partial in word['text'].lower():
|
||||
x1, y1, x2, y2 = word['bbox']
|
||||
all_matches.append({
|
||||
'text': word['text'],
|
||||
'x': int((x1 + x2) / 2),
|
||||
'y': int((y1 + y2) / 2),
|
||||
'match_type': 'partial',
|
||||
})
|
||||
|
||||
if all_matches:
|
||||
best = max(all_matches, key=lambda m: m['y'])
|
||||
logger.info(f"Clic sur '{best['text']}' à ({best['x']}, {best['y']})")
|
||||
print(f"✅ [Réflexe/handle] Clic sur '{best['text']}' à ({best['x']}, {best['y']})")
|
||||
pyautogui.click(best['x'], best['y'])
|
||||
time.sleep(1.0)
|
||||
return True
|
||||
|
||||
logger.info(f"Bouton '{target}' introuvable par OCR — appel VLM...")
|
||||
vlm_result = vlm_reason_about_screen(
|
||||
objective=f"Cliquer sur le bouton '{target}'",
|
||||
context=f"Un dialogue '{pattern.get('pattern')}' est détecté"
|
||||
)
|
||||
if vlm_result and vlm_result.get('action') == 'click' and vlm_result.get('target'):
|
||||
vlm_target = vlm_result['target']
|
||||
for word in words:
|
||||
if vlm_target.lower() in word['text'].lower():
|
||||
x1, y1, x2, y2 = word['bbox']
|
||||
x = int((x1 + x2) / 2)
|
||||
y = int((y1 + y2) / 2)
|
||||
logger.info(f"VLM → clic sur '{word['text']}' à ({x}, {y})")
|
||||
pyautogui.click(x, y)
|
||||
time.sleep(1.0)
|
||||
return True
|
||||
|
||||
print(f"⚠️ [Réflexe/handle] Bouton '{target}' introuvable parmi {[w['text'] for w in words[:15]]}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"OCR bouton échoué: {e}")
|
||||
print(f"⚠️ [Réflexe/handle] Erreur: {e}")
|
||||
return False
|
||||
|
||||
elif action == 'hotkey':
|
||||
|
||||
@@ -213,8 +213,40 @@ class ORALoop:
|
||||
|
||||
# --- Mapper action_type vers action Decision ---
|
||||
|
||||
# Types d'action qui ne sont PAS des descriptions valides
|
||||
_action_type_names = {'click_anchor', 'double_click_anchor', 'right_click_anchor',
|
||||
'hover_anchor', 'focus_anchor', 'scroll_to_anchor',
|
||||
'click', 'type_text', 'keyboard_shortcut', 'wait_for_anchor'}
|
||||
|
||||
if action_type in ('click_anchor', 'click', 'double_click_anchor', 'right_click_anchor'):
|
||||
target_text = anchor.get('target_text', '') or label
|
||||
target_text = anchor.get('target_text', '') or anchor.get('description', '')
|
||||
|
||||
# Détecter les target_text absurdes : vide, nom d'action, ou bruit OCR
|
||||
def _is_garbage(t):
|
||||
if not t or t in _action_type_names:
|
||||
return True
|
||||
# Bruit OCR : que des caractères spéciaux/chiffres/espaces
|
||||
cleaned = t.replace('-', '').replace(' ', '').replace('.', '').replace('_', '')
|
||||
if len(cleaned) < 3:
|
||||
return True
|
||||
# Que des chiffres
|
||||
if cleaned.isdigit():
|
||||
return True
|
||||
return False
|
||||
|
||||
# Note: plus d'appel à _describe_anchor_image() (qwen2.5vl) ici.
|
||||
# Le crop d'ancre (screenshot_b64) servira directement au template matching
|
||||
# cv2 dans _act_click, puis fallback InfiGUI fusionné si nécessaire.
|
||||
# Cela évite le conflit VRAM (qwen2.5vl 9.4GB + InfiGUI 2.4GB > 11.5GB GPU).
|
||||
|
||||
# Dernier fallback : label si pas un nom d'action
|
||||
if _is_garbage(target_text):
|
||||
target_text = label if label not in _action_type_names else ''
|
||||
if target_text:
|
||||
print(f"🏷️ [ORA/reason] Label garbage, fallback texte: '{target_text}'")
|
||||
else:
|
||||
print(f"🏷️ [ORA/reason] Pas de label texte — grounding via crop visuel uniquement")
|
||||
|
||||
action = 'click'
|
||||
value = 'double' if action_type == 'double_click_anchor' else (
|
||||
'right' if action_type == 'right_click_anchor' else 'left')
|
||||
@@ -1222,6 +1254,7 @@ Règles:
|
||||
)
|
||||
|
||||
print(f"🚀 [ORA] Démarrage workflow: {total} étapes, verify={self.verify_level}, retries={self.max_retries}")
|
||||
print(f"🔧 [ORA] CODE VERSION: post-shortcut-dialog-handler ACTIF (26 avril 17h30)")
|
||||
|
||||
for i, step in enumerate(steps):
|
||||
if not self._should_continue():
|
||||
@@ -1234,6 +1267,28 @@ Règles:
|
||||
# --- 1. Observer l'état pré-action ---
|
||||
pre = self.observe()
|
||||
|
||||
# --- 1b. Réflexe : dialogue inattendu ? ---
|
||||
# Déclenché si le pHash a changé de manière inattendue.
|
||||
# Flux : titre fenêtre (50ms) → dialogue connu ? → InfiGUI clique (3s)
|
||||
if i > 0 and hasattr(self, '_last_post_phash') and self._last_post_phash:
|
||||
_phash_distance = self._phash_distance(pre.phash, self._last_post_phash)
|
||||
if _phash_distance > 10:
|
||||
print(f"🧠 [ORA/réflexe] pHash changé (distance={_phash_distance}) → vérification dialogue")
|
||||
try:
|
||||
from core.grounding.dialog_handler import DialogHandler
|
||||
_dh = DialogHandler()
|
||||
_dh_result = _dh.handle_if_dialog(pre.screenshot)
|
||||
if _dh_result.get('handled'):
|
||||
print(f"✅ [ORA/réflexe] Dialogue '{_dh_result['title'][:30]}' géré → {_dh_result['action']}")
|
||||
time.sleep(0.5)
|
||||
pre = self.observe()
|
||||
elif _dh_result.get('dialog_type'):
|
||||
print(f"⚠️ [ORA/réflexe] Dialogue '{_dh_result.get('dialog_type')}' détecté mais non géré: {_dh_result.get('reason')}")
|
||||
else:
|
||||
print(f"🧠 [ORA/réflexe] Pas de dialogue détecté: {_dh_result.get('reason', '?')}")
|
||||
except Exception as _reflex_err:
|
||||
print(f"⚠️ [ORA/réflexe] Erreur: {_reflex_err}")
|
||||
|
||||
# --- 2. Raisonner : construire la Decision ---
|
||||
decision = self.reason_workflow_step(step, pre)
|
||||
|
||||
@@ -1281,11 +1336,74 @@ Règles:
|
||||
)
|
||||
)
|
||||
|
||||
# --- 3b. Post-raccourci : attendre changement écran + gérer dialogue ---
|
||||
# Après un keyboard_shortcut (pas scroll), on polle le pHash pour détecter
|
||||
# si un dialogue est apparu (ex: "Enregistrer sous" après Ctrl+Shift+S).
|
||||
# Si oui → InfiGUI localise et clique le bouton visuellement.
|
||||
if act_success and decision.action == 'hotkey' and not decision.value.startswith('scroll_'):
|
||||
print(f"🔍 [ORA/post-shortcut] ENTRÉ dans le bloc post-shortcut (action={decision.action}, value={decision.value})")
|
||||
dialog_handled = self._handle_post_shortcut(pre)
|
||||
if dialog_handled:
|
||||
time.sleep(0.5)
|
||||
post = self.observe()
|
||||
self._last_post_phash = post.phash
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, VerificationResult(
|
||||
success=True, change_level='major',
|
||||
matches_expected=True,
|
||||
detail="Dialogue géré visuellement après raccourci"
|
||||
))
|
||||
continue
|
||||
else:
|
||||
# Invariant : aucune étape suivante ne doit s'exécuter tant que
|
||||
# la cascade déclenchée par le raccourci n'est pas pleinement résolue.
|
||||
# Cas typique : Ctrl+S → "Enregistrer sous" non géré → on ABORT plutôt
|
||||
# que de cliquer sur des coordonnées potentiellement obsolètes.
|
||||
msg = (
|
||||
f"Étape {i+1}: raccourci '{decision.value}' — cascade post-raccourci "
|
||||
f"non résolue (dialogue absent ou bloqué). Workflow stoppé pour éviter "
|
||||
f"un clic dans un contexte incohérent."
|
||||
)
|
||||
print(f"❌ [ORA/post-shortcut] {msg}")
|
||||
logger.warning(f"🆘 [ORA] {msg}")
|
||||
if on_progress:
|
||||
on_progress(i + 1, total, VerificationResult(
|
||||
success=False, change_level='none',
|
||||
matches_expected=False,
|
||||
detail="Cascade post-raccourci non résolue"
|
||||
))
|
||||
return LoopResult(
|
||||
success=False, steps_completed=i, total_steps=total,
|
||||
reason=msg,
|
||||
)
|
||||
|
||||
# Petit délai pour laisser l'écran se stabiliser
|
||||
time.sleep(0.3)
|
||||
|
||||
# --- 4. Observer l'état post-action ---
|
||||
post = self.observe()
|
||||
# Stocker le pHash post-action pour le réflexe check du step suivant
|
||||
self._last_post_phash = post.phash
|
||||
|
||||
# --- 4b. Vérification titre OCR (non-bloquante, ~120ms) ---
|
||||
_action_type = step.get('action_type', '')
|
||||
if _action_type in ('double_click_anchor', 'click_anchor') and pre.screenshot and post.screenshot:
|
||||
try:
|
||||
from core.grounding.title_verifier import TitleVerifier
|
||||
_tv = TitleVerifier()
|
||||
_tv_result = _tv.verify_action(pre.screenshot, post.screenshot, _action_type)
|
||||
if not _tv_result['success']:
|
||||
print(f"⚠️ [ORA/titre] {_tv_result['reason']} → retry")
|
||||
# Retry : recliquer
|
||||
time.sleep(0.5)
|
||||
self.act(decision, step)
|
||||
time.sleep(0.3)
|
||||
post = self.observe()
|
||||
self._last_post_phash = post.phash
|
||||
elif _tv_result['changed']:
|
||||
print(f"✅ [ORA/titre] '{_tv_result['title_after'][:40]}'")
|
||||
except Exception as _tv_err:
|
||||
print(f"⚠️ [ORA/titre] Erreur: {_tv_err}")
|
||||
|
||||
# --- 5. Vérifier ---
|
||||
verification = self.verify(pre, post, decision)
|
||||
@@ -1345,10 +1463,112 @@ Règles:
|
||||
# Méthodes privées — actions
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
def _handle_post_shortcut(self, pre_obs: 'Observation') -> bool:
|
||||
"""Après un raccourci clavier, résoudre la cascade de dialogues réflexes.
|
||||
|
||||
Pilotage par DialogHandler (OCR direct), PAS par pHash. Raison :
|
||||
un dialog modal qui s'ouvre dans une VM ne change quasiment pas le
|
||||
pHash global de l'écran hôte (signature 8x8 sur 1920x1080 — un dialog
|
||||
de 800x500 couvre ~3 pixels pHash, distance Hamming souvent < 3).
|
||||
On poll donc directement DialogHandler.handle_if_dialog().
|
||||
|
||||
Returns:
|
||||
True si au moins un dialog connu a été détecté + géré et qu'aucun
|
||||
autre dialog n'apparaît dans la fenêtre de stabilité finale.
|
||||
False si aucun dialog connu n'apparaît dans la fenêtre d'attente
|
||||
initiale (le workflow doit ABORT — état incohérent).
|
||||
"""
|
||||
from core.grounding.dialog_handler import DialogHandler
|
||||
|
||||
# Fenêtre d'attente du PREMIER dialog après le raccourci. Win11/QEMU :
|
||||
# Ctrl+Shift+S → "Enregistrer sous" apparaît en <2s typiquement.
|
||||
first_dialog_timeout = 8.0
|
||||
# Budget total pour résoudre toute la cascade (InfiGUI ~15s/dialog).
|
||||
total_timeout = 60.0
|
||||
# Fenêtre de stabilité après le dernier dialog géré : si rien d'autre
|
||||
# n'apparaît pendant cette durée, la cascade est considérée terminée.
|
||||
# Doit couvrir l'apparition du popup modal suivant (post_click_wait + marge).
|
||||
stable_window = 3.0
|
||||
# Délai post-clic avant de tester le dialog suivant.
|
||||
post_click_wait = 1.5
|
||||
# Cadence de polling OCR (EasyOCR full-screen ~500ms/poll).
|
||||
poll_interval = 0.5
|
||||
# Garde-fou anti-boucle infinie.
|
||||
max_dialog_iterations = 5
|
||||
|
||||
t_start = time.time()
|
||||
dh = DialogHandler()
|
||||
dialogs_handled = 0
|
||||
|
||||
def _elapsed() -> float:
|
||||
return time.time() - t_start
|
||||
|
||||
def _poll_dialog(deadline: float) -> Optional[Dict[str, Any]]:
|
||||
"""Poll DialogHandler jusqu'à détection d'un dialog connu OU deadline.
|
||||
|
||||
Retourne le dict result si un dialog connu a été géré (cliqué),
|
||||
None si la deadline est atteinte sans match. Si DialogHandler
|
||||
détecte ET clique avec succès, le clic InfiGUI peut excéder la
|
||||
deadline mais on retourne quand même le résultat (action déjà
|
||||
engagée — on ne va pas l'annuler).
|
||||
"""
|
||||
while time.time() < deadline:
|
||||
obs = self.observe()
|
||||
try:
|
||||
result = dh.handle_if_dialog(obs.screenshot)
|
||||
except Exception as e:
|
||||
print(f"⚠️ [ORA/post-shortcut] Erreur dialog handler: {e}")
|
||||
return None
|
||||
if result.get('handled'):
|
||||
return result
|
||||
sleep_left = deadline - time.time()
|
||||
if sleep_left > 0:
|
||||
time.sleep(min(poll_interval, sleep_left))
|
||||
return None
|
||||
|
||||
# --- Étape 1 : attendre le PREMIER dialog ---
|
||||
first_deadline = t_start + min(total_timeout, first_dialog_timeout)
|
||||
result = _poll_dialog(first_deadline)
|
||||
if result is None:
|
||||
print(f"⏳ [ORA/post-shortcut] Aucun dialog connu détecté après "
|
||||
f"{_elapsed():.1f}s (fenêtre={first_dialog_timeout}s) — "
|
||||
f"raccourci sans effet attendu")
|
||||
return False
|
||||
|
||||
dialogs_handled = 1
|
||||
print(f"✅ [ORA/post-shortcut] Dialog #1 géré: {result.get('action')} "
|
||||
f"({_elapsed():.1f}s)")
|
||||
time.sleep(post_click_wait)
|
||||
|
||||
# --- Étape 2 : cascade — chaque dialog suivant doit apparaître dans stable_window ---
|
||||
for iteration in range(1, max_dialog_iterations):
|
||||
if _elapsed() >= total_timeout:
|
||||
print(f"⏳ [ORA/post-shortcut] Timeout cascade ({total_timeout:.0f}s, "
|
||||
f"{dialogs_handled} dialog(s) géré(s))")
|
||||
return True # au moins un dialog traité → considéré OK
|
||||
|
||||
next_deadline = min(time.time() + stable_window, t_start + total_timeout)
|
||||
result = _poll_dialog(next_deadline)
|
||||
if result is None:
|
||||
# Pas de nouveau dialog dans stable_window → cascade terminée
|
||||
print(f"✅ [ORA/post-shortcut] Cascade résolue "
|
||||
f"({dialogs_handled} dialog(s), {_elapsed():.1f}s)")
|
||||
return True
|
||||
|
||||
dialogs_handled += 1
|
||||
print(f"✅ [ORA/post-shortcut] Dialog #{dialogs_handled} géré: "
|
||||
f"{result.get('action')} ({_elapsed():.1f}s)")
|
||||
time.sleep(post_click_wait)
|
||||
|
||||
print(f"⚠️ [ORA/post-shortcut] Trop d'itérations cascade "
|
||||
f"({max_dialog_iterations}) — cascade malformée, on s'arrête là")
|
||||
return dialogs_handled > 0
|
||||
|
||||
def _act_click(self, decision: Decision, step_params: dict) -> bool:
|
||||
"""Exécute un clic (simple, double, droit, hover, focus).
|
||||
|
||||
Pipeline : template matching → find_element_on_screen (OCR → UI-TARS → VLM).
|
||||
Pipeline FAST→SMART→THINK (si activé) ou ancien pipeline en fallback.
|
||||
Activé par la variable d'environnement RPA_USE_FAST_PIPELINE=1.
|
||||
"""
|
||||
if not PYAUTOGUI_AVAILABLE:
|
||||
logger.error("pyautogui non disponible")
|
||||
@@ -1357,29 +1577,23 @@ Règles:
|
||||
anchor = step_params.get('visual_anchor', {})
|
||||
screenshot_b64 = anchor.get('screenshot')
|
||||
bbox = anchor.get('bounding_box', {})
|
||||
target_text = anchor.get('target_text', '') or decision.target
|
||||
# Utiliser le target nettoyé par reason_workflow_step (pas relire le garbage de l'ancre)
|
||||
target_text = decision.target
|
||||
target_desc = anchor.get('description', '')
|
||||
|
||||
print(f"🎯 [ORA/_act_click] target='{target_text}', desc='{target_desc[:40]}', bbox={bbox.get('x','?')},{bbox.get('y','?')}")
|
||||
|
||||
x, y = None, None
|
||||
method_used = ''
|
||||
# Score et position du template-first (réutilisés en fallback intermédiaire)
|
||||
template_score = 0.0
|
||||
template_xy: Optional[tuple] = None
|
||||
|
||||
# --- Méthode 1 : UI-TARS grounding (~3s, 94% précision) ---
|
||||
# Le plus fiable : on dit "click on X" et UI-TARS trouve les coordonnées
|
||||
if target_text or target_desc:
|
||||
try:
|
||||
from core.execution.input_handler import _grounding_ui_tars
|
||||
click_label = target_desc or target_text
|
||||
print(f"🎯 [ORA/UI-TARS] Recherche: '{click_label}'")
|
||||
result = _grounding_ui_tars(target_text, target_desc)
|
||||
if result:
|
||||
x, y = result['x'], result['y']
|
||||
method_used = 'ui_tars'
|
||||
print(f"✅ [ORA/UI-TARS] Trouvé à ({x}, {y})")
|
||||
except Exception as e:
|
||||
logger.debug(f"⚠️ [ORA/UI-TARS] Erreur: {e}")
|
||||
|
||||
# --- Méthode 2 : Template matching (~80ms) ---
|
||||
if x is None and screenshot_b64 and CV2_AVAILABLE and PIL_AVAILABLE and MSS_AVAILABLE:
|
||||
# --- AVANT-POSTE : template matching cv2 sur le crop d'ancre ---
|
||||
# Si l'UI n'a pas changé (cas dominant en replay), un match pixel-perfect
|
||||
# nous donne le clic en ~50ms sans toucher au GPU. On ne déclenche le
|
||||
# pipeline VLM que si le score est insuffisant.
|
||||
if screenshot_b64 and CV2_AVAILABLE and PIL_AVAILABLE and MSS_AVAILABLE:
|
||||
try:
|
||||
import io as _io
|
||||
with mss_lib.mss() as sct:
|
||||
@@ -1399,15 +1613,70 @@ Règles:
|
||||
result_tm = cv2.matchTemplate(screen_cv, anchor_cv, cv2.TM_CCOEFF_NORMED)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
print(f"⚡ [ORA/template] score={max_val:.3f} pos={max_loc} ({elapsed_ms:.0f}ms)")
|
||||
if max_val > 0.75:
|
||||
x = max_loc[0] + anchor_cv.shape[1] // 2
|
||||
y = max_loc[1] + anchor_cv.shape[0] // 2
|
||||
method_used = 'template'
|
||||
template_score = float(max_val)
|
||||
template_xy = (
|
||||
max_loc[0] + anchor_cv.shape[1] // 2,
|
||||
max_loc[1] + anchor_cv.shape[0] // 2,
|
||||
)
|
||||
print(f"⚡ [ORA/template-first] score={template_score:.3f} pos={max_loc} ({elapsed_ms:.0f}ms)")
|
||||
# Seuil élevé pour le mode "direct" : on veut être quasi-certain
|
||||
# que c'est le même élément, pixel-perfect, avant de zapper le VLM.
|
||||
if template_score >= 0.95:
|
||||
x, y = template_xy
|
||||
method_used = 'template_direct'
|
||||
print(f"✅ [ORA/template-first] Match direct → ({x}, {y}), skip pipeline")
|
||||
except Exception as e:
|
||||
logger.debug(f"⚠️ [ORA/template] Erreur: {e}")
|
||||
print(f"⚠️ [ORA/template-first] Erreur: {e}")
|
||||
|
||||
# --- Pipeline FAST→SMART→THINK (escalade si template-first n'a pas tranché) ---
|
||||
_use_fast = os.environ.get('RPA_USE_FAST_PIPELINE', '1') == '1'
|
||||
|
||||
if x is None and _use_fast and (target_text or target_desc or screenshot_b64):
|
||||
print(f"🎯 [ORA/_act_click] RPA_USE_FAST_PIPELINE={_use_fast}, has_target={bool(target_text or target_desc)}, template_score={template_score:.3f}")
|
||||
try:
|
||||
from core.grounding.fast_pipeline import FastSmartThinkPipeline
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
_pipeline = FastSmartThinkPipeline.get_instance()
|
||||
|
||||
# Capture unique de l'écran
|
||||
_screen_pil = None
|
||||
if MSS_AVAILABLE and PIL_AVAILABLE:
|
||||
with mss_lib.mss() as _sct:
|
||||
_mon = _sct.monitors[0]
|
||||
_grab = _sct.grab(_mon)
|
||||
_screen_pil = Image.frombytes('RGB', _grab.size, _grab.bgra, 'raw', 'BGRX')
|
||||
|
||||
_target = GroundingTarget(
|
||||
text=target_text,
|
||||
description=target_desc,
|
||||
template_b64=screenshot_b64 or "",
|
||||
original_bbox=bbox if bbox else None,
|
||||
)
|
||||
|
||||
_result = _pipeline.locate(
|
||||
_target,
|
||||
screenshot_pil=_screen_pil,
|
||||
window_title=getattr(self, '_last_window_title', ''),
|
||||
)
|
||||
|
||||
if _result:
|
||||
x, y = _result.x, _result.y
|
||||
method_used = _result.method
|
||||
print(f"🎯 [ORA/pipeline] ({x}, {y}) via {method_used} "
|
||||
f"conf={_result.confidence:.3f} ({_result.time_ms:.0f}ms)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [ORA/pipeline] Erreur: {e}")
|
||||
|
||||
# --- Fallback : on réutilise le score template-first si pertinent ---
|
||||
# Si le pipeline VLM a échoué mais que le template-first avait un score
|
||||
# intermédiaire (0.75-0.95), on accepte ce match comme secours.
|
||||
if x is None and template_xy is not None and template_score >= 0.75:
|
||||
x, y = template_xy
|
||||
method_used = 'template_fallback'
|
||||
print(f"⚡ [ORA/template-fallback] Réutilisation score={template_score:.3f} → ({x}, {y})")
|
||||
|
||||
# --- Méthode 3 : OCR texte (~1s) ---
|
||||
if x is None and target_text:
|
||||
try:
|
||||
from core.execution.input_handler import _grounding_ocr
|
||||
@@ -1417,22 +1686,21 @@ Règles:
|
||||
method_used = 'ocr'
|
||||
print(f"🔍 [ORA/OCR] Trouvé à ({x}, {y})")
|
||||
except Exception as e:
|
||||
logger.debug(f"⚠️ [ORA/OCR] Erreur: {e}")
|
||||
print(f"⚠️ [ORA/OCR] Erreur: {e}")
|
||||
|
||||
# --- Exécuter le clic ---
|
||||
# --- Dernier recours : coordonnées statiques ---
|
||||
if x is None:
|
||||
# Dernier recours : coordonnées statiques de l'ancre
|
||||
if bbox and bbox.get('width') and bbox.get('height'):
|
||||
x = int(bbox.get('x', 0) + bbox.get('width', 0) / 2)
|
||||
y = int(bbox.get('y', 0) + bbox.get('height', 0) / 2)
|
||||
method_used = 'static_fallback'
|
||||
logger.warning(f"⚠️ [ORA/click] Fallback coordonnées statiques: ({x}, {y})")
|
||||
print(f"⚠️ [ORA/click] Fallback coordonnées statiques: ({x}, {y})")
|
||||
else:
|
||||
logger.error(f"❌ [ORA/click] Impossible de localiser '{target_text}' — aucune méthode n'a fonctionné")
|
||||
print(f"❌ [ORA/click] Impossible de localiser '{target_text}'")
|
||||
return False
|
||||
|
||||
# --- Vérification pré-action : est-ce le bon élément ? ---
|
||||
if target_text and method_used not in ('template',) and MSS_AVAILABLE and PIL_AVAILABLE:
|
||||
# --- Pas de pre-check VLM (le pipeline FAST→SMART→THINK a déjà validé) ---
|
||||
if False:
|
||||
try:
|
||||
pre_check = self._verify_pre_click(x, y, target_text, target_desc)
|
||||
if not pre_check:
|
||||
|
||||
20
core/grounding/__init__.py
Normal file
20
core/grounding/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# core/grounding — Module de localisation d'éléments UI
|
||||
#
|
||||
# Centralise les méthodes de grounding visuel : template matching,
|
||||
# OCR, VLM, etc. Chaque méthode produit un GroundingResult uniforme.
|
||||
#
|
||||
# Le serveur de grounding (server.py) tourne dans un process séparé
|
||||
# sur le port 8200. Le client HTTP (UITarsGrounder) l'appelle via HTTP.
|
||||
# Le pipeline (GroundingPipeline) orchestre template → OCR → UI-TARS → static.
|
||||
|
||||
from core.grounding.template_matcher import TemplateMatcher, MatchResult
|
||||
from core.grounding.target import GroundingTarget, GroundingResult
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
from core.grounding.pipeline import GroundingPipeline
|
||||
|
||||
__all__ = [
|
||||
'TemplateMatcher', 'MatchResult',
|
||||
'GroundingTarget', 'GroundingResult',
|
||||
'UITarsGrounder',
|
||||
'GroundingPipeline',
|
||||
]
|
||||
256
core/grounding/dialog_handler.py
Normal file
256
core/grounding/dialog_handler.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
core/grounding/dialog_handler.py — Gestion intelligente des dialogues
|
||||
|
||||
Quand un dialogue inattendu apparaît (pHash change après une action) :
|
||||
1. Lire le titre de la fenêtre (EasyOCR crop 45px, ~130ms)
|
||||
2. Si titre connu (Enregistrer sous, Confirmer, etc.) → action connue
|
||||
3. Demander à InfiGUI de cliquer sur le bon bouton (~3s)
|
||||
4. Vérifier que le dialogue a disparu (pHash)
|
||||
|
||||
Pas de patterns prédéfinis pour les boutons. InfiGUI comprend
|
||||
visuellement le dialogue et clique au bon endroit.
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.dialog_handler import DialogHandler
|
||||
|
||||
handler = DialogHandler()
|
||||
result = handler.handle_if_dialog(screenshot_pil)
|
||||
if result['handled']:
|
||||
print(f"Dialogue '{result['title']}' géré → {result['action']}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
# Titres connus → quelle action demander à InfiGUI.
|
||||
#
|
||||
# IMPORTANT — ordre du dict = priorité de matching.
|
||||
# L'OCR est full-screen et capte souvent le texte du dialog parent ET du popup
|
||||
# modal qui apparaît par-dessus (ex: "Enregistrer sous" reste visible derrière
|
||||
# "Confirmer l'enregistrement"). Les popups modaux DOIVENT matcher avant les
|
||||
# fenêtres principales, sinon Léa clique sur le bouton du parent qui n'a pas
|
||||
# le focus.
|
||||
KNOWN_DIALOGS = {
|
||||
# ── Popups modaux de confirmation (priorité HAUTE) ──────────────────
|
||||
"voulez-vous le remplacer": {"target": "Oui", "description": "Clique sur Oui pour confirmer le remplacement du fichier"},
|
||||
"do you want to replace": {"target": "Yes", "description": "Click Yes to confirm file replacement"},
|
||||
"existe déjà": {"target": "Oui", "description": "Clique sur Oui, le fichier existe déjà et doit être remplacé"},
|
||||
"already exists": {"target": "Yes", "description": "Click Yes, the file already exists"},
|
||||
"remplacer": {"target": "Oui", "description": "Clique sur le bouton Oui pour confirmer le remplacement du fichier"},
|
||||
"replace": {"target": "Yes", "description": "Click Yes to confirm file replacement"},
|
||||
"écraser": {"target": "Oui", "description": "Clique sur Oui pour écraser le fichier"},
|
||||
"overwrite": {"target": "Yes", "description": "Click Yes to overwrite"},
|
||||
"confirmer l'enregistrement": {"target": "Oui", "description": "Clique sur Oui dans le popup de confirmation d'enregistrement"},
|
||||
"confirmer": {"target": "Oui", "description": "Clique sur le bouton Oui dans le dialogue de confirmation"},
|
||||
# ── Avertissements/erreurs (priorité haute, 1 seul bouton OK) ───────
|
||||
"erreur": {"target": "OK", "description": "Clique sur OK pour fermer le message d'erreur"},
|
||||
"error": {"target": "OK", "description": "Click OK to close the error message"},
|
||||
"avertissement": {"target": "OK", "description": "Clique sur OK pour fermer l'avertissement"},
|
||||
"warning": {"target": "OK", "description": "Click OK to close the warning"},
|
||||
# ── Dialogs principaux de sauvegarde (priorité BASSE — fenêtres parents) ─
|
||||
"voulez-vous enregistrer": {"target": "Enregistrer", "description": "Clique sur Enregistrer pour sauvegarder les modifications"},
|
||||
"do you want to save": {"target": "Save", "description": "Click Save to save changes"},
|
||||
"enregistrer sous": {"target": "Enregistrer", "description": "Clique sur le bouton Enregistrer dans le dialogue Enregistrer sous"},
|
||||
"save as": {"target": "Save", "description": "Click the Save button in the Save As dialog"},
|
||||
}
|
||||
|
||||
|
||||
class DialogHandler:
|
||||
"""Gestion intelligente des dialogues via titre + InfiGUI."""
|
||||
|
||||
def __init__(self):
|
||||
self._easyocr_reader = None
|
||||
|
||||
def handle_if_dialog(
|
||||
self,
|
||||
screenshot_pil,
|
||||
previous_title: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""Vérifie si l'écran montre un dialogue et le gère.
|
||||
|
||||
Args:
|
||||
screenshot_pil: Screenshot PIL actuel.
|
||||
previous_title: Titre de la fenêtre avant l'action (pour comparaison).
|
||||
|
||||
Returns:
|
||||
Dict avec 'handled' (bool), 'title', 'action', 'position'.
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# 1. Lire le titre de la fenêtre
|
||||
title = self._read_title(screenshot_pil)
|
||||
if not title or len(title) < 3:
|
||||
return {'handled': False, 'title': '', 'reason': 'Titre illisible'}
|
||||
|
||||
print(f"🔍 [Dialog] Titre lu: '{title}'")
|
||||
|
||||
# 2. Chercher si c'est un dialogue connu
|
||||
matched_dialog = None
|
||||
for key, action_info in KNOWN_DIALOGS.items():
|
||||
if key in title.lower():
|
||||
matched_dialog = (key, action_info)
|
||||
break
|
||||
|
||||
if not matched_dialog:
|
||||
# Pas un dialogue connu — le workflow continue normalement
|
||||
return {'handled': False, 'title': title, 'reason': 'Pas un dialogue connu'}
|
||||
|
||||
dialog_key, action_info = matched_dialog
|
||||
target = action_info['target']
|
||||
description = action_info['description']
|
||||
|
||||
print(f"🧠 [Dialog] Dialogue détecté: '{dialog_key}' → clic '{target}'")
|
||||
|
||||
# 3. Demander à InfiGUI de cliquer sur le bouton
|
||||
click_result = self._click_via_infigui(
|
||||
target, description, screenshot_pil
|
||||
)
|
||||
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if click_result:
|
||||
print(f"✅ [Dialog] Clic '{target}' à ({click_result['x']}, {click_result['y']}) ({dt:.0f}ms)")
|
||||
return {
|
||||
'handled': True,
|
||||
'title': title,
|
||||
'dialog_type': dialog_key,
|
||||
'action': f"click '{target}'",
|
||||
'position': (click_result['x'], click_result['y']),
|
||||
'time_ms': dt,
|
||||
}
|
||||
else:
|
||||
# InfiGUI n'a pas trouvé le bouton — essayer le clic direct via OCR
|
||||
print(f"⚠️ [Dialog] InfiGUI n'a pas trouvé '{target}', essai OCR direct")
|
||||
ocr_result = self._click_via_ocr(target, screenshot_pil)
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if ocr_result:
|
||||
print(f"✅ [Dialog] OCR clic '{target}' à ({ocr_result[0]}, {ocr_result[1]}) ({dt:.0f}ms)")
|
||||
return {
|
||||
'handled': True,
|
||||
'title': title,
|
||||
'dialog_type': dialog_key,
|
||||
'action': f"click '{target}' (OCR)",
|
||||
'position': ocr_result,
|
||||
'time_ms': dt,
|
||||
}
|
||||
|
||||
print(f"❌ [Dialog] Impossible de cliquer '{target}' ({dt:.0f}ms)")
|
||||
return {
|
||||
'handled': False,
|
||||
'title': title,
|
||||
'dialog_type': dialog_key,
|
||||
'reason': f"Bouton '{target}' introuvable",
|
||||
'time_ms': dt,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lecture titre
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _read_title(self, screenshot_pil) -> str:
|
||||
"""Lit TOUT le texte visible via EasyOCR full-screen (~500ms).
|
||||
|
||||
En VM QEMU, la barre de titre Windows est à l'intérieur du framebuffer,
|
||||
pas en haut absolu de l'écran. On fait l'OCR full-screen et on cherche
|
||||
les mots-clés des dialogues connus dans le texte complet.
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
reader = self._get_easyocr()
|
||||
if reader is None:
|
||||
return ""
|
||||
|
||||
results = reader.readtext(np.array(screenshot_pil))
|
||||
full_text = ' '.join(r[1] for r in results if r[1].strip())
|
||||
return full_text
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [Dialog] Erreur lecture écran: {e}")
|
||||
return ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Clic via InfiGUI (serveur grounding)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _click_via_infigui(
|
||||
self, target: str, description: str, screenshot_pil
|
||||
) -> Optional[Dict]:
|
||||
"""Demande à InfiGUI (subprocess one-shot) de localiser et cliquer sur le bouton."""
|
||||
try:
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
|
||||
grounder = UITarsGrounder.get_instance()
|
||||
result = grounder.ground(
|
||||
target_text=target,
|
||||
target_description=description,
|
||||
screen_pil=screenshot_pil,
|
||||
)
|
||||
|
||||
if result and result.x is not None:
|
||||
import pyautogui
|
||||
pyautogui.click(result.x, result.y)
|
||||
return {'x': result.x, 'y': result.y}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [Dialog/InfiGUI] Erreur: {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Clic via OCR (fallback rapide)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _click_via_ocr(self, target: str, screenshot_pil) -> Optional[tuple]:
|
||||
"""Cherche le bouton par OCR et clique dessus."""
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
reader = self._get_easyocr()
|
||||
if reader is None:
|
||||
return None
|
||||
|
||||
results = reader.readtext(np.array(screenshot_pil))
|
||||
|
||||
target_lower = target.lower()
|
||||
matches = []
|
||||
for (bbox_pts, text, conf) in results:
|
||||
if target_lower in text.lower() or text.lower() in target_lower:
|
||||
x = int(sum(p[0] for p in bbox_pts) / 4)
|
||||
y = int(sum(p[1] for p in bbox_pts) / 4)
|
||||
matches.append((x, y, text))
|
||||
|
||||
if matches:
|
||||
# Prendre le match le plus bas (boutons = bas du dialogue)
|
||||
best = max(matches, key=lambda m: m[1])
|
||||
import pyautogui
|
||||
pyautogui.click(best[0], best[1])
|
||||
return (best[0], best[1])
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [Dialog/OCR] Erreur: {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# EasyOCR singleton
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_easyocr(self):
|
||||
if self._easyocr_reader is not None:
|
||||
return self._easyocr_reader
|
||||
|
||||
try:
|
||||
import easyocr
|
||||
self._easyocr_reader = easyocr.Reader(
|
||||
['fr', 'en'], gpu=True, verbose=False
|
||||
)
|
||||
return self._easyocr_reader
|
||||
except ImportError:
|
||||
return None
|
||||
239
core/grounding/element_signature.py
Normal file
239
core/grounding/element_signature.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
core/grounding/element_signature.py — Signatures d'éléments UI apprises
|
||||
|
||||
Chaque élément cliqué avec succès enrichit sa signature :
|
||||
- texte OCR, type, position relative, voisins contextuels
|
||||
- nombre de succès/échecs, confiance moyenne
|
||||
- variantes observées (résolutions, positions)
|
||||
|
||||
Les signatures sont stockées en SQLite pour un lookup rapide.
|
||||
Pattern identique à TargetMemoryStore (validé en prod).
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.element_signature import SignatureStore
|
||||
|
||||
store = SignatureStore()
|
||||
|
||||
# Après un clic réussi
|
||||
store.record_success("btn_valider", "notepad_1920x1080", element, confidence=0.92)
|
||||
|
||||
# Au replay
|
||||
sig = store.lookup("btn_valider", "notepad_1920x1080")
|
||||
if sig:
|
||||
print(f"Signature connue : {sig['text']} position={sig['relative_position']}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from core.grounding.fast_types import DetectedUIElement
|
||||
|
||||
# Chemin par défaut de la DB
|
||||
_DEFAULT_DB = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"data", "learning", "element_signatures.db",
|
||||
)
|
||||
|
||||
|
||||
class SignatureStore:
|
||||
"""Stockage SQLite des signatures d'éléments UI appris."""
|
||||
|
||||
def __init__(self, db_path: str = _DEFAULT_DB):
|
||||
self.db_path = db_path
|
||||
self._lock = threading.Lock()
|
||||
self._ensure_db()
|
||||
|
||||
def _ensure_db(self):
|
||||
"""Crée la DB et la table si nécessaire."""
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS signatures (
|
||||
target_key TEXT NOT NULL,
|
||||
screen_context TEXT NOT NULL,
|
||||
text TEXT DEFAULT '',
|
||||
element_type TEXT DEFAULT 'element',
|
||||
relative_position TEXT DEFAULT '',
|
||||
neighbors TEXT DEFAULT '[]',
|
||||
success_count INTEGER DEFAULT 0,
|
||||
fail_count INTEGER DEFAULT 0,
|
||||
avg_confidence REAL DEFAULT 0.0,
|
||||
last_seen TEXT DEFAULT '',
|
||||
variants TEXT DEFAULT '[]',
|
||||
PRIMARY KEY (target_key, screen_context)
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_target_key
|
||||
ON signatures(target_key)
|
||||
""")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lookup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def lookup(self, target_key: str, screen_context: str = "") -> Optional[Dict[str, Any]]:
|
||||
"""Cherche une signature connue.
|
||||
|
||||
Args:
|
||||
target_key: Clé unique de la cible (hash du texte + description).
|
||||
screen_context: Contexte d'écran (hash titre fenêtre + résolution).
|
||||
|
||||
Returns:
|
||||
Dict avec les champs de la signature, ou None.
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
# Chercher avec le contexte exact d'abord
|
||||
row = conn.execute(
|
||||
"SELECT * FROM signatures WHERE target_key = ? AND screen_context = ?",
|
||||
(target_key, screen_context),
|
||||
).fetchone()
|
||||
|
||||
# Fallback : chercher sans contexte (toutes les variantes)
|
||||
if row is None and screen_context:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM signatures WHERE target_key = ? ORDER BY success_count DESC LIMIT 1",
|
||||
(target_key,),
|
||||
).fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"target_key": row["target_key"],
|
||||
"screen_context": row["screen_context"],
|
||||
"text": row["text"],
|
||||
"element_type": row["element_type"],
|
||||
"relative_position": row["relative_position"],
|
||||
"neighbors": json.loads(row["neighbors"]),
|
||||
"success_count": row["success_count"],
|
||||
"fail_count": row["fail_count"],
|
||||
"avg_confidence": row["avg_confidence"],
|
||||
"last_seen": row["last_seen"],
|
||||
"variants": json.loads(row["variants"]),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Enregistrement
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def record_success(
|
||||
self,
|
||||
target_key: str,
|
||||
screen_context: str,
|
||||
element: DetectedUIElement,
|
||||
confidence: float,
|
||||
):
|
||||
"""Enregistre un succès — crée ou enrichit la signature."""
|
||||
with self._lock:
|
||||
existing = self.lookup(target_key, screen_context)
|
||||
now = time.strftime("%Y-%m-%dT%H:%M:%S")
|
||||
|
||||
if existing:
|
||||
# Enrichir la signature existante
|
||||
n = existing["success_count"]
|
||||
new_avg = (existing["avg_confidence"] * n + confidence) / (n + 1)
|
||||
|
||||
# Ajouter la variante si position différente
|
||||
variants = existing["variants"]
|
||||
variant = {
|
||||
"position": element.relative_position,
|
||||
"center": list(element.center),
|
||||
"confidence": confidence,
|
||||
"timestamp": now,
|
||||
}
|
||||
variants.append(variant)
|
||||
# Garder les 20 dernières variantes max
|
||||
variants = variants[-20:]
|
||||
|
||||
# Mettre à jour les voisins (union)
|
||||
neighbors = list(set(existing["neighbors"] + element.neighbors))[:10]
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
UPDATE signatures SET
|
||||
success_count = success_count + 1,
|
||||
avg_confidence = ?,
|
||||
last_seen = ?,
|
||||
neighbors = ?,
|
||||
variants = ?,
|
||||
relative_position = ?
|
||||
WHERE target_key = ? AND screen_context = ?
|
||||
""", (
|
||||
new_avg, now,
|
||||
json.dumps(neighbors),
|
||||
json.dumps(variants),
|
||||
element.relative_position,
|
||||
target_key, screen_context,
|
||||
))
|
||||
else:
|
||||
# Créer une nouvelle signature
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO signatures
|
||||
(target_key, screen_context, text, element_type, relative_position,
|
||||
neighbors, success_count, fail_count, avg_confidence, last_seen, variants)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 1, 0, ?, ?, ?)
|
||||
""", (
|
||||
target_key, screen_context,
|
||||
element.ocr_text,
|
||||
element.element_type,
|
||||
element.relative_position,
|
||||
json.dumps(element.neighbors[:10]),
|
||||
confidence, now,
|
||||
json.dumps([{
|
||||
"position": element.relative_position,
|
||||
"center": list(element.center),
|
||||
"confidence": confidence,
|
||||
"timestamp": now,
|
||||
}]),
|
||||
))
|
||||
|
||||
print(f"📝 [Signature] '{target_key}' {'enrichie' if existing else 'créée'} "
|
||||
f"(conf={confidence:.2f}, ctx='{screen_context[:30]}')")
|
||||
|
||||
def record_failure(self, target_key: str, screen_context: str):
|
||||
"""Enregistre un échec pour une signature."""
|
||||
with self._lock:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
UPDATE signatures SET fail_count = fail_count + 1, last_seen = ?
|
||||
WHERE target_key = ? AND screen_context = ?
|
||||
""", (time.strftime("%Y-%m-%dT%H:%M:%S"), target_key, screen_context))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utilitaires
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def make_target_key(text: str, description: str = "") -> str:
|
||||
"""Génère une clé unique pour une cible."""
|
||||
raw = f"{text.lower().strip()}|{description.lower().strip()}"
|
||||
return hashlib.md5(raw.encode()).hexdigest()[:16]
|
||||
|
||||
@staticmethod
|
||||
def make_screen_context(window_title: str, resolution: tuple = (0, 0)) -> str:
|
||||
"""Génère un contexte d'écran."""
|
||||
raw = f"{window_title.lower().strip()}|{resolution[0]}x{resolution[1]}"
|
||||
return hashlib.md5(raw.encode()).hexdigest()[:12]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Statistiques de la base de signatures."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
total = conn.execute("SELECT COUNT(*) FROM signatures").fetchone()[0]
|
||||
reliable = conn.execute(
|
||||
"SELECT COUNT(*) FROM signatures WHERE success_count >= 3 AND fail_count = 0"
|
||||
).fetchone()[0]
|
||||
return {
|
||||
"total_signatures": total,
|
||||
"reliable": reliable,
|
||||
"db_path": self.db_path,
|
||||
}
|
||||
326
core/grounding/fast_detector.py
Normal file
326
core/grounding/fast_detector.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
core/grounding/fast_detector.py — Layer FAST : détection rapide des éléments UI
|
||||
|
||||
Capture l'écran, détecte tous les éléments UI via RF-DETR (~120ms),
|
||||
enrichit chaque élément avec le texte OCR et le contexte spatial.
|
||||
|
||||
Produit un ScreenSnapshot utilisable par le SmartMatcher.
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.fast_detector import FastDetector
|
||||
|
||||
detector = FastDetector()
|
||||
snapshot = detector.detect()
|
||||
print(f"{len(snapshot.elements)} éléments en {snapshot.total_time_ms:.0f}ms")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from core.grounding.fast_types import DetectedUIElement, ScreenSnapshot
|
||||
|
||||
|
||||
class FastDetector:
|
||||
"""Détection rapide de tous les éléments UI visibles sur l'écran.
|
||||
|
||||
Combine RF-DETR (détection bbox) + docTR (OCR) pour produire
|
||||
un ScreenSnapshot enrichi.
|
||||
|
||||
Le modèle RF-DETR est un singleton chargé au premier appel (~1s),
|
||||
puis les appels suivants sont rapides (~120ms).
|
||||
"""
|
||||
|
||||
def __init__(self, detection_threshold: float = 0.30):
|
||||
self.detection_threshold = detection_threshold
|
||||
self._last_snapshot: Optional[ScreenSnapshot] = None
|
||||
self._last_phash: str = ""
|
||||
|
||||
def detect(
|
||||
self,
|
||||
screenshot_pil: Optional[Any] = None,
|
||||
phash: str = "",
|
||||
window_title: str = "",
|
||||
) -> ScreenSnapshot:
|
||||
"""Détecte et enrichit tous les éléments UI de l'écran.
|
||||
|
||||
Args:
|
||||
screenshot_pil: Image PIL. Si None, capture via mss.
|
||||
phash: Hash perceptuel pour le cache. Si identique au dernier, réutilise le cache.
|
||||
window_title: Titre de la fenêtre active.
|
||||
|
||||
Returns:
|
||||
ScreenSnapshot avec tous les éléments enrichis.
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# Cache : même écran → même résultat
|
||||
if phash and phash == self._last_phash and self._last_snapshot is not None:
|
||||
print(f"⚡ [FAST] Cache hit (pHash identique)")
|
||||
return self._last_snapshot
|
||||
|
||||
# Capture si pas fourni
|
||||
if screenshot_pil is None:
|
||||
screenshot_pil = self._capture_screen()
|
||||
if screenshot_pil is None:
|
||||
return ScreenSnapshot(elements=[], ocr_words=[], resolution=(0, 0))
|
||||
|
||||
w, h = screenshot_pil.size
|
||||
|
||||
# --- Détection RF-DETR (~120ms) ---
|
||||
t_det = time.time()
|
||||
raw_elements = self._detect_rfdetr(screenshot_pil)
|
||||
detection_ms = (time.time() - t_det) * 1000
|
||||
|
||||
# --- OCR sur les crops des éléments détectés (pas full screen) ---
|
||||
t_ocr = time.time()
|
||||
ocr_words = self._ocr_extract(screenshot_pil)
|
||||
ocr_ms = (time.time() - t_ocr) * 1000
|
||||
|
||||
# --- Enrichissement : attribuer texte + voisins + position ---
|
||||
enriched = self._enrich_elements(raw_elements, ocr_words, w, h)
|
||||
|
||||
total_ms = (time.time() - t0) * 1000
|
||||
|
||||
snapshot = ScreenSnapshot(
|
||||
elements=enriched,
|
||||
ocr_words=ocr_words,
|
||||
resolution=(w, h),
|
||||
window_title=window_title,
|
||||
phash=phash,
|
||||
detection_time_ms=detection_ms,
|
||||
ocr_time_ms=ocr_ms,
|
||||
total_time_ms=total_ms,
|
||||
)
|
||||
|
||||
# Mettre en cache
|
||||
if phash:
|
||||
self._last_phash = phash
|
||||
self._last_snapshot = snapshot
|
||||
|
||||
print(f"⚡ [FAST] {len(enriched)} éléments détectés en {total_ms:.0f}ms "
|
||||
f"(det={detection_ms:.0f}ms, ocr={ocr_ms:.0f}ms)")
|
||||
|
||||
return snapshot
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Détection RF-DETR
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _detect_rfdetr(self, image) -> List[DetectedUIElement]:
|
||||
"""Détecte les éléments via RF-DETR (réutilise le singleton existant)."""
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, 'visual_workflow_builder/backend')
|
||||
from services.ui_detection_service import detect_ui_elements
|
||||
|
||||
result = detect_ui_elements(image, threshold=self.detection_threshold)
|
||||
|
||||
elements = []
|
||||
for e in result.elements:
|
||||
x1 = e.bbox["x1"]
|
||||
y1 = e.bbox["y1"]
|
||||
x2 = e.bbox["x2"]
|
||||
y2 = e.bbox["y2"]
|
||||
elements.append(DetectedUIElement(
|
||||
id=e.id,
|
||||
bbox=(x1, y1, x2, y2),
|
||||
center=(e.center["x"], e.center["y"]),
|
||||
confidence=e.confidence,
|
||||
))
|
||||
|
||||
return elements
|
||||
|
||||
except Exception as ex:
|
||||
print(f"⚠️ [FAST/detect] RF-DETR erreur: {ex}")
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# OCR
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_easyocr_reader = None # Singleton EasyOCR (chargé une fois)
|
||||
|
||||
def _ocr_extract(self, image) -> List[Dict[str, Any]]:
|
||||
"""Extrait les mots visibles via EasyOCR (GPU, ~500ms).
|
||||
|
||||
Fallback sur docTR si EasyOCR non disponible.
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
import easyocr
|
||||
|
||||
# Singleton : charger le reader une seule fois
|
||||
if FastDetector._easyocr_reader is None:
|
||||
print(f"🔍 [FAST/ocr] Chargement EasyOCR (GPU)...")
|
||||
FastDetector._easyocr_reader = easyocr.Reader(
|
||||
['fr', 'en'], gpu=True, verbose=False
|
||||
)
|
||||
|
||||
results = FastDetector._easyocr_reader.readtext(np.array(image))
|
||||
|
||||
words = []
|
||||
for (bbox_pts, text, conf) in results:
|
||||
if not text or len(text.strip()) < 1:
|
||||
continue
|
||||
# bbox_pts = [[x1,y1],[x2,y1],[x2,y2],[x1,y2]]
|
||||
x1 = int(min(p[0] for p in bbox_pts))
|
||||
y1 = int(min(p[1] for p in bbox_pts))
|
||||
x2 = int(max(p[0] for p in bbox_pts))
|
||||
y2 = int(max(p[1] for p in bbox_pts))
|
||||
words.append({
|
||||
'text': text.strip(),
|
||||
'bbox': [x1, y1, x2, y2],
|
||||
'confidence': float(conf),
|
||||
})
|
||||
|
||||
return words
|
||||
|
||||
except ImportError:
|
||||
# Fallback docTR
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, 'visual_workflow_builder/backend')
|
||||
from services.ocr_service import ocr_extract_words
|
||||
return ocr_extract_words(image) or []
|
||||
except Exception:
|
||||
return []
|
||||
except Exception as ex:
|
||||
print(f"⚠️ [FAST/ocr] EasyOCR erreur: {ex}")
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Enrichissement
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _enrich_elements(
|
||||
self,
|
||||
elements: List[DetectedUIElement],
|
||||
ocr_words: List[Dict[str, Any]],
|
||||
screen_w: int,
|
||||
screen_h: int,
|
||||
) -> List[DetectedUIElement]:
|
||||
"""Enrichit chaque élément avec texte OCR, voisins et position relative."""
|
||||
|
||||
for elem in elements:
|
||||
# 1. Attribuer le texte OCR par intersection bbox
|
||||
elem.ocr_text = self._assign_ocr_text(elem, ocr_words)
|
||||
|
||||
# 2. Position relative dans l'écran (grille 3x3)
|
||||
elem.relative_position = self._compute_relative_position(
|
||||
elem.center, screen_w, screen_h
|
||||
)
|
||||
|
||||
# 3. Classifier le type d'élément (heuristique taille + ratio)
|
||||
elem.element_type = self._classify_element_type(elem)
|
||||
|
||||
# 4. Calculer les voisins (texte des éléments proches)
|
||||
for elem in elements:
|
||||
elem.neighbors = self._find_neighbors(elem, elements)
|
||||
|
||||
return elements
|
||||
|
||||
def _assign_ocr_text(
|
||||
self,
|
||||
elem: DetectedUIElement,
|
||||
ocr_words: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""Attribue le texte OCR à un élément par intersection géométrique."""
|
||||
x1, y1, x2, y2 = elem.bbox
|
||||
# Élargir la bbox de 20% pour capturer le texte autour
|
||||
margin_x = int((x2 - x1) * 0.2)
|
||||
margin_y = int((y2 - y1) * 0.2)
|
||||
ex1, ey1 = x1 - margin_x, y1 - margin_y
|
||||
ex2, ey2 = x2 + margin_x, y2 + margin_y
|
||||
|
||||
texts = []
|
||||
for word in ocr_words:
|
||||
wb = word.get('bbox', [0, 0, 0, 0])
|
||||
if len(wb) < 4:
|
||||
continue
|
||||
wx1, wy1, wx2, wy2 = wb[0], wb[1], wb[2], wb[3]
|
||||
# Intersection ?
|
||||
if wx1 < ex2 and wx2 > ex1 and wy1 < ey2 and wy2 > ey1:
|
||||
text = word.get('text', '').strip()
|
||||
if text and len(text) > 1:
|
||||
texts.append(text)
|
||||
|
||||
return ' '.join(texts)
|
||||
|
||||
@staticmethod
|
||||
def _compute_relative_position(
|
||||
center: Tuple[int, int],
|
||||
screen_w: int,
|
||||
screen_h: int,
|
||||
) -> str:
|
||||
"""Calcule la position relative dans une grille 3x3."""
|
||||
cx, cy = center
|
||||
col = "left" if cx < screen_w / 3 else ("right" if cx > 2 * screen_w / 3 else "center")
|
||||
row = "top" if cy < screen_h / 3 else ("bottom" if cy > 2 * screen_h / 3 else "middle")
|
||||
return f"{row}_{col}"
|
||||
|
||||
@staticmethod
|
||||
def _classify_element_type(elem: DetectedUIElement) -> str:
|
||||
"""Classifie le type d'élément par heuristique taille/ratio."""
|
||||
w, h = elem.width, elem.height
|
||||
if w == 0 or h == 0:
|
||||
return "element"
|
||||
ratio = w / h
|
||||
area = w * h
|
||||
|
||||
# Petit carré → icône
|
||||
if area < 5000 and 0.5 < ratio < 2.0:
|
||||
return "icon"
|
||||
# Large et fin → bouton ou champ
|
||||
if ratio > 3.0 and h < 60:
|
||||
return "input"
|
||||
if ratio > 2.0 and h < 50:
|
||||
return "button"
|
||||
# Grand bloc → zone de contenu
|
||||
if area > 50000:
|
||||
return "container"
|
||||
|
||||
return "element"
|
||||
|
||||
@staticmethod
|
||||
def _find_neighbors(
|
||||
elem: DetectedUIElement,
|
||||
all_elements: List[DetectedUIElement],
|
||||
max_neighbors: int = 5,
|
||||
) -> List[str]:
|
||||
"""Trouve les textes OCR des éléments proches (rayon 1.5x diagonale)."""
|
||||
diag = math.sqrt(elem.width**2 + elem.height**2)
|
||||
radius = max(diag * 1.5, 100) # minimum 100px
|
||||
|
||||
neighbors = []
|
||||
for other in all_elements:
|
||||
if other.id == elem.id or not other.ocr_text:
|
||||
continue
|
||||
dx = other.center[0] - elem.center[0]
|
||||
dy = other.center[1] - elem.center[1]
|
||||
dist = math.sqrt(dx**2 + dy**2)
|
||||
if dist < radius:
|
||||
neighbors.append(other.ocr_text)
|
||||
|
||||
return neighbors[:max_neighbors]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Capture écran
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _capture_screen():
|
||||
"""Capture l'écran via mss."""
|
||||
try:
|
||||
import mss
|
||||
from PIL import Image
|
||||
|
||||
with mss.mss() as sct:
|
||||
mon = sct.monitors[0]
|
||||
grab = sct.grab(mon)
|
||||
return Image.frombytes('RGB', grab.size, grab.bgra, 'raw', 'BGRX')
|
||||
except Exception as ex:
|
||||
print(f"⚠️ [FAST/capture] Erreur: {ex}")
|
||||
return None
|
||||
216
core/grounding/fast_pipeline.py
Normal file
216
core/grounding/fast_pipeline.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
core/grounding/fast_pipeline.py — Pipeline FAST → SMART → THINK
|
||||
|
||||
Orchestrateur central : détecte les éléments (FAST), matche avec la cible (SMART),
|
||||
et demande au VLM de trancher si le score est trop bas (THINK).
|
||||
|
||||
Seuils de confiance :
|
||||
≥ 0.90 → action directe (FAST/SMART)
|
||||
0.60-0.90 → VLM confirme (THINK)
|
||||
< 0.60 → VLM cherche seul (THINK)
|
||||
|
||||
L'ancien GroundingPipeline est utilisé en fallback si tout échoue.
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.fast_pipeline import FastSmartThinkPipeline
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
pipeline = FastSmartThinkPipeline()
|
||||
result = pipeline.locate(GroundingTarget(text="Valider"))
|
||||
if result:
|
||||
print(f"({result.x}, {result.y}) via {result.method} en {result.time_ms:.0f}ms")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from core.grounding.target import GroundingTarget, GroundingResult
|
||||
from core.grounding.fast_types import LocateResult
|
||||
from core.grounding.fast_detector import FastDetector
|
||||
from core.grounding.smart_matcher import SmartMatcher
|
||||
from core.grounding.think_arbiter import ThinkArbiter
|
||||
from core.grounding.element_signature import SignatureStore
|
||||
|
||||
|
||||
# Singleton
|
||||
_instance: Optional[FastSmartThinkPipeline] = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
|
||||
class FastSmartThinkPipeline:
|
||||
"""Pipeline FAST → SMART → THINK pour la localisation d'éléments UI.
|
||||
|
||||
Chaque appel à locate() suit la cascade :
|
||||
1. FAST : détection RF-DETR + OCR enrichissement (~120ms+1s)
|
||||
2. SMART : matching texte/type/position/voisins (< 1ms)
|
||||
3. THINK : VLM arbitre si score insuffisant (~3-5s)
|
||||
4. Fallback : ancien pipeline si tout échoue
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
confidence_direct: float = 0.90,
|
||||
confidence_think: float = 0.60,
|
||||
enable_think: bool = True,
|
||||
enable_learning: bool = True,
|
||||
):
|
||||
self.confidence_direct = confidence_direct
|
||||
self.confidence_think = confidence_think
|
||||
self.enable_think = enable_think
|
||||
self.enable_learning = enable_learning
|
||||
|
||||
self._detector = FastDetector()
|
||||
self._matcher = SmartMatcher()
|
||||
self._arbiter = ThinkArbiter()
|
||||
self._signatures = SignatureStore()
|
||||
self._fallback_pipeline = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> FastSmartThinkPipeline:
|
||||
"""Retourne l'instance singleton."""
|
||||
global _instance
|
||||
if _instance is None:
|
||||
with _instance_lock:
|
||||
if _instance is None:
|
||||
_instance = cls()
|
||||
return _instance
|
||||
|
||||
def set_fallback_pipeline(self, pipeline) -> None:
|
||||
"""Configure l'ancien pipeline comme safety net."""
|
||||
self._fallback_pipeline = pipeline
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API principale
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def locate(
|
||||
self,
|
||||
target: GroundingTarget,
|
||||
screenshot_pil=None,
|
||||
phash: str = "",
|
||||
window_title: str = "",
|
||||
) -> Optional[GroundingResult]:
|
||||
"""Localise un élément UI via la cascade FAST → SMART → THINK.
|
||||
|
||||
Args:
|
||||
target: Ce qu'on cherche (texte, description, bbox d'origine).
|
||||
screenshot_pil: Image PIL. Si None, capture via mss.
|
||||
phash: Hash perceptuel pour le cache.
|
||||
window_title: Titre de la fenêtre active.
|
||||
|
||||
Returns:
|
||||
GroundingResult compatible avec le pipeline existant, ou None.
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# --- FAST : détecter tous les éléments ---
|
||||
snapshot = self._detector.detect(
|
||||
screenshot_pil=screenshot_pil,
|
||||
phash=phash,
|
||||
window_title=window_title,
|
||||
)
|
||||
|
||||
if not snapshot.elements:
|
||||
print(f"⚡ [Pipeline] FAST : aucun élément détecté")
|
||||
return self._try_fallback(target)
|
||||
|
||||
# --- Lookup signature apprise ---
|
||||
target_key = SignatureStore.make_target_key(
|
||||
target.text or "", target.description or ""
|
||||
)
|
||||
screen_ctx = SignatureStore.make_screen_context(
|
||||
window_title, snapshot.resolution
|
||||
)
|
||||
signature = self._signatures.lookup(target_key, screen_ctx)
|
||||
|
||||
# --- SMART : matcher avec la cible ---
|
||||
candidate = self._matcher.match(snapshot, target, signature)
|
||||
|
||||
if candidate:
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
# Score suffisant → action directe
|
||||
if candidate.score >= self.confidence_direct:
|
||||
print(f"✅ [Pipeline] FAST→SMART direct : '{candidate.element.ocr_text}' "
|
||||
f"score={candidate.score:.3f} ({candidate.method}) "
|
||||
f"→ ({candidate.element.center[0]}, {candidate.element.center[1]}) "
|
||||
f"en {dt:.0f}ms")
|
||||
|
||||
# Apprentissage
|
||||
if self.enable_learning:
|
||||
self._signatures.record_success(
|
||||
target_key, screen_ctx,
|
||||
candidate.element, candidate.score,
|
||||
)
|
||||
|
||||
return GroundingResult(
|
||||
x=candidate.element.center[0],
|
||||
y=candidate.element.center[1],
|
||||
method=f"fast_{candidate.method}",
|
||||
confidence=candidate.score,
|
||||
time_ms=dt,
|
||||
)
|
||||
|
||||
# Score moyen → demander au VLM de confirmer
|
||||
if candidate.score >= self.confidence_think and self.enable_think:
|
||||
print(f"🤔 [Pipeline] SMART score={candidate.score:.3f} — THINK pour confirmer")
|
||||
think_result = self._arbiter.arbitrate(
|
||||
target,
|
||||
candidates=[candidate],
|
||||
screenshot_pil=screenshot_pil or snapshot.elements[0] if False else screenshot_pil,
|
||||
)
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if think_result:
|
||||
# VLM a confirmé
|
||||
if self.enable_learning:
|
||||
self._signatures.record_success(
|
||||
target_key, screen_ctx,
|
||||
candidate.element, think_result.confidence,
|
||||
)
|
||||
return GroundingResult(
|
||||
x=think_result.x, y=think_result.y,
|
||||
method="smart_think_confirmed",
|
||||
confidence=think_result.confidence,
|
||||
time_ms=dt,
|
||||
)
|
||||
|
||||
# --- THINK : score trop bas ou pas de candidat → VLM cherche seul ---
|
||||
if self.enable_think:
|
||||
score_info = f"score={candidate.score:.3f}" if candidate else "aucun candidat"
|
||||
print(f"🤔 [Pipeline] {score_info} — THINK recherche complète")
|
||||
think_result = self._arbiter.arbitrate(
|
||||
target, candidates=[], screenshot_pil=screenshot_pil,
|
||||
)
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if think_result:
|
||||
return GroundingResult(
|
||||
x=think_result.x, y=think_result.y,
|
||||
method="think_vlm",
|
||||
confidence=think_result.confidence,
|
||||
time_ms=dt,
|
||||
)
|
||||
|
||||
# --- Fallback : ancien pipeline ---
|
||||
return self._try_fallback(target)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fallback
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _try_fallback(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Tente l'ancien pipeline en dernier recours."""
|
||||
if self._fallback_pipeline is None:
|
||||
print(f"❌ [Pipeline] Aucune méthode n'a trouvé '{target.text}'")
|
||||
return None
|
||||
|
||||
print(f"⚠️ [Pipeline] Fallback ancien pipeline pour '{target.text}'")
|
||||
try:
|
||||
return self._fallback_pipeline.locate(target)
|
||||
except Exception as ex:
|
||||
print(f"⚠️ [Pipeline] Fallback échoué: {ex}")
|
||||
return None
|
||||
81
core/grounding/fast_types.py
Normal file
81
core/grounding/fast_types.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
core/grounding/fast_types.py — Structures de données pour le pipeline FAST→SMART→THINK
|
||||
|
||||
Utilisées exclusivement par le pipeline de localisation rapide.
|
||||
Compatibles avec GroundingTarget/GroundingResult existants via conversion.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectedUIElement:
|
||||
"""Élément UI détecté par le layer FAST (RF-DETR) puis enrichi par OCR."""
|
||||
id: int
|
||||
bbox: Tuple[int, int, int, int] # (x1, y1, x2, y2) pixels absolus
|
||||
center: Tuple[int, int] # (cx, cy)
|
||||
confidence: float # confidence détecteur (0-1)
|
||||
element_type: str = "element" # "button", "input", "icon", "text", "element"
|
||||
ocr_text: str = "" # texte OCR extrait de la région
|
||||
neighbors: List[str] = field(default_factory=list) # textes des éléments proches
|
||||
relative_position: str = "" # "top_left", "center", "bottom_right", etc.
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
return self.bbox[2] - self.bbox[0]
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
return self.bbox[3] - self.bbox[1]
|
||||
|
||||
@property
|
||||
def area(self) -> int:
|
||||
return self.width * self.height
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScreenSnapshot:
|
||||
"""État complet de l'écran à un instant t — sortie du layer FAST."""
|
||||
elements: List[DetectedUIElement]
|
||||
ocr_words: List[Dict[str, Any]] # mots OCR bruts [{text, bbox}]
|
||||
resolution: Tuple[int, int] # (width, height)
|
||||
window_title: str = ""
|
||||
phash: str = ""
|
||||
detection_time_ms: float = 0.0
|
||||
ocr_time_ms: float = 0.0
|
||||
total_time_ms: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchCandidate:
|
||||
"""Résultat du matching SMART pour un élément candidat."""
|
||||
element: DetectedUIElement
|
||||
score: float # score combiné (0-1)
|
||||
score_detail: Dict[str, float] = field(default_factory=dict)
|
||||
method: str = "" # "exact_text", "fuzzy_text", "position", etc.
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocateResult:
|
||||
"""Résultat final du pipeline FAST→SMART→THINK."""
|
||||
x: int
|
||||
y: int
|
||||
confidence: float
|
||||
method: str # "fast_exact", "fast_fuzzy", "smart_vote", "think_vlm"
|
||||
time_ms: float
|
||||
tier: str = "fast" # "fast", "smart", "think"
|
||||
element: Optional[DetectedUIElement] = None
|
||||
candidates_count: int = 0
|
||||
|
||||
def to_grounding_result(self):
|
||||
"""Conversion vers GroundingResult pour compatibilité."""
|
||||
from core.grounding.target import GroundingResult
|
||||
return GroundingResult(
|
||||
x=self.x, y=self.y,
|
||||
method=self.method,
|
||||
confidence=self.confidence,
|
||||
time_ms=self.time_ms,
|
||||
)
|
||||
210
core/grounding/infigui_worker.py
Normal file
210
core/grounding/infigui_worker.py
Normal file
@@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Worker InfiGUI — process indépendant, communication par fichiers.
|
||||
|
||||
Charge le modèle, surveille /tmp/infigui_request.json, infère, écrit /tmp/infigui_response.json.
|
||||
|
||||
Lancement :
|
||||
cd ~/ai/rpa_vision_v3
|
||||
.venv/bin/python3 -m core.grounding.infigui_worker
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import gc
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
import torch
|
||||
|
||||
REQUEST_FILE = "/tmp/infigui_request.json"
|
||||
RESPONSE_FILE = "/tmp/infigui_response.json"
|
||||
READY_FILE = "/tmp/infigui_ready"
|
||||
|
||||
|
||||
def load_model():
|
||||
"""Charge InfiGUI-G1-3B en 4-bit NF4."""
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
|
||||
|
||||
model_id = "InfiX-ai/InfiGUI-G1-3B"
|
||||
print(f"[infigui-worker] Chargement {model_id}...")
|
||||
|
||||
bnb = BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_id, quantization_config=bnb, device_map={"": "cuda:0"},
|
||||
)
|
||||
model.eval()
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_id, padding_side="left",
|
||||
min_pixels=100 * 28 * 28, max_pixels=5600 * 28 * 28,
|
||||
)
|
||||
|
||||
vram = torch.cuda.memory_allocated() / 1e9
|
||||
print(f"[infigui-worker] Prêt — VRAM: {vram:.2f}GB")
|
||||
|
||||
# Signal "prêt"
|
||||
with open(READY_FILE, "w") as f:
|
||||
f.write(f"ready {vram:.2f}GB")
|
||||
|
||||
return model, processor
|
||||
|
||||
|
||||
def infer(model, processor, req):
|
||||
"""Fait une inférence.
|
||||
|
||||
Modes :
|
||||
- texte seul (target/description) : grounding classique
|
||||
- fusionné (anchor_image_path présent) : on passe en plus le crop d'ancre
|
||||
comme image de référence et le modèle doit retrouver cet élément sur
|
||||
le screenshot. Évite la double passe describe→ground.
|
||||
"""
|
||||
from PIL import Image
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
target = req.get("target", "")
|
||||
description = req.get("description", "")
|
||||
label = f"{target} — {description}" if description else target
|
||||
|
||||
# Image principale (screenshot complet)
|
||||
image_path = req.get("image_path", "")
|
||||
if image_path and os.path.exists(image_path):
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
else:
|
||||
import mss
|
||||
with mss.mss() as sct:
|
||||
grab = sct.grab(sct.monitors[0])
|
||||
img = Image.frombytes("RGB", grab.size, grab.bgra, "raw", "BGRX")
|
||||
|
||||
# Image d'ancre (optionnelle) — mode fusionné describe+ground
|
||||
anchor_image_path = req.get("anchor_image_path", "")
|
||||
anchor_img = None
|
||||
if anchor_image_path and os.path.exists(anchor_image_path):
|
||||
anchor_img = Image.open(anchor_image_path).convert("RGB")
|
||||
|
||||
if not label.strip() and anchor_img is None:
|
||||
return {"x": None, "y": None, "error": "target ou anchor_image requis"}
|
||||
|
||||
W, H = img.size
|
||||
factor = 28
|
||||
rH = max(factor, round(H / factor) * factor)
|
||||
rW = max(factor, round(W / factor) * factor)
|
||||
|
||||
system = (
|
||||
"You FIRST think about the reasoning process as an internal monologue "
|
||||
"and then provide the final answer.\n"
|
||||
"The reasoning process MUST BE enclosed within <think> </think> tags."
|
||||
)
|
||||
|
||||
# Construction du prompt selon le mode
|
||||
if anchor_img is not None:
|
||||
# Mode fusionné : Image1 = crop d'ancre, Image2 = screenshot
|
||||
hint = f' Hint: this element looks like "{label}".' if label.strip() else ""
|
||||
user_text = (
|
||||
f"The first image is a small crop of a UI element captured previously. "
|
||||
f"The second image is the current screen ({rW}x{rH}).{hint}\n"
|
||||
f"Locate on the second image the UI element that visually matches the first image. "
|
||||
f"Output the coordinates using JSON format: "
|
||||
f'[{{"point_2d": [x, y]}}, ...]'
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image", "image": anchor_img},
|
||||
{"type": "image", "image": img},
|
||||
{"type": "text", "text": user_text},
|
||||
]},
|
||||
]
|
||||
else:
|
||||
# Mode classique : texte seul
|
||||
user_text = (
|
||||
f'The screen\'s resolution is {rW}x{rH}.\n'
|
||||
f'Locate the UI element(s) for "{label}", '
|
||||
f'output the coordinates using JSON format: '
|
||||
f'[{{"point_2d": [x, y]}}, ...]'
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image", "image": img},
|
||||
{"type": "text", "text": user_text},
|
||||
]},
|
||||
]
|
||||
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text], images=image_inputs, videos=video_inputs,
|
||||
padding=True, return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
t0 = time.time()
|
||||
with torch.no_grad():
|
||||
gen = model.generate(**inputs, max_new_tokens=512)
|
||||
infer_ms = (time.time() - t0) * 1000
|
||||
|
||||
trimmed = [o[len(i):] for i, o in zip(inputs.input_ids, gen)]
|
||||
raw = processor.batch_decode(
|
||||
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False,
|
||||
)[0].strip()
|
||||
|
||||
mode_str = "fused" if anchor_img is not None else "text"
|
||||
print(f"[infigui-worker] [{mode_str}] '{label[:40]}' ({infer_ms:.0f}ms)")
|
||||
|
||||
# Parser JSON point_2d
|
||||
json_part = raw.split("</think>")[-1] if "</think>" in raw else raw
|
||||
json_part = json_part.replace("```json", "").replace("```", "").strip()
|
||||
|
||||
px, py = None, None
|
||||
try:
|
||||
parsed = json.loads(json_part)
|
||||
if isinstance(parsed, list) and len(parsed) > 0:
|
||||
pt = parsed[0].get("point_2d", [])
|
||||
if len(pt) >= 2:
|
||||
px = int(pt[0] * W / rW)
|
||||
py = int(pt[1] * H / rH)
|
||||
except json.JSONDecodeError:
|
||||
m = re.search(r'"point_2d"\s*:\s*\[(\d+),\s*(\d+)\]', raw)
|
||||
if m:
|
||||
px = int(int(m.group(1)) * W / rW)
|
||||
py = int(int(m.group(2)) * H / rH)
|
||||
|
||||
return {
|
||||
"x": px, "y": py,
|
||||
"method": "infigui",
|
||||
"confidence": 0.90 if px else 0.0,
|
||||
"time_ms": round(infer_ms, 1),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Mode one-shot : lit une requête sur stdin, infère, écrit le résultat sur stdout."""
|
||||
# Lire la requête
|
||||
input_data = sys.stdin.read().strip()
|
||||
if not input_data:
|
||||
print(json.dumps({"x": None, "y": None, "error": "pas de requête"}))
|
||||
return
|
||||
|
||||
try:
|
||||
req = json.loads(input_data)
|
||||
except json.JSONDecodeError:
|
||||
print(json.dumps({"x": None, "y": None, "error": "JSON invalide"}))
|
||||
return
|
||||
|
||||
model, processor = load_model()
|
||||
result = infer(model, processor, req)
|
||||
print(json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
190
core/grounding/pipeline.py
Normal file
190
core/grounding/pipeline.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
core/grounding/pipeline.py — Pipeline de grounding en cascade
|
||||
|
||||
Orchestre les methodes de localisation dans l'ordre :
|
||||
1. Template matching (TemplateMatcher, local, ~80ms)
|
||||
2. OCR (docTR via input_handler, local, ~1s)
|
||||
3. UI-TARS (HTTP vers serveur grounding, ~3s)
|
||||
4. Static fallback (coordonnees d'origine du workflow)
|
||||
|
||||
Chaque methode est essayee dans l'ordre. Des qu'une reussit, on retourne
|
||||
le resultat. Cela permet un equilibre entre vitesse (template) et robustesse
|
||||
(UI-TARS pour les elements qui ont change de position/apparence).
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.pipeline import GroundingPipeline
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
pipeline = GroundingPipeline()
|
||||
result = pipeline.locate(GroundingTarget(
|
||||
text="Valider",
|
||||
description="bouton vert en bas",
|
||||
template_b64=screenshot_b64,
|
||||
original_bbox={"x": 100, "y": 200, "width": 80, "height": 30},
|
||||
))
|
||||
if result:
|
||||
print(f"Trouve a ({result.x}, {result.y}) via {result.method}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.grounding.target import GroundingTarget, GroundingResult
|
||||
|
||||
|
||||
class GroundingPipeline:
|
||||
"""Pipeline de localisation en cascade : template -> OCR -> UI-TARS -> static."""
|
||||
|
||||
def __init__(self, template_threshold: float = 0.75, enable_uitars: bool = True):
|
||||
self.template_threshold = template_threshold
|
||||
self.enable_uitars = enable_uitars
|
||||
|
||||
def locate(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Localise un element UI en essayant les methodes en cascade.
|
||||
|
||||
Args:
|
||||
target: description de l'element a localiser
|
||||
|
||||
Returns:
|
||||
GroundingResult ou None si aucune methode ne trouve l'element
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# --- Methode 1 : Template matching (~80ms) ---
|
||||
result = self._try_template(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 2 : OCR texte (~1s) ---
|
||||
result = self._try_ocr(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 3 : UI-TARS via serveur HTTP (~3s) ---
|
||||
if self.enable_uitars:
|
||||
result = self._try_uitars(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
# --- Methode 4 : Fallback statique ---
|
||||
result = self._try_static(target)
|
||||
if result:
|
||||
print(f"[GroundingPipeline] Localise via {result.method} en "
|
||||
f"{(time.time() - t0) * 1000:.0f}ms")
|
||||
return result
|
||||
|
||||
print(f"[GroundingPipeline] ECHEC: '{target.text}' introuvable "
|
||||
f"(toutes methodes epuisees, {(time.time() - t0) * 1000:.0f}ms)")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Methodes individuelles
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _try_template(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Template matching — rapide, exact, mais sensible aux changements visuels."""
|
||||
if not target.template_b64:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.grounding.template_matcher import TemplateMatcher
|
||||
matcher = TemplateMatcher(threshold=self.template_threshold)
|
||||
match = matcher.match_screen(anchor_b64=target.template_b64)
|
||||
if match:
|
||||
print(f"[GroundingPipeline/template] score={match.score:.3f} "
|
||||
f"pos=({match.x},{match.y}) ({match.time_ms:.0f}ms)")
|
||||
return GroundingResult(
|
||||
x=match.x,
|
||||
y=match.y,
|
||||
method='template',
|
||||
confidence=match.score,
|
||||
time_ms=match.time_ms,
|
||||
)
|
||||
else:
|
||||
diag = matcher.match_screen_diagnostic(anchor_b64=target.template_b64)
|
||||
print(f"[GroundingPipeline/template] pas de match — best={diag}")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/template] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_ocr(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""OCR : cherche le texte cible sur l'ecran via docTR."""
|
||||
if not target.text:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.execution.input_handler import _grounding_ocr
|
||||
bbox = target.original_bbox if target.original_bbox else None
|
||||
result = _grounding_ocr(target.text, anchor_bbox=bbox)
|
||||
if result:
|
||||
print(f"[GroundingPipeline/OCR] '{target.text}' -> ({result['x']}, {result['y']})")
|
||||
return GroundingResult(
|
||||
x=result['x'],
|
||||
y=result['y'],
|
||||
method='ocr',
|
||||
confidence=result.get('confidence', 0.80),
|
||||
time_ms=result.get('time_ms', 0),
|
||||
)
|
||||
else:
|
||||
print(f"[GroundingPipeline/OCR] '{target.text}' non trouve")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/OCR] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_uitars(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""UI-TARS via serveur HTTP — robust, gere les changements de layout."""
|
||||
if not target.text and not target.description:
|
||||
return None
|
||||
|
||||
try:
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
grounder = UITarsGrounder.get_instance()
|
||||
result = grounder.ground(
|
||||
target_text=target.text,
|
||||
target_description=target.description,
|
||||
)
|
||||
if result:
|
||||
print(f"[GroundingPipeline/UI-TARS] ({result.x}, {result.y}) "
|
||||
f"conf={result.confidence:.2f} ({result.time_ms:.0f}ms)")
|
||||
return result
|
||||
else:
|
||||
print(f"[GroundingPipeline/UI-TARS] pas de resultat")
|
||||
except Exception as e:
|
||||
print(f"[GroundingPipeline/UI-TARS] ERREUR: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _try_static(self, target: GroundingTarget) -> Optional[GroundingResult]:
|
||||
"""Fallback : coordonnees d'origine du workflow (centre du bounding box)."""
|
||||
bbox = target.original_bbox
|
||||
if not bbox:
|
||||
return None
|
||||
|
||||
w = bbox.get('width', 0)
|
||||
h = bbox.get('height', 0)
|
||||
if not w or not h:
|
||||
return None
|
||||
|
||||
x = int(bbox.get('x', 0) + w / 2)
|
||||
y = int(bbox.get('y', 0) + h / 2)
|
||||
|
||||
print(f"[GroundingPipeline/static] fallback ({x}, {y}) "
|
||||
f"depuis bbox {bbox}")
|
||||
|
||||
return GroundingResult(
|
||||
x=x,
|
||||
y=y,
|
||||
method='static_fallback',
|
||||
confidence=0.30,
|
||||
time_ms=0.0,
|
||||
)
|
||||
113
core/grounding/server.py
Normal file
113
core/grounding/server.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Serveur grounding minimaliste — Flask single-thread, même contexte CUDA."""
|
||||
import base64, io, json, math, os, re, time, gc
|
||||
import torch
|
||||
from flask import Flask, request, jsonify
|
||||
from PIL import Image
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
MODEL_ID = os.environ.get("GROUNDING_MODEL", "InfiX-ai/InfiGUI-G1-3B")
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 5600 * 28 * 28
|
||||
_model = None
|
||||
_processor = None
|
||||
|
||||
def _smart_resize(h, w, factor=28):
|
||||
h_bar = max(factor, round(h/factor)*factor)
|
||||
w_bar = max(factor, round(w/factor)*factor)
|
||||
if h_bar*w_bar > MAX_PIXELS:
|
||||
beta = math.sqrt((h*w)/MAX_PIXELS)
|
||||
h_bar = math.floor(h/beta/factor)*factor
|
||||
w_bar = math.floor(w/beta/factor)*factor
|
||||
elif h_bar*w_bar < MIN_PIXELS:
|
||||
beta = math.sqrt(MIN_PIXELS/(h*w))
|
||||
h_bar = math.ceil(h*beta/factor)*factor
|
||||
w_bar = math.ceil(w*beta/factor)*factor
|
||||
return h_bar, w_bar
|
||||
|
||||
def load_model():
|
||||
global _model, _processor
|
||||
if _model is not None:
|
||||
return
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
|
||||
torch.cuda.empty_cache(); gc.collect()
|
||||
print(f"[grounding] Chargement {MODEL_ID}...")
|
||||
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
|
||||
_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_ID, quantization_config=bnb, device_map="auto")
|
||||
_model.eval()
|
||||
_processor = AutoProcessor.from_pretrained(MODEL_ID, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS, padding_side="left")
|
||||
print(f"[grounding] Prêt — VRAM: {torch.cuda.memory_allocated()/1e9:.2f}GB")
|
||||
|
||||
@app.route('/health')
|
||||
def health():
|
||||
return jsonify({"status": "ok", "model": MODEL_ID, "model_loaded": _model is not None,
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"vram_allocated_gb": round(torch.cuda.memory_allocated()/1e9, 2)})
|
||||
|
||||
@app.route('/ground', methods=['POST'])
|
||||
def ground():
|
||||
if _model is None:
|
||||
return jsonify({"error": "Modèle pas chargé"}), 503
|
||||
from qwen_vl_utils import process_vision_info
|
||||
data = request.json
|
||||
target = data.get('target_text', '')
|
||||
desc = data.get('target_description', '')
|
||||
label = f"{target} — {desc}" if desc else target
|
||||
if not label.strip():
|
||||
return jsonify({"error": "target_text requis"}), 400
|
||||
|
||||
# Image
|
||||
if data.get('image_b64'):
|
||||
raw = data['image_b64'].split(',')[1] if ',' in data['image_b64'] else data['image_b64']
|
||||
img = Image.open(io.BytesIO(base64.b64decode(raw))).convert('RGB')
|
||||
else:
|
||||
import mss
|
||||
with mss.mss() as sct:
|
||||
grab = sct.grab(sct.monitors[0])
|
||||
img = Image.frombytes('RGB', grab.size, grab.bgra, 'raw', 'BGRX')
|
||||
|
||||
W, H = img.size
|
||||
rH, rW = _smart_resize(H, W)
|
||||
|
||||
user_text = f'The screen\'s resolution is {rW}x{rH}.\nLocate the UI element(s) for "{label}", output the coordinates using JSON format: [{{"point_2d": [x, y]}}, ...]'
|
||||
system = "You FIRST think about the reasoning process as an internal monologue and then provide the final answer.\nThe reasoning process MUST BE enclosed within <think> </think> tags."
|
||||
|
||||
messages = [{"role": "system", "content": system},
|
||||
{"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": user_text}]}]
|
||||
|
||||
text = _processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = _processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(_model.device)
|
||||
|
||||
t0 = time.time()
|
||||
with torch.no_grad():
|
||||
gen = _model.generate(**inputs, max_new_tokens=512)
|
||||
infer_ms = (time.time()-t0)*1000
|
||||
|
||||
trimmed = [o[len(i):] for i,o in zip(inputs.input_ids, gen)]
|
||||
raw = _processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0].strip()
|
||||
print(f"[grounding] '{label[:40]}' → {raw[:100]} ({infer_ms:.0f}ms)")
|
||||
|
||||
# Parser JSON point_2d
|
||||
json_part = raw.split("</think>")[-1] if "</think>" in raw else raw
|
||||
json_part = json_part.replace("```json","").replace("```","").strip()
|
||||
px, py = None, None
|
||||
try:
|
||||
parsed = json.loads(json_part)
|
||||
if isinstance(parsed, list) and len(parsed) > 0:
|
||||
pt = parsed[0].get("point_2d", [])
|
||||
if len(pt) >= 2:
|
||||
px, py = int(pt[0]*W/rW), int(pt[1]*H/rH)
|
||||
except json.JSONDecodeError:
|
||||
m = re.search(r'"point_2d"\s*:\s*\[(\d+),\s*(\d+)\]', raw)
|
||||
if m:
|
||||
px, py = int(int(m.group(1))*W/rW), int(int(m.group(2))*H/rH)
|
||||
|
||||
return jsonify({"x": px, "y": py, "method": "infigui", "confidence": 0.90 if px else 0.0,
|
||||
"time_ms": round(infer_ms, 1), "raw_output": raw[:300]})
|
||||
|
||||
if __name__ == '__main__':
|
||||
load_model()
|
||||
app.run(host='0.0.0.0', port=8200, threaded=False)
|
||||
156
core/grounding/shadow_learning_hook.py
Normal file
156
core/grounding/shadow_learning_hook.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
core/grounding/shadow_learning_hook.py — Hook d'apprentissage Shadow
|
||||
|
||||
Connecte le ShadowObserver au SignatureStore : chaque clic observé pendant
|
||||
une session Shadow enrichit la base de signatures d'éléments.
|
||||
|
||||
L'humain clique quelque part → on détecte quel élément UI est sous le clic →
|
||||
on stocke sa signature (texte, type, position, voisins) pour le replay.
|
||||
|
||||
Ce module est un HOOK optionnel — il ne modifie pas le ShadowObserver,
|
||||
il s'y branche via callback.
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.shadow_learning_hook import ShadowLearningHook
|
||||
|
||||
hook = ShadowLearningHook()
|
||||
|
||||
# Dans le ShadowObserver ou l'API de capture :
|
||||
hook.on_click_observed(
|
||||
click_x=542, click_y=318,
|
||||
screenshot_pil=screen,
|
||||
window_title="Bloc-notes",
|
||||
target_label="Bouton Valider",
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from core.grounding.element_signature import SignatureStore
|
||||
from core.grounding.fast_types import DetectedUIElement
|
||||
|
||||
|
||||
class ShadowLearningHook:
|
||||
"""Hook d'apprentissage pour le mode Shadow.
|
||||
|
||||
À chaque clic humain observé, détecte l'élément sous le clic
|
||||
et enrichit le SignatureStore.
|
||||
"""
|
||||
|
||||
def __init__(self, signature_store: Optional[SignatureStore] = None):
|
||||
self._store = signature_store or SignatureStore()
|
||||
self._detector = None # Lazy load pour ne pas charger RF-DETR au startup
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def on_click_observed(
|
||||
self,
|
||||
click_x: int,
|
||||
click_y: int,
|
||||
screenshot_pil: Optional[Any] = None,
|
||||
window_title: str = "",
|
||||
target_label: str = "",
|
||||
target_description: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Appelé quand un clic humain est observé pendant le Shadow.
|
||||
|
||||
Args:
|
||||
click_x, click_y: Position du clic (pixels écran).
|
||||
screenshot_pil: Capture d'écran PIL au moment du clic.
|
||||
window_title: Titre de la fenêtre active.
|
||||
target_label: Label de l'étape (si connu).
|
||||
target_description: Description de l'élément (si connue).
|
||||
|
||||
Returns:
|
||||
Dict avec la signature créée/enrichie, ou None si échec.
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
try:
|
||||
# Lazy load du détecteur
|
||||
if self._detector is None:
|
||||
from core.grounding.fast_detector import FastDetector
|
||||
self._detector = FastDetector()
|
||||
|
||||
# Détecter les éléments sur l'écran
|
||||
snapshot = self._detector.detect(screenshot_pil=screenshot_pil)
|
||||
|
||||
if not snapshot.elements:
|
||||
print(f"📝 [Shadow/learn] Aucun élément détecté à ({click_x}, {click_y})")
|
||||
return None
|
||||
|
||||
# Trouver l'élément sous le clic
|
||||
clicked_element = self._find_element_at(click_x, click_y, snapshot.elements)
|
||||
|
||||
if clicked_element is None:
|
||||
print(f"📝 [Shadow/learn] Aucun élément sous ({click_x}, {click_y})")
|
||||
return None
|
||||
|
||||
# Construire la clé de la cible
|
||||
target_key = SignatureStore.make_target_key(
|
||||
target_label or clicked_element.ocr_text,
|
||||
target_description,
|
||||
)
|
||||
screen_ctx = SignatureStore.make_screen_context(
|
||||
window_title, snapshot.resolution,
|
||||
)
|
||||
|
||||
# Enregistrer la signature
|
||||
self._store.record_success(
|
||||
target_key=target_key,
|
||||
screen_context=screen_ctx,
|
||||
element=clicked_element,
|
||||
confidence=1.0, # L'humain a cliqué → confiance maximale
|
||||
)
|
||||
|
||||
dt = (time.time() - t0) * 1000
|
||||
print(f"📝 [Shadow/learn] Signature '{clicked_element.ocr_text}' "
|
||||
f"type={clicked_element.element_type} "
|
||||
f"pos={clicked_element.relative_position} "
|
||||
f"voisins={clicked_element.neighbors[:3]} ({dt:.0f}ms)")
|
||||
|
||||
return {
|
||||
"target_key": target_key,
|
||||
"text": clicked_element.ocr_text,
|
||||
"element_type": clicked_element.element_type,
|
||||
"relative_position": clicked_element.relative_position,
|
||||
"neighbors": clicked_element.neighbors,
|
||||
"center": clicked_element.center,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [Shadow/learn] Erreur: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _find_element_at(
|
||||
x: int, y: int,
|
||||
elements: list,
|
||||
margin: int = 20,
|
||||
) -> Optional[DetectedUIElement]:
|
||||
"""Trouve l'élément dont la bbox contient le point (x, y).
|
||||
|
||||
Si aucun match exact, prend le plus proche dans un rayon de `margin` pixels.
|
||||
"""
|
||||
# Match exact : le clic est dans la bbox
|
||||
for elem in elements:
|
||||
x1, y1, x2, y2 = elem.bbox
|
||||
if x1 <= x <= x2 and y1 <= y <= y2:
|
||||
return elem
|
||||
|
||||
# Match par proximité : le clic est proche du centre
|
||||
best_elem = None
|
||||
best_dist = float('inf')
|
||||
|
||||
for elem in elements:
|
||||
dx = abs(elem.center[0] - x)
|
||||
dy = abs(elem.center[1] - y)
|
||||
dist = (dx**2 + dy**2) ** 0.5
|
||||
if dist < margin and dist < best_dist:
|
||||
best_dist = dist
|
||||
best_elem = elem
|
||||
|
||||
return best_elem
|
||||
263
core/grounding/smart_matcher.py
Normal file
263
core/grounding/smart_matcher.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
core/grounding/smart_matcher.py — Layer SMART : matching déterministe/probabiliste
|
||||
|
||||
Étant donné un ScreenSnapshot (tous les éléments détectés) et un GroundingTarget
|
||||
(ce qu'on cherche), trouve l'élément correspondant avec un score de confiance.
|
||||
|
||||
Pipeline de matching (court-circuit au premier match haute confiance) :
|
||||
1. Texte exact (2ms) → score 0.95
|
||||
2. Texte fuzzy ratio (5ms) → score 0.70-0.90
|
||||
3. Type + position (2ms) → bonus/malus
|
||||
4. Voisins contextuels (5ms) → bonus
|
||||
5. Score combiné → MatchCandidate
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.smart_matcher import SmartMatcher
|
||||
from core.grounding.fast_types import ScreenSnapshot
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
matcher = SmartMatcher()
|
||||
candidate = matcher.match(snapshot, GroundingTarget(text="Valider"))
|
||||
if candidate and candidate.score >= 0.90:
|
||||
print(f"Match direct : ({candidate.element.center}) score={candidate.score}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from core.grounding.fast_types import DetectedUIElement, MatchCandidate, ScreenSnapshot
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
|
||||
class SmartMatcher:
|
||||
"""Matching intelligent entre une cible et les éléments détectés.
|
||||
|
||||
Combine plusieurs signaux (texte, type, position, voisins) en un score
|
||||
de confiance unique pour chaque candidat.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_text: float = 0.50,
|
||||
weight_type: float = 0.10,
|
||||
weight_position: float = 0.15,
|
||||
weight_neighbors: float = 0.25,
|
||||
):
|
||||
self.w_text = weight_text
|
||||
self.w_type = weight_type
|
||||
self.w_position = weight_position
|
||||
self.w_neighbors = weight_neighbors
|
||||
|
||||
def match(
|
||||
self,
|
||||
snapshot: ScreenSnapshot,
|
||||
target: GroundingTarget,
|
||||
signature: Optional[Dict] = None,
|
||||
) -> Optional[MatchCandidate]:
|
||||
"""Trouve le MEILLEUR élément correspondant à la cible.
|
||||
|
||||
Returns:
|
||||
Le MatchCandidate avec le score le plus élevé, ou None si aucun match.
|
||||
"""
|
||||
candidates = self.match_all(snapshot, target, signature)
|
||||
if not candidates:
|
||||
return None
|
||||
return candidates[0]
|
||||
|
||||
def match_all(
|
||||
self,
|
||||
snapshot: ScreenSnapshot,
|
||||
target: GroundingTarget,
|
||||
signature: Optional[Dict] = None,
|
||||
) -> List[MatchCandidate]:
|
||||
"""Trouve TOUS les candidats triés par score décroissant.
|
||||
|
||||
Args:
|
||||
snapshot: État de l'écran (éléments détectés + OCR).
|
||||
target: Ce qu'on cherche (texte, description, bbox d'origine).
|
||||
signature: Signature apprise (optionnel, enrichit le matching).
|
||||
|
||||
Returns:
|
||||
Liste de MatchCandidate triée par score décroissant.
|
||||
"""
|
||||
if not snapshot.elements:
|
||||
return []
|
||||
|
||||
target_text = (target.text or "").strip()
|
||||
target_desc = (target.description or "").strip()
|
||||
search_text = target_text or target_desc
|
||||
|
||||
if not search_text:
|
||||
return []
|
||||
|
||||
candidates = []
|
||||
search_lower = self._normalize(search_text)
|
||||
|
||||
for elem in snapshot.elements:
|
||||
score_detail: Dict[str, float] = {}
|
||||
method = ""
|
||||
|
||||
# --- 1. Score texte ---
|
||||
text_score = self._score_text(search_lower, elem.ocr_text)
|
||||
score_detail["text"] = text_score
|
||||
|
||||
if text_score >= 0.95:
|
||||
method = "exact_text"
|
||||
elif text_score >= 0.70:
|
||||
method = "fuzzy_text"
|
||||
|
||||
# --- 2. Score type (si signature connue) ---
|
||||
type_score = 0.5 # neutre par défaut
|
||||
if signature and signature.get("element_type"):
|
||||
if elem.element_type == signature["element_type"]:
|
||||
type_score = 1.0
|
||||
elif elem.element_type == "element":
|
||||
type_score = 0.5 # non classifié, neutre
|
||||
else:
|
||||
type_score = 0.2
|
||||
score_detail["type"] = type_score
|
||||
|
||||
# --- 3. Score position (si bbox d'origine connue) ---
|
||||
position_score = 0.5 # neutre
|
||||
if target.original_bbox:
|
||||
position_score = self._score_position(
|
||||
elem.center, target.original_bbox,
|
||||
snapshot.resolution[0], snapshot.resolution[1],
|
||||
)
|
||||
elif signature and signature.get("relative_position"):
|
||||
if elem.relative_position == signature["relative_position"]:
|
||||
position_score = 0.9
|
||||
else:
|
||||
position_score = 0.3
|
||||
score_detail["position"] = position_score
|
||||
|
||||
# --- 4. Score voisins (si signature connue) ---
|
||||
neighbor_score = 0.5 # neutre
|
||||
if signature and signature.get("neighbors"):
|
||||
neighbor_score = self._score_neighbors(
|
||||
elem.neighbors, signature["neighbors"]
|
||||
)
|
||||
score_detail["neighbors"] = neighbor_score
|
||||
|
||||
# --- Score combiné ---
|
||||
combined = (
|
||||
self.w_text * text_score
|
||||
+ self.w_type * type_score
|
||||
+ self.w_position * position_score
|
||||
+ self.w_neighbors * neighbor_score
|
||||
)
|
||||
|
||||
# Seuil minimum : pas de candidat si le texte ne matche pas du tout
|
||||
if text_score < 0.30:
|
||||
continue
|
||||
|
||||
if not method:
|
||||
method = "combined"
|
||||
|
||||
candidates.append(MatchCandidate(
|
||||
element=elem,
|
||||
score=combined,
|
||||
score_detail=score_detail,
|
||||
method=method,
|
||||
))
|
||||
|
||||
# Trier par score décroissant
|
||||
candidates.sort(key=lambda c: c.score, reverse=True)
|
||||
|
||||
return candidates
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scoring texte
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _score_text(self, search: str, ocr_text: str) -> float:
|
||||
"""Score de similarité textuelle (0-1)."""
|
||||
if not ocr_text:
|
||||
return 0.0
|
||||
|
||||
ocr_lower = self._normalize(ocr_text)
|
||||
|
||||
# Match exact
|
||||
if search == ocr_lower:
|
||||
return 1.0
|
||||
|
||||
# Inclusion (l'un contient l'autre)
|
||||
if search in ocr_lower or ocr_lower in search:
|
||||
overlap = min(len(search), len(ocr_lower))
|
||||
total = max(len(search), len(ocr_lower))
|
||||
if total > 0:
|
||||
return 0.70 + 0.25 * (overlap / total)
|
||||
|
||||
# Fuzzy matching (SequenceMatcher, standard library)
|
||||
ratio = SequenceMatcher(None, search, ocr_lower).ratio()
|
||||
if ratio >= 0.60:
|
||||
return 0.50 + 0.40 * ratio
|
||||
|
||||
return ratio * 0.3
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scoring position
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _score_position(
|
||||
center: tuple,
|
||||
original_bbox: dict,
|
||||
screen_w: int,
|
||||
screen_h: int,
|
||||
) -> float:
|
||||
"""Score de proximité par rapport à la position d'origine (0-1)."""
|
||||
if not original_bbox:
|
||||
return 0.5
|
||||
|
||||
orig_x = original_bbox.get("x", 0) + original_bbox.get("width", 0) / 2
|
||||
orig_y = original_bbox.get("y", 0) + original_bbox.get("height", 0) / 2
|
||||
|
||||
dx = abs(center[0] - orig_x) / max(screen_w, 1)
|
||||
dy = abs(center[1] - orig_y) / max(screen_h, 1)
|
||||
distance_norm = (dx**2 + dy**2) ** 0.5
|
||||
|
||||
# distance 0 = score 1.0, distance 0.5 (demi-écran) = score ~0.2
|
||||
return max(0.0, 1.0 - distance_norm * 2.0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scoring voisins
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _score_neighbors(
|
||||
current_neighbors: List[str],
|
||||
expected_neighbors: List[str],
|
||||
) -> float:
|
||||
"""Score Jaccard sur les ensembles de mots voisins (0-1)."""
|
||||
if not expected_neighbors:
|
||||
return 0.5
|
||||
|
||||
current_set = {n.lower().strip() for n in current_neighbors if n}
|
||||
expected_set = {n.lower().strip() for n in expected_neighbors if n}
|
||||
|
||||
if not current_set and not expected_set:
|
||||
return 0.5
|
||||
|
||||
intersection = current_set & expected_set
|
||||
union = current_set | expected_set
|
||||
|
||||
if not union:
|
||||
return 0.5
|
||||
|
||||
return len(intersection) / len(union)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utilitaires
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _normalize(text: str) -> str:
|
||||
"""Normalise un texte pour la comparaison."""
|
||||
text = text.lower().strip()
|
||||
text = re.sub(r'[_\-\./\\]', ' ', text)
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text
|
||||
48
core/grounding/target.py
Normal file
48
core/grounding/target.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
core/grounding/target.py — Types partagés pour le grounding visuel
|
||||
|
||||
Dataclasses décrivant une cible à localiser (GroundingTarget) et
|
||||
le résultat d'une localisation (GroundingResult).
|
||||
|
||||
Ces types sont la brique commune pour tous les modules de grounding :
|
||||
template matching, OCR, VLM, CLIP, etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroundingTarget:
|
||||
"""Description d'un élément UI à localiser sur l'écran.
|
||||
|
||||
Attributs :
|
||||
text : texte visible de l'élément (bouton, label, etc.)
|
||||
description : description sémantique libre (ex: "le bouton Valider en bas à droite")
|
||||
template_b64 : capture visuelle de l'élément, encodée en base64 PNG/JPEG
|
||||
original_bbox : position d'origine lors de la capture {x, y, width, height}
|
||||
"""
|
||||
text: str = ""
|
||||
description: str = ""
|
||||
template_b64: str = ""
|
||||
original_bbox: Optional[Dict[str, int]] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroundingResult:
|
||||
"""Résultat d'une localisation d'élément UI.
|
||||
|
||||
Attributs :
|
||||
x : coordonnée X du centre de l'élément trouvé (pixels écran)
|
||||
y : coordonnée Y du centre de l'élément trouvé (pixels écran)
|
||||
method : méthode ayant produit le résultat ('template', 'ocr', 'vlm', 'clip', etc.)
|
||||
confidence : score de confiance [0.0 – 1.0]
|
||||
time_ms : temps de recherche en millisecondes
|
||||
"""
|
||||
x: int
|
||||
y: int
|
||||
method: str
|
||||
confidence: float
|
||||
time_ms: float
|
||||
350
core/grounding/template_matcher.py
Normal file
350
core/grounding/template_matcher.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
core/grounding/template_matcher.py — Template matching centralisé
|
||||
|
||||
Fournit une classe TemplateMatcher qui localise une ancre visuelle (image template)
|
||||
dans un screenshot via cv2.matchTemplate. Supporte single-scale et multi-scale.
|
||||
|
||||
Remplace les implémentations dupliquées dans :
|
||||
- core/execution/observe_reason_act.py (~1348-1375)
|
||||
- visual_workflow_builder/backend/api_v3/execute.py (~930-963)
|
||||
- visual_workflow_builder/backend/catalog_routes_v2_vlm.py (~339-381)
|
||||
- visual_workflow_builder/backend/services/intelligent_executor.py (~131-210)
|
||||
- core/detection/omniparser_adapter.py (~330)
|
||||
|
||||
Utilisation :
|
||||
from core.grounding import TemplateMatcher, MatchResult
|
||||
|
||||
matcher = TemplateMatcher(threshold=0.75)
|
||||
result = matcher.match_screen(anchor_b64="...")
|
||||
if result:
|
||||
print(f"Trouvé à ({result.x}, {result.y}) score={result.score:.3f}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Imports optionnels — le module se charge même sans cv2/PIL/mss
|
||||
try:
|
||||
import cv2
|
||||
_CV2 = True
|
||||
except ImportError:
|
||||
_CV2 = False
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
_NP = True
|
||||
except ImportError:
|
||||
_NP = False
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
_PIL = True
|
||||
except ImportError:
|
||||
_PIL = False
|
||||
|
||||
try:
|
||||
import mss as mss_lib
|
||||
_MSS = True
|
||||
except ImportError:
|
||||
_MSS = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Résultat d'un match
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MatchResult:
|
||||
"""Résultat d'un template matching."""
|
||||
x: int
|
||||
y: int
|
||||
score: float
|
||||
method: str # 'template' | 'template_multiscale'
|
||||
time_ms: float
|
||||
scale: float = 1.0 # Échelle à laquelle le meilleur match a été trouvé
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TemplateMatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TemplateMatcher:
|
||||
"""Localise une ancre visuelle dans un screenshot via template matching.
|
||||
|
||||
Paramètres :
|
||||
threshold : score minimum pour accepter un match (défaut 0.75)
|
||||
multiscale : active le matching multi-échelle (défaut False)
|
||||
scales : liste d'échelles à tester en mode multi-scale
|
||||
method : méthode cv2 (défaut cv2.TM_CCOEFF_NORMED)
|
||||
grayscale : convertir en niveaux de gris avant matching (défaut False)
|
||||
"""
|
||||
|
||||
# Échelles par défaut pour le mode multi-scale, ordonnées par
|
||||
# probabilité décroissante (1.0 en premier = rapide si ça matche)
|
||||
DEFAULT_SCALES: List[float] = [1.0, 0.95, 1.05, 0.9, 1.1, 0.85, 1.15, 0.8, 1.2]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.75,
|
||||
multiscale: bool = False,
|
||||
scales: Optional[List[float]] = None,
|
||||
grayscale: bool = False,
|
||||
):
|
||||
self.threshold = threshold
|
||||
self.multiscale = multiscale
|
||||
self.scales = scales or self.DEFAULT_SCALES
|
||||
self.grayscale = grayscale
|
||||
# cv2.TM_CCOEFF_NORMED est la méthode utilisée partout dans le projet
|
||||
self._cv2_method = cv2.TM_CCOEFF_NORMED if _CV2 else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API publique
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def match_screen(
|
||||
self,
|
||||
anchor_b64: Optional[str] = None,
|
||||
anchor_pil: Optional["Image.Image"] = None,
|
||||
screen_pil: Optional["Image.Image"] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Cherche l'ancre dans le screenshot courant (ou fourni).
|
||||
|
||||
L'ancre peut être passée en base64 ou en PIL Image.
|
||||
Le screenshot est capturé via mss si non fourni.
|
||||
|
||||
Retourne un MatchResult ou None si aucun match >= seuil.
|
||||
"""
|
||||
if not (_CV2 and _NP and _PIL):
|
||||
logger.debug("[TemplateMatcher] cv2/numpy/PIL non disponible")
|
||||
return None
|
||||
|
||||
# --- Préparer l'ancre ---
|
||||
anchor_img = self._decode_anchor(anchor_b64, anchor_pil)
|
||||
if anchor_img is None:
|
||||
return None
|
||||
|
||||
# --- Préparer le screenshot ---
|
||||
if screen_pil is None:
|
||||
screen_pil = self._capture_screen()
|
||||
if screen_pil is None:
|
||||
return None
|
||||
|
||||
# --- Convertir en arrays cv2 ---
|
||||
screen_cv = cv2.cvtColor(np.array(screen_pil), cv2.COLOR_RGB2BGR)
|
||||
anchor_cv = cv2.cvtColor(np.array(anchor_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# --- Matching ---
|
||||
if self.multiscale:
|
||||
return self._match_multiscale(screen_cv, anchor_cv)
|
||||
else:
|
||||
return self._match_single(screen_cv, anchor_cv)
|
||||
|
||||
def match_in_region(
|
||||
self,
|
||||
region_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Match dans une région déjà découpée (arrays BGR).
|
||||
|
||||
Utilisé par les pipelines qui font leur propre capture/découpe.
|
||||
"""
|
||||
if not (_CV2 and _NP):
|
||||
return None
|
||||
|
||||
thr = threshold if threshold is not None else self.threshold
|
||||
|
||||
if self.multiscale:
|
||||
return self._match_multiscale(region_cv, anchor_cv, threshold_override=thr)
|
||||
else:
|
||||
return self._match_single(region_cv, anchor_cv, threshold_override=thr)
|
||||
|
||||
def match_screen_diagnostic(
|
||||
self,
|
||||
anchor_b64: Optional[str] = None,
|
||||
anchor_pil: Optional["Image.Image"] = None,
|
||||
screen_pil: Optional["Image.Image"] = None,
|
||||
) -> str:
|
||||
"""Retourne un diagnostic textuel (score + position) même sans match."""
|
||||
if not (_CV2 and _NP and _PIL):
|
||||
return "cv2/numpy/PIL non dispo"
|
||||
|
||||
anchor_img = self._decode_anchor(anchor_b64, anchor_pil)
|
||||
if anchor_img is None:
|
||||
return "ancre non décodable"
|
||||
|
||||
if screen_pil is None:
|
||||
screen_pil = self._capture_screen()
|
||||
if screen_pil is None:
|
||||
return "capture écran échouée"
|
||||
|
||||
screen_cv = cv2.cvtColor(np.array(screen_pil), cv2.COLOR_RGB2BGR)
|
||||
anchor_cv = cv2.cvtColor(np.array(anchor_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
if anchor_cv.shape[0] >= screen_cv.shape[0] or anchor_cv.shape[1] >= screen_cv.shape[1]:
|
||||
return f"ancre {anchor_cv.shape[:2]} >= écran {screen_cv.shape[:2]}"
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, anchor_cv)
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
return f"{max_val:.3f} pos={max_loc}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Méthodes internes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _match_single(
|
||||
self,
|
||||
screen_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold_override: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Template matching single-scale."""
|
||||
threshold = threshold_override if threshold_override is not None else self.threshold
|
||||
|
||||
if anchor_cv.shape[0] >= screen_cv.shape[0] or anchor_cv.shape[1] >= screen_cv.shape[1]:
|
||||
logger.debug("[TemplateMatcher] Ancre plus grande que le screen")
|
||||
return None
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, anchor_cv)
|
||||
|
||||
t0 = time.time()
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
|
||||
logger.debug(
|
||||
"[TemplateMatcher] score=%.3f pos=%s (%.0fms)",
|
||||
max_val, max_loc, elapsed_ms,
|
||||
)
|
||||
|
||||
if max_val >= threshold:
|
||||
cx = max_loc[0] + anchor_cv.shape[1] // 2
|
||||
cy = max_loc[1] + anchor_cv.shape[0] // 2
|
||||
return MatchResult(
|
||||
x=cx,
|
||||
y=cy,
|
||||
score=float(max_val),
|
||||
method='template',
|
||||
time_ms=elapsed_ms,
|
||||
scale=1.0,
|
||||
)
|
||||
return None
|
||||
|
||||
def _match_multiscale(
|
||||
self,
|
||||
screen_cv: "np.ndarray",
|
||||
anchor_cv: "np.ndarray",
|
||||
threshold_override: Optional[float] = None,
|
||||
) -> Optional[MatchResult]:
|
||||
"""Template matching multi-scale."""
|
||||
threshold = threshold_override if threshold_override is not None else self.threshold
|
||||
|
||||
best_score = -1.0
|
||||
best_loc = None
|
||||
best_scale = 1.0
|
||||
best_anchor_shape = anchor_cv.shape
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
for scale in self.scales:
|
||||
if scale == 1.0:
|
||||
scaled = anchor_cv
|
||||
else:
|
||||
new_w = int(anchor_cv.shape[1] * scale)
|
||||
new_h = int(anchor_cv.shape[0] * scale)
|
||||
if new_w < 8 or new_h < 8:
|
||||
continue
|
||||
if new_h >= screen_cv.shape[0] or new_w >= screen_cv.shape[1]:
|
||||
continue
|
||||
scaled = cv2.resize(anchor_cv, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
if scaled.shape[0] >= screen_cv.shape[0] or scaled.shape[1] >= screen_cv.shape[1]:
|
||||
continue
|
||||
|
||||
s_img, a_img = self._maybe_grayscale(screen_cv, scaled)
|
||||
result_tm = cv2.matchTemplate(s_img, a_img, self._cv2_method)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result_tm)
|
||||
|
||||
if max_val > best_score:
|
||||
best_score = max_val
|
||||
best_loc = max_loc
|
||||
best_scale = scale
|
||||
best_anchor_shape = scaled.shape
|
||||
|
||||
elapsed_ms = (time.time() - t0) * 1000
|
||||
|
||||
logger.debug(
|
||||
"[TemplateMatcher/multiscale] best_score=%.3f scale=%.2f (%.0fms)",
|
||||
best_score, best_scale, elapsed_ms,
|
||||
)
|
||||
|
||||
if best_score >= threshold and best_loc is not None:
|
||||
cx = best_loc[0] + best_anchor_shape[1] // 2
|
||||
cy = best_loc[1] + best_anchor_shape[0] // 2
|
||||
return MatchResult(
|
||||
x=cx,
|
||||
y=cy,
|
||||
score=float(best_score),
|
||||
method='template_multiscale',
|
||||
time_ms=elapsed_ms,
|
||||
scale=best_scale,
|
||||
)
|
||||
return None
|
||||
|
||||
def _maybe_grayscale(
|
||||
self,
|
||||
screen: "np.ndarray",
|
||||
anchor: "np.ndarray",
|
||||
) -> Tuple["np.ndarray", "np.ndarray"]:
|
||||
"""Convertit en niveaux de gris si self.grayscale est True."""
|
||||
if not self.grayscale:
|
||||
return screen, anchor
|
||||
s = cv2.cvtColor(screen, cv2.COLOR_BGR2GRAY) if len(screen.shape) == 3 else screen
|
||||
a = cv2.cvtColor(anchor, cv2.COLOR_BGR2GRAY) if len(anchor.shape) == 3 else anchor
|
||||
return s, a
|
||||
|
||||
@staticmethod
|
||||
def _decode_anchor(
|
||||
anchor_b64: Optional[str],
|
||||
anchor_pil: Optional["Image.Image"],
|
||||
) -> Optional["Image.Image"]:
|
||||
"""Décode l'ancre depuis base64 ou retourne le PIL directement."""
|
||||
if anchor_pil is not None:
|
||||
return anchor_pil
|
||||
|
||||
if anchor_b64 is None:
|
||||
logger.debug("[TemplateMatcher] Ni anchor_b64 ni anchor_pil fourni")
|
||||
return None
|
||||
|
||||
try:
|
||||
raw = anchor_b64.split(',')[1] if ',' in anchor_b64 else anchor_b64
|
||||
data = base64.b64decode(raw)
|
||||
return Image.open(io.BytesIO(data))
|
||||
except Exception as e:
|
||||
logger.debug("[TemplateMatcher] Erreur décodage ancre: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _capture_screen() -> Optional["Image.Image"]:
|
||||
"""Capture l'écran complet via mss (moniteur 0 = tous les écrans)."""
|
||||
if not _MSS:
|
||||
logger.debug("[TemplateMatcher] mss non disponible")
|
||||
return None
|
||||
|
||||
try:
|
||||
with mss_lib.mss() as sct:
|
||||
mon = sct.monitors[0]
|
||||
grab = sct.grab(mon)
|
||||
return Image.frombytes('RGB', grab.size, grab.bgra, 'raw', 'BGRX')
|
||||
except Exception as e:
|
||||
logger.debug("[TemplateMatcher] Erreur capture écran: %s", e)
|
||||
return None
|
||||
103
core/grounding/think_arbiter.py
Normal file
103
core/grounding/think_arbiter.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
core/grounding/think_arbiter.py — Layer THINK : VLM arbitre (InfiGUI via subprocess)
|
||||
|
||||
Appelé UNIQUEMENT quand le SmartMatcher n'a pas assez confiance.
|
||||
Utilise le subprocess worker InfiGUI (pas de serveur HTTP).
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.think_arbiter import ThinkArbiter
|
||||
|
||||
arbiter = ThinkArbiter()
|
||||
result = arbiter.arbitrate(target, candidates, screenshot)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from core.grounding.fast_types import LocateResult, MatchCandidate
|
||||
from core.grounding.target import GroundingTarget
|
||||
|
||||
|
||||
class ThinkArbiter:
|
||||
"""Arbitre VLM — appelle InfiGUI via subprocess worker."""
|
||||
|
||||
def __init__(self):
|
||||
self._grounder = None
|
||||
|
||||
def _get_grounder(self):
|
||||
if self._grounder is None:
|
||||
from core.grounding.ui_tars_grounder import UITarsGrounder
|
||||
self._grounder = UITarsGrounder.get_instance()
|
||||
return self._grounder
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
"""Toujours disponible — le worker se lance à la demande."""
|
||||
return True
|
||||
|
||||
def arbitrate(
|
||||
self,
|
||||
target: GroundingTarget,
|
||||
candidates: List[MatchCandidate],
|
||||
screenshot_pil: Optional[Any] = None,
|
||||
) -> Optional[LocateResult]:
|
||||
"""Demande au VLM de trancher.
|
||||
|
||||
Si target.template_b64 est fourni, on bascule en mode fusionné :
|
||||
le crop est passé comme image de référence à InfiGUI, ce qui évite
|
||||
une description Ollama qwen2.5vl coûteuse en VRAM.
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# Décodage du crop d'ancre si disponible (mode fusionné)
|
||||
anchor_pil = None
|
||||
if target.template_b64:
|
||||
try:
|
||||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
|
||||
raw_b64 = target.template_b64
|
||||
if ',' in raw_b64:
|
||||
raw_b64 = raw_b64.split(',', 1)[1]
|
||||
anchor_pil = Image.open(io.BytesIO(base64.b64decode(raw_b64))).convert("RGB")
|
||||
except Exception as ex:
|
||||
print(f"⚠️ [THINK] Décodage anchor échoué: {ex}")
|
||||
anchor_pil = None
|
||||
|
||||
try:
|
||||
grounder = self._get_grounder()
|
||||
result = grounder.ground(
|
||||
target_text=target.text or "",
|
||||
target_description=target.description or "",
|
||||
screen_pil=screenshot_pil,
|
||||
anchor_pil=anchor_pil,
|
||||
)
|
||||
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if result is None:
|
||||
label = target.text or "<crop>"
|
||||
print(f"🤔 [THINK] VLM n'a pas trouvé '{label}' ({dt:.0f}ms)")
|
||||
return None
|
||||
|
||||
method = "think_vlm_fused" if anchor_pil is not None else "think_vlm"
|
||||
locate = LocateResult(
|
||||
x=result.x,
|
||||
y=result.y,
|
||||
confidence=result.confidence,
|
||||
method=method,
|
||||
time_ms=dt,
|
||||
tier="think",
|
||||
candidates_count=len(candidates),
|
||||
)
|
||||
|
||||
print(f"🤔 [THINK/{method}] ({result.x}, {result.y}) conf={result.confidence:.2f} ({dt:.0f}ms)")
|
||||
return locate
|
||||
|
||||
except Exception as ex:
|
||||
dt = (time.time() - t0) * 1000
|
||||
print(f"⚠️ [THINK] Erreur: {ex} ({dt:.0f}ms)")
|
||||
return None
|
||||
174
core/grounding/title_verifier.py
Normal file
174
core/grounding/title_verifier.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
core/grounding/title_verifier.py — Vérification post-action par titre de fenêtre
|
||||
|
||||
Après chaque action (clic, double-clic), vérifie que la fenêtre active
|
||||
a changé de manière attendue en lisant le titre via OCR sur un crop
|
||||
de 45px en haut de l'écran.
|
||||
|
||||
Léger (~120ms), non-bloquant (échec = warning + retry, pas stop).
|
||||
|
||||
Utilisation :
|
||||
from core.grounding.title_verifier import TitleVerifier
|
||||
|
||||
verifier = TitleVerifier()
|
||||
title = verifier.read_title(screenshot_pil)
|
||||
changed = verifier.has_title_changed(title_before, title_after)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TitleVerifier:
|
||||
"""Vérifie le titre de la fenêtre active via OCR sur crop."""
|
||||
|
||||
# Hauteur du crop pour la barre de titre Windows
|
||||
TITLE_BAR_HEIGHT = 45
|
||||
|
||||
def __init__(self):
|
||||
self._ocr_fn = None # Lazy load
|
||||
|
||||
def read_title(self, screenshot_pil) -> str:
|
||||
"""Lit le titre de la fenêtre active via OCR sur le crop supérieur.
|
||||
|
||||
Args:
|
||||
screenshot_pil: Image PIL du screenshot complet.
|
||||
|
||||
Returns:
|
||||
Texte du titre (peut être vide si OCR échoue).
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
try:
|
||||
w, h = screenshot_pil.size
|
||||
# Crop la barre de titre (45px du haut)
|
||||
title_crop = screenshot_pil.crop((0, 0, w, min(self.TITLE_BAR_HEIGHT, h)))
|
||||
|
||||
# OCR sur le petit crop
|
||||
ocr_fn = self._get_ocr()
|
||||
if ocr_fn is None:
|
||||
return ""
|
||||
|
||||
text = ocr_fn(title_crop)
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
# Nettoyer le texte
|
||||
title = text.strip() if text else ""
|
||||
if title:
|
||||
print(f"📋 [TitleVerify] Titre lu: '{title[:60]}' ({dt:.0f}ms)")
|
||||
|
||||
return title
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ [TitleVerify] Erreur lecture titre: {e}")
|
||||
return ""
|
||||
|
||||
def has_title_changed(self, title_before: str, title_after: str) -> bool:
|
||||
"""Vérifie si le titre a changé de manière significative."""
|
||||
if not title_before and not title_after:
|
||||
return False
|
||||
if not title_before or not title_after:
|
||||
return True # Un des deux est vide = changement
|
||||
|
||||
# Comparaison fuzzy — les titres peuvent avoir des variations mineures
|
||||
ratio = SequenceMatcher(None, title_before.lower(), title_after.lower()).ratio()
|
||||
return ratio < 0.85 # Changement si < 85% similaire
|
||||
|
||||
def verify_action(
|
||||
self,
|
||||
screenshot_before,
|
||||
screenshot_after,
|
||||
action_type: str,
|
||||
) -> dict:
|
||||
"""Vérifie qu'une action a produit l'effet attendu sur le titre.
|
||||
|
||||
Args:
|
||||
screenshot_before: Screenshot PIL avant l'action.
|
||||
screenshot_after: Screenshot PIL après l'action.
|
||||
action_type: Type d'action ("double_click", "click", "type", "hotkey").
|
||||
|
||||
Returns:
|
||||
Dict avec success, title_before, title_after, changed.
|
||||
"""
|
||||
# Les actions qui ne changent pas le titre
|
||||
if action_type in ('type_text', 'keyboard_shortcut', 'wait_for_anchor', 'hover'):
|
||||
return {
|
||||
'success': True,
|
||||
'title_before': '',
|
||||
'title_after': '',
|
||||
'changed': False,
|
||||
'reason': f"Action '{action_type}' — vérification titre non requise",
|
||||
}
|
||||
|
||||
title_before = self.read_title(screenshot_before)
|
||||
title_after = self.read_title(screenshot_after)
|
||||
changed = self.has_title_changed(title_before, title_after)
|
||||
|
||||
# Pour un double-clic (ouverture fichier/dossier), le titre DOIT changer
|
||||
# Mais seulement si les titres lus sont significatifs (> 3 chars)
|
||||
# docTR sur un crop 45px dans une VM peut donner du bruit ('o', 'a', etc.)
|
||||
if action_type in ('double_click_anchor',) and not changed:
|
||||
if len(title_before) > 3 and len(title_after) > 3:
|
||||
return {
|
||||
'success': False,
|
||||
'title_before': title_before,
|
||||
'title_after': title_after,
|
||||
'changed': False,
|
||||
'reason': f"Double-clic sans changement de titre ('{title_after[:40]}')",
|
||||
}
|
||||
# Titres trop courts = bruit OCR, on ne peut pas conclure
|
||||
return {
|
||||
'success': True,
|
||||
'title_before': title_before,
|
||||
'title_after': title_after,
|
||||
'changed': False,
|
||||
'reason': f"Titre trop court pour vérifier ('{title_after}')",
|
||||
}
|
||||
|
||||
# Pour un clic simple, le changement est optionnel
|
||||
return {
|
||||
'success': True,
|
||||
'title_before': title_before,
|
||||
'title_after': title_after,
|
||||
'changed': changed,
|
||||
'reason': 'Titre changé' if changed else 'Titre identique (acceptable)',
|
||||
}
|
||||
|
||||
_easyocr_reader = None # Singleton partagé
|
||||
|
||||
def _get_ocr(self):
|
||||
"""Lazy load de la fonction OCR (EasyOCR prioritaire, fallback docTR)."""
|
||||
if self._ocr_fn is not None:
|
||||
return self._ocr_fn
|
||||
|
||||
# EasyOCR (rapide, bonne qualité GUI)
|
||||
try:
|
||||
import easyocr
|
||||
import numpy as np
|
||||
|
||||
if TitleVerifier._easyocr_reader is None:
|
||||
TitleVerifier._easyocr_reader = easyocr.Reader(
|
||||
['fr', 'en'], gpu=True, verbose=False
|
||||
)
|
||||
|
||||
def _easyocr_extract_text(img):
|
||||
results = TitleVerifier._easyocr_reader.readtext(np.array(img))
|
||||
return ' '.join(r[1] for r in results if r[1].strip())
|
||||
|
||||
self._ocr_fn = _easyocr_extract_text
|
||||
return self._ocr_fn
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fallback docTR
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, 'visual_workflow_builder/backend')
|
||||
from services.ocr_service import ocr_extract_text
|
||||
self._ocr_fn = ocr_extract_text
|
||||
return self._ocr_fn
|
||||
except ImportError:
|
||||
return None
|
||||
161
core/grounding/ui_tars_grounder.py
Normal file
161
core/grounding/ui_tars_grounder.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
core/grounding/ui_tars_grounder.py — Grounding via script one-shot InfiGUI
|
||||
|
||||
Chaque appel lance un subprocess Python qui charge le modèle, infère, et quitte.
|
||||
Lent (~15s) mais fiable — pas de crash CUDA en process persistant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.grounding.target import GroundingResult
|
||||
|
||||
_instance: Optional[UITarsGrounder] = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
|
||||
class UITarsGrounder:
|
||||
"""Grounding via script one-shot InfiGUI."""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._project_root = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> UITarsGrounder:
|
||||
global _instance
|
||||
if _instance is None:
|
||||
with _instance_lock:
|
||||
if _instance is None:
|
||||
_instance = cls()
|
||||
return _instance
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
return True # Toujours disponible — le script se lance à la demande
|
||||
|
||||
def ground(
|
||||
self,
|
||||
target_text: str = "",
|
||||
target_description: str = "",
|
||||
screen_pil=None,
|
||||
anchor_pil=None,
|
||||
) -> Optional[GroundingResult]:
|
||||
"""Localise un élément UI via un script one-shot InfiGUI.
|
||||
|
||||
Args:
|
||||
target_text: nom textuel de la cible (peut être vide si anchor_pil fourni).
|
||||
target_description: description sémantique libre.
|
||||
screen_pil: screenshot complet (PIL.Image).
|
||||
anchor_pil: crop visuel de l'ancre capturée précédemment (PIL.Image).
|
||||
Si fourni, le worker passe en mode fusionné : Image1=crop, Image2=screen,
|
||||
"trouve sur l'image 2 l'élément visuel de l'image 1".
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
# Sauver l'image principale
|
||||
image_path = "/tmp/infigui_screen.png"
|
||||
if screen_pil is not None:
|
||||
screen_pil.save(image_path)
|
||||
|
||||
# Sauver l'image d'ancre (mode fusionné)
|
||||
anchor_image_path = ""
|
||||
if anchor_pil is not None:
|
||||
anchor_image_path = "/tmp/infigui_anchor.png"
|
||||
anchor_pil.save(anchor_image_path)
|
||||
|
||||
# Construire la requête JSON
|
||||
req = json.dumps({
|
||||
"target": target_text,
|
||||
"description": target_description,
|
||||
"image_path": image_path,
|
||||
"anchor_image_path": anchor_image_path,
|
||||
})
|
||||
|
||||
mode_str = "fused" if anchor_pil is not None else "text"
|
||||
label_short = target_text[:30] if target_text else "<crop only>"
|
||||
print(f"🎯 [InfiGUI] Lancement one-shot [{mode_str}]: '{label_short}'")
|
||||
|
||||
# Lancer le script one-shot
|
||||
# IMPORTANT: depuis un service systemd où le parent a déjà chargé CUDA,
|
||||
# le subprocess hérite d'un état GPU cassé (No CUDA GPUs available).
|
||||
# Solutions : start_new_session=True (nouveau cgroup) + forcer
|
||||
# CUDA_VISIBLE_DEVICES=0 explicitement pour bypass l'héritage parent.
|
||||
_child_env = {**os.environ}
|
||||
_child_env["PYTHONDONTWRITEBYTECODE"] = "1"
|
||||
_child_env["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
_child_env["NVIDIA_VISIBLE_DEVICES"] = "all"
|
||||
# Supprimer les variables Python qui pourraient pointer sur l'état parent
|
||||
_child_env.pop("PYTORCH_NVML_BASED_CUDA_CHECK", None)
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "core.grounding.infigui_worker"],
|
||||
input=req + "\n",
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
cwd=self._project_root,
|
||||
env=_child_env,
|
||||
start_new_session=True, # nouveau session group, isole du parent
|
||||
close_fds=True,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr_lines = (result.stderr or '').strip().split('\n')
|
||||
# Afficher les dernières lignes significatives du stderr
|
||||
last_err = [l for l in stderr_lines[-5:] if l.strip()]
|
||||
print(f"⚠️ [InfiGUI] Script échoué (code {result.returncode})")
|
||||
for l in last_err:
|
||||
print(f" ❌ {l}")
|
||||
return None
|
||||
|
||||
# Parser la sortie — chercher la ligne JSON de résultat
|
||||
data = None
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(line)
|
||||
if "x" in parsed:
|
||||
data = parsed
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if data is None:
|
||||
print(f"⚠️ [InfiGUI] Pas de réponse JSON dans la sortie")
|
||||
return None
|
||||
|
||||
dt = (time.time() - t0) * 1000
|
||||
|
||||
if data.get("x") is not None:
|
||||
method_name = "infigui_fused" if anchor_pil is not None else "infigui"
|
||||
print(f"🎯 [InfiGUI/{method_name}] ({data['x']}, {data['y']}) "
|
||||
f"conf={data.get('confidence', 0):.2f} ({dt:.0f}ms)")
|
||||
return GroundingResult(
|
||||
x=data["x"], y=data["y"],
|
||||
method=method_name,
|
||||
confidence=data.get("confidence", 0.90),
|
||||
time_ms=dt,
|
||||
)
|
||||
else:
|
||||
print(f"⚠️ [InfiGUI] Pas trouvé ({dt:.0f}ms)")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"⚠️ [InfiGUI] Timeout 60s")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"⚠️ [InfiGUI] Erreur: {e}")
|
||||
return None
|
||||
@@ -101,6 +101,35 @@ BUILTIN_PATTERNS: List[Dict[str, Any]] = [
|
||||
"typical_bbox": [0.35, 0.60, 0.45, 0.68],
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "dialog_overwrite",
|
||||
"category": "dialog",
|
||||
"triggers": [
|
||||
"voulez-vous remplacer", "voulez-vous écraser",
|
||||
"remplacer le fichier", "replace existing",
|
||||
"fichier existe déjà", "already exists",
|
||||
"overwrite", "écraser",
|
||||
],
|
||||
"action": "click",
|
||||
"target": "Oui",
|
||||
"alternatives": ["Yes", "Remplacer", "Replace", "Confirmer"],
|
||||
"typical_zone": "dialog_center",
|
||||
"os": "any",
|
||||
},
|
||||
{
|
||||
"name": "dialog_dont_save",
|
||||
"category": "dialog",
|
||||
"triggers": [
|
||||
"ne pas enregistrer", "don't save",
|
||||
"ne pas sauvegarder", "quitter sans enregistrer",
|
||||
"discard changes",
|
||||
],
|
||||
"action": "click",
|
||||
"target": "Ne pas enregistrer",
|
||||
"alternatives": ["Don't Save", "Ne pas sauvegarder", "Non"],
|
||||
"typical_zone": "dialog_center",
|
||||
"os": "any",
|
||||
},
|
||||
|
||||
# === NAVIGATION FENÊTRE ===
|
||||
{
|
||||
|
||||
233
docs/CARTOGRAPHY.md
Normal file
233
docs/CARTOGRAPHY.md
Normal file
@@ -0,0 +1,233 @@
|
||||
# Cartographie d'exécution — RPA Vision V3 (Léa)
|
||||
|
||||
> **Date** : 26 avril 2026
|
||||
> **Objectif** : carte complète de ce qui est branché, ce qui ne l'est pas, et comment les données transitent.
|
||||
> **Règle** : LIRE CE DOCUMENT AVANT TOUTE MODIFICATION DE CODE.
|
||||
|
||||
---
|
||||
|
||||
## 1. Point d'entrée : deux chemins disjoints
|
||||
|
||||
```
|
||||
POST /api/v3/execute/start (execute.py:1528)
|
||||
├── execution_mode = "verified" → run_workflow_verified() ← CHEMIN ORA
|
||||
└── execution_mode = "basic"|"intelligent"|"debug" → execute_workflow_thread() ← CHEMIN LEGACY
|
||||
```
|
||||
|
||||
**Il existe DEUX exécuteurs distincts** qui dupliquent le chargement des ancres, la boucle d'étapes, le grounding, la gestion d'erreurs. Ils ne partagent que `input_handler.py`.
|
||||
|
||||
---
|
||||
|
||||
## 2. Chemin LEGACY (modes basic/intelligent/debug)
|
||||
|
||||
```
|
||||
[API] POST /execute/start (mode=intelligent)
|
||||
→ [execute.py:145] execute_workflow_thread()
|
||||
→ [execute.py:160] Charge steps depuis DB
|
||||
→ BOUCLE sur chaque step:
|
||||
│
|
||||
├─ RÉFLEXE PRÉ-ÉTAPE (modes intelligent/debug)
|
||||
│ → [input_handler.py:79] check_screen_for_patterns()
|
||||
│ → UIPatternLibrary.find_pattern(ocr_text) ← BRANCHÉ
|
||||
│ → [input_handler.py:129] handle_detected_pattern()
|
||||
│ → EasyOCR full screen + clic bouton ← BRANCHÉ
|
||||
│
|
||||
├─ CHARGEMENT ANCRE [execute.py:222-256]
|
||||
│ params['visual_anchor'] = {
|
||||
│ screenshot: base64 du crop,
|
||||
│ bounding_box: {x, y, width, height},
|
||||
│ target_text: anchor.target_text, ← PEUT ÊTRE VIDE ("")
|
||||
│ description: anchor.ocr_description ← PEUT ÊTRE VIDE ("")
|
||||
│ }
|
||||
│
|
||||
├─ execute_action(action_type, params) [execute.py:278]
|
||||
│ │
|
||||
│ ├─ ACTION = click_anchor [execute.py:862-1096]
|
||||
│ │ │
|
||||
│ │ ├─ MODE basic: coordonnées statiques (bbox centre)
|
||||
│ │ │
|
||||
│ │ └─ MODE intelligent/debug:
|
||||
│ │ ├─ target_text = anchor.target_text || step.label
|
||||
│ │ │ Si target_text == "click_anchor" et screenshot_base64:
|
||||
│ │ │ → _describe_anchor_image() (VLM qwen2.5vl:3b) ← BRANCHÉ
|
||||
│ │ │
|
||||
│ │ ├─ MÉTHODE 1: Template matching (cv2) ← BRANCHÉ
|
||||
│ │ ├─ MÉTHODE 2: CLIP matching (RF-DETR + CLIP) ← BRANCHÉ
|
||||
│ │ ├─ MÉTHODE 3: OCR → UI-TARS → VLM ← BRANCHÉ
|
||||
│ │ └─ ÉCHEC: self-healing interactif ← BRANCHÉ
|
||||
│ │
|
||||
│ ├─ ACTION = type_text → safe_type_text() ← BRANCHÉ
|
||||
│ ├─ ACTION = wait → sleep + pattern check ← BRANCHÉ
|
||||
│ ├─ ACTION = keyboard_shortcut → pyautogui.hotkey() ← BRANCHÉ
|
||||
│ ├─ ACTION = ai_analyze_text → Ollama ← BRANCHÉ
|
||||
│ ├─ ACTION = extract_text → docTR OCR ← BRANCHÉ
|
||||
│ └─ ACTION = hover/scroll/focus → coords statiques ← PAS DE GROUNDING
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Chemin ORA (mode "verified")
|
||||
|
||||
```
|
||||
[API] POST /execute/start (mode=verified)
|
||||
→ [execute.py:1349] run_workflow_verified()
|
||||
→ [execute.py:1380-1428] Charge steps + ancres (MÊME logique que legacy)
|
||||
→ [execute.py:1433] ORALoop(verify_level='none', max_retries=2)
|
||||
│ ^^^^^^^^^^^^^^^^^^^
|
||||
│ VÉRIFICATION DÉSACTIVÉE EN DUR
|
||||
│
|
||||
→ [ORA:1478] ora.run_workflow(steps=ora_steps)
|
||||
│
|
||||
BOUCLE sur chaque step:
|
||||
│
|
||||
├─ [ORA:1258] OBSERVE: capture écran + pHash + titre fenêtre
|
||||
│
|
||||
├─ [ORA:1263] RÉFLEXE DIALOGUE (si pHash changé > 10)
|
||||
│ → DialogHandler.handle_if_dialog(screenshot) ← BRANCHÉ
|
||||
│ → EasyOCR full screen → mots-clés dialogues connus
|
||||
│ → InfiGUI worker (/tmp/infigui_*)
|
||||
│ → Fallback OCR clic
|
||||
│
|
||||
├─ [ORA:196] REASON: reason_workflow_step()
|
||||
│ target_text = anchor.target_text || anchor.description
|
||||
│ Si vide ou nom d'action → _describe_anchor_image() ← CORRIGÉ 26/04
|
||||
│ Si encore vide → label (si pas un nom d'action)
|
||||
│
|
||||
├─ [ORA:1306] ACT → _act_click()
|
||||
│ │
|
||||
│ ├─ RPA_USE_FAST_PIPELINE=1 (défaut)
|
||||
│ │ → FastSmartThinkPipeline
|
||||
│ │ → FastDetector (RF-DETR 120ms + EasyOCR 192ms) ← BRANCHÉ
|
||||
│ │ → SmartMatcher (texte+type+position+voisins <1ms) ← BRANCHÉ
|
||||
│ │ → SignatureStore.lookup() (apprentissage) ← BRANCHÉ
|
||||
│ │ → Score ≥ 0.90 → action directe ← BRANCHÉ
|
||||
│ │ → Score 0.60-0.90 → ThinkArbiter
|
||||
│ │ → UITarsGrounder → InfiGUI worker (/tmp) ← BRANCHÉ
|
||||
│ │ → Score < 0.60 → ThinkArbiter seul ← BRANCHÉ
|
||||
│ │ → ÉCHEC → _try_fallback()
|
||||
│ │ → GroundingPipeline ← NON BRANCHÉ (jamais connecté)
|
||||
│ │
|
||||
│ ├─ FALLBACK template matching (cv2, >0.75) ← BRANCHÉ
|
||||
│ ├─ FALLBACK OCR (_grounding_ocr) ← BRANCHÉ
|
||||
│ └─ DERNIER RECOURS: coords statiques ← BRANCHÉ
|
||||
│
|
||||
├─ [ORA:1337] VÉRIFICATION TITRE (post-action)
|
||||
│ → TitleVerifier → EasyOCR crop 45px ← BRANCHÉ
|
||||
│ *** NE LIT RIEN EN VM (titre Windows dans le framebuffer) ← PROBLÈME
|
||||
│
|
||||
├─ [ORA:1358] VERIFY: verify(pre, post, decision)
|
||||
│ *** DÉSACTIVÉ (verify_level='none') *** ← NON BRANCHÉ
|
||||
│
|
||||
└─ [ORA:1362] RECOVERY (5 stratégies)
|
||||
*** JAMAIS ATTEINT *** ← NON BRANCHÉ
|
||||
- _recover_element_not_found (wait+scroll+UI-TARS)
|
||||
- _recover_overlay_blocking (pattern+Win+D)
|
||||
- _recover_wrong_screen (Alt+Tab)
|
||||
- _recover_no_effect (double-clic+décalage)
|
||||
- _classify_error (4 types)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Trace du champ `target_text`
|
||||
|
||||
```
|
||||
CAPTURE (VWB CapturePanel → capture.py:201-263)
|
||||
→ OCR sur crop élargi (docTR)
|
||||
→ VLM qwen2.5vl:3b décrit le crop
|
||||
→ Si les deux échouent → target_text = ""
|
||||
→ Aucune erreur remontée au frontend
|
||||
|
||||
STOCKAGE (DB)
|
||||
→ VisualAnchor.target_text (nullable) = "" si non renseigné
|
||||
|
||||
CHARGEMENT (execute.py:1400-1428)
|
||||
→ SI anchor.target_text existe et non vide → injecté dans visual_anchor
|
||||
→ SINON → la clé 'target_text' N'EXISTE PAS dans le dict
|
||||
|
||||
LEGACY (execute.py:893-907)
|
||||
→ target_text = anchor.get('target_text', '')
|
||||
→ SI vide ET c'est un nom d'action → _describe_anchor_image() ← COMPENSE
|
||||
→ SINON → fallback sur step_label
|
||||
|
||||
ORA (observe_reason_act.py:217) — CORRIGÉ LE 26 AVRIL
|
||||
→ target_text = anchor.target_text || anchor.description
|
||||
→ SI vide ou nom d'action → _describe_anchor_image() ← AJOUTÉ
|
||||
→ SINON → label (si pas un nom d'action)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Fonctions existantes NON BRANCHÉES
|
||||
|
||||
| Fonction | Fichier | Raison |
|
||||
|----------|---------|--------|
|
||||
| `verify()` + `_classify_error()` + 5 `_recover_*()` | observe_reason_act.py | verify_level='none' en dur |
|
||||
| `GroundingPipeline` (ancien) | pipeline.py | set_fallback_pipeline() jamais appelé |
|
||||
| `TemplateMatcher` (classe centralisée) | template_matcher.py | Utilisé seulement par GroundingPipeline mort |
|
||||
| `ShadowLearningHook` | shadow_learning_hook.py | Jamais importé dans aucun flux |
|
||||
| `CognitiveContext` | working_memory.py | Mode instruction seulement |
|
||||
| `VLM pre-check` | observe_reason_act.py | `if False:` en dur |
|
||||
| hover/focus grounding | execute.py | Coords statiques uniquement |
|
||||
| `grounding/server.py` (FastAPI :8200) | server.py | Crash CUDA, remplacé par worker fichiers |
|
||||
|
||||
---
|
||||
|
||||
## 6. Les 12 systèmes de grounding
|
||||
|
||||
| # | Système | Fichier | Branché ? |
|
||||
|---|---------|---------|-----------|
|
||||
| 1 | Template matching inline (legacy) | execute.py:914 | ✅ Legacy |
|
||||
| 2 | Template matching inline (ORA) | ORA:1475 | ✅ ORA fallback |
|
||||
| 3 | CLIP matching (IntelligentExecutor) | intelligent_executor.py | ✅ Legacy |
|
||||
| 4 | OCR docTR (_grounding_ocr) | input_handler.py:430 | ✅ Legacy + ORA |
|
||||
| 5 | UI-TARS Ollama (_grounding_ui_tars) | input_handler.py:513 | ✅ Legacy |
|
||||
| 6 | VLM reasoning (_grounding_vlm) | input_handler.py:627 | ✅ Legacy seulement |
|
||||
| 7 | FastDetector (RF-DETR + EasyOCR) | fast_detector.py | ✅ ORA |
|
||||
| 8 | SmartMatcher | smart_matcher.py | ✅ ORA |
|
||||
| 9 | ThinkArbiter → InfiGUI worker | think_arbiter.py + ui_tars_grounder.py | ✅ ORA |
|
||||
| 10 | DialogHandler → InfiGUI | dialog_handler.py | ✅ ORA réflexe |
|
||||
| 11 | GroundingPipeline (ancien) | pipeline.py | ❌ Jamais connecté |
|
||||
| 12 | TemplateMatcher classe | template_matcher.py | ❌ Via GroundingPipeline mort |
|
||||
|
||||
---
|
||||
|
||||
## 7. Gestion des dialogues (2 systèmes parallèles)
|
||||
|
||||
| # | Système | Base de patterns | OCR | Clic | Utilisé par |
|
||||
|---|---------|-----------------|-----|------|-------------|
|
||||
| 1 | UIPatternLibrary + handle_detected_pattern | 28 patterns builtin | docTR/EasyOCR | OCR find bouton | Legacy |
|
||||
| 2 | DialogHandler + KNOWN_DIALOGS | 15 titres connus | EasyOCR full screen | InfiGUI | ORA |
|
||||
|
||||
---
|
||||
|
||||
## 8. Budget VRAM (configuration actuelle)
|
||||
|
||||
| Composant | VRAM | Process |
|
||||
|-----------|------|---------|
|
||||
| InfiGUI-G1-3B (NF4) | 2.41 GB | Worker indépendant (/tmp) |
|
||||
| RF-DETR Medium | 0.8 GB | Process Flask |
|
||||
| EasyOCR | ~1 GB (GPU) | Process Flask |
|
||||
| Ollama qwen2.5vl:3b (si appelé) | ~3.2 GB | Process Ollama |
|
||||
| Chrome + système | ~1.3 GB | — |
|
||||
| **Total max** | **~8.7 GB / 12 GB** | |
|
||||
|
||||
---
|
||||
|
||||
## 9. Fichiers critiques par ordre d'importance
|
||||
|
||||
1. `core/execution/observe_reason_act.py` — boucle ORA, _act_click, reason, verify
|
||||
2. `visual_workflow_builder/backend/api_v3/execute.py` — API, chargement ancres, legacy executor
|
||||
3. `core/grounding/fast_pipeline.py` — pipeline FAST→SMART→THINK
|
||||
4. `core/grounding/ui_tars_grounder.py` — client InfiGUI worker
|
||||
5. `core/grounding/infigui_worker.py` — worker InfiGUI (process indépendant)
|
||||
6. `core/execution/input_handler.py` — OCR, UI-TARS Ollama, safe_type_text, patterns
|
||||
7. `core/grounding/dialog_handler.py` — gestion dialogues ORA
|
||||
8. `core/grounding/fast_detector.py` — RF-DETR + EasyOCR
|
||||
9. `core/grounding/smart_matcher.py` — matching contextuel
|
||||
10. `core/knowledge/ui_patterns.py` — patterns réflexes
|
||||
|
||||
---
|
||||
|
||||
> **Dernière mise à jour** : 26 avril 2026
|
||||
> **Prochaine action** : rebrancher verify + recovery, converger les 2 exécuteurs, nettoyer le code mort.
|
||||
129
tests/integration/test_chat_window_templates.py
Normal file
129
tests/integration/test_chat_window_templates.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Tests des templates de bulles 'Léa exécute' (J3.4).
|
||||
|
||||
On teste les fonctions _tpl_* et _extract_meta de chat_window.py — elles sont
|
||||
purement fonctionnelles (input payload → output tuple), aucune UI tkinter
|
||||
nécessaire.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_v0.agent_v1.ui import chat_window as cw
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Templates _tpl_*
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def test_tpl_action_started_uses_workflow_name():
|
||||
icon, color, title = cw._tpl_action_started({"workflow": "Demo Urgences UHCD"})
|
||||
assert icon == "▶"
|
||||
assert color == cw.ACTION_ICON_RUN
|
||||
assert "Demo Urgences UHCD" in title
|
||||
|
||||
|
||||
def test_tpl_action_started_fallback_when_no_workflow():
|
||||
_, _, title = cw._tpl_action_started({})
|
||||
assert "?" in title
|
||||
|
||||
|
||||
def test_tpl_action_progress_uses_step_when_provided():
|
||||
_, _, title = cw._tpl_action_progress({"step": "J'ouvre la fiche patient"})
|
||||
assert title == "J'ouvre la fiche patient"
|
||||
|
||||
|
||||
def test_tpl_action_progress_fallback_to_counter():
|
||||
_, _, title = cw._tpl_action_progress({"current": 4, "total": 7})
|
||||
assert "4/7" in title
|
||||
|
||||
|
||||
def test_tpl_done_success():
|
||||
icon, color, title = cw._tpl_done({"success": True, "message": "Codage terminé"})
|
||||
assert icon == "✓"
|
||||
assert color == cw.ACTION_ICON_OK
|
||||
assert title == "Codage terminé"
|
||||
|
||||
|
||||
def test_tpl_done_failure():
|
||||
icon, color, title = cw._tpl_done({"success": False, "message": "Action échouée"})
|
||||
assert icon == "✗"
|
||||
assert color == cw.ACTION_ICON_ERR
|
||||
assert title == "Action échouée"
|
||||
|
||||
|
||||
def test_tpl_done_default_success_when_unspecified():
|
||||
icon, _, _ = cw._tpl_done({})
|
||||
assert icon == "✓" # par défaut on suppose succès si non précisé
|
||||
|
||||
|
||||
def test_tpl_need_confirm_extracts_action_description():
|
||||
icon, _, title = cw._tpl_need_confirm({
|
||||
"action": {"description": "Cliquer sur l'IPP 25003284"}
|
||||
})
|
||||
assert icon == "?"
|
||||
assert "25003284" in title
|
||||
|
||||
|
||||
def test_tpl_need_confirm_fallback():
|
||||
_, _, title = cw._tpl_need_confirm({})
|
||||
assert "Validation" in title
|
||||
|
||||
|
||||
def test_tpl_step_result_ok():
|
||||
icon, color, _ = cw._tpl_step_result({"status": "ok", "message": "ok"})
|
||||
assert icon == "✓"
|
||||
assert color == cw.ACTION_ICON_OK
|
||||
|
||||
|
||||
def test_tpl_step_result_failed():
|
||||
icon, color, _ = cw._tpl_step_result({"status": "failed", "message": "boom"})
|
||||
assert icon == "✗"
|
||||
assert color == cw.ACTION_ICON_ERR
|
||||
|
||||
|
||||
def test_tpl_step_result_neutral_status():
|
||||
icon, color, _ = cw._tpl_step_result({"status": "skipped", "message": "passé"})
|
||||
assert icon == "·"
|
||||
assert color == cw.ACTION_ICON_INFO
|
||||
|
||||
|
||||
def test_tpl_resumed():
|
||||
icon, color, title = cw._tpl_resumed({})
|
||||
assert icon == "→"
|
||||
assert color == cw.ACTION_ICON_OK
|
||||
assert "Reprise" in title
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Dispatch — chaque event lea:* (hors paused/acks) doit avoir un template
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def test_all_relevant_events_have_a_template():
|
||||
expected = {
|
||||
"lea:action_started", "lea:action_progress", "lea:done",
|
||||
"lea:need_confirm", "lea:step_result", "lea:resumed",
|
||||
}
|
||||
assert set(cw._ACTION_TEMPLATES.keys()) == expected
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# _extract_meta
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def test_extract_meta_with_workflow():
|
||||
meta = cw._extract_meta({"workflow": "Demo Urgences"})
|
||||
assert meta == "Demo Urgences"
|
||||
|
||||
|
||||
def test_extract_meta_with_progress():
|
||||
meta = cw._extract_meta({"workflow": "Demo Urgences", "current": 4, "total": 7})
|
||||
assert "Demo Urgences" in meta
|
||||
assert "étape 4/7" in meta
|
||||
|
||||
|
||||
def test_extract_meta_with_replay_id_truncated():
|
||||
meta = cw._extract_meta({"replay_id": "rep_abcdef0123456789"})
|
||||
assert "#789" in meta or "456789" in meta # 6 derniers caractères
|
||||
|
||||
|
||||
def test_extract_meta_empty_payload():
|
||||
assert cw._extract_meta({}) == ""
|
||||
164
tests/integration/test_feedback_bus.py
Normal file
164
tests/integration/test_feedback_bus.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Tests du bus feedback Léa (events lea:* via Flask-SocketIO).
|
||||
|
||||
Couvre J2.5 et J2.6 :
|
||||
- Flag LEA_FEEDBACK_BUS=0 → _emit_lea no-op, _emit_dual ne propage que l'event legacy
|
||||
- Flag LEA_FEEDBACK_BUS=1 → _emit_lea propage 'lea:{event}', _emit_dual propage les deux
|
||||
|
||||
Approche : on intercepte socketio.emit avec monkeypatch (plus fiable que test_client
|
||||
de Flask-SocketIO qui ne capte pas toujours les broadcasts hors contexte requête).
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _reload_app(monkeypatch, flag_value: str):
|
||||
monkeypatch.setenv("LEA_FEEDBACK_BUS", flag_value)
|
||||
import agent_chat.app as app_mod
|
||||
importlib.reload(app_mod)
|
||||
return app_mod
|
||||
|
||||
|
||||
def _capture_emits(monkeypatch, app_mod):
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
app_mod.socketio, "emit",
|
||||
lambda event, payload=None, **kwargs: calls.append((event, payload, kwargs)),
|
||||
)
|
||||
return calls
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_off(monkeypatch):
|
||||
return _reload_app(monkeypatch, "0")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_on(monkeypatch):
|
||||
return _reload_app(monkeypatch, "1")
|
||||
|
||||
|
||||
def test_flag_off_by_default(monkeypatch):
|
||||
monkeypatch.delenv("LEA_FEEDBACK_BUS", raising=False)
|
||||
import agent_chat.app as app_mod
|
||||
importlib.reload(app_mod)
|
||||
assert app_mod.LEA_FEEDBACK_BUS is False
|
||||
|
||||
|
||||
def test_flag_accepts_truthy_values(monkeypatch):
|
||||
for truthy in ["1", "true", "True", "yes", "on", "TRUE"]:
|
||||
monkeypatch.setenv("LEA_FEEDBACK_BUS", truthy)
|
||||
import agent_chat.app as app_mod
|
||||
importlib.reload(app_mod)
|
||||
assert app_mod.LEA_FEEDBACK_BUS is True, f"{truthy!r} devrait activer le flag"
|
||||
|
||||
|
||||
def test_emit_lea_noop_when_flag_off(app_off, monkeypatch):
|
||||
calls = _capture_emits(monkeypatch, app_off)
|
||||
app_off._emit_lea("paused", {"workflow": "demo", "reason": "test"})
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_emit_lea_emits_when_flag_on(app_on, monkeypatch):
|
||||
calls = _capture_emits(monkeypatch, app_on)
|
||||
app_on._emit_lea("paused", {"workflow": "demo", "reason": "test"})
|
||||
assert len(calls) == 1
|
||||
event, payload, _ = calls[0]
|
||||
assert event == "lea:paused"
|
||||
assert payload == {"workflow": "demo", "reason": "test"}
|
||||
|
||||
|
||||
def test_emit_dual_emits_only_legacy_when_flag_off(app_off, monkeypatch):
|
||||
calls = _capture_emits(monkeypatch, app_off)
|
||||
app_off._emit_dual("execution_started", "action_started", {"workflow": "demo"})
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0] == "execution_started"
|
||||
|
||||
|
||||
def test_emit_dual_emits_both_when_flag_on(app_on, monkeypatch):
|
||||
calls = _capture_emits(monkeypatch, app_on)
|
||||
payload = {"workflow": "demo", "params": {"k": "v"}}
|
||||
app_on._emit_dual("execution_started", "action_started", payload)
|
||||
events = [c[0] for c in calls]
|
||||
assert "execution_started" in events
|
||||
assert "lea:action_started" in events
|
||||
assert len(calls) == 2
|
||||
|
||||
|
||||
def test_emit_dual_preserves_kwargs(app_on, monkeypatch):
|
||||
"""broadcast=True et autres kwargs Flask-SocketIO doivent être propagés au legacy."""
|
||||
calls = _capture_emits(monkeypatch, app_on)
|
||||
app_on._emit_dual("execution_cancelled", "cancelled", {}, broadcast=True)
|
||||
legacy_call = next(c for c in calls if c[0] == "execution_cancelled")
|
||||
assert legacy_call[2].get("broadcast") is True
|
||||
|
||||
|
||||
def test_emit_lea_silenced_on_socketio_error(app_on, monkeypatch):
|
||||
"""Une exception dans socketio.emit ne doit jamais remonter."""
|
||||
def boom(*args, **kwargs):
|
||||
raise RuntimeError("socketio fail")
|
||||
monkeypatch.setattr(app_on.socketio, "emit", boom)
|
||||
app_on._emit_lea("paused", {"x": 1})
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# J3.5 — Handlers SocketIO depuis ChatWindow
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, ok=True, status_code=200, text=""):
|
||||
self.ok = ok
|
||||
self.status_code = status_code
|
||||
self.text = text
|
||||
|
||||
|
||||
def test_replay_resume_handler_relays_post_to_streaming(app_on, monkeypatch):
|
||||
"""Le handler 'lea:replay_resume' doit POSTer sur /replay/{id}/resume du streaming."""
|
||||
captured = {}
|
||||
|
||||
def fake_post(url, headers=None, **kwargs):
|
||||
captured["url"] = url
|
||||
captured["headers"] = headers
|
||||
return _FakeResponse(ok=True, status_code=200)
|
||||
|
||||
monkeypatch.setattr(app_on.http_requests, "post", fake_post)
|
||||
emit_calls = _capture_emits(monkeypatch, app_on)
|
||||
|
||||
app_on.handle_lea_replay_resume({"replay_id": "rep_abc123"})
|
||||
|
||||
assert "rep_abc123" in captured["url"]
|
||||
assert captured["url"].endswith("/api/v1/traces/stream/replay/rep_abc123/resume")
|
||||
# Le bus doit propager un ack
|
||||
acked = [c for c in emit_calls if c[0] == "lea:resume_acked"]
|
||||
assert len(acked) == 1
|
||||
assert acked[0][1]["status"] == "ok"
|
||||
|
||||
|
||||
def test_replay_resume_handler_emits_error_on_http_failure(app_on, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
app_on.http_requests, "post",
|
||||
lambda *a, **k: _FakeResponse(ok=False, status_code=500, text="boom"),
|
||||
)
|
||||
emit_calls = _capture_emits(monkeypatch, app_on)
|
||||
app_on.handle_lea_replay_resume({"replay_id": "rep_x"})
|
||||
acked = [c for c in emit_calls if c[0] == "lea:resume_acked"]
|
||||
assert acked[0][1]["status"] == "error"
|
||||
assert acked[0][1]["http_status"] == 500
|
||||
|
||||
|
||||
def test_replay_resume_handler_emits_error_on_no_replay_id(app_on, monkeypatch):
|
||||
emit_calls = _capture_emits(monkeypatch, app_on)
|
||||
app_on.handle_lea_replay_resume({})
|
||||
acked = [c for c in emit_calls if c[0] == "lea:resume_acked"]
|
||||
assert acked[0][1]["status"] == "error"
|
||||
assert "replay_id manquant" in acked[0][1]["detail"]
|
||||
|
||||
|
||||
def test_replay_abort_handler_stops_local_execution(app_on, monkeypatch):
|
||||
app_on.execution_status["running"] = True
|
||||
emit_calls = _capture_emits(monkeypatch, app_on)
|
||||
app_on.handle_lea_replay_abort({"replay_id": "rep_y"})
|
||||
assert app_on.execution_status["running"] is False
|
||||
acked = [c for c in emit_calls if c[0] == "lea:abort_acked"]
|
||||
assert acked[0][1]["status"] == "ok"
|
||||
164
tests/integration/test_feedback_bus_client.py
Normal file
164
tests/integration/test_feedback_bus_client.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Tests FeedbackBusClient (J3.2).
|
||||
|
||||
On mock python-socketio pour ne pas ouvrir de vraie connexion réseau.
|
||||
Le test E2E réel (vraie connexion bus 5004) est différé à J4.3.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_v0.agent_v1.network.feedback_bus import FeedbackBusClient, LEA_EVENTS
|
||||
|
||||
|
||||
def test_init_creates_socketio_client():
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
assert bus._sio is not None
|
||||
assert bus.connected is False
|
||||
|
||||
|
||||
def test_init_strips_trailing_slash():
|
||||
bus = FeedbackBusClient("http://localhost:5004/")
|
||||
assert bus._url == "http://localhost:5004"
|
||||
|
||||
|
||||
def test_lea_events_registered():
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
handlers = bus._sio.handlers.get('/', {})
|
||||
for ev in LEA_EVENTS:
|
||||
assert ev in handlers, f"Handler {ev!r} non enregistré sur le client"
|
||||
|
||||
|
||||
def test_dispatch_calls_callback():
|
||||
received = []
|
||||
bus = FeedbackBusClient(
|
||||
"http://localhost:5004",
|
||||
on_event=lambda e, p: received.append((e, p)),
|
||||
)
|
||||
bus._dispatch('lea:paused', {'workflow': 'demo', 'reason': 'incertain'})
|
||||
assert received == [('lea:paused', {'workflow': 'demo', 'reason': 'incertain'})]
|
||||
|
||||
|
||||
def test_dispatch_handles_none_payload():
|
||||
received = []
|
||||
bus = FeedbackBusClient(
|
||||
"http://localhost:5004",
|
||||
on_event=lambda e, p: received.append((e, p)),
|
||||
)
|
||||
bus._dispatch('lea:done', None)
|
||||
assert received == [('lea:done', {})]
|
||||
|
||||
|
||||
def test_dispatch_silenced_on_callback_error():
|
||||
"""Une exception dans le callback consommateur ne doit jamais remonter."""
|
||||
def boom(event, payload):
|
||||
raise RuntimeError("callback fail")
|
||||
bus = FeedbackBusClient("http://localhost:5004", on_event=boom)
|
||||
bus._dispatch('lea:paused', {}) # ne doit pas raise
|
||||
|
||||
|
||||
def test_default_callback_is_silent():
|
||||
"""Sans callback fourni, le dispatch ne casse pas."""
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
bus._dispatch('lea:paused', {'x': 1}) # ne doit pas raise
|
||||
|
||||
|
||||
def test_token_in_authorization_header():
|
||||
bus = FeedbackBusClient("http://localhost:5004", token="abc123")
|
||||
captured = {}
|
||||
|
||||
def fake_connect(url, headers=None, **kwargs):
|
||||
captured['headers'] = headers
|
||||
raise RuntimeError("stop here")
|
||||
|
||||
with patch.object(bus._sio, 'connect', side_effect=fake_connect):
|
||||
bus._run()
|
||||
|
||||
assert captured['headers']['Authorization'] == 'Bearer abc123'
|
||||
|
||||
|
||||
def test_no_token_means_no_auth_header():
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
captured = {}
|
||||
|
||||
def fake_connect(url, headers=None, **kwargs):
|
||||
captured['headers'] = headers
|
||||
raise RuntimeError("stop here")
|
||||
|
||||
with patch.object(bus._sio, 'connect', side_effect=fake_connect):
|
||||
bus._run()
|
||||
|
||||
assert 'Authorization' not in captured['headers']
|
||||
|
||||
|
||||
def test_run_silenced_on_connect_error():
|
||||
"""connect() qui raise ne doit pas faire crasher le thread."""
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
with patch.object(bus._sio, 'connect', side_effect=ConnectionError("boom")):
|
||||
bus._run() # ne doit pas raise
|
||||
|
||||
|
||||
def test_start_is_idempotent():
|
||||
"""Un second start() pendant que le thread tourne ne doit pas en créer un autre."""
|
||||
import threading
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
block = threading.Event()
|
||||
with patch.object(bus, '_run', side_effect=lambda: block.wait(timeout=2)):
|
||||
bus.start()
|
||||
first_thread = bus._thread
|
||||
bus.start()
|
||||
second_thread = bus._thread
|
||||
block.set()
|
||||
assert first_thread is second_thread, "start() doit être idempotent quand un thread tourne"
|
||||
|
||||
|
||||
def test_stop_when_not_connected_is_silent():
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
bus.stop() # ne doit pas raise même si jamais connecté
|
||||
|
||||
|
||||
def test_stop_silenced_on_disconnect_error():
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
# Forcer connected=True sur l'instance et faire raise disconnect()
|
||||
with patch.object(bus._sio, 'disconnect', side_effect=RuntimeError("boom")):
|
||||
bus._sio.connected = True
|
||||
bus.stop() # ne doit pas raise
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# J3.5 — Actions utilisateur (resume_replay / abort_replay)
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def test_resume_replay_emits_when_connected():
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
bus._sio.connected = True
|
||||
with patch.object(bus._sio, 'emit') as mock_emit:
|
||||
ok = bus.resume_replay("rep_abc")
|
||||
assert ok is True
|
||||
mock_emit.assert_called_once_with("lea:replay_resume", {"replay_id": "rep_abc"})
|
||||
|
||||
|
||||
def test_resume_replay_returns_false_when_disconnected():
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
# _sio.connected reste False par défaut
|
||||
with patch.object(bus._sio, 'emit') as mock_emit:
|
||||
ok = bus.resume_replay("rep_abc")
|
||||
assert ok is False
|
||||
mock_emit.assert_not_called()
|
||||
|
||||
|
||||
def test_abort_replay_emits_when_connected():
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
bus._sio.connected = True
|
||||
with patch.object(bus._sio, 'emit') as mock_emit:
|
||||
ok = bus.abort_replay("rep_xyz")
|
||||
assert ok is True
|
||||
mock_emit.assert_called_once_with("lea:replay_abort", {"replay_id": "rep_xyz"})
|
||||
|
||||
|
||||
def test_safe_emit_silenced_on_error():
|
||||
bus = FeedbackBusClient("http://localhost:5004")
|
||||
bus._sio.connected = True
|
||||
with patch.object(bus._sio, 'emit', side_effect=RuntimeError("boom")):
|
||||
ok = bus.resume_replay("rep_abc")
|
||||
assert ok is False # erreur avalée silencieusement
|
||||
311
tests/unit/test_template_matcher.py
Normal file
311
tests/unit/test_template_matcher.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""Tests pour core/grounding/template_matcher.py"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from core.grounding.template_matcher import MatchResult, TemplateMatcher
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_image(w: int, h: int, color: tuple = (128, 128, 128)) -> Image.Image:
|
||||
"""Crée une image PIL unie."""
|
||||
img = Image.new('RGB', (w, h), color)
|
||||
return img
|
||||
|
||||
|
||||
def _pil_to_b64(img: Image.Image) -> str:
|
||||
"""Encode une image PIL en base64 PNG."""
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format='PNG')
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def _make_screen_with_target(
|
||||
screen_w: int = 800,
|
||||
screen_h: int = 600,
|
||||
target_x: int = 300,
|
||||
target_y: int = 200,
|
||||
target_w: int = 60,
|
||||
target_h: int = 40,
|
||||
):
|
||||
"""Crée un screen bruité avec un motif unique et l'ancre correspondante.
|
||||
|
||||
Le screen a un fond aléatoire (bruit) pour que le template matching
|
||||
ne puisse matcher qu'à l'endroit exact du motif injecté.
|
||||
"""
|
||||
rng = np.random.RandomState(42)
|
||||
# Fond bruité — chaque pixel est différent, pas de faux match possible
|
||||
screen = rng.randint(0, 256, (screen_h, screen_w, 3), dtype=np.uint8)
|
||||
|
||||
# Injecter un motif déterministe unique (damier rouge/bleu)
|
||||
target = np.zeros((target_h, target_w, 3), dtype=np.uint8)
|
||||
for r in range(target_h):
|
||||
for c in range(target_w):
|
||||
if (r + c) % 2 == 0:
|
||||
target[r, c] = [255, 0, 0] # rouge
|
||||
else:
|
||||
target[r, c] = [0, 0, 255] # bleu
|
||||
screen[target_y:target_y + target_h, target_x:target_x + target_w] = target
|
||||
screen_pil = Image.fromarray(screen)
|
||||
|
||||
# L'ancre est exactement le même motif
|
||||
anchor_pil = Image.fromarray(target)
|
||||
|
||||
expected_cx = target_x + target_w // 2
|
||||
expected_cy = target_y + target_h // 2
|
||||
|
||||
return screen_pil, anchor_pil, expected_cx, expected_cy
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests MatchResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatchResult:
|
||||
def test_fields(self):
|
||||
r = MatchResult(x=100, y=200, score=0.85, method='template', time_ms=5.0)
|
||||
assert r.x == 100
|
||||
assert r.y == 200
|
||||
assert r.score == 0.85
|
||||
assert r.method == 'template'
|
||||
assert r.time_ms == 5.0
|
||||
assert r.scale == 1.0 # default
|
||||
|
||||
def test_with_scale(self):
|
||||
r = MatchResult(x=10, y=20, score=0.9, method='template_multiscale', time_ms=12.0, scale=0.95)
|
||||
assert r.scale == 0.95
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests TemplateMatcher — init
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTemplateMatcherInit:
|
||||
def test_defaults(self):
|
||||
m = TemplateMatcher()
|
||||
assert m.threshold == 0.75
|
||||
assert m.multiscale is False
|
||||
assert m.grayscale is False
|
||||
|
||||
def test_custom_params(self):
|
||||
m = TemplateMatcher(threshold=0.5, multiscale=True, grayscale=True, scales=[1.0, 0.8])
|
||||
assert m.threshold == 0.5
|
||||
assert m.multiscale is True
|
||||
assert m.grayscale is True
|
||||
assert m.scales == [1.0, 0.8]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests TemplateMatcher — _decode_anchor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDecodeAnchor:
|
||||
def test_pil_passthrough(self):
|
||||
img = _make_image(50, 50)
|
||||
result = TemplateMatcher._decode_anchor(None, img)
|
||||
assert result is img
|
||||
|
||||
def test_b64_decode(self):
|
||||
img = _make_image(50, 50, (255, 0, 0))
|
||||
b64 = _pil_to_b64(img)
|
||||
result = TemplateMatcher._decode_anchor(b64, None)
|
||||
assert result is not None
|
||||
assert result.size == (50, 50)
|
||||
|
||||
def test_b64_with_data_prefix(self):
|
||||
img = _make_image(30, 30)
|
||||
b64 = "data:image/png;base64," + _pil_to_b64(img)
|
||||
result = TemplateMatcher._decode_anchor(b64, None)
|
||||
assert result is not None
|
||||
|
||||
def test_none_inputs(self):
|
||||
result = TemplateMatcher._decode_anchor(None, None)
|
||||
assert result is None
|
||||
|
||||
def test_invalid_b64(self):
|
||||
result = TemplateMatcher._decode_anchor("not-valid-base64!!!", None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests TemplateMatcher — match_screen avec screen_pil fourni
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatchScreenWithPIL:
|
||||
def test_exact_match(self):
|
||||
screen, anchor, cx, cy = _make_screen_with_target()
|
||||
m = TemplateMatcher(threshold=0.75)
|
||||
result = m.match_screen(anchor_pil=anchor, screen_pil=screen)
|
||||
assert result is not None
|
||||
assert abs(result.x - cx) <= 1
|
||||
assert abs(result.y - cy) <= 1
|
||||
assert result.score > 0.9
|
||||
assert result.method == 'template'
|
||||
assert result.time_ms >= 0
|
||||
|
||||
def test_no_match(self):
|
||||
# Screen bruité, ancre = damier unique absent du screen
|
||||
rng = np.random.RandomState(123)
|
||||
screen_np = rng.randint(0, 256, (600, 800, 3), dtype=np.uint8)
|
||||
screen = Image.fromarray(screen_np)
|
||||
|
||||
# Ancre = damier régulier non présent dans le bruit
|
||||
anchor_np = np.zeros((40, 60, 3), dtype=np.uint8)
|
||||
for r in range(40):
|
||||
for c in range(60):
|
||||
anchor_np[r, c] = [255, 255, 0] if (r + c) % 2 == 0 else [0, 255, 255]
|
||||
anchor = Image.fromarray(anchor_np)
|
||||
|
||||
m = TemplateMatcher(threshold=0.75)
|
||||
result = m.match_screen(anchor_pil=anchor, screen_pil=screen)
|
||||
assert result is None
|
||||
|
||||
def test_b64_anchor(self):
|
||||
screen, anchor, cx, cy = _make_screen_with_target()
|
||||
b64 = _pil_to_b64(anchor)
|
||||
m = TemplateMatcher(threshold=0.75)
|
||||
result = m.match_screen(anchor_b64=b64, screen_pil=screen)
|
||||
assert result is not None
|
||||
assert abs(result.x - cx) <= 1
|
||||
|
||||
def test_anchor_bigger_than_screen(self):
|
||||
screen = _make_image(100, 100)
|
||||
anchor = _make_image(200, 200)
|
||||
m = TemplateMatcher()
|
||||
result = m.match_screen(anchor_pil=anchor, screen_pil=screen)
|
||||
assert result is None
|
||||
|
||||
def test_threshold_configurable(self):
|
||||
screen, anchor, cx, cy = _make_screen_with_target()
|
||||
# Avec un seuil de 0.999, le match exact devrait quand même passer (score=1.0)
|
||||
m = TemplateMatcher(threshold=0.999)
|
||||
result = m.match_screen(anchor_pil=anchor, screen_pil=screen)
|
||||
# Le score d'un match pixel-perfect peut être 1.0 ou très proche
|
||||
# On accepte les deux cas
|
||||
if result:
|
||||
assert result.score >= 0.999
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests TemplateMatcher — multi-scale
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMultiscale:
|
||||
def test_multiscale_exact(self):
|
||||
screen, anchor, cx, cy = _make_screen_with_target()
|
||||
m = TemplateMatcher(threshold=0.75, multiscale=True)
|
||||
result = m.match_screen(anchor_pil=anchor, screen_pil=screen)
|
||||
assert result is not None
|
||||
assert abs(result.x - cx) <= 2
|
||||
assert abs(result.y - cy) <= 2
|
||||
assert result.score > 0.9
|
||||
|
||||
def test_multiscale_scaled_anchor(self):
|
||||
"""L'ancre a été capturée à une échelle légèrement différente.
|
||||
|
||||
On utilise un motif plus gros (bloc de couleur unie) pour que le resize
|
||||
ne détruise pas le pattern comme avec un damier fin.
|
||||
"""
|
||||
# Screen bruité + gros bloc rouge
|
||||
rng = np.random.RandomState(42)
|
||||
screen_np = rng.randint(50, 200, (600, 800, 3), dtype=np.uint8)
|
||||
target = np.full((80, 120, 3), dtype=np.uint8, fill_value=0)
|
||||
target[:, :] = [220, 30, 30] # rouge vif unique
|
||||
# Ajouter un bord vert pour le rendre encore plus unique
|
||||
target[:5, :] = [30, 220, 30]
|
||||
target[-5:, :] = [30, 220, 30]
|
||||
screen_np[200:280, 300:420] = target
|
||||
screen = Image.fromarray(screen_np)
|
||||
|
||||
# L'ancre d'origine
|
||||
anchor_original = Image.fromarray(target)
|
||||
# L'ancre à 105% (scale modeste pour que ça reste réaliste)
|
||||
w, h = anchor_original.size
|
||||
scaled_anchor = anchor_original.resize((int(w * 1.05), int(h * 1.05)), Image.BILINEAR)
|
||||
|
||||
m_multi = TemplateMatcher(threshold=0.60, multiscale=True)
|
||||
result_multi = m_multi.match_screen(anchor_pil=scaled_anchor, screen_pil=screen)
|
||||
assert result_multi is not None
|
||||
assert result_multi.method == 'template_multiscale'
|
||||
|
||||
def test_multiscale_anchor_too_small(self):
|
||||
"""Ancre très petite — certaines échelles sont sautées."""
|
||||
screen = _make_image(800, 600)
|
||||
anchor = _make_image(5, 5, (255, 0, 0))
|
||||
m = TemplateMatcher(threshold=0.99, multiscale=True, scales=[0.5, 0.3])
|
||||
result = m.match_screen(anchor_pil=anchor, screen_pil=screen)
|
||||
# Pas de crash même avec des échelles qui produisent < 8px
|
||||
# Le résultat peut être None ou un match selon le contenu
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests TemplateMatcher — match_in_region
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatchInRegion:
|
||||
def test_region_match(self):
|
||||
# Créer une region BGR bruitée avec un motif damier injecté
|
||||
rng = np.random.RandomState(77)
|
||||
region = rng.randint(0, 256, (200, 300, 3), dtype=np.uint8)
|
||||
# Motif damier en BGR
|
||||
anchor = np.zeros((40, 60, 3), dtype=np.uint8)
|
||||
for r in range(40):
|
||||
for c in range(60):
|
||||
if (r + c) % 2 == 0:
|
||||
anchor[r, c] = [255, 0, 0]
|
||||
else:
|
||||
anchor[r, c] = [0, 0, 255]
|
||||
region[50:90, 100:160] = anchor
|
||||
|
||||
m = TemplateMatcher(threshold=0.75)
|
||||
result = m.match_in_region(region, anchor)
|
||||
assert result is not None
|
||||
assert abs(result.x - 130) <= 1 # 100 + 60//2
|
||||
assert abs(result.y - 70) <= 1 # 50 + 40//2
|
||||
|
||||
def test_region_no_match(self):
|
||||
# Region bruitée, ancre damier absente
|
||||
rng = np.random.RandomState(88)
|
||||
region = rng.randint(0, 256, (200, 300, 3), dtype=np.uint8)
|
||||
anchor = np.zeros((40, 60, 3), dtype=np.uint8)
|
||||
for r in range(40):
|
||||
for c in range(60):
|
||||
anchor[r, c] = [255, 255, 0] if (r + c) % 2 == 0 else [0, 255, 255]
|
||||
|
||||
m = TemplateMatcher(threshold=0.75)
|
||||
result = m.match_in_region(region, anchor)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests grayscale mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGrayscale:
|
||||
def test_grayscale_match(self):
|
||||
screen, anchor, cx, cy = _make_screen_with_target()
|
||||
m = TemplateMatcher(threshold=0.75, grayscale=True)
|
||||
result = m.match_screen(anchor_pil=anchor, screen_pil=screen)
|
||||
assert result is not None
|
||||
assert abs(result.x - cx) <= 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests _capture_screen (mocké)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCaptureScreen:
|
||||
@patch('core.grounding.template_matcher._MSS', False)
|
||||
def test_no_mss(self):
|
||||
result = TemplateMatcher._capture_screen()
|
||||
assert result is None
|
||||
218
tools/benchmark_grounding.py
Normal file
218
tools/benchmark_grounding.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Benchmark complet des méthodes de grounding visuel.
|
||||
À lancer avec la VM Windows visible à l'écran, bureau avec dossier Demo.
|
||||
|
||||
Usage:
|
||||
cd ~/ai/rpa_vision_v3
|
||||
.venv/bin/python3 tools/benchmark_grounding.py
|
||||
"""
|
||||
import mss, io, base64, requests, time, re, cv2, numpy as np, os, glob, json
|
||||
from PIL import Image
|
||||
|
||||
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
|
||||
ANCHOR_DIR = 'visual_workflow_builder/backend/data/anchors'
|
||||
|
||||
|
||||
def capture_screen():
|
||||
with mss.mss() as sct:
|
||||
grab = sct.grab(sct.monitors[0])
|
||||
screen = Image.frombytes('RGB', grab.size, grab.rgb)
|
||||
return screen
|
||||
|
||||
|
||||
def screen_to_b64(screen):
|
||||
buf = io.BytesIO()
|
||||
screen.save(buf, format='JPEG', quality=70)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
def parse_coords(text, screen_w, screen_h):
|
||||
for pat in [
|
||||
r"start_box='?\<?\|?box_start\|?\>?\((\d+),(\d+)\)",
|
||||
r'\((\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\)',
|
||||
r'\[(\d+(?:\.\d+)?)\s*,\s*(\d+(?:\.\d+)?)\]',
|
||||
]:
|
||||
m = re.search(pat, text)
|
||||
if m:
|
||||
rx, ry = float(m.group(1)), float(m.group(2))
|
||||
if rx <= 1.0 and ry <= 1.0:
|
||||
return int(rx * screen_w), int(ry * screen_h)
|
||||
elif rx <= 1000 and ry <= 1000:
|
||||
return int(rx * screen_w / 1000), int(ry * screen_h / 1000)
|
||||
return int(rx), int(ry)
|
||||
return None
|
||||
|
||||
|
||||
def test_vlm(model, prompt, b64, screen_w, screen_h):
|
||||
t0 = time.time()
|
||||
try:
|
||||
resp = requests.post(f'{OLLAMA_URL}/api/generate', json={
|
||||
'model': model, 'prompt': prompt, 'images': [b64],
|
||||
'stream': False, 'options': {'temperature': 0.0, 'num_predict': 50}
|
||||
}, timeout=60)
|
||||
elapsed = time.time() - t0
|
||||
if resp.status_code != 200:
|
||||
return elapsed, None, f"HTTP {resp.status_code}"
|
||||
text = resp.json().get('response', '').strip()
|
||||
coords = parse_coords(text, screen_w, screen_h)
|
||||
return elapsed, coords, text[:120]
|
||||
except Exception as e:
|
||||
return time.time() - t0, None, str(e)[:80]
|
||||
|
||||
|
||||
def test_template(screen_gray, anchor_path):
|
||||
anchor = cv2.imread(anchor_path, cv2.IMREAD_GRAYSCALE)
|
||||
if anchor is None:
|
||||
return None
|
||||
ah, aw = anchor.shape[:2]
|
||||
if ah >= screen_gray.shape[0] or aw >= screen_gray.shape[1]:
|
||||
return None
|
||||
t0 = time.time()
|
||||
result = cv2.matchTemplate(screen_gray, anchor, cv2.TM_CCOEFF_NORMED)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result)
|
||||
elapsed = (time.time() - t0) * 1000
|
||||
return {
|
||||
'method': 'template', 'time_ms': elapsed,
|
||||
'score': max_val, 'pos': (max_loc[0] + aw//2, max_loc[1] + ah//2)
|
||||
}
|
||||
|
||||
|
||||
def test_template_multiscale(screen_gray, anchor_path, scales=(0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3)):
|
||||
anchor = cv2.imread(anchor_path, cv2.IMREAD_GRAYSCALE)
|
||||
if anchor is None:
|
||||
return None
|
||||
ah, aw = anchor.shape[:2]
|
||||
t0 = time.time()
|
||||
best_val, best_loc, best_scale = 0, None, 1.0
|
||||
for s in scales:
|
||||
resized = cv2.resize(anchor, None, fx=s, fy=s)
|
||||
rh, rw = resized.shape[:2]
|
||||
if rh >= screen_gray.shape[0] or rw >= screen_gray.shape[1]:
|
||||
continue
|
||||
res = cv2.matchTemplate(screen_gray, resized, cv2.TM_CCOEFF_NORMED)
|
||||
_, mv, _, ml = cv2.minMaxLoc(res)
|
||||
if mv > best_val:
|
||||
best_val, best_loc, best_scale = mv, ml, s
|
||||
elapsed = (time.time() - t0) * 1000
|
||||
if best_loc is None:
|
||||
return None
|
||||
rh, rw = int(ah * best_scale), int(aw * best_scale)
|
||||
return {
|
||||
'method': 'template_multiscale', 'time_ms': elapsed,
|
||||
'score': best_val, 'pos': (best_loc[0] + rw//2, best_loc[1] + rh//2),
|
||||
'scale': best_scale
|
||||
}
|
||||
|
||||
|
||||
def test_orb(screen_gray, anchor_path, max_distance=50):
|
||||
anchor = cv2.imread(anchor_path, cv2.IMREAD_GRAYSCALE)
|
||||
if anchor is None:
|
||||
return None
|
||||
t0 = time.time()
|
||||
orb = cv2.ORB_create(nfeatures=1000)
|
||||
kp1, des1 = orb.detectAndCompute(anchor, None)
|
||||
kp2, des2 = orb.detectAndCompute(screen_gray, None)
|
||||
if des1 is None or des2 is None or len(des1) < 2 or len(des2) < 2:
|
||||
return {'method': 'ORB', 'time_ms': (time.time()-t0)*1000, 'matches': 0, 'pos': None}
|
||||
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
|
||||
matches = bf.match(des1, des2)
|
||||
good = sorted([m for m in matches if m.distance < max_distance], key=lambda m: m.distance)
|
||||
elapsed = (time.time() - t0) * 1000
|
||||
pos = None
|
||||
if len(good) >= 4:
|
||||
pts = np.float32([kp2[m.trainIdx].pt for m in good])
|
||||
pos = (int(np.median(pts[:, 0])), int(np.median(pts[:, 1])))
|
||||
return {'method': 'ORB', 'time_ms': elapsed, 'matches': len(good), 'pos': pos}
|
||||
|
||||
|
||||
def test_akaze(screen_gray, anchor_path, max_distance=80):
|
||||
anchor = cv2.imread(anchor_path, cv2.IMREAD_GRAYSCALE)
|
||||
if anchor is None:
|
||||
return None
|
||||
t0 = time.time()
|
||||
akaze = cv2.AKAZE_create()
|
||||
kp1, des1 = akaze.detectAndCompute(anchor, None)
|
||||
kp2, des2 = akaze.detectAndCompute(screen_gray, None)
|
||||
if des1 is None or des2 is None or len(des1) < 2 or len(des2) < 2:
|
||||
return {'method': 'AKAZE', 'time_ms': (time.time()-t0)*1000, 'matches': 0, 'pos': None}
|
||||
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
|
||||
matches = bf.match(des1, des2)
|
||||
good = sorted([m for m in matches if m.distance < max_distance], key=lambda m: m.distance)
|
||||
elapsed = (time.time() - t0) * 1000
|
||||
pos = None
|
||||
if len(good) >= 4:
|
||||
pts = np.float32([kp2[m.trainIdx].pt for m in good])
|
||||
pos = (int(np.median(pts[:, 0])), int(np.median(pts[:, 1])))
|
||||
return {'method': 'AKAZE', 'time_ms': elapsed, 'matches': len(good), 'pos': pos}
|
||||
|
||||
|
||||
def main():
|
||||
print("="*70)
|
||||
print("BENCHMARK GROUNDING — Léa RPA Vision")
|
||||
print("="*70)
|
||||
|
||||
screen = capture_screen()
|
||||
screen_w, screen_h = screen.size
|
||||
b64 = screen_to_b64(screen)
|
||||
screen_cv = cv2.cvtColor(np.array(screen), cv2.COLOR_RGB2BGR)
|
||||
screen_gray = cv2.cvtColor(screen_cv, cv2.COLOR_BGR2GRAY)
|
||||
print(f"Écran: {screen_w}x{screen_h}\n")
|
||||
|
||||
# ── VLM grounding ──
|
||||
print("─── VLM GROUNDING (cible: 'Demo folder') ───")
|
||||
vlm_tests = [
|
||||
("qwen3-vl:8b", 'Click on "Demo folder". Return the action in format: click(start_box="(x,y)") with coordinates normalized 0-1000.'),
|
||||
("qwen2.5vl:7b", 'Click on "Demo folder". Return the action in format: click(start_box="(x,y)") with coordinates normalized 0-1000.'),
|
||||
("moondream:latest", 'Where is the "Demo" folder icon? Give coordinates as (x, y) in pixels.'),
|
||||
("gemma4:latest", 'Click on "Demo folder". Return the action in format: click(start_box="(x,y)") with coordinates normalized 0-1000.'),
|
||||
]
|
||||
for model, prompt in vlm_tests:
|
||||
elapsed, coords, text = test_vlm(model, prompt, b64, screen_w, screen_h)
|
||||
coord_str = f"({coords[0]:4d}, {coords[1]:4d})" if coords else " — "
|
||||
print(f" {model:35s} {elapsed:5.1f}s {coord_str} {text[:60]}")
|
||||
|
||||
# ── OpenCV ──
|
||||
print(f"\n─── OPENCV (ancres de {ANCHOR_DIR}) ───")
|
||||
thumbs = sorted(glob.glob(f'{ANCHOR_DIR}/*_thumb.png'))[:5]
|
||||
full_imgs = sorted(glob.glob(f'{ANCHOR_DIR}/*_full.png'))[:5]
|
||||
|
||||
for thumb_path in thumbs:
|
||||
name = os.path.basename(thumb_path).replace('_thumb.png', '')[:30]
|
||||
ah, aw = cv2.imread(thumb_path, cv2.IMREAD_GRAYSCALE).shape[:2] if cv2.imread(thumb_path) is not None else (0,0)
|
||||
print(f"\n Ancre: {name} ({aw}x{ah})")
|
||||
|
||||
r = test_template(screen_gray, thumb_path)
|
||||
if r:
|
||||
print(f" Template: {r['time_ms']:6.1f}ms score={r['score']:.3f} pos={r['pos']}")
|
||||
|
||||
r = test_template_multiscale(screen_gray, thumb_path)
|
||||
if r:
|
||||
print(f" Template multi-s: {r['time_ms']:6.1f}ms score={r['score']:.3f} pos={r['pos']} scale={r['scale']}")
|
||||
|
||||
r = test_orb(screen_gray, thumb_path)
|
||||
if r:
|
||||
print(f" ORB: {r['time_ms']:6.1f}ms matches={r['matches']:3d} pos={r['pos']}")
|
||||
|
||||
r = test_akaze(screen_gray, thumb_path)
|
||||
if r:
|
||||
print(f" AKAZE: {r['time_ms']:6.1f}ms matches={r['matches']:3d} pos={r['pos']}")
|
||||
|
||||
# ── Résumé ──
|
||||
print(f"\n{'='*70}")
|
||||
print("RÉSUMÉ")
|
||||
print("="*70)
|
||||
print("""
|
||||
Pipeline recommandé (du plus rapide au plus lent) :
|
||||
1. Template matching classique ~20-50ms (score > 0.75 = direct)
|
||||
2. Template multi-scale ~80-150ms (robuste aux changements de taille)
|
||||
3. OCR (docTR) ~500-1000ms (texte uniquement)
|
||||
4. Static fallback ~0ms (coordonnées d'origine)
|
||||
|
||||
Note : les feature matchers (ORB/AKAZE) ne sont pas adaptés aux petites
|
||||
ancres UI (< 200x200px) — trop peu de keypoints distinctifs.
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
39
tools/start_grounding_server.sh
Executable file
39
tools/start_grounding_server.sh
Executable file
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
# Lancement du serveur de grounding UI-TARS (port 8200)
|
||||
#
|
||||
# Le serveur charge UI-TARS-1.5-7B en 4-bit NF4 dans son propre process
|
||||
# Python avec un contexte CUDA propre. Le backend Flask VWB et la boucle
|
||||
# ORA appellent ce serveur en HTTP.
|
||||
#
|
||||
# Usage :
|
||||
# ./tools/start_grounding_server.sh # premier plan
|
||||
# ./tools/start_grounding_server.sh --bg # arriere-plan (log dans /tmp)
|
||||
|
||||
set -e
|
||||
|
||||
cd /home/dom/ai/rpa_vision_v3
|
||||
|
||||
VENV=".venv/bin/python3"
|
||||
LOG="/tmp/grounding_server.log"
|
||||
|
||||
if [ ! -f "$VENV" ]; then
|
||||
echo "ERREUR: venv non trouve a $VENV"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "=== Serveur de Grounding UI-TARS ==="
|
||||
echo "Port: 8200"
|
||||
echo "Modele: ByteDance-Seed/UI-TARS-1.5-7B (4-bit NF4)"
|
||||
echo ""
|
||||
|
||||
if [ "$1" = "--bg" ]; then
|
||||
echo "Lancement en arriere-plan (logs dans $LOG)"
|
||||
nohup $VENV -m core.grounding.server > "$LOG" 2>&1 &
|
||||
PID=$!
|
||||
echo "PID: $PID"
|
||||
echo "$PID" > /tmp/grounding_server.pid
|
||||
echo "Verifier: curl http://localhost:8200/health"
|
||||
echo "Logs: tail -f $LOG"
|
||||
else
|
||||
$VENV -m core.grounding.server
|
||||
fi
|
||||
@@ -896,15 +896,15 @@ def execute_action(action_type: str, params: dict) -> dict:
|
||||
_fc_target_text = params.get('_step_label', '')
|
||||
_action_types = {'click_anchor', 'double_click_anchor', 'right_click_anchor',
|
||||
'hover_anchor', 'focus_anchor', 'scroll_to_anchor'}
|
||||
if _fc_target_text in _action_types and screenshot_base64:
|
||||
try:
|
||||
from core.execution.input_handler import _describe_anchor_image
|
||||
_desc = _describe_anchor_image(screenshot_base64)
|
||||
if _desc:
|
||||
print(f"🏷️ [Vision] Ancre décrite: '{_desc}'")
|
||||
_fc_target_text = _desc
|
||||
except Exception:
|
||||
pass
|
||||
# Note: plus d'appel à _describe_anchor_image() (qwen2.5vl) ici.
|
||||
# Le crop d'ancre (screenshot_base64) est utilisé directement par
|
||||
# le template matching pixel-perfect en avant-poste, puis InfiGUI
|
||||
# en mode fusionné si nécessaire (option 2.c+2.a). Économise ~9.4 GB
|
||||
# de VRAM Ollama qui rentrait en conflit avec InfiGUI.
|
||||
if _fc_target_text in _action_types:
|
||||
# Marquer le label comme garbage pour que le pipeline
|
||||
# bascule sur le mode fusionné via template_b64.
|
||||
_fc_target_text = ''
|
||||
_fc_target_desc = params.get('visual_anchor', {}).get('description', '')
|
||||
|
||||
x, y, confidence, method_used = None, None, 0, ''
|
||||
@@ -1431,7 +1431,7 @@ def run_workflow_verified(execution_id: str, workflow_id: str, app):
|
||||
from core.execution.observe_reason_act import ORALoop
|
||||
|
||||
ora = ORALoop(
|
||||
max_retries=2, max_steps=50, verify_level='auto',
|
||||
max_retries=2, max_steps=50, verify_level='none',
|
||||
should_continue=lambda: not _execution_state.get('should_stop', False)
|
||||
)
|
||||
ora._variables = _execution_state.get('variables', {})
|
||||
|
||||
Reference in New Issue
Block a user