Compare commits

..

9 Commits

Author SHA1 Message Date
dom
1d4a0e1128 Merge branch 'feat/speed-optimizations' — batch QC + pipeline CPU/GPU + QC Dell 2026-03-08 15:45:20 +01:00
dom
3a6e008269 feat: 3 optimisations vitesse — batch QC, pipeline CPU/GPU, QC sur Dell
1. Batch QC : la validation justifications (gemma3:12b) est différée et
   exécutée en une seule passe après tous les documents du groupe.
   Réduit les swaps modèle de 2 par document à 2 pour le groupe entier.

2. Pipeline CPU/GPU : prepare_document (extraction, anonymisation, edsnlp)
   tourne sur CPU pendant que le GPU traite le document précédent via
   un ThreadPoolExecutor(1) en prefetch.

3. QC sur Dell : si la machine secondaire (192.168.1.11) est disponible,
   les appels QC sont envoyés là-bas en parallèle du codage principal,
   éliminant tout swap modèle sur le GPU principal.

Refactoring associé :
- _postprocess_dossier() extrait la logique vetos/décisions/GHM/finalizer
- call_ollama() accepte ollama_url pour cibler un serveur spécifique
- _is_secondary_available() avec cache 60s pour éviter le polling

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 15:45:17 +01:00
dom
2fb7b46a7c Merge branch 'feat/parallel-rag-pipeline' — optimisation flux LLM 2026-03-08 14:35:54 +01:00
dom
b0aa83f664 fix: adapter tests RAG à la nouvelle parallélisation (enrich_dp_only + enrich_das_and_actes)
Les agents d'optimisation ont splitté _enrich_with_rag en _enrich_dp_only
et _enrich_das_and_actes mais n'ont pas mis à jour les mocks dans test_rag.py.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 14:35:39 +01:00
dom
f94d8496cb feat: monter OLLAMA_MAX_PARALLEL défaut à 4
Le défaut de 2 était sous-optimal pour la RTX 5070 (12 Go VRAM).
Ollama gère la concurrence interne et queue les requêtes
excédentaires. Un pool de 4 workers Python permet de mieux
saturer le GPU sur les appels DAS/actes parallèles.

Le .env peut toujours override cette valeur via OLLAMA_MAX_PARALLEL.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 14:17:24 +01:00
dom
355a33acde feat: paralléliser DAS_LLM + RAG DP + DP selector
Restructure le pipeline dans extract_medical_info() :

AVANT : DAS_LLM séquentiel → ThreadPool(RAG complet + DP_selector)

APRÈS :
  Groupe 1 (ThreadPool max_workers=3) :
    - DAS_LLM : extraction DAS supplémentaires
    - RAG DP : enrichissement DP seul (via enrich_dp)
    - DP selector : sélection NUKE-3
  Groupe 2 :
    - enrichissement DAS + actes (via enrich_das_and_actes)

Le RAG DP ne dépend pas du DAS_LLM, donc les deux peuvent
s'exécuter en parallèle. Le Groupe 2 attend le DAS_LLM car
il enrichit les DAS trouvés par celui-ci.

Ajoute aussi des timings sur les groupes et la validation QC.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 14:17:18 +01:00
dom
f5a6122495 feat: timings granulaires appels LLM (RAG, DAS, DP, QC)
Ajoute des mesures time.monotonic() autour de chaque appel LLM dans
rag_search.py : enrich_diagnostic, enrich_acte, extract_das_llm.
Format de log : logger.info("⏱ [RAG-DP] 14.2s — texte")

Découpe enrich_dossier() en 2 fonctions exportées :
- enrich_dp() : enrichit seulement le DP (parallélisable)
- enrich_das_and_actes() : enrichit DAS + actes en parallèle
L'ancienne enrich_dossier() reste comme wrapper rétro-compatible.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 14:16:38 +01:00
dom
0cafb47e8d feat: batching CPAM par modèle pour réduire les swaps VRAM
Restructure le flux CPAM pour regrouper les appels LLM par modèle :
- Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles
- Phase 2 (validation model, 1 swap) : validation adversariale de tous
- Phase 3 (cpam model, si correction nécessaire) : correction des échoués

Nouvelle fonction generate_cpam_responses_batched() qui orchestre les 3 phases.
L'ancienne generate_cpam_response() reste intacte (rétrocompatible, utilisée
par le viewer pour la régénération unitaire).

Structure intermédiaire _CpamDraft (dataclass) pour transporter l'état entre phases.
Fonctions de phase extraites : _phase_generate, _phase_validate, _phase_correct,
_phase_finalize.

Gain estimé : de ~4*N swaps de modèle à 2-4 swaps (indépendant de N contrôles).
Sur RTX 5070 avec des modèles 24-32b, chaque swap coûte ~10-15s.
Pour 3 contrôles : économie estimée ~60-90s.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 14:10:34 +01:00
dom
d6b4e48989 feat: timings appels LLM CPAM (génération, validation, correction)
Ajoute des mesures time.time() autour de chaque appel Ollama dans le flux CPAM :
- [CPAM-EXTRACT] : extraction structurée (passe 1, role=cpam)
- [CPAM-GEN] : génération argumentation (passe 2, role=cpam)
- [CPAM-VALID] : validation adversariale (role=validation)
- [CPAM-CORR] : correction post-validation (role=cpam)

Permet de mesurer le temps réel de chaque phase et d'identifier
les coûts de swap de modèle VRAM entre les rôles cpam/validation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 14:09:42 +01:00
9 changed files with 826 additions and 148 deletions

View File

