#!/usr/bin/env python3 """Bench rigoureux des modèles candidats pour QW4 safety_checks contextuels. Méthodologie : - 5 screenshots synthétiques avec différentes anomalies cliniques - 4 modèles candidats (gemma4:e4b sur :11435, qwen2.5vl:7b/3b et medgemma:4b sur :11434) - Pour chaque modèle : 1. Décharger TOUS les modèles déjà en VRAM (keep_alive=0) 2. 1er appel = cold start chronométré (1er screenshot) 3. 12 appels warm = (4 autres screenshots × 3 runs) 4. Mesurer : cold_start, warm avg/p95, taux détection, JSON valide Usage : .venv/bin/python tools/bench_safety_checks_models.py """ from __future__ import annotations import base64 import json import os import statistics import time from dataclasses import dataclass, field from typing import Any import requests from PIL import Image, ImageDraw, ImageFont OLLAMA_PRIMARY = os.environ.get("OLLAMA_URL", "http://localhost:11434") OLLAMA_SECONDARY = os.environ.get("GEMMA4_URL", "http://localhost:11435") # Configuration des candidats : (nom, url, type) CANDIDATES = [ ("gemma4:latest", OLLAMA_PRIMARY, "vlm_default"), ("qwen3-vl:8b", OLLAMA_PRIMARY, "vision_qwen3_8b"), ("qwen2.5vl:7b", OLLAMA_PRIMARY, "vision_qwen25_7b"), ("qwen2.5vl:3b", OLLAMA_PRIMARY, "vision_qwen25_3b"), ("medgemma:4b", OLLAMA_PRIMARY, "medical_4b"), ] TIMEOUT_S = int(os.environ.get("BENCH_TIMEOUT", "60")) # large pour ne rien rater MAX_CHECKS = 3 WORKFLOW_MESSAGE = "Validation T2A avant codage UHCD" EXISTING_LABELS: list[str] = [] WARM_RUNS_PER_SCREENSHOT = 3 # warm = 4 autres screenshots × 3 runs = 12 mesures # --------------------------------------------------------------------------- # Scénarios : 5 screenshots avec anomalies différentes # --------------------------------------------------------------------------- @dataclass class Scenario: label: str # nom court rows: list[tuple[str, str]] anomaly_keywords: list[str] # mots indiquant que l'anomalie est repérée SCENARIOS = [ Scenario( label="ddn_aberrante", rows=[ ("Nom :", "DUPONT Marie"), ("IPP :", "25003284"), ("Date de naissance :", "1900-01-01"), # ANOMALIE ("Sexe :", "F"), ("Date d'admission :", "2026-05-05 14:32"), ("Service :", "URGENCES"), ("Motif :", "Douleur abdominale aiguë"), ("Diagnostic principal :", "K35.8 - Appendicite aiguë"), ("Forfait facturation :", "UHCD - Forfait 24h"), ], anomaly_keywords=["1900", "naissance", "ddn", "date"], ), Scenario( label="ipp_incoherent", rows=[ ("Nom :", "MARTIN Paul"), ("IPP :", "ABC@@##XYZ"), # ANOMALIE : non numérique ("Date de naissance :", "1965-04-12"), ("Sexe :", "M"), ("Date d'admission :", "2026-05-06 09:15"), ("Service :", "URGENCES"), ("Motif :", "Chute mécanique"), ("Diagnostic principal :", "S52.5 - Fracture du radius distal"), ("Forfait facturation :", "UHCD - Forfait 24h"), ], anomaly_keywords=["ipp", "abc", "format", "incohérent", "incoherent", "invalide"], ), Scenario( label="diagnostic_vide", rows=[ ("Nom :", "BERNARD Sophie"), ("IPP :", "25004191"), ("Date de naissance :", "1972-11-08"), ("Sexe :", "F"), ("Date d'admission :", "2026-05-06 10:42"), ("Service :", "URGENCES"), ("Motif :", "Céphalées"), ("Diagnostic principal :", ""), # ANOMALIE : vide ("Forfait facturation :", "UHCD - Forfait 24h"), ], anomaly_keywords=["diagnostic", "vide", "blanc", "absent", "manque", "non renseigné", "non renseigne"], ), Scenario( label="cim_inadapte_age", rows=[ ("Nom :", "PETIT Lucas"), ("IPP :", "25004222"), ("Date de naissance :", "2025-11-01"), # nourrisson 6 mois ("Sexe :", "M"), ("Date d'admission :", "2026-05-06 11:00"), ("Service :", "URGENCES PEDIATRIQUES"), ("Motif :", "Pleurs persistants"), ("Diagnostic principal :", "M19.9 - Arthrose, sans précision"), # ANOMALIE ("Forfait facturation :", "UHCD - Forfait 24h"), ], anomaly_keywords=["arthrose", "âge", "age", "nourrisson", "incohérent", "incoherent", "m19", "incompatible"], ), Scenario( label="forfait_incoherent_duree", rows=[ ("Nom :", "ROUSSEAU Jean"), ("IPP :", "25004317"), ("Date de naissance :", "1958-03-22"), ("Sexe :", "M"), ("Date d'admission :", "2026-05-06 08:00"), ("Date de sortie :", "2026-05-06 09:00"), # 1h ("Service :", "URGENCES"), ("Motif :", "Bilan biologique"), ("Diagnostic principal :", "Z00.0 - Examen médical général"), ("Forfait facturation :", "UHCD - Forfait 24h"), # ANOMALIE : 1h ≠ UHCD 24h ], anomaly_keywords=["forfait", "uhcd", "durée", "duree", "1h", "incohérent", "incoherent", "24h"], ), ] # --------------------------------------------------------------------------- # Génération des screenshots # --------------------------------------------------------------------------- def make_screenshot(scenario: Scenario, path: str) -> None: """Crée un PNG du dossier patient pour un scénario donné.""" img = Image.new("RGB", (1024, 600), color="white") draw = ImageDraw.Draw(img) try: font_title = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 22) font_body = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 18) except OSError: font_title = ImageFont.load_default() font_body = ImageFont.load_default() draw.text((20, 20), "DOSSIER PATIENT - URGENCES UHCD", fill="black", font=font_title) draw.line([(20, 55), (1004, 55)], fill="black", width=2) y = 80 for label, value in scenario.rows: draw.text((30, y), label, fill="black", font=font_body) draw.text((280, y), value, fill="#1f2937", font=font_body) y += 35 img.save(path, format="PNG") def encode_image(path: str) -> str: with open(path, "rb") as f: return base64.b64encode(f.read()).decode("ascii") def build_prompt() -> str: existing = ", ".join(EXISTING_LABELS) if EXISTING_LABELS else "aucun" return f"""Tu es Léa, assistante médicale supervisée. Avant de continuer le workflow, tu dois lister 0 à {MAX_CHECKS} vérifications supplémentaires que l'humain doit acquitter, en regardant l'écran actuel. Contexte workflow : {WORKFLOW_MESSAGE} Checks déjà demandés : {existing} NE répète PAS un check déjà demandé. Si rien d'inhabituel à signaler, retourne {{"additional_checks": []}}. Réponds UNIQUEMENT en JSON : {{ "additional_checks": [ {{"label": "string court", "evidence": "ce que tu as vu d'inhabituel"}} ] }} """ # --------------------------------------------------------------------------- # Gestion VRAM Ollama (déchargement) # --------------------------------------------------------------------------- def list_loaded_models(url: str) -> list[str]: """Retourne la liste des modèles actuellement en VRAM sur cet Ollama.""" try: resp = requests.get(f"{url}/api/ps", timeout=5) if resp.status_code == 200: data = resp.json() return [m["name"] for m in data.get("models", [])] except Exception: pass return [] def unload_all_models() -> None: """Décharge tous les modèles en VRAM sur les 2 Ollama (keep_alive=0).""" for url in (OLLAMA_PRIMARY, OLLAMA_SECONDARY): loaded = list_loaded_models(url) for model_name in loaded: try: requests.post( f"{url}/api/generate", json={"model": model_name, "prompt": "", "keep_alive": 0, "stream": False}, timeout=10, ) except Exception: pass # Petit temps pour laisser le GC GPU faire son travail time.sleep(2) # --------------------------------------------------------------------------- # Appel modèle + parsing # --------------------------------------------------------------------------- @dataclass class CallResult: elapsed_s: float error: str = "" raw: str = "" checks: list[dict] = field(default_factory=list) def call_model(model: str, url: str, prompt: str, image_b64: str) -> CallResult: payload = { "model": model, "prompt": prompt, "stream": False, "format": "json", "options": {"temperature": 0.1, "num_predict": 250}, "images": [image_b64], } t0 = time.perf_counter() try: resp = requests.post(f"{url}/api/generate", json=payload, timeout=TIMEOUT_S) elapsed = time.perf_counter() - t0 except requests.Timeout: return CallResult(elapsed_s=TIMEOUT_S, error="TIMEOUT") except Exception as e: return CallResult(elapsed_s=time.perf_counter() - t0, error=f"NETWORK:{type(e).__name__}") if resp.status_code != 200: return CallResult(elapsed_s=elapsed, error=f"HTTP_{resp.status_code}", raw=resp.text[:200]) raw = resp.json().get("response", "").strip() try: parsed = json.loads(raw) checks = parsed.get("additional_checks") or [] if not isinstance(checks, list): checks = [] return CallResult(elapsed_s=elapsed, raw=raw[:300], checks=checks) except json.JSONDecodeError as e: return CallResult(elapsed_s=elapsed, error=f"JSON:{type(e).__name__}", raw=raw[:200]) def detects_anomaly(scenario: Scenario, checks: list[dict]) -> bool: blob = " ".join( f"{c.get('label', '')} {c.get('evidence', '')}".lower() for c in checks ) return any(pat.lower() in blob for pat in scenario.anomaly_keywords) # --------------------------------------------------------------------------- # Bench main # --------------------------------------------------------------------------- @dataclass class ModelStats: model: str cold_s: float = 0.0 warm_times: list[float] = field(default_factory=list) detection_count: int = 0 detection_total: int = 0 json_valid_count: int = 0 json_valid_total: int = 0 errors: list[str] = field(default_factory=list) sample_checks: list[tuple[str, list[dict]]] = field(default_factory=list) # (scenario_label, checks) def run_bench_for_model(model: str, url: str, screenshots: list[tuple[Scenario, str]]) -> ModelStats: print(f"\n══════════════════════════════════════════════════════════") print(f" MODEL: {model} ({url})") print(f"══════════════════════════════════════════════════════════") # Décharger tout print(f" [1/3] Déchargement VRAM...", end=" ", flush=True) unload_all_models() loaded_after = list_loaded_models(OLLAMA_PRIMARY) + list_loaded_models(OLLAMA_SECONDARY) print(f"OK (loaded={loaded_after if loaded_after else 'aucun'})") stats = ModelStats(model=model) prompt = build_prompt() # Cold start sur le 1er screenshot scen0, path0 = screenshots[0] img_b64 = encode_image(path0) print(f" [2/3] Cold start ({scen0.label})...", end=" ", flush=True) r0 = call_model(model, url, prompt, img_b64) stats.cold_s = r0.elapsed_s if r0.error: print(f"❌ {r0.error} ({r0.elapsed_s:.1f}s)") stats.errors.append(f"cold:{scen0.label}:{r0.error}") else: det = detects_anomaly(scen0, r0.checks) stats.detection_count += int(det) stats.detection_total += 1 stats.json_valid_count += 1 stats.json_valid_total += 1 stats.sample_checks.append((scen0.label, r0.checks)) print(f"{'✅' if det else '⚠️'} {len(r0.checks)} check(s) en {r0.elapsed_s:.1f}s (det={det})") # Warm runs sur les 4 autres screenshots × N runs print(f" [3/3] Warm runs ({len(screenshots)-1} scenarios × {WARM_RUNS_PER_SCREENSHOT} runs)...") for scen, path in screenshots[1:]: img_b64 = encode_image(path) for run_idx in range(WARM_RUNS_PER_SCREENSHOT): r = call_model(model, url, prompt, img_b64) if r.error: stats.errors.append(f"{scen.label}:run{run_idx}:{r.error}") stats.json_valid_total += 1 stats.detection_total += 1 print(f" {scen.label} run{run_idx}: ❌ {r.error}") continue stats.warm_times.append(r.elapsed_s) stats.json_valid_count += 1 stats.json_valid_total += 1 det = detects_anomaly(scen, r.checks) stats.detection_count += int(det) stats.detection_total += 1 if run_idx == 0: stats.sample_checks.append((scen.label, r.checks)) print(f" {scen.label} run{run_idx}: {'✅' if det else '⚠️'} {len(r.checks)} check(s) en {r.elapsed_s:.1f}s") return stats def print_summary_table(all_stats: list[ModelStats]) -> None: print("\n\n══════════════════════════════════════════════════════════") print(" SYNTHÈSE") print("══════════════════════════════════════════════════════════\n") print("| Modèle | Cold (s) | Warm avg (s) | Warm p95 (s) | JSON | Détection | Notes |") print("|---|---:|---:|---:|---:|---:|---|") for s in all_stats: if s.warm_times: warm_avg = statistics.mean(s.warm_times) warm_p95 = sorted(s.warm_times)[int(len(s.warm_times) * 0.95) - 1] if len(s.warm_times) > 1 else s.warm_times[0] else: warm_avg = warm_p95 = 0.0 json_pct = (s.json_valid_count / s.json_valid_total * 100) if s.json_valid_total else 0 det_pct = (s.detection_count / s.detection_total * 100) if s.detection_total else 0 notes = f"{len(s.errors)} err" if s.errors else "OK" print(f"| `{s.model}` | {s.cold_s:.1f} | {warm_avg:.1f} | {warm_p95:.1f} | " f"{json_pct:.0f}% ({s.json_valid_count}/{s.json_valid_total}) | " f"{det_pct:.0f}% ({s.detection_count}/{s.detection_total}) | {notes} |") print("\n## Détail des checks par scénario\n") for s in all_stats: print(f"\n### `{s.model}`") if s.errors: print(f"_Erreurs ({len(s.errors)})_ : {s.errors[:5]}{'...' if len(s.errors) > 5 else ''}") for label, checks in s.sample_checks: if not checks: print(f"- **{label}** : _aucun check_") else: for c in checks[:2]: print(f"- **{label}** : {c.get('label', '?')} — _{c.get('evidence', '?')[:120]}_") def pick_winner(all_stats: list[ModelStats]) -> ModelStats | None: """Le gagnant : meilleur taux détection, départage par warm avg.""" valid = [s for s in all_stats if s.warm_times] if not valid: return None # Tri : détection desc puis warm avg asc valid.sort(key=lambda s: (-(s.detection_count / max(s.detection_total, 1)), statistics.mean(s.warm_times))) return valid[0] def main() -> int: # Génération des 5 screenshots print("📸 Génération des 5 screenshots synthétiques :") screenshots: list[tuple[Scenario, str]] = [] for scen in SCENARIOS: path = f"/tmp/bench_safety_{scen.label}.png" make_screenshot(scen, path) print(f" - {scen.label} → {path}") screenshots.append((scen, path)) print(f"\n⏱ Timeout par appel : {TIMEOUT_S}s") print(f"🔄 Warm runs par scénario : {WARM_RUNS_PER_SCREENSHOT}") print(f"📊 Total mesures par modèle : 1 cold + {(len(SCENARIOS)-1) * WARM_RUNS_PER_SCREENSHOT} warm = " f"{1 + (len(SCENARIOS)-1) * WARM_RUNS_PER_SCREENSHOT}") print(f"🤖 Candidats : {[c[0] for c in CANDIDATES]}") all_stats: list[ModelStats] = [] for model, url, _ in CANDIDATES: try: stats = run_bench_for_model(model, url, screenshots) all_stats.append(stats) except KeyboardInterrupt: print(f"\n⚠️ Interrompu pendant {model}, on saute le reste") break except Exception as e: print(f"\n❌ Crash bench {model}: {e}") all_stats.append(ModelStats(model=model, errors=[f"crash:{e}"])) print_summary_table(all_stats) winner = pick_winner(all_stats) print("\n## Recommandation\n") if winner is None: print("⚠️ Aucun modèle exploitable. Décision manuelle nécessaire.") return 1 det_pct = winner.detection_count / max(winner.detection_total, 1) * 100 warm_avg = statistics.mean(winner.warm_times) print(f"🏆 **{winner.model}** : détection {det_pct:.0f}%, warm avg {warm_avg:.1f}s, cold {winner.cold_s:.1f}s") print(f"\nPour fixer en production :") print(f"```bash\nsudo systemctl edit rpa-streaming") print(f"# [Service]\n# Environment=RPA_SAFETY_CHECKS_LLM_MODEL={winner.model}") print(f"sudo systemctl restart rpa-streaming\n```") return 0 if __name__ == "__main__": raise SystemExit(main())