Merge branch 'feat/parallel-rag-pipeline' — optimisation flux LLM

This commit is contained in:
dom
2026-03-08 14:35:54 +01:00
7 changed files with 577 additions and 49 deletions

View File

@@ -60,7 +60,7 @@ OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "gemma3:27b")
OLLAMA_TIMEOUT = int(os.environ.get("OLLAMA_TIMEOUT", "600"))
OLLAMA_CACHE_PATH = BASE_DIR / "data" / "ollama_cache.json"
OLLAMA_MAX_PARALLEL = int(os.environ.get("OLLAMA_MAX_PARALLEL", "2"))
OLLAMA_MAX_PARALLEL = int(os.environ.get("OLLAMA_MAX_PARALLEL", "4"))
# --- Modèles par rôle LLM ---

View File

@@ -4,6 +4,14 @@ Orchestrateur principal — délègue aux sous-modules :
- cpam_rag : _search_rag_for_control(), _search_rag_queries()
- cpam_context : _build_cpam_prompt(), _build_tagged_context(), _build_bio_summary(), etc.
- cpam_validation : _validate_adversarial(), _validate_grounding(), _format_response(), etc.
Le flux batché (generate_cpam_responses_batched) regroupe les appels LLM par modèle
pour minimiser les swaps VRAM sur GPU unique :
Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles
Phase 2 (validation model, 1 swap) : validation adversariale de tous
Phase 3 (cpam model, 1 swap) : correction des contrôles échoués
Phase 4 (validation model, 1 swap) : re-validation des corrigés
Gain : de ~4N swaps à ~2-4 swaps (indépendant du nombre de contrôles).
"""
from __future__ import annotations
@@ -11,6 +19,8 @@ from __future__ import annotations
import json
import logging
import os
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
@@ -41,6 +51,34 @@ from .cpam_validation import (
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Structure intermédiaire pour le batching
# ---------------------------------------------------------------------------
@dataclass
class _CpamDraft:
"""État intermédiaire d'un contrôle CPAM en cours de traitement batché."""
controle: ControleCPAM
# Résultats des phases successives
extraction: dict | None = None
degraded_pass1: bool = False
sources: list[dict] = field(default_factory=list)
rag_sources: list[RAGSource] = field(default_factory=list)
prompt: str = ""
tag_map: dict[str, str] = field(default_factory=dict)
result: dict | None = None
# Validation
ref_warnings: list[str] = field(default_factory=list)
grounding_warnings: list[str] = field(default_factory=list)
code_warnings: list[str] = field(default_factory=list)
adversarial_warnings: list[str] = field(default_factory=list)
validation: dict | None = None
# Résultat final
needs_correction: bool = False
text: str = ""
response_data: dict | None = None
def _save_version(
dossier: DossierMedical,
controle: ControleCPAM,
@@ -149,14 +187,18 @@ def _extraction_pass(
)
logger.debug(" Passe 1 — extraction structurée")
t0 = time.time()
result = call_ollama(prompt, temperature=0.0, max_tokens=3000, role="cpam")
if result is None:
result = call_anthropic(prompt, temperature=0.0, max_tokens=3000)
elapsed = time.time() - t0
if result is not None:
logger.info(" Passe 1 OK : %d éléments cliniques extraits",
logger.info(" [CPAM-EXTRACT] %.1fs — OGC %s %d éléments cliniques extraits",
elapsed, controle.numero_ogc,
len(result.get("elements_cliniques_pertinents", [])))
else:
logger.warning(" Passe 1 échouée — fallback single-pass")
logger.warning(" [CPAM-EXTRACT] %.1fs — OGC %s — passe 1 échouée",
elapsed, controle.numero_ogc)
return result
@@ -195,14 +237,17 @@ def generate_cpam_response(
prompt, tag_map = _build_cpam_prompt(dossier, controle, sources, extraction)
# 4. Appel LLM — Ollama (rôle cpam) > Haiku fallback
t_gen = time.time()
result = call_ollama(prompt, temperature=0.1, max_tokens=8000, role="cpam")
if result is not None:
logger.info(" Contre-argumentation via Ollama")
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama",
time.time() - t_gen, controle.numero_ogc)
else:
logger.info(" Ollama indisponible → fallback Anthropic Haiku")
result = call_anthropic(prompt, temperature=0.1, max_tokens=8000)
if result is not None:
logger.info(" Contre-argumentation via Anthropic Haiku")
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic",
time.time() - t_gen, controle.numero_ogc)
# 5. Conversion des sources RAG
rag_sources = [
@@ -285,9 +330,12 @@ def generate_cpam_response(
validation.get("score_confiance"), attempt + 1, max_corrections, len(erreurs_v))
correction_prompt = _build_correction_prompt(prompt, result, validation)
t_corr = time.time()
corrected = call_ollama(correction_prompt, temperature=0.0, max_tokens=16000, role="cpam")
if corrected is None:
corrected = call_anthropic(correction_prompt, temperature=0.0, max_tokens=16000)
logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d",
time.time() - t_corr, controle.numero_ogc, attempt + 1, max_corrections)
if not corrected:
break
@@ -347,3 +395,343 @@ def generate_cpam_response(
logger.info(" Contre-argumentation générée (%d caractères)", len(text))
return text, result, rag_sources
# ---------------------------------------------------------------------------
# Fonctions de phase (utilisées par le batching)
# ---------------------------------------------------------------------------
def _phase_generate(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 1 — Extraction + RAG + argumentation (role=cpam).
Remplit draft.extraction, draft.sources, draft.prompt, draft.tag_map, draft.result.
Ne fait aucun appel au modèle de validation.
"""
controle = draft.controle
logger.info("CPAM batch : génération pour OGC %d%s",
controle.numero_ogc, controle.titre)
# 0. Versioning
_save_version(dossier, controle)
# 1. Extraction structurée (role=cpam)
draft.extraction = _extraction_pass(dossier, controle)
draft.degraded_pass1 = draft.extraction is None
if draft.degraded_pass1:
dossier.alertes_codage.append(
"CPAM: passe 1 (extraction structurée) échouée → mode dégradé"
)
# 2. Recherche RAG (pas de LLM)
draft.sources = _search_rag_for_control(controle, dossier)
logger.info(" RAG : %d sources trouvées", len(draft.sources))
# 3. Construction du prompt
draft.prompt, draft.tag_map = _build_cpam_prompt(
dossier, controle, draft.sources, draft.extraction
)
# 4. Argumentation (role=cpam)
t_gen = time.time()
result = call_ollama(draft.prompt, temperature=0.1, max_tokens=8000, role="cpam")
if result is not None:
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama",
time.time() - t_gen, controle.numero_ogc)
else:
logger.info(" Ollama indisponible → fallback Anthropic Haiku")
result = call_anthropic(draft.prompt, temperature=0.1, max_tokens=8000)
if result is not None:
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic",
time.time() - t_gen, controle.numero_ogc)
# 5. Conversion sources RAG
draft.rag_sources = [
RAGSource(
document=s.get("document", ""),
page=s.get("page"),
code=s.get("code"),
extrait=s.get("extrait", "")[:200],
)
for s in draft.sources
]
if result is None:
logger.warning(" LLM non disponible — contre-argumentation non générée")
draft.result = None
return
# 5b. Mode dégradé
if draft.degraded_pass1:
result.setdefault("quality_flags", {})
result["quality_flags"]["cpam_pass1_failed"] = True
result["quality_flags"]["degraded_mode"] = True
# 6. Sanitisation + gardien déterministes (pas de LLM)
sanitized = _sanitize_unauthorized_codes(result, dossier, controle)
if sanitized:
logger.info(" CPAM : %d code(s) hors périmètre supprimé(s) : %s",
len(sanitized), ", ".join(sanitized))
result = _guardian_deterministic(result, dossier, controle, draft.tag_map)
# 7. Validations déterministes (pas de LLM)
draft.ref_warnings = _validate_references(result, draft.sources)
if draft.ref_warnings:
logger.warning(" CPAM : %d référence(s) non vérifiable(s)", len(draft.ref_warnings))
draft.grounding_warnings = _validate_grounding(result, draft.tag_map)
if draft.grounding_warnings:
logger.warning(" CPAM : %d preuve(s) non traçable(s)", len(draft.grounding_warnings))
draft.code_warnings = _validate_codes_in_response(result, dossier, controle)
if draft.code_warnings:
logger.warning(" CPAM : %d code(s) hors périmètre", len(draft.code_warnings))
draft.result = result
def _phase_validate(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 2 — Validation adversariale (role=validation).
Remplit draft.validation et draft.needs_correction.
"""
if draft.result is None:
return
controle = draft.controle
# LOGIC-3 : modèles identiques ?
from ..config import check_adversarial_model_config
same_model, model_msg = check_adversarial_model_config()
if same_model:
draft.result.setdefault("quality_flags", {})
draft.result["quality_flags"]["adversarial_disabled_same_model"] = True
dossier.alertes_codage.append(
"Validation adversariale désactivée (modèles identiques)"
)
draft.validation = _validate_adversarial(draft.result, draft.tag_map, controle)
if draft.validation and not draft.validation.get("coherent", True):
erreurs = draft.validation.get("erreurs", [])
score = draft.validation.get("score_confiance", "?")
for e in erreurs:
if isinstance(e, str) and e.strip():
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
if draft.adversarial_warnings:
draft.adversarial_warnings.append(f"Score de confiance : {score}/10")
# Déterminer si une correction est nécessaire
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
draft.needs_correction = bool(
max_corrections > 0
and draft.validation
and not draft.validation.get("coherent", True)
and draft.validation.get("score_confiance", 10) <= 5
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")
)
def _phase_correct(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 3 — Correction (role=cpam) + re-validation (role=validation).
Boucle correction/re-validation intégrée car les deux sont couplées.
Seuls les drafts avec needs_correction=True entrent ici.
"""
if draft.result is None or not draft.needs_correction:
return
controle = draft.controle
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
for attempt in range(max_corrections):
if not (draft.validation
and not draft.validation.get("coherent", True)
and draft.validation.get("score_confiance", 10) <= 5
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")):
break
erreurs_v = draft.validation.get("erreurs", [])
logger.warning(" Score adversarial %s/10 — correction %d/%d (%d erreur(s))",
draft.validation.get("score_confiance"),
attempt + 1, max_corrections, len(erreurs_v))
correction_prompt = _build_correction_prompt(
draft.prompt, draft.result, draft.validation
)
t_corr = time.time()
corrected = call_ollama(
correction_prompt, temperature=0.0, max_tokens=16000, role="cpam"
)
if corrected is None:
corrected = call_anthropic(
correction_prompt, temperature=0.0, max_tokens=16000
)
logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d",
time.time() - t_corr, controle.numero_ogc,
attempt + 1, max_corrections)
if not corrected:
break
validation2 = _validate_adversarial(corrected, draft.tag_map, controle)
score2 = validation2.get("score_confiance", 0) if validation2 else 0
score1 = draft.validation.get("score_confiance", 0)
if score2 > score1:
logger.info(" Correction %d acceptée (score %s%s)",
attempt + 1, score1, score2)
draft.result = corrected
draft.validation = validation2
_sanitize_unauthorized_codes(draft.result, dossier, controle)
draft.result = _guardian_deterministic(
draft.result, dossier, controle, draft.tag_map
)
draft.ref_warnings = _validate_references(draft.result, draft.sources)
draft.grounding_warnings = _validate_grounding(draft.result, draft.tag_map)
draft.code_warnings = _validate_codes_in_response(
draft.result, dossier, controle
)
draft.adversarial_warnings = []
if draft.validation and not draft.validation.get("coherent", True):
for e in draft.validation.get("erreurs", []):
if isinstance(e, str) and e.strip():
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
if draft.adversarial_warnings:
draft.adversarial_warnings.append(
f"Score de confiance : {draft.validation.get('score_confiance', '?')}/10"
)
else:
logger.warning(" Correction %d rejetée (score %s%s)",
attempt + 1, score1, score2)
break
def _phase_finalize(
dossier: DossierMedical,
draft: _CpamDraft,
) -> tuple[str, dict | None, list[RAGSource]]:
"""Phase finale — Qualité + formatage (déterministe, pas de LLM).
Returns:
Même signature que generate_cpam_response().
"""
controle = draft.controle
if draft.result is None:
return "", None, draft.rag_sources
all_warnings = (
draft.ref_warnings + draft.grounding_warnings
+ draft.code_warnings + draft.adversarial_warnings
)
strength = _assess_dossier_strength(dossier)
if strength["is_weak"]:
logger.info(" Dossier à preuves limitées (score %d/10) : %s",
strength["score"], ", ".join(strength["missing"]))
tier, needs_review, cat_warnings = _assess_quality_tier(
draft.result, draft.ref_warnings, draft.grounding_warnings,
draft.code_warnings, draft.validation,
is_weak_dossier=strength["is_weak"],
)
controle.quality_tier = tier
controle.requires_review = needs_review
controle.quality_warnings = cat_warnings
logger.info(" Qualité CPAM : tier %s, requires_review=%s, %d warnings",
tier, needs_review, len(cat_warnings))
text = _format_response(
draft.result,
ref_warnings=all_warnings,
quality_tier=tier,
categorized_warnings=cat_warnings,
)
logger.info(" Contre-argumentation générée (%d caractères)", len(text))
return text, draft.result, draft.rag_sources
# ---------------------------------------------------------------------------
# Orchestrateur batché — minimise les swaps de modèle VRAM
# ---------------------------------------------------------------------------
def generate_cpam_responses_batched(
dossier: DossierMedical,
controles: list[ControleCPAM],
) -> None:
"""Génère les contre-argumentations pour TOUS les contrôles CPAM en batch.
Regroupe les appels LLM par modèle pour minimiser les swaps VRAM :
Phase 1 (cpam model) : extraction + argumentation de tous les contrôles
Phase 2 (validation model) : validation adversariale de tous
Phase 3 (cpam model) : correction des contrôles échoués (si nécessaire)
Finalisation : qualité + formatage (déterministe)
Gain : de ~4*N swaps de modèle à 2-4 swaps, indépendant de N.
Args:
dossier: Le dossier médical (fusionné ou unique).
controles: Liste des contrôles CPAM à traiter.
Side effects:
Remplit controle.contre_argumentation, controle.response_data,
controle.sources_reponse pour chaque contrôle.
"""
if not controles:
return
t_total = time.time()
n = len(controles)
logger.info("CPAM batch : %d contrôle(s) — début du traitement batché", n)
# Créer les drafts
drafts = [_CpamDraft(controle=ctrl) for ctrl in controles]
# --- Phase 1 : génération (cpam model) ---
t_phase = time.time()
logger.info("CPAM batch phase 1/3 : génération (%d contrôles, role=cpam)", n)
for draft in drafts:
_phase_generate(dossier, draft)
logger.info("CPAM batch phase 1/3 terminée : %.1fs", time.time() - t_phase)
# --- Phase 2 : validation adversariale (validation model — 1 swap) ---
drafts_to_validate = [d for d in drafts if d.result is not None]
if drafts_to_validate:
t_phase = time.time()
logger.info("CPAM batch phase 2/3 : validation (%d contrôles, role=validation)",
len(drafts_to_validate))
for draft in drafts_to_validate:
_phase_validate(dossier, draft)
logger.info("CPAM batch phase 2/3 terminée : %.1fs", time.time() - t_phase)
# --- Phase 3 : correction + re-validation (cpam + validation models) ---
drafts_to_correct = [d for d in drafts if d.needs_correction]
if drafts_to_correct:
t_phase = time.time()
logger.info("CPAM batch phase 3/3 : correction (%d contrôles nécessitent correction)",
len(drafts_to_correct))
for draft in drafts_to_correct:
_phase_correct(dossier, draft)
logger.info("CPAM batch phase 3/3 terminée : %.1fs", time.time() - t_phase)
# --- Finalisation (déterministe, pas de swap) ---
for draft in drafts:
text, response_data, sources = _phase_finalize(dossier, draft)
draft.controle.contre_argumentation = text
draft.controle.response_data = response_data
draft.controle.sources_reponse = sources
elapsed_total = time.time() - t_total
logger.info(
"CPAM batch terminé : %d contrôle(s) en %.1fs (%.1fs/contrôle)",
n, elapsed_total, elapsed_total / n if n else 0,
)

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import logging
import re
import time
from ..config import ControleCPAM, DossierMedical
from ..medical.bio_normals import BIO_NORMALS
@@ -477,11 +478,14 @@ def _validate_adversarial(
)
logger.debug(" Validation adversariale")
t_val = time.time()
result = call_ollama(prompt, temperature=0.0, max_tokens=6000, role="validation")
if result is None:
result = call_anthropic(prompt, temperature=0.0, max_tokens=6000)
elapsed = time.time() - t_val
if result is None:
logger.warning(" Validation adversariale échouée — LLM indisponible")
logger.warning(" [CPAM-VALID] %.1fs — OGC %s — validation adversariale échouée",
elapsed, controle.numero_ogc)
return None
coherent = result.get("coherent", True)
@@ -489,12 +493,13 @@ def _validate_adversarial(
score = result.get("score_confiance", -1)
if not coherent and erreurs:
logger.warning(" Validation adversariale : %d incohérence(s) détectée(s) (score %s/10)",
len(erreurs), score)
logger.warning(" [CPAM-VALID] %.1fs — OGC %s %d incohérence(s) (score %s/10)",
elapsed, controle.numero_ogc, len(erreurs), score)
for e in erreurs:
logger.warning(" - %s", e)
else:
logger.info(" Validation adversariale OK (score %s/10)", score)
logger.info(" [CPAM-VALID] %.1fs — OGC %s OK (score %s/10)",
elapsed, controle.numero_ogc, score)
return result

View File

@@ -610,20 +610,18 @@ def main(input_path: str | None = None) -> None:
merged = None
# Contrôle CPAM : enrichir le dossier principal (fusionné ou dernier)
# Utilise le mode batché pour regrouper les appels LLM par modèle
# et minimiser les swaps VRAM (de ~4N swaps à ~2-4 swaps)
if cpam_data and subdir:
try:
from .control.cpam_parser import match_dossier_ogc
controles = match_dossier_ogc(subdir, cpam_data)
if controles:
from .control.cpam_response import generate_cpam_response
from .control.cpam_response import generate_cpam_responses_batched
target = merged if merged else (group_dossiers[-1] if group_dossiers else None)
if target:
logger.info(" CPAM : %d contrôle(s) pour %s", len(controles), subdir)
for ctrl in controles:
text, response_data, sources = generate_cpam_response(target, ctrl)
ctrl.contre_argumentation = text
ctrl.response_data = response_data
ctrl.sources_reponse = sources
generate_cpam_responses_batched(target, controles)
target.controles_cpam = controles
except Exception:
logger.exception("Erreur CPAM pour %s", subdir)

View File

@@ -11,6 +11,7 @@ from __future__ import annotations
import logging
import re
import time
from datetime import datetime
from typing import Optional
@@ -87,43 +88,71 @@ def extract_medical_info(
_extract_imagerie(anonymized_text, dossier)
_extract_complications(anonymized_text, dossier, edsnlp_result)
# Phase 4 : pass LLM pour détecter des DAS supplémentaires
# Phase 4 : DAS LLM + RAG DP + DP selector en parallèle
_dp_selection_needed = dossier.document_type != "trackare"
if use_rag:
_extract_das_llm(anonymized_text, dossier)
# Optimisation #1 : paralléliser enrichissement RAG et sélection DP
_dp_selection_needed = use_rag and dossier.document_type != "trackare"
if use_rag or _dp_selection_needed:
from concurrent.futures import ThreadPoolExecutor, as_completed
def _task_enrich():
if use_rag:
_enrich_with_rag(dossier)
# --- Groupe 1 : 3 tâches indépendantes en parallèle ---
# - DAS LLM : détecte des DAS supplémentaires (ne dépend pas du RAG DP)
# - RAG DP : enrichit seulement le DP (ne dépend pas du DAS LLM)
# - DP selector : sélectionne le DP optimal (indépendant des deux autres)
def _task_das_llm():
t0 = time.monotonic()
_extract_das_llm(anonymized_text, dossier)
elapsed = time.monotonic() - t0
logger.info("⏱ [DAS-LLM-TASK] %.1fs — extraction DAS LLM terminée", elapsed)
def _task_rag_dp():
t0 = time.monotonic()
_enrich_dp_only(dossier)
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-DP-TASK] %.1fs — enrichissement DP terminé", elapsed)
def _task_select_dp():
if not _dp_selection_needed:
return None
t0 = time.monotonic()
from .dp_selector import select_dp, build_synthese
synthese = build_synthese(dossier, parsed_data)
return select_dp(dossier, synthese, config={"llm_enabled": use_rag})
result = select_dp(dossier, synthese, config={"llm_enabled": use_rag})
elapsed = time.monotonic() - t0
logger.info("⏱ [DP-SELECT] %.1fs — sélection DP terminée", elapsed)
return result
dp_selection_result = None
with ThreadPoolExecutor(max_workers=2) as pool:
fut_enrich = pool.submit(_task_enrich)
t_group1 = time.monotonic()
with ThreadPoolExecutor(max_workers=3) as pool:
fut_das = pool.submit(_task_das_llm)
fut_rag_dp = pool.submit(_task_rag_dp)
fut_dp = pool.submit(_task_select_dp)
# Attendre les deux tâches
for fut in as_completed([fut_enrich, fut_dp]):
for fut in as_completed([fut_das, fut_rag_dp, fut_dp]):
exc = fut.exception()
if exc and fut is fut_dp:
logger.error("NUKE-3: erreur sélection DP", exc_info=exc)
dossier.quality_flags["dp_selection_status"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : sélection DP (NUKE-3) en erreur")
elif exc:
logger.error("RAG enrichissement échoué", exc_info=exc)
elif exc and fut is fut_rag_dp:
logger.error("RAG enrichissement DP échoué", exc_info=exc)
elif exc and fut is fut_das:
logger.error("DAS LLM extraction échouée", exc_info=exc)
if not fut_dp.exception():
dp_selection_result = fut_dp.result()
elapsed_group1 = time.monotonic() - t_group1
logger.info("⏱ [GROUPE-1] %.1fs — DAS_LLM + RAG_DP + DP_SELECT terminés", elapsed_group1)
# --- Groupe 2 : enrichir les DAS (existants + nouveaux du DAS LLM) + actes ---
t_group2 = time.monotonic()
_enrich_das_and_actes(dossier)
elapsed_group2 = time.monotonic() - t_group2
logger.info("⏱ [GROUPE-2] %.1fs — enrichissement DAS + actes terminé", elapsed_group2)
# Appliquer la sélection DP après parallélisation
if dp_selection_result is not None:
selection = dp_selection_result
@@ -151,12 +180,15 @@ def extract_medical_info(
dossier.alertes_codage.append(
f"NUKE-3 REVIEW: DP ambigu — {selection.reason}"
)
elif dossier.document_type != "trackare":
elif _dp_selection_needed:
# Fallback sans RAG : sélection DP seule
try:
t_dp_norag = time.monotonic()
from .dp_selector import select_dp, build_synthese
synthese = build_synthese(dossier, parsed_data)
selection = select_dp(dossier, synthese, config={"llm_enabled": False})
elapsed_dp = time.monotonic() - t_dp_norag
logger.info("⏱ [DP-SELECT] %.1fs — sélection DP (sans RAG)", elapsed_dp)
dossier.dp_selection = selection
if selection.chosen_code:
current_code = (
@@ -240,7 +272,10 @@ def extract_medical_info(
# Post-processing : validation justifications (QC batch)
if use_rag:
t_qc = time.monotonic()
_validate_justifications(dossier)
elapsed_qc = time.monotonic() - t_qc
logger.info("⏱ [QC-VALIDATION] %.1fs — validation justifications terminée", elapsed_qc)
# Post-processing : traçabilité source (page + extrait)
if page_tracker:
@@ -457,7 +492,7 @@ def _extract_das_llm(text: str, dossier: DossierMedical) -> None:
def _enrich_with_rag(dossier: DossierMedical) -> None:
"""Enrichit les diagnostics via le RAG (FAISS + Ollama)."""
"""Enrichit les diagnostics via le RAG (FAISS + Ollama) — wrapper rétro-compatible."""
try:
from .rag_search import enrich_dossier
enrich_dossier(dossier)
@@ -471,6 +506,34 @@ def _enrich_with_rag(dossier: DossierMedical) -> None:
dossier.alertes_codage.append("QUALITE DEGRADEE : erreur RAG — codage sans référentiels")
def _enrich_dp_only(dossier: DossierMedical) -> None:
"""Enrichit SEULEMENT le DP via le RAG (Phase 1, parallélisable)."""
try:
from .rag_search import enrich_dp
enrich_dp(dossier)
except ImportError:
logger.error("RAG INDISPONIBLE : faiss-cpu ou sentence-transformers manquant")
dossier.quality_flags["rag_status"] = "unavailable"
dossier.alertes_codage.append("QUALITE DEGRADEE : RAG indisponible — codage sans référentiels")
except Exception:
logger.error("RAG EN ERREUR : enrichissement DP échoué", exc_info=True)
dossier.quality_flags["rag_status"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : erreur RAG DP — codage sans référentiels")
def _enrich_das_and_actes(dossier: DossierMedical) -> None:
"""Enrichit les DAS et actes CCAM via le RAG (Phase 2, après DAS LLM)."""
try:
from .rag_search import enrich_das_and_actes
enrich_das_and_actes(dossier)
except ImportError:
logger.error("RAG INDISPONIBLE : faiss-cpu ou sentence-transformers manquant")
dossier.quality_flags["rag_status"] = "unavailable"
except Exception:
logger.error("RAG EN ERREUR : enrichissement DAS/actes échoué", exc_info=True)
dossier.quality_flags["rag_status"] = "error"
def _extract_sejour(parsed: dict, dossier: DossierMedical) -> None:
"""Extrait les informations de séjour."""
patient = parsed.get("patient", {})

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import logging
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from ..config import (
@@ -610,7 +611,9 @@ def enrich_diagnostic(
Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent.
"""
t0 = time.monotonic()
diag_type = "dp" if est_dp else "das"
label = f"RAG-{diag_type.upper()}"
# 1. Vérifier le cache
cached = cache.get(diagnostic.texte, diag_type) if cache else None
@@ -626,6 +629,8 @@ def enrich_diagnostic(
if cached is not None:
logger.info("Cache hit (sans sources FAISS) pour %s : « %s »", diag_type.upper(), diagnostic.texte)
_apply_llm_result_diagnostic(diagnostic, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs — %s (no FAISS)", label, elapsed, diagnostic.texte[:60])
return
# 3. Stocker les sources RAG
@@ -643,11 +648,15 @@ def enrich_diagnostic(
if cached is not None:
logger.info("Cache hit pour %s : « %s »", diag_type.upper(), diagnostic.texte)
_apply_llm_result_diagnostic(diagnostic, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs — %s (cache hit)", label, elapsed, diagnostic.texte[:60])
return
# 5. Appel Ollama pour justification avec raisonnement structuré
prompt = _build_prompt(diagnostic.texte, sources, contexte, est_dp=est_dp)
t_llm = time.monotonic()
llm_result = _call_ollama(prompt)
llm_elapsed = time.monotonic() - t_llm
if llm_result:
_apply_llm_result_diagnostic(diagnostic, llm_result)
@@ -656,6 +665,9 @@ def enrich_diagnostic(
else:
logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM")
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs (LLM %.1fs) — %s", label, elapsed, llm_elapsed, diagnostic.texte[:60])
def _apply_llm_result_acte(acte: ActeCCAM, llm_result: dict) -> None:
"""Applique un résultat LLM (frais ou caché) à un ActeCCAM."""
@@ -687,6 +699,8 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None
Modifie l'acte en place. Fallback gracieux si FAISS ou Ollama échouent.
"""
t0 = time.monotonic()
# 1. Vérifier le cache
cached = cache.get(acte.texte, "ccam") if cache else None
@@ -712,11 +726,15 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None
if cached is not None:
logger.info("Cache hit pour CCAM : « %s »", acte.texte)
_apply_llm_result_acte(acte, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-CCAM] %.1fs — %s (cache hit)", elapsed, acte.texte[:60])
return
# 5. Appel Ollama pour justification avec raisonnement structuré
prompt = _build_prompt_ccam(acte.texte, sources, contexte)
t_llm = time.monotonic()
llm_result = _call_ollama(prompt)
llm_elapsed = time.monotonic() - t_llm
if llm_result:
_apply_llm_result_acte(acte, llm_result)
@@ -725,6 +743,9 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None
else:
logger.info("Ollama non disponible — sources FAISS CCAM conservées sans justification LLM")
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-CCAM] %.1fs (LLM %.1fs) — %s", elapsed, llm_elapsed, acte.texte[:60])
def _smart_truncate(text: str, max_chars: int = 6000) -> str:
"""Troncature intelligente : garde le début + les sections finales importantes.
@@ -790,6 +811,8 @@ def extract_das_llm(
"""
import hashlib
t0 = time.monotonic()
# Clé de cache basée sur le hash du texte
text_hash = hashlib.md5(text[:4000].encode()).hexdigest()[:16]
cache_key_text = f"das_extract::{text_hash}"
@@ -798,15 +821,19 @@ def extract_das_llm(
if cache is not None:
cached = cache.get(cache_key_text, "das_llm")
if cached is not None:
logger.info("Cache hit pour extraction DAS LLM")
elapsed = time.monotonic() - t0
logger.info("⏱ [DAS-LLM] %.1fs — cache hit", elapsed)
return cached.get("diagnostics_supplementaires", [])
# Construire le prompt et appeler Ollama
prompt = _build_prompt_das_extraction(text, contexte, existing_das, dp_texte)
t_llm = time.monotonic()
result = call_ollama(prompt, temperature=0.1, max_tokens=2000, role="coding")
llm_elapsed = time.monotonic() - t_llm
if result is None:
logger.warning("Extraction DAS LLM : Ollama non disponible")
elapsed = time.monotonic() - t0
logger.warning("⏱ [DAS-LLM] %.1fs — Ollama non disponible", elapsed)
return []
das_list = result.get("diagnostics_supplementaires", [])
@@ -818,25 +845,52 @@ def extract_das_llm(
if cache is not None:
cache.put(cache_key_text, "das_llm", result)
logger.info("Extraction DAS LLM : %d diagnostics supplémentaires détectés", len(das_list))
elapsed = time.monotonic() - t0
logger.info("⏱ [DAS-LLM] %.1fs (LLM %.1fs) — %d diagnostics détectés", elapsed, llm_elapsed, len(das_list))
return das_list
def enrich_dossier(dossier: DossierMedical) -> None:
"""Enrichit le DP et tous les DAS d'un dossier via le RAG.
def enrich_dp(dossier: DossierMedical, cache: OllamaCache | None = None) -> None:
"""Enrichit SEULEMENT le DP d'un dossier via le RAG (Phase 1).
Utilise un cache persistant et parallélise les appels Ollama
pour les DAS et actes CCAM (max_workers = OLLAMA_MAX_PARALLEL).
Peut être exécuté en parallèle avec d'autres tâches (DAS LLM, DP selector)
car il ne dépend que du DP existant.
Args:
dossier: Dossier médical à enrichir (modifié en place).
cache: Cache Ollama optionnel (créé si None).
"""
t0 = time.monotonic()
if cache is None:
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
contexte = build_enriched_context(dossier)
# Phase 1 : DP seul (le contexte DAS en dépend)
if dossier.diagnostic_principal:
logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte)
enrich_diagnostic(dossier.diagnostic_principal, contexte, est_dp=True, cache=cache)
cache.save()
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-DP-PHASE] %.1fs — enrichissement DP terminé", elapsed)
def enrich_das_and_actes(dossier: DossierMedical, cache: OllamaCache | None = None) -> None:
"""Enrichit les DAS et actes CCAM d'un dossier via le RAG (Phase 2).
Parallélise les appels Ollama (max_workers = OLLAMA_MAX_PARALLEL).
Doit être appelé APRES enrich_dp() et après l'ajout des DAS LLM.
Args:
dossier: Dossier médical à enrichir (modifié en place).
cache: Cache Ollama optionnel (créé si None).
"""
t0 = time.monotonic()
if cache is None:
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
contexte = build_enriched_context(dossier)
# Mettre à jour le contexte avec le DP pour les DAS
if dossier.diagnostic_principal:
contexte["dp_texte"] = dossier.diagnostic_principal.texte
@@ -846,7 +900,6 @@ def enrich_dossier(dossier: DossierMedical) -> None:
if d.cim10_suggestion
]
# Phase 2 : DAS + Actes en parallèle
das_list = dossier.diagnostics_associes
actes_list = dossier.actes_ccam
@@ -863,3 +916,18 @@ def enrich_dossier(dossier: DossierMedical) -> None:
f.result() # propage les exceptions
cache.save()
elapsed = time.monotonic() - t0
n_das = len(das_list) if das_list else 0
n_actes = len(actes_list) if actes_list else 0
logger.info("⏱ [RAG-DAS-PHASE] %.1fs — %d DAS + %d actes enrichis", elapsed, n_das, n_actes)
def enrich_dossier(dossier: DossierMedical) -> None:
"""Enrichit le DP et tous les DAS d'un dossier via le RAG.
Wrapper rétro-compatible qui appelle enrich_dp() puis enrich_das_and_actes().
Utilise un cache persistant partagé entre les deux phases.
"""
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
enrich_dp(dossier, cache=cache)
enrich_das_and_actes(dossier, cache=cache)

View File

@@ -116,7 +116,7 @@ class TestExtractMedicalInfoRAGFlag:
assert dossier.diagnostic_principal.justification is None
def test_use_rag_true_calls_enrich(self):
"""use_rag=True appelle _enrich_with_rag (mocké)."""
"""use_rag=True appelle _enrich_dp_only et _enrich_das_and_actes (mockés)."""
from src.medical.cim10_extractor import extract_medical_info
parsed = {
@@ -127,9 +127,13 @@ class TestExtractMedicalInfoRAGFlag:
}
text = "Pancréatite aiguë biliaire.\nTTT de sortie :\nParacétamol\n\nDevenir : retour."
with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich:
with patch("src.medical.cim10_extractor._enrich_dp_only") as mock_dp, \
patch("src.medical.cim10_extractor._enrich_das_and_actes") as mock_das, \
patch("src.medical.cim10_extractor._extract_das_llm"), \
patch("src.medical.cim10_extractor._validate_justifications"):
dossier = extract_medical_info(parsed, text, use_rag=True)
mock_enrich.assert_called_once_with(dossier)
mock_dp.assert_called_once_with(dossier)
mock_das.assert_called_once_with(dossier)
def test_use_rag_default_false(self):
"""Par défaut, use_rag=False."""
@@ -143,9 +147,11 @@ class TestExtractMedicalInfoRAGFlag:
}
text = "Test simple."
with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich:
with patch("src.medical.cim10_extractor._enrich_dp_only") as mock_dp, \
patch("src.medical.cim10_extractor._enrich_das_and_actes") as mock_das:
extract_medical_info(parsed, text)
mock_enrich.assert_not_called()
mock_dp.assert_not_called()
mock_das.assert_not_called()
class TestChunkingCIM10: