feat: rééquilibrage dataset LoRA — raisonnement DIM vs mémorisation
Passe de 95/3/2 (lookups/raisonnement/règles) à ~31/49/20. Dataset cible ~16K exemples denses (vs 66K de lookups avant). Modifiés : - 03_convert_cache.py : cache complet 1840 entrées (actuel + backup) - 04_build_dataset.py : subsampling agressif (CIM-10 1.5K, CCAM 1.5K, CoCoA 2K) + sélection intelligente priorisant le raisonnement - 12_generate_pipeline_examples.py : 3 templates (court + long + CPAM), cache actuel, cible ~2800 exemples Créés : - 13_generate_fascicule_reasoning.py : parsing 10 fascicules ATIH, génération Q&A raisonnement via Claude Opus 4.6 (~450 exemples) - 14_generate_negative_examples.py : 1000 exemples négatifs (symptômes/DP, redondances sémantiques, DAS non significatifs) - 15_generate_discrimination.py : 800 exercices de discrimination entre codes siblings CIM-10 via Claude Opus 4.6 - 16_parse_guide_metho.py : extraction Guide Méthodologique MCO 2026, Q&A directes + raisonnement via Claude Opus 4.6 (~500 exemples) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
202
scripts/04_build_dataset.py
Normal file
202
scripts/04_build_dataset.py
Normal file
@@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Phase 1 — Fusion de tous les sous-datasets en un dataset final.
|
||||
|
||||
Lit tous les .jsonl dans data/processed/ et produit :
|
||||
- data/datasets/pmsi_train.jsonl (90%)
|
||||
- data/datasets/pmsi_eval.jsonl (10%)
|
||||
- data/datasets/stats.json (statistiques)
|
||||
|
||||
V2 : Rééquilibrage — réduction drastique des lookups, priorité raisonnement.
|
||||
Ratio cible : ~35% lookups / 40% raisonnement / 25% règles (vs 95/3/2 avant).
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
random.seed(42)
|
||||
|
||||
BASE = Path(__file__).resolve().parent.parent
|
||||
PROCESSED = BASE / "data" / "processed"
|
||||
DATASETS = BASE / "data" / "datasets"
|
||||
DATASETS.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
EVAL_RATIO = 0.10 # 10% pour l'évaluation
|
||||
|
||||
# V2 : Sous-échantillonnage agressif des lookups, garder tout le raisonnement
|
||||
SUBSAMPLE = {
|
||||
"cim10_chatml.jsonl": 1_500, # 50K → 1.5K (discrimination + inclusions/exclusions)
|
||||
"ccam_chatml.jsonl": 1_500, # 16.5K → 1.5K (codes les plus fréquents)
|
||||
"cocoa_chatml.jsonl": 2_000, # 30K → 2K (types clinical + tips prioritaires)
|
||||
"referentiels_chatml.jsonl": None, # Garder tout (5.3K, déjà dense en règles)
|
||||
}
|
||||
|
||||
# Mots-clés indiquant du raisonnement structuré (pour sélection intelligente)
|
||||
REASONING_KEYWORDS = [
|
||||
"analyse_clinique", "discrimination", "regle_pmsi", "règle_pmsi",
|
||||
"codes_candidats", "justification", "ne pas coder", "ne doit pas",
|
||||
"à l'exclusion", "comprend", "sévérité", "CMA",
|
||||
]
|
||||
|
||||
|
||||
def count_tokens_approx(text):
|
||||
"""Estimation grossière du nombre de tokens (~1.3 token par mot français)."""
|
||||
return int(len(text.split()) * 1.3)
|
||||
|
||||
|
||||
def _reasoning_score(example: dict) -> int:
|
||||
"""Score de raisonnement d'un exemple (plus haut = plus précieux)."""
|
||||
text = " ".join(m.get("content", "") for m in example.get("messages", []))
|
||||
text_lower = text.lower()
|
||||
score = 0
|
||||
for kw in REASONING_KEYWORDS:
|
||||
if kw.lower() in text_lower:
|
||||
score += 1
|
||||
# Bonus pour les réponses JSON structurées longues (raisonnement complet)
|
||||
assistant_msgs = [m["content"] for m in example.get("messages", []) if m["role"] == "assistant"]
|
||||
for msg in assistant_msgs:
|
||||
if len(msg) > 200:
|
||||
score += 2
|
||||
if '"analyse_clinique"' in msg:
|
||||
score += 3
|
||||
return score
|
||||
|
||||
|
||||
def _smart_subsample(examples: list[dict], target: int) -> list[dict]:
|
||||
"""Sous-échantillonnage intelligent : prioriser les exemples avec raisonnement."""
|
||||
# Trier par score de raisonnement (décroissant), puis shuffle pour diversité
|
||||
scored = [(ex, _reasoning_score(ex)) for ex in examples]
|
||||
scored.sort(key=lambda x: -x[1])
|
||||
|
||||
# Garder tous ceux avec score > 0, puis compléter avec du random
|
||||
high_value = [(ex, s) for ex, s in scored if s > 0]
|
||||
low_value = [(ex, s) for ex, s in scored if s == 0]
|
||||
random.shuffle(low_value)
|
||||
|
||||
if len(high_value) >= target:
|
||||
# Trop d'exemples high-value : prendre les meilleurs + random parmi eux
|
||||
random.shuffle(high_value)
|
||||
selected = [ex for ex, _ in high_value[:target]]
|
||||
else:
|
||||
# Prendre tous les high-value + compléter avec low-value
|
||||
selected = [ex for ex, _ in high_value]
|
||||
remaining = target - len(selected)
|
||||
selected.extend(ex for ex, _ in low_value[:remaining])
|
||||
|
||||
random.shuffle(selected)
|
||||
return selected
|
||||
|
||||
|
||||
def main():
|
||||
# Collecter tous les fichiers JSONL
|
||||
files = sorted(PROCESSED.glob("*.jsonl"))
|
||||
if not files:
|
||||
print("Aucun fichier .jsonl trouvé dans data/processed/")
|
||||
return
|
||||
|
||||
all_examples = []
|
||||
file_stats = {}
|
||||
|
||||
for f in files:
|
||||
examples = []
|
||||
with open(f) as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if line:
|
||||
examples.append(json.loads(line))
|
||||
|
||||
original_count = len(examples)
|
||||
|
||||
# Sous-échantillonner si configuré
|
||||
if f.name in SUBSAMPLE:
|
||||
target = SUBSAMPLE[f.name]
|
||||
if target is None or len(examples) <= target:
|
||||
print(f" {f.name}: {len(examples)} exemples (gardé tout)")
|
||||
else:
|
||||
examples = _smart_subsample(examples, target)
|
||||
print(f" {f.name}: {len(examples)} exemples (sous-échantillonné depuis {original_count})")
|
||||
else:
|
||||
print(f" {f.name}: {len(examples)} exemples")
|
||||
|
||||
file_stats[f.name] = len(examples)
|
||||
all_examples.extend(examples)
|
||||
|
||||
print(f"\nTotal brut : {len(all_examples)} exemples")
|
||||
|
||||
# Mélanger
|
||||
random.shuffle(all_examples)
|
||||
|
||||
# Split train / eval
|
||||
n_eval = max(1, int(len(all_examples) * EVAL_RATIO))
|
||||
eval_set = all_examples[:n_eval]
|
||||
train_set = all_examples[n_eval:]
|
||||
|
||||
# Écrire les datasets
|
||||
train_path = DATASETS / "pmsi_train.jsonl"
|
||||
eval_path = DATASETS / "pmsi_eval.jsonl"
|
||||
|
||||
for path, dataset in [(train_path, train_set), (eval_path, eval_set)]:
|
||||
with open(path, "w") as fh:
|
||||
for ex in dataset:
|
||||
fh.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
||||
|
||||
# Statistiques
|
||||
def compute_stats(dataset):
|
||||
total_tokens = 0
|
||||
max_tokens = 0
|
||||
min_tokens = float("inf")
|
||||
for ex in dataset:
|
||||
text = " ".join(m["content"] for m in ex["messages"])
|
||||
tokens = count_tokens_approx(text)
|
||||
total_tokens += tokens
|
||||
max_tokens = max(max_tokens, tokens)
|
||||
min_tokens = min(min_tokens, tokens)
|
||||
avg = total_tokens / len(dataset) if dataset else 0
|
||||
return {
|
||||
"count": len(dataset),
|
||||
"total_tokens_approx": total_tokens,
|
||||
"avg_tokens": round(avg),
|
||||
"max_tokens": max_tokens,
|
||||
"min_tokens": min_tokens if min_tokens != float("inf") else 0,
|
||||
}
|
||||
|
||||
train_stats = compute_stats(train_set)
|
||||
eval_stats = compute_stats(eval_set)
|
||||
|
||||
stats = {
|
||||
"sources": file_stats,
|
||||
"total": len(all_examples),
|
||||
"train": train_stats,
|
||||
"eval": eval_stats,
|
||||
"eval_ratio": EVAL_RATIO,
|
||||
}
|
||||
|
||||
stats_path = DATASETS / "stats.json"
|
||||
with open(stats_path, "w") as fh:
|
||||
json.dump(stats, fh, indent=2, ensure_ascii=False)
|
||||
|
||||
# Calculer le ratio raisonnement
|
||||
n_reasoning = sum(1 for ex in all_examples if _reasoning_score(ex) >= 3)
|
||||
pct_reasoning = 100.0 * n_reasoning / len(all_examples) if all_examples else 0
|
||||
|
||||
# Affichage
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Dataset final (V2 rééquilibré) :")
|
||||
print(f" Train : {train_stats['count']} exemples → {train_path.name}")
|
||||
print(f" Eval : {eval_stats['count']} exemples → {eval_path.name}")
|
||||
print(f"\nRatio raisonnement structuré (score≥3) : {n_reasoning}/{len(all_examples)} ({pct_reasoning:.0f}%)")
|
||||
print(f"\nTokens (estimation) :")
|
||||
print(f" Train : ~{train_stats['total_tokens_approx']:,} tokens (moy: {train_stats['avg_tokens']}, max: {train_stats['max_tokens']})")
|
||||
print(f" Eval : ~{eval_stats['total_tokens_approx']:,} tokens (moy: {eval_stats['avg_tokens']}, max: {eval_stats['max_tokens']})")
|
||||
print(f"\nSources :")
|
||||
for name, count in sorted(file_stats.items()):
|
||||
print(f" {name}: {count}")
|
||||
print(f"\nFichiers :")
|
||||
print(f" {train_path} ({train_path.stat().st_size / 1024 / 1024:.1f} Mo)")
|
||||
print(f" {eval_path} ({eval_path.stat().st_size / 1024 / 1024:.1f} Mo)")
|
||||
print(f" {stats_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user