Files
anonymisation/scripts/finetune_camembert_bio.py
Domi31tls 26b210607c feat(phase2): Gazetteers FINESS 102K établissements + fine-tuning CamemBERT-bio F1=89%
Gazetteers FINESS (data.gouv.fr open data):
- 102K numéros FINESS → détection par lookup exact dans _mask_admin_label + selective_rescan
- 122K noms d'établissements, 113K téléphones, 76K adresses (disponibles)
- Un nombre 9 chiffres matchant un vrai FINESS est masqué même sans label "FINESS"

Fine-tuning CamemBERT-bio (almanach/camembert-bio-base):
- Export silver annotations réécrit : alignement original↔pseudonymisé (difflib)
  → 6862 entités B- (vs 3344 avec l'ancien audit-only) sur 222K tokens
- Sliding windows (200 tokens, stride 100) pour documents longs
- WeightedNERTrainer avec class weights cappés (max 10x) + label smoothing
- Résultat: Precision=88.1%, Recall=89.8%, F1=88.9% (20 epochs, lr=1e-5)
- Modèle sauvegardé dans models/camembert-bio-deid/best (non commité)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 13:27:37 +01:00

340 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Fine-tune CamemBERT-bio pour la désidentification clinique française.
=====================================================================
Entraîne almanach/camembert-bio-base sur les annotations silver/gold
exportées par export_silver_annotations.py.
Usage:
python scripts/finetune_camembert_bio.py [--epochs 5] [--batch-size 8] [--lr 2e-5]
Prérequis: pip install transformers datasets seqeval accelerate
Export ONNX post-training: python scripts/export_onnx.py
"""
import sys
import argparse
from pathlib import Path
from typing import Dict, List
from collections import Counter
import numpy as np
import torch
from torch import nn
# Vérifier les dépendances
try:
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
TrainingArguments,
Trainer,
DataCollatorForTokenClassification,
)
from datasets import Dataset, DatasetDict
import evaluate
except ImportError as e:
print(f"Dépendance manquante: {e}")
print("Installez: pip install transformers datasets seqeval accelerate")
sys.exit(1)
# Labels BIO pour la désidentification
LABEL_LIST = [
"O",
"B-PER", "I-PER",
"B-TEL", "I-TEL",
"B-EMAIL", "I-EMAIL",
"B-NIR", "I-NIR",
"B-IPP", "I-IPP",
"B-NDA", "I-NDA",
"B-RPPS", "I-RPPS",
"B-DATE_NAISSANCE", "I-DATE_NAISSANCE",
"B-ADRESSE", "I-ADRESSE",
"B-ZIP", "I-ZIP",
"B-VILLE", "I-VILLE",
"B-HOPITAL", "I-HOPITAL",
"B-IBAN", "I-IBAN",
"B-AGE", "I-AGE",
]
LABEL2ID = {l: i for i, l in enumerate(LABEL_LIST)}
ID2LABEL = {i: l for l, i in LABEL2ID.items()}
MODEL_NAME = "almanach/camembert-bio-base"
def load_bio_files(data_dir: Path, window_size: int = 200, stride: int = 100) -> Dict[str, List]:
"""Charge les fichiers .bio et découpe en fenêtres glissantes.
Les documents cliniques sont très longs. On les découpe en fenêtres de
~window_size tokens avec un chevauchement de stride. On ne garde que les
fenêtres contenant au moins une entité (pour l'équilibre des classes).
"""
tokens_list: List[List[str]] = []
labels_list: List[List[int]] = []
for bio_file in sorted(data_dir.glob("*.bio")):
text = bio_file.read_text(encoding="utf-8")
# Charger tous les tokens du document
all_tokens: List[str] = []
all_labels: List[int] = []
for line in text.splitlines():
line = line.strip()
if not line:
continue
parts = line.split("\t")
if len(parts) != 2:
continue
token, label = parts
label_id = LABEL2ID.get(label, LABEL2ID["O"])
all_tokens.append(token)
all_labels.append(label_id)
if not all_tokens:
continue
# Découper en fenêtres glissantes
n = len(all_tokens)
for start in range(0, n, stride):
end = min(start + window_size, n)
chunk_tokens = all_tokens[start:end]
chunk_labels = all_labels[start:end]
# Corriger les I- en début de fenêtre → B-
if chunk_labels and chunk_labels[0] > 0:
lbl_name = LABEL_LIST[chunk_labels[0]]
if lbl_name.startswith("I-"):
b_name = "B-" + lbl_name[2:]
if b_name in LABEL2ID:
chunk_labels[0] = LABEL2ID[b_name]
# Garder les fenêtres avec entités + quelques fenêtres "O" (10%)
has_entities = any(l != 0 for l in chunk_labels)
if has_entities or (start % (stride * 10) == 0):
tokens_list.append(chunk_tokens)
labels_list.append(chunk_labels)
if end >= n:
break
return {"tokens": tokens_list, "ner_tags": labels_list}
def tokenize_and_align(examples, tokenizer):
"""Tokenize et aligne les labels avec les sous-tokens."""
tokenized = tokenizer(
examples["tokens"],
truncation=True,
is_split_into_words=True,
max_length=512,
padding=False,
)
all_labels = []
for i, labels in enumerate(examples["ner_tags"]):
word_ids = tokenized.word_ids(batch_index=i)
label_ids = []
prev_word_id = None
for word_id in word_ids:
if word_id is None:
label_ids.append(-100)
elif word_id != prev_word_id:
label_ids.append(labels[word_id])
else:
# Sous-token : I- si le premier est B-, sinon même label
orig = labels[word_id]
if orig > 0 and LABEL_LIST[orig].startswith("B-"):
# Convertir B- en I-
i_label = LABEL_LIST[orig].replace("B-", "I-")
label_ids.append(LABEL2ID.get(i_label, orig))
else:
label_ids.append(orig)
prev_word_id = word_id
all_labels.append(label_ids)
tokenized["labels"] = all_labels
return tokenized
class WeightedNERTrainer(Trainer):
"""Trainer avec poids de classe pour contrer le déséquilibre O vs entités."""
def __init__(self, class_weights=None, **kwargs):
super().__init__(**kwargs)
if class_weights is not None:
self.class_weights = class_weights.to(self.args.device)
else:
self.class_weights = None
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
if self.class_weights is not None:
loss_fct = nn.CrossEntropyLoss(
weight=self.class_weights,
ignore_index=-100,
label_smoothing=0.1,
)
else:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=0.1)
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
def compute_class_weights(raw_data: Dict, num_labels: int, max_weight: float = 10.0) -> torch.FloatTensor:
"""Calcule les poids inversement proportionnels à la fréquence, cappés après normalisation."""
counts = Counter()
for labels in raw_data["ner_tags"]:
for l in labels:
counts[l] += 1
total = sum(counts.values())
weights = torch.ones(num_labels)
for label_id, count in counts.items():
if count > 0:
weights[label_id] = total / (num_labels * count)
# Normaliser : O=1.0
if weights[0] > 0:
scale = 1.0 / weights[0]
weights *= scale
# Capper APRÈS normalisation pour limiter le déséquilibre
weights = torch.clamp(weights, max=max_weight)
print(f" Class weights (O={weights[0]:.1f}, non-O moyen={weights[1:].mean():.1f}, max={weights[1:].max():.1f})")
return weights
def main():
parser = argparse.ArgumentParser(description="Fine-tune CamemBERT-bio pour désidentification")
parser.add_argument("--data-dir", type=Path,
default=Path(__file__).parent.parent / "data" / "silver_annotations",
help="Répertoire des fichiers .bio")
parser.add_argument("--output-dir", type=Path,
default=Path(__file__).parent.parent / "models" / "camembert-bio-deid",
help="Répertoire de sortie du modèle")
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--val-split", type=float, default=0.15, help="Fraction pour validation")
args = parser.parse_args()
# Charger les données
print(f"Chargement des données depuis {args.data_dir}...")
raw_data = load_bio_files(args.data_dir)
n_sentences = len(raw_data["tokens"])
n_entities = sum(1 for labels in raw_data["ner_tags"] for l in labels if l != 0)
print(f" {n_sentences} phrases, {n_entities} entités annotées")
if n_sentences < 10:
print("ERREUR: pas assez de données. Lancez d'abord export_silver_annotations.py")
sys.exit(1)
# Split train/val
dataset = Dataset.from_dict(raw_data)
split = dataset.train_test_split(test_size=args.val_split, seed=42)
datasets = DatasetDict({"train": split["train"], "validation": split["test"]})
print(f" Train: {len(datasets['train'])}, Validation: {len(datasets['validation'])}")
# Tokenizer + modèle
print(f"\nChargement du modèle {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForTokenClassification.from_pretrained(
MODEL_NAME,
num_labels=len(LABEL_LIST),
id2label=ID2LABEL,
label2id=LABEL2ID,
)
# Tokenization
tokenized = datasets.map(
lambda ex: tokenize_and_align(ex, tokenizer),
batched=True,
remove_columns=datasets["train"].column_names,
)
# Métriques
seqeval = evaluate.load("seqeval")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
true_labels = []
true_preds = []
for pred_seq, label_seq in zip(predictions, labels):
t_labels = []
t_preds = []
for p, l in zip(pred_seq, label_seq):
if l != -100:
t_labels.append(LABEL_LIST[l])
t_preds.append(LABEL_LIST[p])
true_labels.append(t_labels)
true_preds.append(t_preds)
results = seqeval.compute(predictions=true_preds, references=true_labels)
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
}
# Class weights pour contrer le déséquilibre 97% O
print("\nCalcul des poids de classe...")
weights = compute_class_weights(raw_data, len(LABEL_LIST))
# Training
args.output_dir.mkdir(parents=True, exist_ok=True)
training_args = TrainingArguments(
output_dir=str(args.output_dir),
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size * 2,
learning_rate=args.lr,
weight_decay=0.01,
warmup_ratio=0.1,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
logging_steps=50,
fp16=True, # GPU training avec mixed precision
report_to="none",
save_total_limit=2,
)
data_collator = DataCollatorForTokenClassification(tokenizer)
trainer = WeightedNERTrainer(
class_weights=weights,
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["validation"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
)
print(f"\nDémarrage du fine-tuning ({args.epochs} epochs, batch={args.batch_size}, lr={args.lr})...")
trainer.train()
# Sauvegarder
trainer.save_model(str(args.output_dir / "best"))
tokenizer.save_pretrained(str(args.output_dir / "best"))
print(f"\nModèle sauvegardé: {args.output_dir / 'best'}")
# Évaluation finale
results = trainer.evaluate()
print(f"\nRésultats finaux:")
print(f" Precision: {results['eval_precision']:.4f}")
print(f" Recall: {results['eval_recall']:.4f}")
print(f" F1: {results['eval_f1']:.4f}")
print(f"\nPour exporter en ONNX:")
print(f" python -m optimum.exporters.onnx --model {args.output_dir / 'best'} {args.output_dir / 'onnx'}")
if __name__ == "__main__":
main()