"""NER via CamemBERT pour détecter les noms en texte libre.""" from __future__ import annotations import logging from typing import TYPE_CHECKING from ..config import NER_CONFIDENCE_THRESHOLD, NER_MODEL if TYPE_CHECKING: from transformers import Pipeline logger = logging.getLogger(__name__) _pipeline: Pipeline | None = None def _get_pipeline() -> Pipeline: """Charge le modèle NER (lazy loading).""" global _pipeline if _pipeline is None: logger.info("Chargement du modèle NER %s...", NER_MODEL) from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, local_files_only=True) model = AutoModelForTokenClassification.from_pretrained(NER_MODEL, local_files_only=True) _pipeline = pipeline( "ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple", ) logger.info("Modèle NER chargé.") return _pipeline def extract_person_entities(text: str) -> list[dict]: """Extrait les entités de type PER (personnes) du texte. Retourne une liste de dicts avec 'word', 'start', 'end', 'score'. """ pipe = _get_pipeline() # CamemBERT a une limite de tokens — découper en chunks chunks = _split_text(text, max_chars=500) entities: list[dict] = [] offset = 0 for chunk in chunks: results = pipe(chunk) for ent in results: if ent["entity_group"] == "PER" and ent["score"] >= NER_CONFIDENCE_THRESHOLD: word = ent["word"].strip() if len(word) >= 2: entities.append({ "word": word, "start": ent["start"] + offset, "end": ent["end"] + offset, "score": float(ent["score"]), }) offset += len(chunk) return _deduplicate(entities) def _split_text(text: str, max_chars: int = 500) -> list[str]: """Découpe le texte en chunks de taille raisonnable aux limites de phrases.""" if len(text) <= max_chars: return [text] chunks: list[str] = [] start = 0 while start < len(text): end = start + max_chars if end < len(text): # Chercher la fin de phrase la plus proche for sep in ["\n", ". ", ", ", " "]: pos = text.rfind(sep, start, end) if pos > start: end = pos + len(sep) break chunks.append(text[start:end]) start = end return chunks def _deduplicate(entities: list[dict]) -> list[dict]: """Déduplique les entités par position (supprime les chevauchements). Garde toutes les occurrences d'un même mot à des positions différentes, mais supprime les entités qui se chevauchent à la même position (garde celle avec le meilleur score). """ if not entities: return [] # Trier par position de début entities.sort(key=lambda e: e["start"]) result: list[dict] = [] for ent in entities: if result and ent["start"] < result[-1]["end"]: # Chevauchement : garder celle avec le meilleur score if ent["score"] > result[-1]["score"]: result[-1] = ent else: result.append(ent) return result