473 lines
16 KiB
Python
473 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""Comparaison qualité CPAM : multi-modèles sur 3 dossiers.
|
|
|
|
Génère la contre-argumentation CPAM avec plusieurs modèles et compare :
|
|
- Longueur et densité des arguments
|
|
- Présence des 3 axes (médical, asymétrie, réglementaire)
|
|
- Citations de preuves du dossier
|
|
- Références aux sources RAG
|
|
- Mots-clés d'asymétrie d'information
|
|
"""
|
|
|
|
import json
|
|
import re
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
|
|
STRUCTURED_DIR = Path("output/structured")
|
|
OLLAMA_URL = "http://localhost:11434"
|
|
MODELS = ["gemma3:12b-v2"] # 12b avec nouveau prompt nuancé
|
|
TIMEOUTS = {
|
|
"gemma3:12b": 120,
|
|
"gemma3:27b": 300,
|
|
"qwen3:14b": 180,
|
|
"mistral-small3.2:24b": 300,
|
|
}
|
|
|
|
# 3 dossiers variés : DP+DA, DAS long, DP court
|
|
TEST_DOSSIERS = [
|
|
"183_23087212", # DP+DA contestés
|
|
"228_23176885", # DAS seul, arg long (1921c)
|
|
"153_23102610", # DP seul, arg court
|
|
]
|
|
|
|
|
|
def load_dossier(dossier_name: str) -> dict | None:
|
|
dossier_dir = STRUCTURED_DIR / dossier_name
|
|
if not dossier_dir.exists():
|
|
return None
|
|
for f in list(dossier_dir.glob("*_fusionne_cim10.json")) + sorted(dossier_dir.glob("*_cim10.json")):
|
|
return json.loads(f.read_text())
|
|
return None
|
|
|
|
|
|
def build_prompt(data: dict, controle: dict, sources: list[dict]) -> str:
|
|
"""Reconstruit le prompt CPAM (identique au pipeline)."""
|
|
# Import du vrai builder pour garantir la cohérence
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
from src.config import ControleCPAM, DossierMedical
|
|
from src.control.cpam_response import _build_cpam_prompt
|
|
|
|
dossier = DossierMedical.model_validate(data)
|
|
ctrl = ControleCPAM.model_validate(controle)
|
|
return _build_cpam_prompt(dossier, ctrl, sources)
|
|
|
|
|
|
# Modèles incompatibles avec format:json d'Ollama (mode thinking)
|
|
NO_FORMAT_JSON_MODELS = {"qwen3:14b", "qwen3:8b", "qwen3:32b"}
|
|
|
|
|
|
def _parse_json_from_text(raw: str) -> dict | None:
|
|
"""Parse du JSON depuis une réponse brute (avec ou sans markdown)."""
|
|
text = raw.strip()
|
|
# Retirer bloc markdown ```json ... ```
|
|
if text.startswith("```"):
|
|
first_nl = text.find("\n")
|
|
if first_nl != -1:
|
|
text = text[first_nl + 1:]
|
|
if text.rstrip().endswith("```"):
|
|
text = text.rstrip()[:-3]
|
|
text = text.strip()
|
|
# Essayer tel quel
|
|
try:
|
|
return json.loads(text)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
# Trouver le premier { ... dernier }
|
|
brace_start = text.find("{")
|
|
brace_end = text.rfind("}")
|
|
if brace_start != -1 and brace_end > brace_start:
|
|
try:
|
|
return json.loads(text[brace_start:brace_end + 1])
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return None
|
|
|
|
|
|
def call_ollama(prompt: str, model: str) -> tuple[dict | None, float, str]:
|
|
"""Appelle Ollama et retourne (parsed_json, duration_s, raw_text)."""
|
|
timeout = TIMEOUTS.get(model, 180)
|
|
use_format_json = model not in NO_FORMAT_JSON_MODELS
|
|
|
|
# Pour Qwen3 : ajouter /no_think pour désactiver le mode thinking
|
|
actual_prompt = prompt
|
|
if model in NO_FORMAT_JSON_MODELS:
|
|
actual_prompt = prompt + "\n/no_think"
|
|
|
|
payload = {
|
|
"model": model,
|
|
"prompt": actual_prompt,
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": 0.1,
|
|
"num_predict": 4000,
|
|
},
|
|
}
|
|
if use_format_json:
|
|
payload["format"] = "json"
|
|
|
|
t0 = time.time()
|
|
try:
|
|
response = requests.post(
|
|
f"{OLLAMA_URL}/api/generate",
|
|
json=payload,
|
|
timeout=timeout,
|
|
)
|
|
response.raise_for_status()
|
|
duration = time.time() - t0
|
|
raw = response.json().get("response", "")
|
|
|
|
parsed = _parse_json_from_text(raw)
|
|
return parsed, duration, raw
|
|
|
|
except json.JSONDecodeError:
|
|
duration = time.time() - t0
|
|
return None, duration, raw
|
|
except Exception as e:
|
|
duration = time.time() - t0
|
|
return None, duration, str(e)
|
|
|
|
|
|
def compute_metrics(parsed: dict | None) -> dict:
|
|
"""Calcule les métriques de qualité."""
|
|
if parsed is None:
|
|
return {"valid_json": False}
|
|
|
|
full_text = json.dumps(parsed, ensure_ascii=False)
|
|
|
|
# 3 axes présents ?
|
|
has_med = bool(parsed.get("contre_arguments_medicaux"))
|
|
has_asym = bool(parsed.get("contre_arguments_asymetrie"))
|
|
has_regl = bool(parsed.get("contre_arguments_reglementaires"))
|
|
has_3axes = has_med and has_asym and has_regl
|
|
|
|
# Longueurs par axe
|
|
len_med = len(str(parsed.get("contre_arguments_medicaux", "")))
|
|
len_asym = len(str(parsed.get("contre_arguments_asymetrie", "")))
|
|
len_regl = len(str(parsed.get("contre_arguments_reglementaires", "")))
|
|
len_total_args = len_med + len_asym + len_regl
|
|
|
|
# Fallback ancien format
|
|
if not has_3axes:
|
|
len_total_args = max(len_total_args, len(str(parsed.get("contre_arguments", ""))))
|
|
|
|
# Preuves du dossier
|
|
preuves = parsed.get("preuves_dossier", [])
|
|
n_preuves = len(preuves) if isinstance(preuves, list) else 0
|
|
|
|
# Références structurées
|
|
refs = parsed.get("references", [])
|
|
n_refs = len(refs) if isinstance(refs, list) else 0
|
|
|
|
# Références avec citation verbatim
|
|
n_refs_citation = 0
|
|
if isinstance(refs, list):
|
|
for r in refs:
|
|
if isinstance(r, dict) and r.get("citation") and len(str(r["citation"])) > 20:
|
|
n_refs_citation += 1
|
|
|
|
# Mots-clés d'asymétrie
|
|
full_lower = full_text.lower()
|
|
asymetrie_kw = [
|
|
"biologie", "imagerie", "scanner", "irm", "échographie",
|
|
"traitement", "médicament", "posologie",
|
|
"asymétrie", "non transmis", "n'avait pas", "n'a pas eu accès",
|
|
"imc", "antécédent", "crp", "hémoglobine", "leucocytes",
|
|
]
|
|
n_asymetrie = sum(1 for kw in asymetrie_kw if kw in full_lower)
|
|
|
|
# Points d'accord réels
|
|
accord = str(parsed.get("points_accord", ""))
|
|
accord_real = bool(accord) and accord.lower().strip() not in ("aucun", "aucun.", "n/a", "")
|
|
|
|
# Conclusion non vide
|
|
conclusion = str(parsed.get("conclusion", ""))
|
|
has_conclusion = len(conclusion) > 20
|
|
|
|
return {
|
|
"valid_json": True,
|
|
"has_3axes": has_3axes,
|
|
"len_med": len_med,
|
|
"len_asym": len_asym,
|
|
"len_regl": len_regl,
|
|
"len_total_args": len_total_args,
|
|
"n_preuves": n_preuves,
|
|
"n_refs": n_refs,
|
|
"n_refs_citation": n_refs_citation,
|
|
"n_asymetrie": n_asymetrie,
|
|
"accord_real": accord_real,
|
|
"has_conclusion": has_conclusion,
|
|
"total_len": len(full_text),
|
|
}
|
|
|
|
|
|
def model_key(model: str) -> str:
|
|
"""Clé courte pour un modèle (ex: 'gemma3:12b' → 'gemma3_12b')."""
|
|
return model.replace(":", "_").replace(".", "_")
|
|
|
|
|
|
def print_multi_model(results: list[dict], models: list[str]):
|
|
"""Affiche la comparaison multi-modèles."""
|
|
W = 140
|
|
col_w = 18
|
|
print("\n" + "=" * W)
|
|
print(f"COMPARAISON CPAM : {' vs '.join(models)}")
|
|
print("=" * W)
|
|
|
|
metric_labels = [
|
|
("Durée (s)", "duration", True),
|
|
("3 axes", "has_3axes", False),
|
|
("Args médicaux", "len_med", False),
|
|
("Args asymétrie", "len_asym", False),
|
|
("Args réglementaires", "len_regl", False),
|
|
("Total args (car.)", "len_total_args", False),
|
|
("Preuves structurées", "n_preuves", False),
|
|
("Références RAG", "n_refs", False),
|
|
("Refs verbatim", "n_refs_citation", False),
|
|
("Mots-clés asymétrie", "n_asymetrie", False),
|
|
("Points d'accord", "accord_real", False),
|
|
("Conclusion étayée", "has_conclusion", False),
|
|
("Longueur totale", "total_len", False),
|
|
]
|
|
|
|
for r in results:
|
|
print(f"\n{'─' * W}")
|
|
print(f" {r['dossier']} / OGC {r['ogc']} — {r['titre']}")
|
|
print(f" Argument CPAM : {r['arg_len']} car. | Prompt : {r['prompt_len']} car.")
|
|
print(f"{'─' * W}")
|
|
|
|
# Vérifier validité
|
|
all_valid = True
|
|
for m in models:
|
|
mk = model_key(m)
|
|
metrics = r.get(f"metrics_{mk}", {})
|
|
if not metrics.get("valid_json", False):
|
|
dur = r.get(f"duration_{mk}", 0)
|
|
print(f" {m} : JSON INVALIDE ({dur:.1f}s)")
|
|
all_valid = False
|
|
if not all_valid:
|
|
continue
|
|
|
|
# Header
|
|
header = f" {'Métrique':<25}"
|
|
for m in models:
|
|
short = m.split(":")[0][:6] + ":" + m.split(":")[-1] if ":" in m else m[:col_w]
|
|
header += f" {short:>{col_w}}"
|
|
print(header)
|
|
print(f" {'─' * (25 + (col_w + 1) * len(models))}")
|
|
|
|
for label, key, is_duration in metric_labels:
|
|
row = f" {label:<25}"
|
|
for m in models:
|
|
mk = model_key(m)
|
|
if is_duration:
|
|
val = r.get(f"duration_{mk}", 0)
|
|
row += f" {val:>{col_w - 1}.1f}s"
|
|
else:
|
|
metrics = r.get(f"metrics_{mk}", {})
|
|
val = metrics.get(key, 0)
|
|
if isinstance(val, bool):
|
|
row += f" {'Oui' if val else 'Non':>{col_w}}"
|
|
else:
|
|
row += f" {val:>{col_w}}"
|
|
print(row)
|
|
|
|
# Synthèse globale
|
|
print(f"\n{'=' * W}")
|
|
print("SYNTHÈSE GLOBALE")
|
|
print(f"{'=' * W}")
|
|
|
|
# Filtrer les résultats valides pour tous les modèles
|
|
valid = []
|
|
for r in results:
|
|
all_ok = all(r.get(f"metrics_{model_key(m)}", {}).get("valid_json", False) for m in models)
|
|
if all_ok:
|
|
valid.append(r)
|
|
|
|
if not valid:
|
|
print(" Aucun résultat valide pour tous les modèles.")
|
|
return
|
|
|
|
n = len(valid)
|
|
print(f" Dossiers comparés : {n}")
|
|
|
|
# Header synthèse
|
|
header = f"\n {'Métrique':<25}"
|
|
for m in models:
|
|
short = m.split(":")[0][:6] + ":" + m.split(":")[-1] if ":" in m else m[:col_w]
|
|
header += f" {short:>{col_w}}"
|
|
header += f" {'Meilleur':>{col_w}}"
|
|
print(header)
|
|
print(f" {'─' * (25 + (col_w + 1) * (len(models) + 1))}")
|
|
|
|
# Durée
|
|
row = f" {'Durée moy. (s)':<25}"
|
|
dur_vals = {}
|
|
for m in models:
|
|
mk = model_key(m)
|
|
avg_dur = sum(r.get(f"duration_{mk}", 0) for r in valid) / n
|
|
dur_vals[m] = avg_dur
|
|
row += f" {avg_dur:>{col_w - 1}.1f}s"
|
|
best = min(dur_vals, key=dur_vals.get)
|
|
row += f" {best:>{col_w}}"
|
|
print(row)
|
|
|
|
# Métriques (higher is better)
|
|
for label, key in [
|
|
("Total args (car.)", "len_total_args"),
|
|
("Preuves structurées", "n_preuves"),
|
|
("Références RAG", "n_refs"),
|
|
("Refs verbatim", "n_refs_citation"),
|
|
("Mots-clés asymétrie", "n_asymetrie"),
|
|
]:
|
|
row = f" {label:<25}"
|
|
vals = {}
|
|
for m in models:
|
|
mk = model_key(m)
|
|
avg_val = sum(r.get(f"metrics_{mk}", {}).get(key, 0) for r in valid) / n
|
|
vals[m] = avg_val
|
|
row += f" {avg_val:>{col_w}.1f}"
|
|
best = max(vals, key=vals.get)
|
|
row += f" {best:>{col_w}}"
|
|
print(row)
|
|
|
|
# Booléens (count True)
|
|
for label, key in [
|
|
("3 axes", "has_3axes"),
|
|
("Points d'accord", "accord_real"),
|
|
]:
|
|
row = f" {label:<25}"
|
|
vals = {}
|
|
for m in models:
|
|
mk = model_key(m)
|
|
cnt = sum(1 for r in valid if r.get(f"metrics_{mk}", {}).get(key, False))
|
|
vals[m] = cnt
|
|
row += f" {f'{cnt}/{n}':>{col_w}}"
|
|
best = max(vals, key=vals.get)
|
|
row += f" {best:>{col_w}}"
|
|
print(row)
|
|
|
|
# Durées totales
|
|
print()
|
|
fastest = min(models, key=lambda m: sum(r.get(f"duration_{model_key(m)}", 0) for r in valid))
|
|
fastest_dur = sum(r.get(f"duration_{model_key(fastest)}", 0) for r in valid)
|
|
for m in models:
|
|
mk = model_key(m)
|
|
total = sum(r.get(f"duration_{mk}", 0) for r in valid)
|
|
ratio = total / fastest_dur if fastest_dur > 0 else 0
|
|
print(f" {m:<25} total={total:.0f}s (x{ratio:.1f})")
|
|
print()
|
|
|
|
|
|
def main():
|
|
# Charger les résultats précédents (all_models)
|
|
prev_file = Path("output/compare_cpam_all_models.json")
|
|
prev_data = {}
|
|
if prev_file.exists():
|
|
for entry in json.loads(prev_file.read_text()):
|
|
prev_data[entry["dossier"]] = entry
|
|
|
|
# On compare l'ancien 12b (ancien prompt) vs le nouveau 12b-v2 (nouveau prompt nuancé)
|
|
# + 27b comme référence nuance
|
|
ref_models = ["gemma3:12b", "gemma3:27b"]
|
|
all_models = ref_models + MODELS
|
|
print("=" * 100)
|
|
print(f"Comparaison qualité CPAM : {' / '.join(all_models)}")
|
|
print(f"Dossiers : {', '.join(TEST_DOSSIERS)}")
|
|
print(f"Test : gemma3:12b avec NOUVEAU prompt nuancé (v2)")
|
|
print(f"Résultats précédents : {'oui' if prev_data else 'non'}")
|
|
print("=" * 100)
|
|
|
|
results = []
|
|
|
|
for dossier_name in TEST_DOSSIERS:
|
|
data = load_dossier(dossier_name)
|
|
if not data:
|
|
print(f"\nERREUR : {dossier_name} non trouvé")
|
|
continue
|
|
|
|
controles = [c for c in data.get("controles_cpam", []) if c.get("arg_ucr")]
|
|
if not controles:
|
|
print(f"\nERREUR : {dossier_name} — pas de contrôle CPAM")
|
|
continue
|
|
|
|
controle = controles[0]
|
|
sources = [
|
|
{
|
|
"document": s.get("document", ""),
|
|
"page": s.get("page"),
|
|
"code": s.get("code"),
|
|
"extrait": s.get("extrait", ""),
|
|
}
|
|
for s in controle.get("sources_reponse", [])
|
|
]
|
|
|
|
prompt = build_prompt(data, controle, sources)
|
|
|
|
print(f"\n[{dossier_name}] OGC {controle['numero_ogc']} — {controle.get('titre', '')}")
|
|
print(f" Prompt : {len(prompt)} car. | Arg CPAM : {len(controle.get('arg_ucr', ''))} car.")
|
|
|
|
result = {
|
|
"dossier": dossier_name,
|
|
"ogc": controle["numero_ogc"],
|
|
"titre": controle.get("titre", ""),
|
|
"arg_len": len(controle.get("arg_ucr", "")),
|
|
"prompt_len": len(prompt),
|
|
}
|
|
|
|
# Réutiliser les résultats précédents pour les modèles de référence
|
|
prev = prev_data.get(dossier_name)
|
|
if prev:
|
|
for old_model in ref_models:
|
|
mk = model_key(old_model)
|
|
result[f"duration_{mk}"] = prev.get(f"duration_{mk}", 0)
|
|
result[f"metrics_{mk}"] = prev.get(f"metrics_{mk}", {})
|
|
result[f"response_{mk}"] = prev.get(f"response_{mk}")
|
|
dur = result[f"duration_{mk}"]
|
|
is_valid = result[f"metrics_{mk}"].get("valid_json", False)
|
|
print(f" → {old_model} ... (précédent) {'OK' if is_valid else 'FAIL'} ({dur:.1f}s)")
|
|
|
|
# Tester le 12b-v2 (nouveau prompt) — appelle gemma3:12b avec le prompt modifié
|
|
for model_label in MODELS:
|
|
mk = model_key(model_label)
|
|
actual_model = "gemma3:12b" # même modèle, nouveau prompt
|
|
print(f" → {model_label} (nouveau prompt) ...", end=" ", flush=True)
|
|
parsed, dur, raw = call_ollama(prompt, actual_model)
|
|
status = "OK" if parsed else "FAIL"
|
|
print(f"{status} ({dur:.1f}s)")
|
|
|
|
result[f"duration_{mk}"] = dur
|
|
result[f"metrics_{mk}"] = compute_metrics(parsed)
|
|
result[f"response_{mk}"] = parsed
|
|
|
|
results.append(result)
|
|
|
|
# Affichage
|
|
print_multi_model(results, all_models)
|
|
|
|
# Sauvegarde
|
|
output_file = Path("output/compare_cpam_prompt_v2.json")
|
|
output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
save_data = []
|
|
for r in results:
|
|
entry = {
|
|
"dossier": r["dossier"],
|
|
"ogc": r["ogc"],
|
|
"titre": r["titre"],
|
|
}
|
|
for m in all_models:
|
|
mk = model_key(m)
|
|
entry[f"duration_{mk}"] = r.get(f"duration_{mk}", 0)
|
|
entry[f"metrics_{mk}"] = r.get(f"metrics_{mk}", {})
|
|
entry[f"response_{mk}"] = r.get(f"response_{mk}")
|
|
save_data.append(entry)
|
|
output_file.write_text(json.dumps(save_data, ensure_ascii=False, indent=2))
|
|
print(f"Résultats sauvegardés dans {output_file}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|