Naming conventions
This commit is contained in:
162
omnitool/gradio/agent/anthropic_agent.py
Normal file
162
omnitool/gradio/agent/anthropic_agent.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Agentic sampling loop that calls the Anthropic API and local implenmentation of anthropic-defined computer use tools.
|
||||
"""
|
||||
import asyncio
|
||||
import platform
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, cast
|
||||
|
||||
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
|
||||
from anthropic.types import (
|
||||
ToolResultBlockParam,
|
||||
)
|
||||
from anthropic.types.beta import (
|
||||
BetaContentBlock,
|
||||
BetaContentBlockParam,
|
||||
BetaImageBlockParam,
|
||||
BetaMessage,
|
||||
BetaMessageParam,
|
||||
BetaTextBlockParam,
|
||||
BetaToolResultBlockParam,
|
||||
)
|
||||
from anthropic.types import TextBlock
|
||||
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
|
||||
|
||||
from tools import ComputerTool, ToolCollection, ToolResult
|
||||
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import gradio as gr
|
||||
from typing import Dict
|
||||
|
||||
BETA_FLAG = "computer-use-2024-10-22"
|
||||
|
||||
class APIProvider(StrEnum):
|
||||
ANTHROPIC = "anthropic"
|
||||
BEDROCK = "bedrock"
|
||||
VERTEX = "vertex"
|
||||
|
||||
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
|
||||
* You are utilizing a Windows system with internet access.
|
||||
* The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
"""
|
||||
|
||||
class AnthropicActor:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: APIProvider,
|
||||
api_key: str,
|
||||
api_response_callback: Callable[[APIResponse[BetaMessage]], None],
|
||||
max_tokens: int = 4096,
|
||||
only_n_most_recent_images: int | None = None,
|
||||
print_usage: bool = True,
|
||||
):
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
self.api_response_callback = api_response_callback
|
||||
self.max_tokens = max_tokens
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
|
||||
self.tool_collection = ToolCollection(ComputerTool())
|
||||
|
||||
self.system = SYSTEM_PROMPT
|
||||
|
||||
self.total_token_usage = 0
|
||||
self.total_cost = 0
|
||||
self.print_usage = print_usage
|
||||
|
||||
# Instantiate the appropriate API client based on the provider
|
||||
if provider == APIProvider.ANTHROPIC:
|
||||
self.client = Anthropic(api_key=api_key)
|
||||
elif provider == APIProvider.VERTEX:
|
||||
self.client = AnthropicVertex()
|
||||
elif provider == APIProvider.BEDROCK:
|
||||
self.client = AnthropicBedrock()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
messages: list[BetaMessageParam]
|
||||
):
|
||||
"""
|
||||
Generate a response given history messages.
|
||||
"""
|
||||
if self.only_n_most_recent_images:
|
||||
_maybe_filter_to_n_most_recent_images(messages, self.only_n_most_recent_images)
|
||||
|
||||
# Call the API synchronously
|
||||
raw_response = self.client.beta.messages.with_raw_response.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
system=self.system,
|
||||
tools=self.tool_collection.to_params(),
|
||||
betas=["computer-use-2024-10-22"],
|
||||
)
|
||||
|
||||
self.api_response_callback(cast(APIResponse[BetaMessage], raw_response))
|
||||
|
||||
response = raw_response.parse()
|
||||
print(f"AnthropicActor response: {response}")
|
||||
|
||||
self.total_token_usage += response.usage.input_tokens + response.usage.output_tokens
|
||||
self.total_cost += (response.usage.input_tokens * 3 / 1000000 + response.usage.output_tokens * 15 / 1000000)
|
||||
|
||||
if self.print_usage:
|
||||
print(f"Claude total token usage so far: {self.total_token_usage}, total cost so far: $USD{self.total_cost}")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _maybe_filter_to_n_most_recent_images(
|
||||
messages: list[BetaMessageParam],
|
||||
images_to_keep: int,
|
||||
min_removal_threshold: int = 10,
|
||||
):
|
||||
"""
|
||||
With the assumption that images are screenshots that are of diminishing value as
|
||||
the conversation progresses, remove all but the final `images_to_keep` tool_result
|
||||
images in place, with a chunk of min_removal_threshold to reduce the amount we
|
||||
break the implicit prompt cache.
|
||||
"""
|
||||
if images_to_keep is None:
|
||||
return messages
|
||||
|
||||
tool_result_blocks = cast(
|
||||
list[ToolResultBlockParam],
|
||||
[
|
||||
item
|
||||
for message in messages
|
||||
for item in (
|
||||
message["content"] if isinstance(message["content"], list) else []
|
||||
)
|
||||
if isinstance(item, dict) and item.get("type") == "tool_result"
|
||||
],
|
||||
)
|
||||
|
||||
total_images = sum(
|
||||
1
|
||||
for tool_result in tool_result_blocks
|
||||
for content in tool_result.get("content", [])
|
||||
if isinstance(content, dict) and content.get("type") == "image"
|
||||
)
|
||||
|
||||
images_to_remove = total_images - images_to_keep
|
||||
# for better cache behavior, we want to remove in chunks
|
||||
images_to_remove -= images_to_remove % min_removal_threshold
|
||||
|
||||
for tool_result in tool_result_blocks:
|
||||
if isinstance(tool_result.get("content"), list):
|
||||
new_content = []
|
||||
for content in tool_result.get("content", []):
|
||||
if isinstance(content, dict) and content.get("type") == "image":
|
||||
if images_to_remove > 0:
|
||||
images_to_remove -= 1
|
||||
continue
|
||||
new_content.append(content)
|
||||
tool_result["content"] = new_content
|
||||
Reference in New Issue
Block a user