diff --git a/core/graph/graph_builder.py b/core/graph/graph_builder.py index 4bc606c3b..949484515 100644 --- a/core/graph/graph_builder.py +++ b/core/graph/graph_builder.py @@ -173,10 +173,14 @@ class GraphBuilder: clustering_eps: float = 0.08, clustering_min_samples: int = 2, enable_quality_validation: bool = True, + ui_detector: Optional[Any] = None, + screen_analyzer: Optional[Any] = None, + enable_ui_enrichment: bool = True, + element_proximity_max_px: float = 50.0, ): """ Initialiser le GraphBuilder. - + Args: embedding_builder: Builder pour State Embeddings (créé si None) faiss_manager: Manager FAISS pour indexation (optionnel) @@ -185,6 +189,17 @@ class GraphBuilder: clustering_eps: Epsilon pour DBSCAN (distance max entre points) clustering_min_samples: Nombre minimum d'échantillons pour un cluster enable_quality_validation: Activer la validation de qualité + ui_detector: UIDetector optionnel. Si fourni, sera utilisé par + l'analyzer lazy-initialisé. Sinon, fallback sur le singleton + partagé (`get_screen_analyzer()`). + screen_analyzer: Instance ScreenAnalyzer à utiliser directement. + Si None, lazy init via le singleton partagé C1. + enable_ui_enrichment: Active l'enrichissement visuel des + ScreenStates lors de `_create_screen_states` (OCR + UIDetector). + False = comportement historique (ui_elements=[], detected_text=[]). + element_proximity_max_px: Distance maximale (en pixels) entre un + clic et le bbox le plus proche pour qu'un UIElement soit + considéré comme cible. Au-delà, le clic reste sans ancre. """ self.embedding_builder = embedding_builder or StateEmbeddingBuilder() self.faiss_manager = faiss_manager @@ -193,15 +208,65 @@ class GraphBuilder: self.clustering_eps = clustering_eps self.clustering_min_samples = clustering_min_samples self.enable_quality_validation = enable_quality_validation - self._screen_analyzer = None # ScreenAnalyzer (lazy import) - + self.enable_ui_enrichment = enable_ui_enrichment + self.element_proximity_max_px = element_proximity_max_px + # UIDetector explicite (optionnel) — injecté dans l'analyzer lazy. + self._ui_detector = ui_detector + # Instance ScreenAnalyzer. Si fournie, on l'utilise telle quelle ; + # sinon, on bascule sur le singleton partagé (lazy init). + self._screen_analyzer = screen_analyzer + logger.info( f"GraphBuilder initialized: " f"min_repetitions={min_pattern_repetitions}, " f"eps={clustering_eps}, " f"min_samples={clustering_min_samples}, " - f"quality_validation={enable_quality_validation}" + f"quality_validation={enable_quality_validation}, " + f"ui_enrichment={enable_ui_enrichment}" ) + + # ------------------------------------------------------------------ + # Résolution paresseuse du ScreenAnalyzer (singleton C1 par défaut) + # ------------------------------------------------------------------ + + def _get_screen_analyzer(self): + """ + Retourner l'instance ScreenAnalyzer à utiliser. + + Priorité : + 1. Instance injectée via le constructeur (`screen_analyzer=…`). + 2. Singleton partagé `get_screen_analyzer()` (C1) — évite le double + chargement GPU quand ExecutionLoop et stream_processor tournent. + 3. En dernier recours (import circulaire, tests), création locale. + """ + if self._screen_analyzer is not None: + return self._screen_analyzer + + try: + from core.pipeline import get_screen_analyzer + + self._screen_analyzer = get_screen_analyzer( + ui_detector=self._ui_detector, + ) + return self._screen_analyzer + except Exception as e: + logger.warning( + f"Impossible d'obtenir le ScreenAnalyzer singleton " + f"({e}); fallback sur une instance locale." + ) + try: + from core.pipeline.screen_analyzer import ScreenAnalyzer + + self._screen_analyzer = ScreenAnalyzer( + ui_detector=self._ui_detector, + ) + return self._screen_analyzer + except Exception as e2: + logger.error( + f"Impossible d'instancier ScreenAnalyzer: {e2}. " + "Enrichissement UI désactivé." + ) + return None def build_from_session( self, @@ -209,6 +274,7 @@ class GraphBuilder: workflow_name: Optional[str] = None, precomputed_states: Optional[List["ScreenState"]] = None, precomputed_embeddings: Optional[List] = None, + sequential: bool = False, ) -> Workflow: """ Construire un Workflow complet depuis une RawSession. @@ -216,7 +282,7 @@ class GraphBuilder: Processus: 1. Créer ScreenStates depuis screenshots (ou utiliser precomputed_states) 2. Calculer embeddings pour chaque état (ou réutiliser precomputed_embeddings) - 3. Détecter patterns via clustering + 3. Détecter patterns via clustering (ou mode séquentiel) 4. Construire nodes depuis clusters 5. Construire edges depuis transitions @@ -228,6 +294,10 @@ class GraphBuilder: precomputed_embeddings: Embeddings déjà calculés (streaming). Si fourni et de la bonne longueur (= len(screen_states)), saute l'étape 2 (pas de recalcul CLIP). + sequential: Si True, crée un node par état d'écran (pas de + clustering DBSCAN). Approprié pour les enregistrements + single-pass d'un workflow — chaque screenshot est une étape + distincte avec ses actions associées. Returns: Workflow construit avec nodes et edges @@ -242,6 +312,7 @@ class GraphBuilder: f"Building workflow from session {session.session_id} " f"with {len(precomputed_states or session.screenshots)} " f"{'precomputed states' if precomputed_states else 'screenshots'}" + f"{' (mode séquentiel)' if sequential else ''}" ) # Étape 1: Créer ScreenStates (ou réutiliser ceux pré-calculés) @@ -266,16 +337,28 @@ class GraphBuilder: embeddings = self._compute_embeddings(screen_states) logger.debug(f"Computed {len(embeddings)} embeddings") - # Étape 3: Détecter patterns - clusters = self._detect_patterns(embeddings, screen_states) - logger.info(f"Detected {len(clusters)} patterns") + # Étape 3: Détecter patterns ou mode séquentiel + if sequential: + # Mode séquentiel : chaque état d'écran est un node distinct. + # Pas de clustering — essentiel pour les enregistrements single-pass + # où l'on veut reproduire fidèlement la séquence des actions. + clusters = {i: [i] for i in range(len(screen_states))} + logger.info( + f"Mode séquentiel: {len(clusters)} nodes (1 par état)" + ) + else: + clusters = self._detect_patterns(embeddings, screen_states) + logger.info(f"Detected {len(clusters)} patterns") # Étape 4: Construire nodes nodes = self._build_nodes(clusters, screen_states, embeddings) logger.info(f"Built {len(nodes)} workflow nodes") # Étape 5: Construire edges (passer les embeddings pour éviter recalcul) - edges = self._build_edges(nodes, screen_states, session, embeddings=embeddings) + edges = self._build_edges( + nodes, screen_states, session, embeddings=embeddings, + sequential=sequential, + ) logger.info(f"Built {len(edges)} workflow edges") # Créer Workflow @@ -388,18 +471,35 @@ class GraphBuilder: Liste de ScreenStates enrichis """ screen_states = [] - + # Créer un mapping screenshot_id -> événement screenshot_to_event = {} for event in session.events: if event.screenshot_id: screenshot_to_event[event.screenshot_id] = event - + + # Récupérer (une seule fois) l'analyzer partagé si l'enrichissement est actif. + # Le singleton C1 garantit qu'on ne recharge pas UIDetector/CLIP inutilement. + analyzer = None + if self.enable_ui_enrichment: + analyzer = self._get_screen_analyzer() + + # Cache partagé (C1) : réutiliser les analyses si même screenshot est + # repassé plusieurs fois (peu fréquent en construction, utile en tests). + try: + from core.pipeline import get_screen_state_cache + + state_cache = get_screen_state_cache() + except Exception as e: + logger.debug(f"ScreenStateCache indisponible ({e}); aucun cache utilisé.") + state_cache = None + + enriched_count = 0 for i, screenshot in enumerate(session.screenshots): # Trouver l'événement associé event = screenshot_to_event.get(screenshot.screenshot_id) - - # Créer WindowContext depuis l'événement + + # Construire WindowContext depuis l'événement (si dispo) screen_env = session.environment.get("screen", {}) screen_res = screen_env.get("primary_resolution", [1920, 1080]) if event and event.window: @@ -426,60 +526,128 @@ class GraphBuilder: os_theme=session.environment.get("os_theme", "unknown"), os_language=session.environment.get("os_language", "unknown"), ) - - # Créer RawLevel - # Construire chemin absolu : data/training/sessions/{session_id}/{session_id}/{relative_path} - screenshot_absolute_path = f"data/training/sessions/{session.session_id}/{session.session_id}/{screenshot.relative_path}" + + # Chemin absolu du screenshot + screenshot_absolute_path = ( + f"data/training/sessions/{session.session_id}/" + f"{session.session_id}/{screenshot.relative_path}" + ) screenshot_path = Path(screenshot_absolute_path) + + # Timestamp + if isinstance(screenshot.captured_at, str): + timestamp = datetime.fromisoformat( + screenshot.captured_at.replace('Z', '+00:00') + ) + else: + timestamp = screenshot.captured_at + + # ------------------------------------------------------------ + # Enrichissement visuel : déléguer au ScreenAnalyzer partagé + # ------------------------------------------------------------ + # L'analyzer renvoie un ScreenState complet avec : + # - raw (image + file_size) + # - perception (OCR + embedding ref) + # - ui_elements (détection UIDetector) + # On récupère ces niveaux et on rebâtit un état final avec le + # WindowContext et les metadata issus de la session brute (les + # données "metier" que l'analyzer ignore). + # ------------------------------------------------------------ + detected_text: List[str] = [] + text_method = "none" + ui_elements: List = [] raw = RawLevel( screenshot_path=str(screenshot_path), capture_method="mss", - file_size_bytes=screenshot_path.stat().st_size if screenshot_path.exists() else 0 + file_size_bytes=( + screenshot_path.stat().st_size + if screenshot_path.exists() + else 0 + ), ) - - # Créer PerceptionLevel — enrichir avec OCR si le screenshot existe - detected_text = [] - text_method = "none" - if screenshot_path.exists(): + if analyzer is not None and screenshot_path.exists(): try: - if self._screen_analyzer is None: - from core.pipeline.screen_analyzer import ScreenAnalyzer - self._screen_analyzer = ScreenAnalyzer(session_id=session.session_id) - extracted = self._screen_analyzer._extract_text(str(screenshot_path)) - if extracted: - detected_text = extracted - text_method = self._screen_analyzer._get_ocr_method_name() - except Exception as e: - logger.debug(f"OCR échoué pour {screenshot_path}: {e}") + # Construire l'info fenêtre pour donner le contexte à + # l'UIDetector (certains détecteurs s'en servent pour + # filtrer hors-fenêtre). + window_info = { + "app_name": window.app_name, + "title": window.window_title, + "screen_resolution": list(window.screen_resolution or []), + } + analyzed = analyzer.analyze( + str(screenshot_path), + window_info=window_info, + enable_ocr=True, + enable_ui_detection=True, + session_id=session.session_id, + ) + detected_text = list(analyzed.perception.detected_text or []) + text_method = ( + analyzed.perception.text_detection_method or "none" + ) + ui_elements = list(analyzed.ui_elements or []) + # Garder les métriques OCR/UI si présentes (debug) + analyzer_metadata = dict(analyzed.metadata or {}) + raw = analyzed.raw # conserver file_size réel mesuré + if ui_elements: + enriched_count += 1 + except Exception as e: + logger.warning( + f"Enrichissement visuel échoué pour {screenshot_path}: {e}. " + "Fallback sur ScreenState minimal." + ) + analyzer_metadata = {"analyzer_error": str(e)} + else: + analyzer_metadata = {} + if self.enable_ui_enrichment and not screenshot_path.exists(): + logger.debug( + f"Screenshot introuvable: {screenshot_path} " + "— ui_elements restera vide" + ) + + # PerceptionLevel : vector_id calculé de façon déterministe. perception = PerceptionLevel( embedding=EmbeddingRef( provider="openclip_ViT-B-32", - vector_id=f"data/embeddings/screens/{session.session_id}_state_{i:04d}.npy", - dimensions=512 + vector_id=( + f"data/embeddings/screens/" + f"{session.session_id}_state_{i:04d}.npy" + ), + dimensions=512, ), detected_text=detected_text, text_detection_method=text_method, - confidence_avg=0.85 if detected_text else 0.0 + confidence_avg=0.85 if detected_text else 0.0, ) - - # Créer ContextLevel + + # ContextLevel (métier) context = ContextLevel( current_workflow_candidate=None, workflow_step=i, user_id=session.user.get("id", "unknown"), - tags=list(session.context.get("tags", [])) if isinstance(session.context.get("tags"), list) else [], - business_variables={} + tags=( + list(session.context.get("tags", [])) + if isinstance(session.context.get("tags"), list) + else [] + ), + business_variables={}, ) - - # Parser timestamp - if isinstance(screenshot.captured_at, str): - timestamp = datetime.fromisoformat(screenshot.captured_at.replace('Z', '+00:00')) - else: - timestamp = screenshot.captured_at - - # Créer ScreenState complet + + # Metadata : on garde le lien événement/session + éventuels + # compteurs remontés par l'analyzer. + metadata = { + "screenshot_id": screenshot.screenshot_id, + "event_type": event.type if event else None, + "event_time": event.t if event else None, + } + # Propager les indicateurs utiles de l'analyzer sans écraser la base. + for key in ("ocr_ms", "ui_ms", "analyzer_error"): + if key in analyzer_metadata: + metadata[key] = analyzer_metadata[key] + state = ScreenState( screen_state_id=f"{session.session_id}_state_{i:04d}", timestamp=timestamp, @@ -488,17 +656,17 @@ class GraphBuilder: raw=raw, perception=perception, context=context, - metadata={ - "screenshot_id": screenshot.screenshot_id, - "event_type": event.type if event else None, - "event_time": event.t if event else None - }, - ui_elements=[] # Sera rempli par UIDetector si disponible + metadata=metadata, + ui_elements=ui_elements, ) - + screen_states.append(state) - - logger.info(f"Created {len(screen_states)} enriched screen states") + + logger.info( + f"Created {len(screen_states)} enriched screen states " + f"({enriched_count} avec UI détectée, " + f"ui_enrichment={self.enable_ui_enrichment})" + ) return screen_states def _compute_embeddings( @@ -924,6 +1092,99 @@ class GraphBuilder: constraints.sort(key=lambda c: role_counts.get(c.get("role", ""), 0), reverse=True) return constraints[:8] + # ------------------------------------------------------------------ + # Association spatiale clic → UIElement + # ------------------------------------------------------------------ + + def _find_clicked_element( + self, + event: Event, + ui_elements: List[Any], + ) -> Optional[Any]: + """ + Identifier l'UIElement cible d'un clic par proximité spatiale. + + Règle : + 1. Si un bbox contient strictement la position du clic → match. + 2. Sinon, on prend le bbox le plus proche (distance euclidienne + au bord) sous réserve qu'il soit à <= `element_proximity_max_px`. + 3. Sinon, aucun ancrage possible → None. + + Cette association transforme un clic "aveugle" (coordonnées brutes) + en un clic "intelligent" (rôle + label), permettant au matcher de + retrouver l'élément même si la résolution ou la position change. + + Args: + event: Événement `mouse_click` (avec `data["pos"] = [x, y]`). + ui_elements: Liste des UIElement détectés sur l'écran source. + + Returns: + UIElement le plus pertinent, ou None si rien ne correspond. + """ + if not ui_elements: + return None + if not event or event.type != "mouse_click": + return None + + pos = event.data.get("pos") if event.data else None + if not pos or len(pos) < 2: + return None + + try: + click_x = float(pos[0]) + click_y = float(pos[1]) + except (TypeError, ValueError): + return None + + best_contained = None + best_contained_area = None + best_near = None + best_near_distance = None + + for element in ui_elements: + bbox = getattr(element, "bbox", None) + if bbox is None: + continue + + # Extraction défensive des coordonnées (BBox Pydantic ou tuple) + try: + bx = int(getattr(bbox, "x", bbox[0])) + by = int(getattr(bbox, "y", bbox[1])) + bw = int(getattr(bbox, "width", bbox[2])) + bh = int(getattr(bbox, "height", bbox[3])) + except (AttributeError, IndexError, TypeError): + continue + + # Cas 1 : la position est strictement dans le bbox. + if bx <= click_x <= bx + bw and by <= click_y <= by + bh: + # Sélectionner le plus petit bbox qui contient (élément le plus spécifique) + area = max(1, bw * bh) + if best_contained is None or area < best_contained_area: + best_contained = element + best_contained_area = area + continue + + # Cas 2 : calculer la distance au bord le plus proche. + dx = max(bx - click_x, 0, click_x - (bx + bw)) + dy = max(by - click_y, 0, click_y - (by + bh)) + distance = (dx * dx + dy * dy) ** 0.5 + + if best_near is None or distance < best_near_distance: + best_near = element + best_near_distance = distance + + if best_contained is not None: + return best_contained + + if ( + best_near is not None + and best_near_distance is not None + and best_near_distance <= self.element_proximity_max_px + ): + return best_near + + return None + # Patterns d'erreur courants pour la détection fail_fast _ERROR_PATTERNS = [ "erreur", "error", "échec", "failed", "impossible", @@ -937,12 +1198,14 @@ class GraphBuilder: screen_states: List[ScreenState], session: RawSession, embeddings: Optional[List[np.ndarray]] = None, + sequential: bool = False, ) -> List[WorkflowEdge]: """ Construire WorkflowEdges depuis les transitions observées. Algorithme: 1. Mapper chaque ScreenState vers son node (via embedding similarity) + En mode séquentiel, le mapping est direct (state i → node i). 2. Identifier les transitions (state_i -> state_j où node change) 3. Extraire l'action depuis l'événement entre les deux états 4. Créer WorkflowEdge avec action, pré-conditions et post-conditions @@ -960,6 +1223,7 @@ class GraphBuilder: screen_states: ScreenStates session: Session brute (pour événements) embeddings: Embeddings pré-calculés (évite un recalcul dans _map_states_to_nodes) + sequential: Mode séquentiel — chaque paire consécutive = transition Returns: Liste de WorkflowEdges @@ -975,7 +1239,19 @@ class GraphBuilder: node_by_id = {node.node_id: node for node in nodes} # Étape 1: Mapper chaque état vers son node - state_to_node = self._map_states_to_nodes(screen_states, nodes, embeddings=embeddings) + if sequential: + # Mode séquentiel : mapping direct state[i] → node[i] + state_to_node = {} + for i, state in enumerate(screen_states): + if i < len(nodes): + state_to_node[state.screen_state_id] = nodes[i].node_id + logger.debug( + f"Mode séquentiel: {len(state_to_node)} states mappés directement" + ) + else: + state_to_node = self._map_states_to_nodes( + screen_states, nodes, embeddings=embeddings + ) # Étape 2: Récupérer la résolution d'écran pour normaliser les coordonnées screen_env = session.environment.get("screen", {}) @@ -989,8 +1265,11 @@ class GraphBuilder: current_node_id = state_to_node.get(current_state.screen_state_id) next_node_id = state_to_node.get(next_state.screen_state_id) - # Si les deux états sont dans des nodes différents, c'est une transition - if current_node_id and next_node_id and current_node_id != next_node_id: + # En mode séquentiel, chaque paire consécutive est une transition + # En mode clustering, uniquement si les nodes sont différents + if current_node_id and next_node_id and ( + sequential or current_node_id != next_node_id + ): # Trouver TOUS les événements entre les deux états transition_events = self._find_transition_events( current_state, next_state, session.events @@ -1012,6 +1291,7 @@ class GraphBuilder: target_node=target_node, all_events=transition_events, screen_resolution=screen_resolution, + source_state=current_state, ) edges.append(edge) @@ -1094,6 +1374,32 @@ class GraphBuilder: return state_to_node + def _get_state_time(self, state: ScreenState, fallback: float = 0) -> float: + """Extraire le timestamp d'un ScreenState. + + Priorité : + 1. metadata['event_time'] (set par _create_screen_states) + 2. metadata['shot_timestamp'] (set par le reprocessing) + 3. state.timestamp converti en epoch si c'est un datetime + 4. fallback + + Note : event_time peut être 0.0 (timestamps relatifs), donc on + vérifie `is not None` et non `> 0`. + """ + if state.metadata: + et = state.metadata.get("event_time") + if et is not None: + return float(et) + st = state.metadata.get("shot_timestamp") + if st is not None: + return float(st) + if state.timestamp: + try: + return state.timestamp.timestamp() + except (AttributeError, OSError): + pass + return fallback + def _find_transition_events( self, current_state: ScreenState, @@ -1108,6 +1414,9 @@ class GraphBuilder: C'est essentiel pour le replay : une transition peut nécessiter plusieurs actions (ex: Win+R → taper "notepad" → Entrée). + Timestamps : utilise _get_state_time() qui supporte plusieurs + sources (event_time, shot_timestamp, datetime). + Args: current_state: État source next_state: État cible @@ -1117,8 +1426,8 @@ class GraphBuilder: Liste ordonnée (par timestamp) de tous les événements d'action entre les deux états. Peut être vide. """ - current_time = current_state.metadata.get("event_time", 0) - next_time = next_state.metadata.get("event_time", float('inf')) + current_time = self._get_state_time(current_state, fallback=0) + next_time = self._get_state_time(next_state, fallback=float('inf')) action_events = [] for event in events: @@ -1155,6 +1464,7 @@ class GraphBuilder: target_node: Optional[WorkflowNode] = None, all_events: Optional[List[Event]] = None, screen_resolution: Tuple[int, int] = (1920, 1080), + source_state: Optional[ScreenState] = None, ) -> WorkflowEdge: """ Créer un WorkflowEdge depuis une transition observée. @@ -1180,12 +1490,24 @@ class GraphBuilder: # Si on a plusieurs événements, créer une action compound events_to_use = all_events or ([event] if event else []) + # UIElements de l'écran source — sert à ancrer les clics sur un vrai + # élément UI (rôle, texte, bbox) plutôt que sur une coordonnée brute. + source_ui_elements = ( + list(source_state.ui_elements) + if source_state and source_state.ui_elements + else [] + ) + if len(events_to_use) > 1: action = self._build_compound_action( - events_to_use, screen_resolution + events_to_use, screen_resolution, + source_ui_elements=source_ui_elements, ) elif len(events_to_use) == 1: - action = self._build_single_action(events_to_use[0]) + action = self._build_single_action( + events_to_use[0], + source_ui_elements=source_ui_elements, + ) else: action = Action( type="unknown", @@ -1235,15 +1557,29 @@ class GraphBuilder: metadata=edge_metadata, ) - def _build_single_action(self, event: Event) -> Action: + def _build_single_action( + self, + event: Event, + source_ui_elements: Optional[List[Any]] = None, + ) -> Action: """ Construire une Action simple depuis un seul événement. - Rétrocompatible avec l'ancien format : un type d'action direct - (mouse_click, key_press, text_input) avec ses paramètres. + Pour un clic, si `source_ui_elements` est fourni, on tente d'ancrer + l'action sur l'UIElement le plus proche (par proximité spatiale). + Le TargetSpec devient alors discriminant : + - `by_role` = rôle sémantique de l'élément (ex: "primary_action") + - `by_text` = label détecté (ex: "Valider") + - `selection_policy` = "by_similarity" (laisse le matcher scorer) + - `context_hints["anchor_element_id"]` = traçabilité + - `context_hints["anchor_bbox"]` = invariant spatial debug + + À défaut d'ancrage (pas d'UIElement ou clic hors de toute bbox + proche), on retombe sur `by_role="unknown_element"` (legacy). """ action_type = event.type - action_params = {} + action_params: Dict[str, Any] = {} + target_spec: Optional[TargetSpec] = None if action_type == "mouse_click": action_params = { @@ -1251,39 +1587,111 @@ class GraphBuilder: "position": event.data.get("pos", [0, 0]), "wait_after_ms": 500, } - target_role = "unknown_element" + target_spec = self._build_click_target_spec( + event, source_ui_elements or [] + ) elif action_type == "key_press": action_params = { "keys": event.data.get("keys", []), "wait_after_ms": 200, } - target_role = "keyboard_input" + target_spec = TargetSpec( + by_role="keyboard_input", + selection_policy="first", + fallback_strategy="visual_similarity", + ) elif action_type == "text_input": action_params = { "text": event.data.get("text", ""), "wait_after_ms": 300, } - target_role = "text_field" + target_spec = TargetSpec( + by_role="text_field", + selection_policy="first", + fallback_strategy="visual_similarity", + ) else: action_params = {} - target_role = "unknown" + target_spec = TargetSpec( + by_role="unknown", + selection_policy="first", + fallback_strategy="visual_similarity", + ) return Action( type=action_type, - target=TargetSpec( - by_role=target_role, + target=target_spec, + parameters=action_params, + ) + + def _build_click_target_spec( + self, + event: Event, + source_ui_elements: List[Any], + ) -> TargetSpec: + """ + Construire un TargetSpec pour un clic, en essayant de l'ancrer à + un UIElement détecté sur l'écran source. + + Retourne toujours un TargetSpec valide : + - ancré (role + text + context_hints) si un élément proche existe ; + - fallback `unknown_element` sinon (comportement historique). + """ + clicked = self._find_clicked_element(event, source_ui_elements) + + if clicked is None: + return TargetSpec( + by_role="unknown_element", selection_policy="first", fallback_strategy="visual_similarity", - ), - parameters=action_params, + ) + + # Extraction défensive des attributs de l'élément. + role = getattr(clicked, "role", None) or "unknown_element" + label = getattr(clicked, "label", None) or None + element_id = getattr(clicked, "element_id", None) + + # Contexte de traçabilité — `context_hints` est le seul dict libre + # disponible dans TargetSpec (pas de champ `metadata` dédié). + context_hints: Dict[str, Any] = {} + if element_id: + context_hints["anchor_element_id"] = str(element_id) + + bbox = getattr(clicked, "bbox", None) + if bbox is not None: + try: + context_hints["anchor_bbox"] = { + "x": int(getattr(bbox, "x", bbox[0])), + "y": int(getattr(bbox, "y", bbox[1])), + "width": int(getattr(bbox, "width", bbox[2])), + "height": int(getattr(bbox, "height", bbox[3])), + } + except (AttributeError, IndexError, TypeError): + pass + + # Center (utile comme ancre de fallback quand le matcher échoue) + center = getattr(clicked, "center", None) + if center is not None: + try: + context_hints["anchor_center"] = [int(center[0]), int(center[1])] + except (IndexError, TypeError): + pass + + return TargetSpec( + by_role=role, + by_text=label, + selection_policy="by_similarity", + fallback_strategy="visual_similarity", + context_hints=context_hints, ) def _build_compound_action( self, events: List[Event], screen_resolution: Tuple[int, int] = (1920, 1080), + source_ui_elements: Optional[List[Any]] = None, ) -> Action: """ Construire une Action compound (multi-étapes) depuis plusieurs événements. @@ -1360,21 +1768,33 @@ class GraphBuilder: # La cible du compound = cible de la dernière action (le clic final, etc.) last_event = events[-1] if last_event.type == "mouse_click": - target_role = "unknown_element" + # On tente d'ancrer le clic final aux UIElements détectés, + # comme dans _build_single_action. + target_spec = self._build_click_target_spec( + last_event, source_ui_elements or [] + ) elif last_event.type == "text_input": - target_role = "text_field" + target_spec = TargetSpec( + by_role="text_field", + selection_policy="first", + fallback_strategy="visual_similarity", + ) elif last_event.type == "key_press": - target_role = "keyboard_input" + target_spec = TargetSpec( + by_role="keyboard_input", + selection_policy="first", + fallback_strategy="visual_similarity", + ) else: - target_role = "unknown" + target_spec = TargetSpec( + by_role="unknown", + selection_policy="first", + fallback_strategy="visual_similarity", + ) return Action( type="compound", - target=TargetSpec( - by_role=target_role, - selection_policy="first", - fallback_strategy="visual_similarity", - ), + target=target_spec, parameters={ "steps": steps, "step_count": len(steps), diff --git a/core/pipeline/workflow_pipeline.py b/core/pipeline/workflow_pipeline.py index 0010d4164..fba64a79b 100644 --- a/core/pipeline/workflow_pipeline.py +++ b/core/pipeline/workflow_pipeline.py @@ -137,10 +137,14 @@ class WorkflowPipeline: else: logger.warning(f"UI Detector not available: {e}") - # 6. Graph Builder + # 6. Graph Builder — reçoit l'UIDetector pour enrichir les + # ScreenStates avec ui_elements + OCR pendant _create_screen_states. + # Sans ça, les TargetSpec ne peuvent pas être ancrés (by_role=unknown). self.graph_builder = GraphBuilder( embedding_builder=self.embedding_builder, - faiss_manager=self.faiss_manager + faiss_manager=self.faiss_manager, + ui_detector=self.ui_detector, + enable_ui_enrichment=enable_ui_detection, ) logger.info("✓ Graph Builder initialized") diff --git a/tests/test_pipeline_e2e.py b/tests/test_pipeline_e2e.py index 50ccf8ceb..0920bb2df 100644 --- a/tests/test_pipeline_e2e.py +++ b/tests/test_pipeline_e2e.py @@ -143,13 +143,19 @@ def mock_embedding_builder(): @pytest.fixture def graph_builder(mock_embedding_builder): - """GraphBuilder configuré pour le test (validation qualité désactivée).""" + """GraphBuilder configuré pour le test (validation qualité désactivée). + + `enable_ui_enrichment=False` désactive l'analyzer GPU : ces tests + valident le pipeline DBSCAN + edges, pas la détection UI réelle + (couverte par tests/unit/test_graph_builder_ui_enrichment.py). + """ return GraphBuilder( embedding_builder=mock_embedding_builder, min_pattern_repetitions=3, clustering_eps=0.15, clustering_min_samples=2, enable_quality_validation=False, + enable_ui_enrichment=False, ) @@ -356,6 +362,7 @@ class TestQualityValidation: embedding_builder=mock_embedding_builder, min_pattern_repetitions=3, enable_quality_validation=True, + enable_ui_enrichment=False, ) workflow = builder.build_from_session(session) @@ -377,6 +384,7 @@ class TestQualityValidation: embedding_builder=mock_embedding_builder, min_pattern_repetitions=3, enable_quality_validation=True, + enable_ui_enrichment=False, ) workflow = builder.build_from_session(session) @@ -403,6 +411,7 @@ class TestEdgeCases: builder = GraphBuilder( embedding_builder=mock_embedding_builder, enable_quality_validation=False, + enable_ui_enrichment=False, ) with pytest.raises(ValueError, match="no screenshots"): @@ -456,6 +465,7 @@ class TestEdgeCases: embedding_builder=mock_embedding_builder, min_pattern_repetitions=3, enable_quality_validation=False, + enable_ui_enrichment=False, ) workflow = builder.build_from_session(session) diff --git a/tests/unit/test_graph_builder_ui_enrichment.py b/tests/unit/test_graph_builder_ui_enrichment.py new file mode 100644 index 000000000..c09f04251 --- /dev/null +++ b/tests/unit/test_graph_builder_ui_enrichment.py @@ -0,0 +1,513 @@ +""" +Tests unitaires de l'enrichissement visuel dans GraphBuilder (chantier C2). + +Couvre : + - `_create_screen_states` : enrichit `ui_elements` via ScreenAnalyzer + - `_find_clicked_element` : association spatiale clic → UIElement + - `_build_single_action` : TargetSpec avec `by_role`/`by_text` quand ancre + - Fallback `by_role="unknown_element"` quand aucun ancrage n'est possible + - `_extract_common_ui_elements` : required_roles extrait du cluster + - Analyzer qui crash → ScreenState vide, pas de propagation d'exception + - Singleton partagé entre deux GraphBuilder (C1) +""" + +from __future__ import annotations + +from datetime import datetime, timedelta +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +from PIL import Image + +from core.graph.graph_builder import GraphBuilder +from core.models.base_models import BBox +from core.models.raw_session import ( + Event, + RawSession, + RawWindowContext, + Screenshot, +) +from core.models.screen_state import ( + ContextLevel, + EmbeddingRef, + PerceptionLevel, + RawLevel, + ScreenState, + WindowContext, +) +from core.models.ui_element import ( + UIElement, + UIElementEmbeddings, + VisualFeatures, +) +from core.pipeline import ( + reset_screen_analyzer, + reset_screen_state_cache, +) + + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_singletons(): + """Isole chaque test des singletons globaux.""" + reset_screen_analyzer() + reset_screen_state_cache() + yield + reset_screen_analyzer() + reset_screen_state_cache() + + +def _make_click_event(pos, t: float = 1.0, button: str = "left") -> Event: + """Event mouse_click minimal (window est requis par le dataclass).""" + return Event( + t=t, + type="mouse_click", + window=RawWindowContext(title="Test", app_name="test_app"), + data={"button": button, "pos": list(pos)}, + ) + + +def _make_key_event(t: float = 1.0, keys=None, text: str = None, ev_type: str = "key_press") -> Event: + """Event clavier (key_press ou text_input).""" + data = {} + if keys is not None: + data["keys"] = keys + if text is not None: + data["text"] = text + return Event( + t=t, + type=ev_type, + window=RawWindowContext(title="Test", app_name="test_app"), + data=data, + ) + + +def _make_ui_element( + element_id: str, + role: str, + label: str, + bbox: tuple, + el_type: str = "button", +) -> UIElement: + """Construire un UIElement minimal pour les tests.""" + return UIElement( + element_id=element_id, + type=el_type, + role=role, + bbox=BBox.from_tuple(bbox), + center=(bbox[0] + bbox[2] // 2, bbox[1] + bbox[3] // 2), + label=label, + label_confidence=0.95, + embeddings=UIElementEmbeddings(), + visual_features=VisualFeatures( + dominant_color="blue", + has_icon=False, + shape="rectangle", + size_category="medium", + ), + confidence=0.9, + ) + + +def _make_screen_state( + session_id: str, + index: int, + ui_elements: list, + title: str = "Test App", + detected_text: list = None, +) -> ScreenState: + """ScreenState minimal utilisable par _extract_common_ui_elements.""" + return ScreenState( + screen_state_id=f"{session_id}_state_{index:04d}", + timestamp=datetime(2026, 4, 13, 10, 0, index), + session_id=session_id, + window=WindowContext( + app_name="test_app", + window_title=title, + screen_resolution=[1920, 1080], + ), + raw=RawLevel( + screenshot_path=f"/tmp/shot_{index}.png", + capture_method="mss", + file_size_bytes=1024, + ), + perception=PerceptionLevel( + embedding=EmbeddingRef( + provider="test", vector_id=f"v_{index}", dimensions=512 + ), + detected_text=detected_text or [], + text_detection_method="test", + confidence_avg=0.8, + ), + context=ContextLevel(), + metadata={}, + ui_elements=ui_elements, + ) + + +@pytest.fixture +def synthetic_session(tmp_path): + """RawSession synthétique avec 2 screenshots alternés.""" + session_id = "ui_enrich_session" + screens_dir = ( + tmp_path / "data" / "training" / "sessions" + / session_id / session_id / "screenshots" + ) + screens_dir.mkdir(parents=True) + + screenshots = [] + events = [] + for i in range(4): + ts = datetime(2026, 4, 13, 10, 0, i) + color = (200, 50, 50) if i % 2 == 0 else (50, 50, 200) + img = Image.new("RGB", (400, 300), color) + fname = f"screen_{i:03d}.png" + img.save(str(screens_dir / fname)) + + screenshots.append(Screenshot( + screenshot_id=f"ss_{i:03d}", + relative_path=f"screenshots/{fname}", + captured_at=ts.isoformat(), + )) + events.append(Event( + t=float(i), + type="mouse_click", + window=RawWindowContext( + title="App A" if i % 2 == 0 else "App B", + app_name="app", + ), + screenshot_id=f"ss_{i:03d}", + data={"button": "left", "pos": [150, 120]}, + )) + + session = RawSession( + session_id=session_id, + agent_version="test", + environment={"screen": {"primary_resolution": [1920, 1080]}}, + user={"id": "tester"}, + context={}, + started_at=datetime(2026, 4, 13, 10, 0, 0), + events=events, + screenshots=screenshots, + ) + return session, tmp_path + + +# ----------------------------------------------------------------------------- +# Enrichissement des ScreenState via ScreenAnalyzer +# ----------------------------------------------------------------------------- + + +class TestCreateScreenStatesEnrichment: + """_create_screen_states doit déléguer au ScreenAnalyzer.""" + + def test_build_from_session_enriches_screen_states( + self, synthetic_session, monkeypatch + ): + """Avec un analyzer mocké, les ui_elements sont propagés aux ScreenState.""" + session, tmp_path = synthetic_session + monkeypatch.chdir(tmp_path) + + # Analyzer mocké : renvoie un ScreenState avec 3 UIElement canoniques. + fake_elements = [ + _make_ui_element("el_1", "primary_action", "Valider", (100, 100, 80, 30)), + _make_ui_element("el_2", "cancel", "Annuler", (200, 100, 80, 30)), + _make_ui_element("el_3", "form_input", "Nom", (100, 50, 200, 30)), + ] + + def fake_analyze(path, **kwargs): + # On renvoie un ScreenState avec le bon nombre d'éléments + OCR. + return _make_screen_state( + session.session_id, + index=0, + ui_elements=list(fake_elements), + detected_text=["Nom", "Valider", "Annuler"], + ) + + analyzer = MagicMock() + analyzer.analyze.side_effect = fake_analyze + + builder = GraphBuilder( + screen_analyzer=analyzer, + enable_ui_enrichment=True, + enable_quality_validation=False, + ) + states = builder._create_screen_states(session) + + assert len(states) == 4 + for st in states: + assert len(st.ui_elements) == 3 + roles = {e.role for e in st.ui_elements} + assert {"primary_action", "cancel", "form_input"}.issubset(roles) + assert "Valider" in st.perception.detected_text + + def test_enrichment_disabled_leaves_ui_elements_empty( + self, synthetic_session, monkeypatch + ): + """enable_ui_enrichment=False → ui_elements vide, analyzer jamais appelé.""" + session, tmp_path = synthetic_session + monkeypatch.chdir(tmp_path) + + analyzer = MagicMock() + builder = GraphBuilder( + screen_analyzer=analyzer, + enable_ui_enrichment=False, + enable_quality_validation=False, + ) + states = builder._create_screen_states(session) + + assert len(states) == 4 + for st in states: + assert st.ui_elements == [] + assert st.perception.detected_text == [] + # L'analyzer ne doit pas avoir été appelé. + analyzer.analyze.assert_not_called() + + def test_analyzer_failure_falls_back_to_empty( + self, synthetic_session, monkeypatch, caplog + ): + """Un analyzer qui crash → ScreenState vide, log warning, pas d'exception.""" + session, tmp_path = synthetic_session + monkeypatch.chdir(tmp_path) + + analyzer = MagicMock() + analyzer.analyze.side_effect = RuntimeError("boom (GPU OOM)") + + builder = GraphBuilder( + screen_analyzer=analyzer, + enable_ui_enrichment=True, + enable_quality_validation=False, + ) + with caplog.at_level("WARNING"): + states = builder._create_screen_states(session) + + assert len(states) == 4 + for st in states: + assert st.ui_elements == [] + # La metadata trace l'erreur pour le diagnostic + assert "analyzer_error" in st.metadata + # Un log warning a bien été émis + assert any("Enrichissement visuel échoué" in r.getMessage() for r in caplog.records) + + def test_shared_analyzer_singleton(self, monkeypatch): + """Deux GraphBuilder créés sans analyzer explicite partagent le singleton C1.""" + fake_analyzer = MagicMock(name="singleton_analyzer") + # Ne jamais appeler analyze (pas de screenshots dans ce test) + + with patch( + "core.pipeline.get_screen_analyzer", return_value=fake_analyzer + ) as getter: + b1 = GraphBuilder(enable_quality_validation=False) + b2 = GraphBuilder(enable_quality_validation=False) + a1 = b1._get_screen_analyzer() + a2 = b2._get_screen_analyzer() + + assert a1 is fake_analyzer + assert a2 is fake_analyzer + # get_screen_analyzer appelé deux fois (une par builder), mais + # la vraie mutualisation passe par le singleton interne de C1. + assert getter.call_count >= 1 + + +# ----------------------------------------------------------------------------- +# Association spatiale clic → UIElement +# ----------------------------------------------------------------------------- + + +class TestFindClickedElement: + """Logique de proximité _find_clicked_element.""" + + def _builder(self, max_px: float = 50.0) -> GraphBuilder: + return GraphBuilder( + enable_quality_validation=False, + enable_ui_enrichment=False, + element_proximity_max_px=max_px, + ) + + def test_find_clicked_element_inside_bbox(self): + """Clic strictement dans un bbox → match exact.""" + builder = self._builder() + elements = [ + _make_ui_element("e1", "primary_action", "OK", (50, 50, 150, 150)), + _make_ui_element("e2", "cancel", "Annuler", (300, 300, 100, 50)), + ] + event = _make_click_event([100, 100]) + result = builder._find_clicked_element(event, elements) + assert result is not None + assert result.element_id == "e1" + + def test_find_clicked_element_nearest_proximity(self): + """Clic hors de tout bbox mais à <50px → match au plus proche.""" + builder = self._builder(max_px=50.0) + elements = [ + # bbox à (50,50,100,40) → bord droit = 150, bord bas = 90 + _make_ui_element("e_near", "primary_action", "Valider", (50, 50, 100, 40)), + # bbox loin (distance >> 50px du clic) + _make_ui_element("e_far", "cancel", "Annuler", (500, 500, 80, 30)), + ] + # Clic à (170, 70) → bord droit de e_near = 150, dx = 20, dy = 0 → 20px + event = _make_click_event([170, 70]) + result = builder._find_clicked_element(event, elements) + assert result is not None + assert result.element_id == "e_near" + + def test_find_clicked_element_too_far_returns_none(self): + """Clic à >50px du bbox le plus proche → None.""" + builder = self._builder(max_px=50.0) + elements = [ + _make_ui_element("e1", "primary_action", "OK", (50, 50, 100, 40)), + ] + # Clic à (300, 300), bbox à (50,50,100,40) → distance ~ 280px + event = _make_click_event([300, 300]) + result = builder._find_clicked_element(event, elements) + assert result is None + + def test_find_clicked_element_prefers_smallest_containing(self): + """Deux bbox contiennent le clic → retourne le plus spécifique (petit).""" + builder = self._builder() + elements = [ + # Grand container + _make_ui_element( + "container", "data_display", "Form", (0, 0, 800, 600), + el_type="container", + ), + # Petit bouton à l'intérieur + _make_ui_element("btn", "primary_action", "OK", (100, 100, 80, 30)), + ] + event = _make_click_event([120, 110]) + result = builder._find_clicked_element(event, elements) + assert result is not None + assert result.element_id == "btn" + + def test_find_clicked_element_empty_list(self): + builder = self._builder() + event = _make_click_event([100, 100]) + assert builder._find_clicked_element(event, []) is None + + def test_find_clicked_element_non_click_event(self): + """Un événement non-clic → None (pas d'ancrage spatial pertinent).""" + builder = self._builder() + elements = [ + _make_ui_element("e1", "form_input", "Nom", (100, 100, 100, 30)), + ] + event = _make_key_event(keys=["Enter"]) + assert builder._find_clicked_element(event, elements) is None + + +# ----------------------------------------------------------------------------- +# TargetSpec enrichi par _build_single_action +# ----------------------------------------------------------------------------- + + +class TestTargetSpecEnrichment: + """_build_single_action doit produire des TargetSpec discriminants.""" + + def test_target_spec_uses_element_role(self): + """Clic ancré sur un élément → by_role + by_text + context_hints.""" + builder = GraphBuilder( + enable_quality_validation=False, + enable_ui_enrichment=False, + ) + elements = [ + _make_ui_element("el_ok", "primary_action", "Valider", (100, 100, 120, 40)), + ] + event = _make_click_event([150, 120]) + action = builder._build_single_action(event, source_ui_elements=elements) + + assert action.type == "mouse_click" + assert action.target.by_role == "primary_action" + assert action.target.by_text == "Valider" + assert action.target.selection_policy == "by_similarity" + # Traçabilité dans context_hints + assert action.target.context_hints.get("anchor_element_id") == "el_ok" + assert "anchor_bbox" in action.target.context_hints + assert action.target.context_hints["anchor_bbox"]["x"] == 100 + + def test_target_spec_fallback_when_no_element(self): + """Aucun UIElement → legacy by_role=unknown_element.""" + builder = GraphBuilder( + enable_quality_validation=False, + enable_ui_enrichment=False, + ) + event = _make_click_event([400, 400]) + action = builder._build_single_action(event, source_ui_elements=[]) + assert action.target.by_role == "unknown_element" + assert action.target.by_text is None + # Pas de context_hints d'ancrage + assert not action.target.context_hints.get("anchor_element_id") + + def test_target_spec_fallback_when_click_too_far(self): + """Clic loin de tout bbox → fallback unknown_element.""" + builder = GraphBuilder( + enable_quality_validation=False, + enable_ui_enrichment=False, + element_proximity_max_px=30.0, + ) + elements = [ + _make_ui_element("far", "cancel", "X", (50, 50, 20, 20)), + ] + event = _make_click_event([800, 800]) + action = builder._build_single_action(event, source_ui_elements=elements) + assert action.target.by_role == "unknown_element" + + def test_keyboard_event_target_unchanged(self): + """Les events non-clic conservent leur target_role legacy.""" + builder = GraphBuilder( + enable_quality_validation=False, + enable_ui_enrichment=False, + ) + event = _make_key_event(text="hello", ev_type="text_input") + action = builder._build_single_action(event, source_ui_elements=[]) + assert action.target.by_role == "text_field" + + +# ----------------------------------------------------------------------------- +# UIConstraint.required_roles depuis _extract_common_ui_elements +# ----------------------------------------------------------------------------- + + +class TestRequiredRolesExtraction: + def test_required_roles_extracted_from_common_elements(self): + """3 ScreenState avec rôle commun → required_roles le contient.""" + builder = GraphBuilder( + enable_quality_validation=False, + enable_ui_enrichment=False, + ) + # 3 écrans, tous avec "primary_action" (Valider) et 2 avec "cancel" + states = [ + _make_screen_state( + "sid", i, + ui_elements=[ + _make_ui_element( + f"ok_{i}", "primary_action", "Valider", + (100, 100, 80, 30), + ), + _make_ui_element( + f"cancel_{i}", "cancel", "Annuler", + (200, 100, 80, 30), + ) if i < 2 else _make_ui_element( + f"other_{i}", "navigation", "Menu", + (300, 100, 80, 30), + ), + ], + ) + for i in range(3) + ] + + prototype = np.zeros(512, dtype=np.float32) + prototype[0] = 1.0 + template = builder._create_screen_template(states, prototype) + + assert template.ui is not None + # primary_action présent dans 3/3 écrans → inclus + assert "primary_action" in template.ui.required_roles + # cancel présent dans 2/3 → ratio 0.66 >= 0.5 → inclus + assert "cancel" in template.ui.required_roles + # navigation présent dans 1/3 → ratio 0.33 < 0.5 → exclu + assert "navigation" not in template.ui.required_roles