"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10.""" from __future__ import annotations import logging from concurrent.futures import ThreadPoolExecutor, as_completed from ..config import ( ActeCCAM, Diagnostic, DossierMedical, RAGSource, OLLAMA_CACHE_PATH, OLLAMA_MAX_PARALLEL, OLLAMA_MODEL, ) from .cim10_dict import normalize_code, validate_code as cim10_validate from .ccam_dict import validate_code as ccam_validate from .ollama_client import call_ollama, parse_json_response from .ollama_cache import OllamaCache logger = logging.getLogger(__name__) # Singleton pour le modèle d'embedding (chargé une seule fois) _embed_model = None # Score minimum de similarité FAISS pour retenir un résultat _MIN_SCORE = 0.3 def _get_embed_model(): """Charge le modèle d'embedding (singleton).""" global _embed_model if _embed_model is None: from sentence_transformers import SentenceTransformer logger.info("Chargement du modèle d'embedding pour la recherche...") import torch _device = "cuda" if torch.cuda.is_available() else "cpu" _embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device=_device) _embed_model.max_seq_length = 512 return _embed_model def search_similar(query: str, top_k: int = 10) -> list[dict]: """Recherche les passages les plus similaires dans l'index FAISS. Args: query: Texte du diagnostic à rechercher. top_k: Nombre de résultats à retourner. Returns: Liste de dicts avec les métadonnées + score de similarité, filtrés par score minimum et priorisant les sources CIM-10. """ from .rag_index import get_index import numpy as np result = get_index() if result is None: logger.warning("Index FAISS non disponible") return [] faiss_index, metadata = result model = _get_embed_model() query_vec = model.encode([query], normalize_embeddings=True) query_vec = np.array(query_vec, dtype=np.float32) # Chercher plus de résultats que top_k pour pouvoir filtrer ensuite fetch_k = min(top_k * 2, faiss_index.ntotal) scores, indices = faiss_index.search(query_vec, fetch_k) raw_results = [] for score, idx in zip(scores[0], indices[0]): if idx < 0: continue if float(score) < _MIN_SCORE: continue meta = metadata[idx].copy() meta["score"] = float(score) raw_results.append(meta) # Prioriser les sources CIM-10 (au moins 6 sur top_k) cim10_results = [r for r in raw_results if r["document"] in ("cim10", "cim10_alpha")] other_results = [r for r in raw_results if r["document"] not in ("cim10", "cim10_alpha")] min_cim10 = min(6, len(cim10_results)) final = cim10_results[:min_cim10] remaining_slots = top_k - len(final) # Remplir le reste avec les meilleurs résultats (CIM-10 restants + autres) remaining = cim10_results[min_cim10:] + other_results remaining.sort(key=lambda r: r["score"], reverse=True) final.extend(remaining[:remaining_slots]) return final def search_similar_ccam(query: str, top_k: int = 8) -> list[dict]: """Recherche les passages CCAM les plus similaires dans l'index FAISS. Même logique que search_similar() mais priorise les sources CCAM. """ from .rag_index import get_index import numpy as np result = get_index() if result is None: logger.warning("Index FAISS non disponible") return [] faiss_index, metadata = result model = _get_embed_model() query_vec = model.encode([query], normalize_embeddings=True) query_vec = np.array(query_vec, dtype=np.float32) fetch_k = min(top_k * 2, faiss_index.ntotal) scores, indices = faiss_index.search(query_vec, fetch_k) raw_results = [] for score, idx in zip(scores[0], indices[0]): if idx < 0: continue if float(score) < _MIN_SCORE: continue meta = metadata[idx].copy() meta["score"] = float(score) raw_results.append(meta) # Prioriser les sources CCAM (au moins 5 sur top_k) ccam_results = [r for r in raw_results if r["document"] == "ccam"] other_results = [r for r in raw_results if r["document"] != "ccam"] min_ccam = min(5, len(ccam_results)) final = ccam_results[:min_ccam] remaining_slots = top_k - len(final) remaining = ccam_results[min_ccam:] + other_results remaining.sort(key=lambda r: r["score"], reverse=True) final.extend(remaining[:remaining_slots]) return final def _format_contexte(contexte: dict) -> str: """Formate le contexte patient de manière structurée pour le prompt.""" lines = [] sexe = contexte.get("sexe") age = contexte.get("age") imc = contexte.get("imc") patient_parts = [] if sexe: patient_parts.append(sexe) if age: patient_parts.append(f"{age} ans") if imc: patient_parts.append(f"IMC {imc}") if patient_parts: lines.append(f"- Patient : {', '.join(patient_parts)}") duree = contexte.get("duree_sejour") if duree: lines.append(f"- Durée séjour : {duree} jours") antecedents = contexte.get("antecedents") if antecedents: lines.append(f"- Antécédents : {', '.join(antecedents[:5])}") biologie = contexte.get("biologie_cle") if biologie: bio_parts = [] for b in biologie: test, valeur, anomalie = b if isinstance(b, (list, tuple)) else (b.get("test"), b.get("valeur"), b.get("anomalie")) marker = " (\u2191)" if anomalie else "" bio_parts.append(f"{test} {valeur}{marker}") lines.append(f"- Biologie : {', '.join(bio_parts)}") imagerie = contexte.get("imagerie") if imagerie: for img in imagerie: img_type, conclusion = img if isinstance(img, (list, tuple)) else (img.get("type"), img.get("conclusion")) if conclusion: lines.append(f"- Imagerie : {img_type} — {conclusion[:200]}") complications = contexte.get("complications") if complications: lines.append(f"- Complications : {', '.join(complications)}") dp_texte = contexte.get("dp_texte") if dp_texte: lines.append(f"- DP du séjour : {dp_texte}") das_codes = contexte.get("das_codes_existants") if das_codes: lines.append(f"- DAS déjà codés : {', '.join(das_codes)}") return "\n".join(lines) if lines else "Non précisé" def _build_prompt(texte: str, sources: list[dict], contexte: dict, est_dp: bool = True) -> str: """Construit le prompt expert DIM avec raisonnement structuré.""" sources_text = "" for i, src in enumerate(sources, 1): doc_name = { "cim10": "CIM-10 FR 2026", "cim10_alpha": "CIM-10 Index Alphabétique 2026", "guide_methodo": "Guide Méthodologique MCO 2026", "ccam": "CCAM PMSI V4 2025", }.get(src["document"], src["document"]) code_info = f" (code: {src['code']})" if src.get("code") else "" page_info = f" [page {src['page']}]" if src.get("page") else "" sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n" sources_text += (src.get("extrait", "")[:800]) + "\n\n" type_diag = "DP (diagnostic principal)" if est_dp else "DAS (diagnostic associé significatif)" ctx_str = _format_contexte(contexte) return f"""Tu es un médecin DIM (Département d'Information Médicale) expert en codage PMSI. Tu dois coder le diagnostic suivant en respectant STRICTEMENT les règles de l'ATIH. RÈGLES IMPÉRATIVES : - Le code doit provenir UNIQUEMENT des sources CIM-10 fournies - Distingue la DESCRIPTION CLINIQUE (ce que le médecin écrit) de la LOGIQUE DE CODAGE (ce que l'ATIH impose) - Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère) - Vérifie les notes d'inclusion/exclusion de chaque code candidat - Si le diagnostic est un DP, il doit refléter le motif principal de prise en charge du séjour - Si c'est un DAS, il doit avoir mobilisé des ressources supplémentaires pendant le séjour - EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis (Chapitres I-XIV, A00-N99) expliquant ce symptôme est présent, le symptôme ne doit PAS être codé comme DAS DIAGNOSTIC À CODER : "{texte}" TYPE : {type_diag} CONTEXTE CLINIQUE : {ctx_str} SOURCES CIM-10 : {sources_text} Réponds UNIQUEMENT avec un objet JSON au format suivant, sans aucun texte avant ou après : {{ "analyse_clinique": "que signifie ce diagnostic sur le plan médical", "codes_candidats": "quels codes CIM-10 des sources sont compatibles", "discrimination": "pourquoi choisir ce code plutôt qu'un autre (inclusions/exclusions, spécificité)", "regle_pmsi": "conformité aux règles PMSI pour un {type_diag} (guide méthodologique)", "code": "X99.9", "confidence": "high ou medium ou low", "justification": "explication courte en français" }}""" def _build_prompt_ccam(texte: str, sources: list[dict], contexte: dict) -> str: """Construit le prompt expert DIM pour le codage CCAM avec raisonnement structuré.""" sources_text = "" for i, src in enumerate(sources, 1): doc_name = { "cim10": "CIM-10 FR 2026", "cim10_alpha": "CIM-10 Index Alphabétique 2026", "guide_methodo": "Guide Méthodologique MCO 2026", "ccam": "CCAM PMSI V4 2025", }.get(src["document"], src["document"]) code_info = f" (code: {src['code']})" if src.get("code") else "" page_info = f" [page {src['page']}]" if src.get("page") else "" sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n" sources_text += (src.get("extrait", "")[:800]) + "\n\n" ctx_str = _format_contexte(contexte) return f"""Tu es un médecin DIM (Département d'Information Médicale) expert en codage CCAM PMSI. Tu dois coder l'acte chirurgical/médical suivant en respectant STRICTEMENT la nomenclature CCAM. RÈGLES IMPÉRATIVES : - Le code doit provenir UNIQUEMENT des sources CCAM fournies - Un code CCAM est composé de 4 lettres + 3 chiffres (ex: HMFC004) - Vérifie l'activité (1=acte technique, 4=anesthésie) et le regroupement - Tiens compte du tarif secteur 1 pour valider la cohérence - Si plusieurs codes sont possibles, choisis le plus spécifique à l'acte décrit - En cas de doute, indique confidence "low" plutôt que de proposer un code inadapté ACTE À CODER : "{texte}" CONTEXTE CLINIQUE : {ctx_str} SOURCES CCAM : {sources_text} Réponds UNIQUEMENT avec un objet JSON au format suivant, sans aucun texte avant ou après : {{ "analyse_acte": "que décrit cet acte sur le plan technique/chirurgical", "codes_candidats": "quels codes CCAM des sources sont compatibles", "discrimination": "pourquoi choisir ce code plutôt qu'un autre (activité, regroupement, tarif)", "code": "ABCD123", "confidence": "high ou medium ou low", "justification": "explication courte en français" }}""" def _parse_ollama_response(raw: str) -> dict | None: """Parse la réponse JSON d'Ollama et reconstitue le raisonnement structuré.""" parsed = parse_json_response(raw) if parsed is None: return None # Reconstituer le raisonnement à partir des champs structurés reasoning_parts = [] for key in ("analyse_clinique", "analyse_acte", "codes_candidats", "discrimination", "regle_pmsi"): val = parsed.pop(key, None) if val: titre = key.replace("_", " ").upper() reasoning_parts.append(f"{titre} :\n{val}") if reasoning_parts: parsed["raisonnement"] = "\n\n".join(reasoning_parts) return parsed def _call_ollama(prompt: str) -> dict | None: """Appelle Ollama (mode JSON) et parse la réponse avec reconstitution du raisonnement.""" result = call_ollama(prompt, temperature=0.1, max_tokens=2500) if result is None: return None # Reconstituer le raisonnement structuré reasoning_parts = [] for key in ("analyse_clinique", "analyse_acte", "codes_candidats", "discrimination", "regle_pmsi"): val = result.pop(key, None) if val: titre = key.replace("_", " ").upper() reasoning_parts.append(f"{titre} :\n{val}") if reasoning_parts: result["raisonnement"] = "\n\n".join(reasoning_parts) return result def _apply_llm_result_diagnostic(diagnostic: Diagnostic, llm_result: dict) -> None: """Applique un résultat LLM (frais ou caché) à un Diagnostic.""" code = llm_result.get("code") confidence = llm_result.get("confidence") justification = llm_result.get("justification") raisonnement = llm_result.get("raisonnement") if code: code = normalize_code(code) is_valid, _ = cim10_validate(code) if is_valid: diagnostic.cim10_suggestion = code else: logger.warning( "RAG : code Ollama %s invalide pour « %s », code ignoré", code, diagnostic.texte, ) if confidence in ("high", "medium", "low"): diagnostic.cim10_confidence = confidence if justification: diagnostic.justification = justification if raisonnement: diagnostic.raisonnement = raisonnement def enrich_diagnostic( diagnostic: Diagnostic, contexte: dict, est_dp: bool = True, cache: OllamaCache | None = None, ) -> None: """Enrichit un Diagnostic avec le RAG (FAISS + Ollama). Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent. """ diag_type = "dp" if est_dp else "das" # 1. Vérifier le cache cached = cache.get(diagnostic.texte, diag_type) if cache else None # 2. Recherche FAISS (toujours, pour les sources_rag fraîches) sources = search_similar(diagnostic.texte, top_k=10) if not sources: logger.debug("Aucune source RAG trouvée pour : %s", diagnostic.texte) return # 3. Stocker les sources RAG diagnostic.sources_rag = [ RAGSource( document=s["document"], page=s.get("page"), code=s.get("code"), extrait=s.get("extrait", "")[:200], ) for s in sources ] # 4. Si cache hit, appliquer et court-circuiter Ollama if cached is not None: logger.info("Cache hit pour %s : « %s »", diag_type.upper(), diagnostic.texte) _apply_llm_result_diagnostic(diagnostic, cached) return # 5. Appel Ollama pour justification avec raisonnement structuré prompt = _build_prompt(diagnostic.texte, sources, contexte, est_dp=est_dp) llm_result = _call_ollama(prompt) if llm_result: _apply_llm_result_diagnostic(diagnostic, llm_result) if cache: cache.put(diagnostic.texte, diag_type, llm_result) else: logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM") def _apply_llm_result_acte(acte: ActeCCAM, llm_result: dict) -> None: """Applique un résultat LLM (frais ou caché) à un ActeCCAM.""" code = llm_result.get("code") confidence = llm_result.get("confidence") justification = llm_result.get("justification") raisonnement = llm_result.get("raisonnement") if code: code = code.strip().upper() is_valid, _ = ccam_validate(code) if is_valid: acte.code_ccam_suggestion = code else: logger.warning( "RAG : code CCAM Ollama %s invalide pour « %s », code ignoré", code, acte.texte, ) if confidence in ("high", "medium", "low"): acte.ccam_confidence = confidence if justification: acte.justification = justification if raisonnement: acte.raisonnement = raisonnement def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None) -> None: """Enrichit un ActeCCAM avec le RAG (FAISS + Ollama). Modifie l'acte en place. Fallback gracieux si FAISS ou Ollama échouent. """ # 1. Vérifier le cache cached = cache.get(acte.texte, "ccam") if cache else None # 2. Recherche FAISS (sources CCAM priorisées) sources = search_similar_ccam(acte.texte, top_k=8) if not sources: logger.debug("Aucune source RAG CCAM trouvée pour : %s", acte.texte) return # 3. Stocker les sources RAG acte.sources_rag = [ RAGSource( document=s["document"], page=s.get("page"), code=s.get("code"), extrait=s.get("extrait", "")[:200], ) for s in sources ] # 4. Si cache hit, appliquer et court-circuiter Ollama if cached is not None: logger.info("Cache hit pour CCAM : « %s »", acte.texte) _apply_llm_result_acte(acte, cached) return # 5. Appel Ollama pour justification avec raisonnement structuré prompt = _build_prompt_ccam(acte.texte, sources, contexte) llm_result = _call_ollama(prompt) if llm_result: _apply_llm_result_acte(acte, llm_result) if cache: cache.put(acte.texte, "ccam", llm_result) else: logger.info("Ollama non disponible — sources FAISS CCAM conservées sans justification LLM") def enrich_dossier(dossier: DossierMedical) -> None: """Enrichit le DP et tous les DAS d'un dossier via le RAG. Utilise un cache persistant et parallélise les appels Ollama pour les DAS et actes CCAM (max_workers = OLLAMA_MAX_PARALLEL). """ cache = OllamaCache(OLLAMA_CACHE_PATH, OLLAMA_MODEL) contexte = { "sexe": dossier.sejour.sexe, "age": dossier.sejour.age, "duree_sejour": dossier.sejour.duree_sejour, "imc": dossier.sejour.imc, "antecedents": dossier.antecedents[:5], "biologie_cle": [(b.test, b.valeur, b.anomalie) for b in dossier.biologie_cle], "imagerie": [(i.type, (i.conclusion or "")[:200]) for i in dossier.imagerie], "complications": dossier.complications, } # Phase 1 : DP seul (le contexte DAS en dépend) if dossier.diagnostic_principal: logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte) enrich_diagnostic(dossier.diagnostic_principal, contexte, est_dp=True, cache=cache) # Mettre à jour le contexte avec le DP pour les DAS if dossier.diagnostic_principal: contexte["dp_texte"] = dossier.diagnostic_principal.texte contexte["das_codes_existants"] = [ f"{d.cim10_suggestion} ({d.texte})" for d in dossier.diagnostics_associes if d.cim10_suggestion ] # Phase 2 : DAS + Actes en parallèle das_list = dossier.diagnostics_associes actes_list = dossier.actes_ccam if das_list or actes_list: with ThreadPoolExecutor(max_workers=OLLAMA_MAX_PARALLEL) as executor: futures = [] for das in das_list: logger.info("RAG enrichissement DAS : %s", das.texte) futures.append(executor.submit(enrich_diagnostic, das, contexte, False, cache)) for acte in actes_list: logger.info("RAG enrichissement CCAM : %s", acte.texte) futures.append(executor.submit(enrich_acte, acte, contexte, cache)) for f in as_completed(futures): f.result() # propage les exceptions cache.save()