docker demo, migration, speedup inference using cv2

This commit is contained in:
yadonglu
2025-01-04 20:06:33 -08:00
parent d0c163cd02
commit b9d3cb715b
36 changed files with 5842 additions and 2456 deletions

16
demo/tools/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
from .base import CLIResult, ToolResult
from .bash import BashTool
from .collection import ToolCollection
from .computer import ComputerTool
from .edit import EditTool
from .screen_capture import get_screenshot
__ALL__ = [
BashTool,
CLIResult,
ComputerTool,
EditTool,
ToolCollection,
ToolResult,
get_screenshot,
]

69
demo/tools/base.py Normal file
View File

@@ -0,0 +1,69 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, fields, replace
from typing import Any
from anthropic.types.beta import BetaToolUnionParam
class BaseAnthropicTool(metaclass=ABCMeta):
"""Abstract base class for Anthropic-defined tools."""
@abstractmethod
def __call__(self, **kwargs) -> Any:
"""Executes the tool with the given arguments."""
...
@abstractmethod
def to_params(
self,
) -> BetaToolUnionParam:
raise NotImplementedError
@dataclass(kw_only=True, frozen=True)
class ToolResult:
"""Represents the result of a tool execution."""
output: str | None = None
error: str | None = None
base64_image: str | None = None
system: str | None = None
def __bool__(self):
return any(getattr(self, field.name) for field in fields(self))
def __add__(self, other: "ToolResult"):
def combine_fields(
field: str | None, other_field: str | None, concatenate: bool = True
):
if field and other_field:
if concatenate:
return field + other_field
raise ValueError("Cannot combine tool results")
return field or other_field
return ToolResult(
output=combine_fields(self.output, other.output),
error=combine_fields(self.error, other.error),
base64_image=combine_fields(self.base64_image, other.base64_image, False),
system=combine_fields(self.system, other.system),
)
def replace(self, **kwargs):
"""Returns a new ToolResult with the given fields replaced."""
return replace(self, **kwargs)
class CLIResult(ToolResult):
"""A ToolResult that can be rendered as a CLI output."""
class ToolFailure(ToolResult):
"""A ToolResult that represents a failure."""
class ToolError(Exception):
"""Raised when a tool encounters an error."""
def __init__(self, message):
self.message = message

136
demo/tools/bash.py Normal file
View File

@@ -0,0 +1,136 @@
import asyncio
import os
from typing import ClassVar, Literal
from anthropic.types.beta import BetaToolBash20241022Param
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
class _BashSession:
"""A session of a bash shell."""
_started: bool
_process: asyncio.subprocess.Process
command: str = "/bin/bash"
_output_delay: float = 0.2 # seconds
_timeout: float = 120.0 # seconds
_sentinel: str = "<<exit>>"
def __init__(self):
self._started = False
self._timed_out = False
async def start(self):
if self._started:
return
self._process = await asyncio.create_subprocess_shell(
self.command,
shell=False,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
self._started = True
def stop(self):
"""Terminate the bash shell."""
if not self._started:
raise ToolError("Session has not started.")
if self._process.returncode is not None:
return
self._process.terminate()
async def run(self, command: str):
"""Execute a command in the bash shell."""
if not self._started:
raise ToolError("Session has not started.")
if self._process.returncode is not None:
return ToolResult(
system="tool must be restarted",
error=f"bash has exited with returncode {self._process.returncode}",
)
if self._timed_out:
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
)
# we know these are not None because we created the process with PIPEs
assert self._process.stdin
assert self._process.stdout
assert self._process.stderr
# send command to the process
self._process.stdin.write(
command.encode() + f"; echo '{self._sentinel}'\n".encode()
)
await self._process.stdin.drain()
# read output from the process, until the sentinel is found
output = ""
try:
async with asyncio.timeout(self._timeout):
while True:
await asyncio.sleep(self._output_delay)
data = await self._process.stdout.readline()
if not data:
break
line = data.decode()
output += line
if self._sentinel in line:
output = output.replace(self._sentinel, "")
break
except asyncio.TimeoutError:
self._timed_out = True
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
) from None
error = await self._process.stderr.read()
error = error.decode()
return CLIResult(output=output.strip(), error=error.strip())
class BashTool(BaseAnthropicTool):
"""
A tool that allows the agent to run bash commands.
The tool parameters are defined by Anthropic and are not editable.
"""
_session: _BashSession | None
name: ClassVar[Literal["bash"]] = "bash"
api_type: ClassVar[Literal["bash_20241022"]] = "bash_20241022"
def __init__(self):
self._session = None
super().__init__()
async def __call__(
self, command: str | None = None, restart: bool = False, **kwargs
):
if restart:
if self._session:
self._session.stop()
self._session = _BashSession()
await self._session.start()
return ToolResult(system="tool has been restarted.")
if self._session is None:
self._session = _BashSession()
await self._session.start()
if command is not None:
return await self._session.run(command)
raise ToolError("no command provided.")
def to_params(self) -> BetaToolBash20241022Param:
return {
"type": self.api_type,
"name": self.name,
}

