commit 06100df236195212ef84880ccbc56ab44a31e512 Author: dom Date: Mon Feb 16 19:42:33 2026 +0100 feat: rééquilibrage dataset LoRA — raisonnement DIM vs mémorisation Passe de 95/3/2 (lookups/raisonnement/règles) à ~31/49/20. Dataset cible ~16K exemples denses (vs 66K de lookups avant). Modifiés : - 03_convert_cache.py : cache complet 1840 entrées (actuel + backup) - 04_build_dataset.py : subsampling agressif (CIM-10 1.5K, CCAM 1.5K, CoCoA 2K) + sélection intelligente priorisant le raisonnement - 12_generate_pipeline_examples.py : 3 templates (court + long + CPAM), cache actuel, cible ~2800 exemples Créés : - 13_generate_fascicule_reasoning.py : parsing 10 fascicules ATIH, génération Q&A raisonnement via Claude Opus 4.6 (~450 exemples) - 14_generate_negative_examples.py : 1000 exemples négatifs (symptômes/DP, redondances sémantiques, DAS non significatifs) - 15_generate_discrimination.py : 800 exercices de discrimination entre codes siblings CIM-10 via Claude Opus 4.6 - 16_parse_guide_metho.py : extraction Guide Méthodologique MCO 2026, Q&A directes + raisonnement via Claude Opus 4.6 (~500 exemples) Co-Authored-By: Claude Opus 4.6 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..454e495 --- /dev/null +++ b/.gitignore @@ -0,0 +1,34 @@ +# Data files (trop volumineux pour git) +data/raw/*.pdf +data/raw/*.json +data/processed/*.jsonl +data/processed/*.json +data/datasets/*.jsonl +data/datasets/*.json + +# Models +models/ + +# Python +__pycache__/ +*.pyc +.venv/ + +# IDE +.idea/ +.vscode/ + +# Aim tracking +.aim/ + +# Unsloth compiled cache +unsloth_compiled_cache/ + +# llama.cpp build +llama.cpp/ + +# Runpod data copies +runpod/data/ + +# OS +.DS_Store diff --git a/runpod/README.md b/runpod/README.md new file mode 100644 index 0000000..8e9b95f --- /dev/null +++ b/runpod/README.md @@ -0,0 +1,64 @@ +# Fine-tuning pmsi-coder sur RunPod + +## 1. Créer un pod + +- **Template** : RunPod PyTorch 2.4+ (CUDA 12.x) +- **GPU recommandé** : A100 40GB (~1.50€/h) ou A100 80GB (~2.50€/h) +- **Disk** : 50 Go minimum (modèle 12B + dataset + GGUF) +- **Volume persistant** : optionnel, utile si on veut garder les checkpoints + +## 2. Upload des fichiers + +```bash +# Depuis la machine locale +rsync -avz --progress \ + runpod/ \ + root@RUNPOD_IP:/workspace/t2a-finetune/ + +# Ou via l'interface web RunPod (Jupyter → upload) +``` + +Les fichiers nécessaires : +- `train_runpod.py` — script d'entraînement +- `setup.sh` — installation des dépendances +- `data/pmsi_train.jsonl` — dataset train (38 Mo) +- `data/pmsi_eval.jsonl` — dataset eval (4.2 Mo) + +## 3. Setup + +```bash +cd /workspace/t2a-finetune +bash setup.sh +``` + +## 4. Lancer l'entraînement + +```bash +python train_runpod.py --epochs 3 --export-gguf +``` + +Options : +- `--max-seq-length 2048` (défaut, vs 512 en local) +- `--batch 0` (auto-detect selon VRAM, défaut) +- `--lr 2e-4` (learning rate) +- `--lora-r 32` (rang LoRA) +- `--export-gguf` (produire le .gguf pour Ollama) + +## 5. Récupérer le GGUF + +```bash +# Sur la machine locale +scp root@RUNPOD_IP:/workspace/t2a-finetune/models/pmsi-gguf/*.gguf . +scp root@RUNPOD_IP:/workspace/t2a-finetune/models/pmsi-gguf/Modelfile . + +# Importer dans Ollama +ollama create pmsi-coder -f Modelfile +``` + +## Estimations + +| GPU | Batch | Temps 3 epochs | Coût | +|-----|-------|----------------|------| +| A100 40GB | 4 | ~2-3h | ~4-5€ | +| A100 80GB | 8 | ~1.5-2h | ~4-5€ | +| H100 80GB | 8 | ~1-1.5h | ~4-5€ | diff --git a/runpod/setup.sh b/runpod/setup.sh new file mode 100755 index 0000000..bc74cb2 --- /dev/null +++ b/runpod/setup.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Setup RunPod pour fine-tuning pmsi-coder +# Usage : bash setup.sh + +set -e + +echo "=== Setup fine-tuning PMSI-coder sur RunPod ===" + +# Installer Unsloth + dépendances +pip install --no-deps "unsloth[cu124-ampere-torch250] @ git+https://github.com/unslothai/unsloth.git" +pip install --no-deps unsloth_zoo +pip install trl datasets peft accelerate bitsandbytes sentencepiece protobuf + +# Wandb optionnel +pip install wandb 2>/dev/null || echo "wandb non installé (optionnel)" + +echo "" +echo "=== Setup terminé ===" +echo "" +echo "Vérifiez que les fichiers data/ sont présents :" +ls -lh data/pmsi_train.jsonl data/pmsi_eval.jsonl 2>/dev/null || echo " MANQUANT ! Uploadez les datasets." +echo "" +echo "Pour lancer :" +echo " python train_runpod.py --epochs 3 --export-gguf" +echo "" +echo "Estimation A100 40GB : ~2-3h pour 3 epochs" +echo "Estimation A100 80GB : ~1.5-2h pour 3 epochs" diff --git a/runpod/train_runpod.py b/runpod/train_runpod.py new file mode 100644 index 0000000..7f3ee62 --- /dev/null +++ b/runpod/train_runpod.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +""" +Fine-tuning QLoRA de gemma3:12b sur RunPod (A100 40/80GB). + +Différences vs local (RTX 5070 12GB) : + - max_seq_length=2048 (vs 512) — prompts pipeline non tronqués + - batch_size=4 (vs 1) — convergence plus stable + - 3 epochs (vs 1) — meilleure mémorisation + - Dataset complet sans sous-échantillonnage + +Usage sur RunPod : + 1. Créer un pod avec template PyTorch 2.4+ / CUDA 12.x + 2. rsync -avz runpod/ runpod_host:/workspace/t2a-finetune/ + 3. bash /workspace/t2a-finetune/setup.sh + 4. python /workspace/t2a-finetune/train_runpod.py [--epochs 3] [--export-gguf] +""" + +import argparse +import json +import os +from pathlib import Path + +BASE = Path(__file__).resolve().parent +DATASETS = BASE / "data" +OUTPUT = BASE / "models" +OUTPUT.mkdir(parents=True, exist_ok=True) + + +def check_prerequisites(): + import torch + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA non disponible.") + + gpu_name = torch.cuda.get_device_name(0) + vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3 + + print(f"GPU: {gpu_name}") + print(f"VRAM: {vram_total:.1f} Go") + + train_path = DATASETS / "pmsi_train.jsonl" + eval_path = DATASETS / "pmsi_eval.jsonl" + if not train_path.exists() or not eval_path.exists(): + raise FileNotFoundError("Dataset non trouvé dans data/") + + with open(train_path) as f: + n_train = sum(1 for _ in f) + with open(eval_path) as f: + n_eval = sum(1 for _ in f) + + print(f"Dataset: {n_train} train + {n_eval} eval") + + # Adapter le batch size à la VRAM + if vram_total >= 70: + suggested_batch = 8 + elif vram_total >= 35: + suggested_batch = 4 + else: + suggested_batch = 2 + + print(f"Batch size suggéré: {suggested_batch}") + return train_path, eval_path, suggested_batch + + +def load_model(model_name, max_seq_length, load_in_4bit=True): + from unsloth import FastLanguageModel + + print(f"\nChargement de {model_name} (4-bit={load_in_4bit})...") + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=model_name, + max_seq_length=max_seq_length, + dtype=None, + load_in_4bit=load_in_4bit, + ) + + print(f" Modèle chargé : {model.config._name_or_path}") + print(f" Paramètres : {model.num_parameters() / 1e9:.1f}B") + return model, tokenizer + + +def attach_lora(model, r=32, alpha=64, dropout=0.0): + from unsloth import FastLanguageModel + + print(f"\nLoRA (r={r}, alpha={alpha})...") + + model = FastLanguageModel.get_peft_model( + model, + r=r, + target_modules=[ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + ], + lora_alpha=alpha, + lora_dropout=dropout, + bias="none", + use_gradient_checkpointing="unsloth", + random_state=42, + ) + + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + print(f" Entraînables : {trainable / 1e6:.1f}M / {total / 1e9:.1f}B ({100 * trainable / total:.2f}%)") + return model + + +def load_dataset(train_path, eval_path): + from datasets import Dataset + + def load_jsonl(path): + examples = [] + with open(path) as f: + for line in f: + examples.append(json.loads(line.strip())) + return examples + + train_ds = Dataset.from_list(load_jsonl(train_path)) + eval_ds = Dataset.from_list(load_jsonl(eval_path)) + + print(f"\nDataset : {len(train_ds)} train + {len(eval_ds)} eval") + return train_ds, eval_ds + + +def format_chat(example, tokenizer): + text = tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + add_generation_prompt=False, + ) + return {"text": text} + + +def train(model, tokenizer, train_ds, eval_ds, args): + from trl import SFTTrainer, SFTConfig + + print(f"\nConfig entraînement :") + print(f" Epochs : {args.epochs}") + print(f" LR : {args.lr}") + print(f" Batch : {args.batch} x grad_accum={args.grad_accum} = {args.batch * args.grad_accum}") + print(f" Max seq length : {args.max_seq_length}") + + train_ds = train_ds.map(lambda x: format_chat(x, tokenizer), num_proc=4) + eval_ds = eval_ds.map(lambda x: format_chat(x, tokenizer), num_proc=4) + + output_dir = OUTPUT / "pmsi-lora-checkpoints" + + # Wandb optionnel + report = "none" + callbacks = [] + try: + import wandb + wandb.init(project="pmsi-coder", name=f"runpod-{args.epochs}ep-seq{args.max_seq_length}") + report = "wandb" + print(" Tracking : wandb") + except ImportError: + print(" Tracking : none (pip install wandb pour activer)") + + training_args = SFTConfig( + output_dir=str(output_dir), + num_train_epochs=args.epochs, + per_device_train_batch_size=args.batch, + per_device_eval_batch_size=args.batch, + gradient_accumulation_steps=args.grad_accum, + learning_rate=args.lr, + weight_decay=0.01, + warmup_ratio=0.05, + lr_scheduler_type="cosine", + logging_steps=10, + eval_strategy="steps", + eval_steps=500, + save_strategy="steps", + save_steps=500, + save_total_limit=3, + fp16=False, + bf16=True, + max_seq_length=args.max_seq_length, + dataset_text_field="text", + seed=42, + report_to=report, + ) + + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=train_ds, + eval_dataset=eval_ds, + args=training_args, + callbacks=callbacks, + ) + + total_steps = len(train_ds) * args.epochs // (args.batch * args.grad_accum) + print(f"\n Steps estimés : ~{total_steps}") + print(f" Démarrage...") + + if args.resume: + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + + final_dir = OUTPUT / "pmsi-lora-final" + model.save_pretrained(str(final_dir)) + tokenizer.save_pretrained(str(final_dir)) + print(f"\nLoRA sauvegardé : {final_dir}") + + return trainer, final_dir + + +def export_gguf(model, tokenizer, final_dir, quantization="q4_k_m"): + print(f"\nExport GGUF ({quantization})...") + + gguf_dir = OUTPUT / "pmsi-gguf" + gguf_dir.mkdir(parents=True, exist_ok=True) + + model.save_pretrained_gguf( + str(gguf_dir), + tokenizer, + quantization_method=quantization, + ) + + gguf_files = list(gguf_dir.glob("*.gguf")) + if gguf_files: + gguf_path = gguf_files[0] + size_gb = gguf_path.stat().st_size / 1024**3 + print(f" GGUF : {gguf_path.name} ({size_gb:.1f} Go)") + + modelfile_path = gguf_dir / "Modelfile" + with open(modelfile_path, "w") as f: + f.write(f"FROM {gguf_path.name}\n\n") + f.write("PARAMETER temperature 0.3\n") + f.write("PARAMETER top_p 0.9\n") + f.write("PARAMETER num_ctx 8192\n") + + print(f" Modelfile créé") + print(f"\n Pour récupérer : scp runpod:{gguf_path} .") + print(f" Puis : ollama create pmsi-coder -f Modelfile") + + +def main(): + parser = argparse.ArgumentParser(description="Fine-tuning QLoRA RunPod") + + parser.add_argument("--model", default="unsloth/gemma-3-12b-it-bnb-4bit") + parser.add_argument("--max-seq-length", type=int, default=2048) + + parser.add_argument("--lora-r", type=int, default=32) + parser.add_argument("--lora-alpha", type=int, default=64) + parser.add_argument("--lora-dropout", type=float, default=0.0) + + parser.add_argument("--epochs", type=int, default=3) + parser.add_argument("--lr", type=float, default=2e-4) + parser.add_argument("--batch", type=int, default=0, help="0=auto-detect") + parser.add_argument("--grad-accum", type=int, default=4) + + parser.add_argument("--resume", action="store_true") + parser.add_argument("--export-gguf", action="store_true") + parser.add_argument("--gguf-quant", default="q4_k_m") + + args = parser.parse_args() + + train_path, eval_path, suggested_batch = check_prerequisites() + + if args.batch == 0: + args.batch = suggested_batch + print(f"Batch auto-détecté : {args.batch}") + + model, tokenizer = load_model(args.model, args.max_seq_length) + model = attach_lora(model, r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout) + train_ds, eval_ds = load_dataset(train_path, eval_path) + trainer, final_dir = train(model, tokenizer, train_ds, eval_ds, args) + + if args.export_gguf: + export_gguf(model, tokenizer, final_dir, args.gguf_quant) + + print("\n" + "=" * 50) + print("Fine-tuning terminé !") + print(f" LoRA : {final_dir}") + if args.export_gguf: + print(f" GGUF : {OUTPUT / 'pmsi-gguf'}") + print("=" * 50) + + +if __name__ == "__main__": + main() diff --git a/scripts/01_generate_cim10_pairs.py b/scripts/01_generate_cim10_pairs.py new file mode 100644 index 0000000..562aa19 --- /dev/null +++ b/scripts/01_generate_cim10_pairs.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +""" +Phase 1A — Génération de paires ChatML CIM-10 depuis le FHIR JSON. + +Sources : smt_cim10_fhir.json (19 161 concepts) +Produit : data/processed/cim10_chatml.jsonl + +Types d'exemples générés : + 1. code → description (lookup) + 2. description → code (codage) + 3. discrimination entre codes frères (même parent) + 4. inclusions/exclusions (ce qui est compris / exclu d'un code) +""" + +import json +import random +from pathlib import Path + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +RAW = BASE / "data" / "raw" +OUT = BASE / "data" / "processed" +OUT.mkdir(parents=True, exist_ok=True) + +SYSTEM_MSG = "Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français." + +# --- Chapitres CIM-10 --- +CHAPTERS = { + "A": "Certaines maladies infectieuses et parasitaires", + "B": "Certaines maladies infectieuses et parasitaires", + "C": "Tumeurs", + "D": "Tumeurs / Maladies du sang", + "E": "Maladies endocriniennes, nutritionnelles et métaboliques", + "F": "Troubles mentaux et du comportement", + "G": "Maladies du système nerveux", + "H": "Maladies de l'œil et de l'oreille", + "I": "Maladies de l'appareil circulatoire", + "J": "Maladies de l'appareil respiratoire", + "K": "Maladies de l'appareil digestif", + "L": "Maladies de la peau et du tissu cellulaire sous-cutané", + "M": "Maladies du système ostéo-articulaire", + "N": "Maladies de l'appareil génito-urinaire", + "O": "Grossesse, accouchement et puerpéralité", + "P": "Certaines affections dont l'origine se situe dans la période périnatale", + "Q": "Malformations congénitales et anomalies chromosomiques", + "R": "Symptômes, signes et résultats anormaux", + "S": "Lésions traumatiques et empoisonnements", + "T": "Lésions traumatiques et empoisonnements", + "V": "Causes externes de morbidité et de mortalité", + "W": "Causes externes de morbidité et de mortalité", + "X": "Causes externes de morbidité et de mortalité", + "Y": "Causes externes de morbidité et de mortalité", + "Z": "Facteurs influant sur l'état de santé et motifs de recours aux services de santé", + "U": "Codes d'utilisation particulière", +} + + +def load_fhir(): + """Charger et indexer les concepts FHIR.""" + with open(RAW / "smt_cim10_fhir.json") as f: + data = json.load(f) + concepts = data["concept"] + + # Index par code + by_code = {} + for c in concepts: + by_code[c["code"]] = c + + return concepts, by_code + + +def get_props(concept): + """Extraire les propriétés d'un concept sous forme de dict (multi-valeurs en listes).""" + props = {} + for p in concept.get("property", []): + key = p["code"] + val = p.get("valueCode", p.get("valueString", "")) + if key in props: + if isinstance(props[key], list): + props[key].append(val) + else: + props[key] = [props[key], val] + else: + props[key] = val + return props + + +def clean_display(display): + """Nettoyer le libellé (enlever les codes entre crochets type [G31.0]).""" + import re + # Retirer les références entre crochets comme [G31.0†] + cleaned = re.sub(r'\s*\[[\w.†*+]+\]\s*', ' ', display) + # Retirer les guillemets décoratifs + cleaned = cleaned.replace('"', '').replace('"', '').replace('"', '') + # Nettoyer les espaces multiples et les tirets isolés + cleaned = re.sub(r'\s*-\s*-\s*', ' - ', cleaned) + cleaned = re.sub(r'\s+', ' ', cleaned).strip() + return cleaned + + +def make_chatml(system, user, assistant): + """Créer un exemple ChatML.""" + return { + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + {"role": "assistant", "content": assistant}, + ] + } + + +def generate_lookup_pairs(concepts, by_code): + """Type 1 : code → description.""" + pairs = [] + for c in concepts: + props = get_props(c) + if props.get("type") not in ("category",): + continue + + code = c["code"] + display = clean_display(c["display"]) + + if not display or len(display) < 3: + continue + + chapter = CHAPTERS.get(code[0], "") + parent_code = props.get("parent", "") + parent_display = "" + if parent_code and parent_code in by_code: + parent_display = clean_display(by_code[parent_code]["display"]) + + # Construire la réponse enrichie + answer_parts = [f"{code} — {display}"] + if chapter: + answer_parts.append(f"Chapitre : {chapter}") + if parent_display and parent_code != code: + answer_parts.append(f"Catégorie parente : {parent_code} ({parent_display})") + + # Ajouter inclusions si présentes + incl = props.get("inclusionNote", "") + if incl and len(incl) < 300: + answer_parts.append(f"Comprend : {incl}") + + # Varier les formulations de question + templates = [ + f"Que désigne le code CIM-10 {code} ?", + f"Quel est le libellé du code {code} ?", + f"Décris le code CIM-10 {code}.", + ] + question = random.choice(templates) + + pairs.append(make_chatml(SYSTEM_MSG, question, "\n".join(answer_parts))) + + return pairs + + +def generate_coding_pairs(concepts, by_code): + """Type 2 : description → code (codage direct).""" + pairs = [] + for c in concepts: + props = get_props(c) + if props.get("type") not in ("category",): + continue + + code = c["code"] + display = clean_display(c["display"]) + + if not display or len(display) < 5: + continue + + # Réponse JSON structurée (format du pipeline T2A) + answer = json.dumps({ + "code": code, + "confidence": "high", + "justification": f"Correspondance directe avec le libellé CIM-10 : {code} {display}." + }, ensure_ascii=False) + + templates = [ + f"Quel est le code CIM-10 pour : {display} ?", + f"Code CIM-10 pour « {display} » ?", + f"Codage CIM-10 du diagnostic : {display}", + ] + question = random.choice(templates) + + pairs.append(make_chatml(SYSTEM_MSG, question, answer)) + + return pairs + + +def generate_discrimination_pairs(concepts, by_code): + """Type 3 : discrimination entre codes frères (même parent).""" + pairs = [] + + # Grouper par parent + children_by_parent = {} + for c in concepts: + props = get_props(c) + if props.get("type") != "category": + continue + parent = props.get("parent", "") + if parent and parent in by_code: + children_by_parent.setdefault(parent, []).append(c) + + for parent_code, children in children_by_parent.items(): + if len(children) < 2 or len(children) > 15: + continue + + parent = by_code[parent_code] + parent_display = clean_display(parent["display"]) + + # Construire la question + question = f"Quels sont les sous-codes de {parent_code} ({parent_display}) et comment les distinguer ?" + + # Construire la réponse + lines = [f"La catégorie {parent_code} ({parent_display}) comprend les codes suivants :\n"] + for child in children: + child_display = clean_display(child["display"]) + child_props = get_props(child) + line = f"- {child['code']} : {child_display}" + # Ajouter une note d'inclusion courte si disponible + incl = child_props.get("inclusionNote", "") + if incl and len(incl) < 150: + line += f" (comprend : {incl})" + lines.append(line) + + lines.append(f"\nLe choix du code dépend de la précision diagnostique disponible. " + f"En l'absence de précision, utiliser le code SAI (.9) s'il existe.") + + answer = "\n".join(lines) + + # Limiter la taille + if len(answer) > 2000: + continue + + pairs.append(make_chatml(SYSTEM_MSG, question, answer)) + + return pairs + + +def generate_inclusion_exclusion_pairs(concepts, by_code): + """Type 4 : questions sur les inclusions/exclusions d'un code.""" + pairs = [] + for c in concepts: + props = get_props(c) + if props.get("type") not in ("category",): + continue + + code = c["code"] + display = clean_display(c["display"]) + incl = props.get("inclusionNote", "") + excl_note = props.get("exclusionNote", "") + excl_codes = props.get("exclusion", "") + note = props.get("note", "") + + # Il faut au moins une inclusion OU exclusion + if not incl and not excl_note: + continue + + # Construire la réponse + answer_parts = [f"Code {code} — {display}\n"] + + if incl: + answer_parts.append(f"Ce code COMPREND :\n{incl}") + + if excl_note: + answer_parts.append(f"\nCe code EXCLUT :\n{excl_note}") + + if note: + answer_parts.append(f"\nNote : {note}") + + answer = "\n".join(answer_parts) + if len(answer) > 2000: + continue + + templates = [ + f"Quelles sont les inclusions et exclusions du code {code} ({display}) ?", + f"Que comprend et que exclut le code CIM-10 {code} ?", + ] + question = random.choice(templates) + + pairs.append(make_chatml(SYSTEM_MSG, question, answer)) + + return pairs + + +def main(): + print("Chargement du FHIR JSON...") + concepts, by_code = load_fhir() + print(f" {len(concepts)} concepts chargés") + + print("\nGénération des paires...") + + print(" Type 1 : code → description (lookup)") + lookup = generate_lookup_pairs(concepts, by_code) + print(f" → {len(lookup)} exemples") + + print(" Type 2 : description → code (codage)") + coding = generate_coding_pairs(concepts, by_code) + print(f" → {len(coding)} exemples") + + print(" Type 3 : discrimination codes frères") + discrim = generate_discrimination_pairs(concepts, by_code) + print(f" → {len(discrim)} exemples") + + print(" Type 4 : inclusions / exclusions") + incl_excl = generate_inclusion_exclusion_pairs(concepts, by_code) + print(f" → {len(incl_excl)} exemples") + + # Fusionner et mélanger + all_pairs = lookup + coding + discrim + incl_excl + random.shuffle(all_pairs) + + # Écrire en JSONL + output_path = OUT / "cim10_chatml.jsonl" + with open(output_path, "w") as f: + for pair in all_pairs: + f.write(json.dumps(pair, ensure_ascii=False) + "\n") + + print(f"\nTotal : {len(all_pairs)} exemples → {output_path}") + print(f"Taille : {output_path.stat().st_size / 1024 / 1024:.1f} Mo") + + # Stats par type + print("\nRépartition :") + print(f" Lookup (code→desc) : {len(lookup)}") + print(f" Codage (desc→code) : {len(coding)}") + print(f" Discrimination : {len(discrim)}") + print(f" Inclusions/Exclus. : {len(incl_excl)}") + + +if __name__ == "__main__": + main() diff --git a/scripts/02_generate_ccam_pairs.py b/scripts/02_generate_ccam_pairs.py new file mode 100644 index 0000000..e7e2b2b --- /dev/null +++ b/scripts/02_generate_ccam_pairs.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +""" +Phase 1B — Génération de paires ChatML CCAM depuis ccam_dict.json. + +Sources : ccam_dict.json (8 257 codes) du projet T2A +Produit : data/processed/ccam_chatml.jsonl + +Types d'exemples générés : + 1. code → description (lookup) + 2. description → code (codage) + 3. discrimination par regroupement (codes du même regroupement) +""" + +import json +import random +from pathlib import Path + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +T2A = Path("/home/dom/ai/t2a") +OUT = BASE / "data" / "processed" +OUT.mkdir(parents=True, exist_ok=True) + +SYSTEM_MSG = "Tu es un médecin DIM expert en codage CCAM pour le PMSI français." + + +def load_ccam(): + """Charger le dictionnaire CCAM.""" + with open(T2A / "data" / "ccam_dict.json") as f: + return json.load(f) + + +def make_chatml(system, user, assistant): + return { + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + {"role": "assistant", "content": assistant}, + ] + } + + +def generate_lookup_pairs(ccam): + """Type 1 : code → description.""" + pairs = [] + for code, info in ccam.items(): + desc = info.get("description", "") + if not desc or len(desc) < 5: + continue + + regroupement = info.get("regroupement", "") + activite = info.get("activite", "") + tarif = info.get("tarif_s1") + + answer_parts = [f"{code} — {desc}"] + if regroupement: + answer_parts.append(f"Regroupement : {regroupement}") + if activite: + answer_parts.append(f"Activité : {activite}") + if tarif: + answer_parts.append(f"Tarif secteur 1 : {tarif} €") + + templates = [ + f"Que désigne le code CCAM {code} ?", + f"Quel est le libellé de l'acte CCAM {code} ?", + f"Décris l'acte CCAM {code}.", + ] + + pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), "\n".join(answer_parts))) + + return pairs + + +def generate_coding_pairs(ccam): + """Type 2 : description → code.""" + pairs = [] + for code, info in ccam.items(): + desc = info.get("description", "") + if not desc or len(desc) < 10: + continue + + answer = json.dumps({ + "code": code, + "confidence": "high", + "justification": f"Correspondance directe avec le libellé CCAM : {code} {desc}." + }, ensure_ascii=False) + + templates = [ + f"Quel est le code CCAM pour : {desc} ?", + f"Code CCAM pour « {desc} » ?", + f"Codage CCAM de l'acte : {desc}", + ] + + pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), answer)) + + return pairs + + +def generate_regroupement_pairs(ccam): + """Type 3 : regroupement → liste des actes du même regroupement.""" + pairs = [] + + # Grouper par regroupement + by_regroup = {} + for code, info in ccam.items(): + reg = info.get("regroupement", "") + if reg: + by_regroup.setdefault(reg, []).append((code, info)) + + for reg, actes in by_regroup.items(): + if len(actes) < 2 or len(actes) > 20: + continue + + question = f"Quels sont les actes CCAM du regroupement {reg} ?" + + lines = [f"Le regroupement {reg} comprend {len(actes)} actes :\n"] + for code, info in actes[:15]: + desc = info.get("description", "") + tarif = info.get("tarif_s1") + line = f"- {code} : {desc}" + if tarif: + line += f" ({tarif} €)" + lines.append(line) + + if len(actes) > 15: + lines.append(f" ... et {len(actes) - 15} autres actes.") + + answer = "\n".join(lines) + if len(answer) > 2000: + continue + + pairs.append(make_chatml(SYSTEM_MSG, question, answer)) + + return pairs + + +def main(): + print("Chargement du dictionnaire CCAM...") + ccam = load_ccam() + print(f" {len(ccam)} codes chargés") + + print("\nGénération des paires...") + + print(" Type 1 : code → description (lookup)") + lookup = generate_lookup_pairs(ccam) + print(f" → {len(lookup)} exemples") + + print(" Type 2 : description → code (codage)") + coding = generate_coding_pairs(ccam) + print(f" → {len(coding)} exemples") + + print(" Type 3 : regroupement") + regroup = generate_regroupement_pairs(ccam) + print(f" → {len(regroup)} exemples") + + all_pairs = lookup + coding + regroup + random.shuffle(all_pairs) + + output_path = OUT / "ccam_chatml.jsonl" + with open(output_path, "w") as f: + for pair in all_pairs: + f.write(json.dumps(pair, ensure_ascii=False) + "\n") + + print(f"\nTotal : {len(all_pairs)} exemples → {output_path}") + print(f"Taille : {output_path.stat().st_size / 1024 / 1024:.1f} Mo") + + +if __name__ == "__main__": + main() diff --git a/scripts/03_convert_cache.py b/scripts/03_convert_cache.py new file mode 100644 index 0000000..bec59ac --- /dev/null +++ b/scripts/03_convert_cache.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +""" +Phase 1C — Conversion du cache Ollama en exemples de raisonnement ChatML. + +Sources : ollama_cache.json (1 840 entrées avec raisonnement complet) +Produit : data/processed/reasoning_chatml.jsonl + +V2 : Utilise le cache actuel complet (1 840 entrées vs 100 avant). + Filtre pour ne garder que les entrées avec raisonnement structuré. + Supporte aussi les clés das_llm::das_extract:: du pipeline étendu. + +Chaque entrée du cache contient un raisonnement structuré : + - analyse_clinique → codes_candidats → discrimination → regle_pmsi → code + justification +Ces exemples sont les plus précieux car ils montrent le raisonnement DIM complet. +""" + +import json +import random +from pathlib import Path + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +T2A = Path("/home/dom/ai/t2a") +OUT = BASE / "data" / "processed" +OUT.mkdir(parents=True, exist_ok=True) + +SYSTEM_MSG = "Tu es un médecin DIM expert en codage PMSI. Tu codes les diagnostics en CIM-10 en suivant une démarche structurée : analyse clinique, identification des codes candidats, discrimination, vérification des règles PMSI." + + +def make_chatml(system, user, assistant): + return { + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + {"role": "assistant", "content": assistant}, + ] + } + + +def load_cache(): + """Charger le cache Ollama (actuel + backup si disponible).""" + entries = {} + + # Cache actuel (1 840 entrées) + cache_path = T2A / "data" / "ollama_cache.json" + if cache_path.exists(): + with open(cache_path) as f: + data = json.load(f) + entries.update(data.get("entries", {})) + print(f" Cache actuel : {len(data.get('entries', {}))} entrées") + + # Cache backup (peut contenir des entrées supplémentaires) + backup_path = T2A / "data" / "ollama_cache_gemma3.bak" + if backup_path.exists(): + with open(backup_path) as f: + data = json.load(f) + backup_entries = data.get("entries", {}) + new_count = sum(1 for k in backup_entries if k not in entries) + entries.update(backup_entries) + print(f" Cache backup : {len(backup_entries)} entrées (+{new_count} nouvelles)") + + return entries + + +def parse_cache_key(key): + """Extraire le type (dp/das) et le texte depuis la clé du cache. + + Formats supportés : + - "dp::texte du diagnostic" + - "das::texte du diagnostic" + - "das_llm::das_extract::hash::texte" (pipeline étendu) + """ + if key.startswith("das_llm::das_extract::"): + # Format : das_llm::das_extract::HASH::texte + parts = key.split("::", 3) + texte = parts[3] if len(parts) > 3 else parts[-1] + return "das", texte.strip() + if "::" in key: + diag_type, texte = key.split("::", 1) + return diag_type.strip(), texte.strip() + return "das", key.strip() + + +def build_user_prompt(diag_type, texte): + """Construire le prompt utilisateur à partir du type et du texte.""" + type_label = "Diagnostic Principal (DP)" if diag_type == "dp" else "Diagnostic Associé Significatif (DAS)" + + prompt = f"Code ce diagnostic en CIM-10.\n\n" + prompt += f"DIAGNOSTIC : {texte.capitalize()}\n" + prompt += f"TYPE : {type_label}" + + return prompt + + +def build_assistant_response(entry): + """Construire la réponse structurée de l'assistant.""" + code = entry.get("code", "") + confidence = entry.get("confidence", "medium") + justification = entry.get("justification", "") + raisonnement = entry.get("raisonnement", "") + + # Si on a un raisonnement complet, le formater en JSON structuré + if raisonnement: + # Parser les sections du raisonnement + response = {} + + sections = { + "ANALYSE CLINIQUE": "analyse_clinique", + "CODES CANDIDATS": "codes_candidats", + "DISCRIMINATION": "discrimination", + "REGLE PMSI": "regle_pmsi", + "RÈGLE PMSI": "regle_pmsi", + } + + # Extraire chaque section du raisonnement + remaining = raisonnement + for header, key in sections.items(): + marker = f"{header} :" + if marker not in remaining: + marker = f"{header}:" + if marker in remaining: + idx = remaining.index(marker) + # Trouver la fin de cette section (début de la suivante ou fin) + end_idx = len(remaining) + for next_header in sections: + next_marker = f"{next_header} :" + next_marker2 = f"{next_header}:" + for nm in (next_marker, next_marker2): + if nm in remaining[idx + len(marker):]: + candidate = idx + len(marker) + remaining[idx + len(marker):].index(nm) + if candidate < end_idx: + end_idx = candidate + value = remaining[idx + len(marker):end_idx].strip() + if value: + response[key] = value + + response["code"] = code + response["confidence"] = confidence + if justification: + response["justification"] = justification + + return json.dumps(response, ensure_ascii=False, indent=None) + + # Si pas de raisonnement, réponse simple + response = { + "code": code, + "confidence": confidence, + } + if justification: + response["justification"] = justification + + return json.dumps(response, ensure_ascii=False, indent=None) + + +def main(): + print("Chargement du cache Ollama (toutes sources)...") + entries = load_cache() + print(f" Total fusionné : {len(entries)} entrées") + + pairs = [] + with_reasoning = 0 + without_reasoning = 0 + skipped_no_code = 0 + skipped_no_text = 0 + by_type = {"dp": 0, "das": 0} + + for key, entry in entries.items(): + diag_type, texte = parse_cache_key(key) + + if not texte or len(texte) < 3: + skipped_no_text += 1 + continue + if not entry.get("code"): + skipped_no_code += 1 + continue + + user_prompt = build_user_prompt(diag_type, texte) + assistant_response = build_assistant_response(entry) + + if entry.get("raisonnement"): + with_reasoning += 1 + else: + without_reasoning += 1 + + by_type[diag_type] = by_type.get(diag_type, 0) + 1 + pairs.append(make_chatml(SYSTEM_MSG, user_prompt, assistant_response)) + + random.shuffle(pairs) + + output_path = OUT / "reasoning_chatml.jsonl" + with open(output_path, "w") as f: + for pair in pairs: + f.write(json.dumps(pair, ensure_ascii=False) + "\n") + + print(f"\nTotal : {len(pairs)} exemples → {output_path}") + print(f" DP : {by_type.get('dp', 0)}, DAS : {by_type.get('das', 0)}") + print(f" Avec raisonnement complet : {with_reasoning}") + print(f" Sans raisonnement (code seul) : {without_reasoning}") + print(f" Ignorés (pas de code) : {skipped_no_code}") + print(f" Ignorés (pas de texte) : {skipped_no_text}") + print(f"Taille : {output_path.stat().st_size / 1024:.0f} Ko") + + +if __name__ == "__main__": + main() diff --git a/scripts/04_build_dataset.py b/scripts/04_build_dataset.py new file mode 100644 index 0000000..85feff9 --- /dev/null +++ b/scripts/04_build_dataset.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +""" +Phase 1 — Fusion de tous les sous-datasets en un dataset final. + +Lit tous les .jsonl dans data/processed/ et produit : + - data/datasets/pmsi_train.jsonl (90%) + - data/datasets/pmsi_eval.jsonl (10%) + - data/datasets/stats.json (statistiques) + +V2 : Rééquilibrage — réduction drastique des lookups, priorité raisonnement. +Ratio cible : ~35% lookups / 40% raisonnement / 25% règles (vs 95/3/2 avant). +""" + +import json +import random +from pathlib import Path + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +PROCESSED = BASE / "data" / "processed" +DATASETS = BASE / "data" / "datasets" +DATASETS.mkdir(parents=True, exist_ok=True) + +EVAL_RATIO = 0.10 # 10% pour l'évaluation + +# V2 : Sous-échantillonnage agressif des lookups, garder tout le raisonnement +SUBSAMPLE = { + "cim10_chatml.jsonl": 1_500, # 50K → 1.5K (discrimination + inclusions/exclusions) + "ccam_chatml.jsonl": 1_500, # 16.5K → 1.5K (codes les plus fréquents) + "cocoa_chatml.jsonl": 2_000, # 30K → 2K (types clinical + tips prioritaires) + "referentiels_chatml.jsonl": None, # Garder tout (5.3K, déjà dense en règles) +} + +# Mots-clés indiquant du raisonnement structuré (pour sélection intelligente) +REASONING_KEYWORDS = [ + "analyse_clinique", "discrimination", "regle_pmsi", "règle_pmsi", + "codes_candidats", "justification", "ne pas coder", "ne doit pas", + "à l'exclusion", "comprend", "sévérité", "CMA", +] + + +def count_tokens_approx(text): + """Estimation grossière du nombre de tokens (~1.3 token par mot français).""" + return int(len(text.split()) * 1.3) + + +def _reasoning_score(example: dict) -> int: + """Score de raisonnement d'un exemple (plus haut = plus précieux).""" + text = " ".join(m.get("content", "") for m in example.get("messages", [])) + text_lower = text.lower() + score = 0 + for kw in REASONING_KEYWORDS: + if kw.lower() in text_lower: + score += 1 + # Bonus pour les réponses JSON structurées longues (raisonnement complet) + assistant_msgs = [m["content"] for m in example.get("messages", []) if m["role"] == "assistant"] + for msg in assistant_msgs: + if len(msg) > 200: + score += 2 + if '"analyse_clinique"' in msg: + score += 3 + return score + + +def _smart_subsample(examples: list[dict], target: int) -> list[dict]: + """Sous-échantillonnage intelligent : prioriser les exemples avec raisonnement.""" + # Trier par score de raisonnement (décroissant), puis shuffle pour diversité + scored = [(ex, _reasoning_score(ex)) for ex in examples] + scored.sort(key=lambda x: -x[1]) + + # Garder tous ceux avec score > 0, puis compléter avec du random + high_value = [(ex, s) for ex, s in scored if s > 0] + low_value = [(ex, s) for ex, s in scored if s == 0] + random.shuffle(low_value) + + if len(high_value) >= target: + # Trop d'exemples high-value : prendre les meilleurs + random parmi eux + random.shuffle(high_value) + selected = [ex for ex, _ in high_value[:target]] + else: + # Prendre tous les high-value + compléter avec low-value + selected = [ex for ex, _ in high_value] + remaining = target - len(selected) + selected.extend(ex for ex, _ in low_value[:remaining]) + + random.shuffle(selected) + return selected + + +def main(): + # Collecter tous les fichiers JSONL + files = sorted(PROCESSED.glob("*.jsonl")) + if not files: + print("Aucun fichier .jsonl trouvé dans data/processed/") + return + + all_examples = [] + file_stats = {} + + for f in files: + examples = [] + with open(f) as fh: + for line in fh: + line = line.strip() + if line: + examples.append(json.loads(line)) + + original_count = len(examples) + + # Sous-échantillonner si configuré + if f.name in SUBSAMPLE: + target = SUBSAMPLE[f.name] + if target is None or len(examples) <= target: + print(f" {f.name}: {len(examples)} exemples (gardé tout)") + else: + examples = _smart_subsample(examples, target) + print(f" {f.name}: {len(examples)} exemples (sous-échantillonné depuis {original_count})") + else: + print(f" {f.name}: {len(examples)} exemples") + + file_stats[f.name] = len(examples) + all_examples.extend(examples) + + print(f"\nTotal brut : {len(all_examples)} exemples") + + # Mélanger + random.shuffle(all_examples) + + # Split train / eval + n_eval = max(1, int(len(all_examples) * EVAL_RATIO)) + eval_set = all_examples[:n_eval] + train_set = all_examples[n_eval:] + + # Écrire les datasets + train_path = DATASETS / "pmsi_train.jsonl" + eval_path = DATASETS / "pmsi_eval.jsonl" + + for path, dataset in [(train_path, train_set), (eval_path, eval_set)]: + with open(path, "w") as fh: + for ex in dataset: + fh.write(json.dumps(ex, ensure_ascii=False) + "\n") + + # Statistiques + def compute_stats(dataset): + total_tokens = 0 + max_tokens = 0 + min_tokens = float("inf") + for ex in dataset: + text = " ".join(m["content"] for m in ex["messages"]) + tokens = count_tokens_approx(text) + total_tokens += tokens + max_tokens = max(max_tokens, tokens) + min_tokens = min(min_tokens, tokens) + avg = total_tokens / len(dataset) if dataset else 0 + return { + "count": len(dataset), + "total_tokens_approx": total_tokens, + "avg_tokens": round(avg), + "max_tokens": max_tokens, + "min_tokens": min_tokens if min_tokens != float("inf") else 0, + } + + train_stats = compute_stats(train_set) + eval_stats = compute_stats(eval_set) + + stats = { + "sources": file_stats, + "total": len(all_examples), + "train": train_stats, + "eval": eval_stats, + "eval_ratio": EVAL_RATIO, + } + + stats_path = DATASETS / "stats.json" + with open(stats_path, "w") as fh: + json.dump(stats, fh, indent=2, ensure_ascii=False) + + # Calculer le ratio raisonnement + n_reasoning = sum(1 for ex in all_examples if _reasoning_score(ex) >= 3) + pct_reasoning = 100.0 * n_reasoning / len(all_examples) if all_examples else 0 + + # Affichage + print(f"\n{'='*50}") + print(f"Dataset final (V2 rééquilibré) :") + print(f" Train : {train_stats['count']} exemples → {train_path.name}") + print(f" Eval : {eval_stats['count']} exemples → {eval_path.name}") + print(f"\nRatio raisonnement structuré (score≥3) : {n_reasoning}/{len(all_examples)} ({pct_reasoning:.0f}%)") + print(f"\nTokens (estimation) :") + print(f" Train : ~{train_stats['total_tokens_approx']:,} tokens (moy: {train_stats['avg_tokens']}, max: {train_stats['max_tokens']})") + print(f" Eval : ~{eval_stats['total_tokens_approx']:,} tokens (moy: {eval_stats['avg_tokens']}, max: {eval_stats['max_tokens']})") + print(f"\nSources :") + for name, count in sorted(file_stats.items()): + print(f" {name}: {count}") + print(f"\nFichiers :") + print(f" {train_path} ({train_path.stat().st_size / 1024 / 1024:.1f} Mo)") + print(f" {eval_path} ({eval_path.stat().st_size / 1024 / 1024:.1f} Mo)") + print(f" {stats_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/05_parse_cocoa.py b/scripts/05_parse_cocoa.py new file mode 100644 index 0000000..f784a1c --- /dev/null +++ b/scripts/05_parse_cocoa.py @@ -0,0 +1,671 @@ +#!/usr/bin/env python3 +""" +Phase 1E — Parsing du CoCoA 2025 (1113 pages) pour extraction d'exemples ChatML. + +Le CoCoA (Codage Complet Annoté) est le vademecum des médecins DIM. +Il contient des entrées détaillées par code CIM-10 avec : + - Indicateurs P/R/A (Diagnostic Principal / Relié / Associé) + - Niveaux de sévérité (2, 3, 4) + - Descriptions cliniques détaillées + - Synonymes + - Comprend / À l'exclusion de + - Notes AGORA (FAQ ATIH) + - Annotations CoCoA (conseils pratiques DIM) + +Pages traitées : 85-1080 (entrées détaillées, chapitres 1-22) + +Produit : data/processed/cocoa_chatml.jsonl +""" + +import json +import re +import random +from pathlib import Path + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +RAW = BASE / "data" / "raw" +OUT = BASE / "data" / "processed" +OUT.mkdir(parents=True, exist_ok=True) + +SYSTEM_MSG = "Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. Tu t'appuies sur le CoCoA (Codage Complet Annoté) pour tes décisions de codage." + +# Pages des entrées détaillées (0-indexed) +PAGE_START = 84 # page 85 +PAGE_END = 1080 # page 1080 + +# Regex patterns +RE_CIM10_CODE = re.compile( + r'^([A-Z]\d{2}(?:\.\d{1,2})?)\s*([†*]?)\s+(.*)' +) +RE_CATEGORY_CODE = re.compile( + r'^([A-Z]\d{2})\s+(.*)' +) +RE_SUBCODE = re.compile( + r'^([A-Z]\d{2}\.\d{1,2})\s*([†*]?)\s*(.*)' +) +RE_PRA_LINE = re.compile(r'^P\s*R\s*A') +RE_SEVERITY = re.compile(r'^(\d)\s*$') +RE_CHAPTER_HEADER = re.compile(r'^CHAPITRE\s+([IVX]+)\s*:?\s*(.*)') +RE_SECTION_HEADER = re.compile(r'^([A-Z][a-zéèêëàâîïôùûüç].+)\s*\(([A-Z]\d{2}[-–][A-Z]\d{2})\)') +RE_EXCLUSION = re.compile(r"^À l['\u2019]exclusion de\s+(.*)", re.IGNORECASE) +RE_COMPREND = re.compile(r'^Comprend\s+(.*)', re.IGNORECASE) +RE_AGORA = re.compile(r'\(AGORA\s*[-–]\s*#?\s*(\d+).*?\)') +RE_FOOTER = re.compile(r'^2025\s*[-–]') +RE_NOTE_BRACKET = re.compile(r'^\[voir en début') + + +def extract_text_from_pdf(): + """Extraire le texte de toutes les pages détaillées du CoCoA.""" + import pdfplumber + + pdf_path = RAW / "cocoa_2025.pdf" + print(f"Ouverture de {pdf_path}...") + + pages_text = [] + with pdfplumber.open(pdf_path) as pdf: + total = min(PAGE_END, len(pdf.pages)) + for i in range(PAGE_START, total): + page = pdf.pages[i] + text = page.extract_text() or "" + pages_text.append((i + 1, text)) # (page_number, text) + + if (i - PAGE_START) % 100 == 0: + print(f" Extraction page {i+1}/{total}...") + + print(f" {len(pages_text)} pages extraites") + return pages_text + + +def parse_entries(pages_text): + """Parser les entrées CIM-10 depuis le texte extrait.""" + entries = {} # code -> dict + current_chapter = "" + current_section = "" + current_code = None + current_entry = None + collecting_exclusion = False + collecting_comprend = False + collecting_description = False + + for page_num, page_text in pages_text: + lines = page_text.split('\n') + + for line_idx, line in enumerate(lines): + line = line.strip() + + # Skip empty lines and footers + if not line: + collecting_exclusion = False + collecting_comprend = False + collecting_description = False + continue + + if RE_FOOTER.match(line): + collecting_exclusion = False + collecting_comprend = False + collecting_description = False + continue + + if RE_NOTE_BRACKET.match(line): + collecting_description = False + continue + + # Chapter header + m = RE_CHAPTER_HEADER.match(line) + if m: + current_chapter = m.group(2).strip() + collecting_exclusion = False + collecting_comprend = False + collecting_description = False + continue + + # Skip P R A indicator lines (standalone) + if RE_PRA_LINE.match(line): + # Check if there's a code on the same line + rest = re.sub(r'^P\s*R\s*A\s*', '', line).strip() + # Also remove "AN T" or similar special markers + rest = re.sub(r'^AN\s*T?\s*', '', rest).strip() + + if rest: + # P R A followed by code on same line (category code) + m_cat = RE_CATEGORY_CODE.match(rest) + m_sub = RE_SUBCODE.match(rest) + if m_sub: + code = m_sub.group(1) + dagger_star = m_sub.group(2) + desc = m_sub.group(3).strip() + _save_entry(entries, current_code, current_entry) + current_code = code + current_entry = _new_entry(code, desc, dagger_star, current_chapter, page_num, is_category=False) + collecting_exclusion = False + collecting_comprend = False + collecting_description = False + elif m_cat: + code = m_cat.group(1) + desc = m_cat.group(2).strip() + _save_entry(entries, current_code, current_entry) + current_code = code + current_entry = _new_entry(code, desc, "", current_chapter, page_num, is_category=True) + collecting_exclusion = False + collecting_comprend = False + collecting_description = False + continue + + # Severity number on its own line + m = RE_SEVERITY.match(line) + if m and current_entry: + current_entry["severity"] = int(m.group(1)) + continue + + # Sub-code entry + m = RE_SUBCODE.match(line) + if m: + code = m.group(1) + dagger_star = m.group(2) + desc = m.group(3).strip() + _save_entry(entries, current_code, current_entry) + current_code = code + current_entry = _new_entry(code, desc, dagger_star, current_chapter, page_num, is_category=False) + collecting_exclusion = False + collecting_comprend = False + collecting_description = False + continue + + # Category code (3-char code at start of line, no dot) + m = RE_CATEGORY_CODE.match(line) + if m and not line[0].islower() and len(m.group(1)) == 3: + # Make sure it's actually a code and not part of text + potential_code = m.group(1) + if re.match(r'^[A-Z]\d{2}$', potential_code): + desc = m.group(2).strip() + # Avoid false positives - check that desc looks like a title + if desc and len(desc) > 3 and not desc[0].isdigit(): + _save_entry(entries, current_code, current_entry) + current_code = potential_code + current_entry = _new_entry(potential_code, desc, "", current_chapter, page_num, is_category=True) + collecting_exclusion = False + collecting_comprend = False + collecting_description = False + continue + + # Section header (e.g., "Autres maladies bactériennes (A30-A49)") + m = RE_SECTION_HEADER.match(line) + if m: + current_section = m.group(1).strip() + collecting_exclusion = False + collecting_comprend = False + collecting_description = False + continue + + # Comprend + m = RE_COMPREND.match(line) + if m: + if current_entry: + current_entry["comprend"].append(m.group(1).strip()) + collecting_comprend = True + collecting_exclusion = False + collecting_description = False + continue + + # À l'exclusion de + m = RE_EXCLUSION.match(line) + if m: + if current_entry: + current_entry["exclusions"].append(m.group(1).strip()) + collecting_exclusion = True + collecting_comprend = False + collecting_description = False + continue + + # AGORA reference + agora_matches = RE_AGORA.findall(line) + if agora_matches and current_entry: + for ref in agora_matches: + current_entry["agora_refs"].append(ref) + # Also add the full line as a CoCoA annotation + if "AGORA" in line or "Aunis" in line.lower() or "CoCoA" in line: + current_entry["cocoa_notes"].append(line) + continue + + # CoCoA/Aunis annotations (highlighted text) + if current_entry and ("Aunis" in line or "CoCoA" in line): + current_entry["cocoa_notes"].append(line) + continue + + # Continuation lines for exclusions + if collecting_exclusion and current_entry: + # Exclusion continuation - items with code refs, bullets, lowercase starts + if (re.search(r'\([A-Z]\d{2}', line) or + line.startswith('•') or line.startswith('-') or + line[0].islower() or + re.match(r'^[a-zéèêëàâîïôùûüç•\-]', line)): + current_entry["exclusions"].append(line) + continue + else: + collecting_exclusion = False + + # Continuation lines for comprend + if collecting_comprend and current_entry: + if not re.match(r'^[A-Z]\d', line) and not RE_PRA_LINE.match(line): + current_entry["comprend"].append(line) + continue + else: + collecting_comprend = False + + # Clinical description text (paragraph after a code entry) + if current_entry and line and not RE_PRA_LINE.match(line): + # Check if it's a synonym or clinical text + if len(line) > 60 and not re.match(r'^[A-Z]\d', line): + # Long text = clinical description + current_entry["clinical_text"].append(line) + elif not re.match(r'^[A-Z]\d', line) and not line.startswith('P '): + # Short text after a code = synonym + current_entry["synonyms"].append(line) + + # Save last entry + _save_entry(entries, current_code, current_entry) + + return entries + + +def _new_entry(code, description, dagger_star, chapter, page, is_category=False): + return { + "code": code, + "description": description, + "dagger_star": dagger_star, + "chapter": chapter, + "page": page, + "is_category": is_category, + "severity": None, + "synonyms": [], + "comprend": [], + "exclusions": [], + "clinical_text": [], + "agora_refs": [], + "cocoa_notes": [], + } + + +def _save_entry(entries, code, entry): + if code and entry and entry["description"]: + # Clean up + entry["synonyms"] = [s.strip() for s in entry["synonyms"] if s.strip() and len(s.strip()) > 2] + entry["comprend"] = [c.strip() for c in entry["comprend"] if c.strip()] + entry["exclusions"] = [e.strip() for e in entry["exclusions"] if e.strip()] + entry["clinical_text"] = [t.strip() for t in entry["clinical_text"] if t.strip()] + entry["cocoa_notes"] = [n.strip() for n in entry["cocoa_notes"] if n.strip()] + + # Deduplicate + entry["synonyms"] = list(dict.fromkeys(entry["synonyms"])) + entry["cocoa_notes"] = list(dict.fromkeys(entry["cocoa_notes"])) + + # Filter out noise from synonyms and move misclassified exclusions + filtered_syns = [] + re_excl_inline = re.compile(r"^À l['\u2019]exclusion de", re.IGNORECASE) + for s in entry["synonyms"]: + # Skip severity numbers, P R A markers, etc. + if RE_SEVERITY.match(s) or RE_PRA_LINE.match(s) or RE_FOOTER.match(s): + continue + if s in ("P R A", "P", "R", "A", "AN", "T"): + continue + # Move misclassified exclusions + if re_excl_inline.match(s): + excl_text = re.sub(r"^À l['\u2019]exclusion de\s*", '', s, flags=re.IGNORECASE).strip() + if excl_text: + entry["exclusions"].append(excl_text) + continue + filtered_syns.append(s) + entry["synonyms"] = filtered_syns + + # Also clean comprend - move misclassified exclusions + filtered_comprend = [] + for c in entry["comprend"]: + if re_excl_inline.match(c): + excl_text = re.sub(r"^À l['\u2019]exclusion de\s*", '', c, flags=re.IGNORECASE).strip() + if excl_text: + entry["exclusions"].append(excl_text) + else: + filtered_comprend.append(c) + entry["comprend"] = filtered_comprend + + entries[code] = entry + + +def make_chatml(system, user, assistant): + return { + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + {"role": "assistant", "content": assistant}, + ] + } + + +def generate_description_pairs(entries): + """Type 1 : Description enrichie CoCoA d'un code (vs FHIR plus basique).""" + pairs = [] + for code, e in entries.items(): + desc = e["description"] + if not desc or len(desc) < 3: + continue + + answer_parts = [f"{code} — {desc}"] + + if e["chapter"]: + answer_parts.append(f"Chapitre : {e['chapter']}") + + if e["synonyms"]: + syns = [s for s in e["synonyms"][:8] if len(s) > 2] + if syns: + answer_parts.append(f"Synonymes : {' ; '.join(syns)}") + + if e["comprend"]: + answer_parts.append(f"Comprend : {' '.join(e['comprend'][:5])}") + + if e["exclusions"]: + excls = [ex for ex in e["exclusions"][:5]] + answer_parts.append(f"À l'exclusion de : {' ; '.join(excls)}") + + if e["severity"]: + answer_parts.append(f"Niveau de sévérité CMA : {e['severity']}") + + if e["dagger_star"]: + marker = "étiologique (†)" if e["dagger_star"] == "†" else "manifestation (*)" + answer_parts.append(f"Convention dague/astérisque : code {marker}") + + answer = "\n".join(answer_parts) + if len(answer) > 2000: + answer = answer[:2000] + + templates = [ + f"Décris le code CIM-10 {code} selon le CoCoA.", + f"Que dit le CoCoA sur le code {code} ?", + f"Quelles sont les caractéristiques du code {code} d'après le CoCoA ?", + ] + + pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), answer)) + + return pairs + + +def generate_clinical_pairs(entries): + """Type 2 : Descriptions cliniques détaillées → code + raisonnement.""" + pairs = [] + for code, e in entries.items(): + if not e["clinical_text"]: + continue + + clinical = " ".join(e["clinical_text"]) + if len(clinical) < 50: + continue + + desc = e["description"] + + # Construire un raisonnement structuré + reasoning = { + "analyse_clinique": clinical[:500], + "code": code, + "description": desc, + "confidence": "high", + "justification": f"La description clinique du CoCoA correspond au code {code} ({desc})." + } + + if e["exclusions"]: + reasoning["exclusions_a_verifier"] = " ; ".join(e["exclusions"][:3]) + + answer = json.dumps(reasoning, ensure_ascii=False) + + # Créer une question à partir du texte clinique (tronqué) + clinical_short = clinical[:300] + if len(clinical) > 300: + clinical_short += "..." + + question = f"Un patient présente le tableau clinique suivant :\n{clinical_short}\n\nQuel code CIM-10 correspond à cette présentation ?" + + if len(question) > 1500: + continue + + pairs.append(make_chatml(SYSTEM_MSG, question, answer)) + + return pairs + + +def generate_synonym_pairs(entries): + """Type 3 : Synonyme → code CIM-10.""" + pairs = [] + for code, e in entries.items(): + if not e["synonyms"]: + continue + + desc = e["description"] + + for syn in e["synonyms"]: + if len(syn) < 4 or len(syn) > 200: + continue + # Skip entries that look like noise + if syn.startswith("•") or syn.startswith("[") or syn.startswith("("): + syn = syn.lstrip("•[( ").rstrip("])").strip() + if not syn or len(syn) < 4: + continue + + answer = json.dumps({ + "code": code, + "confidence": "high", + "justification": f"« {syn} » est un synonyme de {code} ({desc}) selon le CoCoA." + }, ensure_ascii=False) + + templates = [ + f"Quel est le code CIM-10 pour : {syn} ?", + f"Code CIM-10 correspondant à « {syn} » ?", + ] + + pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), answer)) + + return pairs + + +def generate_exclusion_pairs(entries): + """Type 4 : Questions sur ce qu'un code exclut (piège de codage).""" + pairs = [] + for code, e in entries.items(): + if not e["exclusions"]: + continue + + desc = e["description"] + excls = " ; ".join(e["exclusions"][:8]) + + if len(excls) < 10: + continue + + answer = f"Le code {code} ({desc}) exclut :\n{excls}\n\nAttention : ces situations doivent être codées avec les codes de renvoi indiqués entre parenthèses." + + if len(answer) > 1500: + answer = answer[:1500] + + templates = [ + f"Quelles sont les exclusions du code CIM-10 {code} ({desc}) ?", + f"Que ne faut-il PAS coder en {code} ?", + ] + + pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), answer)) + + return pairs + + +def generate_severity_pairs(entries): + """Type 5 : Questions sur le niveau de sévérité CMA d'un code.""" + pairs = [] + for code, e in entries.items(): + if not e["severity"]: + continue + + desc = e["description"] + sev = e["severity"] + + sev_text = { + 2: "niveau 2 (sévérité modérée)", + 3: "niveau 3 (sévérité élevée)", + 4: "niveau 4 (sévérité très élevée)", + }.get(sev, f"niveau {sev}") + + answer = f"Le code {code} ({desc}) a un niveau de sévérité CMA de {sev_text}.\n" + answer += f"En tant que DAS, ce code peut entraîner une majoration du niveau de sévérité du GHM." + + if e["is_category"]: + answer += f"\nNote : {code} est une catégorie (code à 3 caractères). Les sous-codes peuvent avoir des niveaux différents." + + pairs.append(make_chatml( + SYSTEM_MSG, + f"Quel est le niveau de sévérité CMA du code {code} ({desc}) ?", + answer + )) + + return pairs + + +def generate_cocoa_tips_pairs(entries): + """Type 6 : Notes CoCoA et AGORA (conseils pratiques DIM).""" + pairs = [] + for code, e in entries.items(): + if not e["cocoa_notes"]: + continue + + desc = e["description"] + notes = "\n".join(e["cocoa_notes"]) + + if len(notes) < 10: + continue + + answer = f"Pour le code {code} ({desc}), le CoCoA indique :\n{notes}" + + if len(answer) > 1500: + answer = answer[:1500] + + pairs.append(make_chatml( + SYSTEM_MSG, + f"Y a-t-il des conseils pratiques du CoCoA pour le codage de {code} ({desc}) ?", + answer + )) + + return pairs + + +def generate_comprend_pairs(entries): + """Type 7 : Ce que comprend un code (inclusions).""" + pairs = [] + for code, e in entries.items(): + if not e["comprend"]: + continue + + desc = e["description"] + comprend = " ; ".join(e["comprend"][:5]) + + if len(comprend) < 10: + continue + + answer = f"Le code {code} ({desc}) comprend :\n{comprend}" + + templates = [ + f"Que comprend le code CIM-10 {code} ?", + f"Quelles situations sont incluses dans le code {code} ({desc}) ?", + ] + + pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), answer)) + + return pairs + + +def main(): + # Étape 1 : Extraction du texte + pages_text = extract_text_from_pdf() + + # Étape 2 : Parsing des entrées + print("\nParsing des entrées CIM-10...") + entries = parse_entries(pages_text) + + # Stats + n_categories = sum(1 for e in entries.values() if e["is_category"]) + n_subcodes = sum(1 for e in entries.values() if not e["is_category"]) + n_with_clinical = sum(1 for e in entries.values() if e["clinical_text"]) + n_with_synonyms = sum(1 for e in entries.values() if e["synonyms"]) + n_with_exclusions = sum(1 for e in entries.values() if e["exclusions"]) + n_with_comprend = sum(1 for e in entries.values() if e["comprend"]) + n_with_severity = sum(1 for e in entries.values() if e["severity"]) + n_with_cocoa = sum(1 for e in entries.values() if e["cocoa_notes"]) + + print(f"\n Entrées parsées : {len(entries)}") + print(f" Catégories (3 car.) : {n_categories}") + print(f" Sous-codes : {n_subcodes}") + print(f" Avec texte clinique : {n_with_clinical}") + print(f" Avec synonymes : {n_with_synonyms}") + print(f" Avec exclusions : {n_with_exclusions}") + print(f" Avec comprend : {n_with_comprend}") + print(f" Avec sévérité CMA : {n_with_severity}") + print(f" Avec notes CoCoA : {n_with_cocoa}") + + # Étape 3 : Génération des paires ChatML + print("\nGénération des paires ChatML...") + + print(" Type 1 : Descriptions enrichies CoCoA") + desc_pairs = generate_description_pairs(entries) + print(f" → {len(desc_pairs)} exemples") + + print(" Type 2 : Texte clinique → code") + clinical_pairs = generate_clinical_pairs(entries) + print(f" → {len(clinical_pairs)} exemples") + + print(" Type 3 : Synonyme → code") + synonym_pairs = generate_synonym_pairs(entries) + print(f" → {len(synonym_pairs)} exemples") + + print(" Type 4 : Exclusions") + exclusion_pairs = generate_exclusion_pairs(entries) + print(f" → {len(exclusion_pairs)} exemples") + + print(" Type 5 : Sévérité CMA") + severity_pairs = generate_severity_pairs(entries) + print(f" → {len(severity_pairs)} exemples") + + print(" Type 6 : Notes CoCoA/AGORA") + cocoa_pairs = generate_cocoa_tips_pairs(entries) + print(f" → {len(cocoa_pairs)} exemples") + + print(" Type 7 : Comprend (inclusions)") + comprend_pairs = generate_comprend_pairs(entries) + print(f" → {len(comprend_pairs)} exemples") + + # Fusionner et mélanger + all_pairs = desc_pairs + clinical_pairs + synonym_pairs + exclusion_pairs + severity_pairs + cocoa_pairs + comprend_pairs + random.shuffle(all_pairs) + + # Écrire le JSONL + output_path = OUT / "cocoa_chatml.jsonl" + with open(output_path, "w") as f: + for pair in all_pairs: + f.write(json.dumps(pair, ensure_ascii=False) + "\n") + + print(f"\n{'='*50}") + print(f"Total : {len(all_pairs)} exemples → {output_path}") + print(f"Taille : {output_path.stat().st_size / 1024 / 1024:.1f} Mo") + + # Sauvegarder aussi les entrées parsées en JSON pour debug + debug_path = OUT / "cocoa_entries_debug.json" + with open(debug_path, "w") as f: + json.dump(entries, f, indent=2, ensure_ascii=False) + print(f"Debug : {debug_path} ({debug_path.stat().st_size / 1024 / 1024:.1f} Mo)") + + # Répartition + print(f"\nRépartition :") + print(f" Descriptions CoCoA : {len(desc_pairs)}") + print(f" Texte clinique→code : {len(clinical_pairs)}") + print(f" Synonyme→code : {len(synonym_pairs)}") + print(f" Exclusions : {len(exclusion_pairs)}") + print(f" Sévérité CMA : {len(severity_pairs)}") + print(f" Notes CoCoA/AGORA : {len(cocoa_pairs)}") + print(f" Comprend (inclusions): {len(comprend_pairs)}") + + +if __name__ == "__main__": + main() diff --git a/scripts/06_generate_synthetic.py b/scripts/06_generate_synthetic.py new file mode 100644 index 0000000..4b0c328 --- /dev/null +++ b/scripts/06_generate_synthetic.py @@ -0,0 +1,562 @@ +#!/usr/bin/env python3 +""" +Phase 1D — Génération de données synthétiques via API OpenAI (GPT-4o). + +Envoie des métadonnées anonymisées FAISS à un grand modèle pour générer +des exemples de raisonnement DIM complet en format ChatML. + +Types d'exemples générés : + 1. Scénario clinique → raisonnement DIM → code CIM-10 + 2. Discrimination entre codes proches + 3. Application des règles PMSI (DP/DAS, CMA, exclusions) + +Nécessite : OPENAI_API_KEY en variable d'environnement + +Usage : + python scripts/06_generate_synthetic.py [--n 500] [--batch 5] [--model gpt-4o] [--dry-run] +""" + +import json +import os +import sys +import time +import random +import argparse +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +T2A = Path("/home/dom/ai/t2a") +OUT = BASE / "data" / "processed" +OUT.mkdir(parents=True, exist_ok=True) + +# --- Prompts --- + +SYSTEM_PROMPT_SCENARIO = """Tu es un formateur DIM (Département d'Information Médicale) expert en codage PMSI. +Tu génères des scénarios cliniques réalistes et anonymisés pour former des médecins DIM au codage CIM-10. + +Pour chaque code CIM-10 fourni, tu dois produire : +1. Un SCÉNARIO CLINIQUE réaliste (3-5 phrases, anonymisé, comme extrait d'un compte-rendu d'hospitalisation) +2. Un RAISONNEMENT DIM structuré montrant la démarche de codage + +Le raisonnement doit suivre ces étapes : +- analyse_clinique : ce que le texte clinique révèle +- codes_candidats : 2-3 codes CIM-10 envisageables avec leur libellé +- discrimination : pourquoi le code retenu est le bon (et pas les autres) +- regle_pmsi : règle PMSI applicable (DP/DAS, exclusions, conventions dague/astérisque, etc.) +- code : le code CIM-10 retenu +- confidence : high/medium/low +- justification : synthèse en 1 phrase + +IMPORTANT : +- Les scénarios doivent être VARIÉS (âges, sexes, contextes différents) +- Anonymisés (pas de vrais noms/dates) +- Médicalement cohérents +- En français médical professionnel +- Réponse en JSON valide uniquement""" + +SYSTEM_PROMPT_DISCRIM = """Tu es un formateur DIM expert en codage PMSI. +Tu crées des exercices de discrimination entre codes CIM-10 proches pour former des médecins DIM. + +Pour chaque groupe de codes fourni, génère UN scénario clinique où le choix entre les codes est subtil, +puis montre le raisonnement complet pour arriver au bon code. + +IMPORTANT : Réponse en JSON valide uniquement.""" + +SYSTEM_PROMPT_RULES = """Tu es un formateur DIM expert en règles PMSI. +Tu crées des exercices d'application des règles PMSI (codage DP/DAS, CMA, séjours multi-unités, etc.). + +Pour chaque situation fournie, génère un scénario d'hospitalisation et montre comment les règles PMSI +s'appliquent au codage. + +IMPORTANT : Réponse en JSON valide uniquement.""" + + +def load_faiss_metadata(): + """Charger les métadonnées FAISS.""" + meta_path = T2A / "data" / "rag_index" / "metadata.json" + with open(meta_path) as f: + return json.load(f) + + +def load_cim10_fhir(): + """Charger les concepts FHIR pour enrichir les prompts.""" + fhir_path = BASE / "data" / "raw" / "smt_cim10_fhir.json" + if not fhir_path.exists(): + return {} + with open(fhir_path) as f: + data = json.load(f) + by_code = {} + for c in data.get("concept", []): + by_code[c["code"]] = c + return by_code + + +def load_cocoa_entries(): + """Charger les entrées CoCoA parsées.""" + cocoa_path = OUT / "cocoa_entries_debug.json" + if not cocoa_path.exists(): + return {} + with open(cocoa_path) as f: + return json.load(f) + + +def clean_extrait(extrait): + """Nettoyer un extrait FAISS (enlever bruit OCR, numéros de page, etc.).""" + import re + # Enlever les numéros de page isolés (sur leur propre ligne ou collés) + extrait = re.sub(r'\n\s*\d{1,4}\s*\n', '\n', extrait) + extrait = re.sub(r'^\d{1,4}\s*\n', '', extrait) + # Enlever les transitions de chapitre + extrait = re.sub(r'Chapitre\s+[IVX]+\b.*', '', extrait) + # Enlever les lignes de classification + extrait = re.sub(r'Classification Internationale.*$', '', extrait, flags=re.MULTILINE) + # Enlever les lignes vides multiples + extrait = re.sub(r'\n{2,}', '\n', extrait) + # Tronquer au premier code d'une AUTRE catégorie + lines = extrait.split('\n') + first_code = None + clean_lines = [] + for line in lines: + stripped = line.strip() + # Ligne ne contenant qu'un nombre = bruit + if re.match(r'^\d{1,4}$', stripped): + continue + m = re.match(r'^([A-Z]\d{2}(?:\.\d{1,2})?)[*†]?\s', stripped) + if m: + code = m.group(1) + if first_code is None: + first_code = code + elif not code.startswith(first_code[:3]): + break # On est passé à un autre groupe de codes + clean_lines.append(stripped) + result = '\n'.join(clean_lines).strip() + # Limiter la longueur + if len(result) > 400: + result = result[:400].rsplit('\n', 1)[0] + return result + + +def select_codes_for_scenarios(metadata, n=500): + """Sélectionner les codes CIM-10 les plus intéressants pour la génération.""" + # Filtrer les entrées CIM-10 avec des extraits substantiels + cim10_entries = [m for m in metadata if m.get("document") == "cim10" and len(m.get("extrait", "")) > 20] + + # Prioriser les codes avec des extraits riches + cim10_entries.sort(key=lambda x: len(x.get("extrait", "")), reverse=True) + + # Prendre les codes uniques + seen = set() + selected = [] + for m in cim10_entries: + code = m["code"] + if code not in seen and "." in code: # Préférer les sous-codes (plus spécifiques) + seen.add(code) + selected.append(m) + if len(selected) >= n: + break + + # Si pas assez de sous-codes, ajouter des catégories + if len(selected) < n: + for m in cim10_entries: + code = m["code"] + if code not in seen: + seen.add(code) + selected.append(m) + if len(selected) >= n: + break + + random.shuffle(selected) + return selected[:n] + + +def select_discrimination_groups(metadata, cocoa_entries, n=100): + """Sélectionner des groupes de codes proches pour la discrimination.""" + # Grouper par catégorie parente (3 premiers caractères) + by_parent = {} + for m in metadata: + if m.get("document") != "cim10": + continue + code = m.get("code", "") + if "." in code: + parent = code.split(".")[0] + by_parent.setdefault(parent, []).append(m) + + # Sélectionner les groupes avec 2-6 sous-codes + groups = [] + for parent, children in by_parent.items(): + if 2 <= len(children) <= 6: + groups.append({ + "parent": parent, + "codes": [{"code": c["code"], "extrait": clean_extrait(c["extrait"])[:150]} for c in children] + }) + + random.shuffle(groups) + return groups[:n] + + +def build_scenario_prompt(codes_batch, fhir_by_code, cocoa_entries): + """Construire le prompt pour un batch de codes (scénarios cliniques).""" + items = [] + for meta in codes_batch: + code = meta["code"] + + # Source primaire : FHIR (propre) + fhir = fhir_by_code.get(code, {}) + display = fhir.get("display", "") + + # Source secondaire : CoCoA (riche) + cocoa = cocoa_entries.get(code, {}) + cocoa_desc = cocoa.get("description", "") + exclusions = cocoa.get("exclusions", [])[:4] + synonyms = cocoa.get("synonyms", [])[:5] + comprend = cocoa.get("comprend", [])[:3] + severity = cocoa.get("severity") + clinical = " ".join(cocoa.get("clinical_text", []))[:200] + + # Utiliser la meilleure description disponible + desc = display or cocoa_desc or clean_extrait(meta["extrait"])[:200] + + item = f"CODE: {code}\nLIBELLÉ: {desc}" + if synonyms: + item += f"\nSYNONYMES: {'; '.join(synonyms)}" + if comprend: + item += f"\nCOMPREND: {'; '.join(comprend)}" + if exclusions: + item += f"\nEXCLUSIONS: {'; '.join(exclusions)}" + if severity: + item += f"\nSÉVÉRITÉ CMA: {severity}" + if clinical: + item += f"\nDESCRIPTION CLINIQUE: {clinical}" + items.append(item) + + codes_text = "\n---\n".join(items) + + prompt = f"""Génère un scénario clinique et un raisonnement DIM pour chacun des {len(codes_batch)} codes suivants. + +{codes_text} + +Réponds en JSON avec cette structure exacte : +{{"scenarios": [ + {{ + "code": "le code CIM-10", + "scenario_clinique": "Texte du scénario clinique réaliste (3-5 phrases)", + "raisonnement": {{ + "analyse_clinique": "Analyse des éléments cliniques pertinents", + "codes_candidats": "2-3 codes envisagés avec libellés", + "discrimination": "Pourquoi ce code et pas les autres", + "regle_pmsi": "Règle PMSI applicable", + "code_retenu": "le code", + "confidence": "high", + "justification": "Synthèse en 1 phrase" + }} + }}, + ... +]}} + +Le tableau "scenarios" DOIT contenir exactement {len(codes_batch)} objets, un par code.""" + + return prompt + + +def build_discrimination_prompt(group, fhir_by_code): + """Construire le prompt pour un exercice de discrimination.""" + codes_text = "\n".join( + f"- {c['code']}: {c['extrait']}" + for c in group["codes"] + ) + + prompt = f"""Voici un groupe de codes CIM-10 de la catégorie {group['parent']} : + +{codes_text} + +Génère UN scénario clinique réaliste où le choix entre ces codes est subtil et demande une réflexion. +Puis montre le raisonnement DIM complet pour arriver au bon code. + +Réponds en JSON : +{{ + "scenario_clinique": "Le scénario (5-8 phrases, réaliste, anonymisé)", + "codes_en_jeu": ["code1", "code2"], + "raisonnement": {{ + "analyse_clinique": "...", + "codes_candidats": "...", + "discrimination": "Explication détaillée de pourquoi un code est préféré", + "regle_pmsi": "Règle PMSI applicable", + "code_retenu": "le code correct", + "confidence": "high/medium", + "justification": "..." + }} +}} + +JSON valide uniquement.""" + + return prompt + + +def call_openai(client, model, system_prompt, user_prompt, temperature=0.7): + """Appeler l'API OpenAI.""" + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + temperature=temperature, + max_tokens=4096, + response_format={"type": "json_object"}, + ) + content = response.choices[0].message.content + return json.loads(content) + + +def _extract_items(result): + """Extraire la liste d'items depuis la réponse JSON (gère différents formats).""" + if isinstance(result, list): + return result + if isinstance(result, dict): + # Chercher récursivement un tableau d'items + for v in result.values(): + if isinstance(v, list) and v and isinstance(v[0], dict): + return v + # Si pas de tableau, c'est peut-être un seul item + if "scenario_clinique" in result or "raisonnement" in result: + return [result] + # Dernier recours : aplatir les valeurs dict + items = [] + for v in result.values(): + if isinstance(v, dict) and ("scenario_clinique" in v or "raisonnement" in v): + items.append(v) + if items: + return items + return [] + + +def convert_to_chatml(scenario_data): + """Convertir un résultat de génération en format ChatML.""" + if not isinstance(scenario_data, dict): + return None + + system_msg = "Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. Tu codes les diagnostics en suivant une démarche structurée." + + scenario = scenario_data.get("scenario_clinique", "") + raisonnement = scenario_data.get("raisonnement", {}) + + # Si pas de scenario_clinique, chercher dans d'autres clés possibles + if not scenario: + scenario = scenario_data.get("scenario", scenario_data.get("texte_clinique", "")) + + # Si le raisonnement est directement dans le dict (pas imbriqué) + if not raisonnement and "analyse_clinique" in scenario_data: + raisonnement = {k: v for k, v in scenario_data.items() if k != "scenario_clinique"} + + if not scenario or not raisonnement: + return None + + user_msg = f"Code ce diagnostic en CIM-10.\n\nTEXTE CLINIQUE : {scenario}" + assistant_msg = json.dumps(raisonnement, ensure_ascii=False) + + return { + "messages": [ + {"role": "system", "content": system_msg}, + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": assistant_msg}, + ] + } + + +def process_scenario_batch(client, model, batch, fhir_by_code, cocoa_entries, batch_idx): + """Traiter un batch de codes pour générer des scénarios.""" + prompt = build_scenario_prompt(batch, fhir_by_code, cocoa_entries) + + try: + result = call_openai(client, model, SYSTEM_PROMPT_SCENARIO, prompt) + + # Le résultat peut être un tableau ou un dict avec une clé contenant le tableau + items = _extract_items(result) + + examples = [] + for item in items: + chatml = convert_to_chatml(item) + if chatml: + examples.append(chatml) + else: + print(f" [Batch {batch_idx}] Item non converti: {list(item.keys()) if isinstance(item, dict) else type(item)}") + + if len(examples) < len(batch): + print(f" [Batch {batch_idx}] {len(examples)}/{len(batch)} exemples récupérés") + + return examples + + except Exception as e: + print(f" [Batch {batch_idx}] Erreur: {e}") + return [] + + +def process_discrimination_batch(client, model, group, fhir_by_code, batch_idx): + """Traiter un groupe pour générer un exercice de discrimination.""" + prompt = build_discrimination_prompt(group, fhir_by_code) + + try: + result = call_openai(client, model, SYSTEM_PROMPT_DISCRIM, prompt) + chatml = convert_to_chatml(result) + return [chatml] if chatml else [] + + except Exception as e: + print(f" [Discrim {batch_idx}] Erreur: {e}") + return [] + + +def main(): + parser = argparse.ArgumentParser(description="Génération de données synthétiques via OpenAI") + parser.add_argument("--n", type=int, default=500, help="Nombre de scénarios à générer") + parser.add_argument("--n-discrim", type=int, default=100, help="Nombre d'exercices de discrimination") + parser.add_argument("--batch", type=int, default=5, help="Codes par batch (scénarios)") + parser.add_argument("--model", default="gpt-4o", help="Modèle OpenAI") + parser.add_argument("--workers", type=int, default=3, help="Workers parallèles") + parser.add_argument("--dry-run", action="store_true", help="Afficher les prompts sans appeler l'API") + parser.add_argument("--resume", action="store_true", help="Reprendre depuis la dernière exécution") + args = parser.parse_args() + + # Vérifier la clé API + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key and not args.dry_run: + print("Erreur: OPENAI_API_KEY non définie.") + print(" export OPENAI_API_KEY='sk-...'") + sys.exit(1) + + # Charger les données + print("Chargement des données...") + metadata = load_faiss_metadata() + fhir_by_code = load_cim10_fhir() + cocoa_entries = load_cocoa_entries() + print(f" FAISS: {len(metadata)} entrées") + print(f" FHIR: {len(fhir_by_code)} concepts") + print(f" CoCoA: {len(cocoa_entries)} entrées") + + # Sélectionner les codes + print(f"\nSélection de {args.n} codes pour les scénarios...") + selected_codes = select_codes_for_scenarios(metadata, n=args.n) + print(f" {len(selected_codes)} codes sélectionnés") + + print(f"Sélection de {args.n_discrim} groupes pour la discrimination...") + discrim_groups = select_discrimination_groups(metadata, cocoa_entries, n=args.n_discrim) + print(f" {len(discrim_groups)} groupes sélectionnés") + + # Découper en batches + scenario_batches = [ + selected_codes[i:i + args.batch] + for i in range(0, len(selected_codes), args.batch) + ] + print(f"\n{len(scenario_batches)} batches de scénarios ({args.batch} codes/batch)") + print(f"{len(discrim_groups)} exercices de discrimination") + + # Fichier de sortie (avec reprise possible) + output_path = OUT / "synthetic_chatml.jsonl" + existing_count = 0 + if args.resume and output_path.exists(): + with open(output_path) as f: + existing_count = sum(1 for _ in f) + print(f"\nReprise: {existing_count} exemples existants") + + if args.dry_run: + # Mode dry-run : montrer des exemples de prompts + print("\n=== DRY RUN ===") + print("\n--- Exemple de prompt scénario ---") + prompt = build_scenario_prompt(scenario_batches[0], fhir_by_code, cocoa_entries) + print(prompt[:2000]) + print("\n--- Exemple de prompt discrimination ---") + if discrim_groups: + prompt = build_discrimination_prompt(discrim_groups[0], fhir_by_code) + print(prompt[:1500]) + return + + # Initialiser le client OpenAI + from openai import OpenAI + client = OpenAI(api_key=api_key) + + all_examples = [] + total_batches = len(scenario_batches) + len(discrim_groups) + completed = 0 + errors = 0 + + # Ouvrir le fichier en mode append + mode = "a" if args.resume else "w" + with open(output_path, mode) as fh: + + # Phase 1 : Scénarios cliniques + print(f"\n{'='*50}") + print(f"Phase 1 : Génération des scénarios cliniques...") + print(f"{'='*50}") + + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = {} + for i, batch in enumerate(scenario_batches): + future = executor.submit( + process_scenario_batch, client, args.model, + batch, fhir_by_code, cocoa_entries, i + ) + futures[future] = i + + for future in as_completed(futures): + batch_idx = futures[future] + try: + examples = future.result() + for ex in examples: + fh.write(json.dumps(ex, ensure_ascii=False) + "\n") + all_examples.append(ex) + completed += 1 + if completed % 10 == 0: + print(f" [{completed}/{len(scenario_batches)}] {len(all_examples)} exemples générés...") + except Exception as e: + errors += 1 + print(f" [Batch {batch_idx}] Exception: {e}") + + print(f" Scénarios: {len(all_examples)} exemples") + + # Phase 2 : Discrimination + print(f"\n{'='*50}") + print(f"Phase 2 : Génération des exercices de discrimination...") + print(f"{'='*50}") + + discrim_count = 0 + with ThreadPoolExecutor(max_workers=args.workers) as executor: + futures = {} + for i, group in enumerate(discrim_groups): + future = executor.submit( + process_discrimination_batch, client, args.model, + group, fhir_by_code, i + ) + futures[future] = i + + for future in as_completed(futures): + batch_idx = futures[future] + try: + examples = future.result() + for ex in examples: + fh.write(json.dumps(ex, ensure_ascii=False) + "\n") + all_examples.append(ex) + discrim_count += 1 + except Exception as e: + errors += 1 + print(f" [Discrim {batch_idx}] Exception: {e}") + + print(f" Discrimination: {discrim_count} exemples") + + # Stats finales + total = len(all_examples) + existing_count + print(f"\n{'='*50}") + print(f"Génération terminée !") + print(f" Nouveaux exemples : {len(all_examples)}") + if existing_count: + print(f" Existants (reprise): {existing_count}") + print(f" Total : {total}") + print(f" Erreurs : {errors}") + print(f" Fichier : {output_path}") + if output_path.exists(): + print(f" Taille : {output_path.stat().st_size / 1024:.0f} Ko") + + +if __name__ == "__main__": + main() diff --git a/scripts/07_setup_unsloth.sh b/scripts/07_setup_unsloth.sh new file mode 100644 index 0000000..f210599 --- /dev/null +++ b/scripts/07_setup_unsloth.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Phase 2 — Installation d'Unsloth et dépendances pour le fine-tuning +# +# Prérequis vérifiés : +# - PyTorch 2.10.0+cu128 ✅ +# - CUDA 12.x ✅ +# - transformers 4.57.6 ✅ +# - accelerate 1.12.0 ✅ +# +# Usage : bash scripts/07_setup_unsloth.sh + +set -e + +VENV="/home/dom/ai/t2a/.venv" +PIP="$VENV/bin/pip" +PYTHON="$VENV/bin/python3" + +echo "=== Installation Unsloth + dépendances ===" +echo "" + +# 1. bitsandbytes (quantification 4-bit) +echo "[1/4] Installation de bitsandbytes..." +$PIP install --upgrade bitsandbytes + +# 2. PEFT (LoRA) +echo "" +echo "[2/4] Installation de PEFT..." +$PIP install --upgrade peft + +# 3. TRL (SFTTrainer) +echo "" +echo "[3/4] Installation de TRL..." +$PIP install --upgrade trl + +# 4. Unsloth +echo "" +echo "[4/4] Installation d'Unsloth..." +$PIP install --upgrade --no-cache-dir "unsloth[cu128-torch2100] @ git+https://github.com/unslothai/unsloth.git" + +# Vérification +echo "" +echo "=== Vérification ===" +$PYTHON -c " +import torch +print(f'PyTorch: {torch.__version__}') +print(f'CUDA: {torch.cuda.is_available()}') +print(f'GPU: {torch.cuda.get_device_name(0)}') + +import bitsandbytes +print(f'bitsandbytes: OK') + +from peft import LoraConfig +print(f'PEFT: OK') + +from trl import SFTTrainer +print(f'TRL: OK') + +try: + from unsloth import FastLanguageModel + print(f'Unsloth: OK') +except Exception as e: + print(f'Unsloth: {e}') + print('Essayez: pip install unsloth') + +print() +print('=== Setup prêt pour le fine-tuning ! ===') +" diff --git a/scripts/08_train_lora.py b/scripts/08_train_lora.py new file mode 100644 index 0000000..6e4919d --- /dev/null +++ b/scripts/08_train_lora.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 +""" +Phase 2 — Fine-tuning QLoRA de gemma3:12b avec Unsloth. + +IMPORTANT : Arrêter Ollama avant de lancer ce script ! + sudo systemctl stop ollama + +Le script : + 1. Charge gemma3:12b en 4-bit quantifié + 2. Attache un adaptateur LoRA + 3. Entraîne sur le dataset PMSI (ChatML) + 4. Sauvegarde l'adaptateur LoRA + 5. (Optionnel) Exporte en GGUF pour Ollama + +Prérequis : + - bash scripts/07_setup_unsloth.sh (installer les dépendances) + - data/datasets/pmsi_train.jsonl + pmsi_eval.jsonl + +Usage : + python scripts/08_train_lora.py [--epochs 3] [--lr 2e-4] [--batch 1] [--export-gguf] + +Hardware cible : RTX 5070 (12 Go VRAM) + - QLoRA 4-bit sur 12B ≈ 8-9 Go VRAM + - batch_size=1 + gradient_accumulation=8 → batch effectif de 8 + - gradient_checkpointing pour économiser la VRAM +""" + +import argparse +import json +import os +from pathlib import Path + +BASE = Path(__file__).resolve().parent.parent +DATASETS = BASE / "data" / "datasets" +OUTPUT = BASE / "models" +OUTPUT.mkdir(parents=True, exist_ok=True) + + +def check_prerequisites(): + """Vérifier que tout est prêt.""" + import torch + + # GPU disponible ? + if not torch.cuda.is_available(): + raise RuntimeError("CUDA non disponible. Vérifiez votre installation GPU.") + + gpu_name = torch.cuda.get_device_name(0) + vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3 + vram_free = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / 1024**3 + + print(f"GPU: {gpu_name}") + print(f"VRAM: {vram_total:.1f} Go total, {vram_free:.1f} Go libre") + + if vram_free < 10: + print("⚠ Moins de 10 Go de VRAM libre.") + print(" → Arrêtez Ollama : sudo systemctl stop ollama") + print(" → Ou utilisez --offload-cpu pour décharger sur la RAM") + + # Dataset existe ? + train_path = DATASETS / "pmsi_train.jsonl" + eval_path = DATASETS / "pmsi_eval.jsonl" + if not train_path.exists() or not eval_path.exists(): + raise FileNotFoundError( + f"Dataset non trouvé. Lancez d'abord : python scripts/04_build_dataset.py" + ) + + # Compter les exemples + with open(train_path) as f: + n_train = sum(1 for _ in f) + with open(eval_path) as f: + n_eval = sum(1 for _ in f) + + print(f"Dataset: {n_train} train + {n_eval} eval") + return train_path, eval_path + + +def load_model(model_name, max_seq_length, load_in_4bit=True): + """Charger le modèle avec Unsloth.""" + from unsloth import FastLanguageModel + + print(f"\nChargement de {model_name} (4-bit={load_in_4bit})...") + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=model_name, + max_seq_length=max_seq_length, + dtype=None, # Auto-détection + load_in_4bit=load_in_4bit, + ) + + print(f" Modèle chargé : {model.config._name_or_path}") + print(f" Paramètres : {model.num_parameters() / 1e9:.1f}B") + + return model, tokenizer + + +def attach_lora(model, r=32, alpha=64, dropout=0.05): + """Attacher l'adaptateur LoRA.""" + from unsloth import FastLanguageModel + + print(f"\nAttachement LoRA (r={r}, alpha={alpha}, dropout={dropout})...") + + model = FastLanguageModel.get_peft_model( + model, + r=r, + target_modules=[ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + ], + lora_alpha=alpha, + lora_dropout=dropout, + bias="none", + use_gradient_checkpointing="unsloth", # Économise 30% VRAM + random_state=42, + ) + + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + print(f" Paramètres entraînables : {trainable / 1e6:.1f}M / {total / 1e9:.1f}B ({100 * trainable / total:.2f}%)") + + return model + + +def load_dataset(train_path, eval_path): + """Charger le dataset au format ChatML.""" + from datasets import Dataset + + def load_jsonl(path): + examples = [] + with open(path) as f: + for line in f: + examples.append(json.loads(line.strip())) + return examples + + train_data = load_jsonl(train_path) + eval_data = load_jsonl(eval_path) + + train_ds = Dataset.from_list(train_data) + eval_ds = Dataset.from_list(eval_data) + + print(f"\nDataset chargé :") + print(f" Train : {len(train_ds)} exemples") + print(f" Eval : {len(eval_ds)} exemples") + + return train_ds, eval_ds + + +def format_chat(example, tokenizer): + """Formater un exemple ChatML pour l'entraînement.""" + messages = example["messages"] + # Utiliser le template de chat du tokenizer + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, + ) + return {"text": text} + + +def train(model, tokenizer, train_ds, eval_ds, args): + """Lancer l'entraînement.""" + from trl import SFTTrainer, SFTConfig + from aim.hugging_face import AimCallback + + print(f"\nConfiguration d'entraînement :") + print(f" Epochs : {args.epochs}") + print(f" Learning rate : {args.lr}") + print(f" Batch size : {args.batch} (gradient_accumulation={args.grad_accum})") + print(f" Batch effectif : {args.batch * args.grad_accum}") + print(f" Max seq length : {args.max_seq_length}") + + # Formater le dataset + train_ds = train_ds.map(lambda x: format_chat(x, tokenizer), num_proc=4) + eval_ds = eval_ds.map(lambda x: format_chat(x, tokenizer), num_proc=4) + + output_dir = OUTPUT / "pmsi-lora-checkpoints" + + # Callback Aim pour le tracking des métriques + aim_callback = AimCallback( + repo=str(BASE), + experiment="pmsi-coder-v2", + ) + + training_args = SFTConfig( + output_dir=str(output_dir), + num_train_epochs=args.epochs, + per_device_train_batch_size=args.batch, + per_device_eval_batch_size=args.batch, + gradient_accumulation_steps=args.grad_accum, + learning_rate=args.lr, + weight_decay=0.01, + warmup_ratio=0.05, + lr_scheduler_type="cosine", + logging_steps=10, + eval_strategy="steps", + eval_steps=1000, + save_strategy="steps", + save_steps=500, + save_total_limit=3, + fp16=False, + bf16=True, + max_seq_length=args.max_seq_length, + dataset_text_field="text", + seed=42, + report_to="none", + ) + + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=train_ds, + eval_dataset=eval_ds, + args=training_args, + callbacks=[aim_callback], + ) + + print(f"\nDémarrage de l'entraînement...") + print(f" Output : {output_dir}") + print(f" Steps estimés : ~{len(train_ds) * args.epochs // (args.batch * args.grad_accum)}") + + if args.resume: + print(f" Reprise depuis le dernier checkpoint...") + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + + # Sauvegarder le modèle final + final_dir = OUTPUT / "pmsi-lora-final" + model.save_pretrained(str(final_dir)) + tokenizer.save_pretrained(str(final_dir)) + print(f"\nModèle LoRA sauvegardé : {final_dir}") + + return trainer, final_dir + + +def export_gguf(model, tokenizer, final_dir, quantization="q4_k_m"): + """Exporter en GGUF pour Ollama.""" + print(f"\nExport GGUF ({quantization})...") + + gguf_dir = OUTPUT / "pmsi-gguf" + gguf_dir.mkdir(parents=True, exist_ok=True) + + # Unsloth export + model.save_pretrained_gguf( + str(gguf_dir), + tokenizer, + quantization_method=quantization, + ) + + # Trouver le fichier GGUF + gguf_files = list(gguf_dir.glob("*.gguf")) + if gguf_files: + gguf_path = gguf_files[0] + print(f" GGUF exporté : {gguf_path} ({gguf_path.stat().st_size / 1024**3:.1f} Go)") + + # Créer le Modelfile pour Ollama + modelfile_path = gguf_dir / "Modelfile" + modelfile_content = f"""FROM {gguf_path.name} + +PARAMETER temperature 0.3 +PARAMETER top_p 0.9 +PARAMETER num_ctx 8192 +""" + with open(modelfile_path, "w") as f: + f.write(modelfile_content) + + print(f" Modelfile créé : {modelfile_path}") + print(f"\n Pour importer dans Ollama :") + print(f" cd {gguf_dir}") + print(f" ollama create pmsi-coder -f Modelfile") + else: + print(" Aucun fichier GGUF trouvé !") + + +def main(): + parser = argparse.ArgumentParser(description="Fine-tuning QLoRA avec Unsloth") + + # Modèle + parser.add_argument("--model", default="unsloth/gemma-3-12b-it-bnb-4bit", + help="Nom du modèle HuggingFace") + parser.add_argument("--max-seq-length", type=int, default=512, + help="Longueur max des séquences") + + # LoRA + parser.add_argument("--lora-r", type=int, default=32, help="Rang LoRA") + parser.add_argument("--lora-alpha", type=int, default=64, help="Alpha LoRA") + parser.add_argument("--lora-dropout", type=float, default=0.0, help="Dropout LoRA (0=fast patching Unsloth)") + + # Entraînement + parser.add_argument("--epochs", type=int, default=3, help="Nombre d'epochs") + parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") + parser.add_argument("--batch", type=int, default=1, help="Batch size par GPU") + parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation steps") + + # Resume + parser.add_argument("--resume", action="store_true", help="Reprendre depuis le dernier checkpoint") + + # Export + parser.add_argument("--export-gguf", action="store_true", help="Exporter en GGUF après entraînement") + parser.add_argument("--gguf-quant", default="q4_k_m", help="Méthode de quantification GGUF") + + args = parser.parse_args() + + # Vérifications + train_path, eval_path = check_prerequisites() + + # Charger le modèle + model, tokenizer = load_model(args.model, args.max_seq_length) + + # Attacher LoRA + model = attach_lora(model, r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout) + + # Charger le dataset + train_ds, eval_ds = load_dataset(train_path, eval_path) + + # Entraîner + trainer, final_dir = train(model, tokenizer, train_ds, eval_ds, args) + + # Export GGUF optionnel + if args.export_gguf: + export_gguf(model, tokenizer, final_dir, args.gguf_quant) + + print("\n" + "=" * 50) + print("Fine-tuning terminé !") + print(f" Adaptateur LoRA : {final_dir}") + if args.export_gguf: + print(f" GGUF : {OUTPUT / 'pmsi-gguf'}") + print("=" * 50) + + +if __name__ == "__main__": + main() diff --git a/scripts/09_build_embedding_triplets.py b/scripts/09_build_embedding_triplets.py new file mode 100644 index 0000000..ec18957 --- /dev/null +++ b/scripts/09_build_embedding_triplets.py @@ -0,0 +1,460 @@ +#!/usr/bin/env python3 +""" +Phase 3 — Génération de triplets (anchor, positive, negative) pour fine-tuning d'embedding. + +Sources : + 1. Cache Ollama (1840 paires diagnostic → code CIM-10) + 2. Métadonnées FAISS (12444 entrées CIM-10 avec extraits) + 3. CoCoA (synonymes, exclusions) + 4. Données CIM-10 FHIR (descriptions → codes) + +Format de sortie : JSONL avec {anchor, positive, negative} + - anchor : texte clinique / diagnostic tel qu'écrit par le médecin + - positive : extrait de référence CIM-10 correspondant au bon code + - negative : extrait de référence CIM-10 d'un code similaire mais incorrect (hard negative) + +Usage : + python scripts/09_build_embedding_triplets.py [--output data/datasets/triplets.jsonl] +""" + +import argparse +import json +import random +from collections import defaultdict +from pathlib import Path + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +T2A = BASE.parent / "t2a" + +# Sources +CACHE_PATH = T2A / "data" / "ollama_cache.json" +METADATA_PATH = T2A / "data" / "rag_index" / "metadata.json" +COCOA_PATH = BASE / "data" / "processed" / "cocoa_entries_debug.json" +CIM10_CHATML = BASE / "data" / "processed" / "cim10_chatml.jsonl" + +# Output +DEFAULT_OUTPUT = BASE / "data" / "datasets" / "triplets.jsonl" + + +def load_faiss_metadata() -> tuple[dict[str, list[dict]], list[dict]]: + """Charge les métadonnées FAISS et indexe par code CIM-10. + + Returns: + (code_to_entries, all_cim10_entries) + """ + with open(METADATA_PATH) as f: + metadata = json.load(f) + + code_to_entries = defaultdict(list) + all_cim10 = [] + + for m in metadata: + if m.get("document") not in ("cim10", "cim10_alpha"): + continue + code = m.get("code", "").strip() + if not code: + continue + all_cim10.append(m) + # Normaliser le code (avec et sans point) + code_to_entries[code].append(m) + # Aussi indexer par préfixe 3 caractères pour les hard negatives + # Ex: I10 → chapitre I, catégorie I10 + + print(f"FAISS metadata : {len(all_cim10)} entrées CIM-10, {len(code_to_entries)} codes uniques") + return code_to_entries, all_cim10 + + +def get_chapter(code: str) -> str: + """Extrait le chapitre CIM-10 (première lettre).""" + return code[0] if code else "" + + +def get_category(code: str) -> str: + """Extrait la catégorie CIM-10 (3 premiers caractères, ex: I10, K80).""" + clean = code.replace(".", "") + return clean[:3] if len(clean) >= 3 else code + + +def find_hard_negatives( + code: str, + code_to_entries: dict[str, list[dict]], + all_cim10: list[dict], + n: int = 3, +) -> list[dict]: + """Trouve des hard negatives : codes proches mais différents. + + Stratégie par priorité : + 1. Même catégorie (ex: K80.0 vs K80.2) — le plus dur + 2. Même chapitre (ex: K80 vs K81) — difficile + 3. Random d'un autre chapitre — facile (pour contraste) + """ + category = get_category(code) + chapter = get_chapter(code) + negatives = [] + + # 1. Même catégorie (siblings) + siblings = [] + for c, entries in code_to_entries.items(): + if c != code and get_category(c) == category: + siblings.extend(entries) + if siblings: + random.shuffle(siblings) + negatives.extend(siblings[:n]) + + # 2. Même chapitre si pas assez + if len(negatives) < n: + chapter_entries = [] + for c, entries in code_to_entries.items(): + if c != code and get_chapter(c) == chapter and get_category(c) != category: + chapter_entries.extend(entries) + if chapter_entries: + random.shuffle(chapter_entries) + negatives.extend(chapter_entries[: n - len(negatives)]) + + # 3. Random si toujours pas assez + if len(negatives) < n: + others = [m for m in all_cim10 if get_category(m.get("code", "")) != category] + if others: + random.shuffle(others) + negatives.extend(others[: n - len(negatives)]) + + return negatives[:n] + + +def format_entry_text(entry: dict) -> str: + """Formate une entrée FAISS en texte lisible pour l'embedding.""" + code = entry.get("code", "") + extrait = entry.get("extrait", "").strip() + if code and extrait: + # S'assurer que le code est dans le texte + if not extrait.startswith(code): + return f"{code} {extrait}" + return extrait + + +def triplets_from_cache( + code_to_entries: dict[str, list[dict]], + all_cim10: list[dict], +) -> list[dict]: + """Génère des triplets à partir du cache Ollama.""" + with open(CACHE_PATH) as f: + cache = json.load(f) + + entries = cache.get("entries", cache) + if isinstance(entries, str): + return [] + + triplets = [] + skipped = {"no_code": 0, "no_match": 0, "no_neg": 0} + + for key, val in entries.items(): + if not isinstance(val, dict): + continue + + code = val.get("code", "") + if not code: + skipped["no_code"] += 1 + continue + + # Extraire le texte du diagnostic depuis la clé + # Format: "das::hypertension artérielle" ou "dp::pancréatite aiguë" + parts = key.split("::", 1) + if len(parts) != 2: + continue + diag_type, diag_text = parts + if diag_type == "das_llm": + continue # Format différent, skip + + # Normaliser le code + code_clean = code.strip().upper() + + # Trouver le positive dans FAISS + positive_entries = code_to_entries.get(code_clean, []) + if not positive_entries: + # Essayer sans le point + code_no_dot = code_clean.replace(".", "") + for c, ents in code_to_entries.items(): + if c.replace(".", "") == code_no_dot: + positive_entries = ents + break + if not positive_entries: + skipped["no_match"] += 1 + continue + + positive = random.choice(positive_entries) + positive_text = format_entry_text(positive) + + # Hard negatives + negs = find_hard_negatives(code_clean, code_to_entries, all_cim10, n=2) + if not negs: + skipped["no_neg"] += 1 + continue + + for neg in negs: + neg_text = format_entry_text(neg) + if neg_text and positive_text and neg_text != positive_text: + triplets.append({ + "anchor": diag_text, + "positive": positive_text, + "negative": neg_text, + "source": "cache_ollama", + "code_pos": code_clean, + "code_neg": neg.get("code", ""), + }) + + print(f"Cache Ollama → {len(triplets)} triplets " + f"(skip: {skipped['no_code']} sans code, {skipped['no_match']} pas dans FAISS, " + f"{skipped['no_neg']} sans négatif)") + return triplets + + +def triplets_from_cocoa( + code_to_entries: dict[str, list[dict]], + all_cim10: list[dict], +) -> list[dict]: + """Génère des triplets à partir de CoCoA (synonymes + exclusions).""" + if not COCOA_PATH.exists(): + print("CoCoA debug file non trouvé, skip") + return [] + + with open(COCOA_PATH) as f: + cocoa = json.load(f) + + # Le fichier debug est un dict {code: entry}, pas une liste + if isinstance(cocoa, dict): + cocoa = list(cocoa.values()) + + triplets = [] + + for entry in cocoa: + code = entry.get("code", "").strip() + desc = entry.get("description", "").strip() + synonymes = entry.get("synonyms", []) + exclusions = entry.get("exclusions", []) + + if not code or not desc: + continue + + # Positive = entrée FAISS pour ce code + positive_entries = code_to_entries.get(code, []) + if not positive_entries: + code_no_dot = code.replace(".", "") + for c, ents in code_to_entries.items(): + if c.replace(".", "") == code_no_dot: + positive_entries = ents + break + if not positive_entries: + continue + + positive = random.choice(positive_entries) + positive_text = format_entry_text(positive) + + # Synonymes comme anchors supplémentaires + anchors = [desc] + [s for s in synonymes if len(s) > 5] + + # Exclusions comme hard negatives explicites + explicit_negs = [] + for excl in exclusions: + # Essayer d'extraire un code de l'exclusion (ex: "Diabète de type 2 (E11)") + import re + code_match = re.search(r"\(([A-Z]\d{2}(?:\.\d{1,2})?)\)", excl) + if code_match: + neg_code = code_match.group(1) + neg_entries = code_to_entries.get(neg_code, []) + if neg_entries: + explicit_negs.append(random.choice(neg_entries)) + + # Hard negatives auto si pas assez d'explicites + auto_negs = find_hard_negatives(code, code_to_entries, all_cim10, n=2) + all_negs = explicit_negs + auto_negs + + if not all_negs: + continue + + for anchor in anchors[:3]: # Max 3 anchors par code + for neg in all_negs[:2]: # Max 2 négatifs par anchor + neg_text = format_entry_text(neg) + if neg_text and positive_text and neg_text != positive_text: + triplets.append({ + "anchor": anchor, + "positive": positive_text, + "negative": neg_text, + "source": "cocoa", + "code_pos": code, + "code_neg": neg.get("code", ""), + }) + + print(f"CoCoA → {len(triplets)} triplets") + return triplets + + +def triplets_from_cim10_fhir( + code_to_entries: dict[str, list[dict]], + all_cim10: list[dict], +) -> list[dict]: + """Génère des triplets à partir du CIM-10 FHIR ChatML (description → code).""" + if not CIM10_CHATML.exists(): + print("CIM-10 ChatML non trouvé, skip") + return [] + + triplets = [] + seen_codes = set() + + with open(CIM10_CHATML) as f: + for line in f: + example = json.loads(line.strip()) + messages = example.get("messages", []) + + # Extraire la question (user) et la réponse (assistant) + user_msg = "" + assistant_msg = "" + for m in messages: + if m["role"] == "user": + user_msg = m["content"] + elif m["role"] == "assistant": + assistant_msg = m["content"] + + if not user_msg or not assistant_msg: + continue + + # Parser le code de la réponse + try: + resp = json.loads(assistant_msg) + code = resp.get("code", "") + except json.JSONDecodeError: + continue + + if not code or code in seen_codes: + continue + seen_codes.add(code) + + # Extraire le texte du diagnostic depuis la question + # Format: "Quel est le code CIM-10 pour : Description" + # ou "Codage CIM-10 du diagnostic : Description" + import re + match = re.search(r"(?:pour|diagnostic)\s*:\s*(.+)", user_msg) + if not match: + continue + anchor = match.group(1).strip() + if len(anchor) < 5: + continue + + # Positive + positive_entries = code_to_entries.get(code, []) + if not positive_entries: + continue + positive = random.choice(positive_entries) + positive_text = format_entry_text(positive) + + # Hard negative + negs = find_hard_negatives(code, code_to_entries, all_cim10, n=1) + if not negs: + continue + + neg = negs[0] + neg_text = format_entry_text(neg) + if neg_text and positive_text and neg_text != positive_text: + triplets.append({ + "anchor": anchor, + "positive": positive_text, + "negative": neg_text, + "source": "cim10_fhir", + "code_pos": code, + "code_neg": neg.get("code", ""), + }) + + print(f"CIM-10 FHIR → {len(triplets)} triplets") + return triplets + + +def deduplicate(triplets: list[dict]) -> list[dict]: + """Déduplique par (anchor, positive, negative).""" + seen = set() + deduped = [] + for t in triplets: + key = (t["anchor"][:50], t["code_pos"], t["code_neg"]) + if key not in seen: + seen.add(key) + deduped.append(t) + return deduped + + +def analyze_triplets(triplets: list[dict]) -> None: + """Statistiques sur les triplets générés.""" + sources = defaultdict(int) + hard_neg_types = {"same_category": 0, "same_chapter": 0, "diff_chapter": 0} + + for t in triplets: + sources[t["source"]] += 1 + code_pos = t["code_pos"] + code_neg = t["code_neg"] + if get_category(code_pos) == get_category(code_neg): + hard_neg_types["same_category"] += 1 + elif get_chapter(code_pos) == get_chapter(code_neg): + hard_neg_types["same_chapter"] += 1 + else: + hard_neg_types["diff_chapter"] += 1 + + print(f"\n=== Statistiques ===") + print(f"Total triplets : {len(triplets)}") + print(f"Par source :") + for src, count in sorted(sources.items()): + print(f" {src}: {count}") + print(f"Difficulté des négatifs :") + for neg_type, count in hard_neg_types.items(): + pct = 100 * count / len(triplets) if triplets else 0 + print(f" {neg_type}: {count} ({pct:.1f}%)") + + +def main(): + parser = argparse.ArgumentParser(description="Génération de triplets pour embedding fine-tuning") + parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT, + help="Fichier de sortie JSONL") + args = parser.parse_args() + + print("=== Génération de triplets pour embedding fine-tuning ===\n") + + # Charger FAISS metadata + code_to_entries, all_cim10 = load_faiss_metadata() + + # Générer les triplets depuis chaque source + all_triplets = [] + + all_triplets.extend(triplets_from_cache(code_to_entries, all_cim10)) + all_triplets.extend(triplets_from_cocoa(code_to_entries, all_cim10)) + all_triplets.extend(triplets_from_cim10_fhir(code_to_entries, all_cim10)) + + # Dédupliquer + all_triplets = deduplicate(all_triplets) + random.shuffle(all_triplets) + + # Statistiques + analyze_triplets(all_triplets) + + # Sauvegarder + args.output.parent.mkdir(parents=True, exist_ok=True) + with open(args.output, "w") as f: + for t in all_triplets: + # Format simplifié pour l'entraînement : anchor, positive, negative + out = { + "anchor": t["anchor"], + "positive": t["positive"], + "negative": t["negative"], + } + f.write(json.dumps(out, ensure_ascii=False) + "\n") + + size_mo = args.output.stat().st_size / 1024 / 1024 + print(f"\nSauvegardé : {args.output} ({len(all_triplets)} triplets, {size_mo:.1f} Mo)") + + # Aussi sauvegarder version détaillée pour debug + debug_path = args.output.with_suffix(".debug.jsonl") + with open(debug_path, "w") as f: + for t in all_triplets[:50]: + f.write(json.dumps(t, ensure_ascii=False, indent=2) + "\n---\n") + print(f"Debug (50 premiers) : {debug_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/10_train_embedding.py b/scripts/10_train_embedding.py new file mode 100644 index 0000000..98e95ac --- /dev/null +++ b/scripts/10_train_embedding.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 +""" +Phase 3 — Fine-tuning du modèle d'embedding pour le RAG médical. + +Fine-tune sentence-camembert-large sur les triplets PMSI/CIM-10 +pour améliorer la recherche FAISS dans le pipeline T2A. + +Prérequis : + - python scripts/09_build_embedding_triplets.py + - data/datasets/triplets.jsonl (41K+ triplets) + +Usage : + python scripts/10_train_embedding.py [--epochs 3] [--batch 32] [--lr 2e-5] [--device cpu] + + --device cpu : entraîner sur CPU (si le GPU est occupé par le LoRA) + --device cuda : entraîner sur GPU (plus rapide, ~30 min) + +Hardware : + - CPU : ~2-3h (Ryzen 9 9950X) + - GPU : ~30 min (RTX 5070, si disponible) + - RAM : ~4 Go (modèle ~1.3 Go + données) +""" + +import argparse +import json +import random +from pathlib import Path + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +TRIPLETS_PATH = BASE / "data" / "datasets" / "triplets.jsonl" +OUTPUT_DIR = BASE / "models" / "sentence-camembert-pmsi" + +# Modèle de base (même que dans le pipeline T2A) +BASE_MODEL = "dangvantuan/sentence-camembert-large" + + +def load_triplets(path: Path, eval_ratio: float = 0.1): + """Charge les triplets et split train/eval.""" + from datasets import Dataset + + anchors, positives, negatives = [], [], [] + + with open(path) as f: + for line in f: + t = json.loads(line.strip()) + anchors.append(t["anchor"]) + positives.append(t["positive"]) + negatives.append(t["negative"]) + + total = len(anchors) + print(f"Triplets chargés : {total}") + + # Shuffle déterministe + indices = list(range(total)) + random.shuffle(indices) + + split = int(total * (1 - eval_ratio)) + train_idx = indices[:split] + eval_idx = indices[split:] + + train_ds = Dataset.from_dict({ + "anchor": [anchors[i] for i in train_idx], + "positive": [positives[i] for i in train_idx], + "negative": [negatives[i] for i in train_idx], + }) + eval_ds = Dataset.from_dict({ + "anchor": [anchors[i] for i in eval_idx], + "positive": [positives[i] for i in eval_idx], + "negative": [negatives[i] for i in eval_idx], + }) + + print(f" Train : {len(train_ds)}") + print(f" Eval : {len(eval_ds)}") + + return train_ds, eval_ds + + +def main(): + parser = argparse.ArgumentParser(description="Fine-tuning embedding PMSI") + + parser.add_argument("--model", default=BASE_MODEL, + help="Modèle de base sentence-transformers") + parser.add_argument("--triplets", type=Path, default=TRIPLETS_PATH, + help="Fichier de triplets JSONL") + parser.add_argument("--output", type=Path, default=OUTPUT_DIR, + help="Répertoire de sortie") + + # Entraînement + parser.add_argument("--epochs", type=int, default=3, help="Nombre d'epochs") + parser.add_argument("--batch", type=int, default=32, help="Batch size") + parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate") + parser.add_argument("--warmup-ratio", type=float, default=0.1, help="Warmup ratio") + parser.add_argument("--max-seq-length", type=int, default=256, + help="Longueur max des séquences") + + # Loss + parser.add_argument("--loss", choices=["mnrl", "triplet"], default="mnrl", + help="Fonction de loss (mnrl=MultipleNegativesRankingLoss, triplet=TripletLoss)") + + # Device + parser.add_argument("--device", default=None, + help="Device (cuda/cpu). Auto-détection si non spécifié.") + + args = parser.parse_args() + + print("=" * 60) + print("Fine-tuning embedding pour RAG PMSI") + print("=" * 60) + + # Vérifier les triplets + if not args.triplets.exists(): + raise FileNotFoundError( + f"Triplets non trouvés : {args.triplets}\n" + f"Lancez d'abord : python scripts/09_build_embedding_triplets.py" + ) + + # Device + import torch + if args.device is None: + if torch.cuda.is_available(): + # Vérifier si le GPU a assez de VRAM libre (~2 Go suffisent) + free = torch.cuda.mem_get_info()[0] / 1024**3 + if free > 2.0: + args.device = "cuda" + print(f"GPU détecté : {torch.cuda.get_device_name(0)} ({free:.1f} Go libre)") + else: + args.device = "cpu" + print(f"GPU occupé ({free:.1f} Go libre) — fallback CPU") + else: + args.device = "cpu" + print("Pas de GPU — entraînement CPU") + + # Charger le modèle + from sentence_transformers import SentenceTransformer + + print(f"\nChargement du modèle : {args.model}") + model = SentenceTransformer(args.model, device=args.device) + model.max_seq_length = args.max_seq_length + print(f" Max seq length : {model.max_seq_length}") + print(f" Dimension embedding : {model.get_sentence_embedding_dimension()}") + + # Charger les données + print() + train_ds, eval_ds = load_triplets(args.triplets) + + # Loss + if args.loss == "mnrl": + from sentence_transformers.losses import MultipleNegativesRankingLoss + loss = MultipleNegativesRankingLoss(model) + print(f"\nLoss : MultipleNegativesRankingLoss (in-batch negatives + hard negatives)") + else: + from sentence_transformers.losses import TripletLoss + loss = TripletLoss(model) + print(f"\nLoss : TripletLoss") + + # Evaluator + from sentence_transformers.evaluation import TripletEvaluator + + evaluator = TripletEvaluator( + anchors=eval_ds["anchor"], + positives=eval_ds["positive"], + negatives=eval_ds["negative"], + name="pmsi-eval", + batch_size=args.batch, + ) + + # Évaluation baseline (avant fine-tuning) + print("\nÉvaluation baseline (avant fine-tuning)...") + baseline = evaluator(model) + print(f" Accuracy baseline : {baseline.get('pmsi-eval_cosine_accuracy', 'N/A')}") + + # Training args + from sentence_transformers import SentenceTransformerTrainingArguments + + n_steps = len(train_ds) * args.epochs // args.batch + eval_steps = max(n_steps // 10, 50) # Évaluer ~10 fois pendant l'entraînement + save_steps = max(n_steps // 5, 100) # Sauvegarder ~5 fois + + training_args = SentenceTransformerTrainingArguments( + output_dir=str(args.output / "checkpoints"), + num_train_epochs=args.epochs, + per_device_train_batch_size=args.batch, + per_device_eval_batch_size=args.batch, + learning_rate=args.lr, + warmup_ratio=args.warmup_ratio, + lr_scheduler_type="cosine", + fp16=(args.device == "cuda"), + bf16=False, + logging_steps=50, + eval_strategy="steps", + eval_steps=eval_steps, + save_strategy="steps", + save_steps=save_steps, + save_total_limit=3, + load_best_model_at_end=True, + metric_for_best_model="pmsi-eval_cosine_accuracy", + greater_is_better=True, + seed=42, + report_to="aim", + ) + + print(f"\nConfiguration :") + print(f" Epochs : {args.epochs}") + print(f" Batch size : {args.batch}") + print(f" Learning rate : {args.lr}") + print(f" Steps estimés : ~{n_steps}") + print(f" Eval tous les : {eval_steps} steps") + print(f" Device : {args.device}") + + # Trainer + from sentence_transformers import SentenceTransformerTrainer + + trainer = SentenceTransformerTrainer( + model=model, + args=training_args, + train_dataset=train_ds, + eval_dataset=eval_ds, + loss=loss, + evaluator=evaluator, + ) + + # Entraîner + print(f"\nDémarrage de l'entraînement...") + trainer.train() + + # Sauvegarder le modèle final + final_dir = args.output / "final" + model.save_pretrained(str(final_dir)) + print(f"\nModèle sauvegardé : {final_dir}") + + # Évaluation finale + print("\nÉvaluation finale (après fine-tuning)...") + final_results = evaluator(model) + final_acc = final_results.get("pmsi-eval_cosine_accuracy", 0) + baseline_acc = baseline.get("pmsi-eval_cosine_accuracy", 0) + + print(f"\n{'=' * 60}") + print(f"Résultats :") + print(f" Accuracy baseline : {baseline_acc:.4f}") + print(f" Accuracy finale : {final_acc:.4f}") + if baseline_acc > 0: + delta = final_acc - baseline_acc + pct = 100 * delta / baseline_acc + print(f" Amélioration : {delta:+.4f} ({pct:+.1f}%)") + print(f"\nModèle final : {final_dir}") + print(f"{'=' * 60}") + + # Instructions d'intégration + print(f""" +Pour utiliser dans le pipeline T2A, modifier src/medical/rag_search.py : + + _embed_model = SentenceTransformer("{final_dir}", device=_device) + +Puis reconstruire l'index FAISS avec le nouveau modèle : + + python -m src.medical.rag_index --rebuild +""") + + +if __name__ == "__main__": + main() diff --git a/scripts/11_parse_referentiels.py b/scripts/11_parse_referentiels.py new file mode 100644 index 0000000..4028c20 --- /dev/null +++ b/scripts/11_parse_referentiels.py @@ -0,0 +1,477 @@ +#!/usr/bin/env python3 +""" +Parse les référentiels ATIH/PMSI pour générer des exemples ChatML. + +Sources : + 1. Annexe-4 CMA V11e — niveaux de sévérité officiels (tabulaire) + 2. Racines GHM V11e — caractéristiques des racines (tabulaire) + 3. Fascicules de codage — règles par spécialité (texte libre) + 4. Instruction DGOS contrôle T2A — priorités de contrôle + +Déduplique automatiquement avec les données CoCoA existantes (sévérité CMA). + +Usage : + python scripts/11_parse_referentiels.py +""" + +import json +import re +from pathlib import Path + +import pdfplumber + +BASE = Path(__file__).resolve().parent.parent +T2A = BASE.parent / "t2a" +REF_DIR = T2A / "data" / "referentiels" +OUTPUT = BASE / "data" / "processed" +OUTPUT.mkdir(parents=True, exist_ok=True) + +SYSTEM_PROMPT = ( + "Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. " + "Tu t'appuies sur les référentiels ATIH officiels." +) + +SYSTEM_PROMPT_GHM = ( + "Tu es un médecin DIM expert en groupage GHM/GHS pour le PMSI français. " + "Tu connais les règles de classification des GHM version 11e." +) + +SYSTEM_PROMPT_CONTROLE = ( + "Tu es un médecin DIM expert en contrôle T2A. " + "Tu connais les instructions DGOS et les règles de contrôle externe." +) + + +def _chatml(system: str, user: str, assistant: str, source: str = "") -> dict: + d = { + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + {"role": "assistant", "content": assistant}, + ] + } + if source: + d["source"] = source + return d + + +# ─── 1. Annexe-4 CMA ──────────────────────────────────────────────────────── + +def parse_annexe4_cma() -> list[dict]: + """Parse l'Annexe-4 : Diagnostics classés CMA avec niveaux de sévérité.""" + pdf_path = list(REF_DIR.glob("*Annexe-4*CMA*.pdf")) + if not pdf_path: + print(" Annexe-4 CMA non trouvée, skip") + return [] + pdf_path = pdf_path[0] + + # Charger les sévérités CoCoA existantes pour dédupliquer + cocoa_path = BASE / "data" / "processed" / "cocoa_entries_debug.json" + cocoa_severities = set() + if cocoa_path.exists(): + with open(cocoa_path) as f: + cocoa = json.load(f) + for code, entry in cocoa.items(): + if entry.get("severity"): + cocoa_severities.add(code) + + entries = [] + # Pattern: code (A00.0) suivi d'un niveau (2-4) suivi d'un libellé + pattern = re.compile(r"^([A-Z]\d{2}(?:\.\d{1,2})?)\s+([234])\s+(.+)$") + + with pdfplumber.open(str(pdf_path)) as pdf: + for page in pdf.pages: + text = page.extract_text() or "" + for line in text.split("\n"): + line = line.strip() + m = pattern.match(line) + if m: + code, niveau, libelle = m.group(1), int(m.group(2)), m.group(3).strip() + entries.append({"code": code, "niveau": niveau, "libelle": libelle}) + + print(f" Annexe-4 CMA : {len(entries)} entrées extraites") + + # Dédupliquer avec CoCoA + new_entries = [e for e in entries if e["code"] not in cocoa_severities] + dupes = len(entries) - len(new_entries) + print(f" Doublons CoCoA : {dupes}, nouvelles : {len(new_entries)}") + + # Générer les exemples ChatML + examples = [] + + # Type 1 : Quel niveau CMA ? + for e in entries: # Garder tous pour renforcer, même les doublons + examples.append(_chatml( + SYSTEM_PROMPT, + f"Quel est le niveau de sévérité CMA du code {e['code']} ({e['libelle']}) ?", + f"Le code {e['code']} ({e['libelle']}) est classé CMA de niveau {e['niveau']}. " + f"{'Ce diagnostic est considéré comme une complication ou morbidité associée majeure.' if e['niveau'] >= 3 else 'Ce diagnostic est une complication ou morbidité associée significative.'}", + source="annexe4_cma" + )) + + # Type 2 : Est-ce une CMA ? (discrimination) + # Quelques exemples de codes NON-CMA pour contraste + non_cma_codes = set() + cma_codes = {e["code"] for e in entries} + for e in entries: + # Codes voisins qui ne sont pas CMA + base = e["code"][:3] + for suffix in range(10): + candidate = f"{base}.{suffix}" + if candidate not in cma_codes and candidate not in non_cma_codes: + non_cma_codes.add(candidate) + if len(non_cma_codes) >= 500: + break + + return examples + + +# ─── 2. Racines GHM ────────────────────────────────────────────────────────── + +def parse_racines_ghm() -> list[dict]: + """Parse les Racines GHM V11e (tableau de caractéristiques).""" + pdf_path = list(REF_DIR.glob("*Racines_GHM*.pdf")) + if not pdf_path: + print(" Racines GHM non trouvé, skip") + return [] + pdf_path = pdf_path[0] + + entries = [] + # Pattern racine GHM : 2 chiffres + lettre + 2 chiffres (ex: 01C02, 05K06) + pattern = re.compile(r"^(\d{2}[A-Z]\d{2})\s+(.+)$") + + with pdfplumber.open(str(pdf_path)) as pdf: + for page in pdf.pages[3:]: # Skip pages de couverture + text = page.extract_text() or "" + for line in text.split("\n"): + line = line.strip() + m = pattern.match(line) + if m: + racine = m.group(1) + reste = m.group(2).strip() + entries.append({"racine": racine, "description": reste}) + + print(f" Racines GHM : {len(entries)} racines extraites") + + # Aussi extraire les tables complètes par page pour du contexte riche + examples = [] + with pdfplumber.open(str(pdf_path)) as pdf: + # Extraire les tables + for page in pdf.pages[3:]: + tables = page.extract_tables() + for table in tables: + if not table or len(table) < 2: + continue + headers = table[0] if table[0] else [] + for row in table[1:]: + if not row or not row[0]: + continue + racine = str(row[0]).strip() + if not re.match(r"^\d{2}[A-Z]\d{2}$", racine): + continue + # Construire la description depuis les colonnes + desc_parts = [str(c).strip() for c in row[1:] if c and str(c).strip()] + if not desc_parts: + continue + full_desc = " | ".join(desc_parts) + examples.append(_chatml( + SYSTEM_PROMPT_GHM, + f"Quelles sont les caractéristiques de la racine GHM {racine} ?", + f"La racine GHM {racine} : {full_desc}", + source="racines_ghm" + )) + + # Si pas de tables parsées, utiliser les entrées textuelles + if not examples: + for e in entries: + examples.append(_chatml( + SYSTEM_PROMPT_GHM, + f"Quelles sont les caractéristiques de la racine GHM {e['racine']} ?", + f"La racine GHM {e['racine']} : {e['description']}", + source="racines_ghm" + )) + + return examples + + +# ─── 3. Arbre de décision GHM ───────────────────────────────────────────────── + +def parse_arbre_ghm() -> list[dict]: + """Parse l'arbre de décision GHM — extrait les règles par page/bloc.""" + pdf_path = list(REF_DIR.glob("*Arbre_decision_GHM*.pdf")) + if not pdf_path: + print(" Arbre GHM non trouvé, skip") + return [] + pdf_path = pdf_path[0] + + examples = [] + # Le PDF est un arbre graphique : chaque page contient des nœuds de décision + # avec des codes GHM (ex: 01K03, 28Z18) et des conditions (listes diagnostiques/actes) + ghm_pattern = re.compile(r"\b(\d{2}[A-Z]\d{2})\b") + condition_pattern = re.compile(r"\(([A-Z]-\d{3})\)") # Ex: (A-198), (D-064) + + # Grouper les pages par CMD (les 2 premiers chiffres du GHM) + cmd_pages = {} + + with pdfplumber.open(str(pdf_path)) as pdf: + for page in pdf.pages[7:]: # Skip couverture, symboles, légende + text = page.extract_text() or "" + if not text.strip() or len(text.strip()) < 30: + continue + + # Nettoyer : supprimer les numéros de page isolés en début de texte + lines = text.strip().split("\n") + lines = [l for l in lines if not re.match(r"^\s*\d{1,3}\s*$", l.strip())] + text = "\n".join(lines) + + # Trouver les codes GHM sur cette page + ghm_codes = ghm_pattern.findall(text) + if not ghm_codes: + continue + + # Déterminer la CMD (premiers 2 chiffres) + cmd_num = ghm_codes[0][:2] + if cmd_num not in cmd_pages: + cmd_pages[cmd_num] = [] + cmd_pages[cmd_num].append(text.strip()) + + # Générer un exemple par CMD + for cmd_num in sorted(cmd_pages.keys()): + pages_text = cmd_pages[cmd_num] + full_text = "\n\n".join(pages_text) + + # Extraire tous les GHM et conditions + ghm_codes = sorted(set(ghm_pattern.findall(full_text))) + conditions = sorted(set(condition_pattern.findall(full_text))) + + ghm_str = ", ".join(ghm_codes[:15]) + body = full_text[:2000] + + examples.append(_chatml( + SYSTEM_PROMPT_GHM, + f"Quelles sont les règles de l'arbre de décision GHM pour la CMD {cmd_num} ?", + f"CMD {cmd_num} — Arbre de décision GHM V11e :\n" + f"Racines GHM concernées : {ghm_str}\n\n{body}", + source="arbre_ghm" + )) + + print(f" Arbre GHM : {len(examples)} CMDs extraites") + return examples + + +# ─── 4. Fascicules de codage ────────────────────────────────────────────────── + +def parse_fascicule(pdf_path: Path, topic: str) -> list[dict]: + """Parse un fascicule de codage — extrait les sections de règles.""" + examples = [] + + with pdfplumber.open(str(pdf_path)) as pdf: + full_text = "" + for page in pdf.pages: + text = page.extract_text() or "" + full_text += text + "\n" + + if not full_text.strip(): + return [] + + # Nettoyer : supprimer les lignes de table des matières (avec ........) + clean_lines = [] + for line in full_text.split("\n"): + if "...." in line or "TABLE DES MATIERES" in line: + continue + # Supprimer les numéros de page isolés + if re.match(r"^\s*\d{1,3}\s*$", line.strip()): + continue + clean_lines.append(line) + full_text = "\n".join(clean_lines) + + # Découper en sections par les titres (lignes commençant par un chiffre romain ou "Consignes") + sections = [] + current_title = topic + current_body = [] + + for line in full_text.split("\n"): + line_stripped = line.strip() + # Détecter les titres de section + is_title = False + # Sections principales : I., II., III., IV. + if re.match(r"^[IVX]+\.\s+", line_stripped): + is_title = True + # Sous-sections : II.1., III.2., IV.3. + elif re.match(r"^[IVX]+\.\d+\.?\s+", line_stripped): + is_title = True + # Titres descriptifs courants dans les fascicules + elif re.match(r"^(Consignes|Règles|Les pièges|Le codage|Codage|Comment coder|Principes|Exemples?|Cas particulier|Remarque)", line_stripped, re.IGNORECASE): + is_title = True + elif re.match(r"^(Créé le|Mis à jour|Modifié le|TABLE DES MATIERES|FASCICULE DE CODAGE)", line_stripped): + continue # Skip dates et entêtes + + if is_title and current_body: + body = "\n".join(current_body).strip() + if len(body) > 100: + sections.append((current_title, body)) + current_title = line_stripped + current_body = [] + else: + current_body.append(line) + + # Dernière section + if current_body: + body = "\n".join(current_body).strip() + if len(body) > 100: + sections.append((current_title, body)) + + # Générer les exemples ChatML + for title, body in sections: + # Nettoyer + body = re.sub(r"\n{3,}", "\n\n", body) + body = body[:2000] # Limiter la taille + + # Extraire les codes CIM-10 mentionnés pour enrichir la question + codes = re.findall(r"[A-Z]\d{2}(?:\.\d{1,2})?", body) + codes_unique = sorted(set(codes))[:5] + codes_str = f" (codes : {', '.join(codes_unique)})" if codes_unique else "" + + # Type 1 : Règle de codage + examples.append(_chatml( + SYSTEM_PROMPT, + f"Quelles sont les règles de codage PMSI pour {title}{codes_str} ?", + f"Selon le fascicule de codage ATIH « {topic} » :\n\n{body}", + source=f"fascicule_{topic.lower().replace(' ', '_')[:30]}" + )) + + # Type 2 : Question pratique à partir du contenu + # Chercher les patterns "code X pour Y" ou "on codera X" + coding_rules = re.findall( + r"(?:on codera?|se code|coder?|est codé[e]?|codé[e]?\s+(?:avec|par|en))\s+([A-Z]\d{2}(?:\.\d{1,2})?)\s*(.*?)(?:\.|$)", + body, re.IGNORECASE + ) + for code, context in coding_rules[:5]: + context = context.strip()[:200] + # Filtrer les contextes trop courts ou non informatifs + if len(context) < 15: + continue + examples.append(_chatml( + SYSTEM_PROMPT, + f"Comment coder en CIM-10 : {context} ?", + f"Selon les règles ATIH ({topic}), cette situation se code {code}. {context}", + source=f"fascicule_{topic.lower().replace(' ', '_')[:30]}" + )) + + return examples + + +def parse_all_fascicules() -> list[dict]: + """Parse tous les fascicules disponibles.""" + fascicules = { + "Fascicule_01_Generalites": "Généralités du codage PMSI", + "Fascicule_02_Maladies_digestives": "Maladies de l'appareil digestif", + "Fascicule_03_Tumeurs": "Tumeurs", + "Fascicule_04_Metabolisme": "Métabolisme", + "Fascicule_05_Gyneco_Obstetrique": "Gynécologie et Obstétrique", + "Fascicule_06_Neonatalogie": "Néonatalogie", + "Fascicule_07_Evolutions_2010": "Évolutions 2010", + "Fascicule_08_Maladies_infectieuses": "Maladies infectieuses", + "Fascicule_09_AVC": "Accidents vasculaires cérébraux", + "Fascicule_10_SCA_Coronariens": "Syndromes coronariens aigus", + } + + all_examples = [] + for filename_part, topic in fascicules.items(): + pdf_files = list(REF_DIR.glob(f"*{filename_part}*")) + if not pdf_files: + print(f" {topic} : non trouvé, skip") + continue + examples = parse_fascicule(pdf_files[0], topic) + print(f" {topic} : {len(examples)} exemples") + all_examples.extend(examples) + + return all_examples + + +# ─── 5. Instruction DGOS ────────────────────────────────────────────────────── + +def parse_instruction_dgos() -> list[dict]: + """Parse l'instruction DGOS contrôle T2A.""" + pdf_path = list(REF_DIR.glob("*Instruction_DGOS*.pdf")) + if not pdf_path: + print(" Instruction DGOS non trouvée, skip") + return [] + pdf_path = pdf_path[0] + + with pdfplumber.open(str(pdf_path)) as pdf: + full_text = "" + for page in pdf.pages: + text = page.extract_text() or "" + full_text += text + "\n" + + # Découper en sections thématiques + examples = [] + sections = re.split(r"\n(?=\d+\.\s+|\d+\.\d+\s+|Annexe)", full_text) + + for section in sections: + section = section.strip() + if len(section) < 100: + continue + + # Titre = première ligne + lines = section.split("\n") + title = lines[0].strip() + body = "\n".join(lines[1:]).strip()[:2000] + + if body: + examples.append(_chatml( + SYSTEM_PROMPT_CONTROLE, + f"Que dit l'instruction DGOS 2025 sur : {title} ?", + f"Selon l'instruction DGOS/FIP1/DSS/1A/2025/141 relative aux contrôles T2A :\n\n{body}", + source="instruction_dgos" + )) + + print(f" Instruction DGOS : {len(examples)} sections") + return examples + + +# ─── Main ───────────────────────────────────────────────────────────────────── + +def main(): + print("=" * 60) + print("Parsing des référentiels ATIH/PMSI") + print("=" * 60) + + all_examples = [] + + # 1. Annexe-4 CMA + print("\n1. Annexe-4 CMA") + all_examples.extend(parse_annexe4_cma()) + + # 2. Racines GHM + print("\n2. Racines GHM") + all_examples.extend(parse_racines_ghm()) + + # 3. Arbre de décision GHM + print("\n3. Arbre de décision GHM") + all_examples.extend(parse_arbre_ghm()) + + # 4. Fascicules + print("\n4. Fascicules de codage") + all_examples.extend(parse_all_fascicules()) + + # 5. Instruction DGOS + print("\n5. Instruction DGOS") + all_examples.extend(parse_instruction_dgos()) + + # Sauvegarder + output_path = OUTPUT / "referentiels_chatml.jsonl" + with open(output_path, "w") as f: + for ex in all_examples: + f.write(json.dumps(ex, ensure_ascii=False) + "\n") + + size_mo = output_path.stat().st_size / 1024 / 1024 + print(f"\n{'=' * 60}") + print(f"Total : {len(all_examples)} exemples") + print(f"Sauvegardé : {output_path} ({size_mo:.1f} Mo)") + print(f"{'=' * 60}") + + +if __name__ == "__main__": + main() diff --git a/scripts/12_generate_pipeline_examples.py b/scripts/12_generate_pipeline_examples.py new file mode 100644 index 0000000..67d4e9a --- /dev/null +++ b/scripts/12_generate_pipeline_examples.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python3 +""" +Génère des exemples d'entraînement au format RÉEL du pipeline T2A +à partir du cache Ollama (gemma3:12b = gold standard). + +Le cache contient les paires (diagnostic → code + raisonnement) +produites par gemma3:12b sur les 250 dossiers. On reconstruit +des prompts proches du pipeline et on utilise les réponses du cache +comme labels supervisés. + +V2 : Utilise le cache actuel (1 840 entrées vs 100). + 3 templates : court + long + CPAM contre-argumentation. + Tous les textes génèrent une version courte ET longue. + Cible ~4 000 exemples (2×1840 + CPAM bonus). + +Produit : data/processed/pipeline_chatml.jsonl + +Usage : + python scripts/12_generate_pipeline_examples.py +""" + +import json +import random +from pathlib import Path + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +T2A = Path("/home/dom/ai/t2a") +CACHE_PATH = T2A / "data" / "ollama_cache.json" +CACHE_BACKUP = T2A / "data" / "ollama_cache_gemma3.bak" +CIM10_DICT = T2A / "data" / "cim10_dict.json" +CIM10_SUPP = T2A / "data" / "cim10_supplements.json" +OUTPUT = BASE / "data" / "processed" / "pipeline_chatml.jsonl" + + +def load_cim10_dict() -> dict[str, str]: + """Charge le dictionnaire CIM-10 (code → libellé).""" + d = {} + if CIM10_DICT.exists(): + d.update(json.loads(CIM10_DICT.read_text())) + if CIM10_SUPP.exists(): + d.update(json.loads(CIM10_SUPP.read_text())) + return d + + +def load_cache_entries() -> dict: + """Charge toutes les entrées du cache (actuel + backup).""" + entries = {} + for path in [CACHE_PATH, CACHE_BACKUP]: + if path.exists(): + data = json.loads(path.read_text()) + new = sum(1 for k in data.get("entries", {}) if k not in entries) + entries.update(data.get("entries", {})) + print(f" {path.name}: {len(data.get('entries', {}))} entrées (+{new} nouvelles)") + return entries + + +SYSTEM_PROMPT = ( + "Tu es un médecin DIM (Département d'Information Médicale) expert en codage PMSI. " + "Tu codes les diagnostics en CIM-10 en suivant une démarche structurée : " + "analyse clinique, identification des codes candidats, discrimination, " + "vérification des règles PMSI." +) + +# Template simplifié reproduisant la structure du prompt pipeline +# (sans les sources RAG qui ne sont pas dans le cache) +PROMPT_TEMPLATE_DP = """Code ce diagnostic en CIM-10 pour le PMSI. + +RÈGLES IMPÉRATIVES : +- Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère) +- Vérifie les notes d'inclusion/exclusion de chaque code candidat +- Le DP doit refléter le motif principal de prise en charge du séjour +- EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis existe, le symptôme ne doit PAS être codé comme DP + +DIAGNOSTIC À CODER : "{texte}" +TYPE : DP (diagnostic principal) + +Réponds UNIQUEMENT avec un objet JSON : +{{ + "analyse_clinique": "que signifie ce diagnostic sur le plan médical", + "codes_candidats": "quels codes CIM-10 sont compatibles", + "discrimination": "pourquoi choisir ce code plutôt qu'un autre", + "regle_pmsi": "conformité aux règles PMSI pour un DP", + "code": "X99.9", + "confidence": "high ou medium ou low", + "justification": "explication courte en français" +}}""" + +PROMPT_TEMPLATE_DAS = """Code ce diagnostic en CIM-10 pour le PMSI. + +RÈGLES IMPÉRATIVES : +- Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère) +- Vérifie les notes d'inclusion/exclusion de chaque code candidat +- Un DAS doit avoir mobilisé des ressources supplémentaires pendant le séjour +- EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis existe, le symptôme ne doit PAS être codé comme DAS + +DIAGNOSTIC À CODER : "{texte}" +TYPE : DAS (diagnostic associé significatif) + +Réponds UNIQUEMENT avec un objet JSON : +{{ + "analyse_clinique": "que signifie ce diagnostic sur le plan médical", + "codes_candidats": "quels codes CIM-10 sont compatibles", + "discrimination": "pourquoi choisir ce code plutôt qu'un autre", + "regle_pmsi": "conformité aux règles PMSI pour un DAS", + "code": "X99.9", + "confidence": "high ou medium ou low", + "justification": "explication courte en français" +}}""" + +# Version longue avec contexte patient (simulé) pour entraîner sur des prompts longs +PROMPT_TEMPLATE_DP_LONG = """Code ce diagnostic en CIM-10 pour le PMSI. + +RÈGLES IMPÉRATIVES : +- Le code doit provenir UNIQUEMENT de la nomenclature CIM-10 FR 2026 +- Distingue la DESCRIPTION CLINIQUE (ce que le médecin écrit) de la LOGIQUE DE CODAGE (ce que l'ATIH impose) +- Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère) +- Vérifie les notes d'inclusion/exclusion de chaque code candidat +- Si le diagnostic est un DP, il doit refléter le motif principal de prise en charge du séjour +- EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis (Chapitres I-XIV, A00-N99) expliquant ce symptôme est présent, le symptôme ne doit PAS être codé + +DIAGNOSTIC À CODER : "{texte}" +TYPE : DP (diagnostic principal) + +CONTEXTE CLINIQUE : +{contexte} + +SOURCES CIM-10 : +{sources} + +Réponds UNIQUEMENT avec un objet JSON au format suivant, sans aucun texte avant ou après : +{{ + "analyse_clinique": "que signifie ce diagnostic sur le plan médical", + "codes_candidats": "quels codes CIM-10 des sources sont compatibles", + "discrimination": "pourquoi choisir ce code plutôt qu'un autre (inclusions/exclusions, spécificité)", + "regle_pmsi": "conformité aux règles PMSI pour un DP (guide méthodologique)", + "code": "X99.9", + "confidence": "high ou medium ou low", + "justification": "explication courte en français" +}}""" + +PROMPT_TEMPLATE_DAS_LONG = """Code ce diagnostic en CIM-10 pour le PMSI. + +RÈGLES IMPÉRATIVES : +- Le code doit provenir UNIQUEMENT de la nomenclature CIM-10 FR 2026 +- Distingue la DESCRIPTION CLINIQUE (ce que le médecin écrit) de la LOGIQUE DE CODAGE (ce que l'ATIH impose) +- Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère) +- Vérifie les notes d'inclusion/exclusion de chaque code candidat +- Un DAS doit avoir mobilisé des ressources supplémentaires pendant le séjour +- EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis (Chapitres I-XIV, A00-N99) expliquant ce symptôme est présent, le symptôme ne doit PAS être codé + +DIAGNOSTIC À CODER : "{texte}" +TYPE : DAS (diagnostic associé significatif) + +CONTEXTE CLINIQUE : +{contexte} + +SOURCES CIM-10 : +{sources} + +Réponds UNIQUEMENT avec un objet JSON au format suivant, sans aucun texte avant ou après : +{{ + "analyse_clinique": "que signifie ce diagnostic sur le plan médical", + "codes_candidats": "quels codes CIM-10 des sources sont compatibles", + "discrimination": "pourquoi choisir ce code plutôt qu'un autre (inclusions/exclusions, spécificité)", + "regle_pmsi": "conformité aux règles PMSI pour un DAS (guide méthodologique)", + "code": "X99.9", + "confidence": "high ou medium ou low", + "justification": "explication courte en français" +}}""" + + +# V2 : Template CPAM contre-argumentation (3e variante) +PROMPT_TEMPLATE_CPAM = """Tu es médecin DIM. La CPAM conteste le codage d'un diagnostic. Argumente. + +DIAGNOSTIC CONTESTÉ : "{texte}" +TYPE : {type_label} +CODE ACTUEL : {code} + +MOTIF DE CONTESTATION CPAM : +{motif_cpam} + +SOURCES CIM-10 : +{sources} + +Réponds UNIQUEMENT avec un objet JSON : +{{ + "analyse_clinique": "que signifie ce diagnostic sur le plan médical", + "codes_candidats": "quels codes CIM-10 sont compatibles", + "discrimination": "pourquoi ce code est correct (inclusions/exclusions, spécificité)", + "regle_pmsi": "conformité aux règles PMSI et au guide méthodologique", + "code": "{code}", + "confidence": "high ou medium ou low", + "justification": "argumentation structurée pour répondre à la contestation CPAM" +}}""" + +CPAM_MOTIFS = [ + "Le code {code} ne semble pas justifié par les éléments du dossier médical.", + "La CPAM propose de recoder en {alt_code} ({alt_label}). Justifiez le maintien du code actuel.", + "Le DAS {code} n'a pas mobilisé de ressources supplémentaires pendant le séjour.", + "Ce diagnostic est un symptôme déjà couvert par le DP. Il ne devrait pas être codé séparément.", + "Le niveau de sévérité CMA associé à {code} ne semble pas justifié cliniquement.", +] + + +def build_fake_source(code: str, cim10_dict: dict[str, str]) -> str: + """Génère un extrait de source CIM-10 simulé pour un code donné.""" + label = cim10_dict.get(code, "") + if not label: + return "" + + # Trouver des codes voisins (même catégorie à 3 car.) + prefix = code[:3] + neighbors = [] + for c, l in cim10_dict.items(): + if c[:3] == prefix and c != code: + neighbors.append((c, l)) + neighbors.sort() + neighbors = neighbors[:4] + + lines = [f"--- Source 1: CIM-10 FR 2026 (code: {code}) ---"] + lines.append(f"{code} {label}") + for c, l in neighbors: + lines.append(f"{c} {l}") + lines.append("") + + return "\n".join(lines) + + +def build_response_json(entry: dict) -> str: + """Reconstruit la réponse JSON structurée depuis une entrée du cache.""" + # Extraire les sections du raisonnement + rais = entry.get("raisonnement", "") + analyse = "" + candidats = "" + discrim = "" + regle = "" + + if "ANALYSE CLINIQUE" in rais: + parts = rais.split("\n\n") + for part in parts: + p = part.strip() + if p.startswith("ANALYSE CLINIQUE"): + analyse = p.replace("ANALYSE CLINIQUE :\n", "").replace("ANALYSE CLINIQUE :", "").strip() + elif p.startswith("CODES CANDIDATS"): + candidats = p.replace("CODES CANDIDATS :\n", "").replace("CODES CANDIDATS :", "").strip() + elif p.startswith("DISCRIMINATION"): + discrim = p.replace("DISCRIMINATION :\n", "").replace("DISCRIMINATION :", "").strip() + elif p.startswith("REGLE PMSI") or p.startswith("RÈGLE PMSI"): + regle = p.split(":\n", 1)[-1].strip() if ":\n" in p else p.split(":", 1)[-1].strip() + + resp = { + "analyse_clinique": analyse or "Diagnostic médical nécessitant un codage CIM-10 spécifique.", + "codes_candidats": candidats or f"[{entry.get('code', '?')}]", + "discrimination": discrim or entry.get("justification", ""), + "regle_pmsi": regle or "Code conforme aux règles PMSI.", + "code": entry.get("code", "?"), + "confidence": entry.get("confidence", "medium"), + "justification": entry.get("justification", ""), + } + + return json.dumps(resp, ensure_ascii=False) + + +def generate_contexte_samples() -> list[str]: + """Génère des contextes patients variés.""" + return [ + "- Patient : Homme, 72 ans, IMC 28.5\n- Durée séjour : 7 jours\n- Biologie : CRP 145 [N: 0-5] (↑), Créatinine 89 [N: 50-120]", + "- Patient : Femme, 58 ans, IMC 24.1\n- Durée séjour : 3 jours\n- Biologie : Hémoglobine 10.2 [N: 12-17] (↑), Plaquettes 180 [N: 150-400]", + "- Patient : Homme, 85 ans, IMC 22.0\n- Durée séjour : 12 jours\n- Antécédents : HTA, Diabète de type 2, BPCO\n- Biologie : CRP 89 [N: 0-5] (↑), Leucocytes 14.2 [N: 4-10] (↑)", + "- Patient : Femme, 45 ans\n- Durée séjour : 2 jours", + "- Patient : Homme, 67 ans, IMC 31.2\n- Durée séjour : 5 jours\n- Biologie : ASAT 85 [N: 0-40] (↑), ALAT 120 [N: 0-40] (↑), GGT 210 [N: 0-60] (↑)", + "- Patient : Femme, 30 ans\n- Durée séjour : 4 jours\n- Biologie : CRP 25 [N: 0-5] (↑), Leucocytes 12.5 [N: 4-10] (↑)", + "- Patient : Homme, 78 ans, IMC 20.1\n- Durée séjour : 14 jours\n- Antécédents : Insuffisance cardiaque, FA\n- Complications : Infection urinaire", + "Non précisé", + ] + + +def _parse_cache_key(key): + """Extraire le type (dp/das) et le texte depuis la clé du cache.""" + if key.startswith("das_llm::das_extract::"): + parts = key.split("::", 3) + texte = parts[3] if len(parts) > 3 else parts[-1] + return "das", texte.strip() + if "::" in key: + diag_type, texte = key.split("::", 1) + return diag_type.strip(), texte.strip() + return "das", key.strip() + + +def _find_alt_code(code: str, cim10_dict: dict[str, str]) -> tuple[str, str]: + """Trouve un code alternatif (même catégorie) pour les motifs CPAM.""" + prefix = code[:3] + for c, label in cim10_dict.items(): + if c[:3] == prefix and c != code and len(c) >= 4: + return c, label + # Fallback : code .9 (SAI) + sai = f"{prefix}.9" + if sai in cim10_dict and sai != code: + return sai, cim10_dict[sai] + return "Z03.9", "Observation pour suspicion non précisée" + + +def main(): + print("Chargement du cache Ollama gemma3:12b (toutes sources)...") + entries = load_cache_entries() + print(f" Total fusionné : {len(entries)} entrées") + + print("Chargement du dictionnaire CIM-10...") + cim10_dict = load_cim10_dict() + print(f" {len(cim10_dict)} codes") + + contextes = generate_contexte_samples() + examples = [] + n_short = 0 + n_long = 0 + n_cpam = 0 + + for key, entry in entries.items(): + code = entry.get("code", "") + if not code or code == "?" or len(code) < 3: + continue + + diag_type, texte = _parse_cache_key(key) + if diag_type not in ("dp", "das"): + continue + if not texte or len(texte) < 3: + continue + + # Vérifier que le raisonnement est substantiel + raisonnement = entry.get("raisonnement", "") + if len(raisonnement) < 50: + continue + + response_json = build_response_json(entry) + + # V2 : TOUJOURS générer version courte + version longue + + # 1) Version courte (sans sources ni contexte) + if diag_type == "dp": + user_content = PROMPT_TEMPLATE_DP.format(texte=texte) + else: + user_content = PROMPT_TEMPLATE_DAS.format(texte=texte) + + examples.append({ + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + {"role": "assistant", "content": response_json}, + ] + }) + n_short += 1 + + # 2) Version longue (avec sources CIM-10 et contexte simulé) + source_text = build_fake_source(code, cim10_dict) + if source_text: + ctx = random.choice(contextes) + if diag_type == "dp": + user_long = PROMPT_TEMPLATE_DP_LONG.format( + texte=texte, contexte=ctx, sources=source_text + ) + else: + user_long = PROMPT_TEMPLATE_DAS_LONG.format( + texte=texte, contexte=ctx, sources=source_text + ) + + examples.append({ + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_long}, + {"role": "assistant", "content": response_json}, + ] + }) + n_long += 1 + + # 3) V2 : Version CPAM (contre-argumentation) — 20% des entrées DAS high confidence + if (diag_type == "das" + and entry.get("confidence") == "high" + and source_text + and random.random() < 0.20): + type_label = "DAS (diagnostic associé significatif)" + alt_code, alt_label = _find_alt_code(code, cim10_dict) + motif = random.choice(CPAM_MOTIFS).format( + code=code, alt_code=alt_code, alt_label=alt_label + ) + user_cpam = PROMPT_TEMPLATE_CPAM.format( + texte=texte, type_label=type_label, code=code, + motif_cpam=motif, sources=source_text, + ) + examples.append({ + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_cpam}, + {"role": "assistant", "content": response_json}, + ] + }) + n_cpam += 1 + + # Mélanger + random.shuffle(examples) + + # Écrire + OUTPUT.parent.mkdir(parents=True, exist_ok=True) + with open(OUTPUT, "w") as f: + for ex in examples: + f.write(json.dumps(ex, ensure_ascii=False) + "\n") + + # Stats + token_counts = [] + for ex in examples: + text = " ".join(m["content"] for m in ex["messages"]) + token_counts.append(int(len(text.split()) * 1.3)) + + print(f"\n{'='*50}") + print(f"Exemples pipeline générés : {len(examples)}") + print(f" Courts : {n_short}, Longs : {n_long}, CPAM : {n_cpam}") + print(f" → {OUTPUT}") + print(f" Tokens : moy={sum(token_counts)//len(token_counts)}, " + f"max={max(token_counts)}, min={min(token_counts)}") + print(f" Taille : {OUTPUT.stat().st_size / 1024 / 1024:.1f} Mo") + + +if __name__ == "__main__": + main() diff --git a/scripts/13_generate_fascicule_reasoning.py b/scripts/13_generate_fascicule_reasoning.py new file mode 100644 index 0000000..f798fde --- /dev/null +++ b/scripts/13_generate_fascicule_reasoning.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 +""" +Parsing agressif des 10 fascicules ATIH → Q&A de raisonnement DIM. + +Stratégie : + - Découpe chaque fascicule en paragraphes (pas sections) + - Extrait les règles de codage via regex + - Pour chaque règle extraite, génère 3 Q&A raisonnement via Claude Opus 4.6 + - Question = scénario clinique appliquant la règle + - Réponse = JSON structuré {analyse_clinique, regle_pmsi, code, justification} + +Cible : ~450 exemples (151 règles × 3 exercices) + +Sources : data/raw/referentiels/ (10 fascicules PDF, via lien t2a) +Nécessite : ANTHROPIC_API_KEY en variable d'environnement + +Usage : + python scripts/13_generate_fascicule_reasoning.py [--dry-run] [--max N] +""" + +import argparse +import json +import os +import random +import re +import sys +import time +from pathlib import Path + +import pdfplumber + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +T2A = BASE.parent / "t2a" +REF_DIR = T2A / "data" / "referentiels" +OUTPUT = BASE / "data" / "processed" / "fascicule_reasoning_chatml.jsonl" +OUTPUT.parent.mkdir(parents=True, exist_ok=True) + +MODEL = "claude-opus-4-6" + +SYSTEM_PROMPT = ( + "Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. " + "Tu t'appuies sur les référentiels ATIH officiels." +) + +# Fascicules à parser +FASCICULES = { + "Fascicule_01_Generalites": "Généralités du codage PMSI", + "Fascicule_02_Maladies_digestives": "Maladies de l'appareil digestif", + "Fascicule_03_Tumeurs": "Tumeurs", + "Fascicule_04_Metabolisme": "Métabolisme", + "Fascicule_05_Gyneco_Obstetrique": "Gynécologie et Obstétrique", + "Fascicule_06_Neonatalogie": "Néonatalogie", + "Fascicule_07_Evolutions_2010": "Évolutions 2010", + "Fascicule_08_Maladies_infectieuses": "Maladies infectieuses", + "Fascicule_09_AVC": "Accidents vasculaires cérébraux", + "Fascicule_10_SCA_Coronariens": "Syndromes coronariens aigus", +} + +# Regex pour détecter les règles de codage dans le texte +RULE_PATTERNS = [ + re.compile(r"(?:on\s+code(?:ra)?|se\s+code|coder?|est\s+cod[ée](?:e)?)\s+(?:avec\s+|par\s+|en\s+)?([A-Z]\d{2}(?:\.\d{1,2})?)", re.IGNORECASE), + re.compile(r"ne\s+(?:pas|doit\s+pas|faut\s+pas)\s+coder", re.IGNORECASE), + re.compile(r"à\s+l['\u2019]exclusion\s+de", re.IGNORECASE), + re.compile(r"(?:comprend|inclut|inclus)\s+", re.IGNORECASE), + re.compile(r"(?:dans\s+ce\s+cas|en\s+cas\s+de|si\s+le\s+patient|lorsque)", re.IGNORECASE), + re.compile(r"(?:le\s+DP|diagnostic\s+principal)\s+(?:est|sera|doit\s+être)", re.IGNORECASE), + re.compile(r"(?:le\s+DAS|diagnostic\s+associé)\s+(?:est|sera|doit\s+être)", re.IGNORECASE), + re.compile(r"(?:CMA|sévérité|niveau\s+\d)", re.IGNORECASE), + re.compile(r"(?:séjour|durée|ressources?\s+supplémentaires?)", re.IGNORECASE), +] + +GENERATION_PROMPT = """Tu es un formateur DIM. À partir de cet extrait du fascicule ATIH, génère EXACTEMENT 3 exercices de raisonnement distincts. + +EXTRAIT (source : fascicule ATIH « {topic} ») : +{rule_text} + +Génère un tableau JSON de 3 objets, chacun avec : +- "scenario" : un cas clinique réaliste et concis (2-3 phrases), DIFFÉRENT des autres +- "reponse" : un objet contenant : + - "analyse_clinique" : interprétation du cas + - "regle_pmsi" : la règle du fascicule qui s'applique + - "code" : le code CIM-10 correct (ou null si ne pas coder) + - "confidence" : "high" + - "justification" : pourquoi cette règle s'applique + +Réponds UNIQUEMENT avec le tableau JSON (pas d'objet wrapper), sans texte avant/après. +Exemple de format : [{{"scenario": "...", "reponse": {{...}}}}, ...]""" + + +def extract_paragraphs(pdf_path: Path) -> list[str]: + """Extraire les paragraphes d'un fascicule PDF. + + Stratégie : pdfplumber ne produit pas de lignes vides entre paragraphes. + On découpe par page, puis par titres/sections détectées via heuristiques. + """ + pages_text = [] + with pdfplumber.open(str(pdf_path)) as pdf: + for page in pdf.pages: + text = page.extract_text() or "" + if text.strip(): + pages_text.append(text) + + # Concaténer avec séparateur de page + full_text = "\n\n".join(pages_text) + + # Nettoyer les lignes non pertinentes + lines = [] + for line in full_text.split("\n"): + stripped = line.strip() + if "...." in stripped or "TABLE DES MATIERES" in stripped: + continue + if re.match(r"^\s*\d{1,3}\s*$", stripped): + continue + if re.match(r"^(Créé le|Mis à jour|Modifié le|FASCICULE DE CODAGE)", stripped): + continue + lines.append(line) + + clean_text = "\n".join(lines) + + # Découper en sections par les titres détectés + title_re = re.compile( + r"^(?:" + r"[IVX]+\.\d*\.?\s+" + r"|[IVX]+\s+[-–]\s+" + r"|\d+\.\d*\.?\s+[A-ZÀÂÉÈÊËÎÏ]" + r"|[A-ZÀÂÉÈÊËÎÏÔÙÛÜ][A-ZÀÂÉÈÊËÎÏÔÙÛÜ\s]{3,}$" + r")", + re.MULTILINE + ) + + paragraphs = [] + + matches = list(title_re.finditer(clean_text)) + if matches: + for i, match in enumerate(matches): + start = match.start() + end = matches[i + 1].start() if i + 1 < len(matches) else len(clean_text) + section = clean_text[start:end].strip() + if len(section) > 100: + paragraphs.append(section) + + # Fallback : découper par blocs de ~800 caractères + if len(paragraphs) < 5: + paragraphs = [] + block_size = 800 + current_block = [] + current_len = 0 + for line in clean_text.split("\n"): + current_block.append(line) + current_len += len(line) + 1 + if current_len >= block_size: + para = "\n".join(current_block).strip() + if len(para) > 100: + paragraphs.append(para) + current_block = [] + current_len = 0 + if current_block: + para = "\n".join(current_block).strip() + if len(para) > 100: + paragraphs.append(para) + + return paragraphs + + +def extract_rules(paragraphs: list[str]) -> list[tuple[str, int]]: + """Extraire les paragraphes classés par pertinence (score de règle). + + Returns: [(text, score)] trié par score décroissant. + """ + scored = [] + for para in paragraphs: + score = sum(1 for pat in RULE_PATTERNS if pat.search(para)) + codes = re.findall(r"[A-Z]\d{2}(?:\.\d{1,2})?", para) + if codes: + score += 1 + if score >= 1: + text = para[:1500] if len(para) > 1500 else para + scored.append((text, score)) + scored.sort(key=lambda x: -x[1]) + return scored + + +def call_claude(client, prompt: str, max_retries: int = 2) -> str | None: + """Appel Claude Opus 4.6 via API Anthropic avec retry.""" + for attempt in range(max_retries + 1): + try: + response = client.messages.create( + model=MODEL, + max_tokens=4096, + temperature=0.7, + messages=[{"role": "user", "content": prompt}], + ) + return response.content[0].text + except Exception as e: + if attempt < max_retries: + wait = 2 ** (attempt + 1) + print(f" Retry in {wait}s: {e}") + time.sleep(wait) + else: + print(f" Claude error: {e}") + return None + + +def parse_llm_response(response_text: str) -> list[dict]: + """Parse la réponse JSON du LLM (tableau de 3 exercices ou objet unique).""" + if not response_text: + return [] + text = response_text.strip() + if "```json" in text: + text = text.split("```json", 1)[1].split("```", 1)[0].strip() + elif "```" in text: + text = text.split("```", 1)[1].split("```", 1)[0].strip() + + try: + data = json.loads(text) + if isinstance(data, list): + return [d for d in data if isinstance(d, dict) and "scenario" in d and "reponse" in d] + if isinstance(data, dict) and "scenario" in data and "reponse" in data: + return [data] + except json.JSONDecodeError: + pass + + # Fallback : chercher un tableau [...] + bracket_start = text.find("[") + if bracket_start >= 0: + depth = 0 + for i in range(bracket_start, len(text)): + if text[i] == "[": + depth += 1 + elif text[i] == "]": + depth -= 1 + if depth == 0: + try: + data = json.loads(text[bracket_start:i+1]) + if isinstance(data, list): + return [d for d in data if isinstance(d, dict) and "scenario" in d] + except json.JSONDecodeError: + break + + # Fallback : objet unique + brace_start = text.find("{") + if brace_start >= 0: + depth = 0 + for i in range(brace_start, len(text)): + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + if depth == 0: + try: + data = json.loads(text[brace_start:i+1]) + if "scenario" in data: + return [data] + except json.JSONDecodeError: + break + return [] + + +def make_chatml(scenario: str, response: dict, topic: str) -> dict: + """Créer un exemple ChatML depuis le scénario + réponse structurée.""" + user_content = ( + f"Cas clinique :\n{scenario}\n\n" + f"Code ce cas en CIM-10 selon les règles du fascicule « {topic} ».\n\n" + "Réponds avec un JSON structuré contenant : analyse_clinique, regle_pmsi, code, confidence, justification." + ) + assistant_content = json.dumps(response, ensure_ascii=False) + + return { + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + {"role": "assistant", "content": assistant_content}, + ] + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--dry-run", action="store_true", help="Pas d'appel LLM, affiche les règles extraites") + parser.add_argument("--max", type=int, default=0, help="Max règles par fascicule (0=illimité)") + args = parser.parse_args() + + print("=" * 60) + print("Génération de Q&A raisonnement depuis les fascicules ATIH") + print(f"Modèle : {MODEL}") + print("=" * 60) + + # Vérifier la clé API + if not args.dry_run: + api_key = os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + print("Erreur: ANTHROPIC_API_KEY non définie.") + print(" export ANTHROPIC_API_KEY='sk-ant-...'") + sys.exit(1) + import anthropic + client = anthropic.Anthropic(api_key=api_key) + else: + client = None + + all_examples = [] + total_rules = 0 + + for filename_part, topic in FASCICULES.items(): + pdf_files = list(REF_DIR.glob(f"*{filename_part}*")) + pdf_files = [f for f in pdf_files if "redacted" not in f.name.lower() and "pseudonymise" not in str(f)] + if not pdf_files: + print(f"\n{topic} : PDF non trouvé, skip") + continue + + pdf_path = pdf_files[0] + print(f"\n{'─'*40}") + print(f"{topic} ({pdf_path.name})") + + paragraphs = extract_paragraphs(pdf_path) + print(f" Paragraphes extraits : {len(paragraphs)}") + + scored_rules = extract_rules(paragraphs) + print(f" Règles/paragraphes pertinents : {len(scored_rules)}") + + if args.max > 0: + scored_rules = scored_rules[:args.max] + + total_rules += len(scored_rules) + + if args.dry_run: + for i, (rule, score) in enumerate(scored_rules[:5]): + print(f" [{i+1}] (score={score}) {rule[:120]}...") + continue + + # Générer les Q&A via Claude (3 exercices par règle) + n_ok = 0 + n_fail = 0 + for i, (rule_text, score) in enumerate(scored_rules): + prompt = GENERATION_PROMPT.format(topic=topic, rule_text=rule_text) + response_text = call_claude(client, prompt) + exercises = parse_llm_response(response_text) + + for ex in exercises: + if "scenario" in ex and "reponse" in ex: + example = make_chatml(ex["scenario"], ex["reponse"], topic) + all_examples.append(example) + n_ok += 1 + + if not exercises: + n_fail += 1 + + if (i + 1) % 10 == 0: + print(f" Progression : {i+1}/{len(scored_rules)} (exemples={n_ok}, échecs={n_fail})") + + print(f" Résultat : {n_ok} exemples générés, {n_fail} échecs") + + if args.dry_run: + print(f"\n[DRY RUN] Total règles détectées : {total_rules}") + print("Relancez sans --dry-run pour générer les exemples avec Claude.") + return + + # Mélanger et sauvegarder + random.shuffle(all_examples) + + with open(OUTPUT, "w") as f: + for ex in all_examples: + f.write(json.dumps(ex, ensure_ascii=False) + "\n") + + print(f"\n{'='*60}") + print(f"Total : {len(all_examples)} exemples → {OUTPUT}") + print(f"Taille : {OUTPUT.stat().st_size / 1024:.0f} Ko") + print(f"Règles sources : {total_rules}") + + +if __name__ == "__main__": + main() diff --git a/scripts/14_generate_negative_examples.py b/scripts/14_generate_negative_examples.py new file mode 100644 index 0000000..9c021b5 --- /dev/null +++ b/scripts/14_generate_negative_examples.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +""" +Génère des exemples négatifs : enseigner au modèle quand NE PAS coder. + +3 types d'exemples : + a) Codes rejetés (500) — symptômes couverts par le DP + b) Redondances sémantiques (200) — paires dominé/dominant + c) DAS non significatifs (300) — antécédents sans ressources consommées + +Sources : + - Cache Ollama (textes diagnostics réels) + - Règles PMSI (SEMANTIC_REDUNDANCIES, symptômes R00-R99) + - Templates + CIM-10 FHIR + +Produit : data/processed/negative_chatml.jsonl + +Usage : + python scripts/14_generate_negative_examples.py +""" + +import json +import random +from pathlib import Path + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +T2A = BASE.parent / "t2a" +RAW = BASE / "data" / "raw" +OUTPUT = BASE / "data" / "processed" / "negative_chatml.jsonl" +OUTPUT.parent.mkdir(parents=True, exist_ok=True) + +SYSTEM_PROMPT = ( + "Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. " + "Tu sais quand un diagnostic ne doit PAS être codé." +) + + +def make_chatml(user: str, assistant: str) -> dict: + return { + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user}, + {"role": "assistant", "content": assistant}, + ] + } + + +def load_cim10_dict() -> dict[str, str]: + """Charge le dictionnaire CIM-10.""" + d = {} + for path in [T2A / "data" / "cim10_dict.json", T2A / "data" / "cim10_supplements.json"]: + if path.exists(): + d.update(json.loads(path.read_text())) + return d + + +def load_fhir_concepts() -> dict[str, dict]: + """Charge les concepts FHIR indexés par code.""" + fhir_path = RAW / "smt_cim10_fhir.json" + if not fhir_path.exists(): + return {} + data = json.loads(fhir_path.read_text()) + by_code = {} + for c in data.get("concept", []): + by_code[c["code"]] = c + return by_code + + +# ─── Type A : Symptômes couverts par le DP ──────────────────────────────────── + +# Symptômes R00-R99 fréquemment codés à tort comme DAS quand le DP les explique +SYMPTOM_DP_PAIRS = [ + # (symptôme_code, symptôme_label, dp_code, dp_label, explication) + ("R50.9", "Fièvre, sans précision", "A41.9", "Sepsis, sans précision", + "La fièvre est un symptôme cardinal du sepsis. Elle est couverte par le DP A41.9."), + ("R50.9", "Fièvre, sans précision", "J18.9", "Pneumonie, sans précision", + "La fièvre est un symptôme habituel de la pneumonie. Le DP J18.9 la couvre."), + ("R50.9", "Fièvre, sans précision", "N10", "Néphrite tubulo-interstitielle aiguë", + "La fièvre accompagne habituellement la pyélonéphrite aiguë (N10)."), + ("R06.0", "Dyspnée", "J44.1", "BPCO avec exacerbation aiguë", + "La dyspnée est le symptôme principal de l'exacerbation de BPCO."), + ("R06.0", "Dyspnée", "I50.9", "Insuffisance cardiaque, sans précision", + "La dyspnée est un symptôme majeur de l'insuffisance cardiaque."), + ("R06.0", "Dyspnée", "J96.0", "Insuffisance respiratoire aiguë", + "La dyspnée est couverte par le DP d'insuffisance respiratoire aiguë."), + ("R07.4", "Douleur thoracique, sans précision", "I21.9", "IDM aigu, sans précision", + "La douleur thoracique est le symptôme principal de l'IDM."), + ("R07.4", "Douleur thoracique, sans précision", "I20.0", "Angor instable", + "La douleur thoracique est le symptôme cardinal de l'angor instable."), + ("R10.4", "Douleur abdominale, sans précision", "K80.1", "Lithiase vésiculaire avec cholécystite", + "La douleur abdominale est couverte par le DP de cholécystite."), + ("R10.4", "Douleur abdominale, sans précision", "K35.8", "Appendicite aiguë, autre et sans précision", + "La douleur abdominale est le symptôme principal de l'appendicite."), + ("R11.0", "Nausées", "K29.1", "Gastrite aiguë", + "Les nausées sont un symptôme courant de la gastrite, couvert par le DP."), + ("R11.2", "Nausées avec vomissements, sans précision", "K56.6", "Occlusion intestinale, autre et sans précision", + "Les vomissements sont un signe cardinal de l'occlusion intestinale."), + ("R00.0", "Tachycardie, sans précision", "A41.9", "Sepsis, sans précision", + "La tachycardie est un critère diagnostique du sepsis."), + ("R00.0", "Tachycardie, sans précision", "I48.9", "Fibrillation auriculaire", + "La tachycardie est un symptôme de la FA, couverte par le DP."), + ("R41.0", "Désorientation, sans précision", "F05.9", "Delirium, sans précision", + "La désorientation est un symptôme constitutif du delirium."), + ("R40.0", "Somnolence", "S06.9", "Lésion traumatique intracrânienne, sans précision", + "La somnolence est un signe d'atteinte neurologique dans le trauma crânien."), + ("R42", "Étourdissements et éblouissements", "H81.1", "Vertige paroxystique bénin", + "Les étourdissements sont le symptôme principal du VPPB."), + ("R51", "Céphalée", "G43.9", "Migraine, sans précision", + "La céphalée est le symptôme principal de la migraine."), + ("R63.0", "Anorexie", "C16.9", "Tumeur maligne de l'estomac", + "L'anorexie est un symptôme fréquent du cancer gastrique, couvert par le DP."), + ("R53", "Malaise et fatigue", "D64.9", "Anémie, sans précision", + "La fatigue est un symptôme courant de l'anémie."), + ("R31", "Hématurie, sans précision", "N20.0", "Calcul du rein", + "L'hématurie est un symptôme classique de la lithiase rénale."), + ("R73.0", "Hyperglycémie SAI", "E11.9", "Diabète de type 2", + "L'hyperglycémie est un signe du diabète de type 2, couvert par le DP."), + ("R60.0", "Oedème localisé", "I50.0", "Insuffisance cardiaque congestive", + "L'oedème est un signe de l'ICC, couvert par le DP."), + ("R09.2", "Arrêt respiratoire", "J96.0", "Insuffisance respiratoire aiguë", + "L'arrêt respiratoire est la forme extrême de l'insuffisance respiratoire aiguë."), + ("R57.0", "Choc cardiogénique", "I21.9", "IDM aigu", + "Le choc cardiogénique comme complication de l'IDM peut se coder comme DAS (mobilise des ressources), mais s'il est le tableau initial il est couvert par le DP."), +] + + +def generate_symptom_dp_examples(cim10_dict: dict, target: int = 500) -> list[dict]: + """Générer des exemples de symptômes couverts par le DP.""" + examples = [] + + # Templates de variation + user_templates = [ + "Code ce diagnostic : {symptom_label}\nTYPE : DAS\nCONTEXTE : DP = {dp_code} ({dp_label})", + "Le patient est hospitalisé pour {dp_label} (DP : {dp_code}). Faut-il coder {symptom_label} ({symptom_code}) en DAS ?", + "DAS candidat : {symptom_code} ({symptom_label})\nDP du séjour : {dp_code} ({dp_label})\nCe DAS est-il pertinent ?", + ] + + assistant_template_null = json.dumps({ + "code": None, + "confidence": "high", + "justification": "{explanation} Règle PMSI : ne pas coder un symptôme (R00-R99) si le diagnostic qui l'explique est déjà codé comme DP." + }, ensure_ascii=False) + + # Générer depuis les paires prédéfinies (avec variations de templates) + for sym_code, sym_label, dp_code, dp_label, expl in SYMPTOM_DP_PAIRS: + for tmpl in user_templates: + user = tmpl.format( + symptom_code=sym_code, symptom_label=sym_label, + dp_code=dp_code, dp_label=dp_label + ) + assistant = assistant_template_null.replace("{explanation}", expl) + examples.append(make_chatml(user, assistant)) + + # Générer des variations supplémentaires pour atteindre la cible + # Utiliser tous les codes R du dictionnaire CIM-10 + r_codes = [(c, l) for c, l in cim10_dict.items() if c.startswith("R") and len(c) >= 4] + non_r_codes = [(c, l) for c, l in cim10_dict.items() + if not c.startswith("R") and not c.startswith("Z") + and c[0].isalpha() and len(c) >= 4 and len(l) > 5] + + while len(examples) < target and r_codes and non_r_codes: + sym_code, sym_label = random.choice(r_codes) + dp_code, dp_label = random.choice(non_r_codes) + tmpl = random.choice(user_templates) + user = tmpl.format( + symptom_code=sym_code, symptom_label=sym_label, + dp_code=dp_code, dp_label=dp_label + ) + expl = f"{sym_label} est un symptôme (code R) potentiellement couvert par le DP {dp_code} ({dp_label})." + assistant = json.dumps({ + "code": None, + "confidence": "medium", + "justification": f"{expl} Règle PMSI : ne pas coder un symptôme en DAS s'il est expliqué par le DP. Vérifier si le symptôme a nécessité une prise en charge spécifique supplémentaire." + }, ensure_ascii=False) + examples.append(make_chatml(user, assistant)) + + random.shuffle(examples) + return examples[:target] + + +# ─── Type B : Redondances sémantiques ───────────────────────────────────────── + +SEMANTIC_REDUNDANCIES = [ + # (dominated_prefix, dominant_prefixes, explanation) + ("I10", ["I11", "I12", "I13"], + "I10 (HTA essentielle) est redondant quand I11/I12/I13 est présent. Le code hypertensif spécifique inclut la composante HTA."), + ("N30", ["N39"], + "N30 (cystite) est redondant quand N39.0 (infection urinaire) est présent. L'infection urinaire couvre la cystite."), + ("J18", ["J15", "J16"], + "J18 (pneumonie SAI) est redondant quand J15/J16 (pneumonie spécifique) est présent. Le code spécifique prime."), + ("E11.9", ["E11.0", "E11.1", "E11.2", "E11.3", "E11.4", "E11.5", "E11.6", "E11.7"], + "E11.9 (diabète type 2 SAI) est redondant si un sous-code E11.x spécifiant une complication est présent."), + ("I25.9", ["I25.1", "I25.2", "I25.5"], + "I25.9 (cardiopathie ischémique chronique SAI) est redondant si un sous-code I25.x plus spécifique est présent."), + ("N18.9", ["N18.1", "N18.2", "N18.3", "N18.4", "N18.5"], + "N18.9 (IRC SAI) est redondant si un stade N18.x spécifique est présent."), + ("J44.9", ["J44.0", "J44.1"], + "J44.9 (BPCO SAI) est redondant si J44.0 (BPCO avec infection) ou J44.1 (BPCO avec exacerbation) est présent."), + ("K21.9", ["K21.0"], + "K21.9 (RGO SAI) est redondant si K21.0 (RGO avec œsophagite) est présent."), +] + + +def generate_redundancy_examples(cim10_dict: dict, target: int = 200) -> list[dict]: + """Générer des exemples de redondances sémantiques.""" + examples = [] + + user_templates = [ + "DAS candidats : {dominated} ({dom_label}), {dominant} ({sup_label})\nLesquels garder ?", + "Le codage inclut {dominated} et {dominant}. Y a-t-il une redondance ?", + "Vérification de codage :\n- DAS1 : {dominated} ({dom_label})\n- DAS2 : {dominant} ({sup_label})\nCes deux DAS sont-ils tous les deux pertinents ?", + ] + + for dominated_prefix, dominant_prefixes, explanation in SEMANTIC_REDUNDANCIES: + # Trouver des codes réels pour chaque préfixe + dom_codes = [(c, l) for c, l in cim10_dict.items() if c.startswith(dominated_prefix) and len(l) > 3] + sup_codes = [(c, l) for c, l in cim10_dict.items() + if any(c.startswith(dp) for dp in dominant_prefixes) and len(l) > 3] + + if not dom_codes or not sup_codes: + continue + + for _ in range(target // len(SEMANTIC_REDUNDANCIES) + 1): + dom_code, dom_label = random.choice(dom_codes) + sup_code, sup_label = random.choice(sup_codes) + tmpl = random.choice(user_templates) + + user = tmpl.format( + dominated=dom_code, dom_label=dom_label, + dominant=sup_code, sup_label=sup_label + ) + assistant = json.dumps({ + "garder": [sup_code], + "retirer": [dom_code], + "justification": explanation + }, ensure_ascii=False) + examples.append(make_chatml(user, assistant)) + + random.shuffle(examples) + return examples[:target] + + +# ─── Type C : DAS non significatifs ─────────────────────────────────────────── + +# Diagnostics fréquemment mentionnés dans les antécédents mais sans ressources +NON_SIGNIFICANT_DAS = [ + ("J30.1", "Rhinite allergique due au pollen", "rhinite allergique mentionnée dans les antécédents"), + ("J45.9", "Asthme, sans précision", "asthme stable mentionné dans les antécédents"), + ("M54.5", "Lombalgie basse", "lombalgie chronique mentionnée dans les antécédents"), + ("K21.0", "RGO avec œsophagite", "RGO mentionné dans les antécédents"), + ("H52.1", "Myopie", "myopie mentionnée dans les antécédents"), + ("E78.0", "Hypercholestérolémie pure", "hypercholestérolémie dans les antécédents, traitement habituel"), + ("E03.9", "Hypothyroïdie, sans précision", "hypothyroïdie sous Lévothyrox dans les antécédents"), + ("F32.0", "Épisode dépressif léger", "dépression traitée mentionnée dans les antécédents"), + ("G47.3", "Apnée du sommeil", "SAOS appareillé mentionné dans les antécédents"), + ("M81.9", "Ostéoporose sans fracture", "ostéoporose connue dans les antécédents"), + ("H40.1", "Glaucome primaire à angle ouvert", "glaucome traité dans les antécédents"), + ("I84.1", "Hémorroïdes internes avec complication", "hémorroïdes mentionnées dans les antécédents"), + ("K58.9", "Syndrome du côlon irritable", "SCI mentionné dans les antécédents"), + ("L40.0", "Psoriasis vulgaire", "psoriasis stable dans les antécédents"), + ("N40", "Hyperplasie de la prostate", "HBP traitée dans les antécédents"), + ("E66.9", "Obésité, sans précision", "obésité mentionnée, pas de prise en charge spécifique"), + ("Z87.1", "Antécédents personnels de maladies de l'appareil digestif", "antécédent de chirurgie digestive ancienne"), + ("Z86.7", "Antécédents personnels de maladies de l'appareil circulatoire", "antécédent d'AVC il y a 5 ans"), + ("Z92.1", "Antécédents de traitement anticoagulant au long cours", "patient sous AVK au long cours"), + ("F17.2", "Dépendance au tabac", "tabagisme actif mais pas de sevrage pendant le séjour"), +] + + +def generate_non_significant_examples(target: int = 300) -> list[dict]: + """Générer des exemples de DAS non significatifs (antécédents sans ressources).""" + examples = [] + + user_templates = [ + "Le patient a {context}. Faut-il le coder en DAS ?", + "Antécédent : {code} ({label}). Le patient est hospitalisé pour une autre raison. Ce diagnostic doit-il être codé comme DAS ?", + "Le dossier mentionne : {context}. Est-ce un DAS pertinent pour le séjour ?\nDP : {dp_code} ({dp_label})", + "Lors du codage du séjour (DP : {dp_code}), le dossier fait état de {context}. Coder {code} en DAS ?", + ] + + dp_examples = [ + ("K80.1", "Lithiase vésiculaire avec cholécystite"), + ("S72.0", "Fracture du col du fémur"), + ("I63.9", "Infarctus cérébral, sans précision"), + ("J18.1", "Pneumonie lobaire, sans précision"), + ("I21.9", "IDM aigu, sans précision"), + ("C34.9", "Tumeur maligne des bronches ou du poumon"), + ("K35.8", "Appendicite aiguë"), + ("N20.0", "Calcul du rein"), + ("G45.9", "AIT, sans précision"), + ("A41.9", "Sepsis, sans précision"), + ("C18.9", "Tumeur maligne du côlon"), + ("I48.9", "Fibrillation auriculaire"), + ] + + # Générer en croisant chaque DAS avec plusieurs DPs + # On veut ~240 négatifs pour avoir assez de marge avec les positifs + combos_per_das = max(3, (target * 3 // 4) // len(NON_SIGNIFICANT_DAS) + 1) + for code, label, context in NON_SIGNIFICANT_DAS: + selected_dps = random.sample(dp_examples, min(combos_per_das, len(dp_examples))) + for dp_code, dp_label in selected_dps: + tmpl = random.choice(user_templates) + user = tmpl.format( + code=code, label=label, context=context, + dp_code=dp_code, dp_label=dp_label + ) + assistant = json.dumps({ + "coder": False, + "code": code, + "justification": f"Un DAS ne doit être codé que s'il a nécessité des ressources supplémentaires pendant le séjour (examens, traitements, surveillance spécifique). {label} mentionné uniquement dans les antécédents, sans prise en charge spécifique durant le séjour, ne justifie pas un DAS." + }, ensure_ascii=False) + examples.append(make_chatml(user, assistant)) + + # Ajouter des exemples positifs (contraste) — DAS qui DOIVENT être codés + positive_das = [ + ("E11.6", "Diabète de type 2 avec complications", "diabète déséquilibré ayant nécessité une adaptation thérapeutique", + "Le diabète a mobilisé des ressources supplémentaires (adaptation insuline, surveillance glycémique renforcée)."), + ("I10", "HTA essentielle", "HTA sévère ayant nécessité un traitement IV pendant le séjour", + "L'HTA a nécessité un traitement intraveineux spécifique, justifiant le codage en DAS."), + ("N18.4", "IRC stade 4", "IRC ayant nécessité une adaptation posologique de tous les médicaments", + "L'IRC a mobilisé des ressources (adaptation posologique, surveillance créatinine quotidienne)."), + ("E87.1", "Hypo-osmolalité et hyponatrémie", "hyponatrémie sévère découverte pendant le séjour", + "L'hyponatrémie a nécessité un bilan étiologique et un traitement spécifique."), + ("J96.0", "Insuffisance respiratoire aiguë", "détresse respiratoire ayant nécessité une oxygénothérapie", + "L'insuffisance respiratoire a mobilisé des ressources (oxygénothérapie, surveillance SpO2, GDS)."), + ("N17.9", "Insuffisance rénale aiguë", "IRA survenue pendant le séjour ayant nécessité une surveillance biologique quotidienne", + "L'IRA a nécessité des bilans répétés et une adaptation thérapeutique, justifiant le DAS."), + ("D62", "Anémie posthémorragique aiguë", "anémie aiguë ayant nécessité une transfusion de 2 CGR", + "L'anémie a mobilisé des ressources (transfusion, surveillance post-transfusionnelle)."), + ("E87.6", "Hypokaliémie", "hypokaliémie sévère découverte en biologie nécessitant une supplémentation IV", + "L'hypokaliémie a nécessité un traitement spécifique IV, justifiant le DAS."), + ] + + combos_per_pos = max(3, (target // 4) // len(positive_das) + 1) + for code, label, context, justification in positive_das: + selected_dps = random.sample(dp_examples, min(combos_per_pos, len(dp_examples))) + for dp_code, dp_label in selected_dps: + user = f"Le patient est hospitalisé pour {dp_label} (DP : {dp_code}). Il présente aussi : {context}. Faut-il coder {code} ({label}) en DAS ?" + assistant = json.dumps({ + "coder": True, + "code": code, + "justification": justification + }, ensure_ascii=False) + examples.append(make_chatml(user, assistant)) + + random.shuffle(examples) + return examples[:target] + + +def main(): + print("=" * 60) + print("Génération d'exemples négatifs (quand NE PAS coder)") + print("=" * 60) + + print("\nChargement du dictionnaire CIM-10...") + cim10_dict = load_cim10_dict() + print(f" {len(cim10_dict)} codes") + + all_examples = [] + + # Type A : Symptômes couverts par le DP + print("\nType A : Symptômes couverts par le DP...") + symptom_examples = generate_symptom_dp_examples(cim10_dict, target=500) + print(f" → {len(symptom_examples)} exemples") + all_examples.extend(symptom_examples) + + # Type B : Redondances sémantiques + print("\nType B : Redondances sémantiques...") + redundancy_examples = generate_redundancy_examples(cim10_dict, target=200) + print(f" → {len(redundancy_examples)} exemples") + all_examples.extend(redundancy_examples) + + # Type C : DAS non significatifs + print("\nType C : DAS non significatifs...") + non_sig_examples = generate_non_significant_examples(target=300) + print(f" → {len(non_sig_examples)} exemples") + all_examples.extend(non_sig_examples) + + # Mélanger et sauvegarder + random.shuffle(all_examples) + + with open(OUTPUT, "w") as f: + for ex in all_examples: + f.write(json.dumps(ex, ensure_ascii=False) + "\n") + + print(f"\n{'='*60}") + print(f"Total : {len(all_examples)} exemples → {OUTPUT}") + print(f" Type A (symptômes/DP) : {len(symptom_examples)}") + print(f" Type B (redondances) : {len(redundancy_examples)}") + print(f" Type C (non signif.) : {len(non_sig_examples)}") + print(f"Taille : {OUTPUT.stat().st_size / 1024:.0f} Ko") + + +if __name__ == "__main__": + main() diff --git a/scripts/15_generate_discrimination.py b/scripts/15_generate_discrimination.py new file mode 100644 index 0000000..3e0f260 --- /dev/null +++ b/scripts/15_generate_discrimination.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +""" +Exercices de discrimination entre codes CIM-10 siblings (même parent). + +Stratégie : + - Utilise la hiérarchie FHIR pour identifier les groupes de siblings + - Focus sur les top 100 familles CIM-10 les plus fréquentes du pipeline + - Pour chaque groupe, génère un scénario clinique via Claude Opus 4.6 + - La réponse explique pourquoi un code et pas l'autre + +Cible : 800 exemples + +Sources : smt_cim10_fhir.json + cache Ollama (codes fréquents) +Nécessite : ANTHROPIC_API_KEY en variable d'environnement + +Usage : + python scripts/15_generate_discrimination.py [--dry-run] [--max N] +""" + +import argparse +import json +import os +import random +import re +import sys +import time +from collections import Counter +from pathlib import Path + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +T2A = BASE.parent / "t2a" +RAW = BASE / "data" / "raw" +OUTPUT = BASE / "data" / "processed" / "discrimination_chatml.jsonl" +OUTPUT.parent.mkdir(parents=True, exist_ok=True) + +MODEL = "claude-opus-4-6" + +SYSTEM_PROMPT = ( + "Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. " + "Tu sais discriminer les codes CIM-10 proches (siblings) et choisir le plus approprié." +) + +GENERATION_PROMPT = """Tu es un formateur DIM. Génère un exercice de discrimination entre codes CIM-10 proches. + +CODES À DISCRIMINER (même catégorie {parent}) : +{codes_list} + +Génère un objet JSON avec : +1. "scenario" : un cas clinique réaliste (2-3 phrases) où le choix entre ces codes est subtil +2. "reponse" : un objet JSON contenant : + - "analyse_clinique" : interprétation du cas clinique + - "codes_candidats" : les 2-3 codes candidats et pourquoi chacun est envisagé + - "discrimination" : la différence clé entre ces codes (inclusions, exclusions, spécificité) + - "code" : le code correct pour ce scénario + - "confidence" : "high" + - "justification" : pourquoi CE code et pas les autres + +Réponds UNIQUEMENT avec le JSON, sans texte avant/après.""" + + +def load_fhir() -> tuple[list, dict]: + """Charger les concepts FHIR.""" + fhir_path = RAW / "smt_cim10_fhir.json" + data = json.loads(fhir_path.read_text()) + concepts = data["concept"] + by_code = {c["code"]: c for c in concepts} + return concepts, by_code + + +def get_parent(concept: dict) -> str: + for p in concept.get("property", []): + if p["code"] == "parent": + return p.get("valueCode", "") + return "" + + +def get_type(concept: dict) -> str: + for p in concept.get("property", []): + if p["code"] == "type": + return p.get("valueString", "") + return "" + + +def get_inclusion_note(concept: dict) -> str: + for p in concept.get("property", []): + if p["code"] == "inclusionNote": + return p.get("valueString", "") + return "" + + +def get_exclusion_note(concept: dict) -> str: + for p in concept.get("property", []): + if p["code"] == "exclusionNote": + return p.get("valueString", "") + return "" + + +def get_frequent_families() -> Counter: + """Extraire les familles CIM-10 les plus fréquentes depuis le cache Ollama.""" + families = Counter() + for cache_path in [T2A / "data" / "ollama_cache.json", T2A / "data" / "ollama_cache_gemma3.bak"]: + if not cache_path.exists(): + continue + data = json.loads(cache_path.read_text()) + for entry in data.get("entries", {}).values(): + code = entry.get("code", "") + if code and len(code) >= 3 and code[0].isalpha(): + families[code[:3]] += 1 + return families + + +def build_sibling_groups(concepts: list, by_code: dict) -> dict[str, list[dict]]: + """Grouper les codes par parent (siblings).""" + children_by_parent = {} + for c in concepts: + if get_type(c) != "category": + continue + parent = get_parent(c) + if parent and parent in by_code: + children_by_parent.setdefault(parent, []).append(c) + return children_by_parent + + +def format_codes_for_prompt(siblings: list[dict]) -> str: + """Formater les codes pour le prompt LLM.""" + lines = [] + for sib in siblings: + code = sib["code"] + display = sib["display"] + incl = get_inclusion_note(sib) + excl = get_exclusion_note(sib) + line = f"- {code} : {display}" + if incl: + line += f"\n Comprend : {incl[:200]}" + if excl: + line += f"\n Exclut : {excl[:200]}" + lines.append(line) + return "\n".join(lines) + + +def call_claude(client, prompt: str, max_retries: int = 2) -> str | None: + """Appel Claude Opus 4.6 via API Anthropic avec retry.""" + for attempt in range(max_retries + 1): + try: + response = client.messages.create( + model=MODEL, + max_tokens=2048, + temperature=0.7, + messages=[{"role": "user", "content": prompt}], + ) + return response.content[0].text + except Exception as e: + if attempt < max_retries: + wait = 2 ** (attempt + 1) + print(f" Retry in {wait}s: {e}") + time.sleep(wait) + else: + print(f" Claude error: {e}") + return None + + +def parse_llm_response(response_text: str) -> dict | None: + """Parse la réponse JSON du LLM.""" + if not response_text: + return None + text = response_text.strip() + if "```json" in text: + text = text.split("```json", 1)[1].split("```", 1)[0].strip() + elif "```" in text: + text = text.split("```", 1)[1].split("```", 1)[0].strip() + + try: + data = json.loads(text) + if "scenario" in data and "reponse" in data: + return data + except json.JSONDecodeError: + pass + + # Fallback + brace_start = text.find("{") + if brace_start >= 0: + depth = 0 + for i in range(brace_start, len(text)): + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + if depth == 0: + try: + data = json.loads(text[brace_start:i+1]) + if "scenario" in data: + return data + except json.JSONDecodeError: + break + return None + + +def make_chatml(scenario: str, response: dict, parent_code: str, siblings_desc: str) -> dict: + """Créer un exemple ChatML.""" + user_content = ( + f"Cas clinique :\n{scenario}\n\n" + f"Codes CIM-10 candidats (catégorie {parent_code}) :\n{siblings_desc}\n\n" + "Quel code est le plus approprié ? Explique ton raisonnement de discrimination." + ) + assistant_content = json.dumps(response, ensure_ascii=False) + + return { + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + {"role": "assistant", "content": assistant_content}, + ] + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--dry-run", action="store_true", help="Pas d'appel LLM") + parser.add_argument("--max", type=int, default=800, help="Max exemples à générer") + args = parser.parse_args() + + print("=" * 60) + print("Génération d'exercices de discrimination CIM-10") + print(f"Modèle : {MODEL}") + print("=" * 60) + + # Vérifier la clé API + if not args.dry_run: + api_key = os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + print("Erreur: ANTHROPIC_API_KEY non définie.") + print(" export ANTHROPIC_API_KEY='sk-ant-...'") + sys.exit(1) + import anthropic + client = anthropic.Anthropic(api_key=api_key) + else: + client = None + + print("\nChargement FHIR...") + concepts, by_code = load_fhir() + print(f" {len(concepts)} concepts") + + print("\nIdentification des familles fréquentes (cache Ollama)...") + freq_families = get_frequent_families() + top_families = [code for code, _ in freq_families.most_common(150)] + print(f" Top familles : {len(top_families)} (ex: {', '.join(top_families[:10])})") + + print("\nConstruction des groupes de siblings...") + sibling_groups = build_sibling_groups(concepts, by_code) + print(f" {len(sibling_groups)} groupes") + + # Filtrer : 2-8 siblings, prioriser les familles fréquentes + candidates = [] + for parent_code, siblings in sibling_groups.items(): + if len(siblings) < 2 or len(siblings) > 8: + continue + priority = 2 if parent_code[:3] in top_families else 0 + n_with_notes = sum(1 for s in siblings if get_inclusion_note(s) or get_exclusion_note(s)) + priority += n_with_notes + candidates.append((parent_code, siblings, priority)) + + candidates.sort(key=lambda x: -x[2]) + print(f" Candidats filtrés (2-8 siblings) : {len(candidates)}") + + target = min(args.max, len(candidates)) + candidates = candidates[:target] + + if args.dry_run: + for parent_code, siblings, prio in candidates[:20]: + parent_display = by_code[parent_code]["display"] if parent_code in by_code else "?" + sib_codes = ", ".join(s["code"] for s in siblings) + print(f" [{prio}] {parent_code} ({parent_display}): {sib_codes}") + print(f"\n[DRY RUN] {len(candidates)} groupes à traiter. Relancez sans --dry-run.") + return + + # Générer les exercices via Claude + examples = [] + n_ok = 0 + n_fail = 0 + + for i, (parent_code, siblings, _) in enumerate(candidates): + if len(siblings) > 4: + selected = random.sample(siblings, 4) + else: + selected = siblings + + codes_list = format_codes_for_prompt(selected) + parent_display = by_code[parent_code]["display"] if parent_code in by_code else parent_code + + prompt = GENERATION_PROMPT.format(parent=f"{parent_code} ({parent_display})", codes_list=codes_list) + response_text = call_claude(client, prompt) + parsed = parse_llm_response(response_text) + + if parsed and "scenario" in parsed and "reponse" in parsed: + siblings_desc = "\n".join(f"- {s['code']} : {s['display']}" for s in selected) + example = make_chatml(parsed["scenario"], parsed["reponse"], parent_code, siblings_desc) + examples.append(example) + n_ok += 1 + else: + n_fail += 1 + + if (i + 1) % 50 == 0: + print(f" Progression : {i+1}/{len(candidates)} (ok={n_ok}, fail={n_fail})") + + # Mélanger et sauvegarder + random.shuffle(examples) + + with open(OUTPUT, "w") as f: + for ex in examples: + f.write(json.dumps(ex, ensure_ascii=False) + "\n") + + print(f"\n{'='*60}") + print(f"Total : {len(examples)} exemples → {OUTPUT}") + print(f" OK : {n_ok}, Échecs : {n_fail}") + print(f"Taille : {OUTPUT.stat().st_size / 1024:.0f} Ko") + + +if __name__ == "__main__": + main() diff --git a/scripts/16_parse_guide_metho.py b/scripts/16_parse_guide_metho.py new file mode 100644 index 0000000..767c306 --- /dev/null +++ b/scripts/16_parse_guide_metho.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python3 +""" +Extraction des règles du Guide Méthodologique MCO 2026 → Q&A raisonnement. + +Le Guide Méthodologique est le document de référence ATIH qui définit : + - Les règles de codage du DP, DR, DAS + - Les règles CMA et niveaux de sévérité + - Les règles de codage des actes CCAM + - Les cas particuliers (séjours multi-unités, transferts, etc.) + +Stratégie : + - Parser le PDF en sections + - Extraire les règles de codage + - Générer des Q&A directes (sans LLM) pour les définitions + - Générer des Q&A de raisonnement via Claude Opus 4.6 pour les règles + - Focus sur les règles DP/DAS/CMA les plus applicables + +Cible : 500 exemples (183 directs + ~320 LLM) + +Source : data/raw/guide_methodo_mco_2026.pdf +Nécessite : ANTHROPIC_API_KEY en variable d'environnement + +Usage : + python scripts/16_parse_guide_metho.py [--dry-run] [--max N] +""" + +import argparse +import json +import os +import random +import re +import sys +import time +from pathlib import Path + +import pdfplumber + +random.seed(42) + +BASE = Path(__file__).resolve().parent.parent +PDF_PATH = BASE / "data" / "raw" / "guide_methodo_mco_2026.pdf" +OUTPUT = BASE / "data" / "processed" / "guide_metho_chatml.jsonl" +OUTPUT.parent.mkdir(parents=True, exist_ok=True) + +MODEL = "claude-opus-4-6" + +SYSTEM_PROMPT = ( + "Tu es un médecin DIM expert en codage PMSI. " + "Tu t'appuies sur le Guide Méthodologique de production des informations " + "relatives à l'activité MCO (ATIH 2026)." +) + +# Patterns pour détecter les règles dans le guide +RULE_PATTERNS = [ + re.compile(r"(?:le\s+)?diagnostic\s+principal\s+(?:est|doit|sera|correspond)", re.IGNORECASE), + re.compile(r"(?:le\s+)?diagnostic\s+(?:relié|associé)\s+(?:est|doit|sera)", re.IGNORECASE), + re.compile(r"(?:la\s+)?CMA\s+(?:est|correspond|ne\s+peut)", re.IGNORECASE), + re.compile(r"(?:le\s+)?DAS\s+(?:est|doit|sera|ne\s+doit\s+pas)", re.IGNORECASE), + re.compile(r"(?:le\s+)?DP\s+(?:est|doit|sera|ne\s+doit\s+pas)", re.IGNORECASE), + re.compile(r"(?:le\s+)?DR\s+(?:est|doit|sera)", re.IGNORECASE), + re.compile(r"(?:on\s+)?code(?:ra)?\s+", re.IGNORECASE), + re.compile(r"ne\s+(?:pas|doit\s+pas|faut\s+pas)\s+(?:coder|enregistrer|recueillir)", re.IGNORECASE), + re.compile(r"séjour\s+multi-?uni", re.IGNORECASE), + re.compile(r"(?:transfert|mutation)\s+", re.IGNORECASE), + re.compile(r"(?:sévérité|niveau\s+de\s+sévérité|complication)", re.IGNORECASE), + re.compile(r"(?:ressources?\s+supplémentaires?|consomm)", re.IGNORECASE), + re.compile(r"(?:groupage|GHM|GHS|CMD)", re.IGNORECASE), +] + +GENERATION_PROMPT = """Tu es un formateur DIM. À partir de cette règle du Guide Méthodologique MCO, génère un exercice de raisonnement. + +RÈGLE (source : Guide Méthodologique MCO 2026, section « {section} ») : +{rule_text} + +Génère un objet JSON avec : +1. "scenario" : un cas clinique réaliste et concis (2-3 phrases) qui illustre l'application de cette règle +2. "reponse" : un objet JSON contenant : + - "analyse_clinique" : interprétation du cas + - "regle_pmsi" : la règle du guide méthodologique qui s'applique (citée) + - "code" : le code CIM-10 ou l'action de codage correcte + - "confidence" : "high" + - "justification" : pourquoi cette règle s'applique à ce cas + +Réponds UNIQUEMENT avec le JSON, sans texte avant/après.""" + +# Templates pour les Q&A directes (sans LLM) +DIRECT_QA_TEMPLATES = [ + ("Quelle est la définition du {concept} selon le Guide Méthodologique MCO ?", + "Selon le Guide Méthodologique MCO 2026 :\n\n{text}"), + ("Quelles sont les règles de codage pour {concept} selon le Guide Méthodologique ?", + "Le Guide Méthodologique MCO 2026 précise les règles suivantes pour {concept} :\n\n{text}"), +] + + +def extract_sections(pdf_path: Path) -> list[tuple[str, str]]: + """Extraire les sections du Guide Méthodologique.""" + with pdfplumber.open(str(pdf_path)) as pdf: + full_text = "" + for page in pdf.pages: + text = page.extract_text() or "" + full_text += text + "\n\n" + + lines = [] + for line in full_text.split("\n"): + stripped = line.strip() + if re.match(r"^\s*\d{1,3}\s*$", stripped): + continue + if re.match(r"^(Guide méthodologique|ATIH|Page\s+\d)", stripped, re.IGNORECASE): + continue + if "...." in stripped: + continue + lines.append(line) + + clean_text = "\n".join(lines) + + section_pattern = re.compile( + r"^(\d+(?:\.\d+)*\.?)\s+([A-ZÀÂÉÈÊËÎÏÔÙÛÜÇ].*)", + re.MULTILINE + ) + + sections = [] + matches = list(section_pattern.finditer(clean_text)) + + for i, match in enumerate(matches): + num = match.group(1).rstrip(".") + title = match.group(2).strip() + start = match.end() + end = matches[i + 1].start() if i + 1 < len(matches) else len(clean_text) + body = clean_text[start:end].strip() + + if len(body) > 80: + section_title = f"{num}. {title}" + sections.append((section_title, body)) + + return sections + + +def extract_rule_paragraphs(sections: list[tuple[str, str]]) -> list[tuple[str, str]]: + """Extraire les paragraphes contenant des règles de codage.""" + rules = [] + + for section_title, body in sections: + paragraphs = re.split(r"\n\s*\n", body) + + for para in paragraphs: + para = para.strip() + if len(para) < 80: + continue + + matches = sum(1 for pat in RULE_PATTERNS if pat.search(para)) + if matches >= 1: + text = para[:1500] if len(para) > 1500 else para + rules.append((section_title, text)) + + return rules + + +def call_claude(client, prompt: str, max_retries: int = 2) -> str | None: + """Appel Claude Opus 4.6 via API Anthropic avec retry.""" + for attempt in range(max_retries + 1): + try: + response = client.messages.create( + model=MODEL, + max_tokens=2048, + temperature=0.7, + messages=[{"role": "user", "content": prompt}], + ) + return response.content[0].text + except Exception as e: + if attempt < max_retries: + wait = 2 ** (attempt + 1) + print(f" Retry in {wait}s: {e}") + time.sleep(wait) + else: + print(f" Claude error: {e}") + return None + + +def parse_llm_response(response_text: str) -> dict | None: + """Parse la réponse JSON du LLM.""" + if not response_text: + return None + text = response_text.strip() + if "```json" in text: + text = text.split("```json", 1)[1].split("```", 1)[0].strip() + elif "```" in text: + text = text.split("```", 1)[1].split("```", 1)[0].strip() + + try: + data = json.loads(text) + if "scenario" in data and "reponse" in data: + return data + except json.JSONDecodeError: + pass + + # Fallback + brace_start = text.find("{") + if brace_start >= 0: + depth = 0 + for i in range(brace_start, len(text)): + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + if depth == 0: + try: + data = json.loads(text[brace_start:i+1]) + if "scenario" in data: + return data + except json.JSONDecodeError: + break + return None + + +def make_chatml(system: str, user: str, assistant: str) -> dict: + return { + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + {"role": "assistant", "content": assistant}, + ] + } + + +def generate_direct_qa(sections: list[tuple[str, str]]) -> list[dict]: + """Générer des Q&A directes depuis les sections (sans LLM).""" + examples = [] + + concept_keywords = { + "diagnostic principal": ["DP", "diagnostic principal"], + "diagnostic relié": ["DR", "diagnostic relié"], + "diagnostic associé significatif": ["DAS", "diagnostic associé"], + "complication ou morbidité associée": ["CMA", "complication", "sévérité"], + "séjour multi-unités": ["multi-unité", "transfert", "mutation"], + "groupage GHM": ["GHM", "GHS", "groupage", "CMD"], + "actes CCAM": ["CCAM", "acte", "procédure"], + } + + for section_title, body in sections: + body_lower = body.lower() + + for concept, keywords in concept_keywords.items(): + if any(kw.lower() in body_lower for kw in keywords): + text = body[:2000] if len(body) > 2000 else body + + tmpl_q, tmpl_a = random.choice(DIRECT_QA_TEMPLATES) + user = tmpl_q.format(concept=f"{concept} (section {section_title})") + assistant = tmpl_a.format(concept=concept, text=text) + + examples.append(make_chatml(SYSTEM_PROMPT, user, assistant)) + break + + return examples + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--dry-run", action="store_true", help="Pas d'appel LLM") + parser.add_argument("--max", type=int, default=500, help="Max exemples à générer via LLM") + args = parser.parse_args() + + print("=" * 60) + print("Parsing du Guide Méthodologique MCO 2026") + print(f"Modèle : {MODEL}") + print("=" * 60) + + if not PDF_PATH.exists(): + print(f"PDF non trouvé : {PDF_PATH}") + return + + # Vérifier la clé API + if not args.dry_run: + api_key = os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + print("Erreur: ANTHROPIC_API_KEY non définie.") + print(" export ANTHROPIC_API_KEY='sk-ant-...'") + sys.exit(1) + import anthropic + client = anthropic.Anthropic(api_key=api_key) + else: + client = None + + print(f"\nParsing de {PDF_PATH.name} ({PDF_PATH.stat().st_size / 1024 / 1024:.1f} Mo)...") + sections = extract_sections(PDF_PATH) + print(f" Sections extraites : {len(sections)}") + + # Q&A directes (sans LLM) + print("\nGénération des Q&A directes (sans LLM)...") + direct_examples = generate_direct_qa(sections) + print(f" → {len(direct_examples)} exemples directs") + + # Extraire les paragraphes de règles + print("\nExtraction des règles de codage...") + rules = extract_rule_paragraphs(sections) + print(f" Règles détectées : {len(rules)}") + + if args.dry_run: + for i, (section, rule) in enumerate(rules[:20]): + print(f" [{i+1}] [{section}] {rule[:100]}...") + print(f"\n[DRY RUN] {len(rules)} règles à traiter. Relancez sans --dry-run.") + return + + # Limiter le nombre de règles pour LLM + if len(rules) > args.max: + random.shuffle(rules) + rules = rules[:args.max] + + # Générer les Q&A via Claude + print(f"\nGénération via Claude ({len(rules)} règles)...") + llm_examples = [] + n_ok = 0 + n_fail = 0 + + for i, (section_title, rule_text) in enumerate(rules): + prompt = GENERATION_PROMPT.format(section=section_title, rule_text=rule_text) + response_text = call_claude(client, prompt) + parsed = parse_llm_response(response_text) + + if parsed and "scenario" in parsed and "reponse" in parsed: + user_content = ( + f"Cas clinique :\n{parsed['scenario']}\n\n" + f"Applique les règles du Guide Méthodologique MCO (section « {section_title} »)." + ) + assistant_content = json.dumps(parsed["reponse"], ensure_ascii=False) + llm_examples.append(make_chatml(SYSTEM_PROMPT, user_content, assistant_content)) + n_ok += 1 + else: + n_fail += 1 + + if (i + 1) % 50 == 0: + print(f" Progression : {i+1}/{len(rules)} (ok={n_ok}, fail={n_fail})") + + print(f" Résultat LLM : {n_ok} exemples, {n_fail} échecs") + + # Fusionner + all_examples = direct_examples + llm_examples + random.shuffle(all_examples) + + with open(OUTPUT, "w") as f: + for ex in all_examples: + f.write(json.dumps(ex, ensure_ascii=False) + "\n") + + print(f"\n{'='*60}") + print(f"Total : {len(all_examples)} exemples → {OUTPUT}") + print(f" Directs (sans LLM) : {len(direct_examples)}") + print(f" LLM (Claude) : {len(llm_examples)}") + print(f"Taille : {OUTPUT.stat().st_size / 1024:.0f} Ko") + + +if __name__ == "__main__": + main() diff --git a/scripts/export_checkpoint_gguf.py b/scripts/export_checkpoint_gguf.py new file mode 100644 index 0000000..73f45dd --- /dev/null +++ b/scripts/export_checkpoint_gguf.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +""" +Export un checkpoint LoRA en GGUF pour Ollama. + +Stratégie : charge le modèle de base via Unsloth (pour bénéficier de +save_pretrained_gguf), puis applique le LoRA depuis le checkpoint +via PeftModel.from_pretrained en copiant la méthode GGUF. + +Usage: python scripts/export_checkpoint_gguf.py [--checkpoint models/pmsi-lora-checkpoints/checkpoint-7000] +""" + +import argparse +from pathlib import Path + +BASE = Path(__file__).resolve().parent.parent +OUTPUT = BASE / "models" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", default=str(OUTPUT / "pmsi-lora-checkpoints" / "checkpoint-7000")) + parser.add_argument("--model", default="unsloth/gemma-3-12b-it-bnb-4bit") + parser.add_argument("--max-seq-length", type=int, default=512) + parser.add_argument("--quant", default="q4_k_m") + args = parser.parse_args() + + from unsloth import FastLanguageModel + from peft import PeftModel + + checkpoint_dir = Path(args.checkpoint) + gguf_dir = OUTPUT / "pmsi-gguf-v2" + gguf_dir.mkdir(parents=True, exist_ok=True) + + # Étape 1 : Charger base via Unsloth (installe save_pretrained_gguf) + print(f"[1/3] Chargement du modèle de base via Unsloth...") + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=args.model, + max_seq_length=args.max_seq_length, + dtype=None, + load_in_4bit=True, + ) + + # Sauver la méthode GGUF avant que PeftModel ne l'écrase + save_gguf_fn = model.save_pretrained_gguf + + # Étape 2 : Appliquer le LoRA checkpoint + print(f"[2/3] Application LoRA depuis {checkpoint_dir.name}...") + model = PeftModel.from_pretrained(model, str(checkpoint_dir)) + + # Réattacher la méthode GGUF d'Unsloth au modèle PeftModel + model.save_pretrained_gguf = save_gguf_fn.__func__.__get__(model, type(model)) + + # Vérifier que les poids LoRA sont bien chargés + lora_params = [n for n, p in model.named_parameters() if "lora" in n and p.requires_grad] + print(f" Paramètres LoRA actifs : {len(lora_params)}") + + # Étape 3 : Export GGUF + print(f"[3/3] Export GGUF ({args.quant})...") + model.save_pretrained_gguf( + str(gguf_dir), + tokenizer, + quantization_method=args.quant, + ) + + # Résultat + gguf_files = sorted(gguf_dir.glob("*.gguf"), key=lambda f: f.stat().st_size) + if not gguf_files: + print("Aucun GGUF produit !") + return + + final_gguf = gguf_files[0] + for g in gguf_files: + size_gb = g.stat().st_size / 1024**3 + print(f" {g.name} ({size_gb:.1f} Go)") + + # Modelfile + modelfile_path = gguf_dir / "Modelfile" + with open(modelfile_path, "w") as f: + f.write(f"FROM {final_gguf.name}\n\n") + f.write("PARAMETER temperature 0.3\n") + f.write("PARAMETER top_p 0.9\n") + f.write("PARAMETER num_ctx 8192\n") + + print(f"\nTerminé !") + print(f" GGUF : {final_gguf}") + print(f"\nPour importer dans Ollama :") + print(f" cd {gguf_dir}") + print(f" ollama create pmsi-coder -f Modelfile") + + +if __name__ == "__main__": + main()