""" Pix2Struct-based embedder implementation. This module provides a wrapper around Google's Pix2Struct model for generating image embeddings specialized for UI understanding and document screenshots. """ import torch import numpy as np from PIL import Image from typing import List, Optional import logging try: from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration except ImportError: Pix2StructProcessor = None Pix2StructForConditionalGeneration = None from .base import EmbedderBase logger = logging.getLogger(__name__) class Pix2StructEmbedder(EmbedderBase): """ Pix2Struct-based image embedder specialized for UI understanding. Pix2Struct is a vision-language model trained on screenshots and structured documents, making it particularly well-suited for RPA and UI automation tasks. This embedder uses the encoder's hidden states as embeddings, which capture visual features optimized for understanding UI elements and layouts. """ def __init__( self, model_name: str = "google/pix2struct-base", device: Optional[str] = None ): """ Initialize the Pix2Struct embedder. Args: model_name: Pix2Struct model to use (default: google/pix2struct-base) device: Device to use ('cuda', 'cpu', or None for auto-detect) Raises: ImportError: If transformers is not installed RuntimeError: If model loading fails """ if Pix2StructProcessor is None or Pix2StructForConditionalGeneration is None: raise ImportError( "Transformers is not installed or version is too old. " "Install it with: pip install transformers>=4.35.0" ) # Auto-detect device if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.model_name = model_name self.device = device self._embedding_dim = None # Load model and processor try: logger.info(f"Loading Pix2Struct model: {model_name}") self.processor = Pix2StructProcessor.from_pretrained(model_name) self.model = Pix2StructForConditionalGeneration.from_pretrained( model_name ).to(device) self.model.eval() # Determine embedding dimension from encoder with torch.no_grad(): dummy_image = Image.new('RGB', (224, 224), color=(128, 128, 128)) inputs = self.processor(images=dummy_image, return_tensors="pt").to(device) encoder_outputs = self.model.encoder(**inputs) # Use mean pooling of last hidden state self._embedding_dim = encoder_outputs.last_hidden_state.shape[-1] logger.info( f"Pix2StructEmbedder loaded: {model_name} on {device}, " f"dimension={self._embedding_dim}" ) except Exception as e: raise RuntimeError(f"Failed to load Pix2Struct model: {e}") def embed(self, image: Image.Image) -> np.ndarray: """ Generate embedding for a single image. Args: image: PIL Image to embed Returns: np.ndarray: Normalized embedding vector of shape (dimension,) Raises: ValueError: If image is invalid RuntimeError: If embedding generation fails """ if not isinstance(image, Image.Image): raise ValueError("Input must be a PIL Image") try: # Process image inputs = self.processor(images=image, return_tensors="pt").to(self.device) # Generate embedding from encoder with torch.no_grad(): encoder_outputs = self.model.encoder(**inputs) # Mean pooling over sequence dimension embedding = encoder_outputs.last_hidden_state.mean(dim=1) # L2 normalize for cosine similarity embedding = embedding / embedding.norm(dim=-1, keepdim=True) return embedding.cpu().numpy().flatten() except Exception as e: raise RuntimeError(f"Failed to generate embedding: {e}") def embed_batch(self, images: List[Image.Image]) -> np.ndarray: """ Generate embeddings for multiple images (optimized batch processing). Args: images: List of PIL Images to embed Returns: np.ndarray: Array of embeddings with shape (len(images), dimension) Raises: ValueError: If any image is invalid RuntimeError: If embedding generation fails """ if not images: return np.array([]).reshape(0, self.get_dimension()) # Validate all images for i, img in enumerate(images): if not isinstance(img, Image.Image): raise ValueError(f"Image at index {i} is not a PIL Image") try: # Process all images in batch inputs = self.processor(images=images, return_tensors="pt").to(self.device) # Generate embeddings in batch with torch.no_grad(): encoder_outputs = self.model.encoder(**inputs) # Mean pooling over sequence dimension embeddings = encoder_outputs.last_hidden_state.mean(dim=1) # L2 normalize for cosine similarity embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) return embeddings.cpu().numpy() except Exception as e: raise RuntimeError(f"Failed to generate batch embeddings: {e}") def get_dimension(self) -> int: """ Get the dimensionality of embeddings. Returns: int: Embedding dimension (768 for pix2struct-base) """ return self._embedding_dim def get_model_name(self) -> str: """ Get model identifier. Returns: str: Model name (e.g., "pix2struct-base") """ # Extract model name from full path model_id = self.model_name.split('/')[-1] return f"pix2struct-{model_id}" if not model_id.startswith('pix2struct') else model_id def supports_batch(self) -> bool: """ Check if batch processing is supported. Returns: bool: True (Pix2Struct supports efficient batch processing) """ return True