feat: parallélisation pipeline --workers N (ThreadPoolExecutor)
- Fix thread-safety FAISS index (Lock + double-check sur _loaded) - Fix thread-safety reranker (Lock + double-check sur _reranker_model) - main.py : flag --workers, extraction _process_group(), ThreadPoolExecutor - benchmark_quality.py : flag --workers, subprocess en parallèle - Validé sur 10 dossiers gold standard --workers 3 : 0 crash, codes identiques Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -594,6 +594,7 @@ def main():
|
|||||||
parser.add_argument("--no-reprocess", action="store_true", help="Analyser les outputs existants sans relancer le pipeline")
|
parser.add_argument("--no-reprocess", action="store_true", help="Analyser les outputs existants sans relancer le pipeline")
|
||||||
parser.add_argument("--clean", action="store_true", help="Supprimer les outputs avant retraitement")
|
parser.add_argument("--clean", action="store_true", help="Supprimer les outputs avant retraitement")
|
||||||
parser.add_argument("--seed", type=int, default=42, help="Seed pour la sélection aléatoire")
|
parser.add_argument("--seed", type=int, default=42, help="Seed pour la sélection aléatoire")
|
||||||
|
parser.add_argument("--workers", type=int, default=1, help="Nombre de dossiers traités en parallèle")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Sélection dossiers
|
# Sélection dossiers
|
||||||
@@ -632,23 +633,55 @@ def main():
|
|||||||
|
|
||||||
# Traitement
|
# Traitement
|
||||||
per_dossier = []
|
per_dossier = []
|
||||||
for i, dossier_id in enumerate(dossiers, 1):
|
total = len(dossiers)
|
||||||
print(f" [{i}/{len(dossiers)}] {dossier_id}", end="", flush=True)
|
|
||||||
|
|
||||||
if args.no_reprocess:
|
if args.workers > 1 and not args.no_reprocess:
|
||||||
duration = 0.0
|
# Mode parallèle : exécuter les pipelines en parallèle puis analyser
|
||||||
success = find_merged_json(dossier_id) is not None
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
if not success:
|
print(f" Mode parallèle : {args.workers} workers")
|
||||||
print(" — pas de JSON")
|
pipeline_results: dict[str, tuple[float, bool]] = {}
|
||||||
|
done = 0
|
||||||
|
with ThreadPoolExecutor(max_workers=args.workers) as executor:
|
||||||
|
futures = {
|
||||||
|
executor.submit(run_pipeline, dossier_id, args.clean): dossier_id
|
||||||
|
for dossier_id in dossiers
|
||||||
|
}
|
||||||
|
for future in as_completed(futures):
|
||||||
|
dossier_id = futures[future]
|
||||||
|
try:
|
||||||
|
duration, success = future.result()
|
||||||
|
except Exception as e:
|
||||||
|
print(f" EXCEPTION {dossier_id}: {e}")
|
||||||
|
duration, success = 0.0, False
|
||||||
|
pipeline_results[dossier_id] = (duration, success)
|
||||||
|
done += 1
|
||||||
|
mark = "✓" if success else "✗"
|
||||||
|
print(f" [{done}/{total}] {dossier_id} — {duration:.1f}s {mark}")
|
||||||
|
|
||||||
|
# Analyse séquentielle (ordre stable)
|
||||||
|
for dossier_id in dossiers:
|
||||||
|
duration, success = pipeline_results[dossier_id]
|
||||||
|
metrics = analyze_dossier(dossier_id, cim10, duration)
|
||||||
|
per_dossier.append(metrics)
|
||||||
|
else:
|
||||||
|
# Mode séquentiel (ou --no-reprocess)
|
||||||
|
for i, dossier_id in enumerate(dossiers, 1):
|
||||||
|
print(f" [{i}/{total}] {dossier_id}", end="", flush=True)
|
||||||
|
|
||||||
|
if args.no_reprocess:
|
||||||
|
duration = 0.0
|
||||||
|
success = find_merged_json(dossier_id) is not None
|
||||||
|
if not success:
|
||||||
|
print(" — pas de JSON")
|
||||||
|
else:
|
||||||
|
print(" — analyse existant")
|
||||||
else:
|
else:
|
||||||
print(" — analyse existant")
|
print(" — traitement...", end="", flush=True)
|
||||||
else:
|
duration, success = run_pipeline(dossier_id, args.clean)
|
||||||
print(" — traitement...", end="", flush=True)
|
print(f" {duration:.1f}s {'✓' if success else '✗'}")
|
||||||
duration, success = run_pipeline(dossier_id, args.clean)
|
|
||||||
print(f" {duration:.1f}s {'✓' if success else '✗'}")
|
|
||||||
|
|
||||||
metrics = analyze_dossier(dossier_id, cim10, duration)
|
metrics = analyze_dossier(dossier_id, cim10, duration)
|
||||||
per_dossier.append(metrics)
|
per_dossier.append(metrics)
|
||||||
|
|
||||||
# Agrégation
|
# Agrégation
|
||||||
agg = compute_aggregate(per_dossier)
|
agg = compute_aggregate(per_dossier)
|
||||||
|
|||||||
27
src/main.py
27
src/main.py
@@ -399,6 +399,12 @@ def main(input_path: str | None = None) -> None:
|
|||||||
metavar="PATH",
|
metavar="PATH",
|
||||||
help="Fichier Excel de contrôle CPAM (enrichit les dossiers avec contre-argumentation)",
|
help="Fichier Excel de contrôle CPAM (enrichit les dossiers avec contre-argumentation)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Nombre de dossiers traités en parallèle (défaut: 1)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.build_dict:
|
if args.build_dict:
|
||||||
@@ -501,7 +507,8 @@ def main(input_path: str | None = None) -> None:
|
|||||||
|
|
||||||
logger.info("Traitement de %d PDF(s)...", total)
|
logger.info("Traitement de %d PDF(s)...", total)
|
||||||
|
|
||||||
for pdfs, subdir in groups:
|
def _process_group(pdfs: list[Path], subdir: str | None) -> None:
|
||||||
|
"""Traite un groupe de PDFs (un dossier patient)."""
|
||||||
if subdir:
|
if subdir:
|
||||||
logger.info("--- Dossier %s (%d PDFs) ---", subdir, len(pdfs))
|
logger.info("--- Dossier %s (%d PDFs) ---", subdir, len(pdfs))
|
||||||
|
|
||||||
@@ -633,6 +640,24 @@ def main(input_path: str | None = None) -> None:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Erreur écriture dossier fusionné %s", subdir)
|
logger.exception("Erreur écriture dossier fusionné %s", subdir)
|
||||||
|
|
||||||
|
# Exécution séquentielle ou parallèle selon --workers
|
||||||
|
if args.workers > 1:
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
logger.info("Mode parallèle : %d workers", args.workers)
|
||||||
|
with ThreadPoolExecutor(max_workers=args.workers) as executor:
|
||||||
|
futures = {
|
||||||
|
executor.submit(_process_group, pdfs, subdir): subdir
|
||||||
|
for pdfs, subdir in groups
|
||||||
|
}
|
||||||
|
for future in as_completed(futures):
|
||||||
|
try:
|
||||||
|
future.result()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Erreur groupe %s", futures[future])
|
||||||
|
else:
|
||||||
|
for pdfs, subdir in groups:
|
||||||
|
_process_group(pdfs, subdir)
|
||||||
|
|
||||||
logger.info("Terminé.")
|
logger.info("Terminé.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -26,6 +27,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Singletons pour les index chargés en mémoire
|
# Singletons pour les index chargés en mémoire
|
||||||
_loaded: dict[str, tuple] = {}
|
_loaded: dict[str, tuple] = {}
|
||||||
|
_loaded_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -577,30 +579,35 @@ def get_index(kind: str = "ref") -> tuple | None:
|
|||||||
if kind in _loaded:
|
if kind in _loaded:
|
||||||
return _loaded[kind]
|
return _loaded[kind]
|
||||||
|
|
||||||
index_path, meta_path = _paths(kind)
|
with _loaded_lock:
|
||||||
|
# Double-check après acquisition du lock
|
||||||
|
if kind in _loaded:
|
||||||
|
return _loaded[kind]
|
||||||
|
|
||||||
# Backwards compat : si ref/proc absent, fallback sur all
|
index_path, meta_path = _paths(kind)
|
||||||
if kind in ("ref", "proc") and (not index_path.exists() or not meta_path.exists()):
|
|
||||||
legacy_idx, legacy_meta = _paths("all")
|
# Backwards compat : si ref/proc absent, fallback sur all
|
||||||
if legacy_idx.exists() and legacy_meta.exists():
|
if kind in ("ref", "proc") and (not index_path.exists() or not meta_path.exists()):
|
||||||
logger.warning("Index %s absent — fallback legacy faiss.index", kind)
|
legacy_idx, legacy_meta = _paths("all")
|
||||||
index_path, meta_path = legacy_idx, legacy_meta
|
if legacy_idx.exists() and legacy_meta.exists():
|
||||||
else:
|
logger.warning("Index %s absent — fallback legacy faiss.index", kind)
|
||||||
logger.warning("Index FAISS non trouvé dans %s — lancez build_index() d'abord", RAG_INDEX_DIR)
|
index_path, meta_path = legacy_idx, legacy_meta
|
||||||
|
else:
|
||||||
|
logger.warning("Index FAISS non trouvé dans %s — lancez build_index() d'abord", RAG_INDEX_DIR)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not index_path.exists() or not meta_path.exists():
|
||||||
|
logger.warning("Index FAISS non trouvé (%s) dans %s — lancez build_index() d'abord", kind, RAG_INDEX_DIR)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not index_path.exists() or not meta_path.exists():
|
import faiss
|
||||||
logger.warning("Index FAISS non trouvé (%s) dans %s — lancez build_index() d'abord", kind, RAG_INDEX_DIR)
|
|
||||||
return None
|
|
||||||
|
|
||||||
import faiss
|
faiss_index = faiss.read_index(str(index_path))
|
||||||
|
metadata = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
faiss_index = faiss.read_index(str(index_path))
|
logger.info("Index FAISS chargé (%s) : %d vecteurs", kind, faiss_index.ntotal)
|
||||||
metadata = json.loads(meta_path.read_text(encoding="utf-8"))
|
_loaded[kind] = (faiss_index, metadata)
|
||||||
|
return _loaded[kind]
|
||||||
logger.info("Index FAISS chargé (%s) : %d vecteurs", kind, faiss_index.ntotal)
|
|
||||||
_loaded[kind] = (faiss_index, metadata)
|
|
||||||
return _loaded[kind]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -800,4 +807,5 @@ def add_chunks_to_index(chunks: list[Chunk]) -> int:
|
|||||||
|
|
||||||
def reset_index() -> None:
|
def reset_index() -> None:
|
||||||
"""Invalide les singletons FAISS pour forcer le rechargement au prochain accès."""
|
"""Invalide les singletons FAISS pour forcer le rechargement au prochain accès."""
|
||||||
_loaded.clear()
|
with _loaded_lock:
|
||||||
|
_loaded.clear()
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ _embed_failed = False # Sentinelle pour éviter les retries infinis
|
|||||||
|
|
||||||
# Singleton pour le cross-encoder de re-ranking (CPU uniquement)
|
# Singleton pour le cross-encoder de re-ranking (CPU uniquement)
|
||||||
_reranker_model = None
|
_reranker_model = None
|
||||||
|
_reranker_lock = threading.Lock()
|
||||||
|
|
||||||
# Score minimum de similarité FAISS pour retenir un résultat
|
# Score minimum de similarité FAISS pour retenir un résultat
|
||||||
_MIN_SCORE = 0.3
|
_MIN_SCORE = 0.3
|
||||||
@@ -84,12 +85,17 @@ def _get_embed_model():
|
|||||||
|
|
||||||
|
|
||||||
def _get_reranker():
|
def _get_reranker():
|
||||||
"""Charge le cross-encoder de re-ranking (singleton, CPU uniquement).
|
"""Charge le cross-encoder de re-ranking (singleton thread-safe, CPU uniquement).
|
||||||
|
|
||||||
Forcé sur CPU pour ne pas interférer avec Ollama sur GPU.
|
Forcé sur CPU pour ne pas interférer avec Ollama sur GPU.
|
||||||
"""
|
"""
|
||||||
global _reranker_model
|
global _reranker_model
|
||||||
if _reranker_model is None:
|
if _reranker_model is not None:
|
||||||
|
return _reranker_model
|
||||||
|
with _reranker_lock:
|
||||||
|
# Double-check après acquisition du lock
|
||||||
|
if _reranker_model is not None:
|
||||||
|
return _reranker_model
|
||||||
from sentence_transformers import CrossEncoder
|
from sentence_transformers import CrossEncoder
|
||||||
logger.info("Chargement du cross-encoder de re-ranking (cpu)...")
|
logger.info("Chargement du cross-encoder de re-ranking (cpu)...")
|
||||||
_reranker_model = CrossEncoder(RERANKER_MODEL, device="cpu")
|
_reranker_model = CrossEncoder(RERANKER_MODEL, device="cpu")
|
||||||
|
|||||||
Reference in New Issue
Block a user