feat(agent): add learn action flow and grounding guards

This commit is contained in:
Dom
2026-06-02 16:24:10 +02:00
parent 86b3c8f7e7
commit d38f0b0f2f
39 changed files with 5901 additions and 212 deletions

View File

@@ -0,0 +1,254 @@
"""Tests integration pour agent_chat.handlers.learn_action.
Mocks HTTP uniquement — pas de lancement du streaming server réel.
"""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
import pytest
from agent_chat.handlers.learn_action import (
LearnActionOrchestrator,
LearnIntentParser,
LearnState,
StateStore,
StreamingClient,
)
@pytest.fixture
def fake_http_client():
"""Mock httpx.Client (méthode request)."""
client = MagicMock()
return client
def _mk_response(status: int = 200, body: dict | None = None):
resp = MagicMock()
resp.status_code = status
resp.json.return_value = body or {}
resp.text = json.dumps(body or {})
return resp
class TestStreamingClient:
def test_shadow_start_calls_correct_endpoint(self, fake_http_client):
fake_http_client.request.return_value = _mk_response(200, {"ok": True})
client = StreamingClient(
base_url="http://stream:5005",
token="abc",
http_client=fake_http_client,
retries=0,
)
out = client.shadow_start("sid_xyz", user_id="dom")
assert out == {"ok": True}
call = fake_http_client.request.call_args
assert call[0][0] == "POST"
assert call[0][1] == "http://stream:5005/api/v1/shadow/start"
assert call[1]["json"]["session_id"] == "sid_xyz"
assert "Authorization" in call[1]["headers"]
assert call[1]["headers"]["Authorization"] == "Bearer abc"
def test_retry_on_failure(self, fake_http_client):
# 1ere tentative : exception, 2eme : succès
fake_http_client.request.side_effect = [
Exception("conn refused"),
_mk_response(200, {"ok": True}),
]
client = StreamingClient(
base_url="http://stream:5005",
token="",
http_client=fake_http_client,
retries=1,
)
out = client.shadow_stop("sid")
assert out == {"ok": True}
assert fake_http_client.request.call_count == 2
def test_retry_exhausted_raises(self, fake_http_client):
fake_http_client.request.side_effect = Exception("boom")
client = StreamingClient(
base_url="http://stream:5005",
token="",
http_client=fake_http_client,
retries=2,
)
with pytest.raises(RuntimeError, match="unreachable"):
client.shadow_stop("sid")
class TestFullFlowIntegration:
def test_end_to_end_with_http_mock(self, tmp_path, fake_http_client):
# Mock séquence HTTP : start, stop, understanding, build, persist
understanding_body = {
"understanding": [
{"action_type": "click", "target_label": "Patient", "widget_type": "Fenêtre"},
{
"action_type": "type",
"target_label": "IPP",
"widget_type": "Champ",
"value": "25003284",
},
]
}
fake_http_client.request.side_effect = [
_mk_response(200, {"ok": True}), # shadow_start
_mk_response(200, {"ok": True}), # shadow_stop
_mk_response(200, understanding_body), # shadow_understanding
_mk_response(200, {"ok": True}), # shadow_build
_mk_response(200, {"slug": "facture_urg"}), # persist
]
client = StreamingClient(
base_url="http://stream:5005",
token="t",
http_client=fake_http_client,
retries=0,
)
orch = LearnActionOrchestrator(
streaming_client=client,
intent_parser=LearnIntentParser(use_llm_fallback=False),
state_store=StateStore(tmp_path),
emit=MagicMock(),
)
st, _ = orch.start_session(user_id="dom", machine_id="m1")
sid = st.session_id
assert st.state == LearnState.WAITING_USER_STOP
# Stop
orch.handle_chat_message(sid, "c'est bon")
assert orch._sessions[sid].state == LearnState.ITERATING_FEEDBACK
# Validation globale
orch.handle_chat_message(sid, "parfait")
assert orch._sessions[sid].state == LearnState.NAMING
# Nom
orch.handle_chat_message(sid, "facturation urgences")
# Marquer IPP comme paramètre
reply = orch.handle_chat_message(sid, "ça change à chaque fois")
assert orch._sessions[sid].state == LearnState.DONE
assert "facture_urg" in (reply or "")
def test_streaming_down_during_stop(self, tmp_path, fake_http_client):
# shadow_start OK, shadow_stop échoue
fake_http_client.request.side_effect = [
_mk_response(200, {"ok": True}), # shadow_start
Exception("boom 1"), # shadow_stop attempt 1
Exception("boom 2"), # shadow_stop attempt 2 (retry)
Exception("boom 3"), # shadow_stop attempt 3 (retry)
]
client = StreamingClient(
base_url="http://stream:5005",
token="",
http_client=fake_http_client,
retries=2,
)
orch = LearnActionOrchestrator(
streaming_client=client,
intent_parser=LearnIntentParser(use_llm_fallback=False),
state_store=StateStore(tmp_path),
emit=MagicMock(),
)
st, _ = orch.start_session(user_id="dom")
sid = st.session_id
reply = orch.handle_chat_message(sid, "stop")
assert "n'arrive pas à clôturer" in (reply or "") or "réessaie" in (reply or "").lower()
# ============================================================
# POST /api/learn/start (Correction #4)
# ============================================================
class TestApiLearnStart:
"""Tests integration de la route HTTP POST /api/learn/start."""
def _make_orchestrator(self, tmp_path):
client_http = MagicMock()
client_http.request.return_value = _mk_response(200, {"ok": True})
stream = StreamingClient(
base_url="http://stream:5005",
token="",
http_client=client_http,
retries=0,
)
return LearnActionOrchestrator(
streaming_client=stream,
intent_parser=LearnIntentParser(use_llm_fallback=False),
state_store=StateStore(tmp_path),
emit=MagicMock(),
)
def test_api_learn_start_creates_session(self, tmp_path):
from agent_chat import app as app_module
orch = self._make_orchestrator(tmp_path)
app_module.learn_action_orchestrator = orch
try:
client = app_module.app.test_client()
resp = client.post(
"/api/learn/start",
json={
"machine_id": "DESKTOP-58D5CAC_windows",
"user_id": "dom",
"trigger_source": "windows_button",
},
)
assert resp.status_code == 200
data = resp.get_json()
assert "session_id" in data
assert data["state"] == LearnState.WAITING_USER_STOP.value
assert data["message"]
# Vérifie que la session existe bien côté orchestrateur
sid = data["session_id"]
assert orch._sessions[sid].machine_id == "DESKTOP-58D5CAC_windows"
assert orch._sessions[sid].trigger_source == "windows_button"
finally:
app_module.learn_action_orchestrator = None
def test_api_learn_start_400_without_machine_id(self, tmp_path):
from agent_chat import app as app_module
orch = self._make_orchestrator(tmp_path)
app_module.learn_action_orchestrator = orch
try:
client = app_module.app.test_client()
resp = client.post("/api/learn/start", json={"user_id": "dom"})
assert resp.status_code == 400
data = resp.get_json()
assert "machine_id" in (data.get("error") or "").lower()
finally:
app_module.learn_action_orchestrator = None
def test_api_learn_start_400_with_empty_machine_id(self, tmp_path):
from agent_chat import app as app_module
orch = self._make_orchestrator(tmp_path)
app_module.learn_action_orchestrator = orch
try:
client = app_module.app.test_client()
resp = client.post(
"/api/learn/start",
json={"machine_id": " "},
)
assert resp.status_code == 400
finally:
app_module.learn_action_orchestrator = None
def test_api_learn_start_503_if_orchestrator_not_initialized(self):
from agent_chat import app as app_module
prev = app_module.learn_action_orchestrator
app_module.learn_action_orchestrator = None
try:
client = app_module.app.test_client()
resp = client.post(
"/api/learn/start",
json={"machine_id": "m1"},
)
assert resp.status_code == 503
finally:
app_module.learn_action_orchestrator = prev

