diff --git a/src/config.py b/src/config.py index 738e581..1dde987 100644 --- a/src/config.py +++ b/src/config.py @@ -60,7 +60,7 @@ OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434") OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "gemma3:27b") OLLAMA_TIMEOUT = int(os.environ.get("OLLAMA_TIMEOUT", "600")) OLLAMA_CACHE_PATH = BASE_DIR / "data" / "ollama_cache.json" -OLLAMA_MAX_PARALLEL = int(os.environ.get("OLLAMA_MAX_PARALLEL", "2")) +OLLAMA_MAX_PARALLEL = int(os.environ.get("OLLAMA_MAX_PARALLEL", "4")) # --- Modèles par rôle LLM --- diff --git a/src/control/cpam_response.py b/src/control/cpam_response.py index 62a85fa..f95b1e6 100644 --- a/src/control/cpam_response.py +++ b/src/control/cpam_response.py @@ -4,6 +4,14 @@ Orchestrateur principal — délègue aux sous-modules : - cpam_rag : _search_rag_for_control(), _search_rag_queries() - cpam_context : _build_cpam_prompt(), _build_tagged_context(), _build_bio_summary(), etc. - cpam_validation : _validate_adversarial(), _validate_grounding(), _format_response(), etc. + +Le flux batché (generate_cpam_responses_batched) regroupe les appels LLM par modèle +pour minimiser les swaps VRAM sur GPU unique : + Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles + Phase 2 (validation model, 1 swap) : validation adversariale de tous + Phase 3 (cpam model, 1 swap) : correction des contrôles échoués + Phase 4 (validation model, 1 swap) : re-validation des corrigés +Gain : de ~4N swaps à ~2-4 swaps (indépendant du nombre de contrôles). """ from __future__ import annotations @@ -11,6 +19,8 @@ from __future__ import annotations import json import logging import os +import time +from dataclasses import dataclass, field from datetime import datetime from pathlib import Path @@ -41,6 +51,34 @@ from .cpam_validation import ( logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Structure intermédiaire pour le batching +# --------------------------------------------------------------------------- + +@dataclass +class _CpamDraft: + """État intermédiaire d'un contrôle CPAM en cours de traitement batché.""" + controle: ControleCPAM + # Résultats des phases successives + extraction: dict | None = None + degraded_pass1: bool = False + sources: list[dict] = field(default_factory=list) + rag_sources: list[RAGSource] = field(default_factory=list) + prompt: str = "" + tag_map: dict[str, str] = field(default_factory=dict) + result: dict | None = None + # Validation + ref_warnings: list[str] = field(default_factory=list) + grounding_warnings: list[str] = field(default_factory=list) + code_warnings: list[str] = field(default_factory=list) + adversarial_warnings: list[str] = field(default_factory=list) + validation: dict | None = None + # Résultat final + needs_correction: bool = False + text: str = "" + response_data: dict | None = None + + def _save_version( dossier: DossierMedical, controle: ControleCPAM, @@ -149,14 +187,18 @@ def _extraction_pass( ) logger.debug(" Passe 1 — extraction structurée") + t0 = time.time() result = call_ollama(prompt, temperature=0.0, max_tokens=3000, role="cpam") if result is None: result = call_anthropic(prompt, temperature=0.0, max_tokens=3000) + elapsed = time.time() - t0 if result is not None: - logger.info(" Passe 1 OK : %d éléments cliniques extraits", + logger.info(" [CPAM-EXTRACT] %.1fs — OGC %s — %d éléments cliniques extraits", + elapsed, controle.numero_ogc, len(result.get("elements_cliniques_pertinents", []))) else: - logger.warning(" Passe 1 échouée — fallback single-pass") + logger.warning(" [CPAM-EXTRACT] %.1fs — OGC %s — passe 1 échouée", + elapsed, controle.numero_ogc) return result @@ -195,14 +237,17 @@ def generate_cpam_response( prompt, tag_map = _build_cpam_prompt(dossier, controle, sources, extraction) # 4. Appel LLM — Ollama (rôle cpam) > Haiku fallback + t_gen = time.time() result = call_ollama(prompt, temperature=0.1, max_tokens=8000, role="cpam") if result is not None: - logger.info(" Contre-argumentation via Ollama") + logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama", + time.time() - t_gen, controle.numero_ogc) else: logger.info(" Ollama indisponible → fallback Anthropic Haiku") result = call_anthropic(prompt, temperature=0.1, max_tokens=8000) if result is not None: - logger.info(" Contre-argumentation via Anthropic Haiku") + logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic", + time.time() - t_gen, controle.numero_ogc) # 5. Conversion des sources RAG rag_sources = [ @@ -285,9 +330,12 @@ def generate_cpam_response( validation.get("score_confiance"), attempt + 1, max_corrections, len(erreurs_v)) correction_prompt = _build_correction_prompt(prompt, result, validation) + t_corr = time.time() corrected = call_ollama(correction_prompt, temperature=0.0, max_tokens=16000, role="cpam") if corrected is None: corrected = call_anthropic(correction_prompt, temperature=0.0, max_tokens=16000) + logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d", + time.time() - t_corr, controle.numero_ogc, attempt + 1, max_corrections) if not corrected: break @@ -347,3 +395,343 @@ def generate_cpam_response( logger.info(" Contre-argumentation générée (%d caractères)", len(text)) return text, result, rag_sources + + +# --------------------------------------------------------------------------- +# Fonctions de phase (utilisées par le batching) +# --------------------------------------------------------------------------- + +def _phase_generate( + dossier: DossierMedical, + draft: _CpamDraft, +) -> None: + """Phase 1 — Extraction + RAG + argumentation (role=cpam). + + Remplit draft.extraction, draft.sources, draft.prompt, draft.tag_map, draft.result. + Ne fait aucun appel au modèle de validation. + """ + controle = draft.controle + logger.info("CPAM batch : génération pour OGC %d — %s", + controle.numero_ogc, controle.titre) + + # 0. Versioning + _save_version(dossier, controle) + + # 1. Extraction structurée (role=cpam) + draft.extraction = _extraction_pass(dossier, controle) + draft.degraded_pass1 = draft.extraction is None + if draft.degraded_pass1: + dossier.alertes_codage.append( + "CPAM: passe 1 (extraction structurée) échouée → mode dégradé" + ) + + # 2. Recherche RAG (pas de LLM) + draft.sources = _search_rag_for_control(controle, dossier) + logger.info(" RAG : %d sources trouvées", len(draft.sources)) + + # 3. Construction du prompt + draft.prompt, draft.tag_map = _build_cpam_prompt( + dossier, controle, draft.sources, draft.extraction + ) + + # 4. Argumentation (role=cpam) + t_gen = time.time() + result = call_ollama(draft.prompt, temperature=0.1, max_tokens=8000, role="cpam") + if result is not None: + logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama", + time.time() - t_gen, controle.numero_ogc) + else: + logger.info(" Ollama indisponible → fallback Anthropic Haiku") + result = call_anthropic(draft.prompt, temperature=0.1, max_tokens=8000) + if result is not None: + logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic", + time.time() - t_gen, controle.numero_ogc) + + # 5. Conversion sources RAG + draft.rag_sources = [ + RAGSource( + document=s.get("document", ""), + page=s.get("page"), + code=s.get("code"), + extrait=s.get("extrait", "")[:200], + ) + for s in draft.sources + ] + + if result is None: + logger.warning(" LLM non disponible — contre-argumentation non générée") + draft.result = None + return + + # 5b. Mode dégradé + if draft.degraded_pass1: + result.setdefault("quality_flags", {}) + result["quality_flags"]["cpam_pass1_failed"] = True + result["quality_flags"]["degraded_mode"] = True + + # 6. Sanitisation + gardien déterministes (pas de LLM) + sanitized = _sanitize_unauthorized_codes(result, dossier, controle) + if sanitized: + logger.info(" CPAM : %d code(s) hors périmètre supprimé(s) : %s", + len(sanitized), ", ".join(sanitized)) + result = _guardian_deterministic(result, dossier, controle, draft.tag_map) + + # 7. Validations déterministes (pas de LLM) + draft.ref_warnings = _validate_references(result, draft.sources) + if draft.ref_warnings: + logger.warning(" CPAM : %d référence(s) non vérifiable(s)", len(draft.ref_warnings)) + + draft.grounding_warnings = _validate_grounding(result, draft.tag_map) + if draft.grounding_warnings: + logger.warning(" CPAM : %d preuve(s) non traçable(s)", len(draft.grounding_warnings)) + + draft.code_warnings = _validate_codes_in_response(result, dossier, controle) + if draft.code_warnings: + logger.warning(" CPAM : %d code(s) hors périmètre", len(draft.code_warnings)) + + draft.result = result + + +def _phase_validate( + dossier: DossierMedical, + draft: _CpamDraft, +) -> None: + """Phase 2 — Validation adversariale (role=validation). + + Remplit draft.validation et draft.needs_correction. + """ + if draft.result is None: + return + + controle = draft.controle + + # LOGIC-3 : modèles identiques ? + from ..config import check_adversarial_model_config + same_model, model_msg = check_adversarial_model_config() + if same_model: + draft.result.setdefault("quality_flags", {}) + draft.result["quality_flags"]["adversarial_disabled_same_model"] = True + dossier.alertes_codage.append( + "Validation adversariale désactivée (modèles identiques)" + ) + + draft.validation = _validate_adversarial(draft.result, draft.tag_map, controle) + if draft.validation and not draft.validation.get("coherent", True): + erreurs = draft.validation.get("erreurs", []) + score = draft.validation.get("score_confiance", "?") + for e in erreurs: + if isinstance(e, str) and e.strip(): + draft.adversarial_warnings.append(f"Incohérence détectée : {e}") + if draft.adversarial_warnings: + draft.adversarial_warnings.append(f"Score de confiance : {score}/10") + + # Déterminer si une correction est nécessaire + max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2")) + draft.needs_correction = bool( + max_corrections > 0 + and draft.validation + and not draft.validation.get("coherent", True) + and draft.validation.get("score_confiance", 10) <= 5 + and rule_enabled("RULE-CPAM-CORRECTION-LOOP") + ) + + +def _phase_correct( + dossier: DossierMedical, + draft: _CpamDraft, +) -> None: + """Phase 3 — Correction (role=cpam) + re-validation (role=validation). + + Boucle correction/re-validation intégrée car les deux sont couplées. + Seuls les drafts avec needs_correction=True entrent ici. + """ + if draft.result is None or not draft.needs_correction: + return + + controle = draft.controle + max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2")) + + for attempt in range(max_corrections): + if not (draft.validation + and not draft.validation.get("coherent", True) + and draft.validation.get("score_confiance", 10) <= 5 + and rule_enabled("RULE-CPAM-CORRECTION-LOOP")): + break + + erreurs_v = draft.validation.get("erreurs", []) + logger.warning(" Score adversarial %s/10 — correction %d/%d (%d erreur(s))", + draft.validation.get("score_confiance"), + attempt + 1, max_corrections, len(erreurs_v)) + + correction_prompt = _build_correction_prompt( + draft.prompt, draft.result, draft.validation + ) + t_corr = time.time() + corrected = call_ollama( + correction_prompt, temperature=0.0, max_tokens=16000, role="cpam" + ) + if corrected is None: + corrected = call_anthropic( + correction_prompt, temperature=0.0, max_tokens=16000 + ) + logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d", + time.time() - t_corr, controle.numero_ogc, + attempt + 1, max_corrections) + + if not corrected: + break + + validation2 = _validate_adversarial(corrected, draft.tag_map, controle) + score2 = validation2.get("score_confiance", 0) if validation2 else 0 + score1 = draft.validation.get("score_confiance", 0) + + if score2 > score1: + logger.info(" Correction %d acceptée (score %s → %s)", + attempt + 1, score1, score2) + draft.result = corrected + draft.validation = validation2 + _sanitize_unauthorized_codes(draft.result, dossier, controle) + draft.result = _guardian_deterministic( + draft.result, dossier, controle, draft.tag_map + ) + draft.ref_warnings = _validate_references(draft.result, draft.sources) + draft.grounding_warnings = _validate_grounding(draft.result, draft.tag_map) + draft.code_warnings = _validate_codes_in_response( + draft.result, dossier, controle + ) + draft.adversarial_warnings = [] + if draft.validation and not draft.validation.get("coherent", True): + for e in draft.validation.get("erreurs", []): + if isinstance(e, str) and e.strip(): + draft.adversarial_warnings.append(f"Incohérence détectée : {e}") + if draft.adversarial_warnings: + draft.adversarial_warnings.append( + f"Score de confiance : {draft.validation.get('score_confiance', '?')}/10" + ) + else: + logger.warning(" Correction %d rejetée (score %s → %s)", + attempt + 1, score1, score2) + break + + +def _phase_finalize( + dossier: DossierMedical, + draft: _CpamDraft, +) -> tuple[str, dict | None, list[RAGSource]]: + """Phase finale — Qualité + formatage (déterministe, pas de LLM). + + Returns: + Même signature que generate_cpam_response(). + """ + controle = draft.controle + + if draft.result is None: + return "", None, draft.rag_sources + + all_warnings = ( + draft.ref_warnings + draft.grounding_warnings + + draft.code_warnings + draft.adversarial_warnings + ) + + strength = _assess_dossier_strength(dossier) + if strength["is_weak"]: + logger.info(" Dossier à preuves limitées (score %d/10) : %s", + strength["score"], ", ".join(strength["missing"])) + + tier, needs_review, cat_warnings = _assess_quality_tier( + draft.result, draft.ref_warnings, draft.grounding_warnings, + draft.code_warnings, draft.validation, + is_weak_dossier=strength["is_weak"], + ) + controle.quality_tier = tier + controle.requires_review = needs_review + controle.quality_warnings = cat_warnings + logger.info(" Qualité CPAM : tier %s, requires_review=%s, %d warnings", + tier, needs_review, len(cat_warnings)) + + text = _format_response( + draft.result, + ref_warnings=all_warnings, + quality_tier=tier, + categorized_warnings=cat_warnings, + ) + logger.info(" Contre-argumentation générée (%d caractères)", len(text)) + + return text, draft.result, draft.rag_sources + + +# --------------------------------------------------------------------------- +# Orchestrateur batché — minimise les swaps de modèle VRAM +# --------------------------------------------------------------------------- + +def generate_cpam_responses_batched( + dossier: DossierMedical, + controles: list[ControleCPAM], +) -> None: + """Génère les contre-argumentations pour TOUS les contrôles CPAM en batch. + + Regroupe les appels LLM par modèle pour minimiser les swaps VRAM : + Phase 1 (cpam model) : extraction + argumentation de tous les contrôles + Phase 2 (validation model) : validation adversariale de tous + Phase 3 (cpam model) : correction des contrôles échoués (si nécessaire) + Finalisation : qualité + formatage (déterministe) + + Gain : de ~4*N swaps de modèle à 2-4 swaps, indépendant de N. + + Args: + dossier: Le dossier médical (fusionné ou unique). + controles: Liste des contrôles CPAM à traiter. + + Side effects: + Remplit controle.contre_argumentation, controle.response_data, + controle.sources_reponse pour chaque contrôle. + """ + if not controles: + return + + t_total = time.time() + n = len(controles) + logger.info("CPAM batch : %d contrôle(s) — début du traitement batché", n) + + # Créer les drafts + drafts = [_CpamDraft(controle=ctrl) for ctrl in controles] + + # --- Phase 1 : génération (cpam model) --- + t_phase = time.time() + logger.info("CPAM batch phase 1/3 : génération (%d contrôles, role=cpam)", n) + for draft in drafts: + _phase_generate(dossier, draft) + logger.info("CPAM batch phase 1/3 terminée : %.1fs", time.time() - t_phase) + + # --- Phase 2 : validation adversariale (validation model — 1 swap) --- + drafts_to_validate = [d for d in drafts if d.result is not None] + if drafts_to_validate: + t_phase = time.time() + logger.info("CPAM batch phase 2/3 : validation (%d contrôles, role=validation)", + len(drafts_to_validate)) + for draft in drafts_to_validate: + _phase_validate(dossier, draft) + logger.info("CPAM batch phase 2/3 terminée : %.1fs", time.time() - t_phase) + + # --- Phase 3 : correction + re-validation (cpam + validation models) --- + drafts_to_correct = [d for d in drafts if d.needs_correction] + if drafts_to_correct: + t_phase = time.time() + logger.info("CPAM batch phase 3/3 : correction (%d contrôles nécessitent correction)", + len(drafts_to_correct)) + for draft in drafts_to_correct: + _phase_correct(dossier, draft) + logger.info("CPAM batch phase 3/3 terminée : %.1fs", time.time() - t_phase) + + # --- Finalisation (déterministe, pas de swap) --- + for draft in drafts: + text, response_data, sources = _phase_finalize(dossier, draft) + draft.controle.contre_argumentation = text + draft.controle.response_data = response_data + draft.controle.sources_reponse = sources + + elapsed_total = time.time() - t_total + logger.info( + "CPAM batch terminé : %d contrôle(s) en %.1fs (%.1fs/contrôle)", + n, elapsed_total, elapsed_total / n if n else 0, + ) diff --git a/src/control/cpam_validation.py b/src/control/cpam_validation.py index 888edb2..106ce8b 100644 --- a/src/control/cpam_validation.py +++ b/src/control/cpam_validation.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging import re +import time from ..config import ControleCPAM, DossierMedical from ..medical.bio_normals import BIO_NORMALS @@ -477,11 +478,14 @@ def _validate_adversarial( ) logger.debug(" Validation adversariale") + t_val = time.time() result = call_ollama(prompt, temperature=0.0, max_tokens=6000, role="validation") if result is None: result = call_anthropic(prompt, temperature=0.0, max_tokens=6000) + elapsed = time.time() - t_val if result is None: - logger.warning(" Validation adversariale échouée — LLM indisponible") + logger.warning(" [CPAM-VALID] %.1fs — OGC %s — validation adversariale échouée", + elapsed, controle.numero_ogc) return None coherent = result.get("coherent", True) @@ -489,12 +493,13 @@ def _validate_adversarial( score = result.get("score_confiance", -1) if not coherent and erreurs: - logger.warning(" Validation adversariale : %d incohérence(s) détectée(s) (score %s/10)", - len(erreurs), score) + logger.warning(" [CPAM-VALID] %.1fs — OGC %s — %d incohérence(s) (score %s/10)", + elapsed, controle.numero_ogc, len(erreurs), score) for e in erreurs: logger.warning(" - %s", e) else: - logger.info(" Validation adversariale OK (score %s/10)", score) + logger.info(" [CPAM-VALID] %.1fs — OGC %s — OK (score %s/10)", + elapsed, controle.numero_ogc, score) return result diff --git a/src/main.py b/src/main.py index 6da3d75..e32a9f1 100644 --- a/src/main.py +++ b/src/main.py @@ -610,20 +610,18 @@ def main(input_path: str | None = None) -> None: merged = None # Contrôle CPAM : enrichir le dossier principal (fusionné ou dernier) + # Utilise le mode batché pour regrouper les appels LLM par modèle + # et minimiser les swaps VRAM (de ~4N swaps à ~2-4 swaps) if cpam_data and subdir: try: from .control.cpam_parser import match_dossier_ogc controles = match_dossier_ogc(subdir, cpam_data) if controles: - from .control.cpam_response import generate_cpam_response + from .control.cpam_response import generate_cpam_responses_batched target = merged if merged else (group_dossiers[-1] if group_dossiers else None) if target: logger.info(" CPAM : %d contrôle(s) pour %s", len(controles), subdir) - for ctrl in controles: - text, response_data, sources = generate_cpam_response(target, ctrl) - ctrl.contre_argumentation = text - ctrl.response_data = response_data - ctrl.sources_reponse = sources + generate_cpam_responses_batched(target, controles) target.controles_cpam = controles except Exception: logger.exception("Erreur CPAM pour %s", subdir) diff --git a/src/medical/cim10_extractor.py b/src/medical/cim10_extractor.py index 57f7745..4c6f3d5 100644 --- a/src/medical/cim10_extractor.py +++ b/src/medical/cim10_extractor.py @@ -11,6 +11,7 @@ from __future__ import annotations import logging import re +import time from datetime import datetime from typing import Optional @@ -87,43 +88,71 @@ def extract_medical_info( _extract_imagerie(anonymized_text, dossier) _extract_complications(anonymized_text, dossier, edsnlp_result) - # Phase 4 : pass LLM pour détecter des DAS supplémentaires + # Phase 4 : DAS LLM + RAG DP + DP selector en parallèle + _dp_selection_needed = dossier.document_type != "trackare" + if use_rag: - _extract_das_llm(anonymized_text, dossier) - - # Optimisation #1 : paralléliser enrichissement RAG et sélection DP - _dp_selection_needed = use_rag and dossier.document_type != "trackare" - - if use_rag or _dp_selection_needed: from concurrent.futures import ThreadPoolExecutor, as_completed - def _task_enrich(): - if use_rag: - _enrich_with_rag(dossier) + # --- Groupe 1 : 3 tâches indépendantes en parallèle --- + # - DAS LLM : détecte des DAS supplémentaires (ne dépend pas du RAG DP) + # - RAG DP : enrichit seulement le DP (ne dépend pas du DAS LLM) + # - DP selector : sélectionne le DP optimal (indépendant des deux autres) + + def _task_das_llm(): + t0 = time.monotonic() + _extract_das_llm(anonymized_text, dossier) + elapsed = time.monotonic() - t0 + logger.info("⏱ [DAS-LLM-TASK] %.1fs — extraction DAS LLM terminée", elapsed) + + def _task_rag_dp(): + t0 = time.monotonic() + _enrich_dp_only(dossier) + elapsed = time.monotonic() - t0 + logger.info("⏱ [RAG-DP-TASK] %.1fs — enrichissement DP terminé", elapsed) def _task_select_dp(): if not _dp_selection_needed: return None + t0 = time.monotonic() from .dp_selector import select_dp, build_synthese synthese = build_synthese(dossier, parsed_data) - return select_dp(dossier, synthese, config={"llm_enabled": use_rag}) + result = select_dp(dossier, synthese, config={"llm_enabled": use_rag}) + elapsed = time.monotonic() - t0 + logger.info("⏱ [DP-SELECT] %.1fs — sélection DP terminée", elapsed) + return result dp_selection_result = None - with ThreadPoolExecutor(max_workers=2) as pool: - fut_enrich = pool.submit(_task_enrich) + t_group1 = time.monotonic() + + with ThreadPoolExecutor(max_workers=3) as pool: + fut_das = pool.submit(_task_das_llm) + fut_rag_dp = pool.submit(_task_rag_dp) fut_dp = pool.submit(_task_select_dp) - # Attendre les deux tâches - for fut in as_completed([fut_enrich, fut_dp]): + + for fut in as_completed([fut_das, fut_rag_dp, fut_dp]): exc = fut.exception() if exc and fut is fut_dp: logger.error("NUKE-3: erreur sélection DP", exc_info=exc) dossier.quality_flags["dp_selection_status"] = "error" dossier.alertes_codage.append("QUALITE DEGRADEE : sélection DP (NUKE-3) en erreur") - elif exc: - logger.error("RAG enrichissement échoué", exc_info=exc) + elif exc and fut is fut_rag_dp: + logger.error("RAG enrichissement DP échoué", exc_info=exc) + elif exc and fut is fut_das: + logger.error("DAS LLM extraction échouée", exc_info=exc) + if not fut_dp.exception(): dp_selection_result = fut_dp.result() + elapsed_group1 = time.monotonic() - t_group1 + logger.info("⏱ [GROUPE-1] %.1fs — DAS_LLM + RAG_DP + DP_SELECT terminés", elapsed_group1) + + # --- Groupe 2 : enrichir les DAS (existants + nouveaux du DAS LLM) + actes --- + t_group2 = time.monotonic() + _enrich_das_and_actes(dossier) + elapsed_group2 = time.monotonic() - t_group2 + logger.info("⏱ [GROUPE-2] %.1fs — enrichissement DAS + actes terminé", elapsed_group2) + # Appliquer la sélection DP après parallélisation if dp_selection_result is not None: selection = dp_selection_result @@ -151,12 +180,15 @@ def extract_medical_info( dossier.alertes_codage.append( f"NUKE-3 REVIEW: DP ambigu — {selection.reason}" ) - elif dossier.document_type != "trackare": + elif _dp_selection_needed: # Fallback sans RAG : sélection DP seule try: + t_dp_norag = time.monotonic() from .dp_selector import select_dp, build_synthese synthese = build_synthese(dossier, parsed_data) selection = select_dp(dossier, synthese, config={"llm_enabled": False}) + elapsed_dp = time.monotonic() - t_dp_norag + logger.info("⏱ [DP-SELECT] %.1fs — sélection DP (sans RAG)", elapsed_dp) dossier.dp_selection = selection if selection.chosen_code: current_code = ( @@ -240,7 +272,10 @@ def extract_medical_info( # Post-processing : validation justifications (QC batch) if use_rag: + t_qc = time.monotonic() _validate_justifications(dossier) + elapsed_qc = time.monotonic() - t_qc + logger.info("⏱ [QC-VALIDATION] %.1fs — validation justifications terminée", elapsed_qc) # Post-processing : traçabilité source (page + extrait) if page_tracker: @@ -457,7 +492,7 @@ def _extract_das_llm(text: str, dossier: DossierMedical) -> None: def _enrich_with_rag(dossier: DossierMedical) -> None: - """Enrichit les diagnostics via le RAG (FAISS + Ollama).""" + """Enrichit les diagnostics via le RAG (FAISS + Ollama) — wrapper rétro-compatible.""" try: from .rag_search import enrich_dossier enrich_dossier(dossier) @@ -471,6 +506,34 @@ def _enrich_with_rag(dossier: DossierMedical) -> None: dossier.alertes_codage.append("QUALITE DEGRADEE : erreur RAG — codage sans référentiels") +def _enrich_dp_only(dossier: DossierMedical) -> None: + """Enrichit SEULEMENT le DP via le RAG (Phase 1, parallélisable).""" + try: + from .rag_search import enrich_dp + enrich_dp(dossier) + except ImportError: + logger.error("RAG INDISPONIBLE : faiss-cpu ou sentence-transformers manquant") + dossier.quality_flags["rag_status"] = "unavailable" + dossier.alertes_codage.append("QUALITE DEGRADEE : RAG indisponible — codage sans référentiels") + except Exception: + logger.error("RAG EN ERREUR : enrichissement DP échoué", exc_info=True) + dossier.quality_flags["rag_status"] = "error" + dossier.alertes_codage.append("QUALITE DEGRADEE : erreur RAG DP — codage sans référentiels") + + +def _enrich_das_and_actes(dossier: DossierMedical) -> None: + """Enrichit les DAS et actes CCAM via le RAG (Phase 2, après DAS LLM).""" + try: + from .rag_search import enrich_das_and_actes + enrich_das_and_actes(dossier) + except ImportError: + logger.error("RAG INDISPONIBLE : faiss-cpu ou sentence-transformers manquant") + dossier.quality_flags["rag_status"] = "unavailable" + except Exception: + logger.error("RAG EN ERREUR : enrichissement DAS/actes échoué", exc_info=True) + dossier.quality_flags["rag_status"] = "error" + + def _extract_sejour(parsed: dict, dossier: DossierMedical) -> None: """Extrait les informations de séjour.""" patient = parsed.get("patient", {}) diff --git a/src/medical/rag_search.py b/src/medical/rag_search.py index 3617f9c..ffa7600 100644 --- a/src/medical/rag_search.py +++ b/src/medical/rag_search.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging import os import threading +import time from concurrent.futures import ThreadPoolExecutor, as_completed from ..config import ( @@ -610,7 +611,9 @@ def enrich_diagnostic( Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent. """ + t0 = time.monotonic() diag_type = "dp" if est_dp else "das" + label = f"RAG-{diag_type.upper()}" # 1. Vérifier le cache cached = cache.get(diagnostic.texte, diag_type) if cache else None @@ -626,6 +629,8 @@ def enrich_diagnostic( if cached is not None: logger.info("Cache hit (sans sources FAISS) pour %s : « %s »", diag_type.upper(), diagnostic.texte) _apply_llm_result_diagnostic(diagnostic, cached) + elapsed = time.monotonic() - t0 + logger.info("⏱ [%s] %.1fs — %s (no FAISS)", label, elapsed, diagnostic.texte[:60]) return # 3. Stocker les sources RAG @@ -643,11 +648,15 @@ def enrich_diagnostic( if cached is not None: logger.info("Cache hit pour %s : « %s »", diag_type.upper(), diagnostic.texte) _apply_llm_result_diagnostic(diagnostic, cached) + elapsed = time.monotonic() - t0 + logger.info("⏱ [%s] %.1fs — %s (cache hit)", label, elapsed, diagnostic.texte[:60]) return # 5. Appel Ollama pour justification avec raisonnement structuré prompt = _build_prompt(diagnostic.texte, sources, contexte, est_dp=est_dp) + t_llm = time.monotonic() llm_result = _call_ollama(prompt) + llm_elapsed = time.monotonic() - t_llm if llm_result: _apply_llm_result_diagnostic(diagnostic, llm_result) @@ -656,6 +665,9 @@ def enrich_diagnostic( else: logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM") + elapsed = time.monotonic() - t0 + logger.info("⏱ [%s] %.1fs (LLM %.1fs) — %s", label, elapsed, llm_elapsed, diagnostic.texte[:60]) + def _apply_llm_result_acte(acte: ActeCCAM, llm_result: dict) -> None: """Applique un résultat LLM (frais ou caché) à un ActeCCAM.""" @@ -687,6 +699,8 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None Modifie l'acte en place. Fallback gracieux si FAISS ou Ollama échouent. """ + t0 = time.monotonic() + # 1. Vérifier le cache cached = cache.get(acte.texte, "ccam") if cache else None @@ -712,11 +726,15 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None if cached is not None: logger.info("Cache hit pour CCAM : « %s »", acte.texte) _apply_llm_result_acte(acte, cached) + elapsed = time.monotonic() - t0 + logger.info("⏱ [RAG-CCAM] %.1fs — %s (cache hit)", elapsed, acte.texte[:60]) return # 5. Appel Ollama pour justification avec raisonnement structuré prompt = _build_prompt_ccam(acte.texte, sources, contexte) + t_llm = time.monotonic() llm_result = _call_ollama(prompt) + llm_elapsed = time.monotonic() - t_llm if llm_result: _apply_llm_result_acte(acte, llm_result) @@ -725,6 +743,9 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None else: logger.info("Ollama non disponible — sources FAISS CCAM conservées sans justification LLM") + elapsed = time.monotonic() - t0 + logger.info("⏱ [RAG-CCAM] %.1fs (LLM %.1fs) — %s", elapsed, llm_elapsed, acte.texte[:60]) + def _smart_truncate(text: str, max_chars: int = 6000) -> str: """Troncature intelligente : garde le début + les sections finales importantes. @@ -790,6 +811,8 @@ def extract_das_llm( """ import hashlib + t0 = time.monotonic() + # Clé de cache basée sur le hash du texte text_hash = hashlib.md5(text[:4000].encode()).hexdigest()[:16] cache_key_text = f"das_extract::{text_hash}" @@ -798,15 +821,19 @@ def extract_das_llm( if cache is not None: cached = cache.get(cache_key_text, "das_llm") if cached is not None: - logger.info("Cache hit pour extraction DAS LLM") + elapsed = time.monotonic() - t0 + logger.info("⏱ [DAS-LLM] %.1fs — cache hit", elapsed) return cached.get("diagnostics_supplementaires", []) # Construire le prompt et appeler Ollama prompt = _build_prompt_das_extraction(text, contexte, existing_das, dp_texte) + t_llm = time.monotonic() result = call_ollama(prompt, temperature=0.1, max_tokens=2000, role="coding") + llm_elapsed = time.monotonic() - t_llm if result is None: - logger.warning("Extraction DAS LLM : Ollama non disponible") + elapsed = time.monotonic() - t0 + logger.warning("⏱ [DAS-LLM] %.1fs — Ollama non disponible", elapsed) return [] das_list = result.get("diagnostics_supplementaires", []) @@ -818,25 +845,52 @@ def extract_das_llm( if cache is not None: cache.put(cache_key_text, "das_llm", result) - logger.info("Extraction DAS LLM : %d diagnostics supplémentaires détectés", len(das_list)) + elapsed = time.monotonic() - t0 + logger.info("⏱ [DAS-LLM] %.1fs (LLM %.1fs) — %d diagnostics détectés", elapsed, llm_elapsed, len(das_list)) return das_list -def enrich_dossier(dossier: DossierMedical) -> None: - """Enrichit le DP et tous les DAS d'un dossier via le RAG. +def enrich_dp(dossier: DossierMedical, cache: OllamaCache | None = None) -> None: + """Enrichit SEULEMENT le DP d'un dossier via le RAG (Phase 1). - Utilise un cache persistant et parallélise les appels Ollama - pour les DAS et actes CCAM (max_workers = OLLAMA_MAX_PARALLEL). + Peut être exécuté en parallèle avec d'autres tâches (DAS LLM, DP selector) + car il ne dépend que du DP existant. + + Args: + dossier: Dossier médical à enrichir (modifié en place). + cache: Cache Ollama optionnel (créé si None). """ - cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding")) + t0 = time.monotonic() + if cache is None: + cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding")) contexte = build_enriched_context(dossier) - # Phase 1 : DP seul (le contexte DAS en dépend) if dossier.diagnostic_principal: logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte) enrich_diagnostic(dossier.diagnostic_principal, contexte, est_dp=True, cache=cache) + cache.save() + elapsed = time.monotonic() - t0 + logger.info("⏱ [RAG-DP-PHASE] %.1fs — enrichissement DP terminé", elapsed) + + +def enrich_das_and_actes(dossier: DossierMedical, cache: OllamaCache | None = None) -> None: + """Enrichit les DAS et actes CCAM d'un dossier via le RAG (Phase 2). + + Parallélise les appels Ollama (max_workers = OLLAMA_MAX_PARALLEL). + Doit être appelé APRES enrich_dp() et après l'ajout des DAS LLM. + + Args: + dossier: Dossier médical à enrichir (modifié en place). + cache: Cache Ollama optionnel (créé si None). + """ + t0 = time.monotonic() + if cache is None: + cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding")) + + contexte = build_enriched_context(dossier) + # Mettre à jour le contexte avec le DP pour les DAS if dossier.diagnostic_principal: contexte["dp_texte"] = dossier.diagnostic_principal.texte @@ -846,7 +900,6 @@ def enrich_dossier(dossier: DossierMedical) -> None: if d.cim10_suggestion ] - # Phase 2 : DAS + Actes en parallèle das_list = dossier.diagnostics_associes actes_list = dossier.actes_ccam @@ -863,3 +916,18 @@ def enrich_dossier(dossier: DossierMedical) -> None: f.result() # propage les exceptions cache.save() + elapsed = time.monotonic() - t0 + n_das = len(das_list) if das_list else 0 + n_actes = len(actes_list) if actes_list else 0 + logger.info("⏱ [RAG-DAS-PHASE] %.1fs — %d DAS + %d actes enrichis", elapsed, n_das, n_actes) + + +def enrich_dossier(dossier: DossierMedical) -> None: + """Enrichit le DP et tous les DAS d'un dossier via le RAG. + + Wrapper rétro-compatible qui appelle enrich_dp() puis enrich_das_and_actes(). + Utilise un cache persistant partagé entre les deux phases. + """ + cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding")) + enrich_dp(dossier, cache=cache) + enrich_das_and_actes(dossier, cache=cache) diff --git a/tests/test_rag.py b/tests/test_rag.py index b4b70d9..3022d32 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -116,7 +116,7 @@ class TestExtractMedicalInfoRAGFlag: assert dossier.diagnostic_principal.justification is None def test_use_rag_true_calls_enrich(self): - """use_rag=True appelle _enrich_with_rag (mocké).""" + """use_rag=True appelle _enrich_dp_only et _enrich_das_and_actes (mockés).""" from src.medical.cim10_extractor import extract_medical_info parsed = { @@ -127,9 +127,13 @@ class TestExtractMedicalInfoRAGFlag: } text = "Pancréatite aiguë biliaire.\nTTT de sortie :\nParacétamol\n\nDevenir : retour." - with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich: + with patch("src.medical.cim10_extractor._enrich_dp_only") as mock_dp, \ + patch("src.medical.cim10_extractor._enrich_das_and_actes") as mock_das, \ + patch("src.medical.cim10_extractor._extract_das_llm"), \ + patch("src.medical.cim10_extractor._validate_justifications"): dossier = extract_medical_info(parsed, text, use_rag=True) - mock_enrich.assert_called_once_with(dossier) + mock_dp.assert_called_once_with(dossier) + mock_das.assert_called_once_with(dossier) def test_use_rag_default_false(self): """Par défaut, use_rag=False.""" @@ -143,9 +147,11 @@ class TestExtractMedicalInfoRAGFlag: } text = "Test simple." - with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich: + with patch("src.medical.cim10_extractor._enrich_dp_only") as mock_dp, \ + patch("src.medical.cim10_extractor._enrich_das_and_actes") as mock_das: extract_medical_info(parsed, text) - mock_enrich.assert_not_called() + mock_dp.assert_not_called() + mock_das.assert_not_called() class TestChunkingCIM10: