From 6cb310d12485377e860286de6d763910882e1024 Mon Sep 17 00:00:00 2001 From: "Thomas Dhome Casanova (from Dev Box)" Date: Mon, 20 Jan 2025 18:29:46 -0800 Subject: [PATCH] remove need to write to disk --- demo/remote_request.py | 19 +++++++------------ utils.py | 30 ++++++++++++++++++++++-------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/demo/remote_request.py b/demo/remote_request.py index d0bc1c3..32c1730 100644 --- a/demo/remote_request.py +++ b/demo/remote_request.py @@ -4,11 +4,12 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model +from utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, get_ocr_bbox import torch from PIL import Image from typing import Dict, Tuple, List import base64 +import io config = { @@ -30,12 +31,9 @@ class Omniparser(object): print('Omniparser initialized!!!') def parse(self, image_base64: str): - image_path = '../imgs/demo_image.jpg' - with open(image_path, "wb") as fh: - fh.write(base64.b64decode(image_base64)) - print('Parsing image:', image_path) - - image = Image.open(image_path) + # Convert base64 to image directly without saving to disk + image_bytes = base64.b64decode(image_base64) + image = Image.open(io.BytesIO(image_bytes)) print('image size:', image.size) box_overlay_ratio = max(image.size) / 3200 @@ -47,11 +45,8 @@ class Omniparser(object): } BOX_TRESHOLD = config['BOX_TRESHOLD'] - ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.8}, use_paddleocr=True) - text, ocr_bbox = ocr_bbox_rslt - dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128) - with open('../imgs/demo_image_som.jpg', "wb") as fh: - fh.write(base64.b64decode(dino_labled_img)) + text, ocr_bbox = get_ocr_bbox(image) + dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128) return dino_labled_img, parsed_content_list diff --git a/utils.py b/utils.py index c4b6f88..6ddf88b 100755 --- a/utils.py +++ b/utils.py @@ -35,7 +35,7 @@ import base64 import os import ast import torch -from typing import Tuple, List +from typing import Tuple, List, Union from torchvision.ops import box_convert import re from torchvision.transforms import ToPILImage @@ -384,20 +384,20 @@ def predict(model, image, caption, box_threshold, text_threshold): return boxes, logits, phrases -def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7): +def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7): """ Use huggingface model to replace the original model """ # model = model['model'] if scale_img: result = model.predict( - source=image_path, + source=image, conf=box_threshold, imgsz=imgsz, iou=iou_threshold, # default 0.7 ) else: result = model.predict( - source=image_path, + source=image, conf=box_threshold, iou=iou_threshold, # default 0.7 ) @@ -408,15 +408,21 @@ def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_thresho return boxes, conf, phrases -def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64): - """ ocr_bbox: list of xyxy format bbox +def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64): + """Process either an image path or Image object + + Args: + image_source: Either a file path (str) or PIL Image object + ... """ - image_source = Image.open(img_path).convert("RGB") + if isinstance(image_source, str): + image_source = Image.open(image_source).convert("RGB") + w, h = image_source.size if not imgsz: imgsz = (h, w) # print('image size:', w, h) - xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1) + xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1) xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device) image_source = np.asarray(image_source) phrases = [str(i) for i in range(len(phrases))] @@ -545,5 +551,13 @@ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_ # print('bounding box!!!', bb) return (text, bb), goal_filtering +def get_ocr_bbox(image): + text_threshold = 0.8 + result = paddle_ocr.ocr(image, cls=False)[0] + coord = [item[0] for item in result if item[1][1] > text_threshold] + text = [item[1][0] for item in result if item[1][1] > text_threshold] + bb = [get_xyxy(item) for item in coord] + return text, bb +