feat: rééquilibrage dataset LoRA — raisonnement DIM vs mémorisation

Passe de 95/3/2 (lookups/raisonnement/règles) à ~31/49/20.
Dataset cible ~16K exemples denses (vs 66K de lookups avant).

Modifiés :
- 03_convert_cache.py : cache complet 1840 entrées (actuel + backup)
- 04_build_dataset.py : subsampling agressif (CIM-10 1.5K, CCAM 1.5K,
  CoCoA 2K) + sélection intelligente priorisant le raisonnement
- 12_generate_pipeline_examples.py : 3 templates (court + long + CPAM),
  cache actuel, cible ~2800 exemples

Créés :
- 13_generate_fascicule_reasoning.py : parsing 10 fascicules ATIH,
  génération Q&A raisonnement via Claude Opus 4.6 (~450 exemples)
- 14_generate_negative_examples.py : 1000 exemples négatifs
  (symptômes/DP, redondances sémantiques, DAS non significatifs)
- 15_generate_discrimination.py : 800 exercices de discrimination
  entre codes siblings CIM-10 via Claude Opus 4.6
- 16_parse_guide_metho.py : extraction Guide Méthodologique MCO 2026,
  Q&A directes + raisonnement via Claude Opus 4.6 (~500 exemples)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dom
2026-02-16 19:42:33 +01:00
commit 06100df236
21 changed files with 6106 additions and 0 deletions

View File

