Implement a complete system for capitalizing user corrections across multiple workflows and sessions. This enables automatic application of learned fixes when similar failures occur in different contexts. New components: - core/corrections/models.py: CorrectionKey, Correction, CorrectionPack models - core/corrections/correction_repository.py: JSON storage with atomic writes - core/corrections/aggregator.py: Aggregation by hash and quality filtering - core/corrections/correction_pack_service.py: CRUD, export/import, versioning - backend/api/correction_packs.py: REST API with 15 endpoints Features: - MD5-based key hashing for correction deduplication - Export/import in JSON and YAML formats - Version history with rollback support - Cross-workflow pattern detection - Integration with SelfHealingEngine for automatic application - 29 unit tests (all passing) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
481 lines
16 KiB
Python
481 lines
16 KiB
Python
"""
|
|
Data models for the Correction Packs system.
|
|
|
|
Defines CorrectionKey, Correction, CorrectionPack and related types.
|
|
"""
|
|
|
|
import hashlib
|
|
import uuid
|
|
from dataclasses import dataclass, field, asdict
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Dict, List, Any, Optional
|
|
|
|
|
|
class CorrectionType(Enum):
|
|
"""Types of corrections that can be applied."""
|
|
TARGET_CHANGE = "target_change" # Changed target selector
|
|
PARAMETER_CHANGE = "parameter_change" # Changed action parameters
|
|
ACTION_REPLACE = "action_replace" # Replaced action type
|
|
WAIT_ADDED = "wait_added" # Added wait before action
|
|
RETRY_CONFIG = "retry_config" # Changed retry configuration
|
|
FALLBACK_ADDED = "fallback_added" # Added fallback action
|
|
COORDINATES_ADJUST = "coordinates_adjust" # Adjusted click coordinates
|
|
TIMING_ADJUST = "timing_adjust" # Adjusted timing/delay
|
|
OTHER = "other"
|
|
|
|
|
|
class CorrectionStatus(Enum):
|
|
"""Status of a correction."""
|
|
ACTIVE = "active" # Correction is active and can be applied
|
|
DEPRECATED = "deprecated" # Correction is deprecated
|
|
DISABLED = "disabled" # Correction is manually disabled
|
|
|
|
|
|
@dataclass
|
|
class CorrectionKey:
|
|
"""
|
|
Unique key to identify similar corrections across workflows.
|
|
|
|
Uses action_type + element_type + failure_context to generate MD5 hash.
|
|
Similar to LearningRepository pattern key generation.
|
|
"""
|
|
action_type: str
|
|
element_type: str
|
|
failure_context: str
|
|
|
|
def to_hash(self) -> str:
|
|
"""Generate MD5 hash of the correction key (16 chars)."""
|
|
key_string = f"{self.action_type}|{self.element_type}|{self.failure_context}"
|
|
return hashlib.md5(key_string.encode()).hexdigest()[:16]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
'action_type': self.action_type,
|
|
'element_type': self.element_type,
|
|
'failure_context': self.failure_context,
|
|
'hash': self.to_hash()
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionKey':
|
|
"""Create from dictionary."""
|
|
return cls(
|
|
action_type=data.get('action_type', 'unknown'),
|
|
element_type=data.get('element_type', 'unknown'),
|
|
failure_context=data.get('failure_context', '')
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class CorrectionSource:
|
|
"""Source information for a correction."""
|
|
session_id: str
|
|
workflow_id: Optional[str] = None
|
|
node_id: Optional[str] = None
|
|
timestamp: datetime = field(default_factory=datetime.now)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
'session_id': self.session_id,
|
|
'workflow_id': self.workflow_id,
|
|
'node_id': self.node_id,
|
|
'timestamp': self.timestamp.isoformat()
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionSource':
|
|
"""Create from dictionary."""
|
|
timestamp = data.get('timestamp')
|
|
if isinstance(timestamp, str):
|
|
timestamp = datetime.fromisoformat(timestamp)
|
|
elif timestamp is None:
|
|
timestamp = datetime.now()
|
|
|
|
return cls(
|
|
session_id=data.get('session_id', ''),
|
|
workflow_id=data.get('workflow_id'),
|
|
node_id=data.get('node_id'),
|
|
timestamp=timestamp
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Correction:
|
|
"""
|
|
Individual correction record.
|
|
|
|
Contains the original context, the applied correction, source information,
|
|
and statistics about success/failure.
|
|
"""
|
|
id: str
|
|
key: CorrectionKey
|
|
correction_type: CorrectionType
|
|
|
|
# Original context
|
|
original_target: Optional[Dict[str, Any]] = None
|
|
original_params: Optional[Dict[str, Any]] = None
|
|
|
|
# Correction applied
|
|
corrected_target: Optional[Dict[str, Any]] = None
|
|
corrected_params: Optional[Dict[str, Any]] = None
|
|
correction_description: str = ""
|
|
|
|
# Source information
|
|
source: Optional[CorrectionSource] = None
|
|
|
|
# Statistics
|
|
success_count: int = 0
|
|
failure_count: int = 0
|
|
last_applied: Optional[datetime] = None
|
|
|
|
# Status
|
|
status: CorrectionStatus = CorrectionStatus.ACTIVE
|
|
|
|
# Metadata
|
|
created_at: datetime = field(default_factory=datetime.now)
|
|
updated_at: datetime = field(default_factory=datetime.now)
|
|
tags: List[str] = field(default_factory=list)
|
|
|
|
@property
|
|
def success_rate(self) -> float:
|
|
"""Calculate success rate."""
|
|
total = self.success_count + self.failure_count
|
|
if total == 0:
|
|
return 0.0
|
|
return self.success_count / total
|
|
|
|
@property
|
|
def confidence_score(self) -> float:
|
|
"""Calculate confidence score based on success rate and usage count."""
|
|
total = self.success_count + self.failure_count
|
|
if total == 0:
|
|
return 0.0
|
|
# Confidence increases with more data points (up to 100 uses)
|
|
usage_factor = min(1.0, total / 100)
|
|
return self.success_rate * (0.5 + 0.5 * usage_factor)
|
|
|
|
def record_application(self, success: bool):
|
|
"""Record an application of this correction."""
|
|
if success:
|
|
self.success_count += 1
|
|
else:
|
|
self.failure_count += 1
|
|
self.last_applied = datetime.now()
|
|
self.updated_at = datetime.now()
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
'id': self.id,
|
|
'key': self.key.to_dict(),
|
|
'correction_type': self.correction_type.value,
|
|
'original_target': self.original_target,
|
|
'original_params': self.original_params,
|
|
'corrected_target': self.corrected_target,
|
|
'corrected_params': self.corrected_params,
|
|
'correction_description': self.correction_description,
|
|
'source': self.source.to_dict() if self.source else None,
|
|
'success_count': self.success_count,
|
|
'failure_count': self.failure_count,
|
|
'success_rate': self.success_rate,
|
|
'confidence_score': self.confidence_score,
|
|
'last_applied': self.last_applied.isoformat() if self.last_applied else None,
|
|
'status': self.status.value,
|
|
'created_at': self.created_at.isoformat(),
|
|
'updated_at': self.updated_at.isoformat(),
|
|
'tags': self.tags
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'Correction':
|
|
"""Create from dictionary."""
|
|
# Parse datetime fields
|
|
created_at = data.get('created_at')
|
|
if isinstance(created_at, str):
|
|
created_at = datetime.fromisoformat(created_at)
|
|
elif created_at is None:
|
|
created_at = datetime.now()
|
|
|
|
updated_at = data.get('updated_at')
|
|
if isinstance(updated_at, str):
|
|
updated_at = datetime.fromisoformat(updated_at)
|
|
elif updated_at is None:
|
|
updated_at = datetime.now()
|
|
|
|
last_applied = data.get('last_applied')
|
|
if isinstance(last_applied, str):
|
|
last_applied = datetime.fromisoformat(last_applied)
|
|
|
|
# Parse nested objects
|
|
key_data = data.get('key', {})
|
|
key = CorrectionKey.from_dict(key_data)
|
|
|
|
source_data = data.get('source')
|
|
source = CorrectionSource.from_dict(source_data) if source_data else None
|
|
|
|
return cls(
|
|
id=data.get('id', str(uuid.uuid4())),
|
|
key=key,
|
|
correction_type=CorrectionType(data.get('correction_type', 'other')),
|
|
original_target=data.get('original_target'),
|
|
original_params=data.get('original_params'),
|
|
corrected_target=data.get('corrected_target'),
|
|
corrected_params=data.get('corrected_params'),
|
|
correction_description=data.get('correction_description', ''),
|
|
source=source,
|
|
success_count=data.get('success_count', 0),
|
|
failure_count=data.get('failure_count', 0),
|
|
last_applied=last_applied,
|
|
status=CorrectionStatus(data.get('status', 'active')),
|
|
created_at=created_at,
|
|
updated_at=updated_at,
|
|
tags=data.get('tags', [])
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class CorrectionPackMetadata:
|
|
"""Metadata for a correction pack."""
|
|
version: str = "1.0.0"
|
|
author: str = ""
|
|
category: str = "general"
|
|
tags: List[str] = field(default_factory=list)
|
|
description: str = ""
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
'version': self.version,
|
|
'author': self.author,
|
|
'category': self.category,
|
|
'tags': self.tags,
|
|
'description': self.description
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPackMetadata':
|
|
"""Create from dictionary."""
|
|
return cls(
|
|
version=data.get('version', '1.0.0'),
|
|
author=data.get('author', ''),
|
|
category=data.get('category', 'general'),
|
|
tags=data.get('tags', []),
|
|
description=data.get('description', '')
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class CorrectionPack:
|
|
"""
|
|
A pack of corrections that can be exported, imported, and shared.
|
|
|
|
Contains multiple corrections with metadata and statistics.
|
|
"""
|
|
id: str
|
|
name: str
|
|
description: str = ""
|
|
|
|
# Corrections in this pack
|
|
corrections: Dict[str, Correction] = field(default_factory=dict)
|
|
|
|
# Metadata
|
|
metadata: CorrectionPackMetadata = field(default_factory=CorrectionPackMetadata)
|
|
|
|
# Timestamps
|
|
created_at: datetime = field(default_factory=datetime.now)
|
|
updated_at: datetime = field(default_factory=datetime.now)
|
|
|
|
# Version history
|
|
current_version: int = 1
|
|
|
|
@property
|
|
def correction_count(self) -> int:
|
|
"""Number of corrections in the pack."""
|
|
return len(self.corrections)
|
|
|
|
@property
|
|
def total_applications(self) -> int:
|
|
"""Total number of times corrections were applied."""
|
|
return sum(c.success_count + c.failure_count for c in self.corrections.values())
|
|
|
|
@property
|
|
def overall_success_rate(self) -> float:
|
|
"""Overall success rate across all corrections."""
|
|
total_success = sum(c.success_count for c in self.corrections.values())
|
|
total_failure = sum(c.failure_count for c in self.corrections.values())
|
|
total = total_success + total_failure
|
|
if total == 0:
|
|
return 0.0
|
|
return total_success / total
|
|
|
|
@property
|
|
def active_corrections(self) -> List[Correction]:
|
|
"""Get only active corrections."""
|
|
return [c for c in self.corrections.values() if c.status == CorrectionStatus.ACTIVE]
|
|
|
|
def add_correction(self, correction: Correction) -> bool:
|
|
"""
|
|
Add a correction to the pack.
|
|
|
|
Returns True if added, False if already exists.
|
|
"""
|
|
if correction.id in self.corrections:
|
|
return False
|
|
self.corrections[correction.id] = correction
|
|
self.updated_at = datetime.now()
|
|
return True
|
|
|
|
def remove_correction(self, correction_id: str) -> bool:
|
|
"""
|
|
Remove a correction from the pack.
|
|
|
|
Returns True if removed, False if not found.
|
|
"""
|
|
if correction_id not in self.corrections:
|
|
return False
|
|
del self.corrections[correction_id]
|
|
self.updated_at = datetime.now()
|
|
return True
|
|
|
|
def get_correction(self, correction_id: str) -> Optional[Correction]:
|
|
"""Get a correction by ID."""
|
|
return self.corrections.get(correction_id)
|
|
|
|
def find_by_key_hash(self, key_hash: str) -> List[Correction]:
|
|
"""Find corrections matching a key hash."""
|
|
return [
|
|
c for c in self.corrections.values()
|
|
if c.key.to_hash() == key_hash and c.status == CorrectionStatus.ACTIVE
|
|
]
|
|
|
|
def merge(self, other: 'CorrectionPack', conflict_strategy: str = 'keep_higher_confidence') -> int:
|
|
"""
|
|
Merge another pack into this one.
|
|
|
|
Args:
|
|
other: Pack to merge
|
|
conflict_strategy: How to handle conflicts:
|
|
- 'keep_higher_confidence': Keep correction with higher confidence
|
|
- 'keep_newer': Keep the most recently updated correction
|
|
- 'keep_existing': Keep existing, ignore new
|
|
- 'replace': Always replace with new
|
|
|
|
Returns:
|
|
Number of corrections added/updated
|
|
"""
|
|
merged_count = 0
|
|
|
|
for correction_id, correction in other.corrections.items():
|
|
if correction_id in self.corrections:
|
|
existing = self.corrections[correction_id]
|
|
|
|
should_replace = False
|
|
if conflict_strategy == 'keep_higher_confidence':
|
|
should_replace = correction.confidence_score > existing.confidence_score
|
|
elif conflict_strategy == 'keep_newer':
|
|
should_replace = correction.updated_at > existing.updated_at
|
|
elif conflict_strategy == 'replace':
|
|
should_replace = True
|
|
# 'keep_existing' doesn't replace
|
|
|
|
if should_replace:
|
|
self.corrections[correction_id] = correction
|
|
merged_count += 1
|
|
else:
|
|
self.corrections[correction_id] = correction
|
|
merged_count += 1
|
|
|
|
if merged_count > 0:
|
|
self.updated_at = datetime.now()
|
|
|
|
return merged_count
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""Get aggregate statistics for the pack."""
|
|
corrections = list(self.corrections.values())
|
|
active = [c for c in corrections if c.status == CorrectionStatus.ACTIVE]
|
|
|
|
# Group by correction type
|
|
by_type = {}
|
|
for c in corrections:
|
|
type_name = c.correction_type.value
|
|
if type_name not in by_type:
|
|
by_type[type_name] = 0
|
|
by_type[type_name] += 1
|
|
|
|
return {
|
|
'total_corrections': len(corrections),
|
|
'active_corrections': len(active),
|
|
'deprecated_corrections': len([c for c in corrections if c.status == CorrectionStatus.DEPRECATED]),
|
|
'disabled_corrections': len([c for c in corrections if c.status == CorrectionStatus.DISABLED]),
|
|
'total_applications': self.total_applications,
|
|
'overall_success_rate': self.overall_success_rate,
|
|
'corrections_by_type': by_type,
|
|
'average_confidence': sum(c.confidence_score for c in active) / len(active) if active else 0.0
|
|
}
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
'id': self.id,
|
|
'name': self.name,
|
|
'description': self.description,
|
|
'corrections': {cid: c.to_dict() for cid, c in self.corrections.items()},
|
|
'metadata': self.metadata.to_dict(),
|
|
'created_at': self.created_at.isoformat(),
|
|
'updated_at': self.updated_at.isoformat(),
|
|
'current_version': self.current_version,
|
|
'statistics': self.get_statistics()
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPack':
|
|
"""Create from dictionary."""
|
|
# Parse datetime fields
|
|
created_at = data.get('created_at')
|
|
if isinstance(created_at, str):
|
|
created_at = datetime.fromisoformat(created_at)
|
|
elif created_at is None:
|
|
created_at = datetime.now()
|
|
|
|
updated_at = data.get('updated_at')
|
|
if isinstance(updated_at, str):
|
|
updated_at = datetime.fromisoformat(updated_at)
|
|
elif updated_at is None:
|
|
updated_at = datetime.now()
|
|
|
|
# Parse corrections
|
|
corrections_data = data.get('corrections', {})
|
|
corrections = {
|
|
cid: Correction.from_dict(cdata)
|
|
for cid, cdata in corrections_data.items()
|
|
}
|
|
|
|
# Parse metadata
|
|
metadata_data = data.get('metadata', {})
|
|
metadata = CorrectionPackMetadata.from_dict(metadata_data)
|
|
|
|
return cls(
|
|
id=data.get('id', str(uuid.uuid4())),
|
|
name=data.get('name', 'Unnamed Pack'),
|
|
description=data.get('description', ''),
|
|
corrections=corrections,
|
|
metadata=metadata,
|
|
created_at=created_at,
|
|
updated_at=updated_at,
|
|
current_version=data.get('current_version', 1)
|
|
)
|
|
|
|
|
|
def generate_correction_id() -> str:
|
|
"""Generate a unique correction ID."""
|
|
return f"corr_{uuid.uuid4().hex[:12]}"
|
|
|
|
|
|
def generate_pack_id() -> str:
|
|
"""Generate a unique pack ID."""
|
|
return f"pack_{uuid.uuid4().hex[:12]}"
|