feat: rééquilibrage dataset LoRA — raisonnement DIM vs mémorisation

Passe de 95/3/2 (lookups/raisonnement/règles) à ~31/49/20.
Dataset cible ~16K exemples denses (vs 66K de lookups avant).

Modifiés :
- 03_convert_cache.py : cache complet 1840 entrées (actuel + backup)
- 04_build_dataset.py : subsampling agressif (CIM-10 1.5K, CCAM 1.5K,
  CoCoA 2K) + sélection intelligente priorisant le raisonnement
- 12_generate_pipeline_examples.py : 3 templates (court + long + CPAM),
  cache actuel, cible ~2800 exemples

Créés :
- 13_generate_fascicule_reasoning.py : parsing 10 fascicules ATIH,
  génération Q&A raisonnement via Claude Opus 4.6 (~450 exemples)
- 14_generate_negative_examples.py : 1000 exemples négatifs
  (symptômes/DP, redondances sémantiques, DAS non significatifs)
- 15_generate_discrimination.py : 800 exercices de discrimination
  entre codes siblings CIM-10 via Claude Opus 4.6
- 16_parse_guide_metho.py : extraction Guide Méthodologique MCO 2026,
  Q&A directes + raisonnement via Claude Opus 4.6 (~500 exemples)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dom
2026-02-16 19:42:33 +01:00
commit 06100df236
21 changed files with 6106 additions and 0 deletions

View File

@@ -0,0 +1,332 @@
#!/usr/bin/env python3
"""
Phase 1A — Génération de paires ChatML CIM-10 depuis le FHIR JSON.
Sources : smt_cim10_fhir.json (19 161 concepts)
Produit : data/processed/cim10_chatml.jsonl
Types d'exemples générés :
1. code → description (lookup)
2. description → code (codage)
3. discrimination entre codes frères (même parent)
4. inclusions/exclusions (ce qui est compris / exclu d'un code)
"""
import json
import random
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
RAW = BASE / "data" / "raw"
OUT = BASE / "data" / "processed"
OUT.mkdir(parents=True, exist_ok=True)
SYSTEM_MSG = "Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français."
# --- Chapitres CIM-10 ---
CHAPTERS = {
"A": "Certaines maladies infectieuses et parasitaires",
"B": "Certaines maladies infectieuses et parasitaires",
"C": "Tumeurs",
"D": "Tumeurs / Maladies du sang",
"E": "Maladies endocriniennes, nutritionnelles et métaboliques",
"F": "Troubles mentaux et du comportement",
"G": "Maladies du système nerveux",
"H": "Maladies de l'œil et de l'oreille",
"I": "Maladies de l'appareil circulatoire",
"J": "Maladies de l'appareil respiratoire",
"K": "Maladies de l'appareil digestif",
"L": "Maladies de la peau et du tissu cellulaire sous-cutané",
"M": "Maladies du système ostéo-articulaire",
"N": "Maladies de l'appareil génito-urinaire",
"O": "Grossesse, accouchement et puerpéralité",
"P": "Certaines affections dont l'origine se situe dans la période périnatale",
"Q": "Malformations congénitales et anomalies chromosomiques",
"R": "Symptômes, signes et résultats anormaux",
"S": "Lésions traumatiques et empoisonnements",
"T": "Lésions traumatiques et empoisonnements",
"V": "Causes externes de morbidité et de mortalité",
"W": "Causes externes de morbidité et de mortalité",
"X": "Causes externes de morbidité et de mortalité",
"Y": "Causes externes de morbidité et de mortalité",
"Z": "Facteurs influant sur l'état de santé et motifs de recours aux services de santé",
"U": "Codes d'utilisation particulière",
}
def load_fhir():
"""Charger et indexer les concepts FHIR."""
with open(RAW / "smt_cim10_fhir.json") as f:
data = json.load(f)
concepts = data["concept"]
# Index par code
by_code = {}
for c in concepts:
by_code[c["code"]] = c
return concepts, by_code
def get_props(concept):
"""Extraire les propriétés d'un concept sous forme de dict (multi-valeurs en listes)."""
props = {}
for p in concept.get("property", []):
key = p["code"]
val = p.get("valueCode", p.get("valueString", ""))
if key in props:
if isinstance(props[key], list):
props[key].append(val)
else:
props[key] = [props[key], val]
else:
props[key] = val
return props
def clean_display(display):
"""Nettoyer le libellé (enlever les codes entre crochets type [G31.0])."""
import re
# Retirer les références entre crochets comme [G31.0†]
cleaned = re.sub(r'\s*\[[\w.†*+]+\]\s*', ' ', display)
# Retirer les guillemets décoratifs
cleaned = cleaned.replace('"', '').replace('"', '').replace('"', '')
# Nettoyer les espaces multiples et les tirets isolés
cleaned = re.sub(r'\s*-\s*-\s*', ' - ', cleaned)
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
return cleaned
def make_chatml(system, user, assistant):
"""Créer un exemple ChatML."""
return {
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
}
def generate_lookup_pairs(concepts, by_code):
"""Type 1 : code → description."""
pairs = []
for c in concepts:
props = get_props(c)
if props.get("type") not in ("category",):
continue
code = c["code"]
display = clean_display(c["display"])
if not display or len(display) < 3:
continue
chapter = CHAPTERS.get(code[0], "")
parent_code = props.get("parent", "")
parent_display = ""
if parent_code and parent_code in by_code:
parent_display = clean_display(by_code[parent_code]["display"])
# Construire la réponse enrichie
answer_parts = [f"{code}{display}"]
if chapter:
answer_parts.append(f"Chapitre : {chapter}")
if parent_display and parent_code != code:
answer_parts.append(f"Catégorie parente : {parent_code} ({parent_display})")
# Ajouter inclusions si présentes
incl = props.get("inclusionNote", "")
if incl and len(incl) < 300:
answer_parts.append(f"Comprend : {incl}")
# Varier les formulations de question
templates = [
f"Que désigne le code CIM-10 {code} ?",
f"Quel est le libellé du code {code} ?",
f"Décris le code CIM-10 {code}.",
]
question = random.choice(templates)
pairs.append(make_chatml(SYSTEM_MSG, question, "\n".join(answer_parts)))
return pairs
def generate_coding_pairs(concepts, by_code):
"""Type 2 : description → code (codage direct)."""
pairs = []
for c in concepts:
props = get_props(c)
if props.get("type") not in ("category",):
continue
code = c["code"]
display = clean_display(c["display"])
if not display or len(display) < 5:
continue
# Réponse JSON structurée (format du pipeline T2A)
answer = json.dumps({
"code": code,
"confidence": "high",
"justification": f"Correspondance directe avec le libellé CIM-10 : {code} {display}."
}, ensure_ascii=False)
templates = [
f"Quel est le code CIM-10 pour : {display} ?",
f"Code CIM-10 pour « {display} » ?",
f"Codage CIM-10 du diagnostic : {display}",
]
question = random.choice(templates)
pairs.append(make_chatml(SYSTEM_MSG, question, answer))
return pairs
def generate_discrimination_pairs(concepts, by_code):
"""Type 3 : discrimination entre codes frères (même parent)."""
pairs = []
# Grouper par parent
children_by_parent = {}
for c in concepts:
props = get_props(c)
if props.get("type") != "category":
continue
parent = props.get("parent", "")
if parent and parent in by_code:
children_by_parent.setdefault(parent, []).append(c)
for parent_code, children in children_by_parent.items():
if len(children) < 2 or len(children) > 15:
continue
parent = by_code[parent_code]
parent_display = clean_display(parent["display"])
# Construire la question
question = f"Quels sont les sous-codes de {parent_code} ({parent_display}) et comment les distinguer ?"
# Construire la réponse
lines = [f"La catégorie {parent_code} ({parent_display}) comprend les codes suivants :\n"]
for child in children:
child_display = clean_display(child["display"])
child_props = get_props(child)
line = f"- {child['code']} : {child_display}"
# Ajouter une note d'inclusion courte si disponible
incl = child_props.get("inclusionNote", "")
if incl and len(incl) < 150:
line += f" (comprend : {incl})"
lines.append(line)
lines.append(f"\nLe choix du code dépend de la précision diagnostique disponible. "
f"En l'absence de précision, utiliser le code SAI (.9) s'il existe.")
answer = "\n".join(lines)
# Limiter la taille
if len(answer) > 2000:
continue
pairs.append(make_chatml(SYSTEM_MSG, question, answer))
return pairs
def generate_inclusion_exclusion_pairs(concepts, by_code):
"""Type 4 : questions sur les inclusions/exclusions d'un code."""
pairs = []
for c in concepts:
props = get_props(c)
if props.get("type") not in ("category",):
continue
code = c["code"]
display = clean_display(c["display"])
incl = props.get("inclusionNote", "")
excl_note = props.get("exclusionNote", "")
excl_codes = props.get("exclusion", "")
note = props.get("note", "")
# Il faut au moins une inclusion OU exclusion
if not incl and not excl_note:
continue
# Construire la réponse
answer_parts = [f"Code {code}{display}\n"]
if incl:
answer_parts.append(f"Ce code COMPREND :\n{incl}")
if excl_note:
answer_parts.append(f"\nCe code EXCLUT :\n{excl_note}")
if note:
answer_parts.append(f"\nNote : {note}")
answer = "\n".join(answer_parts)
if len(answer) > 2000:
continue
templates = [
f"Quelles sont les inclusions et exclusions du code {code} ({display}) ?",
f"Que comprend et que exclut le code CIM-10 {code} ?",
]
question = random.choice(templates)
pairs.append(make_chatml(SYSTEM_MSG, question, answer))
return pairs
def main():
print("Chargement du FHIR JSON...")
concepts, by_code = load_fhir()
print(f" {len(concepts)} concepts chargés")
print("\nGénération des paires...")
print(" Type 1 : code → description (lookup)")
lookup = generate_lookup_pairs(concepts, by_code)
print(f"{len(lookup)} exemples")
print(" Type 2 : description → code (codage)")
coding = generate_coding_pairs(concepts, by_code)
print(f"{len(coding)} exemples")
print(" Type 3 : discrimination codes frères")
discrim = generate_discrimination_pairs(concepts, by_code)
print(f"{len(discrim)} exemples")
print(" Type 4 : inclusions / exclusions")
incl_excl = generate_inclusion_exclusion_pairs(concepts, by_code)
print(f"{len(incl_excl)} exemples")
# Fusionner et mélanger
all_pairs = lookup + coding + discrim + incl_excl
random.shuffle(all_pairs)
# Écrire en JSONL
output_path = OUT / "cim10_chatml.jsonl"
with open(output_path, "w") as f:
for pair in all_pairs:
f.write(json.dumps(pair, ensure_ascii=False) + "\n")
print(f"\nTotal : {len(all_pairs)} exemples → {output_path}")
print(f"Taille : {output_path.stat().st_size / 1024 / 1024:.1f} Mo")
# Stats par type
print("\nRépartition :")
print(f" Lookup (code→desc) : {len(lookup)}")
print(f" Codage (desc→code) : {len(coding)}")
print(f" Discrimination : {len(discrim)}")
print(f" Inclusions/Exclus. : {len(incl_excl)}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,170 @@
#!/usr/bin/env python3
"""
Phase 1B — Génération de paires ChatML CCAM depuis ccam_dict.json.
Sources : ccam_dict.json (8 257 codes) du projet T2A
Produit : data/processed/ccam_chatml.jsonl
Types d'exemples générés :
1. code → description (lookup)
2. description → code (codage)
3. discrimination par regroupement (codes du même regroupement)
"""
import json
import random
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
T2A = Path("/home/dom/ai/t2a")
OUT = BASE / "data" / "processed"
OUT.mkdir(parents=True, exist_ok=True)
SYSTEM_MSG = "Tu es un médecin DIM expert en codage CCAM pour le PMSI français."
def load_ccam():
"""Charger le dictionnaire CCAM."""
with open(T2A / "data" / "ccam_dict.json") as f:
return json.load(f)
def make_chatml(system, user, assistant):
return {
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
}
def generate_lookup_pairs(ccam):
"""Type 1 : code → description."""
pairs = []
for code, info in ccam.items():
desc = info.get("description", "")
if not desc or len(desc) < 5:
continue
regroupement = info.get("regroupement", "")
activite = info.get("activite", "")
tarif = info.get("tarif_s1")
answer_parts = [f"{code}{desc}"]
if regroupement:
answer_parts.append(f"Regroupement : {regroupement}")
if activite:
answer_parts.append(f"Activité : {activite}")
if tarif:
answer_parts.append(f"Tarif secteur 1 : {tarif}")
templates = [
f"Que désigne le code CCAM {code} ?",
f"Quel est le libellé de l'acte CCAM {code} ?",
f"Décris l'acte CCAM {code}.",
]
pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), "\n".join(answer_parts)))
return pairs
def generate_coding_pairs(ccam):
"""Type 2 : description → code."""
pairs = []
for code, info in ccam.items():
desc = info.get("description", "")
if not desc or len(desc) < 10:
continue
answer = json.dumps({
"code": code,
"confidence": "high",
"justification": f"Correspondance directe avec le libellé CCAM : {code} {desc}."
}, ensure_ascii=False)
templates = [
f"Quel est le code CCAM pour : {desc} ?",
f"Code CCAM pour « {desc} » ?",
f"Codage CCAM de l'acte : {desc}",
]
pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), answer))
return pairs
def generate_regroupement_pairs(ccam):
"""Type 3 : regroupement → liste des actes du même regroupement."""
pairs = []
# Grouper par regroupement
by_regroup = {}
for code, info in ccam.items():
reg = info.get("regroupement", "")
if reg:
by_regroup.setdefault(reg, []).append((code, info))
for reg, actes in by_regroup.items():
if len(actes) < 2 or len(actes) > 20:
continue
question = f"Quels sont les actes CCAM du regroupement {reg} ?"
lines = [f"Le regroupement {reg} comprend {len(actes)} actes :\n"]
for code, info in actes[:15]:
desc = info.get("description", "")
tarif = info.get("tarif_s1")
line = f"- {code} : {desc}"
if tarif:
line += f" ({tarif} €)"
lines.append(line)
if len(actes) > 15:
lines.append(f" ... et {len(actes) - 15} autres actes.")
answer = "\n".join(lines)
if len(answer) > 2000:
continue
pairs.append(make_chatml(SYSTEM_MSG, question, answer))
return pairs
def main():
print("Chargement du dictionnaire CCAM...")
ccam = load_ccam()
print(f" {len(ccam)} codes chargés")
print("\nGénération des paires...")
print(" Type 1 : code → description (lookup)")
lookup = generate_lookup_pairs(ccam)
print(f"{len(lookup)} exemples")
print(" Type 2 : description → code (codage)")
coding = generate_coding_pairs(ccam)
print(f"{len(coding)} exemples")
print(" Type 3 : regroupement")
regroup = generate_regroupement_pairs(ccam)
print(f"{len(regroup)} exemples")
all_pairs = lookup + coding + regroup
random.shuffle(all_pairs)
output_path = OUT / "ccam_chatml.jsonl"
with open(output_path, "w") as f:
for pair in all_pairs:
f.write(json.dumps(pair, ensure_ascii=False) + "\n")
print(f"\nTotal : {len(all_pairs)} exemples → {output_path}")
print(f"Taille : {output_path.stat().st_size / 1024 / 1024:.1f} Mo")
if __name__ == "__main__":
main()

206
scripts/03_convert_cache.py Normal file
View File

@@ -0,0 +1,206 @@
#!/usr/bin/env python3
"""
Phase 1C — Conversion du cache Ollama en exemples de raisonnement ChatML.
Sources : ollama_cache.json (1 840 entrées avec raisonnement complet)
Produit : data/processed/reasoning_chatml.jsonl
V2 : Utilise le cache actuel complet (1 840 entrées vs 100 avant).
Filtre pour ne garder que les entrées avec raisonnement structuré.
Supporte aussi les clés das_llm::das_extract:: du pipeline étendu.
Chaque entrée du cache contient un raisonnement structuré :
- analyse_clinique → codes_candidats → discrimination → regle_pmsi → code + justification
Ces exemples sont les plus précieux car ils montrent le raisonnement DIM complet.
"""
import json
import random
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
T2A = Path("/home/dom/ai/t2a")
OUT = BASE / "data" / "processed"
OUT.mkdir(parents=True, exist_ok=True)
SYSTEM_MSG = "Tu es un médecin DIM expert en codage PMSI. Tu codes les diagnostics en CIM-10 en suivant une démarche structurée : analyse clinique, identification des codes candidats, discrimination, vérification des règles PMSI."
def make_chatml(system, user, assistant):
return {
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
}
def load_cache():
"""Charger le cache Ollama (actuel + backup si disponible)."""
entries = {}
# Cache actuel (1 840 entrées)
cache_path = T2A / "data" / "ollama_cache.json"
if cache_path.exists():
with open(cache_path) as f:
data = json.load(f)
entries.update(data.get("entries", {}))
print(f" Cache actuel : {len(data.get('entries', {}))} entrées")
# Cache backup (peut contenir des entrées supplémentaires)
backup_path = T2A / "data" / "ollama_cache_gemma3.bak"
if backup_path.exists():
with open(backup_path) as f:
data = json.load(f)
backup_entries = data.get("entries", {})
new_count = sum(1 for k in backup_entries if k not in entries)
entries.update(backup_entries)
print(f" Cache backup : {len(backup_entries)} entrées (+{new_count} nouvelles)")
return entries
def parse_cache_key(key):
"""Extraire le type (dp/das) et le texte depuis la clé du cache.
Formats supportés :
- "dp::texte du diagnostic"
- "das::texte du diagnostic"
- "das_llm::das_extract::hash::texte" (pipeline étendu)
"""
if key.startswith("das_llm::das_extract::"):
# Format : das_llm::das_extract::HASH::texte
parts = key.split("::", 3)
texte = parts[3] if len(parts) > 3 else parts[-1]
return "das", texte.strip()
if "::" in key:
diag_type, texte = key.split("::", 1)
return diag_type.strip(), texte.strip()
return "das", key.strip()
def build_user_prompt(diag_type, texte):
"""Construire le prompt utilisateur à partir du type et du texte."""
type_label = "Diagnostic Principal (DP)" if diag_type == "dp" else "Diagnostic Associé Significatif (DAS)"
prompt = f"Code ce diagnostic en CIM-10.\n\n"
prompt += f"DIAGNOSTIC : {texte.capitalize()}\n"
prompt += f"TYPE : {type_label}"
return prompt
def build_assistant_response(entry):
"""Construire la réponse structurée de l'assistant."""
code = entry.get("code", "")
confidence = entry.get("confidence", "medium")
justification = entry.get("justification", "")
raisonnement = entry.get("raisonnement", "")
# Si on a un raisonnement complet, le formater en JSON structuré
if raisonnement:
# Parser les sections du raisonnement
response = {}
sections = {
"ANALYSE CLINIQUE": "analyse_clinique",
"CODES CANDIDATS": "codes_candidats",
"DISCRIMINATION": "discrimination",
"REGLE PMSI": "regle_pmsi",
"RÈGLE PMSI": "regle_pmsi",
}
# Extraire chaque section du raisonnement
remaining = raisonnement
for header, key in sections.items():
marker = f"{header} :"
if marker not in remaining:
marker = f"{header}:"
if marker in remaining:
idx = remaining.index(marker)
# Trouver la fin de cette section (début de la suivante ou fin)
end_idx = len(remaining)
for next_header in sections:
next_marker = f"{next_header} :"
next_marker2 = f"{next_header}:"
for nm in (next_marker, next_marker2):
if nm in remaining[idx + len(marker):]:
candidate = idx + len(marker) + remaining[idx + len(marker):].index(nm)
if candidate < end_idx:
end_idx = candidate
value = remaining[idx + len(marker):end_idx].strip()
if value:
response[key] = value
response["code"] = code
response["confidence"] = confidence
if justification:
response["justification"] = justification
return json.dumps(response, ensure_ascii=False, indent=None)
# Si pas de raisonnement, réponse simple
response = {
"code": code,
"confidence": confidence,
}
if justification:
response["justification"] = justification
return json.dumps(response, ensure_ascii=False, indent=None)
def main():
print("Chargement du cache Ollama (toutes sources)...")
entries = load_cache()
print(f" Total fusionné : {len(entries)} entrées")
pairs = []
with_reasoning = 0
without_reasoning = 0
skipped_no_code = 0
skipped_no_text = 0
by_type = {"dp": 0, "das": 0}
for key, entry in entries.items():
diag_type, texte = parse_cache_key(key)
if not texte or len(texte) < 3:
skipped_no_text += 1
continue
if not entry.get("code"):
skipped_no_code += 1
continue
user_prompt = build_user_prompt(diag_type, texte)
assistant_response = build_assistant_response(entry)
if entry.get("raisonnement"):
with_reasoning += 1
else:
without_reasoning += 1
by_type[diag_type] = by_type.get(diag_type, 0) + 1
pairs.append(make_chatml(SYSTEM_MSG, user_prompt, assistant_response))
random.shuffle(pairs)
output_path = OUT / "reasoning_chatml.jsonl"
with open(output_path, "w") as f:
for pair in pairs:
f.write(json.dumps(pair, ensure_ascii=False) + "\n")
print(f"\nTotal : {len(pairs)} exemples → {output_path}")
print(f" DP : {by_type.get('dp', 0)}, DAS : {by_type.get('das', 0)}")
print(f" Avec raisonnement complet : {with_reasoning}")
print(f" Sans raisonnement (code seul) : {without_reasoning}")
print(f" Ignorés (pas de code) : {skipped_no_code}")
print(f" Ignorés (pas de texte) : {skipped_no_text}")
print(f"Taille : {output_path.stat().st_size / 1024:.0f} Ko")
if __name__ == "__main__":
main()

202
scripts/04_build_dataset.py Normal file
View File

