feat(corrections): Add automatic COACHING integration for Correction Packs

- Add CorrectionPackIntegration class to bridge learning components
- Modify TrainingDataCollector to auto-propagate corrections to packs
- Modify FeedbackProcessor to capture corrections on INCORRECT/PARTIAL feedback
- Add convenience functions: get_correction_pack_integration(), capture_coaching_correction()
- Add 19 integration tests (all passing)

Corrections made during COACHING mode are now automatically captured
into a dedicated "auto_captured_corrections" pack for cross-workflow reuse.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Dom
2026-01-18 19:06:09 +01:00
parent d8756883c5
commit efb184fdb9
5 changed files with 1206 additions and 1 deletions

View File

@@ -16,6 +16,11 @@ from .models import (
from .correction_repository import CorrectionRepository from .correction_repository import CorrectionRepository
from .aggregator import CorrectionAggregator from .aggregator import CorrectionAggregator
from .correction_pack_service import CorrectionPackService from .correction_pack_service import CorrectionPackService
from .integration import (
CorrectionPackIntegration,
get_correction_pack_integration,
capture_coaching_correction
)
__all__ = [ __all__ = [
'CorrectionKey', 'CorrectionKey',
@@ -26,5 +31,9 @@ __all__ = [
'CorrectionPackMetadata', 'CorrectionPackMetadata',
'CorrectionRepository', 'CorrectionRepository',
'CorrectionAggregator', 'CorrectionAggregator',
'CorrectionPackService' 'CorrectionPackService',
# Integration
'CorrectionPackIntegration',
'get_correction_pack_integration',
'capture_coaching_correction'
] ]

View File

@@ -0,0 +1,288 @@
"""
Correction Pack Integration - Bridge between learning components and Correction Packs.
Automatically captures corrections from COACHING mode and propagates them
to the Correction Pack system for cross-workflow learning.
"""
import logging
from typing import Dict, Any, Optional, Callable
from datetime import datetime
from .correction_pack_service import CorrectionPackService
from .models import CorrectionType
logger = logging.getLogger(__name__)
class CorrectionPackIntegration:
"""
Integration layer between learning components and Correction Packs.
Captures corrections from TrainingDataCollector and FeedbackProcessor
and automatically forwards them to the Correction Pack system.
"""
# Default pack name for auto-captured corrections
DEFAULT_PACK_NAME = "auto_captured_corrections"
DEFAULT_PACK_DESCRIPTION = "Corrections automatiquement capturées pendant les sessions COACHING"
def __init__(
self,
service: Optional[CorrectionPackService] = None,
auto_create_pack: bool = True,
default_pack_id: Optional[str] = None
):
"""
Initialize the integration.
Args:
service: CorrectionPackService instance (lazy-loaded if None)
auto_create_pack: Create default pack if it doesn't exist
default_pack_id: ID of pack to use (auto-detected if None)
"""
self._service = service
self._auto_create_pack = auto_create_pack
self._default_pack_id = default_pack_id
self._initialized = False
@property
def service(self) -> CorrectionPackService:
"""Lazy-load service."""
if self._service is None:
self._service = CorrectionPackService()
return self._service
def _ensure_initialized(self) -> str:
"""
Ensure the default pack exists and return its ID.
Returns:
Pack ID to use for storing corrections
"""
if self._initialized and self._default_pack_id:
return self._default_pack_id
# If specific pack ID provided, verify it exists
if self._default_pack_id:
pack = self.service.get_pack(self._default_pack_id)
if pack:
self._initialized = True
return self._default_pack_id
else:
logger.warning(f"Pack {self._default_pack_id} not found, creating default")
# Search for existing auto-capture pack
packs = self.service.list_packs()
for pack in packs:
# Handle both dict and object formats
pack_name = pack.get('name') if isinstance(pack, dict) else pack.name
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
if pack_name == self.DEFAULT_PACK_NAME:
self._default_pack_id = pack_id
self._initialized = True
logger.info(f"Using existing auto-capture pack: {pack_id}")
return self._default_pack_id
# Create new pack if auto_create enabled
if self._auto_create_pack:
pack = self.service.create_pack(
name=self.DEFAULT_PACK_NAME,
description=self.DEFAULT_PACK_DESCRIPTION,
tags=["auto", "coaching"],
category="learning"
)
# Handle both dict and object formats
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
self._default_pack_id = pack_id
self._initialized = True
logger.info(f"Created auto-capture pack: {pack_id}")
return self._default_pack_id
raise RuntimeError("No default pack available and auto_create_pack is disabled")
def capture_correction(
self,
correction_data: Dict[str, Any],
session_id: Optional[str] = None,
workflow_id: Optional[str] = None,
node_id: Optional[str] = None,
pack_id: Optional[str] = None
) -> Optional[str]:
"""
Capture a correction and add it to a Correction Pack.
Args:
correction_data: Correction dict (from TrainingDataCollector format)
session_id: Source session ID
workflow_id: Source workflow ID
node_id: Source node ID
pack_id: Specific pack to add to (uses default if None)
Returns:
Correction ID if successfully added, None otherwise
"""
try:
target_pack_id = pack_id or self._ensure_initialized()
# Enrich correction data with source info
enriched = {
**correction_data,
'node_id': node_id or correction_data.get('node_id')
}
# Determine session_id
final_session_id = session_id or correction_data.get('session_id', 'unknown')
final_workflow_id = workflow_id or correction_data.get('workflow_id')
# Add to pack
correction = self.service.add_correction_from_session(
pack_id=target_pack_id,
session_correction=enriched,
session_id=final_session_id,
workflow_id=final_workflow_id
)
if correction:
# Handle both dict and object formats
correction_id = correction.get('id') if isinstance(correction, dict) else correction.id
logger.info(
f"Captured correction {correction_id} from "
f"session={final_session_id}, workflow={final_workflow_id}"
)
return correction_id
else:
logger.warning("Failed to add correction to pack")
return None
except Exception as e:
logger.error(f"Error capturing correction: {e}")
return None
def capture_feedback_correction(
self,
workflow_id: str,
execution_id: str,
corrections: Dict[str, Any],
context: Optional[Dict[str, Any]] = None
) -> Optional[str]:
"""
Capture a correction from FeedbackProcessor.
Args:
workflow_id: Workflow ID
execution_id: Execution ID (used as session_id)
corrections: Correction dict from feedback
context: Additional context (failure_reason, element_type, etc.)
Returns:
Correction ID if successfully added
"""
if not corrections:
return None
# Build correction data in standard format
correction_data = {
'action_type': corrections.get('action_type', corrections.get('type', 'unknown')),
'element_type': corrections.get('element_type', context.get('element_type', 'unknown') if context else 'unknown'),
'failure_reason': corrections.get('failure_reason', context.get('failure_reason', '') if context else ''),
'correction_type': self._infer_correction_type(corrections),
'original_target': corrections.get('original_target'),
'corrected_target': corrections.get('corrected_target', corrections.get('new_target')),
'original_params': corrections.get('original_params'),
'corrected_params': corrections.get('corrected_params', corrections.get('new_params')),
'description': corrections.get('description', f"Correction from feedback {execution_id}")
}
return self.capture_correction(
correction_data=correction_data,
session_id=execution_id,
workflow_id=workflow_id
)
def _infer_correction_type(self, corrections: Dict[str, Any]) -> str:
"""Infer correction type from correction data."""
if corrections.get('correction_type'):
return corrections['correction_type']
# Infer from data
if corrections.get('corrected_target') or corrections.get('new_target'):
return CorrectionType.TARGET_CHANGE.value
if corrections.get('corrected_params') or corrections.get('new_params'):
return CorrectionType.PARAMETER_CHANGE.value
if corrections.get('wait_time') or corrections.get('timing'):
return CorrectionType.TIMING_ADJUST.value
if corrections.get('coordinates') or corrections.get('offset'):
return CorrectionType.COORDINATES_ADJUST.value
return CorrectionType.OTHER.value
def create_hook_for_collector(self) -> Callable[[Dict[str, Any]], None]:
"""
Create a hook function for TrainingDataCollector.
Returns:
Callback function that captures corrections
Usage:
integration = CorrectionPackIntegration()
collector.on_correction = integration.create_hook_for_collector()
"""
def hook(correction: Dict[str, Any]) -> None:
self.capture_correction(correction)
return hook
def get_statistics(self) -> Dict[str, Any]:
"""Get statistics about captured corrections."""
try:
pack_id = self._ensure_initialized()
return self.service.get_pack_statistics(pack_id)
except Exception as e:
logger.error(f"Error getting statistics: {e}")
return {'error': str(e)}
# Global instance for easy access
_global_integration: Optional[CorrectionPackIntegration] = None
def get_correction_pack_integration() -> CorrectionPackIntegration:
"""
Get the global CorrectionPackIntegration instance.
Creates a new instance if one doesn't exist.
Returns:
CorrectionPackIntegration instance
"""
global _global_integration
if _global_integration is None:
_global_integration = CorrectionPackIntegration()
return _global_integration
def capture_coaching_correction(
correction_data: Dict[str, Any],
session_id: Optional[str] = None,
workflow_id: Optional[str] = None,
node_id: Optional[str] = None
) -> Optional[str]:
"""
Convenience function to capture a correction.
Args:
correction_data: Correction dict
session_id: Source session ID
workflow_id: Source workflow ID
node_id: Source node ID
Returns:
Correction ID if successful
"""
integration = get_correction_pack_integration()
return integration.capture_correction(
correction_data=correction_data,
session_id=session_id,
workflow_id=workflow_id,
node_id=node_id
)

View File

@@ -0,0 +1,176 @@
"""Feedback Processor - Processes user feedback to improve workflows"""
import logging
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
logger = logging.getLogger(__name__)
# Lazy import for correction pack integration
_correction_integration = None
class FeedbackType(str, Enum):
"""Types of user feedback"""
CORRECT = "correct"
INCORRECT = "incorrect"
PARTIAL = "partial"
SKIP = "skip"
@dataclass
class Feedback:
"""User feedback on workflow execution"""
workflow_id: str
execution_id: str
feedback_type: FeedbackType
timestamp: datetime
confidence_before: float
user_comment: Optional[str] = None
corrections: Optional[Dict] = None
class FeedbackProcessor:
"""Processes user feedback to improve workflows"""
def __init__(self, auto_integrate_corrections: bool = True):
"""
Initialize the feedback processor.
Args:
auto_integrate_corrections: Auto-propagate corrections to Correction Packs
"""
self.feedback_history: List[Feedback] = []
self._auto_integrate = auto_integrate_corrections
self._correction_integration = None
logger.info("FeedbackProcessor initialized")
def process_feedback(
self,
workflow_id: str,
execution_id: str,
feedback_type: FeedbackType,
confidence: float,
comment: Optional[str] = None,
corrections: Optional[Dict] = None,
context: Optional[Dict[str, Any]] = None
) -> Dict:
"""Process user feedback and return improvement suggestions"""
feedback = Feedback(
workflow_id=workflow_id,
execution_id=execution_id,
feedback_type=feedback_type,
timestamp=datetime.now(),
confidence_before=confidence,
user_comment=comment,
corrections=corrections
)
self.feedback_history.append(feedback)
logger.info(
f"Feedback processed: workflow={workflow_id}, "
f"type={feedback_type.value}, confidence={confidence:.2f}"
)
# Propagate corrections to Correction Packs if applicable
correction_id = None
if corrections and feedback_type in [FeedbackType.INCORRECT, FeedbackType.PARTIAL]:
correction_id = self._propagate_corrections(
workflow_id, execution_id, corrections, context
)
# Generate improvement suggestions
suggestions = self._generate_suggestions(feedback)
return {
'feedback_recorded': True,
'suggestions': suggestions,
'should_update_workflow': feedback_type in [FeedbackType.INCORRECT, FeedbackType.PARTIAL],
'correction_pack_id': correction_id
}
def _propagate_corrections(
self,
workflow_id: str,
execution_id: str,
corrections: Dict,
context: Optional[Dict[str, Any]] = None
) -> Optional[str]:
"""Propagate corrections to Correction Pack system."""
if not self._auto_integrate:
return None
try:
if self._correction_integration is None:
from core.corrections import get_correction_pack_integration
self._correction_integration = get_correction_pack_integration()
return self._correction_integration.capture_feedback_correction(
workflow_id=workflow_id,
execution_id=execution_id,
corrections=corrections,
context=context
)
except ImportError:
logger.debug("Correction pack integration not available")
return None
except Exception as e:
logger.warning(f"Error propagating corrections: {e}")
return None
def _generate_suggestions(self, feedback: Feedback) -> List[str]:
"""Generate improvement suggestions based on feedback"""
suggestions = []
if feedback.feedback_type == FeedbackType.INCORRECT:
suggestions.append("Review target resolution strategy")
suggestions.append("Check if UI elements changed")
suggestions.append("Verify action sequence")
if feedback.corrections:
suggestions.append(f"Apply user corrections: {feedback.corrections}")
elif feedback.feedback_type == FeedbackType.PARTIAL:
suggestions.append("Some steps succeeded - identify failing step")
suggestions.append("Consider splitting workflow into smaller parts")
elif feedback.feedback_type == FeedbackType.CORRECT:
suggestions.append("Workflow performing well - increase confidence")
return suggestions
def get_feedback_stats(self, workflow_id: str) -> Dict:
"""Get feedback statistics for a workflow"""
workflow_feedback = [f for f in self.feedback_history if f.workflow_id == workflow_id]
if not workflow_feedback:
return {
'total': 0,
'correct': 0,
'incorrect': 0,
'partial': 0,
'skip': 0,
'accuracy': 0.0
}
total = len(workflow_feedback)
correct = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.CORRECT)
incorrect = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.INCORRECT)
partial = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.PARTIAL)
skip = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.SKIP)
accuracy = correct / total if total > 0 else 0.0
return {
'total': total,
'correct': correct,
'incorrect': incorrect,
'partial': partial,
'skip': skip,
'accuracy': accuracy
}
def get_recent_feedback(self, workflow_id: str, limit: int = 10) -> List[Feedback]:
"""Get recent feedback for a workflow"""
workflow_feedback = [f for f in self.feedback_history if f.workflow_id == workflow_id]
return sorted(workflow_feedback, key=lambda f: f.timestamp, reverse=True)[:limit]

