diff --git a/src/medical/rag_index.py b/src/medical/rag_index.py index 6b7fbd8..9691f07 100644 --- a/src/medical/rag_index.py +++ b/src/medical/rag_index.py @@ -423,9 +423,11 @@ def build_index(force: bool = False) -> None: logger.info("Total : %d chunks à indexer", len(all_chunks)) - # Embeddings — forcer CPU pour éviter les bugs CUDA avec ce modèle - logger.info("Chargement du modèle d'embedding dangvantuan/sentence-camembert-large (CPU)...") - model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu") + # Embeddings — GPU si disponible + import torch + _device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info("Chargement du modèle d'embedding dangvantuan/sentence-camembert-large (%s)...", _device) + model = SentenceTransformer("dangvantuan/sentence-camembert-large", device=_device) model.max_seq_length = 512 # CamemBERT max position embeddings texts = [c.text[:2000] for c in all_chunks] # Tronquer les chunks trop longs diff --git a/src/medical/rag_search.py b/src/medical/rag_search.py index 1b42acc..dbb8fd9 100644 --- a/src/medical/rag_search.py +++ b/src/medical/rag_search.py @@ -29,7 +29,9 @@ def _get_embed_model(): if _embed_model is None: from sentence_transformers import SentenceTransformer logger.info("Chargement du modèle d'embedding pour la recherche...") - _embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device="cpu") + import torch + _device = "cuda" if torch.cuda.is_available() else "cpu" + _embed_model = SentenceTransformer("dangvantuan/sentence-camembert-large", device=_device) _embed_model.max_seq_length = 512 return _embed_model