feat(corrections): Add Correction Packs system for cross-workflow learning

Implement a complete system for capitalizing user corrections across multiple
workflows and sessions. This enables automatic application of learned fixes
when similar failures occur in different contexts.

New components:
- core/corrections/models.py: CorrectionKey, Correction, CorrectionPack models
- core/corrections/correction_repository.py: JSON storage with atomic writes
- core/corrections/aggregator.py: Aggregation by hash and quality filtering
- core/corrections/correction_pack_service.py: CRUD, export/import, versioning
- backend/api/correction_packs.py: REST API with 15 endpoints

Features:
- MD5-based key hashing for correction deduplication
- Export/import in JSON and YAML formats
- Version history with rollback support
- Cross-workflow pattern detection
- Integration with SelfHealingEngine for automatic application
- 29 unit tests (all passing)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Dom
2026-01-18 18:48:35 +01:00
parent fa57ecdbfd
commit d8756883c5
9 changed files with 4411 additions and 0 deletions

View File

@@ -0,0 +1,644 @@
"""
Tests for the Correction Packs system.
Tests models, repository, aggregator, and service.
"""
import json
import os
import shutil
import tempfile
import unittest
from datetime import datetime
from pathlib import Path
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from core.corrections.models import (
CorrectionKey,
CorrectionType,
CorrectionStatus,
Correction,
CorrectionPack,
CorrectionPackMetadata,
CorrectionSource,
generate_correction_id,
generate_pack_id
)
from core.corrections.correction_repository import CorrectionRepository
from core.corrections.aggregator import CorrectionAggregator, CrossWorkflowAggregator
from core.corrections.correction_pack_service import CorrectionPackService
class TestCorrectionKey(unittest.TestCase):
"""Tests for CorrectionKey model."""
def test_to_hash(self):
"""Test hash generation is consistent."""
key1 = CorrectionKey("click", "button", "element_not_found")
key2 = CorrectionKey("click", "button", "element_not_found")
key3 = CorrectionKey("type", "input", "element_not_found")
# Same keys should produce same hash
self.assertEqual(key1.to_hash(), key2.to_hash())
# Different keys should produce different hash
self.assertNotEqual(key1.to_hash(), key3.to_hash())
# Hash should be 16 characters
self.assertEqual(len(key1.to_hash()), 16)
def test_to_dict_from_dict(self):
"""Test serialization round-trip."""
key = CorrectionKey("click", "button", "timeout")
key_dict = key.to_dict()
self.assertEqual(key_dict['action_type'], "click")
self.assertEqual(key_dict['element_type'], "button")
self.assertEqual(key_dict['failure_context'], "timeout")
self.assertIn('hash', key_dict)
restored = CorrectionKey.from_dict(key_dict)
self.assertEqual(key.to_hash(), restored.to_hash())
class TestCorrection(unittest.TestCase):
"""Tests for Correction model."""
def test_create_correction(self):
"""Test creating a correction."""
key = CorrectionKey("click", "button", "not_visible")
correction = Correction(
id=generate_correction_id(),
key=key,
correction_type=CorrectionType.WAIT_ADDED,
original_target={"selector": "#submit"},
corrected_params={"wait_time": 2.0},
correction_description="Added wait for button visibility"
)
self.assertTrue(correction.id.startswith("corr_"))
self.assertEqual(correction.correction_type, CorrectionType.WAIT_ADDED)
self.assertEqual(correction.success_rate, 0.0)
def test_success_rate_calculation(self):
"""Test success rate calculation."""
key = CorrectionKey("click", "button", "not_visible")
correction = Correction(
id=generate_correction_id(),
key=key,
correction_type=CorrectionType.WAIT_ADDED,
success_count=7,
failure_count=3
)
self.assertEqual(correction.success_rate, 0.7)
def test_record_application(self):
"""Test recording application updates stats."""
key = CorrectionKey("click", "button", "not_visible")
correction = Correction(
id=generate_correction_id(),
key=key,
correction_type=CorrectionType.WAIT_ADDED
)
correction.record_application(success=True)
self.assertEqual(correction.success_count, 1)
self.assertEqual(correction.failure_count, 0)
self.assertIsNotNone(correction.last_applied)
correction.record_application(success=False)
self.assertEqual(correction.success_count, 1)
self.assertEqual(correction.failure_count, 1)
def test_to_dict_from_dict(self):
"""Test serialization round-trip."""
key = CorrectionKey("type", "input", "validation_error")
source = CorrectionSource(
session_id="sess_123",
workflow_id="wf_456",
node_id="node_789"
)
correction = Correction(
id="corr_test123",
key=key,
correction_type=CorrectionType.PARAMETER_CHANGE,
original_params={"value": "test"},
corrected_params={"value": "corrected"},
source=source,
success_count=5,
failure_count=2,
tags=["form", "validation"]
)
correction_dict = correction.to_dict()
restored = Correction.from_dict(correction_dict)
self.assertEqual(restored.id, correction.id)
self.assertEqual(restored.key.to_hash(), key.to_hash())
self.assertEqual(restored.correction_type, CorrectionType.PARAMETER_CHANGE)
self.assertEqual(restored.success_count, 5)
self.assertEqual(restored.tags, ["form", "validation"])
class TestCorrectionPack(unittest.TestCase):
"""Tests for CorrectionPack model."""
def test_create_pack(self):
"""Test creating a pack."""
pack = CorrectionPack(
id=generate_pack_id(),
name="Test Pack",
description="A test correction pack"
)
self.assertTrue(pack.id.startswith("pack_"))
self.assertEqual(pack.name, "Test Pack")
self.assertEqual(pack.correction_count, 0)
def test_add_remove_correction(self):
"""Test adding and removing corrections."""
pack = CorrectionPack(
id=generate_pack_id(),
name="Test Pack"
)
key = CorrectionKey("click", "button", "not_found")
correction = Correction(
id="corr_test1",
key=key,
correction_type=CorrectionType.TARGET_CHANGE
)
# Add correction
self.assertTrue(pack.add_correction(correction))
self.assertEqual(pack.correction_count, 1)
# Adding same correction should return False
self.assertFalse(pack.add_correction(correction))
# Remove correction
self.assertTrue(pack.remove_correction("corr_test1"))
self.assertEqual(pack.correction_count, 0)
# Removing non-existent should return False
self.assertFalse(pack.remove_correction("corr_test1"))
def test_find_by_key_hash(self):
"""Test finding corrections by key hash."""
pack = CorrectionPack(id=generate_pack_id(), name="Test Pack")
key1 = CorrectionKey("click", "button", "not_found")
key2 = CorrectionKey("type", "input", "validation")
correction1 = Correction(id="corr_1", key=key1, correction_type=CorrectionType.TARGET_CHANGE)
correction2 = Correction(id="corr_2", key=key2, correction_type=CorrectionType.PARAMETER_CHANGE)
pack.add_correction(correction1)
pack.add_correction(correction2)
# Find by key1 hash
found = pack.find_by_key_hash(key1.to_hash())
self.assertEqual(len(found), 1)
self.assertEqual(found[0].id, "corr_1")
def test_merge_packs(self):
"""Test merging two packs."""
pack1 = CorrectionPack(id="pack_1", name="Pack 1")
pack2 = CorrectionPack(id="pack_2", name="Pack 2")
key = CorrectionKey("click", "button", "error")
corr1 = Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE)
corr2 = Correction(id="corr_2", key=key, correction_type=CorrectionType.WAIT_ADDED)
pack1.add_correction(corr1)
pack2.add_correction(corr2)
merged_count = pack1.merge(pack2)
self.assertEqual(merged_count, 1)
self.assertEqual(pack1.correction_count, 2)
def test_get_statistics(self):
"""Test getting pack statistics."""
pack = CorrectionPack(id=generate_pack_id(), name="Test Pack")
key = CorrectionKey("click", "button", "error")
correction = Correction(
id="corr_1",
key=key,
correction_type=CorrectionType.TARGET_CHANGE,
success_count=10,
failure_count=2
)
pack.add_correction(correction)
stats = pack.get_statistics()
self.assertEqual(stats['total_corrections'], 1)
self.assertEqual(stats['active_corrections'], 1)
self.assertEqual(stats['total_applications'], 12)
class TestCorrectionRepository(unittest.TestCase):
"""Tests for CorrectionRepository."""
def setUp(self):
"""Create temporary directory for tests."""
self.test_dir = tempfile.mkdtemp()
self.repo = CorrectionRepository(self.test_dir)
def tearDown(self):
"""Clean up temporary directory."""
shutil.rmtree(self.test_dir)
def test_create_and_get_pack(self):
"""Test creating and retrieving a pack."""
pack = self.repo.create_pack("Test Pack", "Description", {
'category': 'testing',
'tags': ['unit', 'test']
})
self.assertIsNotNone(pack)
self.assertEqual(pack.name, "Test Pack")
retrieved = self.repo.get_pack(pack.id)
self.assertIsNotNone(retrieved)
self.assertEqual(retrieved.name, "Test Pack")
def test_list_packs(self):
"""Test listing packs."""
self.repo.create_pack("Pack 1", "First pack")
self.repo.create_pack("Pack 2", "Second pack")
packs = self.repo.list_packs()
self.assertEqual(len(packs), 2)
def test_update_pack(self):
"""Test updating a pack."""
pack = self.repo.create_pack("Original Name")
updated = self.repo.update_pack(pack.id, {'name': 'New Name'})
self.assertIsNotNone(updated)
self.assertEqual(updated.name, 'New Name')
def test_delete_pack(self):
"""Test deleting a pack."""
pack = self.repo.create_pack("To Delete")
self.assertTrue(self.repo.delete_pack(pack.id))
self.assertIsNone(self.repo.get_pack(pack.id))
self.assertFalse(self.repo.delete_pack(pack.id))
def test_add_and_find_correction(self):
"""Test adding and finding corrections."""
pack = self.repo.create_pack("Test Pack")
key = CorrectionKey("click", "button", "timeout")
correction = Correction(
id=generate_correction_id(),
key=key,
correction_type=CorrectionType.TIMING_ADJUST
)
self.assertTrue(self.repo.add_correction(pack.id, correction))
# Find by context
found = self.repo.find_by_context("click", "button", "timeout")
self.assertEqual(len(found), 1)
self.assertEqual(found[0][1].id, correction.id)
def test_export_import_pack(self):
"""Test exporting and importing a pack."""
# Create pack with correction
pack = self.repo.create_pack("Export Test")
key = CorrectionKey("click", "button", "error")
correction = Correction(
id="corr_export",
key=key,
correction_type=CorrectionType.TARGET_CHANGE
)
self.repo.add_correction(pack.id, correction)
# Export
exported = self.repo.export_pack(pack.id)
self.assertIsNotNone(exported)
# Import as new
imported = self.repo.import_pack(exported, merge_strategy='create_new')
self.assertIsNotNone(imported)
self.assertNotEqual(imported.id, pack.id)
self.assertEqual(imported.correction_count, 1)
def test_version_and_rollback(self):
"""Test versioning and rollback."""
pack = self.repo.create_pack("Version Test")
# Add correction
key = CorrectionKey("click", "button", "error")
correction1 = Correction(id="corr_v1", key=key, correction_type=CorrectionType.TARGET_CHANGE)
self.repo.add_correction(pack.id, correction1)
# Create version
version = self.repo.create_version(pack.id, "First version")
self.assertEqual(version, 1)
# Add another correction
correction2 = Correction(id="corr_v2", key=key, correction_type=CorrectionType.WAIT_ADDED)
self.repo.add_correction(pack.id, correction2)
# Verify 2 corrections
pack = self.repo.get_pack(pack.id)
self.assertEqual(pack.correction_count, 2)
# Rollback to version 1
self.assertTrue(self.repo.rollback_pack(pack.id, 1))
# Should have 1 correction again
pack = self.repo.get_pack(pack.id)
self.assertEqual(pack.correction_count, 1)
class TestCorrectionAggregator(unittest.TestCase):
"""Tests for CorrectionAggregator."""
def setUp(self):
"""Set up aggregator for tests."""
self.aggregator = CorrectionAggregator(
min_occurrences=2,
min_success_rate=0.5,
min_confidence=0.3
)
def test_aggregate_by_key(self):
"""Test aggregating corrections by key."""
key = CorrectionKey("click", "button", "not_found")
# Create multiple corrections with same key
corrections = [
Correction(
id=f"corr_{i}",
key=key,
correction_type=CorrectionType.WAIT_ADDED,
success_count=5,
failure_count=2
)
for i in range(3)
]
aggregated = self.aggregator.aggregate_corrections(corrections)
# Should have 1 merged correction
self.assertEqual(len(aggregated), 1)
# Sum of success counts
self.assertEqual(aggregated[0].success_count, 15)
def test_filter_by_quality(self):
"""Test filtering by quality thresholds."""
key = CorrectionKey("click", "button", "error")
high_quality = Correction(
id="corr_high",
key=key,
correction_type=CorrectionType.TARGET_CHANGE,
success_count=9,
failure_count=1
)
low_quality = Correction(
id="corr_low",
key=key,
correction_type=CorrectionType.TARGET_CHANGE,
success_count=1,
failure_count=9
)
filtered = self.aggregator.filter_by_quality(
[high_quality, low_quality],
min_success_rate=0.7
)
self.assertEqual(len(filtered), 1)
self.assertEqual(filtered[0].id, "corr_high")
def test_rank_corrections(self):
"""Test ranking corrections."""
key = CorrectionKey("click", "button", "error")
corrections = [
Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=3, failure_count=7),
Correction(id="corr_2", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=8, failure_count=2),
Correction(id="corr_3", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=5, failure_count=5),
]
ranked = self.aggregator.rank_corrections(corrections, sort_by='success_rate')
self.assertEqual(ranked[0].id, "corr_2")
self.assertEqual(ranked[1].id, "corr_3")
self.assertEqual(ranked[2].id, "corr_1")
class TestCorrectionPackService(unittest.TestCase):
"""Tests for CorrectionPackService."""
def setUp(self):
"""Set up service for tests."""
self.test_dir = tempfile.mkdtemp()
self.training_dir = tempfile.mkdtemp()
self.service = CorrectionPackService(self.test_dir, self.training_dir)
def tearDown(self):
"""Clean up temporary directories."""
shutil.rmtree(self.test_dir)
shutil.rmtree(self.training_dir)
def test_create_and_list_packs(self):
"""Test creating and listing packs."""
pack = self.service.create_pack(
name="Service Test Pack",
description="Testing the service",
category="testing",
tags=["unit", "test"]
)
self.assertIsNotNone(pack)
self.assertEqual(pack['name'], "Service Test Pack")
packs = self.service.list_packs()
self.assertEqual(len(packs), 1)
self.assertEqual(packs[0]['name'], "Service Test Pack")
def test_add_correction(self):
"""Test adding corrections via service."""
pack = self.service.create_pack(name="Correction Test")
correction = self.service.add_correction(pack['id'], {
'action_type': 'click',
'element_type': 'button',
'failure_context': 'element_not_found',
'correction_type': 'target_change',
'original_target': {'selector': '#old'},
'corrected_target': {'selector': '#new'},
'description': 'Updated selector'
})
self.assertIsNotNone(correction)
self.assertTrue(correction['id'].startswith('corr_'))
def test_find_applicable_corrections(self):
"""Test finding applicable corrections."""
pack = self.service.create_pack(name="Find Test")
self.service.add_correction(pack['id'], {
'action_type': 'click',
'element_type': 'button',
'failure_context': 'timeout',
'correction_type': 'timing_adjust',
'corrected_params': {'wait_time': 5}
})
found = self.service.find_applicable_corrections(
action_type='click',
element_type='button',
failure_context='timeout',
min_confidence=0.0
)
self.assertEqual(len(found), 1)
def test_export_import_file(self):
"""Test file-based export and import."""
pack = self.service.create_pack(name="File Export Test")
self.service.add_correction(pack['id'], {
'action_type': 'type',
'element_type': 'input',
'failure_context': 'validation',
'correction_type': 'parameter_change'
})
# Export to file
export_file = os.path.join(self.test_dir, 'export.json')
self.assertTrue(self.service.export_pack_to_file(pack['id'], export_file))
self.assertTrue(os.path.exists(export_file))
# Import from file
imported = self.service.import_pack_from_file(export_file)
self.assertIsNotNone(imported)
self.assertEqual(len(imported['corrections']), 1)
def test_version_management(self):
"""Test version creation and rollback."""
pack = self.service.create_pack(name="Version Test")
# Add correction
self.service.add_correction(pack['id'], {
'action_type': 'click',
'element_type': 'button',
'failure_context': 'error',
'correction_type': 'target_change'
})
# Create version
version = self.service.create_version(pack['id'], "v1 snapshot")
self.assertEqual(version, 1)
# List versions
versions = self.service.list_versions(pack['id'])
self.assertEqual(len(versions), 1)
# Add another correction
self.service.add_correction(pack['id'], {
'action_type': 'type',
'element_type': 'input',
'failure_context': 'error',
'correction_type': 'parameter_change'
})
# Get pack - should have 2 corrections
current_pack = self.service.get_pack(pack['id'])
self.assertEqual(len(current_pack['corrections']), 2)
# Rollback
self.assertTrue(self.service.rollback_pack(pack['id'], 1))
# Get pack - should have 1 correction
rolled_back = self.service.get_pack(pack['id'])
self.assertEqual(len(rolled_back['corrections']), 1)
def test_statistics(self):
"""Test getting statistics."""
pack = self.service.create_pack(name="Stats Test")
self.service.add_correction(pack['id'], {
'action_type': 'click',
'element_type': 'button',
'failure_context': 'error',
'correction_type': 'target_change'
})
# Pack statistics
pack_stats = self.service.get_pack_statistics(pack['id'])
self.assertIsNotNone(pack_stats)
self.assertEqual(pack_stats['total_corrections'], 1)
# Global statistics
global_stats = self.service.get_global_statistics()
self.assertIsNotNone(global_stats)
self.assertEqual(global_stats['total_packs'], 1)
class TestCrossWorkflowAggregator(unittest.TestCase):
"""Tests for CrossWorkflowAggregator."""
def test_find_common_corrections(self):
"""Test finding corrections common across workflows."""
aggregator = CrossWorkflowAggregator()
key = CorrectionKey("click", "submit_button", "timeout")
# Corrections from different workflows
corrections = [
Correction(
id="corr_1",
key=key,
correction_type=CorrectionType.WAIT_ADDED,
source=CorrectionSource(session_id="s1", workflow_id="wf_1")
),
Correction(
id="corr_2",
key=key,
correction_type=CorrectionType.WAIT_ADDED,
source=CorrectionSource(session_id="s2", workflow_id="wf_2")
),
Correction(
id="corr_3",
key=key,
correction_type=CorrectionType.WAIT_ADDED,
source=CorrectionSource(session_id="s3", workflow_id="wf_3")
),
]
common = aggregator.find_common_corrections(corrections, min_workflows=2)
self.assertEqual(len(common), 1)
merged, workflows = common[0]
self.assertEqual(len(workflows), 3)
def test_detect_patterns(self):
"""Test pattern detection."""
aggregator = CrossWorkflowAggregator()
key = CorrectionKey("click", "button", "error")
corrections = [
Correction(id=f"corr_{i}", key=key, correction_type=CorrectionType.WAIT_ADDED)
for i in range(5)
]
patterns = aggregator.detect_patterns(corrections)
self.assertIn('patterns', patterns)
self.assertIn('summary', patterns)
self.assertEqual(patterns['summary']['total_corrections'], 5)
if __name__ == '__main__':
unittest.main()