View File

@@ -0,0 +1,246 @@
"""Training Data Collector - Collect clean training data during real usage"""
import logging
import json
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
logger = logging.getLogger(__name__)
# Type for correction callback
CorrectionCallback = Callable[[Dict[str, Any], str, Optional[str]], None]
@dataclass
class TrainingSession:
"""Single training session record"""
session_id: str
workflow_id: Optional[str]
timestamp: datetime
screenshots: List[str] # Paths to screenshots
actions: List[Dict[str, Any]]
embeddings: List[str] # Paths to embedding files
success: bool
user_corrections: List[Dict[str, Any]] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class WorkflowPattern:
"""Detected workflow pattern"""
pattern_id: str
name: str
frequency: int
avg_confidence: float
states: List[str] # State IDs
actions: List[str] # Action types
success_rate: float
class TrainingDataCollector:
"""Collect structured training data during real usage"""
def __init__(
self,
output_dir: str = "training_data",
correction_callback: Optional[CorrectionCallback] = None,
auto_integrate_corrections: bool = True
):
"""
Initialize the training data collector.
Args:
output_dir: Directory for storing training data
correction_callback: Optional callback for correction propagation
auto_integrate_corrections: Auto-enable correction pack integration
"""
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.sessions: List[TrainingSession] = []
self.patterns: List[WorkflowPattern] = []
self.current_session: Optional[TrainingSession] = None
# Correction pack integration
self._correction_callback = correction_callback
self._auto_integrate = auto_integrate_corrections
self._correction_integration = None
logger.info(f"TrainingDataCollector initialized (output_dir={output_dir})")
def start_session(self, session_id: str, workflow_id: Optional[str] = None) -> None:
"""Start recording a new training session"""
self.current_session = TrainingSession(
session_id=session_id,
workflow_id=workflow_id,
timestamp=datetime.now(),
screenshots=[],
actions=[],
embeddings=[],
success=False
)
logger.info(f"Started training session: {session_id}")
def record_screenshot(self, screenshot_path: str) -> None:
"""Record a screenshot"""
if self.current_session:
self.current_session.screenshots.append(screenshot_path)
def record_action(self, action: Dict[str, Any]) -> None:
"""Record an action"""
if self.current_session:
self.current_session.actions.append({
**action,
'timestamp': datetime.now().isoformat()
})
def record_embedding(self, embedding_path: str) -> None:
"""Record an embedding file path"""
if self.current_session:
self.current_session.embeddings.append(embedding_path)
def record_correction(self, correction: Dict[str, Any]) -> None:
"""Record a user correction and propagate to Correction Packs."""
if self.current_session:
timestamped_correction = {
**correction,
'timestamp': datetime.now().isoformat()
}
self.current_session.user_corrections.append(timestamped_correction)
logger.info(f"User correction recorded: {correction.get('type', 'unknown')}")
# Propagate to Correction Packs
self._propagate_to_correction_packs(timestamped_correction)
def _propagate_to_correction_packs(self, correction: Dict[str, Any]) -> None:
"""Propagate correction to Correction Pack system."""
# Use explicit callback if provided
if self._correction_callback:
try:
session_id = self.current_session.session_id if self.current_session else None
workflow_id = self.current_session.workflow_id if self.current_session else None
self._correction_callback(correction, session_id, workflow_id)
except Exception as e:
logger.warning(f"Error in correction callback: {e}")
# Auto-integrate if enabled
elif self._auto_integrate:
try:
if self._correction_integration is None:
from core.corrections import get_correction_pack_integration
self._correction_integration = get_correction_pack_integration()
session_id = self.current_session.session_id if self.current_session else None
workflow_id = self.current_session.workflow_id if self.current_session else None
self._correction_integration.capture_correction(
correction_data=correction,
session_id=session_id,
workflow_id=workflow_id
)
except ImportError:
logger.debug("Correction pack integration not available")
except Exception as e:
logger.warning(f"Error propagating to correction packs: {e}")
def end_session(self, success: bool, metadata: Optional[Dict] = None) -> None:
"""End current session and save"""
if not self.current_session:
logger.warning("No active session to end")
return
self.current_session.success = success
if metadata:
self.current_session.metadata.update(metadata)
self.sessions.append(self.current_session)
# Save session to disk
self._save_session(self.current_session)
logger.info(
f"Session ended: {self.current_session.session_id} "
f"(success={success}, actions={len(self.current_session.actions)})"
)
self.current_session = None
def _save_session(self, session: TrainingSession) -> None:
"""Save session to JSON file"""
session_file = self.output_dir / f"session_{session.session_id}.json"
# Convert to dict with datetime serialization
session_dict = asdict(session)
session_dict['timestamp'] = session.timestamp.isoformat()
with open(session_file, 'w') as f:
json.dump(session_dict, f, indent=2)
def add_pattern(self, pattern: WorkflowPattern) -> None:
"""Add a detected workflow pattern"""
self.patterns.append(pattern)
logger.info(f"Pattern added: {pattern.name} (frequency={pattern.frequency})")
def export_training_set(self, output_file: str = "training_set.json") -> Dict[str, Any]:
"""Export complete training dataset"""
training_set = {
'metadata': {
'export_date': datetime.now().isoformat(),
'total_sessions': len(self.sessions),
'total_patterns': len(self.patterns),
'success_rate': self._calculate_success_rate()
},
'sessions': [self._serialize_session(s) for s in self.sessions],
'patterns': [asdict(p) for p in self.patterns],
'statistics': self._calculate_statistics()
}
output_path = self.output_dir / output_file
with open(output_path, 'w') as f:
json.dump(training_set, f, indent=2)
logger.info(f"Training set exported: {output_path} ({len(self.sessions)} sessions)")
return training_set
def _calculate_success_rate(self) -> float:
"""Calculate overall success rate"""
if not self.sessions:
return 0.0
successful = sum(1 for s in self.sessions if s.success)
return successful / len(self.sessions)
def _calculate_statistics(self) -> Dict[str, Any]:
"""Calculate training data statistics"""
if not self.sessions:
return {}
total_actions = sum(len(s.actions) for s in self.sessions)
total_corrections = sum(len(s.user_corrections) for s in self.sessions)
return {
'total_sessions': len(self.sessions),
'successful_sessions': sum(1 for s in self.sessions if s.success),
'total_actions': total_actions,
'total_corrections': total_corrections,
'avg_actions_per_session': total_actions / len(self.sessions),
'correction_rate': total_corrections / total_actions if total_actions > 0 else 0
}
def get_sessions_by_workflow(self, workflow_id: str) -> List[TrainingSession]:
"""Get all sessions for a specific workflow"""
return [s for s in self.sessions if s.workflow_id == workflow_id]
def get_successful_sessions(self) -> List[TrainingSession]:
"""Get only successful sessions"""
return [s for s in self.sessions if s.success]
def clear_data(self) -> None:
"""Clear all collected data"""
self.sessions.clear()
self.patterns.clear()
self.current_session = None
logger.info("Training data cleared")
def _serialize_session(self, session: TrainingSession) -> Dict:
"""Serialize session with datetime conversion"""
data = asdict(session)
data['timestamp'] = session.timestamp.isoformat()
return data

