feat: ajout RAG CIM-10 avec FAISS + Ollama

Implémente un système RAG (Retrieval Augmented Generation) qui indexe
les documents de référence ATIH (CIM-10 FR 2026, Guide Métho MCO,
CCAM PMSI) et utilise Ollama (mistral-small3.2:24b) pour justifier
et valider le codage CIM-10 des diagnostics.

- Nouveaux modèles Pydantic : RAGSource, Diagnostic étendu (confidence,
  justification, sources_rag) — rétrocompatible
- Module rag_index.py : chunking des 3 PDFs, embedding sentence-camembert-large,
  index FAISS IndexFlatIP (3630 vecteurs)
- Module rag_search.py : recherche FAISS + appel Ollama avec fallback double
- Flag CLI --no-rag pour désactiver l'enrichissement RAG
- 18 nouveaux tests (88/88 passent)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dom
2026-02-10 17:47:08 +01:00
parent 4a12cd2676
commit 4d6fbef2b9
8 changed files with 885 additions and 4 deletions

View File

@@ -28,9 +28,24 @@ NER_MODEL = "Jean-Baptiste/camembert-ner"
NER_CONFIDENCE_THRESHOLD = 0.80
# --- Configuration RAG ---
RAG_INDEX_DIR = BASE_DIR / "data" / "rag_index"
CIM10_PDF = Path("/home/dom/ai/aivanov_CIM/cim-10-fr_2026_a_usage_pmsi_version_provisoire_111225.pdf")
GUIDE_METHODO_PDF = Path("/home/dom/ai/aivanov_CIM/guide_methodo_mco_2026_version_provisoire.pdf")
CCAM_PDF = Path("/home/dom/ai/aivanov_CIM/actualisation_ccam_descriptive_a_usage_pmsi_v4_2025.pdf")
# --- Modèles de données CIM-10 ---
class RAGSource(BaseModel):
document: str
page: Optional[int] = None
code: Optional[str] = None
extrait: Optional[str] = None
class Sejour(BaseModel):
sexe: Optional[str] = None
age: Optional[int] = None
@@ -47,6 +62,9 @@ class Sejour(BaseModel):
class Diagnostic(BaseModel):
texte: str
cim10_suggestion: Optional[str] = None
cim10_confidence: Optional[str] = None
justification: Optional[str] = None
sources_rag: list[RAGSource] = Field(default_factory=list)
class ActeCCAM(BaseModel):

View File

