935 lines
39 KiB
Python
Executable File
935 lines
39 KiB
Python
Executable File
# from ultralytics import YOLO
|
|
import os
|
|
import io
|
|
import base64
|
|
import time
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
import json
|
|
import requests
|
|
# utility function
|
|
import os
|
|
from openai import AzureOpenAI
|
|
|
|
import json
|
|
import sys
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
# %matplotlib inline
|
|
from matplotlib import pyplot as plt
|
|
import easyocr
|
|
from paddleocr import PaddleOCR
|
|
reader = easyocr.Reader(['en', 'ch_sim'], gpu=True)
|
|
paddle_ocr = PaddleOCR(
|
|
lang='en', # other lang also available
|
|
use_angle_cls=False,
|
|
use_gpu=False, # using cuda will conflict with pytorch in the same process
|
|
show_log=False,
|
|
max_batch_size=1024,
|
|
use_dilation=True, # improves accuracy
|
|
det_db_score_mode='slow', # improves accuracy
|
|
rec_batch_num=1024)
|
|
import time
|
|
import base64
|
|
|
|
import os
|
|
import ast
|
|
import torch
|
|
from typing import Tuple, List
|
|
from torchvision.ops import box_convert
|
|
import re
|
|
from torchvision.transforms import ToPILImage
|
|
import supervision as sv
|
|
import torchvision.transforms as T
|
|
|
|
|
|
def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
|
|
if not device:
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
if model_name == "blip2":
|
|
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
|
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
|
if device == 'cpu':
|
|
model = Blip2ForConditionalGeneration.from_pretrained(
|
|
model_name_or_path, device_map=None, torch_dtype=torch.float32
|
|
)
|
|
else:
|
|
model = Blip2ForConditionalGeneration.from_pretrained(
|
|
model_name_or_path, device_map=None, torch_dtype=torch.float16
|
|
).to(device)
|
|
elif model_name == "florence2":
|
|
from transformers import AutoProcessor, AutoModelForCausalLM
|
|
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
|
|
if device == 'cpu':
|
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
|
|
else:
|
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
|
|
return {'model': model.to(device), 'processor': processor}
|
|
|
|
|
|
def get_yolo_model(model_path):
|
|
from ultralytics import YOLO
|
|
# Load the model.
|
|
model = YOLO(model_path)
|
|
return model
|
|
|
|
|
|
@torch.inference_mode()
|
|
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
|
|
# Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
|
|
|
|
to_pil = ToPILImage()
|
|
if starting_idx:
|
|
non_ocr_boxes = filtered_boxes[starting_idx:]
|
|
else:
|
|
non_ocr_boxes = filtered_boxes
|
|
croped_pil_image = []
|
|
t0 = time.time()
|
|
for i, coord in enumerate(non_ocr_boxes):
|
|
try:
|
|
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
|
|
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
|
|
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
|
# resize the image to 224x224 to avoid long overhead in clipimageprocessor # TODO
|
|
cropped_image = cv2.resize(cropped_image, (64, 64))
|
|
croped_pil_image.append(to_pil(cropped_image))
|
|
except:
|
|
continue
|
|
# print('time to prepare bbox:', time.time()-t0)
|
|
|
|
model, processor = caption_model_processor['model'], caption_model_processor['processor']
|
|
if not prompt:
|
|
if 'florence' in model.config.name_or_path:
|
|
prompt = "<CAPTION>"
|
|
else:
|
|
prompt = "The image shows"
|
|
|
|
generated_texts = []
|
|
device = model.device
|
|
# batch_size = 64
|
|
for i in range(0, len(croped_pil_image), batch_size):
|
|
start = time.time()
|
|
batch = croped_pil_image[i:i+batch_size]
|
|
t1 = time.time()
|
|
if model.device.type == 'cuda':
|
|
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
|
|
else:
|
|
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
|
|
t2 = time.time()
|
|
# print('time to process image + tokenize text inputs:', t2-t1)
|
|
if 'florence' in model.config.name_or_path:
|
|
generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
|
|
else:
|
|
generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
|
|
t3 = time.time()
|
|
# print('time to generate:', t3-t2)
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
generated_text = [gen.strip() for gen in generated_text]
|
|
generated_texts.extend(generated_text)
|
|
|
|
return generated_texts
|
|
|
|
|
|
|
|
def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
|
|
to_pil = ToPILImage()
|
|
if ocr_bbox:
|
|
non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
|
|
else:
|
|
non_ocr_boxes = filtered_boxes
|
|
croped_pil_image = []
|
|
for i, coord in enumerate(non_ocr_boxes):
|
|
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
|
|
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
|
|
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
|
croped_pil_image.append(to_pil(cropped_image))
|
|
|
|
model, processor = caption_model_processor['model'], caption_model_processor['processor']
|
|
device = model.device
|
|
messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
|
|
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
|
|
batch_size = 5 # Number of samples per batch
|
|
generated_texts = []
|
|
|
|
for i in range(0, len(croped_pil_image), batch_size):
|
|
images = croped_pil_image[i:i+batch_size]
|
|
image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
|
|
inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
|
|
texts = [prompt] * len(images)
|
|
for i, txt in enumerate(texts):
|
|
input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
|
|
inputs['input_ids'].append(input['input_ids'])
|
|
inputs['attention_mask'].append(input['attention_mask'])
|
|
inputs['pixel_values'].append(input['pixel_values'])
|
|
inputs['image_sizes'].append(input['image_sizes'])
|
|
max_len = max([x.shape[1] for x in inputs['input_ids']])
|
|
for i, v in enumerate(inputs['input_ids']):
|
|
inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
|
|
inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
|
|
inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
|
|
|
|
generation_args = {
|
|
"max_new_tokens": 25,
|
|
"temperature": 0.01,
|
|
"do_sample": False,
|
|
}
|
|
generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
|
|
# # remove input tokens
|
|
generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
|
|
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
response = [res.strip('\n').strip() for res in response]
|
|
generated_texts.extend(response)
|
|
|
|
return generated_texts
|
|
|
|
def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
|
|
assert ocr_bbox is None or isinstance(ocr_bbox, List)
|
|
|
|
def box_area(box):
|
|
return (box[2] - box[0]) * (box[3] - box[1])
|
|
|
|
def intersection_area(box1, box2):
|
|
x1 = max(box1[0], box2[0])
|
|
y1 = max(box1[1], box2[1])
|
|
x2 = min(box1[2], box2[2])
|
|
y2 = min(box1[3], box2[3])
|
|
return max(0, x2 - x1) * max(0, y2 - y1)
|
|
|
|
def IoU(box1, box2):
|
|
intersection = intersection_area(box1, box2)
|
|
union = box_area(box1) + box_area(box2) - intersection + 1e-6
|
|
if box_area(box1) > 0 and box_area(box2) > 0:
|
|
ratio1 = intersection / box_area(box1)
|
|
ratio2 = intersection / box_area(box2)
|
|
else:
|
|
ratio1, ratio2 = 0, 0
|
|
return max(intersection / union, ratio1, ratio2)
|
|
|
|
def is_inside(box1, box2):
|
|
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
|
|
intersection = intersection_area(box1, box2)
|
|
ratio1 = intersection / box_area(box1)
|
|
return ratio1 > 0.95
|
|
|
|
boxes = boxes.tolist()
|
|
filtered_boxes = []
|
|
if ocr_bbox:
|
|
filtered_boxes.extend(ocr_bbox)
|
|
# print('ocr_bbox!!!', ocr_bbox)
|
|
for i, box1 in enumerate(boxes):
|
|
# if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
|
|
is_valid_box = True
|
|
for j, box2 in enumerate(boxes):
|
|
# keep the smaller box
|
|
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
|
|
is_valid_box = False
|
|
break
|
|
if is_valid_box:
|
|
# add the following 2 lines to include ocr bbox
|
|
if ocr_bbox:
|
|
# only add the box if it does not overlap with any ocr bbox
|
|
if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
|
|
filtered_boxes.append(box1)
|
|
else:
|
|
filtered_boxes.append(box1)
|
|
return torch.tensor(filtered_boxes)
|
|
|
|
|
|
def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
|
|
'''
|
|
ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
|
|
boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
|
|
|
|
'''
|
|
assert ocr_bbox is None or isinstance(ocr_bbox, List)
|
|
|
|
def box_area(box):
|
|
return (box[2] - box[0]) * (box[3] - box[1])
|
|
|
|
def intersection_area(box1, box2):
|
|
x1 = max(box1[0], box2[0])
|
|
y1 = max(box1[1], box2[1])
|
|
x2 = min(box1[2], box2[2])
|
|
y2 = min(box1[3], box2[3])
|
|
return max(0, x2 - x1) * max(0, y2 - y1)
|
|
|
|
def IoU(box1, box2):
|
|
intersection = intersection_area(box1, box2)
|
|
union = box_area(box1) + box_area(box2) - intersection + 1e-6
|
|
if box_area(box1) > 0 and box_area(box2) > 0:
|
|
ratio1 = intersection / box_area(box1)
|
|
ratio2 = intersection / box_area(box2)
|
|
else:
|
|
ratio1, ratio2 = 0, 0
|
|
return max(intersection / union, ratio1, ratio2)
|
|
|
|
def is_inside(box1, box2):
|
|
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
|
|
intersection = intersection_area(box1, box2)
|
|
ratio1 = intersection / box_area(box1)
|
|
return ratio1 > 0.80
|
|
|
|
# boxes = boxes.tolist()
|
|
filtered_boxes = []
|
|
if ocr_bbox:
|
|
filtered_boxes.extend(ocr_bbox)
|
|
# print('ocr_bbox!!!', ocr_bbox)
|
|
for i, box1_elem in enumerate(boxes):
|
|
box1 = box1_elem['bbox']
|
|
is_valid_box = True
|
|
for j, box2_elem in enumerate(boxes):
|
|
# keep the smaller box
|
|
box2 = box2_elem['bbox']
|
|
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
|
|
is_valid_box = False
|
|
break
|
|
if is_valid_box:
|
|
# add the following 2 lines to include ocr bbox
|
|
if ocr_bbox:
|
|
# keep yolo boxes + prioritize ocr label
|
|
box_added = False
|
|
ocr_labels = ''
|
|
for box3_elem in ocr_bbox:
|
|
if not box_added:
|
|
box3 = box3_elem['bbox']
|
|
if is_inside(box3, box1): # ocr inside icon
|
|
# box_added = True
|
|
# delete the box3_elem from ocr_bbox
|
|
try:
|
|
# filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content'], 'source':'box_yolo_content_ocr'})
|
|
# gather all ocr labels
|
|
ocr_labels += box3_elem['content'] + ' '
|
|
filtered_boxes.remove(box3_elem)
|
|
# print('remove ocr bbox:', box3_elem)
|
|
except:
|
|
continue
|
|
# break
|
|
elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
|
|
box_added = True
|
|
# try:
|
|
# filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
|
|
# filtered_boxes.remove(box3_elem)
|
|
# except:
|
|
# continue
|
|
break
|
|
else:
|
|
continue
|
|
if not box_added:
|
|
if ocr_labels:
|
|
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
|
|
else:
|
|
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
|
|
|
|
else:
|
|
filtered_boxes.append(box1)
|
|
return filtered_boxes # torch.tensor(filtered_boxes)
|
|
|
|
|
|
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
|
|
transform = T.Compose(
|
|
[
|
|
T.RandomResize([800], max_size=1333),
|
|
T.ToTensor(),
|
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
]
|
|
)
|
|
image_source = Image.open(image_path).convert("RGB")
|
|
image = np.asarray(image_source)
|
|
image_transformed, _ = transform(image_source, None)
|
|
return image, image_transformed
|
|
|
|
|
|
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
|
|
text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
|
|
"""
|
|
This function annotates an image with bounding boxes and labels.
|
|
|
|
Parameters:
|
|
image_source (np.ndarray): The source image to be annotated.
|
|
boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
|
|
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
|
|
phrases (List[str]): A list of labels for each bounding box.
|
|
text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
|
|
|
Returns:
|
|
np.ndarray: The annotated image.
|
|
"""
|
|
h, w, _ = image_source.shape
|
|
boxes = boxes * torch.Tensor([w, h, w, h])
|
|
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
|
xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
|
|
detections = sv.Detections(xyxy=xyxy)
|
|
|
|
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
|
|
|
|
# from util.box_annotator import BoxAnnotator
|
|
box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
|
annotated_frame = image_source.copy()
|
|
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
|
|
|
|
label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
|
|
return annotated_frame, label_coordinates
|
|
|
|
|
|
def predict(model, image, caption, box_threshold, text_threshold):
|
|
""" Use huggingface model to replace the original model
|
|
"""
|
|
model, processor = model['model'], model['processor']
|
|
device = model.device
|
|
|
|
inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
results = processor.post_process_grounded_object_detection(
|
|
outputs,
|
|
inputs.input_ids,
|
|
box_threshold=box_threshold, # 0.4,
|
|
text_threshold=text_threshold, # 0.3,
|
|
target_sizes=[image.size[::-1]]
|
|
)[0]
|
|
boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
|
|
return boxes, logits, phrases
|
|
|
|
|
|
def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
|
|
""" Use huggingface model to replace the original model
|
|
"""
|
|
# model = model['model']
|
|
if scale_img:
|
|
result = model.predict(
|
|
source=image_path,
|
|
conf=box_threshold,
|
|
imgsz=imgsz,
|
|
iou=iou_threshold, # default 0.7
|
|
)
|
|
else:
|
|
result = model.predict(
|
|
source=image_path,
|
|
conf=box_threshold,
|
|
iou=iou_threshold, # default 0.7
|
|
)
|
|
boxes = result[0].boxes.xyxy#.tolist() # in pixel space
|
|
conf = result[0].boxes.conf
|
|
phrases = [str(i) for i in range(len(boxes))]
|
|
|
|
return boxes, conf, phrases
|
|
|
|
|
|
def int_box_area(box, w, h):
|
|
x1, y1, x2, y2 = box
|
|
int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
|
|
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
|
|
return area
|
|
|
|
|
|
def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
|
|
""" ocr_bbox: list of xyxy format bbox
|
|
"""
|
|
image_source = Image.open(img_path).convert("RGB")
|
|
w, h = image_source.size
|
|
if not imgsz:
|
|
imgsz = (h, w)
|
|
# print('image size:', w, h)
|
|
xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
|
|
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
|
|
image_source = np.asarray(image_source)
|
|
phrases = [str(i) for i in range(len(phrases))]
|
|
|
|
|
|
# annotate the image with labels
|
|
if ocr_bbox:
|
|
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
|
|
ocr_bbox=ocr_bbox.tolist()
|
|
else:
|
|
print('no ocr bbox!!!')
|
|
ocr_bbox = None
|
|
|
|
ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
|
|
xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
|
|
filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
|
|
|
|
|
|
# sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
|
|
filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
|
|
# get the index of the first 'content': None
|
|
starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
|
|
filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
|
|
print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
|
|
|
|
# get parsed icon local semantics
|
|
time1 = time.time()
|
|
if use_local_semantics:
|
|
caption_model = caption_model_processor['model']
|
|
if 'phi3_v' in caption_model.config.model_type:
|
|
parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
|
|
else:
|
|
parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
|
|
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
|
|
icon_start = len(ocr_text)
|
|
parsed_content_icon_ls = []
|
|
# fill the filtered_boxes_elem None content with parsed_content_icon in order
|
|
for i, box in enumerate(filtered_boxes_elem):
|
|
if box['content'] is None:
|
|
box['content'] = parsed_content_icon.pop(0)
|
|
for i, txt in enumerate(parsed_content_icon):
|
|
parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
|
|
parsed_content_merged = ocr_text + parsed_content_icon_ls
|
|
else:
|
|
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
|
|
parsed_content_merged = ocr_text
|
|
print('time to get parsed content:', time.time()-time1)
|
|
|
|
filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
|
|
|
|
phrases = [i for i in range(len(filtered_boxes))]
|
|
|
|
# draw boxes
|
|
if draw_bbox_config:
|
|
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
|
|
else:
|
|
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
|
|
|
|
pil_img = Image.fromarray(annotated_frame)
|
|
buffered = io.BytesIO()
|
|
pil_img.save(buffered, format="PNG")
|
|
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
|
if output_coord_in_ratio:
|
|
# h, w, _ = image_source.shape
|
|
label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
|
|
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
|
|
|
|
return encoded_image, label_coordinates, filtered_boxes_elem
|
|
|
|
|
|
def get_xywh(input):
|
|
x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
|
|
x, y, w, h = int(x), int(y), int(w), int(h)
|
|
return x, y, w, h
|
|
|
|
def get_xyxy(input):
|
|
x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
|
|
x, y, xp, yp = int(x), int(y), int(xp), int(yp)
|
|
return x, y, xp, yp
|
|
|
|
def get_xywh_yolo(input):
|
|
x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
|
|
x, y, w, h = int(x), int(y), int(w), int(h)
|
|
return x, y, w, h
|
|
|
|
|
|
|
|
def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
|
|
if use_paddleocr:
|
|
if easyocr_args is None:
|
|
text_threshold = 0.5
|
|
else:
|
|
text_threshold = easyocr_args['text_threshold']
|
|
result = paddle_ocr.ocr(image_path, cls=False)[0]
|
|
# conf = [item[1] for item in result]
|
|
coord = [item[0] for item in result if item[1][1] > text_threshold]
|
|
text = [item[1][0] for item in result if item[1][1] > text_threshold]
|
|
else: # EasyOCR
|
|
if easyocr_args is None:
|
|
easyocr_args = {}
|
|
result = reader.readtext(image_path, **easyocr_args)
|
|
# print('goal filtering pred:', result[-5:])
|
|
coord = [item[0] for item in result]
|
|
text = [item[1] for item in result]
|
|
# read the image using cv2
|
|
if display_img:
|
|
opencv_img = cv2.imread(image_path)
|
|
opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
|
|
bb = []
|
|
for item in coord:
|
|
x, y, a, b = get_xywh(item)
|
|
# print(x, y, a, b)
|
|
bb.append((x, y, a, b))
|
|
cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
|
|
|
|
# Display the image
|
|
plt.imshow(opencv_img)
|
|
else:
|
|
if output_bb_format == 'xywh':
|
|
bb = [get_xywh(item) for item in coord]
|
|
elif output_bb_format == 'xyxy':
|
|
bb = [get_xyxy(item) for item in coord]
|
|
# print('bounding box!!!', bb)
|
|
return (text, bb), goal_filtering
|
|
|
|
|
|
|
|
from typing import List, Optional, Union, Tuple
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from supervision.detection.core import Detections
|
|
from supervision.draw.color import Color, ColorPalette
|
|
|
|
|
|
class BoxAnnotator:
|
|
"""
|
|
A class for drawing bounding boxes on an image using detections provided.
|
|
|
|
Attributes:
|
|
color (Union[Color, ColorPalette]): The color to draw the bounding box,
|
|
can be a single color or a color palette
|
|
thickness (int): The thickness of the bounding box lines, default is 2
|
|
text_color (Color): The color of the text on the bounding box, default is white
|
|
text_scale (float): The scale of the text on the bounding box, default is 0.5
|
|
text_thickness (int): The thickness of the text on the bounding box,
|
|
default is 1
|
|
text_padding (int): The padding around the text on the bounding box,
|
|
default is 5
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
|
|
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
|
|
text_color: Color = Color.BLACK,
|
|
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
|
text_thickness: int = 2, #1, # 2 for demo
|
|
text_padding: int = 10,
|
|
avoid_overlap: bool = True,
|
|
):
|
|
self.color: Union[Color, ColorPalette] = color
|
|
self.thickness: int = thickness
|
|
self.text_color: Color = text_color
|
|
self.text_scale: float = text_scale
|
|
self.text_thickness: int = text_thickness
|
|
self.text_padding: int = text_padding
|
|
self.avoid_overlap: bool = avoid_overlap
|
|
|
|
def annotate(
|
|
self,
|
|
scene: np.ndarray,
|
|
detections: Detections,
|
|
labels: Optional[List[str]] = None,
|
|
skip_label: bool = False,
|
|
image_size: Optional[Tuple[int, int]] = None,
|
|
) -> np.ndarray:
|
|
"""
|
|
Draws bounding boxes on the frame using the detections provided.
|
|
|
|
Args:
|
|
scene (np.ndarray): The image on which the bounding boxes will be drawn
|
|
detections (Detections): The detections for which the
|
|
bounding boxes will be drawn
|
|
labels (Optional[List[str]]): An optional list of labels
|
|
corresponding to each detection. If `labels` are not provided,
|
|
corresponding `class_id` will be used as label.
|
|
skip_label (bool): Is set to `True`, skips bounding box label annotation.
|
|
Returns:
|
|
np.ndarray: The image with the bounding boxes drawn on it
|
|
|
|
Example:
|
|
```python
|
|
import supervision as sv
|
|
|
|
classes = ['person', ...]
|
|
image = ...
|
|
detections = sv.Detections(...)
|
|
|
|
box_annotator = sv.BoxAnnotator()
|
|
labels = [
|
|
f"{classes[class_id]} {confidence:0.2f}"
|
|
for _, _, confidence, class_id, _ in detections
|
|
]
|
|
annotated_frame = box_annotator.annotate(
|
|
scene=image.copy(),
|
|
detections=detections,
|
|
labels=labels
|
|
)
|
|
```
|
|
"""
|
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
for i in range(len(detections)):
|
|
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
|
|
class_id = (
|
|
detections.class_id[i] if detections.class_id is not None else None
|
|
)
|
|
idx = class_id if class_id is not None else i
|
|
color = (
|
|
self.color.by_idx(idx)
|
|
if isinstance(self.color, ColorPalette)
|
|
else self.color
|
|
)
|
|
cv2.rectangle(
|
|
img=scene,
|
|
pt1=(x1, y1),
|
|
pt2=(x2, y2),
|
|
color=color.as_bgr(),
|
|
thickness=self.thickness,
|
|
)
|
|
if skip_label:
|
|
continue
|
|
|
|
text = (
|
|
f"{class_id}"
|
|
if (labels is None or len(detections) != len(labels))
|
|
else labels[i]
|
|
)
|
|
|
|
text_width, text_height = cv2.getTextSize(
|
|
text=text,
|
|
fontFace=font,
|
|
fontScale=self.text_scale,
|
|
thickness=self.text_thickness,
|
|
)[0]
|
|
|
|
if not self.avoid_overlap:
|
|
text_x = x1 + self.text_padding
|
|
text_y = y1 - self.text_padding
|
|
|
|
text_background_x1 = x1
|
|
text_background_y1 = y1 - 2 * self.text_padding - text_height
|
|
|
|
text_background_x2 = x1 + 2 * self.text_padding + text_width
|
|
text_background_y2 = y1
|
|
# text_x = x1 - self.text_padding - text_width
|
|
# text_y = y1 + self.text_padding + text_height
|
|
# text_background_x1 = x1 - 2 * self.text_padding - text_width
|
|
# text_background_y1 = y1
|
|
# text_background_x2 = x1
|
|
# text_background_y2 = y1 + 2 * self.text_padding + text_height
|
|
else:
|
|
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
|
|
|
|
cv2.rectangle(
|
|
img=scene,
|
|
pt1=(text_background_x1, text_background_y1),
|
|
pt2=(text_background_x2, text_background_y2),
|
|
color=color.as_bgr(),
|
|
thickness=cv2.FILLED,
|
|
)
|
|
box_color = color.as_rgb()
|
|
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
|
|
text_color = (0,0,0) if luminance > 160 else (255,255,255)
|
|
cv2.putText(
|
|
img=scene,
|
|
text=text,
|
|
org=(text_x, text_y),
|
|
fontFace=font,
|
|
fontScale=self.text_scale,
|
|
# color=self.text_color.as_rgb(),
|
|
color=text_color,
|
|
thickness=self.text_thickness,
|
|
lineType=cv2.LINE_AA,
|
|
)
|
|
return scene
|
|
|
|
|
|
def box_area(box):
|
|
return (box[2] - box[0]) * (box[3] - box[1])
|
|
|
|
def intersection_area(box1, box2):
|
|
x1 = max(box1[0], box2[0])
|
|
y1 = max(box1[1], box2[1])
|
|
x2 = min(box1[2], box2[2])
|
|
y2 = min(box1[3], box2[3])
|
|
return max(0, x2 - x1) * max(0, y2 - y1)
|
|
|
|
def IoU(box1, box2, return_max=True):
|
|
intersection = intersection_area(box1, box2)
|
|
union = box_area(box1) + box_area(box2) - intersection
|
|
if box_area(box1) > 0 and box_area(box2) > 0:
|
|
ratio1 = intersection / box_area(box1)
|
|
ratio2 = intersection / box_area(box2)
|
|
else:
|
|
ratio1, ratio2 = 0, 0
|
|
if return_max:
|
|
return max(intersection / union, ratio1, ratio2)
|
|
else:
|
|
return intersection / union
|
|
|
|
|
|
def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
|
|
""" check overlap of text and background detection box, and get_optimal_label_pos,
|
|
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
|
|
Threshold: default to 0.3
|
|
"""
|
|
|
|
def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
|
|
is_overlap = False
|
|
for i in range(len(detections)):
|
|
detection = detections.xyxy[i].astype(int)
|
|
if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
|
|
is_overlap = True
|
|
break
|
|
# check if the text is out of the image
|
|
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
|
|
is_overlap = True
|
|
return is_overlap
|
|
|
|
# if pos == 'top left':
|
|
text_x = x1 + text_padding
|
|
text_y = y1 - text_padding
|
|
|
|
text_background_x1 = x1
|
|
text_background_y1 = y1 - 2 * text_padding - text_height
|
|
|
|
text_background_x2 = x1 + 2 * text_padding + text_width
|
|
text_background_y2 = y1
|
|
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
|
if not is_overlap:
|
|
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
|
|
|
# elif pos == 'outer left':
|
|
text_x = x1 - text_padding - text_width
|
|
text_y = y1 + text_padding + text_height
|
|
|
|
text_background_x1 = x1 - 2 * text_padding - text_width
|
|
text_background_y1 = y1
|
|
|
|
text_background_x2 = x1
|
|
text_background_y2 = y1 + 2 * text_padding + text_height
|
|
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
|
if not is_overlap:
|
|
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
|
|
|
|
|
# elif pos == 'outer right':
|
|
text_x = x2 + text_padding
|
|
text_y = y1 + text_padding + text_height
|
|
|
|
text_background_x1 = x2
|
|
text_background_y1 = y1
|
|
|
|
text_background_x2 = x2 + 2 * text_padding + text_width
|
|
text_background_y2 = y1 + 2 * text_padding + text_height
|
|
|
|
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
|
if not is_overlap:
|
|
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
|
|
|
# elif pos == 'top right':
|
|
text_x = x2 - text_padding - text_width
|
|
text_y = y1 - text_padding
|
|
|
|
text_background_x1 = x2 - 2 * text_padding - text_width
|
|
text_background_y1 = y1 - 2 * text_padding - text_height
|
|
|
|
text_background_x2 = x2
|
|
text_background_y2 = y1
|
|
|
|
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
|
if not is_overlap:
|
|
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
|
|
|
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
|
|
|
|
|
|
|
import re
|
|
def extract_dict_from_text(text):
|
|
# Define the regex pattern for a dictionary-like structure
|
|
pattern = r"\{\s*'(?P<key1>.*?)':\s*'(?P<value1>.*?)',\s*'(?P<key2>.*?)':\s*'(?P<value2>.*?)'\s*\}"
|
|
|
|
# Search for the dictionary in the text
|
|
match = re.search(pattern, text)
|
|
|
|
if match:
|
|
# Extract matched groups into a dictionary
|
|
return {
|
|
match.group('key1'): match.group('value1'),
|
|
match.group('key2'): match.group('value2'),
|
|
}
|
|
else:
|
|
raise ValueError("No valid dictionary structure found in the text.")
|
|
|
|
|
|
def get_phi3v_model_dict():
|
|
from PIL import Image
|
|
import requests
|
|
from transformers import AutoModelForCausalLM
|
|
from transformers import AutoProcessor
|
|
|
|
model_id = "microsoft/Phi-3.5-vision-instruct"
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto", _attn_implementation='flash_attention_2')
|
|
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
|
print('phi3v model loaded!!!')
|
|
return {'model': model, 'processor': processor}
|
|
|
|
|
|
def call_phi3v(messages, image_base64, model_dict):
|
|
model, processor = model_dict['model'], model_dict['processor']
|
|
device = model.device
|
|
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
if isinstance(image_base64, tuple):
|
|
image_base64, dino_labled_img = image_base64
|
|
image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
|
|
dino_labled_img = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
|
|
inputs = processor(prompt, [image, dino_labled_img], return_tensors="pt").to(device)
|
|
else:
|
|
image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
|
|
inputs = processor(prompt, [image], return_tensors="pt").to(device)
|
|
|
|
generation_args = {
|
|
"max_new_tokens": 512,
|
|
"temperature": 0.01,
|
|
"do_sample": False,
|
|
}
|
|
|
|
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
|
|
# remove input tokens
|
|
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
|
|
ans = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
return ans
|
|
|
|
|
|
def get_pred_phi3v(message_text, image_base64, label_coordinates, id_key='Click ID', model_dict=None):
|
|
# messages = [
|
|
# {"role": "system", "content": '''You are an expert at completing instructions on GUI screens.
|
|
# You will be presented with two images. The first is the original screenshot. The second is the same screenshot with some numeric tags. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
|
|
# ]
|
|
messages = [
|
|
{"role": "system", "content": '''You are an expert at completing instructions on GUI screens. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
|
|
]
|
|
messages = []
|
|
if isinstance(image_base64, tuple):
|
|
messages.append({"role": "user", "content": '<|image_1|>\n' + '<|image_2|>\n' + message_text})
|
|
else:
|
|
messages.append({"role": "user", "content": '<|image_1|>\n' + message_text})
|
|
|
|
response_text = call_phi3v(messages, image_base64, model_dict)
|
|
print(response_text)
|
|
|
|
try:
|
|
response_text = ast.literal_eval(response_text)
|
|
|
|
icon_id = response_text['Click BBox ID']
|
|
bbox = label_coordinates[str(icon_id)]
|
|
click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
|
|
except:
|
|
print('error parsing, use regex to parse!!!')
|
|
import pdb; pdb.set_trace()
|
|
response_text = extract_dict_from_text(response_text)
|
|
icon_id = response_text['Click BBox ID']
|
|
bbox = label_coordinates[str(icon_id)]
|
|
click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
|
|
return icon_id, bbox, click_point, response_text
|
|
|
|
# try:
|
|
# match = re.search(r"```(.*?)```", ans, re.DOTALL)
|
|
# if match:
|
|
# result = match.group(1).strip()
|
|
# pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
|
|
# pred = ast.literal_eval(pred)
|
|
# else:
|
|
# pred = ans.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
|
|
# pred = ast.literal_eval(pred)
|
|
|
|
# if pred[id_key]:
|
|
# icon_id = pred[id_key]
|
|
# bbox = label_coordinates[str(icon_id)]
|
|
# pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
|
|
# except:
|
|
# print('phi3v action regex extract fail!!!')
|
|
# pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False}
|
|
|
|
# step_pred_summary = None
|
|
# return pred, [True, ans, None, step_pred_summary] |