move back to check_ocr_box

This commit is contained in:
Thomas Dhome-Casanova
2025-01-27 22:38:05 -08:00
parent 9cb2263545
commit 72040b9ded
2 changed files with 21 additions and 19 deletions

View File

@@ -501,50 +501,52 @@ def get_xywh_yolo(input):
x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
x, y, w, h = int(x), int(y), int(w), int(h)
return x, y, w, h
def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
if isinstance(image_source, str):
image_source = Image.open(image_source)
if image_source.mode == 'RGBA':
# Convert RGBA to RGB to avoid alpha channel issues
image_source = image_source.convert('RGB')
image_np = np.array(image_source)
w, h = image_source.size
if use_paddleocr:
if easyocr_args is None:
text_threshold = 0.5
else:
text_threshold = easyocr_args['text_threshold']
result = paddle_ocr.ocr(image_path, cls=False)[0]
# conf = [item[1] for item in result]
result = paddle_ocr.ocr(image_np, cls=False)[0]
coord = [item[0] for item in result if item[1][1] > text_threshold]
text = [item[1][0] for item in result if item[1][1] > text_threshold]
else: # EasyOCR
if easyocr_args is None:
easyocr_args = {}
result = reader.readtext(image_path, **easyocr_args)
# print('goal filtering pred:', result[-5:])
result = reader.readtext(image_np, **easyocr_args)
coord = [item[0] for item in result]
text = [item[1] for item in result]
# read the image using cv2
if display_img:
opencv_img = cv2.imread(image_path)
opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
bb = []
for item in coord:
x, y, a, b = get_xywh(item)
# print(x, y, a, b)
bb.append((x, y, a, b))
cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
# Display the image
plt.imshow(opencv_img)
# matplotlib expects RGB
plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB))
else:
if output_bb_format == 'xywh':
bb = [get_xywh(item) for item in coord]
elif output_bb_format == 'xyxy':
bb = [get_xyxy(item) for item in coord]
# print('bounding box!!!', bb)
return (text, bb), goal_filtering
def get_ocr_bbox(image):
def get_ocr_bbox(image: Image.Image):
if image.mode == 'RGBA':
# Convert RGBA to RGB to avoid alpha channel issues
image = image.convert('RGB')
image_np = np.array(image)
result = paddle_ocr.ocr(image_np, cls=False)[0]
text_threshold = 0.8
result = paddle_ocr.ocr(image, cls=False)[0]
coord = [item[0] for item in result if item[1][1] > text_threshold]
text = [item[1][0] for item in result if item[1][1] > text_threshold]
bb = [get_xyxy(item) for item in coord]