feat: parallélisation pipeline --workers N (ThreadPoolExecutor)

- Fix thread-safety FAISS index (Lock + double-check sur _loaded)
- Fix thread-safety reranker (Lock + double-check sur _reranker_model)
- main.py : flag --workers, extraction _process_group(), ThreadPoolExecutor
- benchmark_quality.py : flag --workers, subprocess en parallèle
- Validé sur 10 dossiers gold standard --workers 3 : 0 crash, codes identiques

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dom
2026-02-20 01:30:51 +01:00
parent 0b94299975
commit 5cf7d74fa3
4 changed files with 109 additions and 37 deletions

View File

@@ -594,6 +594,7 @@ def main():
parser.add_argument("--no-reprocess", action="store_true", help="Analyser les outputs existants sans relancer le pipeline") parser.add_argument("--no-reprocess", action="store_true", help="Analyser les outputs existants sans relancer le pipeline")
parser.add_argument("--clean", action="store_true", help="Supprimer les outputs avant retraitement") parser.add_argument("--clean", action="store_true", help="Supprimer les outputs avant retraitement")
parser.add_argument("--seed", type=int, default=42, help="Seed pour la sélection aléatoire") parser.add_argument("--seed", type=int, default=42, help="Seed pour la sélection aléatoire")
parser.add_argument("--workers", type=int, default=1, help="Nombre de dossiers traités en parallèle")
args = parser.parse_args() args = parser.parse_args()
# Sélection dossiers # Sélection dossiers
@@ -632,23 +633,55 @@ def main():
# Traitement # Traitement
per_dossier = [] per_dossier = []
for i, dossier_id in enumerate(dossiers, 1): total = len(dossiers)
print(f" [{i}/{len(dossiers)}] {dossier_id}", end="", flush=True)
if args.no_reprocess: if args.workers > 1 and not args.no_reprocess:
duration = 0.0 # Mode parallèle : exécuter les pipelines en parallèle puis analyser
success = find_merged_json(dossier_id) is not None from concurrent.futures import ThreadPoolExecutor, as_completed
if not success: print(f" Mode parallèle : {args.workers} workers")
print(" — pas de JSON") pipeline_results: dict[str, tuple[float, bool]] = {}
done = 0
with ThreadPoolExecutor(max_workers=args.workers) as executor:
futures = {
executor.submit(run_pipeline, dossier_id, args.clean): dossier_id
for dossier_id in dossiers
}
for future in as_completed(futures):
dossier_id = futures[future]
try:
duration, success = future.result()
except Exception as e:
print(f" EXCEPTION {dossier_id}: {e}")
duration, success = 0.0, False
pipeline_results[dossier_id] = (duration, success)
done += 1
mark = "" if success else ""
print(f" [{done}/{total}] {dossier_id}{duration:.1f}s {mark}")
# Analyse séquentielle (ordre stable)
for dossier_id in dossiers:
duration, success = pipeline_results[dossier_id]
metrics = analyze_dossier(dossier_id, cim10, duration)
per_dossier.append(metrics)
else:
# Mode séquentiel (ou --no-reprocess)
for i, dossier_id in enumerate(dossiers, 1):
print(f" [{i}/{total}] {dossier_id}", end="", flush=True)
if args.no_reprocess:
duration = 0.0
success = find_merged_json(dossier_id) is not None
if not success:
print(" — pas de JSON")
else:
print(" — analyse existant")
else: else:
print("analyse existant") print("traitement...", end="", flush=True)
else: duration, success = run_pipeline(dossier_id, args.clean)
print(" — traitement...", end="", flush=True) print(f" {duration:.1f}s {'' if success else ''}")
duration, success = run_pipeline(dossier_id, args.clean)
print(f" {duration:.1f}s {'' if success else ''}")
metrics = analyze_dossier(dossier_id, cim10, duration) metrics = analyze_dossier(dossier_id, cim10, duration)
per_dossier.append(metrics) per_dossier.append(metrics)
# Agrégation # Agrégation
agg = compute_aggregate(per_dossier) agg = compute_aggregate(per_dossier)

View File

@@ -399,6 +399,12 @@ def main(input_path: str | None = None) -> None:
metavar="PATH", metavar="PATH",
help="Fichier Excel de contrôle CPAM (enrichit les dossiers avec contre-argumentation)", help="Fichier Excel de contrôle CPAM (enrichit les dossiers avec contre-argumentation)",
) )
parser.add_argument(
"--workers",
type=int,
default=1,
help="Nombre de dossiers traités en parallèle (défaut: 1)",
)
args = parser.parse_args() args = parser.parse_args()
if args.build_dict: if args.build_dict:
@@ -501,7 +507,8 @@ def main(input_path: str | None = None) -> None:
logger.info("Traitement de %d PDF(s)...", total) logger.info("Traitement de %d PDF(s)...", total)
for pdfs, subdir in groups: def _process_group(pdfs: list[Path], subdir: str | None) -> None:
"""Traite un groupe de PDFs (un dossier patient)."""
if subdir: if subdir:
logger.info("--- Dossier %s (%d PDFs) ---", subdir, len(pdfs)) logger.info("--- Dossier %s (%d PDFs) ---", subdir, len(pdfs))
@@ -633,6 +640,24 @@ def main(input_path: str | None = None) -> None:
except Exception: except Exception:
logger.exception("Erreur écriture dossier fusionné %s", subdir) logger.exception("Erreur écriture dossier fusionné %s", subdir)
# Exécution séquentielle ou parallèle selon --workers
if args.workers > 1:
from concurrent.futures import ThreadPoolExecutor, as_completed
logger.info("Mode parallèle : %d workers", args.workers)
with ThreadPoolExecutor(max_workers=args.workers) as executor:
futures = {
executor.submit(_process_group, pdfs, subdir): subdir
for pdfs, subdir in groups
}
for future in as_completed(futures):
try:
future.result()
except Exception:
logger.exception("Erreur groupe %s", futures[future])
else:
for pdfs, subdir in groups:
_process_group(pdfs, subdir)
logger.info("Terminé.") logger.info("Terminé.")

