feat: batching CPAM par modèle pour réduire les swaps VRAM

Restructure le flux CPAM pour regrouper les appels LLM par modèle :
- Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles
- Phase 2 (validation model, 1 swap) : validation adversariale de tous
- Phase 3 (cpam model, si correction nécessaire) : correction des échoués

Nouvelle fonction generate_cpam_responses_batched() qui orchestre les 3 phases.
L'ancienne generate_cpam_response() reste intacte (rétrocompatible, utilisée
par le viewer pour la régénération unitaire).

Structure intermédiaire _CpamDraft (dataclass) pour transporter l'état entre phases.
Fonctions de phase extraites : _phase_generate, _phase_validate, _phase_correct,
_phase_finalize.

Gain estimé : de ~4*N swaps de modèle à 2-4 swaps (indépendant de N contrôles).
Sur RTX 5070 avec des modèles 24-32b, chaque swap coûte ~10-15s.
Pour 3 contrôles : économie estimée ~60-90s.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dom
2026-03-08 14:10:34 +01:00
parent d6b4e48989
commit 0cafb47e8d
2 changed files with 381 additions and 6 deletions

View File

@@ -4,6 +4,14 @@ Orchestrateur principal — délègue aux sous-modules :
- cpam_rag : _search_rag_for_control(), _search_rag_queries() - cpam_rag : _search_rag_for_control(), _search_rag_queries()
- cpam_context : _build_cpam_prompt(), _build_tagged_context(), _build_bio_summary(), etc. - cpam_context : _build_cpam_prompt(), _build_tagged_context(), _build_bio_summary(), etc.
- cpam_validation : _validate_adversarial(), _validate_grounding(), _format_response(), etc. - cpam_validation : _validate_adversarial(), _validate_grounding(), _format_response(), etc.
Le flux batché (generate_cpam_responses_batched) regroupe les appels LLM par modèle
pour minimiser les swaps VRAM sur GPU unique :
Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles
Phase 2 (validation model, 1 swap) : validation adversariale de tous
Phase 3 (cpam model, 1 swap) : correction des contrôles échoués
Phase 4 (validation model, 1 swap) : re-validation des corrigés
Gain : de ~4N swaps à ~2-4 swaps (indépendant du nombre de contrôles).
""" """
from __future__ import annotations from __future__ import annotations
@@ -12,6 +20,7 @@ import json
import logging import logging
import os import os
import time import time
from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -42,6 +51,34 @@ from .cpam_validation import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Structure intermédiaire pour le batching
# ---------------------------------------------------------------------------
@dataclass
class _CpamDraft:
"""État intermédiaire d'un contrôle CPAM en cours de traitement batché."""
controle: ControleCPAM
# Résultats des phases successives
extraction: dict | None = None
degraded_pass1: bool = False
sources: list[dict] = field(default_factory=list)
rag_sources: list[RAGSource] = field(default_factory=list)
prompt: str = ""
tag_map: dict[str, str] = field(default_factory=dict)
result: dict | None = None
# Validation
ref_warnings: list[str] = field(default_factory=list)
grounding_warnings: list[str] = field(default_factory=list)
code_warnings: list[str] = field(default_factory=list)
adversarial_warnings: list[str] = field(default_factory=list)
validation: dict | None = None
# Résultat final
needs_correction: bool = False
text: str = ""
response_data: dict | None = None
def _save_version( def _save_version(
dossier: DossierMedical, dossier: DossierMedical,
controle: ControleCPAM, controle: ControleCPAM,
@@ -358,3 +395,343 @@ def generate_cpam_response(
logger.info(" Contre-argumentation générée (%d caractères)", len(text)) logger.info(" Contre-argumentation générée (%d caractères)", len(text))
return text, result, rag_sources return text, result, rag_sources
# ---------------------------------------------------------------------------
# Fonctions de phase (utilisées par le batching)
# ---------------------------------------------------------------------------
def _phase_generate(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 1 — Extraction + RAG + argumentation (role=cpam).
Remplit draft.extraction, draft.sources, draft.prompt, draft.tag_map, draft.result.
Ne fait aucun appel au modèle de validation.
"""
controle = draft.controle
logger.info("CPAM batch : génération pour OGC %d%s",
controle.numero_ogc, controle.titre)
# 0. Versioning
_save_version(dossier, controle)
# 1. Extraction structurée (role=cpam)
draft.extraction = _extraction_pass(dossier, controle)
draft.degraded_pass1 = draft.extraction is None
if draft.degraded_pass1:
dossier.alertes_codage.append(
"CPAM: passe 1 (extraction structurée) échouée → mode dégradé"
)
# 2. Recherche RAG (pas de LLM)
draft.sources = _search_rag_for_control(controle, dossier)
logger.info(" RAG : %d sources trouvées", len(draft.sources))
# 3. Construction du prompt
draft.prompt, draft.tag_map = _build_cpam_prompt(
dossier, controle, draft.sources, draft.extraction
)
# 4. Argumentation (role=cpam)
t_gen = time.time()
result = call_ollama(draft.prompt, temperature=0.1, max_tokens=8000, role="cpam")
if result is not None:
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama",
time.time() - t_gen, controle.numero_ogc)
else:
logger.info(" Ollama indisponible → fallback Anthropic Haiku")
result = call_anthropic(draft.prompt, temperature=0.1, max_tokens=8000)
if result is not None:
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic",
time.time() - t_gen, controle.numero_ogc)
# 5. Conversion sources RAG
draft.rag_sources = [
RAGSource(
document=s.get("document", ""),
page=s.get("page"),
code=s.get("code"),
extrait=s.get("extrait", "")[:200],
)
for s in draft.sources
]
if result is None:
logger.warning(" LLM non disponible — contre-argumentation non générée")
draft.result = None
return
# 5b. Mode dégradé
if draft.degraded_pass1:
result.setdefault("quality_flags", {})
result["quality_flags"]["cpam_pass1_failed"] = True
result["quality_flags"]["degraded_mode"] = True
# 6. Sanitisation + gardien déterministes (pas de LLM)
sanitized = _sanitize_unauthorized_codes(result, dossier, controle)
if sanitized:
logger.info(" CPAM : %d code(s) hors périmètre supprimé(s) : %s",
len(sanitized), ", ".join(sanitized))
result = _guardian_deterministic(result, dossier, controle, draft.tag_map)
# 7. Validations déterministes (pas de LLM)
draft.ref_warnings = _validate_references(result, draft.sources)
if draft.ref_warnings:
logger.warning(" CPAM : %d référence(s) non vérifiable(s)", len(draft.ref_warnings))
draft.grounding_warnings = _validate_grounding(result, draft.tag_map)
if draft.grounding_warnings:
logger.warning(" CPAM : %d preuve(s) non traçable(s)", len(draft.grounding_warnings))
draft.code_warnings = _validate_codes_in_response(result, dossier, controle)
if draft.code_warnings:
logger.warning(" CPAM : %d code(s) hors périmètre", len(draft.code_warnings))
draft.result = result
def _phase_validate(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 2 — Validation adversariale (role=validation).
Remplit draft.validation et draft.needs_correction.
"""
if draft.result is None:
return
controle = draft.controle
# LOGIC-3 : modèles identiques ?
from ..config import check_adversarial_model_config
same_model, model_msg = check_adversarial_model_config()
if same_model:
draft.result.setdefault("quality_flags", {})
draft.result["quality_flags"]["adversarial_disabled_same_model"] = True
dossier.alertes_codage.append(
"Validation adversariale désactivée (modèles identiques)"
)
draft.validation = _validate_adversarial(draft.result, draft.tag_map, controle)
if draft.validation and not draft.validation.get("coherent", True):
erreurs = draft.validation.get("erreurs", [])
score = draft.validation.get("score_confiance", "?")
for e in erreurs:
if isinstance(e, str) and e.strip():
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
if draft.adversarial_warnings:
draft.adversarial_warnings.append(f"Score de confiance : {score}/10")
# Déterminer si une correction est nécessaire
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
draft.needs_correction = bool(
max_corrections > 0
and draft.validation
and not draft.validation.get("coherent", True)
and draft.validation.get("score_confiance", 10) <= 5
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")
)
def _phase_correct(
dossier: DossierMedical,
draft: _CpamDraft,
) -> None:
"""Phase 3 — Correction (role=cpam) + re-validation (role=validation).
Boucle correction/re-validation intégrée car les deux sont couplées.
Seuls les drafts avec needs_correction=True entrent ici.
"""
if draft.result is None or not draft.needs_correction:
return
controle = draft.controle
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
for attempt in range(max_corrections):
if not (draft.validation
and not draft.validation.get("coherent", True)
and draft.validation.get("score_confiance", 10) <= 5
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")):
break
erreurs_v = draft.validation.get("erreurs", [])
logger.warning(" Score adversarial %s/10 — correction %d/%d (%d erreur(s))",
draft.validation.get("score_confiance"),
attempt + 1, max_corrections, len(erreurs_v))
correction_prompt = _build_correction_prompt(
draft.prompt, draft.result, draft.validation
)
t_corr = time.time()
corrected = call_ollama(
correction_prompt, temperature=0.0, max_tokens=16000, role="cpam"
)
if corrected is None:
corrected = call_anthropic(
correction_prompt, temperature=0.0, max_tokens=16000
)
logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d",
time.time() - t_corr, controle.numero_ogc,
attempt + 1, max_corrections)
if not corrected:
break
validation2 = _validate_adversarial(corrected, draft.tag_map, controle)
score2 = validation2.get("score_confiance", 0) if validation2 else 0
score1 = draft.validation.get("score_confiance", 0)
if score2 > score1:
logger.info(" Correction %d acceptée (score %s%s)",
attempt + 1, score1, score2)
draft.result = corrected
draft.validation = validation2
_sanitize_unauthorized_codes(draft.result, dossier, controle)
draft.result = _guardian_deterministic(
draft.result, dossier, controle, draft.tag_map
)
draft.ref_warnings = _validate_references(draft.result, draft.sources)
draft.grounding_warnings = _validate_grounding(draft.result, draft.tag_map)
draft.code_warnings = _validate_codes_in_response(
draft.result, dossier, controle
)
draft.adversarial_warnings = []
if draft.validation and not draft.validation.get("coherent", True):
for e in draft.validation.get("erreurs", []):
if isinstance(e, str) and e.strip():
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
if draft.adversarial_warnings:
draft.adversarial_warnings.append(
f"Score de confiance : {draft.validation.get('score_confiance', '?')}/10"
)
else:
logger.warning(" Correction %d rejetée (score %s%s)",
attempt + 1, score1, score2)
break
def _phase_finalize(
dossier: DossierMedical,
draft: _CpamDraft,
) -> tuple[str, dict | None, list[RAGSource]]:
"""Phase finale — Qualité + formatage (déterministe, pas de LLM).
Returns:
Même signature que generate_cpam_response().
"""
controle = draft.controle
if draft.result is None:
return "", None, draft.rag_sources
all_warnings = (
draft.ref_warnings + draft.grounding_warnings
+ draft.code_warnings + draft.adversarial_warnings
)
strength = _assess_dossier_strength(dossier)
if strength["is_weak"]:
logger.info(" Dossier à preuves limitées (score %d/10) : %s",
strength["score"], ", ".join(strength["missing"]))
tier, needs_review, cat_warnings = _assess_quality_tier(
draft.result, draft.ref_warnings, draft.grounding_warnings,
draft.code_warnings, draft.validation,
is_weak_dossier=strength["is_weak"],
)
controle.quality_tier = tier
controle.requires_review = needs_review
controle.quality_warnings = cat_warnings
logger.info(" Qualité CPAM : tier %s, requires_review=%s, %d warnings",
tier, needs_review, len(cat_warnings))
text = _format_response(
draft.result,
ref_warnings=all_warnings,
quality_tier=tier,
categorized_warnings=cat_warnings,
)
logger.info(" Contre-argumentation générée (%d caractères)", len(text))
return text, draft.result, draft.rag_sources
# ---------------------------------------------------------------------------
# Orchestrateur batché — minimise les swaps de modèle VRAM
# ---------------------------------------------------------------------------
def generate_cpam_responses_batched(
dossier: DossierMedical,
controles: list[ControleCPAM],
) -> None:
"""Génère les contre-argumentations pour TOUS les contrôles CPAM en batch.
Regroupe les appels LLM par modèle pour minimiser les swaps VRAM :
Phase 1 (cpam model) : extraction + argumentation de tous les contrôles
Phase 2 (validation model) : validation adversariale de tous
Phase 3 (cpam model) : correction des contrôles échoués (si nécessaire)
Finalisation : qualité + formatage (déterministe)
Gain : de ~4*N swaps de modèle à 2-4 swaps, indépendant de N.
Args:
dossier: Le dossier médical (fusionné ou unique).
controles: Liste des contrôles CPAM à traiter.
Side effects:
Remplit controle.contre_argumentation, controle.response_data,
controle.sources_reponse pour chaque contrôle.
"""
if not controles:
return
t_total = time.time()
n = len(controles)
logger.info("CPAM batch : %d contrôle(s) — début du traitement batché", n)
# Créer les drafts
drafts = [_CpamDraft(controle=ctrl) for ctrl in controles]
# --- Phase 1 : génération (cpam model) ---
t_phase = time.time()
logger.info("CPAM batch phase 1/3 : génération (%d contrôles, role=cpam)", n)
for draft in drafts:
_phase_generate(dossier, draft)
logger.info("CPAM batch phase 1/3 terminée : %.1fs", time.time() - t_phase)
# --- Phase 2 : validation adversariale (validation model — 1 swap) ---
drafts_to_validate = [d for d in drafts if d.result is not None]
if drafts_to_validate:
t_phase = time.time()
logger.info("CPAM batch phase 2/3 : validation (%d contrôles, role=validation)",
len(drafts_to_validate))
for draft in drafts_to_validate:
_phase_validate(dossier, draft)
logger.info("CPAM batch phase 2/3 terminée : %.1fs", time.time() - t_phase)
# --- Phase 3 : correction + re-validation (cpam + validation models) ---
drafts_to_correct = [d for d in drafts if d.needs_correction]
if drafts_to_correct:
t_phase = time.time()
logger.info("CPAM batch phase 3/3 : correction (%d contrôles nécessitent correction)",
len(drafts_to_correct))
for draft in drafts_to_correct:
_phase_correct(dossier, draft)
logger.info("CPAM batch phase 3/3 terminée : %.1fs", time.time() - t_phase)
# --- Finalisation (déterministe, pas de swap) ---
for draft in drafts:
text, response_data, sources = _phase_finalize(dossier, draft)
draft.controle.contre_argumentation = text
draft.controle.response_data = response_data
draft.controle.sources_reponse = sources
elapsed_total = time.time() - t_total
logger.info(
"CPAM batch terminé : %d contrôle(s) en %.1fs (%.1fs/contrôle)",
n, elapsed_total, elapsed_total / n if n else 0,
)

View File

@@ -610,20 +610,18 @@ def main(input_path: str | None = None) -> None:
merged = None merged = None
# Contrôle CPAM : enrichir le dossier principal (fusionné ou dernier) # Contrôle CPAM : enrichir le dossier principal (fusionné ou dernier)
# Utilise le mode batché pour regrouper les appels LLM par modèle
# et minimiser les swaps VRAM (de ~4N swaps à ~2-4 swaps)
if cpam_data and subdir: if cpam_data and subdir:
try: try:
from .control.cpam_parser import match_dossier_ogc from .control.cpam_parser import match_dossier_ogc
controles = match_dossier_ogc(subdir, cpam_data) controles = match_dossier_ogc(subdir, cpam_data)
if controles: if controles:
from .control.cpam_response import generate_cpam_response from .control.cpam_response import generate_cpam_responses_batched
target = merged if merged else (group_dossiers[-1] if group_dossiers else None) target = merged if merged else (group_dossiers[-1] if group_dossiers else None)
if target: if target:
logger.info(" CPAM : %d contrôle(s) pour %s", len(controles), subdir) logger.info(" CPAM : %d contrôle(s) pour %s", len(controles), subdir)
for ctrl in controles: generate_cpam_responses_batched(target, controles)
text, response_data, sources = generate_cpam_response(target, ctrl)
ctrl.contre_argumentation = text
ctrl.response_data = response_data
ctrl.sources_reponse = sources
target.controles_cpam = controles target.controles_cpam = controles
except Exception: except Exception:
logger.exception("Erreur CPAM pour %s", subdir) logger.exception("Erreur CPAM pour %s", subdir)