From 931b6c5d1c5453f5e2781149ddde640a1671fe3c Mon Sep 17 00:00:00 2001 From: dom Date: Wed, 11 Feb 2026 23:42:46 +0100 Subject: [PATCH] feat: embeddings sur GPU (CUDA) pour l'indexation et la recherche RAG MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Détection automatique GPU/CPU avec fallback. Index FAISS reconstruit en 1min (GPU) au lieu de 16min (CPU). Co-Authored-By: Claude Opus 4.6 --- src/medical/rag_index.py | 8 +++++--- src/medical/rag_search.py | 4 +++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/medical/rag_index.py b/src/medical/rag_index.py index 6b7fbd8..9691f07 100644 --- a/src/medical/rag_index.py +++ b/src/medical/rag_index.py @@ -423,9 +423,11 @@ def build_index(force: bool = False) -> None: logger.info("Total : %d chunks à indexer", len(all_chunks)) - # Embeddings — forcer CPU pour éviter les bugs CUDA avec ce modèle - logger.info("Chargement du modèle d'embedding dangvantuan/sentence-camembert-large (CPU)...") - model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu") + # Embeddings — GPU si disponible + import torch + _device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info("Chargement du modèle d'embedding dangvantuan/sentence-camembert-large (%s)...", _device) + model = SentenceTransformer("dangvantuan/sentence-camembert-large", device=_device) model.max_seq_length = 512 # CamemBERT max position embeddings texts = [c.text[:2000] for c in all_chunks] # Tronquer les chunks trop longs diff --git a/src/medical/rag_search.py b/src/medical/rag_search.py index 1b42acc..dbb8fd9 100644 --- a/src/medical/rag_search.py +++ b/src/medical/rag_search.py @@ -29,7 +29,9 @@ def _get_embed_model(): if _embed_model is None: from sentence_transformers import SentenceTransformer logger.info("Chargement du modèle d'embedding pour la recherche...") - _embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu") + import torch + _device = "cuda" if torch.cuda.is_available() else "cpu" + _embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device=_device) _embed_model.max_seq_length = 512 return _embed_model