#!/usr/bin/env python3 """ Phase 1 — Fusion de tous les sous-datasets en un dataset final. Lit tous les .jsonl dans data/processed/ et produit : - data/datasets/pmsi_train.jsonl (90%) - data/datasets/pmsi_eval.jsonl (10%) - data/datasets/stats.json (statistiques) V2 : Rééquilibrage — réduction drastique des lookups, priorité raisonnement. Ratio cible : ~35% lookups / 40% raisonnement / 25% règles (vs 95/3/2 avant). """ import json import random from pathlib import Path random.seed(42) BASE = Path(__file__).resolve().parent.parent PROCESSED = BASE / "data" / "processed" DATASETS = BASE / "data" / "datasets" DATASETS.mkdir(parents=True, exist_ok=True) EVAL_RATIO = 0.10 # 10% pour l'évaluation # V2 : Sous-échantillonnage agressif des lookups, garder tout le raisonnement SUBSAMPLE = { "cim10_chatml.jsonl": 1_500, # 50K → 1.5K (discrimination + inclusions/exclusions) "ccam_chatml.jsonl": 1_500, # 16.5K → 1.5K (codes les plus fréquents) "cocoa_chatml.jsonl": 2_000, # 30K → 2K (types clinical + tips prioritaires) "referentiels_chatml.jsonl": None, # Garder tout (5.3K, déjà dense en règles) } # Mots-clés indiquant du raisonnement structuré (pour sélection intelligente) REASONING_KEYWORDS = [ "analyse_clinique", "discrimination", "regle_pmsi", "règle_pmsi", "codes_candidats", "justification", "ne pas coder", "ne doit pas", "à l'exclusion", "comprend", "sévérité", "CMA", ] def count_tokens_approx(text): """Estimation grossière du nombre de tokens (~1.3 token par mot français).""" return int(len(text.split()) * 1.3) def _reasoning_score(example: dict) -> int: """Score de raisonnement d'un exemple (plus haut = plus précieux).""" text = " ".join(m.get("content", "") for m in example.get("messages", [])) text_lower = text.lower() score = 0 for kw in REASONING_KEYWORDS: if kw.lower() in text_lower: score += 1 # Bonus pour les réponses JSON structurées longues (raisonnement complet) assistant_msgs = [m["content"] for m in example.get("messages", []) if m["role"] == "assistant"] for msg in assistant_msgs: if len(msg) > 200: score += 2 if '"analyse_clinique"' in msg: score += 3 return score def _smart_subsample(examples: list[dict], target: int) -> list[dict]: """Sous-échantillonnage intelligent : prioriser les exemples avec raisonnement.""" # Trier par score de raisonnement (décroissant), puis shuffle pour diversité scored = [(ex, _reasoning_score(ex)) for ex in examples] scored.sort(key=lambda x: -x[1]) # Garder tous ceux avec score > 0, puis compléter avec du random high_value = [(ex, s) for ex, s in scored if s > 0] low_value = [(ex, s) for ex, s in scored if s == 0] random.shuffle(low_value) if len(high_value) >= target: # Trop d'exemples high-value : prendre les meilleurs + random parmi eux random.shuffle(high_value) selected = [ex for ex, _ in high_value[:target]] else: # Prendre tous les high-value + compléter avec low-value selected = [ex for ex, _ in high_value] remaining = target - len(selected) selected.extend(ex for ex, _ in low_value[:remaining]) random.shuffle(selected) return selected def main(): # Collecter tous les fichiers JSONL files = sorted(PROCESSED.glob("*.jsonl")) if not files: print("Aucun fichier .jsonl trouvé dans data/processed/") return all_examples = [] file_stats = {} for f in files: examples = [] with open(f) as fh: for line in fh: line = line.strip() if line: examples.append(json.loads(line)) original_count = len(examples) # Sous-échantillonner si configuré if f.name in SUBSAMPLE: target = SUBSAMPLE[f.name] if target is None or len(examples) <= target: print(f" {f.name}: {len(examples)} exemples (gardé tout)") else: examples = _smart_subsample(examples, target) print(f" {f.name}: {len(examples)} exemples (sous-échantillonné depuis {original_count})") else: print(f" {f.name}: {len(examples)} exemples") file_stats[f.name] = len(examples) all_examples.extend(examples) print(f"\nTotal brut : {len(all_examples)} exemples") # Mélanger random.shuffle(all_examples) # Split train / eval n_eval = max(1, int(len(all_examples) * EVAL_RATIO)) eval_set = all_examples[:n_eval] train_set = all_examples[n_eval:] # Écrire les datasets train_path = DATASETS / "pmsi_train.jsonl" eval_path = DATASETS / "pmsi_eval.jsonl" for path, dataset in [(train_path, train_set), (eval_path, eval_set)]: with open(path, "w") as fh: for ex in dataset: fh.write(json.dumps(ex, ensure_ascii=False) + "\n") # Statistiques def compute_stats(dataset): total_tokens = 0 max_tokens = 0 min_tokens = float("inf") for ex in dataset: text = " ".join(m["content"] for m in ex["messages"]) tokens = count_tokens_approx(text) total_tokens += tokens max_tokens = max(max_tokens, tokens) min_tokens = min(min_tokens, tokens) avg = total_tokens / len(dataset) if dataset else 0 return { "count": len(dataset), "total_tokens_approx": total_tokens, "avg_tokens": round(avg), "max_tokens": max_tokens, "min_tokens": min_tokens if min_tokens != float("inf") else 0, } train_stats = compute_stats(train_set) eval_stats = compute_stats(eval_set) stats = { "sources": file_stats, "total": len(all_examples), "train": train_stats, "eval": eval_stats, "eval_ratio": EVAL_RATIO, } stats_path = DATASETS / "stats.json" with open(stats_path, "w") as fh: json.dump(stats, fh, indent=2, ensure_ascii=False) # Calculer le ratio raisonnement n_reasoning = sum(1 for ex in all_examples if _reasoning_score(ex) >= 3) pct_reasoning = 100.0 * n_reasoning / len(all_examples) if all_examples else 0 # Affichage print(f"\n{'='*50}") print(f"Dataset final (V2 rééquilibré) :") print(f" Train : {train_stats['count']} exemples → {train_path.name}") print(f" Eval : {eval_stats['count']} exemples → {eval_path.name}") print(f"\nRatio raisonnement structuré (score≥3) : {n_reasoning}/{len(all_examples)} ({pct_reasoning:.0f}%)") print(f"\nTokens (estimation) :") print(f" Train : ~{train_stats['total_tokens_approx']:,} tokens (moy: {train_stats['avg_tokens']}, max: {train_stats['max_tokens']})") print(f" Eval : ~{eval_stats['total_tokens_approx']:,} tokens (moy: {eval_stats['avg_tokens']}, max: {eval_stats['max_tokens']})") print(f"\nSources :") for name, count in sorted(file_stats.items()): print(f" {name}: {count}") print(f"\nFichiers :") print(f" {train_path} ({train_path.stat().st_size / 1024 / 1024:.1f} Mo)") print(f" {eval_path} ({eval_path.stat().st_size / 1024 / 1024:.1f} Mo)") print(f" {stats_path}") if __name__ == "__main__": main()