feat(corrections): Add automatic COACHING integration for Correction Packs
- Add CorrectionPackIntegration class to bridge learning components - Modify TrainingDataCollector to auto-propagate corrections to packs - Modify FeedbackProcessor to capture corrections on INCORRECT/PARTIAL feedback - Add convenience functions: get_correction_pack_integration(), capture_coaching_correction() - Add 19 integration tests (all passing) Corrections made during COACHING mode are now automatically captured into a dedicated "auto_captured_corrections" pack for cross-workflow reuse. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
486
tests/test_correction_pack_integration.py
Normal file
486
tests/test_correction_pack_integration.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Tests for Correction Pack Integration with learning components.
|
||||
|
||||
Tests the automatic capture of corrections from COACHING mode into
|
||||
the Correction Pack system.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from core.corrections import (
|
||||
CorrectionPackIntegration,
|
||||
CorrectionPackService,
|
||||
get_correction_pack_integration,
|
||||
capture_coaching_correction
|
||||
)
|
||||
from core.corrections.models import CorrectionType
|
||||
from core.training.training_data_collector import TrainingDataCollector
|
||||
from core.learning.feedback_processor import FeedbackProcessor, FeedbackType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage():
|
||||
"""Create temporary storage directory."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield Path(temp_dir)
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(temp_storage):
|
||||
"""Create a CorrectionPackService with temp storage."""
|
||||
return CorrectionPackService(storage_path=temp_storage)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def integration(service):
|
||||
"""Create a CorrectionPackIntegration with test service."""
|
||||
return CorrectionPackIntegration(service=service, auto_create_pack=True)
|
||||
|
||||
|
||||
class TestCorrectionPackIntegration:
|
||||
"""Tests for CorrectionPackIntegration class."""
|
||||
|
||||
def test_init(self, service):
|
||||
"""Test integration initialization."""
|
||||
integration = CorrectionPackIntegration(service=service)
|
||||
assert integration._service is service
|
||||
assert integration._auto_create_pack is True
|
||||
assert integration._initialized is False
|
||||
|
||||
def test_ensure_initialized_creates_pack(self, integration):
|
||||
"""Test auto-creation of default pack."""
|
||||
pack_id = integration._ensure_initialized()
|
||||
|
||||
assert pack_id is not None
|
||||
assert integration._initialized is True
|
||||
assert integration._default_pack_id == pack_id
|
||||
|
||||
# Verify pack exists
|
||||
pack = integration.service.get_pack(pack_id)
|
||||
assert pack is not None
|
||||
# Handle dict format from service
|
||||
pack_name = pack.get('name') if isinstance(pack, dict) else pack.name
|
||||
assert pack_name == CorrectionPackIntegration.DEFAULT_PACK_NAME
|
||||
|
||||
def test_capture_correction(self, integration):
|
||||
"""Test capturing a correction."""
|
||||
correction_data = {
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_reason': 'element_not_found',
|
||||
'correction_type': 'target_change',
|
||||
'original_target': {'text': 'Submit'},
|
||||
'corrected_target': {'text': 'Envoyer'}
|
||||
}
|
||||
|
||||
correction_id = integration.capture_correction(
|
||||
correction_data=correction_data,
|
||||
session_id='test_session_1',
|
||||
workflow_id='wf_001'
|
||||
)
|
||||
|
||||
assert correction_id is not None
|
||||
|
||||
# Verify in pack
|
||||
pack = integration.service.get_pack(integration._default_pack_id)
|
||||
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||
assert len(corrections) == 1
|
||||
|
||||
def test_capture_correction_with_minimal_data(self, integration):
|
||||
"""Test capturing correction with minimal data."""
|
||||
correction_data = {
|
||||
'type': 'input',
|
||||
'new_target': {'xpath': '//input[@name="email"]'}
|
||||
}
|
||||
|
||||
correction_id = integration.capture_correction(
|
||||
correction_data=correction_data,
|
||||
session_id='session_2'
|
||||
)
|
||||
|
||||
assert correction_id is not None
|
||||
|
||||
def test_capture_feedback_correction(self, integration):
|
||||
"""Test capturing correction from feedback."""
|
||||
corrections = {
|
||||
'action_type': 'type',
|
||||
'element_type': 'input',
|
||||
'original_target': {'id': 'username'},
|
||||
'corrected_target': {'name': 'user_name'}
|
||||
}
|
||||
context = {
|
||||
'failure_reason': 'id_changed'
|
||||
}
|
||||
|
||||
correction_id = integration.capture_feedback_correction(
|
||||
workflow_id='wf_feedback_test',
|
||||
execution_id='exec_001',
|
||||
corrections=corrections,
|
||||
context=context
|
||||
)
|
||||
|
||||
assert correction_id is not None
|
||||
|
||||
def test_infer_correction_type(self, integration):
|
||||
"""Test correction type inference."""
|
||||
# Target change
|
||||
assert integration._infer_correction_type({
|
||||
'corrected_target': {'id': 'new'}
|
||||
}) == CorrectionType.TARGET_CHANGE.value
|
||||
|
||||
# Parameter change
|
||||
assert integration._infer_correction_type({
|
||||
'corrected_params': {'timeout': 5}
|
||||
}) == CorrectionType.PARAMETER_CHANGE.value
|
||||
|
||||
# Timing adjust
|
||||
assert integration._infer_correction_type({
|
||||
'wait_time': 2.0
|
||||
}) == CorrectionType.TIMING_ADJUST.value
|
||||
|
||||
# Coordinates adjust
|
||||
assert integration._infer_correction_type({
|
||||
'coordinates': {'x': 100, 'y': 200}
|
||||
}) == CorrectionType.COORDINATES_ADJUST.value
|
||||
|
||||
# Other
|
||||
assert integration._infer_correction_type({}) == CorrectionType.OTHER.value
|
||||
|
||||
def test_create_hook_for_collector(self, integration):
|
||||
"""Test creating hook function."""
|
||||
hook = integration.create_hook_for_collector()
|
||||
|
||||
assert callable(hook)
|
||||
|
||||
# Call hook with correction
|
||||
hook({
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'correction_type': 'target_change'
|
||||
})
|
||||
|
||||
# Verify captured
|
||||
pack = integration.service.get_pack(integration._default_pack_id)
|
||||
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||
assert len(corrections) == 1
|
||||
|
||||
def test_get_statistics(self, integration):
|
||||
"""Test getting statistics."""
|
||||
# Add some corrections
|
||||
for i in range(3):
|
||||
integration.capture_correction(
|
||||
correction_data={
|
||||
'action_type': f'action_{i}',
|
||||
'element_type': 'button'
|
||||
},
|
||||
session_id=f'session_{i}'
|
||||
)
|
||||
|
||||
stats = integration.get_statistics()
|
||||
assert stats['total_corrections'] == 3
|
||||
|
||||
|
||||
class TestGlobalIntegration:
|
||||
"""Tests for global integration functions."""
|
||||
|
||||
def test_get_correction_pack_integration(self, temp_storage):
|
||||
"""Test getting global integration instance."""
|
||||
with patch('core.corrections.integration._global_integration', None):
|
||||
integration1 = get_correction_pack_integration()
|
||||
integration2 = get_correction_pack_integration()
|
||||
|
||||
# Same instance returned
|
||||
assert integration1 is integration2
|
||||
|
||||
def test_capture_coaching_correction(self, temp_storage):
|
||||
"""Test convenience capture function."""
|
||||
with patch('core.corrections.integration._global_integration', None):
|
||||
correction_id = capture_coaching_correction(
|
||||
correction_data={
|
||||
'action_type': 'click',
|
||||
'element_type': 'button'
|
||||
},
|
||||
session_id='test_session'
|
||||
)
|
||||
|
||||
# Should succeed (creates pack automatically)
|
||||
assert correction_id is not None
|
||||
|
||||
|
||||
class TestTrainingDataCollectorIntegration:
|
||||
"""Tests for TrainingDataCollector integration."""
|
||||
|
||||
def test_collector_with_auto_integration(self, temp_storage, service):
|
||||
"""Test collector with automatic correction pack integration."""
|
||||
# Create collector
|
||||
collector = TrainingDataCollector(
|
||||
output_dir=str(temp_storage / 'training'),
|
||||
auto_integrate_corrections=True
|
||||
)
|
||||
|
||||
# Inject our test service
|
||||
integration = CorrectionPackIntegration(service=service)
|
||||
collector._correction_integration = integration
|
||||
|
||||
# Start session
|
||||
collector.start_session('session_auto_test', workflow_id='wf_auto')
|
||||
|
||||
# Record correction
|
||||
collector.record_correction({
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_reason': 'element_moved',
|
||||
'correction_type': 'coordinates_adjust',
|
||||
'original_target': {'x': 100, 'y': 100},
|
||||
'corrected_target': {'x': 150, 'y': 120}
|
||||
})
|
||||
|
||||
# End session
|
||||
collector.end_session(success=True)
|
||||
|
||||
# Verify correction captured in pack
|
||||
pack = service.get_pack(integration._default_pack_id)
|
||||
assert pack is not None
|
||||
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||
assert len(corrections) == 1
|
||||
|
||||
def test_collector_with_custom_callback(self, temp_storage):
|
||||
"""Test collector with custom callback."""
|
||||
captured = []
|
||||
|
||||
def custom_callback(correction, session_id, workflow_id):
|
||||
captured.append({
|
||||
'correction': correction,
|
||||
'session_id': session_id,
|
||||
'workflow_id': workflow_id
|
||||
})
|
||||
|
||||
collector = TrainingDataCollector(
|
||||
output_dir=str(temp_storage / 'training'),
|
||||
correction_callback=custom_callback,
|
||||
auto_integrate_corrections=False
|
||||
)
|
||||
|
||||
collector.start_session('session_callback', workflow_id='wf_callback')
|
||||
collector.record_correction({'type': 'test_correction'})
|
||||
|
||||
assert len(captured) == 1
|
||||
assert captured[0]['session_id'] == 'session_callback'
|
||||
assert captured[0]['workflow_id'] == 'wf_callback'
|
||||
|
||||
def test_collector_disabled_integration(self, temp_storage):
|
||||
"""Test collector with disabled integration."""
|
||||
collector = TrainingDataCollector(
|
||||
output_dir=str(temp_storage / 'training'),
|
||||
auto_integrate_corrections=False
|
||||
)
|
||||
|
||||
collector.start_session('session_disabled')
|
||||
|
||||
# Should not fail even without integration
|
||||
collector.record_correction({'type': 'test'})
|
||||
|
||||
# Still recorded in session
|
||||
assert len(collector.current_session.user_corrections) == 1
|
||||
|
||||
|
||||
class TestFeedbackProcessorIntegration:
|
||||
"""Tests for FeedbackProcessor integration."""
|
||||
|
||||
def test_processor_with_auto_integration(self, service):
|
||||
"""Test feedback processor with auto integration."""
|
||||
processor = FeedbackProcessor(auto_integrate_corrections=True)
|
||||
|
||||
# Inject our test service
|
||||
integration = CorrectionPackIntegration(service=service)
|
||||
processor._correction_integration = integration
|
||||
|
||||
# Process feedback with corrections
|
||||
result = processor.process_feedback(
|
||||
workflow_id='wf_feedback',
|
||||
execution_id='exec_001',
|
||||
feedback_type=FeedbackType.INCORRECT,
|
||||
confidence=0.5,
|
||||
corrections={
|
||||
'action_type': 'click',
|
||||
'element_type': 'link',
|
||||
'original_target': {'text': 'More'},
|
||||
'corrected_target': {'text': 'Plus'}
|
||||
},
|
||||
context={'failure_reason': 'text_changed'}
|
||||
)
|
||||
|
||||
assert result['feedback_recorded'] is True
|
||||
assert result['correction_pack_id'] is not None
|
||||
|
||||
# Verify in pack
|
||||
pack = service.get_pack(integration._default_pack_id)
|
||||
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||
assert len(corrections) == 1
|
||||
|
||||
def test_processor_correct_feedback_no_propagation(self, service):
|
||||
"""Test that CORRECT feedback doesn't propagate corrections."""
|
||||
processor = FeedbackProcessor(auto_integrate_corrections=True)
|
||||
|
||||
integration = CorrectionPackIntegration(service=service)
|
||||
processor._correction_integration = integration
|
||||
|
||||
result = processor.process_feedback(
|
||||
workflow_id='wf_correct',
|
||||
execution_id='exec_002',
|
||||
feedback_type=FeedbackType.CORRECT, # Correct, no correction needed
|
||||
confidence=0.95,
|
||||
corrections={'some': 'data'} # Even with data, shouldn't propagate
|
||||
)
|
||||
|
||||
assert result['correction_pack_id'] is None
|
||||
|
||||
def test_processor_partial_feedback_propagates(self, service):
|
||||
"""Test that PARTIAL feedback propagates corrections."""
|
||||
processor = FeedbackProcessor(auto_integrate_corrections=True)
|
||||
|
||||
integration = CorrectionPackIntegration(service=service)
|
||||
processor._correction_integration = integration
|
||||
|
||||
result = processor.process_feedback(
|
||||
workflow_id='wf_partial',
|
||||
execution_id='exec_003',
|
||||
feedback_type=FeedbackType.PARTIAL,
|
||||
confidence=0.7,
|
||||
corrections={
|
||||
'action_type': 'fill',
|
||||
'corrected_params': {'value': 'corrected_value'}
|
||||
}
|
||||
)
|
||||
|
||||
assert result['correction_pack_id'] is not None
|
||||
|
||||
def test_processor_disabled_integration(self):
|
||||
"""Test processor with disabled integration."""
|
||||
processor = FeedbackProcessor(auto_integrate_corrections=False)
|
||||
|
||||
result = processor.process_feedback(
|
||||
workflow_id='wf_disabled',
|
||||
execution_id='exec_004',
|
||||
feedback_type=FeedbackType.INCORRECT,
|
||||
confidence=0.3,
|
||||
corrections={'action_type': 'test'}
|
||||
)
|
||||
|
||||
# Feedback recorded but no correction pack ID
|
||||
assert result['feedback_recorded'] is True
|
||||
assert result['correction_pack_id'] is None
|
||||
|
||||
|
||||
class TestEndToEndFlow:
|
||||
"""End-to-end tests for the complete COACHING flow."""
|
||||
|
||||
def test_coaching_correction_flow(self, temp_storage, service):
|
||||
"""Test complete COACHING correction flow."""
|
||||
# Setup components
|
||||
integration = CorrectionPackIntegration(service=service)
|
||||
|
||||
collector = TrainingDataCollector(
|
||||
output_dir=str(temp_storage / 'training'),
|
||||
auto_integrate_corrections=True
|
||||
)
|
||||
collector._correction_integration = integration
|
||||
|
||||
processor = FeedbackProcessor(auto_integrate_corrections=True)
|
||||
processor._correction_integration = integration
|
||||
|
||||
# Simulate COACHING session
|
||||
collector.start_session('coaching_session_001', workflow_id='wf_coaching')
|
||||
|
||||
# User makes corrections during execution
|
||||
collector.record_correction({
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_reason': 'element_not_found',
|
||||
'correction_type': 'target_change',
|
||||
'original_target': {'text': 'OK'},
|
||||
'corrected_target': {'text': 'Valider'}
|
||||
})
|
||||
|
||||
collector.record_correction({
|
||||
'action_type': 'type',
|
||||
'element_type': 'input',
|
||||
'failure_reason': 'wrong_field',
|
||||
'correction_type': 'target_change',
|
||||
'original_target': {'id': 'email'},
|
||||
'corrected_target': {'name': 'user_email'}
|
||||
})
|
||||
|
||||
collector.end_session(success=True)
|
||||
|
||||
# User provides feedback at end
|
||||
processor.process_feedback(
|
||||
workflow_id='wf_coaching',
|
||||
execution_id='coaching_session_001',
|
||||
feedback_type=FeedbackType.PARTIAL,
|
||||
confidence=0.75,
|
||||
corrections={
|
||||
'action_type': 'submit',
|
||||
'element_type': 'form',
|
||||
'corrected_params': {'wait_after': 2.0}
|
||||
}
|
||||
)
|
||||
|
||||
# Verify all corrections captured
|
||||
pack = service.get_pack(integration._default_pack_id)
|
||||
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||
assert len(corrections) == 3
|
||||
|
||||
# Verify statistics
|
||||
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
|
||||
stats = service.get_pack_statistics(pack_id)
|
||||
assert stats['total_corrections'] == 3
|
||||
|
||||
def test_multiple_sessions_aggregation(self, temp_storage, service):
|
||||
"""Test corrections from multiple sessions aggregate correctly."""
|
||||
integration = CorrectionPackIntegration(service=service)
|
||||
|
||||
# Session 1
|
||||
collector1 = TrainingDataCollector(
|
||||
output_dir=str(temp_storage / 'training1')
|
||||
)
|
||||
collector1._correction_integration = integration
|
||||
|
||||
collector1.start_session('session_001', workflow_id='wf_multi')
|
||||
collector1.record_correction({
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_reason': 'timeout',
|
||||
'correction_type': 'timing_adjust'
|
||||
})
|
||||
collector1.end_session(success=True)
|
||||
|
||||
# Session 2 - same type of correction
|
||||
collector2 = TrainingDataCollector(
|
||||
output_dir=str(temp_storage / 'training2')
|
||||
)
|
||||
collector2._correction_integration = integration
|
||||
|
||||
collector2.start_session('session_002', workflow_id='wf_multi')
|
||||
collector2.record_correction({
|
||||
'action_type': 'click',
|
||||
'element_type': 'button',
|
||||
'failure_reason': 'timeout',
|
||||
'correction_type': 'timing_adjust'
|
||||
})
|
||||
collector2.end_session(success=True)
|
||||
|
||||
# Both corrections captured
|
||||
pack = service.get_pack(integration._default_pack_id)
|
||||
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
|
||||
assert len(corrections) == 2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user