remove need to write to disk

This commit is contained in:
Thomas Dhome Casanova (from Dev Box)
2025-01-20 18:29:46 -08:00
parent 85f5fc0385
commit 6cb310d124
2 changed files with 29 additions and 20 deletions

View File

@@ -4,11 +4,12 @@ import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model from utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, get_ocr_bbox
import torch import torch
from PIL import Image from PIL import Image
from typing import Dict, Tuple, List from typing import Dict, Tuple, List
import base64 import base64
import io
config = { config = {
@@ -30,12 +31,9 @@ class Omniparser(object):
print('Omniparser initialized!!!') print('Omniparser initialized!!!')
def parse(self, image_base64: str): def parse(self, image_base64: str):
image_path = '../imgs/demo_image.jpg' # Convert base64 to image directly without saving to disk
with open(image_path, "wb") as fh: image_bytes = base64.b64decode(image_base64)
fh.write(base64.b64decode(image_base64)) image = Image.open(io.BytesIO(image_bytes))
print('Parsing image:', image_path)
image = Image.open(image_path)
print('image size:', image.size) print('image size:', image.size)
box_overlay_ratio = max(image.size) / 3200 box_overlay_ratio = max(image.size) / 3200
@@ -47,11 +45,8 @@ class Omniparser(object):
} }
BOX_TRESHOLD = config['BOX_TRESHOLD'] BOX_TRESHOLD = config['BOX_TRESHOLD']
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.8}, use_paddleocr=True) text, ocr_bbox = get_ocr_bbox(image)
text, ocr_bbox = ocr_bbox_rslt dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
with open('../imgs/demo_image_som.jpg', "wb") as fh:
fh.write(base64.b64decode(dino_labled_img))
return dino_labled_img, parsed_content_list return dino_labled_img, parsed_content_list

View File

@@ -35,7 +35,7 @@ import base64
import os import os
import ast import ast
import torch import torch
from typing import Tuple, List from typing import Tuple, List, Union
from torchvision.ops import box_convert from torchvision.ops import box_convert
import re import re
from torchvision.transforms import ToPILImage from torchvision.transforms import ToPILImage
@@ -384,20 +384,20 @@ def predict(model, image, caption, box_threshold, text_threshold):
return boxes, logits, phrases return boxes, logits, phrases
def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7): def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7):
""" Use huggingface model to replace the original model """ Use huggingface model to replace the original model
""" """
# model = model['model'] # model = model['model']
if scale_img: if scale_img:
result = model.predict( result = model.predict(
source=image_path, source=image,
conf=box_threshold, conf=box_threshold,
imgsz=imgsz, imgsz=imgsz,
iou=iou_threshold, # default 0.7 iou=iou_threshold, # default 0.7
) )
else: else:
result = model.predict( result = model.predict(
source=image_path, source=image,
conf=box_threshold, conf=box_threshold,
iou=iou_threshold, # default 0.7 iou=iou_threshold, # default 0.7
) )
@@ -408,15 +408,21 @@ def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_thresho
return boxes, conf, phrases return boxes, conf, phrases
def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64): def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
""" ocr_bbox: list of xyxy format bbox """Process either an image path or Image object
Args:
image_source: Either a file path (str) or PIL Image object
...
""" """
image_source = Image.open(img_path).convert("RGB") if isinstance(image_source, str):
image_source = Image.open(image_source).convert("RGB")
w, h = image_source.size w, h = image_source.size
if not imgsz: if not imgsz:
imgsz = (h, w) imgsz = (h, w)
# print('image size:', w, h) # print('image size:', w, h)
xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1) xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device) xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
image_source = np.asarray(image_source) image_source = np.asarray(image_source)
phrases = [str(i) for i in range(len(phrases))] phrases = [str(i) for i in range(len(phrases))]
@@ -545,5 +551,13 @@ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_
# print('bounding box!!!', bb) # print('bounding box!!!', bb)
return (text, bb), goal_filtering return (text, bb), goal_filtering
def get_ocr_bbox(image):
text_threshold = 0.8
result = paddle_ocr.ocr(image, cls=False)[0]
coord = [item[0] for item in result if item[1][1] > text_threshold]
text = [item[1][0] for item in result if item[1][1] > text_threshold]
bb = [get_xyxy(item) for item in coord]
return text, bb