@@ -57,10 +57,11 @@ NER_CONFIDENCE_THRESHOLD = float(os.environ.get("T2A_NER_THRESHOLD", "0.80"))
# --- Configuration Ollama ---
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
OLLAMA_URL_SECONDARY = os.environ.get("OLLAMA_URL_SECONDARY", "http://192.168.1.11:11434")
OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "gemma3:27b")
OLLAMA_TIMEOUT = int(os.environ.get("OLLAMA_TIMEOUT", "600"))
OLLAMA_CACHE_PATH = BASE_DIR / "data" / "ollama_cache.json"
OLLAMA_MAX_PARALLEL = int(os.environ.get("OLLAMA_MAX_PARALLEL", "2"))
OLLAMA_MAX_PARALLEL = int(os.environ.get("OLLAMA_MAX_PARALLEL", "4"))
# --- Modèles par rôle LLM ---
@@ -203,7 +204,7 @@ EMBEDDING_MODEL = os.environ.get("T2A_EMBEDDING_MODEL", "dangvantuan/sentence-ca
# --- Modèle de re-ranking (cross-encoder, CPU uniquement) ---
RERANKER_MODEL = os.environ.get("T2A_RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
RERANKER_MODEL = os.environ.get("T2A_RERANKER_MODEL", "ncbi/MedCPT-Cross-Encoder")
# --- Références biologiques (fallback) ---

View File

@@ -4,6 +4,14 @@ Orchestrateur principal — délègue aux sous-modules :
- cpam_rag : _search_rag_for_control(), _search_rag_queries()
- cpam_context : _build_cpam_prompt(), _build_tagged_context(), _build_bio_summary(), etc.
- cpam_validation : _validate_adversarial(), _validate_grounding(), _format_response(), etc.
Le flux batché (generate_cpam_responses_batched) regroupe les appels LLM par modèle
pour minimiser les swaps VRAM sur GPU unique :
Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles
Phase 2 (validation model, 1 swap) : validation adversariale de tous
Phase 3 (cpam model, 1 swap) : correction des contrôles échoués
Phase 4 (validation model, 1 swap) : re-validation des corrigés
Gain : de ~4N swaps à ~2-4 swaps (indépendant du nombre de contrôles).
"""
from __future__ import annotations
@@ -11,6 +19,8 @@ from __future__ import annotations
import json
import logging
import os
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
@@ -41,6 +51,34 @@ from .cpam_validation import (
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Structure intermédiaire pour le batching
# ---------------------------------------------------------------------------
@dataclass
class _CpamDraft:
"""État intermédiaire d'un contrôle CPAM en cours de traitement batché."""
controle: ControleCPAM
# Résultats des phases successives
extraction: dict | None = None
degraded_pass1: bool = False
sources: list[dict] = field(default_factory=list)
rag_sources: list[RAGSource] = field(default_factory=list)
prompt: str = ""
tag_map: dict[str, str] = field(default_factory=dict)
result: dict | None = None
# Validation
ref_warnings: list[str] = field(default_factory=list)
grounding_warnings: list[str] = field(default_factory=list)
code_warnings: list[str] = field(default_factory=list)
adversarial_warnings: list[str] = field(default_factory=list)
validation: dict | None = None
# Résultat final
needs_correction: bool = False
text: str = ""
response_data: dict | None = None
def _save_version(
dossier: DossierMedical,
controle: ControleCPAM,
@@ -149,14 +187,18 @@ def _extraction_pass(
)
logger.debug(" Passe 1 — extraction structurée")
t0 = time.time()
result = call_ollama(prompt, temperature=0.0, max_tokens=3000, role="cpam")
if result is None:
result = call_anthropic(prompt, temperature=0.0, max_tokens=3000)
elapsed = time.time() - t0
if result is not None:
logger.info(" Passe 1 OK : %d éléments cliniques extraits",
logger.info(" [CPAM-EXTRACT] %.1fs — OGC %s %d éléments cliniques extraits",
elapsed, controle.numero_ogc,
len(result.get("elements_cliniques_pertinents", [])))
else:
logger.warning(" Passe 1 échouée — fallback single-pass")
logger.warning(" [CPAM-EXTRACT] %.1fs — OGC %s — passe 1 échouée",
elapsed, controle.numero_ogc)
return result
@@ -195,14 +237,17 @@ def generate_cpam_response(
prompt, tag_map = _build_cpam_prompt(dossier, controle, sources, extraction)
# 4. Appel LLM — Ollama (rôle cpam) > Haiku fallback
t_gen = time.time()
result = call_ollama(prompt, temperature=0.1, max_tokens=8000, role="cpam")
if result is not None:
logger.info(" Contre-argumentation via Ollama")
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama",
time.time() - t_gen, controle.numero_ogc)
else:
logger.info(" Ollama indisponible → fallback Anthropic Haiku")
result = call_anthropic(prompt, temperature=0.1, max_tokens=8000)
if result is not None:
logger.info(" Contre-argumentation via Anthropic Haiku")
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic",
time.time() - t_gen, controle.numero_ogc)
# 5. Conversion des sources RAG
rag_sources = [
@@ -285,9 +330,12 @@ def generate_cpam_response(
validation.get("score_confiance"), attempt + 1, max_corrections, len(erreurs_v))
correction_prompt = _build_correction_prompt(prompt, result, validation)
t_corr = time.time()
corrected = call_ollama(correction_prompt, temperature=0.0, max_tokens=16000, role="cpam")
if corrected is None:
corrected = call_anthropic(correction_prompt, temperature=0.0, max_tokens=16000)
logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d",
time.time() - t_corr, controle.numero_ogc, attempt + 1, max_corrections)
if not corrected:
break
@@ -347,3 +395,343 @@ def generate_cpam_response(
logger.info(" Contre-argumentation générée (%d caractères)", len(text))
return text, result, rag_sources
# ---------------------------------------------------------------------------
# Fonctions de phase (utilisées par le batching)
# ---------------------------------------------------------------------------
def _phase_generate(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 1 — Extraction + RAG + argumentation (role=cpam).
Remplit draft.extraction, draft.sources, draft.prompt, draft.tag_map, draft.result.
Ne fait aucun appel au modèle de validation.
"""
controle = draft.controle
logger.info("CPAM batch : génération pour OGC %d%s",
controle.numero_ogc, controle.titre)
# 0. Versioning
_save_version(dossier, controle)
# 1. Extraction structurée (role=cpam)
draft.extraction = _extraction_pass(dossier, controle)
draft.degraded_pass1 = draft.extraction is None
if draft.degraded_pass1:
dossier.alertes_codage.append(
"CPAM: passe 1 (extraction structurée) échouée → mode dégradé"
)
# 2. Recherche RAG (pas de LLM)
draft.sources = _search_rag_for_control(controle, dossier)
logger.info(" RAG : %d sources trouvées", len(draft.sources))
# 3. Construction du prompt
draft.prompt, draft.tag_map = _build_cpam_prompt(
dossier, controle, draft.sources, draft.extraction
)
# 4. Argumentation (role=cpam)
t_gen = time.time()
result = call_ollama(draft.prompt, temperature=0.1, max_tokens=8000, role="cpam")
if result is not None:
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama",
time.time() - t_gen, controle.numero_ogc)
else:
logger.info(" Ollama indisponible → fallback Anthropic Haiku")
result = call_anthropic(draft.prompt, temperature=0.1, max_tokens=8000)
if result is not None:
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic",
time.time() - t_gen, controle.numero_ogc)
# 5. Conversion sources RAG
draft.rag_sources = [
RAGSource(
document=s.get("document", ""),
page=s.get("page"),
code=s.get("code"),
extrait=s.get("extrait", "")[:200],
)
for s in draft.sources
]
if result is None:
logger.warning(" LLM non disponible — contre-argumentation non générée")
draft.result = None
return
# 5b. Mode dégradé
if draft.degraded_pass1:
result.setdefault("quality_flags", {})
result["quality_flags"]["cpam_pass1_failed"] = True
result["quality_flags"]["degraded_mode"] = True
# 6. Sanitisation + gardien déterministes (pas de LLM)
sanitized = _sanitize_unauthorized_codes(result, dossier, controle)
if sanitized:
logger.info(" CPAM : %d code(s) hors périmètre supprimé(s) : %s",
len(sanitized), ", ".join(sanitized))
result = _guardian_deterministic(result, dossier, controle, draft.tag_map)
# 7. Validations déterministes (pas de LLM)
draft.ref_warnings = _validate_references(result, draft.sources)
if draft.ref_warnings:
logger.warning(" CPAM : %d référence(s) non vérifiable(s)", len(draft.ref_warnings))
draft.grounding_warnings = _validate_grounding(result, draft.tag_map)
if draft.grounding_warnings:
logger.warning(" CPAM : %d preuve(s) non traçable(s)", len(draft.grounding_warnings))
draft.code_warnings = _validate_codes_in_response(result, dossier, controle)
if draft.code_warnings:
logger.warning(" CPAM : %d code(s) hors périmètre", len(draft.code_warnings))
draft.result = result
def _phase_validate(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 2 — Validation adversariale (role=validation).
Remplit draft.validation et draft.needs_correction.
"""
if draft.result is None:
return
controle = draft.controle
# LOGIC-3 : modèles identiques ?
from ..config import check_adversarial_model_config
same_model, model_msg = check_adversarial_model_config()
if same_model:
draft.result.setdefault("quality_flags", {})
draft.result["quality_flags"]["adversarial_disabled_same_model"] = True
dossier.alertes_codage.append(
"Validation adversariale désactivée (modèles identiques)"
)
draft.validation = _validate_adversarial(draft.result, draft.tag_map, controle)
if draft.validation and not draft.validation.get("coherent", True):
erreurs = draft.validation.get("erreurs", [])
score = draft.validation.get("score_confiance", "?")
for e in erreurs:
if isinstance(e, str) and e.strip():
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
if draft.adversarial_warnings:
draft.adversarial_warnings.append(f"Score de confiance : {score}/10")
# Déterminer si une correction est nécessaire
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
draft.needs_correction = bool(
max_corrections > 0
and draft.validation
and not draft.validation.get("coherent", True)
and draft.validation.get("score_confiance", 10) <= 5
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")
)
def _phase_correct(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 3 — Correction (role=cpam) + re-validation (role=validation).
Boucle correction/re-validation intégrée car les deux sont couplées.
Seuls les drafts avec needs_correction=True entrent ici.
"""
if draft.result is None or not draft.needs_correction:
return
controle = draft.controle
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
for attempt in range(max_corrections):
if not (draft.validation
and not draft.validation.get("coherent", True)
and draft.validation.get("score_confiance", 10) <= 5
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")):
break
erreurs_v = draft.validation.get("erreurs", [])
logger.warning(" Score adversarial %s/10 — correction %d/%d (%d erreur(s))",
draft.validation.get("score_confiance"),
attempt + 1, max_corrections, len(erreurs_v))
correction_prompt = _build_correction_prompt(
draft.prompt, draft.result, draft.validation
)
t_corr = time.time()
corrected = call_ollama(
correction_prompt, temperature=0.0, max_tokens=16000, role="cpam"
)
if corrected is None:
corrected = call_anthropic(
correction_prompt, temperature=0.0, max_tokens=16000
)
logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d",
time.time() - t_corr, controle.numero_ogc,
attempt + 1, max_corrections)
if not corrected:
break
validation2 = _validate_adversarial(corrected, draft.tag_map, controle)
score2 = validation2.get("score_confiance", 0) if validation2 else 0
score1 = draft.validation.get("score_confiance", 0)
if score2 > score1:
logger.info(" Correction %d acceptée (score %s%s)",
attempt + 1, score1, score2)
draft.result = corrected
draft.validation = validation2
_sanitize_unauthorized_codes(draft.result, dossier, controle)
draft.result = _guardian_deterministic(
draft.result, dossier, controle, draft.tag_map
)
draft.ref_warnings = _validate_references(draft.result, draft.sources)
draft.grounding_warnings = _validate_grounding(draft.result, draft.tag_map)
draft.code_warnings = _validate_codes_in_response(
draft.result, dossier, controle
)
draft.adversarial_warnings = []
if draft.validation and not draft.validation.get("coherent", True):
for e in draft.validation.get("erreurs", []):
if isinstance(e, str) and e.strip():
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
if draft.adversarial_warnings:
draft.adversarial_warnings.append(
f"Score de confiance : {draft.validation.get('score_confiance', '?')}/10"
)
else:
logger.warning(" Correction %d rejetée (score %s%s)",
attempt + 1, score1, score2)
break
def _phase_finalize(
dossier: DossierMedical,
draft: _CpamDraft,
) -> tuple[str, dict | None, list[RAGSource]]:
"""Phase finale — Qualité + formatage (déterministe, pas de LLM).
Returns:
Même signature que generate_cpam_response().
"""
controle = draft.controle
if draft.result is None:
return "", None, draft.rag_sources
all_warnings = (
draft.ref_warnings + draft.grounding_warnings
+ draft.code_warnings + draft.adversarial_warnings
)
strength = _assess_dossier_strength(dossier)
if strength["is_weak"]:
logger.info(" Dossier à preuves limitées (score %d/10) : %s",
strength["score"], ", ".join(strength["missing"]))
tier, needs_review, cat_warnings = _assess_quality_tier(
draft.result, draft.ref_warnings, draft.grounding_warnings,
draft.code_warnings, draft.validation,
is_weak_dossier=strength["is_weak"],
)
controle.quality_tier = tier
controle.requires_review = needs_review
controle.quality_warnings = cat_warnings
logger.info(" Qualité CPAM : tier %s, requires_review=%s, %d warnings",
tier, needs_review, len(cat_warnings))
text = _format_response(
draft.result,
ref_warnings=all_warnings,
quality_tier=tier,
categorized_warnings=cat_warnings,
)
logger.info(" Contre-argumentation générée (%d caractères)", len(text))
return text, draft.result, draft.rag_sources
# ---------------------------------------------------------------------------
# Orchestrateur batché — minimise les swaps de modèle VRAM
# ---------------------------------------------------------------------------
def generate_cpam_responses_batched(
dossier: DossierMedical,
controles: list[ControleCPAM],
) -> None:
"""Génère les contre-argumentations pour TOUS les contrôles CPAM en batch.
Regroupe les appels LLM par modèle pour minimiser les swaps VRAM :
Phase 1 (cpam model) : extraction + argumentation de tous les contrôles
Phase 2 (validation model) : validation adversariale de tous
Phase 3 (cpam model) : correction des contrôles échoués (si nécessaire)
Finalisation : qualité + formatage (déterministe)
Gain : de ~4*N swaps de modèle à 2-4 swaps, indépendant de N.
Args:
dossier: Le dossier médical (fusionné ou unique).
controles: Liste des contrôles CPAM à traiter.
Side effects:
Remplit controle.contre_argumentation, controle.response_data,
controle.sources_reponse pour chaque contrôle.
"""
if not controles:
return
t_total = time.time()
n = len(controles)
logger.info("CPAM batch : %d contrôle(s) — début du traitement batché", n)
# Créer les drafts
drafts = [_CpamDraft(controle=ctrl) for ctrl in controles]
# --- Phase 1 : génération (cpam model) ---
t_phase = time.time()
logger.info("CPAM batch phase 1/3 : génération (%d contrôles, role=cpam)", n)
for draft in drafts:
_phase_generate(dossier, draft)
logger.info("CPAM batch phase 1/3 terminée : %.1fs", time.time() - t_phase)
# --- Phase 2 : validation adversariale (validation model — 1 swap) ---
drafts_to_validate = [d for d in drafts if d.result is not None]
if drafts_to_validate:
t_phase = time.time()
logger.info("CPAM batch phase 2/3 : validation (%d contrôles, role=validation)",
len(drafts_to_validate))
for draft in drafts_to_validate:
_phase_validate(dossier, draft)
logger.info("CPAM batch phase 2/3 terminée : %.1fs", time.time() - t_phase)
# --- Phase 3 : correction + re-validation (cpam + validation models) ---
drafts_to_correct = [d for d in drafts if d.needs_correction]
if drafts_to_correct:
t_phase = time.time()
logger.info("CPAM batch phase 3/3 : correction (%d contrôles nécessitent correction)",
len(drafts_to_correct))
for draft in drafts_to_correct:
_phase_correct(dossier, draft)
logger.info("CPAM batch phase 3/3 terminée : %.1fs", time.time() - t_phase)
# --- Finalisation (déterministe, pas de swap) ---
for draft in drafts:
text, response_data, sources = _phase_finalize(dossier, draft)
draft.controle.contre_argumentation = text
draft.controle.response_data = response_data
draft.controle.sources_reponse = sources
elapsed_total = time.time() - t_total
logger.info(
"CPAM batch terminé : %d contrôle(s) en %.1fs (%.1fs/contrôle)",
n, elapsed_total, elapsed_total / n if n else 0,
)

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import logging
import re
import time
from ..config import ControleCPAM, DossierMedical
from ..medical.bio_normals import BIO_NORMALS
@@ -477,11 +478,14 @@ def _validate_adversarial(
)
logger.debug(" Validation adversariale")
t_val = time.time()
result = call_ollama(prompt, temperature=0.0, max_tokens=6000, role="validation")
if result is None:
result = call_anthropic(prompt, temperature=0.0, max_tokens=6000)
elapsed = time.time() - t_val
if result is None:
logger.warning(" Validation adversariale échouée — LLM indisponible")
logger.warning(" [CPAM-VALID] %.1fs — OGC %s — validation adversariale échouée",
elapsed, controle.numero_ogc)
return None
coherent = result.get("coherent", True)
@@ -489,12 +493,13 @@ def _validate_adversarial(
score = result.get("score_confiance", -1)
if not coherent and erreurs:
logger.warning(" Validation adversariale : %d incohérence(s) détectée(s) (score %s/10)",
len(erreurs), score)
logger.warning(" [CPAM-VALID] %.1fs — OGC %s %d incohérence(s) (score %s/10)",
elapsed, controle.numero_ogc, len(erreurs), score)
for e in erreurs:
logger.warning(" - %s", e)
else:
logger.info(" Validation adversariale OK (score %s/10)", score)
logger.info(" [CPAM-VALID] %.1fs — OGC %s OK (score %s/10)",
elapsed, controle.numero_ogc, score)
return result

View File

@@ -143,11 +143,142 @@ _use_edsnlp = True
_use_rag = True
def process_document(file_path: Path) -> list[tuple[str, DossierMedical, AnonymizationReport]]:
# Type alias pour les données CPU préparées
PreparedChunk = tuple # (parsed, anonymized_text, report, edsnlp_result, chunk_text)
PreparedDoc = tuple # (file_path, doc_type, raw_text, page_tracker, extraction_stats, list[PreparedChunk])
def prepare_document(file_path: Path) -> PreparedDoc:
"""Phase CPU : extraction → splitting → parsing → anonymisation → edsnlp.
Séparé de process_document pour permettre le pipelining CPU/GPU :
pendant que le GPU traite le document N, le CPU prépare le document N+1.
"""
logger.info("Préparation de %s (CPU)", file_path.name)
# 1. Extraction texte
raw_text, page_tracker, extraction_stats = extract_document_with_pages(file_path)
logger.info(" Texte extrait : %d caractères (%d pages, format=%s)",
len(raw_text), extraction_stats.total_pages, extraction_stats.source_format)
# 2. Classification
doc_type = classify(raw_text)
logger.info(" Type de document : %s", doc_type)
# 3. Splitting
chunks = split_documents(raw_text, doc_type)
if len(chunks) > 1:
logger.info(" Découpage : %d dossiers détectés dans %s", len(chunks), file_path.name)
prepared_chunks: list[PreparedChunk] = []
for i, chunk_text in enumerate(chunks):
part_label = f" [part {i+1}/{len(chunks)}]" if len(chunks) > 1 else ""
# 4. Parsing
if doc_type == "trackare":
parsed = parse_trackare(chunk_text)
else:
parsed = parse_crh(chunk_text)
# 5. Anonymisation
anonymizer = Anonymizer(parsed_data=parsed)
anonymized_text = anonymizer.anonymize(chunk_text)
report = anonymizer.report
report.source_file = file_path.name
logger.info(" Anonymisation%s : %d remplacements (regex=%d, ner=%d, sweep=%d)",
part_label, report.total_replacements, report.regex_replacements,
report.ner_replacements, report.sweep_replacements)
# 6. edsnlp
edsnlp_result = None
if _use_edsnlp:
edsnlp_result = _run_edsnlp(anonymized_text)
prepared_chunks.append((parsed, anonymized_text, report, edsnlp_result, chunk_text))
return (file_path, doc_type, raw_text, page_tracker, extraction_stats, prepared_chunks)
def _postprocess_dossier(dossier: DossierMedical, parsed: dict, subdir: str | None = None) -> None:
"""Post-processing déterministe : vetos, décisions, GHM, complétude, finalizer DP."""
# Routage des règles
rules_token = None
try:
rules_ctx = build_rules_runtime_context(dossier)
dossier.rules_runtime = rules_ctx
rules_token = set_rules_runtime(rules_ctx)
except Exception:
logger.error(" Routage règles : erreur", exc_info=True)
dossier.quality_flags["rules_routing"] = "error"
# Vetos
veto = None
try:
veto = apply_vetos(dossier)
dossier.veto_report = veto
except Exception:
logger.error(" Vetos : erreur lors du contrôle", exc_info=True)
dossier.quality_flags["veto_engine"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : moteur de vetos en erreur")
# Décisions
try:
apply_decisions(dossier)
_inject_decision_alerts(dossier, scope="PDF")
if veto is not None:
_inject_veto_alerts(dossier, veto, scope="PDF")
except Exception:
logger.error(" Décisions : erreur lors du post-traitement", exc_info=True)
dossier.quality_flags["decision_engine"] = "error"
finally:
if rules_token is not None:
reset_rules_runtime(rules_token)
# Complétude
try:
dossier.completude = build_completude_checklist(dossier)
except Exception:
logger.error(" Complétude : erreur lors du contrôle", exc_info=True)
dossier.quality_flags["completude"] = "error"
# GHM + métriques
try:
metrics = _compute_metrics(dossier)
ghm = estimate_ghm(dossier)
dossier.ghm_estimation = ghm
logger.info(
" DAS : actifs=%d / total=%d (écartés=%d, removed=%d) | Actes : %d (avec code=%d)",
metrics.das_active, metrics.das_total, metrics.das_excluded,
metrics.das_removed, metrics.actes_total, metrics.actes_with_code,
)
logger.info(" GHM : CMD=%s, Type=%s, Sévérité=%d%s",
ghm.cmd or "?", ghm.type_ghm or "?", ghm.severite, ghm.ghm_approx or "?")
except Exception:
logger.error(" Erreur estimation GHM/metrics", exc_info=True)
dossier.quality_flags["ghm_estimation"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : estimation GHM en erreur")
# Finalizer DP
try:
from .medical.dp_finalizer import finalize_dp
finalize_dp(dossier)
except Exception:
logger.error(" Finalizer DP : erreur", exc_info=True)
dossier.quality_flags["dp_finalizer"] = "error"
def process_document(
file_path: Path,
defer_qc: bool = False,
) -> list[tuple[str, DossierMedical, AnonymizationReport]]:
"""Traite un document : extraction → splitting → parsing → anonymisation → extraction CIM-10.
Supporte PDF, images (JPEG/PNG/TIFF) et DOCX via le router d'extraction.
Args:
file_path: Chemin du document à traiter.
defer_qc: Si True, ne pas exécuter la validation QC (sera faite en batch plus tard).
Retourne une liste de (texte_anonymisé, dossier, rapport) — un par dossier détecté.
"""
t0 = time.time()
@@ -200,6 +331,7 @@ def process_document(file_path: Path) -> list[tuple[str, DossierMedical, Anonymi
dossier = extract_medical_info(
parsed, anonymized_text, edsnlp_result, use_rag=_use_rag,
page_tracker=page_tracker, raw_text=raw_text,
defer_qc=defer_qc,
)
dossier.source_file = file_path.name
dossier.document_type = doc_type
@@ -213,86 +345,8 @@ def process_document(file_path: Path) -> list[tuple[str, DossierMedical, Anonymi
if extraction_alert:
dossier.alertes_codage.append(extraction_alert)
# 8. Vetos (contestabilité) + décisions (post-traitement)
# Routage des règles (packs) : par défaut, on garde le socle vetos/decisions,
# et on active des packs additionnels selon les signaux du dossier (codes/labs/extraits).
rules_token = None
try:
rules_ctx = build_rules_runtime_context(dossier)
dossier.rules_runtime = rules_ctx
rules_token = set_rules_runtime(rules_ctx)
packs = ",".join(rules_ctx.get("enabled_packs", []))
if packs:
logger.info(" Règles%s : packs=%s", part_label, packs)
if rules_ctx.get("triggers_fired"):
logger.info(" Règles%s : triggers=%s", part_label, ",".join(rules_ctx["triggers_fired"]))
except Exception:
logger.error(" Routage règles : erreur", exc_info=True)
dossier.quality_flags["rules_routing"] = "error"
veto = None
try:
veto = apply_vetos(dossier)
dossier.veto_report = veto
except Exception:
logger.error(" Vetos : erreur lors du contrôle", exc_info=True)
dossier.quality_flags["veto_engine"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : moteur de vetos en erreur")
try:
apply_decisions(dossier)
_inject_decision_alerts(dossier, scope="PDF")
if veto is not None:
_inject_veto_alerts(dossier, veto, scope="PDF")
except Exception:
logger.error(" Décisions : erreur lors du post-traitement", exc_info=True)
dossier.quality_flags["decision_engine"] = "error"
finally:
if rules_token is not None:
reset_rules_runtime(rules_token)
try:
dossier.completude = build_completude_checklist(dossier)
except Exception:
logger.error(" Complétude : erreur lors du contrôle", exc_info=True)
dossier.quality_flags["completude"] = "error"
# 9. Estimation GHM (sur codes finaux) + métriques (actifs vs écartés)
try:
metrics = _compute_metrics(dossier)
ghm = estimate_ghm(dossier)
dossier.ghm_estimation = ghm
logger.info(
" DAS : actifs=%d / total=%d (écartés=%d, removed=%d, no_code=%d) | Actes : %d (avec code=%d)",
metrics.das_active,
metrics.das_total,
metrics.das_excluded,
metrics.das_removed,
metrics.das_no_code,
metrics.actes_total,
metrics.actes_with_code,
)
logger.info(
" GHM : CMD=%s, Type=%s, Sévérité=%d%s",
ghm.cmd or "?",
ghm.type_ghm or "?",
ghm.severite,
ghm.ghm_approx or "?",
)
except Exception:
logger.error(" Erreur estimation GHM/metrics", exc_info=True)
dossier.quality_flags["ghm_estimation"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : estimation GHM en erreur")
# 10. Finalizer DP (arbitrage Trackare vs CRH, traçabilité)
try:
from .medical.dp_finalizer import finalize_dp
finalize_dp(dossier)
except Exception:
logger.error(" Finalizer DP : erreur", exc_info=True)
dossier.quality_flags["dp_finalizer"] = "error"
# 8-10. Vetos + Décisions + GHM + Finalizer DP
_postprocess_dossier(dossier, parsed)
dossier.processing_time_s = round(time.time() - t0, 2)
results.append((anonymized_text, dossier, report))
@@ -569,22 +623,82 @@ def main(input_path: str | None = None) -> None:
logger.info("Traitement de %d document(s)...", total)
def _process_group(docs: list[Path], subdir: str | None) -> None:
"""Traite un groupe de documents (un dossier patient)."""
"""Traite un groupe de documents (un dossier patient).
Optimisations appliquées :
- defer_qc=True : la validation QC est différée et batchée après tous les documents
du groupe (1 swap modèle au lieu de 2 par document)
- Si la machine secondaire Dell est disponible, le QC est envoyé là-bas
(0 swap sur le GPU principal)
"""
if subdir:
logger.info("--- Dossier %s (%d documents) ---", subdir, len(docs))
group_dossiers: list[DossierMedical] = []
for doc_path in docs:
try:
doc_results = process_document(doc_path)
stem = doc_path.stem.replace(" ", "_")
multi = len(doc_results) > 1
for part_idx, (anonymized_text, dossier, report) in enumerate(doc_results):
part_stem = f"{stem}_part{part_idx + 1}" if multi else stem
write_outputs(part_stem, anonymized_text, dossier, report, subdir=subdir, export_rum_flag=export_rum_flag)
group_dossiers.append(dossier)
except Exception:
logger.exception("Erreur lors du traitement de %s", doc_path.name)
# Pipeline CPU/GPU : préparer le document N+1 (CPU) pendant que le GPU traite le N
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=1, thread_name_prefix="cpu-prep") as cpu_pool:
# Soumettre la préparation du premier document
prep_future = cpu_pool.submit(prepare_document, docs[0]) if docs else None
for i, doc_path in enumerate(docs):
try:
# Attendre la préparation du document courant
prep = prep_future.result()
# Lancer la préparation du document suivant EN PARALLÈLE du GPU ci-dessous
if i + 1 < len(docs):
prep_future = cpu_pool.submit(prepare_document, docs[i + 1])
# Phase GPU : extract_medical_info + post-processing
t0 = time.time()
file_path_p, doc_type, raw_text, page_tracker, extraction_stats, prepared_chunks = prep
stem = doc_path.stem.replace(" ", "_")
multi = len(prepared_chunks) > 1
for part_idx, (parsed, anonymized_text, report, edsnlp_result, chunk_text) in enumerate(prepared_chunks):
part_stem = f"{stem}_part{part_idx + 1}" if multi else stem
dossier = extract_medical_info(
parsed, anonymized_text, edsnlp_result, use_rag=_use_rag,
page_tracker=page_tracker, raw_text=raw_text, defer_qc=True,
)
dossier.source_file = doc_path.name
dossier.document_type = doc_type
logger.info(" DP : %s", dossier.diagnostic_principal)
# Injection des stats d'extraction
extraction_flags = extraction_stats.to_flags()
if extraction_flags:
dossier.quality_flags.update(extraction_flags)
extraction_alert = extraction_stats.to_alert()
if extraction_alert:
dossier.alertes_codage.append(extraction_alert)
# Vetos + Décisions + Complétude + GHM + Finalizer
_postprocess_dossier(dossier, parsed, subdir=subdir)
dossier.processing_time_s = round(time.time() - t0, 2)
write_outputs(part_stem, anonymized_text, dossier, report, subdir=subdir, export_rum_flag=export_rum_flag)
group_dossiers.append(dossier)
logger.info(" Temps total %s : %.2fs", doc_path.name, time.time() - t0)
except Exception:
logger.exception("Erreur lors du traitement de %s", doc_path.name)
# Batch QC : validation justifications pour tous les dossiers du groupe en une seule passe
# Évite les swaps coding ↔ qc entre chaque document
if _use_rag and group_dossiers:
t_qc = time.time()
from .medical.validation_pipeline import _validate_justifications
for d in group_dossiers:
try:
_validate_justifications(d)
except Exception:
logger.warning("QC batch : erreur pour %s", d.source_file, exc_info=True)
logger.info(" ⏱ [QC-BATCH] %.1fs — %d dossiers validés", time.time() - t_qc, len(group_dossiers))
# Fusion multi-PDFs si plusieurs documents dans le même groupe
merged = None
@@ -610,20 +724,18 @@ def main(input_path: str | None = None) -> None:
merged = None
# Contrôle CPAM : enrichir le dossier principal (fusionné ou dernier)
# Utilise le mode batché pour regrouper les appels LLM par modèle
# et minimiser les swaps VRAM (de ~4N swaps à ~2-4 swaps)
if cpam_data and subdir:
try:
from .control.cpam_parser import match_dossier_ogc
controles = match_dossier_ogc(subdir, cpam_data)
if controles:
from .control.cpam_response import generate_cpam_response
from .control.cpam_response import generate_cpam_responses_batched
target = merged if merged else (group_dossiers[-1] if group_dossiers else None)
if target:
logger.info(" CPAM : %d contrôle(s) pour %s", len(controles), subdir)
for ctrl in controles:
text, response_data, sources = generate_cpam_response(target, ctrl)
ctrl.contre_argumentation = text
ctrl.response_data = response_data
ctrl.sources_reponse = sources
generate_cpam_responses_batched(target, controles)
target.controles_cpam = controles
except Exception:
logger.exception("Erreur CPAM pour %s", subdir)

View File

@@ -11,6 +11,7 @@ from __future__ import annotations
import logging
import re
import time
from datetime import datetime
from typing import Optional
@@ -64,12 +65,14 @@ def extract_medical_info(
use_rag: bool = False,
page_tracker=None,
raw_text: str | None = None,
defer_qc: bool = False,
) -> DossierMedical:
"""Extrait les informations médicales structurées depuis les données parsées et le texte.
Args:
page_tracker: PageTracker pour la traçabilité page/extrait (optionnel).
raw_text: Texte brut avant anonymisation (pour recherche page source).
defer_qc: Si True, ne pas exécuter la validation QC (sera faite en batch plus tard).
"""
dossier = DossierMedical()
dossier.document_type = parsed_data.get("type", "")
@@ -87,43 +90,71 @@ def extract_medical_info(
_extract_imagerie(anonymized_text, dossier)
_extract_complications(anonymized_text, dossier, edsnlp_result)
# Phase 4 : pass LLM pour détecter des DAS supplémentaires
# Phase 4 : DAS LLM + RAG DP + DP selector en parallèle
_dp_selection_needed = dossier.document_type != "trackare"
if use_rag:
_extract_das_llm(anonymized_text, dossier)
# Optimisation #1 : paralléliser enrichissement RAG et sélection DP
_dp_selection_needed = use_rag and dossier.document_type != "trackare"
if use_rag or _dp_selection_needed:
from concurrent.futures import ThreadPoolExecutor, as_completed
def _task_enrich():
if use_rag:
_enrich_with_rag(dossier)
# --- Groupe 1 : 3 tâches indépendantes en parallèle ---
# - DAS LLM : détecte des DAS supplémentaires (ne dépend pas du RAG DP)
# - RAG DP : enrichit seulement le DP (ne dépend pas du DAS LLM)
# - DP selector : sélectionne le DP optimal (indépendant des deux autres)
def _task_das_llm():
t0 = time.monotonic()
_extract_das_llm(anonymized_text, dossier)
elapsed = time.monotonic() - t0
logger.info("⏱ [DAS-LLM-TASK] %.1fs — extraction DAS LLM terminée", elapsed)
def _task_rag_dp():
t0 = time.monotonic()
_enrich_dp_only(dossier)
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-DP-TASK] %.1fs — enrichissement DP terminé", elapsed)
def _task_select_dp():
if not _dp_selection_needed:
return None
t0 = time.monotonic()
from .dp_selector import select_dp, build_synthese
synthese = build_synthese(dossier, parsed_data)
return select_dp(dossier, synthese, config={"llm_enabled": use_rag})
result = select_dp(dossier, synthese, config={"llm_enabled": use_rag})
elapsed = time.monotonic() - t0
logger.info("⏱ [DP-SELECT] %.1fs — sélection DP terminée", elapsed)
return result
dp_selection_result = None
with ThreadPoolExecutor(max_workers=2) as pool:
fut_enrich = pool.submit(_task_enrich)
t_group1 = time.monotonic()
with ThreadPoolExecutor(max_workers=3) as pool:
fut_das = pool.submit(_task_das_llm)
fut_rag_dp = pool.submit(_task_rag_dp)
fut_dp = pool.submit(_task_select_dp)
# Attendre les deux tâches
for fut in as_completed([fut_enrich, fut_dp]):
for fut in as_completed([fut_das, fut_rag_dp, fut_dp]):
exc = fut.exception()
if exc and fut is fut_dp:
logger.error("NUKE-3: erreur sélection DP", exc_info=exc)
dossier.quality_flags["dp_selection_status"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : sélection DP (NUKE-3) en erreur")
elif exc:
logger.error("RAG enrichissement échoué", exc_info=exc)
elif exc and fut is fut_rag_dp:
logger.error("RAG enrichissement DP échoué", exc_info=exc)
elif exc and fut is fut_das:
logger.error("DAS LLM extraction échouée", exc_info=exc)
if not fut_dp.exception():
dp_selection_result = fut_dp.result()
elapsed_group1 = time.monotonic() - t_group1
logger.info("⏱ [GROUPE-1] %.1fs — DAS_LLM + RAG_DP + DP_SELECT terminés", elapsed_group1)
# --- Groupe 2 : enrichir les DAS (existants + nouveaux du DAS LLM) + actes ---
t_group2 = time.monotonic()
_enrich_das_and_actes(dossier)
elapsed_group2 = time.monotonic() - t_group2
logger.info("⏱ [GROUPE-2] %.1fs — enrichissement DAS + actes terminé", elapsed_group2)
# Appliquer la sélection DP après parallélisation
if dp_selection_result is not None:
selection = dp_selection_result
@@ -151,12 +182,15 @@ def extract_medical_info(
dossier.alertes_codage.append(
f"NUKE-3 REVIEW: DP ambigu — {selection.reason}"
)
elif dossier.document_type != "trackare":
elif _dp_selection_needed:
# Fallback sans RAG : sélection DP seule
try:
t_dp_norag = time.monotonic()
from .dp_selector import select_dp, build_synthese
synthese = build_synthese(dossier, parsed_data)
selection = select_dp(dossier, synthese, config={"llm_enabled": False})
elapsed_dp = time.monotonic() - t_dp_norag
logger.info("⏱ [DP-SELECT] %.1fs — sélection DP (sans RAG)", elapsed_dp)
dossier.dp_selection = selection
if selection.chosen_code:
current_code = (
@@ -238,9 +272,13 @@ def extract_medical_info(
except Exception:
logger.error("NUKE-3 reselect après vetos échouée", exc_info=True)
# Post-processing : validation justifications (QC batch)
if use_rag:
# Post-processing : validation justifications (QC)
# Si defer_qc=True, le QC sera fait en batch après tous les dossiers (évite les swaps modèle)
if use_rag and not defer_qc:
t_qc = time.monotonic()
_validate_justifications(dossier)
elapsed_qc = time.monotonic() - t_qc
logger.info("⏱ [QC-VALIDATION] %.1fs — validation justifications terminée", elapsed_qc)
# Post-processing : traçabilité source (page + extrait)
if page_tracker:
@@ -457,7 +495,7 @@ def _extract_das_llm(text: str, dossier: DossierMedical) -> None:
def _enrich_with_rag(dossier: DossierMedical) -> None:
"""Enrichit les diagnostics via le RAG (FAISS + Ollama)."""
"""Enrichit les diagnostics via le RAG (FAISS + Ollama) — wrapper rétro-compatible."""
try:
from .rag_search import enrich_dossier
enrich_dossier(dossier)
@@ -471,6 +509,34 @@ def _enrich_with_rag(dossier: DossierMedical) -> None:
dossier.alertes_codage.append("QUALITE DEGRADEE : erreur RAG — codage sans référentiels")
def _enrich_dp_only(dossier: DossierMedical) -> None:
"""Enrichit SEULEMENT le DP via le RAG (Phase 1, parallélisable)."""
try:
from .rag_search import enrich_dp
enrich_dp(dossier)
except ImportError:
logger.error("RAG INDISPONIBLE : faiss-cpu ou sentence-transformers manquant")
dossier.quality_flags["rag_status"] = "unavailable"
dossier.alertes_codage.append("QUALITE DEGRADEE : RAG indisponible — codage sans référentiels")
except Exception:
logger.error("RAG EN ERREUR : enrichissement DP échoué", exc_info=True)
dossier.quality_flags["rag_status"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : erreur RAG DP — codage sans référentiels")
def _enrich_das_and_actes(dossier: DossierMedical) -> None:
"""Enrichit les DAS et actes CCAM via le RAG (Phase 2, après DAS LLM)."""
try:
from .rag_search import enrich_das_and_actes
enrich_das_and_actes(dossier)
except ImportError:
logger.error("RAG INDISPONIBLE : faiss-cpu ou sentence-transformers manquant")
dossier.quality_flags["rag_status"] = "unavailable"
except Exception:
logger.error("RAG EN ERREUR : enrichissement DAS/actes échoué", exc_info=True)
dossier.quality_flags["rag_status"] = "error"
def _extract_sejour(parsed: dict, dossier: DossierMedical) -> None:
"""Extrait les informations de séjour."""
patient = parsed.get("patient", {})

View File

@@ -9,7 +9,7 @@ import time
import requests
from ..config import OLLAMA_URL, OLLAMA_MODEL, OLLAMA_TIMEOUT, get_model
from ..config import OLLAMA_URL, OLLAMA_URL_SECONDARY, OLLAMA_MODEL, OLLAMA_TIMEOUT, get_model
logger = logging.getLogger(__name__)
@@ -146,6 +146,24 @@ def parse_json_response(raw: str) -> dict | None:
return None
def _is_secondary_available() -> bool:
"""Vérifie si la machine Ollama secondaire est accessible (cache 60s)."""
now = time.monotonic()
if hasattr(_is_secondary_available, "_cache"):
cached_time, cached_result = _is_secondary_available._cache
if now - cached_time < 60:
return cached_result
try:
r = requests.get(f"{OLLAMA_URL_SECONDARY}/api/tags", timeout=2)
result = r.status_code == 200
except Exception:
result = False
_is_secondary_available._cache = (now, result)
if result:
logger.info("Ollama secondaire disponible : %s", OLLAMA_URL_SECONDARY)
return result
def call_ollama(
prompt: str,
temperature: float = 0.1,
@@ -153,6 +171,7 @@ def call_ollama(
model: str | None = None,
timeout: int | None = None,
role: str | None = None,
ollama_url: str | None = None,
) -> dict | None:
"""Appelle Ollama en mode JSON natif, avec fallback Anthropic si indisponible.
@@ -163,12 +182,14 @@ def call_ollama(
model: Modèle Ollama à utiliser (prioritaire sur role).
timeout: Timeout en secondes (défaut: OLLAMA_TIMEOUT global).
role: Rôle LLM (coding, cpam, validation, qc) → résolu via get_model().
ollama_url: URL Ollama à utiliser (prioritaire sur OLLAMA_URL global).
Returns:
Le dict JSON parsé, ou None en cas d'erreur.
"""
use_model = model or (get_model(role) if role else OLLAMA_MODEL)
use_timeout = timeout or OLLAMA_TIMEOUT
use_url = ollama_url or OLLAMA_URL
messages: list[dict] = [{"role": "user", "content": prompt}]
@@ -186,7 +207,7 @@ def call_ollama(
},
}
response = requests.post(
f"{OLLAMA_URL}/api/chat",
f"{use_url}/api/chat",
json=payload,
timeout=use_timeout,
)

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import logging
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from ..config import (
@@ -610,7 +611,9 @@ def enrich_diagnostic(
Modifie le diagnostic en place. Fallback gracieux si FAISS ou Ollama échouent.
"""
t0 = time.monotonic()
diag_type = "dp" if est_dp else "das"
label = f"RAG-{diag_type.upper()}"
# 1. Vérifier le cache
cached = cache.get(diagnostic.texte, diag_type) if cache else None
@@ -626,6 +629,8 @@ def enrich_diagnostic(
if cached is not None:
logger.info("Cache hit (sans sources FAISS) pour %s : « %s »", diag_type.upper(), diagnostic.texte)
_apply_llm_result_diagnostic(diagnostic, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs — %s (no FAISS)", label, elapsed, diagnostic.texte[:60])
return
# 3. Stocker les sources RAG
@@ -643,11 +648,15 @@ def enrich_diagnostic(
if cached is not None:
logger.info("Cache hit pour %s : « %s »", diag_type.upper(), diagnostic.texte)
_apply_llm_result_diagnostic(diagnostic, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs — %s (cache hit)", label, elapsed, diagnostic.texte[:60])
return
# 5. Appel Ollama pour justification avec raisonnement structuré
prompt = _build_prompt(diagnostic.texte, sources, contexte, est_dp=est_dp)
t_llm = time.monotonic()
llm_result = _call_ollama(prompt)
llm_elapsed = time.monotonic() - t_llm
if llm_result:
_apply_llm_result_diagnostic(diagnostic, llm_result)
@@ -656,6 +665,9 @@ def enrich_diagnostic(
else:
logger.info("Ollama non disponible — sources FAISS conservées sans justification LLM")
elapsed = time.monotonic() - t0
logger.info("⏱ [%s] %.1fs (LLM %.1fs) — %s", label, elapsed, llm_elapsed, diagnostic.texte[:60])
def _apply_llm_result_acte(acte: ActeCCAM, llm_result: dict) -> None:
"""Applique un résultat LLM (frais ou caché) à un ActeCCAM."""
@@ -687,6 +699,8 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None
Modifie l'acte en place. Fallback gracieux si FAISS ou Ollama échouent.
"""
t0 = time.monotonic()
# 1. Vérifier le cache
cached = cache.get(acte.texte, "ccam") if cache else None
@@ -712,11 +726,15 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None
if cached is not None:
logger.info("Cache hit pour CCAM : « %s »", acte.texte)
_apply_llm_result_acte(acte, cached)
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-CCAM] %.1fs — %s (cache hit)", elapsed, acte.texte[:60])
return
# 5. Appel Ollama pour justification avec raisonnement structuré
prompt = _build_prompt_ccam(acte.texte, sources, contexte)
t_llm = time.monotonic()
llm_result = _call_ollama(prompt)
llm_elapsed = time.monotonic() - t_llm
if llm_result:
_apply_llm_result_acte(acte, llm_result)
@@ -725,6 +743,9 @@ def enrich_acte(acte: ActeCCAM, contexte: dict, cache: OllamaCache | None = None
else:
logger.info("Ollama non disponible — sources FAISS CCAM conservées sans justification LLM")
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-CCAM] %.1fs (LLM %.1fs) — %s", elapsed, llm_elapsed, acte.texte[:60])
def _smart_truncate(text: str, max_chars: int = 6000) -> str:
"""Troncature intelligente : garde le début + les sections finales importantes.
@@ -790,6 +811,8 @@ def extract_das_llm(
"""
import hashlib
t0 = time.monotonic()
# Clé de cache basée sur le hash du texte
text_hash = hashlib.md5(text[:4000].encode()).hexdigest()[:16]
cache_key_text = f"das_extract::{text_hash}"
@@ -798,15 +821,19 @@ def extract_das_llm(
if cache is not None:
cached = cache.get(cache_key_text, "das_llm")
if cached is not None:
logger.info("Cache hit pour extraction DAS LLM")
elapsed = time.monotonic() - t0
logger.info("⏱ [DAS-LLM] %.1fs — cache hit", elapsed)
return cached.get("diagnostics_supplementaires", [])
# Construire le prompt et appeler Ollama
prompt = _build_prompt_das_extraction(text, contexte, existing_das, dp_texte)
t_llm = time.monotonic()
result = call_ollama(prompt, temperature=0.1, max_tokens=2000, role="coding")
llm_elapsed = time.monotonic() - t_llm
if result is None:
logger.warning("Extraction DAS LLM : Ollama non disponible")
elapsed = time.monotonic() - t0
logger.warning("⏱ [DAS-LLM] %.1fs — Ollama non disponible", elapsed)
return []
das_list = result.get("diagnostics_supplementaires", [])
@@ -818,25 +845,52 @@ def extract_das_llm(
if cache is not None:
cache.put(cache_key_text, "das_llm", result)
logger.info("Extraction DAS LLM : %d diagnostics supplémentaires détectés", len(das_list))
elapsed = time.monotonic() - t0
logger.info("⏱ [DAS-LLM] %.1fs (LLM %.1fs) — %d diagnostics détectés", elapsed, llm_elapsed, len(das_list))
return das_list
def enrich_dossier(dossier: DossierMedical) -> None:
"""Enrichit le DP et tous les DAS d'un dossier via le RAG.
def enrich_dp(dossier: DossierMedical, cache: OllamaCache | None = None) -> None:
"""Enrichit SEULEMENT le DP d'un dossier via le RAG (Phase 1).
Utilise un cache persistant et parallélise les appels Ollama
pour les DAS et actes CCAM (max_workers = OLLAMA_MAX_PARALLEL).
Peut être exécuté en parallèle avec d'autres tâches (DAS LLM, DP selector)
car il ne dépend que du DP existant.
Args:
dossier: Dossier médical à enrichir (modifié en place).
cache: Cache Ollama optionnel (créé si None).
"""
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
t0 = time.monotonic()
if cache is None:
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
contexte = build_enriched_context(dossier)
# Phase 1 : DP seul (le contexte DAS en dépend)
if dossier.diagnostic_principal:
logger.info("RAG enrichissement DP : %s", dossier.diagnostic_principal.texte)
enrich_diagnostic(dossier.diagnostic_principal, contexte, est_dp=True, cache=cache)
cache.save()
elapsed = time.monotonic() - t0
logger.info("⏱ [RAG-DP-PHASE] %.1fs — enrichissement DP terminé", elapsed)
def enrich_das_and_actes(dossier: DossierMedical, cache: OllamaCache | None = None) -> None:
"""Enrichit les DAS et actes CCAM d'un dossier via le RAG (Phase 2).
Parallélise les appels Ollama (max_workers = OLLAMA_MAX_PARALLEL).
Doit être appelé APRES enrich_dp() et après l'ajout des DAS LLM.
Args:
dossier: Dossier médical à enrichir (modifié en place).
cache: Cache Ollama optionnel (créé si None).
"""
t0 = time.monotonic()
if cache is None:
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
contexte = build_enriched_context(dossier)
# Mettre à jour le contexte avec le DP pour les DAS
if dossier.diagnostic_principal:
contexte["dp_texte"] = dossier.diagnostic_principal.texte
@@ -846,7 +900,6 @@ def enrich_dossier(dossier: DossierMedical) -> None:
if d.cim10_suggestion
]
# Phase 2 : DAS + Actes en parallèle
das_list = dossier.diagnostics_associes
actes_list = dossier.actes_ccam
@@ -863,3 +916,18 @@ def enrich_dossier(dossier: DossierMedical) -> None:
f.result() # propage les exceptions
cache.save()
elapsed = time.monotonic() - t0
n_das = len(das_list) if das_list else 0
n_actes = len(actes_list) if actes_list else 0
logger.info("⏱ [RAG-DAS-PHASE] %.1fs — %d DAS + %d actes enrichis", elapsed, n_das, n_actes)
def enrich_dossier(dossier: DossierMedical) -> None:
"""Enrichit le DP et tous les DAS d'un dossier via le RAG.
Wrapper rétro-compatible qui appelle enrich_dp() puis enrich_das_and_actes().
Utilise un cache persistant partagé entre les deux phases.
"""
cache = OllamaCache(OLLAMA_CACHE_PATH, get_model("coding"))
enrich_dp(dossier, cache=cache)
enrich_das_and_actes(dossier, cache=cache)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import logging
import time
from .cim10_dict import lookup as dict_lookup, normalize_code, validate_code as cim10_validate
from .ccam_dict import validate_code as ccam_validate
@@ -408,8 +409,18 @@ def _validate_justifications(dossier: DossierMedical) -> None:
from ..prompts import QC_VALIDATION
prompt = QC_VALIDATION.format(ctx_str=ctx_str, codes_section=codes_section)
# Si la machine secondaire est disponible, envoyer le QC là-bas
# pour éviter un swap de modèle sur le GPU principal
from .ollama_client import _is_secondary_available
from ..config import OLLAMA_URL_SECONDARY
qc_url = OLLAMA_URL_SECONDARY if _is_secondary_available() else None
try:
result = call_ollama(prompt, temperature=0.1, max_tokens=2500, role="qc")
t0 = time.time()
result = call_ollama(prompt, temperature=0.1, max_tokens=2500, role="qc", ollama_url=qc_url)
elapsed = time.time() - t0
target = "secondaire" if qc_url else "local"
logger.info("⏱ [QC] %.1fs — validation QC (%s)", elapsed, target)
except Exception:
logger.warning("Erreur lors de l'appel Ollama pour validation QC", exc_info=True)
return

View File

@@ -116,7 +116,7 @@ class TestExtractMedicalInfoRAGFlag:
assert dossier.diagnostic_principal.justification is None
def test_use_rag_true_calls_enrich(self):
"""use_rag=True appelle _enrich_with_rag (mocké)."""
"""use_rag=True appelle _enrich_dp_only et _enrich_das_and_actes (mockés)."""
from src.medical.cim10_extractor import extract_medical_info
parsed = {
@@ -127,9 +127,13 @@ class TestExtractMedicalInfoRAGFlag:
}
text = "Pancréatite aiguë biliaire.\nTTT de sortie :\nParacétamol\n\nDevenir : retour."
with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich:
with patch("src.medical.cim10_extractor._enrich_dp_only") as mock_dp, \
patch("src.medical.cim10_extractor._enrich_das_and_actes") as mock_das, \
patch("src.medical.cim10_extractor._extract_das_llm"), \
patch("src.medical.cim10_extractor._validate_justifications"):
dossier = extract_medical_info(parsed, text, use_rag=True)
mock_enrich.assert_called_once_with(dossier)
mock_dp.assert_called_once_with(dossier)
mock_das.assert_called_once_with(dossier)
def test_use_rag_default_false(self):
"""Par défaut, use_rag=False."""
@@ -143,9 +147,11 @@ class TestExtractMedicalInfoRAGFlag:
}
text = "Test simple."
with patch("src.medical.cim10_extractor._enrich_with_rag") as mock_enrich:
with patch("src.medical.cim10_extractor._enrich_dp_only") as mock_dp, \
patch("src.medical.cim10_extractor._enrich_das_and_actes") as mock_das:
extract_medical_info(parsed, text)
mock_enrich.assert_not_called()
mock_dp.assert_not_called()
mock_das.assert_not_called()
class TestChunkingCIM10: