"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10.""" from __future__ import annotations import json import logging import re from typing import Optional import requests from ..config import Diagnostic, DossierMedical, RAGSource, OLLAMA_URL, OLLAMA_MODEL, OLLAMA_TIMEOUT logger = logging.getLogger(__name__) # Singleton pour le modèle d'embedding (chargé une seule fois) _embed_model = None # Score minimum de similarité FAISS pour retenir un résultat _MIN_SCORE = 0.3 # Marqueur de fin de raisonnement dans la réponse Ollama _RESULT_MARKER = "###RESULT###" def _get_embed_model(): """Charge le modèle d'embedding (singleton).""" global _embed_model if _embed_model is None: from sentence_transformers import SentenceTransformer logger.info("Chargement du modèle d'embedding pour la recherche...") import torch _device = "cuda" if torch.cuda.is_available() else "cpu" _embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device=_device) _embed_model.max_seq_length = 512 return _embed_model def search_similar(query: str, top_k: int = 10) -> list[dict]: """Recherche les passages les plus similaires dans l'index FAISS. Args: query: Texte du diagnostic à rechercher. top_k: Nombre de résultats à retourner. Returns: Liste de dicts avec les métadonnées + score de similarité, filtrés par score minimum et priorisant les sources CIM-10. """ from .rag_index import get_index import numpy as np result = get_index() if result is None: logger.warning("Index FAISS non disponible") return [] faiss_index, metadata = result model = _get_embed_model() query_vec = model.encode([query], normalize_embeddings=True) query_vec = np.array(query_vec, dtype=np.float32) # Chercher plus de résultats que top_k pour pouvoir filtrer ensuite fetch_k = min(top_k * 2, faiss_index.ntotal) scores, indices = faiss_index.search(query_vec, fetch_k) raw_results = [] for score, idx in zip(scores[0], indices[0]): if idx < 0: continue if float(score) < _MIN_SCORE: continue meta = metadata[idx].copy() meta["score"] = float(score) raw_results.append(meta) # Prioriser les sources CIM-10 (au moins 6 sur top_k) cim10_results = [r for r in raw_results if r["document"] in ("cim10", "cim10_alpha")] other_results = [r for r in raw_results if r["document"] not in ("cim10", "cim10_alpha")] min_cim10 = min(6, len(cim10_results)) final = cim10_results[:min_cim10] remaining_slots = top_k - len(final) # Remplir le reste avec les meilleurs résultats (CIM-10 restants + autres) remaining = cim10_results[min_cim10:] + other_results remaining.sort(key=lambda r: r["score"], reverse=True) final.extend(remaining[:remaining_slots]) return final def _format_contexte(contexte: dict) -> str: """Formate le contexte patient de manière structurée pour le prompt.""" lines = [] sexe = contexte.get("sexe") age = contexte.get("age") imc = contexte.get("imc") patient_parts = [] if sexe: patient_parts.append(sexe) if age: patient_parts.append(f"{age} ans") if imc: patient_parts.append(f"IMC {imc}") if patient_parts: lines.append(f"- Patient : {', '.join(patient_parts)}") duree = contexte.get("duree_sejour") if duree: lines.append(f"- Durée séjour : {duree} jours") antecedents = contexte.get("antecedents") if antecedents: lines.append(f"- Antécédents : {', '.join(antecedents[:5])}") biologie = contexte.get("biologie_cle") if biologie: bio_parts = [] for b in biologie: test, valeur, anomalie = b if isinstance(b, (list, tuple)) else (b.get("test"), b.get("valeur"), b.get("anomalie")) marker = " (\u2191)" if anomalie else "" bio_parts.append(f"{test} {valeur}{marker}") lines.append(f"- Biologie : {', '.join(bio_parts)}") imagerie = contexte.get("imagerie") if imagerie: for img in imagerie: img_type, conclusion = img if isinstance(img, (list, tuple)) else (img.get("type"), img.get("conclusion")) if conclusion: lines.append(f"- Imagerie : {img_type} — {conclusion[:200]}") complications = contexte.get("complications") if complications: lines.append(f"- Complications : {', '.join(complications)}") dp_texte = contexte.get("dp_texte") if dp_texte: lines.append(f"- DP du séjour : {dp_texte}") das_codes = contexte.get("das_codes_existants") if das_codes: lines.append(f"- DAS déjà codés : {', '.join(das_codes)}") return "\n".join(lines) if lines else "Non précisé" def _build_prompt(texte: str, sources: list[dict], contexte: dict, est_dp: bool = True) -> str: """Construit le prompt expert DIM avec raisonnement structuré.""" sources_text = "" for i, src in enumerate(sources, 1): doc_name = { "cim10": "CIM-10 FR 2026", "cim10_alpha": "CIM-10 Index Alphabétique 2026", "guide_methodo": "Guide Méthodologique MCO 2026", "ccam": "CCAM PMSI V4 2025", }.get(src["document"], src["document"]) code_info = f" (code: {src['code']})" if src.get("code") else "" page_info = f" [page {src['page']}]" if src.get("page") else "" sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n" sources_text += (src.get("extrait", "")[:800]) + "\n\n" type_diag = "DP (diagnostic principal)" if est_dp else "DAS (diagnostic associé significatif)" ctx_str = _format_contexte(contexte) return f"""Tu es un médecin DIM (Département d'Information Médicale) expert en codage PMSI. Tu dois coder le diagnostic suivant en respectant STRICTEMENT les règles de l'ATIH. RÈGLES IMPÉRATIVES : - Le code doit provenir UNIQUEMENT des sources CIM-10 fournies - Distingue la DESCRIPTION CLINIQUE (ce que le médecin écrit) de la LOGIQUE DE CODAGE (ce que l'ATIH impose) - Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère) - Vérifie les notes d'inclusion/exclusion de chaque code candidat - Si le diagnostic est un DP, il doit refléter le motif principal de prise en charge du séjour - Si c'est un DAS, il doit avoir mobilisé des ressources supplémentaires pendant le séjour - EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis (Chapitres I-XIV, A00-N99) expliquant ce symptôme est présent, le symptôme ne doit PAS être codé comme DAS DIAGNOSTIC À CODER : "{texte}" TYPE : {type_diag} CONTEXTE CLINIQUE : {ctx_str} SOURCES CIM-10 : {sources_text} RAISONNE ÉTAPE PAR ÉTAPE : 1. ANALYSE CLINIQUE : Que signifie ce diagnostic sur le plan médical ? 2. CODES CANDIDATS : Quels codes des sources fournies sont compatibles ? 3. DISCRIMINATION : Pourquoi choisir un code plutôt qu'un autre ? (inclusions/exclusions, spécificité) 4. RÈGLE PMSI : Ce code est-il conforme pour un {type_diag} ? (guide méthodologique) Après ton raisonnement, conclus OBLIGATOIREMENT par le JSON suivant sur une ligne séparée : {_RESULT_MARKER} {{"code": "X99.9", "confidence": "high|medium|low", "justification": "explication courte en français"}}""" def _parse_ollama_response(raw: str) -> dict | None: """Parse la réponse Ollama en extrayant le JSON après le marqueur ###RESULT###. Fallback sur la recherche d'accolades si le marqueur est absent. Retourne un dict avec les clés code/confidence/justification + raisonnement. """ raisonnement = None json_str = None # Stratégie 1 : chercher le marqueur ###RESULT### marker_pos = raw.find(_RESULT_MARKER) if marker_pos != -1: raisonnement = raw[:marker_pos].strip() after_marker = raw[marker_pos + len(_RESULT_MARKER):] brace_start = after_marker.find("{") brace_end = after_marker.rfind("}") if brace_start != -1 and brace_end != -1: json_str = after_marker[brace_start:brace_end + 1] else: # Fallback : chercher le dernier bloc JSON dans la réponse # (le raisonnement peut contenir des accolades intermédiaires) last_brace = raw.rfind("}") if last_brace != -1: # Chercher l'accolade ouvrante correspondante en remontant depth = 0 start = -1 for i in range(last_brace, -1, -1): if raw[i] == "}": depth += 1 elif raw[i] == "{": depth -= 1 if depth == 0: start = i break if start != -1: json_str = raw[start:last_brace + 1] raisonnement = raw[:start].strip() if not json_str: logger.warning("Ollama : réponse sans JSON valide : %s", raw[:200]) return None try: parsed = json.loads(json_str) except json.JSONDecodeError: logger.warning("Ollama : JSON invalide : %s", json_str[:200]) return None if raisonnement: parsed["raisonnement"] = raisonnement return parsed def _call_ollama(prompt: str) -> dict | None: """Appelle Ollama et parse la réponse JSON.""" try: response = requests.post( f"{OLLAMA_URL}/api/generate", json={ "model": OLLAMA_MODEL, "prompt": prompt, "stream": False, "options": { "temperature": 0.1, "num_predict": 1200, }, }, timeout=OLLAMA_TIMEOUT, ) response.raise_for_status() raw = response.json().get("response", "") return _parse_ollama_response(raw) except requests.ConnectionError: logger.warning("Ollama non disponible (connexion refusée)") return None except requests.Timeout: logger.warning("Ollama timeout après %ds", OLLAMA_TIMEOUT) return None except (requests.RequestException, json.JSONDecodeError) as e: logger.warning("Ollama erreur : %s", e) return None def enrich_diagnostic( diagnostic: Diagnostic, contexte: dict, est_dp: bool = True, ) -> None: """Enrichit un Diagnostic avec le RAG (FAISS + Ollama). Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent. """ # 1. Recherche FAISS sources = search_similar(diagnostic.texte, top_k=10) if not sources: logger.debug("Aucune source RAG trouvée pour : %s", diagnostic.texte) return # 2. Stocker les sources RAG diagnostic.sources_rag = [ RAGSource( document=s["document"], page=s.get("page"), code=s.get("code"), extrait=s.get("extrait", "")[:200], ) for s in sources ] # 3. Appel Ollama pour justification avec raisonnement structuré prompt = _build_prompt(diagnostic.texte, sources, contexte, est_dp=est_dp) llm_result = _call_ollama(prompt) if llm_result: code = llm_result.get("code") confidence = llm_result.get("confidence") justification = llm_result.get("justification") raisonnement = llm_result.get("raisonnement") if code: diagnostic.cim10_suggestion = code if confidence in ("high", "medium", "low"): diagnostic.cim10_confidence = confidence if justification: diagnostic.justification = justification if raisonnement: diagnostic.raisonnement = raisonnement else: logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM") def enrich_dossier(dossier: DossierMedical) -> None: """Enrichit le DP et tous les DAS d'un dossier via le RAG.""" contexte = { "sexe": dossier.sejour.sexe, "age": dossier.sejour.age, "duree_sejour": dossier.sejour.duree_sejour, "imc": dossier.sejour.imc, "antecedents": dossier.antecedents[:5], "biologie_cle": [(b.test, b.valeur, b.anomalie) for b in dossier.biologie_cle], "imagerie": [(i.type, (i.conclusion or "")[:200]) for i in dossier.imagerie], "complications": dossier.complications, } if dossier.diagnostic_principal: logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte) enrich_diagnostic(dossier.diagnostic_principal, contexte, est_dp=True) # Pour les DAS, ajouter le DP et les DAS existants au contexte pour cohérence if dossier.diagnostic_principal: contexte["dp_texte"] = dossier.diagnostic_principal.texte contexte["das_codes_existants"] = [ f"{d.cim10_suggestion} ({d.texte})" for d in dossier.diagnostics_associes if d.cim10_suggestion ] for das in dossier.diagnostics_associes: logger.info("RAG enrichissement DAS : %s", das.texte) enrich_diagnostic(das, contexte, est_dp=False)