├── .python-version ├── computer_use_modal ├── sandbox │ ├── __init__.py │ ├── io.py │ ├── bash_manager.py │ ├── sandbox_manager.py │ └── edit_manager.py ├── server │ ├── __init__.py │ ├── prompts.py │ ├── server.py │ └── messages.py ├── tools │ ├── __init__.py │ ├── edit │ │ ├── __init__.py │ │ ├── edit.py │ │ └── types.py │ ├── computer │ │ ├── __init__.py │ │ ├── types.py │ │ └── computer.py │ ├── bash.py │ └── base.py ├── vnd │ ├── __init__.py │ └── anthropic │ │ ├── __init__.py │ │ └── tools │ │ ├── __init__.py │ │ ├── edit.py │ │ ├── shared.py │ │ └── computer.py ├── __init__.py ├── __pycache__ │ └── __init__.cpython-312.pyc ├── demo.py ├── app.py └── streamlit.py ├── demo.png ├── pyproject.toml ├── README.md └── .gitignore /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /computer_use_modal/sandbox/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /computer_use_modal/server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /computer_use_modal/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /computer_use_modal/vnd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /computer_use_modal/vnd/anthropic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /computer_use_modal/vnd/anthropic/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /computer_use_modal/tools/edit/__init__.py: -------------------------------------------------------------------------------- 1 | from .edit import EditTool 2 | -------------------------------------------------------------------------------- /computer_use_modal/tools/computer/__init__.py: -------------------------------------------------------------------------------- 1 | from .computer import ComputerTool 2 | -------------------------------------------------------------------------------- /demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yasyf/anthropic-computer-use-modal/HEAD/demo.png -------------------------------------------------------------------------------- /computer_use_modal/__init__.py: -------------------------------------------------------------------------------- 1 | from .app import app 2 | from .sandbox.sandbox_manager import SandboxManager 3 | from .server.server import ComputerUseServer 4 | -------------------------------------------------------------------------------- /computer_use_modal/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yasyf/anthropic-computer-use-modal/HEAD/computer_use_modal/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "computer-use-modal" 3 | version = "0.1.2" 4 | description = "Anthropic Computer Use with Modal Sandboxes" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "annotated-types>=0.7.0", 9 | "anthropic>=0.37.1", 10 | "backoff>=2.2.1", 11 | "fuzzysearch>=0.7.3", 12 | "modal>=0.64.211", 13 | "pydantic>=2.9.2", 14 | "rich>=13.9.3", 15 | "uuid6>=2024.7.10", 16 | "wand>=0.6.13", 17 | ] 18 | 19 | [project.urls] 20 | Homepage = "https://musings.yasyf.com/improving-claude-computer-use/" 21 | Repository = "https://github.com/yasyf/anthropic-tool-use-modal.git" 22 | 23 | [tool.uv] 24 | dev-dependencies = [ 25 | "streamlit>=1.39.0", 26 | ] 27 | -------------------------------------------------------------------------------- /computer_use_modal/tools/edit/edit.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from anthropic.types.beta import BetaToolTextEditor20241022Param 4 | from pydantic import ValidationError 5 | 6 | from computer_use_modal.sandbox.edit_manager import EditSession, EditSessionManager 7 | from computer_use_modal.tools.base import BaseTool, ToolError 8 | from computer_use_modal.tools.edit.types import BaseEditRequest 9 | 10 | 11 | @dataclass(kw_only=True) 12 | class EditTool(BaseTool[BetaToolTextEditor20241022Param]): 13 | @property 14 | def options(self) -> BetaToolTextEditor20241022Param: 15 | return {"name": "str_replace_editor", "type": "text_editor_20241022"} 16 | 17 | async def __call__( 18 | self, 19 | /, 20 | **data, 21 | ): 22 | try: 23 | request = BaseEditRequest.parse(data) 24 | except ValidationError as e: 25 | raise ToolError(f"Invalid tool parameters:\n{e.json()}") from e 26 | return await (await self.edit_manager()).dispatch(request) 27 | 28 | async def edit_manager(self) -> EditSessionManager: 29 | return EditSessionManager( 30 | sandbox=self.manager, 31 | session=await EditSession.from_request_id(self.manager.request_id), 32 | ) 33 | -------------------------------------------------------------------------------- /computer_use_modal/demo.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | from uuid import uuid4 4 | 5 | from rich import print 6 | 7 | from computer_use_modal import ComputerUseServer, SandboxManager 8 | from computer_use_modal.app import app 9 | from computer_use_modal.tools.base import ToolResult 10 | 11 | 12 | @app.local_entrypoint() 13 | async def demo(request_id: str = uuid4().hex): 14 | sandbox = SandboxManager(request_id=request_id) 15 | print("[bold]Debug URLs:[/bold]", await sandbox.debug_urls.remote.aio()) 16 | 17 | server = ComputerUseServer() 18 | res = server.messages_create_gen.remote_gen.aio( 19 | request_id=request_id, 20 | user_messages=[ 21 | { 22 | "role": "user", 23 | "content": "What is the weather in San Francisco?", 24 | } 25 | ], 26 | ) 27 | async for msg in res: 28 | if isinstance(msg, ToolResult): 29 | if msg.base64_image: 30 | proc = await asyncio.create_subprocess_shell( 31 | "viu -", stdin=asyncio.subprocess.PIPE 32 | ) 33 | await proc.communicate(base64.b64decode(msg.base64_image)) 34 | await proc.wait() 35 | else: 36 | print("[bold]Tool Result:[/bold]", msg) 37 | elif isinstance(msg, dict) and msg["role"] == "assistant": 38 | print("[bold]Response:[/bold]", msg) 39 | -------------------------------------------------------------------------------- /computer_use_modal/tools/bash.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from anthropic.types.beta import BetaToolBash20241022Param 4 | 5 | from computer_use_modal.sandbox.bash_manager import BashSession 6 | from computer_use_modal.tools.base import BaseTool, ToolError, ToolResult 7 | 8 | 9 | @dataclass(kw_only=True) 10 | class BashTool(BaseTool[BetaToolBash20241022Param]): 11 | session: BashSession | None = None 12 | 13 | @property 14 | def options(self) -> BetaToolBash20241022Param: 15 | return {"name": "bash", "type": "bash_20241022"} 16 | 17 | async def __call__( 18 | self, 19 | /, 20 | command: str | None = None, 21 | restart: bool = False, 22 | ): 23 | if restart: 24 | if self.session is None: 25 | raise ToolError("No active bash session") 26 | await self.manager.end_bash_session.remote.aio(self.session) 27 | self.session = None 28 | return ToolResult(system="bash tool has been restarted") 29 | if not command: 30 | return ToolResult(system="no command provided") 31 | 32 | result = ToolResult() 33 | if self.session is None: 34 | self.session = await self.manager.start_bash_session.remote.aio() 35 | result += ToolResult(system="bash tool has been started") 36 | result += await self.manager.execute_bash_command.remote.aio( 37 | self.session, command 38 | ) 39 | return result 40 | -------------------------------------------------------------------------------- /computer_use_modal/vnd/anthropic/tools/edit.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | TRUNCATED_MESSAGE: str = "To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for." 4 | MAX_RESPONSE_LEN: int = 16000 5 | 6 | 7 | Command = Literal[ 8 | "view", 9 | "create", 10 | "str_replace", 11 | "insert", 12 | "undo_edit", 13 | ] 14 | 15 | 16 | def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN): 17 | """Truncate content and append a notice if content exceeds the specified length.""" 18 | return ( 19 | content 20 | if not truncate_after or len(content) <= truncate_after 21 | else content[:truncate_after] + TRUNCATED_MESSAGE 22 | ) 23 | 24 | 25 | def make_output( 26 | file_content: str, 27 | file_descriptor: str, 28 | init_line: int = 1, 29 | ): 30 | """Generate output for the CLI based on the content of a file.""" 31 | file_content = maybe_truncate(file_content) 32 | file_content = file_content.expandtabs() 33 | file_content = "\n".join( 34 | [ 35 | f"{i + init_line:6}\t{line}" 36 | for i, line in enumerate(file_content.split("\n")) 37 | ] 38 | ) 39 | return ( 40 | f"Here's the result of running `cat -n` on {file_descriptor}:\n" 41 | + file_content 42 | + "\n" 43 | ) 44 | -------------------------------------------------------------------------------- /computer_use_modal/vnd/anthropic/tools/shared.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field, fields, replace 2 | 3 | 4 | class ToolError(Exception): 5 | """Raised when a tool encounters an error.""" 6 | 7 | def __init__(self, message): 8 | self.message = message 9 | 10 | 11 | @dataclass(kw_only=True, frozen=True) 12 | class ToolResult: 13 | """Represents the result of a tool execution.""" 14 | 15 | output: str | None = None 16 | error: str | None = None 17 | base64_image: str | None = field(default=None, repr=False) 18 | system: str | None = None 19 | 20 | def __bool__(self): 21 | return any(getattr(self, field.name) for field in fields(self)) 22 | 23 | @staticmethod 24 | def combine_fields( 25 | field: str | None, other_field: str | None, concatenate: bool = True 26 | ): 27 | if field and other_field: 28 | if concatenate: 29 | return field + "\n" + other_field 30 | raise ValueError("Cannot combine tool results") 31 | return field or other_field 32 | 33 | def __add__(self, other: "ToolResult"): 34 | return self.__class__( 35 | output=self.combine_fields(self.output, other.output), 36 | error=self.combine_fields(self.error, other.error), 37 | base64_image=self.combine_fields( 38 | self.base64_image, other.base64_image, False 39 | ), 40 | system=self.combine_fields(self.system, other.system), 41 | ) 42 | 43 | def replace(self, **kwargs): 44 | """Returns a new ToolResult with the given fields replaced.""" 45 | return replace(self, **kwargs) 46 | -------------------------------------------------------------------------------- /computer_use_modal/app.py: -------------------------------------------------------------------------------- 1 | import base64 2 | 3 | from modal import App, Image, Secret 4 | 5 | MOUNT_PATH = "/mnt/nfs" 6 | 7 | app = App("anthropic-computer-use-modal") 8 | 9 | FIREFOX_PIN = base64.b64encode( 10 | """ 11 | Package: * 12 | Pin: release o=LP-PPA-mozillateam 13 | Pin-Priority: 1001 14 | 15 | Package: firefox 16 | Pin: version 1:1snap* 17 | Pin-Priority: -1 18 | """.encode() 19 | ).decode() 20 | 21 | image = ( 22 | Image.debian_slim(python_version="3.12") 23 | .apt_install("libmagickwand-dev") 24 | .env( 25 | { 26 | "UV_PROJECT_ENVIRONMENT": "/usr/local", 27 | "UV_COMPILE_BYTECODE": "1", 28 | "UV_LINK_MODE": "copy", 29 | } 30 | ) 31 | .pip_install("uv") 32 | .copy_local_file("pyproject.toml") 33 | .copy_local_file("uv.lock") 34 | .run_commands("uv sync --frozen --inexact --no-dev") 35 | ) 36 | sandbox_image = ( 37 | Image.from_registry( 38 | "ghcr.io/anthropics/anthropic-quickstarts:computer-use-demo-latest", 39 | ) 40 | .workdir("/home/computeruse") 41 | .run_commands( 42 | "sed -i 's|Exec=firefox-esr -new-window|Exec=sudo firefox-esr -new-window|' /home/computeruse/.config/tint2/applications/firefox-custom.desktop", 43 | "add-apt-repository ppa:mozillateam/ppa", 44 | "add-apt-repository ppa:apt-fast/stable", 45 | f"echo '{FIREFOX_PIN}' | base64 --decode | tee /etc/apt/preferences.d/mozilla-firefox", 46 | "apt-get update -y && apt-get install -y firefox-esr apt-fast", 47 | "apt remove -y xdg-desktop-portal", 48 | ) 49 | .run_commands("timeout 30 sudo firefox-esr -headless -new-window || true") 50 | ) 51 | secrets = Secret.from_local_environ(["ANTHROPIC_API_KEY"]) 52 | -------------------------------------------------------------------------------- /computer_use_modal/tools/edit/types.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Annotated, Literal, Union 3 | 4 | from annotated_types import Gt, Len 5 | from pydantic import BaseModel, Field, TypeAdapter, field_validator 6 | 7 | from computer_use_modal.vnd.anthropic.tools.edit import Command 8 | 9 | 10 | class BaseEditRequest(BaseModel): 11 | command: Command 12 | path: Path 13 | 14 | @classmethod 15 | def parse(cls, data: dict): 16 | adapter: TypeAdapter[TRequest] = TypeAdapter( 17 | Annotated[TRequest, Field(discriminator="command")] 18 | ) 19 | return adapter.validate_python(data) 20 | 21 | @field_validator("path") 22 | def validate_path(cls, v: str | Path): 23 | if isinstance(v, str): 24 | v = Path(v) 25 | return v 26 | 27 | 28 | class ViewRequest(BaseEditRequest): 29 | command: Literal["view"] = "view" 30 | view_range: Annotated[tuple[int, int], Len(2, 2), Gt(0)] | None = None 31 | 32 | 33 | class CreateRequest(BaseEditRequest): 34 | command: Literal["create"] = "create" 35 | file_text: str 36 | 37 | 38 | class StrReplaceRequest(BaseEditRequest): 39 | command: Literal["str_replace"] = "str_replace" 40 | old_str: str 41 | new_str: str = "" 42 | 43 | @field_validator("old_str", "new_str") 44 | def validate_strs(cls, v: str): 45 | return v.expandtabs() 46 | 47 | 48 | class InsertRequest(BaseEditRequest): 49 | command: Literal["insert"] = "insert" 50 | insert_line: int 51 | new_str: str 52 | 53 | @field_validator("new_str") 54 | def validate_strs(cls, v: str): 55 | return v.expandtabs() 56 | 57 | 58 | class UndoEditRequest(BaseEditRequest): 59 | command: Literal["undo_edit"] = "undo_edit" 60 | 61 | 62 | TRequest = Union[ 63 | ViewRequest, CreateRequest, StrReplaceRequest, InsertRequest, UndoEditRequest 64 | ] 65 | -------------------------------------------------------------------------------- /computer_use_modal/tools/computer/types.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Literal, Union 2 | 3 | from annotated_types import Gt 4 | from pydantic import BaseModel, Field, TypeAdapter 5 | 6 | from computer_use_modal.vnd.anthropic.tools.computer import Action 7 | 8 | 9 | class BaseComputerRequest(BaseModel): 10 | action: Action 11 | 12 | @classmethod 13 | def parse(cls, data: dict): 14 | adapter: TypeAdapter[TRequest] = TypeAdapter( 15 | Annotated[TRequest, Field(discriminator="action")] 16 | ) 17 | return adapter.validate_python(data) 18 | 19 | 20 | class CoordinateRequest(BaseComputerRequest): 21 | coordinate: tuple[Annotated[int, Gt(0)], Annotated[int, Gt(0)]] 22 | 23 | 24 | class MouseMoveRequest(CoordinateRequest): 25 | action: Literal["mouse_move"] = "mouse_move" 26 | 27 | 28 | class LeftClickDragRequest(CoordinateRequest): 29 | action: Literal["left_click_drag"] = "left_click_drag" 30 | 31 | 32 | class KeysRequest(BaseComputerRequest): 33 | text: str 34 | 35 | 36 | class KeyRequest(KeysRequest): 37 | action: Literal["key"] = "key" 38 | 39 | 40 | class TypeRequest(KeysRequest): 41 | action: Literal["type"] = "type" 42 | 43 | 44 | class MouseRequest(BaseComputerRequest): 45 | pass 46 | 47 | 48 | class LeftClickRequest(MouseRequest): 49 | action: Literal["left_click"] = "left_click" 50 | 51 | 52 | class RightClickRequest(MouseRequest): 53 | action: Literal["right_click"] = "right_click" 54 | 55 | 56 | class DoubleClickRequest(MouseRequest): 57 | action: Literal["double_click"] = "double_click" 58 | 59 | 60 | class MiddleClickRequest(MouseRequest): 61 | action: Literal["middle_click"] = "middle_click" 62 | 63 | 64 | class ScreenshotRequest(BaseComputerRequest): 65 | action: Literal["screenshot"] = "screenshot" 66 | 67 | 68 | class CursorPositionRequest(BaseComputerRequest): 69 | action: Literal["cursor_position"] = "cursor_position" 70 | 71 | 72 | TRequest = Union[ 73 | MouseMoveRequest, 74 | LeftClickDragRequest, 75 | KeyRequest, 76 | TypeRequest, 77 | LeftClickRequest, 78 | RightClickRequest, 79 | DoubleClickRequest, 80 | MiddleClickRequest, 81 | ScreenshotRequest, 82 | CursorPositionRequest, 83 | ] 84 | -------------------------------------------------------------------------------- /computer_use_modal/server/prompts.py: -------------------------------------------------------------------------------- 1 | import platform 2 | from datetime import datetime 3 | 4 | from computer_use_modal.app import MOUNT_PATH 5 | 6 | SYSTEM_PROMPT = f""" 7 | * You are utilising an Ubuntu virtual machine using {platform.machine()} architecture with internet access. 8 | * You can feel free to install Ubuntu applications with your bash tool. Use curl instead of wget. 9 | * You can also use apt to install applications. Use apt-fast instead of apt or apt-get. 10 | * To open firefox, please just click on the firefox icon. Note, firefox-esr is what is installed on your system. 11 | * Using bash tool you can start GUI applications, but you need to set export DISPLAY=:1 and use a subshell. For example "(DISPLAY=:1 xterm &)". GUI apps run with bash tool will appear within your desktop environment, but they may take some time to appear. Take a screenshot to confirm it did. 12 | * When using your bash tool with commands that are expected to output very large quantities of text, redirect into a tmp file and use str_replace_editor or `grep -n -B -A ` to confirm output. 13 | * When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. 14 | * When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. 15 | * The current date is {datetime.today().strftime('%A, %B %-d, %Y')}. 16 | 17 | 18 | 19 | The only directory you can directly write to is {MOUNT_PATH}. If you need to create files, do so within {MOUNT_PATH}. 20 | The `str_replace_editor` tool is rooted at {MOUNT_PATH}. You do not need to include the full path in your requests, just the relative path from {MOUNT_PATH}. 21 | The `str_replace_editor` tool cannot operate on files outside of {MOUNT_PATH}. 22 | 23 | 24 | 25 | * When using Firefox, if a startup wizard appears, IGNORE IT. Do not even click "skip this step". Instead, click on the address bar where it says "Search or enter address", and enter the appropriate search term or URL there. 26 | * If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. 27 | """ 28 | -------------------------------------------------------------------------------- /computer_use_modal/sandbox/io.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from dataclasses import dataclass 4 | from functools import cached_property 5 | from typing import AsyncIterator, Literal 6 | 7 | from modal.container_process import ContainerProcess 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @dataclass(frozen=True, kw_only=True) 13 | class IOChunk: 14 | data: str 15 | stream: Literal["stdout", "stderr"] 16 | exit_code: int | None = None 17 | 18 | 19 | @dataclass(frozen=True, kw_only=True) 20 | class IOTask: 21 | proc: ContainerProcess 22 | timeout: float 23 | queue: asyncio.Queue[IOChunk] 24 | 25 | @cached_property 26 | def iters(self) -> tuple[AsyncIterator[str], AsyncIterator[str]]: 27 | return aiter(self.proc.stdout), aiter(self.proc.stderr) 28 | 29 | async def _select(self, tasks: list[asyncio.Task]): 30 | done, _ = await asyncio.wait( 31 | tasks, 32 | return_when=asyncio.FIRST_COMPLETED, 33 | timeout=self.timeout, 34 | ) 35 | if not done: # timeout 36 | await self.queue.put( 37 | IOChunk( 38 | data=f"timed out: bash has not output anything in {self.timeout} seconds and must be restarted", 39 | stream="stderr", 40 | exit_code=-999, 41 | ) 42 | ) 43 | return True 44 | 45 | if tasks[1].done() and not tasks[1].exception(): 46 | await self.queue.put(IOChunk(data=tasks[1].result(), stream="stderr")) 47 | tasks[1] = asyncio.create_task(anext(self.iters[1])) 48 | if tasks[0].done() and not tasks[0].exception(): 49 | await self.queue.put(IOChunk(data=tasks[0].result(), stream="stdout")) 50 | tasks[0] = asyncio.create_task(anext(self.iters[0])) 51 | if tasks[2].done(): 52 | exit_code = tasks[2].result() 53 | await self.queue.put( 54 | IOChunk( 55 | data=f"error: bash has exited with returncode {exit_code} and must be restarted", 56 | stream="stderr", 57 | exit_code=exit_code, 58 | ) 59 | ) 60 | return True 61 | 62 | async def run(self): 63 | tasks: list[asyncio.Task] = [ 64 | asyncio.create_task(anext(self.iters[0])), 65 | asyncio.create_task(anext(self.iters[1])), 66 | asyncio.create_task(self.proc.wait.aio()), 67 | ] 68 | 69 | try: 70 | while True: 71 | if await self._select(tasks): 72 | break 73 | finally: 74 | for task in tasks: 75 | task.cancel() 76 | -------------------------------------------------------------------------------- /computer_use_modal/vnd/anthropic/tools/computer.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | from typing import Literal, TypedDict 3 | 4 | from .shared import ToolError 5 | 6 | OUTPUT_DIR = "/tmp/outputs" 7 | 8 | TYPING_DELAY_MS = 12 9 | TYPING_GROUP_SIZE = 50 10 | 11 | Action = Literal[ 12 | "key", 13 | "type", 14 | "mouse_move", 15 | "left_click", 16 | "left_click_drag", 17 | "right_click", 18 | "middle_click", 19 | "double_click", 20 | "screenshot", 21 | "cursor_position", 22 | ] 23 | 24 | 25 | class Resolution(TypedDict): 26 | width: int 27 | height: int 28 | 29 | 30 | # sizes above XGA/WXGA are not recommended (see README.md) 31 | # scale down to one of these targets if ComputerTool._scaling_enabled is set 32 | MAX_SCALING_TARGETS: dict[str, Resolution] = { 33 | "XGA": Resolution(width=1024, height=768), # 4:3 34 | "WXGA": Resolution(width=1280, height=800), # 16:10 35 | "FWXGA": Resolution(width=1366, height=768), # ~16:9 36 | } 37 | 38 | 39 | class ScalingSource(StrEnum): 40 | COMPUTER = "computer" 41 | API = "api" 42 | 43 | 44 | class ComputerToolOptions(TypedDict): 45 | display_height_px: int 46 | display_width_px: int 47 | display_number: int | None 48 | 49 | 50 | class ComputerToolMixin: 51 | TYPING_DELAY_MS = TYPING_DELAY_MS 52 | TYPING_GROUP_SIZE = TYPING_GROUP_SIZE 53 | SCREENSHOT_DELAY_S = 1 54 | 55 | width: int 56 | height: int 57 | 58 | @staticmethod 59 | def chunks(s: str, chunk_size: int) -> list[str]: 60 | return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)] 61 | 62 | def scale_coordinates(self, source: ScalingSource, x: int, y: int): 63 | """Scale coordinates to a target maximum resolution.""" 64 | 65 | ratio = self.width / self.height 66 | target_dimension = next( 67 | ( 68 | dimension 69 | for dimension in MAX_SCALING_TARGETS.values() 70 | if abs(dimension["width"] / dimension["height"] - ratio) < 0.02 71 | and dimension["width"] < self.width 72 | ), 73 | None, 74 | ) 75 | if target_dimension is None: 76 | return x, y 77 | # should be less than 1 78 | x_scaling_factor = target_dimension["width"] / self.width 79 | y_scaling_factor = target_dimension["height"] / self.height 80 | if source == ScalingSource.API: 81 | if x > self.width or y > self.height: 82 | raise ToolError(f"Coordinates {x}, {y} are out of bounds") 83 | # scale up 84 | return round(x / x_scaling_factor), round(y / y_scaling_factor) 85 | # scale down 86 | return round(x * x_scaling_factor), round(y * y_scaling_factor) 87 | -------------------------------------------------------------------------------- /computer_use_modal/server/server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import AsyncGenerator, cast 3 | 4 | import modal 5 | from anthropic import Anthropic 6 | from anthropic.types.beta import ( 7 | BetaContentBlock, 8 | BetaContentBlockParam, 9 | BetaMessageParam, 10 | ) 11 | 12 | from computer_use_modal.app import app, image, secrets 13 | from computer_use_modal.sandbox.sandbox_manager import SandboxManager 14 | from computer_use_modal.server.messages import Messages 15 | from computer_use_modal.server.prompts import SYSTEM_PROMPT 16 | from computer_use_modal.tools.base import ToolCollection, ToolResult 17 | from computer_use_modal.tools.bash import BashTool 18 | from computer_use_modal.tools.computer.computer import ComputerTool 19 | from computer_use_modal.tools.edit.edit import EditTool 20 | 21 | 22 | @app.cls(image=image, allow_concurrent_inputs=10, secrets=[secrets], timeout=60 * 60) 23 | class ComputerUseServer: 24 | @modal.enter() 25 | def init(self): 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | self.client = Anthropic() 29 | 30 | @modal.method() 31 | async def messages_create( 32 | self, 33 | request_id: str, 34 | user_messages: list[BetaMessageParam], 35 | max_tokens: int = 4096, 36 | model: str = "claude-3-5-sonnet-20241022", 37 | ): 38 | messages = [ 39 | msg 40 | async for msg in self.messages_create_gen.remote_gen.aio( 41 | request_id=request_id, 42 | user_messages=user_messages, 43 | max_tokens=max_tokens, 44 | model=model, 45 | ) 46 | ] 47 | return messages[-1] 48 | 49 | @modal.method(is_generator=True) 50 | async def messages_create_gen( 51 | self, 52 | request_id: str, 53 | user_messages: list[BetaMessageParam], 54 | max_tokens: int = 4096, 55 | model: str = "claude-3-5-sonnet-20241022", 56 | ) -> AsyncGenerator[BetaMessageParam | ToolResult, None]: 57 | manager = SandboxManager(request_id=request_id) 58 | messages = await Messages.from_request_id(request_id) 59 | await messages.add_user_messages(user_messages) 60 | 61 | tools = ( 62 | ComputerTool(manager=manager), 63 | EditTool(manager=manager), 64 | BashTool(manager=manager), 65 | ) 66 | 67 | while True: 68 | tool_runner = ToolCollection(tools=tools) 69 | response = self.client.beta.messages.create( 70 | max_tokens=max_tokens, 71 | messages=messages.messages, 72 | model=model, 73 | system=SYSTEM_PROMPT, 74 | tools=tool_runner.to_params(), 75 | betas=["computer-use-2024-10-22", "prompt-caching-2024-07-31"], 76 | ) 77 | yield await messages.add_assistant_content( 78 | cast(list[BetaContentBlockParam], response.content) 79 | ) 80 | for content_block in cast(list[BetaContentBlock], response.content): 81 | if content_block.type != "tool_use": 82 | continue 83 | yield await tool_runner.run( 84 | name=content_block.name, 85 | tool_input=cast(dict, content_block.input), 86 | tool_use_id=content_block.id, 87 | ) 88 | if not tool_runner.results: 89 | return 90 | yield await messages.add_tool_result( 91 | [r.to_api() for r in tool_runner.results] 92 | ) 93 | 94 | @modal.method() 95 | async def debug(self, request_id: str): 96 | manager = SandboxManager(request_id=request_id) 97 | return await manager.debug_urls.remote.aio() 98 | -------------------------------------------------------------------------------- /computer_use_modal/streamlit.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | from enum import StrEnum 4 | from typing import cast 5 | from uuid import uuid4 6 | 7 | import streamlit as st 8 | from anthropic.types import TextBlock 9 | from anthropic.types.beta import BetaTextBlock, BetaToolUseBlock 10 | from anthropic.types.tool_use_block import ToolUseBlock 11 | from modal import Cls 12 | 13 | from computer_use_modal import ComputerUseServer, app 14 | from computer_use_modal.tools.base import ToolResult 15 | 16 | STREAMLIT_STYLE = """ 17 | 28 | """ 29 | 30 | 31 | class Sender(StrEnum): 32 | USER = "user" 33 | BOT = "assistant" 34 | TOOL = "tool" 35 | 36 | 37 | def setup_state(): 38 | if "request_id" in st.session_state: 39 | return 40 | st.session_state.last_role = None 41 | st.session_state.request_id = uuid4().hex 42 | 43 | 44 | async def main(): 45 | setup_state() 46 | 47 | st.markdown(STREAMLIT_STYLE, unsafe_allow_html=True) 48 | st.title("Modal Computer Use Demo") 49 | 50 | new_message = st.chat_input( 51 | "Type a message to send to Claude to control the computer..." 52 | ) 53 | 54 | if new_message: 55 | st.session_state.last_role = Sender.USER 56 | _render_message(Sender.USER, new_message) 57 | 58 | if st.session_state.last_role is not Sender.USER: 59 | return 60 | 61 | with st.spinner("Running Agent..."): 62 | res = Cls.lookup( 63 | app.name, ComputerUseServer.__name__ 64 | ).messages_create.remote_gen.aio( 65 | request_id=st.session_state.request_id, 66 | user_messages=[{"role": "user", "content": new_message}], 67 | ) 68 | async for msg in res: 69 | if msg.__class__.__name__ == "ToolResult": 70 | _render_message(Sender.TOOL, msg) 71 | st.session_state.last_role = Sender.TOOL 72 | else: 73 | st.session_state.last_role = msg["role"] 74 | if isinstance(msg["content"], str): 75 | _render_message(msg["role"], msg["content"]) 76 | elif isinstance(msg["content"], list): 77 | for block in msg["content"]: 78 | if isinstance(block, dict) and block["type"] == "tool_result": 79 | continue 80 | _render_message( 81 | msg["role"], 82 | cast(BetaTextBlock | BetaToolUseBlock, block), 83 | ) 84 | 85 | 86 | def _render_message( 87 | sender: Sender, 88 | message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, 89 | ): 90 | with st.chat_message(sender): 91 | if sender == Sender.TOOL: 92 | message = cast(ToolResult, message) 93 | if message.output and message.output.strip(): 94 | st.code(message.output) 95 | if message.error and message.error.strip(): 96 | st.error(message.error) 97 | if message.base64_image: 98 | st.image(base64.b64decode(message.base64_image)) 99 | elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock): 100 | if message.text: 101 | st.write(message.text) 102 | elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock): 103 | st.code(f"Tool Use: {message.name}\nInput: {message.input}") 104 | elif message: 105 | st.markdown(message) 106 | 107 | 108 | if __name__ == "__main__": 109 | asyncio.run(main()) 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Anthropic Computer Use on Modal 2 | 3 | Anthropic's recent release of [Computer Use](https://anthropic.com/news/3-5-models-and-computer-use) is fantastic, but spinning up random Docker images didn't fit into our production workflow. Luckily, we use [Modal](https://modal.com) and they have all the primitives we need to implement the Computer Use API. 4 | 5 | If you're curious about why this library exists, check out [this blog post](https://musings.yasyf.com/improving-claude-computer-use/). This repo has a reimplementation of the Computer Use Tools from ~scratch, with a focus on using distributed primitives and state management. 6 | 7 | This library provides an out-of-the-box implementation that can be deployed into your Modal environment. Everything runs in a Sandbox, with tool calls being translated into Modal API calls. It may or may not spectacularly explode. It's also quite slow at the moment. Caveat emptor! 8 | 9 | ## Features 10 | 11 | - Deploys into its own app that can be called from your existing apps 12 | - Sandboxes scale to zero and are resumable 13 | - VNC tunnel to each sandbox for debugging 14 | - One NFS per sandbox, available for inspection 15 | - Image processing outside the sandbox, greatly speeding up screenshot generation 16 | - Fuzzy matching for the Edit tool, since the model often misses a newline or two 17 | - Hardware-accelerated browsing in the sandbox 18 | - Pre-warming of the sandbox for faster startup times 19 | - Tools for the LLM to work faster, such as `apt-fast` 20 | 21 | ## Installation 22 | 23 | You can install this library without cloning the repo by running: 24 | 25 | ```bash 26 | pip install computer-use-modal 27 | ``` 28 | 29 | To use it in your own project, simply deploy it once: 30 | 31 | ```bash 32 | modal deploy computer_use_modal 33 | ``` 34 | 35 | Then you can use it in your app like this: 36 | 37 | ```python 38 | from modal import Cls 39 | 40 | server = Cls.lookup("anthropic-computer-use-modal", "ComputerUseServer") 41 | response = server.messages_create.remote.aio( 42 | request_id=uuid4().hex, 43 | user_messages=[{"role": "user", "content": "What is the weather in San Francisco?"}], 44 | ) 45 | print(response) 46 | ``` 47 | 48 | ```python 49 | { 50 | "role": "assistant", 51 | "content": [ 52 | BetaTextBlock( 53 | text="According to the National Weather Service, the current weather in San Francisco is:\n\nTemperature: 65°F (18°C)\nHumidity: 53%\nDewpoint: 48°F (9°C)\nLast update: October 23, 2:43 PM PDT\n\nThe website shows the forecast details as well. Would you like me to provide the extended forecast for the coming days?", 54 | type="text", 55 | ) 56 | ] 57 | } 58 | ``` 59 | 60 | You can also watch the progress with a VNC tunnel: 61 | 62 | ```python 63 | manager = Cls.lookup("anthropic-computer-use-modal", "SandboxManager") 64 | urls = manager.debug_urls.remote() 65 | print(urls["vnc"]) 66 | ``` 67 | 68 | ```python 69 | "https://x2xzanmu4yg.r9.modal.host" 70 | ``` 71 | 72 | If you want to stream the responses, you can use `ComputerUseServer.messages_create_gen`. 73 | 74 | ## Demo 75 | 76 | You can clone this repo and run two demos locally. 77 | 78 | ### CLI Demo 79 | 80 | This demo will deploy an ephemeral Modal app, and ask the LLM to browse the web to fetch the weather in San Francisco. 81 | Screenshots will be shown in your terminal so you can follow along! 82 | 83 | ```bash 84 | git clone https://github.com/yasyf/anthropic-tool-use-modal 85 | cd anthropic-tool-use-modal 86 | uv sync 87 | modal run computer_use_modal.demo 88 | ``` 89 | 90 | ### Streamlit Demo 91 | 92 | This demo deploys the app persistently to its own namespace in your Modal account, then starts a Streamlit app that can interact with it. 93 | 94 | ```bash 95 | git clone https://github.com/yasyf/anthropic-tool-use-modal 96 | cd anthropic-tool-use-modal 97 | uv sync --dev 98 | modal deploy computer_use_modal 99 | python -m streamlit run computer_use_modal/streamlit.py 100 | ``` 101 | 102 | ![Streamlit Demo](demo.png) 103 | 104 | ## Thanks 105 | 106 | Thanks to the Anthropic team for the awesome starting point! 107 | -------------------------------------------------------------------------------- /computer_use_modal/server/messages.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import Self, cast 4 | 5 | import modal 6 | from anthropic.types.beta import ( 7 | BetaCacheControlEphemeralParam, 8 | BetaContentBlockParam, 9 | BetaMessageParam, 10 | BetaToolResultBlockParam, 11 | BetaToolUseBlockParam, 12 | ) 13 | 14 | MESSAGES = modal.Dict.from_name("messages", create_if_missing=True) 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @dataclass(kw_only=True) 20 | class Messages: 21 | CHUNK_SIZE: int = 10 22 | MAX_CACHE_CONTROL: int = 4 23 | 24 | request_id: str 25 | _messages: list[BetaMessageParam] 26 | keep_n_images: int = 10 27 | 28 | @classmethod 29 | async def from_request_id(cls, request_id: str) -> Self: 30 | return await MESSAGES.get.aio( 31 | request_id, cls(request_id=request_id, _messages=[]) 32 | ) 33 | 34 | async def flush(self): 35 | self._filter_cache_control() 36 | self._filter_images() 37 | await MESSAGES.put.aio(self.request_id, self) 38 | 39 | @property 40 | def messages(self) -> tuple[BetaMessageParam, ...]: 41 | return tuple(self._messages) 42 | 43 | async def add_assistant_content(self, content: list[BetaContentBlockParam]): 44 | logger.info(f"AI said: {content}") 45 | self._messages.append( 46 | msg := {"role": "assistant", "content": content}, 47 | ) 48 | await self.flush() 49 | return msg 50 | 51 | async def add_user_messages(self, messages: list[BetaMessageParam]): 52 | logger.info(f"User said: {messages}") 53 | self._messages.extend(messages) 54 | await self.flush() 55 | 56 | async def add_tool_result(self, tool_results: list[BetaToolResultBlockParam]): 57 | self._messages.append( 58 | msg := {"content": tool_results, "role": "user"}, 59 | ) 60 | self.tool_results[-1]["cache_control"] = cast( 61 | BetaCacheControlEphemeralParam, {"type": "ephemeral"} 62 | ) 63 | await self.flush() 64 | return msg 65 | 66 | @property 67 | def tool_results(self) -> list[BetaToolUseBlockParam]: 68 | return cast( 69 | list[BetaToolUseBlockParam], 70 | [ 71 | item 72 | for message in self._messages 73 | for item in ( 74 | message["content"] if isinstance(message["content"], list) else [] 75 | ) 76 | if isinstance(item, dict) and item.get("type") == "tool_result" 77 | ], 78 | ) 79 | 80 | def _filter_cache_control(self): 81 | total_cache_control = sum( 82 | 1 for tool_result in self.tool_results if "cache_control" in tool_result 83 | ) 84 | if (to_remove := total_cache_control - self.MAX_CACHE_CONTROL) <= 0: 85 | return 86 | while to_remove > 0: 87 | for tool_result in self.tool_results: 88 | if "cache_control" not in tool_result: 89 | continue 90 | tool_result.pop("cache_control") 91 | to_remove -= 1 92 | if to_remove == 0: 93 | return 94 | 95 | def _filter_images(self): 96 | total_images = sum( 97 | 1 98 | for tool_result in self.tool_results 99 | for content in tool_result.get("content", []) 100 | if isinstance(content, dict) and content.get("type") == "image" 101 | ) 102 | 103 | if total_images <= self.keep_n_images: 104 | return 105 | if not ( 106 | images_to_remove := (total_images - self.keep_n_images) % self.CHUNK_SIZE 107 | ): 108 | return 109 | 110 | logger.info(f"Removing {images_to_remove} images") 111 | 112 | while images_to_remove > 0: 113 | for res in self.tool_results: 114 | if not isinstance(contents := res.get("content"), list): 115 | continue 116 | for content in contents.copy(): 117 | if not ( 118 | isinstance(content, dict) and content.get("type") == "image" 119 | ): 120 | continue 121 | contents.remove(content) 122 | images_to_remove -= 1 123 | if images_to_remove == 0: 124 | return 125 | -------------------------------------------------------------------------------- /computer_use_modal/tools/base.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from abc import ABC, abstractmethod 4 | from dataclasses import dataclass, field 5 | from typing import TYPE_CHECKING, Generic, Mapping, TypeVar 6 | 7 | from anthropic.types.beta import ( 8 | BetaImageBlockParam, 9 | BetaTextBlockParam, 10 | BetaToolResultBlockParam, 11 | BetaToolUnionParam, 12 | ) 13 | 14 | from computer_use_modal.vnd.anthropic.tools.shared import ToolError as _ToolError 15 | from computer_use_modal.vnd.anthropic.tools.shared import ToolResult as _ToolResult 16 | 17 | if TYPE_CHECKING: 18 | from computer_use_modal.sandbox.sandbox_manager import SandboxManager 19 | 20 | P = TypeVar("P", bound=Mapping) 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | class ToolError(_ToolError): ... 25 | 26 | 27 | @dataclass(kw_only=True, frozen=True) 28 | class ToolResult(_ToolResult): 29 | tool_use_id: str | None = None 30 | is_error: bool = False 31 | 32 | def __add__(self, other: "ToolResult"): 33 | result = super().__add__(other) 34 | return result.replace( 35 | is_error=self.is_error or other.is_error, 36 | tool_use_id=self.combine_fields( 37 | self.tool_use_id, other.tool_use_id, concatenate=False 38 | ), 39 | ) 40 | 41 | def is_empty(self) -> bool: 42 | return not ( 43 | self.error 44 | or self.output 45 | or self.base64_image 46 | or self.system 47 | or self.tool_use_id 48 | ) 49 | 50 | def to_api(self) -> BetaToolResultBlockParam: 51 | assert self.tool_use_id is not None, "tool_use_id is required" 52 | assert not self.is_empty(), "content is required" 53 | 54 | content: list[BetaTextBlockParam | BetaImageBlockParam] | str = [] 55 | system = f"{self.system}\n" if self.system else "" 56 | 57 | if system: 58 | content.append({"type": "text", "text": system}) 59 | if self.error: 60 | content.append({"type": "text", "text": self.error}) 61 | if self.output: 62 | content.append({"type": "text", "text": self.output}) 63 | if self.base64_image: 64 | content.append( 65 | { 66 | "type": "image", 67 | "source": { 68 | "type": "base64", 69 | "media_type": "image/png", 70 | "data": self.base64_image, 71 | }, 72 | } 73 | ) 74 | return { 75 | "type": "tool_result", 76 | "content": content, 77 | "tool_use_id": self.tool_use_id, 78 | "is_error": self.is_error, 79 | } 80 | 81 | 82 | @dataclass(kw_only=True) 83 | class BaseTool(ABC, Generic[P]): 84 | manager: "SandboxManager" 85 | 86 | @property 87 | @abstractmethod 88 | def options(self) -> P: ... 89 | 90 | @abstractmethod 91 | async def __call__(self, /, **kwargs) -> ToolResult: ... 92 | 93 | async def execute(self, command: str, *args): 94 | return await self.manager.run_command.remote.aio(command, *args) 95 | 96 | 97 | @dataclass(kw_only=True, frozen=True) 98 | class ToolCollection: 99 | tools: tuple[BaseTool, ...] 100 | results: list[ToolResult] = field(default_factory=list) 101 | timeout: int = 60 102 | 103 | @property 104 | def tool_map(self) -> dict[str, BaseTool]: 105 | return {tool.options["name"]: tool for tool in self.tools} 106 | 107 | def to_params( 108 | self, 109 | ) -> list[BetaToolUnionParam]: 110 | return [tool.options for tool in self.tools] 111 | 112 | async def _run(self, *, name: str, tool_input: dict) -> ToolResult: 113 | tool = self.tool_map.get(name) 114 | if not tool: 115 | return ToolResult(error=f"Tool {name} is invalid", is_error=True) 116 | try: 117 | async with asyncio.timeout(self.timeout): 118 | return await tool(**tool_input) 119 | except asyncio.TimeoutError: 120 | return ToolResult(error=f"Tool {name} timed out. Try again.", is_error=True) 121 | except ToolError as e: 122 | logger.error(f"ToolError: {e}") 123 | return ToolResult(error=e.message, is_error=True) 124 | except Exception as e: 125 | logger.error(f"Exception: {e}") 126 | return ToolResult(error=str(e), is_error=True) 127 | 128 | async def run(self, *, name: str, tool_input: dict, tool_use_id: str) -> ToolResult: 129 | result = await self._run(name=name, tool_input=tool_input) 130 | result = result.replace(tool_use_id=tool_use_id) 131 | self.results.append(result) 132 | return result 133 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python,osx 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,osx 3 | 4 | ### OSX ### 5 | # General 6 | .DS_Store 7 | .AppleDouble 8 | .LSOverride 9 | 10 | # Icon must end with two \r 11 | Icon 12 | 13 | # Thumbnails 14 | ._* 15 | 16 | # Files that might appear in the root of a volume 17 | .DocumentRevisions-V100 18 | .fseventsd 19 | .Spotlight-V100 20 | .TemporaryItems 21 | .Trashes 22 | .VolumeIcon.icns 23 | .com.apple.timemachine.donotpresent 24 | 25 | # Directories potentially created on remote AFP share 26 | .AppleDB 27 | .AppleDesktop 28 | Network Trash Folder 29 | Temporary Items 30 | .apdisk 31 | 32 | ### Python ### 33 | # Byte-compiled / optimized / DLL files 34 | __pycache__/ 35 | *.py[cod] 36 | *$py.class 37 | 38 | # C extensions 39 | *.so 40 | 41 | # Distribution / packaging 42 | .Python 43 | build/ 44 | develop-eggs/ 45 | dist/ 46 | downloads/ 47 | eggs/ 48 | .eggs/ 49 | lib/ 50 | lib64/ 51 | parts/ 52 | sdist/ 53 | var/ 54 | wheels/ 55 | share/python-wheels/ 56 | *.egg-info/ 57 | .installed.cfg 58 | *.egg 59 | MANIFEST 60 | 61 | # PyInstaller 62 | # Usually these files are written by a python script from a template 63 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 64 | *.manifest 65 | *.spec 66 | 67 | # Installer logs 68 | pip-log.txt 69 | pip-delete-this-directory.txt 70 | 71 | # Unit test / coverage reports 72 | htmlcov/ 73 | .tox/ 74 | .nox/ 75 | .coverage 76 | .coverage.* 77 | .cache 78 | nosetests.xml 79 | coverage.xml 80 | *.cover 81 | *.py,cover 82 | .hypothesis/ 83 | .pytest_cache/ 84 | cover/ 85 | 86 | # Translations 87 | *.mo 88 | *.pot 89 | 90 | # Django stuff: 91 | *.log 92 | local_settings.py 93 | db.sqlite3 94 | db.sqlite3-journal 95 | 96 | # Flask stuff: 97 | instance/ 98 | .webassets-cache 99 | 100 | # Scrapy stuff: 101 | .scrapy 102 | 103 | # Sphinx documentation 104 | docs/_build/ 105 | 106 | # PyBuilder 107 | .pybuilder/ 108 | target/ 109 | 110 | # Jupyter Notebook 111 | .ipynb_checkpoints 112 | 113 | # IPython 114 | profile_default/ 115 | ipython_config.py 116 | 117 | # pyenv 118 | # For a library or package, you might want to ignore these files since the code is 119 | # intended to run in multiple environments; otherwise, check them in: 120 | # .python-version 121 | 122 | # pipenv 123 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 124 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 125 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 126 | # install all needed dependencies. 127 | #Pipfile.lock 128 | 129 | # poetry 130 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 131 | # This is especially recommended for binary packages to ensure reproducibility, and is more 132 | # commonly ignored for libraries. 133 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 134 | #poetry.lock 135 | 136 | # pdm 137 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 138 | #pdm.lock 139 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 140 | # in version control. 141 | # https://pdm.fming.dev/#use-with-ide 142 | .pdm.toml 143 | 144 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 145 | __pypackages__/ 146 | 147 | # Celery stuff 148 | celerybeat-schedule 149 | celerybeat.pid 150 | 151 | # SageMath parsed files 152 | *.sage.py 153 | 154 | # Environments 155 | .env 156 | .venv 157 | env/ 158 | venv/ 159 | ENV/ 160 | env.bak/ 161 | venv.bak/ 162 | 163 | # Spyder project settings 164 | .spyderproject 165 | .spyproject 166 | 167 | # Rope project settings 168 | .ropeproject 169 | 170 | # mkdocs documentation 171 | /site 172 | 173 | # mypy 174 | .mypy_cache/ 175 | .dmypy.json 176 | dmypy.json 177 | 178 | # Pyre type checker 179 | .pyre/ 180 | 181 | # pytype static type analyzer 182 | .pytype/ 183 | 184 | # Cython debug symbols 185 | cython_debug/ 186 | 187 | # PyCharm 188 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 189 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 190 | # and can be added to the global gitignore or merged into this file. For a more nuclear 191 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 192 | #.idea/ 193 | 194 | ### Python Patch ### 195 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 196 | poetry.toml 197 | 198 | # ruff 199 | .ruff_cache/ 200 | 201 | # LSP config files 202 | pyrightconfig.json 203 | 204 | # End of https://www.toptal.com/developers/gitignore/api/python,osx 205 | -------------------------------------------------------------------------------- /computer_use_modal/sandbox/bash_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from dataclasses import dataclass, field, replace 4 | from io import StringIO 5 | from typing import Any, cast 6 | 7 | from modal import Sandbox 8 | from modal.container_process import ContainerProcess 9 | 10 | from computer_use_modal.sandbox.io import IOChunk, IOTask 11 | from computer_use_modal.tools.base import ToolResult 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | @dataclass(frozen=True, kw_only=True, unsafe_hash=True) 16 | class BashSession: 17 | session_id: str 18 | pid: int | None = None 19 | 20 | 21 | @dataclass(kw_only=True) 22 | class BashSessionManager: 23 | sandbox: Sandbox 24 | session: BashSession | None = None 25 | proc: ContainerProcess | None = None 26 | timeout: float = 30 27 | 28 | io_queue: asyncio.Queue[IOChunk] = field(default_factory=asyncio.Queue) 29 | _io_task: asyncio.Task | None = None 30 | 31 | async def start(self) -> BashSession: 32 | if self.session: 33 | self.proc = ContainerProcess( 34 | self.session.session_id, _gross_modal_hack(self.sandbox)._client 35 | ) 36 | else: 37 | self.proc = cast( 38 | ContainerProcess, 39 | await self.sandbox.exec.aio("bash"), 40 | ) 41 | process_id = _gross_modal_hack(self.proc)._process_id 42 | assert process_id is not None 43 | self.session = BashSession(session_id=process_id) 44 | 45 | self._io_task = asyncio.create_task( 46 | IOTask(proc=self.proc, timeout=self.timeout, queue=self.io_queue).run() 47 | ) 48 | 49 | if not self.session.pid: 50 | self.session = replace( 51 | self.session, pid=int((await self.run("echo $BASHPID")).output.strip()) 52 | ) 53 | 54 | return self.session 55 | 56 | async def kill(self): 57 | assert self.session is not None 58 | if self._io_task: 59 | self._io_task.cancel() 60 | logger.info(f"killing bash with pid {self.session.pid}") 61 | proc = await self.sandbox.exec.aio("kill", str(self.session.pid)) 62 | await proc.wait.aio() 63 | self.session = None 64 | self.proc = None 65 | 66 | async def run(self, command: str) -> ToolResult: 67 | if not self.proc: 68 | await self.start() 69 | assert self.session is not None 70 | cmd = BashCommandManager(session=self) 71 | await cmd.start(command) 72 | res = await cmd.wait() 73 | if cmd.exit_code: 74 | await self.kill() 75 | return res 76 | 77 | 78 | @dataclass(kw_only=True) 79 | class BashCommandManager: 80 | SENTINEL = "<>" 81 | 82 | session: BashSessionManager 83 | 84 | stdout: StringIO = field(default_factory=StringIO) 85 | stderr: StringIO = field(default_factory=StringIO) 86 | exit_code: int | None = None 87 | 88 | @property 89 | def proc(self) -> ContainerProcess: 90 | assert self.session.proc is not None 91 | return self.session.proc 92 | 93 | async def start(self, command: str): 94 | assert self.proc is not None 95 | logger.info(f"running command: {command}") 96 | self.proc.stdin.write(f"{command}; echo '{self.SENTINEL}'\n") 97 | await self.proc.stdin.drain.aio() 98 | 99 | async def _loop(self): 100 | while True: 101 | chunk = await self.session.io_queue.get() 102 | if chunk.stream == "stdout": 103 | logger.info(f"stdout: {chunk.data}") 104 | self.stdout.write(chunk.data) 105 | elif chunk.stream == "stderr": 106 | logger.info(f"stderr: {chunk.data}") 107 | self.stderr.write(chunk.data) 108 | 109 | if chunk.exit_code is not None: 110 | logger.info(f"command exited with code {chunk.exit_code}") 111 | self.exit_code = chunk.exit_code 112 | break 113 | elif self.SENTINEL in self.stdout.getvalue(): 114 | logger.info("command succeeded") 115 | break 116 | 117 | async def loop(self): 118 | try: 119 | async with asyncio.timeout(self.session.timeout): 120 | await self._loop() 121 | except asyncio.TimeoutError: 122 | self.stderr.write( 123 | f"timed out: bash has not returned in {self.session.timeout} seconds and must be restarted" 124 | ) 125 | self.exit_code = -999 126 | 127 | async def wait(self): 128 | await self.loop() 129 | if self.SENTINEL in (output := self.stdout.getvalue()): 130 | output = output[: output.index(self.SENTINEL)] 131 | return ToolResult( 132 | output=output, 133 | error=self.stderr.getvalue(), 134 | system=( 135 | "tool must be restarted" if self.exit_code else "bash command succeeded" 136 | ), 137 | is_error=bool(self.exit_code), 138 | ) 139 | 140 | 141 | def _gross_modal_hack(obj: Any): 142 | for k, v in obj.__dict__.items(): 143 | if k.startswith("_sync_original_"): 144 | return v 145 | return obj 146 | -------------------------------------------------------------------------------- /computer_use_modal/sandbox/sandbox_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from io import BytesIO 4 | from pathlib import Path 5 | from typing import cast 6 | 7 | import backoff 8 | import modal 9 | from grpclib import GRPCError 10 | from modal import NetworkFileSystem, Sandbox 11 | from modal.container_process import ContainerProcess 12 | 13 | from computer_use_modal.app import MOUNT_PATH, app, image, sandbox_image 14 | from computer_use_modal.sandbox.bash_manager import BashSession, BashSessionManager 15 | from computer_use_modal.tools.base import ToolResult 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | @app.cls( 20 | image=image, 21 | concurrency_limit=1, 22 | allow_concurrent_inputs=15, 23 | timeout=60 * 30, 24 | container_idle_timeout=60 * 20, 25 | ) 26 | class SandboxManager: 27 | request_id: str = modal.parameter() 28 | auto_cleanup: int = modal.parameter(default=1) 29 | 30 | @modal.enter() 31 | async def create_sandbox(self): 32 | logging.basicConfig(level=logging.INFO) 33 | 34 | self.bash_sessions: dict[BashSession, BashSessionManager] = {} 35 | self.nfs = await NetworkFileSystem.lookup.aio( 36 | f"anthropic-computer-use-{self.request_id}", create_if_missing=True 37 | ) 38 | if sandbox := await anext( 39 | Sandbox.list.aio(tags={"request_id": self.request_id}), None 40 | ): 41 | self.sandbox = sandbox 42 | else: 43 | self.sandbox = await Sandbox.create.aio( 44 | image=sandbox_image, 45 | cpu=8, 46 | memory=1024 * 8, 47 | gpu="T4", 48 | network_file_systems={MOUNT_PATH: self.nfs}, 49 | timeout=60 * 60, 50 | encrypted_ports=[8501, 6080], 51 | ) 52 | logger.info("Waiting for sandbox to start...") 53 | await asyncio.sleep(30) 54 | logger.info("Sandbox started") 55 | 56 | @modal.exit() 57 | async def cleanup_sandbox(self): 58 | if not self.auto_cleanup: 59 | return 60 | for manager in self.bash_sessions.values(): 61 | await manager.kill() 62 | await self.sandbox.terminate.aio() 63 | 64 | @modal.method() 65 | async def debug_urls(self): 66 | tunnels = await self.sandbox.tunnels.aio() 67 | return { 68 | "vnc": tunnels[6080].url, 69 | "webui": tunnels[8501].url, 70 | } 71 | 72 | @modal.method() 73 | async def run_command(self, *command: str) -> ToolResult: 74 | logger.info(f"Running command: {command}") 75 | proc: ContainerProcess = await self.sandbox.exec.aio(*map(str, command)) 76 | await proc.wait.aio() 77 | res = ToolResult( 78 | output=await proc.stdout.read.aio(), 79 | error=await proc.stderr.read.aio(), 80 | ) 81 | logger.info(f"Command returned: {res}") 82 | return res 83 | 84 | @modal.method() 85 | @backoff.on_exception(backoff.expo, FileNotFoundError, max_tries=3) 86 | async def read_file(self, path: Path) -> bytes: 87 | buff = BytesIO() 88 | try: 89 | async for chunk in self.nfs.read_file.aio(path.as_posix()): 90 | buff.write(chunk) 91 | except GRPCError: 92 | raise FileNotFoundError(f"File not found: {path}") 93 | buff.seek(0) 94 | return buff.getvalue() 95 | 96 | @modal.method() 97 | async def write_file(self, path: Path, content: bytes): 98 | await self.nfs.write_file.aio(path.as_posix(), BytesIO(content)) 99 | 100 | @modal.method() 101 | async def stat_file(self, path: Path) -> list[dict]: 102 | try: 103 | return [e.__dict__ for e in await self.nfs.listdir.aio(path.as_posix())] 104 | except GRPCError: 105 | return [] 106 | 107 | @modal.method() 108 | async def take_screenshot(self, display: int, size: tuple[int, int]) -> ToolResult: 109 | from base64 import b64encode 110 | 111 | from uuid6 import uuid7 112 | from wand.image import Image 113 | 114 | path = Path(MOUNT_PATH) / f"{uuid7().hex}.png" 115 | await self.run_command.local( 116 | "env", 117 | f"DISPLAY=:{display}", 118 | "scrot", 119 | "-p", 120 | path.as_posix(), 121 | ) 122 | with Image( 123 | blob=await self.read_file.remote.aio(path.relative_to(MOUNT_PATH)) 124 | ) as img: 125 | img.resize(width=size[0], height=size[1]) 126 | return ToolResult( 127 | base64_image=b64encode(cast(bytes, img.make_blob())).decode() 128 | ) 129 | 130 | @modal.method() 131 | async def start_bash_session(self) -> BashSession: 132 | manager = BashSessionManager(sandbox=self.sandbox) 133 | session = await manager.start() 134 | self.bash_sessions[session] = manager 135 | return session 136 | 137 | @modal.method() 138 | async def execute_bash_command(self, session: BashSession, cmd: str) -> ToolResult: 139 | try: 140 | manager = self.bash_sessions[session] 141 | except KeyError: 142 | manager = BashSessionManager(sandbox=self.sandbox, session=session) 143 | self.bash_sessions[session] = manager 144 | return await manager.run(cmd) 145 | 146 | @modal.method() 147 | async def end_bash_session(self, session: BashSession): 148 | try: 149 | manager = self.bash_sessions.pop(session) 150 | except KeyError: 151 | manager = BashSessionManager(sandbox=self.sandbox, session=session) 152 | await manager.kill() 153 | -------------------------------------------------------------------------------- /computer_use_modal/tools/computer/computer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import shlex 3 | from dataclasses import dataclass 4 | from functools import singledispatchmethod 5 | 6 | from anthropic.types.beta import BetaToolComputerUse20241022Param 7 | from pydantic import ValidationError 8 | 9 | from computer_use_modal.tools.base import BaseTool, ToolError, ToolResult 10 | from computer_use_modal.tools.computer.types import ( 11 | BaseComputerRequest, 12 | CursorPositionRequest, 13 | DoubleClickRequest, 14 | KeyRequest, 15 | LeftClickDragRequest, 16 | LeftClickRequest, 17 | MiddleClickRequest, 18 | MouseMoveRequest, 19 | RightClickRequest, 20 | ScreenshotRequest, 21 | TypeRequest, 22 | ) 23 | from computer_use_modal.vnd.anthropic.tools.computer import ( 24 | ComputerToolMixin, 25 | ScalingSource, 26 | ) 27 | 28 | 29 | @dataclass(kw_only=True) 30 | class ComputerTool(BaseTool[BetaToolComputerUse20241022Param], ComputerToolMixin): 31 | width: int = 1024 32 | height: int = 768 33 | display_num: int = 1 34 | 35 | @property 36 | def options(self) -> BetaToolComputerUse20241022Param: 37 | width, height = self.scale_coordinates( 38 | ScalingSource.COMPUTER, self.width, self.height 39 | ) 40 | return { 41 | "name": "computer", 42 | "type": "computer_20241022", 43 | "display_width_px": width, 44 | "display_height_px": height, 45 | "display_number": self.display_num, 46 | } 47 | 48 | def _command(self, *args): 49 | return ("env", f"DISPLAY=:{self.display_num}", "xdotool") + args 50 | 51 | async def __call__( 52 | self, 53 | /, 54 | **data, 55 | ): 56 | try: 57 | request = BaseComputerRequest.parse(data) 58 | except ValidationError as e: 59 | raise ToolError(f"Invalid tool parameters:\n{e.json()}") from e 60 | return await self.dispatch(request) 61 | 62 | @singledispatchmethod 63 | async def dispatch(self, request: BaseComputerRequest) -> ToolResult: 64 | raise ToolError(f"Unknown action: {request.action}") 65 | 66 | @dispatch.register(MouseMoveRequest) 67 | async def mouse_move(self, request: MouseMoveRequest): 68 | x, y = self.scale_coordinates( 69 | ScalingSource.API, request.coordinate[0], request.coordinate[1] 70 | ) 71 | return await self.execute( 72 | *self._command("mousemove", "--sync", x, y), take_screenshot=False 73 | ) 74 | 75 | @dispatch.register(LeftClickDragRequest) 76 | async def left_click_drag(self, request: LeftClickDragRequest): 77 | x, y = self.scale_coordinates( 78 | ScalingSource.API, request.coordinate[0], request.coordinate[1] 79 | ) 80 | return await self.execute( 81 | *self._command("mousedown", 1, "mousemove", "--sync", x, y, "mouseup", 1) 82 | ) 83 | 84 | @dispatch.register(KeyRequest) 85 | async def key(self, request: KeyRequest): 86 | return await self.execute(*self._command("key", "--", request.text)) 87 | 88 | @dispatch.register(TypeRequest) 89 | async def type(self, request: TypeRequest): 90 | results = [ 91 | await self.execute( 92 | *self._command( 93 | "type", "--delay", self.TYPING_DELAY_MS, "--", shlex.quote(chunk) 94 | ), 95 | take_screenshot=False, 96 | ) 97 | for chunk in self.chunks(request.text, self.TYPING_GROUP_SIZE) 98 | ] 99 | result = sum(results, ToolResult()) 100 | return result.replace( 101 | base64_image=(await self.screenshot(ScreenshotRequest())).base64_image 102 | ) 103 | 104 | @dispatch.register(LeftClickRequest) 105 | async def left_click(self, request: LeftClickRequest): 106 | return await self.execute(*self._command("click", "1")) 107 | 108 | @dispatch.register(RightClickRequest) 109 | async def right_click(self, request: RightClickRequest): 110 | return await self.execute(*self._command("click", "3")) 111 | 112 | @dispatch.register(DoubleClickRequest) 113 | async def double_click(self, request: DoubleClickRequest): 114 | return await self.execute( 115 | *self._command("click", "--repeat", "2", "--delay", "500", "1") 116 | ) 117 | 118 | @dispatch.register(MiddleClickRequest) 119 | async def middle_click(self, request: MiddleClickRequest): 120 | return await self.execute(*self._command("click", "2")) 121 | 122 | @dispatch.register(CursorPositionRequest) 123 | async def cursor_position(self, request: CursorPositionRequest): 124 | import re 125 | 126 | result = await self.execute( 127 | *self._command("getmouselocation", "--shell"), take_screenshot=False 128 | ) 129 | if not result.output: 130 | raise ToolError("Failed to get cursor position") 131 | x, y = self.scale_coordinates( 132 | ScalingSource.COMPUTER, 133 | *map(int, re.match(r"X=(\d+).*Y=(\d+)", result.output).groups()), 134 | ) 135 | return result.replace(output=f"X={x},Y={y}") 136 | 137 | @dispatch.register(ScreenshotRequest) 138 | async def screenshot(self, request: ScreenshotRequest): 139 | return await self.manager.take_screenshot.remote.aio( 140 | self.display_num, 141 | self.scale_coordinates(ScalingSource.COMPUTER, self.width, self.height), 142 | ) 143 | 144 | async def execute(self, command: str, *args, take_screenshot: bool = True): 145 | result = await super().execute(command, *args) 146 | if not take_screenshot: 147 | return result 148 | await asyncio.sleep(self.SCREENSHOT_DELAY_S) 149 | return result + await self.screenshot(ScreenshotRequest()) 150 | -------------------------------------------------------------------------------- /computer_use_modal/sandbox/edit_manager.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | from dataclasses import dataclass, field 4 | from functools import singledispatchmethod 5 | from pathlib import Path 6 | from typing import TYPE_CHECKING, Self 7 | 8 | import modal 9 | from modal.volume import FileEntry, FileEntryType 10 | 11 | from computer_use_modal.app import MOUNT_PATH 12 | from computer_use_modal.tools.base import ToolError, ToolResult 13 | from computer_use_modal.tools.edit.types import ( 14 | CreateRequest, 15 | InsertRequest, 16 | StrReplaceRequest, 17 | TRequest, 18 | UndoEditRequest, 19 | ViewRequest, 20 | ) 21 | from computer_use_modal.vnd.anthropic.tools.edit import make_output 22 | 23 | if TYPE_CHECKING: 24 | from computer_use_modal.sandbox.sandbox_manager import SandboxManager 25 | 26 | SESSIONS = modal.Dict.from_name("edit-sessions", create_if_missing=True) 27 | 28 | logger = logging.getLogger(__name__) 29 | @dataclass(frozen=True, kw_only=True) 30 | class EditSession: 31 | file_versions: dict[Path, list[str]] = field( 32 | default_factory=lambda: defaultdict(list) 33 | ) 34 | 35 | @classmethod 36 | async def from_request_id(cls, request_id: str) -> Self: 37 | return await SESSIONS.get.aio(request_id, cls()) 38 | 39 | 40 | @dataclass(kw_only=True, frozen=True) 41 | class FileInfo: 42 | path: Path 43 | listing: list[FileEntry] 44 | manager: "EditSessionManager" 45 | 46 | def exists(self) -> bool: 47 | return bool(self.listing) 48 | 49 | def is_file(self) -> bool: 50 | return len(self.listing) == 1 and self.listing[0].type == FileEntryType.FILE 51 | 52 | def is_dir(self) -> bool: 53 | return self.exists and not self.is_file 54 | 55 | @property 56 | def local_path(self) -> Path: 57 | return Path(MOUNT_PATH) / self.path 58 | 59 | async def read(self) -> str: 60 | return ( 61 | (await self.manager.sandbox.read_file.remote.aio(self.path)) 62 | .decode() 63 | .expandtabs() 64 | ) 65 | 66 | async def write(self, content: str): 67 | if self.exists(): 68 | self.manager.session.file_versions[self.path].append(await self.read()) 69 | await self.manager.sandbox.write_file.remote.aio( 70 | self.path, content.expandtabs().encode() 71 | ) 72 | 73 | def __str__(self) -> str: 74 | return self.path.as_posix() 75 | 76 | 77 | @dataclass(kw_only=True, frozen=True) 78 | class EditSessionManager: 79 | SNIPPET_LINES: int = 4 80 | 81 | sandbox: "SandboxManager" 82 | session: EditSession 83 | 84 | async def _validate_request(self, request: TRequest): 85 | info = FileInfo( 86 | path=( 87 | request.path.relative_to(MOUNT_PATH) 88 | if MOUNT_PATH in request.path.as_posix() 89 | else request.path 90 | ), 91 | listing=[ 92 | FileEntry(**e) 93 | for e in await self.sandbox.stat_file.remote.aio(request.path) 94 | ], 95 | manager=self, 96 | ) 97 | if request.command != "create" and not info.exists(): 98 | raise ToolError( 99 | f"The path {request.path} does not exist. Please provide a valid path." 100 | ) 101 | if request.command == "create" and info.exists(): 102 | raise ToolError( 103 | f"File already exists at: {request.path}. Cannot overwrite files using command `create`." 104 | ) 105 | if request.command != "view" and info.is_dir(): 106 | raise ToolError( 107 | f"The path {request.path} is a directory and only the `view` command can be used on directories" 108 | ) 109 | if request.command == "view" and request.view_range and info.is_dir(): 110 | raise ToolError( 111 | "The `view_range` parameter is not allowed when `path` points to a directory." 112 | ) 113 | return info 114 | 115 | def _make_output(self, body: str, fname: str, start: int = 1): 116 | res = ToolResult(output=make_output(body, fname, start)) 117 | logger.info(f"edit_manager: {res}") 118 | return res 119 | 120 | @singledispatchmethod 121 | async def dispatch(self, request: TRequest) -> ToolResult: 122 | raise ToolError(f"Action {request.command} not supported") 123 | 124 | @dispatch.register 125 | async def view(self, request: ViewRequest): 126 | f = await self._validate_request(request) 127 | if f.is_dir(): 128 | res = await self.sandbox.run_command.local( 129 | "find", str(f.local_path), "-maxdepth", "2", "-not", "-path", r"'*/\.*'" 130 | ) 131 | if res.output: 132 | return ( 133 | ToolResult( 134 | output=f"Here's the files and directories up to 2 levels deep in {f}, excluding hidden items" 135 | ) 136 | + res 137 | ) 138 | else: 139 | return res 140 | 141 | lines = (await f.read()).splitlines(keepends=True) 142 | (start, end) = request.view_range or (1, -1) 143 | start, end = ( 144 | max(1, start), 145 | min(len(lines) + 1, len(lines) + 1 if end == -1 else end), 146 | ) 147 | return self._make_output( 148 | body="\n".join(lines[start - 1 : end]), 149 | fname=str(f), 150 | start=start, 151 | ) 152 | 153 | @dispatch.register(CreateRequest) 154 | async def create(self, request: CreateRequest): 155 | f = await self._validate_request(request) 156 | await f.write(request.file_text) 157 | return ToolResult(output=f"File created successfully at: {f}") 158 | 159 | async def _make_snippet( 160 | self, f: FileInfo, center: int, length: int = SNIPPET_LINES 161 | ): 162 | start, end = max(0, center - length), center + length 163 | snippet = "\n".join((await f.read()).split("\n")[start : end + 1]) 164 | return self._make_output(snippet, f"a snippet of {f}", start + 1) 165 | 166 | @dispatch.register(StrReplaceRequest) 167 | async def str_replace(self, request: StrReplaceRequest): 168 | import fuzzysearch 169 | 170 | f = await self._validate_request(request) 171 | content = await f.read() 172 | 173 | if request.old_str not in content and ( 174 | matches := fuzzysearch.find_near_matches( 175 | request.old_str, content, max_l_dist=3 176 | ) 177 | ): 178 | request.old_str = matches[0].matched 179 | 180 | if (occurrences := content.count(request.old_str)) == 0: 181 | raise ToolError( 182 | f"No replacement was performed, old_str `{request.old_str}` did not appear verbatim in {f}." 183 | ) 184 | elif occurrences > 1: 185 | lines = [ 186 | idx + 1 187 | for idx, line in enumerate(content.split("\n")) 188 | if request.old_str in line 189 | ] 190 | raise ToolError( 191 | f"No replacement was performed. Multiple occurrences of old_str `{request.old_str}` in lines {lines}. Please ensure it is unique." 192 | ) 193 | 194 | await f.write(content.replace(request.old_str, request.new_str)) 195 | replacement = content.split(request.old_str)[0].count("\n") 196 | 197 | return ( 198 | ToolResult(output=f"The file {f} has been edited. ") 199 | + await self._make_snippet( 200 | f, 201 | replacement, 202 | length=self.SNIPPET_LINES + len(request.new_str.splitlines()), 203 | ) 204 | + ToolResult( 205 | output="Review the changes and make sure they are as expected. Edit the file again if necessary." 206 | ) 207 | ) 208 | 209 | @dispatch.register(InsertRequest) 210 | async def insert(self, request: InsertRequest): 211 | f = await self._validate_request(request) 212 | content = await f.read() 213 | lines = content.splitlines(keepends=True) 214 | 215 | if request.insert_line < 0 or request.insert_line > len(lines): 216 | raise ToolError( 217 | f"Invalid `insert_line` parameter: {request.insert_line}. It should be within the range of lines of the file: {[0, len(lines)]}" 218 | ) 219 | 220 | lines = ( 221 | lines[: request.insert_line] 222 | + request.new_str.splitlines(keepends=True) 223 | + lines[request.insert_line :] 224 | ) 225 | await f.write("\n".join(lines)) 226 | 227 | return ( 228 | ToolResult(output=f"The file {f} has been edited. ") 229 | + await self._make_snippet( 230 | f, 231 | request.insert_line, 232 | length=self.SNIPPET_LINES + len(request.new_str.splitlines()), 233 | ) 234 | + ToolResult( 235 | output="Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary." 236 | ) 237 | ) 238 | 239 | @dispatch.register(UndoEditRequest) 240 | async def undo_edit(self, request: UndoEditRequest): 241 | f = await self._validate_request(request) 242 | await f.write(old_content := self.session.file_versions[f.path].pop()) 243 | return ToolResult( 244 | output=f"Last edit to {f} undone successfully." 245 | ) + self._make_output(old_content, str(f)) 246 | --------------------------------------------------------------------------------