feat(qw4): SafetyChecksProvider hybride déclaratif + LLM contextuel

build_pause_payload(action, state, last_screenshot) → PausePayload
- Toujours inclure les checks déclaratifs (workflow.parameters.safety_checks)
- Si safety_level=medical_critical ET RPA_SAFETY_CHECKS_LLM_ENABLED=1 :
    appel LLM (medgemma:4b par défaut) en format=json strict, timeout 5s,
    max 3 checks ajoutés (configurables via env vars)
- Tous les chemins d'erreur (timeout, HTTP, JSON parse, exception) loggent
  et retournent [] (fallback safe : déclaratifs seuls)

Tests : 7 cas (déclaratif seul, hybride OK, timeout, LLM invalide,
kill-switch, max_checks, déclaratif vide).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Dom
2026-05-05 23:29:38 +02:00
parent ca0b436a61
commit 7c6945171e
2 changed files with 302 additions and 0 deletions

View File

@@ -0,0 +1,191 @@
# agent_v0/server_v1/safety_checks_provider.py
"""SafetyChecksProvider — checks hybrides déclaratifs + LLM contextuels (QW4).
Pour une action pause_for_human :
- les checks déclaratifs (workflow) sont toujours inclus
- si safety_level == "medical_critical" et RPA_SAFETY_CHECKS_LLM_ENABLED=1,
un appel LLM (medgemma:4b par défaut) ajoute jusqu'à N checks contextuels
Tout échec côté LLM (timeout, exception, parse) → additional_checks=[] :
le replay continue avec uniquement les déclaratifs (fallback safe).
"""
import base64
import io
import json
import logging
import os
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@dataclass
class PausePayload:
checks: List[Dict[str, Any]] = field(default_factory=list)
pause_reason: str = ""
message: str = ""
def _env(name: str, default: str) -> str:
return os.environ.get(name, default).strip()
def _env_int(name: str, default: int) -> int:
try:
return int(os.environ.get(name, default))
except (TypeError, ValueError):
return default
def _env_bool_enabled(name: str) -> bool:
val = os.environ.get(name, "1").strip().lower()
return val not in ("0", "false", "no", "off", "")
def build_pause_payload(
action: Dict[str, Any],
replay_state: Dict[str, Any],
last_screenshot: Optional[str],
) -> PausePayload:
"""Construit le payload de pause enrichi pour une action pause_for_human."""
params = action.get("parameters") or {}
message = params.get("message", "Validation requise")
safety_level = params.get("safety_level")
declarative = params.get("safety_checks") or []
# Normalisation des checks déclaratifs
checks: List[Dict[str, Any]] = []
for d in declarative:
checks.append({
"id": d.get("id") or f"decl_{uuid.uuid4().hex[:6]}",
"label": d.get("label", "Validation"),
"required": bool(d.get("required", True)),
"source": "declarative",
"evidence": None,
})
# Ajout LLM contextual si applicable
if safety_level == "medical_critical" and _env_bool_enabled("RPA_SAFETY_CHECKS_LLM_ENABLED"):
try:
additional = _call_llm_for_contextual_checks(
action=action,
replay_state=replay_state,
last_screenshot=last_screenshot,
existing_labels=[c["label"] for c in checks],
)
except Exception as e:
logger.warning("safety_checks LLM exception (%s) — fallback safe", e)
additional = []
for a in additional:
checks.append({
"id": f"llm_{uuid.uuid4().hex[:6]}",
"label": a.get("label", ""),
"required": False, # checks LLM = informationnels, pas obligatoires V1
"source": "llm_contextual",
"evidence": a.get("evidence", ""),
})
return PausePayload(
checks=checks,
pause_reason="",
message=message,
)
def _call_llm_for_contextual_checks(
action: Dict[str, Any],
replay_state: Dict[str, Any],
last_screenshot: Optional[str],
existing_labels: List[str],
) -> List[Dict[str, str]]:
"""Appelle Ollama en mode JSON strict pour générer 0-N checks contextuels.
Returns:
List[{label, evidence}] (max RPA_SAFETY_CHECKS_LLM_MAX_CHECKS).
[] sur tout échec (timeout, JSON invalide, exception).
"""
import requests
model = _env("RPA_SAFETY_CHECKS_LLM_MODEL", "medgemma:4b")
timeout_s = _env_int("RPA_SAFETY_CHECKS_LLM_TIMEOUT_S", 5)
max_checks = _env_int("RPA_SAFETY_CHECKS_LLM_MAX_CHECKS", 3)
ollama_url = _env("OLLAMA_URL", "http://localhost:11434")
params = action.get("parameters") or {}
workflow_message = params.get("message", "")
existing = ", ".join(existing_labels) if existing_labels else "aucun"
prompt = f"""Tu es Léa, assistante médicale supervisée.
Avant de continuer le workflow, tu dois lister 0 à {max_checks} vérifications supplémentaires
que l'humain doit acquitter, en regardant l'écran actuel.
Contexte workflow : {workflow_message}
Checks déjà demandés : {existing}
NE répète PAS un check déjà demandé.
Si rien d'inhabituel à signaler, retourne {{"additional_checks": []}}.
Réponds UNIQUEMENT en JSON :
{{
"additional_checks": [
{{"label": "string court", "evidence": "ce que tu as vu d'inhabituel"}}
]
}}
"""
payload = {
"model": model,
"prompt": prompt,
"stream": False,
"format": "json",
"options": {"temperature": 0.1, "num_predict": 200},
}
if last_screenshot and os.path.isfile(last_screenshot):
try:
with open(last_screenshot, "rb") as f:
payload["images"] = [base64.b64encode(f.read()).decode("ascii")]
except Exception as e:
logger.debug("safety_checks: lecture screenshot échouée (%s) — appel sans image", e)
try:
response = requests.post(
f"{ollama_url}/api/generate",
json=payload,
timeout=timeout_s,
)
if response.status_code != 200:
logger.warning("safety_checks LLM HTTP %s", response.status_code)
return []
text = response.json().get("response", "").strip()
except requests.Timeout:
logger.warning("safety_checks LLM timeout (%ss)", timeout_s)
return []
except Exception as e:
logger.warning("safety_checks LLM erreur réseau: %s", e)
return []
# format=json garantit normalement du JSON valide
try:
parsed = json.loads(text)
except json.JSONDecodeError as e:
logger.warning("safety_checks LLM JSON invalide (%s) — fallback safe", e)
return []
additional = parsed.get("additional_checks") or []
if not isinstance(additional, list):
return []
# Filtre + tronc
valid = []
for item in additional[:max_checks]:
if isinstance(item, dict) and item.get("label"):
valid.append({
"label": str(item["label"])[:200],
"evidence": str(item.get("evidence", ""))[:300],
})
return valid

