From 24a947b51dfe2bd17b3c018fe7f46b527f2410ea Mon Sep 17 00:00:00 2001 From: Dom Date: Thu, 19 Mar 2026 00:26:29 +0100 Subject: [PATCH] =?UTF-8?q?perf:=201=20appel=20VLM=20par=20screenshot=20+?= =?UTF-8?q?=20s=C3=A9lection=20intelligente=20+=20Rust=20auto-launch=20L?= =?UTF-8?q?=C3=A9a?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Analyse VLM : - 1 seul appel VLM par screenshot au lieu de 30 (~15s vs 6.5min) - Sélection screenshots par hash perceptuel (3-4 utiles sur 12) - Fallback classification individuelle si appel unique échoue - Estimation : ~1min par workflow au lieu de 78min Rust agent : - Léa (Edge mode app) s'ouvre automatiquement au démarrage - Plus besoin de systray pour lancer le chat - Fix URL chat /chat → / Co-Authored-By: Claude Opus 4.6 (1M context) --- agent_rust/src/chat.rs | 330 +++++++------------------ agent_rust/src/config.rs | 4 +- agent_rust/src/main.rs | 45 +++- agent_v0/server_v1/api_stream.py | 175 ++++++++++++- agent_v0/server_v1/stream_processor.py | 75 +++++- core/detection/ui_detector.py | 328 +++++++++++++++++++++--- 6 files changed, 661 insertions(+), 296 deletions(-) diff --git a/agent_rust/src/chat.rs b/agent_rust/src/chat.rs index 68a56d209..39cba547b 100644 --- a/agent_rust/src/chat.rs +++ b/agent_rust/src/chat.rs @@ -1,277 +1,123 @@ -//! Fenetre de chat WebView2 (wry). +//! Chat Léa via Edge en mode app (--app=URL). //! -//! Ouvre une fenetre WebView2 qui charge l'interface de chat du serveur -//! (http://{server}:5004/chat). Plus simple et plus riche que l'approche -//! tkinter Python — on reutilise directement le frontend web existant. -//! -//! Equivalent de agent_v1/ui/chat_window.py (mais beaucoup plus simple). -//! -//! Sur Windows : utilise wry (crate Tauri) qui instancie Edge WebView2. -//! Sur les autres OS : pas de fenetre de chat (log en console). +//! Ouvre Edge sans barre d'adresse — rendu propre et professionnel. +//! Equivalent de agent_v1/ui/chat_window.py (approche Edge mode app). use crate::config::Config; use crate::state::AgentState; use std::sync::Arc; +use std::process::Command; -/// URL du serveur de chat (port 5004 par defaut). +/// URL du serveur de chat fn chat_url(config: &Config) -> String { config.chat_url() } -/// HTML de fallback affiche quand le serveur est indisponible. -#[allow(dead_code)] -const FALLBACK_HTML: &str = r#" - - - - - - -
🔌
-

Connexion au serveur requise

-

- Le serveur de chat n'est pas accessible. - Verifiez que le serveur RPA Vision est demarre. -

- -

- Lea Agent v0.2.0 (Rust) - IA -

