convert rgba to rgb if passed into omniparserserver
This commit is contained in:
@@ -76,8 +76,8 @@ def get_yolo_model(model_path):
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
|
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=128):
|
||||||
# Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
|
# Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model
|
||||||
to_pil = ToPILImage()
|
to_pil = ToPILImage()
|
||||||
if starting_idx:
|
if starting_idx:
|
||||||
non_ocr_boxes = filtered_boxes[starting_idx:]
|
non_ocr_boxes = filtered_boxes[starting_idx:]
|
||||||
@@ -103,7 +103,6 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
|
|||||||
|
|
||||||
generated_texts = []
|
generated_texts = []
|
||||||
device = model.device
|
device = model.device
|
||||||
# batch_size = 64
|
|
||||||
for i in range(0, len(croped_pil_image), batch_size):
|
for i in range(0, len(croped_pil_image), batch_size):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
batch = croped_pil_image[i:i+batch_size]
|
batch = croped_pil_image[i:i+batch_size]
|
||||||
@@ -405,7 +404,7 @@ def int_box_area(box, w, h):
|
|||||||
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
|
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
|
||||||
return area
|
return area
|
||||||
|
|
||||||
def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
|
def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=128):
|
||||||
"""Process either an image path or Image object
|
"""Process either an image path or Image object
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -413,8 +412,8 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T
|
|||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
if isinstance(image_source, str):
|
if isinstance(image_source, str):
|
||||||
image_source = Image.open(image_source).convert("RGB")
|
image_source = Image.open(image_source)
|
||||||
|
image_source = image_source.convert("RGB") # for CLIP
|
||||||
w, h = image_source.size
|
w, h = image_source.size
|
||||||
if not imgsz:
|
if not imgsz:
|
||||||
imgsz = (h, w)
|
imgsz = (h, w)
|
||||||
@@ -539,5 +538,3 @@ def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, out
|
|||||||
elif output_bb_format == 'xyxy':
|
elif output_bb_format == 'xyxy':
|
||||||
bb = [get_xyxy(item) for item in coord]
|
bb = [get_xyxy(item) for item in coord]
|
||||||
return (text, bb), goal_filtering
|
return (text, bb), goal_filtering
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user