remove need to write to disk
This commit is contained in:
@@ -4,11 +4,12 @@ import sys
|
|||||||
import os
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model
|
from utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, get_ocr_bbox
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Dict, Tuple, List
|
from typing import Dict, Tuple, List
|
||||||
import base64
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
@@ -30,12 +31,9 @@ class Omniparser(object):
|
|||||||
print('Omniparser initialized!!!')
|
print('Omniparser initialized!!!')
|
||||||
|
|
||||||
def parse(self, image_base64: str):
|
def parse(self, image_base64: str):
|
||||||
image_path = '../imgs/demo_image.jpg'
|
# Convert base64 to image directly without saving to disk
|
||||||
with open(image_path, "wb") as fh:
|
image_bytes = base64.b64decode(image_base64)
|
||||||
fh.write(base64.b64decode(image_base64))
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
print('Parsing image:', image_path)
|
|
||||||
|
|
||||||
image = Image.open(image_path)
|
|
||||||
print('image size:', image.size)
|
print('image size:', image.size)
|
||||||
|
|
||||||
box_overlay_ratio = max(image.size) / 3200
|
box_overlay_ratio = max(image.size) / 3200
|
||||||
@@ -47,11 +45,8 @@ class Omniparser(object):
|
|||||||
}
|
}
|
||||||
BOX_TRESHOLD = config['BOX_TRESHOLD']
|
BOX_TRESHOLD = config['BOX_TRESHOLD']
|
||||||
|
|
||||||
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.8}, use_paddleocr=True)
|
text, ocr_bbox = get_ocr_bbox(image)
|
||||||
text, ocr_bbox = ocr_bbox_rslt
|
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
|
||||||
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
|
|
||||||
with open('../imgs/demo_image_som.jpg', "wb") as fh:
|
|
||||||
fh.write(base64.b64decode(dino_labled_img))
|
|
||||||
|
|
||||||
return dino_labled_img, parsed_content_list
|
return dino_labled_img, parsed_content_list
|
||||||
|
|
||||||
|
|||||||
30
utils.py
30
utils.py
@@ -35,7 +35,7 @@ import base64
|
|||||||
import os
|
import os
|
||||||
import ast
|
import ast
|
||||||
import torch
|
import torch
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List, Union
|
||||||
from torchvision.ops import box_convert
|
from torchvision.ops import box_convert
|
||||||
import re
|
import re
|
||||||
from torchvision.transforms import ToPILImage
|
from torchvision.transforms import ToPILImage
|
||||||
@@ -384,20 +384,20 @@ def predict(model, image, caption, box_threshold, text_threshold):
|
|||||||
return boxes, logits, phrases
|
return boxes, logits, phrases
|
||||||
|
|
||||||
|
|
||||||
def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
|
def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7):
|
||||||
""" Use huggingface model to replace the original model
|
""" Use huggingface model to replace the original model
|
||||||
"""
|
"""
|
||||||
# model = model['model']
|
# model = model['model']
|
||||||
if scale_img:
|
if scale_img:
|
||||||
result = model.predict(
|
result = model.predict(
|
||||||
source=image_path,
|
source=image,
|
||||||
conf=box_threshold,
|
conf=box_threshold,
|
||||||
imgsz=imgsz,
|
imgsz=imgsz,
|
||||||
iou=iou_threshold, # default 0.7
|
iou=iou_threshold, # default 0.7
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = model.predict(
|
result = model.predict(
|
||||||
source=image_path,
|
source=image,
|
||||||
conf=box_threshold,
|
conf=box_threshold,
|
||||||
iou=iou_threshold, # default 0.7
|
iou=iou_threshold, # default 0.7
|
||||||
)
|
)
|
||||||
@@ -408,15 +408,21 @@ def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_thresho
|
|||||||
return boxes, conf, phrases
|
return boxes, conf, phrases
|
||||||
|
|
||||||
|
|
||||||
def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
|
def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
|
||||||
""" ocr_bbox: list of xyxy format bbox
|
"""Process either an image path or Image object
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_source: Either a file path (str) or PIL Image object
|
||||||
|
...
|
||||||
"""
|
"""
|
||||||
image_source = Image.open(img_path).convert("RGB")
|
if isinstance(image_source, str):
|
||||||
|
image_source = Image.open(image_source).convert("RGB")
|
||||||
|
|
||||||
w, h = image_source.size
|
w, h = image_source.size
|
||||||
if not imgsz:
|
if not imgsz:
|
||||||
imgsz = (h, w)
|
imgsz = (h, w)
|
||||||
# print('image size:', w, h)
|
# print('image size:', w, h)
|
||||||
xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
|
xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
|
||||||
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
|
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
|
||||||
image_source = np.asarray(image_source)
|
image_source = np.asarray(image_source)
|
||||||
phrases = [str(i) for i in range(len(phrases))]
|
phrases = [str(i) for i in range(len(phrases))]
|
||||||
@@ -545,5 +551,13 @@ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_
|
|||||||
# print('bounding box!!!', bb)
|
# print('bounding box!!!', bb)
|
||||||
return (text, bb), goal_filtering
|
return (text, bb), goal_filtering
|
||||||
|
|
||||||
|
def get_ocr_bbox(image):
|
||||||
|
text_threshold = 0.8
|
||||||
|
result = paddle_ocr.ocr(image, cls=False)[0]
|
||||||
|
coord = [item[0] for item in result if item[1][1] > text_threshold]
|
||||||
|
text = [item[1][0] for item in result if item[1][1] > text_threshold]
|
||||||
|
bb = [get_xyxy(item) for item in coord]
|
||||||
|
return text, bb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user