feat(p11): learn from offline cross-session matches
This commit is contained in:
207
tests/unit/test_stream_processor_cross_session_learning.py
Normal file
207
tests/unit/test_stream_processor_cross_session_learning.py
Normal file
@@ -0,0 +1,207 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agent_v0.server_v1.stream_processor import StreamProcessor
|
||||
from core.models.screen_state import (
|
||||
ContextLevel,
|
||||
EmbeddingRef,
|
||||
PerceptionLevel,
|
||||
RawLevel,
|
||||
ScreenState,
|
||||
WindowContext,
|
||||
)
|
||||
from core.models.workflow_graph import (
|
||||
EmbeddingPrototype,
|
||||
LearningConfig,
|
||||
SafetyRules,
|
||||
ScreenTemplate,
|
||||
TextConstraint,
|
||||
UIConstraint,
|
||||
WindowConstraint,
|
||||
Workflow,
|
||||
WorkflowNode,
|
||||
WorkflowStats,
|
||||
)
|
||||
|
||||
|
||||
def _unit(vector) -> np.ndarray:
|
||||
arr = np.asarray(vector, dtype=np.float32)
|
||||
return arr / np.linalg.norm(arr)
|
||||
|
||||
|
||||
def _state(window_title: str = "Application") -> ScreenState:
|
||||
return ScreenState(
|
||||
screen_state_id="state_obs",
|
||||
timestamp=datetime.now(),
|
||||
session_id="sess_obs",
|
||||
window=WindowContext(
|
||||
app_name="app",
|
||||
window_title=window_title,
|
||||
screen_resolution=[1920, 1080],
|
||||
),
|
||||
raw=RawLevel(
|
||||
screenshot_path="",
|
||||
capture_method="test",
|
||||
file_size_bytes=0,
|
||||
),
|
||||
perception=PerceptionLevel(
|
||||
embedding=EmbeddingRef(provider="test", vector_id="", dimensions=2),
|
||||
detected_text=[],
|
||||
text_detection_method="test",
|
||||
confidence_avg=1.0,
|
||||
),
|
||||
context=ContextLevel(),
|
||||
ui_elements=[],
|
||||
)
|
||||
|
||||
|
||||
def _workflow(workflow_id: str, prototype) -> Workflow:
|
||||
embedding = EmbeddingPrototype(
|
||||
provider="test",
|
||||
vector_id="",
|
||||
min_cosine_similarity=0.85,
|
||||
sample_count=1,
|
||||
)
|
||||
template = ScreenTemplate(
|
||||
window=WindowConstraint(),
|
||||
text=TextConstraint(),
|
||||
ui=UIConstraint(),
|
||||
embedding=embedding,
|
||||
)
|
||||
node = WorkflowNode(
|
||||
node_id="node_000",
|
||||
name="Node 000",
|
||||
description="test node",
|
||||
template=template,
|
||||
metadata={"_prototype_vector": _unit(prototype).tolist()},
|
||||
)
|
||||
return Workflow(
|
||||
workflow_id=workflow_id,
|
||||
name="Existing workflow",
|
||||
description="test workflow",
|
||||
version=1,
|
||||
learning_state="OBSERVATION",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
entry_nodes=["node_000"],
|
||||
end_nodes=[],
|
||||
nodes=[node],
|
||||
edges=[],
|
||||
safety_rules=SafetyRules(),
|
||||
stats=WorkflowStats(),
|
||||
learning=LearningConfig(),
|
||||
)
|
||||
|
||||
|
||||
def _processor(tmp_path, workflow: Workflow | None = None) -> StreamProcessor:
|
||||
processor = StreamProcessor(data_dir=str(tmp_path / "training"))
|
||||
processor._initialized = True
|
||||
processor._continuous_learner = MagicMock()
|
||||
processor._workflows = {}
|
||||
if workflow is not None:
|
||||
processor._workflows[workflow.workflow_id] = workflow
|
||||
return processor
|
||||
|
||||
|
||||
def test_cross_session_learning_updates_with_observed_embedding(tmp_path):
|
||||
workflow = _workflow("wf_patient_name_must_not_leak", [1.0, 0.0])
|
||||
processor = _processor(tmp_path, workflow)
|
||||
observed = _unit([0.9, 0.2])
|
||||
|
||||
stats = processor._run_cross_session_learning(
|
||||
session_id="sess_repeat",
|
||||
states=[_state()],
|
||||
embeddings=[observed],
|
||||
)
|
||||
|
||||
assert stats["updates"] == 1
|
||||
processor._continuous_learner.update_prototype.assert_called_once()
|
||||
node_key, updated_vector = processor._continuous_learner.update_prototype.call_args.args
|
||||
kwargs = processor._continuous_learner.update_prototype.call_args.kwargs
|
||||
assert kwargs["execution_success"] is True
|
||||
assert np.allclose(updated_vector, observed)
|
||||
assert not np.allclose(updated_vector, _unit([1.0, 0.0]))
|
||||
assert "patient" not in node_key
|
||||
assert "wf_patient" not in node_key
|
||||
assert node_key.endswith("__node_000")
|
||||
|
||||
|
||||
def test_cross_session_learning_skips_without_existing_workflow(tmp_path):
|
||||
processor = _processor(tmp_path, workflow=None)
|
||||
|
||||
stats = processor._run_cross_session_learning(
|
||||
session_id="sess_repeat",
|
||||
states=[_state()],
|
||||
embeddings=[_unit([0.9, 0.2])],
|
||||
)
|
||||
|
||||
assert stats["updates"] == 0
|
||||
assert stats["skips"]["no_existing_workflow"] == 1
|
||||
processor._continuous_learner.update_prototype.assert_not_called()
|
||||
|
||||
|
||||
def test_cross_session_learning_skips_embedding_count_mismatch(tmp_path):
|
||||
workflow = _workflow("wf_existing", [1.0, 0.0])
|
||||
processor = _processor(tmp_path, workflow)
|
||||
|
||||
stats = processor._run_cross_session_learning(
|
||||
session_id="sess_repeat",
|
||||
states=[_state(), _state()],
|
||||
embeddings=[_unit([0.9, 0.2])],
|
||||
)
|
||||
|
||||
assert stats["updates"] == 0
|
||||
assert stats["skips"]["embedding_count_mismatch"] == 1
|
||||
processor._continuous_learner.update_prototype.assert_not_called()
|
||||
|
||||
|
||||
def test_cross_session_learning_skips_exact_prototype_noop(tmp_path):
|
||||
workflow = _workflow("wf_existing", [1.0, 0.0])
|
||||
processor = _processor(tmp_path, workflow)
|
||||
|
||||
stats = processor._run_cross_session_learning(
|
||||
session_id="sess_repeat",
|
||||
states=[_state()],
|
||||
embeddings=[_unit([1.0, 0.0])],
|
||||
)
|
||||
|
||||
assert stats["updates"] == 0
|
||||
assert stats["matches"] == 1
|
||||
assert stats["skips"]["same_as_existing_prototype"] == 1
|
||||
processor._continuous_learner.update_prototype.assert_not_called()
|
||||
|
||||
|
||||
def test_cross_session_learning_uses_low_confidence_for_drift_only(tmp_path):
|
||||
workflow = _workflow("wf_existing", [1.0, 0.0])
|
||||
processor = _processor(tmp_path, workflow)
|
||||
observed = _unit([0.6, 0.8])
|
||||
|
||||
stats = processor._run_cross_session_learning(
|
||||
session_id="sess_repeat",
|
||||
states=[_state()],
|
||||
embeddings=[observed],
|
||||
)
|
||||
|
||||
assert stats["updates"] == 0
|
||||
assert stats["drift_checks"] == 1
|
||||
assert stats["skips"]["below_update_threshold"] == 1
|
||||
processor._continuous_learner.detect_drift.assert_called_once()
|
||||
processor._continuous_learner.update_prototype.assert_not_called()
|
||||
|
||||
|
||||
def test_learning_node_key_is_safe_and_does_not_leak_workflow_name():
|
||||
key = StreamProcessor._learning_node_key(
|
||||
workflow_id="DPI patient DUPONT né 1970",
|
||||
node_id="node/écran principal",
|
||||
)
|
||||
|
||||
assert "DUPONT" not in key
|
||||
assert "patient" not in key
|
||||
assert "/" not in key
|
||||
assert key.startswith("wf_")
|
||||
assert "__node_" in key
|
||||
assert key.isascii()
|
||||
Reference in New Issue
Block a user