Files
t2a_v2/src/medical/rag_search.py
dom f5a6122495 feat: timings granulaires appels LLM (RAG, DAS, DP, QC)
Ajoute des mesures time.monotonic() autour de chaque appel LLM dans
rag_search.py : enrich_diagnostic, enrich_acte, extract_das_llm.
Format de log : logger.info("⏱ [RAG-DP] 14.2s — texte")

Découpe enrich_dossier() en 2 fonctions exportées :
- enrich_dp() : enrichit seulement le DP (parallélisable)
- enrich_das_and_actes() : enrichit DAS + actes en parallèle
L'ancienne enrich_dossier() reste comme wrapper rétro-compatible.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 14:16:38 +01:00

934 lines
34 KiB
Python

"""Recherche RAG (FAISS) + génération via Ollama pour le codage CIM-10."""
from __future__ import annotations
import logging
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from ..config import (
ActeCCAM, Diagnostic, DossierMedical, PreuveClinique, RAGSource,
OLLAMA_CACHE_PATH, OLLAMA_MAX_PARALLEL, get_model,
EMBEDDING_MODEL, RERANKER_MODEL,
)
from .cim10_dict import normalize_code, validate_code as cim10_validate, fallback_parent_code
from .bio_normals import BIO_NORMALS
from .clinical_context import build_enriched_context, format_enriched_context
from .ccam_dict import validate_code as ccam_validate
from .ollama_client import call_ollama, parse_json_response
from .ollama_cache import OllamaCache
from ..prompts import CODING_CIM10, CODING_CCAM, DAS_EXTRACTION
logger = logging.getLogger(__name__)
# Singleton pour le modèle d'embedding (chargé une seule fois)
_embed_model = None
_embed_lock = threading.Lock()
_embed_failed = False # Sentinelle pour éviter les retries infinis
# Singleton pour le cross-encoder de re-ranking (CPU uniquement)
_reranker_model = None
_reranker_lock = threading.Lock()
# Cache d'embeddings : évite de recalculer les vecteurs pour les mêmes textes
_embedding_cache: dict[str, "numpy.ndarray"] = {}
_embedding_cache_lock = threading.Lock()
_EMBEDDING_CACHE_MAX = 5000
# Score minimum de similarité FAISS pour retenir un résultat
_MIN_SCORE = 0.3
# Seuil rehaussé pour le contexte CPAM (filtrage plus agressif du bruit)
_MIN_SCORE_CPAM = 0.40
def _get_embed_model():
"""Charge le modèle d'embedding (singleton thread-safe).
Tente CUDA d'abord, fallback CPU si OOM (Ollama peut occuper la VRAM).
low_cpu_mem_usage=False évite les meta tensors (accelerate + sentence-transformers 5.x).
Un Lock empêche les chargements concurrents depuis le ThreadPool.
"""
global _embed_model, _embed_failed
if _embed_model is not None:
return _embed_model
if _embed_failed:
raise RuntimeError("Modèle d'embedding indisponible (échec précédent)")
with _embed_lock:
# Double-check après acquisition du lock
if _embed_model is not None:
return _embed_model
if _embed_failed:
raise RuntimeError("Modèle d'embedding indisponible (échec précédent)")
from sentence_transformers import SentenceTransformer
import torch
_device = "cpu" if os.environ.get("T2A_EMBED_CPU") else ("cuda" if torch.cuda.is_available() else "cpu")
_model_kwargs = {"low_cpu_mem_usage": False}
try:
logger.info("Chargement du modèle d'embedding (%s)...", _device)
_embed_model = SentenceTransformer(
EMBEDDING_MODEL, device=_device, model_kwargs=_model_kwargs,
)
except (torch.OutOfMemoryError, torch.cuda.CudaError, torch.AcceleratorError,
RuntimeError, NotImplementedError) as exc:
exc_msg = str(exc).lower()
if _device == "cuda" and ("memory" in exc_msg or "meta tensor" in exc_msg):
logger.warning("CUDA erreur pour l'embedding — fallback CPU : %s", exc)
torch.cuda.empty_cache()
try:
_embed_model = SentenceTransformer(
EMBEDDING_MODEL, device="cpu", model_kwargs=_model_kwargs,
)
except Exception as exc2:
logger.error("Fallback CPU aussi en échec : %s", exc2)
_embed_failed = True
raise
else:
_embed_failed = True
raise
_embed_model.max_seq_length = 512
return _embed_model
def _get_reranker():
"""Charge le cross-encoder de re-ranking (singleton thread-safe, CPU uniquement).
Forcé sur CPU pour ne pas interférer avec Ollama sur GPU.
"""
global _reranker_model
if _reranker_model is not None:
return _reranker_model
with _reranker_lock:
# Double-check après acquisition du lock
if _reranker_model is not None:
return _reranker_model
from sentence_transformers import CrossEncoder
logger.info("Chargement du cross-encoder de re-ranking (cpu)...")
_reranker_model = CrossEncoder(RERANKER_MODEL, device="cpu")
return _reranker_model
def _rerank(query: str, results: list[dict], top_k: int) -> list[dict]:
"""Re-classe les résultats FAISS via un cross-encoder.
Args:
query: Texte de la requête originale.
results: Résultats FAISS avec clé 'extrait'.
top_k: Nombre de résultats à retourner.
Returns:
Résultats re-classés par score cross-encoder, limités à top_k.
"""
if not results:
return results
passages = [r.get("extrait", "") for r in results]
# Essayer le serveur distant d'abord
ce_scores = None
try:
from .remote_embed import rerank_remote
remote_scores = rerank_remote(query, passages)
if remote_scores is not None:
ce_scores = remote_scores
except ImportError:
pass
# Fallback local
if ce_scores is None:
reranker = _get_reranker()
pairs = [(query, p) for p in passages]
ce_scores = reranker.predict(pairs)
# Injecter le score cross-encoder et trier
for r, ce_score in zip(results, ce_scores):
r["score_faiss"] = r["score"]
r["score"] = float(ce_score)
results.sort(key=lambda r: r["score"], reverse=True)
return results[:top_k]
def _embed_cached(texts: list[str]) -> "numpy.ndarray":
"""Calcule les embeddings avec cache. Essaie le serveur distant d'abord."""
import numpy as np
results = [None] * len(texts)
to_compute: list[tuple[int, str]] = []
with _embedding_cache_lock:
for i, t in enumerate(texts):
cached = _embedding_cache.get(t)
if cached is not None:
results[i] = cached
else:
to_compute.append((i, t))
if to_compute:
new_texts = [t for _, t in to_compute]
# Essayer le serveur distant d'abord
new_vecs = None
try:
from .remote_embed import embed_remote
new_vecs = embed_remote(new_texts)
except ImportError:
pass
# Fallback local
if new_vecs is None:
model = _get_embed_model()
new_vecs = model.encode(new_texts, normalize_embeddings=True, batch_size=64)
new_vecs = np.array(new_vecs, dtype=np.float32)
with _embedding_cache_lock:
for j, (i, t) in enumerate(to_compute):
vec = new_vecs[j]
results[i] = vec
_embedding_cache[t] = vec
# Eviction simple si trop d'entrées
if len(_embedding_cache) > _EMBEDDING_CACHE_MAX:
keys = list(_embedding_cache.keys())
for k in keys[:len(keys) // 5]:
del _embedding_cache[k]
return np.array(results, dtype=np.float32)
def search_similar(query: str, top_k: int = 10) -> list[dict]:
"""Recherche les passages les plus similaires dans l'index FAISS.
Args:
query: Texte du diagnostic à rechercher.
top_k: Nombre de résultats à retourner.
Returns:
Liste de dicts avec les métadonnées + score de similarité,
filtrés par score minimum et priorisant les sources CIM-10.
"""
from .rag_index import get_index
import numpy as np
# Codage CIM-10 : on interroge l'index "ref" (pas le guide méthodo).
result = get_index(kind="ref")
if result is None:
logger.warning("Index FAISS non disponible")
return []
faiss_index, metadata = result
query_vec = _embed_cached([query])
# Chercher plus de résultats que top_k pour pouvoir filtrer ensuite
fetch_k = min(top_k * 2, faiss_index.ntotal)
scores, indices = faiss_index.search(query_vec, fetch_k)
raw_results = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
if float(score) < _MIN_SCORE:
continue
meta = metadata[idx].copy()
meta["score"] = float(score)
raw_results.append(meta)
# Codage : on garde uniquement CIM-10 + index alpha + éventuels référentiels uploadés en ref:...
cim10_results = [r for r in raw_results if r.get("document") == "cim10"]
alpha_results = [r for r in raw_results if r.get("document") == "cim10_alpha"]
ref_uploads = [r for r in raw_results if str(r.get("document", "")).startswith("ref:")]
cim10_results.sort(key=lambda r: r["score"], reverse=True)
alpha_results.sort(key=lambda r: r["score"], reverse=True)
ref_uploads.sort(key=lambda r: r["score"], reverse=True)
# Quotas : on veut garder le codage ancré sur CIM-10, tout en gardant un peu d'alpha et de ref.
q_cim10 = min(6, top_k)
q_alpha = 2 if top_k >= 10 else (1 if top_k >= 8 else 0)
q_alpha = min(q_alpha, max(0, top_k - q_cim10))
q_ref = max(0, top_k - q_cim10 - q_alpha)
q_ref = min(q_ref, 2) # éviter que les uploads 'ref:' prennent tout l'espace contexte
final: list[dict] = []
final.extend(cim10_results[:q_cim10])
final.extend(alpha_results[:q_alpha])
final.extend(ref_uploads[:q_ref])
# Compléter si on a moins que top_k (ex: pas assez d'alpha/ref)
if len(final) < top_k:
remaining = cim10_results[q_cim10:] + alpha_results[q_alpha:] + ref_uploads[q_ref:]
remaining.sort(key=lambda r: r["score"], reverse=True)
final.extend(remaining[: (top_k - len(final))])
return final
def search_similar_ccam(query: str, top_k: int = 8) -> list[dict]:
"""Recherche les passages CCAM les plus similaires dans l'index FAISS.
Même logique que search_similar() mais priorise les sources CCAM.
"""
from .rag_index import get_index
import numpy as np
# CCAM : index "ref".
result = get_index(kind="ref")
if result is None:
logger.warning("Index FAISS non disponible")
return []
faiss_index, metadata = result
query_vec = _embed_cached([query])
fetch_k = min(top_k * 2, faiss_index.ntotal)
scores, indices = faiss_index.search(query_vec, fetch_k)
raw_results = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
if float(score) < _MIN_SCORE:
continue
meta = metadata[idx].copy()
meta["score"] = float(score)
raw_results.append(meta)
# Prioriser les sources CCAM (au moins 5 sur top_k)
ccam_results = [r for r in raw_results if r["document"] == "ccam"]
other_results = [r for r in raw_results if r["document"] != "ccam"]
min_ccam = min(5, len(ccam_results))
final = ccam_results[:min_ccam]
remaining_slots = top_k - len(final)
remaining = ccam_results[min_ccam:] + other_results
remaining.sort(key=lambda r: r["score"], reverse=True)
final.extend(remaining[:remaining_slots])
return final
def search_similar_cpam(query: str, top_k: int = 8) -> list[dict]:
"""Recherche RAG spécifique au contexte CPAM (contre-argumentation).
Différences avec search_similar() :
- Priorité Guide Méthodologique (min 3 résultats) plutôt que CIM-10
- Seuil de score rehaussé (0.40 vs 0.30) pour éliminer le bruit
- Fetch élargi (top_k * 3) car filtrage plus agressif
- Déduplication par code CIM-10 (garde le meilleur score par code)
"""
from .rag_index import get_index
import numpy as np
# Contexte CPAM : on veut des procédures (guide) + définitions référentielles (CIM-10).
proc = get_index(kind="proc")
ref = get_index(kind="ref")
if proc is None and ref is None:
logger.warning("Index FAISS non disponible")
return []
query_vec = _embed_cached([query])
def _search_one(result_tuple, fetch_mult: int) -> list[dict]:
if result_tuple is None:
return []
faiss_index, metadata = result_tuple
fetch_k = min(top_k * fetch_mult, faiss_index.ntotal)
scores, indices = faiss_index.search(query_vec, fetch_k)
out = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
if float(score) < _MIN_SCORE_CPAM:
continue
meta = metadata[idx].copy()
meta["score"] = float(score)
out.append(meta)
return out
raw_proc = _search_one(proc, fetch_mult=3)
raw_ref = _search_one(ref, fetch_mult=3)
# Filtrer clairement :
# - proc : guide_methodo + uploads proc:
raw_proc = [r for r in raw_proc if r.get("document") == "guide_methodo" or str(r.get("document", "")).startswith("proc:")]
# - ref : CIM-10 + index alpha + uploads ref:
raw_ref = [r for r in raw_ref if r.get("document") in ("cim10", "cim10_alpha") or str(r.get("document", "")).startswith("ref:")]
raw_results = raw_proc + raw_ref
# Dédupliquer par code CIM-10 (garder meilleur score par code)
seen_codes: dict[str, dict] = {}
deduped = []
for r in raw_results:
code = r.get("code")
if code:
if code in seen_codes:
if r["score"] > seen_codes[code]["score"]:
seen_codes[code] = r
else:
seen_codes[code] = r
else:
deduped.append(r) # pas de code → garder (guide_methodo, etc.)
deduped.extend(seen_codes.values())
deduped.sort(key=lambda r: r["score"], reverse=True)
# Re-ranking cross-encoder (CPU) pour affiner le classement
reranked = _rerank(query, deduped, top_k=len(deduped))
# Prioriser le Guide Méthodologique (min 3 résultats)
guide_results = [r for r in reranked if r.get("document") == "guide_methodo" or str(r.get("document", "")).startswith("proc:")]
other_results = [
r for r in reranked
if not (r.get("document") == "guide_methodo" or str(r.get("document", "")).startswith("proc:"))
]
min_guide = min(3, len(guide_results))
final = guide_results[:min_guide]
remaining_slots = top_k - len(final)
remaining = guide_results[min_guide:] + other_results
remaining.sort(key=lambda r: r["score"], reverse=True)
final.extend(remaining[:remaining_slots])
return final
def _format_contexte(contexte: dict) -> str:
"""Formate le contexte patient de manière structurée pour le prompt."""
lines = []
sexe = contexte.get("sexe")
age = contexte.get("age")
imc = contexte.get("imc")
patient_parts = []
if sexe:
patient_parts.append(sexe)
if age:
patient_parts.append(f"{age} ans")
if imc:
patient_parts.append(f"IMC {imc}")
if patient_parts:
lines.append(f"- Patient : {', '.join(patient_parts)}")
duree = contexte.get("duree_sejour")
if duree:
lines.append(f"- Durée séjour : {duree} jours")
antecedents = contexte.get("antecedents")
if antecedents:
lines.append(f"- Antécédents : {', '.join(antecedents[:5])}")
biologie = contexte.get("biologie_cle")
if biologie:
bio_parts = []
for b in biologie:
test, valeur, anomalie = b if isinstance(b, (list, tuple)) else (b.get("test"), b.get("valeur"), b.get("anomalie"))
# Ajouter la plage de référence si connue
norme_str = ""
if test in BIO_NORMALS:
lo, hi = BIO_NORMALS[test]
lo_s = int(lo) if lo == int(lo) else lo
hi_s = int(hi) if hi == int(hi) else hi
norme_str = f" [N: {lo_s}-{hi_s}]"
marker = " (\u2191)" if anomalie else ""
bio_parts.append(f"{test} {valeur}{norme_str}{marker}")
lines.append(f"- Biologie : {', '.join(bio_parts)}")
imagerie = contexte.get("imagerie")
if imagerie:
for img in imagerie:
img_type, conclusion = img if isinstance(img, (list, tuple)) else (img.get("type"), img.get("conclusion"))
if conclusion:
lines.append(f"- Imagerie : {img_type}{conclusion[:200]}")
complications = contexte.get("complications")
if complications:
lines.append(f"- Complications : {', '.join(complications)}")
dp_texte = contexte.get("dp_texte")
if dp_texte:
lines.append(f"- DP du séjour : {dp_texte}")
das_codes = contexte.get("das_codes_existants")
if das_codes:
lines.append(f"- DAS déjà codés : {', '.join(das_codes)}")
return "\n".join(lines) if lines else "Non précisé"
def _format_sources(sources: list[dict]) -> str:
"""Formate les sources RAG pour injection dans un prompt."""
sources_text = ""
for i, src in enumerate(sources, 1):
doc_raw = str(src.get("document", ""))
if doc_raw.startswith("ref:"):
doc_name = f"Référentiel uploadé : {doc_raw[4:]}"
elif doc_raw.startswith("proc:"):
doc_name = f"Procédure uploadée : {doc_raw[5:]}"
else:
doc_name = {
"cim10": "CIM-10 FR 2026",
"cim10_alpha": "CIM-10 Index Alphabétique 2026",
"guide_methodo": "Guide Méthodologique MCO 2026",
"ccam": "CCAM PMSI V4 2025",
}.get(doc_raw, doc_raw)
code_info = f" (code: {src['code']})" if src.get("code") else ""
page_info = f" [page {src['page']}]" if src.get("page") else ""
sources_text += f"--- Source {i}: {doc_name}{code_info}{page_info} ---\n"
sources_text += (src.get("extrait", "")[:800]) + "\n\n"
return sources_text
def _build_prompt(texte: str, sources: list[dict], contexte: dict, est_dp: bool = True) -> str:
"""Construit le prompt expert DIM avec raisonnement structuré."""
type_diag = "DP (diagnostic principal)" if est_dp else "DAS (diagnostic associé significatif)"
ctx_str = format_enriched_context(contexte)
sources_text = _format_sources(sources)
return CODING_CIM10.format(
texte=texte,
type_diag=type_diag,
ctx_str=ctx_str,
sources_text=sources_text,
)
def _build_prompt_ccam(texte: str, sources: list[dict], contexte: dict) -> str:
"""Construit le prompt expert DIM pour le codage CCAM avec raisonnement structuré."""
ctx_str = format_enriched_context(contexte)
sources_text = _format_sources(sources)
return CODING_CCAM.format(
texte=texte,
ctx_str=ctx_str,
sources_text=sources_text,
)
def _parse_ollama_response(raw: str) -> dict | None:
"""Parse la réponse JSON d'Ollama et reconstitue le raisonnement structuré."""
parsed = parse_json_response(raw)
if parsed is None:
return None
# Reconstituer le raisonnement à partir des champs structurés
reasoning_parts = []
for key in ("analyse_clinique", "analyse_acte", "codes_candidats", "discrimination", "regle_pmsi"):
val = parsed.pop(key, None)
if val:
titre = key.replace("_", " ").upper()
reasoning_parts.append(f"{titre} :\n{val}")
if reasoning_parts:
parsed["raisonnement"] = "\n\n".join(reasoning_parts)
return parsed
def _call_ollama(prompt: str) -> dict | None:
"""Appelle Ollama (mode JSON) et parse la réponse avec reconstitution du raisonnement."""
result = call_ollama(prompt, temperature=0.1, max_tokens=2500, role="coding")
if result is None:
return None
# Reconstituer le raisonnement structuré
reasoning_parts = []
for key in ("analyse_clinique", "analyse_acte", "codes_candidats", "discrimination", "regle_pmsi"):
val = result.pop(key, None)
if val:
titre = key.replace("_", " ").upper()
reasoning_parts.append(f"{titre} :\n{val}")
if reasoning_parts:
result["raisonnement"] = "\n\n".join(reasoning_parts)
return result
def _apply_llm_result_diagnostic(diagnostic: Diagnostic, llm_result: dict) -> None:
"""Applique un résultat LLM (frais ou caché) à un Diagnostic."""
code = llm_result.get("code")
confidence = llm_result.get("confidence")
justification = llm_result.get("justification")
raisonnement = llm_result.get("raisonnement")
if code:
code = normalize_code(code)
# Garde-fou : rejeter un code sans raisonnement ni justification
# (corrélation forte avec les hallucinations LLM)
if not raisonnement and not justification:
logger.warning(
"RAG : code %s rejeté pour « %s » — raisonnement et justification vides",
code, diagnostic.texte,
)
code = None
if code:
is_valid, _ = cim10_validate(code)
if is_valid:
diagnostic.cim10_suggestion = code
else:
# Tenter fallback vers le code parent (D71.9 → D71)
parent = fallback_parent_code(code)
if parent:
logger.info(
"RAG : code Ollama %s invalide → fallback parent %s pour « %s »",
code, parent, diagnostic.texte,
)
diagnostic.cim10_suggestion = parent
else:
logger.warning(
"RAG : code Ollama %s invalide pour « %s », code ignoré",
code, diagnostic.texte,
)
if confidence in ("high", "medium", "low"):
diagnostic.cim10_confidence = confidence
if justification:
diagnostic.justification = justification
if raisonnement:
diagnostic.raisonnement = raisonnement
# Stocker les preuves cliniques
preuves = llm_result.get("preuves_cliniques", [])
if preuves and isinstance(preuves, list):
for p in preuves:
if isinstance(p, dict) and p.get("element"):
try:
diagnostic.preuves_cliniques.append(PreuveClinique(
type=p.get("type", "clinique"),
element=p["element"],
interpretation=p.get("interpretation", ""),
))
except Exception:
pass
def enrich_diagnostic(
diagnostic: Diagnostic,
contexte: dict,
est_dp: bool = True,
cache: OllamaCache | None = None,
) -> None:
"""Enrichit un Diagnostic avec le RAG (FAISS + Ollama).
Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent.
"""
t0 = time.monotonic()
diag_type = "dp" if est_dp else "das"
label = f"RAG-{diag_type.upper()}"
# 1. Vérifier le cache
cached = cache.get(diagnostic.texte, diag_type) if cache else None
# 2. Recherche FAISS (toujours, pour les sources_rag fraîches)
sources = search_similar(diagnostic.texte, top_k=10)
if not sources:
# Toujours initialiser sources_rag (même vide) pour traçabilité
diagnostic.sources_rag = []
logger.debug("RAG: 0 résultat FAISS pour « %s »", diagnostic.texte)
# Si un cache hit existe, appliquer le résultat LLM malgré l'absence de sources
if cached is not None:
logger.info("Cache hit (sans sources FAISS) pour %s : « %s »", diag_type.upper(), diagnostic.texte)
_apply_llm_result_diagnostic(diagnostic, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs — %s (no FAISS)", label, elapsed, diagnostic.texte[:60])
return
# 3. Stocker les sources RAG
diagnostic.sources_rag = [
RAGSource(
document=s["document"],
page=s.get("page"),
code=s.get("code"),
extrait=s.get("extrait", "")[:200],
)
for s in sources
]
# 4. Si cache hit, appliquer et court-circuiter Ollama
if cached is not None:
logger.info("Cache hit pour %s : « %s »", diag_type.upper(), diagnostic.texte)
_apply_llm_result_diagnostic(diagnostic, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs — %s (cache hit)", label, elapsed, diagnostic.texte[:60])
return
# 5. Appel Ollama pour justification avec raisonnement structuré
prompt = _build_prompt(diagnostic.texte, sources, contexte, est_dp=est_dp)
t_llm = time.monotonic()
llm_result = _call_ollama(prompt)
llm_elapsed = time.monotonic() - t_llm
if llm_result:
_apply_llm_result_diagnostic(diagnostic, llm_result)
if cache:
cache.put(diagnostic.texte, diag_type, llm_result)
else:
logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM")
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs (LLM %.1fs) — %s", label, elapsed, llm_elapsed, diagnostic.texte[:60])
def _apply_llm_result_acte(acte: ActeCCAM, llm_result: dict) -> None:
"""Applique un résultat LLM (frais ou caché) à un ActeCCAM."""
code = llm_result.get("code")
confidence = llm_result.get("confidence")
justification = llm_result.get("justification")
raisonnement = llm_result.get("raisonnement")
if code:
code = code.strip().upper()
is_valid, _ = ccam_validate(code)
if is_valid:
acte.code_ccam_suggestion = code
else:
logger.warning(
"RAG : code CCAM Ollama %s invalide pour « %s », code ignoré",
code, acte.texte,
)
if confidence in ("high", "medium", "low"):
acte.ccam_confidence = confidence
if justification:
acte.justification = justification
if raisonnement:
acte.raisonnement = raisonnement
def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None) -> None:
"""Enrichit un ActeCCAM avec le RAG (FAISS + Ollama).
Modifie l'acte en place. Fallback gracieux si FAISS ou Ollama échouent.
"""
t0 = time.monotonic()
# 1. Vérifier le cache
cached = cache.get(acte.texte, "ccam") if cache else None
# 2. Recherche FAISS (sources CCAM priorisées)
sources = search_similar_ccam(acte.texte, top_k=8)
if not sources:
logger.debug("Aucune source RAG CCAM trouvée pour : %s", acte.texte)
return
# 3. Stocker les sources RAG
acte.sources_rag = [
RAGSource(
document=s["document"],
page=s.get("page"),
code=s.get("code"),
extrait=s.get("extrait", "")[:200],
)
for s in sources
]
# 4. Si cache hit, appliquer et court-circuiter Ollama
if cached is not None:
logger.info("Cache hit pour CCAM : « %s »", acte.texte)
_apply_llm_result_acte(acte, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-CCAM] %.1fs — %s (cache hit)", elapsed, acte.texte[:60])
return
# 5. Appel Ollama pour justification avec raisonnement structuré
prompt = _build_prompt_ccam(acte.texte, sources, contexte)
t_llm = time.monotonic()
llm_result = _call_ollama(prompt)
llm_elapsed = time.monotonic() - t_llm
if llm_result:
_apply_llm_result_acte(acte, llm_result)
if cache:
cache.put(acte.texte, "ccam", llm_result)
else:
logger.info("Ollama non disponible — sources FAISS CCAM conservées sans justification LLM")
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-CCAM] %.1fs (LLM %.1fs) — %s", elapsed, llm_elapsed, acte.texte[:60])
def _smart_truncate(text: str, max_chars: int = 6000) -> str:
"""Troncature intelligente : garde le début + les sections finales importantes.
Pour les textes longs, on garde :
- Les premiers 60% de max_chars (début du document : identité, motif, histoire)
- Les derniers 40% (conclusion, synthèse, diagnostic de sortie, TTT)
Séparés par un marqueur [...] pour indiquer la troncature.
"""
if len(text) <= max_chars:
return text
head_size = int(max_chars * 0.6)
tail_size = max_chars - head_size - 30 # 30 chars pour le séparateur
# Chercher une fin de phrase propre pour le head
head = text[:head_size]
last_newline = head.rfind("\n")
if last_newline > head_size * 0.8:
head = head[:last_newline]
# Chercher un début de ligne propre pour le tail
tail_start = len(text) - tail_size
first_newline = text.find("\n", tail_start)
if first_newline > 0 and first_newline < tail_start + 200:
tail_start = first_newline + 1
tail = text[tail_start:]
return head + "\n\n[... texte tronqué ...]\n\n" + tail
def _build_prompt_das_extraction(text: str, contexte: dict, existing_das: list[str], dp_texte: str) -> str:
"""Construit le prompt pour l'extraction LLM de DAS supplémentaires."""
ctx_str = format_enriched_context(contexte)
existing_str = "\n".join(f"- {d}" for d in existing_das) if existing_das else "Aucun"
return DAS_EXTRACTION.format(
dp_texte=dp_texte or "Non identifié",
existing_str=existing_str,
ctx_str=ctx_str,
text_medical=_smart_truncate(text, 6000),
)
def extract_das_llm(
text: str,
contexte: dict,
existing_das: list[str],
dp_texte: str,
cache: OllamaCache | None = None,
) -> list[dict]:
"""Extrait des DAS supplémentaires via un pass LLM.
Args:
text: Texte médical complet.
contexte: Contexte patient (sexe, age, etc.).
existing_das: Liste des DAS déjà codés (texte + code).
dp_texte: Texte du diagnostic principal.
cache: Cache Ollama optionnel.
Returns:
Liste de dicts {texte, code_cim10, justification} pour les DAS détectés.
"""
import hashlib
t0 = time.monotonic()
# Clé de cache basée sur le hash du texte
text_hash = hashlib.md5(text[:4000].encode()).hexdigest()[:16]
cache_key_text = f"das_extract::{text_hash}"
# Vérifier le cache
if cache is not None:
cached = cache.get(cache_key_text, "das_llm")
if cached is not None:
elapsed = time.monotonic() - t0
logger.info("⏱ [DAS-LLM] %.1fs — cache hit", elapsed)
return cached.get("diagnostics_supplementaires", [])
# Construire le prompt et appeler Ollama
prompt = _build_prompt_das_extraction(text, contexte, existing_das, dp_texte)
t_llm = time.monotonic()
result = call_ollama(prompt, temperature=0.1, max_tokens=2000, role="coding")
llm_elapsed = time.monotonic() - t_llm
if result is None:
elapsed = time.monotonic() - t0
logger.warning("⏱ [DAS-LLM] %.1fs — Ollama non disponible", elapsed)
return []
das_list = result.get("diagnostics_supplementaires", [])
if not isinstance(das_list, list):
logger.warning("Extraction DAS LLM : format inattendu")
return []
# Stocker dans le cache
if cache is not None:
cache.put(cache_key_text, "das_llm", result)
elapsed = time.monotonic() - t0
logger.info("⏱ [DAS-LLM] %.1fs (LLM %.1fs) — %d diagnostics détectés", elapsed, llm_elapsed, len(das_list))
return das_list
def enrich_dp(dossier: DossierMedical, cache: OllamaCache | None = None) -> None:
"""Enrichit SEULEMENT le DP d'un dossier via le RAG (Phase 1).
Peut être exécuté en parallèle avec d'autres tâches (DAS LLM, DP selector)
car il ne dépend que du DP existant.
Args:
dossier: Dossier médical à enrichir (modifié en place).
cache: Cache Ollama optionnel (créé si None).
"""
t0 = time.monotonic()
if cache is None:
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
contexte = build_enriched_context(dossier)
if dossier.diagnostic_principal:
logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte)
enrich_diagnostic(dossier.diagnostic_principal, contexte, est_dp=True, cache=cache)
cache.save()
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-DP-PHASE] %.1fs — enrichissement DP terminé", elapsed)
def enrich_das_and_actes(dossier: DossierMedical, cache: OllamaCache | None = None) -> None:
"""Enrichit les DAS et actes CCAM d'un dossier via le RAG (Phase 2).
Parallélise les appels Ollama (max_workers = OLLAMA_MAX_PARALLEL).
Doit être appelé APRES enrich_dp() et après l'ajout des DAS LLM.
Args:
dossier: Dossier médical à enrichir (modifié en place).
cache: Cache Ollama optionnel (créé si None).
"""
t0 = time.monotonic()
if cache is None:
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
contexte = build_enriched_context(dossier)
# Mettre à jour le contexte avec le DP pour les DAS
if dossier.diagnostic_principal:
contexte["dp_texte"] = dossier.diagnostic_principal.texte
contexte["das_codes_existants"] = [
f"{d.cim10_suggestion} ({d.texte})"
for d in dossier.diagnostics_associes
if d.cim10_suggestion
]
das_list = dossier.diagnostics_associes
actes_list = dossier.actes_ccam
if das_list or actes_list:
with ThreadPoolExecutor(max_workers=OLLAMA_MAX_PARALLEL) as executor:
futures = []
for das in das_list:
logger.info("RAG enrichissement DAS : %s", das.texte)
futures.append(executor.submit(enrich_diagnostic, das, contexte, False, cache))
for acte in actes_list:
logger.info("RAG enrichissement CCAM : %s", acte.texte)
futures.append(executor.submit(enrich_acte, acte, contexte, cache))
for f in as_completed(futures):
f.result() # propage les exceptions
cache.save()
elapsed = time.monotonic() - t0
n_das = len(das_list) if das_list else 0
n_actes = len(actes_list) if actes_list else 0
logger.info("⏱ [RAG-DAS-PHASE] %.1fs — %d DAS + %d actes enrichis", elapsed, n_das, n_actes)
def enrich_dossier(dossier: DossierMedical) -> None:
"""Enrichit le DP et tous les DAS d'un dossier via le RAG.
Wrapper rétro-compatible qui appelle enrich_dp() puis enrich_das_and_actes().
Utilise un cache persistant partagé entre les deux phases.
"""
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
enrich_dp(dossier, cache=cache)
enrich_das_and_actes(dossier, cache=cache)