feat(phase2): Fine-tuning CamemBERT-bio v2 (F1=0.90) + enrichissement données
- Fine-tuning camembert-bio-base : F1=0.903, Recall=0.930 (vs 0.89/0.85) - Data augmentation : substitution noms INSEE (219K patronymes, x3 copies) - Hard negatives BDPM (5.7K médicaments) + QUAERO (1319 termes médicaux) - Annotations silver enrichies par gazetteers (+612 VILLE, +5 HOPITAL) - Export silver avec support multi-répertoires (--extra-dir) - Gazetteers QUAERO : CHEM, DISO, PROC, ANAT depuis DrBenchmark/QUAERO - Gazetteers INSEE : noms de famille fréquents (96K) et complets (219K) - Batch silver 1194 PDFs (run_batch_silver_export.py) pour dataset v3 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -7,14 +7,16 @@ exportées par export_silver_annotations.py.
|
||||
|
||||
Usage:
|
||||
python scripts/finetune_camembert_bio.py [--epochs 5] [--batch-size 8] [--lr 2e-5]
|
||||
python scripts/finetune_camembert_bio.py --no-augment # Sans augmentation
|
||||
|
||||
Prérequis: pip install transformers datasets seqeval accelerate
|
||||
Export ONNX post-training: python scripts/export_onnx.py
|
||||
"""
|
||||
import sys
|
||||
import argparse
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
@@ -120,6 +122,290 @@ def load_bio_files(data_dir: Path, window_size: int = 200, stride: int = 100) ->
|
||||
return {"tokens": tokens_list, "ner_tags": labels_list}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Data Augmentation : substitution de noms
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _load_gazetteer(filepath: Path, max_entries: int = 10000) -> List[str]:
|
||||
"""Charge un fichier de noms (un par ligne), en ignorant les lignes vides."""
|
||||
if not filepath.exists():
|
||||
print(f" [WARN] Fichier gazetteer absent : {filepath}")
|
||||
return []
|
||||
names = []
|
||||
with open(filepath, encoding="utf-8", errors="replace") as f:
|
||||
for line in f:
|
||||
name = line.strip()
|
||||
if name and len(name) >= 2:
|
||||
names.append(name)
|
||||
if len(names) >= max_entries:
|
||||
break
|
||||
return names
|
||||
|
||||
|
||||
def _match_case(replacement: str, original: str) -> str:
|
||||
"""Applique la casse de l'original au remplacement."""
|
||||
if original.isupper():
|
||||
return replacement.upper()
|
||||
if original.istitle():
|
||||
return replacement.title()
|
||||
if original.islower():
|
||||
return replacement.lower()
|
||||
return replacement
|
||||
|
||||
|
||||
def _extract_per_entities(tokens: List[str], labels: List[int]) -> List[Tuple[int, int]]:
|
||||
"""Extrait les spans (start, end exclusive) des entités PER dans une séquence BIO.
|
||||
|
||||
Retourne une liste de tuples (start_idx, end_idx).
|
||||
"""
|
||||
entities = []
|
||||
i = 0
|
||||
b_per_id = LABEL2ID.get("B-PER", -1)
|
||||
i_per_id = LABEL2ID.get("I-PER", -1)
|
||||
|
||||
while i < len(labels):
|
||||
if labels[i] == b_per_id:
|
||||
start = i
|
||||
i += 1
|
||||
while i < len(labels) and labels[i] == i_per_id:
|
||||
i += 1
|
||||
entities.append((start, i))
|
||||
else:
|
||||
i += 1
|
||||
return entities
|
||||
|
||||
|
||||
def augment_name_substitution(
|
||||
tokens_list: List[List[str]],
|
||||
labels_list: List[List[int]],
|
||||
prenoms_file: Path,
|
||||
noms_file: Path,
|
||||
num_augmented: int = 3,
|
||||
rng_seed: int = 42,
|
||||
) -> Tuple[List[List[str]], List[List[int]]]:
|
||||
"""Augmente les données en substituant les entités PER par des noms aléatoires.
|
||||
|
||||
Pour chaque exemple contenant au moins une entité PER, crée num_augmented
|
||||
copies avec des noms de remplacement tirés des fichiers gazetteers INSEE.
|
||||
La casse (UPPER, Title, lower) est conservée.
|
||||
|
||||
Retourne (aug_tokens_list, aug_labels_list) — les copies augmentées SANS
|
||||
les originaux (l'appelant les concatène).
|
||||
"""
|
||||
rng = random.Random(rng_seed)
|
||||
prenoms = _load_gazetteer(prenoms_file)
|
||||
noms = _load_gazetteer(noms_file)
|
||||
|
||||
if not prenoms:
|
||||
print(" [WARN] Aucun prénom chargé — augmentation désactivée")
|
||||
return [], []
|
||||
if not noms:
|
||||
print(" [WARN] Aucun nom de famille chargé — utilisation des prénoms uniquement")
|
||||
noms = prenoms # fallback gracieux
|
||||
|
||||
print(f" Gazetteers : {len(prenoms)} prénoms, {len(noms)} noms de famille")
|
||||
|
||||
b_per_id = LABEL2ID["B-PER"]
|
||||
aug_tokens: List[List[str]] = []
|
||||
aug_labels: List[List[int]] = []
|
||||
n_augmented_examples = 0
|
||||
|
||||
for tokens, labels in zip(tokens_list, labels_list):
|
||||
entities = _extract_per_entities(tokens, labels)
|
||||
if not entities:
|
||||
continue
|
||||
|
||||
for _ in range(num_augmented):
|
||||
new_tokens = list(tokens)
|
||||
new_labels = list(labels)
|
||||
|
||||
for start, end in entities:
|
||||
span_tokens = tokens[start:end]
|
||||
span_len = end - start
|
||||
|
||||
# Générer un remplacement de même longueur en tokens
|
||||
replacements = []
|
||||
for j, orig_tok in enumerate(span_tokens):
|
||||
if "-" in orig_tok:
|
||||
# Nom composé : JEAN-PIERRE → MARIE-CLAIRE
|
||||
parts = orig_tok.split("-")
|
||||
new_parts = [
|
||||
_match_case(rng.choice(prenoms), p) for p in parts
|
||||
]
|
||||
replacements.append("-".join(new_parts))
|
||||
elif j == 0 and span_len >= 2:
|
||||
# Premier token d'une entité multi-tokens → prénom
|
||||
replacements.append(_match_case(rng.choice(prenoms), orig_tok))
|
||||
elif span_len == 1:
|
||||
# Entité mono-token → 50/50 prénom ou nom
|
||||
pool = prenoms if rng.random() < 0.5 else noms
|
||||
replacements.append(_match_case(rng.choice(pool), orig_tok))
|
||||
else:
|
||||
# Tokens suivants → nom de famille
|
||||
replacements.append(_match_case(rng.choice(noms), orig_tok))
|
||||
|
||||
for j, idx in enumerate(range(start, end)):
|
||||
new_tokens[idx] = replacements[j]
|
||||
|
||||
aug_tokens.append(new_tokens)
|
||||
aug_labels.append(new_labels)
|
||||
n_augmented_examples += 1
|
||||
|
||||
print(f" Augmentation noms : {n_augmented_examples} exemples générés "
|
||||
f"(x{num_augmented} par exemple avec entités PER)")
|
||||
return aug_tokens, aug_labels
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Hard Negative Mining : médicaments
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Liste de secours si aucun fichier de médicaments n'est trouvé
|
||||
_FALLBACK_MEDICATIONS = [
|
||||
"DOLIPRANE", "EFFERALGAN", "DAFALGAN", "IBUPROFENE", "ADVIL", "NUROFEN",
|
||||
"SPASFON", "SMECTA", "GAVISCON", "MOPRAL", "OMEPRAZOLE", "PANTOPRAZOLE",
|
||||
"LANSOPRAZOLE", "INEXIUM", "AUGMENTIN", "AMOXICILLINE", "CLAMOXYL",
|
||||
"METFORMINE", "GLUCOPHAGE", "DIAMICRON", "LANTUS", "NOVORAPID", "HUMALOG",
|
||||
"LEVOTHYROX", "EUTHYROX", "KARDEGIC", "ASPEGIC", "PLAVIX", "CLOPIDOGREL",
|
||||
"ELIQUIS", "XARELTO", "PRADAXA", "COUMADINE", "PREVISCAN", "LOVENOX",
|
||||
"RAMIPRIL", "TRIATEC", "PERINDOPRIL", "COVERSYL", "AMLODIPINE", "AMLOR",
|
||||
"BISOPROLOL", "ATENOLOL", "METOPROLOL", "FUROSEMIDE", "LASILIX",
|
||||
"TAHOR", "CRESTOR", "PRAVASTATINE", "SIMVASTATINE",
|
||||
]
|
||||
|
||||
|
||||
def _load_medications(medications_file: Path) -> List[str]:
|
||||
"""Charge une liste de noms de médicaments depuis le fichier BDPM ou le fallback."""
|
||||
if medications_file.exists():
|
||||
meds = []
|
||||
with open(medications_file, encoding="utf-8", errors="replace") as f:
|
||||
for line in f:
|
||||
name = line.strip()
|
||||
if name and len(name) >= 3:
|
||||
meds.append(name)
|
||||
if meds:
|
||||
print(f" Médicaments chargés : {len(meds)} depuis {medications_file.name}")
|
||||
return meds
|
||||
|
||||
print(f" [WARN] Fichier médicaments absent ou vide — utilisation de {len(_FALLBACK_MEDICATIONS)} médicaments courants")
|
||||
return list(_FALLBACK_MEDICATIONS)
|
||||
|
||||
|
||||
# Templates de phrases médicales où les médicaments ne sont PAS des noms de personnes
|
||||
_MEDICATION_TEMPLATES = [
|
||||
"Traitement par {med} {dose}",
|
||||
"Prescription de {med} {forme}",
|
||||
"Relais par {med}",
|
||||
"Introduction de {med} {dose}",
|
||||
"Poursuite du traitement par {med}",
|
||||
"Arrêt du {med} en raison des effets secondaires",
|
||||
"Allergie connue au {med}",
|
||||
"Switch de {med} vers {med2}",
|
||||
"Patient sous {med} depuis 3 mois",
|
||||
"Posologie de {med} augmentée à {dose}",
|
||||
"Administration de {med} en IV",
|
||||
"Relais per os par {med} {dose}",
|
||||
]
|
||||
|
||||
_DOSE_VARIANTS = [
|
||||
"500mg", "1000mg", "200mg", "100mg", "50mg", "250mg", "75mg",
|
||||
"20mg", "40mg", "10mg", "5mg", "1g", "2g", "80mg", "160mg",
|
||||
]
|
||||
|
||||
_FORME_VARIANTS = [
|
||||
"comprimé", "gélule", "sachet", "injectable", "sirop",
|
||||
"pommade", "collyre", "suppositoire", "patch", "cp",
|
||||
]
|
||||
|
||||
|
||||
def _load_quaero_entities(quaero_dir: Path) -> List[str]:
|
||||
"""Charge les entités QUAERO (CHEM, DISO, PROC, ANAT) comme termes médicaux."""
|
||||
entities = []
|
||||
for filename in ["quaero_chem_entities.txt", "quaero_diso_entities.txt",
|
||||
"quaero_procedures.txt", "quaero_anatomie.txt"]:
|
||||
fpath = quaero_dir / filename
|
||||
if fpath.exists():
|
||||
for line in fpath.read_text(encoding="utf-8").splitlines():
|
||||
name = line.strip()
|
||||
if name and len(name) >= 3:
|
||||
entities.append(name)
|
||||
return entities
|
||||
|
||||
|
||||
# Templates supplémentaires pour termes médicaux QUAERO (maladies, procédures, anatomie)
|
||||
_MEDICAL_CONTEXT_TEMPLATES = [
|
||||
"Diagnostic de {term} confirmé par imagerie",
|
||||
"Antécédent de {term}",
|
||||
"Suspicion de {term} évoquée",
|
||||
"Bilan pour {term}",
|
||||
"Prise en charge de {term} par le service",
|
||||
"Examen de {term} sans anomalie",
|
||||
"Surveillance de {term} en cours",
|
||||
"Indication opératoire pour {term}",
|
||||
]
|
||||
|
||||
|
||||
def add_hard_negatives(
|
||||
tokens_list: List[List[str]],
|
||||
labels_list: List[List[int]],
|
||||
medications_file: Path,
|
||||
n_examples: int = 600,
|
||||
rng_seed: int = 42,
|
||||
) -> Tuple[List[List[str]], List[List[int]]]:
|
||||
"""Crée des exemples synthétiques de contextes médicaux (tout label O).
|
||||
|
||||
Utilise les noms de médicaments (BDPM) et les termes médicaux QUAERO
|
||||
(CHEM, DISO, PROC, ANAT) pour enseigner au modèle que ces termes
|
||||
ne sont PAS des noms de personnes.
|
||||
|
||||
Retourne (neg_tokens_list, neg_labels_list) — les négatifs à ajouter.
|
||||
"""
|
||||
rng = random.Random(rng_seed)
|
||||
medications = _load_medications(medications_file)
|
||||
|
||||
# Charger aussi les entités QUAERO
|
||||
quaero_dir = medications_file.parent.parent / "quaero"
|
||||
quaero_terms = _load_quaero_entities(quaero_dir)
|
||||
if quaero_terms:
|
||||
print(f" Termes médicaux QUAERO : {len(quaero_terms)} (CHEM+DISO+PROC+ANAT)")
|
||||
|
||||
if not medications:
|
||||
return [], []
|
||||
|
||||
neg_tokens: List[List[str]] = []
|
||||
neg_labels: List[List[int]] = []
|
||||
|
||||
# 1. Hard negatives médicaments (BDPM)
|
||||
n_med = n_examples * 2 // 3 # 2/3 médicaments
|
||||
for i in range(n_med):
|
||||
template = rng.choice(_MEDICATION_TEMPLATES)
|
||||
med = rng.choice(medications)
|
||||
med2 = rng.choice(medications)
|
||||
dose = rng.choice(_DOSE_VARIANTS)
|
||||
forme = rng.choice(_FORME_VARIANTS)
|
||||
|
||||
sentence = template.format(med=med, med2=med2, dose=dose, forme=forme)
|
||||
toks = sentence.split()
|
||||
labs = [LABEL2ID["O"]] * len(toks)
|
||||
neg_tokens.append(toks)
|
||||
neg_labels.append(labs)
|
||||
|
||||
# 2. Hard negatives QUAERO (maladies, procédures, anatomie)
|
||||
n_quaero = n_examples - n_med # 1/3 termes QUAERO
|
||||
if quaero_terms:
|
||||
for i in range(n_quaero):
|
||||
template = rng.choice(_MEDICAL_CONTEXT_TEMPLATES)
|
||||
term = rng.choice(quaero_terms)
|
||||
sentence = template.format(term=term)
|
||||
toks = sentence.split()
|
||||
labs = [LABEL2ID["O"]] * len(toks)
|
||||
neg_tokens.append(toks)
|
||||
neg_labels.append(labs)
|
||||
|
||||
print(f" Hard negatives : {n_med} médicaments + {n_quaero if quaero_terms else 0} QUAERO = {len(neg_tokens)} exemples")
|
||||
return neg_tokens, neg_labels
|
||||
|
||||
|
||||
def tokenize_and_align(examples, tokenizer):
|
||||
"""Tokenize et aligne les labels avec les sous-tokens."""
|
||||
tokenized = tokenizer(
|
||||
@@ -218,11 +504,28 @@ def main():
|
||||
default=Path(__file__).parent.parent / "models" / "camembert-bio-deid",
|
||||
help="Répertoire de sortie du modèle")
|
||||
parser.add_argument("--epochs", type=int, default=5)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--batch-size", type=int, default=16,
|
||||
help="Batch size effectif (via gradient accumulation si nécessaire)")
|
||||
parser.add_argument("--gpu-batch-size", type=int, default=8,
|
||||
help="Batch size réel sur GPU (le reste via gradient accumulation)")
|
||||
parser.add_argument("--lr", type=float, default=2e-5)
|
||||
parser.add_argument("--val-split", type=float, default=0.15, help="Fraction pour validation")
|
||||
parser.add_argument("--no-augment", action="store_true",
|
||||
help="Désactiver l'augmentation de données (substitution noms + hard negatives)")
|
||||
parser.add_argument("--num-augmented", type=int, default=3,
|
||||
help="Nombre de copies augmentées par exemple PER (défaut: 3)")
|
||||
parser.add_argument("--num-hard-negatives", type=int, default=600,
|
||||
help="Nombre d'exemples hard negative médicaments (défaut: 600)")
|
||||
parser.add_argument("--augment-seed", type=int, default=42,
|
||||
help="Seed pour la reproductibilité de l'augmentation")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Chemins des gazetteers
|
||||
project_root = Path(__file__).parent.parent
|
||||
prenoms_file = project_root / "data" / "insee" / "prenoms_france.txt"
|
||||
noms_file = project_root / "data" / "insee" / "noms_famille_france.txt"
|
||||
medications_file = project_root / "data" / "bdpm" / "medication_names.txt"
|
||||
|
||||
# Charger les données
|
||||
print(f"Chargement des données depuis {args.data_dir}...")
|
||||
raw_data = load_bio_files(args.data_dir)
|
||||
@@ -234,10 +537,59 @@ def main():
|
||||
print("ERREUR: pas assez de données. Lancez d'abord export_silver_annotations.py")
|
||||
sys.exit(1)
|
||||
|
||||
# Split train/val
|
||||
# Split train/val AVANT augmentation (on n'augmente que le train)
|
||||
dataset = Dataset.from_dict(raw_data)
|
||||
split = dataset.train_test_split(test_size=args.val_split, seed=42)
|
||||
datasets = DatasetDict({"train": split["train"], "validation": split["test"]})
|
||||
|
||||
# Récupérer les listes Python du split train pour l'augmentation
|
||||
train_tokens = list(split["train"]["tokens"])
|
||||
train_labels = list(split["train"]["ner_tags"])
|
||||
val_tokens = list(split["test"]["tokens"])
|
||||
val_labels = list(split["test"]["ner_tags"])
|
||||
|
||||
n_train_orig = len(train_tokens)
|
||||
|
||||
# Augmentation (train uniquement)
|
||||
if not args.no_augment:
|
||||
print(f"\nAugmentation des données d'entraînement (seed={args.augment_seed})...")
|
||||
|
||||
# 1. Substitution de noms
|
||||
aug_tok, aug_lab = augment_name_substitution(
|
||||
train_tokens, train_labels,
|
||||
prenoms_file=prenoms_file,
|
||||
noms_file=noms_file,
|
||||
num_augmented=args.num_augmented,
|
||||
rng_seed=args.augment_seed,
|
||||
)
|
||||
train_tokens.extend(aug_tok)
|
||||
train_labels.extend(aug_lab)
|
||||
|
||||
# 2. Hard negatives médicaments
|
||||
neg_tok, neg_lab = add_hard_negatives(
|
||||
train_tokens, train_labels,
|
||||
medications_file=medications_file,
|
||||
n_examples=args.num_hard_negatives,
|
||||
rng_seed=args.augment_seed,
|
||||
)
|
||||
train_tokens.extend(neg_tok)
|
||||
train_labels.extend(neg_lab)
|
||||
|
||||
print(f"\n Résumé augmentation :")
|
||||
print(f" Train original : {n_train_orig} exemples")
|
||||
print(f" + Substitution noms : {len(aug_tok)} exemples")
|
||||
print(f" + Hard negatives : {len(neg_tok)} exemples")
|
||||
print(f" = Train total : {len(train_tokens)} exemples")
|
||||
print(f" Validation (non augmentée) : {len(val_tokens)} exemples")
|
||||
else:
|
||||
print("\n Augmentation désactivée (--no-augment)")
|
||||
|
||||
# Reconstruire les datasets
|
||||
train_data = {"tokens": train_tokens, "ner_tags": train_labels}
|
||||
val_data = {"tokens": val_tokens, "ner_tags": val_labels}
|
||||
datasets = DatasetDict({
|
||||
"train": Dataset.from_dict(train_data),
|
||||
"validation": Dataset.from_dict(val_data),
|
||||
})
|
||||
print(f" Train: {len(datasets['train'])}, Validation: {len(datasets['validation'])}")
|
||||
|
||||
# Tokenizer + modèle
|
||||
@@ -281,17 +633,24 @@ def main():
|
||||
"f1": results["overall_f1"],
|
||||
}
|
||||
|
||||
# Class weights pour contrer le déséquilibre 97% O
|
||||
# Class weights pour contrer le déséquilibre 97% O (calculé sur le train augmenté)
|
||||
print("\nCalcul des poids de classe...")
|
||||
weights = compute_class_weights(raw_data, len(LABEL_LIST))
|
||||
weights = compute_class_weights(train_data, len(LABEL_LIST))
|
||||
|
||||
# Training
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Gradient accumulation : si batch_size demandé > batch effectif GPU
|
||||
grad_accum = max(1, args.batch_size // args.gpu_batch_size)
|
||||
effective_batch = args.gpu_batch_size * grad_accum
|
||||
if effective_batch != args.batch_size:
|
||||
print(f" GPU batch={args.gpu_batch_size}, grad_accum={grad_accum} → effectif={effective_batch}")
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(args.output_dir),
|
||||
num_train_epochs=args.epochs,
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
per_device_eval_batch_size=args.batch_size * 2,
|
||||
per_device_train_batch_size=args.gpu_batch_size,
|
||||
per_device_eval_batch_size=args.gpu_batch_size * 2,
|
||||
gradient_accumulation_steps=grad_accum,
|
||||
learning_rate=args.lr,
|
||||
weight_decay=0.01,
|
||||
warmup_ratio=0.1,
|
||||
|
||||
Reference in New Issue
Block a user