Files
rpa_vision_v3/tools/bench_safety_checks_models.py
Dom 0a02a6ec9c
Some checks failed
tests / Lint (ruff + black) (push) Successful in 15s
tests / Tests unitaires (sans GPU) (push) Failing after 14s
tests / Tests sécurité (critique) (push) Has been skipped
feat(qw4): bench rigoureux LLM safety_checks → gemma4:latest par défaut
Bench 5 modèles × 5 scénarios × cold+warm sur RTX 5070 :
- gemma4:latest : warm 2.9s, JSON 92%, détection 46% → gagnant
- qwen2.5vl:7b : warm 6.6s, détection 23% (trop lent)
- qwen2.5vl:3b : warm 2.0s, détection 8% (vérifie pour vérifier)
- medgemma:4b : warm 0.5s, détection 0% (refuse de signaler) → mauvais
  défaut initial, corrigé
- qwen3-vl:8b : 0% JSON valide (ignore format=json Ollama) → écarté

Modifications safety_checks_provider.py :
- RPA_SAFETY_CHECKS_LLM_MODEL défaut: medgemma:4b → gemma4:latest
- RPA_SAFETY_CHECKS_LLM_TIMEOUT_S défaut: 5 → 7 (warm 2.9s + marge)

Doc complète : docs/BENCH_SAFETY_CHECKS_2026-05-06.md
Script : tools/bench_safety_checks_models.py (reproductible, ~10-15 min)

Limite assumée : 46% de détection. À présenter en démo comme aide médecin,
pas certification. Amélioration V2 = prompt plus dirigé sur champs à vérifier.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 09:23:09 +02:00