View File

@@ -14,6 +14,7 @@ from __future__ import annotations
import json import json
import logging import logging
import re import re
import threading
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -26,6 +27,7 @@ logger = logging.getLogger(__name__)
# Singletons pour les index chargés en mémoire # Singletons pour les index chargés en mémoire
_loaded: dict[str, tuple] = {} _loaded: dict[str, tuple] = {}
_loaded_lock = threading.Lock()
@dataclass @dataclass
@@ -577,30 +579,35 @@ def get_index(kind: str = "ref") -> tuple | None:
if kind in _loaded: if kind in _loaded:
return _loaded[kind] return _loaded[kind]
index_path, meta_path = _paths(kind) with _loaded_lock:
# Double-check après acquisition du lock
if kind in _loaded:
return _loaded[kind]
# Backwards compat : si ref/proc absent, fallback sur all index_path, meta_path = _paths(kind)
if kind in ("ref", "proc") and (not index_path.exists() or not meta_path.exists()):
legacy_idx, legacy_meta = _paths("all") # Backwards compat : si ref/proc absent, fallback sur all
if legacy_idx.exists() and legacy_meta.exists(): if kind in ("ref", "proc") and (not index_path.exists() or not meta_path.exists()):
logger.warning("Index %s absent — fallback legacy faiss.index", kind) legacy_idx, legacy_meta = _paths("all")
index_path, meta_path = legacy_idx, legacy_meta if legacy_idx.exists() and legacy_meta.exists():
else: logger.warning("Index %s absent — fallback legacy faiss.index", kind)
logger.warning("Index FAISS non trouvé dans %s — lancez build_index() d'abord", RAG_INDEX_DIR) index_path, meta_path = legacy_idx, legacy_meta
else:
logger.warning("Index FAISS non trouvé dans %s — lancez build_index() d'abord", RAG_INDEX_DIR)
return None
if not index_path.exists() or not meta_path.exists():
logger.warning("Index FAISS non trouvé (%s) dans %s — lancez build_index() d'abord", kind, RAG_INDEX_DIR)
return None return None
if not index_path.exists() or not meta_path.exists(): import faiss
logger.warning("Index FAISS non trouvé (%s) dans %s — lancez build_index() d'abord", kind, RAG_INDEX_DIR)
return None
import faiss faiss_index = faiss.read_index(str(index_path))
metadata = json.loads(meta_path.read_text(encoding="utf-8"))
faiss_index = faiss.read_index(str(index_path)) logger.info("Index FAISS chargé (%s) : %d vecteurs", kind, faiss_index.ntotal)
metadata = json.loads(meta_path.read_text(encoding="utf-8")) _loaded[kind] = (faiss_index, metadata)
return _loaded[kind]
logger.info("Index FAISS chargé (%s) : %d vecteurs", kind, faiss_index.ntotal)
_loaded[kind] = (faiss_index, metadata)
return _loaded[kind]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -800,4 +807,5 @@ def add_chunks_to_index(chunks: list[Chunk]) -> int:
def reset_index() -> None: def reset_index() -> None:
"""Invalide les singletons FAISS pour forcer le rechargement au prochain accès.""" """Invalide les singletons FAISS pour forcer le rechargement au prochain accès."""
_loaded.clear() with _loaded_lock:
_loaded.clear()

View File

@@ -28,6 +28,7 @@ _embed_failed = False # Sentinelle pour éviter les retries infinis
# Singleton pour le cross-encoder de re-ranking (CPU uniquement) # Singleton pour le cross-encoder de re-ranking (CPU uniquement)
_reranker_model = None _reranker_model = None
_reranker_lock = threading.Lock()
# Score minimum de similarité FAISS pour retenir un résultat # Score minimum de similarité FAISS pour retenir un résultat
_MIN_SCORE = 0.3 _MIN_SCORE = 0.3
@@ -84,12 +85,17 @@ def _get_embed_model():
def _get_reranker(): def _get_reranker():
"""Charge le cross-encoder de re-ranking (singleton, CPU uniquement). """Charge le cross-encoder de re-ranking (singleton thread-safe, CPU uniquement).
Forcé sur CPU pour ne pas interférer avec Ollama sur GPU. Forcé sur CPU pour ne pas interférer avec Ollama sur GPU.
""" """
global _reranker_model global _reranker_model
if _reranker_model is None: if _reranker_model is not None:
return _reranker_model
with _reranker_lock:
# Double-check après acquisition du lock
if _reranker_model is not None:
return _reranker_model
from sentence_transformers import CrossEncoder from sentence_transformers import CrossEncoder
logger.info("Chargement du cross-encoder de re-ranking (cpu)...") logger.info("Chargement du cross-encoder de re-ranking (cpu)...")
_reranker_model = CrossEncoder(RERANKER_MODEL, device="cpu") _reranker_model = CrossEncoder(RERANKER_MODEL, device="cpu")