- A (wired, imports project modules): e2e_map_roles, anonymize_demo, grounding_e2e_resolve_engine - B (orphan projection, standalone benches): enrichment_eval_multi, extract_easily_bench_cases, extract_record_bench_cases, grounding_eval_multi
113 lines
3.9 KiB
Python
113 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
|
"""E2E — valide le MODULE `core.extraction.role_mapper` en conditions réelles.
|
|
|
|
Remplace le POC ad hoc (`poc_lecture_ecran.py`) : au lieu de logique inline, on
|
|
appelle la brique TESTÉE `map_roles` avec un vrai client vLLM. Prouve la parité
|
|
module ↔ POC sur un vrai écran DGX.
|
|
|
|
Pipeline : extract_grid_from_image (OCR) → tokens_from_grid → map_roles(client réel).
|
|
Sortie masquée (PII) ; détail complet dumpé dans /tmp (reste sur le DGX).
|
|
"""
|
|
import argparse
|
|
import base64
|
|
import json
|
|
import re
|
|
import sys
|
|
import time
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
from PIL import Image
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
from core.llm.ocr_extractor import extract_grid_from_image # noqa: E402
|
|
from core.extraction.role_mapper import tokens_from_grid, map_roles # noqa: E402
|
|
|
|
VLLM_URL = "http://localhost:8001/v1/chat/completions"
|
|
MODEL = "Qwen/Qwen3-VL-4B-Instruct"
|
|
|
|
|
|
def _img_data_url(path, max_w=1280):
|
|
img = Image.open(path).convert("RGB")
|
|
if img.width > max_w:
|
|
h = int(img.height * max_w / img.width)
|
|
img = img.resize((max_w, h), Image.LANCZOS)
|
|
buf = BytesIO()
|
|
img.save(buf, format="PNG")
|
|
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()
|
|
|
|
|
|
def make_client(max_tokens=1500, max_w=1280):
|
|
"""Construit un client VLM (image_path, prompt) -> texte, branché sur vLLM:8001."""
|
|
def client(image_path, prompt):
|
|
body = {
|
|
"model": MODEL,
|
|
"messages": [{"role": "user", "content": [
|
|
{"type": "image_url", "image_url": {"url": _img_data_url(image_path, max_w)}},
|
|
{"type": "text", "text": prompt},
|
|
]}],
|
|
"temperature": 0.0,
|
|
"max_tokens": max_tokens,
|
|
"chat_template_kwargs": {"enable_thinking": False},
|
|
}
|
|
r = requests.post(VLLM_URL, json=body, timeout=120)
|
|
if r.status_code != 200:
|
|
raise RuntimeError(f"vLLM {r.status_code}: {r.text[:300]}")
|
|
return r.json()["choices"][0]["message"]["content"]
|
|
return client
|
|
|
|
|
|
def _mask(v):
|
|
v = str(v)
|
|
if not v:
|
|
return "<vide>"
|
|
if re.fullmatch(r"[\d .,/:%€-]+", v):
|
|
k = "num/date"
|
|
elif len(v.split()) >= 4:
|
|
k = "texte"
|
|
else:
|
|
k = "court"
|
|
return f"<{k}:{len(v)}c>"
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--extract", required=True)
|
|
ap.add_argument("--roles", default="", help="rôles attendus, séparés par des virgules (mode guidé)")
|
|
a = ap.parse_args()
|
|
roles = [r.strip() for r in a.roles.split(",") if r.strip()] or None
|
|
|
|
t0 = time.time()
|
|
grid = extract_grid_from_image(a.extract)
|
|
t_ocr = time.time() - t0
|
|
tokens = tokens_from_grid(grid)
|
|
confs = sorted(t.confidence for t in tokens)
|
|
med = confs[len(confs) // 2] if confs else 0.0
|
|
|
|
client = make_client()
|
|
t1 = time.time()
|
|
fields = map_roles(a.extract, tokens, client, roles)
|
|
t_vlm = time.time() - t1
|
|
|
|
out = Path(f"/tmp/e2e_{Path(a.extract).stem}.json")
|
|
out.write_text(json.dumps(
|
|
[{"label": f.label, "value": f.value, "confidence": f.confidence,
|
|
"anchored": f.anchored, "value_ids": f.value_ids} for f in fields],
|
|
ensure_ascii=False, indent=2))
|
|
|
|
anc = sum(1 for f in fields if f.anchored)
|
|
print(f"# Image : {Path(a.extract).name}")
|
|
print(f"# Mode : {'guidé ' + str(roles) if roles else 'libre'}")
|
|
print(f"# OCR : {len(tokens)} tokens, conf médiane {med:.2f}, {t_ocr:.1f}s")
|
|
print(f"# VLM : {t_vlm:.1f}s | via map_roles (module testé)")
|
|
print(f"# Champs : {len(fields)} (ancrés OCR: {anc})")
|
|
for f in fields:
|
|
flag = "·" if f.anchored else "∅"
|
|
print(f" {flag} {str(f.label)[:28]:28s} = {_mask(f.value)}")
|
|
print(f"# Ancrage strict : {anc}/{len(fields)} | détail PII -> {out} (DGX, NE PAS rapatrier)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|