This commit is contained in:
Thomas Dhome Casanova (from Dev Box)
2025-01-31 18:37:14 -08:00
parent ed7b34621b
commit e268184f8d
7 changed files with 119 additions and 28 deletions

View File

@@ -0,0 +1,59 @@
from groq import Groq
import os
from .utils import is_image_path
def run_groq_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0.6):
"""
Run a chat completion through Groq's API, ignoring any images in the messages.
"""
api_key = api_key or os.environ.get("GROQ_API_KEY")
if not api_key:
raise ValueError("GROQ_API_KEY is not set")
client = Groq(api_key=api_key)
# avoid using system messages for R1
final_messages = [{"role": "user", "content": system}]
if isinstance(messages, list):
for item in messages:
if isinstance(item, dict):
# For dict items, concatenate all text content, ignoring images
text_contents = []
for cnt in item["content"]:
if isinstance(cnt, str):
if not is_image_path(cnt): # Skip image paths
text_contents.append(cnt)
else:
text_contents.append(str(cnt))
if text_contents: # Only add if there's text content
message = {"role": "user", "content": " ".join(text_contents)}
final_messages.append(message)
else: # str
message = {"role": "user", "content": item}
final_messages.append(message)
elif isinstance(messages, str):
final_messages.append({"role": "user", "content": messages})
try:
completion = client.chat.completions.create(
model="deepseek-r1-distill-llama-70b",
messages=final_messages,
temperature=0.6,
max_completion_tokens=max_tokens,
top_p=0.95,
stream=False,
reasoning_format="raw"
)
response = completion.choices[0].message.content
final_answer = response.split('</think>\n')[-1] if '</think>' in response else response
final_answer = final_answer.replace("<output>", "").replace("</output>", "")
token_usage = completion.usage.total_tokens
return final_answer, token_usage
except Exception as e:
print(f"Error in interleaved Groq: {e}")
return str(e), 0

View File

@@ -2,21 +2,9 @@ import os
import logging import logging
import base64 import base64
import requests import requests
from .utils import is_image_path, encode_image
def is_image_path(text):
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
if text.endswith(image_extensions):
return True
else:
return False
def encode_image(image_path):
"""Encode image file to base64."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0): def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0):
api_key = api_key or os.environ.get("OPENAI_API_KEY") api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not api_key: if not api_key:
raise ValueError("OPENAI_API_KEY is not set") raise ValueError("OPENAI_API_KEY is not set")
@@ -54,8 +42,6 @@ def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max
elif isinstance(messages, str): elif isinstance(messages, str):
final_messages = [{"role": "user", "content": messages}] final_messages = [{"role": "user", "content": messages}]
print("[oai] sending messages:", {"role": "user", "content": messages})
payload = { payload = {
"model": llm, "model": llm,
"messages": final_messages, "messages": final_messages,

View File

@@ -2,7 +2,7 @@ import requests
import base64 import base64
from pathlib import Path from pathlib import Path
from tools.screen_capture import get_screenshot from tools.screen_capture import get_screenshot
from agent.llm_utils.oai import encode_image from agent.llm_utils.utils import encode_image
OUTPUT_DIR = "./tmp/outputs" OUTPUT_DIR = "./tmp/outputs"

View File

@@ -0,0 +1,13 @@
import base64
def is_image_path(text):
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
if text.endswith(image_extensions):
return True
else:
return False
def encode_image(image_path):
"""Encode image file to base64."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

View File

@@ -11,6 +11,7 @@ from anthropic.types import ToolResultBlockParam
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
from agent.llm_utils.oai import run_oai_interleaved from agent.llm_utils.oai import run_oai_interleaved
from agent.llm_utils.groqclient import run_groq_interleaved
import time import time
import re import re
@@ -39,9 +40,12 @@ class VLMAgent:
): ):
if model == "omniparser + gpt-4o": if model == "omniparser + gpt-4o":
self.model = "gpt-4o-2024-11-20" self.model = "gpt-4o-2024-11-20"
elif model == "omniparser + R1":
self.model = "deepseek-r1-distill-llama-70b"
else: else:
raise ValueError(f"Model {model} not supported") raise ValueError(f"Model {model} not supported")
self.provider = provider self.provider = provider
self.api_key = api_key self.api_key = api_key
self.api_response_callback = api_response_callback self.api_response_callback = api_response_callback
@@ -108,8 +112,17 @@ class VLMAgent:
print(f"oai token usage: {token_usage}") print(f"oai token usage: {token_usage}")
self.total_token_usage += token_usage self.total_token_usage += token_usage
self.total_cost += (token_usage * 0.15 / 1000000) # https://openai.com/api/pricing/ self.total_cost += (token_usage * 0.15 / 1000000) # https://openai.com/api/pricing/
elif "phi" in self.model: elif "r1" in self.model:
pass # TODO vlm_response, token_usage = run_groq_interleaved(
messages=planner_messages,
system=system,
llm=self.model,
api_key=self.api_key,
max_tokens=self.max_tokens,
)
print(f"groq token usage: {token_usage}")
self.total_token_usage += token_usage
self.total_cost += (token_usage * 0.99 / 1000000)
else: else:
raise ValueError(f"Model {self.model} not supported") raise ValueError(f"Model {self.model} not supported")
latency_vlm = time.time() - start latency_vlm = time.time() - start
@@ -168,7 +181,7 @@ class VLMAgent:
self.api_response_callback(response) self.api_response_callback(response)
def _get_system_prompt(self, screen_info: str = ""): def _get_system_prompt(self, screen_info: str = ""):
return f""" main_section = f"""
You are using a Windows device. You are using a Windows device.
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot. You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
You can only interact with the desktop GUI (no terminal or application menu access). You can only interact with the desktop GUI (no terminal or application menu access).
@@ -222,11 +235,27 @@ Another Example:
IMPORTANT NOTES: IMPORTANT NOTES:
1. You should only give a single action at a time. 1. You should only give a single action at a time.
"""
thinking_model = "r1" in self.model
if not thinking_model:
main_section += """
2. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. 2. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task.
"""
else:
main_section += """
2. In <think> XML tags give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. In <output> XML tags put the next action prediction JSON.
"""
main_section += """
3. Attach the next action prediction in the "Next Action". 3. Attach the next action prediction in the "Next Action".
4. You should not include other actions, such as keyboard shortcuts. 4. You should not include other actions, such as keyboard shortcuts.
5. When the task is completed, you should say "Next Action": "None" in the json field. 5. When the task is completed, you should say "Next Action": "None" in the json field.
""" """
return main_section
def draw_action(self, vlm_response_json, image_base64): def draw_action(self, vlm_response_json, image_base64):
# draw a circle using the coordinate in parsed_screen['som_image_base64'] # draw a circle using the coordinate in parsed_screen['som_image_base64']
image_data = base64.b64decode(image_base64) image_data = base64.b64decode(image_base64)

