""" Data models for the Correction Packs system. Defines CorrectionKey, Correction, CorrectionPack and related types. """ import hashlib import uuid from dataclasses import dataclass, field, asdict from datetime import datetime from enum import Enum from typing import Dict, List, Any, Optional class CorrectionType(Enum): """Types of corrections that can be applied.""" TARGET_CHANGE = "target_change" # Changed target selector PARAMETER_CHANGE = "parameter_change" # Changed action parameters ACTION_REPLACE = "action_replace" # Replaced action type WAIT_ADDED = "wait_added" # Added wait before action RETRY_CONFIG = "retry_config" # Changed retry configuration FALLBACK_ADDED = "fallback_added" # Added fallback action COORDINATES_ADJUST = "coordinates_adjust" # Adjusted click coordinates TIMING_ADJUST = "timing_adjust" # Adjusted timing/delay OTHER = "other" class CorrectionStatus(Enum): """Status of a correction.""" ACTIVE = "active" # Correction is active and can be applied DEPRECATED = "deprecated" # Correction is deprecated DISABLED = "disabled" # Correction is manually disabled @dataclass class CorrectionKey: """ Unique key to identify similar corrections across workflows. Uses action_type + element_type + failure_context to generate MD5 hash. Similar to LearningRepository pattern key generation. """ action_type: str element_type: str failure_context: str def to_hash(self) -> str: """Generate MD5 hash of the correction key (16 chars).""" key_string = f"{self.action_type}|{self.element_type}|{self.failure_context}" return hashlib.md5(key_string.encode()).hexdigest()[:16] def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return { 'action_type': self.action_type, 'element_type': self.element_type, 'failure_context': self.failure_context, 'hash': self.to_hash() } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionKey': """Create from dictionary.""" return cls( action_type=data.get('action_type', 'unknown'), element_type=data.get('element_type', 'unknown'), failure_context=data.get('failure_context', '') ) @dataclass class CorrectionSource: """Source information for a correction.""" session_id: str workflow_id: Optional[str] = None node_id: Optional[str] = None timestamp: datetime = field(default_factory=datetime.now) def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return { 'session_id': self.session_id, 'workflow_id': self.workflow_id, 'node_id': self.node_id, 'timestamp': self.timestamp.isoformat() } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionSource': """Create from dictionary.""" timestamp = data.get('timestamp') if isinstance(timestamp, str): timestamp = datetime.fromisoformat(timestamp) elif timestamp is None: timestamp = datetime.now() return cls( session_id=data.get('session_id', ''), workflow_id=data.get('workflow_id'), node_id=data.get('node_id'), timestamp=timestamp ) @dataclass class Correction: """ Individual correction record. Contains the original context, the applied correction, source information, and statistics about success/failure. """ id: str key: CorrectionKey correction_type: CorrectionType # Original context original_target: Optional[Dict[str, Any]] = None original_params: Optional[Dict[str, Any]] = None # Correction applied corrected_target: Optional[Dict[str, Any]] = None corrected_params: Optional[Dict[str, Any]] = None correction_description: str = "" # Source information source: Optional[CorrectionSource] = None # Statistics success_count: int = 0 failure_count: int = 0 last_applied: Optional[datetime] = None # Status status: CorrectionStatus = CorrectionStatus.ACTIVE # Metadata created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) tags: List[str] = field(default_factory=list) @property def success_rate(self) -> float: """Calculate success rate.""" total = self.success_count + self.failure_count if total == 0: return 0.0 return self.success_count / total @property def confidence_score(self) -> float: """Calculate confidence score based on success rate and usage count.""" total = self.success_count + self.failure_count if total == 0: return 0.0 # Confidence increases with more data points (up to 100 uses) usage_factor = min(1.0, total / 100) return self.success_rate * (0.5 + 0.5 * usage_factor) def record_application(self, success: bool): """Record an application of this correction.""" if success: self.success_count += 1 else: self.failure_count += 1 self.last_applied = datetime.now() self.updated_at = datetime.now() def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return { 'id': self.id, 'key': self.key.to_dict(), 'correction_type': self.correction_type.value, 'original_target': self.original_target, 'original_params': self.original_params, 'corrected_target': self.corrected_target, 'corrected_params': self.corrected_params, 'correction_description': self.correction_description, 'source': self.source.to_dict() if self.source else None, 'success_count': self.success_count, 'failure_count': self.failure_count, 'success_rate': self.success_rate, 'confidence_score': self.confidence_score, 'last_applied': self.last_applied.isoformat() if self.last_applied else None, 'status': self.status.value, 'created_at': self.created_at.isoformat(), 'updated_at': self.updated_at.isoformat(), 'tags': self.tags } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'Correction': """Create from dictionary.""" # Parse datetime fields created_at = data.get('created_at') if isinstance(created_at, str): created_at = datetime.fromisoformat(created_at) elif created_at is None: created_at = datetime.now() updated_at = data.get('updated_at') if isinstance(updated_at, str): updated_at = datetime.fromisoformat(updated_at) elif updated_at is None: updated_at = datetime.now() last_applied = data.get('last_applied') if isinstance(last_applied, str): last_applied = datetime.fromisoformat(last_applied) # Parse nested objects key_data = data.get('key', {}) key = CorrectionKey.from_dict(key_data) source_data = data.get('source') source = CorrectionSource.from_dict(source_data) if source_data else None return cls( id=data.get('id', str(uuid.uuid4())), key=key, correction_type=CorrectionType(data.get('correction_type', 'other')), original_target=data.get('original_target'), original_params=data.get('original_params'), corrected_target=data.get('corrected_target'), corrected_params=data.get('corrected_params'), correction_description=data.get('correction_description', ''), source=source, success_count=data.get('success_count', 0), failure_count=data.get('failure_count', 0), last_applied=last_applied, status=CorrectionStatus(data.get('status', 'active')), created_at=created_at, updated_at=updated_at, tags=data.get('tags', []) ) @dataclass class CorrectionPackMetadata: """Metadata for a correction pack.""" version: str = "1.0.0" author: str = "" category: str = "general" tags: List[str] = field(default_factory=list) description: str = "" def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return { 'version': self.version, 'author': self.author, 'category': self.category, 'tags': self.tags, 'description': self.description } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPackMetadata': """Create from dictionary.""" return cls( version=data.get('version', '1.0.0'), author=data.get('author', ''), category=data.get('category', 'general'), tags=data.get('tags', []), description=data.get('description', '') ) @dataclass class CorrectionPack: """ A pack of corrections that can be exported, imported, and shared. Contains multiple corrections with metadata and statistics. """ id: str name: str description: str = "" # Corrections in this pack corrections: Dict[str, Correction] = field(default_factory=dict) # Metadata metadata: CorrectionPackMetadata = field(default_factory=CorrectionPackMetadata) # Timestamps created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) # Version history current_version: int = 1 @property def correction_count(self) -> int: """Number of corrections in the pack.""" return len(self.corrections) @property def total_applications(self) -> int: """Total number of times corrections were applied.""" return sum(c.success_count + c.failure_count for c in self.corrections.values()) @property def overall_success_rate(self) -> float: """Overall success rate across all corrections.""" total_success = sum(c.success_count for c in self.corrections.values()) total_failure = sum(c.failure_count for c in self.corrections.values()) total = total_success + total_failure if total == 0: return 0.0 return total_success / total @property def active_corrections(self) -> List[Correction]: """Get only active corrections.""" return [c for c in self.corrections.values() if c.status == CorrectionStatus.ACTIVE] def add_correction(self, correction: Correction) -> bool: """ Add a correction to the pack. Returns True if added, False if already exists. """ if correction.id in self.corrections: return False self.corrections[correction.id] = correction self.updated_at = datetime.now() return True def remove_correction(self, correction_id: str) -> bool: """ Remove a correction from the pack. Returns True if removed, False if not found. """ if correction_id not in self.corrections: return False del self.corrections[correction_id] self.updated_at = datetime.now() return True def get_correction(self, correction_id: str) -> Optional[Correction]: """Get a correction by ID.""" return self.corrections.get(correction_id) def find_by_key_hash(self, key_hash: str) -> List[Correction]: """Find corrections matching a key hash.""" return [ c for c in self.corrections.values() if c.key.to_hash() == key_hash and c.status == CorrectionStatus.ACTIVE ] def merge(self, other: 'CorrectionPack', conflict_strategy: str = 'keep_higher_confidence') -> int: """ Merge another pack into this one. Args: other: Pack to merge conflict_strategy: How to handle conflicts: - 'keep_higher_confidence': Keep correction with higher confidence - 'keep_newer': Keep the most recently updated correction - 'keep_existing': Keep existing, ignore new - 'replace': Always replace with new Returns: Number of corrections added/updated """ merged_count = 0 for correction_id, correction in other.corrections.items(): if correction_id in self.corrections: existing = self.corrections[correction_id] should_replace = False if conflict_strategy == 'keep_higher_confidence': should_replace = correction.confidence_score > existing.confidence_score elif conflict_strategy == 'keep_newer': should_replace = correction.updated_at > existing.updated_at elif conflict_strategy == 'replace': should_replace = True # 'keep_existing' doesn't replace if should_replace: self.corrections[correction_id] = correction merged_count += 1 else: self.corrections[correction_id] = correction merged_count += 1 if merged_count > 0: self.updated_at = datetime.now() return merged_count def get_statistics(self) -> Dict[str, Any]: """Get aggregate statistics for the pack.""" corrections = list(self.corrections.values()) active = [c for c in corrections if c.status == CorrectionStatus.ACTIVE] # Group by correction type by_type = {} for c in corrections: type_name = c.correction_type.value if type_name not in by_type: by_type[type_name] = 0 by_type[type_name] += 1 return { 'total_corrections': len(corrections), 'active_corrections': len(active), 'deprecated_corrections': len([c for c in corrections if c.status == CorrectionStatus.DEPRECATED]), 'disabled_corrections': len([c for c in corrections if c.status == CorrectionStatus.DISABLED]), 'total_applications': self.total_applications, 'overall_success_rate': self.overall_success_rate, 'corrections_by_type': by_type, 'average_confidence': sum(c.confidence_score for c in active) / len(active) if active else 0.0 } def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return { 'id': self.id, 'name': self.name, 'description': self.description, 'corrections': {cid: c.to_dict() for cid, c in self.corrections.items()}, 'metadata': self.metadata.to_dict(), 'created_at': self.created_at.isoformat(), 'updated_at': self.updated_at.isoformat(), 'current_version': self.current_version, 'statistics': self.get_statistics() } @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPack': """Create from dictionary.""" # Parse datetime fields created_at = data.get('created_at') if isinstance(created_at, str): created_at = datetime.fromisoformat(created_at) elif created_at is None: created_at = datetime.now() updated_at = data.get('updated_at') if isinstance(updated_at, str): updated_at = datetime.fromisoformat(updated_at) elif updated_at is None: updated_at = datetime.now() # Parse corrections corrections_data = data.get('corrections', {}) corrections = { cid: Correction.from_dict(cdata) for cid, cdata in corrections_data.items() } # Parse metadata metadata_data = data.get('metadata', {}) metadata = CorrectionPackMetadata.from_dict(metadata_data) return cls( id=data.get('id', str(uuid.uuid4())), name=data.get('name', 'Unnamed Pack'), description=data.get('description', ''), corrections=corrections, metadata=metadata, created_at=created_at, updated_at=updated_at, current_version=data.get('current_version', 1) ) def generate_correction_id() -> str: """Generate a unique correction ID.""" return f"corr_{uuid.uuid4().hex[:12]}" def generate_pack_id() -> str: """Generate a unique pack ID.""" return f"pack_{uuid.uuid4().hex[:12]}"