View File

@@ -0,0 +1,111 @@
# tests/unit/test_safety_checks_provider.py
"""Tests unitaires SafetyChecksProvider (QW4)."""
import json
import pytest
from unittest.mock import patch, MagicMock
from agent_v0.server_v1.safety_checks_provider import build_pause_payload, PausePayload
def _action(safety_level=None, declarative_checks=None, message="Validation"):
params = {"message": message}
if safety_level:
params["safety_level"] = safety_level
if declarative_checks is not None:
params["safety_checks"] = declarative_checks
return {"type": "pause_for_human", "parameters": params}
def test_only_declarative_when_no_safety_level():
"""Pas de safety_level → uniquement les checks déclaratifs, pas d'appel LLM."""
decl = [{"id": "c1", "label": "Vérifier IPP", "required": True}]
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks") as mock_llm:
payload = build_pause_payload(_action(declarative_checks=decl), {}, last_screenshot=None)
mock_llm.assert_not_called()
assert len(payload.checks) == 1
assert payload.checks[0]["source"] == "declarative"
def test_hybrid_appends_llm_checks_on_medical_critical(monkeypatch):
"""safety_level=medical_critical → LLM appelé, checks concaténés."""
decl = [{"id": "c1", "label": "Vérifier IPP", "required": True}]
llm_resp = [{"label": "Nom patient suspect à l'écran", "evidence": "vu un nom différent"}]
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks",
return_value=llm_resp) as mock_llm:
payload = build_pause_payload(
_action(safety_level="medical_critical", declarative_checks=decl),
{}, last_screenshot="/tmp/fake.png",
)
mock_llm.assert_called_once()
assert len(payload.checks) == 2
assert payload.checks[0]["source"] == "declarative"
assert payload.checks[1]["source"] == "llm_contextual"
assert payload.checks[1]["evidence"] == "vu un nom différent"
def test_llm_timeout_falls_back_to_declarative_only():
"""LLM timeout → additional_checks=[], pas de crash, déclaratifs gardés."""
decl = [{"id": "c1", "label": "Vérifier IPP", "required": True}]
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks",
return_value=[]) as mock_llm:
payload = build_pause_payload(
_action(safety_level="medical_critical", declarative_checks=decl),
{}, last_screenshot="/tmp/fake.png",
)
assert len(payload.checks) == 1
assert payload.checks[0]["source"] == "declarative"
def test_llm_invalid_response_falls_back():
"""Si _call_llm retourne [] (parse échoué en interne) → fallback safe."""
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks",
return_value=[]):
payload = build_pause_payload(
_action(safety_level="medical_critical", declarative_checks=[]),
{}, last_screenshot="/tmp/fake.png",
)
assert payload.checks == []
def test_kill_switch_disables_llm_call(monkeypatch):
"""RPA_SAFETY_CHECKS_LLM_ENABLED=0 → LLM jamais appelé."""
monkeypatch.setenv("RPA_SAFETY_CHECKS_LLM_ENABLED", "0")
decl = [{"id": "c1", "label": "X", "required": True}]
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks") as mock_llm:
payload = build_pause_payload(
_action(safety_level="medical_critical", declarative_checks=decl),
{}, last_screenshot="/tmp/fake.png",
)
mock_llm.assert_not_called()
assert len(payload.checks) == 1
def test_max_checks_respected(monkeypatch):
"""RPA_SAFETY_CHECKS_LLM_MAX_CHECKS=2 → max 2 checks LLM ajoutés."""
monkeypatch.setenv("RPA_SAFETY_CHECKS_LLM_MAX_CHECKS", "2")
decl = []
llm_resp = [
{"label": f"Check {i}", "evidence": f"e{i}"} for i in range(5)
]
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks",
return_value=llm_resp[:2]): # provider tronque déjà
payload = build_pause_payload(
_action(safety_level="medical_critical", declarative_checks=decl),
{}, last_screenshot="/tmp/fake.png",
)
assert len(payload.checks) == 2
def test_empty_declarative_with_llm_returns_only_llm():
"""Pas de déclaratif + LLM ajoute 2 checks → payload contient les 2."""
llm_resp = [{"label": "Vérifier date", "evidence": "date 1900 suspecte"},
{"label": "Vérifier devise", "evidence": "montant en USD au lieu d'EUR"}]
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks",
return_value=llm_resp):
payload = build_pause_payload(
_action(safety_level="medical_critical", declarative_checks=[]),
{}, last_screenshot="/tmp/fake.png",
)
assert len(payload.checks) == 2
assert all(c["source"] == "llm_contextual" for c in payload.checks)