322 lines
11 KiB
Python
322 lines
11 KiB
Python
"""
|
|
Lightweight fine-tuner for embedding models.
|
|
|
|
This module provides incremental fine-tuning capabilities that run in the
|
|
background, adapting the embedding model to user-specific workflows over time.
|
|
"""
|
|
|
|
import threading
|
|
import time
|
|
import pickle
|
|
import logging
|
|
from collections import deque
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any, Optional
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LightweightFineTuner:
|
|
"""
|
|
Lightweight fine-tuner for incremental model adaptation.
|
|
|
|
This class collects positive and negative examples from user interactions
|
|
and periodically fine-tunes the embedding model to improve accuracy on
|
|
user-specific workflows.
|
|
|
|
Features:
|
|
- Automatic triggering after N examples
|
|
- Background training (non-blocking)
|
|
- Checkpoint save/load for recovery
|
|
- Metrics tracking
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedder,
|
|
trigger_threshold: int = 10,
|
|
max_examples: int = 1000,
|
|
checkpoint_dir: str = "data/fine_tuning"
|
|
):
|
|
"""
|
|
Initialize the fine-tuner.
|
|
|
|
Args:
|
|
embedder: Embedder instance to fine-tune (must support fine_tune method)
|
|
trigger_threshold: Number of new examples before triggering fine-tuning
|
|
max_examples: Maximum examples to keep (LRU eviction)
|
|
checkpoint_dir: Directory for saving checkpoints
|
|
"""
|
|
self.embedder = embedder
|
|
self.trigger_threshold = trigger_threshold
|
|
self.max_examples = max_examples
|
|
self.checkpoint_dir = Path(checkpoint_dir)
|
|
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Example storage (deque for automatic LRU)
|
|
self.positive_examples = deque(maxlen=max_examples)
|
|
self.negative_examples = deque(maxlen=max_examples)
|
|
|
|
# Training state
|
|
self.is_training = False
|
|
self.training_thread: Optional[threading.Thread] = None
|
|
self.last_training_time = 0
|
|
self.training_count = 0
|
|
|
|
# Metrics
|
|
self.metrics_history: List[Dict[str, Any]] = []
|
|
|
|
logger.info(
|
|
f"LightweightFineTuner initialized: "
|
|
f"trigger={trigger_threshold}, max_examples={max_examples}"
|
|
)
|
|
|
|
def add_positive_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
|
|
"""
|
|
Add a positive example (successful workflow execution).
|
|
|
|
Args:
|
|
image: Screenshot where workflow succeeded
|
|
workflow_id: ID of the successful workflow
|
|
metadata: Optional additional metadata
|
|
"""
|
|
example = {
|
|
'image': image,
|
|
'workflow_id': workflow_id,
|
|
'metadata': metadata or {},
|
|
'timestamp': time.time()
|
|
}
|
|
|
|
self.positive_examples.append(example)
|
|
|
|
logger.debug(
|
|
f"Added positive example: workflow={workflow_id}, "
|
|
f"total_positive={len(self.positive_examples)}"
|
|
)
|
|
|
|
self._check_trigger()
|
|
|
|
def add_negative_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
|
|
"""
|
|
Add a negative example (rejected workflow suggestion).
|
|
|
|
Args:
|
|
image: Screenshot where workflow was rejected
|
|
workflow_id: ID of the rejected workflow
|
|
metadata: Optional additional metadata
|
|
"""
|
|
example = {
|
|
'image': image,
|
|
'workflow_id': workflow_id,
|
|
'metadata': metadata or {},
|
|
'timestamp': time.time()
|
|
}
|
|
|
|
self.negative_examples.append(example)
|
|
|
|
logger.debug(
|
|
f"Added negative example: workflow={workflow_id}, "
|
|
f"total_negative={len(self.negative_examples)}"
|
|
)
|
|
|
|
self._check_trigger()
|
|
|
|
def _check_trigger(self):
|
|
"""Check if we should trigger fine-tuning."""
|
|
total_new = len(self.positive_examples) + len(self.negative_examples)
|
|
|
|
# Don't trigger if already training
|
|
if self.is_training:
|
|
logger.debug("Fine-tuning already in progress, skipping trigger check")
|
|
return
|
|
|
|
# Check if we have enough examples
|
|
if total_new >= self.trigger_threshold:
|
|
logger.info(
|
|
f"Fine-tuning triggered: {total_new} examples "
|
|
f"({len(self.positive_examples)} positive, "
|
|
f"{len(self.negative_examples)} negative)"
|
|
)
|
|
self._start_training()
|
|
|
|
def _start_training(self):
|
|
"""Start training in background thread."""
|
|
if self.is_training:
|
|
logger.warning("Training already in progress")
|
|
return
|
|
|
|
self.training_thread = threading.Thread(
|
|
target=self._train,
|
|
name="FineTuningThread",
|
|
daemon=True
|
|
)
|
|
self.training_thread.start()
|
|
logger.info("Fine-tuning thread started")
|
|
|
|
def _train(self):
|
|
"""Fine-tune the model (runs in background thread)."""
|
|
self.is_training = True
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Check if embedder supports fine-tuning
|
|
if not hasattr(self.embedder, 'fine_tune'):
|
|
logger.info(
|
|
f"Embedder {self.embedder.get_model_name()} doesn't support "
|
|
f"fine-tuning, skipping"
|
|
)
|
|
return
|
|
|
|
# Prepare training data
|
|
positive_images = [ex['image'] for ex in self.positive_examples]
|
|
negative_images = [ex['image'] for ex in self.negative_examples]
|
|
|
|
if not positive_images and not negative_images:
|
|
logger.warning("No examples to train on")
|
|
return
|
|
|
|
logger.info(
|
|
f"Starting fine-tuning: {len(positive_images)} positive, "
|
|
f"{len(negative_images)} negative examples"
|
|
)
|
|
|
|
# Fine-tune (implementation depends on embedder)
|
|
metrics = self.embedder.fine_tune(
|
|
positive_images=positive_images,
|
|
negative_images=negative_images,
|
|
epochs=1,
|
|
learning_rate=1e-4
|
|
)
|
|
|
|
# Record metrics
|
|
duration = time.time() - start_time
|
|
metrics['duration_seconds'] = duration
|
|
metrics['timestamp'] = time.time()
|
|
metrics['positive_count'] = len(positive_images)
|
|
metrics['negative_count'] = len(negative_images)
|
|
metrics['training_number'] = self.training_count
|
|
|
|
self.metrics_history.append(metrics)
|
|
self.last_training_time = time.time()
|
|
self.training_count += 1
|
|
|
|
logger.info(
|
|
f"Fine-tuning complete #{self.training_count}: "
|
|
f"loss={metrics.get('loss', 'N/A'):.4f}, "
|
|
f"duration={duration:.1f}s"
|
|
)
|
|
|
|
# Clear examples after successful training
|
|
self.positive_examples.clear()
|
|
self.negative_examples.clear()
|
|
logger.debug("Training examples cleared")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Fine-tuning failed: {e}", exc_info=True)
|
|
|
|
finally:
|
|
self.is_training = False
|
|
|
|
def save_checkpoint(self, name: str = "checkpoint"):
|
|
"""
|
|
Save training examples and metrics for recovery.
|
|
|
|
Args:
|
|
name: Checkpoint name
|
|
"""
|
|
try:
|
|
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
|
|
|
|
data = {
|
|
'positive_examples': list(self.positive_examples),
|
|
'negative_examples': list(self.negative_examples),
|
|
'metrics_history': self.metrics_history,
|
|
'training_count': self.training_count,
|
|
'last_training_time': self.last_training_time
|
|
}
|
|
|
|
with open(checkpoint_path, 'wb') as f:
|
|
pickle.dump(data, f)
|
|
|
|
logger.info(f"Checkpoint saved: {checkpoint_path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save checkpoint: {e}")
|
|
|
|
def load_checkpoint(self, name: str = "checkpoint"):
|
|
"""
|
|
Load training examples and metrics from checkpoint.
|
|
|
|
Args:
|
|
name: Checkpoint name
|
|
|
|
Returns:
|
|
bool: True if loaded successfully, False otherwise
|
|
"""
|
|
try:
|
|
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
|
|
|
|
if not checkpoint_path.exists():
|
|
logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
|
return False
|
|
|
|
with open(checkpoint_path, 'rb') as f:
|
|
data = pickle.load(f)
|
|
|
|
self.positive_examples.extend(data.get('positive_examples', []))
|
|
self.negative_examples.extend(data.get('negative_examples', []))
|
|
self.metrics_history = data.get('metrics_history', [])
|
|
self.training_count = data.get('training_count', 0)
|
|
self.last_training_time = data.get('last_training_time', 0)
|
|
|
|
logger.info(
|
|
f"Checkpoint loaded: {len(self.positive_examples)} positive, "
|
|
f"{len(self.negative_examples)} negative examples"
|
|
)
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load checkpoint: {e}")
|
|
return False
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
Get fine-tuning statistics.
|
|
|
|
Returns:
|
|
Dict with statistics
|
|
"""
|
|
return {
|
|
'positive_examples': len(self.positive_examples),
|
|
'negative_examples': len(self.negative_examples),
|
|
'total_examples': len(self.positive_examples) + len(self.negative_examples),
|
|
'is_training': self.is_training,
|
|
'training_count': self.training_count,
|
|
'last_training_time': self.last_training_time,
|
|
'metrics_history': self.metrics_history,
|
|
'trigger_threshold': self.trigger_threshold
|
|
}
|
|
|
|
def wait_for_training(self, timeout: Optional[float] = None):
|
|
"""
|
|
Wait for current training to complete.
|
|
|
|
Args:
|
|
timeout: Maximum time to wait in seconds (None = wait forever)
|
|
"""
|
|
if self.training_thread and self.training_thread.is_alive():
|
|
self.training_thread.join(timeout=timeout)
|
|
|
|
def __repr__(self) -> str:
|
|
"""String representation."""
|
|
stats = self.get_stats()
|
|
return (
|
|
f"LightweightFineTuner("
|
|
f"examples={stats['total_examples']}, "
|
|
f"trainings={stats['training_count']}, "
|
|
f"is_training={stats['is_training']})"
|
|
)
|