feat(phase2): Fine-tuning CamemBERT-bio v2 (F1=0.90) + enrichissement données

- Fine-tuning camembert-bio-base : F1=0.903, Recall=0.930 (vs 0.89/0.85)
- Data augmentation : substitution noms INSEE (219K patronymes, x3 copies)
- Hard negatives BDPM (5.7K médicaments) + QUAERO (1319 termes médicaux)
- Annotations silver enrichies par gazetteers (+612 VILLE, +5 HOPITAL)
- Export silver avec support multi-répertoires (--extra-dir)
- Gazetteers QUAERO : CHEM, DISO, PROC, ANAT depuis DrBenchmark/QUAERO
- Gazetteers INSEE : noms de famille fréquents (96K) et complets (219K)
- Batch silver 1194 PDFs (run_batch_silver_export.py) pour dataset v3

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-10 02:06:08 +01:00
parent 274e2fa586
commit c9572c383a
38 changed files with 318811 additions and 1406 deletions

View File

@@ -0,0 +1,118 @@
#!/usr/bin/env python3
"""
build_noms_famille_gazetteer.py
Construit deux gazetteers de noms de famille français à partir du fichier INSEE.
Source : /home/dom/ai/anonymisation/data/insee/noms2008nat_txt.txt
(TSV : NOM, puis effectifs par décennie 1891-2000)
Sorties dans /home/dom/ai/anonymisation/data/insee/ :
- noms_famille_france.txt : TOUS les noms (219K), un par ligne, majuscules
- noms_famille_frequents.txt : noms avec total >= 100 ET longueur >= 3 caractères
"""
import os
from collections import Counter
INPUT_FILE = "/home/dom/ai/anonymisation/data/insee/noms2008nat_txt.txt"
OUTPUT_DIR = "/home/dom/ai/anonymisation/data/insee"
OUTPUT_ALL = os.path.join(OUTPUT_DIR, "noms_famille_france.txt")
OUTPUT_FREQ = os.path.join(OUTPUT_DIR, "noms_famille_frequents.txt")
MIN_TOTAL = 100 # seuil de fréquence pour le fichier "fréquents"
MIN_LENGTH = 3 # longueur minimale pour le fichier "fréquents"
SKIP_NAMES = {"AUTRES NOMS"}
def main():
all_names = []
freq_names = []
frequencies = [] # pour les stats de distribution
with open(INPUT_FILE, "r", encoding="utf-8") as f:
header = f.readline() # skip header
print(f"En-tête : {header.strip()}")
for line in f:
line = line.strip()
if not line:
continue
parts = line.split("\t")
nom = parts[0].strip()
if nom in SKIP_NAMES:
continue
# Calculer la fréquence totale sur toutes les décennies
total = 0
for val in parts[1:]:
val = val.strip()
try:
total += int(val)
except ValueError:
pass
# Tous les noms (uppercase, déjà le cas dans le fichier)
nom_upper = nom.upper()
all_names.append(nom_upper)
frequencies.append(total)
# Filtre pour les noms fréquents
if total >= MIN_TOTAL and len(nom_upper) >= MIN_LENGTH:
freq_names.append(nom_upper)
# Écriture des fichiers
with open(OUTPUT_ALL, "w", encoding="utf-8") as f:
f.write("\n".join(all_names) + "\n")
with open(OUTPUT_FREQ, "w", encoding="utf-8") as f:
f.write("\n".join(freq_names) + "\n")
# --- Stats ---
print(f"\n{'='*60}")
print(f"STATISTIQUES")
print(f"{'='*60}")
print(f"Fichier source : {INPUT_FILE}")
print(f"Noms totaux lus : {len(all_names):>10,}")
print(f"{OUTPUT_ALL}")
print(f"Noms fréquents : {len(freq_names):>10,} (total >= {MIN_TOTAL}, len >= {MIN_LENGTH})")
print(f"{OUTPUT_FREQ}")
print(f"Noms exclus (filtre): {len(all_names) - len(freq_names):>10,}")
# Distribution de fréquence
print(f"\n--- Distribution des fréquences ---")
buckets = [
(0, 0, "0 (hapax)"),
(1, 9, "1-9"),
(10, 49, "10-49"),
(50, 99, "50-99"),
(100, 499, "100-499"),
(500, 999, "500-999"),
(1000, 4999, "1 000-4 999"),
(5000, 9999, "5 000-9 999"),
(10000, 49999, "10 000-49 999"),
(50000, 99999, "50 000-99 999"),
(100000, float("inf"), "100 000+"),
]
for lo, hi, label in buckets:
count = sum(1 for freq in frequencies if lo <= freq <= hi)
if count > 0:
print(f" {label:>20s} : {count:>8,} noms")
# Top 20
print(f"\n--- Top 20 noms les plus fréquents ---")
indexed = list(zip(all_names, frequencies))
indexed.sort(key=lambda x: x[1], reverse=True)
for i, (nom, freq) in enumerate(indexed[:20], 1):
print(f" {i:>2}. {nom:<25s} {freq:>10,}")
# Quelques stats globales
total_naissances = sum(frequencies)
print(f"\nTotal naissances couvertes : {total_naissances:>12,}")
print(f"Fréquence médiane : {sorted(frequencies)[len(frequencies)//2]:>12,}")
print(f"Fréquence moyenne : {total_naissances / len(frequencies):>12,.1f}")
if __name__ == "__main__":
main()

