diff --git a/demo/remote_request.py b/demo/remote_request.py index 32c1730..95f3595 100644 --- a/demo/remote_request.py +++ b/demo/remote_request.py @@ -4,7 +4,7 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, get_ocr_bbox +from utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box import torch from PIL import Image from typing import Dict, Tuple, List @@ -45,7 +45,7 @@ class Omniparser(object): } BOX_TRESHOLD = config['BOX_TRESHOLD'] - text, ocr_bbox = get_ocr_bbox(image) + (text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False) dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128) return dino_labled_img, parsed_content_list diff --git a/utils.py b/utils.py index a4e3f00..1da7e68 100755 --- a/utils.py +++ b/utils.py @@ -501,50 +501,52 @@ def get_xywh_yolo(input): x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1] x, y, w, h = int(x), int(y), int(w), int(h) return x, y, w, h - - -def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False): +def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False): + if isinstance(image_source, str): + image_source = Image.open(image_source) + if image_source.mode == 'RGBA': + # Convert RGBA to RGB to avoid alpha channel issues + image_source = image_source.convert('RGB') + image_np = np.array(image_source) + w, h = image_source.size if use_paddleocr: if easyocr_args is None: text_threshold = 0.5 else: text_threshold = easyocr_args['text_threshold'] - result = paddle_ocr.ocr(image_path, cls=False)[0] - # conf = [item[1] for item in result] + result = paddle_ocr.ocr(image_np, cls=False)[0] coord = [item[0] for item in result if item[1][1] > text_threshold] text = [item[1][0] for item in result if item[1][1] > text_threshold] else: # EasyOCR if easyocr_args is None: easyocr_args = {} - result = reader.readtext(image_path, **easyocr_args) - # print('goal filtering pred:', result[-5:]) + result = reader.readtext(image_np, **easyocr_args) coord = [item[0] for item in result] text = [item[1] for item in result] - # read the image using cv2 if display_img: - opencv_img = cv2.imread(image_path) - opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR) + opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) bb = [] for item in coord: x, y, a, b = get_xywh(item) - # print(x, y, a, b) bb.append((x, y, a, b)) cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2) - - # Display the image - plt.imshow(opencv_img) + # matplotlib expects RGB + plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB)) else: if output_bb_format == 'xywh': bb = [get_xywh(item) for item in coord] elif output_bb_format == 'xyxy': bb = [get_xyxy(item) for item in coord] - # print('bounding box!!!', bb) return (text, bb), goal_filtering -def get_ocr_bbox(image): +def get_ocr_bbox(image: Image.Image): + if image.mode == 'RGBA': + # Convert RGBA to RGB to avoid alpha channel issues + image = image.convert('RGB') + image_np = np.array(image) + result = paddle_ocr.ocr(image_np, cls=False)[0] text_threshold = 0.8 - result = paddle_ocr.ocr(image, cls=False)[0] coord = [item[0] for item in result if item[1][1] > text_threshold] text = [item[1][0] for item in result if item[1][1] > text_threshold] bb = [get_xyxy(item) for item in coord]