o1 (has vision) and o3-mini (no vision)

This commit is contained in:
Thomas Dhome-Casanova
2025-02-03 23:52:04 -08:00
parent 8725445881
commit 31d7b1d096
4 changed files with 26 additions and 14 deletions

View File

@@ -7,7 +7,6 @@ from .utils import is_image_path, encode_image
def run_oai_interleaved(messages: list, system: str, model_name: str, api_key: str, max_tokens=256, temperature=0, provider_base_url: str = "https://api.openai.com/v1"): def run_oai_interleaved(messages: list, system: str, model_name: str, api_key: str, max_tokens=256, temperature=0, provider_base_url: str = "https://api.openai.com/v1"):
headers = {"Content-Type": "application/json", headers = {"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"} "Authorization": f"Bearer {api_key}"}
final_messages = [{"role": "system", "content": system}] final_messages = [{"role": "system", "content": system}]
if type(messages) == list: if type(messages) == list:
@@ -16,7 +15,8 @@ def run_oai_interleaved(messages: list, system: str, model_name: str, api_key: s
if isinstance(item, dict): if isinstance(item, dict):
for cnt in item["content"]: for cnt in item["content"]:
if isinstance(cnt, str): if isinstance(cnt, str):
if is_image_path(cnt): if is_image_path(cnt) and 'o3-mini' not in model_name:
# 03 mini does not support images
base64_image = encode_image(cnt) base64_image = encode_image(cnt)
content = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} content = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
else: else:
@@ -41,9 +41,12 @@ def run_oai_interleaved(messages: list, system: str, model_name: str, api_key: s
payload = { payload = {
"model": model_name, "model": model_name,
"messages": final_messages, "messages": final_messages,
"max_tokens": max_tokens,
"temperature": temperature
} }
if 'o1' in model_name or 'o3-mini' in model_name:
payload['reasoning_effort'] = 'low'
payload['max_completion_tokens'] = max_tokens
else:
payload['max_tokens'] = max_tokens
response = requests.post( response = requests.post(
f"{provider_base_url}/chat/completions", headers=headers, json=payload f"{provider_base_url}/chat/completions", headers=headers, json=payload

View File

@@ -10,7 +10,7 @@ from anthropic import APIResponse
from anthropic.types import ToolResultBlockParam from anthropic.types import ToolResultBlockParam
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage
from agent.llm_utils.oai import run_oai_interleaved from agent.llm_utils.oaiclient import run_oai_interleaved
from agent.llm_utils.groqclient import run_groq_interleaved from agent.llm_utils.groqclient import run_groq_interleaved
from agent.llm_utils.utils import is_image_path from agent.llm_utils.utils import is_image_path
import time import time
@@ -45,6 +45,10 @@ class VLMAgent:
self.model = "deepseek-r1-distill-llama-70b" self.model = "deepseek-r1-distill-llama-70b"
elif model == "omniparser + qwen2.5vl": elif model == "omniparser + qwen2.5vl":
self.model = "qwen2.5-vl-72b-instruct" self.model = "qwen2.5-vl-72b-instruct"
elif model == "omniparser + o1":
self.model = "o1"
elif model == "omniparser + o3-mini":
self.model = "o3-mini"
else: else:
raise ValueError(f"Model {model} not supported") raise ValueError(f"Model {model} not supported")
@@ -69,9 +73,6 @@ class VLMAgent:
latency_omniparser = parsed_screen['latency'] latency_omniparser = parsed_screen['latency']
self.output_callback(f'-- Step {self.step_count}: --', sender="bot") self.output_callback(f'-- Step {self.step_count}: --', sender="bot")
screen_info = str(parsed_screen['screen_info']) screen_info = str(parsed_screen['screen_info'])
screenshot_uuid = parsed_screen['screenshot_uuid'] screenshot_uuid = parsed_screen['screenshot_uuid']
screen_width, screen_height = parsed_screen['width'], parsed_screen['height'] screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
@@ -90,7 +91,7 @@ class VLMAgent:
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_som_{screenshot_uuid}.png") planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_som_{screenshot_uuid}.png")
start = time.time() start = time.time()
if "gpt" in self.model: if "gpt" in self.model or "o1" in self.model or "o3-mini" in self.model:
vlm_response, token_usage = run_oai_interleaved( vlm_response, token_usage = run_oai_interleaved(
messages=planner_messages, messages=planner_messages,
system=system, system=system,
@@ -102,7 +103,12 @@ class VLMAgent:
) )
print(f"oai token usage: {token_usage}") print(f"oai token usage: {token_usage}")
self.total_token_usage += token_usage self.total_token_usage += token_usage
if 'gpt' in self.model:
self.total_cost += (token_usage * 2.5 / 1000000) # https://openai.com/api/pricing/ self.total_cost += (token_usage * 2.5 / 1000000) # https://openai.com/api/pricing/
elif 'o1' in self.model:
self.total_cost += (token_usage * 15 / 1000000) # https://openai.com/api/pricing/
elif 'o3-mini' in self.model:
self.total_cost += (token_usage * 1.1 / 1000000) # https://openai.com/api/pricing/
elif "r1" in self.model: elif "r1" in self.model:
vlm_response, token_usage = run_groq_interleaved( vlm_response, token_usage = run_groq_interleaved(
messages=planner_messages, messages=planner_messages,

View File

@@ -242,6 +242,7 @@ def process_input(user_input, state):
api_response_callback=partial(_api_response_callback, response_state=state["responses"]), api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
api_key=state["api_key"], api_key=state["api_key"],
only_n_most_recent_images=state["only_n_most_recent_images"], only_n_most_recent_images=state["only_n_most_recent_images"],
max_tokens=16384,
omniparser_url=args.omniparser_server_url omniparser_url=args.omniparser_server_url
): ):
if loop_msg is None or state.get("stop"): if loop_msg is None or state.get("stop"):
@@ -280,7 +281,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
with gr.Column(): with gr.Column():
model = gr.Dropdown( model = gr.Dropdown(
label="Model", label="Model",
choices=["omniparser + gpt-4o", "omniparser + R1", "omniparser + qwen2.5vl", "claude-3-5-sonnet-20241022"], choices=["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl", "claude-3-5-sonnet-20241022"],
value="omniparser + gpt-4o", value="omniparser + gpt-4o",
interactive=True, interactive=True,
) )
@@ -334,7 +335,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
if model_selection == "claude-3-5-sonnet-20241022": if model_selection == "claude-3-5-sonnet-20241022":
provider_choices = [option.value for option in APIProvider if option.value != "openai"] provider_choices = [option.value for option in APIProvider if option.value != "openai"]
elif model_selection == "omniparser + gpt-4o": elif model_selection in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini"]):
provider_choices = ["openai"] provider_choices = ["openai"]
elif model_selection == "omniparser + R1": elif model_selection == "omniparser + R1":
provider_choices = ["groq"] provider_choices = ["groq"]

View File

@@ -64,13 +64,15 @@ def sampling_loop_sync(
max_tokens=max_tokens, max_tokens=max_tokens,
only_n_most_recent_images=only_n_most_recent_images only_n_most_recent_images=only_n_most_recent_images
) )
elif model == "omniparser + gpt-4o" or model == "omniparser + R1" or model == "omniparser + qwen2.5vl": elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl"]):
actor = VLMAgent( actor = VLMAgent(
model=model, model=model,
provider=provider, provider=provider,
api_key=api_key, api_key=api_key,
api_response_callback=api_response_callback, api_response_callback=api_response_callback,
output_callback=output_callback, output_callback=output_callback,
max_tokens=max_tokens,
only_n_most_recent_images=only_n_most_recent_images
) )
else: else:
raise ValueError(f"Model {model} not supported") raise ValueError(f"Model {model} not supported")
@@ -100,7 +102,7 @@ def sampling_loop_sync(
messages.append({"content": tool_result_content, "role": "user"}) messages.append({"content": tool_result_content, "role": "user"})
elif model == "omniparser + gpt-4o" or model == "omniparser + R1" or model == "omniparser + qwen2.5vl": elif model in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini", "omniparser + R1", "omniparser + qwen2.5vl"]):
while True: while True:
parsed_screen = omniparser_client() parsed_screen = omniparser_client()
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen) tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)