feat: re-ranking cross-encoder CPU pour la recherche RAG CPAM
- Nouveau singleton _get_reranker() : CrossEncoder ms-marco-MiniLM-L-6-v2 forcé sur CPU pour ne pas interférer avec Ollama sur GPU - Fonction _rerank() : re-classe les résultats FAISS via cross-encoder, conserve le score FAISS original dans score_faiss - Intégré dans search_similar_cpam() après déduplication, avant priorisation - Config RERANKER_MODEL externalisée via T2A_RERANKER_MODEL (.env) - Fix fallback CUDA OOM : rattrapage de torch.AcceleratorError en plus de torch.OutOfMemoryError Latence : ~7-12s (incluant chargement one-time du modèle ~80Mo). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -64,6 +64,9 @@ CCAM_PDF = Path(os.environ.get("T2A_CCAM_PDF", "/home/dom/ai/aivanov_CIM/actuali
|
|||||||
|
|
||||||
EMBEDDING_MODEL = os.environ.get("T2A_EMBEDDING_MODEL", "dangvantuan/sentence-camembert-large")
|
EMBEDDING_MODEL = os.environ.get("T2A_EMBEDDING_MODEL", "dangvantuan/sentence-camembert-large")
|
||||||
|
|
||||||
|
# --- Modèle de re-ranking (cross-encoder, CPU uniquement) ---
|
||||||
|
|
||||||
|
RERANKER_MODEL = os.environ.get("T2A_RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||||
|
|
||||||
# --- Modèles de données CIM-10 ---
|
# --- Modèles de données CIM-10 ---
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||||||
from ..config import (
|
from ..config import (
|
||||||
ActeCCAM, Diagnostic, DossierMedical, RAGSource,
|
ActeCCAM, Diagnostic, DossierMedical, RAGSource,
|
||||||
OLLAMA_CACHE_PATH, OLLAMA_MAX_PARALLEL, OLLAMA_MODEL,
|
OLLAMA_CACHE_PATH, OLLAMA_MAX_PARALLEL, OLLAMA_MODEL,
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL, RERANKER_MODEL,
|
||||||
)
|
)
|
||||||
from .cim10_dict import normalize_code, validate_code as cim10_validate
|
from .cim10_dict import normalize_code, validate_code as cim10_validate
|
||||||
from .cim10_extractor import BIO_NORMALS
|
from .cim10_extractor import BIO_NORMALS
|
||||||
@@ -21,6 +21,9 @@ logger = logging.getLogger(__name__)
|
|||||||
# Singleton pour le modèle d'embedding (chargé une seule fois)
|
# Singleton pour le modèle d'embedding (chargé une seule fois)
|
||||||
_embed_model = None
|
_embed_model = None
|
||||||
|
|
||||||
|
# Singleton pour le cross-encoder de re-ranking (CPU uniquement)
|
||||||
|
_reranker_model = None
|
||||||
|
|
||||||
# Score minimum de similarité FAISS pour retenir un résultat
|
# Score minimum de similarité FAISS pour retenir un résultat
|
||||||
_MIN_SCORE = 0.3
|
_MIN_SCORE = 0.3
|
||||||
# Seuil rehaussé pour le contexte CPAM (filtrage plus agressif du bruit)
|
# Seuil rehaussé pour le contexte CPAM (filtrage plus agressif du bruit)
|
||||||
@@ -40,8 +43,8 @@ def _get_embed_model():
|
|||||||
try:
|
try:
|
||||||
logger.info("Chargement du modèle d'embedding (%s)...", _device)
|
logger.info("Chargement du modèle d'embedding (%s)...", _device)
|
||||||
_embed_model = SentenceTransformer(EMBEDDING_MODEL, device=_device)
|
_embed_model = SentenceTransformer(EMBEDDING_MODEL, device=_device)
|
||||||
except torch.OutOfMemoryError:
|
except (torch.OutOfMemoryError, torch.cuda.CudaError, torch.AcceleratorError, RuntimeError) as exc:
|
||||||
if _device == "cuda":
|
if _device == "cuda" and "memory" in str(exc).lower():
|
||||||
logger.warning("CUDA OOM pour l'embedding — fallback CPU")
|
logger.warning("CUDA OOM pour l'embedding — fallback CPU")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
_embed_model = SentenceTransformer(EMBEDDING_MODEL, device="cpu")
|
_embed_model = SentenceTransformer(EMBEDDING_MODEL, device="cpu")
|
||||||
@@ -51,6 +54,48 @@ def _get_embed_model():
|
|||||||
return _embed_model
|
return _embed_model
|
||||||
|
|
||||||
|
|
||||||
|
def _get_reranker():
|
||||||
|
"""Charge le cross-encoder de re-ranking (singleton, CPU uniquement).
|
||||||
|
|
||||||
|
Forcé sur CPU pour ne pas interférer avec Ollama sur GPU.
|
||||||
|
"""
|
||||||
|
global _reranker_model
|
||||||
|
if _reranker_model is None:
|
||||||
|
from sentence_transformers import CrossEncoder
|
||||||
|
logger.info("Chargement du cross-encoder de re-ranking (cpu)...")
|
||||||
|
_reranker_model = CrossEncoder(RERANKER_MODEL, device="cpu")
|
||||||
|
return _reranker_model
|
||||||
|
|
||||||
|
|
||||||
|
def _rerank(query: str, results: list[dict], top_k: int) -> list[dict]:
|
||||||
|
"""Re-classe les résultats FAISS via un cross-encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Texte de la requête originale.
|
||||||
|
results: Résultats FAISS avec clé 'extrait'.
|
||||||
|
top_k: Nombre de résultats à retourner.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Résultats re-classés par score cross-encoder, limités à top_k.
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return results
|
||||||
|
|
||||||
|
reranker = _get_reranker()
|
||||||
|
|
||||||
|
# Construire les paires (query, passage) pour le cross-encoder
|
||||||
|
pairs = [(query, r.get("extrait", "")) for r in results]
|
||||||
|
ce_scores = reranker.predict(pairs)
|
||||||
|
|
||||||
|
# Injecter le score cross-encoder et trier
|
||||||
|
for r, ce_score in zip(results, ce_scores):
|
||||||
|
r["score_faiss"] = r["score"]
|
||||||
|
r["score"] = float(ce_score)
|
||||||
|
|
||||||
|
results.sort(key=lambda r: r["score"], reverse=True)
|
||||||
|
return results[:top_k]
|
||||||
|
|
||||||
|
|
||||||
def search_similar(query: str, top_k: int = 10) -> list[dict]:
|
def search_similar(query: str, top_k: int = 10) -> list[dict]:
|
||||||
"""Recherche les passages les plus similaires dans l'index FAISS.
|
"""Recherche les passages les plus similaires dans l'index FAISS.
|
||||||
|
|
||||||
@@ -204,9 +249,12 @@ def search_similar_cpam(query: str, top_k: int = 8) -> list[dict]:
|
|||||||
deduped.extend(seen_codes.values())
|
deduped.extend(seen_codes.values())
|
||||||
deduped.sort(key=lambda r: r["score"], reverse=True)
|
deduped.sort(key=lambda r: r["score"], reverse=True)
|
||||||
|
|
||||||
|
# Re-ranking cross-encoder (CPU) pour affiner le classement
|
||||||
|
reranked = _rerank(query, deduped, top_k=len(deduped))
|
||||||
|
|
||||||
# Prioriser le Guide Méthodologique (min 3 résultats)
|
# Prioriser le Guide Méthodologique (min 3 résultats)
|
||||||
guide_results = [r for r in deduped if r["document"] == "guide_methodo"]
|
guide_results = [r for r in reranked if r["document"] == "guide_methodo"]
|
||||||
other_results = [r for r in deduped if r["document"] != "guide_methodo"]
|
other_results = [r for r in reranked if r["document"] != "guide_methodo"]
|
||||||
|
|
||||||
min_guide = min(3, len(guide_results))
|
min_guide = min(3, len(guide_results))
|
||||||
final = guide_results[:min_guide]
|
final = guide_results[:min_guide]
|
||||||
|
|||||||
Reference in New Issue
Block a user