feat: ajout RAG CIM-10 avec FAISS + Ollama
Implémente un système RAG (Retrieval Augmented Generation) qui indexe les documents de référence ATIH (CIM-10 FR 2026, Guide Métho MCO, CCAM PMSI) et utilise Ollama (mistral-small3.2:24b) pour justifier et valider le codage CIM-10 des diagnostics. - Nouveaux modèles Pydantic : RAGSource, Diagnostic étendu (confidence, justification, sources_rag) — rétrocompatible - Module rag_index.py : chunking des 3 PDFs, embedding sentence-camembert-large, index FAISS IndexFlatIP (3630 vecteurs) - Module rag_search.py : recherche FAISS + appel Ollama avec fallback double - Flag CLI --no-rag pour désactiver l'enrichissement RAG - 18 nouveaux tests (88/88 passent) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@ __pycache__/
|
|||||||
output/
|
output/
|
||||||
input/
|
input/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
|
data/
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
pdfplumber>=0.10.0
|
pdfplumber>=0.10.0
|
||||||
transformers>=4.35.0
|
transformers>=4.35.0,<5.0.0
|
||||||
torch>=2.1.0
|
torch>=2.1.0
|
||||||
|
protobuf>=3.20.0,<4.0.0
|
||||||
regex>=2023.0
|
regex>=2023.0
|
||||||
pydantic>=2.5.0
|
pydantic>=2.5.0
|
||||||
pytest>=7.4.0
|
pytest>=7.4.0
|
||||||
sentencepiece>=0.1.99,<0.2.0
|
sentencepiece>=0.1.99,<0.2.0
|
||||||
edsnlp[ml]>=0.17.0
|
edsnlp[ml]>=0.17.0
|
||||||
|
faiss-cpu>=1.7.0
|
||||||
|
sentence-transformers>=2.2.0
|
||||||
|
requests>=2.28.0
|
||||||
|
|||||||
@@ -28,9 +28,24 @@ NER_MODEL = "Jean-Baptiste/camembert-ner"
|
|||||||
NER_CONFIDENCE_THRESHOLD = 0.80
|
NER_CONFIDENCE_THRESHOLD = 0.80
|
||||||
|
|
||||||
|
|
||||||
|
# --- Configuration RAG ---
|
||||||
|
|
||||||
|
RAG_INDEX_DIR = BASE_DIR / "data" / "rag_index"
|
||||||
|
CIM10_PDF = Path("/home/dom/ai/aivanov_CIM/cim-10-fr_2026_a_usage_pmsi_version_provisoire_111225.pdf")
|
||||||
|
GUIDE_METHODO_PDF = Path("/home/dom/ai/aivanov_CIM/guide_methodo_mco_2026_version_provisoire.pdf")
|
||||||
|
CCAM_PDF = Path("/home/dom/ai/aivanov_CIM/actualisation_ccam_descriptive_a_usage_pmsi_v4_2025.pdf")
|
||||||
|
|
||||||
|
|
||||||
# --- Modèles de données CIM-10 ---
|
# --- Modèles de données CIM-10 ---
|
||||||
|
|
||||||
|
|
||||||
|
class RAGSource(BaseModel):
|
||||||
|
document: str
|
||||||
|
page: Optional[int] = None
|
||||||
|
code: Optional[str] = None
|
||||||
|
extrait: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class Sejour(BaseModel):
|
class Sejour(BaseModel):
|
||||||
sexe: Optional[str] = None
|
sexe: Optional[str] = None
|
||||||
age: Optional[int] = None
|
age: Optional[int] = None
|
||||||
@@ -47,6 +62,9 @@ class Sejour(BaseModel):
|
|||||||
class Diagnostic(BaseModel):
|
class Diagnostic(BaseModel):
|
||||||
texte: str
|
texte: str
|
||||||
cim10_suggestion: Optional[str] = None
|
cim10_suggestion: Optional[str] = None
|
||||||
|
cim10_confidence: Optional[str] = None
|
||||||
|
justification: Optional[str] = None
|
||||||
|
sources_rag: list[RAGSource] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ActeCCAM(BaseModel):
|
class ActeCCAM(BaseModel):
|
||||||
|
|||||||
15
src/main.py
15
src/main.py
@@ -22,8 +22,9 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Flag global pour désactiver edsnlp
|
# Flags globaux
|
||||||
_use_edsnlp = True
|
_use_edsnlp = True
|
||||||
|
_use_rag = True
|
||||||
|
|
||||||
|
|
||||||
def process_pdf(pdf_path: Path) -> tuple[str, DossierMedical, AnonymizationReport]:
|
def process_pdf(pdf_path: Path) -> tuple[str, DossierMedical, AnonymizationReport]:
|
||||||
@@ -63,7 +64,7 @@ def process_pdf(pdf_path: Path) -> tuple[str, DossierMedical, AnonymizationRepor
|
|||||||
edsnlp_result = _run_edsnlp(anonymized_text)
|
edsnlp_result = _run_edsnlp(anonymized_text)
|
||||||
|
|
||||||
# 6. Extraction médicale CIM-10
|
# 6. Extraction médicale CIM-10
|
||||||
dossier = extract_medical_info(parsed, anonymized_text, edsnlp_result)
|
dossier = extract_medical_info(parsed, anonymized_text, edsnlp_result, use_rag=_use_rag)
|
||||||
dossier.source_file = pdf_path.name
|
dossier.source_file = pdf_path.name
|
||||||
dossier.document_type = doc_type
|
dossier.document_type = doc_type
|
||||||
logger.info(" DP : %s", dossier.diagnostic_principal)
|
logger.info(" DP : %s", dossier.diagnostic_principal)
|
||||||
@@ -123,7 +124,7 @@ def write_outputs(
|
|||||||
|
|
||||||
def main(input_path: str | None = None) -> None:
|
def main(input_path: str | None = None) -> None:
|
||||||
"""Point d'entrée principal."""
|
"""Point d'entrée principal."""
|
||||||
global _use_edsnlp
|
global _use_edsnlp, _use_rag
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Anonymisation de documents médicaux PDF et extraction CIM-10",
|
description="Anonymisation de documents médicaux PDF et extraction CIM-10",
|
||||||
@@ -144,6 +145,11 @@ def main(input_path: str | None = None) -> None:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Désactiver l'analyse edsnlp (mode regex seul)",
|
help="Désactiver l'analyse edsnlp (mode regex seul)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-rag",
|
||||||
|
action="store_true",
|
||||||
|
help="Désactiver l'enrichissement RAG (FAISS + Ollama)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.no_ner:
|
if args.no_ner:
|
||||||
@@ -154,6 +160,9 @@ def main(input_path: str | None = None) -> None:
|
|||||||
if args.no_edsnlp:
|
if args.no_edsnlp:
|
||||||
_use_edsnlp = False
|
_use_edsnlp = False
|
||||||
|
|
||||||
|
if args.no_rag:
|
||||||
|
_use_rag = False
|
||||||
|
|
||||||
input_p = Path(args.input)
|
input_p = Path(args.input)
|
||||||
if input_p.is_file():
|
if input_p.is_file():
|
||||||
pdfs = [input_p]
|
pdfs = [input_p]
|
||||||
|
|||||||
@@ -2,10 +2,13 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from ..config import (
|
from ..config import (
|
||||||
ActeCCAM,
|
ActeCCAM,
|
||||||
BiologieCle,
|
BiologieCle,
|
||||||
@@ -91,6 +94,7 @@ def extract_medical_info(
|
|||||||
parsed_data: dict,
|
parsed_data: dict,
|
||||||
anonymized_text: str,
|
anonymized_text: str,
|
||||||
edsnlp_result: Optional[EdsnlpResult] = None,
|
edsnlp_result: Optional[EdsnlpResult] = None,
|
||||||
|
use_rag: bool = False,
|
||||||
) -> DossierMedical:
|
) -> DossierMedical:
|
||||||
"""Extrait les informations médicales structurées depuis les données parsées et le texte."""
|
"""Extrait les informations médicales structurées depuis les données parsées et le texte."""
|
||||||
dossier = DossierMedical()
|
dossier = DossierMedical()
|
||||||
@@ -105,9 +109,23 @@ def extract_medical_info(
|
|||||||
_extract_imagerie(anonymized_text, dossier)
|
_extract_imagerie(anonymized_text, dossier)
|
||||||
_extract_complications(anonymized_text, dossier, edsnlp_result)
|
_extract_complications(anonymized_text, dossier, edsnlp_result)
|
||||||
|
|
||||||
|
if use_rag:
|
||||||
|
_enrich_with_rag(dossier)
|
||||||
|
|
||||||
return dossier
|
return dossier
|
||||||
|
|
||||||
|
|
||||||
|
def _enrich_with_rag(dossier: DossierMedical) -> None:
|
||||||
|
"""Enrichit les diagnostics via le RAG (FAISS + Ollama)."""
|
||||||
|
try:
|
||||||
|
from .rag_search import enrich_dossier
|
||||||
|
enrich_dossier(dossier)
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("Module RAG non disponible (faiss-cpu ou sentence-transformers manquant)")
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Erreur lors de l'enrichissement RAG", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def _extract_sejour(parsed: dict, dossier: DossierMedical) -> None:
|
def _extract_sejour(parsed: dict, dossier: DossierMedical) -> None:
|
||||||
"""Extrait les informations de séjour."""
|
"""Extrait les informations de séjour."""
|
||||||
patient = parsed.get("patient", {})
|
patient = parsed.get("patient", {})
|
||||||
|
|||||||
352
src/medical/rag_index.py
Normal file
352
src/medical/rag_index.py
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
"""Indexation FAISS des documents de référence CIM-10 / Guide métho / CCAM."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pdfplumber
|
||||||
|
|
||||||
|
from ..config import RAG_INDEX_DIR, CIM10_PDF, GUIDE_METHODO_PDF, CCAM_PDF
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Singleton pour l'index chargé en mémoire
|
||||||
|
_faiss_index = None
|
||||||
|
_metadata: list[dict] = []
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Chunk:
|
||||||
|
text: str
|
||||||
|
document: str # "cim10", "guide_methodo", "ccam"
|
||||||
|
page: Optional[int] = None
|
||||||
|
code: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Chunking CIM-10
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _chunk_cim10(pdf_path: Path) -> list[Chunk]:
|
||||||
|
"""Découpe le PDF CIM-10 en chunks par code 3 caractères (ex: K80, K85)."""
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
current_code: str | None = None
|
||||||
|
current_text: list[str] = []
|
||||||
|
current_page: int | None = None
|
||||||
|
|
||||||
|
# Pattern pour détecter un code CIM-10 à 3 caractères en début de ligne
|
||||||
|
code3_pattern = re.compile(r"^([A-Z]\d{2})\s+(.+)")
|
||||||
|
# Pattern pour les sous-codes (ex: K80.0, K80.1)
|
||||||
|
subcode_pattern = re.compile(r"^([A-Z]\d{2}\.\d+)\s+(.+)")
|
||||||
|
|
||||||
|
logger.info("Extraction des chunks CIM-10 depuis %s", pdf_path.name)
|
||||||
|
|
||||||
|
with pdfplumber.open(pdf_path) as pdf:
|
||||||
|
for page_num, page in enumerate(pdf.pages, start=1):
|
||||||
|
text = page.extract_text()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for line in text.split("\n"):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = code3_pattern.match(line)
|
||||||
|
if m and not subcode_pattern.match(line):
|
||||||
|
# Nouveau code 3-char → sauvegarder le chunk précédent
|
||||||
|
if current_code and current_text:
|
||||||
|
chunk_text = "\n".join(current_text)
|
||||||
|
if len(chunk_text.split()) >= 5:
|
||||||
|
chunks.append(Chunk(
|
||||||
|
text=chunk_text,
|
||||||
|
document="cim10",
|
||||||
|
page=current_page,
|
||||||
|
code=current_code,
|
||||||
|
))
|
||||||
|
current_code = m.group(1)
|
||||||
|
current_text = [line]
|
||||||
|
current_page = page_num
|
||||||
|
else:
|
||||||
|
if current_code:
|
||||||
|
current_text.append(line)
|
||||||
|
|
||||||
|
# Dernier chunk
|
||||||
|
if current_code and current_text:
|
||||||
|
chunk_text = "\n".join(current_text)
|
||||||
|
if len(chunk_text.split()) >= 5:
|
||||||
|
chunks.append(Chunk(
|
||||||
|
text=chunk_text,
|
||||||
|
document="cim10",
|
||||||
|
page=current_page,
|
||||||
|
code=current_code,
|
||||||
|
))
|
||||||
|
|
||||||
|
logger.info("CIM-10 : %d chunks extraits", len(chunks))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Chunking Guide Méthodologique MCO
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _chunk_guide_methodo(pdf_path: Path) -> list[Chunk]:
|
||||||
|
"""Découpe le Guide Méthodologique MCO par sections/titres."""
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
current_title: str | None = None
|
||||||
|
current_text: list[str] = []
|
||||||
|
current_page: int | None = None
|
||||||
|
|
||||||
|
# Patterns de titres de sections (chapitres, sous-chapitres)
|
||||||
|
title_patterns = [
|
||||||
|
re.compile(r"^((?:CHAPITRE|TITRE|PARTIE)\s+[IVXLCDM0-9]+.*)$", re.IGNORECASE),
|
||||||
|
re.compile(r"^(\d+\.\d*\s+[A-ZÉÈÊÀÂÔÙÛÜ].{5,})$"),
|
||||||
|
re.compile(r"^([A-ZÉÈÊÀÂÔÙÛÜ][A-ZÉÈÊÀÂÔÙÛÜ\s]{10,})$"),
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info("Extraction des chunks Guide Métho depuis %s", pdf_path.name)
|
||||||
|
|
||||||
|
with pdfplumber.open(pdf_path) as pdf:
|
||||||
|
for page_num, page in enumerate(pdf.pages, start=1):
|
||||||
|
text = page.extract_text()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for line in text.split("\n"):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_title = False
|
||||||
|
for pat in title_patterns:
|
||||||
|
if pat.match(line):
|
||||||
|
is_title = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_title and len(line) > 8:
|
||||||
|
# Sauvegarder le chunk précédent
|
||||||
|
if current_title and current_text:
|
||||||
|
chunk_text = current_title + "\n" + "\n".join(current_text)
|
||||||
|
if len(chunk_text.split()) >= 20:
|
||||||
|
chunks.append(Chunk(
|
||||||
|
text=chunk_text,
|
||||||
|
document="guide_methodo",
|
||||||
|
page=current_page,
|
||||||
|
))
|
||||||
|
current_title = line
|
||||||
|
current_text = []
|
||||||
|
current_page = page_num
|
||||||
|
else:
|
||||||
|
current_text.append(line)
|
||||||
|
|
||||||
|
# Dernier chunk
|
||||||
|
if current_title and current_text:
|
||||||
|
chunk_text = current_title + "\n" + "\n".join(current_text)
|
||||||
|
if len(chunk_text.split()) >= 20:
|
||||||
|
chunks.append(Chunk(
|
||||||
|
text=chunk_text,
|
||||||
|
document="guide_methodo",
|
||||||
|
page=current_page,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Si trop peu de chunks (le PDF ne suit pas les patterns de titre),
|
||||||
|
# fallback : découper par pages groupées par 3
|
||||||
|
if len(chunks) < 10:
|
||||||
|
logger.info("Guide Métho : fallback découpe par pages (peu de titres détectés)")
|
||||||
|
chunks = []
|
||||||
|
with pdfplumber.open(pdf_path) as pdf:
|
||||||
|
page_texts: list[str] = []
|
||||||
|
start_page = 1
|
||||||
|
for page_num, page in enumerate(pdf.pages, start=1):
|
||||||
|
text = page.extract_text()
|
||||||
|
if text:
|
||||||
|
page_texts.append(text)
|
||||||
|
if len(page_texts) >= 3:
|
||||||
|
combined = "\n".join(page_texts)
|
||||||
|
if len(combined.split()) >= 20:
|
||||||
|
chunks.append(Chunk(
|
||||||
|
text=combined,
|
||||||
|
document="guide_methodo",
|
||||||
|
page=start_page,
|
||||||
|
))
|
||||||
|
page_texts = []
|
||||||
|
start_page = page_num + 1
|
||||||
|
if page_texts:
|
||||||
|
combined = "\n".join(page_texts)
|
||||||
|
if len(combined.split()) >= 20:
|
||||||
|
chunks.append(Chunk(
|
||||||
|
text=combined,
|
||||||
|
document="guide_methodo",
|
||||||
|
page=start_page,
|
||||||
|
))
|
||||||
|
|
||||||
|
logger.info("Guide Métho : %d chunks extraits", len(chunks))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Chunking CCAM
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _chunk_ccam(pdf_path: Path) -> list[Chunk]:
|
||||||
|
"""Découpe le PDF CCAM en chunks par code d'acte."""
|
||||||
|
chunks: list[Chunk] = []
|
||||||
|
ccam_pattern = re.compile(r"([A-Z]{4}\d{3})\s+(.*)")
|
||||||
|
|
||||||
|
logger.info("Extraction des chunks CCAM depuis %s", pdf_path.name)
|
||||||
|
|
||||||
|
with pdfplumber.open(pdf_path) as pdf:
|
||||||
|
for page_num, page in enumerate(pdf.pages, start=1):
|
||||||
|
text = page.extract_text()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_code: str | None = None
|
||||||
|
current_lines: list[str] = []
|
||||||
|
|
||||||
|
for line in text.split("\n"):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = ccam_pattern.match(line)
|
||||||
|
if m:
|
||||||
|
if current_code and current_lines:
|
||||||
|
chunks.append(Chunk(
|
||||||
|
text="\n".join(current_lines),
|
||||||
|
document="ccam",
|
||||||
|
page=page_num,
|
||||||
|
code=current_code,
|
||||||
|
))
|
||||||
|
current_code = m.group(1)
|
||||||
|
current_lines = [line]
|
||||||
|
elif current_code:
|
||||||
|
current_lines.append(line)
|
||||||
|
|
||||||
|
if current_code and current_lines:
|
||||||
|
chunks.append(Chunk(
|
||||||
|
text="\n".join(current_lines),
|
||||||
|
document="ccam",
|
||||||
|
page=page_num,
|
||||||
|
code=current_code,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Fallback : si aucun code CCAM détecté, indexer par page
|
||||||
|
if not chunks:
|
||||||
|
logger.info("CCAM : aucun code détecté, fallback par page")
|
||||||
|
with pdfplumber.open(pdf_path) as pdf:
|
||||||
|
for page_num, page in enumerate(pdf.pages, start=1):
|
||||||
|
text = page.extract_text()
|
||||||
|
if text and len(text.split()) >= 10:
|
||||||
|
chunks.append(Chunk(
|
||||||
|
text=text,
|
||||||
|
document="ccam",
|
||||||
|
page=page_num,
|
||||||
|
))
|
||||||
|
|
||||||
|
logger.info("CCAM : %d chunks extraits", len(chunks))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Construction de l'index FAISS
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def build_index(force: bool = False) -> None:
|
||||||
|
"""Construit l'index FAISS à partir des 3 PDFs de référence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: Si True, reconstruit même si l'index existe déjà.
|
||||||
|
"""
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
index_path = RAG_INDEX_DIR / "faiss.index"
|
||||||
|
meta_path = RAG_INDEX_DIR / "metadata.json"
|
||||||
|
|
||||||
|
if not force and index_path.exists() and meta_path.exists():
|
||||||
|
logger.info("Index FAISS déjà existant dans %s (use force=True pour reconstruire)", RAG_INDEX_DIR)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Collecter tous les chunks
|
||||||
|
all_chunks: list[Chunk] = []
|
||||||
|
|
||||||
|
for pdf_path, chunk_fn in [
|
||||||
|
(CIM10_PDF, _chunk_cim10),
|
||||||
|
(GUIDE_METHODO_PDF, _chunk_guide_methodo),
|
||||||
|
(CCAM_PDF, _chunk_ccam),
|
||||||
|
]:
|
||||||
|
if pdf_path.exists():
|
||||||
|
all_chunks.extend(chunk_fn(pdf_path))
|
||||||
|
else:
|
||||||
|
logger.warning("PDF non trouvé : %s", pdf_path)
|
||||||
|
|
||||||
|
if not all_chunks:
|
||||||
|
logger.error("Aucun chunk extrait — vérifiez les chemins des PDFs")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Total : %d chunks à indexer", len(all_chunks))
|
||||||
|
|
||||||
|
# Embeddings — forcer CPU pour éviter les bugs CUDA avec ce modèle
|
||||||
|
logger.info("Chargement du modèle d'embedding dangvantuan/sentence-camembert-large (CPU)...")
|
||||||
|
model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu")
|
||||||
|
model.max_seq_length = 512 # CamemBERT max position embeddings
|
||||||
|
|
||||||
|
texts = [c.text[:2000] for c in all_chunks] # Tronquer les chunks trop longs
|
||||||
|
logger.info("Calcul des embeddings pour %d chunks...", len(texts))
|
||||||
|
embeddings = model.encode(
|
||||||
|
texts, show_progress_bar=True, normalize_embeddings=True, batch_size=64,
|
||||||
|
)
|
||||||
|
embeddings = np.array(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
# Index FAISS (IndexFlatIP = cosine similarity avec vecteurs normalisés)
|
||||||
|
dim = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dim)
|
||||||
|
index.add(embeddings)
|
||||||
|
|
||||||
|
# Sauvegarder
|
||||||
|
RAG_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
faiss.write_index(index, str(index_path))
|
||||||
|
|
||||||
|
metadata = [asdict(c) for c in all_chunks]
|
||||||
|
# Ne pas sauvegarder le texte complet dans metadata (trop lourd),
|
||||||
|
# garder un extrait de 500 chars
|
||||||
|
for m in metadata:
|
||||||
|
m["extrait"] = m.pop("text")[:500]
|
||||||
|
|
||||||
|
meta_path.write_text(json.dumps(metadata, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
|
||||||
|
logger.info("Index FAISS sauvegardé : %s (%d vecteurs, dim=%d)", index_path, len(all_chunks), dim)
|
||||||
|
|
||||||
|
|
||||||
|
def get_index() -> tuple | None:
|
||||||
|
"""Charge l'index FAISS et les métadonnées (singleton lazy-loaded).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple (faiss_index, metadata_list) ou None si l'index n'existe pas.
|
||||||
|
"""
|
||||||
|
global _faiss_index, _metadata
|
||||||
|
|
||||||
|
if _faiss_index is not None:
|
||||||
|
return _faiss_index, _metadata
|
||||||
|
|
||||||
|
index_path = RAG_INDEX_DIR / "faiss.index"
|
||||||
|
meta_path = RAG_INDEX_DIR / "metadata.json"
|
||||||
|
|
||||||
|
if not index_path.exists() or not meta_path.exists():
|
||||||
|
logger.warning("Index FAISS non trouvé dans %s — lancez build_index() d'abord", RAG_INDEX_DIR)
|
||||||
|
return None
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
|
||||||
|
_faiss_index = faiss.read_index(str(index_path))
|
||||||
|
_metadata = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
logger.info("Index FAISS chargé : %d vecteurs", _faiss_index.ntotal)
|
||||||
|
return _faiss_index, _metadata
|
||||||
208
src/medical/rag_search.py
Normal file
208
src/medical/rag_search.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ..config import Diagnostic, DossierMedical, RAGSource
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Configuration Ollama
|
||||||
|
OLLAMA_URL = "http://localhost:11434/api/generate"
|
||||||
|
OLLAMA_MODEL = "mistral-small3.2:24b"
|
||||||
|
OLLAMA_TIMEOUT = 120 # secondes
|
||||||
|
|
||||||
|
# Singleton pour le modèle d'embedding (chargé une seule fois)
|
||||||
|
_embed_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_embed_model():
|
||||||
|
"""Charge le modèle d'embedding (singleton)."""
|
||||||
|
global _embed_model
|
||||||
|
if _embed_model is None:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
logger.info("Chargement du modèle d'embedding pour la recherche...")
|
||||||
|
_embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu")
|
||||||
|
_embed_model.max_seq_length = 512
|
||||||
|
return _embed_model
|
||||||
|
|
||||||
|
|
||||||
|
def search_similar(query: str, top_k: int = 5) -> list[dict]:
|
||||||
|
"""Recherche les passages les plus similaires dans l'index FAISS.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Texte du diagnostic à rechercher.
|
||||||
|
top_k: Nombre de résultats à retourner.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Liste de dicts avec les métadonnées + score de similarité.
|
||||||
|
"""
|
||||||
|
from .rag_index import get_index
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
result = get_index()
|
||||||
|
if result is None:
|
||||||
|
logger.warning("Index FAISS non disponible")
|
||||||
|
return []
|
||||||
|
|
||||||
|
faiss_index, metadata = result
|
||||||
|
|
||||||
|
model = _get_embed_model()
|
||||||
|
query_vec = model.encode([query], normalize_embeddings=True)
|
||||||
|
query_vec = np.array(query_vec, dtype=np.float32)
|
||||||
|
|
||||||
|
scores, indices = faiss_index.search(query_vec, min(top_k, faiss_index.ntotal))
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
|
if idx < 0:
|
||||||
|
continue
|
||||||
|
meta = metadata[idx].copy()
|
||||||
|
meta["score"] = float(score)
|
||||||
|
results.append(meta)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _build_prompt(texte: str, sources: list[dict], contexte: dict) -> str:
|
||||||
|
"""Construit le prompt pour Ollama."""
|
||||||
|
sources_text = ""
|
||||||
|
for i, src in enumerate(sources, 1):
|
||||||
|
doc_name = {
|
||||||
|
"cim10": "CIM-10 FR 2026",
|
||||||
|
"guide_methodo": "Guide Méthodologique MCO 2026",
|
||||||
|
"ccam": "CCAM PMSI V4 2025",
|
||||||
|
}.get(src["document"], src["document"])
|
||||||
|
|
||||||
|
code_info = f" (code: {src['code']})" if src.get("code") else ""
|
||||||
|
page_info = f" [page {src['page']}]" if src.get("page") else ""
|
||||||
|
|
||||||
|
sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n"
|
||||||
|
sources_text += (src.get("extrait", "")[:800]) + "\n\n"
|
||||||
|
|
||||||
|
ctx_parts = []
|
||||||
|
if contexte.get("sexe"):
|
||||||
|
ctx_parts.append(f"sexe: {contexte['sexe']}")
|
||||||
|
if contexte.get("age"):
|
||||||
|
ctx_parts.append(f"âge: {contexte['age']} ans")
|
||||||
|
ctx_str = ", ".join(ctx_parts) if ctx_parts else "non précisé"
|
||||||
|
|
||||||
|
return f"""Tu es un expert en codage CIM-10 pour le PMSI en France. Suggère le code CIM-10 le plus précis pour le diagnostic suivant, en te basant UNIQUEMENT sur les sources officielles fournies.
|
||||||
|
|
||||||
|
Diagnostic à coder : "{texte}"
|
||||||
|
Contexte patient : {ctx_str}
|
||||||
|
|
||||||
|
Sources de référence :
|
||||||
|
{sources_text}
|
||||||
|
Réponds UNIQUEMENT au format JSON suivant, sans texte avant ou après :
|
||||||
|
{{"code": "X99.9", "confidence": "high|medium|low", "justification": "explication courte en français"}}"""
|
||||||
|
|
||||||
|
|
||||||
|
def _call_ollama(prompt: str) -> dict | None:
|
||||||
|
"""Appelle Ollama et parse la réponse JSON."""
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
OLLAMA_URL,
|
||||||
|
json={
|
||||||
|
"model": OLLAMA_MODEL,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": False,
|
||||||
|
"options": {
|
||||||
|
"temperature": 0.1,
|
||||||
|
"num_predict": 300,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
timeout=OLLAMA_TIMEOUT,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
raw = response.json().get("response", "")
|
||||||
|
|
||||||
|
# Extraire le JSON de la réponse (peut contenir du texte autour)
|
||||||
|
json_match = None
|
||||||
|
# Chercher un bloc JSON entre accolades
|
||||||
|
brace_start = raw.find("{")
|
||||||
|
brace_end = raw.rfind("}")
|
||||||
|
if brace_start != -1 and brace_end != -1:
|
||||||
|
json_match = raw[brace_start:brace_end + 1]
|
||||||
|
|
||||||
|
if json_match:
|
||||||
|
return json.loads(json_match)
|
||||||
|
else:
|
||||||
|
logger.warning("Ollama : réponse sans JSON valide : %s", raw[:200])
|
||||||
|
return None
|
||||||
|
|
||||||
|
except requests.ConnectionError:
|
||||||
|
logger.warning("Ollama non disponible (connexion refusée)")
|
||||||
|
return None
|
||||||
|
except requests.Timeout:
|
||||||
|
logger.warning("Ollama timeout après %ds", OLLAMA_TIMEOUT)
|
||||||
|
return None
|
||||||
|
except (requests.RequestException, json.JSONDecodeError) as e:
|
||||||
|
logger.warning("Ollama erreur : %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def enrich_diagnostic(
|
||||||
|
diagnostic: Diagnostic,
|
||||||
|
contexte: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Enrichit un Diagnostic avec le RAG (FAISS + Ollama).
|
||||||
|
|
||||||
|
Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent.
|
||||||
|
"""
|
||||||
|
# 1. Recherche FAISS
|
||||||
|
sources = search_similar(diagnostic.texte, top_k=5)
|
||||||
|
|
||||||
|
if not sources:
|
||||||
|
logger.debug("Aucune source RAG trouvée pour : %s", diagnostic.texte)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. Stocker les sources RAG
|
||||||
|
diagnostic.sources_rag = [
|
||||||
|
RAGSource(
|
||||||
|
document=s["document"],
|
||||||
|
page=s.get("page"),
|
||||||
|
code=s.get("code"),
|
||||||
|
extrait=s.get("extrait", "")[:200],
|
||||||
|
)
|
||||||
|
for s in sources
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. Appel Ollama pour justification
|
||||||
|
prompt = _build_prompt(diagnostic.texte, sources, contexte)
|
||||||
|
llm_result = _call_ollama(prompt)
|
||||||
|
|
||||||
|
if llm_result:
|
||||||
|
code = llm_result.get("code")
|
||||||
|
confidence = llm_result.get("confidence")
|
||||||
|
justification = llm_result.get("justification")
|
||||||
|
|
||||||
|
if code:
|
||||||
|
diagnostic.cim10_suggestion = code
|
||||||
|
if confidence in ("high", "medium", "low"):
|
||||||
|
diagnostic.cim10_confidence = confidence
|
||||||
|
if justification:
|
||||||
|
diagnostic.justification = justification
|
||||||
|
else:
|
||||||
|
logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM")
|
||||||
|
|
||||||
|
|
||||||
|
def enrich_dossier(dossier: DossierMedical) -> None:
|
||||||
|
"""Enrichit le DP et tous les DAS d'un dossier via le RAG."""
|
||||||
|
contexte = {
|
||||||
|
"sexe": dossier.sejour.sexe,
|
||||||
|
"age": dossier.sejour.age,
|
||||||
|
}
|
||||||
|
|
||||||
|
if dossier.diagnostic_principal:
|
||||||
|
logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte)
|
||||||
|
enrich_diagnostic(dossier.diagnostic_principal, contexte)
|
||||||
|
|
||||||
|
for das in dossier.diagnostics_associes:
|
||||||
|
logger.info("RAG enrichissement DAS : %s", das.texte)
|
||||||
|
enrich_diagnostic(das, contexte)
|
||||||
271
tests/test_rag.py
Normal file
271
tests/test_rag.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
"""Tests pour le RAG CIM-10 (modèles, chunking, intégration)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.config import RAGSource, Diagnostic, DossierMedical, CIM10_PDF, GUIDE_METHODO_PDF, CCAM_PDF
|
||||||
|
|
||||||
|
|
||||||
|
class TestRAGSource:
|
||||||
|
def test_create_minimal(self):
|
||||||
|
src = RAGSource(document="cim10")
|
||||||
|
assert src.document == "cim10"
|
||||||
|
assert src.page is None
|
||||||
|
assert src.code is None
|
||||||
|
assert src.extrait is None
|
||||||
|
|
||||||
|
def test_create_full(self):
|
||||||
|
src = RAGSource(
|
||||||
|
document="guide_methodo",
|
||||||
|
page=42,
|
||||||
|
code="K85",
|
||||||
|
extrait="Pancréatite aiguë biliaire...",
|
||||||
|
)
|
||||||
|
assert src.document == "guide_methodo"
|
||||||
|
assert src.page == 42
|
||||||
|
assert src.code == "K85"
|
||||||
|
assert src.extrait == "Pancréatite aiguë biliaire..."
|
||||||
|
|
||||||
|
def test_serialization(self):
|
||||||
|
src = RAGSource(document="ccam", page=1, code="HMFC004")
|
||||||
|
data = src.model_dump(exclude_none=True)
|
||||||
|
assert data == {"document": "ccam", "page": 1, "code": "HMFC004"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiagnosticExtended:
|
||||||
|
def test_backward_compatible(self):
|
||||||
|
"""Les nouveaux champs sont optionnels — rétrocompatible."""
|
||||||
|
d = Diagnostic(texte="Pancréatite aiguë", cim10_suggestion="K85.9")
|
||||||
|
assert d.texte == "Pancréatite aiguë"
|
||||||
|
assert d.cim10_suggestion == "K85.9"
|
||||||
|
assert d.cim10_confidence is None
|
||||||
|
assert d.justification is None
|
||||||
|
assert d.sources_rag == []
|
||||||
|
|
||||||
|
def test_with_rag_fields(self):
|
||||||
|
d = Diagnostic(
|
||||||
|
texte="Lithiase cholédoque",
|
||||||
|
cim10_suggestion="K80.5",
|
||||||
|
cim10_confidence="high",
|
||||||
|
justification="Code K80.5 correspond à la lithiase du cholédoque",
|
||||||
|
sources_rag=[
|
||||||
|
RAGSource(document="cim10", page=480, code="K80"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert d.cim10_confidence == "high"
|
||||||
|
assert d.justification is not None
|
||||||
|
assert len(d.sources_rag) == 1
|
||||||
|
assert d.sources_rag[0].code == "K80"
|
||||||
|
|
||||||
|
def test_serialization_exclude_none(self):
|
||||||
|
"""Vérifier que le JSON n'inclut pas les champs None."""
|
||||||
|
d = Diagnostic(texte="Test", cim10_suggestion="K85.9")
|
||||||
|
data = d.model_dump(exclude_none=True)
|
||||||
|
assert "cim10_confidence" not in data
|
||||||
|
assert "justification" not in data
|
||||||
|
assert "sources_rag" in data # list vide incluse
|
||||||
|
|
||||||
|
def test_dossier_with_extended_diagnostic(self):
|
||||||
|
"""Un DossierMedical avec des diagnostics enrichis par le RAG."""
|
||||||
|
dossier = DossierMedical(
|
||||||
|
diagnostic_principal=Diagnostic(
|
||||||
|
texte="Pancréatite aiguë biliaire",
|
||||||
|
cim10_suggestion="K85.1",
|
||||||
|
cim10_confidence="high",
|
||||||
|
justification="Confirmé par CIM-10 FR 2026",
|
||||||
|
sources_rag=[
|
||||||
|
RAGSource(document="cim10", page=496, code="K85"),
|
||||||
|
RAGSource(document="guide_methodo", page=30),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert dossier.diagnostic_principal.cim10_confidence == "high"
|
||||||
|
assert len(dossier.diagnostic_principal.sources_rag) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractMedicalInfoRAGFlag:
|
||||||
|
def test_use_rag_false_no_change(self):
|
||||||
|
"""use_rag=False ne modifie pas le comportement existant."""
|
||||||
|
from src.medical.cim10_extractor import extract_medical_info
|
||||||
|
|
||||||
|
parsed = {
|
||||||
|
"type": "crh",
|
||||||
|
"patient": {"sexe": "M"},
|
||||||
|
"sejour": {},
|
||||||
|
"diagnostics": [],
|
||||||
|
}
|
||||||
|
text = "Pancréatite aiguë biliaire.\nTTT de sortie :\nParacétamol\n\nDevenir : retour."
|
||||||
|
|
||||||
|
dossier = extract_medical_info(parsed, text, use_rag=False)
|
||||||
|
assert dossier.diagnostic_principal is not None
|
||||||
|
assert dossier.diagnostic_principal.cim10_suggestion == "K85.1"
|
||||||
|
# Pas de sources RAG
|
||||||
|
assert dossier.diagnostic_principal.sources_rag == []
|
||||||
|
assert dossier.diagnostic_principal.justification is None
|
||||||
|
|
||||||
|
def test_use_rag_true_calls_enrich(self):
|
||||||
|
"""use_rag=True appelle _enrich_with_rag (mocké)."""
|
||||||
|
from src.medical.cim10_extractor import extract_medical_info
|
||||||
|
|
||||||
|
parsed = {
|
||||||
|
"type": "crh",
|
||||||
|
"patient": {"sexe": "M"},
|
||||||
|
"sejour": {},
|
||||||
|
"diagnostics": [],
|
||||||
|
}
|
||||||
|
text = "Pancréatite aiguë biliaire.\nTTT de sortie :\nParacétamol\n\nDevenir : retour."
|
||||||
|
|
||||||
|
with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich:
|
||||||
|
dossier = extract_medical_info(parsed, text, use_rag=True)
|
||||||
|
mock_enrich.assert_called_once_with(dossier)
|
||||||
|
|
||||||
|
def test_use_rag_default_false(self):
|
||||||
|
"""Par défaut, use_rag=False."""
|
||||||
|
from src.medical.cim10_extractor import extract_medical_info
|
||||||
|
|
||||||
|
parsed = {
|
||||||
|
"type": "crh",
|
||||||
|
"patient": {"sexe": "M"},
|
||||||
|
"sejour": {},
|
||||||
|
"diagnostics": [],
|
||||||
|
}
|
||||||
|
text = "Test simple."
|
||||||
|
|
||||||
|
with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich:
|
||||||
|
extract_medical_info(parsed, text)
|
||||||
|
mock_enrich.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkingCIM10:
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not CIM10_PDF.exists(),
|
||||||
|
reason=f"PDF CIM-10 non trouvé : {CIM10_PDF}",
|
||||||
|
)
|
||||||
|
def test_chunks_contain_known_codes(self):
|
||||||
|
from src.medical.rag_index import _chunk_cim10
|
||||||
|
|
||||||
|
chunks = _chunk_cim10(CIM10_PDF)
|
||||||
|
assert len(chunks) > 100, f"Trop peu de chunks : {len(chunks)}"
|
||||||
|
|
||||||
|
codes = {c.code for c in chunks if c.code}
|
||||||
|
assert "K85" in codes, "K85 (pancréatite) non trouvé"
|
||||||
|
assert "K80" in codes, "K80 (lithiase biliaire) non trouvé"
|
||||||
|
assert "E66" in codes, "E66 (obésité) non trouvé"
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not CIM10_PDF.exists(),
|
||||||
|
reason=f"PDF CIM-10 non trouvé : {CIM10_PDF}",
|
||||||
|
)
|
||||||
|
def test_chunk_content(self):
|
||||||
|
from src.medical.rag_index import _chunk_cim10
|
||||||
|
|
||||||
|
chunks = _chunk_cim10(CIM10_PDF)
|
||||||
|
k85_chunks = [c for c in chunks if c.code == "K85"]
|
||||||
|
assert len(k85_chunks) >= 1
|
||||||
|
assert "pancréatite" in k85_chunks[0].text.lower() or "pancreatite" in k85_chunks[0].text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkingGuideMethodo:
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not GUIDE_METHODO_PDF.exists(),
|
||||||
|
reason=f"PDF Guide Métho non trouvé : {GUIDE_METHODO_PDF}",
|
||||||
|
)
|
||||||
|
def test_chunks_extracted(self):
|
||||||
|
from src.medical.rag_index import _chunk_guide_methodo
|
||||||
|
|
||||||
|
chunks = _chunk_guide_methodo(GUIDE_METHODO_PDF)
|
||||||
|
assert len(chunks) >= 10, f"Trop peu de chunks : {len(chunks)}"
|
||||||
|
assert all(c.document == "guide_methodo" for c in chunks)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkingCCAM:
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not CCAM_PDF.exists(),
|
||||||
|
reason=f"PDF CCAM non trouvé : {CCAM_PDF}",
|
||||||
|
)
|
||||||
|
def test_chunks_extracted(self):
|
||||||
|
from src.medical.rag_index import _chunk_ccam
|
||||||
|
|
||||||
|
chunks = _chunk_ccam(CCAM_PDF)
|
||||||
|
assert len(chunks) >= 1, f"Aucun chunk CCAM extrait"
|
||||||
|
assert all(c.document == "ccam" for c in chunks)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRAGSearchMocked:
|
||||||
|
def test_search_similar_no_index(self):
|
||||||
|
"""search_similar retourne une liste vide si l'index n'existe pas."""
|
||||||
|
from src.medical.rag_search import search_similar
|
||||||
|
|
||||||
|
with patch("src.medical.rag_index.get_index", return_value=None):
|
||||||
|
results = search_similar("pancréatite aiguë")
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
def test_enrich_diagnostic_no_sources(self):
|
||||||
|
"""enrich_diagnostic ne plante pas si aucune source trouvée."""
|
||||||
|
from src.medical.rag_search import enrich_diagnostic
|
||||||
|
|
||||||
|
diag = Diagnostic(texte="test quelconque", cim10_suggestion="Z99.9")
|
||||||
|
|
||||||
|
with patch("src.medical.rag_search.search_similar", return_value=[]):
|
||||||
|
enrich_diagnostic(diag, {"sexe": "M", "age": 50})
|
||||||
|
|
||||||
|
assert diag.sources_rag == []
|
||||||
|
assert diag.justification is None
|
||||||
|
|
||||||
|
def test_enrich_diagnostic_with_sources_no_ollama(self):
|
||||||
|
"""Enrichissement avec sources FAISS mais sans Ollama."""
|
||||||
|
from src.medical.rag_search import enrich_diagnostic
|
||||||
|
|
||||||
|
diag = Diagnostic(texte="Pancréatite aiguë", cim10_suggestion="K85.9")
|
||||||
|
mock_sources = [
|
||||||
|
{
|
||||||
|
"document": "cim10",
|
||||||
|
"page": 496,
|
||||||
|
"code": "K85",
|
||||||
|
"extrait": "K85 Pancréatite aiguë...",
|
||||||
|
"score": 0.92,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("src.medical.rag_search.search_similar", return_value=mock_sources), \
|
||||||
|
patch("src.medical.rag_search._call_ollama", return_value=None):
|
||||||
|
enrich_diagnostic(diag, {"sexe": "M", "age": 50})
|
||||||
|
|
||||||
|
assert len(diag.sources_rag) == 1
|
||||||
|
assert diag.sources_rag[0].document == "cim10"
|
||||||
|
assert diag.sources_rag[0].code == "K85"
|
||||||
|
# Pas de justification (Ollama non disponible)
|
||||||
|
assert diag.justification is None
|
||||||
|
|
||||||
|
def test_enrich_diagnostic_with_ollama(self):
|
||||||
|
"""Enrichissement complet avec sources + Ollama."""
|
||||||
|
from src.medical.rag_search import enrich_diagnostic
|
||||||
|
|
||||||
|
diag = Diagnostic(texte="Pancréatite aiguë biliaire")
|
||||||
|
mock_sources = [
|
||||||
|
{
|
||||||
|
"document": "cim10",
|
||||||
|
"page": 496,
|
||||||
|
"code": "K85",
|
||||||
|
"extrait": "K85 Pancréatite aiguë...",
|
||||||
|
"score": 0.95,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
mock_llm = {
|
||||||
|
"code": "K85.1",
|
||||||
|
"confidence": "high",
|
||||||
|
"justification": "Pancréatite aiguë d'origine biliaire = K85.1",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("src.medical.rag_search.search_similar", return_value=mock_sources), \
|
||||||
|
patch("src.medical.rag_search._call_ollama", return_value=mock_llm):
|
||||||
|
enrich_diagnostic(diag, {"sexe": "F", "age": 43})
|
||||||
|
|
||||||
|
assert diag.cim10_suggestion == "K85.1"
|
||||||
|
assert diag.cim10_confidence == "high"
|
||||||
|
assert diag.justification == "Pancréatite aiguë d'origine biliaire = K85.1"
|
||||||
|
assert len(diag.sources_rag) == 1
|
||||||
Reference in New Issue
Block a user