diff --git a/src/medical/rag_search.py b/src/medical/rag_search.py index b914105..6ccd9c0 100644 --- a/src/medical/rag_search.py +++ b/src/medical/rag_search.py @@ -35,21 +35,27 @@ def _get_embed_model(): """Charge le modèle d'embedding (singleton). Tente CUDA d'abord, fallback CPU si OOM (Ollama peut occuper la VRAM). + low_cpu_mem_usage=False évite les meta tensors (accelerate + sentence-transformers 5.x). """ global _embed_model if _embed_model is None: from sentence_transformers import SentenceTransformer import torch _device = "cuda" if torch.cuda.is_available() else "cpu" + _model_kwargs = {"low_cpu_mem_usage": False} try: logger.info("Chargement du modèle d'embedding (%s)...", _device) - _embed_model = SentenceTransformer(EMBEDDING_MODEL, device=_device) + _embed_model = SentenceTransformer( + EMBEDDING_MODEL, device=_device, model_kwargs=_model_kwargs, + ) except (torch.OutOfMemoryError, torch.cuda.CudaError, torch.AcceleratorError, RuntimeError) as exc: exc_msg = str(exc).lower() if _device == "cuda" and ("memory" in exc_msg or "meta tensor" in exc_msg): logger.warning("CUDA erreur pour l'embedding — fallback CPU : %s", exc) torch.cuda.empty_cache() - _embed_model = SentenceTransformer(EMBEDDING_MODEL, device="cpu") + _embed_model = SentenceTransformer( + EMBEDDING_MODEL, device="cpu", model_kwargs=_model_kwargs, + ) else: raise _embed_model.max_seq_length = 512