Files
aivanov_CIM/tests/test_rag_engine.py
2026-03-05 01:20:14 +01:00

922 lines
35 KiB
Python

"""
Tests unitaires pour le RAGEngine.
Ce module teste:
- La recherche BM25
- La recherche vectorielle
- La fusion RRF (Reciprocal Rank Fusion)
- La recherche hybride CIM-10 et CCAM
- L'extraction des critères d'éligibilité
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, Mock, patch
import faiss
import numpy as np
import pytest
from pipeline_mco_pmsi.rag.rag_engine import (
CodeCandidate,
EligibilityCriteria,
RAGEngine,
)
from pipeline_mco_pmsi.rag.referentiels_manager import Chunk, ReferentielsManager
@pytest.fixture
def temp_data_dir():
"""Crée un répertoire temporaire pour les tests."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
@pytest.fixture
def mock_referentiels_manager():
"""Crée un mock du ReferentielsManager."""
manager = Mock(spec=ReferentielsManager)
manager.data_dir = Path("data/referentiels")
return manager
@pytest.fixture
def sample_chunks_cim10():
"""Crée des chunks CIM-10 de test."""
return [
Chunk(
chunk_id="cim10_2026_0",
referentiel_type="cim10",
referentiel_version="2026",
content="A00.0 Choléra dû à Vibrio cholerae 01, biovar cholerae\nInclus: choléra classique",
metadata={"chunk_type": "code_block", "chapter": "Chapitre I"},
chunk_index=0,
),
Chunk(
chunk_id="cim10_2026_1",
referentiel_type="cim10",
referentiel_version="2026",
content="A00.1 Choléra dû à Vibrio cholerae 01, biovar El Tor\nInclus: choléra El Tor",
metadata={"chunk_type": "code_block", "chapter": "Chapitre I"},
chunk_index=1,
),
Chunk(
chunk_id="cim10_2026_2",
referentiel_type="cim10",
referentiel_version="2026",
content="K29.7 Gastrite, sans précision\nInclus: gastrite SAI",
metadata={"chunk_type": "code_block", "chapter": "Chapitre XI"},
chunk_index=2,
),
]
@pytest.fixture
def sample_chunks_ccam():
"""Crée des chunks CCAM de test."""
return [
Chunk(
chunk_id="ccam_2025_0",
referentiel_type="ccam",
referentiel_version="2025",
content="YYYY001 Appendicectomie par laparotomie\nNote: Ablation de l'appendice",
metadata={"chunk_type": "acte", "section": "Section 1"},
chunk_index=0,
),
Chunk(
chunk_id="ccam_2025_1",
referentiel_type="ccam",
referentiel_version="2025",
content="YYYY002+ABC Appendicectomie par cœlioscopie\nNote: Ablation de l'appendice par voie cœlioscopique",
metadata={"chunk_type": "acte", "section": "Section 1"},
chunk_index=1,
),
]
@pytest.fixture
def sample_chunks_guide():
"""Crée des chunks du Guide MCO de test."""
return [
Chunk(
chunk_id="guide_mco_2026_0",
referentiel_type="guide_mco",
referentiel_version="2026",
content="""Critères d'éligibilité DP
- Le DP doit être le diagnostic ayant mobilisé l'essentiel des ressources
- Exclut: les diagnostics niés
- Exclut: les antécédents
- Hiérarchisation: privilégier le diagnostic le plus grave""",
metadata={"chunk_type": "section", "section": "Chapitre 1"},
chunk_index=0,
),
]
@pytest.fixture
def rag_engine(mock_referentiels_manager, temp_data_dir):
"""Crée une instance de RAGEngine pour les tests."""
engine = RAGEngine(mock_referentiels_manager, data_dir=temp_data_dir)
return engine
class TestRAGEngineInit:
"""Tests d'initialisation du RAGEngine."""
def test_init_creates_engine(self, mock_referentiels_manager, temp_data_dir):
"""Test que le RAGEngine s'initialise correctement."""
engine = RAGEngine(mock_referentiels_manager, data_dir=temp_data_dir)
assert engine.referentiels_manager == mock_referentiels_manager
assert engine.data_dir == temp_data_dir
assert engine._bm25_indexes == {}
assert engine._faiss_indexes == {}
assert engine._chunks_cache == {}
assert engine._embeddings_model is None
class TestBM25Search:
"""Tests de la recherche BM25."""
def test_build_bm25_index(self, rag_engine, sample_chunks_cim10):
"""Test la construction d'un index BM25."""
bm25_index = rag_engine._build_bm25_index(sample_chunks_cim10)
assert bm25_index is not None
assert len(bm25_index.doc_freqs) > 0
def test_bm25_search_returns_results(
self, rag_engine, sample_chunks_cim10, temp_data_dir
):
"""Test que la recherche BM25 retourne des résultats."""
# Sauvegarder les chunks
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
with open(chunks_path, "w", encoding="utf-8") as f:
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
json.dump(chunks_data, f, ensure_ascii=False, default=str)
# Effectuer la recherche
results = rag_engine._bm25_search("choléra", "cim10", "2026", top_k=2)
assert len(results) > 0
assert all(isinstance(r, tuple) for r in results)
assert all(len(r) == 2 for r in results)
# Vérifier que les scores sont des floats
assert all(isinstance(r[1], float) for r in results)
def test_bm25_search_ranks_relevant_higher(
self, rag_engine, sample_chunks_cim10, temp_data_dir
):
"""Test que BM25 classe les résultats pertinents plus haut."""
# Sauvegarder les chunks
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
with open(chunks_path, "w", encoding="utf-8") as f:
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
json.dump(chunks_data, f, ensure_ascii=False, default=str)
# Rechercher "gastrite" - devrait trouver le chunk 2 en premier
results = rag_engine._bm25_search("gastrite", "cim10", "2026", top_k=3)
# Le premier résultat devrait être le chunk contenant "gastrite"
top_chunk_idx = results[0][0]
assert sample_chunks_cim10[top_chunk_idx].content.lower().find("gastrite") != -1
class TestVectorSearch:
"""Tests de la recherche vectorielle."""
def test_vector_search_with_mock_index(
self, rag_engine, sample_chunks_cim10, temp_data_dir
):
"""Test la recherche vectorielle avec un index FAISS mocké."""
# Créer un index FAISS simple pour le test
dimension = 384 # Dimension typique pour MiniLM
index = faiss.IndexFlatL2(dimension)
# Ajouter des vecteurs aléatoires
vectors = np.random.rand(len(sample_chunks_cim10), dimension).astype(np.float32)
# Normaliser pour cosine similarity
faiss.normalize_L2(vectors)
index.add(vectors)
# Sauvegarder l'index
index_path = temp_data_dir / "cim10_2026_index.faiss"
faiss.write_index(index, str(index_path))
# Sauvegarder les chunks
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
with open(chunks_path, "w", encoding="utf-8") as f:
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
json.dump(chunks_data, f, ensure_ascii=False, default=str)
# Mocker le modèle d'embeddings
mock_model = Mock()
mock_query_vector = np.random.rand(dimension).astype(np.float32)
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
mock_model.encode.return_value = mock_query_vector
rag_engine._embeddings_model = mock_model
# Effectuer la recherche
results = rag_engine._vector_search("test query", "cim10", "2026", top_k=2)
assert len(results) > 0
assert all(isinstance(r, tuple) for r in results)
assert all(len(r) == 2 for r in results)
# Vérifier que les scores de similarité sont dans [0, 1]
assert all(0.0 <= r[1] <= 1.0 for r in results)
class TestReciprocalRankFusion:
"""Tests de la fusion RRF."""
def test_rrf_fusion_combines_results(self, rag_engine):
"""Test que RRF fusionne correctement les résultats."""
bm25_results = [(0, 10.5), (1, 8.2), (2, 5.1)]
vector_results = [(1, 0.95), (0, 0.88), (3, 0.75)]
fused = rag_engine._reciprocal_rank_fusion(bm25_results, vector_results, k=60)
# Vérifier que tous les chunks uniques sont présents
chunk_indices = [idx for idx, _ in fused]
assert set(chunk_indices) == {0, 1, 2, 3}
# Vérifier que les résultats sont triés par score décroissant
scores = [score for _, score in fused]
assert scores == sorted(scores, reverse=True)
def test_rrf_boosts_common_results(self, rag_engine):
"""Test que RRF booste les résultats présents dans les deux listes."""
# Chunk 0 est en tête des deux listes
bm25_results = [(0, 10.0), (1, 5.0), (2, 3.0)]
vector_results = [(0, 0.95), (3, 0.80), (1, 0.70)]
fused = rag_engine._reciprocal_rank_fusion(bm25_results, vector_results, k=60)
# Le chunk 0 devrait avoir le score le plus élevé
top_chunk = fused[0][0]
assert top_chunk == 0
def test_rrf_with_empty_lists(self, rag_engine):
"""Test RRF avec des listes vides."""
fused = rag_engine._reciprocal_rank_fusion([], [], k=60)
assert fused == []
def test_rrf_with_one_empty_list(self, rag_engine):
"""Test RRF avec une liste vide."""
bm25_results = [(0, 10.0), (1, 5.0)]
fused = rag_engine._reciprocal_rank_fusion(bm25_results, [], k=60)
assert len(fused) == 2
assert all(idx in [0, 1] for idx, _ in fused)
class TestSearchICD10:
"""Tests de la recherche CIM-10."""
def test_search_icd10_returns_candidates(
self, rag_engine, sample_chunks_cim10, temp_data_dir
):
"""Test que search_icd10 retourne des candidats."""
# Préparer les données
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
# Mocker le reranker pour éviter de charger le modèle réel
mock_reranker = Mock()
mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7])
rag_engine._reranker_model = mock_reranker
# Effectuer la recherche
candidates = rag_engine.search_icd10("gastrite", top_k=2, version="2026")
assert len(candidates) > 0
assert all(isinstance(c, CodeCandidate) for c in candidates)
def test_search_icd10_candidate_structure(
self, rag_engine, sample_chunks_cim10, temp_data_dir
):
"""Test la structure des candidats retournés."""
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
# Mocker le reranker
mock_reranker = Mock()
mock_reranker.predict.return_value = np.array([0.9])
rag_engine._reranker_model = mock_reranker
candidates = rag_engine.search_icd10("choléra", top_k=1, version="2026")
assert len(candidates) > 0
candidate = candidates[0]
# Vérifier les champs obligatoires
assert candidate.code is not None
assert candidate.label is not None
assert 0.0 <= candidate.similarity_score <= 1.0
assert candidate.source == "reranked"
assert candidate.chunk_id is not None
assert candidate.chunk_text is not None
def test_search_icd10_respects_top_k(
self, rag_engine, sample_chunks_cim10, temp_data_dir
):
"""Test que search_icd10 respecte le paramètre top_k."""
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
# Mocker le reranker
mock_reranker = Mock()
mock_reranker.predict.return_value = np.array([0.9, 0.8])
rag_engine._reranker_model = mock_reranker
candidates = rag_engine.search_icd10("test", top_k=2, version="2026")
assert len(candidates) <= 2
def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type):
"""Helper pour préparer les données de test."""
# Sauvegarder les chunks
version = "2026" if ref_type == "cim10" else "2025"
chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json"
with open(chunks_path, "w", encoding="utf-8") as f:
chunks_data = [chunk.model_dump() for chunk in chunks]
json.dump(chunks_data, f, ensure_ascii=False, default=str)
# Créer un index FAISS simple
dimension = 384
index = faiss.IndexFlatL2(dimension)
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
faiss.normalize_L2(vectors)
index.add(vectors)
index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss"
faiss.write_index(index, str(index_path))
# Mocker le modèle d'embeddings
mock_model = Mock()
mock_query_vector = np.random.rand(dimension).astype(np.float32)
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
mock_model.encode.return_value = mock_query_vector
rag_engine._embeddings_model = mock_model
class TestSearchCCAM:
"""Tests de la recherche CCAM."""
def test_search_ccam_returns_candidates(
self, rag_engine, sample_chunks_ccam, temp_data_dir
):
"""Test que search_ccam retourne des candidats."""
self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam")
# Mocker le reranker
mock_reranker = Mock()
mock_reranker.predict.return_value = np.array([0.9, 0.8])
rag_engine._reranker_model = mock_reranker
candidates = rag_engine.search_ccam("appendicectomie", top_k=2, version="2025")
assert len(candidates) > 0
assert all(isinstance(c, CodeCandidate) for c in candidates)
def test_search_ccam_extracts_code_with_extension(
self, rag_engine, sample_chunks_ccam, temp_data_dir
):
"""Test que search_ccam extrait correctement les codes avec extension ATIH."""
self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam")
# Mocker le reranker
mock_reranker = Mock()
mock_reranker.predict.return_value = np.array([0.9, 0.8])
rag_engine._reranker_model = mock_reranker
candidates = rag_engine.search_ccam("cœlioscopie", top_k=2, version="2025")
# Vérifier qu'au moins un candidat a un code avec extension
codes = [c.code for c in candidates]
# Au moins un code devrait être présent
assert len(codes) > 0
def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type):
"""Helper pour préparer les données de test."""
version = "2025"
chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json"
with open(chunks_path, "w", encoding="utf-8") as f:
chunks_data = [chunk.model_dump() for chunk in chunks]
json.dump(chunks_data, f, ensure_ascii=False, default=str)
dimension = 384
index = faiss.IndexFlatL2(dimension)
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
faiss.normalize_L2(vectors)
index.add(vectors)
index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss"
faiss.write_index(index, str(index_path))
mock_model = Mock()
mock_query_vector = np.random.rand(dimension).astype(np.float32)
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
mock_model.encode.return_value = mock_query_vector
rag_engine._embeddings_model = mock_model
class TestExtractCodeAndLabel:
"""Tests de l'extraction de code et libellé."""
def test_extract_cim10_code_and_label(self, rag_engine):
"""Test l'extraction d'un code CIM-10 et son libellé."""
chunk_text = "A00.0 Choléra dû à Vibrio cholerae 01, biovar cholerae\nInclus: choléra classique"
code, label = rag_engine._extract_code_and_label(chunk_text, "cim10")
assert code == "A00.0"
assert "Choléra" in label
def test_extract_ccam_code_and_label(self, rag_engine):
"""Test l'extraction d'un code CCAM et son libellé."""
chunk_text = "YYYY001 Appendicectomie par laparotomie\nNote: Ablation de l'appendice"
code, label = rag_engine._extract_code_and_label(chunk_text, "ccam")
assert code == "YYYY001"
assert "Appendicectomie" in label
def test_extract_ccam_code_with_extension(self, rag_engine):
"""Test l'extraction d'un code CCAM avec extension ATIH."""
chunk_text = "YYYY002+ABC Appendicectomie par cœlioscopie"
code, label = rag_engine._extract_code_and_label(chunk_text, "ccam")
assert code == "YYYY002+ABC"
assert "Appendicectomie" in label
def test_extract_returns_unknown_for_invalid_format(self, rag_engine):
"""Test que l'extraction retourne UNKNOWN pour un format invalide."""
chunk_text = "Texte sans code valide"
code, label = rag_engine._extract_code_and_label(chunk_text, "cim10")
assert code == "UNKNOWN"
assert len(label) > 0
class TestRetrieveEligibilityCriteria:
"""Tests de la récupération des critères d'éligibilité."""
def test_retrieve_criteria_returns_eligibility(
self, rag_engine, sample_chunks_guide, temp_data_dir
):
"""Test que retrieve_eligibility_criteria retourne des critères."""
self._setup_guide_data(rag_engine, sample_chunks_guide, temp_data_dir)
criteria = rag_engine.retrieve_eligibility_criteria("K29.7", "dp")
assert criteria is not None
assert isinstance(criteria, EligibilityCriteria)
def test_retrieve_criteria_structure(
self, rag_engine, sample_chunks_guide, temp_data_dir
):
"""Test la structure des critères d'éligibilité."""
self._setup_guide_data(rag_engine, sample_chunks_guide, temp_data_dir)
criteria = rag_engine.retrieve_eligibility_criteria("A00.0", "dp")
assert criteria is not None
assert criteria.code == "A00.0"
assert criteria.code_type == "dp"
assert criteria.criteria_text is not None
assert isinstance(criteria.exclusion_rules, list)
assert isinstance(criteria.hierarchization_rules, list)
assert criteria.guide_section is not None
def test_extract_exclusion_rules(self, rag_engine):
"""Test l'extraction des règles d'exclusion."""
text = """Critères DP
- Exclut: les diagnostics niés
- À l'exclusion de: les antécédents
- Ne pas coder: les suspicions"""
rules = rag_engine._extract_exclusion_rules(text)
assert len(rules) == 3
assert any("niés" in rule for rule in rules)
assert any("antécédents" in rule for rule in rules)
assert any("suspicions" in rule for rule in rules)
def test_extract_hierarchization_rules(self, rag_engine):
"""Test l'extraction des règles de hiérarchisation."""
text = """Critères DP
- Hiérarchisation: privilégier le diagnostic le plus grave
- Priorité: diagnostic principal avant associé"""
rules = rag_engine._extract_hierarchization_rules(text)
assert len(rules) == 2
assert any("grave" in rule for rule in rules)
assert any("principal" in rule for rule in rules)
def _setup_guide_data(self, rag_engine, chunks, temp_data_dir):
"""Helper pour préparer les données du guide."""
version = "2026"
chunks_path = temp_data_dir / f"guide_mco_{version}_chunks.json"
with open(chunks_path, "w", encoding="utf-8") as f:
chunks_data = [chunk.model_dump() for chunk in chunks]
json.dump(chunks_data, f, ensure_ascii=False, default=str)
dimension = 384
index = faiss.IndexFlatL2(dimension)
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
faiss.normalize_L2(vectors)
index.add(vectors)
index_path = temp_data_dir / f"guide_mco_{version}_index.faiss"
faiss.write_index(index, str(index_path))
mock_model = Mock()
mock_query_vector = np.random.rand(dimension).astype(np.float32)
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
mock_model.encode.return_value = mock_query_vector
rag_engine._embeddings_model = mock_model
class TestCaching:
"""Tests du système de cache."""
def test_chunks_are_cached(self, rag_engine, sample_chunks_cim10, temp_data_dir):
"""Test que les chunks sont mis en cache."""
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
with open(chunks_path, "w", encoding="utf-8") as f:
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
json.dump(chunks_data, f, ensure_ascii=False, default=str)
# Premier chargement
chunks1 = rag_engine._load_chunks("cim10", "2026")
# Deuxième chargement (devrait utiliser le cache)
chunks2 = rag_engine._load_chunks("cim10", "2026")
assert chunks1 is chunks2 # Même objet en mémoire
def test_faiss_index_is_cached(self, rag_engine, temp_data_dir):
"""Test que l'index FAISS est mis en cache."""
# Créer un index
dimension = 384
index = faiss.IndexFlatL2(dimension)
vectors = np.random.rand(3, dimension).astype(np.float32)
index.add(vectors)
index_path = temp_data_dir / "cim10_2026_index.faiss"
faiss.write_index(index, str(index_path))
# Premier chargement
index1 = rag_engine._load_faiss_index("cim10", "2026")
# Deuxième chargement (devrait utiliser le cache)
index2 = rag_engine._load_faiss_index("cim10", "2026")
assert index1 is index2 # Même objet en mémoire
class TestErrorHandling:
"""Tests de la gestion d'erreurs."""
def test_load_chunks_raises_on_missing_file(self, rag_engine):
"""Test que _load_chunks lève une erreur si le fichier n'existe pas."""
with pytest.raises(FileNotFoundError):
rag_engine._load_chunks("cim10", "9999")
def test_load_faiss_index_raises_on_missing_file(self, rag_engine):
"""Test que _load_faiss_index lève une erreur si le fichier n'existe pas."""
with pytest.raises(FileNotFoundError):
rag_engine._load_faiss_index("cim10", "9999")
def test_search_icd10_handles_invalid_chunk_index(
self, rag_engine, sample_chunks_cim10, temp_data_dir
):
"""Test que search_icd10 gère les index de chunk invalides."""
# Préparer les données
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
with open(chunks_path, "w", encoding="utf-8") as f:
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
json.dump(chunks_data, f, ensure_ascii=False, default=str)
dimension = 384
index = faiss.IndexFlatL2(dimension)
# Ajouter plus de vecteurs que de chunks (pour simuler un index invalide)
vectors = np.random.rand(10, dimension).astype(np.float32)
faiss.normalize_L2(vectors)
index.add(vectors)
index_path = temp_data_dir / "cim10_2026_index.faiss"
faiss.write_index(index, str(index_path))
mock_model = Mock()
mock_query_vector = np.random.rand(dimension).astype(np.float32)
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
mock_model.encode.return_value = mock_query_vector
rag_engine._embeddings_model = mock_model
# Mocker le reranker
mock_reranker = Mock()
# Retourner des scores pour les chunks valides seulement
mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7])
rag_engine._reranker_model = mock_reranker
# La recherche ne devrait pas crasher
candidates = rag_engine.search_icd10("test", top_k=5, version="2026")
# La recherche devrait retourner des candidats sans crasher
# (même si certains index sont invalides)
assert len(candidates) > 0
# Tous les candidats retournés doivent avoir des codes valides
assert all(c.code != "UNKNOWN" or len(c.chunk_text) > 0 for c in candidates)
class TestReranking:
"""Tests du reranking avec cross-encoder."""
def test_rerank_results_returns_reranked_list(
self, rag_engine, sample_chunks_cim10
):
"""Test que _rerank_results retourne une liste reclassée."""
# Préparer des candidats
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
# Mocker le cross-encoder
mock_reranker = Mock()
mock_reranker.predict.return_value = np.array([0.8, 0.6, 0.9])
rag_engine._reranker_model = mock_reranker
# Effectuer le reranking
reranked = rag_engine._rerank_results(
"test query", candidates, sample_chunks_cim10, top_k=3
)
assert len(reranked) > 0
assert all(isinstance(r, tuple) for r in reranked)
assert all(len(r) == 2 for r in reranked)
# Vérifier que le cross-encoder a été appelé
mock_reranker.predict.assert_called_once()
def test_rerank_results_sorts_by_score(self, rag_engine, sample_chunks_cim10):
"""Test que _rerank_results trie par score décroissant."""
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
# Mocker le cross-encoder avec des scores spécifiques
mock_reranker = Mock()
# Chunk 2 a le meilleur score, puis 0, puis 1
mock_reranker.predict.return_value = np.array([0.5, 0.3, 0.9])
rag_engine._reranker_model = mock_reranker
reranked = rag_engine._rerank_results(
"test query", candidates, sample_chunks_cim10, top_k=3
)
# Vérifier que les résultats sont triés par score décroissant
scores = [score for _, score in reranked]
assert scores == sorted(scores, reverse=True)
# Le chunk 2 devrait être en premier (score 0.9)
assert reranked[0][0] == 2
def test_rerank_results_boosts_alphabetical_index(self, rag_engine):
"""Test que _rerank_results booste les résultats de l'index alphabétique."""
# Créer des chunks avec et sans index alphabétique
chunks = [
Chunk(
chunk_id="test_0",
referentiel_type="cim10",
referentiel_version="2026",
content="A00.0 Test code",
metadata={"chunk_type": "code_block"},
chunk_index=0,
),
Chunk(
chunk_id="test_1",
referentiel_type="cim10",
referentiel_version="2026",
content="Gastrite -> K29.7",
metadata={"chunk_type": "alphabetical_index"},
chunk_index=1,
),
]
candidates = [(0, 0.5), (1, 0.4)]
# Mocker le cross-encoder avec des scores identiques
mock_reranker = Mock()
mock_reranker.predict.return_value = np.array([0.7, 0.7])
rag_engine._reranker_model = mock_reranker
reranked = rag_engine._rerank_results("test query", candidates, chunks, top_k=2)
# Le chunk 1 (index alphabétique) devrait avoir un score plus élevé
# grâce au bonus de 0.1
chunk_1_score = next(score for idx, score in reranked if idx == 1)
chunk_0_score = next(score for idx, score in reranked if idx == 0)
assert chunk_1_score > chunk_0_score
# Le chunk 1 devrait être en premier
assert reranked[0][0] == 1
def test_rerank_results_respects_top_k(self, rag_engine, sample_chunks_cim10):
"""Test que _rerank_results respecte le paramètre top_k."""
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
mock_reranker = Mock()
mock_reranker.predict.return_value = np.array([0.8, 0.6, 0.9])
rag_engine._reranker_model = mock_reranker
reranked = rag_engine._rerank_results(
"test query", candidates, sample_chunks_cim10, top_k=2
)
assert len(reranked) == 2
def test_rerank_results_handles_empty_candidates(
self, rag_engine, sample_chunks_cim10
):
"""Test que _rerank_results gère les candidats vides."""
reranked = rag_engine._rerank_results(
"test query", [], sample_chunks_cim10, top_k=10
)
assert reranked == []
def test_rerank_results_handles_invalid_chunk_index(
self, rag_engine, sample_chunks_cim10
):
"""Test que _rerank_results gère les index de chunk invalides."""
# Candidat avec un index invalide
candidates = [(0, 0.5), (999, 0.4)]
mock_reranker = Mock()
# Le cross-encoder ne sera appelé qu'avec le chunk valide
mock_reranker.predict.return_value = np.array([0.8])
rag_engine._reranker_model = mock_reranker
reranked = rag_engine._rerank_results(
"test query", candidates, sample_chunks_cim10, top_k=10
)
# Seul le chunk valide devrait être retourné
assert len(reranked) == 1
assert reranked[0][0] == 0
def test_rerank_results_handles_reranker_error(
self, rag_engine, sample_chunks_cim10
):
"""Test que _rerank_results gère les erreurs du cross-encoder."""
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
# Mocker le cross-encoder pour lever une erreur
mock_reranker = Mock()
mock_reranker.predict.side_effect = Exception("Reranker error")
rag_engine._reranker_model = mock_reranker
# Le reranking devrait retourner les candidats originaux en cas d'erreur
reranked = rag_engine._rerank_results(
"test query", candidates, sample_chunks_cim10, top_k=3
)
# Devrait retourner les candidats originaux (top_k)
assert len(reranked) == 3
assert reranked == candidates[:3]
def test_get_reranker_model_loads_model(self, rag_engine):
"""Test que _get_reranker_model charge le modèle cross-encoder."""
with patch("sentence_transformers.CrossEncoder") as mock_ce:
mock_model = Mock()
mock_ce.return_value = mock_model
model = rag_engine._get_reranker_model()
assert model == mock_model
mock_ce.assert_called_once()
def test_get_reranker_model_caches_model(self, rag_engine):
"""Test que _get_reranker_model met en cache le modèle."""
with patch("sentence_transformers.CrossEncoder") as mock_ce:
mock_model = Mock()
mock_ce.return_value = mock_model
# Premier appel
model1 = rag_engine._get_reranker_model()
# Deuxième appel
model2 = rag_engine._get_reranker_model()
assert model1 is model2
# CrossEncoder ne devrait être appelé qu'une fois
mock_ce.assert_called_once()
class TestSearchWithReranking:
"""Tests de la recherche avec reranking intégré."""
def test_search_icd10_uses_reranking(
self, rag_engine, sample_chunks_cim10, temp_data_dir
):
"""Test que search_icd10 utilise le reranking."""
# Préparer les données
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
# Mocker le reranker
mock_reranker = Mock()
mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7])
rag_engine._reranker_model = mock_reranker
# Effectuer la recherche
candidates = rag_engine.search_icd10("gastrite", top_k=3, version="2026")
# Vérifier que le reranker a été appelé
mock_reranker.predict.assert_called_once()
# Vérifier que les candidats sont retournés
assert len(candidates) > 0
assert all(c.source == "reranked" for c in candidates)
def test_search_ccam_uses_reranking(
self, rag_engine, sample_chunks_ccam, temp_data_dir
):
"""Test que search_ccam utilise le reranking."""
# Préparer les données
self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam")
# Mocker le reranker
mock_reranker = Mock()
mock_reranker.predict.return_value = np.array([0.9, 0.8])
rag_engine._reranker_model = mock_reranker
# Effectuer la recherche
candidates = rag_engine.search_ccam("appendicectomie", top_k=2, version="2025")
# Vérifier que le reranker a été appelé
mock_reranker.predict.assert_called_once()
# Vérifier que les candidats sont retournés
assert len(candidates) > 0
assert all(c.source == "reranked" for c in candidates)
def test_search_icd10_alphabetical_index_prioritized(self, rag_engine, temp_data_dir):
"""Test que search_icd10 priorise les résultats de l'index alphabétique."""
# Créer des chunks avec index alphabétique
chunks = [
Chunk(
chunk_id="cim10_2026_0",
referentiel_type="cim10",
referentiel_version="2026",
content="K29.7 Gastrite, sans précision",
metadata={"chunk_type": "code_block"},
chunk_index=0,
),
Chunk(
chunk_id="cim10_2026_1",
referentiel_type="cim10",
referentiel_version="2026",
content="Gastrite -> K29.7",
metadata={"chunk_type": "alphabetical_index"},
chunk_index=1,
),
]
# Préparer les données
self._setup_test_data(rag_engine, chunks, temp_data_dir, "cim10")
# Mocker le reranker avec des scores identiques
mock_reranker = Mock()
mock_reranker.predict.return_value = np.array([0.7, 0.7])
rag_engine._reranker_model = mock_reranker
# Effectuer la recherche
candidates = rag_engine.search_icd10("gastrite", top_k=2, version="2026")
# Le premier candidat devrait provenir de l'index alphabétique
# (grâce au bonus de 0.1)
assert len(candidates) >= 1
# Vérifier que le chunk_id du premier candidat est celui de l'index alphabétique
assert candidates[0].chunk_id == "cim10_2026_1"
def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type):
"""Helper pour préparer les données de test."""
version = "2026" if ref_type == "cim10" else "2025"
chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json"
with open(chunks_path, "w", encoding="utf-8") as f:
chunks_data = [chunk.model_dump() for chunk in chunks]
json.dump(chunks_data, f, ensure_ascii=False, default=str)
dimension = 384
index = faiss.IndexFlatL2(dimension)
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
faiss.normalize_L2(vectors)
index.add(vectors)
index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss"
faiss.write_index(index, str(index_path))
mock_model = Mock()
mock_query_vector = np.random.rand(dimension).astype(np.float32)
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
mock_model.encode.return_value = mock_query_vector
rag_engine._embeddings_model = mock_model