From efb184fdb941f95a316ff2f556a803cb2a43c64e Mon Sep 17 00:00:00 2001 From: Dom Date: Sun, 18 Jan 2026 19:06:09 +0100 Subject: [PATCH] feat(corrections): Add automatic COACHING integration for Correction Packs - Add CorrectionPackIntegration class to bridge learning components - Modify TrainingDataCollector to auto-propagate corrections to packs - Modify FeedbackProcessor to capture corrections on INCORRECT/PARTIAL feedback - Add convenience functions: get_correction_pack_integration(), capture_coaching_correction() - Add 19 integration tests (all passing) Corrections made during COACHING mode are now automatically captured into a dedicated "auto_captured_corrections" pack for cross-workflow reuse. Co-Authored-By: Claude Opus 4.5 --- core/corrections/__init__.py | 11 +- core/corrections/integration.py | 288 +++++++++++++ core/learning/feedback_processor.py | 176 ++++++++ core/training/training_data_collector.py | 246 +++++++++++ tests/test_correction_pack_integration.py | 486 ++++++++++++++++++++++ 5 files changed, 1206 insertions(+), 1 deletion(-) create mode 100644 core/corrections/integration.py create mode 100644 core/learning/feedback_processor.py create mode 100644 core/training/training_data_collector.py create mode 100644 tests/test_correction_pack_integration.py diff --git a/core/corrections/__init__.py b/core/corrections/__init__.py index 778def9c9..0bc293cb7 100644 --- a/core/corrections/__init__.py +++ b/core/corrections/__init__.py @@ -16,6 +16,11 @@ from .models import ( from .correction_repository import CorrectionRepository from .aggregator import CorrectionAggregator from .correction_pack_service import CorrectionPackService +from .integration import ( + CorrectionPackIntegration, + get_correction_pack_integration, + capture_coaching_correction +) __all__ = [ 'CorrectionKey', @@ -26,5 +31,9 @@ __all__ = [ 'CorrectionPackMetadata', 'CorrectionRepository', 'CorrectionAggregator', - 'CorrectionPackService' + 'CorrectionPackService', + # Integration + 'CorrectionPackIntegration', + 'get_correction_pack_integration', + 'capture_coaching_correction' ] diff --git a/core/corrections/integration.py b/core/corrections/integration.py new file mode 100644 index 000000000..12c1b7115 --- /dev/null +++ b/core/corrections/integration.py @@ -0,0 +1,288 @@ +""" +Correction Pack Integration - Bridge between learning components and Correction Packs. + +Automatically captures corrections from COACHING mode and propagates them +to the Correction Pack system for cross-workflow learning. +""" + +import logging +from typing import Dict, Any, Optional, Callable +from datetime import datetime + +from .correction_pack_service import CorrectionPackService +from .models import CorrectionType + +logger = logging.getLogger(__name__) + + +class CorrectionPackIntegration: + """ + Integration layer between learning components and Correction Packs. + + Captures corrections from TrainingDataCollector and FeedbackProcessor + and automatically forwards them to the Correction Pack system. + """ + + # Default pack name for auto-captured corrections + DEFAULT_PACK_NAME = "auto_captured_corrections" + DEFAULT_PACK_DESCRIPTION = "Corrections automatiquement capturées pendant les sessions COACHING" + + def __init__( + self, + service: Optional[CorrectionPackService] = None, + auto_create_pack: bool = True, + default_pack_id: Optional[str] = None + ): + """ + Initialize the integration. + + Args: + service: CorrectionPackService instance (lazy-loaded if None) + auto_create_pack: Create default pack if it doesn't exist + default_pack_id: ID of pack to use (auto-detected if None) + """ + self._service = service + self._auto_create_pack = auto_create_pack + self._default_pack_id = default_pack_id + self._initialized = False + + @property + def service(self) -> CorrectionPackService: + """Lazy-load service.""" + if self._service is None: + self._service = CorrectionPackService() + return self._service + + def _ensure_initialized(self) -> str: + """ + Ensure the default pack exists and return its ID. + + Returns: + Pack ID to use for storing corrections + """ + if self._initialized and self._default_pack_id: + return self._default_pack_id + + # If specific pack ID provided, verify it exists + if self._default_pack_id: + pack = self.service.get_pack(self._default_pack_id) + if pack: + self._initialized = True + return self._default_pack_id + else: + logger.warning(f"Pack {self._default_pack_id} not found, creating default") + + # Search for existing auto-capture pack + packs = self.service.list_packs() + for pack in packs: + # Handle both dict and object formats + pack_name = pack.get('name') if isinstance(pack, dict) else pack.name + pack_id = pack.get('id') if isinstance(pack, dict) else pack.id + if pack_name == self.DEFAULT_PACK_NAME: + self._default_pack_id = pack_id + self._initialized = True + logger.info(f"Using existing auto-capture pack: {pack_id}") + return self._default_pack_id + + # Create new pack if auto_create enabled + if self._auto_create_pack: + pack = self.service.create_pack( + name=self.DEFAULT_PACK_NAME, + description=self.DEFAULT_PACK_DESCRIPTION, + tags=["auto", "coaching"], + category="learning" + ) + # Handle both dict and object formats + pack_id = pack.get('id') if isinstance(pack, dict) else pack.id + self._default_pack_id = pack_id + self._initialized = True + logger.info(f"Created auto-capture pack: {pack_id}") + return self._default_pack_id + + raise RuntimeError("No default pack available and auto_create_pack is disabled") + + def capture_correction( + self, + correction_data: Dict[str, Any], + session_id: Optional[str] = None, + workflow_id: Optional[str] = None, + node_id: Optional[str] = None, + pack_id: Optional[str] = None + ) -> Optional[str]: + """ + Capture a correction and add it to a Correction Pack. + + Args: + correction_data: Correction dict (from TrainingDataCollector format) + session_id: Source session ID + workflow_id: Source workflow ID + node_id: Source node ID + pack_id: Specific pack to add to (uses default if None) + + Returns: + Correction ID if successfully added, None otherwise + """ + try: + target_pack_id = pack_id or self._ensure_initialized() + + # Enrich correction data with source info + enriched = { + **correction_data, + 'node_id': node_id or correction_data.get('node_id') + } + + # Determine session_id + final_session_id = session_id or correction_data.get('session_id', 'unknown') + final_workflow_id = workflow_id or correction_data.get('workflow_id') + + # Add to pack + correction = self.service.add_correction_from_session( + pack_id=target_pack_id, + session_correction=enriched, + session_id=final_session_id, + workflow_id=final_workflow_id + ) + + if correction: + # Handle both dict and object formats + correction_id = correction.get('id') if isinstance(correction, dict) else correction.id + logger.info( + f"Captured correction {correction_id} from " + f"session={final_session_id}, workflow={final_workflow_id}" + ) + return correction_id + else: + logger.warning("Failed to add correction to pack") + return None + + except Exception as e: + logger.error(f"Error capturing correction: {e}") + return None + + def capture_feedback_correction( + self, + workflow_id: str, + execution_id: str, + corrections: Dict[str, Any], + context: Optional[Dict[str, Any]] = None + ) -> Optional[str]: + """ + Capture a correction from FeedbackProcessor. + + Args: + workflow_id: Workflow ID + execution_id: Execution ID (used as session_id) + corrections: Correction dict from feedback + context: Additional context (failure_reason, element_type, etc.) + + Returns: + Correction ID if successfully added + """ + if not corrections: + return None + + # Build correction data in standard format + correction_data = { + 'action_type': corrections.get('action_type', corrections.get('type', 'unknown')), + 'element_type': corrections.get('element_type', context.get('element_type', 'unknown') if context else 'unknown'), + 'failure_reason': corrections.get('failure_reason', context.get('failure_reason', '') if context else ''), + 'correction_type': self._infer_correction_type(corrections), + 'original_target': corrections.get('original_target'), + 'corrected_target': corrections.get('corrected_target', corrections.get('new_target')), + 'original_params': corrections.get('original_params'), + 'corrected_params': corrections.get('corrected_params', corrections.get('new_params')), + 'description': corrections.get('description', f"Correction from feedback {execution_id}") + } + + return self.capture_correction( + correction_data=correction_data, + session_id=execution_id, + workflow_id=workflow_id + ) + + def _infer_correction_type(self, corrections: Dict[str, Any]) -> str: + """Infer correction type from correction data.""" + if corrections.get('correction_type'): + return corrections['correction_type'] + + # Infer from data + if corrections.get('corrected_target') or corrections.get('new_target'): + return CorrectionType.TARGET_CHANGE.value + if corrections.get('corrected_params') or corrections.get('new_params'): + return CorrectionType.PARAMETER_CHANGE.value + if corrections.get('wait_time') or corrections.get('timing'): + return CorrectionType.TIMING_ADJUST.value + if corrections.get('coordinates') or corrections.get('offset'): + return CorrectionType.COORDINATES_ADJUST.value + + return CorrectionType.OTHER.value + + def create_hook_for_collector(self) -> Callable[[Dict[str, Any]], None]: + """ + Create a hook function for TrainingDataCollector. + + Returns: + Callback function that captures corrections + + Usage: + integration = CorrectionPackIntegration() + collector.on_correction = integration.create_hook_for_collector() + """ + def hook(correction: Dict[str, Any]) -> None: + self.capture_correction(correction) + return hook + + def get_statistics(self) -> Dict[str, Any]: + """Get statistics about captured corrections.""" + try: + pack_id = self._ensure_initialized() + return self.service.get_pack_statistics(pack_id) + except Exception as e: + logger.error(f"Error getting statistics: {e}") + return {'error': str(e)} + + +# Global instance for easy access +_global_integration: Optional[CorrectionPackIntegration] = None + + +def get_correction_pack_integration() -> CorrectionPackIntegration: + """ + Get the global CorrectionPackIntegration instance. + + Creates a new instance if one doesn't exist. + + Returns: + CorrectionPackIntegration instance + """ + global _global_integration + if _global_integration is None: + _global_integration = CorrectionPackIntegration() + return _global_integration + + +def capture_coaching_correction( + correction_data: Dict[str, Any], + session_id: Optional[str] = None, + workflow_id: Optional[str] = None, + node_id: Optional[str] = None +) -> Optional[str]: + """ + Convenience function to capture a correction. + + Args: + correction_data: Correction dict + session_id: Source session ID + workflow_id: Source workflow ID + node_id: Source node ID + + Returns: + Correction ID if successful + """ + integration = get_correction_pack_integration() + return integration.capture_correction( + correction_data=correction_data, + session_id=session_id, + workflow_id=workflow_id, + node_id=node_id + ) diff --git a/core/learning/feedback_processor.py b/core/learning/feedback_processor.py new file mode 100644 index 000000000..c9804756b --- /dev/null +++ b/core/learning/feedback_processor.py @@ -0,0 +1,176 @@ +"""Feedback Processor - Processes user feedback to improve workflows""" +import logging +from typing import Dict, List, Optional, Any +from dataclasses import dataclass +from datetime import datetime +from enum import Enum + +logger = logging.getLogger(__name__) + +# Lazy import for correction pack integration +_correction_integration = None + +class FeedbackType(str, Enum): + """Types of user feedback""" + CORRECT = "correct" + INCORRECT = "incorrect" + PARTIAL = "partial" + SKIP = "skip" + +@dataclass +class Feedback: + """User feedback on workflow execution""" + workflow_id: str + execution_id: str + feedback_type: FeedbackType + timestamp: datetime + confidence_before: float + user_comment: Optional[str] = None + corrections: Optional[Dict] = None + +class FeedbackProcessor: + """Processes user feedback to improve workflows""" + + def __init__(self, auto_integrate_corrections: bool = True): + """ + Initialize the feedback processor. + + Args: + auto_integrate_corrections: Auto-propagate corrections to Correction Packs + """ + self.feedback_history: List[Feedback] = [] + self._auto_integrate = auto_integrate_corrections + self._correction_integration = None + logger.info("FeedbackProcessor initialized") + + def process_feedback( + self, + workflow_id: str, + execution_id: str, + feedback_type: FeedbackType, + confidence: float, + comment: Optional[str] = None, + corrections: Optional[Dict] = None, + context: Optional[Dict[str, Any]] = None + ) -> Dict: + """Process user feedback and return improvement suggestions""" + + feedback = Feedback( + workflow_id=workflow_id, + execution_id=execution_id, + feedback_type=feedback_type, + timestamp=datetime.now(), + confidence_before=confidence, + user_comment=comment, + corrections=corrections + ) + + self.feedback_history.append(feedback) + + logger.info( + f"Feedback processed: workflow={workflow_id}, " + f"type={feedback_type.value}, confidence={confidence:.2f}" + ) + + # Propagate corrections to Correction Packs if applicable + correction_id = None + if corrections and feedback_type in [FeedbackType.INCORRECT, FeedbackType.PARTIAL]: + correction_id = self._propagate_corrections( + workflow_id, execution_id, corrections, context + ) + + # Generate improvement suggestions + suggestions = self._generate_suggestions(feedback) + + return { + 'feedback_recorded': True, + 'suggestions': suggestions, + 'should_update_workflow': feedback_type in [FeedbackType.INCORRECT, FeedbackType.PARTIAL], + 'correction_pack_id': correction_id + } + + def _propagate_corrections( + self, + workflow_id: str, + execution_id: str, + corrections: Dict, + context: Optional[Dict[str, Any]] = None + ) -> Optional[str]: + """Propagate corrections to Correction Pack system.""" + if not self._auto_integrate: + return None + + try: + if self._correction_integration is None: + from core.corrections import get_correction_pack_integration + self._correction_integration = get_correction_pack_integration() + + return self._correction_integration.capture_feedback_correction( + workflow_id=workflow_id, + execution_id=execution_id, + corrections=corrections, + context=context + ) + except ImportError: + logger.debug("Correction pack integration not available") + return None + except Exception as e: + logger.warning(f"Error propagating corrections: {e}") + return None + + def _generate_suggestions(self, feedback: Feedback) -> List[str]: + """Generate improvement suggestions based on feedback""" + suggestions = [] + + if feedback.feedback_type == FeedbackType.INCORRECT: + suggestions.append("Review target resolution strategy") + suggestions.append("Check if UI elements changed") + suggestions.append("Verify action sequence") + + if feedback.corrections: + suggestions.append(f"Apply user corrections: {feedback.corrections}") + + elif feedback.feedback_type == FeedbackType.PARTIAL: + suggestions.append("Some steps succeeded - identify failing step") + suggestions.append("Consider splitting workflow into smaller parts") + + elif feedback.feedback_type == FeedbackType.CORRECT: + suggestions.append("Workflow performing well - increase confidence") + + return suggestions + + def get_feedback_stats(self, workflow_id: str) -> Dict: + """Get feedback statistics for a workflow""" + workflow_feedback = [f for f in self.feedback_history if f.workflow_id == workflow_id] + + if not workflow_feedback: + return { + 'total': 0, + 'correct': 0, + 'incorrect': 0, + 'partial': 0, + 'skip': 0, + 'accuracy': 0.0 + } + + total = len(workflow_feedback) + correct = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.CORRECT) + incorrect = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.INCORRECT) + partial = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.PARTIAL) + skip = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.SKIP) + + accuracy = correct / total if total > 0 else 0.0 + + return { + 'total': total, + 'correct': correct, + 'incorrect': incorrect, + 'partial': partial, + 'skip': skip, + 'accuracy': accuracy + } + + def get_recent_feedback(self, workflow_id: str, limit: int = 10) -> List[Feedback]: + """Get recent feedback for a workflow""" + workflow_feedback = [f for f in self.feedback_history if f.workflow_id == workflow_id] + return sorted(workflow_feedback, key=lambda f: f.timestamp, reverse=True)[:limit] diff --git a/core/training/training_data_collector.py b/core/training/training_data_collector.py new file mode 100644 index 000000000..764b219d5 --- /dev/null +++ b/core/training/training_data_collector.py @@ -0,0 +1,246 @@ +"""Training Data Collector - Collect clean training data during real usage""" +import logging +import json +from typing import List, Dict, Any, Optional, Callable +from dataclasses import dataclass, field, asdict +from datetime import datetime +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Type for correction callback +CorrectionCallback = Callable[[Dict[str, Any], str, Optional[str]], None] + +@dataclass +class TrainingSession: + """Single training session record""" + session_id: str + workflow_id: Optional[str] + timestamp: datetime + screenshots: List[str] # Paths to screenshots + actions: List[Dict[str, Any]] + embeddings: List[str] # Paths to embedding files + success: bool + user_corrections: List[Dict[str, Any]] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + +@dataclass +class WorkflowPattern: + """Detected workflow pattern""" + pattern_id: str + name: str + frequency: int + avg_confidence: float + states: List[str] # State IDs + actions: List[str] # Action types + success_rate: float + +class TrainingDataCollector: + """Collect structured training data during real usage""" + + def __init__( + self, + output_dir: str = "training_data", + correction_callback: Optional[CorrectionCallback] = None, + auto_integrate_corrections: bool = True + ): + """ + Initialize the training data collector. + + Args: + output_dir: Directory for storing training data + correction_callback: Optional callback for correction propagation + auto_integrate_corrections: Auto-enable correction pack integration + """ + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + self.sessions: List[TrainingSession] = [] + self.patterns: List[WorkflowPattern] = [] + self.current_session: Optional[TrainingSession] = None + + # Correction pack integration + self._correction_callback = correction_callback + self._auto_integrate = auto_integrate_corrections + self._correction_integration = None + + logger.info(f"TrainingDataCollector initialized (output_dir={output_dir})") + + def start_session(self, session_id: str, workflow_id: Optional[str] = None) -> None: + """Start recording a new training session""" + self.current_session = TrainingSession( + session_id=session_id, + workflow_id=workflow_id, + timestamp=datetime.now(), + screenshots=[], + actions=[], + embeddings=[], + success=False + ) + logger.info(f"Started training session: {session_id}") + + def record_screenshot(self, screenshot_path: str) -> None: + """Record a screenshot""" + if self.current_session: + self.current_session.screenshots.append(screenshot_path) + + def record_action(self, action: Dict[str, Any]) -> None: + """Record an action""" + if self.current_session: + self.current_session.actions.append({ + **action, + 'timestamp': datetime.now().isoformat() + }) + + def record_embedding(self, embedding_path: str) -> None: + """Record an embedding file path""" + if self.current_session: + self.current_session.embeddings.append(embedding_path) + + def record_correction(self, correction: Dict[str, Any]) -> None: + """Record a user correction and propagate to Correction Packs.""" + if self.current_session: + timestamped_correction = { + **correction, + 'timestamp': datetime.now().isoformat() + } + self.current_session.user_corrections.append(timestamped_correction) + logger.info(f"User correction recorded: {correction.get('type', 'unknown')}") + + # Propagate to Correction Packs + self._propagate_to_correction_packs(timestamped_correction) + + def _propagate_to_correction_packs(self, correction: Dict[str, Any]) -> None: + """Propagate correction to Correction Pack system.""" + # Use explicit callback if provided + if self._correction_callback: + try: + session_id = self.current_session.session_id if self.current_session else None + workflow_id = self.current_session.workflow_id if self.current_session else None + self._correction_callback(correction, session_id, workflow_id) + except Exception as e: + logger.warning(f"Error in correction callback: {e}") + + # Auto-integrate if enabled + elif self._auto_integrate: + try: + if self._correction_integration is None: + from core.corrections import get_correction_pack_integration + self._correction_integration = get_correction_pack_integration() + + session_id = self.current_session.session_id if self.current_session else None + workflow_id = self.current_session.workflow_id if self.current_session else None + + self._correction_integration.capture_correction( + correction_data=correction, + session_id=session_id, + workflow_id=workflow_id + ) + except ImportError: + logger.debug("Correction pack integration not available") + except Exception as e: + logger.warning(f"Error propagating to correction packs: {e}") + + def end_session(self, success: bool, metadata: Optional[Dict] = None) -> None: + """End current session and save""" + if not self.current_session: + logger.warning("No active session to end") + return + + self.current_session.success = success + if metadata: + self.current_session.metadata.update(metadata) + + self.sessions.append(self.current_session) + + # Save session to disk + self._save_session(self.current_session) + + logger.info( + f"Session ended: {self.current_session.session_id} " + f"(success={success}, actions={len(self.current_session.actions)})" + ) + + self.current_session = None + + def _save_session(self, session: TrainingSession) -> None: + """Save session to JSON file""" + session_file = self.output_dir / f"session_{session.session_id}.json" + + # Convert to dict with datetime serialization + session_dict = asdict(session) + session_dict['timestamp'] = session.timestamp.isoformat() + + with open(session_file, 'w') as f: + json.dump(session_dict, f, indent=2) + + def add_pattern(self, pattern: WorkflowPattern) -> None: + """Add a detected workflow pattern""" + self.patterns.append(pattern) + logger.info(f"Pattern added: {pattern.name} (frequency={pattern.frequency})") + + def export_training_set(self, output_file: str = "training_set.json") -> Dict[str, Any]: + """Export complete training dataset""" + training_set = { + 'metadata': { + 'export_date': datetime.now().isoformat(), + 'total_sessions': len(self.sessions), + 'total_patterns': len(self.patterns), + 'success_rate': self._calculate_success_rate() + }, + 'sessions': [self._serialize_session(s) for s in self.sessions], + 'patterns': [asdict(p) for p in self.patterns], + 'statistics': self._calculate_statistics() + } + + output_path = self.output_dir / output_file + with open(output_path, 'w') as f: + json.dump(training_set, f, indent=2) + + logger.info(f"Training set exported: {output_path} ({len(self.sessions)} sessions)") + return training_set + + def _calculate_success_rate(self) -> float: + """Calculate overall success rate""" + if not self.sessions: + return 0.0 + successful = sum(1 for s in self.sessions if s.success) + return successful / len(self.sessions) + + def _calculate_statistics(self) -> Dict[str, Any]: + """Calculate training data statistics""" + if not self.sessions: + return {} + + total_actions = sum(len(s.actions) for s in self.sessions) + total_corrections = sum(len(s.user_corrections) for s in self.sessions) + + return { + 'total_sessions': len(self.sessions), + 'successful_sessions': sum(1 for s in self.sessions if s.success), + 'total_actions': total_actions, + 'total_corrections': total_corrections, + 'avg_actions_per_session': total_actions / len(self.sessions), + 'correction_rate': total_corrections / total_actions if total_actions > 0 else 0 + } + + def get_sessions_by_workflow(self, workflow_id: str) -> List[TrainingSession]: + """Get all sessions for a specific workflow""" + return [s for s in self.sessions if s.workflow_id == workflow_id] + + def get_successful_sessions(self) -> List[TrainingSession]: + """Get only successful sessions""" + return [s for s in self.sessions if s.success] + + def clear_data(self) -> None: + """Clear all collected data""" + self.sessions.clear() + self.patterns.clear() + self.current_session = None + logger.info("Training data cleared") + + def _serialize_session(self, session: TrainingSession) -> Dict: + """Serialize session with datetime conversion""" + data = asdict(session) + data['timestamp'] = session.timestamp.isoformat() + return data diff --git a/tests/test_correction_pack_integration.py b/tests/test_correction_pack_integration.py new file mode 100644 index 000000000..90d93f9b2 --- /dev/null +++ b/tests/test_correction_pack_integration.py @@ -0,0 +1,486 @@ +""" +Tests for Correction Pack Integration with learning components. + +Tests the automatic capture of corrections from COACHING mode into +the Correction Pack system. +""" + +import pytest +import tempfile +import shutil +from pathlib import Path +from datetime import datetime +from unittest.mock import patch, MagicMock + +from core.corrections import ( + CorrectionPackIntegration, + CorrectionPackService, + get_correction_pack_integration, + capture_coaching_correction +) +from core.corrections.models import CorrectionType +from core.training.training_data_collector import TrainingDataCollector +from core.learning.feedback_processor import FeedbackProcessor, FeedbackType + + +@pytest.fixture +def temp_storage(): + """Create temporary storage directory.""" + temp_dir = tempfile.mkdtemp() + yield Path(temp_dir) + shutil.rmtree(temp_dir, ignore_errors=True) + + +@pytest.fixture +def service(temp_storage): + """Create a CorrectionPackService with temp storage.""" + return CorrectionPackService(storage_path=temp_storage) + + +@pytest.fixture +def integration(service): + """Create a CorrectionPackIntegration with test service.""" + return CorrectionPackIntegration(service=service, auto_create_pack=True) + + +class TestCorrectionPackIntegration: + """Tests for CorrectionPackIntegration class.""" + + def test_init(self, service): + """Test integration initialization.""" + integration = CorrectionPackIntegration(service=service) + assert integration._service is service + assert integration._auto_create_pack is True + assert integration._initialized is False + + def test_ensure_initialized_creates_pack(self, integration): + """Test auto-creation of default pack.""" + pack_id = integration._ensure_initialized() + + assert pack_id is not None + assert integration._initialized is True + assert integration._default_pack_id == pack_id + + # Verify pack exists + pack = integration.service.get_pack(pack_id) + assert pack is not None + # Handle dict format from service + pack_name = pack.get('name') if isinstance(pack, dict) else pack.name + assert pack_name == CorrectionPackIntegration.DEFAULT_PACK_NAME + + def test_capture_correction(self, integration): + """Test capturing a correction.""" + correction_data = { + 'action_type': 'click', + 'element_type': 'button', + 'failure_reason': 'element_not_found', + 'correction_type': 'target_change', + 'original_target': {'text': 'Submit'}, + 'corrected_target': {'text': 'Envoyer'} + } + + correction_id = integration.capture_correction( + correction_data=correction_data, + session_id='test_session_1', + workflow_id='wf_001' + ) + + assert correction_id is not None + + # Verify in pack + pack = integration.service.get_pack(integration._default_pack_id) + corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections + assert len(corrections) == 1 + + def test_capture_correction_with_minimal_data(self, integration): + """Test capturing correction with minimal data.""" + correction_data = { + 'type': 'input', + 'new_target': {'xpath': '//input[@name="email"]'} + } + + correction_id = integration.capture_correction( + correction_data=correction_data, + session_id='session_2' + ) + + assert correction_id is not None + + def test_capture_feedback_correction(self, integration): + """Test capturing correction from feedback.""" + corrections = { + 'action_type': 'type', + 'element_type': 'input', + 'original_target': {'id': 'username'}, + 'corrected_target': {'name': 'user_name'} + } + context = { + 'failure_reason': 'id_changed' + } + + correction_id = integration.capture_feedback_correction( + workflow_id='wf_feedback_test', + execution_id='exec_001', + corrections=corrections, + context=context + ) + + assert correction_id is not None + + def test_infer_correction_type(self, integration): + """Test correction type inference.""" + # Target change + assert integration._infer_correction_type({ + 'corrected_target': {'id': 'new'} + }) == CorrectionType.TARGET_CHANGE.value + + # Parameter change + assert integration._infer_correction_type({ + 'corrected_params': {'timeout': 5} + }) == CorrectionType.PARAMETER_CHANGE.value + + # Timing adjust + assert integration._infer_correction_type({ + 'wait_time': 2.0 + }) == CorrectionType.TIMING_ADJUST.value + + # Coordinates adjust + assert integration._infer_correction_type({ + 'coordinates': {'x': 100, 'y': 200} + }) == CorrectionType.COORDINATES_ADJUST.value + + # Other + assert integration._infer_correction_type({}) == CorrectionType.OTHER.value + + def test_create_hook_for_collector(self, integration): + """Test creating hook function.""" + hook = integration.create_hook_for_collector() + + assert callable(hook) + + # Call hook with correction + hook({ + 'action_type': 'click', + 'element_type': 'button', + 'correction_type': 'target_change' + }) + + # Verify captured + pack = integration.service.get_pack(integration._default_pack_id) + corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections + assert len(corrections) == 1 + + def test_get_statistics(self, integration): + """Test getting statistics.""" + # Add some corrections + for i in range(3): + integration.capture_correction( + correction_data={ + 'action_type': f'action_{i}', + 'element_type': 'button' + }, + session_id=f'session_{i}' + ) + + stats = integration.get_statistics() + assert stats['total_corrections'] == 3 + + +class TestGlobalIntegration: + """Tests for global integration functions.""" + + def test_get_correction_pack_integration(self, temp_storage): + """Test getting global integration instance.""" + with patch('core.corrections.integration._global_integration', None): + integration1 = get_correction_pack_integration() + integration2 = get_correction_pack_integration() + + # Same instance returned + assert integration1 is integration2 + + def test_capture_coaching_correction(self, temp_storage): + """Test convenience capture function.""" + with patch('core.corrections.integration._global_integration', None): + correction_id = capture_coaching_correction( + correction_data={ + 'action_type': 'click', + 'element_type': 'button' + }, + session_id='test_session' + ) + + # Should succeed (creates pack automatically) + assert correction_id is not None + + +class TestTrainingDataCollectorIntegration: + """Tests for TrainingDataCollector integration.""" + + def test_collector_with_auto_integration(self, temp_storage, service): + """Test collector with automatic correction pack integration.""" + # Create collector + collector = TrainingDataCollector( + output_dir=str(temp_storage / 'training'), + auto_integrate_corrections=True + ) + + # Inject our test service + integration = CorrectionPackIntegration(service=service) + collector._correction_integration = integration + + # Start session + collector.start_session('session_auto_test', workflow_id='wf_auto') + + # Record correction + collector.record_correction({ + 'action_type': 'click', + 'element_type': 'button', + 'failure_reason': 'element_moved', + 'correction_type': 'coordinates_adjust', + 'original_target': {'x': 100, 'y': 100}, + 'corrected_target': {'x': 150, 'y': 120} + }) + + # End session + collector.end_session(success=True) + + # Verify correction captured in pack + pack = service.get_pack(integration._default_pack_id) + assert pack is not None + corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections + assert len(corrections) == 1 + + def test_collector_with_custom_callback(self, temp_storage): + """Test collector with custom callback.""" + captured = [] + + def custom_callback(correction, session_id, workflow_id): + captured.append({ + 'correction': correction, + 'session_id': session_id, + 'workflow_id': workflow_id + }) + + collector = TrainingDataCollector( + output_dir=str(temp_storage / 'training'), + correction_callback=custom_callback, + auto_integrate_corrections=False + ) + + collector.start_session('session_callback', workflow_id='wf_callback') + collector.record_correction({'type': 'test_correction'}) + + assert len(captured) == 1 + assert captured[0]['session_id'] == 'session_callback' + assert captured[0]['workflow_id'] == 'wf_callback' + + def test_collector_disabled_integration(self, temp_storage): + """Test collector with disabled integration.""" + collector = TrainingDataCollector( + output_dir=str(temp_storage / 'training'), + auto_integrate_corrections=False + ) + + collector.start_session('session_disabled') + + # Should not fail even without integration + collector.record_correction({'type': 'test'}) + + # Still recorded in session + assert len(collector.current_session.user_corrections) == 1 + + +class TestFeedbackProcessorIntegration: + """Tests for FeedbackProcessor integration.""" + + def test_processor_with_auto_integration(self, service): + """Test feedback processor with auto integration.""" + processor = FeedbackProcessor(auto_integrate_corrections=True) + + # Inject our test service + integration = CorrectionPackIntegration(service=service) + processor._correction_integration = integration + + # Process feedback with corrections + result = processor.process_feedback( + workflow_id='wf_feedback', + execution_id='exec_001', + feedback_type=FeedbackType.INCORRECT, + confidence=0.5, + corrections={ + 'action_type': 'click', + 'element_type': 'link', + 'original_target': {'text': 'More'}, + 'corrected_target': {'text': 'Plus'} + }, + context={'failure_reason': 'text_changed'} + ) + + assert result['feedback_recorded'] is True + assert result['correction_pack_id'] is not None + + # Verify in pack + pack = service.get_pack(integration._default_pack_id) + corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections + assert len(corrections) == 1 + + def test_processor_correct_feedback_no_propagation(self, service): + """Test that CORRECT feedback doesn't propagate corrections.""" + processor = FeedbackProcessor(auto_integrate_corrections=True) + + integration = CorrectionPackIntegration(service=service) + processor._correction_integration = integration + + result = processor.process_feedback( + workflow_id='wf_correct', + execution_id='exec_002', + feedback_type=FeedbackType.CORRECT, # Correct, no correction needed + confidence=0.95, + corrections={'some': 'data'} # Even with data, shouldn't propagate + ) + + assert result['correction_pack_id'] is None + + def test_processor_partial_feedback_propagates(self, service): + """Test that PARTIAL feedback propagates corrections.""" + processor = FeedbackProcessor(auto_integrate_corrections=True) + + integration = CorrectionPackIntegration(service=service) + processor._correction_integration = integration + + result = processor.process_feedback( + workflow_id='wf_partial', + execution_id='exec_003', + feedback_type=FeedbackType.PARTIAL, + confidence=0.7, + corrections={ + 'action_type': 'fill', + 'corrected_params': {'value': 'corrected_value'} + } + ) + + assert result['correction_pack_id'] is not None + + def test_processor_disabled_integration(self): + """Test processor with disabled integration.""" + processor = FeedbackProcessor(auto_integrate_corrections=False) + + result = processor.process_feedback( + workflow_id='wf_disabled', + execution_id='exec_004', + feedback_type=FeedbackType.INCORRECT, + confidence=0.3, + corrections={'action_type': 'test'} + ) + + # Feedback recorded but no correction pack ID + assert result['feedback_recorded'] is True + assert result['correction_pack_id'] is None + + +class TestEndToEndFlow: + """End-to-end tests for the complete COACHING flow.""" + + def test_coaching_correction_flow(self, temp_storage, service): + """Test complete COACHING correction flow.""" + # Setup components + integration = CorrectionPackIntegration(service=service) + + collector = TrainingDataCollector( + output_dir=str(temp_storage / 'training'), + auto_integrate_corrections=True + ) + collector._correction_integration = integration + + processor = FeedbackProcessor(auto_integrate_corrections=True) + processor._correction_integration = integration + + # Simulate COACHING session + collector.start_session('coaching_session_001', workflow_id='wf_coaching') + + # User makes corrections during execution + collector.record_correction({ + 'action_type': 'click', + 'element_type': 'button', + 'failure_reason': 'element_not_found', + 'correction_type': 'target_change', + 'original_target': {'text': 'OK'}, + 'corrected_target': {'text': 'Valider'} + }) + + collector.record_correction({ + 'action_type': 'type', + 'element_type': 'input', + 'failure_reason': 'wrong_field', + 'correction_type': 'target_change', + 'original_target': {'id': 'email'}, + 'corrected_target': {'name': 'user_email'} + }) + + collector.end_session(success=True) + + # User provides feedback at end + processor.process_feedback( + workflow_id='wf_coaching', + execution_id='coaching_session_001', + feedback_type=FeedbackType.PARTIAL, + confidence=0.75, + corrections={ + 'action_type': 'submit', + 'element_type': 'form', + 'corrected_params': {'wait_after': 2.0} + } + ) + + # Verify all corrections captured + pack = service.get_pack(integration._default_pack_id) + corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections + assert len(corrections) == 3 + + # Verify statistics + pack_id = pack.get('id') if isinstance(pack, dict) else pack.id + stats = service.get_pack_statistics(pack_id) + assert stats['total_corrections'] == 3 + + def test_multiple_sessions_aggregation(self, temp_storage, service): + """Test corrections from multiple sessions aggregate correctly.""" + integration = CorrectionPackIntegration(service=service) + + # Session 1 + collector1 = TrainingDataCollector( + output_dir=str(temp_storage / 'training1') + ) + collector1._correction_integration = integration + + collector1.start_session('session_001', workflow_id='wf_multi') + collector1.record_correction({ + 'action_type': 'click', + 'element_type': 'button', + 'failure_reason': 'timeout', + 'correction_type': 'timing_adjust' + }) + collector1.end_session(success=True) + + # Session 2 - same type of correction + collector2 = TrainingDataCollector( + output_dir=str(temp_storage / 'training2') + ) + collector2._correction_integration = integration + + collector2.start_session('session_002', workflow_id='wf_multi') + collector2.record_correction({ + 'action_type': 'click', + 'element_type': 'button', + 'failure_reason': 'timeout', + 'correction_type': 'timing_adjust' + }) + collector2.end_session(success=True) + + # Both corrections captured + pack = service.get_pack(integration._default_pack_id) + corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections + assert len(corrections) == 2 + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])