194 lines
6.7 KiB
Python
194 lines
6.7 KiB
Python
"""
|
|
Pix2Struct-based embedder implementation.
|
|
|
|
This module provides a wrapper around Google's Pix2Struct model for generating
|
|
image embeddings specialized for UI understanding and document screenshots.
|
|
"""
|
|
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
from typing import List, Optional
|
|
import logging
|
|
|
|
try:
|
|
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
|
except ImportError:
|
|
Pix2StructProcessor = None
|
|
Pix2StructForConditionalGeneration = None
|
|
|
|
from .base import EmbedderBase
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Pix2StructEmbedder(EmbedderBase):
|
|
"""
|
|
Pix2Struct-based image embedder specialized for UI understanding.
|
|
|
|
Pix2Struct is a vision-language model trained on screenshots and structured
|
|
documents, making it particularly well-suited for RPA and UI automation tasks.
|
|
|
|
This embedder uses the encoder's hidden states as embeddings, which capture
|
|
visual features optimized for understanding UI elements and layouts.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "google/pix2struct-base",
|
|
device: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize the Pix2Struct embedder.
|
|
|
|
Args:
|
|
model_name: Pix2Struct model to use (default: google/pix2struct-base)
|
|
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
|
|
|
Raises:
|
|
ImportError: If transformers is not installed
|
|
RuntimeError: If model loading fails
|
|
"""
|
|
if Pix2StructProcessor is None or Pix2StructForConditionalGeneration is None:
|
|
raise ImportError(
|
|
"Transformers is not installed or version is too old. "
|
|
"Install it with: pip install transformers>=4.35.0"
|
|
)
|
|
|
|
# Auto-detect device
|
|
if device is None:
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
self.model_name = model_name
|
|
self.device = device
|
|
self._embedding_dim = None
|
|
|
|
# Load model and processor
|
|
try:
|
|
logger.info(f"Loading Pix2Struct model: {model_name}")
|
|
|
|
self.processor = Pix2StructProcessor.from_pretrained(model_name)
|
|
self.model = Pix2StructForConditionalGeneration.from_pretrained(
|
|
model_name
|
|
).to(device)
|
|
self.model.eval()
|
|
|
|
# Determine embedding dimension from encoder
|
|
with torch.no_grad():
|
|
dummy_image = Image.new('RGB', (224, 224), color=(128, 128, 128))
|
|
inputs = self.processor(images=dummy_image, return_tensors="pt").to(device)
|
|
encoder_outputs = self.model.encoder(**inputs)
|
|
# Use mean pooling of last hidden state
|
|
self._embedding_dim = encoder_outputs.last_hidden_state.shape[-1]
|
|
|
|
logger.info(
|
|
f"Pix2StructEmbedder loaded: {model_name} on {device}, "
|
|
f"dimension={self._embedding_dim}"
|
|
)
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load Pix2Struct model: {e}")
|
|
|
|
def embed(self, image: Image.Image) -> np.ndarray:
|
|
"""
|
|
Generate embedding for a single image.
|
|
|
|
Args:
|
|
image: PIL Image to embed
|
|
|
|
Returns:
|
|
np.ndarray: Normalized embedding vector of shape (dimension,)
|
|
|
|
Raises:
|
|
ValueError: If image is invalid
|
|
RuntimeError: If embedding generation fails
|
|
"""
|
|
if not isinstance(image, Image.Image):
|
|
raise ValueError("Input must be a PIL Image")
|
|
|
|
try:
|
|
# Process image
|
|
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
|
|
|
# Generate embedding from encoder
|
|
with torch.no_grad():
|
|
encoder_outputs = self.model.encoder(**inputs)
|
|
# Mean pooling over sequence dimension
|
|
embedding = encoder_outputs.last_hidden_state.mean(dim=1)
|
|
# L2 normalize for cosine similarity
|
|
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
|
|
|
|
return embedding.cpu().numpy().flatten()
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to generate embedding: {e}")
|
|
|
|
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
|
"""
|
|
Generate embeddings for multiple images (optimized batch processing).
|
|
|
|
Args:
|
|
images: List of PIL Images to embed
|
|
|
|
Returns:
|
|
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
|
|
|
Raises:
|
|
ValueError: If any image is invalid
|
|
RuntimeError: If embedding generation fails
|
|
"""
|
|
if not images:
|
|
return np.array([]).reshape(0, self.get_dimension())
|
|
|
|
# Validate all images
|
|
for i, img in enumerate(images):
|
|
if not isinstance(img, Image.Image):
|
|
raise ValueError(f"Image at index {i} is not a PIL Image")
|
|
|
|
try:
|
|
# Process all images in batch
|
|
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
|
|
|
|
# Generate embeddings in batch
|
|
with torch.no_grad():
|
|
encoder_outputs = self.model.encoder(**inputs)
|
|
# Mean pooling over sequence dimension
|
|
embeddings = encoder_outputs.last_hidden_state.mean(dim=1)
|
|
# L2 normalize for cosine similarity
|
|
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
|
|
|
|
return embeddings.cpu().numpy()
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
|
|
|
|
def get_dimension(self) -> int:
|
|
"""
|
|
Get the dimensionality of embeddings.
|
|
|
|
Returns:
|
|
int: Embedding dimension (768 for pix2struct-base)
|
|
"""
|
|
return self._embedding_dim
|
|
|
|
def get_model_name(self) -> str:
|
|
"""
|
|
Get model identifier.
|
|
|
|
Returns:
|
|
str: Model name (e.g., "pix2struct-base")
|
|
"""
|
|
# Extract model name from full path
|
|
model_id = self.model_name.split('/')[-1]
|
|
return f"pix2struct-{model_id}" if not model_id.startswith('pix2struct') else model_id
|
|
|
|
def supports_batch(self) -> bool:
|
|
"""
|
|
Check if batch processing is supported.
|
|
|
|
Returns:
|
|
bool: True (Pix2Struct supports efficient batch processing)
|
|
"""
|
|
return True
|