Merge pull request #53 from aliencaocao/paddle-ocr

Add PaddleOCR option
This commit is contained in:
yadong-lu
2024-10-31 14:59:13 -07:00
committed by GitHub
3 changed files with 31 additions and 11 deletions

View File

@@ -58,14 +58,15 @@ DEVICE = torch.device('cuda')
def process( def process(
image_input, image_input,
box_threshold, box_threshold,
iou_threshold iou_threshold,
use_paddleocr
) -> Optional[Image.Image]: ) -> Optional[Image.Image]:
image_save_path = 'imgs/saved_image_demo.png' image_save_path = 'imgs/saved_image_demo.png'
image_input.save(image_save_path) image_input.save(image_save_path)
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}) ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
text, ocr_bbox = ocr_bbox_rslt text, ocr_bbox = ocr_bbox_rslt
# print('prompt:', prompt) # print('prompt:', prompt)
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold) dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold)
@@ -88,6 +89,8 @@ with gr.Blocks() as demo:
# set the threshold for removing the bounding boxes with large overlap, default is 0.1 # set the threshold for removing the bounding boxes with large overlap, default is 0.1
iou_threshold_component = gr.Slider( iou_threshold_component = gr.Slider(
label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1) label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
use_paddleocr_component = gr.Checkbox(
label='Use PaddleOCR', default=True)
submit_button_component = gr.Button( submit_button_component = gr.Button(
value='Submit', variant='primary') value='Submit', variant='primary')
with gr.Column(): with gr.Column():
@@ -99,7 +102,8 @@ with gr.Blocks() as demo:
inputs=[ inputs=[
image_input_component, image_input_component,
box_threshold_component, box_threshold_component,
iou_threshold_component iou_threshold_component,
use_paddleocr_component
], ],
outputs=[image_output_component, text_output_component] outputs=[image_output_component, text_output_component]
) )

View File

@@ -14,3 +14,5 @@ dill
accelerate accelerate
timm timm
einops==0.8.0 einops==0.8.0
paddlepaddle
paddleocr

View File

@@ -18,7 +18,17 @@ import numpy as np
# %matplotlib inline # %matplotlib inline
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import easyocr import easyocr
from paddleocr import PaddleOCR
reader = easyocr.Reader(['en']) reader = easyocr.Reader(['en'])
paddle_ocr = PaddleOCR(
lang='en', # other lang also available
use_angle_cls=False,
use_gpu=False, # using cuda will conflict with pytorch in the same process
show_log=False,
max_batch_size=1024,
use_dilation=True, # improves accuracy
det_db_score_mode='slow', # improves accuracy
rec_batch_num=1024)
import time import time
import base64 import base64
@@ -370,14 +380,18 @@ def get_xywh_yolo(input):
def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None): def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
if easyocr_args is None: if use_paddleocr:
easyocr_args = {} result = paddle_ocr.ocr(image_path, cls=False)[0]
result = reader.readtext(image_path, **easyocr_args) coord = [item[0] for item in result]
is_goal_filtered = False text = [item[1][0] for item in result]
# print('goal filtering pred:', result[-5:]) else: # EasyOCR
coord = [item[0] for item in result] if easyocr_args is None:
text = [item[1] for item in result] easyocr_args = {}
result = reader.readtext(image_path, **easyocr_args)
# print('goal filtering pred:', result[-5:])
coord = [item[0] for item in result]
text = [item[1] for item in result]
# read the image using cv2 # read the image using cv2
if display_img: if display_img:
opencv_img = cv2.imread(image_path) opencv_img = cv2.imread(image_path)