diff --git a/agent_v0/server_v1/stream_processor.py b/agent_v0/server_v1/stream_processor.py index 9cc534bcd..c17ef733f 100644 --- a/agent_v0/server_v1/stream_processor.py +++ b/agent_v0/server_v1/stream_processor.py @@ -2461,6 +2461,10 @@ class StreamProcessor: # Workflows construits (pour le matching) self._workflows: Dict[str, Any] = {} + # P1.1 : learner continu branché uniquement sur de vraies observations + # cross-session, après matching offline contre un workflow existant. + self._continuous_learner = None + # Shadow learning : dernier pattern UI détecté par session # Stocke {session_id: {"pattern": str, "ocr_text": str, "screen_state": obj, "shot_id": str}} self._pending_ui_patterns: Dict[str, Dict[str, Any]] = {} @@ -3005,6 +3009,14 @@ class StreamProcessor: except Exception as e2: return {"error": f"Erreur RawSession: {e2}"} + session_machine_id = getattr(session, "machine_id", None) + cross_learning = self._run_cross_session_learning( + session_id=session_id, + states=states, + embeddings=embeddings, + machine_id=session_machine_id, + ) + # Construire le workflow via GraphBuilder try: from core.graph.graph_builder import GraphBuilder @@ -3086,6 +3098,7 @@ class StreamProcessor: "embeddings_indexed": len(embeddings), "saved_path": str(saved_path) if saved_path else None, "app_context": app_context, + "cross_session_learning": cross_learning, } logger.info( @@ -3130,6 +3143,257 @@ class StreamProcessor: return None + def _get_continuous_learner(self): + """Lazy init du ContinuousLearner existant.""" + if self._continuous_learner is not None: + return self._continuous_learner + from core.learning.continuous_learner import ContinuousLearner + self._continuous_learner = ContinuousLearner() + return self._continuous_learner + + def _run_cross_session_learning( + self, + session_id: str, + states: List[Any], + embeddings: List[np.ndarray], + machine_id: Optional[str] = None, + *, + update_threshold: float = 0.85, + drift_min_confidence: float = 0.50, + ) -> Dict[str, Any]: + """Matcher une session observée contre les workflows existants. + + P1.1 Option A : le learner ne reçoit jamais le prototype du node + existant. Il reçoit uniquement un embedding observé de la session + courante, si cet embedding matche fortement un node d'un workflow déjà + connu. Les identifiants écrits par le learner sont hashés pour éviter + de propager un nom de workflow potentiellement métier dans les chemins. + """ + stats: Dict[str, Any] = { + "status": "skipped", + "states_seen": len(states), + "embeddings_seen": len(embeddings), + "candidate_workflows": 0, + "matches": 0, + "updates": 0, + "drift_checks": 0, + "skips": {}, + } + + def _skip(reason: str) -> Dict[str, Any]: + stats["skips"][reason] = stats["skips"].get(reason, 0) + 1 + return stats + + if not states: + return _skip("no_states") + if len(states) != len(embeddings): + return _skip("embedding_count_mismatch") + + with self._data_lock: + workflows = list(self._workflows.values()) + + if not workflows: + return _skip("no_existing_workflow") + + stats["candidate_workflows"] = sum( + 1 + for workflow in workflows + if not ( + machine_id + and getattr(workflow, "_machine_id", None) + and getattr(workflow, "_machine_id", None) != machine_id + ) + ) + if stats["candidate_workflows"] == 0: + return _skip("no_candidate_workflow_for_machine") + + learner = self._get_continuous_learner() + stats["status"] = "processed" + + for state, observed_embedding in zip(states, embeddings): + observed_vector = self._normalise_vector(observed_embedding) + if observed_vector is None: + _skip("invalid_observed_embedding") + continue + + match = self._find_best_cross_session_match( + state=state, + observed_vector=observed_vector, + workflows=workflows, + machine_id=machine_id, + min_confidence=drift_min_confidence, + ) + if match is None: + _skip("no_match") + continue + + stats["matches"] += 1 + node_key = self._learning_node_key( + workflow_id=match["workflow_id"], + node_id=match["node_id"], + ) + + confidence = float(match["confidence"]) + learner.detect_drift(node_key, [confidence]) + stats["drift_checks"] += 1 + + if confidence < update_threshold: + _skip("below_update_threshold") + continue + + prototype = match.get("prototype") + if prototype is not None and np.allclose( + observed_vector, prototype, atol=1e-6 + ): + _skip("same_as_existing_prototype") + continue + + # Signal réel : observation acceptée uniquement parce qu'un match + # cross-session dépasse le seuil d'update. Les confidences faibles + # alimentent seulement la détection de drift ci-dessus. + execution_success = confidence >= update_threshold + if not execution_success: + _skip("no_success_signal") + continue + + learner.update_prototype( + node_key, + observed_vector.copy(), + execution_success=execution_success, + ) + stats["updates"] += 1 + + logger.info( + "P1.1 cross-session learning: states=%d workflows=%d matches=%d " + "updates=%d drift_checks=%d skips=%s", + stats["states_seen"], + stats["candidate_workflows"], + stats["matches"], + stats["updates"], + stats["drift_checks"], + stats["skips"], + ) + return stats + + def _find_best_cross_session_match( + self, + state: Any, + observed_vector: np.ndarray, + workflows: List[Any], + machine_id: Optional[str], + min_confidence: float, + ) -> Optional[Dict[str, Any]]: + """Retour le meilleur node existant pour un embedding observé.""" + best: Optional[Dict[str, Any]] = None + for workflow in workflows: + workflow_machine = getattr(workflow, "_machine_id", None) + if machine_id and workflow_machine and workflow_machine != machine_id: + continue + + workflow_id = getattr(workflow, "workflow_id", "") + for node in getattr(workflow, "nodes", []) or []: + prototype = self._extract_node_prototype(node) + if prototype is None or prototype.shape != observed_vector.shape: + continue + + confidence = self._cosine_similarity(observed_vector, prototype) + if confidence < min_confidence: + continue + if not self._template_accepts_observation(node, state, confidence): + continue + + if best is None or confidence > best["confidence"]: + best = { + "workflow_id": workflow_id, + "node_id": getattr(node, "node_id", ""), + "confidence": confidence, + "prototype": prototype, + } + + return best + + def _extract_node_prototype(self, node: Any) -> Optional[np.ndarray]: + """Extraire le prototype d'un node sans dépendre de FAISS.""" + meta = getattr(node, "metadata", {}) or {} + proto_list = meta.get("_prototype_vector") + if isinstance(proto_list, list): + return self._normalise_vector(proto_list) + + template = getattr(node, "template", None) + embedding = getattr(template, "embedding", None) if template else None + vector_id = getattr(embedding, "vector_id", None) if embedding else None + if vector_id: + try: + path = Path(vector_id) + if path.exists(): + return self._normalise_vector(np.load(path)) + except Exception as exc: + logger.debug("Prototype node illisible, skip: %s", exc) + return None + + @staticmethod + def _normalise_vector(vector: Any) -> Optional[np.ndarray]: + try: + arr = np.asarray(vector, dtype=np.float32) + except Exception: + return None + if arr.ndim != 1 or arr.size == 0: + return None + norm = float(np.linalg.norm(arr)) + if norm <= 0: + return None + return arr / norm + + @staticmethod + def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: + return float(np.clip(np.dot(a, b), -1.0, 1.0)) + + @staticmethod + def _template_accepts_observation( + node: Any, + state: Any, + confidence: float, + ) -> bool: + template = getattr(node, "template", None) + if template is None: + return True + try: + window = getattr(template, "window", None) + if window and hasattr(state, "window"): + window_title = getattr(state.window, "window_title", "") + process = getattr(state.window, "process", "") + app_name = getattr(state.window, "app_name", "") + if not window.matches(window_title, process or app_name): + return False + + text = getattr(template, "text", None) + if text and hasattr(state, "perception"): + detected = getattr(state.perception, "detected_text", []) + if not text.matches(detected): + return False + + ui = getattr(template, "ui", None) + if ui and hasattr(state, "ui_elements"): + if not ui.matches(getattr(state, "ui_elements", [])): + return False + + return True + except Exception as exc: + logger.debug("Template match offline impossible, skip node: %s", exc) + return False + + @staticmethod + def _learning_node_key(workflow_id: str, node_id: str) -> str: + """Clé learner stable sans fuite de nom workflow potentiellement métier.""" + digest = hashlib.sha256(f"{workflow_id}:{node_id}".encode("utf-8")).hexdigest() + safe_node = "".join( + ch if ch.isascii() and (ch.isalnum() or ch in "._-") else "_" + for ch in str(node_id) + ).strip("._-") + if not safe_node: + safe_node = "node" + return f"wf_{digest[:16]}__{safe_node[:64]}" + # ========================================================================= # Enrichissement VLM des workflows (target_spec sur chaque edge) # ========================================================================= diff --git a/tests/unit/test_stream_processor_cross_session_learning.py b/tests/unit/test_stream_processor_cross_session_learning.py new file mode 100644 index 000000000..c727f896c --- /dev/null +++ b/tests/unit/test_stream_processor_cross_session_learning.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from datetime import datetime +from unittest.mock import MagicMock + +import numpy as np + +from agent_v0.server_v1.stream_processor import StreamProcessor +from core.models.screen_state import ( + ContextLevel, + EmbeddingRef, + PerceptionLevel, + RawLevel, + ScreenState, + WindowContext, +) +from core.models.workflow_graph import ( + EmbeddingPrototype, + LearningConfig, + SafetyRules, + ScreenTemplate, + TextConstraint, + UIConstraint, + WindowConstraint, + Workflow, + WorkflowNode, + WorkflowStats, +) + + +def _unit(vector) -> np.ndarray: + arr = np.asarray(vector, dtype=np.float32) + return arr / np.linalg.norm(arr) + + +def _state(window_title: str = "Application") -> ScreenState: + return ScreenState( + screen_state_id="state_obs", + timestamp=datetime.now(), + session_id="sess_obs", + window=WindowContext( + app_name="app", + window_title=window_title, + screen_resolution=[1920, 1080], + ), + raw=RawLevel( + screenshot_path="", + capture_method="test", + file_size_bytes=0, + ), + perception=PerceptionLevel( + embedding=EmbeddingRef(provider="test", vector_id="", dimensions=2), + detected_text=[], + text_detection_method="test", + confidence_avg=1.0, + ), + context=ContextLevel(), + ui_elements=[], + ) + + +def _workflow(workflow_id: str, prototype) -> Workflow: + embedding = EmbeddingPrototype( + provider="test", + vector_id="", + min_cosine_similarity=0.85, + sample_count=1, + ) + template = ScreenTemplate( + window=WindowConstraint(), + text=TextConstraint(), + ui=UIConstraint(), + embedding=embedding, + ) + node = WorkflowNode( + node_id="node_000", + name="Node 000", + description="test node", + template=template, + metadata={"_prototype_vector": _unit(prototype).tolist()}, + ) + return Workflow( + workflow_id=workflow_id, + name="Existing workflow", + description="test workflow", + version=1, + learning_state="OBSERVATION", + created_at=datetime.now(), + updated_at=datetime.now(), + entry_nodes=["node_000"], + end_nodes=[], + nodes=[node], + edges=[], + safety_rules=SafetyRules(), + stats=WorkflowStats(), + learning=LearningConfig(), + ) + + +def _processor(tmp_path, workflow: Workflow | None = None) -> StreamProcessor: + processor = StreamProcessor(data_dir=str(tmp_path / "training")) + processor._initialized = True + processor._continuous_learner = MagicMock() + processor._workflows = {} + if workflow is not None: + processor._workflows[workflow.workflow_id] = workflow + return processor + + +def test_cross_session_learning_updates_with_observed_embedding(tmp_path): + workflow = _workflow("wf_patient_name_must_not_leak", [1.0, 0.0]) + processor = _processor(tmp_path, workflow) + observed = _unit([0.9, 0.2]) + + stats = processor._run_cross_session_learning( + session_id="sess_repeat", + states=[_state()], + embeddings=[observed], + ) + + assert stats["updates"] == 1 + processor._continuous_learner.update_prototype.assert_called_once() + node_key, updated_vector = processor._continuous_learner.update_prototype.call_args.args + kwargs = processor._continuous_learner.update_prototype.call_args.kwargs + assert kwargs["execution_success"] is True + assert np.allclose(updated_vector, observed) + assert not np.allclose(updated_vector, _unit([1.0, 0.0])) + assert "patient" not in node_key + assert "wf_patient" not in node_key + assert node_key.endswith("__node_000") + + +def test_cross_session_learning_skips_without_existing_workflow(tmp_path): + processor = _processor(tmp_path, workflow=None) + + stats = processor._run_cross_session_learning( + session_id="sess_repeat", + states=[_state()], + embeddings=[_unit([0.9, 0.2])], + ) + + assert stats["updates"] == 0 + assert stats["skips"]["no_existing_workflow"] == 1 + processor._continuous_learner.update_prototype.assert_not_called() + + +def test_cross_session_learning_skips_embedding_count_mismatch(tmp_path): + workflow = _workflow("wf_existing", [1.0, 0.0]) + processor = _processor(tmp_path, workflow) + + stats = processor._run_cross_session_learning( + session_id="sess_repeat", + states=[_state(), _state()], + embeddings=[_unit([0.9, 0.2])], + ) + + assert stats["updates"] == 0 + assert stats["skips"]["embedding_count_mismatch"] == 1 + processor._continuous_learner.update_prototype.assert_not_called() + + +def test_cross_session_learning_skips_exact_prototype_noop(tmp_path): + workflow = _workflow("wf_existing", [1.0, 0.0]) + processor = _processor(tmp_path, workflow) + + stats = processor._run_cross_session_learning( + session_id="sess_repeat", + states=[_state()], + embeddings=[_unit([1.0, 0.0])], + ) + + assert stats["updates"] == 0 + assert stats["matches"] == 1 + assert stats["skips"]["same_as_existing_prototype"] == 1 + processor._continuous_learner.update_prototype.assert_not_called() + + +def test_cross_session_learning_uses_low_confidence_for_drift_only(tmp_path): + workflow = _workflow("wf_existing", [1.0, 0.0]) + processor = _processor(tmp_path, workflow) + observed = _unit([0.6, 0.8]) + + stats = processor._run_cross_session_learning( + session_id="sess_repeat", + states=[_state()], + embeddings=[observed], + ) + + assert stats["updates"] == 0 + assert stats["drift_checks"] == 1 + assert stats["skips"]["below_update_threshold"] == 1 + processor._continuous_learner.detect_drift.assert_called_once() + processor._continuous_learner.update_prototype.assert_not_called() + + +def test_learning_node_key_is_safe_and_does_not_leak_workflow_name(): + key = StreamProcessor._learning_node_key( + workflow_id="DPI patient DUPONT né 1970", + node_id="node/écran principal", + ) + + assert "DUPONT" not in key + assert "patient" not in key + assert "/" not in key + assert key.startswith("wf_") + assert "__node_" in key + assert key.isascii()