move back to check_ocr_box
This commit is contained in:
@@ -4,7 +4,7 @@ import sys
|
|||||||
import os
|
import os
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, get_ocr_bbox
|
from utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Dict, Tuple, List
|
from typing import Dict, Tuple, List
|
||||||
@@ -45,7 +45,7 @@ class Omniparser(object):
|
|||||||
}
|
}
|
||||||
BOX_TRESHOLD = config['BOX_TRESHOLD']
|
BOX_TRESHOLD = config['BOX_TRESHOLD']
|
||||||
|
|
||||||
text, ocr_bbox = get_ocr_bbox(image)
|
(text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
|
||||||
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
|
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
|
||||||
|
|
||||||
return dino_labled_img, parsed_content_list
|
return dino_labled_img, parsed_content_list
|
||||||
|
|||||||
36
utils.py
36
utils.py
@@ -501,50 +501,52 @@ def get_xywh_yolo(input):
|
|||||||
x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
|
x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
|
||||||
x, y, w, h = int(x), int(y), int(w), int(h)
|
x, y, w, h = int(x), int(y), int(w), int(h)
|
||||||
return x, y, w, h
|
return x, y, w, h
|
||||||
|
|
||||||
|
|
||||||
|
def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
|
||||||
def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
|
if isinstance(image_source, str):
|
||||||
|
image_source = Image.open(image_source)
|
||||||
|
if image_source.mode == 'RGBA':
|
||||||
|
# Convert RGBA to RGB to avoid alpha channel issues
|
||||||
|
image_source = image_source.convert('RGB')
|
||||||
|
image_np = np.array(image_source)
|
||||||
|
w, h = image_source.size
|
||||||
if use_paddleocr:
|
if use_paddleocr:
|
||||||
if easyocr_args is None:
|
if easyocr_args is None:
|
||||||
text_threshold = 0.5
|
text_threshold = 0.5
|
||||||
else:
|
else:
|
||||||
text_threshold = easyocr_args['text_threshold']
|
text_threshold = easyocr_args['text_threshold']
|
||||||
result = paddle_ocr.ocr(image_path, cls=False)[0]
|
result = paddle_ocr.ocr(image_np, cls=False)[0]
|
||||||
# conf = [item[1] for item in result]
|
|
||||||
coord = [item[0] for item in result if item[1][1] > text_threshold]
|
coord = [item[0] for item in result if item[1][1] > text_threshold]
|
||||||
text = [item[1][0] for item in result if item[1][1] > text_threshold]
|
text = [item[1][0] for item in result if item[1][1] > text_threshold]
|
||||||
else: # EasyOCR
|
else: # EasyOCR
|
||||||
if easyocr_args is None:
|
if easyocr_args is None:
|
||||||
easyocr_args = {}
|
easyocr_args = {}
|
||||||
result = reader.readtext(image_path, **easyocr_args)
|
result = reader.readtext(image_np, **easyocr_args)
|
||||||
# print('goal filtering pred:', result[-5:])
|
|
||||||
coord = [item[0] for item in result]
|
coord = [item[0] for item in result]
|
||||||
text = [item[1] for item in result]
|
text = [item[1] for item in result]
|
||||||
# read the image using cv2
|
|
||||||
if display_img:
|
if display_img:
|
||||||
opencv_img = cv2.imread(image_path)
|
opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
||||||
opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
|
|
||||||
bb = []
|
bb = []
|
||||||
for item in coord:
|
for item in coord:
|
||||||
x, y, a, b = get_xywh(item)
|
x, y, a, b = get_xywh(item)
|
||||||
# print(x, y, a, b)
|
|
||||||
bb.append((x, y, a, b))
|
bb.append((x, y, a, b))
|
||||||
cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
|
cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
|
||||||
|
# matplotlib expects RGB
|
||||||
# Display the image
|
plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB))
|
||||||
plt.imshow(opencv_img)
|
|
||||||
else:
|
else:
|
||||||
if output_bb_format == 'xywh':
|
if output_bb_format == 'xywh':
|
||||||
bb = [get_xywh(item) for item in coord]
|
bb = [get_xywh(item) for item in coord]
|
||||||
elif output_bb_format == 'xyxy':
|
elif output_bb_format == 'xyxy':
|
||||||
bb = [get_xyxy(item) for item in coord]
|
bb = [get_xyxy(item) for item in coord]
|
||||||
# print('bounding box!!!', bb)
|
|
||||||
return (text, bb), goal_filtering
|
return (text, bb), goal_filtering
|
||||||
|
|
||||||
def get_ocr_bbox(image):
|
def get_ocr_bbox(image: Image.Image):
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
# Convert RGBA to RGB to avoid alpha channel issues
|
||||||
|
image = image.convert('RGB')
|
||||||
|
image_np = np.array(image)
|
||||||
|
result = paddle_ocr.ocr(image_np, cls=False)[0]
|
||||||
text_threshold = 0.8
|
text_threshold = 0.8
|
||||||
result = paddle_ocr.ocr(image, cls=False)[0]
|
|
||||||
coord = [item[0] for item in result if item[1][1] > text_threshold]
|
coord = [item[0] for item in result if item[1][1] > text_threshold]
|
||||||
text = [item[1][0] for item in result if item[1][1] > text_threshold]
|
text = [item[1][0] for item in result if item[1][1] > text_threshold]
|
||||||
bb = [get_xyxy(item) for item in coord]
|
bb = [get_xyxy(item) for item in coord]
|
||||||
|
|||||||
Reference in New Issue
Block a user