922 lines
35 KiB
Python
922 lines
35 KiB
Python
"""
|
|
Tests unitaires pour le RAGEngine.
|
|
|
|
Ce module teste:
|
|
- La recherche BM25
|
|
- La recherche vectorielle
|
|
- La fusion RRF (Reciprocal Rank Fusion)
|
|
- La recherche hybride CIM-10 et CCAM
|
|
- L'extraction des critères d'éligibilité
|
|
"""
|
|
|
|
import json
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
import faiss
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from pipeline_mco_pmsi.rag.rag_engine import (
|
|
CodeCandidate,
|
|
EligibilityCriteria,
|
|
RAGEngine,
|
|
)
|
|
from pipeline_mco_pmsi.rag.referentiels_manager import Chunk, ReferentielsManager
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_data_dir():
|
|
"""Crée un répertoire temporaire pour les tests."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
yield Path(tmpdir)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_referentiels_manager():
|
|
"""Crée un mock du ReferentielsManager."""
|
|
manager = Mock(spec=ReferentielsManager)
|
|
manager.data_dir = Path("data/referentiels")
|
|
return manager
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_chunks_cim10():
|
|
"""Crée des chunks CIM-10 de test."""
|
|
return [
|
|
Chunk(
|
|
chunk_id="cim10_2026_0",
|
|
referentiel_type="cim10",
|
|
referentiel_version="2026",
|
|
content="A00.0 Choléra dû à Vibrio cholerae 01, biovar cholerae\nInclus: choléra classique",
|
|
metadata={"chunk_type": "code_block", "chapter": "Chapitre I"},
|
|
chunk_index=0,
|
|
),
|
|
Chunk(
|
|
chunk_id="cim10_2026_1",
|
|
referentiel_type="cim10",
|
|
referentiel_version="2026",
|
|
content="A00.1 Choléra dû à Vibrio cholerae 01, biovar El Tor\nInclus: choléra El Tor",
|
|
metadata={"chunk_type": "code_block", "chapter": "Chapitre I"},
|
|
chunk_index=1,
|
|
),
|
|
Chunk(
|
|
chunk_id="cim10_2026_2",
|
|
referentiel_type="cim10",
|
|
referentiel_version="2026",
|
|
content="K29.7 Gastrite, sans précision\nInclus: gastrite SAI",
|
|
metadata={"chunk_type": "code_block", "chapter": "Chapitre XI"},
|
|
chunk_index=2,
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_chunks_ccam():
|
|
"""Crée des chunks CCAM de test."""
|
|
return [
|
|
Chunk(
|
|
chunk_id="ccam_2025_0",
|
|
referentiel_type="ccam",
|
|
referentiel_version="2025",
|
|
content="YYYY001 Appendicectomie par laparotomie\nNote: Ablation de l'appendice",
|
|
metadata={"chunk_type": "acte", "section": "Section 1"},
|
|
chunk_index=0,
|
|
),
|
|
Chunk(
|
|
chunk_id="ccam_2025_1",
|
|
referentiel_type="ccam",
|
|
referentiel_version="2025",
|
|
content="YYYY002+ABC Appendicectomie par cœlioscopie\nNote: Ablation de l'appendice par voie cœlioscopique",
|
|
metadata={"chunk_type": "acte", "section": "Section 1"},
|
|
chunk_index=1,
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_chunks_guide():
|
|
"""Crée des chunks du Guide MCO de test."""
|
|
return [
|
|
Chunk(
|
|
chunk_id="guide_mco_2026_0",
|
|
referentiel_type="guide_mco",
|
|
referentiel_version="2026",
|
|
content="""Critères d'éligibilité DP
|
|
- Le DP doit être le diagnostic ayant mobilisé l'essentiel des ressources
|
|
- Exclut: les diagnostics niés
|
|
- Exclut: les antécédents
|
|
- Hiérarchisation: privilégier le diagnostic le plus grave""",
|
|
metadata={"chunk_type": "section", "section": "Chapitre 1"},
|
|
chunk_index=0,
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def rag_engine(mock_referentiels_manager, temp_data_dir):
|
|
"""Crée une instance de RAGEngine pour les tests."""
|
|
engine = RAGEngine(mock_referentiels_manager, data_dir=temp_data_dir)
|
|
return engine
|
|
|
|
|
|
class TestRAGEngineInit:
|
|
"""Tests d'initialisation du RAGEngine."""
|
|
|
|
def test_init_creates_engine(self, mock_referentiels_manager, temp_data_dir):
|
|
"""Test que le RAGEngine s'initialise correctement."""
|
|
engine = RAGEngine(mock_referentiels_manager, data_dir=temp_data_dir)
|
|
|
|
assert engine.referentiels_manager == mock_referentiels_manager
|
|
assert engine.data_dir == temp_data_dir
|
|
assert engine._bm25_indexes == {}
|
|
assert engine._faiss_indexes == {}
|
|
assert engine._chunks_cache == {}
|
|
assert engine._embeddings_model is None
|
|
|
|
|
|
class TestBM25Search:
|
|
"""Tests de la recherche BM25."""
|
|
|
|
def test_build_bm25_index(self, rag_engine, sample_chunks_cim10):
|
|
"""Test la construction d'un index BM25."""
|
|
bm25_index = rag_engine._build_bm25_index(sample_chunks_cim10)
|
|
|
|
assert bm25_index is not None
|
|
assert len(bm25_index.doc_freqs) > 0
|
|
|
|
def test_bm25_search_returns_results(
|
|
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
|
):
|
|
"""Test que la recherche BM25 retourne des résultats."""
|
|
# Sauvegarder les chunks
|
|
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
|
|
with open(chunks_path, "w", encoding="utf-8") as f:
|
|
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
|
|
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
|
|
|
# Effectuer la recherche
|
|
results = rag_engine._bm25_search("choléra", "cim10", "2026", top_k=2)
|
|
|
|
assert len(results) > 0
|
|
assert all(isinstance(r, tuple) for r in results)
|
|
assert all(len(r) == 2 for r in results)
|
|
# Vérifier que les scores sont des floats
|
|
assert all(isinstance(r[1], float) for r in results)
|
|
|
|
def test_bm25_search_ranks_relevant_higher(
|
|
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
|
):
|
|
"""Test que BM25 classe les résultats pertinents plus haut."""
|
|
# Sauvegarder les chunks
|
|
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
|
|
with open(chunks_path, "w", encoding="utf-8") as f:
|
|
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
|
|
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
|
|
|
# Rechercher "gastrite" - devrait trouver le chunk 2 en premier
|
|
results = rag_engine._bm25_search("gastrite", "cim10", "2026", top_k=3)
|
|
|
|
# Le premier résultat devrait être le chunk contenant "gastrite"
|
|
top_chunk_idx = results[0][0]
|
|
assert sample_chunks_cim10[top_chunk_idx].content.lower().find("gastrite") != -1
|
|
|
|
|
|
class TestVectorSearch:
|
|
"""Tests de la recherche vectorielle."""
|
|
|
|
def test_vector_search_with_mock_index(
|
|
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
|
):
|
|
"""Test la recherche vectorielle avec un index FAISS mocké."""
|
|
# Créer un index FAISS simple pour le test
|
|
dimension = 384 # Dimension typique pour MiniLM
|
|
index = faiss.IndexFlatL2(dimension)
|
|
|
|
# Ajouter des vecteurs aléatoires
|
|
vectors = np.random.rand(len(sample_chunks_cim10), dimension).astype(np.float32)
|
|
# Normaliser pour cosine similarity
|
|
faiss.normalize_L2(vectors)
|
|
index.add(vectors)
|
|
|
|
# Sauvegarder l'index
|
|
index_path = temp_data_dir / "cim10_2026_index.faiss"
|
|
faiss.write_index(index, str(index_path))
|
|
|
|
# Sauvegarder les chunks
|
|
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
|
|
with open(chunks_path, "w", encoding="utf-8") as f:
|
|
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
|
|
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
|
|
|
# Mocker le modèle d'embeddings
|
|
mock_model = Mock()
|
|
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
|
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
|
mock_model.encode.return_value = mock_query_vector
|
|
rag_engine._embeddings_model = mock_model
|
|
|
|
# Effectuer la recherche
|
|
results = rag_engine._vector_search("test query", "cim10", "2026", top_k=2)
|
|
|
|
assert len(results) > 0
|
|
assert all(isinstance(r, tuple) for r in results)
|
|
assert all(len(r) == 2 for r in results)
|
|
# Vérifier que les scores de similarité sont dans [0, 1]
|
|
assert all(0.0 <= r[1] <= 1.0 for r in results)
|
|
|
|
|
|
class TestReciprocalRankFusion:
|
|
"""Tests de la fusion RRF."""
|
|
|
|
def test_rrf_fusion_combines_results(self, rag_engine):
|
|
"""Test que RRF fusionne correctement les résultats."""
|
|
bm25_results = [(0, 10.5), (1, 8.2), (2, 5.1)]
|
|
vector_results = [(1, 0.95), (0, 0.88), (3, 0.75)]
|
|
|
|
fused = rag_engine._reciprocal_rank_fusion(bm25_results, vector_results, k=60)
|
|
|
|
# Vérifier que tous les chunks uniques sont présents
|
|
chunk_indices = [idx for idx, _ in fused]
|
|
assert set(chunk_indices) == {0, 1, 2, 3}
|
|
|
|
# Vérifier que les résultats sont triés par score décroissant
|
|
scores = [score for _, score in fused]
|
|
assert scores == sorted(scores, reverse=True)
|
|
|
|
def test_rrf_boosts_common_results(self, rag_engine):
|
|
"""Test que RRF booste les résultats présents dans les deux listes."""
|
|
# Chunk 0 est en tête des deux listes
|
|
bm25_results = [(0, 10.0), (1, 5.0), (2, 3.0)]
|
|
vector_results = [(0, 0.95), (3, 0.80), (1, 0.70)]
|
|
|
|
fused = rag_engine._reciprocal_rank_fusion(bm25_results, vector_results, k=60)
|
|
|
|
# Le chunk 0 devrait avoir le score le plus élevé
|
|
top_chunk = fused[0][0]
|
|
assert top_chunk == 0
|
|
|
|
def test_rrf_with_empty_lists(self, rag_engine):
|
|
"""Test RRF avec des listes vides."""
|
|
fused = rag_engine._reciprocal_rank_fusion([], [], k=60)
|
|
assert fused == []
|
|
|
|
def test_rrf_with_one_empty_list(self, rag_engine):
|
|
"""Test RRF avec une liste vide."""
|
|
bm25_results = [(0, 10.0), (1, 5.0)]
|
|
fused = rag_engine._reciprocal_rank_fusion(bm25_results, [], k=60)
|
|
|
|
assert len(fused) == 2
|
|
assert all(idx in [0, 1] for idx, _ in fused)
|
|
|
|
|
|
class TestSearchICD10:
|
|
"""Tests de la recherche CIM-10."""
|
|
|
|
def test_search_icd10_returns_candidates(
|
|
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
|
):
|
|
"""Test que search_icd10 retourne des candidats."""
|
|
# Préparer les données
|
|
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
|
|
|
|
# Mocker le reranker pour éviter de charger le modèle réel
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
# Effectuer la recherche
|
|
candidates = rag_engine.search_icd10("gastrite", top_k=2, version="2026")
|
|
|
|
assert len(candidates) > 0
|
|
assert all(isinstance(c, CodeCandidate) for c in candidates)
|
|
|
|
def test_search_icd10_candidate_structure(
|
|
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
|
):
|
|
"""Test la structure des candidats retournés."""
|
|
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
|
|
|
|
# Mocker le reranker
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.return_value = np.array([0.9])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
candidates = rag_engine.search_icd10("choléra", top_k=1, version="2026")
|
|
|
|
assert len(candidates) > 0
|
|
candidate = candidates[0]
|
|
|
|
# Vérifier les champs obligatoires
|
|
assert candidate.code is not None
|
|
assert candidate.label is not None
|
|
assert 0.0 <= candidate.similarity_score <= 1.0
|
|
assert candidate.source == "reranked"
|
|
assert candidate.chunk_id is not None
|
|
assert candidate.chunk_text is not None
|
|
|
|
def test_search_icd10_respects_top_k(
|
|
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
|
):
|
|
"""Test que search_icd10 respecte le paramètre top_k."""
|
|
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
|
|
|
|
# Mocker le reranker
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.return_value = np.array([0.9, 0.8])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
candidates = rag_engine.search_icd10("test", top_k=2, version="2026")
|
|
|
|
assert len(candidates) <= 2
|
|
|
|
def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type):
|
|
"""Helper pour préparer les données de test."""
|
|
# Sauvegarder les chunks
|
|
version = "2026" if ref_type == "cim10" else "2025"
|
|
chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json"
|
|
with open(chunks_path, "w", encoding="utf-8") as f:
|
|
chunks_data = [chunk.model_dump() for chunk in chunks]
|
|
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
|
|
|
# Créer un index FAISS simple
|
|
dimension = 384
|
|
index = faiss.IndexFlatL2(dimension)
|
|
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
|
|
faiss.normalize_L2(vectors)
|
|
index.add(vectors)
|
|
|
|
index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss"
|
|
faiss.write_index(index, str(index_path))
|
|
|
|
# Mocker le modèle d'embeddings
|
|
mock_model = Mock()
|
|
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
|
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
|
mock_model.encode.return_value = mock_query_vector
|
|
rag_engine._embeddings_model = mock_model
|
|
|
|
|
|
class TestSearchCCAM:
|
|
"""Tests de la recherche CCAM."""
|
|
|
|
def test_search_ccam_returns_candidates(
|
|
self, rag_engine, sample_chunks_ccam, temp_data_dir
|
|
):
|
|
"""Test que search_ccam retourne des candidats."""
|
|
self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam")
|
|
|
|
# Mocker le reranker
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.return_value = np.array([0.9, 0.8])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
candidates = rag_engine.search_ccam("appendicectomie", top_k=2, version="2025")
|
|
|
|
assert len(candidates) > 0
|
|
assert all(isinstance(c, CodeCandidate) for c in candidates)
|
|
|
|
def test_search_ccam_extracts_code_with_extension(
|
|
self, rag_engine, sample_chunks_ccam, temp_data_dir
|
|
):
|
|
"""Test que search_ccam extrait correctement les codes avec extension ATIH."""
|
|
self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam")
|
|
|
|
# Mocker le reranker
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.return_value = np.array([0.9, 0.8])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
candidates = rag_engine.search_ccam("cœlioscopie", top_k=2, version="2025")
|
|
|
|
# Vérifier qu'au moins un candidat a un code avec extension
|
|
codes = [c.code for c in candidates]
|
|
# Au moins un code devrait être présent
|
|
assert len(codes) > 0
|
|
|
|
def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type):
|
|
"""Helper pour préparer les données de test."""
|
|
version = "2025"
|
|
chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json"
|
|
with open(chunks_path, "w", encoding="utf-8") as f:
|
|
chunks_data = [chunk.model_dump() for chunk in chunks]
|
|
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
|
|
|
dimension = 384
|
|
index = faiss.IndexFlatL2(dimension)
|
|
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
|
|
faiss.normalize_L2(vectors)
|
|
index.add(vectors)
|
|
|
|
index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss"
|
|
faiss.write_index(index, str(index_path))
|
|
|
|
mock_model = Mock()
|
|
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
|
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
|
mock_model.encode.return_value = mock_query_vector
|
|
rag_engine._embeddings_model = mock_model
|
|
|
|
|
|
class TestExtractCodeAndLabel:
|
|
"""Tests de l'extraction de code et libellé."""
|
|
|
|
def test_extract_cim10_code_and_label(self, rag_engine):
|
|
"""Test l'extraction d'un code CIM-10 et son libellé."""
|
|
chunk_text = "A00.0 Choléra dû à Vibrio cholerae 01, biovar cholerae\nInclus: choléra classique"
|
|
|
|
code, label = rag_engine._extract_code_and_label(chunk_text, "cim10")
|
|
|
|
assert code == "A00.0"
|
|
assert "Choléra" in label
|
|
|
|
def test_extract_ccam_code_and_label(self, rag_engine):
|
|
"""Test l'extraction d'un code CCAM et son libellé."""
|
|
chunk_text = "YYYY001 Appendicectomie par laparotomie\nNote: Ablation de l'appendice"
|
|
|
|
code, label = rag_engine._extract_code_and_label(chunk_text, "ccam")
|
|
|
|
assert code == "YYYY001"
|
|
assert "Appendicectomie" in label
|
|
|
|
def test_extract_ccam_code_with_extension(self, rag_engine):
|
|
"""Test l'extraction d'un code CCAM avec extension ATIH."""
|
|
chunk_text = "YYYY002+ABC Appendicectomie par cœlioscopie"
|
|
|
|
code, label = rag_engine._extract_code_and_label(chunk_text, "ccam")
|
|
|
|
assert code == "YYYY002+ABC"
|
|
assert "Appendicectomie" in label
|
|
|
|
def test_extract_returns_unknown_for_invalid_format(self, rag_engine):
|
|
"""Test que l'extraction retourne UNKNOWN pour un format invalide."""
|
|
chunk_text = "Texte sans code valide"
|
|
|
|
code, label = rag_engine._extract_code_and_label(chunk_text, "cim10")
|
|
|
|
assert code == "UNKNOWN"
|
|
assert len(label) > 0
|
|
|
|
|
|
class TestRetrieveEligibilityCriteria:
|
|
"""Tests de la récupération des critères d'éligibilité."""
|
|
|
|
def test_retrieve_criteria_returns_eligibility(
|
|
self, rag_engine, sample_chunks_guide, temp_data_dir
|
|
):
|
|
"""Test que retrieve_eligibility_criteria retourne des critères."""
|
|
self._setup_guide_data(rag_engine, sample_chunks_guide, temp_data_dir)
|
|
|
|
criteria = rag_engine.retrieve_eligibility_criteria("K29.7", "dp")
|
|
|
|
assert criteria is not None
|
|
assert isinstance(criteria, EligibilityCriteria)
|
|
|
|
def test_retrieve_criteria_structure(
|
|
self, rag_engine, sample_chunks_guide, temp_data_dir
|
|
):
|
|
"""Test la structure des critères d'éligibilité."""
|
|
self._setup_guide_data(rag_engine, sample_chunks_guide, temp_data_dir)
|
|
|
|
criteria = rag_engine.retrieve_eligibility_criteria("A00.0", "dp")
|
|
|
|
assert criteria is not None
|
|
assert criteria.code == "A00.0"
|
|
assert criteria.code_type == "dp"
|
|
assert criteria.criteria_text is not None
|
|
assert isinstance(criteria.exclusion_rules, list)
|
|
assert isinstance(criteria.hierarchization_rules, list)
|
|
assert criteria.guide_section is not None
|
|
|
|
def test_extract_exclusion_rules(self, rag_engine):
|
|
"""Test l'extraction des règles d'exclusion."""
|
|
text = """Critères DP
|
|
- Exclut: les diagnostics niés
|
|
- À l'exclusion de: les antécédents
|
|
- Ne pas coder: les suspicions"""
|
|
|
|
rules = rag_engine._extract_exclusion_rules(text)
|
|
|
|
assert len(rules) == 3
|
|
assert any("niés" in rule for rule in rules)
|
|
assert any("antécédents" in rule for rule in rules)
|
|
assert any("suspicions" in rule for rule in rules)
|
|
|
|
def test_extract_hierarchization_rules(self, rag_engine):
|
|
"""Test l'extraction des règles de hiérarchisation."""
|
|
text = """Critères DP
|
|
- Hiérarchisation: privilégier le diagnostic le plus grave
|
|
- Priorité: diagnostic principal avant associé"""
|
|
|
|
rules = rag_engine._extract_hierarchization_rules(text)
|
|
|
|
assert len(rules) == 2
|
|
assert any("grave" in rule for rule in rules)
|
|
assert any("principal" in rule for rule in rules)
|
|
|
|
def _setup_guide_data(self, rag_engine, chunks, temp_data_dir):
|
|
"""Helper pour préparer les données du guide."""
|
|
version = "2026"
|
|
chunks_path = temp_data_dir / f"guide_mco_{version}_chunks.json"
|
|
with open(chunks_path, "w", encoding="utf-8") as f:
|
|
chunks_data = [chunk.model_dump() for chunk in chunks]
|
|
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
|
|
|
dimension = 384
|
|
index = faiss.IndexFlatL2(dimension)
|
|
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
|
|
faiss.normalize_L2(vectors)
|
|
index.add(vectors)
|
|
|
|
index_path = temp_data_dir / f"guide_mco_{version}_index.faiss"
|
|
faiss.write_index(index, str(index_path))
|
|
|
|
mock_model = Mock()
|
|
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
|
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
|
mock_model.encode.return_value = mock_query_vector
|
|
rag_engine._embeddings_model = mock_model
|
|
|
|
|
|
class TestCaching:
|
|
"""Tests du système de cache."""
|
|
|
|
def test_chunks_are_cached(self, rag_engine, sample_chunks_cim10, temp_data_dir):
|
|
"""Test que les chunks sont mis en cache."""
|
|
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
|
|
with open(chunks_path, "w", encoding="utf-8") as f:
|
|
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
|
|
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
|
|
|
# Premier chargement
|
|
chunks1 = rag_engine._load_chunks("cim10", "2026")
|
|
|
|
# Deuxième chargement (devrait utiliser le cache)
|
|
chunks2 = rag_engine._load_chunks("cim10", "2026")
|
|
|
|
assert chunks1 is chunks2 # Même objet en mémoire
|
|
|
|
def test_faiss_index_is_cached(self, rag_engine, temp_data_dir):
|
|
"""Test que l'index FAISS est mis en cache."""
|
|
# Créer un index
|
|
dimension = 384
|
|
index = faiss.IndexFlatL2(dimension)
|
|
vectors = np.random.rand(3, dimension).astype(np.float32)
|
|
index.add(vectors)
|
|
|
|
index_path = temp_data_dir / "cim10_2026_index.faiss"
|
|
faiss.write_index(index, str(index_path))
|
|
|
|
# Premier chargement
|
|
index1 = rag_engine._load_faiss_index("cim10", "2026")
|
|
|
|
# Deuxième chargement (devrait utiliser le cache)
|
|
index2 = rag_engine._load_faiss_index("cim10", "2026")
|
|
|
|
assert index1 is index2 # Même objet en mémoire
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Tests de la gestion d'erreurs."""
|
|
|
|
def test_load_chunks_raises_on_missing_file(self, rag_engine):
|
|
"""Test que _load_chunks lève une erreur si le fichier n'existe pas."""
|
|
with pytest.raises(FileNotFoundError):
|
|
rag_engine._load_chunks("cim10", "9999")
|
|
|
|
def test_load_faiss_index_raises_on_missing_file(self, rag_engine):
|
|
"""Test que _load_faiss_index lève une erreur si le fichier n'existe pas."""
|
|
with pytest.raises(FileNotFoundError):
|
|
rag_engine._load_faiss_index("cim10", "9999")
|
|
|
|
def test_search_icd10_handles_invalid_chunk_index(
|
|
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
|
):
|
|
"""Test que search_icd10 gère les index de chunk invalides."""
|
|
# Préparer les données
|
|
chunks_path = temp_data_dir / "cim10_2026_chunks.json"
|
|
with open(chunks_path, "w", encoding="utf-8") as f:
|
|
chunks_data = [chunk.model_dump() for chunk in sample_chunks_cim10]
|
|
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
|
|
|
dimension = 384
|
|
index = faiss.IndexFlatL2(dimension)
|
|
# Ajouter plus de vecteurs que de chunks (pour simuler un index invalide)
|
|
vectors = np.random.rand(10, dimension).astype(np.float32)
|
|
faiss.normalize_L2(vectors)
|
|
index.add(vectors)
|
|
|
|
index_path = temp_data_dir / "cim10_2026_index.faiss"
|
|
faiss.write_index(index, str(index_path))
|
|
|
|
mock_model = Mock()
|
|
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
|
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
|
mock_model.encode.return_value = mock_query_vector
|
|
rag_engine._embeddings_model = mock_model
|
|
|
|
# Mocker le reranker
|
|
mock_reranker = Mock()
|
|
# Retourner des scores pour les chunks valides seulement
|
|
mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
# La recherche ne devrait pas crasher
|
|
candidates = rag_engine.search_icd10("test", top_k=5, version="2026")
|
|
|
|
# La recherche devrait retourner des candidats sans crasher
|
|
# (même si certains index sont invalides)
|
|
assert len(candidates) > 0
|
|
# Tous les candidats retournés doivent avoir des codes valides
|
|
assert all(c.code != "UNKNOWN" or len(c.chunk_text) > 0 for c in candidates)
|
|
|
|
|
|
class TestReranking:
|
|
"""Tests du reranking avec cross-encoder."""
|
|
|
|
def test_rerank_results_returns_reranked_list(
|
|
self, rag_engine, sample_chunks_cim10
|
|
):
|
|
"""Test que _rerank_results retourne une liste reclassée."""
|
|
# Préparer des candidats
|
|
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
|
|
|
|
# Mocker le cross-encoder
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.return_value = np.array([0.8, 0.6, 0.9])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
# Effectuer le reranking
|
|
reranked = rag_engine._rerank_results(
|
|
"test query", candidates, sample_chunks_cim10, top_k=3
|
|
)
|
|
|
|
assert len(reranked) > 0
|
|
assert all(isinstance(r, tuple) for r in reranked)
|
|
assert all(len(r) == 2 for r in reranked)
|
|
|
|
# Vérifier que le cross-encoder a été appelé
|
|
mock_reranker.predict.assert_called_once()
|
|
|
|
def test_rerank_results_sorts_by_score(self, rag_engine, sample_chunks_cim10):
|
|
"""Test que _rerank_results trie par score décroissant."""
|
|
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
|
|
|
|
# Mocker le cross-encoder avec des scores spécifiques
|
|
mock_reranker = Mock()
|
|
# Chunk 2 a le meilleur score, puis 0, puis 1
|
|
mock_reranker.predict.return_value = np.array([0.5, 0.3, 0.9])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
reranked = rag_engine._rerank_results(
|
|
"test query", candidates, sample_chunks_cim10, top_k=3
|
|
)
|
|
|
|
# Vérifier que les résultats sont triés par score décroissant
|
|
scores = [score for _, score in reranked]
|
|
assert scores == sorted(scores, reverse=True)
|
|
|
|
# Le chunk 2 devrait être en premier (score 0.9)
|
|
assert reranked[0][0] == 2
|
|
|
|
def test_rerank_results_boosts_alphabetical_index(self, rag_engine):
|
|
"""Test que _rerank_results booste les résultats de l'index alphabétique."""
|
|
# Créer des chunks avec et sans index alphabétique
|
|
chunks = [
|
|
Chunk(
|
|
chunk_id="test_0",
|
|
referentiel_type="cim10",
|
|
referentiel_version="2026",
|
|
content="A00.0 Test code",
|
|
metadata={"chunk_type": "code_block"},
|
|
chunk_index=0,
|
|
),
|
|
Chunk(
|
|
chunk_id="test_1",
|
|
referentiel_type="cim10",
|
|
referentiel_version="2026",
|
|
content="Gastrite -> K29.7",
|
|
metadata={"chunk_type": "alphabetical_index"},
|
|
chunk_index=1,
|
|
),
|
|
]
|
|
|
|
candidates = [(0, 0.5), (1, 0.4)]
|
|
|
|
# Mocker le cross-encoder avec des scores identiques
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.return_value = np.array([0.7, 0.7])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
reranked = rag_engine._rerank_results("test query", candidates, chunks, top_k=2)
|
|
|
|
# Le chunk 1 (index alphabétique) devrait avoir un score plus élevé
|
|
# grâce au bonus de 0.1
|
|
chunk_1_score = next(score for idx, score in reranked if idx == 1)
|
|
chunk_0_score = next(score for idx, score in reranked if idx == 0)
|
|
|
|
assert chunk_1_score > chunk_0_score
|
|
# Le chunk 1 devrait être en premier
|
|
assert reranked[0][0] == 1
|
|
|
|
def test_rerank_results_respects_top_k(self, rag_engine, sample_chunks_cim10):
|
|
"""Test que _rerank_results respecte le paramètre top_k."""
|
|
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
|
|
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.return_value = np.array([0.8, 0.6, 0.9])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
reranked = rag_engine._rerank_results(
|
|
"test query", candidates, sample_chunks_cim10, top_k=2
|
|
)
|
|
|
|
assert len(reranked) == 2
|
|
|
|
def test_rerank_results_handles_empty_candidates(
|
|
self, rag_engine, sample_chunks_cim10
|
|
):
|
|
"""Test que _rerank_results gère les candidats vides."""
|
|
reranked = rag_engine._rerank_results(
|
|
"test query", [], sample_chunks_cim10, top_k=10
|
|
)
|
|
|
|
assert reranked == []
|
|
|
|
def test_rerank_results_handles_invalid_chunk_index(
|
|
self, rag_engine, sample_chunks_cim10
|
|
):
|
|
"""Test que _rerank_results gère les index de chunk invalides."""
|
|
# Candidat avec un index invalide
|
|
candidates = [(0, 0.5), (999, 0.4)]
|
|
|
|
mock_reranker = Mock()
|
|
# Le cross-encoder ne sera appelé qu'avec le chunk valide
|
|
mock_reranker.predict.return_value = np.array([0.8])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
reranked = rag_engine._rerank_results(
|
|
"test query", candidates, sample_chunks_cim10, top_k=10
|
|
)
|
|
|
|
# Seul le chunk valide devrait être retourné
|
|
assert len(reranked) == 1
|
|
assert reranked[0][0] == 0
|
|
|
|
def test_rerank_results_handles_reranker_error(
|
|
self, rag_engine, sample_chunks_cim10
|
|
):
|
|
"""Test que _rerank_results gère les erreurs du cross-encoder."""
|
|
candidates = [(0, 0.5), (1, 0.4), (2, 0.3)]
|
|
|
|
# Mocker le cross-encoder pour lever une erreur
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.side_effect = Exception("Reranker error")
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
# Le reranking devrait retourner les candidats originaux en cas d'erreur
|
|
reranked = rag_engine._rerank_results(
|
|
"test query", candidates, sample_chunks_cim10, top_k=3
|
|
)
|
|
|
|
# Devrait retourner les candidats originaux (top_k)
|
|
assert len(reranked) == 3
|
|
assert reranked == candidates[:3]
|
|
|
|
def test_get_reranker_model_loads_model(self, rag_engine):
|
|
"""Test que _get_reranker_model charge le modèle cross-encoder."""
|
|
with patch("sentence_transformers.CrossEncoder") as mock_ce:
|
|
mock_model = Mock()
|
|
mock_ce.return_value = mock_model
|
|
|
|
model = rag_engine._get_reranker_model()
|
|
|
|
assert model == mock_model
|
|
mock_ce.assert_called_once()
|
|
|
|
def test_get_reranker_model_caches_model(self, rag_engine):
|
|
"""Test que _get_reranker_model met en cache le modèle."""
|
|
with patch("sentence_transformers.CrossEncoder") as mock_ce:
|
|
mock_model = Mock()
|
|
mock_ce.return_value = mock_model
|
|
|
|
# Premier appel
|
|
model1 = rag_engine._get_reranker_model()
|
|
# Deuxième appel
|
|
model2 = rag_engine._get_reranker_model()
|
|
|
|
assert model1 is model2
|
|
# CrossEncoder ne devrait être appelé qu'une fois
|
|
mock_ce.assert_called_once()
|
|
|
|
|
|
class TestSearchWithReranking:
|
|
"""Tests de la recherche avec reranking intégré."""
|
|
|
|
def test_search_icd10_uses_reranking(
|
|
self, rag_engine, sample_chunks_cim10, temp_data_dir
|
|
):
|
|
"""Test que search_icd10 utilise le reranking."""
|
|
# Préparer les données
|
|
self._setup_test_data(rag_engine, sample_chunks_cim10, temp_data_dir, "cim10")
|
|
|
|
# Mocker le reranker
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.return_value = np.array([0.9, 0.8, 0.7])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
# Effectuer la recherche
|
|
candidates = rag_engine.search_icd10("gastrite", top_k=3, version="2026")
|
|
|
|
# Vérifier que le reranker a été appelé
|
|
mock_reranker.predict.assert_called_once()
|
|
|
|
# Vérifier que les candidats sont retournés
|
|
assert len(candidates) > 0
|
|
assert all(c.source == "reranked" for c in candidates)
|
|
|
|
def test_search_ccam_uses_reranking(
|
|
self, rag_engine, sample_chunks_ccam, temp_data_dir
|
|
):
|
|
"""Test que search_ccam utilise le reranking."""
|
|
# Préparer les données
|
|
self._setup_test_data(rag_engine, sample_chunks_ccam, temp_data_dir, "ccam")
|
|
|
|
# Mocker le reranker
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.return_value = np.array([0.9, 0.8])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
# Effectuer la recherche
|
|
candidates = rag_engine.search_ccam("appendicectomie", top_k=2, version="2025")
|
|
|
|
# Vérifier que le reranker a été appelé
|
|
mock_reranker.predict.assert_called_once()
|
|
|
|
# Vérifier que les candidats sont retournés
|
|
assert len(candidates) > 0
|
|
assert all(c.source == "reranked" for c in candidates)
|
|
|
|
def test_search_icd10_alphabetical_index_prioritized(self, rag_engine, temp_data_dir):
|
|
"""Test que search_icd10 priorise les résultats de l'index alphabétique."""
|
|
# Créer des chunks avec index alphabétique
|
|
chunks = [
|
|
Chunk(
|
|
chunk_id="cim10_2026_0",
|
|
referentiel_type="cim10",
|
|
referentiel_version="2026",
|
|
content="K29.7 Gastrite, sans précision",
|
|
metadata={"chunk_type": "code_block"},
|
|
chunk_index=0,
|
|
),
|
|
Chunk(
|
|
chunk_id="cim10_2026_1",
|
|
referentiel_type="cim10",
|
|
referentiel_version="2026",
|
|
content="Gastrite -> K29.7",
|
|
metadata={"chunk_type": "alphabetical_index"},
|
|
chunk_index=1,
|
|
),
|
|
]
|
|
|
|
# Préparer les données
|
|
self._setup_test_data(rag_engine, chunks, temp_data_dir, "cim10")
|
|
|
|
# Mocker le reranker avec des scores identiques
|
|
mock_reranker = Mock()
|
|
mock_reranker.predict.return_value = np.array([0.7, 0.7])
|
|
rag_engine._reranker_model = mock_reranker
|
|
|
|
# Effectuer la recherche
|
|
candidates = rag_engine.search_icd10("gastrite", top_k=2, version="2026")
|
|
|
|
# Le premier candidat devrait provenir de l'index alphabétique
|
|
# (grâce au bonus de 0.1)
|
|
assert len(candidates) >= 1
|
|
# Vérifier que le chunk_id du premier candidat est celui de l'index alphabétique
|
|
assert candidates[0].chunk_id == "cim10_2026_1"
|
|
|
|
def _setup_test_data(self, rag_engine, chunks, temp_data_dir, ref_type):
|
|
"""Helper pour préparer les données de test."""
|
|
version = "2026" if ref_type == "cim10" else "2025"
|
|
chunks_path = temp_data_dir / f"{ref_type}_{version}_chunks.json"
|
|
with open(chunks_path, "w", encoding="utf-8") as f:
|
|
chunks_data = [chunk.model_dump() for chunk in chunks]
|
|
json.dump(chunks_data, f, ensure_ascii=False, default=str)
|
|
|
|
dimension = 384
|
|
index = faiss.IndexFlatL2(dimension)
|
|
vectors = np.random.rand(len(chunks), dimension).astype(np.float32)
|
|
faiss.normalize_L2(vectors)
|
|
index.add(vectors)
|
|
|
|
index_path = temp_data_dir / f"{ref_type}_{version}_index.faiss"
|
|
faiss.write_index(index, str(index_path))
|
|
|
|
mock_model = Mock()
|
|
mock_query_vector = np.random.rand(dimension).astype(np.float32)
|
|
faiss.normalize_L2(mock_query_vector.reshape(1, -1))
|
|
mock_model.encode.return_value = mock_query_vector
|
|
rag_engine._embeddings_model = mock_model
|