Merge branch 'feat/speed-optimizations' — batch QC + pipeline CPU/GPU + QC Dell

This commit is contained in:
dom
2026-03-08 15:45:20 +01:00
5 changed files with 249 additions and 99 deletions

View File

@@ -57,6 +57,7 @@ NER_CONFIDENCE_THRESHOLD = float(os.environ.get("T2A_NER_THRESHOLD", "0.80"))
# --- Configuration Ollama ---
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
OLLAMA_URL_SECONDARY = os.environ.get("OLLAMA_URL_SECONDARY", "http://192.168.1.11:11434")
OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "gemma3:27b")
OLLAMA_TIMEOUT = int(os.environ.get("OLLAMA_TIMEOUT", "600"))
OLLAMA_CACHE_PATH = BASE_DIR / "data" / "ollama_cache.json"
@@ -203,7 +204,7 @@ EMBEDDING_MODEL = os.environ.get("T2A_EMBEDDING_MODEL", "dangvantuan/sentence-ca
# --- Modèle de re-ranking (cross-encoder, CPU uniquement) ---
RERANKER_MODEL = os.environ.get("T2A_RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
RERANKER_MODEL = os.environ.get("T2A_RERANKER_MODEL", "ncbi/MedCPT-Cross-Encoder")
# --- Références biologiques (fallback) ---

View File

@@ -143,11 +143,142 @@ _use_edsnlp = True
_use_rag = True
def process_document(file_path: Path) -> list[tuple[str, DossierMedical, AnonymizationReport]]:
# Type alias pour les données CPU préparées
PreparedChunk = tuple # (parsed, anonymized_text, report, edsnlp_result, chunk_text)
PreparedDoc = tuple # (file_path, doc_type, raw_text, page_tracker, extraction_stats, list[PreparedChunk])
def prepare_document(file_path: Path) -> PreparedDoc:
"""Phase CPU : extraction → splitting → parsing → anonymisation → edsnlp.
Séparé de process_document pour permettre le pipelining CPU/GPU :
pendant que le GPU traite le document N, le CPU prépare le document N+1.
"""
logger.info("Préparation de %s (CPU)", file_path.name)
# 1. Extraction texte
raw_text, page_tracker, extraction_stats = extract_document_with_pages(file_path)
logger.info(" Texte extrait : %d caractères (%d pages, format=%s)",
len(raw_text), extraction_stats.total_pages, extraction_stats.source_format)
# 2. Classification
doc_type = classify(raw_text)
logger.info(" Type de document : %s", doc_type)
# 3. Splitting
chunks = split_documents(raw_text, doc_type)
if len(chunks) > 1:
logger.info(" Découpage : %d dossiers détectés dans %s", len(chunks), file_path.name)
prepared_chunks: list[PreparedChunk] = []
for i, chunk_text in enumerate(chunks):
part_label = f" [part {i+1}/{len(chunks)}]" if len(chunks) > 1 else ""
# 4. Parsing
if doc_type == "trackare":
parsed = parse_trackare(chunk_text)
else:
parsed = parse_crh(chunk_text)
# 5. Anonymisation
anonymizer = Anonymizer(parsed_data=parsed)
anonymized_text = anonymizer.anonymize(chunk_text)
report = anonymizer.report
report.source_file = file_path.name
logger.info(" Anonymisation%s : %d remplacements (regex=%d, ner=%d, sweep=%d)",
part_label, report.total_replacements, report.regex_replacements,
report.ner_replacements, report.sweep_replacements)
# 6. edsnlp
edsnlp_result = None
if _use_edsnlp:
edsnlp_result = _run_edsnlp(anonymized_text)
prepared_chunks.append((parsed, anonymized_text, report, edsnlp_result, chunk_text))
return (file_path, doc_type, raw_text, page_tracker, extraction_stats, prepared_chunks)
def _postprocess_dossier(dossier: DossierMedical, parsed: dict, subdir: str | None = None) -> None:
"""Post-processing déterministe : vetos, décisions, GHM, complétude, finalizer DP."""
# Routage des règles
rules_token = None
try:
rules_ctx = build_rules_runtime_context(dossier)
dossier.rules_runtime = rules_ctx
rules_token = set_rules_runtime(rules_ctx)
except Exception:
logger.error(" Routage règles : erreur", exc_info=True)
dossier.quality_flags["rules_routing"] = "error"
# Vetos
veto = None
try:
veto = apply_vetos(dossier)
dossier.veto_report = veto
except Exception:
logger.error(" Vetos : erreur lors du contrôle", exc_info=True)
dossier.quality_flags["veto_engine"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : moteur de vetos en erreur")
# Décisions
try:
apply_decisions(dossier)
_inject_decision_alerts(dossier, scope="PDF")
if veto is not None:
_inject_veto_alerts(dossier, veto, scope="PDF")
except Exception:
logger.error(" Décisions : erreur lors du post-traitement", exc_info=True)
dossier.quality_flags["decision_engine"] = "error"
finally:
if rules_token is not None:
reset_rules_runtime(rules_token)
# Complétude
try:
dossier.completude = build_completude_checklist(dossier)
except Exception:
logger.error(" Complétude : erreur lors du contrôle", exc_info=True)
dossier.quality_flags["completude"] = "error"
# GHM + métriques
try:
metrics = _compute_metrics(dossier)
ghm = estimate_ghm(dossier)
dossier.ghm_estimation = ghm
logger.info(
" DAS : actifs=%d / total=%d (écartés=%d, removed=%d) | Actes : %d (avec code=%d)",
metrics.das_active, metrics.das_total, metrics.das_excluded,
metrics.das_removed, metrics.actes_total, metrics.actes_with_code,
)
logger.info(" GHM : CMD=%s, Type=%s, Sévérité=%d%s",
ghm.cmd or "?", ghm.type_ghm or "?", ghm.severite, ghm.ghm_approx or "?")
except Exception:
logger.error(" Erreur estimation GHM/metrics", exc_info=True)
dossier.quality_flags["ghm_estimation"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : estimation GHM en erreur")
# Finalizer DP
try:
from .medical.dp_finalizer import finalize_dp
finalize_dp(dossier)
except Exception:
logger.error(" Finalizer DP : erreur", exc_info=True)
dossier.quality_flags["dp_finalizer"] = "error"
def process_document(
file_path: Path,
defer_qc: bool = False,
) -> list[tuple[str, DossierMedical, AnonymizationReport]]:
"""Traite un document : extraction → splitting → parsing → anonymisation → extraction CIM-10.
Supporte PDF, images (JPEG/PNG/TIFF) et DOCX via le router d'extraction.
Args:
file_path: Chemin du document à traiter.
defer_qc: Si True, ne pas exécuter la validation QC (sera faite en batch plus tard).
Retourne une liste de (texte_anonymisé, dossier, rapport) — un par dossier détecté.
"""
t0 = time.time()
@@ -200,6 +331,7 @@ def process_document(file_path: Path) -> list[tuple[str, DossierMedical, Anonymi
dossier = extract_medical_info(
parsed, anonymized_text, edsnlp_result, use_rag=_use_rag,
page_tracker=page_tracker, raw_text=raw_text,
defer_qc=defer_qc,
)
dossier.source_file = file_path.name
dossier.document_type = doc_type
@@ -213,86 +345,8 @@ def process_document(file_path: Path) -> list[tuple[str, DossierMedical, Anonymi
if extraction_alert:
dossier.alertes_codage.append(extraction_alert)
# 8. Vetos (contestabilité) + décisions (post-traitement)
# Routage des règles (packs) : par défaut, on garde le socle vetos/decisions,
# et on active des packs additionnels selon les signaux du dossier (codes/labs/extraits).
rules_token = None
try:
rules_ctx = build_rules_runtime_context(dossier)
dossier.rules_runtime = rules_ctx
rules_token = set_rules_runtime(rules_ctx)
packs = ",".join(rules_ctx.get("enabled_packs", []))
if packs:
logger.info(" Règles%s : packs=%s", part_label, packs)
if rules_ctx.get("triggers_fired"):
logger.info(" Règles%s : triggers=%s", part_label, ",".join(rules_ctx["triggers_fired"]))
except Exception:
logger.error(" Routage règles : erreur", exc_info=True)
dossier.quality_flags["rules_routing"] = "error"
veto = None
try:
veto = apply_vetos(dossier)
dossier.veto_report = veto
except Exception:
logger.error(" Vetos : erreur lors du contrôle", exc_info=True)
dossier.quality_flags["veto_engine"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : moteur de vetos en erreur")
try:
apply_decisions(dossier)
_inject_decision_alerts(dossier, scope="PDF")
if veto is not None:
_inject_veto_alerts(dossier, veto, scope="PDF")
except Exception:
logger.error(" Décisions : erreur lors du post-traitement", exc_info=True)
dossier.quality_flags["decision_engine"] = "error"
finally:
if rules_token is not None:
reset_rules_runtime(rules_token)
try:
dossier.completude = build_completude_checklist(dossier)
except Exception:
logger.error(" Complétude : erreur lors du contrôle", exc_info=True)
dossier.quality_flags["completude"] = "error"
# 9. Estimation GHM (sur codes finaux) + métriques (actifs vs écartés)
try:
metrics = _compute_metrics(dossier)
ghm = estimate_ghm(dossier)
dossier.ghm_estimation = ghm
logger.info(
" DAS : actifs=%d / total=%d (écartés=%d, removed=%d, no_code=%d) | Actes : %d (avec code=%d)",
metrics.das_active,
metrics.das_total,
metrics.das_excluded,
metrics.das_removed,
metrics.das_no_code,
metrics.actes_total,
metrics.actes_with_code,
)
logger.info(
" GHM : CMD=%s, Type=%s, Sévérité=%d%s",
ghm.cmd or "?",
ghm.type_ghm or "?",
ghm.severite,
ghm.ghm_approx or "?",
)
except Exception:
logger.error(" Erreur estimation GHM/metrics", exc_info=True)
dossier.quality_flags["ghm_estimation"] = "error"
dossier.alertes_codage.append("QUALITE DEGRADEE : estimation GHM en erreur")
# 10. Finalizer DP (arbitrage Trackare vs CRH, traçabilité)
try:
from .medical.dp_finalizer import finalize_dp
finalize_dp(dossier)
except Exception:
logger.error(" Finalizer DP : erreur", exc_info=True)
dossier.quality_flags["dp_finalizer"] = "error"
# 8-10. Vetos + Décisions + GHM + Finalizer DP
_postprocess_dossier(dossier, parsed)
dossier.processing_time_s = round(time.time() - t0, 2)
results.append((anonymized_text, dossier, report))
@@ -569,22 +623,82 @@ def main(input_path: str | None = None) -> None:
logger.info("Traitement de %d document(s)...", total)
def _process_group(docs: list[Path], subdir: str | None) -> None:
"""Traite un groupe de documents (un dossier patient)."""
"""Traite un groupe de documents (un dossier patient).
Optimisations appliquées :
- defer_qc=True : la validation QC est différée et batchée après tous les documents
du groupe (1 swap modèle au lieu de 2 par document)
- Si la machine secondaire Dell est disponible, le QC est envoyé là-bas
(0 swap sur le GPU principal)
"""
if subdir:
logger.info("--- Dossier %s (%d documents) ---", subdir, len(docs))
group_dossiers: list[DossierMedical] = []
for doc_path in docs:
try:
doc_results = process_document(doc_path)
stem = doc_path.stem.replace(" ", "_")
multi = len(doc_results) > 1
for part_idx, (anonymized_text, dossier, report) in enumerate(doc_results):
part_stem = f"{stem}_part{part_idx + 1}" if multi else stem
write_outputs(part_stem, anonymized_text, dossier, report, subdir=subdir, export_rum_flag=export_rum_flag)
group_dossiers.append(dossier)
except Exception:
logger.exception("Erreur lors du traitement de %s", doc_path.name)
# Pipeline CPU/GPU : préparer le document N+1 (CPU) pendant que le GPU traite le N
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=1, thread_name_prefix="cpu-prep") as cpu_pool:
# Soumettre la préparation du premier document
prep_future = cpu_pool.submit(prepare_document, docs[0]) if docs else None
for i, doc_path in enumerate(docs):
try:
# Attendre la préparation du document courant
prep = prep_future.result()
# Lancer la préparation du document suivant EN PARALLÈLE du GPU ci-dessous
if i + 1 < len(docs):
prep_future = cpu_pool.submit(prepare_document, docs[i + 1])
# Phase GPU : extract_medical_info + post-processing
t0 = time.time()
file_path_p, doc_type, raw_text, page_tracker, extraction_stats, prepared_chunks = prep
stem = doc_path.stem.replace(" ", "_")
multi = len(prepared_chunks) > 1
for part_idx, (parsed, anonymized_text, report, edsnlp_result, chunk_text) in enumerate(prepared_chunks):
part_stem = f"{stem}_part{part_idx + 1}" if multi else stem
dossier = extract_medical_info(
parsed, anonymized_text, edsnlp_result, use_rag=_use_rag,
page_tracker=page_tracker, raw_text=raw_text, defer_qc=True,
)
dossier.source_file = doc_path.name
dossier.document_type = doc_type
logger.info(" DP : %s", dossier.diagnostic_principal)
# Injection des stats d'extraction
extraction_flags = extraction_stats.to_flags()
if extraction_flags:
dossier.quality_flags.update(extraction_flags)
extraction_alert = extraction_stats.to_alert()
if extraction_alert:
dossier.alertes_codage.append(extraction_alert)
# Vetos + Décisions + Complétude + GHM + Finalizer
_postprocess_dossier(dossier, parsed, subdir=subdir)
dossier.processing_time_s = round(time.time() - t0, 2)
write_outputs(part_stem, anonymized_text, dossier, report, subdir=subdir, export_rum_flag=export_rum_flag)
group_dossiers.append(dossier)
logger.info(" Temps total %s : %.2fs", doc_path.name, time.time() - t0)
except Exception:
logger.exception("Erreur lors du traitement de %s", doc_path.name)
# Batch QC : validation justifications pour tous les dossiers du groupe en une seule passe
# Évite les swaps coding ↔ qc entre chaque document
if _use_rag and group_dossiers:
t_qc = time.time()
from .medical.validation_pipeline import _validate_justifications
for d in group_dossiers:
try:
_validate_justifications(d)
except Exception:
logger.warning("QC batch : erreur pour %s", d.source_file, exc_info=True)
logger.info(" ⏱ [QC-BATCH] %.1fs — %d dossiers validés", time.time() - t_qc, len(group_dossiers))
# Fusion multi-PDFs si plusieurs documents dans le même groupe
merged = None

View File

@@ -65,12 +65,14 @@ def extract_medical_info(
use_rag: bool = False,
page_tracker=None,
raw_text: str | None = None,
defer_qc: bool = False,
) -> DossierMedical:
"""Extrait les informations médicales structurées depuis les données parsées et le texte.
Args:
page_tracker: PageTracker pour la traçabilité page/extrait (optionnel).
raw_text: Texte brut avant anonymisation (pour recherche page source).
defer_qc: Si True, ne pas exécuter la validation QC (sera faite en batch plus tard).
"""
dossier = DossierMedical()
dossier.document_type = parsed_data.get("type", "")
@@ -270,8 +272,9 @@ def extract_medical_info(
except Exception:
logger.error("NUKE-3 reselect après vetos échouée", exc_info=True)
# Post-processing : validation justifications (QC batch)
if use_rag:
# Post-processing : validation justifications (QC)
# Si defer_qc=True, le QC sera fait en batch après tous les dossiers (évite les swaps modèle)
if use_rag and not defer_qc:
t_qc = time.monotonic()
_validate_justifications(dossier)
elapsed_qc = time.monotonic() - t_qc

View File

@@ -9,7 +9,7 @@ import time
import requests
from ..config import OLLAMA_URL, OLLAMA_MODEL, OLLAMA_TIMEOUT, get_model
from ..config import OLLAMA_URL, OLLAMA_URL_SECONDARY, OLLAMA_MODEL, OLLAMA_TIMEOUT, get_model
logger = logging.getLogger(__name__)
@@ -146,6 +146,24 @@ def parse_json_response(raw: str) -> dict | None:
return None
def _is_secondary_available() -> bool:
"""Vérifie si la machine Ollama secondaire est accessible (cache 60s)."""
now = time.monotonic()
if hasattr(_is_secondary_available, "_cache"):
cached_time, cached_result = _is_secondary_available._cache
if now - cached_time < 60:
return cached_result
try:
r = requests.get(f"{OLLAMA_URL_SECONDARY}/api/tags", timeout=2)
result = r.status_code == 200
except Exception:
result = False
_is_secondary_available._cache = (now, result)
if result:
logger.info("Ollama secondaire disponible : %s", OLLAMA_URL_SECONDARY)
return result
def call_ollama(
prompt: str,
temperature: float = 0.1,
@@ -153,6 +171,7 @@ def call_ollama(
model: str | None = None,
timeout: int | None = None,
role: str | None = None,
ollama_url: str | None = None,
) -> dict | None:
"""Appelle Ollama en mode JSON natif, avec fallback Anthropic si indisponible.
@@ -163,12 +182,14 @@ def call_ollama(
model: Modèle Ollama à utiliser (prioritaire sur role).
timeout: Timeout en secondes (défaut: OLLAMA_TIMEOUT global).
role: Rôle LLM (coding, cpam, validation, qc) → résolu via get_model().
ollama_url: URL Ollama à utiliser (prioritaire sur OLLAMA_URL global).
Returns:
Le dict JSON parsé, ou None en cas d'erreur.
"""
use_model = model or (get_model(role) if role else OLLAMA_MODEL)
use_timeout = timeout or OLLAMA_TIMEOUT
use_url = ollama_url or OLLAMA_URL
messages: list[dict] = [{"role": "user", "content": prompt}]
@@ -186,7 +207,7 @@ def call_ollama(
},
}
response = requests.post(
f"{OLLAMA_URL}/api/chat",
f"{use_url}/api/chat",
json=payload,
timeout=use_timeout,
)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import logging
import time
from .cim10_dict import lookup as dict_lookup, normalize_code, validate_code as cim10_validate
from .ccam_dict import validate_code as ccam_validate
@@ -408,8 +409,18 @@ def _validate_justifications(dossier: DossierMedical) -> None:
from ..prompts import QC_VALIDATION
prompt = QC_VALIDATION.format(ctx_str=ctx_str, codes_section=codes_section)
# Si la machine secondaire est disponible, envoyer le QC là-bas
# pour éviter un swap de modèle sur le GPU principal
from .ollama_client import _is_secondary_available
from ..config import OLLAMA_URL_SECONDARY
qc_url = OLLAMA_URL_SECONDARY if _is_secondary_available() else None
try:
result = call_ollama(prompt, temperature=0.1, max_tokens=2500, role="qc")
t0 = time.time()
result = call_ollama(prompt, temperature=0.1, max_tokens=2500, role="qc", ollama_url=qc_url)
elapsed = time.time() - t0
target = "secondaire" if qc_url else "local"
logger.info("⏱ [QC] %.1fs — validation QC (%s)", elapsed, target)
except Exception:
logger.warning("Erreur lors de l'appel Ollama pour validation QC", exc_info=True)
return