feat(p11): learn from offline cross-session matches
This commit is contained in:
@@ -2461,6 +2461,10 @@ class StreamProcessor:
|
||||
# Workflows construits (pour le matching)
|
||||
self._workflows: Dict[str, Any] = {}
|
||||
|
||||
# P1.1 : learner continu branché uniquement sur de vraies observations
|
||||
# cross-session, après matching offline contre un workflow existant.
|
||||
self._continuous_learner = None
|
||||
|
||||
# Shadow learning : dernier pattern UI détecté par session
|
||||
# Stocke {session_id: {"pattern": str, "ocr_text": str, "screen_state": obj, "shot_id": str}}
|
||||
self._pending_ui_patterns: Dict[str, Dict[str, Any]] = {}
|
||||
@@ -3005,6 +3009,14 @@ class StreamProcessor:
|
||||
except Exception as e2:
|
||||
return {"error": f"Erreur RawSession: {e2}"}
|
||||
|
||||
session_machine_id = getattr(session, "machine_id", None)
|
||||
cross_learning = self._run_cross_session_learning(
|
||||
session_id=session_id,
|
||||
states=states,
|
||||
embeddings=embeddings,
|
||||
machine_id=session_machine_id,
|
||||
)
|
||||
|
||||
# Construire le workflow via GraphBuilder
|
||||
try:
|
||||
from core.graph.graph_builder import GraphBuilder
|
||||
@@ -3086,6 +3098,7 @@ class StreamProcessor:
|
||||
"embeddings_indexed": len(embeddings),
|
||||
"saved_path": str(saved_path) if saved_path else None,
|
||||
"app_context": app_context,
|
||||
"cross_session_learning": cross_learning,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
@@ -3130,6 +3143,257 @@ class StreamProcessor:
|
||||
|
||||
return None
|
||||
|
||||
def _get_continuous_learner(self):
|
||||
"""Lazy init du ContinuousLearner existant."""
|
||||
if self._continuous_learner is not None:
|
||||
return self._continuous_learner
|
||||
from core.learning.continuous_learner import ContinuousLearner
|
||||
self._continuous_learner = ContinuousLearner()
|
||||
return self._continuous_learner
|
||||
|
||||
def _run_cross_session_learning(
|
||||
self,
|
||||
session_id: str,
|
||||
states: List[Any],
|
||||
embeddings: List[np.ndarray],
|
||||
machine_id: Optional[str] = None,
|
||||
*,
|
||||
update_threshold: float = 0.85,
|
||||
drift_min_confidence: float = 0.50,
|
||||
) -> Dict[str, Any]:
|
||||
"""Matcher une session observée contre les workflows existants.
|
||||
|
||||
P1.1 Option A : le learner ne reçoit jamais le prototype du node
|
||||
existant. Il reçoit uniquement un embedding observé de la session
|
||||
courante, si cet embedding matche fortement un node d'un workflow déjà
|
||||
connu. Les identifiants écrits par le learner sont hashés pour éviter
|
||||
de propager un nom de workflow potentiellement métier dans les chemins.
|
||||
"""
|
||||
stats: Dict[str, Any] = {
|
||||
"status": "skipped",
|
||||
"states_seen": len(states),
|
||||
"embeddings_seen": len(embeddings),
|
||||
"candidate_workflows": 0,
|
||||
"matches": 0,
|
||||
"updates": 0,
|
||||
"drift_checks": 0,
|
||||
"skips": {},
|
||||
}
|
||||
|
||||
def _skip(reason: str) -> Dict[str, Any]:
|
||||
stats["skips"][reason] = stats["skips"].get(reason, 0) + 1
|
||||
return stats
|
||||
|
||||
if not states:
|
||||
return _skip("no_states")
|
||||
if len(states) != len(embeddings):
|
||||
return _skip("embedding_count_mismatch")
|
||||
|
||||
with self._data_lock:
|
||||
workflows = list(self._workflows.values())
|
||||
|
||||
if not workflows:
|
||||
return _skip("no_existing_workflow")
|
||||
|
||||
stats["candidate_workflows"] = sum(
|
||||
1
|
||||
for workflow in workflows
|
||||
if not (
|
||||
machine_id
|
||||
and getattr(workflow, "_machine_id", None)
|
||||
and getattr(workflow, "_machine_id", None) != machine_id
|
||||
)
|
||||
)
|
||||
if stats["candidate_workflows"] == 0:
|
||||
return _skip("no_candidate_workflow_for_machine")
|
||||
|
||||
learner = self._get_continuous_learner()
|
||||
stats["status"] = "processed"
|
||||
|
||||
for state, observed_embedding in zip(states, embeddings):
|
||||
observed_vector = self._normalise_vector(observed_embedding)
|
||||
if observed_vector is None:
|
||||
_skip("invalid_observed_embedding")
|
||||
continue
|
||||
|
||||
match = self._find_best_cross_session_match(
|
||||
state=state,
|
||||
observed_vector=observed_vector,
|
||||
workflows=workflows,
|
||||
machine_id=machine_id,
|
||||
min_confidence=drift_min_confidence,
|
||||
)
|
||||
if match is None:
|
||||
_skip("no_match")
|
||||
continue
|
||||
|
||||
stats["matches"] += 1
|
||||
node_key = self._learning_node_key(
|
||||
workflow_id=match["workflow_id"],
|
||||
node_id=match["node_id"],
|
||||
)
|
||||
|
||||
confidence = float(match["confidence"])
|
||||
learner.detect_drift(node_key, [confidence])
|
||||
stats["drift_checks"] += 1
|
||||
|
||||
if confidence < update_threshold:
|
||||
_skip("below_update_threshold")
|
||||
continue
|
||||
|
||||
prototype = match.get("prototype")
|
||||
if prototype is not None and np.allclose(
|
||||
observed_vector, prototype, atol=1e-6
|
||||
):
|
||||
_skip("same_as_existing_prototype")
|
||||
continue
|
||||
|
||||
# Signal réel : observation acceptée uniquement parce qu'un match
|
||||
# cross-session dépasse le seuil d'update. Les confidences faibles
|
||||
# alimentent seulement la détection de drift ci-dessus.
|
||||
execution_success = confidence >= update_threshold
|
||||
if not execution_success:
|
||||
_skip("no_success_signal")
|
||||
continue
|
||||
|
||||
learner.update_prototype(
|
||||
node_key,
|
||||
observed_vector.copy(),
|
||||
execution_success=execution_success,
|
||||
)
|
||||
stats["updates"] += 1
|
||||
|
||||
logger.info(
|
||||
"P1.1 cross-session learning: states=%d workflows=%d matches=%d "
|
||||
"updates=%d drift_checks=%d skips=%s",
|
||||
stats["states_seen"],
|
||||
stats["candidate_workflows"],
|
||||
stats["matches"],
|
||||
stats["updates"],
|
||||
stats["drift_checks"],
|
||||
stats["skips"],
|
||||
)
|
||||
return stats
|
||||
|
||||
def _find_best_cross_session_match(
|
||||
self,
|
||||
state: Any,
|
||||
observed_vector: np.ndarray,
|
||||
workflows: List[Any],
|
||||
machine_id: Optional[str],
|
||||
min_confidence: float,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Retour le meilleur node existant pour un embedding observé."""
|
||||
best: Optional[Dict[str, Any]] = None
|
||||
for workflow in workflows:
|
||||
workflow_machine = getattr(workflow, "_machine_id", None)
|
||||
if machine_id and workflow_machine and workflow_machine != machine_id:
|
||||
continue
|
||||
|
||||
workflow_id = getattr(workflow, "workflow_id", "")
|
||||
for node in getattr(workflow, "nodes", []) or []:
|
||||
prototype = self._extract_node_prototype(node)
|
||||
if prototype is None or prototype.shape != observed_vector.shape:
|
||||
continue
|
||||
|
||||
confidence = self._cosine_similarity(observed_vector, prototype)
|
||||
if confidence < min_confidence:
|
||||
continue
|
||||
if not self._template_accepts_observation(node, state, confidence):
|
||||
continue
|
||||
|
||||
if best is None or confidence > best["confidence"]:
|
||||
best = {
|
||||
"workflow_id": workflow_id,
|
||||
"node_id": getattr(node, "node_id", ""),
|
||||
"confidence": confidence,
|
||||
"prototype": prototype,
|
||||
}
|
||||
|
||||
return best
|
||||
|
||||
def _extract_node_prototype(self, node: Any) -> Optional[np.ndarray]:
|
||||
"""Extraire le prototype d'un node sans dépendre de FAISS."""
|
||||
meta = getattr(node, "metadata", {}) or {}
|
||||
proto_list = meta.get("_prototype_vector")
|
||||
if isinstance(proto_list, list):
|
||||
return self._normalise_vector(proto_list)
|
||||
|
||||
template = getattr(node, "template", None)
|
||||
embedding = getattr(template, "embedding", None) if template else None
|
||||
vector_id = getattr(embedding, "vector_id", None) if embedding else None
|
||||
if vector_id:
|
||||
try:
|
||||
path = Path(vector_id)
|
||||
if path.exists():
|
||||
return self._normalise_vector(np.load(path))
|
||||
except Exception as exc:
|
||||
logger.debug("Prototype node illisible, skip: %s", exc)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalise_vector(vector: Any) -> Optional[np.ndarray]:
|
||||
try:
|
||||
arr = np.asarray(vector, dtype=np.float32)
|
||||
except Exception:
|
||||
return None
|
||||
if arr.ndim != 1 or arr.size == 0:
|
||||
return None
|
||||
norm = float(np.linalg.norm(arr))
|
||||
if norm <= 0:
|
||||
return None
|
||||
return arr / norm
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
return float(np.clip(np.dot(a, b), -1.0, 1.0))
|
||||
|
||||
@staticmethod
|
||||
def _template_accepts_observation(
|
||||
node: Any,
|
||||
state: Any,
|
||||
confidence: float,
|
||||
) -> bool:
|
||||
template = getattr(node, "template", None)
|
||||
if template is None:
|
||||
return True
|
||||
try:
|
||||
window = getattr(template, "window", None)
|
||||
if window and hasattr(state, "window"):
|
||||
window_title = getattr(state.window, "window_title", "")
|
||||
process = getattr(state.window, "process", "")
|
||||
app_name = getattr(state.window, "app_name", "")
|
||||
if not window.matches(window_title, process or app_name):
|
||||
return False
|
||||
|
||||
text = getattr(template, "text", None)
|
||||
if text and hasattr(state, "perception"):
|
||||
detected = getattr(state.perception, "detected_text", [])
|
||||
if not text.matches(detected):
|
||||
return False
|
||||
|
||||
ui = getattr(template, "ui", None)
|
||||
if ui and hasattr(state, "ui_elements"):
|
||||
if not ui.matches(getattr(state, "ui_elements", [])):
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.debug("Template match offline impossible, skip node: %s", exc)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _learning_node_key(workflow_id: str, node_id: str) -> str:
|
||||
"""Clé learner stable sans fuite de nom workflow potentiellement métier."""
|
||||
digest = hashlib.sha256(f"{workflow_id}:{node_id}".encode("utf-8")).hexdigest()
|
||||
safe_node = "".join(
|
||||
ch if ch.isascii() and (ch.isalnum() or ch in "._-") else "_"
|
||||
for ch in str(node_id)
|
||||
).strip("._-")
|
||||
if not safe_node:
|
||||
safe_node = "node"
|
||||
return f"wf_{digest[:16]}__{safe_node[:64]}"
|
||||
|
||||
# =========================================================================
|
||||
# Enrichissement VLM des workflows (target_spec sur chaque edge)
|
||||
# =========================================================================
|
||||
|
||||
207
tests/unit/test_stream_processor_cross_session_learning.py
Normal file
207
tests/unit/test_stream_processor_cross_session_learning.py
Normal file
@@ -0,0 +1,207 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agent_v0.server_v1.stream_processor import StreamProcessor
|
||||
from core.models.screen_state import (
|
||||
ContextLevel,
|
||||
EmbeddingRef,
|
||||
PerceptionLevel,
|
||||
RawLevel,
|
||||
ScreenState,
|
||||
WindowContext,
|
||||
)
|
||||
from core.models.workflow_graph import (
|
||||
EmbeddingPrototype,
|
||||
LearningConfig,
|
||||
SafetyRules,
|
||||
ScreenTemplate,
|
||||
TextConstraint,
|
||||
UIConstraint,
|
||||
WindowConstraint,
|
||||
Workflow,
|
||||
WorkflowNode,
|
||||
WorkflowStats,
|
||||
)
|
||||
|
||||
|
||||
def _unit(vector) -> np.ndarray:
|
||||
arr = np.asarray(vector, dtype=np.float32)
|
||||
return arr / np.linalg.norm(arr)
|
||||
|
||||
|
||||
def _state(window_title: str = "Application") -> ScreenState:
|
||||
return ScreenState(
|
||||
screen_state_id="state_obs",
|
||||
timestamp=datetime.now(),
|
||||
session_id="sess_obs",
|
||||
window=WindowContext(
|
||||
app_name="app",
|
||||
window_title=window_title,
|
||||
screen_resolution=[1920, 1080],
|
||||
),
|
||||
raw=RawLevel(
|
||||
screenshot_path="",
|
||||
capture_method="test",
|
||||
file_size_bytes=0,
|
||||
),
|
||||
perception=PerceptionLevel(
|
||||
embedding=EmbeddingRef(provider="test", vector_id="", dimensions=2),
|
||||
detected_text=[],
|
||||
text_detection_method="test",
|
||||
confidence_avg=1.0,
|
||||
),
|
||||
context=ContextLevel(),
|
||||
ui_elements=[],
|
||||
)
|
||||
|
||||
|
||||
def _workflow(workflow_id: str, prototype) -> Workflow:
|
||||
embedding = EmbeddingPrototype(
|
||||
provider="test",
|
||||
vector_id="",
|
||||
min_cosine_similarity=0.85,
|
||||
sample_count=1,
|
||||
)
|
||||
template = ScreenTemplate(
|
||||
window=WindowConstraint(),
|
||||
text=TextConstraint(),
|
||||
ui=UIConstraint(),
|
||||
embedding=embedding,
|
||||
)
|
||||
node = WorkflowNode(
|
||||
node_id="node_000",
|
||||
name="Node 000",
|
||||
description="test node",
|
||||
template=template,
|
||||
metadata={"_prototype_vector": _unit(prototype).tolist()},
|
||||
)
|
||||
return Workflow(
|
||||
workflow_id=workflow_id,
|
||||
name="Existing workflow",
|
||||
description="test workflow",
|
||||
version=1,
|
||||
learning_state="OBSERVATION",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
entry_nodes=["node_000"],
|
||||
end_nodes=[],
|
||||
nodes=[node],
|
||||
edges=[],
|
||||
safety_rules=SafetyRules(),
|
||||
stats=WorkflowStats(),
|
||||
learning=LearningConfig(),
|
||||
)
|
||||
|
||||
|
||||
def _processor(tmp_path, workflow: Workflow | None = None) -> StreamProcessor:
|
||||
processor = StreamProcessor(data_dir=str(tmp_path / "training"))
|
||||
processor._initialized = True
|
||||
processor._continuous_learner = MagicMock()
|
||||
processor._workflows = {}
|
||||
if workflow is not None:
|
||||
processor._workflows[workflow.workflow_id] = workflow
|
||||
return processor
|
||||
|
||||
|
||||
def test_cross_session_learning_updates_with_observed_embedding(tmp_path):
|
||||
workflow = _workflow("wf_patient_name_must_not_leak", [1.0, 0.0])
|
||||
processor = _processor(tmp_path, workflow)
|
||||
observed = _unit([0.9, 0.2])
|
||||
|
||||
stats = processor._run_cross_session_learning(
|
||||
session_id="sess_repeat",
|
||||
states=[_state()],
|
||||
embeddings=[observed],
|
||||
)
|
||||
|
||||
assert stats["updates"] == 1
|
||||
processor._continuous_learner.update_prototype.assert_called_once()
|
||||
node_key, updated_vector = processor._continuous_learner.update_prototype.call_args.args
|
||||
kwargs = processor._continuous_learner.update_prototype.call_args.kwargs
|
||||
assert kwargs["execution_success"] is True
|
||||
assert np.allclose(updated_vector, observed)
|
||||
assert not np.allclose(updated_vector, _unit([1.0, 0.0]))
|
||||
assert "patient" not in node_key
|
||||
assert "wf_patient" not in node_key
|
||||
assert node_key.endswith("__node_000")
|
||||
|
||||
|
||||
def test_cross_session_learning_skips_without_existing_workflow(tmp_path):
|
||||
processor = _processor(tmp_path, workflow=None)
|
||||
|
||||
stats = processor._run_cross_session_learning(
|
||||
session_id="sess_repeat",
|
||||
states=[_state()],
|
||||
embeddings=[_unit([0.9, 0.2])],
|
||||
)
|
||||
|
||||
assert stats["updates"] == 0
|
||||
assert stats["skips"]["no_existing_workflow"] == 1
|
||||
processor._continuous_learner.update_prototype.assert_not_called()
|
||||
|
||||
|
||||
def test_cross_session_learning_skips_embedding_count_mismatch(tmp_path):
|
||||
workflow = _workflow("wf_existing", [1.0, 0.0])
|
||||
processor = _processor(tmp_path, workflow)
|
||||
|
||||
stats = processor._run_cross_session_learning(
|
||||
session_id="sess_repeat",
|
||||
states=[_state(), _state()],
|
||||
embeddings=[_unit([0.9, 0.2])],
|
||||
)
|
||||
|
||||
assert stats["updates"] == 0
|
||||
assert stats["skips"]["embedding_count_mismatch"] == 1
|
||||
processor._continuous_learner.update_prototype.assert_not_called()
|
||||
|
||||
|
||||
def test_cross_session_learning_skips_exact_prototype_noop(tmp_path):
|
||||
workflow = _workflow("wf_existing", [1.0, 0.0])
|
||||
processor = _processor(tmp_path, workflow)
|
||||
|
||||
stats = processor._run_cross_session_learning(
|
||||
session_id="sess_repeat",
|
||||
states=[_state()],
|
||||
embeddings=[_unit([1.0, 0.0])],
|
||||
)
|
||||
|
||||
assert stats["updates"] == 0
|
||||
assert stats["matches"] == 1
|
||||
assert stats["skips"]["same_as_existing_prototype"] == 1
|
||||
processor._continuous_learner.update_prototype.assert_not_called()
|
||||
|
||||
|
||||
def test_cross_session_learning_uses_low_confidence_for_drift_only(tmp_path):
|
||||
workflow = _workflow("wf_existing", [1.0, 0.0])
|
||||
processor = _processor(tmp_path, workflow)
|
||||
observed = _unit([0.6, 0.8])
|
||||
|
||||
stats = processor._run_cross_session_learning(
|
||||
session_id="sess_repeat",
|
||||
states=[_state()],
|
||||
embeddings=[observed],
|
||||
)
|
||||
|
||||
assert stats["updates"] == 0
|
||||
assert stats["drift_checks"] == 1
|
||||
assert stats["skips"]["below_update_threshold"] == 1
|
||||
processor._continuous_learner.detect_drift.assert_called_once()
|
||||
processor._continuous_learner.update_prototype.assert_not_called()
|
||||
|
||||
|
||||
def test_learning_node_key_is_safe_and_does_not_leak_workflow_name():
|
||||
key = StreamProcessor._learning_node_key(
|
||||
workflow_id="DPI patient DUPONT né 1970",
|
||||
node_id="node/écran principal",
|
||||
)
|
||||
|
||||
assert "DUPONT" not in key
|
||||
assert "patient" not in key
|
||||
assert "/" not in key
|
||||
assert key.startswith("wf_")
|
||||
assert "__node_" in key
|
||||
assert key.isascii()
|
||||
Reference in New Issue
Block a user