feat: rééquilibrage dataset LoRA — raisonnement DIM vs mémorisation
Passe de 95/3/2 (lookups/raisonnement/règles) à ~31/49/20. Dataset cible ~16K exemples denses (vs 66K de lookups avant). Modifiés : - 03_convert_cache.py : cache complet 1840 entrées (actuel + backup) - 04_build_dataset.py : subsampling agressif (CIM-10 1.5K, CCAM 1.5K, CoCoA 2K) + sélection intelligente priorisant le raisonnement - 12_generate_pipeline_examples.py : 3 templates (court + long + CPAM), cache actuel, cible ~2800 exemples Créés : - 13_generate_fascicule_reasoning.py : parsing 10 fascicules ATIH, génération Q&A raisonnement via Claude Opus 4.6 (~450 exemples) - 14_generate_negative_examples.py : 1000 exemples négatifs (symptômes/DP, redondances sémantiques, DAS non significatifs) - 15_generate_discrimination.py : 800 exercices de discrimination entre codes siblings CIM-10 via Claude Opus 4.6 - 16_parse_guide_metho.py : extraction Guide Méthodologique MCO 2026, Q&A directes + raisonnement via Claude Opus 4.6 (~500 exemples) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
460
scripts/09_build_embedding_triplets.py
Normal file
460
scripts/09_build_embedding_triplets.py
Normal file
@@ -0,0 +1,460 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Phase 3 — Génération de triplets (anchor, positive, negative) pour fine-tuning d'embedding.
|
||||
|
||||
Sources :
|
||||
1. Cache Ollama (1840 paires diagnostic → code CIM-10)
|
||||
2. Métadonnées FAISS (12444 entrées CIM-10 avec extraits)
|
||||
3. CoCoA (synonymes, exclusions)
|
||||
4. Données CIM-10 FHIR (descriptions → codes)
|
||||
|
||||
Format de sortie : JSONL avec {anchor, positive, negative}
|
||||
- anchor : texte clinique / diagnostic tel qu'écrit par le médecin
|
||||
- positive : extrait de référence CIM-10 correspondant au bon code
|
||||
- negative : extrait de référence CIM-10 d'un code similaire mais incorrect (hard negative)
|
||||
|
||||
Usage :
|
||||
python scripts/09_build_embedding_triplets.py [--output data/datasets/triplets.jsonl]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
random.seed(42)
|
||||
|
||||
BASE = Path(__file__).resolve().parent.parent
|
||||
T2A = BASE.parent / "t2a"
|
||||
|
||||
# Sources
|
||||
CACHE_PATH = T2A / "data" / "ollama_cache.json"
|
||||
METADATA_PATH = T2A / "data" / "rag_index" / "metadata.json"
|
||||
COCOA_PATH = BASE / "data" / "processed" / "cocoa_entries_debug.json"
|
||||
CIM10_CHATML = BASE / "data" / "processed" / "cim10_chatml.jsonl"
|
||||
|
||||
# Output
|
||||
DEFAULT_OUTPUT = BASE / "data" / "datasets" / "triplets.jsonl"
|
||||
|
||||
|
||||
def load_faiss_metadata() -> tuple[dict[str, list[dict]], list[dict]]:
|
||||
"""Charge les métadonnées FAISS et indexe par code CIM-10.
|
||||
|
||||
Returns:
|
||||
(code_to_entries, all_cim10_entries)
|
||||
"""
|
||||
with open(METADATA_PATH) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
code_to_entries = defaultdict(list)
|
||||
all_cim10 = []
|
||||
|
||||
for m in metadata:
|
||||
if m.get("document") not in ("cim10", "cim10_alpha"):
|
||||
continue
|
||||
code = m.get("code", "").strip()
|
||||
if not code:
|
||||
continue
|
||||
all_cim10.append(m)
|
||||
# Normaliser le code (avec et sans point)
|
||||
code_to_entries[code].append(m)
|
||||
# Aussi indexer par préfixe 3 caractères pour les hard negatives
|
||||
# Ex: I10 → chapitre I, catégorie I10
|
||||
|
||||
print(f"FAISS metadata : {len(all_cim10)} entrées CIM-10, {len(code_to_entries)} codes uniques")
|
||||
return code_to_entries, all_cim10
|
||||
|
||||
|
||||
def get_chapter(code: str) -> str:
|
||||
"""Extrait le chapitre CIM-10 (première lettre)."""
|
||||
return code[0] if code else ""
|
||||
|
||||
|
||||
def get_category(code: str) -> str:
|
||||
"""Extrait la catégorie CIM-10 (3 premiers caractères, ex: I10, K80)."""
|
||||
clean = code.replace(".", "")
|
||||
return clean[:3] if len(clean) >= 3 else code
|
||||
|
||||
|
||||
def find_hard_negatives(
|
||||
code: str,
|
||||
code_to_entries: dict[str, list[dict]],
|
||||
all_cim10: list[dict],
|
||||
n: int = 3,
|
||||
) -> list[dict]:
|
||||
"""Trouve des hard negatives : codes proches mais différents.
|
||||
|
||||
Stratégie par priorité :
|
||||
1. Même catégorie (ex: K80.0 vs K80.2) — le plus dur
|
||||
2. Même chapitre (ex: K80 vs K81) — difficile
|
||||
3. Random d'un autre chapitre — facile (pour contraste)
|
||||
"""
|
||||
category = get_category(code)
|
||||
chapter = get_chapter(code)
|
||||
negatives = []
|
||||
|
||||
# 1. Même catégorie (siblings)
|
||||
siblings = []
|
||||
for c, entries in code_to_entries.items():
|
||||
if c != code and get_category(c) == category:
|
||||
siblings.extend(entries)
|
||||
if siblings:
|
||||
random.shuffle(siblings)
|
||||
negatives.extend(siblings[:n])
|
||||
|
||||
# 2. Même chapitre si pas assez
|
||||
if len(negatives) < n:
|
||||
chapter_entries = []
|
||||
for c, entries in code_to_entries.items():
|
||||
if c != code and get_chapter(c) == chapter and get_category(c) != category:
|
||||
chapter_entries.extend(entries)
|
||||
if chapter_entries:
|
||||
random.shuffle(chapter_entries)
|
||||
negatives.extend(chapter_entries[: n - len(negatives)])
|
||||
|
||||
# 3. Random si toujours pas assez
|
||||
if len(negatives) < n:
|
||||
others = [m for m in all_cim10 if get_category(m.get("code", "")) != category]
|
||||
if others:
|
||||
random.shuffle(others)
|
||||
negatives.extend(others[: n - len(negatives)])
|
||||
|
||||
return negatives[:n]
|
||||
|
||||
|
||||
def format_entry_text(entry: dict) -> str:
|
||||
"""Formate une entrée FAISS en texte lisible pour l'embedding."""
|
||||
code = entry.get("code", "")
|
||||
extrait = entry.get("extrait", "").strip()
|
||||
if code and extrait:
|
||||
# S'assurer que le code est dans le texte
|
||||
if not extrait.startswith(code):
|
||||
return f"{code} {extrait}"
|
||||
return extrait
|
||||
|
||||
|
||||
def triplets_from_cache(
|
||||
code_to_entries: dict[str, list[dict]],
|
||||
all_cim10: list[dict],
|
||||
) -> list[dict]:
|
||||
"""Génère des triplets à partir du cache Ollama."""
|
||||
with open(CACHE_PATH) as f:
|
||||
cache = json.load(f)
|
||||
|
||||
entries = cache.get("entries", cache)
|
||||
if isinstance(entries, str):
|
||||
return []
|
||||
|
||||
triplets = []
|
||||
skipped = {"no_code": 0, "no_match": 0, "no_neg": 0}
|
||||
|
||||
for key, val in entries.items():
|
||||
if not isinstance(val, dict):
|
||||
continue
|
||||
|
||||
code = val.get("code", "")
|
||||
if not code:
|
||||
skipped["no_code"] += 1
|
||||
continue
|
||||
|
||||
# Extraire le texte du diagnostic depuis la clé
|
||||
# Format: "das::hypertension artérielle" ou "dp::pancréatite aiguë"
|
||||
parts = key.split("::", 1)
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
diag_type, diag_text = parts
|
||||
if diag_type == "das_llm":
|
||||
continue # Format différent, skip
|
||||
|
||||
# Normaliser le code
|
||||
code_clean = code.strip().upper()
|
||||
|
||||
# Trouver le positive dans FAISS
|
||||
positive_entries = code_to_entries.get(code_clean, [])
|
||||
if not positive_entries:
|
||||
# Essayer sans le point
|
||||
code_no_dot = code_clean.replace(".", "")
|
||||
for c, ents in code_to_entries.items():
|
||||
if c.replace(".", "") == code_no_dot:
|
||||
positive_entries = ents
|
||||
break
|
||||
if not positive_entries:
|
||||
skipped["no_match"] += 1
|
||||
continue
|
||||
|
||||
positive = random.choice(positive_entries)
|
||||
positive_text = format_entry_text(positive)
|
||||
|
||||
# Hard negatives
|
||||
negs = find_hard_negatives(code_clean, code_to_entries, all_cim10, n=2)
|
||||
if not negs:
|
||||
skipped["no_neg"] += 1
|
||||
continue
|
||||
|
||||
for neg in negs:
|
||||
neg_text = format_entry_text(neg)
|
||||
if neg_text and positive_text and neg_text != positive_text:
|
||||
triplets.append({
|
||||
"anchor": diag_text,
|
||||
"positive": positive_text,
|
||||
"negative": neg_text,
|
||||
"source": "cache_ollama",
|
||||
"code_pos": code_clean,
|
||||
"code_neg": neg.get("code", ""),
|
||||
})
|
||||
|
||||
print(f"Cache Ollama → {len(triplets)} triplets "
|
||||
f"(skip: {skipped['no_code']} sans code, {skipped['no_match']} pas dans FAISS, "
|
||||
f"{skipped['no_neg']} sans négatif)")
|
||||
return triplets
|
||||
|
||||
|
||||
def triplets_from_cocoa(
|
||||
code_to_entries: dict[str, list[dict]],
|
||||
all_cim10: list[dict],
|
||||
) -> list[dict]:
|
||||
"""Génère des triplets à partir de CoCoA (synonymes + exclusions)."""
|
||||
if not COCOA_PATH.exists():
|
||||
print("CoCoA debug file non trouvé, skip")
|
||||
return []
|
||||
|
||||
with open(COCOA_PATH) as f:
|
||||
cocoa = json.load(f)
|
||||
|
||||
# Le fichier debug est un dict {code: entry}, pas une liste
|
||||
if isinstance(cocoa, dict):
|
||||
cocoa = list(cocoa.values())
|
||||
|
||||
triplets = []
|
||||
|
||||
for entry in cocoa:
|
||||
code = entry.get("code", "").strip()
|
||||
desc = entry.get("description", "").strip()
|
||||
synonymes = entry.get("synonyms", [])
|
||||
exclusions = entry.get("exclusions", [])
|
||||
|
||||
if not code or not desc:
|
||||
continue
|
||||
|
||||
# Positive = entrée FAISS pour ce code
|
||||
positive_entries = code_to_entries.get(code, [])
|
||||
if not positive_entries:
|
||||
code_no_dot = code.replace(".", "")
|
||||
for c, ents in code_to_entries.items():
|
||||
if c.replace(".", "") == code_no_dot:
|
||||
positive_entries = ents
|
||||
break
|
||||
if not positive_entries:
|
||||
continue
|
||||
|
||||
positive = random.choice(positive_entries)
|
||||
positive_text = format_entry_text(positive)
|
||||
|
||||
# Synonymes comme anchors supplémentaires
|
||||
anchors = [desc] + [s for s in synonymes if len(s) > 5]
|
||||
|
||||
# Exclusions comme hard negatives explicites
|
||||
explicit_negs = []
|
||||
for excl in exclusions:
|
||||
# Essayer d'extraire un code de l'exclusion (ex: "Diabète de type 2 (E11)")
|
||||
import re
|
||||
code_match = re.search(r"\(([A-Z]\d{2}(?:\.\d{1,2})?)\)", excl)
|
||||
if code_match:
|
||||
neg_code = code_match.group(1)
|
||||
neg_entries = code_to_entries.get(neg_code, [])
|
||||
if neg_entries:
|
||||
explicit_negs.append(random.choice(neg_entries))
|
||||
|
||||
# Hard negatives auto si pas assez d'explicites
|
||||
auto_negs = find_hard_negatives(code, code_to_entries, all_cim10, n=2)
|
||||
all_negs = explicit_negs + auto_negs
|
||||
|
||||
if not all_negs:
|
||||
continue
|
||||
|
||||
for anchor in anchors[:3]: # Max 3 anchors par code
|
||||
for neg in all_negs[:2]: # Max 2 négatifs par anchor
|
||||
neg_text = format_entry_text(neg)
|
||||
if neg_text and positive_text and neg_text != positive_text:
|
||||
triplets.append({
|
||||
"anchor": anchor,
|
||||
"positive": positive_text,
|
||||
"negative": neg_text,
|
||||
"source": "cocoa",
|
||||
"code_pos": code,
|
||||
"code_neg": neg.get("code", ""),
|
||||
})
|
||||
|
||||
print(f"CoCoA → {len(triplets)} triplets")
|
||||
return triplets
|
||||
|
||||
|
||||
def triplets_from_cim10_fhir(
|
||||
code_to_entries: dict[str, list[dict]],
|
||||
all_cim10: list[dict],
|
||||
) -> list[dict]:
|
||||
"""Génère des triplets à partir du CIM-10 FHIR ChatML (description → code)."""
|
||||
if not CIM10_CHATML.exists():
|
||||
print("CIM-10 ChatML non trouvé, skip")
|
||||
return []
|
||||
|
||||
triplets = []
|
||||
seen_codes = set()
|
||||
|
||||
with open(CIM10_CHATML) as f:
|
||||
for line in f:
|
||||
example = json.loads(line.strip())
|
||||
messages = example.get("messages", [])
|
||||
|
||||
# Extraire la question (user) et la réponse (assistant)
|
||||
user_msg = ""
|
||||
assistant_msg = ""
|
||||
for m in messages:
|
||||
if m["role"] == "user":
|
||||
user_msg = m["content"]
|
||||
elif m["role"] == "assistant":
|
||||
assistant_msg = m["content"]
|
||||
|
||||
if not user_msg or not assistant_msg:
|
||||
continue
|
||||
|
||||
# Parser le code de la réponse
|
||||
try:
|
||||
resp = json.loads(assistant_msg)
|
||||
code = resp.get("code", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if not code or code in seen_codes:
|
||||
continue
|
||||
seen_codes.add(code)
|
||||
|
||||
# Extraire le texte du diagnostic depuis la question
|
||||
# Format: "Quel est le code CIM-10 pour : Description"
|
||||
# ou "Codage CIM-10 du diagnostic : Description"
|
||||
import re
|
||||
match = re.search(r"(?:pour|diagnostic)\s*:\s*(.+)", user_msg)
|
||||
if not match:
|
||||
continue
|
||||
anchor = match.group(1).strip()
|
||||
if len(anchor) < 5:
|
||||
continue
|
||||
|
||||
# Positive
|
||||
positive_entries = code_to_entries.get(code, [])
|
||||
if not positive_entries:
|
||||
continue
|
||||
positive = random.choice(positive_entries)
|
||||
positive_text = format_entry_text(positive)
|
||||
|
||||
# Hard negative
|
||||
negs = find_hard_negatives(code, code_to_entries, all_cim10, n=1)
|
||||
if not negs:
|
||||
continue
|
||||
|
||||
neg = negs[0]
|
||||
neg_text = format_entry_text(neg)
|
||||
if neg_text and positive_text and neg_text != positive_text:
|
||||
triplets.append({
|
||||
"anchor": anchor,
|
||||
"positive": positive_text,
|
||||
"negative": neg_text,
|
||||
"source": "cim10_fhir",
|
||||
"code_pos": code,
|
||||
"code_neg": neg.get("code", ""),
|
||||
})
|
||||
|
||||
print(f"CIM-10 FHIR → {len(triplets)} triplets")
|
||||
return triplets
|
||||
|
||||
|
||||
def deduplicate(triplets: list[dict]) -> list[dict]:
|
||||
"""Déduplique par (anchor, positive, negative)."""
|
||||
seen = set()
|
||||
deduped = []
|
||||
for t in triplets:
|
||||
key = (t["anchor"][:50], t["code_pos"], t["code_neg"])
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(t)
|
||||
return deduped
|
||||
|
||||
|
||||
def analyze_triplets(triplets: list[dict]) -> None:
|
||||
"""Statistiques sur les triplets générés."""
|
||||
sources = defaultdict(int)
|
||||
hard_neg_types = {"same_category": 0, "same_chapter": 0, "diff_chapter": 0}
|
||||
|
||||
for t in triplets:
|
||||
sources[t["source"]] += 1
|
||||
code_pos = t["code_pos"]
|
||||
code_neg = t["code_neg"]
|
||||
if get_category(code_pos) == get_category(code_neg):
|
||||
hard_neg_types["same_category"] += 1
|
||||
elif get_chapter(code_pos) == get_chapter(code_neg):
|
||||
hard_neg_types["same_chapter"] += 1
|
||||
else:
|
||||
hard_neg_types["diff_chapter"] += 1
|
||||
|
||||
print(f"\n=== Statistiques ===")
|
||||
print(f"Total triplets : {len(triplets)}")
|
||||
print(f"Par source :")
|
||||
for src, count in sorted(sources.items()):
|
||||
print(f" {src}: {count}")
|
||||
print(f"Difficulté des négatifs :")
|
||||
for neg_type, count in hard_neg_types.items():
|
||||
pct = 100 * count / len(triplets) if triplets else 0
|
||||
print(f" {neg_type}: {count} ({pct:.1f}%)")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Génération de triplets pour embedding fine-tuning")
|
||||
parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT,
|
||||
help="Fichier de sortie JSONL")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=== Génération de triplets pour embedding fine-tuning ===\n")
|
||||
|
||||
# Charger FAISS metadata
|
||||
code_to_entries, all_cim10 = load_faiss_metadata()
|
||||
|
||||
# Générer les triplets depuis chaque source
|
||||
all_triplets = []
|
||||
|
||||
all_triplets.extend(triplets_from_cache(code_to_entries, all_cim10))
|
||||
all_triplets.extend(triplets_from_cocoa(code_to_entries, all_cim10))
|
||||
all_triplets.extend(triplets_from_cim10_fhir(code_to_entries, all_cim10))
|
||||
|
||||
# Dédupliquer
|
||||
all_triplets = deduplicate(all_triplets)
|
||||
random.shuffle(all_triplets)
|
||||
|
||||
# Statistiques
|
||||
analyze_triplets(all_triplets)
|
||||
|
||||
# Sauvegarder
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.output, "w") as f:
|
||||
for t in all_triplets:
|
||||
# Format simplifié pour l'entraînement : anchor, positive, negative
|
||||
out = {
|
||||
"anchor": t["anchor"],
|
||||
"positive": t["positive"],
|
||||
"negative": t["negative"],
|
||||
}
|
||||
f.write(json.dumps(out, ensure_ascii=False) + "\n")
|
||||
|
||||
size_mo = args.output.stat().st_size / 1024 / 1024
|
||||
print(f"\nSauvegardé : {args.output} ({len(all_triplets)} triplets, {size_mo:.1f} Mo)")
|
||||
|
||||
# Aussi sauvegarder version détaillée pour debug
|
||||
debug_path = args.output.with_suffix(".debug.jsonl")
|
||||
with open(debug_path, "w") as f:
|
||||
for t in all_triplets[:50]:
|
||||
f.write(json.dumps(t, ensure_ascii=False, indent=2) + "\n---\n")
|
||||
print(f"Debug (50 premiers) : {debug_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user