add drop down for omni output in the chat; add args for app.py;

This commit is contained in:
yadonglu
2025-01-28 21:33:58 -08:00
parent 16570a9bf3
commit 7ea2239e10
10 changed files with 86 additions and 29 deletions

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@ weights/icon_detect_v1_5_2/
.gradio .gradio
__pycache__/ __pycache__/
debug.ipynb debug.ipynb
util/__pycache__/

View File

@@ -1,5 +1,6 @@
""" """
Entrypoint for Gradio, see https://gradio.app/ Entrypoint for Gradio, see https://gradio.app/
python app.py --windows_host_url xxxx:8006/ --omniparser_host_url localhost:8000
""" """
import platform import platform
@@ -15,6 +16,7 @@ from pathlib import Path
from typing import cast, Dict from typing import cast, Dict
from PIL import Image from PIL import Image
import socket import socket
import argparse
import gradio as gr import gradio as gr
from anthropic import APIResponse from anthropic import APIResponse
@@ -39,6 +41,18 @@ Welcome to the OmniParser+X Demo! X = [GPT-4o/4o-mini, Claude, Phi, Llama]. Let
Type a message and press submit to start OmniParser+X. Press the trash icon in the chat to clear the message history. Type a message and press submit to start OmniParser+X. Press the trash icon in the chat to clear the message history.
''' '''
def parse_arguments():
parser = argparse.ArgumentParser(description="Gradio App")
parser.add_argument("--windows_host_url", type=str, default='GCRSANDBOX336.redmond.corp.microsoft.com:8006/') # http://gcrsandbox336.redmond.corp.microsoft.com/
parser.add_argument("--omniparser_host_url", type=str, default="localhost:8000")
return parser.parse_args()
args = parse_arguments()
windows_host_url = args.windows_host_url
omniparser_host_url = args.omniparser_host_url
print(f"Windows host URL: {windows_host_url}")
print(f"OmniParser host URL: {omniparser_host_url}")
class Sender(StrEnum): class Sender(StrEnum):
USER = "user" USER = "user"
BOT = "assistant" BOT = "assistant"
@@ -68,8 +82,8 @@ def setup_state(state):
state["only_n_most_recent_images"] = 2 state["only_n_most_recent_images"] = 2
if 'chatbot_messages' not in state: if 'chatbot_messages' not in state:
state['chatbot_messages'] = [] state['chatbot_messages'] = []
if "omniparser_url" not in state: # if "omniparser_url" not in state:
state["omniparser_url"] = "localhost:8000" # state["omniparser_url"] = "localhost:8000"
async def main(state): async def main(state):
"""Render loop for Gradio""" """Render loop for Gradio"""
@@ -211,7 +225,7 @@ def process_input(user_input, state):
api_response_callback=partial(_api_response_callback, response_state=state["responses"]), api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
api_key=state["api_key"], api_key=state["api_key"],
only_n_most_recent_images=state["only_n_most_recent_images"], only_n_most_recent_images=state["only_n_most_recent_images"],
omniparser_url=state["omniparser_url"] omniparser_url=omniparser_host_url #state["omniparser_url"]
): ):
if loop_msg is None: if loop_msg is None:
yield state['chatbot_messages'] yield state['chatbot_messages']
@@ -275,13 +289,13 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
placeholder="Paste your API key here", placeholder="Paste your API key here",
interactive=True, interactive=True,
) )
with gr.Row(): # with gr.Row():
omniparser_url = gr.Textbox( # omniparser_url = gr.Textbox(
label="OmniParser Base URL", # label="OmniParser Base URL",
value="localhost:8000", # value="localhost:8000",
placeholder="Enter OmniParser base URL (e.g. localhost:8000)", # placeholder="Enter OmniParser base URL (e.g. localhost:8000)",
interactive=True # interactive=True
) # )
# hide_images = gr.Checkbox(label="Hide screenshots", value=False) # hide_images = gr.Checkbox(label="Hide screenshots", value=False)
with gr.Row(): with gr.Row():
@@ -294,11 +308,20 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
with gr.Column(scale=1): with gr.Column(scale=1):
chatbot = gr.Chatbot(label="Chatbot History", autoscroll=True, height=580) chatbot = gr.Chatbot(label="Chatbot History", autoscroll=True, height=580)
with gr.Column(scale=3): with gr.Column(scale=3):
if not windows_host_url:
iframe = gr.HTML( iframe = gr.HTML(
f'<iframe src="http://localhost:8006/vnc.html?view_only=1&autoconnect=1&resize=scale" width="100%" height="580" allow="fullscreen"></iframe>', f'<iframe src="http://localhost:8006/vnc.html?view_only=1&autoconnect=1&resize=scale" width="100%" height="580" allow="fullscreen"></iframe>',
container=False, container=False,
elem_classes="no-padding" elem_classes="no-padding"
) )
else:
# machine_fqdn = socket.getfqdn()
# print('machine_fqdn:', machine_fqdn)
iframe = gr.HTML(
f'<iframe src="http://{windows_host_url}/vnc.html?view_only=1&autoconnect=1&resize=scale" width="100%" height="580" allow="fullscreen"></iframe>',
container=False,
elem_classes="no-padding"
)
def update_model(model_selection, state): def update_model(model_selection, state):
state["model"] = model_selection state["model"] = model_selection
@@ -350,8 +373,8 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
state["api_key"] = api_key_value state["api_key"] = api_key_value
state[f'{state["provider"]}_api_key'] = api_key_value state[f'{state["provider"]}_api_key'] = api_key_value
def update_omniparser_url(url_value, state): # def update_omniparser_url(url_value, state):
state["omniparser_url"] = url_value # state["omniparser_url"] = url_value
def clear_chat(state): def clear_chat(state):
# Reset message-related state # Reset message-related state
@@ -365,7 +388,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
only_n_images.change(fn=update_only_n_images, inputs=[only_n_images, state], outputs=None) only_n_images.change(fn=update_only_n_images, inputs=[only_n_images, state], outputs=None)
provider.change(fn=update_provider, inputs=[provider, state], outputs=api_key) provider.change(fn=update_provider, inputs=[provider, state], outputs=api_key)
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None) api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
omniparser_url.change(fn=update_omniparser_url, inputs=[omniparser_url, state], outputs=None) # omniparser_url.change(fn=update_omniparser_url, inputs=[omniparser_url, state], outputs=None)
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot]) chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])
submit_button.click(process_input, [chat_input, state], chatbot) submit_button.click(process_input, [chat_input, state], chatbot)

