#!/usr/bin/env python3 """ Phase 3 — Fine-tuning du modèle d'embedding pour le RAG médical. Fine-tune sentence-camembert-large sur les triplets PMSI/CIM-10 pour améliorer la recherche FAISS dans le pipeline T2A. Prérequis : - python scripts/09_build_embedding_triplets.py - data/datasets/triplets.jsonl (41K+ triplets) Usage : python scripts/10_train_embedding.py [--epochs 3] [--batch 32] [--lr 2e-5] [--device cpu] --device cpu : entraîner sur CPU (si le GPU est occupé par le LoRA) --device cuda : entraîner sur GPU (plus rapide, ~30 min) Hardware : - CPU : ~2-3h (Ryzen 9 9950X) - GPU : ~30 min (RTX 5070, si disponible) - RAM : ~4 Go (modèle ~1.3 Go + données) """ import argparse import json import random from pathlib import Path random.seed(42) BASE = Path(__file__).resolve().parent.parent TRIPLETS_PATH = BASE / "data" / "datasets" / "triplets.jsonl" OUTPUT_DIR = BASE / "models" / "sentence-camembert-pmsi" # Modèle de base (même que dans le pipeline T2A) BASE_MODEL = "dangvantuan/sentence-camembert-large" def load_triplets(path: Path, eval_ratio: float = 0.1): """Charge les triplets et split train/eval.""" from datasets import Dataset anchors, positives, negatives = [], [], [] with open(path) as f: for line in f: t = json.loads(line.strip()) anchors.append(t["anchor"]) positives.append(t["positive"]) negatives.append(t["negative"]) total = len(anchors) print(f"Triplets chargés : {total}") # Shuffle déterministe indices = list(range(total)) random.shuffle(indices) split = int(total * (1 - eval_ratio)) train_idx = indices[:split] eval_idx = indices[split:] train_ds = Dataset.from_dict({ "anchor": [anchors[i] for i in train_idx], "positive": [positives[i] for i in train_idx], "negative": [negatives[i] for i in train_idx], }) eval_ds = Dataset.from_dict({ "anchor": [anchors[i] for i in eval_idx], "positive": [positives[i] for i in eval_idx], "negative": [negatives[i] for i in eval_idx], }) print(f" Train : {len(train_ds)}") print(f" Eval : {len(eval_ds)}") return train_ds, eval_ds def main(): parser = argparse.ArgumentParser(description="Fine-tuning embedding PMSI") parser.add_argument("--model", default=BASE_MODEL, help="Modèle de base sentence-transformers") parser.add_argument("--triplets", type=Path, default=TRIPLETS_PATH, help="Fichier de triplets JSONL") parser.add_argument("--output", type=Path, default=OUTPUT_DIR, help="Répertoire de sortie") # Entraînement parser.add_argument("--epochs", type=int, default=3, help="Nombre d'epochs") parser.add_argument("--batch", type=int, default=32, help="Batch size") parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate") parser.add_argument("--warmup-ratio", type=float, default=0.1, help="Warmup ratio") parser.add_argument("--max-seq-length", type=int, default=256, help="Longueur max des séquences") # Loss parser.add_argument("--loss", choices=["mnrl", "triplet"], default="mnrl", help="Fonction de loss (mnrl=MultipleNegativesRankingLoss, triplet=TripletLoss)") # Device parser.add_argument("--device", default=None, help="Device (cuda/cpu). Auto-détection si non spécifié.") args = parser.parse_args() print("=" * 60) print("Fine-tuning embedding pour RAG PMSI") print("=" * 60) # Vérifier les triplets if not args.triplets.exists(): raise FileNotFoundError( f"Triplets non trouvés : {args.triplets}\n" f"Lancez d'abord : python scripts/09_build_embedding_triplets.py" ) # Device import torch if args.device is None: if torch.cuda.is_available(): # Vérifier si le GPU a assez de VRAM libre (~2 Go suffisent) free = torch.cuda.mem_get_info()[0] / 1024**3 if free > 2.0: args.device = "cuda" print(f"GPU détecté : {torch.cuda.get_device_name(0)} ({free:.1f} Go libre)") else: args.device = "cpu" print(f"GPU occupé ({free:.1f} Go libre) — fallback CPU") else: args.device = "cpu" print("Pas de GPU — entraînement CPU") # Charger le modèle from sentence_transformers import SentenceTransformer print(f"\nChargement du modèle : {args.model}") model = SentenceTransformer(args.model, device=args.device) model.max_seq_length = args.max_seq_length print(f" Max seq length : {model.max_seq_length}") print(f" Dimension embedding : {model.get_sentence_embedding_dimension()}") # Charger les données print() train_ds, eval_ds = load_triplets(args.triplets) # Loss if args.loss == "mnrl": from sentence_transformers.losses import MultipleNegativesRankingLoss loss = MultipleNegativesRankingLoss(model) print(f"\nLoss : MultipleNegativesRankingLoss (in-batch negatives + hard negatives)") else: from sentence_transformers.losses import TripletLoss loss = TripletLoss(model) print(f"\nLoss : TripletLoss") # Evaluator from sentence_transformers.evaluation import TripletEvaluator evaluator = TripletEvaluator( anchors=eval_ds["anchor"], positives=eval_ds["positive"], negatives=eval_ds["negative"], name="pmsi-eval", batch_size=args.batch, ) # Évaluation baseline (avant fine-tuning) print("\nÉvaluation baseline (avant fine-tuning)...") baseline = evaluator(model) print(f" Accuracy baseline : {baseline.get('pmsi-eval_cosine_accuracy', 'N/A')}") # Training args from sentence_transformers import SentenceTransformerTrainingArguments n_steps = len(train_ds) * args.epochs // args.batch eval_steps = max(n_steps // 10, 50) # Évaluer ~10 fois pendant l'entraînement save_steps = max(n_steps // 5, 100) # Sauvegarder ~5 fois training_args = SentenceTransformerTrainingArguments( output_dir=str(args.output / "checkpoints"), num_train_epochs=args.epochs, per_device_train_batch_size=args.batch, per_device_eval_batch_size=args.batch, learning_rate=args.lr, warmup_ratio=args.warmup_ratio, lr_scheduler_type="cosine", fp16=(args.device == "cuda"), bf16=False, logging_steps=50, eval_strategy="steps", eval_steps=eval_steps, save_strategy="steps", save_steps=save_steps, save_total_limit=3, load_best_model_at_end=True, metric_for_best_model="pmsi-eval_cosine_accuracy", greater_is_better=True, seed=42, report_to="aim", ) print(f"\nConfiguration :") print(f" Epochs : {args.epochs}") print(f" Batch size : {args.batch}") print(f" Learning rate : {args.lr}") print(f" Steps estimés : ~{n_steps}") print(f" Eval tous les : {eval_steps} steps") print(f" Device : {args.device}") # Trainer from sentence_transformers import SentenceTransformerTrainer trainer = SentenceTransformerTrainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds, loss=loss, evaluator=evaluator, ) # Entraîner print(f"\nDémarrage de l'entraînement...") trainer.train() # Sauvegarder le modèle final final_dir = args.output / "final" model.save_pretrained(str(final_dir)) print(f"\nModèle sauvegardé : {final_dir}") # Évaluation finale print("\nÉvaluation finale (après fine-tuning)...") final_results = evaluator(model) final_acc = final_results.get("pmsi-eval_cosine_accuracy", 0) baseline_acc = baseline.get("pmsi-eval_cosine_accuracy", 0) print(f"\n{'=' * 60}") print(f"Résultats :") print(f" Accuracy baseline : {baseline_acc:.4f}") print(f" Accuracy finale : {final_acc:.4f}") if baseline_acc > 0: delta = final_acc - baseline_acc pct = 100 * delta / baseline_acc print(f" Amélioration : {delta:+.4f} ({pct:+.1f}%)") print(f"\nModèle final : {final_dir}") print(f"{'=' * 60}") # Instructions d'intégration print(f""" Pour utiliser dans le pipeline T2A, modifier src/medical/rag_search.py : _embed_model = SentenceTransformer("{final_dir}", device=_device) Puis reconstruire l'index FAISS avec le nouveau modèle : python -m src.medical.rag_index --rebuild """) if __name__ == "__main__": main()