Merge branch 'master' of https://github.com/microsoft/OmniParser into demo

This commit is contained in:
yadonglu
2024-12-05 11:47:04 -08:00
5 changed files with 100 additions and 161 deletions

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@ weights/icon_caption_blip2
weights/icon_caption_florence weights/icon_caption_florence
weights/icon_detect/ weights/icon_detect/
weights/icon_detect_v1_5/ weights/icon_detect_v1_5/
weights/icon_detect_v1_5_2/
.gradio .gradio
__pycache__/ __pycache__/
debug.ipynb debug.ipynb

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -255,7 +255,7 @@ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
intersection = intersection_area(box1, box2) intersection = intersection_area(box1, box2)
ratio1 = intersection / box_area(box1) ratio1 = intersection / box_area(box1)
return ratio1 > 0.95 return ratio1 > 0.80
# boxes = boxes.tolist() # boxes = boxes.tolist()
filtered_boxes = [] filtered_boxes = []
@@ -274,27 +274,28 @@ def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
if is_valid_box: if is_valid_box:
# add the following 2 lines to include ocr bbox # add the following 2 lines to include ocr bbox
if ocr_bbox: if ocr_bbox:
# only add the box if it does not overlap with any ocr bbox # keep yolo boxes + prioritize ocr label
box_added = False box_added = False
for box3_elem in ocr_bbox: for box3_elem in ocr_bbox:
if not box_added: if not box_added:
box3 = box3_elem['bbox'] box3 = box3_elem['bbox']
if is_inside(box3, box1): # ocr inside icon if is_inside(box3, box1): # ocr inside icon
box_added = True # box_added = True
# delete the box3_elem from ocr_bbox # delete the box3_elem from ocr_bbox
try: try:
filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content']}) filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content']})
filtered_boxes.remove(box3_elem) filtered_boxes.remove(box3_elem)
# print('remove ocr bbox:', box3_elem)
except: except:
continue continue
break # break
elif is_inside(box1, box3): # icon inside ocr elif is_inside(box1, box3): # icon inside ocr
box_added = True box_added = True
try: # try:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None}) # filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
filtered_boxes.remove(box3_elem) # filtered_boxes.remove(box3_elem)
except: # except:
continue # continue
break break
else: else:
continue continue

View File

@@ -3,25 +3,24 @@ from ultralytics.nn.tasks import DetectionModel
from safetensors.torch import load_file from safetensors.torch import load_file
import argparse import argparse
import yaml import yaml
import os
# accept args to specify v1 or v1_5 # accept args to specify v1 or v1_5
parser = argparse.ArgumentParser(description='Specify version v1 or v1_5') parser = argparse.ArgumentParser(description='Specify version v1 or v1_5')
parser.add_argument('--version', choices=['v1', 'v1_5'], required=True, help='Specify the version: v1 or v1_5') parser.add_argument('--weights_dir', type=str, required=True, help='Specify the path to the safetensor file', default='weights/icon_detect_v1_5')
args = parser.parse_args() args = parser.parse_args()
if args.version == 'v1': tensor_dict = load_file(os.path.join(args.weights_dir, "model.safetensors"))
tensor_dict = load_file("weights/icon_detect/model.safetensors") model = DetectionModel(os.path.join(args.weights_dir, "model.yaml"))
model = DetectionModel('weights/icon_detect/model.yaml') # from ultralytics import YOLO
model.load_state_dict(tensor_dict) # som_model = YOLO("yolo11m.pt")
torch.save({'model':model}, 'weights/icon_detect/best.pt') # model = som_model.model
elif args.version == 'v1_5':
print("Converting v1_5") model.load_state_dict(tensor_dict)
tensor_dict = load_file("weights/icon_detect_v1_5/model.safetensors") save_dict = {'model':model}
model = DetectionModel('weights/icon_detect_v1_5/model.yaml')
model.load_state_dict(tensor_dict) with open(os.path.join(args.weights_dir, "train_args.yaml"), 'r') as file:
save_dict = {'model':model} train_args = yaml.safe_load(file)
save_dict.update(train_args)
torch.save(save_dict, os.path.join(args.weights_dir, "best.pt"))
with open("weights/icon_detect_v1_5/train_args.yaml", 'r') as file:
train_args = yaml.safe_load(file)
save_dict.update(train_args)
torch.save(save_dict, 'weights/icon_detect_v1_5/best.pt')