View File

@@ -15,8 +15,10 @@ garantit que l'env est defini AVANT tout import.
from __future__ import annotations
import os
import sqlite3
import sys
import tempfile
import time
from pathlib import Path
import pytest
@@ -273,6 +275,107 @@ def test_reenroll_after_uninstall_reactivates(agents_client):
assert agent["version"] == "1.1.0"
def test_reenroll_after_admin_revoke_is_forbidden(agents_client):
client, token, _ = agents_client
client.post(
"/api/v1/agents/enroll",
json={"machine_id": "revoked-001", "user_name": "Revoked"},
headers=_auth_headers(token),
)
revoke = client.post(
"/api/v1/agents/uninstall",
json={"machine_id": "revoked-001", "reason": "admin_revoke"},
headers=_auth_headers(token),
)
assert revoke.status_code == 200
resp = client.post(
"/api/v1/agents/enroll",
json={"machine_id": "revoked-001", "user_name": "Revoked Again"},
headers=_auth_headers(token),
)
assert resp.status_code == 403, resp.text
detail = resp.json()["detail"]
assert detail["error"] == "agent_revoked"
assert detail["existing"]["machine_id"] == "revoked-001"
assert detail["existing"]["uninstall_reason"] == "admin_revoke"
def test_revoked_agent_cannot_stream_or_poll(agents_client):
client, token, _ = agents_client
client.post(
"/api/v1/agents/enroll",
json={"machine_id": "revoked-runtime-001", "user_name": "Runtime"},
headers=_auth_headers(token),
)
client.post(
"/api/v1/agents/uninstall",
json={"machine_id": "revoked-runtime-001", "reason": "admin_revoke"},
headers=_auth_headers(token),
)
event_resp = client.post(
"/api/v1/traces/stream/event",
json={
"session_id": "sess_revoked_runtime",
"timestamp": time.time(),
"event": {"type": "heartbeat"},
"machine_id": "revoked-runtime-001",
},
headers=_auth_headers(token),
)
assert event_resp.status_code == 403, event_resp.text
assert event_resp.json()["detail"]["error"] == "agent_not_active"
next_resp = client.get(
"/api/v1/traces/stream/replay/next",
params={
"session_id": "sess_revoked_runtime",
"machine_id": "revoked-runtime-001",
},
headers=_auth_headers(token),
)
assert next_resp.status_code == 403, next_resp.text
assert next_resp.json()["detail"]["error"] == "agent_not_active"
def test_active_agent_stream_updates_last_seen(agents_client):
client, token, registry = agents_client
machine_id = "last-seen-001"
client.post(
"/api/v1/agents/enroll",
json={"machine_id": machine_id, "user_name": "Seen"},
headers=_auth_headers(token),
)
stale = "2000-01-01T00:00:00+00:00"
with sqlite3.connect(str(registry.db_path)) as conn:
conn.execute(
"UPDATE enrolled_agents SET last_seen_at = ? WHERE machine_id = ?",
(stale, machine_id),
)
conn.commit()
resp = client.post(
"/api/v1/traces/stream/event",
json={
"session_id": "sess_last_seen",
"timestamp": time.time(),
"event": {"type": "heartbeat"},
"machine_id": machine_id,
},
headers=_auth_headers(token),
)
assert resp.status_code == 200, resp.text
row = registry.get(machine_id)
assert row is not None
assert row["last_seen_at"] != stale
# ---------------------------------------------------------------------------
# GET /api/v1/agents/fleet
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,198 @@
"""Mesure du gain perf RPA_SKIP_INTENTION_ENRICHMENT sur build_replay.
Harnais lecture seule : charge une fixture raw events réelle (smoke Bloc-notes
2026-05-20 - même session que replay_sess_e96e5822 18/18 du 2026-05-25) et
appelle directement build_replay_from_raw_events() sans déclencher dispatch
ni replay live.
Ne pas lancer en CI standard : test perf, run manuel uniquement.
Run :
.venv/bin/python -m pytest tests/integration/test_build_replay_perf.py \
-m performance -s -v
Référence : inbox_claude/2026-05-25_1244_codex-to-claude_recadrage-demo-1juin.md
(mission C2) et plan docs/plans/PLAN_STABILISATION_DEMO_2026-06-01.md
(P0 performance mesurable).
"""
from __future__ import annotations
import json
import sys
import time
from pathlib import Path
import pytest
ROOT = Path(__file__).resolve().parents[2]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
SESSION_DIR = (
ROOT
/ "data"
/ "training"
/ "live_sessions"
/ "DESKTOP-58D5CAC_windows"
/ "sess_20260520T102916_066851"
)
FIXTURE = SESSION_DIR / "live_events.jsonl"
def _load_raw_events() -> list:
"""Charge la fixture raw events réelle (55 events bruts, 16 actions utiles)."""
if not FIXTURE.exists():
pytest.skip(f"Fixture absente : {FIXTURE}")
with FIXTURE.open() as f:
return [json.loads(line) for line in f if line.strip()]
@pytest.fixture
def raw_events():
return _load_raw_events()
@pytest.fixture
def session_dir() -> str:
"""Chemin vers session_dir (déclenche l'enrichissement gemma4 si présent)."""
if not SESSION_DIR.exists():
pytest.skip(f"Session dir absent : {SESSION_DIR}")
return str(SESSION_DIR)
def _extract_perf_breakdown(caplog) -> list[tuple[str, float]]:
"""Extrait les spans [PERF] build.step* des logs capturés.
Format attendu : "[PERF] build.<step_name> session=<sid> elapsed_ms=<X>"
Retourne [(step_name, elapsed_ms)] dans l'ordre d'apparition.
"""
import re
pattern = re.compile(r"\[PERF\] build\.(\S+) session=\S+ elapsed_ms=([\d.]+)")
out = []
for record in caplog.records:
m = pattern.search(record.getMessage())
if m:
out.append((m.group(1), float(m.group(2))))
return out
@pytest.mark.performance
def test_build_replay_perf_skip_enrichment(monkeypatch, raw_events, session_dir, caplog):
"""Mesure build_replay_from_raw_events avec et sans RPA_SKIP_INTENTION_ENRICHMENT.
Asserts :
- skip enrichissement est au moins 3x plus rapide
- même nombre d'actions produites dans les 2 modes
- skip → 0 actions avec intention non-vide
- full → au moins 1 action avec intention (preuve que gemma4 a tourné)
Print [PERF] explicite des deux mesures (capturé via -s).
"""
import logging
from agent_v0.server_v1.stream_processor import build_replay_from_raw_events
# Capter les logs INFO du stream_processor pour récupérer les spans [PERF]
caplog.set_level(logging.INFO, logger="agent_v0.server_v1.stream_processor")
# Premier run : enrichissement actif (comportement legacy)
monkeypatch.delenv("RPA_SKIP_INTENTION_ENRICHMENT", raising=False)
monkeypatch.delenv("RPA_SKIP_ENRICHMENT", raising=False)
t0 = time.perf_counter()
actions_full = build_replay_from_raw_events(
raw_events, session_id="perf_full", session_dir=session_dir
)
elapsed_full_ms = (time.perf_counter() - t0) * 1000
breakdown_full = _extract_perf_breakdown(caplog)
caplog.clear()
# Second run : skip enrichissement activé (Phase 1 Codex 2026-05-25)
monkeypatch.setenv("RPA_SKIP_INTENTION_ENRICHMENT", "1")
t0 = time.perf_counter()
actions_skip = build_replay_from_raw_events(
raw_events, session_id="perf_skip", session_dir=session_dir
)
elapsed_skip_ms = (time.perf_counter() - t0) * 1000
breakdown_skip = _extract_perf_breakdown(caplog)
speedup = elapsed_full_ms / max(1.0, elapsed_skip_ms)
intentions_full = sum(1 for a in actions_full if a.get("intention"))
intentions_skip = sum(1 for a in actions_skip if a.get("intention"))
print(
f"\n[PERF] build_replay events={len(raw_events)} "
f"actions_full={len(actions_full)} actions_skip={len(actions_skip)} "
f"full_ms={elapsed_full_ms:.0f} skip_ms={elapsed_skip_ms:.0f} "
f"speedup={speedup:.1f}x "
f"intentions_full={intentions_full} intentions_skip={intentions_skip}"
)
# Décomposition par étape (C2b) — utile pour identifier les vraies cibles
# d'optimisation post-skip enrichissement.
def _format_breakdown(label: str, b: list[tuple[str, float]]) -> str:
if not b:
return f" {label}: (aucun span [PERF] capturé)"
lines = [f" {label}:"]
for step, ms in b:
bar = "" * max(1, int(ms / 500)) # 1 char par 500ms
lines.append(f" {step:40s} {ms:>7.0f} ms {bar}")
return "\n".join(lines)
print(_format_breakdown("Décomposition FULL", breakdown_full))
print(_format_breakdown("Décomposition SKIP", breakdown_skip))
# Invariants — même nombre d'actions, juste les champs intention en moins
assert len(actions_skip) == len(actions_full), (
f"Le skip ne doit pas changer le nombre d'actions "
f"(full={len(actions_full)}, skip={len(actions_skip)})"
)
# Skip → 0 actions avec intention enrichie
assert intentions_skip == 0, (
f"Skip enrichment doit produire 0 intention non-vide "
f"(observé : {intentions_skip})"
)
# Full → au moins 1 action avec intention (sinon gemma4 a planté ou la
# fixture n'a pas d'action eligible). Si 0, c'est anormal et on échoue
# bruyamment.
assert intentions_full > 0, (
f"Full enrichment doit produire au moins 1 intention non-vide "
f"sur fixture {FIXTURE.name}. Si 0 → gemma4 indisponible ou fixture "
f"non éligible (toutes les actions filtrées avant enrichissement)."
)
# Gain perf minimum : 3x.
# Mesure réelle observée (2026-05-25 sur fixture 16 actions, 9 enrichies) :
# full=93.8s, skip=24.1s, speedup=3.9x.
# Le mode skip n'est pas instantané (~24s) car d'autres étapes consomment
# du temps : extraction crops d'ancrage pour clics visual_mode, consolidation
# avec ReplayLearner, normalisation des waits, etc. Seul gemma4 est skippé.
# Estimation initiale 215x était basée sur l'hypothèse "gemma4 seul gros
# coût" — invalidée par la mesure.
assert speedup >= 3.0, (
f"Gain insuffisant : {speedup:.1f}x (attendu ≥ 3x). "
f"Soit gemma4 cache-hit, soit la fixture n'a pas d'action éligible, "
f"soit Ollama indisponible (fallback rapide). full_ms={elapsed_full_ms:.0f}, "
f"skip_ms={elapsed_skip_ms:.0f}."
)
@pytest.mark.performance
def test_build_replay_skip_alias_works(monkeypatch, raw_events, session_dir):
"""Vérifie que l'alias RPA_SKIP_ENRICHMENT a le même effet."""
from agent_v0.server_v1.stream_processor import build_replay_from_raw_events
monkeypatch.delenv("RPA_SKIP_INTENTION_ENRICHMENT", raising=False)
monkeypatch.setenv("RPA_SKIP_ENRICHMENT", "1")
actions = build_replay_from_raw_events(
raw_events, session_id="perf_alias", session_dir=session_dir
)
intentions = sum(1 for a in actions if a.get("intention"))
print(f"\n[PERF] alias RPA_SKIP_ENRICHMENT actions={len(actions)} intentions={intentions}")
assert intentions == 0, (
f"L'alias RPA_SKIP_ENRICHMENT doit aussi désactiver l'enrichissement "
f"(observé : {intentions} intentions)"
)

