Force input_ids et attention_mask en int64 avant inference CamemBERT ONNX, pour eviter les erreurs de dtype selon les tokenizers/environnements Windows. Test cible: test_camembert_manager_cache.py.
99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
import json
|
|
|
|
import numpy as np
|
|
|
|
|
|
def test_camembert_load_is_idempotent_and_reuses_process_session(tmp_path, monkeypatch):
|
|
import camembert_ner_manager as module
|
|
|
|
model_dir = tmp_path / "camembert-bio-deid" / "onnx"
|
|
model_dir.mkdir(parents=True)
|
|
(model_dir / "model.onnx").write_bytes(b"fake")
|
|
(model_dir / "config.json").write_text(
|
|
json.dumps({"id2label": {"0": "O", "1": "B-PER"}}),
|
|
encoding="utf-8",
|
|
)
|
|
(model_dir.parent / "VERSION.json").write_text(
|
|
json.dumps({"current_version": "v-test", "versions": {"v-test": {"f1": 1, "recall": 1}}}),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
created_sessions = []
|
|
|
|
class FakeSessionOptions:
|
|
inter_op_num_threads = 0
|
|
intra_op_num_threads = 0
|
|
|
|
class FakeOrt:
|
|
SessionOptions = FakeSessionOptions
|
|
|
|
@staticmethod
|
|
def InferenceSession(path, sess_options=None, providers=None):
|
|
session = {"path": path, "providers": providers}
|
|
created_sessions.append(session)
|
|
return session
|
|
|
|
class FakeTokenizer:
|
|
@staticmethod
|
|
def from_pretrained(path):
|
|
return {"tokenizer_path": path}
|
|
|
|
monkeypatch.setattr(module, "_ORT_AVAILABLE", True)
|
|
monkeypatch.setattr(module, "_TOKENIZERS_AVAILABLE", True)
|
|
monkeypatch.setattr(module, "ort", FakeOrt)
|
|
monkeypatch.setattr(module, "AutoTokenizer", FakeTokenizer)
|
|
module._PROCESS_CACHE.clear()
|
|
|
|
first = module.CamembertNerManager(model_dir)
|
|
first.load()
|
|
first.load()
|
|
|
|
second = module.CamembertNerManager(model_dir)
|
|
second.load()
|
|
|
|
assert len(created_sessions) == 1
|
|
assert first.is_loaded()
|
|
assert second.is_loaded()
|
|
assert first._session is second._session
|
|
|
|
|
|
def test_camembert_predict_casts_tokenizer_inputs_to_int64():
|
|
import camembert_ner_manager as module
|
|
|
|
captured_inputs = {}
|
|
|
|
class FakeTokenizer:
|
|
def __call__(self, text, **kwargs):
|
|
return {
|
|
"input_ids": np.array([[5, 42, 6]], dtype=np.int32),
|
|
"attention_mask": np.array([[1, 1, 1]], dtype=np.int32),
|
|
"offset_mapping": np.array([[[0, 0], [0, 5], [0, 0]]], dtype=np.int64),
|
|
}
|
|
|
|
class FakeSession:
|
|
def run(self, output_names, inputs):
|
|
captured_inputs.update(inputs)
|
|
logits = np.array(
|
|
[[[8.0, 0.0], [0.0, 8.0], [8.0, 0.0]]],
|
|
dtype=np.float32,
|
|
)
|
|
return [logits]
|
|
|
|
manager = module.CamembertNerManager()
|
|
manager._loaded = True
|
|
manager._tokenizer = FakeTokenizer()
|
|
manager._session = FakeSession()
|
|
manager._id2label = {0: "O", 1: "B-PER"}
|
|
|
|
entities = manager.predict("Alice")
|
|
|
|
assert captured_inputs["input_ids"].dtype == np.int64
|
|
assert captured_inputs["attention_mask"].dtype == np.int64
|
|
assert len(entities) == 1
|
|
assert entities[0]["word"] == "Alice"
|
|
assert entities[0]["label"] == "PER"
|
|
assert entities[0]["bio_label"] == "B-PER"
|
|
assert entities[0]["start"] == 0
|
|
assert entities[0]["end"] == 5
|
|
assert entities[0]["score"] > 0.99
|