Initial commit
This commit is contained in:
206
test_pix2struct_vs_clip.py
Executable file
206
test_pix2struct_vs_clip.py
Executable file
@@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Benchmark Pix2Struct vs CLIP for UI understanding.
|
||||
|
||||
This script compares the two models on:
|
||||
1. Embedding quality for UI screenshots
|
||||
2. Performance (time, memory)
|
||||
3. Similarity matching accuracy
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from geniusia2.core.embedders import CLIPEmbedder, EmbeddingManager
|
||||
|
||||
try:
|
||||
from geniusia2.core.embedders import Pix2StructEmbedder
|
||||
PIX2STRUCT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PIX2STRUCT_AVAILABLE = False
|
||||
print("⚠️ Pix2Struct not available. Install with: pip install transformers>=4.35.0")
|
||||
|
||||
|
||||
def create_ui_screenshot(text: str, button_color=(100, 150, 255), size=(400, 300)):
|
||||
"""Create a fake UI screenshot with a button."""
|
||||
img = Image.new('RGB', size, color=(240, 240, 240))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Draw a button
|
||||
button_rect = [100, 100, 300, 150]
|
||||
draw.rectangle(button_rect, fill=button_color, outline=(50, 50, 50), width=2)
|
||||
|
||||
# Add text
|
||||
try:
|
||||
# Try to use a font, fallback to default if not available
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
text_bbox = draw.textbbox((0, 0), text, font=font)
|
||||
text_width = text_bbox[2] - text_bbox[0]
|
||||
text_height = text_bbox[3] - text_bbox[1]
|
||||
text_pos = (200 - text_width//2, 125 - text_height//2)
|
||||
draw.text(text_pos, text, fill=(255, 255, 255), font=font)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def benchmark_model(embedder, name, images):
|
||||
"""Benchmark a model on a set of images."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"BENCHMARKING: {name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Test single embedding
|
||||
print("\n1. Single Embedding")
|
||||
start = time.time()
|
||||
emb = embedder.embed(images[0])
|
||||
single_time = time.time() - start
|
||||
print(f" Time: {single_time*1000:.2f}ms")
|
||||
print(f" Shape: {emb.shape}")
|
||||
print(f" Norm: {np.linalg.norm(emb):.4f}")
|
||||
|
||||
# Test batch embedding
|
||||
print("\n2. Batch Embedding (5 images)")
|
||||
start = time.time()
|
||||
embs = embedder.embed_batch(images)
|
||||
batch_time = time.time() - start
|
||||
print(f" Time: {batch_time*1000:.2f}ms")
|
||||
print(f" Time per image: {batch_time*1000/len(images):.2f}ms")
|
||||
print(f" Shape: {embs.shape}")
|
||||
|
||||
# Test similarity
|
||||
print("\n3. Similarity Test")
|
||||
# Create similar and different buttons
|
||||
similar_img = create_ui_screenshot("Submit", button_color=(100, 150, 255))
|
||||
different_img = create_ui_screenshot("Cancel", button_color=(255, 100, 100))
|
||||
|
||||
ref_emb = embedder.embed(images[0]) # "Submit" button
|
||||
similar_emb = embedder.embed(similar_img)
|
||||
different_emb = embedder.embed(different_img)
|
||||
|
||||
sim_similar = np.dot(ref_emb, similar_emb)
|
||||
sim_different = np.dot(ref_emb, different_emb)
|
||||
|
||||
print(f" Submit vs Submit: {sim_similar:.4f}")
|
||||
print(f" Submit vs Cancel: {sim_different:.4f}")
|
||||
print(f" Discrimination: {sim_similar - sim_different:.4f} (higher is better)")
|
||||
|
||||
return {
|
||||
'name': name,
|
||||
'dimension': embedder.get_dimension(),
|
||||
'single_time_ms': single_time * 1000,
|
||||
'batch_time_ms': batch_time * 1000,
|
||||
'time_per_image_ms': batch_time * 1000 / len(images),
|
||||
'sim_similar': sim_similar,
|
||||
'sim_different': sim_different,
|
||||
'discrimination': sim_similar - sim_different
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Run benchmark."""
|
||||
print("\n" + "="*60)
|
||||
print("PIX2STRUCT VS CLIP BENCHMARK")
|
||||
print("="*60)
|
||||
|
||||
# Create test images
|
||||
print("\nCreating test UI screenshots...")
|
||||
test_images = [
|
||||
create_ui_screenshot("Submit", button_color=(100, 150, 255)),
|
||||
create_ui_screenshot("OK", button_color=(100, 200, 100)),
|
||||
create_ui_screenshot("Cancel", button_color=(255, 100, 100)),
|
||||
create_ui_screenshot("Apply", button_color=(150, 100, 255)),
|
||||
create_ui_screenshot("Close", button_color=(200, 200, 200)),
|
||||
]
|
||||
print(f"✓ Created {len(test_images)} test images")
|
||||
|
||||
results = []
|
||||
|
||||
# Test CLIP
|
||||
print("\n" + "="*60)
|
||||
print("Testing CLIP")
|
||||
print("="*60)
|
||||
try:
|
||||
clip = CLIPEmbedder(device='cpu')
|
||||
clip_results = benchmark_model(clip, "CLIP ViT-B/32", test_images)
|
||||
results.append(clip_results)
|
||||
except Exception as e:
|
||||
print(f"❌ CLIP failed: {e}")
|
||||
return 1
|
||||
|
||||
# Test Pix2Struct
|
||||
if PIX2STRUCT_AVAILABLE:
|
||||
print("\n" + "="*60)
|
||||
print("Testing Pix2Struct")
|
||||
print("="*60)
|
||||
try:
|
||||
pix2struct = Pix2StructEmbedder(device='cpu')
|
||||
pix2struct_results = benchmark_model(pix2struct, "Pix2Struct Base", test_images)
|
||||
results.append(pix2struct_results)
|
||||
except Exception as e:
|
||||
print(f"❌ Pix2Struct failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
else:
|
||||
print("\n⚠️ Skipping Pix2Struct (not installed)")
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*60)
|
||||
print("COMPARISON SUMMARY")
|
||||
print("="*60)
|
||||
|
||||
if len(results) == 2:
|
||||
clip_res = results[0]
|
||||
pix_res = results[1]
|
||||
|
||||
print(f"\n{'Metric':<30} {'CLIP':<15} {'Pix2Struct':<15} {'Winner':<10}")
|
||||
print("-" * 70)
|
||||
|
||||
# Dimension
|
||||
print(f"{'Embedding Dimension':<30} {clip_res['dimension']:<15} {pix_res['dimension']:<15} {'-':<10}")
|
||||
|
||||
# Speed
|
||||
clip_faster = clip_res['time_per_image_ms'] < pix_res['time_per_image_ms']
|
||||
winner = "CLIP" if clip_faster else "Pix2Struct"
|
||||
print(f"{'Time per image (ms)':<30} {clip_res['time_per_image_ms']:<15.2f} {pix_res['time_per_image_ms']:<15.2f} {winner:<10}")
|
||||
|
||||
# Discrimination
|
||||
pix_better = pix_res['discrimination'] > clip_res['discrimination']
|
||||
winner = "Pix2Struct" if pix_better else "CLIP"
|
||||
print(f"{'UI Discrimination':<30} {clip_res['discrimination']:<15.4f} {pix_res['discrimination']:<15.4f} {winner:<10}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("RECOMMENDATION")
|
||||
print("="*60)
|
||||
|
||||
if pix_better:
|
||||
print("✅ Pix2Struct shows better UI understanding")
|
||||
print(" Recommended for production use")
|
||||
else:
|
||||
print("⚠️ CLIP performs similarly or better")
|
||||
print(" Pix2Struct may not provide significant benefit")
|
||||
|
||||
if not clip_faster:
|
||||
speedup = pix_res['time_per_image_ms'] / clip_res['time_per_image_ms']
|
||||
print(f"\n⚠️ Pix2Struct is {speedup:.1f}x slower than CLIP")
|
||||
print(" Consider performance vs accuracy tradeoff")
|
||||
|
||||
else:
|
||||
print("\n✓ CLIP benchmark completed")
|
||||
print(f" Dimension: {results[0]['dimension']}")
|
||||
print(f" Speed: {results[0]['time_per_image_ms']:.2f}ms per image")
|
||||
print(f" Discrimination: {results[0]['discrimination']:.4f}")
|
||||
|
||||
print("\n✅ Benchmark complete!")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user