View File

@@ -28,12 +28,13 @@ API_KEY_FILE = CONFIG_DIR / "api_key"
INTRO_TEXT = ''' INTRO_TEXT = '''
🚀🤖✨ It's Play Time! 🚀🤖✨ It's Play Time!
Welcome to the OmniParser+X Computer Use Demo! X = [GPT-4o/4o-mini, Claude, Phi, Llama]. Let OmniParser turn your general purpose vision-langauge model to an AI agent. Welcome to the OmniParser+X Computer Use Demo! X = [GPT-4o, R1, Claude]. Let OmniParser turn your general purpose vision-langauge model to an AI agent.
Type a message and press submit to start OmniParser+X. Press the trash icon in the chat to clear the message history. Type a message and press submit to start OmniParser+X. Press the trash icon in the chat to clear the message history.
''' '''
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description="Gradio App") parser = argparse.ArgumentParser(description="Gradio App")
parser.add_argument("--windows_host_url", type=str, default='localhost:8006') parser.add_argument("--windows_host_url", type=str, default='localhost:8006')
parser.add_argument("--omniparser_server_url", type=str, default="localhost:8000") parser.add_argument("--omniparser_server_url", type=str, default="localhost:8000")
@@ -255,11 +256,10 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
} }
</style> </style>
""") """)
state = gr.State({}) # Use Gradio's state management state = gr.State({})
setup_state(state.value) # Initialize the state setup_state(state.value)
# Retrieve screen details
gr.Markdown("# OmniParser + ✖️ Demo") gr.Markdown("# OmniParser + ✖️ Demo")
if not os.getenv("HIDE_WARNING", False): if not os.getenv("HIDE_WARNING", False):
@@ -270,8 +270,8 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
with gr.Column(): with gr.Column():
model = gr.Dropdown( model = gr.Dropdown(
label="Model", label="Model",
choices=["omniparser + gpt-4o", "omniparser + phi35v", "claude-3-5-sonnet-20241022"], choices=["omniparser + gpt-4o", "omniparser + R1", "claude-3-5-sonnet-20241022"],
value="omniparser + gpt-4o", # Set to one of the choices value="omniparser + gpt-4o",
interactive=True, interactive=True,
) )
with gr.Column(): with gr.Column():
@@ -322,11 +322,14 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
if model_selection == "claude-3-5-sonnet-20241022": if model_selection == "claude-3-5-sonnet-20241022":
provider_choices = [option.value for option in APIProvider if option.value != "openai"] provider_choices = [option.value for option in APIProvider if option.value != "openai"]
elif model_selection == "omniparser + gpt-4o" or model_selection == "omniparser + phi35v": elif model_selection == "omniparser + gpt-4o":
provider_choices = ["openai"] provider_choices = ["openai"]
elif model_selection == "omniparser + R1":
provider_choices = ["groq"]
else: else:
provider_choices = [option.value for option in APIProvider] provider_choices = [option.value for option in APIProvider]
default_provider_value = provider_choices[0] default_provider_value = provider_choices[0]
provider_interactive = len(provider_choices) > 1 provider_interactive = len(provider_choices) > 1
api_key_placeholder = f"{default_provider_value.title()} API Key" api_key_placeholder = f"{default_provider_value.title()} API Key"

View File

@@ -28,4 +28,5 @@ boto3>=1.28.57
google-auth<3,>=2 google-auth<3,>=2
screeninfo screeninfo
uiautomation uiautomation
dashscope dashscope
groq