Initial commit
This commit is contained in:
358
geniusia2/core/embedders/clip_embedder.py
Normal file
358
geniusia2/core/embedders/clip_embedder.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
CLIP-based embedder implementation.
|
||||
|
||||
This module provides a wrapper around OpenCLIP for generating image embeddings
|
||||
using the CLIP (Contrastive Language-Image Pre-training) model.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
try:
|
||||
import open_clip
|
||||
except ImportError:
|
||||
open_clip = None
|
||||
|
||||
from .base import EmbedderBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLIPEmbedder(EmbedderBase):
|
||||
"""
|
||||
CLIP-based image embedder using OpenCLIP.
|
||||
|
||||
This embedder uses the ViT-B/32 architecture by default, which produces
|
||||
512-dimensional embeddings. It automatically handles GPU/CPU device selection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "ViT-B-32",
|
||||
pretrained: str = "openai",
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the CLIP embedder.
|
||||
|
||||
Args:
|
||||
model_name: CLIP model architecture (default: ViT-B-32)
|
||||
pretrained: Pretrained weights to use (default: openai)
|
||||
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
||||
Note: Defaults to CPU to save GPU memory for other models
|
||||
|
||||
Raises:
|
||||
ImportError: If open_clip is not installed
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
if open_clip is None:
|
||||
raise ImportError(
|
||||
"OpenCLIP is not installed. "
|
||||
"Install it with: pip install open-clip-torch"
|
||||
)
|
||||
|
||||
# Default to CPU to save GPU for vision models (Qwen3-VL)
|
||||
if device is None:
|
||||
device = "cpu"
|
||||
|
||||
self.model_name = model_name
|
||||
self.pretrained = pretrained
|
||||
self.device = device
|
||||
self._embedding_dim = None
|
||||
|
||||
# Load model
|
||||
try:
|
||||
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
||||
model_name,
|
||||
pretrained=pretrained,
|
||||
device=device
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
# Determine embedding dimension
|
||||
with torch.no_grad():
|
||||
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
|
||||
dummy_embedding = self.model.encode_image(dummy_image)
|
||||
self._embedding_dim = dummy_embedding.shape[-1]
|
||||
|
||||
logger.info(
|
||||
f"CLIPEmbedder loaded: {model_name} on {device}, "
|
||||
f"dimension={self._embedding_dim}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load CLIP model: {e}")
|
||||
|
||||
def embed(self, image: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single image.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector of shape (dimension,)
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError("Input must be a PIL Image")
|
||||
|
||||
try:
|
||||
# Preprocess image
|
||||
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
|
||||
|
||||
# Generate embedding
|
||||
with torch.no_grad():
|
||||
embedding = self.model.encode_image(image_tensor)
|
||||
# L2 normalize for cosine similarity
|
||||
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embedding.cpu().numpy().flatten()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate embedding: {e}")
|
||||
|
||||
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images (optimized batch processing).
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
# Validate all images
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, Image.Image):
|
||||
raise ValueError(f"Image at index {i} is not a PIL Image")
|
||||
|
||||
try:
|
||||
# Preprocess all images
|
||||
image_tensors = torch.stack([
|
||||
self.preprocess(img) for img in images
|
||||
]).to(self.device)
|
||||
|
||||
# Generate embeddings in batch
|
||||
with torch.no_grad():
|
||||
embeddings = self.model.encode_image(image_tensors)
|
||||
# L2 normalize for cosine similarity
|
||||
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embeddings.cpu().numpy()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get the dimensionality of embeddings.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension (512 for ViT-B/32)
|
||||
"""
|
||||
return self._embedding_dim
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get model identifier.
|
||||
|
||||
Returns:
|
||||
str: Model name (e.g., "clip-vit-b32")
|
||||
"""
|
||||
return f"clip-{self.model_name.lower().replace('/', '-')}"
|
||||
|
||||
def supports_batch(self) -> bool:
|
||||
"""
|
||||
Check if batch processing is supported.
|
||||
|
||||
Returns:
|
||||
bool: True (CLIP supports efficient batch processing)
|
||||
"""
|
||||
return True
|
||||
|
||||
def fine_tune(
|
||||
self,
|
||||
positive_images: List[Image.Image],
|
||||
negative_images: List[Image.Image],
|
||||
epochs: int = 1,
|
||||
learning_rate: float = 1e-4
|
||||
) -> dict:
|
||||
"""
|
||||
Fine-tune the model using contrastive learning.
|
||||
|
||||
This method fine-tunes only the final projection layer to adapt
|
||||
the model to user-specific workflows. It uses a simple contrastive
|
||||
loss: positive examples should be similar, negative examples should
|
||||
be dissimilar.
|
||||
|
||||
Args:
|
||||
positive_images: Images from successful workflows
|
||||
negative_images: Images from rejected workflows
|
||||
epochs: Number of training epochs (default: 1 for speed)
|
||||
learning_rate: Learning rate (default: 1e-4)
|
||||
|
||||
Returns:
|
||||
dict: Training metrics (loss, accuracy, etc.)
|
||||
"""
|
||||
if not positive_images and not negative_images:
|
||||
return {'loss': 0.0, 'note': 'No examples to train on'}
|
||||
|
||||
# Set model to training mode
|
||||
self.model.train()
|
||||
|
||||
# Only train the visual projection layer (last layer)
|
||||
# Freeze all other parameters
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Unfreeze visual projection
|
||||
if hasattr(self.model, 'visual') and hasattr(self.model.visual, 'proj'):
|
||||
if self.model.visual.proj is not None:
|
||||
self.model.visual.proj.requires_grad = True
|
||||
|
||||
# Setup optimizer (only for trainable parameters)
|
||||
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
||||
|
||||
if not trainable_params:
|
||||
logger.warning("No trainable parameters found, skipping fine-tuning")
|
||||
self.model.eval()
|
||||
return {'loss': 0.0, 'note': 'No trainable parameters'}
|
||||
|
||||
optimizer = torch.optim.Adam(trainable_params, lr=learning_rate)
|
||||
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
try:
|
||||
for epoch in range(epochs):
|
||||
# Process positive examples (should be similar to each other)
|
||||
if len(positive_images) >= 2:
|
||||
pos_loss = self._contrastive_loss_positive(positive_images)
|
||||
total_loss += pos_loss
|
||||
num_batches += 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
pos_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Process negative examples (should be dissimilar from positives)
|
||||
if positive_images and negative_images:
|
||||
neg_loss = self._contrastive_loss_negative(
|
||||
positive_images[:5], # Use subset for speed
|
||||
negative_images[:5]
|
||||
)
|
||||
total_loss += neg_loss
|
||||
num_batches += 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
neg_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
finally:
|
||||
# Always restore eval mode and freeze parameters
|
||||
self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
|
||||
return {
|
||||
'loss': float(avg_loss),
|
||||
'epochs': epochs,
|
||||
'learning_rate': learning_rate,
|
||||
'positive_count': len(positive_images),
|
||||
'negative_count': len(negative_images),
|
||||
'num_batches': num_batches
|
||||
}
|
||||
|
||||
def _contrastive_loss_positive(self, images: List[Image.Image]) -> torch.Tensor:
|
||||
"""
|
||||
Contrastive loss for positive examples (should be similar).
|
||||
|
||||
Args:
|
||||
images: List of positive example images
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loss value
|
||||
"""
|
||||
# Generate embeddings
|
||||
embeddings = []
|
||||
for img in images:
|
||||
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
||||
emb = self.model.encode_image(img_tensor)
|
||||
emb = emb / emb.norm(dim=-1, keepdim=True)
|
||||
embeddings.append(emb)
|
||||
|
||||
embeddings = torch.cat(embeddings, dim=0)
|
||||
|
||||
# Compute pairwise cosine similarities
|
||||
similarities = torch.mm(embeddings, embeddings.t())
|
||||
|
||||
# Loss: maximize similarity (minimize negative similarity)
|
||||
# Exclude diagonal (self-similarity)
|
||||
mask = torch.eye(len(images), device=self.device).bool()
|
||||
similarities = similarities.masked_fill(mask, 0)
|
||||
|
||||
# We want high similarity, so minimize (1 - similarity)
|
||||
loss = (1 - similarities).mean()
|
||||
|
||||
return loss
|
||||
|
||||
def _contrastive_loss_negative(
|
||||
self,
|
||||
positive_images: List[Image.Image],
|
||||
negative_images: List[Image.Image]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Contrastive loss for negative examples (should be dissimilar).
|
||||
|
||||
Args:
|
||||
positive_images: Positive example images
|
||||
negative_images: Negative example images
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loss value
|
||||
"""
|
||||
# Generate embeddings for positives
|
||||
pos_embeddings = []
|
||||
for img in positive_images:
|
||||
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
||||
emb = self.model.encode_image(img_tensor)
|
||||
emb = emb / emb.norm(dim=-1, keepdim=True)
|
||||
pos_embeddings.append(emb)
|
||||
|
||||
pos_embeddings = torch.cat(pos_embeddings, dim=0)
|
||||
|
||||
# Generate embeddings for negatives
|
||||
neg_embeddings = []
|
||||
for img in negative_images:
|
||||
img_tensor = self.preprocess(img).unsqueeze(0).to(self.device)
|
||||
emb = self.model.encode_image(img_tensor)
|
||||
emb = emb / emb.norm(dim=-1, keepdim=True)
|
||||
neg_embeddings.append(emb)
|
||||
|
||||
neg_embeddings = torch.cat(neg_embeddings, dim=0)
|
||||
|
||||
# Compute cross-similarities (positive vs negative)
|
||||
similarities = torch.mm(pos_embeddings, neg_embeddings.t())
|
||||
|
||||
# We want low similarity, so minimize similarity directly
|
||||
# (or maximize dissimilarity)
|
||||
loss = similarities.mean()
|
||||
|
||||
return loss
|
||||
Reference in New Issue
Block a user