@@ -0,0 +1,202 @@
#!/usr/bin/env python3
"""
Phase 1 — Fusion de tous les sous-datasets en un dataset final.
Lit tous les .jsonl dans data/processed/ et produit :
- data/datasets/pmsi_train.jsonl (90%)
- data/datasets/pmsi_eval.jsonl (10%)
- data/datasets/stats.json (statistiques)
V2 : Rééquilibrage — réduction drastique des lookups, priorité raisonnement.
Ratio cible : ~35% lookups / 40% raisonnement / 25% règles (vs 95/3/2 avant).
"""
import json
import random
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
PROCESSED = BASE / "data" / "processed"
DATASETS = BASE / "data" / "datasets"
DATASETS.mkdir(parents=True, exist_ok=True)
EVAL_RATIO = 0.10 # 10% pour l'évaluation
# V2 : Sous-échantillonnage agressif des lookups, garder tout le raisonnement
SUBSAMPLE = {
"cim10_chatml.jsonl": 1_500, # 50K → 1.5K (discrimination + inclusions/exclusions)
"ccam_chatml.jsonl": 1_500, # 16.5K → 1.5K (codes les plus fréquents)
"cocoa_chatml.jsonl": 2_000, # 30K → 2K (types clinical + tips prioritaires)
"referentiels_chatml.jsonl": None, # Garder tout (5.3K, déjà dense en règles)
}
# Mots-clés indiquant du raisonnement structuré (pour sélection intelligente)
REASONING_KEYWORDS = [
"analyse_clinique", "discrimination", "regle_pmsi", "règle_pmsi",
"codes_candidats", "justification", "ne pas coder", "ne doit pas",
"à l'exclusion", "comprend", "sévérité", "CMA",
]
def count_tokens_approx(text):
"""Estimation grossière du nombre de tokens (~1.3 token par mot français)."""
return int(len(text.split()) * 1.3)
def _reasoning_score(example: dict) -> int:
"""Score de raisonnement d'un exemple (plus haut = plus précieux)."""
text = " ".join(m.get("content", "") for m in example.get("messages", []))
text_lower = text.lower()
score = 0
for kw in REASONING_KEYWORDS:
if kw.lower() in text_lower:
score += 1
# Bonus pour les réponses JSON structurées longues (raisonnement complet)
assistant_msgs = [m["content"] for m in example.get("messages", []) if m["role"] == "assistant"]
for msg in assistant_msgs:
if len(msg) > 200:
score += 2
if '"analyse_clinique"' in msg:
score += 3
return score
def _smart_subsample(examples: list[dict], target: int) -> list[dict]:
"""Sous-échantillonnage intelligent : prioriser les exemples avec raisonnement."""
# Trier par score de raisonnement (décroissant), puis shuffle pour diversité
scored = [(ex, _reasoning_score(ex)) for ex in examples]
scored.sort(key=lambda x: -x[1])
# Garder tous ceux avec score > 0, puis compléter avec du random
high_value = [(ex, s) for ex, s in scored if s > 0]
low_value = [(ex, s) for ex, s in scored if s == 0]
random.shuffle(low_value)
if len(high_value) >= target:
# Trop d'exemples high-value : prendre les meilleurs + random parmi eux
random.shuffle(high_value)
selected = [ex for ex, _ in high_value[:target]]
else:
# Prendre tous les high-value + compléter avec low-value
selected = [ex for ex, _ in high_value]
remaining = target - len(selected)
selected.extend(ex for ex, _ in low_value[:remaining])
random.shuffle(selected)
return selected
def main():
# Collecter tous les fichiers JSONL
files = sorted(PROCESSED.glob("*.jsonl"))
if not files:
print("Aucun fichier .jsonl trouvé dans data/processed/")
return
all_examples = []
file_stats = {}
for f in files:
examples = []
with open(f) as fh:
for line in fh:
line = line.strip()
if line:
examples.append(json.loads(line))
original_count = len(examples)
# Sous-échantillonner si configuré
if f.name in SUBSAMPLE:
target = SUBSAMPLE[f.name]
if target is None or len(examples) <= target:
print(f" {f.name}: {len(examples)} exemples (gardé tout)")
else:
examples = _smart_subsample(examples, target)
print(f" {f.name}: {len(examples)} exemples (sous-échantillonné depuis {original_count})")
else:
print(f" {f.name}: {len(examples)} exemples")
file_stats[f.name] = len(examples)
all_examples.extend(examples)
print(f"\nTotal brut : {len(all_examples)} exemples")
# Mélanger
random.shuffle(all_examples)
# Split train / eval
n_eval = max(1, int(len(all_examples) * EVAL_RATIO))
eval_set = all_examples[:n_eval]
train_set = all_examples[n_eval:]
# Écrire les datasets
train_path = DATASETS / "pmsi_train.jsonl"
eval_path = DATASETS / "pmsi_eval.jsonl"
for path, dataset in [(train_path, train_set), (eval_path, eval_set)]:
with open(path, "w") as fh:
for ex in dataset:
fh.write(json.dumps(ex, ensure_ascii=False) + "\n")
# Statistiques
def compute_stats(dataset):
total_tokens = 0
max_tokens = 0
min_tokens = float("inf")
for ex in dataset:
text = " ".join(m["content"] for m in ex["messages"])
tokens = count_tokens_approx(text)
total_tokens += tokens
max_tokens = max(max_tokens, tokens)
min_tokens = min(min_tokens, tokens)
avg = total_tokens / len(dataset) if dataset else 0
return {
"count": len(dataset),
"total_tokens_approx": total_tokens,
"avg_tokens": round(avg),
"max_tokens": max_tokens,
"min_tokens": min_tokens if min_tokens != float("inf") else 0,
}
train_stats = compute_stats(train_set)
eval_stats = compute_stats(eval_set)
stats = {
"sources": file_stats,
"total": len(all_examples),
"train": train_stats,
"eval": eval_stats,
"eval_ratio": EVAL_RATIO,
}
stats_path = DATASETS / "stats.json"
with open(stats_path, "w") as fh:
json.dump(stats, fh, indent=2, ensure_ascii=False)
# Calculer le ratio raisonnement
n_reasoning = sum(1 for ex in all_examples if _reasoning_score(ex) >= 3)
pct_reasoning = 100.0 * n_reasoning / len(all_examples) if all_examples else 0
# Affichage
print(f"\n{'='*50}")
print(f"Dataset final (V2 rééquilibré) :")
print(f" Train : {train_stats['count']} exemples → {train_path.name}")
print(f" Eval : {eval_stats['count']} exemples → {eval_path.name}")
print(f"\nRatio raisonnement structuré (score≥3) : {n_reasoning}/{len(all_examples)} ({pct_reasoning:.0f}%)")
print(f"\nTokens (estimation) :")
print(f" Train : ~{train_stats['total_tokens_approx']:,} tokens (moy: {train_stats['avg_tokens']}, max: {train_stats['max_tokens']})")
print(f" Eval : ~{eval_stats['total_tokens_approx']:,} tokens (moy: {eval_stats['avg_tokens']}, max: {eval_stats['max_tokens']})")
print(f"\nSources :")
for name, count in sorted(file_stats.items()):
print(f" {name}: {count}")
print(f"\nFichiers :")
print(f" {train_path} ({train_path.stat().st_size / 1024 / 1024:.1f} Mo)")
print(f" {eval_path} ({eval_path.stat().st_size / 1024 / 1024:.1f} Mo)")
print(f" {stats_path}")
if __name__ == "__main__":
main()

671
scripts/05_parse_cocoa.py Normal file
View File

