207 lines
7.1 KiB
Python
207 lines
7.1 KiB
Python
"""Tests évaluation gold CRH — logique tolérante sans mocks.
|
|
|
|
3 cas inline :
|
|
1. Strict match OK
|
|
2. Strict FAIL mais acceptable via family3
|
|
3. R* choisi avec allow_symptom_dp=false → symptom_not_allowed
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from src.eval.gold_models import (
|
|
GoldCRHCase,
|
|
GoldDPExpected,
|
|
GoldEvidence,
|
|
evaluate_dp,
|
|
is_valid_cim10_format,
|
|
cim10_family3,
|
|
load_gold_jsonl,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers validation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestCIM10Format:
|
|
def test_valid_codes(self):
|
|
assert is_valid_cim10_format("I26.9")
|
|
assert is_valid_cim10_format("K81.0")
|
|
assert is_valid_cim10_format("R06")
|
|
assert is_valid_cim10_format("Z51.30")
|
|
|
|
def test_invalid_codes(self):
|
|
assert not is_valid_cim10_format("26.9")
|
|
assert not is_valid_cim10_format("INVALID")
|
|
assert not is_valid_cim10_format("")
|
|
assert not is_valid_cim10_format("I2")
|
|
|
|
def test_family3(self):
|
|
assert cim10_family3("I26.9") == "I26"
|
|
assert cim10_family3("K81.0") == "K81"
|
|
assert cim10_family3("R06") == "R06"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Modèle GoldCRHCase
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestGoldCRHCase:
|
|
def test_valid_case(self):
|
|
case = GoldCRHCase(
|
|
case_id="test_001",
|
|
dp_expected=GoldDPExpected(code="I26.9", label="Embolie pulmonaire"),
|
|
dp_acceptable_codes=["I26.0"],
|
|
dp_acceptable_family3=["I26"],
|
|
confidence="certain",
|
|
)
|
|
assert case.case_id == "test_001"
|
|
assert case.dp_expected.code == "I26.9"
|
|
assert case.allow_symptom_dp is False
|
|
|
|
def test_invalid_confidence_rejected(self):
|
|
import pytest
|
|
with pytest.raises(Exception):
|
|
GoldCRHCase(
|
|
case_id="test",
|
|
dp_expected=GoldDPExpected(code="I26.9", label="Test"),
|
|
confidence="invalid_value",
|
|
)
|
|
|
|
def test_invalid_code_rejected(self):
|
|
import pytest
|
|
with pytest.raises(Exception):
|
|
GoldDPExpected(code="INVALID", label="Test")
|
|
|
|
def test_notes_max_length(self):
|
|
import pytest
|
|
with pytest.raises(Exception):
|
|
GoldCRHCase(
|
|
case_id="test",
|
|
dp_expected=GoldDPExpected(code="I26.9", label="Test"),
|
|
notes="x" * 401,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Évaluation tolérante — 3 cas demandés
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_gold(
|
|
code: str,
|
|
label: str,
|
|
acceptable_codes: list[str] | None = None,
|
|
acceptable_family3: list[str] | None = None,
|
|
allow_symptom: bool = False,
|
|
confidence: str = "certain",
|
|
) -> GoldCRHCase:
|
|
return GoldCRHCase(
|
|
case_id="test_case",
|
|
dp_expected=GoldDPExpected(code=code, label=label),
|
|
dp_acceptable_codes=acceptable_codes or [],
|
|
dp_acceptable_family3=acceptable_family3 or [],
|
|
allow_symptom_dp=allow_symptom,
|
|
confidence=confidence,
|
|
)
|
|
|
|
|
|
class TestEvaluateDP:
|
|
"""3 cas principaux + cas limites."""
|
|
|
|
def test_strict_match_ok(self):
|
|
"""Cas 1 — strict match : code choisi == code attendu."""
|
|
gold = _make_gold("I26.9", "Embolie pulmonaire", ["I26.0"], ["I26"])
|
|
result = evaluate_dp("I26.9", gold)
|
|
|
|
assert result["exact_match_strict"] is True
|
|
assert result["exact_match_tolerant_codes"] is True
|
|
assert result["family3_match_tolerant"] is True
|
|
assert result["acceptable_match"] is True
|
|
assert result["symptom_not_allowed"] is False
|
|
|
|
def test_strict_fail_family3_ok(self):
|
|
"""Cas 2 — strict FAIL, mais acceptable via family3."""
|
|
gold = _make_gold("I25.1", "SCA", ["I25.5"], ["I25"])
|
|
result = evaluate_dp("I25.8", gold)
|
|
|
|
assert result["exact_match_strict"] is False
|
|
assert result["exact_match_tolerant_codes"] is False # I25.8 pas dans [I25.1, I25.5]
|
|
assert result["family3_match_tolerant"] is True # I25 dans ["I25"]
|
|
assert result["acceptable_match"] is True
|
|
|
|
def test_symptom_not_allowed(self):
|
|
"""Cas 3 — R* choisi avec allow_symptom_dp=false → pénalité."""
|
|
gold = _make_gold("I25.1", "SCA", acceptable_family3=["I25"], allow_symptom=False)
|
|
result = evaluate_dp("R10.4", gold)
|
|
|
|
assert result["exact_match_strict"] is False
|
|
assert result["acceptable_match"] is False
|
|
assert result["symptom_not_allowed"] is True
|
|
|
|
def test_symptom_allowed(self):
|
|
"""R* choisi avec allow_symptom_dp=true → pas de pénalité."""
|
|
gold = _make_gold("R06.0", "Dyspnée", allow_symptom=True)
|
|
result = evaluate_dp("R06.0", gold)
|
|
|
|
assert result["exact_match_strict"] is True
|
|
assert result["symptom_not_allowed"] is False
|
|
|
|
def test_no_chosen_code(self):
|
|
"""Pas de code choisi → tout False."""
|
|
gold = _make_gold("I26.9", "EP")
|
|
result = evaluate_dp(None, gold)
|
|
|
|
assert result["exact_match_strict"] is False
|
|
assert result["acceptable_match"] is False
|
|
assert result["symptom_not_allowed"] is False
|
|
|
|
def test_tolerant_codes_match(self):
|
|
"""Code dans dp_acceptable_codes mais pas dp_expected."""
|
|
gold = _make_gold("I26.9", "EP", acceptable_codes=["I26.0"])
|
|
result = evaluate_dp("I26.0", gold)
|
|
|
|
assert result["exact_match_strict"] is False
|
|
assert result["exact_match_tolerant_codes"] is True
|
|
assert result["acceptable_match"] is True
|
|
|
|
def test_case_insensitive(self):
|
|
"""Codes en minuscules fonctionnent."""
|
|
gold = _make_gold("I26.9", "EP")
|
|
result = evaluate_dp("i26.9", gold)
|
|
|
|
assert result["exact_match_strict"] is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Chargement JSONL
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestLoadGold:
|
|
def test_load_nonexistent_raises(self):
|
|
import pytest
|
|
with pytest.raises(FileNotFoundError):
|
|
load_gold_jsonl("/nonexistent/path.jsonl")
|
|
|
|
def test_load_valid_jsonl(self, tmp_path):
|
|
import json
|
|
jsonl = tmp_path / "test.jsonl"
|
|
case = {
|
|
"case_id": "test_001",
|
|
"dp_expected": {"code": "I26.9", "label": "EP"},
|
|
"confidence": "certain",
|
|
}
|
|
jsonl.write_text(json.dumps(case) + "\n", encoding="utf-8")
|
|
|
|
cases = load_gold_jsonl(jsonl)
|
|
assert len(cases) == 1
|
|
assert cases[0].case_id == "test_001"
|
|
assert cases[0].dp_expected.code == "I26.9"
|
|
|
|
def test_load_invalid_line_raises(self, tmp_path):
|
|
import pytest
|
|
jsonl = tmp_path / "bad.jsonl"
|
|
jsonl.write_text('{"case_id": "x", "dp_expected": {"code": "INVALID"}}\n')
|
|
|
|
with pytest.raises(ValueError, match="erreur"):
|
|
load_gold_jsonl(jsonl)
|