Clean up folder structure

This commit is contained in:
Thomas Dhome Casanova (from Dev Box)
2025-01-29 21:54:35 -08:00
parent 53900f8411
commit 124d9f6fb6
19 changed files with 141 additions and 1952 deletions

View File

@@ -0,0 +1,44 @@
import requests
import base64
from pathlib import Path
from computer_use_demo.tools.screen_capture import get_screenshot
from computer_use_demo.gui_agent.llm_utils.oai import encode_image
OUTPUT_DIR = "./tmp/outputs"
class OmniParserClient:
def __init__(self,
url: str) -> None:
self.url = url
def __call__(self,):
screenshot, screenshot_path = get_screenshot()
screenshot_path = str(screenshot_path)
image_base64 = encode_image(screenshot_path)
response = requests.post(self.url, json={"base64_image": image_base64})
response_json = response.json()
print('omniparser latency:', response_json['latency'])
som_image_data = base64.b64decode(response_json['som_image_base64'])
screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
with open(som_screenshot_path, "wb") as f:
f.write(som_image_data)
response_json['width'] = screenshot.size[0]
response_json['height'] = screenshot.size[1]
response_json['original_screenshot_base64'] = image_base64
response_json['screenshot_uuid'] = screenshot_path_uuid
response_json = self.reformat_messages(response_json)
return response_json
def reformat_messages(self, response_json: dict):
screen_info = ""
for idx, element in enumerate(response_json["parsed_content_list"]):
element['idx'] = idx
if element['type'] == 'text':
screen_info += f'ID: {idx}, Text: {element["content"]}\n'
elif element['type'] == 'icon':
screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
response_json['screen_info'] = screen_info
return response_json

View File

