#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ ONNX NER Model Manager (CamemBERT family) ----------------------------------------- - Chargement paresseux (après lancement de l'appli) - Support des modèles ONNX publiés (model.onnx / model_quantized.onnx) - Fallback : export ONNX à la volée si seul un modèle PyTorch est fourni - Prédiction par paragraphes (token-classification), agrégation 'simple' Dépendances : pip install onnxruntime optimum transformers sentencepiece """ from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Any import os from transformers import AutoTokenizer, AutoConfig, pipeline try: from optimum.onnxruntime import ORTModelForTokenClassification except Exception as e: ORTModelForTokenClassification = None # type: ignore try: from optimum.exporters.onnx import export from optimum.exporters.tasks import TasksManager except Exception: export = None # type: ignore TasksManager = None # type: ignore DEFAULT_MODELS = { # Rapide & léger (quantifié quand présent) "DistilCamemBERT-NER (ONNX)": "cmarkea/distilcamembert-base-ner", # Robuste & répandu "CamemBERT-NER (ONNX)": "Jean-Baptiste/camembert-ner", } SUPPORTED_PER_TAGS = {"PER", "PERSON"} SUPPORTED_LOC_TAGS = {"LOC"} SUPPORTED_ORG_TAGS = {"ORG"} SUPPORTED_DATE_TAGS = {"DATE"} @dataclass class NerThresholds: per: float = 0.90 org: float = 0.90 loc: float = 0.90 date: float = 0.85 class NerModelManager: def __init__(self, cache_dir: Optional[Path] = None, prefer_quantized: bool = True, providers: Optional[List[str]] = None): self.cache_dir = Path(cache_dir) if cache_dir else None self.prefer_quantized = prefer_quantized self.providers = providers or ["CPUExecutionProvider"] self.model_id: Optional[str] = None self._pipe = None self._tokenizer = None self._loaded = False # ------------------ public API ------------------ def is_loaded(self) -> bool: return self._loaded and self._pipe is not None def load(self, model_id_or_path: str, try_export_if_missing_onnx: bool = True) -> None: """Charge un modèle ONNX; si pas d'ONNX et try_export=True, exporte depuis PyTorch. - Supporte un dossier local (contenant model.onnx) ou un repo HF. """ if ORTModelForTokenClassification is None: raise RuntimeError("optimum.onnxruntime introuvable. Installez 'optimum' et 'onnxruntime'.") self.unload() self.model_id = model_id_or_path cache = str(self.cache_dir) if self.cache_dir else None # 1) essaie ONNX quantifié puis normal candidates = [] if self.prefer_quantized: candidates.append("model_quantized.onnx") candidates.append("model.onnx") loaded = False last_err: Optional[Exception] = None for fname in candidates: try: model = ORTModelForTokenClassification.from_pretrained( self.model_id, file_name=fname, cache_dir=cache, provider=self.providers[0], ) tokenizer = AutoTokenizer.from_pretrained(self.model_id, cache_dir=cache, use_fast=True) self._pipe = pipeline( task="token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple", ) self._tokenizer = tokenizer loaded = True break except Exception as e: last_err = e continue # 2) fallback : export ONNX si demandé if not loaded and try_export_if_missing_onnx: if export is None or TasksManager is None: raise RuntimeError("Impossible d'exporter en ONNX (optimum.exporters manquant).") try: tmp_dir = Path(cache or ".") / ".onnx_export" tmp_dir.mkdir(parents=True, exist_ok=True) task = "token-classification" onnx_paths = export( model_name_or_path=self.model_id, output=tmp_dir, task=task, opset=17, optimize="O2", atol=1e-4, ) model = ORTModelForTokenClassification.from_pretrained(str(tmp_dir), file_name="model.onnx", provider=self.providers[0]) tokenizer = AutoTokenizer.from_pretrained(self.model_id, cache_dir=cache, use_fast=True) self._pipe = pipeline(task="token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple") self._tokenizer = tokenizer loaded = True except Exception as e: last_err = e if not loaded: raise RuntimeError(f"Échec de chargement/export ONNX pour '{self.model_id}': {last_err}") self._loaded = True def unload(self) -> None: self._pipe = None self._tokenizer = None self._loaded = False def models_catalog(self) -> Dict[str, str]: return dict(DEFAULT_MODELS) # ------------------ inference ------------------ def infer_paragraphs(self, paragraphs: List[str], thresholds: Optional[NerThresholds] = None, max_length: int = 384, stride: int = 128) -> List[List[Dict[str, Any]]]: """Retourne, pour chaque paragraphe, une liste d'entités agrégées. Chaque entité a les clés: entity_group, score, word, start, end. """ if not self.is_loaded(): return [[] for _ in paragraphs] th = thresholds or NerThresholds() out: List[List[Dict[str, Any]]] = [] for para in paragraphs: if not para.strip(): out.append([]) continue # Tronquer manuellement si nécessaire (compatibilité transformers récents) input_text = para if self._tokenizer: tok_len = len(self._tokenizer.encode(para, add_special_tokens=True)) if tok_len > 512: tokens = self._tokenizer.encode(para, add_special_tokens=False)[:510] input_text = self._tokenizer.decode(tokens) ents = self._pipe( input_text, aggregation_strategy="simple", ) # Filtrage par seuils filtered: List[Dict[str, Any]] = [] for e in ents: grp = (e.get("entity_group") or e.get("entity") or "").upper() sc = float(e.get("score", 0.0)) if grp in SUPPORTED_PER_TAGS and sc >= th.per: filtered.append(e) elif grp in SUPPORTED_ORG_TAGS and sc >= th.org: filtered.append(e) elif grp in SUPPORTED_LOC_TAGS and sc >= th.loc: filtered.append(e) elif grp in SUPPORTED_DATE_TAGS and sc >= th.date: filtered.append(e) out.append(filtered) return out