feat: rééquilibrage dataset LoRA — raisonnement DIM vs mémorisation

Passe de 95/3/2 (lookups/raisonnement/règles) à ~31/49/20.
Dataset cible ~16K exemples denses (vs 66K de lookups avant).

Modifiés :
- 03_convert_cache.py : cache complet 1840 entrées (actuel + backup)
- 04_build_dataset.py : subsampling agressif (CIM-10 1.5K, CCAM 1.5K,
  CoCoA 2K) + sélection intelligente priorisant le raisonnement
- 12_generate_pipeline_examples.py : 3 templates (court + long + CPAM),
  cache actuel, cible ~2800 exemples

Créés :
- 13_generate_fascicule_reasoning.py : parsing 10 fascicules ATIH,
  génération Q&A raisonnement via Claude Opus 4.6 (~450 exemples)
- 14_generate_negative_examples.py : 1000 exemples négatifs
  (symptômes/DP, redondances sémantiques, DAS non significatifs)
- 15_generate_discrimination.py : 800 exercices de discrimination
  entre codes siblings CIM-10 via Claude Opus 4.6
- 16_parse_guide_metho.py : extraction Guide Méthodologique MCO 2026,
  Q&A directes + raisonnement via Claude Opus 4.6 (~500 exemples)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dom
2026-02-16 19:42:33 +01:00
commit 06100df236
21 changed files with 6106 additions and 0 deletions

View File

