align main utils with demo utils

This commit is contained in:
Thomas Dhome-Casanova
2025-01-27 22:08:54 -08:00
parent 5cf55e116f
commit 9cb2263545

View File

@@ -41,6 +41,7 @@ import re
from torchvision.transforms import ToPILImage
import supervision as sv
import torchvision.transforms as T
from util.box_annotator import BoxAnnotator
def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
@@ -77,22 +78,21 @@ def get_yolo_model(model_path):
@torch.inference_mode()
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
# Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
to_pil = ToPILImage()
if starting_idx:
non_ocr_boxes = filtered_boxes[starting_idx:]
else:
non_ocr_boxes = filtered_boxes
croped_pil_image = []
t0 = time.time()
for i, coord in enumerate(non_ocr_boxes):
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
# resize the image to 224x224 to avoid long overhead in clipimageprocessor # TODO
cropped_image = cv2.resize(cropped_image, (224, 224))
croped_pil_image.append(to_pil(cropped_image))
print('time to prepare bbox:', time.time()-t0)
try:
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
cropped_image = cv2.resize(cropped_image, (64, 64))
croped_pil_image.append(to_pil(cropped_image))
except:
continue
model, processor = caption_model_processor['model'], caption_model_processor['processor']
if not prompt:
@@ -112,14 +112,10 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
else:
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
t2 = time.time()
print('time to process image + tokenize text inputs:', t2-t1)
if 'florence' in model.config.name_or_path:
generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
else:
generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
t3 = time.time()
print('time to generate:', t3-t2)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
generated_text = [gen.strip() for gen in generated_text]
generated_texts.extend(generated_text)
@@ -282,10 +278,10 @@ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
is_valid_box = False
break
if is_valid_box:
# add the following 2 lines to include ocr bbox
if ocr_bbox:
# keep yolo boxes + prioritize ocr label
box_added = False
ocr_labels = ''
for box3_elem in ocr_bbox:
if not box_added:
box3 = box3_elem['bbox']
@@ -293,25 +289,22 @@ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
# box_added = True
# delete the box3_elem from ocr_bbox
try:
filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content']})
# gather all ocr labels
ocr_labels += box3_elem['content'] + ' '
filtered_boxes.remove(box3_elem)
# print('remove ocr bbox:', box3_elem)
except:
continue
# break
elif is_inside(box1, box3): # icon inside ocr
elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
box_added = True
# try:
# filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
# filtered_boxes.remove(box3_elem)
# except:
# continue
break
else:
continue
if not box_added:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
if ocr_labels:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
else:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
else:
filtered_boxes.append(box1)
return filtered_boxes # torch.tensor(filtered_boxes)
@@ -354,7 +347,6 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
from util.box_annotator import BoxAnnotator
box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
annotated_frame = image_source.copy()
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
@@ -407,7 +399,12 @@ def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.
return boxes, conf, phrases
def int_box_area(box, w, h):
x1, y1, x2, y2 = box
int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
return area
def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
"""Process either an image path or Image object
@@ -428,19 +425,15 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T
phrases = [str(i) for i in range(len(phrases))]
# annotate the image with labels
h, w, _ = image_source.shape
if ocr_bbox:
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
ocr_bbox=ocr_bbox.tolist()
else:
print('no ocr bbox!!!')
ocr_bbox = None
# filtered_boxes = remove_overlap(boxes=xyxy, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox)
# starting_idx = len(ocr_bbox)
# print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt} for box, txt in zip(ocr_bbox, ocr_text)]
xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist()]
ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
# sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
@@ -450,7 +443,6 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T
filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
# get parsed icon local semantics
time1 = time.time()
if use_local_semantics:
@@ -489,7 +481,6 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T
pil_img.save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
if output_coord_in_ratio:
# h, w, _ = image_source.shape
label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]