#!/usr/bin/env python3 """ Fine-tune CamemBERT-bio pour la désidentification clinique française. ===================================================================== Entraîne almanach/camembert-bio-base sur les annotations silver/gold exportées par export_silver_annotations.py. Usage: python scripts/finetune_camembert_bio.py [--epochs 5] [--batch-size 8] [--lr 2e-5] python scripts/finetune_camembert_bio.py --no-augment # Sans augmentation Prérequis: pip install transformers datasets seqeval accelerate Export ONNX post-training: python scripts/export_onnx.py """ import sys import json import subprocess import argparse import random from pathlib import Path from datetime import date from typing import Dict, List, Tuple from collections import Counter import numpy as np import torch from torch import nn # Vérifier les dépendances try: from transformers import ( AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification, ) from datasets import Dataset, DatasetDict import evaluate except ImportError as e: print(f"Dépendance manquante: {e}") print("Installez: pip install transformers datasets seqeval accelerate") sys.exit(1) # Labels BIO pour la désidentification LABEL_LIST = [ "O", "B-PER", "I-PER", "B-TEL", "I-TEL", "B-EMAIL", "I-EMAIL", "B-NIR", "I-NIR", "B-IPP", "I-IPP", "B-NDA", "I-NDA", "B-RPPS", "I-RPPS", "B-DATE_NAISSANCE", "I-DATE_NAISSANCE", "B-ADRESSE", "I-ADRESSE", "B-ZIP", "I-ZIP", "B-VILLE", "I-VILLE", "B-HOPITAL", "I-HOPITAL", "B-IBAN", "I-IBAN", "B-AGE", "I-AGE", ] LABEL2ID = {l: i for i, l in enumerate(LABEL_LIST)} ID2LABEL = {i: l for l, i in LABEL2ID.items()} MODEL_NAME = "almanach/camembert-bio-base" def load_bio_files(data_dir: Path, window_size: int = 200, stride: int = 100) -> Dict[str, List]: """Charge les fichiers .bio et découpe en fenêtres glissantes. Les documents cliniques sont très longs. On les découpe en fenêtres de ~window_size tokens avec un chevauchement de stride. On ne garde que les fenêtres contenant au moins une entité (pour l'équilibre des classes). """ tokens_list: List[List[str]] = [] labels_list: List[List[int]] = [] for bio_file in sorted(data_dir.glob("*.bio")): text = bio_file.read_text(encoding="utf-8") # Charger tous les tokens du document all_tokens: List[str] = [] all_labels: List[int] = [] for line in text.splitlines(): line = line.strip() if not line: continue parts = line.split("\t") if len(parts) != 2: continue token, label = parts label_id = LABEL2ID.get(label, LABEL2ID["O"]) all_tokens.append(token) all_labels.append(label_id) if not all_tokens: continue # Découper en fenêtres glissantes n = len(all_tokens) for start in range(0, n, stride): end = min(start + window_size, n) chunk_tokens = all_tokens[start:end] chunk_labels = all_labels[start:end] # Corriger les I- en début de fenêtre → B- if chunk_labels and chunk_labels[0] > 0: lbl_name = LABEL_LIST[chunk_labels[0]] if lbl_name.startswith("I-"): b_name = "B-" + lbl_name[2:] if b_name in LABEL2ID: chunk_labels[0] = LABEL2ID[b_name] # Garder les fenêtres avec entités + quelques fenêtres "O" (10%) has_entities = any(l != 0 for l in chunk_labels) if has_entities or (start % (stride * 10) == 0): tokens_list.append(chunk_tokens) labels_list.append(chunk_labels) if end >= n: break return {"tokens": tokens_list, "ner_tags": labels_list} # ───────────────────────────────────────────────────────────────────────────── # Data Augmentation : substitution de noms # ───────────────────────────────────────────────────────────────────────────── def _load_gazetteer(filepath: Path, max_entries: int = 10000) -> List[str]: """Charge un fichier de noms (un par ligne), en ignorant les lignes vides.""" if not filepath.exists(): print(f" [WARN] Fichier gazetteer absent : {filepath}") return [] names = [] with open(filepath, encoding="utf-8", errors="replace") as f: for line in f: name = line.strip() if name and len(name) >= 2: names.append(name) if len(names) >= max_entries: break return names def _match_case(replacement: str, original: str) -> str: """Applique la casse de l'original au remplacement.""" if original.isupper(): return replacement.upper() if original.istitle(): return replacement.title() if original.islower(): return replacement.lower() return replacement def _extract_per_entities(tokens: List[str], labels: List[int]) -> List[Tuple[int, int]]: """Extrait les spans (start, end exclusive) des entités PER dans une séquence BIO. Retourne une liste de tuples (start_idx, end_idx). """ entities = [] i = 0 b_per_id = LABEL2ID.get("B-PER", -1) i_per_id = LABEL2ID.get("I-PER", -1) while i < len(labels): if labels[i] == b_per_id: start = i i += 1 while i < len(labels) and labels[i] == i_per_id: i += 1 entities.append((start, i)) else: i += 1 return entities def augment_name_substitution( tokens_list: List[List[str]], labels_list: List[List[int]], prenoms_file: Path, noms_file: Path, num_augmented: int = 3, rng_seed: int = 42, ) -> Tuple[List[List[str]], List[List[int]]]: """Augmente les données en substituant les entités PER par des noms aléatoires. Pour chaque exemple contenant au moins une entité PER, crée num_augmented copies avec des noms de remplacement tirés des fichiers gazetteers INSEE. La casse (UPPER, Title, lower) est conservée. Retourne (aug_tokens_list, aug_labels_list) — les copies augmentées SANS les originaux (l'appelant les concatène). """ rng = random.Random(rng_seed) prenoms = _load_gazetteer(prenoms_file) noms = _load_gazetteer(noms_file) if not prenoms: print(" [WARN] Aucun prénom chargé — augmentation désactivée") return [], [] if not noms: print(" [WARN] Aucun nom de famille chargé — utilisation des prénoms uniquement") noms = prenoms # fallback gracieux print(f" Gazetteers : {len(prenoms)} prénoms, {len(noms)} noms de famille") b_per_id = LABEL2ID["B-PER"] aug_tokens: List[List[str]] = [] aug_labels: List[List[int]] = [] n_augmented_examples = 0 for tokens, labels in zip(tokens_list, labels_list): entities = _extract_per_entities(tokens, labels) if not entities: continue for _ in range(num_augmented): new_tokens = list(tokens) new_labels = list(labels) for start, end in entities: span_tokens = tokens[start:end] span_len = end - start # Générer un remplacement de même longueur en tokens replacements = [] for j, orig_tok in enumerate(span_tokens): if "-" in orig_tok: # Nom composé : JEAN-PIERRE → MARIE-CLAIRE parts = orig_tok.split("-") new_parts = [ _match_case(rng.choice(prenoms), p) for p in parts ] replacements.append("-".join(new_parts)) elif j == 0 and span_len >= 2: # Premier token d'une entité multi-tokens → prénom replacements.append(_match_case(rng.choice(prenoms), orig_tok)) elif span_len == 1: # Entité mono-token → 50/50 prénom ou nom pool = prenoms if rng.random() < 0.5 else noms replacements.append(_match_case(rng.choice(pool), orig_tok)) else: # Tokens suivants → nom de famille replacements.append(_match_case(rng.choice(noms), orig_tok)) for j, idx in enumerate(range(start, end)): new_tokens[idx] = replacements[j] aug_tokens.append(new_tokens) aug_labels.append(new_labels) n_augmented_examples += 1 print(f" Augmentation noms : {n_augmented_examples} exemples générés " f"(x{num_augmented} par exemple avec entités PER)") return aug_tokens, aug_labels # ───────────────────────────────────────────────────────────────────────────── # Hard Negative Mining : médicaments # ───────────────────────────────────────────────────────────────────────────── # Liste de secours si aucun fichier de médicaments n'est trouvé _FALLBACK_MEDICATIONS = [ "DOLIPRANE", "EFFERALGAN", "DAFALGAN", "IBUPROFENE", "ADVIL", "NUROFEN", "SPASFON", "SMECTA", "GAVISCON", "MOPRAL", "OMEPRAZOLE", "PANTOPRAZOLE", "LANSOPRAZOLE", "INEXIUM", "AUGMENTIN", "AMOXICILLINE", "CLAMOXYL", "METFORMINE", "GLUCOPHAGE", "DIAMICRON", "LANTUS", "NOVORAPID", "HUMALOG", "LEVOTHYROX", "EUTHYROX", "KARDEGIC", "ASPEGIC", "PLAVIX", "CLOPIDOGREL", "ELIQUIS", "XARELTO", "PRADAXA", "COUMADINE", "PREVISCAN", "LOVENOX", "RAMIPRIL", "TRIATEC", "PERINDOPRIL", "COVERSYL", "AMLODIPINE", "AMLOR", "BISOPROLOL", "ATENOLOL", "METOPROLOL", "FUROSEMIDE", "LASILIX", "TAHOR", "CRESTOR", "PRAVASTATINE", "SIMVASTATINE", ] def _load_medications(medications_file: Path) -> List[str]: """Charge une liste de noms de médicaments depuis le fichier BDPM ou le fallback.""" if medications_file.exists(): meds = [] with open(medications_file, encoding="utf-8", errors="replace") as f: for line in f: name = line.strip() if name and len(name) >= 3: meds.append(name) if meds: print(f" Médicaments chargés : {len(meds)} depuis {medications_file.name}") return meds print(f" [WARN] Fichier médicaments absent ou vide — utilisation de {len(_FALLBACK_MEDICATIONS)} médicaments courants") return list(_FALLBACK_MEDICATIONS) # Templates de phrases médicales où les médicaments ne sont PAS des noms de personnes _MEDICATION_TEMPLATES = [ "Traitement par {med} {dose}", "Prescription de {med} {forme}", "Relais par {med}", "Introduction de {med} {dose}", "Poursuite du traitement par {med}", "Arrêt du {med} en raison des effets secondaires", "Allergie connue au {med}", "Switch de {med} vers {med2}", "Patient sous {med} depuis 3 mois", "Posologie de {med} augmentée à {dose}", "Administration de {med} en IV", "Relais per os par {med} {dose}", ] _DOSE_VARIANTS = [ "500mg", "1000mg", "200mg", "100mg", "50mg", "250mg", "75mg", "20mg", "40mg", "10mg", "5mg", "1g", "2g", "80mg", "160mg", ] _FORME_VARIANTS = [ "comprimé", "gélule", "sachet", "injectable", "sirop", "pommade", "collyre", "suppositoire", "patch", "cp", ] def _load_quaero_entities(quaero_dir: Path) -> List[str]: """Charge les entités QUAERO (CHEM, DISO, PROC, ANAT) comme termes médicaux.""" entities = [] for filename in ["quaero_chem_entities.txt", "quaero_diso_entities.txt", "quaero_procedures.txt", "quaero_anatomie.txt"]: fpath = quaero_dir / filename if fpath.exists(): for line in fpath.read_text(encoding="utf-8").splitlines(): name = line.strip() if name and len(name) >= 3: entities.append(name) return entities # Templates supplémentaires pour termes médicaux QUAERO (maladies, procédures, anatomie) _MEDICAL_CONTEXT_TEMPLATES = [ "Diagnostic de {term} confirmé par imagerie", "Antécédent de {term}", "Suspicion de {term} évoquée", "Bilan pour {term}", "Prise en charge de {term} par le service", "Examen de {term} sans anomalie", "Surveillance de {term} en cours", "Indication opératoire pour {term}", ] def add_hard_negatives( tokens_list: List[List[str]], labels_list: List[List[int]], medications_file: Path, n_examples: int = 600, rng_seed: int = 42, ) -> Tuple[List[List[str]], List[List[int]]]: """Crée des exemples synthétiques de contextes médicaux (tout label O). Utilise les noms de médicaments (BDPM) et les termes médicaux QUAERO (CHEM, DISO, PROC, ANAT) pour enseigner au modèle que ces termes ne sont PAS des noms de personnes. Retourne (neg_tokens_list, neg_labels_list) — les négatifs à ajouter. """ rng = random.Random(rng_seed) medications = _load_medications(medications_file) # Charger aussi les entités QUAERO quaero_dir = medications_file.parent.parent / "quaero" quaero_terms = _load_quaero_entities(quaero_dir) if quaero_terms: print(f" Termes médicaux QUAERO : {len(quaero_terms)} (CHEM+DISO+PROC+ANAT)") if not medications: return [], [] neg_tokens: List[List[str]] = [] neg_labels: List[List[int]] = [] # 1. Hard negatives médicaments (BDPM) n_med = n_examples * 2 // 3 # 2/3 médicaments for i in range(n_med): template = rng.choice(_MEDICATION_TEMPLATES) med = rng.choice(medications) med2 = rng.choice(medications) dose = rng.choice(_DOSE_VARIANTS) forme = rng.choice(_FORME_VARIANTS) sentence = template.format(med=med, med2=med2, dose=dose, forme=forme) toks = sentence.split() labs = [LABEL2ID["O"]] * len(toks) neg_tokens.append(toks) neg_labels.append(labs) # 2. Hard negatives QUAERO (maladies, procédures, anatomie) n_quaero = n_examples - n_med # 1/3 termes QUAERO if quaero_terms: for i in range(n_quaero): template = rng.choice(_MEDICAL_CONTEXT_TEMPLATES) term = rng.choice(quaero_terms) sentence = template.format(term=term) toks = sentence.split() labs = [LABEL2ID["O"]] * len(toks) neg_tokens.append(toks) neg_labels.append(labs) print(f" Hard negatives : {n_med} médicaments + {n_quaero if quaero_terms else 0} QUAERO = {len(neg_tokens)} exemples") return neg_tokens, neg_labels def tokenize_and_align(examples, tokenizer): """Tokenize et aligne les labels avec les sous-tokens.""" tokenized = tokenizer( examples["tokens"], truncation=True, is_split_into_words=True, max_length=512, padding=False, ) all_labels = [] for i, labels in enumerate(examples["ner_tags"]): word_ids = tokenized.word_ids(batch_index=i) label_ids = [] prev_word_id = None for word_id in word_ids: if word_id is None: label_ids.append(-100) elif word_id != prev_word_id: label_ids.append(labels[word_id]) else: # Sous-token : I- si le premier est B-, sinon même label orig = labels[word_id] if orig > 0 and LABEL_LIST[orig].startswith("B-"): # Convertir B- en I- i_label = LABEL_LIST[orig].replace("B-", "I-") label_ids.append(LABEL2ID.get(i_label, orig)) else: label_ids.append(orig) prev_word_id = word_id all_labels.append(label_ids) tokenized["labels"] = all_labels return tokenized class WeightedNERTrainer(Trainer): """Trainer avec poids de classe pour contrer le déséquilibre O vs entités.""" def __init__(self, class_weights=None, **kwargs): super().__init__(**kwargs) if class_weights is not None: self.class_weights = class_weights.to(self.args.device) else: self.class_weights = None def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.pop("labels") outputs = model(**inputs) logits = outputs.logits if self.class_weights is not None: loss_fct = nn.CrossEntropyLoss( weight=self.class_weights, ignore_index=-100, label_smoothing=0.1, ) else: loss_fct = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=0.1) loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) return (loss, outputs) if return_outputs else loss def compute_class_weights(raw_data: Dict, num_labels: int, max_weight: float = 10.0) -> torch.FloatTensor: """Calcule les poids inversement proportionnels à la fréquence, cappés après normalisation.""" counts = Counter() for labels in raw_data["ner_tags"]: for l in labels: counts[l] += 1 total = sum(counts.values()) weights = torch.ones(num_labels) for label_id, count in counts.items(): if count > 0: weights[label_id] = total / (num_labels * count) # Normaliser : O=1.0 if weights[0] > 0: scale = 1.0 / weights[0] weights *= scale # Capper APRÈS normalisation pour limiter le déséquilibre weights = torch.clamp(weights, max=max_weight) print(f" Class weights (O={weights[0]:.1f}, non-O moyen={weights[1:].mean():.1f}, max={weights[1:].max():.1f})") return weights def main(): parser = argparse.ArgumentParser(description="Fine-tune CamemBERT-bio pour désidentification") parser.add_argument("--data-dir", type=Path, default=Path(__file__).parent.parent / "data" / "silver_annotations", help="Répertoire des fichiers .bio") parser.add_argument("--output-dir", type=Path, default=Path(__file__).parent.parent / "models" / "camembert-bio-deid", help="Répertoire de sortie du modèle") parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--batch-size", type=int, default=16, help="Batch size effectif (via gradient accumulation si nécessaire)") parser.add_argument("--gpu-batch-size", type=int, default=8, help="Batch size réel sur GPU (le reste via gradient accumulation)") parser.add_argument("--lr", type=float, default=2e-5) parser.add_argument("--val-split", type=float, default=0.15, help="Fraction pour validation") parser.add_argument("--no-augment", action="store_true", help="Désactiver l'augmentation de données (substitution noms + hard negatives)") parser.add_argument("--num-augmented", type=int, default=3, help="Nombre de copies augmentées par exemple PER (défaut: 3)") parser.add_argument("--num-hard-negatives", type=int, default=600, help="Nombre d'exemples hard negative médicaments (défaut: 600)") parser.add_argument("--augment-seed", type=int, default=42, help="Seed pour la reproductibilité de l'augmentation") args = parser.parse_args() # Chemins des gazetteers project_root = Path(__file__).parent.parent prenoms_file = project_root / "data" / "insee" / "prenoms_france.txt" noms_file = project_root / "data" / "insee" / "noms_famille_france.txt" medications_file = project_root / "data" / "bdpm" / "medication_names.txt" # Charger les données print(f"Chargement des données depuis {args.data_dir}...") raw_data = load_bio_files(args.data_dir) n_sentences = len(raw_data["tokens"]) n_entities = sum(1 for labels in raw_data["ner_tags"] for l in labels if l != 0) print(f" {n_sentences} phrases, {n_entities} entités annotées") if n_sentences < 10: print("ERREUR: pas assez de données. Lancez d'abord export_silver_annotations.py") sys.exit(1) # Split train/val AVANT augmentation (on n'augmente que le train) dataset = Dataset.from_dict(raw_data) split = dataset.train_test_split(test_size=args.val_split, seed=42) # Récupérer les listes Python du split train pour l'augmentation train_tokens = list(split["train"]["tokens"]) train_labels = list(split["train"]["ner_tags"]) val_tokens = list(split["test"]["tokens"]) val_labels = list(split["test"]["ner_tags"]) n_train_orig = len(train_tokens) # Augmentation (train uniquement) if not args.no_augment: print(f"\nAugmentation des données d'entraînement (seed={args.augment_seed})...") # 1. Substitution de noms aug_tok, aug_lab = augment_name_substitution( train_tokens, train_labels, prenoms_file=prenoms_file, noms_file=noms_file, num_augmented=args.num_augmented, rng_seed=args.augment_seed, ) train_tokens.extend(aug_tok) train_labels.extend(aug_lab) # 2. Hard negatives médicaments neg_tok, neg_lab = add_hard_negatives( train_tokens, train_labels, medications_file=medications_file, n_examples=args.num_hard_negatives, rng_seed=args.augment_seed, ) train_tokens.extend(neg_tok) train_labels.extend(neg_lab) print(f"\n Résumé augmentation :") print(f" Train original : {n_train_orig} exemples") print(f" + Substitution noms : {len(aug_tok)} exemples") print(f" + Hard negatives : {len(neg_tok)} exemples") print(f" = Train total : {len(train_tokens)} exemples") print(f" Validation (non augmentée) : {len(val_tokens)} exemples") else: print("\n Augmentation désactivée (--no-augment)") # Reconstruire les datasets train_data = {"tokens": train_tokens, "ner_tags": train_labels} val_data = {"tokens": val_tokens, "ner_tags": val_labels} datasets = DatasetDict({ "train": Dataset.from_dict(train_data), "validation": Dataset.from_dict(val_data), }) print(f" Train: {len(datasets['train'])}, Validation: {len(datasets['validation'])}") # Tokenizer + modèle print(f"\nChargement du modèle {MODEL_NAME}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForTokenClassification.from_pretrained( MODEL_NAME, num_labels=len(LABEL_LIST), id2label=ID2LABEL, label2id=LABEL2ID, ) # Tokenization tokenized = datasets.map( lambda ex: tokenize_and_align(ex, tokenizer), batched=True, remove_columns=datasets["train"].column_names, ) # Métriques seqeval = evaluate.load("seqeval") def compute_metrics(eval_pred): logits, labels = eval_pred predictions = np.argmax(logits, axis=-1) true_labels = [] true_preds = [] for pred_seq, label_seq in zip(predictions, labels): t_labels = [] t_preds = [] for p, l in zip(pred_seq, label_seq): if l != -100: t_labels.append(LABEL_LIST[l]) t_preds.append(LABEL_LIST[p]) true_labels.append(t_labels) true_preds.append(t_preds) results = seqeval.compute(predictions=true_preds, references=true_labels) return { "precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], } # Class weights pour contrer le déséquilibre 97% O (calculé sur le train augmenté) print("\nCalcul des poids de classe...") weights = compute_class_weights(train_data, len(LABEL_LIST)) # Training args.output_dir.mkdir(parents=True, exist_ok=True) # Gradient accumulation : si batch_size demandé > batch effectif GPU grad_accum = max(1, args.batch_size // args.gpu_batch_size) effective_batch = args.gpu_batch_size * grad_accum if effective_batch != args.batch_size: print(f" GPU batch={args.gpu_batch_size}, grad_accum={grad_accum} → effectif={effective_batch}") training_args = TrainingArguments( output_dir=str(args.output_dir), num_train_epochs=args.epochs, per_device_train_batch_size=args.gpu_batch_size, per_device_eval_batch_size=args.gpu_batch_size * 2, gradient_accumulation_steps=grad_accum, learning_rate=args.lr, weight_decay=0.01, warmup_ratio=0.1, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="f1", logging_steps=50, fp16=True, # GPU training avec mixed precision report_to="none", save_total_limit=2, ) data_collator = DataCollatorForTokenClassification(tokenizer) trainer = WeightedNERTrainer( class_weights=weights, model=model, args=training_args, train_dataset=tokenized["train"], eval_dataset=tokenized["validation"], data_collator=data_collator, compute_metrics=compute_metrics, tokenizer=tokenizer, ) print(f"\nDémarrage du fine-tuning ({args.epochs} epochs, batch={args.batch_size}, lr={args.lr})...") trainer.train() # Sauvegarder trainer.save_model(str(args.output_dir / "best")) tokenizer.save_pretrained(str(args.output_dir / "best")) print(f"\nModèle sauvegardé: {args.output_dir / 'best'}") # Évaluation finale results = trainer.evaluate() print(f"\nRésultats finaux:") print(f" Precision: {results['eval_precision']:.4f}") print(f" Recall: {results['eval_recall']:.4f}") print(f" F1: {results['eval_f1']:.4f}") # ── Export ONNX automatique ────────────────────────────────────────────── best_dir = args.output_dir / "best" onnx_dir = args.output_dir / "onnx" onnx_export_ok = False try: print(f"\nExport ONNX automatique...") print(f" Source : {best_dir}") print(f" Destination : {onnx_dir}") result = subprocess.run( [ sys.executable, "-m", "optimum.exporters.onnx", "--model", str(best_dir), "--task", "token-classification", str(onnx_dir), ], capture_output=True, text=True, timeout=600, ) if result.returncode == 0: onnx_export_ok = True print(f" Export ONNX réussi → {onnx_dir}") else: print(f" [ERREUR] Export ONNX échoué (code {result.returncode})") if result.stderr: # Afficher les dernières lignes d'erreur for line in result.stderr.strip().splitlines()[-10:]: print(f" {line}") print(f"\n Pour exporter manuellement :") print(f" python -m optimum.exporters.onnx --model {best_dir} --task token-classification {onnx_dir}") except FileNotFoundError: print(f" [WARN] optimum non installé — export ONNX ignoré") print(f" Pour exporter manuellement :") print(f" pip install optimum[exporters]") print(f" python -m optimum.exporters.onnx --model {best_dir} --task token-classification {onnx_dir}") except subprocess.TimeoutExpired: print(f" [ERREUR] Export ONNX timeout (>600s)") print(f" Pour exporter manuellement :") print(f" python -m optimum.exporters.onnx --model {best_dir} --task token-classification {onnx_dir}") except Exception as e: print(f" [ERREUR] Export ONNX inattendu : {e}") print(f" Pour exporter manuellement :") print(f" python -m optimum.exporters.onnx --model {best_dir} --task token-classification {onnx_dir}") # ── Mise à jour VERSION.json ───────────────────────────────────────────── version_file = args.output_dir / "VERSION.json" try: # Compter les documents d'entraînement (.bio files) n_bio_files = len(list(args.data_dir.glob("*.bio"))) # Déterminer le numéro de version if version_file.exists(): version_data = json.loads(version_file.read_text(encoding="utf-8")) else: version_data = { "model": "camembert-bio-deid", "base_model": MODEL_NAME, "versions": {}, "directories": {}, } # Incrémenter la version existing_versions = [ k for k in version_data.get("versions", {}).keys() if k.startswith("v") and k[1:].isdigit() ] if existing_versions: max_v = max(int(k[1:]) for k in existing_versions) new_version = f"v{max_v + 1}" else: new_version = "v1" # Trouver le best checkpoint (dernier sauvegardé par Trainer) best_checkpoint = None checkpoints = sorted(args.output_dir.glob("checkpoint-*")) if checkpoints: best_checkpoint = checkpoints[-1].name # Construire l'entrée de version version_entry = { "date": date.today().isoformat(), "training_docs": n_bio_files, "training_examples": len(train_tokens), "epochs": args.epochs, "batch_size": args.batch_size, "learning_rate": args.lr, "f1": round(results["eval_f1"], 4), "recall": round(results["eval_recall"], 4), "precision": round(results["eval_precision"], 4), "onnx_exported": onnx_export_ok, } if best_checkpoint: version_entry["best_checkpoint"] = best_checkpoint version_data["current_version"] = new_version version_data["versions"][new_version] = version_entry version_data["directories"] = { "onnx": f"Modèle ONNX actif ({new_version}) — utilisé en inférence CPU", f"best": f"Modèle PyTorch {new_version} (pour ré-export ONNX si besoin)", } version_file.write_text( json.dumps(version_data, indent=2, ensure_ascii=False) + "\n", encoding="utf-8", ) print(f"\n VERSION.json mis à jour → {new_version} (F1={results['eval_f1']:.4f})") except Exception as e: print(f"\n [WARN] Impossible de mettre à jour VERSION.json : {e}") if __name__ == "__main__": main()