diff --git a/core/corrections/__init__.py b/core/corrections/__init__.py new file mode 100644 index 000000000..778def9c9 --- /dev/null +++ b/core/corrections/__init__.py @@ -0,0 +1,30 @@ +""" +Correction Packs System - Cross-workflow correction capitalization. + +This module provides a system for storing, aggregating, and applying +user corrections across multiple workflows and sessions. +""" + +from .models import ( + CorrectionKey, + CorrectionType, + CorrectionStatus, + Correction, + CorrectionPack, + CorrectionPackMetadata +) +from .correction_repository import CorrectionRepository +from .aggregator import CorrectionAggregator +from .correction_pack_service import CorrectionPackService + +__all__ = [ + 'CorrectionKey', + 'CorrectionType', + 'CorrectionStatus', + 'Correction', + 'CorrectionPack', + 'CorrectionPackMetadata', + 'CorrectionRepository', + 'CorrectionAggregator', + 'CorrectionPackService' +] diff --git a/core/corrections/aggregator.py b/core/corrections/aggregator.py new file mode 100644 index 000000000..1185a32c6 --- /dev/null +++ b/core/corrections/aggregator.py @@ -0,0 +1,545 @@ +""" +Correction Aggregator - Aggregate corrections from multiple sessions. + +Groups corrections by key hash and filters by quality thresholds. +""" + +import logging +from collections import defaultdict +from datetime import datetime +from typing import Dict, List, Optional, Any, Tuple + +from .models import ( + Correction, + CorrectionKey, + CorrectionType, + CorrectionSource, + CorrectionStatus, + generate_correction_id +) + +logger = logging.getLogger(__name__) + + +class CorrectionAggregator: + """ + Aggregates corrections from multiple sources (sessions, workflows). + + Groups corrections by CorrectionKey hash and merges statistics. + Filters out low-quality corrections based on configurable thresholds. + """ + + def __init__( + self, + min_occurrences: int = 2, + min_success_rate: float = 0.5, + min_confidence: float = 0.3 + ): + """ + Initialize the aggregator. + + Args: + min_occurrences: Minimum number of times a correction must occur + min_success_rate: Minimum success rate to include correction + min_confidence: Minimum confidence score to include correction + """ + self.min_occurrences = min_occurrences + self.min_success_rate = min_success_rate + self.min_confidence = min_confidence + + def aggregate_corrections( + self, + corrections: List[Correction], + merge_strategy: str = 'weighted_average' + ) -> List[Correction]: + """ + Aggregate a list of corrections by key hash. + + Args: + corrections: List of corrections to aggregate + merge_strategy: How to merge duplicate corrections: + - 'weighted_average': Weight by usage count + - 'best_confidence': Keep highest confidence + - 'most_recent': Keep most recently used + + Returns: + List of aggregated corrections + """ + # Group by key hash + grouped: Dict[str, List[Correction]] = defaultdict(list) + for correction in corrections: + key_hash = correction.key.to_hash() + grouped[key_hash].append(correction) + + # Aggregate each group + aggregated = [] + for key_hash, group in grouped.items(): + if len(group) == 1: + # Single correction, just check thresholds + if self._passes_thresholds(group[0]): + aggregated.append(group[0]) + else: + # Multiple corrections, merge + merged = self._merge_corrections(group, merge_strategy) + if merged and self._passes_thresholds(merged): + aggregated.append(merged) + + return aggregated + + def aggregate_from_sessions( + self, + session_corrections: List[Dict[str, Any]], + workflow_id: Optional[str] = None + ) -> List[Correction]: + """ + Aggregate corrections from session data (user_corrections format). + + Args: + session_corrections: List of correction dicts from TrainingSession + workflow_id: Optional workflow ID to filter by + + Returns: + List of aggregated Correction objects + """ + # Convert session corrections to Correction objects + corrections = [] + for sc in session_corrections: + try: + correction = self._session_correction_to_correction(sc) + if correction: + # Filter by workflow if specified + if workflow_id is None or (correction.source and correction.source.workflow_id == workflow_id): + corrections.append(correction) + except Exception as e: + logger.warning(f"Error converting session correction: {e}") + + # Aggregate + return self.aggregate_corrections(corrections) + + def group_by_type( + self, + corrections: List[Correction] + ) -> Dict[CorrectionType, List[Correction]]: + """ + Group corrections by their type. + + Args: + corrections: List of corrections + + Returns: + Dict mapping correction type to list of corrections + """ + grouped: Dict[CorrectionType, List[Correction]] = defaultdict(list) + for correction in corrections: + grouped[correction.correction_type].append(correction) + return dict(grouped) + + def group_by_action_type( + self, + corrections: List[Correction] + ) -> Dict[str, List[Correction]]: + """ + Group corrections by the action type they apply to. + + Args: + corrections: List of corrections + + Returns: + Dict mapping action type to list of corrections + """ + grouped: Dict[str, List[Correction]] = defaultdict(list) + for correction in corrections: + grouped[correction.key.action_type].append(correction) + return dict(grouped) + + def filter_by_quality( + self, + corrections: List[Correction], + min_success_rate: Optional[float] = None, + min_confidence: Optional[float] = None, + min_applications: Optional[int] = None + ) -> List[Correction]: + """ + Filter corrections by quality thresholds. + + Args: + corrections: List of corrections to filter + min_success_rate: Minimum success rate (overrides default) + min_confidence: Minimum confidence (overrides default) + min_applications: Minimum number of applications + + Returns: + Filtered list of corrections + """ + min_sr = min_success_rate if min_success_rate is not None else self.min_success_rate + min_conf = min_confidence if min_confidence is not None else self.min_confidence + min_apps = min_applications if min_applications is not None else 1 + + return [ + c for c in corrections + if c.success_rate >= min_sr + and c.confidence_score >= min_conf + and (c.success_count + c.failure_count) >= min_apps + ] + + def rank_corrections( + self, + corrections: List[Correction], + sort_by: str = 'confidence' + ) -> List[Correction]: + """ + Rank corrections by specified criteria. + + Args: + corrections: List of corrections + sort_by: Ranking criteria: + - 'confidence': By confidence score + - 'success_rate': By success rate + - 'usage': By total applications + - 'recent': By last application date + + Returns: + Sorted list of corrections + """ + if sort_by == 'confidence': + return sorted(corrections, key=lambda c: c.confidence_score, reverse=True) + elif sort_by == 'success_rate': + return sorted(corrections, key=lambda c: c.success_rate, reverse=True) + elif sort_by == 'usage': + return sorted(corrections, key=lambda c: c.success_count + c.failure_count, reverse=True) + elif sort_by == 'recent': + return sorted( + corrections, + key=lambda c: c.last_applied or datetime.min, + reverse=True + ) + else: + return corrections + + def calculate_statistics( + self, + corrections: List[Correction] + ) -> Dict[str, Any]: + """ + Calculate aggregate statistics for a list of corrections. + + Args: + corrections: List of corrections + + Returns: + Statistics dict + """ + if not corrections: + return { + 'count': 0, + 'total_applications': 0, + 'average_success_rate': 0.0, + 'average_confidence': 0.0, + 'by_type': {}, + 'by_action': {} + } + + total_apps = sum(c.success_count + c.failure_count for c in corrections) + total_success = sum(c.success_count for c in corrections) + + # Group by type + by_type = self.group_by_type(corrections) + type_stats = { + t.value: len(cs) for t, cs in by_type.items() + } + + # Group by action + by_action = self.group_by_action_type(corrections) + action_stats = { + a: len(cs) for a, cs in by_action.items() + } + + return { + 'count': len(corrections), + 'total_applications': total_apps, + 'total_success': total_success, + 'total_failure': total_apps - total_success, + 'average_success_rate': total_success / total_apps if total_apps > 0 else 0.0, + 'average_confidence': sum(c.confidence_score for c in corrections) / len(corrections), + 'by_type': type_stats, + 'by_action': action_stats + } + + def _passes_thresholds(self, correction: Correction) -> bool: + """Check if a correction passes quality thresholds.""" + total = correction.success_count + correction.failure_count + + # Always include if no applications yet (new correction) + if total == 0: + return True + + # Check minimum occurrences + if total < self.min_occurrences: + return False + + # Check success rate + if correction.success_rate < self.min_success_rate: + return False + + # Check confidence + if correction.confidence_score < self.min_confidence: + return False + + return True + + def _merge_corrections( + self, + corrections: List[Correction], + strategy: str = 'weighted_average' + ) -> Optional[Correction]: + """ + Merge multiple corrections with the same key hash. + + Args: + corrections: List of corrections to merge + strategy: Merge strategy + + Returns: + Merged correction + """ + if not corrections: + return None + + if len(corrections) == 1: + return corrections[0] + + if strategy == 'best_confidence': + return max(corrections, key=lambda c: c.confidence_score) + + elif strategy == 'most_recent': + return max( + corrections, + key=lambda c: c.last_applied or c.created_at + ) + + else: # weighted_average (default) + # Use the correction with highest usage as base + base = max(corrections, key=lambda c: c.success_count + c.failure_count) + + # Sum statistics + total_success = sum(c.success_count for c in corrections) + total_failure = sum(c.failure_count for c in corrections) + + # Find most recent application + last_applied = None + for c in corrections: + if c.last_applied: + if last_applied is None or c.last_applied > last_applied: + last_applied = c.last_applied + + # Merge tags + all_tags = set() + for c in corrections: + all_tags.update(c.tags) + + # Create merged correction + merged = Correction( + id=generate_correction_id(), + key=base.key, + correction_type=base.correction_type, + original_target=base.original_target, + original_params=base.original_params, + corrected_target=base.corrected_target, + corrected_params=base.corrected_params, + correction_description=base.correction_description, + source=base.source, # Keep first source as primary + success_count=total_success, + failure_count=total_failure, + last_applied=last_applied, + status=CorrectionStatus.ACTIVE, + created_at=min(c.created_at for c in corrections), + updated_at=datetime.now(), + tags=list(all_tags) + ) + + return merged + + def _session_correction_to_correction( + self, + session_correction: Dict[str, Any] + ) -> Optional[Correction]: + """ + Convert a session correction dict to a Correction object. + + Args: + session_correction: Dict from TrainingSession.user_corrections + + Returns: + Correction object or None if invalid + """ + try: + # Extract key components + action_type = session_correction.get('action_type', session_correction.get('type', 'unknown')) + element_type = session_correction.get('element_type', 'unknown') + failure_context = session_correction.get('failure_reason', session_correction.get('reason', '')) + + key = CorrectionKey(action_type, element_type, failure_context) + + # Determine correction type + correction_type_str = session_correction.get('correction_type', 'other') + try: + correction_type = CorrectionType(correction_type_str) + except ValueError: + correction_type = CorrectionType.OTHER + + # Create source + source = CorrectionSource( + session_id=session_correction.get('session_id', ''), + workflow_id=session_correction.get('workflow_id'), + node_id=session_correction.get('node_id') + ) + + timestamp = session_correction.get('timestamp') + if timestamp: + if isinstance(timestamp, str): + source.timestamp = datetime.fromisoformat(timestamp) + elif isinstance(timestamp, datetime): + source.timestamp = timestamp + + # Create correction + return Correction( + id=generate_correction_id(), + key=key, + correction_type=correction_type, + original_target=session_correction.get('original_target'), + original_params=session_correction.get('original_params'), + corrected_target=session_correction.get('corrected_target', session_correction.get('new_target')), + corrected_params=session_correction.get('corrected_params', session_correction.get('new_params')), + correction_description=session_correction.get('description', ''), + source=source, + success_count=1 if session_correction.get('applied', True) else 0, + failure_count=0 if session_correction.get('applied', True) else 1, + tags=session_correction.get('tags', []) + ) + + except Exception as e: + logger.warning(f"Error converting session correction: {e}") + return None + + +class CrossWorkflowAggregator(CorrectionAggregator): + """ + Extended aggregator for cross-workflow correction analysis. + + Provides additional methods for analyzing patterns across workflows. + """ + + def find_common_corrections( + self, + corrections: List[Correction], + min_workflows: int = 2 + ) -> List[Tuple[Correction, List[str]]]: + """ + Find corrections that appear in multiple workflows. + + Args: + corrections: List of corrections + min_workflows: Minimum number of workflows + + Returns: + List of (correction, workflow_ids) tuples + """ + # Group by key hash + grouped: Dict[str, List[Correction]] = defaultdict(list) + for correction in corrections: + key_hash = correction.key.to_hash() + grouped[key_hash].append(correction) + + # Find common corrections + common = [] + for key_hash, group in grouped.items(): + workflow_ids = set() + for c in group: + if c.source and c.source.workflow_id: + workflow_ids.add(c.source.workflow_id) + + if len(workflow_ids) >= min_workflows: + # Merge the group + merged = self._merge_corrections(group) + if merged: + common.append((merged, list(workflow_ids))) + + # Sort by number of workflows + common.sort(key=lambda x: len(x[1]), reverse=True) + return common + + def detect_patterns( + self, + corrections: List[Correction] + ) -> Dict[str, Any]: + """ + Detect patterns in corrections. + + Args: + corrections: List of corrections + + Returns: + Pattern analysis dict + """ + if not corrections: + return {'patterns': [], 'summary': {}} + + patterns = [] + + # Pattern 1: Common action type corrections + by_action = self.group_by_action_type(corrections) + for action, action_corrections in by_action.items(): + if len(action_corrections) >= 3: + patterns.append({ + 'type': 'frequent_action_correction', + 'action_type': action, + 'count': len(action_corrections), + 'avg_success_rate': sum(c.success_rate for c in action_corrections) / len(action_corrections) + }) + + # Pattern 2: Common element type failures + by_element: Dict[str, List[Correction]] = defaultdict(list) + for c in corrections: + by_element[c.key.element_type].append(c) + + for element, element_corrections in by_element.items(): + if len(element_corrections) >= 3: + patterns.append({ + 'type': 'element_type_issues', + 'element_type': element, + 'count': len(element_corrections), + 'common_fixes': self._find_common_fix_patterns(element_corrections) + }) + + # Pattern 3: Time-based patterns + recent = [c for c in corrections if c.last_applied and + (datetime.now() - c.last_applied).days < 7] + if len(recent) > len(corrections) * 0.5: + patterns.append({ + 'type': 'recent_activity_spike', + 'recent_count': len(recent), + 'total_count': len(corrections) + }) + + return { + 'patterns': patterns, + 'summary': { + 'total_corrections': len(corrections), + 'unique_action_types': len(by_action), + 'unique_element_types': len(by_element), + 'pattern_count': len(patterns) + } + } + + def _find_common_fix_patterns( + self, + corrections: List[Correction] + ) -> List[str]: + """Find common fix patterns in a list of corrections.""" + fix_types = defaultdict(int) + for c in corrections: + fix_types[c.correction_type.value] += 1 + + # Return top 3 most common + sorted_fixes = sorted(fix_types.items(), key=lambda x: x[1], reverse=True) + return [fix[0] for fix in sorted_fixes[:3]] diff --git a/core/corrections/correction_pack_service.py b/core/corrections/correction_pack_service.py new file mode 100644 index 000000000..a065b6492 --- /dev/null +++ b/core/corrections/correction_pack_service.py @@ -0,0 +1,785 @@ +""" +Correction Pack Service - Main service for correction pack management. + +Combines repository, aggregator, and provides high-level operations +for CRUD, aggregation, export/import, and application of corrections. +""" + +import json +import logging +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Any, Tuple + +from .models import ( + Correction, + CorrectionPack, + CorrectionKey, + CorrectionType, + CorrectionSource, + CorrectionStatus, + CorrectionPackMetadata, + generate_correction_id, + generate_pack_id +) +from .correction_repository import CorrectionRepository +from .aggregator import CorrectionAggregator, CrossWorkflowAggregator + +logger = logging.getLogger(__name__) + + +class CorrectionPackService: + """ + Main service for managing correction packs. + + Provides high-level operations combining the repository and aggregator. + """ + + def __init__( + self, + storage_path: str = "data/correction_packs", + training_data_path: str = "training_data" + ): + """ + Initialize the service. + + Args: + storage_path: Path for correction pack storage + training_data_path: Path to training data (for session import) + """ + self.storage_path = Path(storage_path) + self.training_data_path = Path(training_data_path) + + # Initialize components + self.repository = CorrectionRepository(storage_path) + self.aggregator = CorrectionAggregator() + self.cross_workflow_aggregator = CrossWorkflowAggregator() + + logger.info(f"CorrectionPackService initialized (storage={storage_path})") + + # ========== Pack CRUD Operations ========== + + def list_packs( + self, + category: Optional[str] = None, + tags: Optional[List[str]] = None + ) -> List[Dict[str, Any]]: + """ + List all correction packs with optional filtering. + + Args: + category: Filter by category + tags: Filter by tags (match any) + + Returns: + List of pack summaries + """ + packs = self.repository.list_packs() + + # Filter by category + if category: + packs = [p for p in packs if p.metadata.category == category] + + # Filter by tags + if tags: + packs = [ + p for p in packs + if any(tag in p.metadata.tags for tag in tags) + ] + + # Return summaries + return [ + { + 'id': p.id, + 'name': p.name, + 'description': p.description, + 'category': p.metadata.category, + 'tags': p.metadata.tags, + 'correction_count': p.correction_count, + 'success_rate': p.overall_success_rate, + 'created_at': p.created_at.isoformat(), + 'updated_at': p.updated_at.isoformat() + } + for p in packs + ] + + def get_pack(self, pack_id: str) -> Optional[Dict[str, Any]]: + """ + Get a pack by ID with full details. + + Args: + pack_id: Pack ID + + Returns: + Pack dict or None + """ + pack = self.repository.get_pack(pack_id) + if not pack: + return None + return pack.to_dict() + + def create_pack( + self, + name: str, + description: str = "", + category: str = "general", + tags: Optional[List[str]] = None, + author: str = "" + ) -> Dict[str, Any]: + """ + Create a new correction pack. + + Args: + name: Pack name + description: Pack description + category: Pack category + tags: Pack tags + author: Pack author + + Returns: + Created pack dict + """ + metadata = { + 'category': category, + 'tags': tags or [], + 'author': author + } + + pack = self.repository.create_pack(name, description, metadata) + return pack.to_dict() + + def update_pack( + self, + pack_id: str, + name: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + """ + Update a pack's properties. + + Args: + pack_id: Pack ID + name: New name + description: New description + metadata: New metadata + + Returns: + Updated pack dict or None + """ + updates = {} + if name is not None: + updates['name'] = name + if description is not None: + updates['description'] = description + if metadata is not None: + updates['metadata'] = metadata + + pack = self.repository.update_pack(pack_id, updates) + if not pack: + return None + return pack.to_dict() + + def delete_pack(self, pack_id: str) -> bool: + """ + Delete a pack. + + Args: + pack_id: Pack ID + + Returns: + True if deleted + """ + return self.repository.delete_pack(pack_id) + + # ========== Correction Operations ========== + + def add_correction( + self, + pack_id: str, + correction_data: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """ + Add a correction to a pack. + + Args: + pack_id: Pack ID + correction_data: Correction data dict + + Returns: + Created correction dict or None + """ + try: + # Build correction key + key = CorrectionKey( + action_type=correction_data.get('action_type', 'unknown'), + element_type=correction_data.get('element_type', 'unknown'), + failure_context=correction_data.get('failure_context', '') + ) + + # Determine correction type + type_str = correction_data.get('correction_type', 'other') + try: + correction_type = CorrectionType(type_str) + except ValueError: + correction_type = CorrectionType.OTHER + + # Build source if provided + source = None + if 'source' in correction_data: + source = CorrectionSource( + session_id=correction_data['source'].get('session_id', ''), + workflow_id=correction_data['source'].get('workflow_id'), + node_id=correction_data['source'].get('node_id') + ) + + # Create correction + correction = Correction( + id=generate_correction_id(), + key=key, + correction_type=correction_type, + original_target=correction_data.get('original_target'), + original_params=correction_data.get('original_params'), + corrected_target=correction_data.get('corrected_target'), + corrected_params=correction_data.get('corrected_params'), + correction_description=correction_data.get('description', ''), + source=source, + tags=correction_data.get('tags', []) + ) + + if self.repository.add_correction(pack_id, correction): + return correction.to_dict() + return None + + except Exception as e: + logger.error(f"Error adding correction: {e}") + return None + + def add_correction_from_session( + self, + pack_id: str, + session_correction: Dict[str, Any], + session_id: str, + workflow_id: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """ + Add a correction from session user_corrections format. + + Args: + pack_id: Pack ID + session_correction: Correction dict from TrainingSession.user_corrections + session_id: Session ID + workflow_id: Workflow ID + + Returns: + Created correction dict or None + """ + # Enrich with source info + session_correction['source'] = { + 'session_id': session_id, + 'workflow_id': workflow_id, + 'node_id': session_correction.get('node_id') + } + + # Map old format fields to new format + if 'type' in session_correction and 'action_type' not in session_correction: + session_correction['action_type'] = session_correction['type'] + if 'reason' in session_correction and 'failure_context' not in session_correction: + session_correction['failure_context'] = session_correction['reason'] + if 'new_target' in session_correction and 'corrected_target' not in session_correction: + session_correction['corrected_target'] = session_correction['new_target'] + if 'new_params' in session_correction and 'corrected_params' not in session_correction: + session_correction['corrected_params'] = session_correction['new_params'] + + return self.add_correction(pack_id, session_correction) + + def update_correction( + self, + pack_id: str, + correction_id: str, + updates: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """ + Update a correction. + + Args: + pack_id: Pack ID + correction_id: Correction ID + updates: Fields to update + + Returns: + Updated correction dict or None + """ + correction = self.repository.update_correction(pack_id, correction_id, updates) + if not correction: + return None + return correction.to_dict() + + def remove_correction(self, pack_id: str, correction_id: str) -> bool: + """ + Remove a correction from a pack. + + Args: + pack_id: Pack ID + correction_id: Correction ID + + Returns: + True if removed + """ + return self.repository.remove_correction(pack_id, correction_id) + + # ========== Aggregation Operations ========== + + def aggregate_from_sessions( + self, + pack_id: str, + session_files: Optional[List[str]] = None, + workflow_id: Optional[str] = None, + min_occurrences: int = 1, + min_success_rate: float = 0.0 + ) -> Dict[str, Any]: + """ + Aggregate corrections from training sessions into a pack. + + Args: + pack_id: Pack ID to add corrections to + session_files: Specific session files to process (None = all) + workflow_id: Filter by workflow ID + min_occurrences: Minimum occurrences to include + min_success_rate: Minimum success rate to include + + Returns: + Aggregation result summary + """ + # Collect all session corrections + all_corrections = [] + + if session_files is None: + # Find all session files + session_files = list(self.training_data_path.glob("session_*.json")) + else: + session_files = [Path(f) for f in session_files] + + sessions_processed = 0 + for session_file in session_files: + try: + with open(session_file, 'r', encoding='utf-8') as f: + session_data = json.load(f) + + # Filter by workflow if specified + if workflow_id and session_data.get('workflow_id') != workflow_id: + continue + + user_corrections = session_data.get('user_corrections', []) + session_id = session_data.get('session_id', session_file.stem) + + for uc in user_corrections: + uc['session_id'] = session_id + uc['workflow_id'] = session_data.get('workflow_id') + all_corrections.append(uc) + + sessions_processed += 1 + + except Exception as e: + logger.warning(f"Error reading session file {session_file}: {e}") + + # Configure aggregator + self.aggregator.min_occurrences = min_occurrences + self.aggregator.min_success_rate = min_success_rate + + # Aggregate + aggregated = self.aggregator.aggregate_from_sessions(all_corrections, workflow_id) + + # Add to pack + added_count = 0 + for correction in aggregated: + if self.repository.add_correction(pack_id, correction): + added_count += 1 + + return { + 'sessions_processed': sessions_processed, + 'corrections_found': len(all_corrections), + 'corrections_aggregated': len(aggregated), + 'corrections_added': added_count, + 'statistics': self.aggregator.calculate_statistics(aggregated) + } + + def aggregate_cross_workflow( + self, + pack_id: str, + min_workflows: int = 2 + ) -> Dict[str, Any]: + """ + Find and aggregate corrections common across multiple workflows. + + Args: + pack_id: Pack ID to add corrections to + min_workflows: Minimum number of workflows a correction must appear in + + Returns: + Aggregation result summary + """ + # Get existing corrections from all packs + all_corrections = [] + for pack in self.repository.list_packs(): + all_corrections.extend(pack.corrections.values()) + + # Find common corrections + common = self.cross_workflow_aggregator.find_common_corrections( + all_corrections, min_workflows + ) + + # Add to pack + added_count = 0 + for correction, workflow_ids in common: + # Tag with source workflows + correction.tags.extend([f"wf:{wid}" for wid in workflow_ids[:5]]) # Limit tags + correction.correction_description += f" (common in {len(workflow_ids)} workflows)" + + if self.repository.add_correction(pack_id, correction): + added_count += 1 + + # Detect patterns + patterns = self.cross_workflow_aggregator.detect_patterns(all_corrections) + + return { + 'total_corrections_analyzed': len(all_corrections), + 'common_corrections_found': len(common), + 'corrections_added': added_count, + 'patterns': patterns + } + + # ========== Search and Apply Operations ========== + + def find_applicable_corrections( + self, + action_type: str, + element_type: str, + failure_context: str, + pack_ids: Optional[List[str]] = None, + min_confidence: float = 0.3 + ) -> List[Dict[str, Any]]: + """ + Find corrections applicable to a failure context. + + Args: + action_type: Type of action that failed + element_type: Type of element involved + failure_context: Description of failure + pack_ids: Specific packs to search (None = all) + min_confidence: Minimum confidence score + + Returns: + List of applicable corrections with pack info + """ + results = self.repository.find_by_context(action_type, element_type, failure_context) + + # Filter by pack IDs if specified + if pack_ids: + results = [(pid, c) for pid, c in results if pid in pack_ids] + + # Filter by confidence + results = [(pid, c) for pid, c in results if c.confidence_score >= min_confidence] + + # Format results + return [ + { + 'pack_id': pack_id, + 'correction': correction.to_dict() + } + for pack_id, correction in results + ] + + def apply_correction( + self, + pack_id: str, + correction_id: str, + success: bool, + application_context: Optional[Dict[str, Any]] = None + ) -> bool: + """ + Record an application of a correction. + + Args: + pack_id: Pack ID + correction_id: Correction ID + success: Whether the application was successful + application_context: Optional context information + + Returns: + True if recorded + """ + self.repository.record_application(pack_id, correction_id, success) + + logger.info( + f"Recorded correction application: pack={pack_id}, " + f"correction={correction_id}, success={success}" + ) + return True + + def search_corrections( + self, + query: str, + pack_id: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Search corrections by text query. + + Args: + query: Search query + pack_id: Optional pack ID to search in + + Returns: + List of matching corrections with pack info + """ + results = self.repository.search_corrections(query, pack_id) + + return [ + { + 'pack_id': pid, + 'correction': correction.to_dict() + } + for pid, correction in results + ] + + # ========== Export/Import Operations ========== + + def export_pack( + self, + pack_id: str, + format: str = 'json', + include_stats: bool = True + ) -> Optional[str]: + """ + Export a pack to JSON or YAML. + + Args: + pack_id: Pack ID + format: Export format ('json' or 'yaml') + include_stats: Whether to include statistics + + Returns: + Exported content string or None + """ + return self.repository.export_pack(pack_id, format) + + def import_pack( + self, + content: str, + format: str = 'json', + merge_strategy: str = 'create_new' + ) -> Optional[Dict[str, Any]]: + """ + Import a pack from JSON or YAML. + + Args: + content: Content to import + format: Import format ('json' or 'yaml') + merge_strategy: How to handle existing pack + + Returns: + Imported pack dict or None + """ + pack = self.repository.import_pack(content, format, merge_strategy) + if not pack: + return None + return pack.to_dict() + + def export_pack_to_file( + self, + pack_id: str, + file_path: str, + format: str = 'json' + ) -> bool: + """ + Export a pack to a file. + + Args: + pack_id: Pack ID + file_path: Output file path + format: Export format + + Returns: + True if successful + """ + content = self.export_pack(pack_id, format) + if not content: + return False + + try: + with open(file_path, 'w', encoding='utf-8') as f: + f.write(content) + logger.info(f"Exported pack {pack_id} to {file_path}") + return True + except Exception as e: + logger.error(f"Error exporting pack to file: {e}") + return False + + def import_pack_from_file( + self, + file_path: str, + format: Optional[str] = None, + merge_strategy: str = 'create_new' + ) -> Optional[Dict[str, Any]]: + """ + Import a pack from a file. + + Args: + file_path: Input file path + format: Import format (auto-detected if None) + merge_strategy: How to handle existing pack + + Returns: + Imported pack dict or None + """ + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Auto-detect format + if format is None: + if file_path.endswith('.yaml') or file_path.endswith('.yml'): + format = 'yaml' + else: + format = 'json' + + result = self.import_pack(content, format, merge_strategy) + if result: + logger.info(f"Imported pack from {file_path}") + return result + + except Exception as e: + logger.error(f"Error importing pack from file: {e}") + return None + + # ========== Version Operations ========== + + def create_version( + self, + pack_id: str, + description: str = "" + ) -> int: + """ + Create a version snapshot of a pack. + + Args: + pack_id: Pack ID + description: Version description + + Returns: + Version number (-1 on error) + """ + return self.repository.create_version(pack_id, description) + + def list_versions(self, pack_id: str) -> List[Dict[str, Any]]: + """ + List all versions of a pack. + + Args: + pack_id: Pack ID + + Returns: + List of version info dicts + """ + return self.repository.list_versions(pack_id) + + def rollback_pack(self, pack_id: str, version: int) -> bool: + """ + Rollback a pack to a previous version. + + Args: + pack_id: Pack ID + version: Version number + + Returns: + True if successful + """ + return self.repository.rollback_pack(pack_id, version) + + # ========== Statistics and Analysis ========== + + def get_pack_statistics(self, pack_id: str) -> Optional[Dict[str, Any]]: + """ + Get detailed statistics for a pack. + + Args: + pack_id: Pack ID + + Returns: + Statistics dict or None + """ + pack = self.repository.get_pack(pack_id) + if not pack: + return None + + return pack.get_statistics() + + def get_global_statistics(self) -> Dict[str, Any]: + """ + Get global statistics across all packs. + + Returns: + Global statistics dict + """ + packs = self.repository.list_packs() + + total_corrections = sum(p.correction_count for p in packs) + total_applications = sum(p.total_applications for p in packs) + + # Aggregate success rate + total_success = sum( + sum(c.success_count for c in p.corrections.values()) + for p in packs + ) + + return { + 'total_packs': len(packs), + 'total_corrections': total_corrections, + 'total_applications': total_applications, + 'overall_success_rate': total_success / total_applications if total_applications > 0 else 0.0, + 'packs_by_category': self._group_packs_by_category(packs), + 'top_corrections': self._get_top_corrections(packs, 10) + } + + def _group_packs_by_category( + self, + packs: List[CorrectionPack] + ) -> Dict[str, int]: + """Group packs by category.""" + by_category = {} + for pack in packs: + cat = pack.metadata.category + by_category[cat] = by_category.get(cat, 0) + 1 + return by_category + + def _get_top_corrections( + self, + packs: List[CorrectionPack], + limit: int = 10 + ) -> List[Dict[str, Any]]: + """Get top corrections by confidence score.""" + all_corrections = [] + for pack in packs: + for correction in pack.corrections.values(): + all_corrections.append({ + 'pack_id': pack.id, + 'pack_name': pack.name, + 'correction': correction + }) + + # Sort by confidence + all_corrections.sort(key=lambda x: x['correction'].confidence_score, reverse=True) + + # Return top N + return [ + { + 'pack_id': item['pack_id'], + 'pack_name': item['pack_name'], + 'correction_id': item['correction'].id, + 'action_type': item['correction'].key.action_type, + 'confidence': item['correction'].confidence_score, + 'success_rate': item['correction'].success_rate, + 'applications': item['correction'].success_count + item['correction'].failure_count + } + for item in all_corrections[:limit] + ] diff --git a/core/corrections/correction_repository.py b/core/corrections/correction_repository.py new file mode 100644 index 000000000..24974b3c8 --- /dev/null +++ b/core/corrections/correction_repository.py @@ -0,0 +1,643 @@ +""" +Correction Repository - JSON storage for corrections and packs. + +Provides atomic file operations for safe persistence. +""" + +import json +import logging +import os +import shutil +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Any + +from .models import ( + Correction, + CorrectionPack, + CorrectionKey, + generate_pack_id +) + +logger = logging.getLogger(__name__) + + +class CorrectionRepository: + """ + Repository for storing and retrieving correction packs. + + Uses atomic JSON writes (temp file + rename) for safety. + Maintains an index by key hash for fast lookups. + """ + + def __init__(self, storage_path: str = "data/correction_packs"): + """ + Initialize the repository. + + Args: + storage_path: Base path for storing correction data + """ + self.storage_path = Path(storage_path) + self.packs_file = self.storage_path / "packs.json" + self.corrections_dir = self.storage_path / "corrections" + self.versions_dir = self.storage_path / "versions" + + # Ensure directories exist + self._ensure_directories() + + # In-memory cache + self._packs: Dict[str, CorrectionPack] = {} + self._key_index: Dict[str, List[str]] = {} # key_hash -> [correction_ids] + + # Load existing data + self._load_packs() + self._rebuild_index() + + def _ensure_directories(self): + """Ensure all required directories exist.""" + self.storage_path.mkdir(parents=True, exist_ok=True) + self.corrections_dir.mkdir(parents=True, exist_ok=True) + self.versions_dir.mkdir(parents=True, exist_ok=True) + + # ========== Pack Operations ========== + + def list_packs(self) -> List[CorrectionPack]: + """Get all correction packs.""" + return list(self._packs.values()) + + def get_pack(self, pack_id: str) -> Optional[CorrectionPack]: + """Get a pack by ID.""" + return self._packs.get(pack_id) + + def create_pack(self, name: str, description: str = "", + metadata: Optional[Dict[str, Any]] = None) -> CorrectionPack: + """ + Create a new correction pack. + + Args: + name: Pack name + description: Pack description + metadata: Optional metadata dict + + Returns: + Created pack + """ + pack_id = generate_pack_id() + pack = CorrectionPack( + id=pack_id, + name=name, + description=description + ) + + if metadata: + pack.metadata.version = metadata.get('version', '1.0.0') + pack.metadata.author = metadata.get('author', '') + pack.metadata.category = metadata.get('category', 'general') + pack.metadata.tags = metadata.get('tags', []) + + self._packs[pack_id] = pack + self._save_packs() + + logger.info(f"Created correction pack: {pack_id} ({name})") + return pack + + def update_pack(self, pack_id: str, updates: Dict[str, Any]) -> Optional[CorrectionPack]: + """ + Update a pack's properties. + + Args: + pack_id: Pack ID + updates: Dict of fields to update + + Returns: + Updated pack or None if not found + """ + pack = self._packs.get(pack_id) + if not pack: + return None + + # Update allowed fields + if 'name' in updates: + pack.name = updates['name'] + if 'description' in updates: + pack.description = updates['description'] + if 'metadata' in updates: + metadata = updates['metadata'] + if 'version' in metadata: + pack.metadata.version = metadata['version'] + if 'author' in metadata: + pack.metadata.author = metadata['author'] + if 'category' in metadata: + pack.metadata.category = metadata['category'] + if 'tags' in metadata: + pack.metadata.tags = metadata['tags'] + + pack.updated_at = datetime.now() + self._save_packs() + + logger.info(f"Updated correction pack: {pack_id}") + return pack + + def delete_pack(self, pack_id: str) -> bool: + """ + Delete a pack and all its corrections. + + Args: + pack_id: Pack ID + + Returns: + True if deleted, False if not found + """ + if pack_id not in self._packs: + return False + + pack = self._packs[pack_id] + + # Remove from index + for correction in pack.corrections.values(): + key_hash = correction.key.to_hash() + if key_hash in self._key_index: + if correction.id in self._key_index[key_hash]: + self._key_index[key_hash].remove(correction.id) + + del self._packs[pack_id] + self._save_packs() + + # Remove version history + version_dir = self.versions_dir / pack_id + if version_dir.exists(): + shutil.rmtree(version_dir) + + logger.info(f"Deleted correction pack: {pack_id}") + return True + + # ========== Correction Operations ========== + + def add_correction(self, pack_id: str, correction: Correction) -> bool: + """ + Add a correction to a pack. + + Args: + pack_id: Pack ID + correction: Correction to add + + Returns: + True if added, False if pack not found or correction exists + """ + pack = self._packs.get(pack_id) + if not pack: + return False + + if not pack.add_correction(correction): + return False + + # Update index + key_hash = correction.key.to_hash() + if key_hash not in self._key_index: + self._key_index[key_hash] = [] + if correction.id not in self._key_index[key_hash]: + self._key_index[key_hash].append(correction.id) + + self._save_packs() + logger.info(f"Added correction {correction.id} to pack {pack_id}") + return True + + def remove_correction(self, pack_id: str, correction_id: str) -> bool: + """ + Remove a correction from a pack. + + Args: + pack_id: Pack ID + correction_id: Correction ID + + Returns: + True if removed, False if not found + """ + pack = self._packs.get(pack_id) + if not pack: + return False + + correction = pack.get_correction(correction_id) + if not correction: + return False + + # Update index + key_hash = correction.key.to_hash() + if key_hash in self._key_index: + if correction_id in self._key_index[key_hash]: + self._key_index[key_hash].remove(correction_id) + + pack.remove_correction(correction_id) + self._save_packs() + + logger.info(f"Removed correction {correction_id} from pack {pack_id}") + return True + + def update_correction(self, pack_id: str, correction_id: str, + updates: Dict[str, Any]) -> Optional[Correction]: + """ + Update a correction's properties. + + Args: + pack_id: Pack ID + correction_id: Correction ID + updates: Dict of fields to update + + Returns: + Updated correction or None if not found + """ + pack = self._packs.get(pack_id) + if not pack: + return None + + correction = pack.get_correction(correction_id) + if not correction: + return None + + # Update allowed fields + if 'corrected_target' in updates: + correction.corrected_target = updates['corrected_target'] + if 'corrected_params' in updates: + correction.corrected_params = updates['corrected_params'] + if 'correction_description' in updates: + correction.correction_description = updates['correction_description'] + if 'status' in updates: + from .models import CorrectionStatus + correction.status = CorrectionStatus(updates['status']) + if 'tags' in updates: + correction.tags = updates['tags'] + + correction.updated_at = datetime.now() + pack.updated_at = datetime.now() + self._save_packs() + + logger.info(f"Updated correction {correction_id} in pack {pack_id}") + return correction + + def record_application(self, pack_id: str, correction_id: str, success: bool): + """ + Record an application of a correction. + + Args: + pack_id: Pack ID + correction_id: Correction ID + success: Whether the application was successful + """ + pack = self._packs.get(pack_id) + if not pack: + return + + correction = pack.get_correction(correction_id) + if not correction: + return + + correction.record_application(success) + pack.updated_at = datetime.now() + self._save_packs() + + # ========== Search Operations ========== + + def find_by_key_hash(self, key_hash: str) -> List[tuple]: + """ + Find all corrections matching a key hash. + + Args: + key_hash: MD5 hash of correction key + + Returns: + List of (pack_id, correction) tuples + """ + results = [] + + for pack_id, pack in self._packs.items(): + for correction in pack.find_by_key_hash(key_hash): + results.append((pack_id, correction)) + + # Sort by confidence score + results.sort(key=lambda x: x[1].confidence_score, reverse=True) + return results + + def find_by_context(self, action_type: str, element_type: str, + failure_context: str) -> List[tuple]: + """ + Find corrections matching a failure context. + + Args: + action_type: Type of action that failed + element_type: Type of element involved + failure_context: Description of failure + + Returns: + List of (pack_id, correction) tuples + """ + key = CorrectionKey(action_type, element_type, failure_context) + return self.find_by_key_hash(key.to_hash()) + + def search_corrections(self, query: str, pack_id: Optional[str] = None) -> List[tuple]: + """ + Search corrections by text query. + + Args: + query: Search query (searches in description, tags, etc.) + pack_id: Optional pack ID to search in + + Returns: + List of (pack_id, correction) tuples + """ + results = [] + query_lower = query.lower() + + packs_to_search = [self._packs[pack_id]] if pack_id and pack_id in self._packs else self._packs.values() + + for pack in packs_to_search: + for correction in pack.corrections.values(): + # Search in description + if query_lower in correction.correction_description.lower(): + results.append((pack.id, correction)) + continue + + # Search in tags + if any(query_lower in tag.lower() for tag in correction.tags): + results.append((pack.id, correction)) + continue + + # Search in key components + if (query_lower in correction.key.action_type.lower() or + query_lower in correction.key.element_type.lower() or + query_lower in correction.key.failure_context.lower()): + results.append((pack.id, correction)) + + return results + + # ========== Version Operations ========== + + def create_version(self, pack_id: str, description: str = "") -> int: + """ + Create a version snapshot of a pack. + + Args: + pack_id: Pack ID + description: Version description + + Returns: + Version number + """ + pack = self._packs.get(pack_id) + if not pack: + return -1 + + # Create version directory + pack_versions = self.versions_dir / pack_id + pack_versions.mkdir(parents=True, exist_ok=True) + + # Increment version + version = pack.current_version + pack.current_version += 1 + + # Save snapshot + version_file = pack_versions / f"v{version}.json" + version_data = { + 'version': version, + 'description': description, + 'created_at': datetime.now().isoformat(), + 'pack_snapshot': pack.to_dict() + } + + self._atomic_write(version_file, version_data) + self._save_packs() + + logger.info(f"Created version {version} for pack {pack_id}") + return version + + def list_versions(self, pack_id: str) -> List[Dict[str, Any]]: + """ + List all versions of a pack. + + Args: + pack_id: Pack ID + + Returns: + List of version info dicts + """ + versions = [] + pack_versions = self.versions_dir / pack_id + + if not pack_versions.exists(): + return versions + + for version_file in sorted(pack_versions.glob("v*.json")): + try: + with open(version_file, 'r', encoding='utf-8') as f: + data = json.load(f) + versions.append({ + 'version': data.get('version'), + 'description': data.get('description', ''), + 'created_at': data.get('created_at'), + 'correction_count': len(data.get('pack_snapshot', {}).get('corrections', {})) + }) + except Exception as e: + logger.warning(f"Error reading version file {version_file}: {e}") + + return versions + + def rollback_pack(self, pack_id: str, version: int) -> bool: + """ + Rollback a pack to a previous version. + + Args: + pack_id: Pack ID + version: Version number to rollback to + + Returns: + True if successful, False otherwise + """ + version_file = self.versions_dir / pack_id / f"v{version}.json" + + if not version_file.exists(): + logger.warning(f"Version {version} not found for pack {pack_id}") + return False + + try: + with open(version_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + # Restore pack from snapshot + snapshot = data.get('pack_snapshot', {}) + pack = CorrectionPack.from_dict(snapshot) + + # Update version number to continue from current + if pack_id in self._packs: + pack.current_version = self._packs[pack_id].current_version + 1 + + self._packs[pack_id] = pack + self._rebuild_index() + self._save_packs() + + logger.info(f"Rolled back pack {pack_id} to version {version}") + return True + + except Exception as e: + logger.error(f"Error rolling back pack {pack_id} to version {version}: {e}") + return False + + # ========== Persistence ========== + + def _load_packs(self): + """Load all packs from storage.""" + if not self.packs_file.exists(): + return + + try: + with open(self.packs_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + for pack_id, pack_data in data.items(): + try: + self._packs[pack_id] = CorrectionPack.from_dict(pack_data) + except Exception as e: + logger.warning(f"Error loading pack {pack_id}: {e}") + + logger.info(f"Loaded {len(self._packs)} correction packs") + + except Exception as e: + logger.error(f"Error loading packs file: {e}") + + def _save_packs(self): + """Save all packs to storage using atomic write.""" + try: + data = { + pack_id: pack.to_dict() + for pack_id, pack in self._packs.items() + } + self._atomic_write(self.packs_file, data) + except Exception as e: + logger.error(f"Error saving packs: {e}") + + def _atomic_write(self, file_path: Path, data: Dict[str, Any]): + """ + Write data to file atomically (temp file + rename). + + Args: + file_path: Target file path + data: Data to write + """ + temp_file = file_path.with_suffix('.tmp') + + try: + with open(temp_file, 'w', encoding='utf-8') as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + # Atomic rename + temp_file.replace(file_path) + + except Exception as e: + # Clean up temp file on error + if temp_file.exists(): + temp_file.unlink() + raise e + + def _rebuild_index(self): + """Rebuild the key hash index from all packs.""" + self._key_index.clear() + + for pack in self._packs.values(): + for correction in pack.corrections.values(): + key_hash = correction.key.to_hash() + if key_hash not in self._key_index: + self._key_index[key_hash] = [] + if correction.id not in self._key_index[key_hash]: + self._key_index[key_hash].append(correction.id) + + # ========== Export/Import ========== + + def export_pack(self, pack_id: str, format: str = 'json') -> Optional[str]: + """ + Export a pack to JSON or YAML format. + + Args: + pack_id: Pack ID + format: Export format ('json' or 'yaml') + + Returns: + Exported content as string, or None if pack not found + """ + pack = self._packs.get(pack_id) + if not pack: + return None + + export_data = { + 'export_format': 'correction_pack', + 'export_version': '1.0', + 'exported_at': datetime.now().isoformat(), + 'pack': pack.to_dict() + } + + if format == 'yaml': + try: + import yaml + return yaml.dump(export_data, default_flow_style=False, allow_unicode=True) + except ImportError: + logger.warning("PyYAML not installed, falling back to JSON") + return json.dumps(export_data, indent=2, ensure_ascii=False) + else: + return json.dumps(export_data, indent=2, ensure_ascii=False) + + def import_pack(self, content: str, format: str = 'json', + merge_strategy: str = 'create_new') -> Optional[CorrectionPack]: + """ + Import a pack from JSON or YAML content. + + Args: + content: Content to import + format: Import format ('json' or 'yaml') + merge_strategy: How to handle existing pack: + - 'create_new': Always create new pack with new ID + - 'merge': Merge into existing pack if same ID + - 'replace': Replace existing pack if same ID + + Returns: + Imported pack, or None on error + """ + try: + if format == 'yaml': + try: + import yaml + data = yaml.safe_load(content) + except ImportError: + logger.error("PyYAML not installed") + return None + else: + data = json.loads(content) + + # Validate format + if data.get('export_format') != 'correction_pack': + logger.warning("Invalid export format") + return None + + pack_data = data.get('pack', {}) + imported_pack = CorrectionPack.from_dict(pack_data) + + # Handle merge strategy + if merge_strategy == 'create_new': + imported_pack.id = generate_pack_id() + self._packs[imported_pack.id] = imported_pack + elif merge_strategy == 'merge' and imported_pack.id in self._packs: + existing_pack = self._packs[imported_pack.id] + existing_pack.merge(imported_pack) + imported_pack = existing_pack + elif merge_strategy == 'replace': + self._packs[imported_pack.id] = imported_pack + else: + # Default: create new + imported_pack.id = generate_pack_id() + self._packs[imported_pack.id] = imported_pack + + self._rebuild_index() + self._save_packs() + + logger.info(f"Imported pack: {imported_pack.id} ({imported_pack.name})") + return imported_pack + + except Exception as e: + logger.error(f"Error importing pack: {e}") + return None diff --git a/core/corrections/models.py b/core/corrections/models.py new file mode 100644 index 000000000..bb790038a --- /dev/null +++ b/core/corrections/models.py @@ -0,0 +1,480 @@ +""" +Data models for the Correction Packs system. + +Defines CorrectionKey, Correction, CorrectionPack and related types. +""" + +import hashlib +import uuid +from dataclasses import dataclass, field, asdict +from datetime import datetime +from enum import Enum +from typing import Dict, List, Any, Optional + + +class CorrectionType(Enum): + """Types of corrections that can be applied.""" + TARGET_CHANGE = "target_change" # Changed target selector + PARAMETER_CHANGE = "parameter_change" # Changed action parameters + ACTION_REPLACE = "action_replace" # Replaced action type + WAIT_ADDED = "wait_added" # Added wait before action + RETRY_CONFIG = "retry_config" # Changed retry configuration + FALLBACK_ADDED = "fallback_added" # Added fallback action + COORDINATES_ADJUST = "coordinates_adjust" # Adjusted click coordinates + TIMING_ADJUST = "timing_adjust" # Adjusted timing/delay + OTHER = "other" + + +class CorrectionStatus(Enum): + """Status of a correction.""" + ACTIVE = "active" # Correction is active and can be applied + DEPRECATED = "deprecated" # Correction is deprecated + DISABLED = "disabled" # Correction is manually disabled + + +@dataclass +class CorrectionKey: + """ + Unique key to identify similar corrections across workflows. + + Uses action_type + element_type + failure_context to generate MD5 hash. + Similar to LearningRepository pattern key generation. + """ + action_type: str + element_type: str + failure_context: str + + def to_hash(self) -> str: + """Generate MD5 hash of the correction key (16 chars).""" + key_string = f"{self.action_type}|{self.element_type}|{self.failure_context}" + return hashlib.md5(key_string.encode()).hexdigest()[:16] + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + 'action_type': self.action_type, + 'element_type': self.element_type, + 'failure_context': self.failure_context, + 'hash': self.to_hash() + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionKey': + """Create from dictionary.""" + return cls( + action_type=data.get('action_type', 'unknown'), + element_type=data.get('element_type', 'unknown'), + failure_context=data.get('failure_context', '') + ) + + +@dataclass +class CorrectionSource: + """Source information for a correction.""" + session_id: str + workflow_id: Optional[str] = None + node_id: Optional[str] = None + timestamp: datetime = field(default_factory=datetime.now) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + 'session_id': self.session_id, + 'workflow_id': self.workflow_id, + 'node_id': self.node_id, + 'timestamp': self.timestamp.isoformat() + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionSource': + """Create from dictionary.""" + timestamp = data.get('timestamp') + if isinstance(timestamp, str): + timestamp = datetime.fromisoformat(timestamp) + elif timestamp is None: + timestamp = datetime.now() + + return cls( + session_id=data.get('session_id', ''), + workflow_id=data.get('workflow_id'), + node_id=data.get('node_id'), + timestamp=timestamp + ) + + +@dataclass +class Correction: + """ + Individual correction record. + + Contains the original context, the applied correction, source information, + and statistics about success/failure. + """ + id: str + key: CorrectionKey + correction_type: CorrectionType + + # Original context + original_target: Optional[Dict[str, Any]] = None + original_params: Optional[Dict[str, Any]] = None + + # Correction applied + corrected_target: Optional[Dict[str, Any]] = None + corrected_params: Optional[Dict[str, Any]] = None + correction_description: str = "" + + # Source information + source: Optional[CorrectionSource] = None + + # Statistics + success_count: int = 0 + failure_count: int = 0 + last_applied: Optional[datetime] = None + + # Status + status: CorrectionStatus = CorrectionStatus.ACTIVE + + # Metadata + created_at: datetime = field(default_factory=datetime.now) + updated_at: datetime = field(default_factory=datetime.now) + tags: List[str] = field(default_factory=list) + + @property + def success_rate(self) -> float: + """Calculate success rate.""" + total = self.success_count + self.failure_count + if total == 0: + return 0.0 + return self.success_count / total + + @property + def confidence_score(self) -> float: + """Calculate confidence score based on success rate and usage count.""" + total = self.success_count + self.failure_count + if total == 0: + return 0.0 + # Confidence increases with more data points (up to 100 uses) + usage_factor = min(1.0, total / 100) + return self.success_rate * (0.5 + 0.5 * usage_factor) + + def record_application(self, success: bool): + """Record an application of this correction.""" + if success: + self.success_count += 1 + else: + self.failure_count += 1 + self.last_applied = datetime.now() + self.updated_at = datetime.now() + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + 'id': self.id, + 'key': self.key.to_dict(), + 'correction_type': self.correction_type.value, + 'original_target': self.original_target, + 'original_params': self.original_params, + 'corrected_target': self.corrected_target, + 'corrected_params': self.corrected_params, + 'correction_description': self.correction_description, + 'source': self.source.to_dict() if self.source else None, + 'success_count': self.success_count, + 'failure_count': self.failure_count, + 'success_rate': self.success_rate, + 'confidence_score': self.confidence_score, + 'last_applied': self.last_applied.isoformat() if self.last_applied else None, + 'status': self.status.value, + 'created_at': self.created_at.isoformat(), + 'updated_at': self.updated_at.isoformat(), + 'tags': self.tags + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'Correction': + """Create from dictionary.""" + # Parse datetime fields + created_at = data.get('created_at') + if isinstance(created_at, str): + created_at = datetime.fromisoformat(created_at) + elif created_at is None: + created_at = datetime.now() + + updated_at = data.get('updated_at') + if isinstance(updated_at, str): + updated_at = datetime.fromisoformat(updated_at) + elif updated_at is None: + updated_at = datetime.now() + + last_applied = data.get('last_applied') + if isinstance(last_applied, str): + last_applied = datetime.fromisoformat(last_applied) + + # Parse nested objects + key_data = data.get('key', {}) + key = CorrectionKey.from_dict(key_data) + + source_data = data.get('source') + source = CorrectionSource.from_dict(source_data) if source_data else None + + return cls( + id=data.get('id', str(uuid.uuid4())), + key=key, + correction_type=CorrectionType(data.get('correction_type', 'other')), + original_target=data.get('original_target'), + original_params=data.get('original_params'), + corrected_target=data.get('corrected_target'), + corrected_params=data.get('corrected_params'), + correction_description=data.get('correction_description', ''), + source=source, + success_count=data.get('success_count', 0), + failure_count=data.get('failure_count', 0), + last_applied=last_applied, + status=CorrectionStatus(data.get('status', 'active')), + created_at=created_at, + updated_at=updated_at, + tags=data.get('tags', []) + ) + + +@dataclass +class CorrectionPackMetadata: + """Metadata for a correction pack.""" + version: str = "1.0.0" + author: str = "" + category: str = "general" + tags: List[str] = field(default_factory=list) + description: str = "" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + 'version': self.version, + 'author': self.author, + 'category': self.category, + 'tags': self.tags, + 'description': self.description + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPackMetadata': + """Create from dictionary.""" + return cls( + version=data.get('version', '1.0.0'), + author=data.get('author', ''), + category=data.get('category', 'general'), + tags=data.get('tags', []), + description=data.get('description', '') + ) + + +@dataclass +class CorrectionPack: + """ + A pack of corrections that can be exported, imported, and shared. + + Contains multiple corrections with metadata and statistics. + """ + id: str + name: str + description: str = "" + + # Corrections in this pack + corrections: Dict[str, Correction] = field(default_factory=dict) + + # Metadata + metadata: CorrectionPackMetadata = field(default_factory=CorrectionPackMetadata) + + # Timestamps + created_at: datetime = field(default_factory=datetime.now) + updated_at: datetime = field(default_factory=datetime.now) + + # Version history + current_version: int = 1 + + @property + def correction_count(self) -> int: + """Number of corrections in the pack.""" + return len(self.corrections) + + @property + def total_applications(self) -> int: + """Total number of times corrections were applied.""" + return sum(c.success_count + c.failure_count for c in self.corrections.values()) + + @property + def overall_success_rate(self) -> float: + """Overall success rate across all corrections.""" + total_success = sum(c.success_count for c in self.corrections.values()) + total_failure = sum(c.failure_count for c in self.corrections.values()) + total = total_success + total_failure + if total == 0: + return 0.0 + return total_success / total + + @property + def active_corrections(self) -> List[Correction]: + """Get only active corrections.""" + return [c for c in self.corrections.values() if c.status == CorrectionStatus.ACTIVE] + + def add_correction(self, correction: Correction) -> bool: + """ + Add a correction to the pack. + + Returns True if added, False if already exists. + """ + if correction.id in self.corrections: + return False + self.corrections[correction.id] = correction + self.updated_at = datetime.now() + return True + + def remove_correction(self, correction_id: str) -> bool: + """ + Remove a correction from the pack. + + Returns True if removed, False if not found. + """ + if correction_id not in self.corrections: + return False + del self.corrections[correction_id] + self.updated_at = datetime.now() + return True + + def get_correction(self, correction_id: str) -> Optional[Correction]: + """Get a correction by ID.""" + return self.corrections.get(correction_id) + + def find_by_key_hash(self, key_hash: str) -> List[Correction]: + """Find corrections matching a key hash.""" + return [ + c for c in self.corrections.values() + if c.key.to_hash() == key_hash and c.status == CorrectionStatus.ACTIVE + ] + + def merge(self, other: 'CorrectionPack', conflict_strategy: str = 'keep_higher_confidence') -> int: + """ + Merge another pack into this one. + + Args: + other: Pack to merge + conflict_strategy: How to handle conflicts: + - 'keep_higher_confidence': Keep correction with higher confidence + - 'keep_newer': Keep the most recently updated correction + - 'keep_existing': Keep existing, ignore new + - 'replace': Always replace with new + + Returns: + Number of corrections added/updated + """ + merged_count = 0 + + for correction_id, correction in other.corrections.items(): + if correction_id in self.corrections: + existing = self.corrections[correction_id] + + should_replace = False + if conflict_strategy == 'keep_higher_confidence': + should_replace = correction.confidence_score > existing.confidence_score + elif conflict_strategy == 'keep_newer': + should_replace = correction.updated_at > existing.updated_at + elif conflict_strategy == 'replace': + should_replace = True + # 'keep_existing' doesn't replace + + if should_replace: + self.corrections[correction_id] = correction + merged_count += 1 + else: + self.corrections[correction_id] = correction + merged_count += 1 + + if merged_count > 0: + self.updated_at = datetime.now() + + return merged_count + + def get_statistics(self) -> Dict[str, Any]: + """Get aggregate statistics for the pack.""" + corrections = list(self.corrections.values()) + active = [c for c in corrections if c.status == CorrectionStatus.ACTIVE] + + # Group by correction type + by_type = {} + for c in corrections: + type_name = c.correction_type.value + if type_name not in by_type: + by_type[type_name] = 0 + by_type[type_name] += 1 + + return { + 'total_corrections': len(corrections), + 'active_corrections': len(active), + 'deprecated_corrections': len([c for c in corrections if c.status == CorrectionStatus.DEPRECATED]), + 'disabled_corrections': len([c for c in corrections if c.status == CorrectionStatus.DISABLED]), + 'total_applications': self.total_applications, + 'overall_success_rate': self.overall_success_rate, + 'corrections_by_type': by_type, + 'average_confidence': sum(c.confidence_score for c in active) / len(active) if active else 0.0 + } + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + 'id': self.id, + 'name': self.name, + 'description': self.description, + 'corrections': {cid: c.to_dict() for cid, c in self.corrections.items()}, + 'metadata': self.metadata.to_dict(), + 'created_at': self.created_at.isoformat(), + 'updated_at': self.updated_at.isoformat(), + 'current_version': self.current_version, + 'statistics': self.get_statistics() + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPack': + """Create from dictionary.""" + # Parse datetime fields + created_at = data.get('created_at') + if isinstance(created_at, str): + created_at = datetime.fromisoformat(created_at) + elif created_at is None: + created_at = datetime.now() + + updated_at = data.get('updated_at') + if isinstance(updated_at, str): + updated_at = datetime.fromisoformat(updated_at) + elif updated_at is None: + updated_at = datetime.now() + + # Parse corrections + corrections_data = data.get('corrections', {}) + corrections = { + cid: Correction.from_dict(cdata) + for cid, cdata in corrections_data.items() + } + + # Parse metadata + metadata_data = data.get('metadata', {}) + metadata = CorrectionPackMetadata.from_dict(metadata_data) + + return cls( + id=data.get('id', str(uuid.uuid4())), + name=data.get('name', 'Unnamed Pack'), + description=data.get('description', ''), + corrections=corrections, + metadata=metadata, + created_at=created_at, + updated_at=updated_at, + current_version=data.get('current_version', 1) + ) + + +def generate_correction_id() -> str: + """Generate a unique correction ID.""" + return f"corr_{uuid.uuid4().hex[:12]}" + + +def generate_pack_id() -> str: + """Generate a unique pack ID.""" + return f"pack_{uuid.uuid4().hex[:12]}" diff --git a/core/healing/healing_engine.py b/core/healing/healing_engine.py new file mode 100644 index 000000000..03c81880a --- /dev/null +++ b/core/healing/healing_engine.py @@ -0,0 +1,443 @@ +"""Main self-healing engine for workflow recovery.""" + +import time +import logging +from typing import List, Optional, Dict, Any, Tuple +from pathlib import Path +from .models import RecoveryContext, RecoveryResult, RecoverySuggestion +from .learning_repository import LearningRepository +from .confidence_scorer import ConfidenceScorer +from .strategies import ( + RecoveryStrategy, + SemanticVariantStrategy, + SpatialFallbackStrategy, + TimingAdaptationStrategy, + FormatTransformationStrategy +) + +logger = logging.getLogger(__name__) + + +class SelfHealingEngine: + """Main engine for self-healing workflows.""" + + def __init__( + self, + learning_repo: Optional[LearningRepository] = None, + storage_path: Optional[Path] = None, + correction_packs_enabled: bool = True + ): + """ + Initialize the self-healing engine. + + Args: + learning_repo: Learning repository instance + storage_path: Path for storing learned patterns + correction_packs_enabled: Whether to consult correction packs + """ + # Initialize learning repository + if learning_repo: + self.learning_repo = learning_repo + else: + storage_path = storage_path or Path('data/healing') + self.learning_repo = LearningRepository(storage_path) + + # Initialize confidence scorer + self.confidence_scorer = ConfidenceScorer() + + # Initialize recovery strategies + self.recovery_strategies: List[RecoveryStrategy] = [ + SemanticVariantStrategy(), + SpatialFallbackStrategy(), + TimingAdaptationStrategy(), + FormatTransformationStrategy() + ] + + # Configuration + self.max_recovery_time = 30.0 # seconds + self.parallel_execution = True + + # Correction packs integration + self.correction_packs_enabled = correction_packs_enabled + self._correction_pack_service = None + + @property + def correction_pack_service(self): + """Lazy-load correction pack service.""" + if self._correction_pack_service is None and self.correction_packs_enabled: + try: + from core.corrections import CorrectionPackService + self._correction_pack_service = CorrectionPackService() + logger.info("CorrectionPackService initialized for healing engine") + except ImportError as e: + logger.warning(f"CorrectionPackService not available: {e}") + self.correction_packs_enabled = False + return self._correction_pack_service + + def attempt_recovery(self, context: RecoveryContext) -> RecoveryResult: + """ + Attempt to recover from a workflow failure. + + Args: + context: Recovery context with failure information + + Returns: + RecoveryResult with outcome + """ + start_time = time.time() + + # Check if we've exceeded max attempts + if context.attempt_count >= context.max_attempts: + return RecoveryResult( + success=False, + strategy_used='none', + error_message=f'Max recovery attempts ({context.max_attempts}) exceeded' + ) + + # Try correction packs first (if enabled) + if self.correction_packs_enabled and self.correction_pack_service: + pack_result = self._try_correction_packs(context) + if pack_result and pack_result.success: + return pack_result + + # Get learned patterns for this context + learned_patterns = self.learning_repo.get_matching_patterns(context) + + # Get prioritized strategies + strategies = self._prioritize_strategies(context, learned_patterns) + + # Try each strategy + for strategy in strategies: + # Check time limit + elapsed = time.time() - start_time + if elapsed >= self.max_recovery_time: + break + + # Skip if strategy can't handle this failure + if not strategy.can_handle(context): + continue + + # Attempt recovery + result = strategy.attempt_recovery(context) + + # If successful, learn from it + if result.success: + # Calculate final confidence + historical_success = self._get_historical_success_rate( + strategy.name, context, learned_patterns + ) + result.confidence_score = self.confidence_scorer.calculate_recovery_confidence( + result.strategy_used, + context, + historical_success + ) + + # Check if safe to proceed + if self.confidence_scorer.is_safe_to_proceed( + result.confidence_score, + context.confidence_threshold, + involves_data_modification=self._involves_data_modification(context) + ): + # Learn from success + self.learn_from_success(context, result) + return result + else: + # Confidence too low, mark as requiring user input + result.requires_user_input = True + return result + + # All strategies failed + total_time = time.time() - start_time + return RecoveryResult( + success=False, + strategy_used='all_failed', + execution_time=total_time, + error_message='All recovery strategies failed', + requires_user_input=True + ) + + def learn_from_success(self, context: RecoveryContext, result: RecoveryResult): + """ + Learn from successful recovery for future use. + + Args: + context: Recovery context + result: Successful recovery result + """ + self.learning_repo.store_pattern(context, result) + + def get_recovery_suggestions(self, context: RecoveryContext) -> List[RecoverySuggestion]: + """ + Get ranked recovery suggestions based on learned patterns. + + Args: + context: Recovery context + + Returns: + List of recovery suggestions sorted by confidence + """ + suggestions = [] + + # Get learned patterns + learned_patterns = self.learning_repo.get_matching_patterns(context) + + # Get suggestions from each strategy + for strategy in self.recovery_strategies: + if not strategy.can_handle(context): + continue + + # Calculate confidence based on historical success + historical_success = self._get_historical_success_rate( + strategy.name, context, learned_patterns + ) + confidence = self.confidence_scorer.calculate_recovery_confidence( + strategy.name, + context, + historical_success + ) + + suggestion = RecoverySuggestion( + strategy=strategy.name, + confidence=confidence, + description=self._get_strategy_description(strategy), + estimated_time=self._estimate_strategy_time(strategy, context) + ) + suggestions.append(suggestion) + + # Sort by confidence + suggestions.sort(key=lambda s: s.confidence, reverse=True) + + return suggestions + + def prune_learned_patterns( + self, + max_age_days: int = 90, + min_confidence: float = 0.3 + ): + """ + Prune outdated learned patterns. + + Args: + max_age_days: Maximum age for patterns + min_confidence: Minimum confidence threshold + """ + self.learning_repo.prune_outdated_patterns(max_age_days, min_confidence) + + def _prioritize_strategies( + self, + context: RecoveryContext, + learned_patterns: List + ) -> List[RecoveryStrategy]: + """Prioritize strategies based on context and learned patterns.""" + # Score each strategy + scored_strategies = [] + for strategy in self.recovery_strategies: + # Base priority from strategy + priority = strategy.get_priority(context) + + # Boost priority if we have successful patterns + historical_success = self._get_historical_success_rate( + strategy.name, context, learned_patterns + ) + if historical_success > 0: + priority *= (1.0 + historical_success) + + scored_strategies.append((priority, strategy)) + + # Sort by priority (highest first) + scored_strategies.sort(key=lambda x: x[0], reverse=True) + + return [strategy for _, strategy in scored_strategies] + + def _get_historical_success_rate( + self, + strategy_name: str, + context: RecoveryContext, + learned_patterns: List + ) -> float: + """Get historical success rate for a strategy.""" + matching_patterns = [ + p for p in learned_patterns + if p.recovery_strategy == strategy_name + ] + + if not matching_patterns: + return 0.0 + + # Return average success rate + return sum(p.success_rate for p in matching_patterns) / len(matching_patterns) + + def _involves_data_modification(self, context: RecoveryContext) -> bool: + """Check if the action involves data modification.""" + data_modification_actions = [ + 'input', 'type', 'submit', 'delete', 'update', 'save' + ] + return any(action in context.original_action.lower() + for action in data_modification_actions) + + def _get_strategy_description(self, strategy: RecoveryStrategy) -> str: + """Get human-readable description of strategy.""" + descriptions = { + 'SemanticVariantStrategy': 'Try semantic variants of the element (e.g., Submit → Send)', + 'SpatialFallbackStrategy': 'Search in expanded area around original position', + 'TimingAdaptationStrategy': 'Increase wait time for element to appear', + 'FormatTransformationStrategy': 'Transform input format to match validation' + } + return descriptions.get(strategy.name, strategy.name) + + def _estimate_strategy_time( + self, + strategy: RecoveryStrategy, + context: RecoveryContext + ) -> float: + """Estimate execution time for strategy.""" + # Simple estimates in seconds + estimates = { + 'SemanticVariantStrategy': 2.0, + 'SpatialFallbackStrategy': 5.0, + 'TimingAdaptationStrategy': 10.0, + 'FormatTransformationStrategy': 1.0 + } + return estimates.get(strategy.name, 3.0) + + def _try_correction_packs(self, context: RecoveryContext) -> Optional[RecoveryResult]: + """ + Try to find and apply a correction from correction packs. + + Args: + context: Recovery context + + Returns: + RecoveryResult if a correction was successfully applied, None otherwise + """ + try: + # Get element type from metadata + element_type = context.metadata.get('element_type', 'unknown') + + # Search for applicable corrections + corrections = self.correction_pack_service.find_applicable_corrections( + action_type=context.original_action, + element_type=element_type, + failure_context=context.failure_reason, + min_confidence=0.5 # Only use high-confidence corrections + ) + + if not corrections: + return None + + # Try corrections in order of confidence + for correction_info in corrections: + pack_id = correction_info['pack_id'] + correction = correction_info['correction'] + + try: + # Apply the correction + applied = self._apply_correction(context, correction) + + if applied: + # Record successful application + self.correction_pack_service.apply_correction( + pack_id=pack_id, + correction_id=correction['id'], + success=True + ) + + # Propagate to learning repository + result = RecoveryResult( + success=True, + strategy_used=f"correction_pack:{correction['correction_type']}", + confidence_score=correction['confidence_score'], + modified_action=correction.get('corrected_params'), + modified_target=correction.get('corrected_target') + ) + + # Learn from this for future reference + self.learning_repo.store_pattern(context, result) + + logger.info( + f"Applied correction from pack {pack_id}: " + f"{correction['id']} (confidence: {correction['confidence_score']:.2f})" + ) + return result + + except Exception as e: + logger.warning(f"Failed to apply correction {correction['id']}: {e}") + # Record failed application + self.correction_pack_service.apply_correction( + pack_id=pack_id, + correction_id=correction['id'], + success=False + ) + + return None + + except Exception as e: + logger.error(f"Error in correction pack lookup: {e}") + return None + + def _apply_correction( + self, + context: RecoveryContext, + correction: Dict[str, Any] + ) -> bool: + """ + Apply a correction to the current context. + + Args: + context: Recovery context + correction: Correction data dict + + Returns: + True if correction was applied successfully + """ + correction_type = correction.get('correction_type', 'other') + + # Handle different correction types + if correction_type == 'target_change': + # Update target in context + if correction.get('corrected_target'): + context.metadata['corrected_target'] = correction['corrected_target'] + return True + + elif correction_type == 'parameter_change': + # Update parameters in context + if correction.get('corrected_params'): + context.metadata['corrected_params'] = correction['corrected_params'] + return True + + elif correction_type == 'wait_added': + # Add wait time + wait_time = correction.get('corrected_params', {}).get('wait_time', 2.0) + context.metadata['additional_wait'] = wait_time + return True + + elif correction_type == 'timing_adjust': + # Adjust timing + timing = correction.get('corrected_params', {}).get('timing_multiplier', 1.5) + context.metadata['timing_multiplier'] = timing + return True + + elif correction_type == 'coordinates_adjust': + # Adjust coordinates + offset = correction.get('corrected_params', {}).get('offset', {}) + context.metadata['coordinate_offset'] = offset + return True + + # For other types, just mark as applied if we have any correction data + return bool(correction.get('corrected_target') or correction.get('corrected_params')) + + def get_correction_pack_statistics(self) -> Optional[Dict[str, Any]]: + """ + Get statistics from correction packs. + + Returns: + Statistics dict or None if not available + """ + if not self.correction_packs_enabled or not self.correction_pack_service: + return None + + try: + return self.correction_pack_service.get_global_statistics() + except Exception as e: + logger.error(f"Error getting correction pack statistics: {e}") + return None diff --git a/tests/test_correction_packs.py b/tests/test_correction_packs.py new file mode 100644 index 000000000..8b8151fc5 --- /dev/null +++ b/tests/test_correction_packs.py @@ -0,0 +1,644 @@ +""" +Tests for the Correction Packs system. + +Tests models, repository, aggregator, and service. +""" + +import json +import os +import shutil +import tempfile +import unittest +from datetime import datetime +from pathlib import Path + +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from core.corrections.models import ( + CorrectionKey, + CorrectionType, + CorrectionStatus, + Correction, + CorrectionPack, + CorrectionPackMetadata, + CorrectionSource, + generate_correction_id, + generate_pack_id +) +from core.corrections.correction_repository import CorrectionRepository +from core.corrections.aggregator import CorrectionAggregator, CrossWorkflowAggregator +from core.corrections.correction_pack_service import CorrectionPackService + + +class TestCorrectionKey(unittest.TestCase): + """Tests for CorrectionKey model.""" + + def test_to_hash(self): + """Test hash generation is consistent.""" + key1 = CorrectionKey("click", "button", "element_not_found") + key2 = CorrectionKey("click", "button", "element_not_found") + key3 = CorrectionKey("type", "input", "element_not_found") + + # Same keys should produce same hash + self.assertEqual(key1.to_hash(), key2.to_hash()) + + # Different keys should produce different hash + self.assertNotEqual(key1.to_hash(), key3.to_hash()) + + # Hash should be 16 characters + self.assertEqual(len(key1.to_hash()), 16) + + def test_to_dict_from_dict(self): + """Test serialization round-trip.""" + key = CorrectionKey("click", "button", "timeout") + key_dict = key.to_dict() + + self.assertEqual(key_dict['action_type'], "click") + self.assertEqual(key_dict['element_type'], "button") + self.assertEqual(key_dict['failure_context'], "timeout") + self.assertIn('hash', key_dict) + + restored = CorrectionKey.from_dict(key_dict) + self.assertEqual(key.to_hash(), restored.to_hash()) + + +class TestCorrection(unittest.TestCase): + """Tests for Correction model.""" + + def test_create_correction(self): + """Test creating a correction.""" + key = CorrectionKey("click", "button", "not_visible") + correction = Correction( + id=generate_correction_id(), + key=key, + correction_type=CorrectionType.WAIT_ADDED, + original_target={"selector": "#submit"}, + corrected_params={"wait_time": 2.0}, + correction_description="Added wait for button visibility" + ) + + self.assertTrue(correction.id.startswith("corr_")) + self.assertEqual(correction.correction_type, CorrectionType.WAIT_ADDED) + self.assertEqual(correction.success_rate, 0.0) + + def test_success_rate_calculation(self): + """Test success rate calculation.""" + key = CorrectionKey("click", "button", "not_visible") + correction = Correction( + id=generate_correction_id(), + key=key, + correction_type=CorrectionType.WAIT_ADDED, + success_count=7, + failure_count=3 + ) + + self.assertEqual(correction.success_rate, 0.7) + + def test_record_application(self): + """Test recording application updates stats.""" + key = CorrectionKey("click", "button", "not_visible") + correction = Correction( + id=generate_correction_id(), + key=key, + correction_type=CorrectionType.WAIT_ADDED + ) + + correction.record_application(success=True) + self.assertEqual(correction.success_count, 1) + self.assertEqual(correction.failure_count, 0) + self.assertIsNotNone(correction.last_applied) + + correction.record_application(success=False) + self.assertEqual(correction.success_count, 1) + self.assertEqual(correction.failure_count, 1) + + def test_to_dict_from_dict(self): + """Test serialization round-trip.""" + key = CorrectionKey("type", "input", "validation_error") + source = CorrectionSource( + session_id="sess_123", + workflow_id="wf_456", + node_id="node_789" + ) + correction = Correction( + id="corr_test123", + key=key, + correction_type=CorrectionType.PARAMETER_CHANGE, + original_params={"value": "test"}, + corrected_params={"value": "corrected"}, + source=source, + success_count=5, + failure_count=2, + tags=["form", "validation"] + ) + + correction_dict = correction.to_dict() + restored = Correction.from_dict(correction_dict) + + self.assertEqual(restored.id, correction.id) + self.assertEqual(restored.key.to_hash(), key.to_hash()) + self.assertEqual(restored.correction_type, CorrectionType.PARAMETER_CHANGE) + self.assertEqual(restored.success_count, 5) + self.assertEqual(restored.tags, ["form", "validation"]) + + +class TestCorrectionPack(unittest.TestCase): + """Tests for CorrectionPack model.""" + + def test_create_pack(self): + """Test creating a pack.""" + pack = CorrectionPack( + id=generate_pack_id(), + name="Test Pack", + description="A test correction pack" + ) + + self.assertTrue(pack.id.startswith("pack_")) + self.assertEqual(pack.name, "Test Pack") + self.assertEqual(pack.correction_count, 0) + + def test_add_remove_correction(self): + """Test adding and removing corrections.""" + pack = CorrectionPack( + id=generate_pack_id(), + name="Test Pack" + ) + + key = CorrectionKey("click", "button", "not_found") + correction = Correction( + id="corr_test1", + key=key, + correction_type=CorrectionType.TARGET_CHANGE + ) + + # Add correction + self.assertTrue(pack.add_correction(correction)) + self.assertEqual(pack.correction_count, 1) + + # Adding same correction should return False + self.assertFalse(pack.add_correction(correction)) + + # Remove correction + self.assertTrue(pack.remove_correction("corr_test1")) + self.assertEqual(pack.correction_count, 0) + + # Removing non-existent should return False + self.assertFalse(pack.remove_correction("corr_test1")) + + def test_find_by_key_hash(self): + """Test finding corrections by key hash.""" + pack = CorrectionPack(id=generate_pack_id(), name="Test Pack") + + key1 = CorrectionKey("click", "button", "not_found") + key2 = CorrectionKey("type", "input", "validation") + + correction1 = Correction(id="corr_1", key=key1, correction_type=CorrectionType.TARGET_CHANGE) + correction2 = Correction(id="corr_2", key=key2, correction_type=CorrectionType.PARAMETER_CHANGE) + + pack.add_correction(correction1) + pack.add_correction(correction2) + + # Find by key1 hash + found = pack.find_by_key_hash(key1.to_hash()) + self.assertEqual(len(found), 1) + self.assertEqual(found[0].id, "corr_1") + + def test_merge_packs(self): + """Test merging two packs.""" + pack1 = CorrectionPack(id="pack_1", name="Pack 1") + pack2 = CorrectionPack(id="pack_2", name="Pack 2") + + key = CorrectionKey("click", "button", "error") + corr1 = Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE) + corr2 = Correction(id="corr_2", key=key, correction_type=CorrectionType.WAIT_ADDED) + + pack1.add_correction(corr1) + pack2.add_correction(corr2) + + merged_count = pack1.merge(pack2) + self.assertEqual(merged_count, 1) + self.assertEqual(pack1.correction_count, 2) + + def test_get_statistics(self): + """Test getting pack statistics.""" + pack = CorrectionPack(id=generate_pack_id(), name="Test Pack") + + key = CorrectionKey("click", "button", "error") + correction = Correction( + id="corr_1", + key=key, + correction_type=CorrectionType.TARGET_CHANGE, + success_count=10, + failure_count=2 + ) + pack.add_correction(correction) + + stats = pack.get_statistics() + self.assertEqual(stats['total_corrections'], 1) + self.assertEqual(stats['active_corrections'], 1) + self.assertEqual(stats['total_applications'], 12) + + +class TestCorrectionRepository(unittest.TestCase): + """Tests for CorrectionRepository.""" + + def setUp(self): + """Create temporary directory for tests.""" + self.test_dir = tempfile.mkdtemp() + self.repo = CorrectionRepository(self.test_dir) + + def tearDown(self): + """Clean up temporary directory.""" + shutil.rmtree(self.test_dir) + + def test_create_and_get_pack(self): + """Test creating and retrieving a pack.""" + pack = self.repo.create_pack("Test Pack", "Description", { + 'category': 'testing', + 'tags': ['unit', 'test'] + }) + + self.assertIsNotNone(pack) + self.assertEqual(pack.name, "Test Pack") + + retrieved = self.repo.get_pack(pack.id) + self.assertIsNotNone(retrieved) + self.assertEqual(retrieved.name, "Test Pack") + + def test_list_packs(self): + """Test listing packs.""" + self.repo.create_pack("Pack 1", "First pack") + self.repo.create_pack("Pack 2", "Second pack") + + packs = self.repo.list_packs() + self.assertEqual(len(packs), 2) + + def test_update_pack(self): + """Test updating a pack.""" + pack = self.repo.create_pack("Original Name") + + updated = self.repo.update_pack(pack.id, {'name': 'New Name'}) + self.assertIsNotNone(updated) + self.assertEqual(updated.name, 'New Name') + + def test_delete_pack(self): + """Test deleting a pack.""" + pack = self.repo.create_pack("To Delete") + + self.assertTrue(self.repo.delete_pack(pack.id)) + self.assertIsNone(self.repo.get_pack(pack.id)) + self.assertFalse(self.repo.delete_pack(pack.id)) + + def test_add_and_find_correction(self): + """Test adding and finding corrections.""" + pack = self.repo.create_pack("Test Pack") + + key = CorrectionKey("click", "button", "timeout") + correction = Correction( + id=generate_correction_id(), + key=key, + correction_type=CorrectionType.TIMING_ADJUST + ) + + self.assertTrue(self.repo.add_correction(pack.id, correction)) + + # Find by context + found = self.repo.find_by_context("click", "button", "timeout") + self.assertEqual(len(found), 1) + self.assertEqual(found[0][1].id, correction.id) + + def test_export_import_pack(self): + """Test exporting and importing a pack.""" + # Create pack with correction + pack = self.repo.create_pack("Export Test") + key = CorrectionKey("click", "button", "error") + correction = Correction( + id="corr_export", + key=key, + correction_type=CorrectionType.TARGET_CHANGE + ) + self.repo.add_correction(pack.id, correction) + + # Export + exported = self.repo.export_pack(pack.id) + self.assertIsNotNone(exported) + + # Import as new + imported = self.repo.import_pack(exported, merge_strategy='create_new') + self.assertIsNotNone(imported) + self.assertNotEqual(imported.id, pack.id) + self.assertEqual(imported.correction_count, 1) + + def test_version_and_rollback(self): + """Test versioning and rollback.""" + pack = self.repo.create_pack("Version Test") + + # Add correction + key = CorrectionKey("click", "button", "error") + correction1 = Correction(id="corr_v1", key=key, correction_type=CorrectionType.TARGET_CHANGE) + self.repo.add_correction(pack.id, correction1) + + # Create version + version = self.repo.create_version(pack.id, "First version") + self.assertEqual(version, 1) + + # Add another correction + correction2 = Correction(id="corr_v2", key=key, correction_type=CorrectionType.WAIT_ADDED) + self.repo.add_correction(pack.id, correction2) + + # Verify 2 corrections + pack = self.repo.get_pack(pack.id) + self.assertEqual(pack.correction_count, 2) + + # Rollback to version 1 + self.assertTrue(self.repo.rollback_pack(pack.id, 1)) + + # Should have 1 correction again + pack = self.repo.get_pack(pack.id) + self.assertEqual(pack.correction_count, 1) + + +class TestCorrectionAggregator(unittest.TestCase): + """Tests for CorrectionAggregator.""" + + def setUp(self): + """Set up aggregator for tests.""" + self.aggregator = CorrectionAggregator( + min_occurrences=2, + min_success_rate=0.5, + min_confidence=0.3 + ) + + def test_aggregate_by_key(self): + """Test aggregating corrections by key.""" + key = CorrectionKey("click", "button", "not_found") + + # Create multiple corrections with same key + corrections = [ + Correction( + id=f"corr_{i}", + key=key, + correction_type=CorrectionType.WAIT_ADDED, + success_count=5, + failure_count=2 + ) + for i in range(3) + ] + + aggregated = self.aggregator.aggregate_corrections(corrections) + + # Should have 1 merged correction + self.assertEqual(len(aggregated), 1) + # Sum of success counts + self.assertEqual(aggregated[0].success_count, 15) + + def test_filter_by_quality(self): + """Test filtering by quality thresholds.""" + key = CorrectionKey("click", "button", "error") + + high_quality = Correction( + id="corr_high", + key=key, + correction_type=CorrectionType.TARGET_CHANGE, + success_count=9, + failure_count=1 + ) + + low_quality = Correction( + id="corr_low", + key=key, + correction_type=CorrectionType.TARGET_CHANGE, + success_count=1, + failure_count=9 + ) + + filtered = self.aggregator.filter_by_quality( + [high_quality, low_quality], + min_success_rate=0.7 + ) + + self.assertEqual(len(filtered), 1) + self.assertEqual(filtered[0].id, "corr_high") + + def test_rank_corrections(self): + """Test ranking corrections.""" + key = CorrectionKey("click", "button", "error") + + corrections = [ + Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=3, failure_count=7), + Correction(id="corr_2", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=8, failure_count=2), + Correction(id="corr_3", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=5, failure_count=5), + ] + + ranked = self.aggregator.rank_corrections(corrections, sort_by='success_rate') + + self.assertEqual(ranked[0].id, "corr_2") + self.assertEqual(ranked[1].id, "corr_3") + self.assertEqual(ranked[2].id, "corr_1") + + +class TestCorrectionPackService(unittest.TestCase): + """Tests for CorrectionPackService.""" + + def setUp(self): + """Set up service for tests.""" + self.test_dir = tempfile.mkdtemp() + self.training_dir = tempfile.mkdtemp() + self.service = CorrectionPackService(self.test_dir, self.training_dir) + + def tearDown(self): + """Clean up temporary directories.""" + shutil.rmtree(self.test_dir) + shutil.rmtree(self.training_dir) + + def test_create_and_list_packs(self): + """Test creating and listing packs.""" + pack = self.service.create_pack( + name="Service Test Pack", + description="Testing the service", + category="testing", + tags=["unit", "test"] + ) + + self.assertIsNotNone(pack) + self.assertEqual(pack['name'], "Service Test Pack") + + packs = self.service.list_packs() + self.assertEqual(len(packs), 1) + self.assertEqual(packs[0]['name'], "Service Test Pack") + + def test_add_correction(self): + """Test adding corrections via service.""" + pack = self.service.create_pack(name="Correction Test") + + correction = self.service.add_correction(pack['id'], { + 'action_type': 'click', + 'element_type': 'button', + 'failure_context': 'element_not_found', + 'correction_type': 'target_change', + 'original_target': {'selector': '#old'}, + 'corrected_target': {'selector': '#new'}, + 'description': 'Updated selector' + }) + + self.assertIsNotNone(correction) + self.assertTrue(correction['id'].startswith('corr_')) + + def test_find_applicable_corrections(self): + """Test finding applicable corrections.""" + pack = self.service.create_pack(name="Find Test") + + self.service.add_correction(pack['id'], { + 'action_type': 'click', + 'element_type': 'button', + 'failure_context': 'timeout', + 'correction_type': 'timing_adjust', + 'corrected_params': {'wait_time': 5} + }) + + found = self.service.find_applicable_corrections( + action_type='click', + element_type='button', + failure_context='timeout', + min_confidence=0.0 + ) + + self.assertEqual(len(found), 1) + + def test_export_import_file(self): + """Test file-based export and import.""" + pack = self.service.create_pack(name="File Export Test") + self.service.add_correction(pack['id'], { + 'action_type': 'type', + 'element_type': 'input', + 'failure_context': 'validation', + 'correction_type': 'parameter_change' + }) + + # Export to file + export_file = os.path.join(self.test_dir, 'export.json') + self.assertTrue(self.service.export_pack_to_file(pack['id'], export_file)) + self.assertTrue(os.path.exists(export_file)) + + # Import from file + imported = self.service.import_pack_from_file(export_file) + self.assertIsNotNone(imported) + self.assertEqual(len(imported['corrections']), 1) + + def test_version_management(self): + """Test version creation and rollback.""" + pack = self.service.create_pack(name="Version Test") + + # Add correction + self.service.add_correction(pack['id'], { + 'action_type': 'click', + 'element_type': 'button', + 'failure_context': 'error', + 'correction_type': 'target_change' + }) + + # Create version + version = self.service.create_version(pack['id'], "v1 snapshot") + self.assertEqual(version, 1) + + # List versions + versions = self.service.list_versions(pack['id']) + self.assertEqual(len(versions), 1) + + # Add another correction + self.service.add_correction(pack['id'], { + 'action_type': 'type', + 'element_type': 'input', + 'failure_context': 'error', + 'correction_type': 'parameter_change' + }) + + # Get pack - should have 2 corrections + current_pack = self.service.get_pack(pack['id']) + self.assertEqual(len(current_pack['corrections']), 2) + + # Rollback + self.assertTrue(self.service.rollback_pack(pack['id'], 1)) + + # Get pack - should have 1 correction + rolled_back = self.service.get_pack(pack['id']) + self.assertEqual(len(rolled_back['corrections']), 1) + + def test_statistics(self): + """Test getting statistics.""" + pack = self.service.create_pack(name="Stats Test") + self.service.add_correction(pack['id'], { + 'action_type': 'click', + 'element_type': 'button', + 'failure_context': 'error', + 'correction_type': 'target_change' + }) + + # Pack statistics + pack_stats = self.service.get_pack_statistics(pack['id']) + self.assertIsNotNone(pack_stats) + self.assertEqual(pack_stats['total_corrections'], 1) + + # Global statistics + global_stats = self.service.get_global_statistics() + self.assertIsNotNone(global_stats) + self.assertEqual(global_stats['total_packs'], 1) + + +class TestCrossWorkflowAggregator(unittest.TestCase): + """Tests for CrossWorkflowAggregator.""" + + def test_find_common_corrections(self): + """Test finding corrections common across workflows.""" + aggregator = CrossWorkflowAggregator() + + key = CorrectionKey("click", "submit_button", "timeout") + + # Corrections from different workflows + corrections = [ + Correction( + id="corr_1", + key=key, + correction_type=CorrectionType.WAIT_ADDED, + source=CorrectionSource(session_id="s1", workflow_id="wf_1") + ), + Correction( + id="corr_2", + key=key, + correction_type=CorrectionType.WAIT_ADDED, + source=CorrectionSource(session_id="s2", workflow_id="wf_2") + ), + Correction( + id="corr_3", + key=key, + correction_type=CorrectionType.WAIT_ADDED, + source=CorrectionSource(session_id="s3", workflow_id="wf_3") + ), + ] + + common = aggregator.find_common_corrections(corrections, min_workflows=2) + + self.assertEqual(len(common), 1) + merged, workflows = common[0] + self.assertEqual(len(workflows), 3) + + def test_detect_patterns(self): + """Test pattern detection.""" + aggregator = CrossWorkflowAggregator() + + key = CorrectionKey("click", "button", "error") + corrections = [ + Correction(id=f"corr_{i}", key=key, correction_type=CorrectionType.WAIT_ADDED) + for i in range(5) + ] + + patterns = aggregator.detect_patterns(corrections) + + self.assertIn('patterns', patterns) + self.assertIn('summary', patterns) + self.assertEqual(patterns['summary']['total_corrections'], 5) + + +if __name__ == '__main__': + unittest.main() diff --git a/visual_workflow_builder/backend/api/correction_packs.py b/visual_workflow_builder/backend/api/correction_packs.py new file mode 100644 index 000000000..72e273372 --- /dev/null +++ b/visual_workflow_builder/backend/api/correction_packs.py @@ -0,0 +1,640 @@ +""" +API endpoints pour les Correction Packs. + +Provides REST API for managing correction packs, corrections, +aggregation, export/import, and versioning. +""" + +import os +import sys +from flask import Blueprint, request, jsonify, Response +from typing import Optional + +# Add parent paths for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))) + +from core.corrections import CorrectionPackService + +correction_packs_bp = Blueprint('correction_packs', __name__) + +# Initialize service (singleton pattern) +_service: Optional[CorrectionPackService] = None + + +def get_service() -> CorrectionPackService: + """Get or create the correction pack service.""" + global _service + if _service is None: + # Use paths relative to the project root + base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) + storage_path = os.path.join(base_path, "data", "correction_packs") + training_path = os.path.join(base_path, "training_data") + _service = CorrectionPackService(storage_path, training_path) + return _service + + +# ========== Pack CRUD Endpoints ========== + +@correction_packs_bp.route('/correction-packs', methods=['GET']) +def list_packs(): + """ + List all correction packs. + + Query params: + category: Filter by category + tags: Filter by tags (comma-separated) + + Returns: + JSON list of pack summaries + """ + try: + service = get_service() + + category = request.args.get('category') + tags_param = request.args.get('tags') + tags = tags_param.split(',') if tags_param else None + + packs = service.list_packs(category=category, tags=tags) + return jsonify({'success': True, 'packs': packs}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs', methods=['POST']) +def create_pack(): + """ + Create a new correction pack. + + Request body: + name: Pack name (required) + description: Pack description + category: Pack category + tags: List of tags + author: Pack author + + Returns: + Created pack + """ + try: + service = get_service() + data = request.get_json() or {} + + if not data.get('name'): + return jsonify({'success': False, 'error': 'name is required'}), 400 + + pack = service.create_pack( + name=data['name'], + description=data.get('description', ''), + category=data.get('category', 'general'), + tags=data.get('tags', []), + author=data.get('author', '') + ) + + return jsonify({'success': True, 'pack': pack}), 201 + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs/', methods=['GET']) +def get_pack(pack_id: str): + """ + Get a pack by ID. + + Returns: + Pack details including all corrections + """ + try: + service = get_service() + pack = service.get_pack(pack_id) + + if not pack: + return jsonify({'success': False, 'error': 'Pack not found'}), 404 + + return jsonify({'success': True, 'pack': pack}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs/', methods=['PUT']) +def update_pack(pack_id: str): + """ + Update a pack. + + Request body: + name: New name + description: New description + metadata: New metadata dict + + Returns: + Updated pack + """ + try: + service = get_service() + data = request.get_json() or {} + + pack = service.update_pack( + pack_id=pack_id, + name=data.get('name'), + description=data.get('description'), + metadata=data.get('metadata') + ) + + if not pack: + return jsonify({'success': False, 'error': 'Pack not found'}), 404 + + return jsonify({'success': True, 'pack': pack}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs/', methods=['DELETE']) +def delete_pack(pack_id: str): + """ + Delete a pack. + + Returns: + Success status + """ + try: + service = get_service() + + if not service.delete_pack(pack_id): + return jsonify({'success': False, 'error': 'Pack not found'}), 404 + + return jsonify({'success': True, 'message': 'Pack deleted'}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# ========== Correction Endpoints ========== + +@correction_packs_bp.route('/correction-packs//corrections', methods=['POST']) +def add_correction(pack_id: str): + """ + Add a correction to a pack. + + Request body: + action_type: Type of action (required) + element_type: Type of element + failure_context: Description of failure + correction_type: Type of correction + original_target: Original target dict + original_params: Original parameters dict + corrected_target: Corrected target dict + corrected_params: Corrected parameters dict + description: Correction description + tags: List of tags + source: Source info dict + + Returns: + Created correction + """ + try: + service = get_service() + data = request.get_json() or {} + + if not data.get('action_type'): + return jsonify({'success': False, 'error': 'action_type is required'}), 400 + + correction = service.add_correction(pack_id, data) + + if not correction: + return jsonify({'success': False, 'error': 'Failed to add correction'}), 400 + + return jsonify({'success': True, 'correction': correction}), 201 + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs//corrections/', methods=['PUT']) +def update_correction(pack_id: str, correction_id: str): + """ + Update a correction. + + Request body: + corrected_target: New corrected target + corrected_params: New corrected params + correction_description: New description + status: New status + tags: New tags + + Returns: + Updated correction + """ + try: + service = get_service() + data = request.get_json() or {} + + correction = service.update_correction(pack_id, correction_id, data) + + if not correction: + return jsonify({'success': False, 'error': 'Correction not found'}), 404 + + return jsonify({'success': True, 'correction': correction}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs//corrections/', methods=['DELETE']) +def remove_correction(pack_id: str, correction_id: str): + """ + Remove a correction from a pack. + + Returns: + Success status + """ + try: + service = get_service() + + if not service.remove_correction(pack_id, correction_id): + return jsonify({'success': False, 'error': 'Correction not found'}), 404 + + return jsonify({'success': True, 'message': 'Correction removed'}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# ========== Aggregation Endpoints ========== + +@correction_packs_bp.route('/correction-packs//aggregate', methods=['POST']) +def aggregate_from_sessions(pack_id: str): + """ + Aggregate corrections from training sessions into a pack. + + Request body: + workflow_id: Filter by workflow ID + min_occurrences: Minimum occurrences to include + min_success_rate: Minimum success rate to include + session_files: Specific session files to process + + Returns: + Aggregation summary + """ + try: + service = get_service() + data = request.get_json() or {} + + result = service.aggregate_from_sessions( + pack_id=pack_id, + session_files=data.get('session_files'), + workflow_id=data.get('workflow_id'), + min_occurrences=data.get('min_occurrences', 1), + min_success_rate=data.get('min_success_rate', 0.0) + ) + + return jsonify({'success': True, 'result': result}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs//aggregate-cross-workflow', methods=['POST']) +def aggregate_cross_workflow(pack_id: str): + """ + Find and aggregate corrections common across multiple workflows. + + Request body: + min_workflows: Minimum number of workflows (default: 2) + + Returns: + Aggregation summary with patterns + """ + try: + service = get_service() + data = request.get_json() or {} + + result = service.aggregate_cross_workflow( + pack_id=pack_id, + min_workflows=data.get('min_workflows', 2) + ) + + return jsonify({'success': True, 'result': result}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# ========== Search and Apply Endpoints ========== + +@correction_packs_bp.route('/correction-packs/find', methods=['POST']) +def find_corrections(): + """ + Find corrections applicable to a failure context. + + Request body: + action_type: Type of action that failed (required) + element_type: Type of element involved + failure_context: Description of failure + pack_ids: Specific packs to search + min_confidence: Minimum confidence score + + Returns: + List of applicable corrections + """ + try: + service = get_service() + data = request.get_json() or {} + + if not data.get('action_type'): + return jsonify({'success': False, 'error': 'action_type is required'}), 400 + + corrections = service.find_applicable_corrections( + action_type=data['action_type'], + element_type=data.get('element_type', 'unknown'), + failure_context=data.get('failure_context', ''), + pack_ids=data.get('pack_ids'), + min_confidence=data.get('min_confidence', 0.3) + ) + + return jsonify({'success': True, 'corrections': corrections}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs/apply', methods=['POST']) +def apply_correction(): + """ + Record an application of a correction. + + Request body: + pack_id: Pack ID (required) + correction_id: Correction ID (required) + success: Whether the application was successful (required) + context: Optional application context + + Returns: + Success status + """ + try: + service = get_service() + data = request.get_json() or {} + + if not data.get('pack_id') or not data.get('correction_id'): + return jsonify({'success': False, 'error': 'pack_id and correction_id are required'}), 400 + + if 'success' not in data: + return jsonify({'success': False, 'error': 'success is required'}), 400 + + service.apply_correction( + pack_id=data['pack_id'], + correction_id=data['correction_id'], + success=data['success'], + application_context=data.get('context') + ) + + return jsonify({'success': True, 'message': 'Application recorded'}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs/search', methods=['GET']) +def search_corrections(): + """ + Search corrections by text query. + + Query params: + q: Search query (required) + pack_id: Optional pack ID to search in + + Returns: + List of matching corrections + """ + try: + service = get_service() + + query = request.args.get('q') + if not query: + return jsonify({'success': False, 'error': 'q parameter is required'}), 400 + + pack_id = request.args.get('pack_id') + + corrections = service.search_corrections(query, pack_id) + return jsonify({'success': True, 'corrections': corrections}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# ========== Export/Import Endpoints ========== + +@correction_packs_bp.route('/correction-packs//export', methods=['GET']) +def export_pack(pack_id: str): + """ + Export a pack to JSON or YAML. + + Query params: + format: Export format ('json' or 'yaml', default: 'json') + + Returns: + Exported content + """ + try: + service = get_service() + + format_type = request.args.get('format', 'json') + if format_type not in ['json', 'yaml']: + return jsonify({'success': False, 'error': 'Invalid format'}), 400 + + content = service.export_pack(pack_id, format=format_type) + + if not content: + return jsonify({'success': False, 'error': 'Pack not found'}), 404 + + # Determine content type + if format_type == 'yaml': + content_type = 'application/x-yaml' + filename = f'correction_pack_{pack_id}.yaml' + else: + content_type = 'application/json' + filename = f'correction_pack_{pack_id}.json' + + return Response( + content, + mimetype=content_type, + headers={ + 'Content-Disposition': f'attachment; filename={filename}' + } + ) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs/import', methods=['POST']) +def import_pack(): + """ + Import a pack from JSON or YAML. + + Request body or file upload: + content: Content to import (if JSON body) + format: Import format ('json' or 'yaml') + merge_strategy: How to handle existing pack + - 'create_new': Always create new pack + - 'merge': Merge into existing pack + - 'replace': Replace existing pack + + Returns: + Imported pack + """ + try: + service = get_service() + + # Check for file upload + if 'file' in request.files: + file = request.files['file'] + content = file.read().decode('utf-8') + filename = file.filename + + # Auto-detect format from filename + if filename and (filename.endswith('.yaml') or filename.endswith('.yml')): + format_type = 'yaml' + else: + format_type = 'json' + + merge_strategy = request.form.get('merge_strategy', 'create_new') + else: + # JSON body + data = request.get_json() or {} + content = data.get('content') + + if not content: + return jsonify({'success': False, 'error': 'content or file is required'}), 400 + + format_type = data.get('format', 'json') + merge_strategy = data.get('merge_strategy', 'create_new') + + pack = service.import_pack(content, format=format_type, merge_strategy=merge_strategy) + + if not pack: + return jsonify({'success': False, 'error': 'Failed to import pack'}), 400 + + return jsonify({'success': True, 'pack': pack}), 201 + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# ========== Version Endpoints ========== + +@correction_packs_bp.route('/correction-packs//versions', methods=['GET']) +def list_versions(pack_id: str): + """ + List all versions of a pack. + + Returns: + List of version info + """ + try: + service = get_service() + versions = service.list_versions(pack_id) + return jsonify({'success': True, 'versions': versions}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs//versions', methods=['POST']) +def create_version(pack_id: str): + """ + Create a version snapshot of a pack. + + Request body: + description: Version description + + Returns: + Created version number + """ + try: + service = get_service() + data = request.get_json() or {} + + version = service.create_version( + pack_id=pack_id, + description=data.get('description', '') + ) + + if version < 0: + return jsonify({'success': False, 'error': 'Pack not found'}), 404 + + return jsonify({'success': True, 'version': version}), 201 + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs//rollback', methods=['POST']) +def rollback_pack(pack_id: str): + """ + Rollback a pack to a previous version. + + Request body: + version: Version number to rollback to (required) + + Returns: + Success status + """ + try: + service = get_service() + data = request.get_json() or {} + + version = data.get('version') + if version is None: + return jsonify({'success': False, 'error': 'version is required'}), 400 + + if not service.rollback_pack(pack_id, version): + return jsonify({'success': False, 'error': 'Rollback failed'}), 400 + + return jsonify({'success': True, 'message': f'Rolled back to version {version}'}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +# ========== Statistics Endpoints ========== + +@correction_packs_bp.route('/correction-packs//statistics', methods=['GET']) +def get_pack_statistics(pack_id: str): + """ + Get detailed statistics for a pack. + + Returns: + Statistics dict + """ + try: + service = get_service() + stats = service.get_pack_statistics(pack_id) + + if not stats: + return jsonify({'success': False, 'error': 'Pack not found'}), 404 + + return jsonify({'success': True, 'statistics': stats}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 + + +@correction_packs_bp.route('/correction-packs/statistics', methods=['GET']) +def get_global_statistics(): + """ + Get global statistics across all packs. + + Returns: + Global statistics dict + """ + try: + service = get_service() + stats = service.get_global_statistics() + return jsonify({'success': True, 'statistics': stats}) + + except Exception as e: + return jsonify({'success': False, 'error': str(e)}), 500 diff --git a/visual_workflow_builder/backend/app.py b/visual_workflow_builder/backend/app.py new file mode 100644 index 000000000..3744bbd1a --- /dev/null +++ b/visual_workflow_builder/backend/app.py @@ -0,0 +1,201 @@ +""" +Visual Workflow Builder - Backend Flask Application + +This is the main entry point for the Visual Workflow Builder backend API. +It provides REST endpoints for workflow management and WebSocket support +for real-time execution updates. +""" + +from flask import Flask +from flask_cors import CORS +from flask_socketio import SocketIO +from flask_sqlalchemy import SQLAlchemy +from flask_caching import Cache +import os +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Initialize Flask app +app = Flask(__name__) + +# Configuration +app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', 'dev-secret-key-change-in-production') +app.config['SQLALCHEMY_DATABASE_URI'] = os.getenv('DATABASE_URL', 'sqlite:///workflows.db') +app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False +app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024 # 10MB max upload +app.config['CACHE_TYPE'] = 'redis' if os.getenv('REDIS_URL') else 'simple' +app.config['CACHE_REDIS_URL'] = os.getenv('REDIS_URL', 'redis://localhost:6379/0') + +# Initialize extensions +db = SQLAlchemy(app) +cache = Cache(app) +socketio = SocketIO( + app, + cors_allowed_origins="*", + async_mode='threading', + logger=True, + engineio_logger=True +) + +# Enable CORS +CORS(app, resources={ + r"/api/*": { + "origins": os.getenv('CORS_ORIGINS', 'http://localhost:3000').split(','), + "methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"], + "allow_headers": ["Content-Type", "Authorization"] + } +}) + +# Import and register blueprints (minimal set) +from api.workflows import workflows_bp +from api.screen_capture import screen_capture_bp +from api.real_demo import real_demo_bp +from api.errors import error_response + +app.register_blueprint(workflows_bp, url_prefix='/api/workflows') +app.register_blueprint(screen_capture_bp, url_prefix='/api/screen-capture') +app.register_blueprint(real_demo_bp) + +# Optional / Phase 2+ blueprints (loaded only if modules are available) +try: + from api.self_healing import self_healing_bp + app.register_blueprint(self_healing_bp) +except ImportError as e: + print(f"⚠️ Blueprint self_healing désactivé: {e}") + +try: + from api.visual_targets import visual_targets_bp, init_visual_target_manager + app.register_blueprint(visual_targets_bp) + VISUAL_TARGETS_BP_AVAILABLE = True +except ImportError as e: + print(f"⚠️ Blueprint visual_targets désactivé: {e}") + VISUAL_TARGETS_BP_AVAILABLE = False + init_visual_target_manager = None + +try: + from api.element_detection import element_detection_bp, init_element_detection + app.register_blueprint(element_detection_bp) + ELEMENT_DETECTION_BP_AVAILABLE = True +except ImportError as e: + print(f"⚠️ Blueprint element_detection désactivé: {e}") + ELEMENT_DETECTION_BP_AVAILABLE = False + init_element_detection = None + +try: + from api.analytics import analytics_bp + app.register_blueprint(analytics_bp, url_prefix='/api/analytics') +except ImportError: + pass + + +# Register other blueprints (optional - depends on Phase 2+ services) +try: + from api.templates import templates_bp + app.register_blueprint(templates_bp, url_prefix='/api/templates') +except ImportError as e: + print(f"⚠️ Blueprint templates désactivé: {e}") + +from api.node_types import node_types_bp +app.register_blueprint(node_types_bp, url_prefix='/api/node-types') + +try: + from api.executions import executions_bp + app.register_blueprint(executions_bp, url_prefix='/api/executions') +except ImportError as e: + print(f"⚠️ Blueprint executions désactivé: {e}") + +try: + from api.import_export import import_export_bp + app.register_blueprint(import_export_bp, url_prefix='/api') +except ImportError as e: + print(f"⚠️ Blueprint import_export désactivé: {e}") + +try: + from api.correction_packs import correction_packs_bp + app.register_blueprint(correction_packs_bp, url_prefix='/api') + print("✅ Blueprint correction_packs enregistré") +except ImportError as e: + print(f"⚠️ Blueprint correction_packs désactivé: {e}") + + +# Import WebSocket handlers (optional) +try: + from api import websocket_handlers # noqa: F401 +except Exception as e: + print(f"⚠️ WebSocket handlers désactivés: {e}") + +# Global error handlers +@app.errorhandler(404) +def not_found(error): + """Handle 404 errors""" + return error_response(404, "Resource not found") + +@app.errorhandler(405) +def method_not_allowed(error): + """Handle 405 errors""" + return error_response(405, "Method not allowed") + +@app.errorhandler(500) +def internal_error(error): + """Handle 500 errors""" + return error_response(500, "Internal server error") + +@app.errorhandler(Exception) +def handle_exception(error): + """Handle all unhandled exceptions""" + import traceback + traceback.print_exc() + return error_response(500, f"Unexpected error: {str(error)}") + +# Health check endpoint +@app.route('/health') +def health_check(): + """Health check endpoint for monitoring""" + return {'status': 'healthy', 'version': '1.0.0'} + +# Create database tables +with app.app_context(): + db.create_all() + +# Initialize VisualTargetManager with RPA Vision V3 components (optional) +try: + from core.capture.screen_capturer import ScreenCapturer + from core.detection.ui_detector import UIDetector + from core.embedding.fusion_engine import FusionEngine + + # Only initialize if the related blueprints were actually loaded + if VISUAL_TARGETS_BP_AVAILABLE and init_visual_target_manager: + screen_capturer = ScreenCapturer() + ui_detector = UIDetector() + fusion_engine = FusionEngine() + init_visual_target_manager(screen_capturer, ui_detector, fusion_engine) + + if ELEMENT_DETECTION_BP_AVAILABLE and init_element_detection: + # Reuse the same instances when possible + if 'ui_detector' not in locals(): + ui_detector = UIDetector() + if 'screen_capturer' not in locals(): + screen_capturer = ScreenCapturer() + init_element_detection(ui_detector, screen_capturer) + + if (VISUAL_TARGETS_BP_AVAILABLE and init_visual_target_manager) or (ELEMENT_DETECTION_BP_AVAILABLE and init_element_detection): + print("✅ Services visuels initialisés (VisualTargets / ElementDetection)") +except ImportError as e: + print(f"⚠️ Core RPA non disponible pour l'initialisation visuelle: {e}") +except Exception as e: + print(f"❌ Erreur lors de l'initialisation des services visuels: {e}") + +if __name__ == '__main__': + port = int(os.getenv('PORT', 5002)) + debug = os.getenv('FLASK_ENV') == 'development' + + socketio.run( + app, + host='0.0.0.0', + port=port, + debug=debug, + use_reloader=debug, + allow_unsafe_werkzeug=True # For development only + )