feat(corrections): Add automatic COACHING integration for Correction Packs
- Add CorrectionPackIntegration class to bridge learning components - Modify TrainingDataCollector to auto-propagate corrections to packs - Modify FeedbackProcessor to capture corrections on INCORRECT/PARTIAL feedback - Add convenience functions: get_correction_pack_integration(), capture_coaching_correction() - Add 19 integration tests (all passing) Corrections made during COACHING mode are now automatically captured into a dedicated "auto_captured_corrections" pack for cross-workflow reuse. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
246
core/training/training_data_collector.py
Normal file
246
core/training/training_data_collector.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Training Data Collector - Collect clean training data during real usage"""
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type for correction callback
|
||||
CorrectionCallback = Callable[[Dict[str, Any], str, Optional[str]], None]
|
||||
|
||||
@dataclass
|
||||
class TrainingSession:
|
||||
"""Single training session record"""
|
||||
session_id: str
|
||||
workflow_id: Optional[str]
|
||||
timestamp: datetime
|
||||
screenshots: List[str] # Paths to screenshots
|
||||
actions: List[Dict[str, Any]]
|
||||
embeddings: List[str] # Paths to embedding files
|
||||
success: bool
|
||||
user_corrections: List[Dict[str, Any]] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class WorkflowPattern:
|
||||
"""Detected workflow pattern"""
|
||||
pattern_id: str
|
||||
name: str
|
||||
frequency: int
|
||||
avg_confidence: float
|
||||
states: List[str] # State IDs
|
||||
actions: List[str] # Action types
|
||||
success_rate: float
|
||||
|
||||
class TrainingDataCollector:
|
||||
"""Collect structured training data during real usage"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str = "training_data",
|
||||
correction_callback: Optional[CorrectionCallback] = None,
|
||||
auto_integrate_corrections: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize the training data collector.
|
||||
|
||||
Args:
|
||||
output_dir: Directory for storing training data
|
||||
correction_callback: Optional callback for correction propagation
|
||||
auto_integrate_corrections: Auto-enable correction pack integration
|
||||
"""
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.sessions: List[TrainingSession] = []
|
||||
self.patterns: List[WorkflowPattern] = []
|
||||
self.current_session: Optional[TrainingSession] = None
|
||||
|
||||
# Correction pack integration
|
||||
self._correction_callback = correction_callback
|
||||
self._auto_integrate = auto_integrate_corrections
|
||||
self._correction_integration = None
|
||||
|
||||
logger.info(f"TrainingDataCollector initialized (output_dir={output_dir})")
|
||||
|
||||
def start_session(self, session_id: str, workflow_id: Optional[str] = None) -> None:
|
||||
"""Start recording a new training session"""
|
||||
self.current_session = TrainingSession(
|
||||
session_id=session_id,
|
||||
workflow_id=workflow_id,
|
||||
timestamp=datetime.now(),
|
||||
screenshots=[],
|
||||
actions=[],
|
||||
embeddings=[],
|
||||
success=False
|
||||
)
|
||||
logger.info(f"Started training session: {session_id}")
|
||||
|
||||
def record_screenshot(self, screenshot_path: str) -> None:
|
||||
"""Record a screenshot"""
|
||||
if self.current_session:
|
||||
self.current_session.screenshots.append(screenshot_path)
|
||||
|
||||
def record_action(self, action: Dict[str, Any]) -> None:
|
||||
"""Record an action"""
|
||||
if self.current_session:
|
||||
self.current_session.actions.append({
|
||||
**action,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
})
|
||||
|
||||
def record_embedding(self, embedding_path: str) -> None:
|
||||
"""Record an embedding file path"""
|
||||
if self.current_session:
|
||||
self.current_session.embeddings.append(embedding_path)
|
||||
|
||||
def record_correction(self, correction: Dict[str, Any]) -> None:
|
||||
"""Record a user correction and propagate to Correction Packs."""
|
||||
if self.current_session:
|
||||
timestamped_correction = {
|
||||
**correction,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
self.current_session.user_corrections.append(timestamped_correction)
|
||||
logger.info(f"User correction recorded: {correction.get('type', 'unknown')}")
|
||||
|
||||
# Propagate to Correction Packs
|
||||
self._propagate_to_correction_packs(timestamped_correction)
|
||||
|
||||
def _propagate_to_correction_packs(self, correction: Dict[str, Any]) -> None:
|
||||
"""Propagate correction to Correction Pack system."""
|
||||
# Use explicit callback if provided
|
||||
if self._correction_callback:
|
||||
try:
|
||||
session_id = self.current_session.session_id if self.current_session else None
|
||||
workflow_id = self.current_session.workflow_id if self.current_session else None
|
||||
self._correction_callback(correction, session_id, workflow_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in correction callback: {e}")
|
||||
|
||||
# Auto-integrate if enabled
|
||||
elif self._auto_integrate:
|
||||
try:
|
||||
if self._correction_integration is None:
|
||||
from core.corrections import get_correction_pack_integration
|
||||
self._correction_integration = get_correction_pack_integration()
|
||||
|
||||
session_id = self.current_session.session_id if self.current_session else None
|
||||
workflow_id = self.current_session.workflow_id if self.current_session else None
|
||||
|
||||
self._correction_integration.capture_correction(
|
||||
correction_data=correction,
|
||||
session_id=session_id,
|
||||
workflow_id=workflow_id
|
||||
)
|
||||
except ImportError:
|
||||
logger.debug("Correction pack integration not available")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error propagating to correction packs: {e}")
|
||||
|
||||
def end_session(self, success: bool, metadata: Optional[Dict] = None) -> None:
|
||||
"""End current session and save"""
|
||||
if not self.current_session:
|
||||
logger.warning("No active session to end")
|
||||
return
|
||||
|
||||
self.current_session.success = success
|
||||
if metadata:
|
||||
self.current_session.metadata.update(metadata)
|
||||
|
||||
self.sessions.append(self.current_session)
|
||||
|
||||
# Save session to disk
|
||||
self._save_session(self.current_session)
|
||||
|
||||
logger.info(
|
||||
f"Session ended: {self.current_session.session_id} "
|
||||
f"(success={success}, actions={len(self.current_session.actions)})"
|
||||
)
|
||||
|
||||
self.current_session = None
|
||||
|
||||
def _save_session(self, session: TrainingSession) -> None:
|
||||
"""Save session to JSON file"""
|
||||
session_file = self.output_dir / f"session_{session.session_id}.json"
|
||||
|
||||
# Convert to dict with datetime serialization
|
||||
session_dict = asdict(session)
|
||||
session_dict['timestamp'] = session.timestamp.isoformat()
|
||||
|
||||
with open(session_file, 'w') as f:
|
||||
json.dump(session_dict, f, indent=2)
|
||||
|
||||
def add_pattern(self, pattern: WorkflowPattern) -> None:
|
||||
"""Add a detected workflow pattern"""
|
||||
self.patterns.append(pattern)
|
||||
logger.info(f"Pattern added: {pattern.name} (frequency={pattern.frequency})")
|
||||
|
||||
def export_training_set(self, output_file: str = "training_set.json") -> Dict[str, Any]:
|
||||
"""Export complete training dataset"""
|
||||
training_set = {
|
||||
'metadata': {
|
||||
'export_date': datetime.now().isoformat(),
|
||||
'total_sessions': len(self.sessions),
|
||||
'total_patterns': len(self.patterns),
|
||||
'success_rate': self._calculate_success_rate()
|
||||
},
|
||||
'sessions': [self._serialize_session(s) for s in self.sessions],
|
||||
'patterns': [asdict(p) for p in self.patterns],
|
||||
'statistics': self._calculate_statistics()
|
||||
}
|
||||
|
||||
output_path = self.output_dir / output_file
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(training_set, f, indent=2)
|
||||
|
||||
logger.info(f"Training set exported: {output_path} ({len(self.sessions)} sessions)")
|
||||
return training_set
|
||||
|
||||
def _calculate_success_rate(self) -> float:
|
||||
"""Calculate overall success rate"""
|
||||
if not self.sessions:
|
||||
return 0.0
|
||||
successful = sum(1 for s in self.sessions if s.success)
|
||||
return successful / len(self.sessions)
|
||||
|
||||
def _calculate_statistics(self) -> Dict[str, Any]:
|
||||
"""Calculate training data statistics"""
|
||||
if not self.sessions:
|
||||
return {}
|
||||
|
||||
total_actions = sum(len(s.actions) for s in self.sessions)
|
||||
total_corrections = sum(len(s.user_corrections) for s in self.sessions)
|
||||
|
||||
return {
|
||||
'total_sessions': len(self.sessions),
|
||||
'successful_sessions': sum(1 for s in self.sessions if s.success),
|
||||
'total_actions': total_actions,
|
||||
'total_corrections': total_corrections,
|
||||
'avg_actions_per_session': total_actions / len(self.sessions),
|
||||
'correction_rate': total_corrections / total_actions if total_actions > 0 else 0
|
||||
}
|
||||
|
||||
def get_sessions_by_workflow(self, workflow_id: str) -> List[TrainingSession]:
|
||||
"""Get all sessions for a specific workflow"""
|
||||
return [s for s in self.sessions if s.workflow_id == workflow_id]
|
||||
|
||||
def get_successful_sessions(self) -> List[TrainingSession]:
|
||||
"""Get only successful sessions"""
|
||||
return [s for s in self.sessions if s.success]
|
||||
|
||||
def clear_data(self) -> None:
|
||||
"""Clear all collected data"""
|
||||
self.sessions.clear()
|
||||
self.patterns.clear()
|
||||
self.current_session = None
|
||||
logger.info("Training data cleared")
|
||||
|
||||
def _serialize_session(self, session: TrainingSession) -> Dict:
|
||||
"""Serialize session with datetime conversion"""
|
||||
data = asdict(session)
|
||||
data['timestamp'] = session.timestamp.isoformat()
|
||||
return data
|
||||
Reference in New Issue
Block a user