359 lines
12 KiB
Python
359 lines
12 KiB
Python
"""
|
|
CLIP-based embedder implementation.
|
|
|
|
This module provides a wrapper around OpenCLIP for generating image embeddings
|
|
using the CLIP (Contrastive Language-Image Pre-training) model.
|
|
"""
|
|
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
from typing import List, Optional
|
|
import logging
|
|
|
|
try:
|
|
import open_clip
|
|
except ImportError:
|
|
open_clip = None
|
|
|
|
from .base import EmbedderBase
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CLIPEmbedder(EmbedderBase):
|
|
"""
|
|
CLIP-based image embedder using OpenCLIP.
|
|
|
|
This embedder uses the ViT-B/32 architecture by default, which produces
|
|
512-dimensional embeddings. It automatically handles GPU/CPU device selection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "ViT-B-32",
|
|
pretrained: str = "openai",
|
|
device: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize the CLIP embedder.
|
|
|
|
Args:
|
|
model_name: CLIP model architecture (default: ViT-B-32)
|
|
pretrained: Pretrained weights to use (default: openai)
|
|
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
|
Note: Defaults to CPU to save GPU memory for other models
|
|
|
|
Raises:
|
|
ImportError: If open_clip is not installed
|
|
RuntimeError: If model loading fails
|
|
"""
|
|
if open_clip is None:
|
|
raise ImportError(
|
|
"OpenCLIP is not installed. "
|
|
"Install it with: pip install open-clip-torch"
|
|
)
|
|
|
|
# Default to CPU to save GPU for vision models (Qwen3-VL)
|
|
if device is None:
|
|
device = "cpu"
|
|
|
|
self.model_name = model_name
|
|
self.pretrained = pretrained
|
|
self.device = device
|
|
self._embedding_dim = None
|
|
|
|
# Load model
|
|
try:
|
|
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
|
model_name,
|
|
pretrained=pretrained,
|
|
device=device
|
|
)
|
|
self.model.eval()
|
|
|
|
# Determine embedding dimension
|
|
with torch.no_grad():
|
|
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
|
|
dummy_embedding = self.model.encode_image(dummy_image)
|
|
self._embedding_dim = dummy_embedding.shape[-1]
|
|
|
|
logger.info(
|
|
f"CLIPEmbedder loaded: {model_name} on {device}, "
|
|
f"dimension={self._embedding_dim}"
|
|
)
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load CLIP model: {e}")
|
|
|
|
def embed(self, image: Image.Image) -> np.ndarray:
|
|
"""
|
|
Generate embedding for a single image.
|
|
|
|
Args:
|
|
image: PIL Image to embed
|
|
|
|
Returns:
|
|
np.ndarray: Normalized embedding vector of shape (dimension,)
|
|
|
|
Raises:
|
|
ValueError: If image is invalid
|
|
RuntimeError: If embedding generation fails
|
|
"""
|
|
if not isinstance(image, Image.Image):
|
|
raise ValueError("Input must be a PIL Image")
|
|
|
|
try:
|
|
# Preprocess image
|
|
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
|
|
|
|
# Generate embedding
|
|
with torch.no_grad():
|
|
embedding = self.model.encode_image(image_tensor)
|
|
# L2 normalize for cosine similarity
|
|
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
|
|
|
|
return embedding.cpu().numpy().flatten()
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to generate embedding: {e}")
|
|
|
|
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
|
"""
|
|
Generate embeddings for multiple images (optimized batch processing).
|
|
|
|
Args:
|
|
images: List of PIL Images to embed
|
|
|
|
Returns:
|
|
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
|
|
|
Raises:
|
|
ValueError: If any image is invalid
|
|
RuntimeError: If embedding generation fails
|
|
"""
|
|
if not images:
|
|
return np.array([]).reshape(0, self.get_dimension())
|
|
|
|
# Validate all images
|
|
for i, img in enumerate(images):
|
|
if not isinstance(img, Image.Image):
|
|
raise ValueError(f"Image at index {i} is not a PIL Image")
|
|
|
|
try:
|
|
# Preprocess all images
|
|
image_tensors = torch.stack([
|
|
self.preprocess(img) for img in images
|
|
]).to(self.device)
|
|
|
|
# Generate embeddings in batch
|
|
with torch.no_grad():
|
|
embeddings = self.model.encode_image(image_tensors)
|
|
# L2 normalize for cosine similarity
|
|
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
|
|
|
|
return embeddings.cpu().numpy()
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
|
|
|
|
def get_dimension(self) -> int:
|
|
"""
|
|
Get the dimensionality of embeddings.
|
|
|
|
Returns:
|
|
int: Embedding dimension (512 for ViT-B/32)
|
|
"""
|
|
return self._embedding_dim
|
|
|
|
def get_model_name(self) -> str:
|
|
"""
|
|
Get model identifier.
|
|
|
|
Returns:
|
|
str: Model name (e.g., "clip-vit-b32")
|
|
"""
|
|
return f"clip-{self.model_name.lower().replace('/', '-')}"
|
|
|
|
def supports_batch(self) -> bool:
|
|
"""
|
|
Check if batch processing is supported.
|
|
|
|
Returns:
|
|
bool: True (CLIP supports efficient batch processing)
|
|
"""
|
|
return True
|
|
|
|
def fine_tune(
|
|
self,
|
|
positive_images: List[Image.Image],
|
|
negative_images: List[Image.Image],
|
|
epochs: int = 1,
|
|
learning_rate: float = 1e-4
|
|
) -> dict:
|
|
"""
|
|
Fine-tune the model using contrastive learning.
|
|
|
|
This method fine-tunes only the final projection layer to adapt
|
|
the model to user-specific workflows. It uses a simple contrastive
|
|
loss: positive examples should be similar, negative examples should
|
|
be dissimilar.
|
|
|
|
Args:
|
|
positive_images: Images from successful workflows
|
|
negative_images: Images from rejected workflows
|
|
epochs: Number of training epochs (default: 1 for speed)
|
|
learning_rate: Learning rate (default: 1e-4)
|
|
|
|
Returns:
|
|
dict: Training metrics (loss, accuracy, etc.)
|
|
"""
|
|
if not positive_images and not negative_images:
|
|
return {'loss': 0.0, 'note': 'No examples to train on'}
|
|
|
|
# Set model to training mode
|
|
self.model.train()
|
|
|
|
# Only train the visual projection layer (last layer)
|
|
# Freeze all other parameters
|
|
for param in self.model.parameters():
|
|
param.requires_grad = False
|
|
|
|
# Unfreeze visual projection
|
|
if hasattr(self.model, 'visual') and hasattr(self.model.visual, 'proj'):
|
|
if self.model.visual.proj is not None:
|
|
self.model.visual.proj.requires_grad = True
|
|
|
|
# Setup optimizer (only for trainable parameters)
|
|
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
|
|
|
if not trainable_params:
|
|
logger.warning("No trainable parameters found, skipping fine-tuning")
|
|
self.model.eval()
|
|
return {'loss': 0.0, 'note': 'No trainable parameters'}
|
|
|
|
optimizer = torch.optim.Adam(trainable_params, lr=learning_rate)
|
|
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
|
|
try:
|
|
for epoch in range(epochs):
|
|
# Process positive examples (should be similar to each other)
|
|
if len(positive_images) >= 2:
|
|
pos_loss = self._contrastive_loss_positive(positive_images)
|
|
total_loss += pos_loss
|
|
num_batches += 1
|
|
|
|
optimizer.zero_grad()
|
|
pos_loss.backward()
|
|
optimizer.step()
|
|
|
|
# Process negative examples (should be dissimilar from positives)
|
|
if positive_images and negative_images:
|
|
neg_loss = self._contrastive_loss_negative(
|
|
positive_images[:5], # Use subset for speed
|
|
negative_images[:5]
|
|
)
|
|
total_loss += neg_loss
|
|
num_batches += 1
|
|
|
|
optimizer.zero_grad()
|
|
neg_loss.backward()
|
|
optimizer.step()
|
|
|
|
finally:
|
|
# Always restore eval mode and freeze parameters
|
|
self.model.eval()
|
|
for param in self.model.parameters():
|
|
param.requires_grad = False
|
|
|
|
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
|
|
|
return {
|
|
'loss': float(avg_loss),
|
|
'epochs': epochs,
|
|
'learning_rate': learning_rate,
|
|
'positive_count': len(positive_images),
|
|
'negative_count': len(negative_images),
|
|
'num_batches': num_batches
|
|
}
|
|
|
|
def _contrastive_loss_positive(self, images: List[Image.Image]) -> torch.Tensor:
|
|
"""
|
|
Contrastive loss for positive examples (should be similar).
|
|
|
|
Args:
|
|
images: List of positive example images
|
|
|
|
Returns:
|
|
torch.Tensor: Loss value
|
|
"""
|
|
# Generate embeddings
|
|
embeddings = []
|
|
for img in images:
|
|
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
|
emb = self.model.encode_image(img_tensor)
|
|
emb = emb / emb.norm(dim=-1, keepdim=True)
|
|
embeddings.append(emb)
|
|
|
|
embeddings = torch.cat(embeddings, dim=0)
|
|
|
|
# Compute pairwise cosine similarities
|
|
similarities = torch.mm(embeddings, embeddings.t())
|
|
|
|
# Loss: maximize similarity (minimize negative similarity)
|
|
# Exclude diagonal (self-similarity)
|
|
mask = torch.eye(len(images), device=self.device).bool()
|
|
similarities = similarities.masked_fill(mask, 0)
|
|
|
|
# We want high similarity, so minimize (1 - similarity)
|
|
loss = (1 - similarities).mean()
|
|
|
|
return loss
|
|
|
|
def _contrastive_loss_negative(
|
|
self,
|
|
positive_images: List[Image.Image],
|
|
negative_images: List[Image.Image]
|
|
) -> torch.Tensor:
|
|
"""
|
|
Contrastive loss for negative examples (should be dissimilar).
|
|
|
|
Args:
|
|
positive_images: Positive example images
|
|
negative_images: Negative example images
|
|
|
|
Returns:
|
|
torch.Tensor: Loss value
|
|
"""
|
|
# Generate embeddings for positives
|
|
pos_embeddings = []
|
|
for img in positive_images:
|
|
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
|
emb = self.model.encode_image(img_tensor)
|
|
emb = emb / emb.norm(dim=-1, keepdim=True)
|
|
pos_embeddings.append(emb)
|
|
|
|
pos_embeddings = torch.cat(pos_embeddings, dim=0)
|
|
|
|
# Generate embeddings for negatives
|
|
neg_embeddings = []
|
|
for img in negative_images:
|
|
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
|
emb = self.model.encode_image(img_tensor)
|
|
emb = emb / emb.norm(dim=-1, keepdim=True)
|
|
neg_embeddings.append(emb)
|
|
|
|
neg_embeddings = torch.cat(neg_embeddings, dim=0)
|
|
|
|
# Compute cross-similarities (positive vs negative)
|
|
similarities = torch.mm(pos_embeddings, neg_embeddings.t())
|
|
|
|
# We want low similarity, so minimize similarity directly
|
|
# (or maximize dissimilarity)
|
|
loss = similarities.mean()
|
|
|
|
return loss
|