View File

@@ -65,7 +65,7 @@ def test_tpl_need_confirm_extracts_action_description():
def test_tpl_need_confirm_fallback():
_, _, title = cw._tpl_need_confirm({})
assert "Validation" in title
assert "accord" in title
def test_tpl_step_result_ok():

View File

@@ -24,15 +24,19 @@ class TestReplayResumePreservesOriginalAction:
monkeypatch.setattr(api_stream_mod, "API_TOKEN", self._TEST_API_TOKEN)
@pytest.fixture
def client(self, monkeypatch):
def client(self, monkeypatch, tmp_path):
from fastapi.testclient import TestClient
from agent_v0.server_v1 import api_stream
from agent_v0.server_v1.agent_registry import AgentRegistry
monkeypatch.setattr(api_stream, "API_TOKEN", self._TEST_API_TOKEN)
saved_states = dict(api_stream._replay_states)
saved_queues = dict(api_stream._replay_queues)
saved_retry = dict(api_stream._retry_pending)
original_registry = api_stream.agent_registry
empty_registry = AgentRegistry(db_path=str(tmp_path / "empty_agents.db"))
monkeypatch.setattr(api_stream, "agent_registry", empty_registry)
api_stream._replay_states.clear()
api_stream._replay_queues.clear()
@@ -47,6 +51,7 @@ class TestReplayResumePreservesOriginalAction:
api_stream._replay_queues.update(saved_queues)
api_stream._retry_pending.clear()
api_stream._retry_pending.update(saved_retry)
monkeypatch.setattr(api_stream, "agent_registry", original_registry)
def test_resume_reinjects_full_original_action_from_failed_action(self, client):
http_client, api_stream, token = client
@@ -144,6 +149,7 @@ class TestReplayResumePreservesOriginalAction:
next_resp = http_client.get(
"/api/v1/traces/stream/replay/next",
params={"session_id": "sess_resume_watchdog", "machine_id": "pc-watchdog"},
headers={"Authorization": f"Bearer {token}"},
)
assert next_resp.status_code == 200

