feat(corrections): Add Correction Packs system for cross-workflow learning
Implement a complete system for capitalizing user corrections across multiple workflows and sessions. This enables automatic application of learned fixes when similar failures occur in different contexts. New components: - core/corrections/models.py: CorrectionKey, Correction, CorrectionPack models - core/corrections/correction_repository.py: JSON storage with atomic writes - core/corrections/aggregator.py: Aggregation by hash and quality filtering - core/corrections/correction_pack_service.py: CRUD, export/import, versioning - backend/api/correction_packs.py: REST API with 15 endpoints Features: - MD5-based key hashing for correction deduplication - Export/import in JSON and YAML formats - Version history with rollback support - Cross-workflow pattern detection - Integration with SelfHealingEngine for automatic application - 29 unit tests (all passing) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
30
core/corrections/__init__.py
Normal file
30
core/corrections/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
Correction Packs System - Cross-workflow correction capitalization.
|
||||
|
||||
This module provides a system for storing, aggregating, and applying
|
||||
user corrections across multiple workflows and sessions.
|
||||
"""
|
||||
|
||||
from .models import (
|
||||
CorrectionKey,
|
||||
CorrectionType,
|
||||
CorrectionStatus,
|
||||
Correction,
|
||||
CorrectionPack,
|
||||
CorrectionPackMetadata
|
||||
)
|
||||
from .correction_repository import CorrectionRepository
|
||||
from .aggregator import CorrectionAggregator
|
||||
from .correction_pack_service import CorrectionPackService
|
||||
|
||||
__all__ = [
|
||||
'CorrectionKey',
|
||||
'CorrectionType',
|
||||
'CorrectionStatus',
|
||||
'Correction',
|
||||
'CorrectionPack',
|
||||
'CorrectionPackMetadata',
|
||||
'CorrectionRepository',
|
||||
'CorrectionAggregator',
|
||||
'CorrectionPackService'
|
||||
]
|
||||
545
core/corrections/aggregator.py
Normal file
545
core/corrections/aggregator.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""
|
||||
Correction Aggregator - Aggregate corrections from multiple sessions.
|
||||
|
||||
Groups corrections by key hash and filters by quality thresholds.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
from .models import (
|
||||
Correction,
|
||||
CorrectionKey,
|
||||
CorrectionType,
|
||||
CorrectionSource,
|
||||
CorrectionStatus,
|
||||
generate_correction_id
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CorrectionAggregator:
|
||||
"""
|
||||
Aggregates corrections from multiple sources (sessions, workflows).
|
||||
|
||||
Groups corrections by CorrectionKey hash and merges statistics.
|
||||
Filters out low-quality corrections based on configurable thresholds.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_occurrences: int = 2,
|
||||
min_success_rate: float = 0.5,
|
||||
min_confidence: float = 0.3
|
||||
):
|
||||
"""
|
||||
Initialize the aggregator.
|
||||
|
||||
Args:
|
||||
min_occurrences: Minimum number of times a correction must occur
|
||||
min_success_rate: Minimum success rate to include correction
|
||||
min_confidence: Minimum confidence score to include correction
|
||||
"""
|
||||
self.min_occurrences = min_occurrences
|
||||
self.min_success_rate = min_success_rate
|
||||
self.min_confidence = min_confidence
|
||||
|
||||
def aggregate_corrections(
|
||||
self,
|
||||
corrections: List[Correction],
|
||||
merge_strategy: str = 'weighted_average'
|
||||
) -> List[Correction]:
|
||||
"""
|
||||
Aggregate a list of corrections by key hash.
|
||||
|
||||
Args:
|
||||
corrections: List of corrections to aggregate
|
||||
merge_strategy: How to merge duplicate corrections:
|
||||
- 'weighted_average': Weight by usage count
|
||||
- 'best_confidence': Keep highest confidence
|
||||
- 'most_recent': Keep most recently used
|
||||
|
||||
Returns:
|
||||
List of aggregated corrections
|
||||
"""
|
||||
# Group by key hash
|
||||
grouped: Dict[str, List[Correction]] = defaultdict(list)
|
||||
for correction in corrections:
|
||||
key_hash = correction.key.to_hash()
|
||||
grouped[key_hash].append(correction)
|
||||
|
||||
# Aggregate each group
|
||||
aggregated = []
|
||||
for key_hash, group in grouped.items():
|
||||
if len(group) == 1:
|
||||
# Single correction, just check thresholds
|
||||
if self._passes_thresholds(group[0]):
|
||||
aggregated.append(group[0])
|
||||
else:
|
||||
# Multiple corrections, merge
|
||||
merged = self._merge_corrections(group, merge_strategy)
|
||||
if merged and self._passes_thresholds(merged):
|
||||
aggregated.append(merged)
|
||||
|
||||
return aggregated
|
||||
|
||||
def aggregate_from_sessions(
|
||||
self,
|
||||
session_corrections: List[Dict[str, Any]],
|
||||
workflow_id: Optional[str] = None
|
||||
) -> List[Correction]:
|
||||
"""
|
||||
Aggregate corrections from session data (user_corrections format).
|
||||
|
||||
Args:
|
||||
session_corrections: List of correction dicts from TrainingSession
|
||||
workflow_id: Optional workflow ID to filter by
|
||||
|
||||
Returns:
|
||||
List of aggregated Correction objects
|
||||
"""
|
||||
# Convert session corrections to Correction objects
|
||||
corrections = []
|
||||
for sc in session_corrections:
|
||||
try:
|
||||
correction = self._session_correction_to_correction(sc)
|
||||
if correction:
|
||||
# Filter by workflow if specified
|
||||
if workflow_id is None or (correction.source and correction.source.workflow_id == workflow_id):
|
||||
corrections.append(correction)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error converting session correction: {e}")
|
||||
|
||||
# Aggregate
|
||||
return self.aggregate_corrections(corrections)
|
||||
|
||||
def group_by_type(
|
||||
self,
|
||||
corrections: List[Correction]
|
||||
) -> Dict[CorrectionType, List[Correction]]:
|
||||
"""
|
||||
Group corrections by their type.
|
||||
|
||||
Args:
|
||||
corrections: List of corrections
|
||||
|
||||
Returns:
|
||||
Dict mapping correction type to list of corrections
|
||||
"""
|
||||
grouped: Dict[CorrectionType, List[Correction]] = defaultdict(list)
|
||||
for correction in corrections:
|
||||
grouped[correction.correction_type].append(correction)
|
||||
return dict(grouped)
|
||||
|
||||
def group_by_action_type(
|
||||
self,
|
||||
corrections: List[Correction]
|
||||
) -> Dict[str, List[Correction]]:
|
||||
"""
|
||||
Group corrections by the action type they apply to.
|
||||
|
||||
Args:
|
||||
corrections: List of corrections
|
||||
|
||||
Returns:
|
||||
Dict mapping action type to list of corrections
|
||||
"""
|
||||
grouped: Dict[str, List[Correction]] = defaultdict(list)
|
||||
for correction in corrections:
|
||||
grouped[correction.key.action_type].append(correction)
|
||||
return dict(grouped)
|
||||
|
||||
def filter_by_quality(
|
||||
self,
|
||||
corrections: List[Correction],
|
||||
min_success_rate: Optional[float] = None,
|
||||
min_confidence: Optional[float] = None,
|
||||
min_applications: Optional[int] = None
|
||||
) -> List[Correction]:
|
||||
"""
|
||||
Filter corrections by quality thresholds.
|
||||
|
||||
Args:
|
||||
corrections: List of corrections to filter
|
||||
min_success_rate: Minimum success rate (overrides default)
|
||||
min_confidence: Minimum confidence (overrides default)
|
||||
min_applications: Minimum number of applications
|
||||
|
||||
Returns:
|
||||
Filtered list of corrections
|
||||
"""
|
||||
min_sr = min_success_rate if min_success_rate is not None else self.min_success_rate
|
||||
min_conf = min_confidence if min_confidence is not None else self.min_confidence
|
||||
min_apps = min_applications if min_applications is not None else 1
|
||||
|
||||
return [
|
||||
c for c in corrections
|
||||
if c.success_rate >= min_sr
|
||||
and c.confidence_score >= min_conf
|
||||
and (c.success_count + c.failure_count) >= min_apps
|
||||
]
|
||||
|
||||
def rank_corrections(
|
||||
self,
|
||||
corrections: List[Correction],
|
||||
sort_by: str = 'confidence'
|
||||
) -> List[Correction]:
|
||||
"""
|
||||
Rank corrections by specified criteria.
|
||||
|
||||
Args:
|
||||
corrections: List of corrections
|
||||
sort_by: Ranking criteria:
|
||||
- 'confidence': By confidence score
|
||||
- 'success_rate': By success rate
|
||||
- 'usage': By total applications
|
||||
- 'recent': By last application date
|
||||
|
||||
Returns:
|
||||
Sorted list of corrections
|
||||
"""
|
||||
if sort_by == 'confidence':
|
||||
return sorted(corrections, key=lambda c: c.confidence_score, reverse=True)
|
||||
elif sort_by == 'success_rate':
|
||||
return sorted(corrections, key=lambda c: c.success_rate, reverse=True)
|
||||
elif sort_by == 'usage':
|
||||
return sorted(corrections, key=lambda c: c.success_count + c.failure_count, reverse=True)
|
||||
elif sort_by == 'recent':
|
||||
return sorted(
|
||||
corrections,
|
||||
key=lambda c: c.last_applied or datetime.min,
|
||||
reverse=True
|
||||
)
|
||||
else:
|
||||
return corrections
|
||||
|
||||
def calculate_statistics(
|
||||
self,
|
||||
corrections: List[Correction]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate aggregate statistics for a list of corrections.
|
||||
|
||||
Args:
|
||||
corrections: List of corrections
|
||||
|
||||
Returns:
|
||||
Statistics dict
|
||||
"""
|
||||
if not corrections:
|
||||
return {
|
||||
'count': 0,
|
||||
'total_applications': 0,
|
||||
'average_success_rate': 0.0,
|
||||
'average_confidence': 0.0,
|
||||
'by_type': {},
|
||||
'by_action': {}
|
||||
}
|
||||
|
||||
total_apps = sum(c.success_count + c.failure_count for c in corrections)
|
||||
total_success = sum(c.success_count for c in corrections)
|
||||
|
||||
# Group by type
|
||||
by_type = self.group_by_type(corrections)
|
||||
type_stats = {
|
||||
t.value: len(cs) for t, cs in by_type.items()
|
||||
}
|
||||
|
||||
# Group by action
|
||||
by_action = self.group_by_action_type(corrections)
|
||||
action_stats = {
|
||||
a: len(cs) for a, cs in by_action.items()
|
||||
}
|
||||
|
||||
return {
|
||||
'count': len(corrections),
|
||||
'total_applications': total_apps,
|
||||
'total_success': total_success,
|
||||
'total_failure': total_apps - total_success,
|
||||
'average_success_rate': total_success / total_apps if total_apps > 0 else 0.0,
|
||||
'average_confidence': sum(c.confidence_score for c in corrections) / len(corrections),
|
||||
'by_type': type_stats,
|
||||
'by_action': action_stats
|
||||
}
|
||||
|
||||
def _passes_thresholds(self, correction: Correction) -> bool:
|
||||
"""Check if a correction passes quality thresholds."""
|
||||
total = correction.success_count + correction.failure_count
|
||||
|
||||
# Always include if no applications yet (new correction)
|
||||
if total == 0:
|
||||
return True
|
||||
|
||||
# Check minimum occurrences
|
||||
if total < self.min_occurrences:
|
||||
return False
|
||||
|
||||
# Check success rate
|
||||
if correction.success_rate < self.min_success_rate:
|
||||
return False
|
||||
|
||||
# Check confidence
|
||||
if correction.confidence_score < self.min_confidence:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _merge_corrections(
|
||||
self,
|
||||
corrections: List[Correction],
|
||||
strategy: str = 'weighted_average'
|
||||
) -> Optional[Correction]:
|
||||
"""
|
||||
Merge multiple corrections with the same key hash.
|
||||
|
||||
Args:
|
||||
corrections: List of corrections to merge
|
||||
strategy: Merge strategy
|
||||
|
||||
Returns:
|
||||
Merged correction
|
||||
"""
|
||||
if not corrections:
|
||||
return None
|
||||
|
||||
if len(corrections) == 1:
|
||||
return corrections[0]
|
||||
|
||||
if strategy == 'best_confidence':
|
||||
return max(corrections, key=lambda c: c.confidence_score)
|
||||
|
||||
elif strategy == 'most_recent':
|
||||
return max(
|
||||
corrections,
|
||||
key=lambda c: c.last_applied or c.created_at
|
||||
)
|
||||
|
||||
else: # weighted_average (default)
|
||||
# Use the correction with highest usage as base
|
||||
base = max(corrections, key=lambda c: c.success_count + c.failure_count)
|
||||
|
||||
# Sum statistics
|
||||
total_success = sum(c.success_count for c in corrections)
|
||||
total_failure = sum(c.failure_count for c in corrections)
|
||||
|
||||
# Find most recent application
|
||||
last_applied = None
|
||||
for c in corrections:
|
||||
if c.last_applied:
|
||||
if last_applied is None or c.last_applied > last_applied:
|
||||
last_applied = c.last_applied
|
||||
|
||||
# Merge tags
|
||||
all_tags = set()
|
||||
for c in corrections:
|
||||
all_tags.update(c.tags)
|
||||
|
||||
# Create merged correction
|
||||
merged = Correction(
|
||||
id=generate_correction_id(),
|
||||
key=base.key,
|
||||
correction_type=base.correction_type,
|
||||
original_target=base.original_target,
|
||||
original_params=base.original_params,
|
||||
corrected_target=base.corrected_target,
|
||||
corrected_params=base.corrected_params,
|
||||
correction_description=base.correction_description,
|
||||
source=base.source, # Keep first source as primary
|
||||
success_count=total_success,
|
||||
failure_count=total_failure,
|
||||
last_applied=last_applied,
|
||||
status=CorrectionStatus.ACTIVE,
|
||||
created_at=min(c.created_at for c in corrections),
|
||||
updated_at=datetime.now(),
|
||||
tags=list(all_tags)
|
||||
)
|
||||
|
||||
return merged
|
||||
|
||||
def _session_correction_to_correction(
|
||||
self,
|
||||
session_correction: Dict[str, Any]
|
||||
) -> Optional[Correction]:
|
||||
"""
|
||||
Convert a session correction dict to a Correction object.
|
||||
|
||||
Args:
|
||||
session_correction: Dict from TrainingSession.user_corrections
|
||||
|
||||
Returns:
|
||||
Correction object or None if invalid
|
||||
"""
|
||||
try:
|
||||
# Extract key components
|
||||
action_type = session_correction.get('action_type', session_correction.get('type', 'unknown'))
|
||||
element_type = session_correction.get('element_type', 'unknown')
|
||||
failure_context = session_correction.get('failure_reason', session_correction.get('reason', ''))
|
||||
|
||||
key = CorrectionKey(action_type, element_type, failure_context)
|
||||
|
||||
# Determine correction type
|
||||
correction_type_str = session_correction.get('correction_type', 'other')
|
||||
try:
|
||||
correction_type = CorrectionType(correction_type_str)
|
||||
except ValueError:
|
||||
correction_type = CorrectionType.OTHER
|
||||
|
||||
# Create source
|
||||
source = CorrectionSource(
|
||||
session_id=session_correction.get('session_id', ''),
|
||||
workflow_id=session_correction.get('workflow_id'),
|
||||
node_id=session_correction.get('node_id')
|
||||
)
|
||||
|
||||
timestamp = session_correction.get('timestamp')
|
||||
if timestamp:
|
||||
if isinstance(timestamp, str):
|
||||
source.timestamp = datetime.fromisoformat(timestamp)
|
||||
elif isinstance(timestamp, datetime):
|
||||
source.timestamp = timestamp
|
||||
|
||||
# Create correction
|
||||
return Correction(
|
||||
id=generate_correction_id(),
|
||||
key=key,
|
||||
correction_type=correction_type,
|
||||
original_target=session_correction.get('original_target'),
|
||||
original_params=session_correction.get('original_params'),
|
||||
corrected_target=session_correction.get('corrected_target', session_correction.get('new_target')),
|
||||
corrected_params=session_correction.get('corrected_params', session_correction.get('new_params')),
|
||||
correction_description=session_correction.get('description', ''),
|
||||
source=source,
|
||||
success_count=1 if session_correction.get('applied', True) else 0,
|
||||
failure_count=0 if session_correction.get('applied', True) else 1,
|
||||
tags=session_correction.get('tags', [])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error converting session correction: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class CrossWorkflowAggregator(CorrectionAggregator):
|
||||
"""
|
||||
Extended aggregator for cross-workflow correction analysis.
|
||||
|
||||
Provides additional methods for analyzing patterns across workflows.
|
||||
"""
|
||||
|
||||
def find_common_corrections(
|
||||
self,
|
||||
corrections: List[Correction],
|
||||
min_workflows: int = 2
|
||||
) -> List[Tuple[Correction, List[str]]]:
|
||||
"""
|
||||
Find corrections that appear in multiple workflows.
|
||||
|
||||
Args:
|
||||
corrections: List of corrections
|
||||
min_workflows: Minimum number of workflows
|
||||
|
||||
Returns:
|
||||
List of (correction, workflow_ids) tuples
|
||||
"""
|
||||
# Group by key hash
|
||||
grouped: Dict[str, List[Correction]] = defaultdict(list)
|
||||
for correction in corrections:
|
||||
key_hash = correction.key.to_hash()
|
||||
grouped[key_hash].append(correction)
|
||||
|
||||
# Find common corrections
|
||||
common = []
|
||||
for key_hash, group in grouped.items():
|
||||
workflow_ids = set()
|
||||
for c in group:
|
||||
if c.source and c.source.workflow_id:
|
||||
workflow_ids.add(c.source.workflow_id)
|
||||
|
||||
if len(workflow_ids) >= min_workflows:
|
||||
# Merge the group
|
||||
merged = self._merge_corrections(group)
|
||||
if merged:
|
||||
common.append((merged, list(workflow_ids)))
|
||||
|
||||
# Sort by number of workflows
|
||||
common.sort(key=lambda x: len(x[1]), reverse=True)
|
||||
return common
|
||||
|
||||
def detect_patterns(
|
||||
self,
|
||||
corrections: List[Correction]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Detect patterns in corrections.
|
||||
|
||||
Args:
|
||||
corrections: List of corrections
|
||||
|
||||
Returns:
|
||||
Pattern analysis dict
|
||||
"""
|
||||
if not corrections:
|
||||
return {'patterns': [], 'summary': {}}
|
||||
|
||||
patterns = []
|
||||
|
||||
# Pattern 1: Common action type corrections
|
||||
by_action = self.group_by_action_type(corrections)
|
||||
for action, action_corrections in by_action.items():
|
||||
if len(action_corrections) >= 3:
|
||||
patterns.append({
|
||||
'type': 'frequent_action_correction',
|
||||
'action_type': action,
|
||||
'count': len(action_corrections),
|
||||
'avg_success_rate': sum(c.success_rate for c in action_corrections) / len(action_corrections)
|
||||
})
|
||||
|
||||
# Pattern 2: Common element type failures
|
||||
by_element: Dict[str, List[Correction]] = defaultdict(list)
|
||||
for c in corrections:
|
||||
by_element[c.key.element_type].append(c)
|
||||
|
||||
for element, element_corrections in by_element.items():
|
||||
if len(element_corrections) >= 3:
|
||||
patterns.append({
|
||||
'type': 'element_type_issues',
|
||||
'element_type': element,
|
||||
'count': len(element_corrections),
|
||||
'common_fixes': self._find_common_fix_patterns(element_corrections)
|
||||
})
|
||||
|
||||
# Pattern 3: Time-based patterns
|
||||
recent = [c for c in corrections if c.last_applied and
|
||||
(datetime.now() - c.last_applied).days < 7]
|
||||
if len(recent) > len(corrections) * 0.5:
|
||||
patterns.append({
|
||||
'type': 'recent_activity_spike',
|
||||
'recent_count': len(recent),
|
||||
'total_count': len(corrections)
|
||||
})
|
||||
|
||||
return {
|
||||
'patterns': patterns,
|
||||
'summary': {
|
||||
'total_corrections': len(corrections),
|
||||
'unique_action_types': len(by_action),
|
||||
'unique_element_types': len(by_element),
|
||||
'pattern_count': len(patterns)
|
||||
}
|
||||
}
|
||||
|
||||
def _find_common_fix_patterns(
|
||||
self,
|
||||
corrections: List[Correction]
|
||||
) -> List[str]:
|
||||
"""Find common fix patterns in a list of corrections."""
|
||||
fix_types = defaultdict(int)
|
||||
for c in corrections:
|
||||
fix_types[c.correction_type.value] += 1
|
||||
|
||||
# Return top 3 most common
|
||||
sorted_fixes = sorted(fix_types.items(), key=lambda x: x[1], reverse=True)
|
||||
return [fix[0] for fix in sorted_fixes[:3]]
|
||||
785
core/corrections/correction_pack_service.py
Normal file
785
core/corrections/correction_pack_service.py
Normal file
@@ -0,0 +1,785 @@
|
||||
"""
|
||||
Correction Pack Service - Main service for correction pack management.
|
||||
|
||||
Combines repository, aggregator, and provides high-level operations
|
||||
for CRUD, aggregation, export/import, and application of corrections.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
from .models import (
|
||||
Correction,
|
||||
CorrectionPack,
|
||||
CorrectionKey,
|
||||
CorrectionType,
|
||||
CorrectionSource,
|
||||
CorrectionStatus,
|
||||
CorrectionPackMetadata,
|
||||
generate_correction_id,
|
||||
generate_pack_id
|
||||
)
|
||||
from .correction_repository import CorrectionRepository
|
||||
from .aggregator import CorrectionAggregator, CrossWorkflowAggregator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CorrectionPackService:
|
||||
"""
|
||||
Main service for managing correction packs.
|
||||
|
||||
Provides high-level operations combining the repository and aggregator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_path: str = "data/correction_packs",
|
||||
training_data_path: str = "training_data"
|
||||
):
|
||||
"""
|
||||
Initialize the service.
|
||||
|
||||
Args:
|
||||
storage_path: Path for correction pack storage
|
||||
training_data_path: Path to training data (for session import)
|
||||
"""
|
||||
self.storage_path = Path(storage_path)
|
||||
self.training_data_path = Path(training_data_path)
|
||||
|
||||
# Initialize components
|
||||
self.repository = CorrectionRepository(storage_path)
|
||||
self.aggregator = CorrectionAggregator()
|
||||
self.cross_workflow_aggregator = CrossWorkflowAggregator()
|
||||
|
||||
logger.info(f"CorrectionPackService initialized (storage={storage_path})")
|
||||
|
||||
# ========== Pack CRUD Operations ==========
|
||||
|
||||
def list_packs(
|
||||
self,
|
||||
category: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all correction packs with optional filtering.
|
||||
|
||||
Args:
|
||||
category: Filter by category
|
||||
tags: Filter by tags (match any)
|
||||
|
||||
Returns:
|
||||
List of pack summaries
|
||||
"""
|
||||
packs = self.repository.list_packs()
|
||||
|
||||
# Filter by category
|
||||
if category:
|
||||
packs = [p for p in packs if p.metadata.category == category]
|
||||
|
||||
# Filter by tags
|
||||
if tags:
|
||||
packs = [
|
||||
p for p in packs
|
||||
if any(tag in p.metadata.tags for tag in tags)
|
||||
]
|
||||
|
||||
# Return summaries
|
||||
return [
|
||||
{
|
||||
'id': p.id,
|
||||
'name': p.name,
|
||||
'description': p.description,
|
||||
'category': p.metadata.category,
|
||||
'tags': p.metadata.tags,
|
||||
'correction_count': p.correction_count,
|
||||
'success_rate': p.overall_success_rate,
|
||||
'created_at': p.created_at.isoformat(),
|
||||
'updated_at': p.updated_at.isoformat()
|
||||
}
|
||||
for p in packs
|
||||
]
|
||||
|
||||
def get_pack(self, pack_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get a pack by ID with full details.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
|
||||
Returns:
|
||||
Pack dict or None
|
||||
"""
|
||||
pack = self.repository.get_pack(pack_id)
|
||||
if not pack:
|
||||
return None
|
||||
return pack.to_dict()
|
||||
|
||||
def create_pack(
|
||||
self,
|
||||
name: str,
|
||||
description: str = "",
|
||||
category: str = "general",
|
||||
tags: Optional[List[str]] = None,
|
||||
author: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new correction pack.
|
||||
|
||||
Args:
|
||||
name: Pack name
|
||||
description: Pack description
|
||||
category: Pack category
|
||||
tags: Pack tags
|
||||
author: Pack author
|
||||
|
||||
Returns:
|
||||
Created pack dict
|
||||
"""
|
||||
metadata = {
|
||||
'category': category,
|
||||
'tags': tags or [],
|
||||
'author': author
|
||||
}
|
||||
|
||||
pack = self.repository.create_pack(name, description, metadata)
|
||||
return pack.to_dict()
|
||||
|
||||
def update_pack(
|
||||
self,
|
||||
pack_id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Update a pack's properties.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
name: New name
|
||||
description: New description
|
||||
metadata: New metadata
|
||||
|
||||
Returns:
|
||||
Updated pack dict or None
|
||||
"""
|
||||
updates = {}
|
||||
if name is not None:
|
||||
updates['name'] = name
|
||||
if description is not None:
|
||||
updates['description'] = description
|
||||
if metadata is not None:
|
||||
updates['metadata'] = metadata
|
||||
|
||||
pack = self.repository.update_pack(pack_id, updates)
|
||||
if not pack:
|
||||
return None
|
||||
return pack.to_dict()
|
||||
|
||||
def delete_pack(self, pack_id: str) -> bool:
|
||||
"""
|
||||
Delete a pack.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
return self.repository.delete_pack(pack_id)
|
||||
|
||||
# ========== Correction Operations ==========
|
||||
|
||||
def add_correction(
|
||||
self,
|
||||
pack_id: str,
|
||||
correction_data: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Add a correction to a pack.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
correction_data: Correction data dict
|
||||
|
||||
Returns:
|
||||
Created correction dict or None
|
||||
"""
|
||||
try:
|
||||
# Build correction key
|
||||
key = CorrectionKey(
|
||||
action_type=correction_data.get('action_type', 'unknown'),
|
||||
element_type=correction_data.get('element_type', 'unknown'),
|
||||
failure_context=correction_data.get('failure_context', '')
|
||||
)
|
||||
|
||||
# Determine correction type
|
||||
type_str = correction_data.get('correction_type', 'other')
|
||||
try:
|
||||
correction_type = CorrectionType(type_str)
|
||||
except ValueError:
|
||||
correction_type = CorrectionType.OTHER
|
||||
|
||||
# Build source if provided
|
||||
source = None
|
||||
if 'source' in correction_data:
|
||||
source = CorrectionSource(
|
||||
session_id=correction_data['source'].get('session_id', ''),
|
||||
workflow_id=correction_data['source'].get('workflow_id'),
|
||||
node_id=correction_data['source'].get('node_id')
|
||||
)
|
||||
|
||||
# Create correction
|
||||
correction = Correction(
|
||||
id=generate_correction_id(),
|
||||
key=key,
|
||||
correction_type=correction_type,
|
||||
original_target=correction_data.get('original_target'),
|
||||
original_params=correction_data.get('original_params'),
|
||||
corrected_target=correction_data.get('corrected_target'),
|
||||
corrected_params=correction_data.get('corrected_params'),
|
||||
correction_description=correction_data.get('description', ''),
|
||||
source=source,
|
||||
tags=correction_data.get('tags', [])
|
||||
)
|
||||
|
||||
if self.repository.add_correction(pack_id, correction):
|
||||
return correction.to_dict()
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding correction: {e}")
|
||||
return None
|
||||
|
||||
def add_correction_from_session(
|
||||
self,
|
||||
pack_id: str,
|
||||
session_correction: Dict[str, Any],
|
||||
session_id: str,
|
||||
workflow_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Add a correction from session user_corrections format.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
session_correction: Correction dict from TrainingSession.user_corrections
|
||||
session_id: Session ID
|
||||
workflow_id: Workflow ID
|
||||
|
||||
Returns:
|
||||
Created correction dict or None
|
||||
"""
|
||||
# Enrich with source info
|
||||
session_correction['source'] = {
|
||||
'session_id': session_id,
|
||||
'workflow_id': workflow_id,
|
||||
'node_id': session_correction.get('node_id')
|
||||
}
|
||||
|
||||
# Map old format fields to new format
|
||||
if 'type' in session_correction and 'action_type' not in session_correction:
|
||||
session_correction['action_type'] = session_correction['type']
|
||||
if 'reason' in session_correction and 'failure_context' not in session_correction:
|
||||
session_correction['failure_context'] = session_correction['reason']
|
||||
if 'new_target' in session_correction and 'corrected_target' not in session_correction:
|
||||
session_correction['corrected_target'] = session_correction['new_target']
|
||||
if 'new_params' in session_correction and 'corrected_params' not in session_correction:
|
||||
session_correction['corrected_params'] = session_correction['new_params']
|
||||
|
||||
return self.add_correction(pack_id, session_correction)
|
||||
|
||||
def update_correction(
|
||||
self,
|
||||
pack_id: str,
|
||||
correction_id: str,
|
||||
updates: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Update a correction.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
correction_id: Correction ID
|
||||
updates: Fields to update
|
||||
|
||||
Returns:
|
||||
Updated correction dict or None
|
||||
"""
|
||||
correction = self.repository.update_correction(pack_id, correction_id, updates)
|
||||
if not correction:
|
||||
return None
|
||||
return correction.to_dict()
|
||||
|
||||
def remove_correction(self, pack_id: str, correction_id: str) -> bool:
|
||||
"""
|
||||
Remove a correction from a pack.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
correction_id: Correction ID
|
||||
|
||||
Returns:
|
||||
True if removed
|
||||
"""
|
||||
return self.repository.remove_correction(pack_id, correction_id)
|
||||
|
||||
# ========== Aggregation Operations ==========
|
||||
|
||||
def aggregate_from_sessions(
|
||||
self,
|
||||
pack_id: str,
|
||||
session_files: Optional[List[str]] = None,
|
||||
workflow_id: Optional[str] = None,
|
||||
min_occurrences: int = 1,
|
||||
min_success_rate: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Aggregate corrections from training sessions into a pack.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID to add corrections to
|
||||
session_files: Specific session files to process (None = all)
|
||||
workflow_id: Filter by workflow ID
|
||||
min_occurrences: Minimum occurrences to include
|
||||
min_success_rate: Minimum success rate to include
|
||||
|
||||
Returns:
|
||||
Aggregation result summary
|
||||
"""
|
||||
# Collect all session corrections
|
||||
all_corrections = []
|
||||
|
||||
if session_files is None:
|
||||
# Find all session files
|
||||
session_files = list(self.training_data_path.glob("session_*.json"))
|
||||
else:
|
||||
session_files = [Path(f) for f in session_files]
|
||||
|
||||
sessions_processed = 0
|
||||
for session_file in session_files:
|
||||
try:
|
||||
with open(session_file, 'r', encoding='utf-8') as f:
|
||||
session_data = json.load(f)
|
||||
|
||||
# Filter by workflow if specified
|
||||
if workflow_id and session_data.get('workflow_id') != workflow_id:
|
||||
continue
|
||||
|
||||
user_corrections = session_data.get('user_corrections', [])
|
||||
session_id = session_data.get('session_id', session_file.stem)
|
||||
|
||||
for uc in user_corrections:
|
||||
uc['session_id'] = session_id
|
||||
uc['workflow_id'] = session_data.get('workflow_id')
|
||||
all_corrections.append(uc)
|
||||
|
||||
sessions_processed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading session file {session_file}: {e}")
|
||||
|
||||
# Configure aggregator
|
||||
self.aggregator.min_occurrences = min_occurrences
|
||||
self.aggregator.min_success_rate = min_success_rate
|
||||
|
||||
# Aggregate
|
||||
aggregated = self.aggregator.aggregate_from_sessions(all_corrections, workflow_id)
|
||||
|
||||
# Add to pack
|
||||
added_count = 0
|
||||
for correction in aggregated:
|
||||
if self.repository.add_correction(pack_id, correction):
|
||||
added_count += 1
|
||||
|
||||
return {
|
||||
'sessions_processed': sessions_processed,
|
||||
'corrections_found': len(all_corrections),
|
||||
'corrections_aggregated': len(aggregated),
|
||||
'corrections_added': added_count,
|
||||
'statistics': self.aggregator.calculate_statistics(aggregated)
|
||||
}
|
||||
|
||||
def aggregate_cross_workflow(
|
||||
self,
|
||||
pack_id: str,
|
||||
min_workflows: int = 2
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Find and aggregate corrections common across multiple workflows.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID to add corrections to
|
||||
min_workflows: Minimum number of workflows a correction must appear in
|
||||
|
||||
Returns:
|
||||
Aggregation result summary
|
||||
"""
|
||||
# Get existing corrections from all packs
|
||||
all_corrections = []
|
||||
for pack in self.repository.list_packs():
|
||||
all_corrections.extend(pack.corrections.values())
|
||||
|
||||
# Find common corrections
|
||||
common = self.cross_workflow_aggregator.find_common_corrections(
|
||||
all_corrections, min_workflows
|
||||
)
|
||||
|
||||
# Add to pack
|
||||
added_count = 0
|
||||
for correction, workflow_ids in common:
|
||||
# Tag with source workflows
|
||||
correction.tags.extend([f"wf:{wid}" for wid in workflow_ids[:5]]) # Limit tags
|
||||
correction.correction_description += f" (common in {len(workflow_ids)} workflows)"
|
||||
|
||||
if self.repository.add_correction(pack_id, correction):
|
||||
added_count += 1
|
||||
|
||||
# Detect patterns
|
||||
patterns = self.cross_workflow_aggregator.detect_patterns(all_corrections)
|
||||
|
||||
return {
|
||||
'total_corrections_analyzed': len(all_corrections),
|
||||
'common_corrections_found': len(common),
|
||||
'corrections_added': added_count,
|
||||
'patterns': patterns
|
||||
}
|
||||
|
||||
# ========== Search and Apply Operations ==========
|
||||
|
||||
def find_applicable_corrections(
|
||||
self,
|
||||
action_type: str,
|
||||
element_type: str,
|
||||
failure_context: str,
|
||||
pack_ids: Optional[List[str]] = None,
|
||||
min_confidence: float = 0.3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Find corrections applicable to a failure context.
|
||||
|
||||
Args:
|
||||
action_type: Type of action that failed
|
||||
element_type: Type of element involved
|
||||
failure_context: Description of failure
|
||||
pack_ids: Specific packs to search (None = all)
|
||||
min_confidence: Minimum confidence score
|
||||
|
||||
Returns:
|
||||
List of applicable corrections with pack info
|
||||
"""
|
||||
results = self.repository.find_by_context(action_type, element_type, failure_context)
|
||||
|
||||
# Filter by pack IDs if specified
|
||||
if pack_ids:
|
||||
results = [(pid, c) for pid, c in results if pid in pack_ids]
|
||||
|
||||
# Filter by confidence
|
||||
results = [(pid, c) for pid, c in results if c.confidence_score >= min_confidence]
|
||||
|
||||
# Format results
|
||||
return [
|
||||
{
|
||||
'pack_id': pack_id,
|
||||
'correction': correction.to_dict()
|
||||
}
|
||||
for pack_id, correction in results
|
||||
]
|
||||
|
||||
def apply_correction(
|
||||
self,
|
||||
pack_id: str,
|
||||
correction_id: str,
|
||||
success: bool,
|
||||
application_context: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Record an application of a correction.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
correction_id: Correction ID
|
||||
success: Whether the application was successful
|
||||
application_context: Optional context information
|
||||
|
||||
Returns:
|
||||
True if recorded
|
||||
"""
|
||||
self.repository.record_application(pack_id, correction_id, success)
|
||||
|
||||
logger.info(
|
||||
f"Recorded correction application: pack={pack_id}, "
|
||||
f"correction={correction_id}, success={success}"
|
||||
)
|
||||
return True
|
||||
|
||||
def search_corrections(
|
||||
self,
|
||||
query: str,
|
||||
pack_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search corrections by text query.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
pack_id: Optional pack ID to search in
|
||||
|
||||
Returns:
|
||||
List of matching corrections with pack info
|
||||
"""
|
||||
results = self.repository.search_corrections(query, pack_id)
|
||||
|
||||
return [
|
||||
{
|
||||
'pack_id': pid,
|
||||
'correction': correction.to_dict()
|
||||
}
|
||||
for pid, correction in results
|
||||
]
|
||||
|
||||
# ========== Export/Import Operations ==========
|
||||
|
||||
def export_pack(
|
||||
self,
|
||||
pack_id: str,
|
||||
format: str = 'json',
|
||||
include_stats: bool = True
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Export a pack to JSON or YAML.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
format: Export format ('json' or 'yaml')
|
||||
include_stats: Whether to include statistics
|
||||
|
||||
Returns:
|
||||
Exported content string or None
|
||||
"""
|
||||
return self.repository.export_pack(pack_id, format)
|
||||
|
||||
def import_pack(
|
||||
self,
|
||||
content: str,
|
||||
format: str = 'json',
|
||||
merge_strategy: str = 'create_new'
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Import a pack from JSON or YAML.
|
||||
|
||||
Args:
|
||||
content: Content to import
|
||||
format: Import format ('json' or 'yaml')
|
||||
merge_strategy: How to handle existing pack
|
||||
|
||||
Returns:
|
||||
Imported pack dict or None
|
||||
"""
|
||||
pack = self.repository.import_pack(content, format, merge_strategy)
|
||||
if not pack:
|
||||
return None
|
||||
return pack.to_dict()
|
||||
|
||||
def export_pack_to_file(
|
||||
self,
|
||||
pack_id: str,
|
||||
file_path: str,
|
||||
format: str = 'json'
|
||||
) -> bool:
|
||||
"""
|
||||
Export a pack to a file.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
file_path: Output file path
|
||||
format: Export format
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
content = self.export_pack(pack_id, format)
|
||||
if not content:
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
logger.info(f"Exported pack {pack_id} to {file_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting pack to file: {e}")
|
||||
return False
|
||||
|
||||
def import_pack_from_file(
|
||||
self,
|
||||
file_path: str,
|
||||
format: Optional[str] = None,
|
||||
merge_strategy: str = 'create_new'
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Import a pack from a file.
|
||||
|
||||
Args:
|
||||
file_path: Input file path
|
||||
format: Import format (auto-detected if None)
|
||||
merge_strategy: How to handle existing pack
|
||||
|
||||
Returns:
|
||||
Imported pack dict or None
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Auto-detect format
|
||||
if format is None:
|
||||
if file_path.endswith('.yaml') or file_path.endswith('.yml'):
|
||||
format = 'yaml'
|
||||
else:
|
||||
format = 'json'
|
||||
|
||||
result = self.import_pack(content, format, merge_strategy)
|
||||
if result:
|
||||
logger.info(f"Imported pack from {file_path}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing pack from file: {e}")
|
||||
return None
|
||||
|
||||
# ========== Version Operations ==========
|
||||
|
||||
def create_version(
|
||||
self,
|
||||
pack_id: str,
|
||||
description: str = ""
|
||||
) -> int:
|
||||
"""
|
||||
Create a version snapshot of a pack.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
description: Version description
|
||||
|
||||
Returns:
|
||||
Version number (-1 on error)
|
||||
"""
|
||||
return self.repository.create_version(pack_id, description)
|
||||
|
||||
def list_versions(self, pack_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all versions of a pack.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
|
||||
Returns:
|
||||
List of version info dicts
|
||||
"""
|
||||
return self.repository.list_versions(pack_id)
|
||||
|
||||
def rollback_pack(self, pack_id: str, version: int) -> bool:
|
||||
"""
|
||||
Rollback a pack to a previous version.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
version: Version number
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
return self.repository.rollback_pack(pack_id, version)
|
||||
|
||||
# ========== Statistics and Analysis ==========
|
||||
|
||||
def get_pack_statistics(self, pack_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get detailed statistics for a pack.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
|
||||
Returns:
|
||||
Statistics dict or None
|
||||
"""
|
||||
pack = self.repository.get_pack(pack_id)
|
||||
if not pack:
|
||||
return None
|
||||
|
||||
return pack.get_statistics()
|
||||
|
||||
def get_global_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get global statistics across all packs.
|
||||
|
||||
Returns:
|
||||
Global statistics dict
|
||||
"""
|
||||
packs = self.repository.list_packs()
|
||||
|
||||
total_corrections = sum(p.correction_count for p in packs)
|
||||
total_applications = sum(p.total_applications for p in packs)
|
||||
|
||||
# Aggregate success rate
|
||||
total_success = sum(
|
||||
sum(c.success_count for c in p.corrections.values())
|
||||
for p in packs
|
||||
)
|
||||
|
||||
return {
|
||||
'total_packs': len(packs),
|
||||
'total_corrections': total_corrections,
|
||||
'total_applications': total_applications,
|
||||
'overall_success_rate': total_success / total_applications if total_applications > 0 else 0.0,
|
||||
'packs_by_category': self._group_packs_by_category(packs),
|
||||
'top_corrections': self._get_top_corrections(packs, 10)
|
||||
}
|
||||
|
||||
def _group_packs_by_category(
|
||||
self,
|
||||
packs: List[CorrectionPack]
|
||||
) -> Dict[str, int]:
|
||||
"""Group packs by category."""
|
||||
by_category = {}
|
||||
for pack in packs:
|
||||
cat = pack.metadata.category
|
||||
by_category[cat] = by_category.get(cat, 0) + 1
|
||||
return by_category
|
||||
|
||||
def _get_top_corrections(
|
||||
self,
|
||||
packs: List[CorrectionPack],
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get top corrections by confidence score."""
|
||||
all_corrections = []
|
||||
for pack in packs:
|
||||
for correction in pack.corrections.values():
|
||||
all_corrections.append({
|
||||
'pack_id': pack.id,
|
||||
'pack_name': pack.name,
|
||||
'correction': correction
|
||||
})
|
||||
|
||||
# Sort by confidence
|
||||
all_corrections.sort(key=lambda x: x['correction'].confidence_score, reverse=True)
|
||||
|
||||
# Return top N
|
||||
return [
|
||||
{
|
||||
'pack_id': item['pack_id'],
|
||||
'pack_name': item['pack_name'],
|
||||
'correction_id': item['correction'].id,
|
||||
'action_type': item['correction'].key.action_type,
|
||||
'confidence': item['correction'].confidence_score,
|
||||
'success_rate': item['correction'].success_rate,
|
||||
'applications': item['correction'].success_count + item['correction'].failure_count
|
||||
}
|
||||
for item in all_corrections[:limit]
|
||||
]
|
||||
643
core/corrections/correction_repository.py
Normal file
643
core/corrections/correction_repository.py
Normal file
@@ -0,0 +1,643 @@
|
||||
"""
|
||||
Correction Repository - JSON storage for corrections and packs.
|
||||
|
||||
Provides atomic file operations for safe persistence.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from .models import (
|
||||
Correction,
|
||||
CorrectionPack,
|
||||
CorrectionKey,
|
||||
generate_pack_id
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CorrectionRepository:
|
||||
"""
|
||||
Repository for storing and retrieving correction packs.
|
||||
|
||||
Uses atomic JSON writes (temp file + rename) for safety.
|
||||
Maintains an index by key hash for fast lookups.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_path: str = "data/correction_packs"):
|
||||
"""
|
||||
Initialize the repository.
|
||||
|
||||
Args:
|
||||
storage_path: Base path for storing correction data
|
||||
"""
|
||||
self.storage_path = Path(storage_path)
|
||||
self.packs_file = self.storage_path / "packs.json"
|
||||
self.corrections_dir = self.storage_path / "corrections"
|
||||
self.versions_dir = self.storage_path / "versions"
|
||||
|
||||
# Ensure directories exist
|
||||
self._ensure_directories()
|
||||
|
||||
# In-memory cache
|
||||
self._packs: Dict[str, CorrectionPack] = {}
|
||||
self._key_index: Dict[str, List[str]] = {} # key_hash -> [correction_ids]
|
||||
|
||||
# Load existing data
|
||||
self._load_packs()
|
||||
self._rebuild_index()
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure all required directories exist."""
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
self.corrections_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.versions_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ========== Pack Operations ==========
|
||||
|
||||
def list_packs(self) -> List[CorrectionPack]:
|
||||
"""Get all correction packs."""
|
||||
return list(self._packs.values())
|
||||
|
||||
def get_pack(self, pack_id: str) -> Optional[CorrectionPack]:
|
||||
"""Get a pack by ID."""
|
||||
return self._packs.get(pack_id)
|
||||
|
||||
def create_pack(self, name: str, description: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None) -> CorrectionPack:
|
||||
"""
|
||||
Create a new correction pack.
|
||||
|
||||
Args:
|
||||
name: Pack name
|
||||
description: Pack description
|
||||
metadata: Optional metadata dict
|
||||
|
||||
Returns:
|
||||
Created pack
|
||||
"""
|
||||
pack_id = generate_pack_id()
|
||||
pack = CorrectionPack(
|
||||
id=pack_id,
|
||||
name=name,
|
||||
description=description
|
||||
)
|
||||
|
||||
if metadata:
|
||||
pack.metadata.version = metadata.get('version', '1.0.0')
|
||||
pack.metadata.author = metadata.get('author', '')
|
||||
pack.metadata.category = metadata.get('category', 'general')
|
||||
pack.metadata.tags = metadata.get('tags', [])
|
||||
|
||||
self._packs[pack_id] = pack
|
||||
self._save_packs()
|
||||
|
||||
logger.info(f"Created correction pack: {pack_id} ({name})")
|
||||
return pack
|
||||
|
||||
def update_pack(self, pack_id: str, updates: Dict[str, Any]) -> Optional[CorrectionPack]:
|
||||
"""
|
||||
Update a pack's properties.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
updates: Dict of fields to update
|
||||
|
||||
Returns:
|
||||
Updated pack or None if not found
|
||||
"""
|
||||
pack = self._packs.get(pack_id)
|
||||
if not pack:
|
||||
return None
|
||||
|
||||
# Update allowed fields
|
||||
if 'name' in updates:
|
||||
pack.name = updates['name']
|
||||
if 'description' in updates:
|
||||
pack.description = updates['description']
|
||||
if 'metadata' in updates:
|
||||
metadata = updates['metadata']
|
||||
if 'version' in metadata:
|
||||
pack.metadata.version = metadata['version']
|
||||
if 'author' in metadata:
|
||||
pack.metadata.author = metadata['author']
|
||||
if 'category' in metadata:
|
||||
pack.metadata.category = metadata['category']
|
||||
if 'tags' in metadata:
|
||||
pack.metadata.tags = metadata['tags']
|
||||
|
||||
pack.updated_at = datetime.now()
|
||||
self._save_packs()
|
||||
|
||||
logger.info(f"Updated correction pack: {pack_id}")
|
||||
return pack
|
||||
|
||||
def delete_pack(self, pack_id: str) -> bool:
|
||||
"""
|
||||
Delete a pack and all its corrections.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
if pack_id not in self._packs:
|
||||
return False
|
||||
|
||||
pack = self._packs[pack_id]
|
||||
|
||||
# Remove from index
|
||||
for correction in pack.corrections.values():
|
||||
key_hash = correction.key.to_hash()
|
||||
if key_hash in self._key_index:
|
||||
if correction.id in self._key_index[key_hash]:
|
||||
self._key_index[key_hash].remove(correction.id)
|
||||
|
||||
del self._packs[pack_id]
|
||||
self._save_packs()
|
||||
|
||||
# Remove version history
|
||||
version_dir = self.versions_dir / pack_id
|
||||
if version_dir.exists():
|
||||
shutil.rmtree(version_dir)
|
||||
|
||||
logger.info(f"Deleted correction pack: {pack_id}")
|
||||
return True
|
||||
|
||||
# ========== Correction Operations ==========
|
||||
|
||||
def add_correction(self, pack_id: str, correction: Correction) -> bool:
|
||||
"""
|
||||
Add a correction to a pack.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
correction: Correction to add
|
||||
|
||||
Returns:
|
||||
True if added, False if pack not found or correction exists
|
||||
"""
|
||||
pack = self._packs.get(pack_id)
|
||||
if not pack:
|
||||
return False
|
||||
|
||||
if not pack.add_correction(correction):
|
||||
return False
|
||||
|
||||
# Update index
|
||||
key_hash = correction.key.to_hash()
|
||||
if key_hash not in self._key_index:
|
||||
self._key_index[key_hash] = []
|
||||
if correction.id not in self._key_index[key_hash]:
|
||||
self._key_index[key_hash].append(correction.id)
|
||||
|
||||
self._save_packs()
|
||||
logger.info(f"Added correction {correction.id} to pack {pack_id}")
|
||||
return True
|
||||
|
||||
def remove_correction(self, pack_id: str, correction_id: str) -> bool:
|
||||
"""
|
||||
Remove a correction from a pack.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
correction_id: Correction ID
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
"""
|
||||
pack = self._packs.get(pack_id)
|
||||
if not pack:
|
||||
return False
|
||||
|
||||
correction = pack.get_correction(correction_id)
|
||||
if not correction:
|
||||
return False
|
||||
|
||||
# Update index
|
||||
key_hash = correction.key.to_hash()
|
||||
if key_hash in self._key_index:
|
||||
if correction_id in self._key_index[key_hash]:
|
||||
self._key_index[key_hash].remove(correction_id)
|
||||
|
||||
pack.remove_correction(correction_id)
|
||||
self._save_packs()
|
||||
|
||||
logger.info(f"Removed correction {correction_id} from pack {pack_id}")
|
||||
return True
|
||||
|
||||
def update_correction(self, pack_id: str, correction_id: str,
|
||||
updates: Dict[str, Any]) -> Optional[Correction]:
|
||||
"""
|
||||
Update a correction's properties.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
correction_id: Correction ID
|
||||
updates: Dict of fields to update
|
||||
|
||||
Returns:
|
||||
Updated correction or None if not found
|
||||
"""
|
||||
pack = self._packs.get(pack_id)
|
||||
if not pack:
|
||||
return None
|
||||
|
||||
correction = pack.get_correction(correction_id)
|
||||
if not correction:
|
||||
return None
|
||||
|
||||
# Update allowed fields
|
||||
if 'corrected_target' in updates:
|
||||
correction.corrected_target = updates['corrected_target']
|
||||
if 'corrected_params' in updates:
|
||||
correction.corrected_params = updates['corrected_params']
|
||||
if 'correction_description' in updates:
|
||||
correction.correction_description = updates['correction_description']
|
||||
if 'status' in updates:
|
||||
from .models import CorrectionStatus
|
||||
correction.status = CorrectionStatus(updates['status'])
|
||||
if 'tags' in updates:
|
||||
correction.tags = updates['tags']
|
||||
|
||||
correction.updated_at = datetime.now()
|
||||
pack.updated_at = datetime.now()
|
||||
self._save_packs()
|
||||
|
||||
logger.info(f"Updated correction {correction_id} in pack {pack_id}")
|
||||
return correction
|
||||
|
||||
def record_application(self, pack_id: str, correction_id: str, success: bool):
|
||||
"""
|
||||
Record an application of a correction.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
correction_id: Correction ID
|
||||
success: Whether the application was successful
|
||||
"""
|
||||
pack = self._packs.get(pack_id)
|
||||
if not pack:
|
||||
return
|
||||
|
||||
correction = pack.get_correction(correction_id)
|
||||
if not correction:
|
||||
return
|
||||
|
||||
correction.record_application(success)
|
||||
pack.updated_at = datetime.now()
|
||||
self._save_packs()
|
||||
|
||||
# ========== Search Operations ==========
|
||||
|
||||
def find_by_key_hash(self, key_hash: str) -> List[tuple]:
|
||||
"""
|
||||
Find all corrections matching a key hash.
|
||||
|
||||
Args:
|
||||
key_hash: MD5 hash of correction key
|
||||
|
||||
Returns:
|
||||
List of (pack_id, correction) tuples
|
||||
"""
|
||||
results = []
|
||||
|
||||
for pack_id, pack in self._packs.items():
|
||||
for correction in pack.find_by_key_hash(key_hash):
|
||||
results.append((pack_id, correction))
|
||||
|
||||
# Sort by confidence score
|
||||
results.sort(key=lambda x: x[1].confidence_score, reverse=True)
|
||||
return results
|
||||
|
||||
def find_by_context(self, action_type: str, element_type: str,
|
||||
failure_context: str) -> List[tuple]:
|
||||
"""
|
||||
Find corrections matching a failure context.
|
||||
|
||||
Args:
|
||||
action_type: Type of action that failed
|
||||
element_type: Type of element involved
|
||||
failure_context: Description of failure
|
||||
|
||||
Returns:
|
||||
List of (pack_id, correction) tuples
|
||||
"""
|
||||
key = CorrectionKey(action_type, element_type, failure_context)
|
||||
return self.find_by_key_hash(key.to_hash())
|
||||
|
||||
def search_corrections(self, query: str, pack_id: Optional[str] = None) -> List[tuple]:
|
||||
"""
|
||||
Search corrections by text query.
|
||||
|
||||
Args:
|
||||
query: Search query (searches in description, tags, etc.)
|
||||
pack_id: Optional pack ID to search in
|
||||
|
||||
Returns:
|
||||
List of (pack_id, correction) tuples
|
||||
"""
|
||||
results = []
|
||||
query_lower = query.lower()
|
||||
|
||||
packs_to_search = [self._packs[pack_id]] if pack_id and pack_id in self._packs else self._packs.values()
|
||||
|
||||
for pack in packs_to_search:
|
||||
for correction in pack.corrections.values():
|
||||
# Search in description
|
||||
if query_lower in correction.correction_description.lower():
|
||||
results.append((pack.id, correction))
|
||||
continue
|
||||
|
||||
# Search in tags
|
||||
if any(query_lower in tag.lower() for tag in correction.tags):
|
||||
results.append((pack.id, correction))
|
||||
continue
|
||||
|
||||
# Search in key components
|
||||
if (query_lower in correction.key.action_type.lower() or
|
||||
query_lower in correction.key.element_type.lower() or
|
||||
query_lower in correction.key.failure_context.lower()):
|
||||
results.append((pack.id, correction))
|
||||
|
||||
return results
|
||||
|
||||
# ========== Version Operations ==========
|
||||
|
||||
def create_version(self, pack_id: str, description: str = "") -> int:
|
||||
"""
|
||||
Create a version snapshot of a pack.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
description: Version description
|
||||
|
||||
Returns:
|
||||
Version number
|
||||
"""
|
||||
pack = self._packs.get(pack_id)
|
||||
if not pack:
|
||||
return -1
|
||||
|
||||
# Create version directory
|
||||
pack_versions = self.versions_dir / pack_id
|
||||
pack_versions.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Increment version
|
||||
version = pack.current_version
|
||||
pack.current_version += 1
|
||||
|
||||
# Save snapshot
|
||||
version_file = pack_versions / f"v{version}.json"
|
||||
version_data = {
|
||||
'version': version,
|
||||
'description': description,
|
||||
'created_at': datetime.now().isoformat(),
|
||||
'pack_snapshot': pack.to_dict()
|
||||
}
|
||||
|
||||
self._atomic_write(version_file, version_data)
|
||||
self._save_packs()
|
||||
|
||||
logger.info(f"Created version {version} for pack {pack_id}")
|
||||
return version
|
||||
|
||||
def list_versions(self, pack_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all versions of a pack.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
|
||||
Returns:
|
||||
List of version info dicts
|
||||
"""
|
||||
versions = []
|
||||
pack_versions = self.versions_dir / pack_id
|
||||
|
||||
if not pack_versions.exists():
|
||||
return versions
|
||||
|
||||
for version_file in sorted(pack_versions.glob("v*.json")):
|
||||
try:
|
||||
with open(version_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
versions.append({
|
||||
'version': data.get('version'),
|
||||
'description': data.get('description', ''),
|
||||
'created_at': data.get('created_at'),
|
||||
'correction_count': len(data.get('pack_snapshot', {}).get('corrections', {}))
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading version file {version_file}: {e}")
|
||||
|
||||
return versions
|
||||
|
||||
def rollback_pack(self, pack_id: str, version: int) -> bool:
|
||||
"""
|
||||
Rollback a pack to a previous version.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
version: Version number to rollback to
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
version_file = self.versions_dir / pack_id / f"v{version}.json"
|
||||
|
||||
if not version_file.exists():
|
||||
logger.warning(f"Version {version} not found for pack {pack_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(version_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Restore pack from snapshot
|
||||
snapshot = data.get('pack_snapshot', {})
|
||||
pack = CorrectionPack.from_dict(snapshot)
|
||||
|
||||
# Update version number to continue from current
|
||||
if pack_id in self._packs:
|
||||
pack.current_version = self._packs[pack_id].current_version + 1
|
||||
|
||||
self._packs[pack_id] = pack
|
||||
self._rebuild_index()
|
||||
self._save_packs()
|
||||
|
||||
logger.info(f"Rolled back pack {pack_id} to version {version}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error rolling back pack {pack_id} to version {version}: {e}")
|
||||
return False
|
||||
|
||||
# ========== Persistence ==========
|
||||
|
||||
def _load_packs(self):
|
||||
"""Load all packs from storage."""
|
||||
if not self.packs_file.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.packs_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
for pack_id, pack_data in data.items():
|
||||
try:
|
||||
self._packs[pack_id] = CorrectionPack.from_dict(pack_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading pack {pack_id}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(self._packs)} correction packs")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading packs file: {e}")
|
||||
|
||||
def _save_packs(self):
|
||||
"""Save all packs to storage using atomic write."""
|
||||
try:
|
||||
data = {
|
||||
pack_id: pack.to_dict()
|
||||
for pack_id, pack in self._packs.items()
|
||||
}
|
||||
self._atomic_write(self.packs_file, data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving packs: {e}")
|
||||
|
||||
def _atomic_write(self, file_path: Path, data: Dict[str, Any]):
|
||||
"""
|
||||
Write data to file atomically (temp file + rename).
|
||||
|
||||
Args:
|
||||
file_path: Target file path
|
||||
data: Data to write
|
||||
"""
|
||||
temp_file = file_path.with_suffix('.tmp')
|
||||
|
||||
try:
|
||||
with open(temp_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Atomic rename
|
||||
temp_file.replace(file_path)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up temp file on error
|
||||
if temp_file.exists():
|
||||
temp_file.unlink()
|
||||
raise e
|
||||
|
||||
def _rebuild_index(self):
|
||||
"""Rebuild the key hash index from all packs."""
|
||||
self._key_index.clear()
|
||||
|
||||
for pack in self._packs.values():
|
||||
for correction in pack.corrections.values():
|
||||
key_hash = correction.key.to_hash()
|
||||
if key_hash not in self._key_index:
|
||||
self._key_index[key_hash] = []
|
||||
if correction.id not in self._key_index[key_hash]:
|
||||
self._key_index[key_hash].append(correction.id)
|
||||
|
||||
# ========== Export/Import ==========
|
||||
|
||||
def export_pack(self, pack_id: str, format: str = 'json') -> Optional[str]:
|
||||
"""
|
||||
Export a pack to JSON or YAML format.
|
||||
|
||||
Args:
|
||||
pack_id: Pack ID
|
||||
format: Export format ('json' or 'yaml')
|
||||
|
||||
Returns:
|
||||
Exported content as string, or None if pack not found
|
||||
"""
|
||||
pack = self._packs.get(pack_id)
|
||||
if not pack:
|
||||
return None
|
||||
|
||||
export_data = {
|
||||
'export_format': 'correction_pack',
|
||||
'export_version': '1.0',
|
||||
'exported_at': datetime.now().isoformat(),
|
||||
'pack': pack.to_dict()
|
||||
}
|
||||
|
||||
if format == 'yaml':
|
||||
try:
|
||||
import yaml
|
||||
return yaml.dump(export_data, default_flow_style=False, allow_unicode=True)
|
||||
except ImportError:
|
||||
logger.warning("PyYAML not installed, falling back to JSON")
|
||||
return json.dumps(export_data, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
return json.dumps(export_data, indent=2, ensure_ascii=False)
|
||||
|
||||
def import_pack(self, content: str, format: str = 'json',
|
||||
merge_strategy: str = 'create_new') -> Optional[CorrectionPack]:
|
||||
"""
|
||||
Import a pack from JSON or YAML content.
|
||||
|
||||
Args:
|
||||
content: Content to import
|
||||
format: Import format ('json' or 'yaml')
|
||||
merge_strategy: How to handle existing pack:
|
||||
- 'create_new': Always create new pack with new ID
|
||||
- 'merge': Merge into existing pack if same ID
|
||||
- 'replace': Replace existing pack if same ID
|
||||
|
||||
Returns:
|
||||
Imported pack, or None on error
|
||||
"""
|
||||
try:
|
||||
if format == 'yaml':
|
||||
try:
|
||||
import yaml
|
||||
data = yaml.safe_load(content)
|
||||
except ImportError:
|
||||
logger.error("PyYAML not installed")
|
||||
return None
|
||||
else:
|
||||
data = json.loads(content)
|
||||
|
||||
# Validate format
|
||||
if data.get('export_format') != 'correction_pack':
|
||||
logger.warning("Invalid export format")
|
||||
return None
|
||||
|
||||
pack_data = data.get('pack', {})
|
||||
imported_pack = CorrectionPack.from_dict(pack_data)
|
||||
|
||||
# Handle merge strategy
|
||||
if merge_strategy == 'create_new':
|
||||
imported_pack.id = generate_pack_id()
|
||||
self._packs[imported_pack.id] = imported_pack
|
||||
elif merge_strategy == 'merge' and imported_pack.id in self._packs:
|
||||
existing_pack = self._packs[imported_pack.id]
|
||||
existing_pack.merge(imported_pack)
|
||||
imported_pack = existing_pack
|
||||
elif merge_strategy == 'replace':
|
||||
self._packs[imported_pack.id] = imported_pack
|
||||
else:
|
||||
# Default: create new
|
||||
imported_pack.id = generate_pack_id()
|
||||
self._packs[imported_pack.id] = imported_pack
|
||||
|
||||
self._rebuild_index()
|
||||
self._save_packs()
|
||||
|
||||
logger.info(f"Imported pack: {imported_pack.id} ({imported_pack.name})")
|
||||
return imported_pack
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing pack: {e}")
|
||||
return None
|
||||
480
core/corrections/models.py
Normal file
480
core/corrections/models.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""
|
||||
Data models for the Correction Packs system.
|
||||
|
||||
Defines CorrectionKey, Correction, CorrectionPack and related types.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
|
||||
class CorrectionType(Enum):
|
||||
"""Types of corrections that can be applied."""
|
||||
TARGET_CHANGE = "target_change" # Changed target selector
|
||||
PARAMETER_CHANGE = "parameter_change" # Changed action parameters
|
||||
ACTION_REPLACE = "action_replace" # Replaced action type
|
||||
WAIT_ADDED = "wait_added" # Added wait before action
|
||||
RETRY_CONFIG = "retry_config" # Changed retry configuration
|
||||
FALLBACK_ADDED = "fallback_added" # Added fallback action
|
||||
COORDINATES_ADJUST = "coordinates_adjust" # Adjusted click coordinates
|
||||
TIMING_ADJUST = "timing_adjust" # Adjusted timing/delay
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class CorrectionStatus(Enum):
|
||||
"""Status of a correction."""
|
||||
ACTIVE = "active" # Correction is active and can be applied
|
||||
DEPRECATED = "deprecated" # Correction is deprecated
|
||||
DISABLED = "disabled" # Correction is manually disabled
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionKey:
|
||||
"""
|
||||
Unique key to identify similar corrections across workflows.
|
||||
|
||||
Uses action_type + element_type + failure_context to generate MD5 hash.
|
||||
Similar to LearningRepository pattern key generation.
|
||||
"""
|
||||
action_type: str
|
||||
element_type: str
|
||||
failure_context: str
|
||||
|
||||
def to_hash(self) -> str:
|
||||
"""Generate MD5 hash of the correction key (16 chars)."""
|
||||
key_string = f"{self.action_type}|{self.element_type}|{self.failure_context}"
|
||||
return hashlib.md5(key_string.encode()).hexdigest()[:16]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'action_type': self.action_type,
|
||||
'element_type': self.element_type,
|
||||
'failure_context': self.failure_context,
|
||||
'hash': self.to_hash()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionKey':
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
action_type=data.get('action_type', 'unknown'),
|
||||
element_type=data.get('element_type', 'unknown'),
|
||||
failure_context=data.get('failure_context', '')
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionSource:
|
||||
"""Source information for a correction."""
|
||||
session_id: str
|
||||
workflow_id: Optional[str] = None
|
||||
node_id: Optional[str] = None
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'session_id': self.session_id,
|
||||
'workflow_id': self.workflow_id,
|
||||
'node_id': self.node_id,
|
||||
'timestamp': self.timestamp.isoformat()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionSource':
|
||||
"""Create from dictionary."""
|
||||
timestamp = data.get('timestamp')
|
||||
if isinstance(timestamp, str):
|
||||
timestamp = datetime.fromisoformat(timestamp)
|
||||
elif timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
return cls(
|
||||
session_id=data.get('session_id', ''),
|
||||
workflow_id=data.get('workflow_id'),
|
||||
node_id=data.get('node_id'),
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Correction:
|
||||
"""
|
||||
Individual correction record.
|
||||
|
||||
Contains the original context, the applied correction, source information,
|
||||
and statistics about success/failure.
|
||||
"""
|
||||
id: str
|
||||
key: CorrectionKey
|
||||
correction_type: CorrectionType
|
||||
|
||||
# Original context
|
||||
original_target: Optional[Dict[str, Any]] = None
|
||||
original_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Correction applied
|
||||
corrected_target: Optional[Dict[str, Any]] = None
|
||||
corrected_params: Optional[Dict[str, Any]] = None
|
||||
correction_description: str = ""
|
||||
|
||||
# Source information
|
||||
source: Optional[CorrectionSource] = None
|
||||
|
||||
# Statistics
|
||||
success_count: int = 0
|
||||
failure_count: int = 0
|
||||
last_applied: Optional[datetime] = None
|
||||
|
||||
# Status
|
||||
status: CorrectionStatus = CorrectionStatus.ACTIVE
|
||||
|
||||
# Metadata
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
tags: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate success rate."""
|
||||
total = self.success_count + self.failure_count
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.success_count / total
|
||||
|
||||
@property
|
||||
def confidence_score(self) -> float:
|
||||
"""Calculate confidence score based on success rate and usage count."""
|
||||
total = self.success_count + self.failure_count
|
||||
if total == 0:
|
||||
return 0.0
|
||||
# Confidence increases with more data points (up to 100 uses)
|
||||
usage_factor = min(1.0, total / 100)
|
||||
return self.success_rate * (0.5 + 0.5 * usage_factor)
|
||||
|
||||
def record_application(self, success: bool):
|
||||
"""Record an application of this correction."""
|
||||
if success:
|
||||
self.success_count += 1
|
||||
else:
|
||||
self.failure_count += 1
|
||||
self.last_applied = datetime.now()
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'id': self.id,
|
||||
'key': self.key.to_dict(),
|
||||
'correction_type': self.correction_type.value,
|
||||
'original_target': self.original_target,
|
||||
'original_params': self.original_params,
|
||||
'corrected_target': self.corrected_target,
|
||||
'corrected_params': self.corrected_params,
|
||||
'correction_description': self.correction_description,
|
||||
'source': self.source.to_dict() if self.source else None,
|
||||
'success_count': self.success_count,
|
||||
'failure_count': self.failure_count,
|
||||
'success_rate': self.success_rate,
|
||||
'confidence_score': self.confidence_score,
|
||||
'last_applied': self.last_applied.isoformat() if self.last_applied else None,
|
||||
'status': self.status.value,
|
||||
'created_at': self.created_at.isoformat(),
|
||||
'updated_at': self.updated_at.isoformat(),
|
||||
'tags': self.tags
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'Correction':
|
||||
"""Create from dictionary."""
|
||||
# Parse datetime fields
|
||||
created_at = data.get('created_at')
|
||||
if isinstance(created_at, str):
|
||||
created_at = datetime.fromisoformat(created_at)
|
||||
elif created_at is None:
|
||||
created_at = datetime.now()
|
||||
|
||||
updated_at = data.get('updated_at')
|
||||
if isinstance(updated_at, str):
|
||||
updated_at = datetime.fromisoformat(updated_at)
|
||||
elif updated_at is None:
|
||||
updated_at = datetime.now()
|
||||
|
||||
last_applied = data.get('last_applied')
|
||||
if isinstance(last_applied, str):
|
||||
last_applied = datetime.fromisoformat(last_applied)
|
||||
|
||||
# Parse nested objects
|
||||
key_data = data.get('key', {})
|
||||
key = CorrectionKey.from_dict(key_data)
|
||||
|
||||
source_data = data.get('source')
|
||||
source = CorrectionSource.from_dict(source_data) if source_data else None
|
||||
|
||||
return cls(
|
||||
id=data.get('id', str(uuid.uuid4())),
|
||||
key=key,
|
||||
correction_type=CorrectionType(data.get('correction_type', 'other')),
|
||||
original_target=data.get('original_target'),
|
||||
original_params=data.get('original_params'),
|
||||
corrected_target=data.get('corrected_target'),
|
||||
corrected_params=data.get('corrected_params'),
|
||||
correction_description=data.get('correction_description', ''),
|
||||
source=source,
|
||||
success_count=data.get('success_count', 0),
|
||||
failure_count=data.get('failure_count', 0),
|
||||
last_applied=last_applied,
|
||||
status=CorrectionStatus(data.get('status', 'active')),
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
tags=data.get('tags', [])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionPackMetadata:
|
||||
"""Metadata for a correction pack."""
|
||||
version: str = "1.0.0"
|
||||
author: str = ""
|
||||
category: str = "general"
|
||||
tags: List[str] = field(default_factory=list)
|
||||
description: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'version': self.version,
|
||||
'author': self.author,
|
||||
'category': self.category,
|
||||
'tags': self.tags,
|
||||
'description': self.description
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPackMetadata':
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
version=data.get('version', '1.0.0'),
|
||||
author=data.get('author', ''),
|
||||
category=data.get('category', 'general'),
|
||||
tags=data.get('tags', []),
|
||||
description=data.get('description', '')
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionPack:
|
||||
"""
|
||||
A pack of corrections that can be exported, imported, and shared.
|
||||
|
||||
Contains multiple corrections with metadata and statistics.
|
||||
"""
|
||||
id: str
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
# Corrections in this pack
|
||||
corrections: Dict[str, Correction] = field(default_factory=dict)
|
||||
|
||||
# Metadata
|
||||
metadata: CorrectionPackMetadata = field(default_factory=CorrectionPackMetadata)
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
# Version history
|
||||
current_version: int = 1
|
||||
|
||||
@property
|
||||
def correction_count(self) -> int:
|
||||
"""Number of corrections in the pack."""
|
||||
return len(self.corrections)
|
||||
|
||||
@property
|
||||
def total_applications(self) -> int:
|
||||
"""Total number of times corrections were applied."""
|
||||
return sum(c.success_count + c.failure_count for c in self.corrections.values())
|
||||
|
||||
@property
|
||||
def overall_success_rate(self) -> float:
|
||||
"""Overall success rate across all corrections."""
|
||||
total_success = sum(c.success_count for c in self.corrections.values())
|
||||
total_failure = sum(c.failure_count for c in self.corrections.values())
|
||||
total = total_success + total_failure
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return total_success / total
|
||||
|
||||
@property
|
||||
def active_corrections(self) -> List[Correction]:
|
||||
"""Get only active corrections."""
|
||||
return [c for c in self.corrections.values() if c.status == CorrectionStatus.ACTIVE]
|
||||
|
||||
def add_correction(self, correction: Correction) -> bool:
|
||||
"""
|
||||
Add a correction to the pack.
|
||||
|
||||
Returns True if added, False if already exists.
|
||||
"""
|
||||
if correction.id in self.corrections:
|
||||
return False
|
||||
self.corrections[correction.id] = correction
|
||||
self.updated_at = datetime.now()
|
||||
return True
|
||||
|
||||
def remove_correction(self, correction_id: str) -> bool:
|
||||
"""
|
||||
Remove a correction from the pack.
|
||||
|
||||
Returns True if removed, False if not found.
|
||||
"""
|
||||
if correction_id not in self.corrections:
|
||||
return False
|
||||
del self.corrections[correction_id]
|
||||
self.updated_at = datetime.now()
|
||||
return True
|
||||
|
||||
def get_correction(self, correction_id: str) -> Optional[Correction]:
|
||||
"""Get a correction by ID."""
|
||||
return self.corrections.get(correction_id)
|
||||
|
||||
def find_by_key_hash(self, key_hash: str) -> List[Correction]:
|
||||
"""Find corrections matching a key hash."""
|
||||
return [
|
||||
c for c in self.corrections.values()
|
||||
if c.key.to_hash() == key_hash and c.status == CorrectionStatus.ACTIVE
|
||||
]
|
||||
|
||||
def merge(self, other: 'CorrectionPack', conflict_strategy: str = 'keep_higher_confidence') -> int:
|
||||
"""
|
||||
Merge another pack into this one.
|
||||
|
||||
Args:
|
||||
other: Pack to merge
|
||||
conflict_strategy: How to handle conflicts:
|
||||
- 'keep_higher_confidence': Keep correction with higher confidence
|
||||
- 'keep_newer': Keep the most recently updated correction
|
||||
- 'keep_existing': Keep existing, ignore new
|
||||
- 'replace': Always replace with new
|
||||
|
||||
Returns:
|
||||
Number of corrections added/updated
|
||||
"""
|
||||
merged_count = 0
|
||||
|
||||
for correction_id, correction in other.corrections.items():
|
||||
if correction_id in self.corrections:
|
||||
existing = self.corrections[correction_id]
|
||||
|
||||
should_replace = False
|
||||
if conflict_strategy == 'keep_higher_confidence':
|
||||
should_replace = correction.confidence_score > existing.confidence_score
|
||||
elif conflict_strategy == 'keep_newer':
|
||||
should_replace = correction.updated_at > existing.updated_at
|
||||
elif conflict_strategy == 'replace':
|
||||
should_replace = True
|
||||
# 'keep_existing' doesn't replace
|
||||
|
||||
if should_replace:
|
||||
self.corrections[correction_id] = correction
|
||||
merged_count += 1
|
||||
else:
|
||||
self.corrections[correction_id] = correction
|
||||
merged_count += 1
|
||||
|
||||
if merged_count > 0:
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
return merged_count
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get aggregate statistics for the pack."""
|
||||
corrections = list(self.corrections.values())
|
||||
active = [c for c in corrections if c.status == CorrectionStatus.ACTIVE]
|
||||
|
||||
# Group by correction type
|
||||
by_type = {}
|
||||
for c in corrections:
|
||||
type_name = c.correction_type.value
|
||||
if type_name not in by_type:
|
||||
by_type[type_name] = 0
|
||||
by_type[type_name] += 1
|
||||
|
||||
return {
|
||||
'total_corrections': len(corrections),
|
||||
'active_corrections': len(active),
|
||||
'deprecated_corrections': len([c for c in corrections if c.status == CorrectionStatus.DEPRECATED]),
|
||||
'disabled_corrections': len([c for c in corrections if c.status == CorrectionStatus.DISABLED]),
|
||||
'total_applications': self.total_applications,
|
||||
'overall_success_rate': self.overall_success_rate,
|
||||
'corrections_by_type': by_type,
|
||||
'average_confidence': sum(c.confidence_score for c in active) / len(active) if active else 0.0
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'corrections': {cid: c.to_dict() for cid, c in self.corrections.items()},
|
||||
'metadata': self.metadata.to_dict(),
|
||||
'created_at': self.created_at.isoformat(),
|
||||
'updated_at': self.updated_at.isoformat(),
|
||||
'current_version': self.current_version,
|
||||
'statistics': self.get_statistics()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPack':
|
||||
"""Create from dictionary."""
|
||||
# Parse datetime fields
|
||||
created_at = data.get('created_at')
|
||||
if isinstance(created_at, str):
|
||||
created_at = datetime.fromisoformat(created_at)
|
||||
elif created_at is None:
|
||||
created_at = datetime.now()
|
||||
|
||||
updated_at = data.get('updated_at')
|
||||
if isinstance(updated_at, str):
|
||||
updated_at = datetime.fromisoformat(updated_at)
|
||||
elif updated_at is None:
|
||||
updated_at = datetime.now()
|
||||
|
||||
# Parse corrections
|
||||
corrections_data = data.get('corrections', {})
|
||||
corrections = {
|
||||
cid: Correction.from_dict(cdata)
|
||||
for cid, cdata in corrections_data.items()
|
||||
}
|
||||
|
||||
# Parse metadata
|
||||
metadata_data = data.get('metadata', {})
|
||||
metadata = CorrectionPackMetadata.from_dict(metadata_data)
|
||||
|
||||
return cls(
|
||||
id=data.get('id', str(uuid.uuid4())),
|
||||
name=data.get('name', 'Unnamed Pack'),
|
||||
description=data.get('description', ''),
|
||||
corrections=corrections,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
current_version=data.get('current_version', 1)
|
||||
)
|
||||
|
||||
|
||||
def generate_correction_id() -> str:
|
||||
"""Generate a unique correction ID."""
|
||||
return f"corr_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
def generate_pack_id() -> str:
|
||||
"""Generate a unique pack ID."""
|
||||
return f"pack_{uuid.uuid4().hex[:12]}"
|
||||
443
core/healing/healing_engine.py
Normal file
443
core/healing/healing_engine.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""Main self-healing engine for workflow recovery."""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from pathlib import Path
|
||||
from .models import RecoveryContext, RecoveryResult, RecoverySuggestion
|
||||
from .learning_repository import LearningRepository
|
||||
from .confidence_scorer import ConfidenceScorer
|
||||
from .strategies import (
|
||||
RecoveryStrategy,
|
||||
SemanticVariantStrategy,
|
||||
SpatialFallbackStrategy,
|
||||
TimingAdaptationStrategy,
|
||||
FormatTransformationStrategy
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SelfHealingEngine:
|
||||
"""Main engine for self-healing workflows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_repo: Optional[LearningRepository] = None,
|
||||
storage_path: Optional[Path] = None,
|
||||
correction_packs_enabled: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize the self-healing engine.
|
||||
|
||||
Args:
|
||||
learning_repo: Learning repository instance
|
||||
storage_path: Path for storing learned patterns
|
||||
correction_packs_enabled: Whether to consult correction packs
|
||||
"""
|
||||
# Initialize learning repository
|
||||
if learning_repo:
|
||||
self.learning_repo = learning_repo
|
||||
else:
|
||||
storage_path = storage_path or Path('data/healing')
|
||||
self.learning_repo = LearningRepository(storage_path)
|
||||
|
||||
# Initialize confidence scorer
|
||||
self.confidence_scorer = ConfidenceScorer()
|
||||
|
||||
# Initialize recovery strategies
|
||||
self.recovery_strategies: List[RecoveryStrategy] = [
|
||||
SemanticVariantStrategy(),
|
||||
SpatialFallbackStrategy(),
|
||||
TimingAdaptationStrategy(),
|
||||
FormatTransformationStrategy()
|
||||
]
|
||||
|
||||
# Configuration
|
||||
self.max_recovery_time = 30.0 # seconds
|
||||
self.parallel_execution = True
|
||||
|
||||
# Correction packs integration
|
||||
self.correction_packs_enabled = correction_packs_enabled
|
||||
self._correction_pack_service = None
|
||||
|
||||
@property
|
||||
def correction_pack_service(self):
|
||||
"""Lazy-load correction pack service."""
|
||||
if self._correction_pack_service is None and self.correction_packs_enabled:
|
||||
try:
|
||||
from core.corrections import CorrectionPackService
|
||||
self._correction_pack_service = CorrectionPackService()
|
||||
logger.info("CorrectionPackService initialized for healing engine")
|
||||
except ImportError as e:
|
||||
logger.warning(f"CorrectionPackService not available: {e}")
|
||||
self.correction_packs_enabled = False
|
||||
return self._correction_pack_service
|
||||
|
||||
def attempt_recovery(self, context: RecoveryContext) -> RecoveryResult:
|
||||
"""
|
||||
Attempt to recover from a workflow failure.
|
||||
|
||||
Args:
|
||||
context: Recovery context with failure information
|
||||
|
||||
Returns:
|
||||
RecoveryResult with outcome
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Check if we've exceeded max attempts
|
||||
if context.attempt_count >= context.max_attempts:
|
||||
return RecoveryResult(
|
||||
success=False,
|
||||
strategy_used='none',
|
||||
error_message=f'Max recovery attempts ({context.max_attempts}) exceeded'
|
||||
)
|
||||
|
||||
# Try correction packs first (if enabled)
|
||||
if self.correction_packs_enabled and self.correction_pack_service:
|
||||
pack_result = self._try_correction_packs(context)
|
||||
if pack_result and pack_result.success:
|
||||
return pack_result
|
||||
|
||||
# Get learned patterns for this context
|
||||
learned_patterns = self.learning_repo.get_matching_patterns(context)
|
||||
|
||||
# Get prioritized strategies
|
||||
strategies = self._prioritize_strategies(context, learned_patterns)
|
||||
|
||||
# Try each strategy
|
||||
for strategy in strategies:
|
||||
# Check time limit
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed >= self.max_recovery_time:
|
||||
break
|
||||
|
||||
# Skip if strategy can't handle this failure
|
||||
if not strategy.can_handle(context):
|
||||
continue
|
||||
|
||||
# Attempt recovery
|
||||
result = strategy.attempt_recovery(context)
|
||||
|
||||
# If successful, learn from it
|
||||
if result.success:
|
||||
# Calculate final confidence
|
||||
historical_success = self._get_historical_success_rate(
|
||||
strategy.name, context, learned_patterns
|
||||
)
|
||||
result.confidence_score = self.confidence_scorer.calculate_recovery_confidence(
|
||||
result.strategy_used,
|
||||
context,
|
||||
historical_success
|
||||
)
|
||||
|
||||
# Check if safe to proceed
|
||||
if self.confidence_scorer.is_safe_to_proceed(
|
||||
result.confidence_score,
|
||||
context.confidence_threshold,
|
||||
involves_data_modification=self._involves_data_modification(context)
|
||||
):
|
||||
# Learn from success
|
||||
self.learn_from_success(context, result)
|
||||
return result
|
||||
else:
|
||||
# Confidence too low, mark as requiring user input
|
||||
result.requires_user_input = True
|
||||
return result
|
||||
|
||||
# All strategies failed
|
||||
total_time = time.time() - start_time
|
||||
return RecoveryResult(
|
||||
success=False,
|
||||
strategy_used='all_failed',
|
||||
execution_time=total_time,
|
||||
error_message='All recovery strategies failed',
|
||||
requires_user_input=True
|
||||
)
|
||||
|
||||
def learn_from_success(self, context: RecoveryContext, result: RecoveryResult):
|
||||
"""
|
||||
Learn from successful recovery for future use.
|
||||
|
||||
Args:
|
||||
context: Recovery context
|
||||
result: Successful recovery result
|
||||
"""
|
||||
self.learning_repo.store_pattern(context, result)
|
||||
|
||||
def get_recovery_suggestions(self, context: RecoveryContext) -> List[RecoverySuggestion]:
|
||||
"""
|
||||
Get ranked recovery suggestions based on learned patterns.
|
||||
|
||||
Args:
|
||||
context: Recovery context
|
||||
|
||||
Returns:
|
||||
List of recovery suggestions sorted by confidence
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
# Get learned patterns
|
||||
learned_patterns = self.learning_repo.get_matching_patterns(context)
|
||||
|
||||
# Get suggestions from each strategy
|
||||
for strategy in self.recovery_strategies:
|
||||
if not strategy.can_handle(context):
|
||||
continue
|
||||
|
||||
# Calculate confidence based on historical success
|
||||
historical_success = self._get_historical_success_rate(
|
||||
strategy.name, context, learned_patterns
|
||||
)
|
||||
confidence = self.confidence_scorer.calculate_recovery_confidence(
|
||||
strategy.name,
|
||||
context,
|
||||
historical_success
|
||||
)
|
||||
|
||||
suggestion = RecoverySuggestion(
|
||||
strategy=strategy.name,
|
||||
confidence=confidence,
|
||||
description=self._get_strategy_description(strategy),
|
||||
estimated_time=self._estimate_strategy_time(strategy, context)
|
||||
)
|
||||
suggestions.append(suggestion)
|
||||
|
||||
# Sort by confidence
|
||||
suggestions.sort(key=lambda s: s.confidence, reverse=True)
|
||||
|
||||
return suggestions
|
||||
|
||||
def prune_learned_patterns(
|
||||
self,
|
||||
max_age_days: int = 90,
|
||||
min_confidence: float = 0.3
|
||||
):
|
||||
"""
|
||||
Prune outdated learned patterns.
|
||||
|
||||
Args:
|
||||
max_age_days: Maximum age for patterns
|
||||
min_confidence: Minimum confidence threshold
|
||||
"""
|
||||
self.learning_repo.prune_outdated_patterns(max_age_days, min_confidence)
|
||||
|
||||
def _prioritize_strategies(
|
||||
self,
|
||||
context: RecoveryContext,
|
||||
learned_patterns: List
|
||||
) -> List[RecoveryStrategy]:
|
||||
"""Prioritize strategies based on context and learned patterns."""
|
||||
# Score each strategy
|
||||
scored_strategies = []
|
||||
for strategy in self.recovery_strategies:
|
||||
# Base priority from strategy
|
||||
priority = strategy.get_priority(context)
|
||||
|
||||
# Boost priority if we have successful patterns
|
||||
historical_success = self._get_historical_success_rate(
|
||||
strategy.name, context, learned_patterns
|
||||
)
|
||||
if historical_success > 0:
|
||||
priority *= (1.0 + historical_success)
|
||||
|
||||
scored_strategies.append((priority, strategy))
|
||||
|
||||
# Sort by priority (highest first)
|
||||
scored_strategies.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
return [strategy for _, strategy in scored_strategies]
|
||||
|
||||
def _get_historical_success_rate(
|
||||
self,
|
||||
strategy_name: str,
|
||||
context: RecoveryContext,
|
||||
learned_patterns: List
|
||||
) -> float:
|
||||
"""Get historical success rate for a strategy."""
|
||||
matching_patterns = [
|
||||
p for p in learned_patterns
|
||||
if p.recovery_strategy == strategy_name
|
||||
]
|
||||
|
||||
if not matching_patterns:
|
||||
return 0.0
|
||||
|
||||
# Return average success rate
|
||||
return sum(p.success_rate for p in matching_patterns) / len(matching_patterns)
|
||||
|
||||
def _involves_data_modification(self, context: RecoveryContext) -> bool:
|
||||
"""Check if the action involves data modification."""
|
||||
data_modification_actions = [
|
||||
'input', 'type', 'submit', 'delete', 'update', 'save'
|
||||
]
|
||||
return any(action in context.original_action.lower()
|
||||
for action in data_modification_actions)
|
||||
|
||||
def _get_strategy_description(self, strategy: RecoveryStrategy) -> str:
|
||||
"""Get human-readable description of strategy."""
|
||||
descriptions = {
|
||||
'SemanticVariantStrategy': 'Try semantic variants of the element (e.g., Submit → Send)',
|
||||
'SpatialFallbackStrategy': 'Search in expanded area around original position',
|
||||
'TimingAdaptationStrategy': 'Increase wait time for element to appear',
|
||||
'FormatTransformationStrategy': 'Transform input format to match validation'
|
||||
}
|
||||
return descriptions.get(strategy.name, strategy.name)
|
||||
|
||||
def _estimate_strategy_time(
|
||||
self,
|
||||
strategy: RecoveryStrategy,
|
||||
context: RecoveryContext
|
||||
) -> float:
|
||||
"""Estimate execution time for strategy."""
|
||||
# Simple estimates in seconds
|
||||
estimates = {
|
||||
'SemanticVariantStrategy': 2.0,
|
||||
'SpatialFallbackStrategy': 5.0,
|
||||
'TimingAdaptationStrategy': 10.0,
|
||||
'FormatTransformationStrategy': 1.0
|
||||
}
|
||||
return estimates.get(strategy.name, 3.0)
|
||||
|
||||
def _try_correction_packs(self, context: RecoveryContext) -> Optional[RecoveryResult]:
|
||||
"""
|
||||
Try to find and apply a correction from correction packs.
|
||||
|
||||
Args:
|
||||
context: Recovery context
|
||||
|
||||
Returns:
|
||||
RecoveryResult if a correction was successfully applied, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Get element type from metadata
|
||||
element_type = context.metadata.get('element_type', 'unknown')
|
||||
|
||||
# Search for applicable corrections
|
||||
corrections = self.correction_pack_service.find_applicable_corrections(
|
||||
action_type=context.original_action,
|
||||
element_type=element_type,
|
||||
failure_context=context.failure_reason,
|
||||
min_confidence=0.5 # Only use high-confidence corrections
|
||||
)
|
||||
|
||||
if not corrections:
|
||||
return None
|
||||
|
||||
# Try corrections in order of confidence
|
||||
for correction_info in corrections:
|
||||
pack_id = correction_info['pack_id']
|
||||
correction = correction_info['correction']
|
||||
|
||||
try:
|
||||
# Apply the correction
|
||||
applied = self._apply_correction(context, correction)
|
||||
|
||||
if applied:
|
||||
# Record successful application
|
||||
self.correction_pack_service.apply_correction(
|
||||
pack_id=pack_id,
|
||||
correction_id=correction['id'],
|
||||
success=True
|
||||
)
|
||||
|
||||
# Propagate to learning repository
|
||||
result = RecoveryResult(
|
||||
success=True,
|
||||
strategy_used=f"correction_pack:{correction['correction_type']}",
|
||||
confidence_score=correction['confidence_score'],
|
||||
modified_action=correction.get('corrected_params'),
|
||||
modified_target=correction.get('corrected_target')
|
||||
)
|
||||
|
||||
# Learn from this for future reference
|
||||
self.learning_repo.store_pattern(context, result)
|
||||
|
||||
logger.info(
|
||||
f"Applied correction from pack {pack_id}: "
|
||||
f"{correction['id']} (confidence: {correction['confidence_score']:.2f})"
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply correction {correction['id']}: {e}")
|
||||
# Record failed application
|
||||
self.correction_pack_service.apply_correction(
|
||||
pack_id=pack_id,
|
||||
correction_id=correction['id'],
|
||||
success=False
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in correction pack lookup: {e}")
|
||||
return None
|
||||
|
||||
def _apply_correction(
|
||||
self,
|
||||
context: RecoveryContext,
|
||||
correction: Dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Apply a correction to the current context.
|
||||
|
||||
Args:
|
||||
context: Recovery context
|
||||
correction: Correction data dict
|
||||
|
||||
Returns:
|
||||
True if correction was applied successfully
|
||||
"""
|
||||
correction_type = correction.get('correction_type', 'other')
|
||||
|
||||
# Handle different correction types
|
||||
if correction_type == 'target_change':
|
||||
# Update target in context
|
||||
if correction.get('corrected_target'):
|
||||
context.metadata['corrected_target'] = correction['corrected_target']
|
||||
return True
|
||||
|
||||
elif correction_type == 'parameter_change':
|
||||
# Update parameters in context
|
||||
if correction.get('corrected_params'):
|
||||
context.metadata['corrected_params'] = correction['corrected_params']
|
||||
return True
|
||||
|
||||
elif correction_type == 'wait_added':
|
||||
# Add wait time
|
||||
wait_time = correction.get('corrected_params', {}).get('wait_time', 2.0)
|
||||
context.metadata['additional_wait'] = wait_time
|
||||
return True
|
||||
|
||||
elif correction_type == 'timing_adjust':
|
||||
# Adjust timing
|
||||
timing = correction.get('corrected_params', {}).get('timing_multiplier', 1.5)
|
||||
context.metadata['timing_multiplier'] = timing
|
||||
return True
|
||||
|
||||
elif correction_type == 'coordinates_adjust':
|
||||
# Adjust coordinates
|
||||
offset = correction.get('corrected_params', {}).get('offset', {})
|
||||
context.metadata['coordinate_offset'] = offset
|
||||
return True
|
||||
|
||||
# For other types, just mark as applied if we have any correction data
|
||||
return bool(correction.get('corrected_target') or correction.get('corrected_params'))
|
||||
|
||||
def get_correction_pack_statistics(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get statistics from correction packs.
|
||||
|
||||
Returns:
|
||||
Statistics dict or None if not available
|
||||
"""
|
||||
if not self.correction_packs_enabled or not self.correction_pack_service:
|
||||
return None
|
||||
|
||||
try:
|
||||
return self.correction_pack_service.get_global_statistics()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting correction pack statistics: {e}")
|
||||
return None
|
||||
644
tests/test_correction_packs.py
Normal file
644
tests/test_correction_packs.py
Normal file
@@ -0,0 +1,644 @@
|
||||
"""
|
||||
Tests for the Correction Packs system.
|
||||
|
||||
Tests models, repository, aggregator, and service.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.corrections.models import (
|
||||
CorrectionKey,
|
||||
CorrectionType,
|
||||
CorrectionStatus,
|
||||
Correction,
|
||||
CorrectionPack,
|
||||
CorrectionPackMetadata,
|
||||
CorrectionSource,
|
||||
generate_correction_id,
|
||||
generate_pack_id
|
||||
)
|
||||
from core.corrections.correction_repository import CorrectionRepository
|
||||
from core.corrections.aggregator import CorrectionAggregator, CrossWorkflowAggregator
|
||||
from core.corrections.correction_pack_service import CorrectionPackService
|
||||
|
||||
|
||||
class TestCorrectionKey(unittest.TestCase):
|
||||
"""Tests for CorrectionKey model."""
|
||||
|
||||
def test_to_hash(self):
|
||||
"""Test hash generation is consistent."""
|
||||
key1 = CorrectionKey("click", "button", "element_not_found")
|
||||
key2 = CorrectionKey("click", "button", "element_not_found")
|
||||
key3 = CorrectionKey("type", "input", "element_not_found")
|
||||
|
||||
# Same keys should produce same hash
|
||||
self.assertEqual(key1.to_hash(), key2.to_hash())
|
||||
|
||||
# Different keys should produce different hash
|
||||
self.assertNotEqual(key1.to_hash(), key3.to_hash())
|
||||
|
||||
# Hash should be 16 characters
|
||||
self.assertEqual(len(key1.to_hash()), 16)
|
||||
|
||||
def test_to_dict_from_dict(self):
|
||||
"""Test serialization round-trip."""
|
||||
key = CorrectionKey("click", "button", "timeout")
|
||||
key_dict = key.to_dict()
|
||||
|
||||
self.assertEqual(key_dict['action_type'], "click")
|
||||
self.assertEqual(key_dict['element_type'], "button")
|
||||
self.assertEqual(key_dict['failure_context'], "timeout")
|
||||
self.assertIn('hash', key_dict)
|
||||
|
||||
restored = CorrectionKey.from_dict(key_dict)
|
||||
self.assertEqual(key.to_hash(), restored.to_hash())
|
||||
|
||||
|
||||
class TestCorrection(unittest.TestCase):
|
||||
"""Tests for Correction model."""
|
||||
|
||||
def test_create_correction(self):
|
||||
"""Test creating a correction."""
|
||||
key = CorrectionKey("click", "button", "not_visible")
|
||||
correction = Correction(
|
||||
id=generate_correction_id(),
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
original_target={"selector": "#submit"},
|
||||
corrected_params={"wait_time": 2.0},
|
||||
correction_description="Added wait for button visibility"
|
||||
)
|
||||
|
||||
self.assertTrue(correction.id.startswith("corr_"))
|
||||
self.assertEqual(correction.correction_type, CorrectionType.WAIT_ADDED)
|
||||
self.assertEqual(correction.success_rate, 0.0)
|
||||
|
||||
def test_success_rate_calculation(self):
|
||||
"""Test success rate calculation."""
|
||||
key = CorrectionKey("click", "button", "not_visible")
|
||||
correction = Correction(
|
||||
id=generate_correction_id(),
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
success_count=7,
|
||||
failure_count=3
|
||||
)
|
||||
|
||||
self.assertEqual(correction.success_rate, 0.7)
|
||||
|
||||
def test_record_application(self):
|
||||
"""Test recording application updates stats."""
|
||||
key = CorrectionKey("click", "button", "not_visible")
|
||||
correction = Correction(
|
||||
id=generate_correction_id(),
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED
|
||||
)
|
||||
|
||||
correction.record_application(success=True)
|
||||
self.assertEqual(correction.success_count, 1)
|
||||
self.assertEqual(correction.failure_count, 0)
|
||||
self.assertIsNotNone(correction.last_applied)
|
||||
|
||||
correction.record_application(success=False)
|
||||
self.assertEqual(correction.success_count, 1)
|
||||
self.assertEqual(correction.failure_count, 1)
|
||||
|
||||
def test_to_dict_from_dict(self):
|
||||
"""Test serialization round-trip."""
|
||||
key = CorrectionKey("type", "input", "validation_error")
|
||||
source = CorrectionSource(
|
||||
session_id="sess_123",
|
||||
workflow_id="wf_456",
|
||||
node_id="node_789"
|
||||
)
|
||||
correction = Correction(
|
||||
id="corr_test123",
|
||||
key=key,
|
||||
correction_type=CorrectionType.PARAMETER_CHANGE,
|
||||
original_params={"value": "test"},
|
||||
corrected_params={"value": "corrected"},
|
||||
source=source,
|
||||
success_count=5,
|
||||
failure_count=2,
|
||||
tags=["form", "validation"]
|
||||
)
|
||||
|
||||
correction_dict = correction.to_dict()
|
||||
restored = Correction.from_dict(correction_dict)
|
||||
|
||||
self.assertEqual(restored.id, correction.id)
|
||||
self.assertEqual(restored.key.to_hash(), key.to_hash())
|
||||
self.assertEqual(restored.correction_type, CorrectionType.PARAMETER_CHANGE)
|
||||
self.assertEqual(restored.success_count, 5)
|
||||
self.assertEqual(restored.tags, ["form", "validation"])
|
||||
|
||||
|
||||
class TestCorrectionPack(unittest.TestCase):
|
||||
"""Tests for CorrectionPack model."""
|
||||
|
||||
def test_create_pack(self):
|
||||
"""Test creating a pack."""
|
||||
pack = CorrectionPack(
|
||||
id=generate_pack_id(),
|
||||
name="Test Pack",
|
||||
description="A test correction pack"
|
||||
)
|
||||
|
||||
self.assertTrue(pack.id.startswith("pack_"))
|
||||
self.assertEqual(pack.name, "Test Pack")
|
||||
self.assertEqual(pack.correction_count, 0)
|
||||
|
||||
def test_add_remove_correction(self):
|
||||
"""Test adding and removing corrections."""
|
||||
pack = CorrectionPack(
|
||||
id=generate_pack_id(),
|
||||
name="Test Pack"
|
||||
)
|
||||
|
||||
key = CorrectionKey("click", "button", "not_found")
|
||||
correction = Correction(
|
||||
id="corr_test1",
|
||||
key=key,
|
||||
correction_type=CorrectionType.TARGET_CHANGE
|
||||
)
|
||||
|
||||
# Add correction
|
||||
self.assertTrue(pack.add_correction(correction))
|
||||
self.assertEqual(pack.correction_count, 1)
|
||||
|
||||
# Adding same correction should return False
|
||||
self.assertFalse(pack.add_correction(correction))
|
||||
|
||||
# Remove correction
|
||||
self.assertTrue(pack.remove_correction("corr_test1"))
|
||||
self.assertEqual(pack.correction_count, 0)
|
||||
|
||||
# Removing non-existent should return False
|
||||
self.assertFalse(pack.remove_correction("corr_test1"))
|
||||
|
||||
def test_find_by_key_hash(self):
|
||||
"""Test finding corrections by key hash."""
|
||||
pack = CorrectionPack(id=generate_pack_id(), name="Test Pack")
|
||||
|
||||
key1 = CorrectionKey("click", "button", "not_found")
|
||||
key2 = CorrectionKey("type", "input", "validation")
|
||||
|
||||
correction1 = Correction(id="corr_1", key=key1, correction_type=CorrectionType.TARGET_CHANGE)
|
||||
correction2 = Correction(id="corr_2", key=key2, correction_type=CorrectionType.PARAMETER_CHANGE)
|
||||
|
||||
pack.add_correction(correction1)
|
||||
pack.add_correction(correction2)
|
||||
|
||||
# Find by key1 hash
|
||||
found = pack.find_by_key_hash(key1.to_hash())
|
||||
self.assertEqual(len(found), 1)
|
||||
self.assertEqual(found[0].id, "corr_1")
|
||||
|
||||
def test_merge_packs(self):
|
||||
"""Test merging two packs."""
|
||||
pack1 = CorrectionPack(id="pack_1", name="Pack 1")
|
||||
pack2 = CorrectionPack(id="pack_2", name="Pack 2")
|
||||
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
corr1 = Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE)
|
||||
corr2 = Correction(id="corr_2", key=key, correction_type=CorrectionType.WAIT_ADDED)
|
||||
|
||||
pack1.add_correction(corr1)
|
||||
pack2.add_correction(corr2)
|
||||
|
||||
merged_count = pack1.merge(pack2)
|
||||
self.assertEqual(merged_count, 1)
|
||||
self.assertEqual(pack1.correction_count, 2)
|
||||
|
||||
def test_get_statistics(self):
|
||||
"""Test getting pack statistics."""
|
||||
pack = CorrectionPack(id=generate_pack_id(), name="Test Pack")
|
||||
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
correction = Correction(
|
||||
id="corr_1",
|
||||
key=key,
|
||||
correction_type=CorrectionType.TARGET_CHANGE,
|
||||
success_count=10,
|
||||
failure_count=2
|
||||
)
|
||||
pack.add_correction(correction)
|
||||
|
||||
stats = pack.get_statistics()
|
||||
self.assertEqual(stats['total_corrections'], 1)
|
||||
self.assertEqual(stats['active_corrections'], 1)
|
||||
self.assertEqual(stats['total_applications'], 12)
|
||||
|
||||
|
||||
class TestCorrectionRepository(unittest.TestCase):
|
||||
"""Tests for CorrectionRepository."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary directory for tests."""
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.repo = CorrectionRepository(self.test_dir)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up temporary directory."""
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def test_create_and_get_pack(self):
|
||||
"""Test creating and retrieving a pack."""
|
||||
pack = self.repo.create_pack("Test Pack", "Description", {
|
||||
'category': 'testing',
|
||||
'tags': ['unit', 'test']
|
||||
})
|
||||
|
||||
self.assertIsNotNone(pack)
|
||||
self.assertEqual(pack.name, "Test Pack")
|
||||
|
||||
retrieved = self.repo.get_pack(pack.id)
|
||||
self.assertIsNotNone(retrieved)
|
||||
self.assertEqual(retrieved.name, "Test Pack")
|
||||
|
||||
def test_list_packs(self):
|
||||
"""Test listing packs."""
|
||||
self.repo.create_pack("Pack 1", "First pack")
|
||||
self.repo.create_pack("Pack 2", "Second pack")
|
||||
|
||||
packs = self.repo.list_packs()
|
||||
self.assertEqual(len(packs), 2)
|
||||
|
||||
def test_update_pack(self):
|
||||
"""Test updating a pack."""
|
||||
pack = self.repo.create_pack("Original Name")
|
||||
|
||||
updated = self.repo.update_pack(pack.id, {'name': 'New Name'})
|
||||
self.assertIsNotNone(updated)
|
||||
self.assertEqual(updated.name, 'New Name')
|
||||
|
||||
def test_delete_pack(self):
|
||||
"""Test deleting a pack."""
|
||||
pack = self.repo.create_pack("To Delete")
|
||||
|
||||
self.assertTrue(self.repo.delete_pack(pack.id))
|
||||
self.assertIsNone(self.repo.get_pack(pack.id))
|
||||
self.assertFalse(self.repo.delete_pack(pack.id))
|
||||
|
||||
def test_add_and_find_correction(self):
|
||||
"""Test adding and finding corrections."""
|
||||
pack = self.repo.create_pack("Test Pack")
|
||||
|
||||
key = CorrectionKey("click", "button", "timeout")
|
||||
correction = Correction(
|
||||
id=generate_correction_id(),
|
||||
key=key,
|
||||
correction_type=CorrectionType.TIMING_ADJUST
|
||||
)
|
||||
|
||||
self.assertTrue(self.repo.add_correction(pack.id, correction))
|
||||
|
||||
# Find by context
|
||||
found = self.repo.find_by_context("click", "button", "timeout")
|
||||
self.assertEqual(len(found), 1)
|
||||
self.assertEqual(found[0][1].id, correction.id)
|
||||
|
||||
def test_export_import_pack(self):
|
||||
"""Test exporting and importing a pack."""
|
||||
# Create pack with correction
|
||||
pack = self.repo.create_pack("Export Test")
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
correction = Correction(
|
||||
id="corr_export",
|
||||
key=key,
|
||||
correction_type=CorrectionType.TARGET_CHANGE
|
||||
)
|
||||
self.repo.add_correction(pack.id, correction)
|
||||
|
||||
# Export
|
||||
exported = self.repo.export_pack(pack.id)
|
||||
self.assertIsNotNone(exported)
|
||||
|
||||
# Import as new
|
||||
imported = self.repo.import_pack(exported, merge_strategy='create_new')
|
||||
self.assertIsNotNone(imported)
|
||||
self.assertNotEqual(imported.id, pack.id)
|
||||
self.assertEqual(imported.correction_count, 1)
|
||||
|
||||
def test_version_and_rollback(self):
|
||||
"""Test versioning and rollback."""
|
||||
pack = self.repo.create_pack("Version Test")
|
||||
|
||||
# Add correction
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
correction1 = Correction(id="corr_v1", key=key, correction_type=CorrectionType.TARGET_CHANGE)
|
||||
self.repo.add_correction(pack.id, correction1)
|
||||
|
||||
# Create version
|
||||
version = self.repo.create_version(pack.id, "First version")
|
||||
self.assertEqual(version, 1)
|
||||
|
||||
# Add another correction
|
||||
correction2 = Correction(id="corr_v2", key=key, correction_type=CorrectionType.WAIT_ADDED)
|
||||
self.repo.add_correction(pack.id, correction2)
|
||||
|
||||
# Verify 2 corrections
|
||||
pack = self.repo.get_pack(pack.id)
|
||||
self.assertEqual(pack.correction_count, 2)
|
||||
|
||||
# Rollback to version 1
|
||||
self.assertTrue(self.repo.rollback_pack(pack.id, 1))
|
||||
|
||||
# Should have 1 correction again
|
||||
pack = self.repo.get_pack(pack.id)
|
||||
self.assertEqual(pack.correction_count, 1)
|
||||
|
||||
|
||||
class TestCorrectionAggregator(unittest.TestCase):
|
||||
"""Tests for CorrectionAggregator."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up aggregator for tests."""
|
||||
self.aggregator = CorrectionAggregator(
|
||||
min_occurrences=2,
|
||||
min_success_rate=0.5,
|
||||
min_confidence=0.3
|
||||
)
|
||||
|
||||
def test_aggregate_by_key(self):
|
||||
"""Test aggregating corrections by key."""
|
||||
key = CorrectionKey("click", "button", "not_found")
|
||||
|
||||
# Create multiple corrections with same key
|
||||
corrections = [
|
||||
Correction(
|
||||
id=f"corr_{i}",
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
success_count=5,
|
||||
failure_count=2
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
aggregated = self.aggregator.aggregate_corrections(corrections)
|
||||
|
||||
# Should have 1 merged correction
|
||||
self.assertEqual(len(aggregated), 1)
|
||||
# Sum of success counts
|
||||
self.assertEqual(aggregated[0].success_count, 15)
|
||||
|
||||
def test_filter_by_quality(self):
|
||||
"""Test filtering by quality thresholds."""
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
|
||||
high_quality = Correction(
|
||||
id="corr_high",
|
||||
key=key,
|
||||
correction_type=CorrectionType.TARGET_CHANGE,
|
||||
success_count=9,
|
||||
failure_count=1
|
||||
)
|
||||
|
||||
low_quality = Correction(
|
||||
id="corr_low",
|
||||
key=key,
|
||||
correction_type=CorrectionType.TARGET_CHANGE,
|
||||
success_count=1,
|
||||
failure_count=9
|
||||
)
|
||||
|
||||
filtered = self.aggregator.filter_by_quality(
|
||||
[high_quality, low_quality],
|
||||
min_success_rate=0.7
|
||||
)
|
||||
|
||||
self.assertEqual(len(filtered), 1)
|
||||
self.assertEqual(filtered[0].id, "corr_high")
|
||||
|
||||
def test_rank_corrections(self):
|
||||
"""Test ranking corrections."""
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
|
||||
corrections = [
|
||||
Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=3, failure_count=7),
|
||||
Correction(id="corr_2", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=8, failure_count=2),
|
||||
Correction(id="corr_3", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=5, failure_count=5),
|
||||
]
|
||||
|
||||
ranked = self.aggregator.rank_corrections(corrections, sort_by='success_rate')
|
||||
|
||||
self.assertEqual(ranked[0].id, "corr_2")
|
||||
self.assertEqual(ranked[1].id, "corr_3")
|
||||
self.assertEqual(ranked[2].id, "corr_1")
|
||||
|
||||
|
||||
class TestCorrectionPackService(unittest.TestCase):
|
||||
"""Tests for CorrectionPackService."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up service for tests."""
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.training_dir = tempfile.mkdtemp()
|
||||
self.service = CorrectionPackService(self.test_dir, self.training_dir)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up temporary directories."""
|
||||
shutil.rmtree(self.test_dir)
|
||||
shutil.rmtree(self.training_dir)
|
||||
|
||||
def test_create_and_list_packs(self):
|
||||
"""Test creating and listing packs."""
|
||||
pack = self.service.create_pack(
|
||||
name="Service Test Pack",
|
||||
description="Testing the service",
|
||||
category="testing",
|
||||
tags=["unit", "test"]
|
||||
)
|
||||
|
||||
self.assertIsNotNone(pack)
|
||||
self.assertEqual(pack['name'], "Service Test Pack")
|
||||
|
||||
packs = self.service.list_packs()
|
||||
self.assertEqual(len(packs), 1)
|
||||
self.assertEqual(packs[0]['name'], "Service Test Pack")
|
||||
|
||||
def test_add_correction(self):
|
||||
"""Test adding corrections via service."""
|
||||
pack = self.service.create_pack(name="Correction Test")
|
||||
|
||||
correction = self.service.add_correction(pack['id'], {
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_context': 'element_not_found',
|
||||
'correction_type': 'target_change',
|
||||
'original_target': {'selector': '#old'},
|
||||
'corrected_target': {'selector': '#new'},
|
||||
'description': 'Updated selector'
|
||||
})
|
||||
|
||||
self.assertIsNotNone(correction)
|
||||
self.assertTrue(correction['id'].startswith('corr_'))
|
||||
|
||||
def test_find_applicable_corrections(self):
|
||||
"""Test finding applicable corrections."""
|
||||
pack = self.service.create_pack(name="Find Test")
|
||||
|
||||
self.service.add_correction(pack['id'], {
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_context': 'timeout',
|
||||
'correction_type': 'timing_adjust',
|
||||
'corrected_params': {'wait_time': 5}
|
||||
})
|
||||
|
||||
found = self.service.find_applicable_corrections(
|
||||
action_type='click',
|
||||
element_type='button',
|
||||
failure_context='timeout',
|
||||
min_confidence=0.0
|
||||
)
|
||||
|
||||
self.assertEqual(len(found), 1)
|
||||
|
||||
def test_export_import_file(self):
|
||||
"""Test file-based export and import."""
|
||||
pack = self.service.create_pack(name="File Export Test")
|
||||
self.service.add_correction(pack['id'], {
|
||||
'action_type': 'type',
|
||||
'element_type': 'input',
|
||||
'failure_context': 'validation',
|
||||
'correction_type': 'parameter_change'
|
||||
})
|
||||
|
||||
# Export to file
|
||||
export_file = os.path.join(self.test_dir, 'export.json')
|
||||
self.assertTrue(self.service.export_pack_to_file(pack['id'], export_file))
|
||||
self.assertTrue(os.path.exists(export_file))
|
||||
|
||||
# Import from file
|
||||
imported = self.service.import_pack_from_file(export_file)
|
||||
self.assertIsNotNone(imported)
|
||||
self.assertEqual(len(imported['corrections']), 1)
|
||||
|
||||
def test_version_management(self):
|
||||
"""Test version creation and rollback."""
|
||||
pack = self.service.create_pack(name="Version Test")
|
||||
|
||||
# Add correction
|
||||
self.service.add_correction(pack['id'], {
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_context': 'error',
|
||||
'correction_type': 'target_change'
|
||||
})
|
||||
|
||||
# Create version
|
||||
version = self.service.create_version(pack['id'], "v1 snapshot")
|
||||
self.assertEqual(version, 1)
|
||||
|
||||
# List versions
|
||||
versions = self.service.list_versions(pack['id'])
|
||||
self.assertEqual(len(versions), 1)
|
||||
|
||||
# Add another correction
|
||||
self.service.add_correction(pack['id'], {
|
||||
'action_type': 'type',
|
||||
'element_type': 'input',
|
||||
'failure_context': 'error',
|
||||
'correction_type': 'parameter_change'
|
||||
})
|
||||
|
||||
# Get pack - should have 2 corrections
|
||||
current_pack = self.service.get_pack(pack['id'])
|
||||
self.assertEqual(len(current_pack['corrections']), 2)
|
||||
|
||||
# Rollback
|
||||
self.assertTrue(self.service.rollback_pack(pack['id'], 1))
|
||||
|
||||
# Get pack - should have 1 correction
|
||||
rolled_back = self.service.get_pack(pack['id'])
|
||||
self.assertEqual(len(rolled_back['corrections']), 1)
|
||||
|
||||
def test_statistics(self):
|
||||
"""Test getting statistics."""
|
||||
pack = self.service.create_pack(name="Stats Test")
|
||||
self.service.add_correction(pack['id'], {
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_context': 'error',
|
||||
'correction_type': 'target_change'
|
||||
})
|
||||
|
||||
# Pack statistics
|
||||
pack_stats = self.service.get_pack_statistics(pack['id'])
|
||||
self.assertIsNotNone(pack_stats)
|
||||
self.assertEqual(pack_stats['total_corrections'], 1)
|
||||
|
||||
# Global statistics
|
||||
global_stats = self.service.get_global_statistics()
|
||||
self.assertIsNotNone(global_stats)
|
||||
self.assertEqual(global_stats['total_packs'], 1)
|
||||
|
||||
|
||||
class TestCrossWorkflowAggregator(unittest.TestCase):
|
||||
"""Tests for CrossWorkflowAggregator."""
|
||||
|
||||
def test_find_common_corrections(self):
|
||||
"""Test finding corrections common across workflows."""
|
||||
aggregator = CrossWorkflowAggregator()
|
||||
|
||||
key = CorrectionKey("click", "submit_button", "timeout")
|
||||
|
||||
# Corrections from different workflows
|
||||
corrections = [
|
||||
Correction(
|
||||
id="corr_1",
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
source=CorrectionSource(session_id="s1", workflow_id="wf_1")
|
||||
),
|
||||
Correction(
|
||||
id="corr_2",
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
source=CorrectionSource(session_id="s2", workflow_id="wf_2")
|
||||
),
|
||||
Correction(
|
||||
id="corr_3",
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
source=CorrectionSource(session_id="s3", workflow_id="wf_3")
|
||||
),
|
||||
]
|
||||
|
||||
common = aggregator.find_common_corrections(corrections, min_workflows=2)
|
||||
|
||||
self.assertEqual(len(common), 1)
|
||||
merged, workflows = common[0]
|
||||
self.assertEqual(len(workflows), 3)
|
||||
|
||||
def test_detect_patterns(self):
|
||||
"""Test pattern detection."""
|
||||
aggregator = CrossWorkflowAggregator()
|
||||
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
corrections = [
|
||||
Correction(id=f"corr_{i}", key=key, correction_type=CorrectionType.WAIT_ADDED)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
patterns = aggregator.detect_patterns(corrections)
|
||||
|
||||
self.assertIn('patterns', patterns)
|
||||
self.assertIn('summary', patterns)
|
||||
self.assertEqual(patterns['summary']['total_corrections'], 5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
640
visual_workflow_builder/backend/api/correction_packs.py
Normal file
640
visual_workflow_builder/backend/api/correction_packs.py
Normal file
@@ -0,0 +1,640 @@
|
||||
"""
|
||||
API endpoints pour les Correction Packs.
|
||||
|
||||
Provides REST API for managing correction packs, corrections,
|
||||
aggregation, export/import, and versioning.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from flask import Blueprint, request, jsonify, Response
|
||||
from typing import Optional
|
||||
|
||||
# Add parent paths for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))))
|
||||
|
||||
from core.corrections import CorrectionPackService
|
||||
|
||||
correction_packs_bp = Blueprint('correction_packs', __name__)
|
||||
|
||||
# Initialize service (singleton pattern)
|
||||
_service: Optional[CorrectionPackService] = None
|
||||
|
||||
|
||||
def get_service() -> CorrectionPackService:
|
||||
"""Get or create the correction pack service."""
|
||||
global _service
|
||||
if _service is None:
|
||||
# Use paths relative to the project root
|
||||
base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||
storage_path = os.path.join(base_path, "data", "correction_packs")
|
||||
training_path = os.path.join(base_path, "training_data")
|
||||
_service = CorrectionPackService(storage_path, training_path)
|
||||
return _service
|
||||
|
||||
|
||||
# ========== Pack CRUD Endpoints ==========
|
||||
|
||||
@correction_packs_bp.route('/correction-packs', methods=['GET'])
|
||||
def list_packs():
|
||||
"""
|
||||
List all correction packs.
|
||||
|
||||
Query params:
|
||||
category: Filter by category
|
||||
tags: Filter by tags (comma-separated)
|
||||
|
||||
Returns:
|
||||
JSON list of pack summaries
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
|
||||
category = request.args.get('category')
|
||||
tags_param = request.args.get('tags')
|
||||
tags = tags_param.split(',') if tags_param else None
|
||||
|
||||
packs = service.list_packs(category=category, tags=tags)
|
||||
return jsonify({'success': True, 'packs': packs})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs', methods=['POST'])
|
||||
def create_pack():
|
||||
"""
|
||||
Create a new correction pack.
|
||||
|
||||
Request body:
|
||||
name: Pack name (required)
|
||||
description: Pack description
|
||||
category: Pack category
|
||||
tags: List of tags
|
||||
author: Pack author
|
||||
|
||||
Returns:
|
||||
Created pack
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
data = request.get_json() or {}
|
||||
|
||||
if not data.get('name'):
|
||||
return jsonify({'success': False, 'error': 'name is required'}), 400
|
||||
|
||||
pack = service.create_pack(
|
||||
name=data['name'],
|
||||
description=data.get('description', ''),
|
||||
category=data.get('category', 'general'),
|
||||
tags=data.get('tags', []),
|
||||
author=data.get('author', '')
|
||||
)
|
||||
|
||||
return jsonify({'success': True, 'pack': pack}), 201
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>', methods=['GET'])
|
||||
def get_pack(pack_id: str):
|
||||
"""
|
||||
Get a pack by ID.
|
||||
|
||||
Returns:
|
||||
Pack details including all corrections
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
pack = service.get_pack(pack_id)
|
||||
|
||||
if not pack:
|
||||
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||
|
||||
return jsonify({'success': True, 'pack': pack})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>', methods=['PUT'])
|
||||
def update_pack(pack_id: str):
|
||||
"""
|
||||
Update a pack.
|
||||
|
||||
Request body:
|
||||
name: New name
|
||||
description: New description
|
||||
metadata: New metadata dict
|
||||
|
||||
Returns:
|
||||
Updated pack
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
data = request.get_json() or {}
|
||||
|
||||
pack = service.update_pack(
|
||||
pack_id=pack_id,
|
||||
name=data.get('name'),
|
||||
description=data.get('description'),
|
||||
metadata=data.get('metadata')
|
||||
)
|
||||
|
||||
if not pack:
|
||||
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||
|
||||
return jsonify({'success': True, 'pack': pack})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>', methods=['DELETE'])
|
||||
def delete_pack(pack_id: str):
|
||||
"""
|
||||
Delete a pack.
|
||||
|
||||
Returns:
|
||||
Success status
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
|
||||
if not service.delete_pack(pack_id):
|
||||
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||
|
||||
return jsonify({'success': True, 'message': 'Pack deleted'})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
# ========== Correction Endpoints ==========
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>/corrections', methods=['POST'])
|
||||
def add_correction(pack_id: str):
|
||||
"""
|
||||
Add a correction to a pack.
|
||||
|
||||
Request body:
|
||||
action_type: Type of action (required)
|
||||
element_type: Type of element
|
||||
failure_context: Description of failure
|
||||
correction_type: Type of correction
|
||||
original_target: Original target dict
|
||||
original_params: Original parameters dict
|
||||
corrected_target: Corrected target dict
|
||||
corrected_params: Corrected parameters dict
|
||||
description: Correction description
|
||||
tags: List of tags
|
||||
source: Source info dict
|
||||
|
||||
Returns:
|
||||
Created correction
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
data = request.get_json() or {}
|
||||
|
||||
if not data.get('action_type'):
|
||||
return jsonify({'success': False, 'error': 'action_type is required'}), 400
|
||||
|
||||
correction = service.add_correction(pack_id, data)
|
||||
|
||||
if not correction:
|
||||
return jsonify({'success': False, 'error': 'Failed to add correction'}), 400
|
||||
|
||||
return jsonify({'success': True, 'correction': correction}), 201
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>/corrections/<correction_id>', methods=['PUT'])
|
||||
def update_correction(pack_id: str, correction_id: str):
|
||||
"""
|
||||
Update a correction.
|
||||
|
||||
Request body:
|
||||
corrected_target: New corrected target
|
||||
corrected_params: New corrected params
|
||||
correction_description: New description
|
||||
status: New status
|
||||
tags: New tags
|
||||
|
||||
Returns:
|
||||
Updated correction
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
data = request.get_json() or {}
|
||||
|
||||
correction = service.update_correction(pack_id, correction_id, data)
|
||||
|
||||
if not correction:
|
||||
return jsonify({'success': False, 'error': 'Correction not found'}), 404
|
||||
|
||||
return jsonify({'success': True, 'correction': correction})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>/corrections/<correction_id>', methods=['DELETE'])
|
||||
def remove_correction(pack_id: str, correction_id: str):
|
||||
"""
|
||||
Remove a correction from a pack.
|
||||
|
||||
Returns:
|
||||
Success status
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
|
||||
if not service.remove_correction(pack_id, correction_id):
|
||||
return jsonify({'success': False, 'error': 'Correction not found'}), 404
|
||||
|
||||
return jsonify({'success': True, 'message': 'Correction removed'})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
# ========== Aggregation Endpoints ==========
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>/aggregate', methods=['POST'])
|
||||
def aggregate_from_sessions(pack_id: str):
|
||||
"""
|
||||
Aggregate corrections from training sessions into a pack.
|
||||
|
||||
Request body:
|
||||
workflow_id: Filter by workflow ID
|
||||
min_occurrences: Minimum occurrences to include
|
||||
min_success_rate: Minimum success rate to include
|
||||
session_files: Specific session files to process
|
||||
|
||||
Returns:
|
||||
Aggregation summary
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
data = request.get_json() or {}
|
||||
|
||||
result = service.aggregate_from_sessions(
|
||||
pack_id=pack_id,
|
||||
session_files=data.get('session_files'),
|
||||
workflow_id=data.get('workflow_id'),
|
||||
min_occurrences=data.get('min_occurrences', 1),
|
||||
min_success_rate=data.get('min_success_rate', 0.0)
|
||||
)
|
||||
|
||||
return jsonify({'success': True, 'result': result})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>/aggregate-cross-workflow', methods=['POST'])
|
||||
def aggregate_cross_workflow(pack_id: str):
|
||||
"""
|
||||
Find and aggregate corrections common across multiple workflows.
|
||||
|
||||
Request body:
|
||||
min_workflows: Minimum number of workflows (default: 2)
|
||||
|
||||
Returns:
|
||||
Aggregation summary with patterns
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
data = request.get_json() or {}
|
||||
|
||||
result = service.aggregate_cross_workflow(
|
||||
pack_id=pack_id,
|
||||
min_workflows=data.get('min_workflows', 2)
|
||||
)
|
||||
|
||||
return jsonify({'success': True, 'result': result})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
# ========== Search and Apply Endpoints ==========
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/find', methods=['POST'])
|
||||
def find_corrections():
|
||||
"""
|
||||
Find corrections applicable to a failure context.
|
||||
|
||||
Request body:
|
||||
action_type: Type of action that failed (required)
|
||||
element_type: Type of element involved
|
||||
failure_context: Description of failure
|
||||
pack_ids: Specific packs to search
|
||||
min_confidence: Minimum confidence score
|
||||
|
||||
Returns:
|
||||
List of applicable corrections
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
data = request.get_json() or {}
|
||||
|
||||
if not data.get('action_type'):
|
||||
return jsonify({'success': False, 'error': 'action_type is required'}), 400
|
||||
|
||||
corrections = service.find_applicable_corrections(
|
||||
action_type=data['action_type'],
|
||||
element_type=data.get('element_type', 'unknown'),
|
||||
failure_context=data.get('failure_context', ''),
|
||||
pack_ids=data.get('pack_ids'),
|
||||
min_confidence=data.get('min_confidence', 0.3)
|
||||
)
|
||||
|
||||
return jsonify({'success': True, 'corrections': corrections})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/apply', methods=['POST'])
|
||||
def apply_correction():
|
||||
"""
|
||||
Record an application of a correction.
|
||||
|
||||
Request body:
|
||||
pack_id: Pack ID (required)
|
||||
correction_id: Correction ID (required)
|
||||
success: Whether the application was successful (required)
|
||||
context: Optional application context
|
||||
|
||||
Returns:
|
||||
Success status
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
data = request.get_json() or {}
|
||||
|
||||
if not data.get('pack_id') or not data.get('correction_id'):
|
||||
return jsonify({'success': False, 'error': 'pack_id and correction_id are required'}), 400
|
||||
|
||||
if 'success' not in data:
|
||||
return jsonify({'success': False, 'error': 'success is required'}), 400
|
||||
|
||||
service.apply_correction(
|
||||
pack_id=data['pack_id'],
|
||||
correction_id=data['correction_id'],
|
||||
success=data['success'],
|
||||
application_context=data.get('context')
|
||||
)
|
||||
|
||||
return jsonify({'success': True, 'message': 'Application recorded'})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/search', methods=['GET'])
|
||||
def search_corrections():
|
||||
"""
|
||||
Search corrections by text query.
|
||||
|
||||
Query params:
|
||||
q: Search query (required)
|
||||
pack_id: Optional pack ID to search in
|
||||
|
||||
Returns:
|
||||
List of matching corrections
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
|
||||
query = request.args.get('q')
|
||||
if not query:
|
||||
return jsonify({'success': False, 'error': 'q parameter is required'}), 400
|
||||
|
||||
pack_id = request.args.get('pack_id')
|
||||
|
||||
corrections = service.search_corrections(query, pack_id)
|
||||
return jsonify({'success': True, 'corrections': corrections})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
# ========== Export/Import Endpoints ==========
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>/export', methods=['GET'])
|
||||
def export_pack(pack_id: str):
|
||||
"""
|
||||
Export a pack to JSON or YAML.
|
||||
|
||||
Query params:
|
||||
format: Export format ('json' or 'yaml', default: 'json')
|
||||
|
||||
Returns:
|
||||
Exported content
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
|
||||
format_type = request.args.get('format', 'json')
|
||||
if format_type not in ['json', 'yaml']:
|
||||
return jsonify({'success': False, 'error': 'Invalid format'}), 400
|
||||
|
||||
content = service.export_pack(pack_id, format=format_type)
|
||||
|
||||
if not content:
|
||||
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||
|
||||
# Determine content type
|
||||
if format_type == 'yaml':
|
||||
content_type = 'application/x-yaml'
|
||||
filename = f'correction_pack_{pack_id}.yaml'
|
||||
else:
|
||||
content_type = 'application/json'
|
||||
filename = f'correction_pack_{pack_id}.json'
|
||||
|
||||
return Response(
|
||||
content,
|
||||
mimetype=content_type,
|
||||
headers={
|
||||
'Content-Disposition': f'attachment; filename={filename}'
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/import', methods=['POST'])
|
||||
def import_pack():
|
||||
"""
|
||||
Import a pack from JSON or YAML.
|
||||
|
||||
Request body or file upload:
|
||||
content: Content to import (if JSON body)
|
||||
format: Import format ('json' or 'yaml')
|
||||
merge_strategy: How to handle existing pack
|
||||
- 'create_new': Always create new pack
|
||||
- 'merge': Merge into existing pack
|
||||
- 'replace': Replace existing pack
|
||||
|
||||
Returns:
|
||||
Imported pack
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
|
||||
# Check for file upload
|
||||
if 'file' in request.files:
|
||||
file = request.files['file']
|
||||
content = file.read().decode('utf-8')
|
||||
filename = file.filename
|
||||
|
||||
# Auto-detect format from filename
|
||||
if filename and (filename.endswith('.yaml') or filename.endswith('.yml')):
|
||||
format_type = 'yaml'
|
||||
else:
|
||||
format_type = 'json'
|
||||
|
||||
merge_strategy = request.form.get('merge_strategy', 'create_new')
|
||||
else:
|
||||
# JSON body
|
||||
data = request.get_json() or {}
|
||||
content = data.get('content')
|
||||
|
||||
if not content:
|
||||
return jsonify({'success': False, 'error': 'content or file is required'}), 400
|
||||
|
||||
format_type = data.get('format', 'json')
|
||||
merge_strategy = data.get('merge_strategy', 'create_new')
|
||||
|
||||
pack = service.import_pack(content, format=format_type, merge_strategy=merge_strategy)
|
||||
|
||||
if not pack:
|
||||
return jsonify({'success': False, 'error': 'Failed to import pack'}), 400
|
||||
|
||||
return jsonify({'success': True, 'pack': pack}), 201
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
# ========== Version Endpoints ==========
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>/versions', methods=['GET'])
|
||||
def list_versions(pack_id: str):
|
||||
"""
|
||||
List all versions of a pack.
|
||||
|
||||
Returns:
|
||||
List of version info
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
versions = service.list_versions(pack_id)
|
||||
return jsonify({'success': True, 'versions': versions})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>/versions', methods=['POST'])
|
||||
def create_version(pack_id: str):
|
||||
"""
|
||||
Create a version snapshot of a pack.
|
||||
|
||||
Request body:
|
||||
description: Version description
|
||||
|
||||
Returns:
|
||||
Created version number
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
data = request.get_json() or {}
|
||||
|
||||
version = service.create_version(
|
||||
pack_id=pack_id,
|
||||
description=data.get('description', '')
|
||||
)
|
||||
|
||||
if version < 0:
|
||||
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||
|
||||
return jsonify({'success': True, 'version': version}), 201
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>/rollback', methods=['POST'])
|
||||
def rollback_pack(pack_id: str):
|
||||
"""
|
||||
Rollback a pack to a previous version.
|
||||
|
||||
Request body:
|
||||
version: Version number to rollback to (required)
|
||||
|
||||
Returns:
|
||||
Success status
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
data = request.get_json() or {}
|
||||
|
||||
version = data.get('version')
|
||||
if version is None:
|
||||
return jsonify({'success': False, 'error': 'version is required'}), 400
|
||||
|
||||
if not service.rollback_pack(pack_id, version):
|
||||
return jsonify({'success': False, 'error': 'Rollback failed'}), 400
|
||||
|
||||
return jsonify({'success': True, 'message': f'Rolled back to version {version}'})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
# ========== Statistics Endpoints ==========
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/<pack_id>/statistics', methods=['GET'])
|
||||
def get_pack_statistics(pack_id: str):
|
||||
"""
|
||||
Get detailed statistics for a pack.
|
||||
|
||||
Returns:
|
||||
Statistics dict
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
stats = service.get_pack_statistics(pack_id)
|
||||
|
||||
if not stats:
|
||||
return jsonify({'success': False, 'error': 'Pack not found'}), 404
|
||||
|
||||
return jsonify({'success': True, 'statistics': stats})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
|
||||
|
||||
@correction_packs_bp.route('/correction-packs/statistics', methods=['GET'])
|
||||
def get_global_statistics():
|
||||
"""
|
||||
Get global statistics across all packs.
|
||||
|
||||
Returns:
|
||||
Global statistics dict
|
||||
"""
|
||||
try:
|
||||
service = get_service()
|
||||
stats = service.get_global_statistics()
|
||||
return jsonify({'success': True, 'statistics': stats})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({'success': False, 'error': str(e)}), 500
|
||||
201
visual_workflow_builder/backend/app.py
Normal file
201
visual_workflow_builder/backend/app.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Visual Workflow Builder - Backend Flask Application
|
||||
|
||||
This is the main entry point for the Visual Workflow Builder backend API.
|
||||
It provides REST endpoints for workflow management and WebSocket support
|
||||
for real-time execution updates.
|
||||
"""
|
||||
|
||||
from flask import Flask
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_caching import Cache
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize Flask app
|
||||
app = Flask(__name__)
|
||||
|
||||
# Configuration
|
||||
app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', 'dev-secret-key-change-in-production')
|
||||
app.config['SQLALCHEMY_DATABASE_URI'] = os.getenv('DATABASE_URL', 'sqlite:///workflows.db')
|
||||
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
|
||||
app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024 # 10MB max upload
|
||||
app.config['CACHE_TYPE'] = 'redis' if os.getenv('REDIS_URL') else 'simple'
|
||||
app.config['CACHE_REDIS_URL'] = os.getenv('REDIS_URL', 'redis://localhost:6379/0')
|
||||
|
||||
# Initialize extensions
|
||||
db = SQLAlchemy(app)
|
||||
cache = Cache(app)
|
||||
socketio = SocketIO(
|
||||
app,
|
||||
cors_allowed_origins="*",
|
||||
async_mode='threading',
|
||||
logger=True,
|
||||
engineio_logger=True
|
||||
)
|
||||
|
||||
# Enable CORS
|
||||
CORS(app, resources={
|
||||
r"/api/*": {
|
||||
"origins": os.getenv('CORS_ORIGINS', 'http://localhost:3000').split(','),
|
||||
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
"allow_headers": ["Content-Type", "Authorization"]
|
||||
}
|
||||
})
|
||||
|
||||
# Import and register blueprints (minimal set)
|
||||
from api.workflows import workflows_bp
|
||||
from api.screen_capture import screen_capture_bp
|
||||
from api.real_demo import real_demo_bp
|
||||
from api.errors import error_response
|
||||
|
||||
app.register_blueprint(workflows_bp, url_prefix='/api/workflows')
|
||||
app.register_blueprint(screen_capture_bp, url_prefix='/api/screen-capture')
|
||||
app.register_blueprint(real_demo_bp)
|
||||
|
||||
# Optional / Phase 2+ blueprints (loaded only if modules are available)
|
||||
try:
|
||||
from api.self_healing import self_healing_bp
|
||||
app.register_blueprint(self_healing_bp)
|
||||
except ImportError as e:
|
||||
print(f"⚠️ Blueprint self_healing désactivé: {e}")
|
||||
|
||||
try:
|
||||
from api.visual_targets import visual_targets_bp, init_visual_target_manager
|
||||
app.register_blueprint(visual_targets_bp)
|
||||
VISUAL_TARGETS_BP_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
print(f"⚠️ Blueprint visual_targets désactivé: {e}")
|
||||
VISUAL_TARGETS_BP_AVAILABLE = False
|
||||
init_visual_target_manager = None
|
||||
|
||||
try:
|
||||
from api.element_detection import element_detection_bp, init_element_detection
|
||||
app.register_blueprint(element_detection_bp)
|
||||
ELEMENT_DETECTION_BP_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
print(f"⚠️ Blueprint element_detection désactivé: {e}")
|
||||
ELEMENT_DETECTION_BP_AVAILABLE = False
|
||||
init_element_detection = None
|
||||
|
||||
try:
|
||||
from api.analytics import analytics_bp
|
||||
app.register_blueprint(analytics_bp, url_prefix='/api/analytics')
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# Register other blueprints (optional - depends on Phase 2+ services)
|
||||
try:
|
||||
from api.templates import templates_bp
|
||||
app.register_blueprint(templates_bp, url_prefix='/api/templates')
|
||||
except ImportError as e:
|
||||
print(f"⚠️ Blueprint templates désactivé: {e}")
|
||||
|
||||
from api.node_types import node_types_bp
|
||||
app.register_blueprint(node_types_bp, url_prefix='/api/node-types')
|
||||
|
||||
try:
|
||||
from api.executions import executions_bp
|
||||
app.register_blueprint(executions_bp, url_prefix='/api/executions')
|
||||
except ImportError as e:
|
||||
print(f"⚠️ Blueprint executions désactivé: {e}")
|
||||
|
||||
try:
|
||||
from api.import_export import import_export_bp
|
||||
app.register_blueprint(import_export_bp, url_prefix='/api')
|
||||
except ImportError as e:
|
||||
print(f"⚠️ Blueprint import_export désactivé: {e}")
|
||||
|
||||
try:
|
||||
from api.correction_packs import correction_packs_bp
|
||||
app.register_blueprint(correction_packs_bp, url_prefix='/api')
|
||||
print("✅ Blueprint correction_packs enregistré")
|
||||
except ImportError as e:
|
||||
print(f"⚠️ Blueprint correction_packs désactivé: {e}")
|
||||
|
||||
|
||||
# Import WebSocket handlers (optional)
|
||||
try:
|
||||
from api import websocket_handlers # noqa: F401
|
||||
except Exception as e:
|
||||
print(f"⚠️ WebSocket handlers désactivés: {e}")
|
||||
|
||||
# Global error handlers
|
||||
@app.errorhandler(404)
|
||||
def not_found(error):
|
||||
"""Handle 404 errors"""
|
||||
return error_response(404, "Resource not found")
|
||||
|
||||
@app.errorhandler(405)
|
||||
def method_not_allowed(error):
|
||||
"""Handle 405 errors"""
|
||||
return error_response(405, "Method not allowed")
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(error):
|
||||
"""Handle 500 errors"""
|
||||
return error_response(500, "Internal server error")
|
||||
|
||||
@app.errorhandler(Exception)
|
||||
def handle_exception(error):
|
||||
"""Handle all unhandled exceptions"""
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return error_response(500, f"Unexpected error: {str(error)}")
|
||||
|
||||
# Health check endpoint
|
||||
@app.route('/health')
|
||||
def health_check():
|
||||
"""Health check endpoint for monitoring"""
|
||||
return {'status': 'healthy', 'version': '1.0.0'}
|
||||
|
||||
# Create database tables
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
|
||||
# Initialize VisualTargetManager with RPA Vision V3 components (optional)
|
||||
try:
|
||||
from core.capture.screen_capturer import ScreenCapturer
|
||||
from core.detection.ui_detector import UIDetector
|
||||
from core.embedding.fusion_engine import FusionEngine
|
||||
|
||||
# Only initialize if the related blueprints were actually loaded
|
||||
if VISUAL_TARGETS_BP_AVAILABLE and init_visual_target_manager:
|
||||
screen_capturer = ScreenCapturer()
|
||||
ui_detector = UIDetector()
|
||||
fusion_engine = FusionEngine()
|
||||
init_visual_target_manager(screen_capturer, ui_detector, fusion_engine)
|
||||
|
||||
if ELEMENT_DETECTION_BP_AVAILABLE and init_element_detection:
|
||||
# Reuse the same instances when possible
|
||||
if 'ui_detector' not in locals():
|
||||
ui_detector = UIDetector()
|
||||
if 'screen_capturer' not in locals():
|
||||
screen_capturer = ScreenCapturer()
|
||||
init_element_detection(ui_detector, screen_capturer)
|
||||
|
||||
if (VISUAL_TARGETS_BP_AVAILABLE and init_visual_target_manager) or (ELEMENT_DETECTION_BP_AVAILABLE and init_element_detection):
|
||||
print("✅ Services visuels initialisés (VisualTargets / ElementDetection)")
|
||||
except ImportError as e:
|
||||
print(f"⚠️ Core RPA non disponible pour l'initialisation visuelle: {e}")
|
||||
except Exception as e:
|
||||
print(f"❌ Erreur lors de l'initialisation des services visuels: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
port = int(os.getenv('PORT', 5002))
|
||||
debug = os.getenv('FLASK_ENV') == 'development'
|
||||
|
||||
socketio.run(
|
||||
app,
|
||||
host='0.0.0.0',
|
||||
port=port,
|
||||
debug=debug,
|
||||
use_reloader=debug,
|
||||
allow_unsafe_werkzeug=True # For development only
|
||||
)
|
||||
Reference in New Issue
Block a user