diff --git a/README.md b/README.md index e6f0d30..0465b9f 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,8 @@ ## Install ```python conda create -n "omni" python==3.12 -pip install -r requirements.txt +conda activate omni +pip install -r requirement.txt ``` ## Examples: diff --git a/__pycache__/utils.cpython-312.pyc b/__pycache__/utils.cpython-312.pyc index 395a981..3b24508 100644 Binary files a/__pycache__/utils.cpython-312.pyc and b/__pycache__/utils.cpython-312.pyc differ diff --git a/demo.ipynb b/demo.ipynb index 3dc633e..ecdb8cb 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -2,9 +2,17 @@ "cells": [ { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/yadonglu/sandbox/miniconda/envs/omni/lib/python3.12/site-packages/ultralytics/nn/tasks.py:714: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " ckpt = torch.load(file, map_location=\"cpu\")\n" + ] + }, { "data": { "text/plain": [ @@ -365,7 +373,7 @@ ")" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -377,19 +385,34 @@ "from PIL import Image\n", "device = 'cuda'\n", "\n", - "# dino_model = get_dino_model(load_hf_model=True, device=device)\n", - "som_model = get_yolo_model(model_path='omniparser/weights/best.pt')\n", - "\n", - "# caption_model_processor = get_caption_model_processor(\"Salesforce/blip2-opt-2.7b\", device=device)\n", - "# caption_model_processor['model'].to(torch.float32)\n", - "som_model.to(device)\n", - "\n" + "som_model = get_yolo_model(model_path='weights/omniparser/weights/best.pt')\n", + "som_model.to(device)\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.98it/s]\n" + ] + } + ], + "source": [ + "\n", + "caption_model_processor = get_caption_model_processor(model_name_or_path=\"weights/omniparser/blipv2_ui_merge\", device=device)\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, "outputs": [ { "data": { @@ -397,7 +420,7 @@ "(device(type='cuda', index=0), ultralytics.models.yolo.model.YOLO)" ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -408,7 +431,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -416,8 +439,8 @@ "output_type": "stream", "text": [ "\n", - "image 1/1 /home/yadonglu/sandbox/OmniParser/imgs/pc_1.png: 800x1280 211 icons, 51.2ms\n", - "Speed: 3.7ms preprocess, 51.2ms inference, 160.7ms postprocess per image at shape (1, 3, 800, 1280)\n" + "image 1/1 /home/yadonglu/sandbox/OmniParser/imgs/pc_1.png: 800x1280 211 icons, 29.0ms\n", + "Speed: 4.1ms preprocess, 29.0ms inference, 121.1ms postprocess per image at shape (1, 3, 800, 1280)\n" ] } ], @@ -458,14 +481,14 @@ "ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})\n", "text, ocr_bbox = ocr_bbox_rslt\n", "\n", - "dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)\n", + "dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,use_local_semantics=False)\n", "\n", "\n" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -532,87 +555,146 @@ " '58': array([ 1617, 1187, 564, 63], dtype=float32),\n", " '59': array([ 602, 1944, 242, 32], dtype=float32),\n", " '60': array([ 2034, 277, 68, 39], dtype=float32),\n", - " '61': array([ 2965.8, 11.372, 73.33, 65.827], dtype=float32),\n", - " '62': array([ 2963.6, 104.42, 45.381, 45.756], dtype=float32),\n", - " '63': array([ 198.97, 28.516, 79.805, 38.483], dtype=float32),\n", - " '64': array([ 608.18, 181.26, 342.97, 50.053], dtype=float32),\n", - " '65': array([ 1300.6, 250, 33.767, 35.427], dtype=float32),\n", - " '66': array([ 304.13, 30.349, 37.342, 36.602], dtype=float32),\n", - " '67': array([ 667.74, 241.55, 47.529, 53.978], dtype=float32),\n", - " '68': array([ 822.62, 244.9, 47.754, 51.829], dtype=float32),\n", - " '69': array([ 770.31, 244.16, 46.905, 51.143], dtype=float32),\n", - " '70': array([ 1248.1, 251.22, 31.97, 32.78], dtype=float32),\n", - " '71': array([ 1048.9, 244.81, 45.524, 47], dtype=float32),\n", - " '72': array([ 438.99, 28.466, 35.963, 37.508], dtype=float32),\n", - " '73': array([ 954.49, 181.67, 94.131, 49.485], dtype=float32),\n", - " '74': array([ 363.65, 29.569, 32.883, 36.315], dtype=float32),\n", - " '75': array([ 497.99, 28.813, 32.513, 34.289], dtype=float32),\n", - " '76': array([ 1332.3, 188.4, 32.144, 34.772], dtype=float32),\n", - " '77': array([ 1137.5, 251.92, 34.499, 34.152], dtype=float32),\n", - " '78': array([ 880.41, 249.27, 39.986, 37.76], dtype=float32),\n", - " '79': array([ 954.73, 33.853, 23.966, 27.856], dtype=float32),\n", - " '80': array([ 2888, 21.725, 45.384, 44.9], dtype=float32),\n", - " '81': array([ 1997.3, 222.31, 34.341, 34.824], dtype=float32),\n", - " '82': array([ 625.14, 249.45, 29.211, 36.318], dtype=float32),\n", - " '83': array([ 554.33, 25.136, 36.99, 42.99], dtype=float32),\n", - " '84': array([ 2786.6, 22.758, 52.631, 49.663], dtype=float32),\n", - " '85': array([ 1812.5, 176.7, 57.761, 59.091], dtype=float32),\n", - " '86': array([ 3170.5, 26.987, 44.986, 46.376], dtype=float32),\n", - " '87': array([ 1284.7, 182.54, 32.466, 49.088], dtype=float32),\n", - " '88': array([ 423.68, 280.2, 28.832, 30.261], dtype=float32),\n", - " '89': array([ 1716.8, 179.54, 59.374, 50.265], dtype=float32),\n", - " '90': array([ 344.3, 185.95, 54.543, 42.996], dtype=float32),\n", - " '91': array([ 1515.7, 252.05, 33.876, 33.092], dtype=float32),\n", - " '92': array([ 1090.8, 243.67, 36.191, 50.504], dtype=float32),\n", - " '93': array([ 1248.6, 189.02, 32.175, 33.261], dtype=float32),\n", - " '94': array([ 963.23, 254.24, 40.774, 33.985], dtype=float32),\n", - " '95': array([ 1717.4, 174.03, 52.504, 48.792], dtype=float32),\n", - " '96': array([ 3075.8, 30.369, 39.574, 38.682], dtype=float32),\n", - " '97': array([ 3187.4, 107.39, 33.734, 40.21], dtype=float32),\n", - " '98': array([ 2966.8, 168.2, 91.261, 105.9], dtype=float32),\n", - " '99': array([ 30.782, 33.378, 33.3, 31.807], dtype=float32),\n", - " '100': array([ 1196.9, 324.51, 27.192, 25.869], dtype=float32),\n", - " '101': array([ 3172.2, 310.76, 44.785, 39.261], dtype=float32),\n", - " '102': array([ 1998.4, 173.44, 30.627, 32.274], dtype=float32),\n", - " '103': array([ 787.33, 241.95, 64.931, 57.226], dtype=float32),\n", - " '104': array([ 2692.7, 21.129, 51.957, 55.543], dtype=float32),\n", - " '105': array([ 1170.2, 247.59, 34.892, 42.735], dtype=float32),\n", - " '106': array([ 1910.2, 174.37, 58.716, 53.398], dtype=float32),\n", - " '107': array([ 2259.9, 225.75, 31.152, 32.17], dtype=float32),\n", - " '108': array([ 254.78, 181.7, 56.533, 44.474], dtype=float32),\n", - " '109': array([ 1047.4, 182.75, 71.279, 49.489], dtype=float32),\n", - " '110': array([ 424.87, 175.44, 29.783, 28.272], dtype=float32),\n", - " '111': array([ 892.42, 233.03, 55.596, 76.345], dtype=float32),\n", - " '112': array([ 2873.1, 180.61, 62.642, 47.626], dtype=float32),\n", - " '113': array([ 1996.8, 278.17, 31.037, 31.622], dtype=float32),\n", - " '114': array([ 11.883, 368.91, 564.39, 1578.2], dtype=float32),\n", - " '115': array([ 422.86, 225.74, 29.434, 28.735], dtype=float32),\n", - " '116': array([ 1794.2, 209.55, 92.729, 62.57], dtype=float32),\n", - " '117': array([ 1702.6, 235.84, 84.464, 33.964], dtype=float32)},\n", - " 'AutoSave')" + " '61': array([ 2965.8, 11.37, 73.327, 65.831], dtype=float32),\n", + " '62': array([ 2963.6, 104.43, 45.38, 45.75], dtype=float32),\n", + " '63': array([ 198.97, 28.514, 79.804, 38.484], dtype=float32),\n", + " '64': array([ 608.16, 181.26, 343, 50.052], dtype=float32),\n", + " '65': array([ 1300.6, 250, 33.764, 35.422], dtype=float32),\n", + " '66': array([ 304.13, 30.345, 37.345, 36.607], dtype=float32),\n", + " '67': array([ 667.74, 241.53, 47.522, 54.004], dtype=float32),\n", + " '68': array([ 822.62, 244.89, 47.756, 51.845], dtype=float32),\n", + " '69': array([ 770.31, 244.15, 46.912, 51.157], dtype=float32),\n", + " '70': array([ 1248.1, 251.23, 31.97, 32.777], dtype=float32),\n", + " '71': array([ 1048.9, 244.81, 45.531, 47.005], dtype=float32),\n", + " '72': array([ 438.98, 28.462, 35.958, 37.513], dtype=float32),\n", + " '73': array([ 954.49, 181.66, 94.134, 49.488], dtype=float32),\n", + " '74': array([ 363.65, 29.564, 32.889, 36.319], dtype=float32),\n", + " '75': array([ 497.99, 28.809, 32.51, 34.293], dtype=float32),\n", + " '76': array([ 1332.3, 188.4, 32.162, 34.782], dtype=float32),\n", + " '77': array([ 1137.5, 251.92, 34.494, 34.152], dtype=float32),\n", + " '78': array([ 880.41, 249.27, 39.99, 37.764], dtype=float32),\n", + " '79': array([ 954.73, 33.854, 23.97, 27.857], dtype=float32),\n", + " '80': array([ 2888, 21.728, 45.381, 44.893], dtype=float32),\n", + " '81': array([ 1997.3, 222.31, 34.338, 34.831], dtype=float32),\n", + " '82': array([ 625.13, 249.45, 29.214, 36.326], dtype=float32),\n", + " '83': array([ 554.33, 25.126, 36.996, 43.004], dtype=float32),\n", + " '84': array([ 2786.6, 22.757, 52.635, 49.66], dtype=float32),\n", + " '85': array([ 1812.5, 176.7, 57.783, 59.101], dtype=float32),\n", + " '86': array([ 3170.5, 26.972, 45.001, 46.404], dtype=float32),\n", + " '87': array([ 1284.7, 182.54, 32.421, 49.083], dtype=float32),\n", + " '88': array([ 423.68, 280.19, 28.83, 30.268], dtype=float32),\n", + " '89': array([ 1716.8, 179.53, 59.38, 50.26], dtype=float32),\n", + " '90': array([ 344.29, 185.94, 54.55, 43], dtype=float32),\n", + " '91': array([ 1515.7, 252.05, 33.876, 33.096], dtype=float32),\n", + " '92': array([ 1090.7, 243.67, 36.209, 50.511], dtype=float32),\n", + " '93': array([ 1248.6, 189.02, 32.176, 33.262], dtype=float32),\n", + " '94': array([ 963.23, 254.23, 40.772, 33.994], dtype=float32),\n", + " '95': array([ 1717.4, 174.03, 52.489, 48.778], dtype=float32),\n", + " '96': array([ 3075.8, 30.359, 39.588, 38.699], dtype=float32),\n", + " '97': array([ 3187.4, 107.38, 33.732, 40.21], dtype=float32),\n", + " '98': array([ 2966.8, 168.2, 91.264, 105.89], dtype=float32),\n", + " '99': array([ 30.783, 33.381, 33.3, 31.808], dtype=float32),\n", + " '100': array([ 1196.9, 324.51, 27.192, 25.871], dtype=float32),\n", + " '101': array([ 3172.2, 310.75, 44.795, 39.267], dtype=float32),\n", + " '102': array([ 1998.4, 173.44, 30.628, 32.281], dtype=float32),\n", + " '103': array([ 787.33, 241.94, 64.897, 57.237], dtype=float32),\n", + " '104': array([ 2692.7, 21.127, 51.957, 55.546], dtype=float32),\n", + " '105': array([ 1170.2, 247.59, 34.857, 42.74], dtype=float32),\n", + " '106': array([ 1910.1, 174.37, 58.718, 53.371], dtype=float32),\n", + " '107': array([ 2259.9, 225.75, 31.158, 32.173], dtype=float32),\n", + " '108': array([ 254.78, 181.69, 56.54, 44.482], dtype=float32),\n", + " '109': array([ 1047.4, 182.75, 71.273, 49.486], dtype=float32),\n", + " '110': array([ 892.42, 233.03, 55.595, 76.336], dtype=float32),\n", + " '111': array([ 424.87, 175.43, 29.786, 28.278], dtype=float32),\n", + " '112': array([ 2873.1, 180.62, 62.633, 47.619], dtype=float32),\n", + " '113': array([ 1996.8, 278.16, 31.037, 31.629], dtype=float32),\n", + " '114': array([ 11.928, 369.01, 564.33, 1577.8], dtype=float32),\n", + " '115': array([ 422.86, 225.74, 29.438, 28.741], dtype=float32),\n", + " '116': array([ 1702.6, 235.74, 84.462, 34.063], dtype=float32)},\n", + " ['Text Box ID 0: AutoSave',\n", + " 'Text Box ID 1: Presentation2',\n", + " 'Text Box ID 2: PowerPoint',\n", + " 'Text Box ID 3: General*',\n", + " 'Text Box ID 4: Search',\n", + " 'Text Box ID 5: Yadong',\n", + " 'Text Box ID 6: File',\n", + " 'Text Box ID 7: Home',\n", + " 'Text Box ID 8: Insert',\n", + " 'Text Box ID 9: Draw',\n", + " 'Text Box ID 10: Design',\n", + " 'Text Box ID 11: Transitions',\n", + " 'Text Box ID 12: Animations',\n", + " 'Text Box ID 13: Slide Show',\n", + " 'Text Box ID 14: Record',\n", + " 'Text Box ID 15: Review',\n", + " 'Text Box ID 16: View',\n", + " 'Text Box ID 17: Help',\n", + " 'Text Box ID 18: Record',\n", + " 'Text Box ID 19: Present in Teams',\n", + " 'Text Box ID 20: Share',\n", + " 'Text Box ID 21: Layout',\n", + " 'Text Box ID 22: A\" | A',\n", + " 'Text Box ID 23: 8 =#~',\n", + " 'Text Box ID 24: Shape',\n", + " 'Text Box ID 25: Find',\n", + " 'Text Box ID 26: Paste',\n", + " 'Text Box ID 27: New',\n", + " 'Text Box ID 28: Reuse',\n", + " 'Text Box ID 29: Reset',\n", + " 'Text Box ID 30: [t]',\n", + " 'Text Box ID 31: Shapes Arrange',\n", + " 'Text Box ID 32: Quick',\n", + " 'Text Box ID 33: Shape Outline',\n", + " 'Text Box ID 34: Replace',\n", + " 'Text Box ID 35: Dictate',\n", + " 'Text Box ID 36: Sensitivity',\n", + " 'Text Box ID 37: Add-ins',\n", + " 'Text Box ID 38: Designer Copilot',\n", + " 'Text Box ID 39: 4',\n", + " 'Text Box ID 40: Aa ~',\n", + " 'Text Box ID 41: 22E6',\n", + " 'Text Box ID 42: Slide',\n", + " 'Text Box ID 43: Slides',\n", + " 'Text Box ID 44: Section',\n", + " 'Text Box ID 45: Styles',\n", + " 'Text Box ID 46: Effects',\n", + " 'Text Box ID 47: Select',\n", + " 'Text Box ID 48: Clipboard',\n", + " 'Text Box ID 49: Slides',\n", + " 'Text Box ID 50: Font',\n", + " 'Text Box ID 51: Paragraph',\n", + " 'Text Box ID 52: Drawing',\n", + " 'Text Box ID 53: Editing',\n", + " 'Text Box ID 54: Voice',\n", + " 'Text Box ID 55: Sensitivity',\n", + " 'Text Box ID 56: Add-ins',\n", + " 'Text Box ID 57: Click to add title',\n", + " 'Text Box ID 58: Click to add subtitle',\n", + " 'Text Box ID 59: Click to add notes',\n", + " 'Text Box ID 60: Shape'])" ] }, - "execution_count": 13, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "label_coordinates, parsed_content_list[0].split(': ')[1]" + "label_coordinates, parsed_content_list#[0].split(': ')[1]" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, @@ -640,238 +722,6 @@ "plt.imshow(image)\n" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# wrapped Omniparser" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parsing image: examples/pc_1.png\n", - "\n", - "image 1/1 /home/yadonglu/sandbox/screenparsing_collab/screenparsing/omniparser/examples/pc_1.png: 800x1280 210 icons, 55.6ms\n", - "Speed: 7.7ms preprocess, 55.6ms inference, 1.5ms postprocess per image at shape (1, 3, 800, 1280)\n", - "boxes cpu\n", - "Time taken for Omniparser on cpu: 2.029506206512451\n" - ] - } - ], - "source": [ - "from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model\n", - "import torch\n", - "from ultralytics import YOLO\n", - "from PIL import Image\n", - "from typing import Dict, Tuple, List\n", - "import io\n", - "import base64\n", - "\n", - "\n", - "config = {\n", - " 'som_model_path': 'finetuned_icon_detect.pt',\n", - " 'device': 'cpu',\n", - " 'caption_model_path': 'Salesforce/blip2-opt-2.7b',\n", - " 'draw_bbox_config': {\n", - " 'text_scale': 0.8,\n", - " 'text_thickness': 2,\n", - " 'text_padding': 3,\n", - " 'thickness': 3,\n", - " },\n", - " 'BOX_TRESHOLD': 0.05\n", - "}\n", - "\n", - "\n", - "class Omniparser(object):\n", - " def __init__(self, config: Dict):\n", - " self.config = config\n", - " \n", - " self.som_model = get_yolo_model(model_path=config['som_model_path'])\n", - " # self.caption_model_processor = get_caption_model_processor(config['caption_model_path'], device=cofig['device'])\n", - " # self.caption_model_processor['model'].to(torch.float32)\n", - "\n", - " def parse(self, image_path: str):\n", - " print('Parsing image:', image_path)\n", - " ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})\n", - " text, ocr_bbox = ocr_bbox_rslt\n", - "\n", - " draw_bbox_config = self.config['draw_bbox_config']\n", - " BOX_TRESHOLD = self.config['BOX_TRESHOLD']\n", - " dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)\n", - " \n", - " image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))\n", - " # formating output\n", - " return_list = [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},\n", - " 'text': parsed_content_list[i].split(': ')[1], 'type':'text'} for i, (k, coord) in enumerate(label_coordinates.items()) if i < len(parsed_content_list)]\n", - " return_list.extend(\n", - " [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},\n", - " 'text': 'None', 'type':'icon'} for i, (k, coord) in enumerate(label_coordinates.items()) if i >= len(parsed_content_list)]\n", - " )\n", - "\n", - " return [image, return_list]\n", - " \n", - "parser = Omniparser(config)\n", - "image_path = 'imgs/pc_1.png'\n", - "\n", - "# time the parser\n", - "import time\n", - "s = time.time()\n", - "image, parsed_content_list = parser.parse(image_path)\n", - "device = config['device']\n", - "print(f'Time taken for Omniparser on {device}:', time.time() - s)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0: 800x1280 210 icons, 49.4ms\n", - "Speed: 5.7ms preprocess, 49.4ms inference, 1.1ms postprocess per image at shape (1, 3, 800, 1280)\n", - "boxes cpu\n", - "Time taken for Omniparser finetuned YOLO module on cpu: 0.2898883819580078\n" - ] - } - ], - "source": [ - "from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model, predict_yolo\n", - "import torch\n", - "from ultralytics import YOLO\n", - "from PIL import Image\n", - "from typing import Dict, Tuple, List\n", - "import io\n", - "import base64\n", - "\n", - "\n", - "config = {\n", - " 'som_model_path': 'finetuned_icon_detect.pt',\n", - " 'device': 'cpu',\n", - " 'caption_model_path': 'Salesforce/blip2-opt-2.7b',\n", - " 'draw_bbox_config': {\n", - " 'text_scale': 0.8,\n", - " 'text_thickness': 2,\n", - " 'text_padding': 3,\n", - " 'thickness': 3,\n", - " },\n", - " 'BOX_TRESHOLD': 0.05\n", - "}\n", - "\n", - "class OmniparserYOLO(object):\n", - " def __init__(self, config: Dict):\n", - " self.config = config\n", - " self.som_model = get_yolo_model(model_path=config['som_model_path'])\n", - "\n", - " def parse(self, image):\n", - " draw_bbox_config = self.config['draw_bbox_config']\n", - " BOX_TRESHOLD = self.config['BOX_TRESHOLD']\n", - " xyxy, logits, phrases = predict_yolo(model=self.som_model, image_path=image, box_threshold=BOX_TRESHOLD)\n", - " # print('xyxy:', xyxy)\n", - " xyxy = xyxy.tolist()\n", - " # formating output\n", - " return_list = [{'from': 'omniparserYOLO', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2]-coord[0], 'height':coord[3] - coord[1]},\n", - " 'text': 'None', 'type':'icon'} for i, coord in enumerate(xyxy)]\n", - " \n", - " return [None, return_list]\n", - " \n", - "parser = OmniparserYOLO(config)\n", - "image_path = 'imgs/pc_1.png'\n", - "image = Image.open(image_path)\n", - "\n", - "# time the parser\n", - "import time\n", - "s = time.time()\n", - "_, parsed_content_list = parser.parse(image)\n", - "device = config['device']\n", - "print(f'Time taken for Omniparser finetuned YOLO module on {device}:', time.time() - s)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# florence caption model" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/yadonglu/anaconda3/envs/pilot/lib/python3.9/site-packages/transformers/utils/generic.py:342: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", - " _torch_pytree._register_pytree_node(\n", - "/home/yadonglu/anaconda3/envs/pilot/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "from transformers import AutoProcessor, AutoModelForCausalLM \n", - "import torch\n", - "device = 'cpu'\n", - "torch_dtype = torch.float16 if device == 'cuda' else torch.float32\n", - "model = AutoModelForCausalLM.from_pretrained(\"/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_rai_win_ep5/epoch_5\", torch_dtype=torch_dtype, trust_remote_code=True).to(device)\n", - "processor = AutoProcessor.from_pretrained(\"microsoft/Florence-2-base\", trust_remote_code=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['settings or configuration options.']" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from PIL import Image\n", - "prompt = \"\"\n", - "image_path = 'imgs/settings.png'\n", - "image = [Image.open(image_path).convert('RGB')]\n", - "inputs = processor(images=image, text=[prompt]*len(image), return_tensors=\"pt\").to(device=device)\n", - "generated_ids = model.generate(input_ids=inputs[\"input_ids\"],pixel_values=inputs[\"pixel_values\"],max_new_tokens=1024,num_beams=3, do_sample=False)\n", - "generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)\n", - "generated_text = [gen.strip() for gen in generated_text]\n", - "generated_text" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import cv2" - ] - }, { "cell_type": "code", "execution_count": null, @@ -896,7 +746,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.12.0" } }, "nbformat": 4, diff --git a/util/__pycache__/__init__.cpython-312.pyc b/util/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..cf380e0 Binary files /dev/null and b/util/__pycache__/__init__.cpython-312.pyc differ diff --git a/util/__pycache__/box_annotator.cpython-312.pyc b/util/__pycache__/box_annotator.cpython-312.pyc new file mode 100644 index 0000000..678898d Binary files /dev/null and b/util/__pycache__/box_annotator.cpython-312.pyc differ diff --git a/utils.py b/utils.py index 40fb079..2fc180d 100755 --- a/utils.py +++ b/utils.py @@ -18,7 +18,7 @@ import numpy as np # %matplotlib inline from matplotlib import pyplot as plt import easyocr -reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory # 'ch_sim', +reader = easyocr.Reader(['en']) import time import base64 @@ -33,44 +33,19 @@ import supervision as sv import torchvision.transforms as T -def get_caption_model_processor(model_name="Salesforce/blip2-opt-2.7b", device=None): +def get_caption_model_processor(model_name_or_path="Salesforce/blip2-opt-2.7b", device=None): if not device: device = "cuda" if torch.cuda.is_available() else "cpu" - if model_name == "Salesforce/blip2-opt-2.7b": - from transformers import Blip2Processor, Blip2ForConditionalGeneration - processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + from transformers import Blip2Processor, Blip2ForConditionalGeneration + processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + if device == 'cpu': model = Blip2ForConditionalGeneration.from_pretrained( - "Salesforce/blip2-opt-2.7b", device_map=None, torch_dtype=torch.float16 - # '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16 - ) - elif model_name == "blip2-opt-2.7b-ui": - from transformers import Blip2Processor, Blip2ForConditionalGeneration - processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") - if device == 'cpu': - model = Blip2ForConditionalGeneration.from_pretrained( - '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float32 - ) - else: - model = Blip2ForConditionalGeneration.from_pretrained( - '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16 - ) - elif model_name == "florence": - from transformers import AutoProcessor, AutoModelForCausalLM - processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) - if device == 'cpu': - model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai", torch_dtype=torch.float32, trust_remote_code=True)#.to(device) - else: - model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai_win_ep5_fixed", torch_dtype=torch.float16, trust_remote_code=True).to(device) - elif model_name == 'phi3v_ui': - from transformers import AutoModelForCausalLM, AutoProcessor - model_id = "microsoft/Phi-3-vision-128k-instruct" - model = AutoModelForCausalLM.from_pretrained('/home/yadonglu/sandbox/data/orca/phi3v_ui', device_map=device, trust_remote_code=True, torch_dtype="auto") - processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) - elif model_name == 'phi3v': - from transformers import AutoModelForCausalLM, AutoProcessor - model_id = "microsoft/Phi-3-vision-128k-instruct" - model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, trust_remote_code=True, torch_dtype="auto") - processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + model_name_or_path, device_map=None, torch_dtype=torch.float32 + ) + else: + model = Blip2ForConditionalGeneration.from_pretrained( + model_name_or_path, device_map=None, torch_dtype=torch.float16 + ) return {'model': model.to(device), 'processor': processor} @@ -94,14 +69,12 @@ def get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_mode cropped_image = image_source[ymin:ymax, xmin:xmax, :] croped_pil_image.append(to_pil(cropped_image)) - # import pdb; pdb.set_trace() model, processor = caption_model_processor['model'], caption_model_processor['processor'] if not prompt: if 'florence' in model.config.name_or_path: prompt = "" else: prompt = "The image shows" - # prompt = "NO gender!NO gender!NO gender! The image shows a icon:" batch_size = 10 # Number of samples per batch generated_texts = [] @@ -387,117 +360,15 @@ def get_xywh_yolo(input): return x, y, w, h -def run_api(body, max_tokens=1024): - ''' - API call, check https://platform.openai.com/docs/guides/vision for the latest api usage. - ''' - max_num_trial = 3 - num_trial = 0 - while num_trial < max_num_trial: - try: - response = client.chat.completions.create( - model=deployment, - messages=body, - temperature=0.01, - max_tokens=max_tokens, - ) - return response.choices[0].message.content - except: - print('retry call gptv', num_trial) - num_trial += 1 - time.sleep(10) - return '' - -def call_gpt4v_new(message_text, image_path=None, max_tokens=2048): - if image_path: - try: - with open(image_path, "rb") as img_file: - encoded_image = base64.b64encode(img_file.read()).decode('ascii') - except: - encoded_image = image_path - - if image_path: - content = [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}, {"type": "text","text": message_text},] - else: - content = [{"type": "text","text": message_text},] - - max_num_trial = 3 - num_trial = 0 - call_api_success = True - - while num_trial < max_num_trial: - try: - response = client.chat.completions.create( - model=deployment, - messages=[ - { - "role": "system", - "content": [ - { - "type": "text", - "text": "You are an AI assistant that is good at making plans and analyzing screens, and helping people find information." - }, - ] - }, - { - "role": "user", - "content": content - } - ], - temperature=0.01, - max_tokens=max_tokens, - ) - ans_1st_pass = response.choices[0].message.content - break - except: - print('retry call gptv', num_trial) - num_trial += 1 - ans_1st_pass = '' - time.sleep(10) - if num_trial == max_num_trial: - call_api_success = False - return ans_1st_pass, call_api_success - def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None): if easyocr_args is None: easyocr_args = {} result = reader.readtext(image_path, **easyocr_args) is_goal_filtered = False - if goal_filtering: - ocr_filter_fs = "Example 1:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Share', 0.949013667261589), ([[3068, 197], [3135, 197], [3135, 227], [3068, 227]], 'Link _', 0.3567054243152049), ([[3006, 321], [3178, 321], [3178, 354], [3006, 354]], 'Manage Access', 0.8800734456437066)] ``` \n Example 2:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Search Google or type a URL', 0.949013667261589)] ```" - # message_text = f"Based on the ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. The task is: {goal_filtering}, the ocr results are: {str(result)}. Your final answer should be in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis." - message_text = f"Based on the task and ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. Requirement: 1. first give a brief analysis. 2. provide an answer in the format: ```In summary, the task related bboxes are: ..```, you must put it inside ``` ```. Do not include any info after ```.\n {ocr_filter_fs}\n The task is: {goal_filtering}, the ocr results are: {str(result)}." - - prompt = [{"role":"system", "content": "You are an AI assistant that helps people find the correct way to operate computer or smartphone."}, {"role":"user","content": message_text},] - print('[Perform OCR filtering by goal] ongoing ...') - # pred, _, _ = call_gpt4(prompt) - pred, _, = call_gpt4v(message_text) - # import pdb; pdb.set_trace() - try: - # match = re.search(r"```(.*?)```", pred, re.DOTALL) - # result = match.group(1).strip() - # pred = result.split('In summary, the task related bboxes are:')[-1].strip() - pred = pred.split('In summary, the task related bboxes are:')[-1].strip().strip('```') - result = ast.literal_eval(pred) - print('[Perform OCR filtering by goal] success!!! Filtered buttons: ', pred) - is_goal_filtered = True - except: - print('[Perform OCR filtering by goal] failed or unused!!!') - pass - # added_prompt = [{"role":"assistant","content":pred}, - # {"role":"user","content": "given the previous answers, please provide the final answer in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."}] - # prompt.extend(added_prompt) - # pred, _, _ = call_gpt4(prompt) - # print('goal filtering pred 2nd:', pred) - # result = ast.literal_eval(pred) # print('goal filtering pred:', result[-5:]) coord = [item[0] for item in result] text = [item[1] for item in result] - # confidence = [item[2] for item in result] - # if confidence_filtering: - # coord = [coord[i] for i in range(len(coord)) if confidence[i] > confidence_filtering] - # text = [text[i] for i in range(len(text)) if confidence[i] > confidence_filtering] # read the image using cv2 if display_img: opencv_img = cv2.imread(image_path) @@ -520,87 +391,4 @@ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_ return (text, bb), is_goal_filtered -def get_pred_gptv(message_text, yolo_labled_img, label_coordinates, summarize_history=True, verbose=True, history=None, id_key='Click ID'): - """ This func first - 1. call gptv(yolo_labled_img, text bbox+task) -> ans_1st_cal - 2. call gpt4(ans_1st_cal, label_coordinates) -> final ans - """ - - # Configuration - encoded_image = yolo_labled_img - - # Payload for the request - if not history: - messages = [ - {"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]}, - {"role": "user","content": [{"type": "text","text": message_text}, {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},]} - ] - else: - messages = [ - {"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]}, - history, - {"role": "user","content": [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},{"type": "text","text": message_text},]} - ] - - payload = { - "messages": messages, - "temperature": 0.01, # 0.01 - "top_p": 0.95, - "max_tokens": 800 - } - - max_num_trial = 3 - num_trial = 0 - call_api_success = True - while num_trial < max_num_trial: - try: - # response = requests.post(GPT4V_ENDPOINT, headers=headers, json=payload) - # response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code - # ans_1st_pass = response.json()['choices'][0]['message']['content'] - response = client.chat.completions.create( - model=deployment, - messages=messages, - temperature=0.01, - max_tokens=512, - ) - ans_1st_pass = response.choices[0].message.content - break - except requests.RequestException as e: - print('retry call gptv', num_trial) - num_trial += 1 - ans_1st_pass = '' - time.sleep(30) - # raise SystemExit(f"Failed to make the request. Error: {e}") - if num_trial == max_num_trial: - call_api_success = False - if verbose: - print('Answer by GPTV: ', ans_1st_pass) - # extract by simple parsing - try: - match = re.search(r"```(.*?)```", ans_1st_pass, re.DOTALL) - if match: - result = match.group(1).strip() - pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '') - pred = ast.literal_eval(pred) - else: - pred = ans_1st_pass.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '') - pred = ast.literal_eval(pred) - - if id_key in pred: - icon_id = pred[id_key] - bbox = label_coordinates[str(icon_id)] - pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2] - except: - # import pdb; pdb.set_trace() - print('gptv action regex extract fail!!!') - print('ans_1st_pass:', ans_1st_pass) - pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False} - - step_pred_summary = None - if summarize_history: - step_pred_summary, _ = call_gpt4v_new('Summarize what action you decide to perform in the current step, in one sentence, and do not include any icon box number: ' + ans_1st_pass, max_tokens=128) - print('step_pred_summary', step_pred_summary) - return pred, [call_api_success, ans_1st_pass, None, step_pred_summary] - # return pred, [call_api_success, message_2nd, completion_2nd.choices[0].message.content, step_pred_summary] -