Files
anonymisation/scripts/finetune_camembert_bio.py
Domi31tls eb14cd219d feat(phase3): CamemBERT v3 + détection villes + initiales + texte espacé + docs réglementaires
Intégration du modèle CamemBERT-bio-deid v3 (F1=0.96, Recall=0.97, 1112 docs)
et corrections qualité issues de l'audit approfondi sur 29 fichiers.

Détection des villes en texte libre :
- Automate Aho-Corasick sur 33K communes INSEE + 11.6K villes FINESS
- Stratégie contextuelle : exige un contexte géographique (à, de, vers,
  habite, urgences de, etc.) sauf pour les villes composées (Saint-Palais)
- Blacklist de ~80 communes homonymes de mots courants (charge, signes, plan...)
- Normalisation SAINT↔ST pour les variantes orthographiques
- De 18 fuites de villes à 2 cas résiduels atypiques

Masquage des initiales de prénom :
- Post-traitement regex : "Dr T. [NOM]" → "Dr [NOM] [NOM]"
- Références initiales : "Ref : JF/VA" → "Ref : [NOM]/[NOM]"

Détection texte espacé d'en-tête :
- "C E N T R E  H O S P I T A L I E R" → [ETABLISSEMENT]

Autres corrections :
- Fix regex RE_EXTRACT_MME_MR (Mr?.? → Mr.?, \s+ → [ \t]+, * → {0,4})
- Stop words médicaux : lever, coucher, services hospitaliers (viscérale, etc.)
- CamemBERT NER manager : version tracking, propriété version, log F1/Recall
- Script finetune : export ONNX automatique + mise à jour VERSION.json
- Évaluateur qualité : exclusion stop words médicaux des alertes INSEE

