Files
Geniusia_v2/geniusia2/core/embedders/pix2struct_embedder.py
2026-03-05 00:20:25 +01:00

194 lines
6.7 KiB
Python

"""
Pix2Struct-based embedder implementation.
This module provides a wrapper around Google's Pix2Struct model for generating
image embeddings specialized for UI understanding and document screenshots.
"""
import torch
import numpy as np
from PIL import Image
from typing import List, Optional
import logging
try:
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
except ImportError:
Pix2StructProcessor = None
Pix2StructForConditionalGeneration = None
from .base import EmbedderBase
logger = logging.getLogger(__name__)
class Pix2StructEmbedder(EmbedderBase):
"""
Pix2Struct-based image embedder specialized for UI understanding.
Pix2Struct is a vision-language model trained on screenshots and structured
documents, making it particularly well-suited for RPA and UI automation tasks.
This embedder uses the encoder's hidden states as embeddings, which capture
visual features optimized for understanding UI elements and layouts.
"""
def __init__(
self,
model_name: str = "google/pix2struct-base",
device: Optional[str] = None
):
"""
Initialize the Pix2Struct embedder.
Args:
model_name: Pix2Struct model to use (default: google/pix2struct-base)
device: Device to use ('cuda', 'cpu', or None for auto-detect)
Raises:
ImportError: If transformers is not installed
RuntimeError: If model loading fails
"""
if Pix2StructProcessor is None or Pix2StructForConditionalGeneration is None:
raise ImportError(
"Transformers is not installed or version is too old. "
"Install it with: pip install transformers>=4.35.0"
)
# Auto-detect device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_name = model_name
self.device = device
self._embedding_dim = None
# Load model and processor
try:
logger.info(f"Loading Pix2Struct model: {model_name}")
self.processor = Pix2StructProcessor.from_pretrained(model_name)
self.model = Pix2StructForConditionalGeneration.from_pretrained(
model_name
).to(device)
self.model.eval()
# Determine embedding dimension from encoder
with torch.no_grad():
dummy_image = Image.new('RGB', (224, 224), color=(128, 128, 128))
inputs = self.processor(images=dummy_image, return_tensors="pt").to(device)
encoder_outputs = self.model.encoder(**inputs)
# Use mean pooling of last hidden state
self._embedding_dim = encoder_outputs.last_hidden_state.shape[-1]
logger.info(
f"Pix2StructEmbedder loaded: {model_name} on {device}, "
f"dimension={self._embedding_dim}"
)
except Exception as e:
raise RuntimeError(f"Failed to load Pix2Struct model: {e}")
def embed(self, image: Image.Image) -> np.ndarray:
"""
Generate embedding for a single image.
Args:
image: PIL Image to embed
Returns:
np.ndarray: Normalized embedding vector of shape (dimension,)
Raises:
ValueError: If image is invalid
RuntimeError: If embedding generation fails
"""
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL Image")
try:
# Process image
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
# Generate embedding from encoder
with torch.no_grad():
encoder_outputs = self.model.encoder(**inputs)
# Mean pooling over sequence dimension
embedding = encoder_outputs.last_hidden_state.mean(dim=1)
# L2 normalize for cosine similarity
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
return embedding.cpu().numpy().flatten()
except Exception as e:
raise RuntimeError(f"Failed to generate embedding: {e}")
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
"""
Generate embeddings for multiple images (optimized batch processing).
Args:
images: List of PIL Images to embed
Returns:
np.ndarray: Array of embeddings with shape (len(images), dimension)
Raises:
ValueError: If any image is invalid
RuntimeError: If embedding generation fails
"""
if not images:
return np.array([]).reshape(0, self.get_dimension())
# Validate all images
for i, img in enumerate(images):
if not isinstance(img, Image.Image):
raise ValueError(f"Image at index {i} is not a PIL Image")
try:
# Process all images in batch
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
# Generate embeddings in batch
with torch.no_grad():
encoder_outputs = self.model.encoder(**inputs)
# Mean pooling over sequence dimension
embeddings = encoder_outputs.last_hidden_state.mean(dim=1)
# L2 normalize for cosine similarity
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
return embeddings.cpu().numpy()
except Exception as e:
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
def get_dimension(self) -> int:
"""
Get the dimensionality of embeddings.
Returns:
int: Embedding dimension (768 for pix2struct-base)
"""
return self._embedding_dim
def get_model_name(self) -> str:
"""
Get model identifier.
Returns:
str: Model name (e.g., "pix2struct-base")
"""
# Extract model name from full path
model_id = self.model_name.split('/')[-1]
return f"pix2struct-{model_id}" if not model_id.startswith('pix2struct') else model_id
def supports_batch(self) -> bool:
"""
Check if batch processing is supported.
Returns:
bool: True (Pix2Struct supports efficient batch processing)
"""
return True