feat: re-ranking cross-encoder CPU pour la recherche RAG CPAM

- Nouveau singleton _get_reranker() : CrossEncoder ms-marco-MiniLM-L-6-v2
  forcé sur CPU pour ne pas interférer avec Ollama sur GPU
- Fonction _rerank() : re-classe les résultats FAISS via cross-encoder,
  conserve le score FAISS original dans score_faiss
- Intégré dans search_similar_cpam() après déduplication, avant priorisation
- Config RERANKER_MODEL externalisée via T2A_RERANKER_MODEL (.env)
- Fix fallback CUDA OOM : rattrapage de torch.AcceleratorError en plus
  de torch.OutOfMemoryError

Latence : ~7-12s (incluant chargement one-time du modèle ~80Mo).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dom
2026-02-15 11:16:58 +01:00
parent 50a77c9f61
commit 59365e3af9
2 changed files with 56 additions and 5 deletions

View File

@@ -8,7 +8,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from ..config import (
ActeCCAM, Diagnostic, DossierMedical, RAGSource,
OLLAMA_CACHE_PATH, OLLAMA_MAX_PARALLEL, OLLAMA_MODEL,
EMBEDDING_MODEL,
EMBEDDING_MODEL, RERANKER_MODEL,
)
from .cim10_dict import normalize_code, validate_code as cim10_validate
from .cim10_extractor import BIO_NORMALS
@@ -21,6 +21,9 @@ logger = logging.getLogger(__name__)
# Singleton pour le modèle d'embedding (chargé une seule fois)
_embed_model = None
# Singleton pour le cross-encoder de re-ranking (CPU uniquement)
_reranker_model = None
# Score minimum de similarité FAISS pour retenir un résultat
_MIN_SCORE = 0.3
# Seuil rehaussé pour le contexte CPAM (filtrage plus agressif du bruit)
@@ -40,8 +43,8 @@ def _get_embed_model():
try:
logger.info("Chargement du modèle d'embedding (%s)...", _device)
_embed_model = SentenceTransformer(EMBEDDING_MODEL, device=_device)
except torch.OutOfMemoryError:
if _device == "cuda":
except (torch.OutOfMemoryError, torch.cuda.CudaError, torch.AcceleratorError, RuntimeError) as exc:
if _device == "cuda" and "memory" in str(exc).lower():
logger.warning("CUDA OOM pour l'embedding — fallback CPU")
torch.cuda.empty_cache()
_embed_model = SentenceTransformer(EMBEDDING_MODEL, device="cpu")
@@ -51,6 +54,48 @@ def _get_embed_model():
return _embed_model
def _get_reranker():
"""Charge le cross-encoder de re-ranking (singleton, CPU uniquement).
Forcé sur CPU pour ne pas interférer avec Ollama sur GPU.
"""
global _reranker_model
if _reranker_model is None:
from sentence_transformers import CrossEncoder
logger.info("Chargement du cross-encoder de re-ranking (cpu)...")
_reranker_model = CrossEncoder(RERANKER_MODEL, device="cpu")
return _reranker_model
def _rerank(query: str, results: list[dict], top_k: int) -> list[dict]:
"""Re-classe les résultats FAISS via un cross-encoder.
Args:
query: Texte de la requête originale.
results: Résultats FAISS avec clé 'extrait'.
top_k: Nombre de résultats à retourner.
Returns:
Résultats re-classés par score cross-encoder, limités à top_k.
"""
if not results:
return results
reranker = _get_reranker()
# Construire les paires (query, passage) pour le cross-encoder
pairs = [(query, r.get("extrait", "")) for r in results]
ce_scores = reranker.predict(pairs)
# Injecter le score cross-encoder et trier
for r, ce_score in zip(results, ce_scores):
r["score_faiss"] = r["score"]
r["score"] = float(ce_score)
results.sort(key=lambda r: r["score"], reverse=True)
return results[:top_k]
def search_similar(query: str, top_k: int = 10) -> list[dict]:
"""Recherche les passages les plus similaires dans l'index FAISS.
@@ -204,9 +249,12 @@ def search_similar_cpam(query: str, top_k: int = 8) -> list[dict]:
deduped.extend(seen_codes.values())
deduped.sort(key=lambda r: r["score"], reverse=True)
# Re-ranking cross-encoder (CPU) pour affiner le classement
reranked = _rerank(query, deduped, top_k=len(deduped))
# Prioriser le Guide Méthodologique (min 3 résultats)
guide_results = [r for r in deduped if r["document"] == "guide_methodo"]
other_results = [r for r in deduped if r["document"] != "guide_methodo"]
guide_results = [r for r in reranked if r["document"] == "guide_methodo"]
other_results = [r for r in reranked if r["document"] != "guide_methodo"]
min_guide = min(3, len(guide_results))
final = guide_results[:min_guide]