diff --git a/src/config.py b/src/config.py index 3170cd1..5b19eb6 100644 --- a/src/config.py +++ b/src/config.py @@ -101,6 +101,7 @@ class Diagnostic(BaseModel): est_cma: Optional[bool] = None est_cms: Optional[bool] = None niveau_severite: Optional[str] = None # "leger" | "modere" | "severe" | "non_evalue" + source: Optional[str] = None # "trackare" | "edsnlp" | "regex" | "llm_das" class ActeCCAM(BaseModel): diff --git a/src/medical/cim10_dict.py b/src/medical/cim10_dict.py index 33d15ec..bf2d83e 100644 --- a/src/medical/cim10_dict.py +++ b/src/medical/cim10_dict.py @@ -212,6 +212,25 @@ def validate_code(code: str) -> tuple[bool, str]: return False, "" +def fallback_parent_code(code: str) -> str | None: + """Tente de corriger un code invalide en remontant au code parent. + + Le LLM hallucine souvent des sous-codes (.8, .9) sur des codes + standalone à 3 caractères (ex: D71.9 → D71, R69.8 → R69). + + Returns: + Le code parent valide, ou None si aucun fallback trouvé. + """ + normalized = normalize_code(code) + # Extraire le code parent (3 caractères avant le point) + if "." in normalized: + parent = normalized.split(".")[0] + is_valid, _ = validate_code(parent) + if is_valid: + return parent + return None + + def reset_cache() -> None: """Réinitialise les caches (utile pour les tests).""" global _dict_cache, _normalized_cache diff --git a/src/medical/cim10_extractor.py b/src/medical/cim10_extractor.py index c8816e8..8c9da1e 100644 --- a/src/medical/cim10_extractor.py +++ b/src/medical/cim10_extractor.py @@ -205,6 +205,7 @@ def _extract_das_llm(text: str, dossier: DossierMedical) -> None: texte=texte, cim10_suggestion=code, justification=das.get("justification"), + source="llm_das", )) added += 1 @@ -297,6 +298,7 @@ def _extract_diagnostics( d = Diagnostic( texte=texte, cim10_suggestion=diag.get("code_cim10"), + source="trackare", ) if diag.get("type", "").lower() == "principal": dossier.diagnostic_principal = d @@ -331,6 +333,7 @@ def _extract_diagnostics( code, texte = next(iter(edsnlp_codes.items())) dossier.diagnostic_principal = Diagnostic( texte=texte.capitalize(), cim10_suggestion=code, + source="edsnlp", ) # Diagnostics associés depuis le texte (regex) @@ -356,6 +359,7 @@ def _extract_diagnostics( dossier.diagnostics_associes.append(Diagnostic( texte=texte, cim10_suggestion=ent.code, + source="edsnlp", )) existing_codes.add(ent.code) @@ -370,7 +374,7 @@ def _find_diagnostic_principal(text_lower: str, conclusion: str) -> Diagnostic | # Chercher dans la conclusion d'abord via CIM10_MAP (domain override) for terme, code in CIM10_MAP.items(): if normalize_text(terme) in conclusion_norm: - return Diagnostic(texte=terme.capitalize(), cim10_suggestion=code) + return Diagnostic(texte=terme.capitalize(), cim10_suggestion=code, source="regex") text_norm = normalize_text(text_lower) @@ -385,7 +389,7 @@ def _find_diagnostic_principal(text_lower: str, conclusion: str) -> Diagnostic | if m: matched = m.group(0) code = _lookup_cim10(matched) - return Diagnostic(texte=matched.capitalize(), cim10_suggestion=code) + return Diagnostic(texte=matched.capitalize(), cim10_suggestion=code, source="regex") return None @@ -444,7 +448,7 @@ def _find_diagnostics_associes( # Patterns DAS for pat, label, code in _DAS_PATTERNS: if re.search(pat, text_norm) and code not in existing_codes: - das.append(Diagnostic(texte=label, cim10_suggestion=code)) + das.append(Diagnostic(texte=label, cim10_suggestion=code, source="regex")) existing_codes.add(code) # Obésité (IMC >= 30) — pattern spécial avec extraction de valeur @@ -452,7 +456,7 @@ def _find_diagnostics_associes( if m: imc_val = float(m.group(1).replace(",", ".")) if imc_val >= 30 and "E66.0" not in existing_codes: - das.append(Diagnostic(texte=f"Obésité (IMC {imc_val})", cim10_suggestion="E66.0")) + das.append(Diagnostic(texte=f"Obésité (IMC {imc_val})", cim10_suggestion="E66.0", source="regex")) existing_codes.add("E66.0") return das diff --git a/src/medical/das_filter.py b/src/medical/das_filter.py index 5023bc1..af61e4b 100644 --- a/src/medical/das_filter.py +++ b/src/medical/das_filter.py @@ -83,6 +83,20 @@ def is_valid_diagnostic_text(text: str) -> bool: if re.match(r'^(Dans |La |Le |Les |Au |Aux )', t) and len(t) < 30: return False + # 12. En-têtes de systèmes anatomiques (catégories sans pathologie) + _ANATOMICAL_HEADERS = { + "musculaire", "squelettique", "cardiovasculaire", "pulmonaire", + "neurologique", "digestif", "digestive", "hépatique", "rénal", + "rénale", "urinaire", "cutané", "cutanée", "articulaire", + "osseux", "osseuse", "gastrique", "intestinal", "intestinale", + "cérébral", "thoracique", "abdominal", "abdominale", + } + if len(words) == 1 and t.lower() in _ANATOMICAL_HEADERS: + return False + # Catégorie + description vague : "Musculaire - masse musculaire" + if re.match(r'^[A-ZÀ-Ú][a-zà-ÿ]+ - (masse|zone|région|état|bilan)', t, re.IGNORECASE): + return False + return True diff --git a/src/medical/rag_search.py b/src/medical/rag_search.py index 1a86b01..ff9cc61 100644 --- a/src/medical/rag_search.py +++ b/src/medical/rag_search.py @@ -10,7 +10,7 @@ from ..config import ( OLLAMA_CACHE_PATH, OLLAMA_MAX_PARALLEL, OLLAMA_MODEL, EMBEDDING_MODEL, RERANKER_MODEL, ) -from .cim10_dict import normalize_code, validate_code as cim10_validate +from .cim10_dict import normalize_code, validate_code as cim10_validate, fallback_parent_code from .cim10_extractor import BIO_NORMALS from .ccam_dict import validate_code as ccam_validate from .ollama_client import call_ollama, parse_json_response @@ -478,10 +478,19 @@ def _apply_llm_result_diagnostic(diagnostic: Diagnostic, llm_result: dict) -> No if is_valid: diagnostic.cim10_suggestion = code else: - logger.warning( - "RAG : code Ollama %s invalide pour « %s », code ignoré", - code, diagnostic.texte, - ) + # Tenter fallback vers le code parent (D71.9 → D71) + parent = fallback_parent_code(code) + if parent: + logger.info( + "RAG : code Ollama %s invalide → fallback parent %s pour « %s »", + code, parent, diagnostic.texte, + ) + diagnostic.cim10_suggestion = parent + else: + logger.warning( + "RAG : code Ollama %s invalide pour « %s », code ignoré", + code, diagnostic.texte, + ) if confidence in ("high", "medium", "low"): diagnostic.cim10_confidence = confidence if justification: diff --git a/tests/test_das_filter.py b/tests/test_das_filter.py index 009a1d1..9c0913b 100644 --- a/tests/test_das_filter.py +++ b/tests/test_das_filter.py @@ -161,6 +161,29 @@ class TestIsValidDiagnosticText: """Un fragment long commençant par 'Dans' peut être légitime.""" assert is_valid_diagnostic_text("Dans le cadre d'une insuffisance rénale chronique terminale") + # --- Règle 12 : en-têtes anatomiques / catégories vagues --- + def test_reject_musculaire_alone(self): + assert not is_valid_diagnostic_text("Musculaire") + + def test_reject_musculaire_masse(self): + assert not is_valid_diagnostic_text("Musculaire - masse musculaire") + + def test_reject_digestif_alone(self): + assert not is_valid_diagnostic_text("Digestif") + + def test_reject_hepatique_alone(self): + assert not is_valid_diagnostic_text("Hépatique") + + def test_reject_category_bilan(self): + assert not is_valid_diagnostic_text("Rénal - bilan rénal") + + def test_accept_real_diagnostic_musculaire(self): + """Un vrai diagnostic contenant 'musculaire' est accepté.""" + assert is_valid_diagnostic_text("Dystrophie musculaire de Duchenne") + + def test_accept_real_diagnostic_hepatique(self): + assert is_valid_diagnostic_text("Insuffisance hépatique aiguë") + class TestCorrectKnownMiscodes: """Tests pour la correction des codes systématiquement mal attribués.""" diff --git a/tests/test_rag.py b/tests/test_rag.py index 184ec6b..b4b70d9 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -504,7 +504,7 @@ class TestRAGSearchMocked: {"document": "cim10", "page": 496, "code": "K85", "extrait": "K85", "score": 0.9}, ] mock_llm = { - "code": "X99.99", # code invalide + "code": "QQ9.99", # code invalide (pas de parent valide) "confidence": "high", "justification": "Hallucination", } @@ -516,6 +516,26 @@ class TestRAGSearchMocked: # Le code original est conservé (pas remplacé par le code invalide) assert diag.cim10_suggestion == "K85.9" + def test_enrich_diagnostic_fallback_parent_code(self): + """Un code invalide D71.9 est corrigé en D71 (code parent standalone).""" + from src.medical.rag_search import enrich_diagnostic + + diag = Diagnostic(texte="Anomalie des leucocytes") + mock_sources = [ + {"document": "cim10", "page": 100, "code": "D71", "extrait": "D71", "score": 0.9}, + ] + mock_llm = { + "code": "D71.9", # invalide : D71 est standalone + "confidence": "high", + "justification": "Anomalie leucocytaire", + } + + with patch("src.medical.rag_search.search_similar", return_value=mock_sources), \ + patch("src.medical.rag_search._call_ollama", return_value=mock_llm): + enrich_diagnostic(diag, {"sexe": "M", "age": 60}) + + assert diag.cim10_suggestion == "D71" + def test_enrich_diagnostic_normalizes_code(self): """Un code Ollama sans point est normalisé (K851 → K85.1).""" from src.medical.rag_search import enrich_diagnostic @@ -668,6 +688,36 @@ class TestValidateCodeCIM10: assert is_valid is True +class TestFallbackParentCode: + def test_d71_9_to_d71(self): + """D71.9 (invalide) → D71 (standalone valide).""" + from src.medical.cim10_dict import fallback_parent_code + assert fallback_parent_code("D71.9") == "D71" + + def test_r69_8_to_r69(self): + """R69.8 (invalide) → R69 (standalone valide).""" + from src.medical.cim10_dict import fallback_parent_code + assert fallback_parent_code("R69.8") == "R69" + + def test_valid_code_no_fallback(self): + """Un code déjà valide ne devrait pas matcher (parent aussi valide).""" + from src.medical.cim10_dict import fallback_parent_code + # K85.1 est valide, donc on ne devrait pas appeler fallback + # Mais si on l'appelle, K85 est aussi valide → retourne K85 + result = fallback_parent_code("K85.1") + assert result == "K85" # le parent est valide + + def test_truly_invalid_no_fallback(self): + """Un code sans parent valide retourne None.""" + from src.medical.cim10_dict import fallback_parent_code + assert fallback_parent_code("QQ9.99") is None + + def test_three_char_code_no_fallback(self): + """Un code 3 caractères sans point ne peut pas remonter.""" + from src.medical.cim10_dict import fallback_parent_code + assert fallback_parent_code("QQ9") is None + + class TestValidateCIM10PostProcessing: def test_hallucination_rejected(self): """Les codes hallucination (Aucun, N/A...) sont rejetés."""