Initial commit
This commit is contained in:
921
tests/test_rag_engine.py
Normal file
921
tests/test_rag_engine.py
Normal file
@@ -0,0 +1,921 @@
|
||||
"""
|
||||
Tests unitaires pour le RAGEngine.
|
||||
|
||||
Ce module teste:
|
||||
- La recherche BM25
|
||||
- La recherche vectorielle
|
||||
- La fusion RRF (Reciprocal Rank Fusion)
|
||||
- La recherche hybride CIM-10 et CCAM
|
||||
- L'extraction des critères d'éligibilité
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pipeline_mco_pmsi.rag.rag_engine import (
|
||||
CodeCandidate,
|
||||
EligibilityCriteria,
|
||||
RAGEngine,
|
||||
)
|
||||
from pipeline_mco_pmsi.rag.referentiels_manager import Chunk, ReferentielsManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_data_dir():
|
||||
"""Crée un répertoire temporaire pour les tests."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_referentiels_manager():
|
||||
"""Crée un mock du ReferentielsManager."""
|
||||
manager = Mock(spec=ReferentielsManager)
|
||||
manager.data_dir = Path("data/referentiels")
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunks_cim10():
|
||||
"""Crée des chunks CIM-10 de test."""
|
||||
return [
|
||||
Chunk(
|
||||
chunk_id="cim10_2026_0",
|
||||
referentiel_type="cim10",
|
||||
referentiel_version="2026",
|
||||
content="A00.0 Choléra dû à Vibrio cholerae 01, biovar cholerae\nInclus: choléra classique",
|
||||
metadata={"chunk_type": "code_block", "chapter": "Chapitre I"},
|
||||
chunk_index=0,
|
||||
),
|
||||
Chunk(
|
||||
chunk_id="cim10_2026_1",
|
||||
referentiel_type="cim10",
|
||||
referentiel_version="2026",
|
||||
content="A00.1 Choléra dû à Vibrio cholerae 01, biovar El Tor\nInclus: choléra El Tor",
|
||||
metadata={"chunk_type": "code_block", "chapter": "Chapitre I"},
|
||||
chunk_index=1,
|
||||
),
|
||||
Chunk(
|
||||
chunk_id="cim10_2026_2",
|
||||
referentiel_type="cim10",
|
||||
referentiel_version="2026",
|
||||
content="K29.7 Gastrite, sans précision\nInclus: gastrite SAI",
|
||||
metadata={"chunk_type": "code_block", "chapter": "Chapitre XI"},
|
||||
chunk_index=2,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunks_ccam():
|
||||
"""Crée des chunks CCAM de test."""
|
||||
return [
|
||||
Chunk(
|
||||
chunk_id="ccam_2025_0",
|
||||
referentiel_type="ccam",
|
||||
referentiel_version="2025",
|
||||
content="YYYY001 Appendicectomie par laparotomie\nNote: Ablation de l'appendice",
|
||||
metadata={"chunk_type": "acte", "section": "Section 1"},
|
||||
chunk_index=0,
|
||||
),
|
||||
Chunk(
|
||||
chunk_id="ccam_2025_1",
|
||||
referentiel_type="ccam",
|
||||
referentiel_version="2025",
|
||||
content="YYYY002+ABC Appendicectomie par cœlioscopie\nNote: Ablation de l'appendice par voie cœlioscopique",
|
||||
metadata={"chunk_type": "acte", "section": "Section 1"},
|
||||
chunk_index=1,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunks_guide():
|
||||
"""Crée des chunks du Guide MCO de test."""
|
||||
return [
|
||||
Chunk(
|
||||
chunk_id="guide_mco_2026_0",
|
||||
referentiel_type="guide_mco",
|
||||
referentiel_version="2026",
|
||||
content="""Critères d'éligibilité DP
|
||||
- Le DP doit être le diagnostic ayant mobilisé l'essentiel des ressources
|
||||
- Exclut: les diagnostics niés
|
||||
- Exclut: les antécédents
|
||||
- Hiérarchisation: privilégier le diagnostic le plus grave""",
|
||||
metadata={"chunk_type": "section", "section": "Chapitre 1"},
|
||||
chunk_index=0,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rag_engine(mock_referentiels_manager, temp_data_dir):
|
||||
"""Crée une instance de RAGEngine pour les tests."""
|
||||
engine = RAGEngine(mock_referentiels_manager, data_dir=temp_data_dir)
|
||||
return engine
|
||||
|
||||
|
||||
class TestRAGEngineInit:
|
||||
"""Tests d'initialisation du RAGEngine."""
|
||||
|
||||
def test_init_creates_engine(self, mock_referentiels_manager, temp_data_dir):
|
||||
"""Test que le RAGEngine s'initialise correctement."""
|
||||
engine = RAGEngine(mock_referentiels_manager, data_dir=temp_data_dir)
|
||||
|
||||
assert engine.referentiels_manager == mock_referentiels_manager
|
||||
assert engine.data_dir == temp_data_dir
|
||||
assert engine._bm25_indexes == {}
|
||||
assert engine._faiss_indexes == {}
|
||||
assert engine._chunks_cache == {}
|
||||
assert engine._embeddings_model is None
|
||||
|
||||
|
||||
class TestBM25Search:
|
||||
"""Tests de la recherche BM25."""
|
||||
|
||||
def test_build_bm25_index(self, rag_engine, sample_chunks_cim10):
|
||||
"""Test la construction d'un index BM25."""
|
||||
bm25_index = rag_engine._build_bm25_index(sample_chunks_cim10)
|
||||
|
||||
assert bm25_index is not None
|
||||
assert len(bm25_index.doc_freqs) > 0
|
||||
|
||||
def test_bm25_search_returns_results(
|
||||
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
||||
):
|
||||
"""Test que la recherche BM25 retourne des résultats."""
|
||||
# Sauvegarder les chunks
|
||||
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
|
||||
with open(chunks_path, "w", encoding="utf-8") as f:
|
||||
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
|
||||
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
||||
|
||||
# Effectuer la recherche
|
||||
results = rag_engine._bm25_search("choléra", "cim10", "2026", top_k=2)
|
||||
|
||||
assert len(results) > 0
|
||||
assert all(isinstance(r, tuple) for r in results)
|
||||
assert all(len(r) == 2 for r in results)
|
||||
# Vérifier que les scores sont des floats
|
||||
assert all(isinstance(r[1], float) for r in results)
|
||||
|
||||
def test_bm25_search_ranks_relevant_higher(
|
||||
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
||||
):
|
||||
"""Test que BM25 classe les résultats pertinents plus haut."""
|
||||
# Sauvegarder les chunks
|
||||
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
|
||||
with open(chunks_path, "w", encoding="utf-8") as f:
|
||||
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
|
||||
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
||||
|
||||
# Rechercher "gastrite" - devrait trouver le chunk 2 en premier
|
||||
results = rag_engine._bm25_search("gastrite", "cim10", "2026", top_k=3)
|
||||
|
||||
# Le premier résultat devrait être le chunk contenant "gastrite"
|
||||
top_chunk_idx = results[0][0]
|
||||
assert sample_chunks_cim10[top_chunk_idx].content.lower().find("gastrite") != -1
|
||||
|
||||
|
||||
class TestVectorSearch:
|
||||
"""Tests de la recherche vectorielle."""
|
||||
|
||||
def test_vector_search_with_mock_index(
|
||||
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
||||
):
|
||||
"""Test la recherche vectorielle avec un index FAISS mocké."""
|
||||
# Créer un index FAISS simple pour le test
|
||||
dimension = 384 # Dimension typique pour MiniLM
|
||||
index = faiss.IndexFlatL2(dimension)
|
||||
|
||||
# Ajouter des vecteurs aléatoires
|
||||
vectors = np.random.rand(len(sample_chunks_cim10), dimension).astype(np.float32)
|
||||
# Normaliser pour cosine similarity
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Sauvegarder l'index
|
||||
index_path = temp_data_dir / "cim10_2026_index.faiss"
|
||||
faiss.write_index(index, str(index_path))
|
||||
|
||||
# Sauvegarder les chunks
|
||||
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
|
||||
with open(chunks_path, "w", encoding="utf-8") as f:
|
||||
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
|
||||
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
||||
|
||||
# Mocker le modèle d'embeddings
|
||||
mock_model = Mock()
|
||||
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
||||
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
||||
mock_model.encode.return_value = mock_query_vector
|
||||
rag_engine._embeddings_model = mock_model
|
||||
|
||||
# Effectuer la recherche
|
||||
results = rag_engine._vector_search("test query", "cim10", "2026", top_k=2)
|
||||
|
||||
assert len(results) > 0
|
||||
assert all(isinstance(r, tuple) for r in results)
|
||||
assert all(len(r) == 2 for r in results)
|
||||
# Vérifier que les scores de similarité sont dans [0, 1]
|
||||
assert all(0.0 <= r[1] <= 1.0 for r in results)
|
||||
|
||||
|
||||
class TestReciprocalRankFusion:
|
||||
"""Tests de la fusion RRF."""
|
||||
|
||||
def test_rrf_fusion_combines_results(self, rag_engine):
|
||||
"""Test que RRF fusionne correctement les résultats."""
|
||||
bm25_results = [(0, 10.5), (1, 8.2), (2, 5.1)]
|
||||
vector_results = [(1, 0.95), (0, 0.88), (3, 0.75)]
|
||||
|
||||
fused = rag_engine._reciprocal_rank_fusion(bm25_results, vector_results, k=60)
|
||||
|
||||
# Vérifier que tous les chunks uniques sont présents
|
||||
chunk_indices = [idx for idx, _ in fused]
|
||||
assert set(chunk_indices) == {0, 1, 2, 3}
|
||||
|
||||
# Vérifier que les résultats sont triés par score décroissant
|
||||
scores = [score for _, score in fused]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_rrf_boosts_common_results(self, rag_engine):
|
||||
"""Test que RRF booste les résultats présents dans les deux listes."""
|
||||
# Chunk 0 est en tête des deux listes
|
||||
bm25_results = [(0, 10.0), (1, 5.0), (2, 3.0)]
|
||||
vector_results = [(0, 0.95), (3, 0.80), (1, 0.70)]
|
||||
|
||||
fused = rag_engine._reciprocal_rank_fusion(bm25_results, vector_results, k=60)
|
||||
|
||||
# Le chunk 0 devrait avoir le score le plus élevé
|
||||
top_chunk = fused[0][0]
|
||||
assert top_chunk == 0
|
||||
|
||||
def test_rrf_with_empty_lists(self, rag_engine):
|
||||
"""Test RRF avec des listes vides."""
|
||||
fused = rag_engine._reciprocal_rank_fusion([], [], k=60)
|
||||
assert fused == []
|
||||
|
||||
def test_rrf_with_one_empty_list(self, rag_engine):
|
||||
"""Test RRF avec une liste vide."""
|
||||
bm25_results = [(0, 10.0), (1, 5.0)]
|
||||
fused = rag_engine._reciprocal_rank_fusion(bm25_results, [], k=60)
|
||||
|
||||
assert len(fused) == 2
|
||||
assert all(idx in [0, 1] for idx, _ in fused)
|
||||
|
||||
|
||||
class TestSearchICD10:
|
||||
"""Tests de la recherche CIM-10."""
|
||||
|
||||
def test_search_icd10_returns_candidates(
|
||||
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
||||
):
|
||||
"""Test que search_icd10 retourne des candidats."""
|
||||
# Préparer les données
|
||||
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
|
||||
|
||||
# Mocker le reranker pour éviter de charger le modèle réel
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
# Effectuer la recherche
|
||||
candidates = rag_engine.search_icd10("gastrite", top_k=2, version="2026")
|
||||
|
||||
assert len(candidates) > 0
|
||||
assert all(isinstance(c, CodeCandidate) for c in candidates)
|
||||
|
||||
def test_search_icd10_candidate_structure(
|
||||
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
||||
):
|
||||
"""Test la structure des candidats retournés."""
|
||||
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
|
||||
|
||||
# Mocker le reranker
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.return_value = np.array([0.9])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
candidates = rag_engine.search_icd10("choléra", top_k=1, version="2026")
|
||||
|
||||
assert len(candidates) > 0
|
||||
candidate = candidates[0]
|
||||
|
||||
# Vérifier les champs obligatoires
|
||||
assert candidate.code is not None
|
||||
assert candidate.label is not None
|
||||
assert 0.0 <= candidate.similarity_score <= 1.0
|
||||
assert candidate.source == "reranked"
|
||||
assert candidate.chunk_id is not None
|
||||
assert candidate.chunk_text is not None
|
||||
|
||||
def test_search_icd10_respects_top_k(
|
||||
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
||||
):
|
||||
"""Test que search_icd10 respecte le paramètre top_k."""
|
||||
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
|
||||
|
||||
# Mocker le reranker
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.return_value = np.array([0.9, 0.8])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
candidates = rag_engine.search_icd10("test", top_k=2, version="2026")
|
||||
|
||||
assert len(candidates) <= 2
|
||||
|
||||
def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type):
|
||||
"""Helper pour préparer les données de test."""
|
||||
# Sauvegarder les chunks
|
||||
version = "2026" if ref_type == "cim10" else "2025"
|
||||
chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json"
|
||||
with open(chunks_path, "w", encoding="utf-8") as f:
|
||||
chunks_data = [chunk.model_dump() for chunk in chunks]
|
||||
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
||||
|
||||
# Créer un index FAISS simple
|
||||
dimension = 384
|
||||
index = faiss.IndexFlatL2(dimension)
|
||||
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss"
|
||||
faiss.write_index(index, str(index_path))
|
||||
|
||||
# Mocker le modèle d'embeddings
|
||||
mock_model = Mock()
|
||||
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
||||
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
||||
mock_model.encode.return_value = mock_query_vector
|
||||
rag_engine._embeddings_model = mock_model
|
||||
|
||||
|
||||
class TestSearchCCAM:
|
||||
"""Tests de la recherche CCAM."""
|
||||
|
||||
def test_search_ccam_returns_candidates(
|
||||
self, rag_engine, sample_chunks_ccam, temp_data_dir
|
||||
):
|
||||
"""Test que search_ccam retourne des candidats."""
|
||||
self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam")
|
||||
|
||||
# Mocker le reranker
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.return_value = np.array([0.9, 0.8])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
candidates = rag_engine.search_ccam("appendicectomie", top_k=2, version="2025")
|
||||
|
||||
assert len(candidates) > 0
|
||||
assert all(isinstance(c, CodeCandidate) for c in candidates)
|
||||
|
||||
def test_search_ccam_extracts_code_with_extension(
|
||||
self, rag_engine, sample_chunks_ccam, temp_data_dir
|
||||
):
|
||||
"""Test que search_ccam extrait correctement les codes avec extension ATIH."""
|
||||
self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam")
|
||||
|
||||
# Mocker le reranker
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.return_value = np.array([0.9, 0.8])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
candidates = rag_engine.search_ccam("cœlioscopie", top_k=2, version="2025")
|
||||
|
||||
# Vérifier qu'au moins un candidat a un code avec extension
|
||||
codes = [c.code for c in candidates]
|
||||
# Au moins un code devrait être présent
|
||||
assert len(codes) > 0
|
||||
|
||||
def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type):
|
||||
"""Helper pour préparer les données de test."""
|
||||
version = "2025"
|
||||
chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json"
|
||||
with open(chunks_path, "w", encoding="utf-8") as f:
|
||||
chunks_data = [chunk.model_dump() for chunk in chunks]
|
||||
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
||||
|
||||
dimension = 384
|
||||
index = faiss.IndexFlatL2(dimension)
|
||||
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss"
|
||||
faiss.write_index(index, str(index_path))
|
||||
|
||||
mock_model = Mock()
|
||||
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
||||
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
||||
mock_model.encode.return_value = mock_query_vector
|
||||
rag_engine._embeddings_model = mock_model
|
||||
|
||||
|
||||
class TestExtractCodeAndLabel:
|
||||
"""Tests de l'extraction de code et libellé."""
|
||||
|
||||
def test_extract_cim10_code_and_label(self, rag_engine):
|
||||
"""Test l'extraction d'un code CIM-10 et son libellé."""
|
||||
chunk_text = "A00.0 Choléra dû à Vibrio cholerae 01, biovar cholerae\nInclus: choléra classique"
|
||||
|
||||
code, label = rag_engine._extract_code_and_label(chunk_text, "cim10")
|
||||
|
||||
assert code == "A00.0"
|
||||
assert "Choléra" in label
|
||||
|
||||
def test_extract_ccam_code_and_label(self, rag_engine):
|
||||
"""Test l'extraction d'un code CCAM et son libellé."""
|
||||
chunk_text = "YYYY001 Appendicectomie par laparotomie\nNote: Ablation de l'appendice"
|
||||
|
||||
code, label = rag_engine._extract_code_and_label(chunk_text, "ccam")
|
||||
|
||||
assert code == "YYYY001"
|
||||
assert "Appendicectomie" in label
|
||||
|
||||
def test_extract_ccam_code_with_extension(self, rag_engine):
|
||||
"""Test l'extraction d'un code CCAM avec extension ATIH."""
|
||||
chunk_text = "YYYY002+ABC Appendicectomie par cœlioscopie"
|
||||
|
||||
code, label = rag_engine._extract_code_and_label(chunk_text, "ccam")
|
||||
|
||||
assert code == "YYYY002+ABC"
|
||||
assert "Appendicectomie" in label
|
||||
|
||||
def test_extract_returns_unknown_for_invalid_format(self, rag_engine):
|
||||
"""Test que l'extraction retourne UNKNOWN pour un format invalide."""
|
||||
chunk_text = "Texte sans code valide"
|
||||
|
||||
code, label = rag_engine._extract_code_and_label(chunk_text, "cim10")
|
||||
|
||||
assert code == "UNKNOWN"
|
||||
assert len(label) > 0
|
||||
|
||||
|
||||
class TestRetrieveEligibilityCriteria:
|
||||
"""Tests de la récupération des critères d'éligibilité."""
|
||||
|
||||
def test_retrieve_criteria_returns_eligibility(
|
||||
self, rag_engine, sample_chunks_guide, temp_data_dir
|
||||
):
|
||||
"""Test que retrieve_eligibility_criteria retourne des critères."""
|
||||
self._setup_guide_data(rag_engine, sample_chunks_guide, temp_data_dir)
|
||||
|
||||
criteria = rag_engine.retrieve_eligibility_criteria("K29.7", "dp")
|
||||
|
||||
assert criteria is not None
|
||||
assert isinstance(criteria, EligibilityCriteria)
|
||||
|
||||
def test_retrieve_criteria_structure(
|
||||
self, rag_engine, sample_chunks_guide, temp_data_dir
|
||||
):
|
||||
"""Test la structure des critères d'éligibilité."""
|
||||
self._setup_guide_data(rag_engine, sample_chunks_guide, temp_data_dir)
|
||||
|
||||
criteria = rag_engine.retrieve_eligibility_criteria("A00.0", "dp")
|
||||
|
||||
assert criteria is not None
|
||||
assert criteria.code == "A00.0"
|
||||
assert criteria.code_type == "dp"
|
||||
assert criteria.criteria_text is not None
|
||||
assert isinstance(criteria.exclusion_rules, list)
|
||||
assert isinstance(criteria.hierarchization_rules, list)
|
||||
assert criteria.guide_section is not None
|
||||
|
||||
def test_extract_exclusion_rules(self, rag_engine):
|
||||
"""Test l'extraction des règles d'exclusion."""
|
||||
text = """Critères DP
|
||||
- Exclut: les diagnostics niés
|
||||
- À l'exclusion de: les antécédents
|
||||
- Ne pas coder: les suspicions"""
|
||||
|
||||
rules = rag_engine._extract_exclusion_rules(text)
|
||||
|
||||
assert len(rules) == 3
|
||||
assert any("niés" in rule for rule in rules)
|
||||
assert any("antécédents" in rule for rule in rules)
|
||||
assert any("suspicions" in rule for rule in rules)
|
||||
|
||||
def test_extract_hierarchization_rules(self, rag_engine):
|
||||
"""Test l'extraction des règles de hiérarchisation."""
|
||||
text = """Critères DP
|
||||
- Hiérarchisation: privilégier le diagnostic le plus grave
|
||||
- Priorité: diagnostic principal avant associé"""
|
||||
|
||||
rules = rag_engine._extract_hierarchization_rules(text)
|
||||
|
||||
assert len(rules) == 2
|
||||
assert any("grave" in rule for rule in rules)
|
||||
assert any("principal" in rule for rule in rules)
|
||||
|
||||
def _setup_guide_data(self, rag_engine, chunks, temp_data_dir):
|
||||
"""Helper pour préparer les données du guide."""
|
||||
version = "2026"
|
||||
chunks_path = temp_data_dir / f"guide_mco_{version}_chunks.json"
|
||||
with open(chunks_path, "w", encoding="utf-8") as f:
|
||||
chunks_data = [chunk.model_dump() for chunk in chunks]
|
||||
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
||||
|
||||
dimension = 384
|
||||
index = faiss.IndexFlatL2(dimension)
|
||||
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
index_path = temp_data_dir / f"guide_mco_{version}_index.faiss"
|
||||
faiss.write_index(index, str(index_path))
|
||||
|
||||
mock_model = Mock()
|
||||
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
||||
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
||||
mock_model.encode.return_value = mock_query_vector
|
||||
rag_engine._embeddings_model = mock_model
|
||||
|
||||
|
||||
class TestCaching:
|
||||
"""Tests du système de cache."""
|
||||
|
||||
def test_chunks_are_cached(self, rag_engine, sample_chunks_cim10, temp_data_dir):
|
||||
"""Test que les chunks sont mis en cache."""
|
||||
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
|
||||
with open(chunks_path, "w", encoding="utf-8") as f:
|
||||
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
|
||||
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
||||
|
||||
# Premier chargement
|
||||
chunks1 = rag_engine._load_chunks("cim10", "2026")
|
||||
|
||||
# Deuxième chargement (devrait utiliser le cache)
|
||||
chunks2 = rag_engine._load_chunks("cim10", "2026")
|
||||
|
||||
assert chunks1 is chunks2 # Même objet en mémoire
|
||||
|
||||
def test_faiss_index_is_cached(self, rag_engine, temp_data_dir):
|
||||
"""Test que l'index FAISS est mis en cache."""
|
||||
# Créer un index
|
||||
dimension = 384
|
||||
index = faiss.IndexFlatL2(dimension)
|
||||
vectors = np.random.rand(3, dimension).astype(np.float32)
|
||||
index.add(vectors)
|
||||
|
||||
index_path = temp_data_dir / "cim10_2026_index.faiss"
|
||||
faiss.write_index(index, str(index_path))
|
||||
|
||||
# Premier chargement
|
||||
index1 = rag_engine._load_faiss_index("cim10", "2026")
|
||||
|
||||
# Deuxième chargement (devrait utiliser le cache)
|
||||
index2 = rag_engine._load_faiss_index("cim10", "2026")
|
||||
|
||||
assert index1 is index2 # Même objet en mémoire
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests de la gestion d'erreurs."""
|
||||
|
||||
def test_load_chunks_raises_on_missing_file(self, rag_engine):
|
||||
"""Test que _load_chunks lève une erreur si le fichier n'existe pas."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
rag_engine._load_chunks("cim10", "9999")
|
||||
|
||||
def test_load_faiss_index_raises_on_missing_file(self, rag_engine):
|
||||
"""Test que _load_faiss_index lève une erreur si le fichier n'existe pas."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
rag_engine._load_faiss_index("cim10", "9999")
|
||||
|
||||
def test_search_icd10_handles_invalid_chunk_index(
|
||||
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
||||
):
|
||||
"""Test que search_icd10 gère les index de chunk invalides."""
|
||||
# Préparer les données
|
||||
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
|
||||
with open(chunks_path, "w", encoding="utf-8") as f:
|
||||
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
|
||||
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
||||
|
||||
dimension = 384
|
||||
index = faiss.IndexFlatL2(dimension)
|
||||
# Ajouter plus de vecteurs que de chunks (pour simuler un index invalide)
|
||||
vectors = np.random.rand(10, dimension).astype(np.float32)
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
index_path = temp_data_dir / "cim10_2026_index.faiss"
|
||||
faiss.write_index(index, str(index_path))
|
||||
|
||||
mock_model = Mock()
|
||||
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
||||
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
||||
mock_model.encode.return_value = mock_query_vector
|
||||
rag_engine._embeddings_model = mock_model
|
||||
|
||||
# Mocker le reranker
|
||||
mock_reranker = Mock()
|
||||
# Retourner des scores pour les chunks valides seulement
|
||||
mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
# La recherche ne devrait pas crasher
|
||||
candidates = rag_engine.search_icd10("test", top_k=5, version="2026")
|
||||
|
||||
# La recherche devrait retourner des candidats sans crasher
|
||||
# (même si certains index sont invalides)
|
||||
assert len(candidates) > 0
|
||||
# Tous les candidats retournés doivent avoir des codes valides
|
||||
assert all(c.code != "UNKNOWN" or len(c.chunk_text) > 0 for c in candidates)
|
||||
|
||||
|
||||
class TestReranking:
|
||||
"""Tests du reranking avec cross-encoder."""
|
||||
|
||||
def test_rerank_results_returns_reranked_list(
|
||||
self, rag_engine, sample_chunks_cim10
|
||||
):
|
||||
"""Test que _rerank_results retourne une liste reclassée."""
|
||||
# Préparer des candidats
|
||||
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
|
||||
|
||||
# Mocker le cross-encoder
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.return_value = np.array([0.8, 0.6, 0.9])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
# Effectuer le reranking
|
||||
reranked = rag_engine._rerank_results(
|
||||
"test query", candidates, sample_chunks_cim10, top_k=3
|
||||
)
|
||||
|
||||
assert len(reranked) > 0
|
||||
assert all(isinstance(r, tuple) for r in reranked)
|
||||
assert all(len(r) == 2 for r in reranked)
|
||||
|
||||
# Vérifier que le cross-encoder a été appelé
|
||||
mock_reranker.predict.assert_called_once()
|
||||
|
||||
def test_rerank_results_sorts_by_score(self, rag_engine, sample_chunks_cim10):
|
||||
"""Test que _rerank_results trie par score décroissant."""
|
||||
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
|
||||
|
||||
# Mocker le cross-encoder avec des scores spécifiques
|
||||
mock_reranker = Mock()
|
||||
# Chunk 2 a le meilleur score, puis 0, puis 1
|
||||
mock_reranker.predict.return_value = np.array([0.5, 0.3, 0.9])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
reranked = rag_engine._rerank_results(
|
||||
"test query", candidates, sample_chunks_cim10, top_k=3
|
||||
)
|
||||
|
||||
# Vérifier que les résultats sont triés par score décroissant
|
||||
scores = [score for _, score in reranked]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
# Le chunk 2 devrait être en premier (score 0.9)
|
||||
assert reranked[0][0] == 2
|
||||
|
||||
def test_rerank_results_boosts_alphabetical_index(self, rag_engine):
|
||||
"""Test que _rerank_results booste les résultats de l'index alphabétique."""
|
||||
# Créer des chunks avec et sans index alphabétique
|
||||
chunks = [
|
||||
Chunk(
|
||||
chunk_id="test_0",
|
||||
referentiel_type="cim10",
|
||||
referentiel_version="2026",
|
||||
content="A00.0 Test code",
|
||||
metadata={"chunk_type": "code_block"},
|
||||
chunk_index=0,
|
||||
),
|
||||
Chunk(
|
||||
chunk_id="test_1",
|
||||
referentiel_type="cim10",
|
||||
referentiel_version="2026",
|
||||
content="Gastrite -> K29.7",
|
||||
metadata={"chunk_type": "alphabetical_index"},
|
||||
chunk_index=1,
|
||||
),
|
||||
]
|
||||
|
||||
candidates = [(0, 0.5), (1, 0.4)]
|
||||
|
||||
# Mocker le cross-encoder avec des scores identiques
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.return_value = np.array([0.7, 0.7])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
reranked = rag_engine._rerank_results("test query", candidates, chunks, top_k=2)
|
||||
|
||||
# Le chunk 1 (index alphabétique) devrait avoir un score plus élevé
|
||||
# grâce au bonus de 0.1
|
||||
chunk_1_score = next(score for idx, score in reranked if idx == 1)
|
||||
chunk_0_score = next(score for idx, score in reranked if idx == 0)
|
||||
|
||||
assert chunk_1_score > chunk_0_score
|
||||
# Le chunk 1 devrait être en premier
|
||||
assert reranked[0][0] == 1
|
||||
|
||||
def test_rerank_results_respects_top_k(self, rag_engine, sample_chunks_cim10):
|
||||
"""Test que _rerank_results respecte le paramètre top_k."""
|
||||
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
|
||||
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.return_value = np.array([0.8, 0.6, 0.9])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
reranked = rag_engine._rerank_results(
|
||||
"test query", candidates, sample_chunks_cim10, top_k=2
|
||||
)
|
||||
|
||||
assert len(reranked) == 2
|
||||
|
||||
def test_rerank_results_handles_empty_candidates(
|
||||
self, rag_engine, sample_chunks_cim10
|
||||
):
|
||||
"""Test que _rerank_results gère les candidats vides."""
|
||||
reranked = rag_engine._rerank_results(
|
||||
"test query", [], sample_chunks_cim10, top_k=10
|
||||
)
|
||||
|
||||
assert reranked == []
|
||||
|
||||
def test_rerank_results_handles_invalid_chunk_index(
|
||||
self, rag_engine, sample_chunks_cim10
|
||||
):
|
||||
"""Test que _rerank_results gère les index de chunk invalides."""
|
||||
# Candidat avec un index invalide
|
||||
candidates = [(0, 0.5), (999, 0.4)]
|
||||
|
||||
mock_reranker = Mock()
|
||||
# Le cross-encoder ne sera appelé qu'avec le chunk valide
|
||||
mock_reranker.predict.return_value = np.array([0.8])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
reranked = rag_engine._rerank_results(
|
||||
"test query", candidates, sample_chunks_cim10, top_k=10
|
||||
)
|
||||
|
||||
# Seul le chunk valide devrait être retourné
|
||||
assert len(reranked) == 1
|
||||
assert reranked[0][0] == 0
|
||||
|
||||
def test_rerank_results_handles_reranker_error(
|
||||
self, rag_engine, sample_chunks_cim10
|
||||
):
|
||||
"""Test que _rerank_results gère les erreurs du cross-encoder."""
|
||||
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
|
||||
|
||||
# Mocker le cross-encoder pour lever une erreur
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.side_effect = Exception("Reranker error")
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
# Le reranking devrait retourner les candidats originaux en cas d'erreur
|
||||
reranked = rag_engine._rerank_results(
|
||||
"test query", candidates, sample_chunks_cim10, top_k=3
|
||||
)
|
||||
|
||||
# Devrait retourner les candidats originaux (top_k)
|
||||
assert len(reranked) == 3
|
||||
assert reranked == candidates[:3]
|
||||
|
||||
def test_get_reranker_model_loads_model(self, rag_engine):
|
||||
"""Test que _get_reranker_model charge le modèle cross-encoder."""
|
||||
with patch("sentence_transformers.CrossEncoder") as mock_ce:
|
||||
mock_model = Mock()
|
||||
mock_ce.return_value = mock_model
|
||||
|
||||
model = rag_engine._get_reranker_model()
|
||||
|
||||
assert model == mock_model
|
||||
mock_ce.assert_called_once()
|
||||
|
||||
def test_get_reranker_model_caches_model(self, rag_engine):
|
||||
"""Test que _get_reranker_model met en cache le modèle."""
|
||||
with patch("sentence_transformers.CrossEncoder") as mock_ce:
|
||||
mock_model = Mock()
|
||||
mock_ce.return_value = mock_model
|
||||
|
||||
# Premier appel
|
||||
model1 = rag_engine._get_reranker_model()
|
||||
# Deuxième appel
|
||||
model2 = rag_engine._get_reranker_model()
|
||||
|
||||
assert model1 is model2
|
||||
# CrossEncoder ne devrait être appelé qu'une fois
|
||||
mock_ce.assert_called_once()
|
||||
|
||||
|
||||
class TestSearchWithReranking:
|
||||
"""Tests de la recherche avec reranking intégré."""
|
||||
|
||||
def test_search_icd10_uses_reranking(
|
||||
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
||||
):
|
||||
"""Test que search_icd10 utilise le reranking."""
|
||||
# Préparer les données
|
||||
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
|
||||
|
||||
# Mocker le reranker
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
# Effectuer la recherche
|
||||
candidates = rag_engine.search_icd10("gastrite", top_k=3, version="2026")
|
||||
|
||||
# Vérifier que le reranker a été appelé
|
||||
mock_reranker.predict.assert_called_once()
|
||||
|
||||
# Vérifier que les candidats sont retournés
|
||||
assert len(candidates) > 0
|
||||
assert all(c.source == "reranked" for c in candidates)
|
||||
|
||||
def test_search_ccam_uses_reranking(
|
||||
self, rag_engine, sample_chunks_ccam, temp_data_dir
|
||||
):
|
||||
"""Test que search_ccam utilise le reranking."""
|
||||
# Préparer les données
|
||||
self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam")
|
||||
|
||||
# Mocker le reranker
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.return_value = np.array([0.9, 0.8])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
# Effectuer la recherche
|
||||
candidates = rag_engine.search_ccam("appendicectomie", top_k=2, version="2025")
|
||||
|
||||
# Vérifier que le reranker a été appelé
|
||||
mock_reranker.predict.assert_called_once()
|
||||
|
||||
# Vérifier que les candidats sont retournés
|
||||
assert len(candidates) > 0
|
||||
assert all(c.source == "reranked" for c in candidates)
|
||||
|
||||
def test_search_icd10_alphabetical_index_prioritized(self, rag_engine, temp_data_dir):
|
||||
"""Test que search_icd10 priorise les résultats de l'index alphabétique."""
|
||||
# Créer des chunks avec index alphabétique
|
||||
chunks = [
|
||||
Chunk(
|
||||
chunk_id="cim10_2026_0",
|
||||
referentiel_type="cim10",
|
||||
referentiel_version="2026",
|
||||
content="K29.7 Gastrite, sans précision",
|
||||
metadata={"chunk_type": "code_block"},
|
||||
chunk_index=0,
|
||||
),
|
||||
Chunk(
|
||||
chunk_id="cim10_2026_1",
|
||||
referentiel_type="cim10",
|
||||
referentiel_version="2026",
|
||||
content="Gastrite -> K29.7",
|
||||
metadata={"chunk_type": "alphabetical_index"},
|
||||
chunk_index=1,
|
||||
),
|
||||
]
|
||||
|
||||
# Préparer les données
|
||||
self._setup_test_data(rag_engine, chunks, temp_data_dir, "cim10")
|
||||
|
||||
# Mocker le reranker avec des scores identiques
|
||||
mock_reranker = Mock()
|
||||
mock_reranker.predict.return_value = np.array([0.7, 0.7])
|
||||
rag_engine._reranker_model = mock_reranker
|
||||
|
||||
# Effectuer la recherche
|
||||
candidates = rag_engine.search_icd10("gastrite", top_k=2, version="2026")
|
||||
|
||||
# Le premier candidat devrait provenir de l'index alphabétique
|
||||
# (grâce au bonus de 0.1)
|
||||
assert len(candidates) >= 1
|
||||
# Vérifier que le chunk_id du premier candidat est celui de l'index alphabétique
|
||||
assert candidates[0].chunk_id == "cim10_2026_1"
|
||||
|
||||
def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type):
|
||||
"""Helper pour préparer les données de test."""
|
||||
version = "2026" if ref_type == "cim10" else "2025"
|
||||
chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json"
|
||||
with open(chunks_path, "w", encoding="utf-8") as f:
|
||||
chunks_data = [chunk.model_dump() for chunk in chunks]
|
||||
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
||||
|
||||
dimension = 384
|
||||
index = faiss.IndexFlatL2(dimension)
|
||||
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss"
|
||||
faiss.write_index(index, str(index_path))
|
||||
|
||||
mock_model = Mock()
|
||||
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
||||
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
||||
mock_model.encode.return_value = mock_query_vector
|
||||
rag_engine._embeddings_model = mock_model
|
||||
Reference in New Issue
Block a user