.*?)'\s*\}"
-
- # Search for the dictionary in the text
- match = re.search(pattern, text)
-
- if match:
- # Extract matched groups into a dictionary
- return {
- match.group('key1'): match.group('value1'),
- match.group('key2'): match.group('value2'),
- }
- else:
- raise ValueError("No valid dictionary structure found in the text.")
-
-
-def get_phi3v_model_dict():
- from PIL import Image
- import requests
- from transformers import AutoModelForCausalLM
- from transformers import AutoProcessor
-
- model_id = "microsoft/Phi-3.5-vision-instruct"
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto", _attn_implementation='flash_attention_2')
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
- print('phi3v model loaded!!!')
- return {'model': model, 'processor': processor}
-
-
-def call_phi3v(messages, image_base64, model_dict):
- model, processor = model_dict['model'], model_dict['processor']
- device = model.device
- prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
- if isinstance(image_base64, tuple):
- image_base64, dino_labled_img = image_base64
- image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
- dino_labled_img = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
- inputs = processor(prompt, [image, dino_labled_img], return_tensors="pt").to(device)
- else:
- image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
- inputs = processor(prompt, [image], return_tensors="pt").to(device)
-
- generation_args = {
- "max_new_tokens": 512,
- "temperature": 0.01,
- "do_sample": False,
- }
-
- generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
- # remove input tokens
- generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
- ans = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- return ans
-
-
-def get_pred_phi3v(message_text, image_base64, label_coordinates, id_key='Click ID', model_dict=None):
- # messages = [
- # {"role": "system", "content": '''You are an expert at completing instructions on GUI screens.
- # You will be presented with two images. The first is the original screenshot. The second is the same screenshot with some numeric tags. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
- # ]
- messages = [
- {"role": "system", "content": '''You are an expert at completing instructions on GUI screens. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
- ]
- messages = []
- if isinstance(image_base64, tuple):
- messages.append({"role": "user", "content": '<|image_1|>\n' + '<|image_2|>\n' + message_text})
- else:
- messages.append({"role": "user", "content": '<|image_1|>\n' + message_text})
-
- response_text = call_phi3v(messages, image_base64, model_dict)
- print(response_text)
-
- try:
- response_text = ast.literal_eval(response_text)
-
- icon_id = response_text['Click BBox ID']
- bbox = label_coordinates[str(icon_id)]
- click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
- except:
- print('error parsing, use regex to parse!!!')
- import pdb; pdb.set_trace()
- response_text = extract_dict_from_text(response_text)
- icon_id = response_text['Click BBox ID']
- bbox = label_coordinates[str(icon_id)]
- click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
- return icon_id, bbox, click_point, response_text
-
- # try:
- # match = re.search(r"```(.*?)```", ans, re.DOTALL)
- # if match:
- # result = match.group(1).strip()
- # pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
- # pred = ast.literal_eval(pred)
- # else:
- # pred = ans.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
- # pred = ast.literal_eval(pred)
-
- # if pred[id_key]:
- # icon_id = pred[id_key]
- # bbox = label_coordinates[str(icon_id)]
- # pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
- # except:
- # print('phi3v action regex extract fail!!!')
- # pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False}
-
- # step_pred_summary = None
- # return pred, [True, ans, None, step_pred_summary]
\ No newline at end of file
diff --git a/demo/gradio/demo_image.jpg b/demo/gradio/demo_image.jpg
deleted file mode 100644
index 06a901c..0000000
Binary files a/demo/gradio/demo_image.jpg and /dev/null differ
diff --git a/demo/gradio/demo_image_som.jpg b/demo/gradio/demo_image_som.jpg
deleted file mode 100644
index 2d580f3..0000000
Binary files a/demo/gradio/demo_image_som.jpg and /dev/null differ
diff --git a/demo/omniparserserver/omniparser.py b/demo/omniparserserver/omniparser.py
new file mode 100644
index 0000000..424430f
--- /dev/null
+++ b/demo/omniparserserver/omniparser.py
@@ -0,0 +1,33 @@
+from ...util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
+import torch
+from PIL import Image
+import io
+import base64
+
+class Omniparser(object):
+ def __init__(self, config: Dict):
+ self.config = config
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ self.som_model = get_yolo_model(model_path=config['som_model_path'])
+ self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
+ print('Omniparser initialized!!!')
+
+ def parse(self, image_base64: str):
+ image_bytes = base64.b64decode(image_base64)
+ image = Image.open(io.BytesIO(image_bytes))
+ print('image size:', image.size)
+
+ box_overlay_ratio = max(image.size) / 3200
+ draw_bbox_config = {
+ 'text_scale': 0.8 * box_overlay_ratio,
+ 'text_thickness': max(int(2 * box_overlay_ratio), 1),
+ 'text_padding': max(int(3 * box_overlay_ratio), 1),
+ 'thickness': max(int(3 * box_overlay_ratio), 1),
+ }
+ BOX_TRESHOLD = self.config['BOX_TRESHOLD']
+
+ (text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
+
+ return dino_labled_img, parsed_content_list
\ No newline at end of file
diff --git a/demo/omniparserserver/remote_request.py b/demo/omniparserserver/remote_request.py
new file mode 100644
index 0000000..816940a
--- /dev/null
+++ b/demo/omniparserserver/remote_request.py
@@ -0,0 +1,50 @@
+'''
+python -m remote_request --som_model_path ../weights/icon_detect_v1_5/model_v1_5.pt --caption_model_name florence2 --caption_model_path ../weights/icon_caption_florence --device cuda --BOX_TRESHOLD 0.05
+'''
+
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import time
+from fastapi import FastAPI
+from pydantic import BaseModel
+import argparse
+import uvicorn
+from omniparser import Omniparser
+
+def parse_arguments():
+ parser = argparse.ArgumentParser(description='Omniparser API')
+ parser.add_argument('--som_model_path', type=str, default='../weights/icon_detect_v1_5/model_v1_5.pt', help='Path to the som model')
+ parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model')
+ parser.add_argument('--caption_model_path', type=str, default='../weights/icon_caption_florence', help='Path to the caption model')
+ parser.add_argument('--device', type=str, default='cpu', help='Device to run the model')
+ parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection')
+ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API')
+ parser.add_argument('--port', type=int, default=8000, help='Port for the API')
+ args = parser.parse_args()
+ return args
+
+args = parse_arguments()
+config = vars(args)
+
+app = FastAPI()
+omniparser = Omniparser(config)
+
+class ParseRequest(BaseModel):
+ base64_image: str
+
+@app.post("/parse/")
+async def parse(parse_request: ParseRequest):
+ print('start parsing...')
+ start = time.time()
+ dino_labled_img, parsed_content_list = omniparser.parse(parse_request.base64_image)
+ latency = time.time() - start
+ print('time:', latency)
+ return {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, 'latency': latency}
+
+@app.get("/probe/")
+async def root():
+ return {"message": "Omniparser API ready"}
+
+if __name__ == "__main__":
+ uvicorn.run("remote_request:app", host=args.host, port=args.port, reload=True)
\ No newline at end of file
diff --git a/demo/remote_request.py b/demo/remote_request.py
deleted file mode 100644
index 414be81..0000000
--- a/demo/remote_request.py
+++ /dev/null
@@ -1,101 +0,0 @@
-'''
-python -m remote_request --som_model_path ../weights/icon_detect_v1_5/model_v1_5.pt --caption_model_name florence2 --caption_model_path ../weights/icon_caption_florence --device cuda --BOX_TRESHOLD 0.05
-'''
-
-import sys
-import os
-sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-import time
-from utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
-import torch
-from PIL import Image
-from typing import Dict, Tuple, List
-import base64
-import io
-from fastapi import FastAPI
-from pydantic import BaseModel
-import argparse
-
-def parse_arguments():
- parser = argparse.ArgumentParser(description='Omniparser API')
- parser.add_argument('--som_model_path', type=str, default='../weights/icon_detect_v1_5/model_v1_5.pt', help='Path to the som model')
- parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model')
- parser.add_argument('--caption_model_path', type=str, default='../weights/icon_caption_florence', help='Path to the caption model')
- parser.add_argument('--device', type=str, default='cpu', help='Device to run the model')
- parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection')
- parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API')
- parser.add_argument('--port', type=int, default=8000, help='Port for the API')
- args = parser.parse_args()
- return args
-
-args = parse_arguments()
-config = vars(args)
-
-
-# config = {
-# 'som_model_path': '../weights/icon_detect_v1_5/model_v1_5.pt',
-# 'device': 'cpu',
-# 'caption_model_name': 'florence2',
-# 'caption_model_path': '../weights/icon_caption_florence',
-# 'BOX_TRESHOLD': 0.05
-# }
-
-
-class Omniparser(object):
- def __init__(self, config: Dict):
- self.config = config
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
-
- self.som_model = get_yolo_model(model_path=config['som_model_path'])
- self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
- print('Omniparser initialized!!!')
-
- def parse(self, image_base64: str):
- # Convert base64 to image directly without saving to disk
- image_bytes = base64.b64decode(image_base64)
- image = Image.open(io.BytesIO(image_bytes))
- print('image size:', image.size)
-
- box_overlay_ratio = max(image.size) / 3200
- draw_bbox_config = {
- 'text_scale': 0.8 * box_overlay_ratio,
- 'text_thickness': max(int(2 * box_overlay_ratio), 1),
- 'text_padding': max(int(3 * box_overlay_ratio), 1),
- 'thickness': max(int(3 * box_overlay_ratio), 1),
- }
- BOX_TRESHOLD = config['BOX_TRESHOLD']
-
- (text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
-
- return dino_labled_img, parsed_content_list
-
-
-
-
-app = FastAPI()
-
-class Item(BaseModel):
- base64_image: str
- prompt: str
-
-Omniparser = Omniparser(config)
-
-@app.post("/send_text/")
-async def send_text(item: Item):
- print('start parsing...')
-
- start = time.time()
- dino_labled_img, parsed_content_list = Omniparser.parse(item.base64_image)
- latency = time.time() - start
- print('time:', latency)
- return {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, 'latency': latency}
-
-@app.get("/")
-async def root():
- return {"message": "Omniparser API ready"}
-
-
-if __name__ == "__main__":
- import uvicorn
- uvicorn.run("remote_request:app", host=args.host, port=args.port, reload=True)
\ No newline at end of file
diff --git a/gradio_demo.py b/gradio_demo.py
index b09168b..15d46a7 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -8,7 +8,7 @@ import io
import base64, os
-from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
+from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
import torch
from PIL import Image
@@ -17,8 +17,6 @@ yolo_model = get_yolo_model(model_path='weights/icon_detect_v1_5/best.pt')
caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")
# caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
-
-
MARKDOWN = """
# OmniParser for Pure Vision Based General GUI Agent 🔥
@@ -65,8 +63,6 @@ def process(
# parsed_content_list = str(parsed_content_list)
return image, str(parsed_content_list)
-
-
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
diff --git a/omniparser.py b/omniparser.py
deleted file mode 100644
index 634ae9f..0000000
--- a/omniparser.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from utils import get_som_labeled_img, check_ocr_box, get_yolo_model
-import torch
-from ultralytics import YOLO
-from PIL import Image
-from typing import Dict, Tuple, List
-import io
-import base64
-
-
-config = {
- 'som_model_path': 'finetuned_icon_detect.pt',
- 'device': 'cpu',
- 'caption_model_path': 'Salesforce/blip2-opt-2.7b',
- 'draw_bbox_config': {
- 'text_scale': 0.8,
- 'text_thickness': 2,
- 'text_padding': 3,
- 'thickness': 3,
- },
- 'BOX_TRESHOLD': 0.05
-}
-
-
-class Omniparser(object):
- def __init__(self, config: Dict):
- self.config = config
-
- self.som_model = get_yolo_model(model_path=config['som_model_path'])
- # self.caption_model_processor = get_caption_model_processor(config['caption_model_path'], device=cofig['device'])
- # self.caption_model_processor['model'].to(torch.float32)
-
- def parse(self, image_path: str):
- print('Parsing image:', image_path)
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})
- text, ocr_bbox = ocr_bbox_rslt
-
- draw_bbox_config = self.config['draw_bbox_config']
- BOX_TRESHOLD = self.config['BOX_TRESHOLD']
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)
-
- image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
- # formating output
- return_list = [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},
- 'text': parsed_content_list[i].split(': ')[1], 'type':'text'} for i, (k, coord) in enumerate(label_coordinates.items()) if i < len(parsed_content_list)]
- return_list.extend(
- [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},
- 'text': 'None', 'type':'icon'} for i, (k, coord) in enumerate(label_coordinates.items()) if i >= len(parsed_content_list)]
- )
-
- return [image, return_list]
-
-parser = Omniparser(config)
-image_path = 'examples/pc_1.png'
-
-# time the parser
-import time
-s = time.time()
-image, parsed_content_list = parser.parse(image_path)
-device = config['device']
-print(f'Time taken for Omniparser on {device}:', time.time() - s)
diff --git a/util/action_matching.py b/util/action_matching.py
deleted file mode 100644
index 51c5359..0000000
--- a/util/action_matching.py
+++ /dev/null
@@ -1,425 +0,0 @@
-'''
-Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
-'''
-
-import jax
-import jax.numpy as jnp
-import numpy as np
-
-# import action_type as action_type_lib
-import enum
-
-class ActionType(enum.IntEnum):
- # Placeholders for unused enum values
- UNUSED_0 = 0
- UNUSED_1 = 1
- UNUSED_2 = 2
- UNUSED_8 = 8
- UNUSED_9 = 9
-
- ########### Agent actions ###########
-
- # A type action that sends text to the emulator. Note that this simply sends
- # text and does not perform any clicks for element focus or enter presses for
- # submitting text.
- TYPE = 3
-
- # The dual point action used to represent all gestures.
- DUAL_POINT = 4
-
- # These actions differentiate pressing the home and back button from touches.
- # They represent explicit presses of back and home performed using ADB.
- PRESS_BACK = 5
- PRESS_HOME = 6
-
- # An action representing that ADB command for hitting enter was performed.
- PRESS_ENTER = 7
-
- ########### Episode status actions ###########
-
- # An action used to indicate the desired task has been completed and resets
- # the environment. This action should also be used in the case that the task
- # has already been completed and there is nothing to do.
- # e.g. The task is to turn on the Wi-Fi when it is already on
- STATUS_TASK_COMPLETE = 10
-
- # An action used to indicate that desired task is impossible to complete and
- # resets the environment. This can be a result of many different things
- # including UI changes, Android version differences, etc.
- STATUS_TASK_IMPOSSIBLE = 11
-
-
-_TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen
-ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4
-ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4
-
-# Interval determining if an action is a tap or a swipe.
-_SWIPE_DISTANCE_THRESHOLD = 0.04
-
-
-def _yx_in_bounding_boxes(
- yx, bounding_boxes
-):
- """Check if the (y,x) point is contained in each bounding box.
-
- Args:
- yx: The (y, x) coordinate in pixels of the point.
- bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row
- represents a bounding box: (y_top_left, x_top_left, box_height,
- box_width). Note: containment is inclusive of the bounding box edges.
-
- Returns:
- is_inside: A 1D bool array where each element specifies if the point is
- contained within the respective box.
- """
- y, x = yx
-
- # `bounding_boxes` has shape (n_elements, 4); we extract each array along the
- # last axis into shape (n_elements, 1), then squeeze unneeded dimension.
- top, left, height, width = [
- jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1)
- ]
-
- # The y-axis is inverted for AndroidEnv, so bottom = top + height.
- bottom, right = top + height, left + width
-
- return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(
- x >= left, x <= right)
-
-
-def _resize_annotation_bounding_boxes(
- annotation_positions, annotation_width_augment_fraction,
- annotation_height_augment_fraction):
- """Resize the bounding boxes by the given fractions.
-
- Args:
- annotation_positions: Array of shape (N, 4), where each row represents the
- (y, x, height, width) of the bounding boxes.
- annotation_width_augment_fraction: The fraction to augment the box widths,
- E.g., 1.4 == 240% total increase.
- annotation_height_augment_fraction: Same as described for width, but for box
- height.
-
- Returns:
- Resized bounding box.
-
- """
- height_change = (
- annotation_height_augment_fraction * annotation_positions[:, 2])
- width_change = (
- annotation_width_augment_fraction * annotation_positions[:, 3])
-
- # Limit bounding box positions to the screen.
- resized_annotations = jnp.stack([
- jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)),
- jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)),
- jnp.minimum(1, annotation_positions[:, 2] + height_change),
- jnp.minimum(1, annotation_positions[:, 3] + width_change),
- ],
- axis=1)
- return resized_annotations
-
-
-def is_tap_action(normalized_start_yx,
- normalized_end_yx):
- distance = jnp.linalg.norm(
- jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
- return distance <= _SWIPE_DISTANCE_THRESHOLD
-
-
-def _is_non_dual_point_action(action_type):
- return jnp.not_equal(action_type, ActionType.DUAL_POINT)
-
-
-def _check_tap_actions_match(
- tap_1_yx,
- tap_2_yx,
- annotation_positions,
- matching_tap_distance_threshold_screen_percentage,
- annotation_width_augment_fraction,
- annotation_height_augment_fraction,
-):
- """Determines if two tap actions are the same."""
- resized_annotation_positions = _resize_annotation_bounding_boxes(
- annotation_positions,
- annotation_width_augment_fraction,
- annotation_height_augment_fraction,
- )
-
- # Check if the ground truth tap action falls in an annotation's bounding box.
- tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions)
- tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions)
- both_in_box = jnp.max(tap1_in_box & tap2_in_box)
-
- # If the ground-truth tap action falls outside any of the annotation
- # bounding boxes or one of the actions is inside a bounding box and the other
- # is outside bounding box or vice versa, compare the points using Euclidean
- # distance.
- within_threshold = (
- jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx))
- <= matching_tap_distance_threshold_screen_percentage
- )
- return jnp.logical_or(both_in_box, within_threshold)
-
-
-def _check_drag_actions_match(
- drag_1_touch_yx,
- drag_1_lift_yx,
- drag_2_touch_yx,
- drag_2_lift_yx,
-):
- """Determines if two drag actions are the same."""
- # Store drag deltas (the change in the y and x coordinates from touch to
- # lift), magnitudes, and the index of the main axis, which is the axis with
- # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and
- # ending at (0.3, 0.5) has a main axis index of 1).
- drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx
- drag_1_magnitudes = jnp.abs(drag_1_deltas)
- drag_1_main_axis = np.argmax(drag_1_magnitudes)
- drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx
- drag_2_magnitudes = jnp.abs(drag_2_deltas)
- drag_2_main_axis = np.argmax(drag_2_magnitudes)
-
- return jnp.equal(drag_1_main_axis, drag_2_main_axis)
-
-
-def check_actions_match(
- action_1_touch_yx,
- action_1_lift_yx,
- action_1_action_type,
- action_2_touch_yx,
- action_2_lift_yx,
- action_2_action_type,
- annotation_positions,
- tap_distance_threshold = _TAP_DISTANCE_THRESHOLD,
- annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION,
- annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION,
-):
- """Determines if two actions are considered to be the same.
-
- Two actions being "the same" is defined here as two actions that would result
- in a similar screen state.
-
- Args:
- action_1_touch_yx: The (y, x) coordinates of the first action's touch.
- action_1_lift_yx: The (y, x) coordinates of the first action's lift.
- action_1_action_type: The action type of the first action.
- action_2_touch_yx: The (y, x) coordinates of the second action's touch.
- action_2_lift_yx: The (y, x) coordinates of the second action's lift.
- action_2_action_type: The action type of the second action.
- annotation_positions: The positions of the UI annotations for the screen. It
- is A 2D int array of shape (num_bboxes, 4), where each row represents a
- bounding box: (y_top_left, x_top_left, box_height, box_width). Note that
- containment is inclusive of the bounding box edges.
- tap_distance_threshold: The threshold that determines if two taps result in
- a matching screen state if they don't fall the same bounding boxes.
- annotation_width_augment_fraction: The fraction to increase the width of the
- bounding box by.
- annotation_height_augment_fraction: The fraction to increase the height of
- of the bounding box by.
-
- Returns:
- A boolean representing whether the two given actions are the same or not.
- """
- action_1_touch_yx = jnp.asarray(action_1_touch_yx)
- action_1_lift_yx = jnp.asarray(action_1_lift_yx)
- action_2_touch_yx = jnp.asarray(action_2_touch_yx)
- action_2_lift_yx = jnp.asarray(action_2_lift_yx)
-
- # Checks if at least one of the actions is global (i.e. not DUAL_POINT),
- # because if that is the case, only the actions' types need to be compared.
- has_non_dual_point_action = jnp.logical_or(
- _is_non_dual_point_action(action_1_action_type),
- _is_non_dual_point_action(action_2_action_type),
- )
- #print("non dual point: "+str(has_non_dual_point_action))
-
- different_dual_point_types = jnp.logical_xor(
- is_tap_action(action_1_touch_yx, action_1_lift_yx),
- is_tap_action(action_2_touch_yx, action_2_lift_yx),
- )
- #print("different dual type: "+str(different_dual_point_types))
-
- is_tap = jnp.logical_and(
- is_tap_action(action_1_touch_yx, action_1_lift_yx),
- is_tap_action(action_2_touch_yx, action_2_lift_yx),
- )
- #print("is tap: "+str(is_tap))
-
- taps_match = _check_tap_actions_match(
- action_1_touch_yx,
- action_2_touch_yx,
- annotation_positions,
- tap_distance_threshold,
- annotation_width_augment_fraction,
- annotation_height_augment_fraction,
- )
- #print("tap match: "+str(taps_match))
-
- taps_match = jnp.logical_and(is_tap, taps_match)
- #print("tap match: "+str(taps_match))
-
- drags_match = _check_drag_actions_match(
- action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx
- )
- drags_match = jnp.where(is_tap, False, drags_match)
- #print("drag match: "+str(drags_match))
-
- return jnp.where(
- has_non_dual_point_action,
- jnp.equal(action_1_action_type, action_2_action_type),
- jnp.where(
- different_dual_point_types,
- False,
- jnp.logical_or(taps_match, drags_match),
- ),
- )
-
-
-def action_2_format(step_data):
- # 把test数据集中的动作格式转换为计算matching score的格式
- action_type = step_data["action_type_id"]
-
- if action_type == 4:
- if step_data["action_type_text"] == 'click': # 点击
- touch_point = step_data["touch"]
- lift_point = step_data["lift"]
- else: # 上下左右滑动
- if step_data["action_type_text"] == 'scroll down':
- touch_point = [0.5, 0.8]
- lift_point = [0.5, 0.2]
- elif step_data["action_type_text"] == 'scroll up':
- touch_point = [0.5, 0.2]
- lift_point = [0.5, 0.8]
- elif step_data["action_type_text"] == 'scroll left':
- touch_point = [0.2, 0.5]
- lift_point = [0.8, 0.5]
- elif step_data["action_type_text"] == 'scroll right':
- touch_point = [0.8, 0.5]
- lift_point = [0.2, 0.5]
- else:
- touch_point = [-1.0, -1.0]
- lift_point = [-1.0, -1.0]
-
- if action_type == 3:
- typed_text = step_data["type_text"]
- else:
- typed_text = ""
-
- action = {"action_type": action_type, "touch_point": touch_point, "lift_point": lift_point,
- "typed_text": typed_text}
-
- action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
- action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
- action["typed_text"] = action["typed_text"].lower()
-
- return action
-
-
-def pred_2_format(step_data):
- # 把模型输出的内容转换为计算action_matching的格式
- action_type = step_data["action_type"]
-
- if action_type == 4: # 点击
- action_type_new = 4
- touch_point = step_data["click_point"]
- lift_point = step_data["click_point"]
- typed_text = ""
- elif action_type == 0:
- action_type_new = 4
- touch_point = [0.5, 0.8]
- lift_point = [0.5, 0.2]
- typed_text = ""
- elif action_type == 1:
- action_type_new = 4
- touch_point = [0.5, 0.2]
- lift_point = [0.5, 0.8]
- typed_text = ""
- elif action_type == 8:
- action_type_new = 4
- touch_point = [0.2, 0.5]
- lift_point = [0.8, 0.5]
- typed_text = ""
- elif action_type == 9:
- action_type_new = 4
- touch_point = [0.8, 0.5]
- lift_point = [0.2, 0.5]
- typed_text = ""
- else:
- action_type_new = action_type
- touch_point = [-1.0, -1.0]
- lift_point = [-1.0, -1.0]
- typed_text = ""
- if action_type_new == 3:
- typed_text = step_data["typed_text"]
-
- action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
- "typed_text": typed_text}
-
- action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
- action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
- action["typed_text"] = action["typed_text"].lower()
-
- return action
-
-
-def pred_2_format_simplified(step_data):
- # 把模型输出的内容转换为计算action_matching的格式
- action_type = step_data["action_type"]
-
- if action_type == 'click' : # 点击
- action_type_new = 4
- touch_point = step_data["click_point"]
- lift_point = step_data["click_point"]
- typed_text = ""
- elif action_type == 'scroll' and step_data["direction"] == 'down':
- action_type_new = 4
- touch_point = [0.5, 0.8]
- lift_point = [0.5, 0.2]
- typed_text = ""
- elif action_type == 'scroll' and step_data["direction"] == 'up':
- action_type_new = 4
- touch_point = [0.5, 0.2]
- lift_point = [0.5, 0.8]
- typed_text = ""
- elif action_type == 'scroll' and step_data["direction"] == 'left':
- action_type_new = 4
- touch_point = [0.2, 0.5]
- lift_point = [0.8, 0.5]
- typed_text = ""
- elif action_type == 'scroll' and step_data["direction"] == 'right':
- action_type_new = 4
- touch_point = [0.8, 0.5]
- lift_point = [0.2, 0.5]
- typed_text = ""
- elif action_type == 'type':
- action_type_new = 3
- touch_point = [-1.0, -1.0]
- lift_point = [-1.0, -1.0]
- typed_text = step_data["text"]
- elif action_type == 'navigate_back':
- action_type_new = 5
- touch_point = [-1.0, -1.0]
- lift_point = [-1.0, -1.0]
- typed_text = ""
- elif action_type == 'navigate_home':
- action_type_new = 6
- touch_point = [-1.0, -1.0]
- lift_point = [-1.0, -1.0]
- typed_text = ""
- else:
- action_type_new = action_type
- touch_point = [-1.0, -1.0]
- lift_point = [-1.0, -1.0]
- typed_text = ""
- # if action_type_new == 'type':
- # typed_text = step_data["text"]
-
- action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
- "typed_text": typed_text}
-
- action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
- action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
- action["typed_text"] = action["typed_text"].lower()
-
- return action
\ No newline at end of file
diff --git a/util/action_type.py b/util/action_type.py
deleted file mode 100644
index 2e2b06d..0000000
--- a/util/action_type.py
+++ /dev/null
@@ -1,45 +0,0 @@
-'''
-Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
-'''
-
-import enum
-
-class ActionType(enum.IntEnum):
-
- # Placeholders for unused enum values
- UNUSED_0 = 0
- UNUSED_1 = 1
- UNUSED_2 = 2
- UNUSED_8 = 8
- UNUSED_9 = 9
-
- ########### Agent actions ###########
-
- # A type action that sends text to the emulator. Note that this simply sends
- # text and does not perform any clicks for element focus or enter presses for
- # submitting text.
- TYPE = 3
-
- # The dual point action used to represent all gestures.
- DUAL_POINT = 4
-
- # These actions differentiate pressing the home and back button from touches.
- # They represent explicit presses of back and home performed using ADB.
- PRESS_BACK = 5
- PRESS_HOME = 6
-
- # An action representing that ADB command for hitting enter was performed.
- PRESS_ENTER = 7
-
- ########### Episode status actions ###########
-
- # An action used to indicate the desired task has been completed and resets
- # the environment. This action should also be used in the case that the task
- # has already been completed and there is nothing to do.
- # e.g. The task is to turn on the Wi-Fi when it is already on
- STATUS_TASK_COMPLETE = 10
-
- # An action used to indicate that desired task is impossible to complete and
- # resets the environment. This can be a result of many different things
- # including UI changes, Android version differences, etc.
- STATUS_TASK_IMPOSSIBLE = 11
\ No newline at end of file
diff --git a/utils.py b/util/utils.py
old mode 100755
new mode 100644
similarity index 100%
rename from utils.py
rename to util/utils.py