diff --git a/src/config.py b/src/config.py index 373d56b..bfbaa8e 100644 --- a/src/config.py +++ b/src/config.py @@ -38,6 +38,7 @@ OLLAMA_TIMEOUT = 120 # --- Configuration RAG --- RAG_INDEX_DIR = BASE_DIR / "data" / "rag_index" +CIM10_DICT_PATH = BASE_DIR / "data" / "cim10_dict.json" CIM10_PDF = Path("/home/dom/ai/aivanov_CIM/cim-10-fr_2026_a_usage_pmsi_version_provisoire_111225.pdf") GUIDE_METHODO_PDF = Path("/home/dom/ai/aivanov_CIM/guide_methodo_mco_2026_version_provisoire.pdf") CCAM_PDF = Path("/home/dom/ai/aivanov_CIM/actualisation_ccam_descriptive_a_usage_pmsi_v4_2025.pdf") @@ -71,6 +72,7 @@ class Diagnostic(BaseModel): cim10_suggestion: Optional[str] = None cim10_confidence: Optional[str] = None justification: Optional[str] = None + raisonnement: Optional[str] = None sources_rag: list[RAGSource] = Field(default_factory=list) diff --git a/src/extraction/document_classifier.py b/src/extraction/document_classifier.py index c0de561..7616188 100644 --- a/src/extraction/document_classifier.py +++ b/src/extraction/document_classifier.py @@ -2,44 +2,93 @@ from __future__ import annotations +from dataclasses import dataclass + + +@dataclass +class ClassificationResult: + """Résultat de classification avec score de confiance.""" + doc_type: str + confidence: float + scores: dict[str, float] + + +# Marqueurs pondérés : (texte, poids) +_TRACKARE_MARKERS: list[tuple[str, int]] = [ + ("ipp:", 3), + ("episode no:", 3), + ("dossier patient", 2), + ("détails des patients", 2), + ("détails épisode", 2), + ("liste des contacts", 1), + ("notes paramédicales", 1), + ("signes vitaux", 1), + ("traitements médicamenteux", 1), + ("observations médicales", 1), + ("constantes", 1), + ("prescriptions", 1), + ("presc. de sortie", 2), + ("type de note", 1), +] + +_CRH_MARKERS: list[tuple[str, int]] = [ + ("mon cher confrère", 3), + ("cher confrère", 3), + ("chère consœur", 3), + ("compte rendu d'hospitalisation", 3), + ("compte-rendu", 2), + ("service de gastro", 2), + ("service de chirurgie", 2), + ("service de médecine", 2), + ("pôle spécialités", 1), + ("votre patient", 2), + ("votre patiente", 2), + ("au total", 1), + ("ttt de sortie", 1), + ("devenir", 1), + ("cordialement", 1), +] + +_SCAN_LENGTH = 5000 + + +def classify_with_confidence(text: str) -> ClassificationResult: + """Classifie un document avec un score de confiance. + + Retourne un ClassificationResult avec le type, la confiance (0.0-1.0), + et les scores détaillés. + """ + text_lower = text[:_SCAN_LENGTH].lower() + + trackare_score = sum(weight for marker, weight in _TRACKARE_MARKERS if marker in text_lower) + crh_score = sum(weight for marker, weight in _CRH_MARKERS if marker in text_lower) + + total = trackare_score + crh_score + if total == 0: + return ClassificationResult(doc_type="crh", confidence=0.5, scores={"trackare": 0, "crh": 0}) + + if trackare_score > crh_score: + confidence = trackare_score / total + doc_type = "trackare" + elif crh_score > trackare_score: + confidence = crh_score / total + doc_type = "crh" + else: + # Égalité — défaut CRH + confidence = 0.5 + doc_type = "crh" + + return ClassificationResult( + doc_type=doc_type, + confidence=round(confidence, 2), + scores={"trackare": trackare_score, "crh": crh_score}, + ) + def classify(text: str) -> str: """Classifie un document extrait en CRH ou Trackare. Retourne "crh" ou "trackare". + Signature inchangée pour rétrocompatibilité. """ - text_lower = text[:3000].lower() - - trackare_markers = [ - "dossier patient", - "détails des patients", - "détails épisode", - "liste des contacts", - "notes paramédicales", - "signes vitaux", - "traitements médicamenteux", - "observations médicales", - ] - trackare_score = sum(1 for m in trackare_markers if m in text_lower) - - crh_markers = [ - "mon cher confrère", - "cher confrère", - "compte rendu d'hospitalisation", - "compte-rendu", - "service de gastro", - "pôle spécialités", - "votre patient", - ] - crh_score = sum(1 for m in crh_markers if m in text_lower) - - if trackare_score >= 2: - return "trackare" - if crh_score >= 2: - return "crh" - - # Heuristique : Trackare contient des tableaux avec IPP - if "ipp:" in text_lower or "episode no:" in text_lower: - return "trackare" - - return "crh" + return classify_with_confidence(text).doc_type diff --git a/src/main.py b/src/main.py index de5d349..21eb82e 100644 --- a/src/main.py +++ b/src/main.py @@ -163,8 +163,18 @@ def main(input_path: str | None = None) -> None: action="store_true", help="Désactiver l'enrichissement RAG (FAISS + Ollama)", ) + parser.add_argument( + "--build-dict", + action="store_true", + help="Générer le dictionnaire CIM-10 depuis metadata.json et quitter", + ) args = parser.parse_args() + if args.build_dict: + from .medical.cim10_dict import build_dict + build_dict() + return + if args.no_ner: # Monkey-patch pour désactiver NER from .anonymization import ner_anonymizer diff --git a/src/medical/cim10_dict.py b/src/medical/cim10_dict.py new file mode 100644 index 0000000..e6ad4b2 --- /dev/null +++ b/src/medical/cim10_dict.py @@ -0,0 +1,180 @@ +"""Dictionnaire CIM-10 complet extrait depuis les métadonnées FAISS. + +Fournit un lookup intelligent avec normalisation Unicode pour la recherche +de codes CIM-10 à partir de textes médicaux en français. +""" + +from __future__ import annotations + +import json +import logging +import re +import unicodedata +from pathlib import Path +from typing import Optional + +from ..config import CIM10_DICT_PATH, RAG_INDEX_DIR + +logger = logging.getLogger(__name__) + +# Singleton : dictionnaire chargé une seule fois +_dict_cache: dict[str, str] | None = None +# Cache des labels normalisés pour le substring matching +_normalized_cache: list[tuple[str, str, str]] | None = None + + +def normalize_text(text: str) -> str: + """Normalise un texte : accent folding, lowercase, collapse whitespace. + + Utilise unicodedata pour supprimer les accents (NFD → suppression des + combining marks), puis met en minuscules et collapse les espaces multiples. + """ + # Normaliser les apostrophes Unicode → ASCII + text = text.replace("\u2019", "'").replace("\u2018", "'").replace("\u02BC", "'") + # NFD decomposition puis suppression des combining marks (accents) + nfkd = unicodedata.normalize("NFKD", text) + stripped = "".join(c for c in nfkd if unicodedata.category(c) != "Mn") + # Lowercase + collapse whitespace + return re.sub(r"\s+", " ", stripped.lower()).strip() + + +def build_dict() -> dict[str, str]: + """Construit le dictionnaire CIM-10 depuis metadata.json et l'écrit dans data/cim10_dict.json. + + Extrait le code et le label (première ligne de l'extrait, sans le préfixe code) + depuis chaque entrée CIM-10 du metadata.json existant. + + Returns: + Le dictionnaire code → label. + """ + metadata_path = RAG_INDEX_DIR / "metadata.json" + if not metadata_path.exists(): + logger.error("metadata.json non trouvé : %s", metadata_path) + return {} + + with open(metadata_path, encoding="utf-8") as f: + metadata = json.load(f) + + result: dict[str, str] = {} + for entry in metadata: + if entry.get("document") != "cim10": + continue + code = entry.get("code") + extrait = entry.get("extrait", "") + if not code or not extrait: + continue + + # Extraire le label : première ligne, sans le préfixe "CODE " + first_line = extrait.split("\n")[0].strip() + # Retirer le préfixe code (ex: "K85.1 Pancréatite aigüe...") + prefix = f"{code} " + if first_line.startswith(prefix): + label = first_line[len(prefix):] + else: + label = first_line + + # Garder l'entrée la plus spécifique (avec point > sans point) + if code not in result or not label: + result[code] = label + + # Écrire le fichier JSON + CIM10_DICT_PATH.parent.mkdir(parents=True, exist_ok=True) + with open(CIM10_DICT_PATH, "w", encoding="utf-8") as f: + json.dump(result, f, ensure_ascii=False, indent=2) + + logger.info("Dictionnaire CIM-10 généré : %d codes → %s", len(result), CIM10_DICT_PATH) + return result + + +def load_dict() -> dict[str, str]: + """Charge le dictionnaire CIM-10 (singleton lazy-loaded). + + Si le fichier JSON n'existe pas, tente de le construire depuis metadata.json. + """ + global _dict_cache + if _dict_cache is not None: + return _dict_cache + + if CIM10_DICT_PATH.exists(): + with open(CIM10_DICT_PATH, encoding="utf-8") as f: + _dict_cache = json.load(f) + else: + logger.info("Dictionnaire CIM-10 absent, construction depuis metadata.json...") + _dict_cache = build_dict() + + return _dict_cache + + +def _get_normalized_entries() -> list[tuple[str, str, str]]: + """Retourne une liste de (code, label_original, label_normalisé) triée par spécificité. + + Les codes avec point (sous-codes, plus spécifiques) sont en premier. + """ + global _normalized_cache + if _normalized_cache is not None: + return _normalized_cache + + d = load_dict() + entries = [] + for code, label in d.items(): + norm = normalize_text(label) + entries.append((code, label, norm)) + + # Trier : sous-codes (avec point) d'abord, puis par longueur de label décroissante + # pour préférer les matchs les plus spécifiques + entries.sort(key=lambda e: (0 if "." in e[0] else 1, -len(e[2]))) + _normalized_cache = entries + return _normalized_cache + + +def lookup( + text: str, + domain_overrides: dict[str, str] | None = None, +) -> str | None: + """Recherche un code CIM-10 pour un texte donné. + + Stratégie en 3 niveaux : + 1. Match substring dans domain_overrides (prioritaire, ex: CIM10_MAP existant) + 2. Match exact normalisé dans le dictionnaire complet + 3. Match substring normalisé avec scoring par spécificité (préfère sous-codes) + + Args: + text: Le texte médical à rechercher. + domain_overrides: Dictionnaire terme→code prioritaire (ex: CIM10_MAP). + + Returns: + Le code CIM-10 trouvé ou None. + """ + if not text: + return None + + text_norm = normalize_text(text) + + # Niveau 1 : domain overrides (substring match) + if domain_overrides: + for terme, code in domain_overrides.items(): + if normalize_text(terme) in text_norm: + return code + + # Niveau 2 : match exact normalisé dans le dictionnaire complet + d = load_dict() + for code, label in d.items(): + if normalize_text(label) == text_norm: + return code + + # Niveau 3 : substring match normalisé (plus spécifique d'abord) + entries = _get_normalized_entries() + for code, _label, norm_label in entries: + if not norm_label or len(norm_label) < 4: + continue + if norm_label in text_norm: + return code + + return None + + +def reset_cache() -> None: + """Réinitialise les caches (utile pour les tests).""" + global _dict_cache, _normalized_cache + _dict_cache = None + _normalized_cache = None diff --git a/src/medical/cim10_extractor.py b/src/medical/cim10_extractor.py index 1b85391..6478e82 100644 --- a/src/medical/cim10_extractor.py +++ b/src/medical/cim10_extractor.py @@ -9,6 +9,7 @@ from typing import Optional logger = logging.getLogger(__name__) +from .cim10_dict import lookup as dict_lookup, normalize_text from ..config import ( ActeCCAM, BiologieCle, @@ -253,33 +254,77 @@ def _extract_diagnostics( def _find_diagnostic_principal(text_lower: str, conclusion: str) -> Diagnostic | None: - """Trouve le diagnostic principal dans le texte.""" - conclusion_lower = conclusion.lower() + """Trouve le diagnostic principal dans le texte. - # Chercher dans la conclusion d'abord + Normalise le texte avant matching pour gérer les variations d'accents/casse. + """ + conclusion_norm = normalize_text(conclusion) + + # Chercher dans la conclusion d'abord via CIM10_MAP (domain override) for terme, code in CIM10_MAP.items(): - if terme in conclusion_lower: + if normalize_text(terme) in conclusion_norm: return Diagnostic(texte=terme.capitalize(), cim10_suggestion=code) - # Patterns courants pour le DP + text_norm = normalize_text(text_lower) + + # Patterns courants pour le DP (normalisés, sans accents) dp_patterns = [ - r"pancréatite\s+aigu[eë]\s+(?:d'origine\s+)?lithiasique", - r"pancréatite\s+aigu[eë]\s+biliaire", - r"pancréatite\s+aigu[eë]", + r"pancreatite\s+aigue\s+(?:d'origine\s+)?lithiasique", + r"pancreatite\s+aigue\s+biliaire", + r"pancreatite\s+aigue", ] for pat in dp_patterns: - if re.search(pat, text_lower): - matched = re.search(pat, text_lower).group(0) + m = re.search(pat, text_norm) + if m: + matched = m.group(0) code = _lookup_cim10(matched) return Diagnostic(texte=matched.capitalize(), cim10_suggestion=code) return None +# Patterns DAS : (pattern_normalisé, label, code_fallback) +# Les patterns sont appliqués sur du texte normalisé (sans accents, lowercase) +_DAS_PATTERNS: list[tuple[str, str, str]] = [ + # Lithiases biliaires + (r"lithiase\s+(?:du\s+)?(?:bas\s+)?choledoque", "Lithiase du cholédoque", "K80.5"), + (r"vesicule\s+lithiasique|lithiases?\s+vesiculaire", "Lithiase vésiculaire", "K80.2"), + # Inflammation biliaire + (r"cholecystite\s+aigue", "Cholécystite aiguë", "K81.0"), + (r"angiocholite|cholangite", "Angiocholite", "K83.0"), + # Réactions médicamenteuses + (r"eruption\s+cutanee|toxidermie|reaction\s+au\s+tramadol", "Éruption cutanée médicamenteuse", "L27.0"), + # Cardiovasculaire + (r"hypertension\s+arterielle|\bhta\b", "Hypertension artérielle", "I10"), + (r"fibrillation\s+auriculaire|\bfa\b(?:\s+paroxystique)?|\bacfa\b", "Fibrillation auriculaire", "I48.9"), + (r"embolie\s+pulmonaire", "Embolie pulmonaire", "I26.9"), + (r"thrombose\s+veineuse\s+profonde|\btvp\b", "Thrombose veineuse profonde", "I80.2"), + # Métabolique + (r"diabete\s+(?:sucre\s+)?(?:de\s+)?type\s+2|diabete\s+type\s*2", "Diabète de type 2", "E11.9"), + (r"diabete\s+(?:sucre\s+)?(?:de\s+)?type\s+1|diabete\s+type\s*1", "Diabète de type 1", "E10.9"), + (r"dyslipidemie|hypercholesterolemie", "Dyslipidémie", "E78.5"), + (r"denutrition|malnutrition", "Dénutrition", "E46"), + # Infectieux + (r"pneumopathie|pneumonie", "Pneumopathie", "J18.9"), + (r"infection\s+urinaire|pyelonephrite", "Infection urinaire", "N39.0"), + (r"\bsepsis\b|septicemie|choc\s+septique", "Sepsis", "A41.9"), + # Rénal + (r"insuffisance\s+renale", "Insuffisance rénale", "N19"), + # Hématologique + (r"anemie", "Anémie", "D64.9"), + # Addictions + (r"tabagisme|tabac\s+actif", "Tabagisme", "F17.2"), + (r"ethylisme|alcoolisme|intoxication\s+ethylique", "Éthylisme", "F10.1"), +] + + def _find_diagnostics_associes( text_lower: str, conclusion: str, dossier: DossierMedical ) -> list[Diagnostic]: - """Trouve les diagnostics associés.""" + """Trouve les diagnostics associés. + + Utilise des patterns normalisés (sans accents) pour une détection robuste. + """ das: list[Diagnostic] = [] existing_codes = set() if dossier.diagnostic_principal: @@ -287,32 +332,21 @@ def _find_diagnostics_associes( for d in dossier.diagnostics_associes: existing_codes.add(d.cim10_suggestion) - # Lithiase cholédoque - if re.search(r"lithiase\s+(?:du\s+)?(?:bas\s+)?cholédoque", text_lower): - if "K80.5" not in existing_codes: - das.append(Diagnostic(texte="Lithiase du cholédoque", cim10_suggestion="K80.5")) - existing_codes.add("K80.5") + text_norm = normalize_text(text_lower) - # Éruption médicamenteuse - if re.search(r"éruption\s+cutanée|eruption\s+cutanée|toxidermie|réaction\s+au\s+tramadol", text_lower): - if "L27.0" not in existing_codes: - das.append(Diagnostic(texte="Éruption cutanée médicamenteuse", cim10_suggestion="L27.0")) - existing_codes.add("L27.0") + # Patterns DAS + for pat, label, code in _DAS_PATTERNS: + if re.search(pat, text_norm) and code not in existing_codes: + das.append(Diagnostic(texte=label, cim10_suggestion=code)) + existing_codes.add(code) - # Obésité (IMC >= 30) - if re.search(r"imc\s*[:=]?\s*(\d{2,3}[.,]\d+)", text_lower): - m = re.search(r"imc\s*[:=]?\s*(\d{2,3}[.,]\d+)", text_lower) - if m: - imc_val = float(m.group(1).replace(",", ".")) - if imc_val >= 30 and "E66.0" not in existing_codes: - das.append(Diagnostic(texte=f"Obésité (IMC {imc_val})", cim10_suggestion="E66.0")) - existing_codes.add("E66.0") - - # Lithiases vésiculaires - if re.search(r"vésicule\s+lithiasique|lithiases?\s+vésiculaire", text_lower): - if "K80.2" not in existing_codes: - das.append(Diagnostic(texte="Lithiase vésiculaire", cim10_suggestion="K80.2")) - existing_codes.add("K80.2") + # Obésité (IMC >= 30) — pattern spécial avec extraction de valeur + m = re.search(r"imc\s*[:=]?\s*(\d{2,3}[.,]\d+)", text_norm) + if m: + imc_val = float(m.group(1).replace(",", ".")) + if imc_val >= 30 and "E66.0" not in existing_codes: + das.append(Diagnostic(texte=f"Obésité (IMC {imc_val})", cim10_suggestion="E66.0")) + existing_codes.add("E66.0") return das @@ -399,7 +433,7 @@ def _extract_traitements( if not drug.negation and drug.code_atc: drug_atc[drug.texte.lower()] = drug.code_atc - # Depuis le texte — section "TTT de sortie" (limiter à quelques lignes) + # Depuis le texte — section "TTT de sortie" (sans limite de lignes) m = re.search( r"(?:TTT|Traitement)\s+de\s+sortie\s*[::]?\s*\n?(.*?)(?=\n\s*(?:Devenir|Rédigé|Cordialement|Patient:|Episode|Le \d{2}/\d{2}|\n\n)|$)", text, @@ -408,17 +442,28 @@ def _extract_traitements( if m: block = m.group(1).strip() lines = block.split("\n") - for line in lines[:10]: # Limiter à 10 lignes max + for line in lines: line = line.strip().lstrip("- •") if not line or len(line) <= 2: continue - # Ignorer les footers et lignes non-médicament - if re.match(r"^(Patient|Episode|Le \d|Page|V\d)", line): + # Conditions d'arrêt : footers, signatures, metadata + if re.match( + r"^(Patient|Episode|Le \d|Page\s+\d|V\d|Rédigé|Cordialement|Dr\s|Docteur|Signature|Date|Fait\s+le)", + line, + re.IGNORECASE, + ): break med = line poso = None - # Séparer médicament et posologie - poso_match = re.search(r"\s+(si besoin|matin|soir|midi|\d+\s*(?:mg|cp|gel).*)", line, re.IGNORECASE) + # Séparer médicament et posologie (pattern élargi) + poso_match = re.search( + r"\s+(si besoin|matin|soir|midi|" + r"\d+\s*(?:mg|cp|gel|sachet|comprim[ée]|g[ée]lule).*|" + r"\d+\s*(?:x|fois)\s*/?\s*(?:j(?:our)?|semaine)|" + r"pendant\s+\d+\s*jours?)", + line, + re.IGNORECASE, + ) if poso_match: med = line[:poso_match.start()].strip() poso = poso_match.group(1).strip() @@ -460,16 +505,24 @@ def _match_drug_atc(med_name: str, drug_atc: dict[str, str]) -> Optional[str]: def _extract_biologie(text: str, dossier: DossierMedical) -> None: - """Extrait les résultats biologiques clés.""" + """Extrait les résultats biologiques clés. + + Supporte les aliases (TGO/TGP, Hb), variantes d'unités (UI/L, µmol/L, g/dL), + et des tests additionnels (hémoglobine, plaquettes, leucocytes, créatinine). + """ bio_patterns = [ - (r"[Ll]ipas[ée]mie\s*(?:[àa=:])?\s*(\d+)", "Lipasémie", None), - (r"CRP\s*[=:à]?\s*(\d+(?:[.,]\d+)?)", "CRP", None), - (r"ASAT\s*[=:à]?\s*([\d.,]+)\s*(?:N|U/L)?", "ASAT", None), - (r"ALAT\s*[=:à]?\s*([\d.,]+)\s*(?:N|U/L)?", "ALAT", None), - (r"GGT\s*[=:à]?\s*(\d+)\s*(?:U/L)?", "GGT", None), - (r"PAL\s*[=:à]?\s*(\d+)\s*(?:U/L)?", "PAL", None), - (r"[Bb]ilirubine\s+(?:totale\s+)?[àa=:]\s*(\d+)\s*(?:µmol/L)?", "Bilirubine totale", None), - (r"troponine\s+(négative|positive|normale)", "Troponine", None), + (r"[Ll]ipas[ée]mie\s*(?:[àa=:])?\s*(\d+)\s*(?:UI/L|U/L)?", "Lipasémie", None), + (r"CRP\s*[=:àa]?\s*(\d+(?:[.,]\d+)?)\s*(?:mg/[Ll])?", "CRP", None), + (r"(?:ASAT|TGO)\s*[=:àa]?\s*([\d.,]+)\s*(?:N|U(?:I)?/L)?", "ASAT", None), + (r"(?:ALAT|TGP)\s*[=:àa]?\s*([\d.,]+)\s*(?:N|U(?:I)?/L)?", "ALAT", None), + (r"GGT\s*[=:àa]?\s*(\d+)\s*(?:U(?:I)?/L)?", "GGT", None), + (r"PAL\s*[=:àa]?\s*(\d+)\s*(?:U(?:I)?/L)?", "PAL", None), + (r"[Bb]ilirubine\s+(?:totale\s+)?[àa=:]\s*(\d+(?:[.,]\d+)?)\s*(?:µmol/L|mg/dL)?", "Bilirubine totale", None), + (r"[Tt]roponine\s+(?:us\s+)?(n[ée]gative|positive|normale)", "Troponine", None), + (r"(?:[Hh][ée]moglobine|Hb)\s*[=:àa]?\s*(\d+(?:[.,]\d+)?)\s*(?:g/dL|g/L)?", "Hémoglobine", None), + (r"[Pp]laquettes?\s*[=:àa]?\s*(\d+(?:\s*000)?)\s*(?:/mm3|G/L)?", "Plaquettes", None), + (r"[Ll]eucocytes?\s*[=:àa]?\s*(\d+(?:\s*000)?)\s*(?:/mm3|G/L)?", "Leucocytes", None), + (r"[Cc]r[ée]atinine?\s*[=:àa]?\s*(\d+(?:[.,]\d+)?)\s*(?:µmol/L|mg/dL)?", "Créatinine", None), ] for pattern, test_name, _ in bio_patterns: @@ -589,12 +642,11 @@ def _find_act_date(text: str, act_pattern: str) -> str | None: def _lookup_cim10(text: str) -> str | None: - """Cherche un code CIM-10 pour un texte donné.""" - text_lower = text.lower().strip() - for terme, code in CIM10_MAP.items(): - if terme in text_lower: - return code - return None + """Cherche un code CIM-10 pour un texte donné. + + Utilise le dictionnaire complet (10 893 codes) avec CIM10_MAP en override prioritaire. + """ + return dict_lookup(text, domain_overrides=CIM10_MAP) def _is_abnormal(test: str, value: str) -> bool | None: @@ -616,6 +668,10 @@ def _is_abnormal(test: str, value: str) -> bool | None: "GGT": (0, 60), "PAL": (0, 150), "Bilirubine totale": (0, 17), + "Hémoglobine": (12, 17), + "Plaquettes": (150, 400), + "Leucocytes": (4, 10), + "Créatinine": (50, 120), } if test in normals: diff --git a/tests/test_medical.py b/tests/test_medical.py index 2d89be1..201b4b7 100644 --- a/tests/test_medical.py +++ b/tests/test_medical.py @@ -8,6 +8,8 @@ from src.medical.cim10_extractor import ( _lookup_cim10, _is_abnormal, ) +from src.medical.cim10_dict import normalize_text, load_dict, lookup, reset_cache +from src.extraction.document_classifier import classify, classify_with_confidence class TestCIM10Lookup: @@ -236,3 +238,221 @@ Devenir : sortie le 03/03.""" complication_terms = [c.lower() for c in dossier.complications] assert "fièvre" not in complication_terms assert "infection" not in complication_terms + + +# === Nouveaux tests : dictionnaire CIM-10, normalisation, robustesse === + + +class TestCIM10Dict: + """Tests pour le chargement du dictionnaire CIM-10 complet.""" + + def test_load_dict_not_empty(self): + d = load_dict() + assert len(d) > 10000 + + def test_known_codes_present(self): + d = load_dict() + assert "K85.1" in d + assert "K80.5" in d + assert "I10" in d + assert "E66.0" in d + assert "L27.0" in d + + def test_labels_non_empty(self): + d = load_dict() + for code, label in list(d.items())[:100]: + assert label, f"Label vide pour {code}" + + +class TestNormalizeText: + """Tests pour normalize_text : accents, casse, whitespace.""" + + def test_accents_removed(self): + assert normalize_text("Pancréatite") == "pancreatite" + + def test_lowercase(self): + assert normalize_text("PANCRÉATITE AIGUË") == "pancreatite aigue" + + def test_whitespace_collapsed(self): + assert normalize_text(" pancréatite aiguë ") == "pancreatite aigue" + + def test_trema(self): + assert normalize_text("aigüe") == "aigue" + + def test_mixed(self): + assert normalize_text("Éruption Cutanée Médicamenteuse") == "eruption cutanee medicamenteuse" + + +class TestDictLookup: + """Tests pour lookup : priorité domain override, match exact, substring.""" + + def test_domain_override_priority(self): + """CIM10_MAP (override) a priorité sur le dictionnaire complet.""" + override = {"pancréatite aiguë biliaire": "K85.1"} + result = lookup("pancréatite aiguë biliaire", domain_overrides=override) + assert result == "K85.1" + + def test_exact_normalized_match(self): + """Match exact normalisé dans le dictionnaire complet.""" + # "Hypertension essentielle (primitive)" est le label exact de I10 + result = lookup("Hypertension essentielle (primitive)") + assert result == "I10" + + def test_substring_match(self): + """Match substring normalisé.""" + result = lookup("patient avec cholécystite aiguë sévère") + assert result == "K81.0" + + def test_unknown_returns_none(self): + result = lookup("texte complètement inconnu xyz123") + assert result is None + + def test_accent_insensitive(self): + """La recherche ignore les accents.""" + result = lookup("pancreatite aigue d'origine biliaire") + assert result == "K85.1" + + +class TestDiagnosticAccentVariations: + """Tests pour la détection de diagnostics avec variations d'accents.""" + + def _extract(self, text: str) -> DossierMedical: + parsed = { + "type": "crh", + "patient": {"sexe": "M"}, + "sejour": {}, + "diagnostics": [], + } + return extract_medical_info(parsed, text) + + def test_pancreatite_sans_accents(self): + dossier = self._extract("Pancreatite aigue biliaire.\nDevenir : retour.") + assert dossier.diagnostic_principal is not None + assert dossier.diagnostic_principal.cim10_suggestion == "K85.1" + + def test_pancreatite_trema(self): + dossier = self._extract("Pancréatite aigüe biliaire.\nDevenir : retour.") + assert dossier.diagnostic_principal is not None + assert dossier.diagnostic_principal.cim10_suggestion == "K85.1" + + def test_pancreatite_majuscules(self): + dossier = self._extract("PANCREATITE AIGUE BILIAIRE.\nDevenir : retour.") + assert dossier.diagnostic_principal is not None + assert dossier.diagnostic_principal.cim10_suggestion == "K85.1" + + def test_hta_as_das(self): + """HTA détectée comme DAS même sans accent.""" + dossier = self._extract("Douleur abdominale.\nhypertension arterielle connue.\nDevenir : retour.") + codes = {d.cim10_suggestion for d in dossier.diagnostics_associes} + assert "I10" in codes + + +class TestBiologieEdgeCases: + """Tests pour l'extraction biologie avec variantes.""" + + def _extract_bio(self, text: str) -> list: + parsed = { + "type": "crh", + "patient": {"sexe": "M"}, + "sejour": {}, + "diagnostics": [], + } + dossier = extract_medical_info(parsed, text) + return dossier.biologie_cle + + def test_crp_with_unit(self): + bio = self._extract_bio("CRP=45 mg/L") + assert any(b.test == "CRP" and b.valeur == "45" for b in bio) + + def test_lipasemie_ui_l(self): + bio = self._extract_bio("Lipasémie à 850 UI/L") + assert any(b.test == "Lipasémie" and b.valeur == "850" for b in bio) + + def test_troponine_us(self): + bio = self._extract_bio("Troponine us négative") + assert any(b.test == "Troponine" and b.valeur == "négative" for b in bio) + + def test_hb_shorthand(self): + bio = self._extract_bio("Hb = 11.5 g/dL") + assert any(b.test == "Hémoglobine" and b.valeur == "11.5" for b in bio) + + def test_tgo_alias(self): + bio = self._extract_bio("TGO = 120 UI/L") + assert any(b.test == "ASAT" and b.valeur == "120" for b in bio) + + def test_creatinine(self): + bio = self._extract_bio("Créatinine à 95 µmol/L") + assert any(b.test == "Créatinine" and b.valeur == "95" for b in bio) + + +class TestTraitementEdgeCases: + """Tests pour l'extraction des traitements.""" + + def _extract_ttt(self, text: str) -> list: + parsed = { + "type": "crh", + "patient": {"sexe": "M"}, + "sejour": {}, + "diagnostics": [], + } + dossier = extract_medical_info(parsed, text) + return dossier.traitements_sortie + + def test_more_than_10_medications(self): + """Vérifie que la limite de 10 est supprimée.""" + meds = "\n".join(f"Médicament{i} 100mg matin" for i in range(15)) + text = f"TTT de sortie :\n{meds}\n\nDevenir : retour." + ttt = self._extract_ttt(text) + assert len(ttt) >= 15 + + def test_posologie_sachet(self): + text = "TTT de sortie :\nMovicol 1 sachet matin\n\nDevenir : retour." + ttt = self._extract_ttt(text) + assert len(ttt) >= 1 + + def test_posologie_x_par_jour(self): + text = "TTT de sortie :\nParacétamol 1g 3x/jour\n\nDevenir : retour." + ttt = self._extract_ttt(text) + assert len(ttt) >= 1 + assert ttt[0].posologie is not None + + def test_stop_on_footer(self): + text = "TTT de sortie :\nParacétamol\nDoliprane\nDr Martin signature\nAutre médicament\n\nDevenir : retour." + ttt = self._extract_ttt(text) + meds = [t.medicament for t in ttt] + assert "Autre médicament" not in meds + + def test_pendant_x_jours(self): + text = "TTT de sortie :\nAmoxicilline 1g pendant 7 jours\n\nDevenir : retour." + ttt = self._extract_ttt(text) + assert len(ttt) >= 1 + assert ttt[0].posologie is not None + assert "7 jours" in ttt[0].posologie + + +class TestClassifierConfidence: + """Tests pour classify_with_confidence.""" + + def test_high_confidence_trackare(self): + text = "Dossier Patient\nIPP: 12345\nDétails épisode\nEpisode No: 67890\nSignes vitaux\n" + result = classify_with_confidence(text) + assert result.doc_type == "trackare" + assert result.confidence >= 0.7 + + def test_high_confidence_crh(self): + text = "Mon cher confrère,\nCompte rendu d'hospitalisation\nVotre patient a été admis dans le service de gastro\n" + result = classify_with_confidence(text) + assert result.doc_type == "crh" + assert result.confidence >= 0.7 + + def test_ambiguous_case(self): + text = "Document médical quelconque sans marqueurs spécifiques." + result = classify_with_confidence(text) + assert result.confidence <= 0.6 + + def test_backward_compatible(self): + """classify() retourne toujours une string.""" + text = "Dossier Patient\nIPP: 12345\n" + result = classify(text) + assert isinstance(result, str) + assert result in ("crh", "trackare")