From 5cf7d74fa37b297d5e3582381421bf5d0c759bd5 Mon Sep 17 00:00:00 2001 From: dom Date: Fri, 20 Feb 2026 01:30:51 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20parall=C3=A9lisation=20pipeline=20--wor?= =?UTF-8?q?kers=20N=20(ThreadPoolExecutor)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix thread-safety FAISS index (Lock + double-check sur _loaded) - Fix thread-safety reranker (Lock + double-check sur _reranker_model) - main.py : flag --workers, extraction _process_group(), ThreadPoolExecutor - benchmark_quality.py : flag --workers, subprocess en parallèle - Validé sur 10 dossiers gold standard --workers 3 : 0 crash, codes identiques Co-Authored-By: Claude Opus 4.6 --- scripts/benchmark_quality.py | 61 +++++++++++++++++++++++++++--------- src/main.py | 27 +++++++++++++++- src/medical/rag_index.py | 48 ++++++++++++++++------------ src/medical/rag_search.py | 10 ++++-- 4 files changed, 109 insertions(+), 37 deletions(-) diff --git a/scripts/benchmark_quality.py b/scripts/benchmark_quality.py index f634b5b..c6fb564 100644 --- a/scripts/benchmark_quality.py +++ b/scripts/benchmark_quality.py @@ -594,6 +594,7 @@ def main(): parser.add_argument("--no-reprocess", action="store_true", help="Analyser les outputs existants sans relancer le pipeline") parser.add_argument("--clean", action="store_true", help="Supprimer les outputs avant retraitement") parser.add_argument("--seed", type=int, default=42, help="Seed pour la sélection aléatoire") + parser.add_argument("--workers", type=int, default=1, help="Nombre de dossiers traités en parallèle") args = parser.parse_args() # Sélection dossiers @@ -632,23 +633,55 @@ def main(): # Traitement per_dossier = [] - for i, dossier_id in enumerate(dossiers, 1): - print(f" [{i}/{len(dossiers)}] {dossier_id}", end="", flush=True) + total = len(dossiers) - if args.no_reprocess: - duration = 0.0 - success = find_merged_json(dossier_id) is not None - if not success: - print(" — pas de JSON") + if args.workers > 1 and not args.no_reprocess: + # Mode parallèle : exécuter les pipelines en parallèle puis analyser + from concurrent.futures import ThreadPoolExecutor, as_completed + print(f" Mode parallèle : {args.workers} workers") + pipeline_results: dict[str, tuple[float, bool]] = {} + done = 0 + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = { + executor.submit(run_pipeline, dossier_id, args.clean): dossier_id + for dossier_id in dossiers + } + for future in as_completed(futures): + dossier_id = futures[future] + try: + duration, success = future.result() + except Exception as e: + print(f" EXCEPTION {dossier_id}: {e}") + duration, success = 0.0, False + pipeline_results[dossier_id] = (duration, success) + done += 1 + mark = "✓" if success else "✗" + print(f" [{done}/{total}] {dossier_id} — {duration:.1f}s {mark}") + + # Analyse séquentielle (ordre stable) + for dossier_id in dossiers: + duration, success = pipeline_results[dossier_id] + metrics = analyze_dossier(dossier_id, cim10, duration) + per_dossier.append(metrics) + else: + # Mode séquentiel (ou --no-reprocess) + for i, dossier_id in enumerate(dossiers, 1): + print(f" [{i}/{total}] {dossier_id}", end="", flush=True) + + if args.no_reprocess: + duration = 0.0 + success = find_merged_json(dossier_id) is not None + if not success: + print(" — pas de JSON") + else: + print(" — analyse existant") else: - print(" — analyse existant") - else: - print(" — traitement...", end="", flush=True) - duration, success = run_pipeline(dossier_id, args.clean) - print(f" {duration:.1f}s {'✓' if success else '✗'}") + print(" — traitement...", end="", flush=True) + duration, success = run_pipeline(dossier_id, args.clean) + print(f" {duration:.1f}s {'✓' if success else '✗'}") - metrics = analyze_dossier(dossier_id, cim10, duration) - per_dossier.append(metrics) + metrics = analyze_dossier(dossier_id, cim10, duration) + per_dossier.append(metrics) # Agrégation agg = compute_aggregate(per_dossier) diff --git a/src/main.py b/src/main.py index 6f86f24..cb66880 100644 --- a/src/main.py +++ b/src/main.py @@ -399,6 +399,12 @@ def main(input_path: str | None = None) -> None: metavar="PATH", help="Fichier Excel de contrôle CPAM (enrichit les dossiers avec contre-argumentation)", ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Nombre de dossiers traités en parallèle (défaut: 1)", + ) args = parser.parse_args() if args.build_dict: @@ -501,7 +507,8 @@ def main(input_path: str | None = None) -> None: logger.info("Traitement de %d PDF(s)...", total) - for pdfs, subdir in groups: + def _process_group(pdfs: list[Path], subdir: str | None) -> None: + """Traite un groupe de PDFs (un dossier patient).""" if subdir: logger.info("--- Dossier %s (%d PDFs) ---", subdir, len(pdfs)) @@ -633,6 +640,24 @@ def main(input_path: str | None = None) -> None: except Exception: logger.exception("Erreur écriture dossier fusionné %s", subdir) + # Exécution séquentielle ou parallèle selon --workers + if args.workers > 1: + from concurrent.futures import ThreadPoolExecutor, as_completed + logger.info("Mode parallèle : %d workers", args.workers) + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = { + executor.submit(_process_group, pdfs, subdir): subdir + for pdfs, subdir in groups + } + for future in as_completed(futures): + try: + future.result() + except Exception: + logger.exception("Erreur groupe %s", futures[future]) + else: + for pdfs, subdir in groups: + _process_group(pdfs, subdir) + logger.info("Terminé.") diff --git a/src/medical/rag_index.py b/src/medical/rag_index.py index fa37576..fabeebf 100644 --- a/src/medical/rag_index.py +++ b/src/medical/rag_index.py @@ -14,6 +14,7 @@ from __future__ import annotations import json import logging import re +import threading from dataclasses import dataclass, asdict from pathlib import Path from typing import Optional @@ -26,6 +27,7 @@ logger = logging.getLogger(__name__) # Singletons pour les index chargés en mémoire _loaded: dict[str, tuple] = {} +_loaded_lock = threading.Lock() @dataclass @@ -577,30 +579,35 @@ def get_index(kind: str = "ref") -> tuple | None: if kind in _loaded: return _loaded[kind] - index_path, meta_path = _paths(kind) + with _loaded_lock: + # Double-check après acquisition du lock + if kind in _loaded: + return _loaded[kind] - # Backwards compat : si ref/proc absent, fallback sur all - if kind in ("ref", "proc") and (not index_path.exists() or not meta_path.exists()): - legacy_idx, legacy_meta = _paths("all") - if legacy_idx.exists() and legacy_meta.exists(): - logger.warning("Index %s absent — fallback legacy faiss.index", kind) - index_path, meta_path = legacy_idx, legacy_meta - else: - logger.warning("Index FAISS non trouvé dans %s — lancez build_index() d'abord", RAG_INDEX_DIR) + index_path, meta_path = _paths(kind) + + # Backwards compat : si ref/proc absent, fallback sur all + if kind in ("ref", "proc") and (not index_path.exists() or not meta_path.exists()): + legacy_idx, legacy_meta = _paths("all") + if legacy_idx.exists() and legacy_meta.exists(): + logger.warning("Index %s absent — fallback legacy faiss.index", kind) + index_path, meta_path = legacy_idx, legacy_meta + else: + logger.warning("Index FAISS non trouvé dans %s — lancez build_index() d'abord", RAG_INDEX_DIR) + return None + + if not index_path.exists() or not meta_path.exists(): + logger.warning("Index FAISS non trouvé (%s) dans %s — lancez build_index() d'abord", kind, RAG_INDEX_DIR) return None - if not index_path.exists() or not meta_path.exists(): - logger.warning("Index FAISS non trouvé (%s) dans %s — lancez build_index() d'abord", kind, RAG_INDEX_DIR) - return None + import faiss - import faiss + faiss_index = faiss.read_index(str(index_path)) + metadata = json.loads(meta_path.read_text(encoding="utf-8")) - faiss_index = faiss.read_index(str(index_path)) - metadata = json.loads(meta_path.read_text(encoding="utf-8")) - - logger.info("Index FAISS chargé (%s) : %d vecteurs", kind, faiss_index.ntotal) - _loaded[kind] = (faiss_index, metadata) - return _loaded[kind] + logger.info("Index FAISS chargé (%s) : %d vecteurs", kind, faiss_index.ntotal) + _loaded[kind] = (faiss_index, metadata) + return _loaded[kind] # --------------------------------------------------------------------------- @@ -800,4 +807,5 @@ def add_chunks_to_index(chunks: list[Chunk]) -> int: def reset_index() -> None: """Invalide les singletons FAISS pour forcer le rechargement au prochain accès.""" - _loaded.clear() + with _loaded_lock: + _loaded.clear() diff --git a/src/medical/rag_search.py b/src/medical/rag_search.py index 21cbcd6..a3b207b 100644 --- a/src/medical/rag_search.py +++ b/src/medical/rag_search.py @@ -28,6 +28,7 @@ _embed_failed = False # Sentinelle pour éviter les retries infinis # Singleton pour le cross-encoder de re-ranking (CPU uniquement) _reranker_model = None +_reranker_lock = threading.Lock() # Score minimum de similarité FAISS pour retenir un résultat _MIN_SCORE = 0.3 @@ -84,12 +85,17 @@ def _get_embed_model(): def _get_reranker(): - """Charge le cross-encoder de re-ranking (singleton, CPU uniquement). + """Charge le cross-encoder de re-ranking (singleton thread-safe, CPU uniquement). Forcé sur CPU pour ne pas interférer avec Ollama sur GPU. """ global _reranker_model - if _reranker_model is None: + if _reranker_model is not None: + return _reranker_model + with _reranker_lock: + # Double-check après acquisition du lock + if _reranker_model is not None: + return _reranker_model from sentence_transformers import CrossEncoder logger.info("Chargement du cross-encoder de re-ranking (cpu)...") _reranker_model = CrossEncoder(RERANKER_MODEL, device="cpu")