feat: batching CPAM par modèle pour réduire les swaps VRAM
Restructure le flux CPAM pour regrouper les appels LLM par modèle : - Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles - Phase 2 (validation model, 1 swap) : validation adversariale de tous - Phase 3 (cpam model, si correction nécessaire) : correction des échoués Nouvelle fonction generate_cpam_responses_batched() qui orchestre les 3 phases. L'ancienne generate_cpam_response() reste intacte (rétrocompatible, utilisée par le viewer pour la régénération unitaire). Structure intermédiaire _CpamDraft (dataclass) pour transporter l'état entre phases. Fonctions de phase extraites : _phase_generate, _phase_validate, _phase_correct, _phase_finalize. Gain estimé : de ~4*N swaps de modèle à 2-4 swaps (indépendant de N contrôles). Sur RTX 5070 avec des modèles 24-32b, chaque swap coûte ~10-15s. Pour 3 contrôles : économie estimée ~60-90s. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,14 @@ Orchestrateur principal — délègue aux sous-modules :
|
||||
- cpam_rag : _search_rag_for_control(), _search_rag_queries()
|
||||
- cpam_context : _build_cpam_prompt(), _build_tagged_context(), _build_bio_summary(), etc.
|
||||
- cpam_validation : _validate_adversarial(), _validate_grounding(), _format_response(), etc.
|
||||
|
||||
Le flux batché (generate_cpam_responses_batched) regroupe les appels LLM par modèle
|
||||
pour minimiser les swaps VRAM sur GPU unique :
|
||||
Phase 1 (cpam model) : extraction + argumentation de TOUS les contrôles
|
||||
Phase 2 (validation model, 1 swap) : validation adversariale de tous
|
||||
Phase 3 (cpam model, 1 swap) : correction des contrôles échoués
|
||||
Phase 4 (validation model, 1 swap) : re-validation des corrigés
|
||||
Gain : de ~4N swaps à ~2-4 swaps (indépendant du nombre de contrôles).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -12,6 +20,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -42,6 +51,34 @@ from .cpam_validation import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Structure intermédiaire pour le batching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class _CpamDraft:
|
||||
"""État intermédiaire d'un contrôle CPAM en cours de traitement batché."""
|
||||
controle: ControleCPAM
|
||||
# Résultats des phases successives
|
||||
extraction: dict | None = None
|
||||
degraded_pass1: bool = False
|
||||
sources: list[dict] = field(default_factory=list)
|
||||
rag_sources: list[RAGSource] = field(default_factory=list)
|
||||
prompt: str = ""
|
||||
tag_map: dict[str, str] = field(default_factory=dict)
|
||||
result: dict | None = None
|
||||
# Validation
|
||||
ref_warnings: list[str] = field(default_factory=list)
|
||||
grounding_warnings: list[str] = field(default_factory=list)
|
||||
code_warnings: list[str] = field(default_factory=list)
|
||||
adversarial_warnings: list[str] = field(default_factory=list)
|
||||
validation: dict | None = None
|
||||
# Résultat final
|
||||
needs_correction: bool = False
|
||||
text: str = ""
|
||||
response_data: dict | None = None
|
||||
|
||||
|
||||
def _save_version(
|
||||
dossier: DossierMedical,
|
||||
controle: ControleCPAM,
|
||||
@@ -358,3 +395,343 @@ def generate_cpam_response(
|
||||
logger.info(" Contre-argumentation générée (%d caractères)", len(text))
|
||||
|
||||
return text, result, rag_sources
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fonctions de phase (utilisées par le batching)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _phase_generate(
|
||||
dossier: DossierMedical,
|
||||
draft: _CpamDraft,
|
||||
) -> None:
|
||||
"""Phase 1 — Extraction + RAG + argumentation (role=cpam).
|
||||
|
||||
Remplit draft.extraction, draft.sources, draft.prompt, draft.tag_map, draft.result.
|
||||
Ne fait aucun appel au modèle de validation.
|
||||
"""
|
||||
controle = draft.controle
|
||||
logger.info("CPAM batch : génération pour OGC %d — %s",
|
||||
controle.numero_ogc, controle.titre)
|
||||
|
||||
# 0. Versioning
|
||||
_save_version(dossier, controle)
|
||||
|
||||
# 1. Extraction structurée (role=cpam)
|
||||
draft.extraction = _extraction_pass(dossier, controle)
|
||||
draft.degraded_pass1 = draft.extraction is None
|
||||
if draft.degraded_pass1:
|
||||
dossier.alertes_codage.append(
|
||||
"CPAM: passe 1 (extraction structurée) échouée → mode dégradé"
|
||||
)
|
||||
|
||||
# 2. Recherche RAG (pas de LLM)
|
||||
draft.sources = _search_rag_for_control(controle, dossier)
|
||||
logger.info(" RAG : %d sources trouvées", len(draft.sources))
|
||||
|
||||
# 3. Construction du prompt
|
||||
draft.prompt, draft.tag_map = _build_cpam_prompt(
|
||||
dossier, controle, draft.sources, draft.extraction
|
||||
)
|
||||
|
||||
# 4. Argumentation (role=cpam)
|
||||
t_gen = time.time()
|
||||
result = call_ollama(draft.prompt, temperature=0.1, max_tokens=8000, role="cpam")
|
||||
if result is not None:
|
||||
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Ollama",
|
||||
time.time() - t_gen, controle.numero_ogc)
|
||||
else:
|
||||
logger.info(" Ollama indisponible → fallback Anthropic Haiku")
|
||||
result = call_anthropic(draft.prompt, temperature=0.1, max_tokens=8000)
|
||||
if result is not None:
|
||||
logger.info(" [CPAM-GEN] %.1fs — OGC %s — contre-argumentation via Anthropic",
|
||||
time.time() - t_gen, controle.numero_ogc)
|
||||
|
||||
# 5. Conversion sources RAG
|
||||
draft.rag_sources = [
|
||||
RAGSource(
|
||||
document=s.get("document", ""),
|
||||
page=s.get("page"),
|
||||
code=s.get("code"),
|
||||
extrait=s.get("extrait", "")[:200],
|
||||
)
|
||||
for s in draft.sources
|
||||
]
|
||||
|
||||
if result is None:
|
||||
logger.warning(" LLM non disponible — contre-argumentation non générée")
|
||||
draft.result = None
|
||||
return
|
||||
|
||||
# 5b. Mode dégradé
|
||||
if draft.degraded_pass1:
|
||||
result.setdefault("quality_flags", {})
|
||||
result["quality_flags"]["cpam_pass1_failed"] = True
|
||||
result["quality_flags"]["degraded_mode"] = True
|
||||
|
||||
# 6. Sanitisation + gardien déterministes (pas de LLM)
|
||||
sanitized = _sanitize_unauthorized_codes(result, dossier, controle)
|
||||
if sanitized:
|
||||
logger.info(" CPAM : %d code(s) hors périmètre supprimé(s) : %s",
|
||||
len(sanitized), ", ".join(sanitized))
|
||||
result = _guardian_deterministic(result, dossier, controle, draft.tag_map)
|
||||
|
||||
# 7. Validations déterministes (pas de LLM)
|
||||
draft.ref_warnings = _validate_references(result, draft.sources)
|
||||
if draft.ref_warnings:
|
||||
logger.warning(" CPAM : %d référence(s) non vérifiable(s)", len(draft.ref_warnings))
|
||||
|
||||
draft.grounding_warnings = _validate_grounding(result, draft.tag_map)
|
||||
if draft.grounding_warnings:
|
||||
logger.warning(" CPAM : %d preuve(s) non traçable(s)", len(draft.grounding_warnings))
|
||||
|
||||
draft.code_warnings = _validate_codes_in_response(result, dossier, controle)
|
||||
if draft.code_warnings:
|
||||
logger.warning(" CPAM : %d code(s) hors périmètre", len(draft.code_warnings))
|
||||
|
||||
draft.result = result
|
||||
|
||||
|
||||
def _phase_validate(
|
||||
dossier: DossierMedical,
|
||||
draft: _CpamDraft,
|
||||
) -> None:
|
||||
"""Phase 2 — Validation adversariale (role=validation).
|
||||
|
||||
Remplit draft.validation et draft.needs_correction.
|
||||
"""
|
||||
if draft.result is None:
|
||||
return
|
||||
|
||||
controle = draft.controle
|
||||
|
||||
# LOGIC-3 : modèles identiques ?
|
||||
from ..config import check_adversarial_model_config
|
||||
same_model, model_msg = check_adversarial_model_config()
|
||||
if same_model:
|
||||
draft.result.setdefault("quality_flags", {})
|
||||
draft.result["quality_flags"]["adversarial_disabled_same_model"] = True
|
||||
dossier.alertes_codage.append(
|
||||
"Validation adversariale désactivée (modèles identiques)"
|
||||
)
|
||||
|
||||
draft.validation = _validate_adversarial(draft.result, draft.tag_map, controle)
|
||||
if draft.validation and not draft.validation.get("coherent", True):
|
||||
erreurs = draft.validation.get("erreurs", [])
|
||||
score = draft.validation.get("score_confiance", "?")
|
||||
for e in erreurs:
|
||||
if isinstance(e, str) and e.strip():
|
||||
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
|
||||
if draft.adversarial_warnings:
|
||||
draft.adversarial_warnings.append(f"Score de confiance : {score}/10")
|
||||
|
||||
# Déterminer si une correction est nécessaire
|
||||
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
|
||||
draft.needs_correction = bool(
|
||||
max_corrections > 0
|
||||
and draft.validation
|
||||
and not draft.validation.get("coherent", True)
|
||||
and draft.validation.get("score_confiance", 10) <= 5
|
||||
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")
|
||||
)
|
||||
|
||||
|
||||
def _phase_correct(
|
||||
dossier: DossierMedical,
|
||||
draft: _CpamDraft,
|
||||
) -> None:
|
||||
"""Phase 3 — Correction (role=cpam) + re-validation (role=validation).
|
||||
|
||||
Boucle correction/re-validation intégrée car les deux sont couplées.
|
||||
Seuls les drafts avec needs_correction=True entrent ici.
|
||||
"""
|
||||
if draft.result is None or not draft.needs_correction:
|
||||
return
|
||||
|
||||
controle = draft.controle
|
||||
max_corrections = int(os.environ.get("T2A_CPAM_MAX_CORRECTIONS", "2"))
|
||||
|
||||
for attempt in range(max_corrections):
|
||||
if not (draft.validation
|
||||
and not draft.validation.get("coherent", True)
|
||||
and draft.validation.get("score_confiance", 10) <= 5
|
||||
and rule_enabled("RULE-CPAM-CORRECTION-LOOP")):
|
||||
break
|
||||
|
||||
erreurs_v = draft.validation.get("erreurs", [])
|
||||
logger.warning(" Score adversarial %s/10 — correction %d/%d (%d erreur(s))",
|
||||
draft.validation.get("score_confiance"),
|
||||
attempt + 1, max_corrections, len(erreurs_v))
|
||||
|
||||
correction_prompt = _build_correction_prompt(
|
||||
draft.prompt, draft.result, draft.validation
|
||||
)
|
||||
t_corr = time.time()
|
||||
corrected = call_ollama(
|
||||
correction_prompt, temperature=0.0, max_tokens=16000, role="cpam"
|
||||
)
|
||||
if corrected is None:
|
||||
corrected = call_anthropic(
|
||||
correction_prompt, temperature=0.0, max_tokens=16000
|
||||
)
|
||||
logger.info(" [CPAM-CORR] %.1fs — OGC %s — correction %d/%d",
|
||||
time.time() - t_corr, controle.numero_ogc,
|
||||
attempt + 1, max_corrections)
|
||||
|
||||
if not corrected:
|
||||
break
|
||||
|
||||
validation2 = _validate_adversarial(corrected, draft.tag_map, controle)
|
||||
score2 = validation2.get("score_confiance", 0) if validation2 else 0
|
||||
score1 = draft.validation.get("score_confiance", 0)
|
||||
|
||||
if score2 > score1:
|
||||
logger.info(" Correction %d acceptée (score %s → %s)",
|
||||
attempt + 1, score1, score2)
|
||||
draft.result = corrected
|
||||
draft.validation = validation2
|
||||
_sanitize_unauthorized_codes(draft.result, dossier, controle)
|
||||
draft.result = _guardian_deterministic(
|
||||
draft.result, dossier, controle, draft.tag_map
|
||||
)
|
||||
draft.ref_warnings = _validate_references(draft.result, draft.sources)
|
||||
draft.grounding_warnings = _validate_grounding(draft.result, draft.tag_map)
|
||||
draft.code_warnings = _validate_codes_in_response(
|
||||
draft.result, dossier, controle
|
||||
)
|
||||
draft.adversarial_warnings = []
|
||||
if draft.validation and not draft.validation.get("coherent", True):
|
||||
for e in draft.validation.get("erreurs", []):
|
||||
if isinstance(e, str) and e.strip():
|
||||
draft.adversarial_warnings.append(f"Incohérence détectée : {e}")
|
||||
if draft.adversarial_warnings:
|
||||
draft.adversarial_warnings.append(
|
||||
f"Score de confiance : {draft.validation.get('score_confiance', '?')}/10"
|
||||
)
|
||||
else:
|
||||
logger.warning(" Correction %d rejetée (score %s → %s)",
|
||||
attempt + 1, score1, score2)
|
||||
break
|
||||
|
||||
|
||||
def _phase_finalize(
|
||||
dossier: DossierMedical,
|
||||
draft: _CpamDraft,
|
||||
) -> tuple[str, dict | None, list[RAGSource]]:
|
||||
"""Phase finale — Qualité + formatage (déterministe, pas de LLM).
|
||||
|
||||
Returns:
|
||||
Même signature que generate_cpam_response().
|
||||
"""
|
||||
controle = draft.controle
|
||||
|
||||
if draft.result is None:
|
||||
return "", None, draft.rag_sources
|
||||
|
||||
all_warnings = (
|
||||
draft.ref_warnings + draft.grounding_warnings
|
||||
+ draft.code_warnings + draft.adversarial_warnings
|
||||
)
|
||||
|
||||
strength = _assess_dossier_strength(dossier)
|
||||
if strength["is_weak"]:
|
||||
logger.info(" Dossier à preuves limitées (score %d/10) : %s",
|
||||
strength["score"], ", ".join(strength["missing"]))
|
||||
|
||||
tier, needs_review, cat_warnings = _assess_quality_tier(
|
||||
draft.result, draft.ref_warnings, draft.grounding_warnings,
|
||||
draft.code_warnings, draft.validation,
|
||||
is_weak_dossier=strength["is_weak"],
|
||||
)
|
||||
controle.quality_tier = tier
|
||||
controle.requires_review = needs_review
|
||||
controle.quality_warnings = cat_warnings
|
||||
logger.info(" Qualité CPAM : tier %s, requires_review=%s, %d warnings",
|
||||
tier, needs_review, len(cat_warnings))
|
||||
|
||||
text = _format_response(
|
||||
draft.result,
|
||||
ref_warnings=all_warnings,
|
||||
quality_tier=tier,
|
||||
categorized_warnings=cat_warnings,
|
||||
)
|
||||
logger.info(" Contre-argumentation générée (%d caractères)", len(text))
|
||||
|
||||
return text, draft.result, draft.rag_sources
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Orchestrateur batché — minimise les swaps de modèle VRAM
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def generate_cpam_responses_batched(
|
||||
dossier: DossierMedical,
|
||||
controles: list[ControleCPAM],
|
||||
) -> None:
|
||||
"""Génère les contre-argumentations pour TOUS les contrôles CPAM en batch.
|
||||
|
||||
Regroupe les appels LLM par modèle pour minimiser les swaps VRAM :
|
||||
Phase 1 (cpam model) : extraction + argumentation de tous les contrôles
|
||||
Phase 2 (validation model) : validation adversariale de tous
|
||||
Phase 3 (cpam model) : correction des contrôles échoués (si nécessaire)
|
||||
Finalisation : qualité + formatage (déterministe)
|
||||
|
||||
Gain : de ~4*N swaps de modèle à 2-4 swaps, indépendant de N.
|
||||
|
||||
Args:
|
||||
dossier: Le dossier médical (fusionné ou unique).
|
||||
controles: Liste des contrôles CPAM à traiter.
|
||||
|
||||
Side effects:
|
||||
Remplit controle.contre_argumentation, controle.response_data,
|
||||
controle.sources_reponse pour chaque contrôle.
|
||||
"""
|
||||
if not controles:
|
||||
return
|
||||
|
||||
t_total = time.time()
|
||||
n = len(controles)
|
||||
logger.info("CPAM batch : %d contrôle(s) — début du traitement batché", n)
|
||||
|
||||
# Créer les drafts
|
||||
drafts = [_CpamDraft(controle=ctrl) for ctrl in controles]
|
||||
|
||||
# --- Phase 1 : génération (cpam model) ---
|
||||
t_phase = time.time()
|
||||
logger.info("CPAM batch phase 1/3 : génération (%d contrôles, role=cpam)", n)
|
||||
for draft in drafts:
|
||||
_phase_generate(dossier, draft)
|
||||
logger.info("CPAM batch phase 1/3 terminée : %.1fs", time.time() - t_phase)
|
||||
|
||||
# --- Phase 2 : validation adversariale (validation model — 1 swap) ---
|
||||
drafts_to_validate = [d for d in drafts if d.result is not None]
|
||||
if drafts_to_validate:
|
||||
t_phase = time.time()
|
||||
logger.info("CPAM batch phase 2/3 : validation (%d contrôles, role=validation)",
|
||||
len(drafts_to_validate))
|
||||
for draft in drafts_to_validate:
|
||||
_phase_validate(dossier, draft)
|
||||
logger.info("CPAM batch phase 2/3 terminée : %.1fs", time.time() - t_phase)
|
||||
|
||||
# --- Phase 3 : correction + re-validation (cpam + validation models) ---
|
||||
drafts_to_correct = [d for d in drafts if d.needs_correction]
|
||||
if drafts_to_correct:
|
||||
t_phase = time.time()
|
||||
logger.info("CPAM batch phase 3/3 : correction (%d contrôles nécessitent correction)",
|
||||
len(drafts_to_correct))
|
||||
for draft in drafts_to_correct:
|
||||
_phase_correct(dossier, draft)
|
||||
logger.info("CPAM batch phase 3/3 terminée : %.1fs", time.time() - t_phase)
|
||||
|
||||
# --- Finalisation (déterministe, pas de swap) ---
|
||||
for draft in drafts:
|
||||
text, response_data, sources = _phase_finalize(dossier, draft)
|
||||
draft.controle.contre_argumentation = text
|
||||
draft.controle.response_data = response_data
|
||||
draft.controle.sources_reponse = sources
|
||||
|
||||
elapsed_total = time.time() - t_total
|
||||
logger.info(
|
||||
"CPAM batch terminé : %d contrôle(s) en %.1fs (%.1fs/contrôle)",
|
||||
n, elapsed_total, elapsed_total / n if n else 0,
|
||||
)
|
||||
|
||||
10
src/main.py
10
src/main.py
@@ -610,20 +610,18 @@ def main(input_path: str | None = None) -> None:
|
||||
merged = None
|
||||
|
||||
# Contrôle CPAM : enrichir le dossier principal (fusionné ou dernier)
|
||||
# Utilise le mode batché pour regrouper les appels LLM par modèle
|
||||
# et minimiser les swaps VRAM (de ~4N swaps à ~2-4 swaps)
|
||||
if cpam_data and subdir:
|
||||
try:
|
||||
from .control.cpam_parser import match_dossier_ogc
|
||||
controles = match_dossier_ogc(subdir, cpam_data)
|
||||
if controles:
|
||||
from .control.cpam_response import generate_cpam_response
|
||||
from .control.cpam_response import generate_cpam_responses_batched
|
||||
target = merged if merged else (group_dossiers[-1] if group_dossiers else None)
|
||||
if target:
|
||||
logger.info(" CPAM : %d contrôle(s) pour %s", len(controles), subdir)
|
||||
for ctrl in controles:
|
||||
text, response_data, sources = generate_cpam_response(target, ctrl)
|
||||
ctrl.contre_argumentation = text
|
||||
ctrl.response_data = response_data
|
||||
ctrl.sources_reponse = sources
|
||||
generate_cpam_responses_batched(target, controles)
|
||||
target.controles_cpam = controles
|
||||
except Exception:
|
||||
logger.exception("Erreur CPAM pour %s", subdir)
|
||||
|
||||
Reference in New Issue
Block a user