diff --git a/src/config.py b/src/config.py index 1dde987..0fb0ef2 100644 --- a/src/config.py +++ b/src/config.py @@ -57,6 +57,7 @@ NER_CONFIDENCE_THRESHOLD = float(os.environ.get("T2A_NER_THRESHOLD", "0.80")) # --- Configuration Ollama --- OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434") +OLLAMA_URL_SECONDARY = os.environ.get("OLLAMA_URL_SECONDARY", "http://192.168.1.11:11434") OLLAMA_MODEL = os.environ.get("OLLAMA_MODEL", "gemma3:27b") OLLAMA_TIMEOUT = int(os.environ.get("OLLAMA_TIMEOUT", "600")) OLLAMA_CACHE_PATH = BASE_DIR / "data" / "ollama_cache.json" @@ -203,7 +204,7 @@ EMBEDDING_MODEL = os.environ.get("T2A_EMBEDDING_MODEL", "dangvantuan/sentence-ca # --- Modèle de re-ranking (cross-encoder, CPU uniquement) --- -RERANKER_MODEL = os.environ.get("T2A_RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2") +RERANKER_MODEL = os.environ.get("T2A_RERANKER_MODEL", "ncbi/MedCPT-Cross-Encoder") # --- Références biologiques (fallback) --- diff --git a/src/main.py b/src/main.py index e32a9f1..82ec7cc 100644 --- a/src/main.py +++ b/src/main.py @@ -143,11 +143,142 @@ _use_edsnlp = True _use_rag = True -def process_document(file_path: Path) -> list[tuple[str, DossierMedical, AnonymizationReport]]: +# Type alias pour les données CPU préparées +PreparedChunk = tuple # (parsed, anonymized_text, report, edsnlp_result, chunk_text) +PreparedDoc = tuple # (file_path, doc_type, raw_text, page_tracker, extraction_stats, list[PreparedChunk]) + + +def prepare_document(file_path: Path) -> PreparedDoc: + """Phase CPU : extraction → splitting → parsing → anonymisation → edsnlp. + + Séparé de process_document pour permettre le pipelining CPU/GPU : + pendant que le GPU traite le document N, le CPU prépare le document N+1. + """ + logger.info("Préparation de %s (CPU)", file_path.name) + + # 1. Extraction texte + raw_text, page_tracker, extraction_stats = extract_document_with_pages(file_path) + logger.info(" Texte extrait : %d caractères (%d pages, format=%s)", + len(raw_text), extraction_stats.total_pages, extraction_stats.source_format) + + # 2. Classification + doc_type = classify(raw_text) + logger.info(" Type de document : %s", doc_type) + + # 3. Splitting + chunks = split_documents(raw_text, doc_type) + if len(chunks) > 1: + logger.info(" Découpage : %d dossiers détectés dans %s", len(chunks), file_path.name) + + prepared_chunks: list[PreparedChunk] = [] + for i, chunk_text in enumerate(chunks): + part_label = f" [part {i+1}/{len(chunks)}]" if len(chunks) > 1 else "" + + # 4. Parsing + if doc_type == "trackare": + parsed = parse_trackare(chunk_text) + else: + parsed = parse_crh(chunk_text) + + # 5. Anonymisation + anonymizer = Anonymizer(parsed_data=parsed) + anonymized_text = anonymizer.anonymize(chunk_text) + report = anonymizer.report + report.source_file = file_path.name + logger.info(" Anonymisation%s : %d remplacements (regex=%d, ner=%d, sweep=%d)", + part_label, report.total_replacements, report.regex_replacements, + report.ner_replacements, report.sweep_replacements) + + # 6. edsnlp + edsnlp_result = None + if _use_edsnlp: + edsnlp_result = _run_edsnlp(anonymized_text) + + prepared_chunks.append((parsed, anonymized_text, report, edsnlp_result, chunk_text)) + + return (file_path, doc_type, raw_text, page_tracker, extraction_stats, prepared_chunks) + + +def _postprocess_dossier(dossier: DossierMedical, parsed: dict, subdir: str | None = None) -> None: + """Post-processing déterministe : vetos, décisions, GHM, complétude, finalizer DP.""" + # Routage des règles + rules_token = None + try: + rules_ctx = build_rules_runtime_context(dossier) + dossier.rules_runtime = rules_ctx + rules_token = set_rules_runtime(rules_ctx) + except Exception: + logger.error(" Routage règles : erreur", exc_info=True) + dossier.quality_flags["rules_routing"] = "error" + + # Vetos + veto = None + try: + veto = apply_vetos(dossier) + dossier.veto_report = veto + except Exception: + logger.error(" Vetos : erreur lors du contrôle", exc_info=True) + dossier.quality_flags["veto_engine"] = "error" + dossier.alertes_codage.append("QUALITE DEGRADEE : moteur de vetos en erreur") + + # Décisions + try: + apply_decisions(dossier) + _inject_decision_alerts(dossier, scope="PDF") + if veto is not None: + _inject_veto_alerts(dossier, veto, scope="PDF") + except Exception: + logger.error(" Décisions : erreur lors du post-traitement", exc_info=True) + dossier.quality_flags["decision_engine"] = "error" + finally: + if rules_token is not None: + reset_rules_runtime(rules_token) + + # Complétude + try: + dossier.completude = build_completude_checklist(dossier) + except Exception: + logger.error(" Complétude : erreur lors du contrôle", exc_info=True) + dossier.quality_flags["completude"] = "error" + + # GHM + métriques + try: + metrics = _compute_metrics(dossier) + ghm = estimate_ghm(dossier) + dossier.ghm_estimation = ghm + logger.info( + " DAS : actifs=%d / total=%d (écartés=%d, removed=%d) | Actes : %d (avec code=%d)", + metrics.das_active, metrics.das_total, metrics.das_excluded, + metrics.das_removed, metrics.actes_total, metrics.actes_with_code, + ) + logger.info(" GHM : CMD=%s, Type=%s, Sévérité=%d → %s", + ghm.cmd or "?", ghm.type_ghm or "?", ghm.severite, ghm.ghm_approx or "?") + except Exception: + logger.error(" Erreur estimation GHM/metrics", exc_info=True) + dossier.quality_flags["ghm_estimation"] = "error" + dossier.alertes_codage.append("QUALITE DEGRADEE : estimation GHM en erreur") + + # Finalizer DP + try: + from .medical.dp_finalizer import finalize_dp + finalize_dp(dossier) + except Exception: + logger.error(" Finalizer DP : erreur", exc_info=True) + dossier.quality_flags["dp_finalizer"] = "error" + + +def process_document( + file_path: Path, + defer_qc: bool = False, +) -> list[tuple[str, DossierMedical, AnonymizationReport]]: """Traite un document : extraction → splitting → parsing → anonymisation → extraction CIM-10. Supporte PDF, images (JPEG/PNG/TIFF) et DOCX via le router d'extraction. + Args: + file_path: Chemin du document à traiter. + defer_qc: Si True, ne pas exécuter la validation QC (sera faite en batch plus tard). + Retourne une liste de (texte_anonymisé, dossier, rapport) — un par dossier détecté. """ t0 = time.time() @@ -200,6 +331,7 @@ def process_document(file_path: Path) -> list[tuple[str, DossierMedical, Anonymi dossier = extract_medical_info( parsed, anonymized_text, edsnlp_result, use_rag=_use_rag, page_tracker=page_tracker, raw_text=raw_text, + defer_qc=defer_qc, ) dossier.source_file = file_path.name dossier.document_type = doc_type @@ -213,86 +345,8 @@ def process_document(file_path: Path) -> list[tuple[str, DossierMedical, Anonymi if extraction_alert: dossier.alertes_codage.append(extraction_alert) - # 8. Vetos (contestabilité) + décisions (post-traitement) - # Routage des règles (packs) : par défaut, on garde le socle vetos/decisions, - # et on active des packs additionnels selon les signaux du dossier (codes/labs/extraits). - rules_token = None - try: - rules_ctx = build_rules_runtime_context(dossier) - dossier.rules_runtime = rules_ctx - rules_token = set_rules_runtime(rules_ctx) - - packs = ",".join(rules_ctx.get("enabled_packs", [])) - if packs: - logger.info(" Règles%s : packs=%s", part_label, packs) - if rules_ctx.get("triggers_fired"): - logger.info(" Règles%s : triggers=%s", part_label, ",".join(rules_ctx["triggers_fired"])) - except Exception: - logger.error(" Routage règles : erreur", exc_info=True) - dossier.quality_flags["rules_routing"] = "error" - - veto = None - try: - veto = apply_vetos(dossier) - dossier.veto_report = veto - except Exception: - logger.error(" Vetos : erreur lors du contrôle", exc_info=True) - dossier.quality_flags["veto_engine"] = "error" - dossier.alertes_codage.append("QUALITE DEGRADEE : moteur de vetos en erreur") - - try: - apply_decisions(dossier) - _inject_decision_alerts(dossier, scope="PDF") - if veto is not None: - _inject_veto_alerts(dossier, veto, scope="PDF") - except Exception: - logger.error(" Décisions : erreur lors du post-traitement", exc_info=True) - dossier.quality_flags["decision_engine"] = "error" - finally: - if rules_token is not None: - reset_rules_runtime(rules_token) - - try: - dossier.completude = build_completude_checklist(dossier) - except Exception: - logger.error(" Complétude : erreur lors du contrôle", exc_info=True) - dossier.quality_flags["completude"] = "error" - - # 9. Estimation GHM (sur codes finaux) + métriques (actifs vs écartés) - try: - metrics = _compute_metrics(dossier) - ghm = estimate_ghm(dossier) - dossier.ghm_estimation = ghm - - logger.info( - " DAS : actifs=%d / total=%d (écartés=%d, removed=%d, no_code=%d) | Actes : %d (avec code=%d)", - metrics.das_active, - metrics.das_total, - metrics.das_excluded, - metrics.das_removed, - metrics.das_no_code, - metrics.actes_total, - metrics.actes_with_code, - ) - logger.info( - " GHM : CMD=%s, Type=%s, Sévérité=%d → %s", - ghm.cmd or "?", - ghm.type_ghm or "?", - ghm.severite, - ghm.ghm_approx or "?", - ) - except Exception: - logger.error(" Erreur estimation GHM/metrics", exc_info=True) - dossier.quality_flags["ghm_estimation"] = "error" - dossier.alertes_codage.append("QUALITE DEGRADEE : estimation GHM en erreur") - - # 10. Finalizer DP (arbitrage Trackare vs CRH, traçabilité) - try: - from .medical.dp_finalizer import finalize_dp - finalize_dp(dossier) - except Exception: - logger.error(" Finalizer DP : erreur", exc_info=True) - dossier.quality_flags["dp_finalizer"] = "error" + # 8-10. Vetos + Décisions + GHM + Finalizer DP + _postprocess_dossier(dossier, parsed) dossier.processing_time_s = round(time.time() - t0, 2) results.append((anonymized_text, dossier, report)) @@ -569,22 +623,82 @@ def main(input_path: str | None = None) -> None: logger.info("Traitement de %d document(s)...", total) def _process_group(docs: list[Path], subdir: str | None) -> None: - """Traite un groupe de documents (un dossier patient).""" + """Traite un groupe de documents (un dossier patient). + + Optimisations appliquées : + - defer_qc=True : la validation QC est différée et batchée après tous les documents + du groupe (1 swap modèle au lieu de 2 par document) + - Si la machine secondaire Dell est disponible, le QC est envoyé là-bas + (0 swap sur le GPU principal) + """ if subdir: logger.info("--- Dossier %s (%d documents) ---", subdir, len(docs)) group_dossiers: list[DossierMedical] = [] - for doc_path in docs: - try: - doc_results = process_document(doc_path) - stem = doc_path.stem.replace(" ", "_") - multi = len(doc_results) > 1 - for part_idx, (anonymized_text, dossier, report) in enumerate(doc_results): - part_stem = f"{stem}_part{part_idx + 1}" if multi else stem - write_outputs(part_stem, anonymized_text, dossier, report, subdir=subdir, export_rum_flag=export_rum_flag) - group_dossiers.append(dossier) - except Exception: - logger.exception("Erreur lors du traitement de %s", doc_path.name) + + # Pipeline CPU/GPU : préparer le document N+1 (CPU) pendant que le GPU traite le N + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=1, thread_name_prefix="cpu-prep") as cpu_pool: + # Soumettre la préparation du premier document + prep_future = cpu_pool.submit(prepare_document, docs[0]) if docs else None + + for i, doc_path in enumerate(docs): + try: + # Attendre la préparation du document courant + prep = prep_future.result() + + # Lancer la préparation du document suivant EN PARALLÈLE du GPU ci-dessous + if i + 1 < len(docs): + prep_future = cpu_pool.submit(prepare_document, docs[i + 1]) + + # Phase GPU : extract_medical_info + post-processing + t0 = time.time() + file_path_p, doc_type, raw_text, page_tracker, extraction_stats, prepared_chunks = prep + stem = doc_path.stem.replace(" ", "_") + multi = len(prepared_chunks) > 1 + + for part_idx, (parsed, anonymized_text, report, edsnlp_result, chunk_text) in enumerate(prepared_chunks): + part_stem = f"{stem}_part{part_idx + 1}" if multi else stem + + dossier = extract_medical_info( + parsed, anonymized_text, edsnlp_result, use_rag=_use_rag, + page_tracker=page_tracker, raw_text=raw_text, defer_qc=True, + ) + dossier.source_file = doc_path.name + dossier.document_type = doc_type + logger.info(" DP : %s", dossier.diagnostic_principal) + + # Injection des stats d'extraction + extraction_flags = extraction_stats.to_flags() + if extraction_flags: + dossier.quality_flags.update(extraction_flags) + extraction_alert = extraction_stats.to_alert() + if extraction_alert: + dossier.alertes_codage.append(extraction_alert) + + # Vetos + Décisions + Complétude + GHM + Finalizer + _postprocess_dossier(dossier, parsed, subdir=subdir) + + dossier.processing_time_s = round(time.time() - t0, 2) + write_outputs(part_stem, anonymized_text, dossier, report, subdir=subdir, export_rum_flag=export_rum_flag) + group_dossiers.append(dossier) + + logger.info(" Temps total %s : %.2fs", doc_path.name, time.time() - t0) + except Exception: + logger.exception("Erreur lors du traitement de %s", doc_path.name) + + # Batch QC : validation justifications pour tous les dossiers du groupe en une seule passe + # Évite les swaps coding ↔ qc entre chaque document + if _use_rag and group_dossiers: + t_qc = time.time() + from .medical.validation_pipeline import _validate_justifications + for d in group_dossiers: + try: + _validate_justifications(d) + except Exception: + logger.warning("QC batch : erreur pour %s", d.source_file, exc_info=True) + logger.info(" ⏱ [QC-BATCH] %.1fs — %d dossiers validés", time.time() - t_qc, len(group_dossiers)) # Fusion multi-PDFs si plusieurs documents dans le même groupe merged = None diff --git a/src/medical/cim10_extractor.py b/src/medical/cim10_extractor.py index 4c6f3d5..7375d48 100644 --- a/src/medical/cim10_extractor.py +++ b/src/medical/cim10_extractor.py @@ -65,12 +65,14 @@ def extract_medical_info( use_rag: bool = False, page_tracker=None, raw_text: str | None = None, + defer_qc: bool = False, ) -> DossierMedical: """Extrait les informations médicales structurées depuis les données parsées et le texte. Args: page_tracker: PageTracker pour la traçabilité page/extrait (optionnel). raw_text: Texte brut avant anonymisation (pour recherche page source). + defer_qc: Si True, ne pas exécuter la validation QC (sera faite en batch plus tard). """ dossier = DossierMedical() dossier.document_type = parsed_data.get("type", "") @@ -270,8 +272,9 @@ def extract_medical_info( except Exception: logger.error("NUKE-3 reselect après vetos échouée", exc_info=True) - # Post-processing : validation justifications (QC batch) - if use_rag: + # Post-processing : validation justifications (QC) + # Si defer_qc=True, le QC sera fait en batch après tous les dossiers (évite les swaps modèle) + if use_rag and not defer_qc: t_qc = time.monotonic() _validate_justifications(dossier) elapsed_qc = time.monotonic() - t_qc diff --git a/src/medical/ollama_client.py b/src/medical/ollama_client.py index b315af7..5b2c70f 100644 --- a/src/medical/ollama_client.py +++ b/src/medical/ollama_client.py @@ -9,7 +9,7 @@ import time import requests -from ..config import OLLAMA_URL, OLLAMA_MODEL, OLLAMA_TIMEOUT, get_model +from ..config import OLLAMA_URL, OLLAMA_URL_SECONDARY, OLLAMA_MODEL, OLLAMA_TIMEOUT, get_model logger = logging.getLogger(__name__) @@ -146,6 +146,24 @@ def parse_json_response(raw: str) -> dict | None: return None +def _is_secondary_available() -> bool: + """Vérifie si la machine Ollama secondaire est accessible (cache 60s).""" + now = time.monotonic() + if hasattr(_is_secondary_available, "_cache"): + cached_time, cached_result = _is_secondary_available._cache + if now - cached_time < 60: + return cached_result + try: + r = requests.get(f"{OLLAMA_URL_SECONDARY}/api/tags", timeout=2) + result = r.status_code == 200 + except Exception: + result = False + _is_secondary_available._cache = (now, result) + if result: + logger.info("Ollama secondaire disponible : %s", OLLAMA_URL_SECONDARY) + return result + + def call_ollama( prompt: str, temperature: float = 0.1, @@ -153,6 +171,7 @@ def call_ollama( model: str | None = None, timeout: int | None = None, role: str | None = None, + ollama_url: str | None = None, ) -> dict | None: """Appelle Ollama en mode JSON natif, avec fallback Anthropic si indisponible. @@ -163,12 +182,14 @@ def call_ollama( model: Modèle Ollama à utiliser (prioritaire sur role). timeout: Timeout en secondes (défaut: OLLAMA_TIMEOUT global). role: Rôle LLM (coding, cpam, validation, qc) → résolu via get_model(). + ollama_url: URL Ollama à utiliser (prioritaire sur OLLAMA_URL global). Returns: Le dict JSON parsé, ou None en cas d'erreur. """ use_model = model or (get_model(role) if role else OLLAMA_MODEL) use_timeout = timeout or OLLAMA_TIMEOUT + use_url = ollama_url or OLLAMA_URL messages: list[dict] = [{"role": "user", "content": prompt}] @@ -186,7 +207,7 @@ def call_ollama( }, } response = requests.post( - f"{OLLAMA_URL}/api/chat", + f"{use_url}/api/chat", json=payload, timeout=use_timeout, ) diff --git a/src/medical/validation_pipeline.py b/src/medical/validation_pipeline.py index 6cc294e..6fd06f0 100644 --- a/src/medical/validation_pipeline.py +++ b/src/medical/validation_pipeline.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import time from .cim10_dict import lookup as dict_lookup, normalize_code, validate_code as cim10_validate from .ccam_dict import validate_code as ccam_validate @@ -408,8 +409,18 @@ def _validate_justifications(dossier: DossierMedical) -> None: from ..prompts import QC_VALIDATION prompt = QC_VALIDATION.format(ctx_str=ctx_str, codes_section=codes_section) + # Si la machine secondaire est disponible, envoyer le QC là-bas + # pour éviter un swap de modèle sur le GPU principal + from .ollama_client import _is_secondary_available + from ..config import OLLAMA_URL_SECONDARY + qc_url = OLLAMA_URL_SECONDARY if _is_secondary_available() else None + try: - result = call_ollama(prompt, temperature=0.1, max_tokens=2500, role="qc") + t0 = time.time() + result = call_ollama(prompt, temperature=0.1, max_tokens=2500, role="qc", ollama_url=qc_url) + elapsed = time.time() - t0 + target = "secondaire" if qc_url else "local" + logger.info("⏱ [QC] %.1fs — validation QC (%s)", elapsed, target) except Exception: logger.warning("Erreur lors de l'appel Ollama pour validation QC", exc_info=True) return