Documentation :
- Spécifications techniques CamemBERT-bio-deid v3
- Conformité RGPD + AI Act (caviardage PDF raster)
- AIPD (Analyse d'Impact Protection des Données)

Score qualité : 97.0/100 (Grade A), Leak score 100/100

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 12:16:13 +01:00

809 lines
31 KiB
Python

#!/usr/bin/env python3
"""
Fine-tune CamemBERT-bio pour la désidentification clinique française.
=====================================================================
Entraîne almanach/camembert-bio-base sur les annotations silver/gold
exportées par export_silver_annotations.py.
Usage:
python scripts/finetune_camembert_bio.py [--epochs 5] [--batch-size 8] [--lr 2e-5]
python scripts/finetune_camembert_bio.py --no-augment # Sans augmentation
Prérequis: pip install transformers datasets seqeval accelerate
Export ONNX post-training: python scripts/export_onnx.py
"""
import sys
import json
import subprocess
import argparse
import random
from pathlib import Path
from datetime import date
from typing import Dict, List, Tuple
from collections import Counter
import numpy as np
import torch
from torch import nn
# Vérifier les dépendances
try:
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
TrainingArguments,
Trainer,
DataCollatorForTokenClassification,
)
from datasets import Dataset, DatasetDict
import evaluate
except ImportError as e:
print(f"Dépendance manquante: {e}")
print("Installez: pip install transformers datasets seqeval accelerate")
sys.exit(1)
# Labels BIO pour la désidentification
LABEL_LIST = [
"O",
"B-PER", "I-PER",
"B-TEL", "I-TEL",
"B-EMAIL", "I-EMAIL",
"B-NIR", "I-NIR",
"B-IPP", "I-IPP",
"B-NDA", "I-NDA",
"B-RPPS", "I-RPPS",
"B-DATE_NAISSANCE", "I-DATE_NAISSANCE",
"B-ADRESSE", "I-ADRESSE",
"B-ZIP", "I-ZIP",
"B-VILLE", "I-VILLE",
"B-HOPITAL", "I-HOPITAL",
"B-IBAN", "I-IBAN",
"B-AGE", "I-AGE",
]
LABEL2ID = {l: i for i, l in enumerate(LABEL_LIST)}
ID2LABEL = {i: l for l, i in LABEL2ID.items()}
MODEL_NAME = "almanach/camembert-bio-base"
def load_bio_files(data_dir: Path, window_size: int = 200, stride: int = 100) -> Dict[str, List]:
"""Charge les fichiers .bio et découpe en fenêtres glissantes.
Les documents cliniques sont très longs. On les découpe en fenêtres de
~window_size tokens avec un chevauchement de stride. On ne garde que les
fenêtres contenant au moins une entité (pour l'équilibre des classes).
"""
tokens_list: List[List[str]] = []
labels_list: List[List[int]] = []
for bio_file in sorted(data_dir.glob("*.bio")):
text = bio_file.read_text(encoding="utf-8")
# Charger tous les tokens du document
all_tokens: List[str] = []
all_labels: List[int] = []
for line in text.splitlines():
line = line.strip()
if not line:
continue
parts = line.split("\t")
if len(parts) != 2:
continue
token, label = parts
label_id = LABEL2ID.get(label, LABEL2ID["O"])
all_tokens.append(token)
all_labels.append(label_id)
if not all_tokens:
continue
# Découper en fenêtres glissantes
n = len(all_tokens)
for start in range(0, n, stride):
end = min(start + window_size, n)
chunk_tokens = all_tokens[start:end]
chunk_labels = all_labels[start:end]
# Corriger les I- en début de fenêtre → B-
if chunk_labels and chunk_labels[0] > 0:
lbl_name = LABEL_LIST[chunk_labels[0]]
if lbl_name.startswith("I-"):
b_name = "B-" + lbl_name[2:]
if b_name in LABEL2ID:
chunk_labels[0] = LABEL2ID[b_name]
# Garder les fenêtres avec entités + quelques fenêtres "O" (10%)
has_entities = any(l != 0 for l in chunk_labels)
if has_entities or (start % (stride * 10) == 0):
tokens_list.append(chunk_tokens)
labels_list.append(chunk_labels)
if end >= n:
break
return {"tokens": tokens_list, "ner_tags": labels_list}
# ─────────────────────────────────────────────────────────────────────────────
# Data Augmentation : substitution de noms
# ─────────────────────────────────────────────────────────────────────────────
def _load_gazetteer(filepath: Path, max_entries: int = 10000) -> List[str]:
"""Charge un fichier de noms (un par ligne), en ignorant les lignes vides."""
if not filepath.exists():
print(f" [WARN] Fichier gazetteer absent : {filepath}")
return []
names = []
with open(filepath, encoding="utf-8", errors="replace") as f:
for line in f:
name = line.strip()
if name and len(name) >= 2:
names.append(name)
if len(names) >= max_entries:
break
return names
def _match_case(replacement: str, original: str) -> str:
"""Applique la casse de l'original au remplacement."""
if original.isupper():
return replacement.upper()
if original.istitle():
return replacement.title()
if original.islower():
return replacement.lower()
return replacement
def _extract_per_entities(tokens: List[str], labels: List[int]) -> List[Tuple[int, int]]:
"""Extrait les spans (start, end exclusive) des entités PER dans une séquence BIO.
Retourne une liste de tuples (start_idx, end_idx).
"""
entities = []
i = 0
b_per_id = LABEL2ID.get("B-PER", -1)
i_per_id = LABEL2ID.get("I-PER", -1)
while i < len(labels):
if labels[i] == b_per_id:
start = i
i += 1
while i < len(labels) and labels[i] == i_per_id:
i += 1
entities.append((start, i))
else:
i += 1
return entities
def augment_name_substitution(
tokens_list: List[List[str]],
labels_list: List[List[int]],
prenoms_file: Path,
noms_file: Path,
num_augmented: int = 3,
rng_seed: int = 42,
) -> Tuple[List[List[str]], List[List[int]]]:
"""Augmente les données en substituant les entités PER par des noms aléatoires.
Pour chaque exemple contenant au moins une entité PER, crée num_augmented
copies avec des noms de remplacement tirés des fichiers gazetteers INSEE.
La casse (UPPER, Title, lower) est conservée.
Retourne (aug_tokens_list, aug_labels_list) — les copies augmentées SANS
les originaux (l'appelant les concatène).
"""
rng = random.Random(rng_seed)
prenoms = _load_gazetteer(prenoms_file)
noms = _load_gazetteer(noms_file)
if not prenoms:
print(" [WARN] Aucun prénom chargé — augmentation désactivée")
return [], []
if not noms:
print(" [WARN] Aucun nom de famille chargé — utilisation des prénoms uniquement")
noms = prenoms # fallback gracieux
print(f" Gazetteers : {len(prenoms)} prénoms, {len(noms)} noms de famille")
b_per_id = LABEL2ID["B-PER"]
aug_tokens: List[List[str]] = []
aug_labels: List[List[int]] = []
n_augmented_examples = 0
for tokens, labels in zip(tokens_list, labels_list):
entities = _extract_per_entities(tokens, labels)
if not entities:
continue
for _ in range(num_augmented):
new_tokens = list(tokens)
new_labels = list(labels)
for start, end in entities:
span_tokens = tokens[start:end]
span_len = end - start
# Générer un remplacement de même longueur en tokens
replacements = []
for j, orig_tok in enumerate(span_tokens):
if "-" in orig_tok:
# Nom composé : JEAN-PIERRE → MARIE-CLAIRE
parts = orig_tok.split("-")
new_parts = [
_match_case(rng.choice(prenoms), p) for p in parts
]
replacements.append("-".join(new_parts))
elif j == 0 and span_len >= 2:
# Premier token d'une entité multi-tokens → prénom
replacements.append(_match_case(rng.choice(prenoms), orig_tok))
elif span_len == 1:
# Entité mono-token → 50/50 prénom ou nom
pool = prenoms if rng.random() < 0.5 else noms
replacements.append(_match_case(rng.choice(pool), orig_tok))
else:
# Tokens suivants → nom de famille
replacements.append(_match_case(rng.choice(noms), orig_tok))
for j, idx in enumerate(range(start, end)):
new_tokens[idx] = replacements[j]
aug_tokens.append(new_tokens)
aug_labels.append(new_labels)
n_augmented_examples += 1
print(f" Augmentation noms : {n_augmented_examples} exemples générés "
f"(x{num_augmented} par exemple avec entités PER)")
return aug_tokens, aug_labels
# ─────────────────────────────────────────────────────────────────────────────
# Hard Negative Mining : médicaments
# ─────────────────────────────────────────────────────────────────────────────
# Liste de secours si aucun fichier de médicaments n'est trouvé
_FALLBACK_MEDICATIONS = [
"DOLIPRANE", "EFFERALGAN", "DAFALGAN", "IBUPROFENE", "ADVIL", "NUROFEN",
"SPASFON", "SMECTA", "GAVISCON", "MOPRAL", "OMEPRAZOLE", "PANTOPRAZOLE",
"LANSOPRAZOLE", "INEXIUM", "AUGMENTIN", "AMOXICILLINE", "CLAMOXYL",
"METFORMINE", "GLUCOPHAGE", "DIAMICRON", "LANTUS", "NOVORAPID", "HUMALOG",
"LEVOTHYROX", "EUTHYROX", "KARDEGIC", "ASPEGIC", "PLAVIX", "CLOPIDOGREL",
"ELIQUIS", "XARELTO", "PRADAXA", "COUMADINE", "PREVISCAN", "LOVENOX",
"RAMIPRIL", "TRIATEC", "PERINDOPRIL", "COVERSYL", "AMLODIPINE", "AMLOR",
"BISOPROLOL", "ATENOLOL", "METOPROLOL", "FUROSEMIDE", "LASILIX",
"TAHOR", "CRESTOR", "PRAVASTATINE", "SIMVASTATINE",
]
def _load_medications(medications_file: Path) -> List[str]:
"""Charge une liste de noms de médicaments depuis le fichier BDPM ou le fallback."""
if medications_file.exists():
meds = []
with open(medications_file, encoding="utf-8", errors="replace") as f:
for line in f:
name = line.strip()
if name and len(name) >= 3:
meds.append(name)
if meds:
print(f" Médicaments chargés : {len(meds)} depuis {medications_file.name}")
return meds
print(f" [WARN] Fichier médicaments absent ou vide — utilisation de {len(_FALLBACK_MEDICATIONS)} médicaments courants")
return list(_FALLBACK_MEDICATIONS)
# Templates de phrases médicales où les médicaments ne sont PAS des noms de personnes
_MEDICATION_TEMPLATES = [
"Traitement par {med} {dose}",
"Prescription de {med} {forme}",
"Relais par {med}",
"Introduction de {med} {dose}",
"Poursuite du traitement par {med}",
"Arrêt du {med} en raison des effets secondaires",
"Allergie connue au {med}",
"Switch de {med} vers {med2}",
"Patient sous {med} depuis 3 mois",
"Posologie de {med} augmentée à {dose}",
"Administration de {med} en IV",
"Relais per os par {med} {dose}",
]
_DOSE_VARIANTS = [
"500mg", "1000mg", "200mg", "100mg", "50mg", "250mg", "75mg",
"20mg", "40mg", "10mg", "5mg", "1g", "2g", "80mg", "160mg",
]
_FORME_VARIANTS = [
"comprimé", "gélule", "sachet", "injectable", "sirop",
"pommade", "collyre", "suppositoire", "patch", "cp",
]
def _load_quaero_entities(quaero_dir: Path) -> List[str]:
"""Charge les entités QUAERO (CHEM, DISO, PROC, ANAT) comme termes médicaux."""
entities = []
for filename in ["quaero_chem_entities.txt", "quaero_diso_entities.txt",
"quaero_procedures.txt", "quaero_anatomie.txt"]:
fpath = quaero_dir / filename
if fpath.exists():
for line in fpath.read_text(encoding="utf-8").splitlines():
name = line.strip()
if name and len(name) >= 3:
entities.append(name)
return entities
# Templates supplémentaires pour termes médicaux QUAERO (maladies, procédures, anatomie)
_MEDICAL_CONTEXT_TEMPLATES = [
"Diagnostic de {term} confirmé par imagerie",
"Antécédent de {term}",
"Suspicion de {term} évoquée",
"Bilan pour {term}",
"Prise en charge de {term} par le service",
"Examen de {term} sans anomalie",
"Surveillance de {term} en cours",
"Indication opératoire pour {term}",
]
def add_hard_negatives(
tokens_list: List[List[str]],
labels_list: List[List[int]],
medications_file: Path,
n_examples: int = 600,
rng_seed: int = 42,
) -> Tuple[List[List[str]], List[List[int]]]:
"""Crée des exemples synthétiques de contextes médicaux (tout label O).
Utilise les noms de médicaments (BDPM) et les termes médicaux QUAERO
(CHEM, DISO, PROC, ANAT) pour enseigner au modèle que ces termes
ne sont PAS des noms de personnes.
Retourne (neg_tokens_list, neg_labels_list) — les négatifs à ajouter.
"""
rng = random.Random(rng_seed)
medications = _load_medications(medications_file)
# Charger aussi les entités QUAERO
quaero_dir = medications_file.parent.parent / "quaero"
quaero_terms = _load_quaero_entities(quaero_dir)
if quaero_terms:
print(f" Termes médicaux QUAERO : {len(quaero_terms)} (CHEM+DISO+PROC+ANAT)")
if not medications:
return [], []
neg_tokens: List[List[str]] = []
neg_labels: List[List[int]] = []
# 1. Hard negatives médicaments (BDPM)
n_med = n_examples * 2 // 3 # 2/3 médicaments
for i in range(n_med):
template = rng.choice(_MEDICATION_TEMPLATES)
med = rng.choice(medications)
med2 = rng.choice(medications)
dose = rng.choice(_DOSE_VARIANTS)
forme = rng.choice(_FORME_VARIANTS)
sentence = template.format(med=med, med2=med2, dose=dose, forme=forme)
toks = sentence.split()
labs = [LABEL2ID["O"]] * len(toks)
neg_tokens.append(toks)
neg_labels.append(labs)
# 2. Hard negatives QUAERO (maladies, procédures, anatomie)
n_quaero = n_examples - n_med # 1/3 termes QUAERO
if quaero_terms:
for i in range(n_quaero):
template = rng.choice(_MEDICAL_CONTEXT_TEMPLATES)
term = rng.choice(quaero_terms)
sentence = template.format(term=term)
toks = sentence.split()
labs = [LABEL2ID["O"]] * len(toks)
neg_tokens.append(toks)
neg_labels.append(labs)
print(f" Hard negatives : {n_med} médicaments + {n_quaero if quaero_terms else 0} QUAERO = {len(neg_tokens)} exemples")
return neg_tokens, neg_labels
def tokenize_and_align(examples, tokenizer):
"""Tokenize et aligne les labels avec les sous-tokens."""
tokenized = tokenizer(
examples["tokens"],
truncation=True,
is_split_into_words=True,
max_length=512,
padding=False,
)
all_labels = []
for i, labels in enumerate(examples["ner_tags"]):
word_ids = tokenized.word_ids(batch_index=i)
label_ids = []
prev_word_id = None
for word_id in word_ids:
if word_id is None:
label_ids.append(-100)
elif word_id != prev_word_id:
label_ids.append(labels[word_id])
else:
# Sous-token : I- si le premier est B-, sinon même label
orig = labels[word_id]
if orig > 0 and LABEL_LIST[orig].startswith("B-"):
# Convertir B- en I-
i_label = LABEL_LIST[orig].replace("B-", "I-")
label_ids.append(LABEL2ID.get(i_label, orig))
else:
label_ids.append(orig)
prev_word_id = word_id
all_labels.append(label_ids)
tokenized["labels"] = all_labels
return tokenized
class WeightedNERTrainer(Trainer):
"""Trainer avec poids de classe pour contrer le déséquilibre O vs entités."""
def __init__(self, class_weights=None, **kwargs):
super().__init__(**kwargs)
if class_weights is not None:
self.class_weights = class_weights.to(self.args.device)
else:
self.class_weights = None
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
if self.class_weights is not None:
loss_fct = nn.CrossEntropyLoss(
weight=self.class_weights,
ignore_index=-100,
label_smoothing=0.1,
)
else:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=0.1)
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
def compute_class_weights(raw_data: Dict, num_labels: int, max_weight: float = 10.0) -> torch.FloatTensor:
"""Calcule les poids inversement proportionnels à la fréquence, cappés après normalisation."""
counts = Counter()
for labels in raw_data["ner_tags"]:
for l in labels:
counts[l] += 1
total = sum(counts.values())
weights = torch.ones(num_labels)
for label_id, count in counts.items():
if count > 0:
weights[label_id] = total / (num_labels * count)
# Normaliser : O=1.0
if weights[0] > 0:
scale = 1.0 / weights[0]
weights *= scale
# Capper APRÈS normalisation pour limiter le déséquilibre
weights = torch.clamp(weights, max=max_weight)
print(f" Class weights (O={weights[0]:.1f}, non-O moyen={weights[1:].mean():.1f}, max={weights[1:].max():.1f})")
return weights
def main():
parser = argparse.ArgumentParser(description="Fine-tune CamemBERT-bio pour désidentification")
parser.add_argument("--data-dir", type=Path,
default=Path(__file__).parent.parent / "data" / "silver_annotations",
help="Répertoire des fichiers .bio")
parser.add_argument("--output-dir", type=Path,
default=Path(__file__).parent.parent / "models" / "camembert-bio-deid",
help="Répertoire de sortie du modèle")
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=16,
help="Batch size effectif (via gradient accumulation si nécessaire)")
parser.add_argument("--gpu-batch-size", type=int, default=8,
help="Batch size réel sur GPU (le reste via gradient accumulation)")
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--val-split", type=float, default=0.15, help="Fraction pour validation")
parser.add_argument("--no-augment", action="store_true",
help="Désactiver l'augmentation de données (substitution noms + hard negatives)")
parser.add_argument("--num-augmented", type=int, default=3,
help="Nombre de copies augmentées par exemple PER (défaut: 3)")
parser.add_argument("--num-hard-negatives", type=int, default=600,
help="Nombre d'exemples hard negative médicaments (défaut: 600)")
parser.add_argument("--augment-seed", type=int, default=42,
help="Seed pour la reproductibilité de l'augmentation")
args = parser.parse_args()
# Chemins des gazetteers
project_root = Path(__file__).parent.parent
prenoms_file = project_root / "data" / "insee" / "prenoms_france.txt"
noms_file = project_root / "data" / "insee" / "noms_famille_france.txt"
medications_file = project_root / "data" / "bdpm" / "medication_names.txt"
# Charger les données
print(f"Chargement des données depuis {args.data_dir}...")
raw_data = load_bio_files(args.data_dir)
n_sentences = len(raw_data["tokens"])
n_entities = sum(1 for labels in raw_data["ner_tags"] for l in labels if l != 0)
print(f" {n_sentences} phrases, {n_entities} entités annotées")
if n_sentences < 10:
print("ERREUR: pas assez de données. Lancez d'abord export_silver_annotations.py")
sys.exit(1)
# Split train/val AVANT augmentation (on n'augmente que le train)
dataset = Dataset.from_dict(raw_data)
split = dataset.train_test_split(test_size=args.val_split, seed=42)
# Récupérer les listes Python du split train pour l'augmentation
train_tokens = list(split["train"]["tokens"])
train_labels = list(split["train"]["ner_tags"])
val_tokens = list(split["test"]["tokens"])
val_labels = list(split["test"]["ner_tags"])
n_train_orig = len(train_tokens)
# Augmentation (train uniquement)
if not args.no_augment:
print(f"\nAugmentation des données d'entraînement (seed={args.augment_seed})...")
# 1. Substitution de noms
aug_tok, aug_lab = augment_name_substitution(
train_tokens, train_labels,
prenoms_file=prenoms_file,
noms_file=noms_file,
num_augmented=args.num_augmented,
rng_seed=args.augment_seed,
)
train_tokens.extend(aug_tok)
train_labels.extend(aug_lab)
# 2. Hard negatives médicaments
neg_tok, neg_lab = add_hard_negatives(
train_tokens, train_labels,
medications_file=medications_file,
n_examples=args.num_hard_negatives,
rng_seed=args.augment_seed,
)
train_tokens.extend(neg_tok)
train_labels.extend(neg_lab)
print(f"\n Résumé augmentation :")
print(f" Train original : {n_train_orig} exemples")
print(f" + Substitution noms : {len(aug_tok)} exemples")
print(f" + Hard negatives : {len(neg_tok)} exemples")
print(f" = Train total : {len(train_tokens)} exemples")
print(f" Validation (non augmentée) : {len(val_tokens)} exemples")
else:
print("\n Augmentation désactivée (--no-augment)")
# Reconstruire les datasets
train_data = {"tokens": train_tokens, "ner_tags": train_labels}
val_data = {"tokens": val_tokens, "ner_tags": val_labels}
datasets = DatasetDict({
"train": Dataset.from_dict(train_data),
"validation": Dataset.from_dict(val_data),
})
print(f" Train: {len(datasets['train'])}, Validation: {len(datasets['validation'])}")
# Tokenizer + modèle
print(f"\nChargement du modèle {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForTokenClassification.from_pretrained(
MODEL_NAME,
num_labels=len(LABEL_LIST),
id2label=ID2LABEL,
label2id=LABEL2ID,
)
# Tokenization
tokenized = datasets.map(
lambda ex: tokenize_and_align(ex, tokenizer),
batched=True,
remove_columns=datasets["train"].column_names,
)
# Métriques
seqeval = evaluate.load("seqeval")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
true_labels = []
true_preds = []
for pred_seq, label_seq in zip(predictions, labels):
t_labels = []
t_preds = []
for p, l in zip(pred_seq, label_seq):
if l != -100:
t_labels.append(LABEL_LIST[l])
t_preds.append(LABEL_LIST[p])
true_labels.append(t_labels)
true_preds.append(t_preds)
results = seqeval.compute(predictions=true_preds, references=true_labels)
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
}
# Class weights pour contrer le déséquilibre 97% O (calculé sur le train augmenté)
print("\nCalcul des poids de classe...")
weights = compute_class_weights(train_data, len(LABEL_LIST))
# Training
args.output_dir.mkdir(parents=True, exist_ok=True)
# Gradient accumulation : si batch_size demandé > batch effectif GPU
grad_accum = max(1, args.batch_size // args.gpu_batch_size)
effective_batch = args.gpu_batch_size * grad_accum
if effective_batch != args.batch_size:
print(f" GPU batch={args.gpu_batch_size}, grad_accum={grad_accum} → effectif={effective_batch}")
training_args = TrainingArguments(
output_dir=str(args.output_dir),
num_train_epochs=args.epochs,
per_device_train_batch_size=args.gpu_batch_size,
per_device_eval_batch_size=args.gpu_batch_size * 2,
gradient_accumulation_steps=grad_accum,
learning_rate=args.lr,
weight_decay=0.01,
warmup_ratio=0.1,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
logging_steps=50,
fp16=True, # GPU training avec mixed precision
report_to="none",
save_total_limit=2,
)
data_collator = DataCollatorForTokenClassification(tokenizer)
trainer = WeightedNERTrainer(
class_weights=weights,
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["validation"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
)
print(f"\nDémarrage du fine-tuning ({args.epochs} epochs, batch={args.batch_size}, lr={args.lr})...")
trainer.train()
# Sauvegarder
trainer.save_model(str(args.output_dir / "best"))
tokenizer.save_pretrained(str(args.output_dir / "best"))
print(f"\nModèle sauvegardé: {args.output_dir / 'best'}")
# Évaluation finale
results = trainer.evaluate()
print(f"\nRésultats finaux:")
print(f" Precision: {results['eval_precision']:.4f}")
print(f" Recall: {results['eval_recall']:.4f}")
print(f" F1: {results['eval_f1']:.4f}")
# ── Export ONNX automatique ──────────────────────────────────────────────
best_dir = args.output_dir / "best"
onnx_dir = args.output_dir / "onnx"
onnx_export_ok = False
try:
print(f"\nExport ONNX automatique...")
print(f" Source : {best_dir}")
print(f" Destination : {onnx_dir}")
result = subprocess.run(
[
sys.executable, "-m", "optimum.exporters.onnx",
"--model", str(best_dir),
"--task", "token-classification",
str(onnx_dir),
],
capture_output=True,
text=True,
timeout=600,
)
if result.returncode == 0:
onnx_export_ok = True
print(f" Export ONNX réussi → {onnx_dir}")
else:
print(f" [ERREUR] Export ONNX échoué (code {result.returncode})")
if result.stderr:
# Afficher les dernières lignes d'erreur
for line in result.stderr.strip().splitlines()[-10:]:
print(f" {line}")
print(f"\n Pour exporter manuellement :")
print(f" python -m optimum.exporters.onnx --model {best_dir} --task token-classification {onnx_dir}")
except FileNotFoundError:
print(f" [WARN] optimum non installé — export ONNX ignoré")
print(f" Pour exporter manuellement :")
print(f" pip install optimum[exporters]")
print(f" python -m optimum.exporters.onnx --model {best_dir} --task token-classification {onnx_dir}")
except subprocess.TimeoutExpired:
print(f" [ERREUR] Export ONNX timeout (>600s)")
print(f" Pour exporter manuellement :")
print(f" python -m optimum.exporters.onnx --model {best_dir} --task token-classification {onnx_dir}")
except Exception as e:
print(f" [ERREUR] Export ONNX inattendu : {e}")
print(f" Pour exporter manuellement :")
print(f" python -m optimum.exporters.onnx --model {best_dir} --task token-classification {onnx_dir}")
# ── Mise à jour VERSION.json ─────────────────────────────────────────────
version_file = args.output_dir / "VERSION.json"
try:
# Compter les documents d'entraînement (.bio files)
n_bio_files = len(list(args.data_dir.glob("*.bio")))
# Déterminer le numéro de version
if version_file.exists():
version_data = json.loads(version_file.read_text(encoding="utf-8"))
else:
version_data = {
"model": "camembert-bio-deid",
"base_model": MODEL_NAME,
"versions": {},
"directories": {},
}
# Incrémenter la version
existing_versions = [
k for k in version_data.get("versions", {}).keys()
if k.startswith("v") and k[1:].isdigit()
]
if existing_versions:
max_v = max(int(k[1:]) for k in existing_versions)
new_version = f"v{max_v + 1}"
else:
new_version = "v1"
# Trouver le best checkpoint (dernier sauvegardé par Trainer)
best_checkpoint = None
checkpoints = sorted(args.output_dir.glob("checkpoint-*"))
if checkpoints:
best_checkpoint = checkpoints[-1].name
# Construire l'entrée de version
version_entry = {
"date": date.today().isoformat(),
"training_docs": n_bio_files,
"training_examples": len(train_tokens),
"epochs": args.epochs,
"batch_size": args.batch_size,
"learning_rate": args.lr,
"f1": round(results["eval_f1"], 4),
"recall": round(results["eval_recall"], 4),
"precision": round(results["eval_precision"], 4),
"onnx_exported": onnx_export_ok,
}
if best_checkpoint:
version_entry["best_checkpoint"] = best_checkpoint
version_data["current_version"] = new_version
version_data["versions"][new_version] = version_entry
version_data["directories"] = {
"onnx": f"Modèle ONNX actif ({new_version}) — utilisé en inférence CPU",
f"best": f"Modèle PyTorch {new_version} (pour ré-export ONNX si besoin)",
}
version_file.write_text(
json.dumps(version_data, indent=2, ensure_ascii=False) + "\n",
encoding="utf-8",
)
print(f"\n VERSION.json mis à jour → {new_version} (F1={results['eval_f1']:.4f})")
except Exception as e:
print(f"\n [WARN] Impossible de mettre à jour VERSION.json : {e}")
if __name__ == "__main__":
main()