@@ -0,0 +1,671 @@
#!/usr/bin/env python3
"""
Phase 1E — Parsing du CoCoA 2025 (1113 pages) pour extraction d'exemples ChatML.
Le CoCoA (Codage Complet Annoté) est le vademecum des médecins DIM.
Il contient des entrées détaillées par code CIM-10 avec :
- Indicateurs P/R/A (Diagnostic Principal / Relié / Associé)
- Niveaux de sévérité (2, 3, 4)
- Descriptions cliniques détaillées
- Synonymes
- Comprend / À l'exclusion de
- Notes AGORA (FAQ ATIH)
- Annotations CoCoA (conseils pratiques DIM)
Pages traitées : 85-1080 (entrées détaillées, chapitres 1-22)
Produit : data/processed/cocoa_chatml.jsonl
"""
import json
import re
import random
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
RAW = BASE / "data" / "raw"
OUT = BASE / "data" / "processed"
OUT.mkdir(parents=True, exist_ok=True)
SYSTEM_MSG = "Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. Tu t'appuies sur le CoCoA (Codage Complet Annoté) pour tes décisions de codage."
# Pages des entrées détaillées (0-indexed)
PAGE_START = 84 # page 85
PAGE_END = 1080 # page 1080
# Regex patterns
RE_CIM10_CODE = re.compile(
r'^([A-Z]\d{2}(?:\.\d{1,2})?)\s*([†*]?)\s+(.*)'
)
RE_CATEGORY_CODE = re.compile(
r'^([A-Z]\d{2})\s+(.*)'
)
RE_SUBCODE = re.compile(
r'^([A-Z]\d{2}\.\d{1,2})\s*([†*]?)\s*(.*)'
)
RE_PRA_LINE = re.compile(r'^P\s*R\s*A')
RE_SEVERITY = re.compile(r'^(\d)\s*$')
RE_CHAPTER_HEADER = re.compile(r'^CHAPITRE\s+([IVX]+)\s*:?\s*(.*)')
RE_SECTION_HEADER = re.compile(r'^([A-Z][a-zéèêëàâîïôùûüç].+)\s*\(([A-Z]\d{2}[-][A-Z]\d{2})\)')
RE_EXCLUSION = re.compile(r"^À l['\u2019]exclusion de\s+(.*)", re.IGNORECASE)
RE_COMPREND = re.compile(r'^Comprend\s+(.*)', re.IGNORECASE)
RE_AGORA = re.compile(r'\(AGORA\s*[-]\s*#?\s*(\d+).*?\)')
RE_FOOTER = re.compile(r'^2025\s*[-]')
RE_NOTE_BRACKET = re.compile(r'^\[voir en début')
def extract_text_from_pdf():
"""Extraire le texte de toutes les pages détaillées du CoCoA."""
import pdfplumber
pdf_path = RAW / "cocoa_2025.pdf"
print(f"Ouverture de {pdf_path}...")
pages_text = []
with pdfplumber.open(pdf_path) as pdf:
total = min(PAGE_END, len(pdf.pages))
for i in range(PAGE_START, total):
page = pdf.pages[i]
text = page.extract_text() or ""
pages_text.append((i + 1, text)) # (page_number, text)
if (i - PAGE_START) % 100 == 0:
print(f" Extraction page {i+1}/{total}...")
print(f" {len(pages_text)} pages extraites")
return pages_text
def parse_entries(pages_text):
"""Parser les entrées CIM-10 depuis le texte extrait."""
entries = {} # code -> dict
current_chapter = ""
current_section = ""
current_code = None
current_entry = None
collecting_exclusion = False
collecting_comprend = False
collecting_description = False
for page_num, page_text in pages_text:
lines = page_text.split('\n')
for line_idx, line in enumerate(lines):
line = line.strip()
# Skip empty lines and footers
if not line:
collecting_exclusion = False
collecting_comprend = False
collecting_description = False
continue
if RE_FOOTER.match(line):
collecting_exclusion = False
collecting_comprend = False
collecting_description = False
continue
if RE_NOTE_BRACKET.match(line):
collecting_description = False
continue
# Chapter header
m = RE_CHAPTER_HEADER.match(line)
if m:
current_chapter = m.group(2).strip()
collecting_exclusion = False
collecting_comprend = False
collecting_description = False
continue
# Skip P R A indicator lines (standalone)
if RE_PRA_LINE.match(line):
# Check if there's a code on the same line
rest = re.sub(r'^P\s*R\s*A\s*', '', line).strip()
# Also remove "AN T" or similar special markers
rest = re.sub(r'^AN\s*T?\s*', '', rest).strip()
if rest:
# P R A followed by code on same line (category code)
m_cat = RE_CATEGORY_CODE.match(rest)
m_sub = RE_SUBCODE.match(rest)
if m_sub:
code = m_sub.group(1)
dagger_star = m_sub.group(2)
desc = m_sub.group(3).strip()
_save_entry(entries, current_code, current_entry)
current_code = code
current_entry = _new_entry(code, desc, dagger_star, current_chapter, page_num, is_category=False)
collecting_exclusion = False
collecting_comprend = False
collecting_description = False
elif m_cat:
code = m_cat.group(1)
desc = m_cat.group(2).strip()
_save_entry(entries, current_code, current_entry)
current_code = code
current_entry = _new_entry(code, desc, "", current_chapter, page_num, is_category=True)
collecting_exclusion = False
collecting_comprend = False
collecting_description = False
continue
# Severity number on its own line
m = RE_SEVERITY.match(line)
if m and current_entry:
current_entry["severity"] = int(m.group(1))
continue
# Sub-code entry
m = RE_SUBCODE.match(line)
if m:
code = m.group(1)
dagger_star = m.group(2)
desc = m.group(3).strip()
_save_entry(entries, current_code, current_entry)
current_code = code
current_entry = _new_entry(code, desc, dagger_star, current_chapter, page_num, is_category=False)
collecting_exclusion = False
collecting_comprend = False
collecting_description = False
continue
# Category code (3-char code at start of line, no dot)
m = RE_CATEGORY_CODE.match(line)
if m and not line[0].islower() and len(m.group(1)) == 3:
# Make sure it's actually a code and not part of text
potential_code = m.group(1)
if re.match(r'^[A-Z]\d{2}$', potential_code):
desc = m.group(2).strip()
# Avoid false positives - check that desc looks like a title
if desc and len(desc) > 3 and not desc[0].isdigit():
_save_entry(entries, current_code, current_entry)
current_code = potential_code
current_entry = _new_entry(potential_code, desc, "", current_chapter, page_num, is_category=True)
collecting_exclusion = False
collecting_comprend = False
collecting_description = False
continue
# Section header (e.g., "Autres maladies bactériennes (A30-A49)")
m = RE_SECTION_HEADER.match(line)
if m:
current_section = m.group(1).strip()
collecting_exclusion = False
collecting_comprend = False
collecting_description = False
continue
# Comprend
m = RE_COMPREND.match(line)
if m:
if current_entry:
current_entry["comprend"].append(m.group(1).strip())
collecting_comprend = True
collecting_exclusion = False
collecting_description = False
continue
# À l'exclusion de
m = RE_EXCLUSION.match(line)
if m:
if current_entry:
current_entry["exclusions"].append(m.group(1).strip())
collecting_exclusion = True
collecting_comprend = False
collecting_description = False
continue
# AGORA reference
agora_matches = RE_AGORA.findall(line)
if agora_matches and current_entry:
for ref in agora_matches:
current_entry["agora_refs"].append(ref)
# Also add the full line as a CoCoA annotation
if "AGORA" in line or "Aunis" in line.lower() or "CoCoA" in line:
current_entry["cocoa_notes"].append(line)
continue
# CoCoA/Aunis annotations (highlighted text)
if current_entry and ("Aunis" in line or "CoCoA" in line):
current_entry["cocoa_notes"].append(line)
continue
# Continuation lines for exclusions
if collecting_exclusion and current_entry:
# Exclusion continuation - items with code refs, bullets, lowercase starts
if (re.search(r'\([A-Z]\d{2}', line) or
line.startswith('') or line.startswith('-') or
line[0].islower() or
re.match(r'^[a-zéèêëàâîïôùûüç•\-]', line)):
current_entry["exclusions"].append(line)
continue
else:
collecting_exclusion = False
# Continuation lines for comprend
if collecting_comprend and current_entry:
if not re.match(r'^[A-Z]\d', line) and not RE_PRA_LINE.match(line):
current_entry["comprend"].append(line)
continue
else:
collecting_comprend = False
# Clinical description text (paragraph after a code entry)
if current_entry and line and not RE_PRA_LINE.match(line):
# Check if it's a synonym or clinical text
if len(line) > 60 and not re.match(r'^[A-Z]\d', line):
# Long text = clinical description
current_entry["clinical_text"].append(line)
elif not re.match(r'^[A-Z]\d', line) and not line.startswith('P '):
# Short text after a code = synonym
current_entry["synonyms"].append(line)
# Save last entry
_save_entry(entries, current_code, current_entry)
return entries
def _new_entry(code, description, dagger_star, chapter, page, is_category=False):
return {
"code": code,
"description": description,
"dagger_star": dagger_star,
"chapter": chapter,
"page": page,
"is_category": is_category,
"severity": None,
"synonyms": [],
"comprend": [],
"exclusions": [],
"clinical_text": [],
"agora_refs": [],
"cocoa_notes": [],
}
def _save_entry(entries, code, entry):
if code and entry and entry["description"]:
# Clean up
entry["synonyms"] = [s.strip() for s in entry["synonyms"] if s.strip() and len(s.strip()) > 2]
entry["comprend"] = [c.strip() for c in entry["comprend"] if c.strip()]
entry["exclusions"] = [e.strip() for e in entry["exclusions"] if e.strip()]
entry["clinical_text"] = [t.strip() for t in entry["clinical_text"] if t.strip()]
entry["cocoa_notes"] = [n.strip() for n in entry["cocoa_notes"] if n.strip()]
# Deduplicate
entry["synonyms"] = list(dict.fromkeys(entry["synonyms"]))
entry["cocoa_notes"] = list(dict.fromkeys(entry["cocoa_notes"]))
# Filter out noise from synonyms and move misclassified exclusions
filtered_syns = []
re_excl_inline = re.compile(r"^À l['\u2019]exclusion de", re.IGNORECASE)
for s in entry["synonyms"]:
# Skip severity numbers, P R A markers, etc.
if RE_SEVERITY.match(s) or RE_PRA_LINE.match(s) or RE_FOOTER.match(s):
continue
if s in ("P R A", "P", "R", "A", "AN", "T"):
continue
# Move misclassified exclusions
if re_excl_inline.match(s):
excl_text = re.sub(r"^À l['\u2019]exclusion de\s*", '', s, flags=re.IGNORECASE).strip()
if excl_text:
entry["exclusions"].append(excl_text)
continue
filtered_syns.append(s)
entry["synonyms"] = filtered_syns
# Also clean comprend - move misclassified exclusions
filtered_comprend = []
for c in entry["comprend"]:
if re_excl_inline.match(c):
excl_text = re.sub(r"^À l['\u2019]exclusion de\s*", '', c, flags=re.IGNORECASE).strip()
if excl_text:
entry["exclusions"].append(excl_text)
else:
filtered_comprend.append(c)
entry["comprend"] = filtered_comprend
entries[code] = entry
def make_chatml(system, user, assistant):
return {
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
}
def generate_description_pairs(entries):
"""Type 1 : Description enrichie CoCoA d'un code (vs FHIR plus basique)."""
pairs = []
for code, e in entries.items():
desc = e["description"]
if not desc or len(desc) < 3:
continue
answer_parts = [f"{code}{desc}"]
if e["chapter"]:
answer_parts.append(f"Chapitre : {e['chapter']}")
if e["synonyms"]:
syns = [s for s in e["synonyms"][:8] if len(s) > 2]
if syns:
answer_parts.append(f"Synonymes : {' ; '.join(syns)}")
if e["comprend"]:
answer_parts.append(f"Comprend : {' '.join(e['comprend'][:5])}")
if e["exclusions"]:
excls = [ex for ex in e["exclusions"][:5]]
answer_parts.append(f"À l'exclusion de : {' ; '.join(excls)}")
if e["severity"]:
answer_parts.append(f"Niveau de sévérité CMA : {e['severity']}")
if e["dagger_star"]:
marker = "étiologique (†)" if e["dagger_star"] == "" else "manifestation (*)"
answer_parts.append(f"Convention dague/astérisque : code {marker}")
answer = "\n".join(answer_parts)
if len(answer) > 2000:
answer = answer[:2000]
templates = [
f"Décris le code CIM-10 {code} selon le CoCoA.",
f"Que dit le CoCoA sur le code {code} ?",
f"Quelles sont les caractéristiques du code {code} d'après le CoCoA ?",
]
pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), answer))
return pairs
def generate_clinical_pairs(entries):
"""Type 2 : Descriptions cliniques détaillées → code + raisonnement."""
pairs = []
for code, e in entries.items():
if not e["clinical_text"]:
continue
clinical = " ".join(e["clinical_text"])
if len(clinical) < 50:
continue
desc = e["description"]
# Construire un raisonnement structuré
reasoning = {
"analyse_clinique": clinical[:500],
"code": code,
"description": desc,
"confidence": "high",
"justification": f"La description clinique du CoCoA correspond au code {code} ({desc})."
}
if e["exclusions"]:
reasoning["exclusions_a_verifier"] = " ; ".join(e["exclusions"][:3])
answer = json.dumps(reasoning, ensure_ascii=False)
# Créer une question à partir du texte clinique (tronqué)
clinical_short = clinical[:300]
if len(clinical) > 300:
clinical_short += "..."
question = f"Un patient présente le tableau clinique suivant :\n{clinical_short}\n\nQuel code CIM-10 correspond à cette présentation ?"
if len(question) > 1500:
continue
pairs.append(make_chatml(SYSTEM_MSG, question, answer))
return pairs
def generate_synonym_pairs(entries):
"""Type 3 : Synonyme → code CIM-10."""
pairs = []
for code, e in entries.items():
if not e["synonyms"]:
continue
desc = e["description"]
for syn in e["synonyms"]:
if len(syn) < 4 or len(syn) > 200:
continue
# Skip entries that look like noise
if syn.startswith("") or syn.startswith("[") or syn.startswith("("):
syn = syn.lstrip("•[( ").rstrip("])").strip()
if not syn or len(syn) < 4:
continue
answer = json.dumps({
"code": code,
"confidence": "high",
"justification": f"« {syn} » est un synonyme de {code} ({desc}) selon le CoCoA."
}, ensure_ascii=False)
templates = [
f"Quel est le code CIM-10 pour : {syn} ?",
f"Code CIM-10 correspondant à « {syn} » ?",
]
pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), answer))
return pairs
def generate_exclusion_pairs(entries):
"""Type 4 : Questions sur ce qu'un code exclut (piège de codage)."""
pairs = []
for code, e in entries.items():
if not e["exclusions"]:
continue
desc = e["description"]
excls = " ; ".join(e["exclusions"][:8])
if len(excls) < 10:
continue
answer = f"Le code {code} ({desc}) exclut :\n{excls}\n\nAttention : ces situations doivent être codées avec les codes de renvoi indiqués entre parenthèses."
if len(answer) > 1500:
answer = answer[:1500]
templates = [
f"Quelles sont les exclusions du code CIM-10 {code} ({desc}) ?",
f"Que ne faut-il PAS coder en {code} ?",
]
pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), answer))
return pairs
def generate_severity_pairs(entries):
"""Type 5 : Questions sur le niveau de sévérité CMA d'un code."""
pairs = []
for code, e in entries.items():
if not e["severity"]:
continue
desc = e["description"]
sev = e["severity"]
sev_text = {
2: "niveau 2 (sévérité modérée)",
3: "niveau 3 (sévérité élevée)",
4: "niveau 4 (sévérité très élevée)",
}.get(sev, f"niveau {sev}")
answer = f"Le code {code} ({desc}) a un niveau de sévérité CMA de {sev_text}.\n"
answer += f"En tant que DAS, ce code peut entraîner une majoration du niveau de sévérité du GHM."
if e["is_category"]:
answer += f"\nNote : {code} est une catégorie (code à 3 caractères). Les sous-codes peuvent avoir des niveaux différents."
pairs.append(make_chatml(
SYSTEM_MSG,
f"Quel est le niveau de sévérité CMA du code {code} ({desc}) ?",
answer
))
return pairs
def generate_cocoa_tips_pairs(entries):
"""Type 6 : Notes CoCoA et AGORA (conseils pratiques DIM)."""
pairs = []
for code, e in entries.items():
if not e["cocoa_notes"]:
continue
desc = e["description"]
notes = "\n".join(e["cocoa_notes"])
if len(notes) < 10:
continue
answer = f"Pour le code {code} ({desc}), le CoCoA indique :\n{notes}"
if len(answer) > 1500:
answer = answer[:1500]
pairs.append(make_chatml(
SYSTEM_MSG,
f"Y a-t-il des conseils pratiques du CoCoA pour le codage de {code} ({desc}) ?",
answer
))
return pairs
def generate_comprend_pairs(entries):
"""Type 7 : Ce que comprend un code (inclusions)."""
pairs = []
for code, e in entries.items():
if not e["comprend"]:
continue
desc = e["description"]
comprend = " ; ".join(e["comprend"][:5])
if len(comprend) < 10:
continue
answer = f"Le code {code} ({desc}) comprend :\n{comprend}"
templates = [
f"Que comprend le code CIM-10 {code} ?",
f"Quelles situations sont incluses dans le code {code} ({desc}) ?",
]
pairs.append(make_chatml(SYSTEM_MSG, random.choice(templates), answer))
return pairs
def main():
# Étape 1 : Extraction du texte
pages_text = extract_text_from_pdf()
# Étape 2 : Parsing des entrées
print("\nParsing des entrées CIM-10...")
entries = parse_entries(pages_text)
# Stats
n_categories = sum(1 for e in entries.values() if e["is_category"])
n_subcodes = sum(1 for e in entries.values() if not e["is_category"])
n_with_clinical = sum(1 for e in entries.values() if e["clinical_text"])
n_with_synonyms = sum(1 for e in entries.values() if e["synonyms"])
n_with_exclusions = sum(1 for e in entries.values() if e["exclusions"])
n_with_comprend = sum(1 for e in entries.values() if e["comprend"])
n_with_severity = sum(1 for e in entries.values() if e["severity"])
n_with_cocoa = sum(1 for e in entries.values() if e["cocoa_notes"])
print(f"\n Entrées parsées : {len(entries)}")
print(f" Catégories (3 car.) : {n_categories}")
print(f" Sous-codes : {n_subcodes}")
print(f" Avec texte clinique : {n_with_clinical}")
print(f" Avec synonymes : {n_with_synonyms}")
print(f" Avec exclusions : {n_with_exclusions}")
print(f" Avec comprend : {n_with_comprend}")
print(f" Avec sévérité CMA : {n_with_severity}")
print(f" Avec notes CoCoA : {n_with_cocoa}")
# Étape 3 : Génération des paires ChatML
print("\nGénération des paires ChatML...")
print(" Type 1 : Descriptions enrichies CoCoA")
desc_pairs = generate_description_pairs(entries)
print(f"{len(desc_pairs)} exemples")
print(" Type 2 : Texte clinique → code")
clinical_pairs = generate_clinical_pairs(entries)
print(f"{len(clinical_pairs)} exemples")
print(" Type 3 : Synonyme → code")
synonym_pairs = generate_synonym_pairs(entries)
print(f"{len(synonym_pairs)} exemples")
print(" Type 4 : Exclusions")
exclusion_pairs = generate_exclusion_pairs(entries)
print(f"{len(exclusion_pairs)} exemples")
print(" Type 5 : Sévérité CMA")
severity_pairs = generate_severity_pairs(entries)
print(f"{len(severity_pairs)} exemples")
print(" Type 6 : Notes CoCoA/AGORA")
cocoa_pairs = generate_cocoa_tips_pairs(entries)
print(f"{len(cocoa_pairs)} exemples")
print(" Type 7 : Comprend (inclusions)")
comprend_pairs = generate_comprend_pairs(entries)
print(f"{len(comprend_pairs)} exemples")
# Fusionner et mélanger
all_pairs = desc_pairs + clinical_pairs + synonym_pairs + exclusion_pairs + severity_pairs + cocoa_pairs + comprend_pairs
random.shuffle(all_pairs)
# Écrire le JSONL
output_path = OUT / "cocoa_chatml.jsonl"
with open(output_path, "w") as f:
for pair in all_pairs:
f.write(json.dumps(pair, ensure_ascii=False) + "\n")
print(f"\n{'='*50}")
print(f"Total : {len(all_pairs)} exemples → {output_path}")
print(f"Taille : {output_path.stat().st_size / 1024 / 1024:.1f} Mo")
# Sauvegarder aussi les entrées parsées en JSON pour debug
debug_path = OUT / "cocoa_entries_debug.json"
with open(debug_path, "w") as f:
json.dump(entries, f, indent=2, ensure_ascii=False)
print(f"Debug : {debug_path} ({debug_path.stat().st_size / 1024 / 1024:.1f} Mo)")
# Répartition
print(f"\nRépartition :")
print(f" Descriptions CoCoA : {len(desc_pairs)}")
print(f" Texte clinique→code : {len(clinical_pairs)}")
print(f" Synonyme→code : {len(synonym_pairs)}")
print(f" Exclusions : {len(exclusion_pairs)}")
print(f" Sévérité CMA : {len(severity_pairs)}")
print(f" Notes CoCoA/AGORA : {len(cocoa_pairs)}")
print(f" Comprend (inclusions): {len(comprend_pairs)}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,562 @@
#!/usr/bin/env python3
"""
Phase 1D — Génération de données synthétiques via API OpenAI (GPT-4o).
Envoie des métadonnées anonymisées FAISS à un grand modèle pour générer
des exemples de raisonnement DIM complet en format ChatML.
Types d'exemples générés :
1. Scénario clinique → raisonnement DIM → code CIM-10
2. Discrimination entre codes proches
3. Application des règles PMSI (DP/DAS, CMA, exclusions)
Nécessite : OPENAI_API_KEY en variable d'environnement
Usage :
python scripts/06_generate_synthetic.py [--n 500] [--batch 5] [--model gpt-4o] [--dry-run]
"""
import json
import os
import sys
import time
import random
import argparse
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
T2A = Path("/home/dom/ai/t2a")
OUT = BASE / "data" / "processed"
OUT.mkdir(parents=True, exist_ok=True)
# --- Prompts ---
SYSTEM_PROMPT_SCENARIO = """Tu es un formateur DIM (Département d'Information Médicale) expert en codage PMSI.
Tu génères des scénarios cliniques réalistes et anonymisés pour former des médecins DIM au codage CIM-10.
Pour chaque code CIM-10 fourni, tu dois produire :
1. Un SCÉNARIO CLINIQUE réaliste (3-5 phrases, anonymisé, comme extrait d'un compte-rendu d'hospitalisation)
2. Un RAISONNEMENT DIM structuré montrant la démarche de codage
Le raisonnement doit suivre ces étapes :
- analyse_clinique : ce que le texte clinique révèle
- codes_candidats : 2-3 codes CIM-10 envisageables avec leur libellé
- discrimination : pourquoi le code retenu est le bon (et pas les autres)
- regle_pmsi : règle PMSI applicable (DP/DAS, exclusions, conventions dague/astérisque, etc.)
- code : le code CIM-10 retenu
- confidence : high/medium/low
- justification : synthèse en 1 phrase
IMPORTANT :
- Les scénarios doivent être VARIÉS (âges, sexes, contextes différents)
- Anonymisés (pas de vrais noms/dates)
- Médicalement cohérents
- En français médical professionnel
- Réponse en JSON valide uniquement"""
SYSTEM_PROMPT_DISCRIM = """Tu es un formateur DIM expert en codage PMSI.
Tu crées des exercices de discrimination entre codes CIM-10 proches pour former des médecins DIM.
Pour chaque groupe de codes fourni, génère UN scénario clinique où le choix entre les codes est subtil,
puis montre le raisonnement complet pour arriver au bon code.
IMPORTANT : Réponse en JSON valide uniquement."""
SYSTEM_PROMPT_RULES = """Tu es un formateur DIM expert en règles PMSI.
Tu crées des exercices d'application des règles PMSI (codage DP/DAS, CMA, séjours multi-unités, etc.).
Pour chaque situation fournie, génère un scénario d'hospitalisation et montre comment les règles PMSI
s'appliquent au codage.
IMPORTANT : Réponse en JSON valide uniquement."""
def load_faiss_metadata():
"""Charger les métadonnées FAISS."""
meta_path = T2A / "data" / "rag_index" / "metadata.json"
with open(meta_path) as f:
return json.load(f)
def load_cim10_fhir():
"""Charger les concepts FHIR pour enrichir les prompts."""
fhir_path = BASE / "data" / "raw" / "smt_cim10_fhir.json"
if not fhir_path.exists():
return {}
with open(fhir_path) as f:
data = json.load(f)
by_code = {}
for c in data.get("concept", []):
by_code[c["code"]] = c
return by_code
def load_cocoa_entries():
"""Charger les entrées CoCoA parsées."""
cocoa_path = OUT / "cocoa_entries_debug.json"
if not cocoa_path.exists():
return {}
with open(cocoa_path) as f:
return json.load(f)
def clean_extrait(extrait):
"""Nettoyer un extrait FAISS (enlever bruit OCR, numéros de page, etc.)."""
import re
# Enlever les numéros de page isolés (sur leur propre ligne ou collés)
extrait = re.sub(r'\n\s*\d{1,4}\s*\n', '\n', extrait)
extrait = re.sub(r'^\d{1,4}\s*\n', '', extrait)
# Enlever les transitions de chapitre
extrait = re.sub(r'Chapitre\s+[IVX]+\b.*', '', extrait)
# Enlever les lignes de classification
extrait = re.sub(r'Classification Internationale.*$', '', extrait, flags=re.MULTILINE)
# Enlever les lignes vides multiples
extrait = re.sub(r'\n{2,}', '\n', extrait)
# Tronquer au premier code d'une AUTRE catégorie
lines = extrait.split('\n')
first_code = None
clean_lines = []
for line in lines:
stripped = line.strip()
# Ligne ne contenant qu'un nombre = bruit
if re.match(r'^\d{1,4}$', stripped):
continue
m = re.match(r'^([A-Z]\d{2}(?:\.\d{1,2})?)[*†]?\s', stripped)
if m:
code = m.group(1)
if first_code is None:
first_code = code
elif not code.startswith(first_code[:3]):
break # On est passé à un autre groupe de codes
clean_lines.append(stripped)
result = '\n'.join(clean_lines).strip()
# Limiter la longueur
if len(result) > 400:
result = result[:400].rsplit('\n', 1)[0]
return result
def select_codes_for_scenarios(metadata, n=500):
"""Sélectionner les codes CIM-10 les plus intéressants pour la génération."""
# Filtrer les entrées CIM-10 avec des extraits substantiels
cim10_entries = [m for m in metadata if m.get("document") == "cim10" and len(m.get("extrait", "")) > 20]
# Prioriser les codes avec des extraits riches
cim10_entries.sort(key=lambda x: len(x.get("extrait", "")), reverse=True)
# Prendre les codes uniques
seen = set()
selected = []
for m in cim10_entries:
code = m["code"]
if code not in seen and "." in code: # Préférer les sous-codes (plus spécifiques)
seen.add(code)
selected.append(m)
if len(selected) >= n:
break
# Si pas assez de sous-codes, ajouter des catégories
if len(selected) < n:
for m in cim10_entries:
code = m["code"]
if code not in seen:
seen.add(code)
selected.append(m)
if len(selected) >= n:
break
random.shuffle(selected)
return selected[:n]
def select_discrimination_groups(metadata, cocoa_entries, n=100):
"""Sélectionner des groupes de codes proches pour la discrimination."""
# Grouper par catégorie parente (3 premiers caractères)
by_parent = {}
for m in metadata:
if m.get("document") != "cim10":
continue
code = m.get("code", "")
if "." in code:
parent = code.split(".")[0]
by_parent.setdefault(parent, []).append(m)
# Sélectionner les groupes avec 2-6 sous-codes
groups = []
for parent, children in by_parent.items():
if 2 <= len(children) <= 6:
groups.append({
"parent": parent,
"codes": [{"code": c["code"], "extrait": clean_extrait(c["extrait"])[:150]} for c in children]
})
random.shuffle(groups)
return groups[:n]
def build_scenario_prompt(codes_batch, fhir_by_code, cocoa_entries):
"""Construire le prompt pour un batch de codes (scénarios cliniques)."""
items = []
for meta in codes_batch:
code = meta["code"]
# Source primaire : FHIR (propre)
fhir = fhir_by_code.get(code, {})
display = fhir.get("display", "")
# Source secondaire : CoCoA (riche)
cocoa = cocoa_entries.get(code, {})
cocoa_desc = cocoa.get("description", "")
exclusions = cocoa.get("exclusions", [])[:4]
synonyms = cocoa.get("synonyms", [])[:5]
comprend = cocoa.get("comprend", [])[:3]
severity = cocoa.get("severity")
clinical = " ".join(cocoa.get("clinical_text", []))[:200]
# Utiliser la meilleure description disponible
desc = display or cocoa_desc or clean_extrait(meta["extrait"])[:200]
item = f"CODE: {code}\nLIBELLÉ: {desc}"
if synonyms:
item += f"\nSYNONYMES: {'; '.join(synonyms)}"
if comprend:
item += f"\nCOMPREND: {'; '.join(comprend)}"
if exclusions:
item += f"\nEXCLUSIONS: {'; '.join(exclusions)}"
if severity:
item += f"\nSÉVÉRITÉ CMA: {severity}"
if clinical:
item += f"\nDESCRIPTION CLINIQUE: {clinical}"
items.append(item)
codes_text = "\n---\n".join(items)
prompt = f"""Génère un scénario clinique et un raisonnement DIM pour chacun des {len(codes_batch)} codes suivants.
{codes_text}
Réponds en JSON avec cette structure exacte :
{{"scenarios": [
{{
"code": "le code CIM-10",
"scenario_clinique": "Texte du scénario clinique réaliste (3-5 phrases)",
"raisonnement": {{
"analyse_clinique": "Analyse des éléments cliniques pertinents",
"codes_candidats": "2-3 codes envisagés avec libellés",
"discrimination": "Pourquoi ce code et pas les autres",
"regle_pmsi": "Règle PMSI applicable",
"code_retenu": "le code",
"confidence": "high",
"justification": "Synthèse en 1 phrase"
}}
}},
...
]}}
Le tableau "scenarios" DOIT contenir exactement {len(codes_batch)} objets, un par code."""
return prompt
def build_discrimination_prompt(group, fhir_by_code):
"""Construire le prompt pour un exercice de discrimination."""
codes_text = "\n".join(
f"- {c['code']}: {c['extrait']}"
for c in group["codes"]
)
prompt = f"""Voici un groupe de codes CIM-10 de la catégorie {group['parent']} :
{codes_text}
Génère UN scénario clinique réaliste où le choix entre ces codes est subtil et demande une réflexion.
Puis montre le raisonnement DIM complet pour arriver au bon code.
Réponds en JSON :
{{
"scenario_clinique": "Le scénario (5-8 phrases, réaliste, anonymisé)",
"codes_en_jeu": ["code1", "code2"],
"raisonnement": {{
"analyse_clinique": "...",
"codes_candidats": "...",
"discrimination": "Explication détaillée de pourquoi un code est préféré",
"regle_pmsi": "Règle PMSI applicable",
"code_retenu": "le code correct",
"confidence": "high/medium",
"justification": "..."
}}
}}
JSON valide uniquement."""
return prompt
def call_openai(client, model, system_prompt, user_prompt, temperature=0.7):
"""Appeler l'API OpenAI."""
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
max_tokens=4096,
response_format={"type": "json_object"},
)
content = response.choices[0].message.content
return json.loads(content)
def _extract_items(result):
"""Extraire la liste d'items depuis la réponse JSON (gère différents formats)."""
if isinstance(result, list):
return result
if isinstance(result, dict):
# Chercher récursivement un tableau d'items
for v in result.values():
if isinstance(v, list) and v and isinstance(v[0], dict):
return v
# Si pas de tableau, c'est peut-être un seul item
if "scenario_clinique" in result or "raisonnement" in result:
return [result]
# Dernier recours : aplatir les valeurs dict
items = []
for v in result.values():
if isinstance(v, dict) and ("scenario_clinique" in v or "raisonnement" in v):
items.append(v)
if items:
return items
return []
def convert_to_chatml(scenario_data):
"""Convertir un résultat de génération en format ChatML."""
if not isinstance(scenario_data, dict):
return None
system_msg = "Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. Tu codes les diagnostics en suivant une démarche structurée."
scenario = scenario_data.get("scenario_clinique", "")
raisonnement = scenario_data.get("raisonnement", {})
# Si pas de scenario_clinique, chercher dans d'autres clés possibles
if not scenario:
scenario = scenario_data.get("scenario", scenario_data.get("texte_clinique", ""))
# Si le raisonnement est directement dans le dict (pas imbriqué)
if not raisonnement and "analyse_clinique" in scenario_data:
raisonnement = {k: v for k, v in scenario_data.items() if k != "scenario_clinique"}
if not scenario or not raisonnement:
return None
user_msg = f"Code ce diagnostic en CIM-10.\n\nTEXTE CLINIQUE : {scenario}"
assistant_msg = json.dumps(raisonnement, ensure_ascii=False)
return {
"messages": [
{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_msg},
]
}
def process_scenario_batch(client, model, batch, fhir_by_code, cocoa_entries, batch_idx):
"""Traiter un batch de codes pour générer des scénarios."""
prompt = build_scenario_prompt(batch, fhir_by_code, cocoa_entries)
try:
result = call_openai(client, model, SYSTEM_PROMPT_SCENARIO, prompt)
# Le résultat peut être un tableau ou un dict avec une clé contenant le tableau
items = _extract_items(result)
examples = []
for item in items:
chatml = convert_to_chatml(item)
if chatml:
examples.append(chatml)
else:
print(f" [Batch {batch_idx}] Item non converti: {list(item.keys()) if isinstance(item, dict) else type(item)}")
if len(examples) < len(batch):
print(f" [Batch {batch_idx}] {len(examples)}/{len(batch)} exemples récupérés")
return examples
except Exception as e:
print(f" [Batch {batch_idx}] Erreur: {e}")
return []
def process_discrimination_batch(client, model, group, fhir_by_code, batch_idx):
"""Traiter un groupe pour générer un exercice de discrimination."""
prompt = build_discrimination_prompt(group, fhir_by_code)
try:
result = call_openai(client, model, SYSTEM_PROMPT_DISCRIM, prompt)
chatml = convert_to_chatml(result)
return [chatml] if chatml else []
except Exception as e:
print(f" [Discrim {batch_idx}] Erreur: {e}")
return []
def main():
parser = argparse.ArgumentParser(description="Génération de données synthétiques via OpenAI")
parser.add_argument("--n", type=int, default=500, help="Nombre de scénarios à générer")
parser.add_argument("--n-discrim", type=int, default=100, help="Nombre d'exercices de discrimination")
parser.add_argument("--batch", type=int, default=5, help="Codes par batch (scénarios)")
parser.add_argument("--model", default="gpt-4o", help="Modèle OpenAI")
parser.add_argument("--workers", type=int, default=3, help="Workers parallèles")
parser.add_argument("--dry-run", action="store_true", help="Afficher les prompts sans appeler l'API")
parser.add_argument("--resume", action="store_true", help="Reprendre depuis la dernière exécution")
args = parser.parse_args()
# Vérifier la clé API
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key and not args.dry_run:
print("Erreur: OPENAI_API_KEY non définie.")
print(" export OPENAI_API_KEY='sk-...'")
sys.exit(1)
# Charger les données
print("Chargement des données...")
metadata = load_faiss_metadata()
fhir_by_code = load_cim10_fhir()
cocoa_entries = load_cocoa_entries()
print(f" FAISS: {len(metadata)} entrées")
print(f" FHIR: {len(fhir_by_code)} concepts")
print(f" CoCoA: {len(cocoa_entries)} entrées")
# Sélectionner les codes
print(f"\nSélection de {args.n} codes pour les scénarios...")
selected_codes = select_codes_for_scenarios(metadata, n=args.n)
print(f" {len(selected_codes)} codes sélectionnés")
print(f"Sélection de {args.n_discrim} groupes pour la discrimination...")
discrim_groups = select_discrimination_groups(metadata, cocoa_entries, n=args.n_discrim)
print(f" {len(discrim_groups)} groupes sélectionnés")
# Découper en batches
scenario_batches = [
selected_codes[i:i + args.batch]
for i in range(0, len(selected_codes), args.batch)
]
print(f"\n{len(scenario_batches)} batches de scénarios ({args.batch} codes/batch)")
print(f"{len(discrim_groups)} exercices de discrimination")
# Fichier de sortie (avec reprise possible)
output_path = OUT / "synthetic_chatml.jsonl"
existing_count = 0
if args.resume and output_path.exists():
with open(output_path) as f:
existing_count = sum(1 for _ in f)
print(f"\nReprise: {existing_count} exemples existants")
if args.dry_run:
# Mode dry-run : montrer des exemples de prompts
print("\n=== DRY RUN ===")
print("\n--- Exemple de prompt scénario ---")
prompt = build_scenario_prompt(scenario_batches[0], fhir_by_code, cocoa_entries)
print(prompt[:2000])
print("\n--- Exemple de prompt discrimination ---")
if discrim_groups:
prompt = build_discrimination_prompt(discrim_groups[0], fhir_by_code)
print(prompt[:1500])
return
# Initialiser le client OpenAI
from openai import OpenAI
client = OpenAI(api_key=api_key)
all_examples = []
total_batches = len(scenario_batches) + len(discrim_groups)
completed = 0
errors = 0
# Ouvrir le fichier en mode append
mode = "a" if args.resume else "w"
with open(output_path, mode) as fh:
# Phase 1 : Scénarios cliniques
print(f"\n{'='*50}")
print(f"Phase 1 : Génération des scénarios cliniques...")
print(f"{'='*50}")
with ThreadPoolExecutor(max_workers=args.workers) as executor:
futures = {}
for i, batch in enumerate(scenario_batches):
future = executor.submit(
process_scenario_batch, client, args.model,
batch, fhir_by_code, cocoa_entries, i
)
futures[future] = i
for future in as_completed(futures):
batch_idx = futures[future]
try:
examples = future.result()
for ex in examples:
fh.write(json.dumps(ex, ensure_ascii=False) + "\n")
all_examples.append(ex)
completed += 1
if completed % 10 == 0:
print(f" [{completed}/{len(scenario_batches)}] {len(all_examples)} exemples générés...")
except Exception as e:
errors += 1
print(f" [Batch {batch_idx}] Exception: {e}")
print(f" Scénarios: {len(all_examples)} exemples")
# Phase 2 : Discrimination
print(f"\n{'='*50}")
print(f"Phase 2 : Génération des exercices de discrimination...")
print(f"{'='*50}")
discrim_count = 0
with ThreadPoolExecutor(max_workers=args.workers) as executor:
futures = {}
for i, group in enumerate(discrim_groups):
future = executor.submit(
process_discrimination_batch, client, args.model,
group, fhir_by_code, i
)
futures[future] = i
for future in as_completed(futures):
batch_idx = futures[future]
try:
examples = future.result()
for ex in examples:
fh.write(json.dumps(ex, ensure_ascii=False) + "\n")
all_examples.append(ex)
discrim_count += 1
except Exception as e:
errors += 1
print(f" [Discrim {batch_idx}] Exception: {e}")
print(f" Discrimination: {discrim_count} exemples")
# Stats finales
total = len(all_examples) + existing_count
print(f"\n{'='*50}")
print(f"Génération terminée !")
print(f" Nouveaux exemples : {len(all_examples)}")
if existing_count:
print(f" Existants (reprise): {existing_count}")
print(f" Total : {total}")
print(f" Erreurs : {errors}")
print(f" Fichier : {output_path}")
if output_path.exists():
print(f" Taille : {output_path.stat().st_size / 1024:.0f} Ko")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,67 @@
#!/bin/bash
# Phase 2 — Installation d'Unsloth et dépendances pour le fine-tuning
#
# Prérequis vérifiés :
# - PyTorch 2.10.0+cu128 ✅
# - CUDA 12.x ✅
# - transformers 4.57.6 ✅
# - accelerate 1.12.0 ✅
#
# Usage : bash scripts/07_setup_unsloth.sh
set -e
VENV="/home/dom/ai/t2a/.venv"
PIP="$VENV/bin/pip"
PYTHON="$VENV/bin/python3"
echo "=== Installation Unsloth + dépendances ==="
echo ""
# 1. bitsandbytes (quantification 4-bit)
echo "[1/4] Installation de bitsandbytes..."
$PIP install --upgrade bitsandbytes
# 2. PEFT (LoRA)
echo ""
echo "[2/4] Installation de PEFT..."
$PIP install --upgrade peft
# 3. TRL (SFTTrainer)
echo ""
echo "[3/4] Installation de TRL..."
$PIP install --upgrade trl
# 4. Unsloth
echo ""
echo "[4/4] Installation d'Unsloth..."
$PIP install --upgrade --no-cache-dir "unsloth[cu128-torch2100] @ git+https://github.com/unslothai/unsloth.git"
# Vérification
echo ""
echo "=== Vérification ==="
$PYTHON -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
print(f'GPU: {torch.cuda.get_device_name(0)}')
import bitsandbytes
print(f'bitsandbytes: OK')
from peft import LoraConfig
print(f'PEFT: OK')
from trl import SFTTrainer
print(f'TRL: OK')
try:
from unsloth import FastLanguageModel
print(f'Unsloth: OK')
except Exception as e:
print(f'Unsloth: {e}')
print('Essayez: pip install unsloth')
print()
print('=== Setup prêt pour le fine-tuning ! ===')
"

331
scripts/08_train_lora.py Normal file
View File

