diff --git a/util/utils.py b/util/utils.py index 62278cb..eb7c8b2 100644 --- a/util/utils.py +++ b/util/utils.py @@ -76,8 +76,8 @@ def get_yolo_model(model_path): @torch.inference_mode() -def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None): - # Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model +def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=128): + # Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model to_pil = ToPILImage() if starting_idx: non_ocr_boxes = filtered_boxes[starting_idx:] @@ -103,7 +103,6 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_ generated_texts = [] device = model.device - # batch_size = 64 for i in range(0, len(croped_pil_image), batch_size): start = time.time() batch = croped_pil_image[i:i+batch_size] @@ -405,7 +404,7 @@ def int_box_area(box, w, h): area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1]) return area -def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64): +def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=128): """Process either an image path or Image object Args: @@ -413,8 +412,8 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T ... """ if isinstance(image_source, str): - image_source = Image.open(image_source).convert("RGB") - + image_source = Image.open(image_source) + image_source = image_source.convert("RGB") # for CLIP w, h = image_source.size if not imgsz: imgsz = (h, w) @@ -538,6 +537,4 @@ def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, out bb = [get_xywh(item) for item in coord] elif output_bb_format == 'xyxy': bb = [get_xyxy(item) for item in coord] - return (text, bb), goal_filtering - - + return (text, bb), goal_filtering \ No newline at end of file