View File

@@ -0,0 +1,486 @@
"""
Tests for Correction Pack Integration with learning components.
Tests the automatic capture of corrections from COACHING mode into
the Correction Pack system.
"""
import pytest
import tempfile
import shutil
from pathlib import Path
from datetime import datetime
from unittest.mock import patch, MagicMock
from core.corrections import (
CorrectionPackIntegration,
CorrectionPackService,
get_correction_pack_integration,
capture_coaching_correction
)
from core.corrections.models import CorrectionType
from core.training.training_data_collector import TrainingDataCollector
from core.learning.feedback_processor import FeedbackProcessor, FeedbackType
@pytest.fixture
def temp_storage():
"""Create temporary storage directory."""
temp_dir = tempfile.mkdtemp()
yield Path(temp_dir)
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def service(temp_storage):
"""Create a CorrectionPackService with temp storage."""
return CorrectionPackService(storage_path=temp_storage)
@pytest.fixture
def integration(service):
"""Create a CorrectionPackIntegration with test service."""
return CorrectionPackIntegration(service=service, auto_create_pack=True)
class TestCorrectionPackIntegration:
"""Tests for CorrectionPackIntegration class."""
def test_init(self, service):
"""Test integration initialization."""
integration = CorrectionPackIntegration(service=service)
assert integration._service is service
assert integration._auto_create_pack is True
assert integration._initialized is False
def test_ensure_initialized_creates_pack(self, integration):
"""Test auto-creation of default pack."""
pack_id = integration._ensure_initialized()
assert pack_id is not None
assert integration._initialized is True
assert integration._default_pack_id == pack_id
# Verify pack exists
pack = integration.service.get_pack(pack_id)
assert pack is not None
# Handle dict format from service
pack_name = pack.get('name') if isinstance(pack, dict) else pack.name
assert pack_name == CorrectionPackIntegration.DEFAULT_PACK_NAME
def test_capture_correction(self, integration):
"""Test capturing a correction."""
correction_data = {
'action_type': 'click',
'element_type': 'button',
'failure_reason': 'element_not_found',
'correction_type': 'target_change',
'original_target': {'text': 'Submit'},
'corrected_target': {'text': 'Envoyer'}
}
correction_id = integration.capture_correction(
correction_data=correction_data,
session_id='test_session_1',
workflow_id='wf_001'
)
assert correction_id is not None
# Verify in pack
pack = integration.service.get_pack(integration._default_pack_id)
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 1
def test_capture_correction_with_minimal_data(self, integration):
"""Test capturing correction with minimal data."""
correction_data = {
'type': 'input',
'new_target': {'xpath': '//input[@name="email"]'}
}
correction_id = integration.capture_correction(
correction_data=correction_data,
session_id='session_2'
)
assert correction_id is not None
def test_capture_feedback_correction(self, integration):
"""Test capturing correction from feedback."""
corrections = {
'action_type': 'type',
'element_type': 'input',
'original_target': {'id': 'username'},
'corrected_target': {'name': 'user_name'}
}
context = {
'failure_reason': 'id_changed'
}
correction_id = integration.capture_feedback_correction(
workflow_id='wf_feedback_test',
execution_id='exec_001',
corrections=corrections,
context=context
)
assert correction_id is not None
def test_infer_correction_type(self, integration):
"""Test correction type inference."""
# Target change
assert integration._infer_correction_type({
'corrected_target': {'id': 'new'}
}) == CorrectionType.TARGET_CHANGE.value
# Parameter change
assert integration._infer_correction_type({
'corrected_params': {'timeout': 5}
}) == CorrectionType.PARAMETER_CHANGE.value
# Timing adjust
assert integration._infer_correction_type({
'wait_time': 2.0
}) == CorrectionType.TIMING_ADJUST.value
# Coordinates adjust
assert integration._infer_correction_type({
'coordinates': {'x': 100, 'y': 200}
}) == CorrectionType.COORDINATES_ADJUST.value
# Other
assert integration._infer_correction_type({}) == CorrectionType.OTHER.value
def test_create_hook_for_collector(self, integration):
"""Test creating hook function."""
hook = integration.create_hook_for_collector()
assert callable(hook)
# Call hook with correction
hook({
'action_type': 'click',
'element_type': 'button',
'correction_type': 'target_change'
})
# Verify captured
pack = integration.service.get_pack(integration._default_pack_id)
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 1
def test_get_statistics(self, integration):
"""Test getting statistics."""
# Add some corrections
for i in range(3):
integration.capture_correction(
correction_data={
'action_type': f'action_{i}',
'element_type': 'button'
},
session_id=f'session_{i}'
)
stats = integration.get_statistics()
assert stats['total_corrections'] == 3
class TestGlobalIntegration:
"""Tests for global integration functions."""
def test_get_correction_pack_integration(self, temp_storage):
"""Test getting global integration instance."""
with patch('core.corrections.integration._global_integration', None):
integration1 = get_correction_pack_integration()
integration2 = get_correction_pack_integration()
# Same instance returned
assert integration1 is integration2
def test_capture_coaching_correction(self, temp_storage):
"""Test convenience capture function."""
with patch('core.corrections.integration._global_integration', None):
correction_id = capture_coaching_correction(
correction_data={
'action_type': 'click',
'element_type': 'button'
},
session_id='test_session'
)
# Should succeed (creates pack automatically)
assert correction_id is not None
class TestTrainingDataCollectorIntegration:
"""Tests for TrainingDataCollector integration."""
def test_collector_with_auto_integration(self, temp_storage, service):
"""Test collector with automatic correction pack integration."""
# Create collector
collector = TrainingDataCollector(
output_dir=str(temp_storage / 'training'),
auto_integrate_corrections=True
)
# Inject our test service
integration = CorrectionPackIntegration(service=service)
collector._correction_integration = integration
# Start session
collector.start_session('session_auto_test', workflow_id='wf_auto')
# Record correction
collector.record_correction({
'action_type': 'click',
'element_type': 'button',
'failure_reason': 'element_moved',
'correction_type': 'coordinates_adjust',
'original_target': {'x': 100, 'y': 100},
'corrected_target': {'x': 150, 'y': 120}
})
# End session
collector.end_session(success=True)
# Verify correction captured in pack
pack = service.get_pack(integration._default_pack_id)
assert pack is not None
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 1
def test_collector_with_custom_callback(self, temp_storage):
"""Test collector with custom callback."""
captured = []
def custom_callback(correction, session_id, workflow_id):
captured.append({
'correction': correction,
'session_id': session_id,
'workflow_id': workflow_id
})
collector = TrainingDataCollector(
output_dir=str(temp_storage / 'training'),
correction_callback=custom_callback,
auto_integrate_corrections=False
)
collector.start_session('session_callback', workflow_id='wf_callback')
collector.record_correction({'type': 'test_correction'})
assert len(captured) == 1
assert captured[0]['session_id'] == 'session_callback'
assert captured[0]['workflow_id'] == 'wf_callback'
def test_collector_disabled_integration(self, temp_storage):
"""Test collector with disabled integration."""
collector = TrainingDataCollector(
output_dir=str(temp_storage / 'training'),
auto_integrate_corrections=False
)
collector.start_session('session_disabled')
# Should not fail even without integration
collector.record_correction({'type': 'test'})
# Still recorded in session
assert len(collector.current_session.user_corrections) == 1
class TestFeedbackProcessorIntegration:
"""Tests for FeedbackProcessor integration."""
def test_processor_with_auto_integration(self, service):
"""Test feedback processor with auto integration."""
processor = FeedbackProcessor(auto_integrate_corrections=True)
# Inject our test service
integration = CorrectionPackIntegration(service=service)
processor._correction_integration = integration
# Process feedback with corrections
result = processor.process_feedback(
workflow_id='wf_feedback',
execution_id='exec_001',
feedback_type=FeedbackType.INCORRECT,
confidence=0.5,
corrections={
'action_type': 'click',
'element_type': 'link',
'original_target': {'text': 'More'},
'corrected_target': {'text': 'Plus'}
},
context={'failure_reason': 'text_changed'}
)
assert result['feedback_recorded'] is True
assert result['correction_pack_id'] is not None
# Verify in pack
pack = service.get_pack(integration._default_pack_id)
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 1
def test_processor_correct_feedback_no_propagation(self, service):
"""Test that CORRECT feedback doesn't propagate corrections."""
processor = FeedbackProcessor(auto_integrate_corrections=True)
integration = CorrectionPackIntegration(service=service)
processor._correction_integration = integration
result = processor.process_feedback(
workflow_id='wf_correct',
execution_id='exec_002',
feedback_type=FeedbackType.CORRECT, # Correct, no correction needed
confidence=0.95,
corrections={'some': 'data'} # Even with data, shouldn't propagate
)
assert result['correction_pack_id'] is None
def test_processor_partial_feedback_propagates(self, service):
"""Test that PARTIAL feedback propagates corrections."""
processor = FeedbackProcessor(auto_integrate_corrections=True)
integration = CorrectionPackIntegration(service=service)
processor._correction_integration = integration
result = processor.process_feedback(
workflow_id='wf_partial',
execution_id='exec_003',
feedback_type=FeedbackType.PARTIAL,
confidence=0.7,
corrections={
'action_type': 'fill',
'corrected_params': {'value': 'corrected_value'}
}
)
assert result['correction_pack_id'] is not None
def test_processor_disabled_integration(self):
"""Test processor with disabled integration."""
processor = FeedbackProcessor(auto_integrate_corrections=False)
result = processor.process_feedback(
workflow_id='wf_disabled',
execution_id='exec_004',
feedback_type=FeedbackType.INCORRECT,
confidence=0.3,
corrections={'action_type': 'test'}
)
# Feedback recorded but no correction pack ID
assert result['feedback_recorded'] is True
assert result['correction_pack_id'] is None
class TestEndToEndFlow:
"""End-to-end tests for the complete COACHING flow."""
def test_coaching_correction_flow(self, temp_storage, service):
"""Test complete COACHING correction flow."""
# Setup components
integration = CorrectionPackIntegration(service=service)
collector = TrainingDataCollector(
output_dir=str(temp_storage / 'training'),
auto_integrate_corrections=True
)
collector._correction_integration = integration
processor = FeedbackProcessor(auto_integrate_corrections=True)
processor._correction_integration = integration
# Simulate COACHING session
collector.start_session('coaching_session_001', workflow_id='wf_coaching')
# User makes corrections during execution
collector.record_correction({
'action_type': 'click',
'element_type': 'button',
'failure_reason': 'element_not_found',
'correction_type': 'target_change',
'original_target': {'text': 'OK'},
'corrected_target': {'text': 'Valider'}
})
collector.record_correction({
'action_type': 'type',
'element_type': 'input',
'failure_reason': 'wrong_field',
'correction_type': 'target_change',
'original_target': {'id': 'email'},
'corrected_target': {'name': 'user_email'}
})
collector.end_session(success=True)
# User provides feedback at end
processor.process_feedback(
workflow_id='wf_coaching',
execution_id='coaching_session_001',
feedback_type=FeedbackType.PARTIAL,
confidence=0.75,
corrections={
'action_type': 'submit',
'element_type': 'form',
'corrected_params': {'wait_after': 2.0}
}
)
# Verify all corrections captured
pack = service.get_pack(integration._default_pack_id)
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 3
# Verify statistics
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
stats = service.get_pack_statistics(pack_id)
assert stats['total_corrections'] == 3
def test_multiple_sessions_aggregation(self, temp_storage, service):
"""Test corrections from multiple sessions aggregate correctly."""
integration = CorrectionPackIntegration(service=service)
# Session 1
collector1 = TrainingDataCollector(
output_dir=str(temp_storage / 'training1')
)
collector1._correction_integration = integration
collector1.start_session('session_001', workflow_id='wf_multi')
collector1.record_correction({
'action_type': 'click',
'element_type': 'button',
'failure_reason': 'timeout',
'correction_type': 'timing_adjust'
})
collector1.end_session(success=True)
# Session 2 - same type of correction
collector2 = TrainingDataCollector(
output_dir=str(temp_storage / 'training2')
)
collector2._correction_integration = integration
collector2.start_session('session_002', workflow_id='wf_multi')
collector2.record_correction({
'action_type': 'click',
'element_type': 'button',
'failure_reason': 'timeout',
'correction_type': 'timing_adjust'
})
collector2.end_session(success=True)
# Both corrections captured
pack = service.get_pack(integration._default_pack_id)
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 2
if __name__ == '__main__':
pytest.main([__file__, '-v'])