#!/usr/bin/env python3 """Benchmark A/B : gemma3:12b (base) vs pmsi-coder-v2 (fine-tuné). Compare les codes CIM-10 produits par les deux modèles sur N dossiers. Teste DP + DAS (échantillon) pour chaque dossier. Usage: python scripts/benchmark_models.py [--n 50] [--das-max 5] """ from __future__ import annotations import json import random import sys import time from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from src.config import STRUCTURED_DIR, OLLAMA_URL, DossierMedical from src.medical.cim10_dict import load_dict, normalize_code, validate_code import requests MODEL_BASE = "gemma3:12b" MODEL_FINETUNED = "pmsi-coder-v2" PROMPT_TEMPLATE = """Tu es un médecin DIM expert en codage PMSI. Code le diagnostic suivant en CIM-10. Choisis le code le plus spécifique possible. DIAGNOSTIC : "{texte}" TYPE : {type_diag} {contexte} Réponds UNIQUEMENT avec un objet JSON : {{"code": "X99.9", "confidence": "high|medium|low", "justification": "explication courte"}}""" def call_model(prompt: str, model: str, timeout: int = 120) -> tuple[dict | None, float]: """Appelle un modèle Ollama et retourne (résultat, durée_s).""" t0 = time.time() try: resp = requests.post( f"{OLLAMA_URL}/api/generate", json={ "model": model, "prompt": prompt, "stream": False, "format": "json", "options": {"temperature": 0.1, "num_predict": 500}, }, timeout=timeout, ) resp.raise_for_status() raw = resp.json().get("response", "") duration = time.time() - t0 try: return json.loads(raw), duration except json.JSONDecodeError: return None, duration except Exception as e: return None, time.time() - t0 def load_dossiers(n: int) -> list[dict]: """Charge N dossiers fusionnés diversifiés.""" dossiers = [] for subdir in sorted(STRUCTURED_DIR.iterdir()): if not subdir.is_dir(): continue for f in subdir.glob("*fusionne*.json"): if ".gemma_" in f.name or ".bak" in f.name: continue try: data = json.loads(f.read_text(encoding="utf-8")) d = DossierMedical.model_validate(data) if d.diagnostic_principal and d.diagnostic_principal.cim10_suggestion: dossiers.append({ "name": subdir.name, "dossier": d, "path": str(f), }) except Exception: continue break random.seed(42) random.shuffle(dossiers) return dossiers[:n] def build_contexte(d: DossierMedical) -> str: """Construit un contexte clinique résumé.""" parts = [] s = d.sejour if s.age is not None: parts.append(f"Patient {s.sexe or '?'}, {s.age} ans") if s.duree_sejour is not None: parts.append(f"Durée séjour : {s.duree_sejour}j") if d.diagnostic_principal: parts.append(f"DP : {d.diagnostic_principal.texte}") bio = [f"{b.test}={b.valeur}" for b in d.biologie_cle[:5] if b.valeur] if bio: parts.append(f"Bio : {', '.join(bio)}") return "CONTEXTE : " + " | ".join(parts) if parts else "" def code_match_level(code_a: str, code_b: str) -> str: """Retourne le niveau de correspondance entre deux codes.""" if code_a == code_b: return "exact" if code_a[:3] == code_b[:3]: return "categorie" return "diff" def run_benchmark(n: int = 50, das_max: int = 5): print(f"=== Benchmark A/B : {MODEL_BASE} vs {MODEL_FINETUNED} ===") print(f" Dossiers : {n}, DAS max/dossier : {das_max}\n") # Vérifier que les deux modèles sont disponibles for model in [MODEL_BASE, MODEL_FINETUNED]: try: resp = requests.post( f"{OLLAMA_URL}/api/generate", json={"model": model, "prompt": "test", "stream": False, "options": {"num_predict": 1}}, timeout=60, ) resp.raise_for_status() print(f" {model} : OK") except Exception as e: print(f" {model} : ERREUR — {e}") sys.exit(1) dossiers = load_dossiers(n) print(f"\nDossiers chargés : {len(dossiers)}\n") cim10 = load_dict() t_global_start = time.time() dp_results = [] das_results = [] for i, item in enumerate(dossiers, 1): d = item["dossier"] dp = d.diagnostic_principal name = item["name"] ctx = build_contexte(d) # === DP === prompt_dp = PROMPT_TEMPLATE.format( texte=dp.texte, type_diag="DP (diagnostic principal)", contexte=ctx, ) res_base, t_base = call_model(prompt_dp, MODEL_BASE) res_ft, t_ft = call_model(prompt_dp, MODEL_FINETUNED) code_base = normalize_code(res_base.get("code", "")) if res_base else "ERREUR" code_ft = normalize_code(res_ft.get("code", "")) if res_ft else "ERREUR" conf_base = res_base.get("confidence", "?") if res_base else "?" conf_ft = res_ft.get("confidence", "?") if res_ft else "?" valid_base = validate_code(code_base)[0] if code_base != "ERREUR" else False valid_ft = validate_code(code_ft)[0] if code_ft != "ERREUR" else False pipeline_code = dp.cim10_suggestion match_level = code_match_level(code_base, code_ft) dp_result = { "dossier": name, "texte": dp.texte[:80], "pipeline": pipeline_code, "base": code_base, "ft": code_ft, "conf_base": conf_base, "conf_ft": conf_ft, "valid_base": valid_base, "valid_ft": valid_ft, "match": match_level, "t_base": round(t_base, 2), "t_ft": round(t_ft, 2), } dp_results.append(dp_result) tag = {"exact": "=", "categorie": "~", "diff": "X"}[match_level] print(f" [{i:2d}/{len(dossiers)}] {name:<20s} DP=\"{dp.texte[:35]:<35s}\" " f"base={code_base:<7s} ft={code_ft:<7s} [{tag}] " f"({t_base:.1f}s / {t_ft:.1f}s)") # === DAS (échantillon) === das_list = [das for das in d.diagnostics_associes if das.texte and das.cim10_suggestion] if len(das_list) > das_max: random.seed(hash(name)) das_list = random.sample(das_list, das_max) for das in das_list: prompt_das = PROMPT_TEMPLATE.format( texte=das.texte, type_diag="DAS (diagnostic associé significatif)", contexte=ctx, ) res_b, tb = call_model(prompt_das, MODEL_BASE) res_f, tf = call_model(prompt_das, MODEL_FINETUNED) cb = normalize_code(res_b.get("code", "")) if res_b else "ERREUR" cf = normalize_code(res_f.get("code", "")) if res_f else "ERREUR" vb = validate_code(cb)[0] if cb != "ERREUR" else False vf = validate_code(cf)[0] if cf != "ERREUR" else False das_results.append({ "dossier": name, "texte": das.texte[:80], "pipeline": das.cim10_suggestion, "base": cb, "ft": cf, "conf_base": (res_b or {}).get("confidence", "?"), "conf_ft": (res_f or {}).get("confidence", "?"), "valid_base": vb, "valid_ft": vf, "match": code_match_level(cb, cf), "t_base": round(tb, 2), "t_ft": round(tf, 2), }) t_global = time.time() - t_global_start # === RÉSUMÉ === print(f"\n{'='*75}") print(f"RÉSUMÉ — {len(dp_results)} dossiers, {len(das_results)} DAS testés") print(f"Durée totale : {t_global/60:.1f} min\n") for label, results in [("DP", dp_results), ("DAS", das_results)]: if not results: continue nt = len(results) n_exact = sum(1 for r in results if r["match"] == "exact") n_cat = sum(1 for r in results if r["match"] == "categorie") n_diff = sum(1 for r in results if r["match"] == "diff") n_vb = sum(1 for r in results if r["valid_base"]) n_vf = sum(1 for r in results if r["valid_ft"]) avg_tb = sum(r["t_base"] for r in results) / nt avg_tf = sum(r["t_ft"] for r in results) / nt # Confiance conf_b = {} conf_f = {} for r in results: conf_b[r["conf_base"]] = conf_b.get(r["conf_base"], 0) + 1 conf_f[r["conf_ft"]] = conf_f.get(r["conf_ft"], 0) + 1 # Concordance avec pipeline (gemma run original) n_base_eq_pipe = sum(1 for r in results if r["base"] == r["pipeline"]) n_ft_eq_pipe = sum(1 for r in results if r["ft"] == r["pipeline"]) n_base_cat_pipe = sum(1 for r in results if r["base"][:3] == r["pipeline"][:3]) n_ft_cat_pipe = sum(1 for r in results if r["ft"][:3] == r["pipeline"][:3]) print(f" --- {label} ({nt} diagnostics) ---") print(f" Concordance base↔ft :") print(f" Exact : {n_exact}/{nt} ({100*n_exact/nt:.0f}%)") print(f" Catégorie : {n_exact+n_cat}/{nt} ({100*(n_exact+n_cat)/nt:.0f}%)") print(f" Différent : {n_diff}/{nt} ({100*n_diff/nt:.0f}%)") print(f" Codes valides :") print(f" base : {n_vb}/{nt} ({100*n_vb/nt:.0f}%)") print(f" ft : {n_vf}/{nt} ({100*n_vf/nt:.0f}%)") print(f" vs pipeline (gemma original) :") print(f" base=pipe : {n_base_eq_pipe}/{nt} exact, {n_base_cat_pipe}/{nt} catégorie") print(f" ft=pipe : {n_ft_eq_pipe}/{nt} exact, {n_ft_cat_pipe}/{nt} catégorie") print(f" Temps moyen : base={avg_tb:.2f}s ft={avg_tf:.2f}s (Δ={100*(avg_tf-avg_tb)/avg_tb:+.0f}%)") print(f" Confiance base : {conf_b}") print(f" Confiance ft : {conf_f}") print() # Lister les différences DP diffs_dp = [r for r in dp_results if r["match"] == "diff"] if diffs_dp: print(f" Différences DP ({len(diffs_dp)}) :") for r in diffs_dp: vb = "✓" if r["valid_base"] else "✗" vf = "✓" if r["valid_ft"] else "✗" print(f" {r['dossier']:<18s} \"{r['texte'][:40]}\"") print(f" base={r['base']:<7s}{vb} ft={r['ft']:<7s}{vf} pipe={r['pipeline']}") # Sauvegarder out = { "meta": { "date": time.strftime("%Y-%m-%dT%H:%M:%S"), "model_base": MODEL_BASE, "model_ft": MODEL_FINETUNED, "n_dossiers": len(dp_results), "n_das": len(das_results), "duration_min": round(t_global / 60, 1), }, "dp": dp_results, "das": das_results, } out_path = Path(__file__).parent.parent / "output" / "benchmark_ab.json" out_path.write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding="utf-8") print(f"\nRésultats détaillés : {out_path}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--n", type=int, default=50, help="Nombre de dossiers à tester") parser.add_argument("--das-max", type=int, default=5, help="Max DAS testés par dossier") args = parser.parse_args() run_benchmark(args.n, args.das_max)