feat(corrections): Add Correction Packs system for cross-workflow learning
Implement a complete system for capitalizing user corrections across multiple workflows and sessions. This enables automatic application of learned fixes when similar failures occur in different contexts. New components: - core/corrections/models.py: CorrectionKey, Correction, CorrectionPack models - core/corrections/correction_repository.py: JSON storage with atomic writes - core/corrections/aggregator.py: Aggregation by hash and quality filtering - core/corrections/correction_pack_service.py: CRUD, export/import, versioning - backend/api/correction_packs.py: REST API with 15 endpoints Features: - MD5-based key hashing for correction deduplication - Export/import in JSON and YAML formats - Version history with rollback support - Cross-workflow pattern detection - Integration with SelfHealingEngine for automatic application - 29 unit tests (all passing) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
644
tests/test_correction_packs.py
Normal file
644
tests/test_correction_packs.py
Normal file
@@ -0,0 +1,644 @@
|
||||
"""
|
||||
Tests for the Correction Packs system.
|
||||
|
||||
Tests models, repository, aggregator, and service.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.corrections.models import (
|
||||
CorrectionKey,
|
||||
CorrectionType,
|
||||
CorrectionStatus,
|
||||
Correction,
|
||||
CorrectionPack,
|
||||
CorrectionPackMetadata,
|
||||
CorrectionSource,
|
||||
generate_correction_id,
|
||||
generate_pack_id
|
||||
)
|
||||
from core.corrections.correction_repository import CorrectionRepository
|
||||
from core.corrections.aggregator import CorrectionAggregator, CrossWorkflowAggregator
|
||||
from core.corrections.correction_pack_service import CorrectionPackService
|
||||
|
||||
|
||||
class TestCorrectionKey(unittest.TestCase):
|
||||
"""Tests for CorrectionKey model."""
|
||||
|
||||
def test_to_hash(self):
|
||||
"""Test hash generation is consistent."""
|
||||
key1 = CorrectionKey("click", "button", "element_not_found")
|
||||
key2 = CorrectionKey("click", "button", "element_not_found")
|
||||
key3 = CorrectionKey("type", "input", "element_not_found")
|
||||
|
||||
# Same keys should produce same hash
|
||||
self.assertEqual(key1.to_hash(), key2.to_hash())
|
||||
|
||||
# Different keys should produce different hash
|
||||
self.assertNotEqual(key1.to_hash(), key3.to_hash())
|
||||
|
||||
# Hash should be 16 characters
|
||||
self.assertEqual(len(key1.to_hash()), 16)
|
||||
|
||||
def test_to_dict_from_dict(self):
|
||||
"""Test serialization round-trip."""
|
||||
key = CorrectionKey("click", "button", "timeout")
|
||||
key_dict = key.to_dict()
|
||||
|
||||
self.assertEqual(key_dict['action_type'], "click")
|
||||
self.assertEqual(key_dict['element_type'], "button")
|
||||
self.assertEqual(key_dict['failure_context'], "timeout")
|
||||
self.assertIn('hash', key_dict)
|
||||
|
||||
restored = CorrectionKey.from_dict(key_dict)
|
||||
self.assertEqual(key.to_hash(), restored.to_hash())
|
||||
|
||||
|
||||
class TestCorrection(unittest.TestCase):
|
||||
"""Tests for Correction model."""
|
||||
|
||||
def test_create_correction(self):
|
||||
"""Test creating a correction."""
|
||||
key = CorrectionKey("click", "button", "not_visible")
|
||||
correction = Correction(
|
||||
id=generate_correction_id(),
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
original_target={"selector": "#submit"},
|
||||
corrected_params={"wait_time": 2.0},
|
||||
correction_description="Added wait for button visibility"
|
||||
)
|
||||
|
||||
self.assertTrue(correction.id.startswith("corr_"))
|
||||
self.assertEqual(correction.correction_type, CorrectionType.WAIT_ADDED)
|
||||
self.assertEqual(correction.success_rate, 0.0)
|
||||
|
||||
def test_success_rate_calculation(self):
|
||||
"""Test success rate calculation."""
|
||||
key = CorrectionKey("click", "button", "not_visible")
|
||||
correction = Correction(
|
||||
id=generate_correction_id(),
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
success_count=7,
|
||||
failure_count=3
|
||||
)
|
||||
|
||||
self.assertEqual(correction.success_rate, 0.7)
|
||||
|
||||
def test_record_application(self):
|
||||
"""Test recording application updates stats."""
|
||||
key = CorrectionKey("click", "button", "not_visible")
|
||||
correction = Correction(
|
||||
id=generate_correction_id(),
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED
|
||||
)
|
||||
|
||||
correction.record_application(success=True)
|
||||
self.assertEqual(correction.success_count, 1)
|
||||
self.assertEqual(correction.failure_count, 0)
|
||||
self.assertIsNotNone(correction.last_applied)
|
||||
|
||||
correction.record_application(success=False)
|
||||
self.assertEqual(correction.success_count, 1)
|
||||
self.assertEqual(correction.failure_count, 1)
|
||||
|
||||
def test_to_dict_from_dict(self):
|
||||
"""Test serialization round-trip."""
|
||||
key = CorrectionKey("type", "input", "validation_error")
|
||||
source = CorrectionSource(
|
||||
session_id="sess_123",
|
||||
workflow_id="wf_456",
|
||||
node_id="node_789"
|
||||
)
|
||||
correction = Correction(
|
||||
id="corr_test123",
|
||||
key=key,
|
||||
correction_type=CorrectionType.PARAMETER_CHANGE,
|
||||
original_params={"value": "test"},
|
||||
corrected_params={"value": "corrected"},
|
||||
source=source,
|
||||
success_count=5,
|
||||
failure_count=2,
|
||||
tags=["form", "validation"]
|
||||
)
|
||||
|
||||
correction_dict = correction.to_dict()
|
||||
restored = Correction.from_dict(correction_dict)
|
||||
|
||||
self.assertEqual(restored.id, correction.id)
|
||||
self.assertEqual(restored.key.to_hash(), key.to_hash())
|
||||
self.assertEqual(restored.correction_type, CorrectionType.PARAMETER_CHANGE)
|
||||
self.assertEqual(restored.success_count, 5)
|
||||
self.assertEqual(restored.tags, ["form", "validation"])
|
||||
|
||||
|
||||
class TestCorrectionPack(unittest.TestCase):
|
||||
"""Tests for CorrectionPack model."""
|
||||
|
||||
def test_create_pack(self):
|
||||
"""Test creating a pack."""
|
||||
pack = CorrectionPack(
|
||||
id=generate_pack_id(),
|
||||
name="Test Pack",
|
||||
description="A test correction pack"
|
||||
)
|
||||
|
||||
self.assertTrue(pack.id.startswith("pack_"))
|
||||
self.assertEqual(pack.name, "Test Pack")
|
||||
self.assertEqual(pack.correction_count, 0)
|
||||
|
||||
def test_add_remove_correction(self):
|
||||
"""Test adding and removing corrections."""
|
||||
pack = CorrectionPack(
|
||||
id=generate_pack_id(),
|
||||
name="Test Pack"
|
||||
)
|
||||
|
||||
key = CorrectionKey("click", "button", "not_found")
|
||||
correction = Correction(
|
||||
id="corr_test1",
|
||||
key=key,
|
||||
correction_type=CorrectionType.TARGET_CHANGE
|
||||
)
|
||||
|
||||
# Add correction
|
||||
self.assertTrue(pack.add_correction(correction))
|
||||
self.assertEqual(pack.correction_count, 1)
|
||||
|
||||
# Adding same correction should return False
|
||||
self.assertFalse(pack.add_correction(correction))
|
||||
|
||||
# Remove correction
|
||||
self.assertTrue(pack.remove_correction("corr_test1"))
|
||||
self.assertEqual(pack.correction_count, 0)
|
||||
|
||||
# Removing non-existent should return False
|
||||
self.assertFalse(pack.remove_correction("corr_test1"))
|
||||
|
||||
def test_find_by_key_hash(self):
|
||||
"""Test finding corrections by key hash."""
|
||||
pack = CorrectionPack(id=generate_pack_id(), name="Test Pack")
|
||||
|
||||
key1 = CorrectionKey("click", "button", "not_found")
|
||||
key2 = CorrectionKey("type", "input", "validation")
|
||||
|
||||
correction1 = Correction(id="corr_1", key=key1, correction_type=CorrectionType.TARGET_CHANGE)
|
||||
correction2 = Correction(id="corr_2", key=key2, correction_type=CorrectionType.PARAMETER_CHANGE)
|
||||
|
||||
pack.add_correction(correction1)
|
||||
pack.add_correction(correction2)
|
||||
|
||||
# Find by key1 hash
|
||||
found = pack.find_by_key_hash(key1.to_hash())
|
||||
self.assertEqual(len(found), 1)
|
||||
self.assertEqual(found[0].id, "corr_1")
|
||||
|
||||
def test_merge_packs(self):
|
||||
"""Test merging two packs."""
|
||||
pack1 = CorrectionPack(id="pack_1", name="Pack 1")
|
||||
pack2 = CorrectionPack(id="pack_2", name="Pack 2")
|
||||
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
corr1 = Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE)
|
||||
corr2 = Correction(id="corr_2", key=key, correction_type=CorrectionType.WAIT_ADDED)
|
||||
|
||||
pack1.add_correction(corr1)
|
||||
pack2.add_correction(corr2)
|
||||
|
||||
merged_count = pack1.merge(pack2)
|
||||
self.assertEqual(merged_count, 1)
|
||||
self.assertEqual(pack1.correction_count, 2)
|
||||
|
||||
def test_get_statistics(self):
|
||||
"""Test getting pack statistics."""
|
||||
pack = CorrectionPack(id=generate_pack_id(), name="Test Pack")
|
||||
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
correction = Correction(
|
||||
id="corr_1",
|
||||
key=key,
|
||||
correction_type=CorrectionType.TARGET_CHANGE,
|
||||
success_count=10,
|
||||
failure_count=2
|
||||
)
|
||||
pack.add_correction(correction)
|
||||
|
||||
stats = pack.get_statistics()
|
||||
self.assertEqual(stats['total_corrections'], 1)
|
||||
self.assertEqual(stats['active_corrections'], 1)
|
||||
self.assertEqual(stats['total_applications'], 12)
|
||||
|
||||
|
||||
class TestCorrectionRepository(unittest.TestCase):
|
||||
"""Tests for CorrectionRepository."""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary directory for tests."""
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.repo = CorrectionRepository(self.test_dir)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up temporary directory."""
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def test_create_and_get_pack(self):
|
||||
"""Test creating and retrieving a pack."""
|
||||
pack = self.repo.create_pack("Test Pack", "Description", {
|
||||
'category': 'testing',
|
||||
'tags': ['unit', 'test']
|
||||
})
|
||||
|
||||
self.assertIsNotNone(pack)
|
||||
self.assertEqual(pack.name, "Test Pack")
|
||||
|
||||
retrieved = self.repo.get_pack(pack.id)
|
||||
self.assertIsNotNone(retrieved)
|
||||
self.assertEqual(retrieved.name, "Test Pack")
|
||||
|
||||
def test_list_packs(self):
|
||||
"""Test listing packs."""
|
||||
self.repo.create_pack("Pack 1", "First pack")
|
||||
self.repo.create_pack("Pack 2", "Second pack")
|
||||
|
||||
packs = self.repo.list_packs()
|
||||
self.assertEqual(len(packs), 2)
|
||||
|
||||
def test_update_pack(self):
|
||||
"""Test updating a pack."""
|
||||
pack = self.repo.create_pack("Original Name")
|
||||
|
||||
updated = self.repo.update_pack(pack.id, {'name': 'New Name'})
|
||||
self.assertIsNotNone(updated)
|
||||
self.assertEqual(updated.name, 'New Name')
|
||||
|
||||
def test_delete_pack(self):
|
||||
"""Test deleting a pack."""
|
||||
pack = self.repo.create_pack("To Delete")
|
||||
|
||||
self.assertTrue(self.repo.delete_pack(pack.id))
|
||||
self.assertIsNone(self.repo.get_pack(pack.id))
|
||||
self.assertFalse(self.repo.delete_pack(pack.id))
|
||||
|
||||
def test_add_and_find_correction(self):
|
||||
"""Test adding and finding corrections."""
|
||||
pack = self.repo.create_pack("Test Pack")
|
||||
|
||||
key = CorrectionKey("click", "button", "timeout")
|
||||
correction = Correction(
|
||||
id=generate_correction_id(),
|
||||
key=key,
|
||||
correction_type=CorrectionType.TIMING_ADJUST
|
||||
)
|
||||
|
||||
self.assertTrue(self.repo.add_correction(pack.id, correction))
|
||||
|
||||
# Find by context
|
||||
found = self.repo.find_by_context("click", "button", "timeout")
|
||||
self.assertEqual(len(found), 1)
|
||||
self.assertEqual(found[0][1].id, correction.id)
|
||||
|
||||
def test_export_import_pack(self):
|
||||
"""Test exporting and importing a pack."""
|
||||
# Create pack with correction
|
||||
pack = self.repo.create_pack("Export Test")
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
correction = Correction(
|
||||
id="corr_export",
|
||||
key=key,
|
||||
correction_type=CorrectionType.TARGET_CHANGE
|
||||
)
|
||||
self.repo.add_correction(pack.id, correction)
|
||||
|
||||
# Export
|
||||
exported = self.repo.export_pack(pack.id)
|
||||
self.assertIsNotNone(exported)
|
||||
|
||||
# Import as new
|
||||
imported = self.repo.import_pack(exported, merge_strategy='create_new')
|
||||
self.assertIsNotNone(imported)
|
||||
self.assertNotEqual(imported.id, pack.id)
|
||||
self.assertEqual(imported.correction_count, 1)
|
||||
|
||||
def test_version_and_rollback(self):
|
||||
"""Test versioning and rollback."""
|
||||
pack = self.repo.create_pack("Version Test")
|
||||
|
||||
# Add correction
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
correction1 = Correction(id="corr_v1", key=key, correction_type=CorrectionType.TARGET_CHANGE)
|
||||
self.repo.add_correction(pack.id, correction1)
|
||||
|
||||
# Create version
|
||||
version = self.repo.create_version(pack.id, "First version")
|
||||
self.assertEqual(version, 1)
|
||||
|
||||
# Add another correction
|
||||
correction2 = Correction(id="corr_v2", key=key, correction_type=CorrectionType.WAIT_ADDED)
|
||||
self.repo.add_correction(pack.id, correction2)
|
||||
|
||||
# Verify 2 corrections
|
||||
pack = self.repo.get_pack(pack.id)
|
||||
self.assertEqual(pack.correction_count, 2)
|
||||
|
||||
# Rollback to version 1
|
||||
self.assertTrue(self.repo.rollback_pack(pack.id, 1))
|
||||
|
||||
# Should have 1 correction again
|
||||
pack = self.repo.get_pack(pack.id)
|
||||
self.assertEqual(pack.correction_count, 1)
|
||||
|
||||
|
||||
class TestCorrectionAggregator(unittest.TestCase):
|
||||
"""Tests for CorrectionAggregator."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up aggregator for tests."""
|
||||
self.aggregator = CorrectionAggregator(
|
||||
min_occurrences=2,
|
||||
min_success_rate=0.5,
|
||||
min_confidence=0.3
|
||||
)
|
||||
|
||||
def test_aggregate_by_key(self):
|
||||
"""Test aggregating corrections by key."""
|
||||
key = CorrectionKey("click", "button", "not_found")
|
||||
|
||||
# Create multiple corrections with same key
|
||||
corrections = [
|
||||
Correction(
|
||||
id=f"corr_{i}",
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
success_count=5,
|
||||
failure_count=2
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
aggregated = self.aggregator.aggregate_corrections(corrections)
|
||||
|
||||
# Should have 1 merged correction
|
||||
self.assertEqual(len(aggregated), 1)
|
||||
# Sum of success counts
|
||||
self.assertEqual(aggregated[0].success_count, 15)
|
||||
|
||||
def test_filter_by_quality(self):
|
||||
"""Test filtering by quality thresholds."""
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
|
||||
high_quality = Correction(
|
||||
id="corr_high",
|
||||
key=key,
|
||||
correction_type=CorrectionType.TARGET_CHANGE,
|
||||
success_count=9,
|
||||
failure_count=1
|
||||
)
|
||||
|
||||
low_quality = Correction(
|
||||
id="corr_low",
|
||||
key=key,
|
||||
correction_type=CorrectionType.TARGET_CHANGE,
|
||||
success_count=1,
|
||||
failure_count=9
|
||||
)
|
||||
|
||||
filtered = self.aggregator.filter_by_quality(
|
||||
[high_quality, low_quality],
|
||||
min_success_rate=0.7
|
||||
)
|
||||
|
||||
self.assertEqual(len(filtered), 1)
|
||||
self.assertEqual(filtered[0].id, "corr_high")
|
||||
|
||||
def test_rank_corrections(self):
|
||||
"""Test ranking corrections."""
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
|
||||
corrections = [
|
||||
Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=3, failure_count=7),
|
||||
Correction(id="corr_2", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=8, failure_count=2),
|
||||
Correction(id="corr_3", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=5, failure_count=5),
|
||||
]
|
||||
|
||||
ranked = self.aggregator.rank_corrections(corrections, sort_by='success_rate')
|
||||
|
||||
self.assertEqual(ranked[0].id, "corr_2")
|
||||
self.assertEqual(ranked[1].id, "corr_3")
|
||||
self.assertEqual(ranked[2].id, "corr_1")
|
||||
|
||||
|
||||
class TestCorrectionPackService(unittest.TestCase):
|
||||
"""Tests for CorrectionPackService."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up service for tests."""
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.training_dir = tempfile.mkdtemp()
|
||||
self.service = CorrectionPackService(self.test_dir, self.training_dir)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up temporary directories."""
|
||||
shutil.rmtree(self.test_dir)
|
||||
shutil.rmtree(self.training_dir)
|
||||
|
||||
def test_create_and_list_packs(self):
|
||||
"""Test creating and listing packs."""
|
||||
pack = self.service.create_pack(
|
||||
name="Service Test Pack",
|
||||
description="Testing the service",
|
||||
category="testing",
|
||||
tags=["unit", "test"]
|
||||
)
|
||||
|
||||
self.assertIsNotNone(pack)
|
||||
self.assertEqual(pack['name'], "Service Test Pack")
|
||||
|
||||
packs = self.service.list_packs()
|
||||
self.assertEqual(len(packs), 1)
|
||||
self.assertEqual(packs[0]['name'], "Service Test Pack")
|
||||
|
||||
def test_add_correction(self):
|
||||
"""Test adding corrections via service."""
|
||||
pack = self.service.create_pack(name="Correction Test")
|
||||
|
||||
correction = self.service.add_correction(pack['id'], {
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_context': 'element_not_found',
|
||||
'correction_type': 'target_change',
|
||||
'original_target': {'selector': '#old'},
|
||||
'corrected_target': {'selector': '#new'},
|
||||
'description': 'Updated selector'
|
||||
})
|
||||
|
||||
self.assertIsNotNone(correction)
|
||||
self.assertTrue(correction['id'].startswith('corr_'))
|
||||
|
||||
def test_find_applicable_corrections(self):
|
||||
"""Test finding applicable corrections."""
|
||||
pack = self.service.create_pack(name="Find Test")
|
||||
|
||||
self.service.add_correction(pack['id'], {
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_context': 'timeout',
|
||||
'correction_type': 'timing_adjust',
|
||||
'corrected_params': {'wait_time': 5}
|
||||
})
|
||||
|
||||
found = self.service.find_applicable_corrections(
|
||||
action_type='click',
|
||||
element_type='button',
|
||||
failure_context='timeout',
|
||||
min_confidence=0.0
|
||||
)
|
||||
|
||||
self.assertEqual(len(found), 1)
|
||||
|
||||
def test_export_import_file(self):
|
||||
"""Test file-based export and import."""
|
||||
pack = self.service.create_pack(name="File Export Test")
|
||||
self.service.add_correction(pack['id'], {
|
||||
'action_type': 'type',
|
||||
'element_type': 'input',
|
||||
'failure_context': 'validation',
|
||||
'correction_type': 'parameter_change'
|
||||
})
|
||||
|
||||
# Export to file
|
||||
export_file = os.path.join(self.test_dir, 'export.json')
|
||||
self.assertTrue(self.service.export_pack_to_file(pack['id'], export_file))
|
||||
self.assertTrue(os.path.exists(export_file))
|
||||
|
||||
# Import from file
|
||||
imported = self.service.import_pack_from_file(export_file)
|
||||
self.assertIsNotNone(imported)
|
||||
self.assertEqual(len(imported['corrections']), 1)
|
||||
|
||||
def test_version_management(self):
|
||||
"""Test version creation and rollback."""
|
||||
pack = self.service.create_pack(name="Version Test")
|
||||
|
||||
# Add correction
|
||||
self.service.add_correction(pack['id'], {
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_context': 'error',
|
||||
'correction_type': 'target_change'
|
||||
})
|
||||
|
||||
# Create version
|
||||
version = self.service.create_version(pack['id'], "v1 snapshot")
|
||||
self.assertEqual(version, 1)
|
||||
|
||||
# List versions
|
||||
versions = self.service.list_versions(pack['id'])
|
||||
self.assertEqual(len(versions), 1)
|
||||
|
||||
# Add another correction
|
||||
self.service.add_correction(pack['id'], {
|
||||
'action_type': 'type',
|
||||
'element_type': 'input',
|
||||
'failure_context': 'error',
|
||||
'correction_type': 'parameter_change'
|
||||
})
|
||||
|
||||
# Get pack - should have 2 corrections
|
||||
current_pack = self.service.get_pack(pack['id'])
|
||||
self.assertEqual(len(current_pack['corrections']), 2)
|
||||
|
||||
# Rollback
|
||||
self.assertTrue(self.service.rollback_pack(pack['id'], 1))
|
||||
|
||||
# Get pack - should have 1 correction
|
||||
rolled_back = self.service.get_pack(pack['id'])
|
||||
self.assertEqual(len(rolled_back['corrections']), 1)
|
||||
|
||||
def test_statistics(self):
|
||||
"""Test getting statistics."""
|
||||
pack = self.service.create_pack(name="Stats Test")
|
||||
self.service.add_correction(pack['id'], {
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_context': 'error',
|
||||
'correction_type': 'target_change'
|
||||
})
|
||||
|
||||
# Pack statistics
|
||||
pack_stats = self.service.get_pack_statistics(pack['id'])
|
||||
self.assertIsNotNone(pack_stats)
|
||||
self.assertEqual(pack_stats['total_corrections'], 1)
|
||||
|
||||
# Global statistics
|
||||
global_stats = self.service.get_global_statistics()
|
||||
self.assertIsNotNone(global_stats)
|
||||
self.assertEqual(global_stats['total_packs'], 1)
|
||||
|
||||
|
||||
class TestCrossWorkflowAggregator(unittest.TestCase):
|
||||
"""Tests for CrossWorkflowAggregator."""
|
||||
|
||||
def test_find_common_corrections(self):
|
||||
"""Test finding corrections common across workflows."""
|
||||
aggregator = CrossWorkflowAggregator()
|
||||
|
||||
key = CorrectionKey("click", "submit_button", "timeout")
|
||||
|
||||
# Corrections from different workflows
|
||||
corrections = [
|
||||
Correction(
|
||||
id="corr_1",
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
source=CorrectionSource(session_id="s1", workflow_id="wf_1")
|
||||
),
|
||||
Correction(
|
||||
id="corr_2",
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
source=CorrectionSource(session_id="s2", workflow_id="wf_2")
|
||||
),
|
||||
Correction(
|
||||
id="corr_3",
|
||||
key=key,
|
||||
correction_type=CorrectionType.WAIT_ADDED,
|
||||
source=CorrectionSource(session_id="s3", workflow_id="wf_3")
|
||||
),
|
||||
]
|
||||
|
||||
common = aggregator.find_common_corrections(corrections, min_workflows=2)
|
||||
|
||||
self.assertEqual(len(common), 1)
|
||||
merged, workflows = common[0]
|
||||
self.assertEqual(len(workflows), 3)
|
||||
|
||||
def test_detect_patterns(self):
|
||||
"""Test pattern detection."""
|
||||
aggregator = CrossWorkflowAggregator()
|
||||
|
||||
key = CorrectionKey("click", "button", "error")
|
||||
corrections = [
|
||||
Correction(id=f"corr_{i}", key=key, correction_type=CorrectionType.WAIT_ADDED)
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
patterns = aggregator.detect_patterns(corrections)
|
||||
|
||||
self.assertIn('patterns', patterns)
|
||||
self.assertIn('summary', patterns)
|
||||
self.assertEqual(patterns['summary']['total_corrections'], 5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user