feat: ajout RAG CIM-10 avec FAISS + Ollama
Implémente un système RAG (Retrieval Augmented Generation) qui indexe les documents de référence ATIH (CIM-10 FR 2026, Guide Métho MCO, CCAM PMSI) et utilise Ollama (mistral-small3.2:24b) pour justifier et valider le codage CIM-10 des diagnostics. - Nouveaux modèles Pydantic : RAGSource, Diagnostic étendu (confidence, justification, sources_rag) — rétrocompatible - Module rag_index.py : chunking des 3 PDFs, embedding sentence-camembert-large, index FAISS IndexFlatIP (3630 vecteurs) - Module rag_search.py : recherche FAISS + appel Ollama avec fallback double - Flag CLI --no-rag pour désactiver l'enrichissement RAG - 18 nouveaux tests (88/88 passent) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@ __pycache__/
|
||||
output/
|
||||
input/
|
||||
*.egg-info/
|
||||
data/
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
pdfplumber>=0.10.0
|
||||
transformers>=4.35.0
|
||||
transformers>=4.35.0,<5.0.0
|
||||
torch>=2.1.0
|
||||
protobuf>=3.20.0,<4.0.0
|
||||
regex>=2023.0
|
||||
pydantic>=2.5.0
|
||||
pytest>=7.4.0
|
||||
sentencepiece>=0.1.99,<0.2.0
|
||||
edsnlp[ml]>=0.17.0
|
||||
faiss-cpu>=1.7.0
|
||||
sentence-transformers>=2.2.0
|
||||
requests>=2.28.0
|
||||
|
||||
@@ -28,9 +28,24 @@ NER_MODEL = "Jean-Baptiste/camembert-ner"
|
||||
NER_CONFIDENCE_THRESHOLD = 0.80
|
||||
|
||||
|
||||
# --- Configuration RAG ---
|
||||
|
||||
RAG_INDEX_DIR = BASE_DIR / "data" / "rag_index"
|
||||
CIM10_PDF = Path("/home/dom/ai/aivanov_CIM/cim-10-fr_2026_a_usage_pmsi_version_provisoire_111225.pdf")
|
||||
GUIDE_METHODO_PDF = Path("/home/dom/ai/aivanov_CIM/guide_methodo_mco_2026_version_provisoire.pdf")
|
||||
CCAM_PDF = Path("/home/dom/ai/aivanov_CIM/actualisation_ccam_descriptive_a_usage_pmsi_v4_2025.pdf")
|
||||
|
||||
|
||||
# --- Modèles de données CIM-10 ---
|
||||
|
||||
|
||||
class RAGSource(BaseModel):
|
||||
document: str
|
||||
page: Optional[int] = None
|
||||
code: Optional[str] = None
|
||||
extrait: Optional[str] = None
|
||||
|
||||
|
||||
class Sejour(BaseModel):
|
||||
sexe: Optional[str] = None
|
||||
age: Optional[int] = None
|
||||
@@ -47,6 +62,9 @@ class Sejour(BaseModel):
|
||||
class Diagnostic(BaseModel):
|
||||
texte: str
|
||||
cim10_suggestion: Optional[str] = None
|
||||
cim10_confidence: Optional[str] = None
|
||||
justification: Optional[str] = None
|
||||
sources_rag: list[RAGSource] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ActeCCAM(BaseModel):
|
||||
|
||||
15
src/main.py
15
src/main.py
@@ -22,8 +22,9 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Flag global pour désactiver edsnlp
|
||||
# Flags globaux
|
||||
_use_edsnlp = True
|
||||
_use_rag = True
|
||||
|
||||
|
||||
def process_pdf(pdf_path: Path) -> tuple[str, DossierMedical, AnonymizationReport]:
|
||||
@@ -63,7 +64,7 @@ def process_pdf(pdf_path: Path) -> tuple[str, DossierMedical, AnonymizationRepor
|
||||
edsnlp_result = _run_edsnlp(anonymized_text)
|
||||
|
||||
# 6. Extraction médicale CIM-10
|
||||
dossier = extract_medical_info(parsed, anonymized_text, edsnlp_result)
|
||||
dossier = extract_medical_info(parsed, anonymized_text, edsnlp_result, use_rag=_use_rag)
|
||||
dossier.source_file = pdf_path.name
|
||||
dossier.document_type = doc_type
|
||||
logger.info(" DP : %s", dossier.diagnostic_principal)
|
||||
@@ -123,7 +124,7 @@ def write_outputs(
|
||||
|
||||
def main(input_path: str | None = None) -> None:
|
||||
"""Point d'entrée principal."""
|
||||
global _use_edsnlp
|
||||
global _use_edsnlp, _use_rag
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Anonymisation de documents médicaux PDF et extraction CIM-10",
|
||||
@@ -144,6 +145,11 @@ def main(input_path: str | None = None) -> None:
|
||||
action="store_true",
|
||||
help="Désactiver l'analyse edsnlp (mode regex seul)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-rag",
|
||||
action="store_true",
|
||||
help="Désactiver l'enrichissement RAG (FAISS + Ollama)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.no_ner:
|
||||
@@ -154,6 +160,9 @@ def main(input_path: str | None = None) -> None:
|
||||
if args.no_edsnlp:
|
||||
_use_edsnlp = False
|
||||
|
||||
if args.no_rag:
|
||||
_use_rag = False
|
||||
|
||||
input_p = Path(args.input)
|
||||
if input_p.is_file():
|
||||
pdfs = [input_p]
|
||||
|
||||
@@ -2,10 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ..config import (
|
||||
ActeCCAM,
|
||||
BiologieCle,
|
||||
@@ -91,6 +94,7 @@ def extract_medical_info(
|
||||
parsed_data: dict,
|
||||
anonymized_text: str,
|
||||
edsnlp_result: Optional[EdsnlpResult] = None,
|
||||
use_rag: bool = False,
|
||||
) -> DossierMedical:
|
||||
"""Extrait les informations médicales structurées depuis les données parsées et le texte."""
|
||||
dossier = DossierMedical()
|
||||
@@ -105,9 +109,23 @@ def extract_medical_info(
|
||||
_extract_imagerie(anonymized_text, dossier)
|
||||
_extract_complications(anonymized_text, dossier, edsnlp_result)
|
||||
|
||||
if use_rag:
|
||||
_enrich_with_rag(dossier)
|
||||
|
||||
return dossier
|
||||
|
||||
|
||||
def _enrich_with_rag(dossier: DossierMedical) -> None:
|
||||
"""Enrichit les diagnostics via le RAG (FAISS + Ollama)."""
|
||||
try:
|
||||
from .rag_search import enrich_dossier
|
||||
enrich_dossier(dossier)
|
||||
except ImportError:
|
||||
logger.warning("Module RAG non disponible (faiss-cpu ou sentence-transformers manquant)")
|
||||
except Exception:
|
||||
logger.warning("Erreur lors de l'enrichissement RAG", exc_info=True)
|
||||
|
||||
|
||||
def _extract_sejour(parsed: dict, dossier: DossierMedical) -> None:
|
||||
"""Extrait les informations de séjour."""
|
||||
patient = parsed.get("patient", {})
|
||||
|
||||
352
src/medical/rag_index.py
Normal file
352
src/medical/rag_index.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Indexation FAISS des documents de référence CIM-10 / Guide métho / CCAM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pdfplumber
|
||||
|
||||
from ..config import RAG_INDEX_DIR, CIM10_PDF, GUIDE_METHODO_PDF, CCAM_PDF
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Singleton pour l'index chargé en mémoire
|
||||
_faiss_index = None
|
||||
_metadata: list[dict] = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
text: str
|
||||
document: str # "cim10", "guide_methodo", "ccam"
|
||||
page: Optional[int] = None
|
||||
code: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chunking CIM-10
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _chunk_cim10(pdf_path: Path) -> list[Chunk]:
|
||||
"""Découpe le PDF CIM-10 en chunks par code 3 caractères (ex: K80, K85)."""
|
||||
chunks: list[Chunk] = []
|
||||
current_code: str | None = None
|
||||
current_text: list[str] = []
|
||||
current_page: int | None = None
|
||||
|
||||
# Pattern pour détecter un code CIM-10 à 3 caractères en début de ligne
|
||||
code3_pattern = re.compile(r"^([A-Z]\d{2})\s+(.+)")
|
||||
# Pattern pour les sous-codes (ex: K80.0, K80.1)
|
||||
subcode_pattern = re.compile(r"^([A-Z]\d{2}\.\d+)\s+(.+)")
|
||||
|
||||
logger.info("Extraction des chunks CIM-10 depuis %s", pdf_path.name)
|
||||
|
||||
with pdfplumber.open(pdf_path) as pdf:
|
||||
for page_num, page in enumerate(pdf.pages, start=1):
|
||||
text = page.extract_text()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
m = code3_pattern.match(line)
|
||||
if m and not subcode_pattern.match(line):
|
||||
# Nouveau code 3-char → sauvegarder le chunk précédent
|
||||
if current_code and current_text:
|
||||
chunk_text = "\n".join(current_text)
|
||||
if len(chunk_text.split()) >= 5:
|
||||
chunks.append(Chunk(
|
||||
text=chunk_text,
|
||||
document="cim10",
|
||||
page=current_page,
|
||||
code=current_code,
|
||||
))
|
||||
current_code = m.group(1)
|
||||
current_text = [line]
|
||||
current_page = page_num
|
||||
else:
|
||||
if current_code:
|
||||
current_text.append(line)
|
||||
|
||||
# Dernier chunk
|
||||
if current_code and current_text:
|
||||
chunk_text = "\n".join(current_text)
|
||||
if len(chunk_text.split()) >= 5:
|
||||
chunks.append(Chunk(
|
||||
text=chunk_text,
|
||||
document="cim10",
|
||||
page=current_page,
|
||||
code=current_code,
|
||||
))
|
||||
|
||||
logger.info("CIM-10 : %d chunks extraits", len(chunks))
|
||||
return chunks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chunking Guide Méthodologique MCO
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _chunk_guide_methodo(pdf_path: Path) -> list[Chunk]:
|
||||
"""Découpe le Guide Méthodologique MCO par sections/titres."""
|
||||
chunks: list[Chunk] = []
|
||||
current_title: str | None = None
|
||||
current_text: list[str] = []
|
||||
current_page: int | None = None
|
||||
|
||||
# Patterns de titres de sections (chapitres, sous-chapitres)
|
||||
title_patterns = [
|
||||
re.compile(r"^((?:CHAPITRE|TITRE|PARTIE)\s+[IVXLCDM0-9]+.*)$", re.IGNORECASE),
|
||||
re.compile(r"^(\d+\.\d*\s+[A-ZÉÈÊÀÂÔÙÛÜ].{5,})$"),
|
||||
re.compile(r"^([A-ZÉÈÊÀÂÔÙÛÜ][A-ZÉÈÊÀÂÔÙÛÜ\s]{10,})$"),
|
||||
]
|
||||
|
||||
logger.info("Extraction des chunks Guide Métho depuis %s", pdf_path.name)
|
||||
|
||||
with pdfplumber.open(pdf_path) as pdf:
|
||||
for page_num, page in enumerate(pdf.pages, start=1):
|
||||
text = page.extract_text()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
is_title = False
|
||||
for pat in title_patterns:
|
||||
if pat.match(line):
|
||||
is_title = True
|
||||
break
|
||||
|
||||
if is_title and len(line) > 8:
|
||||
# Sauvegarder le chunk précédent
|
||||
if current_title and current_text:
|
||||
chunk_text = current_title + "\n" + "\n".join(current_text)
|
||||
if len(chunk_text.split()) >= 20:
|
||||
chunks.append(Chunk(
|
||||
text=chunk_text,
|
||||
document="guide_methodo",
|
||||
page=current_page,
|
||||
))
|
||||
current_title = line
|
||||
current_text = []
|
||||
current_page = page_num
|
||||
else:
|
||||
current_text.append(line)
|
||||
|
||||
# Dernier chunk
|
||||
if current_title and current_text:
|
||||
chunk_text = current_title + "\n" + "\n".join(current_text)
|
||||
if len(chunk_text.split()) >= 20:
|
||||
chunks.append(Chunk(
|
||||
text=chunk_text,
|
||||
document="guide_methodo",
|
||||
page=current_page,
|
||||
))
|
||||
|
||||
# Si trop peu de chunks (le PDF ne suit pas les patterns de titre),
|
||||
# fallback : découper par pages groupées par 3
|
||||
if len(chunks) < 10:
|
||||
logger.info("Guide Métho : fallback découpe par pages (peu de titres détectés)")
|
||||
chunks = []
|
||||
with pdfplumber.open(pdf_path) as pdf:
|
||||
page_texts: list[str] = []
|
||||
start_page = 1
|
||||
for page_num, page in enumerate(pdf.pages, start=1):
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
page_texts.append(text)
|
||||
if len(page_texts) >= 3:
|
||||
combined = "\n".join(page_texts)
|
||||
if len(combined.split()) >= 20:
|
||||
chunks.append(Chunk(
|
||||
text=combined,
|
||||
document="guide_methodo",
|
||||
page=start_page,
|
||||
))
|
||||
page_texts = []
|
||||
start_page = page_num + 1
|
||||
if page_texts:
|
||||
combined = "\n".join(page_texts)
|
||||
if len(combined.split()) >= 20:
|
||||
chunks.append(Chunk(
|
||||
text=combined,
|
||||
document="guide_methodo",
|
||||
page=start_page,
|
||||
))
|
||||
|
||||
logger.info("Guide Métho : %d chunks extraits", len(chunks))
|
||||
return chunks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chunking CCAM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _chunk_ccam(pdf_path: Path) -> list[Chunk]:
|
||||
"""Découpe le PDF CCAM en chunks par code d'acte."""
|
||||
chunks: list[Chunk] = []
|
||||
ccam_pattern = re.compile(r"([A-Z]{4}\d{3})\s+(.*)")
|
||||
|
||||
logger.info("Extraction des chunks CCAM depuis %s", pdf_path.name)
|
||||
|
||||
with pdfplumber.open(pdf_path) as pdf:
|
||||
for page_num, page in enumerate(pdf.pages, start=1):
|
||||
text = page.extract_text()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
current_code: str | None = None
|
||||
current_lines: list[str] = []
|
||||
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
m = ccam_pattern.match(line)
|
||||
if m:
|
||||
if current_code and current_lines:
|
||||
chunks.append(Chunk(
|
||||
text="\n".join(current_lines),
|
||||
document="ccam",
|
||||
page=page_num,
|
||||
code=current_code,
|
||||
))
|
||||
current_code = m.group(1)
|
||||
current_lines = [line]
|
||||
elif current_code:
|
||||
current_lines.append(line)
|
||||
|
||||
if current_code and current_lines:
|
||||
chunks.append(Chunk(
|
||||
text="\n".join(current_lines),
|
||||
document="ccam",
|
||||
page=page_num,
|
||||
code=current_code,
|
||||
))
|
||||
|
||||
# Fallback : si aucun code CCAM détecté, indexer par page
|
||||
if not chunks:
|
||||
logger.info("CCAM : aucun code détecté, fallback par page")
|
||||
with pdfplumber.open(pdf_path) as pdf:
|
||||
for page_num, page in enumerate(pdf.pages, start=1):
|
||||
text = page.extract_text()
|
||||
if text and len(text.split()) >= 10:
|
||||
chunks.append(Chunk(
|
||||
text=text,
|
||||
document="ccam",
|
||||
page=page_num,
|
||||
))
|
||||
|
||||
logger.info("CCAM : %d chunks extraits", len(chunks))
|
||||
return chunks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction de l'index FAISS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_index(force: bool = False) -> None:
|
||||
"""Construit l'index FAISS à partir des 3 PDFs de référence.
|
||||
|
||||
Args:
|
||||
force: Si True, reconstruit même si l'index existe déjà.
|
||||
"""
|
||||
import faiss
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
index_path = RAG_INDEX_DIR / "faiss.index"
|
||||
meta_path = RAG_INDEX_DIR / "metadata.json"
|
||||
|
||||
if not force and index_path.exists() and meta_path.exists():
|
||||
logger.info("Index FAISS déjà existant dans %s (use force=True pour reconstruire)", RAG_INDEX_DIR)
|
||||
return
|
||||
|
||||
# Collecter tous les chunks
|
||||
all_chunks: list[Chunk] = []
|
||||
|
||||
for pdf_path, chunk_fn in [
|
||||
(CIM10_PDF, _chunk_cim10),
|
||||
(GUIDE_METHODO_PDF, _chunk_guide_methodo),
|
||||
(CCAM_PDF, _chunk_ccam),
|
||||
]:
|
||||
if pdf_path.exists():
|
||||
all_chunks.extend(chunk_fn(pdf_path))
|
||||
else:
|
||||
logger.warning("PDF non trouvé : %s", pdf_path)
|
||||
|
||||
if not all_chunks:
|
||||
logger.error("Aucun chunk extrait — vérifiez les chemins des PDFs")
|
||||
return
|
||||
|
||||
logger.info("Total : %d chunks à indexer", len(all_chunks))
|
||||
|
||||
# Embeddings — forcer CPU pour éviter les bugs CUDA avec ce modèle
|
||||
logger.info("Chargement du modèle d'embedding dangvantuan/sentence-camembert-large (CPU)...")
|
||||
model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu")
|
||||
model.max_seq_length = 512 # CamemBERT max position embeddings
|
||||
|
||||
texts = [c.text[:2000] for c in all_chunks] # Tronquer les chunks trop longs
|
||||
logger.info("Calcul des embeddings pour %d chunks...", len(texts))
|
||||
embeddings = model.encode(
|
||||
texts, show_progress_bar=True, normalize_embeddings=True, batch_size=64,
|
||||
)
|
||||
embeddings = np.array(embeddings, dtype=np.float32)
|
||||
|
||||
# Index FAISS (IndexFlatIP = cosine similarity avec vecteurs normalisés)
|
||||
dim = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dim)
|
||||
index.add(embeddings)
|
||||
|
||||
# Sauvegarder
|
||||
RAG_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
faiss.write_index(index, str(index_path))
|
||||
|
||||
metadata = [asdict(c) for c in all_chunks]
|
||||
# Ne pas sauvegarder le texte complet dans metadata (trop lourd),
|
||||
# garder un extrait de 500 chars
|
||||
for m in metadata:
|
||||
m["extrait"] = m.pop("text")[:500]
|
||||
|
||||
meta_path.write_text(json.dumps(metadata, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
logger.info("Index FAISS sauvegardé : %s (%d vecteurs, dim=%d)", index_path, len(all_chunks), dim)
|
||||
|
||||
|
||||
def get_index() -> tuple | None:
|
||||
"""Charge l'index FAISS et les métadonnées (singleton lazy-loaded).
|
||||
|
||||
Returns:
|
||||
Tuple (faiss_index, metadata_list) ou None si l'index n'existe pas.
|
||||
"""
|
||||
global _faiss_index, _metadata
|
||||
|
||||
if _faiss_index is not None:
|
||||
return _faiss_index, _metadata
|
||||
|
||||
index_path = RAG_INDEX_DIR / "faiss.index"
|
||||
meta_path = RAG_INDEX_DIR / "metadata.json"
|
||||
|
||||
if not index_path.exists() or not meta_path.exists():
|
||||
logger.warning("Index FAISS non trouvé dans %s — lancez build_index() d'abord", RAG_INDEX_DIR)
|
||||
return None
|
||||
|
||||
import faiss
|
||||
|
||||
_faiss_index = faiss.read_index(str(index_path))
|
||||
_metadata = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
|
||||
logger.info("Index FAISS chargé : %d vecteurs", _faiss_index.ntotal)
|
||||
return _faiss_index, _metadata
|
||||
208
src/medical/rag_search.py
Normal file
208
src/medical/rag_search.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from ..config import Diagnostic, DossierMedical, RAGSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration Ollama
|
||||
OLLAMA_URL = "http://localhost:11434/api/generate"
|
||||
OLLAMA_MODEL = "mistral-small3.2:24b"
|
||||
OLLAMA_TIMEOUT = 120 # secondes
|
||||
|
||||
# Singleton pour le modèle d'embedding (chargé une seule fois)
|
||||
_embed_model = None
|
||||
|
||||
|
||||
def _get_embed_model():
|
||||
"""Charge le modèle d'embedding (singleton)."""
|
||||
global _embed_model
|
||||
if _embed_model is None:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
logger.info("Chargement du modèle d'embedding pour la recherche...")
|
||||
_embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu")
|
||||
_embed_model.max_seq_length = 512
|
||||
return _embed_model
|
||||
|
||||
|
||||
def search_similar(query: str, top_k: int = 5) -> list[dict]:
|
||||
"""Recherche les passages les plus similaires dans l'index FAISS.
|
||||
|
||||
Args:
|
||||
query: Texte du diagnostic à rechercher.
|
||||
top_k: Nombre de résultats à retourner.
|
||||
|
||||
Returns:
|
||||
Liste de dicts avec les métadonnées + score de similarité.
|
||||
"""
|
||||
from .rag_index import get_index
|
||||
import numpy as np
|
||||
|
||||
result = get_index()
|
||||
if result is None:
|
||||
logger.warning("Index FAISS non disponible")
|
||||
return []
|
||||
|
||||
faiss_index, metadata = result
|
||||
|
||||
model = _get_embed_model()
|
||||
query_vec = model.encode([query], normalize_embeddings=True)
|
||||
query_vec = np.array(query_vec, dtype=np.float32)
|
||||
|
||||
scores, indices = faiss_index.search(query_vec, min(top_k, faiss_index.ntotal))
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx < 0:
|
||||
continue
|
||||
meta = metadata[idx].copy()
|
||||
meta["score"] = float(score)
|
||||
results.append(meta)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _build_prompt(texte: str, sources: list[dict], contexte: dict) -> str:
|
||||
"""Construit le prompt pour Ollama."""
|
||||
sources_text = ""
|
||||
for i, src in enumerate(sources, 1):
|
||||
doc_name = {
|
||||
"cim10": "CIM-10 FR 2026",
|
||||
"guide_methodo": "Guide Méthodologique MCO 2026",
|
||||
"ccam": "CCAM PMSI V4 2025",
|
||||
}.get(src["document"], src["document"])
|
||||
|
||||
code_info = f" (code: {src['code']})" if src.get("code") else ""
|
||||
page_info = f" [page {src['page']}]" if src.get("page") else ""
|
||||
|
||||
sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n"
|
||||
sources_text += (src.get("extrait", "")[:800]) + "\n\n"
|
||||
|
||||
ctx_parts = []
|
||||
if contexte.get("sexe"):
|
||||
ctx_parts.append(f"sexe: {contexte['sexe']}")
|
||||
if contexte.get("age"):
|
||||
ctx_parts.append(f"âge: {contexte['age']} ans")
|
||||
ctx_str = ", ".join(ctx_parts) if ctx_parts else "non précisé"
|
||||
|
||||
return f"""Tu es un expert en codage CIM-10 pour le PMSI en France. Suggère le code CIM-10 le plus précis pour le diagnostic suivant, en te basant UNIQUEMENT sur les sources officielles fournies.
|
||||
|
||||
Diagnostic à coder : "{texte}"
|
||||
Contexte patient : {ctx_str}
|
||||
|
||||
Sources de référence :
|
||||
{sources_text}
|
||||
Réponds UNIQUEMENT au format JSON suivant, sans texte avant ou après :
|
||||
{{"code": "X99.9", "confidence": "high|medium|low", "justification": "explication courte en français"}}"""
|
||||
|
||||
|
||||
def _call_ollama(prompt: str) -> dict | None:
|
||||
"""Appelle Ollama et parse la réponse JSON."""
|
||||
try:
|
||||
response = requests.post(
|
||||
OLLAMA_URL,
|
||||
json={
|
||||
"model": OLLAMA_MODEL,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": 0.1,
|
||||
"num_predict": 300,
|
||||
},
|
||||
},
|
||||
timeout=OLLAMA_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
raw = response.json().get("response", "")
|
||||
|
||||
# Extraire le JSON de la réponse (peut contenir du texte autour)
|
||||
json_match = None
|
||||
# Chercher un bloc JSON entre accolades
|
||||
brace_start = raw.find("{")
|
||||
brace_end = raw.rfind("}")
|
||||
if brace_start != -1 and brace_end != -1:
|
||||
json_match = raw[brace_start:brace_end + 1]
|
||||
|
||||
if json_match:
|
||||
return json.loads(json_match)
|
||||
else:
|
||||
logger.warning("Ollama : réponse sans JSON valide : %s", raw[:200])
|
||||
return None
|
||||
|
||||
except requests.ConnectionError:
|
||||
logger.warning("Ollama non disponible (connexion refusée)")
|
||||
return None
|
||||
except requests.Timeout:
|
||||
logger.warning("Ollama timeout après %ds", OLLAMA_TIMEOUT)
|
||||
return None
|
||||
except (requests.RequestException, json.JSONDecodeError) as e:
|
||||
logger.warning("Ollama erreur : %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def enrich_diagnostic(
|
||||
diagnostic: Diagnostic,
|
||||
contexte: dict,
|
||||
) -> None:
|
||||
"""Enrichit un Diagnostic avec le RAG (FAISS + Ollama).
|
||||
|
||||
Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent.
|
||||
"""
|
||||
# 1. Recherche FAISS
|
||||
sources = search_similar(diagnostic.texte, top_k=5)
|
||||
|
||||
if not sources:
|
||||
logger.debug("Aucune source RAG trouvée pour : %s", diagnostic.texte)
|
||||
return
|
||||
|
||||
# 2. Stocker les sources RAG
|
||||
diagnostic.sources_rag = [
|
||||
RAGSource(
|
||||
document=s["document"],
|
||||
page=s.get("page"),
|
||||
code=s.get("code"),
|
||||
extrait=s.get("extrait", "")[:200],
|
||||
)
|
||||
for s in sources
|
||||
]
|
||||
|
||||
# 3. Appel Ollama pour justification
|
||||
prompt = _build_prompt(diagnostic.texte, sources, contexte)
|
||||
llm_result = _call_ollama(prompt)
|
||||
|
||||
if llm_result:
|
||||
code = llm_result.get("code")
|
||||
confidence = llm_result.get("confidence")
|
||||
justification = llm_result.get("justification")
|
||||
|
||||
if code:
|
||||
diagnostic.cim10_suggestion = code
|
||||
if confidence in ("high", "medium", "low"):
|
||||
diagnostic.cim10_confidence = confidence
|
||||
if justification:
|
||||
diagnostic.justification = justification
|
||||
else:
|
||||
logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM")
|
||||
|
||||
|
||||
def enrich_dossier(dossier: DossierMedical) -> None:
|
||||
"""Enrichit le DP et tous les DAS d'un dossier via le RAG."""
|
||||
contexte = {
|
||||
"sexe": dossier.sejour.sexe,
|
||||
"age": dossier.sejour.age,
|
||||
}
|
||||
|
||||
if dossier.diagnostic_principal:
|
||||
logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte)
|
||||
enrich_diagnostic(dossier.diagnostic_principal, contexte)
|
||||
|
||||
for das in dossier.diagnostics_associes:
|
||||
logger.info("RAG enrichissement DAS : %s", das.texte)
|
||||
enrich_diagnostic(das, contexte)
|
||||
271
tests/test_rag.py
Normal file
271
tests/test_rag.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Tests pour le RAG CIM-10 (modèles, chunking, intégration)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import RAGSource, Diagnostic, DossierMedical, CIM10_PDF, GUIDE_METHODO_PDF, CCAM_PDF
|
||||
|
||||
|
||||
class TestRAGSource:
|
||||
def test_create_minimal(self):
|
||||
src = RAGSource(document="cim10")
|
||||
assert src.document == "cim10"
|
||||
assert src.page is None
|
||||
assert src.code is None
|
||||
assert src.extrait is None
|
||||
|
||||
def test_create_full(self):
|
||||
src = RAGSource(
|
||||
document="guide_methodo",
|
||||
page=42,
|
||||
code="K85",
|
||||
extrait="Pancréatite aiguë biliaire...",
|
||||
)
|
||||
assert src.document == "guide_methodo"
|
||||
assert src.page == 42
|
||||
assert src.code == "K85"
|
||||
assert src.extrait == "Pancréatite aiguë biliaire..."
|
||||
|
||||
def test_serialization(self):
|
||||
src = RAGSource(document="ccam", page=1, code="HMFC004")
|
||||
data = src.model_dump(exclude_none=True)
|
||||
assert data == {"document": "ccam", "page": 1, "code": "HMFC004"}
|
||||
|
||||
|
||||
class TestDiagnosticExtended:
|
||||
def test_backward_compatible(self):
|
||||
"""Les nouveaux champs sont optionnels — rétrocompatible."""
|
||||
d = Diagnostic(texte="Pancréatite aiguë", cim10_suggestion="K85.9")
|
||||
assert d.texte == "Pancréatite aiguë"
|
||||
assert d.cim10_suggestion == "K85.9"
|
||||
assert d.cim10_confidence is None
|
||||
assert d.justification is None
|
||||
assert d.sources_rag == []
|
||||
|
||||
def test_with_rag_fields(self):
|
||||
d = Diagnostic(
|
||||
texte="Lithiase cholédoque",
|
||||
cim10_suggestion="K80.5",
|
||||
cim10_confidence="high",
|
||||
justification="Code K80.5 correspond à la lithiase du cholédoque",
|
||||
sources_rag=[
|
||||
RAGSource(document="cim10", page=480, code="K80"),
|
||||
],
|
||||
)
|
||||
assert d.cim10_confidence == "high"
|
||||
assert d.justification is not None
|
||||
assert len(d.sources_rag) == 1
|
||||
assert d.sources_rag[0].code == "K80"
|
||||
|
||||
def test_serialization_exclude_none(self):
|
||||
"""Vérifier que le JSON n'inclut pas les champs None."""
|
||||
d = Diagnostic(texte="Test", cim10_suggestion="K85.9")
|
||||
data = d.model_dump(exclude_none=True)
|
||||
assert "cim10_confidence" not in data
|
||||
assert "justification" not in data
|
||||
assert "sources_rag" in data # list vide incluse
|
||||
|
||||
def test_dossier_with_extended_diagnostic(self):
|
||||
"""Un DossierMedical avec des diagnostics enrichis par le RAG."""
|
||||
dossier = DossierMedical(
|
||||
diagnostic_principal=Diagnostic(
|
||||
texte="Pancréatite aiguë biliaire",
|
||||
cim10_suggestion="K85.1",
|
||||
cim10_confidence="high",
|
||||
justification="Confirmé par CIM-10 FR 2026",
|
||||
sources_rag=[
|
||||
RAGSource(document="cim10", page=496, code="K85"),
|
||||
RAGSource(document="guide_methodo", page=30),
|
||||
],
|
||||
),
|
||||
)
|
||||
assert dossier.diagnostic_principal.cim10_confidence == "high"
|
||||
assert len(dossier.diagnostic_principal.sources_rag) == 2
|
||||
|
||||
|
||||
class TestExtractMedicalInfoRAGFlag:
|
||||
def test_use_rag_false_no_change(self):
|
||||
"""use_rag=False ne modifie pas le comportement existant."""
|
||||
from src.medical.cim10_extractor import extract_medical_info
|
||||
|
||||
parsed = {
|
||||
"type": "crh",
|
||||
"patient": {"sexe": "M"},
|
||||
"sejour": {},
|
||||
"diagnostics": [],
|
||||
}
|
||||
text = "Pancréatite aiguë biliaire.\nTTT de sortie :\nParacétamol\n\nDevenir : retour."
|
||||
|
||||
dossier = extract_medical_info(parsed, text, use_rag=False)
|
||||
assert dossier.diagnostic_principal is not None
|
||||
assert dossier.diagnostic_principal.cim10_suggestion == "K85.1"
|
||||
# Pas de sources RAG
|
||||
assert dossier.diagnostic_principal.sources_rag == []
|
||||
assert dossier.diagnostic_principal.justification is None
|
||||
|
||||
def test_use_rag_true_calls_enrich(self):
|
||||
"""use_rag=True appelle _enrich_with_rag (mocké)."""
|
||||
from src.medical.cim10_extractor import extract_medical_info
|
||||
|
||||
parsed = {
|
||||
"type": "crh",
|
||||
"patient": {"sexe": "M"},
|
||||
"sejour": {},
|
||||
"diagnostics": [],
|
||||
}
|
||||
text = "Pancréatite aiguë biliaire.\nTTT de sortie :\nParacétamol\n\nDevenir : retour."
|
||||
|
||||
with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich:
|
||||
dossier = extract_medical_info(parsed, text, use_rag=True)
|
||||
mock_enrich.assert_called_once_with(dossier)
|
||||
|
||||
def test_use_rag_default_false(self):
|
||||
"""Par défaut, use_rag=False."""
|
||||
from src.medical.cim10_extractor import extract_medical_info
|
||||
|
||||
parsed = {
|
||||
"type": "crh",
|
||||
"patient": {"sexe": "M"},
|
||||
"sejour": {},
|
||||
"diagnostics": [],
|
||||
}
|
||||
text = "Test simple."
|
||||
|
||||
with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich:
|
||||
extract_medical_info(parsed, text)
|
||||
mock_enrich.assert_not_called()
|
||||
|
||||
|
||||
class TestChunkingCIM10:
|
||||
@pytest.mark.skipif(
|
||||
not CIM10_PDF.exists(),
|
||||
reason=f"PDF CIM-10 non trouvé : {CIM10_PDF}",
|
||||
)
|
||||
def test_chunks_contain_known_codes(self):
|
||||
from src.medical.rag_index import _chunk_cim10
|
||||
|
||||
chunks = _chunk_cim10(CIM10_PDF)
|
||||
assert len(chunks) > 100, f"Trop peu de chunks : {len(chunks)}"
|
||||
|
||||
codes = {c.code for c in chunks if c.code}
|
||||
assert "K85" in codes, "K85 (pancréatite) non trouvé"
|
||||
assert "K80" in codes, "K80 (lithiase biliaire) non trouvé"
|
||||
assert "E66" in codes, "E66 (obésité) non trouvé"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not CIM10_PDF.exists(),
|
||||
reason=f"PDF CIM-10 non trouvé : {CIM10_PDF}",
|
||||
)
|
||||
def test_chunk_content(self):
|
||||
from src.medical.rag_index import _chunk_cim10
|
||||
|
||||
chunks = _chunk_cim10(CIM10_PDF)
|
||||
k85_chunks = [c for c in chunks if c.code == "K85"]
|
||||
assert len(k85_chunks) >= 1
|
||||
assert "pancréatite" in k85_chunks[0].text.lower() or "pancreatite" in k85_chunks[0].text.lower()
|
||||
|
||||
|
||||
class TestChunkingGuideMethodo:
|
||||
@pytest.mark.skipif(
|
||||
not GUIDE_METHODO_PDF.exists(),
|
||||
reason=f"PDF Guide Métho non trouvé : {GUIDE_METHODO_PDF}",
|
||||
)
|
||||
def test_chunks_extracted(self):
|
||||
from src.medical.rag_index import _chunk_guide_methodo
|
||||
|
||||
chunks = _chunk_guide_methodo(GUIDE_METHODO_PDF)
|
||||
assert len(chunks) >= 10, f"Trop peu de chunks : {len(chunks)}"
|
||||
assert all(c.document == "guide_methodo" for c in chunks)
|
||||
|
||||
|
||||
class TestChunkingCCAM:
|
||||
@pytest.mark.skipif(
|
||||
not CCAM_PDF.exists(),
|
||||
reason=f"PDF CCAM non trouvé : {CCAM_PDF}",
|
||||
)
|
||||
def test_chunks_extracted(self):
|
||||
from src.medical.rag_index import _chunk_ccam
|
||||
|
||||
chunks = _chunk_ccam(CCAM_PDF)
|
||||
assert len(chunks) >= 1, f"Aucun chunk CCAM extrait"
|
||||
assert all(c.document == "ccam" for c in chunks)
|
||||
|
||||
|
||||
class TestRAGSearchMocked:
|
||||
def test_search_similar_no_index(self):
|
||||
"""search_similar retourne une liste vide si l'index n'existe pas."""
|
||||
from src.medical.rag_search import search_similar
|
||||
|
||||
with patch("src.medical.rag_index.get_index", return_value=None):
|
||||
results = search_similar("pancréatite aiguë")
|
||||
assert results == []
|
||||
|
||||
def test_enrich_diagnostic_no_sources(self):
|
||||
"""enrich_diagnostic ne plante pas si aucune source trouvée."""
|
||||
from src.medical.rag_search import enrich_diagnostic
|
||||
|
||||
diag = Diagnostic(texte="test quelconque", cim10_suggestion="Z99.9")
|
||||
|
||||
with patch("src.medical.rag_search.search_similar", return_value=[]):
|
||||
enrich_diagnostic(diag, {"sexe": "M", "age": 50})
|
||||
|
||||
assert diag.sources_rag == []
|
||||
assert diag.justification is None
|
||||
|
||||
def test_enrich_diagnostic_with_sources_no_ollama(self):
|
||||
"""Enrichissement avec sources FAISS mais sans Ollama."""
|
||||
from src.medical.rag_search import enrich_diagnostic
|
||||
|
||||
diag = Diagnostic(texte="Pancréatite aiguë", cim10_suggestion="K85.9")
|
||||
mock_sources = [
|
||||
{
|
||||
"document": "cim10",
|
||||
"page": 496,
|
||||
"code": "K85",
|
||||
"extrait": "K85 Pancréatite aiguë...",
|
||||
"score": 0.92,
|
||||
},
|
||||
]
|
||||
|
||||
with patch("src.medical.rag_search.search_similar", return_value=mock_sources), \
|
||||
patch("src.medical.rag_search._call_ollama", return_value=None):
|
||||
enrich_diagnostic(diag, {"sexe": "M", "age": 50})
|
||||
|
||||
assert len(diag.sources_rag) == 1
|
||||
assert diag.sources_rag[0].document == "cim10"
|
||||
assert diag.sources_rag[0].code == "K85"
|
||||
# Pas de justification (Ollama non disponible)
|
||||
assert diag.justification is None
|
||||
|
||||
def test_enrich_diagnostic_with_ollama(self):
|
||||
"""Enrichissement complet avec sources + Ollama."""
|
||||
from src.medical.rag_search import enrich_diagnostic
|
||||
|
||||
diag = Diagnostic(texte="Pancréatite aiguë biliaire")
|
||||
mock_sources = [
|
||||
{
|
||||
"document": "cim10",
|
||||
"page": 496,
|
||||
"code": "K85",
|
||||
"extrait": "K85 Pancréatite aiguë...",
|
||||
"score": 0.95,
|
||||
},
|
||||
]
|
||||
mock_llm = {
|
||||
"code": "K85.1",
|
||||
"confidence": "high",
|
||||
"justification": "Pancréatite aiguë d'origine biliaire = K85.1",
|
||||
}
|
||||
|
||||
with patch("src.medical.rag_search.search_similar", return_value=mock_sources), \
|
||||
patch("src.medical.rag_search._call_ollama", return_value=mock_llm):
|
||||
enrich_diagnostic(diag, {"sexe": "F", "age": 43})
|
||||
|
||||
assert diag.cim10_suggestion == "K85.1"
|
||||
assert diag.cim10_confidence == "high"
|
||||
assert diag.justification == "Pancréatite aiguë d'origine biliaire = K85.1"
|
||||
assert len(diag.sources_rag) == 1
|
||||
Reference in New Issue
Block a user