@@ -0,0 +1,263 @@
#!/usr/bin/env python3
"""
Phase 3 — Fine-tuning du modèle d'embedding pour le RAG médical.
Fine-tune sentence-camembert-large sur les triplets PMSI/CIM-10
pour améliorer la recherche FAISS dans le pipeline T2A.
Prérequis :
- python scripts/09_build_embedding_triplets.py
- data/datasets/triplets.jsonl (41K+ triplets)
Usage :
python scripts/10_train_embedding.py [--epochs 3] [--batch 32] [--lr 2e-5] [--device cpu]
--device cpu : entraîner sur CPU (si le GPU est occupé par le LoRA)
--device cuda : entraîner sur GPU (plus rapide, ~30 min)
Hardware :
- CPU : ~2-3h (Ryzen 9 9950X)
- GPU : ~30 min (RTX 5070, si disponible)
- RAM : ~4 Go (modèle ~1.3 Go + données)
"""
import argparse
import json
import random
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
TRIPLETS_PATH = BASE / "data" / "datasets" / "triplets.jsonl"
OUTPUT_DIR = BASE / "models" / "sentence-camembert-pmsi"
# Modèle de base (même que dans le pipeline T2A)
BASE_MODEL = "dangvantuan/sentence-camembert-large"
def load_triplets(path: Path, eval_ratio: float = 0.1):
"""Charge les triplets et split train/eval."""
from datasets import Dataset
anchors, positives, negatives = [], [], []
with open(path) as f:
for line in f:
t = json.loads(line.strip())
anchors.append(t["anchor"])
positives.append(t["positive"])
negatives.append(t["negative"])
total = len(anchors)
print(f"Triplets chargés : {total}")
# Shuffle déterministe
indices = list(range(total))
random.shuffle(indices)
split = int(total * (1 - eval_ratio))
train_idx = indices[:split]
eval_idx = indices[split:]
train_ds = Dataset.from_dict({
"anchor": [anchors[i] for i in train_idx],
"positive": [positives[i] for i in train_idx],
"negative": [negatives[i] for i in train_idx],
})
eval_ds = Dataset.from_dict({
"anchor": [anchors[i] for i in eval_idx],
"positive": [positives[i] for i in eval_idx],
"negative": [negatives[i] for i in eval_idx],
})
print(f" Train : {len(train_ds)}")
print(f" Eval : {len(eval_ds)}")
return train_ds, eval_ds
def main():
parser = argparse.ArgumentParser(description="Fine-tuning embedding PMSI")
parser.add_argument("--model", default=BASE_MODEL,
help="Modèle de base sentence-transformers")
parser.add_argument("--triplets", type=Path, default=TRIPLETS_PATH,
help="Fichier de triplets JSONL")
parser.add_argument("--output", type=Path, default=OUTPUT_DIR,
help="Répertoire de sortie")
# Entraînement
parser.add_argument("--epochs", type=int, default=3, help="Nombre d'epochs")
parser.add_argument("--batch", type=int, default=32, help="Batch size")
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
parser.add_argument("--warmup-ratio", type=float, default=0.1, help="Warmup ratio")
parser.add_argument("--max-seq-length", type=int, default=256,
help="Longueur max des séquences")
# Loss
parser.add_argument("--loss", choices=["mnrl", "triplet"], default="mnrl",
help="Fonction de loss (mnrl=MultipleNegativesRankingLoss, triplet=TripletLoss)")
# Device
parser.add_argument("--device", default=None,
help="Device (cuda/cpu). Auto-détection si non spécifié.")
args = parser.parse_args()
print("=" * 60)
print("Fine-tuning embedding pour RAG PMSI")
print("=" * 60)
# Vérifier les triplets
if not args.triplets.exists():
raise FileNotFoundError(
f"Triplets non trouvés : {args.triplets}\n"
f"Lancez d'abord : python scripts/09_build_embedding_triplets.py"
)
# Device
import torch
if args.device is None:
if torch.cuda.is_available():
# Vérifier si le GPU a assez de VRAM libre (~2 Go suffisent)
free = torch.cuda.mem_get_info()[0] / 1024**3
if free > 2.0:
args.device = "cuda"
print(f"GPU détecté : {torch.cuda.get_device_name(0)} ({free:.1f} Go libre)")
else:
args.device = "cpu"
print(f"GPU occupé ({free:.1f} Go libre) — fallback CPU")
else:
args.device = "cpu"
print("Pas de GPU — entraînement CPU")
# Charger le modèle
from sentence_transformers import SentenceTransformer
print(f"\nChargement du modèle : {args.model}")
model = SentenceTransformer(args.model, device=args.device)
model.max_seq_length = args.max_seq_length
print(f" Max seq length : {model.max_seq_length}")
print(f" Dimension embedding : {model.get_sentence_embedding_dimension()}")
# Charger les données
print()
train_ds, eval_ds = load_triplets(args.triplets)
# Loss
if args.loss == "mnrl":
from sentence_transformers.losses import MultipleNegativesRankingLoss
loss = MultipleNegativesRankingLoss(model)
print(f"\nLoss : MultipleNegativesRankingLoss (in-batch negatives + hard negatives)")
else:
from sentence_transformers.losses import TripletLoss
loss = TripletLoss(model)
print(f"\nLoss : TripletLoss")
# Evaluator
from sentence_transformers.evaluation import TripletEvaluator
evaluator = TripletEvaluator(
anchors=eval_ds["anchor"],
positives=eval_ds["positive"],
negatives=eval_ds["negative"],
name="pmsi-eval",
batch_size=args.batch,
)
# Évaluation baseline (avant fine-tuning)
print("\nÉvaluation baseline (avant fine-tuning)...")
baseline = evaluator(model)
print(f" Accuracy baseline : {baseline.get('pmsi-eval_cosine_accuracy', 'N/A')}")
# Training args
from sentence_transformers import SentenceTransformerTrainingArguments
n_steps = len(train_ds) * args.epochs // args.batch
eval_steps = max(n_steps // 10, 50) # Évaluer ~10 fois pendant l'entraînement
save_steps = max(n_steps // 5, 100) # Sauvegarder ~5 fois
training_args = SentenceTransformerTrainingArguments(
output_dir=str(args.output / "checkpoints"),
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch,
per_device_eval_batch_size=args.batch,
learning_rate=args.lr,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type="cosine",
fp16=(args.device == "cuda"),
bf16=False,
logging_steps=50,
eval_strategy="steps",
eval_steps=eval_steps,
save_strategy="steps",
save_steps=save_steps,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="pmsi-eval_cosine_accuracy",
greater_is_better=True,
seed=42,
report_to="aim",
)
print(f"\nConfiguration :")
print(f" Epochs : {args.epochs}")
print(f" Batch size : {args.batch}")
print(f" Learning rate : {args.lr}")
print(f" Steps estimés : ~{n_steps}")
print(f" Eval tous les : {eval_steps} steps")
print(f" Device : {args.device}")
# Trainer
from sentence_transformers import SentenceTransformerTrainer
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
loss=loss,
evaluator=evaluator,
)
# Entraîner
print(f"\nDémarrage de l'entraînement...")
trainer.train()
# Sauvegarder le modèle final
final_dir = args.output / "final"
model.save_pretrained(str(final_dir))
print(f"\nModèle sauvegardé : {final_dir}")
# Évaluation finale
print("\nÉvaluation finale (après fine-tuning)...")
final_results = evaluator(model)
final_acc = final_results.get("pmsi-eval_cosine_accuracy", 0)
baseline_acc = baseline.get("pmsi-eval_cosine_accuracy", 0)
print(f"\n{'=' * 60}")
print(f"Résultats :")
print(f" Accuracy baseline : {baseline_acc:.4f}")
print(f" Accuracy finale : {final_acc:.4f}")
if baseline_acc > 0:
delta = final_acc - baseline_acc
pct = 100 * delta / baseline_acc
print(f" Amélioration : {delta:+.4f} ({pct:+.1f}%)")
print(f"\nModèle final : {final_dir}")
print(f"{'=' * 60}")
# Instructions d'intégration
print(f"""
Pour utiliser dans le pipeline T2A, modifier src/medical/rag_search.py :
_embed_model = SentenceTransformer("{final_dir}", device=_device)
Puis reconstruire l'index FAISS avec le nouveau modèle :
python -m src.medical.rag_index --rebuild
""")
if __name__ == "__main__":
main()