"""Tests unitaires pour le cache Ollama persistant.""" import json import threading import pytest from src.medical.ollama_cache import OllamaCache class TestOllamaCache: def test_get_miss(self, tmp_path): cache = OllamaCache(tmp_path / "cache.json", "gemma3:12b") assert cache.get("HTA", "das") is None def test_put_and_get(self, tmp_path): cache = OllamaCache(tmp_path / "cache.json", "gemma3:12b") result = {"code": "I10", "confidence": "high", "justification": "HTA essentielle"} cache.put("HTA", "das", result) assert cache.get("HTA", "das") == result def test_key_normalization(self, tmp_path): cache = OllamaCache(tmp_path / "cache.json", "gemma3:12b") result = {"code": "I10", "confidence": "high"} cache.put(" HTA ", "das", result) assert cache.get("hta", "das") == result def test_different_types_different_keys(self, tmp_path): cache = OllamaCache(tmp_path / "cache.json", "gemma3:12b") cache.put("Diabète", "dp", {"code": "E11.9"}) cache.put("Diabète", "das", {"code": "E11.8"}) assert cache.get("Diabète", "dp")["code"] == "E11.9" assert cache.get("Diabète", "das")["code"] == "E11.8" def test_save_and_reload(self, tmp_path): path = tmp_path / "cache.json" cache = OllamaCache(path, "gemma3:12b") cache.put("HTA", "das", {"code": "I10"}) cache.save() assert path.exists() cache2 = OllamaCache(path, "gemma3:12b") assert cache2.get("HTA", "das") == {"code": "I10"} def test_save_no_write_if_clean(self, tmp_path): path = tmp_path / "cache.json" cache = OllamaCache(path, "gemma3:12b") cache.save() assert not path.exists() def test_model_change_invalidates(self, tmp_path): path = tmp_path / "cache.json" cache = OllamaCache(path, "gemma3:12b") cache.put("HTA", "das", {"code": "I10"}) cache.save() cache2 = OllamaCache(path, "llama3:8b") assert cache2.get("HTA", "das") is None assert len(cache2) == 0 def test_corrupted_file(self, tmp_path): path = tmp_path / "cache.json" path.write_text("not valid json", encoding="utf-8") cache = OllamaCache(path, "gemma3:12b") assert len(cache) == 0 assert cache.get("HTA", "das") is None def test_len(self, tmp_path): cache = OllamaCache(tmp_path / "cache.json", "gemma3:12b") assert len(cache) == 0 cache.put("HTA", "das", {"code": "I10"}) assert len(cache) == 1 cache.put("Diabète", "dp", {"code": "E11.9"}) assert len(cache) == 2 def test_thread_safety(self, tmp_path): """Écriture concurrente depuis plusieurs threads.""" cache = OllamaCache(tmp_path / "cache.json", "gemma3:12b") errors = [] def writer(i): try: cache.put(f"diag_{i}", "das", {"code": f"X{i:02d}"}) except Exception as e: errors.append(e) threads = [threading.Thread(target=writer, args=(i,)) for i in range(20)] for t in threads: t.start() for t in threads: t.join() assert not errors assert len(cache) == 20 def test_json_format(self, tmp_path): """Le fichier JSON contient le modèle et les entrées.""" path = tmp_path / "cache.json" cache = OllamaCache(path, "gemma3:12b") cache.put("HTA", "das", {"code": "I10"}) cache.save() raw = json.loads(path.read_text(encoding="utf-8")) assert raw["model"] == "gemma3:12b" assert "entries" in raw assert len(raw["entries"]) == 1