#!/usr/bin/env python3 """Test Training System - Phase 8""" import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def test_training_data_collector(): logger.info("\n=== Testing TrainingDataCollector ===") try: from core.training.training_data_collector import TrainingDataCollector collector = TrainingDataCollector(output_dir="test_training_data") logger.info("✓ TrainingDataCollector created") # Simulate collecting data collector.start_session("session_001", workflow_id="email_workflow") collector.record_screenshot("/path/to/screenshot1.png") collector.record_action({'type': 'click', 'target': 'compose_button'}) collector.record_embedding("/path/to/embedding1.npy") collector.end_session(success=True, metadata={'duration_ms': 1500}) logger.info("✓ Session recorded") # Export training set training_set = collector.export_training_set("test_training_set.json") logger.info(f"✓ Training set exported: {training_set['metadata']['total_sessions']} sessions") return True except Exception as e: logger.error(f"✗ Test failed: {e}", exc_info=True) return False def test_offline_trainer(): logger.info("\n=== Testing OfflineTrainer ===") try: from core.training.offline_trainer import OfflineTrainer, TrainingConfig config = TrainingConfig( learning_rate=0.001, num_epochs=5, min_samples_per_workflow=3 ) trainer = OfflineTrainer(config) logger.info("✓ OfflineTrainer created") # Create dummy training data dummy_data = { 'metadata': { 'total_sessions': 10, 'total_patterns': 2 }, 'sessions': [ { 'session_id': f'session_{i}', 'workflow_id': 'test_workflow', 'timestamp': '2024-11-23T12:00:00', 'success': True, 'actions': [], 'embeddings': [] } for i in range(10) ], 'patterns': [] } # Train prototypes prototypes = trainer.train_prototypes(dummy_data) logger.info(f"✓ Prototypes trained: {len(prototypes)} workflows") # Train thresholds thresholds = trainer.train_thresholds(dummy_data) logger.info(f"✓ Thresholds trained: {len(thresholds)} workflows") # Validate metrics = trainer.validate_model(dummy_data) logger.info(f"✓ Model validated: accuracy={metrics['accuracy']:.2%}") return True except Exception as e: logger.error(f"✗ Test failed: {e}", exc_info=True) return False def test_model_validator(): logger.info("\n=== Testing ModelValidator ===") try: from core.training.model_validator import ModelValidator validator = ModelValidator(min_accuracy=0.80) logger.info("✓ ModelValidator created") logger.info("✓ Validator ready (requires trained model for full test)") return True except Exception as e: logger.error(f"✗ Test failed: {e}", exc_info=True) return False def test_complete_workflow(): logger.info("\n=== Testing Complete Training Workflow ===") try: from core.training.training_data_collector import TrainingDataCollector from core.training.offline_trainer import OfflineTrainer # Step 1: Collect data collector = TrainingDataCollector(output_dir="workflow_test") for i in range(5): collector.start_session(f"session_{i}", workflow_id="test_wf") collector.record_action({'type': 'click'}) collector.end_session(success=True) training_set_path = "workflow_test/training_set.json" collector.export_training_set("training_set.json") logger.info("✓ Step 1: Data collected") # Step 2: Train model trainer = OfflineTrainer() # Would train on real data here logger.info("✓ Step 2: Model training ready") # Step 3: Validate # Would validate here logger.info("✓ Step 3: Validation ready") logger.info("✓ Complete workflow tested") return True except Exception as e: logger.error(f"✗ Test failed: {e}", exc_info=True) return False def main(): logger.info("=" * 60) logger.info("Phase 8 - Training System Tests") logger.info("=" * 60) tests = [ test_training_data_collector, test_offline_trainer, test_model_validator, test_complete_workflow ] results = [] for test in tests: try: result = test() results.append(result) except Exception as e: logger.error(f"Test crashed: {e}", exc_info=True) results.append(False) passed = sum(results) logger.info(f"\n{'='*60}\nResults: {passed}/{len(results)} tests passed\n{'='*60}") return 0 if passed == len(results) else 1 if __name__ == '__main__': sys.exit(main())