""" Tests for the Correction Packs system. Tests models, repository, aggregator, and service. """ import json import os import shutil import tempfile import unittest from datetime import datetime from pathlib import Path import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from core.corrections.models import ( CorrectionKey, CorrectionType, CorrectionStatus, Correction, CorrectionPack, CorrectionPackMetadata, CorrectionSource, generate_correction_id, generate_pack_id ) from core.corrections.correction_repository import CorrectionRepository from core.corrections.aggregator import CorrectionAggregator, CrossWorkflowAggregator from core.corrections.correction_pack_service import CorrectionPackService class TestCorrectionKey(unittest.TestCase): """Tests for CorrectionKey model.""" def test_to_hash(self): """Test hash generation is consistent.""" key1 = CorrectionKey("click", "button", "element_not_found") key2 = CorrectionKey("click", "button", "element_not_found") key3 = CorrectionKey("type", "input", "element_not_found") # Same keys should produce same hash self.assertEqual(key1.to_hash(), key2.to_hash()) # Different keys should produce different hash self.assertNotEqual(key1.to_hash(), key3.to_hash()) # Hash should be 16 characters self.assertEqual(len(key1.to_hash()), 16) def test_to_dict_from_dict(self): """Test serialization round-trip.""" key = CorrectionKey("click", "button", "timeout") key_dict = key.to_dict() self.assertEqual(key_dict['action_type'], "click") self.assertEqual(key_dict['element_type'], "button") self.assertEqual(key_dict['failure_context'], "timeout") self.assertIn('hash', key_dict) restored = CorrectionKey.from_dict(key_dict) self.assertEqual(key.to_hash(), restored.to_hash()) class TestCorrection(unittest.TestCase): """Tests for Correction model.""" def test_create_correction(self): """Test creating a correction.""" key = CorrectionKey("click", "button", "not_visible") correction = Correction( id=generate_correction_id(), key=key, correction_type=CorrectionType.WAIT_ADDED, original_target={"selector": "#submit"}, corrected_params={"wait_time": 2.0}, correction_description="Added wait for button visibility" ) self.assertTrue(correction.id.startswith("corr_")) self.assertEqual(correction.correction_type, CorrectionType.WAIT_ADDED) self.assertEqual(correction.success_rate, 0.0) def test_success_rate_calculation(self): """Test success rate calculation.""" key = CorrectionKey("click", "button", "not_visible") correction = Correction( id=generate_correction_id(), key=key, correction_type=CorrectionType.WAIT_ADDED, success_count=7, failure_count=3 ) self.assertEqual(correction.success_rate, 0.7) def test_record_application(self): """Test recording application updates stats.""" key = CorrectionKey("click", "button", "not_visible") correction = Correction( id=generate_correction_id(), key=key, correction_type=CorrectionType.WAIT_ADDED ) correction.record_application(success=True) self.assertEqual(correction.success_count, 1) self.assertEqual(correction.failure_count, 0) self.assertIsNotNone(correction.last_applied) correction.record_application(success=False) self.assertEqual(correction.success_count, 1) self.assertEqual(correction.failure_count, 1) def test_to_dict_from_dict(self): """Test serialization round-trip.""" key = CorrectionKey("type", "input", "validation_error") source = CorrectionSource( session_id="sess_123", workflow_id="wf_456", node_id="node_789" ) correction = Correction( id="corr_test123", key=key, correction_type=CorrectionType.PARAMETER_CHANGE, original_params={"value": "test"}, corrected_params={"value": "corrected"}, source=source, success_count=5, failure_count=2, tags=["form", "validation"] ) correction_dict = correction.to_dict() restored = Correction.from_dict(correction_dict) self.assertEqual(restored.id, correction.id) self.assertEqual(restored.key.to_hash(), key.to_hash()) self.assertEqual(restored.correction_type, CorrectionType.PARAMETER_CHANGE) self.assertEqual(restored.success_count, 5) self.assertEqual(restored.tags, ["form", "validation"]) class TestCorrectionPack(unittest.TestCase): """Tests for CorrectionPack model.""" def test_create_pack(self): """Test creating a pack.""" pack = CorrectionPack( id=generate_pack_id(), name="Test Pack", description="A test correction pack" ) self.assertTrue(pack.id.startswith("pack_")) self.assertEqual(pack.name, "Test Pack") self.assertEqual(pack.correction_count, 0) def test_add_remove_correction(self): """Test adding and removing corrections.""" pack = CorrectionPack( id=generate_pack_id(), name="Test Pack" ) key = CorrectionKey("click", "button", "not_found") correction = Correction( id="corr_test1", key=key, correction_type=CorrectionType.TARGET_CHANGE ) # Add correction self.assertTrue(pack.add_correction(correction)) self.assertEqual(pack.correction_count, 1) # Adding same correction should return False self.assertFalse(pack.add_correction(correction)) # Remove correction self.assertTrue(pack.remove_correction("corr_test1")) self.assertEqual(pack.correction_count, 0) # Removing non-existent should return False self.assertFalse(pack.remove_correction("corr_test1")) def test_find_by_key_hash(self): """Test finding corrections by key hash.""" pack = CorrectionPack(id=generate_pack_id(), name="Test Pack") key1 = CorrectionKey("click", "button", "not_found") key2 = CorrectionKey("type", "input", "validation") correction1 = Correction(id="corr_1", key=key1, correction_type=CorrectionType.TARGET_CHANGE) correction2 = Correction(id="corr_2", key=key2, correction_type=CorrectionType.PARAMETER_CHANGE) pack.add_correction(correction1) pack.add_correction(correction2) # Find by key1 hash found = pack.find_by_key_hash(key1.to_hash()) self.assertEqual(len(found), 1) self.assertEqual(found[0].id, "corr_1") def test_merge_packs(self): """Test merging two packs.""" pack1 = CorrectionPack(id="pack_1", name="Pack 1") pack2 = CorrectionPack(id="pack_2", name="Pack 2") key = CorrectionKey("click", "button", "error") corr1 = Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE) corr2 = Correction(id="corr_2", key=key, correction_type=CorrectionType.WAIT_ADDED) pack1.add_correction(corr1) pack2.add_correction(corr2) merged_count = pack1.merge(pack2) self.assertEqual(merged_count, 1) self.assertEqual(pack1.correction_count, 2) def test_get_statistics(self): """Test getting pack statistics.""" pack = CorrectionPack(id=generate_pack_id(), name="Test Pack") key = CorrectionKey("click", "button", "error") correction = Correction( id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=10, failure_count=2 ) pack.add_correction(correction) stats = pack.get_statistics() self.assertEqual(stats['total_corrections'], 1) self.assertEqual(stats['active_corrections'], 1) self.assertEqual(stats['total_applications'], 12) class TestCorrectionRepository(unittest.TestCase): """Tests for CorrectionRepository.""" def setUp(self): """Create temporary directory for tests.""" self.test_dir = tempfile.mkdtemp() self.repo = CorrectionRepository(self.test_dir) def tearDown(self): """Clean up temporary directory.""" shutil.rmtree(self.test_dir) def test_create_and_get_pack(self): """Test creating and retrieving a pack.""" pack = self.repo.create_pack("Test Pack", "Description", { 'category': 'testing', 'tags': ['unit', 'test'] }) self.assertIsNotNone(pack) self.assertEqual(pack.name, "Test Pack") retrieved = self.repo.get_pack(pack.id) self.assertIsNotNone(retrieved) self.assertEqual(retrieved.name, "Test Pack") def test_list_packs(self): """Test listing packs.""" self.repo.create_pack("Pack 1", "First pack") self.repo.create_pack("Pack 2", "Second pack") packs = self.repo.list_packs() self.assertEqual(len(packs), 2) def test_update_pack(self): """Test updating a pack.""" pack = self.repo.create_pack("Original Name") updated = self.repo.update_pack(pack.id, {'name': 'New Name'}) self.assertIsNotNone(updated) self.assertEqual(updated.name, 'New Name') def test_delete_pack(self): """Test deleting a pack.""" pack = self.repo.create_pack("To Delete") self.assertTrue(self.repo.delete_pack(pack.id)) self.assertIsNone(self.repo.get_pack(pack.id)) self.assertFalse(self.repo.delete_pack(pack.id)) def test_add_and_find_correction(self): """Test adding and finding corrections.""" pack = self.repo.create_pack("Test Pack") key = CorrectionKey("click", "button", "timeout") correction = Correction( id=generate_correction_id(), key=key, correction_type=CorrectionType.TIMING_ADJUST ) self.assertTrue(self.repo.add_correction(pack.id, correction)) # Find by context found = self.repo.find_by_context("click", "button", "timeout") self.assertEqual(len(found), 1) self.assertEqual(found[0][1].id, correction.id) def test_export_import_pack(self): """Test exporting and importing a pack.""" # Create pack with correction pack = self.repo.create_pack("Export Test") key = CorrectionKey("click", "button", "error") correction = Correction( id="corr_export", key=key, correction_type=CorrectionType.TARGET_CHANGE ) self.repo.add_correction(pack.id, correction) # Export exported = self.repo.export_pack(pack.id) self.assertIsNotNone(exported) # Import as new imported = self.repo.import_pack(exported, merge_strategy='create_new') self.assertIsNotNone(imported) self.assertNotEqual(imported.id, pack.id) self.assertEqual(imported.correction_count, 1) def test_version_and_rollback(self): """Test versioning and rollback.""" pack = self.repo.create_pack("Version Test") # Add correction key = CorrectionKey("click", "button", "error") correction1 = Correction(id="corr_v1", key=key, correction_type=CorrectionType.TARGET_CHANGE) self.repo.add_correction(pack.id, correction1) # Create version version = self.repo.create_version(pack.id, "First version") self.assertEqual(version, 1) # Add another correction correction2 = Correction(id="corr_v2", key=key, correction_type=CorrectionType.WAIT_ADDED) self.repo.add_correction(pack.id, correction2) # Verify 2 corrections pack = self.repo.get_pack(pack.id) self.assertEqual(pack.correction_count, 2) # Rollback to version 1 self.assertTrue(self.repo.rollback_pack(pack.id, 1)) # Should have 1 correction again pack = self.repo.get_pack(pack.id) self.assertEqual(pack.correction_count, 1) class TestCorrectionAggregator(unittest.TestCase): """Tests for CorrectionAggregator.""" def setUp(self): """Set up aggregator for tests.""" self.aggregator = CorrectionAggregator( min_occurrences=2, min_success_rate=0.5, min_confidence=0.3 ) def test_aggregate_by_key(self): """Test aggregating corrections by key.""" key = CorrectionKey("click", "button", "not_found") # Create multiple corrections with same key corrections = [ Correction( id=f"corr_{i}", key=key, correction_type=CorrectionType.WAIT_ADDED, success_count=5, failure_count=2 ) for i in range(3) ] aggregated = self.aggregator.aggregate_corrections(corrections) # Should have 1 merged correction self.assertEqual(len(aggregated), 1) # Sum of success counts self.assertEqual(aggregated[0].success_count, 15) def test_filter_by_quality(self): """Test filtering by quality thresholds.""" key = CorrectionKey("click", "button", "error") high_quality = Correction( id="corr_high", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=9, failure_count=1 ) low_quality = Correction( id="corr_low", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=1, failure_count=9 ) filtered = self.aggregator.filter_by_quality( [high_quality, low_quality], min_success_rate=0.7 ) self.assertEqual(len(filtered), 1) self.assertEqual(filtered[0].id, "corr_high") def test_rank_corrections(self): """Test ranking corrections.""" key = CorrectionKey("click", "button", "error") corrections = [ Correction(id="corr_1", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=3, failure_count=7), Correction(id="corr_2", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=8, failure_count=2), Correction(id="corr_3", key=key, correction_type=CorrectionType.TARGET_CHANGE, success_count=5, failure_count=5), ] ranked = self.aggregator.rank_corrections(corrections, sort_by='success_rate') self.assertEqual(ranked[0].id, "corr_2") self.assertEqual(ranked[1].id, "corr_3") self.assertEqual(ranked[2].id, "corr_1") class TestCorrectionPackService(unittest.TestCase): """Tests for CorrectionPackService.""" def setUp(self): """Set up service for tests.""" self.test_dir = tempfile.mkdtemp() self.training_dir = tempfile.mkdtemp() self.service = CorrectionPackService(self.test_dir, self.training_dir) def tearDown(self): """Clean up temporary directories.""" shutil.rmtree(self.test_dir) shutil.rmtree(self.training_dir) def test_create_and_list_packs(self): """Test creating and listing packs.""" pack = self.service.create_pack( name="Service Test Pack", description="Testing the service", category="testing", tags=["unit", "test"] ) self.assertIsNotNone(pack) self.assertEqual(pack['name'], "Service Test Pack") packs = self.service.list_packs() self.assertEqual(len(packs), 1) self.assertEqual(packs[0]['name'], "Service Test Pack") def test_add_correction(self): """Test adding corrections via service.""" pack = self.service.create_pack(name="Correction Test") correction = self.service.add_correction(pack['id'], { 'action_type': 'click', 'element_type': 'button', 'failure_context': 'element_not_found', 'correction_type': 'target_change', 'original_target': {'selector': '#old'}, 'corrected_target': {'selector': '#new'}, 'description': 'Updated selector' }) self.assertIsNotNone(correction) self.assertTrue(correction['id'].startswith('corr_')) def test_find_applicable_corrections(self): """Test finding applicable corrections.""" pack = self.service.create_pack(name="Find Test") self.service.add_correction(pack['id'], { 'action_type': 'click', 'element_type': 'button', 'failure_context': 'timeout', 'correction_type': 'timing_adjust', 'corrected_params': {'wait_time': 5} }) found = self.service.find_applicable_corrections( action_type='click', element_type='button', failure_context='timeout', min_confidence=0.0 ) self.assertEqual(len(found), 1) def test_export_import_file(self): """Test file-based export and import.""" pack = self.service.create_pack(name="File Export Test") self.service.add_correction(pack['id'], { 'action_type': 'type', 'element_type': 'input', 'failure_context': 'validation', 'correction_type': 'parameter_change' }) # Export to file export_file = os.path.join(self.test_dir, 'export.json') self.assertTrue(self.service.export_pack_to_file(pack['id'], export_file)) self.assertTrue(os.path.exists(export_file)) # Import from file imported = self.service.import_pack_from_file(export_file) self.assertIsNotNone(imported) self.assertEqual(len(imported['corrections']), 1) def test_version_management(self): """Test version creation and rollback.""" pack = self.service.create_pack(name="Version Test") # Add correction self.service.add_correction(pack['id'], { 'action_type': 'click', 'element_type': 'button', 'failure_context': 'error', 'correction_type': 'target_change' }) # Create version version = self.service.create_version(pack['id'], "v1 snapshot") self.assertEqual(version, 1) # List versions versions = self.service.list_versions(pack['id']) self.assertEqual(len(versions), 1) # Add another correction self.service.add_correction(pack['id'], { 'action_type': 'type', 'element_type': 'input', 'failure_context': 'error', 'correction_type': 'parameter_change' }) # Get pack - should have 2 corrections current_pack = self.service.get_pack(pack['id']) self.assertEqual(len(current_pack['corrections']), 2) # Rollback self.assertTrue(self.service.rollback_pack(pack['id'], 1)) # Get pack - should have 1 correction rolled_back = self.service.get_pack(pack['id']) self.assertEqual(len(rolled_back['corrections']), 1) def test_statistics(self): """Test getting statistics.""" pack = self.service.create_pack(name="Stats Test") self.service.add_correction(pack['id'], { 'action_type': 'click', 'element_type': 'button', 'failure_context': 'error', 'correction_type': 'target_change' }) # Pack statistics pack_stats = self.service.get_pack_statistics(pack['id']) self.assertIsNotNone(pack_stats) self.assertEqual(pack_stats['total_corrections'], 1) # Global statistics global_stats = self.service.get_global_statistics() self.assertIsNotNone(global_stats) self.assertEqual(global_stats['total_packs'], 1) class TestCrossWorkflowAggregator(unittest.TestCase): """Tests for CrossWorkflowAggregator.""" def test_find_common_corrections(self): """Test finding corrections common across workflows.""" aggregator = CrossWorkflowAggregator() key = CorrectionKey("click", "submit_button", "timeout") # Corrections from different workflows corrections = [ Correction( id="corr_1", key=key, correction_type=CorrectionType.WAIT_ADDED, source=CorrectionSource(session_id="s1", workflow_id="wf_1") ), Correction( id="corr_2", key=key, correction_type=CorrectionType.WAIT_ADDED, source=CorrectionSource(session_id="s2", workflow_id="wf_2") ), Correction( id="corr_3", key=key, correction_type=CorrectionType.WAIT_ADDED, source=CorrectionSource(session_id="s3", workflow_id="wf_3") ), ] common = aggregator.find_common_corrections(corrections, min_workflows=2) self.assertEqual(len(common), 1) merged, workflows = common[0] self.assertEqual(len(workflows), 3) def test_detect_patterns(self): """Test pattern detection.""" aggregator = CrossWorkflowAggregator() key = CorrectionKey("click", "button", "error") corrections = [ Correction(id=f"corr_{i}", key=key, correction_type=CorrectionType.WAIT_ADDED) for i in range(5) ] patterns = aggregator.detect_patterns(corrections) self.assertIn('patterns', patterns) self.assertIn('summary', patterns) self.assertEqual(patterns['summary']['total_corrections'], 5) if __name__ == '__main__': unittest.main()