fix(ner): convertir les entrees ONNX en int64
Force input_ids et attention_mask en int64 avant inference CamemBERT ONNX, pour eviter les erreurs de dtype selon les tokenizers/environnements Windows. Test cible: test_camembert_manager_cache.py.
This commit is contained in:
@@ -183,8 +183,16 @@ class CamembertNerManager:
|
||||
)
|
||||
offsets = encoding.pop("offset_mapping")[0] # (seq_len, 2)
|
||||
|
||||
# Inférence
|
||||
inputs = {k: v for k, v in encoding.items() if k in ("input_ids", "attention_mask")}
|
||||
# Inférence. Certains tokenizers renvoient des tableaux int32 sous
|
||||
# Windows, alors que le graphe CamemBERT ONNX attend des int64.
|
||||
inputs = {}
|
||||
for key, value in encoding.items():
|
||||
if key not in ("input_ids", "attention_mask"):
|
||||
continue
|
||||
array = np.asarray(value)
|
||||
if array.dtype != np.int64:
|
||||
array = array.astype(np.int64)
|
||||
inputs[key] = array
|
||||
outputs = self._session.run(None, inputs)
|
||||
logits = outputs[0][0] # (seq_len, num_labels)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_camembert_load_is_idempotent_and_reuses_process_session(tmp_path, monkeypatch):
|
||||
import camembert_ner_manager as module
|
||||
@@ -53,3 +55,44 @@ def test_camembert_load_is_idempotent_and_reuses_process_session(tmp_path, monke
|
||||
assert first.is_loaded()
|
||||
assert second.is_loaded()
|
||||
assert first._session is second._session
|
||||
|
||||
|
||||
def test_camembert_predict_casts_tokenizer_inputs_to_int64():
|
||||
import camembert_ner_manager as module
|
||||
|
||||
captured_inputs = {}
|
||||
|
||||
class FakeTokenizer:
|
||||
def __call__(self, text, **kwargs):
|
||||
return {
|
||||
"input_ids": np.array([[5, 42, 6]], dtype=np.int32),
|
||||
"attention_mask": np.array([[1, 1, 1]], dtype=np.int32),
|
||||
"offset_mapping": np.array([[[0, 0], [0, 5], [0, 0]]], dtype=np.int64),
|
||||
}
|
||||
|
||||
class FakeSession:
|
||||
def run(self, output_names, inputs):
|
||||
captured_inputs.update(inputs)
|
||||
logits = np.array(
|
||||
[[[8.0, 0.0], [0.0, 8.0], [8.0, 0.0]]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
return [logits]
|
||||
|
||||
manager = module.CamembertNerManager()
|
||||
manager._loaded = True
|
||||
manager._tokenizer = FakeTokenizer()
|
||||
manager._session = FakeSession()
|
||||
manager._id2label = {0: "O", 1: "B-PER"}
|
||||
|
||||
entities = manager.predict("Alice")
|
||||
|
||||
assert captured_inputs["input_ids"].dtype == np.int64
|
||||
assert captured_inputs["attention_mask"].dtype == np.int64
|
||||
assert len(entities) == 1
|
||||
assert entities[0]["word"] == "Alice"
|
||||
assert entities[0]["label"] == "PER"
|
||||
assert entities[0]["bio_label"] == "B-PER"
|
||||
assert entities[0]["start"] == 0
|
||||
assert entities[0]["end"] == 5
|
||||
assert entities[0]["score"] > 0.99
|
||||
|
||||
Reference in New Issue
Block a user