diff --git a/computer_use_demo/gradio/agent/llm_utils/groqclient.py b/computer_use_demo/gradio/agent/llm_utils/groqclient.py
new file mode 100644
index 0000000..812d929
--- /dev/null
+++ b/computer_use_demo/gradio/agent/llm_utils/groqclient.py
@@ -0,0 +1,59 @@
+from groq import Groq
+import os
+from .utils import is_image_path
+
+def run_groq_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0.6):
+ """
+ Run a chat completion through Groq's API, ignoring any images in the messages.
+ """
+ api_key = api_key or os.environ.get("GROQ_API_KEY")
+ if not api_key:
+ raise ValueError("GROQ_API_KEY is not set")
+
+ client = Groq(api_key=api_key)
+ # avoid using system messages for R1
+ final_messages = [{"role": "user", "content": system}]
+
+ if isinstance(messages, list):
+ for item in messages:
+ if isinstance(item, dict):
+ # For dict items, concatenate all text content, ignoring images
+ text_contents = []
+ for cnt in item["content"]:
+ if isinstance(cnt, str):
+ if not is_image_path(cnt): # Skip image paths
+ text_contents.append(cnt)
+ else:
+ text_contents.append(str(cnt))
+
+ if text_contents: # Only add if there's text content
+ message = {"role": "user", "content": " ".join(text_contents)}
+ final_messages.append(message)
+ else: # str
+ message = {"role": "user", "content": item}
+ final_messages.append(message)
+
+ elif isinstance(messages, str):
+ final_messages.append({"role": "user", "content": messages})
+
+ try:
+ completion = client.chat.completions.create(
+ model="deepseek-r1-distill-llama-70b",
+ messages=final_messages,
+ temperature=0.6,
+ max_completion_tokens=max_tokens,
+ top_p=0.95,
+ stream=False,
+ reasoning_format="raw"
+ )
+
+ response = completion.choices[0].message.content
+ final_answer = response.split('\n')[-1] if '' in response else response
+ final_answer = final_answer.replace("", "")
+ token_usage = completion.usage.total_tokens
+
+ return final_answer, token_usage
+ except Exception as e:
+ print(f"Error in interleaved Groq: {e}")
+
+ return str(e), 0
\ No newline at end of file
diff --git a/computer_use_demo/gradio/agent/llm_utils/oai.py b/computer_use_demo/gradio/agent/llm_utils/oai.py
index 2d613f1..e2daba7 100644
--- a/computer_use_demo/gradio/agent/llm_utils/oai.py
+++ b/computer_use_demo/gradio/agent/llm_utils/oai.py
@@ -2,21 +2,9 @@ import os
import logging
import base64
import requests
-
-def is_image_path(text):
- image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
- if text.endswith(image_extensions):
- return True
- else:
- return False
-
-def encode_image(image_path):
- """Encode image file to base64."""
- with open(image_path, "rb") as image_file:
- return base64.b64encode(image_file.read()).decode("utf-8")
+from .utils import is_image_path, encode_image
def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0):
-
api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY is not set")
@@ -54,8 +42,6 @@ def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max
elif isinstance(messages, str):
final_messages = [{"role": "user", "content": messages}]
- print("[oai] sending messages:", {"role": "user", "content": messages})
-
payload = {
"model": llm,
"messages": final_messages,
diff --git a/computer_use_demo/gradio/agent/llm_utils/omniparserclient.py b/computer_use_demo/gradio/agent/llm_utils/omniparserclient.py
index c29eb58..e90ddef 100644
--- a/computer_use_demo/gradio/agent/llm_utils/omniparserclient.py
+++ b/computer_use_demo/gradio/agent/llm_utils/omniparserclient.py
@@ -2,7 +2,7 @@ import requests
import base64
from pathlib import Path
from tools.screen_capture import get_screenshot
-from agent.llm_utils.oai import encode_image
+from agent.llm_utils.utils import encode_image
OUTPUT_DIR = "./tmp/outputs"
diff --git a/computer_use_demo/gradio/agent/llm_utils/utils.py b/computer_use_demo/gradio/agent/llm_utils/utils.py
new file mode 100644
index 0000000..12ab36e
--- /dev/null
+++ b/computer_use_demo/gradio/agent/llm_utils/utils.py
@@ -0,0 +1,13 @@
+import base64
+
+def is_image_path(text):
+ image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
+ if text.endswith(image_extensions):
+ return True
+ else:
+ return False
+
+def encode_image(image_path):
+ """Encode image file to base64."""
+ with open(image_path, "rb") as image_file:
+ return base64.b64encode(image_file.read()).decode("utf-8")
\ No newline at end of file
diff --git a/computer_use_demo/gradio/agent/vlm_agent.py b/computer_use_demo/gradio/agent/vlm_agent.py
index 76815a6..ece91c0 100644
--- a/computer_use_demo/gradio/agent/vlm_agent.py
+++ b/computer_use_demo/gradio/agent/vlm_agent.py
@@ -11,6 +11,7 @@ from anthropic.types import ToolResultBlockParam
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
from agent.llm_utils.oai import run_oai_interleaved
+from agent.llm_utils.groqclient import run_groq_interleaved
import time
import re
@@ -39,9 +40,12 @@ class VLMAgent:
):
if model == "omniparser + gpt-4o":
self.model = "gpt-4o-2024-11-20"
+ elif model == "omniparser + R1":
+ self.model = "deepseek-r1-distill-llama-70b"
else:
raise ValueError(f"Model {model} not supported")
+
self.provider = provider
self.api_key = api_key
self.api_response_callback = api_response_callback
@@ -108,8 +112,17 @@ class VLMAgent:
print(f"oai token usage: {token_usage}")
self.total_token_usage += token_usage
self.total_cost += (token_usage * 0.15 / 1000000) # https://openai.com/api/pricing/
- elif "phi" in self.model:
- pass # TODO
+ elif "r1" in self.model:
+ vlm_response, token_usage = run_groq_interleaved(
+ messages=planner_messages,
+ system=system,
+ llm=self.model,
+ api_key=self.api_key,
+ max_tokens=self.max_tokens,
+ )
+ print(f"groq token usage: {token_usage}")
+ self.total_token_usage += token_usage
+ self.total_cost += (token_usage * 0.99 / 1000000)
else:
raise ValueError(f"Model {self.model} not supported")
latency_vlm = time.time() - start
@@ -168,7 +181,7 @@ class VLMAgent:
self.api_response_callback(response)
def _get_system_prompt(self, screen_info: str = ""):
- return f"""
+ main_section = f"""
You are using a Windows device.
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
You can only interact with the desktop GUI (no terminal or application menu access).
@@ -222,11 +235,27 @@ Another Example:
IMPORTANT NOTES:
1. You should only give a single action at a time.
+
+"""
+ thinking_model = "r1" in self.model
+ if not thinking_model:
+ main_section += """
2. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task.
+
+"""
+ else:
+ main_section += """
+2. In XML tags give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task. In