init r1
This commit is contained in:
59
computer_use_demo/gradio/agent/llm_utils/groqclient.py
Normal file
59
computer_use_demo/gradio/agent/llm_utils/groqclient.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from groq import Groq
|
||||||
|
import os
|
||||||
|
from .utils import is_image_path
|
||||||
|
|
||||||
|
def run_groq_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0.6):
|
||||||
|
"""
|
||||||
|
Run a chat completion through Groq's API, ignoring any images in the messages.
|
||||||
|
"""
|
||||||
|
api_key = api_key or os.environ.get("GROQ_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("GROQ_API_KEY is not set")
|
||||||
|
|
||||||
|
client = Groq(api_key=api_key)
|
||||||
|
# avoid using system messages for R1
|
||||||
|
final_messages = [{"role": "user", "content": system}]
|
||||||
|
|
||||||
|
if isinstance(messages, list):
|
||||||
|
for item in messages:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
# For dict items, concatenate all text content, ignoring images
|
||||||
|
text_contents = []
|
||||||
|
for cnt in item["content"]:
|
||||||
|
if isinstance(cnt, str):
|
||||||
|
if not is_image_path(cnt): # Skip image paths
|
||||||
|
text_contents.append(cnt)
|
||||||
|
else:
|
||||||
|
text_contents.append(str(cnt))
|
||||||
|
|
||||||
|
if text_contents: # Only add if there's text content
|
||||||
|
message = {"role": "user", "content": " ".join(text_contents)}
|
||||||
|
final_messages.append(message)
|
||||||
|
else: # str
|
||||||
|
message = {"role": "user", "content": item}
|
||||||
|
final_messages.append(message)
|
||||||
|
|
||||||
|
elif isinstance(messages, str):
|
||||||
|
final_messages.append({"role": "user", "content": messages})
|
||||||
|
|
||||||
|
try:
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model="deepseek-r1-distill-llama-70b",
|
||||||
|
messages=final_messages,
|
||||||
|
temperature=0.6,
|
||||||
|
max_completion_tokens=max_tokens,
|
||||||
|
top_p=0.95,
|
||||||
|
stream=False,
|
||||||
|
reasoning_format="raw"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = completion.choices[0].message.content
|
||||||
|
final_answer = response.split('</think>\n')[-1] if '</think>' in response else response
|
||||||
|
final_answer = final_answer.replace("<output>", "").replace("</output>", "")
|
||||||
|
token_usage = completion.usage.total_tokens
|
||||||
|
|
||||||
|
return final_answer, token_usage
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in interleaved Groq: {e}")
|
||||||
|
|
||||||
|
return str(e), 0
|
||||||
@@ -2,21 +2,9 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import base64
|
import base64
|
||||||
import requests
|
import requests
|
||||||
|
from .utils import is_image_path, encode_image
|
||||||
def is_image_path(text):
|
|
||||||
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
|
|
||||||
if text.endswith(image_extensions):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def encode_image(image_path):
|
|
||||||
"""Encode image file to base64."""
|
|
||||||
with open(image_path, "rb") as image_file:
|
|
||||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
|
||||||
|
|
||||||
def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0):
|
def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0):
|
||||||
|
|
||||||
api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("OPENAI_API_KEY is not set")
|
raise ValueError("OPENAI_API_KEY is not set")
|
||||||
@@ -54,8 +42,6 @@ def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max
|
|||||||
elif isinstance(messages, str):
|
elif isinstance(messages, str):
|
||||||
final_messages = [{"role": "user", "content": messages}]
|
final_messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
print("[oai] sending messages:", {"role": "user", "content": messages})
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": llm,
|
"model": llm,
|
||||||
"messages": final_messages,
|
"messages": final_messages,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import requests
|
|||||||
import base64
|
import base64
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tools.screen_capture import get_screenshot
|
from tools.screen_capture import get_screenshot
|
||||||
from agent.llm_utils.oai import encode_image
|
from agent.llm_utils.utils import encode_image
|
||||||
|
|
||||||
OUTPUT_DIR = "./tmp/outputs"
|
OUTPUT_DIR = "./tmp/outputs"
|
||||||
|
|
||||||
|
|||||||
13
computer_use_demo/gradio/agent/llm_utils/utils.py
Normal file
13
computer_use_demo/gradio/agent/llm_utils/utils.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
|
def is_image_path(text):
|
||||||
|
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
|
||||||
|
if text.endswith(image_extensions):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def encode_image(image_path):
|
||||||
|
"""Encode image file to base64."""
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
@@ -11,6 +11,7 @@ from anthropic.types import ToolResultBlockParam
|
|||||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
|
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
|
||||||
|
|
||||||
from agent.llm_utils.oai import run_oai_interleaved
|
from agent.llm_utils.oai import run_oai_interleaved
|
||||||
|
from agent.llm_utils.groqclient import run_groq_interleaved
|
||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -39,9 +40,12 @@ class VLMAgent:
|
|||||||
):
|
):
|
||||||
if model == "omniparser + gpt-4o":
|
if model == "omniparser + gpt-4o":
|
||||||
self.model = "gpt-4o-2024-11-20"
|
self.model = "gpt-4o-2024-11-20"
|
||||||
|
elif model == "omniparser + R1":
|
||||||
|
self.model = "deepseek-r1-distill-llama-70b"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model {model} not supported")
|
raise ValueError(f"Model {model} not supported")
|
||||||
|
|
||||||
|
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.api_response_callback = api_response_callback
|
self.api_response_callback = api_response_callback
|
||||||
@@ -108,8 +112,17 @@ class VLMAgent:
|
|||||||
print(f"oai token usage: {token_usage}")
|
print(f"oai token usage: {token_usage}")
|
||||||
self.total_token_usage += token_usage
|
self.total_token_usage += token_usage
|
||||||
self.total_cost += (token_usage * 0.15 / 1000000) # https://openai.com/api/pricing/
|
self.total_cost += (token_usage * 0.15 / 1000000) # https://openai.com/api/pricing/
|
||||||
elif "phi" in self.model:
|
elif "r1" in self.model:
|
||||||
pass # TODO
|
vlm_response, token_usage = run_groq_interleaved(
|
||||||
|
messages=planner_messages,
|
||||||
|
system=system,
|
||||||
|
llm=self.model,
|
||||||
|
api_key=self.api_key,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
)
|
||||||
|
print(f"groq token usage: {token_usage}")
|
||||||
|
self.total_token_usage += token_usage
|
||||||
|
self.total_cost += (token_usage * 0.99 / 1000000)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Model {self.model} not supported")
|
raise ValueError(f"Model {self.model} not supported")
|
||||||
latency_vlm = time.time() - start
|
latency_vlm = time.time() - start
|
||||||
@@ -168,7 +181,7 @@ class VLMAgent:
|
|||||||
self.api_response_callback(response)
|
self.api_response_callback(response)
|
||||||
|
|
||||||
def _get_system_prompt(self, screen_info: str = ""):
|
def _get_system_prompt(self, screen_info: str = ""):
|
||||||
return f"""
|
main_section = f"""
|
||||||
You are using a Windows device.
|
You are using a Windows device.
|
||||||
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
|
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
|
||||||
You can only interact with the desktop GUI (no terminal or application menu access).
|
You can only interact with the desktop GUI (no terminal or application menu access).
|
||||||
@@ -222,11 +235,27 @@ Another Example:
|
|||||||
|
|
||||||
IMPORTANT NOTES:
|
IMPORTANT NOTES:
|
||||||
1. You should only give a single action at a time.
|
1. You should only give a single action at a time.
|
||||||
|
|
||||||
|
"""
|
||||||
|
thinking_model = "r1" in self.model
|
||||||
|
if not thinking_model:
|
||||||
|
main_section += """
|
||||||
2. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task.
|
2. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task.
|
||||||
|
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
main_section += """
|
||||||
|
2. In <think> XML tags give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. In <output> XML tags put the next action prediction JSON.
|
||||||
|
|
||||||
|
"""
|
||||||
|
main_section += """
|
||||||
3. Attach the next action prediction in the "Next Action".
|
3. Attach the next action prediction in the "Next Action".
|
||||||
4. You should not include other actions, such as keyboard shortcuts.
|
4. You should not include other actions, such as keyboard shortcuts.
|
||||||
5. When the task is completed, you should say "Next Action": "None" in the json field.
|
5. When the task is completed, you should say "Next Action": "None" in the json field.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
return main_section
|
||||||
|
|
||||||
def draw_action(self, vlm_response_json, image_base64):
|
def draw_action(self, vlm_response_json, image_base64):
|
||||||
# draw a circle using the coordinate in parsed_screen['som_image_base64']
|
# draw a circle using the coordinate in parsed_screen['som_image_base64']
|
||||||
image_data = base64.b64decode(image_base64)
|
image_data = base64.b64decode(image_base64)
|
||||||
|
|||||||
@@ -28,12 +28,13 @@ API_KEY_FILE = CONFIG_DIR / "api_key"
|
|||||||
INTRO_TEXT = '''
|
INTRO_TEXT = '''
|
||||||
🚀🤖✨ It's Play Time!
|
🚀🤖✨ It's Play Time!
|
||||||
|
|
||||||
Welcome to the OmniParser+X Computer Use Demo! X = [GPT-4o/4o-mini, Claude, Phi, Llama]. Let OmniParser turn your general purpose vision-langauge model to an AI agent.
|
Welcome to the OmniParser+X Computer Use Demo! X = [GPT-4o, R1, Claude]. Let OmniParser turn your general purpose vision-langauge model to an AI agent.
|
||||||
|
|
||||||
Type a message and press submit to start OmniParser+X. Press the trash icon in the chat to clear the message history.
|
Type a message and press submit to start OmniParser+X. Press the trash icon in the chat to clear the message history.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Gradio App")
|
parser = argparse.ArgumentParser(description="Gradio App")
|
||||||
parser.add_argument("--windows_host_url", type=str, default='localhost:8006')
|
parser.add_argument("--windows_host_url", type=str, default='localhost:8006')
|
||||||
parser.add_argument("--omniparser_server_url", type=str, default="localhost:8000")
|
parser.add_argument("--omniparser_server_url", type=str, default="localhost:8000")
|
||||||
@@ -255,11 +256,10 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
""")
|
""")
|
||||||
state = gr.State({}) # Use Gradio's state management
|
state = gr.State({})
|
||||||
|
|
||||||
setup_state(state.value) # Initialize the state
|
setup_state(state.value)
|
||||||
|
|
||||||
# Retrieve screen details
|
|
||||||
gr.Markdown("# OmniParser + ✖️ Demo")
|
gr.Markdown("# OmniParser + ✖️ Demo")
|
||||||
|
|
||||||
if not os.getenv("HIDE_WARNING", False):
|
if not os.getenv("HIDE_WARNING", False):
|
||||||
@@ -270,8 +270,8 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
model = gr.Dropdown(
|
model = gr.Dropdown(
|
||||||
label="Model",
|
label="Model",
|
||||||
choices=["omniparser + gpt-4o", "omniparser + phi35v", "claude-3-5-sonnet-20241022"],
|
choices=["omniparser + gpt-4o", "omniparser + R1", "claude-3-5-sonnet-20241022"],
|
||||||
value="omniparser + gpt-4o", # Set to one of the choices
|
value="omniparser + gpt-4o",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@@ -322,11 +322,14 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|||||||
|
|
||||||
if model_selection == "claude-3-5-sonnet-20241022":
|
if model_selection == "claude-3-5-sonnet-20241022":
|
||||||
provider_choices = [option.value for option in APIProvider if option.value != "openai"]
|
provider_choices = [option.value for option in APIProvider if option.value != "openai"]
|
||||||
elif model_selection == "omniparser + gpt-4o" or model_selection == "omniparser + phi35v":
|
elif model_selection == "omniparser + gpt-4o":
|
||||||
provider_choices = ["openai"]
|
provider_choices = ["openai"]
|
||||||
|
elif model_selection == "omniparser + R1":
|
||||||
|
provider_choices = ["groq"]
|
||||||
else:
|
else:
|
||||||
provider_choices = [option.value for option in APIProvider]
|
provider_choices = [option.value for option in APIProvider]
|
||||||
default_provider_value = provider_choices[0]
|
default_provider_value = provider_choices[0]
|
||||||
|
|
||||||
provider_interactive = len(provider_choices) > 1
|
provider_interactive = len(provider_choices) > 1
|
||||||
api_key_placeholder = f"{default_provider_value.title()} API Key"
|
api_key_placeholder = f"{default_provider_value.title()} API Key"
|
||||||
|
|
||||||
|
|||||||
@@ -29,3 +29,4 @@ google-auth<3,>=2
|
|||||||
screeninfo
|
screeninfo
|
||||||
uiautomation
|
uiautomation
|
||||||
dashscope
|
dashscope
|
||||||
|
groq
|
||||||
Reference in New Issue
Block a user