527 lines
20 KiB
Python
527 lines
20 KiB
Python
"""Tests unit pour agent_chat.handlers.learn_action.
|
|
|
|
Couvre :
|
|
- LearnIntentParser (regex)
|
|
- OptionCFormatter
|
|
- StateStore (write atomique + reprise)
|
|
- LearnActionOrchestrator (transitions, garde-fous, persistance)
|
|
- PersistPayloadBuilder
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from agent_chat.handlers.learn_action import (
|
|
LearnActionOrchestrator,
|
|
LearnIntent,
|
|
LearnIntentParser,
|
|
LearnState,
|
|
OptionCFormatter,
|
|
PersistPayloadBuilder,
|
|
SessionState,
|
|
StateStore,
|
|
)
|
|
|
|
|
|
# ============================================================
|
|
# LearnIntentParser
|
|
# ============================================================
|
|
class TestLearnIntentParser:
|
|
def setup_method(self):
|
|
# Désactive le LLM fallback pour isoler les tests regex
|
|
self.parser = LearnIntentParser(use_llm_fallback=False)
|
|
|
|
@pytest.mark.parametrize(
|
|
"msg",
|
|
[
|
|
"apprends-moi",
|
|
"Apprends moi",
|
|
"regarde-moi faire",
|
|
"observe",
|
|
"enregistre",
|
|
"on apprend",
|
|
"tu vas apprendre",
|
|
"Léa apprends",
|
|
],
|
|
)
|
|
def test_start_observe(self, msg):
|
|
r = self.parser.parse(msg, current_state=LearnState.IDLE)
|
|
assert r.intent == LearnIntent.START_OBSERVE
|
|
assert r.confidence >= 0.9
|
|
|
|
@pytest.mark.parametrize(
|
|
"msg",
|
|
[
|
|
"stop",
|
|
"c'est bon",
|
|
"j'ai fini",
|
|
"voilà c'est tout",
|
|
"fini",
|
|
"arrête",
|
|
"termine",
|
|
],
|
|
)
|
|
def test_user_stop_observe(self, msg):
|
|
r = self.parser.parse(msg, current_state=LearnState.WAITING_USER_STOP)
|
|
assert r.intent == LearnIntent.USER_STOP_OBSERVE
|
|
|
|
def test_correct_step_with_index(self):
|
|
r = self.parser.parse(
|
|
"Corrige l'étape 3 : il faut cliquer sur Valider",
|
|
current_state=LearnState.ITERATING_FEEDBACK,
|
|
)
|
|
assert r.intent == LearnIntent.CORRECT_STEP
|
|
assert r.step_index == 3
|
|
assert "valider" in (r.extra.get("new_intent") or "").lower()
|
|
|
|
def test_undo_step(self):
|
|
r = self.parser.parse(
|
|
"Retire l'étape 2", current_state=LearnState.ITERATING_FEEDBACK
|
|
)
|
|
assert r.intent == LearnIntent.UNDO_STEP
|
|
assert r.step_index == 2
|
|
|
|
def test_merge_next(self):
|
|
r = self.parser.parse(
|
|
"Fusionne avec la suivante", current_state=LearnState.ITERATING_FEEDBACK
|
|
)
|
|
assert r.intent == LearnIntent.MERGE_NEXT
|
|
|
|
def test_split_step(self):
|
|
r = self.parser.parse(
|
|
"Coupe l'étape 4", current_state=LearnState.ITERATING_FEEDBACK
|
|
)
|
|
assert r.intent == LearnIntent.SPLIT_STEP
|
|
assert r.step_index == 4
|
|
|
|
def test_cancel(self):
|
|
r = self.parser.parse("annule tout", current_state=LearnState.LISTENING)
|
|
assert r.intent == LearnIntent.CANCEL
|
|
|
|
def test_validate_in_iterating(self):
|
|
r = self.parser.parse(
|
|
"c'est parfait", current_state=LearnState.ITERATING_FEEDBACK
|
|
)
|
|
assert r.intent == LearnIntent.VALIDATE_STEP
|
|
|
|
def test_mark_parameter_variable(self):
|
|
r = self.parser.parse(
|
|
"ça change à chaque fois", current_state=LearnState.NAMING
|
|
)
|
|
assert r.intent == LearnIntent.MARK_PARAMETER
|
|
assert r.extra.get("is_parameter") is True
|
|
|
|
def test_mark_parameter_constant(self):
|
|
r = self.parser.parse(
|
|
"toujours pareil", current_state=LearnState.NAMING
|
|
)
|
|
assert r.intent == LearnIntent.MARK_PARAMETER
|
|
assert r.extra.get("is_parameter") is False
|
|
|
|
def test_name_competence_when_naming(self):
|
|
r = self.parser.parse(
|
|
"facturation urgences", current_state=LearnState.NAMING
|
|
)
|
|
assert r.intent == LearnIntent.NAME_COMPETENCE
|
|
assert "facturation" in (r.extra.get("name") or "")
|
|
|
|
def test_unknown_in_idle(self):
|
|
r = self.parser.parse(
|
|
"blabla random", current_state=LearnState.IDLE
|
|
)
|
|
assert r.intent == LearnIntent.UNKNOWN
|
|
|
|
def test_llm_fallback_disabled_after_failure(self, monkeypatch):
|
|
# Active le LLM mais simule une erreur réseau
|
|
parser = LearnIntentParser(use_llm_fallback=True)
|
|
# Force exception sur httpx
|
|
parser._parse_llm = lambda *args, **kwargs: None # type: ignore[method-assign]
|
|
r = parser.parse("zorglub blabla truc", current_state=LearnState.IDLE)
|
|
# Doit retomber gracieusement sur UNKNOWN sans crasher
|
|
assert r.intent == LearnIntent.UNKNOWN
|
|
|
|
|
|
# ============================================================
|
|
# OptionCFormatter
|
|
# ============================================================
|
|
class TestOptionCFormatter:
|
|
def setup_method(self):
|
|
self.fmt = OptionCFormatter()
|
|
|
|
def test_empty(self):
|
|
assert "aucune étape" in self.fmt.format([])
|
|
|
|
def test_simple_click(self):
|
|
understanding = [
|
|
{"action_type": "click", "target_label": "Valider", "widget_type": "Bouton"}
|
|
]
|
|
out = self.fmt.format(understanding)
|
|
assert "1." in out
|
|
assert "« Valider »" in out
|
|
assert "cliqué" in out
|
|
|
|
def test_type_with_value(self):
|
|
understanding = [
|
|
{
|
|
"action_type": "type",
|
|
"target_label": "IPP",
|
|
"widget_type": "Champ",
|
|
"value": "25003284",
|
|
}
|
|
]
|
|
out = self.fmt.format(understanding)
|
|
assert "« IPP »" in out
|
|
assert "« 25003284 »" in out
|
|
assert "saisi" in out
|
|
|
|
def test_low_confidence_suffix(self):
|
|
understanding = [
|
|
{
|
|
"action_type": "click",
|
|
"target_label": "Patient",
|
|
"widget_type": "Fenêtre",
|
|
"confidence_ocr": 0.4,
|
|
}
|
|
]
|
|
out = self.fmt.format(understanding)
|
|
assert "(à confirmer)" in out
|
|
|
|
def test_unknown_action_fallback(self):
|
|
understanding = [{"action_type": "wibble", "target_label": "X"}]
|
|
out = self.fmt.format(understanding)
|
|
assert "effectuée" in out
|
|
|
|
def test_closing_question(self):
|
|
q = self.fmt.closing_question()
|
|
assert "trompée" in q or "trompee" in q.lower().replace("é", "e")
|
|
|
|
|
|
# ============================================================
|
|
# StateStore
|
|
# ============================================================
|
|
class TestStateStore:
|
|
def test_save_and_load(self, tmp_path):
|
|
store = StateStore(tmp_path)
|
|
st = SessionState(
|
|
session_id="abc123",
|
|
user_id="dom",
|
|
state=LearnState.ITERATING_FEEDBACK,
|
|
)
|
|
store.save(st)
|
|
loaded = store.load("abc123")
|
|
assert loaded is not None
|
|
assert loaded.session_id == "abc123"
|
|
assert loaded.user_id == "dom"
|
|
assert loaded.state == LearnState.ITERATING_FEEDBACK
|
|
|
|
def test_atomic_write_no_partial(self, tmp_path):
|
|
store = StateStore(tmp_path)
|
|
st = SessionState(session_id="atomic1")
|
|
store.save(st)
|
|
# Pas de fichier .tmp restant
|
|
tmp_files = list(tmp_path.glob("*.tmp"))
|
|
assert tmp_files == []
|
|
|
|
def test_list_active_filters_done(self, tmp_path):
|
|
store = StateStore(tmp_path)
|
|
store.save(SessionState(session_id="s1", state=LearnState.ITERATING_FEEDBACK))
|
|
store.save(SessionState(session_id="s2", state=LearnState.DONE))
|
|
store.save(SessionState(session_id="s3", state=LearnState.ABORTED))
|
|
active = store.list_active()
|
|
ids = {s.session_id for s in active}
|
|
assert ids == {"s1"}
|
|
|
|
def test_session_id_sanitized(self, tmp_path):
|
|
store = StateStore(tmp_path)
|
|
st = SessionState(session_id="../../etc/passwd")
|
|
store.save(st)
|
|
# Aucun fichier hors tmp_path
|
|
files = list(tmp_path.glob("*.json"))
|
|
assert len(files) == 1
|
|
assert files[0].parent == tmp_path
|
|
|
|
def test_delete(self, tmp_path):
|
|
store = StateStore(tmp_path)
|
|
store.save(SessionState(session_id="del_me"))
|
|
store.delete("del_me")
|
|
assert store.load("del_me") is None
|
|
|
|
|
|
# ============================================================
|
|
# PersistPayloadBuilder
|
|
# ============================================================
|
|
class TestPersistPayloadBuilder:
|
|
def test_build_with_parameters(self):
|
|
st = SessionState(
|
|
session_id="sX",
|
|
competence_name="Test compétence",
|
|
user_id="dom",
|
|
parameters_marked=[
|
|
{
|
|
"step_index": 3,
|
|
"is_parameter": True,
|
|
"name": "ipp",
|
|
"example_value": "25003284",
|
|
"field_label": "IPP",
|
|
},
|
|
{
|
|
"step_index": 4,
|
|
"is_parameter": False,
|
|
"name": "type",
|
|
"example_value": "C2",
|
|
"field_label": "Type",
|
|
},
|
|
],
|
|
)
|
|
payload = PersistPayloadBuilder().build(st)
|
|
assert payload["name"] == "Test compétence"
|
|
assert payload["session_id"] == "sX"
|
|
assert payload["user_id"] == "dom"
|
|
# Seul le param flagué is_parameter=True doit apparaître
|
|
assert len(payload["parameters"]) == 1
|
|
assert payload["parameters"][0]["name"] == "ipp"
|
|
|
|
def test_persist_payload_includes_machine_id(self):
|
|
"""Correction #1 — payload doit inclure machine_id."""
|
|
st = SessionState(
|
|
session_id="sM",
|
|
competence_name="X",
|
|
machine_id="DESKTOP-58D5CAC_windows",
|
|
)
|
|
payload = PersistPayloadBuilder().build(st)
|
|
assert "machine_id" in payload
|
|
assert payload["machine_id"] == "DESKTOP-58D5CAC_windows"
|
|
|
|
def test_persist_payload_machine_id_none_when_absent(self):
|
|
"""Quand non fourni, machine_id reste présent à None dans le payload."""
|
|
st = SessionState(session_id="sM2", competence_name="X")
|
|
payload = PersistPayloadBuilder().build(st)
|
|
assert "machine_id" in payload
|
|
assert payload["machine_id"] is None
|
|
|
|
|
|
# ============================================================
|
|
# LearnActionOrchestrator (avec StreamingClient mocké)
|
|
# ============================================================
|
|
@pytest.fixture
|
|
def mock_streaming():
|
|
"""StreamingClient simulé."""
|
|
m = MagicMock()
|
|
m.shadow_start.return_value = {"ok": True}
|
|
m.shadow_stop.return_value = {"ok": True}
|
|
m.shadow_understanding.return_value = {
|
|
"understanding": [
|
|
{"action_type": "click", "target_label": "Patient", "widget_type": "Fenêtre"},
|
|
{
|
|
"action_type": "type",
|
|
"target_label": "IPP",
|
|
"widget_type": "Champ",
|
|
"value": "25003284",
|
|
},
|
|
]
|
|
}
|
|
m.shadow_feedback.return_value = {"ok": True}
|
|
m.shadow_build.return_value = {"ok": True}
|
|
m.competence_persist.return_value = {"slug": "facturation_urgences"}
|
|
return m
|
|
|
|
|
|
@pytest.fixture
|
|
def orchestrator(tmp_path, mock_streaming):
|
|
parser = LearnIntentParser(use_llm_fallback=False)
|
|
store = StateStore(tmp_path)
|
|
return LearnActionOrchestrator(
|
|
streaming_client=mock_streaming,
|
|
intent_parser=parser,
|
|
state_store=store,
|
|
emit=MagicMock(),
|
|
)
|
|
|
|
|
|
class TestLearnActionOrchestrator:
|
|
def test_start_session_transitions(self, orchestrator, mock_streaming):
|
|
st, reply = orchestrator.start_session(user_id="dom", trigger_source="button")
|
|
assert st.state == LearnState.WAITING_USER_STOP
|
|
mock_streaming.shadow_start.assert_called_once()
|
|
assert "je te regarde" in reply.lower() or "regarde" in reply.lower()
|
|
|
|
def test_full_happy_path(self, orchestrator, mock_streaming):
|
|
st, _ = orchestrator.start_session(user_id="dom", machine_id="m1")
|
|
sid = st.session_id
|
|
|
|
# Utilisateur dit stop
|
|
reply = orchestrator.handle_chat_message(sid, "c'est bon")
|
|
assert reply is not None
|
|
assert "j'ai compris" in reply.lower()
|
|
assert orchestrator._sessions[sid].state == LearnState.ITERATING_FEEDBACK
|
|
|
|
# Utilisateur valide globalement → NAMING
|
|
reply = orchestrator.handle_chat_message(sid, "c'est parfait")
|
|
assert orchestrator._sessions[sid].state == LearnState.NAMING
|
|
|
|
# Nomination
|
|
reply = orchestrator.handle_chat_message(sid, "facturation urgences")
|
|
# Maintenant Léa doit poser une question sur le paramètre IPP
|
|
assert "25003284" in (reply or "")
|
|
assert orchestrator._sessions[sid].competence_name == "facturation urgences"
|
|
|
|
# Marquer le paramètre comme variable
|
|
reply = orchestrator.handle_chat_message(sid, "ça change à chaque fois")
|
|
# Plus de pending → persist
|
|
mock_streaming.shadow_build.assert_called_once()
|
|
mock_streaming.competence_persist.assert_called_once()
|
|
assert orchestrator._sessions[sid].state == LearnState.DONE
|
|
|
|
def test_emergency_exit_after_3_corrections(self, orchestrator, mock_streaming):
|
|
st, _ = orchestrator.start_session(user_id="dom")
|
|
sid = st.session_id
|
|
orchestrator.handle_chat_message(sid, "c'est bon") # stop
|
|
|
|
for i in range(3):
|
|
r = orchestrator.handle_chat_message(
|
|
sid, "corrige l'étape 3 : clique sur Valider"
|
|
)
|
|
assert orchestrator._sessions[sid].state == LearnState.ITERATING_FEEDBACK
|
|
|
|
# 4e correction → ABORTED
|
|
r = orchestrator.handle_chat_message(
|
|
sid, "corrige l'étape 3 : clique sur Valider"
|
|
)
|
|
assert orchestrator._sessions[sid].state == LearnState.ABORTED
|
|
assert "n°3" in (r or "")
|
|
|
|
def test_cancel_anywhere(self, orchestrator, mock_streaming):
|
|
st, _ = orchestrator.start_session(user_id="dom")
|
|
sid = st.session_id
|
|
reply = orchestrator.handle_chat_message(sid, "annule tout")
|
|
assert orchestrator._sessions[sid].state == LearnState.ABORTED
|
|
assert "annule" in (reply or "").lower()
|
|
|
|
def test_idle_message_returns_none(self, orchestrator):
|
|
# Aucune session ouverte → None (laisser le flux normal gérer)
|
|
r = orchestrator.handle_chat_message("nonexistent", "Bonjour")
|
|
assert r is None
|
|
|
|
def test_state_persistence_across_reload(self, tmp_path, mock_streaming):
|
|
store = StateStore(tmp_path)
|
|
parser = LearnIntentParser(use_llm_fallback=False)
|
|
orch1 = LearnActionOrchestrator(
|
|
streaming_client=mock_streaming,
|
|
intent_parser=parser,
|
|
state_store=store,
|
|
emit=MagicMock(),
|
|
)
|
|
st, _ = orch1.start_session(user_id="dom")
|
|
sid = st.session_id
|
|
orch1.handle_chat_message(sid, "c'est bon") # passe en ITERATING_FEEDBACK
|
|
|
|
# Simule un crash + redémarrage
|
|
orch2 = LearnActionOrchestrator(
|
|
streaming_client=mock_streaming,
|
|
intent_parser=parser,
|
|
state_store=store,
|
|
emit=MagicMock(),
|
|
)
|
|
resumed = orch2.resume_sessions()
|
|
assert sid in resumed
|
|
assert orch2._sessions[sid].state == LearnState.ITERATING_FEEDBACK
|
|
|
|
def test_proactive_signal_cooldown(self, orchestrator):
|
|
r1 = orchestrator.handle_proactive_signal("action_repeat", {})
|
|
assert r1 is not None
|
|
# Deuxième signal immédiat → ignoré
|
|
r2 = orchestrator.handle_proactive_signal("action_repeat", {})
|
|
assert r2 is None
|
|
|
|
def test_illegal_transition_ignored(self, orchestrator, mock_streaming):
|
|
st, _ = orchestrator.start_session(user_id="dom")
|
|
# Tentative de passer directement de WAITING_USER_STOP à DONE
|
|
prev = orchestrator._sessions[st.session_id].state
|
|
orchestrator._transition(
|
|
orchestrator._sessions[st.session_id], LearnState.DONE
|
|
)
|
|
assert orchestrator._sessions[st.session_id].state == prev
|
|
|
|
# ============================================================
|
|
# Corrections P1-LEA-SHADOW 2026-06-01 (NO-GO Qwen)
|
|
# ============================================================
|
|
def test_start_session_stores_machine_id(self, orchestrator):
|
|
"""Correction #1 — machine_id transmis à start_session est stocké."""
|
|
st, _ = orchestrator.start_session(
|
|
user_id="dom",
|
|
trigger_source="windows_button",
|
|
machine_id="DESKTOP-58D5CAC_windows",
|
|
)
|
|
assert st.machine_id == "DESKTOP-58D5CAC_windows"
|
|
# Et la session en mémoire aussi
|
|
assert (
|
|
orchestrator._sessions[st.session_id].machine_id
|
|
== "DESKTOP-58D5CAC_windows"
|
|
)
|
|
|
|
def test_persist_blocked_without_machine_id(self, orchestrator, mock_streaming):
|
|
"""Correction #1 — persist refusé conversationnellement sans machine_id."""
|
|
st, _ = orchestrator.start_session(user_id="dom") # pas de machine_id
|
|
sid = st.session_id
|
|
orchestrator.handle_chat_message(sid, "c'est bon") # → ITERATING
|
|
orchestrator.handle_chat_message(sid, "c'est parfait") # → NAMING
|
|
orchestrator.handle_chat_message(sid, "ma competence") # nom
|
|
# Marquer paramètre → tentative persist
|
|
reply = orchestrator.handle_chat_message(sid, "ça change à chaque fois")
|
|
# competence_persist NE doit PAS avoir été appelée
|
|
mock_streaming.competence_persist.assert_not_called()
|
|
# Message métier explicite côté Léa
|
|
assert reply is not None
|
|
assert "machine" in reply.lower()
|
|
|
|
def test_datetime_uses_timezone_aware(self):
|
|
"""Correction #2 — created_at / last_transition_at sont timezone-aware."""
|
|
st = SessionState(session_id="tz1")
|
|
# Le format ISO doit contenir un offset (+00:00 ou Z) — tzinfo présent
|
|
# après reparse via fromisoformat (Python 3.11+).
|
|
from datetime import datetime as _dt
|
|
parsed_created = _dt.fromisoformat(st.created_at)
|
|
parsed_transition = _dt.fromisoformat(st.last_transition_at)
|
|
assert parsed_created.tzinfo is not None
|
|
assert parsed_transition.tzinfo is not None
|
|
# Sanity check : c'est bien UTC.
|
|
assert "+00:00" in st.created_at or st.created_at.endswith("Z")
|
|
|
|
def test_confirm_blocked_when_name_missing(self, orchestrator, mock_streaming):
|
|
"""Correction #3 — CONFIRM en NAMING avec competence_name=None reste NAMING."""
|
|
st, _ = orchestrator.start_session(
|
|
user_id="dom", machine_id="machine_x"
|
|
)
|
|
sid = st.session_id
|
|
orchestrator.handle_chat_message(sid, "c'est bon")
|
|
orchestrator.handle_chat_message(sid, "c'est parfait") # → NAMING
|
|
# Forcer competence_name à None et envoyer un CONFIRM
|
|
orchestrator._sessions[sid].competence_name = None
|
|
reply = orchestrator.handle_chat_message(sid, "ok") # CONFIRM
|
|
assert orchestrator._sessions[sid].state == LearnState.NAMING
|
|
assert reply is not None
|
|
assert "nom" in reply.lower() or "appeler" in reply.lower()
|
|
mock_streaming.competence_persist.assert_not_called()
|
|
|
|
def test_confirm_blocked_when_name_empty(self, orchestrator, mock_streaming):
|
|
"""Correction #3 — CONFIRM en NAMING avec competence_name='' reste NAMING."""
|
|
st, _ = orchestrator.start_session(
|
|
user_id="dom", machine_id="machine_x"
|
|
)
|
|
sid = st.session_id
|
|
orchestrator.handle_chat_message(sid, "c'est bon")
|
|
orchestrator.handle_chat_message(sid, "c'est parfait") # → NAMING
|
|
orchestrator._sessions[sid].competence_name = " " # vide après strip
|
|
reply = orchestrator.handle_chat_message(sid, "ok")
|
|
assert orchestrator._sessions[sid].state == LearnState.NAMING
|
|
assert reply is not None
|
|
assert "nom" in reply.lower() or "appeler" in reply.lower()
|
|
mock_streaming.competence_persist.assert_not_called()
|