diff --git a/core/llm/__init__.py b/core/llm/__init__.py index 50242000b..15773c121 100644 --- a/core/llm/__init__.py +++ b/core/llm/__init__.py @@ -8,6 +8,7 @@ from .t2a_decision import ( ) from .ocr_extractor import ( extract_digits_tesseract_from_image, + extract_grid_from_image, extract_table_from_image, extract_text_from_image, ) @@ -19,5 +20,6 @@ __all__ = [ "build_dpi_enriched", "extract_text_from_image", "extract_table_from_image", + "extract_grid_from_image", "extract_digits_tesseract_from_image", ] diff --git a/core/llm/ocr_extractor.py b/core/llm/ocr_extractor.py index e8fbd0f08..39b08fc1d 100644 --- a/core/llm/ocr_extractor.py +++ b/core/llm/ocr_extractor.py @@ -243,3 +243,107 @@ def extract_table_from_image( except Exception as e: logger.warning("extract_table échoué sur %s : %s", image_path, e) return [] + + +def _cluster_1d(centers: List[float], tol: float) -> List[Tuple[float, int]]: + """Regroupe des positions 1D par proximité (centres triés, gap > tol = nouveau cluster). + + Retourne, pour chaque centre d'entrée (ordre d'origine), un couple + (centre_du_cluster, index_du_cluster), les clusters étant indexés dans + l'ordre croissant. Permet de mapper lignes (y) et colonnes (x). + """ + order = sorted(range(len(centers)), key=lambda i: centers[i]) + cluster_of = [0] * len(centers) + cluster_centers: List[List[float]] = [] + prev = None + idx = -1 + for i in order: + c = centers[i] + if prev is None or (c - prev) > tol: + idx += 1 + cluster_centers.append([]) + cluster_centers[idx].append(c) + cluster_of[i] = idx + prev = c + means = [sum(g) / len(g) for g in cluster_centers] + return [(means[cluster_of[i]], cluster_of[i]) for i in range(len(centers))] + + +def extract_grid_from_image( + image_path: str, + region: Optional[Tuple[int, int, int, int]] = None, + row_tol: float = 12.0, + col_tol: float = 25.0, +) -> List[List[dict]]: + """Extrait un tableau STRUCTURÉ (lignes ET colonnes) via OCR EasyOCR. + + Contrairement à `extract_table_from_image` (liste plate triée par y, x jeté), + on conserve la coordonnée x pour reconstruire une grille. Clustering : + lignes par proximité du centre y, colonnes par proximité du centre x. + + Args: + image_path: chemin du PNG sur disque. + region: (x, y, w, h) pour cropper avant OCR. None = image entière. + row_tol: écart vertical max (px) entre 2 tokens d'une même ligne. + col_tol: écart horizontal max (px) entre 2 tokens d'une même colonne. + + Returns: + Grille `List[List[cell]]`, lignes top→bottom, colonnes left→right. + `cell = {"text", "bbox", "confidence", "row", "col"}`. + En cas d'erreur ou d'absence de tokens, retourne []. + """ + path = Path(image_path) + if not path.exists(): + logger.warning("extract_grid: fichier introuvable %s", image_path) + return [] + + try: + from PIL import Image + import numpy as np + + img = Image.open(path) + if region: + x, y, w, h = region + img = img.crop((x, y, x + w, y + h)) + + reader = _get_reader() + results = reader.readtext(np.array(img), detail=1, paragraph=False) + + toks = [] + for bbox, text, conf in results: + t = str(text).strip() + if not t: + continue + xs = [p[0] for p in bbox] + ys = [p[1] for p in bbox] + toks.append({ + "text": t, + "bbox": bbox, + "confidence": conf, + "xc": sum(xs) / len(xs), + "yc": sum(ys) / len(ys), + }) + if not toks: + return [] + + rows_cl = _cluster_1d([tk["yc"] for tk in toks], row_tol) + cols_cl = _cluster_1d([tk["xc"] for tk in toks], col_tol) + for tk, (_yc, r), (_xc, c) in zip(toks, rows_cl, cols_cl): + tk["row"], tk["col"] = r, c + + n_rows = max(tk["row"] for tk in toks) + 1 + grid: List[List[dict]] = [[] for _ in range(n_rows)] + for tk in toks: + grid[tk["row"]].append({ + "text": tk["text"], + "bbox": tk["bbox"], + "confidence": tk["confidence"], + "row": tk["row"], + "col": tk["col"], + }) + for row in grid: + row.sort(key=lambda cell: cell["col"]) + return grid + except Exception as e: + logger.warning("extract_grid échoué sur %s : %s", image_path, e) + return [] diff --git a/tests/unit/test_extract_grid.py b/tests/unit/test_extract_grid.py new file mode 100644 index 000000000..fa4220576 --- /dev/null +++ b/tests/unit/test_extract_grid.py @@ -0,0 +1,79 @@ +"""Tests pour extract_grid_from_image — lecture de tableau STRUCTURÉE. + +Contrairement à extract_table_from_image (qui jette x et retourne une liste +plate triée par y), extract_grid_from_image reconstruit une vraie grille +List[List[cell]] : clustering des lignes par proximité y, des colonnes par +proximité x. bbox + confiance conservées par cellule. + +Les tokens OCR sont injectés (mock du reader EasyOCR) → pas de PNG réel, +pas de GPU. +""" +from pathlib import Path +from types import SimpleNamespace + +from PIL import Image + +import core.llm.ocr_extractor as ocr_extractor + + +def _blank_png(path: Path) -> None: + Image.new("RGB", (300, 120), "white").save(path) + + +def _bbox(x0: float, y0: float, x1: float, y1: float): + """bbox EasyOCR = 4 points [tl, tr, br, bl], chaque point (x, y).""" + return [[x0, y0], [x1, y0], [x1, y1], [x0, y1]] + + +def _fake_reader(tokens): + """Reader factice : readtext() renvoie la liste (bbox, text, conf) fournie.""" + return SimpleNamespace(readtext=lambda *a, **k: tokens) + + +def test_extract_grid_2x3(tmp_path, monkeypatch): + image_path = tmp_path / "table.png" + _blank_png(image_path) + + # 2 lignes (y≈10 et y≈60) × 3 colonnes (x≈10, x≈110, x≈210). + # Volontairement mélangées dans l'ordre OCR pour vérifier le tri. + tokens = [ + (_bbox(110, 58, 160, 78), "B2", 0.97), + (_bbox(10, 10, 60, 30), "A1", 0.91), + (_bbox(210, 12, 260, 32), "C1", 0.88), + (_bbox(210, 60, 260, 80), "C2", 0.95), + (_bbox(10, 60, 60, 80), "A2", 0.90), + (_bbox(110, 8, 160, 28), "B1", 0.93), + ] + monkeypatch.setattr(ocr_extractor, "_get_reader", lambda: _fake_reader(tokens)) + + grid = ocr_extractor.extract_grid_from_image(str(image_path)) + + # Grille 2×3 ordonnée + assert len(grid) == 2, "doit détecter 2 lignes" + assert all(len(row) == 3 for row in grid), "chaque ligne doit avoir 3 colonnes" + + texts = [[cell["text"] for cell in row] for row in grid] + assert texts == [["A1", "B1", "C1"], ["A2", "B2", "C2"]] + + # Métadonnées conservées + indices row/col cohérents + cell = grid[0][2] + assert cell["text"] == "C1" + assert cell["confidence"] == 0.88 + assert cell["bbox"] == _bbox(210, 12, 260, 32) + assert cell["row"] == 0 + assert cell["col"] == 2 + assert grid[1][0]["row"] == 1 and grid[1][0]["col"] == 0 + + +def test_extract_grid_empty_when_no_tokens(tmp_path, monkeypatch): + image_path = tmp_path / "blank.png" + _blank_png(image_path) + monkeypatch.setattr(ocr_extractor, "_get_reader", lambda: _fake_reader([])) + + grid = ocr_extractor.extract_grid_from_image(str(image_path)) + assert grid == [] + + +def test_extract_grid_missing_file_returns_empty(): + grid = ocr_extractor.extract_grid_from_image("/no/such/file.png") + assert grid == []