fix(ner): convertir les entrees ONNX en int64
Force input_ids et attention_mask en int64 avant inference CamemBERT ONNX, pour eviter les erreurs de dtype selon les tokenizers/environnements Windows. Test cible: test_camembert_manager_cache.py.
This commit is contained in:
@@ -183,8 +183,16 @@ class CamembertNerManager:
|
|||||||
)
|
)
|
||||||
offsets = encoding.pop("offset_mapping")[0] # (seq_len, 2)
|
offsets = encoding.pop("offset_mapping")[0] # (seq_len, 2)
|
||||||
|
|
||||||
# Inférence
|
# Inférence. Certains tokenizers renvoient des tableaux int32 sous
|
||||||
inputs = {k: v for k, v in encoding.items() if k in ("input_ids", "attention_mask")}
|
# Windows, alors que le graphe CamemBERT ONNX attend des int64.
|
||||||
|
inputs = {}
|
||||||
|
for key, value in encoding.items():
|
||||||
|
if key not in ("input_ids", "attention_mask"):
|
||||||
|
continue
|
||||||
|
array = np.asarray(value)
|
||||||
|
if array.dtype != np.int64:
|
||||||
|
array = array.astype(np.int64)
|
||||||
|
inputs[key] = array
|
||||||
outputs = self._session.run(None, inputs)
|
outputs = self._session.run(None, inputs)
|
||||||
logits = outputs[0][0] # (seq_len, num_labels)
|
logits = outputs[0][0] # (seq_len, num_labels)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def test_camembert_load_is_idempotent_and_reuses_process_session(tmp_path, monkeypatch):
|
def test_camembert_load_is_idempotent_and_reuses_process_session(tmp_path, monkeypatch):
|
||||||
import camembert_ner_manager as module
|
import camembert_ner_manager as module
|
||||||
@@ -53,3 +55,44 @@ def test_camembert_load_is_idempotent_and_reuses_process_session(tmp_path, monke
|
|||||||
assert first.is_loaded()
|
assert first.is_loaded()
|
||||||
assert second.is_loaded()
|
assert second.is_loaded()
|
||||||
assert first._session is second._session
|
assert first._session is second._session
|
||||||
|
|
||||||
|
|
||||||
|
def test_camembert_predict_casts_tokenizer_inputs_to_int64():
|
||||||
|
import camembert_ner_manager as module
|
||||||
|
|
||||||
|
captured_inputs = {}
|
||||||
|
|
||||||
|
class FakeTokenizer:
|
||||||
|
def __call__(self, text, **kwargs):
|
||||||
|
return {
|
||||||
|
"input_ids": np.array([[5, 42, 6]], dtype=np.int32),
|
||||||
|
"attention_mask": np.array([[1, 1, 1]], dtype=np.int32),
|
||||||
|
"offset_mapping": np.array([[[0, 0], [0, 5], [0, 0]]], dtype=np.int64),
|
||||||
|
}
|
||||||
|
|
||||||
|
class FakeSession:
|
||||||
|
def run(self, output_names, inputs):
|
||||||
|
captured_inputs.update(inputs)
|
||||||
|
logits = np.array(
|
||||||
|
[[[8.0, 0.0], [0.0, 8.0], [8.0, 0.0]]],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
return [logits]
|
||||||
|
|
||||||
|
manager = module.CamembertNerManager()
|
||||||
|
manager._loaded = True
|
||||||
|
manager._tokenizer = FakeTokenizer()
|
||||||
|
manager._session = FakeSession()
|
||||||
|
manager._id2label = {0: "O", 1: "B-PER"}
|
||||||
|
|
||||||
|
entities = manager.predict("Alice")
|
||||||
|
|
||||||
|
assert captured_inputs["input_ids"].dtype == np.int64
|
||||||
|
assert captured_inputs["attention_mask"].dtype == np.int64
|
||||||
|
assert len(entities) == 1
|
||||||
|
assert entities[0]["word"] == "Alice"
|
||||||
|
assert entities[0]["label"] == "PER"
|
||||||
|
assert entities[0]["bio_label"] == "B-PER"
|
||||||
|
assert entities[0]["start"] == 0
|
||||||
|
assert entities[0]["end"] == 5
|
||||||
|
assert entities[0]["score"] > 0.99
|
||||||
|
|||||||
Reference in New Issue
Block a user