feat: ajout RAG CIM-10 avec FAISS + Ollama
Implémente un système RAG (Retrieval Augmented Generation) qui indexe les documents de référence ATIH (CIM-10 FR 2026, Guide Métho MCO, CCAM PMSI) et utilise Ollama (mistral-small3.2:24b) pour justifier et valider le codage CIM-10 des diagnostics. - Nouveaux modèles Pydantic : RAGSource, Diagnostic étendu (confidence, justification, sources_rag) — rétrocompatible - Module rag_index.py : chunking des 3 PDFs, embedding sentence-camembert-large, index FAISS IndexFlatIP (3630 vecteurs) - Module rag_search.py : recherche FAISS + appel Ollama avec fallback double - Flag CLI --no-rag pour désactiver l'enrichissement RAG - 18 nouveaux tests (88/88 passent) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
208
src/medical/rag_search.py
Normal file
208
src/medical/rag_search.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from ..config import Diagnostic, DossierMedical, RAGSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration Ollama
|
||||
OLLAMA_URL = "http://localhost:11434/api/generate"
|
||||
OLLAMA_MODEL = "mistral-small3.2:24b"
|
||||
OLLAMA_TIMEOUT = 120 # secondes
|
||||
|
||||
# Singleton pour le modèle d'embedding (chargé une seule fois)
|
||||
_embed_model = None
|
||||
|
||||
|
||||
def _get_embed_model():
|
||||
"""Charge le modèle d'embedding (singleton)."""
|
||||
global _embed_model
|
||||
if _embed_model is None:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
logger.info("Chargement du modèle d'embedding pour la recherche...")
|
||||
_embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu")
|
||||
_embed_model.max_seq_length = 512
|
||||
return _embed_model
|
||||
|
||||
|
||||
def search_similar(query: str, top_k: int = 5) -> list[dict]:
|
||||
"""Recherche les passages les plus similaires dans l'index FAISS.
|
||||
|
||||
Args:
|
||||
query: Texte du diagnostic à rechercher.
|
||||
top_k: Nombre de résultats à retourner.
|
||||
|
||||
Returns:
|
||||
Liste de dicts avec les métadonnées + score de similarité.
|
||||
"""
|
||||
from .rag_index import get_index
|
||||
import numpy as np
|
||||
|
||||
result = get_index()
|
||||
if result is None:
|
||||
logger.warning("Index FAISS non disponible")
|
||||
return []
|
||||
|
||||
faiss_index, metadata = result
|
||||
|
||||
model = _get_embed_model()
|
||||
query_vec = model.encode([query], normalize_embeddings=True)
|
||||
query_vec = np.array(query_vec, dtype=np.float32)
|
||||
|
||||
scores, indices = faiss_index.search(query_vec, min(top_k, faiss_index.ntotal))
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx < 0:
|
||||
continue
|
||||
meta = metadata[idx].copy()
|
||||
meta["score"] = float(score)
|
||||
results.append(meta)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _build_prompt(texte: str, sources: list[dict], contexte: dict) -> str:
|
||||
"""Construit le prompt pour Ollama."""
|
||||
sources_text = ""
|
||||
for i, src in enumerate(sources, 1):
|
||||
doc_name = {
|
||||
"cim10": "CIM-10 FR 2026",
|
||||
"guide_methodo": "Guide Méthodologique MCO 2026",
|
||||
"ccam": "CCAM PMSI V4 2025",
|
||||
}.get(src["document"], src["document"])
|
||||
|
||||
code_info = f" (code: {src['code']})" if src.get("code") else ""
|
||||
page_info = f" [page {src['page']}]" if src.get("page") else ""
|
||||
|
||||
sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n"
|
||||
sources_text += (src.get("extrait", "")[:800]) + "\n\n"
|
||||
|
||||
ctx_parts = []
|
||||
if contexte.get("sexe"):
|
||||
ctx_parts.append(f"sexe: {contexte['sexe']}")
|
||||
if contexte.get("age"):
|
||||
ctx_parts.append(f"âge: {contexte['age']} ans")
|
||||
ctx_str = ", ".join(ctx_parts) if ctx_parts else "non précisé"
|
||||
|
||||
return f"""Tu es un expert en codage CIM-10 pour le PMSI en France. Suggère le code CIM-10 le plus précis pour le diagnostic suivant, en te basant UNIQUEMENT sur les sources officielles fournies.
|
||||
|
||||
Diagnostic à coder : "{texte}"
|
||||
Contexte patient : {ctx_str}
|
||||
|
||||
Sources de référence :
|
||||
{sources_text}
|
||||
Réponds UNIQUEMENT au format JSON suivant, sans texte avant ou après :
|
||||
{{"code": "X99.9", "confidence": "high|medium|low", "justification": "explication courte en français"}}"""
|
||||
|
||||
|
||||
def _call_ollama(prompt: str) -> dict | None:
|
||||
"""Appelle Ollama et parse la réponse JSON."""
|
||||
try:
|
||||
response = requests.post(
|
||||
OLLAMA_URL,
|
||||
json={
|
||||
"model": OLLAMA_MODEL,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": 0.1,
|
||||
"num_predict": 300,
|
||||
},
|
||||
},
|
||||
timeout=OLLAMA_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
raw = response.json().get("response", "")
|
||||
|
||||
# Extraire le JSON de la réponse (peut contenir du texte autour)
|
||||
json_match = None
|
||||
# Chercher un bloc JSON entre accolades
|
||||
brace_start = raw.find("{")
|
||||
brace_end = raw.rfind("}")
|
||||
if brace_start != -1 and brace_end != -1:
|
||||
json_match = raw[brace_start:brace_end + 1]
|
||||
|
||||
if json_match:
|
||||
return json.loads(json_match)
|
||||
else:
|
||||
logger.warning("Ollama : réponse sans JSON valide : %s", raw[:200])
|
||||
return None
|
||||
|
||||
except requests.ConnectionError:
|
||||
logger.warning("Ollama non disponible (connexion refusée)")
|
||||
return None
|
||||
except requests.Timeout:
|
||||
logger.warning("Ollama timeout après %ds", OLLAMA_TIMEOUT)
|
||||
return None
|
||||
except (requests.RequestException, json.JSONDecodeError) as e:
|
||||
logger.warning("Ollama erreur : %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def enrich_diagnostic(
|
||||
diagnostic: Diagnostic,
|
||||
contexte: dict,
|
||||
) -> None:
|
||||
"""Enrichit un Diagnostic avec le RAG (FAISS + Ollama).
|
||||
|
||||
Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent.
|
||||
"""
|
||||
# 1. Recherche FAISS
|
||||
sources = search_similar(diagnostic.texte, top_k=5)
|
||||
|
||||
if not sources:
|
||||
logger.debug("Aucune source RAG trouvée pour : %s", diagnostic.texte)
|
||||
return
|
||||
|
||||
# 2. Stocker les sources RAG
|
||||
diagnostic.sources_rag = [
|
||||
RAGSource(
|
||||
document=s["document"],
|
||||
page=s.get("page"),
|
||||
code=s.get("code"),
|
||||
extrait=s.get("extrait", "")[:200],
|
||||
)
|
||||
for s in sources
|
||||
]
|
||||
|
||||
# 3. Appel Ollama pour justification
|
||||
prompt = _build_prompt(diagnostic.texte, sources, contexte)
|
||||
llm_result = _call_ollama(prompt)
|
||||
|
||||
if llm_result:
|
||||
code = llm_result.get("code")
|
||||
confidence = llm_result.get("confidence")
|
||||
justification = llm_result.get("justification")
|
||||
|
||||
if code:
|
||||
diagnostic.cim10_suggestion = code
|
||||
if confidence in ("high", "medium", "low"):
|
||||
diagnostic.cim10_confidence = confidence
|
||||
if justification:
|
||||
diagnostic.justification = justification
|
||||
else:
|
||||
logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM")
|
||||
|
||||
|
||||
def enrich_dossier(dossier: DossierMedical) -> None:
|
||||
"""Enrichit le DP et tous les DAS d'un dossier via le RAG."""
|
||||
contexte = {
|
||||
"sexe": dossier.sejour.sexe,
|
||||
"age": dossier.sejour.age,
|
||||
}
|
||||
|
||||
if dossier.diagnostic_principal:
|
||||
logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte)
|
||||
enrich_diagnostic(dossier.diagnostic_principal, contexte)
|
||||
|
||||
for das in dossier.diagnostics_associes:
|
||||
logger.info("RAG enrichissement DAS : %s", das.texte)
|
||||
enrich_diagnostic(das, contexte)
|
||||
Reference in New Issue
Block a user