""" Tests unitaires pour le RAGEngine. Ce module teste: - La recherche BM25 - La recherche vectorielle - La fusion RRF (Reciprocal Rank Fusion) - La recherche hybride CIM-10 et CCAM - L'extraction des critères d'éligibilité """ import json import tempfile from pathlib import Path from unittest.mock import MagicMock, Mock, patch import faiss import numpy as np import pytest from pipeline_mco_pmsi.rag.rag_engine import ( CodeCandidate, EligibilityCriteria, RAGEngine, ) from pipeline_mco_pmsi.rag.referentiels_manager import Chunk, ReferentielsManager @pytest.fixture def temp_data_dir(): """Crée un répertoire temporaire pour les tests.""" with tempfile.TemporaryDirectory() as tmpdir: yield Path(tmpdir) @pytest.fixture def mock_referentiels_manager(): """Crée un mock du ReferentielsManager.""" manager = Mock(spec=ReferentielsManager) manager.data_dir = Path("data/referentiels") return manager @pytest.fixture def sample_chunks_cim10(): """Crée des chunks CIM-10 de test.""" return [ Chunk( chunk_id="cim10_2026_0", referentiel_type="cim10", referentiel_version="2026", content="A00.0 Choléra dû à Vibrio cholerae 01, biovar cholerae\nInclus: choléra classique", metadata={"chunk_type": "code_block", "chapter": "Chapitre I"}, chunk_index=0, ), Chunk( chunk_id="cim10_2026_1", referentiel_type="cim10", referentiel_version="2026", content="A00.1 Choléra dû à Vibrio cholerae 01, biovar El Tor\nInclus: choléra El Tor", metadata={"chunk_type": "code_block", "chapter": "Chapitre I"}, chunk_index=1, ), Chunk( chunk_id="cim10_2026_2", referentiel_type="cim10", referentiel_version="2026", content="K29.7 Gastrite, sans précision\nInclus: gastrite SAI", metadata={"chunk_type": "code_block", "chapter": "Chapitre XI"}, chunk_index=2, ), ] @pytest.fixture def sample_chunks_ccam(): """Crée des chunks CCAM de test.""" return [ Chunk( chunk_id="ccam_2025_0", referentiel_type="ccam", referentiel_version="2025", content="YYYY001 Appendicectomie par laparotomie\nNote: Ablation de l'appendice", metadata={"chunk_type": "acte", "section": "Section 1"}, chunk_index=0, ), Chunk( chunk_id="ccam_2025_1", referentiel_type="ccam", referentiel_version="2025", content="YYYY002+ABC Appendicectomie par cœlioscopie\nNote: Ablation de l'appendice par voie cœlioscopique", metadata={"chunk_type": "acte", "section": "Section 1"}, chunk_index=1, ), ] @pytest.fixture def sample_chunks_guide(): """Crée des chunks du Guide MCO de test.""" return [ Chunk( chunk_id="guide_mco_2026_0", referentiel_type="guide_mco", referentiel_version="2026", content="""Critères d'éligibilité DP - Le DP doit être le diagnostic ayant mobilisé l'essentiel des ressources - Exclut: les diagnostics niés - Exclut: les antécédents - Hiérarchisation: privilégier le diagnostic le plus grave""", metadata={"chunk_type": "section", "section": "Chapitre 1"}, chunk_index=0, ), ] @pytest.fixture def rag_engine(mock_referentiels_manager, temp_data_dir): """Crée une instance de RAGEngine pour les tests.""" engine = RAGEngine(mock_referentiels_manager, data_dir=temp_data_dir) return engine class TestRAGEngineInit: """Tests d'initialisation du RAGEngine.""" def test_init_creates_engine(self, mock_referentiels_manager, temp_data_dir): """Test que le RAGEngine s'initialise correctement.""" engine = RAGEngine(mock_referentiels_manager, data_dir=temp_data_dir) assert engine.referentiels_manager == mock_referentiels_manager assert engine.data_dir == temp_data_dir assert engine._bm25_indexes == {} assert engine._faiss_indexes == {} assert engine._chunks_cache == {} assert engine._embeddings_model is None class TestBM25Search: """Tests de la recherche BM25.""" def test_build_bm25_index(self, rag_engine, sample_chunks_cim10): """Test la construction d'un index BM25.""" bm25_index = rag_engine._build_bm25_index(sample_chunks_cim10) assert bm25_index is not None assert len(bm25_index.doc_freqs) > 0 def test_bm25_search_returns_results( self, rag_engine, sample_chunks_cim10, temp_data_dir ): """Test que la recherche BM25 retourne des résultats.""" # Sauvegarder les chunks chunks_path = temp_data_dir / "cim10_2026_chunks.json" with open(chunks_path, "w", encoding="utf-8") as f: chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10] json.dump(chunks_data, f, ensure_ascii=False, default=str) # Effectuer la recherche results = rag_engine._bm25_search("choléra", "cim10", "2026", top_k=2) assert len(results) > 0 assert all(isinstance(r, tuple) for r in results) assert all(len(r) == 2 for r in results) # Vérifier que les scores sont des floats assert all(isinstance(r[1], float) for r in results) def test_bm25_search_ranks_relevant_higher( self, rag_engine, sample_chunks_cim10, temp_data_dir ): """Test que BM25 classe les résultats pertinents plus haut.""" # Sauvegarder les chunks chunks_path = temp_data_dir / "cim10_2026_chunks.json" with open(chunks_path, "w", encoding="utf-8") as f: chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10] json.dump(chunks_data, f, ensure_ascii=False, default=str) # Rechercher "gastrite" - devrait trouver le chunk 2 en premier results = rag_engine._bm25_search("gastrite", "cim10", "2026", top_k=3) # Le premier résultat devrait être le chunk contenant "gastrite" top_chunk_idx = results[0][0] assert sample_chunks_cim10[top_chunk_idx].content.lower().find("gastrite") != -1 class TestVectorSearch: """Tests de la recherche vectorielle.""" def test_vector_search_with_mock_index( self, rag_engine, sample_chunks_cim10, temp_data_dir ): """Test la recherche vectorielle avec un index FAISS mocké.""" # Créer un index FAISS simple pour le test dimension = 384 # Dimension typique pour MiniLM index = faiss.IndexFlatL2(dimension) # Ajouter des vecteurs aléatoires vectors = np.random.rand(len(sample_chunks_cim10), dimension).astype(np.float32) # Normaliser pour cosine similarity faiss.normalize_L2(vectors) index.add(vectors) # Sauvegarder l'index index_path = temp_data_dir / "cim10_2026_index.faiss" faiss.write_index(index, str(index_path)) # Sauvegarder les chunks chunks_path = temp_data_dir / "cim10_2026_chunks.json" with open(chunks_path, "w", encoding="utf-8") as f: chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10] json.dump(chunks_data, f, ensure_ascii=False, default=str) # Mocker le modèle d'embeddings mock_model = Mock() mock_query_vector = np.random.rand(dimension).astype(np.float32) faiss.normalize_L2(mock_query_vector.reshape(1, -1)) mock_model.encode.return_value = mock_query_vector rag_engine._embeddings_model = mock_model # Effectuer la recherche results = rag_engine._vector_search("test query", "cim10", "2026", top_k=2) assert len(results) > 0 assert all(isinstance(r, tuple) for r in results) assert all(len(r) == 2 for r in results) # Vérifier que les scores de similarité sont dans [0, 1] assert all(0.0 <= r[1] <= 1.0 for r in results) class TestReciprocalRankFusion: """Tests de la fusion RRF.""" def test_rrf_fusion_combines_results(self, rag_engine): """Test que RRF fusionne correctement les résultats.""" bm25_results = [(0, 10.5), (1, 8.2), (2, 5.1)] vector_results = [(1, 0.95), (0, 0.88), (3, 0.75)] fused = rag_engine._reciprocal_rank_fusion(bm25_results, vector_results, k=60) # Vérifier que tous les chunks uniques sont présents chunk_indices = [idx for idx, _ in fused] assert set(chunk_indices) == {0, 1, 2, 3} # Vérifier que les résultats sont triés par score décroissant scores = [score for _, score in fused] assert scores == sorted(scores, reverse=True) def test_rrf_boosts_common_results(self, rag_engine): """Test que RRF booste les résultats présents dans les deux listes.""" # Chunk 0 est en tête des deux listes bm25_results = [(0, 10.0), (1, 5.0), (2, 3.0)] vector_results = [(0, 0.95), (3, 0.80), (1, 0.70)] fused = rag_engine._reciprocal_rank_fusion(bm25_results, vector_results, k=60) # Le chunk 0 devrait avoir le score le plus élevé top_chunk = fused[0][0] assert top_chunk == 0 def test_rrf_with_empty_lists(self, rag_engine): """Test RRF avec des listes vides.""" fused = rag_engine._reciprocal_rank_fusion([], [], k=60) assert fused == [] def test_rrf_with_one_empty_list(self, rag_engine): """Test RRF avec une liste vide.""" bm25_results = [(0, 10.0), (1, 5.0)] fused = rag_engine._reciprocal_rank_fusion(bm25_results, [], k=60) assert len(fused) == 2 assert all(idx in [0, 1] for idx, _ in fused) class TestSearchICD10: """Tests de la recherche CIM-10.""" def test_search_icd10_returns_candidates( self, rag_engine, sample_chunks_cim10, temp_data_dir ): """Test que search_icd10 retourne des candidats.""" # Préparer les données self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10") # Mocker le reranker pour éviter de charger le modèle réel mock_reranker = Mock() mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7]) rag_engine._reranker_model = mock_reranker # Effectuer la recherche candidates = rag_engine.search_icd10("gastrite", top_k=2, version="2026") assert len(candidates) > 0 assert all(isinstance(c, CodeCandidate) for c in candidates) def test_search_icd10_candidate_structure( self, rag_engine, sample_chunks_cim10, temp_data_dir ): """Test la structure des candidats retournés.""" self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10") # Mocker le reranker mock_reranker = Mock() mock_reranker.predict.return_value = np.array([0.9]) rag_engine._reranker_model = mock_reranker candidates = rag_engine.search_icd10("choléra", top_k=1, version="2026") assert len(candidates) > 0 candidate = candidates[0] # Vérifier les champs obligatoires assert candidate.code is not None assert candidate.label is not None assert 0.0 <= candidate.similarity_score <= 1.0 assert candidate.source == "reranked" assert candidate.chunk_id is not None assert candidate.chunk_text is not None def test_search_icd10_respects_top_k( self, rag_engine, sample_chunks_cim10, temp_data_dir ): """Test que search_icd10 respecte le paramètre top_k.""" self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10") # Mocker le reranker mock_reranker = Mock() mock_reranker.predict.return_value = np.array([0.9, 0.8]) rag_engine._reranker_model = mock_reranker candidates = rag_engine.search_icd10("test", top_k=2, version="2026") assert len(candidates) <= 2 def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type): """Helper pour préparer les données de test.""" # Sauvegarder les chunks version = "2026" if ref_type == "cim10" else "2025" chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json" with open(chunks_path, "w", encoding="utf-8") as f: chunks_data = [chunk.model_dump() for chunk in chunks] json.dump(chunks_data, f, ensure_ascii=False, default=str) # Créer un index FAISS simple dimension = 384 index = faiss.IndexFlatL2(dimension) vectors = np.random.rand(len(chunks), dimension).astype(np.float32) faiss.normalize_L2(vectors) index.add(vectors) index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss" faiss.write_index(index, str(index_path)) # Mocker le modèle d'embeddings mock_model = Mock() mock_query_vector = np.random.rand(dimension).astype(np.float32) faiss.normalize_L2(mock_query_vector.reshape(1, -1)) mock_model.encode.return_value = mock_query_vector rag_engine._embeddings_model = mock_model class TestSearchCCAM: """Tests de la recherche CCAM.""" def test_search_ccam_returns_candidates( self, rag_engine, sample_chunks_ccam, temp_data_dir ): """Test que search_ccam retourne des candidats.""" self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam") # Mocker le reranker mock_reranker = Mock() mock_reranker.predict.return_value = np.array([0.9, 0.8]) rag_engine._reranker_model = mock_reranker candidates = rag_engine.search_ccam("appendicectomie", top_k=2, version="2025") assert len(candidates) > 0 assert all(isinstance(c, CodeCandidate) for c in candidates) def test_search_ccam_extracts_code_with_extension( self, rag_engine, sample_chunks_ccam, temp_data_dir ): """Test que search_ccam extrait correctement les codes avec extension ATIH.""" self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam") # Mocker le reranker mock_reranker = Mock() mock_reranker.predict.return_value = np.array([0.9, 0.8]) rag_engine._reranker_model = mock_reranker candidates = rag_engine.search_ccam("cœlioscopie", top_k=2, version="2025") # Vérifier qu'au moins un candidat a un code avec extension codes = [c.code for c in candidates] # Au moins un code devrait être présent assert len(codes) > 0 def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type): """Helper pour préparer les données de test.""" version = "2025" chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json" with open(chunks_path, "w", encoding="utf-8") as f: chunks_data = [chunk.model_dump() for chunk in chunks] json.dump(chunks_data, f, ensure_ascii=False, default=str) dimension = 384 index = faiss.IndexFlatL2(dimension) vectors = np.random.rand(len(chunks), dimension).astype(np.float32) faiss.normalize_L2(vectors) index.add(vectors) index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss" faiss.write_index(index, str(index_path)) mock_model = Mock() mock_query_vector = np.random.rand(dimension).astype(np.float32) faiss.normalize_L2(mock_query_vector.reshape(1, -1)) mock_model.encode.return_value = mock_query_vector rag_engine._embeddings_model = mock_model class TestExtractCodeAndLabel: """Tests de l'extraction de code et libellé.""" def test_extract_cim10_code_and_label(self, rag_engine): """Test l'extraction d'un code CIM-10 et son libellé.""" chunk_text = "A00.0 Choléra dû à Vibrio cholerae 01, biovar cholerae\nInclus: choléra classique" code, label = rag_engine._extract_code_and_label(chunk_text, "cim10") assert code == "A00.0" assert "Choléra" in label def test_extract_ccam_code_and_label(self, rag_engine): """Test l'extraction d'un code CCAM et son libellé.""" chunk_text = "YYYY001 Appendicectomie par laparotomie\nNote: Ablation de l'appendice" code, label = rag_engine._extract_code_and_label(chunk_text, "ccam") assert code == "YYYY001" assert "Appendicectomie" in label def test_extract_ccam_code_with_extension(self, rag_engine): """Test l'extraction d'un code CCAM avec extension ATIH.""" chunk_text = "YYYY002+ABC Appendicectomie par cœlioscopie" code, label = rag_engine._extract_code_and_label(chunk_text, "ccam") assert code == "YYYY002+ABC" assert "Appendicectomie" in label def test_extract_returns_unknown_for_invalid_format(self, rag_engine): """Test que l'extraction retourne UNKNOWN pour un format invalide.""" chunk_text = "Texte sans code valide" code, label = rag_engine._extract_code_and_label(chunk_text, "cim10") assert code == "UNKNOWN" assert len(label) > 0 class TestRetrieveEligibilityCriteria: """Tests de la récupération des critères d'éligibilité.""" def test_retrieve_criteria_returns_eligibility( self, rag_engine, sample_chunks_guide, temp_data_dir ): """Test que retrieve_eligibility_criteria retourne des critères.""" self._setup_guide_data(rag_engine, sample_chunks_guide, temp_data_dir) criteria = rag_engine.retrieve_eligibility_criteria("K29.7", "dp") assert criteria is not None assert isinstance(criteria, EligibilityCriteria) def test_retrieve_criteria_structure( self, rag_engine, sample_chunks_guide, temp_data_dir ): """Test la structure des critères d'éligibilité.""" self._setup_guide_data(rag_engine, sample_chunks_guide, temp_data_dir) criteria = rag_engine.retrieve_eligibility_criteria("A00.0", "dp") assert criteria is not None assert criteria.code == "A00.0" assert criteria.code_type == "dp" assert criteria.criteria_text is not None assert isinstance(criteria.exclusion_rules, list) assert isinstance(criteria.hierarchization_rules, list) assert criteria.guide_section is not None def test_extract_exclusion_rules(self, rag_engine): """Test l'extraction des règles d'exclusion.""" text = """Critères DP - Exclut: les diagnostics niés - À l'exclusion de: les antécédents - Ne pas coder: les suspicions""" rules = rag_engine._extract_exclusion_rules(text) assert len(rules) == 3 assert any("niés" in rule for rule in rules) assert any("antécédents" in rule for rule in rules) assert any("suspicions" in rule for rule in rules) def test_extract_hierarchization_rules(self, rag_engine): """Test l'extraction des règles de hiérarchisation.""" text = """Critères DP - Hiérarchisation: privilégier le diagnostic le plus grave - Priorité: diagnostic principal avant associé""" rules = rag_engine._extract_hierarchization_rules(text) assert len(rules) == 2 assert any("grave" in rule for rule in rules) assert any("principal" in rule for rule in rules) def _setup_guide_data(self, rag_engine, chunks, temp_data_dir): """Helper pour préparer les données du guide.""" version = "2026" chunks_path = temp_data_dir / f"guide_mco_{version}_chunks.json" with open(chunks_path, "w", encoding="utf-8") as f: chunks_data = [chunk.model_dump() for chunk in chunks] json.dump(chunks_data, f, ensure_ascii=False, default=str) dimension = 384 index = faiss.IndexFlatL2(dimension) vectors = np.random.rand(len(chunks), dimension).astype(np.float32) faiss.normalize_L2(vectors) index.add(vectors) index_path = temp_data_dir / f"guide_mco_{version}_index.faiss" faiss.write_index(index, str(index_path)) mock_model = Mock() mock_query_vector = np.random.rand(dimension).astype(np.float32) faiss.normalize_L2(mock_query_vector.reshape(1, -1)) mock_model.encode.return_value = mock_query_vector rag_engine._embeddings_model = mock_model class TestCaching: """Tests du système de cache.""" def test_chunks_are_cached(self, rag_engine, sample_chunks_cim10, temp_data_dir): """Test que les chunks sont mis en cache.""" chunks_path = temp_data_dir / "cim10_2026_chunks.json" with open(chunks_path, "w", encoding="utf-8") as f: chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10] json.dump(chunks_data, f, ensure_ascii=False, default=str) # Premier chargement chunks1 = rag_engine._load_chunks("cim10", "2026") # Deuxième chargement (devrait utiliser le cache) chunks2 = rag_engine._load_chunks("cim10", "2026") assert chunks1 is chunks2 # Même objet en mémoire def test_faiss_index_is_cached(self, rag_engine, temp_data_dir): """Test que l'index FAISS est mis en cache.""" # Créer un index dimension = 384 index = faiss.IndexFlatL2(dimension) vectors = np.random.rand(3, dimension).astype(np.float32) index.add(vectors) index_path = temp_data_dir / "cim10_2026_index.faiss" faiss.write_index(index, str(index_path)) # Premier chargement index1 = rag_engine._load_faiss_index("cim10", "2026") # Deuxième chargement (devrait utiliser le cache) index2 = rag_engine._load_faiss_index("cim10", "2026") assert index1 is index2 # Même objet en mémoire class TestErrorHandling: """Tests de la gestion d'erreurs.""" def test_load_chunks_raises_on_missing_file(self, rag_engine): """Test que _load_chunks lève une erreur si le fichier n'existe pas.""" with pytest.raises(FileNotFoundError): rag_engine._load_chunks("cim10", "9999") def test_load_faiss_index_raises_on_missing_file(self, rag_engine): """Test que _load_faiss_index lève une erreur si le fichier n'existe pas.""" with pytest.raises(FileNotFoundError): rag_engine._load_faiss_index("cim10", "9999") def test_search_icd10_handles_invalid_chunk_index( self, rag_engine, sample_chunks_cim10, temp_data_dir ): """Test que search_icd10 gère les index de chunk invalides.""" # Préparer les données chunks_path = temp_data_dir / "cim10_2026_chunks.json" with open(chunks_path, "w", encoding="utf-8") as f: chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10] json.dump(chunks_data, f, ensure_ascii=False, default=str) dimension = 384 index = faiss.IndexFlatL2(dimension) # Ajouter plus de vecteurs que de chunks (pour simuler un index invalide) vectors = np.random.rand(10, dimension).astype(np.float32) faiss.normalize_L2(vectors) index.add(vectors) index_path = temp_data_dir / "cim10_2026_index.faiss" faiss.write_index(index, str(index_path)) mock_model = Mock() mock_query_vector = np.random.rand(dimension).astype(np.float32) faiss.normalize_L2(mock_query_vector.reshape(1, -1)) mock_model.encode.return_value = mock_query_vector rag_engine._embeddings_model = mock_model # Mocker le reranker mock_reranker = Mock() # Retourner des scores pour les chunks valides seulement mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7]) rag_engine._reranker_model = mock_reranker # La recherche ne devrait pas crasher candidates = rag_engine.search_icd10("test", top_k=5, version="2026") # La recherche devrait retourner des candidats sans crasher # (même si certains index sont invalides) assert len(candidates) > 0 # Tous les candidats retournés doivent avoir des codes valides assert all(c.code != "UNKNOWN" or len(c.chunk_text) > 0 for c in candidates) class TestReranking: """Tests du reranking avec cross-encoder.""" def test_rerank_results_returns_reranked_list( self, rag_engine, sample_chunks_cim10 ): """Test que _rerank_results retourne une liste reclassée.""" # Préparer des candidats candidates = [(0, 0.5), (1, 0.4), (2, 0.3)] # Mocker le cross-encoder mock_reranker = Mock() mock_reranker.predict.return_value = np.array([0.8, 0.6, 0.9]) rag_engine._reranker_model = mock_reranker # Effectuer le reranking reranked = rag_engine._rerank_results( "test query", candidates, sample_chunks_cim10, top_k=3 ) assert len(reranked) > 0 assert all(isinstance(r, tuple) for r in reranked) assert all(len(r) == 2 for r in reranked) # Vérifier que le cross-encoder a été appelé mock_reranker.predict.assert_called_once() def test_rerank_results_sorts_by_score(self, rag_engine, sample_chunks_cim10): """Test que _rerank_results trie par score décroissant.""" candidates = [(0, 0.5), (1, 0.4), (2, 0.3)] # Mocker le cross-encoder avec des scores spécifiques mock_reranker = Mock() # Chunk 2 a le meilleur score, puis 0, puis 1 mock_reranker.predict.return_value = np.array([0.5, 0.3, 0.9]) rag_engine._reranker_model = mock_reranker reranked = rag_engine._rerank_results( "test query", candidates, sample_chunks_cim10, top_k=3 ) # Vérifier que les résultats sont triés par score décroissant scores = [score for _, score in reranked] assert scores == sorted(scores, reverse=True) # Le chunk 2 devrait être en premier (score 0.9) assert reranked[0][0] == 2 def test_rerank_results_boosts_alphabetical_index(self, rag_engine): """Test que _rerank_results booste les résultats de l'index alphabétique.""" # Créer des chunks avec et sans index alphabétique chunks = [ Chunk( chunk_id="test_0", referentiel_type="cim10", referentiel_version="2026", content="A00.0 Test code", metadata={"chunk_type": "code_block"}, chunk_index=0, ), Chunk( chunk_id="test_1", referentiel_type="cim10", referentiel_version="2026", content="Gastrite -> K29.7", metadata={"chunk_type": "alphabetical_index"}, chunk_index=1, ), ] candidates = [(0, 0.5), (1, 0.4)] # Mocker le cross-encoder avec des scores identiques mock_reranker = Mock() mock_reranker.predict.return_value = np.array([0.7, 0.7]) rag_engine._reranker_model = mock_reranker reranked = rag_engine._rerank_results("test query", candidates, chunks, top_k=2) # Le chunk 1 (index alphabétique) devrait avoir un score plus élevé # grâce au bonus de 0.1 chunk_1_score = next(score for idx, score in reranked if idx == 1) chunk_0_score = next(score for idx, score in reranked if idx == 0) assert chunk_1_score > chunk_0_score # Le chunk 1 devrait être en premier assert reranked[0][0] == 1 def test_rerank_results_respects_top_k(self, rag_engine, sample_chunks_cim10): """Test que _rerank_results respecte le paramètre top_k.""" candidates = [(0, 0.5), (1, 0.4), (2, 0.3)] mock_reranker = Mock() mock_reranker.predict.return_value = np.array([0.8, 0.6, 0.9]) rag_engine._reranker_model = mock_reranker reranked = rag_engine._rerank_results( "test query", candidates, sample_chunks_cim10, top_k=2 ) assert len(reranked) == 2 def test_rerank_results_handles_empty_candidates( self, rag_engine, sample_chunks_cim10 ): """Test que _rerank_results gère les candidats vides.""" reranked = rag_engine._rerank_results( "test query", [], sample_chunks_cim10, top_k=10 ) assert reranked == [] def test_rerank_results_handles_invalid_chunk_index( self, rag_engine, sample_chunks_cim10 ): """Test que _rerank_results gère les index de chunk invalides.""" # Candidat avec un index invalide candidates = [(0, 0.5), (999, 0.4)] mock_reranker = Mock() # Le cross-encoder ne sera appelé qu'avec le chunk valide mock_reranker.predict.return_value = np.array([0.8]) rag_engine._reranker_model = mock_reranker reranked = rag_engine._rerank_results( "test query", candidates, sample_chunks_cim10, top_k=10 ) # Seul le chunk valide devrait être retourné assert len(reranked) == 1 assert reranked[0][0] == 0 def test_rerank_results_handles_reranker_error( self, rag_engine, sample_chunks_cim10 ): """Test que _rerank_results gère les erreurs du cross-encoder.""" candidates = [(0, 0.5), (1, 0.4), (2, 0.3)] # Mocker le cross-encoder pour lever une erreur mock_reranker = Mock() mock_reranker.predict.side_effect = Exception("Reranker error") rag_engine._reranker_model = mock_reranker # Le reranking devrait retourner les candidats originaux en cas d'erreur reranked = rag_engine._rerank_results( "test query", candidates, sample_chunks_cim10, top_k=3 ) # Devrait retourner les candidats originaux (top_k) assert len(reranked) == 3 assert reranked == candidates[:3] def test_get_reranker_model_loads_model(self, rag_engine): """Test que _get_reranker_model charge le modèle cross-encoder.""" with patch("sentence_transformers.CrossEncoder") as mock_ce: mock_model = Mock() mock_ce.return_value = mock_model model = rag_engine._get_reranker_model() assert model == mock_model mock_ce.assert_called_once() def test_get_reranker_model_caches_model(self, rag_engine): """Test que _get_reranker_model met en cache le modèle.""" with patch("sentence_transformers.CrossEncoder") as mock_ce: mock_model = Mock() mock_ce.return_value = mock_model # Premier appel model1 = rag_engine._get_reranker_model() # Deuxième appel model2 = rag_engine._get_reranker_model() assert model1 is model2 # CrossEncoder ne devrait être appelé qu'une fois mock_ce.assert_called_once() class TestSearchWithReranking: """Tests de la recherche avec reranking intégré.""" def test_search_icd10_uses_reranking( self, rag_engine, sample_chunks_cim10, temp_data_dir ): """Test que search_icd10 utilise le reranking.""" # Préparer les données self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10") # Mocker le reranker mock_reranker = Mock() mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7]) rag_engine._reranker_model = mock_reranker # Effectuer la recherche candidates = rag_engine.search_icd10("gastrite", top_k=3, version="2026") # Vérifier que le reranker a été appelé mock_reranker.predict.assert_called_once() # Vérifier que les candidats sont retournés assert len(candidates) > 0 assert all(c.source == "reranked" for c in candidates) def test_search_ccam_uses_reranking( self, rag_engine, sample_chunks_ccam, temp_data_dir ): """Test que search_ccam utilise le reranking.""" # Préparer les données self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam") # Mocker le reranker mock_reranker = Mock() mock_reranker.predict.return_value = np.array([0.9, 0.8]) rag_engine._reranker_model = mock_reranker # Effectuer la recherche candidates = rag_engine.search_ccam("appendicectomie", top_k=2, version="2025") # Vérifier que le reranker a été appelé mock_reranker.predict.assert_called_once() # Vérifier que les candidats sont retournés assert len(candidates) > 0 assert all(c.source == "reranked" for c in candidates) def test_search_icd10_alphabetical_index_prioritized(self, rag_engine, temp_data_dir): """Test que search_icd10 priorise les résultats de l'index alphabétique.""" # Créer des chunks avec index alphabétique chunks = [ Chunk( chunk_id="cim10_2026_0", referentiel_type="cim10", referentiel_version="2026", content="K29.7 Gastrite, sans précision", metadata={"chunk_type": "code_block"}, chunk_index=0, ), Chunk( chunk_id="cim10_2026_1", referentiel_type="cim10", referentiel_version="2026", content="Gastrite -> K29.7", metadata={"chunk_type": "alphabetical_index"}, chunk_index=1, ), ] # Préparer les données self._setup_test_data(rag_engine, chunks, temp_data_dir, "cim10") # Mocker le reranker avec des scores identiques mock_reranker = Mock() mock_reranker.predict.return_value = np.array([0.7, 0.7]) rag_engine._reranker_model = mock_reranker # Effectuer la recherche candidates = rag_engine.search_icd10("gastrite", top_k=2, version="2026") # Le premier candidat devrait provenir de l'index alphabétique # (grâce au bonus de 0.1) assert len(candidates) >= 1 # Vérifier que le chunk_id du premier candidat est celui de l'index alphabétique assert candidates[0].chunk_id == "cim10_2026_1" def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type): """Helper pour préparer les données de test.""" version = "2026" if ref_type == "cim10" else "2025" chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json" with open(chunks_path, "w", encoding="utf-8") as f: chunks_data = [chunk.model_dump() for chunk in chunks] json.dump(chunks_data, f, ensure_ascii=False, default=str) dimension = 384 index = faiss.IndexFlatL2(dimension) vectors = np.random.rand(len(chunks), dimension).astype(np.float32) faiss.normalize_L2(vectors) index.add(vectors) index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss" faiss.write_index(index, str(index_path)) mock_model = Mock() mock_query_vector = np.random.rand(dimension).astype(np.float32) faiss.normalize_L2(mock_query_vector.reshape(1, -1)) mock_model.encode.return_value = mock_query_vector rag_engine._embeddings_model = mock_model