diff --git a/.gitignore b/.gitignore index af65819..9ab3ebf 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __pycache__/ output/ input/ *.egg-info/ +data/ diff --git a/requirements.txt b/requirements.txt index 54774ad..89e9a4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,12 @@ pdfplumber>=0.10.0 -transformers>=4.35.0 +transformers>=4.35.0,<5.0.0 torch>=2.1.0 +protobuf>=3.20.0,<4.0.0 regex>=2023.0 pydantic>=2.5.0 pytest>=7.4.0 sentencepiece>=0.1.99,<0.2.0 edsnlp[ml]>=0.17.0 +faiss-cpu>=1.7.0 +sentence-transformers>=2.2.0 +requests>=2.28.0 diff --git a/src/config.py b/src/config.py index 083ed3a..1a5c033 100644 --- a/src/config.py +++ b/src/config.py @@ -28,9 +28,24 @@ NER_MODEL = "Jean-Baptiste/camembert-ner" NER_CONFIDENCE_THRESHOLD = 0.80 +# --- Configuration RAG --- + +RAG_INDEX_DIR = BASE_DIR / "data" / "rag_index" +CIM10_PDF = Path("/home/dom/ai/aivanov_CIM/cim-10-fr_2026_a_usage_pmsi_version_provisoire_111225.pdf") +GUIDE_METHODO_PDF = Path("/home/dom/ai/aivanov_CIM/guide_methodo_mco_2026_version_provisoire.pdf") +CCAM_PDF = Path("/home/dom/ai/aivanov_CIM/actualisation_ccam_descriptive_a_usage_pmsi_v4_2025.pdf") + + # --- Modèles de données CIM-10 --- +class RAGSource(BaseModel): + document: str + page: Optional[int] = None + code: Optional[str] = None + extrait: Optional[str] = None + + class Sejour(BaseModel): sexe: Optional[str] = None age: Optional[int] = None @@ -47,6 +62,9 @@ class Sejour(BaseModel): class Diagnostic(BaseModel): texte: str cim10_suggestion: Optional[str] = None + cim10_confidence: Optional[str] = None + justification: Optional[str] = None + sources_rag: list[RAGSource] = Field(default_factory=list) class ActeCCAM(BaseModel): diff --git a/src/main.py b/src/main.py index 51739c3..65ac81d 100644 --- a/src/main.py +++ b/src/main.py @@ -22,8 +22,9 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) -# Flag global pour désactiver edsnlp +# Flags globaux _use_edsnlp = True +_use_rag = True def process_pdf(pdf_path: Path) -> tuple[str, DossierMedical, AnonymizationReport]: @@ -63,7 +64,7 @@ def process_pdf(pdf_path: Path) -> tuple[str, DossierMedical, AnonymizationRepor edsnlp_result = _run_edsnlp(anonymized_text) # 6. Extraction médicale CIM-10 - dossier = extract_medical_info(parsed, anonymized_text, edsnlp_result) + dossier = extract_medical_info(parsed, anonymized_text, edsnlp_result, use_rag=_use_rag) dossier.source_file = pdf_path.name dossier.document_type = doc_type logger.info(" DP : %s", dossier.diagnostic_principal) @@ -123,7 +124,7 @@ def write_outputs( def main(input_path: str | None = None) -> None: """Point d'entrée principal.""" - global _use_edsnlp + global _use_edsnlp, _use_rag parser = argparse.ArgumentParser( description="Anonymisation de documents médicaux PDF et extraction CIM-10", @@ -144,6 +145,11 @@ def main(input_path: str | None = None) -> None: action="store_true", help="Désactiver l'analyse edsnlp (mode regex seul)", ) + parser.add_argument( + "--no-rag", + action="store_true", + help="Désactiver l'enrichissement RAG (FAISS + Ollama)", + ) args = parser.parse_args() if args.no_ner: @@ -154,6 +160,9 @@ def main(input_path: str | None = None) -> None: if args.no_edsnlp: _use_edsnlp = False + if args.no_rag: + _use_rag = False + input_p = Path(args.input) if input_p.is_file(): pdfs = [input_p] diff --git a/src/medical/cim10_extractor.py b/src/medical/cim10_extractor.py index b8adc85..1b85391 100644 --- a/src/medical/cim10_extractor.py +++ b/src/medical/cim10_extractor.py @@ -2,10 +2,13 @@ from __future__ import annotations +import logging import re from datetime import datetime from typing import Optional +logger = logging.getLogger(__name__) + from ..config import ( ActeCCAM, BiologieCle, @@ -91,6 +94,7 @@ def extract_medical_info( parsed_data: dict, anonymized_text: str, edsnlp_result: Optional[EdsnlpResult] = None, + use_rag: bool = False, ) -> DossierMedical: """Extrait les informations médicales structurées depuis les données parsées et le texte.""" dossier = DossierMedical() @@ -105,9 +109,23 @@ def extract_medical_info( _extract_imagerie(anonymized_text, dossier) _extract_complications(anonymized_text, dossier, edsnlp_result) + if use_rag: + _enrich_with_rag(dossier) + return dossier +def _enrich_with_rag(dossier: DossierMedical) -> None: + """Enrichit les diagnostics via le RAG (FAISS + Ollama).""" + try: + from .rag_search import enrich_dossier + enrich_dossier(dossier) + except ImportError: + logger.warning("Module RAG non disponible (faiss-cpu ou sentence-transformers manquant)") + except Exception: + logger.warning("Erreur lors de l'enrichissement RAG", exc_info=True) + + def _extract_sejour(parsed: dict, dossier: DossierMedical) -> None: """Extrait les informations de séjour.""" patient = parsed.get("patient", {}) diff --git a/src/medical/rag_index.py b/src/medical/rag_index.py new file mode 100644 index 0000000..a59e323 --- /dev/null +++ b/src/medical/rag_index.py @@ -0,0 +1,352 @@ +"""Indexation FAISS des documents de référence CIM-10 / Guide métho / CCAM.""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Optional + +import pdfplumber + +from ..config import RAG_INDEX_DIR, CIM10_PDF, GUIDE_METHODO_PDF, CCAM_PDF + +logger = logging.getLogger(__name__) + +# Singleton pour l'index chargé en mémoire +_faiss_index = None +_metadata: list[dict] = [] + + +@dataclass +class Chunk: + text: str + document: str # "cim10", "guide_methodo", "ccam" + page: Optional[int] = None + code: Optional[str] = None + + +# --------------------------------------------------------------------------- +# Chunking CIM-10 +# --------------------------------------------------------------------------- + +def _chunk_cim10(pdf_path: Path) -> list[Chunk]: + """Découpe le PDF CIM-10 en chunks par code 3 caractères (ex: K80, K85).""" + chunks: list[Chunk] = [] + current_code: str | None = None + current_text: list[str] = [] + current_page: int | None = None + + # Pattern pour détecter un code CIM-10 à 3 caractères en début de ligne + code3_pattern = re.compile(r"^([A-Z]\d{2})\s+(.+)") + # Pattern pour les sous-codes (ex: K80.0, K80.1) + subcode_pattern = re.compile(r"^([A-Z]\d{2}\.\d+)\s+(.+)") + + logger.info("Extraction des chunks CIM-10 depuis %s", pdf_path.name) + + with pdfplumber.open(pdf_path) as pdf: + for page_num, page in enumerate(pdf.pages, start=1): + text = page.extract_text() + if not text: + continue + + for line in text.split("\n"): + line = line.strip() + if not line: + continue + + m = code3_pattern.match(line) + if m and not subcode_pattern.match(line): + # Nouveau code 3-char → sauvegarder le chunk précédent + if current_code and current_text: + chunk_text = "\n".join(current_text) + if len(chunk_text.split()) >= 5: + chunks.append(Chunk( + text=chunk_text, + document="cim10", + page=current_page, + code=current_code, + )) + current_code = m.group(1) + current_text = [line] + current_page = page_num + else: + if current_code: + current_text.append(line) + + # Dernier chunk + if current_code and current_text: + chunk_text = "\n".join(current_text) + if len(chunk_text.split()) >= 5: + chunks.append(Chunk( + text=chunk_text, + document="cim10", + page=current_page, + code=current_code, + )) + + logger.info("CIM-10 : %d chunks extraits", len(chunks)) + return chunks + + +# --------------------------------------------------------------------------- +# Chunking Guide Méthodologique MCO +# --------------------------------------------------------------------------- + +def _chunk_guide_methodo(pdf_path: Path) -> list[Chunk]: + """Découpe le Guide Méthodologique MCO par sections/titres.""" + chunks: list[Chunk] = [] + current_title: str | None = None + current_text: list[str] = [] + current_page: int | None = None + + # Patterns de titres de sections (chapitres, sous-chapitres) + title_patterns = [ + re.compile(r"^((?:CHAPITRE|TITRE|PARTIE)\s+[IVXLCDM0-9]+.*)$", re.IGNORECASE), + re.compile(r"^(\d+\.\d*\s+[A-ZÉÈÊÀÂÔÙÛÜ].{5,})$"), + re.compile(r"^([A-ZÉÈÊÀÂÔÙÛÜ][A-ZÉÈÊÀÂÔÙÛÜ\s]{10,})$"), + ] + + logger.info("Extraction des chunks Guide Métho depuis %s", pdf_path.name) + + with pdfplumber.open(pdf_path) as pdf: + for page_num, page in enumerate(pdf.pages, start=1): + text = page.extract_text() + if not text: + continue + + for line in text.split("\n"): + line = line.strip() + if not line: + continue + + is_title = False + for pat in title_patterns: + if pat.match(line): + is_title = True + break + + if is_title and len(line) > 8: + # Sauvegarder le chunk précédent + if current_title and current_text: + chunk_text = current_title + "\n" + "\n".join(current_text) + if len(chunk_text.split()) >= 20: + chunks.append(Chunk( + text=chunk_text, + document="guide_methodo", + page=current_page, + )) + current_title = line + current_text = [] + current_page = page_num + else: + current_text.append(line) + + # Dernier chunk + if current_title and current_text: + chunk_text = current_title + "\n" + "\n".join(current_text) + if len(chunk_text.split()) >= 20: + chunks.append(Chunk( + text=chunk_text, + document="guide_methodo", + page=current_page, + )) + + # Si trop peu de chunks (le PDF ne suit pas les patterns de titre), + # fallback : découper par pages groupées par 3 + if len(chunks) < 10: + logger.info("Guide Métho : fallback découpe par pages (peu de titres détectés)") + chunks = [] + with pdfplumber.open(pdf_path) as pdf: + page_texts: list[str] = [] + start_page = 1 + for page_num, page in enumerate(pdf.pages, start=1): + text = page.extract_text() + if text: + page_texts.append(text) + if len(page_texts) >= 3: + combined = "\n".join(page_texts) + if len(combined.split()) >= 20: + chunks.append(Chunk( + text=combined, + document="guide_methodo", + page=start_page, + )) + page_texts = [] + start_page = page_num + 1 + if page_texts: + combined = "\n".join(page_texts) + if len(combined.split()) >= 20: + chunks.append(Chunk( + text=combined, + document="guide_methodo", + page=start_page, + )) + + logger.info("Guide Métho : %d chunks extraits", len(chunks)) + return chunks + + +# --------------------------------------------------------------------------- +# Chunking CCAM +# --------------------------------------------------------------------------- + +def _chunk_ccam(pdf_path: Path) -> list[Chunk]: + """Découpe le PDF CCAM en chunks par code d'acte.""" + chunks: list[Chunk] = [] + ccam_pattern = re.compile(r"([A-Z]{4}\d{3})\s+(.*)") + + logger.info("Extraction des chunks CCAM depuis %s", pdf_path.name) + + with pdfplumber.open(pdf_path) as pdf: + for page_num, page in enumerate(pdf.pages, start=1): + text = page.extract_text() + if not text: + continue + + current_code: str | None = None + current_lines: list[str] = [] + + for line in text.split("\n"): + line = line.strip() + if not line: + continue + + m = ccam_pattern.match(line) + if m: + if current_code and current_lines: + chunks.append(Chunk( + text="\n".join(current_lines), + document="ccam", + page=page_num, + code=current_code, + )) + current_code = m.group(1) + current_lines = [line] + elif current_code: + current_lines.append(line) + + if current_code and current_lines: + chunks.append(Chunk( + text="\n".join(current_lines), + document="ccam", + page=page_num, + code=current_code, + )) + + # Fallback : si aucun code CCAM détecté, indexer par page + if not chunks: + logger.info("CCAM : aucun code détecté, fallback par page") + with pdfplumber.open(pdf_path) as pdf: + for page_num, page in enumerate(pdf.pages, start=1): + text = page.extract_text() + if text and len(text.split()) >= 10: + chunks.append(Chunk( + text=text, + document="ccam", + page=page_num, + )) + + logger.info("CCAM : %d chunks extraits", len(chunks)) + return chunks + + +# --------------------------------------------------------------------------- +# Construction de l'index FAISS +# --------------------------------------------------------------------------- + +def build_index(force: bool = False) -> None: + """Construit l'index FAISS à partir des 3 PDFs de référence. + + Args: + force: Si True, reconstruit même si l'index existe déjà. + """ + import faiss + import numpy as np + from sentence_transformers import SentenceTransformer + + index_path = RAG_INDEX_DIR / "faiss.index" + meta_path = RAG_INDEX_DIR / "metadata.json" + + if not force and index_path.exists() and meta_path.exists(): + logger.info("Index FAISS déjà existant dans %s (use force=True pour reconstruire)", RAG_INDEX_DIR) + return + + # Collecter tous les chunks + all_chunks: list[Chunk] = [] + + for pdf_path, chunk_fn in [ + (CIM10_PDF, _chunk_cim10), + (GUIDE_METHODO_PDF, _chunk_guide_methodo), + (CCAM_PDF, _chunk_ccam), + ]: + if pdf_path.exists(): + all_chunks.extend(chunk_fn(pdf_path)) + else: + logger.warning("PDF non trouvé : %s", pdf_path) + + if not all_chunks: + logger.error("Aucun chunk extrait — vérifiez les chemins des PDFs") + return + + logger.info("Total : %d chunks à indexer", len(all_chunks)) + + # Embeddings — forcer CPU pour éviter les bugs CUDA avec ce modèle + logger.info("Chargement du modèle d'embedding dangvantuan/sentence-camembert-large (CPU)...") + model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu") + model.max_seq_length = 512 # CamemBERT max position embeddings + + texts = [c.text[:2000] for c in all_chunks] # Tronquer les chunks trop longs + logger.info("Calcul des embeddings pour %d chunks...", len(texts)) + embeddings = model.encode( + texts, show_progress_bar=True, normalize_embeddings=True, batch_size=64, + ) + embeddings = np.array(embeddings, dtype=np.float32) + + # Index FAISS (IndexFlatIP = cosine similarity avec vecteurs normalisés) + dim = embeddings.shape[1] + index = faiss.IndexFlatIP(dim) + index.add(embeddings) + + # Sauvegarder + RAG_INDEX_DIR.mkdir(parents=True, exist_ok=True) + faiss.write_index(index, str(index_path)) + + metadata = [asdict(c) for c in all_chunks] + # Ne pas sauvegarder le texte complet dans metadata (trop lourd), + # garder un extrait de 500 chars + for m in metadata: + m["extrait"] = m.pop("text")[:500] + + meta_path.write_text(json.dumps(metadata, ensure_ascii=False, indent=2), encoding="utf-8") + + logger.info("Index FAISS sauvegardé : %s (%d vecteurs, dim=%d)", index_path, len(all_chunks), dim) + + +def get_index() -> tuple | None: + """Charge l'index FAISS et les métadonnées (singleton lazy-loaded). + + Returns: + Tuple (faiss_index, metadata_list) ou None si l'index n'existe pas. + """ + global _faiss_index, _metadata + + if _faiss_index is not None: + return _faiss_index, _metadata + + index_path = RAG_INDEX_DIR / "faiss.index" + meta_path = RAG_INDEX_DIR / "metadata.json" + + if not index_path.exists() or not meta_path.exists(): + logger.warning("Index FAISS non trouvé dans %s — lancez build_index() d'abord", RAG_INDEX_DIR) + return None + + import faiss + + _faiss_index = faiss.read_index(str(index_path)) + _metadata = json.loads(meta_path.read_text(encoding="utf-8")) + + logger.info("Index FAISS chargé : %d vecteurs", _faiss_index.ntotal) + return _faiss_index, _metadata diff --git a/src/medical/rag_search.py b/src/medical/rag_search.py new file mode 100644 index 0000000..ee30964 --- /dev/null +++ b/src/medical/rag_search.py @@ -0,0 +1,208 @@ +"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10.""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +import requests + +from ..config import Diagnostic, DossierMedical, RAGSource + +logger = logging.getLogger(__name__) + +# Configuration Ollama +OLLAMA_URL = "http://localhost:11434/api/generate" +OLLAMA_MODEL = "mistral-small3.2:24b" +OLLAMA_TIMEOUT = 120 # secondes + +# Singleton pour le modèle d'embedding (chargé une seule fois) +_embed_model = None + + +def _get_embed_model(): + """Charge le modèle d'embedding (singleton).""" + global _embed_model + if _embed_model is None: + from sentence_transformers import SentenceTransformer + logger.info("Chargement du modèle d'embedding pour la recherche...") + _embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu") + _embed_model.max_seq_length = 512 + return _embed_model + + +def search_similar(query: str, top_k: int = 5) -> list[dict]: + """Recherche les passages les plus similaires dans l'index FAISS. + + Args: + query: Texte du diagnostic à rechercher. + top_k: Nombre de résultats à retourner. + + Returns: + Liste de dicts avec les métadonnées + score de similarité. + """ + from .rag_index import get_index + import numpy as np + + result = get_index() + if result is None: + logger.warning("Index FAISS non disponible") + return [] + + faiss_index, metadata = result + + model = _get_embed_model() + query_vec = model.encode([query], normalize_embeddings=True) + query_vec = np.array(query_vec, dtype=np.float32) + + scores, indices = faiss_index.search(query_vec, min(top_k, faiss_index.ntotal)) + + results = [] + for score, idx in zip(scores[0], indices[0]): + if idx < 0: + continue + meta = metadata[idx].copy() + meta["score"] = float(score) + results.append(meta) + + return results + + +def _build_prompt(texte: str, sources: list[dict], contexte: dict) -> str: + """Construit le prompt pour Ollama.""" + sources_text = "" + for i, src in enumerate(sources, 1): + doc_name = { + "cim10": "CIM-10 FR 2026", + "guide_methodo": "Guide Méthodologique MCO 2026", + "ccam": "CCAM PMSI V4 2025", + }.get(src["document"], src["document"]) + + code_info = f" (code: {src['code']})" if src.get("code") else "" + page_info = f" [page {src['page']}]" if src.get("page") else "" + + sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n" + sources_text += (src.get("extrait", "")[:800]) + "\n\n" + + ctx_parts = [] + if contexte.get("sexe"): + ctx_parts.append(f"sexe: {contexte['sexe']}") + if contexte.get("age"): + ctx_parts.append(f"âge: {contexte['age']} ans") + ctx_str = ", ".join(ctx_parts) if ctx_parts else "non précisé" + + return f"""Tu es un expert en codage CIM-10 pour le PMSI en France. Suggère le code CIM-10 le plus précis pour le diagnostic suivant, en te basant UNIQUEMENT sur les sources officielles fournies. + +Diagnostic à coder : "{texte}" +Contexte patient : {ctx_str} + +Sources de référence : +{sources_text} +Réponds UNIQUEMENT au format JSON suivant, sans texte avant ou après : +{{"code": "X99.9", "confidence": "high|medium|low", "justification": "explication courte en français"}}""" + + +def _call_ollama(prompt: str) -> dict | None: + """Appelle Ollama et parse la réponse JSON.""" + try: + response = requests.post( + OLLAMA_URL, + json={ + "model": OLLAMA_MODEL, + "prompt": prompt, + "stream": False, + "options": { + "temperature": 0.1, + "num_predict": 300, + }, + }, + timeout=OLLAMA_TIMEOUT, + ) + response.raise_for_status() + raw = response.json().get("response", "") + + # Extraire le JSON de la réponse (peut contenir du texte autour) + json_match = None + # Chercher un bloc JSON entre accolades + brace_start = raw.find("{") + brace_end = raw.rfind("}") + if brace_start != -1 and brace_end != -1: + json_match = raw[brace_start:brace_end + 1] + + if json_match: + return json.loads(json_match) + else: + logger.warning("Ollama : réponse sans JSON valide : %s", raw[:200]) + return None + + except requests.ConnectionError: + logger.warning("Ollama non disponible (connexion refusée)") + return None + except requests.Timeout: + logger.warning("Ollama timeout après %ds", OLLAMA_TIMEOUT) + return None + except (requests.RequestException, json.JSONDecodeError) as e: + logger.warning("Ollama erreur : %s", e) + return None + + +def enrich_diagnostic( + diagnostic: Diagnostic, + contexte: dict, +) -> None: + """Enrichit un Diagnostic avec le RAG (FAISS + Ollama). + + Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent. + """ + # 1. Recherche FAISS + sources = search_similar(diagnostic.texte, top_k=5) + + if not sources: + logger.debug("Aucune source RAG trouvée pour : %s", diagnostic.texte) + return + + # 2. Stocker les sources RAG + diagnostic.sources_rag = [ + RAGSource( + document=s["document"], + page=s.get("page"), + code=s.get("code"), + extrait=s.get("extrait", "")[:200], + ) + for s in sources + ] + + # 3. Appel Ollama pour justification + prompt = _build_prompt(diagnostic.texte, sources, contexte) + llm_result = _call_ollama(prompt) + + if llm_result: + code = llm_result.get("code") + confidence = llm_result.get("confidence") + justification = llm_result.get("justification") + + if code: + diagnostic.cim10_suggestion = code + if confidence in ("high", "medium", "low"): + diagnostic.cim10_confidence = confidence + if justification: + diagnostic.justification = justification + else: + logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM") + + +def enrich_dossier(dossier: DossierMedical) -> None: + """Enrichit le DP et tous les DAS d'un dossier via le RAG.""" + contexte = { + "sexe": dossier.sejour.sexe, + "age": dossier.sejour.age, + } + + if dossier.diagnostic_principal: + logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte) + enrich_diagnostic(dossier.diagnostic_principal, contexte) + + for das in dossier.diagnostics_associes: + logger.info("RAG enrichissement DAS : %s", das.texte) + enrich_diagnostic(das, contexte) diff --git a/tests/test_rag.py b/tests/test_rag.py new file mode 100644 index 0000000..1b9fc0c --- /dev/null +++ b/tests/test_rag.py @@ -0,0 +1,271 @@ +"""Tests pour le RAG CIM-10 (modèles, chunking, intégration).""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest + +from src.config import RAGSource, Diagnostic, DossierMedical, CIM10_PDF, GUIDE_METHODO_PDF, CCAM_PDF + + +class TestRAGSource: + def test_create_minimal(self): + src = RAGSource(document="cim10") + assert src.document == "cim10" + assert src.page is None + assert src.code is None + assert src.extrait is None + + def test_create_full(self): + src = RAGSource( + document="guide_methodo", + page=42, + code="K85", + extrait="Pancréatite aiguë biliaire...", + ) + assert src.document == "guide_methodo" + assert src.page == 42 + assert src.code == "K85" + assert src.extrait == "Pancréatite aiguë biliaire..." + + def test_serialization(self): + src = RAGSource(document="ccam", page=1, code="HMFC004") + data = src.model_dump(exclude_none=True) + assert data == {"document": "ccam", "page": 1, "code": "HMFC004"} + + +class TestDiagnosticExtended: + def test_backward_compatible(self): + """Les nouveaux champs sont optionnels — rétrocompatible.""" + d = Diagnostic(texte="Pancréatite aiguë", cim10_suggestion="K85.9") + assert d.texte == "Pancréatite aiguë" + assert d.cim10_suggestion == "K85.9" + assert d.cim10_confidence is None + assert d.justification is None + assert d.sources_rag == [] + + def test_with_rag_fields(self): + d = Diagnostic( + texte="Lithiase cholédoque", + cim10_suggestion="K80.5", + cim10_confidence="high", + justification="Code K80.5 correspond à la lithiase du cholédoque", + sources_rag=[ + RAGSource(document="cim10", page=480, code="K80"), + ], + ) + assert d.cim10_confidence == "high" + assert d.justification is not None + assert len(d.sources_rag) == 1 + assert d.sources_rag[0].code == "K80" + + def test_serialization_exclude_none(self): + """Vérifier que le JSON n'inclut pas les champs None.""" + d = Diagnostic(texte="Test", cim10_suggestion="K85.9") + data = d.model_dump(exclude_none=True) + assert "cim10_confidence" not in data + assert "justification" not in data + assert "sources_rag" in data # list vide incluse + + def test_dossier_with_extended_diagnostic(self): + """Un DossierMedical avec des diagnostics enrichis par le RAG.""" + dossier = DossierMedical( + diagnostic_principal=Diagnostic( + texte="Pancréatite aiguë biliaire", + cim10_suggestion="K85.1", + cim10_confidence="high", + justification="Confirmé par CIM-10 FR 2026", + sources_rag=[ + RAGSource(document="cim10", page=496, code="K85"), + RAGSource(document="guide_methodo", page=30), + ], + ), + ) + assert dossier.diagnostic_principal.cim10_confidence == "high" + assert len(dossier.diagnostic_principal.sources_rag) == 2 + + +class TestExtractMedicalInfoRAGFlag: + def test_use_rag_false_no_change(self): + """use_rag=False ne modifie pas le comportement existant.""" + from src.medical.cim10_extractor import extract_medical_info + + parsed = { + "type": "crh", + "patient": {"sexe": "M"}, + "sejour": {}, + "diagnostics": [], + } + text = "Pancréatite aiguë biliaire.\nTTT de sortie :\nParacétamol\n\nDevenir : retour." + + dossier = extract_medical_info(parsed, text, use_rag=False) + assert dossier.diagnostic_principal is not None + assert dossier.diagnostic_principal.cim10_suggestion == "K85.1" + # Pas de sources RAG + assert dossier.diagnostic_principal.sources_rag == [] + assert dossier.diagnostic_principal.justification is None + + def test_use_rag_true_calls_enrich(self): + """use_rag=True appelle _enrich_with_rag (mocké).""" + from src.medical.cim10_extractor import extract_medical_info + + parsed = { + "type": "crh", + "patient": {"sexe": "M"}, + "sejour": {}, + "diagnostics": [], + } + text = "Pancréatite aiguë biliaire.\nTTT de sortie :\nParacétamol\n\nDevenir : retour." + + with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich: + dossier = extract_medical_info(parsed, text, use_rag=True) + mock_enrich.assert_called_once_with(dossier) + + def test_use_rag_default_false(self): + """Par défaut, use_rag=False.""" + from src.medical.cim10_extractor import extract_medical_info + + parsed = { + "type": "crh", + "patient": {"sexe": "M"}, + "sejour": {}, + "diagnostics": [], + } + text = "Test simple." + + with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich: + extract_medical_info(parsed, text) + mock_enrich.assert_not_called() + + +class TestChunkingCIM10: + @pytest.mark.skipif( + not CIM10_PDF.exists(), + reason=f"PDF CIM-10 non trouvé : {CIM10_PDF}", + ) + def test_chunks_contain_known_codes(self): + from src.medical.rag_index import _chunk_cim10 + + chunks = _chunk_cim10(CIM10_PDF) + assert len(chunks) > 100, f"Trop peu de chunks : {len(chunks)}" + + codes = {c.code for c in chunks if c.code} + assert "K85" in codes, "K85 (pancréatite) non trouvé" + assert "K80" in codes, "K80 (lithiase biliaire) non trouvé" + assert "E66" in codes, "E66 (obésité) non trouvé" + + @pytest.mark.skipif( + not CIM10_PDF.exists(), + reason=f"PDF CIM-10 non trouvé : {CIM10_PDF}", + ) + def test_chunk_content(self): + from src.medical.rag_index import _chunk_cim10 + + chunks = _chunk_cim10(CIM10_PDF) + k85_chunks = [c for c in chunks if c.code == "K85"] + assert len(k85_chunks) >= 1 + assert "pancréatite" in k85_chunks[0].text.lower() or "pancreatite" in k85_chunks[0].text.lower() + + +class TestChunkingGuideMethodo: + @pytest.mark.skipif( + not GUIDE_METHODO_PDF.exists(), + reason=f"PDF Guide Métho non trouvé : {GUIDE_METHODO_PDF}", + ) + def test_chunks_extracted(self): + from src.medical.rag_index import _chunk_guide_methodo + + chunks = _chunk_guide_methodo(GUIDE_METHODO_PDF) + assert len(chunks) >= 10, f"Trop peu de chunks : {len(chunks)}" + assert all(c.document == "guide_methodo" for c in chunks) + + +class TestChunkingCCAM: + @pytest.mark.skipif( + not CCAM_PDF.exists(), + reason=f"PDF CCAM non trouvé : {CCAM_PDF}", + ) + def test_chunks_extracted(self): + from src.medical.rag_index import _chunk_ccam + + chunks = _chunk_ccam(CCAM_PDF) + assert len(chunks) >= 1, f"Aucun chunk CCAM extrait" + assert all(c.document == "ccam" for c in chunks) + + +class TestRAGSearchMocked: + def test_search_similar_no_index(self): + """search_similar retourne une liste vide si l'index n'existe pas.""" + from src.medical.rag_search import search_similar + + with patch("src.medical.rag_index.get_index", return_value=None): + results = search_similar("pancréatite aiguë") + assert results == [] + + def test_enrich_diagnostic_no_sources(self): + """enrich_diagnostic ne plante pas si aucune source trouvée.""" + from src.medical.rag_search import enrich_diagnostic + + diag = Diagnostic(texte="test quelconque", cim10_suggestion="Z99.9") + + with patch("src.medical.rag_search.search_similar", return_value=[]): + enrich_diagnostic(diag, {"sexe": "M", "age": 50}) + + assert diag.sources_rag == [] + assert diag.justification is None + + def test_enrich_diagnostic_with_sources_no_ollama(self): + """Enrichissement avec sources FAISS mais sans Ollama.""" + from src.medical.rag_search import enrich_diagnostic + + diag = Diagnostic(texte="Pancréatite aiguë", cim10_suggestion="K85.9") + mock_sources = [ + { + "document": "cim10", + "page": 496, + "code": "K85", + "extrait": "K85 Pancréatite aiguë...", + "score": 0.92, + }, + ] + + with patch("src.medical.rag_search.search_similar", return_value=mock_sources), \ + patch("src.medical.rag_search._call_ollama", return_value=None): + enrich_diagnostic(diag, {"sexe": "M", "age": 50}) + + assert len(diag.sources_rag) == 1 + assert diag.sources_rag[0].document == "cim10" + assert diag.sources_rag[0].code == "K85" + # Pas de justification (Ollama non disponible) + assert diag.justification is None + + def test_enrich_diagnostic_with_ollama(self): + """Enrichissement complet avec sources + Ollama.""" + from src.medical.rag_search import enrich_diagnostic + + diag = Diagnostic(texte="Pancréatite aiguë biliaire") + mock_sources = [ + { + "document": "cim10", + "page": 496, + "code": "K85", + "extrait": "K85 Pancréatite aiguë...", + "score": 0.95, + }, + ] + mock_llm = { + "code": "K85.1", + "confidence": "high", + "justification": "Pancréatite aiguë d'origine biliaire = K85.1", + } + + with patch("src.medical.rag_search.search_similar", return_value=mock_sources), \ + patch("src.medical.rag_search._call_ollama", return_value=mock_llm): + enrich_diagnostic(diag, {"sexe": "F", "age": 43}) + + assert diag.cim10_suggestion == "K85.1" + assert diag.cim10_confidence == "high" + assert diag.justification == "Pancréatite aiguë d'origine biliaire = K85.1" + assert len(diag.sources_rag) == 1