feat(corrections): Add Correction Packs system for cross-workflow learning

Implement a complete system for capitalizing user corrections across multiple
workflows and sessions. This enables automatic application of learned fixes
when similar failures occur in different contexts.

New components:
- core/corrections/models.py: CorrectionKey, Correction, CorrectionPack models
- core/corrections/correction_repository.py: JSON storage with atomic writes
- core/corrections/aggregator.py: Aggregation by hash and quality filtering
- core/corrections/correction_pack_service.py: CRUD, export/import, versioning
- backend/api/correction_packs.py: REST API with 15 endpoints

Features:
- MD5-based key hashing for correction deduplication
- Export/import in JSON and YAML formats
- Version history with rollback support
- Cross-workflow pattern detection
- Integration with SelfHealingEngine for automatic application
- 29 unit tests (all passing)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Dom
2026-01-18 18:48:35 +01:00
parent fa57ecdbfd
commit d8756883c5
9 changed files with 4411 additions and 0 deletions

View File

@@ -0,0 +1,30 @@
"""
Correction Packs System - Cross-workflow correction capitalization.
This module provides a system for storing, aggregating, and applying
user corrections across multiple workflows and sessions.
"""
from .models import (
CorrectionKey,
CorrectionType,
CorrectionStatus,
Correction,
CorrectionPack,
CorrectionPackMetadata
)
from .correction_repository import CorrectionRepository
from .aggregator import CorrectionAggregator
from .correction_pack_service import CorrectionPackService
__all__ = [
'CorrectionKey',
'CorrectionType',
'CorrectionStatus',
'Correction',
'CorrectionPack',
'CorrectionPackMetadata',
'CorrectionRepository',
'CorrectionAggregator',
'CorrectionPackService'
]

View File

@@ -0,0 +1,545 @@
"""
Correction Aggregator - Aggregate corrections from multiple sessions.
Groups corrections by key hash and filters by quality thresholds.
"""
import logging
from collections import defaultdict
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
from .models import (
Correction,
CorrectionKey,
CorrectionType,
CorrectionSource,
CorrectionStatus,
generate_correction_id
)
logger = logging.getLogger(__name__)
class CorrectionAggregator:
"""
Aggregates corrections from multiple sources (sessions, workflows).
Groups corrections by CorrectionKey hash and merges statistics.
Filters out low-quality corrections based on configurable thresholds.
"""
def __init__(
self,
min_occurrences: int = 2,
min_success_rate: float = 0.5,
min_confidence: float = 0.3
):
"""
Initialize the aggregator.
Args:
min_occurrences: Minimum number of times a correction must occur
min_success_rate: Minimum success rate to include correction
min_confidence: Minimum confidence score to include correction
"""
self.min_occurrences = min_occurrences
self.min_success_rate = min_success_rate
self.min_confidence = min_confidence
def aggregate_corrections(
self,
corrections: List[Correction],
merge_strategy: str = 'weighted_average'
) -> List[Correction]:
"""
Aggregate a list of corrections by key hash.
Args:
corrections: List of corrections to aggregate
merge_strategy: How to merge duplicate corrections:
- 'weighted_average': Weight by usage count
- 'best_confidence': Keep highest confidence
- 'most_recent': Keep most recently used
Returns:
List of aggregated corrections
"""
# Group by key hash
grouped: Dict[str, List[Correction]] = defaultdict(list)
for correction in corrections:
key_hash = correction.key.to_hash()
grouped[key_hash].append(correction)
# Aggregate each group
aggregated = []
for key_hash, group in grouped.items():
if len(group) == 1:
# Single correction, just check thresholds
if self._passes_thresholds(group[0]):
aggregated.append(group[0])
else:
# Multiple corrections, merge
merged = self._merge_corrections(group, merge_strategy)
if merged and self._passes_thresholds(merged):
aggregated.append(merged)
return aggregated
def aggregate_from_sessions(
self,
session_corrections: List[Dict[str, Any]],
workflow_id: Optional[str] = None
) -> List[Correction]:
"""
Aggregate corrections from session data (user_corrections format).
Args:
session_corrections: List of correction dicts from TrainingSession
workflow_id: Optional workflow ID to filter by
Returns:
List of aggregated Correction objects
"""
# Convert session corrections to Correction objects
corrections = []
for sc in session_corrections:
try:
correction = self._session_correction_to_correction(sc)
if correction:
# Filter by workflow if specified
if workflow_id is None or (correction.source and correction.source.workflow_id == workflow_id):
corrections.append(correction)
except Exception as e:
logger.warning(f"Error converting session correction: {e}")
# Aggregate
return self.aggregate_corrections(corrections)
def group_by_type(
self,
corrections: List[Correction]
) -> Dict[CorrectionType, List[Correction]]:
"""
Group corrections by their type.
Args:
corrections: List of corrections
Returns:
Dict mapping correction type to list of corrections
"""
grouped: Dict[CorrectionType, List[Correction]] = defaultdict(list)
for correction in corrections:
grouped[correction.correction_type].append(correction)
return dict(grouped)
def group_by_action_type(
self,
corrections: List[Correction]
) -> Dict[str, List[Correction]]:
"""
Group corrections by the action type they apply to.
Args:
corrections: List of corrections
Returns:
Dict mapping action type to list of corrections
"""
grouped: Dict[str, List[Correction]] = defaultdict(list)
for correction in corrections:
grouped[correction.key.action_type].append(correction)
return dict(grouped)
def filter_by_quality(
self,
corrections: List[Correction],
min_success_rate: Optional[float] = None,
min_confidence: Optional[float] = None,
min_applications: Optional[int] = None
) -> List[Correction]:
"""
Filter corrections by quality thresholds.
Args:
corrections: List of corrections to filter
min_success_rate: Minimum success rate (overrides default)
min_confidence: Minimum confidence (overrides default)
min_applications: Minimum number of applications
Returns:
Filtered list of corrections
"""
min_sr = min_success_rate if min_success_rate is not None else self.min_success_rate
min_conf = min_confidence if min_confidence is not None else self.min_confidence
min_apps = min_applications if min_applications is not None else 1
return [
c for c in corrections
if c.success_rate >= min_sr
and c.confidence_score >= min_conf
and (c.success_count + c.failure_count) >= min_apps
]
def rank_corrections(
self,
corrections: List[Correction],
sort_by: str = 'confidence'
) -> List[Correction]:
"""
Rank corrections by specified criteria.
Args:
corrections: List of corrections
sort_by: Ranking criteria:
- 'confidence': By confidence score
- 'success_rate': By success rate
- 'usage': By total applications
- 'recent': By last application date
Returns:
Sorted list of corrections
"""
if sort_by == 'confidence':
return sorted(corrections, key=lambda c: c.confidence_score, reverse=True)
elif sort_by == 'success_rate':
return sorted(corrections, key=lambda c: c.success_rate, reverse=True)
elif sort_by == 'usage':
return sorted(corrections, key=lambda c: c.success_count + c.failure_count, reverse=True)
elif sort_by == 'recent':
return sorted(
corrections,
key=lambda c: c.last_applied or datetime.min,
reverse=True
)
else:
return corrections
def calculate_statistics(
self,
corrections: List[Correction]
) -> Dict[str, Any]:
"""
Calculate aggregate statistics for a list of corrections.
Args:
corrections: List of corrections
Returns:
Statistics dict
"""
if not corrections:
return {
'count': 0,
'total_applications': 0,
'average_success_rate': 0.0,
'average_confidence': 0.0,
'by_type': {},
'by_action': {}
}
total_apps = sum(c.success_count + c.failure_count for c in corrections)
total_success = sum(c.success_count for c in corrections)
# Group by type
by_type = self.group_by_type(corrections)
type_stats = {
t.value: len(cs) for t, cs in by_type.items()
}
# Group by action
by_action = self.group_by_action_type(corrections)
action_stats = {
a: len(cs) for a, cs in by_action.items()
}
return {
'count': len(corrections),
'total_applications': total_apps,
'total_success': total_success,
'total_failure': total_apps - total_success,
'average_success_rate': total_success / total_apps if total_apps > 0 else 0.0,
'average_confidence': sum(c.confidence_score for c in corrections) / len(corrections),
'by_type': type_stats,
'by_action': action_stats
}
def _passes_thresholds(self, correction: Correction) -> bool:
"""Check if a correction passes quality thresholds."""
total = correction.success_count + correction.failure_count
# Always include if no applications yet (new correction)
if total == 0:
return True
# Check minimum occurrences
if total < self.min_occurrences:
return False
# Check success rate
if correction.success_rate < self.min_success_rate:
return False
# Check confidence
if correction.confidence_score < self.min_confidence:
return False
return True
def _merge_corrections(
self,
corrections: List[Correction],
strategy: str = 'weighted_average'
) -> Optional[Correction]:
"""
Merge multiple corrections with the same key hash.
Args:
corrections: List of corrections to merge
strategy: Merge strategy
Returns:
Merged correction
"""
if not corrections:
return None
if len(corrections) == 1:
return corrections[0]
if strategy == 'best_confidence':
return max(corrections, key=lambda c: c.confidence_score)
elif strategy == 'most_recent':
return max(
corrections,
key=lambda c: c.last_applied or c.created_at
)
else: # weighted_average (default)
# Use the correction with highest usage as base
base = max(corrections, key=lambda c: c.success_count + c.failure_count)
# Sum statistics
total_success = sum(c.success_count for c in corrections)
total_failure = sum(c.failure_count for c in corrections)
# Find most recent application
last_applied = None
for c in corrections:
if c.last_applied:
if last_applied is None or c.last_applied > last_applied:
last_applied = c.last_applied
# Merge tags
all_tags = set()
for c in corrections:
all_tags.update(c.tags)
# Create merged correction
merged = Correction(
id=generate_correction_id(),
key=base.key,
correction_type=base.correction_type,
original_target=base.original_target,
original_params=base.original_params,
corrected_target=base.corrected_target,
corrected_params=base.corrected_params,
correction_description=base.correction_description,
source=base.source, # Keep first source as primary
success_count=total_success,
failure_count=total_failure,
last_applied=last_applied,
status=CorrectionStatus.ACTIVE,
created_at=min(c.created_at for c in corrections),
updated_at=datetime.now(),
tags=list(all_tags)
)
return merged
def _session_correction_to_correction(
self,
session_correction: Dict[str, Any]
) -> Optional[Correction]:
"""
Convert a session correction dict to a Correction object.
Args:
session_correction: Dict from TrainingSession.user_corrections
Returns:
Correction object or None if invalid
"""
try:
# Extract key components
action_type = session_correction.get('action_type', session_correction.get('type', 'unknown'))
element_type = session_correction.get('element_type', 'unknown')
failure_context = session_correction.get('failure_reason', session_correction.get('reason', ''))
key = CorrectionKey(action_type, element_type, failure_context)
# Determine correction type
correction_type_str = session_correction.get('correction_type', 'other')
try:
correction_type = CorrectionType(correction_type_str)
except ValueError:
correction_type = CorrectionType.OTHER
# Create source
source = CorrectionSource(
session_id=session_correction.get('session_id', ''),
workflow_id=session_correction.get('workflow_id'),
node_id=session_correction.get('node_id')
)
timestamp = session_correction.get('timestamp')
if timestamp:
if isinstance(timestamp, str):
source.timestamp = datetime.fromisoformat(timestamp)
elif isinstance(timestamp, datetime):
source.timestamp = timestamp
# Create correction
return Correction(
id=generate_correction_id(),
key=key,
correction_type=correction_type,
original_target=session_correction.get('original_target'),
original_params=session_correction.get('original_params'),
corrected_target=session_correction.get('corrected_target', session_correction.get('new_target')),
corrected_params=session_correction.get('corrected_params', session_correction.get('new_params')),
correction_description=session_correction.get('description', ''),
source=source,
success_count=1 if session_correction.get('applied', True) else 0,
failure_count=0 if session_correction.get('applied', True) else 1,
tags=session_correction.get('tags', [])
)
except Exception as e:
logger.warning(f"Error converting session correction: {e}")
return None
class CrossWorkflowAggregator(CorrectionAggregator):
"""
Extended aggregator for cross-workflow correction analysis.
Provides additional methods for analyzing patterns across workflows.
"""
def find_common_corrections(
self,
corrections: List[Correction],
min_workflows: int = 2
) -> List[Tuple[Correction, List[str]]]:
"""
Find corrections that appear in multiple workflows.
Args:
corrections: List of corrections
min_workflows: Minimum number of workflows
Returns:
List of (correction, workflow_ids) tuples
"""
# Group by key hash
grouped: Dict[str, List[Correction]] = defaultdict(list)
for correction in corrections:
key_hash = correction.key.to_hash()
grouped[key_hash].append(correction)
# Find common corrections
common = []
for key_hash, group in grouped.items():
workflow_ids = set()
for c in group:
if c.source and c.source.workflow_id:
workflow_ids.add(c.source.workflow_id)
if len(workflow_ids) >= min_workflows:
# Merge the group
merged = self._merge_corrections(group)
if merged:
common.append((merged, list(workflow_ids)))
# Sort by number of workflows
common.sort(key=lambda x: len(x[1]), reverse=True)
return common
def detect_patterns(
self,
corrections: List[Correction]
) -> Dict[str, Any]:
"""
Detect patterns in corrections.
Args:
corrections: List of corrections
Returns:
Pattern analysis dict
"""
if not corrections:
return {'patterns': [], 'summary': {}}
patterns = []
# Pattern 1: Common action type corrections
by_action = self.group_by_action_type(corrections)
for action, action_corrections in by_action.items():
if len(action_corrections) >= 3:
patterns.append({
'type': 'frequent_action_correction',
'action_type': action,
'count': len(action_corrections),
'avg_success_rate': sum(c.success_rate for c in action_corrections) / len(action_corrections)
})
# Pattern 2: Common element type failures
by_element: Dict[str, List[Correction]] = defaultdict(list)
for c in corrections:
by_element[c.key.element_type].append(c)
for element, element_corrections in by_element.items():
if len(element_corrections) >= 3:
patterns.append({
'type': 'element_type_issues',
'element_type': element,
'count': len(element_corrections),
'common_fixes': self._find_common_fix_patterns(element_corrections)
})
# Pattern 3: Time-based patterns
recent = [c for c in corrections if c.last_applied and
(datetime.now() - c.last_applied).days < 7]
if len(recent) > len(corrections) * 0.5:
patterns.append({
'type': 'recent_activity_spike',
'recent_count': len(recent),
'total_count': len(corrections)
})
return {
'patterns': patterns,
'summary': {
'total_corrections': len(corrections),
'unique_action_types': len(by_action),
'unique_element_types': len(by_element),
'pattern_count': len(patterns)
}
}
def _find_common_fix_patterns(
self,
corrections: List[Correction]
) -> List[str]:
"""Find common fix patterns in a list of corrections."""
fix_types = defaultdict(int)
for c in corrections:
fix_types[c.correction_type.value] += 1
# Return top 3 most common
sorted_fixes = sorted(fix_types.items(), key=lambda x: x[1], reverse=True)
return [fix[0] for fix in sorted_fixes[:3]]

