first commit
This commit is contained in:
32
README.md
Normal file
32
README.md
Normal file
@@ -0,0 +1,32 @@
|
||||
# OmniParser: Screen Parsing tool for Pure Vision Based GUI Agent
|
||||
|
||||

|
||||
[](https://arxiv.org/abs/2408.00203)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
**OmniParser** is a comprehensive method for parsing user interface screenshots into structured and easy-to-understand elements, which significantly enhances the ability of GPT-4V to generate actions that can be accurately grounded in the corresponding regions of the interface.
|
||||
|
||||
## Examples:
|
||||
We put together a few simple examples in the demo.ipynb.
|
||||
|
||||
## Gradio Demo
|
||||
To run gradio demo, simply run:
|
||||
```python
|
||||
python gradion_demo.py
|
||||
```
|
||||
|
||||
|
||||
## 📚 Citation
|
||||
Our technical report can be found [here](https://arxiv.org/abs/2408.00203).
|
||||
If you find our work useful, please consider citing our work:
|
||||
```
|
||||
@misc{lu2024omniparserpurevisionbased,
|
||||
title={OmniParser for Pure Vision Based GUI Agent},
|
||||
author={Yadong Lu and Jianwei Yang and Yelong Shen and Ahmed Awadallah},
|
||||
year={2024},
|
||||
eprint={2408.00203},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV},
|
||||
url={https://arxiv.org/abs/2408.00203},
|
||||
}
|
||||
```
|
||||
542
demo.ipynb
Normal file
542
demo.ipynb
Normal file
File diff suppressed because one or more lines are too long
104
gradio_demo.py
Normal file
104
gradio_demo.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
|
||||
import base64, os
|
||||
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
yolo_model = get_yolo_model()
|
||||
caption_model_processor = get_caption_model_processor('florence', device='cuda') # 'blip2-opt-2.7b-ui', phi3v_ui florence
|
||||
platform = 'pc'
|
||||
if platform == 'pc':
|
||||
draw_bbox_config = {
|
||||
'text_scale': 0.8,
|
||||
'text_thickness': 2,
|
||||
'text_padding': 2,
|
||||
'thickness': 2,
|
||||
}
|
||||
BOX_TRESHOLD = 0.05
|
||||
elif platform == 'web':
|
||||
draw_bbox_config = {
|
||||
'text_scale': 0.8,
|
||||
'text_thickness': 2,
|
||||
'text_padding': 3,
|
||||
'thickness': 3,
|
||||
}
|
||||
BOX_TRESHOLD = 0.05
|
||||
elif platform == 'mobile':
|
||||
draw_bbox_config = {
|
||||
'text_scale': 0.8,
|
||||
'text_thickness': 2,
|
||||
'text_padding': 3,
|
||||
'thickness': 3,
|
||||
}
|
||||
BOX_TRESHOLD = 0.05
|
||||
|
||||
|
||||
|
||||
MARKDOWN = """
|
||||
# OmniParser for Pure Vision Based General GUI Agent 🔥
|
||||
<div>
|
||||
<a href="https://arxiv.org/pdf/2408.00203">
|
||||
<img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
|
||||
</a>
|
||||
</div>
|
||||
|
||||
OmniParser is a screen parsing tool to convert general GUI screen to structured elements. **Trained models will be released soon**
|
||||
"""
|
||||
|
||||
DEVICE = torch.device('cuda')
|
||||
|
||||
# @spaces.GPU
|
||||
# @torch.inference_mode()
|
||||
# @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
def process(
|
||||
image_input,
|
||||
prompt: str = None
|
||||
) -> Optional[Image.Image]:
|
||||
|
||||
image_path = "/home/yadonglu/sandbox/data/omniparser_demo/image_input.png"
|
||||
image_input.save(image_path)
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})
|
||||
text, ocr_bbox = ocr_bbox_rslt
|
||||
print('prompt:', prompt)
|
||||
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, yolo_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=0.3,prompt=prompt)
|
||||
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
|
||||
print('finish processing')
|
||||
parsed_content_list = '\n'.join(parsed_content_list)
|
||||
return image, str(parsed_content_list)
|
||||
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(MARKDOWN)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
image_input_component = gr.Image(
|
||||
type='pil', label='Upload image')
|
||||
prompt_input_component = gr.Textbox(label='Prompt', placeholder='')
|
||||
submit_button_component = gr.Button(
|
||||
value='Submit', variant='primary')
|
||||
with gr.Column():
|
||||
image_output_component = gr.Image(type='pil', label='Image Output')
|
||||
text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
|
||||
|
||||
submit_button_component.click(
|
||||
fn=process,
|
||||
inputs=[
|
||||
image_input_component,
|
||||
prompt_input_component,
|
||||
],
|
||||
outputs=[image_output_component, text_output_component]
|
||||
)
|
||||
|
||||
# demo.launch(debug=False, show_error=True, share=True)
|
||||
demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
|
||||
BIN
imgs/logo.png
Normal file
BIN
imgs/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.0 MiB |
BIN
imgs/mobile_4.png
Normal file
BIN
imgs/mobile_4.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.8 MiB |
BIN
imgs/pc_1.png
Normal file
BIN
imgs/pc_1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 197 KiB |
BIN
imgs/settings.png
Normal file
BIN
imgs/settings.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 805 B |
60
omniparser.py
Normal file
60
omniparser.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
from PIL import Image
|
||||
from typing import Dict, Tuple, List
|
||||
import io
|
||||
import base64
|
||||
|
||||
|
||||
config = {
|
||||
'som_model_path': 'finetuned_icon_detect.pt',
|
||||
'device': 'cpu',
|
||||
'caption_model_path': 'Salesforce/blip2-opt-2.7b',
|
||||
'draw_bbox_config': {
|
||||
'text_scale': 0.8,
|
||||
'text_thickness': 2,
|
||||
'text_padding': 3,
|
||||
'thickness': 3,
|
||||
},
|
||||
'BOX_TRESHOLD': 0.05
|
||||
}
|
||||
|
||||
|
||||
class Omniparser(object):
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
|
||||
self.som_model = get_yolo_model(model_path=config['som_model_path'])
|
||||
# self.caption_model_processor = get_caption_model_processor(config['caption_model_path'], device=cofig['device'])
|
||||
# self.caption_model_processor['model'].to(torch.float32)
|
||||
|
||||
def parse(self, image_path: str):
|
||||
print('Parsing image:', image_path)
|
||||
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})
|
||||
text, ocr_bbox = ocr_bbox_rslt
|
||||
|
||||
draw_bbox_config = self.config['draw_bbox_config']
|
||||
BOX_TRESHOLD = self.config['BOX_TRESHOLD']
|
||||
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)
|
||||
|
||||
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
|
||||
# formating output
|
||||
return_list = [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},
|
||||
'text': parsed_content_list[i].split(': ')[1], 'type':'text'} for i, (k, coord) in enumerate(label_coordinates.items()) if i < len(parsed_content_list)]
|
||||
return_list.extend(
|
||||
[{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},
|
||||
'text': 'None', 'type':'icon'} for i, (k, coord) in enumerate(label_coordinates.items()) if i >= len(parsed_content_list)]
|
||||
)
|
||||
|
||||
return [image, return_list]
|
||||
|
||||
parser = Omniparser(config)
|
||||
image_path = 'examples/pc_1.png'
|
||||
|
||||
# time the parser
|
||||
import time
|
||||
s = time.time()
|
||||
image, parsed_content_list = parser.parse(image_path)
|
||||
device = config['device']
|
||||
print(f'Time taken for Omniparser on {device}:', time.time() - s)
|
||||
14
requirement.txt
Normal file
14
requirement.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
torch==2.2.2
|
||||
easyocr==1.7.1
|
||||
torchvision==0.17.2
|
||||
supervision==0.18.0
|
||||
openai==1.3.5
|
||||
transformers==4.40.2
|
||||
ultralytics==8.1.24
|
||||
azure-identity
|
||||
numpy
|
||||
opencv-python==4.8.1.78
|
||||
opencv-python-headless==4.8.0.74
|
||||
supervision==0.18.0
|
||||
gradio==4.40.0
|
||||
|
||||
0
util/__init__.py
Executable file
0
util/__init__.py
Executable file
BIN
util/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
util/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
util/__pycache__/action_matching.cpython-39.pyc
Normal file
BIN
util/__pycache__/action_matching.cpython-39.pyc
Normal file
Binary file not shown.
BIN
util/__pycache__/box_annotator.cpython-39.pyc
Normal file
BIN
util/__pycache__/box_annotator.cpython-39.pyc
Normal file
Binary file not shown.
425
util/action_matching.py
Normal file
425
util/action_matching.py
Normal file
@@ -0,0 +1,425 @@
|
||||
'''
|
||||
Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
|
||||
'''
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
# import action_type as action_type_lib
|
||||
import enum
|
||||
|
||||
class ActionType(enum.IntEnum):
|
||||
# Placeholders for unused enum values
|
||||
UNUSED_0 = 0
|
||||
UNUSED_1 = 1
|
||||
UNUSED_2 = 2
|
||||
UNUSED_8 = 8
|
||||
UNUSED_9 = 9
|
||||
|
||||
########### Agent actions ###########
|
||||
|
||||
# A type action that sends text to the emulator. Note that this simply sends
|
||||
# text and does not perform any clicks for element focus or enter presses for
|
||||
# submitting text.
|
||||
TYPE = 3
|
||||
|
||||
# The dual point action used to represent all gestures.
|
||||
DUAL_POINT = 4
|
||||
|
||||
# These actions differentiate pressing the home and back button from touches.
|
||||
# They represent explicit presses of back and home performed using ADB.
|
||||
PRESS_BACK = 5
|
||||
PRESS_HOME = 6
|
||||
|
||||
# An action representing that ADB command for hitting enter was performed.
|
||||
PRESS_ENTER = 7
|
||||
|
||||
########### Episode status actions ###########
|
||||
|
||||
# An action used to indicate the desired task has been completed and resets
|
||||
# the environment. This action should also be used in the case that the task
|
||||
# has already been completed and there is nothing to do.
|
||||
# e.g. The task is to turn on the Wi-Fi when it is already on
|
||||
STATUS_TASK_COMPLETE = 10
|
||||
|
||||
# An action used to indicate that desired task is impossible to complete and
|
||||
# resets the environment. This can be a result of many different things
|
||||
# including UI changes, Android version differences, etc.
|
||||
STATUS_TASK_IMPOSSIBLE = 11
|
||||
|
||||
|
||||
_TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen
|
||||
ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4
|
||||
ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4
|
||||
|
||||
# Interval determining if an action is a tap or a swipe.
|
||||
_SWIPE_DISTANCE_THRESHOLD = 0.04
|
||||
|
||||
|
||||
def _yx_in_bounding_boxes(
|
||||
yx, bounding_boxes
|
||||
):
|
||||
"""Check if the (y,x) point is contained in each bounding box.
|
||||
|
||||
Args:
|
||||
yx: The (y, x) coordinate in pixels of the point.
|
||||
bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row
|
||||
represents a bounding box: (y_top_left, x_top_left, box_height,
|
||||
box_width). Note: containment is inclusive of the bounding box edges.
|
||||
|
||||
Returns:
|
||||
is_inside: A 1D bool array where each element specifies if the point is
|
||||
contained within the respective box.
|
||||
"""
|
||||
y, x = yx
|
||||
|
||||
# `bounding_boxes` has shape (n_elements, 4); we extract each array along the
|
||||
# last axis into shape (n_elements, 1), then squeeze unneeded dimension.
|
||||
top, left, height, width = [
|
||||
jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1)
|
||||
]
|
||||
|
||||
# The y-axis is inverted for AndroidEnv, so bottom = top + height.
|
||||
bottom, right = top + height, left + width
|
||||
|
||||
return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(
|
||||
x >= left, x <= right)
|
||||
|
||||
|
||||
def _resize_annotation_bounding_boxes(
|
||||
annotation_positions, annotation_width_augment_fraction,
|
||||
annotation_height_augment_fraction):
|
||||
"""Resize the bounding boxes by the given fractions.
|
||||
|
||||
Args:
|
||||
annotation_positions: Array of shape (N, 4), where each row represents the
|
||||
(y, x, height, width) of the bounding boxes.
|
||||
annotation_width_augment_fraction: The fraction to augment the box widths,
|
||||
E.g., 1.4 == 240% total increase.
|
||||
annotation_height_augment_fraction: Same as described for width, but for box
|
||||
height.
|
||||
|
||||
Returns:
|
||||
Resized bounding box.
|
||||
|
||||
"""
|
||||
height_change = (
|
||||
annotation_height_augment_fraction * annotation_positions[:, 2])
|
||||
width_change = (
|
||||
annotation_width_augment_fraction * annotation_positions[:, 3])
|
||||
|
||||
# Limit bounding box positions to the screen.
|
||||
resized_annotations = jnp.stack([
|
||||
jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)),
|
||||
jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)),
|
||||
jnp.minimum(1, annotation_positions[:, 2] + height_change),
|
||||
jnp.minimum(1, annotation_positions[:, 3] + width_change),
|
||||
],
|
||||
axis=1)
|
||||
return resized_annotations
|
||||
|
||||
|
||||
def is_tap_action(normalized_start_yx,
|
||||
normalized_end_yx):
|
||||
distance = jnp.linalg.norm(
|
||||
jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
|
||||
return distance <= _SWIPE_DISTANCE_THRESHOLD
|
||||
|
||||
|
||||
def _is_non_dual_point_action(action_type):
|
||||
return jnp.not_equal(action_type, ActionType.DUAL_POINT)
|
||||
|
||||
|
||||
def _check_tap_actions_match(
|
||||
tap_1_yx,
|
||||
tap_2_yx,
|
||||
annotation_positions,
|
||||
matching_tap_distance_threshold_screen_percentage,
|
||||
annotation_width_augment_fraction,
|
||||
annotation_height_augment_fraction,
|
||||
):
|
||||
"""Determines if two tap actions are the same."""
|
||||
resized_annotation_positions = _resize_annotation_bounding_boxes(
|
||||
annotation_positions,
|
||||
annotation_width_augment_fraction,
|
||||
annotation_height_augment_fraction,
|
||||
)
|
||||
|
||||
# Check if the ground truth tap action falls in an annotation's bounding box.
|
||||
tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions)
|
||||
tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions)
|
||||
both_in_box = jnp.max(tap1_in_box & tap2_in_box)
|
||||
|
||||
# If the ground-truth tap action falls outside any of the annotation
|
||||
# bounding boxes or one of the actions is inside a bounding box and the other
|
||||
# is outside bounding box or vice versa, compare the points using Euclidean
|
||||
# distance.
|
||||
within_threshold = (
|
||||
jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx))
|
||||
<= matching_tap_distance_threshold_screen_percentage
|
||||
)
|
||||
return jnp.logical_or(both_in_box, within_threshold)
|
||||
|
||||
|
||||
def _check_drag_actions_match(
|
||||
drag_1_touch_yx,
|
||||
drag_1_lift_yx,
|
||||
drag_2_touch_yx,
|
||||
drag_2_lift_yx,
|
||||
):
|
||||
"""Determines if two drag actions are the same."""
|
||||
# Store drag deltas (the change in the y and x coordinates from touch to
|
||||
# lift), magnitudes, and the index of the main axis, which is the axis with
|
||||
# the greatest change in coordinate value (e.g. a drag starting at (0, 0) and
|
||||
# ending at (0.3, 0.5) has a main axis index of 1).
|
||||
drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx
|
||||
drag_1_magnitudes = jnp.abs(drag_1_deltas)
|
||||
drag_1_main_axis = np.argmax(drag_1_magnitudes)
|
||||
drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx
|
||||
drag_2_magnitudes = jnp.abs(drag_2_deltas)
|
||||
drag_2_main_axis = np.argmax(drag_2_magnitudes)
|
||||
|
||||
return jnp.equal(drag_1_main_axis, drag_2_main_axis)
|
||||
|
||||
|
||||
def check_actions_match(
|
||||
action_1_touch_yx,
|
||||
action_1_lift_yx,
|
||||
action_1_action_type,
|
||||
action_2_touch_yx,
|
||||
action_2_lift_yx,
|
||||
action_2_action_type,
|
||||
annotation_positions,
|
||||
tap_distance_threshold = _TAP_DISTANCE_THRESHOLD,
|
||||
annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION,
|
||||
annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION,
|
||||
):
|
||||
"""Determines if two actions are considered to be the same.
|
||||
|
||||
Two actions being "the same" is defined here as two actions that would result
|
||||
in a similar screen state.
|
||||
|
||||
Args:
|
||||
action_1_touch_yx: The (y, x) coordinates of the first action's touch.
|
||||
action_1_lift_yx: The (y, x) coordinates of the first action's lift.
|
||||
action_1_action_type: The action type of the first action.
|
||||
action_2_touch_yx: The (y, x) coordinates of the second action's touch.
|
||||
action_2_lift_yx: The (y, x) coordinates of the second action's lift.
|
||||
action_2_action_type: The action type of the second action.
|
||||
annotation_positions: The positions of the UI annotations for the screen. It
|
||||
is A 2D int array of shape (num_bboxes, 4), where each row represents a
|
||||
bounding box: (y_top_left, x_top_left, box_height, box_width). Note that
|
||||
containment is inclusive of the bounding box edges.
|
||||
tap_distance_threshold: The threshold that determines if two taps result in
|
||||
a matching screen state if they don't fall the same bounding boxes.
|
||||
annotation_width_augment_fraction: The fraction to increase the width of the
|
||||
bounding box by.
|
||||
annotation_height_augment_fraction: The fraction to increase the height of
|
||||
of the bounding box by.
|
||||
|
||||
Returns:
|
||||
A boolean representing whether the two given actions are the same or not.
|
||||
"""
|
||||
action_1_touch_yx = jnp.asarray(action_1_touch_yx)
|
||||
action_1_lift_yx = jnp.asarray(action_1_lift_yx)
|
||||
action_2_touch_yx = jnp.asarray(action_2_touch_yx)
|
||||
action_2_lift_yx = jnp.asarray(action_2_lift_yx)
|
||||
|
||||
# Checks if at least one of the actions is global (i.e. not DUAL_POINT),
|
||||
# because if that is the case, only the actions' types need to be compared.
|
||||
has_non_dual_point_action = jnp.logical_or(
|
||||
_is_non_dual_point_action(action_1_action_type),
|
||||
_is_non_dual_point_action(action_2_action_type),
|
||||
)
|
||||
#print("non dual point: "+str(has_non_dual_point_action))
|
||||
|
||||
different_dual_point_types = jnp.logical_xor(
|
||||
is_tap_action(action_1_touch_yx, action_1_lift_yx),
|
||||
is_tap_action(action_2_touch_yx, action_2_lift_yx),
|
||||
)
|
||||
#print("different dual type: "+str(different_dual_point_types))
|
||||
|
||||
is_tap = jnp.logical_and(
|
||||
is_tap_action(action_1_touch_yx, action_1_lift_yx),
|
||||
is_tap_action(action_2_touch_yx, action_2_lift_yx),
|
||||
)
|
||||
#print("is tap: "+str(is_tap))
|
||||
|
||||
taps_match = _check_tap_actions_match(
|
||||
action_1_touch_yx,
|
||||
action_2_touch_yx,
|
||||
annotation_positions,
|
||||
tap_distance_threshold,
|
||||
annotation_width_augment_fraction,
|
||||
annotation_height_augment_fraction,
|
||||
)
|
||||
#print("tap match: "+str(taps_match))
|
||||
|
||||
taps_match = jnp.logical_and(is_tap, taps_match)
|
||||
#print("tap match: "+str(taps_match))
|
||||
|
||||
drags_match = _check_drag_actions_match(
|
||||
action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx
|
||||
)
|
||||
drags_match = jnp.where(is_tap, False, drags_match)
|
||||
#print("drag match: "+str(drags_match))
|
||||
|
||||
return jnp.where(
|
||||
has_non_dual_point_action,
|
||||
jnp.equal(action_1_action_type, action_2_action_type),
|
||||
jnp.where(
|
||||
different_dual_point_types,
|
||||
False,
|
||||
jnp.logical_or(taps_match, drags_match),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def action_2_format(step_data):
|
||||
# 把test数据集中的动作格式转换为计算matching score的格式
|
||||
action_type = step_data["action_type_id"]
|
||||
|
||||
if action_type == 4:
|
||||
if step_data["action_type_text"] == 'click': # 点击
|
||||
touch_point = step_data["touch"]
|
||||
lift_point = step_data["lift"]
|
||||
else: # 上下左右滑动
|
||||
if step_data["action_type_text"] == 'scroll down':
|
||||
touch_point = [0.5, 0.8]
|
||||
lift_point = [0.5, 0.2]
|
||||
elif step_data["action_type_text"] == 'scroll up':
|
||||
touch_point = [0.5, 0.2]
|
||||
lift_point = [0.5, 0.8]
|
||||
elif step_data["action_type_text"] == 'scroll left':
|
||||
touch_point = [0.2, 0.5]
|
||||
lift_point = [0.8, 0.5]
|
||||
elif step_data["action_type_text"] == 'scroll right':
|
||||
touch_point = [0.8, 0.5]
|
||||
lift_point = [0.2, 0.5]
|
||||
else:
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
|
||||
if action_type == 3:
|
||||
typed_text = step_data["type_text"]
|
||||
else:
|
||||
typed_text = ""
|
||||
|
||||
action = {"action_type": action_type, "touch_point": touch_point, "lift_point": lift_point,
|
||||
"typed_text": typed_text}
|
||||
|
||||
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
|
||||
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
|
||||
action["typed_text"] = action["typed_text"].lower()
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def pred_2_format(step_data):
|
||||
# 把模型输出的内容转换为计算action_matching的格式
|
||||
action_type = step_data["action_type"]
|
||||
|
||||
if action_type == 4: # 点击
|
||||
action_type_new = 4
|
||||
touch_point = step_data["click_point"]
|
||||
lift_point = step_data["click_point"]
|
||||
typed_text = ""
|
||||
elif action_type == 0:
|
||||
action_type_new = 4
|
||||
touch_point = [0.5, 0.8]
|
||||
lift_point = [0.5, 0.2]
|
||||
typed_text = ""
|
||||
elif action_type == 1:
|
||||
action_type_new = 4
|
||||
touch_point = [0.5, 0.2]
|
||||
lift_point = [0.5, 0.8]
|
||||
typed_text = ""
|
||||
elif action_type == 8:
|
||||
action_type_new = 4
|
||||
touch_point = [0.2, 0.5]
|
||||
lift_point = [0.8, 0.5]
|
||||
typed_text = ""
|
||||
elif action_type == 9:
|
||||
action_type_new = 4
|
||||
touch_point = [0.8, 0.5]
|
||||
lift_point = [0.2, 0.5]
|
||||
typed_text = ""
|
||||
else:
|
||||
action_type_new = action_type
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
typed_text = ""
|
||||
if action_type_new == 3:
|
||||
typed_text = step_data["typed_text"]
|
||||
|
||||
action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
|
||||
"typed_text": typed_text}
|
||||
|
||||
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
|
||||
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
|
||||
action["typed_text"] = action["typed_text"].lower()
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def pred_2_format_simplified(step_data):
|
||||
# 把模型输出的内容转换为计算action_matching的格式
|
||||
action_type = step_data["action_type"]
|
||||
|
||||
if action_type == 'click' : # 点击
|
||||
action_type_new = 4
|
||||
touch_point = step_data["click_point"]
|
||||
lift_point = step_data["click_point"]
|
||||
typed_text = ""
|
||||
elif action_type == 'scroll' and step_data["direction"] == 'down':
|
||||
action_type_new = 4
|
||||
touch_point = [0.5, 0.8]
|
||||
lift_point = [0.5, 0.2]
|
||||
typed_text = ""
|
||||
elif action_type == 'scroll' and step_data["direction"] == 'up':
|
||||
action_type_new = 4
|
||||
touch_point = [0.5, 0.2]
|
||||
lift_point = [0.5, 0.8]
|
||||
typed_text = ""
|
||||
elif action_type == 'scroll' and step_data["direction"] == 'left':
|
||||
action_type_new = 4
|
||||
touch_point = [0.2, 0.5]
|
||||
lift_point = [0.8, 0.5]
|
||||
typed_text = ""
|
||||
elif action_type == 'scroll' and step_data["direction"] == 'right':
|
||||
action_type_new = 4
|
||||
touch_point = [0.8, 0.5]
|
||||
lift_point = [0.2, 0.5]
|
||||
typed_text = ""
|
||||
elif action_type == 'type':
|
||||
action_type_new = 3
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
typed_text = step_data["text"]
|
||||
elif action_type == 'navigate_back':
|
||||
action_type_new = 5
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
typed_text = ""
|
||||
elif action_type == 'navigate_home':
|
||||
action_type_new = 6
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
typed_text = ""
|
||||
else:
|
||||
action_type_new = action_type
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
typed_text = ""
|
||||
# if action_type_new == 'type':
|
||||
# typed_text = step_data["text"]
|
||||
|
||||
action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
|
||||
"typed_text": typed_text}
|
||||
|
||||
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
|
||||
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
|
||||
action["typed_text"] = action["typed_text"].lower()
|
||||
|
||||
return action
|
||||
45
util/action_type.py
Normal file
45
util/action_type.py
Normal file
@@ -0,0 +1,45 @@
|
||||
'''
|
||||
Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
|
||||
'''
|
||||
|
||||
import enum
|
||||
|
||||
class ActionType(enum.IntEnum):
|
||||
|
||||
# Placeholders for unused enum values
|
||||
UNUSED_0 = 0
|
||||
UNUSED_1 = 1
|
||||
UNUSED_2 = 2
|
||||
UNUSED_8 = 8
|
||||
UNUSED_9 = 9
|
||||
|
||||
########### Agent actions ###########
|
||||
|
||||
# A type action that sends text to the emulator. Note that this simply sends
|
||||
# text and does not perform any clicks for element focus or enter presses for
|
||||
# submitting text.
|
||||
TYPE = 3
|
||||
|
||||
# The dual point action used to represent all gestures.
|
||||
DUAL_POINT = 4
|
||||
|
||||
# These actions differentiate pressing the home and back button from touches.
|
||||
# They represent explicit presses of back and home performed using ADB.
|
||||
PRESS_BACK = 5
|
||||
PRESS_HOME = 6
|
||||
|
||||
# An action representing that ADB command for hitting enter was performed.
|
||||
PRESS_ENTER = 7
|
||||
|
||||
########### Episode status actions ###########
|
||||
|
||||
# An action used to indicate the desired task has been completed and resets
|
||||
# the environment. This action should also be used in the case that the task
|
||||
# has already been completed and there is nothing to do.
|
||||
# e.g. The task is to turn on the Wi-Fi when it is already on
|
||||
STATUS_TASK_COMPLETE = 10
|
||||
|
||||
# An action used to indicate that desired task is impossible to complete and
|
||||
# resets the environment. This can be a result of many different things
|
||||
# including UI changes, Android version differences, etc.
|
||||
STATUS_TASK_IMPOSSIBLE = 11
|
||||
262
util/box_annotator.py
Normal file
262
util/box_annotator.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from supervision.detection.core import Detections
|
||||
from supervision.draw.color import Color, ColorPalette
|
||||
|
||||
|
||||
class BoxAnnotator:
|
||||
"""
|
||||
A class for drawing bounding boxes on an image using detections provided.
|
||||
|
||||
Attributes:
|
||||
color (Union[Color, ColorPalette]): The color to draw the bounding box,
|
||||
can be a single color or a color palette
|
||||
thickness (int): The thickness of the bounding box lines, default is 2
|
||||
text_color (Color): The color of the text on the bounding box, default is white
|
||||
text_scale (float): The scale of the text on the bounding box, default is 0.5
|
||||
text_thickness (int): The thickness of the text on the bounding box,
|
||||
default is 1
|
||||
text_padding (int): The padding around the text on the bounding box,
|
||||
default is 5
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
|
||||
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
|
||||
text_color: Color = Color.BLACK,
|
||||
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
text_thickness: int = 2, #1, # 2 for demo
|
||||
text_padding: int = 10,
|
||||
avoid_overlap: bool = True,
|
||||
):
|
||||
self.color: Union[Color, ColorPalette] = color
|
||||
self.thickness: int = thickness
|
||||
self.text_color: Color = text_color
|
||||
self.text_scale: float = text_scale
|
||||
self.text_thickness: int = text_thickness
|
||||
self.text_padding: int = text_padding
|
||||
self.avoid_overlap: bool = avoid_overlap
|
||||
|
||||
def annotate(
|
||||
self,
|
||||
scene: np.ndarray,
|
||||
detections: Detections,
|
||||
labels: Optional[List[str]] = None,
|
||||
skip_label: bool = False,
|
||||
image_size: Optional[Tuple[int, int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Draws bounding boxes on the frame using the detections provided.
|
||||
|
||||
Args:
|
||||
scene (np.ndarray): The image on which the bounding boxes will be drawn
|
||||
detections (Detections): The detections for which the
|
||||
bounding boxes will be drawn
|
||||
labels (Optional[List[str]]): An optional list of labels
|
||||
corresponding to each detection. If `labels` are not provided,
|
||||
corresponding `class_id` will be used as label.
|
||||
skip_label (bool): Is set to `True`, skips bounding box label annotation.
|
||||
Returns:
|
||||
np.ndarray: The image with the bounding boxes drawn on it
|
||||
|
||||
Example:
|
||||
```python
|
||||
import supervision as sv
|
||||
|
||||
classes = ['person', ...]
|
||||
image = ...
|
||||
detections = sv.Detections(...)
|
||||
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
labels = [
|
||||
f"{classes[class_id]} {confidence:0.2f}"
|
||||
for _, _, confidence, class_id, _ in detections
|
||||
]
|
||||
annotated_frame = box_annotator.annotate(
|
||||
scene=image.copy(),
|
||||
detections=detections,
|
||||
labels=labels
|
||||
)
|
||||
```
|
||||
"""
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
for i in range(len(detections)):
|
||||
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
|
||||
class_id = (
|
||||
detections.class_id[i] if detections.class_id is not None else None
|
||||
)
|
||||
idx = class_id if class_id is not None else i
|
||||
color = (
|
||||
self.color.by_idx(idx)
|
||||
if isinstance(self.color, ColorPalette)
|
||||
else self.color
|
||||
)
|
||||
cv2.rectangle(
|
||||
img=scene,
|
||||
pt1=(x1, y1),
|
||||
pt2=(x2, y2),
|
||||
color=color.as_bgr(),
|
||||
thickness=self.thickness,
|
||||
)
|
||||
if skip_label:
|
||||
continue
|
||||
|
||||
text = (
|
||||
f"{class_id}"
|
||||
if (labels is None or len(detections) != len(labels))
|
||||
else labels[i]
|
||||
)
|
||||
|
||||
text_width, text_height = cv2.getTextSize(
|
||||
text=text,
|
||||
fontFace=font,
|
||||
fontScale=self.text_scale,
|
||||
thickness=self.text_thickness,
|
||||
)[0]
|
||||
|
||||
if not self.avoid_overlap:
|
||||
text_x = x1 + self.text_padding
|
||||
text_y = y1 - self.text_padding
|
||||
|
||||
text_background_x1 = x1
|
||||
text_background_y1 = y1 - 2 * self.text_padding - text_height
|
||||
|
||||
text_background_x2 = x1 + 2 * self.text_padding + text_width
|
||||
text_background_y2 = y1
|
||||
# text_x = x1 - self.text_padding - text_width
|
||||
# text_y = y1 + self.text_padding + text_height
|
||||
# text_background_x1 = x1 - 2 * self.text_padding - text_width
|
||||
# text_background_y1 = y1
|
||||
# text_background_x2 = x1
|
||||
# text_background_y2 = y1 + 2 * self.text_padding + text_height
|
||||
else:
|
||||
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
|
||||
|
||||
cv2.rectangle(
|
||||
img=scene,
|
||||
pt1=(text_background_x1, text_background_y1),
|
||||
pt2=(text_background_x2, text_background_y2),
|
||||
color=color.as_bgr(),
|
||||
thickness=cv2.FILLED,
|
||||
)
|
||||
# import pdb; pdb.set_trace()
|
||||
box_color = color.as_rgb()
|
||||
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
|
||||
text_color = (0,0,0) if luminance > 160 else (255,255,255)
|
||||
cv2.putText(
|
||||
img=scene,
|
||||
text=text,
|
||||
org=(text_x, text_y),
|
||||
fontFace=font,
|
||||
fontScale=self.text_scale,
|
||||
# color=self.text_color.as_rgb(),
|
||||
color=text_color,
|
||||
thickness=self.text_thickness,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
return scene
|
||||
|
||||
|
||||
def box_area(box):
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
def intersection_area(box1, box2):
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
return max(0, x2 - x1) * max(0, y2 - y1)
|
||||
|
||||
def IoU(box1, box2, return_max=True):
|
||||
intersection = intersection_area(box1, box2)
|
||||
union = box_area(box1) + box_area(box2) - intersection
|
||||
if box_area(box1) > 0 and box_area(box2) > 0:
|
||||
ratio1 = intersection / box_area(box1)
|
||||
ratio2 = intersection / box_area(box2)
|
||||
else:
|
||||
ratio1, ratio2 = 0, 0
|
||||
if return_max:
|
||||
return max(intersection / union, ratio1, ratio2)
|
||||
else:
|
||||
return intersection / union
|
||||
|
||||
|
||||
def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
|
||||
""" check overlap of text and background detection box, and get_optimal_label_pos,
|
||||
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
|
||||
Threshold: default to 0.3
|
||||
"""
|
||||
|
||||
def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
|
||||
is_overlap = False
|
||||
for i in range(len(detections)):
|
||||
detection = detections.xyxy[i].astype(int)
|
||||
if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
|
||||
is_overlap = True
|
||||
break
|
||||
# check if the text is out of the image
|
||||
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
|
||||
is_overlap = True
|
||||
return is_overlap
|
||||
|
||||
# if pos == 'top left':
|
||||
text_x = x1 + text_padding
|
||||
text_y = y1 - text_padding
|
||||
|
||||
text_background_x1 = x1
|
||||
text_background_y1 = y1 - 2 * text_padding - text_height
|
||||
|
||||
text_background_x2 = x1 + 2 * text_padding + text_width
|
||||
text_background_y2 = y1
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
# elif pos == 'outer left':
|
||||
text_x = x1 - text_padding - text_width
|
||||
text_y = y1 + text_padding + text_height
|
||||
|
||||
text_background_x1 = x1 - 2 * text_padding - text_width
|
||||
text_background_y1 = y1
|
||||
|
||||
text_background_x2 = x1
|
||||
text_background_y2 = y1 + 2 * text_padding + text_height
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
|
||||
# elif pos == 'outer right':
|
||||
text_x = x2 + text_padding
|
||||
text_y = y1 + text_padding + text_height
|
||||
|
||||
text_background_x1 = x2
|
||||
text_background_y1 = y1
|
||||
|
||||
text_background_x2 = x2 + 2 * text_padding + text_width
|
||||
text_background_y2 = y1 + 2 * text_padding + text_height
|
||||
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
# elif pos == 'top right':
|
||||
text_x = x2 - text_padding - text_width
|
||||
text_y = y1 - text_padding
|
||||
|
||||
text_background_x1 = x2 - 2 * text_padding - text_width
|
||||
text_background_y1 = y1 - 2 * text_padding - text_height
|
||||
|
||||
text_background_x2 = x2
|
||||
text_background_y2 = y1
|
||||
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
607
utils.py
Executable file
607
utils.py
Executable file
@@ -0,0 +1,607 @@
|
||||
# from ultralytics import YOLO
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import time
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import json
|
||||
import requests
|
||||
# utility function
|
||||
import os
|
||||
from openai import AzureOpenAI
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
# %matplotlib inline
|
||||
from matplotlib import pyplot as plt
|
||||
import easyocr
|
||||
reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory # 'ch_sim',
|
||||
import time
|
||||
import base64
|
||||
|
||||
import os
|
||||
import ast
|
||||
import torch
|
||||
from typing import Tuple, List
|
||||
from torchvision.ops import box_convert
|
||||
import re
|
||||
from torchvision.transforms import ToPILImage
|
||||
import supervision as sv
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
def get_caption_model_processor(model_name="Salesforce/blip2-opt-2.7b", device=None):
|
||||
if not device:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if model_name == "Salesforce/blip2-opt-2.7b":
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip2-opt-2.7b", device_map=None, torch_dtype=torch.float16
|
||||
# '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16
|
||||
)
|
||||
elif model_name == "blip2-opt-2.7b-ui":
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
if device == 'cpu':
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
'/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
'/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16
|
||||
)
|
||||
elif model_name == "florence":
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
|
||||
if device == 'cpu':
|
||||
model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai", torch_dtype=torch.float32, trust_remote_code=True)#.to(device)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai_win_ep5_fixed", torch_dtype=torch.float16, trust_remote_code=True).to(device)
|
||||
elif model_name == 'phi3v_ui':
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
model_id = "microsoft/Phi-3-vision-128k-instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained('/home/yadonglu/sandbox/data/orca/phi3v_ui', device_map=device, trust_remote_code=True, torch_dtype="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
||||
elif model_name == 'phi3v':
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
model_id = "microsoft/Phi-3-vision-128k-instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, trust_remote_code=True, torch_dtype="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
||||
return {'model': model.to(device), 'processor': processor}
|
||||
|
||||
|
||||
def get_yolo_model():
|
||||
from ultralytics import YOLO
|
||||
# Load the model.
|
||||
# model = YOLO('/home/yadonglu/sandbox/data/yolo/runs/detect/yolov8n_v8_xcyc/weights/best.pt')
|
||||
model = YOLO('/home/yadonglu/sandbox/data/yolo/runs/detect/yolov8n_v8_seq_xcyc_b32_n4_office_ep20/weights/best.pt')
|
||||
return model
|
||||
|
||||
|
||||
def get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_model_processor, prompt=None):
|
||||
to_pil = ToPILImage()
|
||||
if ocr_bbox:
|
||||
non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
|
||||
else:
|
||||
non_ocr_boxes = filtered_boxes
|
||||
croped_pil_image = []
|
||||
for i, coord in enumerate(non_ocr_boxes):
|
||||
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
|
||||
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
|
||||
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
||||
croped_pil_image.append(to_pil(cropped_image))
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
model, processor = caption_model_processor['model'], caption_model_processor['processor']
|
||||
if not prompt:
|
||||
if 'florence' in model.config.name_or_path:
|
||||
prompt = "<CAPTION>"
|
||||
else:
|
||||
prompt = "The image shows"
|
||||
# prompt = "NO gender!NO gender!NO gender! The image shows a icon:"
|
||||
|
||||
batch_size = 10 # Number of samples per batch
|
||||
generated_texts = []
|
||||
device = model.device
|
||||
|
||||
for i in range(0, len(croped_pil_image), batch_size):
|
||||
batch = croped_pil_image[i:i+batch_size]
|
||||
if model.device.type == 'cuda':
|
||||
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
|
||||
else:
|
||||
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
|
||||
if 'florence' in model.config.name_or_path:
|
||||
generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,num_beams=3, do_sample=False)
|
||||
else:
|
||||
generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
generated_text = [gen.strip() for gen in generated_text]
|
||||
generated_texts.extend(generated_text)
|
||||
|
||||
return generated_texts
|
||||
|
||||
|
||||
|
||||
def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
|
||||
to_pil = ToPILImage()
|
||||
if ocr_bbox:
|
||||
non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
|
||||
else:
|
||||
non_ocr_boxes = filtered_boxes
|
||||
croped_pil_image = []
|
||||
for i, coord in enumerate(non_ocr_boxes):
|
||||
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
|
||||
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
|
||||
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
||||
croped_pil_image.append(to_pil(cropped_image))
|
||||
|
||||
model, processor = caption_model_processor['model'], caption_model_processor['processor']
|
||||
device = model.device
|
||||
messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
|
||||
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
batch_size = 5 # Number of samples per batch
|
||||
generated_texts = []
|
||||
|
||||
for i in range(0, len(croped_pil_image), batch_size):
|
||||
images = croped_pil_image[i:i+batch_size]
|
||||
image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
|
||||
inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
|
||||
texts = [prompt] * len(images)
|
||||
for i, txt in enumerate(texts):
|
||||
input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
|
||||
inputs['input_ids'].append(input['input_ids'])
|
||||
inputs['attention_mask'].append(input['attention_mask'])
|
||||
inputs['pixel_values'].append(input['pixel_values'])
|
||||
inputs['image_sizes'].append(input['image_sizes'])
|
||||
max_len = max([x.shape[1] for x in inputs['input_ids']])
|
||||
for i, v in enumerate(inputs['input_ids']):
|
||||
inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
|
||||
inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
|
||||
inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
|
||||
|
||||
generation_args = {
|
||||
"max_new_tokens": 25,
|
||||
"temperature": 0.01,
|
||||
"do_sample": False,
|
||||
}
|
||||
generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
|
||||
# # remove input tokens
|
||||
generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
|
||||
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
response = [res.strip('\n').strip() for res in response]
|
||||
generated_texts.extend(response)
|
||||
|
||||
return generated_texts
|
||||
|
||||
def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
|
||||
assert ocr_bbox is None or isinstance(ocr_bbox, List)
|
||||
|
||||
def box_area(box):
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
def intersection_area(box1, box2):
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
return max(0, x2 - x1) * max(0, y2 - y1)
|
||||
|
||||
def IoU(box1, box2):
|
||||
intersection = intersection_area(box1, box2)
|
||||
union = box_area(box1) + box_area(box2) - intersection + 1e-6
|
||||
if box_area(box1) > 0 and box_area(box2) > 0:
|
||||
ratio1 = intersection / box_area(box1)
|
||||
ratio2 = intersection / box_area(box2)
|
||||
else:
|
||||
ratio1, ratio2 = 0, 0
|
||||
return max(intersection / union, ratio1, ratio2)
|
||||
|
||||
boxes = boxes.tolist()
|
||||
filtered_boxes = []
|
||||
if ocr_bbox:
|
||||
filtered_boxes.extend(ocr_bbox)
|
||||
# print('ocr_bbox!!!', ocr_bbox)
|
||||
for i, box1 in enumerate(boxes):
|
||||
# if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
|
||||
is_valid_box = True
|
||||
for j, box2 in enumerate(boxes):
|
||||
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
|
||||
is_valid_box = False
|
||||
break
|
||||
if is_valid_box:
|
||||
# add the following 2 lines to include ocr bbox
|
||||
if ocr_bbox:
|
||||
if not any(IoU(box1, box3) > iou_threshold for k, box3 in enumerate(ocr_bbox)):
|
||||
filtered_boxes.append(box1)
|
||||
else:
|
||||
filtered_boxes.append(box1)
|
||||
return torch.tensor(filtered_boxes)
|
||||
|
||||
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.RandomResize([800], max_size=1333),
|
||||
T.ToTensor(),
|
||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
image_source = Image.open(image_path).convert("RGB")
|
||||
image = np.asarray(image_source)
|
||||
image_transformed, _ = transform(image_source, None)
|
||||
return image, image_transformed
|
||||
|
||||
|
||||
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
|
||||
text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
|
||||
"""
|
||||
This function annotates an image with bounding boxes and labels.
|
||||
|
||||
Parameters:
|
||||
image_source (np.ndarray): The source image to be annotated.
|
||||
boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
|
||||
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
|
||||
phrases (List[str]): A list of labels for each bounding box.
|
||||
text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
|
||||
Returns:
|
||||
np.ndarray: The annotated image.
|
||||
"""
|
||||
h, w, _ = image_source.shape
|
||||
boxes = boxes * torch.Tensor([w, h, w, h])
|
||||
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
||||
xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
|
||||
detections = sv.Detections(xyxy=xyxy)
|
||||
|
||||
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
|
||||
|
||||
from util.box_annotator import BoxAnnotator
|
||||
box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
annotated_frame = image_source.copy()
|
||||
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
|
||||
|
||||
label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
|
||||
return annotated_frame, label_coordinates
|
||||
|
||||
|
||||
def predict(model, image, caption, box_threshold, text_threshold):
|
||||
""" Use huggingface model to replace the original model
|
||||
"""
|
||||
model, processor = model['model'], model['processor']
|
||||
device = model.device
|
||||
|
||||
inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
inputs.input_ids,
|
||||
box_threshold=box_threshold, # 0.4,
|
||||
text_threshold=text_threshold, # 0.3,
|
||||
target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
|
||||
return boxes, logits, phrases
|
||||
|
||||
|
||||
def predict_yolo(model, image_path, box_threshold):
|
||||
""" Use huggingface model to replace the original model
|
||||
"""
|
||||
# model = model['model']
|
||||
|
||||
result = model.predict(
|
||||
source=image_path,
|
||||
conf=box_threshold,
|
||||
# iou=0.5, # default 0.7
|
||||
)
|
||||
boxes = result[0].boxes.xyxy#.tolist() # in pixel space
|
||||
conf = result[0].boxes.conf
|
||||
phrases = [str(i) for i in range(len(boxes))]
|
||||
|
||||
return boxes, conf, phrases
|
||||
|
||||
|
||||
def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None):
|
||||
""" ocr_bbox: list of xyxy format bbox
|
||||
"""
|
||||
TEXT_PROMPT = "clickable buttons on the screen"
|
||||
# BOX_TRESHOLD = 0.02 # 0.05/0.02 for web and 0.1 for mobile
|
||||
TEXT_TRESHOLD = 0.01 # 0.9 # 0.01
|
||||
image_source = Image.open(img_path).convert("RGB")
|
||||
w, h = image_source.size
|
||||
# import pdb; pdb.set_trace()
|
||||
if False: # TODO
|
||||
xyxy, logits, phrases = predict(model=model, image=image_source, caption=TEXT_PROMPT, box_threshold=BOX_TRESHOLD, text_threshold=TEXT_TRESHOLD)
|
||||
else:
|
||||
xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD)
|
||||
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
|
||||
image_source = np.asarray(image_source)
|
||||
phrases = [str(i) for i in range(len(phrases))]
|
||||
|
||||
# annotate the image with labels
|
||||
h, w, _ = image_source.shape
|
||||
if ocr_bbox:
|
||||
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
|
||||
ocr_bbox=ocr_bbox.tolist()
|
||||
else:
|
||||
print('no ocr bbox!!!')
|
||||
ocr_bbox = None
|
||||
filtered_boxes = remove_overlap(boxes=xyxy, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox)
|
||||
|
||||
# get parsed icon local semantics
|
||||
if use_local_semantics:
|
||||
caption_model = caption_model_processor['model']
|
||||
if 'phi3_v' in caption_model.config.model_type:
|
||||
parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
|
||||
else:
|
||||
parsed_content_icon = get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_model_processor, prompt=prompt)
|
||||
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
|
||||
icon_start = len(ocr_text)
|
||||
parsed_content_icon_ls = []
|
||||
for i, txt in enumerate(parsed_content_icon):
|
||||
parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
|
||||
parsed_content_merged = ocr_text + parsed_content_icon_ls
|
||||
else:
|
||||
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
|
||||
parsed_content_merged = ocr_text
|
||||
|
||||
filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
|
||||
|
||||
phrases = [i for i in range(len(filtered_boxes))]
|
||||
|
||||
# draw boxes
|
||||
if draw_bbox_config:
|
||||
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
|
||||
else:
|
||||
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
|
||||
|
||||
pil_img = Image.fromarray(annotated_frame)
|
||||
buffered = io.BytesIO()
|
||||
pil_img.save(buffered, format="PNG")
|
||||
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||
if output_coord_in_ratio:
|
||||
# h, w, _ = image_source.shape
|
||||
label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
|
||||
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
|
||||
|
||||
return encoded_image, label_coordinates, parsed_content_merged
|
||||
|
||||
|
||||
def get_xywh(input):
|
||||
x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
|
||||
x, y, w, h = int(x), int(y), int(w), int(h)
|
||||
return x, y, w, h
|
||||
|
||||
def get_xyxy(input):
|
||||
x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
|
||||
x, y, xp, yp = int(x), int(y), int(xp), int(yp)
|
||||
return x, y, xp, yp
|
||||
|
||||
def get_xywh_yolo(input):
|
||||
x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
|
||||
x, y, w, h = int(x), int(y), int(w), int(h)
|
||||
return x, y, w, h
|
||||
|
||||
|
||||
def run_api(body, max_tokens=1024):
|
||||
'''
|
||||
API call, check https://platform.openai.com/docs/guides/vision for the latest api usage.
|
||||
'''
|
||||
max_num_trial = 3
|
||||
num_trial = 0
|
||||
while num_trial < max_num_trial:
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=deployment,
|
||||
messages=body,
|
||||
temperature=0.01,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except:
|
||||
print('retry call gptv', num_trial)
|
||||
num_trial += 1
|
||||
time.sleep(10)
|
||||
return ''
|
||||
|
||||
def call_gpt4v_new(message_text, image_path=None, max_tokens=2048):
|
||||
if image_path:
|
||||
try:
|
||||
with open(image_path, "rb") as img_file:
|
||||
encoded_image = base64.b64encode(img_file.read()).decode('ascii')
|
||||
except:
|
||||
encoded_image = image_path
|
||||
|
||||
if image_path:
|
||||
content = [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}, {"type": "text","text": message_text},]
|
||||
else:
|
||||
content = [{"type": "text","text": message_text},]
|
||||
|
||||
max_num_trial = 3
|
||||
num_trial = 0
|
||||
call_api_success = True
|
||||
|
||||
while num_trial < max_num_trial:
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=deployment,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are an AI assistant that is good at making plans and analyzing screens, and helping people find information."
|
||||
},
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
],
|
||||
temperature=0.01,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
ans_1st_pass = response.choices[0].message.content
|
||||
break
|
||||
except:
|
||||
print('retry call gptv', num_trial)
|
||||
num_trial += 1
|
||||
ans_1st_pass = ''
|
||||
time.sleep(10)
|
||||
if num_trial == max_num_trial:
|
||||
call_api_success = False
|
||||
return ans_1st_pass, call_api_success
|
||||
|
||||
|
||||
def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None):
|
||||
if easyocr_args is None:
|
||||
easyocr_args = {}
|
||||
result = reader.readtext(image_path, **easyocr_args)
|
||||
is_goal_filtered = False
|
||||
if goal_filtering:
|
||||
ocr_filter_fs = "Example 1:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Share', 0.949013667261589), ([[3068, 197], [3135, 197], [3135, 227], [3068, 227]], 'Link _', 0.3567054243152049), ([[3006, 321], [3178, 321], [3178, 354], [3006, 354]], 'Manage Access', 0.8800734456437066)] ``` \n Example 2:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Search Google or type a URL', 0.949013667261589)] ```"
|
||||
# message_text = f"Based on the ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. The task is: {goal_filtering}, the ocr results are: {str(result)}. Your final answer should be in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."
|
||||
message_text = f"Based on the task and ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. Requirement: 1. first give a brief analysis. 2. provide an answer in the format: ```In summary, the task related bboxes are: ..```, you must put it inside ``` ```. Do not include any info after ```.\n {ocr_filter_fs}\n The task is: {goal_filtering}, the ocr results are: {str(result)}."
|
||||
|
||||
prompt = [{"role":"system", "content": "You are an AI assistant that helps people find the correct way to operate computer or smartphone."}, {"role":"user","content": message_text},]
|
||||
print('[Perform OCR filtering by goal] ongoing ...')
|
||||
# pred, _, _ = call_gpt4(prompt)
|
||||
pred, _, = call_gpt4v(message_text)
|
||||
# import pdb; pdb.set_trace()
|
||||
try:
|
||||
# match = re.search(r"```(.*?)```", pred, re.DOTALL)
|
||||
# result = match.group(1).strip()
|
||||
# pred = result.split('In summary, the task related bboxes are:')[-1].strip()
|
||||
pred = pred.split('In summary, the task related bboxes are:')[-1].strip().strip('```')
|
||||
result = ast.literal_eval(pred)
|
||||
print('[Perform OCR filtering by goal] success!!! Filtered buttons: ', pred)
|
||||
is_goal_filtered = True
|
||||
except:
|
||||
print('[Perform OCR filtering by goal] failed or unused!!!')
|
||||
pass
|
||||
# added_prompt = [{"role":"assistant","content":pred},
|
||||
# {"role":"user","content": "given the previous answers, please provide the final answer in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."}]
|
||||
# prompt.extend(added_prompt)
|
||||
# pred, _, _ = call_gpt4(prompt)
|
||||
# print('goal filtering pred 2nd:', pred)
|
||||
# result = ast.literal_eval(pred)
|
||||
# print('goal filtering pred:', result[-5:])
|
||||
coord = [item[0] for item in result]
|
||||
text = [item[1] for item in result]
|
||||
# confidence = [item[2] for item in result]
|
||||
# if confidence_filtering:
|
||||
# coord = [coord[i] for i in range(len(coord)) if confidence[i] > confidence_filtering]
|
||||
# text = [text[i] for i in range(len(text)) if confidence[i] > confidence_filtering]
|
||||
# read the image using cv2
|
||||
if display_img:
|
||||
opencv_img = cv2.imread(image_path)
|
||||
opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
|
||||
bb = []
|
||||
for item in coord:
|
||||
x, y, a, b = get_xywh(item)
|
||||
# print(x, y, a, b)
|
||||
bb.append((x, y, a, b))
|
||||
cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
|
||||
|
||||
# Display the image
|
||||
plt.imshow(opencv_img)
|
||||
else:
|
||||
if output_bb_format == 'xywh':
|
||||
bb = [get_xywh(item) for item in coord]
|
||||
elif output_bb_format == 'xyxy':
|
||||
bb = [get_xyxy(item) for item in coord]
|
||||
# print('bounding box!!!', bb)
|
||||
return (text, bb), is_goal_filtered
|
||||
|
||||
|
||||
def get_pred_gptv(message_text, yolo_labled_img, label_coordinates, summarize_history=True, verbose=True, history=None, id_key='Click ID'):
|
||||
""" This func first
|
||||
1. call gptv(yolo_labled_img, text bbox+task) -> ans_1st_cal
|
||||
2. call gpt4(ans_1st_cal, label_coordinates) -> final ans
|
||||
"""
|
||||
|
||||
# Configuration
|
||||
encoded_image = yolo_labled_img
|
||||
|
||||
# Payload for the request
|
||||
if not history:
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]},
|
||||
{"role": "user","content": [{"type": "text","text": message_text}, {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},]}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]},
|
||||
history,
|
||||
{"role": "user","content": [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},{"type": "text","text": message_text},]}
|
||||
]
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"temperature": 0.01, # 0.01
|
||||
"top_p": 0.95,
|
||||
"max_tokens": 800
|
||||
}
|
||||
|
||||
max_num_trial = 3
|
||||
num_trial = 0
|
||||
call_api_success = True
|
||||
while num_trial < max_num_trial:
|
||||
try:
|
||||
# response = requests.post(GPT4V_ENDPOINT, headers=headers, json=payload)
|
||||
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
|
||||
# ans_1st_pass = response.json()['choices'][0]['message']['content']
|
||||
response = client.chat.completions.create(
|
||||
model=deployment,
|
||||
messages=messages,
|
||||
temperature=0.01,
|
||||
max_tokens=512,
|
||||
)
|
||||
ans_1st_pass = response.choices[0].message.content
|
||||
break
|
||||
except requests.RequestException as e:
|
||||
print('retry call gptv', num_trial)
|
||||
num_trial += 1
|
||||
ans_1st_pass = ''
|
||||
time.sleep(30)
|
||||
# raise SystemExit(f"Failed to make the request. Error: {e}")
|
||||
if num_trial == max_num_trial:
|
||||
call_api_success = False
|
||||
if verbose:
|
||||
print('Answer by GPTV: ', ans_1st_pass)
|
||||
# extract by simple parsing
|
||||
try:
|
||||
match = re.search(r"```(.*?)```", ans_1st_pass, re.DOTALL)
|
||||
if match:
|
||||
result = match.group(1).strip()
|
||||
pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
|
||||
pred = ast.literal_eval(pred)
|
||||
else:
|
||||
pred = ans_1st_pass.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
|
||||
pred = ast.literal_eval(pred)
|
||||
|
||||
if id_key in pred:
|
||||
icon_id = pred[id_key]
|
||||
bbox = label_coordinates[str(icon_id)]
|
||||
pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
|
||||
except:
|
||||
# import pdb; pdb.set_trace()
|
||||
print('gptv action regex extract fail!!!')
|
||||
print('ans_1st_pass:', ans_1st_pass)
|
||||
pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False}
|
||||
|
||||
step_pred_summary = None
|
||||
if summarize_history:
|
||||
step_pred_summary, _ = call_gpt4v_new('Summarize what action you decide to perform in the current step, in one sentence, and do not include any icon box number: ' + ans_1st_pass, max_tokens=128)
|
||||
print('step_pred_summary', step_pred_summary)
|
||||
return pred, [call_api_success, ans_1st_pass, None, step_pred_summary]
|
||||
# return pred, [call_api_success, message_2nd, completion_2nd.choices[0].message.content, step_pred_summary]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user