From 59365e3af97cfdf8ba312381bdb34201334bc2a7 Mon Sep 17 00:00:00 2001 From: dom Date: Sun, 15 Feb 2026 11:16:58 +0100 Subject: [PATCH] feat: re-ranking cross-encoder CPU pour la recherche RAG CPAM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Nouveau singleton _get_reranker() : CrossEncoder ms-marco-MiniLM-L-6-v2 forcé sur CPU pour ne pas interférer avec Ollama sur GPU - Fonction _rerank() : re-classe les résultats FAISS via cross-encoder, conserve le score FAISS original dans score_faiss - Intégré dans search_similar_cpam() après déduplication, avant priorisation - Config RERANKER_MODEL externalisée via T2A_RERANKER_MODEL (.env) - Fix fallback CUDA OOM : rattrapage de torch.AcceleratorError en plus de torch.OutOfMemoryError Latence : ~7-12s (incluant chargement one-time du modèle ~80Mo). Co-Authored-By: Claude Opus 4.6 --- src/config.py | 3 ++ src/medical/rag_search.py | 58 +++++++++++++++++++++++++++++++++++---- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/src/config.py b/src/config.py index 2f7b98d..3170cd1 100644 --- a/src/config.py +++ b/src/config.py @@ -64,6 +64,9 @@ CCAM_PDF = Path(os.environ.get("T2A_CCAM_PDF", "/home/dom/ai/aivanov_CIM/actuali EMBEDDING_MODEL = os.environ.get("T2A_EMBEDDING_MODEL", "dangvantuan/sentence-camembert-large") +# --- Modèle de re-ranking (cross-encoder, CPU uniquement) --- + +RERANKER_MODEL = os.environ.get("T2A_RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2") # --- Modèles de données CIM-10 --- diff --git a/src/medical/rag_search.py b/src/medical/rag_search.py index 22395f6..1a86b01 100644 --- a/src/medical/rag_search.py +++ b/src/medical/rag_search.py @@ -8,7 +8,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from ..config import ( ActeCCAM, Diagnostic, DossierMedical, RAGSource, OLLAMA_CACHE_PATH, OLLAMA_MAX_PARALLEL, OLLAMA_MODEL, - EMBEDDING_MODEL, + EMBEDDING_MODEL, RERANKER_MODEL, ) from .cim10_dict import normalize_code, validate_code as cim10_validate from .cim10_extractor import BIO_NORMALS @@ -21,6 +21,9 @@ logger = logging.getLogger(__name__) # Singleton pour le modèle d'embedding (chargé une seule fois) _embed_model = None +# Singleton pour le cross-encoder de re-ranking (CPU uniquement) +_reranker_model = None + # Score minimum de similarité FAISS pour retenir un résultat _MIN_SCORE = 0.3 # Seuil rehaussé pour le contexte CPAM (filtrage plus agressif du bruit) @@ -40,8 +43,8 @@ def _get_embed_model(): try: logger.info("Chargement du modèle d'embedding (%s)...", _device) _embed_model = SentenceTransformer(EMBEDDING_MODEL, device=_device) - except torch.OutOfMemoryError: - if _device == "cuda": + except (torch.OutOfMemoryError, torch.cuda.CudaError, torch.AcceleratorError, RuntimeError) as exc: + if _device == "cuda" and "memory" in str(exc).lower(): logger.warning("CUDA OOM pour l'embedding — fallback CPU") torch.cuda.empty_cache() _embed_model = SentenceTransformer(EMBEDDING_MODEL, device="cpu") @@ -51,6 +54,48 @@ def _get_embed_model(): return _embed_model +def _get_reranker(): + """Charge le cross-encoder de re-ranking (singleton, CPU uniquement). + + Forcé sur CPU pour ne pas interférer avec Ollama sur GPU. + """ + global _reranker_model + if _reranker_model is None: + from sentence_transformers import CrossEncoder + logger.info("Chargement du cross-encoder de re-ranking (cpu)...") + _reranker_model = CrossEncoder(RERANKER_MODEL, device="cpu") + return _reranker_model + + +def _rerank(query: str, results: list[dict], top_k: int) -> list[dict]: + """Re-classe les résultats FAISS via un cross-encoder. + + Args: + query: Texte de la requête originale. + results: Résultats FAISS avec clé 'extrait'. + top_k: Nombre de résultats à retourner. + + Returns: + Résultats re-classés par score cross-encoder, limités à top_k. + """ + if not results: + return results + + reranker = _get_reranker() + + # Construire les paires (query, passage) pour le cross-encoder + pairs = [(query, r.get("extrait", "")) for r in results] + ce_scores = reranker.predict(pairs) + + # Injecter le score cross-encoder et trier + for r, ce_score in zip(results, ce_scores): + r["score_faiss"] = r["score"] + r["score"] = float(ce_score) + + results.sort(key=lambda r: r["score"], reverse=True) + return results[:top_k] + + def search_similar(query: str, top_k: int = 10) -> list[dict]: """Recherche les passages les plus similaires dans l'index FAISS. @@ -204,9 +249,12 @@ def search_similar_cpam(query: str, top_k: int = 8) -> list[dict]: deduped.extend(seen_codes.values()) deduped.sort(key=lambda r: r["score"], reverse=True) + # Re-ranking cross-encoder (CPU) pour affiner le classement + reranked = _rerank(query, deduped, top_k=len(deduped)) + # Prioriser le Guide Méthodologique (min 3 résultats) - guide_results = [r for r in deduped if r["document"] == "guide_methodo"] - other_results = [r for r in deduped if r["document"] != "guide_methodo"] + guide_results = [r for r in reranked if r["document"] == "guide_methodo"] + other_results = [r for r in reranked if r["document"] != "guide_methodo"] min_guide = min(3, len(guide_results)) final = guide_results[:min_guide]