feat(phase2): Fine-tuning CamemBERT-bio v2 (F1=0.90) + enrichissement données
- Fine-tuning camembert-bio-base : F1=0.903, Recall=0.930 (vs 0.89/0.85) - Data augmentation : substitution noms INSEE (219K patronymes, x3 copies) - Hard negatives BDPM (5.7K médicaments) + QUAERO (1319 termes médicaux) - Annotations silver enrichies par gazetteers (+612 VILLE, +5 HOPITAL) - Export silver avec support multi-répertoires (--extra-dir) - Gazetteers QUAERO : CHEM, DISO, PROC, ANAT depuis DrBenchmark/QUAERO - Gazetteers INSEE : noms de famille fréquents (96K) et complets (219K) - Batch silver 1194 PDFs (run_batch_silver_export.py) pour dataset v3 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
118
scripts/build_noms_famille_gazetteer.py
Normal file
118
scripts/build_noms_famille_gazetteer.py
Normal file
@@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
build_noms_famille_gazetteer.py
|
||||
Construit deux gazetteers de noms de famille français à partir du fichier INSEE.
|
||||
|
||||
Source : /home/dom/ai/anonymisation/data/insee/noms2008nat_txt.txt
|
||||
(TSV : NOM, puis effectifs par décennie 1891-2000)
|
||||
|
||||
Sorties dans /home/dom/ai/anonymisation/data/insee/ :
|
||||
- noms_famille_france.txt : TOUS les noms (219K), un par ligne, majuscules
|
||||
- noms_famille_frequents.txt : noms avec total >= 100 ET longueur >= 3 caractères
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections import Counter
|
||||
|
||||
INPUT_FILE = "/home/dom/ai/anonymisation/data/insee/noms2008nat_txt.txt"
|
||||
OUTPUT_DIR = "/home/dom/ai/anonymisation/data/insee"
|
||||
OUTPUT_ALL = os.path.join(OUTPUT_DIR, "noms_famille_france.txt")
|
||||
OUTPUT_FREQ = os.path.join(OUTPUT_DIR, "noms_famille_frequents.txt")
|
||||
|
||||
MIN_TOTAL = 100 # seuil de fréquence pour le fichier "fréquents"
|
||||
MIN_LENGTH = 3 # longueur minimale pour le fichier "fréquents"
|
||||
SKIP_NAMES = {"AUTRES NOMS"}
|
||||
|
||||
|
||||
def main():
|
||||
all_names = []
|
||||
freq_names = []
|
||||
frequencies = [] # pour les stats de distribution
|
||||
|
||||
with open(INPUT_FILE, "r", encoding="utf-8") as f:
|
||||
header = f.readline() # skip header
|
||||
print(f"En-tête : {header.strip()}")
|
||||
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
parts = line.split("\t")
|
||||
nom = parts[0].strip()
|
||||
|
||||
if nom in SKIP_NAMES:
|
||||
continue
|
||||
|
||||
# Calculer la fréquence totale sur toutes les décennies
|
||||
total = 0
|
||||
for val in parts[1:]:
|
||||
val = val.strip()
|
||||
try:
|
||||
total += int(val)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Tous les noms (uppercase, déjà le cas dans le fichier)
|
||||
nom_upper = nom.upper()
|
||||
all_names.append(nom_upper)
|
||||
frequencies.append(total)
|
||||
|
||||
# Filtre pour les noms fréquents
|
||||
if total >= MIN_TOTAL and len(nom_upper) >= MIN_LENGTH:
|
||||
freq_names.append(nom_upper)
|
||||
|
||||
# Écriture des fichiers
|
||||
with open(OUTPUT_ALL, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(all_names) + "\n")
|
||||
|
||||
with open(OUTPUT_FREQ, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(freq_names) + "\n")
|
||||
|
||||
# --- Stats ---
|
||||
print(f"\n{'='*60}")
|
||||
print(f"STATISTIQUES")
|
||||
print(f"{'='*60}")
|
||||
print(f"Fichier source : {INPUT_FILE}")
|
||||
print(f"Noms totaux lus : {len(all_names):>10,}")
|
||||
print(f" → {OUTPUT_ALL}")
|
||||
print(f"Noms fréquents : {len(freq_names):>10,} (total >= {MIN_TOTAL}, len >= {MIN_LENGTH})")
|
||||
print(f" → {OUTPUT_FREQ}")
|
||||
print(f"Noms exclus (filtre): {len(all_names) - len(freq_names):>10,}")
|
||||
|
||||
# Distribution de fréquence
|
||||
print(f"\n--- Distribution des fréquences ---")
|
||||
buckets = [
|
||||
(0, 0, "0 (hapax)"),
|
||||
(1, 9, "1-9"),
|
||||
(10, 49, "10-49"),
|
||||
(50, 99, "50-99"),
|
||||
(100, 499, "100-499"),
|
||||
(500, 999, "500-999"),
|
||||
(1000, 4999, "1 000-4 999"),
|
||||
(5000, 9999, "5 000-9 999"),
|
||||
(10000, 49999, "10 000-49 999"),
|
||||
(50000, 99999, "50 000-99 999"),
|
||||
(100000, float("inf"), "100 000+"),
|
||||
]
|
||||
for lo, hi, label in buckets:
|
||||
count = sum(1 for freq in frequencies if lo <= freq <= hi)
|
||||
if count > 0:
|
||||
print(f" {label:>20s} : {count:>8,} noms")
|
||||
|
||||
# Top 20
|
||||
print(f"\n--- Top 20 noms les plus fréquents ---")
|
||||
indexed = list(zip(all_names, frequencies))
|
||||
indexed.sort(key=lambda x: x[1], reverse=True)
|
||||
for i, (nom, freq) in enumerate(indexed[:20], 1):
|
||||
print(f" {i:>2}. {nom:<25s} {freq:>10,}")
|
||||
|
||||
# Quelques stats globales
|
||||
total_naissances = sum(frequencies)
|
||||
print(f"\nTotal naissances couvertes : {total_naissances:>12,}")
|
||||
print(f"Fréquence médiane : {sorted(frequencies)[len(frequencies)//2]:>12,}")
|
||||
print(f"Fréquence moyenne : {total_naissances / len(frequencies):>12,.1f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -16,8 +16,9 @@ import sys
|
||||
import re
|
||||
import difflib
|
||||
import argparse
|
||||
import unicodedata
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
@@ -47,6 +48,178 @@ RE_PLACEHOLDER = re.compile(r"^\[([A-Z_]+)\]$")
|
||||
SRC = Path("/home/dom/Téléchargements/II-1 Ctrl_T2A_2025_CHCB_DocJustificatifs (1)")
|
||||
AUDIT_DIR = SRC / "anonymise_audit_30"
|
||||
|
||||
# --- Gazetteer paths ---
|
||||
GAZETTEERS_DIR = Path(__file__).parent.parent / "data"
|
||||
VILLES_FINESS_PATH = GAZETTEERS_DIR / "finess" / "villes_finess.txt"
|
||||
COMMUNES_INSEE_PATH = GAZETTEERS_DIR / "insee" / "communes_france.txt"
|
||||
ETABLISSEMENTS_PATH = GAZETTEERS_DIR / "finess" / "etablissements_distinctifs.txt"
|
||||
|
||||
# Mots de contexte indiquant qu'un token VILLE est bien un lieu (pas un nom commun)
|
||||
VILLE_CONTEXT_WORDS = {
|
||||
"à", "a", "de", "né", "née", "nee", "ne", "résid", "resid",
|
||||
"hospitalis", "transféré", "transfere", "transferée", "transferee",
|
||||
"domicilié", "domicilie", "domiciliée", "domiciliee",
|
||||
"habite", "habitant", "demeurant", "originaire", "ville",
|
||||
"commune", "cedex",
|
||||
}
|
||||
|
||||
|
||||
def _strip_accents(s: str) -> str:
|
||||
"""Supprime les accents d'une chaîne (é→e, à→a, etc.)."""
|
||||
nfkd = unicodedata.normalize("NFKD", s)
|
||||
return "".join(c for c in nfkd if not unicodedata.combining(c))
|
||||
|
||||
|
||||
def _normalize_gaz(s: str) -> str:
|
||||
"""Normalise pour comparaison gazetteer : minuscule, sans accents, stripped."""
|
||||
return _strip_accents(s.lower().strip())
|
||||
|
||||
|
||||
def load_gazetteers() -> dict:
|
||||
"""Charge les gazetteers depuis les fichiers, avec fallback gracieux.
|
||||
|
||||
Retourne un dict avec:
|
||||
- "villes": set de tuples de tokens normalisés (ex: ("saint", "palais"))
|
||||
- "hopitaux": set de tuples de tokens normalisés (ex: ("ch", "argentan"))
|
||||
"""
|
||||
villes: Set[Tuple[str, ...]] = set()
|
||||
hopitaux: Set[Tuple[str, ...]] = set()
|
||||
|
||||
# --- Villes FINESS (UPPERCASE, une par ligne) ---
|
||||
if VILLES_FINESS_PATH.exists():
|
||||
for line in VILLES_FINESS_PATH.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Filtrer les entrées avec "CEDEX" (adresses postales, pas des villes)
|
||||
tokens = tuple(_normalize_gaz(t) for t in line.split() if t != "CEDEX")
|
||||
if tokens and len(tokens[0]) >= 2: # Ignorer entrées trop courtes
|
||||
villes.add(tokens)
|
||||
|
||||
# --- Communes INSEE (UPPERCASE, une par ligne) ---
|
||||
if COMMUNES_INSEE_PATH.exists():
|
||||
for line in COMMUNES_INSEE_PATH.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
tokens = tuple(_normalize_gaz(t) for t in line.split())
|
||||
if tokens and len(tokens[0]) >= 2:
|
||||
villes.add(tokens)
|
||||
|
||||
# --- Établissements (format "- nom normalisé", minuscule) ---
|
||||
if ETABLISSEMENTS_PATH.exists():
|
||||
for line in ETABLISSEMENTS_PATH.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line.startswith("- "):
|
||||
continue
|
||||
name = line[2:].strip()
|
||||
if not name:
|
||||
continue
|
||||
tokens = tuple(_normalize_gaz(t) for t in name.split())
|
||||
# Ignorer les entrées trop courtes (1 token de 3 chars ou moins)
|
||||
if tokens and (len(tokens) > 1 or len(tokens[0]) > 3):
|
||||
hopitaux.add(tokens)
|
||||
|
||||
# Retirer les villes d'un seul token très court (risque élevé de faux positifs)
|
||||
villes = {v for v in villes if len(v) > 1 or len(v[0]) >= 3}
|
||||
|
||||
return {"villes": villes, "hopitaux": hopitaux}
|
||||
|
||||
|
||||
def _has_ville_context(tokens: List[str], labels: List[str], pos: int,
|
||||
window: int = 3) -> bool:
|
||||
"""Vérifie si un token à la position `pos` a un contexte indiquant un lieu.
|
||||
|
||||
Regarde les `window` tokens précédents pour des mots-clés de contexte.
|
||||
"""
|
||||
start = max(0, pos - window)
|
||||
for i in range(start, pos):
|
||||
tok_norm = _normalize_gaz(tokens[i].strip(".,;:!?()[]{}\"'"))
|
||||
# Vérifier correspondance exacte ou préfixe (ex: "résid" matche "résidence")
|
||||
for ctx in VILLE_CONTEXT_WORDS:
|
||||
if tok_norm == ctx or tok_norm.startswith(ctx):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def enrich_with_gazetteers(
|
||||
bio_tokens: List[Tuple[str, str]],
|
||||
gazetteers: dict,
|
||||
) -> Tuple[List[Tuple[str, str]], int, int]:
|
||||
"""Enrichit les annotations BIO avec les gazetteers.
|
||||
|
||||
Ne modifie JAMAIS un label existant (non-O). Ajoute uniquement des labels
|
||||
sur les tokens actuellement "O".
|
||||
|
||||
Retourne: (bio_tokens_enrichis, nb_villes_ajoutées, nb_hopitaux_ajoutés)
|
||||
"""
|
||||
tokens = [t for t, _ in bio_tokens]
|
||||
labels = [l for _, l in bio_tokens]
|
||||
n = len(tokens)
|
||||
|
||||
added_villes = 0
|
||||
added_hopitaux = 0
|
||||
|
||||
# Pré-calculer les tokens normalisés (sans ponctuation, sans accents, lowercase)
|
||||
tokens_norm = [
|
||||
_normalize_gaz(t.strip(".,;:!?()[]{}\"'"))
|
||||
for t in tokens
|
||||
]
|
||||
|
||||
# --- Enrichissement HOPITAL (multi-token, sans contrainte de contexte) ---
|
||||
# On traite d'abord les hôpitaux car ils sont plus spécifiques
|
||||
for gaz_tokens in gazetteers.get("hopitaux", set()):
|
||||
gaz_len = len(gaz_tokens)
|
||||
if gaz_len == 0:
|
||||
continue
|
||||
i = 0
|
||||
while i <= n - gaz_len:
|
||||
# Vérifier si la séquence de tokens matche
|
||||
match = True
|
||||
for k in range(gaz_len):
|
||||
if tokens_norm[i + k] != gaz_tokens[k]:
|
||||
match = False
|
||||
break
|
||||
if match:
|
||||
# Vérifier que TOUS les tokens sont actuellement "O"
|
||||
all_o = all(labels[i + k] == "O" for k in range(gaz_len))
|
||||
if all_o:
|
||||
labels[i] = "B-HOPITAL"
|
||||
for k in range(1, gaz_len):
|
||||
labels[i + k] = "I-HOPITAL"
|
||||
added_hopitaux += 1
|
||||
i += gaz_len
|
||||
continue
|
||||
i += 1
|
||||
|
||||
# --- Enrichissement VILLE (avec contexte obligatoire) ---
|
||||
for gaz_tokens in gazetteers.get("villes", set()):
|
||||
gaz_len = len(gaz_tokens)
|
||||
if gaz_len == 0:
|
||||
continue
|
||||
i = 0
|
||||
while i <= n - gaz_len:
|
||||
# Vérifier si la séquence de tokens matche
|
||||
match = True
|
||||
for k in range(gaz_len):
|
||||
if tokens_norm[i + k] != gaz_tokens[k]:
|
||||
match = False
|
||||
break
|
||||
if match:
|
||||
# Vérifier que TOUS les tokens sont actuellement "O"
|
||||
all_o = all(labels[i + k] == "O" for k in range(gaz_len))
|
||||
if all_o and _has_ville_context(tokens, labels, i):
|
||||
labels[i] = "B-VILLE"
|
||||
for k in range(1, gaz_len):
|
||||
labels[i + k] = "I-VILLE"
|
||||
added_villes += 1
|
||||
i += gaz_len
|
||||
continue
|
||||
i += 1
|
||||
|
||||
enriched = list(zip(tokens, labels))
|
||||
return enriched, added_villes, added_hopitaux
|
||||
|
||||
|
||||
def extract_original_text(pdf_path: Path) -> str:
|
||||
"""Extrait le texte brut d'un PDF (même méthode que le pipeline)."""
|
||||
@@ -173,20 +346,37 @@ def align_and_annotate(original_text: str, pseudo_text: str) -> List[Tuple[str,
|
||||
return bio_tokens
|
||||
|
||||
|
||||
def export_document(pdf_path: Path, pseudo_path: Path, out_dir: Path) -> Tuple[int, int]:
|
||||
"""Exporte un document en format BIO. Retourne (nb_tokens, nb_entités)."""
|
||||
def export_document(
|
||||
pdf_path: Path,
|
||||
pseudo_path: Path,
|
||||
out_dir: Path,
|
||||
gazetteers: dict | None = None,
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""Exporte un document en format BIO.
|
||||
|
||||
Retourne (nb_tokens, nb_entités_diff, nb_villes_gaz, nb_hopitaux_gaz).
|
||||
"""
|
||||
# Extraire le texte original
|
||||
original_text = extract_original_text(pdf_path)
|
||||
if not original_text.strip():
|
||||
return 0, 0
|
||||
return 0, 0, 0, 0
|
||||
|
||||
# Lire le texte pseudonymisé
|
||||
pseudo_text = pseudo_path.read_text(encoding="utf-8")
|
||||
if not pseudo_text.strip():
|
||||
return 0, 0
|
||||
return 0, 0, 0, 0
|
||||
|
||||
# Aligner et annoter
|
||||
# Aligner et annoter (diff-based)
|
||||
bio_tokens = align_and_annotate(original_text, pseudo_text)
|
||||
n_ents_diff = sum(1 for _, l in bio_tokens if l.startswith("B-"))
|
||||
|
||||
# Enrichissement gazetteer (post-processing)
|
||||
added_villes = 0
|
||||
added_hopitaux = 0
|
||||
if gazetteers:
|
||||
bio_tokens, added_villes, added_hopitaux = enrich_with_gazetteers(
|
||||
bio_tokens, gazetteers
|
||||
)
|
||||
|
||||
# Écrire en format CoNLL
|
||||
out_name = pdf_path.stem + ".bio"
|
||||
@@ -202,8 +392,7 @@ def export_document(pdf_path: Path, pseudo_path: Path, out_dir: Path) -> Tuple[i
|
||||
|
||||
out_path.write_text("\n".join(lines), encoding="utf-8")
|
||||
|
||||
n_ents = sum(1 for _, l in bio_tokens if l.startswith("B-"))
|
||||
return len(bio_tokens), n_ents
|
||||
return len(bio_tokens), n_ents_diff, added_villes, added_hopitaux
|
||||
|
||||
|
||||
def main():
|
||||
@@ -212,12 +401,42 @@ def main():
|
||||
default=Path(__file__).parent.parent / "data" / "silver_annotations",
|
||||
help="Répertoire de sortie")
|
||||
parser.add_argument("--limit", type=int, default=0, help="Limiter à N fichiers (0=tous)")
|
||||
parser.add_argument("--no-gazetteers", action="store_true",
|
||||
help="Désactiver l'enrichissement par gazetteers")
|
||||
parser.add_argument("--extra-dir", type=Path, nargs="*", default=[],
|
||||
help="Répertoires supplémentaires contenant des .pseudonymise.txt")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Trouver les paires PDF + pseudo
|
||||
pseudo_files = sorted(AUDIT_DIR.glob("*.pseudonymise.txt"))
|
||||
# Charger les gazetteers
|
||||
gazetteers = None
|
||||
if not args.no_gazetteers:
|
||||
gazetteers = load_gazetteers()
|
||||
n_villes = len(gazetteers["villes"])
|
||||
n_hop = len(gazetteers["hopitaux"])
|
||||
print(f"Gazetteers chargés: {n_villes} villes, {n_hop} établissements")
|
||||
else:
|
||||
print("Gazetteers désactivés")
|
||||
|
||||
# Trouver les paires PDF + pseudo (audit_30 + extra dirs)
|
||||
search_dirs = [AUDIT_DIR] + list(args.extra_dir)
|
||||
pseudo_files = []
|
||||
for sdir in search_dirs:
|
||||
if sdir.exists():
|
||||
pseudo_files.extend(sorted(sdir.glob("*.pseudonymise.txt")))
|
||||
print(f" {sdir.name}: {len(list(sdir.glob('*.pseudonymise.txt')))} fichiers pseudo")
|
||||
|
||||
# Dédupliquer par nom de base
|
||||
seen_bases = set()
|
||||
unique_pseudo = []
|
||||
for pf in pseudo_files:
|
||||
base = pf.name.replace(".pseudonymise.txt", "")
|
||||
if base not in seen_bases:
|
||||
seen_bases.add(base)
|
||||
unique_pseudo.append(pf)
|
||||
pseudo_files = unique_pseudo
|
||||
|
||||
pairs = []
|
||||
for pseudo_path in pseudo_files:
|
||||
# Retrouver le PDF source
|
||||
@@ -233,17 +452,35 @@ def main():
|
||||
print(f"Export silver annotations: {len(pairs)} documents → {args.out_dir}")
|
||||
|
||||
total_tokens = 0
|
||||
total_entities = 0
|
||||
total_ents_diff = 0
|
||||
total_villes_gaz = 0
|
||||
total_hop_gaz = 0
|
||||
for pdf_path, pseudo_path in pairs:
|
||||
try:
|
||||
n_tok, n_ent = export_document(pdf_path, pseudo_path, args.out_dir)
|
||||
n_tok, n_diff, n_vgaz, n_hgaz = export_document(
|
||||
pdf_path, pseudo_path, args.out_dir, gazetteers
|
||||
)
|
||||
total_tokens += n_tok
|
||||
total_entities += n_ent
|
||||
print(f" {pdf_path.name}: {n_tok} tokens, {n_ent} entités")
|
||||
total_ents_diff += n_diff
|
||||
total_villes_gaz += n_vgaz
|
||||
total_hop_gaz += n_hgaz
|
||||
gaz_info = ""
|
||||
if n_vgaz or n_hgaz:
|
||||
gaz_info = f" (+{n_vgaz} villes, +{n_hgaz} hôpitaux gaz.)"
|
||||
print(f" {pdf_path.name}: {n_tok} tokens, {n_diff} entités diff{gaz_info}")
|
||||
except Exception as e:
|
||||
print(f" {pdf_path.name}: ERREUR {e}")
|
||||
|
||||
print(f"\nTotal: {total_tokens} tokens, {total_entities} entités B-")
|
||||
total_gaz = total_villes_gaz + total_hop_gaz
|
||||
total_all = total_ents_diff + total_gaz
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Total tokens: {total_tokens}")
|
||||
print(f"Entités diff-based: {total_ents_diff} B-")
|
||||
print(f"Entités gazetteers: +{total_gaz} ({total_villes_gaz} VILLE, {total_hop_gaz} HOPITAL)")
|
||||
print(f"Total entités: {total_all} B-")
|
||||
if total_ents_diff > 0:
|
||||
pct = 100 * total_gaz / total_ents_diff
|
||||
print(f"Enrichissement: +{pct:.1f}% par gazetteers")
|
||||
print(f"Sortie: {args.out_dir}")
|
||||
|
||||
|
||||
|
||||
@@ -7,14 +7,16 @@ exportées par export_silver_annotations.py.
|
||||
|
||||
Usage:
|
||||
python scripts/finetune_camembert_bio.py [--epochs 5] [--batch-size 8] [--lr 2e-5]
|
||||
python scripts/finetune_camembert_bio.py --no-augment # Sans augmentation
|
||||
|
||||
Prérequis: pip install transformers datasets seqeval accelerate
|
||||
Export ONNX post-training: python scripts/export_onnx.py
|
||||
"""
|
||||
import sys
|
||||
import argparse
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
@@ -120,6 +122,290 @@ def load_bio_files(data_dir: Path, window_size: int = 200, stride: int = 100) ->
|
||||
return {"tokens": tokens_list, "ner_tags": labels_list}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Data Augmentation : substitution de noms
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _load_gazetteer(filepath: Path, max_entries: int = 10000) -> List[str]:
|
||||
"""Charge un fichier de noms (un par ligne), en ignorant les lignes vides."""
|
||||
if not filepath.exists():
|
||||
print(f" [WARN] Fichier gazetteer absent : {filepath}")
|
||||
return []
|
||||
names = []
|
||||
with open(filepath, encoding="utf-8", errors="replace") as f:
|
||||
for line in f:
|
||||
name = line.strip()
|
||||
if name and len(name) >= 2:
|
||||
names.append(name)
|
||||
if len(names) >= max_entries:
|
||||
break
|
||||
return names
|
||||
|
||||
|
||||
def _match_case(replacement: str, original: str) -> str:
|
||||
"""Applique la casse de l'original au remplacement."""
|
||||
if original.isupper():
|
||||
return replacement.upper()
|
||||
if original.istitle():
|
||||
return replacement.title()
|
||||
if original.islower():
|
||||
return replacement.lower()
|
||||
return replacement
|
||||
|
||||
|
||||
def _extract_per_entities(tokens: List[str], labels: List[int]) -> List[Tuple[int, int]]:
|
||||
"""Extrait les spans (start, end exclusive) des entités PER dans une séquence BIO.
|
||||
|
||||
Retourne une liste de tuples (start_idx, end_idx).
|
||||
"""
|
||||
entities = []
|
||||
i = 0
|
||||
b_per_id = LABEL2ID.get("B-PER", -1)
|
||||
i_per_id = LABEL2ID.get("I-PER", -1)
|
||||
|
||||
while i < len(labels):
|
||||
if labels[i] == b_per_id:
|
||||
start = i
|
||||
i += 1
|
||||
while i < len(labels) and labels[i] == i_per_id:
|
||||
i += 1
|
||||
entities.append((start, i))
|
||||
else:
|
||||
i += 1
|
||||
return entities
|
||||
|
||||
|
||||
def augment_name_substitution(
|
||||
tokens_list: List[List[str]],
|
||||
labels_list: List[List[int]],
|
||||
prenoms_file: Path,
|
||||
noms_file: Path,
|
||||
num_augmented: int = 3,
|
||||
rng_seed: int = 42,
|
||||
) -> Tuple[List[List[str]], List[List[int]]]:
|
||||
"""Augmente les données en substituant les entités PER par des noms aléatoires.
|
||||
|
||||
Pour chaque exemple contenant au moins une entité PER, crée num_augmented
|
||||
copies avec des noms de remplacement tirés des fichiers gazetteers INSEE.
|
||||
La casse (UPPER, Title, lower) est conservée.
|
||||
|
||||
Retourne (aug_tokens_list, aug_labels_list) — les copies augmentées SANS
|
||||
les originaux (l'appelant les concatène).
|
||||
"""
|
||||
rng = random.Random(rng_seed)
|
||||
prenoms = _load_gazetteer(prenoms_file)
|
||||
noms = _load_gazetteer(noms_file)
|
||||
|
||||
if not prenoms:
|
||||
print(" [WARN] Aucun prénom chargé — augmentation désactivée")
|
||||
return [], []
|
||||
if not noms:
|
||||
print(" [WARN] Aucun nom de famille chargé — utilisation des prénoms uniquement")
|
||||
noms = prenoms # fallback gracieux
|
||||
|
||||
print(f" Gazetteers : {len(prenoms)} prénoms, {len(noms)} noms de famille")
|
||||
|
||||
b_per_id = LABEL2ID["B-PER"]
|
||||
aug_tokens: List[List[str]] = []
|
||||
aug_labels: List[List[int]] = []
|
||||
n_augmented_examples = 0
|
||||
|
||||
for tokens, labels in zip(tokens_list, labels_list):
|
||||
entities = _extract_per_entities(tokens, labels)
|
||||
if not entities:
|
||||
continue
|
||||
|
||||
for _ in range(num_augmented):
|
||||
new_tokens = list(tokens)
|
||||
new_labels = list(labels)
|
||||
|
||||
for start, end in entities:
|
||||
span_tokens = tokens[start:end]
|
||||
span_len = end - start
|
||||
|
||||
# Générer un remplacement de même longueur en tokens
|
||||
replacements = []
|
||||
for j, orig_tok in enumerate(span_tokens):
|
||||
if "-" in orig_tok:
|
||||
# Nom composé : JEAN-PIERRE → MARIE-CLAIRE
|
||||
parts = orig_tok.split("-")
|
||||
new_parts = [
|
||||
_match_case(rng.choice(prenoms), p) for p in parts
|
||||
]
|
||||
replacements.append("-".join(new_parts))
|
||||
elif j == 0 and span_len >= 2:
|
||||
# Premier token d'une entité multi-tokens → prénom
|
||||
replacements.append(_match_case(rng.choice(prenoms), orig_tok))
|
||||
elif span_len == 1:
|
||||
# Entité mono-token → 50/50 prénom ou nom
|
||||
pool = prenoms if rng.random() < 0.5 else noms
|
||||
replacements.append(_match_case(rng.choice(pool), orig_tok))
|
||||
else:
|
||||
# Tokens suivants → nom de famille
|
||||
replacements.append(_match_case(rng.choice(noms), orig_tok))
|
||||
|
||||
for j, idx in enumerate(range(start, end)):
|
||||
new_tokens[idx] = replacements[j]
|
||||
|
||||
aug_tokens.append(new_tokens)
|
||||
aug_labels.append(new_labels)
|
||||
n_augmented_examples += 1
|
||||
|
||||
print(f" Augmentation noms : {n_augmented_examples} exemples générés "
|
||||
f"(x{num_augmented} par exemple avec entités PER)")
|
||||
return aug_tokens, aug_labels
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Hard Negative Mining : médicaments
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Liste de secours si aucun fichier de médicaments n'est trouvé
|
||||
_FALLBACK_MEDICATIONS = [
|
||||
"DOLIPRANE", "EFFERALGAN", "DAFALGAN", "IBUPROFENE", "ADVIL", "NUROFEN",
|
||||
"SPASFON", "SMECTA", "GAVISCON", "MOPRAL", "OMEPRAZOLE", "PANTOPRAZOLE",
|
||||
"LANSOPRAZOLE", "INEXIUM", "AUGMENTIN", "AMOXICILLINE", "CLAMOXYL",
|
||||
"METFORMINE", "GLUCOPHAGE", "DIAMICRON", "LANTUS", "NOVORAPID", "HUMALOG",
|
||||
"LEVOTHYROX", "EUTHYROX", "KARDEGIC", "ASPEGIC", "PLAVIX", "CLOPIDOGREL",
|
||||
"ELIQUIS", "XARELTO", "PRADAXA", "COUMADINE", "PREVISCAN", "LOVENOX",
|
||||
"RAMIPRIL", "TRIATEC", "PERINDOPRIL", "COVERSYL", "AMLODIPINE", "AMLOR",
|
||||
"BISOPROLOL", "ATENOLOL", "METOPROLOL", "FUROSEMIDE", "LASILIX",
|
||||
"TAHOR", "CRESTOR", "PRAVASTATINE", "SIMVASTATINE",
|
||||
]
|
||||
|
||||
|
||||
def _load_medications(medications_file: Path) -> List[str]:
|
||||
"""Charge une liste de noms de médicaments depuis le fichier BDPM ou le fallback."""
|
||||
if medications_file.exists():
|
||||
meds = []
|
||||
with open(medications_file, encoding="utf-8", errors="replace") as f:
|
||||
for line in f:
|
||||
name = line.strip()
|
||||
if name and len(name) >= 3:
|
||||
meds.append(name)
|
||||
if meds:
|
||||
print(f" Médicaments chargés : {len(meds)} depuis {medications_file.name}")
|
||||
return meds
|
||||
|
||||
print(f" [WARN] Fichier médicaments absent ou vide — utilisation de {len(_FALLBACK_MEDICATIONS)} médicaments courants")
|
||||
return list(_FALLBACK_MEDICATIONS)
|
||||
|
||||
|
||||
# Templates de phrases médicales où les médicaments ne sont PAS des noms de personnes
|
||||
_MEDICATION_TEMPLATES = [
|
||||
"Traitement par {med} {dose}",
|
||||
"Prescription de {med} {forme}",
|
||||
"Relais par {med}",
|
||||
"Introduction de {med} {dose}",
|
||||
"Poursuite du traitement par {med}",
|
||||
"Arrêt du {med} en raison des effets secondaires",
|
||||
"Allergie connue au {med}",
|
||||
"Switch de {med} vers {med2}",
|
||||
"Patient sous {med} depuis 3 mois",
|
||||
"Posologie de {med} augmentée à {dose}",
|
||||
"Administration de {med} en IV",
|
||||
"Relais per os par {med} {dose}",
|
||||
]
|
||||
|
||||
_DOSE_VARIANTS = [
|
||||
"500mg", "1000mg", "200mg", "100mg", "50mg", "250mg", "75mg",
|
||||
"20mg", "40mg", "10mg", "5mg", "1g", "2g", "80mg", "160mg",
|
||||
]
|
||||
|
||||
_FORME_VARIANTS = [
|
||||
"comprimé", "gélule", "sachet", "injectable", "sirop",
|
||||
"pommade", "collyre", "suppositoire", "patch", "cp",
|
||||
]
|
||||
|
||||
|
||||
def _load_quaero_entities(quaero_dir: Path) -> List[str]:
|
||||
"""Charge les entités QUAERO (CHEM, DISO, PROC, ANAT) comme termes médicaux."""
|
||||
entities = []
|
||||
for filename in ["quaero_chem_entities.txt", "quaero_diso_entities.txt",
|
||||
"quaero_procedures.txt", "quaero_anatomie.txt"]:
|
||||
fpath = quaero_dir / filename
|
||||
if fpath.exists():
|
||||
for line in fpath.read_text(encoding="utf-8").splitlines():
|
||||
name = line.strip()
|
||||
if name and len(name) >= 3:
|
||||
entities.append(name)
|
||||
return entities
|
||||
|
||||
|
||||
# Templates supplémentaires pour termes médicaux QUAERO (maladies, procédures, anatomie)
|
||||
_MEDICAL_CONTEXT_TEMPLATES = [
|
||||
"Diagnostic de {term} confirmé par imagerie",
|
||||
"Antécédent de {term}",
|
||||
"Suspicion de {term} évoquée",
|
||||
"Bilan pour {term}",
|
||||
"Prise en charge de {term} par le service",
|
||||
"Examen de {term} sans anomalie",
|
||||
"Surveillance de {term} en cours",
|
||||
"Indication opératoire pour {term}",
|
||||
]
|
||||
|
||||
|
||||
def add_hard_negatives(
|
||||
tokens_list: List[List[str]],
|
||||
labels_list: List[List[int]],
|
||||
medications_file: Path,
|
||||
n_examples: int = 600,
|
||||
rng_seed: int = 42,
|
||||
) -> Tuple[List[List[str]], List[List[int]]]:
|
||||
"""Crée des exemples synthétiques de contextes médicaux (tout label O).
|
||||
|
||||
Utilise les noms de médicaments (BDPM) et les termes médicaux QUAERO
|
||||
(CHEM, DISO, PROC, ANAT) pour enseigner au modèle que ces termes
|
||||
ne sont PAS des noms de personnes.
|
||||
|
||||
Retourne (neg_tokens_list, neg_labels_list) — les négatifs à ajouter.
|
||||
"""
|
||||
rng = random.Random(rng_seed)
|
||||
medications = _load_medications(medications_file)
|
||||
|
||||
# Charger aussi les entités QUAERO
|
||||
quaero_dir = medications_file.parent.parent / "quaero"
|
||||
quaero_terms = _load_quaero_entities(quaero_dir)
|
||||
if quaero_terms:
|
||||
print(f" Termes médicaux QUAERO : {len(quaero_terms)} (CHEM+DISO+PROC+ANAT)")
|
||||
|
||||
if not medications:
|
||||
return [], []
|
||||
|
||||
neg_tokens: List[List[str]] = []
|
||||
neg_labels: List[List[int]] = []
|
||||
|
||||
# 1. Hard negatives médicaments (BDPM)
|
||||
n_med = n_examples * 2 // 3 # 2/3 médicaments
|
||||
for i in range(n_med):
|
||||
template = rng.choice(_MEDICATION_TEMPLATES)
|
||||
med = rng.choice(medications)
|
||||
med2 = rng.choice(medications)
|
||||
dose = rng.choice(_DOSE_VARIANTS)
|
||||
forme = rng.choice(_FORME_VARIANTS)
|
||||
|
||||
sentence = template.format(med=med, med2=med2, dose=dose, forme=forme)
|
||||
toks = sentence.split()
|
||||
labs = [LABEL2ID["O"]] * len(toks)
|
||||
neg_tokens.append(toks)
|
||||
neg_labels.append(labs)
|
||||
|
||||
# 2. Hard negatives QUAERO (maladies, procédures, anatomie)
|
||||
n_quaero = n_examples - n_med # 1/3 termes QUAERO
|
||||
if quaero_terms:
|
||||
for i in range(n_quaero):
|
||||
template = rng.choice(_MEDICAL_CONTEXT_TEMPLATES)
|
||||
term = rng.choice(quaero_terms)
|
||||
sentence = template.format(term=term)
|
||||
toks = sentence.split()
|
||||
labs = [LABEL2ID["O"]] * len(toks)
|
||||
neg_tokens.append(toks)
|
||||
neg_labels.append(labs)
|
||||
|
||||
print(f" Hard negatives : {n_med} médicaments + {n_quaero if quaero_terms else 0} QUAERO = {len(neg_tokens)} exemples")
|
||||
return neg_tokens, neg_labels
|
||||
|
||||
|
||||
def tokenize_and_align(examples, tokenizer):
|
||||
"""Tokenize et aligne les labels avec les sous-tokens."""
|
||||
tokenized = tokenizer(
|
||||
@@ -218,11 +504,28 @@ def main():
|
||||
default=Path(__file__).parent.parent / "models" / "camembert-bio-deid",
|
||||
help="Répertoire de sortie du modèle")
|
||||
parser.add_argument("--epochs", type=int, default=5)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--batch-size", type=int, default=16,
|
||||
help="Batch size effectif (via gradient accumulation si nécessaire)")
|
||||
parser.add_argument("--gpu-batch-size", type=int, default=8,
|
||||
help="Batch size réel sur GPU (le reste via gradient accumulation)")
|
||||
parser.add_argument("--lr", type=float, default=2e-5)
|
||||
parser.add_argument("--val-split", type=float, default=0.15, help="Fraction pour validation")
|
||||
parser.add_argument("--no-augment", action="store_true",
|
||||
help="Désactiver l'augmentation de données (substitution noms + hard negatives)")
|
||||
parser.add_argument("--num-augmented", type=int, default=3,
|
||||
help="Nombre de copies augmentées par exemple PER (défaut: 3)")
|
||||
parser.add_argument("--num-hard-negatives", type=int, default=600,
|
||||
help="Nombre d'exemples hard negative médicaments (défaut: 600)")
|
||||
parser.add_argument("--augment-seed", type=int, default=42,
|
||||
help="Seed pour la reproductibilité de l'augmentation")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Chemins des gazetteers
|
||||
project_root = Path(__file__).parent.parent
|
||||
prenoms_file = project_root / "data" / "insee" / "prenoms_france.txt"
|
||||
noms_file = project_root / "data" / "insee" / "noms_famille_france.txt"
|
||||
medications_file = project_root / "data" / "bdpm" / "medication_names.txt"
|
||||
|
||||
# Charger les données
|
||||
print(f"Chargement des données depuis {args.data_dir}...")
|
||||
raw_data = load_bio_files(args.data_dir)
|
||||
@@ -234,10 +537,59 @@ def main():
|
||||
print("ERREUR: pas assez de données. Lancez d'abord export_silver_annotations.py")
|
||||
sys.exit(1)
|
||||
|
||||
# Split train/val
|
||||
# Split train/val AVANT augmentation (on n'augmente que le train)
|
||||
dataset = Dataset.from_dict(raw_data)
|
||||
split = dataset.train_test_split(test_size=args.val_split, seed=42)
|
||||
datasets = DatasetDict({"train": split["train"], "validation": split["test"]})
|
||||
|
||||
# Récupérer les listes Python du split train pour l'augmentation
|
||||
train_tokens = list(split["train"]["tokens"])
|
||||
train_labels = list(split["train"]["ner_tags"])
|
||||
val_tokens = list(split["test"]["tokens"])
|
||||
val_labels = list(split["test"]["ner_tags"])
|
||||
|
||||
n_train_orig = len(train_tokens)
|
||||
|
||||
# Augmentation (train uniquement)
|
||||
if not args.no_augment:
|
||||
print(f"\nAugmentation des données d'entraînement (seed={args.augment_seed})...")
|
||||
|
||||
# 1. Substitution de noms
|
||||
aug_tok, aug_lab = augment_name_substitution(
|
||||
train_tokens, train_labels,
|
||||
prenoms_file=prenoms_file,
|
||||
noms_file=noms_file,
|
||||
num_augmented=args.num_augmented,
|
||||
rng_seed=args.augment_seed,
|
||||
)
|
||||
train_tokens.extend(aug_tok)
|
||||
train_labels.extend(aug_lab)
|
||||
|
||||
# 2. Hard negatives médicaments
|
||||
neg_tok, neg_lab = add_hard_negatives(
|
||||
train_tokens, train_labels,
|
||||
medications_file=medications_file,
|
||||
n_examples=args.num_hard_negatives,
|
||||
rng_seed=args.augment_seed,
|
||||
)
|
||||
train_tokens.extend(neg_tok)
|
||||
train_labels.extend(neg_lab)
|
||||
|
||||
print(f"\n Résumé augmentation :")
|
||||
print(f" Train original : {n_train_orig} exemples")
|
||||
print(f" + Substitution noms : {len(aug_tok)} exemples")
|
||||
print(f" + Hard negatives : {len(neg_tok)} exemples")
|
||||
print(f" = Train total : {len(train_tokens)} exemples")
|
||||
print(f" Validation (non augmentée) : {len(val_tokens)} exemples")
|
||||
else:
|
||||
print("\n Augmentation désactivée (--no-augment)")
|
||||
|
||||
# Reconstruire les datasets
|
||||
train_data = {"tokens": train_tokens, "ner_tags": train_labels}
|
||||
val_data = {"tokens": val_tokens, "ner_tags": val_labels}
|
||||
datasets = DatasetDict({
|
||||
"train": Dataset.from_dict(train_data),
|
||||
"validation": Dataset.from_dict(val_data),
|
||||
})
|
||||
print(f" Train: {len(datasets['train'])}, Validation: {len(datasets['validation'])}")
|
||||
|
||||
# Tokenizer + modèle
|
||||
@@ -281,17 +633,24 @@ def main():
|
||||
"f1": results["overall_f1"],
|
||||
}
|
||||
|
||||
# Class weights pour contrer le déséquilibre 97% O
|
||||
# Class weights pour contrer le déséquilibre 97% O (calculé sur le train augmenté)
|
||||
print("\nCalcul des poids de classe...")
|
||||
weights = compute_class_weights(raw_data, len(LABEL_LIST))
|
||||
weights = compute_class_weights(train_data, len(LABEL_LIST))
|
||||
|
||||
# Training
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Gradient accumulation : si batch_size demandé > batch effectif GPU
|
||||
grad_accum = max(1, args.batch_size // args.gpu_batch_size)
|
||||
effective_batch = args.gpu_batch_size * grad_accum
|
||||
if effective_batch != args.batch_size:
|
||||
print(f" GPU batch={args.gpu_batch_size}, grad_accum={grad_accum} → effectif={effective_batch}")
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(args.output_dir),
|
||||
num_train_epochs=args.epochs,
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
per_device_eval_batch_size=args.batch_size * 2,
|
||||
per_device_train_batch_size=args.gpu_batch_size,
|
||||
per_device_eval_batch_size=args.gpu_batch_size * 2,
|
||||
gradient_accumulation_steps=grad_accum,
|
||||
learning_rate=args.lr,
|
||||
weight_decay=0.01,
|
||||
warmup_ratio=0.1,
|
||||
|
||||
Reference in New Issue
Block a user