Files
t2a/scripts/benchmark_models.py
2026-03-05 00:37:34 +01:00

314 lines
12 KiB
Python

#!/usr/bin/env python3
"""Benchmark A/B : gemma3:12b (base) vs pmsi-coder-v2 (fine-tuné).
Compare les codes CIM-10 produits par les deux modèles sur N dossiers.
Teste DP + DAS (échantillon) pour chaque dossier.
Usage: python scripts/benchmark_models.py [--n 50] [--das-max 5]
"""
from __future__ import annotations
import json
import random
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.config import STRUCTURED_DIR, OLLAMA_URL, DossierMedical
from src.medical.cim10_dict import load_dict, normalize_code, validate_code
import requests
MODEL_BASE = "gemma3:12b"
MODEL_FINETUNED = "pmsi-coder-v2"
PROMPT_TEMPLATE = """Tu es un médecin DIM expert en codage PMSI.
Code le diagnostic suivant en CIM-10. Choisis le code le plus spécifique possible.
DIAGNOSTIC : "{texte}"
TYPE : {type_diag}
{contexte}
Réponds UNIQUEMENT avec un objet JSON :
{{"code": "X99.9", "confidence": "high|medium|low", "justification": "explication courte"}}"""
def call_model(prompt: str, model: str, timeout: int = 120) -> tuple[dict | None, float]:
"""Appelle un modèle Ollama et retourne (résultat, durée_s)."""
t0 = time.time()
try:
resp = requests.post(
f"{OLLAMA_URL}/api/generate",
json={
"model": model,
"prompt": prompt,
"stream": False,
"format": "json",
"options": {"temperature": 0.1, "num_predict": 500},
},
timeout=timeout,
)
resp.raise_for_status()
raw = resp.json().get("response", "")
duration = time.time() - t0
try:
return json.loads(raw), duration
except json.JSONDecodeError:
return None, duration
except Exception as e:
return None, time.time() - t0
def load_dossiers(n: int) -> list[dict]:
"""Charge N dossiers fusionnés diversifiés."""
dossiers = []
for subdir in sorted(STRUCTURED_DIR.iterdir()):
if not subdir.is_dir():
continue
for f in subdir.glob("*fusionne*.json"):
if ".gemma_" in f.name or ".bak" in f.name:
continue
try:
data = json.loads(f.read_text(encoding="utf-8"))
d = DossierMedical.model_validate(data)
if d.diagnostic_principal and d.diagnostic_principal.cim10_suggestion:
dossiers.append({
"name": subdir.name,
"dossier": d,
"path": str(f),
})
except Exception:
continue
break
random.seed(42)
random.shuffle(dossiers)
return dossiers[:n]
def build_contexte(d: DossierMedical) -> str:
"""Construit un contexte clinique résumé."""
parts = []
s = d.sejour
if s.age is not None:
parts.append(f"Patient {s.sexe or '?'}, {s.age} ans")
if s.duree_sejour is not None:
parts.append(f"Durée séjour : {s.duree_sejour}j")
if d.diagnostic_principal:
parts.append(f"DP : {d.diagnostic_principal.texte}")
bio = [f"{b.test}={b.valeur}" for b in d.biologie_cle[:5] if b.valeur]
if bio:
parts.append(f"Bio : {', '.join(bio)}")
return "CONTEXTE : " + " | ".join(parts) if parts else ""
def code_match_level(code_a: str, code_b: str) -> str:
"""Retourne le niveau de correspondance entre deux codes."""
if code_a == code_b:
return "exact"
if code_a[:3] == code_b[:3]:
return "categorie"
return "diff"
def run_benchmark(n: int = 50, das_max: int = 5):
print(f"=== Benchmark A/B : {MODEL_BASE} vs {MODEL_FINETUNED} ===")
print(f" Dossiers : {n}, DAS max/dossier : {das_max}\n")
# Vérifier que les deux modèles sont disponibles
for model in [MODEL_BASE, MODEL_FINETUNED]:
try:
resp = requests.post(
f"{OLLAMA_URL}/api/generate",
json={"model": model, "prompt": "test", "stream": False,
"options": {"num_predict": 1}},
timeout=60,
)
resp.raise_for_status()
print(f" {model} : OK")
except Exception as e:
print(f" {model} : ERREUR — {e}")
sys.exit(1)
dossiers = load_dossiers(n)
print(f"\nDossiers chargés : {len(dossiers)}\n")
cim10 = load_dict()
t_global_start = time.time()
dp_results = []
das_results = []
for i, item in enumerate(dossiers, 1):
d = item["dossier"]
dp = d.diagnostic_principal
name = item["name"]
ctx = build_contexte(d)
# === DP ===
prompt_dp = PROMPT_TEMPLATE.format(
texte=dp.texte,
type_diag="DP (diagnostic principal)",
contexte=ctx,
)
res_base, t_base = call_model(prompt_dp, MODEL_BASE)
res_ft, t_ft = call_model(prompt_dp, MODEL_FINETUNED)
code_base = normalize_code(res_base.get("code", "")) if res_base else "ERREUR"
code_ft = normalize_code(res_ft.get("code", "")) if res_ft else "ERREUR"
conf_base = res_base.get("confidence", "?") if res_base else "?"
conf_ft = res_ft.get("confidence", "?") if res_ft else "?"
valid_base = validate_code(code_base)[0] if code_base != "ERREUR" else False
valid_ft = validate_code(code_ft)[0] if code_ft != "ERREUR" else False
pipeline_code = dp.cim10_suggestion
match_level = code_match_level(code_base, code_ft)
dp_result = {
"dossier": name,
"texte": dp.texte[:80],
"pipeline": pipeline_code,
"base": code_base,
"ft": code_ft,
"conf_base": conf_base,
"conf_ft": conf_ft,
"valid_base": valid_base,
"valid_ft": valid_ft,
"match": match_level,
"t_base": round(t_base, 2),
"t_ft": round(t_ft, 2),
}
dp_results.append(dp_result)
tag = {"exact": "=", "categorie": "~", "diff": "X"}[match_level]
print(f" [{i:2d}/{len(dossiers)}] {name:<20s} DP=\"{dp.texte[:35]:<35s}\" "
f"base={code_base:<7s} ft={code_ft:<7s} [{tag}] "
f"({t_base:.1f}s / {t_ft:.1f}s)")
# === DAS (échantillon) ===
das_list = [das for das in d.diagnostics_associes
if das.texte and das.cim10_suggestion]
if len(das_list) > das_max:
random.seed(hash(name))
das_list = random.sample(das_list, das_max)
for das in das_list:
prompt_das = PROMPT_TEMPLATE.format(
texte=das.texte,
type_diag="DAS (diagnostic associé significatif)",
contexte=ctx,
)
res_b, tb = call_model(prompt_das, MODEL_BASE)
res_f, tf = call_model(prompt_das, MODEL_FINETUNED)
cb = normalize_code(res_b.get("code", "")) if res_b else "ERREUR"
cf = normalize_code(res_f.get("code", "")) if res_f else "ERREUR"
vb = validate_code(cb)[0] if cb != "ERREUR" else False
vf = validate_code(cf)[0] if cf != "ERREUR" else False
das_results.append({
"dossier": name,
"texte": das.texte[:80],
"pipeline": das.cim10_suggestion,
"base": cb,
"ft": cf,
"conf_base": (res_b or {}).get("confidence", "?"),
"conf_ft": (res_f or {}).get("confidence", "?"),
"valid_base": vb,
"valid_ft": vf,
"match": code_match_level(cb, cf),
"t_base": round(tb, 2),
"t_ft": round(tf, 2),
})
t_global = time.time() - t_global_start
# === RÉSUMÉ ===
print(f"\n{'='*75}")
print(f"RÉSUMÉ — {len(dp_results)} dossiers, {len(das_results)} DAS testés")
print(f"Durée totale : {t_global/60:.1f} min\n")
for label, results in [("DP", dp_results), ("DAS", das_results)]:
if not results:
continue
nt = len(results)
n_exact = sum(1 for r in results if r["match"] == "exact")
n_cat = sum(1 for r in results if r["match"] == "categorie")
n_diff = sum(1 for r in results if r["match"] == "diff")
n_vb = sum(1 for r in results if r["valid_base"])
n_vf = sum(1 for r in results if r["valid_ft"])
avg_tb = sum(r["t_base"] for r in results) / nt
avg_tf = sum(r["t_ft"] for r in results) / nt
# Confiance
conf_b = {}
conf_f = {}
for r in results:
conf_b[r["conf_base"]] = conf_b.get(r["conf_base"], 0) + 1
conf_f[r["conf_ft"]] = conf_f.get(r["conf_ft"], 0) + 1
# Concordance avec pipeline (gemma run original)
n_base_eq_pipe = sum(1 for r in results if r["base"] == r["pipeline"])
n_ft_eq_pipe = sum(1 for r in results if r["ft"] == r["pipeline"])
n_base_cat_pipe = sum(1 for r in results
if r["base"][:3] == r["pipeline"][:3])
n_ft_cat_pipe = sum(1 for r in results
if r["ft"][:3] == r["pipeline"][:3])
print(f" --- {label} ({nt} diagnostics) ---")
print(f" Concordance base↔ft :")
print(f" Exact : {n_exact}/{nt} ({100*n_exact/nt:.0f}%)")
print(f" Catégorie : {n_exact+n_cat}/{nt} ({100*(n_exact+n_cat)/nt:.0f}%)")
print(f" Différent : {n_diff}/{nt} ({100*n_diff/nt:.0f}%)")
print(f" Codes valides :")
print(f" base : {n_vb}/{nt} ({100*n_vb/nt:.0f}%)")
print(f" ft : {n_vf}/{nt} ({100*n_vf/nt:.0f}%)")
print(f" vs pipeline (gemma original) :")
print(f" base=pipe : {n_base_eq_pipe}/{nt} exact, {n_base_cat_pipe}/{nt} catégorie")
print(f" ft=pipe : {n_ft_eq_pipe}/{nt} exact, {n_ft_cat_pipe}/{nt} catégorie")
print(f" Temps moyen : base={avg_tb:.2f}s ft={avg_tf:.2f}s (Δ={100*(avg_tf-avg_tb)/avg_tb:+.0f}%)")
print(f" Confiance base : {conf_b}")
print(f" Confiance ft : {conf_f}")
print()
# Lister les différences DP
diffs_dp = [r for r in dp_results if r["match"] == "diff"]
if diffs_dp:
print(f" Différences DP ({len(diffs_dp)}) :")
for r in diffs_dp:
vb = "" if r["valid_base"] else ""
vf = "" if r["valid_ft"] else ""
print(f" {r['dossier']:<18s} \"{r['texte'][:40]}\"")
print(f" base={r['base']:<7s}{vb} ft={r['ft']:<7s}{vf} pipe={r['pipeline']}")
# Sauvegarder
out = {
"meta": {
"date": time.strftime("%Y-%m-%dT%H:%M:%S"),
"model_base": MODEL_BASE,
"model_ft": MODEL_FINETUNED,
"n_dossiers": len(dp_results),
"n_das": len(das_results),
"duration_min": round(t_global / 60, 1),
},
"dp": dp_results,
"das": das_results,
}
out_path = Path(__file__).parent.parent / "output" / "benchmark_ab.json"
out_path.write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"\nRésultats détaillés : {out_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int, default=50,
help="Nombre de dossiers à tester")
parser.add_argument("--das-max", type=int, default=5,
help="Max DAS testés par dossier")
args = parser.parse_args()
run_benchmark(args.n, args.das_max)