update readme; demo

This commit is contained in:
yadonglu
2024-10-09 22:31:38 +00:00
parent 4a8758b865
commit 664407ae9b
6 changed files with 174 additions and 535 deletions

View File

@@ -18,7 +18,8 @@
## Install
```python
conda create -n "omni" python==3.12
pip install -r requirements.txt
conda activate omni
pip install -r requirement.txt
```
## Examples:

Binary file not shown.

View File

@@ -2,9 +2,17 @@
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/yadonglu/sandbox/miniconda/envs/omni/lib/python3.12/site-packages/ultralytics/nn/tasks.py:714: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" ckpt = torch.load(file, map_location=\"cpu\")\n"
]
},
{
"data": {
"text/plain": [
@@ -365,7 +373,7 @@
")"
]
},
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -377,19 +385,34 @@
"from PIL import Image\n",
"device = 'cuda'\n",
"\n",
"# dino_model = get_dino_model(load_hf_model=True, device=device)\n",
"som_model = get_yolo_model(model_path='omniparser/weights/best.pt')\n",
"\n",
"# caption_model_processor = get_caption_model_processor(\"Salesforce/blip2-opt-2.7b\", device=device)\n",
"# caption_model_processor['model'].to(torch.float32)\n",
"som_model.to(device)\n",
"\n"
"som_model = get_yolo_model(model_path='weights/omniparser/weights/best.pt')\n",
"som_model.to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.98it/s]\n"
]
}
],
"source": [
"\n",
"caption_model_processor = get_caption_model_processor(model_name_or_path=\"weights/omniparser/blipv2_ui_merge\", device=device)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
@@ -397,7 +420,7 @@
"(device(type='cuda', index=0), ultralytics.models.yolo.model.YOLO)"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -408,7 +431,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -416,8 +439,8 @@
"output_type": "stream",
"text": [
"\n",
"image 1/1 /home/yadonglu/sandbox/OmniParser/imgs/pc_1.png: 800x1280 211 icons, 51.2ms\n",
"Speed: 3.7ms preprocess, 51.2ms inference, 160.7ms postprocess per image at shape (1, 3, 800, 1280)\n"
"image 1/1 /home/yadonglu/sandbox/OmniParser/imgs/pc_1.png: 800x1280 211 icons, 29.0ms\n",
"Speed: 4.1ms preprocess, 29.0ms inference, 121.1ms postprocess per image at shape (1, 3, 800, 1280)\n"
]
}
],
@@ -458,14 +481,14 @@
"ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})\n",
"text, ocr_bbox = ocr_bbox_rslt\n",
"\n",
"dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)\n",
"dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,use_local_semantics=False)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"outputs": [
{
@@ -532,87 +555,146 @@
" '58': array([ 1617, 1187, 564, 63], dtype=float32),\n",
" '59': array([ 602, 1944, 242, 32], dtype=float32),\n",
" '60': array([ 2034, 277, 68, 39], dtype=float32),\n",
" '61': array([ 2965.8, 11.372, 73.33, 65.827], dtype=float32),\n",
" '62': array([ 2963.6, 104.42, 45.381, 45.756], dtype=float32),\n",
" '63': array([ 198.97, 28.516, 79.805, 38.483], dtype=float32),\n",
" '64': array([ 608.18, 181.26, 342.97, 50.053], dtype=float32),\n",
" '65': array([ 1300.6, 250, 33.767, 35.427], dtype=float32),\n",
" '66': array([ 304.13, 30.349, 37.342, 36.602], dtype=float32),\n",
" '67': array([ 667.74, 241.55, 47.529, 53.978], dtype=float32),\n",
" '68': array([ 822.62, 244.9, 47.754, 51.829], dtype=float32),\n",
" '69': array([ 770.31, 244.16, 46.905, 51.143], dtype=float32),\n",
" '70': array([ 1248.1, 251.22, 31.97, 32.78], dtype=float32),\n",
" '71': array([ 1048.9, 244.81, 45.524, 47], dtype=float32),\n",
" '72': array([ 438.99, 28.466, 35.963, 37.508], dtype=float32),\n",
" '73': array([ 954.49, 181.67, 94.131, 49.485], dtype=float32),\n",
" '74': array([ 363.65, 29.569, 32.883, 36.315], dtype=float32),\n",
" '75': array([ 497.99, 28.813, 32.513, 34.289], dtype=float32),\n",
" '76': array([ 1332.3, 188.4, 32.144, 34.772], dtype=float32),\n",
" '77': array([ 1137.5, 251.92, 34.499, 34.152], dtype=float32),\n",
" '78': array([ 880.41, 249.27, 39.986, 37.76], dtype=float32),\n",
" '79': array([ 954.73, 33.853, 23.966, 27.856], dtype=float32),\n",
" '80': array([ 2888, 21.725, 45.384, 44.9], dtype=float32),\n",
" '81': array([ 1997.3, 222.31, 34.341, 34.824], dtype=float32),\n",
" '82': array([ 625.14, 249.45, 29.211, 36.318], dtype=float32),\n",
" '83': array([ 554.33, 25.136, 36.99, 42.99], dtype=float32),\n",
" '84': array([ 2786.6, 22.758, 52.631, 49.663], dtype=float32),\n",
" '85': array([ 1812.5, 176.7, 57.761, 59.091], dtype=float32),\n",
" '86': array([ 3170.5, 26.987, 44.986, 46.376], dtype=float32),\n",
" '87': array([ 1284.7, 182.54, 32.466, 49.088], dtype=float32),\n",
" '88': array([ 423.68, 280.2, 28.832, 30.261], dtype=float32),\n",
" '89': array([ 1716.8, 179.54, 59.374, 50.265], dtype=float32),\n",
" '90': array([ 344.3, 185.95, 54.543, 42.996], dtype=float32),\n",
" '91': array([ 1515.7, 252.05, 33.876, 33.092], dtype=float32),\n",
" '92': array([ 1090.8, 243.67, 36.191, 50.504], dtype=float32),\n",
" '93': array([ 1248.6, 189.02, 32.175, 33.261], dtype=float32),\n",
" '94': array([ 963.23, 254.24, 40.774, 33.985], dtype=float32),\n",
" '95': array([ 1717.4, 174.03, 52.504, 48.792], dtype=float32),\n",
" '96': array([ 3075.8, 30.369, 39.574, 38.682], dtype=float32),\n",
" '97': array([ 3187.4, 107.39, 33.734, 40.21], dtype=float32),\n",
" '98': array([ 2966.8, 168.2, 91.261, 105.9], dtype=float32),\n",
" '99': array([ 30.782, 33.378, 33.3, 31.807], dtype=float32),\n",
" '100': array([ 1196.9, 324.51, 27.192, 25.869], dtype=float32),\n",
" '101': array([ 3172.2, 310.76, 44.785, 39.261], dtype=float32),\n",
" '102': array([ 1998.4, 173.44, 30.627, 32.274], dtype=float32),\n",
" '103': array([ 787.33, 241.95, 64.931, 57.226], dtype=float32),\n",
" '104': array([ 2692.7, 21.129, 51.957, 55.543], dtype=float32),\n",
" '105': array([ 1170.2, 247.59, 34.892, 42.735], dtype=float32),\n",
" '106': array([ 1910.2, 174.37, 58.716, 53.398], dtype=float32),\n",
" '107': array([ 2259.9, 225.75, 31.152, 32.17], dtype=float32),\n",
" '108': array([ 254.78, 181.7, 56.533, 44.474], dtype=float32),\n",
" '109': array([ 1047.4, 182.75, 71.279, 49.489], dtype=float32),\n",
" '110': array([ 424.87, 175.44, 29.783, 28.272], dtype=float32),\n",
" '111': array([ 892.42, 233.03, 55.596, 76.345], dtype=float32),\n",
" '112': array([ 2873.1, 180.61, 62.642, 47.626], dtype=float32),\n",
" '113': array([ 1996.8, 278.17, 31.037, 31.622], dtype=float32),\n",
" '114': array([ 11.883, 368.91, 564.39, 1578.2], dtype=float32),\n",
" '115': array([ 422.86, 225.74, 29.434, 28.735], dtype=float32),\n",
" '116': array([ 1794.2, 209.55, 92.729, 62.57], dtype=float32),\n",
" '117': array([ 1702.6, 235.84, 84.464, 33.964], dtype=float32)},\n",
" 'AutoSave')"
" '61': array([ 2965.8, 11.37, 73.327, 65.831], dtype=float32),\n",
" '62': array([ 2963.6, 104.43, 45.38, 45.75], dtype=float32),\n",
" '63': array([ 198.97, 28.514, 79.804, 38.484], dtype=float32),\n",
" '64': array([ 608.16, 181.26, 343, 50.052], dtype=float32),\n",
" '65': array([ 1300.6, 250, 33.764, 35.422], dtype=float32),\n",
" '66': array([ 304.13, 30.345, 37.345, 36.607], dtype=float32),\n",
" '67': array([ 667.74, 241.53, 47.522, 54.004], dtype=float32),\n",
" '68': array([ 822.62, 244.89, 47.756, 51.845], dtype=float32),\n",
" '69': array([ 770.31, 244.15, 46.912, 51.157], dtype=float32),\n",
" '70': array([ 1248.1, 251.23, 31.97, 32.777], dtype=float32),\n",
" '71': array([ 1048.9, 244.81, 45.531, 47.005], dtype=float32),\n",
" '72': array([ 438.98, 28.462, 35.958, 37.513], dtype=float32),\n",
" '73': array([ 954.49, 181.66, 94.134, 49.488], dtype=float32),\n",
" '74': array([ 363.65, 29.564, 32.889, 36.319], dtype=float32),\n",
" '75': array([ 497.99, 28.809, 32.51, 34.293], dtype=float32),\n",
" '76': array([ 1332.3, 188.4, 32.162, 34.782], dtype=float32),\n",
" '77': array([ 1137.5, 251.92, 34.494, 34.152], dtype=float32),\n",
" '78': array([ 880.41, 249.27, 39.99, 37.764], dtype=float32),\n",
" '79': array([ 954.73, 33.854, 23.97, 27.857], dtype=float32),\n",
" '80': array([ 2888, 21.728, 45.381, 44.893], dtype=float32),\n",
" '81': array([ 1997.3, 222.31, 34.338, 34.831], dtype=float32),\n",
" '82': array([ 625.13, 249.45, 29.214, 36.326], dtype=float32),\n",
" '83': array([ 554.33, 25.126, 36.996, 43.004], dtype=float32),\n",
" '84': array([ 2786.6, 22.757, 52.635, 49.66], dtype=float32),\n",
" '85': array([ 1812.5, 176.7, 57.783, 59.101], dtype=float32),\n",
" '86': array([ 3170.5, 26.972, 45.001, 46.404], dtype=float32),\n",
" '87': array([ 1284.7, 182.54, 32.421, 49.083], dtype=float32),\n",
" '88': array([ 423.68, 280.19, 28.83, 30.268], dtype=float32),\n",
" '89': array([ 1716.8, 179.53, 59.38, 50.26], dtype=float32),\n",
" '90': array([ 344.29, 185.94, 54.55, 43], dtype=float32),\n",
" '91': array([ 1515.7, 252.05, 33.876, 33.096], dtype=float32),\n",
" '92': array([ 1090.7, 243.67, 36.209, 50.511], dtype=float32),\n",
" '93': array([ 1248.6, 189.02, 32.176, 33.262], dtype=float32),\n",
" '94': array([ 963.23, 254.23, 40.772, 33.994], dtype=float32),\n",
" '95': array([ 1717.4, 174.03, 52.489, 48.778], dtype=float32),\n",
" '96': array([ 3075.8, 30.359, 39.588, 38.699], dtype=float32),\n",
" '97': array([ 3187.4, 107.38, 33.732, 40.21], dtype=float32),\n",
" '98': array([ 2966.8, 168.2, 91.264, 105.89], dtype=float32),\n",
" '99': array([ 30.783, 33.381, 33.3, 31.808], dtype=float32),\n",
" '100': array([ 1196.9, 324.51, 27.192, 25.871], dtype=float32),\n",
" '101': array([ 3172.2, 310.75, 44.795, 39.267], dtype=float32),\n",
" '102': array([ 1998.4, 173.44, 30.628, 32.281], dtype=float32),\n",
" '103': array([ 787.33, 241.94, 64.897, 57.237], dtype=float32),\n",
" '104': array([ 2692.7, 21.127, 51.957, 55.546], dtype=float32),\n",
" '105': array([ 1170.2, 247.59, 34.857, 42.74], dtype=float32),\n",
" '106': array([ 1910.1, 174.37, 58.718, 53.371], dtype=float32),\n",
" '107': array([ 2259.9, 225.75, 31.158, 32.173], dtype=float32),\n",
" '108': array([ 254.78, 181.69, 56.54, 44.482], dtype=float32),\n",
" '109': array([ 1047.4, 182.75, 71.273, 49.486], dtype=float32),\n",
" '110': array([ 892.42, 233.03, 55.595, 76.336], dtype=float32),\n",
" '111': array([ 424.87, 175.43, 29.786, 28.278], dtype=float32),\n",
" '112': array([ 2873.1, 180.62, 62.633, 47.619], dtype=float32),\n",
" '113': array([ 1996.8, 278.16, 31.037, 31.629], dtype=float32),\n",
" '114': array([ 11.928, 369.01, 564.33, 1577.8], dtype=float32),\n",
" '115': array([ 422.86, 225.74, 29.438, 28.741], dtype=float32),\n",
" '116': array([ 1702.6, 235.74, 84.462, 34.063], dtype=float32)},\n",
" ['Text Box ID 0: AutoSave',\n",
" 'Text Box ID 1: Presentation2',\n",
" 'Text Box ID 2: PowerPoint',\n",
" 'Text Box ID 3: General*',\n",
" 'Text Box ID 4: Search',\n",
" 'Text Box ID 5: Yadong',\n",
" 'Text Box ID 6: File',\n",
" 'Text Box ID 7: Home',\n",
" 'Text Box ID 8: Insert',\n",
" 'Text Box ID 9: Draw',\n",
" 'Text Box ID 10: Design',\n",
" 'Text Box ID 11: Transitions',\n",
" 'Text Box ID 12: Animations',\n",
" 'Text Box ID 13: Slide Show',\n",
" 'Text Box ID 14: Record',\n",
" 'Text Box ID 15: Review',\n",
" 'Text Box ID 16: View',\n",
" 'Text Box ID 17: Help',\n",
" 'Text Box ID 18: Record',\n",
" 'Text Box ID 19: Present in Teams',\n",
" 'Text Box ID 20: Share',\n",
" 'Text Box ID 21: Layout',\n",
" 'Text Box ID 22: A\" | A',\n",
" 'Text Box ID 23: 8 =#~',\n",
" 'Text Box ID 24: Shape',\n",
" 'Text Box ID 25: Find',\n",
" 'Text Box ID 26: Paste',\n",
" 'Text Box ID 27: New',\n",
" 'Text Box ID 28: Reuse',\n",
" 'Text Box ID 29: Reset',\n",
" 'Text Box ID 30: [t]',\n",
" 'Text Box ID 31: Shapes Arrange',\n",
" 'Text Box ID 32: Quick',\n",
" 'Text Box ID 33: Shape Outline',\n",
" 'Text Box ID 34: Replace',\n",
" 'Text Box ID 35: Dictate',\n",
" 'Text Box ID 36: Sensitivity',\n",
" 'Text Box ID 37: Add-ins',\n",
" 'Text Box ID 38: Designer Copilot',\n",
" 'Text Box ID 39: 4',\n",
" 'Text Box ID 40: Aa ~',\n",
" 'Text Box ID 41: 22E6',\n",
" 'Text Box ID 42: Slide',\n",
" 'Text Box ID 43: Slides',\n",
" 'Text Box ID 44: Section',\n",
" 'Text Box ID 45: Styles',\n",
" 'Text Box ID 46: Effects',\n",
" 'Text Box ID 47: Select',\n",
" 'Text Box ID 48: Clipboard',\n",
" 'Text Box ID 49: Slides',\n",
" 'Text Box ID 50: Font',\n",
" 'Text Box ID 51: Paragraph',\n",
" 'Text Box ID 52: Drawing',\n",
" 'Text Box ID 53: Editing',\n",
" 'Text Box ID 54: Voice',\n",
" 'Text Box ID 55: Sensitivity',\n",
" 'Text Box ID 56: Add-ins',\n",
" 'Text Box ID 57: Click to add title',\n",
" 'Text Box ID 58: Click to add subtitle',\n",
" 'Text Box ID 59: Click to add notes',\n",
" 'Text Box ID 60: Shape'])"
]
},
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"label_coordinates, parsed_content_list[0].split(': ')[1]"
"label_coordinates, parsed_content_list#[0].split(': ')[1]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7c1b13b5f5b0>"
"<matplotlib.image.AxesImage at 0x7ff3e91f4fb0>"
]
},
"execution_count": 9,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
@@ -640,238 +722,6 @@
"plt.imshow(image)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# wrapped Omniparser"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parsing image: examples/pc_1.png\n",
"\n",
"image 1/1 /home/yadonglu/sandbox/screenparsing_collab/screenparsing/omniparser/examples/pc_1.png: 800x1280 210 icons, 55.6ms\n",
"Speed: 7.7ms preprocess, 55.6ms inference, 1.5ms postprocess per image at shape (1, 3, 800, 1280)\n",
"boxes cpu\n",
"Time taken for Omniparser on cpu: 2.029506206512451\n"
]
}
],
"source": [
"from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model\n",
"import torch\n",
"from ultralytics import YOLO\n",
"from PIL import Image\n",
"from typing import Dict, Tuple, List\n",
"import io\n",
"import base64\n",
"\n",
"\n",
"config = {\n",
" 'som_model_path': 'finetuned_icon_detect.pt',\n",
" 'device': 'cpu',\n",
" 'caption_model_path': 'Salesforce/blip2-opt-2.7b',\n",
" 'draw_bbox_config': {\n",
" 'text_scale': 0.8,\n",
" 'text_thickness': 2,\n",
" 'text_padding': 3,\n",
" 'thickness': 3,\n",
" },\n",
" 'BOX_TRESHOLD': 0.05\n",
"}\n",
"\n",
"\n",
"class Omniparser(object):\n",
" def __init__(self, config: Dict):\n",
" self.config = config\n",
" \n",
" self.som_model = get_yolo_model(model_path=config['som_model_path'])\n",
" # self.caption_model_processor = get_caption_model_processor(config['caption_model_path'], device=cofig['device'])\n",
" # self.caption_model_processor['model'].to(torch.float32)\n",
"\n",
" def parse(self, image_path: str):\n",
" print('Parsing image:', image_path)\n",
" ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})\n",
" text, ocr_bbox = ocr_bbox_rslt\n",
"\n",
" draw_bbox_config = self.config['draw_bbox_config']\n",
" BOX_TRESHOLD = self.config['BOX_TRESHOLD']\n",
" dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)\n",
" \n",
" image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))\n",
" # formating output\n",
" return_list = [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},\n",
" 'text': parsed_content_list[i].split(': ')[1], 'type':'text'} for i, (k, coord) in enumerate(label_coordinates.items()) if i < len(parsed_content_list)]\n",
" return_list.extend(\n",
" [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},\n",
" 'text': 'None', 'type':'icon'} for i, (k, coord) in enumerate(label_coordinates.items()) if i >= len(parsed_content_list)]\n",
" )\n",
"\n",
" return [image, return_list]\n",
" \n",
"parser = Omniparser(config)\n",
"image_path = 'imgs/pc_1.png'\n",
"\n",
"# time the parser\n",
"import time\n",
"s = time.time()\n",
"image, parsed_content_list = parser.parse(image_path)\n",
"device = config['device']\n",
"print(f'Time taken for Omniparser on {device}:', time.time() - s)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0: 800x1280 210 icons, 49.4ms\n",
"Speed: 5.7ms preprocess, 49.4ms inference, 1.1ms postprocess per image at shape (1, 3, 800, 1280)\n",
"boxes cpu\n",
"Time taken for Omniparser finetuned YOLO module on cpu: 0.2898883819580078\n"
]
}
],
"source": [
"from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model, predict_yolo\n",
"import torch\n",
"from ultralytics import YOLO\n",
"from PIL import Image\n",
"from typing import Dict, Tuple, List\n",
"import io\n",
"import base64\n",
"\n",
"\n",
"config = {\n",
" 'som_model_path': 'finetuned_icon_detect.pt',\n",
" 'device': 'cpu',\n",
" 'caption_model_path': 'Salesforce/blip2-opt-2.7b',\n",
" 'draw_bbox_config': {\n",
" 'text_scale': 0.8,\n",
" 'text_thickness': 2,\n",
" 'text_padding': 3,\n",
" 'thickness': 3,\n",
" },\n",
" 'BOX_TRESHOLD': 0.05\n",
"}\n",
"\n",
"class OmniparserYOLO(object):\n",
" def __init__(self, config: Dict):\n",
" self.config = config\n",
" self.som_model = get_yolo_model(model_path=config['som_model_path'])\n",
"\n",
" def parse(self, image):\n",
" draw_bbox_config = self.config['draw_bbox_config']\n",
" BOX_TRESHOLD = self.config['BOX_TRESHOLD']\n",
" xyxy, logits, phrases = predict_yolo(model=self.som_model, image_path=image, box_threshold=BOX_TRESHOLD)\n",
" # print('xyxy:', xyxy)\n",
" xyxy = xyxy.tolist()\n",
" # formating output\n",
" return_list = [{'from': 'omniparserYOLO', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2]-coord[0], 'height':coord[3] - coord[1]},\n",
" 'text': 'None', 'type':'icon'} for i, coord in enumerate(xyxy)]\n",
" \n",
" return [None, return_list]\n",
" \n",
"parser = OmniparserYOLO(config)\n",
"image_path = 'imgs/pc_1.png'\n",
"image = Image.open(image_path)\n",
"\n",
"# time the parser\n",
"import time\n",
"s = time.time()\n",
"_, parsed_content_list = parser.parse(image)\n",
"device = config['device']\n",
"print(f'Time taken for Omniparser finetuned YOLO module on {device}:', time.time() - s)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# florence caption model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/yadonglu/anaconda3/envs/pilot/lib/python3.9/site-packages/transformers/utils/generic.py:342: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" _torch_pytree._register_pytree_node(\n",
"/home/yadonglu/anaconda3/envs/pilot/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from transformers import AutoProcessor, AutoModelForCausalLM \n",
"import torch\n",
"device = 'cpu'\n",
"torch_dtype = torch.float16 if device == 'cuda' else torch.float32\n",
"model = AutoModelForCausalLM.from_pretrained(\"/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_rai_win_ep5/epoch_5\", torch_dtype=torch_dtype, trust_remote_code=True).to(device)\n",
"processor = AutoProcessor.from_pretrained(\"microsoft/Florence-2-base\", trust_remote_code=True)\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['settings or configuration options.']"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from PIL import Image\n",
"prompt = \"<CAPTION>\"\n",
"image_path = 'imgs/settings.png'\n",
"image = [Image.open(image_path).convert('RGB')]\n",
"inputs = processor(images=image, text=[prompt]*len(image), return_tensors=\"pt\").to(device=device)\n",
"generated_ids = model.generate(input_ids=inputs[\"input_ids\"],pixel_values=inputs[\"pixel_values\"],max_new_tokens=1024,num_beams=3, do_sample=False)\n",
"generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)\n",
"generated_text = [gen.strip() for gen in generated_text]\n",
"generated_text"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -896,7 +746,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.12.0"
}
},
"nbformat": 4,

