621 lines
21 KiB
Python
621 lines
21 KiB
Python
"""
|
|
Tests unitaires pour le GoldSetValidator.
|
|
|
|
Ces tests vérifient le chargement du jeu gold, l'exécution du pipeline,
|
|
le calcul des métriques et la validation des releases.
|
|
"""
|
|
|
|
import json
|
|
import tempfile
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from pipeline_mco_pmsi.validation import GoldSetValidator
|
|
from pipeline_mco_pmsi.validation.gold_set_validator import (
|
|
GoldSetMetrics,
|
|
GoldStayResult
|
|
)
|
|
from pipeline_mco_pmsi.models.coding import Code
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_gold_dir():
|
|
"""Crée un répertoire temporaire pour le jeu gold."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
yield Path(tmpdir)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_gold_set():
|
|
"""Crée un jeu gold de test avec 200 séjours."""
|
|
gold_stays = []
|
|
|
|
for i in range(200):
|
|
stay = {
|
|
"stay_id": f"SEJ{i:03d}",
|
|
"documents": [
|
|
{
|
|
"document_id": f"DOC{i:03d}",
|
|
"content": f"Patient {i} avec diagnostic test",
|
|
"document_type": "CRO"
|
|
}
|
|
],
|
|
"expected_codes": {
|
|
"dp": f"I{i%10:02d}.{i%10}",
|
|
"das": [f"E{i%5:02d}.{i%5}", f"K{i%3:02d}.{i%3}"],
|
|
"ccam": [f"YYYY{i%100:03d}"]
|
|
}
|
|
}
|
|
gold_stays.append(stay)
|
|
|
|
return gold_stays
|
|
|
|
|
|
class TestGoldSetValidatorInit:
|
|
"""Tests d'initialisation du GoldSetValidator."""
|
|
|
|
def test_init_with_defaults(self, temp_gold_dir):
|
|
"""Test l'initialisation avec valeurs par défaut."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
assert validator.gold_set_path == temp_gold_dir
|
|
assert validator.min_dp_accuracy == 0.70
|
|
assert validator.min_das_f1 == 0.60
|
|
assert validator.min_ccam_f1 == 0.65
|
|
assert validator.max_degradation == 0.05
|
|
|
|
def test_init_with_custom_thresholds(self, temp_gold_dir):
|
|
"""Test l'initialisation avec seuils personnalisés."""
|
|
validator = GoldSetValidator(
|
|
gold_set_path=temp_gold_dir,
|
|
min_dp_accuracy=0.80,
|
|
min_das_f1=0.70,
|
|
min_ccam_f1=0.75,
|
|
max_degradation=0.03
|
|
)
|
|
|
|
assert validator.min_dp_accuracy == 0.80
|
|
assert validator.min_das_f1 == 0.70
|
|
assert validator.min_ccam_f1 == 0.75
|
|
assert validator.max_degradation == 0.03
|
|
|
|
|
|
class TestLoadGoldSet:
|
|
"""Tests de chargement du jeu gold."""
|
|
|
|
def test_load_gold_set_success(self, temp_gold_dir, sample_gold_set):
|
|
"""Test le chargement réussi d'un jeu gold."""
|
|
# Créer le fichier gold
|
|
gold_file = temp_gold_dir / "gold_set.json"
|
|
with open(gold_file, "w", encoding="utf-8") as f:
|
|
json.dump(sample_gold_set, f)
|
|
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
gold_stays = validator.load_gold_set()
|
|
|
|
assert len(gold_stays) == 200
|
|
assert gold_stays[0]["stay_id"] == "SEJ000"
|
|
assert "expected_codes" in gold_stays[0]
|
|
|
|
def test_load_gold_set_file_not_found(self, temp_gold_dir):
|
|
"""Test que load_gold_set échoue si le fichier n'existe pas."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
with pytest.raises(FileNotFoundError, match="Fichier jeu gold introuvable"):
|
|
validator.load_gold_set()
|
|
|
|
def test_load_gold_set_too_few_stays(self, temp_gold_dir):
|
|
"""Test que load_gold_set rejette un jeu gold trop petit."""
|
|
# Créer un jeu gold avec seulement 50 séjours
|
|
small_gold_set = [
|
|
{
|
|
"stay_id": f"SEJ{i:03d}",
|
|
"documents": [],
|
|
"expected_codes": {"dp": "I21.0", "das": [], "ccam": []}
|
|
}
|
|
for i in range(50)
|
|
]
|
|
|
|
gold_file = temp_gold_dir / "gold_set.json"
|
|
with open(gold_file, "w", encoding="utf-8") as f:
|
|
json.dump(small_gold_set, f)
|
|
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
with pytest.raises(ValueError, match="au moins 200 séjours"):
|
|
validator.load_gold_set()
|
|
|
|
|
|
class TestRunGoldSet:
|
|
"""Tests d'exécution du pipeline sur le jeu gold."""
|
|
|
|
def test_run_gold_set_success(self, temp_gold_dir):
|
|
"""Test l'exécution réussie du pipeline sur le jeu gold."""
|
|
from pipeline_mco_pmsi.models.clinical import Evidence, Span
|
|
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
# Créer un mock du pipeline
|
|
mock_pipeline = Mock()
|
|
mock_result = Mock()
|
|
mock_result.proposed_codes = [
|
|
Code(
|
|
code="I21.0",
|
|
type="dp",
|
|
label="Infarctus",
|
|
confidence=0.9,
|
|
reasoning="Test",
|
|
evidence=[Evidence(document_id="DOC001", span=Span(start=0, end=10), text="test")],
|
|
referentiel_version="2026"
|
|
),
|
|
Code(
|
|
code="I10",
|
|
type="das",
|
|
label="HTA",
|
|
confidence=0.8,
|
|
reasoning="Test",
|
|
evidence=[Evidence(document_id="DOC001", span=Span(start=0, end=10), text="test")],
|
|
referentiel_version="2026"
|
|
),
|
|
Code(
|
|
code="YYYY001",
|
|
type="ccam",
|
|
label="Acte",
|
|
confidence=0.85,
|
|
reasoning="Test",
|
|
evidence=[Evidence(document_id="DOC001", span=Span(start=0, end=10), text="test")],
|
|
referentiel_version="2026"
|
|
)
|
|
]
|
|
mock_pipeline.process_stay.return_value = mock_result
|
|
|
|
# Jeu gold minimal (3 séjours pour le test)
|
|
gold_stays = [
|
|
{
|
|
"stay_id": "SEJ001",
|
|
"documents": [],
|
|
"expected_codes": {"dp": "I21.0", "das": ["I10"], "ccam": ["YYYY001"]}
|
|
},
|
|
{
|
|
"stay_id": "SEJ002",
|
|
"documents": [],
|
|
"expected_codes": {"dp": "I21.0", "das": ["I10", "E11.9"], "ccam": ["YYYY001"]}
|
|
},
|
|
{
|
|
"stay_id": "SEJ003",
|
|
"documents": [],
|
|
"expected_codes": {"dp": "I21.0", "das": [], "ccam": []}
|
|
}
|
|
]
|
|
|
|
results = validator.run_gold_set(mock_pipeline, gold_stays)
|
|
|
|
assert len(results) == 3
|
|
assert all(isinstance(r, GoldStayResult) for r in results)
|
|
assert results[0].stay_id == "SEJ001"
|
|
assert results[0].dp_correct is True
|
|
assert results[0].dp_predicted == "I21.0"
|
|
|
|
def test_run_gold_set_with_errors(self, temp_gold_dir):
|
|
"""Test l'exécution avec erreurs."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
# Mock du pipeline qui lève une exception
|
|
mock_pipeline = Mock()
|
|
mock_pipeline.process_stay.side_effect = Exception("Erreur de traitement")
|
|
|
|
gold_stays = [
|
|
{
|
|
"stay_id": "SEJ001",
|
|
"documents": [],
|
|
"expected_codes": {"dp": "I21.0", "das": [], "ccam": []}
|
|
}
|
|
]
|
|
|
|
results = validator.run_gold_set(mock_pipeline, gold_stays)
|
|
|
|
assert len(results) == 1
|
|
assert results[0].dp_correct is False
|
|
assert len(results[0].errors) > 0
|
|
assert "Erreur de traitement" in results[0].errors[0]
|
|
|
|
|
|
class TestCalculateMetrics:
|
|
"""Tests de calcul des métriques."""
|
|
|
|
def test_calculate_metrics_perfect_score(self, temp_gold_dir):
|
|
"""Test le calcul avec score parfait."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
# Résultats parfaits
|
|
results = [
|
|
GoldStayResult(
|
|
stay_id=f"SEJ{i:03d}",
|
|
dp_correct=True,
|
|
dp_predicted=f"I{i}.0",
|
|
dp_expected=f"I{i}.0",
|
|
das_predicted=["I10", "E11.9"],
|
|
das_expected=["I10", "E11.9"],
|
|
das_precision=1.0,
|
|
das_recall=1.0,
|
|
das_f1=1.0,
|
|
ccam_predicted=["YYYY001"],
|
|
ccam_expected=["YYYY001"],
|
|
ccam_precision=1.0,
|
|
ccam_recall=1.0,
|
|
ccam_f1=1.0,
|
|
processing_time_seconds=1.5,
|
|
errors=[]
|
|
)
|
|
for i in range(10)
|
|
]
|
|
|
|
metrics = validator.calculate_metrics(results)
|
|
|
|
assert metrics.total_stays == 10
|
|
assert metrics.dp_accuracy == 1.0
|
|
assert metrics.das_f1 == 1.0
|
|
assert metrics.ccam_f1 == 1.0
|
|
assert metrics.error_rate == 0.0
|
|
assert metrics.avg_processing_time == 1.5
|
|
|
|
def test_calculate_metrics_partial_score(self, temp_gold_dir):
|
|
"""Test le calcul avec score partiel."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
results = [
|
|
# 50% DP correct
|
|
GoldStayResult(
|
|
stay_id="SEJ001",
|
|
dp_correct=True,
|
|
dp_predicted="I21.0",
|
|
dp_expected="I21.0",
|
|
das_predicted=["I10"],
|
|
das_expected=["I10", "E11.9"],
|
|
das_precision=1.0,
|
|
das_recall=0.5,
|
|
das_f1=0.67,
|
|
ccam_predicted=["YYYY001"],
|
|
ccam_expected=["YYYY001"],
|
|
ccam_precision=1.0,
|
|
ccam_recall=1.0,
|
|
ccam_f1=1.0,
|
|
processing_time_seconds=2.0,
|
|
errors=[]
|
|
),
|
|
GoldStayResult(
|
|
stay_id="SEJ002",
|
|
dp_correct=False,
|
|
dp_predicted="I22.0",
|
|
dp_expected="I21.0",
|
|
das_predicted=["I10", "E11.9"],
|
|
das_expected=["I10"],
|
|
das_precision=0.5,
|
|
das_recall=1.0,
|
|
das_f1=0.67,
|
|
ccam_predicted=[],
|
|
ccam_expected=["YYYY001"],
|
|
ccam_precision=0.0,
|
|
ccam_recall=0.0,
|
|
ccam_f1=0.0,
|
|
processing_time_seconds=1.5,
|
|
errors=[]
|
|
)
|
|
]
|
|
|
|
metrics = validator.calculate_metrics(results)
|
|
|
|
assert metrics.total_stays == 2
|
|
assert metrics.dp_accuracy == 0.5 # 1/2
|
|
assert 0.6 < metrics.das_f1 < 0.7 # Moyenne de 0.67 et 0.67
|
|
assert metrics.ccam_f1 == 0.5 # Moyenne de 1.0 et 0.0
|
|
assert metrics.avg_processing_time == 1.75 # Moyenne de 2.0 et 1.5
|
|
|
|
|
|
class TestCompareMetrics:
|
|
"""Tests de comparaison des métriques."""
|
|
|
|
def test_compare_metrics_improvement(self, temp_gold_dir):
|
|
"""Test la comparaison avec amélioration."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
before = GoldSetMetrics(
|
|
total_stays=200,
|
|
dp_accuracy=0.70,
|
|
das_precision=0.65,
|
|
das_recall=0.60,
|
|
das_f1=0.62,
|
|
ccam_precision=0.70,
|
|
ccam_recall=0.68,
|
|
ccam_f1=0.69,
|
|
avg_processing_time=25.0,
|
|
error_rate=0.05
|
|
)
|
|
|
|
after = GoldSetMetrics(
|
|
total_stays=200,
|
|
dp_accuracy=0.75,
|
|
das_precision=0.70,
|
|
das_recall=0.65,
|
|
das_f1=0.67,
|
|
ccam_precision=0.75,
|
|
ccam_recall=0.73,
|
|
ccam_f1=0.74,
|
|
avg_processing_time=23.0,
|
|
error_rate=0.03
|
|
)
|
|
|
|
differences = validator.compare_metrics(before, after)
|
|
|
|
assert differences["dp_accuracy"] == pytest.approx(0.05) # Amélioration
|
|
assert differences["das_f1"] == pytest.approx(0.05) # Amélioration
|
|
assert differences["ccam_f1"] == pytest.approx(0.05) # Amélioration
|
|
assert differences["error_rate"] == pytest.approx(-0.02) # Amélioration (moins d'erreurs)
|
|
|
|
def test_compare_metrics_degradation(self, temp_gold_dir):
|
|
"""Test la comparaison avec dégradation."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
before = GoldSetMetrics(
|
|
total_stays=200,
|
|
dp_accuracy=0.75,
|
|
das_precision=0.70,
|
|
das_recall=0.65,
|
|
das_f1=0.67,
|
|
ccam_precision=0.75,
|
|
ccam_recall=0.73,
|
|
ccam_f1=0.74,
|
|
avg_processing_time=23.0,
|
|
error_rate=0.03
|
|
)
|
|
|
|
after = GoldSetMetrics(
|
|
total_stays=200,
|
|
dp_accuracy=0.68,
|
|
das_precision=0.63,
|
|
das_recall=0.58,
|
|
das_f1=0.60,
|
|
ccam_precision=0.68,
|
|
ccam_recall=0.66,
|
|
ccam_f1=0.67,
|
|
avg_processing_time=25.0,
|
|
error_rate=0.06
|
|
)
|
|
|
|
differences = validator.compare_metrics(before, after)
|
|
|
|
assert differences["dp_accuracy"] == pytest.approx(-0.07) # Dégradation
|
|
assert differences["das_f1"] == pytest.approx(-0.07) # Dégradation
|
|
assert differences["ccam_f1"] == pytest.approx(-0.07) # Dégradation
|
|
|
|
|
|
class TestValidateRelease:
|
|
"""Tests de validation de release."""
|
|
|
|
def test_validate_release_success(self, temp_gold_dir):
|
|
"""Test la validation réussie d'une release."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
before = GoldSetMetrics(
|
|
total_stays=200,
|
|
dp_accuracy=0.72,
|
|
das_precision=0.65,
|
|
das_recall=0.60,
|
|
das_f1=0.62,
|
|
ccam_precision=0.70,
|
|
ccam_recall=0.68,
|
|
ccam_f1=0.69,
|
|
avg_processing_time=25.0,
|
|
error_rate=0.05
|
|
)
|
|
|
|
after = GoldSetMetrics(
|
|
total_stays=200,
|
|
dp_accuracy=0.75,
|
|
das_precision=0.68,
|
|
das_recall=0.63,
|
|
das_f1=0.65,
|
|
ccam_precision=0.73,
|
|
ccam_recall=0.71,
|
|
ccam_f1=0.72,
|
|
avg_processing_time=23.0,
|
|
error_rate=0.03
|
|
)
|
|
|
|
release_ok, reasons = validator.validate_release(before, after)
|
|
|
|
assert release_ok is True
|
|
assert len(reasons) == 0
|
|
|
|
def test_validate_release_below_threshold(self, temp_gold_dir):
|
|
"""Test le blocage si en dessous des seuils."""
|
|
validator = GoldSetValidator(
|
|
gold_set_path=temp_gold_dir,
|
|
min_dp_accuracy=0.70,
|
|
min_das_f1=0.60,
|
|
min_ccam_f1=0.65
|
|
)
|
|
|
|
before = GoldSetMetrics(
|
|
total_stays=200,
|
|
dp_accuracy=0.72,
|
|
das_precision=0.65,
|
|
das_recall=0.60,
|
|
das_f1=0.62,
|
|
ccam_precision=0.70,
|
|
ccam_recall=0.68,
|
|
ccam_f1=0.69,
|
|
avg_processing_time=25.0,
|
|
error_rate=0.05
|
|
)
|
|
|
|
# Métriques en dessous des seuils
|
|
after = GoldSetMetrics(
|
|
total_stays=200,
|
|
dp_accuracy=0.65, # < 0.70
|
|
das_precision=0.60,
|
|
das_recall=0.55,
|
|
das_f1=0.57, # < 0.60
|
|
ccam_precision=0.63,
|
|
ccam_recall=0.61,
|
|
ccam_f1=0.62, # < 0.65
|
|
avg_processing_time=23.0,
|
|
error_rate=0.03
|
|
)
|
|
|
|
release_ok, reasons = validator.validate_release(before, after)
|
|
|
|
assert release_ok is False
|
|
# On s'attend à 6 raisons: 3 pour les seuils minimums + 3 pour les dégradations
|
|
assert len(reasons) == 6
|
|
assert any("DP accuracy" in r and "seuil minimum" in r for r in reasons)
|
|
assert any("DAS F1" in r and "seuil minimum" in r for r in reasons)
|
|
assert any("CCAM F1" in r and "seuil minimum" in r for r in reasons)
|
|
assert any("Dégradation DP" in r for r in reasons)
|
|
assert any("Dégradation DAS" in r for r in reasons)
|
|
assert any("Dégradation CCAM" in r for r in reasons)
|
|
|
|
def test_validate_release_excessive_degradation(self, temp_gold_dir):
|
|
"""Test le blocage si dégradation excessive."""
|
|
validator = GoldSetValidator(
|
|
gold_set_path=temp_gold_dir,
|
|
max_degradation=0.05
|
|
)
|
|
|
|
before = GoldSetMetrics(
|
|
total_stays=200,
|
|
dp_accuracy=0.75,
|
|
das_precision=0.70,
|
|
das_recall=0.65,
|
|
das_f1=0.67,
|
|
ccam_precision=0.75,
|
|
ccam_recall=0.73,
|
|
ccam_f1=0.74,
|
|
avg_processing_time=23.0,
|
|
error_rate=0.03
|
|
)
|
|
|
|
# Dégradation de 8% (> 5%)
|
|
after = GoldSetMetrics(
|
|
total_stays=200,
|
|
dp_accuracy=0.67, # -8%
|
|
das_precision=0.62,
|
|
das_recall=0.57,
|
|
das_f1=0.59, # -8%
|
|
ccam_precision=0.67,
|
|
ccam_recall=0.65,
|
|
ccam_f1=0.66, # -8%
|
|
avg_processing_time=25.0,
|
|
error_rate=0.06
|
|
)
|
|
|
|
release_ok, reasons = validator.validate_release(before, after)
|
|
|
|
assert release_ok is False
|
|
assert len(reasons) >= 3
|
|
assert any("Dégradation DP" in r for r in reasons)
|
|
assert any("Dégradation DAS" in r for r in reasons)
|
|
assert any("Dégradation CCAM" in r for r in reasons)
|
|
|
|
|
|
class TestSaveMetrics:
|
|
"""Tests de sauvegarde des métriques."""
|
|
|
|
def test_save_metrics(self, temp_gold_dir):
|
|
"""Test la sauvegarde des métriques."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
metrics = GoldSetMetrics(
|
|
total_stays=200,
|
|
dp_accuracy=0.75,
|
|
das_precision=0.70,
|
|
das_recall=0.65,
|
|
das_f1=0.67,
|
|
ccam_precision=0.75,
|
|
ccam_recall=0.73,
|
|
ccam_f1=0.74,
|
|
avg_processing_time=23.0,
|
|
error_rate=0.03
|
|
)
|
|
|
|
output_path = temp_gold_dir / "metrics" / "test_metrics.json"
|
|
validator.save_metrics(metrics, output_path)
|
|
|
|
assert output_path.exists()
|
|
|
|
# Vérifier le contenu
|
|
with open(output_path, "r", encoding="utf-8") as f:
|
|
saved_data = json.load(f)
|
|
|
|
assert saved_data["total_stays"] == 200
|
|
assert saved_data["dp_accuracy"] == 0.75
|
|
assert saved_data["das_f1"] == 0.67
|
|
|
|
|
|
class TestCalculateMetricsHelper:
|
|
"""Tests de la méthode helper _calculate_metrics."""
|
|
|
|
def test_calculate_metrics_perfect_match(self, temp_gold_dir):
|
|
"""Test avec correspondance parfaite."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
predicted = ["I10", "E11.9", "K29.7"]
|
|
expected = ["I10", "E11.9", "K29.7"]
|
|
|
|
precision, recall, f1 = validator._calculate_metrics(predicted, expected)
|
|
|
|
assert precision == 1.0
|
|
assert recall == 1.0
|
|
assert f1 == 1.0
|
|
|
|
def test_calculate_metrics_partial_match(self, temp_gold_dir):
|
|
"""Test avec correspondance partielle."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
predicted = ["I10", "E11.9"]
|
|
expected = ["I10", "E11.9", "K29.7"]
|
|
|
|
precision, recall, f1 = validator._calculate_metrics(predicted, expected)
|
|
|
|
assert precision == 1.0 # 2/2
|
|
assert recall == 2/3 # 2/3
|
|
assert 0.79 < f1 < 0.81 # 2 * (1.0 * 0.67) / (1.0 + 0.67) ≈ 0.80
|
|
|
|
def test_calculate_metrics_no_match(self, temp_gold_dir):
|
|
"""Test sans correspondance."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
predicted = ["I10", "E11.9"]
|
|
expected = ["K29.7", "J44.0"]
|
|
|
|
precision, recall, f1 = validator._calculate_metrics(predicted, expected)
|
|
|
|
assert precision == 0.0
|
|
assert recall == 0.0
|
|
assert f1 == 0.0
|
|
|
|
def test_calculate_metrics_empty_lists(self, temp_gold_dir):
|
|
"""Test avec listes vides."""
|
|
validator = GoldSetValidator(gold_set_path=temp_gold_dir)
|
|
|
|
# Les deux vides = match parfait
|
|
precision, recall, f1 = validator._calculate_metrics([], [])
|
|
assert precision == 1.0
|
|
assert recall == 1.0
|
|
assert f1 == 1.0
|
|
|
|
# Predicted vide, expected non vide
|
|
precision, recall, f1 = validator._calculate_metrics([], ["I10"])
|
|
assert precision == 0.0
|
|
assert recall == 0.0
|
|
assert f1 == 0.0
|
|
|
|
# Predicted non vide, expected vide
|
|
precision, recall, f1 = validator._calculate_metrics(["I10"], [])
|
|
assert precision == 0.0
|
|
assert recall == 0.0
|
|
assert f1 == 0.0
|