Détection automatique GPU/CPU avec fallback. Index FAISS reconstruit en 1min (GPU) au lieu de 16min (CPU). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
363 lines
13 KiB
Python
363 lines
13 KiB
Python
"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from typing import Optional
|
|
|
|
import requests
|
|
|
|
from ..config import Diagnostic, DossierMedical, RAGSource, OLLAMA_URL, OLLAMA_MODEL, OLLAMA_TIMEOUT
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Singleton pour le modèle d'embedding (chargé une seule fois)
|
|
_embed_model = None
|
|
|
|
# Score minimum de similarité FAISS pour retenir un résultat
|
|
_MIN_SCORE = 0.3
|
|
|
|
# Marqueur de fin de raisonnement dans la réponse Ollama
|
|
_RESULT_MARKER = "###RESULT###"
|
|
|
|
|
|
def _get_embed_model():
|
|
"""Charge le modèle d'embedding (singleton)."""
|
|
global _embed_model
|
|
if _embed_model is None:
|
|
from sentence_transformers import SentenceTransformer
|
|
logger.info("Chargement du modèle d'embedding pour la recherche...")
|
|
import torch
|
|
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
_embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device=_device)
|
|
_embed_model.max_seq_length = 512
|
|
return _embed_model
|
|
|
|
|
|
def search_similar(query: str, top_k: int = 10) -> list[dict]:
|
|
"""Recherche les passages les plus similaires dans l'index FAISS.
|
|
|
|
Args:
|
|
query: Texte du diagnostic à rechercher.
|
|
top_k: Nombre de résultats à retourner.
|
|
|
|
Returns:
|
|
Liste de dicts avec les métadonnées + score de similarité,
|
|
filtrés par score minimum et priorisant les sources CIM-10.
|
|
"""
|
|
from .rag_index import get_index
|
|
import numpy as np
|
|
|
|
result = get_index()
|
|
if result is None:
|
|
logger.warning("Index FAISS non disponible")
|
|
return []
|
|
|
|
faiss_index, metadata = result
|
|
|
|
model = _get_embed_model()
|
|
query_vec = model.encode([query], normalize_embeddings=True)
|
|
query_vec = np.array(query_vec, dtype=np.float32)
|
|
|
|
# Chercher plus de résultats que top_k pour pouvoir filtrer ensuite
|
|
fetch_k = min(top_k * 2, faiss_index.ntotal)
|
|
scores, indices = faiss_index.search(query_vec, fetch_k)
|
|
|
|
raw_results = []
|
|
for score, idx in zip(scores[0], indices[0]):
|
|
if idx < 0:
|
|
continue
|
|
if float(score) < _MIN_SCORE:
|
|
continue
|
|
meta = metadata[idx].copy()
|
|
meta["score"] = float(score)
|
|
raw_results.append(meta)
|
|
|
|
# Prioriser les sources CIM-10 (au moins 6 sur top_k)
|
|
cim10_results = [r for r in raw_results if r["document"] in ("cim10", "cim10_alpha")]
|
|
other_results = [r for r in raw_results if r["document"] not in ("cim10", "cim10_alpha")]
|
|
|
|
min_cim10 = min(6, len(cim10_results))
|
|
final = cim10_results[:min_cim10]
|
|
remaining_slots = top_k - len(final)
|
|
# Remplir le reste avec les meilleurs résultats (CIM-10 restants + autres)
|
|
remaining = cim10_results[min_cim10:] + other_results
|
|
remaining.sort(key=lambda r: r["score"], reverse=True)
|
|
final.extend(remaining[:remaining_slots])
|
|
|
|
return final
|
|
|
|
|
|
def _format_contexte(contexte: dict) -> str:
|
|
"""Formate le contexte patient de manière structurée pour le prompt."""
|
|
lines = []
|
|
|
|
sexe = contexte.get("sexe")
|
|
age = contexte.get("age")
|
|
imc = contexte.get("imc")
|
|
patient_parts = []
|
|
if sexe:
|
|
patient_parts.append(sexe)
|
|
if age:
|
|
patient_parts.append(f"{age} ans")
|
|
if imc:
|
|
patient_parts.append(f"IMC {imc}")
|
|
if patient_parts:
|
|
lines.append(f"- Patient : {', '.join(patient_parts)}")
|
|
|
|
duree = contexte.get("duree_sejour")
|
|
if duree:
|
|
lines.append(f"- Durée séjour : {duree} jours")
|
|
|
|
antecedents = contexte.get("antecedents")
|
|
if antecedents:
|
|
lines.append(f"- Antécédents : {', '.join(antecedents[:5])}")
|
|
|
|
biologie = contexte.get("biologie_cle")
|
|
if biologie:
|
|
bio_parts = []
|
|
for b in biologie:
|
|
test, valeur, anomalie = b if isinstance(b, (list, tuple)) else (b.get("test"), b.get("valeur"), b.get("anomalie"))
|
|
marker = " (\u2191)" if anomalie else ""
|
|
bio_parts.append(f"{test} {valeur}{marker}")
|
|
lines.append(f"- Biologie : {', '.join(bio_parts)}")
|
|
|
|
imagerie = contexte.get("imagerie")
|
|
if imagerie:
|
|
for img in imagerie:
|
|
img_type, conclusion = img if isinstance(img, (list, tuple)) else (img.get("type"), img.get("conclusion"))
|
|
if conclusion:
|
|
lines.append(f"- Imagerie : {img_type} — {conclusion[:200]}")
|
|
|
|
complications = contexte.get("complications")
|
|
if complications:
|
|
lines.append(f"- Complications : {', '.join(complications)}")
|
|
|
|
dp_texte = contexte.get("dp_texte")
|
|
if dp_texte:
|
|
lines.append(f"- DP du séjour : {dp_texte}")
|
|
|
|
das_codes = contexte.get("das_codes_existants")
|
|
if das_codes:
|
|
lines.append(f"- DAS déjà codés : {', '.join(das_codes)}")
|
|
|
|
return "\n".join(lines) if lines else "Non précisé"
|
|
|
|
|
|
def _build_prompt(texte: str, sources: list[dict], contexte: dict, est_dp: bool = True) -> str:
|
|
"""Construit le prompt expert DIM avec raisonnement structuré."""
|
|
sources_text = ""
|
|
for i, src in enumerate(sources, 1):
|
|
doc_name = {
|
|
"cim10": "CIM-10 FR 2026",
|
|
"cim10_alpha": "CIM-10 Index Alphabétique 2026",
|
|
"guide_methodo": "Guide Méthodologique MCO 2026",
|
|
"ccam": "CCAM PMSI V4 2025",
|
|
}.get(src["document"], src["document"])
|
|
|
|
code_info = f" (code: {src['code']})" if src.get("code") else ""
|
|
page_info = f" [page {src['page']}]" if src.get("page") else ""
|
|
|
|
sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n"
|
|
sources_text += (src.get("extrait", "")[:800]) + "\n\n"
|
|
|
|
type_diag = "DP (diagnostic principal)" if est_dp else "DAS (diagnostic associé significatif)"
|
|
ctx_str = _format_contexte(contexte)
|
|
|
|
return f"""Tu es un médecin DIM (Département d'Information Médicale) expert en codage PMSI.
|
|
Tu dois coder le diagnostic suivant en respectant STRICTEMENT les règles de l'ATIH.
|
|
|
|
RÈGLES IMPÉRATIVES :
|
|
- Le code doit provenir UNIQUEMENT des sources CIM-10 fournies
|
|
- Distingue la DESCRIPTION CLINIQUE (ce que le médecin écrit) de la LOGIQUE DE CODAGE (ce que l'ATIH impose)
|
|
- Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère)
|
|
- Vérifie les notes d'inclusion/exclusion de chaque code candidat
|
|
- Si le diagnostic est un DP, il doit refléter le motif principal de prise en charge du séjour
|
|
- Si c'est un DAS, il doit avoir mobilisé des ressources supplémentaires pendant le séjour
|
|
- EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis (Chapitres I-XIV, A00-N99) expliquant ce symptôme est présent, le symptôme ne doit PAS être codé comme DAS
|
|
|
|
DIAGNOSTIC À CODER : "{texte}"
|
|
TYPE : {type_diag}
|
|
|
|
CONTEXTE CLINIQUE :
|
|
{ctx_str}
|
|
|
|
SOURCES CIM-10 :
|
|
{sources_text}
|
|
RAISONNE ÉTAPE PAR ÉTAPE :
|
|
1. ANALYSE CLINIQUE : Que signifie ce diagnostic sur le plan médical ?
|
|
2. CODES CANDIDATS : Quels codes des sources fournies sont compatibles ?
|
|
3. DISCRIMINATION : Pourquoi choisir un code plutôt qu'un autre ? (inclusions/exclusions, spécificité)
|
|
4. RÈGLE PMSI : Ce code est-il conforme pour un {type_diag} ? (guide méthodologique)
|
|
|
|
Après ton raisonnement, conclus OBLIGATOIREMENT par le JSON suivant sur une ligne séparée :
|
|
{_RESULT_MARKER}
|
|
{{"code": "X99.9", "confidence": "high|medium|low", "justification": "explication courte en français"}}"""
|
|
|
|
|
|
def _parse_ollama_response(raw: str) -> dict | None:
|
|
"""Parse la réponse Ollama en extrayant le JSON après le marqueur ###RESULT###.
|
|
|
|
Fallback sur la recherche d'accolades si le marqueur est absent.
|
|
Retourne un dict avec les clés code/confidence/justification + raisonnement.
|
|
"""
|
|
raisonnement = None
|
|
json_str = None
|
|
|
|
# Stratégie 1 : chercher le marqueur ###RESULT###
|
|
marker_pos = raw.find(_RESULT_MARKER)
|
|
if marker_pos != -1:
|
|
raisonnement = raw[:marker_pos].strip()
|
|
after_marker = raw[marker_pos + len(_RESULT_MARKER):]
|
|
brace_start = after_marker.find("{")
|
|
brace_end = after_marker.rfind("}")
|
|
if brace_start != -1 and brace_end != -1:
|
|
json_str = after_marker[brace_start:brace_end + 1]
|
|
else:
|
|
# Fallback : chercher le dernier bloc JSON dans la réponse
|
|
# (le raisonnement peut contenir des accolades intermédiaires)
|
|
last_brace = raw.rfind("}")
|
|
if last_brace != -1:
|
|
# Chercher l'accolade ouvrante correspondante en remontant
|
|
depth = 0
|
|
start = -1
|
|
for i in range(last_brace, -1, -1):
|
|
if raw[i] == "}":
|
|
depth += 1
|
|
elif raw[i] == "{":
|
|
depth -= 1
|
|
if depth == 0:
|
|
start = i
|
|
break
|
|
if start != -1:
|
|
json_str = raw[start:last_brace + 1]
|
|
raisonnement = raw[:start].strip()
|
|
|
|
if not json_str:
|
|
logger.warning("Ollama : réponse sans JSON valide : %s", raw[:200])
|
|
return None
|
|
|
|
try:
|
|
parsed = json.loads(json_str)
|
|
except json.JSONDecodeError:
|
|
logger.warning("Ollama : JSON invalide : %s", json_str[:200])
|
|
return None
|
|
|
|
if raisonnement:
|
|
parsed["raisonnement"] = raisonnement
|
|
|
|
return parsed
|
|
|
|
|
|
def _call_ollama(prompt: str) -> dict | None:
|
|
"""Appelle Ollama et parse la réponse JSON."""
|
|
try:
|
|
response = requests.post(
|
|
f"{OLLAMA_URL}/api/generate",
|
|
json={
|
|
"model": OLLAMA_MODEL,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": 0.1,
|
|
"num_predict": 1200,
|
|
},
|
|
},
|
|
timeout=OLLAMA_TIMEOUT,
|
|
)
|
|
response.raise_for_status()
|
|
raw = response.json().get("response", "")
|
|
return _parse_ollama_response(raw)
|
|
|
|
except requests.ConnectionError:
|
|
logger.warning("Ollama non disponible (connexion refusée)")
|
|
return None
|
|
except requests.Timeout:
|
|
logger.warning("Ollama timeout après %ds", OLLAMA_TIMEOUT)
|
|
return None
|
|
except (requests.RequestException, json.JSONDecodeError) as e:
|
|
logger.warning("Ollama erreur : %s", e)
|
|
return None
|
|
|
|
|
|
def enrich_diagnostic(
|
|
diagnostic: Diagnostic,
|
|
contexte: dict,
|
|
est_dp: bool = True,
|
|
) -> None:
|
|
"""Enrichit un Diagnostic avec le RAG (FAISS + Ollama).
|
|
|
|
Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent.
|
|
"""
|
|
# 1. Recherche FAISS
|
|
sources = search_similar(diagnostic.texte, top_k=10)
|
|
|
|
if not sources:
|
|
logger.debug("Aucune source RAG trouvée pour : %s", diagnostic.texte)
|
|
return
|
|
|
|
# 2. Stocker les sources RAG
|
|
diagnostic.sources_rag = [
|
|
RAGSource(
|
|
document=s["document"],
|
|
page=s.get("page"),
|
|
code=s.get("code"),
|
|
extrait=s.get("extrait", "")[:200],
|
|
)
|
|
for s in sources
|
|
]
|
|
|
|
# 3. Appel Ollama pour justification avec raisonnement structuré
|
|
prompt = _build_prompt(diagnostic.texte, sources, contexte, est_dp=est_dp)
|
|
llm_result = _call_ollama(prompt)
|
|
|
|
if llm_result:
|
|
code = llm_result.get("code")
|
|
confidence = llm_result.get("confidence")
|
|
justification = llm_result.get("justification")
|
|
raisonnement = llm_result.get("raisonnement")
|
|
|
|
if code:
|
|
diagnostic.cim10_suggestion = code
|
|
if confidence in ("high", "medium", "low"):
|
|
diagnostic.cim10_confidence = confidence
|
|
if justification:
|
|
diagnostic.justification = justification
|
|
if raisonnement:
|
|
diagnostic.raisonnement = raisonnement
|
|
else:
|
|
logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM")
|
|
|
|
|
|
def enrich_dossier(dossier: DossierMedical) -> None:
|
|
"""Enrichit le DP et tous les DAS d'un dossier via le RAG."""
|
|
contexte = {
|
|
"sexe": dossier.sejour.sexe,
|
|
"age": dossier.sejour.age,
|
|
"duree_sejour": dossier.sejour.duree_sejour,
|
|
"imc": dossier.sejour.imc,
|
|
"antecedents": dossier.antecedents[:5],
|
|
"biologie_cle": [(b.test, b.valeur, b.anomalie) for b in dossier.biologie_cle],
|
|
"imagerie": [(i.type, (i.conclusion or "")[:200]) for i in dossier.imagerie],
|
|
"complications": dossier.complications,
|
|
}
|
|
|
|
if dossier.diagnostic_principal:
|
|
logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte)
|
|
enrich_diagnostic(dossier.diagnostic_principal, contexte, est_dp=True)
|
|
|
|
# Pour les DAS, ajouter le DP et les DAS existants au contexte pour cohérence
|
|
if dossier.diagnostic_principal:
|
|
contexte["dp_texte"] = dossier.diagnostic_principal.texte
|
|
contexte["das_codes_existants"] = [
|
|
f"{d.cim10_suggestion} ({d.texte})"
|
|
for d in dossier.diagnostics_associes
|
|
if d.cim10_suggestion
|
|
]
|
|
|
|
for das in dossier.diagnostics_associes:
|
|
logger.info("RAG enrichissement DAS : %s", das.texte)
|
|
enrich_diagnostic(das, contexte, est_dp=False)
|