""" Harness de comparaison medgemma:4b vs baselines internes. Usage : python3 tools/benchmark_medgemma_demo.py [--models m1,m2,...] [--out report.md] Tâches évaluées : 1. Codage CIM-10 (5 vignettes, gold connu) — match exact + match famille (3 chars) 2. Résumé de dossier (3 CRH anonymisés) — qualitatif, longueur, latence 3. Extraction structurée JSON (mêmes 3 CRH) — conformité schéma + remplissage Métriques : latence, longueur sortie, score CIM-10. Sortie : rapport markdown + JSON brut pour relecture. """ from __future__ import annotations import argparse import json import re import time from pathlib import Path from typing import Any import requests OLLAMA_URL = "http://localhost:11434/api/generate" TIMEOUT = 240 # un appel de 4min max sur les gros modèles DEFAULT_MODELS = [ "medgemma:4b", "pmsi-coder-v2:latest", "qwen2.5:7b", "gemma4:latest", ] T2A_ANON = Path("/home/dom/ai/t2a_v2/output/anonymized") # 5 vignettes CIM-10 — gold construit à partir de cas typiques CIM10_VIGNETTES = [ { "id": "v1_idm_inferieur", "text": ( "Patient de 65 ans, douleur thoracique constrictive irradiant dans " "le bras gauche depuis 2h. ECG : sus-décalage ST en DII, DIII et " "aVF. Troponine I : 4,8 ng/mL (N<0,04). Coronarographie : " "occlusion de la coronaire droite proximale, stent posé." ), "expected_exact": "I21.1", "expected_family3": "I21", "label": "Infarctus du myocarde inférieur", }, { "id": "v2_pneumopathie", "text": ( "Femme 72 ans, fièvre 39°C, toux productive, dyspnée. Examen : " "crépitants base droite. Radio : foyer alvéolaire lobaire moyen " "droit. Antigénurie pneumocoque positive. Antibiothérapie par " "amoxicilline IV 6g/j." ), "expected_exact": "J13", "expected_family3": "J13", "label": "Pneumonie à pneumocoque", }, { "id": "v3_avc_ischemique", "text": ( "Homme 78 ans amené aux urgences pour hémiplégie droite et aphasie " "d'installation brutale 1h auparavant. NIHSS 14. Scanner cérébral " "sans injection : pas d'hémorragie. IRM diffusion : restriction " "sylvienne gauche. Thrombolyse IV par altéplase." ), "expected_exact": "I63.5", "expected_family3": "I63", "label": "AVC ischémique sylvien gauche", }, { "id": "v4_decompensation_cardiaque", "text": ( "Patiente 84 ans, antécédents d'HTA et de cardiopathie ischémique. " "Dyspnée d'aggravation progressive sur 48h, orthopnée, OMI. " "Auscultation : crépitants bilatéraux. BNP 2400 pg/mL. Radio : " "syndrome alvéolo-interstitiel bilatéral, cardiomégalie. " "Diurétiques IV." ), "expected_exact": "I50.1", "expected_family3": "I50", "label": "Insuffisance cardiaque gauche décompensée", }, { "id": "v5_dyspnee_symptome", "text": ( "Patient 56 ans aux urgences pour dyspnée aiguë sans étiologie " "retrouvée après bilan complet (D-dimères négatifs, scanner " "thoracique sans embolie ni foyer, ECG normal, BNP normal). " "Évolution favorable spontanément. Sortie après 48h." ), "expected_exact": "R06.0", "expected_family3": "R06", "label": "Dyspnée (symptôme isolé, étiologie non retrouvée)", }, ] # 3 CRH anonymisés réels pour résumé + extraction CRH_FILES = [ T2A_ANON / "67_23001636/crh_67_23108642_anonymized.txt", T2A_ANON / "103_23056749/CRH 23056749_anonymized.txt", T2A_ANON / "407_23116460/407_crh_anonymized.txt", ] CIM10_PROMPT = """Tu es un médecin codeur PMSI expert en CIM-10. Vignette clinique : {text} Donne UNIQUEMENT le diagnostic principal en CIM-10 au format JSON strict : {{"code": "X00.0", "label": "libellé court"}} Aucun texte autour, juste le JSON.""" SUMMARY_PROMPT = """Tu es un médecin résumant un compte-rendu d'hospitalisation pour passage de relais. Compte-rendu : {text} Résume en 5 puces concises (un point par ligne, format `- ...`) : 1. Motif d'admission 2. Antécédents pertinents 3. Diagnostic(s) retenu(s) 4. Traitements engagés 5. Évolution / orientation Pas de phrases d'introduction. Juste les 5 puces.""" EXTRACTION_PROMPT = """Extrait les informations structurées du compte-rendu suivant. Compte-rendu : {text} Réponds UNIQUEMENT par un JSON strict de ce schéma : {{ "motif_admission": "string court", "diagnostics": ["liste de diagnostics retenus"], "antecedents": ["liste d'antécédents notables"], "traitements": ["traitements engagés pendant le séjour"], "date_admission": "JJ/MM/AAAA ou null", "date_sortie": "JJ/MM/AAAA ou null", "duree_sejour_jours": null }} Si une info est absente, mets null ou liste vide. Aucun texte autour du JSON.""" def call_ollama(model: str, prompt: str) -> tuple[str, float, dict[str, Any]]: """Renvoie (output, latency_s, meta).""" payload = { "model": model, "prompt": prompt, "stream": False, "options": {"temperature": 0.1, "num_ctx": 8192}, } t0 = time.time() try: r = requests.post(OLLAMA_URL, json=payload, timeout=TIMEOUT) r.raise_for_status() data = r.json() latency = time.time() - t0 return data.get("response", ""), latency, { "eval_count": data.get("eval_count"), "eval_duration_ns": data.get("eval_duration"), "load_duration_ns": data.get("load_duration"), } except Exception as e: latency = time.time() - t0 return f"[ERROR: {e}]", latency, {"error": str(e)} def extract_json(text: str) -> dict | None: """Extrait le premier objet JSON d'une chaîne, tolérant aux fences markdown.""" if not text: return None # Nettoyer fences ```json ... ``` cleaned = re.sub(r"```(?:json)?\s*", "", text) cleaned = cleaned.replace("```", "") # Trouver le premier { ... } équilibré start = cleaned.find("{") if start < 0: return None depth = 0 for i in range(start, len(cleaned)): if cleaned[i] == "{": depth += 1 elif cleaned[i] == "}": depth -= 1 if depth == 0: try: return json.loads(cleaned[start:i + 1]) except json.JSONDecodeError: return None return None def score_cim10(predicted_code: str | None, gold_exact: str, gold_family: str) -> str: """Renvoie 'exact', 'family', 'wrong' ou 'parse_error'.""" if not predicted_code: return "parse_error" code = predicted_code.upper().strip().replace(" ", "") if code == gold_exact: return "exact" if code[:3] == gold_family: return "family" return "wrong" def run_cim10_task(models: list[str]) -> list[dict]: results = [] for vig in CIM10_VIGNETTES: for model in models: print(f" [CIM-10] {vig['id']:30s} {model:35s}", end=" ", flush=True) output, latency, meta = call_ollama(model, CIM10_PROMPT.format(text=vig["text"])) parsed = extract_json(output) pred_code = parsed.get("code") if parsed else None score = score_cim10(pred_code, vig["expected_exact"], vig["expected_family3"]) print(f"→ {pred_code or '?'} ({score}) {latency:.1f}s") results.append({ "task": "cim10", "case_id": vig["id"], "model": model, "expected_exact": vig["expected_exact"], "expected_family": vig["expected_family3"], "predicted": pred_code, "score": score, "latency_s": round(latency, 2), "raw_output": output[:500], }) return results def run_summary_task(models: list[str], crh_texts: list[tuple[str, str]]) -> list[dict]: results = [] for crh_id, crh_text in crh_texts: for model in models: print(f" [SUMMARY] {crh_id:30s} {model:35s}", end=" ", flush=True) output, latency, meta = call_ollama(model, SUMMARY_PROMPT.format(text=crh_text)) n_bullets = sum(1 for line in output.splitlines() if line.strip().startswith(("-", "•", "*"))) print(f"→ {n_bullets} puces, {len(output)} car., {latency:.1f}s") results.append({ "task": "summary", "case_id": crh_id, "model": model, "n_bullets": n_bullets, "n_chars": len(output), "latency_s": round(latency, 2), "output": output, }) return results def run_extraction_task(models: list[str], crh_texts: list[tuple[str, str]]) -> list[dict]: expected_keys = {"motif_admission", "diagnostics", "antecedents", "traitements", "date_admission", "date_sortie", "duree_sejour_jours"} results = [] for crh_id, crh_text in crh_texts: for model in models: print(f" [EXTRACT] {crh_id:30s} {model:35s}", end=" ", flush=True) output, latency, meta = call_ollama(model, EXTRACTION_PROMPT.format(text=crh_text)) parsed = extract_json(output) if parsed is None: conformity = "parse_error" filled = 0 else: missing = expected_keys - set(parsed.keys()) extras = set(parsed.keys()) - expected_keys conformity = "conforme" if not missing else f"manque:{','.join(sorted(missing))}" filled = sum(1 for k in expected_keys if parsed.get(k) not in (None, "", [], "null")) print(f"→ {conformity}, {filled}/7 rempli, {latency:.1f}s") results.append({ "task": "extraction", "case_id": crh_id, "model": model, "conformity": conformity, "filled_fields": filled, "parsed": parsed, "latency_s": round(latency, 2), "raw_output": output[:800], }) return results def render_report(all_results: list[dict], out_path: Path) -> str: lines = ["# Benchmark medgemma:4b — démo médicale", ""] lines.append(f"_Généré le {time.strftime('%Y-%m-%d %H:%M:%S')}_") lines.append("") # ---- CIM-10 ---- lines.append("## 1. Codage CIM-10 (5 vignettes)") lines.append("") cim_rows = [r for r in all_results if r["task"] == "cim10"] models = sorted({r["model"] for r in cim_rows}) lines.append("| Modèle | Exact | Famille | Faux | Parse error | Latence moy. |") lines.append("|---|---:|---:|---:|---:|---:|") for m in models: rows = [r for r in cim_rows if r["model"] == m] n_exact = sum(1 for r in rows if r["score"] == "exact") n_fam = sum(1 for r in rows if r["score"] == "family") n_wrong = sum(1 for r in rows if r["score"] == "wrong") n_perr = sum(1 for r in rows if r["score"] == "parse_error") avg_lat = sum(r["latency_s"] for r in rows) / max(len(rows), 1) lines.append(f"| `{m}` | {n_exact}/5 | {n_fam}/5 | {n_wrong}/5 | {n_perr}/5 | {avg_lat:.1f}s |") lines.append("") lines.append("### Détail par vignette") for vig in CIM10_VIGNETTES: lines.append(f"\n**{vig['id']}** — attendu `{vig['expected_exact']}` ({vig['label']})") lines.append("") lines.append("| Modèle | Prédit | Score | Latence |") lines.append("|---|---|---|---:|") for r in [x for x in cim_rows if x["case_id"] == vig["id"]]: lines.append(f"| `{r['model']}` | `{r['predicted'] or '—'}` | {r['score']} | {r['latency_s']}s |") # ---- Résumé ---- lines.append("\n## 2. Résumé de CRH (3 dossiers anonymisés)") lines.append("") sum_rows = [r for r in all_results if r["task"] == "summary"] lines.append("| Modèle | Latence moy. | Longueur moy. | Puces moy. |") lines.append("|---|---:|---:|---:|") for m in models: rows = [r for r in sum_rows if r["model"] == m] if not rows: continue avg_lat = sum(r["latency_s"] for r in rows) / len(rows) avg_len = sum(r["n_chars"] for r in rows) / len(rows) avg_bul = sum(r["n_bullets"] for r in rows) / len(rows) lines.append(f"| `{m}` | {avg_lat:.1f}s | {avg_len:.0f} car. | {avg_bul:.1f} |") lines.append("") lines.append("### Sortie complète par modèle (à juger qualitativement)") for r in sum_rows: lines.append(f"\n#### {r['case_id']} — `{r['model']}` ({r['latency_s']}s)") lines.append("```") lines.append(r["output"][:1500]) lines.append("```") # ---- Extraction ---- lines.append("\n## 3. Extraction structurée JSON") lines.append("") ext_rows = [r for r in all_results if r["task"] == "extraction"] lines.append("| Modèle | Conformes | Champs remplis moy. | Latence moy. |") lines.append("|---|---:|---:|---:|") for m in models: rows = [r for r in ext_rows if r["model"] == m] if not rows: continue n_conforme = sum(1 for r in rows if r["conformity"] == "conforme") avg_filled = sum(r["filled_fields"] for r in rows) / len(rows) avg_lat = sum(r["latency_s"] for r in rows) / len(rows) lines.append(f"| `{m}` | {n_conforme}/{len(rows)} | {avg_filled:.1f}/7 | {avg_lat:.1f}s |") lines.append("") lines.append("### Détail JSON parsé par cas") for r in ext_rows: lines.append(f"\n#### {r['case_id']} — `{r['model']}` ({r['conformity']}, {r['latency_s']}s)") if r["parsed"]: lines.append("```json") lines.append(json.dumps(r["parsed"], indent=2, ensure_ascii=False)[:1500]) lines.append("```") else: lines.append(f"_Parse error._ Brut : `{r['raw_output'][:300]}`") out_path.write_text("\n".join(lines), encoding="utf-8") return "\n".join(lines) def main(): ap = argparse.ArgumentParser() ap.add_argument("--models", default=",".join(DEFAULT_MODELS), help="Liste de modèles séparés par virgule") ap.add_argument("--out", default="docs/BENCH_MEDGEMMA.md") ap.add_argument("--skip-summary", action="store_true") ap.add_argument("--skip-extraction", action="store_true") ap.add_argument("--skip-cim10", action="store_true") args = ap.parse_args() models = [m.strip() for m in args.models.split(",") if m.strip()] print(f"Modèles testés : {models}") # Charger CRH crh_texts = [] for path in CRH_FILES: if path.exists(): crh_texts.append((path.parent.name, path.read_text(encoding="utf-8"))) else: print(f" [WARN] CRH absent : {path}") all_results = [] if not args.skip_cim10: print("\n=== Tâche 1 : Codage CIM-10 ===") all_results.extend(run_cim10_task(models)) if not args.skip_summary and crh_texts: print("\n=== Tâche 2 : Résumé de CRH ===") all_results.extend(run_summary_task(models, crh_texts)) if not args.skip_extraction and crh_texts: print("\n=== Tâche 3 : Extraction structurée ===") all_results.extend(run_extraction_task(models, crh_texts)) # Sauvegarde out_md = Path(args.out) out_md.parent.mkdir(parents=True, exist_ok=True) out_json = out_md.with_suffix(".json") out_json.write_text(json.dumps(all_results, indent=2, ensure_ascii=False), encoding="utf-8") render_report(all_results, out_md) print(f"\n✅ Rapport : {out_md}") print(f"✅ Résultats bruts : {out_json}") if __name__ == "__main__": main()