feat(corrections): Add automatic COACHING integration for Correction Packs

- Add CorrectionPackIntegration class to bridge learning components
- Modify TrainingDataCollector to auto-propagate corrections to packs
- Modify FeedbackProcessor to capture corrections on INCORRECT/PARTIAL feedback
- Add convenience functions: get_correction_pack_integration(), capture_coaching_correction()
- Add 19 integration tests (all passing)

Corrections made during COACHING mode are now automatically captured
into a dedicated "auto_captured_corrections" pack for cross-workflow reuse.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Dom
2026-01-18 19:06:09 +01:00
parent d8756883c5
commit efb184fdb9
5 changed files with 1206 additions and 1 deletions

View File

@@ -0,0 +1,246 @@
"""Training Data Collector - Collect clean training data during real usage"""
import logging
import json
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
logger = logging.getLogger(__name__)
# Type for correction callback
CorrectionCallback = Callable[[Dict[str, Any], str, Optional[str]], None]
@dataclass
class TrainingSession:
"""Single training session record"""
session_id: str
workflow_id: Optional[str]
timestamp: datetime
screenshots: List[str] # Paths to screenshots
actions: List[Dict[str, Any]]
embeddings: List[str] # Paths to embedding files
success: bool
user_corrections: List[Dict[str, Any]] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class WorkflowPattern:
"""Detected workflow pattern"""
pattern_id: str
name: str
frequency: int
avg_confidence: float
states: List[str] # State IDs
actions: List[str] # Action types
success_rate: float
class TrainingDataCollector:
"""Collect structured training data during real usage"""
def __init__(
self,
output_dir: str = "training_data",
correction_callback: Optional[CorrectionCallback] = None,
auto_integrate_corrections: bool = True
):
"""
Initialize the training data collector.
Args:
output_dir: Directory for storing training data
correction_callback: Optional callback for correction propagation
auto_integrate_corrections: Auto-enable correction pack integration
"""
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.sessions: List[TrainingSession] = []
self.patterns: List[WorkflowPattern] = []
self.current_session: Optional[TrainingSession] = None
# Correction pack integration
self._correction_callback = correction_callback
self._auto_integrate = auto_integrate_corrections
self._correction_integration = None
logger.info(f"TrainingDataCollector initialized (output_dir={output_dir})")
def start_session(self, session_id: str, workflow_id: Optional[str] = None) -> None:
"""Start recording a new training session"""
self.current_session = TrainingSession(
session_id=session_id,
workflow_id=workflow_id,
timestamp=datetime.now(),
screenshots=[],
actions=[],
embeddings=[],
success=False
)
logger.info(f"Started training session: {session_id}")
def record_screenshot(self, screenshot_path: str) -> None:
"""Record a screenshot"""
if self.current_session:
self.current_session.screenshots.append(screenshot_path)
def record_action(self, action: Dict[str, Any]) -> None:
"""Record an action"""
if self.current_session:
self.current_session.actions.append({
**action,
'timestamp': datetime.now().isoformat()
})
def record_embedding(self, embedding_path: str) -> None:
"""Record an embedding file path"""
if self.current_session:
self.current_session.embeddings.append(embedding_path)
def record_correction(self, correction: Dict[str, Any]) -> None:
"""Record a user correction and propagate to Correction Packs."""
if self.current_session:
timestamped_correction = {
**correction,
'timestamp': datetime.now().isoformat()
}
self.current_session.user_corrections.append(timestamped_correction)
logger.info(f"User correction recorded: {correction.get('type', 'unknown')}")
# Propagate to Correction Packs
self._propagate_to_correction_packs(timestamped_correction)
def _propagate_to_correction_packs(self, correction: Dict[str, Any]) -> None:
"""Propagate correction to Correction Pack system."""
# Use explicit callback if provided
if self._correction_callback:
try:
session_id = self.current_session.session_id if self.current_session else None
workflow_id = self.current_session.workflow_id if self.current_session else None
self._correction_callback(correction, session_id, workflow_id)
except Exception as e:
logger.warning(f"Error in correction callback: {e}")
# Auto-integrate if enabled
elif self._auto_integrate:
try:
if self._correction_integration is None:
from core.corrections import get_correction_pack_integration
self._correction_integration = get_correction_pack_integration()
session_id = self.current_session.session_id if self.current_session else None
workflow_id = self.current_session.workflow_id if self.current_session else None
self._correction_integration.capture_correction(
correction_data=correction,
session_id=session_id,
workflow_id=workflow_id
)
except ImportError:
logger.debug("Correction pack integration not available")
except Exception as e:
logger.warning(f"Error propagating to correction packs: {e}")
def end_session(self, success: bool, metadata: Optional[Dict] = None) -> None:
"""End current session and save"""
if not self.current_session:
logger.warning("No active session to end")
return
self.current_session.success = success
if metadata:
self.current_session.metadata.update(metadata)
self.sessions.append(self.current_session)
# Save session to disk
self._save_session(self.current_session)
logger.info(
f"Session ended: {self.current_session.session_id} "
f"(success={success}, actions={len(self.current_session.actions)})"
)
self.current_session = None
def _save_session(self, session: TrainingSession) -> None:
"""Save session to JSON file"""
session_file = self.output_dir / f"session_{session.session_id}.json"
# Convert to dict with datetime serialization
session_dict = asdict(session)
session_dict['timestamp'] = session.timestamp.isoformat()
with open(session_file, 'w') as f:
json.dump(session_dict, f, indent=2)
def add_pattern(self, pattern: WorkflowPattern) -> None:
"""Add a detected workflow pattern"""
self.patterns.append(pattern)
logger.info(f"Pattern added: {pattern.name} (frequency={pattern.frequency})")
def export_training_set(self, output_file: str = "training_set.json") -> Dict[str, Any]:
"""Export complete training dataset"""
training_set = {
'metadata': {
'export_date': datetime.now().isoformat(),
'total_sessions': len(self.sessions),
'total_patterns': len(self.patterns),
'success_rate': self._calculate_success_rate()
},
'sessions': [self._serialize_session(s) for s in self.sessions],
'patterns': [asdict(p) for p in self.patterns],
'statistics': self._calculate_statistics()
}
output_path = self.output_dir / output_file
with open(output_path, 'w') as f:
json.dump(training_set, f, indent=2)
logger.info(f"Training set exported: {output_path} ({len(self.sessions)} sessions)")
return training_set
def _calculate_success_rate(self) -> float:
"""Calculate overall success rate"""
if not self.sessions:
return 0.0
successful = sum(1 for s in self.sessions if s.success)
return successful / len(self.sessions)
def _calculate_statistics(self) -> Dict[str, Any]:
"""Calculate training data statistics"""
if not self.sessions:
return {}
total_actions = sum(len(s.actions) for s in self.sessions)
total_corrections = sum(len(s.user_corrections) for s in self.sessions)
return {
'total_sessions': len(self.sessions),
'successful_sessions': sum(1 for s in self.sessions if s.success),
'total_actions': total_actions,
'total_corrections': total_corrections,
'avg_actions_per_session': total_actions / len(self.sessions),
'correction_rate': total_corrections / total_actions if total_actions > 0 else 0
}
def get_sessions_by_workflow(self, workflow_id: str) -> List[TrainingSession]:
"""Get all sessions for a specific workflow"""
return [s for s in self.sessions if s.workflow_id == workflow_id]
def get_successful_sessions(self) -> List[TrainingSession]:
"""Get only successful sessions"""
return [s for s in self.sessions if s.success]
def clear_data(self) -> None:
"""Clear all collected data"""
self.sessions.clear()
self.patterns.clear()
self.current_session = None
logger.info("Training data cleared")
def _serialize_session(self, session: TrainingSession) -> Dict:
"""Serialize session with datetime conversion"""
data = asdict(session)
data['timestamp'] = session.timestamp.isoformat()
return data