View File

@@ -52,7 +52,7 @@ def sampling_loop_sync(
Synchronous agentic sampling loop for the assistant/tool interaction of computer use. Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
""" """
print('in sampling_loop_sync, model:', model) print('in sampling_loop_sync, model:', model)
omniparser = OmniParser(url=f"http://{omniparser_url}/send_text/" if omniparser_url and omniparser_url != "localhost:8000" else None) omniparser = OmniParser(url=f"http://{omniparser_url}/send_text/" if omniparser_url else None)
if model == "claude-3-5-sonnet-20241022": if model == "claude-3-5-sonnet-20241022":
# Register Actor and Executor # Register Actor and Executor
actor = AnthropicActor( actor = AnthropicActor(

View File

@@ -132,7 +132,15 @@ class VLMAgent:
sender="bot") sender="bot")
self.output_callback(f'Set of Marks Screenshot for {colorful_text_vlm}:\n<img src="data:image/png;base64,{parsed_screen["som_image_base64"]}">', sender="bot") self.output_callback(f'Set of Marks Screenshot for {colorful_text_vlm}:\n<img src="data:image/png;base64,{parsed_screen["som_image_base64"]}">', sender="bot")
screen_info = str(parsed_screen['screen_info']) screen_info = str(parsed_screen['screen_info'])
self.output_callback(f'Screen Info for {colorful_text_vlm}:\n{screen_info}', sender="bot") # self.output_callback(f'Screen Info for {colorful_text_vlm}:\n{screen_info}', sender="bot")
self.output_callback(
f'<details>'
f' <summary>Screen Info for {colorful_text_vlm}</summary>'
f' <pre>{screen_info}</pre>'
f'</details>',
sender="bot"
)
screenshot_uuid = parsed_screen['screenshot_uuid'] screenshot_uuid = parsed_screen['screenshot_uuid']
screen_width, screen_height = parsed_screen['width'], parsed_screen['height'] screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
@@ -155,7 +163,7 @@ class VLMAgent:
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_{screenshot_uuid}.png") planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_{screenshot_uuid}.png")
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_som_{screenshot_uuid}.png") planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_som_{screenshot_uuid}.png")
print(f"Sending messages to VLMPlanner : {planner_messages}") # print(f"Sending messages to VLMPlanner : {planner_messages}")
start = time.time() start = time.time()
if "gpt" in self.model: if "gpt" in self.model:
vlm_response, token_usage = run_oai_interleaved( vlm_response, token_usage = run_oai_interleaved(

View File

@@ -1,4 +1,7 @@
# uvicorn remote_request:app --host 0.0.0.0 --port 8000 --reload # uvicorn remote_request:app --host 0.0.0.0 --port 8000 --reload
'''
python -m remote_request --som_model_path ../weights/icon_detect_v1_5/model_v1_5.pt --caption_model_name florence2 --caption_model_path ../weights/icon_detect_v1_5/model_v1_5.pt --device cuda --BOX_TRESHOLD 0.05
'''
import sys import sys
import os import os
@@ -12,14 +15,31 @@ import base64
import io import io
from fastapi import FastAPI from fastapi import FastAPI
from pydantic import BaseModel from pydantic import BaseModel
import argparse
config = { def parse_arguments():
'som_model_path': '../weights/icon_detect_v1_5/model_v1_5.pt', parser = argparse.ArgumentParser(description='Omniparser API')
'device': 'cpu', parser.add_argument('--som_model_path', type=str, default='../weights/icon_detect_v1_5/model_v1_5.pt', help='Path to the som model')
'caption_model_name': 'florence2', parser.add_argument('--caption_model_name', type=str, default='florence2', help='Name of the caption model')
'caption_model_path': '../weights/icon_caption_florence', parser.add_argument('--caption_model_path', type=str, default='../weights/icon_caption_florence', help='Path to the caption model')
'BOX_TRESHOLD': 0.05 parser.add_argument('--device', type=str, default='cpu', help='Device to run the model')
} parser.add_argument('--BOX_TRESHOLD', type=float, default=0.05, help='Threshold for box detection')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host for the API')
parser.add_argument('--port', type=int, default=8000, help='Port for the API')
args = parser.parse_args()
return args
args = parse_arguments()
config = vars(args)
# config = {
# 'som_model_path': '../weights/icon_detect_v1_5/model_v1_5.pt',
# 'device': 'cpu',
# 'caption_model_name': 'florence2',
# 'caption_model_path': '../weights/icon_caption_florence',
# 'BOX_TRESHOLD': 0.05
# }
class Omniparser(object): class Omniparser(object):
@@ -75,3 +95,8 @@ async def send_text(item: Item):
@app.get("/") @app.get("/")
async def root(): async def root():
return {"message": "Omniparser API ready"} return {"message": "Omniparser API ready"}
if __name__ == "__main__":
import uvicorn
uvicorn.run("remote_request:app", host=args.host, port=args.port, reload=True)