R1 fixes
This commit is contained in:
@@ -56,37 +56,27 @@ class VLMAgent:
|
|||||||
self.print_usage = print_usage
|
self.print_usage = print_usage
|
||||||
self.total_token_usage = 0
|
self.total_token_usage = 0
|
||||||
self.total_cost = 0
|
self.total_cost = 0
|
||||||
|
self.step_count = 0
|
||||||
|
|
||||||
self.system = ''
|
self.system = ''
|
||||||
|
|
||||||
def __call__(self, messages: list, parsed_screen: list[str, list, dict]):
|
def __call__(self, messages: list, parsed_screen: list[str, list, dict]):
|
||||||
# Show results of Omniparser
|
self.step_count += 1
|
||||||
image_base64 = parsed_screen['original_screenshot_base64']
|
image_base64 = parsed_screen['original_screenshot_base64']
|
||||||
latency_omniparser = parsed_screen['latency']
|
latency_omniparser = parsed_screen['latency']
|
||||||
self.output_callback(f'Screenshot for OmniParser Agent:\n<img src="data:image/png;base64,{image_base64}">',
|
self.output_callback(f'-- Step {self.step_count}: --', sender="bot")
|
||||||
sender="bot")
|
|
||||||
self.output_callback(f'Set of Marks Screenshot for OmniParser Agent:\n<img src="data:image/png;base64,{parsed_screen["som_image_base64"]}">', sender="bot")
|
|
||||||
screen_info = str(parsed_screen['screen_info'])
|
screen_info = str(parsed_screen['screen_info'])
|
||||||
# self.output_callback(f'Screen Info for OmniParser Agent:\n{screen_info}', sender="bot")
|
|
||||||
self.output_callback(
|
|
||||||
f'<details>'
|
|
||||||
f' <summary>Screen Info for OmniParser Agent</summary>'
|
|
||||||
f' <pre>{screen_info}</pre>'
|
|
||||||
f'</details>',
|
|
||||||
sender="bot"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
screenshot_uuid = parsed_screen['screenshot_uuid']
|
screenshot_uuid = parsed_screen['screenshot_uuid']
|
||||||
screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
|
screen_width, screen_height = parsed_screen['width'], parsed_screen['height']
|
||||||
|
|
||||||
# example parsed_screen: {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, "screen_info"}
|
|
||||||
boxids_and_labels = parsed_screen["screen_info"]
|
boxids_and_labels = parsed_screen["screen_info"]
|
||||||
system = self._get_system_prompt(boxids_and_labels)
|
system = self._get_system_prompt(boxids_and_labels)
|
||||||
|
|
||||||
# drop looping actions msg, byte image etc
|
# drop looping actions msg, byte image etc
|
||||||
planner_messages = messages
|
planner_messages = messages
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
planner_messages = _keep_latest_images(planner_messages)
|
planner_messages = _keep_latest_images(planner_messages)
|
||||||
# if self.only_n_most_recent_images:
|
# if self.only_n_most_recent_images:
|
||||||
# _maybe_filter_to_n_most_recent_images(planner_messages, self.only_n_most_recent_images)
|
# _maybe_filter_to_n_most_recent_images(planner_messages, self.only_n_most_recent_images)
|
||||||
@@ -98,7 +88,6 @@ class VLMAgent:
|
|||||||
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_{screenshot_uuid}.png")
|
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_{screenshot_uuid}.png")
|
||||||
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_som_{screenshot_uuid}.png")
|
planner_messages[-1]["content"].append(f"{OUTPUT_DIR}/screenshot_som_{screenshot_uuid}.png")
|
||||||
|
|
||||||
# print(f"Sending messages to VLMPlanner : {planner_messages}")
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
if "gpt" in self.model:
|
if "gpt" in self.model:
|
||||||
vlm_response, token_usage = run_oai_interleaved(
|
vlm_response, token_usage = run_oai_interleaved(
|
||||||
@@ -113,6 +102,7 @@ class VLMAgent:
|
|||||||
self.total_token_usage += token_usage
|
self.total_token_usage += token_usage
|
||||||
self.total_cost += (token_usage * 0.15 / 1000000) # https://openai.com/api/pricing/
|
self.total_cost += (token_usage * 0.15 / 1000000) # https://openai.com/api/pricing/
|
||||||
elif "r1" in self.model:
|
elif "r1" in self.model:
|
||||||
|
print(f"Sending messages to Groq: {planner_messages}")
|
||||||
vlm_response, token_usage = run_groq_interleaved(
|
vlm_response, token_usage = run_groq_interleaved(
|
||||||
messages=planner_messages,
|
messages=planner_messages,
|
||||||
system=system,
|
system=system,
|
||||||
@@ -126,33 +116,46 @@ class VLMAgent:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Model {self.model} not supported")
|
raise ValueError(f"Model {self.model} not supported")
|
||||||
latency_vlm = time.time() - start
|
latency_vlm = time.time() - start
|
||||||
self.output_callback(f"VLMPlanner latency: {latency_vlm}, Omniparser latency: {latency_omniparser}", sender="bot")
|
self.output_callback(f"LLM: {latency_vlm:.2f}s, OmniParser: {latency_omniparser:.2f}s", sender="bot")
|
||||||
|
|
||||||
print(f"VLMPlanner response: {vlm_response}")
|
print(f"{vlm_response}")
|
||||||
|
|
||||||
if self.print_usage:
|
if self.print_usage:
|
||||||
print(f"VLMPlanner total token usage so far: {self.total_token_usage}. Total cost so far: $USD{self.total_cost:.5f}")
|
print(f"Total token so far: {self.total_token_usage}. Total cost so far: $USD{self.total_cost:.5f}")
|
||||||
|
|
||||||
vlm_response_json = extract_data(vlm_response, "json")
|
vlm_response_json = extract_data(vlm_response, "json")
|
||||||
vlm_response_json = json.loads(vlm_response_json)
|
vlm_response_json = json.loads(vlm_response_json)
|
||||||
|
|
||||||
# map "box_id" to "idx" in parsed_screen, and output the xy coordinate of bbox
|
img_to_show_base64 = parsed_screen["som_image_base64"]
|
||||||
try:
|
if "Box ID" in vlm_response_json:
|
||||||
bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["Box ID"])]["bbox"]
|
bbox = parsed_screen["parsed_content_list"][int(vlm_response_json["Box ID"])]["bbox"]
|
||||||
vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 * screen_width), int((bbox[1] + bbox[3]) / 2 * screen_height)]
|
vlm_response_json["box_centroid_coordinate"] = [int((bbox[0] + bbox[2]) / 2 * screen_width), int((bbox[1] + bbox[3]) / 2 * screen_height)]
|
||||||
# draw a circle on the screenshot image to indicate the action
|
img_to_show_data = base64.b64decode(img_to_show_base64)
|
||||||
self.draw_action(vlm_response_json, image_base64)
|
img_to_show = Image.open(BytesIO(img_to_show_data))
|
||||||
except:
|
|
||||||
print("No Box ID in the response.")
|
|
||||||
|
|
||||||
# Convert the VLM output to a string for printing in chat
|
draw = ImageDraw.Draw(img_to_show)
|
||||||
|
x, y = vlm_response_json["box_centroid_coordinate"]
|
||||||
|
radius = 10
|
||||||
|
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
|
||||||
|
draw.ellipse((x - radius*3, y - radius*3, x + radius*3, y + radius*3), fill=None, outline='red', width=2)
|
||||||
|
|
||||||
|
buffered = BytesIO()
|
||||||
|
img_to_show.save(buffered, format="PNG")
|
||||||
|
img_to_show_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
self.output_callback(f'<img src="data:image/png;base64,{img_to_show_base64}">', sender="bot")
|
||||||
|
self.output_callback(
|
||||||
|
f'<details>'
|
||||||
|
f' <summary>Screen Info for OmniParser Agent</summary>'
|
||||||
|
f' <pre>{screen_info}</pre>'
|
||||||
|
f'</details>',
|
||||||
|
sender="bot"
|
||||||
|
)
|
||||||
vlm_plan_str = ""
|
vlm_plan_str = ""
|
||||||
for key, value in vlm_response_json.items():
|
for key, value in vlm_response_json.items():
|
||||||
if key == "Reasoning":
|
if key == "Reasoning":
|
||||||
vlm_plan_str += f'{value}'
|
vlm_plan_str += f'{value}'
|
||||||
else:
|
else:
|
||||||
vlm_plan_str += f'\n{key}: {value}'
|
vlm_plan_str += f'\n{key}: {value}'
|
||||||
# self.output_callback(f"OmniParser Agent:\n{vlm_plan_str}", sender="bot")
|
|
||||||
|
|
||||||
# construct the response so that anthropicExcutor can execute the tool
|
# construct the response so that anthropicExcutor can execute the tool
|
||||||
response_content = [BetaTextBlock(text=vlm_plan_str, type='text')]
|
response_content = [BetaTextBlock(text=vlm_plan_str, type='text')]
|
||||||
@@ -161,6 +164,7 @@ class VLMAgent:
|
|||||||
input={'action': 'mouse_move', 'coordinate': vlm_response_json["box_centroid_coordinate"]},
|
input={'action': 'mouse_move', 'coordinate': vlm_response_json["box_centroid_coordinate"]},
|
||||||
name='computer', type='tool_use')
|
name='computer', type='tool_use')
|
||||||
response_content.append(move_cursor_block)
|
response_content.append(move_cursor_block)
|
||||||
|
|
||||||
if vlm_response_json["Next Action"] == "type":
|
if vlm_response_json["Next Action"] == "type":
|
||||||
click_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}', input={'action': 'left_click'}, name='computer', type='tool_use')
|
click_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}', input={'action': 'left_click'}, name='computer', type='tool_use')
|
||||||
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
|
||||||
@@ -194,13 +198,12 @@ Here is the list of all detected bounding boxes by IDs on the screen and their d
|
|||||||
Your available "Next Action" only include:
|
Your available "Next Action" only include:
|
||||||
- type: type a string of text.
|
- type: type a string of text.
|
||||||
- left_click: Describe the ui element to be clicked.
|
- left_click: Describe the ui element to be clicked.
|
||||||
- double_click: Describe the ui element to be double clicked.
|
|
||||||
- right_click: Describe the ui element to be right clicked.
|
- right_click: Describe the ui element to be right clicked.
|
||||||
- escape: Press an ESCAPE key.
|
- double_click: Describe the ui element to be double clicked.
|
||||||
- hover: Describe the ui element to be hovered.
|
- hover: Describe the ui element to be hovered.
|
||||||
- scroll_up: Scroll the screen up.
|
- scroll_up: Scroll the screen up.
|
||||||
- scroll_down: Scroll the screen down.
|
- scroll_down: Scroll the screen down.
|
||||||
- press: Describe the ui element to be pressed.
|
- wait: Wait for 1 second for the device to load or respond.
|
||||||
|
|
||||||
Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on, and the value (if the action is 'type') in order to complete the task.
|
Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on, and the value (if the action is 'type') in order to complete the task.
|
||||||
|
|
||||||
@@ -209,8 +212,8 @@ Output format:
|
|||||||
{{
|
{{
|
||||||
"Reasoning": str, # describe what is in the current screen, taking into account the history, then describe your step-by-step thoughts on how to achieve the task, choose one action from available actions at a time.
|
"Reasoning": str, # describe what is in the current screen, taking into account the history, then describe your step-by-step thoughts on how to achieve the task, choose one action from available actions at a time.
|
||||||
"Next Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
|
"Next Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
|
||||||
'Box ID': n,
|
"Box ID": n,
|
||||||
'value': "xxx" # if the action is type, you should provide the text to type.
|
"value": "xxx" # only provide value field if the action is type, else don't include value key
|
||||||
}}
|
}}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -219,7 +222,7 @@ One Example:
|
|||||||
{{
|
{{
|
||||||
"Reasoning": "The current screen shows google result of amazon, in previous action I have searched amazon on google. Then I need to click on the first search results to go to amazon.com.",
|
"Reasoning": "The current screen shows google result of amazon, in previous action I have searched amazon on google. Then I need to click on the first search results to go to amazon.com.",
|
||||||
"Next Action": "left_click",
|
"Next Action": "left_click",
|
||||||
'Box ID': m,
|
"Box ID": m
|
||||||
}}
|
}}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -228,8 +231,8 @@ Another Example:
|
|||||||
{{
|
{{
|
||||||
"Reasoning": "The current screen shows the front page of amazon. There is no previous action. Therefore I need to type "Apple watch" in the search bar.",
|
"Reasoning": "The current screen shows the front page of amazon. There is no previous action. Therefore I need to type "Apple watch" in the search bar.",
|
||||||
"Next Action": "type",
|
"Next Action": "type",
|
||||||
'Box ID': n,
|
"Box ID": n,
|
||||||
'value': "Apple watch"
|
"value": "Apple watch"
|
||||||
}}
|
}}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -251,26 +254,11 @@ IMPORTANT NOTES:
|
|||||||
main_section += """
|
main_section += """
|
||||||
3. Attach the next action prediction in the "Next Action".
|
3. Attach the next action prediction in the "Next Action".
|
||||||
4. You should not include other actions, such as keyboard shortcuts.
|
4. You should not include other actions, such as keyboard shortcuts.
|
||||||
5. When the task is completed, you should say "Next Action": "None" in the json field.
|
5. When the task is completed, don't complete additional actions. You should say "Next Action": "None" in the json field.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return main_section
|
return main_section
|
||||||
|
|
||||||
def draw_action(self, vlm_response_json, image_base64):
|
|
||||||
# draw a circle using the coordinate in parsed_screen['som_image_base64']
|
|
||||||
image_data = base64.b64decode(image_base64)
|
|
||||||
image = Image.open(BytesIO(image_data))
|
|
||||||
|
|
||||||
draw = ImageDraw.Draw(image)
|
|
||||||
x, y = vlm_response_json["box_centroid_coordinate"]
|
|
||||||
radius = 30
|
|
||||||
draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
|
|
||||||
buffered = BytesIO()
|
|
||||||
image.save(buffered, format="PNG")
|
|
||||||
image_with_circle_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
||||||
self.output_callback(f'Action performed on the red circle with centroid ({x}, {y}), for OmniParser Agent:\n<img src="data:image/png;base64,{image_with_circle_base64}">', sender="bot")
|
|
||||||
|
|
||||||
|
|
||||||
def _keep_latest_images(messages):
|
def _keep_latest_images(messages):
|
||||||
for i in range(len(messages)-1):
|
for i in range(len(messages)-1):
|
||||||
if isinstance(messages[i]["content"], list):
|
if isinstance(messages[i]["content"], list):
|
||||||
|
|||||||
@@ -64,14 +64,7 @@ def sampling_loop_sync(
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
only_n_most_recent_images=only_n_most_recent_images
|
only_n_most_recent_images=only_n_most_recent_images
|
||||||
)
|
)
|
||||||
|
elif model == "omniparser + gpt-4o" or model == "omniparser + R1":
|
||||||
# from IPython.core.debugger import Pdb; Pdb().set_trace()
|
|
||||||
executor = AnthropicExecutor(
|
|
||||||
output_callback=output_callback,
|
|
||||||
tool_output_callback=tool_output_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model == "omniparser + gpt-4o" or model == "omniparser + phi35v":
|
|
||||||
actor = VLMAgent(
|
actor = VLMAgent(
|
||||||
model=model,
|
model=model,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
@@ -79,14 +72,12 @@ def sampling_loop_sync(
|
|||||||
api_response_callback=api_response_callback,
|
api_response_callback=api_response_callback,
|
||||||
output_callback=output_callback,
|
output_callback=output_callback,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Model {model} not supported")
|
||||||
executor = AnthropicExecutor(
|
executor = AnthropicExecutor(
|
||||||
output_callback=output_callback,
|
output_callback=output_callback,
|
||||||
tool_output_callback=tool_output_callback,
|
tool_output_callback=tool_output_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Model {model} not supported")
|
|
||||||
print(f"Model Inited: {model}, Provider: {provider}")
|
print(f"Model Inited: {model}, Provider: {provider}")
|
||||||
|
|
||||||
tool_result_content = None
|
tool_result_content = None
|
||||||
@@ -109,7 +100,7 @@ def sampling_loop_sync(
|
|||||||
|
|
||||||
messages.append({"content": tool_result_content, "role": "user"})
|
messages.append({"content": tool_result_content, "role": "user"})
|
||||||
|
|
||||||
elif model == "omniparser + gpt-4o" or model == "omniparser + phi35v":
|
elif model == "omniparser + gpt-4o" or model == "omniparser + R1":
|
||||||
while True:
|
while True:
|
||||||
parsed_screen = omniparser_client()
|
parsed_screen = omniparser_client()
|
||||||
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
|
tools_use_needed, vlm_response_json = actor(messages=messages, parsed_screen=parsed_screen)
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ Action = Literal[
|
|||||||
"double_click",
|
"double_click",
|
||||||
"screenshot",
|
"screenshot",
|
||||||
"cursor_position",
|
"cursor_position",
|
||||||
|
"hover",
|
||||||
|
"wait"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -213,7 +215,11 @@ class ComputerTool(BaseAnthropicTool):
|
|||||||
elif action == "scroll_down":
|
elif action == "scroll_down":
|
||||||
self.send_to_vm("pyautogui.scroll(-100)")
|
self.send_to_vm("pyautogui.scroll(-100)")
|
||||||
return ToolResult(output=f"Performed {action}")
|
return ToolResult(output=f"Performed {action}")
|
||||||
|
if action == "hover":
|
||||||
|
return ToolResult(output=f"Performed {action}")
|
||||||
|
if action == "wait":
|
||||||
|
time.sleep(1)
|
||||||
|
return ToolResult(output=f"Performed {action}")
|
||||||
raise ToolError(f"Invalid action: {action}")
|
raise ToolError(f"Invalid action: {action}")
|
||||||
|
|
||||||
def send_to_vm(self, action: str):
|
def send_to_vm(self, action: str):
|
||||||
|
|||||||
Reference in New Issue
Block a user