Fix import path issues

This commit is contained in:
Thomas Dhome-Casanova
2025-01-30 06:18:24 +00:00
parent 124d9f6fb6
commit 41464ccf1c
3 changed files with 19 additions and 26 deletions

View File

@@ -14,7 +14,7 @@
} }
], ],
"source": [ "source": [
"from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model\n", "from util.utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model\n",
"import torch\n", "import torch\n",
"from ultralytics import YOLO\n", "from ultralytics import YOLO\n",
"from PIL import Image\n", "from PIL import Image\n",
@@ -48,9 +48,9 @@
"source": [ "source": [
"# two choices for caption model: fine-tuned blip2 or florence2\n", "# two choices for caption model: fine-tuned blip2 or florence2\n",
"import importlib\n", "import importlib\n",
"import utils\n", "import util.utils\n",
"importlib.reload(utils)\n", "importlib.reload(util.utils)\n",
"from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model\n", "from util.utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model\n",
"# caption_model_processor = get_caption_model_processor(model_name=\"blip2\", model_name_or_path=\"weights/icon_caption_blip2\", device=device)\n", "# caption_model_processor = get_caption_model_processor(model_name=\"blip2\", model_name_or_path=\"weights/icon_caption_blip2\", device=device)\n",
"caption_model_processor = get_caption_model_processor(model_name=\"florence2\", model_name_or_path=\"weights/icon_caption_florence\", device=device)\n", "caption_model_processor = get_caption_model_processor(model_name=\"florence2\", model_name_or_path=\"weights/icon_caption_florence\", device=device)\n",
"\n" "\n"
@@ -102,9 +102,9 @@
"source": [ "source": [
"# reload utils\n", "# reload utils\n",
"import importlib\n", "import importlib\n",
"import utils\n", "import util.utils\n",
"importlib.reload(utils)\n", "importlib.reload(util.utils)\n",
"from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model\n", "from util.utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model\n",
"\n", "\n",
"image_path = 'imgs/google_page.png'\n", "image_path = 'imgs/google_page.png'\n",
"image_path = 'imgs/windows_home.png'\n", "image_path = 'imgs/windows_home.png'\n",
@@ -167,9 +167,9 @@
"# run on cpu!!!\n", "# run on cpu!!!\n",
"# reload utils\n", "# reload utils\n",
"import importlib\n", "import importlib\n",
"import utils\n", "import util.utils\n",
"importlib.reload(utils)\n", "importlib.reload(util.utils)\n",
"from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model\n", "from util.utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model\n",
"\n", "\n",
"image_path = 'imgs/google_page.png'\n", "image_path = 'imgs/google_page.png'\n",
"image_path = 'imgs/windows_home.png'\n", "image_path = 'imgs/windows_home.png'\n",
@@ -447,13 +447,6 @@
"source": [ "source": [
"parsed_content_list[-1]" "parsed_content_list[-1]"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

View File

@@ -1,22 +1,23 @@
''' '''
python -m remote_request --som_model_path ../weights/icon_detect_v1_5/model_v1_5.pt --caption_model_name florence2 --caption_model_path ../weights/icon_caption_florence --device cuda --BOX_TRESHOLD 0.05 python -m omniparserserver --som_model_path ../../weights/icon_detect_v1_5/model_v1_5.pt --caption_model_name florence2 --caption_model_path ../../weights/icon_caption_florence --device cuda --BOX_TRESHOLD 0.05
''' '''
import sys import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import time import time
from fastapi import FastAPI from fastapi import FastAPI
from pydantic import BaseModel from pydantic import BaseModel
import argparse import argparse
import uvicorn import uvicorn
from omniparser import Omniparser root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(root_dir)
from util.omniparser import Omniparser
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description='Omniparser API') parser = argparse.ArgumentParser(description='Omniparser API')
parser.add_argument('--som_model_path', type=str, default='../weights/icon_detect_v1_5/model_v1_5.pt', help='Path to the som model') parser.add_argument('--som_model_path', type=str, default='../../weights/icon_detect_v1_5/model_v1_5.pt', help='Path to the som model')
parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model') parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model')
parser.add_argument('--caption_model_path', type=str, default='../weights/icon_caption_florence', help='Path to the caption model') parser.add_argument('--caption_model_path', type=str, default='../../weights/icon_caption_florence', help='Path to the caption model')
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model') parser.add_argument('--device', type=str, default='cpu', help='Device to run the model')
parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection') parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API')

View File

@@ -1,9 +1,9 @@
from ...util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
import torch import torch
from PIL import Image from PIL import Image
import io import io
import base64 import base64
from typing import Dict
class Omniparser(object): class Omniparser(object):
def __init__(self, config: Dict): def __init__(self, config: Dict):
self.config = config self.config = config
@@ -25,9 +25,8 @@ class Omniparser(object):
'text_padding': max(int(3 * box_overlay_ratio), 1), 'text_padding': max(int(3 * box_overlay_ratio), 1),
'thickness': max(int(3 * box_overlay_ratio), 1), 'thickness': max(int(3 * box_overlay_ratio), 1),
} }
BOX_TRESHOLD = self.config['BOX_TRESHOLD']
(text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False) (text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128) dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = self.config['BOX_TRESHOLD'], output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
return dino_labled_img, parsed_content_list return dino_labled_img, parsed_content_list