438 lines
17 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Bench rigoureux des modèles candidats pour QW4 safety_checks contextuels.
Méthodologie :
- 5 screenshots synthétiques avec différentes anomalies cliniques
- 4 modèles candidats (gemma4:e4b sur :11435, qwen2.5vl:7b/3b et medgemma:4b sur :11434)
- Pour chaque modèle :
1. Décharger TOUS les modèles déjà en VRAM (keep_alive=0)
2. 1er appel = cold start chronométré (1er screenshot)
3. 12 appels warm = (4 autres screenshots × 3 runs)
4. Mesurer : cold_start, warm avg/p95, taux détection, JSON valide
Usage : .venv/bin/python tools/bench_safety_checks_models.py
"""
from __future__ import annotations
import base64
import json
import os
import statistics
import time
from dataclasses import dataclass, field
from typing import Any
import requests
from PIL import Image, ImageDraw, ImageFont
OLLAMA_PRIMARY = os.environ.get("OLLAMA_URL", "http://localhost:11434")
OLLAMA_SECONDARY = os.environ.get("GEMMA4_URL", "http://localhost:11435")
# Configuration des candidats : (nom, url, type)
CANDIDATES = [
("gemma4:latest", OLLAMA_PRIMARY, "vlm_default"),
("qwen3-vl:8b", OLLAMA_PRIMARY, "vision_qwen3_8b"),
("qwen2.5vl:7b", OLLAMA_PRIMARY, "vision_qwen25_7b"),
("qwen2.5vl:3b", OLLAMA_PRIMARY, "vision_qwen25_3b"),
("medgemma:4b", OLLAMA_PRIMARY, "medical_4b"),
]
TIMEOUT_S = int(os.environ.get("BENCH_TIMEOUT", "60")) # large pour ne rien rater
MAX_CHECKS = 3
WORKFLOW_MESSAGE = "Validation T2A avant codage UHCD"
EXISTING_LABELS: list[str] = []
WARM_RUNS_PER_SCREENSHOT = 3 # warm = 4 autres screenshots × 3 runs = 12 mesures
# ---------------------------------------------------------------------------
# Scénarios : 5 screenshots avec anomalies différentes
# ---------------------------------------------------------------------------
@dataclass
class Scenario:
label: str # nom court
rows: list[tuple[str, str]]
anomaly_keywords: list[str] # mots indiquant que l'anomalie est repérée
SCENARIOS = [
Scenario(
label="ddn_aberrante",
rows=[
("Nom :", "DUPONT Marie"),
("IPP :", "25003284"),
("Date de naissance :", "1900-01-01"), # ANOMALIE
("Sexe :", "F"),
("Date d'admission :", "2026-05-05 14:32"),
("Service :", "URGENCES"),
("Motif :", "Douleur abdominale aiguë"),
("Diagnostic principal :", "K35.8 - Appendicite aiguë"),
("Forfait facturation :", "UHCD - Forfait 24h"),
],
anomaly_keywords=["1900", "naissance", "ddn", "date"],
),
Scenario(
label="ipp_incoherent",
rows=[
("Nom :", "MARTIN Paul"),
("IPP :", "ABC@@##XYZ"), # ANOMALIE : non numérique
("Date de naissance :", "1965-04-12"),
("Sexe :", "M"),
("Date d'admission :", "2026-05-06 09:15"),
("Service :", "URGENCES"),
("Motif :", "Chute mécanique"),
("Diagnostic principal :", "S52.5 - Fracture du radius distal"),
("Forfait facturation :", "UHCD - Forfait 24h"),
],
anomaly_keywords=["ipp", "abc", "format", "incohérent", "incoherent", "invalide"],
),
Scenario(
label="diagnostic_vide",
rows=[
("Nom :", "BERNARD Sophie"),
("IPP :", "25004191"),
("Date de naissance :", "1972-11-08"),
("Sexe :", "F"),
("Date d'admission :", "2026-05-06 10:42"),
("Service :", "URGENCES"),
("Motif :", "Céphalées"),
("Diagnostic principal :", ""), # ANOMALIE : vide
("Forfait facturation :", "UHCD - Forfait 24h"),
],
anomaly_keywords=["diagnostic", "vide", "blanc", "absent", "manque", "non renseigné", "non renseigne"],
),
Scenario(
label="cim_inadapte_age",
rows=[
("Nom :", "PETIT Lucas"),
("IPP :", "25004222"),
("Date de naissance :", "2025-11-01"), # nourrisson 6 mois
("Sexe :", "M"),
("Date d'admission :", "2026-05-06 11:00"),
("Service :", "URGENCES PEDIATRIQUES"),
("Motif :", "Pleurs persistants"),
("Diagnostic principal :", "M19.9 - Arthrose, sans précision"), # ANOMALIE
("Forfait facturation :", "UHCD - Forfait 24h"),
],
anomaly_keywords=["arthrose", "âge", "age", "nourrisson", "incohérent", "incoherent", "m19", "incompatible"],
),
Scenario(
label="forfait_incoherent_duree",
rows=[
("Nom :", "ROUSSEAU Jean"),
("IPP :", "25004317"),
("Date de naissance :", "1958-03-22"),
("Sexe :", "M"),
("Date d'admission :", "2026-05-06 08:00"),
("Date de sortie :", "2026-05-06 09:00"), # 1h
("Service :", "URGENCES"),
("Motif :", "Bilan biologique"),
("Diagnostic principal :", "Z00.0 - Examen médical général"),
("Forfait facturation :", "UHCD - Forfait 24h"), # ANOMALIE : 1h ≠ UHCD 24h
],
anomaly_keywords=["forfait", "uhcd", "durée", "duree", "1h", "incohérent", "incoherent", "24h"],
),
]
# ---------------------------------------------------------------------------
# Génération des screenshots
# ---------------------------------------------------------------------------
def make_screenshot(scenario: Scenario, path: str) -> None:
"""Crée un PNG du dossier patient pour un scénario donné."""
img = Image.new("RGB", (1024, 600), color="white")
draw = ImageDraw.Draw(img)
try:
font_title = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 22)
font_body = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 18)
except OSError:
font_title = ImageFont.load_default()
font_body = ImageFont.load_default()
draw.text((20, 20), "DOSSIER PATIENT - URGENCES UHCD", fill="black", font=font_title)
draw.line([(20, 55), (1004, 55)], fill="black", width=2)
y = 80
for label, value in scenario.rows:
draw.text((30, y), label, fill="black", font=font_body)
draw.text((280, y), value, fill="#1f2937", font=font_body)
y += 35
img.save(path, format="PNG")
def encode_image(path: str) -> str:
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode("ascii")
def build_prompt() -> str:
existing = ", ".join(EXISTING_LABELS) if EXISTING_LABELS else "aucun"
return f"""Tu es Léa, assistante médicale supervisée.
Avant de continuer le workflow, tu dois lister 0 à {MAX_CHECKS} vérifications supplémentaires
que l'humain doit acquitter, en regardant l'écran actuel.
Contexte workflow : {WORKFLOW_MESSAGE}
Checks déjà demandés : {existing}
NE répète PAS un check déjà demandé.
Si rien d'inhabituel à signaler, retourne {{"additional_checks": []}}.
Réponds UNIQUEMENT en JSON :
{{
"additional_checks": [
{{"label": "string court", "evidence": "ce que tu as vu d'inhabituel"}}
]
}}
"""
# ---------------------------------------------------------------------------
# Gestion VRAM Ollama (déchargement)
# ---------------------------------------------------------------------------
def list_loaded_models(url: str) -> list[str]:
"""Retourne la liste des modèles actuellement en VRAM sur cet Ollama."""
try:
resp = requests.get(f"{url}/api/ps", timeout=5)
if resp.status_code == 200:
data = resp.json()
return [m["name"] for m in data.get("models", [])]
except Exception:
pass
return []
def unload_all_models() -> None:
"""Décharge tous les modèles en VRAM sur les 2 Ollama (keep_alive=0)."""
for url in (OLLAMA_PRIMARY, OLLAMA_SECONDARY):
loaded = list_loaded_models(url)
for model_name in loaded:
try:
requests.post(
f"{url}/api/generate",
json={"model": model_name, "prompt": "", "keep_alive": 0, "stream": False},
timeout=10,
)
except Exception:
pass
# Petit temps pour laisser le GC GPU faire son travail
time.sleep(2)
# ---------------------------------------------------------------------------
# Appel modèle + parsing
# ---------------------------------------------------------------------------
@dataclass
class CallResult:
elapsed_s: float
error: str = ""
raw: str = ""
checks: list[dict] = field(default_factory=list)
def call_model(model: str, url: str, prompt: str, image_b64: str) -> CallResult:
payload = {
"model": model,
"prompt": prompt,
"stream": False,
"format": "json",
"options": {"temperature": 0.1, "num_predict": 250},
"images": [image_b64],
}
t0 = time.perf_counter()
try:
resp = requests.post(f"{url}/api/generate", json=payload, timeout=TIMEOUT_S)
elapsed = time.perf_counter() - t0
except requests.Timeout:
return CallResult(elapsed_s=TIMEOUT_S, error="TIMEOUT")
except Exception as e:
return CallResult(elapsed_s=time.perf_counter() - t0, error=f"NETWORK:{type(e).__name__}")
if resp.status_code != 200:
return CallResult(elapsed_s=elapsed, error=f"HTTP_{resp.status_code}", raw=resp.text[:200])
raw = resp.json().get("response", "").strip()
try:
parsed = json.loads(raw)
checks = parsed.get("additional_checks") or []
if not isinstance(checks, list):
checks = []
return CallResult(elapsed_s=elapsed, raw=raw[:300], checks=checks)
except json.JSONDecodeError as e:
return CallResult(elapsed_s=elapsed, error=f"JSON:{type(e).__name__}", raw=raw[:200])
def detects_anomaly(scenario: Scenario, checks: list[dict]) -> bool:
blob = " ".join(
f"{c.get('label', '')} {c.get('evidence', '')}".lower()
for c in checks
)
return any(pat.lower() in blob for pat in scenario.anomaly_keywords)
# ---------------------------------------------------------------------------
# Bench main
# ---------------------------------------------------------------------------
@dataclass
class ModelStats:
model: str
cold_s: float = 0.0
warm_times: list[float] = field(default_factory=list)
detection_count: int = 0
detection_total: int = 0
json_valid_count: int = 0
json_valid_total: int = 0
errors: list[str] = field(default_factory=list)
sample_checks: list[tuple[str, list[dict]]] = field(default_factory=list) # (scenario_label, checks)
def run_bench_for_model(model: str, url: str, screenshots: list[tuple[Scenario, str]]) -> ModelStats:
print(f"\n══════════════════════════════════════════════════════════")
print(f" MODEL: {model} ({url})")
print(f"══════════════════════════════════════════════════════════")
# Décharger tout
print(f" [1/3] Déchargement VRAM...", end=" ", flush=True)
unload_all_models()
loaded_after = list_loaded_models(OLLAMA_PRIMARY) + list_loaded_models(OLLAMA_SECONDARY)
print(f"OK (loaded={loaded_after if loaded_after else 'aucun'})")
stats = ModelStats(model=model)
prompt = build_prompt()
# Cold start sur le 1er screenshot
scen0, path0 = screenshots[0]
img_b64 = encode_image(path0)
print(f" [2/3] Cold start ({scen0.label})...", end=" ", flush=True)
r0 = call_model(model, url, prompt, img_b64)
stats.cold_s = r0.elapsed_s
if r0.error:
print(f"{r0.error} ({r0.elapsed_s:.1f}s)")
stats.errors.append(f"cold:{scen0.label}:{r0.error}")
else:
det = detects_anomaly(scen0, r0.checks)
stats.detection_count += int(det)
stats.detection_total += 1
stats.json_valid_count += 1
stats.json_valid_total += 1
stats.sample_checks.append((scen0.label, r0.checks))
print(f"{'' if det else '⚠️'} {len(r0.checks)} check(s) en {r0.elapsed_s:.1f}s (det={det})")
# Warm runs sur les 4 autres screenshots × N runs
print(f" [3/3] Warm runs ({len(screenshots)-1} scenarios × {WARM_RUNS_PER_SCREENSHOT} runs)...")
for scen, path in screenshots[1:]:
img_b64 = encode_image(path)
for run_idx in range(WARM_RUNS_PER_SCREENSHOT):
r = call_model(model, url, prompt, img_b64)
if r.error:
stats.errors.append(f"{scen.label}:run{run_idx}:{r.error}")
stats.json_valid_total += 1
stats.detection_total += 1
print(f" {scen.label} run{run_idx}: ❌ {r.error}")
continue
stats.warm_times.append(r.elapsed_s)
stats.json_valid_count += 1
stats.json_valid_total += 1
det = detects_anomaly(scen, r.checks)
stats.detection_count += int(det)
stats.detection_total += 1
if run_idx == 0:
stats.sample_checks.append((scen.label, r.checks))
print(f" {scen.label} run{run_idx}: {'' if det else '⚠️'} {len(r.checks)} check(s) en {r.elapsed_s:.1f}s")
return stats
def print_summary_table(all_stats: list[ModelStats]) -> None:
print("\n\n══════════════════════════════════════════════════════════")
print(" SYNTHÈSE")
print("══════════════════════════════════════════════════════════\n")
print("| Modèle | Cold (s) | Warm avg (s) | Warm p95 (s) | JSON | Détection | Notes |")
print("|---|---:|---:|---:|---:|---:|---|")
for s in all_stats:
if s.warm_times:
warm_avg = statistics.mean(s.warm_times)
warm_p95 = sorted(s.warm_times)[int(len(s.warm_times) * 0.95) - 1] if len(s.warm_times) > 1 else s.warm_times[0]
else:
warm_avg = warm_p95 = 0.0
json_pct = (s.json_valid_count / s.json_valid_total * 100) if s.json_valid_total else 0
det_pct = (s.detection_count / s.detection_total * 100) if s.detection_total else 0
notes = f"{len(s.errors)} err" if s.errors else "OK"
print(f"| `{s.model}` | {s.cold_s:.1f} | {warm_avg:.1f} | {warm_p95:.1f} | "
f"{json_pct:.0f}% ({s.json_valid_count}/{s.json_valid_total}) | "
f"{det_pct:.0f}% ({s.detection_count}/{s.detection_total}) | {notes} |")
print("\n## Détail des checks par scénario\n")
for s in all_stats:
print(f"\n### `{s.model}`")
if s.errors:
print(f"_Erreurs ({len(s.errors)})_ : {s.errors[:5]}{'...' if len(s.errors) > 5 else ''}")
for label, checks in s.sample_checks:
if not checks:
print(f"- **{label}** : _aucun check_")
else:
for c in checks[:2]:
print(f"- **{label}** : {c.get('label', '?')} — _{c.get('evidence', '?')[:120]}_")
def pick_winner(all_stats: list[ModelStats]) -> ModelStats | None:
"""Le gagnant : meilleur taux détection, départage par warm avg."""
valid = [s for s in all_stats if s.warm_times]
if not valid:
return None
# Tri : détection desc puis warm avg asc
valid.sort(key=lambda s: (-(s.detection_count / max(s.detection_total, 1)), statistics.mean(s.warm_times)))
return valid[0]
def main() -> int:
# Génération des 5 screenshots
print("📸 Génération des 5 screenshots synthétiques :")
screenshots: list[tuple[Scenario, str]] = []
for scen in SCENARIOS:
path = f"/tmp/bench_safety_{scen.label}.png"
make_screenshot(scen, path)
print(f" - {scen.label}{path}")
screenshots.append((scen, path))
print(f"\n⏱ Timeout par appel : {TIMEOUT_S}s")
print(f"🔄 Warm runs par scénario : {WARM_RUNS_PER_SCREENSHOT}")
print(f"📊 Total mesures par modèle : 1 cold + {(len(SCENARIOS)-1) * WARM_RUNS_PER_SCREENSHOT} warm = "
f"{1 + (len(SCENARIOS)-1) * WARM_RUNS_PER_SCREENSHOT}")
print(f"🤖 Candidats : {[c[0] for c in CANDIDATES]}")
all_stats: list[ModelStats] = []
for model, url, _ in CANDIDATES:
try:
stats = run_bench_for_model(model, url, screenshots)
all_stats.append(stats)
except KeyboardInterrupt:
print(f"\n⚠️ Interrompu pendant {model}, on saute le reste")
break
except Exception as e:
print(f"\n❌ Crash bench {model}: {e}")
all_stats.append(ModelStats(model=model, errors=[f"crash:{e}"]))
print_summary_table(all_stats)
winner = pick_winner(all_stats)
print("\n## Recommandation\n")
if winner is None:
print("⚠️ Aucun modèle exploitable. Décision manuelle nécessaire.")
return 1
det_pct = winner.detection_count / max(winner.detection_total, 1) * 100
warm_avg = statistics.mean(winner.warm_times)
print(f"🏆 **{winner.model}** : détection {det_pct:.0f}%, warm avg {warm_avg:.1f}s, cold {winner.cold_s:.1f}s")
print(f"\nPour fixer en production :")
print(f"```bash\nsudo systemctl edit rpa-streaming")
print(f"# [Service]\n# Environment=RPA_SAFETY_CHECKS_LLM_MODEL={winner.model}")
print(f"sudo systemctl restart rpa-streaming\n```")
return 0
if __name__ == "__main__":
raise SystemExit(main())