- Add CorrectionPackIntegration class to bridge learning components - Modify TrainingDataCollector to auto-propagate corrections to packs - Modify FeedbackProcessor to capture corrections on INCORRECT/PARTIAL feedback - Add convenience functions: get_correction_pack_integration(), capture_coaching_correction() - Add 19 integration tests (all passing) Corrections made during COACHING mode are now automatically captured into a dedicated "auto_captured_corrections" pack for cross-workflow reuse. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
289 lines
10 KiB
Python
289 lines
10 KiB
Python
"""
|
|
Correction Pack Integration - Bridge between learning components and Correction Packs.
|
|
|
|
Automatically captures corrections from COACHING mode and propagates them
|
|
to the Correction Pack system for cross-workflow learning.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Any, Optional, Callable
|
|
from datetime import datetime
|
|
|
|
from .correction_pack_service import CorrectionPackService
|
|
from .models import CorrectionType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CorrectionPackIntegration:
|
|
"""
|
|
Integration layer between learning components and Correction Packs.
|
|
|
|
Captures corrections from TrainingDataCollector and FeedbackProcessor
|
|
and automatically forwards them to the Correction Pack system.
|
|
"""
|
|
|
|
# Default pack name for auto-captured corrections
|
|
DEFAULT_PACK_NAME = "auto_captured_corrections"
|
|
DEFAULT_PACK_DESCRIPTION = "Corrections automatiquement capturées pendant les sessions COACHING"
|
|
|
|
def __init__(
|
|
self,
|
|
service: Optional[CorrectionPackService] = None,
|
|
auto_create_pack: bool = True,
|
|
default_pack_id: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize the integration.
|
|
|
|
Args:
|
|
service: CorrectionPackService instance (lazy-loaded if None)
|
|
auto_create_pack: Create default pack if it doesn't exist
|
|
default_pack_id: ID of pack to use (auto-detected if None)
|
|
"""
|
|
self._service = service
|
|
self._auto_create_pack = auto_create_pack
|
|
self._default_pack_id = default_pack_id
|
|
self._initialized = False
|
|
|
|
@property
|
|
def service(self) -> CorrectionPackService:
|
|
"""Lazy-load service."""
|
|
if self._service is None:
|
|
self._service = CorrectionPackService()
|
|
return self._service
|
|
|
|
def _ensure_initialized(self) -> str:
|
|
"""
|
|
Ensure the default pack exists and return its ID.
|
|
|
|
Returns:
|
|
Pack ID to use for storing corrections
|
|
"""
|
|
if self._initialized and self._default_pack_id:
|
|
return self._default_pack_id
|
|
|
|
# If specific pack ID provided, verify it exists
|
|
if self._default_pack_id:
|
|
pack = self.service.get_pack(self._default_pack_id)
|
|
if pack:
|
|
self._initialized = True
|
|
return self._default_pack_id
|
|
else:
|
|
logger.warning(f"Pack {self._default_pack_id} not found, creating default")
|
|
|
|
# Search for existing auto-capture pack
|
|
packs = self.service.list_packs()
|
|
for pack in packs:
|
|
# Handle both dict and object formats
|
|
pack_name = pack.get('name') if isinstance(pack, dict) else pack.name
|
|
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
|
|
if pack_name == self.DEFAULT_PACK_NAME:
|
|
self._default_pack_id = pack_id
|
|
self._initialized = True
|
|
logger.info(f"Using existing auto-capture pack: {pack_id}")
|
|
return self._default_pack_id
|
|
|
|
# Create new pack if auto_create enabled
|
|
if self._auto_create_pack:
|
|
pack = self.service.create_pack(
|
|
name=self.DEFAULT_PACK_NAME,
|
|
description=self.DEFAULT_PACK_DESCRIPTION,
|
|
tags=["auto", "coaching"],
|
|
category="learning"
|
|
)
|
|
# Handle both dict and object formats
|
|
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
|
|
self._default_pack_id = pack_id
|
|
self._initialized = True
|
|
logger.info(f"Created auto-capture pack: {pack_id}")
|
|
return self._default_pack_id
|
|
|
|
raise RuntimeError("No default pack available and auto_create_pack is disabled")
|
|
|
|
def capture_correction(
|
|
self,
|
|
correction_data: Dict[str, Any],
|
|
session_id: Optional[str] = None,
|
|
workflow_id: Optional[str] = None,
|
|
node_id: Optional[str] = None,
|
|
pack_id: Optional[str] = None
|
|
) -> Optional[str]:
|
|
"""
|
|
Capture a correction and add it to a Correction Pack.
|
|
|
|
Args:
|
|
correction_data: Correction dict (from TrainingDataCollector format)
|
|
session_id: Source session ID
|
|
workflow_id: Source workflow ID
|
|
node_id: Source node ID
|
|
pack_id: Specific pack to add to (uses default if None)
|
|
|
|
Returns:
|
|
Correction ID if successfully added, None otherwise
|
|
"""
|
|
try:
|
|
target_pack_id = pack_id or self._ensure_initialized()
|
|
|
|
# Enrich correction data with source info
|
|
enriched = {
|
|
**correction_data,
|
|
'node_id': node_id or correction_data.get('node_id')
|
|
}
|
|
|
|
# Determine session_id
|
|
final_session_id = session_id or correction_data.get('session_id', 'unknown')
|
|
final_workflow_id = workflow_id or correction_data.get('workflow_id')
|
|
|
|
# Add to pack
|
|
correction = self.service.add_correction_from_session(
|
|
pack_id=target_pack_id,
|
|
session_correction=enriched,
|
|
session_id=final_session_id,
|
|
workflow_id=final_workflow_id
|
|
)
|
|
|
|
if correction:
|
|
# Handle both dict and object formats
|
|
correction_id = correction.get('id') if isinstance(correction, dict) else correction.id
|
|
logger.info(
|
|
f"Captured correction {correction_id} from "
|
|
f"session={final_session_id}, workflow={final_workflow_id}"
|
|
)
|
|
return correction_id
|
|
else:
|
|
logger.warning("Failed to add correction to pack")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error capturing correction: {e}")
|
|
return None
|
|
|
|
def capture_feedback_correction(
|
|
self,
|
|
workflow_id: str,
|
|
execution_id: str,
|
|
corrections: Dict[str, Any],
|
|
context: Optional[Dict[str, Any]] = None
|
|
) -> Optional[str]:
|
|
"""
|
|
Capture a correction from FeedbackProcessor.
|
|
|
|
Args:
|
|
workflow_id: Workflow ID
|
|
execution_id: Execution ID (used as session_id)
|
|
corrections: Correction dict from feedback
|
|
context: Additional context (failure_reason, element_type, etc.)
|
|
|
|
Returns:
|
|
Correction ID if successfully added
|
|
"""
|
|
if not corrections:
|
|
return None
|
|
|
|
# Build correction data in standard format
|
|
correction_data = {
|
|
'action_type': corrections.get('action_type', corrections.get('type', 'unknown')),
|
|
'element_type': corrections.get('element_type', context.get('element_type', 'unknown') if context else 'unknown'),
|
|
'failure_reason': corrections.get('failure_reason', context.get('failure_reason', '') if context else ''),
|
|
'correction_type': self._infer_correction_type(corrections),
|
|
'original_target': corrections.get('original_target'),
|
|
'corrected_target': corrections.get('corrected_target', corrections.get('new_target')),
|
|
'original_params': corrections.get('original_params'),
|
|
'corrected_params': corrections.get('corrected_params', corrections.get('new_params')),
|
|
'description': corrections.get('description', f"Correction from feedback {execution_id}")
|
|
}
|
|
|
|
return self.capture_correction(
|
|
correction_data=correction_data,
|
|
session_id=execution_id,
|
|
workflow_id=workflow_id
|
|
)
|
|
|
|
def _infer_correction_type(self, corrections: Dict[str, Any]) -> str:
|
|
"""Infer correction type from correction data."""
|
|
if corrections.get('correction_type'):
|
|
return corrections['correction_type']
|
|
|
|
# Infer from data
|
|
if corrections.get('corrected_target') or corrections.get('new_target'):
|
|
return CorrectionType.TARGET_CHANGE.value
|
|
if corrections.get('corrected_params') or corrections.get('new_params'):
|
|
return CorrectionType.PARAMETER_CHANGE.value
|
|
if corrections.get('wait_time') or corrections.get('timing'):
|
|
return CorrectionType.TIMING_ADJUST.value
|
|
if corrections.get('coordinates') or corrections.get('offset'):
|
|
return CorrectionType.COORDINATES_ADJUST.value
|
|
|
|
return CorrectionType.OTHER.value
|
|
|
|
def create_hook_for_collector(self) -> Callable[[Dict[str, Any]], None]:
|
|
"""
|
|
Create a hook function for TrainingDataCollector.
|
|
|
|
Returns:
|
|
Callback function that captures corrections
|
|
|
|
Usage:
|
|
integration = CorrectionPackIntegration()
|
|
collector.on_correction = integration.create_hook_for_collector()
|
|
"""
|
|
def hook(correction: Dict[str, Any]) -> None:
|
|
self.capture_correction(correction)
|
|
return hook
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""Get statistics about captured corrections."""
|
|
try:
|
|
pack_id = self._ensure_initialized()
|
|
return self.service.get_pack_statistics(pack_id)
|
|
except Exception as e:
|
|
logger.error(f"Error getting statistics: {e}")
|
|
return {'error': str(e)}
|
|
|
|
|
|
# Global instance for easy access
|
|
_global_integration: Optional[CorrectionPackIntegration] = None
|
|
|
|
|
|
def get_correction_pack_integration() -> CorrectionPackIntegration:
|
|
"""
|
|
Get the global CorrectionPackIntegration instance.
|
|
|
|
Creates a new instance if one doesn't exist.
|
|
|
|
Returns:
|
|
CorrectionPackIntegration instance
|
|
"""
|
|
global _global_integration
|
|
if _global_integration is None:
|
|
_global_integration = CorrectionPackIntegration()
|
|
return _global_integration
|
|
|
|
|
|
def capture_coaching_correction(
|
|
correction_data: Dict[str, Any],
|
|
session_id: Optional[str] = None,
|
|
workflow_id: Optional[str] = None,
|
|
node_id: Optional[str] = None
|
|
) -> Optional[str]:
|
|
"""
|
|
Convenience function to capture a correction.
|
|
|
|
Args:
|
|
correction_data: Correction dict
|
|
session_id: Source session ID
|
|
workflow_id: Source workflow ID
|
|
node_id: Source node ID
|
|
|
|
Returns:
|
|
Correction ID if successful
|
|
"""
|
|
integration = get_correction_pack_integration()
|
|
return integration.capture_correction(
|
|
correction_data=correction_data,
|
|
session_id=session_id,
|
|
workflow_id=workflow_id,
|
|
node_id=node_id
|
|
)
|