Files
OmniParser/demo/gradio/computer_use_demo/omniparser_agent/utils.py

935 lines
39 KiB
Python
Executable File

# from ultralytics import YOLO
import os
import io
import base64
import time
from PIL import Image, ImageDraw, ImageFont
import json
import requests
# utility function
import os
from openai import AzureOpenAI
import json
import sys
import os
import cv2
import numpy as np
# %matplotlib inline
from matplotlib import pyplot as plt
import easyocr
from paddleocr import PaddleOCR
reader = easyocr.Reader(['en', 'ch_sim'], gpu=True)
paddle_ocr = PaddleOCR(
lang='en', # other lang also available
use_angle_cls=False,
use_gpu=False, # using cuda will conflict with pytorch in the same process
show_log=False,
max_batch_size=1024,
use_dilation=True, # improves accuracy
det_db_score_mode='slow', # improves accuracy
rec_batch_num=1024)
import time
import base64
import os
import ast
import torch
from typing import Tuple, List
from torchvision.ops import box_convert
import re
from torchvision.transforms import ToPILImage
import supervision as sv
import torchvision.transforms as T
def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
if not device:
device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name == "blip2":
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
if device == 'cpu':
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float32
)
else:
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float16
).to(device)
elif model_name == "florence2":
from transformers import AutoProcessor, AutoModelForCausalLM
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
return {'model': model.to(device), 'processor': processor}
def get_yolo_model(model_path):
from ultralytics import YOLO
# Load the model.
model = YOLO(model_path)
return model
@torch.inference_mode()
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
# Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
to_pil = ToPILImage()
if starting_idx:
non_ocr_boxes = filtered_boxes[starting_idx:]
else:
non_ocr_boxes = filtered_boxes
croped_pil_image = []
t0 = time.time()
for i, coord in enumerate(non_ocr_boxes):
try:
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
# resize the image to 224x224 to avoid long overhead in clipimageprocessor # TODO
cropped_image = cv2.resize(cropped_image, (64, 64))
croped_pil_image.append(to_pil(cropped_image))
except:
continue
# print('time to prepare bbox:', time.time()-t0)
model, processor = caption_model_processor['model'], caption_model_processor['processor']
if not prompt:
if 'florence' in model.config.name_or_path:
prompt = "<CAPTION>"
else:
prompt = "The image shows"
generated_texts = []
device = model.device
# batch_size = 64
for i in range(0, len(croped_pil_image), batch_size):
start = time.time()
batch = croped_pil_image[i:i+batch_size]
t1 = time.time()
if model.device.type == 'cuda':
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
else:
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
t2 = time.time()
# print('time to process image + tokenize text inputs:', t2-t1)
if 'florence' in model.config.name_or_path:
generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
else:
generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
t3 = time.time()
# print('time to generate:', t3-t2)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
generated_text = [gen.strip() for gen in generated_text]
generated_texts.extend(generated_text)
return generated_texts
def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
to_pil = ToPILImage()
if ocr_bbox:
non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
else:
non_ocr_boxes = filtered_boxes
croped_pil_image = []
for i, coord in enumerate(non_ocr_boxes):
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
croped_pil_image.append(to_pil(cropped_image))
model, processor = caption_model_processor['model'], caption_model_processor['processor']
device = model.device
messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
batch_size = 5 # Number of samples per batch
generated_texts = []
for i in range(0, len(croped_pil_image), batch_size):
images = croped_pil_image[i:i+batch_size]
image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
texts = [prompt] * len(images)
for i, txt in enumerate(texts):
input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
inputs['input_ids'].append(input['input_ids'])
inputs['attention_mask'].append(input['attention_mask'])
inputs['pixel_values'].append(input['pixel_values'])
inputs['image_sizes'].append(input['image_sizes'])
max_len = max([x.shape[1] for x in inputs['input_ids']])
for i, v in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
generation_args = {
"max_new_tokens": 25,
"temperature": 0.01,
"do_sample": False,
}
generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
# # remove input tokens
generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
response = [res.strip('\n').strip() for res in response]
generated_texts.extend(response)
return generated_texts
def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
assert ocr_bbox is None or isinstance(ocr_bbox, List)
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection + 1e-6
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
return max(intersection / union, ratio1, ratio2)
def is_inside(box1, box2):
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
intersection = intersection_area(box1, box2)
ratio1 = intersection / box_area(box1)
return ratio1 > 0.95
boxes = boxes.tolist()
filtered_boxes = []
if ocr_bbox:
filtered_boxes.extend(ocr_bbox)
# print('ocr_bbox!!!', ocr_bbox)
for i, box1 in enumerate(boxes):
# if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
is_valid_box = True
for j, box2 in enumerate(boxes):
# keep the smaller box
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
is_valid_box = False
break
if is_valid_box:
# add the following 2 lines to include ocr bbox
if ocr_bbox:
# only add the box if it does not overlap with any ocr bbox
if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
filtered_boxes.append(box1)
else:
filtered_boxes.append(box1)
return torch.tensor(filtered_boxes)
def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
'''
ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
'''
assert ocr_bbox is None or isinstance(ocr_bbox, List)
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection + 1e-6
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
return max(intersection / union, ratio1, ratio2)
def is_inside(box1, box2):
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
intersection = intersection_area(box1, box2)
ratio1 = intersection / box_area(box1)
return ratio1 > 0.80
# boxes = boxes.tolist()
filtered_boxes = []
if ocr_bbox:
filtered_boxes.extend(ocr_bbox)
# print('ocr_bbox!!!', ocr_bbox)
for i, box1_elem in enumerate(boxes):
box1 = box1_elem['bbox']
is_valid_box = True
for j, box2_elem in enumerate(boxes):
# keep the smaller box
box2 = box2_elem['bbox']
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
is_valid_box = False
break
if is_valid_box:
# add the following 2 lines to include ocr bbox
if ocr_bbox:
# keep yolo boxes + prioritize ocr label
box_added = False
ocr_labels = ''
for box3_elem in ocr_bbox:
if not box_added:
box3 = box3_elem['bbox']
if is_inside(box3, box1): # ocr inside icon
# box_added = True
# delete the box3_elem from ocr_bbox
try:
# filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content'], 'source':'box_yolo_content_ocr'})
# gather all ocr labels
ocr_labels += box3_elem['content'] + ' '
filtered_boxes.remove(box3_elem)
# print('remove ocr bbox:', box3_elem)
except:
continue
# break
elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
box_added = True
# try:
# filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
# filtered_boxes.remove(box3_elem)
# except:
# continue
break
else:
continue
if not box_added:
if ocr_labels:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
else:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
else:
filtered_boxes.append(box1)
return filtered_boxes # torch.tensor(filtered_boxes)
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
"""
This function annotates an image with bounding boxes and labels.
Parameters:
image_source (np.ndarray): The source image to be annotated.
boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
phrases (List[str]): A list of labels for each bounding box.
text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
Returns:
np.ndarray: The annotated image.
"""
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
detections = sv.Detections(xyxy=xyxy)
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
# from util.box_annotator import BoxAnnotator
box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
annotated_frame = image_source.copy()
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
return annotated_frame, label_coordinates
def predict(model, image, caption, box_threshold, text_threshold):
""" Use huggingface model to replace the original model
"""
model, processor = model['model'], model['processor']
device = model.device
inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold, # 0.4,
text_threshold=text_threshold, # 0.3,
target_sizes=[image.size[::-1]]
)[0]
boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
return boxes, logits, phrases
def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
""" Use huggingface model to replace the original model
"""
# model = model['model']
if scale_img:
result = model.predict(
source=image_path,
conf=box_threshold,
imgsz=imgsz,
iou=iou_threshold, # default 0.7
)
else:
result = model.predict(
source=image_path,
conf=box_threshold,
iou=iou_threshold, # default 0.7
)
boxes = result[0].boxes.xyxy#.tolist() # in pixel space
conf = result[0].boxes.conf
phrases = [str(i) for i in range(len(boxes))]
return boxes, conf, phrases
def int_box_area(box, w, h):
x1, y1, x2, y2 = box
int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
return area
def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
""" ocr_bbox: list of xyxy format bbox
"""
image_source = Image.open(img_path).convert("RGB")
w, h = image_source.size
if not imgsz:
imgsz = (h, w)
# print('image size:', w, h)
xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
image_source = np.asarray(image_source)
phrases = [str(i) for i in range(len(phrases))]
# annotate the image with labels
if ocr_bbox:
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
ocr_bbox=ocr_bbox.tolist()
else:
print('no ocr bbox!!!')
ocr_bbox = None
ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
# sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
# get the index of the first 'content': None
starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
# get parsed icon local semantics
time1 = time.time()
if use_local_semantics:
caption_model = caption_model_processor['model']
if 'phi3_v' in caption_model.config.model_type:
parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
else:
parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
icon_start = len(ocr_text)
parsed_content_icon_ls = []
# fill the filtered_boxes_elem None content with parsed_content_icon in order
for i, box in enumerate(filtered_boxes_elem):
if box['content'] is None:
box['content'] = parsed_content_icon.pop(0)
for i, txt in enumerate(parsed_content_icon):
parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
parsed_content_merged = ocr_text + parsed_content_icon_ls
else:
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
parsed_content_merged = ocr_text
print('time to get parsed content:', time.time()-time1)
filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
phrases = [i for i in range(len(filtered_boxes))]
# draw boxes
if draw_bbox_config:
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
else:
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
pil_img = Image.fromarray(annotated_frame)
buffered = io.BytesIO()
pil_img.save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
if output_coord_in_ratio:
# h, w, _ = image_source.shape
label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
return encoded_image, label_coordinates, filtered_boxes_elem
def get_xywh(input):
x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
x, y, w, h = int(x), int(y), int(w), int(h)
return x, y, w, h
def get_xyxy(input):
x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
x, y, xp, yp = int(x), int(y), int(xp), int(yp)
return x, y, xp, yp
def get_xywh_yolo(input):
x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
x, y, w, h = int(x), int(y), int(w), int(h)
return x, y, w, h
def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
if use_paddleocr:
if easyocr_args is None:
text_threshold = 0.5
else:
text_threshold = easyocr_args['text_threshold']
result = paddle_ocr.ocr(image_path, cls=False)[0]
# conf = [item[1] for item in result]
coord = [item[0] for item in result if item[1][1] > text_threshold]
text = [item[1][0] for item in result if item[1][1] > text_threshold]
else: # EasyOCR
if easyocr_args is None:
easyocr_args = {}
result = reader.readtext(image_path, **easyocr_args)
# print('goal filtering pred:', result[-5:])
coord = [item[0] for item in result]
text = [item[1] for item in result]
# read the image using cv2
if display_img:
opencv_img = cv2.imread(image_path)
opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
bb = []
for item in coord:
x, y, a, b = get_xywh(item)
# print(x, y, a, b)
bb.append((x, y, a, b))
cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
# Display the image
plt.imshow(opencv_img)
else:
if output_bb_format == 'xywh':
bb = [get_xywh(item) for item in coord]
elif output_bb_format == 'xyxy':
bb = [get_xyxy(item) for item in coord]
# print('bounding box!!!', bb)
return (text, bb), goal_filtering
from typing import List, Optional, Union, Tuple
import cv2
import numpy as np
from supervision.detection.core import Detections
from supervision.draw.color import Color, ColorPalette
class BoxAnnotator:
"""
A class for drawing bounding boxes on an image using detections provided.
Attributes:
color (Union[Color, ColorPalette]): The color to draw the bounding box,
can be a single color or a color palette
thickness (int): The thickness of the bounding box lines, default is 2
text_color (Color): The color of the text on the bounding box, default is white
text_scale (float): The scale of the text on the bounding box, default is 0.5
text_thickness (int): The thickness of the text on the bounding box,
default is 1
text_padding (int): The padding around the text on the bounding box,
default is 5
"""
def __init__(
self,
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
text_color: Color = Color.BLACK,
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
text_thickness: int = 2, #1, # 2 for demo
text_padding: int = 10,
avoid_overlap: bool = True,
):
self.color: Union[Color, ColorPalette] = color
self.thickness: int = thickness
self.text_color: Color = text_color
self.text_scale: float = text_scale
self.text_thickness: int = text_thickness
self.text_padding: int = text_padding
self.avoid_overlap: bool = avoid_overlap
def annotate(
self,
scene: np.ndarray,
detections: Detections,
labels: Optional[List[str]] = None,
skip_label: bool = False,
image_size: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
"""
Draws bounding boxes on the frame using the detections provided.
Args:
scene (np.ndarray): The image on which the bounding boxes will be drawn
detections (Detections): The detections for which the
bounding boxes will be drawn
labels (Optional[List[str]]): An optional list of labels
corresponding to each detection. If `labels` are not provided,
corresponding `class_id` will be used as label.
skip_label (bool): Is set to `True`, skips bounding box label annotation.
Returns:
np.ndarray: The image with the bounding boxes drawn on it
Example:
```python
import supervision as sv
classes = ['person', ...]
image = ...
detections = sv.Detections(...)
box_annotator = sv.BoxAnnotator()
labels = [
f"{classes[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _ in detections
]
annotated_frame = box_annotator.annotate(
scene=image.copy(),
detections=detections,
labels=labels
)
```
"""
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
class_id = (
detections.class_id[i] if detections.class_id is not None else None
)
idx = class_id if class_id is not None else i
color = (
self.color.by_idx(idx)
if isinstance(self.color, ColorPalette)
else self.color
)
cv2.rectangle(
img=scene,
pt1=(x1, y1),
pt2=(x2, y2),
color=color.as_bgr(),
thickness=self.thickness,
)
if skip_label:
continue
text = (
f"{class_id}"
if (labels is None or len(detections) != len(labels))
else labels[i]
)
text_width, text_height = cv2.getTextSize(
text=text,
fontFace=font,
fontScale=self.text_scale,
thickness=self.text_thickness,
)[0]
if not self.avoid_overlap:
text_x = x1 + self.text_padding
text_y = y1 - self.text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * self.text_padding - text_height
text_background_x2 = x1 + 2 * self.text_padding + text_width
text_background_y2 = y1
# text_x = x1 - self.text_padding - text_width
# text_y = y1 + self.text_padding + text_height
# text_background_x1 = x1 - 2 * self.text_padding - text_width
# text_background_y1 = y1
# text_background_x2 = x1
# text_background_y2 = y1 + 2 * self.text_padding + text_height
else:
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
cv2.rectangle(
img=scene,
pt1=(text_background_x1, text_background_y1),
pt2=(text_background_x2, text_background_y2),
color=color.as_bgr(),
thickness=cv2.FILLED,
)
box_color = color.as_rgb()
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
text_color = (0,0,0) if luminance > 160 else (255,255,255)
cv2.putText(
img=scene,
text=text,
org=(text_x, text_y),
fontFace=font,
fontScale=self.text_scale,
# color=self.text_color.as_rgb(),
color=text_color,
thickness=self.text_thickness,
lineType=cv2.LINE_AA,
)
return scene
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2, return_max=True):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
if return_max:
return max(intersection / union, ratio1, ratio2)
else:
return intersection / union
def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
""" check overlap of text and background detection box, and get_optimal_label_pos,
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
Threshold: default to 0.3
"""
def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
is_overlap = False
for i in range(len(detections)):
detection = detections.xyxy[i].astype(int)
if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
is_overlap = True
break
# check if the text is out of the image
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
is_overlap = True
return is_overlap
# if pos == 'top left':
text_x = x1 + text_padding
text_y = y1 - text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x1 + 2 * text_padding + text_width
text_background_y2 = y1
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer left':
text_x = x1 - text_padding - text_width
text_y = y1 + text_padding + text_height
text_background_x1 = x1 - 2 * text_padding - text_width
text_background_y1 = y1
text_background_x2 = x1
text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer right':
text_x = x2 + text_padding
text_y = y1 + text_padding + text_height
text_background_x1 = x2
text_background_y1 = y1
text_background_x2 = x2 + 2 * text_padding + text_width
text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'top right':
text_x = x2 - text_padding - text_width
text_y = y1 - text_padding
text_background_x1 = x2 - 2 * text_padding - text_width
text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x2
text_background_y2 = y1
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
import re
def extract_dict_from_text(text):
# Define the regex pattern for a dictionary-like structure
pattern = r"\{\s*'(?P<key1>.*?)':\s*'(?P<value1>.*?)',\s*'(?P<key2>.*?)':\s*'(?P<value2>.*?)'\s*\}"
# Search for the dictionary in the text
match = re.search(pattern, text)
if match:
# Extract matched groups into a dictionary
return {
match.group('key1'): match.group('value1'),
match.group('key2'): match.group('value2'),
}
else:
raise ValueError("No valid dictionary structure found in the text.")
def get_phi3v_model_dict():
from PIL import Image
import requests
from transformers import AutoModelForCausalLM
from transformers import AutoProcessor
model_id = "microsoft/Phi-3.5-vision-instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto", _attn_implementation='flash_attention_2')
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
print('phi3v model loaded!!!')
return {'model': model, 'processor': processor}
def call_phi3v(messages, image_base64, model_dict):
model, processor = model_dict['model'], model_dict['processor']
device = model.device
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
if isinstance(image_base64, tuple):
image_base64, dino_labled_img = image_base64
image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
dino_labled_img = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
inputs = processor(prompt, [image, dino_labled_img], return_tensors="pt").to(device)
else:
image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
inputs = processor(prompt, [image], return_tensors="pt").to(device)
generation_args = {
"max_new_tokens": 512,
"temperature": 0.01,
"do_sample": False,
}
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
# remove input tokens
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
ans = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return ans
def get_pred_phi3v(message_text, image_base64, label_coordinates, id_key='Click ID', model_dict=None):
# messages = [
# {"role": "system", "content": '''You are an expert at completing instructions on GUI screens.
# You will be presented with two images. The first is the original screenshot. The second is the same screenshot with some numeric tags. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
# ]
messages = [
{"role": "system", "content": '''You are an expert at completing instructions on GUI screens. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
]
messages = []
if isinstance(image_base64, tuple):
messages.append({"role": "user", "content": '<|image_1|>\n' + '<|image_2|>\n' + message_text})
else:
messages.append({"role": "user", "content": '<|image_1|>\n' + message_text})
response_text = call_phi3v(messages, image_base64, model_dict)
print(response_text)
try:
response_text = ast.literal_eval(response_text)
icon_id = response_text['Click BBox ID']
bbox = label_coordinates[str(icon_id)]
click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
except:
print('error parsing, use regex to parse!!!')
import pdb; pdb.set_trace()
response_text = extract_dict_from_text(response_text)
icon_id = response_text['Click BBox ID']
bbox = label_coordinates[str(icon_id)]
click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
return icon_id, bbox, click_point, response_text
# try:
# match = re.search(r"```(.*?)```", ans, re.DOTALL)
# if match:
# result = match.group(1).strip()
# pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
# pred = ast.literal_eval(pred)
# else:
# pred = ans.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
# pred = ast.literal_eval(pred)
# if pred[id_key]:
# icon_id = pred[id_key]
# bbox = label_coordinates[str(icon_id)]
# pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
# except:
# print('phi3v action regex extract fail!!!')
# pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False}
# step_pred_summary = None
# return pred, [True, ans, None, step_pred_summary]