Clean up folder structure
This commit is contained in:
201
Dockerfile
201
Dockerfile
@@ -1,201 +0,0 @@
|
||||
# Dockerfile for OmniParser with GPU support and OpenGL libraries
|
||||
#
|
||||
# This Dockerfile is intended to create an environment with NVIDIA CUDA
|
||||
# support and the necessary dependencies to run the OmniParser project.
|
||||
# The configuration is designed to support applications that rely on
|
||||
# Python 3.12, OpenCV, Hugging Face transformers, and Gradio. Additionally,
|
||||
# it includes steps to pull large files from Git LFS and a script to
|
||||
# convert model weights from .safetensor to .pt format. The container
|
||||
# runs a Gradio server by default, exposed on port 7861.
|
||||
#
|
||||
# Base image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
||||
#
|
||||
# Key features:
|
||||
# - System dependencies for OpenGL to support graphical libraries.
|
||||
# - Miniconda for Python 3.12, allowing for environment management.
|
||||
# - Git Large File Storage (LFS) setup for handling large model files.
|
||||
# - Requirement file installation, including specific versions of
|
||||
# OpenCV and Hugging Face Hub.
|
||||
# - Entrypoint script execution with Gradio server configuration for
|
||||
# external access.
|
||||
|
||||
FROM nvidia/cuda:12.3.1-devel-ubuntu22.04
|
||||
|
||||
# Install system dependencies with explicit OpenGL libraries
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
git \
|
||||
git-lfs \
|
||||
wget \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
libxrender1 \
|
||||
libglu1-mesa \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxrender1 \
|
||||
libxext6 \
|
||||
python3-opencv \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& git lfs install
|
||||
|
||||
# Install Miniconda for Python 3.12
|
||||
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh && \
|
||||
bash miniconda.sh -b -p /opt/conda && \
|
||||
rm miniconda.sh
|
||||
ENV PATH="/opt/conda/bin:$PATH"
|
||||
|
||||
# Create and activate Conda environment with Python 3.12, and set it as the default
|
||||
RUN conda create -n omni python=3.12 && \
|
||||
echo "source activate omni" > ~/.bashrc
|
||||
ENV CONDA_DEFAULT_ENV=omni
|
||||
ENV PATH="/opt/conda/envs/omni/bin:$PATH"
|
||||
|
||||
# Set the working directory in the container
|
||||
WORKDIR /usr/src/app
|
||||
|
||||
# Copy project files and requirements
|
||||
COPY . .
|
||||
COPY requirements.txt /usr/src/app/requirements.txt
|
||||
|
||||
# Initialize Git LFS and pull LFS files
|
||||
RUN git lfs install && \
|
||||
git lfs pull
|
||||
|
||||
# Install dependencies from requirements.txt with specific opencv-python-headless version
|
||||
RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \
|
||||
pip uninstall -y opencv-python opencv-python-headless && \
|
||||
pip install --no-cache-dir opencv-python-headless==4.8.1.78 && \
|
||||
pip install -r requirements.txt && \
|
||||
pip install huggingface_hub
|
||||
|
||||
# Run download.py to fetch model weights and convert safetensors to .pt format
|
||||
# RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \
|
||||
# python download.py && \
|
||||
# echo "Contents of weights directory:" && \
|
||||
# ls -lR weights && \
|
||||
# python weights/convert_safetensor_to_pt.py
|
||||
|
||||
# Expose the default Gradio port
|
||||
EXPOSE 7861
|
||||
|
||||
# Configure Gradio to be accessible externally
|
||||
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
||||
|
||||
# Copy and set permissions for entrypoint script
|
||||
# COPY entrypoint.sh /usr/src/app/entrypoint.sh
|
||||
# RUN chmod +x /usr/src/app/entrypoint.sh
|
||||
|
||||
# To debug, keep the container running
|
||||
# CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
################################################################################################
|
||||
# virtual display related setup --> from anthropic-quickstarts/computer-use-demo/Dockerfile
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV DEBIAN_PRIORITY=high
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get -y upgrade && \
|
||||
apt-get -y install \
|
||||
# UI Requirements
|
||||
xvfb \
|
||||
xterm \
|
||||
xdotool \
|
||||
scrot \
|
||||
imagemagick \
|
||||
sudo \
|
||||
mutter \
|
||||
x11vnc \
|
||||
# Python/pyenv reqs
|
||||
build-essential \
|
||||
libssl-dev \
|
||||
zlib1g-dev \
|
||||
libbz2-dev \
|
||||
libreadline-dev \
|
||||
libsqlite3-dev \
|
||||
curl \
|
||||
git \
|
||||
libncursesw5-dev \
|
||||
xz-utils \
|
||||
tk-dev \
|
||||
libxml2-dev \
|
||||
libxmlsec1-dev \
|
||||
libffi-dev \
|
||||
liblzma-dev \
|
||||
# Network tools
|
||||
net-tools \
|
||||
netcat \
|
||||
# PPA req
|
||||
software-properties-common && \
|
||||
# Userland apps
|
||||
sudo add-apt-repository ppa:mozillateam/ppa && \
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
libreoffice \
|
||||
firefox-esr \
|
||||
x11-apps \
|
||||
xpdf \
|
||||
gedit \
|
||||
xpaint \
|
||||
tint2 \
|
||||
galculator \
|
||||
pcmanfm \
|
||||
unzip && \
|
||||
apt-get clean
|
||||
|
||||
# Install noVNC
|
||||
RUN git clone --branch v1.5.0 https://github.com/novnc/noVNC.git /opt/noVNC && \
|
||||
git clone --branch v0.12.0 https://github.com/novnc/websockify /opt/noVNC/utils/websockify && \
|
||||
ln -s /opt/noVNC/vnc.html /opt/noVNC/index.html
|
||||
|
||||
# setup user
|
||||
ENV USERNAME=computeruse
|
||||
ENV HOME=/home/$USERNAME
|
||||
RUN useradd -m -s /bin/bash -d $HOME $USERNAME
|
||||
RUN echo "${USERNAME} ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers
|
||||
USER computeruse
|
||||
WORKDIR $HOME
|
||||
|
||||
# setup python
|
||||
RUN git clone https://github.com/pyenv/pyenv.git ~/.pyenv && \
|
||||
cd ~/.pyenv && src/configure && make -C src && cd .. && \
|
||||
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc && \
|
||||
echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc && \
|
||||
echo 'eval "$(pyenv init -)"' >> ~/.bashrc
|
||||
ENV PYENV_ROOT="$HOME/.pyenv"
|
||||
ENV PATH="$PYENV_ROOT/bin:$PATH"
|
||||
ENV PYENV_VERSION_MAJOR=3
|
||||
ENV PYENV_VERSION_MINOR=11
|
||||
ENV PYENV_VERSION_PATCH=6
|
||||
ENV PYENV_VERSION=$PYENV_VERSION_MAJOR.$PYENV_VERSION_MINOR.$PYENV_VERSION_PATCH
|
||||
RUN eval "$(pyenv init -)" && \
|
||||
pyenv install $PYENV_VERSION && \
|
||||
pyenv global $PYENV_VERSION && \
|
||||
pyenv rehash
|
||||
|
||||
ENV PATH="$HOME/.pyenv/shims:$HOME/.pyenv/bin:$PATH"
|
||||
|
||||
RUN python -m pip install --upgrade pip==23.1.2 setuptools==58.0.4 wheel==0.40.0 && \
|
||||
python -m pip config set global.disable-pip-version-check true
|
||||
|
||||
# only reinstall if requirements.txt changes
|
||||
# COPY --chown=$USERNAME:$USERNAME computer_use_demo/requirements.txt $HOME/computer_use_demo/requirements.txt
|
||||
# RUN python -m pip install -r $HOME/computer_use_demo/requirements.txt
|
||||
|
||||
# setup desktop env & app
|
||||
# COPY --chown=$USERNAME:$USERNAME image/ $HOME
|
||||
# COPY --chown=$USERNAME:$USERNAME computer_use_demo/ $HOME/computer_use_demo/
|
||||
|
||||
ARG DISPLAY_NUM=1
|
||||
ARG HEIGHT=768
|
||||
ARG WIDTH=1024
|
||||
ENV DISPLAY_NUM=$DISPLAY_NUM
|
||||
ENV HEIGHT=$HEIGHT
|
||||
ENV WIDTH=$WIDTH
|
||||
|
||||
# Set the entrypoint
|
||||
# ENTRYPOINT ["/usr/src/app/entrypoint.sh"]
|
||||
|
||||
# docker build . -t omniparser-x-demo:local # manually build the docker image (optional)
|
||||
@@ -3,21 +3,13 @@ Entrypoint for Gradio, see https://gradio.app/
|
||||
python app.py --windows_host_url xxxx:8006/ --omniparser_host_url localhost:8000
|
||||
"""
|
||||
|
||||
import platform
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import cast, Dict
|
||||
from PIL import Image
|
||||
import socket
|
||||
from typing import cast
|
||||
import argparse
|
||||
|
||||
import gradio as gr
|
||||
from anthropic import APIResponse
|
||||
from anthropic.types import TextBlock
|
||||
@@ -27,7 +19,6 @@ from computer_use_demo.loop import (
|
||||
APIProvider,
|
||||
sampling_loop_sync,
|
||||
)
|
||||
|
||||
from computer_use_demo.tools import ToolResult
|
||||
|
||||
CONFIG_DIR = Path("~/.anthropic").expanduser()
|
||||
@@ -43,7 +34,7 @@ Type a message and press submit to start OmniParser+X. Press the trash icon in t
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="Gradio App")
|
||||
parser.add_argument("--windows_host_url", type=str, default='GCRSANDBOX336.redmond.corp.microsoft.com:8006/') # http://gcrsandbox336.redmond.corp.microsoft.com/
|
||||
parser.add_argument("--windows_host_url", type=str, default='localhost:8006')
|
||||
parser.add_argument("--omniparser_host_url", type=str, default="localhost:8000")
|
||||
return parser.parse_args()
|
||||
args = parse_arguments()
|
||||
@@ -82,8 +73,6 @@ def setup_state(state):
|
||||
state["only_n_most_recent_images"] = 2
|
||||
if 'chatbot_messages' not in state:
|
||||
state['chatbot_messages'] = []
|
||||
# if "omniparser_url" not in state:
|
||||
# state["omniparser_url"] = "localhost:8000"
|
||||
|
||||
async def main(state):
|
||||
"""Render loop for Gradio"""
|
||||
@@ -225,7 +214,7 @@ def process_input(user_input, state):
|
||||
api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
|
||||
api_key=state["api_key"],
|
||||
only_n_most_recent_images=state["only_n_most_recent_images"],
|
||||
omniparser_url=omniparser_host_url #state["omniparser_url"]
|
||||
omniparser_url=omniparser_host_url
|
||||
):
|
||||
if loop_msg is None:
|
||||
yield state['chatbot_messages']
|
||||
@@ -289,14 +278,6 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
||||
placeholder="Paste your API key here",
|
||||
interactive=True,
|
||||
)
|
||||
# with gr.Row():
|
||||
# omniparser_url = gr.Textbox(
|
||||
# label="OmniParser Base URL",
|
||||
# value="localhost:8000",
|
||||
# placeholder="Enter OmniParser base URL (e.g. localhost:8000)",
|
||||
# interactive=True
|
||||
# )
|
||||
# hide_images = gr.Checkbox(label="Hide screenshots", value=False)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=8):
|
||||
@@ -373,9 +354,6 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
||||
state["api_key"] = api_key_value
|
||||
state[f'{state["provider"]}_api_key'] = api_key_value
|
||||
|
||||
# def update_omniparser_url(url_value, state):
|
||||
# state["omniparser_url"] = url_value
|
||||
|
||||
def clear_chat(state):
|
||||
# Reset message-related state
|
||||
state["messages"] = []
|
||||
@@ -388,7 +366,6 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
||||
only_n_images.change(fn=update_only_n_images, inputs=[only_n_images, state], outputs=None)
|
||||
provider.change(fn=update_provider, inputs=[provider, state], outputs=api_key)
|
||||
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
|
||||
# omniparser_url.change(fn=update_omniparser_url, inputs=[omniparser_url, state], outputs=None)
|
||||
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
|
||||
|
||||
submit_button.click(process_input, [chat_input, state], chatbot)
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import requests
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from computer_use_demo.tools.screen_capture import get_screenshot
|
||||
from computer_use_demo.gui_agent.llm_utils.oai import encode_image
|
||||
|
||||
OUTPUT_DIR = "./tmp/outputs"
|
||||
|
||||
class OmniParserClient:
|
||||
def __init__(self,
|
||||
url: str) -> None:
|
||||
self.url = url
|
||||
|
||||
def __call__(self,):
|
||||
screenshot, screenshot_path = get_screenshot()
|
||||
screenshot_path = str(screenshot_path)
|
||||
image_base64 = encode_image(screenshot_path)
|
||||
response = requests.post(self.url, json={"base64_image": image_base64})
|
||||
response_json = response.json()
|
||||
print('omniparser latency:', response_json['latency'])
|
||||
|
||||
som_image_data = base64.b64decode(response_json['som_image_base64'])
|
||||
screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
|
||||
som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
|
||||
with open(som_screenshot_path, "wb") as f:
|
||||
f.write(som_image_data)
|
||||
|
||||
response_json['width'] = screenshot.size[0]
|
||||
response_json['height'] = screenshot.size[1]
|
||||
response_json['original_screenshot_base64'] = image_base64
|
||||
response_json['screenshot_uuid'] = screenshot_path_uuid
|
||||
response_json = self.reformat_messages(response_json)
|
||||
return response_json
|
||||
|
||||
def reformat_messages(self, response_json: dict):
|
||||
screen_info = ""
|
||||
for idx, element in enumerate(response_json["parsed_content_list"]):
|
||||
element['idx'] = idx
|
||||
if element['type'] == 'text':
|
||||
screen_info += f'ID: {idx}, Text: {element["content"]}\n'
|
||||
elif element['type'] == 'icon':
|
||||
screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
|
||||
response_json['screen_info'] = screen_info
|
||||
return response_json
|
||||
@@ -1,23 +1,16 @@
|
||||
import json
|
||||
import asyncio
|
||||
import platform
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, cast, Dict, Callable
|
||||
from typing import cast, Callable
|
||||
import uuid
|
||||
import requests
|
||||
from PIL import Image, ImageDraw
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
|
||||
from anthropic.types import TextBlock, ToolResultBlockParam
|
||||
from anthropic import APIResponse
|
||||
from anthropic.types import ToolResultBlockParam
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
|
||||
|
||||
from computer_use_demo.tools.screen_capture import get_screenshot
|
||||
from computer_use_demo.gui_agent.llm_utils.oai import run_oai_interleaved, encode_image
|
||||
from computer_use_demo.gui_agent.llm_utils.oai import run_oai_interleaved
|
||||
from computer_use_demo.colorful_text import colorful_text_vlm
|
||||
import time
|
||||
import re
|
||||
@@ -33,67 +26,6 @@ def extract_data(input_string, data_type):
|
||||
# Return the first match if exists, trimming whitespace and ignoring potential closing backticks
|
||||
return matches[0][0].strip() if matches else input_string
|
||||
|
||||
class OmniParser:
|
||||
def __init__(self,
|
||||
url: str) -> None:
|
||||
self.url = url
|
||||
if not self.url:
|
||||
config = {
|
||||
'som_model_path': '../weights/icon_detect_v1_5/model_v1_5.pt',
|
||||
'device': 'cpu',
|
||||
'caption_model_name': 'florence2',
|
||||
'caption_model_path': '../weights/icon_caption_florence',
|
||||
'BOX_TRESHOLD': 0.05
|
||||
}
|
||||
from computer_use_demo.omniparser_agent.omniparser import Omniparser as Omniparser_class
|
||||
self.omniparser = Omniparser_class(config=config)
|
||||
|
||||
def __call__(self,):
|
||||
screenshot, screenshot_path = get_screenshot()
|
||||
screenshot_path = str(screenshot_path)
|
||||
image_base64 = encode_image(screenshot_path)
|
||||
if self.url:
|
||||
response = requests.post(self.url, json={"base64_image": image_base64, 'prompt': 'omniparser process'})
|
||||
response_json = response.json()
|
||||
else:
|
||||
start = time.time()
|
||||
dino_labled_img, parsed_content_list = self.omniparser.parse(image_base64)
|
||||
latency = time.time() - start
|
||||
response_json = {
|
||||
'som_image_base64': dino_labled_img,
|
||||
'parsed_content_list': parsed_content_list,
|
||||
'latency': latency
|
||||
}
|
||||
|
||||
som_image_data = base64.b64decode(response_json['som_image_base64'])
|
||||
screenshot_path_uuid = Path(screenshot_path).stem.replace("screenshot_", "")
|
||||
som_screenshot_path = f"{OUTPUT_DIR}/screenshot_som_{screenshot_path_uuid}.png"
|
||||
with open(som_screenshot_path, "wb") as f:
|
||||
f.write(som_image_data)
|
||||
|
||||
response_json['width'] = screenshot.size[0]
|
||||
response_json['height'] = screenshot.size[1]
|
||||
response_json['original_screenshot_base64'] = image_base64
|
||||
response_json['screenshot_uuid'] = screenshot_path_uuid
|
||||
# example response_json: {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, "latency": 0.1}
|
||||
print('omniparser latency:', response_json['latency'])
|
||||
response_json = self.reformat_messages(response_json)
|
||||
return response_json
|
||||
|
||||
def reformat_messages(self, response_json: dict):
|
||||
screen_info = ""
|
||||
for idx, element in enumerate(response_json["parsed_content_list"]):
|
||||
element['idx'] = idx
|
||||
if element['type'] == 'text':
|
||||
# screen_info += f'''<p id={idx} class="text" alt="{element['content']}"> </p>\n'''
|
||||
screen_info += f'ID: {idx}, Text: {element["content"]}\n'
|
||||
elif element['type'] == 'icon':
|
||||
# screen_info += f'''<img id={idx} class="icon" alt="{element['content']}"> </img>\n'''
|
||||
screen_info += f'ID: {idx}, Icon: {element["content"]}\n'
|
||||
response_json['screen_info'] = screen_info
|
||||
return response_json
|
||||
|
||||
|
||||
class VLMAgent:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -15,9 +15,10 @@ from anthropic.types.beta import (
|
||||
)
|
||||
from computer_use_demo.tools import ToolResult
|
||||
|
||||
from computer_use_demo.gui_agent.anthropic_agent import AnthropicActor
|
||||
from computer_use_demo.agent.llm_utils.omniparserclient import OmniParserClient
|
||||
from computer_use_demo.agent.anthropic_agent import AnthropicActor
|
||||
from computer_use_demo.agent.vlm_agent import VLMAgent
|
||||
from computer_use_demo.executor.anthropic_executor import AnthropicExecutor
|
||||
from computer_use_demo.omniparser_agent.vlm_agent import OmniParser, VLMAgent
|
||||
|
||||
BETA_FLAG = "computer-use-2024-10-22"
|
||||
|
||||
@@ -52,7 +53,7 @@ def sampling_loop_sync(
|
||||
Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
|
||||
"""
|
||||
print('in sampling_loop_sync, model:', model)
|
||||
omniparser = OmniParser(url=f"http://{omniparser_url}/send_text/" if omniparser_url else None)
|
||||
omniparser_client = OmniParserClient(url=f"http://{omniparser_url}/parse/")
|
||||
if model == "claude-3-5-sonnet-20241022":
|
||||
# Register Actor and Executor
|
||||
actor = AnthropicActor(
|
||||
@@ -94,7 +95,7 @@ def sampling_loop_sync(
|
||||
|
||||
if model == "claude-3-5-sonnet-20241022": # Anthropic loop
|
||||
while True:
|
||||
parsed_screen = omniparser() # parsed_screen: {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, "screen_info"}
|
||||
parsed_screen = omniparser_client() # parsed_screen: {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, "screen_info"}
|
||||
import pdb; pdb.set_trace()
|
||||
screen_info_block = TextBlock(text='Below is the structured accessibility information of the current UI screen, which includes text and icons you can operate on, take these information into account when you are making the prediction for the next action. Note you will still need to take screenshot to get the image: \n' + parsed_screen['screen_info'], type='text')
|
||||
# # messages[-1]['content'].append(screen_info_block)
|
||||
@@ -112,7 +113,7 @@ def sampling_loop_sync(
|
||||
|
||||
elif model == "omniparser + gpt-4o" or model == "omniparser + phi35v":
|
||||
while True:
|
||||
parsed_screen = omniparser()
|
||||
parsed_screen = omniparser_client()
|
||||
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
|
||||
|
||||
for message, tool_result_content in executor(tools_use_needed, messages):
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
# uvicorn remote_request:app --host 0.0.0.0 --port 8000 --reload
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from computer_use_demo.omniparser_agent.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
|
||||
import torch
|
||||
from PIL import Image
|
||||
from typing import Dict, Tuple, List
|
||||
import base64
|
||||
import io
|
||||
|
||||
|
||||
# config = {
|
||||
# 'som_model_path': '../weights/icon_detect_v1_5/model_v1_5.pt',
|
||||
# 'device': 'cpu',
|
||||
# 'caption_model_name': 'florence2',
|
||||
# 'caption_model_path': '../weights/icon_caption_florence',
|
||||
# 'BOX_TRESHOLD': 0.05
|
||||
# }
|
||||
|
||||
|
||||
class Omniparser(object):
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
self.som_model = get_yolo_model(model_path=config['som_model_path'])
|
||||
self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
|
||||
print('Omniparser initialized!!!')
|
||||
|
||||
def parse(self, image_base64: str):
|
||||
image_path = './demo_image.jpg'
|
||||
with open(image_path, "wb") as fh:
|
||||
fh.write(base64.b64decode(image_base64))
|
||||
print('Parsing image:', image_path)
|
||||
|
||||
image = Image.open(image_path)
|
||||
print('image size:', image.size)
|
||||
|
||||
box_overlay_ratio = max(image.size) / 3200
|
||||
draw_bbox_config = {
|
||||
'text_scale': 0.8 * box_overlay_ratio,
|
||||
'text_thickness': max(int(2 * box_overlay_ratio), 1),
|
||||
'text_padding': max(int(3 * box_overlay_ratio), 1),
|
||||
'thickness': max(int(3 * box_overlay_ratio), 1),
|
||||
}
|
||||
BOX_TRESHOLD = self.config['BOX_TRESHOLD']
|
||||
|
||||
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.8}, use_paddleocr=True)
|
||||
text, ocr_bbox = ocr_bbox_rslt
|
||||
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
|
||||
|
||||
return dino_labled_img, parsed_content_list
|
||||
|
||||
|
||||
# from fastapi import FastAPI
|
||||
# from pydantic import BaseModel
|
||||
|
||||
# app = FastAPI()
|
||||
|
||||
# class Item(BaseModel):
|
||||
# base64_image: str
|
||||
# prompt: str
|
||||
|
||||
# Omniparser = Omniparser(config)
|
||||
|
||||
# @app.post("/send_text/")
|
||||
# async def send_text(item: Item):
|
||||
# print('start parsing...')
|
||||
# import time
|
||||
# start = time.time()
|
||||
# dino_labled_img, parsed_content_list = Omniparser.parse(item.base64_image)
|
||||
# latency = time.time() - start
|
||||
# print('time:', latency)
|
||||
# return {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, 'latency': latency}
|
||||
@@ -1,935 +0,0 @@
|
||||
# from ultralytics import YOLO
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import time
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import json
|
||||
import requests
|
||||
# utility function
|
||||
import os
|
||||
from openai import AzureOpenAI
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
# %matplotlib inline
|
||||
from matplotlib import pyplot as plt
|
||||
import easyocr
|
||||
from paddleocr import PaddleOCR
|
||||
reader = easyocr.Reader(['en', 'ch_sim'], gpu=True)
|
||||
paddle_ocr = PaddleOCR(
|
||||
lang='en', # other lang also available
|
||||
use_angle_cls=False,
|
||||
use_gpu=False, # using cuda will conflict with pytorch in the same process
|
||||
show_log=False,
|
||||
max_batch_size=1024,
|
||||
use_dilation=True, # improves accuracy
|
||||
det_db_score_mode='slow', # improves accuracy
|
||||
rec_batch_num=1024)
|
||||
import time
|
||||
import base64
|
||||
|
||||
import os
|
||||
import ast
|
||||
import torch
|
||||
from typing import Tuple, List
|
||||
from torchvision.ops import box_convert
|
||||
import re
|
||||
from torchvision.transforms import ToPILImage
|
||||
import supervision as sv
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
|
||||
if not device:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if model_name == "blip2":
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
if device == 'cpu':
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path, device_map=None, torch_dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path, device_map=None, torch_dtype=torch.float16
|
||||
).to(device)
|
||||
elif model_name == "florence2":
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
|
||||
if device == 'cpu':
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
|
||||
return {'model': model.to(device), 'processor': processor}
|
||||
|
||||
|
||||
def get_yolo_model(model_path):
|
||||
from ultralytics import YOLO
|
||||
# Load the model.
|
||||
model = YOLO(model_path)
|
||||
return model
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
|
||||
# Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
|
||||
|
||||
to_pil = ToPILImage()
|
||||
if starting_idx:
|
||||
non_ocr_boxes = filtered_boxes[starting_idx:]
|
||||
else:
|
||||
non_ocr_boxes = filtered_boxes
|
||||
croped_pil_image = []
|
||||
t0 = time.time()
|
||||
for i, coord in enumerate(non_ocr_boxes):
|
||||
try:
|
||||
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
|
||||
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
|
||||
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
||||
# resize the image to 224x224 to avoid long overhead in clipimageprocessor # TODO
|
||||
cropped_image = cv2.resize(cropped_image, (64, 64))
|
||||
croped_pil_image.append(to_pil(cropped_image))
|
||||
except:
|
||||
continue
|
||||
# print('time to prepare bbox:', time.time()-t0)
|
||||
|
||||
model, processor = caption_model_processor['model'], caption_model_processor['processor']
|
||||
if not prompt:
|
||||
if 'florence' in model.config.name_or_path:
|
||||
prompt = "<CAPTION>"
|
||||
else:
|
||||
prompt = "The image shows"
|
||||
|
||||
generated_texts = []
|
||||
device = model.device
|
||||
# batch_size = 64
|
||||
for i in range(0, len(croped_pil_image), batch_size):
|
||||
start = time.time()
|
||||
batch = croped_pil_image[i:i+batch_size]
|
||||
t1 = time.time()
|
||||
if model.device.type == 'cuda':
|
||||
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
|
||||
else:
|
||||
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
|
||||
t2 = time.time()
|
||||
# print('time to process image + tokenize text inputs:', t2-t1)
|
||||
if 'florence' in model.config.name_or_path:
|
||||
generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
|
||||
else:
|
||||
generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
|
||||
t3 = time.time()
|
||||
# print('time to generate:', t3-t2)
|
||||
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
generated_text = [gen.strip() for gen in generated_text]
|
||||
generated_texts.extend(generated_text)
|
||||
|
||||
return generated_texts
|
||||
|
||||
|
||||
|
||||
def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
|
||||
to_pil = ToPILImage()
|
||||
if ocr_bbox:
|
||||
non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
|
||||
else:
|
||||
non_ocr_boxes = filtered_boxes
|
||||
croped_pil_image = []
|
||||
for i, coord in enumerate(non_ocr_boxes):
|
||||
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
|
||||
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
|
||||
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
|
||||
croped_pil_image.append(to_pil(cropped_image))
|
||||
|
||||
model, processor = caption_model_processor['model'], caption_model_processor['processor']
|
||||
device = model.device
|
||||
messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
|
||||
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
batch_size = 5 # Number of samples per batch
|
||||
generated_texts = []
|
||||
|
||||
for i in range(0, len(croped_pil_image), batch_size):
|
||||
images = croped_pil_image[i:i+batch_size]
|
||||
image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
|
||||
inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
|
||||
texts = [prompt] * len(images)
|
||||
for i, txt in enumerate(texts):
|
||||
input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
|
||||
inputs['input_ids'].append(input['input_ids'])
|
||||
inputs['attention_mask'].append(input['attention_mask'])
|
||||
inputs['pixel_values'].append(input['pixel_values'])
|
||||
inputs['image_sizes'].append(input['image_sizes'])
|
||||
max_len = max([x.shape[1] for x in inputs['input_ids']])
|
||||
for i, v in enumerate(inputs['input_ids']):
|
||||
inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
|
||||
inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
|
||||
inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
|
||||
|
||||
generation_args = {
|
||||
"max_new_tokens": 25,
|
||||
"temperature": 0.01,
|
||||
"do_sample": False,
|
||||
}
|
||||
generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
|
||||
# # remove input tokens
|
||||
generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
|
||||
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
response = [res.strip('\n').strip() for res in response]
|
||||
generated_texts.extend(response)
|
||||
|
||||
return generated_texts
|
||||
|
||||
def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
|
||||
assert ocr_bbox is None or isinstance(ocr_bbox, List)
|
||||
|
||||
def box_area(box):
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
def intersection_area(box1, box2):
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
return max(0, x2 - x1) * max(0, y2 - y1)
|
||||
|
||||
def IoU(box1, box2):
|
||||
intersection = intersection_area(box1, box2)
|
||||
union = box_area(box1) + box_area(box2) - intersection + 1e-6
|
||||
if box_area(box1) > 0 and box_area(box2) > 0:
|
||||
ratio1 = intersection / box_area(box1)
|
||||
ratio2 = intersection / box_area(box2)
|
||||
else:
|
||||
ratio1, ratio2 = 0, 0
|
||||
return max(intersection / union, ratio1, ratio2)
|
||||
|
||||
def is_inside(box1, box2):
|
||||
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
|
||||
intersection = intersection_area(box1, box2)
|
||||
ratio1 = intersection / box_area(box1)
|
||||
return ratio1 > 0.95
|
||||
|
||||
boxes = boxes.tolist()
|
||||
filtered_boxes = []
|
||||
if ocr_bbox:
|
||||
filtered_boxes.extend(ocr_bbox)
|
||||
# print('ocr_bbox!!!', ocr_bbox)
|
||||
for i, box1 in enumerate(boxes):
|
||||
# if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
|
||||
is_valid_box = True
|
||||
for j, box2 in enumerate(boxes):
|
||||
# keep the smaller box
|
||||
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
|
||||
is_valid_box = False
|
||||
break
|
||||
if is_valid_box:
|
||||
# add the following 2 lines to include ocr bbox
|
||||
if ocr_bbox:
|
||||
# only add the box if it does not overlap with any ocr bbox
|
||||
if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
|
||||
filtered_boxes.append(box1)
|
||||
else:
|
||||
filtered_boxes.append(box1)
|
||||
return torch.tensor(filtered_boxes)
|
||||
|
||||
|
||||
def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
|
||||
'''
|
||||
ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
|
||||
boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
|
||||
|
||||
'''
|
||||
assert ocr_bbox is None or isinstance(ocr_bbox, List)
|
||||
|
||||
def box_area(box):
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
def intersection_area(box1, box2):
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
return max(0, x2 - x1) * max(0, y2 - y1)
|
||||
|
||||
def IoU(box1, box2):
|
||||
intersection = intersection_area(box1, box2)
|
||||
union = box_area(box1) + box_area(box2) - intersection + 1e-6
|
||||
if box_area(box1) > 0 and box_area(box2) > 0:
|
||||
ratio1 = intersection / box_area(box1)
|
||||
ratio2 = intersection / box_area(box2)
|
||||
else:
|
||||
ratio1, ratio2 = 0, 0
|
||||
return max(intersection / union, ratio1, ratio2)
|
||||
|
||||
def is_inside(box1, box2):
|
||||
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
|
||||
intersection = intersection_area(box1, box2)
|
||||
ratio1 = intersection / box_area(box1)
|
||||
return ratio1 > 0.80
|
||||
|
||||
# boxes = boxes.tolist()
|
||||
filtered_boxes = []
|
||||
if ocr_bbox:
|
||||
filtered_boxes.extend(ocr_bbox)
|
||||
# print('ocr_bbox!!!', ocr_bbox)
|
||||
for i, box1_elem in enumerate(boxes):
|
||||
box1 = box1_elem['bbox']
|
||||
is_valid_box = True
|
||||
for j, box2_elem in enumerate(boxes):
|
||||
# keep the smaller box
|
||||
box2 = box2_elem['bbox']
|
||||
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
|
||||
is_valid_box = False
|
||||
break
|
||||
if is_valid_box:
|
||||
# add the following 2 lines to include ocr bbox
|
||||
if ocr_bbox:
|
||||
# keep yolo boxes + prioritize ocr label
|
||||
box_added = False
|
||||
ocr_labels = ''
|
||||
for box3_elem in ocr_bbox:
|
||||
if not box_added:
|
||||
box3 = box3_elem['bbox']
|
||||
if is_inside(box3, box1): # ocr inside icon
|
||||
# box_added = True
|
||||
# delete the box3_elem from ocr_bbox
|
||||
try:
|
||||
# filtered_boxes.append({'type': 'text', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': box3_elem['content'], 'source':'box_yolo_content_ocr'})
|
||||
# gather all ocr labels
|
||||
ocr_labels += box3_elem['content'] + ' '
|
||||
filtered_boxes.remove(box3_elem)
|
||||
# print('remove ocr bbox:', box3_elem)
|
||||
except:
|
||||
continue
|
||||
# break
|
||||
elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
|
||||
box_added = True
|
||||
# try:
|
||||
# filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None})
|
||||
# filtered_boxes.remove(box3_elem)
|
||||
# except:
|
||||
# continue
|
||||
break
|
||||
else:
|
||||
continue
|
||||
if not box_added:
|
||||
if ocr_labels:
|
||||
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels, 'source':'box_yolo_content_ocr'})
|
||||
else:
|
||||
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, 'source':'box_yolo_content_yolo'})
|
||||
|
||||
else:
|
||||
filtered_boxes.append(box1)
|
||||
return filtered_boxes # torch.tensor(filtered_boxes)
|
||||
|
||||
|
||||
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.RandomResize([800], max_size=1333),
|
||||
T.ToTensor(),
|
||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
image_source = Image.open(image_path).convert("RGB")
|
||||
image = np.asarray(image_source)
|
||||
image_transformed, _ = transform(image_source, None)
|
||||
return image, image_transformed
|
||||
|
||||
|
||||
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
|
||||
text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
|
||||
"""
|
||||
This function annotates an image with bounding boxes and labels.
|
||||
|
||||
Parameters:
|
||||
image_source (np.ndarray): The source image to be annotated.
|
||||
boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
|
||||
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
|
||||
phrases (List[str]): A list of labels for each bounding box.
|
||||
text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
|
||||
Returns:
|
||||
np.ndarray: The annotated image.
|
||||
"""
|
||||
h, w, _ = image_source.shape
|
||||
boxes = boxes * torch.Tensor([w, h, w, h])
|
||||
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
||||
xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
|
||||
detections = sv.Detections(xyxy=xyxy)
|
||||
|
||||
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
|
||||
|
||||
# from util.box_annotator import BoxAnnotator
|
||||
box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
annotated_frame = image_source.copy()
|
||||
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
|
||||
|
||||
label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
|
||||
return annotated_frame, label_coordinates
|
||||
|
||||
|
||||
def predict(model, image, caption, box_threshold, text_threshold):
|
||||
""" Use huggingface model to replace the original model
|
||||
"""
|
||||
model, processor = model['model'], model['processor']
|
||||
device = model.device
|
||||
|
||||
inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
inputs.input_ids,
|
||||
box_threshold=box_threshold, # 0.4,
|
||||
text_threshold=text_threshold, # 0.3,
|
||||
target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
|
||||
return boxes, logits, phrases
|
||||
|
||||
|
||||
def predict_yolo(model, image_path, box_threshold, imgsz, scale_img, iou_threshold=0.7):
|
||||
""" Use huggingface model to replace the original model
|
||||
"""
|
||||
# model = model['model']
|
||||
if scale_img:
|
||||
result = model.predict(
|
||||
source=image_path,
|
||||
conf=box_threshold,
|
||||
imgsz=imgsz,
|
||||
iou=iou_threshold, # default 0.7
|
||||
)
|
||||
else:
|
||||
result = model.predict(
|
||||
source=image_path,
|
||||
conf=box_threshold,
|
||||
iou=iou_threshold, # default 0.7
|
||||
)
|
||||
boxes = result[0].boxes.xyxy#.tolist() # in pixel space
|
||||
conf = result[0].boxes.conf
|
||||
phrases = [str(i) for i in range(len(boxes))]
|
||||
|
||||
return boxes, conf, phrases
|
||||
|
||||
|
||||
def int_box_area(box, w, h):
|
||||
x1, y1, x2, y2 = box
|
||||
int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
|
||||
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
|
||||
return area
|
||||
|
||||
|
||||
def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
|
||||
""" ocr_bbox: list of xyxy format bbox
|
||||
"""
|
||||
image_source = Image.open(img_path).convert("RGB")
|
||||
w, h = image_source.size
|
||||
if not imgsz:
|
||||
imgsz = (h, w)
|
||||
# print('image size:', w, h)
|
||||
xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
|
||||
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
|
||||
image_source = np.asarray(image_source)
|
||||
phrases = [str(i) for i in range(len(phrases))]
|
||||
|
||||
|
||||
# annotate the image with labels
|
||||
if ocr_bbox:
|
||||
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
|
||||
ocr_bbox=ocr_bbox.tolist()
|
||||
else:
|
||||
print('no ocr bbox!!!')
|
||||
ocr_bbox = None
|
||||
|
||||
ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt, 'source': 'box_ocr_content_ocr'} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
|
||||
xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
|
||||
filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
|
||||
|
||||
|
||||
# sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
|
||||
filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
|
||||
# get the index of the first 'content': None
|
||||
starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
|
||||
filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
|
||||
print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
|
||||
|
||||
# get parsed icon local semantics
|
||||
time1 = time.time()
|
||||
if use_local_semantics:
|
||||
caption_model = caption_model_processor['model']
|
||||
if 'phi3_v' in caption_model.config.model_type:
|
||||
parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
|
||||
else:
|
||||
parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
|
||||
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
|
||||
icon_start = len(ocr_text)
|
||||
parsed_content_icon_ls = []
|
||||
# fill the filtered_boxes_elem None content with parsed_content_icon in order
|
||||
for i, box in enumerate(filtered_boxes_elem):
|
||||
if box['content'] is None:
|
||||
box['content'] = parsed_content_icon.pop(0)
|
||||
for i, txt in enumerate(parsed_content_icon):
|
||||
parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
|
||||
parsed_content_merged = ocr_text + parsed_content_icon_ls
|
||||
else:
|
||||
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
|
||||
parsed_content_merged = ocr_text
|
||||
print('time to get parsed content:', time.time()-time1)
|
||||
|
||||
filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
|
||||
|
||||
phrases = [i for i in range(len(filtered_boxes))]
|
||||
|
||||
# draw boxes
|
||||
if draw_bbox_config:
|
||||
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
|
||||
else:
|
||||
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
|
||||
|
||||
pil_img = Image.fromarray(annotated_frame)
|
||||
buffered = io.BytesIO()
|
||||
pil_img.save(buffered, format="PNG")
|
||||
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||
if output_coord_in_ratio:
|
||||
# h, w, _ = image_source.shape
|
||||
label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
|
||||
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
|
||||
|
||||
return encoded_image, label_coordinates, filtered_boxes_elem
|
||||
|
||||
|
||||
def get_xywh(input):
|
||||
x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
|
||||
x, y, w, h = int(x), int(y), int(w), int(h)
|
||||
return x, y, w, h
|
||||
|
||||
def get_xyxy(input):
|
||||
x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
|
||||
x, y, xp, yp = int(x), int(y), int(xp), int(yp)
|
||||
return x, y, xp, yp
|
||||
|
||||
def get_xywh_yolo(input):
|
||||
x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
|
||||
x, y, w, h = int(x), int(y), int(w), int(h)
|
||||
return x, y, w, h
|
||||
|
||||
|
||||
|
||||
def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
|
||||
if use_paddleocr:
|
||||
if easyocr_args is None:
|
||||
text_threshold = 0.5
|
||||
else:
|
||||
text_threshold = easyocr_args['text_threshold']
|
||||
result = paddle_ocr.ocr(image_path, cls=False)[0]
|
||||
# conf = [item[1] for item in result]
|
||||
coord = [item[0] for item in result if item[1][1] > text_threshold]
|
||||
text = [item[1][0] for item in result if item[1][1] > text_threshold]
|
||||
else: # EasyOCR
|
||||
if easyocr_args is None:
|
||||
easyocr_args = {}
|
||||
result = reader.readtext(image_path, **easyocr_args)
|
||||
# print('goal filtering pred:', result[-5:])
|
||||
coord = [item[0] for item in result]
|
||||
text = [item[1] for item in result]
|
||||
# read the image using cv2
|
||||
if display_img:
|
||||
opencv_img = cv2.imread(image_path)
|
||||
opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
|
||||
bb = []
|
||||
for item in coord:
|
||||
x, y, a, b = get_xywh(item)
|
||||
# print(x, y, a, b)
|
||||
bb.append((x, y, a, b))
|
||||
cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
|
||||
|
||||
# Display the image
|
||||
plt.imshow(opencv_img)
|
||||
else:
|
||||
if output_bb_format == 'xywh':
|
||||
bb = [get_xywh(item) for item in coord]
|
||||
elif output_bb_format == 'xyxy':
|
||||
bb = [get_xyxy(item) for item in coord]
|
||||
# print('bounding box!!!', bb)
|
||||
return (text, bb), goal_filtering
|
||||
|
||||
|
||||
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from supervision.detection.core import Detections
|
||||
from supervision.draw.color import Color, ColorPalette
|
||||
|
||||
|
||||
class BoxAnnotator:
|
||||
"""
|
||||
A class for drawing bounding boxes on an image using detections provided.
|
||||
|
||||
Attributes:
|
||||
color (Union[Color, ColorPalette]): The color to draw the bounding box,
|
||||
can be a single color or a color palette
|
||||
thickness (int): The thickness of the bounding box lines, default is 2
|
||||
text_color (Color): The color of the text on the bounding box, default is white
|
||||
text_scale (float): The scale of the text on the bounding box, default is 0.5
|
||||
text_thickness (int): The thickness of the text on the bounding box,
|
||||
default is 1
|
||||
text_padding (int): The padding around the text on the bounding box,
|
||||
default is 5
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
|
||||
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
|
||||
text_color: Color = Color.BLACK,
|
||||
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
|
||||
text_thickness: int = 2, #1, # 2 for demo
|
||||
text_padding: int = 10,
|
||||
avoid_overlap: bool = True,
|
||||
):
|
||||
self.color: Union[Color, ColorPalette] = color
|
||||
self.thickness: int = thickness
|
||||
self.text_color: Color = text_color
|
||||
self.text_scale: float = text_scale
|
||||
self.text_thickness: int = text_thickness
|
||||
self.text_padding: int = text_padding
|
||||
self.avoid_overlap: bool = avoid_overlap
|
||||
|
||||
def annotate(
|
||||
self,
|
||||
scene: np.ndarray,
|
||||
detections: Detections,
|
||||
labels: Optional[List[str]] = None,
|
||||
skip_label: bool = False,
|
||||
image_size: Optional[Tuple[int, int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Draws bounding boxes on the frame using the detections provided.
|
||||
|
||||
Args:
|
||||
scene (np.ndarray): The image on which the bounding boxes will be drawn
|
||||
detections (Detections): The detections for which the
|
||||
bounding boxes will be drawn
|
||||
labels (Optional[List[str]]): An optional list of labels
|
||||
corresponding to each detection. If `labels` are not provided,
|
||||
corresponding `class_id` will be used as label.
|
||||
skip_label (bool): Is set to `True`, skips bounding box label annotation.
|
||||
Returns:
|
||||
np.ndarray: The image with the bounding boxes drawn on it
|
||||
|
||||
Example:
|
||||
```python
|
||||
import supervision as sv
|
||||
|
||||
classes = ['person', ...]
|
||||
image = ...
|
||||
detections = sv.Detections(...)
|
||||
|
||||
box_annotator = sv.BoxAnnotator()
|
||||
labels = [
|
||||
f"{classes[class_id]} {confidence:0.2f}"
|
||||
for _, _, confidence, class_id, _ in detections
|
||||
]
|
||||
annotated_frame = box_annotator.annotate(
|
||||
scene=image.copy(),
|
||||
detections=detections,
|
||||
labels=labels
|
||||
)
|
||||
```
|
||||
"""
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
for i in range(len(detections)):
|
||||
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
|
||||
class_id = (
|
||||
detections.class_id[i] if detections.class_id is not None else None
|
||||
)
|
||||
idx = class_id if class_id is not None else i
|
||||
color = (
|
||||
self.color.by_idx(idx)
|
||||
if isinstance(self.color, ColorPalette)
|
||||
else self.color
|
||||
)
|
||||
cv2.rectangle(
|
||||
img=scene,
|
||||
pt1=(x1, y1),
|
||||
pt2=(x2, y2),
|
||||
color=color.as_bgr(),
|
||||
thickness=self.thickness,
|
||||
)
|
||||
if skip_label:
|
||||
continue
|
||||
|
||||
text = (
|
||||
f"{class_id}"
|
||||
if (labels is None or len(detections) != len(labels))
|
||||
else labels[i]
|
||||
)
|
||||
|
||||
text_width, text_height = cv2.getTextSize(
|
||||
text=text,
|
||||
fontFace=font,
|
||||
fontScale=self.text_scale,
|
||||
thickness=self.text_thickness,
|
||||
)[0]
|
||||
|
||||
if not self.avoid_overlap:
|
||||
text_x = x1 + self.text_padding
|
||||
text_y = y1 - self.text_padding
|
||||
|
||||
text_background_x1 = x1
|
||||
text_background_y1 = y1 - 2 * self.text_padding - text_height
|
||||
|
||||
text_background_x2 = x1 + 2 * self.text_padding + text_width
|
||||
text_background_y2 = y1
|
||||
# text_x = x1 - self.text_padding - text_width
|
||||
# text_y = y1 + self.text_padding + text_height
|
||||
# text_background_x1 = x1 - 2 * self.text_padding - text_width
|
||||
# text_background_y1 = y1
|
||||
# text_background_x2 = x1
|
||||
# text_background_y2 = y1 + 2 * self.text_padding + text_height
|
||||
else:
|
||||
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
|
||||
|
||||
cv2.rectangle(
|
||||
img=scene,
|
||||
pt1=(text_background_x1, text_background_y1),
|
||||
pt2=(text_background_x2, text_background_y2),
|
||||
color=color.as_bgr(),
|
||||
thickness=cv2.FILLED,
|
||||
)
|
||||
box_color = color.as_rgb()
|
||||
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
|
||||
text_color = (0,0,0) if luminance > 160 else (255,255,255)
|
||||
cv2.putText(
|
||||
img=scene,
|
||||
text=text,
|
||||
org=(text_x, text_y),
|
||||
fontFace=font,
|
||||
fontScale=self.text_scale,
|
||||
# color=self.text_color.as_rgb(),
|
||||
color=text_color,
|
||||
thickness=self.text_thickness,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
return scene
|
||||
|
||||
|
||||
def box_area(box):
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
def intersection_area(box1, box2):
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
return max(0, x2 - x1) * max(0, y2 - y1)
|
||||
|
||||
def IoU(box1, box2, return_max=True):
|
||||
intersection = intersection_area(box1, box2)
|
||||
union = box_area(box1) + box_area(box2) - intersection
|
||||
if box_area(box1) > 0 and box_area(box2) > 0:
|
||||
ratio1 = intersection / box_area(box1)
|
||||
ratio2 = intersection / box_area(box2)
|
||||
else:
|
||||
ratio1, ratio2 = 0, 0
|
||||
if return_max:
|
||||
return max(intersection / union, ratio1, ratio2)
|
||||
else:
|
||||
return intersection / union
|
||||
|
||||
|
||||
def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
|
||||
""" check overlap of text and background detection box, and get_optimal_label_pos,
|
||||
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
|
||||
Threshold: default to 0.3
|
||||
"""
|
||||
|
||||
def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
|
||||
is_overlap = False
|
||||
for i in range(len(detections)):
|
||||
detection = detections.xyxy[i].astype(int)
|
||||
if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
|
||||
is_overlap = True
|
||||
break
|
||||
# check if the text is out of the image
|
||||
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
|
||||
is_overlap = True
|
||||
return is_overlap
|
||||
|
||||
# if pos == 'top left':
|
||||
text_x = x1 + text_padding
|
||||
text_y = y1 - text_padding
|
||||
|
||||
text_background_x1 = x1
|
||||
text_background_y1 = y1 - 2 * text_padding - text_height
|
||||
|
||||
text_background_x2 = x1 + 2 * text_padding + text_width
|
||||
text_background_y2 = y1
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
# elif pos == 'outer left':
|
||||
text_x = x1 - text_padding - text_width
|
||||
text_y = y1 + text_padding + text_height
|
||||
|
||||
text_background_x1 = x1 - 2 * text_padding - text_width
|
||||
text_background_y1 = y1
|
||||
|
||||
text_background_x2 = x1
|
||||
text_background_y2 = y1 + 2 * text_padding + text_height
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
|
||||
# elif pos == 'outer right':
|
||||
text_x = x2 + text_padding
|
||||
text_y = y1 + text_padding + text_height
|
||||
|
||||
text_background_x1 = x2
|
||||
text_background_y1 = y1
|
||||
|
||||
text_background_x2 = x2 + 2 * text_padding + text_width
|
||||
text_background_y2 = y1 + 2 * text_padding + text_height
|
||||
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
# elif pos == 'top right':
|
||||
text_x = x2 - text_padding - text_width
|
||||
text_y = y1 - text_padding
|
||||
|
||||
text_background_x1 = x2 - 2 * text_padding - text_width
|
||||
text_background_y1 = y1 - 2 * text_padding - text_height
|
||||
|
||||
text_background_x2 = x2
|
||||
text_background_y2 = y1
|
||||
|
||||
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
|
||||
if not is_overlap:
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
|
||||
|
||||
|
||||
|
||||
import re
|
||||
def extract_dict_from_text(text):
|
||||
# Define the regex pattern for a dictionary-like structure
|
||||
pattern = r"\{\s*'(?P<key1>.*?)':\s*'(?P<value1>.*?)',\s*'(?P<key2>.*?)':\s*'(?P<value2>.*?)'\s*\}"
|
||||
|
||||
# Search for the dictionary in the text
|
||||
match = re.search(pattern, text)
|
||||
|
||||
if match:
|
||||
# Extract matched groups into a dictionary
|
||||
return {
|
||||
match.group('key1'): match.group('value1'),
|
||||
match.group('key2'): match.group('value2'),
|
||||
}
|
||||
else:
|
||||
raise ValueError("No valid dictionary structure found in the text.")
|
||||
|
||||
|
||||
def get_phi3v_model_dict():
|
||||
from PIL import Image
|
||||
import requests
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoProcessor
|
||||
|
||||
model_id = "microsoft/Phi-3.5-vision-instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto", _attn_implementation='flash_attention_2')
|
||||
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
||||
print('phi3v model loaded!!!')
|
||||
return {'model': model, 'processor': processor}
|
||||
|
||||
|
||||
def call_phi3v(messages, image_base64, model_dict):
|
||||
model, processor = model_dict['model'], model_dict['processor']
|
||||
device = model.device
|
||||
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
if isinstance(image_base64, tuple):
|
||||
image_base64, dino_labled_img = image_base64
|
||||
image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
|
||||
dino_labled_img = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
|
||||
inputs = processor(prompt, [image, dino_labled_img], return_tensors="pt").to(device)
|
||||
else:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(image_base64)))
|
||||
inputs = processor(prompt, [image], return_tensors="pt").to(device)
|
||||
|
||||
generation_args = {
|
||||
"max_new_tokens": 512,
|
||||
"temperature": 0.01,
|
||||
"do_sample": False,
|
||||
}
|
||||
|
||||
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
|
||||
# remove input tokens
|
||||
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
|
||||
ans = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
return ans
|
||||
|
||||
|
||||
def get_pred_phi3v(message_text, image_base64, label_coordinates, id_key='Click ID', model_dict=None):
|
||||
# messages = [
|
||||
# {"role": "system", "content": '''You are an expert at completing instructions on GUI screens.
|
||||
# You will be presented with two images. The first is the original screenshot. The second is the same screenshot with some numeric tags. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
|
||||
# ]
|
||||
messages = [
|
||||
{"role": "system", "content": '''You are an expert at completing instructions on GUI screens. You will also be provided with some descriptions of the bbox, and your task is to choose the numeric bbox idx you want to click in order to complete the user instruction.'''},
|
||||
]
|
||||
messages = []
|
||||
if isinstance(image_base64, tuple):
|
||||
messages.append({"role": "user", "content": '<|image_1|>\n' + '<|image_2|>\n' + message_text})
|
||||
else:
|
||||
messages.append({"role": "user", "content": '<|image_1|>\n' + message_text})
|
||||
|
||||
response_text = call_phi3v(messages, image_base64, model_dict)
|
||||
print(response_text)
|
||||
|
||||
try:
|
||||
response_text = ast.literal_eval(response_text)
|
||||
|
||||
icon_id = response_text['Click BBox ID']
|
||||
bbox = label_coordinates[str(icon_id)]
|
||||
click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
|
||||
except:
|
||||
print('error parsing, use regex to parse!!!')
|
||||
import pdb; pdb.set_trace()
|
||||
response_text = extract_dict_from_text(response_text)
|
||||
icon_id = response_text['Click BBox ID']
|
||||
bbox = label_coordinates[str(icon_id)]
|
||||
click_point = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
|
||||
return icon_id, bbox, click_point, response_text
|
||||
|
||||
# try:
|
||||
# match = re.search(r"```(.*?)```", ans, re.DOTALL)
|
||||
# if match:
|
||||
# result = match.group(1).strip()
|
||||
# pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
|
||||
# pred = ast.literal_eval(pred)
|
||||
# else:
|
||||
# pred = ans.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
|
||||
# pred = ast.literal_eval(pred)
|
||||
|
||||
# if pred[id_key]:
|
||||
# icon_id = pred[id_key]
|
||||
# bbox = label_coordinates[str(icon_id)]
|
||||
# pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
|
||||
# except:
|
||||
# print('phi3v action regex extract fail!!!')
|
||||
# pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False}
|
||||
|
||||
# step_pred_summary = None
|
||||
# return pred, [True, ans, None, step_pred_summary]
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 226 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 120 KiB |
33
demo/omniparserserver/omniparser.py
Normal file
33
demo/omniparserserver/omniparser.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from ...util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
|
||||
import torch
|
||||
from PIL import Image
|
||||
import io
|
||||
import base64
|
||||
|
||||
class Omniparser(object):
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
self.som_model = get_yolo_model(model_path=config['som_model_path'])
|
||||
self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
|
||||
print('Omniparser initialized!!!')
|
||||
|
||||
def parse(self, image_base64: str):
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
print('image size:', image.size)
|
||||
|
||||
box_overlay_ratio = max(image.size) / 3200
|
||||
draw_bbox_config = {
|
||||
'text_scale': 0.8 * box_overlay_ratio,
|
||||
'text_thickness': max(int(2 * box_overlay_ratio), 1),
|
||||
'text_padding': max(int(3 * box_overlay_ratio), 1),
|
||||
'thickness': max(int(3 * box_overlay_ratio), 1),
|
||||
}
|
||||
BOX_TRESHOLD = self.config['BOX_TRESHOLD']
|
||||
|
||||
(text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
|
||||
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
|
||||
|
||||
return dino_labled_img, parsed_content_list
|
||||
50
demo/omniparserserver/remote_request.py
Normal file
50
demo/omniparserserver/remote_request.py
Normal file
@@ -0,0 +1,50 @@
|
||||
'''
|
||||
python -m remote_request --som_model_path ../weights/icon_detect_v1_5/model_v1_5.pt --caption_model_name florence2 --caption_model_path ../weights/icon_caption_florence --device cuda --BOX_TRESHOLD 0.05
|
||||
'''
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import time
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
import argparse
|
||||
import uvicorn
|
||||
from omniparser import Omniparser
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Omniparser API')
|
||||
parser.add_argument('--som_model_path', type=str, default='../weights/icon_detect_v1_5/model_v1_5.pt', help='Path to the som model')
|
||||
parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model')
|
||||
parser.add_argument('--caption_model_path', type=str, default='../weights/icon_caption_florence', help='Path to the caption model')
|
||||
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model')
|
||||
parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API')
|
||||
parser.add_argument('--port', type=int, default=8000, help='Port for the API')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = parse_arguments()
|
||||
config = vars(args)
|
||||
|
||||
app = FastAPI()
|
||||
omniparser = Omniparser(config)
|
||||
|
||||
class ParseRequest(BaseModel):
|
||||
base64_image: str
|
||||
|
||||
@app.post("/parse/")
|
||||
async def parse(parse_request: ParseRequest):
|
||||
print('start parsing...')
|
||||
start = time.time()
|
||||
dino_labled_img, parsed_content_list = omniparser.parse(parse_request.base64_image)
|
||||
latency = time.time() - start
|
||||
print('time:', latency)
|
||||
return {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, 'latency': latency}
|
||||
|
||||
@app.get("/probe/")
|
||||
async def root():
|
||||
return {"message": "Omniparser API ready"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("remote_request:app", host=args.host, port=args.port, reload=True)
|
||||
@@ -1,101 +0,0 @@
|
||||
'''
|
||||
python -m remote_request --som_model_path ../weights/icon_detect_v1_5/model_v1_5.pt --caption_model_name florence2 --caption_model_path ../weights/icon_caption_florence --device cuda --BOX_TRESHOLD 0.05
|
||||
'''
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
import time
|
||||
from utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
|
||||
import torch
|
||||
from PIL import Image
|
||||
from typing import Dict, Tuple, List
|
||||
import base64
|
||||
import io
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
import argparse
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Omniparser API')
|
||||
parser.add_argument('--som_model_path', type=str, default='../weights/icon_detect_v1_5/model_v1_5.pt', help='Path to the som model')
|
||||
parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model')
|
||||
parser.add_argument('--caption_model_path', type=str, default='../weights/icon_caption_florence', help='Path to the caption model')
|
||||
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model')
|
||||
parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API')
|
||||
parser.add_argument('--port', type=int, default=8000, help='Port for the API')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = parse_arguments()
|
||||
config = vars(args)
|
||||
|
||||
|
||||
# config = {
|
||||
# 'som_model_path': '../weights/icon_detect_v1_5/model_v1_5.pt',
|
||||
# 'device': 'cpu',
|
||||
# 'caption_model_name': 'florence2',
|
||||
# 'caption_model_path': '../weights/icon_caption_florence',
|
||||
# 'BOX_TRESHOLD': 0.05
|
||||
# }
|
||||
|
||||
|
||||
class Omniparser(object):
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
self.som_model = get_yolo_model(model_path=config['som_model_path'])
|
||||
self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
|
||||
print('Omniparser initialized!!!')
|
||||
|
||||
def parse(self, image_base64: str):
|
||||
# Convert base64 to image directly without saving to disk
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
print('image size:', image.size)
|
||||
|
||||
box_overlay_ratio = max(image.size) / 3200
|
||||
draw_bbox_config = {
|
||||
'text_scale': 0.8 * box_overlay_ratio,
|
||||
'text_thickness': max(int(2 * box_overlay_ratio), 1),
|
||||
'text_padding': max(int(3 * box_overlay_ratio), 1),
|
||||
'thickness': max(int(3 * box_overlay_ratio), 1),
|
||||
}
|
||||
BOX_TRESHOLD = config['BOX_TRESHOLD']
|
||||
|
||||
(text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
|
||||
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
|
||||
|
||||
return dino_labled_img, parsed_content_list
|
||||
|
||||
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
class Item(BaseModel):
|
||||
base64_image: str
|
||||
prompt: str
|
||||
|
||||
Omniparser = Omniparser(config)
|
||||
|
||||
@app.post("/send_text/")
|
||||
async def send_text(item: Item):
|
||||
print('start parsing...')
|
||||
|
||||
start = time.time()
|
||||
dino_labled_img, parsed_content_list = Omniparser.parse(item.base64_image)
|
||||
latency = time.time() - start
|
||||
print('time:', latency)
|
||||
return {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, 'latency': latency}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Omniparser API ready"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("remote_request:app", host=args.host, port=args.port, reload=True)
|
||||
@@ -8,7 +8,7 @@ import io
|
||||
|
||||
|
||||
import base64, os
|
||||
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
|
||||
from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
@@ -17,8 +17,6 @@ yolo_model = get_yolo_model(model_path='weights/icon_detect_v1_5/best.pt')
|
||||
caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")
|
||||
# caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
|
||||
|
||||
|
||||
|
||||
MARKDOWN = """
|
||||
# OmniParser for Pure Vision Based General GUI Agent 🔥
|
||||
<div>
|
||||
@@ -65,8 +63,6 @@ def process(
|
||||
# parsed_content_list = str(parsed_content_list)
|
||||
return image, str(parsed_content_list)
|
||||
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(MARKDOWN)
|
||||
with gr.Row():
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
from utils import get_som_labeled_img, check_ocr_box, get_yolo_model
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
from PIL import Image
|
||||
from typing import Dict, Tuple, List
|
||||
import io
|
||||
import base64
|
||||
|
||||
|
||||
config = {
|
||||
'som_model_path': 'finetuned_icon_detect.pt',
|
||||
'device': 'cpu',
|
||||
'caption_model_path': 'Salesforce/blip2-opt-2.7b',
|
||||
'draw_bbox_config': {
|
||||
'text_scale': 0.8,
|
||||
'text_thickness': 2,
|
||||
'text_padding': 3,
|
||||
'thickness': 3,
|
||||
},
|
||||
'BOX_TRESHOLD': 0.05
|
||||
}
|
||||
|
||||
|
||||
class Omniparser(object):
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
|
||||
self.som_model = get_yolo_model(model_path=config['som_model_path'])
|
||||
# self.caption_model_processor = get_caption_model_processor(config['caption_model_path'], device=cofig['device'])
|
||||
# self.caption_model_processor['model'].to(torch.float32)
|
||||
|
||||
def parse(self, image_path: str):
|
||||
print('Parsing image:', image_path)
|
||||
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})
|
||||
text, ocr_bbox = ocr_bbox_rslt
|
||||
|
||||
draw_bbox_config = self.config['draw_bbox_config']
|
||||
BOX_TRESHOLD = self.config['BOX_TRESHOLD']
|
||||
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)
|
||||
|
||||
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
|
||||
# formating output
|
||||
return_list = [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},
|
||||
'text': parsed_content_list[i].split(': ')[1], 'type':'text'} for i, (k, coord) in enumerate(label_coordinates.items()) if i < len(parsed_content_list)]
|
||||
return_list.extend(
|
||||
[{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},
|
||||
'text': 'None', 'type':'icon'} for i, (k, coord) in enumerate(label_coordinates.items()) if i >= len(parsed_content_list)]
|
||||
)
|
||||
|
||||
return [image, return_list]
|
||||
|
||||
parser = Omniparser(config)
|
||||
image_path = 'examples/pc_1.png'
|
||||
|
||||
# time the parser
|
||||
import time
|
||||
s = time.time()
|
||||
image, parsed_content_list = parser.parse(image_path)
|
||||
device = config['device']
|
||||
print(f'Time taken for Omniparser on {device}:', time.time() - s)
|
||||
@@ -1,425 +0,0 @@
|
||||
'''
|
||||
Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
|
||||
'''
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
# import action_type as action_type_lib
|
||||
import enum
|
||||
|
||||
class ActionType(enum.IntEnum):
|
||||
# Placeholders for unused enum values
|
||||
UNUSED_0 = 0
|
||||
UNUSED_1 = 1
|
||||
UNUSED_2 = 2
|
||||
UNUSED_8 = 8
|
||||
UNUSED_9 = 9
|
||||
|
||||
########### Agent actions ###########
|
||||
|
||||
# A type action that sends text to the emulator. Note that this simply sends
|
||||
# text and does not perform any clicks for element focus or enter presses for
|
||||
# submitting text.
|
||||
TYPE = 3
|
||||
|
||||
# The dual point action used to represent all gestures.
|
||||
DUAL_POINT = 4
|
||||
|
||||
# These actions differentiate pressing the home and back button from touches.
|
||||
# They represent explicit presses of back and home performed using ADB.
|
||||
PRESS_BACK = 5
|
||||
PRESS_HOME = 6
|
||||
|
||||
# An action representing that ADB command for hitting enter was performed.
|
||||
PRESS_ENTER = 7
|
||||
|
||||
########### Episode status actions ###########
|
||||
|
||||
# An action used to indicate the desired task has been completed and resets
|
||||
# the environment. This action should also be used in the case that the task
|
||||
# has already been completed and there is nothing to do.
|
||||
# e.g. The task is to turn on the Wi-Fi when it is already on
|
||||
STATUS_TASK_COMPLETE = 10
|
||||
|
||||
# An action used to indicate that desired task is impossible to complete and
|
||||
# resets the environment. This can be a result of many different things
|
||||
# including UI changes, Android version differences, etc.
|
||||
STATUS_TASK_IMPOSSIBLE = 11
|
||||
|
||||
|
||||
_TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen
|
||||
ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4
|
||||
ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4
|
||||
|
||||
# Interval determining if an action is a tap or a swipe.
|
||||
_SWIPE_DISTANCE_THRESHOLD = 0.04
|
||||
|
||||
|
||||
def _yx_in_bounding_boxes(
|
||||
yx, bounding_boxes
|
||||
):
|
||||
"""Check if the (y,x) point is contained in each bounding box.
|
||||
|
||||
Args:
|
||||
yx: The (y, x) coordinate in pixels of the point.
|
||||
bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row
|
||||
represents a bounding box: (y_top_left, x_top_left, box_height,
|
||||
box_width). Note: containment is inclusive of the bounding box edges.
|
||||
|
||||
Returns:
|
||||
is_inside: A 1D bool array where each element specifies if the point is
|
||||
contained within the respective box.
|
||||
"""
|
||||
y, x = yx
|
||||
|
||||
# `bounding_boxes` has shape (n_elements, 4); we extract each array along the
|
||||
# last axis into shape (n_elements, 1), then squeeze unneeded dimension.
|
||||
top, left, height, width = [
|
||||
jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1)
|
||||
]
|
||||
|
||||
# The y-axis is inverted for AndroidEnv, so bottom = top + height.
|
||||
bottom, right = top + height, left + width
|
||||
|
||||
return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(
|
||||
x >= left, x <= right)
|
||||
|
||||
|
||||
def _resize_annotation_bounding_boxes(
|
||||
annotation_positions, annotation_width_augment_fraction,
|
||||
annotation_height_augment_fraction):
|
||||
"""Resize the bounding boxes by the given fractions.
|
||||
|
||||
Args:
|
||||
annotation_positions: Array of shape (N, 4), where each row represents the
|
||||
(y, x, height, width) of the bounding boxes.
|
||||
annotation_width_augment_fraction: The fraction to augment the box widths,
|
||||
E.g., 1.4 == 240% total increase.
|
||||
annotation_height_augment_fraction: Same as described for width, but for box
|
||||
height.
|
||||
|
||||
Returns:
|
||||
Resized bounding box.
|
||||
|
||||
"""
|
||||
height_change = (
|
||||
annotation_height_augment_fraction * annotation_positions[:, 2])
|
||||
width_change = (
|
||||
annotation_width_augment_fraction * annotation_positions[:, 3])
|
||||
|
||||
# Limit bounding box positions to the screen.
|
||||
resized_annotations = jnp.stack([
|
||||
jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)),
|
||||
jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)),
|
||||
jnp.minimum(1, annotation_positions[:, 2] + height_change),
|
||||
jnp.minimum(1, annotation_positions[:, 3] + width_change),
|
||||
],
|
||||
axis=1)
|
||||
return resized_annotations
|
||||
|
||||
|
||||
def is_tap_action(normalized_start_yx,
|
||||
normalized_end_yx):
|
||||
distance = jnp.linalg.norm(
|
||||
jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
|
||||
return distance <= _SWIPE_DISTANCE_THRESHOLD
|
||||
|
||||
|
||||
def _is_non_dual_point_action(action_type):
|
||||
return jnp.not_equal(action_type, ActionType.DUAL_POINT)
|
||||
|
||||
|
||||
def _check_tap_actions_match(
|
||||
tap_1_yx,
|
||||
tap_2_yx,
|
||||
annotation_positions,
|
||||
matching_tap_distance_threshold_screen_percentage,
|
||||
annotation_width_augment_fraction,
|
||||
annotation_height_augment_fraction,
|
||||
):
|
||||
"""Determines if two tap actions are the same."""
|
||||
resized_annotation_positions = _resize_annotation_bounding_boxes(
|
||||
annotation_positions,
|
||||
annotation_width_augment_fraction,
|
||||
annotation_height_augment_fraction,
|
||||
)
|
||||
|
||||
# Check if the ground truth tap action falls in an annotation's bounding box.
|
||||
tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions)
|
||||
tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions)
|
||||
both_in_box = jnp.max(tap1_in_box & tap2_in_box)
|
||||
|
||||
# If the ground-truth tap action falls outside any of the annotation
|
||||
# bounding boxes or one of the actions is inside a bounding box and the other
|
||||
# is outside bounding box or vice versa, compare the points using Euclidean
|
||||
# distance.
|
||||
within_threshold = (
|
||||
jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx))
|
||||
<= matching_tap_distance_threshold_screen_percentage
|
||||
)
|
||||
return jnp.logical_or(both_in_box, within_threshold)
|
||||
|
||||
|
||||
def _check_drag_actions_match(
|
||||
drag_1_touch_yx,
|
||||
drag_1_lift_yx,
|
||||
drag_2_touch_yx,
|
||||
drag_2_lift_yx,
|
||||
):
|
||||
"""Determines if two drag actions are the same."""
|
||||
# Store drag deltas (the change in the y and x coordinates from touch to
|
||||
# lift), magnitudes, and the index of the main axis, which is the axis with
|
||||
# the greatest change in coordinate value (e.g. a drag starting at (0, 0) and
|
||||
# ending at (0.3, 0.5) has a main axis index of 1).
|
||||
drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx
|
||||
drag_1_magnitudes = jnp.abs(drag_1_deltas)
|
||||
drag_1_main_axis = np.argmax(drag_1_magnitudes)
|
||||
drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx
|
||||
drag_2_magnitudes = jnp.abs(drag_2_deltas)
|
||||
drag_2_main_axis = np.argmax(drag_2_magnitudes)
|
||||
|
||||
return jnp.equal(drag_1_main_axis, drag_2_main_axis)
|
||||
|
||||
|
||||
def check_actions_match(
|
||||
action_1_touch_yx,
|
||||
action_1_lift_yx,
|
||||
action_1_action_type,
|
||||
action_2_touch_yx,
|
||||
action_2_lift_yx,
|
||||
action_2_action_type,
|
||||
annotation_positions,
|
||||
tap_distance_threshold = _TAP_DISTANCE_THRESHOLD,
|
||||
annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION,
|
||||
annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION,
|
||||
):
|
||||
"""Determines if two actions are considered to be the same.
|
||||
|
||||
Two actions being "the same" is defined here as two actions that would result
|
||||
in a similar screen state.
|
||||
|
||||
Args:
|
||||
action_1_touch_yx: The (y, x) coordinates of the first action's touch.
|
||||
action_1_lift_yx: The (y, x) coordinates of the first action's lift.
|
||||
action_1_action_type: The action type of the first action.
|
||||
action_2_touch_yx: The (y, x) coordinates of the second action's touch.
|
||||
action_2_lift_yx: The (y, x) coordinates of the second action's lift.
|
||||
action_2_action_type: The action type of the second action.
|
||||
annotation_positions: The positions of the UI annotations for the screen. It
|
||||
is A 2D int array of shape (num_bboxes, 4), where each row represents a
|
||||
bounding box: (y_top_left, x_top_left, box_height, box_width). Note that
|
||||
containment is inclusive of the bounding box edges.
|
||||
tap_distance_threshold: The threshold that determines if two taps result in
|
||||
a matching screen state if they don't fall the same bounding boxes.
|
||||
annotation_width_augment_fraction: The fraction to increase the width of the
|
||||
bounding box by.
|
||||
annotation_height_augment_fraction: The fraction to increase the height of
|
||||
of the bounding box by.
|
||||
|
||||
Returns:
|
||||
A boolean representing whether the two given actions are the same or not.
|
||||
"""
|
||||
action_1_touch_yx = jnp.asarray(action_1_touch_yx)
|
||||
action_1_lift_yx = jnp.asarray(action_1_lift_yx)
|
||||
action_2_touch_yx = jnp.asarray(action_2_touch_yx)
|
||||
action_2_lift_yx = jnp.asarray(action_2_lift_yx)
|
||||
|
||||
# Checks if at least one of the actions is global (i.e. not DUAL_POINT),
|
||||
# because if that is the case, only the actions' types need to be compared.
|
||||
has_non_dual_point_action = jnp.logical_or(
|
||||
_is_non_dual_point_action(action_1_action_type),
|
||||
_is_non_dual_point_action(action_2_action_type),
|
||||
)
|
||||
#print("non dual point: "+str(has_non_dual_point_action))
|
||||
|
||||
different_dual_point_types = jnp.logical_xor(
|
||||
is_tap_action(action_1_touch_yx, action_1_lift_yx),
|
||||
is_tap_action(action_2_touch_yx, action_2_lift_yx),
|
||||
)
|
||||
#print("different dual type: "+str(different_dual_point_types))
|
||||
|
||||
is_tap = jnp.logical_and(
|
||||
is_tap_action(action_1_touch_yx, action_1_lift_yx),
|
||||
is_tap_action(action_2_touch_yx, action_2_lift_yx),
|
||||
)
|
||||
#print("is tap: "+str(is_tap))
|
||||
|
||||
taps_match = _check_tap_actions_match(
|
||||
action_1_touch_yx,
|
||||
action_2_touch_yx,
|
||||
annotation_positions,
|
||||
tap_distance_threshold,
|
||||
annotation_width_augment_fraction,
|
||||
annotation_height_augment_fraction,
|
||||
)
|
||||
#print("tap match: "+str(taps_match))
|
||||
|
||||
taps_match = jnp.logical_and(is_tap, taps_match)
|
||||
#print("tap match: "+str(taps_match))
|
||||
|
||||
drags_match = _check_drag_actions_match(
|
||||
action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx
|
||||
)
|
||||
drags_match = jnp.where(is_tap, False, drags_match)
|
||||
#print("drag match: "+str(drags_match))
|
||||
|
||||
return jnp.where(
|
||||
has_non_dual_point_action,
|
||||
jnp.equal(action_1_action_type, action_2_action_type),
|
||||
jnp.where(
|
||||
different_dual_point_types,
|
||||
False,
|
||||
jnp.logical_or(taps_match, drags_match),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def action_2_format(step_data):
|
||||
# 把test数据集中的动作格式转换为计算matching score的格式
|
||||
action_type = step_data["action_type_id"]
|
||||
|
||||
if action_type == 4:
|
||||
if step_data["action_type_text"] == 'click': # 点击
|
||||
touch_point = step_data["touch"]
|
||||
lift_point = step_data["lift"]
|
||||
else: # 上下左右滑动
|
||||
if step_data["action_type_text"] == 'scroll down':
|
||||
touch_point = [0.5, 0.8]
|
||||
lift_point = [0.5, 0.2]
|
||||
elif step_data["action_type_text"] == 'scroll up':
|
||||
touch_point = [0.5, 0.2]
|
||||
lift_point = [0.5, 0.8]
|
||||
elif step_data["action_type_text"] == 'scroll left':
|
||||
touch_point = [0.2, 0.5]
|
||||
lift_point = [0.8, 0.5]
|
||||
elif step_data["action_type_text"] == 'scroll right':
|
||||
touch_point = [0.8, 0.5]
|
||||
lift_point = [0.2, 0.5]
|
||||
else:
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
|
||||
if action_type == 3:
|
||||
typed_text = step_data["type_text"]
|
||||
else:
|
||||
typed_text = ""
|
||||
|
||||
action = {"action_type": action_type, "touch_point": touch_point, "lift_point": lift_point,
|
||||
"typed_text": typed_text}
|
||||
|
||||
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
|
||||
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
|
||||
action["typed_text"] = action["typed_text"].lower()
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def pred_2_format(step_data):
|
||||
# 把模型输出的内容转换为计算action_matching的格式
|
||||
action_type = step_data["action_type"]
|
||||
|
||||
if action_type == 4: # 点击
|
||||
action_type_new = 4
|
||||
touch_point = step_data["click_point"]
|
||||
lift_point = step_data["click_point"]
|
||||
typed_text = ""
|
||||
elif action_type == 0:
|
||||
action_type_new = 4
|
||||
touch_point = [0.5, 0.8]
|
||||
lift_point = [0.5, 0.2]
|
||||
typed_text = ""
|
||||
elif action_type == 1:
|
||||
action_type_new = 4
|
||||
touch_point = [0.5, 0.2]
|
||||
lift_point = [0.5, 0.8]
|
||||
typed_text = ""
|
||||
elif action_type == 8:
|
||||
action_type_new = 4
|
||||
touch_point = [0.2, 0.5]
|
||||
lift_point = [0.8, 0.5]
|
||||
typed_text = ""
|
||||
elif action_type == 9:
|
||||
action_type_new = 4
|
||||
touch_point = [0.8, 0.5]
|
||||
lift_point = [0.2, 0.5]
|
||||
typed_text = ""
|
||||
else:
|
||||
action_type_new = action_type
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
typed_text = ""
|
||||
if action_type_new == 3:
|
||||
typed_text = step_data["typed_text"]
|
||||
|
||||
action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
|
||||
"typed_text": typed_text}
|
||||
|
||||
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
|
||||
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
|
||||
action["typed_text"] = action["typed_text"].lower()
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def pred_2_format_simplified(step_data):
|
||||
# 把模型输出的内容转换为计算action_matching的格式
|
||||
action_type = step_data["action_type"]
|
||||
|
||||
if action_type == 'click' : # 点击
|
||||
action_type_new = 4
|
||||
touch_point = step_data["click_point"]
|
||||
lift_point = step_data["click_point"]
|
||||
typed_text = ""
|
||||
elif action_type == 'scroll' and step_data["direction"] == 'down':
|
||||
action_type_new = 4
|
||||
touch_point = [0.5, 0.8]
|
||||
lift_point = [0.5, 0.2]
|
||||
typed_text = ""
|
||||
elif action_type == 'scroll' and step_data["direction"] == 'up':
|
||||
action_type_new = 4
|
||||
touch_point = [0.5, 0.2]
|
||||
lift_point = [0.5, 0.8]
|
||||
typed_text = ""
|
||||
elif action_type == 'scroll' and step_data["direction"] == 'left':
|
||||
action_type_new = 4
|
||||
touch_point = [0.2, 0.5]
|
||||
lift_point = [0.8, 0.5]
|
||||
typed_text = ""
|
||||
elif action_type == 'scroll' and step_data["direction"] == 'right':
|
||||
action_type_new = 4
|
||||
touch_point = [0.8, 0.5]
|
||||
lift_point = [0.2, 0.5]
|
||||
typed_text = ""
|
||||
elif action_type == 'type':
|
||||
action_type_new = 3
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
typed_text = step_data["text"]
|
||||
elif action_type == 'navigate_back':
|
||||
action_type_new = 5
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
typed_text = ""
|
||||
elif action_type == 'navigate_home':
|
||||
action_type_new = 6
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
typed_text = ""
|
||||
else:
|
||||
action_type_new = action_type
|
||||
touch_point = [-1.0, -1.0]
|
||||
lift_point = [-1.0, -1.0]
|
||||
typed_text = ""
|
||||
# if action_type_new == 'type':
|
||||
# typed_text = step_data["text"]
|
||||
|
||||
action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
|
||||
"typed_text": typed_text}
|
||||
|
||||
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
|
||||
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
|
||||
action["typed_text"] = action["typed_text"].lower()
|
||||
|
||||
return action
|
||||
@@ -1,45 +0,0 @@
|
||||
'''
|
||||
Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
|
||||
'''
|
||||
|
||||
import enum
|
||||
|
||||
class ActionType(enum.IntEnum):
|
||||
|
||||
# Placeholders for unused enum values
|
||||
UNUSED_0 = 0
|
||||
UNUSED_1 = 1
|
||||
UNUSED_2 = 2
|
||||
UNUSED_8 = 8
|
||||
UNUSED_9 = 9
|
||||
|
||||
########### Agent actions ###########
|
||||
|
||||
# A type action that sends text to the emulator. Note that this simply sends
|
||||
# text and does not perform any clicks for element focus or enter presses for
|
||||
# submitting text.
|
||||
TYPE = 3
|
||||
|
||||
# The dual point action used to represent all gestures.
|
||||
DUAL_POINT = 4
|
||||
|
||||
# These actions differentiate pressing the home and back button from touches.
|
||||
# They represent explicit presses of back and home performed using ADB.
|
||||
PRESS_BACK = 5
|
||||
PRESS_HOME = 6
|
||||
|
||||
# An action representing that ADB command for hitting enter was performed.
|
||||
PRESS_ENTER = 7
|
||||
|
||||
########### Episode status actions ###########
|
||||
|
||||
# An action used to indicate the desired task has been completed and resets
|
||||
# the environment. This action should also be used in the case that the task
|
||||
# has already been completed and there is nothing to do.
|
||||
# e.g. The task is to turn on the Wi-Fi when it is already on
|
||||
STATUS_TASK_COMPLETE = 10
|
||||
|
||||
# An action used to indicate that desired task is impossible to complete and
|
||||
# resets the environment. This can be a result of many different things
|
||||
# including UI changes, Android version differences, etc.
|
||||
STATUS_TASK_IMPOSSIBLE = 11
|
||||
0
utils.py → util/utils.py
Executable file → Normal file
0
utils.py → util/utils.py
Executable file → Normal file
Reference in New Issue
Block a user