From e268184f8d9d84631cc96556ff4ffbc8c6a2580e Mon Sep 17 00:00:00 2001 From: "Thomas Dhome Casanova (from Dev Box)" Date: Fri, 31 Jan 2025 18:37:14 -0800 Subject: [PATCH] init r1 --- .../gradio/agent/llm_utils/groqclient.py | 59 +++++++++++++++++++ .../gradio/agent/llm_utils/oai.py | 16 +---- .../agent/llm_utils/omniparserclient.py | 2 +- .../gradio/agent/llm_utils/utils.py | 13 ++++ computer_use_demo/gradio/agent/vlm_agent.py | 35 ++++++++++- computer_use_demo/gradio/app.py | 19 +++--- requirements.txt | 3 +- 7 files changed, 119 insertions(+), 28 deletions(-) create mode 100644 computer_use_demo/gradio/agent/llm_utils/groqclient.py create mode 100644 computer_use_demo/gradio/agent/llm_utils/utils.py diff --git a/computer_use_demo/gradio/agent/llm_utils/groqclient.py b/computer_use_demo/gradio/agent/llm_utils/groqclient.py new file mode 100644 index 0000000..812d929 --- /dev/null +++ b/computer_use_demo/gradio/agent/llm_utils/groqclient.py @@ -0,0 +1,59 @@ +from groq import Groq +import os +from .utils import is_image_path + +def run_groq_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0.6): + """ + Run a chat completion through Groq's API, ignoring any images in the messages. + """ + api_key = api_key or os.environ.get("GROQ_API_KEY") + if not api_key: + raise ValueError("GROQ_API_KEY is not set") + + client = Groq(api_key=api_key) + # avoid using system messages for R1 + final_messages = [{"role": "user", "content": system}] + + if isinstance(messages, list): + for item in messages: + if isinstance(item, dict): + # For dict items, concatenate all text content, ignoring images + text_contents = [] + for cnt in item["content"]: + if isinstance(cnt, str): + if not is_image_path(cnt): # Skip image paths + text_contents.append(cnt) + else: + text_contents.append(str(cnt)) + + if text_contents: # Only add if there's text content + message = {"role": "user", "content": " ".join(text_contents)} + final_messages.append(message) + else: # str + message = {"role": "user", "content": item} + final_messages.append(message) + + elif isinstance(messages, str): + final_messages.append({"role": "user", "content": messages}) + + try: + completion = client.chat.completions.create( + model="deepseek-r1-distill-llama-70b", + messages=final_messages, + temperature=0.6, + max_completion_tokens=max_tokens, + top_p=0.95, + stream=False, + reasoning_format="raw" + ) + + response = completion.choices[0].message.content + final_answer = response.split('\n')[-1] if '' in response else response + final_answer = final_answer.replace("", "").replace("", "") + token_usage = completion.usage.total_tokens + + return final_answer, token_usage + except Exception as e: + print(f"Error in interleaved Groq: {e}") + + return str(e), 0 \ No newline at end of file diff --git a/computer_use_demo/gradio/agent/llm_utils/oai.py b/computer_use_demo/gradio/agent/llm_utils/oai.py index 2d613f1..e2daba7 100644 --- a/computer_use_demo/gradio/agent/llm_utils/oai.py +++ b/computer_use_demo/gradio/agent/llm_utils/oai.py @@ -2,21 +2,9 @@ import os import logging import base64 import requests - -def is_image_path(text): - image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif") - if text.endswith(image_extensions): - return True - else: - return False - -def encode_image(image_path): - """Encode image file to base64.""" - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode("utf-8") +from .utils import is_image_path, encode_image def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0): - api_key = api_key or os.environ.get("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY is not set") @@ -54,8 +42,6 @@ def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max elif isinstance(messages, str): final_messages = [{"role": "user", "content": messages}] - print("[oai] sending messages:", {"role": "user", "content": messages}) - payload = { "model": llm, "messages": final_messages, diff --git a/computer_use_demo/gradio/agent/llm_utils/omniparserclient.py b/computer_use_demo/gradio/agent/llm_utils/omniparserclient.py index c29eb58..e90ddef 100644 --- a/computer_use_demo/gradio/agent/llm_utils/omniparserclient.py +++ b/computer_use_demo/gradio/agent/llm_utils/omniparserclient.py @@ -2,7 +2,7 @@ import requests import base64 from pathlib import Path from tools.screen_capture import get_screenshot -from agent.llm_utils.oai import encode_image +from agent.llm_utils.utils import encode_image OUTPUT_DIR = "./tmp/outputs" diff --git a/computer_use_demo/gradio/agent/llm_utils/utils.py b/computer_use_demo/gradio/agent/llm_utils/utils.py new file mode 100644 index 0000000..12ab36e --- /dev/null +++ b/computer_use_demo/gradio/agent/llm_utils/utils.py @@ -0,0 +1,13 @@ +import base64 + +def is_image_path(text): + image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif") + if text.endswith(image_extensions): + return True + else: + return False + +def encode_image(image_path): + """Encode image file to base64.""" + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") \ No newline at end of file diff --git a/computer_use_demo/gradio/agent/vlm_agent.py b/computer_use_demo/gradio/agent/vlm_agent.py index 76815a6..ece91c0 100644 --- a/computer_use_demo/gradio/agent/vlm_agent.py +++ b/computer_use_demo/gradio/agent/vlm_agent.py @@ -11,6 +11,7 @@ from anthropic.types import ToolResultBlockParam from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage from agent.llm_utils.oai import run_oai_interleaved +from agent.llm_utils.groqclient import run_groq_interleaved import time import re @@ -39,9 +40,12 @@ class VLMAgent: ): if model == "omniparser + gpt-4o": self.model = "gpt-4o-2024-11-20" + elif model == "omniparser + R1": + self.model = "deepseek-r1-distill-llama-70b" else: raise ValueError(f"Model {model} not supported") + self.provider = provider self.api_key = api_key self.api_response_callback = api_response_callback @@ -108,8 +112,17 @@ class VLMAgent: print(f"oai token usage: {token_usage}") self.total_token_usage += token_usage self.total_cost += (token_usage * 0.15 / 1000000) # https://openai.com/api/pricing/ - elif "phi" in self.model: - pass # TODO + elif "r1" in self.model: + vlm_response, token_usage = run_groq_interleaved( + messages=planner_messages, + system=system, + llm=self.model, + api_key=self.api_key, + max_tokens=self.max_tokens, + ) + print(f"groq token usage: {token_usage}") + self.total_token_usage += token_usage + self.total_cost += (token_usage * 0.99 / 1000000) else: raise ValueError(f"Model {self.model} not supported") latency_vlm = time.time() - start @@ -168,7 +181,7 @@ class VLMAgent: self.api_response_callback(response) def _get_system_prompt(self, screen_info: str = ""): - return f""" + main_section = f""" You are using a Windows device. You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot. You can only interact with the desktop GUI (no terminal or application menu access). @@ -222,11 +235,27 @@ Another Example: IMPORTANT NOTES: 1. You should only give a single action at a time. + +""" + thinking_model = "r1" in self.model + if not thinking_model: + main_section += """ 2. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. + +""" + else: + main_section += """ +2. In XML tags give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. In XML tags put the next action prediction JSON. + +""" + main_section += """ 3. Attach the next action prediction in the "Next Action". 4. You should not include other actions, such as keyboard shortcuts. 5. When the task is completed, you should say "Next Action": "None" in the json field. """ + + return main_section + def draw_action(self, vlm_response_json, image_base64): # draw a circle using the coordinate in parsed_screen['som_image_base64'] image_data = base64.b64decode(image_base64) diff --git a/computer_use_demo/gradio/app.py b/computer_use_demo/gradio/app.py index 97cc48a..7ca2a22 100644 --- a/computer_use_demo/gradio/app.py +++ b/computer_use_demo/gradio/app.py @@ -28,12 +28,13 @@ API_KEY_FILE = CONFIG_DIR / "api_key" INTRO_TEXT = ''' 🚀🤖✨ It's Play Time! -Welcome to the OmniParser+X Computer Use Demo! X = [GPT-4o/4o-mini, Claude, Phi, Llama]. Let OmniParser turn your general purpose vision-langauge model to an AI agent. +Welcome to the OmniParser+X Computer Use Demo! X = [GPT-4o, R1, Claude]. Let OmniParser turn your general purpose vision-langauge model to an AI agent. Type a message and press submit to start OmniParser+X. Press the trash icon in the chat to clear the message history. ''' def parse_arguments(): + parser = argparse.ArgumentParser(description="Gradio App") parser.add_argument("--windows_host_url", type=str, default='localhost:8006') parser.add_argument("--omniparser_server_url", type=str, default="localhost:8000") @@ -255,11 +256,10 @@ with gr.Blocks(theme=gr.themes.Default()) as demo: } """) - state = gr.State({}) # Use Gradio's state management + state = gr.State({}) - setup_state(state.value) # Initialize the state - - # Retrieve screen details + setup_state(state.value) + gr.Markdown("# OmniParser + ✖️ Demo") if not os.getenv("HIDE_WARNING", False): @@ -270,8 +270,8 @@ with gr.Blocks(theme=gr.themes.Default()) as demo: with gr.Column(): model = gr.Dropdown( label="Model", - choices=["omniparser + gpt-4o", "omniparser + phi35v", "claude-3-5-sonnet-20241022"], - value="omniparser + gpt-4o", # Set to one of the choices + choices=["omniparser + gpt-4o", "omniparser + R1", "claude-3-5-sonnet-20241022"], + value="omniparser + gpt-4o", interactive=True, ) with gr.Column(): @@ -322,11 +322,14 @@ with gr.Blocks(theme=gr.themes.Default()) as demo: if model_selection == "claude-3-5-sonnet-20241022": provider_choices = [option.value for option in APIProvider if option.value != "openai"] - elif model_selection == "omniparser + gpt-4o" or model_selection == "omniparser + phi35v": + elif model_selection == "omniparser + gpt-4o": provider_choices = ["openai"] + elif model_selection == "omniparser + R1": + provider_choices = ["groq"] else: provider_choices = [option.value for option in APIProvider] default_provider_value = provider_choices[0] + provider_interactive = len(provider_choices) > 1 api_key_placeholder = f"{default_provider_value.title()} API Key" diff --git a/requirements.txt b/requirements.txt index 5631fb6..b55038f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,5 @@ boto3>=1.28.57 google-auth<3,>=2 screeninfo uiautomation -dashscope \ No newline at end of file +dashscope +groq \ No newline at end of file