@@ -22,8 +22,9 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
# Flag global pour désactiver edsnlp
# Flags globaux
_use_edsnlp = True
_use_rag = True
def process_pdf(pdf_path: Path) -> tuple[str, DossierMedical, AnonymizationReport]:
@@ -63,7 +64,7 @@ def process_pdf(pdf_path: Path) -> tuple[str, DossierMedical, AnonymizationRepor
edsnlp_result = _run_edsnlp(anonymized_text)
# 6. Extraction médicale CIM-10
dossier = extract_medical_info(parsed, anonymized_text, edsnlp_result)
dossier = extract_medical_info(parsed, anonymized_text, edsnlp_result, use_rag=_use_rag)
dossier.source_file = pdf_path.name
dossier.document_type = doc_type
logger.info(" DP : %s", dossier.diagnostic_principal)
@@ -123,7 +124,7 @@ def write_outputs(
def main(input_path: str | None = None) -> None:
"""Point d'entrée principal."""
global _use_edsnlp
global _use_edsnlp, _use_rag
parser = argparse.ArgumentParser(
description="Anonymisation de documents médicaux PDF et extraction CIM-10",
@@ -144,6 +145,11 @@ def main(input_path: str | None = None) -> None:
action="store_true",
help="Désactiver l'analyse edsnlp (mode regex seul)",
)
parser.add_argument(
"--no-rag",
action="store_true",
help="Désactiver l'enrichissement RAG (FAISS + Ollama)",
)
args = parser.parse_args()
if args.no_ner:
@@ -154,6 +160,9 @@ def main(input_path: str | None = None) -> None:
if args.no_edsnlp:
_use_edsnlp = False
if args.no_rag:
_use_rag = False
input_p = Path(args.input)
if input_p.is_file():
pdfs = [input_p]

View File

@@ -2,10 +2,13 @@
from __future__ import annotations
import logging
import re
from datetime import datetime
from typing import Optional
logger = logging.getLogger(__name__)
from ..config import (
ActeCCAM,
BiologieCle,
@@ -91,6 +94,7 @@ def extract_medical_info(
parsed_data: dict,
anonymized_text: str,
edsnlp_result: Optional[EdsnlpResult] = None,
use_rag: bool = False,
) -> DossierMedical:
"""Extrait les informations médicales structurées depuis les données parsées et le texte."""
dossier = DossierMedical()
@@ -105,9 +109,23 @@ def extract_medical_info(
_extract_imagerie(anonymized_text, dossier)
_extract_complications(anonymized_text, dossier, edsnlp_result)
if use_rag:
_enrich_with_rag(dossier)
return dossier
def _enrich_with_rag(dossier: DossierMedical) -> None:
"""Enrichit les diagnostics via le RAG (FAISS + Ollama)."""
try:
from .rag_search import enrich_dossier
enrich_dossier(dossier)
except ImportError:
logger.warning("Module RAG non disponible (faiss-cpu ou sentence-transformers manquant)")
except Exception:
logger.warning("Erreur lors de l'enrichissement RAG", exc_info=True)
def _extract_sejour(parsed: dict, dossier: DossierMedical) -> None:
"""Extrait les informations de séjour."""
patient = parsed.get("patient", {})

352
src/medical/rag_index.py Normal file
View File

@@ -0,0 +1,352 @@
"""Indexation FAISS des documents de référence CIM-10 / Guide métho / CCAM."""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Optional
import pdfplumber
from ..config import RAG_INDEX_DIR, CIM10_PDF, GUIDE_METHODO_PDF, CCAM_PDF
logger = logging.getLogger(__name__)
# Singleton pour l'index chargé en mémoire
_faiss_index = None
_metadata: list[dict] = []
@dataclass
class Chunk:
text: str
document: str # "cim10", "guide_methodo", "ccam"
page: Optional[int] = None
code: Optional[str] = None
# ---------------------------------------------------------------------------
# Chunking CIM-10
# ---------------------------------------------------------------------------
def _chunk_cim10(pdf_path: Path) -> list[Chunk]:
"""Découpe le PDF CIM-10 en chunks par code 3 caractères (ex: K80, K85)."""
chunks: list[Chunk] = []
current_code: str | None = None
current_text: list[str] = []
current_page: int | None = None
# Pattern pour détecter un code CIM-10 à 3 caractères en début de ligne
code3_pattern = re.compile(r"^([A-Z]\d{2})\s+(.+)")
# Pattern pour les sous-codes (ex: K80.0, K80.1)
subcode_pattern = re.compile(r"^([A-Z]\d{2}\.\d+)\s+(.+)")
logger.info("Extraction des chunks CIM-10 depuis %s", pdf_path.name)
with pdfplumber.open(pdf_path) as pdf:
for page_num, page in enumerate(pdf.pages, start=1):
text = page.extract_text()
if not text:
continue
for line in text.split("\n"):
line = line.strip()
if not line:
continue
m = code3_pattern.match(line)
if m and not subcode_pattern.match(line):
# Nouveau code 3-char → sauvegarder le chunk précédent
if current_code and current_text:
chunk_text = "\n".join(current_text)
if len(chunk_text.split()) >= 5:
chunks.append(Chunk(
text=chunk_text,
document="cim10",
page=current_page,
code=current_code,
))
current_code = m.group(1)
current_text = [line]
current_page = page_num
else:
if current_code:
current_text.append(line)
# Dernier chunk
if current_code and current_text:
chunk_text = "\n".join(current_text)
if len(chunk_text.split()) >= 5:
chunks.append(Chunk(
text=chunk_text,
document="cim10",
page=current_page,
code=current_code,
))
logger.info("CIM-10 : %d chunks extraits", len(chunks))
return chunks
# ---------------------------------------------------------------------------
# Chunking Guide Méthodologique MCO
# ---------------------------------------------------------------------------
def _chunk_guide_methodo(pdf_path: Path) -> list[Chunk]:
"""Découpe le Guide Méthodologique MCO par sections/titres."""
chunks: list[Chunk] = []
current_title: str | None = None
current_text: list[str] = []
current_page: int | None = None
# Patterns de titres de sections (chapitres, sous-chapitres)
title_patterns = [
re.compile(r"^((?:CHAPITRE|TITRE|PARTIE)\s+[IVXLCDM0-9]+.*)$", re.IGNORECASE),
re.compile(r"^(\d+\.\d*\s+[A-ZÉÈÊÀÂÔÙÛÜ].{5,})$"),
re.compile(r"^([A-ZÉÈÊÀÂÔÙÛÜ][A-ZÉÈÊÀÂÔÙÛÜ\s]{10,})$"),
]
logger.info("Extraction des chunks Guide Métho depuis %s", pdf_path.name)
with pdfplumber.open(pdf_path) as pdf:
for page_num, page in enumerate(pdf.pages, start=1):
text = page.extract_text()
if not text:
continue
for line in text.split("\n"):
line = line.strip()
if not line:
continue
is_title = False
for pat in title_patterns:
if pat.match(line):
is_title = True
break
if is_title and len(line) > 8:
# Sauvegarder le chunk précédent
if current_title and current_text:
chunk_text = current_title + "\n" + "\n".join(current_text)
if len(chunk_text.split()) >= 20:
chunks.append(Chunk(
text=chunk_text,
document="guide_methodo",
page=current_page,
))
current_title = line
current_text = []
current_page = page_num
else:
current_text.append(line)
# Dernier chunk
if current_title and current_text:
chunk_text = current_title + "\n" + "\n".join(current_text)
if len(chunk_text.split()) >= 20:
chunks.append(Chunk(
text=chunk_text,
document="guide_methodo",
page=current_page,
))
# Si trop peu de chunks (le PDF ne suit pas les patterns de titre),
# fallback : découper par pages groupées par 3
if len(chunks) < 10:
logger.info("Guide Métho : fallback découpe par pages (peu de titres détectés)")
chunks = []
with pdfplumber.open(pdf_path) as pdf:
page_texts: list[str] = []
start_page = 1
for page_num, page in enumerate(pdf.pages, start=1):
text = page.extract_text()
if text:
page_texts.append(text)
if len(page_texts) >= 3:
combined = "\n".join(page_texts)
if len(combined.split()) >= 20:
chunks.append(Chunk(
text=combined,
document="guide_methodo",
page=start_page,
))
page_texts = []
start_page = page_num + 1
if page_texts:
combined = "\n".join(page_texts)
if len(combined.split()) >= 20:
chunks.append(Chunk(
text=combined,
document="guide_methodo",
page=start_page,
))
logger.info("Guide Métho : %d chunks extraits", len(chunks))
return chunks
# ---------------------------------------------------------------------------
# Chunking CCAM
# ---------------------------------------------------------------------------
def _chunk_ccam(pdf_path: Path) -> list[Chunk]:
"""Découpe le PDF CCAM en chunks par code d'acte."""
chunks: list[Chunk] = []
ccam_pattern = re.compile(r"([A-Z]{4}\d{3})\s+(.*)")
logger.info("Extraction des chunks CCAM depuis %s", pdf_path.name)
with pdfplumber.open(pdf_path) as pdf:
for page_num, page in enumerate(pdf.pages, start=1):
text = page.extract_text()
if not text:
continue
current_code: str | None = None
current_lines: list[str] = []
for line in text.split("\n"):
line = line.strip()
if not line:
continue
m = ccam_pattern.match(line)
if m:
if current_code and current_lines:
chunks.append(Chunk(
text="\n".join(current_lines),
document="ccam",
page=page_num,
code=current_code,
))
current_code = m.group(1)
current_lines = [line]
elif current_code:
current_lines.append(line)
if current_code and current_lines:
chunks.append(Chunk(
text="\n".join(current_lines),
document="ccam",
page=page_num,
code=current_code,
))
# Fallback : si aucun code CCAM détecté, indexer par page
if not chunks:
logger.info("CCAM : aucun code détecté, fallback par page")
with pdfplumber.open(pdf_path) as pdf:
for page_num, page in enumerate(pdf.pages, start=1):
text = page.extract_text()
if text and len(text.split()) >= 10:
chunks.append(Chunk(
text=text,
document="ccam",
page=page_num,
))
logger.info("CCAM : %d chunks extraits", len(chunks))
return chunks
# ---------------------------------------------------------------------------
# Construction de l'index FAISS
# ---------------------------------------------------------------------------
def build_index(force: bool = False) -> None:
"""Construit l'index FAISS à partir des 3 PDFs de référence.
Args:
force: Si True, reconstruit même si l'index existe déjà.
"""
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
index_path = RAG_INDEX_DIR / "faiss.index"
meta_path = RAG_INDEX_DIR / "metadata.json"
if not force and index_path.exists() and meta_path.exists():
logger.info("Index FAISS déjà existant dans %s (use force=True pour reconstruire)", RAG_INDEX_DIR)
return
# Collecter tous les chunks
all_chunks: list[Chunk] = []
for pdf_path, chunk_fn in [
(CIM10_PDF, _chunk_cim10),
(GUIDE_METHODO_PDF, _chunk_guide_methodo),
(CCAM_PDF, _chunk_ccam),
]:
if pdf_path.exists():
all_chunks.extend(chunk_fn(pdf_path))
else:
logger.warning("PDF non trouvé : %s", pdf_path)
if not all_chunks:
logger.error("Aucun chunk extrait — vérifiez les chemins des PDFs")
return
logger.info("Total : %d chunks à indexer", len(all_chunks))
# Embeddings — forcer CPU pour éviter les bugs CUDA avec ce modèle
logger.info("Chargement du modèle d'embedding dangvantuan/sentence-camembert-large (CPU)...")
model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu")
model.max_seq_length = 512 # CamemBERT max position embeddings
texts = [c.text[:2000] for c in all_chunks] # Tronquer les chunks trop longs
logger.info("Calcul des embeddings pour %d chunks...", len(texts))
embeddings = model.encode(
texts, show_progress_bar=True, normalize_embeddings=True, batch_size=64,
)
embeddings = np.array(embeddings, dtype=np.float32)
# Index FAISS (IndexFlatIP = cosine similarity avec vecteurs normalisés)
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
# Sauvegarder
RAG_INDEX_DIR.mkdir(parents=True, exist_ok=True)
faiss.write_index(index, str(index_path))
metadata = [asdict(c) for c in all_chunks]
# Ne pas sauvegarder le texte complet dans metadata (trop lourd),
# garder un extrait de 500 chars
for m in metadata:
m["extrait"] = m.pop("text")[:500]
meta_path.write_text(json.dumps(metadata, ensure_ascii=False, indent=2), encoding="utf-8")
logger.info("Index FAISS sauvegardé : %s (%d vecteurs, dim=%d)", index_path, len(all_chunks), dim)
def get_index() -> tuple | None:
"""Charge l'index FAISS et les métadonnées (singleton lazy-loaded).
Returns:
Tuple (faiss_index, metadata_list) ou None si l'index n'existe pas.
"""
global _faiss_index, _metadata
if _faiss_index is not None:
return _faiss_index, _metadata
index_path = RAG_INDEX_DIR / "faiss.index"
meta_path = RAG_INDEX_DIR / "metadata.json"
if not index_path.exists() or not meta_path.exists():
logger.warning("Index FAISS non trouvé dans %s — lancez build_index() d'abord", RAG_INDEX_DIR)
return None
import faiss
_faiss_index = faiss.read_index(str(index_path))
_metadata = json.loads(meta_path.read_text(encoding="utf-8"))
logger.info("Index FAISS chargé : %d vecteurs", _faiss_index.ntotal)
return _faiss_index, _metadata

208
src/medical/rag_search.py Normal file
View File

@@ -0,0 +1,208 @@
"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10."""
from __future__ import annotations
import json
import logging
from typing import Optional
import requests
from ..config import Diagnostic, DossierMedical, RAGSource
logger = logging.getLogger(__name__)
# Configuration Ollama
OLLAMA_URL = "http://localhost:11434/api/generate"
OLLAMA_MODEL = "mistral-small3.2:24b"
OLLAMA_TIMEOUT = 120 # secondes
# Singleton pour le modèle d'embedding (chargé une seule fois)
_embed_model = None
def _get_embed_model():
"""Charge le modèle d'embedding (singleton)."""
global _embed_model
if _embed_model is None:
from sentence_transformers import SentenceTransformer
logger.info("Chargement du modèle d'embedding pour la recherche...")
_embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu")
_embed_model.max_seq_length = 512
return _embed_model
def search_similar(query: str, top_k: int = 5) -> list[dict]:
"""Recherche les passages les plus similaires dans l'index FAISS.
Args:
query: Texte du diagnostic à rechercher.
top_k: Nombre de résultats à retourner.
Returns:
Liste de dicts avec les métadonnées + score de similarité.
"""
from .rag_index import get_index
import numpy as np
result = get_index()
if result is None:
logger.warning("Index FAISS non disponible")
return []
faiss_index, metadata = result
model = _get_embed_model()
query_vec = model.encode([query], normalize_embeddings=True)
query_vec = np.array(query_vec, dtype=np.float32)
scores, indices = faiss_index.search(query_vec, min(top_k, faiss_index.ntotal))
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
meta = metadata[idx].copy()
meta["score"] = float(score)
results.append(meta)
return results
def _build_prompt(texte: str, sources: list[dict], contexte: dict) -> str:
"""Construit le prompt pour Ollama."""
sources_text = ""
for i, src in enumerate(sources, 1):
doc_name = {
"cim10": "CIM-10 FR 2026",
"guide_methodo": "Guide Méthodologique MCO 2026",
"ccam": "CCAM PMSI V4 2025",
}.get(src["document"], src["document"])
code_info = f" (code: {src['code']})" if src.get("code") else ""
page_info = f" [page {src['page']}]" if src.get("page") else ""
sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n"
sources_text += (src.get("extrait", "")[:800]) + "\n\n"
ctx_parts = []
if contexte.get("sexe"):
ctx_parts.append(f"sexe: {contexte['sexe']}")
if contexte.get("age"):
ctx_parts.append(f"âge: {contexte['age']} ans")
ctx_str = ", ".join(ctx_parts) if ctx_parts else "non précisé"
return f"""Tu es un expert en codage CIM-10 pour le PMSI en France. Suggère le code CIM-10 le plus précis pour le diagnostic suivant, en te basant UNIQUEMENT sur les sources officielles fournies.
Diagnostic à coder : "{texte}"
Contexte patient : {ctx_str}
Sources de référence :
{sources_text}
Réponds UNIQUEMENT au format JSON suivant, sans texte avant ou après :
{{"code": "X99.9", "confidence": "high|medium|low", "justification": "explication courte en français"}}"""
def _call_ollama(prompt: str) -> dict | None:
"""Appelle Ollama et parse la réponse JSON."""
try:
response = requests.post(
OLLAMA_URL,
json={
"model": OLLAMA_MODEL,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.1,
"num_predict": 300,
},
},
timeout=OLLAMA_TIMEOUT,
)
response.raise_for_status()
raw = response.json().get("response", "")
# Extraire le JSON de la réponse (peut contenir du texte autour)
json_match = None
# Chercher un bloc JSON entre accolades
brace_start = raw.find("{")
brace_end = raw.rfind("}")
if brace_start != -1 and brace_end != -1:
json_match = raw[brace_start:brace_end + 1]
if json_match:
return json.loads(json_match)
else:
logger.warning("Ollama : réponse sans JSON valide : %s", raw[:200])
return None
except requests.ConnectionError:
logger.warning("Ollama non disponible (connexion refusée)")
return None
except requests.Timeout:
logger.warning("Ollama timeout après %ds", OLLAMA_TIMEOUT)
return None
except (requests.RequestException, json.JSONDecodeError) as e:
logger.warning("Ollama erreur : %s", e)
return None
def enrich_diagnostic(
diagnostic: Diagnostic,
contexte: dict,
) -> None:
"""Enrichit un Diagnostic avec le RAG (FAISS + Ollama).
Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent.
"""
# 1. Recherche FAISS
sources = search_similar(diagnostic.texte, top_k=5)
if not sources:
logger.debug("Aucune source RAG trouvée pour : %s", diagnostic.texte)
return
# 2. Stocker les sources RAG
diagnostic.sources_rag = [
RAGSource(
document=s["document"],
page=s.get("page"),
code=s.get("code"),
extrait=s.get("extrait", "")[:200],
)
for s in sources
]
# 3. Appel Ollama pour justification
prompt = _build_prompt(diagnostic.texte, sources, contexte)
llm_result = _call_ollama(prompt)
if llm_result:
code = llm_result.get("code")
confidence = llm_result.get("confidence")
justification = llm_result.get("justification")
if code:
diagnostic.cim10_suggestion = code
if confidence in ("high", "medium", "low"):
diagnostic.cim10_confidence = confidence
if justification:
diagnostic.justification = justification
else:
logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM")
def enrich_dossier(dossier: DossierMedical) -> None:
"""Enrichit le DP et tous les DAS d'un dossier via le RAG."""
contexte = {
"sexe": dossier.sejour.sexe,
"age": dossier.sejour.age,
}
if dossier.diagnostic_principal:
logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte)
enrich_diagnostic(dossier.diagnostic_principal, contexte)
for das in dossier.diagnostics_associes:
logger.info("RAG enrichissement DAS : %s", das.texte)
enrich_diagnostic(das, contexte)