align main utils with demo utils
This commit is contained in:
59
utils.py
59
utils.py
@@ -41,6 +41,7 @@ import re
|
||||
from torchvision.transforms import ToPILImage
|
||||
import supervision as sv
|
||||
import torchvision.transforms as T
|
||||
from util.box_annotator import BoxAnnotator
|
||||
|
||||
|
||||
def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
|
||||
@@ -77,22 +78,21 @@ def get_yolo_model(model_path):
|
||||
@torch.inference_mode()
|
||||
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
|
||||
# Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
|
||||
|
||||
to_pil = ToPILImage()
|
||||
if starting_idx:
|
||||
non_ocr_boxes = filtered_boxes[starting_idx:]
|
||||
else:
|
||||
non_ocr_boxes = filtered_boxes
|
||||
croped_pil_image = []
|
||||
t0 = time.time()
|
||||
for i, coord in enumerate(non_ocr_boxes):
|
||||
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
|
||||
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
|
||||
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
||||
# resize the image to 224x224 to avoid long overhead in clipimageprocessor # TODO
|
||||
cropped_image = cv2.resize(cropped_image, (224, 224))
|
||||
croped_pil_image.append(to_pil(cropped_image))
|
||||
print('time to prepare bbox:', time.time()-t0)
|
||||
try:
|
||||
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
|
||||
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
|
||||
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
||||
cropped_image = cv2.resize(cropped_image, (64, 64))
|
||||
croped_pil_image.append(to_pil(cropped_image))
|
||||
except:
|
||||
continue
|
||||
|
||||
model, processor = caption_model_processor['model'], caption_model_processor['processor']
|
||||
if not prompt:
|
||||
@@ -112,14 +112,10 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
|
||||
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
|
||||
else:
|
||||
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
|
||||
t2 = time.time()
|
||||
print('time to process image + tokenize text inputs:', t2-t1)
|
||||
if 'florence' in model.config.name_or_path:
|
||||
generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
|
||||
else:
|
||||
generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
|
||||
t3 = time.time()
|
||||
print('time to generate:', t3-t2)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
generated_text = [gen.strip() for gen in generated_text]
|
||||
generated_texts.extend(generated_text)
|
||||
@@ -282,10 +278,10 @@ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
|
||||
is_valid_box = False
|
||||
break
|
||||
if is_valid_box:
|
||||
# add the following 2 lines to include ocr bbox
|
||||
if ocr_bbox:
|
||||
# keep yolo boxes + prioritize ocr label
|
||||
box_added = False
|
||||
ocr_labels = ''
|
||||
for box3_elem in ocr_bbox:
|
||||
if not box_added:
|
||||
box3 = box3_elem['bbox']
|
||||
@@ -293,25 +289,22 @@ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
|
||||
# box_added = True
|
||||
# delete the box3_elem from ocr_bbox
|
||||
try:
|
||||
filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content']})
|
||||
# gather all ocr labels
|
||||
ocr_labels += box3_elem['content'] + ' '
|
||||
filtered_boxes.remove(box3_elem)
|
||||
# print('remove ocr bbox:', box3_elem)
|
||||
except:
|
||||
continue
|
||||
# break
|
||||
elif is_inside(box1, box3): # icon inside ocr
|
||||
elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
|
||||
box_added = True
|
||||
# try:
|
||||
# filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
|
||||
# filtered_boxes.remove(box3_elem)
|
||||
# except:
|
||||
# continue
|
||||
break
|
||||
else:
|
||||
continue
|
||||
if not box_added:
|
||||
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
|
||||
|
||||
if ocr_labels:
|
||||
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
|
||||
else:
|
||||
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
|
||||
else:
|
||||
filtered_boxes.append(box1)
|
||||
return filtered_boxes # torch.tensor(filtered_boxes)
|
||||
@@ -354,7 +347,6 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
|
||||
|
||||
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
|
||||
|
||||
from util.box_annotator import BoxAnnotator
|
||||
box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
annotated_frame = image_source.copy()
|
||||
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
|
||||
@@ -407,7 +399,12 @@ def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.
|
||||
|
||||
return boxes, conf, phrases
|
||||
|
||||
|
||||
def int_box_area(box, w, h):
|
||||
x1, y1, x2, y2 = box
|
||||
int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
|
||||
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
|
||||
return area
|
||||
|
||||
def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
|
||||
"""Process either an image path or Image object
|
||||
|
||||
@@ -428,19 +425,15 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T
|
||||
phrases = [str(i) for i in range(len(phrases))]
|
||||
|
||||
# annotate the image with labels
|
||||
h, w, _ = image_source.shape
|
||||
if ocr_bbox:
|
||||
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
|
||||
ocr_bbox=ocr_bbox.tolist()
|
||||
else:
|
||||
print('no ocr bbox!!!')
|
||||
ocr_bbox = None
|
||||
# filtered_boxes = remove_overlap(boxes=xyxy, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox)
|
||||
# starting_idx = len(ocr_bbox)
|
||||
# print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
|
||||
|
||||
ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt} for box, txt in zip(ocr_bbox, ocr_text)]
|
||||
xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist()]
|
||||
ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
|
||||
xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
|
||||
filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
|
||||
|
||||
# sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
|
||||
@@ -450,7 +443,6 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T
|
||||
filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
|
||||
print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
|
||||
|
||||
|
||||
# get parsed icon local semantics
|
||||
time1 = time.time()
|
||||
if use_local_semantics:
|
||||
@@ -489,7 +481,6 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T
|
||||
pil_img.save(buffered, format="PNG")
|
||||
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||
if output_coord_in_ratio:
|
||||
# h, w, _ = image_source.shape
|
||||
label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
|
||||
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user