feat(corrections): Add automatic COACHING integration for Correction Packs

- Add CorrectionPackIntegration class to bridge learning components
- Modify TrainingDataCollector to auto-propagate corrections to packs
- Modify FeedbackProcessor to capture corrections on INCORRECT/PARTIAL feedback
- Add convenience functions: get_correction_pack_integration(), capture_coaching_correction()
- Add 19 integration tests (all passing)

Corrections made during COACHING mode are now automatically captured
into a dedicated "auto_captured_corrections" pack for cross-workflow reuse.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Dom
2026-01-18 19:06:09 +01:00
parent d8756883c5
commit efb184fdb9
5 changed files with 1206 additions and 1 deletions

View File

@@ -0,0 +1,486 @@
"""
Tests for Correction Pack Integration with learning components.
Tests the automatic capture of corrections from COACHING mode into
the Correction Pack system.
"""
import pytest
import tempfile
import shutil
from pathlib import Path
from datetime import datetime
from unittest.mock import patch, MagicMock
from core.corrections import (
CorrectionPackIntegration,
CorrectionPackService,
get_correction_pack_integration,
capture_coaching_correction
)
from core.corrections.models import CorrectionType
from core.training.training_data_collector import TrainingDataCollector
from core.learning.feedback_processor import FeedbackProcessor, FeedbackType
@pytest.fixture
def temp_storage():
"""Create temporary storage directory."""
temp_dir = tempfile.mkdtemp()
yield Path(temp_dir)
shutil.rmtree(temp_dir, ignore_errors=True)
@pytest.fixture
def service(temp_storage):
"""Create a CorrectionPackService with temp storage."""
return CorrectionPackService(storage_path=temp_storage)
@pytest.fixture
def integration(service):
"""Create a CorrectionPackIntegration with test service."""
return CorrectionPackIntegration(service=service, auto_create_pack=True)
class TestCorrectionPackIntegration:
"""Tests for CorrectionPackIntegration class."""
def test_init(self, service):
"""Test integration initialization."""
integration = CorrectionPackIntegration(service=service)
assert integration._service is service
assert integration._auto_create_pack is True
assert integration._initialized is False
def test_ensure_initialized_creates_pack(self, integration):
"""Test auto-creation of default pack."""
pack_id = integration._ensure_initialized()
assert pack_id is not None
assert integration._initialized is True
assert integration._default_pack_id == pack_id
# Verify pack exists
pack = integration.service.get_pack(pack_id)
assert pack is not None
# Handle dict format from service
pack_name = pack.get('name') if isinstance(pack, dict) else pack.name
assert pack_name == CorrectionPackIntegration.DEFAULT_PACK_NAME
def test_capture_correction(self, integration):
"""Test capturing a correction."""
correction_data = {
'action_type': 'click',
'element_type': 'button',
'failure_reason': 'element_not_found',
'correction_type': 'target_change',
'original_target': {'text': 'Submit'},
'corrected_target': {'text': 'Envoyer'}
}
correction_id = integration.capture_correction(
correction_data=correction_data,
session_id='test_session_1',
workflow_id='wf_001'
)
assert correction_id is not None
# Verify in pack
pack = integration.service.get_pack(integration._default_pack_id)
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 1
def test_capture_correction_with_minimal_data(self, integration):
"""Test capturing correction with minimal data."""
correction_data = {
'type': 'input',
'new_target': {'xpath': '//input[@name="email"]'}
}
correction_id = integration.capture_correction(
correction_data=correction_data,
session_id='session_2'
)
assert correction_id is not None
def test_capture_feedback_correction(self, integration):
"""Test capturing correction from feedback."""
corrections = {
'action_type': 'type',
'element_type': 'input',
'original_target': {'id': 'username'},
'corrected_target': {'name': 'user_name'}
}
context = {
'failure_reason': 'id_changed'
}
correction_id = integration.capture_feedback_correction(
workflow_id='wf_feedback_test',
execution_id='exec_001',
corrections=corrections,
context=context
)
assert correction_id is not None
def test_infer_correction_type(self, integration):
"""Test correction type inference."""
# Target change
assert integration._infer_correction_type({
'corrected_target': {'id': 'new'}
}) == CorrectionType.TARGET_CHANGE.value
# Parameter change
assert integration._infer_correction_type({
'corrected_params': {'timeout': 5}
}) == CorrectionType.PARAMETER_CHANGE.value
# Timing adjust
assert integration._infer_correction_type({
'wait_time': 2.0
}) == CorrectionType.TIMING_ADJUST.value
# Coordinates adjust
assert integration._infer_correction_type({
'coordinates': {'x': 100, 'y': 200}
}) == CorrectionType.COORDINATES_ADJUST.value
# Other
assert integration._infer_correction_type({}) == CorrectionType.OTHER.value
def test_create_hook_for_collector(self, integration):
"""Test creating hook function."""
hook = integration.create_hook_for_collector()
assert callable(hook)
# Call hook with correction
hook({
'action_type': 'click',
'element_type': 'button',
'correction_type': 'target_change'
})
# Verify captured
pack = integration.service.get_pack(integration._default_pack_id)
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 1
def test_get_statistics(self, integration):
"""Test getting statistics."""
# Add some corrections
for i in range(3):
integration.capture_correction(
correction_data={
'action_type': f'action_{i}',
'element_type': 'button'
},
session_id=f'session_{i}'
)
stats = integration.get_statistics()
assert stats['total_corrections'] == 3
class TestGlobalIntegration:
"""Tests for global integration functions."""
def test_get_correction_pack_integration(self, temp_storage):
"""Test getting global integration instance."""
with patch('core.corrections.integration._global_integration', None):
integration1 = get_correction_pack_integration()
integration2 = get_correction_pack_integration()
# Same instance returned
assert integration1 is integration2
def test_capture_coaching_correction(self, temp_storage):
"""Test convenience capture function."""
with patch('core.corrections.integration._global_integration', None):
correction_id = capture_coaching_correction(
correction_data={
'action_type': 'click',
'element_type': 'button'
},
session_id='test_session'
)
# Should succeed (creates pack automatically)
assert correction_id is not None
class TestTrainingDataCollectorIntegration:
"""Tests for TrainingDataCollector integration."""
def test_collector_with_auto_integration(self, temp_storage, service):
"""Test collector with automatic correction pack integration."""
# Create collector
collector = TrainingDataCollector(
output_dir=str(temp_storage / 'training'),
auto_integrate_corrections=True
)
# Inject our test service
integration = CorrectionPackIntegration(service=service)
collector._correction_integration = integration
# Start session
collector.start_session('session_auto_test', workflow_id='wf_auto')
# Record correction
collector.record_correction({
'action_type': 'click',
'element_type': 'button',
'failure_reason': 'element_moved',
'correction_type': 'coordinates_adjust',
'original_target': {'x': 100, 'y': 100},
'corrected_target': {'x': 150, 'y': 120}
})
# End session
collector.end_session(success=True)
# Verify correction captured in pack
pack = service.get_pack(integration._default_pack_id)
assert pack is not None
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 1
def test_collector_with_custom_callback(self, temp_storage):
"""Test collector with custom callback."""
captured = []
def custom_callback(correction, session_id, workflow_id):
captured.append({
'correction': correction,
'session_id': session_id,
'workflow_id': workflow_id
})
collector = TrainingDataCollector(
output_dir=str(temp_storage / 'training'),
correction_callback=custom_callback,
auto_integrate_corrections=False
)
collector.start_session('session_callback', workflow_id='wf_callback')
collector.record_correction({'type': 'test_correction'})
assert len(captured) == 1
assert captured[0]['session_id'] == 'session_callback'
assert captured[0]['workflow_id'] == 'wf_callback'
def test_collector_disabled_integration(self, temp_storage):
"""Test collector with disabled integration."""
collector = TrainingDataCollector(
output_dir=str(temp_storage / 'training'),
auto_integrate_corrections=False
)
collector.start_session('session_disabled')
# Should not fail even without integration
collector.record_correction({'type': 'test'})
# Still recorded in session
assert len(collector.current_session.user_corrections) == 1
class TestFeedbackProcessorIntegration:
"""Tests for FeedbackProcessor integration."""
def test_processor_with_auto_integration(self, service):
"""Test feedback processor with auto integration."""
processor = FeedbackProcessor(auto_integrate_corrections=True)
# Inject our test service
integration = CorrectionPackIntegration(service=service)
processor._correction_integration = integration
# Process feedback with corrections
result = processor.process_feedback(
workflow_id='wf_feedback',
execution_id='exec_001',
feedback_type=FeedbackType.INCORRECT,
confidence=0.5,
corrections={
'action_type': 'click',
'element_type': 'link',
'original_target': {'text': 'More'},
'corrected_target': {'text': 'Plus'}
},
context={'failure_reason': 'text_changed'}
)
assert result['feedback_recorded'] is True
assert result['correction_pack_id'] is not None
# Verify in pack
pack = service.get_pack(integration._default_pack_id)
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 1
def test_processor_correct_feedback_no_propagation(self, service):
"""Test that CORRECT feedback doesn't propagate corrections."""
processor = FeedbackProcessor(auto_integrate_corrections=True)
integration = CorrectionPackIntegration(service=service)
processor._correction_integration = integration
result = processor.process_feedback(
workflow_id='wf_correct',
execution_id='exec_002',
feedback_type=FeedbackType.CORRECT, # Correct, no correction needed
confidence=0.95,
corrections={'some': 'data'} # Even with data, shouldn't propagate
)
assert result['correction_pack_id'] is None
def test_processor_partial_feedback_propagates(self, service):
"""Test that PARTIAL feedback propagates corrections."""
processor = FeedbackProcessor(auto_integrate_corrections=True)
integration = CorrectionPackIntegration(service=service)
processor._correction_integration = integration
result = processor.process_feedback(
workflow_id='wf_partial',
execution_id='exec_003',
feedback_type=FeedbackType.PARTIAL,
confidence=0.7,
corrections={
'action_type': 'fill',
'corrected_params': {'value': 'corrected_value'}
}
)
assert result['correction_pack_id'] is not None
def test_processor_disabled_integration(self):
"""Test processor with disabled integration."""
processor = FeedbackProcessor(auto_integrate_corrections=False)
result = processor.process_feedback(
workflow_id='wf_disabled',
execution_id='exec_004',
feedback_type=FeedbackType.INCORRECT,
confidence=0.3,
corrections={'action_type': 'test'}
)
# Feedback recorded but no correction pack ID
assert result['feedback_recorded'] is True
assert result['correction_pack_id'] is None
class TestEndToEndFlow:
"""End-to-end tests for the complete COACHING flow."""
def test_coaching_correction_flow(self, temp_storage, service):
"""Test complete COACHING correction flow."""
# Setup components
integration = CorrectionPackIntegration(service=service)
collector = TrainingDataCollector(
output_dir=str(temp_storage / 'training'),
auto_integrate_corrections=True
)
collector._correction_integration = integration
processor = FeedbackProcessor(auto_integrate_corrections=True)
processor._correction_integration = integration
# Simulate COACHING session
collector.start_session('coaching_session_001', workflow_id='wf_coaching')
# User makes corrections during execution
collector.record_correction({
'action_type': 'click',
'element_type': 'button',
'failure_reason': 'element_not_found',
'correction_type': 'target_change',
'original_target': {'text': 'OK'},
'corrected_target': {'text': 'Valider'}
})
collector.record_correction({
'action_type': 'type',
'element_type': 'input',
'failure_reason': 'wrong_field',
'correction_type': 'target_change',
'original_target': {'id': 'email'},
'corrected_target': {'name': 'user_email'}
})
collector.end_session(success=True)
# User provides feedback at end
processor.process_feedback(
workflow_id='wf_coaching',
execution_id='coaching_session_001',
feedback_type=FeedbackType.PARTIAL,
confidence=0.75,
corrections={
'action_type': 'submit',
'element_type': 'form',
'corrected_params': {'wait_after': 2.0}
}
)
# Verify all corrections captured
pack = service.get_pack(integration._default_pack_id)
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 3
# Verify statistics
pack_id = pack.get('id') if isinstance(pack, dict) else pack.id
stats = service.get_pack_statistics(pack_id)
assert stats['total_corrections'] == 3
def test_multiple_sessions_aggregation(self, temp_storage, service):
"""Test corrections from multiple sessions aggregate correctly."""
integration = CorrectionPackIntegration(service=service)
# Session 1
collector1 = TrainingDataCollector(
output_dir=str(temp_storage / 'training1')
)
collector1._correction_integration = integration
collector1.start_session('session_001', workflow_id='wf_multi')
collector1.record_correction({
'action_type': 'click',
'element_type': 'button',
'failure_reason': 'timeout',
'correction_type': 'timing_adjust'
})
collector1.end_session(success=True)
# Session 2 - same type of correction
collector2 = TrainingDataCollector(
output_dir=str(temp_storage / 'training2')
)
collector2._correction_integration = integration
collector2.start_session('session_002', workflow_id='wf_multi')
collector2.record_correction({
'action_type': 'click',
'element_type': 'button',
'failure_reason': 'timeout',
'correction_type': 'timing_adjust'
})
collector2.end_session(success=True)
# Both corrections captured
pack = service.get_pack(integration._default_pack_id)
corrections = pack.get('corrections') if isinstance(pack, dict) else pack.corrections
assert len(corrections) == 2
if __name__ == '__main__':
pytest.main([__file__, '-v'])