Bench 5 modèles × 5 scénarios × cold+warm sur RTX 5070 : - gemma4:latest : warm 2.9s, JSON 92%, détection 46% → gagnant - qwen2.5vl:7b : warm 6.6s, détection 23% (trop lent) - qwen2.5vl:3b : warm 2.0s, détection 8% (vérifie pour vérifier) - medgemma:4b : warm 0.5s, détection 0% (refuse de signaler) → mauvais défaut initial, corrigé - qwen3-vl:8b : 0% JSON valide (ignore format=json Ollama) → écarté Modifications safety_checks_provider.py : - RPA_SAFETY_CHECKS_LLM_MODEL défaut: medgemma:4b → gemma4:latest - RPA_SAFETY_CHECKS_LLM_TIMEOUT_S défaut: 5 → 7 (warm 2.9s + marge) Doc complète : docs/BENCH_SAFETY_CHECKS_2026-05-06.md Script : tools/bench_safety_checks_models.py (reproductible, ~10-15 min) Limite assumée : 46% de détection. À présenter en démo comme aide médecin, pas certification. Amélioration V2 = prompt plus dirigé sur champs à vérifier. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
438 lines
17 KiB
Python
Executable File
438 lines
17 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""Bench rigoureux des modèles candidats pour QW4 safety_checks contextuels.
|
||
|
||
Méthodologie :
|
||
- 5 screenshots synthétiques avec différentes anomalies cliniques
|
||
- 4 modèles candidats (gemma4:e4b sur :11435, qwen2.5vl:7b/3b et medgemma:4b sur :11434)
|
||
- Pour chaque modèle :
|
||
1. Décharger TOUS les modèles déjà en VRAM (keep_alive=0)
|
||
2. 1er appel = cold start chronométré (1er screenshot)
|
||
3. 12 appels warm = (4 autres screenshots × 3 runs)
|
||
4. Mesurer : cold_start, warm avg/p95, taux détection, JSON valide
|
||
|
||
Usage : .venv/bin/python tools/bench_safety_checks_models.py
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import base64
|
||
import json
|
||
import os
|
||
import statistics
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from typing import Any
|
||
|
||
import requests
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
|
||
|
||
OLLAMA_PRIMARY = os.environ.get("OLLAMA_URL", "http://localhost:11434")
|
||
OLLAMA_SECONDARY = os.environ.get("GEMMA4_URL", "http://localhost:11435")
|
||
|
||
# Configuration des candidats : (nom, url, type)
|
||
CANDIDATES = [
|
||
("gemma4:latest", OLLAMA_PRIMARY, "vlm_default"),
|
||
("qwen3-vl:8b", OLLAMA_PRIMARY, "vision_qwen3_8b"),
|
||
("qwen2.5vl:7b", OLLAMA_PRIMARY, "vision_qwen25_7b"),
|
||
("qwen2.5vl:3b", OLLAMA_PRIMARY, "vision_qwen25_3b"),
|
||
("medgemma:4b", OLLAMA_PRIMARY, "medical_4b"),
|
||
]
|
||
|
||
TIMEOUT_S = int(os.environ.get("BENCH_TIMEOUT", "60")) # large pour ne rien rater
|
||
MAX_CHECKS = 3
|
||
WORKFLOW_MESSAGE = "Validation T2A avant codage UHCD"
|
||
EXISTING_LABELS: list[str] = []
|
||
WARM_RUNS_PER_SCREENSHOT = 3 # warm = 4 autres screenshots × 3 runs = 12 mesures
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Scénarios : 5 screenshots avec anomalies différentes
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@dataclass
|
||
class Scenario:
|
||
label: str # nom court
|
||
rows: list[tuple[str, str]]
|
||
anomaly_keywords: list[str] # mots indiquant que l'anomalie est repérée
|
||
|
||
|
||
SCENARIOS = [
|
||
Scenario(
|
||
label="ddn_aberrante",
|
||
rows=[
|
||
("Nom :", "DUPONT Marie"),
|
||
("IPP :", "25003284"),
|
||
("Date de naissance :", "1900-01-01"), # ANOMALIE
|
||
("Sexe :", "F"),
|
||
("Date d'admission :", "2026-05-05 14:32"),
|
||
("Service :", "URGENCES"),
|
||
("Motif :", "Douleur abdominale aiguë"),
|
||
("Diagnostic principal :", "K35.8 - Appendicite aiguë"),
|
||
("Forfait facturation :", "UHCD - Forfait 24h"),
|
||
],
|
||
anomaly_keywords=["1900", "naissance", "ddn", "date"],
|
||
),
|
||
Scenario(
|
||
label="ipp_incoherent",
|
||
rows=[
|
||
("Nom :", "MARTIN Paul"),
|
||
("IPP :", "ABC@@##XYZ"), # ANOMALIE : non numérique
|
||
("Date de naissance :", "1965-04-12"),
|
||
("Sexe :", "M"),
|
||
("Date d'admission :", "2026-05-06 09:15"),
|
||
("Service :", "URGENCES"),
|
||
("Motif :", "Chute mécanique"),
|
||
("Diagnostic principal :", "S52.5 - Fracture du radius distal"),
|
||
("Forfait facturation :", "UHCD - Forfait 24h"),
|
||
],
|
||
anomaly_keywords=["ipp", "abc", "format", "incohérent", "incoherent", "invalide"],
|
||
),
|
||
Scenario(
|
||
label="diagnostic_vide",
|
||
rows=[
|
||
("Nom :", "BERNARD Sophie"),
|
||
("IPP :", "25004191"),
|
||
("Date de naissance :", "1972-11-08"),
|
||
("Sexe :", "F"),
|
||
("Date d'admission :", "2026-05-06 10:42"),
|
||
("Service :", "URGENCES"),
|
||
("Motif :", "Céphalées"),
|
||
("Diagnostic principal :", ""), # ANOMALIE : vide
|
||
("Forfait facturation :", "UHCD - Forfait 24h"),
|
||
],
|
||
anomaly_keywords=["diagnostic", "vide", "blanc", "absent", "manque", "non renseigné", "non renseigne"],
|
||
),
|
||
Scenario(
|
||
label="cim_inadapte_age",
|
||
rows=[
|
||
("Nom :", "PETIT Lucas"),
|
||
("IPP :", "25004222"),
|
||
("Date de naissance :", "2025-11-01"), # nourrisson 6 mois
|
||
("Sexe :", "M"),
|
||
("Date d'admission :", "2026-05-06 11:00"),
|
||
("Service :", "URGENCES PEDIATRIQUES"),
|
||
("Motif :", "Pleurs persistants"),
|
||
("Diagnostic principal :", "M19.9 - Arthrose, sans précision"), # ANOMALIE
|
||
("Forfait facturation :", "UHCD - Forfait 24h"),
|
||
],
|
||
anomaly_keywords=["arthrose", "âge", "age", "nourrisson", "incohérent", "incoherent", "m19", "incompatible"],
|
||
),
|
||
Scenario(
|
||
label="forfait_incoherent_duree",
|
||
rows=[
|
||
("Nom :", "ROUSSEAU Jean"),
|
||
("IPP :", "25004317"),
|
||
("Date de naissance :", "1958-03-22"),
|
||
("Sexe :", "M"),
|
||
("Date d'admission :", "2026-05-06 08:00"),
|
||
("Date de sortie :", "2026-05-06 09:00"), # 1h
|
||
("Service :", "URGENCES"),
|
||
("Motif :", "Bilan biologique"),
|
||
("Diagnostic principal :", "Z00.0 - Examen médical général"),
|
||
("Forfait facturation :", "UHCD - Forfait 24h"), # ANOMALIE : 1h ≠ UHCD 24h
|
||
],
|
||
anomaly_keywords=["forfait", "uhcd", "durée", "duree", "1h", "incohérent", "incoherent", "24h"],
|
||
),
|
||
]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Génération des screenshots
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def make_screenshot(scenario: Scenario, path: str) -> None:
|
||
"""Crée un PNG du dossier patient pour un scénario donné."""
|
||
img = Image.new("RGB", (1024, 600), color="white")
|
||
draw = ImageDraw.Draw(img)
|
||
try:
|
||
font_title = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 22)
|
||
font_body = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 18)
|
||
except OSError:
|
||
font_title = ImageFont.load_default()
|
||
font_body = ImageFont.load_default()
|
||
|
||
draw.text((20, 20), "DOSSIER PATIENT - URGENCES UHCD", fill="black", font=font_title)
|
||
draw.line([(20, 55), (1004, 55)], fill="black", width=2)
|
||
y = 80
|
||
for label, value in scenario.rows:
|
||
draw.text((30, y), label, fill="black", font=font_body)
|
||
draw.text((280, y), value, fill="#1f2937", font=font_body)
|
||
y += 35
|
||
img.save(path, format="PNG")
|
||
|
||
|
||
def encode_image(path: str) -> str:
|
||
with open(path, "rb") as f:
|
||
return base64.b64encode(f.read()).decode("ascii")
|
||
|
||
|
||
def build_prompt() -> str:
|
||
existing = ", ".join(EXISTING_LABELS) if EXISTING_LABELS else "aucun"
|
||
return f"""Tu es Léa, assistante médicale supervisée.
|
||
Avant de continuer le workflow, tu dois lister 0 à {MAX_CHECKS} vérifications supplémentaires
|
||
que l'humain doit acquitter, en regardant l'écran actuel.
|
||
|
||
Contexte workflow : {WORKFLOW_MESSAGE}
|
||
Checks déjà demandés : {existing}
|
||
|
||
NE répète PAS un check déjà demandé.
|
||
Si rien d'inhabituel à signaler, retourne {{"additional_checks": []}}.
|
||
|
||
Réponds UNIQUEMENT en JSON :
|
||
{{
|
||
"additional_checks": [
|
||
{{"label": "string court", "evidence": "ce que tu as vu d'inhabituel"}}
|
||
]
|
||
}}
|
||
"""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Gestion VRAM Ollama (déchargement)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def list_loaded_models(url: str) -> list[str]:
|
||
"""Retourne la liste des modèles actuellement en VRAM sur cet Ollama."""
|
||
try:
|
||
resp = requests.get(f"{url}/api/ps", timeout=5)
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
return [m["name"] for m in data.get("models", [])]
|
||
except Exception:
|
||
pass
|
||
return []
|
||
|
||
|
||
def unload_all_models() -> None:
|
||
"""Décharge tous les modèles en VRAM sur les 2 Ollama (keep_alive=0)."""
|
||
for url in (OLLAMA_PRIMARY, OLLAMA_SECONDARY):
|
||
loaded = list_loaded_models(url)
|
||
for model_name in loaded:
|
||
try:
|
||
requests.post(
|
||
f"{url}/api/generate",
|
||
json={"model": model_name, "prompt": "", "keep_alive": 0, "stream": False},
|
||
timeout=10,
|
||
)
|
||
except Exception:
|
||
pass
|
||
# Petit temps pour laisser le GC GPU faire son travail
|
||
time.sleep(2)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Appel modèle + parsing
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@dataclass
|
||
class CallResult:
|
||
elapsed_s: float
|
||
error: str = ""
|
||
raw: str = ""
|
||
checks: list[dict] = field(default_factory=list)
|
||
|
||
|
||
def call_model(model: str, url: str, prompt: str, image_b64: str) -> CallResult:
|
||
payload = {
|
||
"model": model,
|
||
"prompt": prompt,
|
||
"stream": False,
|
||
"format": "json",
|
||
"options": {"temperature": 0.1, "num_predict": 250},
|
||
"images": [image_b64],
|
||
}
|
||
t0 = time.perf_counter()
|
||
try:
|
||
resp = requests.post(f"{url}/api/generate", json=payload, timeout=TIMEOUT_S)
|
||
elapsed = time.perf_counter() - t0
|
||
except requests.Timeout:
|
||
return CallResult(elapsed_s=TIMEOUT_S, error="TIMEOUT")
|
||
except Exception as e:
|
||
return CallResult(elapsed_s=time.perf_counter() - t0, error=f"NETWORK:{type(e).__name__}")
|
||
|
||
if resp.status_code != 200:
|
||
return CallResult(elapsed_s=elapsed, error=f"HTTP_{resp.status_code}", raw=resp.text[:200])
|
||
|
||
raw = resp.json().get("response", "").strip()
|
||
try:
|
||
parsed = json.loads(raw)
|
||
checks = parsed.get("additional_checks") or []
|
||
if not isinstance(checks, list):
|
||
checks = []
|
||
return CallResult(elapsed_s=elapsed, raw=raw[:300], checks=checks)
|
||
except json.JSONDecodeError as e:
|
||
return CallResult(elapsed_s=elapsed, error=f"JSON:{type(e).__name__}", raw=raw[:200])
|
||
|
||
|
||
def detects_anomaly(scenario: Scenario, checks: list[dict]) -> bool:
|
||
blob = " ".join(
|
||
f"{c.get('label', '')} {c.get('evidence', '')}".lower()
|
||
for c in checks
|
||
)
|
||
return any(pat.lower() in blob for pat in scenario.anomaly_keywords)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Bench main
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@dataclass
|
||
class ModelStats:
|
||
model: str
|
||
cold_s: float = 0.0
|
||
warm_times: list[float] = field(default_factory=list)
|
||
detection_count: int = 0
|
||
detection_total: int = 0
|
||
json_valid_count: int = 0
|
||
json_valid_total: int = 0
|
||
errors: list[str] = field(default_factory=list)
|
||
sample_checks: list[tuple[str, list[dict]]] = field(default_factory=list) # (scenario_label, checks)
|
||
|
||
|
||
def run_bench_for_model(model: str, url: str, screenshots: list[tuple[Scenario, str]]) -> ModelStats:
|
||
print(f"\n══════════════════════════════════════════════════════════")
|
||
print(f" MODEL: {model} ({url})")
|
||
print(f"══════════════════════════════════════════════════════════")
|
||
|
||
# Décharger tout
|
||
print(f" [1/3] Déchargement VRAM...", end=" ", flush=True)
|
||
unload_all_models()
|
||
loaded_after = list_loaded_models(OLLAMA_PRIMARY) + list_loaded_models(OLLAMA_SECONDARY)
|
||
print(f"OK (loaded={loaded_after if loaded_after else 'aucun'})")
|
||
|
||
stats = ModelStats(model=model)
|
||
prompt = build_prompt()
|
||
|
||
# Cold start sur le 1er screenshot
|
||
scen0, path0 = screenshots[0]
|
||
img_b64 = encode_image(path0)
|
||
print(f" [2/3] Cold start ({scen0.label})...", end=" ", flush=True)
|
||
r0 = call_model(model, url, prompt, img_b64)
|
||
stats.cold_s = r0.elapsed_s
|
||
if r0.error:
|
||
print(f"❌ {r0.error} ({r0.elapsed_s:.1f}s)")
|
||
stats.errors.append(f"cold:{scen0.label}:{r0.error}")
|
||
else:
|
||
det = detects_anomaly(scen0, r0.checks)
|
||
stats.detection_count += int(det)
|
||
stats.detection_total += 1
|
||
stats.json_valid_count += 1
|
||
stats.json_valid_total += 1
|
||
stats.sample_checks.append((scen0.label, r0.checks))
|
||
print(f"{'✅' if det else '⚠️'} {len(r0.checks)} check(s) en {r0.elapsed_s:.1f}s (det={det})")
|
||
|
||
# Warm runs sur les 4 autres screenshots × N runs
|
||
print(f" [3/3] Warm runs ({len(screenshots)-1} scenarios × {WARM_RUNS_PER_SCREENSHOT} runs)...")
|
||
for scen, path in screenshots[1:]:
|
||
img_b64 = encode_image(path)
|
||
for run_idx in range(WARM_RUNS_PER_SCREENSHOT):
|
||
r = call_model(model, url, prompt, img_b64)
|
||
if r.error:
|
||
stats.errors.append(f"{scen.label}:run{run_idx}:{r.error}")
|
||
stats.json_valid_total += 1
|
||
stats.detection_total += 1
|
||
print(f" {scen.label} run{run_idx}: ❌ {r.error}")
|
||
continue
|
||
stats.warm_times.append(r.elapsed_s)
|
||
stats.json_valid_count += 1
|
||
stats.json_valid_total += 1
|
||
det = detects_anomaly(scen, r.checks)
|
||
stats.detection_count += int(det)
|
||
stats.detection_total += 1
|
||
if run_idx == 0:
|
||
stats.sample_checks.append((scen.label, r.checks))
|
||
print(f" {scen.label} run{run_idx}: {'✅' if det else '⚠️'} {len(r.checks)} check(s) en {r.elapsed_s:.1f}s")
|
||
return stats
|
||
|
||
|
||
def print_summary_table(all_stats: list[ModelStats]) -> None:
|
||
print("\n\n══════════════════════════════════════════════════════════")
|
||
print(" SYNTHÈSE")
|
||
print("══════════════════════════════════════════════════════════\n")
|
||
print("| Modèle | Cold (s) | Warm avg (s) | Warm p95 (s) | JSON | Détection | Notes |")
|
||
print("|---|---:|---:|---:|---:|---:|---|")
|
||
for s in all_stats:
|
||
if s.warm_times:
|
||
warm_avg = statistics.mean(s.warm_times)
|
||
warm_p95 = sorted(s.warm_times)[int(len(s.warm_times) * 0.95) - 1] if len(s.warm_times) > 1 else s.warm_times[0]
|
||
else:
|
||
warm_avg = warm_p95 = 0.0
|
||
json_pct = (s.json_valid_count / s.json_valid_total * 100) if s.json_valid_total else 0
|
||
det_pct = (s.detection_count / s.detection_total * 100) if s.detection_total else 0
|
||
notes = f"{len(s.errors)} err" if s.errors else "OK"
|
||
print(f"| `{s.model}` | {s.cold_s:.1f} | {warm_avg:.1f} | {warm_p95:.1f} | "
|
||
f"{json_pct:.0f}% ({s.json_valid_count}/{s.json_valid_total}) | "
|
||
f"{det_pct:.0f}% ({s.detection_count}/{s.detection_total}) | {notes} |")
|
||
|
||
print("\n## Détail des checks par scénario\n")
|
||
for s in all_stats:
|
||
print(f"\n### `{s.model}`")
|
||
if s.errors:
|
||
print(f"_Erreurs ({len(s.errors)})_ : {s.errors[:5]}{'...' if len(s.errors) > 5 else ''}")
|
||
for label, checks in s.sample_checks:
|
||
if not checks:
|
||
print(f"- **{label}** : _aucun check_")
|
||
else:
|
||
for c in checks[:2]:
|
||
print(f"- **{label}** : {c.get('label', '?')} — _{c.get('evidence', '?')[:120]}_")
|
||
|
||
|
||
def pick_winner(all_stats: list[ModelStats]) -> ModelStats | None:
|
||
"""Le gagnant : meilleur taux détection, départage par warm avg."""
|
||
valid = [s for s in all_stats if s.warm_times]
|
||
if not valid:
|
||
return None
|
||
# Tri : détection desc puis warm avg asc
|
||
valid.sort(key=lambda s: (-(s.detection_count / max(s.detection_total, 1)), statistics.mean(s.warm_times)))
|
||
return valid[0]
|
||
|
||
|
||
def main() -> int:
|
||
# Génération des 5 screenshots
|
||
print("📸 Génération des 5 screenshots synthétiques :")
|
||
screenshots: list[tuple[Scenario, str]] = []
|
||
for scen in SCENARIOS:
|
||
path = f"/tmp/bench_safety_{scen.label}.png"
|
||
make_screenshot(scen, path)
|
||
print(f" - {scen.label} → {path}")
|
||
screenshots.append((scen, path))
|
||
|
||
print(f"\n⏱ Timeout par appel : {TIMEOUT_S}s")
|
||
print(f"🔄 Warm runs par scénario : {WARM_RUNS_PER_SCREENSHOT}")
|
||
print(f"📊 Total mesures par modèle : 1 cold + {(len(SCENARIOS)-1) * WARM_RUNS_PER_SCREENSHOT} warm = "
|
||
f"{1 + (len(SCENARIOS)-1) * WARM_RUNS_PER_SCREENSHOT}")
|
||
print(f"🤖 Candidats : {[c[0] for c in CANDIDATES]}")
|
||
|
||
all_stats: list[ModelStats] = []
|
||
for model, url, _ in CANDIDATES:
|
||
try:
|
||
stats = run_bench_for_model(model, url, screenshots)
|
||
all_stats.append(stats)
|
||
except KeyboardInterrupt:
|
||
print(f"\n⚠️ Interrompu pendant {model}, on saute le reste")
|
||
break
|
||
except Exception as e:
|
||
print(f"\n❌ Crash bench {model}: {e}")
|
||
all_stats.append(ModelStats(model=model, errors=[f"crash:{e}"]))
|
||
|
||
print_summary_table(all_stats)
|
||
|
||
winner = pick_winner(all_stats)
|
||
print("\n## Recommandation\n")
|
||
if winner is None:
|
||
print("⚠️ Aucun modèle exploitable. Décision manuelle nécessaire.")
|
||
return 1
|
||
det_pct = winner.detection_count / max(winner.detection_total, 1) * 100
|
||
warm_avg = statistics.mean(winner.warm_times)
|
||
print(f"🏆 **{winner.model}** : détection {det_pct:.0f}%, warm avg {warm_avg:.1f}s, cold {winner.cold_s:.1f}s")
|
||
print(f"\nPour fixer en production :")
|
||
print(f"```bash\nsudo systemctl edit rpa-streaming")
|
||
print(f"# [Service]\n# Environment=RPA_SAFETY_CHECKS_LLM_MODEL={winner.model}")
|
||
print(f"sudo systemctl restart rpa-streaming\n```")
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|