View File

@@ -104,11 +104,12 @@ def test_replay_session_pipeline_skips_redundant_tab_switch(tmp_path):
# 1) Setup auto reconnaît Notepad et génère ses actions
assert app_info.get("primary_app") == "Notepad.exe"
assert app_info.get("has_neutral_window_title") is True
setup_actions = _generate_setup_actions(app_info, setup_id_prefix="setup_sess")
assert setup_actions, "le setup auto doit injecter des actions Notepad"
action_ids = {a.get("action_id", "") for a in setup_actions}
assert any("click_start" in aid for aid in action_ids)
assert any("click_result" in aid for aid in action_ids)
setup_steps = [a.get("_setup_step", "") for a in setup_actions]
assert "open_run_dialog" in setup_steps
assert "ensure_fresh_document" in setup_steps
# 2) Trim : le clic intra-Notepad redondant doit disparaître
trimmed = _trim_redundant_setup_events(raw_events, app_info)

View File

@@ -213,6 +213,24 @@ def test_edge_to_action_extract_text():
assert a["parameters"]["paragraph"] is True
def test_edge_to_action_extract_table_accepts_tesseract_engine_and_variable_name():
edge = _FakeEdge(_FakeAction(
"extract_table",
parameters={
"variable_name": "t_extraction_liste",
"pattern": r"^25\d{6}$",
"engine": "tesseract",
},
))
actions = _edge_to_normalized_actions(edge, params={})
assert len(actions) == 1
a = actions[0]
assert a["type"] == "extract_table"
assert a["parameters"]["output_var"] == "t_extraction_liste"
assert a["parameters"]["pattern"] == r"^25\d{6}$"
assert a["parameters"]["engine"] == "tesseract"
def test_edge_to_action_t2a_decision():
edge = _FakeEdge(_FakeAction(
"t2a_decision",