remove custom prompt functionality
This commit is contained in:
@@ -67,8 +67,6 @@ def setup_state(state):
|
|||||||
state["tools"] = {}
|
state["tools"] = {}
|
||||||
if "only_n_most_recent_images" not in state:
|
if "only_n_most_recent_images" not in state:
|
||||||
state["only_n_most_recent_images"] = 2
|
state["only_n_most_recent_images"] = 2
|
||||||
if "custom_system_prompt" not in state:
|
|
||||||
state["custom_system_prompt"] = ""
|
|
||||||
if 'chatbot_messages' not in state:
|
if 'chatbot_messages' not in state:
|
||||||
state['chatbot_messages'] = []
|
state['chatbot_messages'] = []
|
||||||
|
|
||||||
@@ -204,7 +202,6 @@ def process_input(user_input, state):
|
|||||||
|
|
||||||
# Run sampling_loop_sync with the chatbot_output_callback
|
# Run sampling_loop_sync with the chatbot_output_callback
|
||||||
for loop_msg in sampling_loop_sync(
|
for loop_msg in sampling_loop_sync(
|
||||||
system_prompt_suffix=state["custom_system_prompt"],
|
|
||||||
model=state["model"],
|
model=state["model"],
|
||||||
provider=state["provider"],
|
provider=state["provider"],
|
||||||
messages=state["messages"],
|
messages=state["messages"],
|
||||||
@@ -252,12 +249,6 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|||||||
value="omniparser + gpt-4o", # Set to one of the choices
|
value="omniparser + gpt-4o", # Set to one of the choices
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
with gr.Column():
|
|
||||||
custom_prompt = gr.Textbox(
|
|
||||||
label="System Prompt Suffix",
|
|
||||||
value="",
|
|
||||||
interactive=True,
|
|
||||||
)
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
only_n_images = gr.Slider(
|
only_n_images = gr.Slider(
|
||||||
label="N most recent screenshots",
|
label="N most recent screenshots",
|
||||||
@@ -333,9 +324,6 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return provider_update, api_key_update
|
return provider_update, api_key_update
|
||||||
|
|
||||||
def update_system_prompt_suffix(system_prompt_suffix, state):
|
|
||||||
state["custom_system_prompt"] = system_prompt_suffix
|
|
||||||
|
|
||||||
def update_only_n_images(only_n_images_value, state):
|
def update_only_n_images(only_n_images_value, state):
|
||||||
state["only_n_most_recent_images"] = only_n_images_value
|
state["only_n_most_recent_images"] = only_n_images_value
|
||||||
@@ -357,7 +345,6 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|||||||
state[f'{state["provider"]}_api_key'] = api_key_value
|
state[f'{state["provider"]}_api_key'] = api_key_value
|
||||||
|
|
||||||
model.change(fn=update_model, inputs=[model, state], outputs=[provider, api_key])
|
model.change(fn=update_model, inputs=[model, state], outputs=[provider, api_key])
|
||||||
custom_prompt.change(fn=update_system_prompt_suffix, inputs=[custom_prompt, state], outputs=None)
|
|
||||||
only_n_images.change(fn=update_only_n_images, inputs=[only_n_images, state], outputs=None)
|
only_n_images.change(fn=update_only_n_images, inputs=[only_n_images, state], outputs=None)
|
||||||
provider.change(fn=update_provider, inputs=[provider, state], outputs=api_key)
|
provider.change(fn=update_provider, inputs=[provider, state], outputs=api_key)
|
||||||
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
|
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
|
||||||
|
|||||||
@@ -61,8 +61,7 @@ class AnthropicActor:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
provider: APIProvider,
|
provider: APIProvider,
|
||||||
system_prompt_suffix: str,
|
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
|
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
@@ -72,7 +71,6 @@ class AnthropicActor:
|
|||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.system_prompt_suffix = system_prompt_suffix
|
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.api_response_callback = api_response_callback
|
self.api_response_callback = api_response_callback
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
@@ -83,9 +81,7 @@ class AnthropicActor:
|
|||||||
ComputerTool(selected_screen=selected_screen),
|
ComputerTool(selected_screen=selected_screen),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.system = (
|
self.system = SYSTEM_PROMPT
|
||||||
f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.total_token_usage = 0
|
self.total_token_usage = 0
|
||||||
self.total_cost = 0
|
self.total_cost = 0
|
||||||
|
|||||||
@@ -75,7 +75,6 @@ def sampling_loop_sync(
|
|||||||
*,
|
*,
|
||||||
model: str,
|
model: str,
|
||||||
provider: APIProvider | None,
|
provider: APIProvider | None,
|
||||||
system_prompt_suffix: str,
|
|
||||||
messages: list[BetaMessageParam],
|
messages: list[BetaMessageParam],
|
||||||
output_callback: Callable[[BetaContentBlock], None],
|
output_callback: Callable[[BetaContentBlock], None],
|
||||||
tool_output_callback: Callable[[ToolResult, str], None],
|
tool_output_callback: Callable[[ToolResult, str], None],
|
||||||
@@ -96,8 +95,7 @@ def sampling_loop_sync(
|
|||||||
# Register Actor and Executor
|
# Register Actor and Executor
|
||||||
actor = AnthropicActor(
|
actor = AnthropicActor(
|
||||||
model=model,
|
model=model,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
system_prompt_suffix=system_prompt_suffix,
|
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_response_callback=api_response_callback,
|
api_response_callback=api_response_callback,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@@ -121,7 +119,6 @@ def sampling_loop_sync(
|
|||||||
actor = VLMAgent(
|
actor = VLMAgent(
|
||||||
model=model,
|
model=model,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
system_prompt_suffix=system_prompt_suffix,
|
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_response_callback=api_response_callback,
|
api_response_callback=api_response_callback,
|
||||||
selected_screen=selected_screen,
|
selected_screen=selected_screen,
|
||||||
|
|||||||
@@ -22,13 +22,7 @@ from computer_use_demo.gui_agent.llm_utils.qwen import run_qwen
|
|||||||
from computer_use_demo.gui_agent.llm_utils.llm_utils import extract_data
|
from computer_use_demo.gui_agent.llm_utils.llm_utils import extract_data
|
||||||
from computer_use_demo.colorful_text import colorful_text_vlm
|
from computer_use_demo.colorful_text import colorful_text_vlm
|
||||||
import time
|
import time
|
||||||
# start = time.time()
|
|
||||||
|
|
||||||
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
|
|
||||||
* You are utilizing a Windows system with internet access.
|
|
||||||
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
|
||||||
</SYSTEM_CAPABILITY>
|
|
||||||
"""
|
|
||||||
OUTPUT_DIR = "./tmp/outputs"
|
OUTPUT_DIR = "./tmp/outputs"
|
||||||
|
|
||||||
class OmniParser:
|
class OmniParser:
|
||||||
@@ -94,13 +88,11 @@ class OmniParser:
|
|||||||
return response_json
|
return response_json
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VLMAgent:
|
class VLMAgent:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
system_prompt_suffix: str,
|
|
||||||
api_key: str,
|
api_key: str,
|
||||||
output_callback: Callable,
|
output_callback: Callable,
|
||||||
api_response_callback: Callable,
|
api_response_callback: Callable,
|
||||||
@@ -115,7 +107,6 @@ class VLMAgent:
|
|||||||
raise ValueError(f"Model {model} not supported")
|
raise ValueError(f"Model {model} not supported")
|
||||||
|
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.system_prompt_suffix = system_prompt_suffix
|
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.api_response_callback = api_response_callback
|
self.api_response_callback = api_response_callback
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
@@ -127,7 +118,7 @@ class VLMAgent:
|
|||||||
self.total_token_usage = 0
|
self.total_token_usage = 0
|
||||||
self.total_cost = 0
|
self.total_cost = 0
|
||||||
|
|
||||||
self.system = system_prompt_suffix
|
self.system = ''
|
||||||
|
|
||||||
def __call__(self, messages: list, parsed_screen: list[str, list, dict]):
|
def __call__(self, messages: list, parsed_screen: list[str, list, dict]):
|
||||||
# Show results of Omniparser
|
# Show results of Omniparser
|
||||||
@@ -144,7 +135,7 @@ class VLMAgent:
|
|||||||
|
|
||||||
# example parsed_screen: {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, "screen_info"}
|
# example parsed_screen: {"som_image_base64": dino_labled_img, "parsed_content_list": parsed_content_list, "screen_info"}
|
||||||
boxids_and_labels = parsed_screen["screen_info"]
|
boxids_and_labels = parsed_screen["screen_info"]
|
||||||
system = self._get_system_prompt(boxids_and_labels) + self.system_prompt_suffix
|
system = self._get_system_prompt(boxids_and_labels)
|
||||||
|
|
||||||
# drop looping actions msg, byte image etc
|
# drop looping actions msg, byte image etc
|
||||||
planner_messages = messages
|
planner_messages = messages
|
||||||
|
|||||||
Reference in New Issue
Block a user