@@ -0,0 +1,331 @@
#!/usr/bin/env python3
"""
Phase 2 — Fine-tuning QLoRA de gemma3:12b avec Unsloth.
IMPORTANT : Arrêter Ollama avant de lancer ce script !
sudo systemctl stop ollama
Le script :
1. Charge gemma3:12b en 4-bit quantifié
2. Attache un adaptateur LoRA
3. Entraîne sur le dataset PMSI (ChatML)
4. Sauvegarde l'adaptateur LoRA
5. (Optionnel) Exporte en GGUF pour Ollama
Prérequis :
- bash scripts/07_setup_unsloth.sh (installer les dépendances)
- data/datasets/pmsi_train.jsonl + pmsi_eval.jsonl
Usage :
python scripts/08_train_lora.py [--epochs 3] [--lr 2e-4] [--batch 1] [--export-gguf]
Hardware cible : RTX 5070 (12 Go VRAM)
- QLoRA 4-bit sur 12B ≈ 8-9 Go VRAM
- batch_size=1 + gradient_accumulation=8 → batch effectif de 8
- gradient_checkpointing pour économiser la VRAM
"""
import argparse
import json
import os
from pathlib import Path
BASE = Path(__file__).resolve().parent.parent
DATASETS = BASE / "data" / "datasets"
OUTPUT = BASE / "models"
OUTPUT.mkdir(parents=True, exist_ok=True)
def check_prerequisites():
"""Vérifier que tout est prêt."""
import torch
# GPU disponible ?
if not torch.cuda.is_available():
raise RuntimeError("CUDA non disponible. Vérifiez votre installation GPU.")
gpu_name = torch.cuda.get_device_name(0)
vram_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
vram_free = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / 1024**3
print(f"GPU: {gpu_name}")
print(f"VRAM: {vram_total:.1f} Go total, {vram_free:.1f} Go libre")
if vram_free < 10:
print("⚠ Moins de 10 Go de VRAM libre.")
print(" → Arrêtez Ollama : sudo systemctl stop ollama")
print(" → Ou utilisez --offload-cpu pour décharger sur la RAM")
# Dataset existe ?
train_path = DATASETS / "pmsi_train.jsonl"
eval_path = DATASETS / "pmsi_eval.jsonl"
if not train_path.exists() or not eval_path.exists():
raise FileNotFoundError(
f"Dataset non trouvé. Lancez d'abord : python scripts/04_build_dataset.py"
)
# Compter les exemples
with open(train_path) as f:
n_train = sum(1 for _ in f)
with open(eval_path) as f:
n_eval = sum(1 for _ in f)
print(f"Dataset: {n_train} train + {n_eval} eval")
return train_path, eval_path
def load_model(model_name, max_seq_length, load_in_4bit=True):
"""Charger le modèle avec Unsloth."""
from unsloth import FastLanguageModel
print(f"\nChargement de {model_name} (4-bit={load_in_4bit})...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=None, # Auto-détection
load_in_4bit=load_in_4bit,
)
print(f" Modèle chargé : {model.config._name_or_path}")
print(f" Paramètres : {model.num_parameters() / 1e9:.1f}B")
return model, tokenizer
def attach_lora(model, r=32, alpha=64, dropout=0.05):
"""Attacher l'adaptateur LoRA."""
from unsloth import FastLanguageModel
print(f"\nAttachement LoRA (r={r}, alpha={alpha}, dropout={dropout})...")
model = FastLanguageModel.get_peft_model(
model,
r=r,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=alpha,
lora_dropout=dropout,
bias="none",
use_gradient_checkpointing="unsloth", # Économise 30% VRAM
random_state=42,
)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f" Paramètres entraînables : {trainable / 1e6:.1f}M / {total / 1e9:.1f}B ({100 * trainable / total:.2f}%)")
return model
def load_dataset(train_path, eval_path):
"""Charger le dataset au format ChatML."""
from datasets import Dataset
def load_jsonl(path):
examples = []
with open(path) as f:
for line in f:
examples.append(json.loads(line.strip()))
return examples
train_data = load_jsonl(train_path)
eval_data = load_jsonl(eval_path)
train_ds = Dataset.from_list(train_data)
eval_ds = Dataset.from_list(eval_data)
print(f"\nDataset chargé :")
print(f" Train : {len(train_ds)} exemples")
print(f" Eval : {len(eval_ds)} exemples")
return train_ds, eval_ds
def format_chat(example, tokenizer):
"""Formater un exemple ChatML pour l'entraînement."""
messages = example["messages"]
# Utiliser le template de chat du tokenizer
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
return {"text": text}
def train(model, tokenizer, train_ds, eval_ds, args):
"""Lancer l'entraînement."""
from trl import SFTTrainer, SFTConfig
from aim.hugging_face import AimCallback
print(f"\nConfiguration d'entraînement :")
print(f" Epochs : {args.epochs}")
print(f" Learning rate : {args.lr}")
print(f" Batch size : {args.batch} (gradient_accumulation={args.grad_accum})")
print(f" Batch effectif : {args.batch * args.grad_accum}")
print(f" Max seq length : {args.max_seq_length}")
# Formater le dataset
train_ds = train_ds.map(lambda x: format_chat(x, tokenizer), num_proc=4)
eval_ds = eval_ds.map(lambda x: format_chat(x, tokenizer), num_proc=4)
output_dir = OUTPUT / "pmsi-lora-checkpoints"
# Callback Aim pour le tracking des métriques
aim_callback = AimCallback(
repo=str(BASE),
experiment="pmsi-coder-v2",
)
training_args = SFTConfig(
output_dir=str(output_dir),
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch,
per_device_eval_batch_size=args.batch,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
weight_decay=0.01,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
logging_steps=10,
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=500,
save_total_limit=3,
fp16=False,
bf16=True,
max_seq_length=args.max_seq_length,
dataset_text_field="text",
seed=42,
report_to="none",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_ds,
eval_dataset=eval_ds,
args=training_args,
callbacks=[aim_callback],
)
print(f"\nDémarrage de l'entraînement...")
print(f" Output : {output_dir}")
print(f" Steps estimés : ~{len(train_ds) * args.epochs // (args.batch * args.grad_accum)}")
if args.resume:
print(f" Reprise depuis le dernier checkpoint...")
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
# Sauvegarder le modèle final
final_dir = OUTPUT / "pmsi-lora-final"
model.save_pretrained(str(final_dir))
tokenizer.save_pretrained(str(final_dir))
print(f"\nModèle LoRA sauvegardé : {final_dir}")
return trainer, final_dir
def export_gguf(model, tokenizer, final_dir, quantization="q4_k_m"):
"""Exporter en GGUF pour Ollama."""
print(f"\nExport GGUF ({quantization})...")
gguf_dir = OUTPUT / "pmsi-gguf"
gguf_dir.mkdir(parents=True, exist_ok=True)
# Unsloth export
model.save_pretrained_gguf(
str(gguf_dir),
tokenizer,
quantization_method=quantization,
)
# Trouver le fichier GGUF
gguf_files = list(gguf_dir.glob("*.gguf"))
if gguf_files:
gguf_path = gguf_files[0]
print(f" GGUF exporté : {gguf_path} ({gguf_path.stat().st_size / 1024**3:.1f} Go)")
# Créer le Modelfile pour Ollama
modelfile_path = gguf_dir / "Modelfile"
modelfile_content = f"""FROM {gguf_path.name}
PARAMETER temperature 0.3
PARAMETER top_p 0.9
PARAMETER num_ctx 8192
"""
with open(modelfile_path, "w") as f:
f.write(modelfile_content)
print(f" Modelfile créé : {modelfile_path}")
print(f"\n Pour importer dans Ollama :")
print(f" cd {gguf_dir}")
print(f" ollama create pmsi-coder -f Modelfile")
else:
print(" Aucun fichier GGUF trouvé !")
def main():
parser = argparse.ArgumentParser(description="Fine-tuning QLoRA avec Unsloth")
# Modèle
parser.add_argument("--model", default="unsloth/gemma-3-12b-it-bnb-4bit",
help="Nom du modèle HuggingFace")
parser.add_argument("--max-seq-length", type=int, default=512,
help="Longueur max des séquences")
# LoRA
parser.add_argument("--lora-r", type=int, default=32, help="Rang LoRA")
parser.add_argument("--lora-alpha", type=int, default=64, help="Alpha LoRA")
parser.add_argument("--lora-dropout", type=float, default=0.0, help="Dropout LoRA (0=fast patching Unsloth)")
# Entraînement
parser.add_argument("--epochs", type=int, default=3, help="Nombre d'epochs")
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
parser.add_argument("--batch", type=int, default=1, help="Batch size par GPU")
parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation steps")
# Resume
parser.add_argument("--resume", action="store_true", help="Reprendre depuis le dernier checkpoint")
# Export
parser.add_argument("--export-gguf", action="store_true", help="Exporter en GGUF après entraînement")
parser.add_argument("--gguf-quant", default="q4_k_m", help="Méthode de quantification GGUF")
args = parser.parse_args()
# Vérifications
train_path, eval_path = check_prerequisites()
# Charger le modèle
model, tokenizer = load_model(args.model, args.max_seq_length)
# Attacher LoRA
model = attach_lora(model, r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout)
# Charger le dataset
train_ds, eval_ds = load_dataset(train_path, eval_path)
# Entraîner
trainer, final_dir = train(model, tokenizer, train_ds, eval_ds, args)
# Export GGUF optionnel
if args.export_gguf:
export_gguf(model, tokenizer, final_dir, args.gguf_quant)
print("\n" + "=" * 50)
print("Fine-tuning terminé !")
print(f" Adaptateur LoRA : {final_dir}")
if args.export_gguf:
print(f" GGUF : {OUTPUT / 'pmsi-gguf'}")
print("=" * 50)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,460 @@
#!/usr/bin/env python3
"""
Phase 3 — Génération de triplets (anchor, positive, negative) pour fine-tuning d'embedding.
Sources :
1. Cache Ollama (1840 paires diagnostic → code CIM-10)
2. Métadonnées FAISS (12444 entrées CIM-10 avec extraits)
3. CoCoA (synonymes, exclusions)
4. Données CIM-10 FHIR (descriptions → codes)
Format de sortie : JSONL avec {anchor, positive, negative}
- anchor : texte clinique / diagnostic tel qu'écrit par le médecin
- positive : extrait de référence CIM-10 correspondant au bon code
- negative : extrait de référence CIM-10 d'un code similaire mais incorrect (hard negative)
Usage :
python scripts/09_build_embedding_triplets.py [--output data/datasets/triplets.jsonl]
"""
import argparse
import json
import random
from collections import defaultdict
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
T2A = BASE.parent / "t2a"
# Sources
CACHE_PATH = T2A / "data" / "ollama_cache.json"
METADATA_PATH = T2A / "data" / "rag_index" / "metadata.json"
COCOA_PATH = BASE / "data" / "processed" / "cocoa_entries_debug.json"
CIM10_CHATML = BASE / "data" / "processed" / "cim10_chatml.jsonl"
# Output
DEFAULT_OUTPUT = BASE / "data" / "datasets" / "triplets.jsonl"
def load_faiss_metadata() -> tuple[dict[str, list[dict]], list[dict]]:
"""Charge les métadonnées FAISS et indexe par code CIM-10.
Returns:
(code_to_entries, all_cim10_entries)
"""
with open(METADATA_PATH) as f:
metadata = json.load(f)
code_to_entries = defaultdict(list)
all_cim10 = []
for m in metadata:
if m.get("document") not in ("cim10", "cim10_alpha"):
continue
code = m.get("code", "").strip()
if not code:
continue
all_cim10.append(m)
# Normaliser le code (avec et sans point)
code_to_entries[code].append(m)
# Aussi indexer par préfixe 3 caractères pour les hard negatives
# Ex: I10 → chapitre I, catégorie I10
print(f"FAISS metadata : {len(all_cim10)} entrées CIM-10, {len(code_to_entries)} codes uniques")
return code_to_entries, all_cim10
def get_chapter(code: str) -> str:
"""Extrait le chapitre CIM-10 (première lettre)."""
return code[0] if code else ""
def get_category(code: str) -> str:
"""Extrait la catégorie CIM-10 (3 premiers caractères, ex: I10, K80)."""
clean = code.replace(".", "")
return clean[:3] if len(clean) >= 3 else code
def find_hard_negatives(
code: str,
code_to_entries: dict[str, list[dict]],
all_cim10: list[dict],
n: int = 3,
) -> list[dict]:
"""Trouve des hard negatives : codes proches mais différents.
Stratégie par priorité :
1. Même catégorie (ex: K80.0 vs K80.2) — le plus dur
2. Même chapitre (ex: K80 vs K81) — difficile
3. Random d'un autre chapitre — facile (pour contraste)
"""
category = get_category(code)
chapter = get_chapter(code)
negatives = []
# 1. Même catégorie (siblings)
siblings = []
for c, entries in code_to_entries.items():
if c != code and get_category(c) == category:
siblings.extend(entries)
if siblings:
random.shuffle(siblings)
negatives.extend(siblings[:n])
# 2. Même chapitre si pas assez
if len(negatives) < n:
chapter_entries = []
for c, entries in code_to_entries.items():
if c != code and get_chapter(c) == chapter and get_category(c) != category:
chapter_entries.extend(entries)
if chapter_entries:
random.shuffle(chapter_entries)
negatives.extend(chapter_entries[: n - len(negatives)])
# 3. Random si toujours pas assez
if len(negatives) < n:
others = [m for m in all_cim10 if get_category(m.get("code", "")) != category]
if others:
random.shuffle(others)
negatives.extend(others[: n - len(negatives)])
return negatives[:n]
def format_entry_text(entry: dict) -> str:
"""Formate une entrée FAISS en texte lisible pour l'embedding."""
code = entry.get("code", "")
extrait = entry.get("extrait", "").strip()
if code and extrait:
# S'assurer que le code est dans le texte
if not extrait.startswith(code):
return f"{code} {extrait}"
return extrait
def triplets_from_cache(
code_to_entries: dict[str, list[dict]],
all_cim10: list[dict],
) -> list[dict]:
"""Génère des triplets à partir du cache Ollama."""
with open(CACHE_PATH) as f:
cache = json.load(f)
entries = cache.get("entries", cache)
if isinstance(entries, str):
return []
triplets = []
skipped = {"no_code": 0, "no_match": 0, "no_neg": 0}
for key, val in entries.items():
if not isinstance(val, dict):
continue
code = val.get("code", "")
if not code:
skipped["no_code"] += 1
continue
# Extraire le texte du diagnostic depuis la clé
# Format: "das::hypertension artérielle" ou "dp::pancréatite aiguë"
parts = key.split("::", 1)
if len(parts) != 2:
continue
diag_type, diag_text = parts
if diag_type == "das_llm":
continue # Format différent, skip
# Normaliser le code
code_clean = code.strip().upper()
# Trouver le positive dans FAISS
positive_entries = code_to_entries.get(code_clean, [])
if not positive_entries:
# Essayer sans le point
code_no_dot = code_clean.replace(".", "")
for c, ents in code_to_entries.items():
if c.replace(".", "") == code_no_dot:
positive_entries = ents
break
if not positive_entries:
skipped["no_match"] += 1
continue
positive = random.choice(positive_entries)
positive_text = format_entry_text(positive)
# Hard negatives
negs = find_hard_negatives(code_clean, code_to_entries, all_cim10, n=2)
if not negs:
skipped["no_neg"] += 1
continue
for neg in negs:
neg_text = format_entry_text(neg)
if neg_text and positive_text and neg_text != positive_text:
triplets.append({
"anchor": diag_text,
"positive": positive_text,
"negative": neg_text,
"source": "cache_ollama",
"code_pos": code_clean,
"code_neg": neg.get("code", ""),
})
print(f"Cache Ollama → {len(triplets)} triplets "
f"(skip: {skipped['no_code']} sans code, {skipped['no_match']} pas dans FAISS, "
f"{skipped['no_neg']} sans négatif)")
return triplets
def triplets_from_cocoa(
code_to_entries: dict[str, list[dict]],
all_cim10: list[dict],
) -> list[dict]:
"""Génère des triplets à partir de CoCoA (synonymes + exclusions)."""
if not COCOA_PATH.exists():
print("CoCoA debug file non trouvé, skip")
return []
with open(COCOA_PATH) as f:
cocoa = json.load(f)
# Le fichier debug est un dict {code: entry}, pas une liste
if isinstance(cocoa, dict):
cocoa = list(cocoa.values())
triplets = []
for entry in cocoa:
code = entry.get("code", "").strip()
desc = entry.get("description", "").strip()
synonymes = entry.get("synonyms", [])
exclusions = entry.get("exclusions", [])
if not code or not desc:
continue
# Positive = entrée FAISS pour ce code
positive_entries = code_to_entries.get(code, [])
if not positive_entries:
code_no_dot = code.replace(".", "")
for c, ents in code_to_entries.items():
if c.replace(".", "") == code_no_dot:
positive_entries = ents
break
if not positive_entries:
continue
positive = random.choice(positive_entries)
positive_text = format_entry_text(positive)
# Synonymes comme anchors supplémentaires
anchors = [desc] + [s for s in synonymes if len(s) > 5]
# Exclusions comme hard negatives explicites
explicit_negs = []
for excl in exclusions:
# Essayer d'extraire un code de l'exclusion (ex: "Diabète de type 2 (E11)")
import re
code_match = re.search(r"\(([A-Z]\d{2}(?:\.\d{1,2})?)\)", excl)
if code_match:
neg_code = code_match.group(1)
neg_entries = code_to_entries.get(neg_code, [])
if neg_entries:
explicit_negs.append(random.choice(neg_entries))
# Hard negatives auto si pas assez d'explicites
auto_negs = find_hard_negatives(code, code_to_entries, all_cim10, n=2)
all_negs = explicit_negs + auto_negs
if not all_negs:
continue
for anchor in anchors[:3]: # Max 3 anchors par code
for neg in all_negs[:2]: # Max 2 négatifs par anchor
neg_text = format_entry_text(neg)
if neg_text and positive_text and neg_text != positive_text:
triplets.append({
"anchor": anchor,
"positive": positive_text,
"negative": neg_text,
"source": "cocoa",
"code_pos": code,
"code_neg": neg.get("code", ""),
})
print(f"CoCoA → {len(triplets)} triplets")
return triplets
def triplets_from_cim10_fhir(
code_to_entries: dict[str, list[dict]],
all_cim10: list[dict],
) -> list[dict]:
"""Génère des triplets à partir du CIM-10 FHIR ChatML (description → code)."""
if not CIM10_CHATML.exists():
print("CIM-10 ChatML non trouvé, skip")
return []
triplets = []
seen_codes = set()
with open(CIM10_CHATML) as f:
for line in f:
example = json.loads(line.strip())
messages = example.get("messages", [])
# Extraire la question (user) et la réponse (assistant)
user_msg = ""
assistant_msg = ""
for m in messages:
if m["role"] == "user":
user_msg = m["content"]
elif m["role"] == "assistant":
assistant_msg = m["content"]
if not user_msg or not assistant_msg:
continue
# Parser le code de la réponse
try:
resp = json.loads(assistant_msg)
code = resp.get("code", "")
except json.JSONDecodeError:
continue
if not code or code in seen_codes:
continue
seen_codes.add(code)
# Extraire le texte du diagnostic depuis la question
# Format: "Quel est le code CIM-10 pour : Description"
# ou "Codage CIM-10 du diagnostic : Description"
import re
match = re.search(r"(?:pour|diagnostic)\s*:\s*(.+)", user_msg)
if not match:
continue
anchor = match.group(1).strip()
if len(anchor) < 5:
continue
# Positive
positive_entries = code_to_entries.get(code, [])
if not positive_entries:
continue
positive = random.choice(positive_entries)
positive_text = format_entry_text(positive)
# Hard negative
negs = find_hard_negatives(code, code_to_entries, all_cim10, n=1)
if not negs:
continue
neg = negs[0]
neg_text = format_entry_text(neg)
if neg_text and positive_text and neg_text != positive_text:
triplets.append({
"anchor": anchor,
"positive": positive_text,
"negative": neg_text,
"source": "cim10_fhir",
"code_pos": code,
"code_neg": neg.get("code", ""),
})
print(f"CIM-10 FHIR → {len(triplets)} triplets")
return triplets
def deduplicate(triplets: list[dict]) -> list[dict]:
"""Déduplique par (anchor, positive, negative)."""
seen = set()
deduped = []
for t in triplets:
key = (t["anchor"][:50], t["code_pos"], t["code_neg"])
if key not in seen:
seen.add(key)
deduped.append(t)
return deduped
def analyze_triplets(triplets: list[dict]) -> None:
"""Statistiques sur les triplets générés."""
sources = defaultdict(int)
hard_neg_types = {"same_category": 0, "same_chapter": 0, "diff_chapter": 0}
for t in triplets:
sources[t["source"]] += 1
code_pos = t["code_pos"]
code_neg = t["code_neg"]
if get_category(code_pos) == get_category(code_neg):
hard_neg_types["same_category"] += 1
elif get_chapter(code_pos) == get_chapter(code_neg):
hard_neg_types["same_chapter"] += 1
else:
hard_neg_types["diff_chapter"] += 1
print(f"\n=== Statistiques ===")
print(f"Total triplets : {len(triplets)}")
print(f"Par source :")
for src, count in sorted(sources.items()):
print(f" {src}: {count}")
print(f"Difficulté des négatifs :")
for neg_type, count in hard_neg_types.items():
pct = 100 * count / len(triplets) if triplets else 0
print(f" {neg_type}: {count} ({pct:.1f}%)")
def main():
parser = argparse.ArgumentParser(description="Génération de triplets pour embedding fine-tuning")
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT,
help="Fichier de sortie JSONL")
args = parser.parse_args()
print("=== Génération de triplets pour embedding fine-tuning ===\n")
# Charger FAISS metadata
code_to_entries, all_cim10 = load_faiss_metadata()
# Générer les triplets depuis chaque source
all_triplets = []
all_triplets.extend(triplets_from_cache(code_to_entries, all_cim10))
all_triplets.extend(triplets_from_cocoa(code_to_entries, all_cim10))
all_triplets.extend(triplets_from_cim10_fhir(code_to_entries, all_cim10))
# Dédupliquer
all_triplets = deduplicate(all_triplets)
random.shuffle(all_triplets)
# Statistiques
analyze_triplets(all_triplets)
# Sauvegarder
args.output.parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w") as f:
for t in all_triplets:
# Format simplifié pour l'entraînement : anchor, positive, negative
out = {
"anchor": t["anchor"],
"positive": t["positive"],
"negative": t["negative"],
}
f.write(json.dumps(out, ensure_ascii=False) + "\n")
size_mo = args.output.stat().st_size / 1024 / 1024
print(f"\nSauvegardé : {args.output} ({len(all_triplets)} triplets, {size_mo:.1f} Mo)")
# Aussi sauvegarder version détaillée pour debug
debug_path = args.output.with_suffix(".debug.jsonl")
with open(debug_path, "w") as f:
for t in all_triplets[:50]:
f.write(json.dumps(t, ensure_ascii=False, indent=2) + "\n---\n")
print(f"Debug (50 premiers) : {debug_path}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,263 @@
#!/usr/bin/env python3
"""
Phase 3 — Fine-tuning du modèle d'embedding pour le RAG médical.
Fine-tune sentence-camembert-large sur les triplets PMSI/CIM-10
pour améliorer la recherche FAISS dans le pipeline T2A.
Prérequis :
- python scripts/09_build_embedding_triplets.py
- data/datasets/triplets.jsonl (41K+ triplets)
Usage :
python scripts/10_train_embedding.py [--epochs 3] [--batch 32] [--lr 2e-5] [--device cpu]
--device cpu : entraîner sur CPU (si le GPU est occupé par le LoRA)
--device cuda : entraîner sur GPU (plus rapide, ~30 min)
Hardware :
- CPU : ~2-3h (Ryzen 9 9950X)
- GPU : ~30 min (RTX 5070, si disponible)
- RAM : ~4 Go (modèle ~1.3 Go + données)
"""
import argparse
import json
import random
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
TRIPLETS_PATH = BASE / "data" / "datasets" / "triplets.jsonl"
OUTPUT_DIR = BASE / "models" / "sentence-camembert-pmsi"
# Modèle de base (même que dans le pipeline T2A)
BASE_MODEL = "dangvantuan/sentence-camembert-large"
def load_triplets(path: Path, eval_ratio: float = 0.1):
"""Charge les triplets et split train/eval."""
from datasets import Dataset
anchors, positives, negatives = [], [], []
with open(path) as f:
for line in f:
t = json.loads(line.strip())
anchors.append(t["anchor"])
positives.append(t["positive"])
negatives.append(t["negative"])
total = len(anchors)
print(f"Triplets chargés : {total}")
# Shuffle déterministe
indices = list(range(total))
random.shuffle(indices)
split = int(total * (1 - eval_ratio))
train_idx = indices[:split]
eval_idx = indices[split:]
train_ds = Dataset.from_dict({
"anchor": [anchors[i] for i in train_idx],
"positive": [positives[i] for i in train_idx],
"negative": [negatives[i] for i in train_idx],
})
eval_ds = Dataset.from_dict({
"anchor": [anchors[i] for i in eval_idx],
"positive": [positives[i] for i in eval_idx],
"negative": [negatives[i] for i in eval_idx],
})
print(f" Train : {len(train_ds)}")
print(f" Eval : {len(eval_ds)}")
return train_ds, eval_ds
def main():
parser = argparse.ArgumentParser(description="Fine-tuning embedding PMSI")
parser.add_argument("--model", default=BASE_MODEL,
help="Modèle de base sentence-transformers")
parser.add_argument("--triplets", type=Path, default=TRIPLETS_PATH,
help="Fichier de triplets JSONL")
parser.add_argument("--output", type=Path, default=OUTPUT_DIR,
help="Répertoire de sortie")
# Entraînement
parser.add_argument("--epochs", type=int, default=3, help="Nombre d'epochs")
parser.add_argument("--batch", type=int, default=32, help="Batch size")
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
parser.add_argument("--warmup-ratio", type=float, default=0.1, help="Warmup ratio")
parser.add_argument("--max-seq-length", type=int, default=256,
help="Longueur max des séquences")
# Loss
parser.add_argument("--loss", choices=["mnrl", "triplet"], default="mnrl",
help="Fonction de loss (mnrl=MultipleNegativesRankingLoss, triplet=TripletLoss)")
# Device
parser.add_argument("--device", default=None,
help="Device (cuda/cpu). Auto-détection si non spécifié.")
args = parser.parse_args()
print("=" * 60)
print("Fine-tuning embedding pour RAG PMSI")
print("=" * 60)
# Vérifier les triplets
if not args.triplets.exists():
raise FileNotFoundError(
f"Triplets non trouvés : {args.triplets}\n"
f"Lancez d'abord : python scripts/09_build_embedding_triplets.py"
)
# Device
import torch
if args.device is None:
if torch.cuda.is_available():
# Vérifier si le GPU a assez de VRAM libre (~2 Go suffisent)
free = torch.cuda.mem_get_info()[0] / 1024**3
if free > 2.0:
args.device = "cuda"
print(f"GPU détecté : {torch.cuda.get_device_name(0)} ({free:.1f} Go libre)")
else:
args.device = "cpu"
print(f"GPU occupé ({free:.1f} Go libre) — fallback CPU")
else:
args.device = "cpu"
print("Pas de GPU — entraînement CPU")
# Charger le modèle
from sentence_transformers import SentenceTransformer
print(f"\nChargement du modèle : {args.model}")
model = SentenceTransformer(args.model, device=args.device)
model.max_seq_length = args.max_seq_length
print(f" Max seq length : {model.max_seq_length}")
print(f" Dimension embedding : {model.get_sentence_embedding_dimension()}")
# Charger les données
print()
train_ds, eval_ds = load_triplets(args.triplets)
# Loss
if args.loss == "mnrl":
from sentence_transformers.losses import MultipleNegativesRankingLoss
loss = MultipleNegativesRankingLoss(model)
print(f"\nLoss : MultipleNegativesRankingLoss (in-batch negatives + hard negatives)")
else:
from sentence_transformers.losses import TripletLoss
loss = TripletLoss(model)
print(f"\nLoss : TripletLoss")
# Evaluator
from sentence_transformers.evaluation import TripletEvaluator
evaluator = TripletEvaluator(
anchors=eval_ds["anchor"],
positives=eval_ds["positive"],
negatives=eval_ds["negative"],
name="pmsi-eval",
batch_size=args.batch,
)
# Évaluation baseline (avant fine-tuning)
print("\nÉvaluation baseline (avant fine-tuning)...")
baseline = evaluator(model)
print(f" Accuracy baseline : {baseline.get('pmsi-eval_cosine_accuracy', 'N/A')}")
# Training args
from sentence_transformers import SentenceTransformerTrainingArguments
n_steps = len(train_ds) * args.epochs // args.batch
eval_steps = max(n_steps // 10, 50) # Évaluer ~10 fois pendant l'entraînement
save_steps = max(n_steps // 5, 100) # Sauvegarder ~5 fois
training_args = SentenceTransformerTrainingArguments(
output_dir=str(args.output / "checkpoints"),
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch,
per_device_eval_batch_size=args.batch,
learning_rate=args.lr,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type="cosine",
fp16=(args.device == "cuda"),
bf16=False,
logging_steps=50,
eval_strategy="steps",
eval_steps=eval_steps,
save_strategy="steps",
save_steps=save_steps,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="pmsi-eval_cosine_accuracy",
greater_is_better=True,
seed=42,
report_to="aim",
)
print(f"\nConfiguration :")
print(f" Epochs : {args.epochs}")
print(f" Batch size : {args.batch}")
print(f" Learning rate : {args.lr}")
print(f" Steps estimés : ~{n_steps}")
print(f" Eval tous les : {eval_steps} steps")
print(f" Device : {args.device}")
# Trainer
from sentence_transformers import SentenceTransformerTrainer
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
loss=loss,
evaluator=evaluator,
)
# Entraîner
print(f"\nDémarrage de l'entraînement...")
trainer.train()
# Sauvegarder le modèle final
final_dir = args.output / "final"
model.save_pretrained(str(final_dir))
print(f"\nModèle sauvegardé : {final_dir}")
# Évaluation finale
print("\nÉvaluation finale (après fine-tuning)...")
final_results = evaluator(model)
final_acc = final_results.get("pmsi-eval_cosine_accuracy", 0)
baseline_acc = baseline.get("pmsi-eval_cosine_accuracy", 0)
print(f"\n{'=' * 60}")
print(f"Résultats :")
print(f" Accuracy baseline : {baseline_acc:.4f}")
print(f" Accuracy finale : {final_acc:.4f}")
if baseline_acc > 0:
delta = final_acc - baseline_acc
pct = 100 * delta / baseline_acc
print(f" Amélioration : {delta:+.4f} ({pct:+.1f}%)")
print(f"\nModèle final : {final_dir}")
print(f"{'=' * 60}")
# Instructions d'intégration
print(f"""
Pour utiliser dans le pipeline T2A, modifier src/medical/rag_search.py :
_embed_model = SentenceTransformer("{final_dir}", device=_device)
Puis reconstruire l'index FAISS avec le nouveau modèle :
python -m src.medical.rag_index --rebuild
""")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,477 @@
#!/usr/bin/env python3
"""
Parse les référentiels ATIH/PMSI pour générer des exemples ChatML.
Sources :
1. Annexe-4 CMA V11e — niveaux de sévérité officiels (tabulaire)
2. Racines GHM V11e — caractéristiques des racines (tabulaire)
3. Fascicules de codage — règles par spécialité (texte libre)
4. Instruction DGOS contrôle T2A — priorités de contrôle
Déduplique automatiquement avec les données CoCoA existantes (sévérité CMA).
Usage :
python scripts/11_parse_referentiels.py
"""
import json
import re
from pathlib import Path
import pdfplumber
BASE = Path(__file__).resolve().parent.parent
T2A = BASE.parent / "t2a"
REF_DIR = T2A / "data" / "referentiels"
OUTPUT = BASE / "data" / "processed"
OUTPUT.mkdir(parents=True, exist_ok=True)
SYSTEM_PROMPT = (
"Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. "
"Tu t'appuies sur les référentiels ATIH officiels."
)
SYSTEM_PROMPT_GHM = (
"Tu es un médecin DIM expert en groupage GHM/GHS pour le PMSI français. "
"Tu connais les règles de classification des GHM version 11e."
)
SYSTEM_PROMPT_CONTROLE = (
"Tu es un médecin DIM expert en contrôle T2A. "
"Tu connais les instructions DGOS et les règles de contrôle externe."
)
def _chatml(system: str, user: str, assistant: str, source: str = "") -> dict:
d = {
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
}
if source:
d["source"] = source
return d
# ─── 1. Annexe-4 CMA ────────────────────────────────────────────────────────
def parse_annexe4_cma() -> list[dict]:
"""Parse l'Annexe-4 : Diagnostics classés CMA avec niveaux de sévérité."""
pdf_path = list(REF_DIR.glob("*Annexe-4*CMA*.pdf"))
if not pdf_path:
print(" Annexe-4 CMA non trouvée, skip")
return []
pdf_path = pdf_path[0]
# Charger les sévérités CoCoA existantes pour dédupliquer
cocoa_path = BASE / "data" / "processed" / "cocoa_entries_debug.json"
cocoa_severities = set()
if cocoa_path.exists():
with open(cocoa_path) as f:
cocoa = json.load(f)
for code, entry in cocoa.items():
if entry.get("severity"):
cocoa_severities.add(code)
entries = []
# Pattern: code (A00.0) suivi d'un niveau (2-4) suivi d'un libellé
pattern = re.compile(r"^([A-Z]\d{2}(?:\.\d{1,2})?)\s+([234])\s+(.+)$")
with pdfplumber.open(str(pdf_path)) as pdf:
for page in pdf.pages:
text = page.extract_text() or ""
for line in text.split("\n"):
line = line.strip()
m = pattern.match(line)
if m:
code, niveau, libelle = m.group(1), int(m.group(2)), m.group(3).strip()
entries.append({"code": code, "niveau": niveau, "libelle": libelle})
print(f" Annexe-4 CMA : {len(entries)} entrées extraites")
# Dédupliquer avec CoCoA
new_entries = [e for e in entries if e["code"] not in cocoa_severities]
dupes = len(entries) - len(new_entries)
print(f" Doublons CoCoA : {dupes}, nouvelles : {len(new_entries)}")
# Générer les exemples ChatML
examples = []
# Type 1 : Quel niveau CMA ?
for e in entries: # Garder tous pour renforcer, même les doublons
examples.append(_chatml(
SYSTEM_PROMPT,
f"Quel est le niveau de sévérité CMA du code {e['code']} ({e['libelle']}) ?",
f"Le code {e['code']} ({e['libelle']}) est classé CMA de niveau {e['niveau']}. "
f"{'Ce diagnostic est considéré comme une complication ou morbidité associée majeure.' if e['niveau'] >= 3 else 'Ce diagnostic est une complication ou morbidité associée significative.'}",
source="annexe4_cma"
))
# Type 2 : Est-ce une CMA ? (discrimination)
# Quelques exemples de codes NON-CMA pour contraste
non_cma_codes = set()
cma_codes = {e["code"] for e in entries}
for e in entries:
# Codes voisins qui ne sont pas CMA
base = e["code"][:3]
for suffix in range(10):
candidate = f"{base}.{suffix}"
if candidate not in cma_codes and candidate not in non_cma_codes:
non_cma_codes.add(candidate)
if len(non_cma_codes) >= 500:
break
return examples
# ─── 2. Racines GHM ──────────────────────────────────────────────────────────
def parse_racines_ghm() -> list[dict]:
"""Parse les Racines GHM V11e (tableau de caractéristiques)."""
pdf_path = list(REF_DIR.glob("*Racines_GHM*.pdf"))
if not pdf_path:
print(" Racines GHM non trouvé, skip")
return []
pdf_path = pdf_path[0]
entries = []
# Pattern racine GHM : 2 chiffres + lettre + 2 chiffres (ex: 01C02, 05K06)
pattern = re.compile(r"^(\d{2}[A-Z]\d{2})\s+(.+)$")
with pdfplumber.open(str(pdf_path)) as pdf:
for page in pdf.pages[3:]: # Skip pages de couverture
text = page.extract_text() or ""
for line in text.split("\n"):
line = line.strip()
m = pattern.match(line)
if m:
racine = m.group(1)
reste = m.group(2).strip()
entries.append({"racine": racine, "description": reste})
print(f" Racines GHM : {len(entries)} racines extraites")
# Aussi extraire les tables complètes par page pour du contexte riche
examples = []
with pdfplumber.open(str(pdf_path)) as pdf:
# Extraire les tables
for page in pdf.pages[3:]:
tables = page.extract_tables()
for table in tables:
if not table or len(table) < 2:
continue
headers = table[0] if table[0] else []
for row in table[1:]:
if not row or not row[0]:
continue
racine = str(row[0]).strip()
if not re.match(r"^\d{2}[A-Z]\d{2}$", racine):
continue
# Construire la description depuis les colonnes
desc_parts = [str(c).strip() for c in row[1:] if c and str(c).strip()]
if not desc_parts:
continue
full_desc = " | ".join(desc_parts)
examples.append(_chatml(
SYSTEM_PROMPT_GHM,
f"Quelles sont les caractéristiques de la racine GHM {racine} ?",
f"La racine GHM {racine} : {full_desc}",
source="racines_ghm"
))
# Si pas de tables parsées, utiliser les entrées textuelles
if not examples:
for e in entries:
examples.append(_chatml(
SYSTEM_PROMPT_GHM,
f"Quelles sont les caractéristiques de la racine GHM {e['racine']} ?",
f"La racine GHM {e['racine']} : {e['description']}",
source="racines_ghm"
))
return examples
# ─── 3. Arbre de décision GHM ─────────────────────────────────────────────────
def parse_arbre_ghm() -> list[dict]:
"""Parse l'arbre de décision GHM — extrait les règles par page/bloc."""
pdf_path = list(REF_DIR.glob("*Arbre_decision_GHM*.pdf"))
if not pdf_path:
print(" Arbre GHM non trouvé, skip")
return []
pdf_path = pdf_path[0]
examples = []
# Le PDF est un arbre graphique : chaque page contient des nœuds de décision
# avec des codes GHM (ex: 01K03, 28Z18) et des conditions (listes diagnostiques/actes)
ghm_pattern = re.compile(r"\b(\d{2}[A-Z]\d{2})\b")
condition_pattern = re.compile(r"\(([A-Z]-\d{3})\)") # Ex: (A-198), (D-064)
# Grouper les pages par CMD (les 2 premiers chiffres du GHM)
cmd_pages = {}
with pdfplumber.open(str(pdf_path)) as pdf:
for page in pdf.pages[7:]: # Skip couverture, symboles, légende
text = page.extract_text() or ""
if not text.strip() or len(text.strip()) < 30:
continue
# Nettoyer : supprimer les numéros de page isolés en début de texte
lines = text.strip().split("\n")
lines = [l for l in lines if not re.match(r"^\s*\d{1,3}\s*$", l.strip())]
text = "\n".join(lines)
# Trouver les codes GHM sur cette page
ghm_codes = ghm_pattern.findall(text)
if not ghm_codes:
continue
# Déterminer la CMD (premiers 2 chiffres)
cmd_num = ghm_codes[0][:2]
if cmd_num not in cmd_pages:
cmd_pages[cmd_num] = []
cmd_pages[cmd_num].append(text.strip())
# Générer un exemple par CMD
for cmd_num in sorted(cmd_pages.keys()):
pages_text = cmd_pages[cmd_num]
full_text = "\n\n".join(pages_text)
# Extraire tous les GHM et conditions
ghm_codes = sorted(set(ghm_pattern.findall(full_text)))
conditions = sorted(set(condition_pattern.findall(full_text)))
ghm_str = ", ".join(ghm_codes[:15])
body = full_text[:2000]
examples.append(_chatml(
SYSTEM_PROMPT_GHM,
f"Quelles sont les règles de l'arbre de décision GHM pour la CMD {cmd_num} ?",
f"CMD {cmd_num} — Arbre de décision GHM V11e :\n"
f"Racines GHM concernées : {ghm_str}\n\n{body}",
source="arbre_ghm"
))
print(f" Arbre GHM : {len(examples)} CMDs extraites")
return examples
# ─── 4. Fascicules de codage ──────────────────────────────────────────────────
def parse_fascicule(pdf_path: Path, topic: str) -> list[dict]:
"""Parse un fascicule de codage — extrait les sections de règles."""
examples = []
with pdfplumber.open(str(pdf_path)) as pdf:
full_text = ""
for page in pdf.pages:
text = page.extract_text() or ""
full_text += text + "\n"
if not full_text.strip():
return []
# Nettoyer : supprimer les lignes de table des matières (avec ........)
clean_lines = []
for line in full_text.split("\n"):
if "...." in line or "TABLE DES MATIERES" in line:
continue
# Supprimer les numéros de page isolés
if re.match(r"^\s*\d{1,3}\s*$", line.strip()):
continue
clean_lines.append(line)
full_text = "\n".join(clean_lines)
# Découper en sections par les titres (lignes commençant par un chiffre romain ou "Consignes")
sections = []
current_title = topic
current_body = []
for line in full_text.split("\n"):
line_stripped = line.strip()
# Détecter les titres de section
is_title = False
# Sections principales : I., II., III., IV.
if re.match(r"^[IVX]+\.\s+", line_stripped):
is_title = True
# Sous-sections : II.1., III.2., IV.3.
elif re.match(r"^[IVX]+\.\d+\.?\s+", line_stripped):
is_title = True
# Titres descriptifs courants dans les fascicules
elif re.match(r"^(Consignes|Règles|Les pièges|Le codage|Codage|Comment coder|Principes|Exemples?|Cas particulier|Remarque)", line_stripped, re.IGNORECASE):
is_title = True
elif re.match(r"^(Créé le|Mis à jour|Modifié le|TABLE DES MATIERES|FASCICULE DE CODAGE)", line_stripped):
continue # Skip dates et entêtes
if is_title and current_body:
body = "\n".join(current_body).strip()
if len(body) > 100:
sections.append((current_title, body))
current_title = line_stripped
current_body = []
else:
current_body.append(line)
# Dernière section
if current_body:
body = "\n".join(current_body).strip()
if len(body) > 100:
sections.append((current_title, body))
# Générer les exemples ChatML
for title, body in sections:
# Nettoyer
body = re.sub(r"\n{3,}", "\n\n", body)
body = body[:2000] # Limiter la taille
# Extraire les codes CIM-10 mentionnés pour enrichir la question
codes = re.findall(r"[A-Z]\d{2}(?:\.\d{1,2})?", body)
codes_unique = sorted(set(codes))[:5]
codes_str = f" (codes : {', '.join(codes_unique)})" if codes_unique else ""
# Type 1 : Règle de codage
examples.append(_chatml(
SYSTEM_PROMPT,
f"Quelles sont les règles de codage PMSI pour {title}{codes_str} ?",
f"Selon le fascicule de codage ATIH « {topic} » :\n\n{body}",
source=f"fascicule_{topic.lower().replace(' ', '_')[:30]}"
))
# Type 2 : Question pratique à partir du contenu
# Chercher les patterns "code X pour Y" ou "on codera X"
coding_rules = re.findall(
r"(?:on codera?|se code|coder?|est codé[e]?|codé[e]?\s+(?:avec|par|en))\s+([A-Z]\d{2}(?:\.\d{1,2})?)\s*(.*?)(?:\.|$)",
body, re.IGNORECASE
)
for code, context in coding_rules[:5]:
context = context.strip()[:200]
# Filtrer les contextes trop courts ou non informatifs
if len(context) < 15:
continue
examples.append(_chatml(
SYSTEM_PROMPT,
f"Comment coder en CIM-10 : {context} ?",
f"Selon les règles ATIH ({topic}), cette situation se code {code}. {context}",
source=f"fascicule_{topic.lower().replace(' ', '_')[:30]}"
))
return examples
def parse_all_fascicules() -> list[dict]:
"""Parse tous les fascicules disponibles."""
fascicules = {
"Fascicule_01_Generalites": "Généralités du codage PMSI",
"Fascicule_02_Maladies_digestives": "Maladies de l'appareil digestif",
"Fascicule_03_Tumeurs": "Tumeurs",
"Fascicule_04_Metabolisme": "Métabolisme",
"Fascicule_05_Gyneco_Obstetrique": "Gynécologie et Obstétrique",
"Fascicule_06_Neonatalogie": "Néonatalogie",
"Fascicule_07_Evolutions_2010": "Évolutions 2010",
"Fascicule_08_Maladies_infectieuses": "Maladies infectieuses",
"Fascicule_09_AVC": "Accidents vasculaires cérébraux",
"Fascicule_10_SCA_Coronariens": "Syndromes coronariens aigus",
}
all_examples = []
for filename_part, topic in fascicules.items():
pdf_files = list(REF_DIR.glob(f"*{filename_part}*"))
if not pdf_files:
print(f" {topic} : non trouvé, skip")
continue
examples = parse_fascicule(pdf_files[0], topic)
print(f" {topic} : {len(examples)} exemples")
all_examples.extend(examples)
return all_examples
# ─── 5. Instruction DGOS ──────────────────────────────────────────────────────
def parse_instruction_dgos() -> list[dict]:
"""Parse l'instruction DGOS contrôle T2A."""
pdf_path = list(REF_DIR.glob("*Instruction_DGOS*.pdf"))
if not pdf_path:
print(" Instruction DGOS non trouvée, skip")
return []
pdf_path = pdf_path[0]
with pdfplumber.open(str(pdf_path)) as pdf:
full_text = ""
for page in pdf.pages:
text = page.extract_text() or ""
full_text += text + "\n"
# Découper en sections thématiques
examples = []
sections = re.split(r"\n(?=\d+\.\s+|\d+\.\d+\s+|Annexe)", full_text)
for section in sections:
section = section.strip()
if len(section) < 100:
continue
# Titre = première ligne
lines = section.split("\n")
title = lines[0].strip()
body = "\n".join(lines[1:]).strip()[:2000]
if body:
examples.append(_chatml(
SYSTEM_PROMPT_CONTROLE,
f"Que dit l'instruction DGOS 2025 sur : {title} ?",
f"Selon l'instruction DGOS/FIP1/DSS/1A/2025/141 relative aux contrôles T2A :\n\n{body}",
source="instruction_dgos"
))
print(f" Instruction DGOS : {len(examples)} sections")
return examples
# ─── Main ─────────────────────────────────────────────────────────────────────
def main():
print("=" * 60)
print("Parsing des référentiels ATIH/PMSI")
print("=" * 60)
all_examples = []
# 1. Annexe-4 CMA
print("\n1. Annexe-4 CMA")
all_examples.extend(parse_annexe4_cma())
# 2. Racines GHM
print("\n2. Racines GHM")
all_examples.extend(parse_racines_ghm())
# 3. Arbre de décision GHM
print("\n3. Arbre de décision GHM")
all_examples.extend(parse_arbre_ghm())
# 4. Fascicules
print("\n4. Fascicules de codage")
all_examples.extend(parse_all_fascicules())
# 5. Instruction DGOS
print("\n5. Instruction DGOS")
all_examples.extend(parse_instruction_dgos())
# Sauvegarder
output_path = OUTPUT / "referentiels_chatml.jsonl"
with open(output_path, "w") as f:
for ex in all_examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
size_mo = output_path.stat().st_size / 1024 / 1024
print(f"\n{'=' * 60}")
print(f"Total : {len(all_examples)} exemples")
print(f"Sauvegardé : {output_path} ({size_mo:.1f} Mo)")
print(f"{'=' * 60}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,425 @@
#!/usr/bin/env python3
"""
Génère des exemples d'entraînement au format RÉEL du pipeline T2A
à partir du cache Ollama (gemma3:12b = gold standard).
Le cache contient les paires (diagnostic → code + raisonnement)
produites par gemma3:12b sur les 250 dossiers. On reconstruit
des prompts proches du pipeline et on utilise les réponses du cache
comme labels supervisés.
V2 : Utilise le cache actuel (1 840 entrées vs 100).
3 templates : court + long + CPAM contre-argumentation.
Tous les textes génèrent une version courte ET longue.
Cible ~4 000 exemples (2×1840 + CPAM bonus).
Produit : data/processed/pipeline_chatml.jsonl
Usage :
python scripts/12_generate_pipeline_examples.py
"""
import json
import random
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
T2A = Path("/home/dom/ai/t2a")
CACHE_PATH = T2A / "data" / "ollama_cache.json"
CACHE_BACKUP = T2A / "data" / "ollama_cache_gemma3.bak"
CIM10_DICT = T2A / "data" / "cim10_dict.json"
CIM10_SUPP = T2A / "data" / "cim10_supplements.json"
OUTPUT = BASE / "data" / "processed" / "pipeline_chatml.jsonl"
def load_cim10_dict() -> dict[str, str]:
"""Charge le dictionnaire CIM-10 (code → libellé)."""
d = {}
if CIM10_DICT.exists():
d.update(json.loads(CIM10_DICT.read_text()))
if CIM10_SUPP.exists():
d.update(json.loads(CIM10_SUPP.read_text()))
return d
def load_cache_entries() -> dict:
"""Charge toutes les entrées du cache (actuel + backup)."""
entries = {}
for path in [CACHE_PATH, CACHE_BACKUP]:
if path.exists():
data = json.loads(path.read_text())
new = sum(1 for k in data.get("entries", {}) if k not in entries)
entries.update(data.get("entries", {}))
print(f" {path.name}: {len(data.get('entries', {}))} entrées (+{new} nouvelles)")
return entries
SYSTEM_PROMPT = (
"Tu es un médecin DIM (Département d'Information Médicale) expert en codage PMSI. "
"Tu codes les diagnostics en CIM-10 en suivant une démarche structurée : "
"analyse clinique, identification des codes candidats, discrimination, "
"vérification des règles PMSI."
)
# Template simplifié reproduisant la structure du prompt pipeline
# (sans les sources RAG qui ne sont pas dans le cache)
PROMPT_TEMPLATE_DP = """Code ce diagnostic en CIM-10 pour le PMSI.
RÈGLES IMPÉRATIVES :
- Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère)
- Vérifie les notes d'inclusion/exclusion de chaque code candidat
- Le DP doit refléter le motif principal de prise en charge du séjour
- EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis existe, le symptôme ne doit PAS être codé comme DP
DIAGNOSTIC À CODER : "{texte}"
TYPE : DP (diagnostic principal)
Réponds UNIQUEMENT avec un objet JSON :
{{
"analyse_clinique": "que signifie ce diagnostic sur le plan médical",
"codes_candidats": "quels codes CIM-10 sont compatibles",
"discrimination": "pourquoi choisir ce code plutôt qu'un autre",
"regle_pmsi": "conformité aux règles PMSI pour un DP",
"code": "X99.9",
"confidence": "high ou medium ou low",
"justification": "explication courte en français"
}}"""
PROMPT_TEMPLATE_DAS = """Code ce diagnostic en CIM-10 pour le PMSI.
RÈGLES IMPÉRATIVES :
- Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère)
- Vérifie les notes d'inclusion/exclusion de chaque code candidat
- Un DAS doit avoir mobilisé des ressources supplémentaires pendant le séjour
- EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis existe, le symptôme ne doit PAS être codé comme DAS
DIAGNOSTIC À CODER : "{texte}"
TYPE : DAS (diagnostic associé significatif)
Réponds UNIQUEMENT avec un objet JSON :
{{
"analyse_clinique": "que signifie ce diagnostic sur le plan médical",
"codes_candidats": "quels codes CIM-10 sont compatibles",
"discrimination": "pourquoi choisir ce code plutôt qu'un autre",
"regle_pmsi": "conformité aux règles PMSI pour un DAS",
"code": "X99.9",
"confidence": "high ou medium ou low",
"justification": "explication courte en français"
}}"""
# Version longue avec contexte patient (simulé) pour entraîner sur des prompts longs
PROMPT_TEMPLATE_DP_LONG = """Code ce diagnostic en CIM-10 pour le PMSI.
RÈGLES IMPÉRATIVES :
- Le code doit provenir UNIQUEMENT de la nomenclature CIM-10 FR 2026
- Distingue la DESCRIPTION CLINIQUE (ce que le médecin écrit) de la LOGIQUE DE CODAGE (ce que l'ATIH impose)
- Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère)
- Vérifie les notes d'inclusion/exclusion de chaque code candidat
- Si le diagnostic est un DP, il doit refléter le motif principal de prise en charge du séjour
- EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis (Chapitres I-XIV, A00-N99) expliquant ce symptôme est présent, le symptôme ne doit PAS être codé
DIAGNOSTIC À CODER : "{texte}"
TYPE : DP (diagnostic principal)
CONTEXTE CLINIQUE :
{contexte}
SOURCES CIM-10 :
{sources}
Réponds UNIQUEMENT avec un objet JSON au format suivant, sans aucun texte avant ou après :
{{
"analyse_clinique": "que signifie ce diagnostic sur le plan médical",
"codes_candidats": "quels codes CIM-10 des sources sont compatibles",
"discrimination": "pourquoi choisir ce code plutôt qu'un autre (inclusions/exclusions, spécificité)",
"regle_pmsi": "conformité aux règles PMSI pour un DP (guide méthodologique)",
"code": "X99.9",
"confidence": "high ou medium ou low",
"justification": "explication courte en français"
}}"""
PROMPT_TEMPLATE_DAS_LONG = """Code ce diagnostic en CIM-10 pour le PMSI.
RÈGLES IMPÉRATIVES :
- Le code doit provenir UNIQUEMENT de la nomenclature CIM-10 FR 2026
- Distingue la DESCRIPTION CLINIQUE (ce que le médecin écrit) de la LOGIQUE DE CODAGE (ce que l'ATIH impose)
- Privilégie le code le plus SPÉCIFIQUE disponible (4e ou 5e caractère)
- Vérifie les notes d'inclusion/exclusion de chaque code candidat
- Un DAS doit avoir mobilisé des ressources supplémentaires pendant le séjour
- EXCLUSION SYMPTÔME : Si le diagnostic est un symptôme (R00-R99) et qu'un diagnostic précis (Chapitres I-XIV, A00-N99) expliquant ce symptôme est présent, le symptôme ne doit PAS être codé
DIAGNOSTIC À CODER : "{texte}"
TYPE : DAS (diagnostic associé significatif)
CONTEXTE CLINIQUE :
{contexte}
SOURCES CIM-10 :
{sources}
Réponds UNIQUEMENT avec un objet JSON au format suivant, sans aucun texte avant ou après :
{{
"analyse_clinique": "que signifie ce diagnostic sur le plan médical",
"codes_candidats": "quels codes CIM-10 des sources sont compatibles",
"discrimination": "pourquoi choisir ce code plutôt qu'un autre (inclusions/exclusions, spécificité)",
"regle_pmsi": "conformité aux règles PMSI pour un DAS (guide méthodologique)",
"code": "X99.9",
"confidence": "high ou medium ou low",
"justification": "explication courte en français"
}}"""
# V2 : Template CPAM contre-argumentation (3e variante)
PROMPT_TEMPLATE_CPAM = """Tu es médecin DIM. La CPAM conteste le codage d'un diagnostic. Argumente.
DIAGNOSTIC CONTESTÉ : "{texte}"
TYPE : {type_label}
CODE ACTUEL : {code}
MOTIF DE CONTESTATION CPAM :
{motif_cpam}
SOURCES CIM-10 :
{sources}
Réponds UNIQUEMENT avec un objet JSON :
{{
"analyse_clinique": "que signifie ce diagnostic sur le plan médical",
"codes_candidats": "quels codes CIM-10 sont compatibles",
"discrimination": "pourquoi ce code est correct (inclusions/exclusions, spécificité)",
"regle_pmsi": "conformité aux règles PMSI et au guide méthodologique",
"code": "{code}",
"confidence": "high ou medium ou low",
"justification": "argumentation structurée pour répondre à la contestation CPAM"
}}"""
CPAM_MOTIFS = [
"Le code {code} ne semble pas justifié par les éléments du dossier médical.",
"La CPAM propose de recoder en {alt_code} ({alt_label}). Justifiez le maintien du code actuel.",
"Le DAS {code} n'a pas mobilisé de ressources supplémentaires pendant le séjour.",
"Ce diagnostic est un symptôme déjà couvert par le DP. Il ne devrait pas être codé séparément.",
"Le niveau de sévérité CMA associé à {code} ne semble pas justifié cliniquement.",
]
def build_fake_source(code: str, cim10_dict: dict[str, str]) -> str:
"""Génère un extrait de source CIM-10 simulé pour un code donné."""
label = cim10_dict.get(code, "")
if not label:
return ""
# Trouver des codes voisins (même catégorie à 3 car.)
prefix = code[:3]
neighbors = []
for c, l in cim10_dict.items():
if c[:3] == prefix and c != code:
neighbors.append((c, l))
neighbors.sort()
neighbors = neighbors[:4]
lines = [f"--- Source 1: CIM-10 FR 2026 (code: {code}) ---"]
lines.append(f"{code} {label}")
for c, l in neighbors:
lines.append(f"{c} {l}")
lines.append("")
return "\n".join(lines)
def build_response_json(entry: dict) -> str:
"""Reconstruit la réponse JSON structurée depuis une entrée du cache."""
# Extraire les sections du raisonnement
rais = entry.get("raisonnement", "")
analyse = ""
candidats = ""
discrim = ""
regle = ""
if "ANALYSE CLINIQUE" in rais:
parts = rais.split("\n\n")
for part in parts:
p = part.strip()
if p.startswith("ANALYSE CLINIQUE"):
analyse = p.replace("ANALYSE CLINIQUE :\n", "").replace("ANALYSE CLINIQUE :", "").strip()
elif p.startswith("CODES CANDIDATS"):
candidats = p.replace("CODES CANDIDATS :\n", "").replace("CODES CANDIDATS :", "").strip()
elif p.startswith("DISCRIMINATION"):
discrim = p.replace("DISCRIMINATION :\n", "").replace("DISCRIMINATION :", "").strip()
elif p.startswith("REGLE PMSI") or p.startswith("RÈGLE PMSI"):
regle = p.split(":\n", 1)[-1].strip() if ":\n" in p else p.split(":", 1)[-1].strip()
resp = {
"analyse_clinique": analyse or "Diagnostic médical nécessitant un codage CIM-10 spécifique.",
"codes_candidats": candidats or f"[{entry.get('code', '?')}]",
"discrimination": discrim or entry.get("justification", ""),
"regle_pmsi": regle or "Code conforme aux règles PMSI.",
"code": entry.get("code", "?"),
"confidence": entry.get("confidence", "medium"),
"justification": entry.get("justification", ""),
}
return json.dumps(resp, ensure_ascii=False)
def generate_contexte_samples() -> list[str]:
"""Génère des contextes patients variés."""
return [
"- Patient : Homme, 72 ans, IMC 28.5\n- Durée séjour : 7 jours\n- Biologie : CRP 145 [N: 0-5] (↑), Créatinine 89 [N: 50-120]",
"- Patient : Femme, 58 ans, IMC 24.1\n- Durée séjour : 3 jours\n- Biologie : Hémoglobine 10.2 [N: 12-17] (↑), Plaquettes 180 [N: 150-400]",
"- Patient : Homme, 85 ans, IMC 22.0\n- Durée séjour : 12 jours\n- Antécédents : HTA, Diabète de type 2, BPCO\n- Biologie : CRP 89 [N: 0-5] (↑), Leucocytes 14.2 [N: 4-10] (↑)",
"- Patient : Femme, 45 ans\n- Durée séjour : 2 jours",
"- Patient : Homme, 67 ans, IMC 31.2\n- Durée séjour : 5 jours\n- Biologie : ASAT 85 [N: 0-40] (↑), ALAT 120 [N: 0-40] (↑), GGT 210 [N: 0-60] (↑)",
"- Patient : Femme, 30 ans\n- Durée séjour : 4 jours\n- Biologie : CRP 25 [N: 0-5] (↑), Leucocytes 12.5 [N: 4-10] (↑)",
"- Patient : Homme, 78 ans, IMC 20.1\n- Durée séjour : 14 jours\n- Antécédents : Insuffisance cardiaque, FA\n- Complications : Infection urinaire",
"Non précisé",
]
def _parse_cache_key(key):
"""Extraire le type (dp/das) et le texte depuis la clé du cache."""
if key.startswith("das_llm::das_extract::"):
parts = key.split("::", 3)
texte = parts[3] if len(parts) > 3 else parts[-1]
return "das", texte.strip()
if "::" in key:
diag_type, texte = key.split("::", 1)
return diag_type.strip(), texte.strip()
return "das", key.strip()
def _find_alt_code(code: str, cim10_dict: dict[str, str]) -> tuple[str, str]:
"""Trouve un code alternatif (même catégorie) pour les motifs CPAM."""
prefix = code[:3]
for c, label in cim10_dict.items():
if c[:3] == prefix and c != code and len(c) >= 4:
return c, label
# Fallback : code .9 (SAI)
sai = f"{prefix}.9"
if sai in cim10_dict and sai != code:
return sai, cim10_dict[sai]
return "Z03.9", "Observation pour suspicion non précisée"
def main():
print("Chargement du cache Ollama gemma3:12b (toutes sources)...")
entries = load_cache_entries()
print(f" Total fusionné : {len(entries)} entrées")
print("Chargement du dictionnaire CIM-10...")
cim10_dict = load_cim10_dict()
print(f" {len(cim10_dict)} codes")
contextes = generate_contexte_samples()
examples = []
n_short = 0
n_long = 0
n_cpam = 0
for key, entry in entries.items():
code = entry.get("code", "")
if not code or code == "?" or len(code) < 3:
continue
diag_type, texte = _parse_cache_key(key)
if diag_type not in ("dp", "das"):
continue
if not texte or len(texte) < 3:
continue
# Vérifier que le raisonnement est substantiel
raisonnement = entry.get("raisonnement", "")
if len(raisonnement) < 50:
continue
response_json = build_response_json(entry)
# V2 : TOUJOURS générer version courte + version longue
# 1) Version courte (sans sources ni contexte)
if diag_type == "dp":
user_content = PROMPT_TEMPLATE_DP.format(texte=texte)
else:
user_content = PROMPT_TEMPLATE_DAS.format(texte=texte)
examples.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
{"role": "assistant", "content": response_json},
]
})
n_short += 1
# 2) Version longue (avec sources CIM-10 et contexte simulé)
source_text = build_fake_source(code, cim10_dict)
if source_text:
ctx = random.choice(contextes)
if diag_type == "dp":
user_long = PROMPT_TEMPLATE_DP_LONG.format(
texte=texte, contexte=ctx, sources=source_text
)
else:
user_long = PROMPT_TEMPLATE_DAS_LONG.format(
texte=texte, contexte=ctx, sources=source_text
)
examples.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_long},
{"role": "assistant", "content": response_json},
]
})
n_long += 1
# 3) V2 : Version CPAM (contre-argumentation) — 20% des entrées DAS high confidence
if (diag_type == "das"
and entry.get("confidence") == "high"
and source_text
and random.random() < 0.20):
type_label = "DAS (diagnostic associé significatif)"
alt_code, alt_label = _find_alt_code(code, cim10_dict)
motif = random.choice(CPAM_MOTIFS).format(
code=code, alt_code=alt_code, alt_label=alt_label
)
user_cpam = PROMPT_TEMPLATE_CPAM.format(
texte=texte, type_label=type_label, code=code,
motif_cpam=motif, sources=source_text,
)
examples.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_cpam},
{"role": "assistant", "content": response_json},
]
})
n_cpam += 1
# Mélanger
random.shuffle(examples)
# Écrire
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
with open(OUTPUT, "w") as f:
for ex in examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
# Stats
token_counts = []
for ex in examples:
text = " ".join(m["content"] for m in ex["messages"])
token_counts.append(int(len(text.split()) * 1.3))
print(f"\n{'='*50}")
print(f"Exemples pipeline générés : {len(examples)}")
print(f" Courts : {n_short}, Longs : {n_long}, CPAM : {n_cpam}")
print(f"{OUTPUT}")
print(f" Tokens : moy={sum(token_counts)//len(token_counts)}, "
f"max={max(token_counts)}, min={min(token_counts)}")
print(f" Taille : {OUTPUT.stat().st_size / 1024 / 1024:.1f} Mo")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,373 @@
#!/usr/bin/env python3
"""
Parsing agressif des 10 fascicules ATIH → Q&A de raisonnement DIM.
Stratégie :
- Découpe chaque fascicule en paragraphes (pas sections)
- Extrait les règles de codage via regex
- Pour chaque règle extraite, génère 3 Q&A raisonnement via Claude Opus 4.6
- Question = scénario clinique appliquant la règle
- Réponse = JSON structuré {analyse_clinique, regle_pmsi, code, justification}
Cible : ~450 exemples (151 règles × 3 exercices)
Sources : data/raw/referentiels/ (10 fascicules PDF, via lien t2a)
Nécessite : ANTHROPIC_API_KEY en variable d'environnement
Usage :
python scripts/13_generate_fascicule_reasoning.py [--dry-run] [--max N]
"""
import argparse
import json
import os
import random
import re
import sys
import time
from pathlib import Path
import pdfplumber
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
T2A = BASE.parent / "t2a"
REF_DIR = T2A / "data" / "referentiels"
OUTPUT = BASE / "data" / "processed" / "fascicule_reasoning_chatml.jsonl"
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
MODEL = "claude-opus-4-6"
SYSTEM_PROMPT = (
"Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. "
"Tu t'appuies sur les référentiels ATIH officiels."
)
# Fascicules à parser
FASCICULES = {
"Fascicule_01_Generalites": "Généralités du codage PMSI",
"Fascicule_02_Maladies_digestives": "Maladies de l'appareil digestif",
"Fascicule_03_Tumeurs": "Tumeurs",
"Fascicule_04_Metabolisme": "Métabolisme",
"Fascicule_05_Gyneco_Obstetrique": "Gynécologie et Obstétrique",
"Fascicule_06_Neonatalogie": "Néonatalogie",
"Fascicule_07_Evolutions_2010": "Évolutions 2010",
"Fascicule_08_Maladies_infectieuses": "Maladies infectieuses",
"Fascicule_09_AVC": "Accidents vasculaires cérébraux",
"Fascicule_10_SCA_Coronariens": "Syndromes coronariens aigus",
}
# Regex pour détecter les règles de codage dans le texte
RULE_PATTERNS = [
re.compile(r"(?:on\s+code(?:ra)?|se\s+code|coder?|est\s+cod[ée](?:e)?)\s+(?:avec\s+|par\s+|en\s+)?([A-Z]\d{2}(?:\.\d{1,2})?)", re.IGNORECASE),
re.compile(r"ne\s+(?:pas|doit\s+pas|faut\s+pas)\s+coder", re.IGNORECASE),
re.compile(r"à\s+l['\u2019]exclusion\s+de", re.IGNORECASE),
re.compile(r"(?:comprend|inclut|inclus)\s+", re.IGNORECASE),
re.compile(r"(?:dans\s+ce\s+cas|en\s+cas\s+de|si\s+le\s+patient|lorsque)", re.IGNORECASE),
re.compile(r"(?:le\s+DP|diagnostic\s+principal)\s+(?:est|sera|doit\s+être)", re.IGNORECASE),
re.compile(r"(?:le\s+DAS|diagnostic\s+associé)\s+(?:est|sera|doit\s+être)", re.IGNORECASE),
re.compile(r"(?:CMA|sévérité|niveau\s+\d)", re.IGNORECASE),
re.compile(r"(?:séjour|durée|ressources?\s+supplémentaires?)", re.IGNORECASE),
]
GENERATION_PROMPT = """Tu es un formateur DIM. À partir de cet extrait du fascicule ATIH, génère EXACTEMENT 3 exercices de raisonnement distincts.
EXTRAIT (source : fascicule ATIH « {topic} ») :
{rule_text}
Génère un tableau JSON de 3 objets, chacun avec :
- "scenario" : un cas clinique réaliste et concis (2-3 phrases), DIFFÉRENT des autres
- "reponse" : un objet contenant :
- "analyse_clinique" : interprétation du cas
- "regle_pmsi" : la règle du fascicule qui s'applique
- "code" : le code CIM-10 correct (ou null si ne pas coder)
- "confidence" : "high"
- "justification" : pourquoi cette règle s'applique
Réponds UNIQUEMENT avec le tableau JSON (pas d'objet wrapper), sans texte avant/après.
Exemple de format : [{{"scenario": "...", "reponse": {{...}}}}, ...]"""
def extract_paragraphs(pdf_path: Path) -> list[str]:
"""Extraire les paragraphes d'un fascicule PDF.
Stratégie : pdfplumber ne produit pas de lignes vides entre paragraphes.
On découpe par page, puis par titres/sections détectées via heuristiques.
"""
pages_text = []
with pdfplumber.open(str(pdf_path)) as pdf:
for page in pdf.pages:
text = page.extract_text() or ""
if text.strip():
pages_text.append(text)
# Concaténer avec séparateur de page
full_text = "\n\n".join(pages_text)
# Nettoyer les lignes non pertinentes
lines = []
for line in full_text.split("\n"):
stripped = line.strip()
if "...." in stripped or "TABLE DES MATIERES" in stripped:
continue
if re.match(r"^\s*\d{1,3}\s*$", stripped):
continue
if re.match(r"^(Créé le|Mis à jour|Modifié le|FASCICULE DE CODAGE)", stripped):
continue
lines.append(line)
clean_text = "\n".join(lines)
# Découper en sections par les titres détectés
title_re = re.compile(
r"^(?:"
r"[IVX]+\.\d*\.?\s+"
r"|[IVX]+\s+[-]\s+"
r"|\d+\.\d*\.?\s+[A-ZÀÂÉÈÊËÎÏ]"
r"|[A-ZÀÂÉÈÊËÎÏÔÙÛÜ][A-ZÀÂÉÈÊËÎÏÔÙÛÜ\s]{3,}$"
r")",
re.MULTILINE
)
paragraphs = []
matches = list(title_re.finditer(clean_text))
if matches:
for i, match in enumerate(matches):
start = match.start()
end = matches[i + 1].start() if i + 1 < len(matches) else len(clean_text)
section = clean_text[start:end].strip()
if len(section) > 100:
paragraphs.append(section)
# Fallback : découper par blocs de ~800 caractères
if len(paragraphs) < 5:
paragraphs = []
block_size = 800
current_block = []
current_len = 0
for line in clean_text.split("\n"):
current_block.append(line)
current_len += len(line) + 1
if current_len >= block_size:
para = "\n".join(current_block).strip()
if len(para) > 100:
paragraphs.append(para)
current_block = []
current_len = 0
if current_block:
para = "\n".join(current_block).strip()
if len(para) > 100:
paragraphs.append(para)
return paragraphs
def extract_rules(paragraphs: list[str]) -> list[tuple[str, int]]:
"""Extraire les paragraphes classés par pertinence (score de règle).
Returns: [(text, score)] trié par score décroissant.
"""
scored = []
for para in paragraphs:
score = sum(1 for pat in RULE_PATTERNS if pat.search(para))
codes = re.findall(r"[A-Z]\d{2}(?:\.\d{1,2})?", para)
if codes:
score += 1
if score >= 1:
text = para[:1500] if len(para) > 1500 else para
scored.append((text, score))
scored.sort(key=lambda x: -x[1])
return scored
def call_claude(client, prompt: str, max_retries: int = 2) -> str | None:
"""Appel Claude Opus 4.6 via API Anthropic avec retry."""
for attempt in range(max_retries + 1):
try:
response = client.messages.create(
model=MODEL,
max_tokens=4096,
temperature=0.7,
messages=[{"role": "user", "content": prompt}],
)
return response.content[0].text
except Exception as e:
if attempt < max_retries:
wait = 2 ** (attempt + 1)
print(f" Retry in {wait}s: {e}")
time.sleep(wait)
else:
print(f" Claude error: {e}")
return None
def parse_llm_response(response_text: str) -> list[dict]:
"""Parse la réponse JSON du LLM (tableau de 3 exercices ou objet unique)."""
if not response_text:
return []
text = response_text.strip()
if "```json" in text:
text = text.split("```json", 1)[1].split("```", 1)[0].strip()
elif "```" in text:
text = text.split("```", 1)[1].split("```", 1)[0].strip()
try:
data = json.loads(text)
if isinstance(data, list):
return [d for d in data if isinstance(d, dict) and "scenario" in d and "reponse" in d]
if isinstance(data, dict) and "scenario" in data and "reponse" in data:
return [data]
except json.JSONDecodeError:
pass
# Fallback : chercher un tableau [...]
bracket_start = text.find("[")
if bracket_start >= 0:
depth = 0
for i in range(bracket_start, len(text)):
if text[i] == "[":
depth += 1
elif text[i] == "]":
depth -= 1
if depth == 0:
try:
data = json.loads(text[bracket_start:i+1])
if isinstance(data, list):
return [d for d in data if isinstance(d, dict) and "scenario" in d]
except json.JSONDecodeError:
break
# Fallback : objet unique
brace_start = text.find("{")
if brace_start >= 0:
depth = 0
for i in range(brace_start, len(text)):
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
if depth == 0:
try:
data = json.loads(text[brace_start:i+1])
if "scenario" in data:
return [data]
except json.JSONDecodeError:
break
return []
def make_chatml(scenario: str, response: dict, topic: str) -> dict:
"""Créer un exemple ChatML depuis le scénario + réponse structurée."""
user_content = (
f"Cas clinique :\n{scenario}\n\n"
f"Code ce cas en CIM-10 selon les règles du fascicule « {topic} ».\n\n"
"Réponds avec un JSON structuré contenant : analyse_clinique, regle_pmsi, code, confidence, justification."
)
assistant_content = json.dumps(response, ensure_ascii=False)
return {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
{"role": "assistant", "content": assistant_content},
]
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dry-run", action="store_true", help="Pas d'appel LLM, affiche les règles extraites")
parser.add_argument("--max", type=int, default=0, help="Max règles par fascicule (0=illimité)")
args = parser.parse_args()
print("=" * 60)
print("Génération de Q&A raisonnement depuis les fascicules ATIH")
print(f"Modèle : {MODEL}")
print("=" * 60)
# Vérifier la clé API
if not args.dry_run:
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
print("Erreur: ANTHROPIC_API_KEY non définie.")
print(" export ANTHROPIC_API_KEY='sk-ant-...'")
sys.exit(1)
import anthropic
client = anthropic.Anthropic(api_key=api_key)
else:
client = None
all_examples = []
total_rules = 0
for filename_part, topic in FASCICULES.items():
pdf_files = list(REF_DIR.glob(f"*{filename_part}*"))
pdf_files = [f for f in pdf_files if "redacted" not in f.name.lower() and "pseudonymise" not in str(f)]
if not pdf_files:
print(f"\n{topic} : PDF non trouvé, skip")
continue
pdf_path = pdf_files[0]
print(f"\n{''*40}")
print(f"{topic} ({pdf_path.name})")
paragraphs = extract_paragraphs(pdf_path)
print(f" Paragraphes extraits : {len(paragraphs)}")
scored_rules = extract_rules(paragraphs)
print(f" Règles/paragraphes pertinents : {len(scored_rules)}")
if args.max > 0:
scored_rules = scored_rules[:args.max]
total_rules += len(scored_rules)
if args.dry_run:
for i, (rule, score) in enumerate(scored_rules[:5]):
print(f" [{i+1}] (score={score}) {rule[:120]}...")
continue
# Générer les Q&A via Claude (3 exercices par règle)
n_ok = 0
n_fail = 0
for i, (rule_text, score) in enumerate(scored_rules):
prompt = GENERATION_PROMPT.format(topic=topic, rule_text=rule_text)
response_text = call_claude(client, prompt)
exercises = parse_llm_response(response_text)
for ex in exercises:
if "scenario" in ex and "reponse" in ex:
example = make_chatml(ex["scenario"], ex["reponse"], topic)
all_examples.append(example)
n_ok += 1
if not exercises:
n_fail += 1
if (i + 1) % 10 == 0:
print(f" Progression : {i+1}/{len(scored_rules)} (exemples={n_ok}, échecs={n_fail})")
print(f" Résultat : {n_ok} exemples générés, {n_fail} échecs")
if args.dry_run:
print(f"\n[DRY RUN] Total règles détectées : {total_rules}")
print("Relancez sans --dry-run pour générer les exemples avec Claude.")
return
# Mélanger et sauvegarder
random.shuffle(all_examples)
with open(OUTPUT, "w") as f:
for ex in all_examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
print(f"\n{'='*60}")
print(f"Total : {len(all_examples)} exemples → {OUTPUT}")
print(f"Taille : {OUTPUT.stat().st_size / 1024:.0f} Ko")
print(f"Règles sources : {total_rules}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,396 @@
#!/usr/bin/env python3
"""
Génère des exemples négatifs : enseigner au modèle quand NE PAS coder.
3 types d'exemples :
a) Codes rejetés (500) — symptômes couverts par le DP
b) Redondances sémantiques (200) — paires dominé/dominant
c) DAS non significatifs (300) — antécédents sans ressources consommées
Sources :
- Cache Ollama (textes diagnostics réels)
- Règles PMSI (SEMANTIC_REDUNDANCIES, symptômes R00-R99)
- Templates + CIM-10 FHIR
Produit : data/processed/negative_chatml.jsonl
Usage :
python scripts/14_generate_negative_examples.py
"""
import json
import random
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
T2A = BASE.parent / "t2a"
RAW = BASE / "data" / "raw"
OUTPUT = BASE / "data" / "processed" / "negative_chatml.jsonl"
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
SYSTEM_PROMPT = (
"Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. "
"Tu sais quand un diagnostic ne doit PAS être codé."
)
def make_chatml(user: str, assistant: str) -> dict:
return {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
}
def load_cim10_dict() -> dict[str, str]:
"""Charge le dictionnaire CIM-10."""
d = {}
for path in [T2A / "data" / "cim10_dict.json", T2A / "data" / "cim10_supplements.json"]:
if path.exists():
d.update(json.loads(path.read_text()))
return d
def load_fhir_concepts() -> dict[str, dict]:
"""Charge les concepts FHIR indexés par code."""
fhir_path = RAW / "smt_cim10_fhir.json"
if not fhir_path.exists():
return {}
data = json.loads(fhir_path.read_text())
by_code = {}
for c in data.get("concept", []):
by_code[c["code"]] = c
return by_code
# ─── Type A : Symptômes couverts par le DP ────────────────────────────────────
# Symptômes R00-R99 fréquemment codés à tort comme DAS quand le DP les explique
SYMPTOM_DP_PAIRS = [
# (symptôme_code, symptôme_label, dp_code, dp_label, explication)
("R50.9", "Fièvre, sans précision", "A41.9", "Sepsis, sans précision",
"La fièvre est un symptôme cardinal du sepsis. Elle est couverte par le DP A41.9."),
("R50.9", "Fièvre, sans précision", "J18.9", "Pneumonie, sans précision",
"La fièvre est un symptôme habituel de la pneumonie. Le DP J18.9 la couvre."),
("R50.9", "Fièvre, sans précision", "N10", "Néphrite tubulo-interstitielle aiguë",
"La fièvre accompagne habituellement la pyélonéphrite aiguë (N10)."),
("R06.0", "Dyspnée", "J44.1", "BPCO avec exacerbation aiguë",
"La dyspnée est le symptôme principal de l'exacerbation de BPCO."),
("R06.0", "Dyspnée", "I50.9", "Insuffisance cardiaque, sans précision",
"La dyspnée est un symptôme majeur de l'insuffisance cardiaque."),
("R06.0", "Dyspnée", "J96.0", "Insuffisance respiratoire aiguë",
"La dyspnée est couverte par le DP d'insuffisance respiratoire aiguë."),
("R07.4", "Douleur thoracique, sans précision", "I21.9", "IDM aigu, sans précision",
"La douleur thoracique est le symptôme principal de l'IDM."),
("R07.4", "Douleur thoracique, sans précision", "I20.0", "Angor instable",
"La douleur thoracique est le symptôme cardinal de l'angor instable."),
("R10.4", "Douleur abdominale, sans précision", "K80.1", "Lithiase vésiculaire avec cholécystite",
"La douleur abdominale est couverte par le DP de cholécystite."),
("R10.4", "Douleur abdominale, sans précision", "K35.8", "Appendicite aiguë, autre et sans précision",
"La douleur abdominale est le symptôme principal de l'appendicite."),
("R11.0", "Nausées", "K29.1", "Gastrite aiguë",
"Les nausées sont un symptôme courant de la gastrite, couvert par le DP."),
("R11.2", "Nausées avec vomissements, sans précision", "K56.6", "Occlusion intestinale, autre et sans précision",
"Les vomissements sont un signe cardinal de l'occlusion intestinale."),
("R00.0", "Tachycardie, sans précision", "A41.9", "Sepsis, sans précision",
"La tachycardie est un critère diagnostique du sepsis."),
("R00.0", "Tachycardie, sans précision", "I48.9", "Fibrillation auriculaire",
"La tachycardie est un symptôme de la FA, couverte par le DP."),
("R41.0", "Désorientation, sans précision", "F05.9", "Delirium, sans précision",
"La désorientation est un symptôme constitutif du delirium."),
("R40.0", "Somnolence", "S06.9", "Lésion traumatique intracrânienne, sans précision",
"La somnolence est un signe d'atteinte neurologique dans le trauma crânien."),
("R42", "Étourdissements et éblouissements", "H81.1", "Vertige paroxystique bénin",
"Les étourdissements sont le symptôme principal du VPPB."),
("R51", "Céphalée", "G43.9", "Migraine, sans précision",
"La céphalée est le symptôme principal de la migraine."),
("R63.0", "Anorexie", "C16.9", "Tumeur maligne de l'estomac",
"L'anorexie est un symptôme fréquent du cancer gastrique, couvert par le DP."),
("R53", "Malaise et fatigue", "D64.9", "Anémie, sans précision",
"La fatigue est un symptôme courant de l'anémie."),
("R31", "Hématurie, sans précision", "N20.0", "Calcul du rein",
"L'hématurie est un symptôme classique de la lithiase rénale."),
("R73.0", "Hyperglycémie SAI", "E11.9", "Diabète de type 2",
"L'hyperglycémie est un signe du diabète de type 2, couvert par le DP."),
("R60.0", "Oedème localisé", "I50.0", "Insuffisance cardiaque congestive",
"L'oedème est un signe de l'ICC, couvert par le DP."),
("R09.2", "Arrêt respiratoire", "J96.0", "Insuffisance respiratoire aiguë",
"L'arrêt respiratoire est la forme extrême de l'insuffisance respiratoire aiguë."),
("R57.0", "Choc cardiogénique", "I21.9", "IDM aigu",
"Le choc cardiogénique comme complication de l'IDM peut se coder comme DAS (mobilise des ressources), mais s'il est le tableau initial il est couvert par le DP."),
]
def generate_symptom_dp_examples(cim10_dict: dict, target: int = 500) -> list[dict]:
"""Générer des exemples de symptômes couverts par le DP."""
examples = []
# Templates de variation
user_templates = [
"Code ce diagnostic : {symptom_label}\nTYPE : DAS\nCONTEXTE : DP = {dp_code} ({dp_label})",
"Le patient est hospitalisé pour {dp_label} (DP : {dp_code}). Faut-il coder {symptom_label} ({symptom_code}) en DAS ?",
"DAS candidat : {symptom_code} ({symptom_label})\nDP du séjour : {dp_code} ({dp_label})\nCe DAS est-il pertinent ?",
]
assistant_template_null = json.dumps({
"code": None,
"confidence": "high",
"justification": "{explanation} Règle PMSI : ne pas coder un symptôme (R00-R99) si le diagnostic qui l'explique est déjà codé comme DP."
}, ensure_ascii=False)
# Générer depuis les paires prédéfinies (avec variations de templates)
for sym_code, sym_label, dp_code, dp_label, expl in SYMPTOM_DP_PAIRS:
for tmpl in user_templates:
user = tmpl.format(
symptom_code=sym_code, symptom_label=sym_label,
dp_code=dp_code, dp_label=dp_label
)
assistant = assistant_template_null.replace("{explanation}", expl)
examples.append(make_chatml(user, assistant))
# Générer des variations supplémentaires pour atteindre la cible
# Utiliser tous les codes R du dictionnaire CIM-10
r_codes = [(c, l) for c, l in cim10_dict.items() if c.startswith("R") and len(c) >= 4]
non_r_codes = [(c, l) for c, l in cim10_dict.items()
if not c.startswith("R") and not c.startswith("Z")
and c[0].isalpha() and len(c) >= 4 and len(l) > 5]
while len(examples) < target and r_codes and non_r_codes:
sym_code, sym_label = random.choice(r_codes)
dp_code, dp_label = random.choice(non_r_codes)
tmpl = random.choice(user_templates)
user = tmpl.format(
symptom_code=sym_code, symptom_label=sym_label,
dp_code=dp_code, dp_label=dp_label
)
expl = f"{sym_label} est un symptôme (code R) potentiellement couvert par le DP {dp_code} ({dp_label})."
assistant = json.dumps({
"code": None,
"confidence": "medium",
"justification": f"{expl} Règle PMSI : ne pas coder un symptôme en DAS s'il est expliqué par le DP. Vérifier si le symptôme a nécessité une prise en charge spécifique supplémentaire."
}, ensure_ascii=False)
examples.append(make_chatml(user, assistant))
random.shuffle(examples)
return examples[:target]
# ─── Type B : Redondances sémantiques ─────────────────────────────────────────
SEMANTIC_REDUNDANCIES = [
# (dominated_prefix, dominant_prefixes, explanation)
("I10", ["I11", "I12", "I13"],
"I10 (HTA essentielle) est redondant quand I11/I12/I13 est présent. Le code hypertensif spécifique inclut la composante HTA."),
("N30", ["N39"],
"N30 (cystite) est redondant quand N39.0 (infection urinaire) est présent. L'infection urinaire couvre la cystite."),
("J18", ["J15", "J16"],
"J18 (pneumonie SAI) est redondant quand J15/J16 (pneumonie spécifique) est présent. Le code spécifique prime."),
("E11.9", ["E11.0", "E11.1", "E11.2", "E11.3", "E11.4", "E11.5", "E11.6", "E11.7"],
"E11.9 (diabète type 2 SAI) est redondant si un sous-code E11.x spécifiant une complication est présent."),
("I25.9", ["I25.1", "I25.2", "I25.5"],
"I25.9 (cardiopathie ischémique chronique SAI) est redondant si un sous-code I25.x plus spécifique est présent."),
("N18.9", ["N18.1", "N18.2", "N18.3", "N18.4", "N18.5"],
"N18.9 (IRC SAI) est redondant si un stade N18.x spécifique est présent."),
("J44.9", ["J44.0", "J44.1"],
"J44.9 (BPCO SAI) est redondant si J44.0 (BPCO avec infection) ou J44.1 (BPCO avec exacerbation) est présent."),
("K21.9", ["K21.0"],
"K21.9 (RGO SAI) est redondant si K21.0 (RGO avec œsophagite) est présent."),
]
def generate_redundancy_examples(cim10_dict: dict, target: int = 200) -> list[dict]:
"""Générer des exemples de redondances sémantiques."""
examples = []
user_templates = [
"DAS candidats : {dominated} ({dom_label}), {dominant} ({sup_label})\nLesquels garder ?",
"Le codage inclut {dominated} et {dominant}. Y a-t-il une redondance ?",
"Vérification de codage :\n- DAS1 : {dominated} ({dom_label})\n- DAS2 : {dominant} ({sup_label})\nCes deux DAS sont-ils tous les deux pertinents ?",
]
for dominated_prefix, dominant_prefixes, explanation in SEMANTIC_REDUNDANCIES:
# Trouver des codes réels pour chaque préfixe
dom_codes = [(c, l) for c, l in cim10_dict.items() if c.startswith(dominated_prefix) and len(l) > 3]
sup_codes = [(c, l) for c, l in cim10_dict.items()
if any(c.startswith(dp) for dp in dominant_prefixes) and len(l) > 3]
if not dom_codes or not sup_codes:
continue
for _ in range(target // len(SEMANTIC_REDUNDANCIES) + 1):
dom_code, dom_label = random.choice(dom_codes)
sup_code, sup_label = random.choice(sup_codes)
tmpl = random.choice(user_templates)
user = tmpl.format(
dominated=dom_code, dom_label=dom_label,
dominant=sup_code, sup_label=sup_label
)
assistant = json.dumps({
"garder": [sup_code],
"retirer": [dom_code],
"justification": explanation
}, ensure_ascii=False)
examples.append(make_chatml(user, assistant))
random.shuffle(examples)
return examples[:target]
# ─── Type C : DAS non significatifs ───────────────────────────────────────────
# Diagnostics fréquemment mentionnés dans les antécédents mais sans ressources
NON_SIGNIFICANT_DAS = [
("J30.1", "Rhinite allergique due au pollen", "rhinite allergique mentionnée dans les antécédents"),
("J45.9", "Asthme, sans précision", "asthme stable mentionné dans les antécédents"),
("M54.5", "Lombalgie basse", "lombalgie chronique mentionnée dans les antécédents"),
("K21.0", "RGO avec œsophagite", "RGO mentionné dans les antécédents"),
("H52.1", "Myopie", "myopie mentionnée dans les antécédents"),
("E78.0", "Hypercholestérolémie pure", "hypercholestérolémie dans les antécédents, traitement habituel"),
("E03.9", "Hypothyroïdie, sans précision", "hypothyroïdie sous Lévothyrox dans les antécédents"),
("F32.0", "Épisode dépressif léger", "dépression traitée mentionnée dans les antécédents"),
("G47.3", "Apnée du sommeil", "SAOS appareillé mentionné dans les antécédents"),
("M81.9", "Ostéoporose sans fracture", "ostéoporose connue dans les antécédents"),
("H40.1", "Glaucome primaire à angle ouvert", "glaucome traité dans les antécédents"),
("I84.1", "Hémorroïdes internes avec complication", "hémorroïdes mentionnées dans les antécédents"),
("K58.9", "Syndrome du côlon irritable", "SCI mentionné dans les antécédents"),
("L40.0", "Psoriasis vulgaire", "psoriasis stable dans les antécédents"),
("N40", "Hyperplasie de la prostate", "HBP traitée dans les antécédents"),
("E66.9", "Obésité, sans précision", "obésité mentionnée, pas de prise en charge spécifique"),
("Z87.1", "Antécédents personnels de maladies de l'appareil digestif", "antécédent de chirurgie digestive ancienne"),
("Z86.7", "Antécédents personnels de maladies de l'appareil circulatoire", "antécédent d'AVC il y a 5 ans"),
("Z92.1", "Antécédents de traitement anticoagulant au long cours", "patient sous AVK au long cours"),
("F17.2", "Dépendance au tabac", "tabagisme actif mais pas de sevrage pendant le séjour"),
]
def generate_non_significant_examples(target: int = 300) -> list[dict]:
"""Générer des exemples de DAS non significatifs (antécédents sans ressources)."""
examples = []
user_templates = [
"Le patient a {context}. Faut-il le coder en DAS ?",
"Antécédent : {code} ({label}). Le patient est hospitalisé pour une autre raison. Ce diagnostic doit-il être codé comme DAS ?",
"Le dossier mentionne : {context}. Est-ce un DAS pertinent pour le séjour ?\nDP : {dp_code} ({dp_label})",
"Lors du codage du séjour (DP : {dp_code}), le dossier fait état de {context}. Coder {code} en DAS ?",
]
dp_examples = [
("K80.1", "Lithiase vésiculaire avec cholécystite"),
("S72.0", "Fracture du col du fémur"),
("I63.9", "Infarctus cérébral, sans précision"),
("J18.1", "Pneumonie lobaire, sans précision"),
("I21.9", "IDM aigu, sans précision"),
("C34.9", "Tumeur maligne des bronches ou du poumon"),
("K35.8", "Appendicite aiguë"),
("N20.0", "Calcul du rein"),
("G45.9", "AIT, sans précision"),
("A41.9", "Sepsis, sans précision"),
("C18.9", "Tumeur maligne du côlon"),
("I48.9", "Fibrillation auriculaire"),
]
# Générer en croisant chaque DAS avec plusieurs DPs
# On veut ~240 négatifs pour avoir assez de marge avec les positifs
combos_per_das = max(3, (target * 3 // 4) // len(NON_SIGNIFICANT_DAS) + 1)
for code, label, context in NON_SIGNIFICANT_DAS:
selected_dps = random.sample(dp_examples, min(combos_per_das, len(dp_examples)))
for dp_code, dp_label in selected_dps:
tmpl = random.choice(user_templates)
user = tmpl.format(
code=code, label=label, context=context,
dp_code=dp_code, dp_label=dp_label
)
assistant = json.dumps({
"coder": False,
"code": code,
"justification": f"Un DAS ne doit être codé que s'il a nécessité des ressources supplémentaires pendant le séjour (examens, traitements, surveillance spécifique). {label} mentionné uniquement dans les antécédents, sans prise en charge spécifique durant le séjour, ne justifie pas un DAS."
}, ensure_ascii=False)
examples.append(make_chatml(user, assistant))
# Ajouter des exemples positifs (contraste) — DAS qui DOIVENT être codés
positive_das = [
("E11.6", "Diabète de type 2 avec complications", "diabète déséquilibré ayant nécessité une adaptation thérapeutique",
"Le diabète a mobilisé des ressources supplémentaires (adaptation insuline, surveillance glycémique renforcée)."),
("I10", "HTA essentielle", "HTA sévère ayant nécessité un traitement IV pendant le séjour",
"L'HTA a nécessité un traitement intraveineux spécifique, justifiant le codage en DAS."),
("N18.4", "IRC stade 4", "IRC ayant nécessité une adaptation posologique de tous les médicaments",
"L'IRC a mobilisé des ressources (adaptation posologique, surveillance créatinine quotidienne)."),
("E87.1", "Hypo-osmolalité et hyponatrémie", "hyponatrémie sévère découverte pendant le séjour",
"L'hyponatrémie a nécessité un bilan étiologique et un traitement spécifique."),
("J96.0", "Insuffisance respiratoire aiguë", "détresse respiratoire ayant nécessité une oxygénothérapie",
"L'insuffisance respiratoire a mobilisé des ressources (oxygénothérapie, surveillance SpO2, GDS)."),
("N17.9", "Insuffisance rénale aiguë", "IRA survenue pendant le séjour ayant nécessité une surveillance biologique quotidienne",
"L'IRA a nécessité des bilans répétés et une adaptation thérapeutique, justifiant le DAS."),
("D62", "Anémie posthémorragique aiguë", "anémie aiguë ayant nécessité une transfusion de 2 CGR",
"L'anémie a mobilisé des ressources (transfusion, surveillance post-transfusionnelle)."),
("E87.6", "Hypokaliémie", "hypokaliémie sévère découverte en biologie nécessitant une supplémentation IV",
"L'hypokaliémie a nécessité un traitement spécifique IV, justifiant le DAS."),
]
combos_per_pos = max(3, (target // 4) // len(positive_das) + 1)
for code, label, context, justification in positive_das:
selected_dps = random.sample(dp_examples, min(combos_per_pos, len(dp_examples)))
for dp_code, dp_label in selected_dps:
user = f"Le patient est hospitalisé pour {dp_label} (DP : {dp_code}). Il présente aussi : {context}. Faut-il coder {code} ({label}) en DAS ?"
assistant = json.dumps({
"coder": True,
"code": code,
"justification": justification
}, ensure_ascii=False)
examples.append(make_chatml(user, assistant))
random.shuffle(examples)
return examples[:target]
def main():
print("=" * 60)
print("Génération d'exemples négatifs (quand NE PAS coder)")
print("=" * 60)
print("\nChargement du dictionnaire CIM-10...")
cim10_dict = load_cim10_dict()
print(f" {len(cim10_dict)} codes")
all_examples = []
# Type A : Symptômes couverts par le DP
print("\nType A : Symptômes couverts par le DP...")
symptom_examples = generate_symptom_dp_examples(cim10_dict, target=500)
print(f"{len(symptom_examples)} exemples")
all_examples.extend(symptom_examples)
# Type B : Redondances sémantiques
print("\nType B : Redondances sémantiques...")
redundancy_examples = generate_redundancy_examples(cim10_dict, target=200)
print(f"{len(redundancy_examples)} exemples")
all_examples.extend(redundancy_examples)
# Type C : DAS non significatifs
print("\nType C : DAS non significatifs...")
non_sig_examples = generate_non_significant_examples(target=300)
print(f"{len(non_sig_examples)} exemples")
all_examples.extend(non_sig_examples)
# Mélanger et sauvegarder
random.shuffle(all_examples)
with open(OUTPUT, "w") as f:
for ex in all_examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
print(f"\n{'='*60}")
print(f"Total : {len(all_examples)} exemples → {OUTPUT}")
print(f" Type A (symptômes/DP) : {len(symptom_examples)}")
print(f" Type B (redondances) : {len(redundancy_examples)}")
print(f" Type C (non signif.) : {len(non_sig_examples)}")
print(f"Taille : {OUTPUT.stat().st_size / 1024:.0f} Ko")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,321 @@
#!/usr/bin/env python3
"""
Exercices de discrimination entre codes CIM-10 siblings (même parent).
Stratégie :
- Utilise la hiérarchie FHIR pour identifier les groupes de siblings
- Focus sur les top 100 familles CIM-10 les plus fréquentes du pipeline
- Pour chaque groupe, génère un scénario clinique via Claude Opus 4.6
- La réponse explique pourquoi un code et pas l'autre
Cible : 800 exemples
Sources : smt_cim10_fhir.json + cache Ollama (codes fréquents)
Nécessite : ANTHROPIC_API_KEY en variable d'environnement
Usage :
python scripts/15_generate_discrimination.py [--dry-run] [--max N]
"""
import argparse
import json
import os
import random
import re
import sys
import time
from collections import Counter
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
T2A = BASE.parent / "t2a"
RAW = BASE / "data" / "raw"
OUTPUT = BASE / "data" / "processed" / "discrimination_chatml.jsonl"
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
MODEL = "claude-opus-4-6"
SYSTEM_PROMPT = (
"Tu es un médecin DIM expert en codage CIM-10 pour le PMSI français. "
"Tu sais discriminer les codes CIM-10 proches (siblings) et choisir le plus approprié."
)
GENERATION_PROMPT = """Tu es un formateur DIM. Génère un exercice de discrimination entre codes CIM-10 proches.
CODES À DISCRIMINER (même catégorie {parent}) :
{codes_list}
Génère un objet JSON avec :
1. "scenario" : un cas clinique réaliste (2-3 phrases) où le choix entre ces codes est subtil
2. "reponse" : un objet JSON contenant :
- "analyse_clinique" : interprétation du cas clinique
- "codes_candidats" : les 2-3 codes candidats et pourquoi chacun est envisagé
- "discrimination" : la différence clé entre ces codes (inclusions, exclusions, spécificité)
- "code" : le code correct pour ce scénario
- "confidence" : "high"
- "justification" : pourquoi CE code et pas les autres
Réponds UNIQUEMENT avec le JSON, sans texte avant/après."""
def load_fhir() -> tuple[list, dict]:
"""Charger les concepts FHIR."""
fhir_path = RAW / "smt_cim10_fhir.json"
data = json.loads(fhir_path.read_text())
concepts = data["concept"]
by_code = {c["code"]: c for c in concepts}
return concepts, by_code
def get_parent(concept: dict) -> str:
for p in concept.get("property", []):
if p["code"] == "parent":
return p.get("valueCode", "")
return ""
def get_type(concept: dict) -> str:
for p in concept.get("property", []):
if p["code"] == "type":
return p.get("valueString", "")
return ""
def get_inclusion_note(concept: dict) -> str:
for p in concept.get("property", []):
if p["code"] == "inclusionNote":
return p.get("valueString", "")
return ""
def get_exclusion_note(concept: dict) -> str:
for p in concept.get("property", []):
if p["code"] == "exclusionNote":
return p.get("valueString", "")
return ""
def get_frequent_families() -> Counter:
"""Extraire les familles CIM-10 les plus fréquentes depuis le cache Ollama."""
families = Counter()
for cache_path in [T2A / "data" / "ollama_cache.json", T2A / "data" / "ollama_cache_gemma3.bak"]:
if not cache_path.exists():
continue
data = json.loads(cache_path.read_text())
for entry in data.get("entries", {}).values():
code = entry.get("code", "")
if code and len(code) >= 3 and code[0].isalpha():
families[code[:3]] += 1
return families
def build_sibling_groups(concepts: list, by_code: dict) -> dict[str, list[dict]]:
"""Grouper les codes par parent (siblings)."""
children_by_parent = {}
for c in concepts:
if get_type(c) != "category":
continue
parent = get_parent(c)
if parent and parent in by_code:
children_by_parent.setdefault(parent, []).append(c)
return children_by_parent
def format_codes_for_prompt(siblings: list[dict]) -> str:
"""Formater les codes pour le prompt LLM."""
lines = []
for sib in siblings:
code = sib["code"]
display = sib["display"]
incl = get_inclusion_note(sib)
excl = get_exclusion_note(sib)
line = f"- {code} : {display}"
if incl:
line += f"\n Comprend : {incl[:200]}"
if excl:
line += f"\n Exclut : {excl[:200]}"
lines.append(line)
return "\n".join(lines)
def call_claude(client, prompt: str, max_retries: int = 2) -> str | None:
"""Appel Claude Opus 4.6 via API Anthropic avec retry."""
for attempt in range(max_retries + 1):
try:
response = client.messages.create(
model=MODEL,
max_tokens=2048,
temperature=0.7,
messages=[{"role": "user", "content": prompt}],
)
return response.content[0].text
except Exception as e:
if attempt < max_retries:
wait = 2 ** (attempt + 1)
print(f" Retry in {wait}s: {e}")
time.sleep(wait)
else:
print(f" Claude error: {e}")
return None
def parse_llm_response(response_text: str) -> dict | None:
"""Parse la réponse JSON du LLM."""
if not response_text:
return None
text = response_text.strip()
if "```json" in text:
text = text.split("```json", 1)[1].split("```", 1)[0].strip()
elif "```" in text:
text = text.split("```", 1)[1].split("```", 1)[0].strip()
try:
data = json.loads(text)
if "scenario" in data and "reponse" in data:
return data
except json.JSONDecodeError:
pass
# Fallback
brace_start = text.find("{")
if brace_start >= 0:
depth = 0
for i in range(brace_start, len(text)):
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
if depth == 0:
try:
data = json.loads(text[brace_start:i+1])
if "scenario" in data:
return data
except json.JSONDecodeError:
break
return None
def make_chatml(scenario: str, response: dict, parent_code: str, siblings_desc: str) -> dict:
"""Créer un exemple ChatML."""
user_content = (
f"Cas clinique :\n{scenario}\n\n"
f"Codes CIM-10 candidats (catégorie {parent_code}) :\n{siblings_desc}\n\n"
"Quel code est le plus approprié ? Explique ton raisonnement de discrimination."
)
assistant_content = json.dumps(response, ensure_ascii=False)
return {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
{"role": "assistant", "content": assistant_content},
]
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dry-run", action="store_true", help="Pas d'appel LLM")
parser.add_argument("--max", type=int, default=800, help="Max exemples à générer")
args = parser.parse_args()
print("=" * 60)
print("Génération d'exercices de discrimination CIM-10")
print(f"Modèle : {MODEL}")
print("=" * 60)
# Vérifier la clé API
if not args.dry_run:
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
print("Erreur: ANTHROPIC_API_KEY non définie.")
print(" export ANTHROPIC_API_KEY='sk-ant-...'")
sys.exit(1)
import anthropic
client = anthropic.Anthropic(api_key=api_key)
else:
client = None
print("\nChargement FHIR...")
concepts, by_code = load_fhir()
print(f" {len(concepts)} concepts")
print("\nIdentification des familles fréquentes (cache Ollama)...")
freq_families = get_frequent_families()
top_families = [code for code, _ in freq_families.most_common(150)]
print(f" Top familles : {len(top_families)} (ex: {', '.join(top_families[:10])})")
print("\nConstruction des groupes de siblings...")
sibling_groups = build_sibling_groups(concepts, by_code)
print(f" {len(sibling_groups)} groupes")
# Filtrer : 2-8 siblings, prioriser les familles fréquentes
candidates = []
for parent_code, siblings in sibling_groups.items():
if len(siblings) < 2 or len(siblings) > 8:
continue
priority = 2 if parent_code[:3] in top_families else 0
n_with_notes = sum(1 for s in siblings if get_inclusion_note(s) or get_exclusion_note(s))
priority += n_with_notes
candidates.append((parent_code, siblings, priority))
candidates.sort(key=lambda x: -x[2])
print(f" Candidats filtrés (2-8 siblings) : {len(candidates)}")
target = min(args.max, len(candidates))
candidates = candidates[:target]
if args.dry_run:
for parent_code, siblings, prio in candidates[:20]:
parent_display = by_code[parent_code]["display"] if parent_code in by_code else "?"
sib_codes = ", ".join(s["code"] for s in siblings)
print(f" [{prio}] {parent_code} ({parent_display}): {sib_codes}")
print(f"\n[DRY RUN] {len(candidates)} groupes à traiter. Relancez sans --dry-run.")
return
# Générer les exercices via Claude
examples = []
n_ok = 0
n_fail = 0
for i, (parent_code, siblings, _) in enumerate(candidates):
if len(siblings) > 4:
selected = random.sample(siblings, 4)
else:
selected = siblings
codes_list = format_codes_for_prompt(selected)
parent_display = by_code[parent_code]["display"] if parent_code in by_code else parent_code
prompt = GENERATION_PROMPT.format(parent=f"{parent_code} ({parent_display})", codes_list=codes_list)
response_text = call_claude(client, prompt)
parsed = parse_llm_response(response_text)
if parsed and "scenario" in parsed and "reponse" in parsed:
siblings_desc = "\n".join(f"- {s['code']} : {s['display']}" for s in selected)
example = make_chatml(parsed["scenario"], parsed["reponse"], parent_code, siblings_desc)
examples.append(example)
n_ok += 1
else:
n_fail += 1
if (i + 1) % 50 == 0:
print(f" Progression : {i+1}/{len(candidates)} (ok={n_ok}, fail={n_fail})")
# Mélanger et sauvegarder
random.shuffle(examples)
with open(OUTPUT, "w") as f:
for ex in examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
print(f"\n{'='*60}")
print(f"Total : {len(examples)} exemples → {OUTPUT}")
print(f" OK : {n_ok}, Échecs : {n_fail}")
print(f"Taille : {OUTPUT.stat().st_size / 1024:.0f} Ko")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,351 @@
#!/usr/bin/env python3
"""
Extraction des règles du Guide Méthodologique MCO 2026 → Q&A raisonnement.
Le Guide Méthodologique est le document de référence ATIH qui définit :
- Les règles de codage du DP, DR, DAS
- Les règles CMA et niveaux de sévérité
- Les règles de codage des actes CCAM
- Les cas particuliers (séjours multi-unités, transferts, etc.)
Stratégie :
- Parser le PDF en sections
- Extraire les règles de codage
- Générer des Q&A directes (sans LLM) pour les définitions
- Générer des Q&A de raisonnement via Claude Opus 4.6 pour les règles
- Focus sur les règles DP/DAS/CMA les plus applicables
Cible : 500 exemples (183 directs + ~320 LLM)
Source : data/raw/guide_methodo_mco_2026.pdf
Nécessite : ANTHROPIC_API_KEY en variable d'environnement
Usage :
python scripts/16_parse_guide_metho.py [--dry-run] [--max N]
"""
import argparse
import json
import os
import random
import re
import sys
import time
from pathlib import Path
import pdfplumber
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
PDF_PATH = BASE / "data" / "raw" / "guide_methodo_mco_2026.pdf"
OUTPUT = BASE / "data" / "processed" / "guide_metho_chatml.jsonl"
OUTPUT.parent.mkdir(parents=True, exist_ok=True)
MODEL = "claude-opus-4-6"
SYSTEM_PROMPT = (
"Tu es un médecin DIM expert en codage PMSI. "
"Tu t'appuies sur le Guide Méthodologique de production des informations "
"relatives à l'activité MCO (ATIH 2026)."
)
# Patterns pour détecter les règles dans le guide
RULE_PATTERNS = [
re.compile(r"(?:le\s+)?diagnostic\s+principal\s+(?:est|doit|sera|correspond)", re.IGNORECASE),
re.compile(r"(?:le\s+)?diagnostic\s+(?:relié|associé)\s+(?:est|doit|sera)", re.IGNORECASE),
re.compile(r"(?:la\s+)?CMA\s+(?:est|correspond|ne\s+peut)", re.IGNORECASE),
re.compile(r"(?:le\s+)?DAS\s+(?:est|doit|sera|ne\s+doit\s+pas)", re.IGNORECASE),
re.compile(r"(?:le\s+)?DP\s+(?:est|doit|sera|ne\s+doit\s+pas)", re.IGNORECASE),
re.compile(r"(?:le\s+)?DR\s+(?:est|doit|sera)", re.IGNORECASE),
re.compile(r"(?:on\s+)?code(?:ra)?\s+", re.IGNORECASE),
re.compile(r"ne\s+(?:pas|doit\s+pas|faut\s+pas)\s+(?:coder|enregistrer|recueillir)", re.IGNORECASE),
re.compile(r"séjour\s+multi-?uni", re.IGNORECASE),
re.compile(r"(?:transfert|mutation)\s+", re.IGNORECASE),
re.compile(r"(?:sévérité|niveau\s+de\s+sévérité|complication)", re.IGNORECASE),
re.compile(r"(?:ressources?\s+supplémentaires?|consomm)", re.IGNORECASE),
re.compile(r"(?:groupage|GHM|GHS|CMD)", re.IGNORECASE),
]
GENERATION_PROMPT = """Tu es un formateur DIM. À partir de cette règle du Guide Méthodologique MCO, génère un exercice de raisonnement.
RÈGLE (source : Guide Méthodologique MCO 2026, section « {section} ») :
{rule_text}
Génère un objet JSON avec :
1. "scenario" : un cas clinique réaliste et concis (2-3 phrases) qui illustre l'application de cette règle
2. "reponse" : un objet JSON contenant :
- "analyse_clinique" : interprétation du cas
- "regle_pmsi" : la règle du guide méthodologique qui s'applique (citée)
- "code" : le code CIM-10 ou l'action de codage correcte
- "confidence" : "high"
- "justification" : pourquoi cette règle s'applique à ce cas
Réponds UNIQUEMENT avec le JSON, sans texte avant/après."""
# Templates pour les Q&A directes (sans LLM)
DIRECT_QA_TEMPLATES = [
("Quelle est la définition du {concept} selon le Guide Méthodologique MCO ?",
"Selon le Guide Méthodologique MCO 2026 :\n\n{text}"),
("Quelles sont les règles de codage pour {concept} selon le Guide Méthodologique ?",
"Le Guide Méthodologique MCO 2026 précise les règles suivantes pour {concept} :\n\n{text}"),
]
def extract_sections(pdf_path: Path) -> list[tuple[str, str]]:
"""Extraire les sections du Guide Méthodologique."""
with pdfplumber.open(str(pdf_path)) as pdf:
full_text = ""
for page in pdf.pages:
text = page.extract_text() or ""
full_text += text + "\n\n"
lines = []
for line in full_text.split("\n"):
stripped = line.strip()
if re.match(r"^\s*\d{1,3}\s*$", stripped):
continue
if re.match(r"^(Guide méthodologique|ATIH|Page\s+\d)", stripped, re.IGNORECASE):
continue
if "...." in stripped:
continue
lines.append(line)
clean_text = "\n".join(lines)
section_pattern = re.compile(
r"^(\d+(?:\.\d+)*\.?)\s+([A-ZÀÂÉÈÊËÎÏÔÙÛÜÇ].*)",
re.MULTILINE
)
sections = []
matches = list(section_pattern.finditer(clean_text))
for i, match in enumerate(matches):
num = match.group(1).rstrip(".")
title = match.group(2).strip()
start = match.end()
end = matches[i + 1].start() if i + 1 < len(matches) else len(clean_text)
body = clean_text[start:end].strip()
if len(body) > 80:
section_title = f"{num}. {title}"
sections.append((section_title, body))
return sections
def extract_rule_paragraphs(sections: list[tuple[str, str]]) -> list[tuple[str, str]]:
"""Extraire les paragraphes contenant des règles de codage."""
rules = []
for section_title, body in sections:
paragraphs = re.split(r"\n\s*\n", body)
for para in paragraphs:
para = para.strip()
if len(para) < 80:
continue
matches = sum(1 for pat in RULE_PATTERNS if pat.search(para))
if matches >= 1:
text = para[:1500] if len(para) > 1500 else para
rules.append((section_title, text))
return rules
def call_claude(client, prompt: str, max_retries: int = 2) -> str | None:
"""Appel Claude Opus 4.6 via API Anthropic avec retry."""
for attempt in range(max_retries + 1):
try:
response = client.messages.create(
model=MODEL,
max_tokens=2048,
temperature=0.7,
messages=[{"role": "user", "content": prompt}],
)
return response.content[0].text
except Exception as e:
if attempt < max_retries:
wait = 2 ** (attempt + 1)
print(f" Retry in {wait}s: {e}")
time.sleep(wait)
else:
print(f" Claude error: {e}")
return None
def parse_llm_response(response_text: str) -> dict | None:
"""Parse la réponse JSON du LLM."""
if not response_text:
return None
text = response_text.strip()
if "```json" in text:
text = text.split("```json", 1)[1].split("```", 1)[0].strip()
elif "```" in text:
text = text.split("```", 1)[1].split("```", 1)[0].strip()
try:
data = json.loads(text)
if "scenario" in data and "reponse" in data:
return data
except json.JSONDecodeError:
pass
# Fallback
brace_start = text.find("{")
if brace_start >= 0:
depth = 0
for i in range(brace_start, len(text)):
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
if depth == 0:
try:
data = json.loads(text[brace_start:i+1])
if "scenario" in data:
return data
except json.JSONDecodeError:
break
return None
def make_chatml(system: str, user: str, assistant: str) -> dict:
return {
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
}
def generate_direct_qa(sections: list[tuple[str, str]]) -> list[dict]:
"""Générer des Q&A directes depuis les sections (sans LLM)."""
examples = []
concept_keywords = {
"diagnostic principal": ["DP", "diagnostic principal"],
"diagnostic relié": ["DR", "diagnostic relié"],
"diagnostic associé significatif": ["DAS", "diagnostic associé"],
"complication ou morbidité associée": ["CMA", "complication", "sévérité"],
"séjour multi-unités": ["multi-unité", "transfert", "mutation"],
"groupage GHM": ["GHM", "GHS", "groupage", "CMD"],
"actes CCAM": ["CCAM", "acte", "procédure"],
}
for section_title, body in sections:
body_lower = body.lower()
for concept, keywords in concept_keywords.items():
if any(kw.lower() in body_lower for kw in keywords):
text = body[:2000] if len(body) > 2000 else body
tmpl_q, tmpl_a = random.choice(DIRECT_QA_TEMPLATES)
user = tmpl_q.format(concept=f"{concept} (section {section_title})")
assistant = tmpl_a.format(concept=concept, text=text)
examples.append(make_chatml(SYSTEM_PROMPT, user, assistant))
break
return examples
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dry-run", action="store_true", help="Pas d'appel LLM")
parser.add_argument("--max", type=int, default=500, help="Max exemples à générer via LLM")
args = parser.parse_args()
print("=" * 60)
print("Parsing du Guide Méthodologique MCO 2026")
print(f"Modèle : {MODEL}")
print("=" * 60)
if not PDF_PATH.exists():
print(f"PDF non trouvé : {PDF_PATH}")
return
# Vérifier la clé API
if not args.dry_run:
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
print("Erreur: ANTHROPIC_API_KEY non définie.")
print(" export ANTHROPIC_API_KEY='sk-ant-...'")
sys.exit(1)
import anthropic
client = anthropic.Anthropic(api_key=api_key)
else:
client = None
print(f"\nParsing de {PDF_PATH.name} ({PDF_PATH.stat().st_size / 1024 / 1024:.1f} Mo)...")
sections = extract_sections(PDF_PATH)
print(f" Sections extraites : {len(sections)}")
# Q&A directes (sans LLM)
print("\nGénération des Q&A directes (sans LLM)...")
direct_examples = generate_direct_qa(sections)
print(f"{len(direct_examples)} exemples directs")
# Extraire les paragraphes de règles
print("\nExtraction des règles de codage...")
rules = extract_rule_paragraphs(sections)
print(f" Règles détectées : {len(rules)}")
if args.dry_run:
for i, (section, rule) in enumerate(rules[:20]):
print(f" [{i+1}] [{section}] {rule[:100]}...")
print(f"\n[DRY RUN] {len(rules)} règles à traiter. Relancez sans --dry-run.")
return
# Limiter le nombre de règles pour LLM
if len(rules) > args.max:
random.shuffle(rules)
rules = rules[:args.max]
# Générer les Q&A via Claude
print(f"\nGénération via Claude ({len(rules)} règles)...")
llm_examples = []
n_ok = 0
n_fail = 0
for i, (section_title, rule_text) in enumerate(rules):
prompt = GENERATION_PROMPT.format(section=section_title, rule_text=rule_text)
response_text = call_claude(client, prompt)
parsed = parse_llm_response(response_text)
if parsed and "scenario" in parsed and "reponse" in parsed:
user_content = (
f"Cas clinique :\n{parsed['scenario']}\n\n"
f"Applique les règles du Guide Méthodologique MCO (section « {section_title} »)."
)
assistant_content = json.dumps(parsed["reponse"], ensure_ascii=False)
llm_examples.append(make_chatml(SYSTEM_PROMPT, user_content, assistant_content))
n_ok += 1
else:
n_fail += 1
if (i + 1) % 50 == 0:
print(f" Progression : {i+1}/{len(rules)} (ok={n_ok}, fail={n_fail})")
print(f" Résultat LLM : {n_ok} exemples, {n_fail} échecs")
# Fusionner
all_examples = direct_examples + llm_examples
random.shuffle(all_examples)
with open(OUTPUT, "w") as f:
for ex in all_examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
print(f"\n{'='*60}")
print(f"Total : {len(all_examples)} exemples → {OUTPUT}")
print(f" Directs (sans LLM) : {len(direct_examples)}")
print(f" LLM (Claude) : {len(llm_examples)}")
print(f"Taille : {OUTPUT.stat().st_size / 1024:.0f} Ko")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""
Export un checkpoint LoRA en GGUF pour Ollama.
Stratégie : charge le modèle de base via Unsloth (pour bénéficier de
save_pretrained_gguf), puis applique le LoRA depuis le checkpoint
via PeftModel.from_pretrained en copiant la méthode GGUF.
Usage: python scripts/export_checkpoint_gguf.py [--checkpoint models/pmsi-lora-checkpoints/checkpoint-7000]
"""
import argparse
from pathlib import Path
BASE = Path(__file__).resolve().parent.parent
OUTPUT = BASE / "models"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default=str(OUTPUT / "pmsi-lora-checkpoints" / "checkpoint-7000"))
parser.add_argument("--model", default="unsloth/gemma-3-12b-it-bnb-4bit")
parser.add_argument("--max-seq-length", type=int, default=512)
parser.add_argument("--quant", default="q4_k_m")
args = parser.parse_args()
from unsloth import FastLanguageModel
from peft import PeftModel
checkpoint_dir = Path(args.checkpoint)
gguf_dir = OUTPUT / "pmsi-gguf-v2"
gguf_dir.mkdir(parents=True, exist_ok=True)
# Étape 1 : Charger base via Unsloth (installe save_pretrained_gguf)
print(f"[1/3] Chargement du modèle de base via Unsloth...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_length,
dtype=None,
load_in_4bit=True,
)
# Sauver la méthode GGUF avant que PeftModel ne l'écrase
save_gguf_fn = model.save_pretrained_gguf
# Étape 2 : Appliquer le LoRA checkpoint
print(f"[2/3] Application LoRA depuis {checkpoint_dir.name}...")
model = PeftModel.from_pretrained(model, str(checkpoint_dir))
# Réattacher la méthode GGUF d'Unsloth au modèle PeftModel
model.save_pretrained_gguf = save_gguf_fn.__func__.__get__(model, type(model))
# Vérifier que les poids LoRA sont bien chargés
lora_params = [n for n, p in model.named_parameters() if "lora" in n and p.requires_grad]
print(f" Paramètres LoRA actifs : {len(lora_params)}")
# Étape 3 : Export GGUF
print(f"[3/3] Export GGUF ({args.quant})...")
model.save_pretrained_gguf(
str(gguf_dir),
tokenizer,
quantization_method=args.quant,
)
# Résultat
gguf_files = sorted(gguf_dir.glob("*.gguf"), key=lambda f: f.stat().st_size)
if not gguf_files:
print("Aucun GGUF produit !")
return
final_gguf = gguf_files[0]
for g in gguf_files:
size_gb = g.stat().st_size / 1024**3
print(f" {g.name} ({size_gb:.1f} Go)")
# Modelfile
modelfile_path = gguf_dir / "Modelfile"
with open(modelfile_path, "w") as f:
f.write(f"FROM {final_gguf.name}\n\n")
f.write("PARAMETER temperature 0.3\n")
f.write("PARAMETER top_p 0.9\n")
f.write("PARAMETER num_ctx 8192\n")
print(f"\nTerminé !")
print(f" GGUF : {final_gguf}")
print(f"\nPour importer dans Ollama :")
print(f" cd {gguf_dir}")
print(f" ollama create pmsi-coder -f Modelfile")
if __name__ == "__main__":
main()