diff --git a/utils.py b/utils.py index 6ddf88b..a4e3f00 100755 --- a/utils.py +++ b/utils.py @@ -41,6 +41,7 @@ import re from torchvision.transforms import ToPILImage import supervision as sv import torchvision.transforms as T +from util.box_annotator import BoxAnnotator def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None): @@ -77,22 +78,21 @@ def get_yolo_model(model_path): @torch.inference_mode() def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None): # Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model - to_pil = ToPILImage() if starting_idx: non_ocr_boxes = filtered_boxes[starting_idx:] else: non_ocr_boxes = filtered_boxes croped_pil_image = [] - t0 = time.time() for i, coord in enumerate(non_ocr_boxes): - xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1]) - ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0]) - cropped_image = image_source[ymin:ymax, xmin:xmax, :] - # resize the image to 224x224 to avoid long overhead in clipimageprocessor # TODO - cropped_image = cv2.resize(cropped_image, (224, 224)) - croped_pil_image.append(to_pil(cropped_image)) - print('time to prepare bbox:', time.time()-t0) + try: + xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1]) + ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0]) + cropped_image = image_source[ymin:ymax, xmin:xmax, :] + cropped_image = cv2.resize(cropped_image, (64, 64)) + croped_pil_image.append(to_pil(cropped_image)) + except: + continue model, processor = caption_model_processor['model'], caption_model_processor['processor'] if not prompt: @@ -112,14 +112,10 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_ inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16) else: inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device) - t2 = time.time() - print('time to process image + tokenize text inputs:', t2-t1) if 'florence' in model.config.name_or_path: generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False) else: generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True, - t3 = time.time() - print('time to generate:', t3-t2) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) generated_text = [gen.strip() for gen in generated_text] generated_texts.extend(generated_text) @@ -282,10 +278,10 @@ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None): is_valid_box = False break if is_valid_box: - # add the following 2 lines to include ocr bbox if ocr_bbox: # keep yolo boxes + prioritize ocr label box_added = False + ocr_labels = '' for box3_elem in ocr_bbox: if not box_added: box3 = box3_elem['bbox'] @@ -293,25 +289,22 @@ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None): # box_added = True # delete the box3_elem from ocr_bbox try: - filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content']}) + # gather all ocr labels + ocr_labels += box3_elem['content'] + ' ' filtered_boxes.remove(box3_elem) - # print('remove ocr bbox:', box3_elem) except: continue # break - elif is_inside(box1, box3): # icon inside ocr + elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box box_added = True - # try: - # filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None}) - # filtered_boxes.remove(box3_elem) - # except: - # continue break else: continue if not box_added: - filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None}) - + if ocr_labels: + filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'}) + else: + filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'}) else: filtered_boxes.append(box1) return filtered_boxes # torch.tensor(filtered_boxes) @@ -354,7 +347,6 @@ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor labels = [f"{phrase}" for phrase in range(boxes.shape[0])] - from util.box_annotator import BoxAnnotator box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web annotated_frame = image_source.copy() annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h)) @@ -407,7 +399,12 @@ def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0. return boxes, conf, phrases - +def int_box_area(box, w, h): + x1, y1, x2, y2 = box + int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)] + area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1]) + return area + def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64): """Process either an image path or Image object @@ -428,19 +425,15 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T phrases = [str(i) for i in range(len(phrases))] # annotate the image with labels - h, w, _ = image_source.shape if ocr_bbox: ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h]) ocr_bbox=ocr_bbox.tolist() else: print('no ocr bbox!!!') ocr_bbox = None - # filtered_boxes = remove_overlap(boxes=xyxy, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox) - # starting_idx = len(ocr_bbox) - # print('len(filtered_boxes):', len(filtered_boxes), starting_idx) - ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt} for box, txt in zip(ocr_bbox, ocr_text)] - xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist()] + ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0] + xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0] filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem) # sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None @@ -450,7 +443,6 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem]) print('len(filtered_boxes):', len(filtered_boxes), starting_idx) - # get parsed icon local semantics time1 = time.time() if use_local_semantics: @@ -489,7 +481,6 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T pil_img.save(buffered, format="PNG") encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii') if output_coord_in_ratio: - # h, w, _ = image_source.shape label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()} assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]