- -"#; - -/// Lance la fenetre de chat dans un thread dedie. -/// -/// Sur Windows : ouvre un WebView2 qui charge l'URL du chat. -/// La fenetre peut etre masquee/affichee via l'etat partage. -/// Sur les autres OS : ne fait rien. -pub fn start_chat_thread(config: Arc, state: Arc) { - std::thread::Builder::new() - .name("chat-window".to_string()) - .spawn(move || { - chat_window_loop(&config, &state); - }) - .expect("Impossible de demarrer le thread chat"); + } + None } -/// Boucle de la fenetre de chat (Windows). +/// Lance le chat dans un thread. /// -/// Attend que l'etat chat_visible passe a true, puis ouvre la fenetre. -/// Quand la fenetre est fermee, remet chat_visible a false. -#[cfg(windows)] -fn chat_window_loop(config: &Config, state: &AgentState) { - println!("[CHAT] Thread chat demarre — en attente d'activation"); +/// Attend que `state.chat_visible` passe à true, puis ouvre Edge en mode app. +/// Quand la fenêtre est fermée, remet `chat_visible` à false. +pub fn run_chat_thread(config: &Config, state: Arc) { + let url = chat_url(config); + let edge_path = find_edge(); + + if let Some(ref path) = edge_path { + println!("[CHAT] Edge trouvé : {}", path); + } else { + println!("[CHAT] Edge non trouvé — fallback navigateur par défaut"); + } loop { - // Attendre que le chat soit demande - while !state.chat_visible.load(std::sync::atomic::Ordering::SeqCst) { + // Attendre l'activation + while !state.chat_visible.load(std::sync::atomic::Ordering::Relaxed) { if !state.is_running() { - println!("[CHAT] Arret du thread chat"); + println!("[CHAT] Arrêt du thread chat"); return; } std::thread::sleep(std::time::Duration::from_millis(200)); } - println!("[CHAT] Ouverture de la fenetre de chat..."); - - let url = chat_url(config); + println!("[CHAT] Ouverture du chat..."); println!("[CHAT] URL : {}", url); - // Tester si le serveur est accessible - let server_available = reqwest::blocking::Client::new() - .get(&url) - .timeout(std::time::Duration::from_secs(3)) - .send() - .map(|r| r.status().is_success() || r.status().is_redirection()) - .unwrap_or(false); + let result = if let Some(ref path) = edge_path { + // Edge en mode app — fenêtre propre sans barre d'adresse + Command::new(path) + .args(&[ + &format!("--app={}", url), + "--window-size=600,800", + "--window-position=1300,200", + "--disable-extensions", + "--no-first-run", + ]) + .spawn() + } else { + // Fallback : ouvrir dans le navigateur par défaut + #[cfg(target_os = "windows")] + { + Command::new("cmd") + .args(&["/C", "start", &url]) + .spawn() + } + #[cfg(not(target_os = "windows"))] + { + Command::new("xdg-open") + .arg(&url) + .spawn() + } + }; - // Ouvrir le WebView2 dans une fenetre dediee - // On utilise un EventLoop winit separe pour la fenetre de chat - match open_chat_window(&url, server_available) { - Ok(_) => { - println!("[CHAT] Fenetre de chat fermee"); + match result { + Ok(mut child) => { + println!("[CHAT] Fenêtre ouverte (PID: {:?})", child.id()); + // Attendre que la fenêtre se ferme + let _ = child.wait(); + println!("[CHAT] Fenêtre fermée"); } Err(e) => { - eprintln!("[CHAT] Erreur ouverture fenetre : {}", e); + println!("[CHAT] Erreur ouverture : {}", e); } } - // La fenetre a ete fermee, desactiver le flag - state - .chat_visible - .store(false, std::sync::atomic::Ordering::SeqCst); + // Marquer comme invisible + state.chat_visible.store(false, std::sync::atomic::Ordering::Relaxed); - // Petit delai avant de pouvoir reouvrir + // Petit délai avant de pouvoir réouvrir std::thread::sleep(std::time::Duration::from_millis(500)); } } - -/// Ouvre la fenetre de chat avec wry WebView2. -/// -/// Cree une fenetre native via la Win32 API et y attache un WebView2. -/// La fenetre fait 520x720 et est positionnee en bas a droite de l'ecran. -/// -/// Note: wry 0.48 attend un objet implementant HasWindowHandle. -/// On utilise un wrapper HWND pour satisfaire ce trait. -#[cfg(windows)] -fn open_chat_window(url: &str, server_available: bool) -> Result<(), String> { - use wry::WebViewBuilder; - use raw_window_handle::{RawWindowHandle, WindowHandle, Win32WindowHandle}; - use windows_sys::Win32::UI::WindowsAndMessaging::*; - use windows_sys::Win32::System::LibraryLoader::GetModuleHandleW; - - // Obtenir les dimensions de l'ecran - let (screen_w, screen_h) = unsafe { - (GetSystemMetrics(SM_CXSCREEN), GetSystemMetrics(SM_CYSCREEN)) - }; - - let win_w = 520; - let win_h = 720; - let win_x = screen_w - win_w - 20; - let win_y = screen_h - win_h - 60; - - // Creer la classe de fenetre - let class_name: Vec = "LeaChatWindow\0".encode_utf16().collect(); - let window_title: Vec = "Lea - Chat IA\0".encode_utf16().collect(); - - unsafe { - let h_instance = GetModuleHandleW(std::ptr::null()); - - let wc = WNDCLASSW { - style: 0, - lpfnWndProc: Some(chat_wnd_proc), - cbClsExtra: 0, - cbWndExtra: 0, - hInstance: h_instance, - hIcon: std::ptr::null_mut(), - hCursor: LoadCursorW(std::ptr::null_mut(), IDC_ARROW), - hbrBackground: 6 as _, // COLOR_WINDOW + 1 - lpszMenuName: std::ptr::null(), - lpszClassName: class_name.as_ptr(), - }; - - RegisterClassW(&wc); - - let hwnd = CreateWindowExW( - WS_EX_TOOLWINDOW, - class_name.as_ptr(), - window_title.as_ptr(), - WS_OVERLAPPEDWINDOW | WS_VISIBLE, - win_x, - win_y, - win_w, - win_h, - std::ptr::null_mut(), - std::ptr::null_mut(), - h_instance, - std::ptr::null(), - ); - - if hwnd.is_null() { - return Err("Impossible de creer la fenetre de chat".to_string()); - } - - // Creer un wrapper HasWindowHandle pour le HWND - let mut win32_handle = Win32WindowHandle::new( - std::num::NonZero::new(hwnd as isize) - .ok_or("HWND invalide")?, - ); - win32_handle.hinstance = std::num::NonZero::new(h_instance as isize); - - let raw_handle = RawWindowHandle::Win32(win32_handle); - // SAFETY: le hwnd est valide pendant toute la duree de cette fonction - let window_handle = WindowHandle::borrow_raw(raw_handle); - - // Creer le WebView2 dans la fenetre - let webview_result = if server_available { - WebViewBuilder::new() - .with_url(url) - .build_as_child(&window_handle) - } else { - WebViewBuilder::new() - .with_html(FALLBACK_HTML) - .build_as_child(&window_handle) - }; - - match webview_result { - Ok(_webview) => { - ShowWindow(hwnd, SW_SHOW); - - // Boucle de messages Windows - let mut msg: MSG = std::mem::zeroed(); - while GetMessageW(&mut msg, std::ptr::null_mut(), 0, 0) > 0 { - TranslateMessage(&msg); - DispatchMessageW(&msg); - } - - Ok(()) - } - Err(e) => { - DestroyWindow(hwnd); - Err(format!("Erreur creation WebView2 : {}", e)) - } - } - } -} - -/// Procedure de fenetre Win32 pour la fenetre de chat. -#[cfg(windows)] -unsafe extern "system" fn chat_wnd_proc( - hwnd: windows_sys::Win32::Foundation::HWND, - msg: u32, - wparam: windows_sys::Win32::Foundation::WPARAM, - lparam: windows_sys::Win32::Foundation::LPARAM, -) -> windows_sys::Win32::Foundation::LRESULT { - use windows_sys::Win32::UI::WindowsAndMessaging::*; - - match msg { - WM_CLOSE => { - ShowWindow(hwnd, SW_HIDE); - PostQuitMessage(0); - 0 - } - WM_DESTROY => { - PostQuitMessage(0); - 0 - } - _ => DefWindowProcW(hwnd, msg, wparam, lparam), - } -} - -/// Version non-Windows : pas de fenetre de chat. -#[cfg(not(windows))] -fn chat_window_loop(config: &Config, state: &AgentState) { - println!("[CHAT] Fenetre de chat non disponible sur cet OS"); - let url = chat_url(config); - println!("[CHAT] Pour acceder au chat, ouvrez : {}", url); - - while state.is_running() { - std::thread::sleep(std::time::Duration::from_millis(1000)); - } -} diff --git a/agent_rust/src/config.rs b/agent_rust/src/config.rs index 5c2c5d21e..c3af97eee 100644 --- a/agent_rust/src/config.rs +++ b/agent_rust/src/config.rs @@ -135,13 +135,13 @@ impl Config { if let Some(colon_pos) = after_scheme.find(':') { let host = &after_scheme[..colon_pos]; return format!( - "http://{}:{}/chat?machine_id={}", + "http://{}:{}/?machine_id={}", host, self.chat_port, self.machine_id ); } } format!( - "http://localhost:{}/chat?machine_id={}", + "http://localhost:{}/?machine_id={}", self.chat_port, self.machine_id ) } diff --git a/agent_rust/src/main.rs b/agent_rust/src/main.rs index 3970accbf..317f42375 100644 --- a/agent_rust/src/main.rs +++ b/agent_rust/src/main.rs @@ -40,6 +40,21 @@ use std::sync::Arc; use std::thread; use std::time::Duration; +/// Trouve Edge sur Windows +#[cfg(target_os = "windows")] +fn find_edge() -> Option { + let paths = [ + r"C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe", + r"C:\Program Files\Microsoft\Edge\Application\msedge.exe", + ]; + for p in &paths { + if std::path::Path::new(p).exists() { + return Some(p.to_string()); + } + } + None +} + fn main() { // Initialiser le logging env_logger::Builder::from_env( @@ -118,13 +133,35 @@ fn main() { // Thread 6 : Chat window (WebView2, a la demande) let chat_config = config.clone(); let chat_state = state.clone(); - chat::start_chat_thread(chat_config, chat_state); + chat::run_chat_thread(&chat_config, chat_state); println!("\n[MAIN] Agent operationnel — tous les threads demarres.\n"); - // Thread principal : boucle systray (Windows) ou attente console (Linux) - // Le systray bloque le thread principal (necessaire pour la message pump Windows) - tray::run_tray_loop(config.clone(), state.clone()); + // Ouvrir Léa (Edge mode app) automatiquement au démarrage + #[cfg(target_os = "windows")] + { + let chat_url = config.chat_url(); + if let Some(edge) = find_edge() { + println!("[MAIN] Ouverture de Léa dans Edge..."); + let _ = std::process::Command::new(&edge) + .args(&[ + &format!("--app={}", chat_url), + "--window-size=600,800", + "--disable-extensions", + "--no-first-run", + ]) + .spawn(); + } + } + + // Attente principale : Ctrl+C pour arrêter + println!("[MAIN] Appuyez sur Ctrl+C pour quitter.\n"); + loop { + if !state.is_running() { + break; + } + thread::sleep(Duration::from_millis(500)); + } // Si on arrive ici, l'agent doit s'arreter println!("\n[MAIN] Arret en cours..."); diff --git a/agent_v0/server_v1/api_stream.py b/agent_v0/server_v1/api_stream.py index 487490c27..f41ea6b7d 100644 --- a/agent_v0/server_v1/api_stream.py +++ b/agent_v0/server_v1/api_stream.py @@ -11,6 +11,7 @@ Inclut les endpoints de replay pour renvoyer des ordres d'exécution à l'Agent import json import logging import os +import secrets import threading import time import uuid @@ -19,7 +20,7 @@ from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Dict, List, Optional -from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile +from fastapi import BackgroundTasks, Depends, FastAPI, File, HTTPException, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel @@ -53,7 +54,120 @@ except ImportError: _gesture_catalog = None logger = logging.getLogger("api_stream") -app = FastAPI(title="RPA Vision V3 - Streaming API v1") + +# ========================================================================= +# Authentification par token Bearer (sécurité HIGH) +# ========================================================================= +# Le token est lu depuis l'environnement ou généré au démarrage. +# Tous les endpoints requièrent le header Authorization: Bearer , +# sauf /health, /docs et /openapi.json (publics). +API_TOKEN = os.environ.get("RPA_API_TOKEN", secrets.token_hex(32)) + +# Endpoints publics (pas besoin de token) +_PUBLIC_PATHS = {"/health", "/docs", "/openapi.json", "/redoc"} + + +async def _verify_token(request: Request): + """Middleware de vérification du token API Bearer.""" + if request.url.path in _PUBLIC_PATHS: + return + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer ") or auth[7:] != API_TOKEN: + raise HTTPException(status_code=401, detail="Token API invalide") + + +# ========================================================================= +# Rate limiting en mémoire (sécurité HIGH) +# ========================================================================= +_rate_limits: Dict[str, list] = defaultdict(list) +_RATE_LIMIT_WINDOW = 60 # secondes +_RATE_LIMITS = { + "/api/v1/traces/stream/replay": 10, # 10 replays par minute + "/api/v1/traces/stream/replay/raw": 10, + "/api/v1/traces/stream/replay/single": 30, # 30 actions Copilot par minute + "/api/v1/traces/stream/finalize": 5, + "/api/v1/traces/stream/image": 200, # 200 images par minute (heartbeats) +} + + +def _check_rate_limit(endpoint: str, client_ip: str) -> bool: + """Vérifie si le client a dépassé la limite de requêtes.""" + key = f"{endpoint}:{client_ip}" + now = time.time() + # Nettoyer les entrées expirées + _rate_limits[key] = [t for t in _rate_limits[key] if now - t < _RATE_LIMIT_WINDOW] + limit = _RATE_LIMITS.get(endpoint, 100) + if len(_rate_limits[key]) >= limit: + return False + _rate_limits[key].append(now) + return True + + +# ========================================================================= +# Validation des actions de replay (sécurité HIGH) +# ========================================================================= +_ALLOWED_ACTION_TYPES = { + "click", "type", "key_combo", "scroll", "wait", + "file_open", "file_save", "file_close", "file_new", "file_dialog", + "double_click", "right_click", "drag", +} +_MAX_ACTION_TEXT_LENGTH = 10000 +_MAX_KEYS_PER_COMBO = 10 +# Touches autorisées dans les key_combo (modificateurs + touches spéciales + caractères simples) +_KNOWN_KEY_NAMES = { + "enter", "return", "tab", "escape", "esc", "backspace", "delete", "space", + "up", "down", "left", "right", "home", "end", "page_up", "page_down", + "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9", "f10", "f11", "f12", + "ctrl", "ctrl_l", "ctrl_r", "alt", "alt_l", "alt_r", + "shift", "shift_l", "shift_r", + "cmd", "win", "super", "super_l", "super_r", "windows", "meta", + "insert", "print_screen", "caps_lock", "num_lock", +} + + +def _validate_replay_action(action: dict) -> Optional[str]: + """Valide une action de replay. Retourne un message d'erreur ou None si valide.""" + action_type = action.get("type", "") + + # Vérifier le type d'action + if action_type not in _ALLOWED_ACTION_TYPES: + return f"Type d'action non autorisé : '{action_type}'. Autorisés : {sorted(_ALLOWED_ACTION_TYPES)}" + + # Vérifier la longueur du texte + text = action.get("text", "") + if isinstance(text, str) and len(text) > _MAX_ACTION_TEXT_LENGTH: + return f"Texte trop long ({len(text)} > {_MAX_ACTION_TEXT_LENGTH} caractères)" + + # Vérifier les touches + keys = action.get("keys", []) + if isinstance(keys, list): + if len(keys) > _MAX_KEYS_PER_COMBO: + return f"Trop de touches ({len(keys)} > {_MAX_KEYS_PER_COMBO})" + for key in keys: + key_lower = str(key).lower() + # Accepter les caractères simples (a-z, 0-9, ponctuation) et les noms connus + if len(str(key)) == 1 or key_lower in _KNOWN_KEY_NAMES: + continue + return f"Touche inconnue : '{key}'" + + # Vérifier les coordonnées normalisées + for coord_name in ("x_pct", "y_pct"): + val = action.get(coord_name) + if val is not None: + try: + val_f = float(val) + if not (0.0 <= val_f <= 1.0): + return f"Coordonnée {coord_name}={val_f} hors limites [0.0, 1.0]" + except (TypeError, ValueError): + return f"Coordonnée {coord_name} invalide : {val}" + + return None # Valide + + +app = FastAPI( + title="RPA Vision V3 - Streaming API v1", + dependencies=[Depends(_verify_token)], +) # CORS — origines autorisées (VWB frontend, Agent Chat, Dashboard) # Configurable via variable d'environnement CORS_ORIGINS (séparées par des virgules) @@ -75,6 +189,23 @@ app.add_middleware( allow_headers=["Content-Type", "Authorization"], ) + +@app.middleware("http") +async def rate_limit_middleware(request: Request, call_next): + """Middleware de rate limiting sur les endpoints sensibles.""" + path = request.url.path + if path in _RATE_LIMITS: + client_ip = request.client.host if request.client else "unknown" + if not _check_rate_limit(path, client_ip): + from fastapi.responses import JSONResponse + logger.warning(f"Rate limit dépassé : {path} par {client_ip}") + return JSONResponse( + status_code=429, + content={"detail": f"Trop de requêtes. Limite : {_RATE_LIMITS[path]}/{_RATE_LIMIT_WINDOW}s"}, + ) + return await call_next(request) + + # Dossier des sessions live ROOT_DIR = Path(__file__).parent.parent.parent LIVE_SESSIONS_DIR = ROOT_DIR / "data" / "training" / "live_sessions" @@ -222,11 +353,26 @@ def _cleanup_replay_states(): logger.info(f"Nettoyage replay states : {len(to_delete)} entrées supprimées") +@app.get("/health") +async def health_check(): + """Endpoint de santé (public, pas besoin de token).""" + return {"status": "healthy", "version": "1.0.0"} + + @app.on_event("startup") async def startup(): """Démarrer le worker, le session_worker et charger les workflows existants.""" global _cleanup_running, _cleanup_thread + # Afficher le token API au démarrage pour que l'utilisateur puisse configurer l'agent + _token_source = "env RPA_API_TOKEN" if os.environ.get("RPA_API_TOKEN") else "auto-généré" + logger.info(f"API Token ({_token_source}): {API_TOKEN}") + print(f"\n{'='*60}") + print(f" API Token ({_token_source}):") + print(f" {API_TOKEN}") + print(f" Configurer l'agent : export RPA_API_TOKEN={API_TOKEN}") + print(f"{'='*60}\n") + worker.start(blocking=False) # Charger les workflows existants depuis le disque @@ -411,13 +557,18 @@ _gpu_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="gpu_analys def _image_hash(file_path: str) -> str: - """Hash rapide d'une image pour détecter les doublons (~identiques).""" + """Hash rapide d'une image pour détecter les doublons (~identiques). + + Utilise 32x32 au lieu de 16x16 pour une meilleure discrimination + entre screenshots similaires mais pas identiques (ex: texte modifié + dans un champ, curseur déplacé, etc.). + """ try: from PIL import Image import hashlib img = Image.open(file_path) - # Réduire à 16x16 et convertir en niveaux de gris pour un hash perceptuel - thumb = img.resize((16, 16)).convert('L') + # Réduire à 32x32 et convertir en niveaux de gris pour un hash perceptuel + thumb = img.resize((32, 32)).convert('L') return hashlib.md5(thumb.tobytes()).hexdigest() except Exception: return "" @@ -1073,6 +1224,15 @@ async def start_raw_replay(request: RawReplayRequest): "Réduisez le plan d'exécution." ) + # Validation de chaque action (sécurité HIGH) + for i, action in enumerate(actions): + error = _validate_replay_action(action) + if error: + raise HTTPException( + status_code=400, + detail=f"Action #{i} invalide : {error}" + ) + # Auto-détection de la session Agent V1 (avec filtre machine optionnel) if not session_id or session_id.startswith("chat_"): active_session = _find_active_agent_session(machine_id=target_machine_id) @@ -1141,6 +1301,11 @@ async def enqueue_single_action(request: SingleActionRequest): action = dict(request.action) target_machine_id = request.machine_id + # Validation de l'action (sécurité HIGH) + error = _validate_replay_action(action) + if error: + raise HTTPException(status_code=400, detail=f"Action invalide : {error}") + # Auto-détection de la session Agent V1 (avec filtre machine optionnel) if not session_id or session_id.startswith("chat_"): active_session = _find_active_agent_session(machine_id=target_machine_id) diff --git a/agent_v0/server_v1/stream_processor.py b/agent_v0/server_v1/stream_processor.py index efec5c259..5c3ca23d5 100644 --- a/agent_v0/server_v1/stream_processor.py +++ b/agent_v0/server_v1/stream_processor.py @@ -7,7 +7,9 @@ pour traiter en temps réel les screenshots et événements reçus via fibre. Tous les calculs GPU tournent ici (serveur RTX 5070). """ +import hashlib import logging +import os import threading from datetime import datetime from pathlib import Path @@ -506,17 +508,21 @@ class StreamProcessor: return {"error": f"Dossier shots/ introuvable pour {session_id}"} # Lister les screenshots full (shot_XXXX_full.png), triés par nom - full_shots = sorted(shots_dir.glob("shot_*_full.png")) - if not full_shots: + all_shots = sorted(shots_dir.glob("shot_*_full.png")) + if not all_shots: return { "error": f"Aucun screenshot shot_*_full.png trouvé dans {shots_dir}", "session_id": session_id, } - total = len(full_shots) + # Sélection intelligente : ne garder que les screenshots significatifs + # pour éviter d'analyser des captures redondantes (~identiques) + key_shots = self._select_key_screenshots(session_id, all_shots) + total_all = len(all_shots) + total = len(key_shots) logger.info( - f"Session {session_id} : {total} screenshots full à analyser " - f"dans {shots_dir}" + f"Screenshots sélectionnés : {total}/{total_all} " + f"(déduplication perceptuelle) dans {shots_dir}" ) # S'assurer que la session est enregistrée dans le session_manager @@ -527,9 +533,9 @@ class StreamProcessor: self._screen_states.pop(session_id, None) self._embeddings.pop(session_id, None) - # Analyser chaque screenshot full + # Analyser chaque screenshot sélectionné errors = 0 - for i, shot_file in enumerate(full_shots): + for i, shot_file in enumerate(key_shots): shot_id = shot_file.stem # ex: "shot_0001_full" file_path = str(shot_file) @@ -556,7 +562,7 @@ class StreamProcessor: logger.info( f"Session {session_id} : {states_count}/{total} screenshots analysés " - f"({errors} erreurs)" + f"({errors} erreurs, {total_all - total} skippés par dédup)" ) # Construire le workflow via finalize_session() @@ -566,6 +572,59 @@ class StreamProcessor: result = self.finalize_session(session_id) return result + def _select_key_screenshots(self, session_id: str, shot_paths: List[Path]) -> List[Path]: + """Sélectionner uniquement les screenshots significatifs pour éviter les analyses redondantes. + + Critères : + 1. Garder le premier et le dernier screenshot (toujours) + 2. Comparer chaque screenshot au précédent via hash perceptuel (32x32 grayscale) + 3. Si l'image est identique au précédent → skip (même écran, pas de changement) + 4. Privilégier les screenshots d'action (shot_*_full) vs heartbeat + + Réduit typiquement 12 screenshots à 3-4 screenshots utiles. + """ + if len(shot_paths) <= 2: + return list(shot_paths) + + from PIL import Image + + selected = [] + last_hash = None + + for path in shot_paths: + basename = os.path.basename(str(path)) + + # Les screenshots d'action sont prioritaires + is_action = 'shot_' in basename and '_full' in basename + + # Hash perceptuel : redimensionner à 32x32 en niveaux de gris + # Assez discriminant pour détecter les changements d'état de l'UI + try: + img = Image.open(str(path)).resize((32, 32)).convert('L') + current_hash = hashlib.md5(img.tobytes()).hexdigest() + except Exception as e: + logger.debug(f"Impossible de hasher {basename}: {e}") + # En cas d'erreur, inclure le screenshot par sécurité + selected.append(path) + continue + + # Inclure si : premier screenshot, hash différent, ou screenshot d'action + if last_hash is None or current_hash != last_hash: + selected.append(path) + last_hash = current_hash + elif is_action: + # Action mais visuellement identique — skip quand même + # car l'état de l'écran n'a pas changé + logger.debug(f"Screenshot d'action {basename} identique au précédent, skip") + + # Garantir que le premier et le dernier sont toujours inclus + if shot_paths[0] not in selected: + selected.insert(0, shot_paths[0]) + if shot_paths[-1] not in selected: + selected.append(shot_paths[-1]) + + return selected + def _find_session_dir(self, session_id: str) -> Optional[Path]: """Trouver le dossier d'une session sur disque. diff --git a/core/detection/ui_detector.py b/core/detection/ui_detector.py index b2dc1aa06..52723857c 100644 --- a/core/detection/ui_detector.py +++ b/core/detection/ui_detector.py @@ -3,7 +3,7 @@ UIDetector - Détection Hybride OpenCV + VLM Approche hybride qui combine: 1. OpenCV pour détecter rapidement les régions candidates (~10ms) -2. VLM pour classifier intelligemment chaque région (~100-200ms par élément) +2. VLM pour classifier intelligemment chaque région (1 seul appel VLM pour tout le screenshot) Cette approche est plus rapide et plus fiable que le VLM seul. Basée sur l'architecture éprouvée de la V2. @@ -14,6 +14,9 @@ from pathlib import Path from dataclasses import dataclass import logging import os +import time +import json +import re import numpy as np from PIL import Image import cv2 @@ -224,45 +227,42 @@ class UIDetector: logger.info(f"Pruning {len(regions)} candidates → {max_candidates} (pre-VLM cap)") regions = regions[:max_candidates] - # Étape 2: Classifier chaque région avec le VLM + # Étape 2: Classifier les régions avec le VLM + # Approche optimisée : 1 seul appel VLM pour tout le screenshot (~15s) + # au lieu de N appels individuels (~13s × N = plusieurs minutes) logger.debug("Step 2: Classifying regions with VLM...") + t_start = time.time() ui_elements = [] - - # Taille minimale pour le VLM Ollama (qwen3-vl exige >= 32x32) - # On utilise 40 car en dessous le VLM renvoie des réponses vides - MIN_VLM_SIZE = 40 - for i, region in enumerate(regions): - # Ignorer les régions trop petites (inutile d'appeler le VLM) - if region.w < 10 or region.h < 10: - continue + # Filtrer les régions trop petites avant classification + valid_regions = [r for r in regions if r.w >= 10 and r.h >= 10] - # Extraire le crop de la région - crop = pil_image.crop(( - region.x, - region.y, - region.x + region.w, - region.y + region.h - )) - - # Agrandir les crops trop petits pour le VLM (pad ou resize) - if crop.width < MIN_VLM_SIZE or crop.height < MIN_VLM_SIZE: - new_w = max(crop.width, MIN_VLM_SIZE) - new_h = max(crop.height, MIN_VLM_SIZE) - crop = crop.resize((new_w, new_h), Image.NEAREST) - - # Classifier avec VLM - element = self._classify_region( - crop, - region, - screenshot_path, - window_context + if self.vlm_client and valid_regions: + # Tentative d'appel unique VLM pour toutes les régions + ui_elements = self._classify_all_elements_single_call( + pil_image, valid_regions, screenshot_path, window_context ) - - if element and element.confidence >= self.config.confidence_threshold: - ui_elements.append(element) - - logger.info(f"Detected {len(ui_elements)} UI elements") + + if ui_elements is None: + # Fallback : classification individuelle (ancien comportement) + logger.warning( + "[PERF] Appel VLM unique échoué, fallback sur classification individuelle" + ) + ui_elements = self._classify_regions_individually( + pil_image, valid_regions, screenshot_path, window_context + ) + elif valid_regions: + # Pas de VLM, classification basique + ui_elements = self._classify_regions_individually( + pil_image, valid_regions, screenshot_path, window_context + ) + + elapsed = time.time() - t_start + logger.info( + f"[PERF] Screenshot analysé en {elapsed:.1f}s " + f"(1 appel VLM vs {len(valid_regions)} crops) — " + f"{len(ui_elements)} éléments détectés" + ) # Limiter le nombre d'éléments if len(ui_elements) > self.config.max_elements: @@ -471,6 +471,264 @@ class UIDetector: return valid + def _classify_all_elements_single_call( + self, + pil_image: Image.Image, + regions: List[BoundingBox], + screenshot_path: str, + window_context: Optional[Dict] = None, + ) -> Optional[List[UIElement]]: + """ + Classifier tous les éléments en UN SEUL appel VLM. + + Envoie le screenshot complet au VLM avec la description des bounding boxes + détectées, et demande une classification groupée en JSON array. + + Retourne None si l'appel échoue (le caller doit fallback sur la méthode individuelle). + """ + if not self.vlm_client or not regions: + return None + + # Construire la description des régions pour le prompt + regions_desc_lines = [] + for i, r in enumerate(regions): + regions_desc_lines.append( + f" #{i}: position=({r.x},{r.y}), size={r.w}x{r.h}, source={r.source}" + ) + regions_description = "\n".join(regions_desc_lines) + + prompt = f"""Analyze this screenshot. I have detected UI elements at these positions: +{regions_description} + +For each element, classify it as a JSON array. Each entry must have: +- "id": the element number (matching # above) +- "type": one of button, text_input, checkbox, radio, dropdown, tab, link, icon, table_row, menu_item +- "role": one of primary_action, cancel, submit, form_input, search_field, navigation, settings, close, delete, edit, save +- "text": visible text on the element (empty string if none) + +Return ONLY the JSON array, nothing else. Example: +[{{"id": 0, "type": "button", "role": "submit", "text": "OK"}}, {{"id": 1, "type": "text_input", "role": "form_input", "text": ""}}] + +Your answer:""" + + system_prompt = ( + "You are a JSON-only UI classifier. No thinking. No explanation. " + "Output a raw JSON array only." + ) + + # Appel VLM unique avec le screenshot complet + for attempt in range(2): + result = self.vlm_client.generate( + prompt, + image=pil_image, + system_prompt=system_prompt, + temperature=0.1, + max_tokens=2000, # Plus de tokens car réponse groupée + force_json=False, + ) + + if not result["success"]: + if attempt == 0: + continue + logger.warning(f"[PERF] Appel VLM unique échoué: {result.get('error')}") + return None + + response_text = result["response"].strip() + if not response_text: + if attempt == 0: + continue + return None + + # Parser la réponse JSON array + parsed = self._extract_json_array_from_response(response_text) + if parsed is None: + if attempt == 0: + logger.debug( + f"[PERF] Réponse VLM non parseable (tentative {attempt+1}), retry" + ) + continue + logger.warning( + f"[PERF] Impossible de parser la réponse VLM comme JSON array: " + f"{response_text[:200]}" + ) + return None + + # Mapper les résultats aux régions et créer les UIElements + ui_elements = [] + # Index des résultats par id pour accès rapide + results_by_id = {} + for item in parsed: + item_id = item.get("id") + if item_id is not None: + results_by_id[int(item_id)] = item + + valid_types = { + "button", "text_input", "checkbox", "radio", "dropdown", + "tab", "link", "icon", "table_row", "menu_item" + } + valid_roles = { + "primary_action", "cancel", "submit", "form_input", + "search_field", "navigation", "settings", "close", + "delete", "edit", "save" + } + + for i, region in enumerate(regions): + # Chercher le résultat VLM pour cette région + classification = results_by_id.get(i) + + if classification is None: + # Si le VLM n'a pas classifié cette région, essayer par index dans le tableau + if i < len(parsed): + classification = parsed[i] + else: + continue + + elem_type = str(classification.get("type", "unknown")).lower().strip() + elem_role = str(classification.get("role", "unknown")).lower().strip() + elem_text = str(classification.get("text", "")) + + if elem_type not in valid_types: + elem_type = "unknown" + if elem_role not in valid_roles: + elem_role = "unknown" + + confidence = 0.85 + + # Extraire le crop pour les features visuelles + crop = pil_image.crop(( + region.x, region.y, + region.x + region.w, region.y + region.h + )) + + element = UIElement( + element_id=f"hybrid_{region.x}_{region.y}", + type=elem_type, + role=elem_role, + bbox=(region.x, region.y, region.w, region.h), + center=region.center(), + label=elem_text, + label_confidence=0.8, + embeddings=UIElementEmbeddings(), + visual_features=self._extract_visual_features(crop), + confidence=confidence, + metadata={ + "detected_by": "hybrid_batch", + "detection_method": region.source, + "vlm_model": self.config.vlm_model, + "screenshot_path": screenshot_path, + "batch_classified": True, + } + ) + + if element.confidence >= self.config.confidence_threshold: + ui_elements.append(element) + + logger.info( + f"[PERF] Classification batch VLM : " + f"{len(ui_elements)}/{len(regions)} éléments classifiés" + ) + return ui_elements + + return None + + def _extract_json_array_from_response(self, text: str) -> Optional[List[Dict]]: + """Extraire un tableau JSON d'une réponse VLM, même si entouré de texte.""" + # Nettoyer le markdown + if "```" in text: + lines = text.split("\n") + text = "\n".join([l for l in lines if not l.startswith("```")]) + text = text.strip() + + # Essai 1 : parse direct + try: + result = json.loads(text) + if isinstance(result, list): + return result + except json.JSONDecodeError: + pass + + # Essai 2 : trouver le tableau JSON le plus long dans le texte + # Chercher le premier [ et le dernier ] + start_idx = text.find("[") + end_idx = text.rfind("]") + if start_idx != -1 and end_idx != -1 and end_idx > start_idx: + candidate = text[start_idx:end_idx + 1] + try: + result = json.loads(candidate) + if isinstance(result, list): + return result + except json.JSONDecodeError: + pass + + # Essai 3 : fixer les single quotes + fixed = text.replace("'", '"') + start_idx = fixed.find("[") + end_idx = fixed.rfind("]") + if start_idx != -1 and end_idx != -1 and end_idx > start_idx: + candidate = fixed[start_idx:end_idx + 1] + try: + result = json.loads(candidate) + if isinstance(result, list): + return result + except json.JSONDecodeError: + pass + + # Essai 4 : extraire chaque objet {…} individuellement et construire la liste + matches = re.findall(r'\{[^{}]+\}', text) + if matches: + items = [] + for m in matches: + try: + items.append(json.loads(m)) + except json.JSONDecodeError: + try: + items.append(json.loads(m.replace("'", '"'))) + except json.JSONDecodeError: + pass + if items: + return items + + logger.debug(f"Impossible d'extraire un JSON array: {text[:200]}") + return None + + def _classify_regions_individually( + self, + pil_image: Image.Image, + regions: List[BoundingBox], + screenshot_path: str, + window_context: Optional[Dict] = None, + ) -> List[UIElement]: + """ + Classification individuelle de chaque région (ancien comportement). + + Utilisé comme fallback quand l'appel VLM unique échoue. + """ + ui_elements = [] + MIN_VLM_SIZE = 40 + + for i, region in enumerate(regions): + # Extraire le crop de la région + crop = pil_image.crop(( + region.x, region.y, + region.x + region.w, region.y + region.h + )) + + # Agrandir les crops trop petits pour le VLM (pad ou resize) + if crop.width < MIN_VLM_SIZE or crop.height < MIN_VLM_SIZE: + new_w = max(crop.width, MIN_VLM_SIZE) + new_h = max(crop.height, MIN_VLM_SIZE) + crop = crop.resize((new_w, new_h), Image.NEAREST) + + # Classifier avec VLM + element = self._classify_region( + crop, region, screenshot_path, window_context + ) + + if element and element.confidence >= self.config.confidence_threshold: + ui_elements.append(element) + + return ui_elements + def _classify_region(self, crop: Image.Image, region: BoundingBox,