feat(qw4): SafetyChecksProvider hybride déclaratif + LLM contextuel
build_pause_payload(action, state, last_screenshot) → PausePayload
- Toujours inclure les checks déclaratifs (workflow.parameters.safety_checks)
- Si safety_level=medical_critical ET RPA_SAFETY_CHECKS_LLM_ENABLED=1 :
appel LLM (medgemma:4b par défaut) en format=json strict, timeout 5s,
max 3 checks ajoutés (configurables via env vars)
- Tous les chemins d'erreur (timeout, HTTP, JSON parse, exception) loggent
et retournent [] (fallback safe : déclaratifs seuls)
Tests : 7 cas (déclaratif seul, hybride OK, timeout, LLM invalide,
kill-switch, max_checks, déclaratif vide).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
191
agent_v0/server_v1/safety_checks_provider.py
Normal file
191
agent_v0/server_v1/safety_checks_provider.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# agent_v0/server_v1/safety_checks_provider.py
|
||||
"""SafetyChecksProvider — checks hybrides déclaratifs + LLM contextuels (QW4).
|
||||
|
||||
Pour une action pause_for_human :
|
||||
- les checks déclaratifs (workflow) sont toujours inclus
|
||||
- si safety_level == "medical_critical" et RPA_SAFETY_CHECKS_LLM_ENABLED=1,
|
||||
un appel LLM (medgemma:4b par défaut) ajoute jusqu'à N checks contextuels
|
||||
|
||||
Tout échec côté LLM (timeout, exception, parse) → additional_checks=[] :
|
||||
le replay continue avec uniquement les déclaratifs (fallback safe).
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PausePayload:
|
||||
checks: List[Dict[str, Any]] = field(default_factory=list)
|
||||
pause_reason: str = ""
|
||||
message: str = ""
|
||||
|
||||
|
||||
def _env(name: str, default: str) -> str:
|
||||
return os.environ.get(name, default).strip()
|
||||
|
||||
|
||||
def _env_int(name: str, default: int) -> int:
|
||||
try:
|
||||
return int(os.environ.get(name, default))
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _env_bool_enabled(name: str) -> bool:
|
||||
val = os.environ.get(name, "1").strip().lower()
|
||||
return val not in ("0", "false", "no", "off", "")
|
||||
|
||||
|
||||
def build_pause_payload(
|
||||
action: Dict[str, Any],
|
||||
replay_state: Dict[str, Any],
|
||||
last_screenshot: Optional[str],
|
||||
) -> PausePayload:
|
||||
"""Construit le payload de pause enrichi pour une action pause_for_human."""
|
||||
params = action.get("parameters") or {}
|
||||
message = params.get("message", "Validation requise")
|
||||
safety_level = params.get("safety_level")
|
||||
declarative = params.get("safety_checks") or []
|
||||
|
||||
# Normalisation des checks déclaratifs
|
||||
checks: List[Dict[str, Any]] = []
|
||||
for d in declarative:
|
||||
checks.append({
|
||||
"id": d.get("id") or f"decl_{uuid.uuid4().hex[:6]}",
|
||||
"label": d.get("label", "Validation"),
|
||||
"required": bool(d.get("required", True)),
|
||||
"source": "declarative",
|
||||
"evidence": None,
|
||||
})
|
||||
|
||||
# Ajout LLM contextual si applicable
|
||||
if safety_level == "medical_critical" and _env_bool_enabled("RPA_SAFETY_CHECKS_LLM_ENABLED"):
|
||||
try:
|
||||
additional = _call_llm_for_contextual_checks(
|
||||
action=action,
|
||||
replay_state=replay_state,
|
||||
last_screenshot=last_screenshot,
|
||||
existing_labels=[c["label"] for c in checks],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("safety_checks LLM exception (%s) — fallback safe", e)
|
||||
additional = []
|
||||
|
||||
for a in additional:
|
||||
checks.append({
|
||||
"id": f"llm_{uuid.uuid4().hex[:6]}",
|
||||
"label": a.get("label", ""),
|
||||
"required": False, # checks LLM = informationnels, pas obligatoires V1
|
||||
"source": "llm_contextual",
|
||||
"evidence": a.get("evidence", ""),
|
||||
})
|
||||
|
||||
return PausePayload(
|
||||
checks=checks,
|
||||
pause_reason="",
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
def _call_llm_for_contextual_checks(
|
||||
action: Dict[str, Any],
|
||||
replay_state: Dict[str, Any],
|
||||
last_screenshot: Optional[str],
|
||||
existing_labels: List[str],
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Appelle Ollama en mode JSON strict pour générer 0-N checks contextuels.
|
||||
|
||||
Returns:
|
||||
List[{label, evidence}] (max RPA_SAFETY_CHECKS_LLM_MAX_CHECKS).
|
||||
[] sur tout échec (timeout, JSON invalide, exception).
|
||||
"""
|
||||
import requests
|
||||
|
||||
model = _env("RPA_SAFETY_CHECKS_LLM_MODEL", "medgemma:4b")
|
||||
timeout_s = _env_int("RPA_SAFETY_CHECKS_LLM_TIMEOUT_S", 5)
|
||||
max_checks = _env_int("RPA_SAFETY_CHECKS_LLM_MAX_CHECKS", 3)
|
||||
ollama_url = _env("OLLAMA_URL", "http://localhost:11434")
|
||||
|
||||
params = action.get("parameters") or {}
|
||||
workflow_message = params.get("message", "")
|
||||
existing = ", ".join(existing_labels) if existing_labels else "aucun"
|
||||
|
||||
prompt = f"""Tu es Léa, assistante médicale supervisée.
|
||||
Avant de continuer le workflow, tu dois lister 0 à {max_checks} vérifications supplémentaires
|
||||
que l'humain doit acquitter, en regardant l'écran actuel.
|
||||
|
||||
Contexte workflow : {workflow_message}
|
||||
Checks déjà demandés : {existing}
|
||||
|
||||
NE répète PAS un check déjà demandé.
|
||||
Si rien d'inhabituel à signaler, retourne {{"additional_checks": []}}.
|
||||
|
||||
Réponds UNIQUEMENT en JSON :
|
||||
{{
|
||||
"additional_checks": [
|
||||
{{"label": "string court", "evidence": "ce que tu as vu d'inhabituel"}}
|
||||
]
|
||||
}}
|
||||
"""
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"options": {"temperature": 0.1, "num_predict": 200},
|
||||
}
|
||||
|
||||
if last_screenshot and os.path.isfile(last_screenshot):
|
||||
try:
|
||||
with open(last_screenshot, "rb") as f:
|
||||
payload["images"] = [base64.b64encode(f.read()).decode("ascii")]
|
||||
except Exception as e:
|
||||
logger.debug("safety_checks: lecture screenshot échouée (%s) — appel sans image", e)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{ollama_url}/api/generate",
|
||||
json=payload,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning("safety_checks LLM HTTP %s", response.status_code)
|
||||
return []
|
||||
text = response.json().get("response", "").strip()
|
||||
except requests.Timeout:
|
||||
logger.warning("safety_checks LLM timeout (%ss)", timeout_s)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning("safety_checks LLM erreur réseau: %s", e)
|
||||
return []
|
||||
|
||||
# format=json garantit normalement du JSON valide
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("safety_checks LLM JSON invalide (%s) — fallback safe", e)
|
||||
return []
|
||||
|
||||
additional = parsed.get("additional_checks") or []
|
||||
if not isinstance(additional, list):
|
||||
return []
|
||||
|
||||
# Filtre + tronc
|
||||
valid = []
|
||||
for item in additional[:max_checks]:
|
||||
if isinstance(item, dict) and item.get("label"):
|
||||
valid.append({
|
||||
"label": str(item["label"])[:200],
|
||||
"evidence": str(item.get("evidence", ""))[:300],
|
||||
})
|
||||
return valid
|
||||
111
tests/unit/test_safety_checks_provider.py
Normal file
111
tests/unit/test_safety_checks_provider.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# tests/unit/test_safety_checks_provider.py
|
||||
"""Tests unitaires SafetyChecksProvider (QW4)."""
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from agent_v0.server_v1.safety_checks_provider import build_pause_payload, PausePayload
|
||||
|
||||
|
||||
def _action(safety_level=None, declarative_checks=None, message="Validation"):
|
||||
params = {"message": message}
|
||||
if safety_level:
|
||||
params["safety_level"] = safety_level
|
||||
if declarative_checks is not None:
|
||||
params["safety_checks"] = declarative_checks
|
||||
return {"type": "pause_for_human", "parameters": params}
|
||||
|
||||
|
||||
def test_only_declarative_when_no_safety_level():
|
||||
"""Pas de safety_level → uniquement les checks déclaratifs, pas d'appel LLM."""
|
||||
decl = [{"id": "c1", "label": "Vérifier IPP", "required": True}]
|
||||
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks") as mock_llm:
|
||||
payload = build_pause_payload(_action(declarative_checks=decl), {}, last_screenshot=None)
|
||||
mock_llm.assert_not_called()
|
||||
assert len(payload.checks) == 1
|
||||
assert payload.checks[0]["source"] == "declarative"
|
||||
|
||||
|
||||
def test_hybrid_appends_llm_checks_on_medical_critical(monkeypatch):
|
||||
"""safety_level=medical_critical → LLM appelé, checks concaténés."""
|
||||
decl = [{"id": "c1", "label": "Vérifier IPP", "required": True}]
|
||||
llm_resp = [{"label": "Nom patient suspect à l'écran", "evidence": "vu un nom différent"}]
|
||||
|
||||
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks",
|
||||
return_value=llm_resp) as mock_llm:
|
||||
payload = build_pause_payload(
|
||||
_action(safety_level="medical_critical", declarative_checks=decl),
|
||||
{}, last_screenshot="/tmp/fake.png",
|
||||
)
|
||||
mock_llm.assert_called_once()
|
||||
assert len(payload.checks) == 2
|
||||
assert payload.checks[0]["source"] == "declarative"
|
||||
assert payload.checks[1]["source"] == "llm_contextual"
|
||||
assert payload.checks[1]["evidence"] == "vu un nom différent"
|
||||
|
||||
|
||||
def test_llm_timeout_falls_back_to_declarative_only():
|
||||
"""LLM timeout → additional_checks=[], pas de crash, déclaratifs gardés."""
|
||||
decl = [{"id": "c1", "label": "Vérifier IPP", "required": True}]
|
||||
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks",
|
||||
return_value=[]) as mock_llm:
|
||||
payload = build_pause_payload(
|
||||
_action(safety_level="medical_critical", declarative_checks=decl),
|
||||
{}, last_screenshot="/tmp/fake.png",
|
||||
)
|
||||
assert len(payload.checks) == 1
|
||||
assert payload.checks[0]["source"] == "declarative"
|
||||
|
||||
|
||||
def test_llm_invalid_response_falls_back():
|
||||
"""Si _call_llm retourne [] (parse échoué en interne) → fallback safe."""
|
||||
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks",
|
||||
return_value=[]):
|
||||
payload = build_pause_payload(
|
||||
_action(safety_level="medical_critical", declarative_checks=[]),
|
||||
{}, last_screenshot="/tmp/fake.png",
|
||||
)
|
||||
assert payload.checks == []
|
||||
|
||||
|
||||
def test_kill_switch_disables_llm_call(monkeypatch):
|
||||
"""RPA_SAFETY_CHECKS_LLM_ENABLED=0 → LLM jamais appelé."""
|
||||
monkeypatch.setenv("RPA_SAFETY_CHECKS_LLM_ENABLED", "0")
|
||||
decl = [{"id": "c1", "label": "X", "required": True}]
|
||||
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks") as mock_llm:
|
||||
payload = build_pause_payload(
|
||||
_action(safety_level="medical_critical", declarative_checks=decl),
|
||||
{}, last_screenshot="/tmp/fake.png",
|
||||
)
|
||||
mock_llm.assert_not_called()
|
||||
assert len(payload.checks) == 1
|
||||
|
||||
|
||||
def test_max_checks_respected(monkeypatch):
|
||||
"""RPA_SAFETY_CHECKS_LLM_MAX_CHECKS=2 → max 2 checks LLM ajoutés."""
|
||||
monkeypatch.setenv("RPA_SAFETY_CHECKS_LLM_MAX_CHECKS", "2")
|
||||
decl = []
|
||||
llm_resp = [
|
||||
{"label": f"Check {i}", "evidence": f"e{i}"} for i in range(5)
|
||||
]
|
||||
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks",
|
||||
return_value=llm_resp[:2]): # provider tronque déjà
|
||||
payload = build_pause_payload(
|
||||
_action(safety_level="medical_critical", declarative_checks=decl),
|
||||
{}, last_screenshot="/tmp/fake.png",
|
||||
)
|
||||
assert len(payload.checks) == 2
|
||||
|
||||
|
||||
def test_empty_declarative_with_llm_returns_only_llm():
|
||||
"""Pas de déclaratif + LLM ajoute 2 checks → payload contient les 2."""
|
||||
llm_resp = [{"label": "Vérifier date", "evidence": "date 1900 suspecte"},
|
||||
{"label": "Vérifier devise", "evidence": "montant en USD au lieu d'EUR"}]
|
||||
with patch("agent_v0.server_v1.safety_checks_provider._call_llm_for_contextual_checks",
|
||||
return_value=llm_resp):
|
||||
payload = build_pause_payload(
|
||||
_action(safety_level="medical_critical", declarative_checks=[]),
|
||||
{}, last_screenshot="/tmp/fake.png",
|
||||
)
|
||||
assert len(payload.checks) == 2
|
||||
assert all(c["source"] == "llm_contextual" for c in payload.checks)
|
||||
Reference in New Issue
Block a user