34
demo/tools/collection.py Normal file
View File

@@ -0,0 +1,34 @@
"""Collection classes for managing multiple tools."""
from typing import Any
from anthropic.types.beta import BetaToolUnionParam
from .base import (
BaseAnthropicTool,
ToolError,
ToolFailure,
ToolResult,
)
class ToolCollection:
"""A collection of anthropic-defined tools."""
def __init__(self, *tools: BaseAnthropicTool):
self.tools = tools
self.tool_map = {tool.to_params()["name"]: tool for tool in tools}
def to_params(
self,
) -> list[BetaToolUnionParam]:
return [tool.to_params() for tool in self.tools]
async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
tool = self.tool_map.get(name)
if not tool:
return ToolFailure(error=f"Tool {name} is invalid")
try:
return await tool(**tool_input)
except ToolError as e:
return ToolFailure(error=e.message)

View File

@@ -0,0 +1,27 @@
"""
Define some colorful stuffs for better visualization in the chat.
"""
# Define the RGB colors for each letter
colors = {
'S': 'rgb(106, 158, 210)',
'h': 'rgb(111, 163, 82)',
'o': 'rgb(209, 100, 94)',
'w': 'rgb(238, 171, 106)',
'U': 'rgb(0, 0, 0)',
'I': 'rgb(0, 0, 0)',
}
# Construct the colorful "ShowUI" word
colorful_text_showui = "**"+''.join(
f'<span style="color:{colors.get(letter, "black")}">{letter}</span>'
for letter in "ShowUI"
)+"**"
colorful_text_vlm = "**OmniParser Agent**"
colorful_text_user = "**User**"
# print(f"colorful_text_showui: {colorful_text_showui}")
# **<span style="color:rgb(106, 158, 210)">S</span><span style="color:rgb(111, 163, 82)">h</span><span style="color:rgb(209, 100, 94)">o</span><span style="color:rgb(238, 171, 106)">w</span><span style="color:rgb(0, 0, 0)">U</span><span style="color:rgb(0, 0, 0)">I</span>**

519
demo/tools/computer.py Normal file
View File