Binary file not shown.

Binary file not shown.

234
utils.py
View File

@@ -18,7 +18,7 @@ import numpy as np
# %matplotlib inline
from matplotlib import pyplot as plt
import easyocr
reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory # 'ch_sim',
reader = easyocr.Reader(['en'])
import time
import base64
@@ -33,44 +33,19 @@ import supervision as sv
import torchvision.transforms as T
def get_caption_model_processor(model_name="Salesforce/blip2-opt-2.7b", device=None):
def get_caption_model_processor(model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
if not device:
device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name == "Salesforce/blip2-opt-2.7b":
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
if device == 'cpu':
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", device_map=None, torch_dtype=torch.float16
# '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16
)
elif model_name == "blip2-opt-2.7b-ui":
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
if device == 'cpu':
model = Blip2ForConditionalGeneration.from_pretrained(
'/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float32
)
else:
model = Blip2ForConditionalGeneration.from_pretrained(
'/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16
)
elif model_name == "florence":
from transformers import AutoProcessor, AutoModelForCausalLM
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai", torch_dtype=torch.float32, trust_remote_code=True)#.to(device)
else:
model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai_win_ep5_fixed", torch_dtype=torch.float16, trust_remote_code=True).to(device)
elif model_name == 'phi3v_ui':
from transformers import AutoModelForCausalLM, AutoProcessor
model_id = "microsoft/Phi-3-vision-128k-instruct"
model = AutoModelForCausalLM.from_pretrained('/home/yadonglu/sandbox/data/orca/phi3v_ui', device_map=device, trust_remote_code=True, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
elif model_name == 'phi3v':
from transformers import AutoModelForCausalLM, AutoProcessor
model_id = "microsoft/Phi-3-vision-128k-instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, trust_remote_code=True, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model_name_or_path, device_map=None, torch_dtype=torch.float32
)
else:
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float16
)
return {'model': model.to(device), 'processor': processor}
@@ -94,14 +69,12 @@ def get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_mode
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
croped_pil_image.append(to_pil(cropped_image))
# import pdb; pdb.set_trace()
model, processor = caption_model_processor['model'], caption_model_processor['processor']
if not prompt:
if 'florence' in model.config.name_or_path:
prompt = "<CAPTION>"
else:
prompt = "The image shows"
# prompt = "NO gender!NO gender!NO gender! The image shows a icon:"
batch_size = 10 # Number of samples per batch
generated_texts = []
@@ -387,117 +360,15 @@ def get_xywh_yolo(input):
return x, y, w, h
def run_api(body, max_tokens=1024):
'''
API call, check https://platform.openai.com/docs/guides/vision for the latest api usage.
'''
max_num_trial = 3
num_trial = 0
while num_trial < max_num_trial:
try:
response = client.chat.completions.create(
model=deployment,
messages=body,
temperature=0.01,
max_tokens=max_tokens,
)
return response.choices[0].message.content
except:
print('retry call gptv', num_trial)
num_trial += 1
time.sleep(10)
return ''
def call_gpt4v_new(message_text, image_path=None, max_tokens=2048):
if image_path:
try:
with open(image_path, "rb") as img_file:
encoded_image = base64.b64encode(img_file.read()).decode('ascii')
except:
encoded_image = image_path
if image_path:
content = [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}, {"type": "text","text": message_text},]
else:
content = [{"type": "text","text": message_text},]
max_num_trial = 3
num_trial = 0
call_api_success = True
while num_trial < max_num_trial:
try:
response = client.chat.completions.create(
model=deployment,
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an AI assistant that is good at making plans and analyzing screens, and helping people find information."
},
]
},
{
"role": "user",
"content": content
}
],
temperature=0.01,
max_tokens=max_tokens,
)
ans_1st_pass = response.choices[0].message.content
break
except:
print('retry call gptv', num_trial)
num_trial += 1
ans_1st_pass = ''
time.sleep(10)
if num_trial == max_num_trial:
call_api_success = False
return ans_1st_pass, call_api_success
def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None):
if easyocr_args is None:
easyocr_args = {}
result = reader.readtext(image_path, **easyocr_args)
is_goal_filtered = False
if goal_filtering:
ocr_filter_fs = "Example 1:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Share', 0.949013667261589), ([[3068, 197], [3135, 197], [3135, 227], [3068, 227]], 'Link _', 0.3567054243152049), ([[3006, 321], [3178, 321], [3178, 354], [3006, 354]], 'Manage Access', 0.8800734456437066)] ``` \n Example 2:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Search Google or type a URL', 0.949013667261589)] ```"
# message_text = f"Based on the ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. The task is: {goal_filtering}, the ocr results are: {str(result)}. Your final answer should be in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."
message_text = f"Based on the task and ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. Requirement: 1. first give a brief analysis. 2. provide an answer in the format: ```In summary, the task related bboxes are: ..```, you must put it inside ``` ```. Do not include any info after ```.\n {ocr_filter_fs}\n The task is: {goal_filtering}, the ocr results are: {str(result)}."
prompt = [{"role":"system", "content": "You are an AI assistant that helps people find the correct way to operate computer or smartphone."}, {"role":"user","content": message_text},]
print('[Perform OCR filtering by goal] ongoing ...')
# pred, _, _ = call_gpt4(prompt)
pred, _, = call_gpt4v(message_text)
# import pdb; pdb.set_trace()
try:
# match = re.search(r"```(.*?)```", pred, re.DOTALL)
# result = match.group(1).strip()
# pred = result.split('In summary, the task related bboxes are:')[-1].strip()
pred = pred.split('In summary, the task related bboxes are:')[-1].strip().strip('```')
result = ast.literal_eval(pred)
print('[Perform OCR filtering by goal] success!!! Filtered buttons: ', pred)
is_goal_filtered = True
except:
print('[Perform OCR filtering by goal] failed or unused!!!')
pass
# added_prompt = [{"role":"assistant","content":pred},
# {"role":"user","content": "given the previous answers, please provide the final answer in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."}]
# prompt.extend(added_prompt)
# pred, _, _ = call_gpt4(prompt)
# print('goal filtering pred 2nd:', pred)
# result = ast.literal_eval(pred)
# print('goal filtering pred:', result[-5:])
coord = [item[0] for item in result]
text = [item[1] for item in result]
# confidence = [item[2] for item in result]
# if confidence_filtering:
# coord = [coord[i] for i in range(len(coord)) if confidence[i] > confidence_filtering]
# text = [text[i] for i in range(len(text)) if confidence[i] > confidence_filtering]
# read the image using cv2
if display_img:
opencv_img = cv2.imread(image_path)
@@ -520,87 +391,4 @@ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_
return (text, bb), is_goal_filtered
def get_pred_gptv(message_text, yolo_labled_img, label_coordinates, summarize_history=True, verbose=True, history=None, id_key='Click ID'):
""" This func first
1. call gptv(yolo_labled_img, text bbox+task) -> ans_1st_cal
2. call gpt4(ans_1st_cal, label_coordinates) -> final ans
"""
# Configuration
encoded_image = yolo_labled_img
# Payload for the request
if not history:
messages = [
{"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]},
{"role": "user","content": [{"type": "text","text": message_text}, {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},]}
]
else:
messages = [
{"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]},
history,
{"role": "user","content": [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},{"type": "text","text": message_text},]}
]
payload = {
"messages": messages,
"temperature": 0.01, # 0.01
"top_p": 0.95,
"max_tokens": 800
}
max_num_trial = 3
num_trial = 0
call_api_success = True
while num_trial < max_num_trial:
try:
# response = requests.post(GPT4V_ENDPOINT, headers=headers, json=payload)
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
# ans_1st_pass = response.json()['choices'][0]['message']['content']
response = client.chat.completions.create(
model=deployment,
messages=messages,
temperature=0.01,
max_tokens=512,
)
ans_1st_pass = response.choices[0].message.content
break
except requests.RequestException as e:
print('retry call gptv', num_trial)
num_trial += 1
ans_1st_pass = ''
time.sleep(30)
# raise SystemExit(f"Failed to make the request. Error: {e}")
if num_trial == max_num_trial:
call_api_success = False
if verbose:
print('Answer by GPTV: ', ans_1st_pass)
# extract by simple parsing
try:
match = re.search(r"```(.*?)```", ans_1st_pass, re.DOTALL)
if match:
result = match.group(1).strip()
pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
pred = ast.literal_eval(pred)
else:
pred = ans_1st_pass.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
pred = ast.literal_eval(pred)
if id_key in pred:
icon_id = pred[id_key]
bbox = label_coordinates[str(icon_id)]
pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
except:
# import pdb; pdb.set_trace()
print('gptv action regex extract fail!!!')
print('ans_1st_pass:', ans_1st_pass)
pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False}
step_pred_summary = None
if summarize_history:
step_pred_summary, _ = call_gpt4v_new('Summarize what action you decide to perform in the current step, in one sentence, and do not include any icon box number: ' + ans_1st_pass, max_tokens=128)
print('step_pred_summary', step_pred_summary)
return pred, [call_api_success, ans_1st_pass, None, step_pred_summary]
# return pred, [call_api_success, message_2nd, completion_2nd.choices[0].message.content, step_pred_summary]