Files
t2a_v2/src/medical/rag_search.py
dom 931b6c5d1c feat: embeddings sur GPU (CUDA) pour l'indexation et la recherche RAG
Détection automatique GPU/CPU avec fallback. Index FAISS reconstruit
en 1min (GPU) au lieu de 16min (CPU).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:42:46 +01:00

363 lines
13 KiB
Python

"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10."""
from __future__ import annotations
import json
import logging
import re
from typing import Optional
import requests
from ..config import Diagnostic, DossierMedical, RAGSource, OLLAMA_URL, OLLAMA_MODEL, OLLAMA_TIMEOUT
logger = logging.getLogger(__name__)
# Singleton pour le modèle d'embedding (chargé une seule fois)
_embed_model = None
# Score minimum de similarité FAISS pour retenir un résultat
_MIN_SCORE = 0.3
# Marqueur de fin de raisonnement dans la réponse Ollama
_RESULT_MARKER = "###RESULT###"
def _get_embed_model():
"""Charge le modèle d'embedding (singleton)."""
global _embed_model
if _embed_model is None:
from sentence_transformers import SentenceTransformer
logger.info("Chargement du modèle d'embedding pour la recherche...")
import torch
_device = "cuda" if torch.cuda.is_available() else "cpu"
_embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device=_device)
_embed_model.max_seq_length = 512
return _embed_model
def search_similar(query: str, top_k: int = 10) -> list[dict]:
"""Recherche les passages les plus similaires dans l'index FAISS.
Args:
query: Texte du diagnostic à rechercher.
top_k: Nombre de résultats à retourner.
Returns:
Liste de dicts avec les métadonnées + score de similarité,
filtrés par score minimum et priorisant les sources CIM-10.
"""
from .rag_index import get_index
import numpy as np
result = get_index()
if result is None:
logger.warning("Index FAISS non disponible")
return []
faiss_index, metadata = result
model = _get_embed_model()
query_vec = model.encode([query], normalize_embeddings=True)
query_vec = np.array(query_vec, dtype=np.float32)
# Chercher plus de résultats que top_k pour pouvoir filtrer ensuite
fetch_k = min(top_k * 2, faiss_index.ntotal)
scores, indices = faiss_index.search(query_vec, fetch_k)
raw_results = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
if float(score) < _MIN_SCORE:
continue
meta = metadata[idx].copy()
meta["score"] = float(score)
raw_results.append(meta)
# Prioriser les sources CIM-10 (au moins 6 sur top_k)
cim10_results = [r for r in raw_results if r["document"] in ("cim10", "cim10_alpha")]
other_results = [r for r in raw_results if r["document"] not in ("cim10", "cim10_alpha")]
min_cim10 = min(6, len(cim10_results))
final = cim10_results[:min_cim10]
remaining_slots = top_k - len(final)
# Remplir le reste avec les meilleurs résultats (CIM-10 restants + autres)
remaining = cim10_results[min_cim10:] + other_results
remaining.sort(key=lambda r: r["score"], reverse=True)
final.extend(remaining[:remaining_slots])
return final
def _format_contexte(contexte: dict) -> str:
"""Formate le contexte patient de manière structurée pour le prompt."""
lines = []
sexe = contexte.get("sexe")
age = contexte.get("age")
imc = contexte.get("imc")
patient_parts = []
if sexe:
patient_parts.append(sexe)
if age:
patient_parts.append(f"{age} ans")
if imc:
patient_parts.append(f"IMC {imc}")
if patient_parts:
lines.append(f"- Patient : {', '.join(patient_parts)}")
duree = contexte.get("duree_sejour")
if duree:
lines.append(f"- Durée séjour : {duree} jours")
antecedents = contexte.get("antecedents")
if antecedents:
lines.append(f"- Antécédents : {', '.join(antecedents[:5])}")
biologie = contexte.get("biologie_cle")
if biologie:
bio_parts = []
for b in biologie:
test, valeur, anomalie = b if isinstance(b, (list, tuple)) else (b.get("test"), b.get("valeur"), b.get("anomalie"))
marker = " (\u2191)" if anomalie else ""
bio_parts.append(f"{test} {valeur}{marker}")
lines.append(f"- Biologie : {', '.join(bio_parts)}")
imagerie = contexte.get("imagerie")
if imagerie:
for img in imagerie:
img_type, conclusion = img if isinstance(img, (list, tuple)) else (img.get("type"), img.get("conclusion"))
if conclusion:
lines.append(f"- Imagerie : {img_type}{conclusion[:200]}")
complications = contexte.get("complications")
if complications:
lines.append(f"- Complications : {', '.join(complications)}")
dp_texte = contexte.get("dp_texte")
if dp_texte:
lines.append(f"- DP du séjour : {dp_texte}")
das_codes = contexte.get("das_codes_existants")
if das_codes:
lines.append(f"- DAS déjà codés : {', '.join(das_codes)}")
return "\n".join(lines) if lines else "Non précisé"
def _build_prompt(texte: str, sources: list[dict], contexte: dict, est_dp: bool = True) -> str:
"""Construit le prompt expert DIM avec raisonnement structuré."""
sources_text = ""
for i, src in enumerate(sources, 1):
doc_name = {
"cim10": "CIM-10 FR 2026",
"cim10_alpha": "CIM-10 Index Alphabétique 2026",
"guide_methodo": "Guide Méthodologique MCO 2026",
"ccam": "CCAM PMSI V4 2025",
}.get(src["document"], src["document"])
code_info = f" (code: {src['code']})" if src.get("code") else ""
page_info = f" [page {src['page']}]" if src.get("page") else ""
sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n"
sources_text += (src.get("extrait", "")[:800]) + "\n\n"
type_diag = "DP (diagnostic principal)" if est_dp else "DAS (diagnostic associé significatif)"
ctx_str = _format_contexte(contexte)
return f"""Tu es un médecin DIM (Département d'Information Médicale) expert en codage PMSI.
Tu dois coder le diagnostic suivant en respectant STRICTEMENT les règles de l'ATIH.
RÈGLES IMPÉRATIVES :
- Le code doit provenir UNIQUEMENT des sources CIM-10 fournies
- Distingue la DESCRIPTION CLINIQUE (ce que le médecin écrit) de la LOGIQUE DE CODAGE (ce que l'ATIH impose)
- Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère)
- Vérifie les notes d'inclusion/exclusion de chaque code candidat
- Si le diagnostic est un DP, il doit refléter le motif principal de prise en charge du séjour
- Si c'est un DAS, il doit avoir mobilisé des ressources supplémentaires pendant le séjour
- EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis (Chapitres I-XIV, A00-N99) expliquant ce symptôme est présent, le symptôme ne doit PAS être codé comme DAS
DIAGNOSTIC À CODER : "{texte}"
TYPE : {type_diag}
CONTEXTE CLINIQUE :
{ctx_str}
SOURCES CIM-10 :
{sources_text}
RAISONNE ÉTAPE PAR ÉTAPE :
1. ANALYSE CLINIQUE : Que signifie ce diagnostic sur le plan médical ?
2. CODES CANDIDATS : Quels codes des sources fournies sont compatibles ?
3. DISCRIMINATION : Pourquoi choisir un code plutôt qu'un autre ? (inclusions/exclusions, spécificité)
4. RÈGLE PMSI : Ce code est-il conforme pour un {type_diag} ? (guide méthodologique)
Après ton raisonnement, conclus OBLIGATOIREMENT par le JSON suivant sur une ligne séparée :
{_RESULT_MARKER}
{{"code": "X99.9", "confidence": "high|medium|low", "justification": "explication courte en français"}}"""
def _parse_ollama_response(raw: str) -> dict | None:
"""Parse la réponse Ollama en extrayant le JSON après le marqueur ###RESULT###.
Fallback sur la recherche d'accolades si le marqueur est absent.
Retourne un dict avec les clés code/confidence/justification + raisonnement.
"""
raisonnement = None
json_str = None
# Stratégie 1 : chercher le marqueur ###RESULT###
marker_pos = raw.find(_RESULT_MARKER)
if marker_pos != -1:
raisonnement = raw[:marker_pos].strip()
after_marker = raw[marker_pos + len(_RESULT_MARKER):]
brace_start = after_marker.find("{")
brace_end = after_marker.rfind("}")
if brace_start != -1 and brace_end != -1:
json_str = after_marker[brace_start:brace_end + 1]
else:
# Fallback : chercher le dernier bloc JSON dans la réponse
# (le raisonnement peut contenir des accolades intermédiaires)
last_brace = raw.rfind("}")
if last_brace != -1:
# Chercher l'accolade ouvrante correspondante en remontant
depth = 0
start = -1
for i in range(last_brace, -1, -1):
if raw[i] == "}":
depth += 1
elif raw[i] == "{":
depth -= 1
if depth == 0:
start = i
break
if start != -1:
json_str = raw[start:last_brace + 1]
raisonnement = raw[:start].strip()
if not json_str:
logger.warning("Ollama : réponse sans JSON valide : %s", raw[:200])
return None
try:
parsed = json.loads(json_str)
except json.JSONDecodeError:
logger.warning("Ollama : JSON invalide : %s", json_str[:200])
return None
if raisonnement:
parsed["raisonnement"] = raisonnement
return parsed
def _call_ollama(prompt: str) -> dict | None:
"""Appelle Ollama et parse la réponse JSON."""
try:
response = requests.post(
f"{OLLAMA_URL}/api/generate",
json={
"model": OLLAMA_MODEL,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.1,
"num_predict": 1200,
},
},
timeout=OLLAMA_TIMEOUT,
)
response.raise_for_status()
raw = response.json().get("response", "")
return _parse_ollama_response(raw)
except requests.ConnectionError:
logger.warning("Ollama non disponible (connexion refusée)")
return None
except requests.Timeout:
logger.warning("Ollama timeout après %ds", OLLAMA_TIMEOUT)
return None
except (requests.RequestException, json.JSONDecodeError) as e:
logger.warning("Ollama erreur : %s", e)
return None
def enrich_diagnostic(
diagnostic: Diagnostic,
contexte: dict,
est_dp: bool = True,
) -> None:
"""Enrichit un Diagnostic avec le RAG (FAISS + Ollama).
Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent.
"""
# 1. Recherche FAISS
sources = search_similar(diagnostic.texte, top_k=10)
if not sources:
logger.debug("Aucune source RAG trouvée pour : %s", diagnostic.texte)
return
# 2. Stocker les sources RAG
diagnostic.sources_rag = [
RAGSource(
document=s["document"],
page=s.get("page"),
code=s.get("code"),
extrait=s.get("extrait", "")[:200],
)
for s in sources
]
# 3. Appel Ollama pour justification avec raisonnement structuré
prompt = _build_prompt(diagnostic.texte, sources, contexte, est_dp=est_dp)
llm_result = _call_ollama(prompt)
if llm_result:
code = llm_result.get("code")
confidence = llm_result.get("confidence")
justification = llm_result.get("justification")
raisonnement = llm_result.get("raisonnement")
if code:
diagnostic.cim10_suggestion = code
if confidence in ("high", "medium", "low"):
diagnostic.cim10_confidence = confidence
if justification:
diagnostic.justification = justification
if raisonnement:
diagnostic.raisonnement = raisonnement
else:
logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM")
def enrich_dossier(dossier: DossierMedical) -> None:
"""Enrichit le DP et tous les DAS d'un dossier via le RAG."""
contexte = {
"sexe": dossier.sejour.sexe,
"age": dossier.sejour.age,
"duree_sejour": dossier.sejour.duree_sejour,
"imc": dossier.sejour.imc,
"antecedents": dossier.antecedents[:5],
"biologie_cle": [(b.test, b.valeur, b.anomalie) for b in dossier.biologie_cle],
"imagerie": [(i.type, (i.conclusion or "")[:200]) for i in dossier.imagerie],
"complications": dossier.complications,
}
if dossier.diagnostic_principal:
logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte)
enrich_diagnostic(dossier.diagnostic_principal, contexte, est_dp=True)
# Pour les DAS, ajouter le DP et les DAS existants au contexte pour cohérence
if dossier.diagnostic_principal:
contexte["dp_texte"] = dossier.diagnostic_principal.texte
contexte["das_codes_existants"] = [
f"{d.cim10_suggestion} ({d.texte})"
for d in dossier.diagnostics_associes
if d.cim10_suggestion
]
for das in dossier.diagnostics_associes:
logger.info("RAG enrichissement DAS : %s", das.texte)
enrich_diagnostic(das, contexte, est_dp=False)