Files
t2a_v2/benchmark_cpam_models.py
2026-03-05 00:37:41 +01:00

507 lines
18 KiB
Python

#!/usr/bin/env python3
"""Benchmark CPAM TIM — test complet multi-modèles sur dossiers réels.
Teste generate_cpam_response() avec chaque modèle local candidat
pour évaluer : validité JSON, compliance TIM, cohérence bio, codes inventés.
Usage:
python benchmark_cpam_models.py [dossier_name]
"""
import json
import logging
import os
import sys
import time
import importlib
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-5s %(name)s%(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger("benchmark_cpam")
# Modèles locaux à tester (pas de cloud)
MODELS_TO_TEST = [
"gemma3:27b",
"gemma3:27b-it-qat",
"qwen3:32b",
"qwen3:14b",
"mistral-small3.2:24b",
"llama3.3:70b",
]
# Dossier de test par défaut
DEFAULT_DOSSIER = "183_23087212"
# Seuils bio connus (ground truth pour vérification)
BIO_GROUND_TRUTH = {
"Créatinine": {"valeur": 84, "norme_min": 50, "norme_max": 120, "status": "NORMAL"},
"Sodium": {"valeur": 140, "norme_min": 135, "norme_max": 145, "status": "NORMAL"},
"Potassium": {"valeur": 3.9, "norme_min": 3.5, "norme_max": 5.0, "status": "NORMAL"},
"Hémoglobine": {"valeur": 12.6, "norme_min": 12, "norme_max": 17, "status": "NORMAL"},
"Plaquettes": {"valeur": 268, "norme_min": 150, "norme_max": 400, "status": "NORMAL"},
"Glycémie": {"valeur": 4.8, "norme_min": 3.9, "norme_max": 5.5, "status": "NORMAL"},
}
def load_dossier(name: str):
"""Charge un dossier JSON depuis output/structured/."""
from src.config import DossierMedical
base = Path(__file__).parent / "output" / "structured" / name
fusionne = list(base.glob("*_fusionne_cim10.json"))
json_files = fusionne if fusionne else sorted(base.glob("*.json"))
if not json_files:
logger.error("Aucun JSON trouvé pour %s", name)
return None
with open(json_files[0], encoding="utf-8") as f:
data = json.load(f)
return DossierMedical(**data)
def set_model(model_name: str):
"""Force le modèle CPAM dans la config au runtime."""
import src.config as cfg
import src.medical.ollama_client as oc
cfg.OLLAMA_MODELS["cpam"] = model_name
# Timeout adapté aux gros modèles locaux (600s = 10 min)
cfg.OLLAMA_TIMEOUT = 600
oc.OLLAMA_TIMEOUT = 600 # Propagation directe (importé par valeur)
logger.info("Modèle CPAM forcé → %s (timeout=600s)", model_name)
def check_model_available(model_name: str) -> bool:
"""Vérifie si le modèle est disponible localement dans Ollama."""
import requests
try:
resp = requests.get(f"{os.environ.get('OLLAMA_URL', 'http://localhost:11434')}/api/tags", timeout=5)
if resp.status_code == 200:
models = [m["name"] for m in resp.json().get("models", [])]
# Vérifier correspondance exacte ou avec :latest
for m in models:
if m == model_name or m == f"{model_name}:latest":
return True
# Gérer les cas comme "gemma3:27b" qui match "gemma3:27b"
if model_name in m:
return True
return False
except Exception:
return False
def is_tim_format(result: dict) -> bool:
"""Vérifie si le résultat est au format TIM."""
return isinstance(result, dict) and "moyens_defense" in result
def check_bio_coherence(result: dict) -> list[dict]:
"""Vérifie la cohérence bio/diagnostic dans les sorties du modèle.
Returns:
Liste d'erreurs trouvées avec détails.
"""
errors = []
if not isinstance(result, dict):
return errors
# Sérialiser tout le résultat en texte pour chercher les erreurs
full_text = json.dumps(result, ensure_ascii=False).lower()
# Vérification 1: Créatinine 84 qualifiée d'anormale
creat_patterns = [
"insuffisance rénale",
"ira", "irc",
"fonction rénale altérée", "fonction rénale dégradée",
"créatinine élevée", "creatinine élevée",
"créatinine augmentée", "hypercréatininémie",
]
# Chercher si créatinine 84 est associée à un diagnostic d'IR
if "84" in full_text and "créatinine" in full_text:
# Chercher dans les arguments et preuves
for pattern in creat_patterns:
if pattern in full_text:
errors.append({
"type": "BIO_HALLUCINATION",
"severity": "CRITICAL",
"detail": f"Créatinine 84 µmol/L (NORMAL 50-120) qualifiée comme '{pattern}'",
"ground_truth": "Créatinine 84 = NORMAL",
})
break
# Vérification 2: confrontation_bio cohérente
confrontation = result.get("confrontation_bio", [])
for entry in confrontation:
if not isinstance(entry, dict):
continue
verdict = str(entry.get("verdict", "")).upper()
test = str(entry.get("test", "")).lower()
valeur = entry.get("valeur")
# Vérifier contre ground truth
for gt_test, gt_data in BIO_GROUND_TRUTH.items():
if gt_test.lower() in test:
if gt_data["status"] == "NORMAL" and "confirmé" in verdict.lower():
errors.append({
"type": "CONFRONTATION_ERROR",
"severity": "CRITICAL",
"detail": f"{gt_test} = {gt_data['valeur']} (NORMAL) mais verdict = {verdict}",
"ground_truth": f"{gt_test} norme [{gt_data['norme_min']}-{gt_data['norme_max']}]",
})
# Vérification 3: codes_non_defendables
codes_nd = result.get("codes_non_defendables", [])
if isinstance(codes_nd, list):
# Vérifier que N17.9 (IRA) est signalé comme non défendable
# car créatinine 84 = NORMAL
nd_codes = [c.get("code", "") for c in codes_nd if isinstance(c, dict)]
# Chercher si le modèle défend N17.9 malgré bio normale
moyens = result.get("moyens_defense", [])
for m in moyens:
if not isinstance(m, dict):
continue
titre = str(m.get("titre", "")).upper()
argument = str(m.get("argument", "")).upper()
for code in ["N17", "N19"]:
if code in titre or code in argument:
# Le modèle défend un code d'IR — vérifier la créatinine
if code not in " ".join(nd_codes):
errors.append({
"type": "DEFENDS_UNDEFENDABLE",
"severity": "HIGH",
"detail": f"Code {code} (IRA/IR) défendu dans moyens_defense malgré créatinine 84 (NORMAL)",
"ground_truth": "Créatinine 84 = NORMAL → N17/N19 non défendable sur base bio",
})
return errors
def check_code_validity(result: dict) -> list[dict]:
"""Vérifie que les codes CIM-10 utilisés sont plausibles."""
import re
errors = []
if not isinstance(result, dict):
return errors
full_text = json.dumps(result, ensure_ascii=False)
# Extraire tous les codes CIM-10 mentionnés
codes = set(re.findall(r'\b([A-Z]\d{2}(?:\.\d{1,2})?)\b', full_text))
# Codes suspects connus
suspicious_codes = {
"Q61.9": "Maladie polykystique — probablement inventé pour Bricker fragile",
"Z45.80": "Code Z45.8 existe mais Z45.80 est suspect (vérifier)",
}
for code in codes:
if code in suspicious_codes:
errors.append({
"type": "SUSPICIOUS_CODE",
"severity": "MEDIUM",
"detail": f"Code {code}: {suspicious_codes[code]}",
})
return errors
def evaluate_tim_structure(result: dict) -> dict:
"""Évalue la complétude de la structure TIM."""
scores = {}
if not is_tim_format(result):
return {"format": "LEGACY", "tim_compliant": False}
scores["format"] = "TIM"
scores["tim_compliant"] = True
# Champs obligatoires TIM
required_fields = [
"objet", "rappel_faits", "moyens_defense", "confrontation_bio",
"asymetrie_information", "reponse_points_cpam", "codes_non_defendables",
"references", "conclusion_dispositive",
]
present = []
missing = []
for field in required_fields:
if result.get(field):
present.append(field)
else:
missing.append(field)
scores["fields_present"] = len(present)
scores["fields_total"] = len(required_fields)
scores["fields_missing"] = missing
# Qualité des moyens de défense
moyens = result.get("moyens_defense", [])
scores["moyens_count"] = len(moyens)
total_preuves = 0
preuves_with_ref = 0
for m in moyens:
if isinstance(m, dict):
for p in m.get("preuves", []):
if isinstance(p, dict):
total_preuves += 1
if p.get("ref"):
preuves_with_ref += 1
scores["preuves_count"] = total_preuves
scores["preuves_with_ref"] = preuves_with_ref
# Confrontation bio
confrontation = result.get("confrontation_bio", [])
scores["confrontation_count"] = len(confrontation) if isinstance(confrontation, list) else 0
# Codes non défendables
codes_nd = result.get("codes_non_defendables", [])
scores["codes_nd_count"] = len(codes_nd) if isinstance(codes_nd, list) else 0
# Références
refs = result.get("references", [])
scores["refs_count"] = len(refs) if isinstance(refs, list) else 0
# Conclusion dispositive
conclusion = result.get("conclusion_dispositive", "")
scores["conclusion_len"] = len(conclusion)
scores["has_maintien"] = "maintien" in conclusion.lower() if conclusion else False
return scores
def run_benchmark_for_model(model_name: str, dossier_name: str) -> dict:
"""Lance le pipeline CPAM complet pour un modèle donné."""
from src.control.cpam_response import generate_cpam_response
from src.control.cpam_validation import _is_new_tim_format
result_data = {
"model": model_name,
"dossier": dossier_name,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
}
# Charger le dossier
dossier = load_dossier(dossier_name)
if not dossier:
result_data["error"] = "Dossier non trouvé"
return result_data
if not dossier.controles_cpam:
result_data["error"] = "Pas de contrôle CPAM"
return result_data
controle = dossier.controles_cpam[0]
result_data["ogc"] = controle.numero_ogc
result_data["titre"] = controle.titre
# Forcer le modèle
set_model(model_name)
# Lancer le pipeline complet
logger.info("=" * 70)
logger.info("BENCHMARK : %s → dossier %s", model_name, dossier_name)
logger.info("=" * 70)
t0 = time.time()
try:
text, parsed, rag_sources = generate_cpam_response(dossier, controle)
elapsed = time.time() - t0
except Exception as e:
elapsed = time.time() - t0
result_data["error"] = str(e)
result_data["elapsed_s"] = round(elapsed, 1)
logger.exception("Erreur pipeline pour %s", model_name)
return result_data
result_data["elapsed_s"] = round(elapsed, 1)
result_data["text_len"] = len(text)
result_data["rag_sources"] = len(rag_sources)
result_data["quality_tier"] = controle.quality_tier or "?"
result_data["requires_review"] = controle.requires_review
if parsed is None:
result_data["error"] = "LLM a retourné None"
result_data["json_valid"] = False
return result_data
result_data["json_valid"] = True
result_data["is_tim"] = is_tim_format(parsed)
# Évaluation structure TIM
tim_eval = evaluate_tim_structure(parsed)
result_data["tim_eval"] = tim_eval
# Vérification cohérence bio
bio_errors = check_bio_coherence(parsed)
result_data["bio_errors"] = bio_errors
result_data["bio_errors_count"] = len(bio_errors)
result_data["bio_critical_count"] = len([e for e in bio_errors if e["severity"] == "CRITICAL"])
# Vérification codes
code_errors = check_code_validity(parsed)
result_data["code_errors"] = code_errors
result_data["code_errors_count"] = len(code_errors)
# Sauvegarder la sortie brute
result_data["parsed_response"] = parsed
result_data["text_output"] = text[:3000] # Tronquer pour lisibilité
return result_data
def print_summary(results: list[dict]):
"""Affiche un tableau résumé comparatif."""
print("\n" + "=" * 100)
print("BENCHMARK CPAM TIM — RÉSUMÉ COMPARATIF")
print("=" * 100)
# En-tête
header = (
f"{'Modèle':<25} {'JSON':>4} {'TIM':>4} {'Tier':>4} {'Temps':>7} "
f"{'Moyens':>6} {'Bio':>4} {'ND':>3} {'Refs':>4} {'Chars':>6} "
f"{'BioErr':>6} {'CritE':>5}"
)
print(header)
print("-" * 100)
for r in results:
if "error" in r and r.get("json_valid") is None:
print(f"{r['model']:<25} ERREUR: {r['error']}")
continue
tim_eval = r.get("tim_eval", {})
print(
f"{r['model']:<25} "
f"{'OK' if r.get('json_valid') else 'FAIL':>4} "
f"{'OK' if r.get('is_tim') else 'NO':>4} "
f"{r.get('quality_tier', '?'):>4} "
f"{r.get('elapsed_s', 0):>6.0f}s "
f"{tim_eval.get('moyens_count', 0):>6} "
f"{tim_eval.get('confrontation_count', 0):>4} "
f"{tim_eval.get('codes_nd_count', 0):>3} "
f"{tim_eval.get('refs_count', 0):>4} "
f"{r.get('text_len', 0):>6} "
f"{r.get('bio_errors_count', 0):>6} "
f"{r.get('bio_critical_count', 0):>5}"
)
# Détail des erreurs bio par modèle
print("\n" + "=" * 100)
print("DÉTAIL DES ERREURS BIOLOGIQUES")
print("=" * 100)
for r in results:
errors = r.get("bio_errors", [])
if not errors:
print(f"\n{r['model']}: ✓ Aucune erreur bio détectée")
continue
print(f"\n{r['model']}: ✗ {len(errors)} erreur(s)")
for e in errors:
severity_icon = "🔴" if e["severity"] == "CRITICAL" else "🟡" if e["severity"] == "HIGH" else ""
print(f" {severity_icon} [{e['severity']}] {e['type']}: {e['detail']}")
if "ground_truth" in e:
print(f" Vérité terrain: {e['ground_truth']}")
# Détail codes suspects
print("\n" + "=" * 100)
print("CODES CIM-10 SUSPECTS")
print("=" * 100)
for r in results:
code_errors = r.get("code_errors", [])
if not code_errors:
print(f"\n{r['model']}: ✓ Aucun code suspect")
continue
print(f"\n{r['model']}: ✗ {len(code_errors)} code(s) suspect(s)")
for e in code_errors:
print(f"{e['detail']}")
# Champs TIM manquants
print("\n" + "=" * 100)
print("COMPLIANCE FORMAT TIM")
print("=" * 100)
for r in results:
tim_eval = r.get("tim_eval", {})
if not tim_eval:
print(f"\n{r['model']}: N/A")
continue
missing = tim_eval.get("fields_missing", [])
total = tim_eval.get("fields_total", 9)
present = tim_eval.get("fields_present", 0)
status = "✓ COMPLET" if not missing else f"{present}/{total} champs"
print(f"\n{r['model']}: {status}")
if missing:
print(f" Manquants: {', '.join(missing)}")
if tim_eval.get("has_maintien"):
print(f" ✓ Conclusion dispositive avec demande de maintien")
elif tim_eval.get("conclusion_len", 0) > 0:
print(f" ⚠ Conclusion présente ({tim_eval['conclusion_len']} chars) mais sans 'maintien'")
else:
print(f" ✗ Pas de conclusion dispositive")
def main():
dossier_name = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_DOSSIER
# Vérifier quels modèles sont disponibles
available = []
unavailable = []
for model in MODELS_TO_TEST:
if check_model_available(model):
available.append(model)
else:
unavailable.append(model)
print(f"Modèles disponibles: {len(available)}/{len(MODELS_TO_TEST)}")
for m in available:
print(f"{m}")
for m in unavailable:
print(f"{m} (non trouvé)")
if not available:
print("ERREUR: Aucun modèle local disponible")
sys.exit(1)
print(f"\nDossier de test: {dossier_name}")
print(f"Début du benchmark...\n")
results = []
for model in available:
try:
result = run_benchmark_for_model(model, dossier_name)
results.append(result)
# Sauvegarder les résultats intermédiaires
output_path = Path(__file__).parent / "output" / "benchmark_cpam_tim.json"
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2, default=str)
except Exception as e:
logger.exception("Erreur fatale pour %s", model)
results.append({"model": model, "error": str(e)})
# Résumé comparatif
print_summary(results)
# Sauvegarder les résultats finaux
output_path = Path(__file__).parent / "output" / "benchmark_cpam_tim.json"
with open(output_path, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2, default=str)
print(f"\nRésultats détaillés sauvegardés dans: {output_path}")
if __name__ == "__main__":
main()