View File

@@ -16,8 +16,9 @@ import sys
import re
import difflib
import argparse
import unicodedata
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Set, Tuple
sys.path.insert(0, str(Path(__file__).parent.parent))
@@ -47,6 +48,178 @@ RE_PLACEHOLDER = re.compile(r"^\[([A-Z_]+)\]$")
SRC = Path("/home/dom/Téléchargements/II-1 Ctrl_T2A_2025_CHCB_DocJustificatifs (1)")
AUDIT_DIR = SRC / "anonymise_audit_30"
# --- Gazetteer paths ---
GAZETTEERS_DIR = Path(__file__).parent.parent / "data"
VILLES_FINESS_PATH = GAZETTEERS_DIR / "finess" / "villes_finess.txt"
COMMUNES_INSEE_PATH = GAZETTEERS_DIR / "insee" / "communes_france.txt"
ETABLISSEMENTS_PATH = GAZETTEERS_DIR / "finess" / "etablissements_distinctifs.txt"
# Mots de contexte indiquant qu'un token VILLE est bien un lieu (pas un nom commun)
VILLE_CONTEXT_WORDS = {
"à", "a", "de", "", "née", "nee", "ne", "résid", "resid",
"hospitalis", "transféré", "transfere", "transferée", "transferee",
"domicilié", "domicilie", "domiciliée", "domiciliee",
"habite", "habitant", "demeurant", "originaire", "ville",
"commune", "cedex",
}
def _strip_accents(s: str) -> str:
"""Supprime les accents d'une chaîne (é→e, à→a, etc.)."""
nfkd = unicodedata.normalize("NFKD", s)
return "".join(c for c in nfkd if not unicodedata.combining(c))
def _normalize_gaz(s: str) -> str:
"""Normalise pour comparaison gazetteer : minuscule, sans accents, stripped."""
return _strip_accents(s.lower().strip())
def load_gazetteers() -> dict:
"""Charge les gazetteers depuis les fichiers, avec fallback gracieux.
Retourne un dict avec:
- "villes": set de tuples de tokens normalisés (ex: ("saint", "palais"))
- "hopitaux": set de tuples de tokens normalisés (ex: ("ch", "argentan"))
"""
villes: Set[Tuple[str, ...]] = set()
hopitaux: Set[Tuple[str, ...]] = set()
# --- Villes FINESS (UPPERCASE, une par ligne) ---
if VILLES_FINESS_PATH.exists():
for line in VILLES_FINESS_PATH.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
# Filtrer les entrées avec "CEDEX" (adresses postales, pas des villes)
tokens = tuple(_normalize_gaz(t) for t in line.split() if t != "CEDEX")
if tokens and len(tokens[0]) >= 2: # Ignorer entrées trop courtes
villes.add(tokens)
# --- Communes INSEE (UPPERCASE, une par ligne) ---
if COMMUNES_INSEE_PATH.exists():
for line in COMMUNES_INSEE_PATH.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
tokens = tuple(_normalize_gaz(t) for t in line.split())
if tokens and len(tokens[0]) >= 2:
villes.add(tokens)
# --- Établissements (format "- nom normalisé", minuscule) ---
if ETABLISSEMENTS_PATH.exists():
for line in ETABLISSEMENTS_PATH.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line.startswith("- "):
continue
name = line[2:].strip()
if not name:
continue
tokens = tuple(_normalize_gaz(t) for t in name.split())
# Ignorer les entrées trop courtes (1 token de 3 chars ou moins)
if tokens and (len(tokens) > 1 or len(tokens[0]) > 3):
hopitaux.add(tokens)
# Retirer les villes d'un seul token très court (risque élevé de faux positifs)
villes = {v for v in villes if len(v) > 1 or len(v[0]) >= 3}
return {"villes": villes, "hopitaux": hopitaux}
def _has_ville_context(tokens: List[str], labels: List[str], pos: int,
window: int = 3) -> bool:
"""Vérifie si un token à la position `pos` a un contexte indiquant un lieu.
Regarde les `window` tokens précédents pour des mots-clés de contexte.
"""
start = max(0, pos - window)
for i in range(start, pos):
tok_norm = _normalize_gaz(tokens[i].strip(".,;:!?()[]{}\"'"))
# Vérifier correspondance exacte ou préfixe (ex: "résid" matche "résidence")
for ctx in VILLE_CONTEXT_WORDS:
if tok_norm == ctx or tok_norm.startswith(ctx):
return True
return False
def enrich_with_gazetteers(
bio_tokens: List[Tuple[str, str]],
gazetteers: dict,
) -> Tuple[List[Tuple[str, str]], int, int]:
"""Enrichit les annotations BIO avec les gazetteers.
Ne modifie JAMAIS un label existant (non-O). Ajoute uniquement des labels
sur les tokens actuellement "O".
Retourne: (bio_tokens_enrichis, nb_villes_ajoutées, nb_hopitaux_ajoutés)
"""
tokens = [t for t, _ in bio_tokens]
labels = [l for _, l in bio_tokens]
n = len(tokens)
added_villes = 0
added_hopitaux = 0
# Pré-calculer les tokens normalisés (sans ponctuation, sans accents, lowercase)
tokens_norm = [
_normalize_gaz(t.strip(".,;:!?()[]{}\"'"))
for t in tokens
]
# --- Enrichissement HOPITAL (multi-token, sans contrainte de contexte) ---
# On traite d'abord les hôpitaux car ils sont plus spécifiques
for gaz_tokens in gazetteers.get("hopitaux", set()):
gaz_len = len(gaz_tokens)
if gaz_len == 0:
continue
i = 0
while i <= n - gaz_len:
# Vérifier si la séquence de tokens matche
match = True
for k in range(gaz_len):
if tokens_norm[i + k] != gaz_tokens[k]:
match = False
break
if match:
# Vérifier que TOUS les tokens sont actuellement "O"
all_o = all(labels[i + k] == "O" for k in range(gaz_len))
if all_o:
labels[i] = "B-HOPITAL"
for k in range(1, gaz_len):
labels[i + k] = "I-HOPITAL"
added_hopitaux += 1
i += gaz_len
continue
i += 1
# --- Enrichissement VILLE (avec contexte obligatoire) ---
for gaz_tokens in gazetteers.get("villes", set()):
gaz_len = len(gaz_tokens)
if gaz_len == 0:
continue
i = 0
while i <= n - gaz_len:
# Vérifier si la séquence de tokens matche
match = True
for k in range(gaz_len):
if tokens_norm[i + k] != gaz_tokens[k]:
match = False
break
if match:
# Vérifier que TOUS les tokens sont actuellement "O"
all_o = all(labels[i + k] == "O" for k in range(gaz_len))
if all_o and _has_ville_context(tokens, labels, i):
labels[i] = "B-VILLE"
for k in range(1, gaz_len):
labels[i + k] = "I-VILLE"
added_villes += 1
i += gaz_len
continue
i += 1
enriched = list(zip(tokens, labels))
return enriched, added_villes, added_hopitaux
def extract_original_text(pdf_path: Path) -> str:
"""Extrait le texte brut d'un PDF (même méthode que le pipeline)."""
@@ -173,20 +346,37 @@ def align_and_annotate(original_text: str, pseudo_text: str) -> List[Tuple[str,
return bio_tokens
def export_document(pdf_path: Path, pseudo_path: Path, out_dir: Path) -> Tuple[int, int]:
"""Exporte un document en format BIO. Retourne (nb_tokens, nb_entités)."""
def export_document(
pdf_path: Path,
pseudo_path: Path,
out_dir: Path,
gazetteers: dict | None = None,
) -> Tuple[int, int, int, int]:
"""Exporte un document en format BIO.
Retourne (nb_tokens, nb_entités_diff, nb_villes_gaz, nb_hopitaux_gaz).
"""
# Extraire le texte original
original_text = extract_original_text(pdf_path)
if not original_text.strip():
return 0, 0
return 0, 0, 0, 0
# Lire le texte pseudonymisé
pseudo_text = pseudo_path.read_text(encoding="utf-8")
if not pseudo_text.strip():
return 0, 0
return 0, 0, 0, 0
# Aligner et annoter
# Aligner et annoter (diff-based)
bio_tokens = align_and_annotate(original_text, pseudo_text)
n_ents_diff = sum(1 for _, l in bio_tokens if l.startswith("B-"))
# Enrichissement gazetteer (post-processing)
added_villes = 0
added_hopitaux = 0
if gazetteers:
bio_tokens, added_villes, added_hopitaux = enrich_with_gazetteers(
bio_tokens, gazetteers
)
# Écrire en format CoNLL
out_name = pdf_path.stem + ".bio"
@@ -202,8 +392,7 @@ def export_document(pdf_path: Path, pseudo_path: Path, out_dir: Path) -> Tuple[i
out_path.write_text("\n".join(lines), encoding="utf-8")
n_ents = sum(1 for _, l in bio_tokens if l.startswith("B-"))
return len(bio_tokens), n_ents
return len(bio_tokens), n_ents_diff, added_villes, added_hopitaux
def main():
@@ -212,12 +401,42 @@ def main():
default=Path(__file__).parent.parent / "data" / "silver_annotations",
help="Répertoire de sortie")
parser.add_argument("--limit", type=int, default=0, help="Limiter à N fichiers (0=tous)")
parser.add_argument("--no-gazetteers", action="store_true",
help="Désactiver l'enrichissement par gazetteers")
parser.add_argument("--extra-dir", type=Path, nargs="*", default=[],
help="Répertoires supplémentaires contenant des .pseudonymise.txt")
args = parser.parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
# Trouver les paires PDF + pseudo
pseudo_files = sorted(AUDIT_DIR.glob("*.pseudonymise.txt"))
# Charger les gazetteers
gazetteers = None
if not args.no_gazetteers:
gazetteers = load_gazetteers()
n_villes = len(gazetteers["villes"])
n_hop = len(gazetteers["hopitaux"])
print(f"Gazetteers chargés: {n_villes} villes, {n_hop} établissements")
else:
print("Gazetteers désactivés")
# Trouver les paires PDF + pseudo (audit_30 + extra dirs)
search_dirs = [AUDIT_DIR] + list(args.extra_dir)
pseudo_files = []
for sdir in search_dirs:
if sdir.exists():
pseudo_files.extend(sorted(sdir.glob("*.pseudonymise.txt")))
print(f" {sdir.name}: {len(list(sdir.glob('*.pseudonymise.txt')))} fichiers pseudo")
# Dédupliquer par nom de base
seen_bases = set()
unique_pseudo = []
for pf in pseudo_files:
base = pf.name.replace(".pseudonymise.txt", "")
if base not in seen_bases:
seen_bases.add(base)
unique_pseudo.append(pf)
pseudo_files = unique_pseudo
pairs = []
for pseudo_path in pseudo_files:
# Retrouver le PDF source
@@ -233,17 +452,35 @@ def main():
print(f"Export silver annotations: {len(pairs)} documents → {args.out_dir}")
total_tokens = 0
total_entities = 0
total_ents_diff = 0
total_villes_gaz = 0
total_hop_gaz = 0
for pdf_path, pseudo_path in pairs:
try:
n_tok, n_ent = export_document(pdf_path, pseudo_path, args.out_dir)
n_tok, n_diff, n_vgaz, n_hgaz = export_document(
pdf_path, pseudo_path, args.out_dir, gazetteers
)
total_tokens += n_tok
total_entities += n_ent
print(f" {pdf_path.name}: {n_tok} tokens, {n_ent} entités")
total_ents_diff += n_diff
total_villes_gaz += n_vgaz
total_hop_gaz += n_hgaz
gaz_info = ""
if n_vgaz or n_hgaz:
gaz_info = f" (+{n_vgaz} villes, +{n_hgaz} hôpitaux gaz.)"
print(f" {pdf_path.name}: {n_tok} tokens, {n_diff} entités diff{gaz_info}")
except Exception as e:
print(f" {pdf_path.name}: ERREUR {e}")
print(f"\nTotal: {total_tokens} tokens, {total_entities} entités B-")
total_gaz = total_villes_gaz + total_hop_gaz
total_all = total_ents_diff + total_gaz
print(f"\n{'='*60}")
print(f"Total tokens: {total_tokens}")
print(f"Entités diff-based: {total_ents_diff} B-")
print(f"Entités gazetteers: +{total_gaz} ({total_villes_gaz} VILLE, {total_hop_gaz} HOPITAL)")
print(f"Total entités: {total_all} B-")
if total_ents_diff > 0:
pct = 100 * total_gaz / total_ents_diff
print(f"Enrichissement: +{pct:.1f}% par gazetteers")
print(f"Sortie: {args.out_dir}")

View File

@@ -7,14 +7,16 @@ exportées par export_silver_annotations.py.
Usage:
python scripts/finetune_camembert_bio.py [--epochs 5] [--batch-size 8] [--lr 2e-5]
python scripts/finetune_camembert_bio.py --no-augment # Sans augmentation
Prérequis: pip install transformers datasets seqeval accelerate
Export ONNX post-training: python scripts/export_onnx.py
"""
import sys
import argparse
import random
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Tuple
from collections import Counter
import numpy as np
@@ -120,6 +122,290 @@ def load_bio_files(data_dir: Path, window_size: int = 200, stride: int = 100) ->
return {"tokens": tokens_list, "ner_tags": labels_list}
# ─────────────────────────────────────────────────────────────────────────────
# Data Augmentation : substitution de noms
# ─────────────────────────────────────────────────────────────────────────────
def _load_gazetteer(filepath: Path, max_entries: int = 10000) -> List[str]:
"""Charge un fichier de noms (un par ligne), en ignorant les lignes vides."""
if not filepath.exists():
print(f" [WARN] Fichier gazetteer absent : {filepath}")
return []
names = []
with open(filepath, encoding="utf-8", errors="replace") as f:
for line in f:
name = line.strip()
if name and len(name) >= 2:
names.append(name)
if len(names) >= max_entries:
break
return names
def _match_case(replacement: str, original: str) -> str:
"""Applique la casse de l'original au remplacement."""
if original.isupper():
return replacement.upper()
if original.istitle():
return replacement.title()
if original.islower():
return replacement.lower()
return replacement
def _extract_per_entities(tokens: List[str], labels: List[int]) -> List[Tuple[int, int]]:
"""Extrait les spans (start, end exclusive) des entités PER dans une séquence BIO.
Retourne une liste de tuples (start_idx, end_idx).
"""
entities = []
i = 0
b_per_id = LABEL2ID.get("B-PER", -1)
i_per_id = LABEL2ID.get("I-PER", -1)
while i < len(labels):
if labels[i] == b_per_id:
start = i
i += 1
while i < len(labels) and labels[i] == i_per_id:
i += 1
entities.append((start, i))
else:
i += 1
return entities
def augment_name_substitution(
tokens_list: List[List[str]],
labels_list: List[List[int]],
prenoms_file: Path,
noms_file: Path,
num_augmented: int = 3,
rng_seed: int = 42,
) -> Tuple[List[List[str]], List[List[int]]]:
"""Augmente les données en substituant les entités PER par des noms aléatoires.
Pour chaque exemple contenant au moins une entité PER, crée num_augmented
copies avec des noms de remplacement tirés des fichiers gazetteers INSEE.
La casse (UPPER, Title, lower) est conservée.
Retourne (aug_tokens_list, aug_labels_list) — les copies augmentées SANS
les originaux (l'appelant les concatène).
"""
rng = random.Random(rng_seed)
prenoms = _load_gazetteer(prenoms_file)
noms = _load_gazetteer(noms_file)
if not prenoms:
print(" [WARN] Aucun prénom chargé — augmentation désactivée")
return [], []
if not noms:
print(" [WARN] Aucun nom de famille chargé — utilisation des prénoms uniquement")
noms = prenoms # fallback gracieux
print(f" Gazetteers : {len(prenoms)} prénoms, {len(noms)} noms de famille")
b_per_id = LABEL2ID["B-PER"]
aug_tokens: List[List[str]] = []
aug_labels: List[List[int]] = []
n_augmented_examples = 0
for tokens, labels in zip(tokens_list, labels_list):
entities = _extract_per_entities(tokens, labels)
if not entities:
continue
for _ in range(num_augmented):
new_tokens = list(tokens)
new_labels = list(labels)
for start, end in entities:
span_tokens = tokens[start:end]
span_len = end - start
# Générer un remplacement de même longueur en tokens
replacements = []
for j, orig_tok in enumerate(span_tokens):
if "-" in orig_tok:
# Nom composé : JEAN-PIERRE → MARIE-CLAIRE
parts = orig_tok.split("-")
new_parts = [
_match_case(rng.choice(prenoms), p) for p in parts
]
replacements.append("-".join(new_parts))
elif j == 0 and span_len >= 2:
# Premier token d'une entité multi-tokens → prénom
replacements.append(_match_case(rng.choice(prenoms), orig_tok))
elif span_len == 1:
# Entité mono-token → 50/50 prénom ou nom
pool = prenoms if rng.random() < 0.5 else noms
replacements.append(_match_case(rng.choice(pool), orig_tok))
else:
# Tokens suivants → nom de famille
replacements.append(_match_case(rng.choice(noms), orig_tok))
for j, idx in enumerate(range(start, end)):
new_tokens[idx] = replacements[j]
aug_tokens.append(new_tokens)
aug_labels.append(new_labels)
n_augmented_examples += 1
print(f" Augmentation noms : {n_augmented_examples} exemples générés "
f"(x{num_augmented} par exemple avec entités PER)")
return aug_tokens, aug_labels
# ─────────────────────────────────────────────────────────────────────────────
# Hard Negative Mining : médicaments
# ─────────────────────────────────────────────────────────────────────────────
# Liste de secours si aucun fichier de médicaments n'est trouvé
_FALLBACK_MEDICATIONS = [
"DOLIPRANE", "EFFERALGAN", "DAFALGAN", "IBUPROFENE", "ADVIL", "NUROFEN",
"SPASFON", "SMECTA", "GAVISCON", "MOPRAL", "OMEPRAZOLE", "PANTOPRAZOLE",
"LANSOPRAZOLE", "INEXIUM", "AUGMENTIN", "AMOXICILLINE", "CLAMOXYL",
"METFORMINE", "GLUCOPHAGE", "DIAMICRON", "LANTUS", "NOVORAPID", "HUMALOG",
"LEVOTHYROX", "EUTHYROX", "KARDEGIC", "ASPEGIC", "PLAVIX", "CLOPIDOGREL",
"ELIQUIS", "XARELTO", "PRADAXA", "COUMADINE", "PREVISCAN", "LOVENOX",
"RAMIPRIL", "TRIATEC", "PERINDOPRIL", "COVERSYL", "AMLODIPINE", "AMLOR",
"BISOPROLOL", "ATENOLOL", "METOPROLOL", "FUROSEMIDE", "LASILIX",
"TAHOR", "CRESTOR", "PRAVASTATINE", "SIMVASTATINE",
]
def _load_medications(medications_file: Path) -> List[str]:
"""Charge une liste de noms de médicaments depuis le fichier BDPM ou le fallback."""
if medications_file.exists():
meds = []
with open(medications_file, encoding="utf-8", errors="replace") as f:
for line in f:
name = line.strip()
if name and len(name) >= 3:
meds.append(name)
if meds:
print(f" Médicaments chargés : {len(meds)} depuis {medications_file.name}")
return meds
print(f" [WARN] Fichier médicaments absent ou vide — utilisation de {len(_FALLBACK_MEDICATIONS)} médicaments courants")
return list(_FALLBACK_MEDICATIONS)
# Templates de phrases médicales où les médicaments ne sont PAS des noms de personnes
_MEDICATION_TEMPLATES = [
"Traitement par {med} {dose}",
"Prescription de {med} {forme}",
"Relais par {med}",
"Introduction de {med} {dose}",
"Poursuite du traitement par {med}",
"Arrêt du {med} en raison des effets secondaires",
"Allergie connue au {med}",
"Switch de {med} vers {med2}",
"Patient sous {med} depuis 3 mois",
"Posologie de {med} augmentée à {dose}",
"Administration de {med} en IV",
"Relais per os par {med} {dose}",
]
_DOSE_VARIANTS = [
"500mg", "1000mg", "200mg", "100mg", "50mg", "250mg", "75mg",
"20mg", "40mg", "10mg", "5mg", "1g", "2g", "80mg", "160mg",
]
_FORME_VARIANTS = [
"comprimé", "gélule", "sachet", "injectable", "sirop",
"pommade", "collyre", "suppositoire", "patch", "cp",
]
def _load_quaero_entities(quaero_dir: Path) -> List[str]:
"""Charge les entités QUAERO (CHEM, DISO, PROC, ANAT) comme termes médicaux."""
entities = []
for filename in ["quaero_chem_entities.txt", "quaero_diso_entities.txt",
"quaero_procedures.txt", "quaero_anatomie.txt"]:
fpath = quaero_dir / filename
if fpath.exists():
for line in fpath.read_text(encoding="utf-8").splitlines():
name = line.strip()
if name and len(name) >= 3:
entities.append(name)
return entities
# Templates supplémentaires pour termes médicaux QUAERO (maladies, procédures, anatomie)
_MEDICAL_CONTEXT_TEMPLATES = [
"Diagnostic de {term} confirmé par imagerie",
"Antécédent de {term}",
"Suspicion de {term} évoquée",
"Bilan pour {term}",
"Prise en charge de {term} par le service",
"Examen de {term} sans anomalie",
"Surveillance de {term} en cours",
"Indication opératoire pour {term}",
]
def add_hard_negatives(
tokens_list: List[List[str]],
labels_list: List[List[int]],
medications_file: Path,
n_examples: int = 600,
rng_seed: int = 42,
) -> Tuple[List[List[str]], List[List[int]]]:
"""Crée des exemples synthétiques de contextes médicaux (tout label O).
Utilise les noms de médicaments (BDPM) et les termes médicaux QUAERO
(CHEM, DISO, PROC, ANAT) pour enseigner au modèle que ces termes
ne sont PAS des noms de personnes.
Retourne (neg_tokens_list, neg_labels_list) — les négatifs à ajouter.
"""
rng = random.Random(rng_seed)
medications = _load_medications(medications_file)
# Charger aussi les entités QUAERO
quaero_dir = medications_file.parent.parent / "quaero"
quaero_terms = _load_quaero_entities(quaero_dir)
if quaero_terms:
print(f" Termes médicaux QUAERO : {len(quaero_terms)} (CHEM+DISO+PROC+ANAT)")
if not medications:
return [], []
neg_tokens: List[List[str]] = []
neg_labels: List[List[int]] = []
# 1. Hard negatives médicaments (BDPM)
n_med = n_examples * 2 // 3 # 2/3 médicaments
for i in range(n_med):
template = rng.choice(_MEDICATION_TEMPLATES)
med = rng.choice(medications)
med2 = rng.choice(medications)
dose = rng.choice(_DOSE_VARIANTS)
forme = rng.choice(_FORME_VARIANTS)
sentence = template.format(med=med, med2=med2, dose=dose, forme=forme)
toks = sentence.split()
labs = [LABEL2ID["O"]] * len(toks)
neg_tokens.append(toks)
neg_labels.append(labs)
# 2. Hard negatives QUAERO (maladies, procédures, anatomie)
n_quaero = n_examples - n_med # 1/3 termes QUAERO
if quaero_terms:
for i in range(n_quaero):
template = rng.choice(_MEDICAL_CONTEXT_TEMPLATES)
term = rng.choice(quaero_terms)
sentence = template.format(term=term)
toks = sentence.split()
labs = [LABEL2ID["O"]] * len(toks)
neg_tokens.append(toks)
neg_labels.append(labs)
print(f" Hard negatives : {n_med} médicaments + {n_quaero if quaero_terms else 0} QUAERO = {len(neg_tokens)} exemples")
return neg_tokens, neg_labels
def tokenize_and_align(examples, tokenizer):
"""Tokenize et aligne les labels avec les sous-tokens."""
tokenized = tokenizer(
@@ -218,11 +504,28 @@ def main():
default=Path(__file__).parent.parent / "models" / "camembert-bio-deid",
help="Répertoire de sortie du modèle")
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--batch-size", type=int, default=16,
help="Batch size effectif (via gradient accumulation si nécessaire)")
parser.add_argument("--gpu-batch-size", type=int, default=8,
help="Batch size réel sur GPU (le reste via gradient accumulation)")
parser.add_argument("--lr", type=float, default=2e-5)
parser.add_argument("--val-split", type=float, default=0.15, help="Fraction pour validation")
parser.add_argument("--no-augment", action="store_true",
help="Désactiver l'augmentation de données (substitution noms + hard negatives)")
parser.add_argument("--num-augmented", type=int, default=3,
help="Nombre de copies augmentées par exemple PER (défaut: 3)")
parser.add_argument("--num-hard-negatives", type=int, default=600,
help="Nombre d'exemples hard negative médicaments (défaut: 600)")
parser.add_argument("--augment-seed", type=int, default=42,
help="Seed pour la reproductibilité de l'augmentation")
args = parser.parse_args()
# Chemins des gazetteers
project_root = Path(__file__).parent.parent
prenoms_file = project_root / "data" / "insee" / "prenoms_france.txt"
noms_file = project_root / "data" / "insee" / "noms_famille_france.txt"
medications_file = project_root / "data" / "bdpm" / "medication_names.txt"
# Charger les données
print(f"Chargement des données depuis {args.data_dir}...")
raw_data = load_bio_files(args.data_dir)
@@ -234,10 +537,59 @@ def main():
print("ERREUR: pas assez de données. Lancez d'abord export_silver_annotations.py")
sys.exit(1)
# Split train/val
# Split train/val AVANT augmentation (on n'augmente que le train)
dataset = Dataset.from_dict(raw_data)
split = dataset.train_test_split(test_size=args.val_split, seed=42)
datasets = DatasetDict({"train": split["train"], "validation": split["test"]})
# Récupérer les listes Python du split train pour l'augmentation
train_tokens = list(split["train"]["tokens"])
train_labels = list(split["train"]["ner_tags"])
val_tokens = list(split["test"]["tokens"])
val_labels = list(split["test"]["ner_tags"])
n_train_orig = len(train_tokens)
# Augmentation (train uniquement)
if not args.no_augment:
print(f"\nAugmentation des données d'entraînement (seed={args.augment_seed})...")
# 1. Substitution de noms
aug_tok, aug_lab = augment_name_substitution(
train_tokens, train_labels,
prenoms_file=prenoms_file,
noms_file=noms_file,
num_augmented=args.num_augmented,
rng_seed=args.augment_seed,
)
train_tokens.extend(aug_tok)
train_labels.extend(aug_lab)
# 2. Hard negatives médicaments
neg_tok, neg_lab = add_hard_negatives(
train_tokens, train_labels,
medications_file=medications_file,
n_examples=args.num_hard_negatives,
rng_seed=args.augment_seed,
)
train_tokens.extend(neg_tok)
train_labels.extend(neg_lab)
print(f"\n Résumé augmentation :")
print(f" Train original : {n_train_orig} exemples")
print(f" + Substitution noms : {len(aug_tok)} exemples")
print(f" + Hard negatives : {len(neg_tok)} exemples")
print(f" = Train total : {len(train_tokens)} exemples")
print(f" Validation (non augmentée) : {len(val_tokens)} exemples")
else:
print("\n Augmentation désactivée (--no-augment)")
# Reconstruire les datasets
train_data = {"tokens": train_tokens, "ner_tags": train_labels}
val_data = {"tokens": val_tokens, "ner_tags": val_labels}
datasets = DatasetDict({
"train": Dataset.from_dict(train_data),
"validation": Dataset.from_dict(val_data),
})
print(f" Train: {len(datasets['train'])}, Validation: {len(datasets['validation'])}")
# Tokenizer + modèle
@@ -281,17 +633,24 @@ def main():
"f1": results["overall_f1"],
}
# Class weights pour contrer le déséquilibre 97% O
# Class weights pour contrer le déséquilibre 97% O (calculé sur le train augmenté)
print("\nCalcul des poids de classe...")
weights = compute_class_weights(raw_data, len(LABEL_LIST))
weights = compute_class_weights(train_data, len(LABEL_LIST))
# Training
args.output_dir.mkdir(parents=True, exist_ok=True)
# Gradient accumulation : si batch_size demandé > batch effectif GPU
grad_accum = max(1, args.batch_size // args.gpu_batch_size)
effective_batch = args.gpu_batch_size * grad_accum
if effective_batch != args.batch_size:
print(f" GPU batch={args.gpu_batch_size}, grad_accum={grad_accum} → effectif={effective_batch}")
training_args = TrainingArguments(
output_dir=str(args.output_dir),
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size * 2,
per_device_train_batch_size=args.gpu_batch_size,
per_device_eval_batch_size=args.gpu_batch_size * 2,
gradient_accumulation_steps=grad_accum,
learning_rate=args.lr,
weight_decay=0.01,
warmup_ratio=0.1,