Files
Geniusia_v2/geniusia2/core/embedders/clip_embedder.py
2026-03-05 00:20:25 +01:00

359 lines
12 KiB
Python

"""
CLIP-based embedder implementation.
This module provides a wrapper around OpenCLIP for generating image embeddings
using the CLIP (Contrastive Language-Image Pre-training) model.
"""
import torch
import numpy as np
from PIL import Image
from typing import List, Optional
import logging
try:
import open_clip
except ImportError:
open_clip = None
from .base import EmbedderBase
logger = logging.getLogger(__name__)
class CLIPEmbedder(EmbedderBase):
"""
CLIP-based image embedder using OpenCLIP.
This embedder uses the ViT-B/32 architecture by default, which produces
512-dimensional embeddings. It automatically handles GPU/CPU device selection.
"""
def __init__(
self,
model_name: str = "ViT-B-32",
pretrained: str = "openai",
device: Optional[str] = None
):
"""
Initialize the CLIP embedder.
Args:
model_name: CLIP model architecture (default: ViT-B-32)
pretrained: Pretrained weights to use (default: openai)
device: Device to use ('cuda', 'cpu', or None for auto-detect)
Note: Defaults to CPU to save GPU memory for other models
Raises:
ImportError: If open_clip is not installed
RuntimeError: If model loading fails
"""
if open_clip is None:
raise ImportError(
"OpenCLIP is not installed. "
"Install it with: pip install open-clip-torch"
)
# Default to CPU to save GPU for vision models (Qwen3-VL)
if device is None:
device = "cpu"
self.model_name = model_name
self.pretrained = pretrained
self.device = device
self._embedding_dim = None
# Load model
try:
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
model_name,
pretrained=pretrained,
device=device
)
self.model.eval()
# Determine embedding dimension
with torch.no_grad():
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
dummy_embedding = self.model.encode_image(dummy_image)
self._embedding_dim = dummy_embedding.shape[-1]
logger.info(
f"CLIPEmbedder loaded: {model_name} on {device}, "
f"dimension={self._embedding_dim}"
)
except Exception as e:
raise RuntimeError(f"Failed to load CLIP model: {e}")
def embed(self, image: Image.Image) -> np.ndarray:
"""
Generate embedding for a single image.
Args:
image: PIL Image to embed
Returns:
np.ndarray: Normalized embedding vector of shape (dimension,)
Raises:
ValueError: If image is invalid
RuntimeError: If embedding generation fails
"""
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image")
try:
# Preprocess image
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# Generate embedding
with torch.no_grad():
embedding = self.model.encode_image(image_tensor)
# L2 normalize for cosine similarity
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
return embedding.cpu().numpy().flatten()
except Exception as e:
raise RuntimeError(f"Failed to generate embedding: {e}")
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
"""
Generate embeddings for multiple images (optimized batch processing).
Args:
images: List of PIL Images to embed
Returns:
np.ndarray: Array of embeddings with shape (len(images), dimension)
Raises:
ValueError: If any image is invalid
RuntimeError: If embedding generation fails
"""
if not images:
return np.array([]).reshape(0, self.get_dimension())
# Validate all images
for i, img in enumerate(images):
if not isinstance(img, Image.Image):
raise ValueError(f"Image at index {i} is not a PIL Image")
try:
# Preprocess all images
image_tensors = torch.stack([
self.preprocess(img) for img in images
]).to(self.device)
# Generate embeddings in batch
with torch.no_grad():
embeddings = self.model.encode_image(image_tensors)
# L2 normalize for cosine similarity
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
return embeddings.cpu().numpy()
except Exception as e:
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
def get_dimension(self) -> int:
"""
Get the dimensionality of embeddings.
Returns:
int: Embedding dimension (512 for ViT-B/32)
"""
return self._embedding_dim
def get_model_name(self) -> str:
"""
Get model identifier.
Returns:
str: Model name (e.g., "clip-vit-b32")
"""
return f"clip-{self.model_name.lower().replace('/', '-')}"
def supports_batch(self) -> bool:
"""
Check if batch processing is supported.
Returns:
bool: True (CLIP supports efficient batch processing)
"""
return True
def fine_tune(
self,
positive_images: List[Image.Image],
negative_images: List[Image.Image],
epochs: int = 1,
learning_rate: float = 1e-4
) -> dict:
"""
Fine-tune the model using contrastive learning.
This method fine-tunes only the final projection layer to adapt
the model to user-specific workflows. It uses a simple contrastive
loss: positive examples should be similar, negative examples should
be dissimilar.
Args:
positive_images: Images from successful workflows
negative_images: Images from rejected workflows
epochs: Number of training epochs (default: 1 for speed)
learning_rate: Learning rate (default: 1e-4)
Returns:
dict: Training metrics (loss, accuracy, etc.)
"""
if not positive_images and not negative_images:
return {'loss': 0.0, 'note': 'No examples to train on'}
# Set model to training mode
self.model.train()
# Only train the visual projection layer (last layer)
# Freeze all other parameters
for param in self.model.parameters():
param.requires_grad = False
# Unfreeze visual projection
if hasattr(self.model, 'visual') and hasattr(self.model.visual, 'proj'):
if self.model.visual.proj is not None:
self.model.visual.proj.requires_grad = True
# Setup optimizer (only for trainable parameters)
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
if not trainable_params:
logger.warning("No trainable parameters found, skipping fine-tuning")
self.model.eval()
return {'loss': 0.0, 'note': 'No trainable parameters'}
optimizer = torch.optim.Adam(trainable_params, lr=learning_rate)
total_loss = 0.0
num_batches = 0
try:
for epoch in range(epochs):
# Process positive examples (should be similar to each other)
if len(positive_images) >= 2:
pos_loss = self._contrastive_loss_positive(positive_images)
total_loss += pos_loss
num_batches += 1
optimizer.zero_grad()
pos_loss.backward()
optimizer.step()
# Process negative examples (should be dissimilar from positives)
if positive_images and negative_images:
neg_loss = self._contrastive_loss_negative(
positive_images[:5], # Use subset for speed
negative_images[:5]
)
total_loss += neg_loss
num_batches += 1
optimizer.zero_grad()
neg_loss.backward()
optimizer.step()
finally:
# Always restore eval mode and freeze parameters
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
return {
'loss': float(avg_loss),
'epochs': epochs,
'learning_rate': learning_rate,
'positive_count': len(positive_images),
'negative_count': len(negative_images),
'num_batches': num_batches
}
def _contrastive_loss_positive(self, images: List[Image.Image]) -> torch.Tensor:
"""
Contrastive loss for positive examples (should be similar).
Args:
images: List of positive example images
Returns:
torch.Tensor: Loss value
"""
# Generate embeddings
embeddings = []
for img in images:
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
emb = self.model.encode_image(img_tensor)
emb = emb / emb.norm(dim=-1, keepdim=True)
embeddings.append(emb)
embeddings = torch.cat(embeddings, dim=0)
# Compute pairwise cosine similarities
similarities = torch.mm(embeddings, embeddings.t())
# Loss: maximize similarity (minimize negative similarity)
# Exclude diagonal (self-similarity)
mask = torch.eye(len(images), device=self.device).bool()
similarities = similarities.masked_fill(mask, 0)
# We want high similarity, so minimize (1 - similarity)
loss = (1 - similarities).mean()
return loss
def _contrastive_loss_negative(
self,
positive_images: List[Image.Image],
negative_images: List[Image.Image]
) -> torch.Tensor:
"""
Contrastive loss for negative examples (should be dissimilar).
Args:
positive_images: Positive example images
negative_images: Negative example images
Returns:
torch.Tensor: Loss value
"""
# Generate embeddings for positives
pos_embeddings = []
for img in positive_images:
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
emb = self.model.encode_image(img_tensor)
emb = emb / emb.norm(dim=-1, keepdim=True)
pos_embeddings.append(emb)
pos_embeddings = torch.cat(pos_embeddings, dim=0)
# Generate embeddings for negatives
neg_embeddings = []
for img in negative_images:
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
emb = self.model.encode_image(img_tensor)
emb = emb / emb.norm(dim=-1, keepdim=True)
neg_embeddings.append(emb)
neg_embeddings = torch.cat(neg_embeddings, dim=0)
# Compute cross-similarities (positive vs negative)
similarities = torch.mm(pos_embeddings, neg_embeddings.t())
# We want low similarity, so minimize similarity directly
# (or maximize dissimilarity)
loss = similarities.mean()
return loss