Initial commit
This commit is contained in:
321
geniusia2/core/embedders/fine_tuner.py
Normal file
321
geniusia2/core/embedders/fine_tuner.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Lightweight fine-tuner for embedding models.
|
||||
|
||||
This module provides incremental fine-tuning capabilities that run in the
|
||||
background, adapting the embedding model to user-specific workflows over time.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import pickle
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LightweightFineTuner:
|
||||
"""
|
||||
Lightweight fine-tuner for incremental model adaptation.
|
||||
|
||||
This class collects positive and negative examples from user interactions
|
||||
and periodically fine-tunes the embedding model to improve accuracy on
|
||||
user-specific workflows.
|
||||
|
||||
Features:
|
||||
- Automatic triggering after N examples
|
||||
- Background training (non-blocking)
|
||||
- Checkpoint save/load for recovery
|
||||
- Metrics tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder,
|
||||
trigger_threshold: int = 10,
|
||||
max_examples: int = 1000,
|
||||
checkpoint_dir: str = "data/fine_tuning"
|
||||
):
|
||||
"""
|
||||
Initialize the fine-tuner.
|
||||
|
||||
Args:
|
||||
embedder: Embedder instance to fine-tune (must support fine_tune method)
|
||||
trigger_threshold: Number of new examples before triggering fine-tuning
|
||||
max_examples: Maximum examples to keep (LRU eviction)
|
||||
checkpoint_dir: Directory for saving checkpoints
|
||||
"""
|
||||
self.embedder = embedder
|
||||
self.trigger_threshold = trigger_threshold
|
||||
self.max_examples = max_examples
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Example storage (deque for automatic LRU)
|
||||
self.positive_examples = deque(maxlen=max_examples)
|
||||
self.negative_examples = deque(maxlen=max_examples)
|
||||
|
||||
# Training state
|
||||
self.is_training = False
|
||||
self.training_thread: Optional[threading.Thread] = None
|
||||
self.last_training_time = 0
|
||||
self.training_count = 0
|
||||
|
||||
# Metrics
|
||||
self.metrics_history: List[Dict[str, Any]] = []
|
||||
|
||||
logger.info(
|
||||
f"LightweightFineTuner initialized: "
|
||||
f"trigger={trigger_threshold}, max_examples={max_examples}"
|
||||
)
|
||||
|
||||
def add_positive_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
|
||||
"""
|
||||
Add a positive example (successful workflow execution).
|
||||
|
||||
Args:
|
||||
image: Screenshot where workflow succeeded
|
||||
workflow_id: ID of the successful workflow
|
||||
metadata: Optional additional metadata
|
||||
"""
|
||||
example = {
|
||||
'image': image,
|
||||
'workflow_id': workflow_id,
|
||||
'metadata': metadata or {},
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
self.positive_examples.append(example)
|
||||
|
||||
logger.debug(
|
||||
f"Added positive example: workflow={workflow_id}, "
|
||||
f"total_positive={len(self.positive_examples)}"
|
||||
)
|
||||
|
||||
self._check_trigger()
|
||||
|
||||
def add_negative_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
|
||||
"""
|
||||
Add a negative example (rejected workflow suggestion).
|
||||
|
||||
Args:
|
||||
image: Screenshot where workflow was rejected
|
||||
workflow_id: ID of the rejected workflow
|
||||
metadata: Optional additional metadata
|
||||
"""
|
||||
example = {
|
||||
'image': image,
|
||||
'workflow_id': workflow_id,
|
||||
'metadata': metadata or {},
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
self.negative_examples.append(example)
|
||||
|
||||
logger.debug(
|
||||
f"Added negative example: workflow={workflow_id}, "
|
||||
f"total_negative={len(self.negative_examples)}"
|
||||
)
|
||||
|
||||
self._check_trigger()
|
||||
|
||||
def _check_trigger(self):
|
||||
"""Check if we should trigger fine-tuning."""
|
||||
total_new = len(self.positive_examples) + len(self.negative_examples)
|
||||
|
||||
# Don't trigger if already training
|
||||
if self.is_training:
|
||||
logger.debug("Fine-tuning already in progress, skipping trigger check")
|
||||
return
|
||||
|
||||
# Check if we have enough examples
|
||||
if total_new >= self.trigger_threshold:
|
||||
logger.info(
|
||||
f"Fine-tuning triggered: {total_new} examples "
|
||||
f"({len(self.positive_examples)} positive, "
|
||||
f"{len(self.negative_examples)} negative)"
|
||||
)
|
||||
self._start_training()
|
||||
|
||||
def _start_training(self):
|
||||
"""Start training in background thread."""
|
||||
if self.is_training:
|
||||
logger.warning("Training already in progress")
|
||||
return
|
||||
|
||||
self.training_thread = threading.Thread(
|
||||
target=self._train,
|
||||
name="FineTuningThread",
|
||||
daemon=True
|
||||
)
|
||||
self.training_thread.start()
|
||||
logger.info("Fine-tuning thread started")
|
||||
|
||||
def _train(self):
|
||||
"""Fine-tune the model (runs in background thread)."""
|
||||
self.is_training = True
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Check if embedder supports fine-tuning
|
||||
if not hasattr(self.embedder, 'fine_tune'):
|
||||
logger.info(
|
||||
f"Embedder {self.embedder.get_model_name()} doesn't support "
|
||||
f"fine-tuning, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
# Prepare training data
|
||||
positive_images = [ex['image'] for ex in self.positive_examples]
|
||||
negative_images = [ex['image'] for ex in self.negative_examples]
|
||||
|
||||
if not positive_images and not negative_images:
|
||||
logger.warning("No examples to train on")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Starting fine-tuning: {len(positive_images)} positive, "
|
||||
f"{len(negative_images)} negative examples"
|
||||
)
|
||||
|
||||
# Fine-tune (implementation depends on embedder)
|
||||
metrics = self.embedder.fine_tune(
|
||||
positive_images=positive_images,
|
||||
negative_images=negative_images,
|
||||
epochs=1,
|
||||
learning_rate=1e-4
|
||||
)
|
||||
|
||||
# Record metrics
|
||||
duration = time.time() - start_time
|
||||
metrics['duration_seconds'] = duration
|
||||
metrics['timestamp'] = time.time()
|
||||
metrics['positive_count'] = len(positive_images)
|
||||
metrics['negative_count'] = len(negative_images)
|
||||
metrics['training_number'] = self.training_count
|
||||
|
||||
self.metrics_history.append(metrics)
|
||||
self.last_training_time = time.time()
|
||||
self.training_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Fine-tuning complete #{self.training_count}: "
|
||||
f"loss={metrics.get('loss', 'N/A'):.4f}, "
|
||||
f"duration={duration:.1f}s"
|
||||
)
|
||||
|
||||
# Clear examples after successful training
|
||||
self.positive_examples.clear()
|
||||
self.negative_examples.clear()
|
||||
logger.debug("Training examples cleared")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fine-tuning failed: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
self.is_training = False
|
||||
|
||||
def save_checkpoint(self, name: str = "checkpoint"):
|
||||
"""
|
||||
Save training examples and metrics for recovery.
|
||||
|
||||
Args:
|
||||
name: Checkpoint name
|
||||
"""
|
||||
try:
|
||||
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
|
||||
|
||||
data = {
|
||||
'positive_examples': list(self.positive_examples),
|
||||
'negative_examples': list(self.negative_examples),
|
||||
'metrics_history': self.metrics_history,
|
||||
'training_count': self.training_count,
|
||||
'last_training_time': self.last_training_time
|
||||
}
|
||||
|
||||
with open(checkpoint_path, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
logger.info(f"Checkpoint saved: {checkpoint_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint: {e}")
|
||||
|
||||
def load_checkpoint(self, name: str = "checkpoint"):
|
||||
"""
|
||||
Load training examples and metrics from checkpoint.
|
||||
|
||||
Args:
|
||||
name: Checkpoint name
|
||||
|
||||
Returns:
|
||||
bool: True if loaded successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
||||
return False
|
||||
|
||||
with open(checkpoint_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
self.positive_examples.extend(data.get('positive_examples', []))
|
||||
self.negative_examples.extend(data.get('negative_examples', []))
|
||||
self.metrics_history = data.get('metrics_history', [])
|
||||
self.training_count = data.get('training_count', 0)
|
||||
self.last_training_time = data.get('last_training_time', 0)
|
||||
|
||||
logger.info(
|
||||
f"Checkpoint loaded: {len(self.positive_examples)} positive, "
|
||||
f"{len(self.negative_examples)} negative examples"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get fine-tuning statistics.
|
||||
|
||||
Returns:
|
||||
Dict with statistics
|
||||
"""
|
||||
return {
|
||||
'positive_examples': len(self.positive_examples),
|
||||
'negative_examples': len(self.negative_examples),
|
||||
'total_examples': len(self.positive_examples) + len(self.negative_examples),
|
||||
'is_training': self.is_training,
|
||||
'training_count': self.training_count,
|
||||
'last_training_time': self.last_training_time,
|
||||
'metrics_history': self.metrics_history,
|
||||
'trigger_threshold': self.trigger_threshold
|
||||
}
|
||||
|
||||
def wait_for_training(self, timeout: Optional[float] = None):
|
||||
"""
|
||||
Wait for current training to complete.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds (None = wait forever)
|
||||
"""
|
||||
if self.training_thread and self.training_thread.is_alive():
|
||||
self.training_thread.join(timeout=timeout)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation."""
|
||||
stats = self.get_stats()
|
||||
return (
|
||||
f"LightweightFineTuner("
|
||||
f"examples={stats['total_examples']}, "
|
||||
f"trainings={stats['training_count']}, "
|
||||
f"is_training={stats['is_training']})"
|
||||
)
|
||||
Reference in New Issue
Block a user