feat(p11): learn from offline cross-session matches

This commit is contained in:
Dom
2026-06-02 17:46:15 +02:00
parent 4b3d5ce0d7
commit 5289f3de48
2 changed files with 471 additions and 0 deletions

View File

@@ -0,0 +1,207 @@
from __future__ import annotations
from datetime import datetime
from unittest.mock import MagicMock
import numpy as np
from agent_v0.server_v1.stream_processor import StreamProcessor
from core.models.screen_state import (
ContextLevel,
EmbeddingRef,
PerceptionLevel,
RawLevel,
ScreenState,
WindowContext,
)
from core.models.workflow_graph import (
EmbeddingPrototype,
LearningConfig,
SafetyRules,
ScreenTemplate,
TextConstraint,
UIConstraint,
WindowConstraint,
Workflow,
WorkflowNode,
WorkflowStats,
)
def _unit(vector) -> np.ndarray:
arr = np.asarray(vector, dtype=np.float32)
return arr / np.linalg.norm(arr)
def _state(window_title: str = "Application") -> ScreenState:
return ScreenState(
screen_state_id="state_obs",
timestamp=datetime.now(),
session_id="sess_obs",
window=WindowContext(
app_name="app",
window_title=window_title,
screen_resolution=[1920, 1080],
),
raw=RawLevel(
screenshot_path="",
capture_method="test",
file_size_bytes=0,
),
perception=PerceptionLevel(
embedding=EmbeddingRef(provider="test", vector_id="", dimensions=2),
detected_text=[],
text_detection_method="test",
confidence_avg=1.0,
),
context=ContextLevel(),
ui_elements=[],
)
def _workflow(workflow_id: str, prototype) -> Workflow:
embedding = EmbeddingPrototype(
provider="test",
vector_id="",
min_cosine_similarity=0.85,
sample_count=1,
)
template = ScreenTemplate(
window=WindowConstraint(),
text=TextConstraint(),
ui=UIConstraint(),
embedding=embedding,
)
node = WorkflowNode(
node_id="node_000",
name="Node 000",
description="test node",
template=template,
metadata={"_prototype_vector": _unit(prototype).tolist()},
)
return Workflow(
workflow_id=workflow_id,
name="Existing workflow",
description="test workflow",
version=1,
learning_state="OBSERVATION",
created_at=datetime.now(),
updated_at=datetime.now(),
entry_nodes=["node_000"],
end_nodes=[],
nodes=[node],
edges=[],
safety_rules=SafetyRules(),
stats=WorkflowStats(),
learning=LearningConfig(),
)
def _processor(tmp_path, workflow: Workflow | None = None) -> StreamProcessor:
processor = StreamProcessor(data_dir=str(tmp_path / "training"))
processor._initialized = True
processor._continuous_learner = MagicMock()
processor._workflows = {}
if workflow is not None:
processor._workflows[workflow.workflow_id] = workflow
return processor
def test_cross_session_learning_updates_with_observed_embedding(tmp_path):
workflow = _workflow("wf_patient_name_must_not_leak", [1.0, 0.0])
processor = _processor(tmp_path, workflow)
observed = _unit([0.9, 0.2])
stats = processor._run_cross_session_learning(
session_id="sess_repeat",
states=[_state()],
embeddings=[observed],
)
assert stats["updates"] == 1
processor._continuous_learner.update_prototype.assert_called_once()
node_key, updated_vector = processor._continuous_learner.update_prototype.call_args.args
kwargs = processor._continuous_learner.update_prototype.call_args.kwargs
assert kwargs["execution_success"] is True
assert np.allclose(updated_vector, observed)
assert not np.allclose(updated_vector, _unit([1.0, 0.0]))
assert "patient" not in node_key
assert "wf_patient" not in node_key
assert node_key.endswith("__node_000")
def test_cross_session_learning_skips_without_existing_workflow(tmp_path):
processor = _processor(tmp_path, workflow=None)
stats = processor._run_cross_session_learning(
session_id="sess_repeat",
states=[_state()],
embeddings=[_unit([0.9, 0.2])],
)
assert stats["updates"] == 0
assert stats["skips"]["no_existing_workflow"] == 1
processor._continuous_learner.update_prototype.assert_not_called()
def test_cross_session_learning_skips_embedding_count_mismatch(tmp_path):
workflow = _workflow("wf_existing", [1.0, 0.0])
processor = _processor(tmp_path, workflow)
stats = processor._run_cross_session_learning(
session_id="sess_repeat",
states=[_state(), _state()],
embeddings=[_unit([0.9, 0.2])],
)
assert stats["updates"] == 0
assert stats["skips"]["embedding_count_mismatch"] == 1
processor._continuous_learner.update_prototype.assert_not_called()
def test_cross_session_learning_skips_exact_prototype_noop(tmp_path):
workflow = _workflow("wf_existing", [1.0, 0.0])
processor = _processor(tmp_path, workflow)
stats = processor._run_cross_session_learning(
session_id="sess_repeat",
states=[_state()],
embeddings=[_unit([1.0, 0.0])],
)
assert stats["updates"] == 0
assert stats["matches"] == 1
assert stats["skips"]["same_as_existing_prototype"] == 1
processor._continuous_learner.update_prototype.assert_not_called()
def test_cross_session_learning_uses_low_confidence_for_drift_only(tmp_path):
workflow = _workflow("wf_existing", [1.0, 0.0])
processor = _processor(tmp_path, workflow)
observed = _unit([0.6, 0.8])
stats = processor._run_cross_session_learning(
session_id="sess_repeat",
states=[_state()],
embeddings=[observed],
)
assert stats["updates"] == 0
assert stats["drift_checks"] == 1
assert stats["skips"]["below_update_threshold"] == 1
processor._continuous_learner.detect_drift.assert_called_once()
processor._continuous_learner.update_prototype.assert_not_called()
def test_learning_node_key_is_safe_and_does_not_leak_workflow_name():
key = StreamProcessor._learning_node_key(
workflow_id="DPI patient DUPONT né 1970",
node_id="node/écran principal",
)
assert "DUPONT" not in key
assert "patient" not in key
assert "/" not in key
assert key.startswith("wf_")
assert "__node_" in key
assert key.isascii()