feat: pass LLM hybride pour DAS + interface admin référentiels RAG
Chantier 1 — Extraction DAS par LLM : - Nouveau prompt expert DIM dans rag_search.py (extract_das_llm) - Phase 4 dans cim10_extractor.py : détection DAS supplémentaires avant enrichissement RAG - Cache persistant (clé hash du texte), validation CIM-10, déduplication - Activé uniquement avec use_rag=True (--no-rag le désactive) Chantier 2 — Admin référentiels : - Config : REFERENTIELS_DIR, UPLOAD_MAX_SIZE_MB, ALLOWED_EXTENSIONS - Chunking générique (PDF/CSV/Excel/TXT) + ajout incrémental FAISS dans rag_index.py - ReferentielManager CRUD dans viewer/referentiels.py - 5 routes Flask (listing, upload, indexation, suppression, rebuild) - Template admin avec tableau interactif + lien sidebar Fix : if cache → if cache is not None (OllamaCache vide évaluait à False) 410 tests passent (27 nouveaux, 0 régression). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
213
tests/test_das_llm.py
Normal file
213
tests/test_das_llm.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Tests pour le pass LLM d'extraction de DAS supplémentaires."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import Diagnostic, DossierMedical, Sejour
|
||||
from src.medical.ollama_cache import OllamaCache
|
||||
|
||||
|
||||
class TestExtractDasLlm:
|
||||
"""Tests pour extract_das_llm() dans rag_search.py."""
|
||||
|
||||
def test_returns_das_from_llm(self):
|
||||
"""Le pass LLM retourne des DAS supplémentaires."""
|
||||
from src.medical.rag_search import extract_das_llm
|
||||
|
||||
mock_result = {
|
||||
"diagnostics_supplementaires": [
|
||||
{
|
||||
"texte": "Hypertension artérielle",
|
||||
"code_cim10": "I10",
|
||||
"justification": "HTA mentionnée dans le texte",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("src.medical.rag_search.call_ollama", return_value=mock_result):
|
||||
result = extract_das_llm(
|
||||
text="Patient hypertendu sous traitement",
|
||||
contexte={"sexe": "M", "age": 65},
|
||||
existing_das=["Diabète de type 2 (E11.9)"],
|
||||
dp_texte="Pancréatite aiguë biliaire",
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["code_cim10"] == "I10"
|
||||
assert result[0]["texte"] == "Hypertension artérielle"
|
||||
|
||||
def test_returns_empty_when_ollama_unavailable(self):
|
||||
"""Retourne une liste vide si Ollama est indisponible."""
|
||||
from src.medical.rag_search import extract_das_llm
|
||||
|
||||
with patch("src.medical.rag_search.call_ollama", return_value=None):
|
||||
result = extract_das_llm(
|
||||
text="Texte médical",
|
||||
contexte={},
|
||||
existing_das=[],
|
||||
dp_texte="",
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_returns_empty_on_bad_format(self):
|
||||
"""Retourne une liste vide si le format de réponse est inattendu."""
|
||||
from src.medical.rag_search import extract_das_llm
|
||||
|
||||
with patch("src.medical.rag_search.call_ollama", return_value={"other_key": "value"}):
|
||||
result = extract_das_llm(
|
||||
text="Texte médical",
|
||||
contexte={},
|
||||
existing_das=[],
|
||||
dp_texte="",
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_cache_hit(self, tmp_path):
|
||||
"""Le cache est utilisé quand disponible."""
|
||||
from src.medical.rag_search import extract_das_llm
|
||||
|
||||
cache = OllamaCache(tmp_path / "cache.json", "test-model")
|
||||
|
||||
mock_result = {
|
||||
"diagnostics_supplementaires": [
|
||||
{"texte": "Anémie", "code_cim10": "D64.9", "justification": "test"},
|
||||
]
|
||||
}
|
||||
|
||||
# Premier appel : cache miss, appelle Ollama
|
||||
with patch("src.medical.rag_search.call_ollama", return_value=mock_result) as mock_call:
|
||||
result1 = extract_das_llm(
|
||||
text="Patient anémique Hb basse",
|
||||
contexte={},
|
||||
existing_das=[],
|
||||
dp_texte="",
|
||||
cache=cache,
|
||||
)
|
||||
assert mock_call.call_count == 1
|
||||
assert len(result1) == 1
|
||||
|
||||
# Vérifier que le cache contient bien l'entrée
|
||||
assert len(cache) > 0
|
||||
|
||||
# Deuxième appel : cache hit, pas d'appel Ollama
|
||||
with patch("src.medical.rag_search.call_ollama") as mock_call:
|
||||
result2 = extract_das_llm(
|
||||
text="Patient anémique Hb basse",
|
||||
contexte={},
|
||||
existing_das=[],
|
||||
dp_texte="",
|
||||
cache=cache,
|
||||
)
|
||||
mock_call.assert_not_called()
|
||||
|
||||
assert len(result2) == 1
|
||||
assert result2[0]["code_cim10"] == "D64.9"
|
||||
|
||||
def test_prompt_includes_context(self):
|
||||
"""Le prompt contient le contexte patient et les DAS existants."""
|
||||
from src.medical.rag_search import _build_prompt_das_extraction
|
||||
|
||||
prompt = _build_prompt_das_extraction(
|
||||
text="Patient hypertendu diabétique",
|
||||
contexte={"sexe": "F", "age": 72, "duree_sejour": 5},
|
||||
existing_das=["Diabète de type 2 (E11.9)", "Obésité (E66.0)"],
|
||||
dp_texte="Pancréatite aiguë biliaire",
|
||||
)
|
||||
|
||||
assert "Pancréatite aiguë biliaire" in prompt
|
||||
assert "Diabète de type 2 (E11.9)" in prompt
|
||||
assert "Obésité (E66.0)" in prompt
|
||||
assert "Patient hypertendu diabétique" in prompt
|
||||
|
||||
|
||||
class TestExtractDasLlmIntegration:
|
||||
"""Tests d'intégration pour le pass LLM DAS dans cim10_extractor.py."""
|
||||
|
||||
def test_das_llm_called_when_use_rag_true(self):
|
||||
"""Le pass LLM DAS est appelé quand use_rag=True."""
|
||||
from src.medical.cim10_extractor import extract_medical_info
|
||||
|
||||
parsed = {
|
||||
"type": "CRH",
|
||||
"patient": {"sexe": "M"},
|
||||
"sejour": {},
|
||||
"diagnostics": [
|
||||
{"libelle": "Pancréatite aiguë biliaire", "code_cim10": "K85.1", "type": "principal"},
|
||||
],
|
||||
}
|
||||
|
||||
with patch("src.medical.cim10_extractor._extract_das_llm") as mock_llm, \
|
||||
patch("src.medical.cim10_extractor._enrich_with_rag"):
|
||||
extract_medical_info(parsed, "texte médical", use_rag=True)
|
||||
mock_llm.assert_called_once()
|
||||
|
||||
def test_das_llm_not_called_when_use_rag_false(self):
|
||||
"""Le pass LLM DAS n'est PAS appelé quand use_rag=False."""
|
||||
from src.medical.cim10_extractor import extract_medical_info
|
||||
|
||||
parsed = {
|
||||
"type": "CRH",
|
||||
"patient": {"sexe": "M"},
|
||||
"sejour": {},
|
||||
"diagnostics": [
|
||||
{"libelle": "Pancréatite aiguë biliaire", "code_cim10": "K85.1", "type": "principal"},
|
||||
],
|
||||
}
|
||||
|
||||
with patch("src.medical.cim10_extractor._extract_das_llm") as mock_llm:
|
||||
extract_medical_info(parsed, "texte médical", use_rag=False)
|
||||
mock_llm.assert_not_called()
|
||||
|
||||
def test_das_llm_filters_invalid_codes(self):
|
||||
"""Les codes CIM-10 invalides sont filtrés lors de l'intégration."""
|
||||
from src.medical.cim10_extractor import _extract_das_llm
|
||||
|
||||
dossier = DossierMedical()
|
||||
dossier.sejour = Sejour(sexe="M", age=50)
|
||||
dossier.diagnostic_principal = Diagnostic(
|
||||
texte="Pancréatite aiguë", cim10_suggestion="K85.9",
|
||||
)
|
||||
|
||||
mock_result = [
|
||||
{"texte": "Hypertension artérielle", "code_cim10": "I10", "justification": "ok"},
|
||||
{"texte": "Diagnostic bidon", "code_cim10": "ZZZ.99", "justification": "invalide"},
|
||||
]
|
||||
|
||||
with patch("src.medical.rag_search.extract_das_llm", return_value=mock_result):
|
||||
_extract_das_llm("texte médical", dossier)
|
||||
|
||||
# I10 est valide → ajouté ; ZZZ.99 est invalide → filtré
|
||||
codes = [d.cim10_suggestion for d in dossier.diagnostics_associes]
|
||||
assert "I10" in codes
|
||||
assert "ZZZ.99" not in codes
|
||||
|
||||
def test_das_llm_deduplicates(self):
|
||||
"""Les codes déjà présents dans les DAS ne sont pas dupliqués."""
|
||||
from src.medical.cim10_extractor import _extract_das_llm
|
||||
|
||||
dossier = DossierMedical()
|
||||
dossier.sejour = Sejour(sexe="M", age=50)
|
||||
dossier.diagnostic_principal = Diagnostic(
|
||||
texte="Pancréatite aiguë", cim10_suggestion="K85.9",
|
||||
)
|
||||
dossier.diagnostics_associes = [
|
||||
Diagnostic(texte="Hypertension artérielle", cim10_suggestion="I10"),
|
||||
]
|
||||
|
||||
mock_result = [
|
||||
{"texte": "HTA essentielle", "code_cim10": "I10", "justification": "doublon"},
|
||||
{"texte": "Obésité", "code_cim10": "E66.0", "justification": "nouveau"},
|
||||
]
|
||||
|
||||
with patch("src.medical.rag_search.extract_das_llm", return_value=mock_result):
|
||||
_extract_das_llm("texte médical", dossier)
|
||||
|
||||
codes = [d.cim10_suggestion for d in dossier.diagnostics_associes]
|
||||
assert codes.count("I10") == 1 # Pas de doublon
|
||||
assert "E66.0" in codes # Nouveau ajouté
|
||||
179
tests/test_referentiels.py
Normal file
179
tests/test_referentiels.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Tests pour le gestionnaire de référentiels et les routes Flask associées."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.viewer.referentiels import ReferentielManager
|
||||
from src.config import ALLOWED_EXTENSIONS, UPLOAD_MAX_SIZE_MB
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests ReferentielManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReferentielManager:
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self, tmp_path):
|
||||
return ReferentielManager(tmp_path / "refs")
|
||||
|
||||
def test_add_file(self, manager):
|
||||
ref = manager.add_file("guide.pdf", b"fake pdf content")
|
||||
assert ref["filename"] == "guide.pdf"
|
||||
assert ref["extension"] == ".pdf"
|
||||
assert ref["status"] == "uploaded"
|
||||
assert ref["size_bytes"] == len(b"fake pdf content")
|
||||
assert ref["chunks_count"] == 0
|
||||
|
||||
def test_list_all(self, manager):
|
||||
manager.add_file("a.txt", b"hello")
|
||||
manager.add_file("b.csv", b"col1,col2")
|
||||
assert len(manager.list_all()) == 2
|
||||
|
||||
def test_get(self, manager):
|
||||
ref = manager.add_file("guide.pdf", b"content")
|
||||
found = manager.get(ref["id"])
|
||||
assert found is not None
|
||||
assert found["filename"] == "guide.pdf"
|
||||
|
||||
def test_get_not_found(self, manager):
|
||||
assert manager.get("nonexistent") is None
|
||||
|
||||
def test_remove(self, manager):
|
||||
ref = manager.add_file("guide.pdf", b"content")
|
||||
assert manager.remove(ref["id"]) is True
|
||||
assert len(manager.list_all()) == 0
|
||||
assert manager.get(ref["id"]) is None
|
||||
|
||||
def test_remove_not_found(self, manager):
|
||||
assert manager.remove("nonexistent") is False
|
||||
|
||||
def test_add_file_invalid_extension(self, manager):
|
||||
with pytest.raises(ValueError, match="Extension"):
|
||||
manager.add_file("malware.exe", b"evil")
|
||||
|
||||
def test_add_file_too_large(self, manager):
|
||||
big_data = b"x" * (UPLOAD_MAX_SIZE_MB * 1024 * 1024 + 1)
|
||||
with pytest.raises(ValueError, match="volumineux"):
|
||||
manager.add_file("big.pdf", big_data)
|
||||
|
||||
def test_persistence(self, tmp_path):
|
||||
"""L'index persiste entre les instances."""
|
||||
dir_path = tmp_path / "refs"
|
||||
m1 = ReferentielManager(dir_path)
|
||||
m1.add_file("a.txt", b"hello")
|
||||
|
||||
m2 = ReferentielManager(dir_path)
|
||||
assert len(m2.list_all()) == 1
|
||||
assert m2.list_all()[0]["filename"] == "a.txt"
|
||||
|
||||
def test_file_stored_on_disk(self, manager, tmp_path):
|
||||
ref = manager.add_file("test.txt", b"file content here")
|
||||
stored_path = manager._dir / ref["stored_name"]
|
||||
assert stored_path.exists()
|
||||
assert stored_path.read_bytes() == b"file content here"
|
||||
|
||||
def test_remove_deletes_file(self, manager):
|
||||
ref = manager.add_file("test.txt", b"content")
|
||||
stored_path = manager._dir / ref["stored_name"]
|
||||
assert stored_path.exists()
|
||||
manager.remove(ref["id"])
|
||||
assert not stored_path.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests chunking générique
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestChunking:
|
||||
|
||||
def test_chunk_txt(self, tmp_path):
|
||||
from src.medical.rag_index import chunk_user_file
|
||||
|
||||
txt_file = tmp_path / "test.txt"
|
||||
txt_file.write_text(
|
||||
"Premier paragraphe avec assez de mots pour le seuil.\n\n"
|
||||
"Deuxième paragraphe avec encore plus de mots pour dépasser le minimum.\n\n"
|
||||
"Court\n\n"
|
||||
"Troisième paragraphe qui devrait aussi être un chunk valide.",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
chunks = chunk_user_file(txt_file, "test_doc")
|
||||
assert len(chunks) >= 2 # au moins 2 paragraphes assez longs
|
||||
assert all(c.document == "test_doc" for c in chunks)
|
||||
|
||||
def test_chunk_csv(self, tmp_path):
|
||||
from src.medical.rag_index import chunk_user_file
|
||||
|
||||
csv_file = tmp_path / "test.csv"
|
||||
csv_file.write_text(
|
||||
"code,description,note\n"
|
||||
"K85.1,Pancréatite aiguë biliaire,diagnostic fréquent\n"
|
||||
"I10,Hypertension essentielle,comorbidité courante\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
chunks = chunk_user_file(csv_file, "csv_doc")
|
||||
assert len(chunks) == 2
|
||||
assert "K85.1" in chunks[0].text
|
||||
assert "I10" in chunks[1].text
|
||||
|
||||
def test_chunk_unsupported_extension(self, tmp_path):
|
||||
from src.medical.rag_index import chunk_user_file
|
||||
|
||||
bad_file = tmp_path / "test.xyz"
|
||||
bad_file.write_text("content")
|
||||
|
||||
chunks = chunk_user_file(bad_file, "bad")
|
||||
assert chunks == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests routes Flask
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReferentielRoutes:
|
||||
|
||||
@pytest.fixture
|
||||
def app(self, tmp_path):
|
||||
"""Crée une app Flask de test avec un manager temporaire."""
|
||||
from src.viewer.app import create_app
|
||||
app = create_app()
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app):
|
||||
return app.test_client()
|
||||
|
||||
def test_admin_page_loads(self, client):
|
||||
resp = client.get("/admin/referentiels")
|
||||
assert resp.status_code == 200
|
||||
assert "Référentiels RAG" in resp.data.decode()
|
||||
|
||||
def test_upload_no_file(self, client):
|
||||
resp = client.post("/admin/referentiels/upload")
|
||||
assert resp.status_code == 400
|
||||
data = resp.get_json()
|
||||
assert "error" in data
|
||||
|
||||
def test_upload_valid_file(self, client):
|
||||
from io import BytesIO
|
||||
data = {
|
||||
"file": (BytesIO(b"test content"), "doc.txt"),
|
||||
}
|
||||
resp = client.post("/admin/referentiels/upload", data=data, content_type="multipart/form-data")
|
||||
result = resp.get_json()
|
||||
assert resp.status_code == 200
|
||||
assert result["ok"] is True
|
||||
assert result["referentiel"]["filename"] == "doc.txt"
|
||||
|
||||
def test_delete_nonexistent(self, client):
|
||||
resp = client.delete("/admin/referentiels/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
Reference in New Issue
Block a user