chore: add .gitignore
This commit is contained in:
472
compare_cpam_models.py
Normal file
472
compare_cpam_models.py
Normal file
@@ -0,0 +1,472 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Comparaison qualité CPAM : multi-modèles sur 3 dossiers.
|
||||
|
||||
Génère la contre-argumentation CPAM avec plusieurs modèles et compare :
|
||||
- Longueur et densité des arguments
|
||||
- Présence des 3 axes (médical, asymétrie, réglementaire)
|
||||
- Citations de preuves du dossier
|
||||
- Références aux sources RAG
|
||||
- Mots-clés d'asymétrie d'information
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
STRUCTURED_DIR = Path("output/structured")
|
||||
OLLAMA_URL = "http://localhost:11434"
|
||||
MODELS = ["gemma3:12b-v2"] # 12b avec nouveau prompt nuancé
|
||||
TIMEOUTS = {
|
||||
"gemma3:12b": 120,
|
||||
"gemma3:27b": 300,
|
||||
"qwen3:14b": 180,
|
||||
"mistral-small3.2:24b": 300,
|
||||
}
|
||||
|
||||
# 3 dossiers variés : DP+DA, DAS long, DP court
|
||||
TEST_DOSSIERS = [
|
||||
"183_23087212", # DP+DA contestés
|
||||
"228_23176885", # DAS seul, arg long (1921c)
|
||||
"153_23102610", # DP seul, arg court
|
||||
]
|
||||
|
||||
|
||||
def load_dossier(dossier_name: str) -> dict | None:
|
||||
dossier_dir = STRUCTURED_DIR / dossier_name
|
||||
if not dossier_dir.exists():
|
||||
return None
|
||||
for f in list(dossier_dir.glob("*_fusionne_cim10.json")) + sorted(dossier_dir.glob("*_cim10.json")):
|
||||
return json.loads(f.read_text())
|
||||
return None
|
||||
|
||||
|
||||
def build_prompt(data: dict, controle: dict, sources: list[dict]) -> str:
|
||||
"""Reconstruit le prompt CPAM (identique au pipeline)."""
|
||||
# Import du vrai builder pour garantir la cohérence
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from src.config import ControleCPAM, DossierMedical
|
||||
from src.control.cpam_response import _build_cpam_prompt
|
||||
|
||||
dossier = DossierMedical.model_validate(data)
|
||||
ctrl = ControleCPAM.model_validate(controle)
|
||||
return _build_cpam_prompt(dossier, ctrl, sources)
|
||||
|
||||
|
||||
# Modèles incompatibles avec format:json d'Ollama (mode thinking)
|
||||
NO_FORMAT_JSON_MODELS = {"qwen3:14b", "qwen3:8b", "qwen3:32b"}
|
||||
|
||||
|
||||
def _parse_json_from_text(raw: str) -> dict | None:
|
||||
"""Parse du JSON depuis une réponse brute (avec ou sans markdown)."""
|
||||
text = raw.strip()
|
||||
# Retirer bloc markdown ```json ... ```
|
||||
if text.startswith("```"):
|
||||
first_nl = text.find("\n")
|
||||
if first_nl != -1:
|
||||
text = text[first_nl + 1:]
|
||||
if text.rstrip().endswith("```"):
|
||||
text = text.rstrip()[:-3]
|
||||
text = text.strip()
|
||||
# Essayer tel quel
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
# Trouver le premier { ... dernier }
|
||||
brace_start = text.find("{")
|
||||
brace_end = text.rfind("}")
|
||||
if brace_start != -1 and brace_end > brace_start:
|
||||
try:
|
||||
return json.loads(text[brace_start:brace_end + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def call_ollama(prompt: str, model: str) -> tuple[dict | None, float, str]:
|
||||
"""Appelle Ollama et retourne (parsed_json, duration_s, raw_text)."""
|
||||
timeout = TIMEOUTS.get(model, 180)
|
||||
use_format_json = model not in NO_FORMAT_JSON_MODELS
|
||||
|
||||
# Pour Qwen3 : ajouter /no_think pour désactiver le mode thinking
|
||||
actual_prompt = prompt
|
||||
if model in NO_FORMAT_JSON_MODELS:
|
||||
actual_prompt = prompt + "\n/no_think"
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": actual_prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": 0.1,
|
||||
"num_predict": 4000,
|
||||
},
|
||||
}
|
||||
if use_format_json:
|
||||
payload["format"] = "json"
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{OLLAMA_URL}/api/generate",
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
duration = time.time() - t0
|
||||
raw = response.json().get("response", "")
|
||||
|
||||
parsed = _parse_json_from_text(raw)
|
||||
return parsed, duration, raw
|
||||
|
||||
except json.JSONDecodeError:
|
||||
duration = time.time() - t0
|
||||
return None, duration, raw
|
||||
except Exception as e:
|
||||
duration = time.time() - t0
|
||||
return None, duration, str(e)
|
||||
|
||||
|
||||
def compute_metrics(parsed: dict | None) -> dict:
|
||||
"""Calcule les métriques de qualité."""
|
||||
if parsed is None:
|
||||
return {"valid_json": False}
|
||||
|
||||
full_text = json.dumps(parsed, ensure_ascii=False)
|
||||
|
||||
# 3 axes présents ?
|
||||
has_med = bool(parsed.get("contre_arguments_medicaux"))
|
||||
has_asym = bool(parsed.get("contre_arguments_asymetrie"))
|
||||
has_regl = bool(parsed.get("contre_arguments_reglementaires"))
|
||||
has_3axes = has_med and has_asym and has_regl
|
||||
|
||||
# Longueurs par axe
|
||||
len_med = len(str(parsed.get("contre_arguments_medicaux", "")))
|
||||
len_asym = len(str(parsed.get("contre_arguments_asymetrie", "")))
|
||||
len_regl = len(str(parsed.get("contre_arguments_reglementaires", "")))
|
||||
len_total_args = len_med + len_asym + len_regl
|
||||
|
||||
# Fallback ancien format
|
||||
if not has_3axes:
|
||||
len_total_args = max(len_total_args, len(str(parsed.get("contre_arguments", ""))))
|
||||
|
||||
# Preuves du dossier
|
||||
preuves = parsed.get("preuves_dossier", [])
|
||||
n_preuves = len(preuves) if isinstance(preuves, list) else 0
|
||||
|
||||
# Références structurées
|
||||
refs = parsed.get("references", [])
|
||||
n_refs = len(refs) if isinstance(refs, list) else 0
|
||||
|
||||
# Références avec citation verbatim
|
||||
n_refs_citation = 0
|
||||
if isinstance(refs, list):
|
||||
for r in refs:
|
||||
if isinstance(r, dict) and r.get("citation") and len(str(r["citation"])) > 20:
|
||||
n_refs_citation += 1
|
||||
|
||||
# Mots-clés d'asymétrie
|
||||
full_lower = full_text.lower()
|
||||
asymetrie_kw = [
|
||||
"biologie", "imagerie", "scanner", "irm", "échographie",
|
||||
"traitement", "médicament", "posologie",
|
||||
"asymétrie", "non transmis", "n'avait pas", "n'a pas eu accès",
|
||||
"imc", "antécédent", "crp", "hémoglobine", "leucocytes",
|
||||
]
|
||||
n_asymetrie = sum(1 for kw in asymetrie_kw if kw in full_lower)
|
||||
|
||||
# Points d'accord réels
|
||||
accord = str(parsed.get("points_accord", ""))
|
||||
accord_real = bool(accord) and accord.lower().strip() not in ("aucun", "aucun.", "n/a", "")
|
||||
|
||||
# Conclusion non vide
|
||||
conclusion = str(parsed.get("conclusion", ""))
|
||||
has_conclusion = len(conclusion) > 20
|
||||
|
||||
return {
|
||||
"valid_json": True,
|
||||
"has_3axes": has_3axes,
|
||||
"len_med": len_med,
|
||||
"len_asym": len_asym,
|
||||
"len_regl": len_regl,
|
||||
"len_total_args": len_total_args,
|
||||
"n_preuves": n_preuves,
|
||||
"n_refs": n_refs,
|
||||
"n_refs_citation": n_refs_citation,
|
||||
"n_asymetrie": n_asymetrie,
|
||||
"accord_real": accord_real,
|
||||
"has_conclusion": has_conclusion,
|
||||
"total_len": len(full_text),
|
||||
}
|
||||
|
||||
|
||||
def model_key(model: str) -> str:
|
||||
"""Clé courte pour un modèle (ex: 'gemma3:12b' → 'gemma3_12b')."""
|
||||
return model.replace(":", "_").replace(".", "_")
|
||||
|
||||
|
||||
def print_multi_model(results: list[dict], models: list[str]):
|
||||
"""Affiche la comparaison multi-modèles."""
|
||||
W = 140
|
||||
col_w = 18
|
||||
print("\n" + "=" * W)
|
||||
print(f"COMPARAISON CPAM : {' vs '.join(models)}")
|
||||
print("=" * W)
|
||||
|
||||
metric_labels = [
|
||||
("Durée (s)", "duration", True),
|
||||
("3 axes", "has_3axes", False),
|
||||
("Args médicaux", "len_med", False),
|
||||
("Args asymétrie", "len_asym", False),
|
||||
("Args réglementaires", "len_regl", False),
|
||||
("Total args (car.)", "len_total_args", False),
|
||||
("Preuves structurées", "n_preuves", False),
|
||||
("Références RAG", "n_refs", False),
|
||||
("Refs verbatim", "n_refs_citation", False),
|
||||
("Mots-clés asymétrie", "n_asymetrie", False),
|
||||
("Points d'accord", "accord_real", False),
|
||||
("Conclusion étayée", "has_conclusion", False),
|
||||
("Longueur totale", "total_len", False),
|
||||
]
|
||||
|
||||
for r in results:
|
||||
print(f"\n{'─' * W}")
|
||||
print(f" {r['dossier']} / OGC {r['ogc']} — {r['titre']}")
|
||||
print(f" Argument CPAM : {r['arg_len']} car. | Prompt : {r['prompt_len']} car.")
|
||||
print(f"{'─' * W}")
|
||||
|
||||
# Vérifier validité
|
||||
all_valid = True
|
||||
for m in models:
|
||||
mk = model_key(m)
|
||||
metrics = r.get(f"metrics_{mk}", {})
|
||||
if not metrics.get("valid_json", False):
|
||||
dur = r.get(f"duration_{mk}", 0)
|
||||
print(f" {m} : JSON INVALIDE ({dur:.1f}s)")
|
||||
all_valid = False
|
||||
if not all_valid:
|
||||
continue
|
||||
|
||||
# Header
|
||||
header = f" {'Métrique':<25}"
|
||||
for m in models:
|
||||
short = m.split(":")[0][:6] + ":" + m.split(":")[-1] if ":" in m else m[:col_w]
|
||||
header += f" {short:>{col_w}}"
|
||||
print(header)
|
||||
print(f" {'─' * (25 + (col_w + 1) * len(models))}")
|
||||
|
||||
for label, key, is_duration in metric_labels:
|
||||
row = f" {label:<25}"
|
||||
for m in models:
|
||||
mk = model_key(m)
|
||||
if is_duration:
|
||||
val = r.get(f"duration_{mk}", 0)
|
||||
row += f" {val:>{col_w - 1}.1f}s"
|
||||
else:
|
||||
metrics = r.get(f"metrics_{mk}", {})
|
||||
val = metrics.get(key, 0)
|
||||
if isinstance(val, bool):
|
||||
row += f" {'Oui' if val else 'Non':>{col_w}}"
|
||||
else:
|
||||
row += f" {val:>{col_w}}"
|
||||
print(row)
|
||||
|
||||
# Synthèse globale
|
||||
print(f"\n{'=' * W}")
|
||||
print("SYNTHÈSE GLOBALE")
|
||||
print(f"{'=' * W}")
|
||||
|
||||
# Filtrer les résultats valides pour tous les modèles
|
||||
valid = []
|
||||
for r in results:
|
||||
all_ok = all(r.get(f"metrics_{model_key(m)}", {}).get("valid_json", False) for m in models)
|
||||
if all_ok:
|
||||
valid.append(r)
|
||||
|
||||
if not valid:
|
||||
print(" Aucun résultat valide pour tous les modèles.")
|
||||
return
|
||||
|
||||
n = len(valid)
|
||||
print(f" Dossiers comparés : {n}")
|
||||
|
||||
# Header synthèse
|
||||
header = f"\n {'Métrique':<25}"
|
||||
for m in models:
|
||||
short = m.split(":")[0][:6] + ":" + m.split(":")[-1] if ":" in m else m[:col_w]
|
||||
header += f" {short:>{col_w}}"
|
||||
header += f" {'Meilleur':>{col_w}}"
|
||||
print(header)
|
||||
print(f" {'─' * (25 + (col_w + 1) * (len(models) + 1))}")
|
||||
|
||||
# Durée
|
||||
row = f" {'Durée moy. (s)':<25}"
|
||||
dur_vals = {}
|
||||
for m in models:
|
||||
mk = model_key(m)
|
||||
avg_dur = sum(r.get(f"duration_{mk}", 0) for r in valid) / n
|
||||
dur_vals[m] = avg_dur
|
||||
row += f" {avg_dur:>{col_w - 1}.1f}s"
|
||||
best = min(dur_vals, key=dur_vals.get)
|
||||
row += f" {best:>{col_w}}"
|
||||
print(row)
|
||||
|
||||
# Métriques (higher is better)
|
||||
for label, key in [
|
||||
("Total args (car.)", "len_total_args"),
|
||||
("Preuves structurées", "n_preuves"),
|
||||
("Références RAG", "n_refs"),
|
||||
("Refs verbatim", "n_refs_citation"),
|
||||
("Mots-clés asymétrie", "n_asymetrie"),
|
||||
]:
|
||||
row = f" {label:<25}"
|
||||
vals = {}
|
||||
for m in models:
|
||||
mk = model_key(m)
|
||||
avg_val = sum(r.get(f"metrics_{mk}", {}).get(key, 0) for r in valid) / n
|
||||
vals[m] = avg_val
|
||||
row += f" {avg_val:>{col_w}.1f}"
|
||||
best = max(vals, key=vals.get)
|
||||
row += f" {best:>{col_w}}"
|
||||
print(row)
|
||||
|
||||
# Booléens (count True)
|
||||
for label, key in [
|
||||
("3 axes", "has_3axes"),
|
||||
("Points d'accord", "accord_real"),
|
||||
]:
|
||||
row = f" {label:<25}"
|
||||
vals = {}
|
||||
for m in models:
|
||||
mk = model_key(m)
|
||||
cnt = sum(1 for r in valid if r.get(f"metrics_{mk}", {}).get(key, False))
|
||||
vals[m] = cnt
|
||||
row += f" {f'{cnt}/{n}':>{col_w}}"
|
||||
best = max(vals, key=vals.get)
|
||||
row += f" {best:>{col_w}}"
|
||||
print(row)
|
||||
|
||||
# Durées totales
|
||||
print()
|
||||
fastest = min(models, key=lambda m: sum(r.get(f"duration_{model_key(m)}", 0) for r in valid))
|
||||
fastest_dur = sum(r.get(f"duration_{model_key(fastest)}", 0) for r in valid)
|
||||
for m in models:
|
||||
mk = model_key(m)
|
||||
total = sum(r.get(f"duration_{mk}", 0) for r in valid)
|
||||
ratio = total / fastest_dur if fastest_dur > 0 else 0
|
||||
print(f" {m:<25} total={total:.0f}s (x{ratio:.1f})")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
# Charger les résultats précédents (all_models)
|
||||
prev_file = Path("output/compare_cpam_all_models.json")
|
||||
prev_data = {}
|
||||
if prev_file.exists():
|
||||
for entry in json.loads(prev_file.read_text()):
|
||||
prev_data[entry["dossier"]] = entry
|
||||
|
||||
# On compare l'ancien 12b (ancien prompt) vs le nouveau 12b-v2 (nouveau prompt nuancé)
|
||||
# + 27b comme référence nuance
|
||||
ref_models = ["gemma3:12b", "gemma3:27b"]
|
||||
all_models = ref_models + MODELS
|
||||
print("=" * 100)
|
||||
print(f"Comparaison qualité CPAM : {' / '.join(all_models)}")
|
||||
print(f"Dossiers : {', '.join(TEST_DOSSIERS)}")
|
||||
print(f"Test : gemma3:12b avec NOUVEAU prompt nuancé (v2)")
|
||||
print(f"Résultats précédents : {'oui' if prev_data else 'non'}")
|
||||
print("=" * 100)
|
||||
|
||||
results = []
|
||||
|
||||
for dossier_name in TEST_DOSSIERS:
|
||||
data = load_dossier(dossier_name)
|
||||
if not data:
|
||||
print(f"\nERREUR : {dossier_name} non trouvé")
|
||||
continue
|
||||
|
||||
controles = [c for c in data.get("controles_cpam", []) if c.get("arg_ucr")]
|
||||
if not controles:
|
||||
print(f"\nERREUR : {dossier_name} — pas de contrôle CPAM")
|
||||
continue
|
||||
|
||||
controle = controles[0]
|
||||
sources = [
|
||||
{
|
||||
"document": s.get("document", ""),
|
||||
"page": s.get("page"),
|
||||
"code": s.get("code"),
|
||||
"extrait": s.get("extrait", ""),
|
||||
}
|
||||
for s in controle.get("sources_reponse", [])
|
||||
]
|
||||
|
||||
prompt = build_prompt(data, controle, sources)
|
||||
|
||||
print(f"\n[{dossier_name}] OGC {controle['numero_ogc']} — {controle.get('titre', '')}")
|
||||
print(f" Prompt : {len(prompt)} car. | Arg CPAM : {len(controle.get('arg_ucr', ''))} car.")
|
||||
|
||||
result = {
|
||||
"dossier": dossier_name,
|
||||
"ogc": controle["numero_ogc"],
|
||||
"titre": controle.get("titre", ""),
|
||||
"arg_len": len(controle.get("arg_ucr", "")),
|
||||
"prompt_len": len(prompt),
|
||||
}
|
||||
|
||||
# Réutiliser les résultats précédents pour les modèles de référence
|
||||
prev = prev_data.get(dossier_name)
|
||||
if prev:
|
||||
for old_model in ref_models:
|
||||
mk = model_key(old_model)
|
||||
result[f"duration_{mk}"] = prev.get(f"duration_{mk}", 0)
|
||||
result[f"metrics_{mk}"] = prev.get(f"metrics_{mk}", {})
|
||||
result[f"response_{mk}"] = prev.get(f"response_{mk}")
|
||||
dur = result[f"duration_{mk}"]
|
||||
is_valid = result[f"metrics_{mk}"].get("valid_json", False)
|
||||
print(f" → {old_model} ... (précédent) {'OK' if is_valid else 'FAIL'} ({dur:.1f}s)")
|
||||
|
||||
# Tester le 12b-v2 (nouveau prompt) — appelle gemma3:12b avec le prompt modifié
|
||||
for model_label in MODELS:
|
||||
mk = model_key(model_label)
|
||||
actual_model = "gemma3:12b" # même modèle, nouveau prompt
|
||||
print(f" → {model_label} (nouveau prompt) ...", end=" ", flush=True)
|
||||
parsed, dur, raw = call_ollama(prompt, actual_model)
|
||||
status = "OK" if parsed else "FAIL"
|
||||
print(f"{status} ({dur:.1f}s)")
|
||||
|
||||
result[f"duration_{mk}"] = dur
|
||||
result[f"metrics_{mk}"] = compute_metrics(parsed)
|
||||
result[f"response_{mk}"] = parsed
|
||||
|
||||
results.append(result)
|
||||
|
||||
# Affichage
|
||||
print_multi_model(results, all_models)
|
||||
|
||||
# Sauvegarde
|
||||
output_file = Path("output/compare_cpam_prompt_v2.json")
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_data = []
|
||||
for r in results:
|
||||
entry = {
|
||||
"dossier": r["dossier"],
|
||||
"ogc": r["ogc"],
|
||||
"titre": r["titre"],
|
||||
}
|
||||
for m in all_models:
|
||||
mk = model_key(m)
|
||||
entry[f"duration_{mk}"] = r.get(f"duration_{mk}", 0)
|
||||
entry[f"metrics_{mk}"] = r.get(f"metrics_{mk}", {})
|
||||
entry[f"response_{mk}"] = r.get(f"response_{mk}")
|
||||
save_data.append(entry)
|
||||
output_file.write_text(json.dumps(save_data, ensure_ascii=False, indent=2))
|
||||
print(f"Résultats sauvegardés dans {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user