update readme; demo

This commit is contained in:
yadonglu
2024-10-09 22:31:38 +00:00
parent 4a8758b865
commit 664407ae9b
6 changed files with 174 additions and 535 deletions

View File

@@ -2,9 +2,17 @@
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/yadonglu/sandbox/miniconda/envs/omni/lib/python3.12/site-packages/ultralytics/nn/tasks.py:714: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" ckpt = torch.load(file, map_location=\"cpu\")\n"
]
},
{
"data": {
"text/plain": [
@@ -365,7 +373,7 @@
")"
]
},
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -377,19 +385,34 @@
"from PIL import Image\n",
"device = 'cuda'\n",
"\n",
"# dino_model = get_dino_model(load_hf_model=True, device=device)\n",
"som_model = get_yolo_model(model_path='omniparser/weights/best.pt')\n",
"\n",
"# caption_model_processor = get_caption_model_processor(\"Salesforce/blip2-opt-2.7b\", device=device)\n",
"# caption_model_processor['model'].to(torch.float32)\n",
"som_model.to(device)\n",
"\n"
"som_model = get_yolo_model(model_path='weights/omniparser/weights/best.pt')\n",
"som_model.to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.98it/s]\n"
]
}
],
"source": [
"\n",
"caption_model_processor = get_caption_model_processor(model_name_or_path=\"weights/omniparser/blipv2_ui_merge\", device=device)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
@@ -397,7 +420,7 @@
"(device(type='cuda', index=0), ultralytics.models.yolo.model.YOLO)"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -408,7 +431,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -416,8 +439,8 @@
"output_type": "stream",
"text": [
"\n",
"image 1/1 /home/yadonglu/sandbox/OmniParser/imgs/pc_1.png: 800x1280 211 icons, 51.2ms\n",
"Speed: 3.7ms preprocess, 51.2ms inference, 160.7ms postprocess per image at shape (1, 3, 800, 1280)\n"
"image 1/1 /home/yadonglu/sandbox/OmniParser/imgs/pc_1.png: 800x1280 211 icons, 29.0ms\n",
"Speed: 4.1ms preprocess, 29.0ms inference, 121.1ms postprocess per image at shape (1, 3, 800, 1280)\n"
]
}
],
@@ -458,14 +481,14 @@
"ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})\n",
"text, ocr_bbox = ocr_bbox_rslt\n",
"\n",
"dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)\n",
"dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,use_local_semantics=False)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -532,87 +555,146 @@
" '58': array([ 1617, 1187, 564, 63], dtype=float32),\n",
" '59': array([ 602, 1944, 242, 32], dtype=float32),\n",
" '60': array([ 2034, 277, 68, 39], dtype=float32),\n",
" '61': array([ 2965.8, 11.372, 73.33, 65.827], dtype=float32),\n",
" '62': array([ 2963.6, 104.42, 45.381, 45.756], dtype=float32),\n",
" '63': array([ 198.97, 28.516, 79.805, 38.483], dtype=float32),\n",
" '64': array([ 608.18, 181.26, 342.97, 50.053], dtype=float32),\n",
" '65': array([ 1300.6, 250, 33.767, 35.427], dtype=float32),\n",
" '66': array([ 304.13, 30.349, 37.342, 36.602], dtype=float32),\n",
" '67': array([ 667.74, 241.55, 47.529, 53.978], dtype=float32),\n",
" '68': array([ 822.62, 244.9, 47.754, 51.829], dtype=float32),\n",
" '69': array([ 770.31, 244.16, 46.905, 51.143], dtype=float32),\n",
" '70': array([ 1248.1, 251.22, 31.97, 32.78], dtype=float32),\n",
" '71': array([ 1048.9, 244.81, 45.524, 47], dtype=float32),\n",
" '72': array([ 438.99, 28.466, 35.963, 37.508], dtype=float32),\n",
" '73': array([ 954.49, 181.67, 94.131, 49.485], dtype=float32),\n",
" '74': array([ 363.65, 29.569, 32.883, 36.315], dtype=float32),\n",
" '75': array([ 497.99, 28.813, 32.513, 34.289], dtype=float32),\n",
" '76': array([ 1332.3, 188.4, 32.144, 34.772], dtype=float32),\n",
" '77': array([ 1137.5, 251.92, 34.499, 34.152], dtype=float32),\n",
" '78': array([ 880.41, 249.27, 39.986, 37.76], dtype=float32),\n",
" '79': array([ 954.73, 33.853, 23.966, 27.856], dtype=float32),\n",
" '80': array([ 2888, 21.725, 45.384, 44.9], dtype=float32),\n",
" '81': array([ 1997.3, 222.31, 34.341, 34.824], dtype=float32),\n",
" '82': array([ 625.14, 249.45, 29.211, 36.318], dtype=float32),\n",
" '83': array([ 554.33, 25.136, 36.99, 42.99], dtype=float32),\n",
" '84': array([ 2786.6, 22.758, 52.631, 49.663], dtype=float32),\n",
" '85': array([ 1812.5, 176.7, 57.761, 59.091], dtype=float32),\n",
" '86': array([ 3170.5, 26.987, 44.986, 46.376], dtype=float32),\n",
" '87': array([ 1284.7, 182.54, 32.466, 49.088], dtype=float32),\n",
" '88': array([ 423.68, 280.2, 28.832, 30.261], dtype=float32),\n",
" '89': array([ 1716.8, 179.54, 59.374, 50.265], dtype=float32),\n",
" '90': array([ 344.3, 185.95, 54.543, 42.996], dtype=float32),\n",
" '91': array([ 1515.7, 252.05, 33.876, 33.092], dtype=float32),\n",
" '92': array([ 1090.8, 243.67, 36.191, 50.504], dtype=float32),\n",
" '93': array([ 1248.6, 189.02, 32.175, 33.261], dtype=float32),\n",
" '94': array([ 963.23, 254.24, 40.774, 33.985], dtype=float32),\n",
" '95': array([ 1717.4, 174.03, 52.504, 48.792], dtype=float32),\n",
" '96': array([ 3075.8, 30.369, 39.574, 38.682], dtype=float32),\n",
" '97': array([ 3187.4, 107.39, 33.734, 40.21], dtype=float32),\n",
" '98': array([ 2966.8, 168.2, 91.261, 105.9], dtype=float32),\n",
" '99': array([ 30.782, 33.378, 33.3, 31.807], dtype=float32),\n",
" '100': array([ 1196.9, 324.51, 27.192, 25.869], dtype=float32),\n",
" '101': array([ 3172.2, 310.76, 44.785, 39.261], dtype=float32),\n",
" '102': array([ 1998.4, 173.44, 30.627, 32.274], dtype=float32),\n",
" '103': array([ 787.33, 241.95, 64.931, 57.226], dtype=float32),\n",
" '104': array([ 2692.7, 21.129, 51.957, 55.543], dtype=float32),\n",
" '105': array([ 1170.2, 247.59, 34.892, 42.735], dtype=float32),\n",
" '106': array([ 1910.2, 174.37, 58.716, 53.398], dtype=float32),\n",
" '107': array([ 2259.9, 225.75, 31.152, 32.17], dtype=float32),\n",
" '108': array([ 254.78, 181.7, 56.533, 44.474], dtype=float32),\n",
" '109': array([ 1047.4, 182.75, 71.279, 49.489], dtype=float32),\n",
" '110': array([ 424.87, 175.44, 29.783, 28.272], dtype=float32),\n",
" '111': array([ 892.42, 233.03, 55.596, 76.345], dtype=float32),\n",
" '112': array([ 2873.1, 180.61, 62.642, 47.626], dtype=float32),\n",
" '113': array([ 1996.8, 278.17, 31.037, 31.622], dtype=float32),\n",
" '114': array([ 11.883, 368.91, 564.39, 1578.2], dtype=float32),\n",
" '115': array([ 422.86, 225.74, 29.434, 28.735], dtype=float32),\n",
" '116': array([ 1794.2, 209.55, 92.729, 62.57], dtype=float32),\n",
" '117': array([ 1702.6, 235.84, 84.464, 33.964], dtype=float32)},\n",
" 'AutoSave')"
" '61': array([ 2965.8, 11.37, 73.327, 65.831], dtype=float32),\n",
" '62': array([ 2963.6, 104.43, 45.38, 45.75], dtype=float32),\n",
" '63': array([ 198.97, 28.514, 79.804, 38.484], dtype=float32),\n",
" '64': array([ 608.16, 181.26, 343, 50.052], dtype=float32),\n",
" '65': array([ 1300.6, 250, 33.764, 35.422], dtype=float32),\n",
" '66': array([ 304.13, 30.345, 37.345, 36.607], dtype=float32),\n",
" '67': array([ 667.74, 241.53, 47.522, 54.004], dtype=float32),\n",
" '68': array([ 822.62, 244.89, 47.756, 51.845], dtype=float32),\n",
" '69': array([ 770.31, 244.15, 46.912, 51.157], dtype=float32),\n",
" '70': array([ 1248.1, 251.23, 31.97, 32.777], dtype=float32),\n",
" '71': array([ 1048.9, 244.81, 45.531, 47.005], dtype=float32),\n",
" '72': array([ 438.98, 28.462, 35.958, 37.513], dtype=float32),\n",
" '73': array([ 954.49, 181.66, 94.134, 49.488], dtype=float32),\n",
" '74': array([ 363.65, 29.564, 32.889, 36.319], dtype=float32),\n",
" '75': array([ 497.99, 28.809, 32.51, 34.293], dtype=float32),\n",
" '76': array([ 1332.3, 188.4, 32.162, 34.782], dtype=float32),\n",
" '77': array([ 1137.5, 251.92, 34.494, 34.152], dtype=float32),\n",
" '78': array([ 880.41, 249.27, 39.99, 37.764], dtype=float32),\n",
" '79': array([ 954.73, 33.854, 23.97, 27.857], dtype=float32),\n",
" '80': array([ 2888, 21.728, 45.381, 44.893], dtype=float32),\n",
" '81': array([ 1997.3, 222.31, 34.338, 34.831], dtype=float32),\n",
" '82': array([ 625.13, 249.45, 29.214, 36.326], dtype=float32),\n",
" '83': array([ 554.33, 25.126, 36.996, 43.004], dtype=float32),\n",
" '84': array([ 2786.6, 22.757, 52.635, 49.66], dtype=float32),\n",
" '85': array([ 1812.5, 176.7, 57.783, 59.101], dtype=float32),\n",
" '86': array([ 3170.5, 26.972, 45.001, 46.404], dtype=float32),\n",
" '87': array([ 1284.7, 182.54, 32.421, 49.083], dtype=float32),\n",
" '88': array([ 423.68, 280.19, 28.83, 30.268], dtype=float32),\n",
" '89': array([ 1716.8, 179.53, 59.38, 50.26], dtype=float32),\n",
" '90': array([ 344.29, 185.94, 54.55, 43], dtype=float32),\n",
" '91': array([ 1515.7, 252.05, 33.876, 33.096], dtype=float32),\n",
" '92': array([ 1090.7, 243.67, 36.209, 50.511], dtype=float32),\n",
" '93': array([ 1248.6, 189.02, 32.176, 33.262], dtype=float32),\n",
" '94': array([ 963.23, 254.23, 40.772, 33.994], dtype=float32),\n",
" '95': array([ 1717.4, 174.03, 52.489, 48.778], dtype=float32),\n",
" '96': array([ 3075.8, 30.359, 39.588, 38.699], dtype=float32),\n",
" '97': array([ 3187.4, 107.38, 33.732, 40.21], dtype=float32),\n",
" '98': array([ 2966.8, 168.2, 91.264, 105.89], dtype=float32),\n",
" '99': array([ 30.783, 33.381, 33.3, 31.808], dtype=float32),\n",
" '100': array([ 1196.9, 324.51, 27.192, 25.871], dtype=float32),\n",
" '101': array([ 3172.2, 310.75, 44.795, 39.267], dtype=float32),\n",
" '102': array([ 1998.4, 173.44, 30.628, 32.281], dtype=float32),\n",
" '103': array([ 787.33, 241.94, 64.897, 57.237], dtype=float32),\n",
" '104': array([ 2692.7, 21.127, 51.957, 55.546], dtype=float32),\n",
" '105': array([ 1170.2, 247.59, 34.857, 42.74], dtype=float32),\n",
" '106': array([ 1910.1, 174.37, 58.718, 53.371], dtype=float32),\n",
" '107': array([ 2259.9, 225.75, 31.158, 32.173], dtype=float32),\n",
" '108': array([ 254.78, 181.69, 56.54, 44.482], dtype=float32),\n",
" '109': array([ 1047.4, 182.75, 71.273, 49.486], dtype=float32),\n",
" '110': array([ 892.42, 233.03, 55.595, 76.336], dtype=float32),\n",
" '111': array([ 424.87, 175.43, 29.786, 28.278], dtype=float32),\n",
" '112': array([ 2873.1, 180.62, 62.633, 47.619], dtype=float32),\n",
" '113': array([ 1996.8, 278.16, 31.037, 31.629], dtype=float32),\n",
" '114': array([ 11.928, 369.01, 564.33, 1577.8], dtype=float32),\n",
" '115': array([ 422.86, 225.74, 29.438, 28.741], dtype=float32),\n",
" '116': array([ 1702.6, 235.74, 84.462, 34.063], dtype=float32)},\n",
" ['Text Box ID 0: AutoSave',\n",
" 'Text Box ID 1: Presentation2',\n",
" 'Text Box ID 2: PowerPoint',\n",
" 'Text Box ID 3: General*',\n",
" 'Text Box ID 4: Search',\n",
" 'Text Box ID 5: Yadong',\n",
" 'Text Box ID 6: File',\n",
" 'Text Box ID 7: Home',\n",
" 'Text Box ID 8: Insert',\n",
" 'Text Box ID 9: Draw',\n",
" 'Text Box ID 10: Design',\n",
" 'Text Box ID 11: Transitions',\n",
" 'Text Box ID 12: Animations',\n",
" 'Text Box ID 13: Slide Show',\n",
" 'Text Box ID 14: Record',\n",
" 'Text Box ID 15: Review',\n",
" 'Text Box ID 16: View',\n",
" 'Text Box ID 17: Help',\n",
" 'Text Box ID 18: Record',\n",
" 'Text Box ID 19: Present in Teams',\n",
" 'Text Box ID 20: Share',\n",
" 'Text Box ID 21: Layout',\n",
" 'Text Box ID 22: A\" | A',\n",
" 'Text Box ID 23: 8 =#~',\n",
" 'Text Box ID 24: Shape',\n",
" 'Text Box ID 25: Find',\n",
" 'Text Box ID 26: Paste',\n",
" 'Text Box ID 27: New',\n",
" 'Text Box ID 28: Reuse',\n",
" 'Text Box ID 29: Reset',\n",
" 'Text Box ID 30: [t]',\n",
" 'Text Box ID 31: Shapes Arrange',\n",
" 'Text Box ID 32: Quick',\n",
" 'Text Box ID 33: Shape Outline',\n",
" 'Text Box ID 34: Replace',\n",
" 'Text Box ID 35: Dictate',\n",
" 'Text Box ID 36: Sensitivity',\n",
" 'Text Box ID 37: Add-ins',\n",
" 'Text Box ID 38: Designer Copilot',\n",
" 'Text Box ID 39: 4',\n",
" 'Text Box ID 40: Aa ~',\n",
" 'Text Box ID 41: 22E6',\n",
" 'Text Box ID 42: Slide',\n",
" 'Text Box ID 43: Slides',\n",
" 'Text Box ID 44: Section',\n",
" 'Text Box ID 45: Styles',\n",
" 'Text Box ID 46: Effects',\n",
" 'Text Box ID 47: Select',\n",
" 'Text Box ID 48: Clipboard',\n",
" 'Text Box ID 49: Slides',\n",
" 'Text Box ID 50: Font',\n",
" 'Text Box ID 51: Paragraph',\n",
" 'Text Box ID 52: Drawing',\n",
" 'Text Box ID 53: Editing',\n",
" 'Text Box ID 54: Voice',\n",
" 'Text Box ID 55: Sensitivity',\n",
" 'Text Box ID 56: Add-ins',\n",
" 'Text Box ID 57: Click to add title',\n",
" 'Text Box ID 58: Click to add subtitle',\n",
" 'Text Box ID 59: Click to add notes',\n",
" 'Text Box ID 60: Shape'])"
]
},
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"label_coordinates, parsed_content_list[0].split(': ')[1]"
"label_coordinates, parsed_content_list#[0].split(': ')[1]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7c1b13b5f5b0>"
"<matplotlib.image.AxesImage at 0x7ff3e91f4fb0>"
]
},
"execution_count": 9,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
@@ -640,238 +722,6 @@
"plt.imshow(image)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# wrapped Omniparser"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parsing image: examples/pc_1.png\n",
"\n",
"image 1/1 /home/yadonglu/sandbox/screenparsing_collab/screenparsing/omniparser/examples/pc_1.png: 800x1280 210 icons, 55.6ms\n",
"Speed: 7.7ms preprocess, 55.6ms inference, 1.5ms postprocess per image at shape (1, 3, 800, 1280)\n",
"boxes cpu\n",
"Time taken for Omniparser on cpu: 2.029506206512451\n"
]
}
],
"source": [
"from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model\n",
"import torch\n",
"from ultralytics import YOLO\n",
"from PIL import Image\n",
"from typing import Dict, Tuple, List\n",
"import io\n",
"import base64\n",
"\n",
"\n",
"config = {\n",
" 'som_model_path': 'finetuned_icon_detect.pt',\n",
" 'device': 'cpu',\n",
" 'caption_model_path': 'Salesforce/blip2-opt-2.7b',\n",
" 'draw_bbox_config': {\n",
" 'text_scale': 0.8,\n",
" 'text_thickness': 2,\n",
" 'text_padding': 3,\n",
" 'thickness': 3,\n",
" },\n",
" 'BOX_TRESHOLD': 0.05\n",
"}\n",
"\n",
"\n",
"class Omniparser(object):\n",
" def __init__(self, config: Dict):\n",
" self.config = config\n",
" \n",
" self.som_model = get_yolo_model(model_path=config['som_model_path'])\n",
" # self.caption_model_processor = get_caption_model_processor(config['caption_model_path'], device=cofig['device'])\n",
" # self.caption_model_processor['model'].to(torch.float32)\n",
"\n",
" def parse(self, image_path: str):\n",
" print('Parsing image:', image_path)\n",
" ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})\n",
" text, ocr_bbox = ocr_bbox_rslt\n",
"\n",
" draw_bbox_config = self.config['draw_bbox_config']\n",
" BOX_TRESHOLD = self.config['BOX_TRESHOLD']\n",
" dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)\n",
" \n",
" image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))\n",
" # formating output\n",
" return_list = [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},\n",
" 'text': parsed_content_list[i].split(': ')[1], 'type':'text'} for i, (k, coord) in enumerate(label_coordinates.items()) if i < len(parsed_content_list)]\n",
" return_list.extend(\n",
" [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},\n",
" 'text': 'None', 'type':'icon'} for i, (k, coord) in enumerate(label_coordinates.items()) if i >= len(parsed_content_list)]\n",
" )\n",
"\n",
" return [image, return_list]\n",
" \n",
"parser = Omniparser(config)\n",
"image_path = 'imgs/pc_1.png'\n",
"\n",
"# time the parser\n",
"import time\n",
"s = time.time()\n",
"image, parsed_content_list = parser.parse(image_path)\n",
"device = config['device']\n",
"print(f'Time taken for Omniparser on {device}:', time.time() - s)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0: 800x1280 210 icons, 49.4ms\n",
"Speed: 5.7ms preprocess, 49.4ms inference, 1.1ms postprocess per image at shape (1, 3, 800, 1280)\n",
"boxes cpu\n",
"Time taken for Omniparser finetuned YOLO module on cpu: 0.2898883819580078\n"
]
}
],
"source": [
"from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model, predict_yolo\n",
"import torch\n",
"from ultralytics import YOLO\n",
"from PIL import Image\n",
"from typing import Dict, Tuple, List\n",
"import io\n",
"import base64\n",
"\n",
"\n",
"config = {\n",
" 'som_model_path': 'finetuned_icon_detect.pt',\n",
" 'device': 'cpu',\n",
" 'caption_model_path': 'Salesforce/blip2-opt-2.7b',\n",
" 'draw_bbox_config': {\n",
" 'text_scale': 0.8,\n",
" 'text_thickness': 2,\n",
" 'text_padding': 3,\n",
" 'thickness': 3,\n",
" },\n",
" 'BOX_TRESHOLD': 0.05\n",
"}\n",
"\n",
"class OmniparserYOLO(object):\n",
" def __init__(self, config: Dict):\n",
" self.config = config\n",
" self.som_model = get_yolo_model(model_path=config['som_model_path'])\n",
"\n",
" def parse(self, image):\n",
" draw_bbox_config = self.config['draw_bbox_config']\n",
" BOX_TRESHOLD = self.config['BOX_TRESHOLD']\n",
" xyxy, logits, phrases = predict_yolo(model=self.som_model, image_path=image, box_threshold=BOX_TRESHOLD)\n",
" # print('xyxy:', xyxy)\n",
" xyxy = xyxy.tolist()\n",
" # formating output\n",
" return_list = [{'from': 'omniparserYOLO', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2]-coord[0], 'height':coord[3] - coord[1]},\n",
" 'text': 'None', 'type':'icon'} for i, coord in enumerate(xyxy)]\n",
" \n",
" return [None, return_list]\n",
" \n",
"parser = OmniparserYOLO(config)\n",
"image_path = 'imgs/pc_1.png'\n",
"image = Image.open(image_path)\n",
"\n",
"# time the parser\n",
"import time\n",
"s = time.time()\n",
"_, parsed_content_list = parser.parse(image)\n",
"device = config['device']\n",
"print(f'Time taken for Omniparser finetuned YOLO module on {device}:', time.time() - s)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# florence caption model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/yadonglu/anaconda3/envs/pilot/lib/python3.9/site-packages/transformers/utils/generic.py:342: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" _torch_pytree._register_pytree_node(\n",
"/home/yadonglu/anaconda3/envs/pilot/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from transformers import AutoProcessor, AutoModelForCausalLM \n",
"import torch\n",
"device = 'cpu'\n",
"torch_dtype = torch.float16 if device == 'cuda' else torch.float32\n",
"model = AutoModelForCausalLM.from_pretrained(\"/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_rai_win_ep5/epoch_5\", torch_dtype=torch_dtype, trust_remote_code=True).to(device)\n",
"processor = AutoProcessor.from_pretrained(\"microsoft/Florence-2-base\", trust_remote_code=True)\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['settings or configuration options.']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from PIL import Image\n",
"prompt = \"<CAPTION>\"\n",
"image_path = 'imgs/settings.png'\n",
"image = [Image.open(image_path).convert('RGB')]\n",
"inputs = processor(images=image, text=[prompt]*len(image), return_tensors=\"pt\").to(device=device)\n",
"generated_ids = model.generate(input_ids=inputs[\"input_ids\"],pixel_values=inputs[\"pixel_values\"],max_new_tokens=1024,num_beams=3, do_sample=False)\n",
"generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)\n",
"generated_text = [gen.strip() for gen in generated_text]\n",
"generated_text"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -896,7 +746,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.12.0"
}
},
"nbformat": 4,