#!/usr/bin/env python3 """Comparaison qualité CPAM : multi-modèles sur 3 dossiers. Génère la contre-argumentation CPAM avec plusieurs modèles et compare : - Longueur et densité des arguments - Présence des 3 axes (médical, asymétrie, réglementaire) - Citations de preuves du dossier - Références aux sources RAG - Mots-clés d'asymétrie d'information """ import json import re import sys import time from pathlib import Path import requests STRUCTURED_DIR = Path("output/structured") OLLAMA_URL = "http://localhost:11434" MODELS = ["gemma3:12b-v2"] # 12b avec nouveau prompt nuancé TIMEOUTS = { "gemma3:12b": 120, "gemma3:27b": 300, "qwen3:14b": 180, "mistral-small3.2:24b": 300, } # 3 dossiers variés : DP+DA, DAS long, DP court TEST_DOSSIERS = [ "183_23087212", # DP+DA contestés "228_23176885", # DAS seul, arg long (1921c) "153_23102610", # DP seul, arg court ] def load_dossier(dossier_name: str) -> dict | None: dossier_dir = STRUCTURED_DIR / dossier_name if not dossier_dir.exists(): return None for f in list(dossier_dir.glob("*_fusionne_cim10.json")) + sorted(dossier_dir.glob("*_cim10.json")): return json.loads(f.read_text()) return None def build_prompt(data: dict, controle: dict, sources: list[dict]) -> str: """Reconstruit le prompt CPAM (identique au pipeline).""" # Import du vrai builder pour garantir la cohérence sys.path.insert(0, str(Path(__file__).parent)) from src.config import ControleCPAM, DossierMedical from src.control.cpam_response import _build_cpam_prompt dossier = DossierMedical.model_validate(data) ctrl = ControleCPAM.model_validate(controle) return _build_cpam_prompt(dossier, ctrl, sources) # Modèles incompatibles avec format:json d'Ollama (mode thinking) NO_FORMAT_JSON_MODELS = {"qwen3:14b", "qwen3:8b", "qwen3:32b"} def _parse_json_from_text(raw: str) -> dict | None: """Parse du JSON depuis une réponse brute (avec ou sans markdown).""" text = raw.strip() # Retirer bloc markdown ```json ... ``` if text.startswith("```"): first_nl = text.find("\n") if first_nl != -1: text = text[first_nl + 1:] if text.rstrip().endswith("```"): text = text.rstrip()[:-3] text = text.strip() # Essayer tel quel try: return json.loads(text) except json.JSONDecodeError: pass # Trouver le premier { ... dernier } brace_start = text.find("{") brace_end = text.rfind("}") if brace_start != -1 and brace_end > brace_start: try: return json.loads(text[brace_start:brace_end + 1]) except json.JSONDecodeError: pass return None def call_ollama(prompt: str, model: str) -> tuple[dict | None, float, str]: """Appelle Ollama et retourne (parsed_json, duration_s, raw_text).""" timeout = TIMEOUTS.get(model, 180) use_format_json = model not in NO_FORMAT_JSON_MODELS # Pour Qwen3 : ajouter /no_think pour désactiver le mode thinking actual_prompt = prompt if model in NO_FORMAT_JSON_MODELS: actual_prompt = prompt + "\n/no_think" payload = { "model": model, "prompt": actual_prompt, "stream": False, "options": { "temperature": 0.1, "num_predict": 4000, }, } if use_format_json: payload["format"] = "json" t0 = time.time() try: response = requests.post( f"{OLLAMA_URL}/api/generate", json=payload, timeout=timeout, ) response.raise_for_status() duration = time.time() - t0 raw = response.json().get("response", "") parsed = _parse_json_from_text(raw) return parsed, duration, raw except json.JSONDecodeError: duration = time.time() - t0 return None, duration, raw except Exception as e: duration = time.time() - t0 return None, duration, str(e) def compute_metrics(parsed: dict | None) -> dict: """Calcule les métriques de qualité.""" if parsed is None: return {"valid_json": False} full_text = json.dumps(parsed, ensure_ascii=False) # 3 axes présents ? has_med = bool(parsed.get("contre_arguments_medicaux")) has_asym = bool(parsed.get("contre_arguments_asymetrie")) has_regl = bool(parsed.get("contre_arguments_reglementaires")) has_3axes = has_med and has_asym and has_regl # Longueurs par axe len_med = len(str(parsed.get("contre_arguments_medicaux", ""))) len_asym = len(str(parsed.get("contre_arguments_asymetrie", ""))) len_regl = len(str(parsed.get("contre_arguments_reglementaires", ""))) len_total_args = len_med + len_asym + len_regl # Fallback ancien format if not has_3axes: len_total_args = max(len_total_args, len(str(parsed.get("contre_arguments", "")))) # Preuves du dossier preuves = parsed.get("preuves_dossier", []) n_preuves = len(preuves) if isinstance(preuves, list) else 0 # Références structurées refs = parsed.get("references", []) n_refs = len(refs) if isinstance(refs, list) else 0 # Références avec citation verbatim n_refs_citation = 0 if isinstance(refs, list): for r in refs: if isinstance(r, dict) and r.get("citation") and len(str(r["citation"])) > 20: n_refs_citation += 1 # Mots-clés d'asymétrie full_lower = full_text.lower() asymetrie_kw = [ "biologie", "imagerie", "scanner", "irm", "échographie", "traitement", "médicament", "posologie", "asymétrie", "non transmis", "n'avait pas", "n'a pas eu accès", "imc", "antécédent", "crp", "hémoglobine", "leucocytes", ] n_asymetrie = sum(1 for kw in asymetrie_kw if kw in full_lower) # Points d'accord réels accord = str(parsed.get("points_accord", "")) accord_real = bool(accord) and accord.lower().strip() not in ("aucun", "aucun.", "n/a", "") # Conclusion non vide conclusion = str(parsed.get("conclusion", "")) has_conclusion = len(conclusion) > 20 return { "valid_json": True, "has_3axes": has_3axes, "len_med": len_med, "len_asym": len_asym, "len_regl": len_regl, "len_total_args": len_total_args, "n_preuves": n_preuves, "n_refs": n_refs, "n_refs_citation": n_refs_citation, "n_asymetrie": n_asymetrie, "accord_real": accord_real, "has_conclusion": has_conclusion, "total_len": len(full_text), } def model_key(model: str) -> str: """Clé courte pour un modèle (ex: 'gemma3:12b' → 'gemma3_12b').""" return model.replace(":", "_").replace(".", "_") def print_multi_model(results: list[dict], models: list[str]): """Affiche la comparaison multi-modèles.""" W = 140 col_w = 18 print("\n" + "=" * W) print(f"COMPARAISON CPAM : {' vs '.join(models)}") print("=" * W) metric_labels = [ ("Durée (s)", "duration", True), ("3 axes", "has_3axes", False), ("Args médicaux", "len_med", False), ("Args asymétrie", "len_asym", False), ("Args réglementaires", "len_regl", False), ("Total args (car.)", "len_total_args", False), ("Preuves structurées", "n_preuves", False), ("Références RAG", "n_refs", False), ("Refs verbatim", "n_refs_citation", False), ("Mots-clés asymétrie", "n_asymetrie", False), ("Points d'accord", "accord_real", False), ("Conclusion étayée", "has_conclusion", False), ("Longueur totale", "total_len", False), ] for r in results: print(f"\n{'─' * W}") print(f" {r['dossier']} / OGC {r['ogc']} — {r['titre']}") print(f" Argument CPAM : {r['arg_len']} car. | Prompt : {r['prompt_len']} car.") print(f"{'─' * W}") # Vérifier validité all_valid = True for m in models: mk = model_key(m) metrics = r.get(f"metrics_{mk}", {}) if not metrics.get("valid_json", False): dur = r.get(f"duration_{mk}", 0) print(f" {m} : JSON INVALIDE ({dur:.1f}s)") all_valid = False if not all_valid: continue # Header header = f" {'Métrique':<25}" for m in models: short = m.split(":")[0][:6] + ":" + m.split(":")[-1] if ":" in m else m[:col_w] header += f" {short:>{col_w}}" print(header) print(f" {'─' * (25 + (col_w + 1) * len(models))}") for label, key, is_duration in metric_labels: row = f" {label:<25}" for m in models: mk = model_key(m) if is_duration: val = r.get(f"duration_{mk}", 0) row += f" {val:>{col_w - 1}.1f}s" else: metrics = r.get(f"metrics_{mk}", {}) val = metrics.get(key, 0) if isinstance(val, bool): row += f" {'Oui' if val else 'Non':>{col_w}}" else: row += f" {val:>{col_w}}" print(row) # Synthèse globale print(f"\n{'=' * W}") print("SYNTHÈSE GLOBALE") print(f"{'=' * W}") # Filtrer les résultats valides pour tous les modèles valid = [] for r in results: all_ok = all(r.get(f"metrics_{model_key(m)}", {}).get("valid_json", False) for m in models) if all_ok: valid.append(r) if not valid: print(" Aucun résultat valide pour tous les modèles.") return n = len(valid) print(f" Dossiers comparés : {n}") # Header synthèse header = f"\n {'Métrique':<25}" for m in models: short = m.split(":")[0][:6] + ":" + m.split(":")[-1] if ":" in m else m[:col_w] header += f" {short:>{col_w}}" header += f" {'Meilleur':>{col_w}}" print(header) print(f" {'─' * (25 + (col_w + 1) * (len(models) + 1))}") # Durée row = f" {'Durée moy. (s)':<25}" dur_vals = {} for m in models: mk = model_key(m) avg_dur = sum(r.get(f"duration_{mk}", 0) for r in valid) / n dur_vals[m] = avg_dur row += f" {avg_dur:>{col_w - 1}.1f}s" best = min(dur_vals, key=dur_vals.get) row += f" {best:>{col_w}}" print(row) # Métriques (higher is better) for label, key in [ ("Total args (car.)", "len_total_args"), ("Preuves structurées", "n_preuves"), ("Références RAG", "n_refs"), ("Refs verbatim", "n_refs_citation"), ("Mots-clés asymétrie", "n_asymetrie"), ]: row = f" {label:<25}" vals = {} for m in models: mk = model_key(m) avg_val = sum(r.get(f"metrics_{mk}", {}).get(key, 0) for r in valid) / n vals[m] = avg_val row += f" {avg_val:>{col_w}.1f}" best = max(vals, key=vals.get) row += f" {best:>{col_w}}" print(row) # Booléens (count True) for label, key in [ ("3 axes", "has_3axes"), ("Points d'accord", "accord_real"), ]: row = f" {label:<25}" vals = {} for m in models: mk = model_key(m) cnt = sum(1 for r in valid if r.get(f"metrics_{mk}", {}).get(key, False)) vals[m] = cnt row += f" {f'{cnt}/{n}':>{col_w}}" best = max(vals, key=vals.get) row += f" {best:>{col_w}}" print(row) # Durées totales print() fastest = min(models, key=lambda m: sum(r.get(f"duration_{model_key(m)}", 0) for r in valid)) fastest_dur = sum(r.get(f"duration_{model_key(fastest)}", 0) for r in valid) for m in models: mk = model_key(m) total = sum(r.get(f"duration_{mk}", 0) for r in valid) ratio = total / fastest_dur if fastest_dur > 0 else 0 print(f" {m:<25} total={total:.0f}s (x{ratio:.1f})") print() def main(): # Charger les résultats précédents (all_models) prev_file = Path("output/compare_cpam_all_models.json") prev_data = {} if prev_file.exists(): for entry in json.loads(prev_file.read_text()): prev_data[entry["dossier"]] = entry # On compare l'ancien 12b (ancien prompt) vs le nouveau 12b-v2 (nouveau prompt nuancé) # + 27b comme référence nuance ref_models = ["gemma3:12b", "gemma3:27b"] all_models = ref_models + MODELS print("=" * 100) print(f"Comparaison qualité CPAM : {' / '.join(all_models)}") print(f"Dossiers : {', '.join(TEST_DOSSIERS)}") print(f"Test : gemma3:12b avec NOUVEAU prompt nuancé (v2)") print(f"Résultats précédents : {'oui' if prev_data else 'non'}") print("=" * 100) results = [] for dossier_name in TEST_DOSSIERS: data = load_dossier(dossier_name) if not data: print(f"\nERREUR : {dossier_name} non trouvé") continue controles = [c for c in data.get("controles_cpam", []) if c.get("arg_ucr")] if not controles: print(f"\nERREUR : {dossier_name} — pas de contrôle CPAM") continue controle = controles[0] sources = [ { "document": s.get("document", ""), "page": s.get("page"), "code": s.get("code"), "extrait": s.get("extrait", ""), } for s in controle.get("sources_reponse", []) ] prompt = build_prompt(data, controle, sources) print(f"\n[{dossier_name}] OGC {controle['numero_ogc']} — {controle.get('titre', '')}") print(f" Prompt : {len(prompt)} car. | Arg CPAM : {len(controle.get('arg_ucr', ''))} car.") result = { "dossier": dossier_name, "ogc": controle["numero_ogc"], "titre": controle.get("titre", ""), "arg_len": len(controle.get("arg_ucr", "")), "prompt_len": len(prompt), } # Réutiliser les résultats précédents pour les modèles de référence prev = prev_data.get(dossier_name) if prev: for old_model in ref_models: mk = model_key(old_model) result[f"duration_{mk}"] = prev.get(f"duration_{mk}", 0) result[f"metrics_{mk}"] = prev.get(f"metrics_{mk}", {}) result[f"response_{mk}"] = prev.get(f"response_{mk}") dur = result[f"duration_{mk}"] is_valid = result[f"metrics_{mk}"].get("valid_json", False) print(f" → {old_model} ... (précédent) {'OK' if is_valid else 'FAIL'} ({dur:.1f}s)") # Tester le 12b-v2 (nouveau prompt) — appelle gemma3:12b avec le prompt modifié for model_label in MODELS: mk = model_key(model_label) actual_model = "gemma3:12b" # même modèle, nouveau prompt print(f" → {model_label} (nouveau prompt) ...", end=" ", flush=True) parsed, dur, raw = call_ollama(prompt, actual_model) status = "OK" if parsed else "FAIL" print(f"{status} ({dur:.1f}s)") result[f"duration_{mk}"] = dur result[f"metrics_{mk}"] = compute_metrics(parsed) result[f"response_{mk}"] = parsed results.append(result) # Affichage print_multi_model(results, all_models) # Sauvegarde output_file = Path("output/compare_cpam_prompt_v2.json") output_file.parent.mkdir(parents=True, exist_ok=True) save_data = [] for r in results: entry = { "dossier": r["dossier"], "ogc": r["ogc"], "titre": r["titre"], } for m in all_models: mk = model_key(m) entry[f"duration_{mk}"] = r.get(f"duration_{mk}", 0) entry[f"metrics_{mk}"] = r.get(f"metrics_{mk}", {}) entry[f"response_{mk}"] = r.get(f"response_{mk}") save_data.append(entry) output_file.write_text(json.dumps(save_data, ensure_ascii=False, indent=2)) print(f"Résultats sauvegardés dans {output_file}") if __name__ == "__main__": main()