@@ -1,23 +1,16 @@
import json
import asyncio
import platform
from collections.abc import Callable
from datetime import datetime
from enum import StrEnum
from typing import Any, cast, Dict, Callable
from typing import cast, Callable
import uuid
import requests
from PIL import Image, ImageDraw
import base64
from io import BytesIO
from pathlib import Path
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
from anthropic.types import TextBlock, ToolResultBlockParam
from anthropic import APIResponse
from anthropic.types import ToolResultBlockParam
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
from computer_use_demo.tools.screen_capture import get_screenshot
from computer_use_demo.gui_agent.llm_utils.oai import run_oai_interleaved, encode_image
from computer_use_demo.gui_agent.llm_utils.oai import run_oai_interleaved
from computer_use_demo.colorful_text import colorful_text_vlm
import time
import re
@@ -33,67 +26,6 @@ def extract_data(input_string, data_type):
# Return the first match if exists, trimming whitespace and ignoring potential closing backticks
return matches[0][0].strip() if matches else input_string
class OmniParser:
def __init__(self,
url: str) -> None:
self.url = url
if not self.url:
config = {
'som_model_path': '../weights/icon_detect_v1_5/model_v1_5.pt',
'device': 'cpu',
'caption_model_name': 'florence2',
'caption_model_path': '../weights/icon_caption_florence',
'BOX_TRESHOLD': 0.05
}
from computer_use_demo.omniparser_agent.omniparser import Omniparser as Omniparser_class
self.omniparser = Omniparser_class(config=config)
def __call__(self,):
screenshot, screenshot_path = get_screenshot()
screenshot_path = str(screenshot_path)
image_base64 = encode_image(screenshot_path)
if self.url:
response = requests.post(self.url, json={"base64_image": image_base64, 'prompt': 'omniparser process'})
response_json = response.json()
else:
start = time.time()
dino_labled_img, parsed_content_list = self.omniparser.parse(image_base64)
latency = time.time() - start
response_json = {
'som_image_base64': dino_labled_img,
'parsed_content_list': parsed_content_list,
'latency': latency
}
som_image_data = base64.b64decode(response_json['som_image_base64'])
screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
with open(som_screenshot_path, "wb") as f:
f.write(som_image_data)
response_json['width'] = screenshot.size[0]
response_json['height'] = screenshot.size[1]
response_json['original_screenshot_base64'] = image_base64
response_json['screenshot_uuid'] = screenshot_path_uuid
# example response_json: {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, "latency": 0.1}
print('omniparser latency:', response_json['latency'])
response_json = self.reformat_messages(response_json)
return response_json
def reformat_messages(self, response_json: dict):
screen_info = ""
for idx, element in enumerate(response_json["parsed_content_list"]):
element['idx'] = idx
if element['type'] == 'text':
# screen_info += f'''<p id={idx} class="text" alt="{element['content']}"> </p>\n'''
screen_info += f'ID: {idx}, Text: {element["content"]}\n'
elif element['type'] == 'icon':
# screen_info += f'''<img id={idx} class="icon" alt="{element['content']}"> </img>\n'''
screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
response_json['screen_info'] = screen_info
return response_json
class VLMAgent:
def __init__(
self,

View File

@@ -15,9 +15,10 @@ from anthropic.types.beta import (
)
from computer_use_demo.tools import ToolResult
from computer_use_demo.gui_agent.anthropic_agent import AnthropicActor
from computer_use_demo.agent.llm_utils.omniparserclient import OmniParserClient
from computer_use_demo.agent.anthropic_agent import AnthropicActor
from computer_use_demo.agent.vlm_agent import VLMAgent
from computer_use_demo.executor.anthropic_executor import AnthropicExecutor
from computer_use_demo.omniparser_agent.vlm_agent import OmniParser, VLMAgent
BETA_FLAG = "computer-use-2024-10-22"
@@ -52,7 +53,7 @@ def sampling_loop_sync(
Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
"""
print('in sampling_loop_sync, model:', model)
omniparser = OmniParser(url=f"http://{omniparser_url}/send_text/" if omniparser_url else None)
omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
if model == "claude-3-5-sonnet-20241022":
# Register Actor and Executor
actor = AnthropicActor(
@@ -94,7 +95,7 @@ def sampling_loop_sync(
if model == "claude-3-5-sonnet-20241022": # Anthropic loop
while True:
parsed_screen = omniparser() # parsed_screen: {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, "screen_info"}
parsed_screen = omniparser_client() # parsed_screen: {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, "screen_info"}
import pdb; pdb.set_trace()
screen_info_block = TextBlock(text='Below is the structured accessibility information of the current UI screen, which includes text and icons you can operate on, take these information into account when you are making the prediction for the next action. Note you will still need to take screenshot to get the image: \n' + parsed_screen['screen_info'], type='text')
# # messages[-1]['content'].append(screen_info_block)
@@ -112,7 +113,7 @@ def sampling_loop_sync(
elif model == "omniparser + gpt-4o" or model == "omniparser + phi35v":
while True:
parsed_screen = omniparser()
parsed_screen = omniparser_client()
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
for message, tool_result_content in executor(tools_use_needed, messages):

View File

@@ -1,77 +0,0 @@
# uvicorn remote_request:app --host 0.0.0.0 --port 8000 --reload
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from computer_use_demo.omniparser_agent.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
import torch
from PIL import Image
from typing import Dict, Tuple, List
import base64
import io
# config = {
# 'som_model_path': '../weights/icon_detect_v1_5/model_v1_5.pt',
# 'device': 'cpu',
# 'caption_model_name': 'florence2',
# 'caption_model_path': '../weights/icon_caption_florence',
# 'BOX_TRESHOLD': 0.05
# }
class Omniparser(object):
def __init__(self, config: Dict):
self.config = config
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.som_model = get_yolo_model(model_path=config['som_model_path'])
self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
print('Omniparser initialized!!!')
def parse(self, image_base64: str):
image_path = './demo_image.jpg'
with open(image_path, "wb") as fh:
fh.write(base64.b64decode(image_base64))
print('Parsing image:', image_path)
image = Image.open(image_path)
print('image size:', image.size)
box_overlay_ratio = max(image.size) / 3200
draw_bbox_config = {
'text_scale': 0.8 * box_overlay_ratio,
'text_thickness': max(int(2 * box_overlay_ratio), 1),
'text_padding': max(int(3 * box_overlay_ratio), 1),
'thickness': max(int(3 * box_overlay_ratio), 1),
}
BOX_TRESHOLD = self.config['BOX_TRESHOLD']
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.8}, use_paddleocr=True)
text, ocr_bbox = ocr_bbox_rslt
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
return dino_labled_img, parsed_content_list
# from fastapi import FastAPI
# from pydantic import BaseModel
# app = FastAPI()
# class Item(BaseModel):
# base64_image: str
# prompt: str
# Omniparser = Omniparser(config)
# @app.post("/send_text/")
# async def send_text(item: Item):
# print('start parsing...')
# import time
# start = time.time()
# dino_labled_img, parsed_content_list = Omniparser.parse(item.base64_image)
# latency = time.time() - start
# print('time:', latency)
# return {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, 'latency': latency}

View File

@@ -1,935 +0,0 @@
# from ultralytics import YOLO
import os
import io
import base64
import time
from PIL import Image, ImageDraw, ImageFont
import json
import requests
# utility function
import os
from openai import AzureOpenAI
import json
import sys
import os
import cv2
import numpy as np
# %matplotlib inline
from matplotlib import pyplot as plt
import easyocr
from paddleocr import PaddleOCR
reader = easyocr.Reader(['en', 'ch_sim'], gpu=True)
paddle_ocr = PaddleOCR(
lang='en', # other lang also available
use_angle_cls=False,
use_gpu=False, # using cuda will conflict with pytorch in the same process
show_log=False,
max_batch_size=1024,
use_dilation=True, # improves accuracy
det_db_score_mode='slow', # improves accuracy
rec_batch_num=1024)
import time
import base64
import os
import ast
import torch
from typing import Tuple, List
from torchvision.ops import box_convert
import re
from torchvision.transforms import ToPILImage
import supervision as sv
import torchvision.transforms as T
def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
if not device:
device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name == "blip2":
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
if device == 'cpu':
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float32
)
else:
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float16
).to(device)
elif model_name == "florence2":
from transformers import AutoProcessor, AutoModelForCausalLM
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
return {'model': model.to(device), 'processor': processor}
def get_yolo_model(model_path):
from ultralytics import YOLO
# Load the model.
model = YOLO(model_path)
return model
@torch.inference_mode()
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
# Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
to_pil = ToPILImage()
if starting_idx:
non_ocr_boxes = filtered_boxes[starting_idx:]
else:
non_ocr_boxes = filtered_boxes
croped_pil_image = []
t0 = time.time()
for i, coord in enumerate(non_ocr_boxes):
try:
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
# resize the image to 224x224 to avoid long overhead in clipimageprocessor # TODO
cropped_image = cv2.resize(cropped_image, (64, 64))
croped_pil_image.append(to_pil(cropped_image))
except:
continue
# print('time to prepare bbox:', time.time()-t0)
model, processor = caption_model_processor['model'], caption_model_processor['processor']
if not prompt:
if 'florence' in model.config.name_or_path:
prompt = "<CAPTION>"
else:
prompt = "The image shows"
generated_texts = []
device = model.device
# batch_size = 64
for i in range(0, len(croped_pil_image), batch_size):
start = time.time()
batch = croped_pil_image[i:i+batch_size]
t1 = time.time()
if model.device.type == 'cuda':
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
else:
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
t2 = time.time()
# print('time to process image + tokenize text inputs:', t2-t1)
if 'florence' in model.config.name_or_path:
generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
else:
generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
t3 = time.time()
# print('time to generate:', t3-t2)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
generated_text = [gen.strip() for gen in generated_text]
generated_texts.extend(generated_text)
return generated_texts
def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
to_pil = ToPILImage()
if ocr_bbox:
non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
else:
non_ocr_boxes = filtered_boxes
croped_pil_image = []
for i, coord in enumerate(non_ocr_boxes):
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
croped_pil_image.append(to_pil(cropped_image))
model, processor = caption_model_processor['model'], caption_model_processor['processor']
device = model.device
messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
batch_size = 5 # Number of samples per batch
generated_texts = []
for i in range(0, len(croped_pil_image), batch_size):
images = croped_pil_image[i:i+batch_size]
image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
texts = [prompt] * len(images)
for i, txt in enumerate(texts):
input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
inputs['input_ids'].append(input['input_ids'])
inputs['attention_mask'].append(input['attention_mask'])
inputs['pixel_values'].append(input['pixel_values'])
inputs['image_sizes'].append(input['image_sizes'])
max_len = max([x.shape[1] for x in inputs['input_ids']])
for i, v in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
generation_args = {
"max_new_tokens": 25,
"temperature": 0.01,
"do_sample": False,
}
generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
# # remove input tokens
generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
response = [res.strip('\n').strip() for res in response]
generated_texts.extend(response)
return generated_texts
def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
assert ocr_bbox is None or isinstance(ocr_bbox, List)
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection + 1e-6
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
return max(intersection / union, ratio1, ratio2)
def is_inside(box1, box2):
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
intersection = intersection_area(box1, box2)
ratio1 = intersection / box_area(box1)
return ratio1 > 0.95
boxes = boxes.tolist()
filtered_boxes = []
if ocr_bbox:
filtered_boxes.extend(ocr_bbox)
# print('ocr_bbox!!!', ocr_bbox)
for i, box1 in enumerate(boxes):
# if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
is_valid_box = True
for j, box2 in enumerate(boxes):
# keep the smaller box
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
is_valid_box = False
break
if is_valid_box:
# add the following 2 lines to include ocr bbox
if ocr_bbox:
# only add the box if it does not overlap with any ocr bbox
if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
filtered_boxes.append(box1)
else:
filtered_boxes.append(box1)
return torch.tensor(filtered_boxes)
def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
'''
ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
'''
assert ocr_bbox is None or isinstance(ocr_bbox, List)
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection + 1e-6
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
return max(intersection / union, ratio1, ratio2)
def is_inside(box1, box2):
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
intersection = intersection_area(box1, box2)
ratio1 = intersection / box_area(box1)
return ratio1 > 0.80
# boxes = boxes.tolist()
filtered_boxes = []
if ocr_bbox:
filtered_boxes.extend(ocr_bbox)
# print('ocr_bbox!!!', ocr_bbox)
for i, box1_elem in enumerate(boxes):
box1 = box1_elem['bbox']
is_valid_box = True
for j, box2_elem in enumerate(boxes):
# keep the smaller box
box2 = box2_elem['bbox']
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
is_valid_box = False
break
if is_valid_box:
# add the following 2 lines to include ocr bbox
if ocr_bbox:
# keep yolo boxes + prioritize ocr label
box_added = False
ocr_labels = ''
for box3_elem in ocr_bbox:
if not box_added:
box3 = box3_elem['bbox']
if is_inside(box3, box1): # ocr inside icon
# box_added = True
# delete the box3_elem from ocr_bbox
try:
# filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content'], 'source':'box_yolo_content_ocr'})
# gather all ocr labels
ocr_labels += box3_elem['content'] + ' '
filtered_boxes.remove(box3_elem)
# print('remove ocr bbox:', box3_elem)
except:
continue
# break
elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
box_added = True
# try:
# filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
# filtered_boxes.remove(box3_elem)
# except:
# continue
break
else:
continue
if not box_added:
if ocr_labels:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
else:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
else:
filtered_boxes.append(box1)
return filtered_boxes # torch.tensor(filtered_boxes)
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
"""
This function annotates an image with bounding boxes and labels.
Parameters:
image_source (np.ndarray): The source image to be annotated.
boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
phrases (List[str]): A list of labels for each bounding box.
text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
Returns:
np.ndarray: The annotated image.
"""
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
detections = sv.Detections(xyxy=xyxy)
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
# from util.box_annotator import BoxAnnotator
box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
annotated_frame = image_source.copy()
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
return annotated_frame, label_coordinates
def predict(model, image, caption, box_threshold, text_threshold):
""" Use huggingface model to replace the original model
"""
model, processor = model['model'], model['processor']
device = model.device
inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold, # 0.4,
text_threshold=text_threshold, # 0.3,
target_sizes=[image.size[::-1]]
)[0]
boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
return boxes, logits, phrases
def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
""" Use huggingface model to replace the original model
"""
# model = model['model']
if scale_img:
result = model.predict(
source=image_path,
conf=box_threshold,
imgsz=imgsz,
iou=iou_threshold, # default 0.7
)
else:
result = model.predict(
source=image_path,
conf=box_threshold,
iou=iou_threshold, # default 0.7
)
boxes = result[0].boxes.xyxy#.tolist() # in pixel space
conf = result[0].boxes.conf
phrases = [str(i) for i in range(len(boxes))]
return boxes, conf, phrases
def int_box_area(box, w, h):
x1, y1, x2, y2 = box
int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
return area
def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
""" ocr_bbox: list of xyxy format bbox
"""
image_source = Image.open(img_path).convert("RGB")
w, h = image_source.size
if not imgsz:
imgsz = (h, w)
# print('image size:', w, h)
xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
image_source = np.asarray(image_source)
phrases = [str(i) for i in range(len(phrases))]
# annotate the image with labels
if ocr_bbox:
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
ocr_bbox=ocr_bbox.tolist()
else:
print('no ocr bbox!!!')
ocr_bbox = None
ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
# sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
# get the index of the first 'content': None
starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
# get parsed icon local semantics
time1 = time.time()
if use_local_semantics:
caption_model = caption_model_processor['model']
if 'phi3_v' in caption_model.config.model_type:
parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
else:
parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
icon_start = len(ocr_text)
parsed_content_icon_ls = []
# fill the filtered_boxes_elem None content with parsed_content_icon in order
for i, box in enumerate(filtered_boxes_elem):
if box['content'] is None:
box['content'] = parsed_content_icon.pop(0)
for i, txt in enumerate(parsed_content_icon):
parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
parsed_content_merged = ocr_text + parsed_content_icon_ls
else:
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
parsed_content_merged = ocr_text
print('time to get parsed content:', time.time()-time1)
filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
phrases = [i for i in range(len(filtered_boxes))]
# draw boxes
if draw_bbox_config:
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
else:
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
pil_img = Image.fromarray(annotated_frame)
buffered = io.BytesIO()
pil_img.save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
if output_coord_in_ratio:
# h, w, _ = image_source.shape
label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
return encoded_image, label_coordinates, filtered_boxes_elem
def get_xywh(input):
x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
x, y, w, h = int(x), int(y), int(w), int(h)
return x, y, w, h
def get_xyxy(input):
x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
x, y, xp, yp = int(x), int(y), int(xp), int(yp)
return x, y, xp, yp
def get_xywh_yolo(input):
x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
x, y, w, h = int(x), int(y), int(w), int(h)
return x, y, w, h
def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
if use_paddleocr:
if easyocr_args is None:
text_threshold = 0.5
else:
text_threshold = easyocr_args['text_threshold']
result = paddle_ocr.ocr(image_path, cls=False)[0]
# conf = [item[1] for item in result]
coord = [item[0] for item in result if item[1][1] > text_threshold]
text = [item[1][0] for item in result if item[1][1] > text_threshold]
else: # EasyOCR
if easyocr_args is None:
easyocr_args = {}
result = reader.readtext(image_path, **easyocr_args)
# print('goal filtering pred:', result[-5:])
coord = [item[0] for item in result]
text = [item[1] for item in result]
# read the image using cv2
if display_img:
opencv_img = cv2.imread(image_path)
opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
bb = []
for item in coord:
x, y, a, b = get_xywh(item)
# print(x, y, a, b)
bb.append((x, y, a, b))
cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
# Display the image
plt.imshow(opencv_img)
else:
if output_bb_format == 'xywh':
bb = [get_xywh(item) for item in coord]
elif output_bb_format == 'xyxy':
bb = [get_xyxy(item) for item in coord]
# print('bounding box!!!', bb)
return (text, bb), goal_filtering
from typing import List, Optional, Union, Tuple
import cv2
import numpy as np
from supervision.detection.core import Detections
from supervision.draw.color import Color, ColorPalette
class BoxAnnotator:
"""
A class for drawing bounding boxes on an image using detections provided.
Attributes:
color (Union[Color, ColorPalette]): The color to draw the bounding box,
can be a single color or a color palette
thickness (int): The thickness of the bounding box lines, default is 2
text_color (Color): The color of the text on the bounding box, default is white
text_scale (float): The scale of the text on the bounding box, default is 0.5
text_thickness (int): The thickness of the text on the bounding box,
default is 1
text_padding (int): The padding around the text on the bounding box,
default is 5
"""
def __init__(
self,
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
text_color: Color = Color.BLACK,
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
text_thickness: int = 2, #1, # 2 for demo
text_padding: int = 10,
avoid_overlap: bool = True,
):
self.color: Union[Color, ColorPalette] = color
self.thickness: int = thickness
self.text_color: Color = text_color
self.text_scale: float = text_scale
self.text_thickness: int = text_thickness
self.text_padding: int = text_padding
self.avoid_overlap: bool = avoid_overlap
def annotate(
self,
scene: np.ndarray,
detections: Detections,
labels: Optional[List[str]] = None,
skip_label: bool = False,
image_size: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
"""
Draws bounding boxes on the frame using the detections provided.
Args:
scene (np.ndarray): The image on which the bounding boxes will be drawn
detections (Detections): The detections for which the
bounding boxes will be drawn
labels (Optional[List[str]]): An optional list of labels
corresponding to each detection. If `labels` are not provided,
corresponding `class_id` will be used as label.
skip_label (bool): Is set to `True`, skips bounding box label annotation.
Returns:
np.ndarray: The image with the bounding boxes drawn on it
Example:
```python
import supervision as sv
classes = ['person', ...]
image = ...
detections = sv.Detections(...)
box_annotator = sv.BoxAnnotator()
labels = [
f"{classes[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _ in detections
]
annotated_frame = box_annotator.annotate(
scene=image.copy(),
detections=detections,
labels=labels
)
```
"""
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
class_id = (
detections.class_id[i] if detections.class_id is not None else None
)
idx = class_id if class_id is not None else i
color = (
self.color.by_idx(idx)
if isinstance(self.color, ColorPalette)
else self.color
)
cv2.rectangle(
img=scene,
pt1=(x1, y1),
pt2=(x2, y2),
color=color.as_bgr(),
thickness=self.thickness,
)
if skip_label:
continue
text = (
f"{class_id}"
if (labels is None or len(detections) != len(labels))
else labels[i]
)
text_width, text_height = cv2.getTextSize(
text=text,
fontFace=font,
fontScale=self.text_scale,
thickness=self.text_thickness,
)[0]
if not self.avoid_overlap:
text_x = x1 + self.text_padding
text_y = y1 - self.text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * self.text_padding - text_height
text_background_x2 = x1 + 2 * self.text_padding + text_width
text_background_y2 = y1
# text_x = x1 - self.text_padding - text_width
# text_y = y1 + self.text_padding + text_height
# text_background_x1 = x1 - 2 * self.text_padding - text_width
# text_background_y1 = y1
# text_background_x2 = x1
# text_background_y2 = y1 + 2 * self.text_padding + text_height
else:
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
cv2.rectangle(
img=scene,
pt1=(text_background_x1, text_background_y1),
pt2=(text_background_x2, text_background_y2),
color=color.as_bgr(),
thickness=cv2.FILLED,
)
box_color = color.as_rgb()
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
text_color = (0,0,0) if luminance > 160 else (255,255,255)
cv2.putText(
img=scene,
text=text,
org=(text_x, text_y),
fontFace=font,
fontScale=self.text_scale,
# color=self.text_color.as_rgb(),
color=text_color,
thickness=self.text_thickness,
lineType=cv2.LINE_AA,
)
return scene
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2, return_max=True):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
if return_max:
return max(intersection / union, ratio1, ratio2)
else:
return intersection / union
def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
""" check overlap of text and background detection box, and get_optimal_label_pos,
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
Threshold: default to 0.3
"""
def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
is_overlap = False
for i in range(len(detections)):
detection = detections.xyxy[i].astype(int)
if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
is_overlap = True
break
# check if the text is out of the image
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
is_overlap = True
return is_overlap
# if pos == 'top left':
text_x = x1 + text_padding
text_y = y1 - text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x1 + 2 * text_padding + text_width
text_background_y2 = y1
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer left':
text_x = x1 - text_padding - text_width
text_y = y1 + text_padding + text_height
text_background_x1 = x1 - 2 * text_padding - text_width
text_background_y1 = y1
text_background_x2 = x1
text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer right':
text_x = x2 + text_padding
text_y = y1 + text_padding + text_height
text_background_x1 = x2
text_background_y1 = y1
text_background_x2 = x2 + 2 * text_padding + text_width
text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'top right':
text_x = x2 - text_padding - text_width
text_y = y1 - text_padding
text_background_x1 = x2 - 2 * text_padding - text_width
text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x2
text_background_y2 = y1
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
import re
def extract_dict_from_text(text):
# Define the regex pattern for a dictionary-like structure
pattern = r"\{\s*'(?P<key1>.*?)':\s*'(?P<value1>.*?)',\s*'(?P<key2>.*?)':\s*'(?P<value2>.*?)'\s*\}"
# Search for the dictionary in the text
match = re.search(pattern, text)
if match:
# Extract matched groups into a dictionary
return {
match.group('key1'): match.group('value1'),
match.group('key2'): match.group('value2'),
}
else:
raise ValueError("No valid dictionary structure found in the text.")
def get_phi3v_model_dict():
from PIL import Image
import requests
from transformers import AutoModelForCausalLM
from transformers import AutoProcessor
model_id = "microsoft/Phi-3.5-vision-instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto", _attn_implementation='flash_attention_2')
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
print('phi3v model loaded!!!')
return {'model': model, 'processor': processor}
def call_phi3v(messages, image_base64, model_dict):
model, processor = model_dict['model'], model_dict['processor']
device = model.device
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
if isinstance(image_base64, tuple):
image_base64, dino_labled_img = image_base64
image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
dino_labled_img = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
inputs = processor(prompt, [image, dino_labled_img], return_tensors="pt").to(device)
else:
image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
inputs = processor(prompt, [image], return_tensors="pt").to(device)
generation_args = {
"max_new_tokens": 512,
"temperature": 0.01,
"do_sample": False,
}
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
# remove input tokens
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
ans = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return ans
def get_pred_phi3v(message_text, image_base64, label_coordinates, id_key='Click ID', model_dict=None):
# messages = [
# {"role": "system", "content": '''You are an expert at completing instructions on GUI screens.
# You will be presented with two images. The first is the original screenshot. The second is the same screenshot with some numeric tags. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
# ]
messages = [
{"role": "system", "content": '''You are an expert at completing instructions on GUI screens. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
]
messages = []
if isinstance(image_base64, tuple):
messages.append({"role": "user", "content": '<|image_1|>\n' + '<|image_2|>\n' + message_text})
else:
messages.append({"role": "user", "content": '<|image_1|>\n' + message_text})
response_text = call_phi3v(messages, image_base64, model_dict)
print(response_text)
try:
response_text = ast.literal_eval(response_text)
icon_id = response_text['Click BBox ID']
bbox = label_coordinates[str(icon_id)]
click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
except:
print('error parsing, use regex to parse!!!')
import pdb; pdb.set_trace()
response_text = extract_dict_from_text(response_text)
icon_id = response_text['Click BBox ID']
bbox = label_coordinates[str(icon_id)]
click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
return icon_id, bbox, click_point, response_text
# try:
# match = re.search(r"```(.*?)```", ans, re.DOTALL)
# if match:
# result = match.group(1).strip()
# pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
# pred = ast.literal_eval(pred)
# else:
# pred = ans.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
# pred = ast.literal_eval(pred)
# if pred[id_key]:
# icon_id = pred[id_key]
# bbox = label_coordinates[str(icon_id)]
# pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
# except:
# print('phi3v action regex extract fail!!!')
# pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False}
# step_pred_summary = None
# return pred, [True, ans, None, step_pred_summary]