diff --git a/agent_v0/server_v1/safety_checks_provider.py b/agent_v0/server_v1/safety_checks_provider.py new file mode 100644 index 000000000..a7bc4148f --- /dev/null +++ b/agent_v0/server_v1/safety_checks_provider.py @@ -0,0 +1,191 @@ +# agent_v0/server_v1/safety_checks_provider.py +"""SafetyChecksProvider — checks hybrides déclaratifs + LLM contextuels (QW4). + +Pour une action pause_for_human : +- les checks déclaratifs (workflow) sont toujours inclus +- si safety_level == "medical_critical" et RPA_SAFETY_CHECKS_LLM_ENABLED=1, + un appel LLM (medgemma:4b par défaut) ajoute jusqu'à N checks contextuels + +Tout échec côté LLM (timeout, exception, parse) → additional_checks=[] : +le replay continue avec uniquement les déclaratifs (fallback safe). +""" + +import base64 +import io +import json +import logging +import os +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class PausePayload: + checks: List[Dict[str, Any]] = field(default_factory=list) + pause_reason: str = "" + message: str = "" + + +def _env(name: str, default: str) -> str: + return os.environ.get(name, default).strip() + + +def _env_int(name: str, default: int) -> int: + try: + return int(os.environ.get(name, default)) + except (TypeError, ValueError): + return default + + +def _env_bool_enabled(name: str) -> bool: + val = os.environ.get(name, "1").strip().lower() + return val not in ("0", "false", "no", "off", "") + + +def build_pause_payload( + action: Dict[str, Any], + replay_state: Dict[str, Any], + last_screenshot: Optional[str], +) -> PausePayload: + """Construit le payload de pause enrichi pour une action pause_for_human.""" + params = action.get("parameters") or {} + message = params.get("message", "Validation requise") + safety_level = params.get("safety_level") + declarative = params.get("safety_checks") or [] + + # Normalisation des checks déclaratifs + checks: List[Dict[str, Any]] = [] + for d in declarative: + checks.append({ + "id": d.get("id") or f"decl_{uuid.uuid4().hex[:6]}", + "label": d.get("label", "Validation"), + "required": bool(d.get("required", True)), + "source": "declarative", + "evidence": None, + }) + + # Ajout LLM contextual si applicable + if safety_level == "medical_critical" and _env_bool_enabled("RPA_SAFETY_CHECKS_LLM_ENABLED"): + try: + additional = _call_llm_for_contextual_checks( + action=action, + replay_state=replay_state, + last_screenshot=last_screenshot, + existing_labels=[c["label"] for c in checks], + ) + except Exception as e: + logger.warning("safety_checks LLM exception (%s) — fallback safe", e) + additional = [] + + for a in additional: + checks.append({ + "id": f"llm_{uuid.uuid4().hex[:6]}", + "label": a.get("label", ""), + "required": False, # checks LLM = informationnels, pas obligatoires V1 + "source": "llm_contextual", + "evidence": a.get("evidence", ""), + }) + + return PausePayload( + checks=checks, + pause_reason="", + message=message, + ) + + +def _call_llm_for_contextual_checks( + action: Dict[str, Any], + replay_state: Dict[str, Any], + last_screenshot: Optional[str], + existing_labels: List[str], +) -> List[Dict[str, str]]: + """Appelle Ollama en mode JSON strict pour générer 0-N checks contextuels. + + Returns: + List[{label, evidence}] (max RPA_SAFETY_CHECKS_LLM_MAX_CHECKS). + [] sur tout échec (timeout, JSON invalide, exception). + """ + import requests + + model = _env("RPA_SAFETY_CHECKS_LLM_MODEL", "medgemma:4b") + timeout_s = _env_int("RPA_SAFETY_CHECKS_LLM_TIMEOUT_S", 5) + max_checks = _env_int("RPA_SAFETY_CHECKS_LLM_MAX_CHECKS", 3) + ollama_url = _env("OLLAMA_URL", "http://localhost:11434") + + params = action.get("parameters") or {} + workflow_message = params.get("message", "") + existing = ", ".join(existing_labels) if existing_labels else "aucun" + + prompt = f"""Tu es Léa, assistante médicale supervisée. +Avant de continuer le workflow, tu dois lister 0 à {max_checks} vérifications supplémentaires +que l'humain doit acquitter, en regardant l'écran actuel. + +Contexte workflow : {workflow_message} +Checks déjà demandés : {existing} + +NE répète PAS un check déjà demandé. +Si rien d'inhabituel à signaler, retourne {{"additional_checks": []}}. + +Réponds UNIQUEMENT en JSON : +{{ + "additional_checks": [ + {{"label": "string court", "evidence": "ce que tu as vu d'inhabituel"}} + ] +}} +""" + + payload = { + "model": model, + "prompt": prompt, + "stream": False, + "format": "json", + "options": {"temperature": 0.1, "num_predict": 200}, + } + + if last_screenshot and os.path.isfile(last_screenshot): + try: + with open(last_screenshot, "rb") as f: + payload["images"] = [base64.b64encode(f.read()).decode("ascii")] + except Exception as e: + logger.debug("safety_checks: lecture screenshot échouée (%s) — appel sans image", e) + + try: + response = requests.post( + f"{ollama_url}/api/generate", + json=payload, + timeout=timeout_s, + ) + if response.status_code != 200: + logger.warning("safety_checks LLM HTTP %s", response.status_code) + return [] + text = response.json().get("response", "").strip() + except requests.Timeout: + logger.warning("safety_checks LLM timeout (%ss)", timeout_s) + return [] + except Exception as e: + logger.warning("safety_checks LLM erreur réseau: %s", e) + return [] + + # format=json garantit normalement du JSON valide + try: + parsed = json.loads(text) + except json.JSONDecodeError as e: + logger.warning("safety_checks LLM JSON invalide (%s) — fallback safe", e) + return [] + + additional = parsed.get("additional_checks") or [] + if not isinstance(additional, list): + return [] + + # Filtre + tronc + valid = [] + for item in additional[:max_checks]: + if isinstance(item, dict) and item.get("label"): + valid.append({ + "label": str(item["label"])[:200], + "evidence": str(item.get("evidence", ""))[:300], + }) + return valid diff --git a/tests/unit/test_safety_checks_provider.py b/tests/unit/test_safety_checks_provider.py new file mode 100644 index 000000000..e87a57274 --- /dev/null +++ b/tests/unit/test_safety_checks_provider.py @@ -0,0 +1,111 @@ +# tests/unit/test_safety_checks_provider.py +"""Tests unitaires SafetyChecksProvider (QW4).""" +import json +import pytest +from unittest.mock import patch, MagicMock + +from agent_v0.server_v1.safety_checks_provider import build_pause_payload, PausePayload + + +def _action(safety_level=None, declarative_checks=None, message="Validation"): + params = {"message": message} + if safety_level: + params["safety_level"] = safety_level + if declarative_checks is not None: + params["safety_checks"] = declarative_checks + return {"type": "pause_for_human", "parameters": params} + + +def test_only_declarative_when_no_safety_level(): + """Pas de safety_level → uniquement les checks déclaratifs, pas d'appel LLM.""" + decl = [{"id": "c1", "label": "Vérifier IPP", "required": True}] + with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks") as mock_llm: + payload = build_pause_payload(_action(declarative_checks=decl), {}, last_screenshot=None) + mock_llm.assert_not_called() + assert len(payload.checks) == 1 + assert payload.checks[0]["source"] == "declarative" + + +def test_hybrid_appends_llm_checks_on_medical_critical(monkeypatch): + """safety_level=medical_critical → LLM appelé, checks concaténés.""" + decl = [{"id": "c1", "label": "Vérifier IPP", "required": True}] + llm_resp = [{"label": "Nom patient suspect à l'écran", "evidence": "vu un nom différent"}] + + with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks", + return_value=llm_resp) as mock_llm: + payload = build_pause_payload( + _action(safety_level="medical_critical", declarative_checks=decl), + {}, last_screenshot="/tmp/fake.png", + ) + mock_llm.assert_called_once() + assert len(payload.checks) == 2 + assert payload.checks[0]["source"] == "declarative" + assert payload.checks[1]["source"] == "llm_contextual" + assert payload.checks[1]["evidence"] == "vu un nom différent" + + +def test_llm_timeout_falls_back_to_declarative_only(): + """LLM timeout → additional_checks=[], pas de crash, déclaratifs gardés.""" + decl = [{"id": "c1", "label": "Vérifier IPP", "required": True}] + with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks", + return_value=[]) as mock_llm: + payload = build_pause_payload( + _action(safety_level="medical_critical", declarative_checks=decl), + {}, last_screenshot="/tmp/fake.png", + ) + assert len(payload.checks) == 1 + assert payload.checks[0]["source"] == "declarative" + + +def test_llm_invalid_response_falls_back(): + """Si _call_llm retourne [] (parse échoué en interne) → fallback safe.""" + with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks", + return_value=[]): + payload = build_pause_payload( + _action(safety_level="medical_critical", declarative_checks=[]), + {}, last_screenshot="/tmp/fake.png", + ) + assert payload.checks == [] + + +def test_kill_switch_disables_llm_call(monkeypatch): + """RPA_SAFETY_CHECKS_LLM_ENABLED=0 → LLM jamais appelé.""" + monkeypatch.setenv("RPA_SAFETY_CHECKS_LLM_ENABLED", "0") + decl = [{"id": "c1", "label": "X", "required": True}] + with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks") as mock_llm: + payload = build_pause_payload( + _action(safety_level="medical_critical", declarative_checks=decl), + {}, last_screenshot="/tmp/fake.png", + ) + mock_llm.assert_not_called() + assert len(payload.checks) == 1 + + +def test_max_checks_respected(monkeypatch): + """RPA_SAFETY_CHECKS_LLM_MAX_CHECKS=2 → max 2 checks LLM ajoutés.""" + monkeypatch.setenv("RPA_SAFETY_CHECKS_LLM_MAX_CHECKS", "2") + decl = [] + llm_resp = [ + {"label": f"Check {i}", "evidence": f"e{i}"} for i in range(5) + ] + with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks", + return_value=llm_resp[:2]): # provider tronque déjà + payload = build_pause_payload( + _action(safety_level="medical_critical", declarative_checks=decl), + {}, last_screenshot="/tmp/fake.png", + ) + assert len(payload.checks) == 2 + + +def test_empty_declarative_with_llm_returns_only_llm(): + """Pas de déclaratif + LLM ajoute 2 checks → payload contient les 2.""" + llm_resp = [{"label": "Vérifier date", "evidence": "date 1900 suspecte"}, + {"label": "Vérifier devise", "evidence": "montant en USD au lieu d'EUR"}] + with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks", + return_value=llm_resp): + payload = build_pause_payload( + _action(safety_level="medical_critical", declarative_checks=[]), + {}, last_screenshot="/tmp/fake.png", + ) + assert len(payload.checks) == 2 + assert all(c["source"] == "llm_contextual" for c in payload.checks)