diff --git a/core/extraction/role_mapper.py b/core/extraction/role_mapper.py index 15fb6ea10..fe3da8049 100644 --- a/core/extraction/role_mapper.py +++ b/core/extraction/role_mapper.py @@ -12,8 +12,9 @@ est orchestré ailleurs (et mockable), pour rester testable hors-ligne. """ from __future__ import annotations +import json from dataclasses import dataclass -from typing import List, Optional, Sequence, Tuple +from typing import Callable, List, Optional, Sequence, Tuple BBox = Tuple[int, int, int, int] # (x_min, y_min, x_max, y_max) @@ -121,3 +122,86 @@ def reconstruct_fields( invalid_ids=invalid, )) return out + + +# --- Orchestration VLM (client injectable pour rester testable hors-ligne) --- + +# Un client VLM est un callable (image_path, prompt) -> texte de réponse. +VlmClient = Callable[[str, str], str] + + +def build_role_prompt( + tokens: Sequence[OcrToken], + roles: Optional[Sequence[str]] = None, +) -> str: + """Construit le prompt d'attribution de rôles (ancrage strict par ids). + + Mode *guidé* si `roles` est fourni (rôles attendus de l'écran), sinon *libre* + (le VLM nomme lui-même les champs). Dans les deux cas le VLM ne renvoie que + des `value_ids` — jamais de texte recopié. + """ + ocr_list = [{"id": t.id, "text": t.text} for t in tokens] + if roles: + roles_line = ( + "Rôles attendus sur cet écran (associe chacun s'il est présent) : " + + ", ".join(roles) + ".\n" + ) + else: + roles_line = ( + "Identifie librement les champs présents — le 'label' est le rôle du champ.\n" + ) + return ( + "Tu reçois une capture d'écran d'un dossier patient et la liste des tokens " + "détectés par OCR (chaque token : id, text).\n" + + roles_line + + "Pour chaque champ, désigne les tokens OCR qui composent sa VALEUR.\n" + "RÈGLES STRICTES :\n" + "- Tu ne recopies AUCUN texte. Tu renvoies seulement 'value_ids' : la liste " + "des id de tokens OCR (dans l'ordre de lecture) qui forment la valeur.\n" + "- 'label' = le rôle du champ. N'invente aucun champ.\n" + "- Réponds UNIQUEMENT en JSON PLAT :\n" + '{"ecran":"","champs":[{"label":"...","value_ids":[,...]}]}\n\n' + "Tokens OCR :\n" + json.dumps(ocr_list, ensure_ascii=False) + ) + + +def parse_vlm_json(text: str) -> dict: + """Extrait le 1er objet JSON d'une réponse VLM (tolère les fences ```json). + + Robuste : renvoie `{}` si la réponse n'est pas du JSON exploitable (pas de + crash en batch). + """ + if not text: + return {} + s = text.strip() + if "```" in s: + parts = s.split("```") + if len(parts) >= 2: + s = parts[1] + if s.lstrip().lower().startswith("json"): + s = s.lstrip()[4:] + a, b = s.find("{"), s.rfind("}") + if a < 0 or b <= a: + return {} + try: + return json.loads(s[a:b + 1]) + except (ValueError, TypeError): + return {} + + +def map_roles( + image_path: str, + tokens: Sequence[OcrToken], + vlm_client: VlmClient, + roles: Optional[Sequence[str]] = None, +) -> List[MappedField]: + """Orchestre l'attribution de rôles : prompt → VLM → parse → reconstruction ancrée. + + `vlm_client` est injecté (testable hors-ligne). Le résultat est toujours + ancré sur l'OCR via `reconstruct_fields`. + """ + prompt = build_role_prompt(tokens, roles) + raw = vlm_client(image_path, prompt) + data = parse_vlm_json(raw) + vlm_fields = data.get("champs", []) if isinstance(data, dict) else [] + return reconstruct_fields(tokens, vlm_fields) diff --git a/tests/unit/test_role_mapper.py b/tests/unit/test_role_mapper.py index 3376efac2..40601f443 100644 --- a/tests/unit/test_role_mapper.py +++ b/tests/unit/test_role_mapper.py @@ -8,6 +8,8 @@ import pytest from core.extraction.role_mapper import ( OcrToken, + build_role_prompt, + map_roles, reconstruct_fields, tokens_from_grid, ) @@ -91,3 +93,60 @@ def test_tokens_from_grid_indexe_et_normalise_bbox(): assert [t.id for t in tokens] == [0, 1] assert tokens[0].text == "Nom" assert tokens[1].bbox == (20, 0, 60, 8) + + +# --- map_roles : orchestrateur (client VLM injectable, donc testable hors-ligne) --- + +def _fake_client(response, capture=None): + """Faux client VLM : enregistre éventuellement le prompt reçu, renvoie une réponse fixe.""" + def client(image_path, prompt): + if capture is not None: + capture["prompt"] = prompt + capture["image_path"] = image_path + return response + return client + + +def test_map_roles_reconstruit_via_client_injecte(): + tokens = [_tok(0, "DUPONT"), _tok(1, "Jean")] + client = _fake_client('{"champs":[{"label":"Nom complet","value_ids":[0,1]}]}') + fields = map_roles("img.png", tokens, client) + assert len(fields) == 1 + assert fields[0].label == "Nom complet" + assert fields[0].value == "DUPONT Jean" + + +def test_map_roles_tolere_les_fences_json(): + tokens = [_tok(0, "DUPONT")] + client = _fake_client('```json\n{"champs":[{"label":"Nom","value_ids":[0]}]}\n```') + fields = map_roles("img.png", tokens, client) + assert fields[0].value == "DUPONT" + + +def test_map_roles_json_invalide_retourne_liste_vide(): + # robustesse batch : une réponse VLM non-JSON ne doit pas crasher. + tokens = [_tok(0, "DUPONT")] + client = _fake_client("désolé, je n'ai pas compris") + fields = map_roles("img.png", tokens, client) + assert fields == [] + + +def test_build_role_prompt_inclut_les_tokens_avec_ids(): + tokens = [_tok(0, "Poids"), _tok(1, "72")] + prompt = build_role_prompt(tokens) + assert "Poids" in prompt and "72" in prompt + assert "value_ids" in prompt # on demande bien des ids, pas du texte recopié + + +def test_build_role_prompt_guide_liste_les_roles_attendus(): + tokens = [_tok(0, "X")] + prompt = build_role_prompt(tokens, roles=["Nom", "IPP", "Poids"]) + assert "Nom" in prompt and "IPP" in prompt and "Poids" in prompt + + +def test_map_roles_passe_les_roles_au_prompt(): + tokens = [_tok(0, "X")] + cap = {} + client = _fake_client('{"champs":[]}', capture=cap) + map_roles("img.png", tokens, client, roles=["Diagnostic", "GEMSA"]) + assert "Diagnostic" in cap["prompt"] and "GEMSA" in cap["prompt"]