diff --git a/gradio_demo.py b/gradio_demo.py index 11f3324..0ecc5b8 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -58,14 +58,15 @@ DEVICE = torch.device('cuda') def process( image_input, box_threshold, - iou_threshold + iou_threshold, + use_paddleocr ) -> Optional[Image.Image]: image_save_path = 'imgs/saved_image_demo.png' image_input.save(image_save_path) # import pdb; pdb.set_trace() - ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}) + ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr) text, ocr_bbox = ocr_bbox_rslt # print('prompt:', prompt) dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold) @@ -88,6 +89,8 @@ with gr.Blocks() as demo: # set the threshold for removing the bounding boxes with large overlap, default is 0.1 iou_threshold_component = gr.Slider( label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1) + use_paddleocr_component = gr.Checkbox( + label='Use PaddleOCR', default=True) submit_button_component = gr.Button( value='Submit', variant='primary') with gr.Column(): @@ -99,7 +102,8 @@ with gr.Blocks() as demo: inputs=[ image_input_component, box_threshold_component, - iou_threshold_component + iou_threshold_component, + use_paddleocr_component ], outputs=[image_output_component, text_output_component] ) diff --git a/requirements.txt b/requirements.txt index cddf7fd..08bfc38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,5 @@ dill accelerate timm einops==0.8.0 +paddlepaddle +paddleocr \ No newline at end of file diff --git a/utils.py b/utils.py index b56588e..13b3c9e 100755 --- a/utils.py +++ b/utils.py @@ -18,7 +18,17 @@ import numpy as np # %matplotlib inline from matplotlib import pyplot as plt import easyocr +from paddleocr import PaddleOCR reader = easyocr.Reader(['en']) +paddle_ocr = PaddleOCR( + lang='en', # other lang also available + use_angle_cls=False, + use_gpu=False, # using cuda will conflict with pytorch in the same process + show_log=False, + max_batch_size=1024, + use_dilation=True, # improves accuracy + det_db_score_mode='slow', # improves accuracy + rec_batch_num=1024) import time import base64 @@ -370,14 +380,18 @@ def get_xywh_yolo(input): -def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None): - if easyocr_args is None: - easyocr_args = {} - result = reader.readtext(image_path, **easyocr_args) - is_goal_filtered = False - # print('goal filtering pred:', result[-5:]) - coord = [item[0] for item in result] - text = [item[1] for item in result] +def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False): + if use_paddleocr: + result = paddle_ocr.ocr(image_path, cls=False)[0] + coord = [item[0] for item in result] + text = [item[1][0] for item in result] + else: # EasyOCR + if easyocr_args is None: + easyocr_args = {} + result = reader.readtext(image_path, **easyocr_args) + # print('goal filtering pred:', result[-5:]) + coord = [item[0] for item in result] + text = [item[1] for item in result] # read the image using cv2 if display_img: opencv_img = cv2.imread(image_path)