"""Tests pour le pass LLM d'extraction de DAS supplémentaires.""" from __future__ import annotations from pathlib import Path from unittest.mock import patch, MagicMock import pytest from src.config import Diagnostic, DossierMedical, Sejour from src.medical.ollama_cache import OllamaCache class TestExtractDasLlm: """Tests pour extract_das_llm() dans rag_search.py.""" def test_returns_das_from_llm(self): """Le pass LLM retourne des DAS supplémentaires.""" from src.medical.rag_search import extract_das_llm mock_result = { "diagnostics_supplementaires": [ { "texte": "Hypertension artérielle", "code_cim10": "I10", "justification": "HTA mentionnée dans le texte", }, ] } with patch("src.medical.rag_search.call_ollama", return_value=mock_result): result = extract_das_llm( text="Patient hypertendu sous traitement", contexte={"sexe": "M", "age": 65}, existing_das=["Diabète de type 2 (E11.9)"], dp_texte="Pancréatite aiguë biliaire", ) assert len(result) == 1 assert result[0]["code_cim10"] == "I10" assert result[0]["texte"] == "Hypertension artérielle" def test_returns_empty_when_ollama_unavailable(self): """Retourne une liste vide si Ollama est indisponible.""" from src.medical.rag_search import extract_das_llm with patch("src.medical.rag_search.call_ollama", return_value=None): result = extract_das_llm( text="Texte médical", contexte={}, existing_das=[], dp_texte="", ) assert result == [] def test_returns_empty_on_bad_format(self): """Retourne une liste vide si le format de réponse est inattendu.""" from src.medical.rag_search import extract_das_llm with patch("src.medical.rag_search.call_ollama", return_value={"other_key": "value"}): result = extract_das_llm( text="Texte médical", contexte={}, existing_das=[], dp_texte="", ) assert result == [] def test_cache_hit(self, tmp_path): """Le cache est utilisé quand disponible.""" from src.medical.rag_search import extract_das_llm cache = OllamaCache(tmp_path / "cache.json", "test-model") mock_result = { "diagnostics_supplementaires": [ {"texte": "Anémie", "code_cim10": "D64.9", "justification": "test"}, ] } # Premier appel : cache miss, appelle Ollama with patch("src.medical.rag_search.call_ollama", return_value=mock_result) as mock_call: result1 = extract_das_llm( text="Patient anémique Hb basse", contexte={}, existing_das=[], dp_texte="", cache=cache, ) assert mock_call.call_count == 1 assert len(result1) == 1 # Vérifier que le cache contient bien l'entrée assert len(cache) > 0 # Deuxième appel : cache hit, pas d'appel Ollama with patch("src.medical.rag_search.call_ollama") as mock_call: result2 = extract_das_llm( text="Patient anémique Hb basse", contexte={}, existing_das=[], dp_texte="", cache=cache, ) mock_call.assert_not_called() assert len(result2) == 1 assert result2[0]["code_cim10"] == "D64.9" def test_prompt_includes_context(self): """Le prompt contient le contexte patient et les DAS existants.""" from src.medical.rag_search import _build_prompt_das_extraction prompt = _build_prompt_das_extraction( text="Patient hypertendu diabétique", contexte={"sexe": "F", "age": 72, "duree_sejour": 5}, existing_das=["Diabète de type 2 (E11.9)", "Obésité (E66.0)"], dp_texte="Pancréatite aiguë biliaire", ) assert "Pancréatite aiguë biliaire" in prompt assert "Diabète de type 2 (E11.9)" in prompt assert "Obésité (E66.0)" in prompt assert "Patient hypertendu diabétique" in prompt class TestBioNormesInContext: """Tests pour l'inclusion des normes biologiques dans le contexte LLM.""" def test_format_contexte_includes_normes(self): """_format_contexte() affiche les normes [N: min-max] pour chaque résultat bio.""" from src.medical.rag_search import _format_contexte contexte = { "biologie_cle": [ ("Créatinine", "76", False), ("CRP", "250", True), ("Lipasémie", "1200", True), ], } result = _format_contexte(contexte) assert "[N: 50-120]" in result assert "[N: 0-5]" in result assert "[N: 0-60]" in result # Créatinine normale → pas de marqueur ↑ assert "Créatinine 76 [N: 50-120]" in result # CRP anormale → marqueur ↑ assert "CRP 250 [N: 0-5] (↑)" in result def test_format_contexte_no_norme_for_unknown_test(self): """Les tests sans norme connue n'affichent pas de [N: ...].""" from src.medical.rag_search import _format_contexte contexte = { "biologie_cle": [ ("Test inconnu", "42", None), ], } result = _format_contexte(contexte) assert "Test inconnu 42" in result assert "[N:" not in result def test_prompt_das_includes_bio_norms_rule(self): """Le prompt DAS contient la règle sur les normes biologiques.""" from src.medical.rag_search import _build_prompt_das_extraction prompt = _build_prompt_das_extraction( text="Patient avec créatinine normale", contexte={"biologie_cle": [("Créatinine", "76", False)]}, existing_das=[], dp_texte="Pancréatite aiguë", ) assert "ATTENTION aux valeurs biologiques" in prompt assert "[N: min-max]" in prompt def test_bio_normals_exported(self): """BIO_NORMALS est bien exporté depuis bio_normals.""" from src.medical.bio_normals import BIO_NORMALS assert "Créatinine" in BIO_NORMALS assert BIO_NORMALS["Créatinine"] == (50, 120) assert "CRP" in BIO_NORMALS assert BIO_NORMALS["CRP"] == (0, 5) class TestExtractDasLlmIntegration: """Tests d'intégration pour le pass LLM DAS dans cim10_extractor.py.""" def test_das_llm_called_when_use_rag_true(self): """Le pass LLM DAS est appelé quand use_rag=True.""" from src.medical.cim10_extractor import extract_medical_info parsed = { "type": "CRH", "patient": {"sexe": "M"}, "sejour": {}, "diagnostics": [ {"libelle": "Pancréatite aiguë biliaire", "code_cim10": "K85.1", "type": "principal"}, ], } with patch("src.medical.cim10_extractor._extract_das_llm") as mock_llm, \ patch("src.medical.cim10_extractor._enrich_with_rag"): extract_medical_info(parsed, "texte médical", use_rag=True) mock_llm.assert_called_once() def test_das_llm_not_called_when_use_rag_false(self): """Le pass LLM DAS n'est PAS appelé quand use_rag=False.""" from src.medical.cim10_extractor import extract_medical_info parsed = { "type": "CRH", "patient": {"sexe": "M"}, "sejour": {}, "diagnostics": [ {"libelle": "Pancréatite aiguë biliaire", "code_cim10": "K85.1", "type": "principal"}, ], } with patch("src.medical.cim10_extractor._extract_das_llm") as mock_llm: extract_medical_info(parsed, "texte médical", use_rag=False) mock_llm.assert_not_called() def test_das_llm_filters_invalid_codes(self): """Les codes CIM-10 invalides sont filtrés lors de l'intégration.""" from src.medical.cim10_extractor import _extract_das_llm dossier = DossierMedical() dossier.sejour = Sejour(sexe="M", age=50) dossier.diagnostic_principal = Diagnostic( texte="Pancréatite aiguë", cim10_suggestion="K85.9", ) mock_result = [ {"texte": "Hypertension artérielle", "code_cim10": "I10", "justification": "ok"}, {"texte": "Diagnostic bidon", "code_cim10": "ZZZ.99", "justification": "invalide"}, ] with patch("src.medical.rag_search.extract_das_llm", return_value=mock_result): _extract_das_llm("texte médical", dossier) # I10 est valide → ajouté ; ZZZ.99 est invalide → filtré codes = [d.cim10_suggestion for d in dossier.diagnostics_associes] assert "I10" in codes assert "ZZZ.99" not in codes def test_das_llm_deduplicates(self): """Les codes déjà présents dans les DAS ne sont pas dupliqués.""" from src.medical.cim10_extractor import _extract_das_llm dossier = DossierMedical() dossier.sejour = Sejour(sexe="M", age=50) dossier.diagnostic_principal = Diagnostic( texte="Pancréatite aiguë", cim10_suggestion="K85.9", ) dossier.diagnostics_associes = [ Diagnostic(texte="Hypertension artérielle", cim10_suggestion="I10"), ] mock_result = [ {"texte": "HTA essentielle", "code_cim10": "I10", "justification": "doublon"}, {"texte": "Obésité", "code_cim10": "E66.0", "justification": "nouveau"}, ] with patch("src.medical.rag_search.extract_das_llm", return_value=mock_result): _extract_das_llm("texte médical", dossier) codes = [d.cim10_suggestion for d in dossier.diagnostics_associes] assert codes.count("I10") == 1 # Pas de doublon assert "E66.0" in codes # Nouveau ajouté