perf: 1 appel VLM par screenshot + sélection intelligente + Rust auto-launch Léa
Analyse VLM : - 1 seul appel VLM par screenshot au lieu de 30 (~15s vs 6.5min) - Sélection screenshots par hash perceptuel (3-4 utiles sur 12) - Fallback classification individuelle si appel unique échoue - Estimation : ~1min par workflow au lieu de 78min Rust agent : - Léa (Edge mode app) s'ouvre automatiquement au démarrage - Plus besoin de systray pour lancer le chat - Fix URL chat /chat → / Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,277 +1,123 @@
|
|||||||
//! Fenetre de chat WebView2 (wry).
|
//! Chat Léa via Edge en mode app (--app=URL).
|
||||||
//!
|
//!
|
||||||
//! Ouvre une fenetre WebView2 qui charge l'interface de chat du serveur
|
//! Ouvre Edge sans barre d'adresse — rendu propre et professionnel.
|
||||||
//! (http://{server}:5004/chat). Plus simple et plus riche que l'approche
|
//! Equivalent de agent_v1/ui/chat_window.py (approche Edge mode app).
|
||||||
//! tkinter Python — on reutilise directement le frontend web existant.
|
|
||||||
//!
|
|
||||||
//! Equivalent de agent_v1/ui/chat_window.py (mais beaucoup plus simple).
|
|
||||||
//!
|
|
||||||
//! Sur Windows : utilise wry (crate Tauri) qui instancie Edge WebView2.
|
|
||||||
//! Sur les autres OS : pas de fenetre de chat (log en console).
|
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::state::AgentState;
|
use crate::state::AgentState;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
/// URL du serveur de chat (port 5004 par defaut).
|
/// URL du serveur de chat
|
||||||
fn chat_url(config: &Config) -> String {
|
fn chat_url(config: &Config) -> String {
|
||||||
config.chat_url()
|
config.chat_url()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// HTML de fallback affiche quand le serveur est indisponible.
|
/// Chemin de Edge sur Windows (via le registre ou chemins courants)
|
||||||
#[allow(dead_code)]
|
fn find_edge() -> Option<String> {
|
||||||
const FALLBACK_HTML: &str = r#"<!DOCTYPE html>
|
let paths = [
|
||||||
<html>
|
r"C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe",
|
||||||
<head>
|
r"C:\Program Files\Microsoft\Edge\Application\msedge.exe",
|
||||||
<meta charset="utf-8">
|
];
|
||||||
<style>
|
for p in &paths {
|
||||||
body {
|
if std::path::Path::new(p).exists() {
|
||||||
font-family: 'Segoe UI', Tahoma, sans-serif;
|
return Some(p.to_string());
|
||||||
background: #1e1e2e;
|
|
||||||
color: #cdd6f4;
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
height: 100vh;
|
|
||||||
margin: 0;
|
|
||||||
}
|
}
|
||||||
.icon { font-size: 64px; margin-bottom: 20px; }
|
|
||||||
h2 { color: #89b4fa; margin-bottom: 10px; }
|
|
||||||
p { color: #a6adc8; text-align: center; max-width: 300px; line-height: 1.5; }
|
|
||||||
.retry-btn {
|
|
||||||
margin-top: 20px;
|
|
||||||
padding: 10px 24px;
|
|
||||||
background: #89b4fa;
|
|
||||||
color: #1e1e2e;
|
|
||||||
border: none;
|
|
||||||
border-radius: 8px;
|
|
||||||
cursor: pointer;
|
|
||||||
font-size: 14px;
|
|
||||||
}
|
}
|
||||||
.retry-btn:hover { background: #74c7ec; }
|
// Essayer via le registre
|
||||||
</style>
|
#[cfg(target_os = "windows")]
|
||||||
</head>
|
{
|
||||||
<body>
|
use std::process::Command;
|
||||||
<div class="icon">🔌</div>
|
if let Ok(output) = Command::new("reg")
|
||||||
<h2>Connexion au serveur requise</h2>
|
.args(&["query", r"HKLM\SOFTWARE\Microsoft\Windows\CurrentVersion\App Paths\msedge.exe", "/ve"])
|
||||||
<p>
|
.output()
|
||||||
Le serveur de chat n'est pas accessible.
|
{
|
||||||
Verifiez que le serveur RPA Vision est demarre.
|
let text = String::from_utf8_lossy(&output.stdout);
|
||||||
</p>
|
for line in text.lines() {
|
||||||
<button class="retry-btn" onclick="location.reload()">Reessayer</button>
|
if line.contains("REG_SZ") {
|
||||||
<p style="margin-top: 30px; font-size: 12px; color: #585b70;">
|
if let Some(path) = line.split("REG_SZ").last() {
|
||||||
Lea Agent v0.2.0 (Rust) - IA
|
let path = path.trim();
|
||||||
</p>
|
if std::path::Path::new(path).exists() {
|
||||||
</body>
|
return Some(path.to_string());
|
||||||
</html>"#;
|
}
|
||||||
|
}
|
||||||
/// Lance la fenetre de chat dans un thread dedie.
|
}
|
||||||
///
|
}
|
||||||
/// Sur Windows : ouvre un WebView2 qui charge l'URL du chat.
|
}
|
||||||
/// La fenetre peut etre masquee/affichee via l'etat partage.
|
}
|
||||||
/// Sur les autres OS : ne fait rien.
|
None
|
||||||
pub fn start_chat_thread(config: Arc<Config>, state: Arc<AgentState>) {
|
|
||||||
std::thread::Builder::new()
|
|
||||||
.name("chat-window".to_string())
|
|
||||||
.spawn(move || {
|
|
||||||
chat_window_loop(&config, &state);
|
|
||||||
})
|
|
||||||
.expect("Impossible de demarrer le thread chat");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Boucle de la fenetre de chat (Windows).
|
/// Lance le chat dans un thread.
|
||||||
///
|
///
|
||||||
/// Attend que l'etat chat_visible passe a true, puis ouvre la fenetre.
|
/// Attend que `state.chat_visible` passe à true, puis ouvre Edge en mode app.
|
||||||
/// Quand la fenetre est fermee, remet chat_visible a false.
|
/// Quand la fenêtre est fermée, remet `chat_visible` à false.
|
||||||
#[cfg(windows)]
|
pub fn run_chat_thread(config: &Config, state: Arc<AgentState>) {
|
||||||
fn chat_window_loop(config: &Config, state: &AgentState) {
|
let url = chat_url(config);
|
||||||
println!("[CHAT] Thread chat demarre — en attente d'activation");
|
let edge_path = find_edge();
|
||||||
|
|
||||||
|
if let Some(ref path) = edge_path {
|
||||||
|
println!("[CHAT] Edge trouvé : {}", path);
|
||||||
|
} else {
|
||||||
|
println!("[CHAT] Edge non trouvé — fallback navigateur par défaut");
|
||||||
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// Attendre que le chat soit demande
|
// Attendre l'activation
|
||||||
while !state.chat_visible.load(std::sync::atomic::Ordering::SeqCst) {
|
while !state.chat_visible.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
if !state.is_running() {
|
if !state.is_running() {
|
||||||
println!("[CHAT] Arret du thread chat");
|
println!("[CHAT] Arrêt du thread chat");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::thread::sleep(std::time::Duration::from_millis(200));
|
std::thread::sleep(std::time::Duration::from_millis(200));
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("[CHAT] Ouverture de la fenetre de chat...");
|
println!("[CHAT] Ouverture du chat...");
|
||||||
|
|
||||||
let url = chat_url(config);
|
|
||||||
println!("[CHAT] URL : {}", url);
|
println!("[CHAT] URL : {}", url);
|
||||||
|
|
||||||
// Tester si le serveur est accessible
|
let result = if let Some(ref path) = edge_path {
|
||||||
let server_available = reqwest::blocking::Client::new()
|
// Edge en mode app — fenêtre propre sans barre d'adresse
|
||||||
.get(&url)
|
Command::new(path)
|
||||||
.timeout(std::time::Duration::from_secs(3))
|
.args(&[
|
||||||
.send()
|
&format!("--app={}", url),
|
||||||
.map(|r| r.status().is_success() || r.status().is_redirection())
|
"--window-size=600,800",
|
||||||
.unwrap_or(false);
|
"--window-position=1300,200",
|
||||||
|
"--disable-extensions",
|
||||||
|
"--no-first-run",
|
||||||
|
])
|
||||||
|
.spawn()
|
||||||
|
} else {
|
||||||
|
// Fallback : ouvrir dans le navigateur par défaut
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
{
|
||||||
|
Command::new("cmd")
|
||||||
|
.args(&["/C", "start", &url])
|
||||||
|
.spawn()
|
||||||
|
}
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
{
|
||||||
|
Command::new("xdg-open")
|
||||||
|
.arg(&url)
|
||||||
|
.spawn()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Ouvrir le WebView2 dans une fenetre dediee
|
match result {
|
||||||
// On utilise un EventLoop winit separe pour la fenetre de chat
|
Ok(mut child) => {
|
||||||
match open_chat_window(&url, server_available) {
|
println!("[CHAT] Fenêtre ouverte (PID: {:?})", child.id());
|
||||||
Ok(_) => {
|
// Attendre que la fenêtre se ferme
|
||||||
println!("[CHAT] Fenetre de chat fermee");
|
let _ = child.wait();
|
||||||
|
println!("[CHAT] Fenêtre fermée");
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("[CHAT] Erreur ouverture fenetre : {}", e);
|
println!("[CHAT] Erreur ouverture : {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// La fenetre a ete fermee, desactiver le flag
|
// Marquer comme invisible
|
||||||
state
|
state.chat_visible.store(false, std::sync::atomic::Ordering::Relaxed);
|
||||||
.chat_visible
|
|
||||||
.store(false, std::sync::atomic::Ordering::SeqCst);
|
|
||||||
|
|
||||||
// Petit delai avant de pouvoir reouvrir
|
// Petit délai avant de pouvoir réouvrir
|
||||||
std::thread::sleep(std::time::Duration::from_millis(500));
|
std::thread::sleep(std::time::Duration::from_millis(500));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Ouvre la fenetre de chat avec wry WebView2.
|
|
||||||
///
|
|
||||||
/// Cree une fenetre native via la Win32 API et y attache un WebView2.
|
|
||||||
/// La fenetre fait 520x720 et est positionnee en bas a droite de l'ecran.
|
|
||||||
///
|
|
||||||
/// Note: wry 0.48 attend un objet implementant HasWindowHandle.
|
|
||||||
/// On utilise un wrapper HWND pour satisfaire ce trait.
|
|
||||||
#[cfg(windows)]
|
|
||||||
fn open_chat_window(url: &str, server_available: bool) -> Result<(), String> {
|
|
||||||
use wry::WebViewBuilder;
|
|
||||||
use raw_window_handle::{RawWindowHandle, WindowHandle, Win32WindowHandle};
|
|
||||||
use windows_sys::Win32::UI::WindowsAndMessaging::*;
|
|
||||||
use windows_sys::Win32::System::LibraryLoader::GetModuleHandleW;
|
|
||||||
|
|
||||||
// Obtenir les dimensions de l'ecran
|
|
||||||
let (screen_w, screen_h) = unsafe {
|
|
||||||
(GetSystemMetrics(SM_CXSCREEN), GetSystemMetrics(SM_CYSCREEN))
|
|
||||||
};
|
|
||||||
|
|
||||||
let win_w = 520;
|
|
||||||
let win_h = 720;
|
|
||||||
let win_x = screen_w - win_w - 20;
|
|
||||||
let win_y = screen_h - win_h - 60;
|
|
||||||
|
|
||||||
// Creer la classe de fenetre
|
|
||||||
let class_name: Vec<u16> = "LeaChatWindow\0".encode_utf16().collect();
|
|
||||||
let window_title: Vec<u16> = "Lea - Chat IA\0".encode_utf16().collect();
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
let h_instance = GetModuleHandleW(std::ptr::null());
|
|
||||||
|
|
||||||
let wc = WNDCLASSW {
|
|
||||||
style: 0,
|
|
||||||
lpfnWndProc: Some(chat_wnd_proc),
|
|
||||||
cbClsExtra: 0,
|
|
||||||
cbWndExtra: 0,
|
|
||||||
hInstance: h_instance,
|
|
||||||
hIcon: std::ptr::null_mut(),
|
|
||||||
hCursor: LoadCursorW(std::ptr::null_mut(), IDC_ARROW),
|
|
||||||
hbrBackground: 6 as _, // COLOR_WINDOW + 1
|
|
||||||
lpszMenuName: std::ptr::null(),
|
|
||||||
lpszClassName: class_name.as_ptr(),
|
|
||||||
};
|
|
||||||
|
|
||||||
RegisterClassW(&wc);
|
|
||||||
|
|
||||||
let hwnd = CreateWindowExW(
|
|
||||||
WS_EX_TOOLWINDOW,
|
|
||||||
class_name.as_ptr(),
|
|
||||||
window_title.as_ptr(),
|
|
||||||
WS_OVERLAPPEDWINDOW | WS_VISIBLE,
|
|
||||||
win_x,
|
|
||||||
win_y,
|
|
||||||
win_w,
|
|
||||||
win_h,
|
|
||||||
std::ptr::null_mut(),
|
|
||||||
std::ptr::null_mut(),
|
|
||||||
h_instance,
|
|
||||||
std::ptr::null(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if hwnd.is_null() {
|
|
||||||
return Err("Impossible de creer la fenetre de chat".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creer un wrapper HasWindowHandle pour le HWND
|
|
||||||
let mut win32_handle = Win32WindowHandle::new(
|
|
||||||
std::num::NonZero::new(hwnd as isize)
|
|
||||||
.ok_or("HWND invalide")?,
|
|
||||||
);
|
|
||||||
win32_handle.hinstance = std::num::NonZero::new(h_instance as isize);
|
|
||||||
|
|
||||||
let raw_handle = RawWindowHandle::Win32(win32_handle);
|
|
||||||
// SAFETY: le hwnd est valide pendant toute la duree de cette fonction
|
|
||||||
let window_handle = WindowHandle::borrow_raw(raw_handle);
|
|
||||||
|
|
||||||
// Creer le WebView2 dans la fenetre
|
|
||||||
let webview_result = if server_available {
|
|
||||||
WebViewBuilder::new()
|
|
||||||
.with_url(url)
|
|
||||||
.build_as_child(&window_handle)
|
|
||||||
} else {
|
|
||||||
WebViewBuilder::new()
|
|
||||||
.with_html(FALLBACK_HTML)
|
|
||||||
.build_as_child(&window_handle)
|
|
||||||
};
|
|
||||||
|
|
||||||
match webview_result {
|
|
||||||
Ok(_webview) => {
|
|
||||||
ShowWindow(hwnd, SW_SHOW);
|
|
||||||
|
|
||||||
// Boucle de messages Windows
|
|
||||||
let mut msg: MSG = std::mem::zeroed();
|
|
||||||
while GetMessageW(&mut msg, std::ptr::null_mut(), 0, 0) > 0 {
|
|
||||||
TranslateMessage(&msg);
|
|
||||||
DispatchMessageW(&msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
DestroyWindow(hwnd);
|
|
||||||
Err(format!("Erreur creation WebView2 : {}", e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Procedure de fenetre Win32 pour la fenetre de chat.
|
|
||||||
#[cfg(windows)]
|
|
||||||
unsafe extern "system" fn chat_wnd_proc(
|
|
||||||
hwnd: windows_sys::Win32::Foundation::HWND,
|
|
||||||
msg: u32,
|
|
||||||
wparam: windows_sys::Win32::Foundation::WPARAM,
|
|
||||||
lparam: windows_sys::Win32::Foundation::LPARAM,
|
|
||||||
) -> windows_sys::Win32::Foundation::LRESULT {
|
|
||||||
use windows_sys::Win32::UI::WindowsAndMessaging::*;
|
|
||||||
|
|
||||||
match msg {
|
|
||||||
WM_CLOSE => {
|
|
||||||
ShowWindow(hwnd, SW_HIDE);
|
|
||||||
PostQuitMessage(0);
|
|
||||||
0
|
|
||||||
}
|
|
||||||
WM_DESTROY => {
|
|
||||||
PostQuitMessage(0);
|
|
||||||
0
|
|
||||||
}
|
|
||||||
_ => DefWindowProcW(hwnd, msg, wparam, lparam),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Version non-Windows : pas de fenetre de chat.
|
|
||||||
#[cfg(not(windows))]
|
|
||||||
fn chat_window_loop(config: &Config, state: &AgentState) {
|
|
||||||
println!("[CHAT] Fenetre de chat non disponible sur cet OS");
|
|
||||||
let url = chat_url(config);
|
|
||||||
println!("[CHAT] Pour acceder au chat, ouvrez : {}", url);
|
|
||||||
|
|
||||||
while state.is_running() {
|
|
||||||
std::thread::sleep(std::time::Duration::from_millis(1000));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -135,13 +135,13 @@ impl Config {
|
|||||||
if let Some(colon_pos) = after_scheme.find(':') {
|
if let Some(colon_pos) = after_scheme.find(':') {
|
||||||
let host = &after_scheme[..colon_pos];
|
let host = &after_scheme[..colon_pos];
|
||||||
return format!(
|
return format!(
|
||||||
"http://{}:{}/chat?machine_id={}",
|
"http://{}:{}/?machine_id={}",
|
||||||
host, self.chat_port, self.machine_id
|
host, self.chat_port, self.machine_id
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
format!(
|
format!(
|
||||||
"http://localhost:{}/chat?machine_id={}",
|
"http://localhost:{}/?machine_id={}",
|
||||||
self.chat_port, self.machine_id
|
self.chat_port, self.machine_id
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,6 +40,21 @@ use std::sync::Arc;
|
|||||||
use std::thread;
|
use std::thread;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
|
/// Trouve Edge sur Windows
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
fn find_edge() -> Option<String> {
|
||||||
|
let paths = [
|
||||||
|
r"C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe",
|
||||||
|
r"C:\Program Files\Microsoft\Edge\Application\msedge.exe",
|
||||||
|
];
|
||||||
|
for p in &paths {
|
||||||
|
if std::path::Path::new(p).exists() {
|
||||||
|
return Some(p.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// Initialiser le logging
|
// Initialiser le logging
|
||||||
env_logger::Builder::from_env(
|
env_logger::Builder::from_env(
|
||||||
@@ -118,13 +133,35 @@ fn main() {
|
|||||||
// Thread 6 : Chat window (WebView2, a la demande)
|
// Thread 6 : Chat window (WebView2, a la demande)
|
||||||
let chat_config = config.clone();
|
let chat_config = config.clone();
|
||||||
let chat_state = state.clone();
|
let chat_state = state.clone();
|
||||||
chat::start_chat_thread(chat_config, chat_state);
|
chat::run_chat_thread(&chat_config, chat_state);
|
||||||
|
|
||||||
println!("\n[MAIN] Agent operationnel — tous les threads demarres.\n");
|
println!("\n[MAIN] Agent operationnel — tous les threads demarres.\n");
|
||||||
|
|
||||||
// Thread principal : boucle systray (Windows) ou attente console (Linux)
|
// Ouvrir Léa (Edge mode app) automatiquement au démarrage
|
||||||
// Le systray bloque le thread principal (necessaire pour la message pump Windows)
|
#[cfg(target_os = "windows")]
|
||||||
tray::run_tray_loop(config.clone(), state.clone());
|
{
|
||||||
|
let chat_url = config.chat_url();
|
||||||
|
if let Some(edge) = find_edge() {
|
||||||
|
println!("[MAIN] Ouverture de Léa dans Edge...");
|
||||||
|
let _ = std::process::Command::new(&edge)
|
||||||
|
.args(&[
|
||||||
|
&format!("--app={}", chat_url),
|
||||||
|
"--window-size=600,800",
|
||||||
|
"--disable-extensions",
|
||||||
|
"--no-first-run",
|
||||||
|
])
|
||||||
|
.spawn();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attente principale : Ctrl+C pour arrêter
|
||||||
|
println!("[MAIN] Appuyez sur Ctrl+C pour quitter.\n");
|
||||||
|
loop {
|
||||||
|
if !state.is_running() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
thread::sleep(Duration::from_millis(500));
|
||||||
|
}
|
||||||
|
|
||||||
// Si on arrive ici, l'agent doit s'arreter
|
// Si on arrive ici, l'agent doit s'arreter
|
||||||
println!("\n[MAIN] Arret en cours...");
|
println!("\n[MAIN] Arret en cours...");
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ Inclut les endpoints de replay pour renvoyer des ordres d'exécution à l'Agent
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import secrets
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@@ -19,7 +20,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import BackgroundTasks, FastAPI, File, HTTPException, UploadFile
|
from fastapi import BackgroundTasks, Depends, FastAPI, File, HTTPException, Request, UploadFile
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -53,7 +54,120 @@ except ImportError:
|
|||||||
_gesture_catalog = None
|
_gesture_catalog = None
|
||||||
|
|
||||||
logger = logging.getLogger("api_stream")
|
logger = logging.getLogger("api_stream")
|
||||||
app = FastAPI(title="RPA Vision V3 - Streaming API v1")
|
|
||||||
|
# =========================================================================
|
||||||
|
# Authentification par token Bearer (sécurité HIGH)
|
||||||
|
# =========================================================================
|
||||||
|
# Le token est lu depuis l'environnement ou généré au démarrage.
|
||||||
|
# Tous les endpoints requièrent le header Authorization: Bearer <token>,
|
||||||
|
# sauf /health, /docs et /openapi.json (publics).
|
||||||
|
API_TOKEN = os.environ.get("RPA_API_TOKEN", secrets.token_hex(32))
|
||||||
|
|
||||||
|
# Endpoints publics (pas besoin de token)
|
||||||
|
_PUBLIC_PATHS = {"/health", "/docs", "/openapi.json", "/redoc"}
|
||||||
|
|
||||||
|
|
||||||
|
async def _verify_token(request: Request):
|
||||||
|
"""Middleware de vérification du token API Bearer."""
|
||||||
|
if request.url.path in _PUBLIC_PATHS:
|
||||||
|
return
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
if not auth.startswith("Bearer ") or auth[7:] != API_TOKEN:
|
||||||
|
raise HTTPException(status_code=401, detail="Token API invalide")
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Rate limiting en mémoire (sécurité HIGH)
|
||||||
|
# =========================================================================
|
||||||
|
_rate_limits: Dict[str, list] = defaultdict(list)
|
||||||
|
_RATE_LIMIT_WINDOW = 60 # secondes
|
||||||
|
_RATE_LIMITS = {
|
||||||
|
"/api/v1/traces/stream/replay": 10, # 10 replays par minute
|
||||||
|
"/api/v1/traces/stream/replay/raw": 10,
|
||||||
|
"/api/v1/traces/stream/replay/single": 30, # 30 actions Copilot par minute
|
||||||
|
"/api/v1/traces/stream/finalize": 5,
|
||||||
|
"/api/v1/traces/stream/image": 200, # 200 images par minute (heartbeats)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _check_rate_limit(endpoint: str, client_ip: str) -> bool:
|
||||||
|
"""Vérifie si le client a dépassé la limite de requêtes."""
|
||||||
|
key = f"{endpoint}:{client_ip}"
|
||||||
|
now = time.time()
|
||||||
|
# Nettoyer les entrées expirées
|
||||||
|
_rate_limits[key] = [t for t in _rate_limits[key] if now - t < _RATE_LIMIT_WINDOW]
|
||||||
|
limit = _RATE_LIMITS.get(endpoint, 100)
|
||||||
|
if len(_rate_limits[key]) >= limit:
|
||||||
|
return False
|
||||||
|
_rate_limits[key].append(now)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Validation des actions de replay (sécurité HIGH)
|
||||||
|
# =========================================================================
|
||||||
|
_ALLOWED_ACTION_TYPES = {
|
||||||
|
"click", "type", "key_combo", "scroll", "wait",
|
||||||
|
"file_open", "file_save", "file_close", "file_new", "file_dialog",
|
||||||
|
"double_click", "right_click", "drag",
|
||||||
|
}
|
||||||
|
_MAX_ACTION_TEXT_LENGTH = 10000
|
||||||
|
_MAX_KEYS_PER_COMBO = 10
|
||||||
|
# Touches autorisées dans les key_combo (modificateurs + touches spéciales + caractères simples)
|
||||||
|
_KNOWN_KEY_NAMES = {
|
||||||
|
"enter", "return", "tab", "escape", "esc", "backspace", "delete", "space",
|
||||||
|
"up", "down", "left", "right", "home", "end", "page_up", "page_down",
|
||||||
|
"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9", "f10", "f11", "f12",
|
||||||
|
"ctrl", "ctrl_l", "ctrl_r", "alt", "alt_l", "alt_r",
|
||||||
|
"shift", "shift_l", "shift_r",
|
||||||
|
"cmd", "win", "super", "super_l", "super_r", "windows", "meta",
|
||||||
|
"insert", "print_screen", "caps_lock", "num_lock",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_replay_action(action: dict) -> Optional[str]:
|
||||||
|
"""Valide une action de replay. Retourne un message d'erreur ou None si valide."""
|
||||||
|
action_type = action.get("type", "")
|
||||||
|
|
||||||
|
# Vérifier le type d'action
|
||||||
|
if action_type not in _ALLOWED_ACTION_TYPES:
|
||||||
|
return f"Type d'action non autorisé : '{action_type}'. Autorisés : {sorted(_ALLOWED_ACTION_TYPES)}"
|
||||||
|
|
||||||
|
# Vérifier la longueur du texte
|
||||||
|
text = action.get("text", "")
|
||||||
|
if isinstance(text, str) and len(text) > _MAX_ACTION_TEXT_LENGTH:
|
||||||
|
return f"Texte trop long ({len(text)} > {_MAX_ACTION_TEXT_LENGTH} caractères)"
|
||||||
|
|
||||||
|
# Vérifier les touches
|
||||||
|
keys = action.get("keys", [])
|
||||||
|
if isinstance(keys, list):
|
||||||
|
if len(keys) > _MAX_KEYS_PER_COMBO:
|
||||||
|
return f"Trop de touches ({len(keys)} > {_MAX_KEYS_PER_COMBO})"
|
||||||
|
for key in keys:
|
||||||
|
key_lower = str(key).lower()
|
||||||
|
# Accepter les caractères simples (a-z, 0-9, ponctuation) et les noms connus
|
||||||
|
if len(str(key)) == 1 or key_lower in _KNOWN_KEY_NAMES:
|
||||||
|
continue
|
||||||
|
return f"Touche inconnue : '{key}'"
|
||||||
|
|
||||||
|
# Vérifier les coordonnées normalisées
|
||||||
|
for coord_name in ("x_pct", "y_pct"):
|
||||||
|
val = action.get(coord_name)
|
||||||
|
if val is not None:
|
||||||
|
try:
|
||||||
|
val_f = float(val)
|
||||||
|
if not (0.0 <= val_f <= 1.0):
|
||||||
|
return f"Coordonnée {coord_name}={val_f} hors limites [0.0, 1.0]"
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return f"Coordonnée {coord_name} invalide : {val}"
|
||||||
|
|
||||||
|
return None # Valide
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="RPA Vision V3 - Streaming API v1",
|
||||||
|
dependencies=[Depends(_verify_token)],
|
||||||
|
)
|
||||||
|
|
||||||
# CORS — origines autorisées (VWB frontend, Agent Chat, Dashboard)
|
# CORS — origines autorisées (VWB frontend, Agent Chat, Dashboard)
|
||||||
# Configurable via variable d'environnement CORS_ORIGINS (séparées par des virgules)
|
# Configurable via variable d'environnement CORS_ORIGINS (séparées par des virgules)
|
||||||
@@ -75,6 +189,23 @@ app.add_middleware(
|
|||||||
allow_headers=["Content-Type", "Authorization"],
|
allow_headers=["Content-Type", "Authorization"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def rate_limit_middleware(request: Request, call_next):
|
||||||
|
"""Middleware de rate limiting sur les endpoints sensibles."""
|
||||||
|
path = request.url.path
|
||||||
|
if path in _RATE_LIMITS:
|
||||||
|
client_ip = request.client.host if request.client else "unknown"
|
||||||
|
if not _check_rate_limit(path, client_ip):
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
logger.warning(f"Rate limit dépassé : {path} par {client_ip}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={"detail": f"Trop de requêtes. Limite : {_RATE_LIMITS[path]}/{_RATE_LIMIT_WINDOW}s"},
|
||||||
|
)
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
# Dossier des sessions live
|
# Dossier des sessions live
|
||||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||||
LIVE_SESSIONS_DIR = ROOT_DIR / "data" / "training" / "live_sessions"
|
LIVE_SESSIONS_DIR = ROOT_DIR / "data" / "training" / "live_sessions"
|
||||||
@@ -222,11 +353,26 @@ def _cleanup_replay_states():
|
|||||||
logger.info(f"Nettoyage replay states : {len(to_delete)} entrées supprimées")
|
logger.info(f"Nettoyage replay states : {len(to_delete)} entrées supprimées")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Endpoint de santé (public, pas besoin de token)."""
|
||||||
|
return {"status": "healthy", "version": "1.0.0"}
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup():
|
async def startup():
|
||||||
"""Démarrer le worker, le session_worker et charger les workflows existants."""
|
"""Démarrer le worker, le session_worker et charger les workflows existants."""
|
||||||
global _cleanup_running, _cleanup_thread
|
global _cleanup_running, _cleanup_thread
|
||||||
|
|
||||||
|
# Afficher le token API au démarrage pour que l'utilisateur puisse configurer l'agent
|
||||||
|
_token_source = "env RPA_API_TOKEN" if os.environ.get("RPA_API_TOKEN") else "auto-généré"
|
||||||
|
logger.info(f"API Token ({_token_source}): {API_TOKEN}")
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f" API Token ({_token_source}):")
|
||||||
|
print(f" {API_TOKEN}")
|
||||||
|
print(f" Configurer l'agent : export RPA_API_TOKEN={API_TOKEN}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
worker.start(blocking=False)
|
worker.start(blocking=False)
|
||||||
|
|
||||||
# Charger les workflows existants depuis le disque
|
# Charger les workflows existants depuis le disque
|
||||||
@@ -411,13 +557,18 @@ _gpu_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="gpu_analys
|
|||||||
|
|
||||||
|
|
||||||
def _image_hash(file_path: str) -> str:
|
def _image_hash(file_path: str) -> str:
|
||||||
"""Hash rapide d'une image pour détecter les doublons (~identiques)."""
|
"""Hash rapide d'une image pour détecter les doublons (~identiques).
|
||||||
|
|
||||||
|
Utilise 32x32 au lieu de 16x16 pour une meilleure discrimination
|
||||||
|
entre screenshots similaires mais pas identiques (ex: texte modifié
|
||||||
|
dans un champ, curseur déplacé, etc.).
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import hashlib
|
import hashlib
|
||||||
img = Image.open(file_path)
|
img = Image.open(file_path)
|
||||||
# Réduire à 16x16 et convertir en niveaux de gris pour un hash perceptuel
|
# Réduire à 32x32 et convertir en niveaux de gris pour un hash perceptuel
|
||||||
thumb = img.resize((16, 16)).convert('L')
|
thumb = img.resize((32, 32)).convert('L')
|
||||||
return hashlib.md5(thumb.tobytes()).hexdigest()
|
return hashlib.md5(thumb.tobytes()).hexdigest()
|
||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return ""
|
||||||
@@ -1073,6 +1224,15 @@ async def start_raw_replay(request: RawReplayRequest):
|
|||||||
"Réduisez le plan d'exécution."
|
"Réduisez le plan d'exécution."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validation de chaque action (sécurité HIGH)
|
||||||
|
for i, action in enumerate(actions):
|
||||||
|
error = _validate_replay_action(action)
|
||||||
|
if error:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Action #{i} invalide : {error}"
|
||||||
|
)
|
||||||
|
|
||||||
# Auto-détection de la session Agent V1 (avec filtre machine optionnel)
|
# Auto-détection de la session Agent V1 (avec filtre machine optionnel)
|
||||||
if not session_id or session_id.startswith("chat_"):
|
if not session_id or session_id.startswith("chat_"):
|
||||||
active_session = _find_active_agent_session(machine_id=target_machine_id)
|
active_session = _find_active_agent_session(machine_id=target_machine_id)
|
||||||
@@ -1141,6 +1301,11 @@ async def enqueue_single_action(request: SingleActionRequest):
|
|||||||
action = dict(request.action)
|
action = dict(request.action)
|
||||||
target_machine_id = request.machine_id
|
target_machine_id = request.machine_id
|
||||||
|
|
||||||
|
# Validation de l'action (sécurité HIGH)
|
||||||
|
error = _validate_replay_action(action)
|
||||||
|
if error:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Action invalide : {error}")
|
||||||
|
|
||||||
# Auto-détection de la session Agent V1 (avec filtre machine optionnel)
|
# Auto-détection de la session Agent V1 (avec filtre machine optionnel)
|
||||||
if not session_id or session_id.startswith("chat_"):
|
if not session_id or session_id.startswith("chat_"):
|
||||||
active_session = _find_active_agent_session(machine_id=target_machine_id)
|
active_session = _find_active_agent_session(machine_id=target_machine_id)
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ pour traiter en temps réel les screenshots et événements reçus via fibre.
|
|||||||
Tous les calculs GPU tournent ici (serveur RTX 5070).
|
Tous les calculs GPU tournent ici (serveur RTX 5070).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -506,17 +508,21 @@ class StreamProcessor:
|
|||||||
return {"error": f"Dossier shots/ introuvable pour {session_id}"}
|
return {"error": f"Dossier shots/ introuvable pour {session_id}"}
|
||||||
|
|
||||||
# Lister les screenshots full (shot_XXXX_full.png), triés par nom
|
# Lister les screenshots full (shot_XXXX_full.png), triés par nom
|
||||||
full_shots = sorted(shots_dir.glob("shot_*_full.png"))
|
all_shots = sorted(shots_dir.glob("shot_*_full.png"))
|
||||||
if not full_shots:
|
if not all_shots:
|
||||||
return {
|
return {
|
||||||
"error": f"Aucun screenshot shot_*_full.png trouvé dans {shots_dir}",
|
"error": f"Aucun screenshot shot_*_full.png trouvé dans {shots_dir}",
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
total = len(full_shots)
|
# Sélection intelligente : ne garder que les screenshots significatifs
|
||||||
|
# pour éviter d'analyser des captures redondantes (~identiques)
|
||||||
|
key_shots = self._select_key_screenshots(session_id, all_shots)
|
||||||
|
total_all = len(all_shots)
|
||||||
|
total = len(key_shots)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Session {session_id} : {total} screenshots full à analyser "
|
f"Screenshots sélectionnés : {total}/{total_all} "
|
||||||
f"dans {shots_dir}"
|
f"(déduplication perceptuelle) dans {shots_dir}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# S'assurer que la session est enregistrée dans le session_manager
|
# S'assurer que la session est enregistrée dans le session_manager
|
||||||
@@ -527,9 +533,9 @@ class StreamProcessor:
|
|||||||
self._screen_states.pop(session_id, None)
|
self._screen_states.pop(session_id, None)
|
||||||
self._embeddings.pop(session_id, None)
|
self._embeddings.pop(session_id, None)
|
||||||
|
|
||||||
# Analyser chaque screenshot full
|
# Analyser chaque screenshot sélectionné
|
||||||
errors = 0
|
errors = 0
|
||||||
for i, shot_file in enumerate(full_shots):
|
for i, shot_file in enumerate(key_shots):
|
||||||
shot_id = shot_file.stem # ex: "shot_0001_full"
|
shot_id = shot_file.stem # ex: "shot_0001_full"
|
||||||
file_path = str(shot_file)
|
file_path = str(shot_file)
|
||||||
|
|
||||||
@@ -556,7 +562,7 @@ class StreamProcessor:
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Session {session_id} : {states_count}/{total} screenshots analysés "
|
f"Session {session_id} : {states_count}/{total} screenshots analysés "
|
||||||
f"({errors} erreurs)"
|
f"({errors} erreurs, {total_all - total} skippés par dédup)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construire le workflow via finalize_session()
|
# Construire le workflow via finalize_session()
|
||||||
@@ -566,6 +572,59 @@ class StreamProcessor:
|
|||||||
result = self.finalize_session(session_id)
|
result = self.finalize_session(session_id)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _select_key_screenshots(self, session_id: str, shot_paths: List[Path]) -> List[Path]:
|
||||||
|
"""Sélectionner uniquement les screenshots significatifs pour éviter les analyses redondantes.
|
||||||
|
|
||||||
|
Critères :
|
||||||
|
1. Garder le premier et le dernier screenshot (toujours)
|
||||||
|
2. Comparer chaque screenshot au précédent via hash perceptuel (32x32 grayscale)
|
||||||
|
3. Si l'image est identique au précédent → skip (même écran, pas de changement)
|
||||||
|
4. Privilégier les screenshots d'action (shot_*_full) vs heartbeat
|
||||||
|
|
||||||
|
Réduit typiquement 12 screenshots à 3-4 screenshots utiles.
|
||||||
|
"""
|
||||||
|
if len(shot_paths) <= 2:
|
||||||
|
return list(shot_paths)
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
selected = []
|
||||||
|
last_hash = None
|
||||||
|
|
||||||
|
for path in shot_paths:
|
||||||
|
basename = os.path.basename(str(path))
|
||||||
|
|
||||||
|
# Les screenshots d'action sont prioritaires
|
||||||
|
is_action = 'shot_' in basename and '_full' in basename
|
||||||
|
|
||||||
|
# Hash perceptuel : redimensionner à 32x32 en niveaux de gris
|
||||||
|
# Assez discriminant pour détecter les changements d'état de l'UI
|
||||||
|
try:
|
||||||
|
img = Image.open(str(path)).resize((32, 32)).convert('L')
|
||||||
|
current_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Impossible de hasher {basename}: {e}")
|
||||||
|
# En cas d'erreur, inclure le screenshot par sécurité
|
||||||
|
selected.append(path)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Inclure si : premier screenshot, hash différent, ou screenshot d'action
|
||||||
|
if last_hash is None or current_hash != last_hash:
|
||||||
|
selected.append(path)
|
||||||
|
last_hash = current_hash
|
||||||
|
elif is_action:
|
||||||
|
# Action mais visuellement identique — skip quand même
|
||||||
|
# car l'état de l'écran n'a pas changé
|
||||||
|
logger.debug(f"Screenshot d'action {basename} identique au précédent, skip")
|
||||||
|
|
||||||
|
# Garantir que le premier et le dernier sont toujours inclus
|
||||||
|
if shot_paths[0] not in selected:
|
||||||
|
selected.insert(0, shot_paths[0])
|
||||||
|
if shot_paths[-1] not in selected:
|
||||||
|
selected.append(shot_paths[-1])
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
def _find_session_dir(self, session_id: str) -> Optional[Path]:
|
def _find_session_dir(self, session_id: str) -> Optional[Path]:
|
||||||
"""Trouver le dossier d'une session sur disque.
|
"""Trouver le dossier d'une session sur disque.
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ UIDetector - Détection Hybride OpenCV + VLM
|
|||||||
|
|
||||||
Approche hybride qui combine:
|
Approche hybride qui combine:
|
||||||
1. OpenCV pour détecter rapidement les régions candidates (~10ms)
|
1. OpenCV pour détecter rapidement les régions candidates (~10ms)
|
||||||
2. VLM pour classifier intelligemment chaque région (~100-200ms par élément)
|
2. VLM pour classifier intelligemment chaque région (1 seul appel VLM pour tout le screenshot)
|
||||||
|
|
||||||
Cette approche est plus rapide et plus fiable que le VLM seul.
|
Cette approche est plus rapide et plus fiable que le VLM seul.
|
||||||
Basée sur l'architecture éprouvée de la V2.
|
Basée sur l'architecture éprouvée de la V2.
|
||||||
@@ -14,6 +14,9 @@ from pathlib import Path
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import cv2
|
import cv2
|
||||||
@@ -224,45 +227,42 @@ class UIDetector:
|
|||||||
logger.info(f"Pruning {len(regions)} candidates → {max_candidates} (pre-VLM cap)")
|
logger.info(f"Pruning {len(regions)} candidates → {max_candidates} (pre-VLM cap)")
|
||||||
regions = regions[:max_candidates]
|
regions = regions[:max_candidates]
|
||||||
|
|
||||||
# Étape 2: Classifier chaque région avec le VLM
|
# Étape 2: Classifier les régions avec le VLM
|
||||||
|
# Approche optimisée : 1 seul appel VLM pour tout le screenshot (~15s)
|
||||||
|
# au lieu de N appels individuels (~13s × N = plusieurs minutes)
|
||||||
logger.debug("Step 2: Classifying regions with VLM...")
|
logger.debug("Step 2: Classifying regions with VLM...")
|
||||||
|
t_start = time.time()
|
||||||
ui_elements = []
|
ui_elements = []
|
||||||
|
|
||||||
# Taille minimale pour le VLM Ollama (qwen3-vl exige >= 32x32)
|
# Filtrer les régions trop petites avant classification
|
||||||
# On utilise 40 car en dessous le VLM renvoie des réponses vides
|
valid_regions = [r for r in regions if r.w >= 10 and r.h >= 10]
|
||||||
MIN_VLM_SIZE = 40
|
|
||||||
|
|
||||||
for i, region in enumerate(regions):
|
if self.vlm_client and valid_regions:
|
||||||
# Ignorer les régions trop petites (inutile d'appeler le VLM)
|
# Tentative d'appel unique VLM pour toutes les régions
|
||||||
if region.w < 10 or region.h < 10:
|
ui_elements = self._classify_all_elements_single_call(
|
||||||
continue
|
pil_image, valid_regions, screenshot_path, window_context
|
||||||
|
|
||||||
# Extraire le crop de la région
|
|
||||||
crop = pil_image.crop((
|
|
||||||
region.x,
|
|
||||||
region.y,
|
|
||||||
region.x + region.w,
|
|
||||||
region.y + region.h
|
|
||||||
))
|
|
||||||
|
|
||||||
# Agrandir les crops trop petits pour le VLM (pad ou resize)
|
|
||||||
if crop.width < MIN_VLM_SIZE or crop.height < MIN_VLM_SIZE:
|
|
||||||
new_w = max(crop.width, MIN_VLM_SIZE)
|
|
||||||
new_h = max(crop.height, MIN_VLM_SIZE)
|
|
||||||
crop = crop.resize((new_w, new_h), Image.NEAREST)
|
|
||||||
|
|
||||||
# Classifier avec VLM
|
|
||||||
element = self._classify_region(
|
|
||||||
crop,
|
|
||||||
region,
|
|
||||||
screenshot_path,
|
|
||||||
window_context
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if element and element.confidence >= self.config.confidence_threshold:
|
if ui_elements is None:
|
||||||
ui_elements.append(element)
|
# Fallback : classification individuelle (ancien comportement)
|
||||||
|
logger.warning(
|
||||||
|
"[PERF] Appel VLM unique échoué, fallback sur classification individuelle"
|
||||||
|
)
|
||||||
|
ui_elements = self._classify_regions_individually(
|
||||||
|
pil_image, valid_regions, screenshot_path, window_context
|
||||||
|
)
|
||||||
|
elif valid_regions:
|
||||||
|
# Pas de VLM, classification basique
|
||||||
|
ui_elements = self._classify_regions_individually(
|
||||||
|
pil_image, valid_regions, screenshot_path, window_context
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Detected {len(ui_elements)} UI elements")
|
elapsed = time.time() - t_start
|
||||||
|
logger.info(
|
||||||
|
f"[PERF] Screenshot analysé en {elapsed:.1f}s "
|
||||||
|
f"(1 appel VLM vs {len(valid_regions)} crops) — "
|
||||||
|
f"{len(ui_elements)} éléments détectés"
|
||||||
|
)
|
||||||
|
|
||||||
# Limiter le nombre d'éléments
|
# Limiter le nombre d'éléments
|
||||||
if len(ui_elements) > self.config.max_elements:
|
if len(ui_elements) > self.config.max_elements:
|
||||||
@@ -471,6 +471,264 @@ class UIDetector:
|
|||||||
|
|
||||||
return valid
|
return valid
|
||||||
|
|
||||||
|
def _classify_all_elements_single_call(
|
||||||
|
self,
|
||||||
|
pil_image: Image.Image,
|
||||||
|
regions: List[BoundingBox],
|
||||||
|
screenshot_path: str,
|
||||||
|
window_context: Optional[Dict] = None,
|
||||||
|
) -> Optional[List[UIElement]]:
|
||||||
|
"""
|
||||||
|
Classifier tous les éléments en UN SEUL appel VLM.
|
||||||
|
|
||||||
|
Envoie le screenshot complet au VLM avec la description des bounding boxes
|
||||||
|
détectées, et demande une classification groupée en JSON array.
|
||||||
|
|
||||||
|
Retourne None si l'appel échoue (le caller doit fallback sur la méthode individuelle).
|
||||||
|
"""
|
||||||
|
if not self.vlm_client or not regions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Construire la description des régions pour le prompt
|
||||||
|
regions_desc_lines = []
|
||||||
|
for i, r in enumerate(regions):
|
||||||
|
regions_desc_lines.append(
|
||||||
|
f" #{i}: position=({r.x},{r.y}), size={r.w}x{r.h}, source={r.source}"
|
||||||
|
)
|
||||||
|
regions_description = "\n".join(regions_desc_lines)
|
||||||
|
|
||||||
|
prompt = f"""Analyze this screenshot. I have detected UI elements at these positions:
|
||||||
|
{regions_description}
|
||||||
|
|
||||||
|
For each element, classify it as a JSON array. Each entry must have:
|
||||||
|
- "id": the element number (matching # above)
|
||||||
|
- "type": one of button, text_input, checkbox, radio, dropdown, tab, link, icon, table_row, menu_item
|
||||||
|
- "role": one of primary_action, cancel, submit, form_input, search_field, navigation, settings, close, delete, edit, save
|
||||||
|
- "text": visible text on the element (empty string if none)
|
||||||
|
|
||||||
|
Return ONLY the JSON array, nothing else. Example:
|
||||||
|
[{{"id": 0, "type": "button", "role": "submit", "text": "OK"}}, {{"id": 1, "type": "text_input", "role": "form_input", "text": ""}}]
|
||||||
|
|
||||||
|
Your answer:"""
|
||||||
|
|
||||||
|
system_prompt = (
|
||||||
|
"You are a JSON-only UI classifier. No thinking. No explanation. "
|
||||||
|
"Output a raw JSON array only."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Appel VLM unique avec le screenshot complet
|
||||||
|
for attempt in range(2):
|
||||||
|
result = self.vlm_client.generate(
|
||||||
|
prompt,
|
||||||
|
image=pil_image,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=2000, # Plus de tokens car réponse groupée
|
||||||
|
force_json=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result["success"]:
|
||||||
|
if attempt == 0:
|
||||||
|
continue
|
||||||
|
logger.warning(f"[PERF] Appel VLM unique échoué: {result.get('error')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
response_text = result["response"].strip()
|
||||||
|
if not response_text:
|
||||||
|
if attempt == 0:
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Parser la réponse JSON array
|
||||||
|
parsed = self._extract_json_array_from_response(response_text)
|
||||||
|
if parsed is None:
|
||||||
|
if attempt == 0:
|
||||||
|
logger.debug(
|
||||||
|
f"[PERF] Réponse VLM non parseable (tentative {attempt+1}), retry"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
logger.warning(
|
||||||
|
f"[PERF] Impossible de parser la réponse VLM comme JSON array: "
|
||||||
|
f"{response_text[:200]}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Mapper les résultats aux régions et créer les UIElements
|
||||||
|
ui_elements = []
|
||||||
|
# Index des résultats par id pour accès rapide
|
||||||
|
results_by_id = {}
|
||||||
|
for item in parsed:
|
||||||
|
item_id = item.get("id")
|
||||||
|
if item_id is not None:
|
||||||
|
results_by_id[int(item_id)] = item
|
||||||
|
|
||||||
|
valid_types = {
|
||||||
|
"button", "text_input", "checkbox", "radio", "dropdown",
|
||||||
|
"tab", "link", "icon", "table_row", "menu_item"
|
||||||
|
}
|
||||||
|
valid_roles = {
|
||||||
|
"primary_action", "cancel", "submit", "form_input",
|
||||||
|
"search_field", "navigation", "settings", "close",
|
||||||
|
"delete", "edit", "save"
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, region in enumerate(regions):
|
||||||
|
# Chercher le résultat VLM pour cette région
|
||||||
|
classification = results_by_id.get(i)
|
||||||
|
|
||||||
|
if classification is None:
|
||||||
|
# Si le VLM n'a pas classifié cette région, essayer par index dans le tableau
|
||||||
|
if i < len(parsed):
|
||||||
|
classification = parsed[i]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
elem_type = str(classification.get("type", "unknown")).lower().strip()
|
||||||
|
elem_role = str(classification.get("role", "unknown")).lower().strip()
|
||||||
|
elem_text = str(classification.get("text", ""))
|
||||||
|
|
||||||
|
if elem_type not in valid_types:
|
||||||
|
elem_type = "unknown"
|
||||||
|
if elem_role not in valid_roles:
|
||||||
|
elem_role = "unknown"
|
||||||
|
|
||||||
|
confidence = 0.85
|
||||||
|
|
||||||
|
# Extraire le crop pour les features visuelles
|
||||||
|
crop = pil_image.crop((
|
||||||
|
region.x, region.y,
|
||||||
|
region.x + region.w, region.y + region.h
|
||||||
|
))
|
||||||
|
|
||||||
|
element = UIElement(
|
||||||
|
element_id=f"hybrid_{region.x}_{region.y}",
|
||||||
|
type=elem_type,
|
||||||
|
role=elem_role,
|
||||||
|
bbox=(region.x, region.y, region.w, region.h),
|
||||||
|
center=region.center(),
|
||||||
|
label=elem_text,
|
||||||
|
label_confidence=0.8,
|
||||||
|
embeddings=UIElementEmbeddings(),
|
||||||
|
visual_features=self._extract_visual_features(crop),
|
||||||
|
confidence=confidence,
|
||||||
|
metadata={
|
||||||
|
"detected_by": "hybrid_batch",
|
||||||
|
"detection_method": region.source,
|
||||||
|
"vlm_model": self.config.vlm_model,
|
||||||
|
"screenshot_path": screenshot_path,
|
||||||
|
"batch_classified": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if element.confidence >= self.config.confidence_threshold:
|
||||||
|
ui_elements.append(element)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[PERF] Classification batch VLM : "
|
||||||
|
f"{len(ui_elements)}/{len(regions)} éléments classifiés"
|
||||||
|
)
|
||||||
|
return ui_elements
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_json_array_from_response(self, text: str) -> Optional[List[Dict]]:
|
||||||
|
"""Extraire un tableau JSON d'une réponse VLM, même si entouré de texte."""
|
||||||
|
# Nettoyer le markdown
|
||||||
|
if "```" in text:
|
||||||
|
lines = text.split("\n")
|
||||||
|
text = "\n".join([l for l in lines if not l.startswith("```")])
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# Essai 1 : parse direct
|
||||||
|
try:
|
||||||
|
result = json.loads(text)
|
||||||
|
if isinstance(result, list):
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Essai 2 : trouver le tableau JSON le plus long dans le texte
|
||||||
|
# Chercher le premier [ et le dernier ]
|
||||||
|
start_idx = text.find("[")
|
||||||
|
end_idx = text.rfind("]")
|
||||||
|
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
||||||
|
candidate = text[start_idx:end_idx + 1]
|
||||||
|
try:
|
||||||
|
result = json.loads(candidate)
|
||||||
|
if isinstance(result, list):
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Essai 3 : fixer les single quotes
|
||||||
|
fixed = text.replace("'", '"')
|
||||||
|
start_idx = fixed.find("[")
|
||||||
|
end_idx = fixed.rfind("]")
|
||||||
|
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
||||||
|
candidate = fixed[start_idx:end_idx + 1]
|
||||||
|
try:
|
||||||
|
result = json.loads(candidate)
|
||||||
|
if isinstance(result, list):
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Essai 4 : extraire chaque objet {…} individuellement et construire la liste
|
||||||
|
matches = re.findall(r'\{[^{}]+\}', text)
|
||||||
|
if matches:
|
||||||
|
items = []
|
||||||
|
for m in matches:
|
||||||
|
try:
|
||||||
|
items.append(json.loads(m))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
try:
|
||||||
|
items.append(json.loads(m.replace("'", '"')))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
if items:
|
||||||
|
return items
|
||||||
|
|
||||||
|
logger.debug(f"Impossible d'extraire un JSON array: {text[:200]}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _classify_regions_individually(
|
||||||
|
self,
|
||||||
|
pil_image: Image.Image,
|
||||||
|
regions: List[BoundingBox],
|
||||||
|
screenshot_path: str,
|
||||||
|
window_context: Optional[Dict] = None,
|
||||||
|
) -> List[UIElement]:
|
||||||
|
"""
|
||||||
|
Classification individuelle de chaque région (ancien comportement).
|
||||||
|
|
||||||
|
Utilisé comme fallback quand l'appel VLM unique échoue.
|
||||||
|
"""
|
||||||
|
ui_elements = []
|
||||||
|
MIN_VLM_SIZE = 40
|
||||||
|
|
||||||
|
for i, region in enumerate(regions):
|
||||||
|
# Extraire le crop de la région
|
||||||
|
crop = pil_image.crop((
|
||||||
|
region.x, region.y,
|
||||||
|
region.x + region.w, region.y + region.h
|
||||||
|
))
|
||||||
|
|
||||||
|
# Agrandir les crops trop petits pour le VLM (pad ou resize)
|
||||||
|
if crop.width < MIN_VLM_SIZE or crop.height < MIN_VLM_SIZE:
|
||||||
|
new_w = max(crop.width, MIN_VLM_SIZE)
|
||||||
|
new_h = max(crop.height, MIN_VLM_SIZE)
|
||||||
|
crop = crop.resize((new_w, new_h), Image.NEAREST)
|
||||||
|
|
||||||
|
# Classifier avec VLM
|
||||||
|
element = self._classify_region(
|
||||||
|
crop, region, screenshot_path, window_context
|
||||||
|
)
|
||||||
|
|
||||||
|
if element and element.confidence >= self.config.confidence_threshold:
|
||||||
|
ui_elements.append(element)
|
||||||
|
|
||||||
|
return ui_elements
|
||||||
|
|
||||||
def _classify_region(self,
|
def _classify_region(self,
|
||||||
crop: Image.Image,
|
crop: Image.Image,
|
||||||
region: BoundingBox,
|
region: BoundingBox,
|
||||||
|
|||||||
Reference in New Issue
Block a user