feat: batching CPAM par modèle pour réduire les swaps VRAM

Restructure le flux CPAM pour regrouper les appels LLM par modèle :
- Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles
- Phase 2 (validation model, 1 swap) : validation adversariale de tous
- Phase 3 (cpam model, si correction nécessaire) : correction des échoués

Nouvelle fonction generate_cpam_responses_batched() qui orchestre les 3 phases.
L'ancienne generate_cpam_response() reste intacte (rétrocompatible, utilisée
par le viewer pour la régénération unitaire).

Structure intermédiaire _CpamDraft (dataclass) pour transporter l'état entre phases.
Fonctions de phase extraites : _phase_generate, _phase_validate, _phase_correct,
_phase_finalize.

Gain estimé : de ~4*N swaps de modèle à 2-4 swaps (indépendant de N contrôles).
Sur RTX 5070 avec des modèles 24-32b, chaque swap coûte ~10-15s.
Pour 3 contrôles : économie estimée ~60-90s.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dom
2026-03-08 14:10:34 +01:00
parent d6b4e48989
commit 0cafb47e8d
2 changed files with 381 additions and 6 deletions

View File

@@ -4,6 +4,14 @@ Orchestrateur principal — délègue aux sous-modules :
- cpam_rag : _search_rag_for_control(), _search_rag_queries()
- cpam_context : _build_cpam_prompt(), _build_tagged_context(), _build_bio_summary(), etc.
- cpam_validation : _validate_adversarial(), _validate_grounding(), _format_response(), etc.
Le flux batché (generate_cpam_responses_batched) regroupe les appels LLM par modèle
pour minimiser les swaps VRAM sur GPU unique :
Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles
Phase 2 (validation model, 1 swap) : validation adversariale de tous
Phase 3 (cpam model, 1 swap) : correction des contrôles échoués
Phase 4 (validation model, 1 swap) : re-validation des corrigés
Gain : de ~4N swaps à ~2-4 swaps (indépendant du nombre de contrôles).
"""
from __future__ import annotations
@@ -12,6 +20,7 @@ import json
import logging
import os
import time
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
@@ -42,6 +51,34 @@ from .cpam_validation import (
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Structure intermédiaire pour le batching
# ---------------------------------------------------------------------------
@dataclass
class _CpamDraft:
"""État intermédiaire d'un contrôle CPAM en cours de traitement batché."""
controle: ControleCPAM
# Résultats des phases successives
extraction: dict | None = None
degraded_pass1: bool = False
sources: list[dict] = field(default_factory=list)
rag_sources: list[RAGSource] = field(default_factory=list)
prompt: str = ""
tag_map: dict[str, str] = field(default_factory=dict)
result: dict | None = None
# Validation
ref_warnings: list[str] = field(default_factory=list)
grounding_warnings: list[str] = field(default_factory=list)
code_warnings: list[str] = field(default_factory=list)
adversarial_warnings: list[str] = field(default_factory=list)
validation: dict | None = None
# Résultat final
needs_correction: bool = False
text: str = ""
response_data: dict | None = None
def _save_version(
dossier: DossierMedical,
controle: ControleCPAM,
@@ -358,3 +395,343 @@ def generate_cpam_response(
logger.info(" Contre-argumentation générée (%d caractères)", len(text))
return text, result, rag_sources
# ---------------------------------------------------------------------------
# Fonctions de phase (utilisées par le batching)
# ---------------------------------------------------------------------------
def _phase_generate(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 1 — Extraction + RAG + argumentation (role=cpam).
Remplit draft.extraction, draft.sources, draft.prompt, draft.tag_map, draft.result.
Ne fait aucun appel au modèle de validation.
"""
controle = draft.controle
logger.info("CPAM batch : génération pour OGC %d%s",
controle.numero_ogc, controle.titre)
# 0. Versioning
_save_version(dossier, controle)
# 1. Extraction structurée (role=cpam)
draft.extraction = _extraction_pass(dossier, controle)
draft.degraded_pass1 = draft.extraction is None
if draft.degraded_pass1:
dossier.alertes_codage.append(
"CPAM: passe 1 (extraction structurée) échouée → mode dégradé"
)
# 2. Recherche RAG (pas de LLM)
draft.sources = _search_rag_for_control(controle, dossier)
logger.info(" RAG : %d sources trouvées", len(draft.sources))
# 3. Construction du prompt
draft.prompt, draft.tag_map = _build_cpam_prompt(
dossier, controle, draft.sources, draft.extraction
)
# 4. Argumentation (role=cpam)
t_gen = time.time()
result = call_ollama(draft.prompt, temperature=0.1, max_tokens=8000, role="cpam")
if result is not None:
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama",
time.time() - t_gen, controle.numero_ogc)
else:
logger.info(" Ollama indisponible → fallback Anthropic Haiku")
result = call_anthropic(draft.prompt, temperature=0.1, max_tokens=8000)
if result is not None:
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic",
time.time() - t_gen, controle.numero_ogc)
# 5. Conversion sources RAG
draft.rag_sources = [
RAGSource(
document=s.get("document", ""),
page=s.get("page"),
code=s.get("code"),
extrait=s.get("extrait", "")[:200],
)
for s in draft.sources
]
if result is None:
logger.warning(" LLM non disponible — contre-argumentation non générée")
draft.result = None
return
# 5b. Mode dégradé
if draft.degraded_pass1:
result.setdefault("quality_flags", {})
result["quality_flags"]["cpam_pass1_failed"] = True
result["quality_flags"]["degraded_mode"] = True
# 6. Sanitisation + gardien déterministes (pas de LLM)
sanitized = _sanitize_unauthorized_codes(result, dossier, controle)
if sanitized:
logger.info(" CPAM : %d code(s) hors périmètre supprimé(s) : %s",
len(sanitized), ", ".join(sanitized))
result = _guardian_deterministic(result, dossier, controle, draft.tag_map)
# 7. Validations déterministes (pas de LLM)
draft.ref_warnings = _validate_references(result, draft.sources)
if draft.ref_warnings:
logger.warning(" CPAM : %d référence(s) non vérifiable(s)", len(draft.ref_warnings))
draft.grounding_warnings = _validate_grounding(result, draft.tag_map)
if draft.grounding_warnings:
logger.warning(" CPAM : %d preuve(s) non traçable(s)", len(draft.grounding_warnings))
draft.code_warnings = _validate_codes_in_response(result, dossier, controle)
if draft.code_warnings:
logger.warning(" CPAM : %d code(s) hors périmètre", len(draft.code_warnings))
draft.result = result
def _phase_validate(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 2 — Validation adversariale (role=validation).
Remplit draft.validation et draft.needs_correction.
"""
if draft.result is None:
return
controle = draft.controle
# LOGIC-3 : modèles identiques ?
from ..config import check_adversarial_model_config
same_model, model_msg = check_adversarial_model_config()
if same_model:
draft.result.setdefault("quality_flags", {})
draft.result["quality_flags"]["adversarial_disabled_same_model"] = True
dossier.alertes_codage.append(
"Validation adversariale désactivée (modèles identiques)"
)
draft.validation = _validate_adversarial(draft.result, draft.tag_map, controle)
if draft.validation and not draft.validation.get("coherent", True):
erreurs = draft.validation.get("erreurs", [])
score = draft.validation.get("score_confiance", "?")
for e in erreurs:
if isinstance(e, str) and e.strip():
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
if draft.adversarial_warnings:
draft.adversarial_warnings.append(f"Score de confiance : {score}/10")
# Déterminer si une correction est nécessaire
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
draft.needs_correction = bool(
max_corrections > 0
and draft.validation
and not draft.validation.get("coherent", True)
and draft.validation.get("score_confiance", 10) <= 5
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")
)
def _phase_correct(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 3 — Correction (role=cpam) + re-validation (role=validation).
Boucle correction/re-validation intégrée car les deux sont couplées.
Seuls les drafts avec needs_correction=True entrent ici.
"""
if draft.result is None or not draft.needs_correction:
return
controle = draft.controle
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
for attempt in range(max_corrections):
if not (draft.validation
and not draft.validation.get("coherent", True)
and draft.validation.get("score_confiance", 10) <= 5
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")):
break
erreurs_v = draft.validation.get("erreurs", [])
logger.warning(" Score adversarial %s/10 — correction %d/%d (%d erreur(s))",
draft.validation.get("score_confiance"),
attempt + 1, max_corrections, len(erreurs_v))
correction_prompt = _build_correction_prompt(
draft.prompt, draft.result, draft.validation
)
t_corr = time.time()
corrected = call_ollama(
correction_prompt, temperature=0.0, max_tokens=16000, role="cpam"
)
if corrected is None:
corrected = call_anthropic(
correction_prompt, temperature=0.0, max_tokens=16000
)
logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d",
time.time() - t_corr, controle.numero_ogc,
attempt + 1, max_corrections)
if not corrected:
break
validation2 = _validate_adversarial(corrected, draft.tag_map, controle)
score2 = validation2.get("score_confiance", 0) if validation2 else 0
score1 = draft.validation.get("score_confiance", 0)
if score2 > score1:
logger.info(" Correction %d acceptée (score %s%s)",
attempt + 1, score1, score2)
draft.result = corrected
draft.validation = validation2
_sanitize_unauthorized_codes(draft.result, dossier, controle)
draft.result = _guardian_deterministic(
draft.result, dossier, controle, draft.tag_map
)
draft.ref_warnings = _validate_references(draft.result, draft.sources)
draft.grounding_warnings = _validate_grounding(draft.result, draft.tag_map)
draft.code_warnings = _validate_codes_in_response(
draft.result, dossier, controle
)
draft.adversarial_warnings = []
if draft.validation and not draft.validation.get("coherent", True):
for e in draft.validation.get("erreurs", []):
if isinstance(e, str) and e.strip():
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
if draft.adversarial_warnings:
draft.adversarial_warnings.append(
f"Score de confiance : {draft.validation.get('score_confiance', '?')}/10"
)
else:
logger.warning(" Correction %d rejetée (score %s%s)",
attempt + 1, score1, score2)
break
def _phase_finalize(
dossier: DossierMedical,
draft: _CpamDraft,
) -> tuple[str, dict | None, list[RAGSource]]:
"""Phase finale — Qualité + formatage (déterministe, pas de LLM).
Returns:
Même signature que generate_cpam_response().
"""
controle = draft.controle
if draft.result is None:
return "", None, draft.rag_sources
all_warnings = (
draft.ref_warnings + draft.grounding_warnings
+ draft.code_warnings + draft.adversarial_warnings
)
strength = _assess_dossier_strength(dossier)
if strength["is_weak"]:
logger.info(" Dossier à preuves limitées (score %d/10) : %s",
strength["score"], ", ".join(strength["missing"]))
tier, needs_review, cat_warnings = _assess_quality_tier(
draft.result, draft.ref_warnings, draft.grounding_warnings,
draft.code_warnings, draft.validation,
is_weak_dossier=strength["is_weak"],
)
controle.quality_tier = tier
controle.requires_review = needs_review
controle.quality_warnings = cat_warnings
logger.info(" Qualité CPAM : tier %s, requires_review=%s, %d warnings",
tier, needs_review, len(cat_warnings))
text = _format_response(
draft.result,
ref_warnings=all_warnings,
quality_tier=tier,
categorized_warnings=cat_warnings,
)
logger.info(" Contre-argumentation générée (%d caractères)", len(text))
return text, draft.result, draft.rag_sources
# ---------------------------------------------------------------------------
# Orchestrateur batché — minimise les swaps de modèle VRAM
# ---------------------------------------------------------------------------
def generate_cpam_responses_batched(
dossier: DossierMedical,
controles: list[ControleCPAM],
) -> None:
"""Génère les contre-argumentations pour TOUS les contrôles CPAM en batch.
Regroupe les appels LLM par modèle pour minimiser les swaps VRAM :
Phase 1 (cpam model) : extraction + argumentation de tous les contrôles
Phase 2 (validation model) : validation adversariale de tous
Phase 3 (cpam model) : correction des contrôles échoués (si nécessaire)
Finalisation : qualité + formatage (déterministe)
Gain : de ~4*N swaps de modèle à 2-4 swaps, indépendant de N.
Args:
dossier: Le dossier médical (fusionné ou unique).
controles: Liste des contrôles CPAM à traiter.
Side effects:
Remplit controle.contre_argumentation, controle.response_data,
controle.sources_reponse pour chaque contrôle.
"""
if not controles:
return
t_total = time.time()
n = len(controles)
logger.info("CPAM batch : %d contrôle(s) — début du traitement batché", n)
# Créer les drafts
drafts = [_CpamDraft(controle=ctrl) for ctrl in controles]
# --- Phase 1 : génération (cpam model) ---
t_phase = time.time()
logger.info("CPAM batch phase 1/3 : génération (%d contrôles, role=cpam)", n)
for draft in drafts:
_phase_generate(dossier, draft)
logger.info("CPAM batch phase 1/3 terminée : %.1fs", time.time() - t_phase)
# --- Phase 2 : validation adversariale (validation model — 1 swap) ---
drafts_to_validate = [d for d in drafts if d.result is not None]
if drafts_to_validate:
t_phase = time.time()
logger.info("CPAM batch phase 2/3 : validation (%d contrôles, role=validation)",
len(drafts_to_validate))
for draft in drafts_to_validate:
_phase_validate(dossier, draft)
logger.info("CPAM batch phase 2/3 terminée : %.1fs", time.time() - t_phase)
# --- Phase 3 : correction + re-validation (cpam + validation models) ---
drafts_to_correct = [d for d in drafts if d.needs_correction]
if drafts_to_correct:
t_phase = time.time()
logger.info("CPAM batch phase 3/3 : correction (%d contrôles nécessitent correction)",
len(drafts_to_correct))
for draft in drafts_to_correct:
_phase_correct(dossier, draft)
logger.info("CPAM batch phase 3/3 terminée : %.1fs", time.time() - t_phase)
# --- Finalisation (déterministe, pas de swap) ---
for draft in drafts:
text, response_data, sources = _phase_finalize(dossier, draft)
draft.controle.contre_argumentation = text
draft.controle.response_data = response_data
draft.controle.sources_reponse = sources
elapsed_total = time.time() - t_total
logger.info(
"CPAM batch terminé : %d contrôle(s) en %.1fs (%.1fs/contrôle)",
n, elapsed_total, elapsed_total / n if n else 0,
)

View File

@@ -610,20 +610,18 @@ def main(input_path: str | None = None) -> None:
merged = None
# Contrôle CPAM : enrichir le dossier principal (fusionné ou dernier)
# Utilise le mode batché pour regrouper les appels LLM par modèle
# et minimiser les swaps VRAM (de ~4N swaps à ~2-4 swaps)
if cpam_data and subdir:
try:
from .control.cpam_parser import match_dossier_ogc
controles = match_dossier_ogc(subdir, cpam_data)
if controles:
from .control.cpam_response import generate_cpam_response
from .control.cpam_response import generate_cpam_responses_batched
target = merged if merged else (group_dossiers[-1] if group_dossiers else None)
if target:
logger.info(" CPAM : %d contrôle(s) pour %s", len(controles), subdir)
for ctrl in controles:
text, response_data, sources = generate_cpam_response(target, ctrl)
ctrl.contre_argumentation = text
ctrl.response_data = response_data
ctrl.sources_reponse = sources
generate_cpam_responses_batched(target, controles)
target.controles_cpam = controles
except Exception:
logger.exception("Erreur CPAM pour %s", subdir)