diff --git a/src/control/cpam_response.py b/src/control/cpam_response.py index 183dbed..f95b1e6 100644 --- a/src/control/cpam_response.py +++ b/src/control/cpam_response.py @@ -4,6 +4,14 @@ Orchestrateur principal — délègue aux sous-modules : - cpam_rag : _search_rag_for_control(), _search_rag_queries() - cpam_context : _build_cpam_prompt(), _build_tagged_context(), _build_bio_summary(), etc. - cpam_validation : _validate_adversarial(), _validate_grounding(), _format_response(), etc. + +Le flux batché (generate_cpam_responses_batched) regroupe les appels LLM par modèle +pour minimiser les swaps VRAM sur GPU unique : + Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles + Phase 2 (validation model, 1 swap) : validation adversariale de tous + Phase 3 (cpam model, 1 swap) : correction des contrôles échoués + Phase 4 (validation model, 1 swap) : re-validation des corrigés +Gain : de ~4N swaps à ~2-4 swaps (indépendant du nombre de contrôles). """ from __future__ import annotations @@ -12,6 +20,7 @@ import json import logging import os import time +from dataclasses import dataclass, field from datetime import datetime from pathlib import Path @@ -42,6 +51,34 @@ from .cpam_validation import ( logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Structure intermédiaire pour le batching +# --------------------------------------------------------------------------- + +@dataclass +class _CpamDraft: + """État intermédiaire d'un contrôle CPAM en cours de traitement batché.""" + controle: ControleCPAM + # Résultats des phases successives + extraction: dict | None = None + degraded_pass1: bool = False + sources: list[dict] = field(default_factory=list) + rag_sources: list[RAGSource] = field(default_factory=list) + prompt: str = "" + tag_map: dict[str, str] = field(default_factory=dict) + result: dict | None = None + # Validation + ref_warnings: list[str] = field(default_factory=list) + grounding_warnings: list[str] = field(default_factory=list) + code_warnings: list[str] = field(default_factory=list) + adversarial_warnings: list[str] = field(default_factory=list) + validation: dict | None = None + # Résultat final + needs_correction: bool = False + text: str = "" + response_data: dict | None = None + + def _save_version( dossier: DossierMedical, controle: ControleCPAM, @@ -358,3 +395,343 @@ def generate_cpam_response( logger.info(" Contre-argumentation générée (%d caractères)", len(text)) return text, result, rag_sources + + +# --------------------------------------------------------------------------- +# Fonctions de phase (utilisées par le batching) +# --------------------------------------------------------------------------- + +def _phase_generate( + dossier: DossierMedical, + draft: _CpamDraft, +) -> None: + """Phase 1 — Extraction + RAG + argumentation (role=cpam). + + Remplit draft.extraction, draft.sources, draft.prompt, draft.tag_map, draft.result. + Ne fait aucun appel au modèle de validation. + """ + controle = draft.controle + logger.info("CPAM batch : génération pour OGC %d — %s", + controle.numero_ogc, controle.titre) + + # 0. Versioning + _save_version(dossier, controle) + + # 1. Extraction structurée (role=cpam) + draft.extraction = _extraction_pass(dossier, controle) + draft.degraded_pass1 = draft.extraction is None + if draft.degraded_pass1: + dossier.alertes_codage.append( + "CPAM: passe 1 (extraction structurée) échouée → mode dégradé" + ) + + # 2. Recherche RAG (pas de LLM) + draft.sources = _search_rag_for_control(controle, dossier) + logger.info(" RAG : %d sources trouvées", len(draft.sources)) + + # 3. Construction du prompt + draft.prompt, draft.tag_map = _build_cpam_prompt( + dossier, controle, draft.sources, draft.extraction + ) + + # 4. Argumentation (role=cpam) + t_gen = time.time() + result = call_ollama(draft.prompt, temperature=0.1, max_tokens=8000, role="cpam") + if result is not None: + logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama", + time.time() - t_gen, controle.numero_ogc) + else: + logger.info(" Ollama indisponible → fallback Anthropic Haiku") + result = call_anthropic(draft.prompt, temperature=0.1, max_tokens=8000) + if result is not None: + logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic", + time.time() - t_gen, controle.numero_ogc) + + # 5. Conversion sources RAG + draft.rag_sources = [ + RAGSource( + document=s.get("document", ""), + page=s.get("page"), + code=s.get("code"), + extrait=s.get("extrait", "")[:200], + ) + for s in draft.sources + ] + + if result is None: + logger.warning(" LLM non disponible — contre-argumentation non générée") + draft.result = None + return + + # 5b. Mode dégradé + if draft.degraded_pass1: + result.setdefault("quality_flags", {}) + result["quality_flags"]["cpam_pass1_failed"] = True + result["quality_flags"]["degraded_mode"] = True + + # 6. Sanitisation + gardien déterministes (pas de LLM) + sanitized = _sanitize_unauthorized_codes(result, dossier, controle) + if sanitized: + logger.info(" CPAM : %d code(s) hors périmètre supprimé(s) : %s", + len(sanitized), ", ".join(sanitized)) + result = _guardian_deterministic(result, dossier, controle, draft.tag_map) + + # 7. Validations déterministes (pas de LLM) + draft.ref_warnings = _validate_references(result, draft.sources) + if draft.ref_warnings: + logger.warning(" CPAM : %d référence(s) non vérifiable(s)", len(draft.ref_warnings)) + + draft.grounding_warnings = _validate_grounding(result, draft.tag_map) + if draft.grounding_warnings: + logger.warning(" CPAM : %d preuve(s) non traçable(s)", len(draft.grounding_warnings)) + + draft.code_warnings = _validate_codes_in_response(result, dossier, controle) + if draft.code_warnings: + logger.warning(" CPAM : %d code(s) hors périmètre", len(draft.code_warnings)) + + draft.result = result + + +def _phase_validate( + dossier: DossierMedical, + draft: _CpamDraft, +) -> None: + """Phase 2 — Validation adversariale (role=validation). + + Remplit draft.validation et draft.needs_correction. + """ + if draft.result is None: + return + + controle = draft.controle + + # LOGIC-3 : modèles identiques ? + from ..config import check_adversarial_model_config + same_model, model_msg = check_adversarial_model_config() + if same_model: + draft.result.setdefault("quality_flags", {}) + draft.result["quality_flags"]["adversarial_disabled_same_model"] = True + dossier.alertes_codage.append( + "Validation adversariale désactivée (modèles identiques)" + ) + + draft.validation = _validate_adversarial(draft.result, draft.tag_map, controle) + if draft.validation and not draft.validation.get("coherent", True): + erreurs = draft.validation.get("erreurs", []) + score = draft.validation.get("score_confiance", "?") + for e in erreurs: + if isinstance(e, str) and e.strip(): + draft.adversarial_warnings.append(f"Incohérence détectée : {e}") + if draft.adversarial_warnings: + draft.adversarial_warnings.append(f"Score de confiance : {score}/10") + + # Déterminer si une correction est nécessaire + max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2")) + draft.needs_correction = bool( + max_corrections > 0 + and draft.validation + and not draft.validation.get("coherent", True) + and draft.validation.get("score_confiance", 10) <= 5 + and rule_enabled("RULE-CPAM-CORRECTION-LOOP") + ) + + +def _phase_correct( + dossier: DossierMedical, + draft: _CpamDraft, +) -> None: + """Phase 3 — Correction (role=cpam) + re-validation (role=validation). + + Boucle correction/re-validation intégrée car les deux sont couplées. + Seuls les drafts avec needs_correction=True entrent ici. + """ + if draft.result is None or not draft.needs_correction: + return + + controle = draft.controle + max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2")) + + for attempt in range(max_corrections): + if not (draft.validation + and not draft.validation.get("coherent", True) + and draft.validation.get("score_confiance", 10) <= 5 + and rule_enabled("RULE-CPAM-CORRECTION-LOOP")): + break + + erreurs_v = draft.validation.get("erreurs", []) + logger.warning(" Score adversarial %s/10 — correction %d/%d (%d erreur(s))", + draft.validation.get("score_confiance"), + attempt + 1, max_corrections, len(erreurs_v)) + + correction_prompt = _build_correction_prompt( + draft.prompt, draft.result, draft.validation + ) + t_corr = time.time() + corrected = call_ollama( + correction_prompt, temperature=0.0, max_tokens=16000, role="cpam" + ) + if corrected is None: + corrected = call_anthropic( + correction_prompt, temperature=0.0, max_tokens=16000 + ) + logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d", + time.time() - t_corr, controle.numero_ogc, + attempt + 1, max_corrections) + + if not corrected: + break + + validation2 = _validate_adversarial(corrected, draft.tag_map, controle) + score2 = validation2.get("score_confiance", 0) if validation2 else 0 + score1 = draft.validation.get("score_confiance", 0) + + if score2 > score1: + logger.info(" Correction %d acceptée (score %s → %s)", + attempt + 1, score1, score2) + draft.result = corrected + draft.validation = validation2 + _sanitize_unauthorized_codes(draft.result, dossier, controle) + draft.result = _guardian_deterministic( + draft.result, dossier, controle, draft.tag_map + ) + draft.ref_warnings = _validate_references(draft.result, draft.sources) + draft.grounding_warnings = _validate_grounding(draft.result, draft.tag_map) + draft.code_warnings = _validate_codes_in_response( + draft.result, dossier, controle + ) + draft.adversarial_warnings = [] + if draft.validation and not draft.validation.get("coherent", True): + for e in draft.validation.get("erreurs", []): + if isinstance(e, str) and e.strip(): + draft.adversarial_warnings.append(f"Incohérence détectée : {e}") + if draft.adversarial_warnings: + draft.adversarial_warnings.append( + f"Score de confiance : {draft.validation.get('score_confiance', '?')}/10" + ) + else: + logger.warning(" Correction %d rejetée (score %s → %s)", + attempt + 1, score1, score2) + break + + +def _phase_finalize( + dossier: DossierMedical, + draft: _CpamDraft, +) -> tuple[str, dict | None, list[RAGSource]]: + """Phase finale — Qualité + formatage (déterministe, pas de LLM). + + Returns: + Même signature que generate_cpam_response(). + """ + controle = draft.controle + + if draft.result is None: + return "", None, draft.rag_sources + + all_warnings = ( + draft.ref_warnings + draft.grounding_warnings + + draft.code_warnings + draft.adversarial_warnings + ) + + strength = _assess_dossier_strength(dossier) + if strength["is_weak"]: + logger.info(" Dossier à preuves limitées (score %d/10) : %s", + strength["score"], ", ".join(strength["missing"])) + + tier, needs_review, cat_warnings = _assess_quality_tier( + draft.result, draft.ref_warnings, draft.grounding_warnings, + draft.code_warnings, draft.validation, + is_weak_dossier=strength["is_weak"], + ) + controle.quality_tier = tier + controle.requires_review = needs_review + controle.quality_warnings = cat_warnings + logger.info(" Qualité CPAM : tier %s, requires_review=%s, %d warnings", + tier, needs_review, len(cat_warnings)) + + text = _format_response( + draft.result, + ref_warnings=all_warnings, + quality_tier=tier, + categorized_warnings=cat_warnings, + ) + logger.info(" Contre-argumentation générée (%d caractères)", len(text)) + + return text, draft.result, draft.rag_sources + + +# --------------------------------------------------------------------------- +# Orchestrateur batché — minimise les swaps de modèle VRAM +# --------------------------------------------------------------------------- + +def generate_cpam_responses_batched( + dossier: DossierMedical, + controles: list[ControleCPAM], +) -> None: + """Génère les contre-argumentations pour TOUS les contrôles CPAM en batch. + + Regroupe les appels LLM par modèle pour minimiser les swaps VRAM : + Phase 1 (cpam model) : extraction + argumentation de tous les contrôles + Phase 2 (validation model) : validation adversariale de tous + Phase 3 (cpam model) : correction des contrôles échoués (si nécessaire) + Finalisation : qualité + formatage (déterministe) + + Gain : de ~4*N swaps de modèle à 2-4 swaps, indépendant de N. + + Args: + dossier: Le dossier médical (fusionné ou unique). + controles: Liste des contrôles CPAM à traiter. + + Side effects: + Remplit controle.contre_argumentation, controle.response_data, + controle.sources_reponse pour chaque contrôle. + """ + if not controles: + return + + t_total = time.time() + n = len(controles) + logger.info("CPAM batch : %d contrôle(s) — début du traitement batché", n) + + # Créer les drafts + drafts = [_CpamDraft(controle=ctrl) for ctrl in controles] + + # --- Phase 1 : génération (cpam model) --- + t_phase = time.time() + logger.info("CPAM batch phase 1/3 : génération (%d contrôles, role=cpam)", n) + for draft in drafts: + _phase_generate(dossier, draft) + logger.info("CPAM batch phase 1/3 terminée : %.1fs", time.time() - t_phase) + + # --- Phase 2 : validation adversariale (validation model — 1 swap) --- + drafts_to_validate = [d for d in drafts if d.result is not None] + if drafts_to_validate: + t_phase = time.time() + logger.info("CPAM batch phase 2/3 : validation (%d contrôles, role=validation)", + len(drafts_to_validate)) + for draft in drafts_to_validate: + _phase_validate(dossier, draft) + logger.info("CPAM batch phase 2/3 terminée : %.1fs", time.time() - t_phase) + + # --- Phase 3 : correction + re-validation (cpam + validation models) --- + drafts_to_correct = [d for d in drafts if d.needs_correction] + if drafts_to_correct: + t_phase = time.time() + logger.info("CPAM batch phase 3/3 : correction (%d contrôles nécessitent correction)", + len(drafts_to_correct)) + for draft in drafts_to_correct: + _phase_correct(dossier, draft) + logger.info("CPAM batch phase 3/3 terminée : %.1fs", time.time() - t_phase) + + # --- Finalisation (déterministe, pas de swap) --- + for draft in drafts: + text, response_data, sources = _phase_finalize(dossier, draft) + draft.controle.contre_argumentation = text + draft.controle.response_data = response_data + draft.controle.sources_reponse = sources + + elapsed_total = time.time() - t_total + logger.info( + "CPAM batch terminé : %d contrôle(s) en %.1fs (%.1fs/contrôle)", + n, elapsed_total, elapsed_total / n if n else 0, + ) diff --git a/src/main.py b/src/main.py index 6da3d75..e32a9f1 100644 --- a/src/main.py +++ b/src/main.py @@ -610,20 +610,18 @@ def main(input_path: str | None = None) -> None: merged = None # Contrôle CPAM : enrichir le dossier principal (fusionné ou dernier) + # Utilise le mode batché pour regrouper les appels LLM par modèle + # et minimiser les swaps VRAM (de ~4N swaps à ~2-4 swaps) if cpam_data and subdir: try: from .control.cpam_parser import match_dossier_ogc controles = match_dossier_ogc(subdir, cpam_data) if controles: - from .control.cpam_response import generate_cpam_response + from .control.cpam_response import generate_cpam_responses_batched target = merged if merged else (group_dossiers[-1] if group_dossiers else None) if target: logger.info(" CPAM : %d contrôle(s) pour %s", len(controles), subdir) - for ctrl in controles: - text, response_data, sources = generate_cpam_response(target, ctrl) - ctrl.contre_argumentation = text - ctrl.response_data = response_data - ctrl.sources_reponse = sources + generate_cpam_responses_batched(target, controles) target.controles_cpam = controles except Exception: logger.exception("Erreur CPAM pour %s", subdir)