"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10.""" from __future__ import annotations import logging import os import threading import time from concurrent.futures import ThreadPoolExecutor, as_completed from ..config import ( ActeCCAM, Diagnostic, DossierMedical, PreuveClinique, RAGSource, OLLAMA_CACHE_PATH, OLLAMA_MAX_PARALLEL, get_model, EMBEDDING_MODEL, RERANKER_MODEL, ) from .cim10_dict import normalize_code, validate_code as cim10_validate, fallback_parent_code from .bio_normals import BIO_NORMALS from .clinical_context import build_enriched_context, format_enriched_context from .ccam_dict import validate_code as ccam_validate from .ollama_client import call_ollama, parse_json_response from .ollama_cache import OllamaCache from ..prompts import CODING_CIM10, CODING_CCAM, DAS_EXTRACTION logger = logging.getLogger(__name__) # Singleton pour le modèle d'embedding (chargé une seule fois) _embed_model = None _embed_lock = threading.Lock() _embed_failed = False # Sentinelle pour éviter les retries infinis # Singleton pour le cross-encoder de re-ranking (CPU uniquement) _reranker_model = None _reranker_lock = threading.Lock() # Cache d'embeddings : évite de recalculer les vecteurs pour les mêmes textes _embedding_cache: dict[str, "numpy.ndarray"] = {} _embedding_cache_lock = threading.Lock() _EMBEDDING_CACHE_MAX = 5000 # Score minimum de similarité FAISS pour retenir un résultat _MIN_SCORE = 0.3 # Seuil rehaussé pour le contexte CPAM (filtrage plus agressif du bruit) _MIN_SCORE_CPAM = 0.40 def _get_embed_model(): """Charge le modèle d'embedding (singleton thread-safe). Tente CUDA d'abord, fallback CPU si OOM (Ollama peut occuper la VRAM). low_cpu_mem_usage=False évite les meta tensors (accelerate + sentence-transformers 5.x). Un Lock empêche les chargements concurrents depuis le ThreadPool. """ global _embed_model, _embed_failed if _embed_model is not None: return _embed_model if _embed_failed: raise RuntimeError("Modèle d'embedding indisponible (échec précédent)") with _embed_lock: # Double-check après acquisition du lock if _embed_model is not None: return _embed_model if _embed_failed: raise RuntimeError("Modèle d'embedding indisponible (échec précédent)") from sentence_transformers import SentenceTransformer import torch _device = "cpu" if os.environ.get("T2A_EMBED_CPU") else ("cuda" if torch.cuda.is_available() else "cpu") _model_kwargs = {"low_cpu_mem_usage": False} try: logger.info("Chargement du modèle d'embedding (%s)...", _device) _embed_model = SentenceTransformer( EMBEDDING_MODEL, device=_device, model_kwargs=_model_kwargs, ) except (torch.OutOfMemoryError, torch.cuda.CudaError, torch.AcceleratorError, RuntimeError, NotImplementedError) as exc: exc_msg = str(exc).lower() if _device == "cuda" and ("memory" in exc_msg or "meta tensor" in exc_msg): logger.warning("CUDA erreur pour l'embedding — fallback CPU : %s", exc) torch.cuda.empty_cache() try: _embed_model = SentenceTransformer( EMBEDDING_MODEL, device="cpu", model_kwargs=_model_kwargs, ) except Exception as exc2: logger.error("Fallback CPU aussi en échec : %s", exc2) _embed_failed = True raise else: _embed_failed = True raise _embed_model.max_seq_length = 512 return _embed_model def _get_reranker(): """Charge le cross-encoder de re-ranking (singleton thread-safe, CPU uniquement). Forcé sur CPU pour ne pas interférer avec Ollama sur GPU. """ global _reranker_model if _reranker_model is not None: return _reranker_model with _reranker_lock: # Double-check après acquisition du lock if _reranker_model is not None: return _reranker_model from sentence_transformers import CrossEncoder logger.info("Chargement du cross-encoder de re-ranking (cpu)...") _reranker_model = CrossEncoder(RERANKER_MODEL, device="cpu") return _reranker_model def _rerank(query: str, results: list[dict], top_k: int) -> list[dict]: """Re-classe les résultats FAISS via un cross-encoder. Args: query: Texte de la requête originale. results: Résultats FAISS avec clé 'extrait'. top_k: Nombre de résultats à retourner. Returns: Résultats re-classés par score cross-encoder, limités à top_k. """ if not results: return results passages = [r.get("extrait", "") for r in results] # Essayer le serveur distant d'abord ce_scores = None try: from .remote_embed import rerank_remote remote_scores = rerank_remote(query, passages) if remote_scores is not None: ce_scores = remote_scores except ImportError: pass # Fallback local if ce_scores is None: reranker = _get_reranker() pairs = [(query, p) for p in passages] ce_scores = reranker.predict(pairs) # Injecter le score cross-encoder et trier for r, ce_score in zip(results, ce_scores): r["score_faiss"] = r["score"] r["score"] = float(ce_score) results.sort(key=lambda r: r["score"], reverse=True) return results[:top_k] def _embed_cached(texts: list[str]) -> "numpy.ndarray": """Calcule les embeddings avec cache. Essaie le serveur distant d'abord.""" import numpy as np results = [None] * len(texts) to_compute: list[tuple[int, str]] = [] with _embedding_cache_lock: for i, t in enumerate(texts): cached = _embedding_cache.get(t) if cached is not None: results[i] = cached else: to_compute.append((i, t)) if to_compute: new_texts = [t for _, t in to_compute] # Essayer le serveur distant d'abord new_vecs = None try: from .remote_embed import embed_remote new_vecs = embed_remote(new_texts) except ImportError: pass # Fallback local if new_vecs is None: model = _get_embed_model() new_vecs = model.encode(new_texts, normalize_embeddings=True, batch_size=64) new_vecs = np.array(new_vecs, dtype=np.float32) with _embedding_cache_lock: for j, (i, t) in enumerate(to_compute): vec = new_vecs[j] results[i] = vec _embedding_cache[t] = vec # Eviction simple si trop d'entrées if len(_embedding_cache) > _EMBEDDING_CACHE_MAX: keys = list(_embedding_cache.keys()) for k in keys[:len(keys) // 5]: del _embedding_cache[k] return np.array(results, dtype=np.float32) def search_similar(query: str, top_k: int = 10) -> list[dict]: """Recherche les passages les plus similaires dans l'index FAISS. Args: query: Texte du diagnostic à rechercher. top_k: Nombre de résultats à retourner. Returns: Liste de dicts avec les métadonnées + score de similarité, filtrés par score minimum et priorisant les sources CIM-10. """ from .rag_index import get_index import numpy as np # Codage CIM-10 : on interroge l'index "ref" (pas le guide méthodo). result = get_index(kind="ref") if result is None: logger.warning("Index FAISS non disponible") return [] faiss_index, metadata = result query_vec = _embed_cached([query]) # Chercher plus de résultats que top_k pour pouvoir filtrer ensuite fetch_k = min(top_k * 2, faiss_index.ntotal) scores, indices = faiss_index.search(query_vec, fetch_k) raw_results = [] for score, idx in zip(scores[0], indices[0]): if idx < 0: continue if float(score) < _MIN_SCORE: continue meta = metadata[idx].copy() meta["score"] = float(score) raw_results.append(meta) # Codage : on garde uniquement CIM-10 + index alpha + éventuels référentiels uploadés en ref:... cim10_results = [r for r in raw_results if r.get("document") == "cim10"] alpha_results = [r for r in raw_results if r.get("document") == "cim10_alpha"] ref_uploads = [r for r in raw_results if str(r.get("document", "")).startswith("ref:")] cim10_results.sort(key=lambda r: r["score"], reverse=True) alpha_results.sort(key=lambda r: r["score"], reverse=True) ref_uploads.sort(key=lambda r: r["score"], reverse=True) # Quotas : on veut garder le codage ancré sur CIM-10, tout en gardant un peu d'alpha et de ref. q_cim10 = min(6, top_k) q_alpha = 2 if top_k >= 10 else (1 if top_k >= 8 else 0) q_alpha = min(q_alpha, max(0, top_k - q_cim10)) q_ref = max(0, top_k - q_cim10 - q_alpha) q_ref = min(q_ref, 2) # éviter que les uploads 'ref:' prennent tout l'espace contexte final: list[dict] = [] final.extend(cim10_results[:q_cim10]) final.extend(alpha_results[:q_alpha]) final.extend(ref_uploads[:q_ref]) # Compléter si on a moins que top_k (ex: pas assez d'alpha/ref) if len(final) < top_k: remaining = cim10_results[q_cim10:] + alpha_results[q_alpha:] + ref_uploads[q_ref:] remaining.sort(key=lambda r: r["score"], reverse=True) final.extend(remaining[: (top_k - len(final))]) return final def search_similar_ccam(query: str, top_k: int = 8) -> list[dict]: """Recherche les passages CCAM les plus similaires dans l'index FAISS. Même logique que search_similar() mais priorise les sources CCAM. """ from .rag_index import get_index import numpy as np # CCAM : index "ref". result = get_index(kind="ref") if result is None: logger.warning("Index FAISS non disponible") return [] faiss_index, metadata = result query_vec = _embed_cached([query]) fetch_k = min(top_k * 2, faiss_index.ntotal) scores, indices = faiss_index.search(query_vec, fetch_k) raw_results = [] for score, idx in zip(scores[0], indices[0]): if idx < 0: continue if float(score) < _MIN_SCORE: continue meta = metadata[idx].copy() meta["score"] = float(score) raw_results.append(meta) # Prioriser les sources CCAM (au moins 5 sur top_k) ccam_results = [r for r in raw_results if r["document"] == "ccam"] other_results = [r for r in raw_results if r["document"] != "ccam"] min_ccam = min(5, len(ccam_results)) final = ccam_results[:min_ccam] remaining_slots = top_k - len(final) remaining = ccam_results[min_ccam:] + other_results remaining.sort(key=lambda r: r["score"], reverse=True) final.extend(remaining[:remaining_slots]) return final def search_similar_cpam(query: str, top_k: int = 8) -> list[dict]: """Recherche RAG spécifique au contexte CPAM (contre-argumentation). Différences avec search_similar() : - Priorité Guide Méthodologique (min 3 résultats) plutôt que CIM-10 - Seuil de score rehaussé (0.40 vs 0.30) pour éliminer le bruit - Fetch élargi (top_k * 3) car filtrage plus agressif - Déduplication par code CIM-10 (garde le meilleur score par code) """ from .rag_index import get_index import numpy as np # Contexte CPAM : on veut des procédures (guide) + définitions référentielles (CIM-10). proc = get_index(kind="proc") ref = get_index(kind="ref") if proc is None and ref is None: logger.warning("Index FAISS non disponible") return [] query_vec = _embed_cached([query]) def _search_one(result_tuple, fetch_mult: int) -> list[dict]: if result_tuple is None: return [] faiss_index, metadata = result_tuple fetch_k = min(top_k * fetch_mult, faiss_index.ntotal) scores, indices = faiss_index.search(query_vec, fetch_k) out = [] for score, idx in zip(scores[0], indices[0]): if idx < 0: continue if float(score) < _MIN_SCORE_CPAM: continue meta = metadata[idx].copy() meta["score"] = float(score) out.append(meta) return out raw_proc = _search_one(proc, fetch_mult=3) raw_ref = _search_one(ref, fetch_mult=3) # Filtrer clairement : # - proc : guide_methodo + uploads proc: raw_proc = [r for r in raw_proc if r.get("document") == "guide_methodo" or str(r.get("document", "")).startswith("proc:")] # - ref : CIM-10 + index alpha + uploads ref: raw_ref = [r for r in raw_ref if r.get("document") in ("cim10", "cim10_alpha") or str(r.get("document", "")).startswith("ref:")] raw_results = raw_proc + raw_ref # Dédupliquer par code CIM-10 (garder meilleur score par code) seen_codes: dict[str, dict] = {} deduped = [] for r in raw_results: code = r.get("code") if code: if code in seen_codes: if r["score"] > seen_codes[code]["score"]: seen_codes[code] = r else: seen_codes[code] = r else: deduped.append(r) # pas de code → garder (guide_methodo, etc.) deduped.extend(seen_codes.values()) deduped.sort(key=lambda r: r["score"], reverse=True) # Re-ranking cross-encoder (CPU) pour affiner le classement reranked = _rerank(query, deduped, top_k=len(deduped)) # Prioriser le Guide Méthodologique (min 3 résultats) guide_results = [r for r in reranked if r.get("document") == "guide_methodo" or str(r.get("document", "")).startswith("proc:")] other_results = [ r for r in reranked if not (r.get("document") == "guide_methodo" or str(r.get("document", "")).startswith("proc:")) ] min_guide = min(3, len(guide_results)) final = guide_results[:min_guide] remaining_slots = top_k - len(final) remaining = guide_results[min_guide:] + other_results remaining.sort(key=lambda r: r["score"], reverse=True) final.extend(remaining[:remaining_slots]) return final def _format_contexte(contexte: dict) -> str: """Formate le contexte patient de manière structurée pour le prompt.""" lines = [] sexe = contexte.get("sexe") age = contexte.get("age") imc = contexte.get("imc") patient_parts = [] if sexe: patient_parts.append(sexe) if age: patient_parts.append(f"{age} ans") if imc: patient_parts.append(f"IMC {imc}") if patient_parts: lines.append(f"- Patient : {', '.join(patient_parts)}") duree = contexte.get("duree_sejour") if duree: lines.append(f"- Durée séjour : {duree} jours") antecedents = contexte.get("antecedents") if antecedents: lines.append(f"- Antécédents : {', '.join(antecedents[:5])}") biologie = contexte.get("biologie_cle") if biologie: bio_parts = [] for b in biologie: test, valeur, anomalie = b if isinstance(b, (list, tuple)) else (b.get("test"), b.get("valeur"), b.get("anomalie")) # Ajouter la plage de référence si connue norme_str = "" if test in BIO_NORMALS: lo, hi = BIO_NORMALS[test] lo_s = int(lo) if lo == int(lo) else lo hi_s = int(hi) if hi == int(hi) else hi norme_str = f" [N: {lo_s}-{hi_s}]" marker = " (\u2191)" if anomalie else "" bio_parts.append(f"{test} {valeur}{norme_str}{marker}") lines.append(f"- Biologie : {', '.join(bio_parts)}") imagerie = contexte.get("imagerie") if imagerie: for img in imagerie: img_type, conclusion = img if isinstance(img, (list, tuple)) else (img.get("type"), img.get("conclusion")) if conclusion: lines.append(f"- Imagerie : {img_type} — {conclusion[:200]}") complications = contexte.get("complications") if complications: lines.append(f"- Complications : {', '.join(complications)}") dp_texte = contexte.get("dp_texte") if dp_texte: lines.append(f"- DP du séjour : {dp_texte}") das_codes = contexte.get("das_codes_existants") if das_codes: lines.append(f"- DAS déjà codés : {', '.join(das_codes)}") return "\n".join(lines) if lines else "Non précisé" def _format_sources(sources: list[dict]) -> str: """Formate les sources RAG pour injection dans un prompt.""" sources_text = "" for i, src in enumerate(sources, 1): doc_raw = str(src.get("document", "")) if doc_raw.startswith("ref:"): doc_name = f"Référentiel uploadé : {doc_raw[4:]}" elif doc_raw.startswith("proc:"): doc_name = f"Procédure uploadée : {doc_raw[5:]}" else: doc_name = { "cim10": "CIM-10 FR 2026", "cim10_alpha": "CIM-10 Index Alphabétique 2026", "guide_methodo": "Guide Méthodologique MCO 2026", "ccam": "CCAM PMSI V4 2025", }.get(doc_raw, doc_raw) code_info = f" (code: {src['code']})" if src.get("code") else "" page_info = f" [page {src['page']}]" if src.get("page") else "" sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n" sources_text += (src.get("extrait", "")[:800]) + "\n\n" return sources_text def _build_prompt(texte: str, sources: list[dict], contexte: dict, est_dp: bool = True) -> str: """Construit le prompt expert DIM avec raisonnement structuré.""" type_diag = "DP (diagnostic principal)" if est_dp else "DAS (diagnostic associé significatif)" ctx_str = format_enriched_context(contexte) sources_text = _format_sources(sources) return CODING_CIM10.format( texte=texte, type_diag=type_diag, ctx_str=ctx_str, sources_text=sources_text, ) def _build_prompt_ccam(texte: str, sources: list[dict], contexte: dict) -> str: """Construit le prompt expert DIM pour le codage CCAM avec raisonnement structuré.""" ctx_str = format_enriched_context(contexte) sources_text = _format_sources(sources) return CODING_CCAM.format( texte=texte, ctx_str=ctx_str, sources_text=sources_text, ) def _parse_ollama_response(raw: str) -> dict | None: """Parse la réponse JSON d'Ollama et reconstitue le raisonnement structuré.""" parsed = parse_json_response(raw) if parsed is None: return None # Reconstituer le raisonnement à partir des champs structurés reasoning_parts = [] for key in ("analyse_clinique", "analyse_acte", "codes_candidats", "discrimination", "regle_pmsi"): val = parsed.pop(key, None) if val: titre = key.replace("_", " ").upper() reasoning_parts.append(f"{titre} :\n{val}") if reasoning_parts: parsed["raisonnement"] = "\n\n".join(reasoning_parts) return parsed def _call_ollama(prompt: str) -> dict | None: """Appelle Ollama (mode JSON) et parse la réponse avec reconstitution du raisonnement.""" result = call_ollama(prompt, temperature=0.1, max_tokens=2500, role="coding") if result is None: return None # Reconstituer le raisonnement structuré reasoning_parts = [] for key in ("analyse_clinique", "analyse_acte", "codes_candidats", "discrimination", "regle_pmsi"): val = result.pop(key, None) if val: titre = key.replace("_", " ").upper() reasoning_parts.append(f"{titre} :\n{val}") if reasoning_parts: result["raisonnement"] = "\n\n".join(reasoning_parts) return result def _apply_llm_result_diagnostic(diagnostic: Diagnostic, llm_result: dict) -> None: """Applique un résultat LLM (frais ou caché) à un Diagnostic.""" code = llm_result.get("code") confidence = llm_result.get("confidence") justification = llm_result.get("justification") raisonnement = llm_result.get("raisonnement") if code: code = normalize_code(code) # Garde-fou : rejeter un code sans raisonnement ni justification # (corrélation forte avec les hallucinations LLM) if not raisonnement and not justification: logger.warning( "RAG : code %s rejeté pour « %s » — raisonnement et justification vides", code, diagnostic.texte, ) code = None if code: is_valid, _ = cim10_validate(code) if is_valid: diagnostic.cim10_suggestion = code else: # Tenter fallback vers le code parent (D71.9 → D71) parent = fallback_parent_code(code) if parent: logger.info( "RAG : code Ollama %s invalide → fallback parent %s pour « %s »", code, parent, diagnostic.texte, ) diagnostic.cim10_suggestion = parent else: logger.warning( "RAG : code Ollama %s invalide pour « %s », code ignoré", code, diagnostic.texte, ) if confidence in ("high", "medium", "low"): diagnostic.cim10_confidence = confidence if justification: diagnostic.justification = justification if raisonnement: diagnostic.raisonnement = raisonnement # Stocker les preuves cliniques preuves = llm_result.get("preuves_cliniques", []) if preuves and isinstance(preuves, list): for p in preuves: if isinstance(p, dict) and p.get("element"): try: diagnostic.preuves_cliniques.append(PreuveClinique( type=p.get("type", "clinique"), element=p["element"], interpretation=p.get("interpretation", ""), )) except Exception: pass def enrich_diagnostic( diagnostic: Diagnostic, contexte: dict, est_dp: bool = True, cache: OllamaCache | None = None, ) -> None: """Enrichit un Diagnostic avec le RAG (FAISS + Ollama). Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent. """ t0 = time.monotonic() diag_type = "dp" if est_dp else "das" label = f"RAG-{diag_type.upper()}" # 1. Vérifier le cache cached = cache.get(diagnostic.texte, diag_type) if cache else None # 2. Recherche FAISS (toujours, pour les sources_rag fraîches) sources = search_similar(diagnostic.texte, top_k=10) if not sources: # Toujours initialiser sources_rag (même vide) pour traçabilité diagnostic.sources_rag = [] logger.debug("RAG: 0 résultat FAISS pour « %s »", diagnostic.texte) # Si un cache hit existe, appliquer le résultat LLM malgré l'absence de sources if cached is not None: logger.info("Cache hit (sans sources FAISS) pour %s : « %s »", diag_type.upper(), diagnostic.texte) _apply_llm_result_diagnostic(diagnostic, cached) elapsed = time.monotonic() - t0 logger.info("⏱ [%s] %.1fs — %s (no FAISS)", label, elapsed, diagnostic.texte[:60]) return # 3. Stocker les sources RAG diagnostic.sources_rag = [ RAGSource( document=s["document"], page=s.get("page"), code=s.get("code"), extrait=s.get("extrait", "")[:200], ) for s in sources ] # 4. Si cache hit, appliquer et court-circuiter Ollama if cached is not None: logger.info("Cache hit pour %s : « %s »", diag_type.upper(), diagnostic.texte) _apply_llm_result_diagnostic(diagnostic, cached) elapsed = time.monotonic() - t0 logger.info("⏱ [%s] %.1fs — %s (cache hit)", label, elapsed, diagnostic.texte[:60]) return # 5. Appel Ollama pour justification avec raisonnement structuré prompt = _build_prompt(diagnostic.texte, sources, contexte, est_dp=est_dp) t_llm = time.monotonic() llm_result = _call_ollama(prompt) llm_elapsed = time.monotonic() - t_llm if llm_result: _apply_llm_result_diagnostic(diagnostic, llm_result) if cache: cache.put(diagnostic.texte, diag_type, llm_result) else: logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM") elapsed = time.monotonic() - t0 logger.info("⏱ [%s] %.1fs (LLM %.1fs) — %s", label, elapsed, llm_elapsed, diagnostic.texte[:60]) def _apply_llm_result_acte(acte: ActeCCAM, llm_result: dict) -> None: """Applique un résultat LLM (frais ou caché) à un ActeCCAM.""" code = llm_result.get("code") confidence = llm_result.get("confidence") justification = llm_result.get("justification") raisonnement = llm_result.get("raisonnement") if code: code = code.strip().upper() is_valid, _ = ccam_validate(code) if is_valid: acte.code_ccam_suggestion = code else: logger.warning( "RAG : code CCAM Ollama %s invalide pour « %s », code ignoré", code, acte.texte, ) if confidence in ("high", "medium", "low"): acte.ccam_confidence = confidence if justification: acte.justification = justification if raisonnement: acte.raisonnement = raisonnement def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None) -> None: """Enrichit un ActeCCAM avec le RAG (FAISS + Ollama). Modifie l'acte en place. Fallback gracieux si FAISS ou Ollama échouent. """ t0 = time.monotonic() # 1. Vérifier le cache cached = cache.get(acte.texte, "ccam") if cache else None # 2. Recherche FAISS (sources CCAM priorisées) sources = search_similar_ccam(acte.texte, top_k=8) if not sources: logger.debug("Aucune source RAG CCAM trouvée pour : %s", acte.texte) return # 3. Stocker les sources RAG acte.sources_rag = [ RAGSource( document=s["document"], page=s.get("page"), code=s.get("code"), extrait=s.get("extrait", "")[:200], ) for s in sources ] # 4. Si cache hit, appliquer et court-circuiter Ollama if cached is not None: logger.info("Cache hit pour CCAM : « %s »", acte.texte) _apply_llm_result_acte(acte, cached) elapsed = time.monotonic() - t0 logger.info("⏱ [RAG-CCAM] %.1fs — %s (cache hit)", elapsed, acte.texte[:60]) return # 5. Appel Ollama pour justification avec raisonnement structuré prompt = _build_prompt_ccam(acte.texte, sources, contexte) t_llm = time.monotonic() llm_result = _call_ollama(prompt) llm_elapsed = time.monotonic() - t_llm if llm_result: _apply_llm_result_acte(acte, llm_result) if cache: cache.put(acte.texte, "ccam", llm_result) else: logger.info("Ollama non disponible — sources FAISS CCAM conservées sans justification LLM") elapsed = time.monotonic() - t0 logger.info("⏱ [RAG-CCAM] %.1fs (LLM %.1fs) — %s", elapsed, llm_elapsed, acte.texte[:60]) def _smart_truncate(text: str, max_chars: int = 6000) -> str: """Troncature intelligente : garde le début + les sections finales importantes. Pour les textes longs, on garde : - Les premiers 60% de max_chars (début du document : identité, motif, histoire) - Les derniers 40% (conclusion, synthèse, diagnostic de sortie, TTT) Séparés par un marqueur [...] pour indiquer la troncature. """ if len(text) <= max_chars: return text head_size = int(max_chars * 0.6) tail_size = max_chars - head_size - 30 # 30 chars pour le séparateur # Chercher une fin de phrase propre pour le head head = text[:head_size] last_newline = head.rfind("\n") if last_newline > head_size * 0.8: head = head[:last_newline] # Chercher un début de ligne propre pour le tail tail_start = len(text) - tail_size first_newline = text.find("\n", tail_start) if first_newline > 0 and first_newline < tail_start + 200: tail_start = first_newline + 1 tail = text[tail_start:] return head + "\n\n[... texte tronqué ...]\n\n" + tail def _build_prompt_das_extraction(text: str, contexte: dict, existing_das: list[str], dp_texte: str) -> str: """Construit le prompt pour l'extraction LLM de DAS supplémentaires.""" ctx_str = format_enriched_context(contexte) existing_str = "\n".join(f"- {d}" for d in existing_das) if existing_das else "Aucun" return DAS_EXTRACTION.format( dp_texte=dp_texte or "Non identifié", existing_str=existing_str, ctx_str=ctx_str, text_medical=_smart_truncate(text, 6000), ) def extract_das_llm( text: str, contexte: dict, existing_das: list[str], dp_texte: str, cache: OllamaCache | None = None, ) -> list[dict]: """Extrait des DAS supplémentaires via un pass LLM. Args: text: Texte médical complet. contexte: Contexte patient (sexe, age, etc.). existing_das: Liste des DAS déjà codés (texte + code). dp_texte: Texte du diagnostic principal. cache: Cache Ollama optionnel. Returns: Liste de dicts {texte, code_cim10, justification} pour les DAS détectés. """ import hashlib t0 = time.monotonic() # Clé de cache basée sur le hash du texte text_hash = hashlib.md5(text[:4000].encode()).hexdigest()[:16] cache_key_text = f"das_extract::{text_hash}" # Vérifier le cache if cache is not None: cached = cache.get(cache_key_text, "das_llm") if cached is not None: elapsed = time.monotonic() - t0 logger.info("⏱ [DAS-LLM] %.1fs — cache hit", elapsed) return cached.get("diagnostics_supplementaires", []) # Construire le prompt et appeler Ollama prompt = _build_prompt_das_extraction(text, contexte, existing_das, dp_texte) t_llm = time.monotonic() result = call_ollama(prompt, temperature=0.1, max_tokens=2000, role="coding") llm_elapsed = time.monotonic() - t_llm if result is None: elapsed = time.monotonic() - t0 logger.warning("⏱ [DAS-LLM] %.1fs — Ollama non disponible", elapsed) return [] das_list = result.get("diagnostics_supplementaires", []) if not isinstance(das_list, list): logger.warning("Extraction DAS LLM : format inattendu") return [] # Stocker dans le cache if cache is not None: cache.put(cache_key_text, "das_llm", result) elapsed = time.monotonic() - t0 logger.info("⏱ [DAS-LLM] %.1fs (LLM %.1fs) — %d diagnostics détectés", elapsed, llm_elapsed, len(das_list)) return das_list def enrich_dp(dossier: DossierMedical, cache: OllamaCache | None = None) -> None: """Enrichit SEULEMENT le DP d'un dossier via le RAG (Phase 1). Peut être exécuté en parallèle avec d'autres tâches (DAS LLM, DP selector) car il ne dépend que du DP existant. Args: dossier: Dossier médical à enrichir (modifié en place). cache: Cache Ollama optionnel (créé si None). """ t0 = time.monotonic() if cache is None: cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding")) contexte = build_enriched_context(dossier) if dossier.diagnostic_principal: logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte) enrich_diagnostic(dossier.diagnostic_principal, contexte, est_dp=True, cache=cache) cache.save() elapsed = time.monotonic() - t0 logger.info("⏱ [RAG-DP-PHASE] %.1fs — enrichissement DP terminé", elapsed) def enrich_das_and_actes(dossier: DossierMedical, cache: OllamaCache | None = None) -> None: """Enrichit les DAS et actes CCAM d'un dossier via le RAG (Phase 2). Parallélise les appels Ollama (max_workers = OLLAMA_MAX_PARALLEL). Doit être appelé APRES enrich_dp() et après l'ajout des DAS LLM. Args: dossier: Dossier médical à enrichir (modifié en place). cache: Cache Ollama optionnel (créé si None). """ t0 = time.monotonic() if cache is None: cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding")) contexte = build_enriched_context(dossier) # Mettre à jour le contexte avec le DP pour les DAS if dossier.diagnostic_principal: contexte["dp_texte"] = dossier.diagnostic_principal.texte contexte["das_codes_existants"] = [ f"{d.cim10_suggestion} ({d.texte})" for d in dossier.diagnostics_associes if d.cim10_suggestion ] das_list = dossier.diagnostics_associes actes_list = dossier.actes_ccam if das_list or actes_list: with ThreadPoolExecutor(max_workers=OLLAMA_MAX_PARALLEL) as executor: futures = [] for das in das_list: logger.info("RAG enrichissement DAS : %s", das.texte) futures.append(executor.submit(enrich_diagnostic, das, contexte, False, cache)) for acte in actes_list: logger.info("RAG enrichissement CCAM : %s", acte.texte) futures.append(executor.submit(enrich_acte, acte, contexte, cache)) for f in as_completed(futures): f.result() # propage les exceptions cache.save() elapsed = time.monotonic() - t0 n_das = len(das_list) if das_list else 0 n_actes = len(actes_list) if actes_list else 0 logger.info("⏱ [RAG-DAS-PHASE] %.1fs — %d DAS + %d actes enrichis", elapsed, n_das, n_actes) def enrich_dossier(dossier: DossierMedical) -> None: """Enrichit le DP et tous les DAS d'un dossier via le RAG. Wrapper rétro-compatible qui appelle enrich_dp() puis enrich_das_and_actes(). Utilise un cache persistant partagé entre les deux phases. """ cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding")) enrich_dp(dossier, cache=cache) enrich_das_and_actes(dossier, cache=cache)