Files
Geniusia_v2/geniusia2/core/embedders/fine_tuner.py
2026-03-05 00:20:25 +01:00

322 lines
11 KiB
Python

"""
Lightweight fine-tuner for embedding models.
This module provides incremental fine-tuning capabilities that run in the
background, adapting the embedding model to user-specific workflows over time.
"""
import threading
import time
import pickle
import logging
from collections import deque
from pathlib import Path
from typing import List, Dict, Any, Optional
from PIL import Image
import numpy as np
logger = logging.getLogger(__name__)
class LightweightFineTuner:
"""
Lightweight fine-tuner for incremental model adaptation.
This class collects positive and negative examples from user interactions
and periodically fine-tunes the embedding model to improve accuracy on
user-specific workflows.
Features:
- Automatic triggering after N examples
- Background training (non-blocking)
- Checkpoint save/load for recovery
- Metrics tracking
"""
def __init__(
self,
embedder,
trigger_threshold: int = 10,
max_examples: int = 1000,
checkpoint_dir: str = "data/fine_tuning"
):
"""
Initialize the fine-tuner.
Args:
embedder: Embedder instance to fine-tune (must support fine_tune method)
trigger_threshold: Number of new examples before triggering fine-tuning
max_examples: Maximum examples to keep (LRU eviction)
checkpoint_dir: Directory for saving checkpoints
"""
self.embedder = embedder
self.trigger_threshold = trigger_threshold
self.max_examples = max_examples
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Example storage (deque for automatic LRU)
self.positive_examples = deque(maxlen=max_examples)
self.negative_examples = deque(maxlen=max_examples)
# Training state
self.is_training = False
self.training_thread: Optional[threading.Thread] = None
self.last_training_time = 0
self.training_count = 0
# Metrics
self.metrics_history: List[Dict[str, Any]] = []
logger.info(
f"LightweightFineTuner initialized: "
f"trigger={trigger_threshold}, max_examples={max_examples}"
)
def add_positive_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
"""
Add a positive example (successful workflow execution).
Args:
image: Screenshot where workflow succeeded
workflow_id: ID of the successful workflow
metadata: Optional additional metadata
"""
example = {
'image': image,
'workflow_id': workflow_id,
'metadata': metadata or {},
'timestamp': time.time()
}
self.positive_examples.append(example)
logger.debug(
f"Added positive example: workflow={workflow_id}, "
f"total_positive={len(self.positive_examples)}"
)
self._check_trigger()
def add_negative_example(self, image: Image.Image, workflow_id: str, metadata: Optional[Dict] = None):
"""
Add a negative example (rejected workflow suggestion).
Args:
image: Screenshot where workflow was rejected
workflow_id: ID of the rejected workflow
metadata: Optional additional metadata
"""
example = {
'image': image,
'workflow_id': workflow_id,
'metadata': metadata or {},
'timestamp': time.time()
}
self.negative_examples.append(example)
logger.debug(
f"Added negative example: workflow={workflow_id}, "
f"total_negative={len(self.negative_examples)}"
)
self._check_trigger()
def _check_trigger(self):
"""Check if we should trigger fine-tuning."""
total_new = len(self.positive_examples) + len(self.negative_examples)
# Don't trigger if already training
if self.is_training:
logger.debug("Fine-tuning already in progress, skipping trigger check")
return
# Check if we have enough examples
if total_new >= self.trigger_threshold:
logger.info(
f"Fine-tuning triggered: {total_new} examples "
f"({len(self.positive_examples)} positive, "
f"{len(self.negative_examples)} negative)"
)
self._start_training()
def _start_training(self):
"""Start training in background thread."""
if self.is_training:
logger.warning("Training already in progress")
return
self.training_thread = threading.Thread(
target=self._train,
name="FineTuningThread",
daemon=True
)
self.training_thread.start()
logger.info("Fine-tuning thread started")
def _train(self):
"""Fine-tune the model (runs in background thread)."""
self.is_training = True
start_time = time.time()
try:
# Check if embedder supports fine-tuning
if not hasattr(self.embedder, 'fine_tune'):
logger.info(
f"Embedder {self.embedder.get_model_name()} doesn't support "
f"fine-tuning, skipping"
)
return
# Prepare training data
positive_images = [ex['image'] for ex in self.positive_examples]
negative_images = [ex['image'] for ex in self.negative_examples]
if not positive_images and not negative_images:
logger.warning("No examples to train on")
return
logger.info(
f"Starting fine-tuning: {len(positive_images)} positive, "
f"{len(negative_images)} negative examples"
)
# Fine-tune (implementation depends on embedder)
metrics = self.embedder.fine_tune(
positive_images=positive_images,
negative_images=negative_images,
epochs=1,
learning_rate=1e-4
)
# Record metrics
duration = time.time() - start_time
metrics['duration_seconds'] = duration
metrics['timestamp'] = time.time()
metrics['positive_count'] = len(positive_images)
metrics['negative_count'] = len(negative_images)
metrics['training_number'] = self.training_count
self.metrics_history.append(metrics)
self.last_training_time = time.time()
self.training_count += 1
logger.info(
f"Fine-tuning complete #{self.training_count}: "
f"loss={metrics.get('loss', 'N/A'):.4f}, "
f"duration={duration:.1f}s"
)
# Clear examples after successful training
self.positive_examples.clear()
self.negative_examples.clear()
logger.debug("Training examples cleared")
except Exception as e:
logger.error(f"Fine-tuning failed: {e}", exc_info=True)
finally:
self.is_training = False
def save_checkpoint(self, name: str = "checkpoint"):
"""
Save training examples and metrics for recovery.
Args:
name: Checkpoint name
"""
try:
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
data = {
'positive_examples': list(self.positive_examples),
'negative_examples': list(self.negative_examples),
'metrics_history': self.metrics_history,
'training_count': self.training_count,
'last_training_time': self.last_training_time
}
with open(checkpoint_path, 'wb') as f:
pickle.dump(data, f)
logger.info(f"Checkpoint saved: {checkpoint_path}")
except Exception as e:
logger.error(f"Failed to save checkpoint: {e}")
def load_checkpoint(self, name: str = "checkpoint"):
"""
Load training examples and metrics from checkpoint.
Args:
name: Checkpoint name
Returns:
bool: True if loaded successfully, False otherwise
"""
try:
checkpoint_path = self.checkpoint_dir / f"{name}.pkl"
if not checkpoint_path.exists():
logger.warning(f"Checkpoint not found: {checkpoint_path}")
return False
with open(checkpoint_path, 'rb') as f:
data = pickle.load(f)
self.positive_examples.extend(data.get('positive_examples', []))
self.negative_examples.extend(data.get('negative_examples', []))
self.metrics_history = data.get('metrics_history', [])
self.training_count = data.get('training_count', 0)
self.last_training_time = data.get('last_training_time', 0)
logger.info(
f"Checkpoint loaded: {len(self.positive_examples)} positive, "
f"{len(self.negative_examples)} negative examples"
)
return True
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
return False
def get_stats(self) -> Dict[str, Any]:
"""
Get fine-tuning statistics.
Returns:
Dict with statistics
"""
return {
'positive_examples': len(self.positive_examples),
'negative_examples': len(self.negative_examples),
'total_examples': len(self.positive_examples) + len(self.negative_examples),
'is_training': self.is_training,
'training_count': self.training_count,
'last_training_time': self.last_training_time,
'metrics_history': self.metrics_history,
'trigger_threshold': self.trigger_threshold
}
def wait_for_training(self, timeout: Optional[float] = None):
"""
Wait for current training to complete.
Args:
timeout: Maximum time to wait in seconds (None = wait forever)
"""
if self.training_thread and self.training_thread.is_alive():
self.training_thread.join(timeout=timeout)
def __repr__(self) -> str:
"""String representation."""
stats = self.get_stats()
return (
f"LightweightFineTuner("
f"examples={stats['total_examples']}, "
f"trainings={stats['training_count']}, "
f"is_training={stats['is_training']})"
)