@@ -0,0 +1,519 @@
import subprocess
import platform
import pyautogui
import asyncio
import base64
import os
import time
if platform.system() == "Darwin":
import Quartz # uncomment this line if you are on macOS
from enum import StrEnum
from pathlib import Path
from typing import Literal, TypedDict
from uuid import uuid4
from screeninfo import get_monitors
from PIL import ImageGrab, Image
from functools import partial
from anthropic.types.beta import BetaToolComputerUse20241022Param
from .base import BaseAnthropicTool, ToolError, ToolResult
from .run import run
OUTPUT_DIR = "./tmp/outputs"
TYPING_DELAY_MS = 12
TYPING_GROUP_SIZE = 50
Action = Literal[
"key",
"type",
"mouse_move",
"left_click",
"left_click_drag",
"right_click",
"middle_click",
"double_click",
"screenshot",
"cursor_position",
]
class Resolution(TypedDict):
width: int
height: int
MAX_SCALING_TARGETS: dict[str, Resolution] = {
"XGA": Resolution(width=1024, height=768), # 4:3
"WXGA": Resolution(width=1280, height=800), # 16:10
"FWXGA": Resolution(width=1366, height=768), # ~16:9
}
class ScalingSource(StrEnum):
COMPUTER = "computer"
API = "api"
class ComputerToolOptions(TypedDict):
display_height_px: int
display_width_px: int
display_number: int | None
def chunks(s: str, chunk_size: int) -> list[str]:
return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)]
def get_screen_details():
screens = get_monitors()
screen_details = []
# Sort screens by x position to arrange from left to right
sorted_screens = sorted(screens, key=lambda s: s.x)
# Loop through sorted screens and assign positions
primary_index = 0
for i, screen in enumerate(sorted_screens):
if i == 0:
layout = "Left"
elif i == len(sorted_screens) - 1:
layout = "Right"
else:
layout = "Center"
if screen.is_primary:
position = "Primary"
primary_index = i
else:
position = "Secondary"
screen_info = f"Screen {i + 1}: {screen.width}x{screen.height}, {layout}, {position}"
screen_details.append(screen_info)
return screen_details, primary_index
class ComputerTool(BaseAnthropicTool):
"""
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current computer.
Adapted for Windows using 'pyautogui'.
"""
name: Literal["computer"] = "computer"
api_type: Literal["computer_20241022"] = "computer_20241022"
width: int
height: int
display_num: int | None
_screenshot_delay = 2.0
_scaling_enabled = True
@property
def options(self) -> ComputerToolOptions:
width, height = self.scale_coordinates(
ScalingSource.COMPUTER, self.width, self.height
)
return {
"display_width_px": width,
"display_height_px": height,
"display_number": self.display_num,
}
def to_params(self) -> BetaToolComputerUse20241022Param:
return {"name": self.name, "type": self.api_type, **self.options}
def __init__(self, selected_screen: int = 0, is_scaling: bool = False):
super().__init__()
# Get screen width and height using Windows command
self.display_num = None
self.offset_x = 0
self.offset_y = 0
self.selected_screen = selected_screen
self.is_scaling = is_scaling
self.width, self.height = self.get_screen_size()
# Path to cliclick
self.cliclick = "cliclick"
self.key_conversion = {"Page_Down": "pagedown",
"Page_Up": "pageup",
"Super_L": "win",
"Escape": "esc"}
system = platform.system() # Detect platform
if system == "Windows":
screens = get_monitors()
sorted_screens = sorted(screens, key=lambda s: s.x)
if self.selected_screen < 0 or self.selected_screen >= len(screens):
raise IndexError("Invalid screen index.")
screen = sorted_screens[self.selected_screen]
bbox = (screen.x, screen.y, screen.x + screen.width, screen.y + screen.height)
elif system == "Darwin": # macOS
max_displays = 32 # Maximum number of displays to handle
active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
screens = []
for display_id in active_displays:
bounds = Quartz.CGDisplayBounds(display_id)
screens.append({
'id': display_id, 'x': int(bounds.origin.x), 'y': int(bounds.origin.y),
'width': int(bounds.size.width), 'height': int(bounds.size.height),
'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
})
sorted_screens = sorted(screens, key=lambda s: s['x'])
if self.selected_screen < 0 or self.selected_screen >= len(screens):
raise IndexError("Invalid screen index.")
screen = sorted_screens[self.selected_screen]
bbox = (screen['x'], screen['y'], screen['x'] + screen['width'], screen['y'] + screen['height'])
else: # Linux or other OS
cmd = "xrandr | grep ' primary' | awk '{print $4}'"
try:
# output = subprocess.check_output(cmd, shell=True).decode()
# resolution = output.strip().split()[0]
# width, height = map(int, resolution.split('x'))
# bbox = (0, 0, width, height) # Assuming single primary screen for simplicity
screen = get_monitors()[0]
bbox = (screen.x, screen.y, screen.x + screen.width, screen.y + screen.height)
except subprocess.CalledProcessError:
raise RuntimeError("Failed to get screen resolution on Linux.")
self.offset_x = screen['x'] if system == "Darwin" else screen.x
self.offset_y = screen['y'] if system == "Darwin" else screen.y
self.bbox = bbox
async def __call__(
self,
*,
action: Action,
text: str | None = None,
coordinate: tuple[int, int] | None = None,
**kwargs,
):
print(f"action: {action}, text: {text}, coordinate: {coordinate}, is_scaling: {self.is_scaling}")
if action in ("mouse_move", "left_click_drag"):
if coordinate is None:
raise ToolError(f"coordinate is required for {action}")
if text is not None:
raise ToolError(f"text is not accepted for {action}")
if not isinstance(coordinate, (list, tuple)) or len(coordinate) != 2:
raise ToolError(f"{coordinate} must be a tuple of length 2")
# if not all(isinstance(i, int) and i >= 0 for i in coordinate):
if not all(isinstance(i, int) for i in coordinate):
raise ToolError(f"{coordinate} must be a tuple of non-negative ints")
if self.is_scaling:
x, y = self.scale_coordinates(
ScalingSource.API, coordinate[0], coordinate[1]
)
else:
x, y = coordinate
# print(f"scaled_coordinates: {x}, {y}")
# print(f"offset: {self.offset_x}, {self.offset_y}")
# x += self.offset_x # TODO - check if this is needed
# y += self.offset_y
print(f"mouse move to {x}, {y}")
if action == "mouse_move":
pyautogui.moveTo(x, y)
return ToolResult(output=f"Moved mouse to ({x}, {y})")
elif action == "left_click_drag":
current_x, current_y = pyautogui.position()
pyautogui.dragTo(x, y, duration=0.5) # Adjust duration as needed
return ToolResult(output=f"Dragged mouse from ({current_x}, {current_y}) to ({x}, {y})")
if action in ("key", "type"):
if text is None:
raise ToolError(f"text is required for {action}")
if coordinate is not None:
raise ToolError(f"coordinate is not accepted for {action}")
if not isinstance(text, str):
raise ToolError(output=f"{text} must be a string")
if action == "key":
# Handle key combinations
keys = text.split('+')
for key in keys:
key = self.key_conversion.get(key.strip(), key.strip())
key = key.lower()
pyautogui.keyDown(key) # Press down each key
for key in reversed(keys):
key = self.key_conversion.get(key.strip(), key.strip())
key = key.lower()
pyautogui.keyUp(key) # Release each key in reverse order
return ToolResult(output=f"Pressed keys: {text}")
elif action == "type":
pyautogui.typewrite(text, interval=TYPING_DELAY_MS / 1000) # Convert ms to seconds
pyautogui.press('enter')
screenshot_base64 = (await self.screenshot()).base64_image
return ToolResult(output=text, base64_image=screenshot_base64)
if action in (
"left_click",
"right_click",
"double_click",
"middle_click",
"screenshot",
"cursor_position",
"left_press",
):
if text is not None:
raise ToolError(f"text is not accepted for {action}")
if coordinate is not None:
raise ToolError(f"coordinate is not accepted for {action}")
if action == "screenshot":
return await self.screenshot()
elif action == "cursor_position":
x, y = pyautogui.position()
x, y = self.scale_coordinates(ScalingSource.COMPUTER, x, y)
return ToolResult(output=f"X={x},Y={y}")
else:
if action == "left_click":
pyautogui.click()
elif action == "right_click":
pyautogui.rightClick()
elif action == "middle_click":
pyautogui.middleClick()
elif action == "double_click":
pyautogui.doubleClick()
elif action == "left_press":
pyautogui.mouseDown()
time.sleep(1)
pyautogui.mouseUp()
return ToolResult(output=f"Performed {action}")
raise ToolError(f"Invalid action: {action}")
async def screenshot(self):
import time
time.sleep(1)
"""Take a screenshot of the current screen and return a ToolResult with the base64 encoded image."""
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"screenshot_{uuid4().hex}.png"
ImageGrab.grab = partial(ImageGrab.grab, all_screens=True)
# Detect platform
system = platform.system()
if system == "Windows":
# Windows: Use screeninfo to get monitor details
screens = get_monitors()
# Sort screens by x position to arrange from left to right
sorted_screens = sorted(screens, key=lambda s: s.x)
if self.selected_screen < 0 or self.selected_screen >= len(screens):
raise IndexError("Invalid screen index.")
screen = sorted_screens[self.selected_screen]
bbox = (screen.x, screen.y, screen.x + screen.width, screen.y + screen.height)
elif system == "Darwin": # macOS
# macOS: Use Quartz to get monitor details
max_displays = 32 # Maximum number of displays to handle
active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
# Get the display bounds (resolution) for each active display
screens = []
for display_id in active_displays:
bounds = Quartz.CGDisplayBounds(display_id)
screens.append({
'id': display_id,
'x': int(bounds.origin.x),
'y': int(bounds.origin.y),
'width': int(bounds.size.width),
'height': int(bounds.size.height),
'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
})
# Sort screens by x position to arrange from left to right
sorted_screens = sorted(screens, key=lambda s: s['x'])
if self.selected_screen < 0 or self.selected_screen >= len(screens):
raise IndexError("Invalid screen index.")
screen = sorted_screens[self.selected_screen]
bbox = (screen['x'], screen['y'], screen['x'] + screen['width'], screen['y'] + screen['height'])
else: # Linux or other OS
cmd = "xrandr | grep ' primary' | awk '{print $4}'"
try:
# output = subprocess.check_output(cmd, shell=True).decode()
# resolution = output.strip().split()[0]
# width, height = map(int, resolution.split('x'))
# bbox = (0, 0, width, height) # Assuming single primary screen for simplicity
screen = get_monitors()[0]
bbox = (screen.x, screen.y, screen.x + screen.width, screen.y + screen.height)
except subprocess.CalledProcessError:
raise RuntimeError("Failed to get screen resolution on Linux.")
# Take screenshot using the bounding box
screenshot = ImageGrab.grab(bbox=bbox)
# Set offsets (for potential future use)
self.offset_x = screen['x'] if system == "Darwin" else screen.x
self.offset_y = screen['y'] if system == "Darwin" else screen.y
print(f"target_dimension {self.target_dimension}")
if not hasattr(self, 'target_dimension'):
screenshot = self.padding_image(screenshot)
self.target_dimension = MAX_SCALING_TARGETS["WXGA"]
# Resize if target_dimensions are specified
print(f"offset is {self.offset_x}, {self.offset_y}")
print(f"target_dimension is {self.target_dimension}")
screenshot = screenshot.resize((self.target_dimension["width"], self.target_dimension["height"]))
# Save the screenshot
screenshot.save(str(path))
if path.exists():
# Return a ToolResult instance instead of a dictionary
return ToolResult(base64_image=base64.b64encode(path.read_bytes()).decode())
raise ToolError(f"Failed to take screenshot: {path} does not exist.")
def padding_image(self, screenshot):
"""Pad the screenshot to 16:10 aspect ratio, when the aspect ratio is not 16:10."""
_, height = screenshot.size
new_width = height * 16 // 10
padding_image = Image.new("RGB", (new_width, height), (255, 255, 255))
# padding to top left
padding_image.paste(screenshot, (0, 0))
return padding_image
async def shell(self, command: str, take_screenshot=True) -> ToolResult:
"""Run a shell command and return the output, error, and optionally a screenshot."""
_, stdout, stderr = await run(command)
base64_image = None
if take_screenshot:
# delay to let things settle before taking a screenshot
await asyncio.sleep(self._screenshot_delay)
base64_image = (await self.screenshot()).base64_image
return ToolResult(output=stdout, error=stderr, base64_image=base64_image)
def scale_coordinates(self, source: ScalingSource, x: int, y: int):
"""Scale coordinates to a target maximum resolution."""
if not self._scaling_enabled:
return x, y
ratio = self.width / self.height
target_dimension = None
for target_name, dimension in MAX_SCALING_TARGETS.items():
# allow some error in the aspect ratio - not ratios are exactly 16:9
if abs(dimension["width"] / dimension["height"] - ratio) < 0.02:
if dimension["width"] < self.width:
target_dimension = dimension
self.target_dimension = target_dimension
# print(f"target_dimension: {target_dimension}")
break
if target_dimension is None:
# TODO: currently we force the target to be WXGA (16:10), when it cannot find a match
target_dimension = MAX_SCALING_TARGETS["WXGA"]
self.target_dimension = MAX_SCALING_TARGETS["WXGA"]
# should be less than 1
x_scaling_factor = target_dimension["width"] / self.width
y_scaling_factor = target_dimension["height"] / self.height
if source == ScalingSource.API:
if x > self.width or y > self.height:
raise ToolError(f"Coordinates {x}, {y} are out of bounds")
# scale up
return round(x / x_scaling_factor), round(y / y_scaling_factor)
# scale down
return round(x * x_scaling_factor), round(y * y_scaling_factor)
def get_screen_size(self):
if platform.system() == "Windows":
# Use screeninfo to get primary monitor on Windows
screens = get_monitors()
# Sort screens by x position to arrange from left to right
sorted_screens = sorted(screens, key=lambda s: s.x)
if self.selected_screen is None:
primary_monitor = next((m for m in get_monitors() if m.is_primary), None)
return primary_monitor.width, primary_monitor.height
elif self.selected_screen < 0 or self.selected_screen >= len(screens):
raise IndexError("Invalid screen index.")
else:
screen = sorted_screens[self.selected_screen]
return screen.width, screen.height
elif platform.system() == "Darwin":
# macOS part using Quartz to get screen information
max_displays = 32 # Maximum number of displays to handle
active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
# Get the display bounds (resolution) for each active display
screens = []
for display_id in active_displays:
bounds = Quartz.CGDisplayBounds(display_id)
screens.append({
'id': display_id,
'x': int(bounds.origin.x),
'y': int(bounds.origin.y),
'width': int(bounds.size.width),
'height': int(bounds.size.height),
'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
})
# Sort screens by x position to arrange from left to right
sorted_screens = sorted(screens, key=lambda s: s['x'])
if self.selected_screen is None:
# Find the primary monitor
primary_monitor = next((screen for screen in screens if screen['is_primary']), None)
if primary_monitor:
return primary_monitor['width'], primary_monitor['height']
else:
raise RuntimeError("No primary monitor found.")
elif self.selected_screen < 0 or self.selected_screen >= len(screens):
raise IndexError("Invalid screen index.")
else:
# Return the resolution of the selected screen
screen = sorted_screens[self.selected_screen]
return screen['width'], screen['height']
else: # Linux or other OS
cmd = "xrandr | grep ' primary' | awk '{print $4}'"
try:
# output = subprocess.check_output(cmd, shell=True).decode()
# resolution = output.strip().split()[0]
# width, height = map(int, resolution.split('x'))
# return width, height
screen = get_monitors()[0]
return screen.width, screen.height
except subprocess.CalledProcessError:
raise RuntimeError("Failed to get screen resolution on Linux.")
def get_mouse_position(self):
# TODO: enhance this func
from AppKit import NSEvent
from Quartz import CGEventSourceCreate, kCGEventSourceStateCombinedSessionState
loc = NSEvent.mouseLocation()
# Adjust for different coordinate system
return int(loc.x), int(self.height - loc.y)
def map_keys(self, text: str):
"""Map text to cliclick key codes if necessary."""
# For simplicity, return text as is
# Implement mapping if special keys are needed
return text

