diff --git a/camembert_ner_manager.py b/camembert_ner_manager.py index 627c990..155fd95 100644 --- a/camembert_ner_manager.py +++ b/camembert_ner_manager.py @@ -183,8 +183,16 @@ class CamembertNerManager: ) offsets = encoding.pop("offset_mapping")[0] # (seq_len, 2) - # Inférence - inputs = {k: v for k, v in encoding.items() if k in ("input_ids", "attention_mask")} + # Inférence. Certains tokenizers renvoient des tableaux int32 sous + # Windows, alors que le graphe CamemBERT ONNX attend des int64. + inputs = {} + for key, value in encoding.items(): + if key not in ("input_ids", "attention_mask"): + continue + array = np.asarray(value) + if array.dtype != np.int64: + array = array.astype(np.int64) + inputs[key] = array outputs = self._session.run(None, inputs) logits = outputs[0][0] # (seq_len, num_labels) diff --git a/tests/unit/test_camembert_manager_cache.py b/tests/unit/test_camembert_manager_cache.py index 1035db1..3e9264a 100644 --- a/tests/unit/test_camembert_manager_cache.py +++ b/tests/unit/test_camembert_manager_cache.py @@ -1,5 +1,7 @@ import json +import numpy as np + def test_camembert_load_is_idempotent_and_reuses_process_session(tmp_path, monkeypatch): import camembert_ner_manager as module @@ -53,3 +55,44 @@ def test_camembert_load_is_idempotent_and_reuses_process_session(tmp_path, monke assert first.is_loaded() assert second.is_loaded() assert first._session is second._session + + +def test_camembert_predict_casts_tokenizer_inputs_to_int64(): + import camembert_ner_manager as module + + captured_inputs = {} + + class FakeTokenizer: + def __call__(self, text, **kwargs): + return { + "input_ids": np.array([[5, 42, 6]], dtype=np.int32), + "attention_mask": np.array([[1, 1, 1]], dtype=np.int32), + "offset_mapping": np.array([[[0, 0], [0, 5], [0, 0]]], dtype=np.int64), + } + + class FakeSession: + def run(self, output_names, inputs): + captured_inputs.update(inputs) + logits = np.array( + [[[8.0, 0.0], [0.0, 8.0], [8.0, 0.0]]], + dtype=np.float32, + ) + return [logits] + + manager = module.CamembertNerManager() + manager._loaded = True + manager._tokenizer = FakeTokenizer() + manager._session = FakeSession() + manager._id2label = {0: "O", 1: "B-PER"} + + entities = manager.predict("Alice") + + assert captured_inputs["input_ids"].dtype == np.int64 + assert captured_inputs["attention_mask"].dtype == np.int64 + assert len(entities) == 1 + assert entities[0]["word"] == "Alice" + assert entities[0]["label"] == "PER" + assert entities[0]["bio_label"] == "B-PER" + assert entities[0]["start"] == 0 + assert entities[0]["end"] == 5 + assert entities[0]["score"] > 0.99