remove need to write to disk

This commit is contained in:
Thomas Dhome Casanova (from Dev Box)
2025-01-20 18:29:46 -08:00
parent 85f5fc0385
commit 6cb310d124
2 changed files with 29 additions and 20 deletions

View File

@@ -35,7 +35,7 @@ import base64
import os
import ast
import torch
from typing import Tuple, List
from typing import Tuple, List, Union
from torchvision.ops import box_convert
import re
from torchvision.transforms import ToPILImage
@@ -384,20 +384,20 @@ def predict(model, image, caption, box_threshold, text_threshold):
return boxes, logits, phrases
def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7):
""" Use huggingface model to replace the original model
"""
# model = model['model']
if scale_img:
result = model.predict(
source=image_path,
source=image,
conf=box_threshold,
imgsz=imgsz,
iou=iou_threshold, # default 0.7
)
else:
result = model.predict(
source=image_path,
source=image,
conf=box_threshold,
iou=iou_threshold, # default 0.7
)
@@ -408,15 +408,21 @@ def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_thresho
return boxes, conf, phrases
def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
""" ocr_bbox: list of xyxy format bbox
def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
"""Process either an image path or Image object
Args:
image_source: Either a file path (str) or PIL Image object
...
"""
image_source = Image.open(img_path).convert("RGB")
if isinstance(image_source, str):
image_source = Image.open(image_source).convert("RGB")
w, h = image_source.size
if not imgsz:
imgsz = (h, w)
# print('image size:', w, h)
xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
image_source = np.asarray(image_source)
phrases = [str(i) for i in range(len(phrases))]
@@ -545,5 +551,13 @@ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_
# print('bounding box!!!', bb)
return (text, bb), goal_filtering
def get_ocr_bbox(image):
text_threshold = 0.8
result = paddle_ocr.ocr(image, cls=False)[0]
coord = [item[0] for item in result if item[1][1] > text_threshold]
text = [item[1][0] for item in result if item[1][1] > text_threshold]
bb = [get_xyxy(item) for item in coord]
return text, bb