feat: pipeline complet MACRO/MÉSO/MICRO — Critic, Observer, Policy, Recovery, Learning, Audit Trail, TaskPlanner
Architecture 3 niveaux implémentée et testée (137 tests unitaires + 21 visuels) : MÉSO (acteur intelligent) : - P0 Critic : vérification sémantique post-action via gemma4 (replay_verifier.py) - P1 Observer : pré-analyse écran avant chaque action (api_stream.py /pre_analyze) - P2 Grounding/Policy : séparation localisation (grounding.py) et décision (policy.py) - P3 Recovery : rollback automatique Ctrl+Z/Escape/Alt+F4 (recovery.py) - P4 Learning : apprentissage runtime avec boucle de consolidation (replay_learner.py) MACRO (planificateur) : - TaskPlanner : comprend les ordres en langage naturel via gemma4 (task_planner.py) - Contexte métier TIM/CIM-10 pour les hôpitaux (domain_context.py) - Endpoint POST /api/v1/task pour l'exécution par instruction Traçabilité : - Audit trail complet avec 18 champs par action (audit_trail.py) - Endpoints GET /audit/history, /audit/summary, /audit/export (CSV) Grounding : - Fix parsing bbox_2d qwen2.5vl (pixels relatifs, pas grille 1000x1000) - Benchmarks visuels sur captures réelles (3 approches : baseline, zoom, Citrix) - Reproductibilité validée : variance < 0.008 sur 10 itérations Sécurité : - Tokens de production retirés du code source → .env.local - Secret key aléatoire si non configuré - Suppression logs qui leakent les tokens Résultats : 80% de replay (vs 12.5% avant), 100% détection visuelle Citrix JPEG Q20 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
683
tests/unit/test_audit_trail.py
Normal file
683
tests/unit/test_audit_trail.py
Normal file
@@ -0,0 +1,683 @@
|
||||
# tests/unit/test_audit_trail.py
|
||||
"""
|
||||
Tests unitaires du module Audit Trail.
|
||||
|
||||
Vérifie l'enregistrement, la recherche, l'export CSV et le résumé
|
||||
journalier des entrées d'audit.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Importer depuis le bon chemin (agent_v0/server_v1/)
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from agent_v0.server_v1.audit_trail import AuditEntry, AuditTrail
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Fixtures
|
||||
# =========================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def audit_dir(tmp_path):
|
||||
"""Répertoire temporaire pour les fichiers d'audit."""
|
||||
d = tmp_path / "audit"
|
||||
d.mkdir()
|
||||
return str(d)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audit(audit_dir):
|
||||
"""Instance AuditTrail avec répertoire temporaire."""
|
||||
return AuditTrail(audit_dir=audit_dir)
|
||||
|
||||
|
||||
def _make_entry(**kwargs) -> AuditEntry:
|
||||
"""Créer une entrée d'audit avec des valeurs par défaut."""
|
||||
defaults = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"session_id": "sess_test_001",
|
||||
"action_id": "act_001",
|
||||
"user_id": "tim_dupont",
|
||||
"user_name": "Marie Dupont",
|
||||
"machine_id": "PC-TIM-01",
|
||||
"action_type": "click",
|
||||
"action_detail": "Clic sur 'Enregistrer' dans DxCare",
|
||||
"target_app": "DxCare",
|
||||
"execution_mode": "assisted",
|
||||
"result": "success",
|
||||
"resolution_method": "som_text_match",
|
||||
"critic_result": "semantic_ok",
|
||||
"recovery_action": "",
|
||||
"domain": "tim_codage",
|
||||
"workflow_id": "wf_codage_cim10",
|
||||
"workflow_name": "Codage CIM-10 séjour",
|
||||
"duration_ms": 234.5,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return AuditEntry(**defaults)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests AuditEntry
|
||||
# =========================================================================
|
||||
|
||||
class TestAuditEntry:
|
||||
"""Tests de la structure AuditEntry."""
|
||||
|
||||
def test_creation_basique(self):
|
||||
"""Créer une entrée avec tous les champs."""
|
||||
entry = _make_entry()
|
||||
assert entry.user_id == "tim_dupont"
|
||||
assert entry.action_type == "click"
|
||||
assert entry.result == "success"
|
||||
assert entry.duration_ms == 234.5
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Sérialiser en dictionnaire."""
|
||||
entry = _make_entry()
|
||||
d = entry.to_dict()
|
||||
assert isinstance(d, dict)
|
||||
assert d["user_id"] == "tim_dupont"
|
||||
assert d["domain"] == "tim_codage"
|
||||
assert d["duration_ms"] == 234.5
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Désérialiser depuis un dictionnaire."""
|
||||
entry = _make_entry()
|
||||
d = entry.to_dict()
|
||||
restored = AuditEntry.from_dict(d)
|
||||
assert restored.user_id == entry.user_id
|
||||
assert restored.action_detail == entry.action_detail
|
||||
assert restored.duration_ms == entry.duration_ms
|
||||
|
||||
def test_from_dict_ignore_unknown_keys(self):
|
||||
"""Les clés inconnues sont ignorées (compatibilité future)."""
|
||||
d = {"user_id": "test", "unknown_field": "valeur", "future_key": 42}
|
||||
entry = AuditEntry.from_dict(d)
|
||||
assert entry.user_id == "test"
|
||||
# Les champs inconnus ne lèvent pas d'erreur
|
||||
|
||||
def test_to_dict_json_serializable(self):
|
||||
"""Le dictionnaire est sérialisable en JSON."""
|
||||
entry = _make_entry(action_detail="Clic sur 'Validé' — accent français")
|
||||
d = entry.to_dict()
|
||||
json_str = json.dumps(d, ensure_ascii=False)
|
||||
assert "accent français" in json_str
|
||||
|
||||
def test_default_values(self):
|
||||
"""Une entrée vide a des valeurs par défaut cohérentes."""
|
||||
entry = AuditEntry()
|
||||
assert entry.timestamp == ""
|
||||
assert entry.user_id == ""
|
||||
assert entry.duration_ms == 0.0
|
||||
assert entry.result == ""
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests AuditTrail — enregistrement et lecture
|
||||
# =========================================================================
|
||||
|
||||
class TestAuditTrailRecord:
|
||||
"""Tests d'enregistrement des entrées."""
|
||||
|
||||
def test_record_and_reload(self, audit, audit_dir):
|
||||
"""Enregistrer une entrée puis la relire depuis le fichier."""
|
||||
entry = _make_entry()
|
||||
audit.record(entry)
|
||||
|
||||
# Vérifier que le fichier existe
|
||||
today = date.today().isoformat()
|
||||
filepath = Path(audit_dir) / f"audit_{today}.jsonl"
|
||||
assert filepath.exists()
|
||||
|
||||
# Lire le fichier directement
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
assert len(lines) == 1
|
||||
|
||||
data = json.loads(lines[0])
|
||||
assert data["user_id"] == "tim_dupont"
|
||||
assert data["action_detail"] == "Clic sur 'Enregistrer' dans DxCare"
|
||||
|
||||
def test_record_multiple_entries(self, audit, audit_dir):
|
||||
"""Enregistrer plusieurs entrées dans le même fichier."""
|
||||
for i in range(5):
|
||||
entry = _make_entry(action_id=f"act_{i:03d}")
|
||||
audit.record(entry)
|
||||
|
||||
today = date.today().isoformat()
|
||||
filepath = Path(audit_dir) / f"audit_{today}.jsonl"
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
assert len(lines) == 5
|
||||
|
||||
def test_record_auto_timestamp(self, audit):
|
||||
"""Le timestamp est généré automatiquement si absent."""
|
||||
entry = _make_entry(timestamp="")
|
||||
audit.record(entry)
|
||||
|
||||
# Le timestamp doit avoir été rempli
|
||||
entries = audit.query()
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["timestamp"] != ""
|
||||
# Vérifier le format ISO 8601
|
||||
datetime.fromisoformat(entries[0]["timestamp"])
|
||||
|
||||
def test_record_utf8_french(self, audit):
|
||||
"""Les caractères français sont correctement enregistrés."""
|
||||
entry = _make_entry(
|
||||
action_detail="Saisie du diagnostic 'Hépatite à cytomégalovirus' — CIM-10: B25.1",
|
||||
user_name="François Müller",
|
||||
workflow_name="Codage séjour réanimation néonatale",
|
||||
)
|
||||
audit.record(entry)
|
||||
|
||||
entries = audit.query()
|
||||
assert len(entries) == 1
|
||||
assert "Hépatite" in entries[0]["action_detail"]
|
||||
assert "François Müller" in entries[0]["user_name"]
|
||||
assert "néonatale" in entries[0]["workflow_name"]
|
||||
|
||||
def test_record_creates_directory(self, tmp_path):
|
||||
"""Le répertoire est créé automatiquement s'il n'existe pas."""
|
||||
new_dir = str(tmp_path / "sub" / "deep" / "audit")
|
||||
audit = AuditTrail(audit_dir=new_dir)
|
||||
entry = _make_entry()
|
||||
audit.record(entry)
|
||||
|
||||
assert Path(new_dir).exists()
|
||||
entries = audit.query()
|
||||
assert len(entries) == 1
|
||||
|
||||
def test_record_different_dates(self, audit, audit_dir):
|
||||
"""Les entrées de dates différentes vont dans des fichiers différents."""
|
||||
today = date.today()
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
entry_today = _make_entry(timestamp=datetime.now().isoformat())
|
||||
entry_yesterday = _make_entry(
|
||||
timestamp=datetime.combine(yesterday, datetime.min.time()).isoformat(),
|
||||
action_id="act_yesterday",
|
||||
)
|
||||
|
||||
audit.record(entry_today)
|
||||
audit.record(entry_yesterday)
|
||||
|
||||
# Vérifier les fichiers
|
||||
file_today = Path(audit_dir) / f"audit_{today.isoformat()}.jsonl"
|
||||
file_yesterday = Path(audit_dir) / f"audit_{yesterday.isoformat()}.jsonl"
|
||||
assert file_today.exists()
|
||||
assert file_yesterday.exists()
|
||||
|
||||
def test_jsonl_format(self, audit, audit_dir):
|
||||
"""Chaque ligne du fichier est un JSON valide (format JSONL)."""
|
||||
for i in range(3):
|
||||
audit.record(_make_entry(action_id=f"act_{i}"))
|
||||
|
||||
today = date.today().isoformat()
|
||||
filepath = Path(audit_dir) / f"audit_{today}.jsonl"
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
assert line, f"Ligne {line_num} vide"
|
||||
data = json.loads(line) # Ne doit pas lever d'exception
|
||||
assert "action_id" in data
|
||||
assert "timestamp" in data
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests AuditTrail — requêtes avec filtres
|
||||
# =========================================================================
|
||||
|
||||
class TestAuditTrailQuery:
|
||||
"""Tests de recherche et filtrage."""
|
||||
|
||||
def _seed_entries(self, audit):
|
||||
"""Insérer des entrées de test variées."""
|
||||
entries = [
|
||||
_make_entry(
|
||||
action_id="act_001",
|
||||
user_id="tim_dupont",
|
||||
result="success",
|
||||
action_type="click",
|
||||
workflow_id="wf_01",
|
||||
domain="tim_codage",
|
||||
),
|
||||
_make_entry(
|
||||
action_id="act_002",
|
||||
user_id="tim_dupont",
|
||||
result="failed",
|
||||
action_type="type",
|
||||
workflow_id="wf_01",
|
||||
domain="generic",
|
||||
),
|
||||
_make_entry(
|
||||
action_id="act_003",
|
||||
user_id="tim_martin",
|
||||
user_name="Jean Martin",
|
||||
result="success",
|
||||
action_type="click",
|
||||
workflow_id="wf_02",
|
||||
domain="generic",
|
||||
),
|
||||
_make_entry(
|
||||
action_id="act_004",
|
||||
user_id="tim_martin",
|
||||
user_name="Jean Martin",
|
||||
result="recovered",
|
||||
action_type="key_combo",
|
||||
workflow_id="wf_02",
|
||||
domain="generic",
|
||||
),
|
||||
_make_entry(
|
||||
action_id="act_005",
|
||||
user_id="tim_dupont",
|
||||
result="success",
|
||||
action_type="click",
|
||||
workflow_id="wf_01",
|
||||
domain="generic",
|
||||
),
|
||||
]
|
||||
for e in entries:
|
||||
audit.record(e)
|
||||
|
||||
def test_query_all(self, audit):
|
||||
"""Requête sans filtre retourne tout."""
|
||||
self._seed_entries(audit)
|
||||
results = audit.query()
|
||||
assert len(results) == 5
|
||||
|
||||
def test_query_by_user(self, audit):
|
||||
"""Filtrer par identifiant utilisateur."""
|
||||
self._seed_entries(audit)
|
||||
results = audit.query(user_id="tim_dupont")
|
||||
assert len(results) == 3
|
||||
assert all(r["user_id"] == "tim_dupont" for r in results)
|
||||
|
||||
def test_query_by_result(self, audit):
|
||||
"""Filtrer par résultat."""
|
||||
self._seed_entries(audit)
|
||||
results = audit.query(result="success")
|
||||
assert len(results) == 3
|
||||
assert all(r["result"] == "success" for r in results)
|
||||
|
||||
def test_query_by_action_type(self, audit):
|
||||
"""Filtrer par type d'action."""
|
||||
self._seed_entries(audit)
|
||||
results = audit.query(action_type="click")
|
||||
assert len(results) == 3
|
||||
|
||||
def test_query_by_workflow(self, audit):
|
||||
"""Filtrer par workflow."""
|
||||
self._seed_entries(audit)
|
||||
results = audit.query(workflow_id="wf_02")
|
||||
assert len(results) == 2
|
||||
|
||||
def test_query_by_domain(self, audit):
|
||||
"""Filtrer par domaine métier."""
|
||||
self._seed_entries(audit)
|
||||
results = audit.query(domain="tim_codage")
|
||||
assert len(results) == 1
|
||||
assert results[0]["action_id"] == "act_001"
|
||||
|
||||
def test_query_by_session(self, audit):
|
||||
"""Filtrer par session."""
|
||||
self._seed_entries(audit)
|
||||
results = audit.query(session_id="sess_test_001")
|
||||
assert len(results) == 5 # Toutes les entrées ont la même session
|
||||
|
||||
def test_query_combined_filters(self, audit):
|
||||
"""Combinaison de plusieurs filtres (AND)."""
|
||||
self._seed_entries(audit)
|
||||
results = audit.query(user_id="tim_dupont", result="success")
|
||||
assert len(results) == 2
|
||||
|
||||
def test_query_no_match(self, audit):
|
||||
"""Filtre sans correspondance retourne une liste vide."""
|
||||
self._seed_entries(audit)
|
||||
results = audit.query(user_id="tim_inexistant")
|
||||
assert len(results) == 0
|
||||
|
||||
def test_query_pagination_limit(self, audit):
|
||||
"""Limiter le nombre de résultats."""
|
||||
self._seed_entries(audit)
|
||||
results = audit.query(limit=2)
|
||||
assert len(results) == 2
|
||||
|
||||
def test_query_pagination_offset(self, audit):
|
||||
"""Décalage dans les résultats."""
|
||||
self._seed_entries(audit)
|
||||
all_results = audit.query()
|
||||
offset_results = audit.query(offset=3)
|
||||
assert len(offset_results) == 2
|
||||
assert offset_results[0] == all_results[3]
|
||||
|
||||
def test_query_sorted_by_timestamp_desc(self, audit):
|
||||
"""Les résultats sont triés par timestamp décroissant."""
|
||||
now = datetime.now()
|
||||
for i in range(5):
|
||||
ts = (now - timedelta(minutes=i)).isoformat()
|
||||
audit.record(_make_entry(
|
||||
timestamp=ts,
|
||||
action_id=f"act_{i}",
|
||||
))
|
||||
|
||||
results = audit.query()
|
||||
timestamps = [r["timestamp"] for r in results]
|
||||
assert timestamps == sorted(timestamps, reverse=True)
|
||||
|
||||
def test_query_date_range(self, audit):
|
||||
"""Filtrer par plage de dates."""
|
||||
today = date.today()
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
# Entrée d'hier
|
||||
audit.record(_make_entry(
|
||||
timestamp=datetime.combine(yesterday, datetime.min.time()).isoformat(),
|
||||
action_id="act_yesterday",
|
||||
))
|
||||
# Entrée d'aujourd'hui
|
||||
audit.record(_make_entry(
|
||||
timestamp=datetime.now().isoformat(),
|
||||
action_id="act_today",
|
||||
))
|
||||
|
||||
# Filtrer uniquement hier
|
||||
results = audit.query(
|
||||
date_from=yesterday.isoformat(),
|
||||
date_to=yesterday.isoformat(),
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0]["action_id"] == "act_yesterday"
|
||||
|
||||
# Filtrer les deux jours
|
||||
results = audit.query(
|
||||
date_from=yesterday.isoformat(),
|
||||
date_to=today.isoformat(),
|
||||
)
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests AuditTrail — résumé journalier
|
||||
# =========================================================================
|
||||
|
||||
class TestAuditTrailSummary:
|
||||
"""Tests du résumé journalier."""
|
||||
|
||||
def test_summary_empty(self, audit):
|
||||
"""Résumé d'un jour sans données."""
|
||||
summary = audit.get_summary("2025-01-01")
|
||||
assert summary["total_actions"] == 0
|
||||
assert summary["success_rate"] == 0.0
|
||||
assert summary["by_user"] == {}
|
||||
|
||||
def test_summary_basic(self, audit):
|
||||
"""Résumé avec quelques entrées."""
|
||||
audit.record(_make_entry(user_id="tim_dupont", result="success"))
|
||||
audit.record(_make_entry(user_id="tim_dupont", result="failed"))
|
||||
audit.record(_make_entry(user_id="tim_martin", user_name="Jean Martin", result="success"))
|
||||
|
||||
summary = audit.get_summary()
|
||||
assert summary["total_actions"] == 3
|
||||
assert summary["success_rate"] == round(2 / 3, 3)
|
||||
|
||||
def test_summary_by_user(self, audit):
|
||||
"""Répartition par utilisateur."""
|
||||
audit.record(_make_entry(user_id="tim_dupont", result="success"))
|
||||
audit.record(_make_entry(user_id="tim_dupont", result="success"))
|
||||
audit.record(_make_entry(user_id="tim_dupont", result="failed"))
|
||||
audit.record(_make_entry(user_id="tim_martin", user_name="Jean Martin", result="success"))
|
||||
|
||||
summary = audit.get_summary()
|
||||
assert "tim_dupont" in summary["by_user"]
|
||||
assert summary["by_user"]["tim_dupont"]["total"] == 3
|
||||
assert summary["by_user"]["tim_dupont"]["success"] == 2
|
||||
assert summary["by_user"]["tim_dupont"]["success_rate"] == round(2 / 3, 3)
|
||||
assert summary["by_user"]["tim_martin"]["total"] == 1
|
||||
assert summary["by_user"]["tim_martin"]["success_rate"] == 1.0
|
||||
|
||||
def test_summary_by_result(self, audit):
|
||||
"""Répartition par résultat."""
|
||||
audit.record(_make_entry(result="success"))
|
||||
audit.record(_make_entry(result="success"))
|
||||
audit.record(_make_entry(result="failed"))
|
||||
audit.record(_make_entry(result="recovered"))
|
||||
|
||||
summary = audit.get_summary()
|
||||
assert summary["by_result"]["success"] == 2
|
||||
assert summary["by_result"]["failed"] == 1
|
||||
assert summary["by_result"]["recovered"] == 1
|
||||
|
||||
def test_summary_by_action_type(self, audit):
|
||||
"""Répartition par type d'action."""
|
||||
audit.record(_make_entry(action_type="click"))
|
||||
audit.record(_make_entry(action_type="click"))
|
||||
audit.record(_make_entry(action_type="type"))
|
||||
|
||||
summary = audit.get_summary()
|
||||
assert summary["by_action_type"]["click"] == 2
|
||||
assert summary["by_action_type"]["type"] == 1
|
||||
|
||||
def test_summary_by_workflow(self, audit):
|
||||
"""Répartition par workflow."""
|
||||
audit.record(_make_entry(workflow_id="wf_01"))
|
||||
audit.record(_make_entry(workflow_id="wf_01"))
|
||||
audit.record(_make_entry(workflow_id="wf_02"))
|
||||
|
||||
summary = audit.get_summary()
|
||||
assert summary["by_workflow"]["wf_01"] == 2
|
||||
assert summary["by_workflow"]["wf_02"] == 1
|
||||
|
||||
def test_summary_by_execution_mode(self, audit):
|
||||
"""Répartition par mode d'exécution."""
|
||||
audit.record(_make_entry(execution_mode="autonomous"))
|
||||
audit.record(_make_entry(execution_mode="assisted"))
|
||||
audit.record(_make_entry(execution_mode="assisted"))
|
||||
|
||||
summary = audit.get_summary()
|
||||
assert summary["by_execution_mode"]["autonomous"] == 1
|
||||
assert summary["by_execution_mode"]["assisted"] == 2
|
||||
|
||||
def test_summary_date_field(self, audit):
|
||||
"""Le résumé contient la date demandée."""
|
||||
today = date.today().isoformat()
|
||||
summary = audit.get_summary(today)
|
||||
assert summary["date"] == today
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests AuditTrail — export CSV
|
||||
# =========================================================================
|
||||
|
||||
class TestAuditTrailExportCSV:
|
||||
"""Tests de l'export CSV."""
|
||||
|
||||
def test_export_csv_empty(self, audit):
|
||||
"""Export sans données retourne une chaîne vide."""
|
||||
csv_data = audit.export_csv(date_from="2025-01-01")
|
||||
assert csv_data == ""
|
||||
|
||||
def test_export_csv_basic(self, audit):
|
||||
"""Export CSV avec quelques entrées."""
|
||||
audit.record(_make_entry(action_id="act_001"))
|
||||
audit.record(_make_entry(action_id="act_002"))
|
||||
|
||||
csv_data = audit.export_csv()
|
||||
assert csv_data
|
||||
assert "act_001" in csv_data
|
||||
assert "act_002" in csv_data
|
||||
|
||||
def test_export_csv_header(self, audit):
|
||||
"""L'en-tête CSV contient tous les champs du dataclass."""
|
||||
audit.record(_make_entry())
|
||||
|
||||
csv_data = audit.export_csv()
|
||||
reader = csv.DictReader(io.StringIO(csv_data))
|
||||
fieldnames = reader.fieldnames
|
||||
assert "timestamp" in fieldnames
|
||||
assert "user_id" in fieldnames
|
||||
assert "action_detail" in fieldnames
|
||||
assert "domain" in fieldnames
|
||||
assert "duration_ms" in fieldnames
|
||||
|
||||
def test_export_csv_parseable(self, audit):
|
||||
"""Le CSV produit est parseable par le module csv."""
|
||||
for i in range(5):
|
||||
audit.record(_make_entry(
|
||||
action_id=f"act_{i}",
|
||||
action_detail=f"Action {i} — avec des 'guillemets' et des, virgules",
|
||||
))
|
||||
|
||||
csv_data = audit.export_csv()
|
||||
reader = csv.DictReader(io.StringIO(csv_data))
|
||||
rows = list(reader)
|
||||
assert len(rows) == 5
|
||||
|
||||
# Vérifier que les valeurs sont correctes malgré les caractères spéciaux
|
||||
for row in rows:
|
||||
assert "virgules" in row["action_detail"]
|
||||
|
||||
def test_export_csv_filter_by_user(self, audit):
|
||||
"""Export filtré par utilisateur."""
|
||||
audit.record(_make_entry(user_id="tim_dupont", action_id="act_001"))
|
||||
audit.record(_make_entry(user_id="tim_martin", action_id="act_002"))
|
||||
|
||||
csv_data = audit.export_csv(user_id="tim_dupont")
|
||||
reader = csv.DictReader(io.StringIO(csv_data))
|
||||
rows = list(reader)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["user_id"] == "tim_dupont"
|
||||
|
||||
def test_export_csv_utf8(self, audit):
|
||||
"""L'export CSV gère correctement l'UTF-8 français."""
|
||||
audit.record(_make_entry(
|
||||
action_detail="Saisie 'Hépatite à cytomégalovirus' — réanimation néonatale",
|
||||
user_name="François Müller",
|
||||
))
|
||||
|
||||
csv_data = audit.export_csv()
|
||||
assert "Hépatite" in csv_data
|
||||
assert "François Müller" in csv_data
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests de robustesse
|
||||
# =========================================================================
|
||||
|
||||
class TestAuditTrailRobustness:
|
||||
"""Tests de robustesse et cas limites."""
|
||||
|
||||
def test_directory_auto_creation(self, tmp_path):
|
||||
"""Le répertoire est créé automatiquement s'il n'existe pas."""
|
||||
audit_dir = str(tmp_path / "nonexistent" / "deep" / "audit")
|
||||
assert not Path(audit_dir).exists()
|
||||
|
||||
audit = AuditTrail(audit_dir=audit_dir)
|
||||
assert Path(audit_dir).exists()
|
||||
|
||||
def test_corrupted_jsonl_line(self, audit, audit_dir):
|
||||
"""Une ligne corrompue dans le fichier JSONL ne fait pas crasher la lecture."""
|
||||
# Écrire des entrées normales
|
||||
audit.record(_make_entry(action_id="act_001"))
|
||||
audit.record(_make_entry(action_id="act_002"))
|
||||
|
||||
# Injecter une ligne corrompue
|
||||
today = date.today().isoformat()
|
||||
filepath = Path(audit_dir) / f"audit_{today}.jsonl"
|
||||
with open(filepath, "a", encoding="utf-8") as f:
|
||||
f.write("{invalid json line\n")
|
||||
|
||||
# Ajouter encore une entrée valide
|
||||
audit.record(_make_entry(action_id="act_003"))
|
||||
|
||||
# La lecture doit fonctionner et ignorer la ligne corrompue
|
||||
entries = audit.query()
|
||||
assert len(entries) == 3 # 2 valides avant + 1 valide après
|
||||
|
||||
def test_empty_file(self, audit, audit_dir):
|
||||
"""Un fichier vide ne fait pas crasher."""
|
||||
today = date.today().isoformat()
|
||||
filepath = Path(audit_dir) / f"audit_{today}.jsonl"
|
||||
filepath.touch() # Fichier vide
|
||||
|
||||
entries = audit.query()
|
||||
assert len(entries) == 0
|
||||
|
||||
def test_concurrent_writes(self, audit):
|
||||
"""Écritures concurrentes grâce au verrou threading."""
|
||||
import threading
|
||||
|
||||
errors = []
|
||||
|
||||
def write_entries(start):
|
||||
try:
|
||||
for i in range(20):
|
||||
audit.record(_make_entry(action_id=f"act_{start}_{i}"))
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=write_entries, args=(t,))
|
||||
for t in range(5)
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert not errors, f"Erreurs concurrentes: {errors}"
|
||||
entries = audit.query(limit=200)
|
||||
assert len(entries) == 100 # 5 threads x 20 entrées
|
||||
|
||||
def test_query_invalid_date(self, audit):
|
||||
"""Dates invalides ne font pas crasher."""
|
||||
# Ne doit pas lever d'exception
|
||||
results = audit.query(date_from="not-a-date")
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_summary_invalid_date(self, audit):
|
||||
"""Date invalide dans get_summary ne fait pas crasher."""
|
||||
summary = audit.get_summary("not-a-date")
|
||||
assert summary["total_actions"] == 0
|
||||
|
||||
def test_entry_all_fields_present_in_export(self, audit):
|
||||
"""Tous les champs du dataclass sont présents dans l'export CSV."""
|
||||
from dataclasses import fields as dc_fields
|
||||
entry = _make_entry()
|
||||
audit.record(entry)
|
||||
|
||||
csv_data = audit.export_csv()
|
||||
reader = csv.DictReader(io.StringIO(csv_data))
|
||||
row = next(reader)
|
||||
|
||||
expected_fields = {f.name for f in dc_fields(AuditEntry)}
|
||||
actual_fields = set(row.keys())
|
||||
assert expected_fields == actual_fields
|
||||
|
||||
def test_date_range_reversed(self, audit):
|
||||
"""Plage de dates inversée (date_to < date_from) fonctionne quand même."""
|
||||
today = date.today()
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
audit.record(_make_entry(
|
||||
timestamp=datetime.combine(yesterday, datetime.min.time()).isoformat(),
|
||||
))
|
||||
|
||||
# date_from > date_to → doit quand même fonctionner
|
||||
results = audit.query(
|
||||
date_from=today.isoformat(),
|
||||
date_to=yesterday.isoformat(),
|
||||
)
|
||||
# L'implémentation inverse automatiquement les dates
|
||||
assert isinstance(results, list)
|
||||
530
tests/unit/test_policy_grounding_recovery_learning.py
Normal file
530
tests/unit/test_policy_grounding_recovery_learning.py
Normal file
@@ -0,0 +1,530 @@
|
||||
"""
|
||||
Tests fonctionnels pour P2 (Policy/Grounding), P3 (Recovery), P4 (Learning).
|
||||
|
||||
Vérifie que chaque module fait bien son travail :
|
||||
- Grounding : localise ou retourne NOT_FOUND (pas de décision)
|
||||
- Policy : décide RETRY/SKIP/ABORT/SUPERVISE (pas de localisation)
|
||||
- Recovery : exécute Ctrl+Z / Escape / Alt+F4 selon le contexte
|
||||
- Learning : enregistre et requête les résultats structurés
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
_ROOT = str(Path(__file__).resolve().parents[2])
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# P2 : Grounding — localisation pure
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestGroundingEngine:
|
||||
|
||||
def _make_engine(self):
|
||||
from agent_v0.agent_v1.core.grounding import GroundingEngine
|
||||
executor = MagicMock()
|
||||
executor._capture_screenshot_b64.return_value = "fake_b64_data"
|
||||
return GroundingEngine(executor), executor
|
||||
|
||||
def test_server_found_retourne_coordonnees(self):
|
||||
"""Si le serveur trouve l'élément, retourne ses coordonnées."""
|
||||
engine, executor = self._make_engine()
|
||||
executor._server_resolve_target.return_value = {
|
||||
"resolved": True, "x_pct": 0.5, "y_pct": 0.3,
|
||||
"method": "som_text", "score": 0.95,
|
||||
"matched_element": {"label": "Enregistrer"},
|
||||
}
|
||||
result = engine.locate("http://server", {"by_text": "Enregistrer"}, 0.5, 0.3, 1920, 1080)
|
||||
assert result.found is True
|
||||
assert result.x_pct == 0.5
|
||||
assert result.y_pct == 0.3
|
||||
assert result.method == "som_text"
|
||||
|
||||
def test_server_not_found_cascade_template(self):
|
||||
"""Si serveur échoue, cascade vers template matching."""
|
||||
engine, executor = self._make_engine()
|
||||
executor._server_resolve_target.return_value = None
|
||||
executor._template_match_anchor.return_value = {
|
||||
"resolved": True, "x_pct": 0.4, "y_pct": 0.6,
|
||||
"score": 0.85,
|
||||
}
|
||||
result = engine.locate(
|
||||
"http://server",
|
||||
{"by_text": "OK", "anchor_image_base64": "abc123"},
|
||||
0.5, 0.3, 1920, 1080,
|
||||
)
|
||||
assert result.found is True
|
||||
assert result.method == "anchor_template"
|
||||
|
||||
def test_toutes_strategies_echouent_retourne_not_found(self):
|
||||
"""Si toutes les stratégies échouent, retourne NOT_FOUND."""
|
||||
engine, executor = self._make_engine()
|
||||
executor._server_resolve_target.return_value = None
|
||||
executor._template_match_anchor.return_value = None
|
||||
executor._hybrid_vlm_resolve.return_value = None
|
||||
result = engine.locate(
|
||||
"http://server",
|
||||
{"by_text": "Inexistant", "anchor_image_base64": "abc", "vlm_description": "bouton"},
|
||||
0.5, 0.3, 1920, 1080,
|
||||
)
|
||||
assert result.found is False
|
||||
assert "échoué" in result.detail
|
||||
|
||||
def test_screenshot_echoue_retourne_not_found(self):
|
||||
"""Si la capture screenshot échoue, NOT_FOUND immédiat."""
|
||||
engine, executor = self._make_engine()
|
||||
executor._capture_screenshot_b64.return_value = None
|
||||
result = engine.locate("http://server", {"by_text": "OK"}, 0.5, 0.3, 1920, 1080)
|
||||
assert result.found is False
|
||||
assert "screenshot" in result.detail.lower()
|
||||
|
||||
def test_strategies_custom(self):
|
||||
"""On peut spécifier les stratégies à utiliser."""
|
||||
engine, executor = self._make_engine()
|
||||
executor._template_match_anchor.return_value = {
|
||||
"resolved": True, "x_pct": 0.2, "y_pct": 0.8, "score": 0.9,
|
||||
}
|
||||
# Seulement template, pas de serveur
|
||||
result = engine.locate(
|
||||
"", {"anchor_image_base64": "abc"}, 0.5, 0.3, 1920, 1080,
|
||||
strategies=["template"],
|
||||
)
|
||||
assert result.found is True
|
||||
# Le serveur n'a PAS été appelé
|
||||
executor._server_resolve_target.assert_not_called()
|
||||
|
||||
def test_grounding_result_to_dict(self):
|
||||
"""Le GroundingResult se sérialise correctement."""
|
||||
from agent_v0.agent_v1.core.grounding import GroundingResult
|
||||
r = GroundingResult(found=True, x_pct=0.5, y_pct=0.3, method="som", score=0.9)
|
||||
d = r.to_dict()
|
||||
assert d["found"] is True
|
||||
assert d["x_pct"] == 0.5
|
||||
assert d["method"] == "som"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# P2 : Policy — décisions quand grounding échoue
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestPolicyEngine:
|
||||
|
||||
def _make_engine(self):
|
||||
from agent_v0.agent_v1.core.policy import PolicyEngine
|
||||
executor = MagicMock()
|
||||
return PolicyEngine(executor), executor
|
||||
|
||||
def test_premier_essai_popup_fermee_retry(self):
|
||||
"""Premier échec + popup fermée → RETRY."""
|
||||
from agent_v0.agent_v1.core.policy import Decision
|
||||
engine, executor = self._make_engine()
|
||||
executor._handle_popup_vlm.return_value = True # Popup fermée
|
||||
|
||||
decision = engine.decide(
|
||||
action={"type": "click"},
|
||||
target_spec={"by_text": "OK"},
|
||||
retry_count=0,
|
||||
)
|
||||
assert decision.decision == Decision.RETRY
|
||||
assert "popup" in decision.reason.lower()
|
||||
|
||||
def test_premier_essai_pas_de_popup_retry(self):
|
||||
"""Premier échec + pas de popup → RETRY quand même (max_retries > 0)."""
|
||||
from agent_v0.agent_v1.core.policy import Decision
|
||||
engine, executor = self._make_engine()
|
||||
executor._handle_popup_vlm.return_value = False
|
||||
|
||||
decision = engine.decide(
|
||||
action={"type": "click"},
|
||||
target_spec={"by_text": "OK"},
|
||||
retry_count=0,
|
||||
max_retries=2,
|
||||
)
|
||||
assert decision.decision == Decision.RETRY
|
||||
|
||||
def test_max_retries_acteur_passer_skip(self):
|
||||
"""Max retries atteint + acteur dit PASSER → SKIP."""
|
||||
from agent_v0.agent_v1.core.policy import Decision
|
||||
engine, executor = self._make_engine()
|
||||
executor._actor_decide.return_value = "PASSER"
|
||||
|
||||
decision = engine.decide(
|
||||
action={"type": "click"},
|
||||
target_spec={"by_text": "Onglet"},
|
||||
retry_count=1,
|
||||
max_retries=1,
|
||||
)
|
||||
assert decision.decision == Decision.SKIP
|
||||
|
||||
def test_max_retries_acteur_stopper_abort(self):
|
||||
"""Max retries atteint + acteur dit STOPPER → ABORT."""
|
||||
from agent_v0.agent_v1.core.policy import Decision
|
||||
engine, executor = self._make_engine()
|
||||
executor._actor_decide.return_value = "STOPPER"
|
||||
|
||||
decision = engine.decide(
|
||||
action={"type": "click"},
|
||||
target_spec={"by_text": "X"},
|
||||
retry_count=1,
|
||||
max_retries=1,
|
||||
)
|
||||
assert decision.decision == Decision.ABORT
|
||||
|
||||
def test_max_retries_acteur_executer_supervise(self):
|
||||
"""Max retries + acteur dit EXECUTER → SUPERVISE (rendre la main)."""
|
||||
from agent_v0.agent_v1.core.policy import Decision
|
||||
engine, executor = self._make_engine()
|
||||
executor._actor_decide.return_value = "EXECUTER"
|
||||
|
||||
decision = engine.decide(
|
||||
action={"type": "click"},
|
||||
target_spec={"by_text": "X"},
|
||||
retry_count=1,
|
||||
max_retries=1,
|
||||
)
|
||||
assert decision.decision == Decision.SUPERVISE
|
||||
|
||||
def test_policy_decision_to_dict(self):
|
||||
"""PolicyDecision se sérialise correctement."""
|
||||
from agent_v0.agent_v1.core.policy import PolicyDecision, Decision
|
||||
d = PolicyDecision(decision=Decision.SKIP, reason="État atteint").to_dict()
|
||||
assert d["decision"] == "skip"
|
||||
assert d["reason"] == "État atteint"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# P3 : Recovery — rollback après échec
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestRecoveryEngine:
|
||||
|
||||
def _make_engine(self):
|
||||
from agent_v0.agent_v1.core.recovery import RecoveryEngine
|
||||
executor = MagicMock()
|
||||
executor.keyboard = MagicMock()
|
||||
executor.sct = MagicMock()
|
||||
executor.sct.monitors = [{}, {"width": 1920, "height": 1080}]
|
||||
executor._click = MagicMock()
|
||||
return RecoveryEngine(executor), executor
|
||||
|
||||
def test_popup_detectee_escape(self):
|
||||
"""Critic dit "popup" → Recovery fait Escape."""
|
||||
from agent_v0.agent_v1.core.recovery import RecoveryAction
|
||||
engine, executor = self._make_engine()
|
||||
result = engine.attempt(
|
||||
failed_action={"type": "click"},
|
||||
critic_detail="Une popup d'erreur est apparue",
|
||||
)
|
||||
assert result.action_taken == RecoveryAction.ESCAPE
|
||||
assert result.success is True
|
||||
# Vérifie que Escape a été pressé
|
||||
executor.keyboard.press.assert_called()
|
||||
|
||||
def test_frappe_incorrecte_undo(self):
|
||||
"""Frappe incorrecte → Recovery fait Ctrl+Z."""
|
||||
from agent_v0.agent_v1.core.recovery import RecoveryAction
|
||||
engine, executor = self._make_engine()
|
||||
result = engine.attempt(
|
||||
failed_action={"type": "type"},
|
||||
critic_detail="Le texte a été tapé au mauvais endroit",
|
||||
)
|
||||
assert result.action_taken == RecoveryAction.UNDO
|
||||
assert result.success is True
|
||||
|
||||
def test_mauvaise_fenetre_close(self):
|
||||
"""Mauvaise fenêtre → Recovery fait Alt+F4."""
|
||||
from agent_v0.agent_v1.core.recovery import RecoveryAction
|
||||
engine, executor = self._make_engine()
|
||||
result = engine.attempt(
|
||||
failed_action={"type": "click"},
|
||||
critic_detail="Mauvaise fenêtre ouverte au lieu du bloc-notes",
|
||||
)
|
||||
assert result.action_taken == RecoveryAction.CLOSE_WINDOW
|
||||
assert result.success is True
|
||||
|
||||
def test_menu_ouvert_escape(self):
|
||||
"""Menu déroulant ouvert → Recovery fait Escape."""
|
||||
from agent_v0.agent_v1.core.recovery import RecoveryAction
|
||||
engine, executor = self._make_engine()
|
||||
result = engine.attempt(
|
||||
failed_action={"type": "click"},
|
||||
critic_detail="Un menu déroulant s'est ouvert",
|
||||
)
|
||||
assert result.action_taken == RecoveryAction.ESCAPE
|
||||
assert result.success is True
|
||||
|
||||
def test_aucune_strategie_applicable(self):
|
||||
"""Pas de pattern reconnu → NONE."""
|
||||
from agent_v0.agent_v1.core.recovery import RecoveryAction
|
||||
engine, executor = self._make_engine()
|
||||
result = engine.attempt(
|
||||
failed_action={"type": "wait"},
|
||||
critic_detail="Quelque chose d'inattendu",
|
||||
)
|
||||
assert result.action_taken == RecoveryAction.NONE
|
||||
assert result.success is False
|
||||
|
||||
def test_recovery_result_to_dict(self):
|
||||
"""RecoveryResult se sérialise correctement."""
|
||||
from agent_v0.agent_v1.core.recovery import RecoveryResult, RecoveryAction
|
||||
d = RecoveryResult(
|
||||
action_taken=RecoveryAction.UNDO, success=True, detail="Ctrl+Z"
|
||||
).to_dict()
|
||||
assert d["action_taken"] == "undo"
|
||||
assert d["success"] is True
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# P4 : Learning — apprentissage runtime
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestReplayLearner:
|
||||
|
||||
@pytest.fixture
|
||||
def learner(self):
|
||||
tmpdir = tempfile.mkdtemp(prefix="test_learning_")
|
||||
from agent_v0.server_v1.replay_learner import ReplayLearner
|
||||
l = ReplayLearner(learning_dir=tmpdir)
|
||||
yield l
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
def test_record_et_load_session(self, learner):
|
||||
"""Enregistrer un résultat et le relire depuis le fichier."""
|
||||
from agent_v0.server_v1.replay_learner import ActionOutcome
|
||||
outcome = ActionOutcome(
|
||||
session_id="test_session",
|
||||
action_id="act_001",
|
||||
action_type="click",
|
||||
target_description="Bouton Enregistrer",
|
||||
resolution_method="som_text",
|
||||
resolution_score=0.95,
|
||||
success=True,
|
||||
)
|
||||
learner.record(outcome)
|
||||
|
||||
# Relire
|
||||
loaded = learner.load_session("test_session")
|
||||
assert len(loaded) == 1
|
||||
assert loaded[0].action_id == "act_001"
|
||||
assert loaded[0].success is True
|
||||
assert loaded[0].resolution_method == "som_text"
|
||||
|
||||
def test_record_from_replay_result(self, learner):
|
||||
"""Convertir le format replay en ActionOutcome."""
|
||||
learner.record_from_replay_result(
|
||||
session_id="s1",
|
||||
action={"action_id": "a1", "type": "click", "target_spec": {"by_text": "OK", "window_title": "App"}},
|
||||
result={"success": True, "resolution_method": "template", "resolution_score": 0.9},
|
||||
verification={"verified": True, "semantic_verified": True, "semantic_detail": "OK"},
|
||||
)
|
||||
loaded = learner.load_session("s1")
|
||||
assert len(loaded) == 1
|
||||
assert loaded[0].target_description == "OK"
|
||||
assert loaded[0].semantic_verified is True
|
||||
|
||||
def test_query_similar(self, learner):
|
||||
"""Requêter des résultats similaires par description."""
|
||||
from agent_v0.server_v1.replay_learner import ActionOutcome
|
||||
# Enregistrer plusieurs résultats
|
||||
for i, (desc, method, success) in enumerate([
|
||||
("Bouton Enregistrer", "som_text", True),
|
||||
("Bouton Annuler", "template", True),
|
||||
("Bouton Enregistrer", "vlm_direct", False),
|
||||
("Menu Fichier", "som_text", True),
|
||||
]):
|
||||
learner.record(ActionOutcome(
|
||||
session_id="s1", action_id=f"a{i}",
|
||||
action_type="click", target_description=desc,
|
||||
resolution_method=method, success=success,
|
||||
))
|
||||
|
||||
# Chercher "Enregistrer"
|
||||
results = learner.query_similar(target_description="Enregistrer")
|
||||
assert len(results) == 2
|
||||
# Les deux résultats concernent "Enregistrer"
|
||||
for r in results:
|
||||
assert "enregistrer" in r["outcome"]["target_description"].lower()
|
||||
|
||||
def test_get_stats(self, learner):
|
||||
"""Les statistiques globales sont correctes."""
|
||||
from agent_v0.server_v1.replay_learner import ActionOutcome
|
||||
for success, method in [(True, "som"), (True, "som"), (False, "template"), (True, "vlm")]:
|
||||
learner.record(ActionOutcome(
|
||||
session_id="s1", action_id="a",
|
||||
action_type="click", success=success,
|
||||
resolution_method=method,
|
||||
))
|
||||
|
||||
stats = learner.get_stats()
|
||||
assert stats["total"] == 4
|
||||
assert stats["success_rate"] == 0.75
|
||||
assert stats["methods"]["som"]["success_rate"] == 1.0
|
||||
assert stats["methods"]["template"]["success_rate"] == 0.0
|
||||
|
||||
def test_gemma4_indisponible_pas_de_crash(self, learner):
|
||||
"""Le learning fonctionne même sans VLM."""
|
||||
from agent_v0.server_v1.replay_learner import ActionOutcome
|
||||
# Pas de crash, juste un record simple
|
||||
learner.record(ActionOutcome(
|
||||
session_id="s1", action_id="a1", action_type="click",
|
||||
success=False, error="target_not_found",
|
||||
))
|
||||
stats = learner.get_stats()
|
||||
assert stats["total"] == 1
|
||||
assert stats["success_rate"] == 0.0
|
||||
|
||||
def test_fichier_jsonl_format(self, learner):
|
||||
"""Le fichier JSONL contient du JSON valide ligne par ligne."""
|
||||
from agent_v0.server_v1.replay_learner import ActionOutcome
|
||||
learner.record(ActionOutcome(
|
||||
session_id="s1", action_id="a1", action_type="click", success=True,
|
||||
))
|
||||
learner.record(ActionOutcome(
|
||||
session_id="s1", action_id="a2", action_type="type", success=False,
|
||||
))
|
||||
|
||||
jsonl_file = learner.learning_dir / "s1.jsonl"
|
||||
assert jsonl_file.is_file()
|
||||
|
||||
with open(jsonl_file) as f:
|
||||
lines = f.readlines()
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
data = json.loads(line) # Doit être du JSON valide
|
||||
assert "action_id" in data
|
||||
assert "success" in data
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Boucle d'apprentissage : consolidation cross-workflow
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestLearningLoop:
|
||||
"""Tests de la boucle d'apprentissage : les replays passés améliorent les suivants."""
|
||||
|
||||
@pytest.fixture
|
||||
def learner(self):
|
||||
tmpdir = tempfile.mkdtemp(prefix="test_learning_loop_")
|
||||
from agent_v0.server_v1.replay_learner import ReplayLearner
|
||||
l = ReplayLearner(learning_dir=tmpdir)
|
||||
yield l
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
def test_best_strategy_apprend_du_succes(self, learner):
|
||||
"""La meilleure stratégie est celle qui a le plus de succès."""
|
||||
from agent_v0.server_v1.replay_learner import ActionOutcome
|
||||
# template échoue 3 fois sur "Enregistrer"
|
||||
for i in range(3):
|
||||
learner.record(ActionOutcome(
|
||||
session_id=f"s{i}", action_id=f"a{i}", action_type="click",
|
||||
target_description="Enregistrer", resolution_method="anchor_template",
|
||||
success=False,
|
||||
))
|
||||
# som_text réussit 2 fois sur "Enregistrer"
|
||||
for i in range(2):
|
||||
learner.record(ActionOutcome(
|
||||
session_id=f"s{10+i}", action_id=f"a{10+i}", action_type="click",
|
||||
target_description="Enregistrer", resolution_method="som_text_match",
|
||||
success=True,
|
||||
))
|
||||
|
||||
best = learner.best_strategy_for("Enregistrer")
|
||||
assert best == "som_text_match"
|
||||
|
||||
def test_best_strategy_minimum_2_essais(self, learner):
|
||||
"""Il faut au moins 2 essais pour qu'une stratégie soit recommandée."""
|
||||
from agent_v0.server_v1.replay_learner import ActionOutcome
|
||||
# Un seul succès → pas assez pour recommander
|
||||
learner.record(ActionOutcome(
|
||||
session_id="s1", action_id="a1", action_type="click",
|
||||
target_description="OK", resolution_method="vlm_direct",
|
||||
success=True,
|
||||
))
|
||||
best = learner.best_strategy_for("OK")
|
||||
assert best is None
|
||||
|
||||
def test_best_strategy_rien_si_historique_vide(self, learner):
|
||||
"""Pas d'historique → pas de recommandation."""
|
||||
best = learner.best_strategy_for("Inexistant")
|
||||
assert best is None
|
||||
|
||||
def test_consolidate_workflow_enrichit_les_actions(self, learner):
|
||||
"""La consolidation injecte _learned_strategy dans les target_spec."""
|
||||
from agent_v0.server_v1.replay_learner import ActionOutcome
|
||||
# Historique : som_text_match marche pour "Fichier"
|
||||
for i in range(3):
|
||||
learner.record(ActionOutcome(
|
||||
session_id=f"s{i}", action_id=f"a{i}", action_type="click",
|
||||
target_description="Fichier", resolution_method="som_text_match",
|
||||
success=True,
|
||||
))
|
||||
|
||||
# Workflow avec une action "Fichier"
|
||||
actions = [
|
||||
{"type": "click", "target_spec": {"by_text": "Fichier", "window_title": "Bloc-notes"}},
|
||||
{"type": "type", "text": "bonjour"},
|
||||
{"type": "click", "target_spec": {"by_text": "Inconnu"}},
|
||||
]
|
||||
|
||||
enriched = learner.consolidate_workflow(actions)
|
||||
assert enriched == 1 # Seul "Fichier" a un historique
|
||||
assert actions[0]["target_spec"]["_learned_strategy"] == "som_text_match"
|
||||
assert "_learned_strategy" not in actions[2].get("target_spec", {})
|
||||
|
||||
def test_consolidation_cross_workflow(self, learner):
|
||||
"""Un succès dans le workflow A améliore le workflow B."""
|
||||
from agent_v0.server_v1.replay_learner import ActionOutcome
|
||||
# Workflow A : "Enregistrer" réussit avec grounding_vlm
|
||||
for i in range(3):
|
||||
learner.record(ActionOutcome(
|
||||
session_id="workflow_A", action_id=f"a{i}", action_type="click",
|
||||
target_description="Enregistrer",
|
||||
window_title="Bloc-notes",
|
||||
resolution_method="grounding_vlm", success=True,
|
||||
))
|
||||
|
||||
# Workflow B : contient aussi "Enregistrer"
|
||||
workflow_b = [
|
||||
{"type": "click", "target_spec": {"by_text": "Enregistrer", "window_title": "Bloc-notes"}},
|
||||
]
|
||||
enriched = learner.consolidate_workflow(workflow_b, "workflow_B")
|
||||
assert enriched == 1
|
||||
assert workflow_b[0]["target_spec"]["_learned_strategy"] == "grounding_vlm"
|
||||
|
||||
def test_grounding_reordonne_strategies(self):
|
||||
"""Le GroundingEngine réordonne ses stratégies selon _learned_strategy."""
|
||||
from agent_v0.agent_v1.core.grounding import GroundingEngine
|
||||
executor = MagicMock()
|
||||
executor._capture_screenshot_b64.return_value = "fake"
|
||||
# Simuler que template marche
|
||||
executor._server_resolve_target.return_value = None
|
||||
executor._template_match_anchor.return_value = {
|
||||
"resolved": True, "x_pct": 0.5, "y_pct": 0.5, "score": 0.9,
|
||||
}
|
||||
executor._hybrid_vlm_resolve.return_value = None
|
||||
|
||||
engine = GroundingEngine(executor)
|
||||
|
||||
# Avec _learned_strategy = anchor_template → template en premier
|
||||
result = engine.locate(
|
||||
"http://server",
|
||||
{"by_text": "OK", "anchor_image_base64": "abc", "_learned_strategy": "anchor_template"},
|
||||
0.5, 0.3, 1920, 1080,
|
||||
)
|
||||
assert result.found is True
|
||||
assert result.method == "anchor_template"
|
||||
# Le serveur n'a PAS été appelé (template était en premier)
|
||||
executor._server_resolve_target.assert_not_called()
|
||||
441
tests/unit/test_replay_critic.py
Normal file
441
tests/unit/test_replay_critic.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
Tests unitaires pour le Critic (ReplayVerifier.verify_with_critic)
|
||||
et l'enrichissement des actions avec intentions.
|
||||
|
||||
Vérifie les FONCTIONNALITÉS, pas juste la non-régression :
|
||||
1. Le Critic fusionne correctement pixel + sémantique
|
||||
2. La matrice de décision (4 cas) est correcte
|
||||
3. L'enrichissement intentions parse bien les réponses gemma4
|
||||
4. Les fallbacks fonctionnent quand le VLM est indisponible
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
_ROOT = str(Path(__file__).resolve().parents[2])
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
from agent_v0.server_v1.replay_verifier import ReplayVerifier, VerificationResult
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Fixtures
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _make_screenshot_b64(width=100, height=100, color=(128, 128, 128)):
|
||||
"""Créer un screenshot base64 factice (JPEG)."""
|
||||
from PIL import Image
|
||||
img = Image.new("RGB", (width, height), color)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=50)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def verifier():
|
||||
return ReplayVerifier()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def screenshot_gray():
|
||||
return _make_screenshot_b64(100, 100, (128, 128, 128))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def screenshot_white():
|
||||
return _make_screenshot_b64(100, 100, (255, 255, 255))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests VerificationResult — nouveaux champs sémantiques
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestVerificationResult:
|
||||
|
||||
def test_to_dict_sans_semantique(self):
|
||||
"""Sans vérification sémantique, les champs semantic_ sont absents du dict."""
|
||||
r = VerificationResult(
|
||||
verified=True, confidence=0.8, changes_detected=True,
|
||||
change_area_pct=5.0, suggestion="continue", detail="test",
|
||||
)
|
||||
d = r.to_dict()
|
||||
assert "semantic_verified" not in d
|
||||
assert d["verified"] is True
|
||||
assert d["confidence"] == 0.8
|
||||
|
||||
def test_to_dict_avec_semantique(self):
|
||||
"""Avec vérification sémantique, les champs semantic_ sont présents."""
|
||||
r = VerificationResult(
|
||||
verified=True, confidence=0.9, changes_detected=True,
|
||||
change_area_pct=5.0, suggestion="continue", detail="test",
|
||||
semantic_verified=True, semantic_detail="Bouton visible",
|
||||
semantic_elapsed_ms=1500.0,
|
||||
)
|
||||
d = r.to_dict()
|
||||
assert d["semantic_verified"] is True
|
||||
assert d["semantic_detail"] == "Bouton visible"
|
||||
assert d["semantic_elapsed_ms"] == 1500.0
|
||||
|
||||
def test_to_dict_semantique_false(self):
|
||||
"""semantic_verified=False doit apparaître dans le dict."""
|
||||
r = VerificationResult(
|
||||
verified=False, confidence=0.7, changes_detected=True,
|
||||
change_area_pct=5.0, suggestion="retry",
|
||||
semantic_verified=False, semantic_detail="Mauvais écran",
|
||||
semantic_elapsed_ms=2000.0,
|
||||
)
|
||||
d = r.to_dict()
|
||||
assert d["semantic_verified"] is False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests verify_with_critic — matrice de décision
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestVerifyWithCritic:
|
||||
|
||||
def test_sans_expected_result_retourne_pixel_seul(self, verifier, screenshot_gray):
|
||||
"""Sans expected_result, verify_with_critic = verify_action (pixel seul)."""
|
||||
result = verifier.verify_with_critic(
|
||||
action={"type": "click", "action_id": "test"},
|
||||
result={"success": True},
|
||||
screenshot_before=screenshot_gray,
|
||||
screenshot_after=screenshot_gray,
|
||||
expected_result="", # Pas d'attendu
|
||||
)
|
||||
# Pixel seul — pas de champ semantic
|
||||
assert result.semantic_verified is None
|
||||
|
||||
def test_sans_screenshots_pas_de_semantique(self, verifier):
|
||||
"""Sans screenshots, pas de vérification sémantique possible."""
|
||||
result = verifier.verify_with_critic(
|
||||
action={"type": "click", "action_id": "test"},
|
||||
result={"success": True},
|
||||
screenshot_before=None,
|
||||
screenshot_after=None,
|
||||
expected_result="Le fichier est ouvert",
|
||||
)
|
||||
# Pas de screenshots → pixel seul (confidence basse)
|
||||
assert result.verified is True
|
||||
assert result.confidence < 0.5
|
||||
|
||||
def test_pixel_pas_change_et_expected_result_skip_vlm(
|
||||
self, verifier, screenshot_gray,
|
||||
):
|
||||
"""Si pixel identiques + expected_result → skip VLM (pas de changement = retry)."""
|
||||
result = verifier.verify_with_critic(
|
||||
action={"type": "click", "action_id": "test", "x_pct": 0.5, "y_pct": 0.5},
|
||||
result={"success": True},
|
||||
screenshot_before=screenshot_gray,
|
||||
screenshot_after=screenshot_gray, # Même image → aucun changement
|
||||
expected_result="Le menu s'est ouvert",
|
||||
)
|
||||
# Pas de changement pixel → retry, VLM non appelé
|
||||
assert result.verified is False
|
||||
assert result.suggestion == "retry"
|
||||
assert result.semantic_verified is None # VLM non appelé
|
||||
|
||||
@patch("agent_v0.server_v1.replay_verifier.ReplayVerifier._verify_semantic")
|
||||
def test_pixel_ok_semantic_ok(
|
||||
self, mock_semantic, verifier, screenshot_gray, screenshot_white,
|
||||
):
|
||||
"""Pixel OK + Semantic OK → vérifié avec haute confiance."""
|
||||
mock_semantic.return_value = {
|
||||
"verified": True,
|
||||
"detail": "Le menu est bien ouvert",
|
||||
"elapsed_ms": 2000.0,
|
||||
}
|
||||
result = verifier.verify_with_critic(
|
||||
action={"type": "click", "action_id": "test"},
|
||||
result={"success": True},
|
||||
screenshot_before=screenshot_gray,
|
||||
screenshot_after=screenshot_white, # Différent → changement détecté
|
||||
expected_result="Le menu s'est ouvert",
|
||||
)
|
||||
assert result.verified is True
|
||||
assert result.semantic_verified is True
|
||||
assert result.confidence >= 0.7
|
||||
assert "Critic OK" in result.detail
|
||||
|
||||
@patch("agent_v0.server_v1.replay_verifier.ReplayVerifier._verify_semantic")
|
||||
def test_pixel_ok_semantic_non(
|
||||
self, mock_semantic, verifier, screenshot_gray, screenshot_white,
|
||||
):
|
||||
"""Pixel OK + Semantic NON → INATTENDU (changement mais pas le bon)."""
|
||||
mock_semantic.return_value = {
|
||||
"verified": False,
|
||||
"detail": "Une erreur est apparue au lieu du menu",
|
||||
"elapsed_ms": 2500.0,
|
||||
}
|
||||
result = verifier.verify_with_critic(
|
||||
action={"type": "click", "action_id": "test"},
|
||||
result={"success": True},
|
||||
screenshot_before=screenshot_gray,
|
||||
screenshot_after=screenshot_white,
|
||||
expected_result="Le menu s'est ouvert",
|
||||
)
|
||||
assert result.verified is False
|
||||
assert result.semantic_verified is False
|
||||
assert result.suggestion == "retry"
|
||||
assert "Critic NON" in result.detail
|
||||
|
||||
@patch("agent_v0.server_v1.replay_verifier.ReplayVerifier._verify_semantic")
|
||||
def test_vlm_indisponible_fallback_pixel(
|
||||
self, mock_semantic, verifier, screenshot_gray, screenshot_white,
|
||||
):
|
||||
"""VLM indisponible → fallback sur pixel seul."""
|
||||
mock_semantic.return_value = None # VLM down
|
||||
result = verifier.verify_with_critic(
|
||||
action={"type": "click", "action_id": "test"},
|
||||
result={"success": True},
|
||||
screenshot_before=screenshot_gray,
|
||||
screenshot_after=screenshot_white,
|
||||
expected_result="Le menu s'est ouvert",
|
||||
)
|
||||
# Fallback pixel seul — le changement est détecté
|
||||
assert result.verified is True
|
||||
assert result.semantic_verified is None # Pas de VLM
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests _verify_semantic — parsing de la réponse VLM
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestVerifySemantic:
|
||||
|
||||
@patch("requests.post")
|
||||
def test_parse_verdict_oui(self, mock_post, verifier, screenshot_white):
|
||||
"""Parse correctement VERDICT: OUI."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = True
|
||||
mock_resp.json.return_value = {
|
||||
"message": {"content": "VERDICT: OUI\nRAISON: Le fichier est bien ouvert"}
|
||||
}
|
||||
mock_post.return_value = mock_resp
|
||||
result = verifier._verify_semantic(
|
||||
screenshot_before=screenshot_white,
|
||||
screenshot_after=screenshot_white,
|
||||
expected_result="Le fichier est ouvert",
|
||||
)
|
||||
assert result is not None
|
||||
assert result["verified"] is True
|
||||
assert "ouvert" in result["detail"]
|
||||
|
||||
@patch("requests.post")
|
||||
def test_parse_verdict_non(self, mock_post, verifier, screenshot_white):
|
||||
"""Parse correctement VERDICT: NON."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = True
|
||||
mock_resp.json.return_value = {
|
||||
"message": {"content": "VERDICT: NON\nRAISON: L'écran n'a pas changé"}
|
||||
}
|
||||
mock_post.return_value = mock_resp
|
||||
result = verifier._verify_semantic(
|
||||
screenshot_before=screenshot_white,
|
||||
screenshot_after=screenshot_white,
|
||||
expected_result="Le menu s'est ouvert",
|
||||
)
|
||||
assert result is not None
|
||||
assert result["verified"] is False
|
||||
|
||||
@patch("requests.post")
|
||||
def test_vlm_timeout_retourne_none(self, mock_post, verifier, screenshot_white):
|
||||
"""Timeout VLM → retourne None (fallback gracieux)."""
|
||||
import requests as _real_requests
|
||||
mock_post.side_effect = _real_requests.Timeout("timeout")
|
||||
result = verifier._verify_semantic(
|
||||
screenshot_before=screenshot_white,
|
||||
screenshot_after=screenshot_white,
|
||||
expected_result="Le fichier est ouvert",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_sans_screenshot_after_retourne_none(self, verifier):
|
||||
"""Sans screenshot_after, pas de vérification possible."""
|
||||
result = verifier._verify_semantic(
|
||||
screenshot_before=None,
|
||||
screenshot_after=None,
|
||||
expected_result="Le fichier est ouvert",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests _merge_results — matrice pixel x sémantique
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestMergeResults:
|
||||
|
||||
def test_pixel_ok_sem_ok(self, verifier):
|
||||
pixel = VerificationResult(
|
||||
verified=True, confidence=0.7, changes_detected=True,
|
||||
change_area_pct=5.0, suggestion="continue",
|
||||
)
|
||||
semantic = {"verified": True, "detail": "OK", "elapsed_ms": 1000}
|
||||
result = verifier._merge_results(pixel, semantic)
|
||||
assert result.verified is True
|
||||
assert result.semantic_verified is True
|
||||
assert result.confidence >= 0.7
|
||||
|
||||
def test_pixel_ok_sem_non(self, verifier):
|
||||
"""Pixel OK + Sémantique NON = inattendu → retry."""
|
||||
pixel = VerificationResult(
|
||||
verified=True, confidence=0.7, changes_detected=True,
|
||||
change_area_pct=5.0, suggestion="continue",
|
||||
)
|
||||
semantic = {"verified": False, "detail": "Erreur popup", "elapsed_ms": 2000}
|
||||
result = verifier._merge_results(pixel, semantic)
|
||||
assert result.verified is False
|
||||
assert result.semantic_verified is False
|
||||
assert result.suggestion == "retry"
|
||||
|
||||
def test_pixel_non_sem_ok(self, verifier):
|
||||
"""Pixel inchangé + Sémantique OK = état subtil → continue."""
|
||||
pixel = VerificationResult(
|
||||
verified=False, confidence=0.5, changes_detected=False,
|
||||
change_area_pct=0.1, suggestion="retry",
|
||||
)
|
||||
semantic = {"verified": True, "detail": "Onglet déjà actif", "elapsed_ms": 1500}
|
||||
result = verifier._merge_results(pixel, semantic)
|
||||
assert result.verified is True
|
||||
assert result.semantic_verified is True
|
||||
assert result.suggestion == "continue"
|
||||
|
||||
def test_pixel_non_sem_non(self, verifier):
|
||||
"""Pixel inchangé + Sémantique NON = échec complet → retry."""
|
||||
pixel = VerificationResult(
|
||||
verified=False, confidence=0.5, changes_detected=False,
|
||||
change_area_pct=0.0, suggestion="retry",
|
||||
)
|
||||
semantic = {"verified": False, "detail": "Rien ne s'est passé", "elapsed_ms": 3000}
|
||||
result = verifier._merge_results(pixel, semantic)
|
||||
assert result.verified is False
|
||||
assert result.semantic_verified is False
|
||||
assert result.confidence >= 0.7 # Haute confiance dans l'échec
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests enrichissement intentions (stream_processor)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestEnrichActionsWithIntentions:
|
||||
|
||||
@patch("requests.post")
|
||||
@patch("requests.get")
|
||||
def test_enrichissement_parse_reponse_gemma4(self, mock_get, mock_post):
|
||||
"""La réponse gemma4 est correctement parsée en intention/avant/après."""
|
||||
from agent_v0.server_v1.stream_processor import _enrich_actions_with_intentions
|
||||
import tempfile, shutil
|
||||
|
||||
# Mock gemma4 disponible
|
||||
mock_tags_resp = MagicMock()
|
||||
mock_tags_resp.ok = True
|
||||
mock_get.return_value = mock_tags_resp
|
||||
|
||||
mock_chat_resp = MagicMock()
|
||||
mock_chat_resp.ok = True
|
||||
mock_chat_resp.json.return_value = {
|
||||
"message": {
|
||||
"content": (
|
||||
"INTENTION: Ouvrir le fichier client dans le logiciel\n"
|
||||
"AVANT: Le logiciel est ouvert sur la page d'accueil\n"
|
||||
"APRÈS: Le fichier client est affiché dans la fenêtre"
|
||||
)
|
||||
}
|
||||
}
|
||||
mock_post.return_value = mock_chat_resp
|
||||
|
||||
actions = [
|
||||
{
|
||||
"type": "click",
|
||||
"action_id": "act_001",
|
||||
"target_spec": {"by_text": "Ouvrir", "window_title": "Logiciel"},
|
||||
},
|
||||
{
|
||||
"type": "wait",
|
||||
"action_id": "act_002",
|
||||
"duration_ms": 1000,
|
||||
},
|
||||
]
|
||||
|
||||
tmpdir = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
(tmpdir / "shots").mkdir()
|
||||
_enrich_actions_with_intentions(actions, tmpdir)
|
||||
|
||||
# L'action click doit être enrichie
|
||||
assert actions[0].get("intention") == "Ouvrir le fichier client dans le logiciel"
|
||||
assert actions[0].get("expected_state") == "Le logiciel est ouvert sur la page d'accueil"
|
||||
assert actions[0].get("expected_result") == "Le fichier client est affiché dans la fenêtre"
|
||||
# expected_state doit aussi être dans target_spec (pour l'Observer)
|
||||
assert actions[0]["target_spec"]["expected_state"] == "Le logiciel est ouvert sur la page d'accueil"
|
||||
|
||||
# L'action wait ne doit PAS être enrichie
|
||||
assert "intention" not in actions[1]
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
@patch("requests.get")
|
||||
def test_gemma4_indisponible_pas_de_crash(self, mock_get):
|
||||
"""Si gemma4 est down, l'enrichissement est silencieusement désactivé."""
|
||||
from agent_v0.server_v1.stream_processor import _enrich_actions_with_intentions
|
||||
import tempfile, shutil
|
||||
|
||||
mock_get.side_effect = ConnectionError("gemma4 down")
|
||||
|
||||
actions = [
|
||||
{"type": "click", "action_id": "act_001", "target_spec": {"by_text": "OK"}},
|
||||
]
|
||||
|
||||
tmpdir = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
(tmpdir / "shots").mkdir()
|
||||
_enrich_actions_with_intentions(actions, tmpdir)
|
||||
# Aucun crash, aucune intention ajoutée
|
||||
assert "intention" not in actions[0]
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
@patch("requests.post")
|
||||
@patch("requests.get")
|
||||
def test_reponse_gemma4_malformee(self, mock_get, mock_post):
|
||||
"""Si gemma4 retourne du texte non structuré, pas de crash."""
|
||||
from agent_v0.server_v1.stream_processor import _enrich_actions_with_intentions
|
||||
import tempfile, shutil
|
||||
|
||||
mock_tags = MagicMock()
|
||||
mock_tags.ok = True
|
||||
mock_get.return_value = mock_tags
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = True
|
||||
mock_resp.json.return_value = {
|
||||
"message": {"content": "Je ne comprends pas cette demande."}
|
||||
}
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
actions = [
|
||||
{"type": "click", "action_id": "act_001", "target_spec": {"by_text": "OK"}},
|
||||
]
|
||||
|
||||
tmpdir = Path(tempfile.mkdtemp())
|
||||
try:
|
||||
(tmpdir / "shots").mkdir()
|
||||
_enrich_actions_with_intentions(actions, tmpdir)
|
||||
# Pas de crash, mais pas d'intention non plus
|
||||
assert "intention" not in actions[0]
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
762
tests/unit/test_task_planner.py
Normal file
762
tests/unit/test_task_planner.py
Normal file
@@ -0,0 +1,762 @@
|
||||
# tests/unit/test_task_planner.py
|
||||
"""
|
||||
Tests unitaires du TaskPlanner (planificateur MACRO).
|
||||
|
||||
Vérifie :
|
||||
1. La compréhension d'ordres simples (understand)
|
||||
2. Le matching de workflows par description sémantique
|
||||
3. La détection de boucles et l'extraction de paramètres
|
||||
4. La conversion étapes → actions JSON (format correct)
|
||||
5. L'extraction de descriptions de session
|
||||
|
||||
Toutes les réponses gemma4 sont mockées pour la reproductibilité.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
_ROOT = str(Path(__file__).resolve().parents[2])
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
from agent_v0.server_v1.task_planner import TaskPlanner, TaskPlan
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Fixtures
|
||||
# =========================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def planner():
|
||||
"""TaskPlanner avec port gemma4 factice."""
|
||||
return TaskPlanner(gemma4_port="11435", domain_id="generic")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflows():
|
||||
"""Workflows disponibles pour les tests de matching."""
|
||||
return [
|
||||
{
|
||||
"session_id": "sess_001",
|
||||
"name": "Bloc-notes",
|
||||
"description": "Ouvrir Bloc-notes via Exécuter (Win+R) et écrire du texte",
|
||||
"machine": "PC-01",
|
||||
"event_count": 25,
|
||||
},
|
||||
{
|
||||
"session_id": "sess_002",
|
||||
"name": "Explorateur de fichiers",
|
||||
"description": "Naviguer dans l'Explorateur de fichiers et ouvrir des images",
|
||||
"machine": "PC-01",
|
||||
"event_count": 40,
|
||||
},
|
||||
{
|
||||
"session_id": "sess_003",
|
||||
"name": "DxCare, Codage CIM-10",
|
||||
"description": "Ouvrir un dossier patient dans DxCare et coder les diagnostics CIM-10",
|
||||
"machine": "PC-TIM",
|
||||
"event_count": 80,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _mock_gemma4_response(content: str):
|
||||
"""Créer un mock de réponse HTTP gemma4."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = True
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {
|
||||
"message": {"content": content}
|
||||
}
|
||||
return mock_resp
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests : understand — ordre simple
|
||||
# =========================================================================
|
||||
|
||||
class TestUnderstandOrdreSimple:
|
||||
"""Vérifier que understand() parse correctement des réponses gemma4."""
|
||||
|
||||
def test_understand_ordre_simple(self, planner, sample_workflows):
|
||||
"""'Ouvre le bloc-notes' → understood=True."""
|
||||
gemma4_response = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: 1\n"
|
||||
"CONFIANCE: 0.9\n"
|
||||
"PARAMETRES: AUCUN\n"
|
||||
"BOUCLE: NON\n"
|
||||
"SOURCE_BOUCLE: aucun\n"
|
||||
"PLAN:\n"
|
||||
"1. Ouvrir le Bloc-notes via Win+R\n"
|
||||
"2. Taper notepad et valider\n"
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
plan = planner.understand(
|
||||
"Ouvre le bloc-notes",
|
||||
available_workflows=sample_workflows,
|
||||
)
|
||||
|
||||
assert plan.understood is True
|
||||
assert plan.instruction == "Ouvre le bloc-notes"
|
||||
|
||||
def test_understand_instruction_non_comprise(self, planner):
|
||||
"""Instruction incompréhensible → understood=False."""
|
||||
gemma4_response = "COMPRIS: NON\nWORKFLOW: AUCUN\nBOUCLE: NON\n"
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
plan = planner.understand("xyzzy blah blah")
|
||||
|
||||
assert plan.understood is False
|
||||
|
||||
def test_understand_gemma4_erreur_http(self, planner):
|
||||
"""Erreur HTTP gemma4 → plan.error renseigné."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = False
|
||||
mock_resp.status_code = 500
|
||||
|
||||
with patch("requests.post", return_value=mock_resp):
|
||||
plan = planner.understand("Ouvre le bloc-notes")
|
||||
|
||||
assert plan.understood is False
|
||||
assert "500" in plan.error
|
||||
|
||||
def test_understand_gemma4_timeout(self, planner):
|
||||
"""Timeout gemma4 → plan.error renseigné."""
|
||||
import requests
|
||||
with patch("requests.post", side_effect=requests.Timeout("timeout")):
|
||||
plan = planner.understand("Ouvre le bloc-notes")
|
||||
|
||||
assert plan.understood is False
|
||||
assert "erreur" in plan.error.lower() or "timeout" in plan.error.lower()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests : matching workflow
|
||||
# =========================================================================
|
||||
|
||||
class TestUnderstandIdentifieWorkflow:
|
||||
"""Vérifier que le matching de workflow fonctionne."""
|
||||
|
||||
def test_understand_identifie_workflow(self, planner, sample_workflows):
|
||||
"""Quand un workflow matche, workflow_match est rempli."""
|
||||
gemma4_response = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: 1\n"
|
||||
"CONFIANCE: 0.9\n"
|
||||
"PARAMETRES: AUCUN\n"
|
||||
"BOUCLE: NON\n"
|
||||
"SOURCE_BOUCLE: aucun\n"
|
||||
"PLAN:\n"
|
||||
"1. Lancer le Bloc-notes\n"
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
plan = planner.understand(
|
||||
"Ouvre le bloc-notes",
|
||||
available_workflows=sample_workflows,
|
||||
)
|
||||
|
||||
assert plan.workflow_match == "sess_001"
|
||||
assert plan.workflow_name == "Bloc-notes"
|
||||
assert plan.mode == "replay"
|
||||
assert plan.match_confidence >= 0.8
|
||||
|
||||
def test_understand_workflow_aucun_match(self, planner, sample_workflows):
|
||||
"""Aucun workflow correspondant → mode libre."""
|
||||
gemma4_response = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: AUCUN\n"
|
||||
"PARAMETRES: AUCUN\n"
|
||||
"BOUCLE: NON\n"
|
||||
"SOURCE_BOUCLE: aucun\n"
|
||||
"PLAN:\n"
|
||||
"1. Ouvrir Chrome\n"
|
||||
"2. Aller sur Google\n"
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
plan = planner.understand(
|
||||
"Recherche voiture sur Google",
|
||||
available_workflows=sample_workflows,
|
||||
)
|
||||
|
||||
assert plan.understood is True
|
||||
assert plan.workflow_match == ""
|
||||
assert plan.mode == "free"
|
||||
|
||||
def test_understand_workflow_second_match(self, planner, sample_workflows):
|
||||
"""Workflow 2 sélectionné correctement."""
|
||||
gemma4_response = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: 2\n"
|
||||
"CONFIANCE: 0.85\n"
|
||||
"BOUCLE: NON\n"
|
||||
"PLAN:\n"
|
||||
"1. Ouvrir l'explorateur de fichiers\n"
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
plan = planner.understand(
|
||||
"Ouvre mes images",
|
||||
available_workflows=sample_workflows,
|
||||
)
|
||||
|
||||
assert plan.workflow_match == "sess_002"
|
||||
assert plan.workflow_name == "Explorateur de fichiers"
|
||||
|
||||
def test_understand_workflow_avec_description_dans_prompt(self, planner, sample_workflows):
|
||||
"""Le prompt envoyé à gemma4 inclut les descriptions des workflows."""
|
||||
captured_body = {}
|
||||
|
||||
def capture_post(url, json=None, **kwargs):
|
||||
captured_body.update(json or {})
|
||||
return _mock_gemma4_response("COMPRIS: OUI\nWORKFLOW: AUCUN\nBOUCLE: NON\n")
|
||||
|
||||
with patch("requests.post", side_effect=capture_post):
|
||||
planner.understand(
|
||||
"Ouvre le bloc-notes",
|
||||
available_workflows=sample_workflows,
|
||||
)
|
||||
|
||||
prompt_content = captured_body["messages"][0]["content"]
|
||||
# La description doit apparaître dans le prompt
|
||||
assert "Ouvrir Bloc-notes via Exécuter" in prompt_content
|
||||
assert "Naviguer dans l'Explorateur" in prompt_content
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests : détection de boucle
|
||||
# =========================================================================
|
||||
|
||||
class TestUnderstandDetecteBoucle:
|
||||
"""Vérifier la détection de boucle."""
|
||||
|
||||
def test_understand_detecte_boucle(self, planner, sample_workflows):
|
||||
"""'traite TOUS les dossiers' → is_loop=True."""
|
||||
gemma4_response = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: 3\n"
|
||||
"CONFIANCE: 0.8\n"
|
||||
"PARAMETRES: AUCUN\n"
|
||||
"BOUCLE: OUI\n"
|
||||
"SOURCE_BOUCLE: écran\n"
|
||||
"PLAN:\n"
|
||||
"1. Pour chaque dossier dans la liste\n"
|
||||
"2. Ouvrir le dossier\n"
|
||||
"3. Coder les diagnostics\n"
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
plan = planner.understand(
|
||||
"Traite TOUS les dossiers de la liste",
|
||||
available_workflows=sample_workflows,
|
||||
)
|
||||
|
||||
assert plan.is_loop is True
|
||||
assert plan.loop_source == "écran"
|
||||
|
||||
def test_understand_pas_de_boucle(self, planner):
|
||||
"""Ordre simple → is_loop=False."""
|
||||
gemma4_response = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: AUCUN\n"
|
||||
"BOUCLE: NON\n"
|
||||
"SOURCE_BOUCLE: aucun\n"
|
||||
"PLAN:\n"
|
||||
"1. Ouvrir le navigateur\n"
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
plan = planner.understand("Ouvre le navigateur")
|
||||
|
||||
assert plan.is_loop is False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests : extraction de paramètres
|
||||
# =========================================================================
|
||||
|
||||
class TestUnderstandExtraitParametres:
|
||||
"""Vérifier l'extraction des paramètres."""
|
||||
|
||||
def test_understand_extrait_parametres(self, planner, sample_workflows):
|
||||
"""'dossiers de janvier' → parameters contient mois=janvier."""
|
||||
gemma4_response = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: 3\n"
|
||||
"CONFIANCE: 0.85\n"
|
||||
"PARAMETRES: mois=janvier\n"
|
||||
"BOUCLE: OUI\n"
|
||||
"SOURCE_BOUCLE: écran\n"
|
||||
"PLAN:\n"
|
||||
"1. Filtrer les dossiers de janvier\n"
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
plan = planner.understand(
|
||||
"Traite les dossiers de janvier",
|
||||
available_workflows=sample_workflows,
|
||||
)
|
||||
|
||||
assert "mois" in plan.parameters
|
||||
assert plan.parameters["mois"] == "janvier"
|
||||
|
||||
def test_understand_parametres_multiples(self, planner):
|
||||
"""Plusieurs paramètres sur des lignes séparées."""
|
||||
gemma4_response = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: AUCUN\n"
|
||||
"PARAMETRES:\n"
|
||||
"- patient=DUPONT\n"
|
||||
"- date=2026-01-15\n"
|
||||
"BOUCLE: NON\n"
|
||||
"PLAN:\n"
|
||||
"1. Rechercher le patient DUPONT\n"
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
plan = planner.understand("Cherche le dossier de DUPONT du 15 janvier")
|
||||
|
||||
assert plan.parameters.get("patient") == "DUPONT"
|
||||
assert plan.parameters.get("date") == "2026-01-15"
|
||||
|
||||
def test_understand_parametres_inline(self, planner):
|
||||
"""Paramètres sur la même ligne que PARAMETRES:."""
|
||||
gemma4_response = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: AUCUN\n"
|
||||
"PARAMETRES: nom=Martin, ville=Paris\n"
|
||||
"BOUCLE: NON\n"
|
||||
"PLAN:\n"
|
||||
"1. Chercher Martin à Paris\n"
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
plan = planner.understand("Cherche Martin à Paris")
|
||||
|
||||
assert plan.parameters.get("nom") == "Martin"
|
||||
assert plan.parameters.get("ville") == "Paris"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests : _parse_understanding (parsing tolérant)
|
||||
# =========================================================================
|
||||
|
||||
class TestParseUnderstanding:
|
||||
"""Tester le parsing tolérant de réponses gemma4 variées."""
|
||||
|
||||
def test_parse_markdown_gras(self, planner):
|
||||
"""Réponse avec **gras** → parsée correctement."""
|
||||
plan = TaskPlan(instruction="test")
|
||||
content = (
|
||||
"**COMPRIS:** OUI\n"
|
||||
"**WORKFLOW:** AUCUN\n"
|
||||
"**BOUCLE:** NON\n"
|
||||
"**PLAN:**\n"
|
||||
"1. Première étape\n"
|
||||
)
|
||||
result = planner._parse_understanding(plan, content, [])
|
||||
assert result.understood is True
|
||||
assert result.mode == "free"
|
||||
|
||||
def test_parse_confiance_pourcentage(self, planner, sample_workflows):
|
||||
"""CONFIANCE: 90% → match_confidence=0.9."""
|
||||
plan = TaskPlan(instruction="test")
|
||||
content = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: 1\n"
|
||||
"CONFIANCE: 90%\n"
|
||||
"BOUCLE: NON\n"
|
||||
)
|
||||
result = planner._parse_understanding(plan, content, sample_workflows)
|
||||
assert result.match_confidence == pytest.approx(0.9)
|
||||
|
||||
def test_parse_confiance_virgule(self, planner, sample_workflows):
|
||||
"""CONFIANCE: 0,85 → match_confidence=0.85."""
|
||||
plan = TaskPlan(instruction="test")
|
||||
content = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: 1\n"
|
||||
"CONFIANCE: 0,85\n"
|
||||
"BOUCLE: NON\n"
|
||||
)
|
||||
result = planner._parse_understanding(plan, content, sample_workflows)
|
||||
assert result.match_confidence == pytest.approx(0.85)
|
||||
|
||||
def test_parse_workflow_avec_parentheses(self, planner, sample_workflows):
|
||||
"""WORKFLOW: 2 (Explorateur) → index 2 correctement extrait."""
|
||||
plan = TaskPlan(instruction="test")
|
||||
content = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: 2 (Explorateur de fichiers)\n"
|
||||
"BOUCLE: NON\n"
|
||||
)
|
||||
result = planner._parse_understanding(plan, content, sample_workflows)
|
||||
assert result.workflow_match == "sess_002"
|
||||
|
||||
def test_parse_workflow_aucun_variantes(self, planner, sample_workflows):
|
||||
"""Toutes les variantes de 'aucun' sont reconnues."""
|
||||
for val in ("AUCUN", "None", "N/A", "-", "NON"):
|
||||
plan = TaskPlan(instruction="test")
|
||||
content = f"COMPRIS: OUI\nWORKFLOW: {val}\nBOUCLE: NON\n"
|
||||
result = planner._parse_understanding(plan, content, sample_workflows)
|
||||
assert result.workflow_match == "", f"Devrait être vide pour '{val}'"
|
||||
|
||||
def test_parse_etapes_tirets(self, planner):
|
||||
"""Étapes avec tirets → ajoutées au plan."""
|
||||
plan = TaskPlan(instruction="test")
|
||||
content = (
|
||||
"COMPRIS: OUI\n"
|
||||
"WORKFLOW: AUCUN\n"
|
||||
"BOUCLE: NON\n"
|
||||
"PLAN:\n"
|
||||
"- Ouvrir l'application\n"
|
||||
"- Cliquer sur Fichier\n"
|
||||
"- Sauvegarder\n"
|
||||
)
|
||||
result = planner._parse_understanding(plan, content, [])
|
||||
assert len(result.steps) == 3
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests : _steps_to_actions
|
||||
# =========================================================================
|
||||
|
||||
class TestStepsToActions:
|
||||
"""Vérifier la conversion étapes → actions JSON."""
|
||||
|
||||
def test_steps_to_actions_format(self, planner):
|
||||
"""Les actions générées ont le bon format (type, target_spec, etc.)."""
|
||||
gemma4_response = (
|
||||
'{"type": "click", "target_spec": {"by_text": "Rechercher"}}\n'
|
||||
'{"type": "type", "text": "bloc-notes"}\n'
|
||||
'{"type": "key_combo", "keys": ["enter"]}\n'
|
||||
'{"type": "wait", "duration_ms": 2000}\n'
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
actions = planner._steps_to_actions(
|
||||
[{"description": "1. Ouvrir le bloc-notes"}],
|
||||
{},
|
||||
)
|
||||
|
||||
assert len(actions) == 4
|
||||
assert actions[0]["type"] == "click"
|
||||
assert actions[0]["visual_mode"] is True # Ajouté automatiquement
|
||||
assert actions[0]["target_spec"]["by_text"] == "Rechercher"
|
||||
assert actions[1]["type"] == "type"
|
||||
assert actions[1]["text"] == "bloc-notes"
|
||||
assert actions[2]["type"] == "key_combo"
|
||||
assert actions[2]["keys"] == ["enter"]
|
||||
assert actions[3]["type"] == "wait"
|
||||
assert actions[3]["duration_ms"] == 2000
|
||||
|
||||
def test_steps_to_actions_json_array(self, planner):
|
||||
"""gemma4 retourne un tableau JSON → parsé correctement."""
|
||||
gemma4_response = (
|
||||
'Voici les actions :\n'
|
||||
'```json\n'
|
||||
'[\n'
|
||||
' {"type": "click", "target_spec": {"by_text": "Fichier"}},\n'
|
||||
' {"type": "click", "target_spec": {"by_text": "Ouvrir"}}\n'
|
||||
']\n'
|
||||
'```\n'
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
actions = planner._steps_to_actions(
|
||||
[{"description": "1. Ouvrir un fichier"}],
|
||||
{},
|
||||
)
|
||||
|
||||
assert len(actions) == 2
|
||||
assert actions[0]["target_spec"]["by_text"] == "Fichier"
|
||||
assert actions[1]["target_spec"]["by_text"] == "Ouvrir"
|
||||
|
||||
def test_steps_to_actions_nested_json(self, planner):
|
||||
"""JSON imbriqué (target_spec) → parsé correctement."""
|
||||
gemma4_response = (
|
||||
'{"type": "click", "target_spec": {"by_text": "OK", "window_title": "Confirmation"}}\n'
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
actions = planner._steps_to_actions(
|
||||
[{"description": "1. Confirmer"}],
|
||||
{},
|
||||
)
|
||||
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["target_spec"]["window_title"] == "Confirmation"
|
||||
|
||||
def test_steps_to_actions_gemma4_erreur(self, planner):
|
||||
"""Erreur gemma4 → liste vide."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.ok = False
|
||||
|
||||
with patch("requests.post", return_value=mock_resp):
|
||||
actions = planner._steps_to_actions(
|
||||
[{"description": "1. Faire quelque chose"}],
|
||||
{},
|
||||
)
|
||||
|
||||
assert actions == []
|
||||
|
||||
def test_steps_to_actions_filtre_types_invalides(self, planner):
|
||||
"""Seuls les types valides (click, type, key_combo, wait) sont acceptés."""
|
||||
gemma4_response = (
|
||||
'{"type": "click", "target_spec": {"by_text": "OK"}}\n'
|
||||
'{"type": "invalid_action", "foo": "bar"}\n'
|
||||
'{"type": "wait", "duration_ms": 500}\n'
|
||||
'{"not_a_type": "test"}\n'
|
||||
)
|
||||
|
||||
with patch("requests.post", return_value=_mock_gemma4_response(gemma4_response)):
|
||||
actions = planner._steps_to_actions(
|
||||
[{"description": "1. Test"}],
|
||||
{},
|
||||
)
|
||||
|
||||
assert len(actions) == 2
|
||||
assert actions[0]["type"] == "click"
|
||||
assert actions[1]["type"] == "wait"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests : _parse_actions_json (parsing robuste)
|
||||
# =========================================================================
|
||||
|
||||
class TestParseActionsJson:
|
||||
"""Tester le parsing robuste d'actions JSON."""
|
||||
|
||||
def test_parse_json_une_par_ligne(self):
|
||||
"""Actions JSON une par ligne."""
|
||||
content = (
|
||||
'{"type": "click", "target_spec": {"by_text": "A"}}\n'
|
||||
'{"type": "type", "text": "hello"}\n'
|
||||
)
|
||||
actions = TaskPlanner._parse_actions_json(content)
|
||||
assert len(actions) == 2
|
||||
|
||||
def test_parse_json_array(self):
|
||||
"""Tableau JSON."""
|
||||
content = '[{"type": "click", "target_spec": {"by_text": "A"}}, {"type": "wait", "duration_ms": 1000}]'
|
||||
actions = TaskPlanner._parse_actions_json(content)
|
||||
assert len(actions) == 2
|
||||
|
||||
def test_parse_json_avec_texte_autour(self):
|
||||
"""JSON entouré de commentaires texte."""
|
||||
content = (
|
||||
"Voici les actions RPA :\n\n"
|
||||
'{"type": "click", "target_spec": {"by_text": "Envoyer"}}\n'
|
||||
"\n"
|
||||
"C'est tout.\n"
|
||||
)
|
||||
actions = TaskPlanner._parse_actions_json(content)
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["target_spec"]["by_text"] == "Envoyer"
|
||||
|
||||
def test_parse_json_vide(self):
|
||||
"""Contenu vide → liste vide."""
|
||||
assert TaskPlanner._parse_actions_json("") == []
|
||||
assert TaskPlanner._parse_actions_json("Pas de JSON ici") == []
|
||||
|
||||
def test_parse_json_markdown_code_block(self):
|
||||
"""JSON dans un bloc de code markdown."""
|
||||
content = (
|
||||
"```json\n"
|
||||
'{"type": "type", "text": "bonjour"}\n'
|
||||
"```\n"
|
||||
)
|
||||
actions = TaskPlanner._parse_actions_json(content)
|
||||
assert len(actions) == 1
|
||||
assert actions[0]["text"] == "bonjour"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests : _extract_session_description
|
||||
# =========================================================================
|
||||
|
||||
class TestExtractSessionDescription:
|
||||
"""Vérifier que les descriptions de session sont lisibles et sémantiques."""
|
||||
|
||||
def _write_events(self, tmp_path, events):
|
||||
"""Écrire des événements dans un fichier JSONL temporaire."""
|
||||
events_file = tmp_path / "live_events.jsonl"
|
||||
with open(events_file, "w") as f:
|
||||
for evt in events:
|
||||
f.write(json.dumps(evt, ensure_ascii=False) + "\n")
|
||||
return events_file
|
||||
|
||||
def test_extract_session_description_bloc_notes(self, tmp_path):
|
||||
"""Session Bloc-notes via Win+R → description sémantique."""
|
||||
events = [
|
||||
{"event": {"type": "key_combo", "keys": ["win", "r"],
|
||||
"window": {"title": "Bureau"}}},
|
||||
{"event": {"type": "window_focus_change",
|
||||
"from": {"title": "Bureau"},
|
||||
"to": {"title": "Exécuter"}}},
|
||||
{"event": {"type": "text_input", "text": "notepad",
|
||||
"window": {"title": "Exécuter"}}},
|
||||
{"event": {"type": "mouse_click", "button": "left",
|
||||
"window": {"title": "Exécuter"}}},
|
||||
{"event": {"type": "window_focus_change",
|
||||
"from": {"title": "Exécuter"},
|
||||
"to": {"title": "Sans titre – Bloc-notes"}}},
|
||||
{"event": {"type": "text_input", "text": "Bonjour le monde",
|
||||
"window": {"title": "Sans titre – Bloc-notes"}}},
|
||||
]
|
||||
events_file = self._write_events(tmp_path, events)
|
||||
|
||||
# Importer depuis api_stream (la fonction est au niveau module)
|
||||
from agent_v0.server_v1.api_stream import _extract_session_description
|
||||
desc = _extract_session_description(events_file)
|
||||
|
||||
assert desc["event_count"] == 6
|
||||
# La description doit être lisible et pas juste "Bloc-notes, Exécuter"
|
||||
description = desc["description"]
|
||||
assert "Bloc-notes" in description or "bloc-notes" in description.lower()
|
||||
# Le nom doit contenir l'app
|
||||
assert "Bloc-notes" in desc["name"]
|
||||
|
||||
def test_extract_session_description_explorateur(self, tmp_path):
|
||||
"""Session Explorateur de fichiers → description pertinente."""
|
||||
events = [
|
||||
{"event": {"type": "window_focus_change",
|
||||
"from": {"title": "Bureau"},
|
||||
"to": {"title": "Images – Explorateur de fichiers"}}},
|
||||
{"event": {"type": "mouse_click", "button": "left",
|
||||
"window": {"title": "Images – Explorateur de fichiers"}}},
|
||||
{"event": {"type": "mouse_click", "button": "left",
|
||||
"window": {"title": "Images – Explorateur de fichiers"}}},
|
||||
{"event": {"type": "mouse_click", "button": "left",
|
||||
"window": {"title": "Images – Explorateur de fichiers"}}},
|
||||
]
|
||||
events_file = self._write_events(tmp_path, events)
|
||||
|
||||
from agent_v0.server_v1.api_stream import _extract_session_description
|
||||
desc = _extract_session_description(events_file)
|
||||
|
||||
assert "Explorateur" in desc["name"] or "Explorateur" in desc["description"]
|
||||
|
||||
def test_extract_session_description_vide(self, tmp_path):
|
||||
"""Fichier vide → description par défaut."""
|
||||
events_file = self._write_events(tmp_path, [])
|
||||
|
||||
from agent_v0.server_v1.api_stream import _extract_session_description
|
||||
desc = _extract_session_description(events_file)
|
||||
|
||||
assert desc["event_count"] == 0
|
||||
assert desc["name"] == "Session sans nom"
|
||||
|
||||
def test_extract_session_description_cmd(self, tmp_path):
|
||||
"""Session avec cmd.exe → description contient cmd."""
|
||||
events = [
|
||||
{"event": {"type": "window_focus_change",
|
||||
"from": {"title": "Bureau"},
|
||||
"to": {"title": "C:\\Windows\\system32\\cmd.exe"}}},
|
||||
{"event": {"type": "text_input", "text": "dir",
|
||||
"window": {"title": "C:\\Windows\\system32\\cmd.exe"}}},
|
||||
{"event": {"type": "text_input", "text": "cd documents",
|
||||
"window": {"title": "C:\\Windows\\system32\\cmd.exe"}}},
|
||||
]
|
||||
events_file = self._write_events(tmp_path, events)
|
||||
|
||||
from agent_v0.server_v1.api_stream import _extract_session_description
|
||||
desc = _extract_session_description(events_file)
|
||||
|
||||
assert desc["event_count"] == 3
|
||||
# Le nom ou la description doit mentionner cmd
|
||||
full = f"{desc['name']} {desc['description']}"
|
||||
assert "cmd" in full.lower()
|
||||
|
||||
def test_extract_session_description_recherche_windows(self, tmp_path):
|
||||
"""Session avec recherche Windows (Win+S) → description mentionne recherche."""
|
||||
events = [
|
||||
{"event": {"type": "key_combo", "keys": ["win", "s"],
|
||||
"window": {"title": "Bureau"}}},
|
||||
{"event": {"type": "window_focus_change",
|
||||
"from": {"title": "Bureau"},
|
||||
"to": {"title": "Rechercher"}}},
|
||||
{"event": {"type": "text_input", "text": "calculator",
|
||||
"window": {"title": "Rechercher"}}},
|
||||
]
|
||||
events_file = self._write_events(tmp_path, events)
|
||||
|
||||
from agent_v0.server_v1.api_stream import _extract_session_description
|
||||
desc = _extract_session_description(events_file)
|
||||
|
||||
# La description doit mentionner la recherche Windows
|
||||
assert "recherche" in desc["description"].lower()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests : list_capabilities
|
||||
# =========================================================================
|
||||
|
||||
class TestListCapabilities:
|
||||
"""Vérifier le listing des capacités."""
|
||||
|
||||
def test_list_capabilities_avec_workflows(self, planner, sample_workflows):
|
||||
"""Avec des workflows → texte lisible avec descriptions."""
|
||||
text = planner.list_capabilities(sample_workflows)
|
||||
assert "Léa sait faire" in text
|
||||
assert "Bloc-notes" in text
|
||||
|
||||
def test_list_capabilities_sans_workflows(self, planner):
|
||||
"""Sans workflows → message d'aide."""
|
||||
text = planner.list_capabilities([])
|
||||
assert "pas encore appris" in text
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests : execute (mode replay et free)
|
||||
# =========================================================================
|
||||
|
||||
class TestExecute:
|
||||
"""Vérifier l'exécution des plans."""
|
||||
|
||||
def test_execute_replay(self, planner):
|
||||
"""Mode replay → callback appelé avec le bon session_id."""
|
||||
plan = TaskPlan(
|
||||
instruction="Ouvre le bloc-notes",
|
||||
understood=True,
|
||||
workflow_match="sess_001",
|
||||
workflow_name="Bloc-notes",
|
||||
mode="replay",
|
||||
)
|
||||
|
||||
callback = MagicMock(return_value="replay_123")
|
||||
result = planner.execute(plan, replay_callback=callback)
|
||||
|
||||
assert result.success is True
|
||||
callback.assert_called_once_with(
|
||||
session_id="sess_001",
|
||||
machine_id="default",
|
||||
params={},
|
||||
)
|
||||
|
||||
def test_execute_non_compris(self, planner):
|
||||
"""Plan non compris → échec."""
|
||||
plan = TaskPlan(instruction="blah", understood=False)
|
||||
result = planner.execute(plan)
|
||||
assert result.success is False
|
||||
assert "non comprise" in result.summary.lower() or "non comprise" in result.summary
|
||||
|
||||
def test_execute_sans_callback(self, planner):
|
||||
"""Mode replay sans callback → échec."""
|
||||
plan = TaskPlan(
|
||||
instruction="test",
|
||||
understood=True,
|
||||
workflow_match="sess_001",
|
||||
mode="replay",
|
||||
)
|
||||
result = planner.execute(plan, replay_callback=None)
|
||||
assert result.success is False
|
||||
419
tests/visual/test_grounding_benchmark.py
Normal file
419
tests/visual/test_grounding_benchmark.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Benchmark de grounding — 3 approches testées en boucle.
|
||||
|
||||
Compare la robustesse et la précision de :
|
||||
1. Baseline : qwen2.5vl direct
|
||||
2. Zoom progressif : 2 passes (full → crop → re-grounding)
|
||||
3. OCR-first : docTR localise le texte, VLM seulement pour les icônes
|
||||
|
||||
Chaque approche est testée N fois sur les mêmes cibles.
|
||||
Mesure : taux de détection, variance des coordonnées, temps moyen.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
_ROOT = str(Path(__file__).resolve().parents[2])
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
_SHOTS_DIR = Path(_ROOT) / "data/training/live_sessions/DESKTOP-ST3VBSD_windows/sess_20260404T135010_cec5c8/shots"
|
||||
|
||||
# Nombre d'itérations par test
|
||||
N_ITERATIONS = 5
|
||||
|
||||
|
||||
def _load_screenshot(name: str) -> str:
|
||||
path = _SHOTS_DIR / name
|
||||
if not path.is_file():
|
||||
pytest.skip(f"Screenshot {name} non disponible")
|
||||
return base64.b64encode(path.read_bytes()).decode()
|
||||
|
||||
|
||||
def _load_screenshot_pil(name: str):
|
||||
from PIL import Image
|
||||
path = _SHOTS_DIR / name
|
||||
if not path.is_file():
|
||||
pytest.skip(f"Screenshot {name} non disponible")
|
||||
return Image.open(path)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approche 1 : Baseline qwen2.5vl direct
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _parse_bbox_2d(content: str) -> Optional[Tuple[int, int, int, int]]:
|
||||
"""Parser les coordonnées bbox_2d depuis une réponse qwen2.5vl.
|
||||
|
||||
qwen2.5vl retourne du JSON :
|
||||
```json
|
||||
[{"bbox_2d": [x1, y1, x2, y2], "label": "..."}]
|
||||
```
|
||||
Les coordonnées sont en pixels relatifs à l'image envoyée.
|
||||
"""
|
||||
# Stratégie 1 : parser le JSON complet (le plus fiable)
|
||||
# Nettoyer les fences markdown
|
||||
cleaned = re.sub(r'```(?:json)?\s*', '', content).strip()
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
if isinstance(data, list) and len(data) > 0:
|
||||
bbox = data[0].get("bbox_2d")
|
||||
if bbox and len(bbox) >= 4:
|
||||
return (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))
|
||||
elif isinstance(data, dict):
|
||||
bbox = data.get("bbox_2d")
|
||||
if bbox and len(bbox) >= 4:
|
||||
return (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))
|
||||
except (json.JSONDecodeError, ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Stratégie 2 : regex ciblé sur "bbox_2d": [x1, y1, x2, y2]
|
||||
bbox_match = re.search(
|
||||
r'"bbox_2d"\s*:\s*\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]',
|
||||
content,
|
||||
)
|
||||
if bbox_match:
|
||||
return tuple(int(bbox_match.group(i)) for i in range(1, 5))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def grounding_baseline(screenshot_b64: str, description: str, img_width: int = 1280, img_height: int = 800) -> Optional[Tuple[float, float]]:
|
||||
"""Grounding qwen2.5vl direct — retourne (x_pct, y_pct) normalisées.
|
||||
|
||||
qwen2.5vl retourne des coordonnées en pixels relatifs à l'image envoyée.
|
||||
On normalise en divisant par les dimensions de l'image.
|
||||
"""
|
||||
import requests
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
"http://localhost:11434/api/chat",
|
||||
json={
|
||||
"model": "qwen2.5vl:7b",
|
||||
"messages": [{"role": "user", "content": f"Detect '{description}' with a bounding box.", "images": [screenshot_b64]}],
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.0, "num_predict": 100},
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if not resp.ok:
|
||||
return None
|
||||
content = resp.json().get("message", {}).get("content", "")
|
||||
bbox = _parse_bbox_2d(content)
|
||||
if bbox:
|
||||
x1, y1, x2, y2 = bbox
|
||||
# Normaliser par les dimensions de l'image (pixels → 0-1)
|
||||
cx = (x1 + x2) / 2 / img_width
|
||||
cy = (y1 + y2) / 2 / img_height
|
||||
if 0.0 <= cx <= 1.0 and 0.0 <= cy <= 1.0:
|
||||
return (cx, cy)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approche 2 : Zoom progressif (2 passes)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def grounding_zoom(screenshot_b64: str, description: str, img_width: int = 1280, img_height: int = 800) -> Optional[Tuple[float, float]]:
|
||||
"""Zoom progressif — passe 1 (full) puis passe 2 (crop 2x)."""
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
# Passe 1 : grounding sur l'image complète
|
||||
result1 = grounding_baseline(screenshot_b64, description, img_width, img_height)
|
||||
if result1 is None:
|
||||
return None
|
||||
|
||||
x1_pct, y1_pct = result1
|
||||
|
||||
# Passe 2 : crop autour de la zone trouvée, re-grounding
|
||||
try:
|
||||
img_bytes = base64.b64decode(screenshot_b64)
|
||||
img = Image.open(io.BytesIO(img_bytes))
|
||||
w, h = img.size
|
||||
|
||||
# Crop 2x autour du point trouvé (25% de l'image de chaque côté)
|
||||
crop_size = 0.25
|
||||
cx_px = int(x1_pct * w)
|
||||
cy_px = int(y1_pct * h)
|
||||
x_left = max(0, cx_px - int(crop_size * w))
|
||||
y_top = max(0, cy_px - int(crop_size * h))
|
||||
x_right = min(w, cx_px + int(crop_size * w))
|
||||
y_bottom = min(h, cy_px + int(crop_size * h))
|
||||
|
||||
cropped = img.crop((x_left, y_top, x_right, y_bottom))
|
||||
crop_w, crop_h = cropped.size
|
||||
|
||||
# Encoder le crop en base64
|
||||
buf = io.BytesIO()
|
||||
cropped.save(buf, format="JPEG", quality=85)
|
||||
crop_b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
# Passe 2 : re-grounding sur le crop (dimensions du crop)
|
||||
result2 = grounding_baseline(crop_b64, description, crop_w, crop_h)
|
||||
if result2 is None:
|
||||
return result1 # Fallback sur passe 1
|
||||
|
||||
# Reconvertir les coordonnées du crop vers l'image originale
|
||||
x2_in_crop, y2_in_crop = result2
|
||||
x_final = (x_left + x2_in_crop * crop_w) / w
|
||||
y_final = (y_top + y2_in_crop * crop_h) / h
|
||||
return (x_final, y_final)
|
||||
|
||||
except Exception:
|
||||
return result1 # Fallback
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approche 3 : OCR-first (docTR)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def grounding_ocr_first(screenshot_b64: str, description: str) -> Optional[Tuple[float, float]]:
|
||||
"""OCR-first — docTR localise le texte, VLM pour les icônes."""
|
||||
try:
|
||||
from doctr.io import DocumentFile
|
||||
from doctr.models import ocr_predictor
|
||||
|
||||
# Décoder l'image
|
||||
img_bytes = base64.b64decode(screenshot_b64)
|
||||
|
||||
# OCR
|
||||
predictor = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
|
||||
doc = DocumentFile.from_images([img_bytes])
|
||||
result = predictor(doc)
|
||||
|
||||
# Chercher le texte dans les résultats OCR
|
||||
target_lower = description.lower()
|
||||
best_match = None
|
||||
best_score = 0
|
||||
|
||||
for page in result.pages:
|
||||
for block in page.blocks:
|
||||
for line_obj in block.lines:
|
||||
for word in line_obj.words:
|
||||
word_text = word.value.lower()
|
||||
# Match exact ou partiel
|
||||
if target_lower in word_text or word_text in target_lower:
|
||||
score = len(word_text) / max(len(target_lower), 1)
|
||||
if score > best_score:
|
||||
# Coordonnées normalisées (docTR retourne 0-1)
|
||||
box = word.geometry # ((x1,y1), (x2,y2))
|
||||
cx = (box[0][0] + box[1][0]) / 2
|
||||
cy = (box[0][1] + box[1][1]) / 2
|
||||
best_match = (cx, cy)
|
||||
best_score = score
|
||||
|
||||
if best_match and best_score > 0.5:
|
||||
return best_match
|
||||
|
||||
except ImportError:
|
||||
pass # docTR non disponible
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback VLM pour les éléments sans texte
|
||||
return grounding_baseline(screenshot_b64, description)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Framework de benchmark
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
approach_fn,
|
||||
approach_name: str,
|
||||
screenshot_b64: str,
|
||||
description: str,
|
||||
n_iterations: int = N_ITERATIONS,
|
||||
) -> Dict:
|
||||
"""Exécuter un benchmark : N itérations, mesurer variance et temps."""
|
||||
results = []
|
||||
times = []
|
||||
|
||||
for i in range(n_iterations):
|
||||
t_start = time.time()
|
||||
result = approach_fn(screenshot_b64, description)
|
||||
elapsed = time.time() - t_start
|
||||
times.append(elapsed)
|
||||
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
|
||||
# Statistiques
|
||||
n_found = len(results)
|
||||
detection_rate = n_found / n_iterations
|
||||
|
||||
stats = {
|
||||
"approach": approach_name,
|
||||
"target": description,
|
||||
"iterations": n_iterations,
|
||||
"detection_rate": round(detection_rate, 2),
|
||||
"avg_time_ms": round(sum(times) / len(times) * 1000, 0),
|
||||
}
|
||||
|
||||
if n_found >= 2:
|
||||
xs = [r[0] for r in results]
|
||||
ys = [r[1] for r in results]
|
||||
stats["x_mean"] = round(sum(xs) / len(xs), 4)
|
||||
stats["y_mean"] = round(sum(ys) / len(ys), 4)
|
||||
stats["x_variance"] = round(max(xs) - min(xs), 4)
|
||||
stats["y_variance"] = round(max(ys) - min(ys), 4)
|
||||
stats["stable"] = stats["x_variance"] < 0.05 and stats["y_variance"] < 0.05
|
||||
elif n_found == 1:
|
||||
stats["x_mean"] = round(results[0][0], 4)
|
||||
stats["y_mean"] = round(results[0][1], 4)
|
||||
stats["x_variance"] = 0
|
||||
stats["y_variance"] = 0
|
||||
stats["stable"] = True
|
||||
else:
|
||||
stats["stable"] = False
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests de benchmark comparatif
|
||||
# =========================================================================
|
||||
|
||||
|
||||
# Cibles à tester (screenshot, description, nom)
|
||||
_TARGETS = [
|
||||
("shot_0001_full.png", "Rechercher", "Rechercher taskbar"),
|
||||
("shot_0001_full.png", "agent_v1", "Dossier agent_v1"),
|
||||
("shot_0004_full.png", "Fichier", "Menu Fichier"),
|
||||
("shot_0004_full.png", "Modifier", "Menu Modifier"),
|
||||
("shot_0004_full.png", "Ceci est un test.txt", "Onglet fichier"),
|
||||
("shot_0014_full.png", "Rechercher sur Google ou saisir une URL", "Recherche Google"),
|
||||
("shot_0014_full.png", "Gmail", "Lien Gmail"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestBenchmarkBaseline:
|
||||
"""Benchmark de l'approche baseline (qwen2.5vl direct)."""
|
||||
|
||||
@pytest.mark.parametrize("shot,desc,name", _TARGETS)
|
||||
def test_baseline_robustesse(self, shot, desc, name):
|
||||
screenshot = _load_screenshot(shot)
|
||||
stats = run_benchmark(grounding_baseline, "baseline", screenshot, desc, N_ITERATIONS)
|
||||
|
||||
print(f"\n [{stats['approach']}] {name}:")
|
||||
print(f" Détection: {stats['detection_rate']*100:.0f}% ({int(stats['detection_rate']*N_ITERATIONS)}/{N_ITERATIONS})")
|
||||
print(f" Temps moyen: {stats['avg_time_ms']:.0f}ms")
|
||||
if stats.get("x_mean") is not None:
|
||||
print(f" Position: ({stats['x_mean']:.3f}, {stats['y_mean']:.3f})")
|
||||
print(f" Variance: X={stats['x_variance']:.4f} Y={stats['y_variance']:.4f}")
|
||||
print(f" Stable: {'OUI' if stats['stable'] else 'NON'}")
|
||||
|
||||
assert stats["detection_rate"] >= 0.6, f"{name}: détection trop faible ({stats['detection_rate']})"
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestBenchmarkZoom:
|
||||
"""Benchmark de l'approche zoom progressif."""
|
||||
|
||||
@pytest.mark.parametrize("shot,desc,name", _TARGETS)
|
||||
def test_zoom_robustesse(self, shot, desc, name):
|
||||
screenshot = _load_screenshot(shot)
|
||||
stats = run_benchmark(grounding_zoom, "zoom", screenshot, desc, N_ITERATIONS)
|
||||
|
||||
print(f"\n [{stats['approach']}] {name}:")
|
||||
print(f" Détection: {stats['detection_rate']*100:.0f}% ({int(stats['detection_rate']*N_ITERATIONS)}/{N_ITERATIONS})")
|
||||
print(f" Temps moyen: {stats['avg_time_ms']:.0f}ms")
|
||||
if stats.get("x_mean") is not None:
|
||||
print(f" Position: ({stats['x_mean']:.3f}, {stats['y_mean']:.3f})")
|
||||
print(f" Variance: X={stats['x_variance']:.4f} Y={stats['y_variance']:.4f}")
|
||||
print(f" Stable: {'OUI' if stats['stable'] else 'NON'}")
|
||||
|
||||
assert stats["detection_rate"] >= 0.6, f"{name}: détection trop faible ({stats['detection_rate']})"
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestBenchmarkCitrix:
|
||||
"""Benchmark baseline sur images dégradées (simulation Citrix JPEG Q20)."""
|
||||
|
||||
def _degrade_citrix(self, screenshot_b64: str) -> str:
|
||||
"""Simuler compression Citrix (JPEG qualité 20)."""
|
||||
from PIL import Image
|
||||
img_bytes = base64.b64decode(screenshot_b64)
|
||||
img = Image.open(io.BytesIO(img_bytes))
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, "JPEG", quality=20)
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
@pytest.mark.parametrize("shot,desc,name", _TARGETS)
|
||||
def test_citrix_robustesse(self, shot, desc, name):
|
||||
screenshot = _load_screenshot(shot)
|
||||
citrix = self._degrade_citrix(screenshot)
|
||||
stats = run_benchmark(grounding_baseline, "citrix_q20", citrix, desc, N_ITERATIONS)
|
||||
|
||||
print(f"\n [{stats['approach']}] {name}:")
|
||||
print(f" Détection: {stats['detection_rate']*100:.0f}%")
|
||||
print(f" Temps moyen: {stats['avg_time_ms']:.0f}ms")
|
||||
if stats.get("x_mean") is not None:
|
||||
print(f" Position: ({stats['x_mean']:.3f}, {stats['y_mean']:.3f})")
|
||||
print(f" Variance: X={stats['x_variance']:.4f} Y={stats['y_variance']:.4f}")
|
||||
print(f" Stable: {'OUI' if stats['stable'] else 'NON'}")
|
||||
|
||||
# Citrix peut être moins fiable — seuil plus bas
|
||||
assert stats["detection_rate"] >= 0.4, f"{name} Citrix: détection trop faible ({stats['detection_rate']})"
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestRapportComparatif:
|
||||
"""Génère un rapport comparatif des 3 approches."""
|
||||
|
||||
def test_rapport_complet(self):
|
||||
"""Exécuter les 3 approches sur toutes les cibles et comparer."""
|
||||
from PIL import Image
|
||||
|
||||
all_results = []
|
||||
|
||||
for shot, desc, name in _TARGETS:
|
||||
screenshot = _load_screenshot(shot)
|
||||
|
||||
# Citrix
|
||||
img_bytes = base64.b64decode(screenshot)
|
||||
img = Image.open(io.BytesIO(img_bytes))
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, "JPEG", quality=20)
|
||||
citrix = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
for approach_fn, approach_name, img_b64 in [
|
||||
(grounding_baseline, "baseline", screenshot),
|
||||
(grounding_zoom, "zoom", screenshot),
|
||||
(grounding_baseline, "citrix_q20", citrix),
|
||||
]:
|
||||
stats = run_benchmark(approach_fn, approach_name, img_b64, desc, 3)
|
||||
stats["target_name"] = name
|
||||
all_results.append(stats)
|
||||
|
||||
# Rapport
|
||||
print("\n" + "=" * 80)
|
||||
print("RAPPORT COMPARATIF — GROUNDING BENCHMARK")
|
||||
print("=" * 80)
|
||||
print(f"{'Cible':<25s} {'Approche':<12s} {'Détect.':<8s} {'Temps':<8s} {'Position':<20s} {'Var X':<8s} {'Var Y':<8s} {'Stable'}")
|
||||
print("-" * 80)
|
||||
for r in all_results:
|
||||
pos = f"({r.get('x_mean',0):.3f}, {r.get('y_mean',0):.3f})" if r.get('x_mean') is not None else "N/A"
|
||||
var_x = f"{r.get('x_variance',0):.4f}" if r.get('x_variance') is not None else "N/A"
|
||||
var_y = f"{r.get('y_variance',0):.4f}" if r.get('y_variance') is not None else "N/A"
|
||||
stable = "OUI" if r.get('stable') else "NON"
|
||||
print(f"{r['target_name']:<25s} {r['approach']:<12s} {r['detection_rate']*100:5.0f}% {r['avg_time_ms']:5.0f}ms {pos:<20s} {var_x:<8s} {var_y:<8s} {stable}")
|
||||
print("=" * 80)
|
||||
445
tests/visual/test_visual_grounding.py
Normal file
445
tests/visual/test_visual_grounding.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Tests visuels sur captures d'écran réelles — Grounding benchmark.
|
||||
|
||||
Vérifie que le système trouve les bons éléments UI sur des screenshots
|
||||
Windows réels. Pas besoin de VM — juste les images et le serveur.
|
||||
|
||||
Chaque test :
|
||||
1. Charge un screenshot réel (sessions enregistrées)
|
||||
2. Demande au serveur de localiser un élément (via /resolve_target)
|
||||
3. Vérifie que les coordonnées retournées sont dans la zone attendue
|
||||
|
||||
C'est l'apprentissage de l'environnement Windows :
|
||||
- Rechercher un programme
|
||||
- Fermer/réduire/agrandir une fenêtre
|
||||
- Naviguer dans les onglets
|
||||
- Utiliser les menus
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
_ROOT = str(Path(__file__).resolve().parents[2])
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
# Répertoire des screenshots de test
|
||||
_SHOTS_DIR = Path(_ROOT) / "data/training/live_sessions/DESKTOP-ST3VBSD_windows/sess_20260404T135010_cec5c8/shots"
|
||||
|
||||
# Résolution des screenshots
|
||||
_SCREEN_W = 1280
|
||||
_SCREEN_H = 800
|
||||
|
||||
|
||||
def _load_screenshot(name: str) -> Optional[str]:
|
||||
"""Charger un screenshot en base64."""
|
||||
path = _SHOTS_DIR / name
|
||||
if not path.is_file():
|
||||
pytest.skip(f"Screenshot {name} non disponible")
|
||||
return base64.b64encode(path.read_bytes()).decode()
|
||||
|
||||
|
||||
def _in_zone(x_pct: float, y_pct: float, zone: dict) -> bool:
|
||||
"""Vérifier si un point est dans une zone attendue.
|
||||
|
||||
zone = {"x_min": 0.3, "x_max": 0.5, "y_min": 0.9, "y_max": 1.0}
|
||||
"""
|
||||
return (
|
||||
zone["x_min"] <= x_pct <= zone["x_max"]
|
||||
and zone["y_min"] <= y_pct <= zone["y_max"]
|
||||
)
|
||||
|
||||
|
||||
def _resolve_via_server(
|
||||
screenshot_b64: str,
|
||||
target_spec: dict,
|
||||
strict: bool = True,
|
||||
) -> Optional[dict]:
|
||||
"""Résoudre une cible visuellement via le VLM (qwen2.5vl grounding direct).
|
||||
|
||||
Appelle qwen2.5vl directement pour le grounding (bbox_2d).
|
||||
Si le VLM ne trouve pas, essaie aussi via l'endpoint serveur.
|
||||
"""
|
||||
import requests
|
||||
import re
|
||||
|
||||
# ── Stratégie 1 : Grounding VLM direct (qwen2.5vl) ──
|
||||
by_text = target_spec.get("by_text", "")
|
||||
vlm_desc = target_spec.get("vlm_description", "")
|
||||
search_text = by_text or vlm_desc
|
||||
|
||||
if search_text:
|
||||
try:
|
||||
prompt = f"Detect the element '{search_text}' with a bounding box."
|
||||
resp = requests.post(
|
||||
"http://localhost:11434/api/chat",
|
||||
json={
|
||||
"model": "qwen2.5vl:7b",
|
||||
"messages": [{"role": "user", "content": prompt, "images": [screenshot_b64]}],
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.0, "num_predict": 100},
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if resp.ok:
|
||||
content = resp.json().get("message", {}).get("content", "")
|
||||
# Parser bbox_2d — qwen2.5vl retourne des pixels relatifs à l'image,
|
||||
# PAS une grille 1000x1000.
|
||||
bbox_match = re.search(
|
||||
r'"bbox_2d"\s*:\s*\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]',
|
||||
content,
|
||||
)
|
||||
if bbox_match:
|
||||
x1, y1, x2, y2 = [int(bbox_match.group(i)) for i in range(1, 5)]
|
||||
# Normaliser par les dimensions de l'image (pixels → 0-1)
|
||||
cx = (x1 + x2) / 2 / _SCREEN_W
|
||||
cy = (y1 + y2) / 2 / _SCREEN_H
|
||||
if 0.0 <= cx <= 1.0 and 0.0 <= cy <= 1.0:
|
||||
return {
|
||||
"resolved": True,
|
||||
"method": "vlm_grounding",
|
||||
"x_pct": cx,
|
||||
"y_pct": cy,
|
||||
"score": 0.8,
|
||||
"raw_bbox": [x1, y1, x2, y2],
|
||||
}
|
||||
except requests.Timeout:
|
||||
pytest.skip("qwen2.5vl timeout — premier chargement ?")
|
||||
except requests.ConnectionError:
|
||||
pytest.skip("Ollama non disponible (localhost:11434)")
|
||||
|
||||
# ── Stratégie 2 : Endpoint serveur (fallback) ──
|
||||
token = os.environ.get("RPA_API_TOKEN", "")
|
||||
if not token:
|
||||
env_file = Path(_ROOT) / ".env.local"
|
||||
if env_file.is_file():
|
||||
for line in env_file.read_text().splitlines():
|
||||
if line.startswith("RPA_API_TOKEN="):
|
||||
token = line.split("=", 1)[1].strip()
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
"http://localhost:5005/api/v1/traces/stream/replay/resolve_target",
|
||||
json={
|
||||
"session_id": "visual_test",
|
||||
"screenshot_b64": screenshot_b64,
|
||||
"target_spec": target_spec,
|
||||
"screen_width": _SCREEN_W,
|
||||
"screen_height": _SCREEN_H,
|
||||
"fallback_x_pct": 0.5,
|
||||
"fallback_y_pct": 0.5,
|
||||
"strict_mode": strict,
|
||||
},
|
||||
headers=headers,
|
||||
timeout=30,
|
||||
)
|
||||
if resp.ok:
|
||||
data = resp.json()
|
||||
if data.get("resolved"):
|
||||
return data
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _assert_found_in_zone(result: dict, zone: dict, element_name: str):
|
||||
"""Vérifier qu'un élément a été trouvé dans la zone attendue."""
|
||||
assert result is not None, f"{element_name}: pas de réponse du serveur"
|
||||
assert result.get("resolved"), (
|
||||
f"{element_name}: non trouvé (reason={result.get('reason', '?')})"
|
||||
)
|
||||
x = result.get("x_pct", 0)
|
||||
y = result.get("y_pct", 0)
|
||||
assert _in_zone(x, y, zone), (
|
||||
f"{element_name}: trouvé à ({x:.3f}, {y:.3f}) "
|
||||
f"mais attendu dans zone x=[{zone['x_min']:.2f}-{zone['x_max']:.2f}] "
|
||||
f"y=[{zone['y_min']:.2f}-{zone['y_max']:.2f}]"
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# shot_0001 : Explorateur de fichiers Windows
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestExplorateurFichiers:
|
||||
"""Tests sur l'Explorateur de fichiers Windows (shot_0001)."""
|
||||
|
||||
@pytest.fixture
|
||||
def screenshot(self):
|
||||
return _load_screenshot("shot_0001_full.png")
|
||||
|
||||
def test_trouver_rechercher_taskbar(self, screenshot):
|
||||
"""Trouver 'Rechercher' dans la barre des tâches."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "Rechercher",
|
||||
"vlm_description": "La barre de recherche Windows dans la barre des tâches, en bas de l'écran",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.20, "x_max": 0.50,
|
||||
"y_min": 0.90, "y_max": 1.00,
|
||||
}, "Rechercher (taskbar)")
|
||||
|
||||
def test_trouver_bouton_fermer_explorateur(self, screenshot):
|
||||
"""Trouver le bouton X (fermer) de l'Explorateur."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "",
|
||||
"vlm_description": "Le bouton fermer (X) de la fenêtre Explorateur de fichiers, en haut à droite",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.90, "x_max": 1.00,
|
||||
"y_min": 0.00, "y_max": 0.05,
|
||||
}, "Bouton fermer (X)")
|
||||
|
||||
def test_trouver_bouton_reduire(self, screenshot):
|
||||
"""Trouver le bouton réduire (-) de l'Explorateur."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "",
|
||||
"vlm_description": "Le bouton réduire (minimize, -) de la fenêtre, en haut à droite à gauche du X",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.85, "x_max": 0.95,
|
||||
"y_min": 0.00, "y_max": 0.05,
|
||||
}, "Bouton réduire (-)")
|
||||
|
||||
def test_trouver_dossier_agent_v1(self, screenshot):
|
||||
"""Trouver le dossier 'agent_v1' dans la liste des fichiers."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "agent_v1",
|
||||
"vlm_description": "Le dossier agent_v1 dans la liste des fichiers de l'Explorateur",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.05, "x_max": 0.50,
|
||||
"y_min": 0.10, "y_max": 0.30,
|
||||
}, "Dossier agent_v1")
|
||||
|
||||
def test_trouver_bouton_demarrer(self, screenshot):
|
||||
"""Trouver le bouton Démarrer (Windows) dans la barre des tâches."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "",
|
||||
"vlm_description": "Le bouton Démarrer (logo Windows) dans la barre des tâches, en bas",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.18, "x_max": 0.30,
|
||||
"y_min": 0.90, "y_max": 1.00,
|
||||
}, "Bouton Démarrer")
|
||||
|
||||
def test_trouver_ce_pc(self, screenshot):
|
||||
"""Trouver 'Ce PC' dans le panneau latéral de l'Explorateur."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "Ce PC",
|
||||
"vlm_description": "L'élément 'Ce PC' dans le panneau de navigation gauche de l'Explorateur",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.00, "x_max": 0.12,
|
||||
"y_min": 0.40, "y_max": 0.55,
|
||||
}, "Ce PC")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# shot_0004 : Bloc-notes avec onglets + Explorateur derrière
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestBlocNotesOnglets:
|
||||
"""Tests sur le Bloc-notes avec plusieurs onglets (shot_0004)."""
|
||||
|
||||
@pytest.fixture
|
||||
def screenshot(self):
|
||||
return _load_screenshot("shot_0004_full.png")
|
||||
|
||||
def test_trouver_menu_fichier(self, screenshot):
|
||||
"""Trouver le menu 'Fichier' du Bloc-notes."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "Fichier",
|
||||
"vlm_description": "Le menu Fichier dans la barre de menus du Bloc-notes",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.02, "x_max": 0.10,
|
||||
"y_min": 0.08, "y_max": 0.15,
|
||||
}, "Menu Fichier")
|
||||
|
||||
def test_trouver_onglet_ceci_est_un_test(self, screenshot):
|
||||
"""Trouver l'onglet 'Ceci est un test.txt' dans le Bloc-notes."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "Ceci est un test",
|
||||
"vlm_description": "L'onglet 'Ceci est un test.txt' dans le Bloc-notes",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.40, "x_max": 0.70,
|
||||
"y_min": 0.03, "y_max": 0.10,
|
||||
}, "Onglet 'Ceci est un test.txt'")
|
||||
|
||||
def test_trouver_nouvel_onglet_plus(self, screenshot):
|
||||
"""Trouver le bouton '+' pour ajouter un nouvel onglet."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "",
|
||||
"vlm_description": "Le bouton + (plus) pour ajouter un nouvel onglet dans le Bloc-notes, à droite des onglets",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.55, "x_max": 0.70,
|
||||
"y_min": 0.03, "y_max": 0.10,
|
||||
}, "Bouton + (nouvel onglet)")
|
||||
|
||||
def test_trouver_bouton_fermer_onglet(self, screenshot):
|
||||
"""Trouver le X de fermeture de l'onglet actif."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "",
|
||||
"vlm_description": "Le bouton X pour fermer l'onglet actif 'Ceci est un test.txt' dans le Bloc-notes",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.50, "x_max": 0.65,
|
||||
"y_min": 0.03, "y_max": 0.10,
|
||||
}, "Fermer onglet (X)")
|
||||
|
||||
def test_trouver_menu_modifier(self, screenshot):
|
||||
"""Trouver le menu 'Modifier' du Bloc-notes."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "Modifier",
|
||||
"vlm_description": "Le menu Modifier dans la barre de menus du Bloc-notes",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.07, "x_max": 0.16,
|
||||
"y_min": 0.08, "y_max": 0.15,
|
||||
}, "Menu Modifier")
|
||||
|
||||
def test_trouver_encodage_utf8(self, screenshot):
|
||||
"""Trouver l'indicateur d'encodage UTF-8 dans la barre de statut."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "UTF-8",
|
||||
"vlm_description": "L'indicateur d'encodage UTF-8 dans la barre de statut en bas du Bloc-notes",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.60, "x_max": 0.80,
|
||||
"y_min": 0.90, "y_max": 1.00,
|
||||
}, "UTF-8 (barre de statut)")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# shot_0014 : Google Chrome page d'accueil
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestGoogleChrome:
|
||||
"""Tests sur Google Chrome avec page d'accueil (shot_0014)."""
|
||||
|
||||
@pytest.fixture
|
||||
def screenshot(self):
|
||||
return _load_screenshot("shot_0014_full.png")
|
||||
|
||||
def test_trouver_barre_recherche_google(self, screenshot):
|
||||
"""Trouver la barre de recherche Google au centre."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "Rechercher sur Google",
|
||||
"vlm_description": "La barre de recherche Google au centre de la page d'accueil",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.10, "x_max": 0.60,
|
||||
"y_min": 0.30, "y_max": 0.50,
|
||||
}, "Barre recherche Google")
|
||||
|
||||
def test_trouver_barre_adresse_chrome(self, screenshot):
|
||||
"""Trouver la barre d'adresse de Chrome en haut."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "",
|
||||
"vlm_description": "La barre d'adresse URL de Google Chrome, en haut du navigateur",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.10, "x_max": 0.60,
|
||||
"y_min": 0.05, "y_max": 0.15,
|
||||
}, "Barre d'adresse Chrome")
|
||||
|
||||
def test_trouver_nouvel_onglet_chrome(self, screenshot):
|
||||
"""Trouver le bouton '+' pour un nouvel onglet Chrome."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "",
|
||||
"vlm_description": "Le bouton + pour ouvrir un nouvel onglet dans Google Chrome",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.15, "x_max": 0.25,
|
||||
"y_min": 0.00, "y_max": 0.06,
|
||||
}, "Nouvel onglet (+) Chrome")
|
||||
|
||||
def test_trouver_fermer_chrome(self, screenshot):
|
||||
"""Trouver le bouton X pour fermer Chrome."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "",
|
||||
"vlm_description": "Le bouton fermer (X) de la fenêtre Google Chrome, en haut à droite",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.90, "x_max": 1.00,
|
||||
"y_min": 0.00, "y_max": 0.06,
|
||||
}, "Fermer Chrome (X)")
|
||||
|
||||
def test_trouver_gmail(self, screenshot):
|
||||
"""Trouver le lien Gmail sur la page d'accueil Google."""
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "Gmail",
|
||||
"vlm_description": "Le lien Gmail en haut à droite de la page Google",
|
||||
})
|
||||
_assert_found_in_zone(result, {
|
||||
"x_min": 0.50, "x_max": 0.80,
|
||||
"y_min": 0.10, "y_max": 0.20,
|
||||
}, "Gmail")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests transversaux (connaissances de base Windows)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestConnaissancesWindowsBase:
|
||||
"""Connaissances de base Windows que tout utilisateur connaît."""
|
||||
|
||||
def test_rechercher_programme_depuis_explorateur(self):
|
||||
"""Depuis l'Explorateur, trouver la barre de recherche Windows."""
|
||||
screenshot = _load_screenshot("shot_0001_full.png")
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "Rechercher",
|
||||
"vlm_description": "La barre de recherche dans la barre des tâches Windows en bas de l'écran",
|
||||
})
|
||||
assert result and result.get("resolved"), "Rechercher non trouvé"
|
||||
|
||||
def test_fermer_programme_depuis_blocnotes(self):
|
||||
"""Depuis le Bloc-notes, trouver le bouton fermer."""
|
||||
screenshot = _load_screenshot("shot_0004_full.png")
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "",
|
||||
"vlm_description": "Le bouton X pour fermer la fenêtre du Bloc-notes, en haut à droite",
|
||||
})
|
||||
assert result and result.get("resolved"), "Bouton fermer non trouvé"
|
||||
|
||||
def test_ajouter_onglet_blocnotes(self):
|
||||
"""Ajouter un nouvel onglet dans le Bloc-notes."""
|
||||
screenshot = _load_screenshot("shot_0004_full.png")
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "",
|
||||
"vlm_description": "Le bouton + pour ajouter un nouvel onglet dans le Bloc-notes",
|
||||
})
|
||||
assert result and result.get("resolved"), "Bouton + non trouvé"
|
||||
|
||||
def test_rechercher_sur_google(self):
|
||||
"""Taper dans la barre de recherche Google."""
|
||||
screenshot = _load_screenshot("shot_0014_full.png")
|
||||
result = _resolve_via_server(screenshot, {
|
||||
"by_text": "Rechercher sur Google",
|
||||
"vlm_description": "Le champ de recherche Google",
|
||||
})
|
||||
assert result and result.get("resolved"), "Recherche Google non trouvée"
|
||||
864
tests/visual/test_visual_robustness.py
Normal file
864
tests/visual/test_visual_robustness.py
Normal file
@@ -0,0 +1,864 @@
|
||||
"""
|
||||
Tests de robustesse visuelle — Grounding VLM qwen2.5vl:7b.
|
||||
|
||||
Objectifs :
|
||||
1. Reproductibilité : même screenshot + même cible → même résultat 10 fois
|
||||
2. Robustesse Citrix : screenshots compressés JPEG qualité 15-25 → ça marche
|
||||
3. Mesure de variance : coordonnées stables à < 5% de l'écran
|
||||
|
||||
Architecture des coordonnées qwen2.5vl :
|
||||
- Format bbox_2d : [x1, y1, x2, y2] en pixels relatifs à l'image envoyée
|
||||
- Pour une image 1280x800, X va de 0 à 1280 et Y de 0 à 800
|
||||
- Normalisation : diviser par les dimensions de l'image (pas par 1000)
|
||||
|
||||
Calibration mesurée (5 avril 2026) sur screenshots 1280x800 :
|
||||
- shot_0001/Rechercher (taskbar) : cx=0.458, cy=0.789
|
||||
- shot_0001/agent_v1 (dossier) : cx=0.247, cy=0.201
|
||||
- shot_0004/Fichier (menu) : cx=0.095, cy=0.086
|
||||
- shot_0004/Modifier (menu) : cx=0.142, cy=0.085
|
||||
- shot_0004/Ceci est un test.txt (onglet): cx=0.694, cy=0.053
|
||||
- shot_0004/Close X (Bloc-notes) : cx=0.990, cy=0.041
|
||||
- shot_0014/Google search (centre) : cx=0.539, cy=0.389
|
||||
- shot_0014/Gmail (haut-droite) : cx=0.913, cy=0.130
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
_ROOT = str(Path(__file__).resolve().parents[2])
|
||||
if _ROOT not in sys.path:
|
||||
sys.path.insert(0, _ROOT)
|
||||
|
||||
# Répertoire des screenshots de test
|
||||
_SHOTS_DIR = (
|
||||
Path(_ROOT)
|
||||
/ "data/training/live_sessions/DESKTOP-ST3VBSD_windows"
|
||||
/ "sess_20260404T135010_cec5c8/shots"
|
||||
)
|
||||
|
||||
# Résolution des screenshots
|
||||
_SCREEN_W = 1280
|
||||
_SCREEN_H = 800
|
||||
|
||||
# Nombre de répétitions pour les tests de reproductibilité
|
||||
_N_REPEATS = 10
|
||||
|
||||
# Tolérance de variance maximale (en fraction de l'écran, 0.05 = 5%)
|
||||
_MAX_VARIANCE = 0.05
|
||||
|
||||
# Taux de détection minimal (X sur _N_REPEATS)
|
||||
_MIN_DETECTION_RATE = 8
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Utilitaires
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _load_screenshot(name: str) -> Optional[str]:
|
||||
"""Charger un screenshot en base64."""
|
||||
path = _SHOTS_DIR / name
|
||||
if not path.is_file():
|
||||
pytest.skip(f"Screenshot {name} non disponible")
|
||||
return base64.b64encode(path.read_bytes()).decode()
|
||||
|
||||
|
||||
def _degrade_citrix(screenshot_b64: str, quality: int = 20) -> str:
|
||||
"""Simuler compression Citrix : JPEG qualité basse puis retour PNG b64."""
|
||||
from PIL import Image
|
||||
|
||||
raw = base64.b64decode(screenshot_b64)
|
||||
img = Image.open(io.BytesIO(raw))
|
||||
|
||||
# Compression JPEG qualité basse (simulation Citrix)
|
||||
buf_jpg = io.BytesIO()
|
||||
img.save(buf_jpg, "JPEG", quality=quality)
|
||||
buf_jpg.seek(0)
|
||||
citrix_img = Image.open(buf_jpg)
|
||||
|
||||
# Re-encoder en PNG pour l'envoi au VLM
|
||||
buf_png = io.BytesIO()
|
||||
citrix_img.save(buf_png, "PNG")
|
||||
return base64.b64encode(buf_png.getvalue()).decode()
|
||||
|
||||
|
||||
def _grounding_vlm(
|
||||
screenshot_b64: str,
|
||||
element_description: str,
|
||||
timeout: int = 60,
|
||||
) -> Tuple[Optional[float], Optional[float], Optional[List[int]], str]:
|
||||
"""Appeler qwen2.5vl pour localiser un élément.
|
||||
|
||||
Retourne (cx, cy, [x1,y1,x2,y2], raw_content).
|
||||
cx et cy sont les centres normalisés sur la grille 1000.
|
||||
"""
|
||||
import requests
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
"http://localhost:11434/api/chat",
|
||||
json={
|
||||
"model": "qwen2.5vl:7b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Detect the element '{element_description}' "
|
||||
f"with a bounding box."
|
||||
),
|
||||
"images": [screenshot_b64],
|
||||
}
|
||||
],
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1, "num_predict": 100},
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
except requests.ConnectionError:
|
||||
pytest.skip("Ollama non disponible (localhost:11434)")
|
||||
except requests.Timeout:
|
||||
pytest.skip("qwen2.5vl timeout — modèle en cours de chargement ?")
|
||||
|
||||
content = resp.json().get("message", {}).get("content", "")
|
||||
|
||||
# Parser bbox_2d depuis la réponse JSON
|
||||
# qwen2.5vl retourne des coordonnées en pixels relatifs à l'image envoyée,
|
||||
# PAS sur une grille 1000x1000.
|
||||
bbox_match = re.search(
|
||||
r'"bbox_2d"\s*:\s*\[(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\]',
|
||||
content,
|
||||
)
|
||||
if bbox_match:
|
||||
x1, y1, x2, y2 = [int(bbox_match.group(i)) for i in range(1, 5)]
|
||||
# Normaliser par les dimensions de l'image (pixels → 0-1)
|
||||
cx = (x1 + x2) / 2 / _SCREEN_W
|
||||
cy = (y1 + y2) / 2 / _SCREEN_H
|
||||
return cx, cy, [x1, y1, x2, y2], content
|
||||
|
||||
return None, None, None, content
|
||||
|
||||
|
||||
def _run_n_times(
|
||||
screenshot_b64: str,
|
||||
description: str,
|
||||
n: int = _N_REPEATS,
|
||||
delay: float = 0.2,
|
||||
) -> List[Dict]:
|
||||
"""Exécuter le grounding N fois et collecter les résultats."""
|
||||
results = []
|
||||
for i in range(n):
|
||||
cx, cy, bbox, raw = _grounding_vlm(screenshot_b64, description)
|
||||
results.append({
|
||||
"run": i + 1,
|
||||
"cx": cx,
|
||||
"cy": cy,
|
||||
"bbox": bbox,
|
||||
"detected": cx is not None,
|
||||
"raw": raw,
|
||||
})
|
||||
if i < n - 1:
|
||||
time.sleep(delay)
|
||||
return results
|
||||
|
||||
|
||||
def _compute_stats(results: List[Dict]) -> Dict:
|
||||
"""Calculer les statistiques de détection et de variance."""
|
||||
detected = [r for r in results if r["detected"]]
|
||||
n_total = len(results)
|
||||
n_detected = len(detected)
|
||||
|
||||
stats = {
|
||||
"total": n_total,
|
||||
"detected": n_detected,
|
||||
"rate": n_detected / n_total if n_total > 0 else 0,
|
||||
"rate_str": f"{n_detected}/{n_total}",
|
||||
}
|
||||
|
||||
if n_detected >= 2:
|
||||
xs = [r["cx"] for r in detected]
|
||||
ys = [r["cy"] for r in detected]
|
||||
stats.update({
|
||||
"x_min": min(xs),
|
||||
"x_max": max(xs),
|
||||
"x_mean": statistics.mean(xs),
|
||||
"x_range": max(xs) - min(xs),
|
||||
"x_stdev": statistics.stdev(xs) if n_detected >= 2 else 0,
|
||||
"y_min": min(ys),
|
||||
"y_max": max(ys),
|
||||
"y_mean": statistics.mean(ys),
|
||||
"y_range": max(ys) - min(ys),
|
||||
"y_stdev": statistics.stdev(ys) if n_detected >= 2 else 0,
|
||||
})
|
||||
elif n_detected == 1:
|
||||
stats.update({
|
||||
"x_min": detected[0]["cx"],
|
||||
"x_max": detected[0]["cx"],
|
||||
"x_mean": detected[0]["cx"],
|
||||
"x_range": 0,
|
||||
"x_stdev": 0,
|
||||
"y_min": detected[0]["cy"],
|
||||
"y_max": detected[0]["cy"],
|
||||
"y_mean": detected[0]["cy"],
|
||||
"y_range": 0,
|
||||
"y_stdev": 0,
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def _assert_reproducible(
|
||||
stats: Dict,
|
||||
element_name: str,
|
||||
min_rate: int = _MIN_DETECTION_RATE,
|
||||
max_var: float = _MAX_VARIANCE,
|
||||
):
|
||||
"""Vérifier la reproductibilité : taux de détection + variance faible."""
|
||||
assert stats["detected"] >= min_rate, (
|
||||
f"{element_name}: seulement {stats['rate_str']} détections "
|
||||
f"(minimum requis: {min_rate}/{stats['total']})"
|
||||
)
|
||||
|
||||
if stats["detected"] >= 2:
|
||||
assert stats["x_range"] < max_var, (
|
||||
f"{element_name}: variance X trop élevée: "
|
||||
f"{stats['x_range']:.4f} (max={max_var})"
|
||||
)
|
||||
assert stats["y_range"] < max_var, (
|
||||
f"{element_name}: variance Y trop élevée: "
|
||||
f"{stats['y_range']:.4f} (max={max_var})"
|
||||
)
|
||||
|
||||
|
||||
def _assert_in_zone(
|
||||
stats: Dict,
|
||||
zone: Dict[str, float],
|
||||
element_name: str,
|
||||
):
|
||||
"""Vérifier que la position moyenne est dans la zone attendue."""
|
||||
assert stats["detected"] >= 1, f"{element_name}: aucune détection"
|
||||
cx = stats["x_mean"]
|
||||
cy = stats["y_mean"]
|
||||
assert zone["x_min"] <= cx <= zone["x_max"], (
|
||||
f"{element_name}: X moyen {cx:.4f} hors zone "
|
||||
f"[{zone['x_min']:.2f}-{zone['x_max']:.2f}]"
|
||||
)
|
||||
assert zone["y_min"] <= cy <= zone["y_max"], (
|
||||
f"{element_name}: Y moyen {cy:.4f} hors zone "
|
||||
f"[{zone['y_min']:.2f}-{zone['y_max']:.2f}]"
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Zones calibrées (mesurées le 5 avril 2026)
|
||||
# =========================================================================
|
||||
|
||||
CALIBRATED_ZONES = {
|
||||
# shot_0001 — Explorateur de fichiers Windows
|
||||
"rechercher_taskbar": {
|
||||
"x_min": 0.40, "x_max": 0.60,
|
||||
"y_min": 0.74, "y_max": 0.84,
|
||||
},
|
||||
"agent_v1_folder": {
|
||||
"x_min": 0.18, "x_max": 0.30,
|
||||
"y_min": 0.16, "y_max": 0.26,
|
||||
},
|
||||
# shot_0004 — Bloc-notes avec onglets
|
||||
"fichier_menu": {
|
||||
"x_min": 0.06, "x_max": 0.13,
|
||||
"y_min": 0.06, "y_max": 0.12,
|
||||
},
|
||||
"modifier_menu": {
|
||||
"x_min": 0.11, "x_max": 0.18,
|
||||
"y_min": 0.06, "y_max": 0.12,
|
||||
},
|
||||
"ceci_est_un_test_tab": {
|
||||
"x_min": 0.65, "x_max": 0.75,
|
||||
"y_min": 0.03, "y_max": 0.08,
|
||||
},
|
||||
"close_x_notepad": {
|
||||
"x_min": 0.95, "x_max": 1.02,
|
||||
"y_min": 0.02, "y_max": 0.06,
|
||||
},
|
||||
# shot_0014 — Google Chrome
|
||||
"google_search_bar": {
|
||||
"x_min": 0.48, "x_max": 0.60,
|
||||
"y_min": 0.35, "y_max": 0.43,
|
||||
},
|
||||
"gmail_link": {
|
||||
"x_min": 0.87, "x_max": 0.95,
|
||||
"y_min": 0.10, "y_max": 0.16,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests de reproductibilité — 10 appels consécutifs
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestReproductibilite:
|
||||
"""Chaque test appelle le VLM 10 fois et vérifie la cohérence.
|
||||
|
||||
Critères de réussite :
|
||||
- Au moins 8/10 détections
|
||||
- Variance des coordonnées < 5% de l'écran sur chaque axe
|
||||
- Position moyenne dans la zone calibrée
|
||||
"""
|
||||
|
||||
# -- shot_0001 : Explorateur de fichiers --
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def shot_0001(self):
|
||||
return _load_screenshot("shot_0001_full.png")
|
||||
|
||||
def test_rechercher_10_fois(self, shot_0001):
|
||||
"""Le VLM trouve 'Rechercher' au même endroit 10 fois de suite."""
|
||||
results = _run_n_times(
|
||||
shot_0001,
|
||||
"the 'Rechercher' search text in the Windows taskbar at the bottom",
|
||||
)
|
||||
stats = _compute_stats(results)
|
||||
_assert_reproducible(stats, "Rechercher (taskbar)")
|
||||
_assert_in_zone(stats, CALIBRATED_ZONES["rechercher_taskbar"], "Rechercher")
|
||||
# Afficher le résumé pour le rapport
|
||||
print(f"\n [Rechercher] {stats['rate_str']} détections, "
|
||||
f"X=[{stats.get('x_min', 0):.4f}-{stats.get('x_max', 0):.4f}], "
|
||||
f"Y=[{stats.get('y_min', 0):.4f}-{stats.get('y_max', 0):.4f}]")
|
||||
|
||||
def test_agent_v1_10_fois(self, shot_0001):
|
||||
"""Le VLM trouve le dossier 'agent_v1' au même endroit 10 fois."""
|
||||
results = _run_n_times(
|
||||
shot_0001,
|
||||
"the folder named 'agent_v1' in the file list",
|
||||
)
|
||||
stats = _compute_stats(results)
|
||||
_assert_reproducible(stats, "agent_v1 (dossier)")
|
||||
_assert_in_zone(stats, CALIBRATED_ZONES["agent_v1_folder"], "agent_v1")
|
||||
print(f"\n [agent_v1] {stats['rate_str']} détections, "
|
||||
f"X=[{stats.get('x_min', 0):.4f}-{stats.get('x_max', 0):.4f}], "
|
||||
f"Y=[{stats.get('y_min', 0):.4f}-{stats.get('y_max', 0):.4f}]")
|
||||
|
||||
def test_close_x_explorateur_10_fois(self, shot_0001):
|
||||
"""Le bouton X de la fenêtre maximisée : overflow X attendu.
|
||||
|
||||
Ce test vérifie que le VLM détecte bien le bouton X de façon cohérente.
|
||||
Sur les fenêtres maximisées (1280px de large), les coordonnées X
|
||||
dépassent la grille 1000 normalisée (cx > 1.0).
|
||||
|
||||
Note : le VLM peut parfois confondre le bouton X de la fenêtre avec
|
||||
celui de l'onglet (ambiguïté multiple close buttons). On vérifie
|
||||
que la majorité des détections ciblent le bon bouton.
|
||||
"""
|
||||
results = _run_n_times(
|
||||
shot_0001,
|
||||
"the X close button of the 'Lea' window",
|
||||
)
|
||||
# Vérifier que le VLM détecte bien quelque chose
|
||||
detected = [r for r in results if r["detected"]]
|
||||
assert len(detected) >= _MIN_DETECTION_RATE, (
|
||||
f"Close X: seulement {len(detected)}/{len(results)} détections"
|
||||
)
|
||||
|
||||
# Classer les détections : overflow (bouton fenêtre) vs non-overflow (bouton onglet)
|
||||
overflows = [r for r in detected if r["cx"] > 1.0]
|
||||
non_overflows = [r for r in detected if r["cx"] <= 1.0]
|
||||
|
||||
# Au moins 60% des détections doivent viser le bouton fenêtre (overflow)
|
||||
assert len(overflows) >= len(detected) * 0.6, (
|
||||
f"Close X: seulement {len(overflows)}/{len(detected)} en overflow. "
|
||||
f"Ambiguïté avec bouton onglet ({len(non_overflows)} non-overflow)."
|
||||
)
|
||||
|
||||
# Vérifier la cohérence des détections overflow (le cluster principal)
|
||||
if len(overflows) >= 2:
|
||||
bboxes = [r["bbox"] for r in overflows]
|
||||
x1s = [b[0] for b in bboxes]
|
||||
y1s = [b[1] for b in bboxes]
|
||||
assert max(x1s) - min(x1s) < 20, (
|
||||
f"Close X overflow: x1 trop variable: {min(x1s)}-{max(x1s)}"
|
||||
)
|
||||
assert max(y1s) - min(y1s) < 20, (
|
||||
f"Close X overflow: y1 trop variable: {min(y1s)}-{max(y1s)}"
|
||||
)
|
||||
|
||||
print(f"\n [Close X Explorer] {len(detected)}/{len(results)} détections, "
|
||||
f"{len(overflows)} overflow (fenêtre), {len(non_overflows)} non-overflow (onglet). "
|
||||
f"cx_mean_overflow={statistics.mean([r['cx'] for r in overflows]):.4f}" if overflows else "")
|
||||
|
||||
# -- shot_0004 : Bloc-notes --
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def shot_0004(self):
|
||||
return _load_screenshot("shot_0004_full.png")
|
||||
|
||||
def test_fichier_10_fois(self, shot_0004):
|
||||
"""Le VLM trouve le menu 'Fichier' au même endroit 10 fois."""
|
||||
results = _run_n_times(
|
||||
shot_0004,
|
||||
"the 'Fichier' menu item in the menu bar",
|
||||
)
|
||||
stats = _compute_stats(results)
|
||||
_assert_reproducible(stats, "Fichier (menu)")
|
||||
_assert_in_zone(stats, CALIBRATED_ZONES["fichier_menu"], "Fichier")
|
||||
print(f"\n [Fichier] {stats['rate_str']} détections, "
|
||||
f"X=[{stats.get('x_min', 0):.4f}-{stats.get('x_max', 0):.4f}], "
|
||||
f"Y=[{stats.get('y_min', 0):.4f}-{stats.get('y_max', 0):.4f}]")
|
||||
|
||||
def test_modifier_10_fois(self, shot_0004):
|
||||
"""Le VLM trouve le menu 'Modifier' au même endroit 10 fois."""
|
||||
results = _run_n_times(
|
||||
shot_0004,
|
||||
"the 'Modifier' menu item in the menu bar",
|
||||
)
|
||||
stats = _compute_stats(results)
|
||||
_assert_reproducible(stats, "Modifier (menu)")
|
||||
_assert_in_zone(stats, CALIBRATED_ZONES["modifier_menu"], "Modifier")
|
||||
print(f"\n [Modifier] {stats['rate_str']} détections, "
|
||||
f"X=[{stats.get('x_min', 0):.4f}-{stats.get('x_max', 0):.4f}], "
|
||||
f"Y=[{stats.get('y_min', 0):.4f}-{stats.get('y_max', 0):.4f}]")
|
||||
|
||||
def test_ceci_est_un_test_10_fois(self, shot_0004):
|
||||
"""Le VLM trouve l'onglet 'Ceci est un test.txt' au même endroit 10 fois."""
|
||||
results = _run_n_times(
|
||||
shot_0004,
|
||||
"the tab labeled 'Ceci est un test.txt'",
|
||||
)
|
||||
stats = _compute_stats(results)
|
||||
_assert_reproducible(stats, "Ceci est un test.txt (onglet)")
|
||||
_assert_in_zone(stats, CALIBRATED_ZONES["ceci_est_un_test_tab"], "Ceci est un test.txt")
|
||||
print(f"\n [Ceci est un test.txt] {stats['rate_str']} détections, "
|
||||
f"X=[{stats.get('x_min', 0):.4f}-{stats.get('x_max', 0):.4f}], "
|
||||
f"Y=[{stats.get('y_min', 0):.4f}-{stats.get('y_max', 0):.4f}]")
|
||||
|
||||
# -- shot_0014 : Google Chrome --
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def shot_0014(self):
|
||||
return _load_screenshot("shot_0014_full.png")
|
||||
|
||||
def test_google_search_10_fois(self, shot_0014):
|
||||
"""Le VLM trouve la barre de recherche Google au même endroit 10 fois."""
|
||||
results = _run_n_times(
|
||||
shot_0014,
|
||||
"the Google search bar 'Rechercher sur Google ou saisir une URL'",
|
||||
)
|
||||
stats = _compute_stats(results)
|
||||
_assert_reproducible(stats, "Recherche Google")
|
||||
_assert_in_zone(stats, CALIBRATED_ZONES["google_search_bar"], "Recherche Google")
|
||||
print(f"\n [Google search] {stats['rate_str']} détections, "
|
||||
f"X=[{stats.get('x_min', 0):.4f}-{stats.get('x_max', 0):.4f}], "
|
||||
f"Y=[{stats.get('y_min', 0):.4f}-{stats.get('y_max', 0):.4f}]")
|
||||
|
||||
def test_gmail_10_fois(self, shot_0014):
|
||||
"""Le VLM trouve le lien Gmail au même endroit 10 fois."""
|
||||
results = _run_n_times(
|
||||
shot_0014,
|
||||
"the 'Gmail' link at the top of the page",
|
||||
)
|
||||
stats = _compute_stats(results)
|
||||
_assert_reproducible(stats, "Gmail")
|
||||
_assert_in_zone(stats, CALIBRATED_ZONES["gmail_link"], "Gmail")
|
||||
print(f"\n [Gmail] {stats['rate_str']} détections, "
|
||||
f"X=[{stats.get('x_min', 0):.4f}-{stats.get('x_max', 0):.4f}], "
|
||||
f"Y=[{stats.get('y_min', 0):.4f}-{stats.get('y_max', 0):.4f}]")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests de robustesse Citrix — JPEG dégradé
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestCitrixRobustesse:
|
||||
"""Vérifier que le grounding fonctionne sur des images compressées.
|
||||
|
||||
Simule un environnement Citrix/RDP avec compression JPEG qualité 15-25.
|
||||
Compare les résultats original vs dégradé.
|
||||
"""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def shots_original(self):
|
||||
return {
|
||||
"shot_0001": _load_screenshot("shot_0001_full.png"),
|
||||
"shot_0004": _load_screenshot("shot_0004_full.png"),
|
||||
"shot_0014": _load_screenshot("shot_0014_full.png"),
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def shots_citrix(self, shots_original):
|
||||
return {
|
||||
name: _degrade_citrix(b64, quality=20)
|
||||
for name, b64 in shots_original.items()
|
||||
}
|
||||
|
||||
def _compare_original_vs_citrix(
|
||||
self,
|
||||
original_b64: str,
|
||||
citrix_b64: str,
|
||||
description: str,
|
||||
element_name: str,
|
||||
zone: Dict,
|
||||
n_runs: int = 5,
|
||||
) -> Dict:
|
||||
"""Comparer les résultats original vs Citrix."""
|
||||
# 5 runs sur l'original
|
||||
results_orig = _run_n_times(original_b64, description, n=n_runs, delay=0.2)
|
||||
stats_orig = _compute_stats(results_orig)
|
||||
|
||||
# 5 runs sur le Citrix
|
||||
results_citrix = _run_n_times(citrix_b64, description, n=n_runs, delay=0.2)
|
||||
stats_citrix = _compute_stats(results_citrix)
|
||||
|
||||
return {
|
||||
"original": stats_orig,
|
||||
"citrix": stats_citrix,
|
||||
}
|
||||
|
||||
def test_rechercher_citrix(self, shots_original, shots_citrix):
|
||||
"""'Rechercher' détecté malgré compression JPEG Q20."""
|
||||
comp = self._compare_original_vs_citrix(
|
||||
shots_original["shot_0001"],
|
||||
shots_citrix["shot_0001"],
|
||||
"the 'Rechercher' search text in the Windows taskbar at the bottom",
|
||||
"Rechercher",
|
||||
CALIBRATED_ZONES["rechercher_taskbar"],
|
||||
)
|
||||
# Au moins 3/5 détections sur Citrix
|
||||
assert comp["citrix"]["detected"] >= 3, (
|
||||
f"Citrix Rechercher: seulement {comp['citrix']['rate_str']} détections"
|
||||
)
|
||||
# Position dans la zone calibrée
|
||||
if comp["citrix"]["detected"] >= 1:
|
||||
_assert_in_zone(comp["citrix"], CALIBRATED_ZONES["rechercher_taskbar"], "Rechercher (Citrix)")
|
||||
print(f"\n [Rechercher Citrix] orig={comp['original']['rate_str']}, "
|
||||
f"citrix={comp['citrix']['rate_str']}")
|
||||
|
||||
def test_fichier_citrix(self, shots_original, shots_citrix):
|
||||
"""Menu 'Fichier' détecté malgré compression JPEG Q20."""
|
||||
comp = self._compare_original_vs_citrix(
|
||||
shots_original["shot_0004"],
|
||||
shots_citrix["shot_0004"],
|
||||
"the 'Fichier' menu item in the menu bar",
|
||||
"Fichier",
|
||||
CALIBRATED_ZONES["fichier_menu"],
|
||||
)
|
||||
assert comp["citrix"]["detected"] >= 3, (
|
||||
f"Citrix Fichier: seulement {comp['citrix']['rate_str']} détections"
|
||||
)
|
||||
if comp["citrix"]["detected"] >= 1:
|
||||
_assert_in_zone(comp["citrix"], CALIBRATED_ZONES["fichier_menu"], "Fichier (Citrix)")
|
||||
print(f"\n [Fichier Citrix] orig={comp['original']['rate_str']}, "
|
||||
f"citrix={comp['citrix']['rate_str']}")
|
||||
|
||||
def test_ceci_est_un_test_citrix(self, shots_original, shots_citrix):
|
||||
"""Onglet 'Ceci est un test.txt' détecté malgré compression JPEG Q20."""
|
||||
comp = self._compare_original_vs_citrix(
|
||||
shots_original["shot_0004"],
|
||||
shots_citrix["shot_0004"],
|
||||
"the tab labeled 'Ceci est un test.txt'",
|
||||
"Ceci est un test.txt",
|
||||
CALIBRATED_ZONES["ceci_est_un_test_tab"],
|
||||
)
|
||||
assert comp["citrix"]["detected"] >= 3, (
|
||||
f"Citrix tab: seulement {comp['citrix']['rate_str']} détections"
|
||||
)
|
||||
if comp["citrix"]["detected"] >= 1:
|
||||
_assert_in_zone(
|
||||
comp["citrix"],
|
||||
CALIBRATED_ZONES["ceci_est_un_test_tab"],
|
||||
"Ceci est un test.txt (Citrix)",
|
||||
)
|
||||
print(f"\n [Ceci est un test.txt Citrix] orig={comp['original']['rate_str']}, "
|
||||
f"citrix={comp['citrix']['rate_str']}")
|
||||
|
||||
def test_google_search_citrix(self, shots_original, shots_citrix):
|
||||
"""Barre de recherche Google détectée malgré compression JPEG Q20."""
|
||||
comp = self._compare_original_vs_citrix(
|
||||
shots_original["shot_0014"],
|
||||
shots_citrix["shot_0014"],
|
||||
"the Google search bar 'Rechercher sur Google ou saisir une URL'",
|
||||
"Recherche Google",
|
||||
CALIBRATED_ZONES["google_search_bar"],
|
||||
)
|
||||
assert comp["citrix"]["detected"] >= 3, (
|
||||
f"Citrix Google: seulement {comp['citrix']['rate_str']} détections"
|
||||
)
|
||||
if comp["citrix"]["detected"] >= 1:
|
||||
_assert_in_zone(
|
||||
comp["citrix"],
|
||||
CALIBRATED_ZONES["google_search_bar"],
|
||||
"Recherche Google (Citrix)",
|
||||
)
|
||||
print(f"\n [Google search Citrix] orig={comp['original']['rate_str']}, "
|
||||
f"citrix={comp['citrix']['rate_str']}")
|
||||
|
||||
def test_gmail_citrix(self, shots_original, shots_citrix):
|
||||
"""Lien Gmail détecté malgré compression JPEG Q20."""
|
||||
comp = self._compare_original_vs_citrix(
|
||||
shots_original["shot_0014"],
|
||||
shots_citrix["shot_0014"],
|
||||
"the 'Gmail' link at the top of the page",
|
||||
"Gmail",
|
||||
CALIBRATED_ZONES["gmail_link"],
|
||||
)
|
||||
assert comp["citrix"]["detected"] >= 3, (
|
||||
f"Citrix Gmail: seulement {comp['citrix']['rate_str']} détections"
|
||||
)
|
||||
if comp["citrix"]["detected"] >= 1:
|
||||
_assert_in_zone(comp["citrix"], CALIBRATED_ZONES["gmail_link"], "Gmail (Citrix)")
|
||||
print(f"\n [Gmail Citrix] orig={comp['original']['rate_str']}, "
|
||||
f"citrix={comp['citrix']['rate_str']}")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests de dégradation progressive — qualité JPEG 50 → 15 → 5
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestDegradationProgressive:
|
||||
"""Mesurer à partir de quelle qualité JPEG le grounding échoue."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def shot_0004(self):
|
||||
return _load_screenshot("shot_0004_full.png")
|
||||
|
||||
def test_fichier_degradation_progressive(self, shot_0004):
|
||||
"""Fichier menu : tester JPEG Q50, Q25, Q15, Q10, Q5."""
|
||||
qualities = [50, 25, 15, 10, 5]
|
||||
results_by_quality = {}
|
||||
|
||||
for q in qualities:
|
||||
degraded = _degrade_citrix(shot_0004, quality=q)
|
||||
results = _run_n_times(
|
||||
degraded,
|
||||
"the 'Fichier' menu item in the menu bar",
|
||||
n=3,
|
||||
delay=0.2,
|
||||
)
|
||||
stats = _compute_stats(results)
|
||||
results_by_quality[q] = stats
|
||||
|
||||
# Afficher le rapport de dégradation
|
||||
print("\n === Dégradation progressive : Fichier menu ===")
|
||||
for q in qualities:
|
||||
s = results_by_quality[q]
|
||||
zone_ok = ""
|
||||
if s["detected"] >= 1:
|
||||
cx = s["x_mean"]
|
||||
cy = s["y_mean"]
|
||||
z = CALIBRATED_ZONES["fichier_menu"]
|
||||
in_zone = z["x_min"] <= cx <= z["x_max"] and z["y_min"] <= cy <= z["y_max"]
|
||||
zone_ok = " (in zone)" if in_zone else f" (HORS zone: {cx:.3f},{cy:.3f})"
|
||||
print(f" Q{q:>2}: {s['rate_str']} détections{zone_ok}")
|
||||
|
||||
# Au moins Q50 et Q25 doivent fonctionner
|
||||
assert results_by_quality[50]["detected"] >= 2, "Q50 devrait fonctionner"
|
||||
assert results_by_quality[25]["detected"] >= 2, "Q25 devrait fonctionner"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Rapport final — exécuté en dernier, résume tout
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@pytest.mark.visual
|
||||
class TestRapportFinal:
|
||||
"""Rapport complet des capacités de grounding VLM.
|
||||
|
||||
Ce test exécute une batterie de détections et produit un rapport
|
||||
structuré avec taux de détection, variance, et comparaison Citrix.
|
||||
"""
|
||||
|
||||
def test_rapport_complet(self):
|
||||
"""Génère le rapport final de robustesse du grounding VLM."""
|
||||
from PIL import Image
|
||||
|
||||
shots = {
|
||||
"shot_0001": _load_screenshot("shot_0001_full.png"),
|
||||
"shot_0004": _load_screenshot("shot_0004_full.png"),
|
||||
"shot_0014": _load_screenshot("shot_0014_full.png"),
|
||||
}
|
||||
|
||||
targets = [
|
||||
("shot_0001", "Rechercher (taskbar)",
|
||||
"the 'Rechercher' search text in the Windows taskbar at the bottom",
|
||||
CALIBRATED_ZONES["rechercher_taskbar"]),
|
||||
("shot_0001", "agent_v1 (dossier)",
|
||||
"the folder named 'agent_v1' in the file list",
|
||||
CALIBRATED_ZONES["agent_v1_folder"]),
|
||||
("shot_0004", "Fichier (menu)",
|
||||
"the 'Fichier' menu item in the menu bar",
|
||||
CALIBRATED_ZONES["fichier_menu"]),
|
||||
("shot_0004", "Modifier (menu)",
|
||||
"the 'Modifier' menu item in the menu bar",
|
||||
CALIBRATED_ZONES["modifier_menu"]),
|
||||
("shot_0004", "Ceci est un test.txt (onglet)",
|
||||
"the tab labeled 'Ceci est un test.txt'",
|
||||
CALIBRATED_ZONES["ceci_est_un_test_tab"]),
|
||||
("shot_0004", "Close X (Bloc-notes)",
|
||||
"the close button X of the Notepad window at the top right",
|
||||
CALIBRATED_ZONES["close_x_notepad"]),
|
||||
("shot_0014", "Recherche Google (barre)",
|
||||
"the Google search bar 'Rechercher sur Google ou saisir une URL'",
|
||||
CALIBRATED_ZONES["google_search_bar"]),
|
||||
("shot_0014", "Gmail (lien)",
|
||||
"the 'Gmail' link at the top of the page",
|
||||
CALIBRATED_ZONES["gmail_link"]),
|
||||
]
|
||||
|
||||
report_lines = [
|
||||
"",
|
||||
"=" * 80,
|
||||
"RAPPORT DE ROBUSTESSE — Grounding VLM qwen2.5vl:7b",
|
||||
f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
f"Screenshots: 1280x800 (3 images, {len(targets)} cibles)",
|
||||
f"Répétitions: 5 par cible (original + Citrix Q20)",
|
||||
"=" * 80,
|
||||
"",
|
||||
"--- ORIGINAL (PNG) ---",
|
||||
f"{'Élément':<35} {'Taux':>6} {'X moy':>8} {'Y moy':>8} "
|
||||
f"{'Var X':>8} {'Var Y':>8} {'Zone':>6}",
|
||||
"-" * 80,
|
||||
]
|
||||
|
||||
all_original_stats = []
|
||||
all_citrix_stats = []
|
||||
|
||||
for shot_name, label, desc, zone in targets:
|
||||
# Original : 5 runs
|
||||
results_orig = _run_n_times(shots[shot_name], desc, n=5, delay=0.2)
|
||||
stats_orig = _compute_stats(results_orig)
|
||||
all_original_stats.append((label, stats_orig, zone))
|
||||
|
||||
in_zone = "?"
|
||||
if stats_orig["detected"] >= 1:
|
||||
cx, cy = stats_orig["x_mean"], stats_orig["y_mean"]
|
||||
ok = (zone["x_min"] <= cx <= zone["x_max"]
|
||||
and zone["y_min"] <= cy <= zone["y_max"])
|
||||
in_zone = "OK" if ok else "HORS"
|
||||
|
||||
report_lines.append(
|
||||
f"{label:<35} {stats_orig['rate_str']:>6} "
|
||||
f"{stats_orig.get('x_mean', 0):>8.4f} "
|
||||
f"{stats_orig.get('y_mean', 0):>8.4f} "
|
||||
f"{stats_orig.get('x_range', 0):>8.4f} "
|
||||
f"{stats_orig.get('y_range', 0):>8.4f} "
|
||||
f"{in_zone:>6}"
|
||||
)
|
||||
|
||||
report_lines.extend([
|
||||
"",
|
||||
"--- CITRIX (JPEG Q20) ---",
|
||||
f"{'Élément':<35} {'Taux':>6} {'X moy':>8} {'Y moy':>8} "
|
||||
f"{'Var X':>8} {'Var Y':>8} {'Zone':>6} {'Écart orig':>10}",
|
||||
"-" * 90,
|
||||
])
|
||||
|
||||
for i, (shot_name, label, desc, zone) in enumerate(targets):
|
||||
citrix_b64 = _degrade_citrix(shots[shot_name], quality=20)
|
||||
results_citrix = _run_n_times(citrix_b64, desc, n=5, delay=0.2)
|
||||
stats_citrix = _compute_stats(results_citrix)
|
||||
all_citrix_stats.append((label, stats_citrix, zone))
|
||||
|
||||
in_zone = "?"
|
||||
ecart = "N/A"
|
||||
if stats_citrix["detected"] >= 1:
|
||||
cx, cy = stats_citrix["x_mean"], stats_citrix["y_mean"]
|
||||
ok = (zone["x_min"] <= cx <= zone["x_max"]
|
||||
and zone["y_min"] <= cy <= zone["y_max"])
|
||||
in_zone = "OK" if ok else "HORS"
|
||||
|
||||
# Calculer l'écart avec l'original
|
||||
orig_stats = all_original_stats[i][1]
|
||||
if orig_stats["detected"] >= 1:
|
||||
dx = abs(cx - orig_stats["x_mean"])
|
||||
dy = abs(cy - orig_stats["y_mean"])
|
||||
ecart = f"{dx:.4f}/{dy:.4f}"
|
||||
|
||||
report_lines.append(
|
||||
f"{label:<35} {stats_citrix['rate_str']:>6} "
|
||||
f"{stats_citrix.get('x_mean', 0):>8.4f} "
|
||||
f"{stats_citrix.get('y_mean', 0):>8.4f} "
|
||||
f"{stats_citrix.get('x_range', 0):>8.4f} "
|
||||
f"{stats_citrix.get('y_range', 0):>8.4f} "
|
||||
f"{in_zone:>6} {ecart:>10}"
|
||||
)
|
||||
|
||||
# Résumé
|
||||
orig_total = sum(s["detected"] for _, s, _ in all_original_stats)
|
||||
orig_max = sum(s["total"] for _, s, _ in all_original_stats)
|
||||
citrix_total = sum(s["detected"] for _, s, _ in all_citrix_stats)
|
||||
citrix_max = sum(s["total"] for _, s, _ in all_citrix_stats)
|
||||
|
||||
orig_in_zone = sum(
|
||||
1 for _, s, z in all_original_stats
|
||||
if s["detected"] >= 1
|
||||
and z["x_min"] <= s["x_mean"] <= z["x_max"]
|
||||
and z["y_min"] <= s["y_mean"] <= z["y_max"]
|
||||
)
|
||||
citrix_in_zone = sum(
|
||||
1 for _, s, z in all_citrix_stats
|
||||
if s["detected"] >= 1
|
||||
and z["x_min"] <= s["x_mean"] <= z["x_max"]
|
||||
and z["y_min"] <= s["y_mean"] <= z["y_max"]
|
||||
)
|
||||
|
||||
# Éléments non fiables
|
||||
unreliable = []
|
||||
for label, s, _ in all_original_stats:
|
||||
if s["detected"] < 3:
|
||||
unreliable.append(f"{label} (taux {s['rate_str']})")
|
||||
elif s.get("x_range", 0) >= _MAX_VARIANCE or s.get("y_range", 0) >= _MAX_VARIANCE:
|
||||
unreliable.append(
|
||||
f"{label} (variance X={s.get('x_range', 0):.4f} "
|
||||
f"Y={s.get('y_range', 0):.4f})"
|
||||
)
|
||||
|
||||
report_lines.extend([
|
||||
"",
|
||||
"=" * 80,
|
||||
"RÉSUMÉ",
|
||||
"=" * 80,
|
||||
f" Détection original : {orig_total}/{orig_max} "
|
||||
f"({orig_total/orig_max*100:.0f}%)",
|
||||
f" Détection Citrix Q20: {citrix_total}/{citrix_max} "
|
||||
f"({citrix_total/citrix_max*100:.0f}%)",
|
||||
f" Positionnement correct (original) : {orig_in_zone}/{len(all_original_stats)}",
|
||||
f" Positionnement correct (Citrix) : {citrix_in_zone}/{len(all_citrix_stats)}",
|
||||
"",
|
||||
])
|
||||
|
||||
if unreliable:
|
||||
report_lines.append(" ÉLÉMENTS NON FIABLES :")
|
||||
for u in unreliable:
|
||||
report_lines.append(f" - {u}")
|
||||
else:
|
||||
report_lines.append(" Tous les éléments sont fiables.")
|
||||
|
||||
report_lines.extend([
|
||||
"",
|
||||
" NOTES TECHNIQUES :",
|
||||
" - qwen2.5vl bbox_2d retourne des pixels relatifs à l'image envoyée",
|
||||
" - Normalisation : diviser par les dimensions de l'image (W, H)",
|
||||
" - temperature=0.1 donne une variance < 0.003 typiquement",
|
||||
"=" * 80,
|
||||
])
|
||||
|
||||
report = "\n".join(report_lines)
|
||||
print(report)
|
||||
|
||||
# Le test réussit si au moins 80% des détections originales fonctionnent
|
||||
assert orig_total / orig_max >= 0.80, (
|
||||
f"Taux de détection global trop bas: {orig_total}/{orig_max}"
|
||||
)
|
||||
Reference in New Issue
Block a user