feat(corrections): Add Correction Packs system for cross-workflow learning
Implement a complete system for capitalizing user corrections across multiple workflows and sessions. This enables automatic application of learned fixes when similar failures occur in different contexts. New components: - core/corrections/models.py: CorrectionKey, Correction, CorrectionPack models - core/corrections/correction_repository.py: JSON storage with atomic writes - core/corrections/aggregator.py: Aggregation by hash and quality filtering - core/corrections/correction_pack_service.py: CRUD, export/import, versioning - backend/api/correction_packs.py: REST API with 15 endpoints Features: - MD5-based key hashing for correction deduplication - Export/import in JSON and YAML formats - Version history with rollback support - Cross-workflow pattern detection - Integration with SelfHealingEngine for automatic application - 29 unit tests (all passing) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
480
core/corrections/models.py
Normal file
480
core/corrections/models.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""
|
||||
Data models for the Correction Packs system.
|
||||
|
||||
Defines CorrectionKey, Correction, CorrectionPack and related types.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
|
||||
class CorrectionType(Enum):
|
||||
"""Types of corrections that can be applied."""
|
||||
TARGET_CHANGE = "target_change" # Changed target selector
|
||||
PARAMETER_CHANGE = "parameter_change" # Changed action parameters
|
||||
ACTION_REPLACE = "action_replace" # Replaced action type
|
||||
WAIT_ADDED = "wait_added" # Added wait before action
|
||||
RETRY_CONFIG = "retry_config" # Changed retry configuration
|
||||
FALLBACK_ADDED = "fallback_added" # Added fallback action
|
||||
COORDINATES_ADJUST = "coordinates_adjust" # Adjusted click coordinates
|
||||
TIMING_ADJUST = "timing_adjust" # Adjusted timing/delay
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class CorrectionStatus(Enum):
|
||||
"""Status of a correction."""
|
||||
ACTIVE = "active" # Correction is active and can be applied
|
||||
DEPRECATED = "deprecated" # Correction is deprecated
|
||||
DISABLED = "disabled" # Correction is manually disabled
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionKey:
|
||||
"""
|
||||
Unique key to identify similar corrections across workflows.
|
||||
|
||||
Uses action_type + element_type + failure_context to generate MD5 hash.
|
||||
Similar to LearningRepository pattern key generation.
|
||||
"""
|
||||
action_type: str
|
||||
element_type: str
|
||||
failure_context: str
|
||||
|
||||
def to_hash(self) -> str:
|
||||
"""Generate MD5 hash of the correction key (16 chars)."""
|
||||
key_string = f"{self.action_type}|{self.element_type}|{self.failure_context}"
|
||||
return hashlib.md5(key_string.encode()).hexdigest()[:16]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'action_type': self.action_type,
|
||||
'element_type': self.element_type,
|
||||
'failure_context': self.failure_context,
|
||||
'hash': self.to_hash()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionKey':
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
action_type=data.get('action_type', 'unknown'),
|
||||
element_type=data.get('element_type', 'unknown'),
|
||||
failure_context=data.get('failure_context', '')
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionSource:
|
||||
"""Source information for a correction."""
|
||||
session_id: str
|
||||
workflow_id: Optional[str] = None
|
||||
node_id: Optional[str] = None
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'session_id': self.session_id,
|
||||
'workflow_id': self.workflow_id,
|
||||
'node_id': self.node_id,
|
||||
'timestamp': self.timestamp.isoformat()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionSource':
|
||||
"""Create from dictionary."""
|
||||
timestamp = data.get('timestamp')
|
||||
if isinstance(timestamp, str):
|
||||
timestamp = datetime.fromisoformat(timestamp)
|
||||
elif timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
return cls(
|
||||
session_id=data.get('session_id', ''),
|
||||
workflow_id=data.get('workflow_id'),
|
||||
node_id=data.get('node_id'),
|
||||
timestamp=timestamp
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Correction:
|
||||
"""
|
||||
Individual correction record.
|
||||
|
||||
Contains the original context, the applied correction, source information,
|
||||
and statistics about success/failure.
|
||||
"""
|
||||
id: str
|
||||
key: CorrectionKey
|
||||
correction_type: CorrectionType
|
||||
|
||||
# Original context
|
||||
original_target: Optional[Dict[str, Any]] = None
|
||||
original_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Correction applied
|
||||
corrected_target: Optional[Dict[str, Any]] = None
|
||||
corrected_params: Optional[Dict[str, Any]] = None
|
||||
correction_description: str = ""
|
||||
|
||||
# Source information
|
||||
source: Optional[CorrectionSource] = None
|
||||
|
||||
# Statistics
|
||||
success_count: int = 0
|
||||
failure_count: int = 0
|
||||
last_applied: Optional[datetime] = None
|
||||
|
||||
# Status
|
||||
status: CorrectionStatus = CorrectionStatus.ACTIVE
|
||||
|
||||
# Metadata
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
tags: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate success rate."""
|
||||
total = self.success_count + self.failure_count
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.success_count / total
|
||||
|
||||
@property
|
||||
def confidence_score(self) -> float:
|
||||
"""Calculate confidence score based on success rate and usage count."""
|
||||
total = self.success_count + self.failure_count
|
||||
if total == 0:
|
||||
return 0.0
|
||||
# Confidence increases with more data points (up to 100 uses)
|
||||
usage_factor = min(1.0, total / 100)
|
||||
return self.success_rate * (0.5 + 0.5 * usage_factor)
|
||||
|
||||
def record_application(self, success: bool):
|
||||
"""Record an application of this correction."""
|
||||
if success:
|
||||
self.success_count += 1
|
||||
else:
|
||||
self.failure_count += 1
|
||||
self.last_applied = datetime.now()
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'id': self.id,
|
||||
'key': self.key.to_dict(),
|
||||
'correction_type': self.correction_type.value,
|
||||
'original_target': self.original_target,
|
||||
'original_params': self.original_params,
|
||||
'corrected_target': self.corrected_target,
|
||||
'corrected_params': self.corrected_params,
|
||||
'correction_description': self.correction_description,
|
||||
'source': self.source.to_dict() if self.source else None,
|
||||
'success_count': self.success_count,
|
||||
'failure_count': self.failure_count,
|
||||
'success_rate': self.success_rate,
|
||||
'confidence_score': self.confidence_score,
|
||||
'last_applied': self.last_applied.isoformat() if self.last_applied else None,
|
||||
'status': self.status.value,
|
||||
'created_at': self.created_at.isoformat(),
|
||||
'updated_at': self.updated_at.isoformat(),
|
||||
'tags': self.tags
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'Correction':
|
||||
"""Create from dictionary."""
|
||||
# Parse datetime fields
|
||||
created_at = data.get('created_at')
|
||||
if isinstance(created_at, str):
|
||||
created_at = datetime.fromisoformat(created_at)
|
||||
elif created_at is None:
|
||||
created_at = datetime.now()
|
||||
|
||||
updated_at = data.get('updated_at')
|
||||
if isinstance(updated_at, str):
|
||||
updated_at = datetime.fromisoformat(updated_at)
|
||||
elif updated_at is None:
|
||||
updated_at = datetime.now()
|
||||
|
||||
last_applied = data.get('last_applied')
|
||||
if isinstance(last_applied, str):
|
||||
last_applied = datetime.fromisoformat(last_applied)
|
||||
|
||||
# Parse nested objects
|
||||
key_data = data.get('key', {})
|
||||
key = CorrectionKey.from_dict(key_data)
|
||||
|
||||
source_data = data.get('source')
|
||||
source = CorrectionSource.from_dict(source_data) if source_data else None
|
||||
|
||||
return cls(
|
||||
id=data.get('id', str(uuid.uuid4())),
|
||||
key=key,
|
||||
correction_type=CorrectionType(data.get('correction_type', 'other')),
|
||||
original_target=data.get('original_target'),
|
||||
original_params=data.get('original_params'),
|
||||
corrected_target=data.get('corrected_target'),
|
||||
corrected_params=data.get('corrected_params'),
|
||||
correction_description=data.get('correction_description', ''),
|
||||
source=source,
|
||||
success_count=data.get('success_count', 0),
|
||||
failure_count=data.get('failure_count', 0),
|
||||
last_applied=last_applied,
|
||||
status=CorrectionStatus(data.get('status', 'active')),
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
tags=data.get('tags', [])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionPackMetadata:
|
||||
"""Metadata for a correction pack."""
|
||||
version: str = "1.0.0"
|
||||
author: str = ""
|
||||
category: str = "general"
|
||||
tags: List[str] = field(default_factory=list)
|
||||
description: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'version': self.version,
|
||||
'author': self.author,
|
||||
'category': self.category,
|
||||
'tags': self.tags,
|
||||
'description': self.description
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPackMetadata':
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
version=data.get('version', '1.0.0'),
|
||||
author=data.get('author', ''),
|
||||
category=data.get('category', 'general'),
|
||||
tags=data.get('tags', []),
|
||||
description=data.get('description', '')
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionPack:
|
||||
"""
|
||||
A pack of corrections that can be exported, imported, and shared.
|
||||
|
||||
Contains multiple corrections with metadata and statistics.
|
||||
"""
|
||||
id: str
|
||||
name: str
|
||||
description: str = ""
|
||||
|
||||
# Corrections in this pack
|
||||
corrections: Dict[str, Correction] = field(default_factory=dict)
|
||||
|
||||
# Metadata
|
||||
metadata: CorrectionPackMetadata = field(default_factory=CorrectionPackMetadata)
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
# Version history
|
||||
current_version: int = 1
|
||||
|
||||
@property
|
||||
def correction_count(self) -> int:
|
||||
"""Number of corrections in the pack."""
|
||||
return len(self.corrections)
|
||||
|
||||
@property
|
||||
def total_applications(self) -> int:
|
||||
"""Total number of times corrections were applied."""
|
||||
return sum(c.success_count + c.failure_count for c in self.corrections.values())
|
||||
|
||||
@property
|
||||
def overall_success_rate(self) -> float:
|
||||
"""Overall success rate across all corrections."""
|
||||
total_success = sum(c.success_count for c in self.corrections.values())
|
||||
total_failure = sum(c.failure_count for c in self.corrections.values())
|
||||
total = total_success + total_failure
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return total_success / total
|
||||
|
||||
@property
|
||||
def active_corrections(self) -> List[Correction]:
|
||||
"""Get only active corrections."""
|
||||
return [c for c in self.corrections.values() if c.status == CorrectionStatus.ACTIVE]
|
||||
|
||||
def add_correction(self, correction: Correction) -> bool:
|
||||
"""
|
||||
Add a correction to the pack.
|
||||
|
||||
Returns True if added, False if already exists.
|
||||
"""
|
||||
if correction.id in self.corrections:
|
||||
return False
|
||||
self.corrections[correction.id] = correction
|
||||
self.updated_at = datetime.now()
|
||||
return True
|
||||
|
||||
def remove_correction(self, correction_id: str) -> bool:
|
||||
"""
|
||||
Remove a correction from the pack.
|
||||
|
||||
Returns True if removed, False if not found.
|
||||
"""
|
||||
if correction_id not in self.corrections:
|
||||
return False
|
||||
del self.corrections[correction_id]
|
||||
self.updated_at = datetime.now()
|
||||
return True
|
||||
|
||||
def get_correction(self, correction_id: str) -> Optional[Correction]:
|
||||
"""Get a correction by ID."""
|
||||
return self.corrections.get(correction_id)
|
||||
|
||||
def find_by_key_hash(self, key_hash: str) -> List[Correction]:
|
||||
"""Find corrections matching a key hash."""
|
||||
return [
|
||||
c for c in self.corrections.values()
|
||||
if c.key.to_hash() == key_hash and c.status == CorrectionStatus.ACTIVE
|
||||
]
|
||||
|
||||
def merge(self, other: 'CorrectionPack', conflict_strategy: str = 'keep_higher_confidence') -> int:
|
||||
"""
|
||||
Merge another pack into this one.
|
||||
|
||||
Args:
|
||||
other: Pack to merge
|
||||
conflict_strategy: How to handle conflicts:
|
||||
- 'keep_higher_confidence': Keep correction with higher confidence
|
||||
- 'keep_newer': Keep the most recently updated correction
|
||||
- 'keep_existing': Keep existing, ignore new
|
||||
- 'replace': Always replace with new
|
||||
|
||||
Returns:
|
||||
Number of corrections added/updated
|
||||
"""
|
||||
merged_count = 0
|
||||
|
||||
for correction_id, correction in other.corrections.items():
|
||||
if correction_id in self.corrections:
|
||||
existing = self.corrections[correction_id]
|
||||
|
||||
should_replace = False
|
||||
if conflict_strategy == 'keep_higher_confidence':
|
||||
should_replace = correction.confidence_score > existing.confidence_score
|
||||
elif conflict_strategy == 'keep_newer':
|
||||
should_replace = correction.updated_at > existing.updated_at
|
||||
elif conflict_strategy == 'replace':
|
||||
should_replace = True
|
||||
# 'keep_existing' doesn't replace
|
||||
|
||||
if should_replace:
|
||||
self.corrections[correction_id] = correction
|
||||
merged_count += 1
|
||||
else:
|
||||
self.corrections[correction_id] = correction
|
||||
merged_count += 1
|
||||
|
||||
if merged_count > 0:
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
return merged_count
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get aggregate statistics for the pack."""
|
||||
corrections = list(self.corrections.values())
|
||||
active = [c for c in corrections if c.status == CorrectionStatus.ACTIVE]
|
||||
|
||||
# Group by correction type
|
||||
by_type = {}
|
||||
for c in corrections:
|
||||
type_name = c.correction_type.value
|
||||
if type_name not in by_type:
|
||||
by_type[type_name] = 0
|
||||
by_type[type_name] += 1
|
||||
|
||||
return {
|
||||
'total_corrections': len(corrections),
|
||||
'active_corrections': len(active),
|
||||
'deprecated_corrections': len([c for c in corrections if c.status == CorrectionStatus.DEPRECATED]),
|
||||
'disabled_corrections': len([c for c in corrections if c.status == CorrectionStatus.DISABLED]),
|
||||
'total_applications': self.total_applications,
|
||||
'overall_success_rate': self.overall_success_rate,
|
||||
'corrections_by_type': by_type,
|
||||
'average_confidence': sum(c.confidence_score for c in active) / len(active) if active else 0.0
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'corrections': {cid: c.to_dict() for cid, c in self.corrections.items()},
|
||||
'metadata': self.metadata.to_dict(),
|
||||
'created_at': self.created_at.isoformat(),
|
||||
'updated_at': self.updated_at.isoformat(),
|
||||
'current_version': self.current_version,
|
||||
'statistics': self.get_statistics()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CorrectionPack':
|
||||
"""Create from dictionary."""
|
||||
# Parse datetime fields
|
||||
created_at = data.get('created_at')
|
||||
if isinstance(created_at, str):
|
||||
created_at = datetime.fromisoformat(created_at)
|
||||
elif created_at is None:
|
||||
created_at = datetime.now()
|
||||
|
||||
updated_at = data.get('updated_at')
|
||||
if isinstance(updated_at, str):
|
||||
updated_at = datetime.fromisoformat(updated_at)
|
||||
elif updated_at is None:
|
||||
updated_at = datetime.now()
|
||||
|
||||
# Parse corrections
|
||||
corrections_data = data.get('corrections', {})
|
||||
corrections = {
|
||||
cid: Correction.from_dict(cdata)
|
||||
for cid, cdata in corrections_data.items()
|
||||
}
|
||||
|
||||
# Parse metadata
|
||||
metadata_data = data.get('metadata', {})
|
||||
metadata = CorrectionPackMetadata.from_dict(metadata_data)
|
||||
|
||||
return cls(
|
||||
id=data.get('id', str(uuid.uuid4())),
|
||||
name=data.get('name', 'Unnamed Pack'),
|
||||
description=data.get('description', ''),
|
||||
corrections=corrections,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
current_version=data.get('current_version', 1)
|
||||
)
|
||||
|
||||
|
||||
def generate_correction_id() -> str:
|
||||
"""Generate a unique correction ID."""
|
||||
return f"corr_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
def generate_pack_id() -> str:
|
||||
"""Generate a unique pack ID."""
|
||||
return f"pack_{uuid.uuid4().hex[:12]}"
|
||||
Reference in New Issue
Block a user