feat(corrections): Add Correction Packs system for cross-workflow learning
Implement a complete system for capitalizing user corrections across multiple workflows and sessions. This enables automatic application of learned fixes when similar failures occur in different contexts. New components: - core/corrections/models.py: CorrectionKey, Correction, CorrectionPack models - core/corrections/correction_repository.py: JSON storage with atomic writes - core/corrections/aggregator.py: Aggregation by hash and quality filtering - core/corrections/correction_pack_service.py: CRUD, export/import, versioning - backend/api/correction_packs.py: REST API with 15 endpoints Features: - MD5-based key hashing for correction deduplication - Export/import in JSON and YAML formats - Version history with rollback support - Cross-workflow pattern detection - Integration with SelfHealingEngine for automatic application - 29 unit tests (all passing) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
30
core/corrections/__init__.py
Normal file
30
core/corrections/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
Correction Packs System - Cross-workflow correction capitalization.
|
||||||
|
|
||||||
|
This module provides a system for storing, aggregating, and applying
|
||||||
|
user corrections across multiple workflows and sessions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
CorrectionKey,
|
||||||
|
CorrectionType,
|
||||||
|
CorrectionStatus,
|
||||||
|
Correction,
|
||||||
|
CorrectionPack,
|
||||||
|
CorrectionPackMetadata
|
||||||
|
)
|
||||||
|
from .correction_repository import CorrectionRepository
|
||||||
|
from .aggregator import CorrectionAggregator
|
||||||
|
from .correction_pack_service import CorrectionPackService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'CorrectionKey',
|
||||||
|
'CorrectionType',
|
||||||
|
'CorrectionStatus',
|
||||||
|
'Correction',
|
||||||
|
'CorrectionPack',
|
||||||
|
'CorrectionPackMetadata',
|
||||||
|
'CorrectionRepository',
|
||||||
|
'CorrectionAggregator',
|
||||||
|
'CorrectionPackService'
|
||||||
|
]
|
||||||
545
core/corrections/aggregator.py
Normal file
545
core/corrections/aggregator.py
Normal file
@@ -0,0 +1,545 @@
|
|||||||
|
"""
|
||||||
|
Correction Aggregator - Aggregate corrections from multiple sessions.
|
||||||
|
|
||||||
|
Groups corrections by key hash and filters by quality thresholds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List, Optional, Any, Tuple
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
Correction,
|
||||||
|
CorrectionKey,
|
||||||
|
CorrectionType,
|
||||||
|
CorrectionSource,
|
||||||
|
CorrectionStatus,
|
||||||
|
generate_correction_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectionAggregator:
|
||||||
|
"""
|
||||||
|
Aggregates corrections from multiple sources (sessions, workflows).
|
||||||
|
|
||||||
|
Groups corrections by CorrectionKey hash and merges statistics.
|
||||||
|
Filters out low-quality corrections based on configurable thresholds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_occurrences: int = 2,
|
||||||
|
min_success_rate: float = 0.5,
|
||||||
|
min_confidence: float = 0.3
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the aggregator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_occurrences: Minimum number of times a correction must occur
|
||||||
|
min_success_rate: Minimum success rate to include correction
|
||||||
|
min_confidence: Minimum confidence score to include correction
|
||||||
|
"""
|
||||||
|
self.min_occurrences = min_occurrences
|
||||||
|
self.min_success_rate = min_success_rate
|
||||||
|
self.min_confidence = min_confidence
|
||||||
|
|
||||||
|
def aggregate_corrections(
|
||||||
|
self,
|
||||||
|
corrections: List[Correction],
|
||||||
|
merge_strategy: str = 'weighted_average'
|
||||||
|
) -> List[Correction]:
|
||||||
|
"""
|
||||||
|
Aggregate a list of corrections by key hash.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corrections: List of corrections to aggregate
|
||||||
|
merge_strategy: How to merge duplicate corrections:
|
||||||
|
- 'weighted_average': Weight by usage count
|
||||||
|
- 'best_confidence': Keep highest confidence
|
||||||
|
- 'most_recent': Keep most recently used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of aggregated corrections
|
||||||
|
"""
|
||||||
|
# Group by key hash
|
||||||
|
grouped: Dict[str, List[Correction]] = defaultdict(list)
|
||||||
|
for correction in corrections:
|
||||||
|
key_hash = correction.key.to_hash()
|
||||||
|
grouped[key_hash].append(correction)
|
||||||
|
|
||||||
|
# Aggregate each group
|
||||||
|
aggregated = []
|
||||||
|
for key_hash, group in grouped.items():
|
||||||
|
if len(group) == 1:
|
||||||
|
# Single correction, just check thresholds
|
||||||
|
if self._passes_thresholds(group[0]):
|
||||||
|
aggregated.append(group[0])
|
||||||
|
else:
|
||||||
|
# Multiple corrections, merge
|
||||||
|
merged = self._merge_corrections(group, merge_strategy)
|
||||||
|
if merged and self._passes_thresholds(merged):
|
||||||
|
aggregated.append(merged)
|
||||||
|
|
||||||
|
return aggregated
|
||||||
|
|
||||||
|
def aggregate_from_sessions(
|
||||||
|
self,
|
||||||
|
session_corrections: List[Dict[str, Any]],
|
||||||
|
workflow_id: Optional[str] = None
|
||||||
|
) -> List[Correction]:
|
||||||
|
"""
|
||||||
|
Aggregate corrections from session data (user_corrections format).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_corrections: List of correction dicts from TrainingSession
|
||||||
|
workflow_id: Optional workflow ID to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of aggregated Correction objects
|
||||||
|
"""
|
||||||
|
# Convert session corrections to Correction objects
|
||||||
|
corrections = []
|
||||||
|
for sc in session_corrections:
|
||||||
|
try:
|
||||||
|
correction = self._session_correction_to_correction(sc)
|
||||||
|
if correction:
|
||||||
|
# Filter by workflow if specified
|
||||||
|
if workflow_id is None or (correction.source and correction.source.workflow_id == workflow_id):
|
||||||
|
corrections.append(correction)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error converting session correction: {e}")
|
||||||
|
|
||||||
|
# Aggregate
|
||||||
|
return self.aggregate_corrections(corrections)
|
||||||
|
|
||||||
|
def group_by_type(
|
||||||
|
self,
|
||||||
|
corrections: List[Correction]
|
||||||
|
) -> Dict[CorrectionType, List[Correction]]:
|
||||||
|
"""
|
||||||
|
Group corrections by their type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corrections: List of corrections
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping correction type to list of corrections
|
||||||
|
"""
|
||||||
|
grouped: Dict[CorrectionType, List[Correction]] = defaultdict(list)
|
||||||
|
for correction in corrections:
|
||||||
|
grouped[correction.correction_type].append(correction)
|
||||||
|
return dict(grouped)
|
||||||
|
|
||||||
|
def group_by_action_type(
|
||||||
|
self,
|
||||||
|
corrections: List[Correction]
|
||||||
|
) -> Dict[str, List[Correction]]:
|
||||||
|
"""
|
||||||
|
Group corrections by the action type they apply to.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corrections: List of corrections
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping action type to list of corrections
|
||||||
|
"""
|
||||||
|
grouped: Dict[str, List[Correction]] = defaultdict(list)
|
||||||
|
for correction in corrections:
|
||||||
|
grouped[correction.key.action_type].append(correction)
|
||||||
|
return dict(grouped)
|
||||||
|
|
||||||
|
def filter_by_quality(
|
||||||
|
self,
|
||||||
|
corrections: List[Correction],
|
||||||
|
min_success_rate: Optional[float] = None,
|
||||||
|
min_confidence: Optional[float] = None,
|
||||||
|
min_applications: Optional[int] = None
|
||||||
|
) -> List[Correction]:
|
||||||
|
"""
|
||||||
|
Filter corrections by quality thresholds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corrections: List of corrections to filter
|
||||||
|
min_success_rate: Minimum success rate (overrides default)
|
||||||
|
min_confidence: Minimum confidence (overrides default)
|
||||||
|
min_applications: Minimum number of applications
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of corrections
|
||||||
|
"""
|
||||||
|
min_sr = min_success_rate if min_success_rate is not None else self.min_success_rate
|
||||||
|
min_conf = min_confidence if min_confidence is not None else self.min_confidence
|
||||||
|
min_apps = min_applications if min_applications is not None else 1
|
||||||
|
|
||||||
|
return [
|
||||||
|
c for c in corrections
|
||||||
|
if c.success_rate >= min_sr
|
||||||
|
and c.confidence_score >= min_conf
|
||||||
|
and (c.success_count + c.failure_count) >= min_apps
|
||||||
|
]
|
||||||
|
|
||||||
|
def rank_corrections(
|
||||||
|
self,
|
||||||
|
corrections: List[Correction],
|
||||||
|
sort_by: str = 'confidence'
|
||||||
|
) -> List[Correction]:
|
||||||
|
"""
|
||||||
|
Rank corrections by specified criteria.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corrections: List of corrections
|
||||||
|
sort_by: Ranking criteria:
|
||||||
|
- 'confidence': By confidence score
|
||||||
|
- 'success_rate': By success rate
|
||||||
|
- 'usage': By total applications
|
||||||
|
- 'recent': By last application date
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of corrections
|
||||||
|
"""
|
||||||
|
if sort_by == 'confidence':
|
||||||
|
return sorted(corrections, key=lambda c: c.confidence_score, reverse=True)
|
||||||
|
elif sort_by == 'success_rate':
|
||||||
|
return sorted(corrections, key=lambda c: c.success_rate, reverse=True)
|
||||||
|
elif sort_by == 'usage':
|
||||||
|
return sorted(corrections, key=lambda c: c.success_count + c.failure_count, reverse=True)
|
||||||
|
elif sort_by == 'recent':
|
||||||
|
return sorted(
|
||||||
|
corrections,
|
||||||
|
key=lambda c: c.last_applied or datetime.min,
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return corrections
|
||||||
|
|
||||||
|
def calculate_statistics(
|
||||||
|
self,
|
||||||
|
corrections: List[Correction]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculate aggregate statistics for a list of corrections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corrections: List of corrections
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Statistics dict
|
||||||
|
"""
|
||||||
|
if not corrections:
|
||||||
|
return {
|
||||||
|
'count': 0,
|
||||||
|
'total_applications': 0,
|
||||||
|
'average_success_rate': 0.0,
|
||||||
|
'average_confidence': 0.0,
|
||||||
|
'by_type': {},
|
||||||
|
'by_action': {}
|
||||||
|
}
|
||||||
|
|
||||||
|
total_apps = sum(c.success_count + c.failure_count for c in corrections)
|
||||||
|
total_success = sum(c.success_count for c in corrections)
|
||||||
|
|
||||||
|
# Group by type
|
||||||
|
by_type = self.group_by_type(corrections)
|
||||||
|
type_stats = {
|
||||||
|
t.value: len(cs) for t, cs in by_type.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Group by action
|
||||||
|
by_action = self.group_by_action_type(corrections)
|
||||||
|
action_stats = {
|
||||||
|
a: len(cs) for a, cs in by_action.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'count': len(corrections),
|
||||||
|
'total_applications': total_apps,
|
||||||
|
'total_success': total_success,
|
||||||
|
'total_failure': total_apps - total_success,
|
||||||
|
'average_success_rate': total_success / total_apps if total_apps > 0 else 0.0,
|
||||||
|
'average_confidence': sum(c.confidence_score for c in corrections) / len(corrections),
|
||||||
|
'by_type': type_stats,
|
||||||
|
'by_action': action_stats
|
||||||
|
}
|
||||||
|
|
||||||
|
def _passes_thresholds(self, correction: Correction) -> bool:
|
||||||
|
"""Check if a correction passes quality thresholds."""
|
||||||
|
total = correction.success_count + correction.failure_count
|
||||||
|
|
||||||
|
# Always include if no applications yet (new correction)
|
||||||
|
if total == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check minimum occurrences
|
||||||
|
if total < self.min_occurrences:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check success rate
|
||||||
|
if correction.success_rate < self.min_success_rate:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check confidence
|
||||||
|
if correction.confidence_score < self.min_confidence:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _merge_corrections(
|
||||||
|
self,
|
||||||
|
corrections: List[Correction],
|
||||||
|
strategy: str = 'weighted_average'
|
||||||
|
) -> Optional[Correction]:
|
||||||
|
"""
|
||||||
|
Merge multiple corrections with the same key hash.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corrections: List of corrections to merge
|
||||||
|
strategy: Merge strategy
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged correction
|
||||||
|
"""
|
||||||
|
if not corrections:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(corrections) == 1:
|
||||||
|
return corrections[0]
|
||||||
|
|
||||||
|
if strategy == 'best_confidence':
|
||||||
|
return max(corrections, key=lambda c: c.confidence_score)
|
||||||
|
|
||||||
|
elif strategy == 'most_recent':
|
||||||
|
return max(
|
||||||
|
corrections,
|
||||||
|
key=lambda c: c.last_applied or c.created_at
|
||||||
|
)
|
||||||
|
|
||||||
|
else: # weighted_average (default)
|
||||||
|
# Use the correction with highest usage as base
|
||||||
|
base = max(corrections, key=lambda c: c.success_count + c.failure_count)
|
||||||
|
|
||||||
|
# Sum statistics
|
||||||
|
total_success = sum(c.success_count for c in corrections)
|
||||||
|
total_failure = sum(c.failure_count for c in corrections)
|
||||||
|
|
||||||
|
# Find most recent application
|
||||||
|
last_applied = None
|
||||||
|
for c in corrections:
|
||||||
|
if c.last_applied:
|
||||||
|
if last_applied is None or c.last_applied > last_applied:
|
||||||
|
last_applied = c.last_applied
|
||||||
|
|
||||||
|
# Merge tags
|
||||||
|
all_tags = set()
|
||||||
|
for c in corrections:
|
||||||
|
all_tags.update(c.tags)
|
||||||
|
|
||||||
|
# Create merged correction
|
||||||
|
merged = Correction(
|
||||||
|
id=generate_correction_id(),
|
||||||
|
key=base.key,
|
||||||
|
correction_type=base.correction_type,
|
||||||
|
original_target=base.original_target,
|
||||||
|
original_params=base.original_params,
|
||||||
|
corrected_target=base.corrected_target,
|
||||||
|
corrected_params=base.corrected_params,
|
||||||
|
correction_description=base.correction_description,
|
||||||
|
source=base.source, # Keep first source as primary
|
||||||
|
success_count=total_success,
|
||||||
|
failure_count=total_failure,
|
||||||
|
last_applied=last_applied,
|
||||||
|
status=CorrectionStatus.ACTIVE,
|
||||||
|
created_at=min(c.created_at for c in corrections),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
tags=list(all_tags)
|
||||||
|
)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def _session_correction_to_correction(
|
||||||
|
self,
|
||||||
|
session_correction: Dict[str, Any]
|
||||||
|
) -> Optional[Correction]:
|
||||||
|
"""
|
||||||
|
Convert a session correction dict to a Correction object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_correction: Dict from TrainingSession.user_corrections
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Correction object or None if invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Extract key components
|
||||||
|
action_type = session_correction.get('action_type', session_correction.get('type', 'unknown'))
|
||||||
|
element_type = session_correction.get('element_type', 'unknown')
|
||||||
|
failure_context = session_correction.get('failure_reason', session_correction.get('reason', ''))
|
||||||
|
|
||||||
|
key = CorrectionKey(action_type, element_type, failure_context)
|
||||||
|
|
||||||
|
# Determine correction type
|
||||||
|
correction_type_str = session_correction.get('correction_type', 'other')
|
||||||
|
try:
|
||||||
|
correction_type = CorrectionType(correction_type_str)
|
||||||
|
except ValueError:
|
||||||
|
correction_type = CorrectionType.OTHER
|
||||||
|
|
||||||
|
# Create source
|
||||||
|
source = CorrectionSource(
|
||||||
|
session_id=session_correction.get('session_id', ''),
|
||||||
|
workflow_id=session_correction.get('workflow_id'),
|
||||||
|
node_id=session_correction.get('node_id')
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamp = session_correction.get('timestamp')
|
||||||
|
if timestamp:
|
||||||
|
if isinstance(timestamp, str):
|
||||||
|
source.timestamp = datetime.fromisoformat(timestamp)
|
||||||
|
elif isinstance(timestamp, datetime):
|
||||||
|
source.timestamp = timestamp
|
||||||
|
|
||||||
|
# Create correction
|
||||||
|
return Correction(
|
||||||
|
id=generate_correction_id(),
|
||||||
|
key=key,
|
||||||
|
correction_type=correction_type,
|
||||||
|
original_target=session_correction.get('original_target'),
|
||||||
|
original_params=session_correction.get('original_params'),
|
||||||
|
corrected_target=session_correction.get('corrected_target', session_correction.get('new_target')),
|
||||||
|
corrected_params=session_correction.get('corrected_params', session_correction.get('new_params')),
|
||||||
|
correction_description=session_correction.get('description', ''),
|
||||||
|
source=source,
|
||||||
|
success_count=1 if session_correction.get('applied', True) else 0,
|
||||||
|
failure_count=0 if session_correction.get('applied', True) else 1,
|
||||||
|
tags=session_correction.get('tags', [])
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error converting session correction: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class CrossWorkflowAggregator(CorrectionAggregator):
|
||||||
|
"""
|
||||||
|
Extended aggregator for cross-workflow correction analysis.
|
||||||
|
|
||||||
|
Provides additional methods for analyzing patterns across workflows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def find_common_corrections(
|
||||||
|
self,
|
||||||
|
corrections: List[Correction],
|
||||||
|
min_workflows: int = 2
|
||||||
|
) -> List[Tuple[Correction, List[str]]]:
|
||||||
|
"""
|
||||||
|
Find corrections that appear in multiple workflows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corrections: List of corrections
|
||||||
|
min_workflows: Minimum number of workflows
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (correction, workflow_ids) tuples
|
||||||
|
"""
|
||||||
|
# Group by key hash
|
||||||
|
grouped: Dict[str, List[Correction]] = defaultdict(list)
|
||||||
|
for correction in corrections:
|
||||||
|
key_hash = correction.key.to_hash()
|
||||||
|
grouped[key_hash].append(correction)
|
||||||
|
|
||||||
|
# Find common corrections
|
||||||
|
common = []
|
||||||
|
for key_hash, group in grouped.items():
|
||||||
|
workflow_ids = set()
|
||||||
|
for c in group:
|
||||||
|
if c.source and c.source.workflow_id:
|
||||||
|
workflow_ids.add(c.source.workflow_id)
|
||||||
|
|
||||||
|
if len(workflow_ids) >= min_workflows:
|
||||||
|
# Merge the group
|
||||||
|
merged = self._merge_corrections(group)
|
||||||
|
if merged:
|
||||||
|
common.append((merged, list(workflow_ids)))
|
||||||
|
|
||||||
|
# Sort by number of workflows
|
||||||
|
common.sort(key=lambda x: len(x[1]), reverse=True)
|
||||||
|
return common
|
||||||
|
|
||||||
|
def detect_patterns(
|
||||||
|
self,
|
||||||
|
corrections: List[Correction]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Detect patterns in corrections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corrections: List of corrections
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pattern analysis dict
|
||||||
|
"""
|
||||||
|
if not corrections:
|
||||||
|
return {'patterns': [], 'summary': {}}
|
||||||
|
|
||||||
|
patterns = []
|
||||||
|
|
||||||
|
# Pattern 1: Common action type corrections
|
||||||
|
by_action = self.group_by_action_type(corrections)
|
||||||
|
for action, action_corrections in by_action.items():
|
||||||
|
if len(action_corrections) >= 3:
|
||||||
|
patterns.append({
|
||||||
|
'type': 'frequent_action_correction',
|
||||||
|
'action_type': action,
|
||||||
|
'count': len(action_corrections),
|
||||||
|
'avg_success_rate': sum(c.success_rate for c in action_corrections) / len(action_corrections)
|
||||||
|
})
|
||||||
|
|
||||||
|
# Pattern 2: Common element type failures
|
||||||
|
by_element: Dict[str, List[Correction]] = defaultdict(list)
|
||||||
|
for c in corrections:
|
||||||
|
by_element[c.key.element_type].append(c)
|
||||||
|
|
||||||
|
for element, element_corrections in by_element.items():
|
||||||
|
if len(element_corrections) >= 3:
|
||||||
|
patterns.append({
|
||||||
|
'type': 'element_type_issues',
|
||||||
|
'element_type': element,
|
||||||
|
'count': len(element_corrections),
|
||||||
|
'common_fixes': self._find_common_fix_patterns(element_corrections)
|
||||||
|
})
|
||||||
|
|
||||||
|
# Pattern 3: Time-based patterns
|
||||||
|
recent = [c for c in corrections if c.last_applied and
|
||||||
|
(datetime.now() - c.last_applied).days < 7]
|
||||||
|
if len(recent) > len(corrections) * 0.5:
|
||||||
|
patterns.append({
|
||||||
|
'type': 'recent_activity_spike',
|
||||||
|
'recent_count': len(recent),
|
||||||
|
'total_count': len(corrections)
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
'patterns': patterns,
|
||||||
|
'summary': {
|
||||||
|
'total_corrections': len(corrections),
|
||||||
|
'unique_action_types': len(by_action),
|
||||||
|
'unique_element_types': len(by_element),
|
||||||
|
'pattern_count': len(patterns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _find_common_fix_patterns(
|
||||||
|
self,
|
||||||
|
corrections: List[Correction]
|
||||||
|
) -> List[str]:
|
||||||
|
"""Find common fix patterns in a list of corrections."""
|
||||||
|
fix_types = defaultdict(int)
|
||||||
|
for c in corrections:
|
||||||
|
fix_types[c.correction_type.value] += 1
|
||||||
|
|
||||||
|
# Return top 3 most common
|
||||||
|
sorted_fixes = sorted(fix_types.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
return [fix[0] for fix in sorted_fixes[:3]]
|
||||||
785
core/corrections/correction_pack_service.py
Normal file
785
core/corrections/correction_pack_service.py
Normal file
@@ -0,0 +1,785 @@
|
|||||||
|
"""
|
||||||
|
Correction Pack Service - Main service for correction pack management.
|
||||||
|
|
||||||
|
Combines repository, aggregator, and provides high-level operations
|
||||||
|
for CRUD, aggregation, export/import, and application of corrections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Any, Tuple
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
Correction,
|
||||||
|
CorrectionPack,
|
||||||
|
CorrectionKey,
|
||||||
|
CorrectionType,
|
||||||
|
CorrectionSource,
|
||||||
|
CorrectionStatus,
|
||||||
|
CorrectionPackMetadata,
|
||||||
|
generate_correction_id,
|
||||||
|
generate_pack_id
|
||||||
|
)
|
||||||
|
from .correction_repository import CorrectionRepository
|
||||||
|
from .aggregator import CorrectionAggregator, CrossWorkflowAggregator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectionPackService:
|
||||||
|
"""
|
||||||
|
Main service for managing correction packs.
|
||||||
|
|
||||||
|
Provides high-level operations combining the repository and aggregator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_path: str = "data/correction_packs",
|
||||||
|
training_data_path: str = "training_data"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: Path for correction pack storage
|
||||||
|
training_data_path: Path to training data (for session import)
|
||||||
|
"""
|
||||||
|
self.storage_path = Path(storage_path)
|
||||||
|
self.training_data_path = Path(training_data_path)
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self.repository = CorrectionRepository(storage_path)
|
||||||
|
self.aggregator = CorrectionAggregator()
|
||||||
|
self.cross_workflow_aggregator = CrossWorkflowAggregator()
|
||||||
|
|
||||||
|
logger.info(f"CorrectionPackService initialized (storage={storage_path})")
|
||||||
|
|
||||||
|
# ========== Pack CRUD Operations ==========
|
||||||
|
|
||||||
|
def list_packs(
|
||||||
|
self,
|
||||||
|
category: Optional[str] = None,
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all correction packs with optional filtering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: Filter by category
|
||||||
|
tags: Filter by tags (match any)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of pack summaries
|
||||||
|
"""
|
||||||
|
packs = self.repository.list_packs()
|
||||||
|
|
||||||
|
# Filter by category
|
||||||
|
if category:
|
||||||
|
packs = [p for p in packs if p.metadata.category == category]
|
||||||
|
|
||||||
|
# Filter by tags
|
||||||
|
if tags:
|
||||||
|
packs = [
|
||||||
|
p for p in packs
|
||||||
|
if any(tag in p.metadata.tags for tag in tags)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Return summaries
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'id': p.id,
|
||||||
|
'name': p.name,
|
||||||
|
'description': p.description,
|
||||||
|
'category': p.metadata.category,
|
||||||
|
'tags': p.metadata.tags,
|
||||||
|
'correction_count': p.correction_count,
|
||||||
|
'success_rate': p.overall_success_rate,
|
||||||
|
'created_at': p.created_at.isoformat(),
|
||||||
|
'updated_at': p.updated_at.isoformat()
|
||||||
|
}
|
||||||
|
for p in packs
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_pack(self, pack_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get a pack by ID with full details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pack dict or None
|
||||||
|
"""
|
||||||
|
pack = self.repository.get_pack(pack_id)
|
||||||
|
if not pack:
|
||||||
|
return None
|
||||||
|
return pack.to_dict()
|
||||||
|
|
||||||
|
def create_pack(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str = "",
|
||||||
|
category: str = "general",
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
author: str = ""
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create a new correction pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Pack name
|
||||||
|
description: Pack description
|
||||||
|
category: Pack category
|
||||||
|
tags: Pack tags
|
||||||
|
author: Pack author
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created pack dict
|
||||||
|
"""
|
||||||
|
metadata = {
|
||||||
|
'category': category,
|
||||||
|
'tags': tags or [],
|
||||||
|
'author': author
|
||||||
|
}
|
||||||
|
|
||||||
|
pack = self.repository.create_pack(name, description, metadata)
|
||||||
|
return pack.to_dict()
|
||||||
|
|
||||||
|
def update_pack(
|
||||||
|
self,
|
||||||
|
pack_id: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Update a pack's properties.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
name: New name
|
||||||
|
description: New description
|
||||||
|
metadata: New metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated pack dict or None
|
||||||
|
"""
|
||||||
|
updates = {}
|
||||||
|
if name is not None:
|
||||||
|
updates['name'] = name
|
||||||
|
if description is not None:
|
||||||
|
updates['description'] = description
|
||||||
|
if metadata is not None:
|
||||||
|
updates['metadata'] = metadata
|
||||||
|
|
||||||
|
pack = self.repository.update_pack(pack_id, updates)
|
||||||
|
if not pack:
|
||||||
|
return None
|
||||||
|
return pack.to_dict()
|
||||||
|
|
||||||
|
def delete_pack(self, pack_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted
|
||||||
|
"""
|
||||||
|
return self.repository.delete_pack(pack_id)
|
||||||
|
|
||||||
|
# ========== Correction Operations ==========
|
||||||
|
|
||||||
|
def add_correction(
|
||||||
|
self,
|
||||||
|
pack_id: str,
|
||||||
|
correction_data: Dict[str, Any]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Add a correction to a pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
correction_data: Correction data dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created correction dict or None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Build correction key
|
||||||
|
key = CorrectionKey(
|
||||||
|
action_type=correction_data.get('action_type', 'unknown'),
|
||||||
|
element_type=correction_data.get('element_type', 'unknown'),
|
||||||
|
failure_context=correction_data.get('failure_context', '')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine correction type
|
||||||
|
type_str = correction_data.get('correction_type', 'other')
|
||||||
|
try:
|
||||||
|
correction_type = CorrectionType(type_str)
|
||||||
|
except ValueError:
|
||||||
|
correction_type = CorrectionType.OTHER
|
||||||
|
|
||||||
|
# Build source if provided
|
||||||
|
source = None
|
||||||
|
if 'source' in correction_data:
|
||||||
|
source = CorrectionSource(
|
||||||
|
session_id=correction_data['source'].get('session_id', ''),
|
||||||
|
workflow_id=correction_data['source'].get('workflow_id'),
|
||||||
|
node_id=correction_data['source'].get('node_id')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create correction
|
||||||
|
correction = Correction(
|
||||||
|
id=generate_correction_id(),
|
||||||
|
key=key,
|
||||||
|
correction_type=correction_type,
|
||||||
|
original_target=correction_data.get('original_target'),
|
||||||
|
original_params=correction_data.get('original_params'),
|
||||||
|
corrected_target=correction_data.get('corrected_target'),
|
||||||
|
corrected_params=correction_data.get('corrected_params'),
|
||||||
|
correction_description=correction_data.get('description', ''),
|
||||||
|
source=source,
|
||||||
|
tags=correction_data.get('tags', [])
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.repository.add_correction(pack_id, correction):
|
||||||
|
return correction.to_dict()
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding correction: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add_correction_from_session(
|
||||||
|
self,
|
||||||
|
pack_id: str,
|
||||||
|
session_correction: Dict[str, Any],
|
||||||
|
session_id: str,
|
||||||
|
workflow_id: Optional[str] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Add a correction from session user_corrections format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
session_correction: Correction dict from TrainingSession.user_corrections
|
||||||
|
session_id: Session ID
|
||||||
|
workflow_id: Workflow ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created correction dict or None
|
||||||
|
"""
|
||||||
|
# Enrich with source info
|
||||||
|
session_correction['source'] = {
|
||||||
|
'session_id': session_id,
|
||||||
|
'workflow_id': workflow_id,
|
||||||
|
'node_id': session_correction.get('node_id')
|
||||||
|
}
|
||||||
|
|
||||||
|
# Map old format fields to new format
|
||||||
|
if 'type' in session_correction and 'action_type' not in session_correction:
|
||||||
|
session_correction['action_type'] = session_correction['type']
|
||||||
|
if 'reason' in session_correction and 'failure_context' not in session_correction:
|
||||||
|
session_correction['failure_context'] = session_correction['reason']
|
||||||
|
if 'new_target' in session_correction and 'corrected_target' not in session_correction:
|
||||||
|
session_correction['corrected_target'] = session_correction['new_target']
|
||||||
|
if 'new_params' in session_correction and 'corrected_params' not in session_correction:
|
||||||
|
session_correction['corrected_params'] = session_correction['new_params']
|
||||||
|
|
||||||
|
return self.add_correction(pack_id, session_correction)
|
||||||
|
|
||||||
|
def update_correction(
|
||||||
|
self,
|
||||||
|
pack_id: str,
|
||||||
|
correction_id: str,
|
||||||
|
updates: Dict[str, Any]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Update a correction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
correction_id: Correction ID
|
||||||
|
updates: Fields to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated correction dict or None
|
||||||
|
"""
|
||||||
|
correction = self.repository.update_correction(pack_id, correction_id, updates)
|
||||||
|
if not correction:
|
||||||
|
return None
|
||||||
|
return correction.to_dict()
|
||||||
|
|
||||||
|
def remove_correction(self, pack_id: str, correction_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a correction from a pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
correction_id: Correction ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if removed
|
||||||
|
"""
|
||||||
|
return self.repository.remove_correction(pack_id, correction_id)
|
||||||
|
|
||||||
|
# ========== Aggregation Operations ==========
|
||||||
|
|
||||||
|
def aggregate_from_sessions(
|
||||||
|
self,
|
||||||
|
pack_id: str,
|
||||||
|
session_files: Optional[List[str]] = None,
|
||||||
|
workflow_id: Optional[str] = None,
|
||||||
|
min_occurrences: int = 1,
|
||||||
|
min_success_rate: float = 0.0
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Aggregate corrections from training sessions into a pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID to add corrections to
|
||||||
|
session_files: Specific session files to process (None = all)
|
||||||
|
workflow_id: Filter by workflow ID
|
||||||
|
min_occurrences: Minimum occurrences to include
|
||||||
|
min_success_rate: Minimum success rate to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Aggregation result summary
|
||||||
|
"""
|
||||||
|
# Collect all session corrections
|
||||||
|
all_corrections = []
|
||||||
|
|
||||||
|
if session_files is None:
|
||||||
|
# Find all session files
|
||||||
|
session_files = list(self.training_data_path.glob("session_*.json"))
|
||||||
|
else:
|
||||||
|
session_files = [Path(f) for f in session_files]
|
||||||
|
|
||||||
|
sessions_processed = 0
|
||||||
|
for session_file in session_files:
|
||||||
|
try:
|
||||||
|
with open(session_file, 'r', encoding='utf-8') as f:
|
||||||
|
session_data = json.load(f)
|
||||||
|
|
||||||
|
# Filter by workflow if specified
|
||||||
|
if workflow_id and session_data.get('workflow_id') != workflow_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
user_corrections = session_data.get('user_corrections', [])
|
||||||
|
session_id = session_data.get('session_id', session_file.stem)
|
||||||
|
|
||||||
|
for uc in user_corrections:
|
||||||
|
uc['session_id'] = session_id
|
||||||
|
uc['workflow_id'] = session_data.get('workflow_id')
|
||||||
|
all_corrections.append(uc)
|
||||||
|
|
||||||
|
sessions_processed += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error reading session file {session_file}: {e}")
|
||||||
|
|
||||||
|
# Configure aggregator
|
||||||
|
self.aggregator.min_occurrences = min_occurrences
|
||||||
|
self.aggregator.min_success_rate = min_success_rate
|
||||||
|
|
||||||
|
# Aggregate
|
||||||
|
aggregated = self.aggregator.aggregate_from_sessions(all_corrections, workflow_id)
|
||||||
|
|
||||||
|
# Add to pack
|
||||||
|
added_count = 0
|
||||||
|
for correction in aggregated:
|
||||||
|
if self.repository.add_correction(pack_id, correction):
|
||||||
|
added_count += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
'sessions_processed': sessions_processed,
|
||||||
|
'corrections_found': len(all_corrections),
|
||||||
|
'corrections_aggregated': len(aggregated),
|
||||||
|
'corrections_added': added_count,
|
||||||
|
'statistics': self.aggregator.calculate_statistics(aggregated)
|
||||||
|
}
|
||||||
|
|
||||||
|
def aggregate_cross_workflow(
|
||||||
|
self,
|
||||||
|
pack_id: str,
|
||||||
|
min_workflows: int = 2
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Find and aggregate corrections common across multiple workflows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID to add corrections to
|
||||||
|
min_workflows: Minimum number of workflows a correction must appear in
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Aggregation result summary
|
||||||
|
"""
|
||||||
|
# Get existing corrections from all packs
|
||||||
|
all_corrections = []
|
||||||
|
for pack in self.repository.list_packs():
|
||||||
|
all_corrections.extend(pack.corrections.values())
|
||||||
|
|
||||||
|
# Find common corrections
|
||||||
|
common = self.cross_workflow_aggregator.find_common_corrections(
|
||||||
|
all_corrections, min_workflows
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to pack
|
||||||
|
added_count = 0
|
||||||
|
for correction, workflow_ids in common:
|
||||||
|
# Tag with source workflows
|
||||||
|
correction.tags.extend([f"wf:{wid}" for wid in workflow_ids[:5]]) # Limit tags
|
||||||
|
correction.correction_description += f" (common in {len(workflow_ids)} workflows)"
|
||||||
|
|
||||||
|
if self.repository.add_correction(pack_id, correction):
|
||||||
|
added_count += 1
|
||||||
|
|
||||||
|
# Detect patterns
|
||||||
|
patterns = self.cross_workflow_aggregator.detect_patterns(all_corrections)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_corrections_analyzed': len(all_corrections),
|
||||||
|
'common_corrections_found': len(common),
|
||||||
|
'corrections_added': added_count,
|
||||||
|
'patterns': patterns
|
||||||
|
}
|
||||||
|
|
||||||
|
# ========== Search and Apply Operations ==========
|
||||||
|
|
||||||
|
def find_applicable_corrections(
|
||||||
|
self,
|
||||||
|
action_type: str,
|
||||||
|
element_type: str,
|
||||||
|
failure_context: str,
|
||||||
|
pack_ids: Optional[List[str]] = None,
|
||||||
|
min_confidence: float = 0.3
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Find corrections applicable to a failure context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_type: Type of action that failed
|
||||||
|
element_type: Type of element involved
|
||||||
|
failure_context: Description of failure
|
||||||
|
pack_ids: Specific packs to search (None = all)
|
||||||
|
min_confidence: Minimum confidence score
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of applicable corrections with pack info
|
||||||
|
"""
|
||||||
|
results = self.repository.find_by_context(action_type, element_type, failure_context)
|
||||||
|
|
||||||
|
# Filter by pack IDs if specified
|
||||||
|
if pack_ids:
|
||||||
|
results = [(pid, c) for pid, c in results if pid in pack_ids]
|
||||||
|
|
||||||
|
# Filter by confidence
|
||||||
|
results = [(pid, c) for pid, c in results if c.confidence_score >= min_confidence]
|
||||||
|
|
||||||
|
# Format results
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'pack_id': pack_id,
|
||||||
|
'correction': correction.to_dict()
|
||||||
|
}
|
||||||
|
for pack_id, correction in results
|
||||||
|
]
|
||||||
|
|
||||||
|
def apply_correction(
|
||||||
|
self,
|
||||||
|
pack_id: str,
|
||||||
|
correction_id: str,
|
||||||
|
success: bool,
|
||||||
|
application_context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Record an application of a correction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
correction_id: Correction ID
|
||||||
|
success: Whether the application was successful
|
||||||
|
application_context: Optional context information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if recorded
|
||||||
|
"""
|
||||||
|
self.repository.record_application(pack_id, correction_id, success)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Recorded correction application: pack={pack_id}, "
|
||||||
|
f"correction={correction_id}, success={success}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def search_corrections(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
pack_id: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Search corrections by text query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query
|
||||||
|
pack_id: Optional pack ID to search in
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching corrections with pack info
|
||||||
|
"""
|
||||||
|
results = self.repository.search_corrections(query, pack_id)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'pack_id': pid,
|
||||||
|
'correction': correction.to_dict()
|
||||||
|
}
|
||||||
|
for pid, correction in results
|
||||||
|
]
|
||||||
|
|
||||||
|
# ========== Export/Import Operations ==========
|
||||||
|
|
||||||
|
def export_pack(
|
||||||
|
self,
|
||||||
|
pack_id: str,
|
||||||
|
format: str = 'json',
|
||||||
|
include_stats: bool = True
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Export a pack to JSON or YAML.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
format: Export format ('json' or 'yaml')
|
||||||
|
include_stats: Whether to include statistics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exported content string or None
|
||||||
|
"""
|
||||||
|
return self.repository.export_pack(pack_id, format)
|
||||||
|
|
||||||
|
def import_pack(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
format: str = 'json',
|
||||||
|
merge_strategy: str = 'create_new'
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Import a pack from JSON or YAML.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to import
|
||||||
|
format: Import format ('json' or 'yaml')
|
||||||
|
merge_strategy: How to handle existing pack
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Imported pack dict or None
|
||||||
|
"""
|
||||||
|
pack = self.repository.import_pack(content, format, merge_strategy)
|
||||||
|
if not pack:
|
||||||
|
return None
|
||||||
|
return pack.to_dict()
|
||||||
|
|
||||||
|
def export_pack_to_file(
|
||||||
|
self,
|
||||||
|
pack_id: str,
|
||||||
|
file_path: str,
|
||||||
|
format: str = 'json'
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Export a pack to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
file_path: Output file path
|
||||||
|
format: Export format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
content = self.export_pack(pack_id, format)
|
||||||
|
if not content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(content)
|
||||||
|
logger.info(f"Exported pack {pack_id} to {file_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error exporting pack to file: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def import_pack_from_file(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
format: Optional[str] = None,
|
||||||
|
merge_strategy: str = 'create_new'
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Import a pack from a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Input file path
|
||||||
|
format: Import format (auto-detected if None)
|
||||||
|
merge_strategy: How to handle existing pack
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Imported pack dict or None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Auto-detect format
|
||||||
|
if format is None:
|
||||||
|
if file_path.endswith('.yaml') or file_path.endswith('.yml'):
|
||||||
|
format = 'yaml'
|
||||||
|
else:
|
||||||
|
format = 'json'
|
||||||
|
|
||||||
|
result = self.import_pack(content, format, merge_strategy)
|
||||||
|
if result:
|
||||||
|
logger.info(f"Imported pack from {file_path}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error importing pack from file: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ========== Version Operations ==========
|
||||||
|
|
||||||
|
def create_version(
|
||||||
|
self,
|
||||||
|
pack_id: str,
|
||||||
|
description: str = ""
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Create a version snapshot of a pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
description: Version description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Version number (-1 on error)
|
||||||
|
"""
|
||||||
|
return self.repository.create_version(pack_id, description)
|
||||||
|
|
||||||
|
def list_versions(self, pack_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all versions of a pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of version info dicts
|
||||||
|
"""
|
||||||
|
return self.repository.list_versions(pack_id)
|
||||||
|
|
||||||
|
def rollback_pack(self, pack_id: str, version: int) -> bool:
|
||||||
|
"""
|
||||||
|
Rollback a pack to a previous version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
version: Version number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
return self.repository.rollback_pack(pack_id, version)
|
||||||
|
|
||||||
|
# ========== Statistics and Analysis ==========
|
||||||
|
|
||||||
|
def get_pack_statistics(self, pack_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get detailed statistics for a pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Statistics dict or None
|
||||||
|
"""
|
||||||
|
pack = self.repository.get_pack(pack_id)
|
||||||
|
if not pack:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return pack.get_statistics()
|
||||||
|
|
||||||
|
def get_global_statistics(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get global statistics across all packs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Global statistics dict
|
||||||
|
"""
|
||||||
|
packs = self.repository.list_packs()
|
||||||
|
|
||||||
|
total_corrections = sum(p.correction_count for p in packs)
|
||||||
|
total_applications = sum(p.total_applications for p in packs)
|
||||||
|
|
||||||
|
# Aggregate success rate
|
||||||
|
total_success = sum(
|
||||||
|
sum(c.success_count for c in p.corrections.values())
|
||||||
|
for p in packs
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_packs': len(packs),
|
||||||
|
'total_corrections': total_corrections,
|
||||||
|
'total_applications': total_applications,
|
||||||
|
'overall_success_rate': total_success / total_applications if total_applications > 0 else 0.0,
|
||||||
|
'packs_by_category': self._group_packs_by_category(packs),
|
||||||
|
'top_corrections': self._get_top_corrections(packs, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _group_packs_by_category(
|
||||||
|
self,
|
||||||
|
packs: List[CorrectionPack]
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
"""Group packs by category."""
|
||||||
|
by_category = {}
|
||||||
|
for pack in packs:
|
||||||
|
cat = pack.metadata.category
|
||||||
|
by_category[cat] = by_category.get(cat, 0) + 1
|
||||||
|
return by_category
|
||||||
|
|
||||||
|
def _get_top_corrections(
|
||||||
|
self,
|
||||||
|
packs: List[CorrectionPack],
|
||||||
|
limit: int = 10
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Get top corrections by confidence score."""
|
||||||
|
all_corrections = []
|
||||||
|
for pack in packs:
|
||||||
|
for correction in pack.corrections.values():
|
||||||
|
all_corrections.append({
|
||||||
|
'pack_id': pack.id,
|
||||||
|
'pack_name': pack.name,
|
||||||
|
'correction': correction
|
||||||
|
})
|
||||||
|
|
||||||
|
# Sort by confidence
|
||||||
|
all_corrections.sort(key=lambda x: x['correction'].confidence_score, reverse=True)
|
||||||
|
|
||||||
|
# Return top N
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'pack_id': item['pack_id'],
|
||||||
|
'pack_name': item['pack_name'],
|
||||||
|
'correction_id': item['correction'].id,
|
||||||
|
'action_type': item['correction'].key.action_type,
|
||||||
|
'confidence': item['correction'].confidence_score,
|
||||||
|
'success_rate': item['correction'].success_rate,
|
||||||
|
'applications': item['correction'].success_count + item['correction'].failure_count
|
||||||
|
}
|
||||||
|
for item in all_corrections[:limit]
|
||||||
|
]
|
||||||
643
core/corrections/correction_repository.py
Normal file
643
core/corrections/correction_repository.py
Normal file
@@ -0,0 +1,643 @@
|
|||||||
|
"""
|
||||||
|
Correction Repository - JSON storage for corrections and packs.
|
||||||
|
|
||||||
|
Provides atomic file operations for safe persistence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
|
||||||
|
from .models import (
|
||||||
|
Correction,
|
||||||
|
CorrectionPack,
|
||||||
|
CorrectionKey,
|
||||||
|
generate_pack_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectionRepository:
|
||||||
|
"""
|
||||||
|
Repository for storing and retrieving correction packs.
|
||||||
|
|
||||||
|
Uses atomic JSON writes (temp file + rename) for safety.
|
||||||
|
Maintains an index by key hash for fast lookups.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, storage_path: str = "data/correction_packs"):
|
||||||
|
"""
|
||||||
|
Initialize the repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_path: Base path for storing correction data
|
||||||
|
"""
|
||||||
|
self.storage_path = Path(storage_path)
|
||||||
|
self.packs_file = self.storage_path / "packs.json"
|
||||||
|
self.corrections_dir = self.storage_path / "corrections"
|
||||||
|
self.versions_dir = self.storage_path / "versions"
|
||||||
|
|
||||||
|
# Ensure directories exist
|
||||||
|
self._ensure_directories()
|
||||||
|
|
||||||
|
# In-memory cache
|
||||||
|
self._packs: Dict[str, CorrectionPack] = {}
|
||||||
|
self._key_index: Dict[str, List[str]] = {} # key_hash -> [correction_ids]
|
||||||
|
|
||||||
|
# Load existing data
|
||||||
|
self._load_packs()
|
||||||
|
self._rebuild_index()
|
||||||
|
|
||||||
|
def _ensure_directories(self):
|
||||||
|
"""Ensure all required directories exist."""
|
||||||
|
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.corrections_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.versions_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ========== Pack Operations ==========
|
||||||
|
|
||||||
|
def list_packs(self) -> List[CorrectionPack]:
|
||||||
|
"""Get all correction packs."""
|
||||||
|
return list(self._packs.values())
|
||||||
|
|
||||||
|
def get_pack(self, pack_id: str) -> Optional[CorrectionPack]:
|
||||||
|
"""Get a pack by ID."""
|
||||||
|
return self._packs.get(pack_id)
|
||||||
|
|
||||||
|
def create_pack(self, name: str, description: str = "",
|
||||||
|
metadata: Optional[Dict[str, Any]] = None) -> CorrectionPack:
|
||||||
|
"""
|
||||||
|
Create a new correction pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Pack name
|
||||||
|
description: Pack description
|
||||||
|
metadata: Optional metadata dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created pack
|
||||||
|
"""
|
||||||
|
pack_id = generate_pack_id()
|
||||||
|
pack = CorrectionPack(
|
||||||
|
id=pack_id,
|
||||||
|
name=name,
|
||||||
|
description=description
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
pack.metadata.version = metadata.get('version', '1.0.0')
|
||||||
|
pack.metadata.author = metadata.get('author', '')
|
||||||
|
pack.metadata.category = metadata.get('category', 'general')
|
||||||
|
pack.metadata.tags = metadata.get('tags', [])
|
||||||
|
|
||||||
|
self._packs[pack_id] = pack
|
||||||
|
self._save_packs()
|
||||||
|
|
||||||
|
logger.info(f"Created correction pack: {pack_id} ({name})")
|
||||||
|
return pack
|
||||||
|
|
||||||
|
def update_pack(self, pack_id: str, updates: Dict[str, Any]) -> Optional[CorrectionPack]:
|
||||||
|
"""
|
||||||
|
Update a pack's properties.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
updates: Dict of fields to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated pack or None if not found
|
||||||
|
"""
|
||||||
|
pack = self._packs.get(pack_id)
|
||||||
|
if not pack:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update allowed fields
|
||||||
|
if 'name' in updates:
|
||||||
|
pack.name = updates['name']
|
||||||
|
if 'description' in updates:
|
||||||
|
pack.description = updates['description']
|
||||||
|
if 'metadata' in updates:
|
||||||
|
metadata = updates['metadata']
|
||||||
|
if 'version' in metadata:
|
||||||
|
pack.metadata.version = metadata['version']
|
||||||
|
if 'author' in metadata:
|
||||||
|
pack.metadata.author = metadata['author']
|
||||||
|
if 'category' in metadata:
|
||||||
|
pack.metadata.category = metadata['category']
|
||||||
|
if 'tags' in metadata:
|
||||||
|
pack.metadata.tags = metadata['tags']
|
||||||
|
|
||||||
|
pack.updated_at = datetime.now()
|
||||||
|
self._save_packs()
|
||||||
|
|
||||||
|
logger.info(f"Updated correction pack: {pack_id}")
|
||||||
|
return pack
|
||||||
|
|
||||||
|
def delete_pack(self, pack_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a pack and all its corrections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found
|
||||||
|
"""
|
||||||
|
if pack_id not in self._packs:
|
||||||
|
return False
|
||||||
|
|
||||||
|
pack = self._packs[pack_id]
|
||||||
|
|
||||||
|
# Remove from index
|
||||||
|
for correction in pack.corrections.values():
|
||||||
|
key_hash = correction.key.to_hash()
|
||||||
|
if key_hash in self._key_index:
|
||||||
|
if correction.id in self._key_index[key_hash]:
|
||||||
|
self._key_index[key_hash].remove(correction.id)
|
||||||
|
|
||||||
|
del self._packs[pack_id]
|
||||||
|
self._save_packs()
|
||||||
|
|
||||||
|
# Remove version history
|
||||||
|
version_dir = self.versions_dir / pack_id
|
||||||
|
if version_dir.exists():
|
||||||
|
shutil.rmtree(version_dir)
|
||||||
|
|
||||||
|
logger.info(f"Deleted correction pack: {pack_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ========== Correction Operations ==========
|
||||||
|
|
||||||
|
def add_correction(self, pack_id: str, correction: Correction) -> bool:
|
||||||
|
"""
|
||||||
|
Add a correction to a pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
correction: Correction to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if added, False if pack not found or correction exists
|
||||||
|
"""
|
||||||
|
pack = self._packs.get(pack_id)
|
||||||
|
if not pack:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not pack.add_correction(correction):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Update index
|
||||||
|
key_hash = correction.key.to_hash()
|
||||||
|
if key_hash not in self._key_index:
|
||||||
|
self._key_index[key_hash] = []
|
||||||
|
if correction.id not in self._key_index[key_hash]:
|
||||||
|
self._key_index[key_hash].append(correction.id)
|
||||||
|
|
||||||
|
self._save_packs()
|
||||||
|
logger.info(f"Added correction {correction.id} to pack {pack_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def remove_correction(self, pack_id: str, correction_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a correction from a pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
correction_id: Correction ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if removed, False if not found
|
||||||
|
"""
|
||||||
|
pack = self._packs.get(pack_id)
|
||||||
|
if not pack:
|
||||||
|
return False
|
||||||
|
|
||||||
|
correction = pack.get_correction(correction_id)
|
||||||
|
if not correction:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Update index
|
||||||
|
key_hash = correction.key.to_hash()
|
||||||
|
if key_hash in self._key_index:
|
||||||
|
if correction_id in self._key_index[key_hash]:
|
||||||
|
self._key_index[key_hash].remove(correction_id)
|
||||||
|
|
||||||
|
pack.remove_correction(correction_id)
|
||||||
|
self._save_packs()
|
||||||
|
|
||||||
|
logger.info(f"Removed correction {correction_id} from pack {pack_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def update_correction(self, pack_id: str, correction_id: str,
|
||||||
|
updates: Dict[str, Any]) -> Optional[Correction]:
|
||||||
|
"""
|
||||||
|
Update a correction's properties.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
correction_id: Correction ID
|
||||||
|
updates: Dict of fields to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated correction or None if not found
|
||||||
|
"""
|
||||||
|
pack = self._packs.get(pack_id)
|
||||||
|
if not pack:
|
||||||
|
return None
|
||||||
|
|
||||||
|
correction = pack.get_correction(correction_id)
|
||||||
|
if not correction:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update allowed fields
|
||||||
|
if 'corrected_target' in updates:
|
||||||
|
correction.corrected_target = updates['corrected_target']
|
||||||
|
if 'corrected_params' in updates:
|
||||||
|
correction.corrected_params = updates['corrected_params']
|
||||||
|
if 'correction_description' in updates:
|
||||||
|
correction.correction_description = updates['correction_description']
|
||||||
|
if 'status' in updates:
|
||||||
|
from .models import CorrectionStatus
|
||||||
|
correction.status = CorrectionStatus(updates['status'])
|
||||||
|
if 'tags' in updates:
|
||||||
|
correction.tags = updates['tags']
|
||||||
|
|
||||||
|
correction.updated_at = datetime.now()
|
||||||
|
pack.updated_at = datetime.now()
|
||||||
|
self._save_packs()
|
||||||
|
|
||||||
|
logger.info(f"Updated correction {correction_id} in pack {pack_id}")
|
||||||
|
return correction
|
||||||
|
|
||||||
|
def record_application(self, pack_id: str, correction_id: str, success: bool):
|
||||||
|
"""
|
||||||
|
Record an application of a correction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
correction_id: Correction ID
|
||||||
|
success: Whether the application was successful
|
||||||
|
"""
|
||||||
|
pack = self._packs.get(pack_id)
|
||||||
|
if not pack:
|
||||||
|
return
|
||||||
|
|
||||||
|
correction = pack.get_correction(correction_id)
|
||||||
|
if not correction:
|
||||||
|
return
|
||||||
|
|
||||||
|
correction.record_application(success)
|
||||||
|
pack.updated_at = datetime.now()
|
||||||
|
self._save_packs()
|
||||||
|
|
||||||
|
# ========== Search Operations ==========
|
||||||
|
|
||||||
|
def find_by_key_hash(self, key_hash: str) -> List[tuple]:
|
||||||
|
"""
|
||||||
|
Find all corrections matching a key hash.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_hash: MD5 hash of correction key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (pack_id, correction) tuples
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for pack_id, pack in self._packs.items():
|
||||||
|
for correction in pack.find_by_key_hash(key_hash):
|
||||||
|
results.append((pack_id, correction))
|
||||||
|
|
||||||
|
# Sort by confidence score
|
||||||
|
results.sort(key=lambda x: x[1].confidence_score, reverse=True)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def find_by_context(self, action_type: str, element_type: str,
|
||||||
|
failure_context: str) -> List[tuple]:
|
||||||
|
"""
|
||||||
|
Find corrections matching a failure context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_type: Type of action that failed
|
||||||
|
element_type: Type of element involved
|
||||||
|
failure_context: Description of failure
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (pack_id, correction) tuples
|
||||||
|
"""
|
||||||
|
key = CorrectionKey(action_type, element_type, failure_context)
|
||||||
|
return self.find_by_key_hash(key.to_hash())
|
||||||
|
|
||||||
|
def search_corrections(self, query: str, pack_id: Optional[str] = None) -> List[tuple]:
|
||||||
|
"""
|
||||||
|
Search corrections by text query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query (searches in description, tags, etc.)
|
||||||
|
pack_id: Optional pack ID to search in
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (pack_id, correction) tuples
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
query_lower = query.lower()
|
||||||
|
|
||||||
|
packs_to_search = [self._packs[pack_id]] if pack_id and pack_id in self._packs else self._packs.values()
|
||||||
|
|
||||||
|
for pack in packs_to_search:
|
||||||
|
for correction in pack.corrections.values():
|
||||||
|
# Search in description
|
||||||
|
if query_lower in correction.correction_description.lower():
|
||||||
|
results.append((pack.id, correction))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Search in tags
|
||||||
|
if any(query_lower in tag.lower() for tag in correction.tags):
|
||||||
|
results.append((pack.id, correction))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Search in key components
|
||||||
|
if (query_lower in correction.key.action_type.lower() or
|
||||||
|
query_lower in correction.key.element_type.lower() or
|
||||||
|
query_lower in correction.key.failure_context.lower()):
|
||||||
|
results.append((pack.id, correction))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# ========== Version Operations ==========
|
||||||
|
|
||||||
|
def create_version(self, pack_id: str, description: str = "") -> int:
|
||||||
|
"""
|
||||||
|
Create a version snapshot of a pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
description: Version description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Version number
|
||||||
|
"""
|
||||||
|
pack = self._packs.get(pack_id)
|
||||||
|
if not pack:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
# Create version directory
|
||||||
|
pack_versions = self.versions_dir / pack_id
|
||||||
|
pack_versions.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Increment version
|
||||||
|
version = pack.current_version
|
||||||
|
pack.current_version += 1
|
||||||
|
|
||||||
|
# Save snapshot
|
||||||
|
version_file = pack_versions / f"v{version}.json"
|
||||||
|
version_data = {
|
||||||
|
'version': version,
|
||||||
|
'description': description,
|
||||||
|
'created_at': datetime.now().isoformat(),
|
||||||
|
'pack_snapshot': pack.to_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
self._atomic_write(version_file, version_data)
|
||||||
|
self._save_packs()
|
||||||
|
|
||||||
|
logger.info(f"Created version {version} for pack {pack_id}")
|
||||||
|
return version
|
||||||
|
|
||||||
|
def list_versions(self, pack_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all versions of a pack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of version info dicts
|
||||||
|
"""
|
||||||
|
versions = []
|
||||||
|
pack_versions = self.versions_dir / pack_id
|
||||||
|
|
||||||
|
if not pack_versions.exists():
|
||||||
|
return versions
|
||||||
|
|
||||||
|
for version_file in sorted(pack_versions.glob("v*.json")):
|
||||||
|
try:
|
||||||
|
with open(version_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
versions.append({
|
||||||
|
'version': data.get('version'),
|
||||||
|
'description': data.get('description', ''),
|
||||||
|
'created_at': data.get('created_at'),
|
||||||
|
'correction_count': len(data.get('pack_snapshot', {}).get('corrections', {}))
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error reading version file {version_file}: {e}")
|
||||||
|
|
||||||
|
return versions
|
||||||
|
|
||||||
|
def rollback_pack(self, pack_id: str, version: int) -> bool:
|
||||||
|
"""
|
||||||
|
Rollback a pack to a previous version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
version: Version number to rollback to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
version_file = self.versions_dir / pack_id / f"v{version}.json"
|
||||||
|
|
||||||
|
if not version_file.exists():
|
||||||
|
logger.warning(f"Version {version} not found for pack {pack_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(version_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Restore pack from snapshot
|
||||||
|
snapshot = data.get('pack_snapshot', {})
|
||||||
|
pack = CorrectionPack.from_dict(snapshot)
|
||||||
|
|
||||||
|
# Update version number to continue from current
|
||||||
|
if pack_id in self._packs:
|
||||||
|
pack.current_version = self._packs[pack_id].current_version + 1
|
||||||
|
|
||||||
|
self._packs[pack_id] = pack
|
||||||
|
self._rebuild_index()
|
||||||
|
self._save_packs()
|
||||||
|
|
||||||
|
logger.info(f"Rolled back pack {pack_id} to version {version}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error rolling back pack {pack_id} to version {version}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ========== Persistence ==========
|
||||||
|
|
||||||
|
def _load_packs(self):
|
||||||
|
"""Load all packs from storage."""
|
||||||
|
if not self.packs_file.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.packs_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
for pack_id, pack_data in data.items():
|
||||||
|
try:
|
||||||
|
self._packs[pack_id] = CorrectionPack.from_dict(pack_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error loading pack {pack_id}: {e}")
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(self._packs)} correction packs")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading packs file: {e}")
|
||||||
|
|
||||||
|
def _save_packs(self):
|
||||||
|
"""Save all packs to storage using atomic write."""
|
||||||
|
try:
|
||||||
|
data = {
|
||||||
|
pack_id: pack.to_dict()
|
||||||
|
for pack_id, pack in self._packs.items()
|
||||||
|
}
|
||||||
|
self._atomic_write(self.packs_file, data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving packs: {e}")
|
||||||
|
|
||||||
|
def _atomic_write(self, file_path: Path, data: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Write data to file atomically (temp file + rename).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Target file path
|
||||||
|
data: Data to write
|
||||||
|
"""
|
||||||
|
temp_file = file_path.with_suffix('.tmp')
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(temp_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Atomic rename
|
||||||
|
temp_file.replace(file_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up temp file on error
|
||||||
|
if temp_file.exists():
|
||||||
|
temp_file.unlink()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _rebuild_index(self):
|
||||||
|
"""Rebuild the key hash index from all packs."""
|
||||||
|
self._key_index.clear()
|
||||||
|
|
||||||
|
for pack in self._packs.values():
|
||||||
|
for correction in pack.corrections.values():
|
||||||
|
key_hash = correction.key.to_hash()
|
||||||
|
if key_hash not in self._key_index:
|
||||||
|
self._key_index[key_hash] = []
|
||||||
|
if correction.id not in self._key_index[key_hash]:
|
||||||
|
self._key_index[key_hash].append(correction.id)
|
||||||
|
|
||||||
|
# ========== Export/Import ==========
|
||||||
|
|
||||||
|
def export_pack(self, pack_id: str, format: str = 'json') -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Export a pack to JSON or YAML format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pack_id: Pack ID
|
||||||
|
format: Export format ('json' or 'yaml')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exported content as string, or None if pack not found
|
||||||
|
"""
|
||||||
|
pack = self._packs.get(pack_id)
|
||||||
|
if not pack:
|
||||||
|
return None
|
||||||
|
|
||||||
|
export_data = {
|
||||||
|
'export_format': 'correction_pack',
|
||||||
|
'export_version': '1.0',
|
||||||
|
'exported_at': datetime.now().isoformat(),
|
||||||
|
'pack': pack.to_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
if format == 'yaml':
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
return yaml.dump(export_data, default_flow_style=False, allow_unicode=True)
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("PyYAML not installed, falling back to JSON")
|
||||||
|
return json.dumps(export_data, indent=2, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
return json.dumps(export_data, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
def import_pack(self, content: str, format: str = 'json',
|
||||||
|
merge_strategy: str = 'create_new') -> Optional[CorrectionPack]:
|
||||||
|
"""
|
||||||
|
Import a pack from JSON or YAML content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to import
|
||||||
|
format: Import format ('json' or 'yaml')
|
||||||
|
merge_strategy: How to handle existing pack:
|
||||||
|
- 'create_new': Always create new pack with new ID
|
||||||
|
- 'merge': Merge into existing pack if same ID
|
||||||
|
- 'replace': Replace existing pack if same ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Imported pack, or None on error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if format == 'yaml':
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
data = yaml.safe_load(content)
|
||||||
|
except ImportError:
|
||||||
|
logger.error("PyYAML not installed")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
data = json.loads(content)
|
||||||
|
|
||||||
|
# Validate format
|
||||||
|
if data.get('export_format') != 'correction_pack':
|
||||||
|
logger.warning("Invalid export format")
|
||||||
|
return None
|
||||||
|
|
||||||
|
pack_data = data.get('pack', {})
|
||||||
|
imported_pack = CorrectionPack.from_dict(pack_data)
|
||||||
|
|
||||||
|
# Handle merge strategy
|
||||||
|
if merge_strategy == 'create_new':
|
||||||
|
imported_pack.id = generate_pack_id()
|
||||||
|
self._packs[imported_pack.id] = imported_pack
|
||||||
|
elif merge_strategy == 'merge' and imported_pack.id in self._packs:
|
||||||
|
existing_pack = self._packs[imported_pack.id]
|
||||||
|
existing_pack.merge(imported_pack)
|
||||||
|
imported_pack = existing_pack
|
||||||
|
elif merge_strategy == 'replace':
|
||||||
|
self._packs[imported_pack.id] = imported_pack
|
||||||
|
else:
|
||||||
|
# Default: create new
|
||||||
|
imported_pack.id = generate_pack_id()
|
||||||
|
self._packs[imported_pack.id] = imported_pack
|
||||||
|
|
||||||
|
self._rebuild_index()
|
||||||
|
self._save_packs()
|
||||||
|
|
||||||
|
logger.info(f"Imported pack: {imported_pack.id} ({imported_pack.name})")
|
||||||
|
return imported_pack
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error importing pack: {e}")
|
||||||
|
return None
|
||||||
480
core/corrections/models.py
Normal file
480
core/corrections/models.py
Normal file
@@ -0,0 +1,480 @@
|
|||||||
|
"""
|
||||||
|
Data models for the Correction Packs system.
|
||||||
|
|
||||||
|
Defines CorrectionKey, Correction, CorrectionPack and related types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectionType(Enum):
|
||||||
|
"""Types of corrections that can be applied."""
|
||||||
|
TARGET_CHANGE = "target_change" # Changed target selector
|
||||||
|
PARAMETER_CHANGE = "parameter_change" # Changed action parameters
|
||||||
|
ACTION_REPLACE = "action_replace" # Replaced action type
|
||||||
|
WAIT_ADDED = "wait_added" # Added wait before action
|
||||||
|
RETRY_CONFIG = "retry_config" # Changed retry configuration
|
||||||
|
FALLBACK_ADDED = "fallback_added" # Added fallback action
|
||||||
|
COORDINATES_ADJUST = "coordinates_adjust" # Adjusted click coordinates
|
||||||
|
TIMING_ADJUST = "timing_adjust" # Adjusted timing/delay
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectionStatus(Enum):
|
||||||
|
"""Status of a correction."""
|
||||||
|
ACTIVE = "active" # Correction is active and can be applied
|
||||||
|
DEPRECATED = "deprecated" # Correction is deprecated
|
||||||
|
DISABLED = "disabled" # Correction is manually disabled
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CorrectionKey:
|
||||||
|
"""
|
||||||
|
Unique key to identify similar corrections across workflows.
|
||||||
|
|
||||||
|
Uses action_type + element_type + failure_context to generate MD5 hash.
|
||||||
|
Similar to LearningRepository pattern key generation.
|
||||||
|
"""
|
||||||
|
action_type: str
|
||||||
|
element_type: str
|
||||||
|
failure_context: str
|
||||||
|
|
||||||
|
def to_hash(self) -> str:
|
||||||
|
"""Generate MD5 hash of the correction key (16 chars)."""
|
||||||
|
key_string = f"{self.action_type}|{self.element_type}|{self.failure_context}"
|
||||||
|
return hashlib.md5(key_string.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
'action_type': self.action_type,
|
||||||
|
'element_type': self.element_type,
|
||||||
|
'failure_context': self.failure_context,
|
||||||
|
'hash': self.to_hash()
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionKey':
|
||||||
|
"""Create from dictionary."""
|
||||||
|
return cls(
|
||||||
|
action_type=data.get('action_type', 'unknown'),
|
||||||
|
element_type=data.get('element_type', 'unknown'),
|
||||||
|
failure_context=data.get('failure_context', '')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CorrectionSource:
|
||||||
|
"""Source information for a correction."""
|
||||||
|
session_id: str
|
||||||
|
workflow_id: Optional[str] = None
|
||||||
|
node_id: Optional[str] = None
|
||||||
|
timestamp: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
'session_id': self.session_id,
|
||||||
|
'workflow_id': self.workflow_id,
|
||||||
|
'node_id': self.node_id,
|
||||||
|
'timestamp': self.timestamp.isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionSource':
|
||||||
|
"""Create from dictionary."""
|
||||||
|
timestamp = data.get('timestamp')
|
||||||
|
if isinstance(timestamp, str):
|
||||||
|
timestamp = datetime.fromisoformat(timestamp)
|
||||||
|
elif timestamp is None:
|
||||||
|
timestamp = datetime.now()
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
session_id=data.get('session_id', ''),
|
||||||
|
workflow_id=data.get('workflow_id'),
|
||||||
|
node_id=data.get('node_id'),
|
||||||
|
timestamp=timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Correction:
|
||||||
|
"""
|
||||||
|
Individual correction record.
|
||||||
|
|
||||||
|
Contains the original context, the applied correction, source information,
|
||||||
|
and statistics about success/failure.
|
||||||
|
"""
|
||||||
|
id: str
|
||||||
|
key: CorrectionKey
|
||||||
|
correction_type: CorrectionType
|
||||||
|
|
||||||
|
# Original context
|
||||||
|
original_target: Optional[Dict[str, Any]] = None
|
||||||
|
original_params: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
# Correction applied
|
||||||
|
corrected_target: Optional[Dict[str, Any]] = None
|
||||||
|
corrected_params: Optional[Dict[str, Any]] = None
|
||||||
|
correction_description: str = ""
|
||||||
|
|
||||||
|
# Source information
|
||||||
|
source: Optional[CorrectionSource] = None
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
success_count: int = 0
|
||||||
|
failure_count: int = 0
|
||||||
|
last_applied: Optional[datetime] = None
|
||||||
|
|
||||||
|
# Status
|
||||||
|
status: CorrectionStatus = CorrectionStatus.ACTIVE
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
|
tags: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def success_rate(self) -> float:
|
||||||
|
"""Calculate success rate."""
|
||||||
|
total = self.success_count + self.failure_count
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.success_count / total
|
||||||
|
|
||||||
|
@property
|
||||||
|
def confidence_score(self) -> float:
|
||||||
|
"""Calculate confidence score based on success rate and usage count."""
|
||||||
|
total = self.success_count + self.failure_count
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
# Confidence increases with more data points (up to 100 uses)
|
||||||
|
usage_factor = min(1.0, total / 100)
|
||||||
|
return self.success_rate * (0.5 + 0.5 * usage_factor)
|
||||||
|
|
||||||
|
def record_application(self, success: bool):
|
||||||
|
"""Record an application of this correction."""
|
||||||
|
if success:
|
||||||
|
self.success_count += 1
|
||||||
|
else:
|
||||||
|
self.failure_count += 1
|
||||||
|
self.last_applied = datetime.now()
|
||||||
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
'id': self.id,
|
||||||
|
'key': self.key.to_dict(),
|
||||||
|
'correction_type': self.correction_type.value,
|
||||||
|
'original_target': self.original_target,
|
||||||
|
'original_params': self.original_params,
|
||||||
|
'corrected_target': self.corrected_target,
|
||||||
|
'corrected_params': self.corrected_params,
|
||||||
|
'correction_description': self.correction_description,
|
||||||
|
'source': self.source.to_dict() if self.source else None,
|
||||||
|
'success_count': self.success_count,
|
||||||
|
'failure_count': self.failure_count,
|
||||||
|
'success_rate': self.success_rate,
|
||||||
|
'confidence_score': self.confidence_score,
|
||||||
|
'last_applied': self.last_applied.isoformat() if self.last_applied else None,
|
||||||
|
'status': self.status.value,
|
||||||
|
'created_at': self.created_at.isoformat(),
|
||||||
|
'updated_at': self.updated_at.isoformat(),
|
||||||
|
'tags': self.tags
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'Correction':
|
||||||
|
"""Create from dictionary."""
|
||||||
|
# Parse datetime fields
|
||||||
|
created_at = data.get('created_at')
|
||||||
|
if isinstance(created_at, str):
|
||||||
|
created_at = datetime.fromisoformat(created_at)
|
||||||
|
elif created_at is None:
|
||||||
|
created_at = datetime.now()
|
||||||
|
|
||||||
|
updated_at = data.get('updated_at')
|
||||||
|
if isinstance(updated_at, str):
|
||||||
|
updated_at = datetime.fromisoformat(updated_at)
|
||||||
|
elif updated_at is None:
|
||||||
|
updated_at = datetime.now()
|
||||||
|
|
||||||
|
last_applied = data.get('last_applied')
|
||||||
|
if isinstance(last_applied, str):
|
||||||
|
last_applied = datetime.fromisoformat(last_applied)
|
||||||
|
|
||||||
|
# Parse nested objects
|
||||||
|
key_data = data.get('key', {})
|
||||||
|
key = CorrectionKey.from_dict(key_data)
|
||||||
|
|
||||||
|
source_data = data.get('source')
|
||||||
|
source = CorrectionSource.from_dict(source_data) if source_data else None
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=data.get('id', str(uuid.uuid4())),
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType(data.get('correction_type', 'other')),
|
||||||
|
original_target=data.get('original_target'),
|
||||||
|
original_params=data.get('original_params'),
|
||||||
|
corrected_target=data.get('corrected_target'),
|
||||||
|
corrected_params=data.get('corrected_params'),
|
||||||
|
correction_description=data.get('correction_description', ''),
|
||||||
|
source=source,
|
||||||
|
success_count=data.get('success_count', 0),
|
||||||
|
failure_count=data.get('failure_count', 0),
|
||||||
|
last_applied=last_applied,
|
||||||
|
status=CorrectionStatus(data.get('status', 'active')),
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=updated_at,
|
||||||
|
tags=data.get('tags', [])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CorrectionPackMetadata:
|
||||||
|
"""Metadata for a correction pack."""
|
||||||
|
version: str = "1.0.0"
|
||||||
|
author: str = ""
|
||||||
|
category: str = "general"
|
||||||
|
tags: List[str] = field(default_factory=list)
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
'version': self.version,
|
||||||
|
'author': self.author,
|
||||||
|
'category': self.category,
|
||||||
|
'tags': self.tags,
|
||||||
|
'description': self.description
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPackMetadata':
|
||||||
|
"""Create from dictionary."""
|
||||||
|
return cls(
|
||||||
|
version=data.get('version', '1.0.0'),
|
||||||
|
author=data.get('author', ''),
|
||||||
|
category=data.get('category', 'general'),
|
||||||
|
tags=data.get('tags', []),
|
||||||
|
description=data.get('description', '')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CorrectionPack:
|
||||||
|
"""
|
||||||
|
A pack of corrections that can be exported, imported, and shared.
|
||||||
|
|
||||||
|
Contains multiple corrections with metadata and statistics.
|
||||||
|
"""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
# Corrections in this pack
|
||||||
|
corrections: Dict[str, Correction] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
metadata: CorrectionPackMetadata = field(default_factory=CorrectionPackMetadata)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
# Version history
|
||||||
|
current_version: int = 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def correction_count(self) -> int:
|
||||||
|
"""Number of corrections in the pack."""
|
||||||
|
return len(self.corrections)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_applications(self) -> int:
|
||||||
|
"""Total number of times corrections were applied."""
|
||||||
|
return sum(c.success_count + c.failure_count for c in self.corrections.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def overall_success_rate(self) -> float:
|
||||||
|
"""Overall success rate across all corrections."""
|
||||||
|
total_success = sum(c.success_count for c in self.corrections.values())
|
||||||
|
total_failure = sum(c.failure_count for c in self.corrections.values())
|
||||||
|
total = total_success + total_failure
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
return total_success / total
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_corrections(self) -> List[Correction]:
|
||||||
|
"""Get only active corrections."""
|
||||||
|
return [c for c in self.corrections.values() if c.status == CorrectionStatus.ACTIVE]
|
||||||
|
|
||||||
|
def add_correction(self, correction: Correction) -> bool:
|
||||||
|
"""
|
||||||
|
Add a correction to the pack.
|
||||||
|
|
||||||
|
Returns True if added, False if already exists.
|
||||||
|
"""
|
||||||
|
if correction.id in self.corrections:
|
||||||
|
return False
|
||||||
|
self.corrections[correction.id] = correction
|
||||||
|
self.updated_at = datetime.now()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def remove_correction(self, correction_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a correction from the pack.
|
||||||
|
|
||||||
|
Returns True if removed, False if not found.
|
||||||
|
"""
|
||||||
|
if correction_id not in self.corrections:
|
||||||
|
return False
|
||||||
|
del self.corrections[correction_id]
|
||||||
|
self.updated_at = datetime.now()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_correction(self, correction_id: str) -> Optional[Correction]:
|
||||||
|
"""Get a correction by ID."""
|
||||||
|
return self.corrections.get(correction_id)
|
||||||
|
|
||||||
|
def find_by_key_hash(self, key_hash: str) -> List[Correction]:
|
||||||
|
"""Find corrections matching a key hash."""
|
||||||
|
return [
|
||||||
|
c for c in self.corrections.values()
|
||||||
|
if c.key.to_hash() == key_hash and c.status == CorrectionStatus.ACTIVE
|
||||||
|
]
|
||||||
|
|
||||||
|
def merge(self, other: 'CorrectionPack', conflict_strategy: str = 'keep_higher_confidence') -> int:
|
||||||
|
"""
|
||||||
|
Merge another pack into this one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: Pack to merge
|
||||||
|
conflict_strategy: How to handle conflicts:
|
||||||
|
- 'keep_higher_confidence': Keep correction with higher confidence
|
||||||
|
- 'keep_newer': Keep the most recently updated correction
|
||||||
|
- 'keep_existing': Keep existing, ignore new
|
||||||
|
- 'replace': Always replace with new
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of corrections added/updated
|
||||||
|
"""
|
||||||
|
merged_count = 0
|
||||||
|
|
||||||
|
for correction_id, correction in other.corrections.items():
|
||||||
|
if correction_id in self.corrections:
|
||||||
|
existing = self.corrections[correction_id]
|
||||||
|
|
||||||
|
should_replace = False
|
||||||
|
if conflict_strategy == 'keep_higher_confidence':
|
||||||
|
should_replace = correction.confidence_score > existing.confidence_score
|
||||||
|
elif conflict_strategy == 'keep_newer':
|
||||||
|
should_replace = correction.updated_at > existing.updated_at
|
||||||
|
elif conflict_strategy == 'replace':
|
||||||
|
should_replace = True
|
||||||
|
# 'keep_existing' doesn't replace
|
||||||
|
|
||||||
|
if should_replace:
|
||||||
|
self.corrections[correction_id] = correction
|
||||||
|
merged_count += 1
|
||||||
|
else:
|
||||||
|
self.corrections[correction_id] = correction
|
||||||
|
merged_count += 1
|
||||||
|
|
||||||
|
if merged_count > 0:
|
||||||
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
|
return merged_count
|
||||||
|
|
||||||
|
def get_statistics(self) -> Dict[str, Any]:
|
||||||
|
"""Get aggregate statistics for the pack."""
|
||||||
|
corrections = list(self.corrections.values())
|
||||||
|
active = [c for c in corrections if c.status == CorrectionStatus.ACTIVE]
|
||||||
|
|
||||||
|
# Group by correction type
|
||||||
|
by_type = {}
|
||||||
|
for c in corrections:
|
||||||
|
type_name = c.correction_type.value
|
||||||
|
if type_name not in by_type:
|
||||||
|
by_type[type_name] = 0
|
||||||
|
by_type[type_name] += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_corrections': len(corrections),
|
||||||
|
'active_corrections': len(active),
|
||||||
|
'deprecated_corrections': len([c for c in corrections if c.status == CorrectionStatus.DEPRECATED]),
|
||||||
|
'disabled_corrections': len([c for c in corrections if c.status == CorrectionStatus.DISABLED]),
|
||||||
|
'total_applications': self.total_applications,
|
||||||
|
'overall_success_rate': self.overall_success_rate,
|
||||||
|
'corrections_by_type': by_type,
|
||||||
|
'average_confidence': sum(c.confidence_score for c in active) / len(active) if active else 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
'id': self.id,
|
||||||
|
'name': self.name,
|
||||||
|
'description': self.description,
|
||||||
|
'corrections': {cid: c.to_dict() for cid, c in self.corrections.items()},
|
||||||
|
'metadata': self.metadata.to_dict(),
|
||||||
|
'created_at': self.created_at.isoformat(),
|
||||||
|
'updated_at': self.updated_at.isoformat(),
|
||||||
|
'current_version': self.current_version,
|
||||||
|
'statistics': self.get_statistics()
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPack':
|
||||||
|
"""Create from dictionary."""
|
||||||
|
# Parse datetime fields
|
||||||
|
created_at = data.get('created_at')
|
||||||
|
if isinstance(created_at, str):
|
||||||
|
created_at = datetime.fromisoformat(created_at)
|
||||||
|
elif created_at is None:
|
||||||
|
created_at = datetime.now()
|
||||||
|
|
||||||
|
updated_at = data.get('updated_at')
|
||||||
|
if isinstance(updated_at, str):
|
||||||
|
updated_at = datetime.fromisoformat(updated_at)
|
||||||
|
elif updated_at is None:
|
||||||
|
updated_at = datetime.now()
|
||||||
|
|
||||||
|
# Parse corrections
|
||||||
|
corrections_data = data.get('corrections', {})
|
||||||
|
corrections = {
|
||||||
|
cid: Correction.from_dict(cdata)
|
||||||
|
for cid, cdata in corrections_data.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse metadata
|
||||||
|
metadata_data = data.get('metadata', {})
|
||||||
|
metadata = CorrectionPackMetadata.from_dict(metadata_data)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=data.get('id', str(uuid.uuid4())),
|
||||||
|
name=data.get('name', 'Unnamed Pack'),
|
||||||
|
description=data.get('description', ''),
|
||||||
|
corrections=corrections,
|
||||||
|
metadata=metadata,
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=updated_at,
|
||||||
|
current_version=data.get('current_version', 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_correction_id() -> str:
|
||||||
|
"""Generate a unique correction ID."""
|
||||||
|
return f"corr_{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pack_id() -> str:
|
||||||
|
"""Generate a unique pack ID."""
|
||||||
|
return f"pack_{uuid.uuid4().hex[:12]}"
|
||||||
443
core/healing/healing_engine.py
Normal file
443
core/healing/healing_engine.py
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
"""Main self-healing engine for workflow recovery."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
|
from pathlib import Path
|
||||||
|
from .models import RecoveryContext, RecoveryResult, RecoverySuggestion
|
||||||
|
from .learning_repository import LearningRepository
|
||||||
|
from .confidence_scorer import ConfidenceScorer
|
||||||
|
from .strategies import (
|
||||||
|
RecoveryStrategy,
|
||||||
|
SemanticVariantStrategy,
|
||||||
|
SpatialFallbackStrategy,
|
||||||
|
TimingAdaptationStrategy,
|
||||||
|
FormatTransformationStrategy
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfHealingEngine:
|
||||||
|
"""Main engine for self-healing workflows."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_repo: Optional[LearningRepository] = None,
|
||||||
|
storage_path: Optional[Path] = None,
|
||||||
|
correction_packs_enabled: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the self-healing engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_repo: Learning repository instance
|
||||||
|
storage_path: Path for storing learned patterns
|
||||||
|
correction_packs_enabled: Whether to consult correction packs
|
||||||
|
"""
|
||||||
|
# Initialize learning repository
|
||||||
|
if learning_repo:
|
||||||
|
self.learning_repo = learning_repo
|
||||||
|
else:
|
||||||
|
storage_path = storage_path or Path('data/healing')
|
||||||
|
self.learning_repo = LearningRepository(storage_path)
|
||||||
|
|
||||||
|
# Initialize confidence scorer
|
||||||
|
self.confidence_scorer = ConfidenceScorer()
|
||||||
|
|
||||||
|
# Initialize recovery strategies
|
||||||
|
self.recovery_strategies: List[RecoveryStrategy] = [
|
||||||
|
SemanticVariantStrategy(),
|
||||||
|
SpatialFallbackStrategy(),
|
||||||
|
TimingAdaptationStrategy(),
|
||||||
|
FormatTransformationStrategy()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
self.max_recovery_time = 30.0 # seconds
|
||||||
|
self.parallel_execution = True
|
||||||
|
|
||||||
|
# Correction packs integration
|
||||||
|
self.correction_packs_enabled = correction_packs_enabled
|
||||||
|
self._correction_pack_service = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def correction_pack_service(self):
|
||||||
|
"""Lazy-load correction pack service."""
|
||||||
|
if self._correction_pack_service is None and self.correction_packs_enabled:
|
||||||
|
try:
|
||||||
|
from core.corrections import CorrectionPackService
|
||||||
|
self._correction_pack_service = CorrectionPackService()
|
||||||
|
logger.info("CorrectionPackService initialized for healing engine")
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"CorrectionPackService not available: {e}")
|
||||||
|
self.correction_packs_enabled = False
|
||||||
|
return self._correction_pack_service
|
||||||
|
|
||||||
|
def attempt_recovery(self, context: RecoveryContext) -> RecoveryResult:
|
||||||
|
"""
|
||||||
|
Attempt to recover from a workflow failure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Recovery context with failure information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RecoveryResult with outcome
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Check if we've exceeded max attempts
|
||||||
|
if context.attempt_count >= context.max_attempts:
|
||||||
|
return RecoveryResult(
|
||||||
|
success=False,
|
||||||
|
strategy_used='none',
|
||||||
|
error_message=f'Max recovery attempts ({context.max_attempts}) exceeded'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try correction packs first (if enabled)
|
||||||
|
if self.correction_packs_enabled and self.correction_pack_service:
|
||||||
|
pack_result = self._try_correction_packs(context)
|
||||||
|
if pack_result and pack_result.success:
|
||||||
|
return pack_result
|
||||||
|
|
||||||
|
# Get learned patterns for this context
|
||||||
|
learned_patterns = self.learning_repo.get_matching_patterns(context)
|
||||||
|
|
||||||
|
# Get prioritized strategies
|
||||||
|
strategies = self._prioritize_strategies(context, learned_patterns)
|
||||||
|
|
||||||
|
# Try each strategy
|
||||||
|
for strategy in strategies:
|
||||||
|
# Check time limit
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
if elapsed >= self.max_recovery_time:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Skip if strategy can't handle this failure
|
||||||
|
if not strategy.can_handle(context):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Attempt recovery
|
||||||
|
result = strategy.attempt_recovery(context)
|
||||||
|
|
||||||
|
# If successful, learn from it
|
||||||
|
if result.success:
|
||||||
|
# Calculate final confidence
|
||||||
|
historical_success = self._get_historical_success_rate(
|
||||||
|
strategy.name, context, learned_patterns
|
||||||
|
)
|
||||||
|
result.confidence_score = self.confidence_scorer.calculate_recovery_confidence(
|
||||||
|
result.strategy_used,
|
||||||
|
context,
|
||||||
|
historical_success
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if safe to proceed
|
||||||
|
if self.confidence_scorer.is_safe_to_proceed(
|
||||||
|
result.confidence_score,
|
||||||
|
context.confidence_threshold,
|
||||||
|
involves_data_modification=self._involves_data_modification(context)
|
||||||
|
):
|
||||||
|
# Learn from success
|
||||||
|
self.learn_from_success(context, result)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
# Confidence too low, mark as requiring user input
|
||||||
|
result.requires_user_input = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
# All strategies failed
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
return RecoveryResult(
|
||||||
|
success=False,
|
||||||
|
strategy_used='all_failed',
|
||||||
|
execution_time=total_time,
|
||||||
|
error_message='All recovery strategies failed',
|
||||||
|
requires_user_input=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def learn_from_success(self, context: RecoveryContext, result: RecoveryResult):
|
||||||
|
"""
|
||||||
|
Learn from successful recovery for future use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Recovery context
|
||||||
|
result: Successful recovery result
|
||||||
|
"""
|
||||||
|
self.learning_repo.store_pattern(context, result)
|
||||||
|
|
||||||
|
def get_recovery_suggestions(self, context: RecoveryContext) -> List[RecoverySuggestion]:
|
||||||
|
"""
|
||||||
|
Get ranked recovery suggestions based on learned patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Recovery context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of recovery suggestions sorted by confidence
|
||||||
|
"""
|
||||||
|
suggestions = []
|
||||||
|
|
||||||
|
# Get learned patterns
|
||||||
|
learned_patterns = self.learning_repo.get_matching_patterns(context)
|
||||||
|
|
||||||
|
# Get suggestions from each strategy
|
||||||
|
for strategy in self.recovery_strategies:
|
||||||
|
if not strategy.can_handle(context):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate confidence based on historical success
|
||||||
|
historical_success = self._get_historical_success_rate(
|
||||||
|
strategy.name, context, learned_patterns
|
||||||
|
)
|
||||||
|
confidence = self.confidence_scorer.calculate_recovery_confidence(
|
||||||
|
strategy.name,
|
||||||
|
context,
|
||||||
|
historical_success
|
||||||
|
)
|
||||||
|
|
||||||
|
suggestion = RecoverySuggestion(
|
||||||
|
strategy=strategy.name,
|
||||||
|
confidence=confidence,
|
||||||
|
description=self._get_strategy_description(strategy),
|
||||||
|
estimated_time=self._estimate_strategy_time(strategy, context)
|
||||||
|
)
|
||||||
|
suggestions.append(suggestion)
|
||||||
|
|
||||||
|
# Sort by confidence
|
||||||
|
suggestions.sort(key=lambda s: s.confidence, reverse=True)
|
||||||
|
|
||||||
|
return suggestions
|
||||||
|
|
||||||
|
def prune_learned_patterns(
|
||||||
|
self,
|
||||||
|
max_age_days: int = 90,
|
||||||
|
min_confidence: float = 0.3
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Prune outdated learned patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_age_days: Maximum age for patterns
|
||||||
|
min_confidence: Minimum confidence threshold
|
||||||
|
"""
|
||||||
|
self.learning_repo.prune_outdated_patterns(max_age_days, min_confidence)
|
||||||
|
|
||||||
|
def _prioritize_strategies(
|
||||||
|
self,
|
||||||
|
context: RecoveryContext,
|
||||||
|
learned_patterns: List
|
||||||
|
) -> List[RecoveryStrategy]:
|
||||||
|
"""Prioritize strategies based on context and learned patterns."""
|
||||||
|
# Score each strategy
|
||||||
|
scored_strategies = []
|
||||||
|
for strategy in self.recovery_strategies:
|
||||||
|
# Base priority from strategy
|
||||||
|
priority = strategy.get_priority(context)
|
||||||
|
|
||||||
|
# Boost priority if we have successful patterns
|
||||||
|
historical_success = self._get_historical_success_rate(
|
||||||
|
strategy.name, context, learned_patterns
|
||||||
|
)
|
||||||
|
if historical_success > 0:
|
||||||
|
priority *= (1.0 + historical_success)
|
||||||
|
|
||||||
|
scored_strategies.append((priority, strategy))
|
||||||
|
|
||||||
|
# Sort by priority (highest first)
|
||||||
|
scored_strategies.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
|
return [strategy for _, strategy in scored_strategies]
|
||||||
|
|
||||||
|
def _get_historical_success_rate(
|
||||||
|
self,
|
||||||
|
strategy_name: str,
|
||||||
|
context: RecoveryContext,
|
||||||
|
learned_patterns: List
|
||||||
|
) -> float:
|
||||||
|
"""Get historical success rate for a strategy."""
|
||||||
|
matching_patterns = [
|
||||||
|
p for p in learned_patterns
|
||||||
|
if p.recovery_strategy == strategy_name
|
||||||
|
]
|
||||||
|
|
||||||
|
if not matching_patterns:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Return average success rate
|
||||||
|
return sum(p.success_rate for p in matching_patterns) / len(matching_patterns)
|
||||||
|
|
||||||
|
def _involves_data_modification(self, context: RecoveryContext) -> bool:
|
||||||
|
"""Check if the action involves data modification."""
|
||||||
|
data_modification_actions = [
|
||||||
|
'input', 'type', 'submit', 'delete', 'update', 'save'
|
||||||
|
]
|
||||||
|
return any(action in context.original_action.lower()
|
||||||
|
for action in data_modification_actions)
|
||||||
|
|
||||||
|
def _get_strategy_description(self, strategy: RecoveryStrategy) -> str:
|
||||||
|
"""Get human-readable description of strategy."""
|
||||||
|
descriptions = {
|
||||||
|
'SemanticVariantStrategy': 'Try semantic variants of the element (e.g., Submit → Send)',
|
||||||
|
'SpatialFallbackStrategy': 'Search in expanded area around original position',
|
||||||
|
'TimingAdaptationStrategy': 'Increase wait time for element to appear',
|
||||||
|
'FormatTransformationStrategy': 'Transform input format to match validation'
|
||||||
|
}
|
||||||
|
return descriptions.get(strategy.name, strategy.name)
|
||||||
|
|
||||||
|
def _estimate_strategy_time(
|
||||||
|
self,
|
||||||
|
strategy: RecoveryStrategy,
|
||||||
|
context: RecoveryContext
|
||||||
|
) -> float:
|
||||||
|
"""Estimate execution time for strategy."""
|
||||||
|
# Simple estimates in seconds
|
||||||
|
estimates = {
|
||||||
|
'SemanticVariantStrategy': 2.0,
|
||||||
|
'SpatialFallbackStrategy': 5.0,
|
||||||
|
'TimingAdaptationStrategy': 10.0,
|
||||||
|
'FormatTransformationStrategy': 1.0
|
||||||
|
}
|
||||||
|
return estimates.get(strategy.name, 3.0)
|
||||||
|
|
||||||
|
def _try_correction_packs(self, context: RecoveryContext) -> Optional[RecoveryResult]:
|
||||||
|
"""
|
||||||
|
Try to find and apply a correction from correction packs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Recovery context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RecoveryResult if a correction was successfully applied, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get element type from metadata
|
||||||
|
element_type = context.metadata.get('element_type', 'unknown')
|
||||||
|
|
||||||
|
# Search for applicable corrections
|
||||||
|
corrections = self.correction_pack_service.find_applicable_corrections(
|
||||||
|
action_type=context.original_action,
|
||||||
|
element_type=element_type,
|
||||||
|
failure_context=context.failure_reason,
|
||||||
|
min_confidence=0.5 # Only use high-confidence corrections
|
||||||
|
)
|
||||||
|
|
||||||
|
if not corrections:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try corrections in order of confidence
|
||||||
|
for correction_info in corrections:
|
||||||
|
pack_id = correction_info['pack_id']
|
||||||
|
correction = correction_info['correction']
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply the correction
|
||||||
|
applied = self._apply_correction(context, correction)
|
||||||
|
|
||||||
|
if applied:
|
||||||
|
# Record successful application
|
||||||
|
self.correction_pack_service.apply_correction(
|
||||||
|
pack_id=pack_id,
|
||||||
|
correction_id=correction['id'],
|
||||||
|
success=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Propagate to learning repository
|
||||||
|
result = RecoveryResult(
|
||||||
|
success=True,
|
||||||
|
strategy_used=f"correction_pack:{correction['correction_type']}",
|
||||||
|
confidence_score=correction['confidence_score'],
|
||||||
|
modified_action=correction.get('corrected_params'),
|
||||||
|
modified_target=correction.get('corrected_target')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Learn from this for future reference
|
||||||
|
self.learning_repo.store_pattern(context, result)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Applied correction from pack {pack_id}: "
|
||||||
|
f"{correction['id']} (confidence: {correction['confidence_score']:.2f})"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to apply correction {correction['id']}: {e}")
|
||||||
|
# Record failed application
|
||||||
|
self.correction_pack_service.apply_correction(
|
||||||
|
pack_id=pack_id,
|
||||||
|
correction_id=correction['id'],
|
||||||
|
success=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in correction pack lookup: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _apply_correction(
|
||||||
|
self,
|
||||||
|
context: RecoveryContext,
|
||||||
|
correction: Dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Apply a correction to the current context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Recovery context
|
||||||
|
correction: Correction data dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if correction was applied successfully
|
||||||
|
"""
|
||||||
|
correction_type = correction.get('correction_type', 'other')
|
||||||
|
|
||||||
|
# Handle different correction types
|
||||||
|
if correction_type == 'target_change':
|
||||||
|
# Update target in context
|
||||||
|
if correction.get('corrected_target'):
|
||||||
|
context.metadata['corrected_target'] = correction['corrected_target']
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif correction_type == 'parameter_change':
|
||||||
|
# Update parameters in context
|
||||||
|
if correction.get('corrected_params'):
|
||||||
|
context.metadata['corrected_params'] = correction['corrected_params']
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif correction_type == 'wait_added':
|
||||||
|
# Add wait time
|
||||||
|
wait_time = correction.get('corrected_params', {}).get('wait_time', 2.0)
|
||||||
|
context.metadata['additional_wait'] = wait_time
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif correction_type == 'timing_adjust':
|
||||||
|
# Adjust timing
|
||||||
|
timing = correction.get('corrected_params', {}).get('timing_multiplier', 1.5)
|
||||||
|
context.metadata['timing_multiplier'] = timing
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif correction_type == 'coordinates_adjust':
|
||||||
|
# Adjust coordinates
|
||||||
|
offset = correction.get('corrected_params', {}).get('offset', {})
|
||||||
|
context.metadata['coordinate_offset'] = offset
|
||||||
|
return True
|
||||||
|
|
||||||
|
# For other types, just mark as applied if we have any correction data
|
||||||
|
return bool(correction.get('corrected_target') or correction.get('corrected_params'))
|
||||||
|
|
||||||
|
def get_correction_pack_statistics(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get statistics from correction packs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Statistics dict or None if not available
|
||||||
|
"""
|
||||||
|
if not self.correction_packs_enabled or not self.correction_pack_service:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.correction_pack_service.get_global_statistics()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting correction pack statistics: {e}")
|
||||||
|
return None
|
||||||
644
tests/test_correction_packs.py
Normal file
644
tests/test_correction_packs.py
Normal file
@@ -0,0 +1,644 @@
|
|||||||
|
"""
|
||||||
|
Tests for the Correction Packs system.
|
||||||
|
|
||||||
|
Tests models, repository, aggregator, and service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from core.corrections.models import (
|
||||||
|
CorrectionKey,
|
||||||
|
CorrectionType,
|
||||||
|
CorrectionStatus,
|
||||||
|
Correction,
|
||||||
|
CorrectionPack,
|
||||||
|
CorrectionPackMetadata,
|
||||||
|
CorrectionSource,
|
||||||
|
generate_correction_id,
|
||||||
|
generate_pack_id
|
||||||
|
)
|
||||||
|
from core.corrections.correction_repository import CorrectionRepository
|
||||||
|
from core.corrections.aggregator import CorrectionAggregator, CrossWorkflowAggregator
|
||||||
|
from core.corrections.correction_pack_service import CorrectionPackService
|
||||||
|
|
||||||
|
|
||||||
|
class TestCorrectionKey(unittest.TestCase):
|
||||||
|
"""Tests for CorrectionKey model."""
|
||||||
|
|
||||||
|
def test_to_hash(self):
|
||||||
|
"""Test hash generation is consistent."""
|
||||||
|
key1 = CorrectionKey("click", "button", "element_not_found")
|
||||||
|
key2 = CorrectionKey("click", "button", "element_not_found")
|
||||||
|
key3 = CorrectionKey("type", "input", "element_not_found")
|
||||||
|
|
||||||
|
# Same keys should produce same hash
|
||||||
|
self.assertEqual(key1.to_hash(), key2.to_hash())
|
||||||
|
|
||||||
|
# Different keys should produce different hash
|
||||||
|
self.assertNotEqual(key1.to_hash(), key3.to_hash())
|
||||||
|
|
||||||
|
# Hash should be 16 characters
|
||||||
|
self.assertEqual(len(key1.to_hash()), 16)
|
||||||
|
|
||||||
|
def test_to_dict_from_dict(self):
|
||||||
|
"""Test serialization round-trip."""
|
||||||
|
key = CorrectionKey("click", "button", "timeout")
|
||||||
|
key_dict = key.to_dict()
|
||||||
|
|
||||||
|
self.assertEqual(key_dict['action_type'], "click")
|
||||||
|
self.assertEqual(key_dict['element_type'], "button")
|
||||||
|
self.assertEqual(key_dict['failure_context'], "timeout")
|
||||||
|
self.assertIn('hash', key_dict)
|
||||||
|
|
||||||
|
restored = CorrectionKey.from_dict(key_dict)
|
||||||
|
self.assertEqual(key.to_hash(), restored.to_hash())
|
||||||
|
|
||||||
|
|
||||||
|
class TestCorrection(unittest.TestCase):
|
||||||
|
"""Tests for Correction model."""
|
||||||
|
|
||||||
|
def test_create_correction(self):
|
||||||
|
"""Test creating a correction."""
|
||||||
|
key = CorrectionKey("click", "button", "not_visible")
|
||||||
|
correction = Correction(
|
||||||
|
id=generate_correction_id(),
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.WAIT_ADDED,
|
||||||
|
original_target={"selector": "#submit"},
|
||||||
|
corrected_params={"wait_time": 2.0},
|
||||||
|
correction_description="Added wait for button visibility"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(correction.id.startswith("corr_"))
|
||||||
|
self.assertEqual(correction.correction_type, CorrectionType.WAIT_ADDED)
|
||||||
|
self.assertEqual(correction.success_rate, 0.0)
|
||||||
|
|
||||||
|
def test_success_rate_calculation(self):
|
||||||
|
"""Test success rate calculation."""
|
||||||
|
key = CorrectionKey("click", "button", "not_visible")
|
||||||
|
correction = Correction(
|
||||||
|
id=generate_correction_id(),
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.WAIT_ADDED,
|
||||||
|
success_count=7,
|
||||||
|
failure_count=3
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(correction.success_rate, 0.7)
|
||||||
|
|
||||||
|
def test_record_application(self):
|
||||||
|
"""Test recording application updates stats."""
|
||||||
|
key = CorrectionKey("click", "button", "not_visible")
|
||||||
|
correction = Correction(
|
||||||
|
id=generate_correction_id(),
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.WAIT_ADDED
|
||||||
|
)
|
||||||
|
|
||||||
|
correction.record_application(success=True)
|
||||||
|
self.assertEqual(correction.success_count, 1)
|
||||||
|
self.assertEqual(correction.failure_count, 0)
|
||||||
|
self.assertIsNotNone(correction.last_applied)
|
||||||
|
|
||||||
|
correction.record_application(success=False)
|
||||||
|
self.assertEqual(correction.success_count, 1)
|
||||||
|
self.assertEqual(correction.failure_count, 1)
|
||||||
|
|
||||||
|
def test_to_dict_from_dict(self):
|
||||||
|
"""Test serialization round-trip."""
|
||||||
|
key = CorrectionKey("type", "input", "validation_error")
|
||||||
|
source = CorrectionSource(
|
||||||
|
session_id="sess_123",
|
||||||
|
workflow_id="wf_456",
|
||||||
|
node_id="node_789"
|
||||||
|
)
|
||||||
|
correction = Correction(
|
||||||
|
id="corr_test123",
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.PARAMETER_CHANGE,
|
||||||
|
original_params={"value": "test"},
|
||||||
|
corrected_params={"value": "corrected"},
|
||||||
|
source=source,
|
||||||
|
success_count=5,
|
||||||
|
failure_count=2,
|
||||||
|
tags=["form", "validation"]
|
||||||
|
)
|
||||||
|
|
||||||
|
correction_dict = correction.to_dict()
|
||||||
|
restored = Correction.from_dict(correction_dict)
|
||||||
|
|
||||||
|
self.assertEqual(restored.id, correction.id)
|
||||||
|
self.assertEqual(restored.key.to_hash(), key.to_hash())
|
||||||
|
self.assertEqual(restored.correction_type, CorrectionType.PARAMETER_CHANGE)
|
||||||
|
self.assertEqual(restored.success_count, 5)
|
||||||
|
self.assertEqual(restored.tags, ["form", "validation"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestCorrectionPack(unittest.TestCase):
|
||||||
|
"""Tests for CorrectionPack model."""
|
||||||
|
|
||||||
|
def test_create_pack(self):
|
||||||
|
"""Test creating a pack."""
|
||||||
|
pack = CorrectionPack(
|
||||||
|
id=generate_pack_id(),
|
||||||
|
name="Test Pack",
|
||||||
|
description="A test correction pack"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(pack.id.startswith("pack_"))
|
||||||
|
self.assertEqual(pack.name, "Test Pack")
|
||||||
|
self.assertEqual(pack.correction_count, 0)
|
||||||
|
|
||||||
|
def test_add_remove_correction(self):
|
||||||
|
"""Test adding and removing corrections."""
|
||||||
|
pack = CorrectionPack(
|
||||||
|
id=generate_pack_id(),
|
||||||
|
name="Test Pack"
|
||||||
|
)
|
||||||
|
|
||||||
|
key = CorrectionKey("click", "button", "not_found")
|
||||||
|
correction = Correction(
|
||||||
|
id="corr_test1",
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.TARGET_CHANGE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add correction
|
||||||
|
self.assertTrue(pack.add_correction(correction))
|
||||||
|
self.assertEqual(pack.correction_count, 1)
|
||||||
|
|
||||||
|
# Adding same correction should return False
|
||||||
|
self.assertFalse(pack.add_correction(correction))
|
||||||
|
|
||||||
|
# Remove correction
|
||||||
|
self.assertTrue(pack.remove_correction("corr_test1"))
|
||||||
|
self.assertEqual(pack.correction_count, 0)
|
||||||
|
|
||||||
|
# Removing non-existent should return False
|
||||||
|
self.assertFalse(pack.remove_correction("corr_test1"))
|
||||||
|
|
||||||
|
def test_find_by_key_hash(self):
|
||||||
|
"""Test finding corrections by key hash."""
|
||||||
|
pack = CorrectionPack(id=generate_pack_id(), name="Test Pack")
|
||||||
|
|
||||||
|
key1 = CorrectionKey("click", "button", "not_found")
|
||||||
|
key2 = CorrectionKey("type", "input", "validation")
|
||||||
|
|
||||||
|
correction1 = Correction(id="corr_1", key=key1, correction_type=CorrectionType.TARGET_CHANGE)
|
||||||
|
correction2 = Correction(id="corr_2", key=key2, correction_type=CorrectionType.PARAMETER_CHANGE)
|
||||||
|
|
||||||
|
pack.add_correction(correction1)
|
||||||
|
pack.add_correction(correction2)
|
||||||
|
|
||||||
|
# Find by key1 hash
|
||||||
|
found = pack.find_by_key_hash(key1.to_hash())
|
||||||
|
self.assertEqual(len(found), 1)
|
||||||
|
self.assertEqual(found[0].id, "corr_1")
|
||||||
|
|
||||||
|
def test_merge_packs(self):
|
||||||
|
"""Test merging two packs."""
|
||||||
|
pack1 = CorrectionPack(id="pack_1", name="Pack 1")
|
||||||
|
pack2 = CorrectionPack(id="pack_2", name="Pack 2")
|
||||||
|
|
||||||
|
key = CorrectionKey("click", "button", "error")
|
||||||
|
corr1 = Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE)
|
||||||
|
corr2 = Correction(id="corr_2", key=key, correction_type=CorrectionType.WAIT_ADDED)
|
||||||
|
|
||||||
|
pack1.add_correction(corr1)
|
||||||
|
pack2.add_correction(corr2)
|
||||||
|
|
||||||
|
merged_count = pack1.merge(pack2)
|
||||||
|
self.assertEqual(merged_count, 1)
|
||||||
|
self.assertEqual(pack1.correction_count, 2)
|
||||||
|
|
||||||
|
def test_get_statistics(self):
|
||||||
|
"""Test getting pack statistics."""
|
||||||
|
pack = CorrectionPack(id=generate_pack_id(), name="Test Pack")
|
||||||
|
|
||||||
|
key = CorrectionKey("click", "button", "error")
|
||||||
|
correction = Correction(
|
||||||
|
id="corr_1",
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.TARGET_CHANGE,
|
||||||
|
success_count=10,
|
||||||
|
failure_count=2
|
||||||
|
)
|
||||||
|
pack.add_correction(correction)
|
||||||
|
|
||||||
|
stats = pack.get_statistics()
|
||||||
|
self.assertEqual(stats['total_corrections'], 1)
|
||||||
|
self.assertEqual(stats['active_corrections'], 1)
|
||||||
|
self.assertEqual(stats['total_applications'], 12)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCorrectionRepository(unittest.TestCase):
|
||||||
|
"""Tests for CorrectionRepository."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Create temporary directory for tests."""
|
||||||
|
self.test_dir = tempfile.mkdtemp()
|
||||||
|
self.repo = CorrectionRepository(self.test_dir)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up temporary directory."""
|
||||||
|
shutil.rmtree(self.test_dir)
|
||||||
|
|
||||||
|
def test_create_and_get_pack(self):
|
||||||
|
"""Test creating and retrieving a pack."""
|
||||||
|
pack = self.repo.create_pack("Test Pack", "Description", {
|
||||||
|
'category': 'testing',
|
||||||
|
'tags': ['unit', 'test']
|
||||||
|
})
|
||||||
|
|
||||||
|
self.assertIsNotNone(pack)
|
||||||
|
self.assertEqual(pack.name, "Test Pack")
|
||||||
|
|
||||||
|
retrieved = self.repo.get_pack(pack.id)
|
||||||
|
self.assertIsNotNone(retrieved)
|
||||||
|
self.assertEqual(retrieved.name, "Test Pack")
|
||||||
|
|
||||||
|
def test_list_packs(self):
|
||||||
|
"""Test listing packs."""
|
||||||
|
self.repo.create_pack("Pack 1", "First pack")
|
||||||
|
self.repo.create_pack("Pack 2", "Second pack")
|
||||||
|
|
||||||
|
packs = self.repo.list_packs()
|
||||||
|
self.assertEqual(len(packs), 2)
|
||||||
|
|
||||||
|
def test_update_pack(self):
|
||||||
|
"""Test updating a pack."""
|
||||||
|
pack = self.repo.create_pack("Original Name")
|
||||||
|
|
||||||
|
updated = self.repo.update_pack(pack.id, {'name': 'New Name'})
|
||||||
|
self.assertIsNotNone(updated)
|
||||||
|
self.assertEqual(updated.name, 'New Name')
|
||||||
|
|
||||||
|
def test_delete_pack(self):
|
||||||
|
"""Test deleting a pack."""
|
||||||
|
pack = self.repo.create_pack("To Delete")
|
||||||
|
|
||||||
|
self.assertTrue(self.repo.delete_pack(pack.id))
|
||||||
|
self.assertIsNone(self.repo.get_pack(pack.id))
|
||||||
|
self.assertFalse(self.repo.delete_pack(pack.id))
|
||||||
|
|
||||||
|
def test_add_and_find_correction(self):
|
||||||
|
"""Test adding and finding corrections."""
|
||||||
|
pack = self.repo.create_pack("Test Pack")
|
||||||
|
|
||||||
|
key = CorrectionKey("click", "button", "timeout")
|
||||||
|
correction = Correction(
|
||||||
|
id=generate_correction_id(),
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.TIMING_ADJUST
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(self.repo.add_correction(pack.id, correction))
|
||||||
|
|
||||||
|
# Find by context
|
||||||
|
found = self.repo.find_by_context("click", "button", "timeout")
|
||||||
|
self.assertEqual(len(found), 1)
|
||||||
|
self.assertEqual(found[0][1].id, correction.id)
|
||||||
|
|
||||||
|
def test_export_import_pack(self):
|
||||||
|
"""Test exporting and importing a pack."""
|
||||||
|
# Create pack with correction
|
||||||
|
pack = self.repo.create_pack("Export Test")
|
||||||
|
key = CorrectionKey("click", "button", "error")
|
||||||
|
correction = Correction(
|
||||||
|
id="corr_export",
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.TARGET_CHANGE
|
||||||
|
)
|
||||||
|
self.repo.add_correction(pack.id, correction)
|
||||||
|
|
||||||
|
# Export
|
||||||
|
exported = self.repo.export_pack(pack.id)
|
||||||
|
self.assertIsNotNone(exported)
|
||||||
|
|
||||||
|
# Import as new
|
||||||
|
imported = self.repo.import_pack(exported, merge_strategy='create_new')
|
||||||
|
self.assertIsNotNone(imported)
|
||||||
|
self.assertNotEqual(imported.id, pack.id)
|
||||||
|
self.assertEqual(imported.correction_count, 1)
|
||||||
|
|
||||||
|
def test_version_and_rollback(self):
|
||||||
|
"""Test versioning and rollback."""
|
||||||
|
pack = self.repo.create_pack("Version Test")
|
||||||
|
|
||||||
|
# Add correction
|
||||||
|
key = CorrectionKey("click", "button", "error")
|
||||||
|
correction1 = Correction(id="corr_v1", key=key, correction_type=CorrectionType.TARGET_CHANGE)
|
||||||
|
self.repo.add_correction(pack.id, correction1)
|
||||||
|
|
||||||
|
# Create version
|
||||||
|
version = self.repo.create_version(pack.id, "First version")
|
||||||
|
self.assertEqual(version, 1)
|
||||||
|
|
||||||
|
# Add another correction
|
||||||
|
correction2 = Correction(id="corr_v2", key=key, correction_type=CorrectionType.WAIT_ADDED)
|
||||||
|
self.repo.add_correction(pack.id, correction2)
|
||||||
|
|
||||||
|
# Verify 2 corrections
|
||||||
|
pack = self.repo.get_pack(pack.id)
|
||||||
|
self.assertEqual(pack.correction_count, 2)
|
||||||
|
|
||||||
|
# Rollback to version 1
|
||||||
|
self.assertTrue(self.repo.rollback_pack(pack.id, 1))
|
||||||
|
|
||||||
|
# Should have 1 correction again
|
||||||
|
pack = self.repo.get_pack(pack.id)
|
||||||
|
self.assertEqual(pack.correction_count, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCorrectionAggregator(unittest.TestCase):
|
||||||
|
"""Tests for CorrectionAggregator."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up aggregator for tests."""
|
||||||
|
self.aggregator = CorrectionAggregator(
|
||||||
|
min_occurrences=2,
|
||||||
|
min_success_rate=0.5,
|
||||||
|
min_confidence=0.3
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_aggregate_by_key(self):
|
||||||
|
"""Test aggregating corrections by key."""
|
||||||
|
key = CorrectionKey("click", "button", "not_found")
|
||||||
|
|
||||||
|
# Create multiple corrections with same key
|
||||||
|
corrections = [
|
||||||
|
Correction(
|
||||||
|
id=f"corr_{i}",
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.WAIT_ADDED,
|
||||||
|
success_count=5,
|
||||||
|
failure_count=2
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
aggregated = self.aggregator.aggregate_corrections(corrections)
|
||||||
|
|
||||||
|
# Should have 1 merged correction
|
||||||
|
self.assertEqual(len(aggregated), 1)
|
||||||
|
# Sum of success counts
|
||||||
|
self.assertEqual(aggregated[0].success_count, 15)
|
||||||
|
|
||||||
|
def test_filter_by_quality(self):
|
||||||
|
"""Test filtering by quality thresholds."""
|
||||||
|
key = CorrectionKey("click", "button", "error")
|
||||||
|
|
||||||
|
high_quality = Correction(
|
||||||
|
id="corr_high",
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.TARGET_CHANGE,
|
||||||
|
success_count=9,
|
||||||
|
failure_count=1
|
||||||
|
)
|
||||||
|
|
||||||
|
low_quality = Correction(
|
||||||
|
id="corr_low",
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.TARGET_CHANGE,
|
||||||
|
success_count=1,
|
||||||
|
failure_count=9
|
||||||
|
)
|
||||||
|
|
||||||
|
filtered = self.aggregator.filter_by_quality(
|
||||||
|
[high_quality, low_quality],
|
||||||
|
min_success_rate=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(filtered), 1)
|
||||||
|
self.assertEqual(filtered[0].id, "corr_high")
|
||||||
|
|
||||||
|
def test_rank_corrections(self):
|
||||||
|
"""Test ranking corrections."""
|
||||||
|
key = CorrectionKey("click", "button", "error")
|
||||||
|
|
||||||
|
corrections = [
|
||||||
|
Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=3, failure_count=7),
|
||||||
|
Correction(id="corr_2", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=8, failure_count=2),
|
||||||
|
Correction(id="corr_3", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=5, failure_count=5),
|
||||||
|
]
|
||||||
|
|
||||||
|
ranked = self.aggregator.rank_corrections(corrections, sort_by='success_rate')
|
||||||
|
|
||||||
|
self.assertEqual(ranked[0].id, "corr_2")
|
||||||
|
self.assertEqual(ranked[1].id, "corr_3")
|
||||||
|
self.assertEqual(ranked[2].id, "corr_1")
|
||||||
|
|
||||||
|
|
||||||
|
class TestCorrectionPackService(unittest.TestCase):
|
||||||
|
"""Tests for CorrectionPackService."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up service for tests."""
|
||||||
|
self.test_dir = tempfile.mkdtemp()
|
||||||
|
self.training_dir = tempfile.mkdtemp()
|
||||||
|
self.service = CorrectionPackService(self.test_dir, self.training_dir)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up temporary directories."""
|
||||||
|
shutil.rmtree(self.test_dir)
|
||||||
|
shutil.rmtree(self.training_dir)
|
||||||
|
|
||||||
|
def test_create_and_list_packs(self):
|
||||||
|
"""Test creating and listing packs."""
|
||||||
|
pack = self.service.create_pack(
|
||||||
|
name="Service Test Pack",
|
||||||
|
description="Testing the service",
|
||||||
|
category="testing",
|
||||||
|
tags=["unit", "test"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsNotNone(pack)
|
||||||
|
self.assertEqual(pack['name'], "Service Test Pack")
|
||||||
|
|
||||||
|
packs = self.service.list_packs()
|
||||||
|
self.assertEqual(len(packs), 1)
|
||||||
|
self.assertEqual(packs[0]['name'], "Service Test Pack")
|
||||||
|
|
||||||
|
def test_add_correction(self):
|
||||||
|
"""Test adding corrections via service."""
|
||||||
|
pack = self.service.create_pack(name="Correction Test")
|
||||||
|
|
||||||
|
correction = self.service.add_correction(pack['id'], {
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'button',
|
||||||
|
'failure_context': 'element_not_found',
|
||||||
|
'correction_type': 'target_change',
|
||||||
|
'original_target': {'selector': '#old'},
|
||||||
|
'corrected_target': {'selector': '#new'},
|
||||||
|
'description': 'Updated selector'
|
||||||
|
})
|
||||||
|
|
||||||
|
self.assertIsNotNone(correction)
|
||||||
|
self.assertTrue(correction['id'].startswith('corr_'))
|
||||||
|
|
||||||
|
def test_find_applicable_corrections(self):
|
||||||
|
"""Test finding applicable corrections."""
|
||||||
|
pack = self.service.create_pack(name="Find Test")
|
||||||
|
|
||||||
|
self.service.add_correction(pack['id'], {
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'button',
|
||||||
|
'failure_context': 'timeout',
|
||||||
|
'correction_type': 'timing_adjust',
|
||||||
|
'corrected_params': {'wait_time': 5}
|
||||||
|
})
|
||||||
|
|
||||||
|
found = self.service.find_applicable_corrections(
|
||||||
|
action_type='click',
|
||||||
|
element_type='button',
|
||||||
|
failure_context='timeout',
|
||||||
|
min_confidence=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(found), 1)
|
||||||
|
|
||||||
|
def test_export_import_file(self):
|
||||||
|
"""Test file-based export and import."""
|
||||||
|
pack = self.service.create_pack(name="File Export Test")
|
||||||
|
self.service.add_correction(pack['id'], {
|
||||||
|
'action_type': 'type',
|
||||||
|
'element_type': 'input',
|
||||||
|
'failure_context': 'validation',
|
||||||
|
'correction_type': 'parameter_change'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Export to file
|
||||||
|
export_file = os.path.join(self.test_dir, 'export.json')
|
||||||
|
self.assertTrue(self.service.export_pack_to_file(pack['id'], export_file))
|
||||||
|
self.assertTrue(os.path.exists(export_file))
|
||||||
|
|
||||||
|
# Import from file
|
||||||
|
imported = self.service.import_pack_from_file(export_file)
|
||||||
|
self.assertIsNotNone(imported)
|
||||||
|
self.assertEqual(len(imported['corrections']), 1)
|
||||||
|
|
||||||
|
def test_version_management(self):
|
||||||
|
"""Test version creation and rollback."""
|
||||||
|
pack = self.service.create_pack(name="Version Test")
|
||||||
|
|
||||||
|
# Add correction
|
||||||
|
self.service.add_correction(pack['id'], {
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'button',
|
||||||
|
'failure_context': 'error',
|
||||||
|
'correction_type': 'target_change'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Create version
|
||||||
|
version = self.service.create_version(pack['id'], "v1 snapshot")
|
||||||
|
self.assertEqual(version, 1)
|
||||||
|
|
||||||
|
# List versions
|
||||||
|
versions = self.service.list_versions(pack['id'])
|
||||||
|
self.assertEqual(len(versions), 1)
|
||||||
|
|
||||||
|
# Add another correction
|
||||||
|
self.service.add_correction(pack['id'], {
|
||||||
|
'action_type': 'type',
|
||||||
|
'element_type': 'input',
|
||||||
|
'failure_context': 'error',
|
||||||
|
'correction_type': 'parameter_change'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Get pack - should have 2 corrections
|
||||||
|
current_pack = self.service.get_pack(pack['id'])
|
||||||
|
self.assertEqual(len(current_pack['corrections']), 2)
|
||||||
|
|
||||||
|
# Rollback
|
||||||
|
self.assertTrue(self.service.rollback_pack(pack['id'], 1))
|
||||||
|
|
||||||
|
# Get pack - should have 1 correction
|
||||||
|
rolled_back = self.service.get_pack(pack['id'])
|
||||||
|
self.assertEqual(len(rolled_back['corrections']), 1)
|
||||||
|
|
||||||
|
def test_statistics(self):
|
||||||
|
"""Test getting statistics."""
|
||||||
|
pack = self.service.create_pack(name="Stats Test")
|
||||||
|
self.service.add_correction(pack['id'], {
|
||||||
|
'action_type': 'click',
|
||||||
|
'element_type': 'button',
|
||||||
|
'failure_context': 'error',
|
||||||
|
'correction_type': 'target_change'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Pack statistics
|
||||||
|
pack_stats = self.service.get_pack_statistics(pack['id'])
|
||||||
|
self.assertIsNotNone(pack_stats)
|
||||||
|
self.assertEqual(pack_stats['total_corrections'], 1)
|
||||||
|
|
||||||
|
# Global statistics
|
||||||
|
global_stats = self.service.get_global_statistics()
|
||||||
|
self.assertIsNotNone(global_stats)
|
||||||
|
self.assertEqual(global_stats['total_packs'], 1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrossWorkflowAggregator(unittest.TestCase):
|
||||||
|
"""Tests for CrossWorkflowAggregator."""
|
||||||
|
|
||||||
|
def test_find_common_corrections(self):
|
||||||
|
"""Test finding corrections common across workflows."""
|
||||||
|
aggregator = CrossWorkflowAggregator()
|
||||||
|
|
||||||
|
key = CorrectionKey("click", "submit_button", "timeout")
|
||||||
|
|
||||||
|
# Corrections from different workflows
|
||||||
|
corrections = [
|
||||||
|
Correction(
|
||||||
|
id="corr_1",
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.WAIT_ADDED,
|
||||||
|
source=CorrectionSource(session_id="s1", workflow_id="wf_1")
|
||||||
|
),
|
||||||
|
Correction(
|
||||||
|
id="corr_2",
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.WAIT_ADDED,
|
||||||
|
source=CorrectionSource(session_id="s2", workflow_id="wf_2")
|
||||||
|
),
|
||||||
|
Correction(
|
||||||
|
id="corr_3",
|
||||||
|
key=key,
|
||||||
|
correction_type=CorrectionType.WAIT_ADDED,
|
||||||
|
source=CorrectionSource(session_id="s3", workflow_id="wf_3")
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
common = aggregator.find_common_corrections(corrections, min_workflows=2)
|
||||||
|
|
||||||
|
self.assertEqual(len(common), 1)
|
||||||
|
merged, workflows = common[0]
|
||||||
|
self.assertEqual(len(workflows), 3)
|
||||||
|
|
||||||
|
def test_detect_patterns(self):
|
||||||
|
"""Test pattern detection."""
|
||||||
|
aggregator = CrossWorkflowAggregator()
|
||||||
|
|
||||||
|
key = CorrectionKey("click", "button", "error")
|
||||||
|
corrections = [
|
||||||
|
Correction(id=f"corr_{i}", key=key, correction_type=CorrectionType.WAIT_ADDED)
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
|
||||||
|
patterns = aggregator.detect_patterns(corrections)
|
||||||
|
|
||||||
|
self.assertIn('patterns', patterns)
|
||||||
|
self.assertIn('summary', patterns)
|
||||||
|
self.assertEqual(patterns['summary']['total_corrections'], 5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
640
visual_workflow_builder/backend/api/correction_packs.py
Normal file
640
visual_workflow_builder/backend/api/correction_packs.py
Normal file
@@ -0,0 +1,640 @@
|
|||||||
|
"""
|
||||||
|
API endpoints pour les Correction Packs.
|
||||||
|
|
||||||
|
Provides REST API for managing correction packs, corrections,
|
||||||
|
aggregation, export/import, and versioning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from flask import Blueprint, request, jsonify, Response
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# Add parent paths for imports
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))))
|
||||||
|
|
||||||
|
from core.corrections import CorrectionPackService
|
||||||
|
|
||||||
|
correction_packs_bp = Blueprint('correction_packs', __name__)
|
||||||
|
|
||||||
|
# Initialize service (singleton pattern)
|
||||||
|
_service: Optional[CorrectionPackService] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_service() -> CorrectionPackService:
|
||||||
|
"""Get or create the correction pack service."""
|
||||||
|
global _service
|
||||||
|
if _service is None:
|
||||||
|
# Use paths relative to the project root
|
||||||
|
base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||||
|
storage_path = os.path.join(base_path, "data", "correction_packs")
|
||||||
|
training_path = os.path.join(base_path, "training_data")
|
||||||
|
_service = CorrectionPackService(storage_path, training_path)
|
||||||
|
return _service
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Pack CRUD Endpoints ==========
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs', methods=['GET'])
|
||||||
|
def list_packs():
|
||||||
|
"""
|
||||||
|
List all correction packs.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
category: Filter by category
|
||||||
|
tags: Filter by tags (comma-separated)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON list of pack summaries
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
|
||||||
|
category = request.args.get('category')
|
||||||
|
tags_param = request.args.get('tags')
|
||||||
|
tags = tags_param.split(',') if tags_param else None
|
||||||
|
|
||||||
|
packs = service.list_packs(category=category, tags=tags)
|
||||||
|
return jsonify({'success': True, 'packs': packs})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs', methods=['POST'])
|
||||||
|
def create_pack():
|
||||||
|
"""
|
||||||
|
Create a new correction pack.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
name: Pack name (required)
|
||||||
|
description: Pack description
|
||||||
|
category: Pack category
|
||||||
|
tags: List of tags
|
||||||
|
author: Pack author
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created pack
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
if not data.get('name'):
|
||||||
|
return jsonify({'success': False, 'error': 'name is required'}), 400
|
||||||
|
|
||||||
|
pack = service.create_pack(
|
||||||
|
name=data['name'],
|
||||||
|
description=data.get('description', ''),
|
||||||
|
category=data.get('category', 'general'),
|
||||||
|
tags=data.get('tags', []),
|
||||||
|
author=data.get('author', '')
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'pack': pack}), 201
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>', methods=['GET'])
|
||||||
|
def get_pack(pack_id: str):
|
||||||
|
"""
|
||||||
|
Get a pack by ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pack details including all corrections
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
pack = service.get_pack(pack_id)
|
||||||
|
|
||||||
|
if not pack:
|
||||||
|
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'pack': pack})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>', methods=['PUT'])
|
||||||
|
def update_pack(pack_id: str):
|
||||||
|
"""
|
||||||
|
Update a pack.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
name: New name
|
||||||
|
description: New description
|
||||||
|
metadata: New metadata dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated pack
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
pack = service.update_pack(
|
||||||
|
pack_id=pack_id,
|
||||||
|
name=data.get('name'),
|
||||||
|
description=data.get('description'),
|
||||||
|
metadata=data.get('metadata')
|
||||||
|
)
|
||||||
|
|
||||||
|
if not pack:
|
||||||
|
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'pack': pack})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>', methods=['DELETE'])
|
||||||
|
def delete_pack(pack_id: str):
|
||||||
|
"""
|
||||||
|
Delete a pack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success status
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
|
||||||
|
if not service.delete_pack(pack_id):
|
||||||
|
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'message': 'Pack deleted'})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Correction Endpoints ==========
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>/corrections', methods=['POST'])
|
||||||
|
def add_correction(pack_id: str):
|
||||||
|
"""
|
||||||
|
Add a correction to a pack.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
action_type: Type of action (required)
|
||||||
|
element_type: Type of element
|
||||||
|
failure_context: Description of failure
|
||||||
|
correction_type: Type of correction
|
||||||
|
original_target: Original target dict
|
||||||
|
original_params: Original parameters dict
|
||||||
|
corrected_target: Corrected target dict
|
||||||
|
corrected_params: Corrected parameters dict
|
||||||
|
description: Correction description
|
||||||
|
tags: List of tags
|
||||||
|
source: Source info dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created correction
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
if not data.get('action_type'):
|
||||||
|
return jsonify({'success': False, 'error': 'action_type is required'}), 400
|
||||||
|
|
||||||
|
correction = service.add_correction(pack_id, data)
|
||||||
|
|
||||||
|
if not correction:
|
||||||
|
return jsonify({'success': False, 'error': 'Failed to add correction'}), 400
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'correction': correction}), 201
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>/corrections/<correction_id>', methods=['PUT'])
|
||||||
|
def update_correction(pack_id: str, correction_id: str):
|
||||||
|
"""
|
||||||
|
Update a correction.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
corrected_target: New corrected target
|
||||||
|
corrected_params: New corrected params
|
||||||
|
correction_description: New description
|
||||||
|
status: New status
|
||||||
|
tags: New tags
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated correction
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
correction = service.update_correction(pack_id, correction_id, data)
|
||||||
|
|
||||||
|
if not correction:
|
||||||
|
return jsonify({'success': False, 'error': 'Correction not found'}), 404
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'correction': correction})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>/corrections/<correction_id>', methods=['DELETE'])
|
||||||
|
def remove_correction(pack_id: str, correction_id: str):
|
||||||
|
"""
|
||||||
|
Remove a correction from a pack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success status
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
|
||||||
|
if not service.remove_correction(pack_id, correction_id):
|
||||||
|
return jsonify({'success': False, 'error': 'Correction not found'}), 404
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'message': 'Correction removed'})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Aggregation Endpoints ==========
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>/aggregate', methods=['POST'])
|
||||||
|
def aggregate_from_sessions(pack_id: str):
|
||||||
|
"""
|
||||||
|
Aggregate corrections from training sessions into a pack.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
workflow_id: Filter by workflow ID
|
||||||
|
min_occurrences: Minimum occurrences to include
|
||||||
|
min_success_rate: Minimum success rate to include
|
||||||
|
session_files: Specific session files to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Aggregation summary
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
result = service.aggregate_from_sessions(
|
||||||
|
pack_id=pack_id,
|
||||||
|
session_files=data.get('session_files'),
|
||||||
|
workflow_id=data.get('workflow_id'),
|
||||||
|
min_occurrences=data.get('min_occurrences', 1),
|
||||||
|
min_success_rate=data.get('min_success_rate', 0.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'result': result})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>/aggregate-cross-workflow', methods=['POST'])
|
||||||
|
def aggregate_cross_workflow(pack_id: str):
|
||||||
|
"""
|
||||||
|
Find and aggregate corrections common across multiple workflows.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
min_workflows: Minimum number of workflows (default: 2)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Aggregation summary with patterns
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
result = service.aggregate_cross_workflow(
|
||||||
|
pack_id=pack_id,
|
||||||
|
min_workflows=data.get('min_workflows', 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'result': result})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Search and Apply Endpoints ==========
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/find', methods=['POST'])
|
||||||
|
def find_corrections():
|
||||||
|
"""
|
||||||
|
Find corrections applicable to a failure context.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
action_type: Type of action that failed (required)
|
||||||
|
element_type: Type of element involved
|
||||||
|
failure_context: Description of failure
|
||||||
|
pack_ids: Specific packs to search
|
||||||
|
min_confidence: Minimum confidence score
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of applicable corrections
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
if not data.get('action_type'):
|
||||||
|
return jsonify({'success': False, 'error': 'action_type is required'}), 400
|
||||||
|
|
||||||
|
corrections = service.find_applicable_corrections(
|
||||||
|
action_type=data['action_type'],
|
||||||
|
element_type=data.get('element_type', 'unknown'),
|
||||||
|
failure_context=data.get('failure_context', ''),
|
||||||
|
pack_ids=data.get('pack_ids'),
|
||||||
|
min_confidence=data.get('min_confidence', 0.3)
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'corrections': corrections})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/apply', methods=['POST'])
|
||||||
|
def apply_correction():
|
||||||
|
"""
|
||||||
|
Record an application of a correction.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
pack_id: Pack ID (required)
|
||||||
|
correction_id: Correction ID (required)
|
||||||
|
success: Whether the application was successful (required)
|
||||||
|
context: Optional application context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success status
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
if not data.get('pack_id') or not data.get('correction_id'):
|
||||||
|
return jsonify({'success': False, 'error': 'pack_id and correction_id are required'}), 400
|
||||||
|
|
||||||
|
if 'success' not in data:
|
||||||
|
return jsonify({'success': False, 'error': 'success is required'}), 400
|
||||||
|
|
||||||
|
service.apply_correction(
|
||||||
|
pack_id=data['pack_id'],
|
||||||
|
correction_id=data['correction_id'],
|
||||||
|
success=data['success'],
|
||||||
|
application_context=data.get('context')
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'message': 'Application recorded'})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/search', methods=['GET'])
|
||||||
|
def search_corrections():
|
||||||
|
"""
|
||||||
|
Search corrections by text query.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
q: Search query (required)
|
||||||
|
pack_id: Optional pack ID to search in
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching corrections
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
|
||||||
|
query = request.args.get('q')
|
||||||
|
if not query:
|
||||||
|
return jsonify({'success': False, 'error': 'q parameter is required'}), 400
|
||||||
|
|
||||||
|
pack_id = request.args.get('pack_id')
|
||||||
|
|
||||||
|
corrections = service.search_corrections(query, pack_id)
|
||||||
|
return jsonify({'success': True, 'corrections': corrections})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Export/Import Endpoints ==========
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>/export', methods=['GET'])
|
||||||
|
def export_pack(pack_id: str):
|
||||||
|
"""
|
||||||
|
Export a pack to JSON or YAML.
|
||||||
|
|
||||||
|
Query params:
|
||||||
|
format: Export format ('json' or 'yaml', default: 'json')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exported content
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
|
||||||
|
format_type = request.args.get('format', 'json')
|
||||||
|
if format_type not in ['json', 'yaml']:
|
||||||
|
return jsonify({'success': False, 'error': 'Invalid format'}), 400
|
||||||
|
|
||||||
|
content = service.export_pack(pack_id, format=format_type)
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||||
|
|
||||||
|
# Determine content type
|
||||||
|
if format_type == 'yaml':
|
||||||
|
content_type = 'application/x-yaml'
|
||||||
|
filename = f'correction_pack_{pack_id}.yaml'
|
||||||
|
else:
|
||||||
|
content_type = 'application/json'
|
||||||
|
filename = f'correction_pack_{pack_id}.json'
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content,
|
||||||
|
mimetype=content_type,
|
||||||
|
headers={
|
||||||
|
'Content-Disposition': f'attachment; filename={filename}'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/import', methods=['POST'])
|
||||||
|
def import_pack():
|
||||||
|
"""
|
||||||
|
Import a pack from JSON or YAML.
|
||||||
|
|
||||||
|
Request body or file upload:
|
||||||
|
content: Content to import (if JSON body)
|
||||||
|
format: Import format ('json' or 'yaml')
|
||||||
|
merge_strategy: How to handle existing pack
|
||||||
|
- 'create_new': Always create new pack
|
||||||
|
- 'merge': Merge into existing pack
|
||||||
|
- 'replace': Replace existing pack
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Imported pack
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
|
||||||
|
# Check for file upload
|
||||||
|
if 'file' in request.files:
|
||||||
|
file = request.files['file']
|
||||||
|
content = file.read().decode('utf-8')
|
||||||
|
filename = file.filename
|
||||||
|
|
||||||
|
# Auto-detect format from filename
|
||||||
|
if filename and (filename.endswith('.yaml') or filename.endswith('.yml')):
|
||||||
|
format_type = 'yaml'
|
||||||
|
else:
|
||||||
|
format_type = 'json'
|
||||||
|
|
||||||
|
merge_strategy = request.form.get('merge_strategy', 'create_new')
|
||||||
|
else:
|
||||||
|
# JSON body
|
||||||
|
data = request.get_json() or {}
|
||||||
|
content = data.get('content')
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return jsonify({'success': False, 'error': 'content or file is required'}), 400
|
||||||
|
|
||||||
|
format_type = data.get('format', 'json')
|
||||||
|
merge_strategy = data.get('merge_strategy', 'create_new')
|
||||||
|
|
||||||
|
pack = service.import_pack(content, format=format_type, merge_strategy=merge_strategy)
|
||||||
|
|
||||||
|
if not pack:
|
||||||
|
return jsonify({'success': False, 'error': 'Failed to import pack'}), 400
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'pack': pack}), 201
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Version Endpoints ==========
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>/versions', methods=['GET'])
|
||||||
|
def list_versions(pack_id: str):
|
||||||
|
"""
|
||||||
|
List all versions of a pack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of version info
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
versions = service.list_versions(pack_id)
|
||||||
|
return jsonify({'success': True, 'versions': versions})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>/versions', methods=['POST'])
|
||||||
|
def create_version(pack_id: str):
|
||||||
|
"""
|
||||||
|
Create a version snapshot of a pack.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
description: Version description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created version number
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
version = service.create_version(
|
||||||
|
pack_id=pack_id,
|
||||||
|
description=data.get('description', '')
|
||||||
|
)
|
||||||
|
|
||||||
|
if version < 0:
|
||||||
|
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'version': version}), 201
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>/rollback', methods=['POST'])
|
||||||
|
def rollback_pack(pack_id: str):
|
||||||
|
"""
|
||||||
|
Rollback a pack to a previous version.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
version: Version number to rollback to (required)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success status
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
data = request.get_json() or {}
|
||||||
|
|
||||||
|
version = data.get('version')
|
||||||
|
if version is None:
|
||||||
|
return jsonify({'success': False, 'error': 'version is required'}), 400
|
||||||
|
|
||||||
|
if not service.rollback_pack(pack_id, version):
|
||||||
|
return jsonify({'success': False, 'error': 'Rollback failed'}), 400
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'message': f'Rolled back to version {version}'})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Statistics Endpoints ==========
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/<pack_id>/statistics', methods=['GET'])
|
||||||
|
def get_pack_statistics(pack_id: str):
|
||||||
|
"""
|
||||||
|
Get detailed statistics for a pack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Statistics dict
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
stats = service.get_pack_statistics(pack_id)
|
||||||
|
|
||||||
|
if not stats:
|
||||||
|
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||||
|
|
||||||
|
return jsonify({'success': True, 'statistics': stats})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@correction_packs_bp.route('/correction-packs/statistics', methods=['GET'])
|
||||||
|
def get_global_statistics():
|
||||||
|
"""
|
||||||
|
Get global statistics across all packs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Global statistics dict
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = get_service()
|
||||||
|
stats = service.get_global_statistics()
|
||||||
|
return jsonify({'success': True, 'statistics': stats})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'error': str(e)}), 500
|
||||||
201
visual_workflow_builder/backend/app.py
Normal file
201
visual_workflow_builder/backend/app.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""
|
||||||
|
Visual Workflow Builder - Backend Flask Application
|
||||||
|
|
||||||
|
This is the main entry point for the Visual Workflow Builder backend API.
|
||||||
|
It provides REST endpoints for workflow management and WebSocket support
|
||||||
|
for real-time execution updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from flask import Flask
|
||||||
|
from flask_cors import CORS
|
||||||
|
from flask_socketio import SocketIO
|
||||||
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
|
from flask_caching import Cache
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Initialize Flask app
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', 'dev-secret-key-change-in-production')
|
||||||
|
app.config['SQLALCHEMY_DATABASE_URI'] = os.getenv('DATABASE_URL', 'sqlite:///workflows.db')
|
||||||
|
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
|
||||||
|
app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024 # 10MB max upload
|
||||||
|
app.config['CACHE_TYPE'] = 'redis' if os.getenv('REDIS_URL') else 'simple'
|
||||||
|
app.config['CACHE_REDIS_URL'] = os.getenv('REDIS_URL', 'redis://localhost:6379/0')
|
||||||
|
|
||||||
|
# Initialize extensions
|
||||||
|
db = SQLAlchemy(app)
|
||||||
|
cache = Cache(app)
|
||||||
|
socketio = SocketIO(
|
||||||
|
app,
|
||||||
|
cors_allowed_origins="*",
|
||||||
|
async_mode='threading',
|
||||||
|
logger=True,
|
||||||
|
engineio_logger=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enable CORS
|
||||||
|
CORS(app, resources={
|
||||||
|
r"/api/*": {
|
||||||
|
"origins": os.getenv('CORS_ORIGINS', 'http://localhost:3000').split(','),
|
||||||
|
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
|
"allow_headers": ["Content-Type", "Authorization"]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# Import and register blueprints (minimal set)
|
||||||
|
from api.workflows import workflows_bp
|
||||||
|
from api.screen_capture import screen_capture_bp
|
||||||
|
from api.real_demo import real_demo_bp
|
||||||
|
from api.errors import error_response
|
||||||
|
|
||||||
|
app.register_blueprint(workflows_bp, url_prefix='/api/workflows')
|
||||||
|
app.register_blueprint(screen_capture_bp, url_prefix='/api/screen-capture')
|
||||||
|
app.register_blueprint(real_demo_bp)
|
||||||
|
|
||||||
|
# Optional / Phase 2+ blueprints (loaded only if modules are available)
|
||||||
|
try:
|
||||||
|
from api.self_healing import self_healing_bp
|
||||||
|
app.register_blueprint(self_healing_bp)
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"⚠️ Blueprint self_healing désactivé: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from api.visual_targets import visual_targets_bp, init_visual_target_manager
|
||||||
|
app.register_blueprint(visual_targets_bp)
|
||||||
|
VISUAL_TARGETS_BP_AVAILABLE = True
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"⚠️ Blueprint visual_targets désactivé: {e}")
|
||||||
|
VISUAL_TARGETS_BP_AVAILABLE = False
|
||||||
|
init_visual_target_manager = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from api.element_detection import element_detection_bp, init_element_detection
|
||||||
|
app.register_blueprint(element_detection_bp)
|
||||||
|
ELEMENT_DETECTION_BP_AVAILABLE = True
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"⚠️ Blueprint element_detection désactivé: {e}")
|
||||||
|
ELEMENT_DETECTION_BP_AVAILABLE = False
|
||||||
|
init_element_detection = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from api.analytics import analytics_bp
|
||||||
|
app.register_blueprint(analytics_bp, url_prefix='/api/analytics')
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Register other blueprints (optional - depends on Phase 2+ services)
|
||||||
|
try:
|
||||||
|
from api.templates import templates_bp
|
||||||
|
app.register_blueprint(templates_bp, url_prefix='/api/templates')
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"⚠️ Blueprint templates désactivé: {e}")
|
||||||
|
|
||||||
|
from api.node_types import node_types_bp
|
||||||
|
app.register_blueprint(node_types_bp, url_prefix='/api/node-types')
|
||||||
|
|
||||||
|
try:
|
||||||
|
from api.executions import executions_bp
|
||||||
|
app.register_blueprint(executions_bp, url_prefix='/api/executions')
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"⚠️ Blueprint executions désactivé: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from api.import_export import import_export_bp
|
||||||
|
app.register_blueprint(import_export_bp, url_prefix='/api')
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"⚠️ Blueprint import_export désactivé: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from api.correction_packs import correction_packs_bp
|
||||||
|
app.register_blueprint(correction_packs_bp, url_prefix='/api')
|
||||||
|
print("✅ Blueprint correction_packs enregistré")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"⚠️ Blueprint correction_packs désactivé: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# Import WebSocket handlers (optional)
|
||||||
|
try:
|
||||||
|
from api import websocket_handlers # noqa: F401
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ WebSocket handlers désactivés: {e}")
|
||||||
|
|
||||||
|
# Global error handlers
|
||||||
|
@app.errorhandler(404)
|
||||||
|
def not_found(error):
|
||||||
|
"""Handle 404 errors"""
|
||||||
|
return error_response(404, "Resource not found")
|
||||||
|
|
||||||
|
@app.errorhandler(405)
|
||||||
|
def method_not_allowed(error):
|
||||||
|
"""Handle 405 errors"""
|
||||||
|
return error_response(405, "Method not allowed")
|
||||||
|
|
||||||
|
@app.errorhandler(500)
|
||||||
|
def internal_error(error):
|
||||||
|
"""Handle 500 errors"""
|
||||||
|
return error_response(500, "Internal server error")
|
||||||
|
|
||||||
|
@app.errorhandler(Exception)
|
||||||
|
def handle_exception(error):
|
||||||
|
"""Handle all unhandled exceptions"""
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return error_response(500, f"Unexpected error: {str(error)}")
|
||||||
|
|
||||||
|
# Health check endpoint
|
||||||
|
@app.route('/health')
|
||||||
|
def health_check():
|
||||||
|
"""Health check endpoint for monitoring"""
|
||||||
|
return {'status': 'healthy', 'version': '1.0.0'}
|
||||||
|
|
||||||
|
# Create database tables
|
||||||
|
with app.app_context():
|
||||||
|
db.create_all()
|
||||||
|
|
||||||
|
# Initialize VisualTargetManager with RPA Vision V3 components (optional)
|
||||||
|
try:
|
||||||
|
from core.capture.screen_capturer import ScreenCapturer
|
||||||
|
from core.detection.ui_detector import UIDetector
|
||||||
|
from core.embedding.fusion_engine import FusionEngine
|
||||||
|
|
||||||
|
# Only initialize if the related blueprints were actually loaded
|
||||||
|
if VISUAL_TARGETS_BP_AVAILABLE and init_visual_target_manager:
|
||||||
|
screen_capturer = ScreenCapturer()
|
||||||
|
ui_detector = UIDetector()
|
||||||
|
fusion_engine = FusionEngine()
|
||||||
|
init_visual_target_manager(screen_capturer, ui_detector, fusion_engine)
|
||||||
|
|
||||||
|
if ELEMENT_DETECTION_BP_AVAILABLE and init_element_detection:
|
||||||
|
# Reuse the same instances when possible
|
||||||
|
if 'ui_detector' not in locals():
|
||||||
|
ui_detector = UIDetector()
|
||||||
|
if 'screen_capturer' not in locals():
|
||||||
|
screen_capturer = ScreenCapturer()
|
||||||
|
init_element_detection(ui_detector, screen_capturer)
|
||||||
|
|
||||||
|
if (VISUAL_TARGETS_BP_AVAILABLE and init_visual_target_manager) or (ELEMENT_DETECTION_BP_AVAILABLE and init_element_detection):
|
||||||
|
print("✅ Services visuels initialisés (VisualTargets / ElementDetection)")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"⚠️ Core RPA non disponible pour l'initialisation visuelle: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Erreur lors de l'initialisation des services visuels: {e}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
port = int(os.getenv('PORT', 5002))
|
||||||
|
debug = os.getenv('FLASK_ENV') == 'development'
|
||||||
|
|
||||||
|
socketio.run(
|
||||||
|
app,
|
||||||
|
host='0.0.0.0',
|
||||||
|
port=port,
|
||||||
|
debug=debug,
|
||||||
|
use_reloader=debug,
|
||||||
|
allow_unsafe_werkzeug=True # For development only
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user