From e77c10da7dfab1dd33511a2a39ea8eded820f470 Mon Sep 17 00:00:00 2001 From: dom Date: Fri, 20 Feb 2026 13:33:39 +0100 Subject: [PATCH] =?UTF-8?q?fix:=20r=C3=A9paration=20JSON=20tronqu=C3=A9=20?= =?UTF-8?q?+=20retry=20429=20+=20whitelist=20codes=20CPAM=20anti-hallucina?= =?UTF-8?q?tion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - parse_json_response : réparation JSON tronqué par max_tokens (fermeture auto des structures ouvertes), meilleur stripping des blocs fencés avec texte superflu après la fermeture ``` - call_ollama : retry avec backoff exponentiel (1s/2s/4s) pour les erreurs 429 rate limit, 3 tentatives au lieu de 2 - Validation adversariale : max_tokens 800 → 1500 - Prompt CPAM : whitelist PÉRIMÈTRE DE CODES AUTORISÉS (dossier DP+DAS + UCR) avec interdiction explicite des codes hors périmètre - Tests : 19 tests parse_json/_repair_truncated_json, 6 tests whitelist Co-Authored-By: Claude Opus 4.6 --- src/control/cpam_context.py | 35 +++++++++ src/control/cpam_validation.py | 4 +- src/medical/ollama_client.py | 85 ++++++++++++++++++++-- src/prompts/templates.py | 3 +- tests/test_cpam_response.py | 46 ++++++++++++ tests/test_ollama_client.py | 128 +++++++++++++++++++++++++++++++++ 6 files changed, 291 insertions(+), 10 deletions(-) create mode 100644 tests/test_ollama_client.py diff --git a/src/control/cpam_context.py b/src/control/cpam_context.py index b1d9033..2d744a3 100644 --- a/src/control/cpam_context.py +++ b/src/control/cpam_context.py @@ -511,6 +511,40 @@ def _build_cpam_prompt( # Définitions CIM-10 déterministes (tous les codes en jeu) definitions_str = _get_cim10_definitions(dossier, controle) + # Whitelist explicite des codes autorisés (anti-hallucination) + _all_codes: list[str] = [] + if dossier.diagnostic_principal and dossier.diagnostic_principal.cim10_suggestion: + _all_codes.append(dossier.diagnostic_principal.cim10_suggestion) + for das in dossier.diagnostics_associes: + if das.cim10_suggestion: + _all_codes.append(das.cim10_suggestion) + for field in (controle.dp_ucr, controle.da_ucr, controle.dr_ucr): + if field: + for raw in re.split(r"[,;\s]+", field.strip()): + raw = raw.strip() + if raw: + _all_codes.append(raw) + # Dédupliquer en normalisant + _seen_norm: set[str] = set() + _unique_codes: list[str] = [] + for c in _all_codes: + norm = normalize_code(c) + if norm and norm not in _seen_norm: + _seen_norm.add(norm) + is_valid, label = validate_code(norm) + _unique_codes.append(f"{norm} — {label}" if is_valid and label else norm) + if _unique_codes: + codes_autorises_str = ( + "\nPÉRIMÈTRE DE CODES AUTORISÉS (liste EXHAUSTIVE) :\n" + + "\n".join(f" {c}" for c in _unique_codes) + + "\n\nINTERDICTION : Ne mentionne AUCUN code CIM-10 qui ne figure pas " + "dans cette liste. Si un code supplémentaire te semble cliniquement " + "pertinent, signale-le en toutes lettres dans la conclusion SANS " + "citer le code CIM-10." + ) + else: + codes_autorises_str = "" + # Contexte clinique tagué pour le grounding tagged_context, tag_map = _build_tagged_context(dossier) if tagged_context: @@ -591,6 +625,7 @@ def _build_cpam_prompt( decision_ucr=controle.decision_ucr, codes_str=codes_str, definitions_str=definitions_str, + codes_autorises_str=codes_autorises_str, sources_text=sources_text, extraction_str=extraction_str, ) diff --git a/src/control/cpam_validation.py b/src/control/cpam_validation.py index 98ed781..8915efa 100644 --- a/src/control/cpam_validation.py +++ b/src/control/cpam_validation.py @@ -232,9 +232,9 @@ def _validate_adversarial( ) logger.debug(" Validation adversariale") - result = call_ollama(prompt, temperature=0.0, max_tokens=800, role="validation") + result = call_ollama(prompt, temperature=0.0, max_tokens=1500, role="validation") if result is None: - result = call_anthropic(prompt, temperature=0.0, max_tokens=800) + result = call_anthropic(prompt, temperature=0.0, max_tokens=1500) if result is None: logger.warning(" Validation adversariale échouée — LLM indisponible") return None diff --git a/src/medical/ollama_client.py b/src/medical/ollama_client.py index 5ed1e8c..f3d47de 100644 --- a/src/medical/ollama_client.py +++ b/src/medical/ollama_client.py @@ -5,6 +5,7 @@ from __future__ import annotations import json import logging import os +import time import requests @@ -60,22 +61,85 @@ def call_anthropic( return None +def _repair_truncated_json(text: str) -> dict | None: + """Tente de réparer un JSON tronqué (réponse LLM coupée par max_tokens). + + Stratégie : fermer les chaînes, tableaux et objets ouverts puis réessayer. + """ + # Étape 1 : détecter si on est dans une chaîne non fermée + in_string = False + escaped = False + for ch in text: + if escaped: + escaped = False + continue + if ch == "\\": + escaped = True + continue + if ch == '"': + in_string = not in_string + if in_string: + text += '"' + + # Étape 2 : compter les ouvreurs/fermeurs non appariés + in_str = False + esc = False + stack: list[str] = [] + for ch in text: + if esc: + esc = False + continue + if ch == "\\": + esc = True + continue + if ch == '"': + in_str = not in_str + continue + if in_str: + continue + if ch in ("{", "["): + stack.append(ch) + elif ch == "}" and stack and stack[-1] == "{": + stack.pop() + elif ch == "]" and stack and stack[-1] == "[": + stack.pop() + + # Fermer en ordre inverse + for opener in reversed(stack): + text += "}" if opener == "{" else "]" + + try: + return json.loads(text) + except json.JSONDecodeError: + return None + + def parse_json_response(raw: str) -> dict | None: - """Parse une réponse JSON, en gérant les blocs markdown.""" + """Parse une réponse JSON, en gérant les blocs markdown et le JSON tronqué.""" text = raw.strip() if text.startswith("```"): first_nl = text.find("\n") if first_nl != -1: text = text[first_nl + 1:] - if text.rstrip().endswith("```"): - text = text.rstrip()[:-3] + # Trouver la fermeture ``` (peut être suivie de texte superflu du LLM) + closing_idx = text.find("```") + if closing_idx != -1: + text = text[:closing_idx] text = text.strip() try: return json.loads(text) except json.JSONDecodeError: - logger.warning("LLM : JSON invalide : %s", raw[:200]) - return None + pass + + # Tentative de réparation (JSON tronqué par max_tokens) + repaired = _repair_truncated_json(text) + if repaired is not None: + logger.info("LLM : JSON tronqué réparé (%d chars)", len(text)) + return repaired + + logger.warning("LLM : JSON invalide : %s", raw[:200]) + return None def call_ollama( @@ -101,7 +165,7 @@ def call_ollama( """ use_model = model or (get_model(role) if role else OLLAMA_MODEL) use_timeout = timeout or OLLAMA_TIMEOUT - for attempt in range(2): + for attempt in range(3): try: response = requests.post( f"{OLLAMA_URL}/api/generate", @@ -117,12 +181,19 @@ def call_ollama( }, timeout=use_timeout, ) + # 429 rate limit → retry avec backoff exponentiel + if response.status_code == 429: + delay = 2 ** attempt # 1s, 2s, 4s + logger.warning("Ollama 429 (rate limit) — retry dans %ds (tentative %d/3)", + delay, attempt + 1) + time.sleep(delay) + continue response.raise_for_status() raw = response.json().get("response", "") result = parse_json_response(raw) if result is not None: return result - if attempt == 0: + if attempt < 2: logger.info("Ollama (%s) : retry après échec de parsing", use_model) except requests.ConnectionError: logger.info("Ollama indisponible → fallback Anthropic (%s)", _ANTHROPIC_MODEL) diff --git a/src/prompts/templates.py b/src/prompts/templates.py index 3725939..4760443 100644 --- a/src/prompts/templates.py +++ b/src/prompts/templates.py @@ -14,7 +14,7 @@ Variables par template : decision_ucr, dp_ucr_line, da_ucr_line CPAM_ARGUMENTATION : dossier_str, asymetrie_str, tagged_str, titre, arg_ucr, decision_ucr, codes_str, definitions_str, - sources_text, extraction_str + codes_autorises_str, sources_text, extraction_str CPAM_ADVERSARIAL : response_json, factual_section, normes_section, dp_ucr_line, da_ucr_line """ @@ -247,6 +247,7 @@ DÉCISION UCR : {decision_ucr} CODES CONTESTÉS : {codes_str} {definitions_str} +{codes_autorises_str} SOURCES RÉGLEMENTAIRES (Guide méthodologique, CIM-10) : {sources_text} diff --git a/tests/test_cpam_response.py b/tests/test_cpam_response.py index 217b469..0fd9ddb 100644 --- a/tests/test_cpam_response.py +++ b/tests/test_cpam_response.py @@ -1950,3 +1950,49 @@ class TestCheckDasBioCoherenceExtended: ) warnings = _check_das_bio_coherence(dossier) assert len(warnings) >= 1 + + +class TestCodesAutorisesWhitelist: + """Tests pour la whitelist de codes autorisés (anti-hallucination).""" + + def test_whitelist_in_prompt(self): + """Le prompt contient la section PÉRIMÈTRE DE CODES AUTORISÉS.""" + dossier = _make_dossier() # DP K81.0, DAS K56.0 + controle = _make_controle() # dp_ucr=K80.1, da_ucr=K56.0 + prompt, _ = _build_cpam_prompt(dossier, controle, []) + assert "PÉRIMÈTRE DE CODES AUTORISÉS" in prompt + assert "INTERDICTION" in prompt + + def test_whitelist_contains_dossier_codes(self): + """Tous les codes du dossier sont dans la whitelist.""" + dossier = _make_dossier() # DP K81.0, DAS K56.0 + controle = _make_controle() + prompt, _ = _build_cpam_prompt(dossier, controle, []) + assert "K81.0" in prompt + assert "K56.0" in prompt + + def test_whitelist_contains_ucr_codes(self): + """Tous les codes UCR sont dans la whitelist.""" + dossier = _make_dossier() + controle = _make_controle() + controle.dp_ucr = "K80.1" + prompt, _ = _build_cpam_prompt(dossier, controle, []) + assert "K80.1" in prompt + + def test_whitelist_dedup(self): + """Les codes en double (dossier + UCR) ne sont listés qu'une fois.""" + dossier = _make_dossier() # K56.0 en DAS + controle = _make_controle() # da_ucr=K56.0 + prompt, _ = _build_cpam_prompt(dossier, controle, []) + # K56.0 apparaît dans PÉRIMÈTRE mais une seule fois dans cette section + perimetre_idx = prompt.index("PÉRIMÈTRE DE CODES AUTORISÉS") + interdit_idx = prompt.index("INTERDICTION") + perimetre_section = prompt[perimetre_idx:interdit_idx] + assert perimetre_section.count("K56.0") == 1 + + def test_whitelist_prohibition_message(self): + """Le message d'interdiction est clair et complet.""" + dossier = _make_dossier() + controle = _make_controle() + prompt, _ = _build_cpam_prompt(dossier, controle, []) + assert "Ne mentionne AUCUN code CIM-10 qui ne figure pas" in prompt diff --git a/tests/test_ollama_client.py b/tests/test_ollama_client.py new file mode 100644 index 0000000..9e0b88c --- /dev/null +++ b/tests/test_ollama_client.py @@ -0,0 +1,128 @@ +"""Tests unitaires pour le client Ollama (parsing JSON, réparation tronqué).""" + +import pytest + +from src.medical.ollama_client import parse_json_response, _repair_truncated_json + + +class TestParseJsonResponse: + """Tests de parse_json_response().""" + + def test_valid_json(self): + result = parse_json_response('{"key": "value"}') + assert result == {"key": "value"} + + def test_fenced_json(self): + raw = '```json\n{"key": "value"}\n```' + assert parse_json_response(raw) == {"key": "value"} + + def test_fenced_no_closing(self): + raw = '```json\n{"key": "value"}' + assert parse_json_response(raw) == {"key": "value"} + + def test_whitespace(self): + assert parse_json_response(' \n {"a": 1} \n') == {"a": 1} + + def test_invalid_json_returns_none(self): + assert parse_json_response("pas du json") is None + + def test_fenced_with_trailing_text(self): + """JSON fencé suivi de texte superflu du LLM après la fermeture.""" + raw = '```json\n{"coherent": true, "erreurs": [], "score_confiance": 9}\n```\n\n**Justification de la vérification :**\n1. OK' + result = parse_json_response(raw) + assert result is not None + assert result["coherent"] is True + assert result["score_confiance"] == 9 + + def test_empty_string(self): + assert parse_json_response("") is None + + +class TestRepairTruncatedJson: + """Tests de _repair_truncated_json() — réparation JSON tronqué par max_tokens.""" + + def test_truncated_object(self): + """Objet principal non fermé.""" + text = '{"coherent": false, "erreurs": ["erreur 1"]' + result = _repair_truncated_json(text) + assert result is not None + assert result["coherent"] is False + assert result["erreurs"] == ["erreur 1"] + + def test_truncated_array_and_object(self): + """Array et objet non fermés.""" + text = '{"coherent": false, "erreurs": ["erreur 1", "erreur 2"' + result = _repair_truncated_json(text) + assert result is not None + assert result["coherent"] is False + assert len(result["erreurs"]) == 2 + + def test_truncated_string_in_array(self): + """Chaîne tronquée à l'intérieur d'un array.""" + text = '{"coherent": false, "erreurs": ["erreur longue qui se term' + result = _repair_truncated_json(text) + assert result is not None + assert result["coherent"] is False + assert len(result["erreurs"]) == 1 + assert "erreur longue" in result["erreurs"][0] + + def test_deeply_nested_truncation(self): + """Troncation dans un objet imbriqué.""" + text = '{"data": {"inner": [1, 2' + result = _repair_truncated_json(text) + assert result is not None + assert result["data"]["inner"] == [1, 2] + + def test_valid_json_passthrough(self): + """JSON déjà valide → retourné tel quel.""" + text = '{"a": 1}' + result = _repair_truncated_json(text) + assert result == {"a": 1} + + def test_complete_adversarial_format(self): + """Format exact de la validation adversariale.""" + text = '{"coherent": false, "erreurs": ["Incohérence bio CRP"], "score_confiance": 4}' + result = _repair_truncated_json(text) + assert result is not None + assert result["score_confiance"] == 4 + + def test_adversarial_truncated_at_score(self): + """Troncation juste avant score_confiance.""" + text = '{"coherent": false, "erreurs": ["Incohérence bio"]' + result = _repair_truncated_json(text) + assert result is not None + assert result["coherent"] is False + # score_confiance absent → -1 par défaut dans le code appelant + + def test_hopelessly_broken(self): + """Texte vraiment non réparable.""" + assert _repair_truncated_json("juste du texte libre") is None + + def test_escaped_quotes(self): + """Chaînes avec des guillemets échappés.""" + text = '{"msg": "il a dit \\"bonjour\\""}' + result = _repair_truncated_json(text) + assert result is not None + assert "bonjour" in result["msg"] + + def test_truncated_after_escaped_quote(self): + """Troncation après un guillemet échappé dans une chaîne.""" + text = '{"msg": "valeur avec \\"guillemet' + result = _repair_truncated_json(text) + assert result is not None + + def test_parse_json_uses_repair(self): + """parse_json_response() utilise la réparation en fallback.""" + # JSON tronqué (objet non fermé) + raw = '{"coherent": true, "erreurs": [], "score_confiance": 8' + result = parse_json_response(raw) + assert result is not None + assert result["coherent"] is True + assert result["score_confiance"] == 8 + + def test_parse_json_repair_fenced_truncated(self): + """JSON fencé ET tronqué.""" + raw = '```json\n{"coherent": false, "erreurs": ["erreur"' + result = parse_json_response(raw) + assert result is not None + assert result["coherent"] is False