feat(corrections): Add automatic COACHING integration for Correction Packs
- Add CorrectionPackIntegration class to bridge learning components - Modify TrainingDataCollector to auto-propagate corrections to packs - Modify FeedbackProcessor to capture corrections on INCORRECT/PARTIAL feedback - Add convenience functions: get_correction_pack_integration(), capture_coaching_correction() - Add 19 integration tests (all passing) Corrections made during COACHING mode are now automatically captured into a dedicated "auto_captured_corrections" pack for cross-workflow reuse. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -16,6 +16,11 @@ from .models import (
|
||||
from .correction_repository import CorrectionRepository
|
||||
from .aggregator import CorrectionAggregator
|
||||
from .correction_pack_service import CorrectionPackService
|
||||
from .integration import (
|
||||
CorrectionPackIntegration,
|
||||
get_correction_pack_integration,
|
||||
capture_coaching_correction
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'CorrectionKey',
|
||||
@@ -26,5 +31,9 @@ __all__ = [
|
||||
'CorrectionPackMetadata',
|
||||
'CorrectionRepository',
|
||||
'CorrectionAggregator',
|
||||
'CorrectionPackService'
|
||||
'CorrectionPackService',
|
||||
# Integration
|
||||
'CorrectionPackIntegration',
|
||||
'get_correction_pack_integration',
|
||||
'capture_coaching_correction'
|
||||
]
|
||||
|
||||
288
core/corrections/integration.py
Normal file
288
core/corrections/integration.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Correction Pack Integration - Bridge between learning components and Correction Packs.
|
||||
|
||||
Automatically captures corrections from COACHING mode and propagates them
|
||||
to the Correction Pack system for cross-workflow learning.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Callable
|
||||
from datetime import datetime
|
||||
|
||||
from .correction_pack_service import CorrectionPackService
|
||||
from .models import CorrectionType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CorrectionPackIntegration:
|
||||
"""
|
||||
Integration layer between learning components and Correction Packs.
|
||||
|
||||
Captures corrections from TrainingDataCollector and FeedbackProcessor
|
||||
and automatically forwards them to the Correction Pack system.
|
||||
"""
|
||||
|
||||
# Default pack name for auto-captured corrections
|
||||
DEFAULT_PACK_NAME = "auto_captured_corrections"
|
||||
DEFAULT_PACK_DESCRIPTION = "Corrections automatiquement capturées pendant les sessions COACHING"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service: Optional[CorrectionPackService] = None,
|
||||
auto_create_pack: bool = True,
|
||||
default_pack_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the integration.
|
||||
|
||||
Args:
|
||||
service: CorrectionPackService instance (lazy-loaded if None)
|
||||
auto_create_pack: Create default pack if it doesn't exist
|
||||
default_pack_id: ID of pack to use (auto-detected if None)
|
||||
"""
|
||||
self._service = service
|
||||
self._auto_create_pack = auto_create_pack
|
||||
self._default_pack_id = default_pack_id
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def service(self) -> CorrectionPackService:
|
||||
"""Lazy-load service."""
|
||||
if self._service is None:
|
||||
self._service = CorrectionPackService()
|
||||
return self._service
|
||||
|
||||
def _ensure_initialized(self) -> str:
|
||||
"""
|
||||
Ensure the default pack exists and return its ID.
|
||||
|
||||
Returns:
|
||||
Pack ID to use for storing corrections
|
||||
"""
|
||||
if self._initialized and self._default_pack_id:
|
||||
return self._default_pack_id
|
||||
|
||||
# If specific pack ID provided, verify it exists
|
||||
if self._default_pack_id:
|
||||
pack = self.service.get_pack(self._default_pack_id)
|
||||
if pack:
|
||||
self._initialized = True
|
||||
return self._default_pack_id
|
||||
else:
|
||||
logger.warning(f"Pack {self._default_pack_id} not found, creating default")
|
||||
|
||||
# Search for existing auto-capture pack
|
||||
packs = self.service.list_packs()
|
||||
for pack in packs:
|
||||
# Handle both dict and object formats
|
||||
pack_name = pack.get('name') if isinstance(pack, dict) else pack.name
|
||||
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
|
||||
if pack_name == self.DEFAULT_PACK_NAME:
|
||||
self._default_pack_id = pack_id
|
||||
self._initialized = True
|
||||
logger.info(f"Using existing auto-capture pack: {pack_id}")
|
||||
return self._default_pack_id
|
||||
|
||||
# Create new pack if auto_create enabled
|
||||
if self._auto_create_pack:
|
||||
pack = self.service.create_pack(
|
||||
name=self.DEFAULT_PACK_NAME,
|
||||
description=self.DEFAULT_PACK_DESCRIPTION,
|
||||
tags=["auto", "coaching"],
|
||||
category="learning"
|
||||
)
|
||||
# Handle both dict and object formats
|
||||
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
|
||||
self._default_pack_id = pack_id
|
||||
self._initialized = True
|
||||
logger.info(f"Created auto-capture pack: {pack_id}")
|
||||
return self._default_pack_id
|
||||
|
||||
raise RuntimeError("No default pack available and auto_create_pack is disabled")
|
||||
|
||||
def capture_correction(
|
||||
self,
|
||||
correction_data: Dict[str, Any],
|
||||
session_id: Optional[str] = None,
|
||||
workflow_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
pack_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Capture a correction and add it to a Correction Pack.
|
||||
|
||||
Args:
|
||||
correction_data: Correction dict (from TrainingDataCollector format)
|
||||
session_id: Source session ID
|
||||
workflow_id: Source workflow ID
|
||||
node_id: Source node ID
|
||||
pack_id: Specific pack to add to (uses default if None)
|
||||
|
||||
Returns:
|
||||
Correction ID if successfully added, None otherwise
|
||||
"""
|
||||
try:
|
||||
target_pack_id = pack_id or self._ensure_initialized()
|
||||
|
||||
# Enrich correction data with source info
|
||||
enriched = {
|
||||
**correction_data,
|
||||
'node_id': node_id or correction_data.get('node_id')
|
||||
}
|
||||
|
||||
# Determine session_id
|
||||
final_session_id = session_id or correction_data.get('session_id', 'unknown')
|
||||
final_workflow_id = workflow_id or correction_data.get('workflow_id')
|
||||
|
||||
# Add to pack
|
||||
correction = self.service.add_correction_from_session(
|
||||
pack_id=target_pack_id,
|
||||
session_correction=enriched,
|
||||
session_id=final_session_id,
|
||||
workflow_id=final_workflow_id
|
||||
)
|
||||
|
||||
if correction:
|
||||
# Handle both dict and object formats
|
||||
correction_id = correction.get('id') if isinstance(correction, dict) else correction.id
|
||||
logger.info(
|
||||
f"Captured correction {correction_id} from "
|
||||
f"session={final_session_id}, workflow={final_workflow_id}"
|
||||
)
|
||||
return correction_id
|
||||
else:
|
||||
logger.warning("Failed to add correction to pack")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing correction: {e}")
|
||||
return None
|
||||
|
||||
def capture_feedback_correction(
|
||||
self,
|
||||
workflow_id: str,
|
||||
execution_id: str,
|
||||
corrections: Dict[str, Any],
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Capture a correction from FeedbackProcessor.
|
||||
|
||||
Args:
|
||||
workflow_id: Workflow ID
|
||||
execution_id: Execution ID (used as session_id)
|
||||
corrections: Correction dict from feedback
|
||||
context: Additional context (failure_reason, element_type, etc.)
|
||||
|
||||
Returns:
|
||||
Correction ID if successfully added
|
||||
"""
|
||||
if not corrections:
|
||||
return None
|
||||
|
||||
# Build correction data in standard format
|
||||
correction_data = {
|
||||
'action_type': corrections.get('action_type', corrections.get('type', 'unknown')),
|
||||
'element_type': corrections.get('element_type', context.get('element_type', 'unknown') if context else 'unknown'),
|
||||
'failure_reason': corrections.get('failure_reason', context.get('failure_reason', '') if context else ''),
|
||||
'correction_type': self._infer_correction_type(corrections),
|
||||
'original_target': corrections.get('original_target'),
|
||||
'corrected_target': corrections.get('corrected_target', corrections.get('new_target')),
|
||||
'original_params': corrections.get('original_params'),
|
||||
'corrected_params': corrections.get('corrected_params', corrections.get('new_params')),
|
||||
'description': corrections.get('description', f"Correction from feedback {execution_id}")
|
||||
}
|
||||
|
||||
return self.capture_correction(
|
||||
correction_data=correction_data,
|
||||
session_id=execution_id,
|
||||
workflow_id=workflow_id
|
||||
)
|
||||
|
||||
def _infer_correction_type(self, corrections: Dict[str, Any]) -> str:
|
||||
"""Infer correction type from correction data."""
|
||||
if corrections.get('correction_type'):
|
||||
return corrections['correction_type']
|
||||
|
||||
# Infer from data
|
||||
if corrections.get('corrected_target') or corrections.get('new_target'):
|
||||
return CorrectionType.TARGET_CHANGE.value
|
||||
if corrections.get('corrected_params') or corrections.get('new_params'):
|
||||
return CorrectionType.PARAMETER_CHANGE.value
|
||||
if corrections.get('wait_time') or corrections.get('timing'):
|
||||
return CorrectionType.TIMING_ADJUST.value
|
||||
if corrections.get('coordinates') or corrections.get('offset'):
|
||||
return CorrectionType.COORDINATES_ADJUST.value
|
||||
|
||||
return CorrectionType.OTHER.value
|
||||
|
||||
def create_hook_for_collector(self) -> Callable[[Dict[str, Any]], None]:
|
||||
"""
|
||||
Create a hook function for TrainingDataCollector.
|
||||
|
||||
Returns:
|
||||
Callback function that captures corrections
|
||||
|
||||
Usage:
|
||||
integration = CorrectionPackIntegration()
|
||||
collector.on_correction = integration.create_hook_for_collector()
|
||||
"""
|
||||
def hook(correction: Dict[str, Any]) -> None:
|
||||
self.capture_correction(correction)
|
||||
return hook
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get statistics about captured corrections."""
|
||||
try:
|
||||
pack_id = self._ensure_initialized()
|
||||
return self.service.get_pack_statistics(pack_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting statistics: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
|
||||
# Global instance for easy access
|
||||
_global_integration: Optional[CorrectionPackIntegration] = None
|
||||
|
||||
|
||||
def get_correction_pack_integration() -> CorrectionPackIntegration:
|
||||
"""
|
||||
Get the global CorrectionPackIntegration instance.
|
||||
|
||||
Creates a new instance if one doesn't exist.
|
||||
|
||||
Returns:
|
||||
CorrectionPackIntegration instance
|
||||
"""
|
||||
global _global_integration
|
||||
if _global_integration is None:
|
||||
_global_integration = CorrectionPackIntegration()
|
||||
return _global_integration
|
||||
|
||||
|
||||
def capture_coaching_correction(
|
||||
correction_data: Dict[str, Any],
|
||||
session_id: Optional[str] = None,
|
||||
workflow_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Convenience function to capture a correction.
|
||||
|
||||
Args:
|
||||
correction_data: Correction dict
|
||||
session_id: Source session ID
|
||||
workflow_id: Source workflow ID
|
||||
node_id: Source node ID
|
||||
|
||||
Returns:
|
||||
Correction ID if successful
|
||||
"""
|
||||
integration = get_correction_pack_integration()
|
||||
return integration.capture_correction(
|
||||
correction_data=correction_data,
|
||||
session_id=session_id,
|
||||
workflow_id=workflow_id,
|
||||
node_id=node_id
|
||||
)
|
||||
176
core/learning/feedback_processor.py
Normal file
176
core/learning/feedback_processor.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""Feedback Processor - Processes user feedback to improve workflows"""
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy import for correction pack integration
|
||||
_correction_integration = None
|
||||
|
||||
class FeedbackType(str, Enum):
|
||||
"""Types of user feedback"""
|
||||
CORRECT = "correct"
|
||||
INCORRECT = "incorrect"
|
||||
PARTIAL = "partial"
|
||||
SKIP = "skip"
|
||||
|
||||
@dataclass
|
||||
class Feedback:
|
||||
"""User feedback on workflow execution"""
|
||||
workflow_id: str
|
||||
execution_id: str
|
||||
feedback_type: FeedbackType
|
||||
timestamp: datetime
|
||||
confidence_before: float
|
||||
user_comment: Optional[str] = None
|
||||
corrections: Optional[Dict] = None
|
||||
|
||||
class FeedbackProcessor:
|
||||
"""Processes user feedback to improve workflows"""
|
||||
|
||||
def __init__(self, auto_integrate_corrections: bool = True):
|
||||
"""
|
||||
Initialize the feedback processor.
|
||||
|
||||
Args:
|
||||
auto_integrate_corrections: Auto-propagate corrections to Correction Packs
|
||||
"""
|
||||
self.feedback_history: List[Feedback] = []
|
||||
self._auto_integrate = auto_integrate_corrections
|
||||
self._correction_integration = None
|
||||
logger.info("FeedbackProcessor initialized")
|
||||
|
||||
def process_feedback(
|
||||
self,
|
||||
workflow_id: str,
|
||||
execution_id: str,
|
||||
feedback_type: FeedbackType,
|
||||
confidence: float,
|
||||
comment: Optional[str] = None,
|
||||
corrections: Optional[Dict] = None,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict:
|
||||
"""Process user feedback and return improvement suggestions"""
|
||||
|
||||
feedback = Feedback(
|
||||
workflow_id=workflow_id,
|
||||
execution_id=execution_id,
|
||||
feedback_type=feedback_type,
|
||||
timestamp=datetime.now(),
|
||||
confidence_before=confidence,
|
||||
user_comment=comment,
|
||||
corrections=corrections
|
||||
)
|
||||
|
||||
self.feedback_history.append(feedback)
|
||||
|
||||
logger.info(
|
||||
f"Feedback processed: workflow={workflow_id}, "
|
||||
f"type={feedback_type.value}, confidence={confidence:.2f}"
|
||||
)
|
||||
|
||||
# Propagate corrections to Correction Packs if applicable
|
||||
correction_id = None
|
||||
if corrections and feedback_type in [FeedbackType.INCORRECT, FeedbackType.PARTIAL]:
|
||||
correction_id = self._propagate_corrections(
|
||||
workflow_id, execution_id, corrections, context
|
||||
)
|
||||
|
||||
# Generate improvement suggestions
|
||||
suggestions = self._generate_suggestions(feedback)
|
||||
|
||||
return {
|
||||
'feedback_recorded': True,
|
||||
'suggestions': suggestions,
|
||||
'should_update_workflow': feedback_type in [FeedbackType.INCORRECT, FeedbackType.PARTIAL],
|
||||
'correction_pack_id': correction_id
|
||||
}
|
||||
|
||||
def _propagate_corrections(
|
||||
self,
|
||||
workflow_id: str,
|
||||
execution_id: str,
|
||||
corrections: Dict,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
"""Propagate corrections to Correction Pack system."""
|
||||
if not self._auto_integrate:
|
||||
return None
|
||||
|
||||
try:
|
||||
if self._correction_integration is None:
|
||||
from core.corrections import get_correction_pack_integration
|
||||
self._correction_integration = get_correction_pack_integration()
|
||||
|
||||
return self._correction_integration.capture_feedback_correction(
|
||||
workflow_id=workflow_id,
|
||||
execution_id=execution_id,
|
||||
corrections=corrections,
|
||||
context=context
|
||||
)
|
||||
except ImportError:
|
||||
logger.debug("Correction pack integration not available")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error propagating corrections: {e}")
|
||||
return None
|
||||
|
||||
def _generate_suggestions(self, feedback: Feedback) -> List[str]:
|
||||
"""Generate improvement suggestions based on feedback"""
|
||||
suggestions = []
|
||||
|
||||
if feedback.feedback_type == FeedbackType.INCORRECT:
|
||||
suggestions.append("Review target resolution strategy")
|
||||
suggestions.append("Check if UI elements changed")
|
||||
suggestions.append("Verify action sequence")
|
||||
|
||||
if feedback.corrections:
|
||||
suggestions.append(f"Apply user corrections: {feedback.corrections}")
|
||||
|
||||
elif feedback.feedback_type == FeedbackType.PARTIAL:
|
||||
suggestions.append("Some steps succeeded - identify failing step")
|
||||
suggestions.append("Consider splitting workflow into smaller parts")
|
||||
|
||||
elif feedback.feedback_type == FeedbackType.CORRECT:
|
||||
suggestions.append("Workflow performing well - increase confidence")
|
||||
|
||||
return suggestions
|
||||
|
||||
def get_feedback_stats(self, workflow_id: str) -> Dict:
|
||||
"""Get feedback statistics for a workflow"""
|
||||
workflow_feedback = [f for f in self.feedback_history if f.workflow_id == workflow_id]
|
||||
|
||||
if not workflow_feedback:
|
||||
return {
|
||||
'total': 0,
|
||||
'correct': 0,
|
||||
'incorrect': 0,
|
||||
'partial': 0,
|
||||
'skip': 0,
|
||||
'accuracy': 0.0
|
||||
}
|
||||
|
||||
total = len(workflow_feedback)
|
||||
correct = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.CORRECT)
|
||||
incorrect = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.INCORRECT)
|
||||
partial = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.PARTIAL)
|
||||
skip = sum(1 for f in workflow_feedback if f.feedback_type == FeedbackType.SKIP)
|
||||
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
'total': total,
|
||||
'correct': correct,
|
||||
'incorrect': incorrect,
|
||||
'partial': partial,
|
||||
'skip': skip,
|
||||
'accuracy': accuracy
|
||||
}
|
||||
|
||||
def get_recent_feedback(self, workflow_id: str, limit: int = 10) -> List[Feedback]:
|
||||
"""Get recent feedback for a workflow"""
|
||||
workflow_feedback = [f for f in self.feedback_history if f.workflow_id == workflow_id]
|
||||
return sorted(workflow_feedback, key=lambda f: f.timestamp, reverse=True)[:limit]
|
||||
246
core/training/training_data_collector.py
Normal file
246
core/training/training_data_collector.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Training Data Collector - Collect clean training data during real usage"""
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type for correction callback
|
||||
CorrectionCallback = Callable[[Dict[str, Any], str, Optional[str]], None]
|
||||
|
||||
@dataclass
|
||||
class TrainingSession:
|
||||
"""Single training session record"""
|
||||
session_id: str
|
||||
workflow_id: Optional[str]
|
||||
timestamp: datetime
|
||||
screenshots: List[str] # Paths to screenshots
|
||||
actions: List[Dict[str, Any]]
|
||||
embeddings: List[str] # Paths to embedding files
|
||||
success: bool
|
||||
user_corrections: List[Dict[str, Any]] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class WorkflowPattern:
|
||||
"""Detected workflow pattern"""
|
||||
pattern_id: str
|
||||
name: str
|
||||
frequency: int
|
||||
avg_confidence: float
|
||||
states: List[str] # State IDs
|
||||
actions: List[str] # Action types
|
||||
success_rate: float
|
||||
|
||||
class TrainingDataCollector:
|
||||
"""Collect structured training data during real usage"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str = "training_data",
|
||||
correction_callback: Optional[CorrectionCallback] = None,
|
||||
auto_integrate_corrections: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize the training data collector.
|
||||
|
||||
Args:
|
||||
output_dir: Directory for storing training data
|
||||
correction_callback: Optional callback for correction propagation
|
||||
auto_integrate_corrections: Auto-enable correction pack integration
|
||||
"""
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.sessions: List[TrainingSession] = []
|
||||
self.patterns: List[WorkflowPattern] = []
|
||||
self.current_session: Optional[TrainingSession] = None
|
||||
|
||||
# Correction pack integration
|
||||
self._correction_callback = correction_callback
|
||||
self._auto_integrate = auto_integrate_corrections
|
||||
self._correction_integration = None
|
||||
|
||||
logger.info(f"TrainingDataCollector initialized (output_dir={output_dir})")
|
||||
|
||||
def start_session(self, session_id: str, workflow_id: Optional[str] = None) -> None:
|
||||
"""Start recording a new training session"""
|
||||
self.current_session = TrainingSession(
|
||||
session_id=session_id,
|
||||
workflow_id=workflow_id,
|
||||
timestamp=datetime.now(),
|
||||
screenshots=[],
|
||||
actions=[],
|
||||
embeddings=[],
|
||||
success=False
|
||||
)
|
||||
logger.info(f"Started training session: {session_id}")
|
||||
|
||||
def record_screenshot(self, screenshot_path: str) -> None:
|
||||
"""Record a screenshot"""
|
||||
if self.current_session:
|
||||
self.current_session.screenshots.append(screenshot_path)
|
||||
|
||||
def record_action(self, action: Dict[str, Any]) -> None:
|
||||
"""Record an action"""
|
||||
if self.current_session:
|
||||
self.current_session.actions.append({
|
||||
**action,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
})
|
||||
|
||||
def record_embedding(self, embedding_path: str) -> None:
|
||||
"""Record an embedding file path"""
|
||||
if self.current_session:
|
||||
self.current_session.embeddings.append(embedding_path)
|
||||
|
||||
def record_correction(self, correction: Dict[str, Any]) -> None:
|
||||
"""Record a user correction and propagate to Correction Packs."""
|
||||
if self.current_session:
|
||||
timestamped_correction = {
|
||||
**correction,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
self.current_session.user_corrections.append(timestamped_correction)
|
||||
logger.info(f"User correction recorded: {correction.get('type', 'unknown')}")
|
||||
|
||||
# Propagate to Correction Packs
|
||||
self._propagate_to_correction_packs(timestamped_correction)
|
||||
|
||||
def _propagate_to_correction_packs(self, correction: Dict[str, Any]) -> None:
|
||||
"""Propagate correction to Correction Pack system."""
|
||||
# Use explicit callback if provided
|
||||
if self._correction_callback:
|
||||
try:
|
||||
session_id = self.current_session.session_id if self.current_session else None
|
||||
workflow_id = self.current_session.workflow_id if self.current_session else None
|
||||
self._correction_callback(correction, session_id, workflow_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in correction callback: {e}")
|
||||
|
||||
# Auto-integrate if enabled
|
||||
elif self._auto_integrate:
|
||||
try:
|
||||
if self._correction_integration is None:
|
||||
from core.corrections import get_correction_pack_integration
|
||||
self._correction_integration = get_correction_pack_integration()
|
||||
|
||||
session_id = self.current_session.session_id if self.current_session else None
|
||||
workflow_id = self.current_session.workflow_id if self.current_session else None
|
||||
|
||||
self._correction_integration.capture_correction(
|
||||
correction_data=correction,
|
||||
session_id=session_id,
|
||||
workflow_id=workflow_id
|
||||
)
|
||||
except ImportError:
|
||||
logger.debug("Correction pack integration not available")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error propagating to correction packs: {e}")
|
||||
|
||||
def end_session(self, success: bool, metadata: Optional[Dict] = None) -> None:
|
||||
"""End current session and save"""
|
||||
if not self.current_session:
|
||||
logger.warning("No active session to end")
|
||||
return
|
||||
|
||||
self.current_session.success = success
|
||||
if metadata:
|
||||
self.current_session.metadata.update(metadata)
|
||||
|
||||
self.sessions.append(self.current_session)
|
||||
|
||||
# Save session to disk
|
||||
self._save_session(self.current_session)
|
||||
|
||||
logger.info(
|
||||
f"Session ended: {self.current_session.session_id} "
|
||||
f"(success={success}, actions={len(self.current_session.actions)})"
|
||||
)
|
||||
|
||||
self.current_session = None
|
||||
|
||||
def _save_session(self, session: TrainingSession) -> None:
|
||||
"""Save session to JSON file"""
|
||||
session_file = self.output_dir / f"session_{session.session_id}.json"
|
||||
|
||||
# Convert to dict with datetime serialization
|
||||
session_dict = asdict(session)
|
||||
session_dict['timestamp'] = session.timestamp.isoformat()
|
||||
|
||||
with open(session_file, 'w') as f:
|
||||
json.dump(session_dict, f, indent=2)
|
||||
|
||||
def add_pattern(self, pattern: WorkflowPattern) -> None:
|
||||
"""Add a detected workflow pattern"""
|
||||
self.patterns.append(pattern)
|
||||
logger.info(f"Pattern added: {pattern.name} (frequency={pattern.frequency})")
|
||||
|
||||
def export_training_set(self, output_file: str = "training_set.json") -> Dict[str, Any]:
|
||||
"""Export complete training dataset"""
|
||||
training_set = {
|
||||
'metadata': {
|
||||
'export_date': datetime.now().isoformat(),
|
||||
'total_sessions': len(self.sessions),
|
||||
'total_patterns': len(self.patterns),
|
||||
'success_rate': self._calculate_success_rate()
|
||||
},
|
||||
'sessions': [self._serialize_session(s) for s in self.sessions],
|
||||
'patterns': [asdict(p) for p in self.patterns],
|
||||
'statistics': self._calculate_statistics()
|
||||
}
|
||||
|
||||
output_path = self.output_dir / output_file
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(training_set, f, indent=2)
|
||||
|
||||
logger.info(f"Training set exported: {output_path} ({len(self.sessions)} sessions)")
|
||||
return training_set
|
||||
|
||||
def _calculate_success_rate(self) -> float:
|
||||
"""Calculate overall success rate"""
|
||||
if not self.sessions:
|
||||
return 0.0
|
||||
successful = sum(1 for s in self.sessions if s.success)
|
||||
return successful / len(self.sessions)
|
||||
|
||||
def _calculate_statistics(self) -> Dict[str, Any]:
|
||||
"""Calculate training data statistics"""
|
||||
if not self.sessions:
|
||||
return {}
|
||||
|
||||
total_actions = sum(len(s.actions) for s in self.sessions)
|
||||
total_corrections = sum(len(s.user_corrections) for s in self.sessions)
|
||||
|
||||
return {
|
||||
'total_sessions': len(self.sessions),
|
||||
'successful_sessions': sum(1 for s in self.sessions if s.success),
|
||||
'total_actions': total_actions,
|
||||
'total_corrections': total_corrections,
|
||||
'avg_actions_per_session': total_actions / len(self.sessions),
|
||||
'correction_rate': total_corrections / total_actions if total_actions > 0 else 0
|
||||
}
|
||||
|
||||
def get_sessions_by_workflow(self, workflow_id: str) -> List[TrainingSession]:
|
||||
"""Get all sessions for a specific workflow"""
|
||||
return [s for s in self.sessions if s.workflow_id == workflow_id]
|
||||
|
||||
def get_successful_sessions(self) -> List[TrainingSession]:
|
||||
"""Get only successful sessions"""
|
||||
return [s for s in self.sessions if s.success]
|
||||
|
||||
def clear_data(self) -> None:
|
||||
"""Clear all collected data"""
|
||||
self.sessions.clear()
|
||||
self.patterns.clear()
|
||||
self.current_session = None
|
||||
logger.info("Training data cleared")
|
||||
|
||||
def _serialize_session(self, session: TrainingSession) -> Dict:
|
||||
"""Serialize session with datetime conversion"""
|
||||
data = asdict(session)
|
||||
data['timestamp'] = session.timestamp.isoformat()
|
||||
return data
|
||||
Reference in New Issue
Block a user