Initial commit
This commit is contained in:
193
geniusia2/core/embedders/pix2struct_embedder.py
Normal file
193
geniusia2/core/embedders/pix2struct_embedder.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Pix2Struct-based embedder implementation.
|
||||
|
||||
This module provides a wrapper around Google's Pix2Struct model for generating
|
||||
image embeddings specialized for UI understanding and document screenshots.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
try:
|
||||
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
|
||||
except ImportError:
|
||||
Pix2StructProcessor = None
|
||||
Pix2StructForConditionalGeneration = None
|
||||
|
||||
from .base import EmbedderBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Pix2StructEmbedder(EmbedderBase):
|
||||
"""
|
||||
Pix2Struct-based image embedder specialized for UI understanding.
|
||||
|
||||
Pix2Struct is a vision-language model trained on screenshots and structured
|
||||
documents, making it particularly well-suited for RPA and UI automation tasks.
|
||||
|
||||
This embedder uses the encoder's hidden states as embeddings, which capture
|
||||
visual features optimized for understanding UI elements and layouts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "google/pix2struct-base",
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the Pix2Struct embedder.
|
||||
|
||||
Args:
|
||||
model_name: Pix2Struct model to use (default: google/pix2struct-base)
|
||||
device: Device to use ('cuda', 'cpu', or None for auto-detect)
|
||||
|
||||
Raises:
|
||||
ImportError: If transformers is not installed
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
if Pix2StructProcessor is None or Pix2StructForConditionalGeneration is None:
|
||||
raise ImportError(
|
||||
"Transformers is not installed or version is too old. "
|
||||
"Install it with: pip install transformers>=4.35.0"
|
||||
)
|
||||
|
||||
# Auto-detect device
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self._embedding_dim = None
|
||||
|
||||
# Load model and processor
|
||||
try:
|
||||
logger.info(f"Loading Pix2Struct model: {model_name}")
|
||||
|
||||
self.processor = Pix2StructProcessor.from_pretrained(model_name)
|
||||
self.model = Pix2StructForConditionalGeneration.from_pretrained(
|
||||
model_name
|
||||
).to(device)
|
||||
self.model.eval()
|
||||
|
||||
# Determine embedding dimension from encoder
|
||||
with torch.no_grad():
|
||||
dummy_image = Image.new('RGB', (224, 224), color=(128, 128, 128))
|
||||
inputs = self.processor(images=dummy_image, return_tensors="pt").to(device)
|
||||
encoder_outputs = self.model.encoder(**inputs)
|
||||
# Use mean pooling of last hidden state
|
||||
self._embedding_dim = encoder_outputs.last_hidden_state.shape[-1]
|
||||
|
||||
logger.info(
|
||||
f"Pix2StructEmbedder loaded: {model_name} on {device}, "
|
||||
f"dimension={self._embedding_dim}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load Pix2Struct model: {e}")
|
||||
|
||||
def embed(self, image: Image.Image) -> np.ndarray:
|
||||
"""
|
||||
Generate embedding for a single image.
|
||||
|
||||
Args:
|
||||
image: PIL Image to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized embedding vector of shape (dimension,)
|
||||
|
||||
Raises:
|
||||
ValueError: If image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError("Input must be a PIL Image")
|
||||
|
||||
try:
|
||||
# Process image
|
||||
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
||||
|
||||
# Generate embedding from encoder
|
||||
with torch.no_grad():
|
||||
encoder_outputs = self.model.encoder(**inputs)
|
||||
# Mean pooling over sequence dimension
|
||||
embedding = encoder_outputs.last_hidden_state.mean(dim=1)
|
||||
# L2 normalize for cosine similarity
|
||||
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embedding.cpu().numpy().flatten()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate embedding: {e}")
|
||||
|
||||
def embed_batch(self, images: List[Image.Image]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for multiple images (optimized batch processing).
|
||||
|
||||
Args:
|
||||
images: List of PIL Images to embed
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of embeddings with shape (len(images), dimension)
|
||||
|
||||
Raises:
|
||||
ValueError: If any image is invalid
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not images:
|
||||
return np.array([]).reshape(0, self.get_dimension())
|
||||
|
||||
# Validate all images
|
||||
for i, img in enumerate(images):
|
||||
if not isinstance(img, Image.Image):
|
||||
raise ValueError(f"Image at index {i} is not a PIL Image")
|
||||
|
||||
try:
|
||||
# Process all images in batch
|
||||
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
|
||||
|
||||
# Generate embeddings in batch
|
||||
with torch.no_grad():
|
||||
encoder_outputs = self.model.encoder(**inputs)
|
||||
# Mean pooling over sequence dimension
|
||||
embeddings = encoder_outputs.last_hidden_state.mean(dim=1)
|
||||
# L2 normalize for cosine similarity
|
||||
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
return embeddings.cpu().numpy()
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate batch embeddings: {e}")
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
"""
|
||||
Get the dimensionality of embeddings.
|
||||
|
||||
Returns:
|
||||
int: Embedding dimension (768 for pix2struct-base)
|
||||
"""
|
||||
return self._embedding_dim
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get model identifier.
|
||||
|
||||
Returns:
|
||||
str: Model name (e.g., "pix2struct-base")
|
||||
"""
|
||||
# Extract model name from full path
|
||||
model_id = self.model_name.split('/')[-1]
|
||||
return f"pix2struct-{model_id}" if not model_id.startswith('pix2struct') else model_id
|
||||
|
||||
def supports_batch(self) -> bool:
|
||||
"""
|
||||
Check if batch processing is supported.
|
||||
|
||||
Returns:
|
||||
bool: True (Pix2Struct supports efficient batch processing)
|
||||
"""
|
||||
return True
|
||||
Reference in New Issue
Block a user