""" Correction Pack Integration - Bridge between learning components and Correction Packs. Automatically captures corrections from COACHING mode and propagates them to the Correction Pack system for cross-workflow learning. """ import logging from typing import Dict, Any, Optional, Callable from datetime import datetime from .correction_pack_service import CorrectionPackService from .models import CorrectionType logger = logging.getLogger(__name__) class CorrectionPackIntegration: """ Integration layer between learning components and Correction Packs. Captures corrections from TrainingDataCollector and FeedbackProcessor and automatically forwards them to the Correction Pack system. """ # Default pack name for auto-captured corrections DEFAULT_PACK_NAME = "auto_captured_corrections" DEFAULT_PACK_DESCRIPTION = "Corrections automatiquement capturées pendant les sessions COACHING" def __init__( self, service: Optional[CorrectionPackService] = None, auto_create_pack: bool = True, default_pack_id: Optional[str] = None ): """ Initialize the integration. Args: service: CorrectionPackService instance (lazy-loaded if None) auto_create_pack: Create default pack if it doesn't exist default_pack_id: ID of pack to use (auto-detected if None) """ self._service = service self._auto_create_pack = auto_create_pack self._default_pack_id = default_pack_id self._initialized = False @property def service(self) -> CorrectionPackService: """Lazy-load service.""" if self._service is None: self._service = CorrectionPackService() return self._service def _ensure_initialized(self) -> str: """ Ensure the default pack exists and return its ID. Returns: Pack ID to use for storing corrections """ if self._initialized and self._default_pack_id: return self._default_pack_id # If specific pack ID provided, verify it exists if self._default_pack_id: pack = self.service.get_pack(self._default_pack_id) if pack: self._initialized = True return self._default_pack_id else: logger.warning(f"Pack {self._default_pack_id} not found, creating default") # Search for existing auto-capture pack packs = self.service.list_packs() for pack in packs: # Handle both dict and object formats pack_name = pack.get('name') if isinstance(pack, dict) else pack.name pack_id = pack.get('id') if isinstance(pack, dict) else pack.id if pack_name == self.DEFAULT_PACK_NAME: self._default_pack_id = pack_id self._initialized = True logger.info(f"Using existing auto-capture pack: {pack_id}") return self._default_pack_id # Create new pack if auto_create enabled if self._auto_create_pack: pack = self.service.create_pack( name=self.DEFAULT_PACK_NAME, description=self.DEFAULT_PACK_DESCRIPTION, tags=["auto", "coaching"], category="learning" ) # Handle both dict and object formats pack_id = pack.get('id') if isinstance(pack, dict) else pack.id self._default_pack_id = pack_id self._initialized = True logger.info(f"Created auto-capture pack: {pack_id}") return self._default_pack_id raise RuntimeError("No default pack available and auto_create_pack is disabled") def capture_correction( self, correction_data: Dict[str, Any], session_id: Optional[str] = None, workflow_id: Optional[str] = None, node_id: Optional[str] = None, pack_id: Optional[str] = None ) -> Optional[str]: """ Capture a correction and add it to a Correction Pack. Args: correction_data: Correction dict (from TrainingDataCollector format) session_id: Source session ID workflow_id: Source workflow ID node_id: Source node ID pack_id: Specific pack to add to (uses default if None) Returns: Correction ID if successfully added, None otherwise """ try: target_pack_id = pack_id or self._ensure_initialized() # Enrich correction data with source info enriched = { **correction_data, 'node_id': node_id or correction_data.get('node_id') } # Determine session_id final_session_id = session_id or correction_data.get('session_id', 'unknown') final_workflow_id = workflow_id or correction_data.get('workflow_id') # Add to pack correction = self.service.add_correction_from_session( pack_id=target_pack_id, session_correction=enriched, session_id=final_session_id, workflow_id=final_workflow_id ) if correction: # Handle both dict and object formats correction_id = correction.get('id') if isinstance(correction, dict) else correction.id logger.info( f"Captured correction {correction_id} from " f"session={final_session_id}, workflow={final_workflow_id}" ) return correction_id else: logger.warning("Failed to add correction to pack") return None except Exception as e: logger.error(f"Error capturing correction: {e}") return None def capture_feedback_correction( self, workflow_id: str, execution_id: str, corrections: Dict[str, Any], context: Optional[Dict[str, Any]] = None ) -> Optional[str]: """ Capture a correction from FeedbackProcessor. Args: workflow_id: Workflow ID execution_id: Execution ID (used as session_id) corrections: Correction dict from feedback context: Additional context (failure_reason, element_type, etc.) Returns: Correction ID if successfully added """ if not corrections: return None # Build correction data in standard format correction_data = { 'action_type': corrections.get('action_type', corrections.get('type', 'unknown')), 'element_type': corrections.get('element_type', context.get('element_type', 'unknown') if context else 'unknown'), 'failure_reason': corrections.get('failure_reason', context.get('failure_reason', '') if context else ''), 'correction_type': self._infer_correction_type(corrections), 'original_target': corrections.get('original_target'), 'corrected_target': corrections.get('corrected_target', corrections.get('new_target')), 'original_params': corrections.get('original_params'), 'corrected_params': corrections.get('corrected_params', corrections.get('new_params')), 'description': corrections.get('description', f"Correction from feedback {execution_id}") } return self.capture_correction( correction_data=correction_data, session_id=execution_id, workflow_id=workflow_id ) def _infer_correction_type(self, corrections: Dict[str, Any]) -> str: """Infer correction type from correction data.""" if corrections.get('correction_type'): return corrections['correction_type'] # Infer from data if corrections.get('corrected_target') or corrections.get('new_target'): return CorrectionType.TARGET_CHANGE.value if corrections.get('corrected_params') or corrections.get('new_params'): return CorrectionType.PARAMETER_CHANGE.value if corrections.get('wait_time') or corrections.get('timing'): return CorrectionType.TIMING_ADJUST.value if corrections.get('coordinates') or corrections.get('offset'): return CorrectionType.COORDINATES_ADJUST.value return CorrectionType.OTHER.value def create_hook_for_collector(self) -> Callable[[Dict[str, Any]], None]: """ Create a hook function for TrainingDataCollector. Returns: Callback function that captures corrections Usage: integration = CorrectionPackIntegration() collector.on_correction = integration.create_hook_for_collector() """ def hook(correction: Dict[str, Any]) -> None: self.capture_correction(correction) return hook def get_statistics(self) -> Dict[str, Any]: """Get statistics about captured corrections.""" try: pack_id = self._ensure_initialized() return self.service.get_pack_statistics(pack_id) except Exception as e: logger.error(f"Error getting statistics: {e}") return {'error': str(e)} # Global instance for easy access _global_integration: Optional[CorrectionPackIntegration] = None def get_correction_pack_integration() -> CorrectionPackIntegration: """ Get the global CorrectionPackIntegration instance. Creates a new instance if one doesn't exist. Returns: CorrectionPackIntegration instance """ global _global_integration if _global_integration is None: _global_integration = CorrectionPackIntegration() return _global_integration def capture_coaching_correction( correction_data: Dict[str, Any], session_id: Optional[str] = None, workflow_id: Optional[str] = None, node_id: Optional[str] = None ) -> Optional[str]: """ Convenience function to capture a correction. Args: correction_data: Correction dict session_id: Source session ID workflow_id: Source workflow ID node_id: Source node ID Returns: Correction ID if successful """ integration = get_correction_pack_integration() return integration.capture_correction( correction_data=correction_data, session_id=session_id, workflow_id=workflow_id, node_id=node_id )