""" Lightweight fine-tuner for embedding models. This module provides incremental fine-tuning capabilities that run in the background, adapting the embedding model to user-specific workflows over time. """ import threading import time import pickle import logging from collections import deque from pathlib import Path from typing import List, Dict, Any, Optional from PIL import Image import numpy as np logger = logging.getLogger(__name__) class LightweightFineTuner: """ Lightweight fine-tuner for incremental model adaptation. This class collects positive and negative examples from user interactions and periodically fine-tunes the embedding model to improve accuracy on user-specific workflows. Features: - Automatic triggering after N examples - Background training (non-blocking) - Checkpoint save/load for recovery - Metrics tracking """ def __init__( self, embedder, trigger_threshold: int = 10, max_examples: int = 1000, checkpoint_dir: str = "data/fine_tuning" ): """ Initialize the fine-tuner. Args: embedder: Embedder instance to fine-tune (must support fine_tune method) trigger_threshold: Number of new examples before triggering fine-tuning max_examples: Maximum examples to keep (LRU eviction) checkpoint_dir: Directory for saving checkpoints """ self.embedder = embedder self.trigger_threshold = trigger_threshold self.max_examples = max_examples self.checkpoint_dir = Path(checkpoint_dir) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) # Example storage (deque for automatic LRU) self.positive_examples = deque(maxlen=max_examples) self.negative_examples = deque(maxlen=max_examples) # Training state self.is_training = False self.training_thread: Optional[threading.Thread] = None self.last_training_time = 0 self.training_count = 0 # Metrics self.metrics_history: List[Dict[str, Any]] = [] logger.info( f"LightweightFineTuner initialized: " f"trigger={trigger_threshold}, max_examples={max_examples}" ) def add_positive_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None): """ Add a positive example (successful workflow execution). Args: image: Screenshot where workflow succeeded workflow_id: ID of the successful workflow metadata: Optional additional metadata """ example = { 'image': image, 'workflow_id': workflow_id, 'metadata': metadata or {}, 'timestamp': time.time() } self.positive_examples.append(example) logger.debug( f"Added positive example: workflow={workflow_id}, " f"total_positive={len(self.positive_examples)}" ) self._check_trigger() def add_negative_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None): """ Add a negative example (rejected workflow suggestion). Args: image: Screenshot where workflow was rejected workflow_id: ID of the rejected workflow metadata: Optional additional metadata """ example = { 'image': image, 'workflow_id': workflow_id, 'metadata': metadata or {}, 'timestamp': time.time() } self.negative_examples.append(example) logger.debug( f"Added negative example: workflow={workflow_id}, " f"total_negative={len(self.negative_examples)}" ) self._check_trigger() def _check_trigger(self): """Check if we should trigger fine-tuning.""" total_new = len(self.positive_examples) + len(self.negative_examples) # Don't trigger if already training if self.is_training: logger.debug("Fine-tuning already in progress, skipping trigger check") return # Check if we have enough examples if total_new >= self.trigger_threshold: logger.info( f"Fine-tuning triggered: {total_new} examples " f"({len(self.positive_examples)} positive, " f"{len(self.negative_examples)} negative)" ) self._start_training() def _start_training(self): """Start training in background thread.""" if self.is_training: logger.warning("Training already in progress") return self.training_thread = threading.Thread( target=self._train, name="FineTuningThread", daemon=True ) self.training_thread.start() logger.info("Fine-tuning thread started") def _train(self): """Fine-tune the model (runs in background thread).""" self.is_training = True start_time = time.time() try: # Check if embedder supports fine-tuning if not hasattr(self.embedder, 'fine_tune'): logger.info( f"Embedder {self.embedder.get_model_name()} doesn't support " f"fine-tuning, skipping" ) return # Prepare training data positive_images = [ex['image'] for ex in self.positive_examples] negative_images = [ex['image'] for ex in self.negative_examples] if not positive_images and not negative_images: logger.warning("No examples to train on") return logger.info( f"Starting fine-tuning: {len(positive_images)} positive, " f"{len(negative_images)} negative examples" ) # Fine-tune (implementation depends on embedder) metrics = self.embedder.fine_tune( positive_images=positive_images, negative_images=negative_images, epochs=1, learning_rate=1e-4 ) # Record metrics duration = time.time() - start_time metrics['duration_seconds'] = duration metrics['timestamp'] = time.time() metrics['positive_count'] = len(positive_images) metrics['negative_count'] = len(negative_images) metrics['training_number'] = self.training_count self.metrics_history.append(metrics) self.last_training_time = time.time() self.training_count += 1 logger.info( f"Fine-tuning complete #{self.training_count}: " f"loss={metrics.get('loss', 'N/A'):.4f}, " f"duration={duration:.1f}s" ) # Clear examples after successful training self.positive_examples.clear() self.negative_examples.clear() logger.debug("Training examples cleared") except Exception as e: logger.error(f"Fine-tuning failed: {e}", exc_info=True) finally: self.is_training = False def save_checkpoint(self, name: str = "checkpoint"): """ Save training examples and metrics for recovery. Args: name: Checkpoint name """ try: checkpoint_path = self.checkpoint_dir / f"{name}.pkl" data = { 'positive_examples': list(self.positive_examples), 'negative_examples': list(self.negative_examples), 'metrics_history': self.metrics_history, 'training_count': self.training_count, 'last_training_time': self.last_training_time } with open(checkpoint_path, 'wb') as f: pickle.dump(data, f) logger.info(f"Checkpoint saved: {checkpoint_path}") except Exception as e: logger.error(f"Failed to save checkpoint: {e}") def load_checkpoint(self, name: str = "checkpoint"): """ Load training examples and metrics from checkpoint. Args: name: Checkpoint name Returns: bool: True if loaded successfully, False otherwise """ try: checkpoint_path = self.checkpoint_dir / f"{name}.pkl" if not checkpoint_path.exists(): logger.warning(f"Checkpoint not found: {checkpoint_path}") return False with open(checkpoint_path, 'rb') as f: data = pickle.load(f) self.positive_examples.extend(data.get('positive_examples', [])) self.negative_examples.extend(data.get('negative_examples', [])) self.metrics_history = data.get('metrics_history', []) self.training_count = data.get('training_count', 0) self.last_training_time = data.get('last_training_time', 0) logger.info( f"Checkpoint loaded: {len(self.positive_examples)} positive, " f"{len(self.negative_examples)} negative examples" ) return True except Exception as e: logger.error(f"Failed to load checkpoint: {e}") return False def get_stats(self) -> Dict[str, Any]: """ Get fine-tuning statistics. Returns: Dict with statistics """ return { 'positive_examples': len(self.positive_examples), 'negative_examples': len(self.negative_examples), 'total_examples': len(self.positive_examples) + len(self.negative_examples), 'is_training': self.is_training, 'training_count': self.training_count, 'last_training_time': self.last_training_time, 'metrics_history': self.metrics_history, 'trigger_threshold': self.trigger_threshold } def wait_for_training(self, timeout: Optional[float] = None): """ Wait for current training to complete. Args: timeout: Maximum time to wait in seconds (None = wait forever) """ if self.training_thread and self.training_thread.is_alive(): self.training_thread.join(timeout=timeout) def __repr__(self) -> str: """String representation.""" stats = self.get_stats() return ( f"LightweightFineTuner(" f"examples={stats['total_examples']}, " f"trainings={stats['training_count']}, " f"is_training={stats['is_training']})" )