""" CLIP-based embedder implementation. This module provides a wrapper around OpenCLIP for generating image embeddings using the CLIP (Contrastive Language-Image Pre-training) model. """ import torch import numpy as np from PIL import Image from typing import List, Optional import logging try: import open_clip except ImportError: open_clip = None from .base import EmbedderBase logger = logging.getLogger(__name__) class CLIPEmbedder(EmbedderBase): """ CLIP-based image embedder using OpenCLIP. This embedder uses the ViT-B/32 architecture by default, which produces 512-dimensional embeddings. It automatically handles GPU/CPU device selection. """ def __init__( self, model_name: str = "ViT-B-32", pretrained: str = "openai", device: Optional[str] = None ): """ Initialize the CLIP embedder. Args: model_name: CLIP model architecture (default: ViT-B-32) pretrained: Pretrained weights to use (default: openai) device: Device to use ('cuda', 'cpu', or None for auto-detect) Note: Defaults to CPU to save GPU memory for other models Raises: ImportError: If open_clip is not installed RuntimeError: If model loading fails """ if open_clip is None: raise ImportError( "OpenCLIP is not installed. " "Install it with: pip install open-clip-torch" ) # Default to CPU to save GPU for vision models (Qwen3-VL) if device is None: device = "cpu" self.model_name = model_name self.pretrained = pretrained self.device = device self._embedding_dim = None # Load model try: self.model, _, self.preprocess = open_clip.create_model_and_transforms( model_name, pretrained=pretrained, device=device ) self.model.eval() # Determine embedding dimension with torch.no_grad(): dummy_image = torch.zeros(1, 3, 224, 224).to(self.device) dummy_embedding = self.model.encode_image(dummy_image) self._embedding_dim = dummy_embedding.shape[-1] logger.info( f"CLIPEmbedder loaded: {model_name} on {device}, " f"dimension={self._embedding_dim}" ) except Exception as e: raise RuntimeError(f"Failed to load CLIP model: {e}") def embed(self, image: Image.Image) -> np.ndarray: """ Generate embedding for a single image. Args: image: PIL Image to embed Returns: np.ndarray: Normalized embedding vector of shape (dimension,) Raises: ValueError: If image is invalid RuntimeError: If embedding generation fails """ if not isinstance(image, Image.Image): raise ValueError("Input must be a PIL Image") try: # Preprocess image image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) # Generate embedding with torch.no_grad(): embedding = self.model.encode_image(image_tensor) # L2 normalize for cosine similarity embedding = embedding / embedding.norm(dim=-1, keepdim=True) return embedding.cpu().numpy().flatten() except Exception as e: raise RuntimeError(f"Failed to generate embedding: {e}") def embed_batch(self, images: List[Image.Image]) -> np.ndarray: """ Generate embeddings for multiple images (optimized batch processing). Args: images: List of PIL Images to embed Returns: np.ndarray: Array of embeddings with shape (len(images), dimension) Raises: ValueError: If any image is invalid RuntimeError: If embedding generation fails """ if not images: return np.array([]).reshape(0, self.get_dimension()) # Validate all images for i, img in enumerate(images): if not isinstance(img, Image.Image): raise ValueError(f"Image at index {i} is not a PIL Image") try: # Preprocess all images image_tensors = torch.stack([ self.preprocess(img) for img in images ]).to(self.device) # Generate embeddings in batch with torch.no_grad(): embeddings = self.model.encode_image(image_tensors) # L2 normalize for cosine similarity embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) return embeddings.cpu().numpy() except Exception as e: raise RuntimeError(f"Failed to generate batch embeddings: {e}") def get_dimension(self) -> int: """ Get the dimensionality of embeddings. Returns: int: Embedding dimension (512 for ViT-B/32) """ return self._embedding_dim def get_model_name(self) -> str: """ Get model identifier. Returns: str: Model name (e.g., "clip-vit-b32") """ return f"clip-{self.model_name.lower().replace('/', '-')}" def supports_batch(self) -> bool: """ Check if batch processing is supported. Returns: bool: True (CLIP supports efficient batch processing) """ return True def fine_tune( self, positive_images: List[Image.Image], negative_images: List[Image.Image], epochs: int = 1, learning_rate: float = 1e-4 ) -> dict: """ Fine-tune the model using contrastive learning. This method fine-tunes only the final projection layer to adapt the model to user-specific workflows. It uses a simple contrastive loss: positive examples should be similar, negative examples should be dissimilar. Args: positive_images: Images from successful workflows negative_images: Images from rejected workflows epochs: Number of training epochs (default: 1 for speed) learning_rate: Learning rate (default: 1e-4) Returns: dict: Training metrics (loss, accuracy, etc.) """ if not positive_images and not negative_images: return {'loss': 0.0, 'note': 'No examples to train on'} # Set model to training mode self.model.train() # Only train the visual projection layer (last layer) # Freeze all other parameters for param in self.model.parameters(): param.requires_grad = False # Unfreeze visual projection if hasattr(self.model, 'visual') and hasattr(self.model.visual, 'proj'): if self.model.visual.proj is not None: self.model.visual.proj.requires_grad = True # Setup optimizer (only for trainable parameters) trainable_params = [p for p in self.model.parameters() if p.requires_grad] if not trainable_params: logger.warning("No trainable parameters found, skipping fine-tuning") self.model.eval() return {'loss': 0.0, 'note': 'No trainable parameters'} optimizer = torch.optim.Adam(trainable_params, lr=learning_rate) total_loss = 0.0 num_batches = 0 try: for epoch in range(epochs): # Process positive examples (should be similar to each other) if len(positive_images) >= 2: pos_loss = self._contrastive_loss_positive(positive_images) total_loss += pos_loss num_batches += 1 optimizer.zero_grad() pos_loss.backward() optimizer.step() # Process negative examples (should be dissimilar from positives) if positive_images and negative_images: neg_loss = self._contrastive_loss_negative( positive_images[:5], # Use subset for speed negative_images[:5] ) total_loss += neg_loss num_batches += 1 optimizer.zero_grad() neg_loss.backward() optimizer.step() finally: # Always restore eval mode and freeze parameters self.model.eval() for param in self.model.parameters(): param.requires_grad = False avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 return { 'loss': float(avg_loss), 'epochs': epochs, 'learning_rate': learning_rate, 'positive_count': len(positive_images), 'negative_count': len(negative_images), 'num_batches': num_batches } def _contrastive_loss_positive(self, images: List[Image.Image]) -> torch.Tensor: """ Contrastive loss for positive examples (should be similar). Args: images: List of positive example images Returns: torch.Tensor: Loss value """ # Generate embeddings embeddings = [] for img in images: img_tensor = self.preprocess(img).unsqueeze(0).to(self.device) emb = self.model.encode_image(img_tensor) emb = emb / emb.norm(dim=-1, keepdim=True) embeddings.append(emb) embeddings = torch.cat(embeddings, dim=0) # Compute pairwise cosine similarities similarities = torch.mm(embeddings, embeddings.t()) # Loss: maximize similarity (minimize negative similarity) # Exclude diagonal (self-similarity) mask = torch.eye(len(images), device=self.device).bool() similarities = similarities.masked_fill(mask, 0) # We want high similarity, so minimize (1 - similarity) loss = (1 - similarities).mean() return loss def _contrastive_loss_negative( self, positive_images: List[Image.Image], negative_images: List[Image.Image] ) -> torch.Tensor: """ Contrastive loss for negative examples (should be dissimilar). Args: positive_images: Positive example images negative_images: Negative example images Returns: torch.Tensor: Loss value """ # Generate embeddings for positives pos_embeddings = [] for img in positive_images: img_tensor = self.preprocess(img).unsqueeze(0).to(self.device) emb = self.model.encode_image(img_tensor) emb = emb / emb.norm(dim=-1, keepdim=True) pos_embeddings.append(emb) pos_embeddings = torch.cat(pos_embeddings, dim=0) # Generate embeddings for negatives neg_embeddings = [] for img in negative_images: img_tensor = self.preprocess(img).unsqueeze(0).to(self.device) emb = self.model.encode_image(img_tensor) emb = emb / emb.norm(dim=-1, keepdim=True) neg_embeddings.append(emb) neg_embeddings = torch.cat(neg_embeddings, dim=0) # Compute cross-similarities (positive vs negative) similarities = torch.mm(pos_embeddings, neg_embeddings.t()) # We want low similarity, so minimize similarity directly # (or maximize dissimilarity) loss = similarities.mean() return loss