feat: ajout RAG CIM-10 avec FAISS + Ollama

Implémente un système RAG (Retrieval Augmented Generation) qui indexe
les documents de référence ATIH (CIM-10 FR 2026, Guide Métho MCO,
CCAM PMSI) et utilise Ollama (mistral-small3.2:24b) pour justifier
et valider le codage CIM-10 des diagnostics.

- Nouveaux modèles Pydantic : RAGSource, Diagnostic étendu (confidence,
  justification, sources_rag) — rétrocompatible
- Module rag_index.py : chunking des 3 PDFs, embedding sentence-camembert-large,
  index FAISS IndexFlatIP (3630 vecteurs)
- Module rag_search.py : recherche FAISS + appel Ollama avec fallback double
- Flag CLI --no-rag pour désactiver l'enrichissement RAG
- 18 nouveaux tests (88/88 passent)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dom
2026-02-10 17:47:08 +01:00
parent 4a12cd2676
commit 4d6fbef2b9
8 changed files with 885 additions and 4 deletions

208
src/medical/rag_search.py Normal file
View File

@@ -0,0 +1,208 @@
"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10."""
from __future__ import annotations
import json
import logging
from typing import Optional
import requests
from ..config import Diagnostic, DossierMedical, RAGSource
logger = logging.getLogger(__name__)
# Configuration Ollama
OLLAMA_URL = "http://localhost:11434/api/generate"
OLLAMA_MODEL = "mistral-small3.2:24b"
OLLAMA_TIMEOUT = 120 # secondes
# Singleton pour le modèle d'embedding (chargé une seule fois)
_embed_model = None
def _get_embed_model():
"""Charge le modèle d'embedding (singleton)."""
global _embed_model
if _embed_model is None:
from sentence_transformers import SentenceTransformer
logger.info("Chargement du modèle d'embedding pour la recherche...")
_embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu")
_embed_model.max_seq_length = 512
return _embed_model
def search_similar(query: str, top_k: int = 5) -> list[dict]:
"""Recherche les passages les plus similaires dans l'index FAISS.
Args:
query: Texte du diagnostic à rechercher.
top_k: Nombre de résultats à retourner.
Returns:
Liste de dicts avec les métadonnées + score de similarité.
"""
from .rag_index import get_index
import numpy as np
result = get_index()
if result is None:
logger.warning("Index FAISS non disponible")
return []
faiss_index, metadata = result
model = _get_embed_model()
query_vec = model.encode([query], normalize_embeddings=True)
query_vec = np.array(query_vec, dtype=np.float32)
scores, indices = faiss_index.search(query_vec, min(top_k, faiss_index.ntotal))
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
meta = metadata[idx].copy()
meta["score"] = float(score)
results.append(meta)
return results
def _build_prompt(texte: str, sources: list[dict], contexte: dict) -> str:
"""Construit le prompt pour Ollama."""
sources_text = ""
for i, src in enumerate(sources, 1):
doc_name = {
"cim10": "CIM-10 FR 2026",
"guide_methodo": "Guide Méthodologique MCO 2026",
"ccam": "CCAM PMSI V4 2025",
}.get(src["document"], src["document"])
code_info = f" (code: {src['code']})" if src.get("code") else ""
page_info = f" [page {src['page']}]" if src.get("page") else ""
sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n"
sources_text += (src.get("extrait", "")[:800]) + "\n\n"
ctx_parts = []
if contexte.get("sexe"):
ctx_parts.append(f"sexe: {contexte['sexe']}")
if contexte.get("age"):
ctx_parts.append(f"âge: {contexte['age']} ans")
ctx_str = ", ".join(ctx_parts) if ctx_parts else "non précisé"
return f"""Tu es un expert en codage CIM-10 pour le PMSI en France. Suggère le code CIM-10 le plus précis pour le diagnostic suivant, en te basant UNIQUEMENT sur les sources officielles fournies.
Diagnostic à coder : "{texte}"
Contexte patient : {ctx_str}
Sources de référence :
{sources_text}
Réponds UNIQUEMENT au format JSON suivant, sans texte avant ou après :
{{"code": "X99.9", "confidence": "high|medium|low", "justification": "explication courte en français"}}"""
def _call_ollama(prompt: str) -> dict | None:
"""Appelle Ollama et parse la réponse JSON."""
try:
response = requests.post(
OLLAMA_URL,
json={
"model": OLLAMA_MODEL,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.1,
"num_predict": 300,
},
},
timeout=OLLAMA_TIMEOUT,
)
response.raise_for_status()
raw = response.json().get("response", "")
# Extraire le JSON de la réponse (peut contenir du texte autour)
json_match = None
# Chercher un bloc JSON entre accolades
brace_start = raw.find("{")
brace_end = raw.rfind("}")
if brace_start != -1 and brace_end != -1:
json_match = raw[brace_start:brace_end + 1]
if json_match:
return json.loads(json_match)
else:
logger.warning("Ollama : réponse sans JSON valide : %s", raw[:200])
return None
except requests.ConnectionError:
logger.warning("Ollama non disponible (connexion refusée)")
return None
except requests.Timeout:
logger.warning("Ollama timeout après %ds", OLLAMA_TIMEOUT)
return None
except (requests.RequestException, json.JSONDecodeError) as e:
logger.warning("Ollama erreur : %s", e)
return None
def enrich_diagnostic(
diagnostic: Diagnostic,
contexte: dict,
) -> None:
"""Enrichit un Diagnostic avec le RAG (FAISS + Ollama).
Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent.
"""
# 1. Recherche FAISS
sources = search_similar(diagnostic.texte, top_k=5)
if not sources:
logger.debug("Aucune source RAG trouvée pour : %s", diagnostic.texte)
return
# 2. Stocker les sources RAG
diagnostic.sources_rag = [
RAGSource(
document=s["document"],
page=s.get("page"),
code=s.get("code"),
extrait=s.get("extrait", "")[:200],
)
for s in sources
]
# 3. Appel Ollama pour justification
prompt = _build_prompt(diagnostic.texte, sources, contexte)
llm_result = _call_ollama(prompt)
if llm_result:
code = llm_result.get("code")
confidence = llm_result.get("confidence")
justification = llm_result.get("justification")
if code:
diagnostic.cim10_suggestion = code
if confidence in ("high", "medium", "low"):
diagnostic.cim10_confidence = confidence
if justification:
diagnostic.justification = justification
else:
logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM")
def enrich_dossier(dossier: DossierMedical) -> None:
"""Enrichit le DP et tous les DAS d'un dossier via le RAG."""
contexte = {
"sexe": dossier.sejour.sexe,
"age": dossier.sejour.age,
}
if dossier.diagnostic_principal:
logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte)
enrich_diagnostic(dossier.diagnostic_principal, contexte)
for das in dossier.diagnostics_associes:
logger.info("RAG enrichissement DAS : %s", das.texte)
enrich_diagnostic(das, contexte)