View File

@@ -0,0 +1,785 @@
"""
Correction Pack Service - Main service for correction pack management.
Combines repository, aggregator, and provides high-level operations
for CRUD, aggregation, export/import, and application of corrections.
"""
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple
from .models import (
Correction,
CorrectionPack,
CorrectionKey,
CorrectionType,
CorrectionSource,
CorrectionStatus,
CorrectionPackMetadata,
generate_correction_id,
generate_pack_id
)
from .correction_repository import CorrectionRepository
from .aggregator import CorrectionAggregator, CrossWorkflowAggregator
logger = logging.getLogger(__name__)
class CorrectionPackService:
"""
Main service for managing correction packs.
Provides high-level operations combining the repository and aggregator.
"""
def __init__(
self,
storage_path: str = "data/correction_packs",
training_data_path: str = "training_data"
):
"""
Initialize the service.
Args:
storage_path: Path for correction pack storage
training_data_path: Path to training data (for session import)
"""
self.storage_path = Path(storage_path)
self.training_data_path = Path(training_data_path)
# Initialize components
self.repository = CorrectionRepository(storage_path)
self.aggregator = CorrectionAggregator()
self.cross_workflow_aggregator = CrossWorkflowAggregator()
logger.info(f"CorrectionPackService initialized (storage={storage_path})")
# ========== Pack CRUD Operations ==========
def list_packs(
self,
category: Optional[str] = None,
tags: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""
List all correction packs with optional filtering.
Args:
category: Filter by category
tags: Filter by tags (match any)
Returns:
List of pack summaries
"""
packs = self.repository.list_packs()
# Filter by category
if category:
packs = [p for p in packs if p.metadata.category == category]
# Filter by tags
if tags:
packs = [
p for p in packs
if any(tag in p.metadata.tags for tag in tags)
]
# Return summaries
return [
{
'id': p.id,
'name': p.name,
'description': p.description,
'category': p.metadata.category,
'tags': p.metadata.tags,
'correction_count': p.correction_count,
'success_rate': p.overall_success_rate,
'created_at': p.created_at.isoformat(),
'updated_at': p.updated_at.isoformat()
}
for p in packs
]
def get_pack(self, pack_id: str) -> Optional[Dict[str, Any]]:
"""
Get a pack by ID with full details.
Args:
pack_id: Pack ID
Returns:
Pack dict or None
"""
pack = self.repository.get_pack(pack_id)
if not pack:
return None
return pack.to_dict()
def create_pack(
self,
name: str,
description: str = "",
category: str = "general",
tags: Optional[List[str]] = None,
author: str = ""
) -> Dict[str, Any]:
"""
Create a new correction pack.
Args:
name: Pack name
description: Pack description
category: Pack category
tags: Pack tags
author: Pack author
Returns:
Created pack dict
"""
metadata = {
'category': category,
'tags': tags or [],
'author': author
}
pack = self.repository.create_pack(name, description, metadata)
return pack.to_dict()
def update_pack(
self,
pack_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
"""
Update a pack's properties.
Args:
pack_id: Pack ID
name: New name
description: New description
metadata: New metadata
Returns:
Updated pack dict or None
"""
updates = {}
if name is not None:
updates['name'] = name
if description is not None:
updates['description'] = description
if metadata is not None:
updates['metadata'] = metadata
pack = self.repository.update_pack(pack_id, updates)
if not pack:
return None
return pack.to_dict()
def delete_pack(self, pack_id: str) -> bool:
"""
Delete a pack.
Args:
pack_id: Pack ID
Returns:
True if deleted
"""
return self.repository.delete_pack(pack_id)
# ========== Correction Operations ==========
def add_correction(
self,
pack_id: str,
correction_data: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Add a correction to a pack.
Args:
pack_id: Pack ID
correction_data: Correction data dict
Returns:
Created correction dict or None
"""
try:
# Build correction key
key = CorrectionKey(
action_type=correction_data.get('action_type', 'unknown'),
element_type=correction_data.get('element_type', 'unknown'),
failure_context=correction_data.get('failure_context', '')
)
# Determine correction type
type_str = correction_data.get('correction_type', 'other')
try:
correction_type = CorrectionType(type_str)
except ValueError:
correction_type = CorrectionType.OTHER
# Build source if provided
source = None
if 'source' in correction_data:
source = CorrectionSource(
session_id=correction_data['source'].get('session_id', ''),
workflow_id=correction_data['source'].get('workflow_id'),
node_id=correction_data['source'].get('node_id')
)
# Create correction
correction = Correction(
id=generate_correction_id(),
key=key,
correction_type=correction_type,
original_target=correction_data.get('original_target'),
original_params=correction_data.get('original_params'),
corrected_target=correction_data.get('corrected_target'),
corrected_params=correction_data.get('corrected_params'),
correction_description=correction_data.get('description', ''),
source=source,
tags=correction_data.get('tags', [])
)
if self.repository.add_correction(pack_id, correction):
return correction.to_dict()
return None
except Exception as e:
logger.error(f"Error adding correction: {e}")
return None
def add_correction_from_session(
self,
pack_id: str,
session_correction: Dict[str, Any],
session_id: str,
workflow_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Add a correction from session user_corrections format.
Args:
pack_id: Pack ID
session_correction: Correction dict from TrainingSession.user_corrections
session_id: Session ID
workflow_id: Workflow ID
Returns:
Created correction dict or None
"""
# Enrich with source info
session_correction['source'] = {
'session_id': session_id,
'workflow_id': workflow_id,
'node_id': session_correction.get('node_id')
}
# Map old format fields to new format
if 'type' in session_correction and 'action_type' not in session_correction:
session_correction['action_type'] = session_correction['type']
if 'reason' in session_correction and 'failure_context' not in session_correction:
session_correction['failure_context'] = session_correction['reason']
if 'new_target' in session_correction and 'corrected_target' not in session_correction:
session_correction['corrected_target'] = session_correction['new_target']
if 'new_params' in session_correction and 'corrected_params' not in session_correction:
session_correction['corrected_params'] = session_correction['new_params']
return self.add_correction(pack_id, session_correction)
def update_correction(
self,
pack_id: str,
correction_id: str,
updates: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Update a correction.
Args:
pack_id: Pack ID
correction_id: Correction ID
updates: Fields to update
Returns:
Updated correction dict or None
"""
correction = self.repository.update_correction(pack_id, correction_id, updates)
if not correction:
return None
return correction.to_dict()
def remove_correction(self, pack_id: str, correction_id: str) -> bool:
"""
Remove a correction from a pack.
Args:
pack_id: Pack ID
correction_id: Correction ID
Returns:
True if removed
"""
return self.repository.remove_correction(pack_id, correction_id)
# ========== Aggregation Operations ==========
def aggregate_from_sessions(
self,
pack_id: str,
session_files: Optional[List[str]] = None,
workflow_id: Optional[str] = None,
min_occurrences: int = 1,
min_success_rate: float = 0.0
) -> Dict[str, Any]:
"""
Aggregate corrections from training sessions into a pack.
Args:
pack_id: Pack ID to add corrections to
session_files: Specific session files to process (None = all)
workflow_id: Filter by workflow ID
min_occurrences: Minimum occurrences to include
min_success_rate: Minimum success rate to include
Returns:
Aggregation result summary
"""
# Collect all session corrections
all_corrections = []
if session_files is None:
# Find all session files
session_files = list(self.training_data_path.glob("session_*.json"))
else:
session_files = [Path(f) for f in session_files]
sessions_processed = 0
for session_file in session_files:
try:
with open(session_file, 'r', encoding='utf-8') as f:
session_data = json.load(f)
# Filter by workflow if specified
if workflow_id and session_data.get('workflow_id') != workflow_id:
continue
user_corrections = session_data.get('user_corrections', [])
session_id = session_data.get('session_id', session_file.stem)
for uc in user_corrections:
uc['session_id'] = session_id
uc['workflow_id'] = session_data.get('workflow_id')
all_corrections.append(uc)
sessions_processed += 1
except Exception as e:
logger.warning(f"Error reading session file {session_file}: {e}")
# Configure aggregator
self.aggregator.min_occurrences = min_occurrences
self.aggregator.min_success_rate = min_success_rate
# Aggregate
aggregated = self.aggregator.aggregate_from_sessions(all_corrections, workflow_id)
# Add to pack
added_count = 0
for correction in aggregated:
if self.repository.add_correction(pack_id, correction):
added_count += 1
return {
'sessions_processed': sessions_processed,
'corrections_found': len(all_corrections),
'corrections_aggregated': len(aggregated),
'corrections_added': added_count,
'statistics': self.aggregator.calculate_statistics(aggregated)
}
def aggregate_cross_workflow(
self,
pack_id: str,
min_workflows: int = 2
) -> Dict[str, Any]:
"""
Find and aggregate corrections common across multiple workflows.
Args:
pack_id: Pack ID to add corrections to
min_workflows: Minimum number of workflows a correction must appear in
Returns:
Aggregation result summary
"""
# Get existing corrections from all packs
all_corrections = []
for pack in self.repository.list_packs():
all_corrections.extend(pack.corrections.values())
# Find common corrections
common = self.cross_workflow_aggregator.find_common_corrections(
all_corrections, min_workflows
)
# Add to pack
added_count = 0
for correction, workflow_ids in common:
# Tag with source workflows
correction.tags.extend([f"wf:{wid}" for wid in workflow_ids[:5]]) # Limit tags
correction.correction_description += f" (common in {len(workflow_ids)} workflows)"
if self.repository.add_correction(pack_id, correction):
added_count += 1
# Detect patterns
patterns = self.cross_workflow_aggregator.detect_patterns(all_corrections)
return {
'total_corrections_analyzed': len(all_corrections),
'common_corrections_found': len(common),
'corrections_added': added_count,
'patterns': patterns
}
# ========== Search and Apply Operations ==========
def find_applicable_corrections(
self,
action_type: str,
element_type: str,
failure_context: str,
pack_ids: Optional[List[str]] = None,
min_confidence: float = 0.3
) -> List[Dict[str, Any]]:
"""
Find corrections applicable to a failure context.
Args:
action_type: Type of action that failed
element_type: Type of element involved
failure_context: Description of failure
pack_ids: Specific packs to search (None = all)
min_confidence: Minimum confidence score
Returns:
List of applicable corrections with pack info
"""
results = self.repository.find_by_context(action_type, element_type, failure_context)
# Filter by pack IDs if specified
if pack_ids:
results = [(pid, c) for pid, c in results if pid in pack_ids]
# Filter by confidence
results = [(pid, c) for pid, c in results if c.confidence_score >= min_confidence]
# Format results
return [
{
'pack_id': pack_id,
'correction': correction.to_dict()
}
for pack_id, correction in results
]
def apply_correction(
self,
pack_id: str,
correction_id: str,
success: bool,
application_context: Optional[Dict[str, Any]] = None
) -> bool:
"""
Record an application of a correction.
Args:
pack_id: Pack ID
correction_id: Correction ID
success: Whether the application was successful
application_context: Optional context information
Returns:
True if recorded
"""
self.repository.record_application(pack_id, correction_id, success)
logger.info(
f"Recorded correction application: pack={pack_id}, "
f"correction={correction_id}, success={success}"
)
return True
def search_corrections(
self,
query: str,
pack_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Search corrections by text query.
Args:
query: Search query
pack_id: Optional pack ID to search in
Returns:
List of matching corrections with pack info
"""
results = self.repository.search_corrections(query, pack_id)
return [
{
'pack_id': pid,
'correction': correction.to_dict()
}
for pid, correction in results
]
# ========== Export/Import Operations ==========
def export_pack(
self,
pack_id: str,
format: str = 'json',
include_stats: bool = True
) -> Optional[str]:
"""
Export a pack to JSON or YAML.
Args:
pack_id: Pack ID
format: Export format ('json' or 'yaml')
include_stats: Whether to include statistics
Returns:
Exported content string or None
"""
return self.repository.export_pack(pack_id, format)
def import_pack(
self,
content: str,
format: str = 'json',
merge_strategy: str = 'create_new'
) -> Optional[Dict[str, Any]]:
"""
Import a pack from JSON or YAML.
Args:
content: Content to import
format: Import format ('json' or 'yaml')
merge_strategy: How to handle existing pack
Returns:
Imported pack dict or None
"""
pack = self.repository.import_pack(content, format, merge_strategy)
if not pack:
return None
return pack.to_dict()
def export_pack_to_file(
self,
pack_id: str,
file_path: str,
format: str = 'json'
) -> bool:
"""
Export a pack to a file.
Args:
pack_id: Pack ID
file_path: Output file path
format: Export format
Returns:
True if successful
"""
content = self.export_pack(pack_id, format)
if not content:
return False
try:
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
logger.info(f"Exported pack {pack_id} to {file_path}")
return True
except Exception as e:
logger.error(f"Error exporting pack to file: {e}")
return False
def import_pack_from_file(
self,
file_path: str,
format: Optional[str] = None,
merge_strategy: str = 'create_new'
) -> Optional[Dict[str, Any]]:
"""
Import a pack from a file.
Args:
file_path: Input file path
format: Import format (auto-detected if None)
merge_strategy: How to handle existing pack
Returns:
Imported pack dict or None
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# Auto-detect format
if format is None:
if file_path.endswith('.yaml') or file_path.endswith('.yml'):
format = 'yaml'
else:
format = 'json'
result = self.import_pack(content, format, merge_strategy)
if result:
logger.info(f"Imported pack from {file_path}")
return result
except Exception as e:
logger.error(f"Error importing pack from file: {e}")
return None
# ========== Version Operations ==========
def create_version(
self,
pack_id: str,
description: str = ""
) -> int:
"""
Create a version snapshot of a pack.
Args:
pack_id: Pack ID
description: Version description
Returns:
Version number (-1 on error)
"""
return self.repository.create_version(pack_id, description)
def list_versions(self, pack_id: str) -> List[Dict[str, Any]]:
"""
List all versions of a pack.
Args:
pack_id: Pack ID
Returns:
List of version info dicts
"""
return self.repository.list_versions(pack_id)
def rollback_pack(self, pack_id: str, version: int) -> bool:
"""
Rollback a pack to a previous version.
Args:
pack_id: Pack ID
version: Version number
Returns:
True if successful
"""
return self.repository.rollback_pack(pack_id, version)
# ========== Statistics and Analysis ==========
def get_pack_statistics(self, pack_id: str) -> Optional[Dict[str, Any]]:
"""
Get detailed statistics for a pack.
Args:
pack_id: Pack ID
Returns:
Statistics dict or None
"""
pack = self.repository.get_pack(pack_id)
if not pack:
return None
return pack.get_statistics()
def get_global_statistics(self) -> Dict[str, Any]:
"""
Get global statistics across all packs.
Returns:
Global statistics dict
"""
packs = self.repository.list_packs()
total_corrections = sum(p.correction_count for p in packs)
total_applications = sum(p.total_applications for p in packs)
# Aggregate success rate
total_success = sum(
sum(c.success_count for c in p.corrections.values())
for p in packs
)
return {
'total_packs': len(packs),
'total_corrections': total_corrections,
'total_applications': total_applications,
'overall_success_rate': total_success / total_applications if total_applications > 0 else 0.0,
'packs_by_category': self._group_packs_by_category(packs),
'top_corrections': self._get_top_corrections(packs, 10)
}
def _group_packs_by_category(
self,
packs: List[CorrectionPack]
) -> Dict[str, int]:
"""Group packs by category."""
by_category = {}
for pack in packs:
cat = pack.metadata.category
by_category[cat] = by_category.get(cat, 0) + 1
return by_category
def _get_top_corrections(
self,
packs: List[CorrectionPack],
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get top corrections by confidence score."""
all_corrections = []
for pack in packs:
for correction in pack.corrections.values():
all_corrections.append({
'pack_id': pack.id,
'pack_name': pack.name,
'correction': correction
})
# Sort by confidence
all_corrections.sort(key=lambda x: x['correction'].confidence_score, reverse=True)
# Return top N
return [
{
'pack_id': item['pack_id'],
'pack_name': item['pack_name'],
'correction_id': item['correction'].id,
'action_type': item['correction'].key.action_type,
'confidence': item['correction'].confidence_score,
'success_rate': item['correction'].success_rate,
'applications': item['correction'].success_count + item['correction'].failure_count
}
for item in all_corrections[:limit]
]

View File

@@ -0,0 +1,643 @@
"""
Correction Repository - JSON storage for corrections and packs.
Provides atomic file operations for safe persistence.
"""
import json
import logging
import os
import shutil
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any
from .models import (
Correction,
CorrectionPack,
CorrectionKey,
generate_pack_id
)
logger = logging.getLogger(__name__)
class CorrectionRepository:
"""
Repository for storing and retrieving correction packs.
Uses atomic JSON writes (temp file + rename) for safety.
Maintains an index by key hash for fast lookups.
"""
def __init__(self, storage_path: str = "data/correction_packs"):
"""
Initialize the repository.
Args:
storage_path: Base path for storing correction data
"""
self.storage_path = Path(storage_path)
self.packs_file = self.storage_path / "packs.json"
self.corrections_dir = self.storage_path / "corrections"
self.versions_dir = self.storage_path / "versions"
# Ensure directories exist
self._ensure_directories()
# In-memory cache
self._packs: Dict[str, CorrectionPack] = {}
self._key_index: Dict[str, List[str]] = {} # key_hash -> [correction_ids]
# Load existing data
self._load_packs()
self._rebuild_index()
def _ensure_directories(self):
"""Ensure all required directories exist."""
self.storage_path.mkdir(parents=True, exist_ok=True)
self.corrections_dir.mkdir(parents=True, exist_ok=True)
self.versions_dir.mkdir(parents=True, exist_ok=True)
# ========== Pack Operations ==========
def list_packs(self) -> List[CorrectionPack]:
"""Get all correction packs."""
return list(self._packs.values())
def get_pack(self, pack_id: str) -> Optional[CorrectionPack]:
"""Get a pack by ID."""
return self._packs.get(pack_id)
def create_pack(self, name: str, description: str = "",
metadata: Optional[Dict[str, Any]] = None) -> CorrectionPack:
"""
Create a new correction pack.
Args:
name: Pack name
description: Pack description
metadata: Optional metadata dict
Returns:
Created pack
"""
pack_id = generate_pack_id()
pack = CorrectionPack(
id=pack_id,
name=name,
description=description
)
if metadata:
pack.metadata.version = metadata.get('version', '1.0.0')
pack.metadata.author = metadata.get('author', '')
pack.metadata.category = metadata.get('category', 'general')
pack.metadata.tags = metadata.get('tags', [])
self._packs[pack_id] = pack
self._save_packs()
logger.info(f"Created correction pack: {pack_id} ({name})")
return pack
def update_pack(self, pack_id: str, updates: Dict[str, Any]) -> Optional[CorrectionPack]:
"""
Update a pack's properties.
Args:
pack_id: Pack ID
updates: Dict of fields to update
Returns:
Updated pack or None if not found
"""
pack = self._packs.get(pack_id)
if not pack:
return None
# Update allowed fields
if 'name' in updates:
pack.name = updates['name']
if 'description' in updates:
pack.description = updates['description']
if 'metadata' in updates:
metadata = updates['metadata']
if 'version' in metadata:
pack.metadata.version = metadata['version']
if 'author' in metadata:
pack.metadata.author = metadata['author']
if 'category' in metadata:
pack.metadata.category = metadata['category']
if 'tags' in metadata:
pack.metadata.tags = metadata['tags']
pack.updated_at = datetime.now()
self._save_packs()
logger.info(f"Updated correction pack: {pack_id}")
return pack
def delete_pack(self, pack_id: str) -> bool:
"""
Delete a pack and all its corrections.
Args:
pack_id: Pack ID
Returns:
True if deleted, False if not found
"""
if pack_id not in self._packs:
return False
pack = self._packs[pack_id]
# Remove from index
for correction in pack.corrections.values():
key_hash = correction.key.to_hash()
if key_hash in self._key_index:
if correction.id in self._key_index[key_hash]:
self._key_index[key_hash].remove(correction.id)
del self._packs[pack_id]
self._save_packs()
# Remove version history
version_dir = self.versions_dir / pack_id
if version_dir.exists():
shutil.rmtree(version_dir)
logger.info(f"Deleted correction pack: {pack_id}")
return True
# ========== Correction Operations ==========
def add_correction(self, pack_id: str, correction: Correction) -> bool:
"""
Add a correction to a pack.
Args:
pack_id: Pack ID
correction: Correction to add
Returns:
True if added, False if pack not found or correction exists
"""
pack = self._packs.get(pack_id)
if not pack:
return False
if not pack.add_correction(correction):
return False
# Update index
key_hash = correction.key.to_hash()
if key_hash not in self._key_index:
self._key_index[key_hash] = []
if correction.id not in self._key_index[key_hash]:
self._key_index[key_hash].append(correction.id)
self._save_packs()
logger.info(f"Added correction {correction.id} to pack {pack_id}")
return True
def remove_correction(self, pack_id: str, correction_id: str) -> bool:
"""
Remove a correction from a pack.
Args:
pack_id: Pack ID
correction_id: Correction ID
Returns:
True if removed, False if not found
"""
pack = self._packs.get(pack_id)
if not pack:
return False
correction = pack.get_correction(correction_id)
if not correction:
return False
# Update index
key_hash = correction.key.to_hash()
if key_hash in self._key_index:
if correction_id in self._key_index[key_hash]:
self._key_index[key_hash].remove(correction_id)
pack.remove_correction(correction_id)
self._save_packs()
logger.info(f"Removed correction {correction_id} from pack {pack_id}")
return True
def update_correction(self, pack_id: str, correction_id: str,
updates: Dict[str, Any]) -> Optional[Correction]:
"""
Update a correction's properties.
Args:
pack_id: Pack ID
correction_id: Correction ID
updates: Dict of fields to update
Returns:
Updated correction or None if not found
"""
pack = self._packs.get(pack_id)
if not pack:
return None
correction = pack.get_correction(correction_id)
if not correction:
return None
# Update allowed fields
if 'corrected_target' in updates:
correction.corrected_target = updates['corrected_target']
if 'corrected_params' in updates:
correction.corrected_params = updates['corrected_params']
if 'correction_description' in updates:
correction.correction_description = updates['correction_description']
if 'status' in updates:
from .models import CorrectionStatus
correction.status = CorrectionStatus(updates['status'])
if 'tags' in updates:
correction.tags = updates['tags']
correction.updated_at = datetime.now()
pack.updated_at = datetime.now()
self._save_packs()
logger.info(f"Updated correction {correction_id} in pack {pack_id}")
return correction
def record_application(self, pack_id: str, correction_id: str, success: bool):
"""
Record an application of a correction.
Args:
pack_id: Pack ID
correction_id: Correction ID
success: Whether the application was successful
"""
pack = self._packs.get(pack_id)
if not pack:
return
correction = pack.get_correction(correction_id)
if not correction:
return
correction.record_application(success)
pack.updated_at = datetime.now()
self._save_packs()
# ========== Search Operations ==========
def find_by_key_hash(self, key_hash: str) -> List[tuple]:
"""
Find all corrections matching a key hash.
Args:
key_hash: MD5 hash of correction key
Returns:
List of (pack_id, correction) tuples
"""
results = []
for pack_id, pack in self._packs.items():
for correction in pack.find_by_key_hash(key_hash):
results.append((pack_id, correction))
# Sort by confidence score
results.sort(key=lambda x: x[1].confidence_score, reverse=True)
return results
def find_by_context(self, action_type: str, element_type: str,
failure_context: str) -> List[tuple]:
"""
Find corrections matching a failure context.
Args:
action_type: Type of action that failed
element_type: Type of element involved
failure_context: Description of failure
Returns:
List of (pack_id, correction) tuples
"""
key = CorrectionKey(action_type, element_type, failure_context)
return self.find_by_key_hash(key.to_hash())
def search_corrections(self, query: str, pack_id: Optional[str] = None) -> List[tuple]:
"""
Search corrections by text query.
Args:
query: Search query (searches in description, tags, etc.)
pack_id: Optional pack ID to search in
Returns:
List of (pack_id, correction) tuples
"""
results = []
query_lower = query.lower()
packs_to_search = [self._packs[pack_id]] if pack_id and pack_id in self._packs else self._packs.values()
for pack in packs_to_search:
for correction in pack.corrections.values():
# Search in description
if query_lower in correction.correction_description.lower():
results.append((pack.id, correction))
continue
# Search in tags
if any(query_lower in tag.lower() for tag in correction.tags):
results.append((pack.id, correction))
continue
# Search in key components
if (query_lower in correction.key.action_type.lower() or
query_lower in correction.key.element_type.lower() or
query_lower in correction.key.failure_context.lower()):
results.append((pack.id, correction))
return results
# ========== Version Operations ==========
def create_version(self, pack_id: str, description: str = "") -> int:
"""
Create a version snapshot of a pack.
Args:
pack_id: Pack ID
description: Version description
Returns:
Version number
"""
pack = self._packs.get(pack_id)
if not pack:
return -1
# Create version directory
pack_versions = self.versions_dir / pack_id
pack_versions.mkdir(parents=True, exist_ok=True)
# Increment version
version = pack.current_version
pack.current_version += 1
# Save snapshot
version_file = pack_versions / f"v{version}.json"
version_data = {
'version': version,
'description': description,
'created_at': datetime.now().isoformat(),
'pack_snapshot': pack.to_dict()
}
self._atomic_write(version_file, version_data)
self._save_packs()
logger.info(f"Created version {version} for pack {pack_id}")
return version
def list_versions(self, pack_id: str) -> List[Dict[str, Any]]:
"""
List all versions of a pack.
Args:
pack_id: Pack ID
Returns:
List of version info dicts
"""
versions = []
pack_versions = self.versions_dir / pack_id
if not pack_versions.exists():
return versions
for version_file in sorted(pack_versions.glob("v*.json")):
try:
with open(version_file, 'r', encoding='utf-8') as f:
data = json.load(f)
versions.append({
'version': data.get('version'),
'description': data.get('description', ''),
'created_at': data.get('created_at'),
'correction_count': len(data.get('pack_snapshot', {}).get('corrections', {}))
})
except Exception as e:
logger.warning(f"Error reading version file {version_file}: {e}")
return versions
def rollback_pack(self, pack_id: str, version: int) -> bool:
"""
Rollback a pack to a previous version.
Args:
pack_id: Pack ID
version: Version number to rollback to
Returns:
True if successful, False otherwise
"""
version_file = self.versions_dir / pack_id / f"v{version}.json"
if not version_file.exists():
logger.warning(f"Version {version} not found for pack {pack_id}")
return False
try:
with open(version_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# Restore pack from snapshot
snapshot = data.get('pack_snapshot', {})
pack = CorrectionPack.from_dict(snapshot)
# Update version number to continue from current
if pack_id in self._packs:
pack.current_version = self._packs[pack_id].current_version + 1
self._packs[pack_id] = pack
self._rebuild_index()
self._save_packs()
logger.info(f"Rolled back pack {pack_id} to version {version}")
return True
except Exception as e:
logger.error(f"Error rolling back pack {pack_id} to version {version}: {e}")
return False
# ========== Persistence ==========
def _load_packs(self):
"""Load all packs from storage."""
if not self.packs_file.exists():
return
try:
with open(self.packs_file, 'r', encoding='utf-8') as f:
data = json.load(f)
for pack_id, pack_data in data.items():
try:
self._packs[pack_id] = CorrectionPack.from_dict(pack_data)
except Exception as e:
logger.warning(f"Error loading pack {pack_id}: {e}")
logger.info(f"Loaded {len(self._packs)} correction packs")
except Exception as e:
logger.error(f"Error loading packs file: {e}")
def _save_packs(self):
"""Save all packs to storage using atomic write."""
try:
data = {
pack_id: pack.to_dict()
for pack_id, pack in self._packs.items()
}
self._atomic_write(self.packs_file, data)
except Exception as e:
logger.error(f"Error saving packs: {e}")
def _atomic_write(self, file_path: Path, data: Dict[str, Any]):
"""
Write data to file atomically (temp file + rename).
Args:
file_path: Target file path
data: Data to write
"""
temp_file = file_path.with_suffix('.tmp')
try:
with open(temp_file, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
# Atomic rename
temp_file.replace(file_path)
except Exception as e:
# Clean up temp file on error
if temp_file.exists():
temp_file.unlink()
raise e
def _rebuild_index(self):
"""Rebuild the key hash index from all packs."""
self._key_index.clear()
for pack in self._packs.values():
for correction in pack.corrections.values():
key_hash = correction.key.to_hash()
if key_hash not in self._key_index:
self._key_index[key_hash] = []
if correction.id not in self._key_index[key_hash]:
self._key_index[key_hash].append(correction.id)
# ========== Export/Import ==========
def export_pack(self, pack_id: str, format: str = 'json') -> Optional[str]:
"""
Export a pack to JSON or YAML format.
Args:
pack_id: Pack ID
format: Export format ('json' or 'yaml')
Returns:
Exported content as string, or None if pack not found
"""
pack = self._packs.get(pack_id)
if not pack:
return None
export_data = {
'export_format': 'correction_pack',
'export_version': '1.0',
'exported_at': datetime.now().isoformat(),
'pack': pack.to_dict()
}
if format == 'yaml':
try:
import yaml
return yaml.dump(export_data, default_flow_style=False, allow_unicode=True)
except ImportError:
logger.warning("PyYAML not installed, falling back to JSON")
return json.dumps(export_data, indent=2, ensure_ascii=False)
else:
return json.dumps(export_data, indent=2, ensure_ascii=False)
def import_pack(self, content: str, format: str = 'json',
merge_strategy: str = 'create_new') -> Optional[CorrectionPack]:
"""
Import a pack from JSON or YAML content.
Args:
content: Content to import
format: Import format ('json' or 'yaml')
merge_strategy: How to handle existing pack:
- 'create_new': Always create new pack with new ID
- 'merge': Merge into existing pack if same ID
- 'replace': Replace existing pack if same ID
Returns:
Imported pack, or None on error
"""
try:
if format == 'yaml':
try:
import yaml
data = yaml.safe_load(content)
except ImportError:
logger.error("PyYAML not installed")
return None
else:
data = json.loads(content)
# Validate format
if data.get('export_format') != 'correction_pack':
logger.warning("Invalid export format")
return None
pack_data = data.get('pack', {})
imported_pack = CorrectionPack.from_dict(pack_data)
# Handle merge strategy
if merge_strategy == 'create_new':
imported_pack.id = generate_pack_id()
self._packs[imported_pack.id] = imported_pack
elif merge_strategy == 'merge' and imported_pack.id in self._packs:
existing_pack = self._packs[imported_pack.id]
existing_pack.merge(imported_pack)
imported_pack = existing_pack
elif merge_strategy == 'replace':
self._packs[imported_pack.id] = imported_pack
else:
# Default: create new
imported_pack.id = generate_pack_id()
self._packs[imported_pack.id] = imported_pack
self._rebuild_index()
self._save_packs()
logger.info(f"Imported pack: {imported_pack.id} ({imported_pack.name})")
return imported_pack
except Exception as e:
logger.error(f"Error importing pack: {e}")
return None

480
core/corrections/models.py Normal file
View File

@@ -0,0 +1,480 @@
"""
Data models for the Correction Packs system.
Defines CorrectionKey, Correction, CorrectionPack and related types.
"""
import hashlib
import uuid
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
from typing import Dict, List, Any, Optional
class CorrectionType(Enum):
"""Types of corrections that can be applied."""
TARGET_CHANGE = "target_change" # Changed target selector
PARAMETER_CHANGE = "parameter_change" # Changed action parameters
ACTION_REPLACE = "action_replace" # Replaced action type
WAIT_ADDED = "wait_added" # Added wait before action
RETRY_CONFIG = "retry_config" # Changed retry configuration
FALLBACK_ADDED = "fallback_added" # Added fallback action
COORDINATES_ADJUST = "coordinates_adjust" # Adjusted click coordinates
TIMING_ADJUST = "timing_adjust" # Adjusted timing/delay
OTHER = "other"
class CorrectionStatus(Enum):
"""Status of a correction."""
ACTIVE = "active" # Correction is active and can be applied
DEPRECATED = "deprecated" # Correction is deprecated
DISABLED = "disabled" # Correction is manually disabled
@dataclass
class CorrectionKey:
"""
Unique key to identify similar corrections across workflows.
Uses action_type + element_type + failure_context to generate MD5 hash.
Similar to LearningRepository pattern key generation.
"""
action_type: str
element_type: str
failure_context: str
def to_hash(self) -> str:
"""Generate MD5 hash of the correction key (16 chars)."""
key_string = f"{self.action_type}|{self.element_type}|{self.failure_context}"
return hashlib.md5(key_string.encode()).hexdigest()[:16]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'action_type': self.action_type,
'element_type': self.element_type,
'failure_context': self.failure_context,
'hash': self.to_hash()
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionKey':
"""Create from dictionary."""
return cls(
action_type=data.get('action_type', 'unknown'),
element_type=data.get('element_type', 'unknown'),
failure_context=data.get('failure_context', '')
)
@dataclass
class CorrectionSource:
"""Source information for a correction."""
session_id: str
workflow_id: Optional[str] = None
node_id: Optional[str] = None
timestamp: datetime = field(default_factory=datetime.now)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'session_id': self.session_id,
'workflow_id': self.workflow_id,
'node_id': self.node_id,
'timestamp': self.timestamp.isoformat()
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionSource':
"""Create from dictionary."""
timestamp = data.get('timestamp')
if isinstance(timestamp, str):
timestamp = datetime.fromisoformat(timestamp)
elif timestamp is None:
timestamp = datetime.now()
return cls(
session_id=data.get('session_id', ''),
workflow_id=data.get('workflow_id'),
node_id=data.get('node_id'),
timestamp=timestamp
)
@dataclass
class Correction:
"""
Individual correction record.
Contains the original context, the applied correction, source information,
and statistics about success/failure.
"""
id: str
key: CorrectionKey
correction_type: CorrectionType
# Original context
original_target: Optional[Dict[str, Any]] = None
original_params: Optional[Dict[str, Any]] = None
# Correction applied
corrected_target: Optional[Dict[str, Any]] = None
corrected_params: Optional[Dict[str, Any]] = None
correction_description: str = ""
# Source information
source: Optional[CorrectionSource] = None
# Statistics
success_count: int = 0
failure_count: int = 0
last_applied: Optional[datetime] = None
# Status
status: CorrectionStatus = CorrectionStatus.ACTIVE
# Metadata
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
tags: List[str] = field(default_factory=list)
@property
def success_rate(self) -> float:
"""Calculate success rate."""
total = self.success_count + self.failure_count
if total == 0:
return 0.0
return self.success_count / total
@property
def confidence_score(self) -> float:
"""Calculate confidence score based on success rate and usage count."""
total = self.success_count + self.failure_count
if total == 0:
return 0.0
# Confidence increases with more data points (up to 100 uses)
usage_factor = min(1.0, total / 100)
return self.success_rate * (0.5 + 0.5 * usage_factor)
def record_application(self, success: bool):
"""Record an application of this correction."""
if success:
self.success_count += 1
else:
self.failure_count += 1
self.last_applied = datetime.now()
self.updated_at = datetime.now()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'id': self.id,
'key': self.key.to_dict(),
'correction_type': self.correction_type.value,
'original_target': self.original_target,
'original_params': self.original_params,
'corrected_target': self.corrected_target,
'corrected_params': self.corrected_params,
'correction_description': self.correction_description,
'source': self.source.to_dict() if self.source else None,
'success_count': self.success_count,
'failure_count': self.failure_count,
'success_rate': self.success_rate,
'confidence_score': self.confidence_score,
'last_applied': self.last_applied.isoformat() if self.last_applied else None,
'status': self.status.value,
'created_at': self.created_at.isoformat(),
'updated_at': self.updated_at.isoformat(),
'tags': self.tags
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Correction':
"""Create from dictionary."""
# Parse datetime fields
created_at = data.get('created_at')
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at)
elif created_at is None:
created_at = datetime.now()
updated_at = data.get('updated_at')
if isinstance(updated_at, str):
updated_at = datetime.fromisoformat(updated_at)
elif updated_at is None:
updated_at = datetime.now()
last_applied = data.get('last_applied')
if isinstance(last_applied, str):
last_applied = datetime.fromisoformat(last_applied)
# Parse nested objects
key_data = data.get('key', {})
key = CorrectionKey.from_dict(key_data)
source_data = data.get('source')
source = CorrectionSource.from_dict(source_data) if source_data else None
return cls(
id=data.get('id', str(uuid.uuid4())),
key=key,
correction_type=CorrectionType(data.get('correction_type', 'other')),
original_target=data.get('original_target'),
original_params=data.get('original_params'),
corrected_target=data.get('corrected_target'),
corrected_params=data.get('corrected_params'),
correction_description=data.get('correction_description', ''),
source=source,
success_count=data.get('success_count', 0),
failure_count=data.get('failure_count', 0),
last_applied=last_applied,
status=CorrectionStatus(data.get('status', 'active')),
created_at=created_at,
updated_at=updated_at,
tags=data.get('tags', [])
)
@dataclass
class CorrectionPackMetadata:
"""Metadata for a correction pack."""
version: str = "1.0.0"
author: str = ""
category: str = "general"
tags: List[str] = field(default_factory=list)
description: str = ""
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'version': self.version,
'author': self.author,
'category': self.category,
'tags': self.tags,
'description': self.description
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPackMetadata':
"""Create from dictionary."""
return cls(
version=data.get('version', '1.0.0'),
author=data.get('author', ''),
category=data.get('category', 'general'),
tags=data.get('tags', []),
description=data.get('description', '')
)
@dataclass
class CorrectionPack:
"""
A pack of corrections that can be exported, imported, and shared.
Contains multiple corrections with metadata and statistics.
"""
id: str
name: str
description: str = ""
# Corrections in this pack
corrections: Dict[str, Correction] = field(default_factory=dict)
# Metadata
metadata: CorrectionPackMetadata = field(default_factory=CorrectionPackMetadata)
# Timestamps
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
# Version history
current_version: int = 1
@property
def correction_count(self) -> int:
"""Number of corrections in the pack."""
return len(self.corrections)
@property
def total_applications(self) -> int:
"""Total number of times corrections were applied."""
return sum(c.success_count + c.failure_count for c in self.corrections.values())
@property
def overall_success_rate(self) -> float:
"""Overall success rate across all corrections."""
total_success = sum(c.success_count for c in self.corrections.values())
total_failure = sum(c.failure_count for c in self.corrections.values())
total = total_success + total_failure
if total == 0:
return 0.0
return total_success / total
@property
def active_corrections(self) -> List[Correction]:
"""Get only active corrections."""
return [c for c in self.corrections.values() if c.status == CorrectionStatus.ACTIVE]
def add_correction(self, correction: Correction) -> bool:
"""
Add a correction to the pack.
Returns True if added, False if already exists.
"""
if correction.id in self.corrections:
return False
self.corrections[correction.id] = correction
self.updated_at = datetime.now()
return True
def remove_correction(self, correction_id: str) -> bool:
"""
Remove a correction from the pack.
Returns True if removed, False if not found.
"""
if correction_id not in self.corrections:
return False
del self.corrections[correction_id]
self.updated_at = datetime.now()
return True
def get_correction(self, correction_id: str) -> Optional[Correction]:
"""Get a correction by ID."""
return self.corrections.get(correction_id)
def find_by_key_hash(self, key_hash: str) -> List[Correction]:
"""Find corrections matching a key hash."""
return [
c for c in self.corrections.values()
if c.key.to_hash() == key_hash and c.status == CorrectionStatus.ACTIVE
]
def merge(self, other: 'CorrectionPack', conflict_strategy: str = 'keep_higher_confidence') -> int:
"""
Merge another pack into this one.
Args:
other: Pack to merge
conflict_strategy: How to handle conflicts:
- 'keep_higher_confidence': Keep correction with higher confidence
- 'keep_newer': Keep the most recently updated correction
- 'keep_existing': Keep existing, ignore new
- 'replace': Always replace with new
Returns:
Number of corrections added/updated
"""
merged_count = 0
for correction_id, correction in other.corrections.items():
if correction_id in self.corrections:
existing = self.corrections[correction_id]
should_replace = False
if conflict_strategy == 'keep_higher_confidence':
should_replace = correction.confidence_score > existing.confidence_score
elif conflict_strategy == 'keep_newer':
should_replace = correction.updated_at > existing.updated_at
elif conflict_strategy == 'replace':
should_replace = True
# 'keep_existing' doesn't replace
if should_replace:
self.corrections[correction_id] = correction
merged_count += 1
else:
self.corrections[correction_id] = correction
merged_count += 1
if merged_count > 0:
self.updated_at = datetime.now()
return merged_count
def get_statistics(self) -> Dict[str, Any]:
"""Get aggregate statistics for the pack."""
corrections = list(self.corrections.values())
active = [c for c in corrections if c.status == CorrectionStatus.ACTIVE]
# Group by correction type
by_type = {}
for c in corrections:
type_name = c.correction_type.value
if type_name not in by_type:
by_type[type_name] = 0
by_type[type_name] += 1
return {
'total_corrections': len(corrections),
'active_corrections': len(active),
'deprecated_corrections': len([c for c in corrections if c.status == CorrectionStatus.DEPRECATED]),
'disabled_corrections': len([c for c in corrections if c.status == CorrectionStatus.DISABLED]),
'total_applications': self.total_applications,
'overall_success_rate': self.overall_success_rate,
'corrections_by_type': by_type,
'average_confidence': sum(c.confidence_score for c in active) / len(active) if active else 0.0
}
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
'id': self.id,
'name': self.name,
'description': self.description,
'corrections': {cid: c.to_dict() for cid, c in self.corrections.items()},
'metadata': self.metadata.to_dict(),
'created_at': self.created_at.isoformat(),
'updated_at': self.updated_at.isoformat(),
'current_version': self.current_version,
'statistics': self.get_statistics()
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPack':
"""Create from dictionary."""
# Parse datetime fields
created_at = data.get('created_at')
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at)
elif created_at is None:
created_at = datetime.now()
updated_at = data.get('updated_at')
if isinstance(updated_at, str):
updated_at = datetime.fromisoformat(updated_at)
elif updated_at is None:
updated_at = datetime.now()
# Parse corrections
corrections_data = data.get('corrections', {})
corrections = {
cid: Correction.from_dict(cdata)
for cid, cdata in corrections_data.items()
}
# Parse metadata
metadata_data = data.get('metadata', {})
metadata = CorrectionPackMetadata.from_dict(metadata_data)
return cls(
id=data.get('id', str(uuid.uuid4())),
name=data.get('name', 'Unnamed Pack'),
description=data.get('description', ''),
corrections=corrections,
metadata=metadata,
created_at=created_at,
updated_at=updated_at,
current_version=data.get('current_version', 1)
)
def generate_correction_id() -> str:
"""Generate a unique correction ID."""
return f"corr_{uuid.uuid4().hex[:12]}"
def generate_pack_id() -> str:
"""Generate a unique pack ID."""
return f"pack_{uuid.uuid4().hex[:12]}"

View File

@@ -0,0 +1,443 @@
"""Main self-healing engine for workflow recovery."""
import time
import logging
from typing import List, Optional, Dict, Any, Tuple
from pathlib import Path
from .models import RecoveryContext, RecoveryResult, RecoverySuggestion
from .learning_repository import LearningRepository
from .confidence_scorer import ConfidenceScorer
from .strategies import (
RecoveryStrategy,
SemanticVariantStrategy,
SpatialFallbackStrategy,
TimingAdaptationStrategy,
FormatTransformationStrategy
)
logger = logging.getLogger(__name__)
class SelfHealingEngine:
"""Main engine for self-healing workflows."""
def __init__(
self,
learning_repo: Optional[LearningRepository] = None,
storage_path: Optional[Path] = None,
correction_packs_enabled: bool = True
):
"""
Initialize the self-healing engine.
Args:
learning_repo: Learning repository instance
storage_path: Path for storing learned patterns
correction_packs_enabled: Whether to consult correction packs
"""
# Initialize learning repository
if learning_repo:
self.learning_repo = learning_repo
else:
storage_path = storage_path or Path('data/healing')
self.learning_repo = LearningRepository(storage_path)
# Initialize confidence scorer
self.confidence_scorer = ConfidenceScorer()
# Initialize recovery strategies
self.recovery_strategies: List[RecoveryStrategy] = [
SemanticVariantStrategy(),
SpatialFallbackStrategy(),
TimingAdaptationStrategy(),
FormatTransformationStrategy()
]
# Configuration
self.max_recovery_time = 30.0 # seconds
self.parallel_execution = True
# Correction packs integration
self.correction_packs_enabled = correction_packs_enabled
self._correction_pack_service = None
@property
def correction_pack_service(self):
"""Lazy-load correction pack service."""
if self._correction_pack_service is None and self.correction_packs_enabled:
try:
from core.corrections import CorrectionPackService
self._correction_pack_service = CorrectionPackService()
logger.info("CorrectionPackService initialized for healing engine")
except ImportError as e:
logger.warning(f"CorrectionPackService not available: {e}")
self.correction_packs_enabled = False
return self._correction_pack_service
def attempt_recovery(self, context: RecoveryContext) -> RecoveryResult:
"""
Attempt to recover from a workflow failure.
Args:
context: Recovery context with failure information
Returns:
RecoveryResult with outcome
"""
start_time = time.time()
# Check if we've exceeded max attempts
if context.attempt_count >= context.max_attempts:
return RecoveryResult(
success=False,
strategy_used='none',
error_message=f'Max recovery attempts ({context.max_attempts}) exceeded'
)
# Try correction packs first (if enabled)
if self.correction_packs_enabled and self.correction_pack_service:
pack_result = self._try_correction_packs(context)
if pack_result and pack_result.success:
return pack_result
# Get learned patterns for this context
learned_patterns = self.learning_repo.get_matching_patterns(context)
# Get prioritized strategies
strategies = self._prioritize_strategies(context, learned_patterns)
# Try each strategy
for strategy in strategies:
# Check time limit
elapsed = time.time() - start_time
if elapsed >= self.max_recovery_time:
break
# Skip if strategy can't handle this failure
if not strategy.can_handle(context):
continue
# Attempt recovery
result = strategy.attempt_recovery(context)
# If successful, learn from it
if result.success:
# Calculate final confidence
historical_success = self._get_historical_success_rate(
strategy.name, context, learned_patterns
)
result.confidence_score = self.confidence_scorer.calculate_recovery_confidence(
result.strategy_used,
context,
historical_success
)
# Check if safe to proceed
if self.confidence_scorer.is_safe_to_proceed(
result.confidence_score,
context.confidence_threshold,
involves_data_modification=self._involves_data_modification(context)
):
# Learn from success
self.learn_from_success(context, result)
return result
else:
# Confidence too low, mark as requiring user input
result.requires_user_input = True
return result
# All strategies failed
total_time = time.time() - start_time
return RecoveryResult(
success=False,
strategy_used='all_failed',
execution_time=total_time,
error_message='All recovery strategies failed',
requires_user_input=True
)
def learn_from_success(self, context: RecoveryContext, result: RecoveryResult):
"""
Learn from successful recovery for future use.
Args:
context: Recovery context
result: Successful recovery result
"""
self.learning_repo.store_pattern(context, result)
def get_recovery_suggestions(self, context: RecoveryContext) -> List[RecoverySuggestion]:
"""
Get ranked recovery suggestions based on learned patterns.
Args:
context: Recovery context
Returns:
List of recovery suggestions sorted by confidence
"""
suggestions = []
# Get learned patterns
learned_patterns = self.learning_repo.get_matching_patterns(context)
# Get suggestions from each strategy
for strategy in self.recovery_strategies:
if not strategy.can_handle(context):
continue
# Calculate confidence based on historical success
historical_success = self._get_historical_success_rate(
strategy.name, context, learned_patterns
)
confidence = self.confidence_scorer.calculate_recovery_confidence(
strategy.name,
context,
historical_success
)
suggestion = RecoverySuggestion(
strategy=strategy.name,
confidence=confidence,
description=self._get_strategy_description(strategy),
estimated_time=self._estimate_strategy_time(strategy, context)
)
suggestions.append(suggestion)
# Sort by confidence
suggestions.sort(key=lambda s: s.confidence, reverse=True)
return suggestions
def prune_learned_patterns(
self,
max_age_days: int = 90,
min_confidence: float = 0.3
):
"""
Prune outdated learned patterns.
Args:
max_age_days: Maximum age for patterns
min_confidence: Minimum confidence threshold
"""
self.learning_repo.prune_outdated_patterns(max_age_days, min_confidence)
def _prioritize_strategies(
self,
context: RecoveryContext,
learned_patterns: List
) -> List[RecoveryStrategy]:
"""Prioritize strategies based on context and learned patterns."""
# Score each strategy
scored_strategies = []
for strategy in self.recovery_strategies:
# Base priority from strategy
priority = strategy.get_priority(context)
# Boost priority if we have successful patterns
historical_success = self._get_historical_success_rate(
strategy.name, context, learned_patterns
)
if historical_success > 0:
priority *= (1.0 + historical_success)
scored_strategies.append((priority, strategy))
# Sort by priority (highest first)
scored_strategies.sort(key=lambda x: x[0], reverse=True)
return [strategy for _, strategy in scored_strategies]
def _get_historical_success_rate(
self,
strategy_name: str,
context: RecoveryContext,
learned_patterns: List
) -> float:
"""Get historical success rate for a strategy."""
matching_patterns = [
p for p in learned_patterns
if p.recovery_strategy == strategy_name
]
if not matching_patterns:
return 0.0
# Return average success rate
return sum(p.success_rate for p in matching_patterns) / len(matching_patterns)
def _involves_data_modification(self, context: RecoveryContext) -> bool:
"""Check if the action involves data modification."""
data_modification_actions = [
'input', 'type', 'submit', 'delete', 'update', 'save'
]
return any(action in context.original_action.lower()
for action in data_modification_actions)
def _get_strategy_description(self, strategy: RecoveryStrategy) -> str:
"""Get human-readable description of strategy."""
descriptions = {
'SemanticVariantStrategy': 'Try semantic variants of the element (e.g., Submit → Send)',
'SpatialFallbackStrategy': 'Search in expanded area around original position',
'TimingAdaptationStrategy': 'Increase wait time for element to appear',
'FormatTransformationStrategy': 'Transform input format to match validation'
}
return descriptions.get(strategy.name, strategy.name)
def _estimate_strategy_time(
self,
strategy: RecoveryStrategy,
context: RecoveryContext
) -> float:
"""Estimate execution time for strategy."""
# Simple estimates in seconds
estimates = {
'SemanticVariantStrategy': 2.0,
'SpatialFallbackStrategy': 5.0,
'TimingAdaptationStrategy': 10.0,
'FormatTransformationStrategy': 1.0
}
return estimates.get(strategy.name, 3.0)
def _try_correction_packs(self, context: RecoveryContext) -> Optional[RecoveryResult]:
"""
Try to find and apply a correction from correction packs.
Args:
context: Recovery context
Returns:
RecoveryResult if a correction was successfully applied, None otherwise
"""
try:
# Get element type from metadata
element_type = context.metadata.get('element_type', 'unknown')
# Search for applicable corrections
corrections = self.correction_pack_service.find_applicable_corrections(
action_type=context.original_action,
element_type=element_type,
failure_context=context.failure_reason,
min_confidence=0.5 # Only use high-confidence corrections
)
if not corrections:
return None
# Try corrections in order of confidence
for correction_info in corrections:
pack_id = correction_info['pack_id']
correction = correction_info['correction']
try:
# Apply the correction
applied = self._apply_correction(context, correction)
if applied:
# Record successful application
self.correction_pack_service.apply_correction(
pack_id=pack_id,
correction_id=correction['id'],
success=True
)
# Propagate to learning repository
result = RecoveryResult(
success=True,
strategy_used=f"correction_pack:{correction['correction_type']}",
confidence_score=correction['confidence_score'],
modified_action=correction.get('corrected_params'),
modified_target=correction.get('corrected_target')
)
# Learn from this for future reference
self.learning_repo.store_pattern(context, result)
logger.info(
f"Applied correction from pack {pack_id}: "
f"{correction['id']} (confidence: {correction['confidence_score']:.2f})"
)
return result
except Exception as e:
logger.warning(f"Failed to apply correction {correction['id']}: {e}")
# Record failed application
self.correction_pack_service.apply_correction(
pack_id=pack_id,
correction_id=correction['id'],
success=False
)
return None
except Exception as e:
logger.error(f"Error in correction pack lookup: {e}")
return None
def _apply_correction(
self,
context: RecoveryContext,
correction: Dict[str, Any]
) -> bool:
"""
Apply a correction to the current context.
Args:
context: Recovery context
correction: Correction data dict
Returns:
True if correction was applied successfully
"""
correction_type = correction.get('correction_type', 'other')
# Handle different correction types
if correction_type == 'target_change':
# Update target in context
if correction.get('corrected_target'):
context.metadata['corrected_target'] = correction['corrected_target']
return True
elif correction_type == 'parameter_change':
# Update parameters in context
if correction.get('corrected_params'):
context.metadata['corrected_params'] = correction['corrected_params']
return True
elif correction_type == 'wait_added':
# Add wait time
wait_time = correction.get('corrected_params', {}).get('wait_time', 2.0)
context.metadata['additional_wait'] = wait_time
return True
elif correction_type == 'timing_adjust':
# Adjust timing
timing = correction.get('corrected_params', {}).get('timing_multiplier', 1.5)
context.metadata['timing_multiplier'] = timing
return True
elif correction_type == 'coordinates_adjust':
# Adjust coordinates
offset = correction.get('corrected_params', {}).get('offset', {})
context.metadata['coordinate_offset'] = offset
return True
# For other types, just mark as applied if we have any correction data
return bool(correction.get('corrected_target') or correction.get('corrected_params'))
def get_correction_pack_statistics(self) -> Optional[Dict[str, Any]]:
"""
Get statistics from correction packs.
Returns:
Statistics dict or None if not available
"""
if not self.correction_packs_enabled or not self.correction_pack_service:
return None
try:
return self.correction_pack_service.get_global_statistics()
except Exception as e:
logger.error(f"Error getting correction pack statistics: {e}")
return None

View File

@@ -0,0 +1,644 @@
"""
Tests for the Correction Packs system.
Tests models, repository, aggregator, and service.
"""
import json
import os
import shutil
import tempfile
import unittest
from datetime import datetime
from pathlib import Path
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.corrections.models import (
CorrectionKey,
CorrectionType,
CorrectionStatus,
Correction,
CorrectionPack,
CorrectionPackMetadata,
CorrectionSource,
generate_correction_id,
generate_pack_id
)
from core.corrections.correction_repository import CorrectionRepository
from core.corrections.aggregator import CorrectionAggregator, CrossWorkflowAggregator
from core.corrections.correction_pack_service import CorrectionPackService
class TestCorrectionKey(unittest.TestCase):
"""Tests for CorrectionKey model."""
def test_to_hash(self):
"""Test hash generation is consistent."""
key1 = CorrectionKey("click", "button", "element_not_found")
key2 = CorrectionKey("click", "button", "element_not_found")
key3 = CorrectionKey("type", "input", "element_not_found")
# Same keys should produce same hash
self.assertEqual(key1.to_hash(), key2.to_hash())
# Different keys should produce different hash
self.assertNotEqual(key1.to_hash(), key3.to_hash())
# Hash should be 16 characters
self.assertEqual(len(key1.to_hash()), 16)
def test_to_dict_from_dict(self):
"""Test serialization round-trip."""
key = CorrectionKey("click", "button", "timeout")
key_dict = key.to_dict()
self.assertEqual(key_dict['action_type'], "click")
self.assertEqual(key_dict['element_type'], "button")
self.assertEqual(key_dict['failure_context'], "timeout")
self.assertIn('hash', key_dict)
restored = CorrectionKey.from_dict(key_dict)
self.assertEqual(key.to_hash(), restored.to_hash())
class TestCorrection(unittest.TestCase):
"""Tests for Correction model."""
def test_create_correction(self):
"""Test creating a correction."""
key = CorrectionKey("click", "button", "not_visible")
correction = Correction(
id=generate_correction_id(),
key=key,
correction_type=CorrectionType.WAIT_ADDED,
original_target={"selector": "#submit"},
corrected_params={"wait_time": 2.0},
correction_description="Added wait for button visibility"
)
self.assertTrue(correction.id.startswith("corr_"))
self.assertEqual(correction.correction_type, CorrectionType.WAIT_ADDED)
self.assertEqual(correction.success_rate, 0.0)
def test_success_rate_calculation(self):
"""Test success rate calculation."""
key = CorrectionKey("click", "button", "not_visible")
correction = Correction(
id=generate_correction_id(),
key=key,
correction_type=CorrectionType.WAIT_ADDED,
success_count=7,
failure_count=3
)
self.assertEqual(correction.success_rate, 0.7)
def test_record_application(self):
"""Test recording application updates stats."""
key = CorrectionKey("click", "button", "not_visible")
correction = Correction(
id=generate_correction_id(),
key=key,
correction_type=CorrectionType.WAIT_ADDED
)
correction.record_application(success=True)
self.assertEqual(correction.success_count, 1)
self.assertEqual(correction.failure_count, 0)
self.assertIsNotNone(correction.last_applied)
correction.record_application(success=False)
self.assertEqual(correction.success_count, 1)
self.assertEqual(correction.failure_count, 1)
def test_to_dict_from_dict(self):
"""Test serialization round-trip."""
key = CorrectionKey("type", "input", "validation_error")
source = CorrectionSource(
session_id="sess_123",
workflow_id="wf_456",
node_id="node_789"
)
correction = Correction(
id="corr_test123",
key=key,
correction_type=CorrectionType.PARAMETER_CHANGE,
original_params={"value": "test"},
corrected_params={"value": "corrected"},
source=source,
success_count=5,
failure_count=2,
tags=["form", "validation"]
)
correction_dict = correction.to_dict()
restored = Correction.from_dict(correction_dict)
self.assertEqual(restored.id, correction.id)
self.assertEqual(restored.key.to_hash(), key.to_hash())
self.assertEqual(restored.correction_type, CorrectionType.PARAMETER_CHANGE)
self.assertEqual(restored.success_count, 5)
self.assertEqual(restored.tags, ["form", "validation"])
class TestCorrectionPack(unittest.TestCase):
"""Tests for CorrectionPack model."""
def test_create_pack(self):
"""Test creating a pack."""
pack = CorrectionPack(
id=generate_pack_id(),
name="Test Pack",
description="A test correction pack"
)
self.assertTrue(pack.id.startswith("pack_"))
self.assertEqual(pack.name, "Test Pack")
self.assertEqual(pack.correction_count, 0)
def test_add_remove_correction(self):
"""Test adding and removing corrections."""
pack = CorrectionPack(
id=generate_pack_id(),
name="Test Pack"
)
key = CorrectionKey("click", "button", "not_found")
correction = Correction(
id="corr_test1",
key=key,
correction_type=CorrectionType.TARGET_CHANGE
)
# Add correction
self.assertTrue(pack.add_correction(correction))
self.assertEqual(pack.correction_count, 1)
# Adding same correction should return False
self.assertFalse(pack.add_correction(correction))
# Remove correction
self.assertTrue(pack.remove_correction("corr_test1"))
self.assertEqual(pack.correction_count, 0)
# Removing non-existent should return False
self.assertFalse(pack.remove_correction("corr_test1"))
def test_find_by_key_hash(self):
"""Test finding corrections by key hash."""
pack = CorrectionPack(id=generate_pack_id(), name="Test Pack")
key1 = CorrectionKey("click", "button", "not_found")
key2 = CorrectionKey("type", "input", "validation")
correction1 = Correction(id="corr_1", key=key1, correction_type=CorrectionType.TARGET_CHANGE)
correction2 = Correction(id="corr_2", key=key2, correction_type=CorrectionType.PARAMETER_CHANGE)
pack.add_correction(correction1)
pack.add_correction(correction2)
# Find by key1 hash
found = pack.find_by_key_hash(key1.to_hash())
self.assertEqual(len(found), 1)
self.assertEqual(found[0].id, "corr_1")
def test_merge_packs(self):
"""Test merging two packs."""
pack1 = CorrectionPack(id="pack_1", name="Pack 1")
pack2 = CorrectionPack(id="pack_2", name="Pack 2")
key = CorrectionKey("click", "button", "error")
corr1 = Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE)
corr2 = Correction(id="corr_2", key=key, correction_type=CorrectionType.WAIT_ADDED)
pack1.add_correction(corr1)
pack2.add_correction(corr2)
merged_count = pack1.merge(pack2)
self.assertEqual(merged_count, 1)
self.assertEqual(pack1.correction_count, 2)
def test_get_statistics(self):
"""Test getting pack statistics."""
pack = CorrectionPack(id=generate_pack_id(), name="Test Pack")
key = CorrectionKey("click", "button", "error")
correction = Correction(
id="corr_1",
key=key,
correction_type=CorrectionType.TARGET_CHANGE,
success_count=10,
failure_count=2
)
pack.add_correction(correction)
stats = pack.get_statistics()
self.assertEqual(stats['total_corrections'], 1)
self.assertEqual(stats['active_corrections'], 1)
self.assertEqual(stats['total_applications'], 12)
class TestCorrectionRepository(unittest.TestCase):
"""Tests for CorrectionRepository."""
def setUp(self):
"""Create temporary directory for tests."""
self.test_dir = tempfile.mkdtemp()
self.repo = CorrectionRepository(self.test_dir)
def tearDown(self):
"""Clean up temporary directory."""
shutil.rmtree(self.test_dir)
def test_create_and_get_pack(self):
"""Test creating and retrieving a pack."""
pack = self.repo.create_pack("Test Pack", "Description", {
'category': 'testing',
'tags': ['unit', 'test']
})
self.assertIsNotNone(pack)
self.assertEqual(pack.name, "Test Pack")
retrieved = self.repo.get_pack(pack.id)
self.assertIsNotNone(retrieved)
self.assertEqual(retrieved.name, "Test Pack")
def test_list_packs(self):
"""Test listing packs."""
self.repo.create_pack("Pack 1", "First pack")
self.repo.create_pack("Pack 2", "Second pack")
packs = self.repo.list_packs()
self.assertEqual(len(packs), 2)
def test_update_pack(self):
"""Test updating a pack."""
pack = self.repo.create_pack("Original Name")
updated = self.repo.update_pack(pack.id, {'name': 'New Name'})
self.assertIsNotNone(updated)
self.assertEqual(updated.name, 'New Name')
def test_delete_pack(self):
"""Test deleting a pack."""
pack = self.repo.create_pack("To Delete")
self.assertTrue(self.repo.delete_pack(pack.id))
self.assertIsNone(self.repo.get_pack(pack.id))
self.assertFalse(self.repo.delete_pack(pack.id))
def test_add_and_find_correction(self):
"""Test adding and finding corrections."""
pack = self.repo.create_pack("Test Pack")
key = CorrectionKey("click", "button", "timeout")
correction = Correction(
id=generate_correction_id(),
key=key,
correction_type=CorrectionType.TIMING_ADJUST
)
self.assertTrue(self.repo.add_correction(pack.id, correction))
# Find by context
found = self.repo.find_by_context("click", "button", "timeout")
self.assertEqual(len(found), 1)
self.assertEqual(found[0][1].id, correction.id)
def test_export_import_pack(self):
"""Test exporting and importing a pack."""
# Create pack with correction
pack = self.repo.create_pack("Export Test")
key = CorrectionKey("click", "button", "error")
correction = Correction(
id="corr_export",
key=key,
correction_type=CorrectionType.TARGET_CHANGE
)
self.repo.add_correction(pack.id, correction)
# Export
exported = self.repo.export_pack(pack.id)
self.assertIsNotNone(exported)
# Import as new
imported = self.repo.import_pack(exported, merge_strategy='create_new')
self.assertIsNotNone(imported)
self.assertNotEqual(imported.id, pack.id)
self.assertEqual(imported.correction_count, 1)
def test_version_and_rollback(self):
"""Test versioning and rollback."""
pack = self.repo.create_pack("Version Test")
# Add correction
key = CorrectionKey("click", "button", "error")
correction1 = Correction(id="corr_v1", key=key, correction_type=CorrectionType.TARGET_CHANGE)
self.repo.add_correction(pack.id, correction1)
# Create version
version = self.repo.create_version(pack.id, "First version")
self.assertEqual(version, 1)
# Add another correction
correction2 = Correction(id="corr_v2", key=key, correction_type=CorrectionType.WAIT_ADDED)
self.repo.add_correction(pack.id, correction2)
# Verify 2 corrections
pack = self.repo.get_pack(pack.id)
self.assertEqual(pack.correction_count, 2)
# Rollback to version 1
self.assertTrue(self.repo.rollback_pack(pack.id, 1))
# Should have 1 correction again
pack = self.repo.get_pack(pack.id)
self.assertEqual(pack.correction_count, 1)
class TestCorrectionAggregator(unittest.TestCase):
"""Tests for CorrectionAggregator."""
def setUp(self):
"""Set up aggregator for tests."""
self.aggregator = CorrectionAggregator(
min_occurrences=2,
min_success_rate=0.5,
min_confidence=0.3
)
def test_aggregate_by_key(self):
"""Test aggregating corrections by key."""
key = CorrectionKey("click", "button", "not_found")
# Create multiple corrections with same key
corrections = [
Correction(
id=f"corr_{i}",
key=key,
correction_type=CorrectionType.WAIT_ADDED,
success_count=5,
failure_count=2
)
for i in range(3)
]
aggregated = self.aggregator.aggregate_corrections(corrections)
# Should have 1 merged correction
self.assertEqual(len(aggregated), 1)
# Sum of success counts
self.assertEqual(aggregated[0].success_count, 15)
def test_filter_by_quality(self):
"""Test filtering by quality thresholds."""
key = CorrectionKey("click", "button", "error")
high_quality = Correction(
id="corr_high",
key=key,
correction_type=CorrectionType.TARGET_CHANGE,
success_count=9,
failure_count=1
)
low_quality = Correction(
id="corr_low",
key=key,
correction_type=CorrectionType.TARGET_CHANGE,
success_count=1,
failure_count=9
)
filtered = self.aggregator.filter_by_quality(
[high_quality, low_quality],
min_success_rate=0.7
)
self.assertEqual(len(filtered), 1)
self.assertEqual(filtered[0].id, "corr_high")
def test_rank_corrections(self):
"""Test ranking corrections."""
key = CorrectionKey("click", "button", "error")
corrections = [
Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=3, failure_count=7),
Correction(id="corr_2", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=8, failure_count=2),
Correction(id="corr_3", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=5, failure_count=5),
]
ranked = self.aggregator.rank_corrections(corrections, sort_by='success_rate')
self.assertEqual(ranked[0].id, "corr_2")
self.assertEqual(ranked[1].id, "corr_3")
self.assertEqual(ranked[2].id, "corr_1")
class TestCorrectionPackService(unittest.TestCase):
"""Tests for CorrectionPackService."""
def setUp(self):
"""Set up service for tests."""
self.test_dir = tempfile.mkdtemp()
self.training_dir = tempfile.mkdtemp()
self.service = CorrectionPackService(self.test_dir, self.training_dir)
def tearDown(self):
"""Clean up temporary directories."""
shutil.rmtree(self.test_dir)
shutil.rmtree(self.training_dir)
def test_create_and_list_packs(self):
"""Test creating and listing packs."""
pack = self.service.create_pack(
name="Service Test Pack",
description="Testing the service",
category="testing",
tags=["unit", "test"]
)
self.assertIsNotNone(pack)
self.assertEqual(pack['name'], "Service Test Pack")
packs = self.service.list_packs()
self.assertEqual(len(packs), 1)
self.assertEqual(packs[0]['name'], "Service Test Pack")
def test_add_correction(self):
"""Test adding corrections via service."""
pack = self.service.create_pack(name="Correction Test")
correction = self.service.add_correction(pack['id'], {
'action_type': 'click',
'element_type': 'button',
'failure_context': 'element_not_found',
'correction_type': 'target_change',
'original_target': {'selector': '#old'},
'corrected_target': {'selector': '#new'},
'description': 'Updated selector'
})
self.assertIsNotNone(correction)
self.assertTrue(correction['id'].startswith('corr_'))
def test_find_applicable_corrections(self):
"""Test finding applicable corrections."""
pack = self.service.create_pack(name="Find Test")
self.service.add_correction(pack['id'], {
'action_type': 'click',
'element_type': 'button',
'failure_context': 'timeout',
'correction_type': 'timing_adjust',
'corrected_params': {'wait_time': 5}
})
found = self.service.find_applicable_corrections(
action_type='click',
element_type='button',
failure_context='timeout',
min_confidence=0.0
)
self.assertEqual(len(found), 1)
def test_export_import_file(self):
"""Test file-based export and import."""
pack = self.service.create_pack(name="File Export Test")
self.service.add_correction(pack['id'], {
'action_type': 'type',
'element_type': 'input',
'failure_context': 'validation',
'correction_type': 'parameter_change'
})
# Export to file
export_file = os.path.join(self.test_dir, 'export.json')
self.assertTrue(self.service.export_pack_to_file(pack['id'], export_file))
self.assertTrue(os.path.exists(export_file))
# Import from file
imported = self.service.import_pack_from_file(export_file)
self.assertIsNotNone(imported)
self.assertEqual(len(imported['corrections']), 1)
def test_version_management(self):
"""Test version creation and rollback."""
pack = self.service.create_pack(name="Version Test")
# Add correction
self.service.add_correction(pack['id'], {
'action_type': 'click',
'element_type': 'button',
'failure_context': 'error',
'correction_type': 'target_change'
})
# Create version
version = self.service.create_version(pack['id'], "v1 snapshot")
self.assertEqual(version, 1)
# List versions
versions = self.service.list_versions(pack['id'])
self.assertEqual(len(versions), 1)
# Add another correction
self.service.add_correction(pack['id'], {
'action_type': 'type',
'element_type': 'input',
'failure_context': 'error',
'correction_type': 'parameter_change'
})
# Get pack - should have 2 corrections
current_pack = self.service.get_pack(pack['id'])
self.assertEqual(len(current_pack['corrections']), 2)
# Rollback
self.assertTrue(self.service.rollback_pack(pack['id'], 1))
# Get pack - should have 1 correction
rolled_back = self.service.get_pack(pack['id'])
self.assertEqual(len(rolled_back['corrections']), 1)
def test_statistics(self):
"""Test getting statistics."""
pack = self.service.create_pack(name="Stats Test")
self.service.add_correction(pack['id'], {
'action_type': 'click',
'element_type': 'button',
'failure_context': 'error',
'correction_type': 'target_change'
})
# Pack statistics
pack_stats = self.service.get_pack_statistics(pack['id'])
self.assertIsNotNone(pack_stats)
self.assertEqual(pack_stats['total_corrections'], 1)
# Global statistics
global_stats = self.service.get_global_statistics()
self.assertIsNotNone(global_stats)
self.assertEqual(global_stats['total_packs'], 1)
class TestCrossWorkflowAggregator(unittest.TestCase):
"""Tests for CrossWorkflowAggregator."""
def test_find_common_corrections(self):
"""Test finding corrections common across workflows."""
aggregator = CrossWorkflowAggregator()
key = CorrectionKey("click", "submit_button", "timeout")
# Corrections from different workflows
corrections = [
Correction(
id="corr_1",
key=key,
correction_type=CorrectionType.WAIT_ADDED,
source=CorrectionSource(session_id="s1", workflow_id="wf_1")
),
Correction(
id="corr_2",
key=key,
correction_type=CorrectionType.WAIT_ADDED,
source=CorrectionSource(session_id="s2", workflow_id="wf_2")
),
Correction(
id="corr_3",
key=key,
correction_type=CorrectionType.WAIT_ADDED,
source=CorrectionSource(session_id="s3", workflow_id="wf_3")
),
]
common = aggregator.find_common_corrections(corrections, min_workflows=2)
self.assertEqual(len(common), 1)
merged, workflows = common[0]
self.assertEqual(len(workflows), 3)
def test_detect_patterns(self):
"""Test pattern detection."""
aggregator = CrossWorkflowAggregator()
key = CorrectionKey("click", "button", "error")
corrections = [
Correction(id=f"corr_{i}", key=key, correction_type=CorrectionType.WAIT_ADDED)
for i in range(5)
]
patterns = aggregator.detect_patterns(corrections)
self.assertIn('patterns', patterns)
self.assertIn('summary', patterns)
self.assertEqual(patterns['summary']['total_corrections'], 5)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,640 @@
"""
API endpoints pour les Correction Packs.
Provides REST API for managing correction packs, corrections,
aggregation, export/import, and versioning.
"""
import os
import sys
from flask import Blueprint, request, jsonify, Response
from typing import Optional
# Add parent paths for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))))
from core.corrections import CorrectionPackService
correction_packs_bp = Blueprint('correction_packs', __name__)
# Initialize service (singleton pattern)
_service: Optional[CorrectionPackService] = None
def get_service() -> CorrectionPackService:
"""Get or create the correction pack service."""
global _service
if _service is None:
# Use paths relative to the project root
base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
storage_path = os.path.join(base_path, "data", "correction_packs")
training_path = os.path.join(base_path, "training_data")
_service = CorrectionPackService(storage_path, training_path)
return _service
# ========== Pack CRUD Endpoints ==========
@correction_packs_bp.route('/correction-packs', methods=['GET'])
def list_packs():
"""
List all correction packs.
Query params:
category: Filter by category
tags: Filter by tags (comma-separated)
Returns:
JSON list of pack summaries
"""
try:
service = get_service()
category = request.args.get('category')
tags_param = request.args.get('tags')
tags = tags_param.split(',') if tags_param else None
packs = service.list_packs(category=category, tags=tags)
return jsonify({'success': True, 'packs': packs})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs', methods=['POST'])
def create_pack():
"""
Create a new correction pack.
Request body:
name: Pack name (required)
description: Pack description
category: Pack category
tags: List of tags
author: Pack author
Returns:
Created pack
"""
try:
service = get_service()
data = request.get_json() or {}
if not data.get('name'):
return jsonify({'success': False, 'error': 'name is required'}), 400
pack = service.create_pack(
name=data['name'],
description=data.get('description', ''),
category=data.get('category', 'general'),
tags=data.get('tags', []),
author=data.get('author', '')
)
return jsonify({'success': True, 'pack': pack}), 201
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/<pack_id>', methods=['GET'])
def get_pack(pack_id: str):
"""
Get a pack by ID.
Returns:
Pack details including all corrections
"""
try:
service = get_service()
pack = service.get_pack(pack_id)
if not pack:
return jsonify({'success': False, 'error': 'Pack not found'}), 404
return jsonify({'success': True, 'pack': pack})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/<pack_id>', methods=['PUT'])
def update_pack(pack_id: str):
"""
Update a pack.
Request body:
name: New name
description: New description
metadata: New metadata dict
Returns:
Updated pack
"""
try:
service = get_service()
data = request.get_json() or {}
pack = service.update_pack(
pack_id=pack_id,
name=data.get('name'),
description=data.get('description'),
metadata=data.get('metadata')
)
if not pack:
return jsonify({'success': False, 'error': 'Pack not found'}), 404
return jsonify({'success': True, 'pack': pack})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/<pack_id>', methods=['DELETE'])
def delete_pack(pack_id: str):
"""
Delete a pack.
Returns:
Success status
"""
try:
service = get_service()
if not service.delete_pack(pack_id):
return jsonify({'success': False, 'error': 'Pack not found'}), 404
return jsonify({'success': True, 'message': 'Pack deleted'})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
# ========== Correction Endpoints ==========
@correction_packs_bp.route('/correction-packs/<pack_id>/corrections', methods=['POST'])
def add_correction(pack_id: str):
"""
Add a correction to a pack.
Request body:
action_type: Type of action (required)
element_type: Type of element
failure_context: Description of failure
correction_type: Type of correction
original_target: Original target dict
original_params: Original parameters dict
corrected_target: Corrected target dict
corrected_params: Corrected parameters dict
description: Correction description
tags: List of tags
source: Source info dict
Returns:
Created correction
"""
try:
service = get_service()
data = request.get_json() or {}
if not data.get('action_type'):
return jsonify({'success': False, 'error': 'action_type is required'}), 400
correction = service.add_correction(pack_id, data)
if not correction:
return jsonify({'success': False, 'error': 'Failed to add correction'}), 400
return jsonify({'success': True, 'correction': correction}), 201
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/<pack_id>/corrections/<correction_id>', methods=['PUT'])
def update_correction(pack_id: str, correction_id: str):
"""
Update a correction.
Request body:
corrected_target: New corrected target
corrected_params: New corrected params
correction_description: New description
status: New status
tags: New tags
Returns:
Updated correction
"""
try:
service = get_service()
data = request.get_json() or {}
correction = service.update_correction(pack_id, correction_id, data)
if not correction:
return jsonify({'success': False, 'error': 'Correction not found'}), 404
return jsonify({'success': True, 'correction': correction})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/<pack_id>/corrections/<correction_id>', methods=['DELETE'])
def remove_correction(pack_id: str, correction_id: str):
"""
Remove a correction from a pack.
Returns:
Success status
"""
try:
service = get_service()
if not service.remove_correction(pack_id, correction_id):
return jsonify({'success': False, 'error': 'Correction not found'}), 404
return jsonify({'success': True, 'message': 'Correction removed'})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
# ========== Aggregation Endpoints ==========
@correction_packs_bp.route('/correction-packs/<pack_id>/aggregate', methods=['POST'])
def aggregate_from_sessions(pack_id: str):
"""
Aggregate corrections from training sessions into a pack.
Request body:
workflow_id: Filter by workflow ID
min_occurrences: Minimum occurrences to include
min_success_rate: Minimum success rate to include
session_files: Specific session files to process
Returns:
Aggregation summary
"""
try:
service = get_service()
data = request.get_json() or {}
result = service.aggregate_from_sessions(
pack_id=pack_id,
session_files=data.get('session_files'),
workflow_id=data.get('workflow_id'),
min_occurrences=data.get('min_occurrences', 1),
min_success_rate=data.get('min_success_rate', 0.0)
)
return jsonify({'success': True, 'result': result})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/<pack_id>/aggregate-cross-workflow', methods=['POST'])
def aggregate_cross_workflow(pack_id: str):
"""
Find and aggregate corrections common across multiple workflows.
Request body:
min_workflows: Minimum number of workflows (default: 2)
Returns:
Aggregation summary with patterns
"""
try:
service = get_service()
data = request.get_json() or {}
result = service.aggregate_cross_workflow(
pack_id=pack_id,
min_workflows=data.get('min_workflows', 2)
)
return jsonify({'success': True, 'result': result})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
# ========== Search and Apply Endpoints ==========
@correction_packs_bp.route('/correction-packs/find', methods=['POST'])
def find_corrections():
"""
Find corrections applicable to a failure context.
Request body:
action_type: Type of action that failed (required)
element_type: Type of element involved
failure_context: Description of failure
pack_ids: Specific packs to search
min_confidence: Minimum confidence score
Returns:
List of applicable corrections
"""
try:
service = get_service()
data = request.get_json() or {}
if not data.get('action_type'):
return jsonify({'success': False, 'error': 'action_type is required'}), 400
corrections = service.find_applicable_corrections(
action_type=data['action_type'],
element_type=data.get('element_type', 'unknown'),
failure_context=data.get('failure_context', ''),
pack_ids=data.get('pack_ids'),
min_confidence=data.get('min_confidence', 0.3)
)
return jsonify({'success': True, 'corrections': corrections})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/apply', methods=['POST'])
def apply_correction():
"""
Record an application of a correction.
Request body:
pack_id: Pack ID (required)
correction_id: Correction ID (required)
success: Whether the application was successful (required)
context: Optional application context
Returns:
Success status
"""
try:
service = get_service()
data = request.get_json() or {}
if not data.get('pack_id') or not data.get('correction_id'):
return jsonify({'success': False, 'error': 'pack_id and correction_id are required'}), 400
if 'success' not in data:
return jsonify({'success': False, 'error': 'success is required'}), 400
service.apply_correction(
pack_id=data['pack_id'],
correction_id=data['correction_id'],
success=data['success'],
application_context=data.get('context')
)
return jsonify({'success': True, 'message': 'Application recorded'})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/search', methods=['GET'])
def search_corrections():
"""
Search corrections by text query.
Query params:
q: Search query (required)
pack_id: Optional pack ID to search in
Returns:
List of matching corrections
"""
try:
service = get_service()
query = request.args.get('q')
if not query:
return jsonify({'success': False, 'error': 'q parameter is required'}), 400
pack_id = request.args.get('pack_id')
corrections = service.search_corrections(query, pack_id)
return jsonify({'success': True, 'corrections': corrections})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
# ========== Export/Import Endpoints ==========
@correction_packs_bp.route('/correction-packs/<pack_id>/export', methods=['GET'])
def export_pack(pack_id: str):
"""
Export a pack to JSON or YAML.
Query params:
format: Export format ('json' or 'yaml', default: 'json')
Returns:
Exported content
"""
try:
service = get_service()
format_type = request.args.get('format', 'json')
if format_type not in ['json', 'yaml']:
return jsonify({'success': False, 'error': 'Invalid format'}), 400
content = service.export_pack(pack_id, format=format_type)
if not content:
return jsonify({'success': False, 'error': 'Pack not found'}), 404
# Determine content type
if format_type == 'yaml':
content_type = 'application/x-yaml'
filename = f'correction_pack_{pack_id}.yaml'
else:
content_type = 'application/json'
filename = f'correction_pack_{pack_id}.json'
return Response(
content,
mimetype=content_type,
headers={
'Content-Disposition': f'attachment; filename={filename}'
}
)
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/import', methods=['POST'])
def import_pack():
"""
Import a pack from JSON or YAML.
Request body or file upload:
content: Content to import (if JSON body)
format: Import format ('json' or 'yaml')
merge_strategy: How to handle existing pack
- 'create_new': Always create new pack
- 'merge': Merge into existing pack
- 'replace': Replace existing pack
Returns:
Imported pack
"""
try:
service = get_service()
# Check for file upload
if 'file' in request.files:
file = request.files['file']
content = file.read().decode('utf-8')
filename = file.filename
# Auto-detect format from filename
if filename and (filename.endswith('.yaml') or filename.endswith('.yml')):
format_type = 'yaml'
else:
format_type = 'json'
merge_strategy = request.form.get('merge_strategy', 'create_new')
else:
# JSON body
data = request.get_json() or {}
content = data.get('content')
if not content:
return jsonify({'success': False, 'error': 'content or file is required'}), 400
format_type = data.get('format', 'json')
merge_strategy = data.get('merge_strategy', 'create_new')
pack = service.import_pack(content, format=format_type, merge_strategy=merge_strategy)
if not pack:
return jsonify({'success': False, 'error': 'Failed to import pack'}), 400
return jsonify({'success': True, 'pack': pack}), 201
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
# ========== Version Endpoints ==========
@correction_packs_bp.route('/correction-packs/<pack_id>/versions', methods=['GET'])
def list_versions(pack_id: str):
"""
List all versions of a pack.
Returns:
List of version info
"""
try:
service = get_service()
versions = service.list_versions(pack_id)
return jsonify({'success': True, 'versions': versions})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/<pack_id>/versions', methods=['POST'])
def create_version(pack_id: str):
"""
Create a version snapshot of a pack.
Request body:
description: Version description
Returns:
Created version number
"""
try:
service = get_service()
data = request.get_json() or {}
version = service.create_version(
pack_id=pack_id,
description=data.get('description', '')
)
if version < 0:
return jsonify({'success': False, 'error': 'Pack not found'}), 404
return jsonify({'success': True, 'version': version}), 201
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/<pack_id>/rollback', methods=['POST'])
def rollback_pack(pack_id: str):
"""
Rollback a pack to a previous version.
Request body:
version: Version number to rollback to (required)
Returns:
Success status
"""
try:
service = get_service()
data = request.get_json() or {}
version = data.get('version')
if version is None:
return jsonify({'success': False, 'error': 'version is required'}), 400
if not service.rollback_pack(pack_id, version):
return jsonify({'success': False, 'error': 'Rollback failed'}), 400
return jsonify({'success': True, 'message': f'Rolled back to version {version}'})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
# ========== Statistics Endpoints ==========
@correction_packs_bp.route('/correction-packs/<pack_id>/statistics', methods=['GET'])
def get_pack_statistics(pack_id: str):
"""
Get detailed statistics for a pack.
Returns:
Statistics dict
"""
try:
service = get_service()
stats = service.get_pack_statistics(pack_id)
if not stats:
return jsonify({'success': False, 'error': 'Pack not found'}), 404
return jsonify({'success': True, 'statistics': stats})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500
@correction_packs_bp.route('/correction-packs/statistics', methods=['GET'])
def get_global_statistics():
"""
Get global statistics across all packs.
Returns:
Global statistics dict
"""
try:
service = get_service()
stats = service.get_global_statistics()
return jsonify({'success': True, 'statistics': stats})
except Exception as e:
return jsonify({'success': False, 'error': str(e)}), 500

View File

@@ -0,0 +1,201 @@
"""
Visual Workflow Builder - Backend Flask Application
This is the main entry point for the Visual Workflow Builder backend API.
It provides REST endpoints for workflow management and WebSocket support
for real-time execution updates.
"""
from flask import Flask
from flask_cors import CORS
from flask_socketio import SocketIO
from flask_sqlalchemy import SQLAlchemy
from flask_caching import Cache
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Initialize Flask app
app = Flask(__name__)
# Configuration
app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', 'dev-secret-key-change-in-production')
app.config['SQLALCHEMY_DATABASE_URI'] = os.getenv('DATABASE_URL', 'sqlite:///workflows.db')
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024 # 10MB max upload
app.config['CACHE_TYPE'] = 'redis' if os.getenv('REDIS_URL') else 'simple'
app.config['CACHE_REDIS_URL'] = os.getenv('REDIS_URL', 'redis://localhost:6379/0')
# Initialize extensions
db = SQLAlchemy(app)
cache = Cache(app)
socketio = SocketIO(
app,
cors_allowed_origins="*",
async_mode='threading',
logger=True,
engineio_logger=True
)
# Enable CORS
CORS(app, resources={
r"/api/*": {
"origins": os.getenv('CORS_ORIGINS', 'http://localhost:3000').split(','),
"methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": ["Content-Type", "Authorization"]
}
})
# Import and register blueprints (minimal set)
from api.workflows import workflows_bp
from api.screen_capture import screen_capture_bp
from api.real_demo import real_demo_bp
from api.errors import error_response
app.register_blueprint(workflows_bp, url_prefix='/api/workflows')
app.register_blueprint(screen_capture_bp, url_prefix='/api/screen-capture')
app.register_blueprint(real_demo_bp)
# Optional / Phase 2+ blueprints (loaded only if modules are available)
try:
from api.self_healing import self_healing_bp
app.register_blueprint(self_healing_bp)
except ImportError as e:
print(f"⚠️ Blueprint self_healing désactivé: {e}")
try:
from api.visual_targets import visual_targets_bp, init_visual_target_manager
app.register_blueprint(visual_targets_bp)
VISUAL_TARGETS_BP_AVAILABLE = True
except ImportError as e:
print(f"⚠️ Blueprint visual_targets désactivé: {e}")
VISUAL_TARGETS_BP_AVAILABLE = False
init_visual_target_manager = None
try:
from api.element_detection import element_detection_bp, init_element_detection
app.register_blueprint(element_detection_bp)
ELEMENT_DETECTION_BP_AVAILABLE = True
except ImportError as e:
print(f"⚠️ Blueprint element_detection désactivé: {e}")
ELEMENT_DETECTION_BP_AVAILABLE = False
init_element_detection = None
try:
from api.analytics import analytics_bp
app.register_blueprint(analytics_bp, url_prefix='/api/analytics')
except ImportError:
pass
# Register other blueprints (optional - depends on Phase 2+ services)
try:
from api.templates import templates_bp
app.register_blueprint(templates_bp, url_prefix='/api/templates')
except ImportError as e:
print(f"⚠️ Blueprint templates désactivé: {e}")
from api.node_types import node_types_bp
app.register_blueprint(node_types_bp, url_prefix='/api/node-types')
try:
from api.executions import executions_bp
app.register_blueprint(executions_bp, url_prefix='/api/executions')
except ImportError as e:
print(f"⚠️ Blueprint executions désactivé: {e}")
try:
from api.import_export import import_export_bp
app.register_blueprint(import_export_bp, url_prefix='/api')
except ImportError as e:
print(f"⚠️ Blueprint import_export désactivé: {e}")
try:
from api.correction_packs import correction_packs_bp
app.register_blueprint(correction_packs_bp, url_prefix='/api')
print("✅ Blueprint correction_packs enregistré")
except ImportError as e:
print(f"⚠️ Blueprint correction_packs désactivé: {e}")
# Import WebSocket handlers (optional)
try:
from api import websocket_handlers # noqa: F401
except Exception as e:
print(f"⚠️ WebSocket handlers désactivés: {e}")
# Global error handlers
@app.errorhandler(404)
def not_found(error):
"""Handle 404 errors"""
return error_response(404, "Resource not found")
@app.errorhandler(405)
def method_not_allowed(error):
"""Handle 405 errors"""
return error_response(405, "Method not allowed")
@app.errorhandler(500)
def internal_error(error):
"""Handle 500 errors"""
return error_response(500, "Internal server error")
@app.errorhandler(Exception)
def handle_exception(error):
"""Handle all unhandled exceptions"""
import traceback
traceback.print_exc()
return error_response(500, f"Unexpected error: {str(error)}")
# Health check endpoint
@app.route('/health')
def health_check():
"""Health check endpoint for monitoring"""
return {'status': 'healthy', 'version': '1.0.0'}
# Create database tables
with app.app_context():
db.create_all()
# Initialize VisualTargetManager with RPA Vision V3 components (optional)
try:
from core.capture.screen_capturer import ScreenCapturer
from core.detection.ui_detector import UIDetector
from core.embedding.fusion_engine import FusionEngine
# Only initialize if the related blueprints were actually loaded
if VISUAL_TARGETS_BP_AVAILABLE and init_visual_target_manager:
screen_capturer = ScreenCapturer()
ui_detector = UIDetector()
fusion_engine = FusionEngine()
init_visual_target_manager(screen_capturer, ui_detector, fusion_engine)
if ELEMENT_DETECTION_BP_AVAILABLE and init_element_detection:
# Reuse the same instances when possible
if 'ui_detector' not in locals():
ui_detector = UIDetector()
if 'screen_capturer' not in locals():
screen_capturer = ScreenCapturer()
init_element_detection(ui_detector, screen_capturer)
if (VISUAL_TARGETS_BP_AVAILABLE and init_visual_target_manager) or (ELEMENT_DETECTION_BP_AVAILABLE and init_element_detection):
print("✅ Services visuels initialisés (VisualTargets / ElementDetection)")
except ImportError as e:
print(f"⚠️ Core RPA non disponible pour l'initialisation visuelle: {e}")
except Exception as e:
print(f"❌ Erreur lors de l'initialisation des services visuels: {e}")
if __name__ == '__main__':
port = int(os.getenv('PORT', 5002))
debug = os.getenv('FLASK_ENV') == 'development'
socketio.run(
app,
host='0.0.0.0',
port=port,
debug=debug,
use_reloader=debug,
allow_unsafe_werkzeug=True # For development only
)