Merge branch 'feat/parallel-rag-pipeline' — optimisation flux LLM

This commit is contained in:
dom
2026-03-08 14:35:54 +01:00
7 changed files with 577 additions and 49 deletions

View File

@@ -60,7 +60,7 @@ OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "gemma3:27b") OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "gemma3:27b")
OLLAMA_TIMEOUT = int(os.environ.get("OLLAMA_TIMEOUT", "600")) OLLAMA_TIMEOUT = int(os.environ.get("OLLAMA_TIMEOUT", "600"))
OLLAMA_CACHE_PATH = BASE_DIR / "data" / "ollama_cache.json" OLLAMA_CACHE_PATH = BASE_DIR / "data" / "ollama_cache.json"
OLLAMA_MAX_PARALLEL = int(os.environ.get("OLLAMA_MAX_PARALLEL", "2")) OLLAMA_MAX_PARALLEL = int(os.environ.get("OLLAMA_MAX_PARALLEL", "4"))
# --- Modèles par rôle LLM --- # --- Modèles par rôle LLM ---

View File

@@ -4,6 +4,14 @@ Orchestrateur principal — délègue aux sous-modules :
- cpam_rag : _search_rag_for_control(), _search_rag_queries() - cpam_rag : _search_rag_for_control(), _search_rag_queries()
- cpam_context : _build_cpam_prompt(), _build_tagged_context(), _build_bio_summary(), etc. - cpam_context : _build_cpam_prompt(), _build_tagged_context(), _build_bio_summary(), etc.
- cpam_validation : _validate_adversarial(), _validate_grounding(), _format_response(), etc. - cpam_validation : _validate_adversarial(), _validate_grounding(), _format_response(), etc.
Le flux batché (generate_cpam_responses_batched) regroupe les appels LLM par modèle
pour minimiser les swaps VRAM sur GPU unique :
Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles
Phase 2 (validation model, 1 swap) : validation adversariale de tous
Phase 3 (cpam model, 1 swap) : correction des contrôles échoués
Phase 4 (validation model, 1 swap) : re-validation des corrigés
Gain : de ~4N swaps à ~2-4 swaps (indépendant du nombre de contrôles).
""" """
from __future__ import annotations from __future__ import annotations
@@ -11,6 +19,8 @@ from __future__ import annotations
import json import json
import logging import logging
import os import os
import time
from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -41,6 +51,34 @@ from .cpam_validation import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Structure intermédiaire pour le batching
# ---------------------------------------------------------------------------
@dataclass
class _CpamDraft:
"""État intermédiaire d'un contrôle CPAM en cours de traitement batché."""
controle: ControleCPAM
# Résultats des phases successives
extraction: dict | None = None
degraded_pass1: bool = False
sources: list[dict] = field(default_factory=list)
rag_sources: list[RAGSource] = field(default_factory=list)
prompt: str = ""
tag_map: dict[str, str] = field(default_factory=dict)
result: dict | None = None
# Validation
ref_warnings: list[str] = field(default_factory=list)
grounding_warnings: list[str] = field(default_factory=list)
code_warnings: list[str] = field(default_factory=list)
adversarial_warnings: list[str] = field(default_factory=list)
validation: dict | None = None
# Résultat final
needs_correction: bool = False
text: str = ""
response_data: dict | None = None
def _save_version( def _save_version(
dossier: DossierMedical, dossier: DossierMedical,
controle: ControleCPAM, controle: ControleCPAM,
@@ -149,14 +187,18 @@ def _extraction_pass(
) )
logger.debug(" Passe 1 — extraction structurée") logger.debug(" Passe 1 — extraction structurée")
t0 = time.time()
result = call_ollama(prompt, temperature=0.0, max_tokens=3000, role="cpam") result = call_ollama(prompt, temperature=0.0, max_tokens=3000, role="cpam")
if result is None: if result is None:
result = call_anthropic(prompt, temperature=0.0, max_tokens=3000) result = call_anthropic(prompt, temperature=0.0, max_tokens=3000)
elapsed = time.time() - t0
if result is not None: if result is not None:
logger.info(" Passe 1 OK : %d éléments cliniques extraits", logger.info(" [CPAM-EXTRACT] %.1fs — OGC %s %d éléments cliniques extraits",
elapsed, controle.numero_ogc,
len(result.get("elements_cliniques_pertinents", []))) len(result.get("elements_cliniques_pertinents", [])))
else: else:
logger.warning(" Passe 1 échouée — fallback single-pass") logger.warning(" [CPAM-EXTRACT] %.1fs — OGC %s — passe 1 échouée",
elapsed, controle.numero_ogc)
return result return result
@@ -195,14 +237,17 @@ def generate_cpam_response(
prompt, tag_map = _build_cpam_prompt(dossier, controle, sources, extraction) prompt, tag_map = _build_cpam_prompt(dossier, controle, sources, extraction)
# 4. Appel LLM — Ollama (rôle cpam) > Haiku fallback # 4. Appel LLM — Ollama (rôle cpam) > Haiku fallback
t_gen = time.time()
result = call_ollama(prompt, temperature=0.1, max_tokens=8000, role="cpam") result = call_ollama(prompt, temperature=0.1, max_tokens=8000, role="cpam")
if result is not None: if result is not None:
logger.info(" Contre-argumentation via Ollama") logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama",
time.time() - t_gen, controle.numero_ogc)
else: else:
logger.info(" Ollama indisponible → fallback Anthropic Haiku") logger.info(" Ollama indisponible → fallback Anthropic Haiku")
result = call_anthropic(prompt, temperature=0.1, max_tokens=8000) result = call_anthropic(prompt, temperature=0.1, max_tokens=8000)
if result is not None: if result is not None:
logger.info(" Contre-argumentation via Anthropic Haiku") logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic",
time.time() - t_gen, controle.numero_ogc)
# 5. Conversion des sources RAG # 5. Conversion des sources RAG
rag_sources = [ rag_sources = [
@@ -285,9 +330,12 @@ def generate_cpam_response(
validation.get("score_confiance"), attempt + 1, max_corrections, len(erreurs_v)) validation.get("score_confiance"), attempt + 1, max_corrections, len(erreurs_v))
correction_prompt = _build_correction_prompt(prompt, result, validation) correction_prompt = _build_correction_prompt(prompt, result, validation)
t_corr = time.time()
corrected = call_ollama(correction_prompt, temperature=0.0, max_tokens=16000, role="cpam") corrected = call_ollama(correction_prompt, temperature=0.0, max_tokens=16000, role="cpam")
if corrected is None: if corrected is None:
corrected = call_anthropic(correction_prompt, temperature=0.0, max_tokens=16000) corrected = call_anthropic(correction_prompt, temperature=0.0, max_tokens=16000)
logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d",
time.time() - t_corr, controle.numero_ogc, attempt + 1, max_corrections)
if not corrected: if not corrected:
break break
@@ -347,3 +395,343 @@ def generate_cpam_response(
logger.info(" Contre-argumentation générée (%d caractères)", len(text)) logger.info(" Contre-argumentation générée (%d caractères)", len(text))
return text, result, rag_sources return text, result, rag_sources
# ---------------------------------------------------------------------------
# Fonctions de phase (utilisées par le batching)
# ---------------------------------------------------------------------------
def _phase_generate(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 1 — Extraction + RAG + argumentation (role=cpam).
Remplit draft.extraction, draft.sources, draft.prompt, draft.tag_map, draft.result.
Ne fait aucun appel au modèle de validation.
"""
controle = draft.controle
logger.info("CPAM batch : génération pour OGC %d%s",
controle.numero_ogc, controle.titre)
# 0. Versioning
_save_version(dossier, controle)
# 1. Extraction structurée (role=cpam)
draft.extraction = _extraction_pass(dossier, controle)
draft.degraded_pass1 = draft.extraction is None
if draft.degraded_pass1:
dossier.alertes_codage.append(
"CPAM: passe 1 (extraction structurée) échouée → mode dégradé"
)
# 2. Recherche RAG (pas de LLM)
draft.sources = _search_rag_for_control(controle, dossier)
logger.info(" RAG : %d sources trouvées", len(draft.sources))
# 3. Construction du prompt
draft.prompt, draft.tag_map = _build_cpam_prompt(
dossier, controle, draft.sources, draft.extraction
)
# 4. Argumentation (role=cpam)
t_gen = time.time()
result = call_ollama(draft.prompt, temperature=0.1, max_tokens=8000, role="cpam")
if result is not None:
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama",
time.time() - t_gen, controle.numero_ogc)
else:
logger.info(" Ollama indisponible → fallback Anthropic Haiku")
result = call_anthropic(draft.prompt, temperature=0.1, max_tokens=8000)
if result is not None:
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic",
time.time() - t_gen, controle.numero_ogc)
# 5. Conversion sources RAG
draft.rag_sources = [
RAGSource(
document=s.get("document", ""),
page=s.get("page"),
code=s.get("code"),
extrait=s.get("extrait", "")[:200],
)
for s in draft.sources
]
if result is None:
logger.warning(" LLM non disponible — contre-argumentation non générée")
draft.result = None
return
# 5b. Mode dégradé
if draft.degraded_pass1:
result.setdefault("quality_flags", {})
result["quality_flags"]["cpam_pass1_failed"] = True
result["quality_flags"]["degraded_mode"] = True
# 6. Sanitisation + gardien déterministes (pas de LLM)
sanitized = _sanitize_unauthorized_codes(result, dossier, controle)
if sanitized:
logger.info(" CPAM : %d code(s) hors périmètre supprimé(s) : %s",
len(sanitized), ", ".join(sanitized))
result = _guardian_deterministic(result, dossier, controle, draft.tag_map)
# 7. Validations déterministes (pas de LLM)
draft.ref_warnings = _validate_references(result, draft.sources)
if draft.ref_warnings:
logger.warning(" CPAM : %d référence(s) non vérifiable(s)", len(draft.ref_warnings))
draft.grounding_warnings = _validate_grounding(result, draft.tag_map)
if draft.grounding_warnings:
logger.warning(" CPAM : %d preuve(s) non traçable(s)", len(draft.grounding_warnings))
draft.code_warnings = _validate_codes_in_response(result, dossier, controle)
if draft.code_warnings:
logger.warning(" CPAM : %d code(s) hors périmètre", len(draft.code_warnings))
draft.result = result
def _phase_validate(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 2 — Validation adversariale (role=validation).
Remplit draft.validation et draft.needs_correction.
"""
if draft.result is None:
return
controle = draft.controle
# LOGIC-3 : modèles identiques ?
from ..config import check_adversarial_model_config
same_model, model_msg = check_adversarial_model_config()
if same_model:
draft.result.setdefault("quality_flags", {})
draft.result["quality_flags"]["adversarial_disabled_same_model"] = True
dossier.alertes_codage.append(
"Validation adversariale désactivée (modèles identiques)"
)
draft.validation = _validate_adversarial(draft.result, draft.tag_map, controle)
if draft.validation and not draft.validation.get("coherent", True):
erreurs = draft.validation.get("erreurs", [])
score = draft.validation.get("score_confiance", "?")
for e in erreurs:
if isinstance(e, str) and e.strip():
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
if draft.adversarial_warnings:
draft.adversarial_warnings.append(f"Score de confiance : {score}/10")
# Déterminer si une correction est nécessaire
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
draft.needs_correction = bool(
max_corrections > 0
and draft.validation
and not draft.validation.get("coherent", True)
and draft.validation.get("score_confiance", 10) <= 5
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")
)
def _phase_correct(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 3 — Correction (role=cpam) + re-validation (role=validation).
Boucle correction/re-validation intégrée car les deux sont couplées.
Seuls les drafts avec needs_correction=True entrent ici.
"""
if draft.result is None or not draft.needs_correction:
return
controle = draft.controle
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
for attempt in range(max_corrections):
if not (draft.validation
and not draft.validation.get("coherent", True)
and draft.validation.get("score_confiance", 10) <= 5
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")):
break
erreurs_v = draft.validation.get("erreurs", [])
logger.warning(" Score adversarial %s/10 — correction %d/%d (%d erreur(s))",
draft.validation.get("score_confiance"),
attempt + 1, max_corrections, len(erreurs_v))
correction_prompt = _build_correction_prompt(
draft.prompt, draft.result, draft.validation
)
t_corr = time.time()
corrected = call_ollama(
correction_prompt, temperature=0.0, max_tokens=16000, role="cpam"
)
if corrected is None:
corrected = call_anthropic(
correction_prompt, temperature=0.0, max_tokens=16000
)
logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d",
time.time() - t_corr, controle.numero_ogc,
attempt + 1, max_corrections)
if not corrected:
break
validation2 = _validate_adversarial(corrected, draft.tag_map, controle)
score2 = validation2.get("score_confiance", 0) if validation2 else 0
score1 = draft.validation.get("score_confiance", 0)
if score2 > score1:
logger.info(" Correction %d acceptée (score %s%s)",
attempt + 1, score1, score2)
draft.result = corrected
draft.validation = validation2
_sanitize_unauthorized_codes(draft.result, dossier, controle)
draft.result = _guardian_deterministic(
draft.result, dossier, controle, draft.tag_map
)
draft.ref_warnings = _validate_references(draft.result, draft.sources)
draft.grounding_warnings = _validate_grounding(draft.result, draft.tag_map)
draft.code_warnings = _validate_codes_in_response(
draft.result, dossier, controle
)
draft.adversarial_warnings = []
if draft.validation and not draft.validation.get("coherent", True):
for e in draft.validation.get("erreurs", []):
if isinstance(e, str) and e.strip():
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
if draft.adversarial_warnings:
draft.adversarial_warnings.append(
f"Score de confiance : {draft.validation.get('score_confiance', '?')}/10"
)
else:
logger.warning(" Correction %d rejetée (score %s%s)",
attempt + 1, score1, score2)
break
def _phase_finalize(
dossier: DossierMedical,
draft: _CpamDraft,
) -> tuple[str, dict | None, list[RAGSource]]:
"""Phase finale — Qualité + formatage (déterministe, pas de LLM).
Returns:
Même signature que generate_cpam_response().
"""
controle = draft.controle
if draft.result is None:
return "", None, draft.rag_sources
all_warnings = (
draft.ref_warnings + draft.grounding_warnings
+ draft.code_warnings + draft.adversarial_warnings
)
strength = _assess_dossier_strength(dossier)
if strength["is_weak"]:
logger.info(" Dossier à preuves limitées (score %d/10) : %s",
strength["score"], ", ".join(strength["missing"]))
tier, needs_review, cat_warnings = _assess_quality_tier(
draft.result, draft.ref_warnings, draft.grounding_warnings,
draft.code_warnings, draft.validation,
is_weak_dossier=strength["is_weak"],
)
controle.quality_tier = tier
controle.requires_review = needs_review
controle.quality_warnings = cat_warnings
logger.info(" Qualité CPAM : tier %s, requires_review=%s, %d warnings",
tier, needs_review, len(cat_warnings))
text = _format_response(
draft.result,
ref_warnings=all_warnings,
quality_tier=tier,
categorized_warnings=cat_warnings,
)
logger.info(" Contre-argumentation générée (%d caractères)", len(text))
return text, draft.result, draft.rag_sources
# ---------------------------------------------------------------------------
# Orchestrateur batché — minimise les swaps de modèle VRAM
# ---------------------------------------------------------------------------
def generate_cpam_responses_batched(
dossier: DossierMedical,
controles: list[ControleCPAM],
) -> None:
"""Génère les contre-argumentations pour TOUS les contrôles CPAM en batch.
Regroupe les appels LLM par modèle pour minimiser les swaps VRAM :
Phase 1 (cpam model) : extraction + argumentation de tous les contrôles
Phase 2 (validation model) : validation adversariale de tous
Phase 3 (cpam model) : correction des contrôles échoués (si nécessaire)
Finalisation : qualité + formatage (déterministe)
Gain : de ~4*N swaps de modèle à 2-4 swaps, indépendant de N.
Args:
dossier: Le dossier médical (fusionné ou unique).
controles: Liste des contrôles CPAM à traiter.
Side effects:
Remplit controle.contre_argumentation, controle.response_data,
controle.sources_reponse pour chaque contrôle.
"""
if not controles:
return
t_total = time.time()
n = len(controles)
logger.info("CPAM batch : %d contrôle(s) — début du traitement batché", n)
# Créer les drafts
drafts = [_CpamDraft(controle=ctrl) for ctrl in controles]
# --- Phase 1 : génération (cpam model) ---
t_phase = time.time()
logger.info("CPAM batch phase 1/3 : génération (%d contrôles, role=cpam)", n)
for draft in drafts:
_phase_generate(dossier, draft)
logger.info("CPAM batch phase 1/3 terminée : %.1fs", time.time() - t_phase)
# --- Phase 2 : validation adversariale (validation model — 1 swap) ---
drafts_to_validate = [d for d in drafts if d.result is not None]
if drafts_to_validate:
t_phase = time.time()
logger.info("CPAM batch phase 2/3 : validation (%d contrôles, role=validation)",
len(drafts_to_validate))
for draft in drafts_to_validate:
_phase_validate(dossier, draft)
logger.info("CPAM batch phase 2/3 terminée : %.1fs", time.time() - t_phase)
# --- Phase 3 : correction + re-validation (cpam + validation models) ---
drafts_to_correct = [d for d in drafts if d.needs_correction]
if drafts_to_correct:
t_phase = time.time()
logger.info("CPAM batch phase 3/3 : correction (%d contrôles nécessitent correction)",
len(drafts_to_correct))
for draft in drafts_to_correct:
_phase_correct(dossier, draft)
logger.info("CPAM batch phase 3/3 terminée : %.1fs", time.time() - t_phase)
# --- Finalisation (déterministe, pas de swap) ---
for draft in drafts:
text, response_data, sources = _phase_finalize(dossier, draft)
draft.controle.contre_argumentation = text
draft.controle.response_data = response_data
draft.controle.sources_reponse = sources
elapsed_total = time.time() - t_total
logger.info(
"CPAM batch terminé : %d contrôle(s) en %.1fs (%.1fs/contrôle)",
n, elapsed_total, elapsed_total / n if n else 0,
)

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import logging import logging
import re import re
import time
from ..config import ControleCPAM, DossierMedical from ..config import ControleCPAM, DossierMedical
from ..medical.bio_normals import BIO_NORMALS from ..medical.bio_normals import BIO_NORMALS
@@ -477,11 +478,14 @@ def _validate_adversarial(
) )
logger.debug(" Validation adversariale") logger.debug(" Validation adversariale")
t_val = time.time()
result = call_ollama(prompt, temperature=0.0, max_tokens=6000, role="validation") result = call_ollama(prompt, temperature=0.0, max_tokens=6000, role="validation")
if result is None: if result is None:
result = call_anthropic(prompt, temperature=0.0, max_tokens=6000) result = call_anthropic(prompt, temperature=0.0, max_tokens=6000)
elapsed = time.time() - t_val
if result is None: if result is None:
logger.warning(" Validation adversariale échouée — LLM indisponible") logger.warning(" [CPAM-VALID] %.1fs — OGC %s — validation adversariale échouée",
elapsed, controle.numero_ogc)
return None return None
coherent = result.get("coherent", True) coherent = result.get("coherent", True)
@@ -489,12 +493,13 @@ def _validate_adversarial(
score = result.get("score_confiance", -1) score = result.get("score_confiance", -1)
if not coherent and erreurs: if not coherent and erreurs:
logger.warning(" Validation adversariale : %d incohérence(s) détectée(s) (score %s/10)", logger.warning(" [CPAM-VALID] %.1fs — OGC %s %d incohérence(s) (score %s/10)",
len(erreurs), score) elapsed, controle.numero_ogc, len(erreurs), score)
for e in erreurs: for e in erreurs:
logger.warning(" - %s", e) logger.warning(" - %s", e)
else: else:
logger.info(" Validation adversariale OK (score %s/10)", score) logger.info(" [CPAM-VALID] %.1fs — OGC %s OK (score %s/10)",
elapsed, controle.numero_ogc, score)
return result return result

View File

@@ -610,20 +610,18 @@ def main(input_path: str | None = None) -> None:
merged = None merged = None
# Contrôle CPAM : enrichir le dossier principal (fusionné ou dernier) # Contrôle CPAM : enrichir le dossier principal (fusionné ou dernier)
# Utilise le mode batché pour regrouper les appels LLM par modèle
# et minimiser les swaps VRAM (de ~4N swaps à ~2-4 swaps)
if cpam_data and subdir: if cpam_data and subdir:
try: try:
from .control.cpam_parser import match_dossier_ogc from .control.cpam_parser import match_dossier_ogc
controles = match_dossier_ogc(subdir, cpam_data) controles = match_dossier_ogc(subdir, cpam_data)
if controles: if controles:
from .control.cpam_response import generate_cpam_response from .control.cpam_response import generate_cpam_responses_batched
target = merged if merged else (group_dossiers[-1] if group_dossiers else None) target = merged if merged else (group_dossiers[-1] if group_dossiers else None)
if target: if target:
logger.info(" CPAM : %d contrôle(s) pour %s", len(controles), subdir) logger.info(" CPAM : %d contrôle(s) pour %s", len(controles), subdir)
for ctrl in controles: generate_cpam_responses_batched(target, controles)
text, response_data, sources = generate_cpam_response(target, ctrl)
ctrl.contre_argumentation = text
ctrl.response_data = response_data
ctrl.sources_reponse = sources
target.controles_cpam = controles target.controles_cpam = controles
except Exception: except Exception:
logger.exception("Erreur CPAM pour %s", subdir) logger.exception("Erreur CPAM pour %s", subdir)

View File

@@ -11,6 +11,7 @@ from __future__ import annotations
import logging import logging
import re import re
import time
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
@@ -87,43 +88,71 @@ def extract_medical_info(
_extract_imagerie(anonymized_text, dossier) _extract_imagerie(anonymized_text, dossier)
_extract_complications(anonymized_text, dossier, edsnlp_result) _extract_complications(anonymized_text, dossier, edsnlp_result)
# Phase 4 : pass LLM pour détecter des DAS supplémentaires # Phase 4 : DAS LLM + RAG DP + DP selector en parallèle
_dp_selection_needed = dossier.document_type != "trackare"
if use_rag: if use_rag:
_extract_das_llm(anonymized_text, dossier)
# Optimisation #1 : paralléliser enrichissement RAG et sélection DP
_dp_selection_needed = use_rag and dossier.document_type != "trackare"
if use_rag or _dp_selection_needed:
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
def _task_enrich(): # --- Groupe 1 : 3 tâches indépendantes en parallèle ---
if use_rag: # - DAS LLM : détecte des DAS supplémentaires (ne dépend pas du RAG DP)
_enrich_with_rag(dossier) # - RAG DP : enrichit seulement le DP (ne dépend pas du DAS LLM)
# - DP selector : sélectionne le DP optimal (indépendant des deux autres)
def _task_das_llm():
t0 = time.monotonic()
_extract_das_llm(anonymized_text, dossier)
elapsed = time.monotonic() - t0
logger.info("⏱ [DAS-LLM-TASK] %.1fs — extraction DAS LLM terminée", elapsed)
def _task_rag_dp():
t0 = time.monotonic()
_enrich_dp_only(dossier)
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-DP-TASK] %.1fs — enrichissement DP terminé", elapsed)
def _task_select_dp(): def _task_select_dp():
if not _dp_selection_needed: if not _dp_selection_needed:
return None return None
t0 = time.monotonic()
from .dp_selector import select_dp, build_synthese from .dp_selector import select_dp, build_synthese
synthese = build_synthese(dossier, parsed_data) synthese = build_synthese(dossier, parsed_data)
return select_dp(dossier, synthese, config={"llm_enabled": use_rag}) result = select_dp(dossier, synthese, config={"llm_enabled": use_rag})
elapsed = time.monotonic() - t0
logger.info("⏱ [DP-SELECT] %.1fs — sélection DP terminée", elapsed)
return result
dp_selection_result = None dp_selection_result = None
with ThreadPoolExecutor(max_workers=2) as pool: t_group1 = time.monotonic()
fut_enrich = pool.submit(_task_enrich)
with ThreadPoolExecutor(max_workers=3) as pool:
fut_das = pool.submit(_task_das_llm)
fut_rag_dp = pool.submit(_task_rag_dp)
fut_dp = pool.submit(_task_select_dp) fut_dp = pool.submit(_task_select_dp)
# Attendre les deux tâches
for fut in as_completed([fut_enrich, fut_dp]): for fut in as_completed([fut_das, fut_rag_dp, fut_dp]):
exc = fut.exception() exc = fut.exception()
if exc and fut is fut_dp: if exc and fut is fut_dp:
logger.error("NUKE-3: erreur sélection DP", exc_info=exc) logger.error("NUKE-3: erreur sélection DP", exc_info=exc)
dossier.quality_flags["dp_selection_status"] = "error" dossier.quality_flags["dp_selection_status"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : sélection DP (NUKE-3) en erreur") dossier.alertes_codage.append("QUALITE DEGRADEE : sélection DP (NUKE-3) en erreur")
elif exc: elif exc and fut is fut_rag_dp:
logger.error("RAG enrichissement échoué", exc_info=exc) logger.error("RAG enrichissement DP échoué", exc_info=exc)
elif exc and fut is fut_das:
logger.error("DAS LLM extraction échouée", exc_info=exc)
if not fut_dp.exception(): if not fut_dp.exception():
dp_selection_result = fut_dp.result() dp_selection_result = fut_dp.result()
elapsed_group1 = time.monotonic() - t_group1
logger.info("⏱ [GROUPE-1] %.1fs — DAS_LLM + RAG_DP + DP_SELECT terminés", elapsed_group1)
# --- Groupe 2 : enrichir les DAS (existants + nouveaux du DAS LLM) + actes ---
t_group2 = time.monotonic()
_enrich_das_and_actes(dossier)
elapsed_group2 = time.monotonic() - t_group2
logger.info("⏱ [GROUPE-2] %.1fs — enrichissement DAS + actes terminé", elapsed_group2)
# Appliquer la sélection DP après parallélisation # Appliquer la sélection DP après parallélisation
if dp_selection_result is not None: if dp_selection_result is not None:
selection = dp_selection_result selection = dp_selection_result
@@ -151,12 +180,15 @@ def extract_medical_info(
dossier.alertes_codage.append( dossier.alertes_codage.append(
f"NUKE-3 REVIEW: DP ambigu — {selection.reason}" f"NUKE-3 REVIEW: DP ambigu — {selection.reason}"
) )
elif dossier.document_type != "trackare": elif _dp_selection_needed:
# Fallback sans RAG : sélection DP seule # Fallback sans RAG : sélection DP seule
try: try:
t_dp_norag = time.monotonic()
from .dp_selector import select_dp, build_synthese from .dp_selector import select_dp, build_synthese
synthese = build_synthese(dossier, parsed_data) synthese = build_synthese(dossier, parsed_data)
selection = select_dp(dossier, synthese, config={"llm_enabled": False}) selection = select_dp(dossier, synthese, config={"llm_enabled": False})
elapsed_dp = time.monotonic() - t_dp_norag
logger.info("⏱ [DP-SELECT] %.1fs — sélection DP (sans RAG)", elapsed_dp)
dossier.dp_selection = selection dossier.dp_selection = selection
if selection.chosen_code: if selection.chosen_code:
current_code = ( current_code = (
@@ -240,7 +272,10 @@ def extract_medical_info(
# Post-processing : validation justifications (QC batch) # Post-processing : validation justifications (QC batch)
if use_rag: if use_rag:
t_qc = time.monotonic()
_validate_justifications(dossier) _validate_justifications(dossier)
elapsed_qc = time.monotonic() - t_qc
logger.info("⏱ [QC-VALIDATION] %.1fs — validation justifications terminée", elapsed_qc)
# Post-processing : traçabilité source (page + extrait) # Post-processing : traçabilité source (page + extrait)
if page_tracker: if page_tracker:
@@ -457,7 +492,7 @@ def _extract_das_llm(text: str, dossier: DossierMedical) -> None:
def _enrich_with_rag(dossier: DossierMedical) -> None: def _enrich_with_rag(dossier: DossierMedical) -> None:
"""Enrichit les diagnostics via le RAG (FAISS + Ollama).""" """Enrichit les diagnostics via le RAG (FAISS + Ollama) — wrapper rétro-compatible."""
try: try:
from .rag_search import enrich_dossier from .rag_search import enrich_dossier
enrich_dossier(dossier) enrich_dossier(dossier)
@@ -471,6 +506,34 @@ def _enrich_with_rag(dossier: DossierMedical) -> None:
dossier.alertes_codage.append("QUALITE DEGRADEE : erreur RAG — codage sans référentiels") dossier.alertes_codage.append("QUALITE DEGRADEE : erreur RAG — codage sans référentiels")
def _enrich_dp_only(dossier: DossierMedical) -> None:
"""Enrichit SEULEMENT le DP via le RAG (Phase 1, parallélisable)."""
try:
from .rag_search import enrich_dp
enrich_dp(dossier)
except ImportError:
logger.error("RAG INDISPONIBLE : faiss-cpu ou sentence-transformers manquant")
dossier.quality_flags["rag_status"] = "unavailable"
dossier.alertes_codage.append("QUALITE DEGRADEE : RAG indisponible — codage sans référentiels")
except Exception:
logger.error("RAG EN ERREUR : enrichissement DP échoué", exc_info=True)
dossier.quality_flags["rag_status"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : erreur RAG DP — codage sans référentiels")
def _enrich_das_and_actes(dossier: DossierMedical) -> None:
"""Enrichit les DAS et actes CCAM via le RAG (Phase 2, après DAS LLM)."""
try:
from .rag_search import enrich_das_and_actes
enrich_das_and_actes(dossier)
except ImportError:
logger.error("RAG INDISPONIBLE : faiss-cpu ou sentence-transformers manquant")
dossier.quality_flags["rag_status"] = "unavailable"
except Exception:
logger.error("RAG EN ERREUR : enrichissement DAS/actes échoué", exc_info=True)
dossier.quality_flags["rag_status"] = "error"
def _extract_sejour(parsed: dict, dossier: DossierMedical) -> None: def _extract_sejour(parsed: dict, dossier: DossierMedical) -> None:
"""Extrait les informations de séjour.""" """Extrait les informations de séjour."""
patient = parsed.get("patient", {}) patient = parsed.get("patient", {})

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import logging import logging
import os import os
import threading import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from ..config import ( from ..config import (
@@ -610,7 +611,9 @@ def enrich_diagnostic(
Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent. Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent.
""" """
t0 = time.monotonic()
diag_type = "dp" if est_dp else "das" diag_type = "dp" if est_dp else "das"
label = f"RAG-{diag_type.upper()}"
# 1. Vérifier le cache # 1. Vérifier le cache
cached = cache.get(diagnostic.texte, diag_type) if cache else None cached = cache.get(diagnostic.texte, diag_type) if cache else None
@@ -626,6 +629,8 @@ def enrich_diagnostic(
if cached is not None: if cached is not None:
logger.info("Cache hit (sans sources FAISS) pour %s : « %s »", diag_type.upper(), diagnostic.texte) logger.info("Cache hit (sans sources FAISS) pour %s : « %s »", diag_type.upper(), diagnostic.texte)
_apply_llm_result_diagnostic(diagnostic, cached) _apply_llm_result_diagnostic(diagnostic, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs — %s (no FAISS)", label, elapsed, diagnostic.texte[:60])
return return
# 3. Stocker les sources RAG # 3. Stocker les sources RAG
@@ -643,11 +648,15 @@ def enrich_diagnostic(
if cached is not None: if cached is not None:
logger.info("Cache hit pour %s : « %s »", diag_type.upper(), diagnostic.texte) logger.info("Cache hit pour %s : « %s »", diag_type.upper(), diagnostic.texte)
_apply_llm_result_diagnostic(diagnostic, cached) _apply_llm_result_diagnostic(diagnostic, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs — %s (cache hit)", label, elapsed, diagnostic.texte[:60])
return return
# 5. Appel Ollama pour justification avec raisonnement structuré # 5. Appel Ollama pour justification avec raisonnement structuré
prompt = _build_prompt(diagnostic.texte, sources, contexte, est_dp=est_dp) prompt = _build_prompt(diagnostic.texte, sources, contexte, est_dp=est_dp)
t_llm = time.monotonic()
llm_result = _call_ollama(prompt) llm_result = _call_ollama(prompt)
llm_elapsed = time.monotonic() - t_llm
if llm_result: if llm_result:
_apply_llm_result_diagnostic(diagnostic, llm_result) _apply_llm_result_diagnostic(diagnostic, llm_result)
@@ -656,6 +665,9 @@ def enrich_diagnostic(
else: else:
logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM") logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM")
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs (LLM %.1fs) — %s", label, elapsed, llm_elapsed, diagnostic.texte[:60])
def _apply_llm_result_acte(acte: ActeCCAM, llm_result: dict) -> None: def _apply_llm_result_acte(acte: ActeCCAM, llm_result: dict) -> None:
"""Applique un résultat LLM (frais ou caché) à un ActeCCAM.""" """Applique un résultat LLM (frais ou caché) à un ActeCCAM."""
@@ -687,6 +699,8 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None
Modifie l'acte en place. Fallback gracieux si FAISS ou Ollama échouent. Modifie l'acte en place. Fallback gracieux si FAISS ou Ollama échouent.
""" """
t0 = time.monotonic()
# 1. Vérifier le cache # 1. Vérifier le cache
cached = cache.get(acte.texte, "ccam") if cache else None cached = cache.get(acte.texte, "ccam") if cache else None
@@ -712,11 +726,15 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None
if cached is not None: if cached is not None:
logger.info("Cache hit pour CCAM : « %s »", acte.texte) logger.info("Cache hit pour CCAM : « %s »", acte.texte)
_apply_llm_result_acte(acte, cached) _apply_llm_result_acte(acte, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-CCAM] %.1fs — %s (cache hit)", elapsed, acte.texte[:60])
return return
# 5. Appel Ollama pour justification avec raisonnement structuré # 5. Appel Ollama pour justification avec raisonnement structuré
prompt = _build_prompt_ccam(acte.texte, sources, contexte) prompt = _build_prompt_ccam(acte.texte, sources, contexte)
t_llm = time.monotonic()
llm_result = _call_ollama(prompt) llm_result = _call_ollama(prompt)
llm_elapsed = time.monotonic() - t_llm
if llm_result: if llm_result:
_apply_llm_result_acte(acte, llm_result) _apply_llm_result_acte(acte, llm_result)
@@ -725,6 +743,9 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None
else: else:
logger.info("Ollama non disponible — sources FAISS CCAM conservées sans justification LLM") logger.info("Ollama non disponible — sources FAISS CCAM conservées sans justification LLM")
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-CCAM] %.1fs (LLM %.1fs) — %s", elapsed, llm_elapsed, acte.texte[:60])
def _smart_truncate(text: str, max_chars: int = 6000) -> str: def _smart_truncate(text: str, max_chars: int = 6000) -> str:
"""Troncature intelligente : garde le début + les sections finales importantes. """Troncature intelligente : garde le début + les sections finales importantes.
@@ -790,6 +811,8 @@ def extract_das_llm(
""" """
import hashlib import hashlib
t0 = time.monotonic()
# Clé de cache basée sur le hash du texte # Clé de cache basée sur le hash du texte
text_hash = hashlib.md5(text[:4000].encode()).hexdigest()[:16] text_hash = hashlib.md5(text[:4000].encode()).hexdigest()[:16]
cache_key_text = f"das_extract::{text_hash}" cache_key_text = f"das_extract::{text_hash}"
@@ -798,15 +821,19 @@ def extract_das_llm(
if cache is not None: if cache is not None:
cached = cache.get(cache_key_text, "das_llm") cached = cache.get(cache_key_text, "das_llm")
if cached is not None: if cached is not None:
logger.info("Cache hit pour extraction DAS LLM") elapsed = time.monotonic() - t0
logger.info("⏱ [DAS-LLM] %.1fs — cache hit", elapsed)
return cached.get("diagnostics_supplementaires", []) return cached.get("diagnostics_supplementaires", [])
# Construire le prompt et appeler Ollama # Construire le prompt et appeler Ollama
prompt = _build_prompt_das_extraction(text, contexte, existing_das, dp_texte) prompt = _build_prompt_das_extraction(text, contexte, existing_das, dp_texte)
t_llm = time.monotonic()
result = call_ollama(prompt, temperature=0.1, max_tokens=2000, role="coding") result = call_ollama(prompt, temperature=0.1, max_tokens=2000, role="coding")
llm_elapsed = time.monotonic() - t_llm
if result is None: if result is None:
logger.warning("Extraction DAS LLM : Ollama non disponible") elapsed = time.monotonic() - t0
logger.warning("⏱ [DAS-LLM] %.1fs — Ollama non disponible", elapsed)
return [] return []
das_list = result.get("diagnostics_supplementaires", []) das_list = result.get("diagnostics_supplementaires", [])
@@ -818,25 +845,52 @@ def extract_das_llm(
if cache is not None: if cache is not None:
cache.put(cache_key_text, "das_llm", result) cache.put(cache_key_text, "das_llm", result)
logger.info("Extraction DAS LLM : %d diagnostics supplémentaires détectés", len(das_list)) elapsed = time.monotonic() - t0
logger.info("⏱ [DAS-LLM] %.1fs (LLM %.1fs) — %d diagnostics détectés", elapsed, llm_elapsed, len(das_list))
return das_list return das_list
def enrich_dossier(dossier: DossierMedical) -> None: def enrich_dp(dossier: DossierMedical, cache: OllamaCache | None = None) -> None:
"""Enrichit le DP et tous les DAS d'un dossier via le RAG. """Enrichit SEULEMENT le DP d'un dossier via le RAG (Phase 1).
Utilise un cache persistant et parallélise les appels Ollama Peut être exécuté en parallèle avec d'autres tâches (DAS LLM, DP selector)
pour les DAS et actes CCAM (max_workers = OLLAMA_MAX_PARALLEL). car il ne dépend que du DP existant.
Args:
dossier: Dossier médical à enrichir (modifié en place).
cache: Cache Ollama optionnel (créé si None).
""" """
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding")) t0 = time.monotonic()
if cache is None:
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
contexte = build_enriched_context(dossier) contexte = build_enriched_context(dossier)
# Phase 1 : DP seul (le contexte DAS en dépend)
if dossier.diagnostic_principal: if dossier.diagnostic_principal:
logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte) logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte)
enrich_diagnostic(dossier.diagnostic_principal, contexte, est_dp=True, cache=cache) enrich_diagnostic(dossier.diagnostic_principal, contexte, est_dp=True, cache=cache)
cache.save()
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-DP-PHASE] %.1fs — enrichissement DP terminé", elapsed)
def enrich_das_and_actes(dossier: DossierMedical, cache: OllamaCache | None = None) -> None:
"""Enrichit les DAS et actes CCAM d'un dossier via le RAG (Phase 2).
Parallélise les appels Ollama (max_workers = OLLAMA_MAX_PARALLEL).
Doit être appelé APRES enrich_dp() et après l'ajout des DAS LLM.
Args:
dossier: Dossier médical à enrichir (modifié en place).
cache: Cache Ollama optionnel (créé si None).
"""
t0 = time.monotonic()
if cache is None:
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
contexte = build_enriched_context(dossier)
# Mettre à jour le contexte avec le DP pour les DAS # Mettre à jour le contexte avec le DP pour les DAS
if dossier.diagnostic_principal: if dossier.diagnostic_principal:
contexte["dp_texte"] = dossier.diagnostic_principal.texte contexte["dp_texte"] = dossier.diagnostic_principal.texte
@@ -846,7 +900,6 @@ def enrich_dossier(dossier: DossierMedical) -> None:
if d.cim10_suggestion if d.cim10_suggestion
] ]
# Phase 2 : DAS + Actes en parallèle
das_list = dossier.diagnostics_associes das_list = dossier.diagnostics_associes
actes_list = dossier.actes_ccam actes_list = dossier.actes_ccam
@@ -863,3 +916,18 @@ def enrich_dossier(dossier: DossierMedical) -> None:
f.result() # propage les exceptions f.result() # propage les exceptions
cache.save() cache.save()
elapsed = time.monotonic() - t0
n_das = len(das_list) if das_list else 0
n_actes = len(actes_list) if actes_list else 0
logger.info("⏱ [RAG-DAS-PHASE] %.1fs — %d DAS + %d actes enrichis", elapsed, n_das, n_actes)
def enrich_dossier(dossier: DossierMedical) -> None:
"""Enrichit le DP et tous les DAS d'un dossier via le RAG.
Wrapper rétro-compatible qui appelle enrich_dp() puis enrich_das_and_actes().
Utilise un cache persistant partagé entre les deux phases.
"""
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
enrich_dp(dossier, cache=cache)
enrich_das_and_actes(dossier, cache=cache)

