feat(agent): add learn action flow and grounding guards
This commit is contained in:
526
tests/unit/test_agent_chat_learn_action.py
Normal file
526
tests/unit/test_agent_chat_learn_action.py
Normal file
@@ -0,0 +1,526 @@
|
||||
"""Tests unit pour agent_chat.handlers.learn_action.
|
||||
|
||||
Couvre :
|
||||
- LearnIntentParser (regex)
|
||||
- OptionCFormatter
|
||||
- StateStore (write atomique + reprise)
|
||||
- LearnActionOrchestrator (transitions, garde-fous, persistance)
|
||||
- PersistPayloadBuilder
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_chat.handlers.learn_action import (
|
||||
LearnActionOrchestrator,
|
||||
LearnIntent,
|
||||
LearnIntentParser,
|
||||
LearnState,
|
||||
OptionCFormatter,
|
||||
PersistPayloadBuilder,
|
||||
SessionState,
|
||||
StateStore,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# LearnIntentParser
|
||||
# ============================================================
|
||||
class TestLearnIntentParser:
|
||||
def setup_method(self):
|
||||
# Désactive le LLM fallback pour isoler les tests regex
|
||||
self.parser = LearnIntentParser(use_llm_fallback=False)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msg",
|
||||
[
|
||||
"apprends-moi",
|
||||
"Apprends moi",
|
||||
"regarde-moi faire",
|
||||
"observe",
|
||||
"enregistre",
|
||||
"on apprend",
|
||||
"tu vas apprendre",
|
||||
"Léa apprends",
|
||||
],
|
||||
)
|
||||
def test_start_observe(self, msg):
|
||||
r = self.parser.parse(msg, current_state=LearnState.IDLE)
|
||||
assert r.intent == LearnIntent.START_OBSERVE
|
||||
assert r.confidence >= 0.9
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msg",
|
||||
[
|
||||
"stop",
|
||||
"c'est bon",
|
||||
"j'ai fini",
|
||||
"voilà c'est tout",
|
||||
"fini",
|
||||
"arrête",
|
||||
"termine",
|
||||
],
|
||||
)
|
||||
def test_user_stop_observe(self, msg):
|
||||
r = self.parser.parse(msg, current_state=LearnState.WAITING_USER_STOP)
|
||||
assert r.intent == LearnIntent.USER_STOP_OBSERVE
|
||||
|
||||
def test_correct_step_with_index(self):
|
||||
r = self.parser.parse(
|
||||
"Corrige l'étape 3 : il faut cliquer sur Valider",
|
||||
current_state=LearnState.ITERATING_FEEDBACK,
|
||||
)
|
||||
assert r.intent == LearnIntent.CORRECT_STEP
|
||||
assert r.step_index == 3
|
||||
assert "valider" in (r.extra.get("new_intent") or "").lower()
|
||||
|
||||
def test_undo_step(self):
|
||||
r = self.parser.parse(
|
||||
"Retire l'étape 2", current_state=LearnState.ITERATING_FEEDBACK
|
||||
)
|
||||
assert r.intent == LearnIntent.UNDO_STEP
|
||||
assert r.step_index == 2
|
||||
|
||||
def test_merge_next(self):
|
||||
r = self.parser.parse(
|
||||
"Fusionne avec la suivante", current_state=LearnState.ITERATING_FEEDBACK
|
||||
)
|
||||
assert r.intent == LearnIntent.MERGE_NEXT
|
||||
|
||||
def test_split_step(self):
|
||||
r = self.parser.parse(
|
||||
"Coupe l'étape 4", current_state=LearnState.ITERATING_FEEDBACK
|
||||
)
|
||||
assert r.intent == LearnIntent.SPLIT_STEP
|
||||
assert r.step_index == 4
|
||||
|
||||
def test_cancel(self):
|
||||
r = self.parser.parse("annule tout", current_state=LearnState.LISTENING)
|
||||
assert r.intent == LearnIntent.CANCEL
|
||||
|
||||
def test_validate_in_iterating(self):
|
||||
r = self.parser.parse(
|
||||
"c'est parfait", current_state=LearnState.ITERATING_FEEDBACK
|
||||
)
|
||||
assert r.intent == LearnIntent.VALIDATE_STEP
|
||||
|
||||
def test_mark_parameter_variable(self):
|
||||
r = self.parser.parse(
|
||||
"ça change à chaque fois", current_state=LearnState.NAMING
|
||||
)
|
||||
assert r.intent == LearnIntent.MARK_PARAMETER
|
||||
assert r.extra.get("is_parameter") is True
|
||||
|
||||
def test_mark_parameter_constant(self):
|
||||
r = self.parser.parse(
|
||||
"toujours pareil", current_state=LearnState.NAMING
|
||||
)
|
||||
assert r.intent == LearnIntent.MARK_PARAMETER
|
||||
assert r.extra.get("is_parameter") is False
|
||||
|
||||
def test_name_competence_when_naming(self):
|
||||
r = self.parser.parse(
|
||||
"facturation urgences", current_state=LearnState.NAMING
|
||||
)
|
||||
assert r.intent == LearnIntent.NAME_COMPETENCE
|
||||
assert "facturation" in (r.extra.get("name") or "")
|
||||
|
||||
def test_unknown_in_idle(self):
|
||||
r = self.parser.parse(
|
||||
"blabla random", current_state=LearnState.IDLE
|
||||
)
|
||||
assert r.intent == LearnIntent.UNKNOWN
|
||||
|
||||
def test_llm_fallback_disabled_after_failure(self, monkeypatch):
|
||||
# Active le LLM mais simule une erreur réseau
|
||||
parser = LearnIntentParser(use_llm_fallback=True)
|
||||
# Force exception sur httpx
|
||||
parser._parse_llm = lambda *args, **kwargs: None # type: ignore[method-assign]
|
||||
r = parser.parse("zorglub blabla truc", current_state=LearnState.IDLE)
|
||||
# Doit retomber gracieusement sur UNKNOWN sans crasher
|
||||
assert r.intent == LearnIntent.UNKNOWN
|
||||
|
||||
|
||||
# ============================================================
|
||||
# OptionCFormatter
|
||||
# ============================================================
|
||||
class TestOptionCFormatter:
|
||||
def setup_method(self):
|
||||
self.fmt = OptionCFormatter()
|
||||
|
||||
def test_empty(self):
|
||||
assert "aucune étape" in self.fmt.format([])
|
||||
|
||||
def test_simple_click(self):
|
||||
understanding = [
|
||||
{"action_type": "click", "target_label": "Valider", "widget_type": "Bouton"}
|
||||
]
|
||||
out = self.fmt.format(understanding)
|
||||
assert "1." in out
|
||||
assert "« Valider »" in out
|
||||
assert "cliqué" in out
|
||||
|
||||
def test_type_with_value(self):
|
||||
understanding = [
|
||||
{
|
||||
"action_type": "type",
|
||||
"target_label": "IPP",
|
||||
"widget_type": "Champ",
|
||||
"value": "25003284",
|
||||
}
|
||||
]
|
||||
out = self.fmt.format(understanding)
|
||||
assert "« IPP »" in out
|
||||
assert "« 25003284 »" in out
|
||||
assert "saisi" in out
|
||||
|
||||
def test_low_confidence_suffix(self):
|
||||
understanding = [
|
||||
{
|
||||
"action_type": "click",
|
||||
"target_label": "Patient",
|
||||
"widget_type": "Fenêtre",
|
||||
"confidence_ocr": 0.4,
|
||||
}
|
||||
]
|
||||
out = self.fmt.format(understanding)
|
||||
assert "(à confirmer)" in out
|
||||
|
||||
def test_unknown_action_fallback(self):
|
||||
understanding = [{"action_type": "wibble", "target_label": "X"}]
|
||||
out = self.fmt.format(understanding)
|
||||
assert "effectuée" in out
|
||||
|
||||
def test_closing_question(self):
|
||||
q = self.fmt.closing_question()
|
||||
assert "trompée" in q or "trompee" in q.lower().replace("é", "e")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# StateStore
|
||||
# ============================================================
|
||||
class TestStateStore:
|
||||
def test_save_and_load(self, tmp_path):
|
||||
store = StateStore(tmp_path)
|
||||
st = SessionState(
|
||||
session_id="abc123",
|
||||
user_id="dom",
|
||||
state=LearnState.ITERATING_FEEDBACK,
|
||||
)
|
||||
store.save(st)
|
||||
loaded = store.load("abc123")
|
||||
assert loaded is not None
|
||||
assert loaded.session_id == "abc123"
|
||||
assert loaded.user_id == "dom"
|
||||
assert loaded.state == LearnState.ITERATING_FEEDBACK
|
||||
|
||||
def test_atomic_write_no_partial(self, tmp_path):
|
||||
store = StateStore(tmp_path)
|
||||
st = SessionState(session_id="atomic1")
|
||||
store.save(st)
|
||||
# Pas de fichier .tmp restant
|
||||
tmp_files = list(tmp_path.glob("*.tmp"))
|
||||
assert tmp_files == []
|
||||
|
||||
def test_list_active_filters_done(self, tmp_path):
|
||||
store = StateStore(tmp_path)
|
||||
store.save(SessionState(session_id="s1", state=LearnState.ITERATING_FEEDBACK))
|
||||
store.save(SessionState(session_id="s2", state=LearnState.DONE))
|
||||
store.save(SessionState(session_id="s3", state=LearnState.ABORTED))
|
||||
active = store.list_active()
|
||||
ids = {s.session_id for s in active}
|
||||
assert ids == {"s1"}
|
||||
|
||||
def test_session_id_sanitized(self, tmp_path):
|
||||
store = StateStore(tmp_path)
|
||||
st = SessionState(session_id="../../etc/passwd")
|
||||
store.save(st)
|
||||
# Aucun fichier hors tmp_path
|
||||
files = list(tmp_path.glob("*.json"))
|
||||
assert len(files) == 1
|
||||
assert files[0].parent == tmp_path
|
||||
|
||||
def test_delete(self, tmp_path):
|
||||
store = StateStore(tmp_path)
|
||||
store.save(SessionState(session_id="del_me"))
|
||||
store.delete("del_me")
|
||||
assert store.load("del_me") is None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# PersistPayloadBuilder
|
||||
# ============================================================
|
||||
class TestPersistPayloadBuilder:
|
||||
def test_build_with_parameters(self):
|
||||
st = SessionState(
|
||||
session_id="sX",
|
||||
competence_name="Test compétence",
|
||||
user_id="dom",
|
||||
parameters_marked=[
|
||||
{
|
||||
"step_index": 3,
|
||||
"is_parameter": True,
|
||||
"name": "ipp",
|
||||
"example_value": "25003284",
|
||||
"field_label": "IPP",
|
||||
},
|
||||
{
|
||||
"step_index": 4,
|
||||
"is_parameter": False,
|
||||
"name": "type",
|
||||
"example_value": "C2",
|
||||
"field_label": "Type",
|
||||
},
|
||||
],
|
||||
)
|
||||
payload = PersistPayloadBuilder().build(st)
|
||||
assert payload["name"] == "Test compétence"
|
||||
assert payload["session_id"] == "sX"
|
||||
assert payload["user_id"] == "dom"
|
||||
# Seul le param flagué is_parameter=True doit apparaître
|
||||
assert len(payload["parameters"]) == 1
|
||||
assert payload["parameters"][0]["name"] == "ipp"
|
||||
|
||||
def test_persist_payload_includes_machine_id(self):
|
||||
"""Correction #1 — payload doit inclure machine_id."""
|
||||
st = SessionState(
|
||||
session_id="sM",
|
||||
competence_name="X",
|
||||
machine_id="DESKTOP-58D5CAC_windows",
|
||||
)
|
||||
payload = PersistPayloadBuilder().build(st)
|
||||
assert "machine_id" in payload
|
||||
assert payload["machine_id"] == "DESKTOP-58D5CAC_windows"
|
||||
|
||||
def test_persist_payload_machine_id_none_when_absent(self):
|
||||
"""Quand non fourni, machine_id reste présent à None dans le payload."""
|
||||
st = SessionState(session_id="sM2", competence_name="X")
|
||||
payload = PersistPayloadBuilder().build(st)
|
||||
assert "machine_id" in payload
|
||||
assert payload["machine_id"] is None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# LearnActionOrchestrator (avec StreamingClient mocké)
|
||||
# ============================================================
|
||||
@pytest.fixture
|
||||
def mock_streaming():
|
||||
"""StreamingClient simulé."""
|
||||
m = MagicMock()
|
||||
m.shadow_start.return_value = {"ok": True}
|
||||
m.shadow_stop.return_value = {"ok": True}
|
||||
m.shadow_understanding.return_value = {
|
||||
"understanding": [
|
||||
{"action_type": "click", "target_label": "Patient", "widget_type": "Fenêtre"},
|
||||
{
|
||||
"action_type": "type",
|
||||
"target_label": "IPP",
|
||||
"widget_type": "Champ",
|
||||
"value": "25003284",
|
||||
},
|
||||
]
|
||||
}
|
||||
m.shadow_feedback.return_value = {"ok": True}
|
||||
m.shadow_build.return_value = {"ok": True}
|
||||
m.competence_persist.return_value = {"slug": "facturation_urgences"}
|
||||
return m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def orchestrator(tmp_path, mock_streaming):
|
||||
parser = LearnIntentParser(use_llm_fallback=False)
|
||||
store = StateStore(tmp_path)
|
||||
return LearnActionOrchestrator(
|
||||
streaming_client=mock_streaming,
|
||||
intent_parser=parser,
|
||||
state_store=store,
|
||||
emit=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
class TestLearnActionOrchestrator:
|
||||
def test_start_session_transitions(self, orchestrator, mock_streaming):
|
||||
st, reply = orchestrator.start_session(user_id="dom", trigger_source="button")
|
||||
assert st.state == LearnState.WAITING_USER_STOP
|
||||
mock_streaming.shadow_start.assert_called_once()
|
||||
assert "je te regarde" in reply.lower() or "regarde" in reply.lower()
|
||||
|
||||
def test_full_happy_path(self, orchestrator, mock_streaming):
|
||||
st, _ = orchestrator.start_session(user_id="dom", machine_id="m1")
|
||||
sid = st.session_id
|
||||
|
||||
# Utilisateur dit stop
|
||||
reply = orchestrator.handle_chat_message(sid, "c'est bon")
|
||||
assert reply is not None
|
||||
assert "j'ai compris" in reply.lower()
|
||||
assert orchestrator._sessions[sid].state == LearnState.ITERATING_FEEDBACK
|
||||
|
||||
# Utilisateur valide globalement → NAMING
|
||||
reply = orchestrator.handle_chat_message(sid, "c'est parfait")
|
||||
assert orchestrator._sessions[sid].state == LearnState.NAMING
|
||||
|
||||
# Nomination
|
||||
reply = orchestrator.handle_chat_message(sid, "facturation urgences")
|
||||
# Maintenant Léa doit poser une question sur le paramètre IPP
|
||||
assert "25003284" in (reply or "")
|
||||
assert orchestrator._sessions[sid].competence_name == "facturation urgences"
|
||||
|
||||
# Marquer le paramètre comme variable
|
||||
reply = orchestrator.handle_chat_message(sid, "ça change à chaque fois")
|
||||
# Plus de pending → persist
|
||||
mock_streaming.shadow_build.assert_called_once()
|
||||
mock_streaming.competence_persist.assert_called_once()
|
||||
assert orchestrator._sessions[sid].state == LearnState.DONE
|
||||
|
||||
def test_emergency_exit_after_3_corrections(self, orchestrator, mock_streaming):
|
||||
st, _ = orchestrator.start_session(user_id="dom")
|
||||
sid = st.session_id
|
||||
orchestrator.handle_chat_message(sid, "c'est bon") # stop
|
||||
|
||||
for i in range(3):
|
||||
r = orchestrator.handle_chat_message(
|
||||
sid, "corrige l'étape 3 : clique sur Valider"
|
||||
)
|
||||
assert orchestrator._sessions[sid].state == LearnState.ITERATING_FEEDBACK
|
||||
|
||||
# 4e correction → ABORTED
|
||||
r = orchestrator.handle_chat_message(
|
||||
sid, "corrige l'étape 3 : clique sur Valider"
|
||||
)
|
||||
assert orchestrator._sessions[sid].state == LearnState.ABORTED
|
||||
assert "n°3" in (r or "")
|
||||
|
||||
def test_cancel_anywhere(self, orchestrator, mock_streaming):
|
||||
st, _ = orchestrator.start_session(user_id="dom")
|
||||
sid = st.session_id
|
||||
reply = orchestrator.handle_chat_message(sid, "annule tout")
|
||||
assert orchestrator._sessions[sid].state == LearnState.ABORTED
|
||||
assert "annule" in (reply or "").lower()
|
||||
|
||||
def test_idle_message_returns_none(self, orchestrator):
|
||||
# Aucune session ouverte → None (laisser le flux normal gérer)
|
||||
r = orchestrator.handle_chat_message("nonexistent", "Bonjour")
|
||||
assert r is None
|
||||
|
||||
def test_state_persistence_across_reload(self, tmp_path, mock_streaming):
|
||||
store = StateStore(tmp_path)
|
||||
parser = LearnIntentParser(use_llm_fallback=False)
|
||||
orch1 = LearnActionOrchestrator(
|
||||
streaming_client=mock_streaming,
|
||||
intent_parser=parser,
|
||||
state_store=store,
|
||||
emit=MagicMock(),
|
||||
)
|
||||
st, _ = orch1.start_session(user_id="dom")
|
||||
sid = st.session_id
|
||||
orch1.handle_chat_message(sid, "c'est bon") # passe en ITERATING_FEEDBACK
|
||||
|
||||
# Simule un crash + redémarrage
|
||||
orch2 = LearnActionOrchestrator(
|
||||
streaming_client=mock_streaming,
|
||||
intent_parser=parser,
|
||||
state_store=store,
|
||||
emit=MagicMock(),
|
||||
)
|
||||
resumed = orch2.resume_sessions()
|
||||
assert sid in resumed
|
||||
assert orch2._sessions[sid].state == LearnState.ITERATING_FEEDBACK
|
||||
|
||||
def test_proactive_signal_cooldown(self, orchestrator):
|
||||
r1 = orchestrator.handle_proactive_signal("action_repeat", {})
|
||||
assert r1 is not None
|
||||
# Deuxième signal immédiat → ignoré
|
||||
r2 = orchestrator.handle_proactive_signal("action_repeat", {})
|
||||
assert r2 is None
|
||||
|
||||
def test_illegal_transition_ignored(self, orchestrator, mock_streaming):
|
||||
st, _ = orchestrator.start_session(user_id="dom")
|
||||
# Tentative de passer directement de WAITING_USER_STOP à DONE
|
||||
prev = orchestrator._sessions[st.session_id].state
|
||||
orchestrator._transition(
|
||||
orchestrator._sessions[st.session_id], LearnState.DONE
|
||||
)
|
||||
assert orchestrator._sessions[st.session_id].state == prev
|
||||
|
||||
# ============================================================
|
||||
# Corrections P1-LEA-SHADOW 2026-06-01 (NO-GO Qwen)
|
||||
# ============================================================
|
||||
def test_start_session_stores_machine_id(self, orchestrator):
|
||||
"""Correction #1 — machine_id transmis à start_session est stocké."""
|
||||
st, _ = orchestrator.start_session(
|
||||
user_id="dom",
|
||||
trigger_source="windows_button",
|
||||
machine_id="DESKTOP-58D5CAC_windows",
|
||||
)
|
||||
assert st.machine_id == "DESKTOP-58D5CAC_windows"
|
||||
# Et la session en mémoire aussi
|
||||
assert (
|
||||
orchestrator._sessions[st.session_id].machine_id
|
||||
== "DESKTOP-58D5CAC_windows"
|
||||
)
|
||||
|
||||
def test_persist_blocked_without_machine_id(self, orchestrator, mock_streaming):
|
||||
"""Correction #1 — persist refusé conversationnellement sans machine_id."""
|
||||
st, _ = orchestrator.start_session(user_id="dom") # pas de machine_id
|
||||
sid = st.session_id
|
||||
orchestrator.handle_chat_message(sid, "c'est bon") # → ITERATING
|
||||
orchestrator.handle_chat_message(sid, "c'est parfait") # → NAMING
|
||||
orchestrator.handle_chat_message(sid, "ma competence") # nom
|
||||
# Marquer paramètre → tentative persist
|
||||
reply = orchestrator.handle_chat_message(sid, "ça change à chaque fois")
|
||||
# competence_persist NE doit PAS avoir été appelée
|
||||
mock_streaming.competence_persist.assert_not_called()
|
||||
# Message métier explicite côté Léa
|
||||
assert reply is not None
|
||||
assert "machine" in reply.lower()
|
||||
|
||||
def test_datetime_uses_timezone_aware(self):
|
||||
"""Correction #2 — created_at / last_transition_at sont timezone-aware."""
|
||||
st = SessionState(session_id="tz1")
|
||||
# Le format ISO doit contenir un offset (+00:00 ou Z) — tzinfo présent
|
||||
# après reparse via fromisoformat (Python 3.11+).
|
||||
from datetime import datetime as _dt
|
||||
parsed_created = _dt.fromisoformat(st.created_at)
|
||||
parsed_transition = _dt.fromisoformat(st.last_transition_at)
|
||||
assert parsed_created.tzinfo is not None
|
||||
assert parsed_transition.tzinfo is not None
|
||||
# Sanity check : c'est bien UTC.
|
||||
assert "+00:00" in st.created_at or st.created_at.endswith("Z")
|
||||
|
||||
def test_confirm_blocked_when_name_missing(self, orchestrator, mock_streaming):
|
||||
"""Correction #3 — CONFIRM en NAMING avec competence_name=None reste NAMING."""
|
||||
st, _ = orchestrator.start_session(
|
||||
user_id="dom", machine_id="machine_x"
|
||||
)
|
||||
sid = st.session_id
|
||||
orchestrator.handle_chat_message(sid, "c'est bon")
|
||||
orchestrator.handle_chat_message(sid, "c'est parfait") # → NAMING
|
||||
# Forcer competence_name à None et envoyer un CONFIRM
|
||||
orchestrator._sessions[sid].competence_name = None
|
||||
reply = orchestrator.handle_chat_message(sid, "ok") # CONFIRM
|
||||
assert orchestrator._sessions[sid].state == LearnState.NAMING
|
||||
assert reply is not None
|
||||
assert "nom" in reply.lower() or "appeler" in reply.lower()
|
||||
mock_streaming.competence_persist.assert_not_called()
|
||||
|
||||
def test_confirm_blocked_when_name_empty(self, orchestrator, mock_streaming):
|
||||
"""Correction #3 — CONFIRM en NAMING avec competence_name='' reste NAMING."""
|
||||
st, _ = orchestrator.start_session(
|
||||
user_id="dom", machine_id="machine_x"
|
||||
)
|
||||
sid = st.session_id
|
||||
orchestrator.handle_chat_message(sid, "c'est bon")
|
||||
orchestrator.handle_chat_message(sid, "c'est parfait") # → NAMING
|
||||
orchestrator._sessions[sid].competence_name = " " # vide après strip
|
||||
reply = orchestrator.handle_chat_message(sid, "ok")
|
||||
assert orchestrator._sessions[sid].state == LearnState.NAMING
|
||||
assert reply is not None
|
||||
assert "nom" in reply.lower() or "appeler" in reply.lower()
|
||||
mock_streaming.competence_persist.assert_not_called()
|
||||
Reference in New Issue
Block a user