feat(corrections): Add Correction Packs system for cross-workflow learning

Implement a complete system for capitalizing user corrections across multiple
workflows and sessions. This enables automatic application of learned fixes
when similar failures occur in different contexts.

New components:
- core/corrections/models.py: CorrectionKey, Correction, CorrectionPack models
- core/corrections/correction_repository.py: JSON storage with atomic writes
- core/corrections/aggregator.py: Aggregation by hash and quality filtering
- core/corrections/correction_pack_service.py: CRUD, export/import, versioning
- backend/api/correction_packs.py: REST API with 15 endpoints

Features:
- MD5-based key hashing for correction deduplication
- Export/import in JSON and YAML formats
- Version history with rollback support
- Cross-workflow pattern detection
- Integration with SelfHealingEngine for automatic application
- 29 unit tests (all passing)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Dom
2026-01-18 18:48:35 +01:00
parent fa57ecdbfd
commit d8756883c5
9 changed files with 4411 additions and 0 deletions

View File

@@ -0,0 +1,30 @@
"""
Correction Packs System - Cross-workflow correction capitalization.
This module provides a system for storing, aggregating, and applying
user corrections across multiple workflows and sessions.
"""
from .models import (
CorrectionKey,
CorrectionType,
CorrectionStatus,
Correction,
CorrectionPack,
CorrectionPackMetadata
)
from .correction_repository import CorrectionRepository
from .aggregator import CorrectionAggregator
from .correction_pack_service import CorrectionPackService
__all__ = [
'CorrectionKey',
'CorrectionType',
'CorrectionStatus',
'Correction',
'CorrectionPack',
'CorrectionPackMetadata',
'CorrectionRepository',
'CorrectionAggregator',
'CorrectionPackService'
]

View File

@@ -0,0 +1,545 @@
"""
Correction Aggregator - Aggregate corrections from multiple sessions.
Groups corrections by key hash and filters by quality thresholds.
"""
import logging
from collections import defaultdict
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
from .models import (
Correction,
CorrectionKey,
CorrectionType,
CorrectionSource,
CorrectionStatus,
generate_correction_id
)
logger = logging.getLogger(__name__)
class CorrectionAggregator:
"""
Aggregates corrections from multiple sources (sessions, workflows).
Groups corrections by CorrectionKey hash and merges statistics.
Filters out low-quality corrections based on configurable thresholds.
"""
def __init__(
self,
min_occurrences: int = 2,
min_success_rate: float = 0.5,
min_confidence: float = 0.3
):
"""
Initialize the aggregator.
Args:
min_occurrences: Minimum number of times a correction must occur
min_success_rate: Minimum success rate to include correction
min_confidence: Minimum confidence score to include correction
"""
self.min_occurrences = min_occurrences
self.min_success_rate = min_success_rate
self.min_confidence = min_confidence
def aggregate_corrections(
self,
corrections: List[Correction],
merge_strategy: str = 'weighted_average'
) -> List[Correction]:
"""
Aggregate a list of corrections by key hash.
Args:
corrections: List of corrections to aggregate
merge_strategy: How to merge duplicate corrections:
- 'weighted_average': Weight by usage count
- 'best_confidence': Keep highest confidence
- 'most_recent': Keep most recently used
Returns:
List of aggregated corrections
"""
# Group by key hash
grouped: Dict[str, List[Correction]] = defaultdict(list)
for correction in corrections:
key_hash = correction.key.to_hash()
grouped[key_hash].append(correction)
# Aggregate each group
aggregated = []
for key_hash, group in grouped.items():
if len(group) == 1:
# Single correction, just check thresholds
if self._passes_thresholds(group[0]):
aggregated.append(group[0])
else:
# Multiple corrections, merge
merged = self._merge_corrections(group, merge_strategy)
if merged and self._passes_thresholds(merged):
aggregated.append(merged)
return aggregated
def aggregate_from_sessions(
self,
session_corrections: List[Dict[str, Any]],
workflow_id: Optional[str] = None
) -> List[Correction]:
"""
Aggregate corrections from session data (user_corrections format).
Args:
session_corrections: List of correction dicts from TrainingSession
workflow_id: Optional workflow ID to filter by
Returns:
List of aggregated Correction objects
"""
# Convert session corrections to Correction objects
corrections = []
for sc in session_corrections:
try:
correction = self._session_correction_to_correction(sc)
if correction:
# Filter by workflow if specified
if workflow_id is None or (correction.source and correction.source.workflow_id == workflow_id):
corrections.append(correction)
except Exception as e:
logger.warning(f"Error converting session correction: {e}")
# Aggregate
return self.aggregate_corrections(corrections)
def group_by_type(
self,
corrections: List[Correction]
) -> Dict[CorrectionType, List[Correction]]:
"""
Group corrections by their type.
Args:
corrections: List of corrections
Returns:
Dict mapping correction type to list of corrections
"""
grouped: Dict[CorrectionType, List[Correction]] = defaultdict(list)
for correction in corrections:
grouped[correction.correction_type].append(correction)
return dict(grouped)
def group_by_action_type(
self,
corrections: List[Correction]
) -> Dict[str, List[Correction]]:
"""
Group corrections by the action type they apply to.
Args:
corrections: List of corrections
Returns:
Dict mapping action type to list of corrections
"""
grouped: Dict[str, List[Correction]] = defaultdict(list)
for correction in corrections:
grouped[correction.key.action_type].append(correction)
return dict(grouped)
def filter_by_quality(
self,
corrections: List[Correction],
min_success_rate: Optional[float] = None,
min_confidence: Optional[float] = None,
min_applications: Optional[int] = None
) -> List[Correction]:
"""
Filter corrections by quality thresholds.
Args:
corrections: List of corrections to filter
min_success_rate: Minimum success rate (overrides default)
min_confidence: Minimum confidence (overrides default)
min_applications: Minimum number of applications
Returns:
Filtered list of corrections
"""
min_sr = min_success_rate if min_success_rate is not None else self.min_success_rate
min_conf = min_confidence if min_confidence is not None else self.min_confidence
min_apps = min_applications if min_applications is not None else 1
return [
c for c in corrections
if c.success_rate >= min_sr
and c.confidence_score >= min_conf
and (c.success_count + c.failure_count) >= min_apps
]
def rank_corrections(
self,
corrections: List[Correction],
sort_by: str = 'confidence'
) -> List[Correction]:
"""
Rank corrections by specified criteria.
Args:
corrections: List of corrections
sort_by: Ranking criteria:
- 'confidence': By confidence score
- 'success_rate': By success rate
- 'usage': By total applications
- 'recent': By last application date
Returns:
Sorted list of corrections
"""
if sort_by == 'confidence':
return sorted(corrections, key=lambda c: c.confidence_score, reverse=True)
elif sort_by == 'success_rate':
return sorted(corrections, key=lambda c: c.success_rate, reverse=True)
elif sort_by == 'usage':
return sorted(corrections, key=lambda c: c.success_count + c.failure_count, reverse=True)
elif sort_by == 'recent':
return sorted(
corrections,
key=lambda c: c.last_applied or datetime.min,
reverse=True
)
else:
return corrections
def calculate_statistics(
self,
corrections: List[Correction]
) -> Dict[str, Any]:
"""
Calculate aggregate statistics for a list of corrections.
Args:
corrections: List of corrections
Returns:
Statistics dict
"""
if not corrections:
return {
'count': 0,
'total_applications': 0,
'average_success_rate': 0.0,
'average_confidence': 0.0,
'by_type': {},
'by_action': {}
}
total_apps = sum(c.success_count + c.failure_count for c in corrections)
total_success = sum(c.success_count for c in corrections)
# Group by type
by_type = self.group_by_type(corrections)
type_stats = {
t.value: len(cs) for t, cs in by_type.items()
}
# Group by action
by_action = self.group_by_action_type(corrections)
action_stats = {
a: len(cs) for a, cs in by_action.items()
}
return {
'count': len(corrections),
'total_applications': total_apps,
'total_success': total_success,
'total_failure': total_apps - total_success,
'average_success_rate': total_success / total_apps if total_apps > 0 else 0.0,
'average_confidence': sum(c.confidence_score for c in corrections) / len(corrections),
'by_type': type_stats,
'by_action': action_stats
}
def _passes_thresholds(self, correction: Correction) -> bool:
"""Check if a correction passes quality thresholds."""
total = correction.success_count + correction.failure_count
# Always include if no applications yet (new correction)
if total == 0:
return True
# Check minimum occurrences
if total < self.min_occurrences:
return False
# Check success rate
if correction.success_rate < self.min_success_rate:
return False
# Check confidence
if correction.confidence_score < self.min_confidence:
return False
return True
def _merge_corrections(
self,
corrections: List[Correction],
strategy: str = 'weighted_average'
) -> Optional[Correction]:
"""
Merge multiple corrections with the same key hash.
Args:
corrections: List of corrections to merge
strategy: Merge strategy
Returns:
Merged correction
"""
if not corrections:
return None
if len(corrections) == 1:
return corrections[0]
if strategy == 'best_confidence':
return max(corrections, key=lambda c: c.confidence_score)
elif strategy == 'most_recent':
return max(
corrections,
key=lambda c: c.last_applied or c.created_at
)
else: # weighted_average (default)
# Use the correction with highest usage as base
base = max(corrections, key=lambda c: c.success_count + c.failure_count)
# Sum statistics
total_success = sum(c.success_count for c in corrections)
total_failure = sum(c.failure_count for c in corrections)
# Find most recent application
last_applied = None
for c in corrections:
if c.last_applied:
if last_applied is None or c.last_applied > last_applied:
last_applied = c.last_applied
# Merge tags
all_tags = set()
for c in corrections:
all_tags.update(c.tags)
# Create merged correction
merged = Correction(
id=generate_correction_id(),
key=base.key,
correction_type=base.correction_type,
original_target=base.original_target,
original_params=base.original_params,
corrected_target=base.corrected_target,
corrected_params=base.corrected_params,
correction_description=base.correction_description,
source=base.source, # Keep first source as primary
success_count=total_success,
failure_count=total_failure,
last_applied=last_applied,
status=CorrectionStatus.ACTIVE,
created_at=min(c.created_at for c in corrections),
updated_at=datetime.now(),
tags=list(all_tags)
)
return merged
def _session_correction_to_correction(
self,
session_correction: Dict[str, Any]
) -> Optional[Correction]:
"""
Convert a session correction dict to a Correction object.
Args:
session_correction: Dict from TrainingSession.user_corrections
Returns:
Correction object or None if invalid
"""
try:
# Extract key components
action_type = session_correction.get('action_type', session_correction.get('type', 'unknown'))
element_type = session_correction.get('element_type', 'unknown')
failure_context = session_correction.get('failure_reason', session_correction.get('reason', ''))
key = CorrectionKey(action_type, element_type, failure_context)
# Determine correction type
correction_type_str = session_correction.get('correction_type', 'other')
try:
correction_type = CorrectionType(correction_type_str)
except ValueError:
correction_type = CorrectionType.OTHER
# Create source
source = CorrectionSource(
session_id=session_correction.get('session_id', ''),
workflow_id=session_correction.get('workflow_id'),
node_id=session_correction.get('node_id')
)
timestamp = session_correction.get('timestamp')
if timestamp:
if isinstance(timestamp, str):
source.timestamp = datetime.fromisoformat(timestamp)
elif isinstance(timestamp, datetime):
source.timestamp = timestamp
# Create correction
return Correction(
id=generate_correction_id(),
key=key,
correction_type=correction_type,
original_target=session_correction.get('original_target'),
original_params=session_correction.get('original_params'),
corrected_target=session_correction.get('corrected_target', session_correction.get('new_target')),
corrected_params=session_correction.get('corrected_params', session_correction.get('new_params')),
correction_description=session_correction.get('description', ''),
source=source,
success_count=1 if session_correction.get('applied', True) else 0,
failure_count=0 if session_correction.get('applied', True) else 1,
tags=session_correction.get('tags', [])
)
except Exception as e:
logger.warning(f"Error converting session correction: {e}")
return None
class CrossWorkflowAggregator(CorrectionAggregator):
"""
Extended aggregator for cross-workflow correction analysis.
Provides additional methods for analyzing patterns across workflows.
"""
def find_common_corrections(
self,
corrections: List[Correction],
min_workflows: int = 2
) -> List[Tuple[Correction, List[str]]]:
"""
Find corrections that appear in multiple workflows.
Args:
corrections: List of corrections
min_workflows: Minimum number of workflows
Returns:
List of (correction, workflow_ids) tuples
"""
# Group by key hash
grouped: Dict[str, List[Correction]] = defaultdict(list)
for correction in corrections:
key_hash = correction.key.to_hash()
grouped[key_hash].append(correction)
# Find common corrections
common = []
for key_hash, group in grouped.items():
workflow_ids = set()
for c in group:
if c.source and c.source.workflow_id:
workflow_ids.add(c.source.workflow_id)
if len(workflow_ids) >= min_workflows:
# Merge the group
merged = self._merge_corrections(group)
if merged:
common.append((merged, list(workflow_ids)))
# Sort by number of workflows
common.sort(key=lambda x: len(x[1]), reverse=True)
return common
def detect_patterns(
self,
corrections: List[Correction]
) -> Dict[str, Any]:
"""
Detect patterns in corrections.
Args:
corrections: List of corrections
Returns:
Pattern analysis dict
"""
if not corrections:
return {'patterns': [], 'summary': {}}
patterns = []
# Pattern 1: Common action type corrections
by_action = self.group_by_action_type(corrections)
for action, action_corrections in by_action.items():
if len(action_corrections) >= 3:
patterns.append({
'type': 'frequent_action_correction',
'action_type': action,
'count': len(action_corrections),
'avg_success_rate': sum(c.success_rate for c in action_corrections) / len(action_corrections)
})
# Pattern 2: Common element type failures
by_element: Dict[str, List[Correction]] = defaultdict(list)
for c in corrections:
by_element[c.key.element_type].append(c)
for element, element_corrections in by_element.items():
if len(element_corrections) >= 3:
patterns.append({
'type': 'element_type_issues',
'element_type': element,
'count': len(element_corrections),
'common_fixes': self._find_common_fix_patterns(element_corrections)
})
# Pattern 3: Time-based patterns
recent = [c for c in corrections if c.last_applied and
(datetime.now() - c.last_applied).days < 7]
if len(recent) > len(corrections) * 0.5:
patterns.append({
'type': 'recent_activity_spike',
'recent_count': len(recent),
'total_count': len(corrections)
})
return {
'patterns': patterns,
'summary': {
'total_corrections': len(corrections),
'unique_action_types': len(by_action),
'unique_element_types': len(by_element),
'pattern_count': len(patterns)
}
}
def _find_common_fix_patterns(
self,
corrections: List[Correction]
) -> List[str]:
"""Find common fix patterns in a list of corrections."""
fix_types = defaultdict(int)
for c in corrections:
fix_types[c.correction_type.value] += 1
# Return top 3 most common
sorted_fixes = sorted(fix_types.items(), key=lambda x: x[1], reverse=True)
return [fix[0] for fix in sorted_fixes[:3]]