290
demo/tools/edit.py Normal file
View File

@@ -0,0 +1,290 @@
from collections import defaultdict
from pathlib import Path
from typing import Literal, get_args
from anthropic.types.beta import BetaToolTextEditor20241022Param
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
from .run import maybe_truncate, run
Command = Literal[
"view",
"create",
"str_replace",
"insert",
"undo_edit",
]
SNIPPET_LINES: int = 4
class EditTool(BaseAnthropicTool):
"""
An filesystem editor tool that allows the agent to view, create, and edit files.
The tool parameters are defined by Anthropic and are not editable.
"""
api_type: Literal["text_editor_20241022"] = "text_editor_20241022"
name: Literal["str_replace_editor"] = "str_replace_editor"
_file_history: dict[Path, list[str]]
def __init__(self):
self._file_history = defaultdict(list)
super().__init__()
def to_params(self) -> BetaToolTextEditor20241022Param:
return {
"name": self.name,
"type": self.api_type,
}
async def __call__(
self,
*,
command: Command,
path: str,
file_text: str | None = None,
view_range: list[int] | None = None,
old_str: str | None = None,
new_str: str | None = None,
insert_line: int | None = None,
**kwargs,
):
_path = Path(path)
self.validate_path(command, _path)
if command == "view":
return await self.view(_path, view_range)
elif command == "create":
if not file_text:
raise ToolError("Parameter `file_text` is required for command: create")
self.write_file(_path, file_text)
self._file_history[_path].append(file_text)
return ToolResult(output=f"File created successfully at: {_path}")
elif command == "str_replace":
if not old_str:
raise ToolError(
"Parameter `old_str` is required for command: str_replace"
)
return self.str_replace(_path, old_str, new_str)
elif command == "insert":
if insert_line is None:
raise ToolError(
"Parameter `insert_line` is required for command: insert"
)
if not new_str:
raise ToolError("Parameter `new_str` is required for command: insert")
return self.insert(_path, insert_line, new_str)
elif command == "undo_edit":
return self.undo_edit(_path)
raise ToolError(
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
)
def validate_path(self, command: str, path: Path):
"""
Check that the path/command combination is valid.
"""
# Check if its an absolute path
if not path.is_absolute():
suggested_path = Path("") / path
raise ToolError(
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
)
# Check if path exists
if not path.exists() and command != "create":
raise ToolError(
f"The path {path} does not exist. Please provide a valid path."
)
if path.exists() and command == "create":
raise ToolError(
f"File already exists at: {path}. Cannot overwrite files using command `create`."
)
# Check if the path points to a directory
if path.is_dir():
if command != "view":
raise ToolError(
f"The path {path} is a directory and only the `view` command can be used on directories"
)
async def view(self, path: Path, view_range: list[int] | None = None):
"""Implement the view command"""
if path.is_dir():
if view_range:
raise ToolError(
"The `view_range` parameter is not allowed when `path` points to a directory."
)
_, stdout, stderr = await run(
rf"find {path} -maxdepth 2 -not -path '*/\.*'"
)
if not stderr:
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
return CLIResult(output=stdout, error=stderr)
file_content = self.read_file(path)
init_line = 1
if view_range:
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
raise ToolError(
"Invalid `view_range`. It should be a list of two integers."
)
file_lines = file_content.split("\n")
n_lines_file = len(file_lines)
init_line, final_line = view_range
if init_line < 1 or init_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. It's first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
)
if final_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
)
if final_line != -1 and final_line < init_line:
raise ToolError(
f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be larger or equal than its first `{init_line}`"
)
if final_line == -1:
file_content = "\n".join(file_lines[init_line - 1 :])
else:
file_content = "\n".join(file_lines[init_line - 1 : final_line])
return CLIResult(
output=self._make_output(file_content, str(path), init_line=init_line)
)
def str_replace(self, path: Path, old_str: str, new_str: str | None):
"""Implement the str_replace command, which replaces old_str with new_str in the file content"""
# Read the file content
file_content = self.read_file(path).expandtabs()
old_str = old_str.expandtabs()
new_str = new_str.expandtabs() if new_str is not None else ""
# Check if old_str is unique in the file
occurrences = file_content.count(old_str)
if occurrences == 0:
raise ToolError(
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
)
elif occurrences > 1:
file_content_lines = file_content.split("\n")
lines = [
idx + 1
for idx, line in enumerate(file_content_lines)
if old_str in line
]
raise ToolError(
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
)
# Replace old_str with new_str
new_file_content = file_content.replace(old_str, new_str)
# Write the new content to the file
self.write_file(path, new_file_content)
# Save the content to history
self._file_history[path].append(file_content)
# Create a snippet of the edited section
replacement_line = file_content.split(old_str)[0].count("\n")
start_line = max(0, replacement_line - SNIPPET_LINES)
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
# Prepare the success message
success_msg = f"The file {path} has been edited. "
success_msg += self._make_output(
snippet, f"a snippet of {path}", start_line + 1
)
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
return CLIResult(output=success_msg)
def insert(self, path: Path, insert_line: int, new_str: str):
"""Implement the insert command, which inserts new_str at the specified line in the file content."""
file_text = self.read_file(path).expandtabs()
new_str = new_str.expandtabs()
file_text_lines = file_text.split("\n")
n_lines_file = len(file_text_lines)
if insert_line < 0 or insert_line > n_lines_file:
raise ToolError(
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
)
new_str_lines = new_str.split("\n")
new_file_text_lines = (
file_text_lines[:insert_line]
+ new_str_lines
+ file_text_lines[insert_line:]
)
snippet_lines = (
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
+ new_str_lines
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
)
new_file_text = "\n".join(new_file_text_lines)
snippet = "\n".join(snippet_lines)
self.write_file(path, new_file_text)
self._file_history[path].append(file_text)
success_msg = f"The file {path} has been edited. "
success_msg += self._make_output(
snippet,
"a snippet of the edited file",
max(1, insert_line - SNIPPET_LINES + 1),
)
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
return CLIResult(output=success_msg)
def undo_edit(self, path: Path):
"""Implement the undo_edit command."""
if not self._file_history[path]:
raise ToolError(f"No edit history found for {path}.")
old_text = self._file_history[path].pop()
self.write_file(path, old_text)
return CLIResult(
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
)
def read_file(self, path: Path):
"""Read the content of a file from a given path; raise a ToolError if an error occurs."""
try:
return path.read_text()
except Exception as e:
raise ToolError(f"Ran into {e} while trying to read {path}") from None
def write_file(self, path: Path, file: str):
"""Write the content of a file to a given path; raise a ToolError if an error occurs."""
try:
path.write_text(file)
except Exception as e:
raise ToolError(f"Ran into {e} while trying to write to {path}") from None
def _make_output(
self,
file_content: str,
file_descriptor: str,
init_line: int = 1,
expand_tabs: bool = True,
):
"""Generate output for the CLI based on the content of a file."""
file_content = maybe_truncate(file_content)
if expand_tabs:
file_content = file_content.expandtabs()
file_content = "\n".join(
[
f"{i + init_line:6}\t{line}"
for i, line in enumerate(file_content.split("\n"))
]
)
return (
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
+ file_content
+ "\n"
)

