fix(cli): avoid duplicate ONNX native load in Windows frozen
This commit is contained in:
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -41,6 +42,9 @@ except ImportError:
|
||||
|
||||
DEFAULT_MODEL_DIR = Path(__file__).parent / "models" / "camembert-bio-deid" / "onnx"
|
||||
|
||||
_LOAD_LOCK = threading.RLock()
|
||||
_PROCESS_CACHE: Dict[Path, Dict[str, Any]] = {}
|
||||
|
||||
# Mapping labels BIO du modèle → clés PLACEHOLDERS (anonymizer_core)
|
||||
CAMEMBERT_LABEL_MAP: Dict[str, str] = {
|
||||
"PER": "NOM",
|
||||
@@ -79,6 +83,9 @@ class CamembertNerManager:
|
||||
|
||||
def load(self) -> None:
|
||||
"""Charge le modèle ONNX et le tokenizer."""
|
||||
if self._loaded and self._session is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
if not _ORT_AVAILABLE:
|
||||
raise RuntimeError("onnxruntime non disponible. Installez : pip install onnxruntime")
|
||||
if not _TOKENIZERS_AVAILABLE:
|
||||
@@ -88,44 +95,65 @@ class CamembertNerManager:
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Modèle ONNX non trouvé: {model_path}")
|
||||
|
||||
self.unload()
|
||||
cache_key = self._model_dir.resolve()
|
||||
with _LOAD_LOCK:
|
||||
cached = _PROCESS_CACHE.get(cache_key)
|
||||
if cached is not None:
|
||||
self._session = cached["session"]
|
||||
self._tokenizer = cached["tokenizer"]
|
||||
self._id2label = dict(cached["id2label"])
|
||||
self._version = cached.get("version", "?")
|
||||
self._loaded = True
|
||||
log.info(f"CamemBERT-bio ONNX réutilisé: {self._model_dir} ({len(self._id2label)} labels)")
|
||||
return
|
||||
|
||||
# Charger id2label depuis config.json
|
||||
config_path = self._model_dir / "config.json"
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
self._id2label = {int(k): v for k, v in cfg.get("id2label", {}).items()}
|
||||
self.unload()
|
||||
|
||||
# Session ONNX (CPU)
|
||||
opts = ort.SessionOptions()
|
||||
opts.inter_op_num_threads = 2
|
||||
opts.intra_op_num_threads = 4
|
||||
self._session = ort.InferenceSession(
|
||||
str(model_path),
|
||||
sess_options=opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
# Charger id2label depuis config.json
|
||||
config_path = self._model_dir / "config.json"
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
self._id2label = {int(k): v for k, v in cfg.get("id2label", {}).items()}
|
||||
|
||||
# Tokenizer
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(str(self._model_dir))
|
||||
self._loaded = True
|
||||
# Session ONNX (CPU). Une seule session CamemBERT par process et par
|
||||
# dossier modèle : certains runtimes Windows/PyInstaller refusent de
|
||||
# recharger le module natif plus d'une fois dans le même process.
|
||||
opts = ort.SessionOptions()
|
||||
opts.inter_op_num_threads = 2
|
||||
opts.intra_op_num_threads = 4
|
||||
self._session = ort.InferenceSession(
|
||||
str(model_path),
|
||||
sess_options=opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
# Lire la version depuis VERSION.json (si disponible)
|
||||
self._version = "?"
|
||||
version_path = self._model_dir.parent / "VERSION.json"
|
||||
if version_path.exists():
|
||||
try:
|
||||
with open(version_path, encoding="utf-8") as vf:
|
||||
vinfo = json.load(vf)
|
||||
self._version = vinfo.get("current_version", "?")
|
||||
v_meta = vinfo.get("versions", {}).get(self._version, {})
|
||||
f1 = v_meta.get("f1", "?")
|
||||
recall = v_meta.get("recall", "?")
|
||||
log.info(f"CamemBERT-bio ONNX {self._version} chargé (F1={f1}, R={recall}, {len(self._id2label)} labels)")
|
||||
except Exception:
|
||||
# Tokenizer
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(str(self._model_dir))
|
||||
self._loaded = True
|
||||
|
||||
# Lire la version depuis VERSION.json (si disponible)
|
||||
self._version = "?"
|
||||
version_path = self._model_dir.parent / "VERSION.json"
|
||||
if version_path.exists():
|
||||
try:
|
||||
with open(version_path, encoding="utf-8") as vf:
|
||||
vinfo = json.load(vf)
|
||||
self._version = vinfo.get("current_version", "?")
|
||||
v_meta = vinfo.get("versions", {}).get(self._version, {})
|
||||
f1 = v_meta.get("f1", "?")
|
||||
recall = v_meta.get("recall", "?")
|
||||
log.info(f"CamemBERT-bio ONNX {self._version} chargé (F1={f1}, R={recall}, {len(self._id2label)} labels)")
|
||||
except Exception:
|
||||
log.info(f"CamemBERT-bio ONNX chargé: {self._model_dir} ({len(self._id2label)} labels)")
|
||||
else:
|
||||
log.info(f"CamemBERT-bio ONNX chargé: {self._model_dir} ({len(self._id2label)} labels)")
|
||||
else:
|
||||
log.info(f"CamemBERT-bio ONNX chargé: {self._model_dir} ({len(self._id2label)} labels)")
|
||||
|
||||
_PROCESS_CACHE[cache_key] = {
|
||||
"session": self._session,
|
||||
"tokenizer": self._tokenizer,
|
||||
"id2label": dict(self._id2label),
|
||||
"version": self._version,
|
||||
}
|
||||
|
||||
def unload(self) -> None:
|
||||
self._session = None
|
||||
|
||||
Reference in New Issue
Block a user