View File

@@ -0,0 +1,785 @@
"""
Correction Pack Service - Main service for correction pack management.
Combines repository, aggregator, and provides high-level operations
for CRUD, aggregation, export/import, and application of corrections.
"""
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple
from .models import (
Correction,
CorrectionPack,
CorrectionKey,
CorrectionType,
CorrectionSource,
CorrectionStatus,
CorrectionPackMetadata,
generate_correction_id,
generate_pack_id
)
from .correction_repository import CorrectionRepository
from .aggregator import CorrectionAggregator, CrossWorkflowAggregator
logger = logging.getLogger(__name__)
class CorrectionPackService:
"""
Main service for managing correction packs.
Provides high-level operations combining the repository and aggregator.
"""
def __init__(
self,
storage_path: str = "data/correction_packs",
training_data_path: str = "training_data"
):
"""
Initialize the service.
Args:
storage_path: Path for correction pack storage
training_data_path: Path to training data (for session import)
"""
self.storage_path = Path(storage_path)
self.training_data_path = Path(training_data_path)
# Initialize components
self.repository = CorrectionRepository(storage_path)
self.aggregator = CorrectionAggregator()
self.cross_workflow_aggregator = CrossWorkflowAggregator()
logger.info(f"CorrectionPackService initialized (storage={storage_path})")
# ========== Pack CRUD Operations ==========
def list_packs(
self,
category: Optional[str] = None,
tags: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""
List all correction packs with optional filtering.
Args:
category: Filter by category
tags: Filter by tags (match any)
Returns:
List of pack summaries
"""
packs = self.repository.list_packs()
# Filter by category
if category:
packs = [p for p in packs if p.metadata.category == category]
# Filter by tags
if tags:
packs = [
p for p in packs
if any(tag in p.metadata.tags for tag in tags)
]
# Return summaries
return [
{
'id': p.id,
'name': p.name,
'description': p.description,
'category': p.metadata.category,
'tags': p.metadata.tags,
'correction_count': p.correction_count,
'success_rate': p.overall_success_rate,
'created_at': p.created_at.isoformat(),
'updated_at': p.updated_at.isoformat()
}
for p in packs
]
def get_pack(self, pack_id: str) -> Optional[Dict[str, Any]]:
"""
Get a pack by ID with full details.
Args:
pack_id: Pack ID
Returns:
Pack dict or None
"""
pack = self.repository.get_pack(pack_id)
if not pack:
return None
return pack.to_dict()
def create_pack(
self,
name: str,
description: str = "",
category: str = "general",
tags: Optional[List[str]] = None,
author: str = ""
) -> Dict[str, Any]:
"""
Create a new correction pack.
Args:
name: Pack name
description: Pack description
category: Pack category
tags: Pack tags
author: Pack author
Returns:
Created pack dict
"""
metadata = {
'category': category,
'tags': tags or [],
'author': author
}
pack = self.repository.create_pack(name, description, metadata)
return pack.to_dict()
def update_pack(
self,
pack_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
"""
Update a pack's properties.
Args:
pack_id: Pack ID
name: New name
description: New description
metadata: New metadata
Returns:
Updated pack dict or None
"""
updates = {}
if name is not None:
updates['name'] = name
if description is not None:
updates['description'] = description
if metadata is not None:
updates['metadata'] = metadata
pack = self.repository.update_pack(pack_id, updates)
if not pack:
return None
return pack.to_dict()
def delete_pack(self, pack_id: str) -> bool:
"""
Delete a pack.
Args:
pack_id: Pack ID
Returns:
True if deleted
"""
return self.repository.delete_pack(pack_id)
# ========== Correction Operations ==========
def add_correction(
self,
pack_id: str,
correction_data: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Add a correction to a pack.
Args:
pack_id: Pack ID
correction_data: Correction data dict
Returns:
Created correction dict or None
"""
try:
# Build correction key
key = CorrectionKey(
action_type=correction_data.get('action_type', 'unknown'),
element_type=correction_data.get('element_type', 'unknown'),
failure_context=correction_data.get('failure_context', '')
)
# Determine correction type
type_str = correction_data.get('correction_type', 'other')
try:
correction_type = CorrectionType(type_str)
except ValueError:
correction_type = CorrectionType.OTHER
# Build source if provided
source = None
if 'source' in correction_data:
source = CorrectionSource(
session_id=correction_data['source'].get('session_id', ''),
workflow_id=correction_data['source'].get('workflow_id'),
node_id=correction_data['source'].get('node_id')
)
# Create correction
correction = Correction(
id=generate_correction_id(),
key=key,
correction_type=correction_type,
original_target=correction_data.get('original_target'),
original_params=correction_data.get('original_params'),
corrected_target=correction_data.get('corrected_target'),
corrected_params=correction_data.get('corrected_params'),
correction_description=correction_data.get('description', ''),
source=source,
tags=correction_data.get('tags', [])
)
if self.repository.add_correction(pack_id, correction):
return correction.to_dict()
return None
except Exception as e:
logger.error(f"Error adding correction: {e}")
return None
def add_correction_from_session(
self,
pack_id: str,
session_correction: Dict[str, Any],
session_id: str,
workflow_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Add a correction from session user_corrections format.
Args:
pack_id: Pack ID
session_correction: Correction dict from TrainingSession.user_corrections
session_id: Session ID
workflow_id: Workflow ID
Returns:
Created correction dict or None
"""
# Enrich with source info
session_correction['source'] = {
'session_id': session_id,
'workflow_id': workflow_id,
'node_id': session_correction.get('node_id')
}
# Map old format fields to new format
if 'type' in session_correction and 'action_type' not in session_correction:
session_correction['action_type'] = session_correction['type']
if 'reason' in session_correction and 'failure_context' not in session_correction:
session_correction['failure_context'] = session_correction['reason']
if 'new_target' in session_correction and 'corrected_target' not in session_correction:
session_correction['corrected_target'] = session_correction['new_target']
if 'new_params' in session_correction and 'corrected_params' not in session_correction:
session_correction['corrected_params'] = session_correction['new_params']
return self.add_correction(pack_id, session_correction)
def update_correction(
self,
pack_id: str,
correction_id: str,
updates: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Update a correction.
Args:
pack_id: Pack ID
correction_id: Correction ID
updates: Fields to update
Returns:
Updated correction dict or None
"""
correction = self.repository.update_correction(pack_id, correction_id, updates)
if not correction:
return None
return correction.to_dict()
def remove_correction(self, pack_id: str, correction_id: str) -> bool:
"""
Remove a correction from a pack.
Args:
pack_id: Pack ID
correction_id: Correction ID
Returns:
True if removed
"""
return self.repository.remove_correction(pack_id, correction_id)
# ========== Aggregation Operations ==========
def aggregate_from_sessions(
self,
pack_id: str,
session_files: Optional[List[str]] = None,
workflow_id: Optional[str] = None,
min_occurrences: int = 1,
min_success_rate: float = 0.0
) -> Dict[str, Any]:
"""
Aggregate corrections from training sessions into a pack.
Args:
pack_id: Pack ID to add corrections to
session_files: Specific session files to process (None = all)
workflow_id: Filter by workflow ID
min_occurrences: Minimum occurrences to include
min_success_rate: Minimum success rate to include
Returns:
Aggregation result summary
"""
# Collect all session corrections
all_corrections = []
if session_files is None:
# Find all session files
session_files = list(self.training_data_path.glob("session_*.json"))
else:
session_files = [Path(f) for f in session_files]
sessions_processed = 0
for session_file in session_files:
try:
with open(session_file, 'r', encoding='utf-8') as f:
session_data = json.load(f)
# Filter by workflow if specified
if workflow_id and session_data.get('workflow_id') != workflow_id:
continue
user_corrections = session_data.get('user_corrections', [])
session_id = session_data.get('session_id', session_file.stem)
for uc in user_corrections:
uc['session_id'] = session_id
uc['workflow_id'] = session_data.get('workflow_id')
all_corrections.append(uc)
sessions_processed += 1
except Exception as e:
logger.warning(f"Error reading session file {session_file}: {e}")
# Configure aggregator
self.aggregator.min_occurrences = min_occurrences
self.aggregator.min_success_rate = min_success_rate
# Aggregate
aggregated = self.aggregator.aggregate_from_sessions(all_corrections, workflow_id)
# Add to pack
added_count = 0
for correction in aggregated:
if self.repository.add_correction(pack_id, correction):
added_count += 1
return {
'sessions_processed': sessions_processed,
'corrections_found': len(all_corrections),
'corrections_aggregated': len(aggregated),
'corrections_added': added_count,
'statistics': self.aggregator.calculate_statistics(aggregated)
}
def aggregate_cross_workflow(
self,
pack_id: str,
min_workflows: int = 2
) -> Dict[str, Any]:
"""
Find and aggregate corrections common across multiple workflows.
Args:
pack_id: Pack ID to add corrections to
min_workflows: Minimum number of workflows a correction must appear in
Returns:
Aggregation result summary
"""
# Get existing corrections from all packs
all_corrections = []
for pack in self.repository.list_packs():
all_corrections.extend(pack.corrections.values())
# Find common corrections
common = self.cross_workflow_aggregator.find_common_corrections(
all_corrections, min_workflows
)
# Add to pack
added_count = 0
for correction, workflow_ids in common:
# Tag with source workflows
correction.tags.extend([f"wf:{wid}" for wid in workflow_ids[:5]]) # Limit tags
correction.correction_description += f" (common in {len(workflow_ids)} workflows)"
if self.repository.add_correction(pack_id, correction):
added_count += 1
# Detect patterns
patterns = self.cross_workflow_aggregator.detect_patterns(all_corrections)
return {
'total_corrections_analyzed': len(all_corrections),
'common_corrections_found': len(common),
'corrections_added': added_count,
'patterns': patterns
}
# ========== Search and Apply Operations ==========
def find_applicable_corrections(
self,
action_type: str,
element_type: str,
failure_context: str,
pack_ids: Optional[List[str]] = None,
min_confidence: float = 0.3
) -> List[Dict[str, Any]]:
"""
Find corrections applicable to a failure context.
Args:
action_type: Type of action that failed
element_type: Type of element involved
failure_context: Description of failure
pack_ids: Specific packs to search (None = all)
min_confidence: Minimum confidence score
Returns:
List of applicable corrections with pack info
"""
results = self.repository.find_by_context(action_type, element_type, failure_context)
# Filter by pack IDs if specified
if pack_ids:
results = [(pid, c) for pid, c in results if pid in pack_ids]
# Filter by confidence
results = [(pid, c) for pid, c in results if c.confidence_score >= min_confidence]
# Format results
return [
{
'pack_id': pack_id,
'correction': correction.to_dict()
}
for pack_id, correction in results
]
def apply_correction(
self,
pack_id: str,
correction_id: str,
success: bool,
application_context: Optional[Dict[str, Any]] = None
) -> bool:
"""
Record an application of a correction.
Args:
pack_id: Pack ID
correction_id: Correction ID
success: Whether the application was successful
application_context: Optional context information
Returns:
True if recorded
"""
self.repository.record_application(pack_id, correction_id, success)
logger.info(
f"Recorded correction application: pack={pack_id}, "
f"correction={correction_id}, success={success}"
)
return True
def search_corrections(
self,
query: str,
pack_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Search corrections by text query.
Args:
query: Search query
pack_id: Optional pack ID to search in
Returns:
List of matching corrections with pack info
"""
results = self.repository.search_corrections(query, pack_id)
return [
{
'pack_id': pid,
'correction': correction.to_dict()
}
for pid, correction in results
]
# ========== Export/Import Operations ==========
def export_pack(
self,
pack_id: str,
format: str = 'json',
include_stats: bool = True
) -> Optional[str]:
"""
Export a pack to JSON or YAML.
Args:
pack_id: Pack ID
format: Export format ('json' or 'yaml')
include_stats: Whether to include statistics
Returns:
Exported content string or None
"""
return self.repository.export_pack(pack_id, format)
def import_pack(
self,
content: str,
format: str = 'json',
merge_strategy: str = 'create_new'
) -> Optional[Dict[str, Any]]:
"""
Import a pack from JSON or YAML.
Args:
content: Content to import
format: Import format ('json' or 'yaml')
merge_strategy: How to handle existing pack
Returns:
Imported pack dict or None
"""
pack = self.repository.import_pack(content, format, merge_strategy)
if not pack:
return None
return pack.to_dict()
def export_pack_to_file(
self,
pack_id: str,
file_path: str,
format: str = 'json'
) -> bool:
"""
Export a pack to a file.
Args:
pack_id: Pack ID
file_path: Output file path
format: Export format
Returns:
True if successful
"""
content = self.export_pack(pack_id, format)
if not content:
return False
try:
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
logger.info(f"Exported pack {pack_id} to {file_path}")
return True
except Exception as e:
logger.error(f"Error exporting pack to file: {e}")
return False
def import_pack_from_file(
self,
file_path: str,
format: Optional[str] = None,
merge_strategy: str = 'create_new'
) -> Optional[Dict[str, Any]]:
"""
Import a pack from a file.
Args:
file_path: Input file path
format: Import format (auto-detected if None)
merge_strategy: How to handle existing pack
Returns:
Imported pack dict or None
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# Auto-detect format
if format is None:
if file_path.endswith('.yaml') or file_path.endswith('.yml'):
format = 'yaml'
else:
format = 'json'
result = self.import_pack(content, format, merge_strategy)
if result:
logger.info(f"Imported pack from {file_path}")
return result
except Exception as e:
logger.error(f"Error importing pack from file: {e}")
return None
# ========== Version Operations ==========
def create_version(
self,
pack_id: str,
description: str = ""
) -> int:
"""
Create a version snapshot of a pack.
Args:
pack_id: Pack ID
description: Version description
Returns:
Version number (-1 on error)
"""
return self.repository.create_version(pack_id, description)
def list_versions(self, pack_id: str) -> List[Dict[str, Any]]:
"""
List all versions of a pack.
Args:
pack_id: Pack ID
Returns:
List of version info dicts
"""
return self.repository.list_versions(pack_id)
def rollback_pack(self, pack_id: str, version: int) -> bool:
"""
Rollback a pack to a previous version.
Args:
pack_id: Pack ID
version: Version number
Returns:
True if successful
"""
return self.repository.rollback_pack(pack_id, version)
# ========== Statistics and Analysis ==========
def get_pack_statistics(self, pack_id: str) -> Optional[Dict[str, Any]]:
"""
Get detailed statistics for a pack.
Args:
pack_id: Pack ID
Returns:
Statistics dict or None
"""
pack = self.repository.get_pack(pack_id)
if not pack:
return None
return pack.get_statistics()
def get_global_statistics(self) -> Dict[str, Any]:
"""
Get global statistics across all packs.
Returns:
Global statistics dict
"""
packs = self.repository.list_packs()
total_corrections = sum(p.correction_count for p in packs)
total_applications = sum(p.total_applications for p in packs)
# Aggregate success rate
total_success = sum(
sum(c.success_count for c in p.corrections.values())
for p in packs
)
return {
'total_packs': len(packs),
'total_corrections': total_corrections,
'total_applications': total_applications,
'overall_success_rate': total_success / total_applications if total_applications > 0 else 0.0,
'packs_by_category': self._group_packs_by_category(packs),
'top_corrections': self._get_top_corrections(packs, 10)
}
def _group_packs_by_category(
self,
packs: List[CorrectionPack]
) -> Dict[str, int]:
"""Group packs by category."""
by_category = {}
for pack in packs:
cat = pack.metadata.category
by_category[cat] = by_category.get(cat, 0) + 1
return by_category
def _get_top_corrections(
self,
packs: List[CorrectionPack],
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get top corrections by confidence score."""
all_corrections = []
for pack in packs:
for correction in pack.corrections.values():
all_corrections.append({
'pack_id': pack.id,
'pack_name': pack.name,
'correction': correction
})
# Sort by confidence
all_corrections.sort(key=lambda x: x['correction'].confidence_score, reverse=True)
# Return top N
return [
{
'pack_id': item['pack_id'],
'pack_name': item['pack_name'],
'correction_id': item['correction'].id,
'action_type': item['correction'].key.action_type,
'confidence': item['correction'].confidence_score,
'success_rate': item['correction'].success_rate,
'applications': item['correction'].success_count + item['correction'].failure_count
}
for item in all_corrections[:limit]
]

View File

@@ -0,0 +1,643 @@
"""
Correction Repository - JSON storage for corrections and packs.
Provides atomic file operations for safe persistence.
"""
import json
import logging
import os
import shutil
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
from .models import (
Correction,
CorrectionPack,
CorrectionKey,
generate_pack_id
)
logger = logging.getLogger(__name__)
class CorrectionRepository:
"""
Repository for storing and retrieving correction packs.
Uses atomic JSON writes (temp file + rename) for safety.
Maintains an index by key hash for fast lookups.
"""
def __init__(self, storage_path: str = "data/correction_packs"):
"""
Initialize the repository.
Args:
storage_path: Base path for storing correction data
"""
self.storage_path = Path(storage_path)
self.packs_file = self.storage_path / "packs.json"
self.corrections_dir = self.storage_path / "corrections"
self.versions_dir = self.storage_path / "versions"
# Ensure directories exist
self._ensure_directories()
# In-memory cache
self._packs: Dict[str, CorrectionPack] = {}
self._key_index: Dict[str, List[str]] = {} # key_hash -> [correction_ids]
# Load existing data
self._load_packs()
self._rebuild_index()
def _ensure_directories(self):
"""Ensure all required directories exist."""
self.storage_path.mkdir(parents=True, exist_ok=True)
self.corrections_dir.mkdir(parents=True, exist_ok=True)
self.versions_dir.mkdir(parents=True, exist_ok=True)
# ========== Pack Operations ==========
def list_packs(self) -> List[CorrectionPack]:
"""Get all correction packs."""
return list(self._packs.values())
def get_pack(self, pack_id: str) -> Optional[CorrectionPack]:
"""Get a pack by ID."""
return self._packs.get(pack_id)
def create_pack(self, name: str, description: str = "",
metadata: Optional[Dict[str, Any]] = None) -> CorrectionPack:
"""
Create a new correction pack.
Args:
name: Pack name
description: Pack description
metadata: Optional metadata dict
Returns:
Created pack
"""
pack_id = generate_pack_id()
pack = CorrectionPack(
id=pack_id,
name=name,
description=description
)
if metadata:
pack.metadata.version = metadata.get('version', '1.0.0')
pack.metadata.author = metadata.get('author', '')
pack.metadata.category = metadata.get('category', 'general')
pack.metadata.tags = metadata.get('tags', [])
self._packs[pack_id] = pack
self._save_packs()
logger.info(f"Created correction pack: {pack_id} ({name})")
return pack
def update_pack(self, pack_id: str, updates: Dict[str, Any]) -> Optional[CorrectionPack]:
"""
Update a pack's properties.
Args:
pack_id: Pack ID
updates: Dict of fields to update
Returns:
Updated pack or None if not found
"""
pack = self._packs.get(pack_id)
if not pack:
return None
# Update allowed fields
if 'name' in updates:
pack.name = updates['name']
if 'description' in updates:
pack.description = updates['description']
if 'metadata' in updates:
metadata = updates['metadata']
if 'version' in metadata:
pack.metadata.version = metadata['version']
if 'author' in metadata:
pack.metadata.author = metadata['author']
if 'category' in metadata:
pack.metadata.category = metadata['category']
if 'tags' in metadata:
pack.metadata.tags = metadata['tags']
pack.updated_at = datetime.now()
self._save_packs()
logger.info(f"Updated correction pack: {pack_id}")
return pack
def delete_pack(self, pack_id: str) -> bool:
"""
Delete a pack and all its corrections.
Args:
pack_id: Pack ID
Returns:
True if deleted, False if not found
"""
if pack_id not in self._packs:
return False
pack = self._packs[pack_id]
# Remove from index
for correction in pack.corrections.values():
key_hash = correction.key.to_hash()
if key_hash in self._key_index:
if correction.id in self._key_index[key_hash]:
self._key_index[key_hash].remove(correction.id)
del self._packs[pack_id]
self._save_packs()
# Remove version history
version_dir = self.versions_dir / pack_id
if version_dir.exists():
shutil.rmtree(version_dir)
logger.info(f"Deleted correction pack: {pack_id}")
return True
# ========== Correction Operations ==========
def add_correction(self, pack_id: str, correction: Correction) -> bool:
"""
Add a correction to a pack.
Args:
pack_id: Pack ID
correction: Correction to add
Returns:
True if added, False if pack not found or correction exists
"""
pack = self._packs.get(pack_id)
if not pack:
return False
if not pack.add_correction(correction):
return False
# Update index
key_hash = correction.key.to_hash()
if key_hash not in self._key_index:
self._key_index[key_hash] = []
if correction.id not in self._key_index[key_hash]:
self._key_index[key_hash].append(correction.id)
self._save_packs()
logger.info(f"Added correction {correction.id} to pack {pack_id}")
return True
def remove_correction(self, pack_id: str, correction_id: str) -> bool:
"""
Remove a correction from a pack.
Args:
pack_id: Pack ID
correction_id: Correction ID
Returns:
True if removed, False if not found
"""
pack = self._packs.get(pack_id)
if not pack:
return False
correction = pack.get_correction(correction_id)
if not correction:
return False
# Update index
key_hash = correction.key.to_hash()
if key_hash in self._key_index:
if correction_id in self._key_index[key_hash]:
self._key_index[key_hash].remove(correction_id)
pack.remove_correction(correction_id)
self._save_packs()
logger.info(f"Removed correction {correction_id} from pack {pack_id}")
return True
def update_correction(self, pack_id: str, correction_id: str,
updates: Dict[str, Any]) -> Optional[Correction]:
"""
Update a correction's properties.
Args:
pack_id: Pack ID
correction_id: Correction ID
updates: Dict of fields to update
Returns:
Updated correction or None if not found
"""
pack = self._packs.get(pack_id)
if not pack:
return None
correction = pack.get_correction(correction_id)
if not correction:
return None
# Update allowed fields
if 'corrected_target' in updates:
correction.corrected_target = updates['corrected_target']
if 'corrected_params' in updates:
correction.corrected_params = updates['corrected_params']
if 'correction_description' in updates:
correction.correction_description = updates['correction_description']
if 'status' in updates:
from .models import CorrectionStatus
correction.status = CorrectionStatus(updates['status'])
if 'tags' in updates:
correction.tags = updates['tags']
correction.updated_at = datetime.now()
pack.updated_at = datetime.now()
self._save_packs()
logger.info(f"Updated correction {correction_id} in pack {pack_id}")
return correction
def record_application(self, pack_id: str, correction_id: str, success: bool):
"""
Record an application of a correction.
Args:
pack_id: Pack ID
correction_id: Correction ID
success: Whether the application was successful
"""
pack = self._packs.get(pack_id)
if not pack:
return
correction = pack.get_correction(correction_id)
if not correction:
return
correction.record_application(success)
pack.updated_at = datetime.now()
self._save_packs()
# ========== Search Operations ==========
def find_by_key_hash(self, key_hash: str) -> List[tuple]:
"""
Find all corrections matching a key hash.
Args:
key_hash: MD5 hash of correction key
Returns:
List of (pack_id, correction) tuples
"""
results = []
for pack_id, pack in self._packs.items():
for correction in pack.find_by_key_hash(key_hash):
results.append((pack_id, correction))
# Sort by confidence score
results.sort(key=lambda x: x[1].confidence_score, reverse=True)
return results
def find_by_context(self, action_type: str, element_type: str,
failure_context: str) -> List[tuple]:
"""
Find corrections matching a failure context.
Args:
action_type: Type of action that failed
element_type: Type of element involved
failure_context: Description of failure
Returns:
List of (pack_id, correction) tuples
"""
key = CorrectionKey(action_type, element_type, failure_context)
return self.find_by_key_hash(key.to_hash())
def search_corrections(self, query: str, pack_id: Optional[str] = None) -> List[tuple]:
"""
Search corrections by text query.
Args:
query: Search query (searches in description, tags, etc.)
pack_id: Optional pack ID to search in
Returns:
List of (pack_id, correction) tuples
"""
results = []
query_lower = query.lower()
packs_to_search = [self._packs[pack_id]] if pack_id and pack_id in self._packs else self._packs.values()
for pack in packs_to_search:
for correction in pack.corrections.values():
# Search in description
if query_lower in correction.correction_description.lower():
results.append((pack.id, correction))
continue
# Search in tags
if any(query_lower in tag.lower() for tag in correction.tags):
results.append((pack.id, correction))
continue
# Search in key components
if (query_lower in correction.key.action_type.lower() or
query_lower in correction.key.element_type.lower() or
query_lower in correction.key.failure_context.lower()):
results.append((pack.id, correction))
return results
# ========== Version Operations ==========
def create_version(self, pack_id: str, description: str = "") -> int:
"""
Create a version snapshot of a pack.
Args:
pack_id: Pack ID
description: Version description
Returns:
Version number
"""
pack = self._packs.get(pack_id)
if not pack:
return -1
# Create version directory
pack_versions = self.versions_dir / pack_id
pack_versions.mkdir(parents=True, exist_ok=True)
# Increment version
version = pack.current_version
pack.current_version += 1
# Save snapshot
version_file = pack_versions / f"v{version}.json"
version_data = {
'version': version,
'description': description,
'created_at': datetime.now().isoformat(),
'pack_snapshot': pack.to_dict()
}
self._atomic_write(version_file, version_data)
self._save_packs()
logger.info(f"Created version {version} for pack {pack_id}")
return version
def list_versions(self, pack_id: str) -> List[Dict[str, Any]]:
"""
List all versions of a pack.
Args:
pack_id: Pack ID
Returns:
List of version info dicts
"""
versions = []
pack_versions = self.versions_dir / pack_id
if not pack_versions.exists():
return versions
for version_file in sorted(pack_versions.glob("v*.json")):
try:
with open(version_file, 'r', encoding='utf-8') as f:
data = json.load(f)
versions.append({
'version': data.get('version'),
'description': data.get('description', ''),
'created_at': data.get('created_at'),
'correction_count': len(data.get('pack_snapshot', {}).get('corrections', {}))
})
except Exception as e:
logger.warning(f"Error reading version file {version_file}: {e}")
return versions
def rollback_pack(self, pack_id: str, version: int) -> bool:
"""
Rollback a pack to a previous version.
Args:
pack_id: Pack ID
version: Version number to rollback to
Returns:
True if successful, False otherwise
"""
version_file = self.versions_dir / pack_id / f"v{version}.json"
if not version_file.exists():
logger.warning(f"Version {version} not found for pack {pack_id}")
return False
try:
with open(version_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# Restore pack from snapshot
snapshot = data.get('pack_snapshot', {})
pack = CorrectionPack.from_dict(snapshot)
# Update version number to continue from current
if pack_id in self._packs:
pack.current_version = self._packs[pack_id].current_version + 1
self._packs[pack_id] = pack
self._rebuild_index()
self._save_packs()
logger.info(f"Rolled back pack {pack_id} to version {version}")
return True
except Exception as e:
logger.error(f"Error rolling back pack {pack_id} to version {version}: {e}")
return False
# ========== Persistence ==========
def _load_packs(self):
"""Load all packs from storage."""
if not self.packs_file.exists():
return
try:
with open(self.packs_file, 'r', encoding='utf-8') as f:
data = json.load(f)
for pack_id, pack_data in data.items():
try:
self._packs[pack_id] = CorrectionPack.from_dict(pack_data)
except Exception as e:
logger.warning(f"Error loading pack {pack_id}: {e}")
logger.info(f"Loaded {len(self._packs)} correction packs")
except Exception as e:
logger.error(f"Error loading packs file: {e}")
def _save_packs(self):
"""Save all packs to storage using atomic write."""
try:
data = {
pack_id: pack.to_dict()
for pack_id, pack in self._packs.items()
}
self._atomic_write(self.packs_file, data)
except Exception as e:
logger.error(f"Error saving packs: {e}")
def _atomic_write(self, file_path: Path, data: Dict[str, Any]):
"""
Write data to file atomically (temp file + rename).
Args:
file_path: Target file path
data: Data to write
"""
temp_file = file_path.with_suffix('.tmp')
try:
with open(temp_file, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
# Atomic rename
temp_file.replace(file_path)
except Exception as e:
# Clean up temp file on error
if temp_file.exists():
temp_file.unlink()
raise e
def _rebuild_index(self):
"""Rebuild the key hash index from all packs."""
self._key_index.clear()
for pack in self._packs.values():
for correction in pack.corrections.values():
key_hash = correction.key.to_hash()
if key_hash not in self._key_index:
self._key_index[key_hash] = []
if correction.id not in self._key_index[key_hash]:
self._key_index[key_hash].append(correction.id)
# ========== Export/Import ==========
def export_pack(self, pack_id: str, format: str = 'json') -> Optional[str]:
"""
Export a pack to JSON or YAML format.
Args:
pack_id: Pack ID
format: Export format ('json' or 'yaml')
Returns:
Exported content as string, or None if pack not found
"""
pack = self._packs.get(pack_id)
if not pack:
return None
export_data = {
'export_format': 'correction_pack',
'export_version': '1.0',
'exported_at': datetime.now().isoformat(),
'pack': pack.to_dict()
}
if format == 'yaml':
try:
import yaml
return yaml.dump(export_data, default_flow_style=False, allow_unicode=True)
except ImportError:
logger.warning("PyYAML not installed, falling back to JSON")
return json.dumps(export_data, indent=2, ensure_ascii=False)
else:
return json.dumps(export_data, indent=2, ensure_ascii=False)
def import_pack(self, content: str, format: str = 'json',
merge_strategy: str = 'create_new') -> Optional[CorrectionPack]:
"""
Import a pack from JSON or YAML content.
Args:
content: Content to import
format: Import format ('json' or 'yaml')
merge_strategy: How to handle existing pack:
- 'create_new': Always create new pack with new ID
- 'merge': Merge into existing pack if same ID
- 'replace': Replace existing pack if same ID
Returns:
Imported pack, or None on error
"""
try:
if format == 'yaml':
try:
import yaml
data = yaml.safe_load(content)
except ImportError:
logger.error("PyYAML not installed")
return None
else:
data = json.loads(content)
# Validate format
if data.get('export_format') != 'correction_pack':
logger.warning("Invalid export format")
return None
pack_data = data.get('pack', {})
imported_pack = CorrectionPack.from_dict(pack_data)
# Handle merge strategy
if merge_strategy == 'create_new':
imported_pack.id = generate_pack_id()
self._packs[imported_pack.id] = imported_pack
elif merge_strategy == 'merge' and imported_pack.id in self._packs:
existing_pack = self._packs[imported_pack.id]
existing_pack.merge(imported_pack)
imported_pack = existing_pack
elif merge_strategy == 'replace':
self._packs[imported_pack.id] = imported_pack
else:
# Default: create new
imported_pack.id = generate_pack_id()
self._packs[imported_pack.id] = imported_pack
self._rebuild_index()
self._save_packs()
logger.info(f"Imported pack: {imported_pack.id} ({imported_pack.name})")
return imported_pack
except Exception as e:
logger.error(f"Error importing pack: {e}")
return None

480
core/corrections/models.py Normal file
View File

@@ -0,0 +1,480 @@
"""
Data models for the Correction Packs system.
Defines CorrectionKey, Correction, CorrectionPack and related types.
"""
import hashlib
import uuid
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
from typing import Dict, List, Any, Optional
class CorrectionType(Enum):
"""Types of corrections that can be applied."""
TARGET_CHANGE = "target_change" # Changed target selector
PARAMETER_CHANGE = "parameter_change" # Changed action parameters
ACTION_REPLACE = "action_replace" # Replaced action type
WAIT_ADDED = "wait_added" # Added wait before action
RETRY_CONFIG = "retry_config" # Changed retry configuration
FALLBACK_ADDED = "fallback_added" # Added fallback action
COORDINATES_ADJUST = "coordinates_adjust" # Adjusted click coordinates
TIMING_ADJUST = "timing_adjust" # Adjusted timing/delay
OTHER = "other"
class CorrectionStatus(Enum):
"""Status of a correction."""
ACTIVE = "active" # Correction is active and can be applied
DEPRECATED = "deprecated" # Correction is deprecated
DISABLED = "disabled" # Correction is manually disabled
@dataclass
class CorrectionKey:
"""
Unique key to identify similar corrections across workflows.
Uses action_type + element_type + failure_context to generate MD5 hash.
Similar to LearningRepository pattern key generation.
"""
action_type: str
element_type: str
failure_context: str
def to_hash(self) -> str:
"""Generate MD5 hash of the correction key (16 chars)."""
key_string = f"{self.action_type}|{self.element_type}|{self.failure_context}"
return hashlib.md5(key_string.encode()).hexdigest()[:16]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'action_type': self.action_type,
'element_type': self.element_type,
'failure_context': self.failure_context,
'hash': self.to_hash()
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionKey':
"""Create from dictionary."""
return cls(
action_type=data.get('action_type', 'unknown'),
element_type=data.get('element_type', 'unknown'),
failure_context=data.get('failure_context', '')
)
@dataclass
class CorrectionSource:
"""Source information for a correction."""
session_id: str
workflow_id: Optional[str] = None
node_id: Optional[str] = None
timestamp: datetime = field(default_factory=datetime.now)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'session_id': self.session_id,
'workflow_id': self.workflow_id,
'node_id': self.node_id,
'timestamp': self.timestamp.isoformat()
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionSource':
"""Create from dictionary."""
timestamp = data.get('timestamp')
if isinstance(timestamp, str):
timestamp = datetime.fromisoformat(timestamp)
elif timestamp is None:
timestamp = datetime.now()
return cls(
session_id=data.get('session_id', ''),
workflow_id=data.get('workflow_id'),
node_id=data.get('node_id'),
timestamp=timestamp
)
@dataclass
class Correction:
"""
Individual correction record.
Contains the original context, the applied correction, source information,
and statistics about success/failure.
"""
id: str
key: CorrectionKey
correction_type: CorrectionType
# Original context
original_target: Optional[Dict[str, Any]] = None
original_params: Optional[Dict[str, Any]] = None
# Correction applied
corrected_target: Optional[Dict[str, Any]] = None
corrected_params: Optional[Dict[str, Any]] = None
correction_description: str = ""
# Source information
source: Optional[CorrectionSource] = None
# Statistics
success_count: int = 0
failure_count: int = 0
last_applied: Optional[datetime] = None
# Status
status: CorrectionStatus = CorrectionStatus.ACTIVE
# Metadata
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
tags: List[str] = field(default_factory=list)
@property
def success_rate(self) -> float:
"""Calculate success rate."""
total = self.success_count + self.failure_count
if total == 0:
return 0.0
return self.success_count / total
@property
def confidence_score(self) -> float:
"""Calculate confidence score based on success rate and usage count."""
total = self.success_count + self.failure_count
if total == 0:
return 0.0
# Confidence increases with more data points (up to 100 uses)
usage_factor = min(1.0, total / 100)
return self.success_rate * (0.5 + 0.5 * usage_factor)
def record_application(self, success: bool):
"""Record an application of this correction."""
if success:
self.success_count += 1
else:
self.failure_count += 1
self.last_applied = datetime.now()
self.updated_at = datetime.now()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'id': self.id,
'key': self.key.to_dict(),
'correction_type': self.correction_type.value,
'original_target': self.original_target,
'original_params': self.original_params,
'corrected_target': self.corrected_target,
'corrected_params': self.corrected_params,
'correction_description': self.correction_description,
'source': self.source.to_dict() if self.source else None,
'success_count': self.success_count,
'failure_count': self.failure_count,
'success_rate': self.success_rate,
'confidence_score': self.confidence_score,
'last_applied': self.last_applied.isoformat() if self.last_applied else None,
'status': self.status.value,
'created_at': self.created_at.isoformat(),
'updated_at': self.updated_at.isoformat(),
'tags': self.tags
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Correction':
"""Create from dictionary."""
# Parse datetime fields
created_at = data.get('created_at')
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at)
elif created_at is None:
created_at = datetime.now()
updated_at = data.get('updated_at')
if isinstance(updated_at, str):
updated_at = datetime.fromisoformat(updated_at)
elif updated_at is None:
updated_at = datetime.now()
last_applied = data.get('last_applied')
if isinstance(last_applied, str):
last_applied = datetime.fromisoformat(last_applied)
# Parse nested objects
key_data = data.get('key', {})
key = CorrectionKey.from_dict(key_data)
source_data = data.get('source')
source = CorrectionSource.from_dict(source_data) if source_data else None
return cls(
id=data.get('id', str(uuid.uuid4())),
key=key,
correction_type=CorrectionType(data.get('correction_type', 'other')),
original_target=data.get('original_target'),
original_params=data.get('original_params'),
corrected_target=data.get('corrected_target'),
corrected_params=data.get('corrected_params'),
correction_description=data.get('correction_description', ''),
source=source,
success_count=data.get('success_count', 0),
failure_count=data.get('failure_count', 0),
last_applied=last_applied,
status=CorrectionStatus(data.get('status', 'active')),
created_at=created_at,
updated_at=updated_at,
tags=data.get('tags', [])
)
@dataclass
class CorrectionPackMetadata:
"""Metadata for a correction pack."""
version: str = "1.0.0"
author: str = ""
category: str = "general"
tags: List[str] = field(default_factory=list)
description: str = ""
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'version': self.version,
'author': self.author,
'category': self.category,
'tags': self.tags,
'description': self.description
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPackMetadata':
"""Create from dictionary."""
return cls(
version=data.get('version', '1.0.0'),
author=data.get('author', ''),
category=data.get('category', 'general'),
tags=data.get('tags', []),
description=data.get('description', '')
)
@dataclass
class CorrectionPack:
"""
A pack of corrections that can be exported, imported, and shared.
Contains multiple corrections with metadata and statistics.
"""
id: str
name: str
description: str = ""
# Corrections in this pack
corrections: Dict[str, Correction] = field(default_factory=dict)
# Metadata
metadata: CorrectionPackMetadata = field(default_factory=CorrectionPackMetadata)
# Timestamps
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
# Version history
current_version: int = 1
@property
def correction_count(self) -> int:
"""Number of corrections in the pack."""
return len(self.corrections)
@property
def total_applications(self) -> int:
"""Total number of times corrections were applied."""
return sum(c.success_count + c.failure_count for c in self.corrections.values())
@property
def overall_success_rate(self) -> float:
"""Overall success rate across all corrections."""
total_success = sum(c.success_count for c in self.corrections.values())
total_failure = sum(c.failure_count for c in self.corrections.values())
total = total_success + total_failure
if total == 0:
return 0.0
return total_success / total
@property
def active_corrections(self) -> List[Correction]:
"""Get only active corrections."""
return [c for c in self.corrections.values() if c.status == CorrectionStatus.ACTIVE]
def add_correction(self, correction: Correction) -> bool:
"""
Add a correction to the pack.
Returns True if added, False if already exists.
"""
if correction.id in self.corrections:
return False
self.corrections[correction.id] = correction
self.updated_at = datetime.now()
return True
def remove_correction(self, correction_id: str) -> bool:
"""
Remove a correction from the pack.
Returns True if removed, False if not found.
"""
if correction_id not in self.corrections:
return False
del self.corrections[correction_id]
self.updated_at = datetime.now()
return True
def get_correction(self, correction_id: str) -> Optional[Correction]:
"""Get a correction by ID."""
return self.corrections.get(correction_id)
def find_by_key_hash(self, key_hash: str) -> List[Correction]:
"""Find corrections matching a key hash."""
return [
c for c in self.corrections.values()
if c.key.to_hash() == key_hash and c.status == CorrectionStatus.ACTIVE
]
def merge(self, other: 'CorrectionPack', conflict_strategy: str = 'keep_higher_confidence') -> int:
"""
Merge another pack into this one.
Args:
other: Pack to merge
conflict_strategy: How to handle conflicts:
- 'keep_higher_confidence': Keep correction with higher confidence
- 'keep_newer': Keep the most recently updated correction
- 'keep_existing': Keep existing, ignore new
- 'replace': Always replace with new
Returns:
Number of corrections added/updated
"""
merged_count = 0
for correction_id, correction in other.corrections.items():
if correction_id in self.corrections:
existing = self.corrections[correction_id]
should_replace = False
if conflict_strategy == 'keep_higher_confidence':
should_replace = correction.confidence_score > existing.confidence_score
elif conflict_strategy == 'keep_newer':
should_replace = correction.updated_at > existing.updated_at
elif conflict_strategy == 'replace':
should_replace = True
# 'keep_existing' doesn't replace
if should_replace:
self.corrections[correction_id] = correction
merged_count += 1
else:
self.corrections[correction_id] = correction
merged_count += 1
if merged_count > 0:
self.updated_at = datetime.now()
return merged_count
def get_statistics(self) -> Dict[str, Any]:
"""Get aggregate statistics for the pack."""
corrections = list(self.corrections.values())
active = [c for c in corrections if c.status == CorrectionStatus.ACTIVE]
# Group by correction type
by_type = {}
for c in corrections:
type_name = c.correction_type.value
if type_name not in by_type:
by_type[type_name] = 0
by_type[type_name] += 1
return {
'total_corrections': len(corrections),
'active_corrections': len(active),
'deprecated_corrections': len([c for c in corrections if c.status == CorrectionStatus.DEPRECATED]),
'disabled_corrections': len([c for c in corrections if c.status == CorrectionStatus.DISABLED]),
'total_applications': self.total_applications,
'overall_success_rate': self.overall_success_rate,
'corrections_by_type': by_type,
'average_confidence': sum(c.confidence_score for c in active) / len(active) if active else 0.0
}
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'id': self.id,
'name': self.name,
'description': self.description,
'corrections': {cid: c.to_dict() for cid, c in self.corrections.items()},
'metadata': self.metadata.to_dict(),
'created_at': self.created_at.isoformat(),
'updated_at': self.updated_at.isoformat(),
'current_version': self.current_version,
'statistics': self.get_statistics()
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPack':
"""Create from dictionary."""
# Parse datetime fields
created_at = data.get('created_at')
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at)
elif created_at is None:
created_at = datetime.now()
updated_at = data.get('updated_at')
if isinstance(updated_at, str):
updated_at = datetime.fromisoformat(updated_at)
elif updated_at is None:
updated_at = datetime.now()
# Parse corrections
corrections_data = data.get('corrections', {})
corrections = {
cid: Correction.from_dict(cdata)
for cid, cdata in corrections_data.items()
}
# Parse metadata
metadata_data = data.get('metadata', {})
metadata = CorrectionPackMetadata.from_dict(metadata_data)
return cls(
id=data.get('id', str(uuid.uuid4())),
name=data.get('name', 'Unnamed Pack'),
description=data.get('description', ''),
corrections=corrections,
metadata=metadata,
created_at=created_at,
updated_at=updated_at,
current_version=data.get('current_version', 1)
)
def generate_correction_id() -> str:
"""Generate a unique correction ID."""
return f"corr_{uuid.uuid4().hex[:12]}"
def generate_pack_id() -> str:
"""Generate a unique pack ID."""
return f"pack_{uuid.uuid4().hex[:12]}"

View File

@@ -0,0 +1,443 @@
"""Main self-healing engine for workflow recovery."""
import time
import logging
from typing import List, Optional, Dict, Any, Tuple
from pathlib import Path
from .models import RecoveryContext, RecoveryResult, RecoverySuggestion
from .learning_repository import LearningRepository
from .confidence_scorer import ConfidenceScorer
from .strategies import (
RecoveryStrategy,
SemanticVariantStrategy,
SpatialFallbackStrategy,
TimingAdaptationStrategy,
FormatTransformationStrategy
)
logger = logging.getLogger(__name__)
class SelfHealingEngine:
"""Main engine for self-healing workflows."""
def __init__(
self,
learning_repo: Optional[LearningRepository] = None,
storage_path: Optional[Path] = None,
correction_packs_enabled: bool = True
):
"""
Initialize the self-healing engine.
Args:
learning_repo: Learning repository instance
storage_path: Path for storing learned patterns
correction_packs_enabled: Whether to consult correction packs
"""
# Initialize learning repository
if learning_repo:
self.learning_repo = learning_repo
else:
storage_path = storage_path or Path('data/healing')
self.learning_repo = LearningRepository(storage_path)
# Initialize confidence scorer
self.confidence_scorer = ConfidenceScorer()
# Initialize recovery strategies
self.recovery_strategies: List[RecoveryStrategy] = [
SemanticVariantStrategy(),
SpatialFallbackStrategy(),
TimingAdaptationStrategy(),
FormatTransformationStrategy()
]
# Configuration
self.max_recovery_time = 30.0 # seconds
self.parallel_execution = True
# Correction packs integration
self.correction_packs_enabled = correction_packs_enabled
self._correction_pack_service = None
@property
def correction_pack_service(self):
"""Lazy-load correction pack service."""
if self._correction_pack_service is None and self.correction_packs_enabled:
try:
from core.corrections import CorrectionPackService
self._correction_pack_service = CorrectionPackService()
logger.info("CorrectionPackService initialized for healing engine")
except ImportError as e:
logger.warning(f"CorrectionPackService not available: {e}")
self.correction_packs_enabled = False
return self._correction_pack_service
def attempt_recovery(self, context: RecoveryContext) -> RecoveryResult:
"""
Attempt to recover from a workflow failure.
Args:
context: Recovery context with failure information
Returns:
RecoveryResult with outcome
"""
start_time = time.time()
# Check if we've exceeded max attempts
if context.attempt_count >= context.max_attempts:
return RecoveryResult(
success=False,
strategy_used='none',
error_message=f'Max recovery attempts ({context.max_attempts}) exceeded'
)
# Try correction packs first (if enabled)
if self.correction_packs_enabled and self.correction_pack_service:
pack_result = self._try_correction_packs(context)
if pack_result and pack_result.success:
return pack_result
# Get learned patterns for this context
learned_patterns = self.learning_repo.get_matching_patterns(context)
# Get prioritized strategies
strategies = self._prioritize_strategies(context, learned_patterns)
# Try each strategy
for strategy in strategies:
# Check time limit
elapsed = time.time() - start_time
if elapsed >= self.max_recovery_time:
break
# Skip if strategy can't handle this failure
if not strategy.can_handle(context):
continue
# Attempt recovery
result = strategy.attempt_recovery(context)
# If successful, learn from it
if result.success:
# Calculate final confidence
historical_success = self._get_historical_success_rate(
strategy.name, context, learned_patterns
)
result.confidence_score = self.confidence_scorer.calculate_recovery_confidence(
result.strategy_used,
context,
historical_success
)
# Check if safe to proceed
if self.confidence_scorer.is_safe_to_proceed(
result.confidence_score,
context.confidence_threshold,
involves_data_modification=self._involves_data_modification(context)
):
# Learn from success
self.learn_from_success(context, result)
return result
else:
# Confidence too low, mark as requiring user input
result.requires_user_input = True
return result
# All strategies failed
total_time = time.time() - start_time
return RecoveryResult(
success=False,
strategy_used='all_failed',
execution_time=total_time,
error_message='All recovery strategies failed',
requires_user_input=True
)
def learn_from_success(self, context: RecoveryContext, result: RecoveryResult):
"""
Learn from successful recovery for future use.
Args:
context: Recovery context
result: Successful recovery result
"""
self.learning_repo.store_pattern(context, result)
def get_recovery_suggestions(self, context: RecoveryContext) -> List[RecoverySuggestion]:
"""
Get ranked recovery suggestions based on learned patterns.
Args:
context: Recovery context
Returns:
List of recovery suggestions sorted by confidence
"""
suggestions = []
# Get learned patterns
learned_patterns = self.learning_repo.get_matching_patterns(context)
# Get suggestions from each strategy
for strategy in self.recovery_strategies:
if not strategy.can_handle(context):
continue
# Calculate confidence based on historical success
historical_success = self._get_historical_success_rate(
strategy.name, context, learned_patterns
)
confidence = self.confidence_scorer.calculate_recovery_confidence(
strategy.name,
context,
historical_success
)
suggestion = RecoverySuggestion(
strategy=strategy.name,
confidence=confidence,
description=self._get_strategy_description(strategy),
estimated_time=self._estimate_strategy_time(strategy, context)
)
suggestions.append(suggestion)
# Sort by confidence
suggestions.sort(key=lambda s: s.confidence, reverse=True)
return suggestions
def prune_learned_patterns(
self,
max_age_days: int = 90,
min_confidence: float = 0.3
):
"""
Prune outdated learned patterns.
Args:
max_age_days: Maximum age for patterns
min_confidence: Minimum confidence threshold
"""
self.learning_repo.prune_outdated_patterns(max_age_days, min_confidence)
def _prioritize_strategies(
self,
context: RecoveryContext,
learned_patterns: List
) -> List[RecoveryStrategy]:
"""Prioritize strategies based on context and learned patterns."""
# Score each strategy
scored_strategies = []
for strategy in self.recovery_strategies:
# Base priority from strategy
priority = strategy.get_priority(context)
# Boost priority if we have successful patterns
historical_success = self._get_historical_success_rate(
strategy.name, context, learned_patterns
)
if historical_success > 0:
priority *= (1.0 + historical_success)
scored_strategies.append((priority, strategy))
# Sort by priority (highest first)
scored_strategies.sort(key=lambda x: x[0], reverse=True)
return [strategy for _, strategy in scored_strategies]
def _get_historical_success_rate(
self,
strategy_name: str,
context: RecoveryContext,
learned_patterns: List
) -> float:
"""Get historical success rate for a strategy."""
matching_patterns = [
p for p in learned_patterns
if p.recovery_strategy == strategy_name
]
if not matching_patterns:
return 0.0
# Return average success rate
return sum(p.success_rate for p in matching_patterns) / len(matching_patterns)
def _involves_data_modification(self, context: RecoveryContext) -> bool:
"""Check if the action involves data modification."""
data_modification_actions = [
'input', 'type', 'submit', 'delete', 'update', 'save'
]
return any(action in context.original_action.lower()
for action in data_modification_actions)
def _get_strategy_description(self, strategy: RecoveryStrategy) -> str:
"""Get human-readable description of strategy."""
descriptions = {
'SemanticVariantStrategy': 'Try semantic variants of the element (e.g., Submit → Send)',
'SpatialFallbackStrategy': 'Search in expanded area around original position',
'TimingAdaptationStrategy': 'Increase wait time for element to appear',
'FormatTransformationStrategy': 'Transform input format to match validation'
}
return descriptions.get(strategy.name, strategy.name)
def _estimate_strategy_time(
self,
strategy: RecoveryStrategy,
context: RecoveryContext
) -> float:
"""Estimate execution time for strategy."""
# Simple estimates in seconds
estimates = {
'SemanticVariantStrategy': 2.0,
'SpatialFallbackStrategy': 5.0,
'TimingAdaptationStrategy': 10.0,
'FormatTransformationStrategy': 1.0
}
return estimates.get(strategy.name, 3.0)
def _try_correction_packs(self, context: RecoveryContext) -> Optional[RecoveryResult]:
"""
Try to find and apply a correction from correction packs.
Args:
context: Recovery context
Returns:
RecoveryResult if a correction was successfully applied, None otherwise
"""
try:
# Get element type from metadata
element_type = context.metadata.get('element_type', 'unknown')
# Search for applicable corrections
corrections = self.correction_pack_service.find_applicable_corrections(
action_type=context.original_action,
element_type=element_type,
failure_context=context.failure_reason,
min_confidence=0.5 # Only use high-confidence corrections
)
if not corrections:
return None
# Try corrections in order of confidence
for correction_info in corrections:
pack_id = correction_info['pack_id']
correction = correction_info['correction']
try:
# Apply the correction
applied = self._apply_correction(context, correction)
if applied:
# Record successful application
self.correction_pack_service.apply_correction(
pack_id=pack_id,
correction_id=correction['id'],
success=True
)
# Propagate to learning repository
result = RecoveryResult(
success=True,
strategy_used=f"correction_pack:{correction['correction_type']}",
confidence_score=correction['confidence_score'],
modified_action=correction.get('corrected_params'),
modified_target=correction.get('corrected_target')
)
# Learn from this for future reference
self.learning_repo.store_pattern(context, result)
logger.info(
f"Applied correction from pack {pack_id}: "
f"{correction['id']} (confidence: {correction['confidence_score']:.2f})"
)
return result
except Exception as e:
logger.warning(f"Failed to apply correction {correction['id']}: {e}")
# Record failed application
self.correction_pack_service.apply_correction(
pack_id=pack_id,
correction_id=correction['id'],
success=False
)
return None
except Exception as e:
logger.error(f"Error in correction pack lookup: {e}")
return None
def _apply_correction(
self,
context: RecoveryContext,
correction: Dict[str, Any]
) -> bool:
"""
Apply a correction to the current context.
Args:
context: Recovery context
correction: Correction data dict
Returns:
True if correction was applied successfully
"""
correction_type = correction.get('correction_type', 'other')
# Handle different correction types
if correction_type == 'target_change':
# Update target in context
if correction.get('corrected_target'):
context.metadata['corrected_target'] = correction['corrected_target']
return True
elif correction_type == 'parameter_change':
# Update parameters in context
if correction.get('corrected_params'):
context.metadata['corrected_params'] = correction['corrected_params']
return True
elif correction_type == 'wait_added':
# Add wait time
wait_time = correction.get('corrected_params', {}).get('wait_time', 2.0)
context.metadata['additional_wait'] = wait_time
return True
elif correction_type == 'timing_adjust':
# Adjust timing
timing = correction.get('corrected_params', {}).get('timing_multiplier', 1.5)
context.metadata['timing_multiplier'] = timing
return True
elif correction_type == 'coordinates_adjust':
# Adjust coordinates
offset = correction.get('corrected_params', {}).get('offset', {})
context.metadata['coordinate_offset'] = offset
return True
# For other types, just mark as applied if we have any correction data
return bool(correction.get('corrected_target') or correction.get('corrected_params'))
def get_correction_pack_statistics(self) -> Optional[Dict[str, Any]]:
"""
Get statistics from correction packs.
Returns:
Statistics dict or None if not available
"""
if not self.correction_packs_enabled or not self.correction_pack_service:
return None
try:
return self.correction_pack_service.get_global_statistics()
except Exception as e:
logger.error(f"Error getting correction pack statistics: {e}")
return None