feat: replay visuel VLM-first, worker séparé, package Léa, AZERTY, sécurité HTTPS
Pipeline replay visuel : - VLM-first : l'agent appelle Ollama directement pour trouver les éléments - Template matching en fallback (seuil strict 0.90) - Stop immédiat si élément non trouvé (pas de clic blind) - Replay depuis session brute (/replay-session) sans attendre le VLM - Vérification post-action (screenshot hash avant/après) - Gestion des popups (Enter/Escape/Tab+Enter) Worker VLM séparé : - run_worker.py : process distinct du serveur HTTP - Communication par fichiers (_worker_queue.txt + _replay_active.lock) - Le serveur HTTP ne fait plus jamais de VLM → toujours réactif - Service systemd rpa-worker.service Capture clavier : - raw_keys (vk + press/release) pour replay exact indépendant du layout - Fix AZERTY : ToUnicodeEx + AltGr detection - Enter capturé comme \n, Tab comme \t - Filtrage modificateurs seuls (Ctrl/Alt/Shift parasites) - Fusion text_input consécutifs, dédup key_combo Sécurité & Internet : - HTTPS Let's Encrypt (lea.labs + vwb.labs.laurinebazin.design) - Token API fixe dans .env.local - HTTP Basic Auth sur VWB - Security headers (HSTS, CSP, nosniff) - CORS domaines publics, plus de wildcard Infrastructure : - DPI awareness (SetProcessDpiAwareness) Python + Rust - Métadonnées système (dpi_scale, window_bounds, monitors, os_theme) - Template matching multi-scale [0.5, 2.0] - Résolution dynamique (plus de hardcode 1920x1080) - VLM prefill fix (47x speedup, 3.5s au lieu de 180s) Modules : - core/auth/ : credential vault (Fernet AES), TOTP (RFC 6238), auth handler - core/federation/ : LearningPack export/import anonymisé, FAISS global - deploy/ : package Léa (config.txt, Lea.bat, install.bat, LISEZMOI.txt) UX : - Filtrage OS (VWB + Chat montrent que les workflows de l'OS courant) - Bibliothèque persistante (cache local + SQLite) - Clustering hybride (titre fenêtre + DBSCAN) - EdgeConstraints + PostConditions peuplés - GraphBuilder compound actions (toutes les frappes) Agent Rust : - Token Bearer auth (network.rs) - sysinfo.rs (DPI, résolution, window bounds via Win32 API) - config.txt lu automatiquement - Support Chrome/Brave/Firefox (pas que Edge) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -220,7 +220,7 @@ class TestStreamWorker:
|
||||
event_file = session_dir / "live_events.jsonl"
|
||||
event_file.write_text(
|
||||
json.dumps({"type": "click", "timestamp": 100}) + "\n"
|
||||
+ json.dumps({"type": "key_press", "timestamp": 200}) + "\n"
|
||||
+ json.dumps({"type": "key_press", "keys": ["enter"], "timestamp": 200}) + "\n"
|
||||
)
|
||||
|
||||
# Simuler un tour de polling
|
||||
|
||||
576
tests/unit/test_auth.py
Normal file
576
tests/unit/test_auth.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""
|
||||
Tests du module d'authentification automatique (core/auth).
|
||||
|
||||
Couvre :
|
||||
- TOTPGenerator : génération, vérification, vecteurs de test RFC 6238
|
||||
- CredentialVault : CRUD, chiffrement, persistance
|
||||
- AuthHandler : détection d'écrans d'auth, génération d'actions
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from core.auth.credential_vault import CredentialVault, _HAS_FERNET
|
||||
from core.auth.totp_generator import TOTPGenerator
|
||||
from core.auth.auth_handler import AuthHandler, AuthRequest
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests TOTP
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestTOTPGenerator:
|
||||
"""Tests du générateur TOTP RFC 6238."""
|
||||
|
||||
def test_generate_returns_6_digits(self):
|
||||
"""Le code généré fait exactement 6 chiffres."""
|
||||
totp = TOTPGenerator("JBSWY3DPEHPK3PXP")
|
||||
code = totp.generate()
|
||||
assert len(code) == 6
|
||||
assert code.isdigit()
|
||||
|
||||
def test_generate_deterministic(self):
|
||||
"""Le même timestamp donne le même code."""
|
||||
totp = TOTPGenerator("JBSWY3DPEHPK3PXP")
|
||||
ts = 1700000000.0
|
||||
code1 = totp.generate(timestamp=ts)
|
||||
code2 = totp.generate(timestamp=ts)
|
||||
assert code1 == code2
|
||||
|
||||
def test_verify_current_code(self):
|
||||
"""Le code généré est validé par verify()."""
|
||||
totp = TOTPGenerator("JBSWY3DPEHPK3PXP")
|
||||
ts = time.time()
|
||||
code = totp.generate(timestamp=ts)
|
||||
assert totp.verify(code, timestamp=ts)
|
||||
|
||||
def test_verify_rejects_wrong_code(self):
|
||||
"""Un code incorrect est rejeté."""
|
||||
totp = TOTPGenerator("JBSWY3DPEHPK3PXP")
|
||||
# Utiliser un timestamp suffisamment grand pour éviter les problèmes
|
||||
# avec window=-1 (counter négatif)
|
||||
assert not totp.verify("000000", timestamp=1700000000.0)
|
||||
|
||||
def test_verify_with_window(self):
|
||||
"""La fenêtre de tolérance accepte les codes adjacents."""
|
||||
totp = TOTPGenerator("JBSWY3DPEHPK3PXP", interval=30)
|
||||
ts = 1700000000.0
|
||||
# Code de l'intervalle précédent
|
||||
prev_code = totp.generate(timestamp=ts - 30)
|
||||
assert totp.verify(prev_code, timestamp=ts, window=1)
|
||||
# Code de l'intervalle suivant
|
||||
next_code = totp.generate(timestamp=ts + 30)
|
||||
assert totp.verify(next_code, timestamp=ts, window=1)
|
||||
|
||||
def test_verify_window_zero_strict(self):
|
||||
"""Window=0 n'accepte que le code exact de l'intervalle courant."""
|
||||
totp = TOTPGenerator("JBSWY3DPEHPK3PXP", interval=30)
|
||||
ts = 1700000000.0
|
||||
code = totp.generate(timestamp=ts)
|
||||
assert totp.verify(code, timestamp=ts, window=0)
|
||||
prev_code = totp.generate(timestamp=ts - 30)
|
||||
assert not totp.verify(prev_code, timestamp=ts, window=0)
|
||||
|
||||
def test_time_remaining_in_range(self):
|
||||
"""time_remaining() retourne entre 1 et interval."""
|
||||
totp = TOTPGenerator("JBSWY3DPEHPK3PXP", interval=30)
|
||||
remaining = totp.time_remaining()
|
||||
assert 1 <= remaining <= 30
|
||||
|
||||
def test_8_digits(self):
|
||||
"""Support des codes à 8 chiffres."""
|
||||
totp = TOTPGenerator("JBSWY3DPEHPK3PXP", digits=8)
|
||||
code = totp.generate()
|
||||
assert len(code) == 8
|
||||
assert code.isdigit()
|
||||
|
||||
def test_rfc6238_sha1_vector(self):
|
||||
"""Vecteur de test RFC 6238 pour SHA1.
|
||||
|
||||
Secret de test : "12345678901234567890" (ASCII)
|
||||
En base32 : "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ"
|
||||
Timestamp : 59 → T = 59 // 30 = 1 → code attendu 287082
|
||||
"""
|
||||
# Le secret ASCII "12345678901234567890" encodé en base32
|
||||
secret_b32 = "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ"
|
||||
totp = TOTPGenerator(secret_b32, digits=8, interval=30, algorithm="SHA1")
|
||||
code = totp.generate(timestamp=59)
|
||||
assert code == "94287082"
|
||||
|
||||
def test_rfc6238_sha1_vector_t1111111109(self):
|
||||
"""Vecteur de test RFC 6238 — T=1111111109."""
|
||||
secret_b32 = "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ"
|
||||
totp = TOTPGenerator(secret_b32, digits=8, interval=30, algorithm="SHA1")
|
||||
code = totp.generate(timestamp=1111111109)
|
||||
assert code == "07081804"
|
||||
|
||||
def test_rfc6238_sha256_vector(self):
|
||||
"""Vecteur de test RFC 6238 pour SHA256.
|
||||
|
||||
Secret 32 bytes : "12345678901234567890123456789012"
|
||||
En base32 : "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQGEZA"
|
||||
Timestamp : 59 → code attendu 46119246
|
||||
"""
|
||||
secret_b32 = "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQGEZA"
|
||||
totp = TOTPGenerator(secret_b32, digits=8, interval=30, algorithm="SHA256")
|
||||
code = totp.generate(timestamp=59)
|
||||
assert code == "46119246"
|
||||
|
||||
def test_invalid_secret_raises(self):
|
||||
"""Un secret invalide lève ValueError."""
|
||||
with pytest.raises(ValueError, match="base32 invalide"):
|
||||
TOTPGenerator("!!! not base32 !!!")
|
||||
|
||||
def test_invalid_algorithm_raises(self):
|
||||
"""Un algorithme inconnu lève ValueError."""
|
||||
with pytest.raises(ValueError, match="non supporté"):
|
||||
TOTPGenerator("JBSWY3DPEHPK3PXP", algorithm="MD5")
|
||||
|
||||
def test_secret_with_spaces(self):
|
||||
"""Les espaces dans le secret sont tolérés."""
|
||||
totp1 = TOTPGenerator("JBSWY3DPEHPK3PXP")
|
||||
totp2 = TOTPGenerator("JBSW Y3DP EHPK 3PXP")
|
||||
ts = 1700000000.0
|
||||
assert totp1.generate(timestamp=ts) == totp2.generate(timestamp=ts)
|
||||
|
||||
def test_zero_padded_code(self):
|
||||
"""Les codes courts sont zero-padded (ex: 003271 et non 3271)."""
|
||||
totp = TOTPGenerator("JBSWY3DPEHPK3PXP")
|
||||
# Tester beaucoup de timestamps pour trouver un code qui commence par 0
|
||||
for ts in range(1700000000, 1700001000, 30):
|
||||
code = totp.generate(timestamp=float(ts))
|
||||
assert len(code) == 6, f"Code {code!r} n'a pas 6 chiffres pour ts={ts}"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests CredentialVault
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestCredentialVault:
|
||||
"""Tests du coffre-fort chiffré."""
|
||||
|
||||
def test_create_add_get(self):
|
||||
"""Créer un vault, ajouter un credential, le récupérer."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".enc", delete=False) as f:
|
||||
vault_path = f.name
|
||||
try:
|
||||
os.unlink(vault_path) # Supprimer pour que le vault se crée
|
||||
vault = CredentialVault(vault_path, "test_password")
|
||||
vault.add_credential("TestApp", "login", {
|
||||
"username": "user1",
|
||||
"password": "pass1",
|
||||
})
|
||||
cred = vault.get_credential("TestApp", "login")
|
||||
assert cred is not None
|
||||
assert cred["username"] == "user1"
|
||||
assert cred["password"] == "pass1"
|
||||
finally:
|
||||
if os.path.exists(vault_path):
|
||||
os.unlink(vault_path)
|
||||
|
||||
def test_save_and_reload(self):
|
||||
"""Sauvegarder et recharger un vault préserve les données."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".enc", delete=False) as f:
|
||||
vault_path = f.name
|
||||
try:
|
||||
os.unlink(vault_path)
|
||||
vault = CredentialVault(vault_path, "master123")
|
||||
vault.add_credential("MyApp", "login", {
|
||||
"username": "admin",
|
||||
"password": "secret",
|
||||
})
|
||||
vault.add_credential("MyApp", "totp_seed", {
|
||||
"secret": "JBSWY3DPEHPK3PXP",
|
||||
"digits": 6,
|
||||
"interval": 30,
|
||||
"algorithm": "SHA1",
|
||||
})
|
||||
vault.save()
|
||||
|
||||
# Recharger
|
||||
vault2 = CredentialVault(vault_path, "master123")
|
||||
assert vault2.list_apps() == ["MyApp"]
|
||||
login = vault2.get_credential("MyApp", "login")
|
||||
assert login["username"] == "admin"
|
||||
totp = vault2.get_credential("MyApp", "totp_seed")
|
||||
assert totp["secret"] == "JBSWY3DPEHPK3PXP"
|
||||
finally:
|
||||
if os.path.exists(vault_path):
|
||||
os.unlink(vault_path)
|
||||
|
||||
def test_remove_credential(self):
|
||||
"""Supprimer un credential fonctionne."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".enc", delete=False) as f:
|
||||
vault_path = f.name
|
||||
try:
|
||||
os.unlink(vault_path)
|
||||
vault = CredentialVault(vault_path, "pw")
|
||||
vault.add_credential("App1", "login", {"username": "u", "password": "p"})
|
||||
assert vault.remove_credential("App1", "login") is True
|
||||
assert vault.get_credential("App1", "login") is None
|
||||
assert vault.list_apps() == []
|
||||
finally:
|
||||
if os.path.exists(vault_path):
|
||||
os.unlink(vault_path)
|
||||
|
||||
def test_remove_nonexistent(self):
|
||||
"""Supprimer un credential inexistant retourne False."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".enc", delete=False) as f:
|
||||
vault_path = f.name
|
||||
try:
|
||||
os.unlink(vault_path)
|
||||
vault = CredentialVault(vault_path, "pw")
|
||||
assert vault.remove_credential("NopApp", "login") is False
|
||||
finally:
|
||||
if os.path.exists(vault_path):
|
||||
os.unlink(vault_path)
|
||||
|
||||
def test_list_apps_sorted(self):
|
||||
"""list_apps() retourne les apps triées."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".enc", delete=False) as f:
|
||||
vault_path = f.name
|
||||
try:
|
||||
os.unlink(vault_path)
|
||||
vault = CredentialVault(vault_path, "pw")
|
||||
vault.add_credential("Zebra", "login", {"username": "z", "password": "z"})
|
||||
vault.add_credential("Alpha", "login", {"username": "a", "password": "a"})
|
||||
vault.add_credential("Middle", "login", {"username": "m", "password": "m"})
|
||||
assert vault.list_apps() == ["Alpha", "Middle", "Zebra"]
|
||||
finally:
|
||||
if os.path.exists(vault_path):
|
||||
os.unlink(vault_path)
|
||||
|
||||
def test_invalid_credential_type(self):
|
||||
"""Un type de credential invalide lève ValueError."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".enc", delete=False) as f:
|
||||
vault_path = f.name
|
||||
try:
|
||||
os.unlink(vault_path)
|
||||
vault = CredentialVault(vault_path, "pw")
|
||||
with pytest.raises(ValueError, match="invalide"):
|
||||
vault.add_credential("App1", "invalid_type", {})
|
||||
finally:
|
||||
if os.path.exists(vault_path):
|
||||
os.unlink(vault_path)
|
||||
|
||||
def test_encryption_on_disk(self):
|
||||
"""Le fichier vault sur disque ne contient pas de texte en clair."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".enc", delete=False) as f:
|
||||
vault_path = f.name
|
||||
try:
|
||||
os.unlink(vault_path)
|
||||
vault = CredentialVault(vault_path, "strong_password_42")
|
||||
vault.add_credential("SecretApp", "login", {
|
||||
"username": "robot_lea",
|
||||
"password": "super_secret_password_xyz",
|
||||
})
|
||||
vault.save()
|
||||
|
||||
# Lire le fichier brut
|
||||
raw_bytes = open(vault_path, "rb").read()
|
||||
raw_str = raw_bytes.decode("latin-1") # Pour chercher du texte ASCII
|
||||
|
||||
# Les données sensibles ne doivent PAS apparaître en clair
|
||||
assert "robot_lea" not in raw_str
|
||||
assert "super_secret_password_xyz" not in raw_str
|
||||
assert "SecretApp" not in raw_str
|
||||
finally:
|
||||
if os.path.exists(vault_path):
|
||||
os.unlink(vault_path)
|
||||
|
||||
def test_wrong_password_raises(self):
|
||||
"""Un mauvais mot de passe empêche le déchiffrement."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".enc", delete=False) as f:
|
||||
vault_path = f.name
|
||||
try:
|
||||
os.unlink(vault_path)
|
||||
vault = CredentialVault(vault_path, "correct_password")
|
||||
vault.add_credential("App", "login", {"username": "u", "password": "p"})
|
||||
vault.save()
|
||||
|
||||
# Tenter de charger avec un mauvais mot de passe
|
||||
with pytest.raises(ValueError, match="[Mm]ot de passe|corrompu"):
|
||||
CredentialVault(vault_path, "wrong_password")
|
||||
finally:
|
||||
if os.path.exists(vault_path):
|
||||
os.unlink(vault_path)
|
||||
|
||||
def test_multiple_credential_types_per_app(self):
|
||||
"""Une app peut avoir plusieurs types de credentials."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".enc", delete=False) as f:
|
||||
vault_path = f.name
|
||||
try:
|
||||
os.unlink(vault_path)
|
||||
vault = CredentialVault(vault_path, "pw")
|
||||
vault.add_credential("DPI", "login", {
|
||||
"username": "lea", "password": "p"
|
||||
})
|
||||
vault.add_credential("DPI", "totp_seed", {
|
||||
"secret": "JBSWY3DPEHPK3PXP"
|
||||
})
|
||||
assert vault.list_credential_types("DPI") == ["login", "totp_seed"]
|
||||
assert vault.get_credential("DPI", "login")["username"] == "lea"
|
||||
assert vault.get_credential("DPI", "totp_seed")["secret"] == "JBSWY3DPEHPK3PXP"
|
||||
finally:
|
||||
if os.path.exists(vault_path):
|
||||
os.unlink(vault_path)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests AuthHandler
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestAuthHandler:
|
||||
"""Tests du gestionnaire d'authentification."""
|
||||
|
||||
@pytest.fixture
|
||||
def vault_with_creds(self, tmp_path):
|
||||
"""Vault avec des credentials de test."""
|
||||
vault_path = str(tmp_path / "test_vault.enc")
|
||||
vault = CredentialVault(vault_path, "test_pw")
|
||||
vault.add_credential("DPI_Crossway", "login", {
|
||||
"username": "robot_lea",
|
||||
"password": "secret123",
|
||||
"domain": "HOPITAL",
|
||||
})
|
||||
vault.add_credential("DPI_Crossway", "totp_seed", {
|
||||
"secret": "JBSWY3DPEHPK3PXP",
|
||||
"digits": 6,
|
||||
"interval": 30,
|
||||
"algorithm": "SHA1",
|
||||
})
|
||||
vault.add_credential("Outlook", "login", {
|
||||
"username": "lea@hopital.fr",
|
||||
"password": "outlook_pass",
|
||||
})
|
||||
return vault
|
||||
|
||||
@pytest.fixture
|
||||
def handler(self, vault_with_creds):
|
||||
return AuthHandler(vault_with_creds)
|
||||
|
||||
def test_detect_login_screen(self, handler):
|
||||
"""Détecter un écran de login classique."""
|
||||
screen_state = {
|
||||
"perception": {
|
||||
"detected_text": [
|
||||
"Bienvenue sur DPI Crossway",
|
||||
"Identifiant",
|
||||
"Mot de passe",
|
||||
"Se connecter",
|
||||
],
|
||||
},
|
||||
"ui_elements": [
|
||||
{"type": "text_input", "role": "text", "label": "Identifiant", "center": [500, 300], "element_id": "e1", "tags": []},
|
||||
{"type": "text_input", "role": "password", "label": "Mot de passe", "center": [500, 350], "element_id": "e2", "tags": []},
|
||||
{"type": "button", "role": "primary_action", "label": "Se connecter", "center": [500, 420], "element_id": "e3", "tags": []},
|
||||
],
|
||||
"window": {"app_name": "DPI_Crossway", "window_title": "DPI Crossway - Connexion"},
|
||||
}
|
||||
|
||||
auth_req = handler.detect_auth_screen(screen_state)
|
||||
assert auth_req is not None
|
||||
assert auth_req.auth_type == "login"
|
||||
assert auth_req.app_name == "DPI_Crossway"
|
||||
assert auth_req.confidence >= 0.6 # Plusieurs signaux
|
||||
|
||||
def test_detect_totp_screen(self, handler):
|
||||
"""Détecter un écran 2FA/TOTP (sans éléments de login)."""
|
||||
screen_state = {
|
||||
"perception": {
|
||||
"detected_text": [
|
||||
"Entrez votre code 2FA",
|
||||
"Code à 6 chiffres",
|
||||
],
|
||||
},
|
||||
"ui_elements": [
|
||||
{"type": "text_input", "role": "text", "label": "Code OTP", "center": [500, 350], "element_id": "e1", "tags": []},
|
||||
{"type": "button", "role": "primary_action", "label": "Confirmer", "center": [500, 420], "element_id": "e2", "tags": []},
|
||||
],
|
||||
"window": {"app_name": "DPI_Crossway"},
|
||||
}
|
||||
|
||||
auth_req = handler.detect_auth_screen(screen_state)
|
||||
assert auth_req is not None
|
||||
assert auth_req.auth_type == "totp"
|
||||
assert auth_req.confidence >= 0.3
|
||||
|
||||
def test_detect_login_and_totp(self, handler):
|
||||
"""Détecter un écran combiné login + TOTP."""
|
||||
screen_state = {
|
||||
"perception": {
|
||||
"detected_text": [
|
||||
"Connexion sécurisée",
|
||||
"Identifiant",
|
||||
"Mot de passe",
|
||||
"Code OTP",
|
||||
],
|
||||
},
|
||||
"ui_elements": [
|
||||
{"type": "text_input", "role": "text", "label": "Identifiant", "center": [500, 300], "element_id": "e1", "tags": []},
|
||||
{"type": "text_input", "role": "password", "label": "Mot de passe", "center": [500, 350], "element_id": "e2", "tags": []},
|
||||
{"type": "text_input", "role": "text", "label": "Code OTP", "center": [500, 400], "element_id": "e3", "tags": []},
|
||||
{"type": "button", "role": "primary_action", "label": "Valider", "center": [500, 450], "element_id": "e4", "tags": []},
|
||||
],
|
||||
"window": {"app_name": "DPI_Crossway"},
|
||||
}
|
||||
|
||||
auth_req = handler.detect_auth_screen(screen_state)
|
||||
assert auth_req is not None
|
||||
assert auth_req.auth_type == "login_and_totp"
|
||||
assert auth_req.confidence >= 0.85 # Beaucoup de signaux
|
||||
|
||||
def test_no_auth_on_normal_screen(self, handler):
|
||||
"""Un écran normal ne déclenche pas de détection."""
|
||||
screen_state = {
|
||||
"perception": {
|
||||
"detected_text": ["Patient: Jean Dupont", "Dossier médical", "Résultats"],
|
||||
},
|
||||
"ui_elements": [
|
||||
{"type": "button", "role": "navigation", "label": "Suivant", "center": [500, 500], "element_id": "e1", "tags": []},
|
||||
],
|
||||
"window": {"app_name": "DPI_Crossway"},
|
||||
}
|
||||
|
||||
auth_req = handler.detect_auth_screen(screen_state)
|
||||
assert auth_req is None
|
||||
|
||||
def test_get_auth_actions_login(self, handler):
|
||||
"""Générer les actions pour un login classique."""
|
||||
auth_req = AuthRequest(
|
||||
auth_type="login",
|
||||
app_name="DPI_Crossway",
|
||||
detected_fields={
|
||||
"username_field": {"type": "text_input", "label": "Identifiant", "center": [500, 300], "element_id": "e1"},
|
||||
"password_field": {"type": "text_input", "label": "Mot de passe", "center": [500, 350], "element_id": "e2"},
|
||||
"submit_button": {"type": "button", "label": "Se connecter", "center": [500, 420], "element_id": "e3"},
|
||||
},
|
||||
confidence=0.85,
|
||||
)
|
||||
|
||||
actions = handler.get_auth_actions(auth_req)
|
||||
assert len(actions) > 0
|
||||
|
||||
# Vérifier la séquence : click username, type username, click password, type password, click submit, wait
|
||||
action_types = [(a["type"], a.get("text", "")) for a in actions]
|
||||
|
||||
# Il doit y avoir des clics et des saisies
|
||||
has_click = any(a["type"] == "click" for a in actions)
|
||||
has_type = any(a["type"] == "type_text" for a in actions)
|
||||
has_wait = any(a["type"] == "wait" for a in actions)
|
||||
assert has_click
|
||||
assert has_type
|
||||
assert has_wait
|
||||
|
||||
# Vérifier que le username et password sont ceux du vault
|
||||
typed_texts = [a["text"] for a in actions if a["type"] == "type_text"]
|
||||
assert "robot_lea" in typed_texts
|
||||
assert "secret123" in typed_texts
|
||||
|
||||
# Toutes les actions ont le flag _auth_action
|
||||
for action in actions:
|
||||
assert action.get("_auth_action") is True
|
||||
|
||||
def test_get_auth_actions_totp(self, handler):
|
||||
"""Générer les actions pour une auth TOTP."""
|
||||
auth_req = AuthRequest(
|
||||
auth_type="totp",
|
||||
app_name="DPI_Crossway",
|
||||
detected_fields={
|
||||
"otp_field": {"type": "text_input", "label": "Code", "center": [500, 350], "element_id": "e1"},
|
||||
"submit_button": {"type": "button", "label": "Valider", "center": [500, 420], "element_id": "e2"},
|
||||
},
|
||||
confidence=0.85,
|
||||
)
|
||||
|
||||
actions = handler.get_auth_actions(auth_req)
|
||||
assert len(actions) > 0
|
||||
|
||||
# Vérifier qu'un code TOTP est tapé (6 chiffres)
|
||||
typed_texts = [a["text"] for a in actions if a["type"] == "type_text"]
|
||||
assert len(typed_texts) >= 1
|
||||
totp_code = typed_texts[0]
|
||||
assert len(totp_code) == 6
|
||||
assert totp_code.isdigit()
|
||||
|
||||
def test_get_auth_actions_login_and_totp(self, handler):
|
||||
"""Générer les actions pour login + TOTP combiné."""
|
||||
auth_req = AuthRequest(
|
||||
auth_type="login_and_totp",
|
||||
app_name="DPI_Crossway",
|
||||
detected_fields={
|
||||
"username_field": {"type": "text_input", "label": "Identifiant", "center": [500, 300], "element_id": "e1"},
|
||||
"password_field": {"type": "text_input", "label": "Mot de passe", "center": [500, 350], "element_id": "e2"},
|
||||
"otp_field": {"type": "text_input", "label": "Code OTP", "center": [500, 400], "element_id": "e3"},
|
||||
"submit_button": {"type": "button", "label": "Valider", "center": [500, 450], "element_id": "e4"},
|
||||
},
|
||||
confidence=0.95,
|
||||
)
|
||||
|
||||
actions = handler.get_auth_actions(auth_req)
|
||||
assert len(actions) > 0
|
||||
|
||||
typed_texts = [a["text"] for a in actions if a["type"] == "type_text"]
|
||||
# username + password + TOTP code
|
||||
assert len(typed_texts) >= 3
|
||||
assert "robot_lea" in typed_texts
|
||||
assert "secret123" in typed_texts
|
||||
# Le 3e est un code TOTP à 6 chiffres
|
||||
totp_code = typed_texts[2]
|
||||
assert len(totp_code) == 6
|
||||
assert totp_code.isdigit()
|
||||
|
||||
def test_get_auth_actions_missing_credentials(self, handler):
|
||||
"""Si le vault n'a pas les credentials, retourne une liste vide."""
|
||||
auth_req = AuthRequest(
|
||||
auth_type="login",
|
||||
app_name="AppInconnue",
|
||||
detected_fields={
|
||||
"username_field": {"type": "text_input", "label": "Login", "center": [500, 300], "element_id": "e1"},
|
||||
"password_field": {"type": "text_input", "label": "Password", "center": [500, 350], "element_id": "e2"},
|
||||
},
|
||||
confidence=0.85,
|
||||
)
|
||||
|
||||
actions = handler.get_auth_actions(auth_req)
|
||||
assert actions == []
|
||||
|
||||
def test_detect_english_auth_screen(self, handler):
|
||||
"""Détecter un écran d'auth en anglais."""
|
||||
screen_state = {
|
||||
"perception": {
|
||||
"detected_text": ["Sign in to your account", "Username", "Password"],
|
||||
},
|
||||
"ui_elements": [
|
||||
{"type": "text_input", "role": "text", "label": "Username", "center": [500, 300], "element_id": "e1", "tags": []},
|
||||
{"type": "text_input", "role": "password", "label": "Password", "center": [500, 350], "element_id": "e2", "tags": []},
|
||||
{"type": "button", "role": "primary_action", "label": "Sign in", "center": [500, 420], "element_id": "e3", "tags": []},
|
||||
],
|
||||
"window": {"app_name": "Outlook"},
|
||||
}
|
||||
|
||||
auth_req = handler.detect_auth_screen(screen_state)
|
||||
assert auth_req is not None
|
||||
assert auth_req.auth_type == "login"
|
||||
assert auth_req.app_name == "Outlook"
|
||||
|
||||
def test_detect_password_tag(self, handler):
|
||||
"""Détecter un champ password via les tags de l'élément UI."""
|
||||
screen_state = {
|
||||
"perception": {"detected_text": []},
|
||||
"ui_elements": [
|
||||
{"type": "text_input", "role": "text", "label": "", "center": [500, 300], "element_id": "e1", "tags": ["password"]},
|
||||
],
|
||||
"window": {"app_name": "SomeApp"},
|
||||
}
|
||||
|
||||
auth_req = handler.detect_auth_screen(screen_state)
|
||||
assert auth_req is not None
|
||||
assert "password_field" in auth_req.detected_fields
|
||||
721
tests/unit/test_learning_pack.py
Normal file
721
tests/unit/test_learning_pack.py
Normal file
@@ -0,0 +1,721 @@
|
||||
"""
|
||||
Tests unitaires pour core.federation.learning_pack
|
||||
|
||||
Vérifie :
|
||||
- Export d'un workflow simple → pas de screenshots/OCR dans le pack
|
||||
- Merge de 2 packs → déduplication correcte des prototypes
|
||||
- Sérialisation / désérialisation JSON round-trip
|
||||
- Anonymisation du client_id (SHA-256, pas en clair)
|
||||
- Filtrage des données sensibles (textes OCR longs, métadonnées)
|
||||
- Index FAISS global (construction, recherche, persistance)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from core.federation.learning_pack import (
|
||||
DEDUP_COSINE_THRESHOLD,
|
||||
LEARNING_PACK_VERSION,
|
||||
AppSignature,
|
||||
EdgeStatistic,
|
||||
ErrorPattern,
|
||||
LearningPack,
|
||||
LearningPackExporter,
|
||||
LearningPackMerger,
|
||||
ScreenPrototype,
|
||||
UIPattern,
|
||||
WorkflowSkeleton,
|
||||
_hash_client_id,
|
||||
_sanitize_text,
|
||||
)
|
||||
from core.models.workflow_graph import (
|
||||
Action,
|
||||
EdgeConstraints,
|
||||
EdgeStats,
|
||||
EmbeddingPrototype,
|
||||
PostConditionCheck,
|
||||
PostConditions,
|
||||
ScreenTemplate,
|
||||
TargetSpec,
|
||||
TextConstraint,
|
||||
UIConstraint,
|
||||
WindowConstraint,
|
||||
Workflow,
|
||||
WorkflowEdge,
|
||||
WorkflowNode,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helpers — construction de workflows de test
|
||||
# ============================================================================
|
||||
|
||||
def _make_node(
|
||||
node_id: str,
|
||||
name: str,
|
||||
process_name: str = "Notepad.exe",
|
||||
title_pattern: str = ".*Sans titre.*",
|
||||
required_roles: List[str] = None,
|
||||
prototype_vector: List[float] = None,
|
||||
) -> WorkflowNode:
|
||||
"""Créer un WorkflowNode minimal pour les tests."""
|
||||
window = WindowConstraint(
|
||||
title_pattern=title_pattern,
|
||||
process_name=process_name,
|
||||
)
|
||||
text = TextConstraint(
|
||||
required_texts=["Fichier", "Edition"],
|
||||
forbidden_texts=["Erreur critique"],
|
||||
)
|
||||
ui = UIConstraint(
|
||||
required_roles=required_roles or ["button", "textfield"],
|
||||
)
|
||||
embedding = EmbeddingPrototype(
|
||||
provider="openclip_ViT-B-32",
|
||||
vector_id="",
|
||||
min_cosine_similarity=0.85,
|
||||
sample_count=5,
|
||||
)
|
||||
template = ScreenTemplate(window=window, text=text, ui=ui, embedding=embedding)
|
||||
|
||||
metadata = {}
|
||||
if prototype_vector is not None:
|
||||
metadata["_prototype_vector"] = prototype_vector
|
||||
|
||||
return WorkflowNode(
|
||||
node_id=node_id,
|
||||
name=name,
|
||||
description=f"Node de test : {name}",
|
||||
template=template,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def _make_edge(
|
||||
edge_id: str,
|
||||
from_node: str,
|
||||
to_node: str,
|
||||
action_type: str = "mouse_click",
|
||||
target_role: str = "button",
|
||||
fail_fast_texts: List[str] = None,
|
||||
) -> WorkflowEdge:
|
||||
"""Créer un WorkflowEdge minimal pour les tests."""
|
||||
target = TargetSpec(by_role=target_role)
|
||||
action = Action(type=action_type, target=target)
|
||||
constraints = EdgeConstraints()
|
||||
|
||||
fail_fast = []
|
||||
for txt in (fail_fast_texts or []):
|
||||
fail_fast.append(PostConditionCheck(kind="text_present", value=txt))
|
||||
|
||||
post_conditions = PostConditions(fail_fast=fail_fast)
|
||||
stats = EdgeStats(execution_count=10, success_count=9, avg_execution_time_ms=150.0)
|
||||
|
||||
return WorkflowEdge(
|
||||
edge_id=edge_id,
|
||||
from_node=from_node,
|
||||
to_node=to_node,
|
||||
action=action,
|
||||
constraints=constraints,
|
||||
post_conditions=post_conditions,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
|
||||
def _make_workflow(
|
||||
workflow_id: str = "wf_test_001",
|
||||
name: str = "Workflow Test",
|
||||
with_vectors: bool = True,
|
||||
) -> Workflow:
|
||||
"""Créer un Workflow complet minimal pour les tests."""
|
||||
vec_a = np.random.randn(512).tolist() if with_vectors else None
|
||||
vec_b = np.random.randn(512).tolist() if with_vectors else None
|
||||
|
||||
node_a = _make_node("node_a", "Écran principal", prototype_vector=vec_a)
|
||||
node_b = _make_node(
|
||||
"node_b", "Dialogue Enregistrer",
|
||||
process_name="Notepad.exe",
|
||||
title_pattern=".*Enregistrer.*",
|
||||
prototype_vector=vec_b,
|
||||
)
|
||||
|
||||
edge_ab = _make_edge(
|
||||
"edge_ab", "node_a", "node_b",
|
||||
fail_fast_texts=["Accès refusé", "Fichier introuvable"],
|
||||
)
|
||||
|
||||
now = datetime.now()
|
||||
return Workflow(
|
||||
workflow_id=workflow_id,
|
||||
name=name,
|
||||
description="Workflow de test pour Learning Pack",
|
||||
version=1,
|
||||
learning_state="COACHING",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
entry_nodes=["node_a"],
|
||||
end_nodes=["node_b"],
|
||||
nodes=[node_a, node_b],
|
||||
edges=[edge_ab],
|
||||
safety_rules=Workflow.from_dict({
|
||||
"workflow_id": "tmp", "name": "tmp", "nodes": [], "edges": [],
|
||||
"safety_rules": {}, "stats": {}, "learning": {},
|
||||
"entry_nodes": [], "end_nodes": [], "created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat(),
|
||||
}).safety_rules,
|
||||
stats=Workflow.from_dict({
|
||||
"workflow_id": "tmp", "name": "tmp", "nodes": [], "edges": [],
|
||||
"safety_rules": {}, "stats": {}, "learning": {},
|
||||
"entry_nodes": [], "end_nodes": [], "created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat(),
|
||||
}).stats,
|
||||
learning=Workflow.from_dict({
|
||||
"workflow_id": "tmp", "name": "tmp", "nodes": [], "edges": [],
|
||||
"safety_rules": {}, "stats": {}, "learning": {},
|
||||
"entry_nodes": [], "end_nodes": [], "created_at": now.isoformat(),
|
||||
"updated_at": now.isoformat(),
|
||||
}).learning,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests — Anonymisation
|
||||
# ============================================================================
|
||||
|
||||
class TestAnonymisation:
|
||||
"""Vérifier que l'anonymisation fonctionne correctement."""
|
||||
|
||||
def test_client_id_est_hashe(self):
|
||||
"""Le client_id ne doit PAS apparaître en clair dans le pack."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="CHU-Lyon-001")
|
||||
|
||||
pack_json = json.dumps(pack.to_dict())
|
||||
assert "CHU-Lyon-001" not in pack_json, \
|
||||
"Le client_id apparaît en clair dans le pack !"
|
||||
|
||||
def test_source_hash_est_sha256(self):
|
||||
"""Le source_hash doit être un hash SHA-256 du client_id."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="CHU-Lyon-001")
|
||||
|
||||
expected_hash = hashlib.sha256(b"CHU-Lyon-001").hexdigest()
|
||||
assert pack.source_hash == expected_hash
|
||||
|
||||
def test_hash_client_id_deterministe(self):
|
||||
"""Le même client_id doit toujours donner le même hash."""
|
||||
h1 = _hash_client_id("Clinique-Pasteur")
|
||||
h2 = _hash_client_id("Clinique-Pasteur")
|
||||
assert h1 == h2
|
||||
|
||||
def test_hash_client_id_differents(self):
|
||||
"""Deux client_id différents doivent donner des hash différents."""
|
||||
h1 = _hash_client_id("CHU-Lyon")
|
||||
h2 = _hash_client_id("CHU-Marseille")
|
||||
assert h1 != h2
|
||||
|
||||
def test_pas_de_screenshots_dans_pack(self):
|
||||
"""Le pack ne doit contenir aucun chemin de screenshot."""
|
||||
wf = _make_workflow()
|
||||
# Ajouter un chemin screenshot dans les métadonnées du node
|
||||
wf.nodes[0].metadata["screenshot_path"] = "/tmp/capture_001.png"
|
||||
wf.nodes[0].metadata["ocr_text"] = "Texte OCR brut avec données patient"
|
||||
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="test")
|
||||
|
||||
pack_json = json.dumps(pack.to_dict())
|
||||
assert "/tmp/capture_001.png" not in pack_json
|
||||
assert "données patient" not in pack_json
|
||||
|
||||
def test_texte_ocr_long_filtre(self):
|
||||
"""Les textes OCR longs (> 120 chars) doivent être filtrés."""
|
||||
assert _sanitize_text("OK") == "OK"
|
||||
assert _sanitize_text("x" * 200) is None
|
||||
assert _sanitize_text("") is None
|
||||
|
||||
def test_texte_patient_filtre(self):
|
||||
"""Les textes contenant des identifiants patient doivent être filtrés."""
|
||||
assert _sanitize_text("patient Dupont") is None
|
||||
assert _sanitize_text("NIP: 123456") is None
|
||||
assert _sanitize_text("Dossier n°789") is None
|
||||
|
||||
def test_texte_court_et_sur_passe(self):
|
||||
"""Les textes courts et non-sensibles doivent passer."""
|
||||
assert _sanitize_text("Enregistrer") == "Enregistrer"
|
||||
assert _sanitize_text("Fichier") == "Fichier"
|
||||
assert _sanitize_text("Erreur de connexion") == "Erreur de connexion"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests — Export
|
||||
# ============================================================================
|
||||
|
||||
class TestExport:
|
||||
"""Vérifier l'export de workflows en Learning Pack."""
|
||||
|
||||
def test_export_basique(self):
|
||||
"""Export d'un workflow simple doit produire un pack valide."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="test_client")
|
||||
|
||||
assert pack.version == LEARNING_PACK_VERSION
|
||||
assert pack.pack_id.startswith("lp_")
|
||||
assert pack.source_hash # Non vide
|
||||
assert pack.created_at # Non vide
|
||||
|
||||
def test_export_stats(self):
|
||||
"""Les stats du pack doivent refléter le contenu."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="test")
|
||||
|
||||
assert pack.stats["workflows_count"] == 1
|
||||
assert pack.stats["total_nodes"] == 2
|
||||
assert pack.stats["total_edges"] == 1
|
||||
assert "Notepad.exe" in pack.stats["apps_seen"]
|
||||
|
||||
def test_export_prototypes_avec_vecteurs(self):
|
||||
"""Les prototypes doivent contenir les vecteurs 512d."""
|
||||
wf = _make_workflow(with_vectors=True)
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="test")
|
||||
|
||||
assert len(pack.screen_prototypes) == 2
|
||||
for proto in pack.screen_prototypes:
|
||||
assert proto.vector is not None
|
||||
assert len(proto.vector) == 512
|
||||
|
||||
def test_export_prototypes_sans_vecteurs(self):
|
||||
"""L'export doit fonctionner même sans vecteurs prototype."""
|
||||
wf = _make_workflow(with_vectors=False)
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="test")
|
||||
|
||||
# Les prototypes sont exportés mais sans vecteur
|
||||
assert len(pack.screen_prototypes) == 2
|
||||
for proto in pack.screen_prototypes:
|
||||
assert proto.vector is None
|
||||
|
||||
def test_export_app_signatures(self):
|
||||
"""Les signatures d'application doivent être collectées."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="test")
|
||||
|
||||
app_names = [sig.app_name for sig in pack.app_signatures]
|
||||
assert "Notepad.exe" in app_names
|
||||
|
||||
def test_export_error_patterns(self):
|
||||
"""Les patterns d'erreur des PostConditions doivent être extraits."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="test")
|
||||
|
||||
error_texts = [ep.error_text for ep in pack.error_patterns]
|
||||
assert "Accès refusé" in error_texts
|
||||
assert "Fichier introuvable" in error_texts
|
||||
|
||||
def test_export_edge_statistics(self):
|
||||
"""Les statistiques d'edges doivent être exportées."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="test")
|
||||
|
||||
assert len(pack.edge_statistics) == 1
|
||||
stat = pack.edge_statistics[0]
|
||||
assert stat.action_type == "mouse_click"
|
||||
assert stat.execution_count == 10
|
||||
assert stat.success_rate == 0.9
|
||||
|
||||
def test_export_workflow_skeleton(self):
|
||||
"""Le squelette du workflow doit refléter la structure."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="test")
|
||||
|
||||
assert len(pack.workflow_skeletons) == 1
|
||||
skel = pack.workflow_skeletons[0]
|
||||
assert skel.node_count == 2
|
||||
assert skel.edge_count == 1
|
||||
assert "Écran principal" in skel.node_names
|
||||
assert skel.learning_state == "COACHING"
|
||||
|
||||
def test_export_action_sans_texte_saisi(self):
|
||||
"""L'export ne doit PAS inclure le texte saisi (action text_input)."""
|
||||
wf = _make_workflow()
|
||||
# Ajouter un edge text_input avec un texte sensible
|
||||
edge_text = _make_edge(
|
||||
"edge_text", "node_a", "node_b",
|
||||
action_type="text_input", target_role="textfield",
|
||||
)
|
||||
edge_text.action.parameters["text"] = "mot_de_passe_secret_123"
|
||||
wf.edges.append(edge_text)
|
||||
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="test")
|
||||
|
||||
pack_json = json.dumps(pack.to_dict())
|
||||
assert "mot_de_passe_secret_123" not in pack_json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests — Sérialisation
|
||||
# ============================================================================
|
||||
|
||||
class TestSerialisation:
|
||||
"""Vérifier le round-trip JSON (to_dict → from_dict)."""
|
||||
|
||||
def test_round_trip_learning_pack(self):
|
||||
"""Sérialisation → désérialisation doit être idempotente."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="round_trip_test")
|
||||
|
||||
# Sérialiser → désérialiser
|
||||
data = pack.to_dict()
|
||||
restored = LearningPack.from_dict(data)
|
||||
|
||||
assert restored.version == pack.version
|
||||
assert restored.source_hash == pack.source_hash
|
||||
assert restored.pack_id == pack.pack_id
|
||||
assert len(restored.screen_prototypes) == len(pack.screen_prototypes)
|
||||
assert len(restored.workflow_skeletons) == len(pack.workflow_skeletons)
|
||||
assert len(restored.error_patterns) == len(pack.error_patterns)
|
||||
assert len(restored.edge_statistics) == len(pack.edge_statistics)
|
||||
|
||||
def test_round_trip_json_string(self):
|
||||
"""Le JSON doit être parseable et reproductible."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="json_test")
|
||||
|
||||
json_str = json.dumps(pack.to_dict(), sort_keys=True)
|
||||
data = json.loads(json_str)
|
||||
restored = LearningPack.from_dict(data)
|
||||
|
||||
assert json.dumps(restored.to_dict(), sort_keys=True) == json_str
|
||||
|
||||
def test_save_load_fichier(self, tmp_path):
|
||||
"""Sauvegarde → chargement fichier doit être idempotent."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="file_test")
|
||||
|
||||
filepath = tmp_path / "test_pack.json"
|
||||
pack.save(filepath)
|
||||
|
||||
loaded = LearningPack.load(filepath)
|
||||
assert loaded.pack_id == pack.pack_id
|
||||
assert loaded.source_hash == pack.source_hash
|
||||
assert len(loaded.screen_prototypes) == len(pack.screen_prototypes)
|
||||
|
||||
def test_all_sub_dataclasses_round_trip(self):
|
||||
"""Chaque sous-structure doit supporter le round-trip."""
|
||||
sig = AppSignature(app_name="Chrome.exe", version="120.0", observation_count=5)
|
||||
assert AppSignature.from_dict(sig.to_dict()).app_name == "Chrome.exe"
|
||||
|
||||
proto = ScreenPrototype(
|
||||
prototype_id="test",
|
||||
vector=[1.0, 2.0, 3.0],
|
||||
provider="test_provider",
|
||||
)
|
||||
restored = ScreenPrototype.from_dict(proto.to_dict())
|
||||
assert restored.vector == [1.0, 2.0, 3.0]
|
||||
|
||||
skel = WorkflowSkeleton(
|
||||
skeleton_id="sk1", name="Test", description="",
|
||||
learning_state="OBSERVATION", node_names=["A", "B"],
|
||||
edge_summaries=[], entry_nodes=["A"], end_nodes=["B"],
|
||||
)
|
||||
assert WorkflowSkeleton.from_dict(skel.to_dict()).name == "Test"
|
||||
|
||||
err = ErrorPattern(pattern_id="e1", error_text="Timeout")
|
||||
assert ErrorPattern.from_dict(err.to_dict()).error_text == "Timeout"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests — Merge
|
||||
# ============================================================================
|
||||
|
||||
class TestMerge:
|
||||
"""Vérifier la fusion de plusieurs Learning Packs."""
|
||||
|
||||
def test_merge_deux_packs(self):
|
||||
"""Fusionner 2 packs doit produire un pack combiné."""
|
||||
wf1 = _make_workflow("wf_1", "Workflow A")
|
||||
wf2 = _make_workflow("wf_2", "Workflow B")
|
||||
|
||||
exporter = LearningPackExporter()
|
||||
pack_a = exporter.export([wf1], client_id="Client-A")
|
||||
pack_b = exporter.export([wf2], client_id="Client-B")
|
||||
|
||||
merger = LearningPackMerger()
|
||||
merged = merger.merge([pack_a, pack_b])
|
||||
|
||||
assert merged.stats["workflows_count"] == 2
|
||||
assert merged.stats["source_packs_count"] == 2
|
||||
assert merged.pack_id.startswith("lp_merged_")
|
||||
|
||||
def test_merge_deduplication_prototypes_identiques(self):
|
||||
"""Deux prototypes avec le même vecteur doivent être fusionnés."""
|
||||
# Créer un vecteur fixe pour les deux packs
|
||||
fixed_vec = np.random.randn(512).tolist()
|
||||
|
||||
wf1 = _make_workflow("wf_same_1")
|
||||
wf1.nodes[0].metadata["_prototype_vector"] = fixed_vec
|
||||
wf2 = _make_workflow("wf_same_2")
|
||||
wf2.nodes[0].metadata["_prototype_vector"] = fixed_vec
|
||||
|
||||
exporter = LearningPackExporter()
|
||||
pack_a = exporter.export([wf1], client_id="A")
|
||||
pack_b = exporter.export([wf2], client_id="B")
|
||||
|
||||
# Avant merge : 2 prototypes avec le même vecteur pour node_a
|
||||
total_before = len(pack_a.screen_prototypes) + len(pack_b.screen_prototypes)
|
||||
assert total_before == 4 # 2 nodes × 2 packs
|
||||
|
||||
merger = LearningPackMerger()
|
||||
merged = merger.merge([pack_a, pack_b])
|
||||
|
||||
# Après merge : les prototypes identiques (node_a) doivent être dédupliqués
|
||||
# node_b a des vecteurs différents (random), donc pas de dédup
|
||||
# node_a est identique → fusionné en 1
|
||||
# Résultat attendu : entre 2 et 3 prototypes (1 dédupliqué + 2 différents)
|
||||
assert len(merged.screen_prototypes) < total_before
|
||||
|
||||
def test_merge_prototypes_differents_conserves(self):
|
||||
"""Deux prototypes très différents ne doivent PAS être fusionnés."""
|
||||
# Créer deux vecteurs orthogonaux
|
||||
vec_a = np.zeros(512, dtype=np.float32)
|
||||
vec_a[0] = 1.0
|
||||
vec_b = np.zeros(512, dtype=np.float32)
|
||||
vec_b[1] = 1.0
|
||||
|
||||
wf1 = _make_workflow("wf_diff_1")
|
||||
wf1.nodes[0].metadata["_prototype_vector"] = vec_a.tolist()
|
||||
# Supprimer node_b pour simplifier
|
||||
wf1.nodes = [wf1.nodes[0]]
|
||||
wf1.edges = []
|
||||
|
||||
wf2 = _make_workflow("wf_diff_2")
|
||||
wf2.nodes[0].metadata["_prototype_vector"] = vec_b.tolist()
|
||||
wf2.nodes = [wf2.nodes[0]]
|
||||
wf2.edges = []
|
||||
|
||||
exporter = LearningPackExporter()
|
||||
pack_a = exporter.export([wf1], client_id="A")
|
||||
pack_b = exporter.export([wf2], client_id="B")
|
||||
|
||||
merger = LearningPackMerger()
|
||||
merged = merger.merge([pack_a, pack_b])
|
||||
|
||||
# Les deux prototypes sont très différents → pas de dédup
|
||||
assert len(merged.screen_prototypes) == 2
|
||||
|
||||
def test_merge_error_patterns_cross_clients(self):
|
||||
"""Les patterns d'erreur vus par plusieurs clients ont un cross_client_count > 1."""
|
||||
# Même erreur dans les deux packs
|
||||
wf1 = _make_workflow("wf_err_1")
|
||||
wf2 = _make_workflow("wf_err_2")
|
||||
|
||||
exporter = LearningPackExporter()
|
||||
pack_a = exporter.export([wf1], client_id="Hôpital-A")
|
||||
pack_b = exporter.export([wf2], client_id="Hôpital-B")
|
||||
|
||||
merger = LearningPackMerger()
|
||||
merged = merger.merge([pack_a, pack_b])
|
||||
|
||||
# "Accès refusé" et "Fichier introuvable" sont dans les deux packs
|
||||
for ep in merged.error_patterns:
|
||||
if ep.error_text == "Accès refusé":
|
||||
assert ep.cross_client_count == 2
|
||||
assert ep.observation_count == 2 # 1 par pack
|
||||
break
|
||||
else:
|
||||
pytest.fail("Pattern 'Accès refusé' non trouvé dans le merge")
|
||||
|
||||
def test_merge_app_signatures_union(self):
|
||||
"""Les signatures d'application doivent être l'union des packs."""
|
||||
wf1 = _make_workflow("wf_app_1")
|
||||
wf2 = _make_workflow("wf_app_2")
|
||||
# Changer l'app du deuxième workflow
|
||||
wf2.nodes[0].template.window.process_name = "Chrome.exe"
|
||||
|
||||
exporter = LearningPackExporter()
|
||||
pack_a = exporter.export([wf1], client_id="A")
|
||||
pack_b = exporter.export([wf2], client_id="B")
|
||||
|
||||
merger = LearningPackMerger()
|
||||
merged = merger.merge([pack_a, pack_b])
|
||||
|
||||
app_names = {sig.app_name for sig in merged.app_signatures}
|
||||
assert "Notepad.exe" in app_names
|
||||
assert "Chrome.exe" in app_names
|
||||
|
||||
def test_merge_liste_vide(self):
|
||||
"""Merger une liste vide retourne un pack vide."""
|
||||
merger = LearningPackMerger()
|
||||
merged = merger.merge([])
|
||||
assert merged.pack_id.startswith("lp_merged_")
|
||||
assert len(merged.screen_prototypes) == 0
|
||||
|
||||
def test_merge_un_seul_pack(self):
|
||||
"""Merger un seul pack le retourne avec un nouveau pack_id."""
|
||||
wf = _make_workflow()
|
||||
exporter = LearningPackExporter()
|
||||
pack = exporter.export([wf], client_id="solo")
|
||||
|
||||
merger = LearningPackMerger()
|
||||
merged = merger.merge([pack])
|
||||
|
||||
assert merged.pack_id != pack.pack_id
|
||||
assert merged.pack_id.startswith("lp_merged_")
|
||||
assert len(merged.screen_prototypes) == len(pack.screen_prototypes)
|
||||
|
||||
def test_merge_edge_statistics_moyennes(self):
|
||||
"""Les statistiques d'edges doivent être combinées par moyenne pondérée."""
|
||||
wf1 = _make_workflow("wf_stat_1")
|
||||
wf2 = _make_workflow("wf_stat_2")
|
||||
|
||||
exporter = LearningPackExporter()
|
||||
pack_a = exporter.export([wf1], client_id="A")
|
||||
pack_b = exporter.export([wf2], client_id="B")
|
||||
|
||||
merger = LearningPackMerger()
|
||||
merged = merger.merge([pack_a, pack_b])
|
||||
|
||||
# Les edges ont les mêmes noms de nodes → ils sont mergés
|
||||
for stat in merged.edge_statistics:
|
||||
if stat.from_node_name == "Écran principal":
|
||||
# 10 exécutions par pack → 20 au total
|
||||
assert stat.execution_count == 20
|
||||
# success_rate = 0.9 pour les deux → moyenne = 0.9
|
||||
assert abs(stat.success_rate - 0.9) < 0.01
|
||||
break
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests — Index FAISS Global
|
||||
# ============================================================================
|
||||
|
||||
class TestGlobalFAISSIndex:
|
||||
"""Tests de l'index FAISS global (nécessite faiss-cpu)."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_packs(self):
|
||||
"""Créer deux packs de test avec des vecteurs."""
|
||||
wf1 = _make_workflow("wf_faiss_1", "Workflow FAISS A")
|
||||
wf2 = _make_workflow("wf_faiss_2", "Workflow FAISS B")
|
||||
|
||||
exporter = LearningPackExporter()
|
||||
pack_a = exporter.export([wf1], client_id="Client-FAISS-A")
|
||||
pack_b = exporter.export([wf2], client_id="Client-FAISS-B")
|
||||
return [pack_a, pack_b]
|
||||
|
||||
def test_build_from_packs(self, sample_packs):
|
||||
"""Construction de l'index depuis les packs."""
|
||||
try:
|
||||
from core.federation.faiss_global import GlobalFAISSIndex
|
||||
except ImportError:
|
||||
pytest.skip("FAISS non installé")
|
||||
|
||||
index = GlobalFAISSIndex(dimensions=512)
|
||||
count = index.build_from_packs(sample_packs)
|
||||
|
||||
assert count > 0
|
||||
assert index.total_vectors == count
|
||||
|
||||
def test_search(self, sample_packs):
|
||||
"""Recherche dans l'index global."""
|
||||
try:
|
||||
from core.federation.faiss_global import GlobalFAISSIndex
|
||||
except ImportError:
|
||||
pytest.skip("FAISS non installé")
|
||||
|
||||
index = GlobalFAISSIndex(dimensions=512)
|
||||
index.build_from_packs(sample_packs)
|
||||
|
||||
# Chercher avec un vecteur aléatoire
|
||||
query = np.random.randn(512).astype(np.float32)
|
||||
results = index.search(query, k=3)
|
||||
|
||||
assert len(results) > 0
|
||||
assert len(results) <= 3
|
||||
for r in results:
|
||||
assert r.prototype_id
|
||||
assert r.pack_source_hash
|
||||
assert -1.0 <= r.similarity <= 1.0
|
||||
|
||||
def test_search_index_vide(self):
|
||||
"""Recherche dans un index vide retourne une liste vide."""
|
||||
try:
|
||||
from core.federation.faiss_global import GlobalFAISSIndex
|
||||
except ImportError:
|
||||
pytest.skip("FAISS non installé")
|
||||
|
||||
index = GlobalFAISSIndex(dimensions=512)
|
||||
results = index.search(np.random.randn(512).astype(np.float32))
|
||||
assert results == []
|
||||
|
||||
def test_add_pack_incremental(self, sample_packs):
|
||||
"""Ajout incrémental d'un pack à l'index."""
|
||||
try:
|
||||
from core.federation.faiss_global import GlobalFAISSIndex
|
||||
except ImportError:
|
||||
pytest.skip("FAISS non installé")
|
||||
|
||||
index = GlobalFAISSIndex(dimensions=512)
|
||||
count1 = index.add_pack(sample_packs[0])
|
||||
count2 = index.add_pack(sample_packs[1])
|
||||
|
||||
assert count1 > 0
|
||||
assert count2 > 0
|
||||
assert index.total_vectors == count1 + count2
|
||||
|
||||
def test_save_load(self, sample_packs, tmp_path):
|
||||
"""Sauvegarde et chargement de l'index."""
|
||||
try:
|
||||
from core.federation.faiss_global import GlobalFAISSIndex
|
||||
except ImportError:
|
||||
pytest.skip("FAISS non installé")
|
||||
|
||||
index = GlobalFAISSIndex(dimensions=512)
|
||||
index.build_from_packs(sample_packs)
|
||||
|
||||
base_path = tmp_path / "global_index"
|
||||
index.save(base_path)
|
||||
|
||||
loaded = GlobalFAISSIndex.load(base_path)
|
||||
assert loaded.total_vectors == index.total_vectors
|
||||
assert loaded.dimensions == index.dimensions
|
||||
|
||||
# Vérifier que la recherche fonctionne sur l'index chargé
|
||||
query = np.random.randn(512).astype(np.float32)
|
||||
results = loaded.search(query, k=2)
|
||||
assert len(results) > 0
|
||||
|
||||
def test_get_stats(self, sample_packs):
|
||||
"""Statistiques de l'index global."""
|
||||
try:
|
||||
from core.federation.faiss_global import GlobalFAISSIndex
|
||||
except ImportError:
|
||||
pytest.skip("FAISS non installé")
|
||||
|
||||
index = GlobalFAISSIndex(dimensions=512)
|
||||
index.build_from_packs(sample_packs)
|
||||
|
||||
stats = index.get_stats()
|
||||
assert stats["dimensions"] == 512
|
||||
assert stats["total_vectors"] > 0
|
||||
assert stats["unique_sources"] >= 1
|
||||
Reference in New Issue
Block a user