feat(corrections): Add automatic COACHING integration for Correction Packs
- Add CorrectionPackIntegration class to bridge learning components - Modify TrainingDataCollector to auto-propagate corrections to packs - Modify FeedbackProcessor to capture corrections on INCORRECT/PARTIAL feedback - Add convenience functions: get_correction_pack_integration(), capture_coaching_correction() - Add 19 integration tests (all passing) Corrections made during COACHING mode are now automatically captured into a dedicated "auto_captured_corrections" pack for cross-workflow reuse. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -16,6 +16,11 @@ from .models import (
|
|||||||
from .correction_repository import CorrectionRepository
|
from .correction_repository import CorrectionRepository
|
||||||
from .aggregator import CorrectionAggregator
|
from .aggregator import CorrectionAggregator
|
||||||
from .correction_pack_service import CorrectionPackService
|
from .correction_pack_service import CorrectionPackService
|
||||||
|
from .integration import (
|
||||||
|
CorrectionPackIntegration,
|
||||||
|
get_correction_pack_integration,
|
||||||
|
capture_coaching_correction
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'CorrectionKey',
|
'CorrectionKey',
|
||||||
@@ -26,5 +31,9 @@ __all__ = [
|
|||||||
'CorrectionPackMetadata',
|
'CorrectionPackMetadata',
|
||||||
'CorrectionRepository',
|
'CorrectionRepository',
|
||||||
'CorrectionAggregator',
|
'CorrectionAggregator',
|
||||||
'CorrectionPackService'
|
'CorrectionPackService',
|
||||||
|
# Integration
|
||||||
|
'CorrectionPackIntegration',
|
||||||
|
'get_correction_pack_integration',
|
||||||
|
'capture_coaching_correction'
|
||||||
]
|
]
|
||||||
|
|||||||
288
core/corrections/integration.py
Normal file
288
core/corrections/integration.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
"""
|
||||||
|
Correction Pack Integration - Bridge between learning components and Correction Packs.
|
||||||
|
|
||||||
|
Automatically captures corrections from COACHING mode and propagates them
|
||||||
|
to the Correction Pack system for cross-workflow learning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Optional, Callable
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from .correction_pack_service import CorrectionPackService
|
||||||
|
from .models import CorrectionType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectionPackIntegration:
|
||||||
|
"""
|
||||||
|
Integration layer between learning components and Correction Packs.
|
||||||
|
|
||||||
|
Captures corrections from TrainingDataCollector and FeedbackProcessor
|
||||||
|
and automatically forwards them to the Correction Pack system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default pack name for auto-captured corrections
|
||||||
|
DEFAULT_PACK_NAME = "auto_captured_corrections"
|
||||||
|
DEFAULT_PACK_DESCRIPTION = "Corrections automatiquement capturées pendant les sessions COACHING"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
service: Optional[CorrectionPackService] = None,
|
||||||
|
auto_create_pack: bool = True,
|
||||||
|
default_pack_id: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service: CorrectionPackService instance (lazy-loaded if None)
|
||||||
|
auto_create_pack: Create default pack if it doesn't exist
|
||||||
|
default_pack_id: ID of pack to use (auto-detected if None)
|
||||||
|
"""
|
||||||
|
self._service = service
|
||||||
|
self._auto_create_pack = auto_create_pack
|
||||||
|
self._default_pack_id = default_pack_id
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def service(self) -> CorrectionPackService:
|
||||||
|
"""Lazy-load service."""
|
||||||
|
if self._service is None:
|
||||||
|
self._service = CorrectionPackService()
|
||||||
|
return self._service
|
||||||
|
|
||||||
|
def _ensure_initialized(self) -> str:
|
||||||
|
"""
|
||||||
|
Ensure the default pack exists and return its ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pack ID to use for storing corrections
|
||||||
|
"""
|
||||||
|
if self._initialized and self._default_pack_id:
|
||||||
|
return self._default_pack_id
|
||||||
|
|
||||||
|
# If specific pack ID provided, verify it exists
|
||||||
|
if self._default_pack_id:
|
||||||
|
pack = self.service.get_pack(self._default_pack_id)
|
||||||
|
if pack:
|
||||||
|
self._initialized = True
|
||||||
|
return self._default_pack_id
|
||||||
|
else:
|
||||||
|
logger.warning(f"Pack {self._default_pack_id} not found, creating default")
|
||||||
|
|
||||||
|
# Search for existing auto-capture pack
|
||||||
|
packs = self.service.list_packs()
|
||||||
|
for pack in packs:
|
||||||
|
# Handle both dict and object formats
|
||||||
|
pack_name = pack.get('name') if isinstance(pack, dict) else pack.name
|
||||||
|
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
|
||||||
|
if pack_name == self.DEFAULT_PACK_NAME:
|
||||||
|
self._default_pack_id = pack_id
|
||||||
|
self._initialized = True
|
||||||
|
logger.info(f"Using existing auto-capture pack: {pack_id}")
|
||||||
|
return self._default_pack_id
|
||||||
|
|
||||||
|
# Create new pack if auto_create enabled
|
||||||
|
if self._auto_create_pack:
|
||||||
|
pack = self.service.create_pack(
|
||||||
|
name=self.DEFAULT_PACK_NAME,
|
||||||
|
description=self.DEFAULT_PACK_DESCRIPTION,
|
||||||
|
tags=["auto", "coaching"],
|
||||||
|
category="learning"
|
||||||
|
)
|
||||||
|
# Handle both dict and object formats
|
||||||
|
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
|
||||||
|
self._default_pack_id = pack_id
|
||||||
|
self._initialized = True
|
||||||
|
logger.info(f"Created auto-capture pack: {pack_id}")
|
||||||
|
return self._default_pack_id
|
||||||
|
|
||||||
|
raise RuntimeError("No default pack available and auto_create_pack is disabled")
|
||||||
|
|
||||||
|
def capture_correction(
|
||||||
|
self,
|
||||||
|
correction_data: Dict[str, Any],
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
workflow_id: Optional[str] = None,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
pack_id: Optional[str] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Capture a correction and add it to a Correction Pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
correction_data: Correction dict (from TrainingDataCollector format)
|
||||||
|
session_id: Source session ID
|
||||||
|
workflow_id: Source workflow ID
|
||||||
|
node_id: Source node ID
|
||||||
|
pack_id: Specific pack to add to (uses default if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Correction ID if successfully added, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
target_pack_id = pack_id or self._ensure_initialized()
|
||||||
|
|
||||||
|
# Enrich correction data with source info
|
||||||
|
enriched = {
|
||||||
|
**correction_data,
|
||||||
|
'node_id': node_id or correction_data.get('node_id')
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine session_id
|
||||||
|
final_session_id = session_id or correction_data.get('session_id', 'unknown')
|
||||||
|
final_workflow_id = workflow_id or correction_data.get('workflow_id')
|
||||||
|
|
||||||
|
# Add to pack
|
||||||
|
correction = self.service.add_correction_from_session(
|
||||||
|
pack_id=target_pack_id,
|
||||||
|
session_correction=enriched,
|
||||||
|
session_id=final_session_id,
|
||||||
|
workflow_id=final_workflow_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if correction:
|
||||||
|
# Handle both dict and object formats
|
||||||
|
correction_id = correction.get('id') if isinstance(correction, dict) else correction.id
|
||||||
|
logger.info(
|
||||||
|
f"Captured correction {correction_id} from "
|
||||||
|
f"session={final_session_id}, workflow={final_workflow_id}"
|
||||||
|
)
|
||||||
|
return correction_id
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to add correction to pack")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error capturing correction: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def capture_feedback_correction(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
execution_id: str,
|
||||||
|
corrections: Dict[str, Any],
|
||||||
|
context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Capture a correction from FeedbackProcessor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_id: Workflow ID
|
||||||
|
execution_id: Execution ID (used as session_id)
|
||||||
|
corrections: Correction dict from feedback
|
||||||
|
context: Additional context (failure_reason, element_type, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Correction ID if successfully added
|
||||||
|
"""
|
||||||
|
if not corrections:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build correction data in standard format
|
||||||
|
correction_data = {
|
||||||
|
'action_type': corrections.get('action_type', corrections.get('type', 'unknown')),
|
||||||
|
'element_type': corrections.get('element_type', context.get('element_type', 'unknown') if context else 'unknown'),
|
||||||
|
'failure_reason': corrections.get('failure_reason', context.get('failure_reason', '') if context else ''),
|
||||||
|
'correction_type': self._infer_correction_type(corrections),
|
||||||
|
'original_target': corrections.get('original_target'),
|
||||||
|
'corrected_target': corrections.get('corrected_target', corrections.get('new_target')),
|
||||||
|
'original_params': corrections.get('original_params'),
|
||||||
|
'corrected_params': corrections.get('corrected_params', corrections.get('new_params')),
|
||||||
|
'description': corrections.get('description', f"Correction from feedback {execution_id}")
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.capture_correction(
|
||||||
|
correction_data=correction_data,
|
||||||
|
session_id=execution_id,
|
||||||
|
workflow_id=workflow_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def _infer_correction_type(self, corrections: Dict[str, Any]) -> str:
|
||||||
|
"""Infer correction type from correction data."""
|
||||||
|
if corrections.get('correction_type'):
|
||||||
|
return corrections['correction_type']
|
||||||
|
|
||||||
|
# Infer from data
|
||||||
|
if corrections.get('corrected_target') or corrections.get('new_target'):
|
||||||
|
return CorrectionType.TARGET_CHANGE.value
|
||||||
|
if corrections.get('corrected_params') or corrections.get('new_params'):
|
||||||
|
return CorrectionType.PARAMETER_CHANGE.value
|
||||||
|
if corrections.get('wait_time') or corrections.get('timing'):
|
||||||
|
return CorrectionType.TIMING_ADJUST.value
|
||||||
|
if corrections.get('coordinates') or corrections.get('offset'):
|
||||||
|
return CorrectionType.COORDINATES_ADJUST.value
|
||||||
|
|
||||||
|
return CorrectionType.OTHER.value
|
||||||
|
|
||||||
|
def create_hook_for_collector(self) -> Callable[[Dict[str, Any]], None]:
|
||||||
|
"""
|
||||||
|
Create a hook function for TrainingDataCollector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callback function that captures corrections
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
integration = CorrectionPackIntegration()
|
||||||
|
collector.on_correction = integration.create_hook_for_collector()
|
||||||
|
"""
|
||||||
|
def hook(correction: Dict[str, Any]) -> None:
|
||||||
|
self.capture_correction(correction)
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def get_statistics(self) -> Dict[str, Any]:
|
||||||
|
"""Get statistics about captured corrections."""
|
||||||
|
try:
|
||||||
|
pack_id = self._ensure_initialized()
|
||||||
|
return self.service.get_pack_statistics(pack_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting statistics: {e}")
|
||||||
|
return {'error': str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance for easy access
|
||||||
|
_global_integration: Optional[CorrectionPackIntegration] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_correction_pack_integration() -> CorrectionPackIntegration:
|
||||||
|
"""
|
||||||
|
Get the global CorrectionPackIntegration instance.
|
||||||
|
|
||||||
|
Creates a new instance if one doesn't exist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CorrectionPackIntegration instance
|
||||||
|
"""
|
||||||
|
global _global_integration
|
||||||
|
if _global_integration is None:
|
||||||
|
_global_integration = CorrectionPackIntegration()
|
||||||
|
return _global_integration
|
||||||
|
|
||||||
|
|
||||||
|
def capture_coaching_correction(
|
||||||
|
correction_data: Dict[str, Any],
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
workflow_id: Optional[str] = None,
|
||||||
|
node_id: Optional[str] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Convenience function to capture a correction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
correction_data: Correction dict
|
||||||
|
session_id: Source session ID
|
||||||
|
workflow_id: Source workflow ID
|
||||||
|
node_id: Source node ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Correction ID if successful
|
||||||
|
"""
|
||||||
|
integration = get_correction_pack_integration()
|
||||||
|
return integration.capture_correction(
|
||||||
|
correction_data=correction_data,
|
||||||
|
session_id=session_id,
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
node_id=node_id
|
||||||
|
)
|
||||||
176
core/learning/feedback_processor.py
Normal file
176
core/learning/feedback_processor.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
"""Feedback Processor - Processes user feedback to improve workflows"""
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Lazy import for correction pack integration
|
||||||
|
_correction_integration = None
|
||||||
|
|
||||||
|
class FeedbackType(str, Enum):
|
||||||
|
"""Types of user feedback"""
|
||||||
|
CORRECT = "correct"
|
||||||
|
INCORRECT = "incorrect"
|
||||||
|
PARTIAL = "partial"
|
||||||
|
SKIP = "skip"
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Feedback:
|
||||||
|
"""User feedback on workflow execution"""
|
||||||
|
workflow_id: str
|
||||||
|
execution_id: str
|
||||||
|
feedback_type: FeedbackType
|
||||||
|
timestamp: datetime
|
||||||
|
confidence_before: float
|
||||||
|
user_comment: Optional[str] = None
|
||||||
|
corrections: Optional[Dict] = None
|
||||||
|
|
||||||
|
class FeedbackProcessor:
|
||||||
|
"""Processes user feedback to improve workflows"""
|
||||||
|
|
||||||
|
def __init__(self, auto_integrate_corrections: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize the feedback processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_integrate_corrections: Auto-propagate corrections to Correction Packs
|
||||||
|
"""
|
||||||
|
self.feedback_history: List[Feedback] = []
|
||||||
|
self._auto_integrate = auto_integrate_corrections
|
||||||
|
self._correction_integration = None
|
||||||
|
logger.info("FeedbackProcessor initialized")
|
||||||
|
|
||||||
|
def process_feedback(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
execution_id: str,
|
||||||
|
feedback_type: FeedbackType,
|
||||||
|
confidence: float,
|
||||||
|
comment: Optional[str] = None,
|
||||||
|
corrections: Optional[Dict] = None,
|
||||||
|
context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict:
|
||||||
|
"""Process user feedback and return improvement suggestions"""
|
||||||
|
|
||||||
|
feedback = Feedback(
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
execution_id=execution_id,
|
||||||
|
feedback_type=feedback_type,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
confidence_before=confidence,
|
||||||
|
user_comment=comment,
|
||||||
|
corrections=corrections
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feedback_history.append(feedback)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Feedback processed: workflow={workflow_id}, "
|
||||||
|
f"type={feedback_type.value}, confidence={confidence:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Propagate corrections to Correction Packs if applicable
|
||||||
|
correction_id = None
|
||||||
|
if corrections and feedback_type in [FeedbackType.INCORRECT, FeedbackType.PARTIAL]:
|
||||||
|
correction_id = self._propagate_corrections(
|
||||||
|
workflow_id, execution_id, corrections, context
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate improvement suggestions
|
||||||
|
suggestions = self._generate_suggestions(feedback)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'feedback_recorded': True,
|
||||||
|
'suggestions': suggestions,
|
||||||
|
'should_update_workflow': feedback_type in [FeedbackType.INCORRECT, FeedbackType.PARTIAL],
|
||||||
|
'correction_pack_id': correction_id
|
||||||
|
}
|
||||||
|
|
||||||
|
def _propagate_corrections(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
execution_id: str,
|
||||||
|
corrections: Dict,
|
||||||
|
context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Propagate corrections to Correction Pack system."""
|
||||||
|
if not self._auto_integrate:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self._correction_integration is None:
|
||||||
|
from core.corrections import get_correction_pack_integration
|
||||||
|
self._correction_integration = get_correction_pack_integration()
|
||||||
|
|
||||||
|
return self._correction_integration.capture_feedback_correction(
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
execution_id=execution_id,
|
||||||
|
corrections=corrections,
|
||||||
|
context=context
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("Correction pack integration not available")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error propagating corrections: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _generate_suggestions(self, feedback: Feedback) -> List[str]:
|
||||||
|
"""Generate improvement suggestions based on feedback"""
|
||||||
|
suggestions = []
|
||||||
|
|
||||||
|
if feedback.feedback_type == FeedbackType.INCORRECT:
|
||||||
|
suggestions.append("Review target resolution strategy")
|
||||||
|
suggestions.append("Check if UI elements changed")
|
||||||
|
suggestions.append("Verify action sequence")
|
||||||
|
|
||||||
|
if feedback.corrections:
|
||||||
|
suggestions.append(f"Apply user corrections: {feedback.corrections}")
|
||||||
|
|
||||||
|
elif feedback.feedback_type == FeedbackType.PARTIAL:
|
||||||
|
suggestions.append("Some steps succeeded - identify failing step")
|
||||||
|
suggestions.append("Consider splitting workflow into smaller parts")
|
||||||
|
|
||||||
|
elif feedback.feedback_type == FeedbackType.CORRECT:
|
||||||
|
suggestions.append("Workflow performing well - increase confidence")
|
||||||
|
|
||||||
|
return suggestions
|
||||||
|
|
||||||
|
def get_feedback_stats(self, workflow_id: str) -> Dict:
|
||||||
|
"""Get feedback statistics for a workflow"""
|
||||||
|
workflow_feedback = [f for f in self.feedback_history if f.workflow_id == workflow_id]
|
||||||
|
|
||||||
|
if not workflow_feedback:
|
||||||
|
return {
|
||||||
|
'total': 0,
|
||||||
|
'correct': 0,
|
||||||
|
'incorrect': 0,
|
||||||
|
'partial': 0,
|
||||||
|
'skip': 0,
|
||||||
|
'accuracy': 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
total = len(workflow_feedback)
|
||||||
|
correct = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.CORRECT)
|
||||||
|
incorrect = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.INCORRECT)
|
||||||
|
partial = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.PARTIAL)
|
||||||
|
skip = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.SKIP)
|
||||||
|
|
||||||
|
accuracy = correct / total if total > 0 else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total': total,
|
||||||
|
'correct': correct,
|
||||||
|
'incorrect': incorrect,
|
||||||
|
'partial': partial,
|
||||||
|
'skip': skip,
|
||||||
|
'accuracy': accuracy
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_recent_feedback(self, workflow_id: str, limit: int = 10) -> List[Feedback]:
|
||||||
|
"""Get recent feedback for a workflow"""
|
||||||
|
workflow_feedback = [f for f in self.feedback_history if f.workflow_id == workflow_id]
|
||||||
|
return sorted(workflow_feedback, key=lambda f: f.timestamp, reverse=True)[:limit]
|
||||||
246
core/training/training_data_collector.py
Normal file
246
core/training/training_data_collector.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
"""Training Data Collector - Collect clean training data during real usage"""
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from typing import List, Dict, Any, Optional, Callable
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type for correction callback
|
||||||
|
CorrectionCallback = Callable[[Dict[str, Any], str, Optional[str]], None]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingSession:
|
||||||
|
"""Single training session record"""
|
||||||
|
session_id: str
|
||||||
|
workflow_id: Optional[str]
|
||||||
|
timestamp: datetime
|
||||||
|
screenshots: List[str] # Paths to screenshots
|
||||||
|
actions: List[Dict[str, Any]]
|
||||||
|
embeddings: List[str] # Paths to embedding files
|
||||||
|
success: bool
|
||||||
|
user_corrections: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkflowPattern:
|
||||||
|
"""Detected workflow pattern"""
|
||||||
|
pattern_id: str
|
||||||
|
name: str
|
||||||
|
frequency: int
|
||||||
|
avg_confidence: float
|
||||||
|
states: List[str] # State IDs
|
||||||
|
actions: List[str] # Action types
|
||||||
|
success_rate: float
|
||||||
|
|
||||||
|
class TrainingDataCollector:
|
||||||
|
"""Collect structured training data during real usage"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_dir: str = "training_data",
|
||||||
|
correction_callback: Optional[CorrectionCallback] = None,
|
||||||
|
auto_integrate_corrections: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the training data collector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Directory for storing training data
|
||||||
|
correction_callback: Optional callback for correction propagation
|
||||||
|
auto_integrate_corrections: Auto-enable correction pack integration
|
||||||
|
"""
|
||||||
|
self.output_dir = Path(output_dir)
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.sessions: List[TrainingSession] = []
|
||||||
|
self.patterns: List[WorkflowPattern] = []
|
||||||
|
self.current_session: Optional[TrainingSession] = None
|
||||||
|
|
||||||
|
# Correction pack integration
|
||||||
|
self._correction_callback = correction_callback
|
||||||
|
self._auto_integrate = auto_integrate_corrections
|
||||||
|
self._correction_integration = None
|
||||||
|
|
||||||
|
logger.info(f"TrainingDataCollector initialized (output_dir={output_dir})")
|
||||||
|
|
||||||
|
def start_session(self, session_id: str, workflow_id: Optional[str] = None) -> None:
|
||||||
|
"""Start recording a new training session"""
|
||||||
|
self.current_session = TrainingSession(
|
||||||
|
session_id=session_id,
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
screenshots=[],
|
||||||
|
actions=[],
|
||||||
|
embeddings=[],
|
||||||
|
success=False
|
||||||
|
)
|
||||||
|
logger.info(f"Started training session: {session_id}")
|
||||||
|
|
||||||
|
def record_screenshot(self, screenshot_path: str) -> None:
|
||||||
|
"""Record a screenshot"""
|
||||||
|
if self.current_session:
|
||||||
|
self.current_session.screenshots.append(screenshot_path)
|
||||||
|
|
||||||
|
def record_action(self, action: Dict[str, Any]) -> None:
|
||||||
|
"""Record an action"""
|
||||||
|
if self.current_session:
|
||||||
|
self.current_session.actions.append({
|
||||||
|
**action,
|
||||||
|
'timestamp': datetime.now().isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
def record_embedding(self, embedding_path: str) -> None:
|
||||||
|
"""Record an embedding file path"""
|
||||||
|
if self.current_session:
|
||||||
|
self.current_session.embeddings.append(embedding_path)
|
||||||
|
|
||||||
|
def record_correction(self, correction: Dict[str, Any]) -> None:
|
||||||
|
"""Record a user correction and propagate to Correction Packs."""
|
||||||
|
if self.current_session:
|
||||||
|
timestamped_correction = {
|
||||||
|
**correction,
|
||||||
|
'timestamp': datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
self.current_session.user_corrections.append(timestamped_correction)
|
||||||
|
logger.info(f"User correction recorded: {correction.get('type', 'unknown')}")
|
||||||
|
|
||||||
|
# Propagate to Correction Packs
|
||||||
|
self._propagate_to_correction_packs(timestamped_correction)
|
||||||
|
|
||||||
|
def _propagate_to_correction_packs(self, correction: Dict[str, Any]) -> None:
|
||||||
|
"""Propagate correction to Correction Pack system."""
|
||||||
|
# Use explicit callback if provided
|
||||||
|
if self._correction_callback:
|
||||||
|
try:
|
||||||
|
session_id = self.current_session.session_id if self.current_session else None
|
||||||
|
workflow_id = self.current_session.workflow_id if self.current_session else None
|
||||||
|
self._correction_callback(correction, session_id, workflow_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error in correction callback: {e}")
|
||||||
|
|
||||||
|
# Auto-integrate if enabled
|
||||||
|
elif self._auto_integrate:
|
||||||
|
try:
|
||||||
|
if self._correction_integration is None:
|
||||||
|
from core.corrections import get_correction_pack_integration
|
||||||
|
self._correction_integration = get_correction_pack_integration()
|
||||||
|
|
||||||
|
session_id = self.current_session.session_id if self.current_session else None
|
||||||
|
workflow_id = self.current_session.workflow_id if self.current_session else None
|
||||||
|
|
||||||
|
self._correction_integration.capture_correction(
|
||||||
|
correction_data=correction,
|
||||||
|
session_id=session_id,
|
||||||
|
workflow_id=workflow_id
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("Correction pack integration not available")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error propagating to correction packs: {e}")
|
||||||
|
|
||||||
|
def end_session(self, success: bool, metadata: Optional[Dict] = None) -> None:
|
||||||
|
"""End current session and save"""
|
||||||
|
if not self.current_session:
|
||||||
|
logger.warning("No active session to end")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.current_session.success = success
|
||||||
|
if metadata:
|
||||||
|
self.current_session.metadata.update(metadata)
|
||||||
|
|
||||||
|
self.sessions.append(self.current_session)
|
||||||
|
|
||||||
|
# Save session to disk
|
||||||
|
self._save_session(self.current_session)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Session ended: {self.current_session.session_id} "
|
||||||
|
f"(success={success}, actions={len(self.current_session.actions)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_session = None
|
||||||
|
|
||||||
|
def _save_session(self, session: TrainingSession) -> None:
|
||||||
|
"""Save session to JSON file"""
|
||||||
|
session_file = self.output_dir / f"session_{session.session_id}.json"
|
||||||
|
|
||||||
|
# Convert to dict with datetime serialization
|
||||||
|
session_dict = asdict(session)
|
||||||
|
session_dict['timestamp'] = session.timestamp.isoformat()
|
||||||
|
|
||||||
|
with open(session_file, 'w') as f:
|
||||||
|
json.dump(session_dict, f, indent=2)
|
||||||
|
|
||||||
|
def add_pattern(self, pattern: WorkflowPattern) -> None:
|
||||||
|
"""Add a detected workflow pattern"""
|
||||||
|
self.patterns.append(pattern)
|
||||||
|
logger.info(f"Pattern added: {pattern.name} (frequency={pattern.frequency})")
|
||||||
|
|
||||||
|
def export_training_set(self, output_file: str = "training_set.json") -> Dict[str, Any]:
|
||||||
|
"""Export complete training dataset"""
|
||||||
|
training_set = {
|
||||||
|
'metadata': {
|
||||||
|
'export_date': datetime.now().isoformat(),
|
||||||
|
'total_sessions': len(self.sessions),
|
||||||
|
'total_patterns': len(self.patterns),
|
||||||
|
'success_rate': self._calculate_success_rate()
|
||||||
|
},
|
||||||
|
'sessions': [self._serialize_session(s) for s in self.sessions],
|
||||||
|
'patterns': [asdict(p) for p in self.patterns],
|
||||||
|
'statistics': self._calculate_statistics()
|
||||||
|
}
|
||||||
|
|
||||||
|
output_path = self.output_dir / output_file
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
json.dump(training_set, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"Training set exported: {output_path} ({len(self.sessions)} sessions)")
|
||||||
|
return training_set
|
||||||
|
|
||||||
|
def _calculate_success_rate(self) -> float:
|
||||||
|
"""Calculate overall success rate"""
|
||||||
|
if not self.sessions:
|
||||||
|
return 0.0
|
||||||
|
successful = sum(1 for s in self.sessions if s.success)
|
||||||
|
return successful / len(self.sessions)
|
||||||
|
|
||||||
|
def _calculate_statistics(self) -> Dict[str, Any]:
|
||||||
|
"""Calculate training data statistics"""
|
||||||
|
if not self.sessions:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
total_actions = sum(len(s.actions) for s in self.sessions)
|
||||||
|
total_corrections = sum(len(s.user_corrections) for s in self.sessions)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_sessions': len(self.sessions),
|
||||||
|
'successful_sessions': sum(1 for s in self.sessions if s.success),
|
||||||
|
'total_actions': total_actions,
|
||||||
|
'total_corrections': total_corrections,
|
||||||
|
'avg_actions_per_session': total_actions / len(self.sessions),
|
||||||
|
'correction_rate': total_corrections / total_actions if total_actions > 0 else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_sessions_by_workflow(self, workflow_id: str) -> List[TrainingSession]:
|
||||||
|
"""Get all sessions for a specific workflow"""
|
||||||
|
return [s for s in self.sessions if s.workflow_id == workflow_id]
|
||||||
|
|
||||||
|
def get_successful_sessions(self) -> List[TrainingSession]:
|
||||||
|
"""Get only successful sessions"""
|
||||||
|
return [s for s in self.sessions if s.success]
|
||||||
|
|
||||||
|
def clear_data(self) -> None:
|
||||||
|
"""Clear all collected data"""
|
||||||
|
self.sessions.clear()
|
||||||
|
self.patterns.clear()
|
||||||
|
self.current_session = None
|
||||||
|
logger.info("Training data cleared")
|
||||||
|
|
||||||
|
def _serialize_session(self, session: TrainingSession) -> Dict:
|
||||||
|
"""Serialize session with datetime conversion"""
|
||||||
|
data = asdict(session)
|
||||||
|
data['timestamp'] = session.timestamp.isoformat()
|
||||||
|
return data
|
||||||
486
tests/test_correction_pack_integration.py
Normal file
486
tests/test_correction_pack_integration.py
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
"""
|
||||||
|
Tests for Correction Pack Integration with learning components.
|
||||||
|
|
||||||
|
Tests the automatic capture of corrections from COACHING mode into
|
||||||
|
the Correction Pack system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from core.corrections import (
|
||||||
|
CorrectionPackIntegration,
|
||||||
|
CorrectionPackService,
|
||||||
|
get_correction_pack_integration,
|
||||||
|
capture_coaching_correction
|
||||||
|
)
|
||||||
|
from core.corrections.models import CorrectionType
|
||||||
|
from core.training.training_data_collector import TrainingDataCollector
|
||||||
|
from core.learning.feedback_processor import FeedbackProcessor, FeedbackType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_storage():
|
||||||
|
"""Create temporary storage directory."""
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
yield Path(temp_dir)
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service(temp_storage):
|
||||||
|
"""Create a CorrectionPackService with temp storage."""
|
||||||
|
return CorrectionPackService(storage_path=temp_storage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def integration(service):
|
||||||
|
"""Create a CorrectionPackIntegration with test service."""
|
||||||
|
return CorrectionPackIntegration(service=service, auto_create_pack=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCorrectionPackIntegration:
|
||||||
|
"""Tests for CorrectionPackIntegration class."""
|
||||||
|
|
||||||
|
def test_init(self, service):
|
||||||
|
"""Test integration initialization."""
|
||||||
|
integration = CorrectionPackIntegration(service=service)
|
||||||
|
assert integration._service is service
|
||||||
|
assert integration._auto_create_pack is True
|
||||||
|
assert integration._initialized is False
|
||||||
|
|
||||||
|
def test_ensure_initialized_creates_pack(self, integration):
|
||||||
|
"""Test auto-creation of default pack."""
|
||||||
|
pack_id = integration._ensure_initialized()
|
||||||
|
|
||||||
|
assert pack_id is not None
|
||||||
|
assert integration._initialized is True
|
||||||
|
assert integration._default_pack_id == pack_id
|
||||||
|
|
||||||
|
# Verify pack exists
|
||||||
|
pack = integration.service.get_pack(pack_id)
|
||||||
|
assert pack is not None
|
||||||
|
# Handle dict format from service
|
||||||
|
pack_name = pack.get('name') if isinstance(pack, dict) else pack.name
|
||||||
|
assert pack_name == CorrectionPackIntegration.DEFAULT_PACK_NAME
|
||||||
|
|
||||||
|
def test_capture_correction(self, integration):
|
||||||
|
"""Test capturing a correction."""
|
||||||
|
correction_data = {
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'button',
|
||||||
|
'failure_reason': 'element_not_found',
|
||||||
|
'correction_type': 'target_change',
|
||||||
|
'original_target': {'text': 'Submit'},
|
||||||
|
'corrected_target': {'text': 'Envoyer'}
|
||||||
|
}
|
||||||
|
|
||||||
|
correction_id = integration.capture_correction(
|
||||||
|
correction_data=correction_data,
|
||||||
|
session_id='test_session_1',
|
||||||
|
workflow_id='wf_001'
|
||||||
|
)
|
||||||
|
|
||||||
|
assert correction_id is not None
|
||||||
|
|
||||||
|
# Verify in pack
|
||||||
|
pack = integration.service.get_pack(integration._default_pack_id)
|
||||||
|
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||||
|
assert len(corrections) == 1
|
||||||
|
|
||||||
|
def test_capture_correction_with_minimal_data(self, integration):
|
||||||
|
"""Test capturing correction with minimal data."""
|
||||||
|
correction_data = {
|
||||||
|
'type': 'input',
|
||||||
|
'new_target': {'xpath': '//input[@name="email"]'}
|
||||||
|
}
|
||||||
|
|
||||||
|
correction_id = integration.capture_correction(
|
||||||
|
correction_data=correction_data,
|
||||||
|
session_id='session_2'
|
||||||
|
)
|
||||||
|
|
||||||
|
assert correction_id is not None
|
||||||
|
|
||||||
|
def test_capture_feedback_correction(self, integration):
|
||||||
|
"""Test capturing correction from feedback."""
|
||||||
|
corrections = {
|
||||||
|
'action_type': 'type',
|
||||||
|
'element_type': 'input',
|
||||||
|
'original_target': {'id': 'username'},
|
||||||
|
'corrected_target': {'name': 'user_name'}
|
||||||
|
}
|
||||||
|
context = {
|
||||||
|
'failure_reason': 'id_changed'
|
||||||
|
}
|
||||||
|
|
||||||
|
correction_id = integration.capture_feedback_correction(
|
||||||
|
workflow_id='wf_feedback_test',
|
||||||
|
execution_id='exec_001',
|
||||||
|
corrections=corrections,
|
||||||
|
context=context
|
||||||
|
)
|
||||||
|
|
||||||
|
assert correction_id is not None
|
||||||
|
|
||||||
|
def test_infer_correction_type(self, integration):
|
||||||
|
"""Test correction type inference."""
|
||||||
|
# Target change
|
||||||
|
assert integration._infer_correction_type({
|
||||||
|
'corrected_target': {'id': 'new'}
|
||||||
|
}) == CorrectionType.TARGET_CHANGE.value
|
||||||
|
|
||||||
|
# Parameter change
|
||||||
|
assert integration._infer_correction_type({
|
||||||
|
'corrected_params': {'timeout': 5}
|
||||||
|
}) == CorrectionType.PARAMETER_CHANGE.value
|
||||||
|
|
||||||
|
# Timing adjust
|
||||||
|
assert integration._infer_correction_type({
|
||||||
|
'wait_time': 2.0
|
||||||
|
}) == CorrectionType.TIMING_ADJUST.value
|
||||||
|
|
||||||
|
# Coordinates adjust
|
||||||
|
assert integration._infer_correction_type({
|
||||||
|
'coordinates': {'x': 100, 'y': 200}
|
||||||
|
}) == CorrectionType.COORDINATES_ADJUST.value
|
||||||
|
|
||||||
|
# Other
|
||||||
|
assert integration._infer_correction_type({}) == CorrectionType.OTHER.value
|
||||||
|
|
||||||
|
def test_create_hook_for_collector(self, integration):
|
||||||
|
"""Test creating hook function."""
|
||||||
|
hook = integration.create_hook_for_collector()
|
||||||
|
|
||||||
|
assert callable(hook)
|
||||||
|
|
||||||
|
# Call hook with correction
|
||||||
|
hook({
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'button',
|
||||||
|
'correction_type': 'target_change'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Verify captured
|
||||||
|
pack = integration.service.get_pack(integration._default_pack_id)
|
||||||
|
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||||
|
assert len(corrections) == 1
|
||||||
|
|
||||||
|
def test_get_statistics(self, integration):
|
||||||
|
"""Test getting statistics."""
|
||||||
|
# Add some corrections
|
||||||
|
for i in range(3):
|
||||||
|
integration.capture_correction(
|
||||||
|
correction_data={
|
||||||
|
'action_type': f'action_{i}',
|
||||||
|
'element_type': 'button'
|
||||||
|
},
|
||||||
|
session_id=f'session_{i}'
|
||||||
|
)
|
||||||
|
|
||||||
|
stats = integration.get_statistics()
|
||||||
|
assert stats['total_corrections'] == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalIntegration:
|
||||||
|
"""Tests for global integration functions."""
|
||||||
|
|
||||||
|
def test_get_correction_pack_integration(self, temp_storage):
|
||||||
|
"""Test getting global integration instance."""
|
||||||
|
with patch('core.corrections.integration._global_integration', None):
|
||||||
|
integration1 = get_correction_pack_integration()
|
||||||
|
integration2 = get_correction_pack_integration()
|
||||||
|
|
||||||
|
# Same instance returned
|
||||||
|
assert integration1 is integration2
|
||||||
|
|
||||||
|
def test_capture_coaching_correction(self, temp_storage):
|
||||||
|
"""Test convenience capture function."""
|
||||||
|
with patch('core.corrections.integration._global_integration', None):
|
||||||
|
correction_id = capture_coaching_correction(
|
||||||
|
correction_data={
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'button'
|
||||||
|
},
|
||||||
|
session_id='test_session'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should succeed (creates pack automatically)
|
||||||
|
assert correction_id is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainingDataCollectorIntegration:
|
||||||
|
"""Tests for TrainingDataCollector integration."""
|
||||||
|
|
||||||
|
def test_collector_with_auto_integration(self, temp_storage, service):
|
||||||
|
"""Test collector with automatic correction pack integration."""
|
||||||
|
# Create collector
|
||||||
|
collector = TrainingDataCollector(
|
||||||
|
output_dir=str(temp_storage / 'training'),
|
||||||
|
auto_integrate_corrections=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inject our test service
|
||||||
|
integration = CorrectionPackIntegration(service=service)
|
||||||
|
collector._correction_integration = integration
|
||||||
|
|
||||||
|
# Start session
|
||||||
|
collector.start_session('session_auto_test', workflow_id='wf_auto')
|
||||||
|
|
||||||
|
# Record correction
|
||||||
|
collector.record_correction({
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'button',
|
||||||
|
'failure_reason': 'element_moved',
|
||||||
|
'correction_type': 'coordinates_adjust',
|
||||||
|
'original_target': {'x': 100, 'y': 100},
|
||||||
|
'corrected_target': {'x': 150, 'y': 120}
|
||||||
|
})
|
||||||
|
|
||||||
|
# End session
|
||||||
|
collector.end_session(success=True)
|
||||||
|
|
||||||
|
# Verify correction captured in pack
|
||||||
|
pack = service.get_pack(integration._default_pack_id)
|
||||||
|
assert pack is not None
|
||||||
|
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||||
|
assert len(corrections) == 1
|
||||||
|
|
||||||
|
def test_collector_with_custom_callback(self, temp_storage):
|
||||||
|
"""Test collector with custom callback."""
|
||||||
|
captured = []
|
||||||
|
|
||||||
|
def custom_callback(correction, session_id, workflow_id):
|
||||||
|
captured.append({
|
||||||
|
'correction': correction,
|
||||||
|
'session_id': session_id,
|
||||||
|
'workflow_id': workflow_id
|
||||||
|
})
|
||||||
|
|
||||||
|
collector = TrainingDataCollector(
|
||||||
|
output_dir=str(temp_storage / 'training'),
|
||||||
|
correction_callback=custom_callback,
|
||||||
|
auto_integrate_corrections=False
|
||||||
|
)
|
||||||
|
|
||||||
|
collector.start_session('session_callback', workflow_id='wf_callback')
|
||||||
|
collector.record_correction({'type': 'test_correction'})
|
||||||
|
|
||||||
|
assert len(captured) == 1
|
||||||
|
assert captured[0]['session_id'] == 'session_callback'
|
||||||
|
assert captured[0]['workflow_id'] == 'wf_callback'
|
||||||
|
|
||||||
|
def test_collector_disabled_integration(self, temp_storage):
|
||||||
|
"""Test collector with disabled integration."""
|
||||||
|
collector = TrainingDataCollector(
|
||||||
|
output_dir=str(temp_storage / 'training'),
|
||||||
|
auto_integrate_corrections=False
|
||||||
|
)
|
||||||
|
|
||||||
|
collector.start_session('session_disabled')
|
||||||
|
|
||||||
|
# Should not fail even without integration
|
||||||
|
collector.record_correction({'type': 'test'})
|
||||||
|
|
||||||
|
# Still recorded in session
|
||||||
|
assert len(collector.current_session.user_corrections) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestFeedbackProcessorIntegration:
|
||||||
|
"""Tests for FeedbackProcessor integration."""
|
||||||
|
|
||||||
|
def test_processor_with_auto_integration(self, service):
|
||||||
|
"""Test feedback processor with auto integration."""
|
||||||
|
processor = FeedbackProcessor(auto_integrate_corrections=True)
|
||||||
|
|
||||||
|
# Inject our test service
|
||||||
|
integration = CorrectionPackIntegration(service=service)
|
||||||
|
processor._correction_integration = integration
|
||||||
|
|
||||||
|
# Process feedback with corrections
|
||||||
|
result = processor.process_feedback(
|
||||||
|
workflow_id='wf_feedback',
|
||||||
|
execution_id='exec_001',
|
||||||
|
feedback_type=FeedbackType.INCORRECT,
|
||||||
|
confidence=0.5,
|
||||||
|
corrections={
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'link',
|
||||||
|
'original_target': {'text': 'More'},
|
||||||
|
'corrected_target': {'text': 'Plus'}
|
||||||
|
},
|
||||||
|
context={'failure_reason': 'text_changed'}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result['feedback_recorded'] is True
|
||||||
|
assert result['correction_pack_id'] is not None
|
||||||
|
|
||||||
|
# Verify in pack
|
||||||
|
pack = service.get_pack(integration._default_pack_id)
|
||||||
|
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||||
|
assert len(corrections) == 1
|
||||||
|
|
||||||
|
def test_processor_correct_feedback_no_propagation(self, service):
|
||||||
|
"""Test that CORRECT feedback doesn't propagate corrections."""
|
||||||
|
processor = FeedbackProcessor(auto_integrate_corrections=True)
|
||||||
|
|
||||||
|
integration = CorrectionPackIntegration(service=service)
|
||||||
|
processor._correction_integration = integration
|
||||||
|
|
||||||
|
result = processor.process_feedback(
|
||||||
|
workflow_id='wf_correct',
|
||||||
|
execution_id='exec_002',
|
||||||
|
feedback_type=FeedbackType.CORRECT, # Correct, no correction needed
|
||||||
|
confidence=0.95,
|
||||||
|
corrections={'some': 'data'} # Even with data, shouldn't propagate
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result['correction_pack_id'] is None
|
||||||
|
|
||||||
|
def test_processor_partial_feedback_propagates(self, service):
|
||||||
|
"""Test that PARTIAL feedback propagates corrections."""
|
||||||
|
processor = FeedbackProcessor(auto_integrate_corrections=True)
|
||||||
|
|
||||||
|
integration = CorrectionPackIntegration(service=service)
|
||||||
|
processor._correction_integration = integration
|
||||||
|
|
||||||
|
result = processor.process_feedback(
|
||||||
|
workflow_id='wf_partial',
|
||||||
|
execution_id='exec_003',
|
||||||
|
feedback_type=FeedbackType.PARTIAL,
|
||||||
|
confidence=0.7,
|
||||||
|
corrections={
|
||||||
|
'action_type': 'fill',
|
||||||
|
'corrected_params': {'value': 'corrected_value'}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result['correction_pack_id'] is not None
|
||||||
|
|
||||||
|
def test_processor_disabled_integration(self):
|
||||||
|
"""Test processor with disabled integration."""
|
||||||
|
processor = FeedbackProcessor(auto_integrate_corrections=False)
|
||||||
|
|
||||||
|
result = processor.process_feedback(
|
||||||
|
workflow_id='wf_disabled',
|
||||||
|
execution_id='exec_004',
|
||||||
|
feedback_type=FeedbackType.INCORRECT,
|
||||||
|
confidence=0.3,
|
||||||
|
corrections={'action_type': 'test'}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Feedback recorded but no correction pack ID
|
||||||
|
assert result['feedback_recorded'] is True
|
||||||
|
assert result['correction_pack_id'] is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndToEndFlow:
|
||||||
|
"""End-to-end tests for the complete COACHING flow."""
|
||||||
|
|
||||||
|
def test_coaching_correction_flow(self, temp_storage, service):
|
||||||
|
"""Test complete COACHING correction flow."""
|
||||||
|
# Setup components
|
||||||
|
integration = CorrectionPackIntegration(service=service)
|
||||||
|
|
||||||
|
collector = TrainingDataCollector(
|
||||||
|
output_dir=str(temp_storage / 'training'),
|
||||||
|
auto_integrate_corrections=True
|
||||||
|
)
|
||||||
|
collector._correction_integration = integration
|
||||||
|
|
||||||
|
processor = FeedbackProcessor(auto_integrate_corrections=True)
|
||||||
|
processor._correction_integration = integration
|
||||||
|
|
||||||
|
# Simulate COACHING session
|
||||||
|
collector.start_session('coaching_session_001', workflow_id='wf_coaching')
|
||||||
|
|
||||||
|
# User makes corrections during execution
|
||||||
|
collector.record_correction({
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'button',
|
||||||
|
'failure_reason': 'element_not_found',
|
||||||
|
'correction_type': 'target_change',
|
||||||
|
'original_target': {'text': 'OK'},
|
||||||
|
'corrected_target': {'text': 'Valider'}
|
||||||
|
})
|
||||||
|
|
||||||
|
collector.record_correction({
|
||||||
|
'action_type': 'type',
|
||||||
|
'element_type': 'input',
|
||||||
|
'failure_reason': 'wrong_field',
|
||||||
|
'correction_type': 'target_change',
|
||||||
|
'original_target': {'id': 'email'},
|
||||||
|
'corrected_target': {'name': 'user_email'}
|
||||||
|
})
|
||||||
|
|
||||||
|
collector.end_session(success=True)
|
||||||
|
|
||||||
|
# User provides feedback at end
|
||||||
|
processor.process_feedback(
|
||||||
|
workflow_id='wf_coaching',
|
||||||
|
execution_id='coaching_session_001',
|
||||||
|
feedback_type=FeedbackType.PARTIAL,
|
||||||
|
confidence=0.75,
|
||||||
|
corrections={
|
||||||
|
'action_type': 'submit',
|
||||||
|
'element_type': 'form',
|
||||||
|
'corrected_params': {'wait_after': 2.0}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify all corrections captured
|
||||||
|
pack = service.get_pack(integration._default_pack_id)
|
||||||
|
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||||
|
assert len(corrections) == 3
|
||||||
|
|
||||||
|
# Verify statistics
|
||||||
|
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
|
||||||
|
stats = service.get_pack_statistics(pack_id)
|
||||||
|
assert stats['total_corrections'] == 3
|
||||||
|
|
||||||
|
def test_multiple_sessions_aggregation(self, temp_storage, service):
|
||||||
|
"""Test corrections from multiple sessions aggregate correctly."""
|
||||||
|
integration = CorrectionPackIntegration(service=service)
|
||||||
|
|
||||||
|
# Session 1
|
||||||
|
collector1 = TrainingDataCollector(
|
||||||
|
output_dir=str(temp_storage / 'training1')
|
||||||
|
)
|
||||||
|
collector1._correction_integration = integration
|
||||||
|
|
||||||
|
collector1.start_session('session_001', workflow_id='wf_multi')
|
||||||
|
collector1.record_correction({
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'button',
|
||||||
|
'failure_reason': 'timeout',
|
||||||
|
'correction_type': 'timing_adjust'
|
||||||
|
})
|
||||||
|
collector1.end_session(success=True)
|
||||||
|
|
||||||
|
# Session 2 - same type of correction
|
||||||
|
collector2 = TrainingDataCollector(
|
||||||
|
output_dir=str(temp_storage / 'training2')
|
||||||
|
)
|
||||||
|
collector2._correction_integration = integration
|
||||||
|
|
||||||
|
collector2.start_session('session_002', workflow_id='wf_multi')
|
||||||
|
collector2.record_correction({
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'button',
|
||||||
|
'failure_reason': 'timeout',
|
||||||
|
'correction_type': 'timing_adjust'
|
||||||
|
})
|
||||||
|
collector2.end_session(success=True)
|
||||||
|
|
||||||
|
# Both corrections captured
|
||||||
|
pack = service.get_pack(integration._default_pack_id)
|
||||||
|
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||||
|
assert len(corrections) == 2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
Reference in New Issue
Block a user