#!/usr/bin/env python3 """ Phase 3 — Génération de triplets (anchor, positive, negative) pour fine-tuning d'embedding. Sources : 1. Cache Ollama (1840 paires diagnostic → code CIM-10) 2. Métadonnées FAISS (12444 entrées CIM-10 avec extraits) 3. CoCoA (synonymes, exclusions) 4. Données CIM-10 FHIR (descriptions → codes) Format de sortie : JSONL avec {anchor, positive, negative} - anchor : texte clinique / diagnostic tel qu'écrit par le médecin - positive : extrait de référence CIM-10 correspondant au bon code - negative : extrait de référence CIM-10 d'un code similaire mais incorrect (hard negative) Usage : python scripts/09_build_embedding_triplets.py [--output data/datasets/triplets.jsonl] """ import argparse import json import random from collections import defaultdict from pathlib import Path random.seed(42) BASE = Path(__file__).resolve().parent.parent T2A = BASE.parent / "t2a" # Sources CACHE_PATH = T2A / "data" / "ollama_cache.json" METADATA_PATH = T2A / "data" / "rag_index" / "metadata.json" COCOA_PATH = BASE / "data" / "processed" / "cocoa_entries_debug.json" CIM10_CHATML = BASE / "data" / "processed" / "cim10_chatml.jsonl" # Output DEFAULT_OUTPUT = BASE / "data" / "datasets" / "triplets.jsonl" def load_faiss_metadata() -> tuple[dict[str, list[dict]], list[dict]]: """Charge les métadonnées FAISS et indexe par code CIM-10. Returns: (code_to_entries, all_cim10_entries) """ with open(METADATA_PATH) as f: metadata = json.load(f) code_to_entries = defaultdict(list) all_cim10 = [] for m in metadata: if m.get("document") not in ("cim10", "cim10_alpha"): continue code = m.get("code", "").strip() if not code: continue all_cim10.append(m) # Normaliser le code (avec et sans point) code_to_entries[code].append(m) # Aussi indexer par préfixe 3 caractères pour les hard negatives # Ex: I10 → chapitre I, catégorie I10 print(f"FAISS metadata : {len(all_cim10)} entrées CIM-10, {len(code_to_entries)} codes uniques") return code_to_entries, all_cim10 def get_chapter(code: str) -> str: """Extrait le chapitre CIM-10 (première lettre).""" return code[0] if code else "" def get_category(code: str) -> str: """Extrait la catégorie CIM-10 (3 premiers caractères, ex: I10, K80).""" clean = code.replace(".", "") return clean[:3] if len(clean) >= 3 else code def find_hard_negatives( code: str, code_to_entries: dict[str, list[dict]], all_cim10: list[dict], n: int = 3, ) -> list[dict]: """Trouve des hard negatives : codes proches mais différents. Stratégie par priorité : 1. Même catégorie (ex: K80.0 vs K80.2) — le plus dur 2. Même chapitre (ex: K80 vs K81) — difficile 3. Random d'un autre chapitre — facile (pour contraste) """ category = get_category(code) chapter = get_chapter(code) negatives = [] # 1. Même catégorie (siblings) siblings = [] for c, entries in code_to_entries.items(): if c != code and get_category(c) == category: siblings.extend(entries) if siblings: random.shuffle(siblings) negatives.extend(siblings[:n]) # 2. Même chapitre si pas assez if len(negatives) < n: chapter_entries = [] for c, entries in code_to_entries.items(): if c != code and get_chapter(c) == chapter and get_category(c) != category: chapter_entries.extend(entries) if chapter_entries: random.shuffle(chapter_entries) negatives.extend(chapter_entries[: n - len(negatives)]) # 3. Random si toujours pas assez if len(negatives) < n: others = [m for m in all_cim10 if get_category(m.get("code", "")) != category] if others: random.shuffle(others) negatives.extend(others[: n - len(negatives)]) return negatives[:n] def format_entry_text(entry: dict) -> str: """Formate une entrée FAISS en texte lisible pour l'embedding.""" code = entry.get("code", "") extrait = entry.get("extrait", "").strip() if code and extrait: # S'assurer que le code est dans le texte if not extrait.startswith(code): return f"{code} {extrait}" return extrait def triplets_from_cache( code_to_entries: dict[str, list[dict]], all_cim10: list[dict], ) -> list[dict]: """Génère des triplets à partir du cache Ollama.""" with open(CACHE_PATH) as f: cache = json.load(f) entries = cache.get("entries", cache) if isinstance(entries, str): return [] triplets = [] skipped = {"no_code": 0, "no_match": 0, "no_neg": 0} for key, val in entries.items(): if not isinstance(val, dict): continue code = val.get("code", "") if not code: skipped["no_code"] += 1 continue # Extraire le texte du diagnostic depuis la clé # Format: "das::hypertension artérielle" ou "dp::pancréatite aiguë" parts = key.split("::", 1) if len(parts) != 2: continue diag_type, diag_text = parts if diag_type == "das_llm": continue # Format différent, skip # Normaliser le code code_clean = code.strip().upper() # Trouver le positive dans FAISS positive_entries = code_to_entries.get(code_clean, []) if not positive_entries: # Essayer sans le point code_no_dot = code_clean.replace(".", "") for c, ents in code_to_entries.items(): if c.replace(".", "") == code_no_dot: positive_entries = ents break if not positive_entries: skipped["no_match"] += 1 continue positive = random.choice(positive_entries) positive_text = format_entry_text(positive) # Hard negatives negs = find_hard_negatives(code_clean, code_to_entries, all_cim10, n=2) if not negs: skipped["no_neg"] += 1 continue for neg in negs: neg_text = format_entry_text(neg) if neg_text and positive_text and neg_text != positive_text: triplets.append({ "anchor": diag_text, "positive": positive_text, "negative": neg_text, "source": "cache_ollama", "code_pos": code_clean, "code_neg": neg.get("code", ""), }) print(f"Cache Ollama → {len(triplets)} triplets " f"(skip: {skipped['no_code']} sans code, {skipped['no_match']} pas dans FAISS, " f"{skipped['no_neg']} sans négatif)") return triplets def triplets_from_cocoa( code_to_entries: dict[str, list[dict]], all_cim10: list[dict], ) -> list[dict]: """Génère des triplets à partir de CoCoA (synonymes + exclusions).""" if not COCOA_PATH.exists(): print("CoCoA debug file non trouvé, skip") return [] with open(COCOA_PATH) as f: cocoa = json.load(f) # Le fichier debug est un dict {code: entry}, pas une liste if isinstance(cocoa, dict): cocoa = list(cocoa.values()) triplets = [] for entry in cocoa: code = entry.get("code", "").strip() desc = entry.get("description", "").strip() synonymes = entry.get("synonyms", []) exclusions = entry.get("exclusions", []) if not code or not desc: continue # Positive = entrée FAISS pour ce code positive_entries = code_to_entries.get(code, []) if not positive_entries: code_no_dot = code.replace(".", "") for c, ents in code_to_entries.items(): if c.replace(".", "") == code_no_dot: positive_entries = ents break if not positive_entries: continue positive = random.choice(positive_entries) positive_text = format_entry_text(positive) # Synonymes comme anchors supplémentaires anchors = [desc] + [s for s in synonymes if len(s) > 5] # Exclusions comme hard negatives explicites explicit_negs = [] for excl in exclusions: # Essayer d'extraire un code de l'exclusion (ex: "Diabète de type 2 (E11)") import re code_match = re.search(r"\(([A-Z]\d{2}(?:\.\d{1,2})?)\)", excl) if code_match: neg_code = code_match.group(1) neg_entries = code_to_entries.get(neg_code, []) if neg_entries: explicit_negs.append(random.choice(neg_entries)) # Hard negatives auto si pas assez d'explicites auto_negs = find_hard_negatives(code, code_to_entries, all_cim10, n=2) all_negs = explicit_negs + auto_negs if not all_negs: continue for anchor in anchors[:3]: # Max 3 anchors par code for neg in all_negs[:2]: # Max 2 négatifs par anchor neg_text = format_entry_text(neg) if neg_text and positive_text and neg_text != positive_text: triplets.append({ "anchor": anchor, "positive": positive_text, "negative": neg_text, "source": "cocoa", "code_pos": code, "code_neg": neg.get("code", ""), }) print(f"CoCoA → {len(triplets)} triplets") return triplets def triplets_from_cim10_fhir( code_to_entries: dict[str, list[dict]], all_cim10: list[dict], ) -> list[dict]: """Génère des triplets à partir du CIM-10 FHIR ChatML (description → code).""" if not CIM10_CHATML.exists(): print("CIM-10 ChatML non trouvé, skip") return [] triplets = [] seen_codes = set() with open(CIM10_CHATML) as f: for line in f: example = json.loads(line.strip()) messages = example.get("messages", []) # Extraire la question (user) et la réponse (assistant) user_msg = "" assistant_msg = "" for m in messages: if m["role"] == "user": user_msg = m["content"] elif m["role"] == "assistant": assistant_msg = m["content"] if not user_msg or not assistant_msg: continue # Parser le code de la réponse try: resp = json.loads(assistant_msg) code = resp.get("code", "") except json.JSONDecodeError: continue if not code or code in seen_codes: continue seen_codes.add(code) # Extraire le texte du diagnostic depuis la question # Format: "Quel est le code CIM-10 pour : Description" # ou "Codage CIM-10 du diagnostic : Description" import re match = re.search(r"(?:pour|diagnostic)\s*:\s*(.+)", user_msg) if not match: continue anchor = match.group(1).strip() if len(anchor) < 5: continue # Positive positive_entries = code_to_entries.get(code, []) if not positive_entries: continue positive = random.choice(positive_entries) positive_text = format_entry_text(positive) # Hard negative negs = find_hard_negatives(code, code_to_entries, all_cim10, n=1) if not negs: continue neg = negs[0] neg_text = format_entry_text(neg) if neg_text and positive_text and neg_text != positive_text: triplets.append({ "anchor": anchor, "positive": positive_text, "negative": neg_text, "source": "cim10_fhir", "code_pos": code, "code_neg": neg.get("code", ""), }) print(f"CIM-10 FHIR → {len(triplets)} triplets") return triplets def deduplicate(triplets: list[dict]) -> list[dict]: """Déduplique par (anchor, positive, negative).""" seen = set() deduped = [] for t in triplets: key = (t["anchor"][:50], t["code_pos"], t["code_neg"]) if key not in seen: seen.add(key) deduped.append(t) return deduped def analyze_triplets(triplets: list[dict]) -> None: """Statistiques sur les triplets générés.""" sources = defaultdict(int) hard_neg_types = {"same_category": 0, "same_chapter": 0, "diff_chapter": 0} for t in triplets: sources[t["source"]] += 1 code_pos = t["code_pos"] code_neg = t["code_neg"] if get_category(code_pos) == get_category(code_neg): hard_neg_types["same_category"] += 1 elif get_chapter(code_pos) == get_chapter(code_neg): hard_neg_types["same_chapter"] += 1 else: hard_neg_types["diff_chapter"] += 1 print(f"\n=== Statistiques ===") print(f"Total triplets : {len(triplets)}") print(f"Par source :") for src, count in sorted(sources.items()): print(f" {src}: {count}") print(f"Difficulté des négatifs :") for neg_type, count in hard_neg_types.items(): pct = 100 * count / len(triplets) if triplets else 0 print(f" {neg_type}: {count} ({pct:.1f}%)") def main(): parser = argparse.ArgumentParser(description="Génération de triplets pour embedding fine-tuning") parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT, help="Fichier de sortie JSONL") args = parser.parse_args() print("=== Génération de triplets pour embedding fine-tuning ===\n") # Charger FAISS metadata code_to_entries, all_cim10 = load_faiss_metadata() # Générer les triplets depuis chaque source all_triplets = [] all_triplets.extend(triplets_from_cache(code_to_entries, all_cim10)) all_triplets.extend(triplets_from_cocoa(code_to_entries, all_cim10)) all_triplets.extend(triplets_from_cim10_fhir(code_to_entries, all_cim10)) # Dédupliquer all_triplets = deduplicate(all_triplets) random.shuffle(all_triplets) # Statistiques analyze_triplets(all_triplets) # Sauvegarder args.output.parent.mkdir(parents=True, exist_ok=True) with open(args.output, "w") as f: for t in all_triplets: # Format simplifié pour l'entraînement : anchor, positive, negative out = { "anchor": t["anchor"], "positive": t["positive"], "negative": t["negative"], } f.write(json.dumps(out, ensure_ascii=False) + "\n") size_mo = args.output.stat().st_size / 1024 / 1024 print(f"\nSauvegardé : {args.output} ({len(all_triplets)} triplets, {size_mo:.1f} Mo)") # Aussi sauvegarder version détaillée pour debug debug_path = args.output.with_suffix(".debug.jsonl") with open(debug_path, "w") as f: for t in all_triplets[:50]: f.write(json.dumps(t, ensure_ascii=False, indent=2) + "\n---\n") print(f"Debug (50 premiers) : {debug_path}") if __name__ == "__main__": main()