42
demo/tools/run.py Normal file
View File

@@ -0,0 +1,42 @@
"""Utility to run shell commands asynchronously with a timeout."""
import asyncio
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
MAX_RESPONSE_LEN: int = 16000
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
"""Truncate content and append a notice if content exceeds the specified length."""
return (
content
if not truncate_after or len(content) <= truncate_after
else content[:truncate_after] + TRUNCATED_MESSAGE
)
async def run(
cmd: str,
timeout: float | None = 120.0, # seconds
truncate_after: int | None = MAX_RESPONSE_LEN,
):
"""Run a shell command asynchronously with a timeout."""
process = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
try:
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
return (
process.returncode or 0,
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
)
except asyncio.TimeoutError as exc:
try:
process.kill()
except ProcessLookupError:
pass
raise TimeoutError(
f"Command '{cmd}' timed out after {timeout} seconds"
) from exc

View File

@@ -0,0 +1,185 @@
import subprocess
import base64
from pathlib import Path
from PIL import ImageGrab
from uuid import uuid4
from screeninfo import get_monitors
import platform
if platform.system() == "Darwin":
import Quartz # uncomment this line if you are on macOS
from PIL import ImageGrab
from functools import partial
from .base import BaseAnthropicTool, ToolError, ToolResult
OUTPUT_DIR = "./tmp/outputs"
def get_screenshot(selected_screen: int = 0, resize: bool = True, target_width: int = 1920, target_height: int = 1080):
# print(f"get_screenshot selected_screen: {selected_screen}")
# Get screen width and height using Windows command
display_num = None
offset_x = 0
offset_y = 0
selected_screen = selected_screen
width, height = _get_screen_size()
"""Take a screenshot of the current screen and return a ToolResult with the base64 encoded image."""
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"screenshot_{uuid4().hex}.png"
ImageGrab.grab = partial(ImageGrab.grab, all_screens=True)
# Detect platform
system = platform.system()
if system == "Windows":
# Windows: Use screeninfo to get monitor details
screens = get_monitors()
# Sort screens by x position to arrange from left to right
sorted_screens = sorted(screens, key=lambda s: s.x)
if selected_screen < 0 or selected_screen >= len(screens):
raise IndexError("Invalid screen index.")
screen = sorted_screens[selected_screen]
bbox = (screen.x, screen.y, screen.x + screen.width, screen.y + screen.height)
elif system == "Darwin": # macOS
# macOS: Use Quartz to get monitor details
max_displays = 32 # Maximum number of displays to handle
active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
# Get the display bounds (resolution) for each active display
screens = []
for display_id in active_displays:
bounds = Quartz.CGDisplayBounds(display_id)
screens.append({
'id': display_id,
'x': int(bounds.origin.x),
'y': int(bounds.origin.y),
'width': int(bounds.size.width),
'height': int(bounds.size.height),
'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
})
# Sort screens by x position to arrange from left to right
sorted_screens = sorted(screens, key=lambda s: s['x'])
# print(f"Darwin sorted_screens: {sorted_screens}")
if selected_screen < 0 or selected_screen >= len(screens):
raise IndexError("Invalid screen index.")
screen = sorted_screens[selected_screen]
bbox = (screen['x'], screen['y'], screen['x'] + screen['width'], screen['y'] + screen['height'])
else: # Linux or other OS
cmd = "xrandr | grep ' primary' | awk '{print $4}'"
try:
# output = subprocess.check_output(cmd, shell=True).decode()
# resolution = output.strip().split()[0]
# width, height = map(int, resolution.split('x'))
screen = get_monitors()[0]
width, height = screen.width, screen.height
bbox = (0, 0, width, height) # Assuming single primary screen for simplicity
except subprocess.CalledProcessError:
raise RuntimeError("Failed to get screen resolution on Linux.")
# Take screenshot using the bounding box
screenshot = ImageGrab.grab(bbox=bbox)
import os
if (display_num := os.getenv("DISPLAY_NUM")) is not None:
display_num = int(display_num)
_display_prefix = f"DISPLAY=:{display_num} "
else:
display_num = None
_display_prefix = ""
screenshot_cmd = f"{_display_prefix}scrot -p {path}"
import pdb; pdb.set_trace()
result = subprocess.run(screenshot_cmd, shell=True, capture_output=True)
# Set offsets (for potential future use)
offset_x = screen['x'] if system == "Darwin" else screen.x
offset_y = screen['y'] if system == "Darwin" else screen.y
# # Resize if
if resize:
screenshot = screenshot.resize((target_width, target_height))
# Save the screenshot
# screenshot.save(str(path))
if path.exists():
# Return a ToolResult instance instead of a dictionary
return screenshot, path
raise ToolError(f"Failed to take screenshot: {path} does not exist.")
def _get_screen_size(selected_screen: int = 0):
if platform.system() == "Windows":
# Use screeninfo to get primary monitor on Windows
screens = get_monitors()
# Sort screens by x position to arrange from left to right
sorted_screens = sorted(screens, key=lambda s: s.x)
if selected_screen is None:
primary_monitor = next((m for m in get_monitors() if m.is_primary), None)
return primary_monitor.width, primary_monitor.height
elif selected_screen < 0 or selected_screen >= len(screens):
raise IndexError("Invalid screen index.")
else:
screen = sorted_screens[selected_screen]
return screen.width, screen.height
elif platform.system() == "Darwin":
# macOS part using Quartz to get screen information
max_displays = 32 # Maximum number of displays to handle
active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
# Get the display bounds (resolution) for each active display
screens = []
for display_id in active_displays:
bounds = Quartz.CGDisplayBounds(display_id)
screens.append({
'id': display_id,
'x': int(bounds.origin.x),
'y': int(bounds.origin.y),
'width': int(bounds.size.width),
'height': int(bounds.size.height),
'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
})
# Sort screens by x position to arrange from left to right
sorted_screens = sorted(screens, key=lambda s: s['x'])
if selected_screen is None:
# Find the primary monitor
primary_monitor = next((screen for screen in screens if screen['is_primary']), None)
if primary_monitor:
return primary_monitor['width'], primary_monitor['height']
else:
raise RuntimeError("No primary monitor found.")
elif selected_screen < 0 or selected_screen >= len(screens):
raise IndexError("Invalid screen index.")
else:
# Return the resolution of the selected screen
screen = sorted_screens[selected_screen]
return screen['width'], screen['height']
else: # Linux or other OS
cmd = "xrandr | grep ' primary' | awk '{print $4}'"
try:
# output = subprocess.check_output(cmd, shell=True).decode()
# resolution = output.strip().split()[0]
# width, height = map(int, resolution.split('x'))
# return width, height
screen = get_monitors()[0]
return screen.width, screen.height
except subprocess.CalledProcessError:
raise RuntimeError("Failed to get screen resolution on Linux.")