@@ -0,0 +1,460 @@
#!/usr/bin/env python3
"""
Phase 3 — Génération de triplets (anchor, positive, negative) pour fine-tuning d'embedding.
Sources :
1. Cache Ollama (1840 paires diagnostic → code CIM-10)
2. Métadonnées FAISS (12444 entrées CIM-10 avec extraits)
3. CoCoA (synonymes, exclusions)
4. Données CIM-10 FHIR (descriptions → codes)
Format de sortie : JSONL avec {anchor, positive, negative}
- anchor : texte clinique / diagnostic tel qu'écrit par le médecin
- positive : extrait de référence CIM-10 correspondant au bon code
- negative : extrait de référence CIM-10 d'un code similaire mais incorrect (hard negative)
Usage :
python scripts/09_build_embedding_triplets.py [--output data/datasets/triplets.jsonl]
"""
import argparse
import json
import random
from collections import defaultdict
from pathlib import Path
random.seed(42)
BASE = Path(__file__).resolve().parent.parent
T2A = BASE.parent / "t2a"
# Sources
CACHE_PATH = T2A / "data" / "ollama_cache.json"
METADATA_PATH = T2A / "data" / "rag_index" / "metadata.json"
COCOA_PATH = BASE / "data" / "processed" / "cocoa_entries_debug.json"
CIM10_CHATML = BASE / "data" / "processed" / "cim10_chatml.jsonl"
# Output
DEFAULT_OUTPUT = BASE / "data" / "datasets" / "triplets.jsonl"
def load_faiss_metadata() -> tuple[dict[str, list[dict]], list[dict]]:
"""Charge les métadonnées FAISS et indexe par code CIM-10.
Returns:
(code_to_entries, all_cim10_entries)
"""
with open(METADATA_PATH) as f:
metadata = json.load(f)
code_to_entries = defaultdict(list)
all_cim10 = []
for m in metadata:
if m.get("document") not in ("cim10", "cim10_alpha"):
continue
code = m.get("code", "").strip()
if not code:
continue
all_cim10.append(m)
# Normaliser le code (avec et sans point)
code_to_entries[code].append(m)
# Aussi indexer par préfixe 3 caractères pour les hard negatives
# Ex: I10 → chapitre I, catégorie I10
print(f"FAISS metadata : {len(all_cim10)} entrées CIM-10, {len(code_to_entries)} codes uniques")
return code_to_entries, all_cim10
def get_chapter(code: str) -> str:
"""Extrait le chapitre CIM-10 (première lettre)."""
return code[0] if code else ""
def get_category(code: str) -> str:
"""Extrait la catégorie CIM-10 (3 premiers caractères, ex: I10, K80)."""
clean = code.replace(".", "")
return clean[:3] if len(clean) >= 3 else code
def find_hard_negatives(
code: str,
code_to_entries: dict[str, list[dict]],
all_cim10: list[dict],
n: int = 3,
) -> list[dict]:
"""Trouve des hard negatives : codes proches mais différents.
Stratégie par priorité :
1. Même catégorie (ex: K80.0 vs K80.2) — le plus dur
2. Même chapitre (ex: K80 vs K81) — difficile
3. Random d'un autre chapitre — facile (pour contraste)
"""
category = get_category(code)
chapter = get_chapter(code)
negatives = []
# 1. Même catégorie (siblings)
siblings = []
for c, entries in code_to_entries.items():
if c != code and get_category(c) == category:
siblings.extend(entries)
if siblings:
random.shuffle(siblings)
negatives.extend(siblings[:n])
# 2. Même chapitre si pas assez
if len(negatives) < n:
chapter_entries = []
for c, entries in code_to_entries.items():
if c != code and get_chapter(c) == chapter and get_category(c) != category:
chapter_entries.extend(entries)
if chapter_entries:
random.shuffle(chapter_entries)
negatives.extend(chapter_entries[: n - len(negatives)])
# 3. Random si toujours pas assez
if len(negatives) < n:
others = [m for m in all_cim10 if get_category(m.get("code", "")) != category]
if others:
random.shuffle(others)
negatives.extend(others[: n - len(negatives)])
return negatives[:n]
def format_entry_text(entry: dict) -> str:
"""Formate une entrée FAISS en texte lisible pour l'embedding."""
code = entry.get("code", "")
extrait = entry.get("extrait", "").strip()
if code and extrait:
# S'assurer que le code est dans le texte
if not extrait.startswith(code):
return f"{code} {extrait}"
return extrait
def triplets_from_cache(
code_to_entries: dict[str, list[dict]],
all_cim10: list[dict],
) -> list[dict]:
"""Génère des triplets à partir du cache Ollama."""
with open(CACHE_PATH) as f:
cache = json.load(f)
entries = cache.get("entries", cache)
if isinstance(entries, str):
return []
triplets = []
skipped = {"no_code": 0, "no_match": 0, "no_neg": 0}
for key, val in entries.items():
if not isinstance(val, dict):
continue
code = val.get("code", "")
if not code:
skipped["no_code"] += 1
continue
# Extraire le texte du diagnostic depuis la clé
# Format: "das::hypertension artérielle" ou "dp::pancréatite aiguë"
parts = key.split("::", 1)
if len(parts) != 2:
continue
diag_type, diag_text = parts
if diag_type == "das_llm":
continue # Format différent, skip
# Normaliser le code
code_clean = code.strip().upper()
# Trouver le positive dans FAISS
positive_entries = code_to_entries.get(code_clean, [])
if not positive_entries:
# Essayer sans le point
code_no_dot = code_clean.replace(".", "")
for c, ents in code_to_entries.items():
if c.replace(".", "") == code_no_dot:
positive_entries = ents
break
if not positive_entries:
skipped["no_match"] += 1
continue
positive = random.choice(positive_entries)
positive_text = format_entry_text(positive)
# Hard negatives
negs = find_hard_negatives(code_clean, code_to_entries, all_cim10, n=2)
if not negs:
skipped["no_neg"] += 1
continue
for neg in negs:
neg_text = format_entry_text(neg)
if neg_text and positive_text and neg_text != positive_text:
triplets.append({
"anchor": diag_text,
"positive": positive_text,
"negative": neg_text,
"source": "cache_ollama",
"code_pos": code_clean,
"code_neg": neg.get("code", ""),
})
print(f"Cache Ollama → {len(triplets)} triplets "
f"(skip: {skipped['no_code']} sans code, {skipped['no_match']} pas dans FAISS, "
f"{skipped['no_neg']} sans négatif)")
return triplets
def triplets_from_cocoa(
code_to_entries: dict[str, list[dict]],
all_cim10: list[dict],
) -> list[dict]:
"""Génère des triplets à partir de CoCoA (synonymes + exclusions)."""
if not COCOA_PATH.exists():
print("CoCoA debug file non trouvé, skip")
return []
with open(COCOA_PATH) as f:
cocoa = json.load(f)
# Le fichier debug est un dict {code: entry}, pas une liste
if isinstance(cocoa, dict):
cocoa = list(cocoa.values())
triplets = []
for entry in cocoa:
code = entry.get("code", "").strip()
desc = entry.get("description", "").strip()
synonymes = entry.get("synonyms", [])
exclusions = entry.get("exclusions", [])
if not code or not desc:
continue
# Positive = entrée FAISS pour ce code
positive_entries = code_to_entries.get(code, [])
if not positive_entries:
code_no_dot = code.replace(".", "")
for c, ents in code_to_entries.items():
if c.replace(".", "") == code_no_dot:
positive_entries = ents
break
if not positive_entries:
continue
positive = random.choice(positive_entries)
positive_text = format_entry_text(positive)
# Synonymes comme anchors supplémentaires
anchors = [desc] + [s for s in synonymes if len(s) > 5]
# Exclusions comme hard negatives explicites
explicit_negs = []
for excl in exclusions:
# Essayer d'extraire un code de l'exclusion (ex: "Diabète de type 2 (E11)")
import re
code_match = re.search(r"\(([A-Z]\d{2}(?:\.\d{1,2})?)\)", excl)
if code_match:
neg_code = code_match.group(1)
neg_entries = code_to_entries.get(neg_code, [])
if neg_entries:
explicit_negs.append(random.choice(neg_entries))
# Hard negatives auto si pas assez d'explicites
auto_negs = find_hard_negatives(code, code_to_entries, all_cim10, n=2)
all_negs = explicit_negs + auto_negs
if not all_negs:
continue
for anchor in anchors[:3]: # Max 3 anchors par code
for neg in all_negs[:2]: # Max 2 négatifs par anchor
neg_text = format_entry_text(neg)
if neg_text and positive_text and neg_text != positive_text:
triplets.append({
"anchor": anchor,
"positive": positive_text,
"negative": neg_text,
"source": "cocoa",
"code_pos": code,
"code_neg": neg.get("code", ""),
})
print(f"CoCoA → {len(triplets)} triplets")
return triplets
def triplets_from_cim10_fhir(
code_to_entries: dict[str, list[dict]],
all_cim10: list[dict],
) -> list[dict]:
"""Génère des triplets à partir du CIM-10 FHIR ChatML (description → code)."""
if not CIM10_CHATML.exists():
print("CIM-10 ChatML non trouvé, skip")
return []
triplets = []
seen_codes = set()
with open(CIM10_CHATML) as f:
for line in f:
example = json.loads(line.strip())
messages = example.get("messages", [])
# Extraire la question (user) et la réponse (assistant)
user_msg = ""
assistant_msg = ""
for m in messages:
if m["role"] == "user":
user_msg = m["content"]
elif m["role"] == "assistant":
assistant_msg = m["content"]
if not user_msg or not assistant_msg:
continue
# Parser le code de la réponse
try:
resp = json.loads(assistant_msg)
code = resp.get("code", "")
except json.JSONDecodeError:
continue
if not code or code in seen_codes:
continue
seen_codes.add(code)
# Extraire le texte du diagnostic depuis la question
# Format: "Quel est le code CIM-10 pour : Description"
# ou "Codage CIM-10 du diagnostic : Description"
import re
match = re.search(r"(?:pour|diagnostic)\s*:\s*(.+)", user_msg)
if not match:
continue
anchor = match.group(1).strip()
if len(anchor) < 5:
continue
# Positive
positive_entries = code_to_entries.get(code, [])
if not positive_entries:
continue
positive = random.choice(positive_entries)
positive_text = format_entry_text(positive)
# Hard negative
negs = find_hard_negatives(code, code_to_entries, all_cim10, n=1)
if not negs:
continue
neg = negs[0]
neg_text = format_entry_text(neg)
if neg_text and positive_text and neg_text != positive_text:
triplets.append({
"anchor": anchor,
"positive": positive_text,
"negative": neg_text,
"source": "cim10_fhir",
"code_pos": code,
"code_neg": neg.get("code", ""),
})
print(f"CIM-10 FHIR → {len(triplets)} triplets")
return triplets
def deduplicate(triplets: list[dict]) -> list[dict]:
"""Déduplique par (anchor, positive, negative)."""
seen = set()
deduped = []
for t in triplets:
key = (t["anchor"][:50], t["code_pos"], t["code_neg"])
if key not in seen:
seen.add(key)
deduped.append(t)
return deduped
def analyze_triplets(triplets: list[dict]) -> None:
"""Statistiques sur les triplets générés."""
sources = defaultdict(int)
hard_neg_types = {"same_category": 0, "same_chapter": 0, "diff_chapter": 0}
for t in triplets:
sources[t["source"]] += 1
code_pos = t["code_pos"]
code_neg = t["code_neg"]
if get_category(code_pos) == get_category(code_neg):
hard_neg_types["same_category"] += 1
elif get_chapter(code_pos) == get_chapter(code_neg):
hard_neg_types["same_chapter"] += 1
else:
hard_neg_types["diff_chapter"] += 1
print(f"\n=== Statistiques ===")
print(f"Total triplets : {len(triplets)}")
print(f"Par source :")
for src, count in sorted(sources.items()):
print(f" {src}: {count}")
print(f"Difficulté des négatifs :")
for neg_type, count in hard_neg_types.items():
pct = 100 * count / len(triplets) if triplets else 0
print(f" {neg_type}: {count} ({pct:.1f}%)")
def main():
parser = argparse.ArgumentParser(description="Génération de triplets pour embedding fine-tuning")
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT,
help="Fichier de sortie JSONL")
args = parser.parse_args()
print("=== Génération de triplets pour embedding fine-tuning ===\n")
# Charger FAISS metadata
code_to_entries, all_cim10 = load_faiss_metadata()
# Générer les triplets depuis chaque source
all_triplets = []
all_triplets.extend(triplets_from_cache(code_to_entries, all_cim10))
all_triplets.extend(triplets_from_cocoa(code_to_entries, all_cim10))
all_triplets.extend(triplets_from_cim10_fhir(code_to_entries, all_cim10))
# Dédupliquer
all_triplets = deduplicate(all_triplets)
random.shuffle(all_triplets)
# Statistiques
analyze_triplets(all_triplets)
# Sauvegarder
args.output.parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w") as f:
for t in all_triplets:
# Format simplifié pour l'entraînement : anchor, positive, negative
out = {
"anchor": t["anchor"],
"positive": t["positive"],
"negative": t["negative"],
}
f.write(json.dumps(out, ensure_ascii=False) + "\n")
size_mo = args.output.stat().st_size / 1024 / 1024
print(f"\nSauvegardé : {args.output} ({len(all_triplets)} triplets, {size_mo:.1f} Mo)")
# Aussi sauvegarder version détaillée pour debug
debug_path = args.output.with_suffix(".debug.jsonl")
with open(debug_path, "w") as f:
for t in all_triplets[:50]:
f.write(json.dumps(t, ensure_ascii=False, indent=2) + "\n---\n")
print(f"Debug (50 premiers) : {debug_path}")
if __name__ == "__main__":
main()