View File

@@ -116,7 +116,7 @@ class TestExtractMedicalInfoRAGFlag:
assert dossier.diagnostic_principal.justification is None assert dossier.diagnostic_principal.justification is None
def test_use_rag_true_calls_enrich(self): def test_use_rag_true_calls_enrich(self):
"""use_rag=True appelle _enrich_with_rag (mocké).""" """use_rag=True appelle _enrich_dp_only et _enrich_das_and_actes (mockés)."""
from src.medical.cim10_extractor import extract_medical_info from src.medical.cim10_extractor import extract_medical_info
parsed = { parsed = {
@@ -127,9 +127,13 @@ class TestExtractMedicalInfoRAGFlag:
} }
text = "Pancréatite aiguë biliaire.\nTTT de sortie :\nParacétamol\n\nDevenir : retour." text = "Pancréatite aiguë biliaire.\nTTT de sortie :\nParacétamol\n\nDevenir : retour."
with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich: with patch("src.medical.cim10_extractor._enrich_dp_only") as mock_dp, \
patch("src.medical.cim10_extractor._enrich_das_and_actes") as mock_das, \
patch("src.medical.cim10_extractor._extract_das_llm"), \
patch("src.medical.cim10_extractor._validate_justifications"):
dossier = extract_medical_info(parsed, text, use_rag=True) dossier = extract_medical_info(parsed, text, use_rag=True)
mock_enrich.assert_called_once_with(dossier) mock_dp.assert_called_once_with(dossier)
mock_das.assert_called_once_with(dossier)
def test_use_rag_default_false(self): def test_use_rag_default_false(self):
"""Par défaut, use_rag=False.""" """Par défaut, use_rag=False."""
@@ -143,9 +147,11 @@ class TestExtractMedicalInfoRAGFlag:
} }
text = "Test simple." text = "Test simple."
with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich: with patch("src.medical.cim10_extractor._enrich_dp_only") as mock_dp, \
patch("src.medical.cim10_extractor._enrich_das_and_actes") as mock_das:
extract_medical_info(parsed, text) extract_medical_info(parsed, text)
mock_enrich.assert_not_called() mock_dp.assert_not_called()
mock_das.assert_not_called()
class TestChunkingCIM10: class TestChunkingCIM10: