feat: replay visuel VLM-first, worker séparé, package Léa, AZERTY, sécurité HTTPS
Pipeline replay visuel : - VLM-first : l'agent appelle Ollama directement pour trouver les éléments - Template matching en fallback (seuil strict 0.90) - Stop immédiat si élément non trouvé (pas de clic blind) - Replay depuis session brute (/replay-session) sans attendre le VLM - Vérification post-action (screenshot hash avant/après) - Gestion des popups (Enter/Escape/Tab+Enter) Worker VLM séparé : - run_worker.py : process distinct du serveur HTTP - Communication par fichiers (_worker_queue.txt + _replay_active.lock) - Le serveur HTTP ne fait plus jamais de VLM → toujours réactif - Service systemd rpa-worker.service Capture clavier : - raw_keys (vk + press/release) pour replay exact indépendant du layout - Fix AZERTY : ToUnicodeEx + AltGr detection - Enter capturé comme \n, Tab comme \t - Filtrage modificateurs seuls (Ctrl/Alt/Shift parasites) - Fusion text_input consécutifs, dédup key_combo Sécurité & Internet : - HTTPS Let's Encrypt (lea.labs + vwb.labs.laurinebazin.design) - Token API fixe dans .env.local - HTTP Basic Auth sur VWB - Security headers (HSTS, CSP, nosniff) - CORS domaines publics, plus de wildcard Infrastructure : - DPI awareness (SetProcessDpiAwareness) Python + Rust - Métadonnées système (dpi_scale, window_bounds, monitors, os_theme) - Template matching multi-scale [0.5, 2.0] - Résolution dynamique (plus de hardcode 1920x1080) - VLM prefill fix (47x speedup, 3.5s au lieu de 180s) Modules : - core/auth/ : credential vault (Fernet AES), TOTP (RFC 6238), auth handler - core/federation/ : LearningPack export/import anonymisé, FAISS global - deploy/ : package Léa (config.txt, Lea.bat, install.bat, LISEZMOI.txt) UX : - Filtrage OS (VWB + Chat montrent que les workflows de l'OS courant) - Bibliothèque persistante (cache local + SQLite) - Clustering hybride (titre fenêtre + DBSCAN) - EdgeConstraints + PostConditions peuplés - GraphBuilder compound actions (toutes les frappes) Agent Rust : - Token Bearer auth (network.rs) - sysinfo.rs (DPI, résolution, window bounds via Win32 API) - config.txt lu automatiquement - Support Chrome/Brave/Firefox (pas que Edge) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
6
core/auth/__init__.py
Normal file
6
core/auth/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# core/auth — Module d'authentification automatique pour Léa
|
||||
#
|
||||
# Fournit :
|
||||
# - CredentialVault : coffre-fort chiffré pour les credentials
|
||||
# - TOTPGenerator : générateur TOTP RFC 6238 (sans dépendance externe)
|
||||
# - AuthHandler : détection d'écrans d'auth et injection automatique
|
||||
523
core/auth/auth_handler.py
Normal file
523
core/auth/auth_handler.py
Normal file
@@ -0,0 +1,523 @@
|
||||
"""
|
||||
Gestionnaire d'authentification automatique pendant le replay.
|
||||
|
||||
Détecte les écrans d'authentification et injecte les credentials appropriés.
|
||||
Fonctionne avec le ScreenState du core pipeline et le CredentialVault chiffré.
|
||||
|
||||
Stratégie de détection :
|
||||
1. Analyse OCR : cherche des patterns textuels indicatifs d'un écran d'auth
|
||||
("mot de passe", "identifiant", "code de vérification", etc.)
|
||||
2. Analyse UI : cherche des éléments sémantiques typiques (champ password,
|
||||
bouton "Se connecter", etc.)
|
||||
3. Identification de l'application : via window_title du ScreenState
|
||||
|
||||
La confiance est calculée selon le nombre de signaux détectés :
|
||||
- 1 signal = 0.3 (faible)
|
||||
- 2 signaux = 0.6 (moyen)
|
||||
- 3+ signaux = 0.85+ (élevé)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .credential_vault import CredentialVault
|
||||
from .totp_generator import TOTPGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =========================================================================
|
||||
# Patterns de détection d'écrans d'authentification
|
||||
# =========================================================================
|
||||
|
||||
# Patterns OCR (texte visible sur l'écran) — FR + EN pour support bilingue
|
||||
_AUTH_TEXT_PATTERNS = [
|
||||
# Français
|
||||
r"mot\s+de\s+passe",
|
||||
r"identifiant",
|
||||
r"nom\s+d'utilisateur",
|
||||
r"connexion",
|
||||
r"se\s+connecter",
|
||||
r"authentification",
|
||||
r"code\s+de\s+v[ée]rification",
|
||||
r"code\s+otp",
|
||||
r"double\s+authentification",
|
||||
r"v[ée]rification\s+en\s+deux\s+[ée]tapes",
|
||||
# Anglais
|
||||
r"password",
|
||||
r"username",
|
||||
r"sign\s+in",
|
||||
r"log\s*in",
|
||||
r"verification\s+code",
|
||||
r"two.factor",
|
||||
r"2fa",
|
||||
r"one.time\s+password",
|
||||
r"enter\s+your\s+code",
|
||||
]
|
||||
|
||||
# Patterns pour identifier spécifiquement un écran TOTP/2FA
|
||||
_TOTP_TEXT_PATTERNS = [
|
||||
r"code\s+de\s+v[ée]rification",
|
||||
r"code\s+otp",
|
||||
r"double\s+authentification",
|
||||
r"v[ée]rification\s+en\s+deux",
|
||||
r"two.factor",
|
||||
r"2fa",
|
||||
r"one.time\s+password",
|
||||
r"enter\s+your\s+code",
|
||||
r"code\s+[àa]\s+\d+\s+chiffres",
|
||||
r"authenticator",
|
||||
]
|
||||
|
||||
# Libellés de boutons de validation
|
||||
_SUBMIT_BUTTON_PATTERNS = [
|
||||
r"se\s+connecter",
|
||||
r"connexion",
|
||||
r"valider",
|
||||
r"envoyer",
|
||||
r"confirmer",
|
||||
r"sign\s+in",
|
||||
r"log\s*in",
|
||||
r"submit",
|
||||
r"verify",
|
||||
r"ok",
|
||||
]
|
||||
|
||||
# Compilations pour performance
|
||||
_AUTH_REGEXES = [re.compile(p, re.IGNORECASE) for p in _AUTH_TEXT_PATTERNS]
|
||||
_TOTP_REGEXES = [re.compile(p, re.IGNORECASE) for p in _TOTP_TEXT_PATTERNS]
|
||||
_SUBMIT_REGEXES = [re.compile(p, re.IGNORECASE) for p in _SUBMIT_BUTTON_PATTERNS]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthRequest:
|
||||
"""Requête d'authentification détectée sur un écran.
|
||||
|
||||
Attributes:
|
||||
auth_type: Type d'authentification détecté ("login", "totp", "login_and_totp").
|
||||
app_name: Application identifiée (depuis window_title).
|
||||
detected_fields: Champs détectés sur l'écran (positions, types).
|
||||
confidence: Confiance de la détection (0.0 à 1.0).
|
||||
"""
|
||||
|
||||
auth_type: str # "login", "totp", "login_and_totp"
|
||||
app_name: str # App identifiée (depuis window_title)
|
||||
detected_fields: Dict[str, Any] = field(default_factory=dict)
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
class AuthHandler:
|
||||
"""Gestionnaire d'authentification automatique pour le replay.
|
||||
|
||||
Analyse les ScreenStates pour détecter les écrans d'authentification
|
||||
et génère les actions de replay correspondantes.
|
||||
|
||||
Usage :
|
||||
handler = AuthHandler(vault)
|
||||
auth_req = handler.detect_auth_screen(screen_state)
|
||||
if auth_req:
|
||||
actions = handler.get_auth_actions(auth_req)
|
||||
# Injecter les actions dans la queue de replay
|
||||
"""
|
||||
|
||||
def __init__(self, vault: CredentialVault):
|
||||
"""Initialise le gestionnaire d'authentification.
|
||||
|
||||
Args:
|
||||
vault: Instance du coffre-fort de credentials.
|
||||
"""
|
||||
self._vault = vault
|
||||
|
||||
def detect_auth_screen(self, screen_state: Any) -> Optional[AuthRequest]:
|
||||
"""Analyse un ScreenState pour détecter un écran d'authentification.
|
||||
|
||||
La détection combine plusieurs signaux :
|
||||
- Textes OCR correspondant à des patterns d'auth
|
||||
- Éléments UI de type password/text_input
|
||||
- Boutons de validation ("Se connecter", "Valider")
|
||||
|
||||
Args:
|
||||
screen_state: ScreenState du core pipeline (ou dict compatible).
|
||||
|
||||
Returns:
|
||||
AuthRequest si un écran d'auth est détecté avec confiance > 0.3,
|
||||
None sinon.
|
||||
"""
|
||||
# Extraire les textes détectés et les éléments UI
|
||||
texts = self._extract_texts(screen_state)
|
||||
ui_elements = self._extract_ui_elements(screen_state)
|
||||
app_name = self._extract_app_name(screen_state)
|
||||
|
||||
# Compteur de signaux de détection
|
||||
signals: Dict[str, Any] = {}
|
||||
|
||||
# Signal 1 : Patterns textuels d'authentification
|
||||
auth_text_matches = []
|
||||
for text in texts:
|
||||
for regex in _AUTH_REGEXES:
|
||||
if regex.search(text):
|
||||
auth_text_matches.append(regex.pattern)
|
||||
if auth_text_matches:
|
||||
signals["auth_text"] = auth_text_matches
|
||||
|
||||
# Signal 2 : Patterns textuels TOTP/2FA
|
||||
totp_text_matches = []
|
||||
for text in texts:
|
||||
for regex in _TOTP_REGEXES:
|
||||
if regex.search(text):
|
||||
totp_text_matches.append(regex.pattern)
|
||||
if totp_text_matches:
|
||||
signals["totp_text"] = totp_text_matches
|
||||
|
||||
# Signal 3 : Champs UI de type password
|
||||
password_fields = []
|
||||
username_fields = []
|
||||
submit_buttons = []
|
||||
otp_fields = []
|
||||
|
||||
for elem in ui_elements:
|
||||
elem_type = self._get_elem_attr(elem, "type", "")
|
||||
elem_role = self._get_elem_attr(elem, "role", "")
|
||||
elem_label = self._get_elem_attr(elem, "label", "").lower()
|
||||
elem_tags = self._get_elem_attr(elem, "tags", [])
|
||||
|
||||
# Champ mot de passe
|
||||
if elem_role == "password" or "password" in elem_tags:
|
||||
password_fields.append(elem)
|
||||
elif elem_type == "text_input" and any(
|
||||
p in elem_label for p in ("mot de passe", "password", "mdp")
|
||||
):
|
||||
password_fields.append(elem)
|
||||
|
||||
# Champ identifiant/username
|
||||
if elem_type == "text_input" and any(
|
||||
p in elem_label
|
||||
for p in ("identifiant", "username", "utilisateur", "login", "email", "e-mail")
|
||||
):
|
||||
username_fields.append(elem)
|
||||
|
||||
# Champ OTP
|
||||
if elem_type == "text_input" and any(
|
||||
p in elem_label for p in ("code", "otp", "vérification", "verification")
|
||||
):
|
||||
otp_fields.append(elem)
|
||||
|
||||
# Bouton de validation
|
||||
if elem_type == "button":
|
||||
for regex in _SUBMIT_REGEXES:
|
||||
if regex.search(elem_label):
|
||||
submit_buttons.append(elem)
|
||||
break
|
||||
|
||||
if password_fields:
|
||||
signals["password_field"] = len(password_fields)
|
||||
if username_fields:
|
||||
signals["username_field"] = len(username_fields)
|
||||
if submit_buttons:
|
||||
signals["submit_button"] = len(submit_buttons)
|
||||
if otp_fields:
|
||||
signals["otp_field"] = len(otp_fields)
|
||||
|
||||
# Pas assez de signaux → pas d'écran d'auth
|
||||
if not signals:
|
||||
return None
|
||||
|
||||
# Déterminer le type d'auth
|
||||
# Les signaux textuels "auth_text" peuvent contenir des patterns ambigus
|
||||
# (ex: "2fa" apparaît dans les deux listes). On ne compte comme signal
|
||||
# login que les patterns auth_text qui ne sont PAS aussi des patterns TOTP.
|
||||
auth_only_text = set(signals.get("auth_text", [])) - set(signals.get("totp_text", []))
|
||||
has_login_signals = bool(
|
||||
password_fields
|
||||
or auth_only_text
|
||||
or username_fields
|
||||
)
|
||||
has_totp_signals = bool(
|
||||
otp_fields
|
||||
or "totp_text" in signals
|
||||
)
|
||||
|
||||
if has_login_signals and has_totp_signals:
|
||||
auth_type = "login_and_totp"
|
||||
elif has_totp_signals:
|
||||
auth_type = "totp"
|
||||
else:
|
||||
auth_type = "login"
|
||||
|
||||
# Calculer la confiance (nombre de signaux distincts)
|
||||
num_signals = len(signals)
|
||||
if num_signals >= 4:
|
||||
confidence = 0.95
|
||||
elif num_signals >= 3:
|
||||
confidence = 0.85
|
||||
elif num_signals >= 2:
|
||||
confidence = 0.6
|
||||
else:
|
||||
confidence = 0.3
|
||||
|
||||
# Construire les champs détectés
|
||||
detected_fields: Dict[str, Any] = {}
|
||||
if username_fields:
|
||||
detected_fields["username_field"] = self._elem_to_dict(username_fields[0])
|
||||
if password_fields:
|
||||
detected_fields["password_field"] = self._elem_to_dict(password_fields[0])
|
||||
if otp_fields:
|
||||
detected_fields["otp_field"] = self._elem_to_dict(otp_fields[0])
|
||||
if submit_buttons:
|
||||
detected_fields["submit_button"] = self._elem_to_dict(submit_buttons[0])
|
||||
|
||||
auth_request = AuthRequest(
|
||||
auth_type=auth_type,
|
||||
app_name=app_name,
|
||||
detected_fields=detected_fields,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Écran d'authentification détecté : type=%s app=%s confiance=%.2f signaux=%s",
|
||||
auth_type,
|
||||
app_name,
|
||||
confidence,
|
||||
list(signals.keys()),
|
||||
)
|
||||
|
||||
return auth_request
|
||||
|
||||
def get_auth_actions(self, auth_request: AuthRequest) -> List[Dict[str, Any]]:
|
||||
"""Génère les actions de replay pour s'authentifier.
|
||||
|
||||
Produit une séquence d'actions que l'Agent V1 peut exécuter :
|
||||
- click sur le champ username, type le login
|
||||
- click sur le champ password, type le mot de passe
|
||||
- (optionnel) type le code TOTP
|
||||
- click sur le bouton de validation
|
||||
|
||||
Args:
|
||||
auth_request: Requête d'authentification détectée.
|
||||
|
||||
Returns:
|
||||
Liste d'actions de replay (format compatible avec la queue de replay).
|
||||
Liste vide si les credentials ne sont pas trouvés dans le vault.
|
||||
"""
|
||||
actions: List[Dict[str, Any]] = []
|
||||
app_name = auth_request.app_name
|
||||
fields = auth_request.detected_fields
|
||||
|
||||
# Générer un préfixe unique pour les action_ids
|
||||
prefix = f"auth_{uuid.uuid4().hex[:6]}"
|
||||
|
||||
# ---- Login : username + password ----
|
||||
if auth_request.auth_type in ("login", "login_and_totp"):
|
||||
login_creds = self._vault.get_credential(app_name, "login")
|
||||
if not login_creds:
|
||||
logger.warning(
|
||||
"Pas de credential 'login' pour l'app '%s' dans le vault",
|
||||
app_name,
|
||||
)
|
||||
return []
|
||||
|
||||
# Action : cliquer sur le champ username et taper
|
||||
username_field = fields.get("username_field")
|
||||
if username_field:
|
||||
actions.append({
|
||||
"action_id": f"{prefix}_click_username",
|
||||
"type": "click",
|
||||
"target": username_field.get("center", [0, 0]),
|
||||
"description": f"Clic champ identifiant ({app_name})",
|
||||
"_auth_action": True,
|
||||
})
|
||||
actions.append({
|
||||
"action_id": f"{prefix}_type_username",
|
||||
"type": "type_text",
|
||||
"text": login_creds.get("username", ""),
|
||||
"description": f"Saisie identifiant ({app_name})",
|
||||
"_auth_action": True,
|
||||
})
|
||||
|
||||
# Action : cliquer sur le champ password et taper
|
||||
password_field = fields.get("password_field")
|
||||
if password_field:
|
||||
actions.append({
|
||||
"action_id": f"{prefix}_click_password",
|
||||
"type": "click",
|
||||
"target": password_field.get("center", [0, 0]),
|
||||
"description": f"Clic champ mot de passe ({app_name})",
|
||||
"_auth_action": True,
|
||||
})
|
||||
actions.append({
|
||||
"action_id": f"{prefix}_type_password",
|
||||
"type": "type_text",
|
||||
"text": login_creds.get("password", ""),
|
||||
"description": f"Saisie mot de passe ({app_name})",
|
||||
"_auth_action": True,
|
||||
})
|
||||
|
||||
# ---- TOTP : générer et taper le code ----
|
||||
if auth_request.auth_type in ("totp", "login_and_totp"):
|
||||
totp_creds = self._vault.get_credential(app_name, "totp_seed")
|
||||
if not totp_creds:
|
||||
logger.warning(
|
||||
"Pas de credential 'totp_seed' pour l'app '%s' dans le vault",
|
||||
app_name,
|
||||
)
|
||||
# On continue quand même si le login a été fait
|
||||
if not actions:
|
||||
return []
|
||||
else:
|
||||
totp = TOTPGenerator(
|
||||
secret=totp_creds["secret"],
|
||||
digits=totp_creds.get("digits", 6),
|
||||
interval=totp_creds.get("interval", 30),
|
||||
algorithm=totp_creds.get("algorithm", "SHA1"),
|
||||
)
|
||||
|
||||
# Attendre si le code expire dans moins de 5 secondes
|
||||
remaining = totp.time_remaining()
|
||||
if remaining < 5:
|
||||
actions.append({
|
||||
"action_id": f"{prefix}_wait_totp",
|
||||
"type": "wait",
|
||||
"duration_ms": (remaining + 1) * 1000,
|
||||
"reason": "attente_nouveau_code_totp",
|
||||
"description": f"Attente nouveau code TOTP ({remaining}s restantes)",
|
||||
"_auth_action": True,
|
||||
})
|
||||
|
||||
code = totp.generate()
|
||||
|
||||
otp_field = fields.get("otp_field")
|
||||
if otp_field:
|
||||
actions.append({
|
||||
"action_id": f"{prefix}_click_otp",
|
||||
"type": "click",
|
||||
"target": otp_field.get("center", [0, 0]),
|
||||
"description": f"Clic champ OTP ({app_name})",
|
||||
"_auth_action": True,
|
||||
})
|
||||
|
||||
actions.append({
|
||||
"action_id": f"{prefix}_type_totp",
|
||||
"type": "type_text",
|
||||
"text": code,
|
||||
"description": f"Saisie code TOTP ({app_name})",
|
||||
"_auth_action": True,
|
||||
})
|
||||
|
||||
# ---- Bouton de validation ----
|
||||
submit_button = fields.get("submit_button")
|
||||
if submit_button and actions:
|
||||
actions.append({
|
||||
"action_id": f"{prefix}_click_submit",
|
||||
"type": "click",
|
||||
"target": submit_button.get("center", [0, 0]),
|
||||
"description": f"Clic validation ({app_name})",
|
||||
"_auth_action": True,
|
||||
})
|
||||
|
||||
# Pause après validation pour laisser l'app charger
|
||||
if actions:
|
||||
actions.append({
|
||||
"action_id": f"{prefix}_wait_after_auth",
|
||||
"type": "wait",
|
||||
"duration_ms": 2000,
|
||||
"reason": "attente_chargement_post_auth",
|
||||
"description": f"Attente post-authentification ({app_name})",
|
||||
"_auth_action": True,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
"Actions d'authentification générées : %d actions pour %s (type=%s)",
|
||||
len(actions),
|
||||
app_name,
|
||||
auth_request.auth_type,
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
# =========================================================================
|
||||
# Méthodes d'extraction internes
|
||||
# =========================================================================
|
||||
|
||||
def _extract_texts(self, screen_state: Any) -> List[str]:
|
||||
"""Extrait tous les textes détectés depuis un ScreenState.
|
||||
|
||||
Supporte les objets ScreenState du core et les dicts bruts.
|
||||
"""
|
||||
texts: List[str] = []
|
||||
|
||||
# ScreenState core (dataclass)
|
||||
if hasattr(screen_state, "perception") and hasattr(
|
||||
screen_state.perception, "detected_text"
|
||||
):
|
||||
texts.extend(screen_state.perception.detected_text)
|
||||
|
||||
# Dict brut (sessions streaming)
|
||||
elif isinstance(screen_state, dict):
|
||||
perception = screen_state.get("perception", {})
|
||||
if isinstance(perception, dict):
|
||||
texts.extend(perception.get("detected_text", []))
|
||||
# Texte OCR brut
|
||||
if "ocr_text" in screen_state:
|
||||
texts.append(screen_state["ocr_text"])
|
||||
# Textes des éléments UI
|
||||
for elem in screen_state.get("ui_elements", []):
|
||||
label = elem.get("label", "")
|
||||
if label:
|
||||
texts.append(label)
|
||||
|
||||
# Textes des éléments UI (objets)
|
||||
if hasattr(screen_state, "ui_elements"):
|
||||
for elem in screen_state.ui_elements:
|
||||
label = self._get_elem_attr(elem, "label", "")
|
||||
if label:
|
||||
texts.append(label)
|
||||
|
||||
return texts
|
||||
|
||||
def _extract_ui_elements(self, screen_state: Any) -> List[Any]:
|
||||
"""Extrait les éléments UI depuis un ScreenState."""
|
||||
if hasattr(screen_state, "ui_elements"):
|
||||
return list(screen_state.ui_elements)
|
||||
if isinstance(screen_state, dict):
|
||||
return screen_state.get("ui_elements", [])
|
||||
return []
|
||||
|
||||
def _extract_app_name(self, screen_state: Any) -> str:
|
||||
"""Extrait le nom de l'application depuis un ScreenState."""
|
||||
# ScreenState core
|
||||
if hasattr(screen_state, "window") and hasattr(screen_state.window, "app_name"):
|
||||
return screen_state.window.app_name
|
||||
|
||||
# Dict brut
|
||||
if isinstance(screen_state, dict):
|
||||
window = screen_state.get("window", {})
|
||||
if isinstance(window, dict):
|
||||
return window.get("app_name", "unknown")
|
||||
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def _get_elem_attr(elem: Any, attr: str, default: Any = None) -> Any:
|
||||
"""Récupère un attribut d'un élément UI (objet ou dict)."""
|
||||
if isinstance(elem, dict):
|
||||
return elem.get(attr, default)
|
||||
return getattr(elem, attr, default)
|
||||
|
||||
@staticmethod
|
||||
def _elem_to_dict(elem: Any) -> Dict[str, Any]:
|
||||
"""Convertit un élément UI en dict minimal pour les detected_fields."""
|
||||
if isinstance(elem, dict):
|
||||
return {
|
||||
"type": elem.get("type", ""),
|
||||
"label": elem.get("label", ""),
|
||||
"center": elem.get("center", [0, 0]),
|
||||
"element_id": elem.get("element_id", ""),
|
||||
}
|
||||
return {
|
||||
"type": getattr(elem, "type", ""),
|
||||
"label": getattr(elem, "label", ""),
|
||||
"center": list(getattr(elem, "center", (0, 0))),
|
||||
"element_id": getattr(elem, "element_id", ""),
|
||||
}
|
||||
298
core/auth/credential_vault.py
Normal file
298
core/auth/credential_vault.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
Coffre-fort chiffré pour les credentials d'authentification.
|
||||
|
||||
Stocke de façon sécurisée :
|
||||
- Comptes de service (login/password)
|
||||
- Seeds TOTP pour la 2FA
|
||||
- Tokens de session
|
||||
- Certificats client
|
||||
|
||||
Le fichier vault est chiffré avec Fernet (AES-128-CBC + HMAC-SHA256).
|
||||
La clé est dérivée d'un mot de passe maître via PBKDF2 (600000 itérations).
|
||||
|
||||
Choix de sécurité :
|
||||
- PBKDF2 avec 600 000 itérations : recommandation OWASP 2023 pour SHA-256.
|
||||
Compromis acceptable entre temps de dérivation (~0.5s) et résistance au brute-force.
|
||||
- Fernet (AES-128-CBC + HMAC-SHA256) : chiffrement authentifié, empêche les
|
||||
modifications silencieuses du fichier vault. Bibliothèque maintenue et auditée.
|
||||
- Salt aléatoire de 16 bytes : empêche les attaques par rainbow table.
|
||||
Stocké en clair en préfixe du fichier (le salt n'est pas un secret).
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Types de credentials supportés
|
||||
CREDENTIAL_TYPES = {"login", "totp_seed", "session_token", "certificate"}
|
||||
|
||||
# Taille du salt en bytes
|
||||
SALT_SIZE = 16
|
||||
|
||||
# Nombre d'itérations PBKDF2 — recommandation OWASP 2023 pour SHA-256
|
||||
PBKDF2_ITERATIONS = 600_000
|
||||
|
||||
# Tentative d'import de cryptography pour le chiffrement Fernet
|
||||
_HAS_FERNET = False
|
||||
try:
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
_HAS_FERNET = True
|
||||
except ImportError:
|
||||
_HAS_FERNET = False
|
||||
warnings.warn(
|
||||
"Module 'cryptography' non disponible. Le vault utilisera un encodage "
|
||||
"base64 NON SÉCURISÉ. NE PAS utiliser en production.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class CredentialVault:
|
||||
"""Coffre-fort chiffré pour les credentials d'applications.
|
||||
|
||||
Usage :
|
||||
vault = CredentialVault("/chemin/vault.enc", "mot_de_passe_maitre")
|
||||
vault.add_credential("DPI_Crossway", "login", {
|
||||
"username": "robot_lea", "password": "xxx", "domain": "HOPITAL"
|
||||
})
|
||||
vault.save()
|
||||
|
||||
creds = vault.get_credential("DPI_Crossway", "login")
|
||||
"""
|
||||
|
||||
def __init__(self, vault_path: str, master_password: str):
|
||||
"""Charge ou crée un vault chiffré.
|
||||
|
||||
Args:
|
||||
vault_path: Chemin du fichier vault chiffré sur disque.
|
||||
master_password: Mot de passe maître pour dériver la clé de chiffrement.
|
||||
"""
|
||||
self._vault_path = Path(vault_path)
|
||||
self._master_password = master_password
|
||||
self._data: Dict[str, Any] = {
|
||||
"version": "1.0",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"credentials": {},
|
||||
}
|
||||
|
||||
if self._vault_path.exists():
|
||||
self._load()
|
||||
else:
|
||||
logger.info("Vault inexistant, création d'un nouveau vault : %s", vault_path)
|
||||
|
||||
# =========================================================================
|
||||
# API publique
|
||||
# =========================================================================
|
||||
|
||||
def add_credential(
|
||||
self, app_name: str, credential_type: str, data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Ajoute ou met à jour un credential pour une application.
|
||||
|
||||
Args:
|
||||
app_name: Nom de l'application (ex: "DPI_Crossway").
|
||||
credential_type: Type parmi "login", "totp_seed", "session_token", "certificate".
|
||||
data: Dictionnaire avec les champs spécifiques au type.
|
||||
|
||||
Raises:
|
||||
ValueError: Si le credential_type n'est pas supporté.
|
||||
"""
|
||||
if credential_type not in CREDENTIAL_TYPES:
|
||||
raise ValueError(
|
||||
f"Type de credential invalide : {credential_type!r}. "
|
||||
f"Types supportés : {CREDENTIAL_TYPES}"
|
||||
)
|
||||
|
||||
if app_name not in self._data["credentials"]:
|
||||
self._data["credentials"][app_name] = {}
|
||||
|
||||
self._data["credentials"][app_name][credential_type] = data
|
||||
logger.info(
|
||||
"Credential ajouté : app=%s type=%s", app_name, credential_type
|
||||
)
|
||||
|
||||
def get_credential(
|
||||
self, app_name: str, credential_type: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Récupère un credential pour une application.
|
||||
|
||||
Args:
|
||||
app_name: Nom de l'application.
|
||||
credential_type: Type de credential recherché.
|
||||
|
||||
Returns:
|
||||
Dictionnaire du credential, ou None si non trouvé.
|
||||
"""
|
||||
app_creds = self._data["credentials"].get(app_name, {})
|
||||
return app_creds.get(credential_type)
|
||||
|
||||
def remove_credential(self, app_name: str, credential_type: str) -> bool:
|
||||
"""Supprime un credential.
|
||||
|
||||
Args:
|
||||
app_name: Nom de l'application.
|
||||
credential_type: Type de credential à supprimer.
|
||||
|
||||
Returns:
|
||||
True si supprimé, False si non trouvé.
|
||||
"""
|
||||
app_creds = self._data["credentials"].get(app_name, {})
|
||||
if credential_type in app_creds:
|
||||
del app_creds[credential_type]
|
||||
# Nettoyer l'app si plus de credentials
|
||||
if not app_creds:
|
||||
del self._data["credentials"][app_name]
|
||||
logger.info(
|
||||
"Credential supprimé : app=%s type=%s", app_name, credential_type
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_apps(self) -> List[str]:
|
||||
"""Liste les noms d'applications configurées.
|
||||
|
||||
Returns:
|
||||
Liste triée des noms d'applications.
|
||||
"""
|
||||
return sorted(self._data["credentials"].keys())
|
||||
|
||||
def list_credential_types(self, app_name: str) -> List[str]:
|
||||
"""Liste les types de credentials pour une application.
|
||||
|
||||
Args:
|
||||
app_name: Nom de l'application.
|
||||
|
||||
Returns:
|
||||
Liste des types de credentials configurés.
|
||||
"""
|
||||
return list(self._data["credentials"].get(app_name, {}).keys())
|
||||
|
||||
def save(self) -> None:
|
||||
"""Chiffre et sauvegarde le vault sur disque."""
|
||||
plaintext = json.dumps(self._data, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
encrypted = self._encrypt(plaintext)
|
||||
|
||||
# Écriture atomique via fichier temporaire
|
||||
tmp_path = self._vault_path.with_suffix(".tmp")
|
||||
self._vault_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path.write_bytes(encrypted)
|
||||
tmp_path.rename(self._vault_path)
|
||||
|
||||
logger.info("Vault sauvegardé : %s (%d bytes)", self._vault_path, len(encrypted))
|
||||
|
||||
# =========================================================================
|
||||
# Chiffrement / Déchiffrement
|
||||
# =========================================================================
|
||||
|
||||
def _derive_key(self, password: str, salt: bytes) -> bytes:
|
||||
"""Dérive une clé Fernet à partir du mot de passe maître.
|
||||
|
||||
Utilise PBKDF2-HMAC-SHA256 avec 600 000 itérations (OWASP 2023).
|
||||
La sortie est encodée en base64 URL-safe pour Fernet (32 bytes → 44 chars).
|
||||
|
||||
Args:
|
||||
password: Mot de passe maître.
|
||||
salt: Salt aléatoire (16 bytes minimum).
|
||||
|
||||
Returns:
|
||||
Clé Fernet encodée en base64 URL-safe (44 bytes).
|
||||
"""
|
||||
if _HAS_FERNET:
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=PBKDF2_ITERATIONS,
|
||||
)
|
||||
key = base64.urlsafe_b64encode(kdf.derive(password.encode("utf-8")))
|
||||
return key
|
||||
else:
|
||||
# Fallback non sécurisé — simple hash pour le développement
|
||||
import hashlib
|
||||
|
||||
dk = hashlib.pbkdf2_hmac(
|
||||
"sha256", password.encode("utf-8"), salt, PBKDF2_ITERATIONS
|
||||
)
|
||||
return base64.urlsafe_b64encode(dk)
|
||||
|
||||
def _encrypt(self, plaintext: bytes) -> bytes:
|
||||
"""Chiffre les données avec Fernet (ou base64 en fallback).
|
||||
|
||||
Format du fichier vault :
|
||||
[16 bytes salt][données chiffrées Fernet]
|
||||
|
||||
Args:
|
||||
plaintext: Données en clair à chiffrer.
|
||||
|
||||
Returns:
|
||||
Bytes chiffrés avec le salt en préfixe.
|
||||
"""
|
||||
salt = os.urandom(SALT_SIZE)
|
||||
key = self._derive_key(self._master_password, salt)
|
||||
|
||||
if _HAS_FERNET:
|
||||
fernet = Fernet(key)
|
||||
encrypted = fernet.encrypt(plaintext)
|
||||
else:
|
||||
# Fallback : base64 simple (NON sécurisé)
|
||||
encrypted = base64.urlsafe_b64encode(plaintext)
|
||||
|
||||
return salt + encrypted
|
||||
|
||||
def _decrypt(self, encrypted_data: bytes) -> bytes:
|
||||
"""Déchiffre les données.
|
||||
|
||||
Args:
|
||||
encrypted_data: Bytes chiffrés (salt + données Fernet).
|
||||
|
||||
Returns:
|
||||
Données déchiffrées.
|
||||
|
||||
Raises:
|
||||
ValueError: Si le mot de passe est incorrect ou les données corrompues.
|
||||
"""
|
||||
if len(encrypted_data) < SALT_SIZE:
|
||||
raise ValueError("Fichier vault corrompu (trop court)")
|
||||
|
||||
salt = encrypted_data[:SALT_SIZE]
|
||||
ciphertext = encrypted_data[SALT_SIZE:]
|
||||
key = self._derive_key(self._master_password, salt)
|
||||
|
||||
if _HAS_FERNET:
|
||||
try:
|
||||
fernet = Fernet(key)
|
||||
return fernet.decrypt(ciphertext)
|
||||
except InvalidToken:
|
||||
raise ValueError(
|
||||
"Mot de passe maître incorrect ou fichier vault corrompu"
|
||||
)
|
||||
else:
|
||||
# Fallback : base64 simple
|
||||
return base64.urlsafe_b64decode(ciphertext)
|
||||
|
||||
# =========================================================================
|
||||
# Chargement
|
||||
# =========================================================================
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Charge et déchiffre le vault depuis le disque."""
|
||||
try:
|
||||
encrypted_data = self._vault_path.read_bytes()
|
||||
plaintext = self._decrypt(encrypted_data)
|
||||
self._data = json.loads(plaintext.decode("utf-8"))
|
||||
logger.info(
|
||||
"Vault chargé : %s (%d apps)",
|
||||
self._vault_path,
|
||||
len(self._data.get("credentials", {})),
|
||||
)
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
raise ValueError(f"Impossible de charger le vault : {e}") from e
|
||||
213
core/auth/manage_vault.py
Normal file
213
core/auth/manage_vault.py
Normal file
@@ -0,0 +1,213 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLI de gestion du coffre-fort de credentials (vault).
|
||||
|
||||
Usage :
|
||||
# Ajouter un login
|
||||
python -m core.auth.manage_vault --vault /path/to/vault.enc --action add \
|
||||
--app "DPI_Crossway" --type login \
|
||||
--username "robot_lea" --password "xxx"
|
||||
|
||||
# Ajouter un seed TOTP
|
||||
python -m core.auth.manage_vault --vault /path/to/vault.enc --action add \
|
||||
--app "DPI_Crossway" --type totp_seed \
|
||||
--secret "JBSWY3DPEHPK3PXP"
|
||||
|
||||
# Lister les applications configurées
|
||||
python -m core.auth.manage_vault --vault /path/to/vault.enc --action list
|
||||
|
||||
# Générer un code TOTP
|
||||
python -m core.auth.manage_vault --vault /path/to/vault.enc --action generate-totp \
|
||||
--app "DPI_Crossway"
|
||||
|
||||
# Supprimer un credential
|
||||
python -m core.auth.manage_vault --vault /path/to/vault.enc --action remove \
|
||||
--app "DPI_Crossway" --type login
|
||||
|
||||
Le mot de passe maître est demandé interactivement via getpass.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import getpass
|
||||
import sys
|
||||
|
||||
from .credential_vault import CredentialVault
|
||||
from .totp_generator import TOTPGenerator
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Gestionnaire de coffre-fort de credentials pour Léa.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vault",
|
||||
required=True,
|
||||
help="Chemin du fichier vault chiffré",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action",
|
||||
required=True,
|
||||
choices=["add", "list", "remove", "generate-totp", "show"],
|
||||
help="Action à effectuer",
|
||||
)
|
||||
parser.add_argument("--app", help="Nom de l'application")
|
||||
parser.add_argument(
|
||||
"--type",
|
||||
dest="cred_type",
|
||||
choices=["login", "totp_seed", "session_token", "certificate"],
|
||||
help="Type de credential",
|
||||
)
|
||||
# Champs pour le type "login"
|
||||
parser.add_argument("--username", help="Nom d'utilisateur (type login)")
|
||||
parser.add_argument("--password", help="Mot de passe (type login)")
|
||||
parser.add_argument("--domain", help="Domaine Windows (type login, optionnel)")
|
||||
# Champs pour le type "totp_seed"
|
||||
parser.add_argument("--secret", help="Secret base32 (type totp_seed)")
|
||||
parser.add_argument(
|
||||
"--digits", type=int, default=6, help="Nombre de chiffres TOTP (défaut: 6)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interval", type=int, default=30, help="Intervalle TOTP en secondes (défaut: 30)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--algorithm", default="SHA1", help="Algorithme HMAC (défaut: SHA1)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Demander le mot de passe maître
|
||||
master_password = getpass.getpass("Mot de passe maître : ")
|
||||
if not master_password:
|
||||
print("Erreur : mot de passe maître requis.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
vault = CredentialVault(args.vault, master_password)
|
||||
except ValueError as e:
|
||||
print(f"Erreur d'ouverture du vault : {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# ---- Actions ----
|
||||
|
||||
if args.action == "list":
|
||||
apps = vault.list_apps()
|
||||
if not apps:
|
||||
print("Vault vide — aucune application configurée.")
|
||||
else:
|
||||
print(f"Applications configurées ({len(apps)}) :")
|
||||
for app in apps:
|
||||
types = vault.list_credential_types(app)
|
||||
print(f" {app} : {', '.join(types)}")
|
||||
|
||||
elif args.action == "add":
|
||||
if not args.app:
|
||||
print("Erreur : --app requis pour l'action 'add'.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if not args.cred_type:
|
||||
print("Erreur : --type requis pour l'action 'add'.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if args.cred_type == "login":
|
||||
if not args.username:
|
||||
args.username = input("Username : ")
|
||||
if not args.password:
|
||||
args.password = getpass.getpass("Password : ")
|
||||
data = {"username": args.username, "password": args.password}
|
||||
if args.domain:
|
||||
data["domain"] = args.domain
|
||||
|
||||
elif args.cred_type == "totp_seed":
|
||||
if not args.secret:
|
||||
args.secret = input("Secret base32 : ")
|
||||
data = {
|
||||
"secret": args.secret,
|
||||
"digits": args.digits,
|
||||
"interval": args.interval,
|
||||
"algorithm": args.algorithm,
|
||||
}
|
||||
|
||||
elif args.cred_type == "session_token":
|
||||
token = input("Token de session : ")
|
||||
data = {"token": token}
|
||||
|
||||
elif args.cred_type == "certificate":
|
||||
cert_path = input("Chemin du certificat : ")
|
||||
key_path = input("Chemin de la clé privée : ")
|
||||
data = {"cert_path": cert_path, "key_path": key_path}
|
||||
|
||||
else:
|
||||
print(f"Type non géré : {args.cred_type}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
vault.add_credential(args.app, args.cred_type, data)
|
||||
vault.save()
|
||||
print(f"Credential ajouté : {args.app} / {args.cred_type}")
|
||||
|
||||
elif args.action == "remove":
|
||||
if not args.app or not args.cred_type:
|
||||
print(
|
||||
"Erreur : --app et --type requis pour l'action 'remove'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
removed = vault.remove_credential(args.app, args.cred_type)
|
||||
if removed:
|
||||
vault.save()
|
||||
print(f"Credential supprimé : {args.app} / {args.cred_type}")
|
||||
else:
|
||||
print(f"Credential non trouvé : {args.app} / {args.cred_type}")
|
||||
|
||||
elif args.action == "generate-totp":
|
||||
if not args.app:
|
||||
print(
|
||||
"Erreur : --app requis pour l'action 'generate-totp'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
totp_creds = vault.get_credential(args.app, "totp_seed")
|
||||
if not totp_creds:
|
||||
print(
|
||||
f"Pas de seed TOTP configuré pour '{args.app}'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
totp = TOTPGenerator(
|
||||
secret=totp_creds["secret"],
|
||||
digits=totp_creds.get("digits", 6),
|
||||
interval=totp_creds.get("interval", 30),
|
||||
algorithm=totp_creds.get("algorithm", "SHA1"),
|
||||
)
|
||||
code = totp.generate()
|
||||
remaining = totp.time_remaining()
|
||||
print(f"Code TOTP : {code}")
|
||||
print(f"Expire dans : {remaining}s")
|
||||
|
||||
elif args.action == "show":
|
||||
if not args.app:
|
||||
print(
|
||||
"Erreur : --app requis pour l'action 'show'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
types = vault.list_credential_types(args.app)
|
||||
if not types:
|
||||
print(f"Aucun credential pour '{args.app}'.")
|
||||
else:
|
||||
print(f"Credentials pour '{args.app}' :")
|
||||
for cred_type in types:
|
||||
cred = vault.get_credential(args.app, cred_type)
|
||||
# Masquer les mots de passe et secrets
|
||||
display = {}
|
||||
for k, v in (cred or {}).items():
|
||||
if k in ("password", "secret", "token"):
|
||||
display[k] = v[:3] + "***" if len(str(v)) > 3 else "***"
|
||||
else:
|
||||
display[k] = v
|
||||
print(f" {cred_type} : {display}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
183
core/auth/totp_generator.py
Normal file
183
core/auth/totp_generator.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Générateur TOTP (Time-based One-Time Password) pour l'authentification 2FA.
|
||||
|
||||
Implémente RFC 6238 directement, sans dépendance externe.
|
||||
Compatible avec FreeOTP, Google Authenticator, Microsoft Authenticator.
|
||||
|
||||
Algorithme (RFC 6238 / RFC 4226) :
|
||||
1. Décoder le secret partagé depuis base32
|
||||
2. Calculer le compteur temporel T = floor(unix_time / interval)
|
||||
3. Encoder T en big-endian 8 bytes
|
||||
4. Calculer HMAC-SHA1(secret, T) (ou SHA-256/SHA-512 selon config)
|
||||
5. Extraction dynamique (dynamic truncation) :
|
||||
- offset = dernier octet du HMAC & 0x0F
|
||||
- extraire 4 bytes à partir de offset
|
||||
- masquer le bit de signe (& 0x7FFFFFFF)
|
||||
- modulo 10^digits pour obtenir le code
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import struct
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mapping des algorithmes supportés
|
||||
_HASH_ALGORITHMS = {
|
||||
"SHA1": hashlib.sha1,
|
||||
"SHA256": hashlib.sha256,
|
||||
"SHA512": hashlib.sha512,
|
||||
}
|
||||
|
||||
|
||||
class TOTPGenerator:
|
||||
"""Générateur de codes TOTP conformes à la RFC 6238.
|
||||
|
||||
Usage :
|
||||
totp = TOTPGenerator("JBSWY3DPEHPK3PXP")
|
||||
code = totp.generate() # "492039"
|
||||
remaining = totp.time_remaining() # 17 (secondes)
|
||||
valid = totp.verify("492039") # True
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secret: str,
|
||||
digits: int = 6,
|
||||
interval: int = 30,
|
||||
algorithm: str = "SHA1",
|
||||
):
|
||||
"""Initialise le générateur TOTP.
|
||||
|
||||
Args:
|
||||
secret: Clé secrète encodée en base32 (standard TOTP).
|
||||
digits: Nombre de chiffres du code (6 ou 8, défaut 6).
|
||||
interval: Intervalle en secondes entre deux codes (défaut 30).
|
||||
algorithm: Algorithme HMAC ("SHA1", "SHA256", "SHA512").
|
||||
|
||||
Raises:
|
||||
ValueError: Si le secret n'est pas du base32 valide ou l'algorithme inconnu.
|
||||
"""
|
||||
# Normaliser et décoder le secret base32
|
||||
# Les secrets TOTP peuvent contenir des espaces pour la lisibilité
|
||||
clean_secret = secret.upper().replace(" ", "")
|
||||
# Ajouter du padding base32 si nécessaire
|
||||
padding = (8 - len(clean_secret) % 8) % 8
|
||||
clean_secret += "=" * padding
|
||||
|
||||
try:
|
||||
self._secret_bytes = base64.b32decode(clean_secret)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Secret base32 invalide : {e}") from e
|
||||
|
||||
if algorithm.upper() not in _HASH_ALGORITHMS:
|
||||
raise ValueError(
|
||||
f"Algorithme non supporté : {algorithm!r}. "
|
||||
f"Valeurs acceptées : {list(_HASH_ALGORITHMS.keys())}"
|
||||
)
|
||||
|
||||
self._digits = digits
|
||||
self._interval = interval
|
||||
self._algorithm = algorithm.upper()
|
||||
|
||||
def generate(self, timestamp: float | None = None) -> str:
|
||||
"""Génère le code TOTP pour l'instant présent (ou un timestamp donné).
|
||||
|
||||
Args:
|
||||
timestamp: Timestamp Unix optionnel (pour les tests). Si None, utilise time.time().
|
||||
|
||||
Returns:
|
||||
Code TOTP sous forme de chaîne zero-padded (ex: "003271").
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = time.time()
|
||||
|
||||
counter = int(timestamp) // self._interval
|
||||
return self._generate_hotp(counter)
|
||||
|
||||
def time_remaining(self) -> int:
|
||||
"""Nombre de secondes avant expiration du code actuel.
|
||||
|
||||
Returns:
|
||||
Secondes restantes (entre 1 et interval).
|
||||
"""
|
||||
return self._interval - (int(time.time()) % self._interval)
|
||||
|
||||
def verify(self, code: str, timestamp: float | None = None, window: int = 1) -> bool:
|
||||
"""Vérifie un code TOTP avec une fenêtre de tolérance.
|
||||
|
||||
La fenêtre permet de compenser le décalage d'horloge entre client et serveur.
|
||||
Avec window=1, on vérifie le code actuel, le précédent et le suivant.
|
||||
|
||||
Args:
|
||||
code: Code TOTP à vérifier.
|
||||
timestamp: Timestamp Unix optionnel.
|
||||
window: Nombre d'intervalles de tolérance de chaque côté (défaut 1).
|
||||
|
||||
Returns:
|
||||
True si le code correspond à un intervalle dans la fenêtre.
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = time.time()
|
||||
|
||||
counter = int(timestamp) // self._interval
|
||||
|
||||
for offset in range(-window, window + 1):
|
||||
check_counter = counter + offset
|
||||
if check_counter < 0:
|
||||
continue # Compteur négatif impossible
|
||||
expected = self._generate_hotp(check_counter)
|
||||
# Comparaison en temps constant pour éviter les timing attacks
|
||||
if hmac.compare_digest(code, expected):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# =========================================================================
|
||||
# Implémentation interne HOTP (RFC 4226)
|
||||
# =========================================================================
|
||||
|
||||
def _generate_hotp(self, counter: int) -> str:
|
||||
"""Génère un code HOTP pour un compteur donné.
|
||||
|
||||
Implémentation conforme à la RFC 4226 section 5.3 :
|
||||
1. Encoder le compteur en big-endian 8 bytes
|
||||
2. HMAC avec l'algorithme configuré
|
||||
3. Truncation dynamique
|
||||
4. Réduction modulo 10^digits
|
||||
|
||||
Args:
|
||||
counter: Valeur du compteur (entier 64 bits).
|
||||
|
||||
Returns:
|
||||
Code HOTP zero-padded.
|
||||
"""
|
||||
# Étape 1 : Compteur en big-endian 8 bytes
|
||||
counter_bytes = struct.pack(">Q", counter)
|
||||
|
||||
# Étape 2 : HMAC
|
||||
hash_func = _HASH_ALGORITHMS[self._algorithm]
|
||||
hmac_digest = hmac.new(
|
||||
self._secret_bytes, counter_bytes, hash_func
|
||||
).digest()
|
||||
|
||||
# Étape 3 : Truncation dynamique (RFC 4226 section 5.4)
|
||||
# L'offset est déterminé par les 4 bits de poids faible du dernier octet
|
||||
offset = hmac_digest[-1] & 0x0F
|
||||
|
||||
# Extraire 4 bytes à partir de l'offset et masquer le bit de signe
|
||||
truncated = (
|
||||
((hmac_digest[offset] & 0x7F) << 24)
|
||||
| ((hmac_digest[offset + 1] & 0xFF) << 16)
|
||||
| ((hmac_digest[offset + 2] & 0xFF) << 8)
|
||||
| (hmac_digest[offset + 3] & 0xFF)
|
||||
)
|
||||
|
||||
# Étape 4 : Réduction modulo pour obtenir le nombre de chiffres voulu
|
||||
code = truncated % (10 ** self._digits)
|
||||
|
||||
# Zero-padding pour garantir la longueur
|
||||
return str(code).zfill(self._digits)
|
||||
@@ -26,7 +26,7 @@ class OllamaClient:
|
||||
def __init__(self,
|
||||
endpoint: str = "http://localhost:11434",
|
||||
model: str = "qwen3-vl:8b",
|
||||
timeout: int = 60):
|
||||
timeout: int = 180):
|
||||
"""
|
||||
Initialiser le client Ollama
|
||||
|
||||
@@ -63,14 +63,21 @@ class OllamaClient:
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 500,
|
||||
force_json: bool = False) -> Dict[str, Any]:
|
||||
force_json: bool = False,
|
||||
assistant_prefill: Optional[str] = None,
|
||||
num_ctx: Optional[int] = None,
|
||||
extra_images_b64: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Générer une réponse du VLM via l'API chat d'Ollama.
|
||||
|
||||
Note: On utilise /api/chat au lieu de /api/generate car qwen3-vl
|
||||
avec /api/generate consomme tous les tokens en thinking interne
|
||||
et retourne une réponse vide. L'API chat gère correctement
|
||||
le mode /no_think et sépare thinking/réponse.
|
||||
Pour les modèles thinking (qwen3-vl), on utilise la technique du
|
||||
"assistant prefill" : un message assistant pré-rempli est ajouté
|
||||
après le message user, forçant le modèle à continuer directement
|
||||
sans phase de thinking. Cela résout le bug Ollama 0.18.x où
|
||||
think=false est ignoré par le renderer qwen3-vl-thinking.
|
||||
|
||||
Sans prefill : le modèle pense 500+ tokens puis répond (~180s)
|
||||
Avec prefill : le modèle répond directement (~1-5s)
|
||||
|
||||
Args:
|
||||
prompt: Prompt textuel
|
||||
@@ -80,6 +87,11 @@ class OllamaClient:
|
||||
temperature: Température de génération
|
||||
max_tokens: Nombre max de tokens
|
||||
force_json: Forcer la sortie JSON (non recommandé pour qwen3-vl)
|
||||
assistant_prefill: Début de réponse pré-rempli (auto-détecté si None)
|
||||
num_ctx: Context window (défaut 2048, augmenter pour batch)
|
||||
extra_images_b64: Images supplémentaires en base64 à envoyer avec le prompt.
|
||||
Ajoutées après l'image principale. Utile pour le VLM multi-image
|
||||
(ex: screenshot + crop de référence).
|
||||
|
||||
Returns:
|
||||
Dict avec 'response', 'success', 'error'
|
||||
@@ -93,17 +105,19 @@ class OllamaClient:
|
||||
image_data = self._encode_image_from_pil(image)
|
||||
|
||||
# Nettoyer le prompt — retirer /no_think et /nothink du texte
|
||||
# car le mode thinking est contrôlé via le paramètre think=false
|
||||
# de l'API chat. Les préfixes /no_think dans le prompt causent
|
||||
# paradoxalement PLUS de thinking interne chez qwen3-vl.
|
||||
effective_prompt = prompt.replace("/no_think\n", "").replace("/no_think", "")
|
||||
effective_prompt = effective_prompt.replace("/nothink ", "").replace("/nothink", "")
|
||||
effective_prompt = effective_prompt.strip()
|
||||
|
||||
# Construire le message utilisateur
|
||||
user_message = {"role": "user", "content": effective_prompt}
|
||||
all_images = []
|
||||
if image_data:
|
||||
user_message["images"] = [image_data]
|
||||
all_images.append(image_data)
|
||||
if extra_images_b64:
|
||||
all_images.extend(extra_images_b64)
|
||||
if all_images:
|
||||
user_message["images"] = all_images
|
||||
|
||||
# Construire les messages
|
||||
messages = []
|
||||
@@ -111,9 +125,37 @@ class OllamaClient:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append(user_message)
|
||||
|
||||
# Déterminer si le modèle supporte le thinking
|
||||
# Déterminer si le modèle est un modèle thinking (qwen3)
|
||||
is_thinking_model = "qwen3" in self.model.lower()
|
||||
|
||||
# WORKAROUND Ollama 0.18.x : think=false est ignoré par le
|
||||
# renderer qwen3-vl-thinking. On utilise un assistant prefill
|
||||
# pour forcer le modèle à skip le thinking et répondre directement.
|
||||
# Le prefill est choisi en fonction du format attendu.
|
||||
# IMPORTANT : avec image, sans prefill le thinking dépasse 180s.
|
||||
prefill_used = None
|
||||
if is_thinking_model:
|
||||
if assistant_prefill is not None:
|
||||
prefill_used = assistant_prefill
|
||||
elif force_json:
|
||||
prefill_used = "{"
|
||||
elif all_images:
|
||||
# Avec image(s), le thinking est catastrophique (>180s).
|
||||
# Prefill générique pour forcer une réponse directe.
|
||||
prefill_used = "Based on the image,"
|
||||
|
||||
if prefill_used is not None:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": prefill_used
|
||||
})
|
||||
|
||||
# num_ctx par défaut à 2048 (correspondant au default du modèle
|
||||
# chargé en mémoire). Changer num_ctx force un rechargement du
|
||||
# KV cache (~30s de pénalité), donc ne l'augmenter que pour les
|
||||
# requêtes batch qui dépassent la limite (image + prompt long).
|
||||
effective_num_ctx = num_ctx or 2048
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
@@ -121,13 +163,13 @@ class OllamaClient:
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens,
|
||||
"num_ctx": 2048,
|
||||
"num_ctx": effective_num_ctx,
|
||||
"top_k": 1
|
||||
}
|
||||
}
|
||||
|
||||
# Désactiver le thinking pour les modèles qui le supportent
|
||||
# Cela réduit drastiquement la consommation de tokens et le temps
|
||||
# Garder think=false au cas où une future version d'Ollama le
|
||||
# corrige — le prefill reste le mécanisme principal
|
||||
if is_thinking_model:
|
||||
payload["think"] = False
|
||||
|
||||
@@ -144,6 +186,11 @@ class OllamaClient:
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
content = result.get("message", {}).get("content", "")
|
||||
|
||||
# Reconstituer la réponse complète en ajoutant le prefill
|
||||
if prefill_used and content:
|
||||
content = prefill_used + content
|
||||
|
||||
return {
|
||||
"response": content,
|
||||
"success": True,
|
||||
@@ -181,8 +228,11 @@ For each element, provide:
|
||||
- Semantic role (primary_action, cancel, submit, form_input, search_field, navigation, settings, close)
|
||||
|
||||
Format your response as JSON."""
|
||||
|
||||
result = self.generate(prompt, image_path=image_path, temperature=0.1)
|
||||
|
||||
result = self.generate(
|
||||
prompt, image_path=image_path, temperature=0.1,
|
||||
assistant_prefill="[",
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
try:
|
||||
@@ -214,14 +264,21 @@ Format your response as JSON."""
|
||||
Choose ONLY ONE from: {types_list}
|
||||
|
||||
Respond with just the type name, nothing else."""
|
||||
|
||||
|
||||
if context:
|
||||
prompt += f"\n\nContext: {context}"
|
||||
|
||||
result = self.generate(prompt, image=element_image, temperature=0.1)
|
||||
|
||||
result = self.generate(
|
||||
prompt, image=element_image, temperature=0.1,
|
||||
assistant_prefill="The type is:",
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
element_type = result["response"].strip().lower()
|
||||
# Retirer le prefill du début pour extraire le type
|
||||
raw = result["response"]
|
||||
if raw.startswith("The type is:"):
|
||||
raw = raw[len("The type is:"):]
|
||||
element_type = raw.strip().lower()
|
||||
# Valider que c'est un type connu
|
||||
valid_types = types_list.split(", ")
|
||||
if element_type in valid_types:
|
||||
@@ -255,14 +312,21 @@ Respond with just the type name, nothing else."""
|
||||
Choose ONLY ONE from: {roles_list}
|
||||
|
||||
Respond with just the role name, nothing else."""
|
||||
|
||||
|
||||
if context:
|
||||
prompt += f"\n\nContext: {context}"
|
||||
|
||||
result = self.generate(prompt, image=element_image, temperature=0.1)
|
||||
|
||||
result = self.generate(
|
||||
prompt, image=element_image, temperature=0.1,
|
||||
assistant_prefill="The role is:",
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
role = result["response"].strip().lower()
|
||||
# Retirer le prefill du début pour extraire le rôle
|
||||
raw = result["response"]
|
||||
if raw.startswith("The role is:"):
|
||||
raw = raw[len("The role is:"):]
|
||||
role = raw.strip().lower()
|
||||
# Valider que c'est un rôle connu
|
||||
valid_roles = roles_list.split(", ")
|
||||
if role in valid_roles:
|
||||
@@ -286,12 +350,19 @@ Respond with just the role name, nothing else."""
|
||||
Dict avec 'text' extrait
|
||||
"""
|
||||
prompt = "Extract all visible text from this image. Return only the text, nothing else."
|
||||
|
||||
result = self.generate(prompt, image=image, temperature=0.1)
|
||||
|
||||
result = self.generate(
|
||||
prompt, image=image, temperature=0.1,
|
||||
assistant_prefill="Text:",
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
return {"text": result["response"].strip(), "success": True}
|
||||
|
||||
# Retirer le prefill du début pour extraire le texte
|
||||
raw = result["response"]
|
||||
if raw.startswith("Text:"):
|
||||
raw = raw[len("Text:"):]
|
||||
return {"text": raw.strip(), "success": True}
|
||||
|
||||
return {"text": "", "success": False, "error": result["error"]}
|
||||
|
||||
# Taille minimum pour une classification fiable par le VLM
|
||||
@@ -346,7 +417,8 @@ Your answer:"""
|
||||
system_prompt=system_prompt,
|
||||
temperature=0.1,
|
||||
max_tokens=300,
|
||||
force_json=False
|
||||
force_json=False,
|
||||
assistant_prefill="{"
|
||||
)
|
||||
|
||||
if not result["success"]:
|
||||
|
||||
@@ -220,7 +220,7 @@ class UIDetector:
|
||||
# des centaines d'appels VLM inutiles (~2-3s chacun).
|
||||
# On garde max 80 candidats — suffisant pour obtenir ~50 éléments
|
||||
# après filtrage par confiance, tout en gardant un temps raisonnable.
|
||||
max_candidates = 30 # 30 suffisent pour les éléments principaux (~6min/screenshot au lieu de 17)
|
||||
max_candidates = 10 # 10 régions : compact, rapide (~5-10s avec prefill)
|
||||
if len(regions) > max_candidates:
|
||||
# Trier par confiance décroissante, puis par surface décroissante
|
||||
regions.sort(key=lambda r: (r.confidence, r.w * r.h), reverse=True)
|
||||
@@ -489,32 +489,18 @@ class UIDetector:
|
||||
if not self.vlm_client or not regions:
|
||||
return None
|
||||
|
||||
# Construire la description des régions pour le prompt
|
||||
# Construire une description compacte des régions (économise les tokens)
|
||||
regions_desc_lines = []
|
||||
for i, r in enumerate(regions):
|
||||
regions_desc_lines.append(
|
||||
f" #{i}: position=({r.x},{r.y}), size={r.w}x{r.h}, source={r.source}"
|
||||
)
|
||||
regions_description = "\n".join(regions_desc_lines)
|
||||
regions_desc_lines.append(f"#{i}:({r.x},{r.y},{r.w}x{r.h})")
|
||||
regions_description = " ".join(regions_desc_lines)
|
||||
|
||||
prompt = f"""Analyze this screenshot. I have detected UI elements at these positions:
|
||||
{regions_description}
|
||||
prompt = f"""Classify UI elements at: {regions_description}
|
||||
Types: button,text_input,checkbox,radio,dropdown,tab,link,icon,table_row,menu_item
|
||||
Roles: primary_action,cancel,submit,form_input,search_field,navigation,settings,close,delete,edit,save
|
||||
JSON array: [{{"id":0,"type":"...","role":"...","text":"..."}}]"""
|
||||
|
||||
For each element, classify it as a JSON array. Each entry must have:
|
||||
- "id": the element number (matching # above)
|
||||
- "type": one of button, text_input, checkbox, radio, dropdown, tab, link, icon, table_row, menu_item
|
||||
- "role": one of primary_action, cancel, submit, form_input, search_field, navigation, settings, close, delete, edit, save
|
||||
- "text": visible text on the element (empty string if none)
|
||||
|
||||
Return ONLY the JSON array, nothing else. Example:
|
||||
[{{"id": 0, "type": "button", "role": "submit", "text": "OK"}}, {{"id": 1, "type": "text_input", "role": "form_input", "text": ""}}]
|
||||
|
||||
Your answer:"""
|
||||
|
||||
system_prompt = (
|
||||
"You are a JSON-only UI classifier. No thinking. No explanation. "
|
||||
"Output a raw JSON array only."
|
||||
)
|
||||
system_prompt = "JSON-only UI classifier. No explanation."
|
||||
|
||||
# Appel VLM unique avec le screenshot complet
|
||||
for attempt in range(2):
|
||||
@@ -523,8 +509,10 @@ Your answer:"""
|
||||
image=pil_image,
|
||||
system_prompt=system_prompt,
|
||||
temperature=0.1,
|
||||
max_tokens=2000, # Plus de tokens car réponse groupée
|
||||
max_tokens=1500, # ~100 tokens/element * 10 elements + marge
|
||||
force_json=False,
|
||||
assistant_prefill="[", # Force JSON array direct, skip thinking
|
||||
num_ctx=2048, # 2048 suffit pour 10 régions compactes + image
|
||||
)
|
||||
|
||||
if not result["success"]:
|
||||
|
||||
622
core/detection/ui_detector_old.py.bak
Normal file
622
core/detection/ui_detector_old.py.bak
Normal file
@@ -0,0 +1,622 @@
|
||||
"""
|
||||
UIDetector - Détection Sémantique d'Éléments UI avec VLM
|
||||
|
||||
Utilise un Vision-Language Model (VLM) pour détecter et classifier
|
||||
les éléments UI avec leurs types et rôles sémantiques.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import json
|
||||
import re
|
||||
|
||||
from ..models.ui_element import UIElement, UIElementEmbeddings, VisualFeatures
|
||||
from .ollama_client import OllamaClient, check_ollama_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionConfig:
|
||||
"""Configuration de la détection UI"""
|
||||
vlm_model: str = "qwen3-vl:8b" # Modèle VLM à utiliser (qwen3-vl:8b recommandé)
|
||||
vlm_endpoint: str = "http://localhost:11434" # Endpoint Ollama
|
||||
confidence_threshold: float = 0.7 # Seuil de confiance minimum
|
||||
max_elements: int = 50 # Nombre max d'éléments à détecter
|
||||
detect_regions: bool = True # Détecter régions d'intérêt d'abord
|
||||
use_embeddings: bool = True # Générer embeddings duaux
|
||||
|
||||
|
||||
class UIDetector:
|
||||
"""
|
||||
Détecteur d'éléments UI sémantique
|
||||
|
||||
Utilise un VLM (Vision-Language Model) pour :
|
||||
1. Détecter les régions d'intérêt dans un screenshot
|
||||
2. Classifier le type de chaque élément UI
|
||||
3. Déterminer le rôle sémantique
|
||||
4. Extraire les features visuelles
|
||||
5. Générer des embeddings duaux (image + texte)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[DetectionConfig] = None):
|
||||
"""
|
||||
Initialiser le détecteur
|
||||
|
||||
Args:
|
||||
config: Configuration (utilise config par défaut si None)
|
||||
"""
|
||||
self.config = config or DetectionConfig()
|
||||
self.vlm_client = None
|
||||
self._initialize_vlm()
|
||||
|
||||
def _initialize_vlm(self) -> None:
|
||||
"""Initialiser le client VLM (Ollama)"""
|
||||
try:
|
||||
# Vérifier si Ollama est disponible
|
||||
if check_ollama_available(self.config.vlm_endpoint):
|
||||
self.vlm_client = OllamaClient(
|
||||
endpoint=self.config.vlm_endpoint,
|
||||
model=self.config.vlm_model
|
||||
)
|
||||
print(f"✓ VLM initialized: {self.config.vlm_model} at {self.config.vlm_endpoint}")
|
||||
else:
|
||||
print(f"⚠ Ollama not available at {self.config.vlm_endpoint}, using simulation mode")
|
||||
self.vlm_client = None
|
||||
except Exception as e:
|
||||
print(f"⚠ Failed to initialize VLM: {e}, using simulation mode")
|
||||
self.vlm_client = None
|
||||
|
||||
def detect(self,
|
||||
screenshot_path: str,
|
||||
window_context: Optional[Dict[str, Any]] = None) -> List[UIElement]:
|
||||
"""
|
||||
Détecter tous les éléments UI dans un screenshot
|
||||
|
||||
Args:
|
||||
screenshot_path: Chemin vers le screenshot
|
||||
window_context: Contexte de la fenêtre (titre, process, etc.)
|
||||
|
||||
Returns:
|
||||
Liste d'UIElements détectés
|
||||
"""
|
||||
# Charger image
|
||||
image = self._load_image(screenshot_path)
|
||||
if image is None:
|
||||
return []
|
||||
|
||||
# Détecter régions d'intérêt si activé
|
||||
if self.config.detect_regions:
|
||||
regions = self._detect_regions_of_interest(image, window_context)
|
||||
else:
|
||||
# Utiliser image complète
|
||||
regions = [{"bbox": (0, 0, image.width, image.height), "confidence": 1.0}]
|
||||
|
||||
# Détecter éléments UI dans chaque région
|
||||
ui_elements = []
|
||||
for region in regions:
|
||||
elements = self._detect_elements_in_region(
|
||||
image,
|
||||
region,
|
||||
screenshot_path,
|
||||
window_context
|
||||
)
|
||||
ui_elements.extend(elements)
|
||||
|
||||
# Filtrer par confiance
|
||||
ui_elements = [
|
||||
el for el in ui_elements
|
||||
if el.confidence >= self.config.confidence_threshold
|
||||
]
|
||||
|
||||
# Limiter nombre d'éléments
|
||||
if len(ui_elements) > self.config.max_elements:
|
||||
# Trier par confiance et garder les meilleurs
|
||||
ui_elements.sort(key=lambda x: x.confidence, reverse=True)
|
||||
ui_elements = ui_elements[:self.config.max_elements]
|
||||
|
||||
return ui_elements
|
||||
|
||||
def _load_image(self, screenshot_path: str) -> Optional[Image.Image]:
|
||||
"""Charger une image depuis un fichier"""
|
||||
try:
|
||||
return Image.open(screenshot_path)
|
||||
except Exception as e:
|
||||
print(f"Error loading image {screenshot_path}: {e}")
|
||||
return None
|
||||
|
||||
def _detect_regions_of_interest(self,
|
||||
image: Image.Image,
|
||||
window_context: Optional[Dict] = None) -> List[Dict]:
|
||||
"""
|
||||
Détecter les régions d'intérêt dans l'image
|
||||
|
||||
Utilise le VLM pour identifier les zones contenant des éléments UI.
|
||||
|
||||
Args:
|
||||
image: Image PIL
|
||||
window_context: Contexte de la fenêtre
|
||||
|
||||
Returns:
|
||||
Liste de régions {bbox: (x, y, w, h), confidence: float}
|
||||
"""
|
||||
if self.vlm_client is None:
|
||||
# Mode simulation : diviser l'image en grille
|
||||
return self._simulate_region_detection(image)
|
||||
|
||||
# Utiliser VLM pour détecter régions
|
||||
# Pour l'instant, on utilise l'image complète (plus simple et efficace)
|
||||
width, height = image.size
|
||||
return [{
|
||||
"bbox": (0, 0, width, height),
|
||||
"confidence": 1.0
|
||||
}]
|
||||
|
||||
def _simulate_region_detection(self, image: Image.Image) -> List[Dict]:
|
||||
"""Simulation de détection de régions (pour développement)"""
|
||||
width, height = image.size
|
||||
|
||||
# Diviser en grille 3x3 pour simulation
|
||||
regions = []
|
||||
grid_size = 3
|
||||
cell_w = width // grid_size
|
||||
cell_h = height // grid_size
|
||||
|
||||
for i in range(grid_size):
|
||||
for j in range(grid_size):
|
||||
regions.append({
|
||||
"bbox": (j * cell_w, i * cell_h, cell_w, cell_h),
|
||||
"confidence": 0.8
|
||||
})
|
||||
|
||||
return regions
|
||||
|
||||
def _detect_elements_in_region(self,
|
||||
image: Image.Image,
|
||||
region: Dict,
|
||||
screenshot_path: str,
|
||||
window_context: Optional[Dict] = None) -> List[UIElement]:
|
||||
"""
|
||||
Détecter éléments UI dans une région spécifique
|
||||
|
||||
Args:
|
||||
image: Image complète
|
||||
region: Région à analyser
|
||||
screenshot_path: Chemin du screenshot
|
||||
window_context: Contexte de la fenêtre
|
||||
|
||||
Returns:
|
||||
Liste d'UIElements dans cette région
|
||||
"""
|
||||
bbox = region["bbox"]
|
||||
x, y, w, h = bbox
|
||||
|
||||
# Extraire crop de la région
|
||||
region_image = image.crop((x, y, x + w, y + h))
|
||||
|
||||
# Détecter éléments avec VLM
|
||||
if self.vlm_client is None:
|
||||
# Mode simulation
|
||||
return self._simulate_element_detection(
|
||||
region_image, bbox, screenshot_path, window_context
|
||||
)
|
||||
|
||||
# Vraie détection avec VLM !
|
||||
return self._detect_with_vlm(
|
||||
region_image, bbox, screenshot_path, window_context
|
||||
)
|
||||
|
||||
def _detect_with_vlm(self,
|
||||
region_image: Image.Image,
|
||||
region_bbox: Tuple[int, int, int, int],
|
||||
screenshot_path: str,
|
||||
window_context: Optional[Dict] = None) -> List[UIElement]:
|
||||
"""
|
||||
Détecter éléments UI avec le VLM (vraie détection)
|
||||
|
||||
Args:
|
||||
region_image: Image de la région
|
||||
region_bbox: Bbox de la région (x, y, w, h)
|
||||
screenshot_path: Chemin du screenshot
|
||||
window_context: Contexte de la fenêtre
|
||||
|
||||
Returns:
|
||||
Liste d'UIElements détectés
|
||||
"""
|
||||
x_offset, y_offset, w, h = region_bbox
|
||||
|
||||
# Construire le prompt pour le VLM
|
||||
context_str = ""
|
||||
if window_context:
|
||||
context_str = f"\nWindow context: {window_context.get('title', 'Unknown')}"
|
||||
|
||||
# Approche simplifiée : demander une description structurée
|
||||
prompt = f"""List all interactive UI elements in this screenshot.{context_str}
|
||||
|
||||
For each element, provide:
|
||||
- type (button, text_input, checkbox, link, etc.)
|
||||
- label (visible text)
|
||||
- approximate position (top/middle/bottom, left/center/right)
|
||||
|
||||
Format as JSON array:
|
||||
[{{"type": "button", "label": "Submit", "position": "middle-center"}}]
|
||||
|
||||
Return ONLY the JSON array, no other text."""
|
||||
|
||||
# Appeler le VLM
|
||||
# Note: Utiliser le chemin du screenshot complet plutôt que le crop
|
||||
# car certains VLM gèrent mieux les fichiers que les images PIL
|
||||
result = self.vlm_client.generate(
|
||||
prompt=prompt,
|
||||
image_path=screenshot_path, # Utiliser le chemin au lieu de l'image PIL
|
||||
temperature=0.1,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
if not result["success"]:
|
||||
print(f"❌ VLM detection failed: {result.get('error', 'Unknown error')}")
|
||||
return []
|
||||
|
||||
if not result["response"] or len(result["response"].strip()) == 0:
|
||||
print(f"⚠ VLM returned empty response")
|
||||
return []
|
||||
|
||||
# Parser la réponse JSON
|
||||
elements = self._parse_vlm_response(
|
||||
result["response"],
|
||||
region_bbox,
|
||||
screenshot_path,
|
||||
window_context
|
||||
)
|
||||
|
||||
return elements
|
||||
|
||||
def _parse_vlm_response(self,
|
||||
response: str,
|
||||
region_bbox: Tuple[int, int, int, int],
|
||||
screenshot_path: str,
|
||||
window_context: Optional[Dict] = None) -> List[UIElement]:
|
||||
"""
|
||||
Parser la réponse JSON du VLM
|
||||
|
||||
Args:
|
||||
response: Réponse texte du VLM
|
||||
region_bbox: Bbox de la région
|
||||
screenshot_path: Chemin du screenshot
|
||||
window_context: Contexte de la fenêtre
|
||||
|
||||
Returns:
|
||||
Liste d'UIElements
|
||||
"""
|
||||
x_offset, y_offset, region_w, region_h = region_bbox
|
||||
|
||||
try:
|
||||
# Extraire le JSON de la réponse (peut contenir du texte avant/après)
|
||||
json_match = re.search(r'\[.*\]', response, re.DOTALL)
|
||||
if not json_match:
|
||||
print(f"No JSON array found in VLM response")
|
||||
print(f"VLM response was: {response[:500]}...")
|
||||
return []
|
||||
|
||||
elements_data = json.loads(json_match.group(0))
|
||||
|
||||
if not isinstance(elements_data, list):
|
||||
print(f"VLM response is not a JSON array")
|
||||
return []
|
||||
|
||||
elements = []
|
||||
for i, elem_data in enumerate(elements_data):
|
||||
try:
|
||||
# Gérer les positions (pourcentages ou textuelles)
|
||||
if 'x' in elem_data and 'y' in elem_data:
|
||||
# Format avec pourcentages
|
||||
x_pct = float(elem_data.get('x', 0))
|
||||
y_pct = float(elem_data.get('y', 0))
|
||||
w_pct = float(elem_data.get('width', 10))
|
||||
h_pct = float(elem_data.get('height', 5))
|
||||
|
||||
elem_x = x_offset + int(region_w * x_pct / 100)
|
||||
elem_y = y_offset + int(region_h * y_pct / 100)
|
||||
elem_w = int(region_w * w_pct / 100)
|
||||
elem_h = int(region_h * h_pct / 100)
|
||||
else:
|
||||
# Format avec position textuelle (top/middle/bottom, left/center/right)
|
||||
position = elem_data.get('position', 'middle-center').lower()
|
||||
|
||||
# Parser la position
|
||||
if 'top' in position:
|
||||
elem_y = y_offset + region_h // 4
|
||||
elif 'bottom' in position:
|
||||
elem_y = y_offset + 3 * region_h // 4
|
||||
else: # middle
|
||||
elem_y = y_offset + region_h // 2
|
||||
|
||||
if 'left' in position:
|
||||
elem_x = x_offset + region_w // 4
|
||||
elif 'right' in position:
|
||||
elem_x = x_offset + 3 * region_w // 4
|
||||
else: # center
|
||||
elem_x = x_offset + region_w // 2
|
||||
|
||||
# Taille par défaut basée sur le type
|
||||
elem_type = elem_data.get('type', 'button')
|
||||
if elem_type == 'button':
|
||||
elem_w, elem_h = 100, 40
|
||||
elif elem_type == 'text_input':
|
||||
elem_w, elem_h = 200, 35
|
||||
elif elem_type == 'checkbox':
|
||||
elem_w, elem_h = 25, 25
|
||||
else:
|
||||
elem_w, elem_h = 80, 30
|
||||
|
||||
# Créer l'UIElement
|
||||
element = UIElement(
|
||||
element_id=f"vlm_{elem_x}_{elem_y}",
|
||||
type=elem_data.get('type', 'unknown'),
|
||||
role=elem_data.get('role', 'unknown'),
|
||||
bbox=(elem_x, elem_y, elem_w, elem_h),
|
||||
center=(elem_x + elem_w // 2, elem_y + elem_h // 2),
|
||||
label=elem_data.get('label', ''),
|
||||
label_confidence=0.85, # Confiance par défaut pour VLM
|
||||
embeddings=UIElementEmbeddings(),
|
||||
visual_features=VisualFeatures(
|
||||
dominant_color="rgb(128, 128, 128)",
|
||||
has_icon=elem_data.get('type') == 'icon',
|
||||
shape="rectangle",
|
||||
size_category="medium"
|
||||
),
|
||||
confidence=0.85, # Confiance par défaut pour VLM
|
||||
metadata={
|
||||
"detected_by": "vlm",
|
||||
"model": self.config.vlm_model,
|
||||
"screenshot_path": screenshot_path
|
||||
}
|
||||
)
|
||||
|
||||
elements.append(element)
|
||||
|
||||
except (KeyError, ValueError, TypeError) as e:
|
||||
print(f"Error parsing element {i}: {e}")
|
||||
continue
|
||||
|
||||
return elements
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse VLM JSON response: {e}")
|
||||
print(f"Response was: {response[:200]}...")
|
||||
return []
|
||||
|
||||
def _simulate_element_detection(self,
|
||||
region_image: Image.Image,
|
||||
region_bbox: Tuple[int, int, int, int],
|
||||
screenshot_path: str,
|
||||
window_context: Optional[Dict] = None) -> List[UIElement]:
|
||||
"""Simulation de détection d'éléments (pour développement)"""
|
||||
# Pour simulation, créer quelques éléments fictifs
|
||||
elements = []
|
||||
|
||||
x_offset, y_offset, w, h = region_bbox
|
||||
|
||||
# Simuler 2-3 éléments par région
|
||||
num_elements = np.random.randint(2, 4)
|
||||
|
||||
for i in range(num_elements):
|
||||
# Position aléatoire dans la région
|
||||
elem_w = np.random.randint(50, 150)
|
||||
elem_h = np.random.randint(20, 60)
|
||||
elem_x = x_offset + np.random.randint(0, max(1, w - elem_w))
|
||||
elem_y = y_offset + np.random.randint(0, max(1, h - elem_h))
|
||||
|
||||
# Type et rôle aléatoires
|
||||
types = ["button", "text_input", "checkbox", "link", "icon"]
|
||||
roles = ["primary_action", "cancel", "submit", "form_input", "navigation"]
|
||||
|
||||
element = UIElement(
|
||||
element_id=f"elem_{elem_x}_{elem_y}",
|
||||
type=np.random.choice(types),
|
||||
role=np.random.choice(roles),
|
||||
bbox=(elem_x, elem_y, elem_w, elem_h),
|
||||
center=(elem_x + elem_w // 2, elem_y + elem_h // 2),
|
||||
label=f"Element {i}",
|
||||
label_confidence=np.random.uniform(0.7, 0.95),
|
||||
embeddings=UIElementEmbeddings(), # Embeddings vides
|
||||
visual_features=VisualFeatures(
|
||||
dominant_color="rgb(128, 128, 128)",
|
||||
has_icon=np.random.choice([True, False]),
|
||||
shape="rectangle",
|
||||
size_category="medium"
|
||||
),
|
||||
confidence=np.random.uniform(0.7, 0.95),
|
||||
metadata={"simulated": True, "screenshot_path": screenshot_path}
|
||||
)
|
||||
|
||||
elements.append(element)
|
||||
|
||||
return elements
|
||||
|
||||
def classify_type(self,
|
||||
element_image: Image.Image,
|
||||
context: Optional[Dict] = None) -> Tuple[str, float]:
|
||||
"""
|
||||
Classifier le type d'un élément UI
|
||||
|
||||
Args:
|
||||
element_image: Image de l'élément
|
||||
context: Contexte additionnel
|
||||
|
||||
Returns:
|
||||
(type, confidence)
|
||||
"""
|
||||
if self.vlm_client is None:
|
||||
# Simulation
|
||||
types = ["button", "text_input", "checkbox", "radio", "dropdown",
|
||||
"tab", "link", "icon", "table_row", "menu_item"]
|
||||
return np.random.choice(types), np.random.uniform(0.7, 0.95)
|
||||
|
||||
# Vraie classification avec VLM
|
||||
result = self.vlm_client.classify_element_type(element_image, context)
|
||||
|
||||
if result["success"]:
|
||||
return result["type"], result["confidence"]
|
||||
|
||||
return "unknown", 0.0
|
||||
|
||||
def classify_role(self,
|
||||
element_image: Image.Image,
|
||||
element_type: str,
|
||||
context: Optional[Dict] = None) -> Tuple[str, float]:
|
||||
"""
|
||||
Classifier le rôle sémantique d'un élément
|
||||
|
||||
Args:
|
||||
element_image: Image de l'élément
|
||||
element_type: Type de l'élément
|
||||
context: Contexte additionnel
|
||||
|
||||
Returns:
|
||||
(role, confidence)
|
||||
"""
|
||||
if self.vlm_client is None:
|
||||
# Simulation
|
||||
roles = ["primary_action", "cancel", "submit", "form_input",
|
||||
"search_field", "navigation", "settings", "close"]
|
||||
return np.random.choice(roles), np.random.uniform(0.7, 0.95)
|
||||
|
||||
# Vraie classification avec VLM
|
||||
result = self.vlm_client.classify_element_role(
|
||||
element_image,
|
||||
element_type,
|
||||
context
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
return result["role"], result["confidence"]
|
||||
|
||||
return "unknown", 0.0
|
||||
|
||||
def extract_visual_features(self,
|
||||
element_image: Image.Image) -> VisualFeatures:
|
||||
"""
|
||||
Extraire les features visuelles d'un élément
|
||||
|
||||
Args:
|
||||
element_image: Image de l'élément
|
||||
|
||||
Returns:
|
||||
VisualFeatures
|
||||
"""
|
||||
# Calculer couleur dominante
|
||||
img_array = np.array(element_image)
|
||||
if len(img_array.shape) == 3:
|
||||
# Moyenne des couleurs
|
||||
dominant_color = tuple(img_array.mean(axis=(0, 1)).astype(int).tolist())
|
||||
else:
|
||||
dominant_color = (128, 128, 128)
|
||||
|
||||
# Déterminer forme (simplifié)
|
||||
width, height = element_image.size
|
||||
aspect_ratio = width / height if height > 0 else 1.0
|
||||
|
||||
if aspect_ratio > 3:
|
||||
shape = "horizontal_bar"
|
||||
elif aspect_ratio < 0.33:
|
||||
shape = "vertical_bar"
|
||||
elif 0.8 <= aspect_ratio <= 1.2:
|
||||
shape = "square"
|
||||
else:
|
||||
shape = "rectangle"
|
||||
|
||||
# Catégorie de taille
|
||||
area = width * height
|
||||
if area < 1000:
|
||||
size_category = "small"
|
||||
elif area < 10000:
|
||||
size_category = "medium"
|
||||
else:
|
||||
size_category = "large"
|
||||
|
||||
# Détection d'icône (simplifié)
|
||||
has_icon = width < 100 and height < 100 and 0.8 <= aspect_ratio <= 1.2
|
||||
|
||||
return VisualFeatures(
|
||||
dominant_color=dominant_color,
|
||||
has_icon=has_icon,
|
||||
shape=shape,
|
||||
size_category=size_category
|
||||
)
|
||||
|
||||
def generate_embeddings(self,
|
||||
element_image: Image.Image,
|
||||
element_label: str,
|
||||
embedder: Optional[Any] = None) -> Optional[UIElementEmbeddings]:
|
||||
"""
|
||||
Générer embeddings duaux (image + texte) pour un élément
|
||||
|
||||
Args:
|
||||
element_image: Image de l'élément
|
||||
element_label: Label textuel de l'élément
|
||||
embedder: Embedder à utiliser (optionnel)
|
||||
|
||||
Returns:
|
||||
UIElementEmbeddings ou None
|
||||
"""
|
||||
if not self.config.use_embeddings or embedder is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Générer embedding image
|
||||
image_embedding_id = None
|
||||
if hasattr(embedder, 'embed_image'):
|
||||
# Sauvegarder temporairement l'image
|
||||
# TODO: Implémenter sauvegarde et embedding
|
||||
pass
|
||||
|
||||
# Générer embedding texte
|
||||
text_embedding_id = None
|
||||
if element_label and hasattr(embedder, 'embed_text'):
|
||||
# TODO: Implémenter embedding texte
|
||||
pass
|
||||
|
||||
if image_embedding_id or text_embedding_id:
|
||||
return UIElementEmbeddings(
|
||||
image_embedding_id=image_embedding_id,
|
||||
text_embedding_id=text_embedding_id,
|
||||
provider="openclip_ViT-B-32",
|
||||
dimensions=512
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to generate embeddings: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def set_vlm_client(self, client: Any) -> None:
|
||||
"""Définir le client VLM"""
|
||||
self.vlm_client = client
|
||||
|
||||
def get_config(self) -> DetectionConfig:
|
||||
"""Récupérer la configuration"""
|
||||
return self.config
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fonctions utilitaires
|
||||
# ============================================================================
|
||||
|
||||
def create_detector(vlm_model: str = "qwen3-vl:8b",
|
||||
confidence_threshold: float = 0.7) -> UIDetector:
|
||||
"""
|
||||
Créer un UIDetector avec configuration personnalisée
|
||||
|
||||
Args:
|
||||
vlm_model: Modèle VLM à utiliser
|
||||
confidence_threshold: Seuil de confiance
|
||||
|
||||
Returns:
|
||||
UIDetector configuré
|
||||
"""
|
||||
config = DetectionConfig(
|
||||
vlm_model=vlm_model,
|
||||
confidence_threshold=confidence_threshold
|
||||
)
|
||||
return UIDetector(config)
|
||||
@@ -125,18 +125,32 @@ class FusionEngine:
|
||||
weights: Dict[str, float]) -> np.ndarray:
|
||||
"""
|
||||
Fusion pondérée simple : somme pondérée des vecteurs
|
||||
|
||||
|
||||
fused = w1*v1 + w2*v2 + w3*v3 + w4*v4
|
||||
|
||||
Les poids sont renormalisés en fonction des modalités effectivement
|
||||
présentes, pour que la somme des poids effectifs = 1.0.
|
||||
Exemple : si seuls image (0.5) et text (0.3) sont fournis,
|
||||
les poids deviennent image=0.625, text=0.375.
|
||||
"""
|
||||
# Initialiser vecteur résultat
|
||||
first_vector = next(iter(embeddings.values()))
|
||||
fused = np.zeros_like(first_vector, dtype=np.float32)
|
||||
|
||||
# Somme pondérée
|
||||
|
||||
# Calculer la somme des poids des modalités présentes pour renormaliser
|
||||
present_weight_sum = sum(
|
||||
weights.get(modality, 0.0) for modality in embeddings
|
||||
)
|
||||
|
||||
# Somme pondérée avec renormalisation
|
||||
for modality, vector in embeddings.items():
|
||||
weight = weights.get(modality, 0.0)
|
||||
fused += weight * vector
|
||||
|
||||
raw_weight = weights.get(modality, 0.0)
|
||||
if present_weight_sum > 1e-10:
|
||||
effective_weight = raw_weight / present_weight_sum
|
||||
else:
|
||||
effective_weight = 1.0 / len(embeddings)
|
||||
fused += effective_weight * vector
|
||||
|
||||
return fused
|
||||
|
||||
def _fuse_concat_projection(self,
|
||||
|
||||
@@ -112,7 +112,7 @@ class StateEmbeddingBuilder:
|
||||
metadata={
|
||||
"screen_state_id": screen_state.screen_state_id,
|
||||
"timestamp": screen_state.timestamp.isoformat(),
|
||||
"window_title": getattr(screen_state.window, 'title', ''),
|
||||
"window_title": getattr(screen_state.window, 'window_title', ''),
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
@@ -160,15 +160,16 @@ class StateEmbeddingBuilder:
|
||||
if ui_emb is not None:
|
||||
embeddings["ui"] = ui_emb
|
||||
|
||||
# Si aucun embedding calculé, créer des vecteurs par défaut
|
||||
# Si aucun embedding calculé, retourner un vecteur zéro unique
|
||||
# (sera ignoré par DBSCAN → noise, comportement correct)
|
||||
if not embeddings:
|
||||
# Utiliser dimensions par défaut (512)
|
||||
default_dim = 512
|
||||
logger.warning(
|
||||
"Aucun embedding calculé pour ce ScreenState — "
|
||||
"retour d'un vecteur zéro (sera traité comme noise par DBSCAN)"
|
||||
)
|
||||
embeddings = {
|
||||
"image": np.random.randn(default_dim).astype(np.float32),
|
||||
"text": np.random.randn(default_dim).astype(np.float32),
|
||||
"title": np.random.randn(default_dim).astype(np.float32),
|
||||
"ui": np.random.randn(default_dim).astype(np.float32)
|
||||
"image": np.zeros(default_dim, dtype=np.float32)
|
||||
}
|
||||
|
||||
return embeddings
|
||||
@@ -243,7 +244,7 @@ class StateEmbeddingBuilder:
|
||||
|
||||
try:
|
||||
embedder = self.embedders["title"]
|
||||
title = getattr(screen_state.window, 'title', '')
|
||||
title = getattr(screen_state.window, 'window_title', '')
|
||||
|
||||
if not title:
|
||||
return None
|
||||
|
||||
24
core/federation/__init__.py
Normal file
24
core/federation/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
core.federation — Fédération des apprentissages entre clients.
|
||||
|
||||
Exporte les connaissances anonymisées (Learning Packs) de chaque site client,
|
||||
les fusionne sur un serveur central, et redistribue le modèle enrichi.
|
||||
|
||||
Modules :
|
||||
learning_pack — Format d'export, exportation, fusion
|
||||
faiss_global — Index FAISS global multi-clients
|
||||
"""
|
||||
|
||||
from .learning_pack import (
|
||||
LearningPack,
|
||||
LearningPackExporter,
|
||||
LearningPackMerger,
|
||||
)
|
||||
from .faiss_global import GlobalFAISSIndex
|
||||
|
||||
__all__ = [
|
||||
"LearningPack",
|
||||
"LearningPackExporter",
|
||||
"LearningPackMerger",
|
||||
"GlobalFAISSIndex",
|
||||
]
|
||||
354
core/federation/faiss_global.py
Normal file
354
core/federation/faiss_global.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
GlobalFAISSIndex — Index FAISS global fédérant les prototypes de tous les clients.
|
||||
|
||||
Construit un index de recherche vectorielle à partir des Learning Packs
|
||||
reçus de multiples sites clients. Chaque vecteur indexé porte des métadonnées
|
||||
permettant de retrouver le pack source, le workflow et l'application d'origine.
|
||||
|
||||
Cet index est utilisé par le serveur central (DGX Spark) pour :
|
||||
- Reconnaître instantanément un écran déjà vu chez un autre client
|
||||
- Proposer des workflows existants quand un nouveau client rencontre un écran familier
|
||||
- Mesurer la couverture applicative globale de Léa
|
||||
|
||||
Auteur : Dom, Claude — 19 mars 2026
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .learning_pack import LearningPack, ScreenPrototype
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Dimensions par défaut des embeddings CLIP (ViT-B-32)
|
||||
DEFAULT_DIMENSIONS = 512
|
||||
|
||||
try:
|
||||
import faiss
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAISS_AVAILABLE = False
|
||||
logger.warning("FAISS non installé — GlobalFAISSIndex désactivé. pip install faiss-cpu")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlobalSearchResult:
|
||||
"""Résultat d'une recherche dans l'index global."""
|
||||
prototype_id: str
|
||||
similarity: float
|
||||
pack_source_hash: str
|
||||
workflow_skeleton_id: str
|
||||
node_name: str
|
||||
app_name: str
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class GlobalFAISSIndex:
|
||||
"""
|
||||
Index FAISS global contenant les prototypes d'écran de tous les clients.
|
||||
|
||||
Chaque vecteur est associé à des métadonnées :
|
||||
- pack_source_hash : hash du client source
|
||||
- workflow_skeleton_id : ID du workflow d'origine
|
||||
- node_name : nom du nœud (écran) dans le workflow
|
||||
- app_name : nom de l'application
|
||||
|
||||
Usage :
|
||||
>>> index = GlobalFAISSIndex()
|
||||
>>> index.build_from_packs([pack_a, pack_b])
|
||||
>>> results = index.search(query_vector, k=5)
|
||||
>>> index.save(Path("global/faiss_index"))
|
||||
"""
|
||||
|
||||
def __init__(self, dimensions: int = DEFAULT_DIMENSIONS):
|
||||
"""
|
||||
Initialiser l'index global.
|
||||
|
||||
Args:
|
||||
dimensions: Nombre de dimensions des vecteurs (512 pour CLIP ViT-B-32).
|
||||
"""
|
||||
if not FAISS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"FAISS est requis pour GlobalFAISSIndex. "
|
||||
"Installer avec : pip install faiss-cpu"
|
||||
)
|
||||
|
||||
self.dimensions = dimensions
|
||||
self.index: Optional["faiss.IndexFlatIP"] = None
|
||||
self._metadata: List[Dict[str, Any]] = []
|
||||
self._rebuild_index()
|
||||
|
||||
def _rebuild_index(self) -> None:
|
||||
"""Créer ou recréer l'index FAISS vide."""
|
||||
# IndexFlatIP pour similarité cosinus (vecteurs normalisés)
|
||||
self.index = faiss.IndexFlatIP(self.dimensions)
|
||||
self._metadata = []
|
||||
|
||||
@property
|
||||
def total_vectors(self) -> int:
|
||||
"""Nombre de vecteurs dans l'index."""
|
||||
return self.index.ntotal if self.index is not None else 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Construction depuis les Learning Packs
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def build_from_packs(self, packs: List[LearningPack]) -> int:
|
||||
"""
|
||||
Construire l'index à partir d'une liste de Learning Packs.
|
||||
|
||||
Remplace le contenu existant de l'index.
|
||||
|
||||
Args:
|
||||
packs: Liste de LearningPacks à indexer.
|
||||
|
||||
Returns:
|
||||
Nombre de vecteurs ajoutés à l'index.
|
||||
"""
|
||||
self._rebuild_index()
|
||||
|
||||
vectors = []
|
||||
metadata_list = []
|
||||
|
||||
for pack in packs:
|
||||
for proto in pack.screen_prototypes:
|
||||
vec = self._proto_to_vector(proto)
|
||||
if vec is None:
|
||||
continue
|
||||
|
||||
meta = {
|
||||
"prototype_id": proto.prototype_id,
|
||||
"pack_source_hash": pack.source_hash,
|
||||
"workflow_skeleton_id": self._extract_skeleton_id(proto),
|
||||
"node_name": self._extract_node_name(proto),
|
||||
"app_name": proto.app_name or "",
|
||||
}
|
||||
vectors.append(vec)
|
||||
metadata_list.append(meta)
|
||||
|
||||
if not vectors:
|
||||
logger.info("Aucun vecteur valide trouvé dans les packs.")
|
||||
return 0
|
||||
|
||||
# Empiler et normaliser les vecteurs
|
||||
matrix = np.array(vectors, dtype=np.float32)
|
||||
faiss.normalize_L2(matrix)
|
||||
|
||||
# Ajouter à l'index
|
||||
self.index.add(matrix)
|
||||
self._metadata = metadata_list
|
||||
|
||||
logger.info(
|
||||
"Index global construit : %d vecteurs depuis %d packs",
|
||||
len(vectors), len(packs),
|
||||
)
|
||||
return len(vectors)
|
||||
|
||||
def add_pack(self, pack: LearningPack) -> int:
|
||||
"""
|
||||
Ajouter les prototypes d'un pack à l'index existant (incrémental).
|
||||
|
||||
Args:
|
||||
pack: LearningPack à ajouter.
|
||||
|
||||
Returns:
|
||||
Nombre de vecteurs ajoutés.
|
||||
"""
|
||||
vectors = []
|
||||
metadata_list = []
|
||||
|
||||
for proto in pack.screen_prototypes:
|
||||
vec = self._proto_to_vector(proto)
|
||||
if vec is None:
|
||||
continue
|
||||
|
||||
meta = {
|
||||
"prototype_id": proto.prototype_id,
|
||||
"pack_source_hash": pack.source_hash,
|
||||
"workflow_skeleton_id": self._extract_skeleton_id(proto),
|
||||
"node_name": self._extract_node_name(proto),
|
||||
"app_name": proto.app_name or "",
|
||||
}
|
||||
vectors.append(vec)
|
||||
metadata_list.append(meta)
|
||||
|
||||
if not vectors:
|
||||
return 0
|
||||
|
||||
matrix = np.array(vectors, dtype=np.float32)
|
||||
faiss.normalize_L2(matrix)
|
||||
|
||||
self.index.add(matrix)
|
||||
self._metadata.extend(metadata_list)
|
||||
|
||||
logger.info(
|
||||
"Pack ajouté à l'index global : +%d vecteurs (total=%d)",
|
||||
len(vectors), self.total_vectors,
|
||||
)
|
||||
return len(vectors)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recherche
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def search(
|
||||
self, query_vector: np.ndarray, k: int = 5
|
||||
) -> List[GlobalSearchResult]:
|
||||
"""
|
||||
Chercher les k écrans les plus similaires dans l'index global.
|
||||
|
||||
Args:
|
||||
query_vector: Vecteur de requête (même dimension que l'index).
|
||||
k: Nombre de résultats à retourner.
|
||||
|
||||
Returns:
|
||||
Liste de GlobalSearchResult triée par similarité décroissante.
|
||||
"""
|
||||
if self.total_vectors == 0:
|
||||
return []
|
||||
|
||||
# Préparer le vecteur
|
||||
q = np.array(query_vector, dtype=np.float32).reshape(1, -1)
|
||||
faiss.normalize_L2(q)
|
||||
|
||||
k = min(k, self.total_vectors)
|
||||
distances, indices = self.index.search(q, k)
|
||||
|
||||
results = []
|
||||
for dist, idx in zip(distances[0], indices[0]):
|
||||
if idx < 0 or idx >= len(self._metadata):
|
||||
continue
|
||||
|
||||
meta = self._metadata[int(idx)]
|
||||
results.append(GlobalSearchResult(
|
||||
prototype_id=meta["prototype_id"],
|
||||
similarity=float(dist),
|
||||
pack_source_hash=meta["pack_source_hash"],
|
||||
workflow_skeleton_id=meta["workflow_skeleton_id"],
|
||||
node_name=meta["node_name"],
|
||||
app_name=meta["app_name"],
|
||||
metadata=meta,
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Persistance
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save(self, path: Path) -> None:
|
||||
"""
|
||||
Sauvegarder l'index et ses métadonnées.
|
||||
|
||||
Crée deux fichiers :
|
||||
- ``{path}.faiss`` — index FAISS binaire
|
||||
- ``{path}.meta.json`` — métadonnées JSON
|
||||
|
||||
Args:
|
||||
path: Chemin de base (sans extension).
|
||||
"""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
index_path = path.with_suffix(".faiss")
|
||||
meta_path = path.with_suffix(".meta.json")
|
||||
|
||||
faiss.write_index(self.index, str(index_path))
|
||||
|
||||
meta_data = {
|
||||
"dimensions": self.dimensions,
|
||||
"total_vectors": self.total_vectors,
|
||||
"entries": self._metadata,
|
||||
}
|
||||
with open(meta_path, "w", encoding="utf-8") as fh:
|
||||
json.dump(meta_data, fh, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(
|
||||
"Index global sauvegardé : %s (%d vecteurs)",
|
||||
index_path, self.total_vectors,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path) -> "GlobalFAISSIndex":
|
||||
"""
|
||||
Charger un index depuis le disque.
|
||||
|
||||
Args:
|
||||
path: Chemin de base (sans extension).
|
||||
|
||||
Returns:
|
||||
GlobalFAISSIndex chargé et prêt à l'emploi.
|
||||
"""
|
||||
if not FAISS_AVAILABLE:
|
||||
raise ImportError("FAISS requis pour charger l'index global.")
|
||||
|
||||
path = Path(path)
|
||||
index_path = path.with_suffix(".faiss")
|
||||
meta_path = path.with_suffix(".meta.json")
|
||||
|
||||
with open(meta_path, "r", encoding="utf-8") as fh:
|
||||
meta_data = json.load(fh)
|
||||
|
||||
dimensions = meta_data.get("dimensions", DEFAULT_DIMENSIONS)
|
||||
instance = cls.__new__(cls)
|
||||
instance.dimensions = dimensions
|
||||
instance.index = faiss.read_index(str(index_path))
|
||||
instance._metadata = meta_data.get("entries", [])
|
||||
|
||||
logger.info(
|
||||
"Index global chargé : %s (%d vecteurs, %dd)",
|
||||
index_path, instance.total_vectors, dimensions,
|
||||
)
|
||||
return instance
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Statistiques de l'index global."""
|
||||
source_hashes = set()
|
||||
app_names = set()
|
||||
for meta in self._metadata:
|
||||
source_hashes.add(meta.get("pack_source_hash", ""))
|
||||
app_name = meta.get("app_name", "")
|
||||
if app_name:
|
||||
app_names.add(app_name)
|
||||
|
||||
return {
|
||||
"dimensions": self.dimensions,
|
||||
"total_vectors": self.total_vectors,
|
||||
"unique_sources": len(source_hashes),
|
||||
"unique_apps": sorted(app_names),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utilitaires internes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _proto_to_vector(self, proto: ScreenPrototype) -> Optional[np.ndarray]:
|
||||
"""Convertir un ScreenPrototype en vecteur numpy, ou None si absent."""
|
||||
if proto.vector is None or len(proto.vector) == 0:
|
||||
return None
|
||||
|
||||
vec = np.array(proto.vector, dtype=np.float32)
|
||||
if vec.shape[0] != self.dimensions:
|
||||
logger.warning(
|
||||
"Prototype %s : dimensions incorrectes (%d != %d), ignoré",
|
||||
proto.prototype_id, vec.shape[0], self.dimensions,
|
||||
)
|
||||
return None
|
||||
return vec
|
||||
|
||||
@staticmethod
|
||||
def _extract_skeleton_id(proto: ScreenPrototype) -> str:
|
||||
"""Extraire le workflow_id depuis le prototype_id (format: workflow_id__node_id)."""
|
||||
parts = proto.prototype_id.split("__", 1)
|
||||
return parts[0] if len(parts) >= 1 else ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_node_name(proto: ScreenPrototype) -> str:
|
||||
"""Extraire le node_id depuis le prototype_id."""
|
||||
parts = proto.prototype_id.split("__", 1)
|
||||
return parts[1] if len(parts) >= 2 else proto.prototype_id
|
||||
961
core/federation/learning_pack.py
Normal file
961
core/federation/learning_pack.py
Normal file
@@ -0,0 +1,961 @@
|
||||
"""
|
||||
Learning Pack — Format d'export anonymisé des apprentissages.
|
||||
|
||||
Un LearningPack contient les connaissances extraites des workflows
|
||||
d'un client, sans aucune donnée personnelle ou sensible.
|
||||
|
||||
Ce qu'on exporte (anonymisé) :
|
||||
- Embeddings CLIP des prototypes d'écran (vecteurs 512d — pas réversibles)
|
||||
- ScreenTemplates (contraintes UI : titres fenêtres, rôles éléments)
|
||||
- Structure des workflows (nodes/edges, actions, contraintes)
|
||||
- Patterns d'erreur rencontrés
|
||||
- Signatures d'applications (app_name, version)
|
||||
|
||||
Ce qu'on N'exporte PAS :
|
||||
- Screenshots bruts
|
||||
- Textes OCR bruts (données patient potentielles)
|
||||
- Événements clavier bruts (mots de passe potentiels)
|
||||
- machine_id, hostname, IP (identification du client)
|
||||
|
||||
Structure JSON :
|
||||
{
|
||||
"version": "1.0",
|
||||
"created_at": "2026-03-19T...",
|
||||
"source_hash": "abc123...", # SHA-256 anonyme du client
|
||||
"pack_id": "lp_xxx",
|
||||
"stats": { ... },
|
||||
"app_signatures": [ ... ],
|
||||
"screen_prototypes": [ ... ],
|
||||
"workflow_skeletons": [ ... ],
|
||||
"ui_patterns": [ ... ],
|
||||
"error_patterns": [ ... ],
|
||||
"edge_statistics": [ ... ],
|
||||
}
|
||||
|
||||
Auteur : Dom, Claude — 19 mars 2026
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Version du format Learning Pack
|
||||
LEARNING_PACK_VERSION = "1.0"
|
||||
|
||||
# Seuil de similarité cosinus pour considérer deux prototypes comme identiques
|
||||
DEDUP_COSINE_THRESHOLD = 0.95
|
||||
|
||||
# Longueur maximale d'un texte avant d'être considéré comme donnée OCR sensible
|
||||
MAX_SAFE_TEXT_LENGTH = 120
|
||||
|
||||
# Champs de métadonnées à exclure (données sensibles)
|
||||
_SENSITIVE_METADATA_KEYS = frozenset({
|
||||
"screenshot_path", "screenshot", "ocr_text", "ocr_raw",
|
||||
"raw_text", "keyboard_events", "key_events", "input_text",
|
||||
"machine_id", "hostname", "ip_address", "user", "username",
|
||||
"patient", "patient_id", "dossier", "nip", "ipp",
|
||||
})
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Structures de données du Learning Pack
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class AppSignature:
|
||||
"""Signature d'une application observée."""
|
||||
app_name: str
|
||||
version: Optional[str] = None
|
||||
window_title_patterns: List[str] = field(default_factory=list)
|
||||
observation_count: int = 1
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"app_name": self.app_name,
|
||||
"version": self.version,
|
||||
"window_title_patterns": self.window_title_patterns,
|
||||
"observation_count": self.observation_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AppSignature":
|
||||
return cls(
|
||||
app_name=data["app_name"],
|
||||
version=data.get("version"),
|
||||
window_title_patterns=data.get("window_title_patterns", []),
|
||||
observation_count=data.get("observation_count", 1),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScreenPrototype:
|
||||
"""Prototype d'écran anonymisé (embedding + contraintes UI)."""
|
||||
prototype_id: str
|
||||
vector: Optional[List[float]] = None # Vecteur 512d sérialisé en liste
|
||||
provider: str = "openclip_ViT-B-32"
|
||||
app_name: Optional[str] = None
|
||||
window_constraints: Optional[Dict[str, Any]] = None
|
||||
text_constraints: Optional[Dict[str, Any]] = None
|
||||
ui_constraints: Optional[Dict[str, Any]] = None
|
||||
sample_count: int = 1
|
||||
source_hashes: List[str] = field(default_factory=list) # Packs d'origine
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"prototype_id": self.prototype_id,
|
||||
"vector": self.vector,
|
||||
"provider": self.provider,
|
||||
"app_name": self.app_name,
|
||||
"window_constraints": self.window_constraints,
|
||||
"text_constraints": self.text_constraints,
|
||||
"ui_constraints": self.ui_constraints,
|
||||
"sample_count": self.sample_count,
|
||||
"source_hashes": self.source_hashes,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ScreenPrototype":
|
||||
return cls(
|
||||
prototype_id=data["prototype_id"],
|
||||
vector=data.get("vector"),
|
||||
provider=data.get("provider", "openclip_ViT-B-32"),
|
||||
app_name=data.get("app_name"),
|
||||
window_constraints=data.get("window_constraints"),
|
||||
text_constraints=data.get("text_constraints"),
|
||||
ui_constraints=data.get("ui_constraints"),
|
||||
sample_count=data.get("sample_count", 1),
|
||||
source_hashes=data.get("source_hashes", []),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowSkeleton:
|
||||
"""Structure anonymisée d'un workflow (sans données sensibles)."""
|
||||
skeleton_id: str
|
||||
name: str
|
||||
description: str
|
||||
learning_state: str
|
||||
node_names: List[str]
|
||||
edge_summaries: List[Dict[str, Any]] # from_node, to_node, action_type, target_role
|
||||
entry_nodes: List[str]
|
||||
end_nodes: List[str]
|
||||
node_count: int = 0
|
||||
edge_count: int = 0
|
||||
app_names: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"skeleton_id": self.skeleton_id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"learning_state": self.learning_state,
|
||||
"node_names": self.node_names,
|
||||
"edge_summaries": self.edge_summaries,
|
||||
"entry_nodes": self.entry_nodes,
|
||||
"end_nodes": self.end_nodes,
|
||||
"node_count": self.node_count,
|
||||
"edge_count": self.edge_count,
|
||||
"app_names": self.app_names,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "WorkflowSkeleton":
|
||||
return cls(
|
||||
skeleton_id=data["skeleton_id"],
|
||||
name=data["name"],
|
||||
description=data.get("description", ""),
|
||||
learning_state=data.get("learning_state", "OBSERVATION"),
|
||||
node_names=data.get("node_names", []),
|
||||
edge_summaries=data.get("edge_summaries", []),
|
||||
entry_nodes=data.get("entry_nodes", []),
|
||||
end_nodes=data.get("end_nodes", []),
|
||||
node_count=data.get("node_count", 0),
|
||||
edge_count=data.get("edge_count", 0),
|
||||
app_names=data.get("app_names", []),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UIPattern:
|
||||
"""Pattern UI universel (bouton Enregistrer, menu Fichier, etc.)."""
|
||||
pattern_id: str
|
||||
role: str # button, textfield, menu, etc.
|
||||
context_description: str # description du contexte
|
||||
window_title_patterns: List[str] = field(default_factory=list)
|
||||
observation_count: int = 1
|
||||
cross_client_count: int = 1 # Nb de clients différents l'ayant vu
|
||||
confidence: float = 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"pattern_id": self.pattern_id,
|
||||
"role": self.role,
|
||||
"context_description": self.context_description,
|
||||
"window_title_patterns": self.window_title_patterns,
|
||||
"observation_count": self.observation_count,
|
||||
"cross_client_count": self.cross_client_count,
|
||||
"confidence": self.confidence,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "UIPattern":
|
||||
return cls(
|
||||
pattern_id=data["pattern_id"],
|
||||
role=data.get("role", "unknown"),
|
||||
context_description=data.get("context_description", ""),
|
||||
window_title_patterns=data.get("window_title_patterns", []),
|
||||
observation_count=data.get("observation_count", 1),
|
||||
cross_client_count=data.get("cross_client_count", 1),
|
||||
confidence=data.get("confidence", 0.0),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorPattern:
|
||||
"""Pattern d'erreur rencontré (texte d'erreur, contexte, fréquence)."""
|
||||
pattern_id: str
|
||||
error_text: str
|
||||
kind: str = "text_present" # kind du PostConditionCheck source
|
||||
app_name: Optional[str] = None
|
||||
observation_count: int = 1
|
||||
cross_client_count: int = 1
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"pattern_id": self.pattern_id,
|
||||
"error_text": self.error_text,
|
||||
"kind": self.kind,
|
||||
"app_name": self.app_name,
|
||||
"observation_count": self.observation_count,
|
||||
"cross_client_count": self.cross_client_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ErrorPattern":
|
||||
return cls(
|
||||
pattern_id=data["pattern_id"],
|
||||
error_text=data["error_text"],
|
||||
kind=data.get("kind", "text_present"),
|
||||
app_name=data.get("app_name"),
|
||||
observation_count=data.get("observation_count", 1),
|
||||
cross_client_count=data.get("cross_client_count", 1),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EdgeStatistic:
|
||||
"""Statistiques anonymisées d'une transition entre écrans."""
|
||||
from_node_name: str
|
||||
to_node_name: str
|
||||
action_type: str
|
||||
target_role: Optional[str] = None
|
||||
execution_count: int = 0
|
||||
success_rate: float = 0.0
|
||||
avg_execution_time_ms: float = 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"from_node_name": self.from_node_name,
|
||||
"to_node_name": self.to_node_name,
|
||||
"action_type": self.action_type,
|
||||
"target_role": self.target_role,
|
||||
"execution_count": self.execution_count,
|
||||
"success_rate": self.success_rate,
|
||||
"avg_execution_time_ms": self.avg_execution_time_ms,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "EdgeStatistic":
|
||||
return cls(
|
||||
from_node_name=data["from_node_name"],
|
||||
to_node_name=data["to_node_name"],
|
||||
action_type=data["action_type"],
|
||||
target_role=data.get("target_role"),
|
||||
execution_count=data.get("execution_count", 0),
|
||||
success_rate=data.get("success_rate", 0.0),
|
||||
avg_execution_time_ms=data.get("avg_execution_time_ms", 0.0),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LearningPack — conteneur principal
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class LearningPack:
|
||||
"""
|
||||
Pack d'apprentissage anonymisé prêt à être échangé entre sites.
|
||||
|
||||
Peut être sérialisé en JSON (``to_dict`` / ``from_dict``)
|
||||
ou sauvegardé / chargé depuis un fichier (``save`` / ``load``).
|
||||
"""
|
||||
|
||||
version: str = LEARNING_PACK_VERSION
|
||||
created_at: str = ""
|
||||
source_hash: str = ""
|
||||
pack_id: str = ""
|
||||
stats: Dict[str, Any] = field(default_factory=dict)
|
||||
app_signatures: List[AppSignature] = field(default_factory=list)
|
||||
screen_prototypes: List[ScreenPrototype] = field(default_factory=list)
|
||||
workflow_skeletons: List[WorkflowSkeleton] = field(default_factory=list)
|
||||
ui_patterns: List[UIPattern] = field(default_factory=list)
|
||||
error_patterns: List[ErrorPattern] = field(default_factory=list)
|
||||
edge_statistics: List[EdgeStatistic] = field(default_factory=list)
|
||||
|
||||
# --- Sérialisation -------------------------------------------------------
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convertir en dictionnaire JSON-sérialisable."""
|
||||
return {
|
||||
"version": self.version,
|
||||
"created_at": self.created_at,
|
||||
"source_hash": self.source_hash,
|
||||
"pack_id": self.pack_id,
|
||||
"stats": self.stats,
|
||||
"app_signatures": [a.to_dict() for a in self.app_signatures],
|
||||
"screen_prototypes": [p.to_dict() for p in self.screen_prototypes],
|
||||
"workflow_skeletons": [s.to_dict() for s in self.workflow_skeletons],
|
||||
"ui_patterns": [u.to_dict() for u in self.ui_patterns],
|
||||
"error_patterns": [e.to_dict() for e in self.error_patterns],
|
||||
"edge_statistics": [e.to_dict() for e in self.edge_statistics],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "LearningPack":
|
||||
"""Reconstruire depuis un dictionnaire."""
|
||||
return cls(
|
||||
version=data.get("version", LEARNING_PACK_VERSION),
|
||||
created_at=data.get("created_at", ""),
|
||||
source_hash=data.get("source_hash", ""),
|
||||
pack_id=data.get("pack_id", ""),
|
||||
stats=data.get("stats", {}),
|
||||
app_signatures=[
|
||||
AppSignature.from_dict(a) for a in data.get("app_signatures", [])
|
||||
],
|
||||
screen_prototypes=[
|
||||
ScreenPrototype.from_dict(p) for p in data.get("screen_prototypes", [])
|
||||
],
|
||||
workflow_skeletons=[
|
||||
WorkflowSkeleton.from_dict(s) for s in data.get("workflow_skeletons", [])
|
||||
],
|
||||
ui_patterns=[
|
||||
UIPattern.from_dict(u) for u in data.get("ui_patterns", [])
|
||||
],
|
||||
error_patterns=[
|
||||
ErrorPattern.from_dict(e) for e in data.get("error_patterns", [])
|
||||
],
|
||||
edge_statistics=[
|
||||
EdgeStatistic.from_dict(e) for e in data.get("edge_statistics", [])
|
||||
],
|
||||
)
|
||||
|
||||
# --- Persistance fichier --------------------------------------------------
|
||||
|
||||
def save(self, path: Path) -> None:
|
||||
"""Sauvegarder le pack au format JSON compressé."""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as fh:
|
||||
json.dump(self.to_dict(), fh, indent=2, ensure_ascii=False)
|
||||
logger.info("Learning pack sauvegardé : %s (%d prototypes, %d skeletons)",
|
||||
path, len(self.screen_prototypes), len(self.workflow_skeletons))
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path) -> "LearningPack":
|
||||
"""Charger un pack depuis un fichier JSON."""
|
||||
path = Path(path)
|
||||
with open(path, "r", encoding="utf-8") as fh:
|
||||
data = json.load(fh)
|
||||
pack = cls.from_dict(data)
|
||||
logger.info("Learning pack chargé : %s (v%s, %d prototypes)",
|
||||
path, pack.version, len(pack.screen_prototypes))
|
||||
return pack
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fonctions utilitaires d'anonymisation
|
||||
# ============================================================================
|
||||
|
||||
def _hash_client_id(client_id: str) -> str:
|
||||
"""Hacher un identifiant client via SHA-256 (irréversible)."""
|
||||
return hashlib.sha256(client_id.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _sanitize_text(text: str) -> Optional[str]:
|
||||
"""
|
||||
Nettoyer un texte pour l'export.
|
||||
|
||||
Retourne None si le texte est trop long (probable donnée OCR sensible)
|
||||
ou s'il contient des patterns suspects (numéros de dossier, etc.).
|
||||
"""
|
||||
if not text or len(text) > MAX_SAFE_TEXT_LENGTH:
|
||||
return None
|
||||
# Filtrer les textes qui ressemblent à des identifiants patients
|
||||
lower = text.lower()
|
||||
for suspect in ("patient", "nip:", "ipp:", "dossier n", "numéro de"):
|
||||
if suspect in lower:
|
||||
return None
|
||||
return text
|
||||
|
||||
|
||||
def _clean_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Retirer les clés sensibles d'un dictionnaire de métadonnées."""
|
||||
return {
|
||||
k: v for k, v in metadata.items()
|
||||
if k.lower() not in _SENSITIVE_METADATA_KEYS
|
||||
}
|
||||
|
||||
|
||||
def _extract_prototype_vector(node) -> Optional[List[float]]:
|
||||
"""
|
||||
Extraire le vecteur prototype d'un WorkflowNode.
|
||||
|
||||
Cherche dans ``node.metadata["_prototype_vector"]`` (numpy array ou liste)
|
||||
puis tente de charger depuis le fichier .npy référencé par le template.
|
||||
"""
|
||||
# 1. Vecteur directement stocké dans les métadonnées
|
||||
vec = node.metadata.get("_prototype_vector")
|
||||
if vec is not None:
|
||||
if isinstance(vec, np.ndarray):
|
||||
return vec.tolist()
|
||||
if isinstance(vec, list):
|
||||
return vec
|
||||
|
||||
# 2. Fichier .npy référencé par le template embedding
|
||||
vector_id = node.template.embedding.vector_id
|
||||
if vector_id:
|
||||
npy_path = Path(vector_id)
|
||||
if npy_path.exists() and npy_path.suffix == ".npy":
|
||||
try:
|
||||
arr = np.load(str(npy_path))
|
||||
return arr.tolist()
|
||||
except Exception as exc:
|
||||
logger.debug("Impossible de charger %s : %s", npy_path, exc)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LearningPackExporter
|
||||
# ============================================================================
|
||||
|
||||
class LearningPackExporter:
|
||||
"""
|
||||
Produit un LearningPack anonymisé à partir d'une liste de Workflows.
|
||||
|
||||
Usage :
|
||||
>>> from core.models.workflow_graph import Workflow
|
||||
>>> exporter = LearningPackExporter()
|
||||
>>> pack = exporter.export(workflows, client_id="CHU-Lyon-001")
|
||||
>>> pack.save(Path("export/chu_lyon.json"))
|
||||
"""
|
||||
|
||||
def export(self, workflows, client_id: str) -> LearningPack:
|
||||
"""
|
||||
Exporter les workflows d'un client en un LearningPack anonymisé.
|
||||
|
||||
Args:
|
||||
workflows: Liste d'objets ``Workflow`` (core.models.workflow_graph).
|
||||
client_id: Identifiant en clair du client (sera haché).
|
||||
|
||||
Returns:
|
||||
LearningPack prêt à être sauvegardé ou envoyé au serveur central.
|
||||
"""
|
||||
source_hash = _hash_client_id(client_id)
|
||||
pack_id = f"lp_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
app_sigs: Dict[str, AppSignature] = {}
|
||||
prototypes: List[ScreenPrototype] = []
|
||||
skeletons: List[WorkflowSkeleton] = []
|
||||
ui_patterns_map: Dict[str, UIPattern] = {}
|
||||
error_patterns_map: Dict[str, ErrorPattern] = {}
|
||||
edge_stats: List[EdgeStatistic] = []
|
||||
|
||||
total_nodes = 0
|
||||
total_edges = 0
|
||||
|
||||
for wf in workflows:
|
||||
# --- Skeleton ---
|
||||
skeleton = self._extract_skeleton(wf)
|
||||
skeletons.append(skeleton)
|
||||
|
||||
total_nodes += len(wf.nodes)
|
||||
total_edges += len(wf.edges)
|
||||
|
||||
# --- Nodes : prototypes + app signatures + UI patterns ---
|
||||
for node in wf.nodes:
|
||||
proto = self._extract_prototype(node, source_hash, wf.workflow_id)
|
||||
if proto is not None:
|
||||
prototypes.append(proto)
|
||||
|
||||
self._collect_app_signature(node, app_sigs)
|
||||
self._collect_ui_patterns(node, ui_patterns_map)
|
||||
|
||||
# --- Edges : actions + error patterns + stats ---
|
||||
for edge in wf.edges:
|
||||
self._collect_error_patterns(edge, error_patterns_map, wf)
|
||||
stat = self._extract_edge_statistic(edge, wf)
|
||||
if stat is not None:
|
||||
edge_stats.append(stat)
|
||||
|
||||
apps_seen = sorted(app_sigs.keys())
|
||||
|
||||
pack = LearningPack(
|
||||
version=LEARNING_PACK_VERSION,
|
||||
created_at=datetime.utcnow().isoformat(),
|
||||
source_hash=source_hash,
|
||||
pack_id=pack_id,
|
||||
stats={
|
||||
"workflows_count": len(workflows),
|
||||
"total_nodes": total_nodes,
|
||||
"total_edges": total_edges,
|
||||
"apps_seen": apps_seen,
|
||||
"prototypes_exported": len(prototypes),
|
||||
},
|
||||
app_signatures=list(app_sigs.values()),
|
||||
screen_prototypes=prototypes,
|
||||
workflow_skeletons=skeletons,
|
||||
ui_patterns=list(ui_patterns_map.values()),
|
||||
error_patterns=list(error_patterns_map.values()),
|
||||
edge_statistics=edge_stats,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Learning pack exporté : %s — %d workflows, %d prototypes, %d error patterns",
|
||||
pack_id, len(workflows), len(prototypes), len(error_patterns_map),
|
||||
)
|
||||
return pack
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Extraction unitaire
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _extract_skeleton(self, wf) -> WorkflowSkeleton:
|
||||
"""Extraire le squelette anonymisé d'un workflow."""
|
||||
node_names = [n.name for n in wf.nodes]
|
||||
app_names = set()
|
||||
|
||||
edge_summaries = []
|
||||
for edge in wf.edges:
|
||||
summary: Dict[str, Any] = {
|
||||
"from_node": edge.from_node,
|
||||
"to_node": edge.to_node,
|
||||
"action_type": edge.action.type,
|
||||
"target_role": edge.action.target.by_role,
|
||||
}
|
||||
edge_summaries.append(summary)
|
||||
|
||||
for node in wf.nodes:
|
||||
proc = node.template.window.process_name
|
||||
if proc:
|
||||
app_names.add(proc)
|
||||
|
||||
return WorkflowSkeleton(
|
||||
skeleton_id=wf.workflow_id,
|
||||
name=wf.name,
|
||||
description=wf.description,
|
||||
learning_state=wf.learning_state,
|
||||
node_names=node_names,
|
||||
edge_summaries=edge_summaries,
|
||||
entry_nodes=wf.entry_nodes,
|
||||
end_nodes=wf.end_nodes,
|
||||
node_count=len(wf.nodes),
|
||||
edge_count=len(wf.edges),
|
||||
app_names=sorted(app_names),
|
||||
)
|
||||
|
||||
def _extract_prototype(
|
||||
self, node, source_hash: str, workflow_id: str
|
||||
) -> Optional[ScreenPrototype]:
|
||||
"""Extraire un ScreenPrototype anonymisé depuis un WorkflowNode."""
|
||||
vector = _extract_prototype_vector(node)
|
||||
# On exporte même sans vecteur : les contraintes UI ont de la valeur
|
||||
app_name = node.template.window.process_name
|
||||
|
||||
# Construire les contraintes nettoyées
|
||||
window_constraints = node.template.window.to_dict()
|
||||
text_constraints = self._sanitize_text_constraints(node.template.text.to_dict())
|
||||
ui_constraints = node.template.ui.to_dict()
|
||||
|
||||
return ScreenPrototype(
|
||||
prototype_id=f"{workflow_id}__{node.node_id}",
|
||||
vector=vector,
|
||||
provider=node.template.embedding.provider,
|
||||
app_name=app_name,
|
||||
window_constraints=window_constraints,
|
||||
text_constraints=text_constraints,
|
||||
ui_constraints=ui_constraints,
|
||||
sample_count=node.template.embedding.sample_count,
|
||||
source_hashes=[source_hash],
|
||||
)
|
||||
|
||||
def _sanitize_text_constraints(self, text_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Nettoyer les contraintes texte en retirant les textes trop longs / sensibles."""
|
||||
required = [
|
||||
t for t in text_dict.get("required_texts", [])
|
||||
if _sanitize_text(t) is not None
|
||||
]
|
||||
forbidden = [
|
||||
t for t in text_dict.get("forbidden_texts", [])
|
||||
if _sanitize_text(t) is not None
|
||||
]
|
||||
return {"required_texts": required, "forbidden_texts": forbidden}
|
||||
|
||||
def _collect_app_signature(
|
||||
self, node, app_sigs: Dict[str, AppSignature]
|
||||
) -> None:
|
||||
"""Collecter la signature d'application depuis un node."""
|
||||
proc = node.template.window.process_name
|
||||
if not proc:
|
||||
return
|
||||
|
||||
if proc in app_sigs:
|
||||
app_sigs[proc].observation_count += 1
|
||||
else:
|
||||
title_pattern = node.template.window.title_pattern
|
||||
patterns = [title_pattern] if title_pattern else []
|
||||
app_sigs[proc] = AppSignature(
|
||||
app_name=proc,
|
||||
window_title_patterns=patterns,
|
||||
)
|
||||
|
||||
# Ajouter le pattern de titre s'il est nouveau
|
||||
title_pattern = node.template.window.title_pattern
|
||||
if title_pattern and title_pattern not in app_sigs[proc].window_title_patterns:
|
||||
app_sigs[proc].window_title_patterns.append(title_pattern)
|
||||
|
||||
def _collect_ui_patterns(
|
||||
self, node, patterns: Dict[str, UIPattern]
|
||||
) -> None:
|
||||
"""Collecter les patterns UI depuis les contraintes d'un node."""
|
||||
for role in node.template.ui.required_roles:
|
||||
key = role
|
||||
if key in patterns:
|
||||
patterns[key].observation_count += 1
|
||||
else:
|
||||
title_pattern = node.template.window.title_pattern
|
||||
title_patterns = [title_pattern] if title_pattern else []
|
||||
patterns[key] = UIPattern(
|
||||
pattern_id=f"uip_{role}",
|
||||
role=role,
|
||||
context_description=f"Rôle UI requis : {role}",
|
||||
window_title_patterns=title_patterns,
|
||||
)
|
||||
|
||||
def _collect_error_patterns(
|
||||
self, edge, patterns: Dict[str, ErrorPattern], wf
|
||||
) -> None:
|
||||
"""Extraire les patterns d'erreur depuis les PostConditions.fail_fast."""
|
||||
for check in edge.post_conditions.fail_fast:
|
||||
if check.value and _sanitize_text(check.value) is not None:
|
||||
key = check.value
|
||||
if key in patterns:
|
||||
patterns[key].observation_count += 1
|
||||
else:
|
||||
# Trouver l'app_name du node source
|
||||
source_node = wf.get_node(edge.from_node)
|
||||
app_name = None
|
||||
if source_node:
|
||||
app_name = source_node.template.window.process_name
|
||||
|
||||
patterns[key] = ErrorPattern(
|
||||
pattern_id=f"err_{hashlib.md5(key.encode()).hexdigest()[:8]}",
|
||||
error_text=check.value,
|
||||
kind=check.kind,
|
||||
app_name=app_name,
|
||||
)
|
||||
|
||||
def _extract_edge_statistic(self, edge, wf) -> Optional[EdgeStatistic]:
|
||||
"""Extraire les statistiques anonymisées d'un edge."""
|
||||
source_node = wf.get_node(edge.from_node)
|
||||
target_node = wf.get_node(edge.to_node)
|
||||
|
||||
from_name = source_node.name if source_node else edge.from_node
|
||||
to_name = target_node.name if target_node else edge.to_node
|
||||
|
||||
return EdgeStatistic(
|
||||
from_node_name=from_name,
|
||||
to_node_name=to_name,
|
||||
action_type=edge.action.type,
|
||||
target_role=edge.action.target.by_role,
|
||||
execution_count=edge.stats.execution_count,
|
||||
success_rate=edge.stats.success_rate,
|
||||
avg_execution_time_ms=edge.stats.avg_execution_time_ms,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LearningPackMerger
|
||||
# ============================================================================
|
||||
|
||||
class LearningPackMerger:
|
||||
"""
|
||||
Fusionne plusieurs LearningPacks en un seul pack consolidé.
|
||||
|
||||
La fusion :
|
||||
- Déduplique les prototypes similaires (cosine > 0.95 = même écran)
|
||||
- Fusionne les signatures d'application (union)
|
||||
- Fusionne les patterns d'erreur (union, comptage cross-clients)
|
||||
- Calcule les occurrences cross-clients (haute confiance si vu par N clients)
|
||||
|
||||
Usage :
|
||||
>>> merger = LearningPackMerger()
|
||||
>>> merged = merger.merge([pack_a, pack_b, pack_c])
|
||||
>>> merged.save(Path("global/merged_pack.json"))
|
||||
"""
|
||||
|
||||
def __init__(self, dedup_threshold: float = DEDUP_COSINE_THRESHOLD):
|
||||
self.dedup_threshold = dedup_threshold
|
||||
|
||||
def merge(self, packs: List[LearningPack]) -> LearningPack:
|
||||
"""
|
||||
Fusionner plusieurs packs en un pack global consolidé.
|
||||
|
||||
Args:
|
||||
packs: Liste de LearningPacks à fusionner.
|
||||
|
||||
Returns:
|
||||
LearningPack consolidé avec déduplication et comptage cross-clients.
|
||||
"""
|
||||
if not packs:
|
||||
return LearningPack(
|
||||
created_at=datetime.utcnow().isoformat(),
|
||||
pack_id=f"lp_merged_{uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
|
||||
if len(packs) == 1:
|
||||
# Un seul pack : on le retourne avec un nouveau pack_id
|
||||
merged = LearningPack.from_dict(packs[0].to_dict())
|
||||
merged.pack_id = f"lp_merged_{uuid.uuid4().hex[:8]}"
|
||||
return merged
|
||||
|
||||
merged_id = f"lp_merged_{uuid.uuid4().hex[:8]}"
|
||||
source_hashes = list({p.source_hash for p in packs if p.source_hash})
|
||||
|
||||
# Fusionner chaque catégorie
|
||||
app_sigs = self._merge_app_signatures(packs)
|
||||
prototypes = self._merge_prototypes(packs)
|
||||
skeletons = self._merge_skeletons(packs)
|
||||
ui_patterns = self._merge_ui_patterns(packs)
|
||||
error_patterns = self._merge_error_patterns(packs)
|
||||
edge_stats = self._merge_edge_statistics(packs)
|
||||
|
||||
# Calculer les stats globales
|
||||
total_wf = sum(p.stats.get("workflows_count", 0) for p in packs)
|
||||
total_nodes = sum(p.stats.get("total_nodes", 0) for p in packs)
|
||||
total_edges = sum(p.stats.get("total_edges", 0) for p in packs)
|
||||
all_apps = set()
|
||||
for p in packs:
|
||||
all_apps.update(p.stats.get("apps_seen", []))
|
||||
|
||||
return LearningPack(
|
||||
version=LEARNING_PACK_VERSION,
|
||||
created_at=datetime.utcnow().isoformat(),
|
||||
source_hash=",".join(sorted(source_hashes)),
|
||||
pack_id=merged_id,
|
||||
stats={
|
||||
"workflows_count": total_wf,
|
||||
"total_nodes": total_nodes,
|
||||
"total_edges": total_edges,
|
||||
"apps_seen": sorted(all_apps),
|
||||
"prototypes_exported": len(prototypes),
|
||||
"source_packs_count": len(packs),
|
||||
"source_hashes": source_hashes,
|
||||
},
|
||||
app_signatures=app_sigs,
|
||||
screen_prototypes=prototypes,
|
||||
workflow_skeletons=skeletons,
|
||||
ui_patterns=ui_patterns,
|
||||
error_patterns=error_patterns,
|
||||
edge_statistics=edge_stats,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fusion par catégorie
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _merge_app_signatures(self, packs: List[LearningPack]) -> List[AppSignature]:
|
||||
"""Union des signatures d'application, cumul des compteurs."""
|
||||
merged: Dict[str, AppSignature] = {}
|
||||
for pack in packs:
|
||||
for sig in pack.app_signatures:
|
||||
if sig.app_name in merged:
|
||||
existing = merged[sig.app_name]
|
||||
existing.observation_count += sig.observation_count
|
||||
for pat in sig.window_title_patterns:
|
||||
if pat not in existing.window_title_patterns:
|
||||
existing.window_title_patterns.append(pat)
|
||||
else:
|
||||
merged[sig.app_name] = AppSignature.from_dict(sig.to_dict())
|
||||
return list(merged.values())
|
||||
|
||||
def _merge_prototypes(self, packs: List[LearningPack]) -> List[ScreenPrototype]:
|
||||
"""
|
||||
Fusionner les prototypes avec déduplication par similarité cosinus.
|
||||
|
||||
Deux prototypes avec cosine > ``self.dedup_threshold`` sont considérés
|
||||
comme le même écran. On conserve celui avec le plus d'échantillons
|
||||
et on fusionne les source_hashes.
|
||||
"""
|
||||
all_protos: List[ScreenPrototype] = []
|
||||
for pack in packs:
|
||||
all_protos.extend(pack.screen_prototypes)
|
||||
|
||||
if not all_protos:
|
||||
return []
|
||||
|
||||
# Séparer les prototypes avec et sans vecteur
|
||||
with_vec: List[Tuple[ScreenPrototype, np.ndarray]] = []
|
||||
without_vec: List[ScreenPrototype] = []
|
||||
|
||||
for proto in all_protos:
|
||||
if proto.vector is not None and len(proto.vector) > 0:
|
||||
vec = np.array(proto.vector, dtype=np.float32)
|
||||
norm = np.linalg.norm(vec)
|
||||
if norm > 0:
|
||||
vec = vec / norm
|
||||
with_vec.append((proto, vec))
|
||||
else:
|
||||
without_vec.append(proto)
|
||||
|
||||
# Déduplication greedy par similarité cosinus
|
||||
merged: List[ScreenPrototype] = []
|
||||
used = [False] * len(with_vec)
|
||||
|
||||
for i, (proto_i, vec_i) in enumerate(with_vec):
|
||||
if used[i]:
|
||||
continue
|
||||
used[i] = True
|
||||
|
||||
# Chercher les prototypes similaires
|
||||
group_sources = set(proto_i.source_hashes)
|
||||
best_sample_count = proto_i.sample_count
|
||||
best_proto = proto_i
|
||||
|
||||
for j in range(i + 1, len(with_vec)):
|
||||
if used[j]:
|
||||
continue
|
||||
proto_j, vec_j = with_vec[j]
|
||||
cosine_sim = float(np.dot(vec_i, vec_j))
|
||||
|
||||
if cosine_sim >= self.dedup_threshold:
|
||||
used[j] = True
|
||||
group_sources.update(proto_j.source_hashes)
|
||||
if proto_j.sample_count > best_sample_count:
|
||||
best_sample_count = proto_j.sample_count
|
||||
best_proto = proto_j
|
||||
|
||||
# Construire le prototype consolidé
|
||||
consolidated = ScreenPrototype.from_dict(best_proto.to_dict())
|
||||
consolidated.source_hashes = sorted(group_sources)
|
||||
consolidated.sample_count = best_sample_count
|
||||
merged.append(consolidated)
|
||||
|
||||
# Ajouter les prototypes sans vecteur (pas de déduplication possible)
|
||||
merged.extend(without_vec)
|
||||
|
||||
logger.info(
|
||||
"Fusion prototypes : %d entrées → %d après déduplication (seuil=%.2f)",
|
||||
len(all_protos), len(merged), self.dedup_threshold,
|
||||
)
|
||||
return merged
|
||||
|
||||
def _merge_skeletons(self, packs: List[LearningPack]) -> List[WorkflowSkeleton]:
|
||||
"""Union des skeletons de workflows (dédupliqués par skeleton_id)."""
|
||||
merged: Dict[str, WorkflowSkeleton] = {}
|
||||
for pack in packs:
|
||||
for skel in pack.workflow_skeletons:
|
||||
if skel.skeleton_id not in merged:
|
||||
merged[skel.skeleton_id] = skel
|
||||
return list(merged.values())
|
||||
|
||||
def _merge_ui_patterns(self, packs: List[LearningPack]) -> List[UIPattern]:
|
||||
"""Fusionner les patterns UI avec comptage cross-clients."""
|
||||
merged: Dict[str, UIPattern] = {}
|
||||
# Suivre quels source_hashes ont vu chaque pattern
|
||||
pattern_sources: Dict[str, set] = {}
|
||||
|
||||
for pack in packs:
|
||||
for pattern in pack.ui_patterns:
|
||||
key = pattern.role
|
||||
if key in merged:
|
||||
merged[key].observation_count += pattern.observation_count
|
||||
for pat in pattern.window_title_patterns:
|
||||
if pat not in merged[key].window_title_patterns:
|
||||
merged[key].window_title_patterns.append(pat)
|
||||
else:
|
||||
merged[key] = UIPattern.from_dict(pattern.to_dict())
|
||||
pattern_sources[key] = set()
|
||||
if pack.source_hash:
|
||||
pattern_sources.setdefault(key, set()).add(pack.source_hash)
|
||||
|
||||
# Mettre à jour le cross_client_count
|
||||
for key, pattern in merged.items():
|
||||
sources = pattern_sources.get(key, set())
|
||||
pattern.cross_client_count = len(sources)
|
||||
# Confiance = proportion de clients ayant vu le pattern
|
||||
total_clients = len({p.source_hash for p in packs if p.source_hash})
|
||||
pattern.confidence = (
|
||||
len(sources) / total_clients if total_clients > 0 else 0.0
|
||||
)
|
||||
|
||||
return list(merged.values())
|
||||
|
||||
def _merge_error_patterns(self, packs: List[LearningPack]) -> List[ErrorPattern]:
|
||||
"""Fusionner les patterns d'erreur avec comptage cross-clients."""
|
||||
merged: Dict[str, ErrorPattern] = {}
|
||||
pattern_sources: Dict[str, set] = {}
|
||||
|
||||
for pack in packs:
|
||||
for pattern in pack.error_patterns:
|
||||
key = pattern.error_text
|
||||
if key in merged:
|
||||
merged[key].observation_count += pattern.observation_count
|
||||
else:
|
||||
merged[key] = ErrorPattern.from_dict(pattern.to_dict())
|
||||
pattern_sources[key] = set()
|
||||
if pack.source_hash:
|
||||
pattern_sources.setdefault(key, set()).add(pack.source_hash)
|
||||
|
||||
for key, pattern in merged.items():
|
||||
pattern.cross_client_count = len(pattern_sources.get(key, set()))
|
||||
|
||||
return list(merged.values())
|
||||
|
||||
def _merge_edge_statistics(
|
||||
self, packs: List[LearningPack]
|
||||
) -> List[EdgeStatistic]:
|
||||
"""Fusionner les statistiques de transitions."""
|
||||
merged: Dict[str, EdgeStatistic] = {}
|
||||
|
||||
for pack in packs:
|
||||
for stat in pack.edge_statistics:
|
||||
key = f"{stat.from_node_name}→{stat.to_node_name}→{stat.action_type}"
|
||||
if key in merged:
|
||||
existing = merged[key]
|
||||
total_exec = existing.execution_count + stat.execution_count
|
||||
if total_exec > 0:
|
||||
# Moyenne pondérée du success_rate
|
||||
existing.success_rate = (
|
||||
existing.success_rate * existing.execution_count
|
||||
+ stat.success_rate * stat.execution_count
|
||||
) / total_exec
|
||||
# Moyenne pondérée du temps d'exécution
|
||||
existing.avg_execution_time_ms = (
|
||||
existing.avg_execution_time_ms * existing.execution_count
|
||||
+ stat.avg_execution_time_ms * stat.execution_count
|
||||
) / total_exec
|
||||
existing.execution_count = total_exec
|
||||
else:
|
||||
merged[key] = EdgeStatistic.from_dict(stat.to_dict())
|
||||
|
||||
return list(merged.values())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -135,27 +135,48 @@ class ContextLevel:
|
||||
|
||||
@dataclass
|
||||
class WindowContext:
|
||||
"""Contexte de fenêtre"""
|
||||
"""Contexte de fenêtre avec métadonnées d'environnement graphique"""
|
||||
app_name: str
|
||||
window_title: str
|
||||
screen_resolution: List[int]
|
||||
workspace: str = "main"
|
||||
|
||||
monitor_index: int = 0 # Index du moniteur (0 = principal)
|
||||
dpi_scale: int = 100 # Facteur DPI en % (100 = normal, 150 = haute résolution)
|
||||
window_bounds: Optional[List[int]] = None # [x, y, width, height] de la fenêtre
|
||||
monitors: Optional[List[Dict[str, int]]] = None # Liste des moniteurs [{width, height, x, y}]
|
||||
os_theme: str = "unknown" # "light", "dark", "unknown"
|
||||
os_language: str = "unknown" # Code langue (fr, en, de...)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
result = {
|
||||
"app_name": self.app_name,
|
||||
"window_title": self.window_title,
|
||||
"screen_resolution": self.screen_resolution,
|
||||
"workspace": self.workspace
|
||||
"workspace": self.workspace,
|
||||
"monitor_index": self.monitor_index,
|
||||
"dpi_scale": self.dpi_scale,
|
||||
"os_theme": self.os_theme,
|
||||
"os_language": self.os_language,
|
||||
}
|
||||
|
||||
if self.window_bounds is not None:
|
||||
result["window_bounds"] = self.window_bounds
|
||||
if self.monitors is not None:
|
||||
result["monitors"] = self.monitors
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'WindowContext':
|
||||
return cls(
|
||||
app_name=data["app_name"],
|
||||
window_title=data["window_title"],
|
||||
screen_resolution=data["screen_resolution"],
|
||||
workspace=data.get("workspace", "main")
|
||||
workspace=data.get("workspace", "main"),
|
||||
monitor_index=data.get("monitor_index", 0),
|
||||
dpi_scale=data.get("dpi_scale", 100),
|
||||
window_bounds=data.get("window_bounds"),
|
||||
monitors=data.get("monitors"),
|
||||
os_theme=data.get("os_theme", "unknown"),
|
||||
os_language=data.get("os_language", "unknown"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -304,7 +304,7 @@ class ScreenTemplate:
|
||||
|
||||
# Vérifier contraintes de fenêtre
|
||||
if hasattr(screen_state, 'window'):
|
||||
window_title = getattr(screen_state.window, 'title', '')
|
||||
window_title = getattr(screen_state.window, 'window_title', '')
|
||||
process = getattr(screen_state.window, 'process', '')
|
||||
if not self.window.matches(window_title, process):
|
||||
return False, 0.0
|
||||
@@ -672,24 +672,94 @@ class Action:
|
||||
|
||||
@dataclass
|
||||
class EdgeConstraints:
|
||||
"""Contraintes pour l'exécution d'un edge"""
|
||||
"""Contraintes pour l'exécution d'un edge (pré-conditions avant l'action)"""
|
||||
pre_conditions: Dict[str, Any] = field(default_factory=dict)
|
||||
required_confidence: float = 0.8
|
||||
max_wait_time_ms: int = 5000
|
||||
|
||||
|
||||
# Contraintes enrichies extraites du node source
|
||||
window: Optional[WindowConstraint] = None
|
||||
text: Optional[TextConstraint] = None
|
||||
min_source_similarity: float = 0.80
|
||||
required_app_name: Optional[str] = None
|
||||
required_window_title: Optional[str] = None
|
||||
|
||||
def check_preconditions(
|
||||
self, window_title: str = "", app_name: str = "",
|
||||
detected_texts: Optional[List[str]] = None,
|
||||
source_similarity: float = 1.0,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Vérifier si toutes les pré-conditions sont satisfaites.
|
||||
|
||||
Returns:
|
||||
(ok: bool, reason: str)
|
||||
"""
|
||||
# Vérifier similarité minimale avec le node source
|
||||
if source_similarity < self.min_source_similarity:
|
||||
return False, (
|
||||
f"Similarité source insuffisante: {source_similarity:.2f} "
|
||||
f"< {self.min_source_similarity:.2f}"
|
||||
)
|
||||
|
||||
# Vérifier titre de fenêtre
|
||||
if self.required_window_title and window_title:
|
||||
if self.required_window_title not in window_title:
|
||||
return False, (
|
||||
f"Titre de fenêtre incorrect: '{window_title}' "
|
||||
f"ne contient pas '{self.required_window_title}'"
|
||||
)
|
||||
|
||||
# Vérifier nom d'application
|
||||
if self.required_app_name and app_name:
|
||||
if self.required_app_name.lower() not in app_name.lower():
|
||||
return False, (
|
||||
f"Application incorrecte: '{app_name}' "
|
||||
f"ne correspond pas à '{self.required_app_name}'"
|
||||
)
|
||||
|
||||
# Vérifier contrainte de fenêtre (objet WindowConstraint)
|
||||
if self.window:
|
||||
if not self.window.matches(window_title, app_name):
|
||||
return False, f"Contrainte de fenêtre non satisfaite"
|
||||
|
||||
# Vérifier contrainte de texte (objet TextConstraint)
|
||||
if self.text and detected_texts is not None:
|
||||
if not self.text.matches(detected_texts):
|
||||
return False, f"Contrainte de texte non satisfaite"
|
||||
|
||||
return True, "OK"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"pre_conditions": self.pre_conditions,
|
||||
"required_confidence": self.required_confidence,
|
||||
"max_wait_time_ms": self.max_wait_time_ms
|
||||
"max_wait_time_ms": self.max_wait_time_ms,
|
||||
"window": self.window.to_dict() if self.window else None,
|
||||
"text": self.text.to_dict() if self.text else None,
|
||||
"min_source_similarity": self.min_source_similarity,
|
||||
"required_app_name": self.required_app_name,
|
||||
"required_window_title": self.required_window_title,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'EdgeConstraints':
|
||||
window = None
|
||||
if data.get("window"):
|
||||
window = WindowConstraint.from_dict(data["window"])
|
||||
text = None
|
||||
if data.get("text"):
|
||||
text = TextConstraint.from_dict(data["text"])
|
||||
|
||||
return cls(
|
||||
pre_conditions=data.get("pre_conditions", {}),
|
||||
required_confidence=data.get("required_confidence", 0.8),
|
||||
max_wait_time_ms=data.get("max_wait_time_ms", 5000)
|
||||
max_wait_time_ms=data.get("max_wait_time_ms", 5000),
|
||||
window=window,
|
||||
text=text,
|
||||
min_source_similarity=data.get("min_source_similarity", 0.80),
|
||||
required_app_name=data.get("required_app_name"),
|
||||
required_window_title=data.get("required_window_title"),
|
||||
)
|
||||
|
||||
|
||||
@@ -709,23 +779,101 @@ class PostConditionCheck:
|
||||
@dataclass
|
||||
class PostConditions:
|
||||
"""Post-conditions attendues après exécution - Fiche #9"""
|
||||
# (garde tes champs existants si tu en as déjà, et ajoute ceux-ci)
|
||||
|
||||
|
||||
success_mode: str = "all" # "all" | "any"
|
||||
timeout_ms: int = 2500
|
||||
poll_ms: int = 200
|
||||
|
||||
|
||||
success: List[PostConditionCheck] = field(default_factory=list)
|
||||
fail_fast: List[PostConditionCheck] = field(default_factory=list)
|
||||
|
||||
|
||||
retries: int = 2 # nb de tentatives après échec post-conditions
|
||||
backoff_ms: int = 150 # 150, 300, 600...
|
||||
|
||||
|
||||
# Contraintes enrichies extraites du node cible
|
||||
expected_window_title: Optional[str] = None
|
||||
expected_app_name: Optional[str] = None
|
||||
min_target_similarity: float = 0.80
|
||||
|
||||
# Legacy fields (garde compatibilité)
|
||||
expected_node: Optional[str] = None # Node attendu après action
|
||||
window_change_expected: bool = False
|
||||
new_ui_elements_expected: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def check_postconditions(
|
||||
self, window_title: str = "", app_name: str = "",
|
||||
detected_texts: Optional[List[str]] = None,
|
||||
target_similarity: float = 1.0,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Vérifier si les post-conditions sont satisfaites après l'action.
|
||||
|
||||
Returns:
|
||||
(ok: bool, reason: str)
|
||||
"""
|
||||
# Vérifier similarité minimale avec le node cible
|
||||
if target_similarity < self.min_target_similarity:
|
||||
return False, (
|
||||
f"Similarité cible insuffisante: {target_similarity:.2f} "
|
||||
f"< {self.min_target_similarity:.2f}"
|
||||
)
|
||||
|
||||
# Vérifier titre de fenêtre attendu
|
||||
if self.expected_window_title and window_title:
|
||||
if self.expected_window_title not in window_title:
|
||||
return False, (
|
||||
f"Titre de fenêtre post-action incorrect: '{window_title}' "
|
||||
f"ne contient pas '{self.expected_window_title}'"
|
||||
)
|
||||
|
||||
# Vérifier application attendue
|
||||
if self.expected_app_name and app_name:
|
||||
if self.expected_app_name.lower() not in app_name.lower():
|
||||
return False, (
|
||||
f"Application post-action incorrecte: '{app_name}' "
|
||||
f"ne correspond pas à '{self.expected_app_name}'"
|
||||
)
|
||||
|
||||
# Vérifier les checks de succès (PostConditionCheck)
|
||||
if self.success:
|
||||
results = []
|
||||
for check in self.success:
|
||||
ok = self._evaluate_check(check, window_title, detected_texts or [])
|
||||
results.append(ok)
|
||||
|
||||
if self.success_mode == "all" and not all(results):
|
||||
return False, "Certaines post-conditions de succès non satisfaites"
|
||||
if self.success_mode == "any" and not any(results):
|
||||
return False, "Aucune post-condition de succès satisfaite"
|
||||
|
||||
# Vérifier fail_fast (si un pattern d'erreur est détecté, échec immédiat)
|
||||
if self.fail_fast and detected_texts:
|
||||
for check in self.fail_fast:
|
||||
if self._evaluate_check(check, window_title, detected_texts):
|
||||
return False, (
|
||||
f"Condition d'échec détectée: {check.kind}={check.value}"
|
||||
)
|
||||
|
||||
return True, "OK"
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_check(
|
||||
check: PostConditionCheck,
|
||||
window_title: str,
|
||||
detected_texts: List[str],
|
||||
) -> bool:
|
||||
"""Évaluer un PostConditionCheck individuel."""
|
||||
texts_lower = [t.lower() for t in detected_texts]
|
||||
|
||||
if check.kind == "text_present":
|
||||
return any(check.value.lower() in t for t in texts_lower) if check.value else False
|
||||
elif check.kind == "text_absent":
|
||||
return not any(check.value.lower() in t for t in texts_lower) if check.value else True
|
||||
elif check.kind == "window_title_contains":
|
||||
return check.value.lower() in window_title.lower() if check.value else False
|
||||
# Autres types de checks non gérés ici → considérés comme OK
|
||||
return True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"success_mode": self.success_mode,
|
||||
@@ -735,24 +883,28 @@ class PostConditions:
|
||||
"fail_fast": [{"kind": c.kind, "value": c.value, "target": c.target.to_dict() if c.target else None} for c in self.fail_fast],
|
||||
"retries": self.retries,
|
||||
"backoff_ms": self.backoff_ms,
|
||||
# Contraintes enrichies
|
||||
"expected_window_title": self.expected_window_title,
|
||||
"expected_app_name": self.expected_app_name,
|
||||
"min_target_similarity": self.min_target_similarity,
|
||||
# Legacy
|
||||
"expected_node": self.expected_node,
|
||||
"window_change_expected": self.window_change_expected,
|
||||
"new_ui_elements_expected": self.new_ui_elements_expected
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'PostConditions':
|
||||
success_checks = []
|
||||
for c in data.get("success", []):
|
||||
target = TargetSpec.from_dict(c["target"]) if c.get("target") else None
|
||||
success_checks.append(PostConditionCheck(kind=c["kind"], value=c.get("value"), target=target))
|
||||
|
||||
|
||||
fail_fast_checks = []
|
||||
for c in data.get("fail_fast", []):
|
||||
target = TargetSpec.from_dict(c["target"]) if c.get("target") else None
|
||||
fail_fast_checks.append(PostConditionCheck(kind=c["kind"], value=c.get("value"), target=target))
|
||||
|
||||
|
||||
return cls(
|
||||
success_mode=data.get("success_mode", "all"),
|
||||
timeout_ms=data.get("timeout_ms", 2500),
|
||||
@@ -761,6 +913,10 @@ class PostConditions:
|
||||
fail_fast=fail_fast_checks,
|
||||
retries=data.get("retries", 2),
|
||||
backoff_ms=data.get("backoff_ms", 150),
|
||||
# Contraintes enrichies
|
||||
expected_window_title=data.get("expected_window_title"),
|
||||
expected_app_name=data.get("expected_app_name"),
|
||||
min_target_similarity=data.get("min_target_similarity", 0.80),
|
||||
# Legacy
|
||||
expected_node=data.get("expected_node"),
|
||||
window_change_expected=data.get("window_change_expected", False),
|
||||
|
||||
@@ -321,6 +321,12 @@ class ScreenAnalyzer:
|
||||
window_title=window_info.get("title", "Unknown"),
|
||||
screen_resolution=window_info.get("screen_resolution", [1920, 1080]),
|
||||
workspace=window_info.get("workspace", "main"),
|
||||
monitor_index=window_info.get("monitor_index", 0),
|
||||
dpi_scale=window_info.get("dpi_scale", 100),
|
||||
window_bounds=window_info.get("window_bounds"),
|
||||
monitors=window_info.get("monitors"),
|
||||
os_theme=window_info.get("os_theme", "unknown"),
|
||||
os_language=window_info.get("os_language", "unknown"),
|
||||
)
|
||||
return WindowContext(
|
||||
app_name="unknown",
|
||||
|
||||
Reference in New Issue
Block a user