├── tests ├── __init__.py ├── unit │ ├── envs │ │ ├── __init__.py │ │ ├── crm │ │ │ └── __init__.py │ │ ├── math │ │ │ └── __init__.py │ │ ├── mcp │ │ │ ├── __init__.py │ │ │ ├── provisioners │ │ │ │ ├── __init__.py │ │ │ │ ├── test_manual.py │ │ │ │ ├── test_local.py │ │ │ │ └── test_provisioners_utils.py │ │ │ └── test_mcp_utils.py │ │ ├── wikipedia │ │ │ ├── __init__.py │ │ │ └── test_wiki_env.py │ │ └── excel │ │ │ ├── test_inputs │ │ │ └── test.xlsx │ │ │ ├── test_excel_code_runner_mcp.py │ │ │ └── test_excel_env.py │ └── prompts │ │ └── test_tools.py ├── integration │ ├── __init__.py │ ├── adapters │ │ ├── __init__.py │ │ └── skyrl │ │ │ ├── __init__.py │ │ │ └── skyrl_adapter_integration.py │ └── envs │ │ ├── __init__.py │ │ ├── mcp │ │ ├── __init__.py │ │ ├── provisioners │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ ├── test_local_integration.py │ │ │ └── test_skypilot_integration.py │ │ └── conftest.py │ │ ├── math │ │ └── __init__.py │ │ ├── wikipedia │ │ └── __init__.py │ │ └── excel │ │ ├── test_inputs │ │ └── test.xlsx │ │ └── test_excel_utils.py └── conftest.py ├── src └── benchmax │ ├── adapters │ ├── __init__.py │ └── skyrl │ │ └── benchmax_data_process.py │ ├── envs │ ├── __init__.py │ ├── excel │ │ ├── workdir │ │ │ ├── __init__.py │ │ │ ├── mcp_config.yaml │ │ │ ├── setup.sh │ │ │ ├── reward_fn.py │ │ │ ├── excel_code_runner_mcp.py │ │ │ └── excel_utils.py │ │ ├── data_utils.py │ │ ├── README.md │ │ └── excel_env.py │ ├── math │ │ ├── workdir │ │ │ ├── setup.sh │ │ │ ├── mcp_config.yaml │ │ │ └── reward_fn.py │ │ ├── README.md │ │ └── math_env.py │ ├── crm │ │ ├── workdir │ │ │ ├── setup.sh │ │ │ ├── mcp_config.yaml │ │ │ └── reward_fn.py │ │ ├── crm_env.py │ │ └── README.md │ ├── mcp │ │ ├── example_workdir │ │ │ ├── setup.sh │ │ │ ├── mcp_config.yaml │ │ │ ├── reward_fn.py │ │ │ └── demo_mcp_server.py │ │ ├── __init__.py │ │ ├── provisioners │ │ │ ├── __init__.py │ │ │ ├── base_provisioner.py │ │ │ ├── manual_provisioner.py │ │ │ ├── utils.py │ │ │ └── skypilot_provisioner.py │ │ ├── utils.py │ │ └── README.md │ ├── types.py │ ├── README.md │ ├── wikipedia │ │ ├── README.md │ │ └── utils.py │ ├── how-to-extend-base-env.md │ └── base_env.py │ └── prompts │ ├── __init__.py │ └── tools.py ├── .env.example ├── static └── benchmax.png ├── pytest.ini ├── .gitignore ├── pyproject.toml ├── examples └── skyrl │ ├── run_benchmax_math.sh │ ├── run_benchmax_excel.sh │ ├── benchmax_math.py │ ├── benchmax_excel.py │ └── README.md └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/benchmax/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/benchmax/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/benchmax/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/envs/crm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/envs/math/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/envs/mcp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/envs/mcp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/envs/wikipedia/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/benchmax/envs/excel/workdir/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/envs/math/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/adapters/skyrl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/envs/wikipedia/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/envs/mcp/provisioners/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/envs/mcp/provisioners/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | WIKIPEDIA_API_KEYS=DUMMY_API_KEY_1,DUMMY_API_KEY_2,DUMMY_API_KEY_3 -------------------------------------------------------------------------------- /static/benchmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgftinc/benchmax/HEAD/static/benchmax.png -------------------------------------------------------------------------------- /src/benchmax/envs/math/workdir/setup.sh: -------------------------------------------------------------------------------- 1 | # Install uv 2 | curl -LsSf https://astral.sh/uv/install.sh | sh -------------------------------------------------------------------------------- /src/benchmax/envs/crm/workdir/setup.sh: -------------------------------------------------------------------------------- 1 | uv pip install simple-salesforce>=1.12.3 python-dateutil==2.9.0.post0 -------------------------------------------------------------------------------- /tests/unit/envs/excel/test_inputs/test.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgftinc/benchmax/HEAD/tests/unit/envs/excel/test_inputs/test.xlsx -------------------------------------------------------------------------------- /src/benchmax/envs/math/workdir/mcp_config.yaml: -------------------------------------------------------------------------------- 1 | mcpServers: 2 | calculator: 3 | command: uvx 4 | args: 5 | - mcp-server-calculator -------------------------------------------------------------------------------- /tests/integration/envs/excel/test_inputs/test.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgftinc/benchmax/HEAD/tests/integration/envs/excel/test_inputs/test.xlsx -------------------------------------------------------------------------------- /src/benchmax/envs/excel/workdir/mcp_config.yaml: -------------------------------------------------------------------------------- 1 | mcpServers: 2 | test_server: 3 | command: python 4 | args: 5 | - ${{ sync_workdir }}/excel_code_runner_mcp.py -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/example_workdir/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Setup script for test MCP server 3 | 4 | echo "Installing test MCP server dependencies..." 5 | 6 | echo "Test MCP server setup complete" -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MCP-based environment infrastructure for parallel rollout execution. 3 | """ 4 | 5 | from .parallel_mcp_env import ParallelMcpEnv 6 | from .server_pool import ServerPool, ServerInfo 7 | 8 | __all__ = [ 9 | "ParallelMcpEnv", 10 | "ServerPool", 11 | "ServerInfo", 12 | ] -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/example_workdir/mcp_config.yaml: -------------------------------------------------------------------------------- 1 | # Test MCP configuration 2 | # This configuration is used for testing the MCP infrastructure 3 | # Use ${{ sync_workdir }} to refer to the location of the synchronized workdir in the remote machine 4 | 5 | mcpServers: 6 | test_server: 7 | command: python 8 | args: 9 | - ${{ sync_workdir }}/demo_mcp_server.py -------------------------------------------------------------------------------- /src/benchmax/envs/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Optional, TypedDict 3 | 4 | 5 | class StandardizedExample(TypedDict): 6 | prompt: str 7 | ground_truth: Any 8 | init_rollout_args: Optional[Dict[str, Any]] 9 | 10 | 11 | @dataclass 12 | class ToolDefinition: 13 | """Definition of a tool's interface""" 14 | 15 | name: str 16 | description: str 17 | input_schema: Optional[Dict[str, Any]] = None 18 | -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/provisioners/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Server provisioning strategies for ParallelMcpEnv. 3 | """ 4 | 5 | from .base_provisioner import BaseProvisioner 6 | from .manual_provisioner import ManualProvisioner 7 | from .local_provisioner import LocalProvisioner 8 | from .skypilot_provisioner import SkypilotProvisioner 9 | 10 | __all__ = [ 11 | "BaseProvisioner", 12 | "ManualProvisioner", 13 | "LocalProvisioner", 14 | "SkypilotProvisioner", 15 | ] 16 | -------------------------------------------------------------------------------- /src/benchmax/envs/crm/workdir/mcp_config.yaml: -------------------------------------------------------------------------------- 1 | mcpServers: 2 | test_server: 3 | command: python 4 | args: 5 | - ${{ sync_workdir }}/salesforce_mcp.py 6 | env: 7 | SALESFORCE_USERNAME: crmarena_b2b@gmaill.com 8 | SALESFORCE_PASSWORD: crmarenatest 9 | SALESFORCE_SECURITY_TOKEN: zdaqqSYBEQTjjLuq0zLUHkC3 10 | # Uncomment the following lines to use B2C credentials 11 | # SALESFORCE_USERNAME: crmarena_b2c@gmaill.com 12 | # SALESFORCE_PASSWORD: crmarenatest 13 | # SALESFORCE_SECURITY_TOKEN: 2AQCtK8MnnV4lJdRNF0DGCs1 -------------------------------------------------------------------------------- /src/benchmax/envs/excel/workdir/setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | OS="$(uname -s)" 5 | 6 | if [[ "$OS" == "Linux"* ]]; then 7 | echo "Detected Linux system. Installing LibreOffice (if not installed) and openpyxl..." 8 | if ! command -v libreoffice >/dev/null 2>&1; then 9 | sudo apt update -qq 10 | sudo apt install -y libreoffice >/dev/null 11 | fi 12 | uv pip install openpyxl 13 | elif [[ "$OS" == "Darwin"* || "$OS" == MINGW* || "$OS" == MSYS* || "$OS" == CYGWIN* ]]; then 14 | echo "Detected macOS/Windows system. Installing xlwings and openpyxl..." 15 | uv pip install openpyxl xlwings 16 | else 17 | echo "Unsupported OS: $OS" >&2 18 | exit 1 19 | fi 20 | -------------------------------------------------------------------------------- /src/benchmax/envs/README.md: -------------------------------------------------------------------------------- 1 | # Envs 2 | 3 | This directory contains: 4 | ```bash 5 | ├── crm/ # Salesforce env (extends MCP) 6 | ├── excel/ # Excel env (extends MCP) 7 | ├── math/ # Math env (extends MCP) 8 | ├── mcp/ # MCP env class (extends BaseEnv) 9 | ├── wikipedia/ # Wikipedia env (extends BaseEnv) 10 | ├── types.py # Shared types 11 | └── base_env.py # Base env class 12 | ``` 13 | 14 | Pre-built envs like CRM, Excel, Math uses MCP which has built-in multi-node parallelization are ready to use out of the box. To learn how to create your own parallelized MCP env, check out [this guide here](mcp/README.md) 15 | 16 | If you want to manually extend `BaseEnv` (no multi-node support), you can check out the Wikipedia env or [follow this guide](how-to-extend-base-env.md). 17 | 18 | -------------------------------------------------------------------------------- /src/benchmax/envs/excel/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import tarfile 4 | 5 | 6 | def download_and_extract(url, output_path): 7 | """ 8 | Downloads a tar.gz file from the given URL and extracts it into output_path. 9 | """ 10 | # Ensure the output directory exists 11 | os.makedirs(output_path, exist_ok=True) 12 | 13 | # Determine local file name 14 | local_filename = os.path.join(output_path, os.path.basename(url)) 15 | 16 | # Download the file 17 | with requests.get(url, stream=True) as r: 18 | r.raise_for_status() 19 | with open(local_filename, "wb") as f: 20 | for chunk in r.iter_content(chunk_size=8192): 21 | f.write(chunk) 22 | 23 | # Extract the tar.gz 24 | with tarfile.open(local_filename, "r:gz") as tar: 25 | tar.extractall(path=output_path) 26 | 27 | print(f"Downloaded and extracted to {output_path}") 28 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | # Pytest configuration for benchmax 3 | pythonpath = . 4 | 5 | testpaths = tests 6 | 7 | # Test discovery 8 | python_files = test_*.py 9 | python_classes = Test* 10 | python_functions = test_* 11 | 12 | # Asyncio mode 13 | asyncio_mode = auto 14 | 15 | # Markers 16 | markers = 17 | slow: Slow-running tests (deselect with '-m "not slow"') 18 | remote: marks tests that require remote resources (deselect with '-m "not remote"') 19 | excel: mark tests that require opening of excel app 20 | unit: Fast, isolated tests 21 | 22 | # Default deselection to speed up CI runs 23 | addopts = -m "not slow and not remote and not excel" 24 | 25 | # Logging 26 | log_cli = false 27 | log_cli_level = INFO 28 | log_cli_format = %(asctime)s [%(levelname)8s] %(message)s 29 | log_cli_date_format = %Y-%m-%d %H:%M:%S 30 | 31 | # Warnings 32 | filterwarnings = 33 | ignore::DeprecationWarning 34 | ignore::PendingDeprecationWarning 35 | 36 | minversion = 7.0 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Virtual Environments 24 | .env 25 | .venv 26 | .python-version 27 | env/ 28 | venv/ 29 | ENV/ 30 | 31 | # IDE 32 | .idea/ 33 | .vscode/ 34 | *.swp 35 | *.swo 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | *.egg-info/ 52 | .installed.cfg 53 | *.egg 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # mypy 68 | .mypy_cache/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # macOS 74 | .DS_Store 75 | 76 | # Poetry 77 | poetry.lock 78 | 79 | # Generated workspaces 80 | workspaces/ 81 | outputs/ -------------------------------------------------------------------------------- /src/benchmax/envs/math/workdir/reward_fn.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import re 3 | from typing import Any, Callable, Dict, Union, Awaitable 4 | from fastmcp import Client 5 | from html import unescape 6 | 7 | RewardFunction = Callable[..., Union[float, Awaitable[float]]] 8 | 9 | 10 | async def text_match_reward( 11 | completion: str, 12 | ground_truth: str, 13 | mcp_client: Client, 14 | workspace: Path, 15 | **kwargs: Any, 16 | ) -> float: 17 | """ 18 | Reward = 1 if `ground_truth` (case-insensitive) appears anywhere *inside* 19 | the first block of `completion`; otherwise 0. 20 | 21 | Falls back to 0 if the tag is missing or empty. 22 | """ 23 | 24 | # Grab only the text inside the first pair (case-insensitive). 25 | m = re.search( 26 | r"(.*?)", completion, flags=re.IGNORECASE | re.DOTALL 27 | ) 28 | if m is None: 29 | return 0.0 30 | 31 | # Unescape any XML entities (& → &, etc.) and normalise whitespace. 32 | answer_text = unescape(m.group(1)).strip().lower() 33 | 34 | try: 35 | # Try to interpret both as floats for numerical comparison. 36 | return float(float(ground_truth.lower()) == float(answer_text)) 37 | except ValueError: 38 | return 0.0 39 | 40 | 41 | # ------------------------------- 42 | # Export reward functions 43 | # ------------------------------- 44 | reward_functions: Dict[str, RewardFunction] = {"match": text_match_reward} 45 | -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/provisioners/base_provisioner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base provisioner interface for server provisioning strategies. 3 | """ 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import List 7 | 8 | 9 | class BaseProvisioner(ABC): 10 | """ 11 | Abstract base class for server provisioning strategies. 12 | 13 | A provisioner is responsible for: 14 | 1. Starting/launching servers (returning their addresses) 15 | 2. Cleaning up resources when done 16 | """ 17 | 18 | @property 19 | @abstractmethod 20 | def num_servers(self) -> int: 21 | """ 22 | Total number of servers 23 | 24 | This reports the number of servers that are / will be provisioned. 25 | """ 26 | pass 27 | 28 | @abstractmethod 29 | async def provision_servers(self, api_secret: str) -> List[str]: 30 | """ 31 | Provision servers and return their addresses. 32 | 33 | Args: 34 | api_secret: Secret for server authentication. 35 | 36 | Returns: 37 | List of server addresses in "host:port" format. 38 | Example: ["localhost:8080", "192.168.1.10:8080"] 39 | """ 40 | pass 41 | 42 | @abstractmethod 43 | async def teardown(self) -> None: 44 | """ 45 | Tear down provisioned resources. 46 | 47 | This should clean up any resources created during provisioning, 48 | such as stopping processes, terminating cloud instances, etc. 49 | """ 50 | pass 51 | -------------------------------------------------------------------------------- /tests/integration/envs/mcp/provisioners/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Optional 3 | import aiohttp 4 | 5 | 6 | 7 | async def check_health(address: str) -> bool: 8 | """ 9 | Check if a server's /health endpoint is responding. 10 | 11 | Args: 12 | address: Server address in format "host:port" 13 | 14 | Returns: 15 | True if server is healthy, False otherwise 16 | """ 17 | host, port = address.split(":") 18 | url = f"http://{host}:{port}/health" 19 | 20 | timeout_obj = aiohttp.ClientTimeout(total=2.0) 21 | 22 | try: 23 | async with aiohttp.ClientSession(timeout=timeout_obj) as session: 24 | async with session.get(url) as response: 25 | return response.status == 200 26 | except (aiohttp.ClientError, asyncio.TimeoutError): 27 | return False 28 | 29 | 30 | async def wait_for_server_health(address: str, timeout: float = 60.0) -> bool: 31 | """ 32 | Wait for a server to be healthy by polling its /health endpoint. 33 | 34 | Args: 35 | address: Server address in format "host:port" 36 | timeout: Maximum time to wait in seconds 37 | 38 | Returns: 39 | True if server becomes healthy, False if timeout 40 | """ 41 | start_time = asyncio.get_event_loop().time() 42 | 43 | while (asyncio.get_event_loop().time() - start_time) < timeout: 44 | if await check_health(address): 45 | return True 46 | await asyncio.sleep(1.0) 47 | 48 | return False 49 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shared fixtures 3 | """ 4 | 5 | import os 6 | import tempfile 7 | import uuid 8 | import pytest 9 | from pathlib import Path 10 | import importlib.util 11 | 12 | 13 | @pytest.fixture 14 | def unique_rollout_id() -> str: 15 | """Generate a unique rollout ID for testing.""" 16 | return f"test-rollout-{uuid.uuid4().hex[:8]}" 17 | 18 | 19 | @pytest.fixture 20 | def test_sync_dir(tmp_path: Path) -> Path: 21 | """Temporary directory for mocking syncdir (unit tests only).""" 22 | sync_dir = tmp_path / "sync" 23 | os.mkdir(sync_dir) 24 | return sync_dir 25 | 26 | 27 | @pytest.fixture(scope="session") 28 | def session_tmp_path() -> Path: 29 | """Temporary directory for test session.""" 30 | return Path(tempfile.mkdtemp(prefix="benchmax_test_session_")) 31 | 32 | 33 | @pytest.fixture(scope="session") 34 | def example_workdir() -> Path: 35 | """Path to example MCP workdir inside benchmax.envs.mcp.""" 36 | # Locate the mcp package dynamically 37 | spec = importlib.util.find_spec("benchmax.envs.mcp") 38 | if not spec or not spec.submodule_search_locations: 39 | raise RuntimeError("Could not locate benchmax.envs.mcp package") 40 | 41 | # The directory containing __init__.py 42 | mcp_pkg_dir = Path(spec.submodule_search_locations[0]) 43 | 44 | # Workdir is relative to that 45 | workdir = mcp_pkg_dir / "example_workdir" 46 | 47 | if not workdir.exists(): 48 | raise FileNotFoundError(f"Expected example_workdir not found at: {workdir}") 49 | 50 | return workdir 51 | -------------------------------------------------------------------------------- /src/benchmax/envs/math/README.md: -------------------------------------------------------------------------------- 1 | # Math Environment 2 | 3 | This environment provides capabilities for solving mathematical problems through a local calculator MCP server. 4 | 5 | ## Prerequisites 6 | 7 | Before using this environment, ensure you have: 8 | - Python 3.12 or later installed 9 | - The `mcp-server-calculator` package installed 10 | 11 | ## Installation 12 | 13 | ```bash 14 | pip install "benchmax[skypilot]" 15 | ``` 16 | 17 | Includes: 18 | - fastmcp: For MCP server functionality that enables calculator operations 19 | 20 | ## Usage 21 | 22 | Use `MathEnvLocal` to run the servers locally on the same machine as benchmax or use `MathEnvSkypilot` to parallelize the servers across multiple nodes. 23 | 24 | ## Available Tools 25 | 26 | The environment provides a calculator MCP tool through the server configuration: 27 | 28 | ### Calculator Tool 29 | The calculator tool is provided through a local MCP server that: 30 | - Handles mathematical computations 31 | - Takes mathematical expressions as input 32 | - Returns computed results 33 | - Supports standard mathematical operations 34 | 35 | ## Reward Function 36 | Written in workdir/reward_fn.py so that the reward function can easily be calculated with the remote node. 37 | 38 | The evaluator awards 1.0 when the ground-truth string, after case-insensitive comparison, whitespace normalization, and XML-entity unescaping—appears anywhere inside the first ... block of the completion; otherwise the reward is 0.0. If that tag pair is missing or empty, the reward defaults to 0.0. This binary scheme incentivizes placing an exact, normalized final answer within the required XML tags. 39 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "benchmax" 3 | version = "0.1.2.dev5" 4 | description = "Framework-Agnostic RL Environments for LLM Fine-Tuning" 5 | readme = "README.md" 6 | authors = [{ name = "cgft.io" }] 7 | requires-python = "==3.12.*" 8 | dependencies = [ 9 | "aiohttp>=3.13.1", 10 | "asyncio>=4.0.0", 11 | "datasets>=4.0.0", 12 | "fastmcp~=2.12.0", 13 | "pyjwt>=2.10.1", 14 | "skypilot~=0.8.1", 15 | ] 16 | classifiers = [ 17 | "Programming Language :: Python :: 3", 18 | "Operating System :: OS Independent", 19 | ] 20 | 21 | [build-system] 22 | requires = ["setuptools>=61.0", "wheel"] 23 | build-backend = "setuptools.build_meta" 24 | 25 | [tool.setuptools.packages.find] 26 | where = ["src"] 27 | 28 | [dependency-groups] 29 | dev = [ 30 | "pytest>=8.4.2", 31 | "pytest-asyncio>=1.2.0", 32 | "python-dotenv>=1.2.1", 33 | "ruff>=0.14.2", 34 | ] 35 | skypilot = [ 36 | "skypilot[aws,gcp,azure]~=0.8.1", # Change this to your cloud provider 37 | "pip>=25.3", # Added as needed for skypilot launch 38 | "msrestazure>=0.6.4.post1", 39 | ] 40 | skyrl = [ 41 | "grpcio>=1.60.0", 42 | "hydra-core>=1.3.2", 43 | "omegaconf>=2.3.0", 44 | "ray>=2.48.0", 45 | "skyrl-gym>=0.1.1", 46 | "skyrl-train[vllm]>=0.2.0", 47 | ] 48 | excel = ["openpyxl>=3.1.5"] 49 | excel-mac-windows = ["openpyxl>=3.1.5", "xlwings>=0.33.16"] 50 | crm = ["python-dateutil>=2.9.0.post0", "simple-salesforce>=1.12.9"] 51 | 52 | [tool.uv] 53 | conflicts = [[{ group = "skypilot" }, { group = "skyrl" }]] 54 | 55 | [tool.uv.pip] 56 | extra = ["dev", "skypilot", "skyrl", "excel", "excel-mac-windows", "crm"] 57 | 58 | [tool.uv.extra-build-dependencies] 59 | flash-attn = [{ requirement = "torch", match-runtime = true }] 60 | 61 | [tool.uv.extra-build-variables] 62 | flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE" } 63 | -------------------------------------------------------------------------------- /src/benchmax/envs/wikipedia/README.md: -------------------------------------------------------------------------------- 1 | # Wikipedia Environment 2 | 3 | This environment provides capabilities for interacting with Wikipedia through its API, enabling search and article retrieval functionality. 4 | 5 | ## Prerequisites 6 | 7 | No additional software installation is required. However, we suggest instantiating the `WikipediaEnv` class with API keys for higher rate limits. 8 | 9 | ## Installation 10 | 11 | ```bash 12 | pip install "benchmax" 13 | ``` 14 | 15 | ## Usage 16 | 17 | Use `WikipediaEnv` to run the servers locally. 18 | 19 | ## Available Tools 20 | 21 | The environment provides two MCP tools for Wikipedia interaction: 22 | 23 | ### search_wikipedia 24 | Searches Wikipedia articles by keyword: 25 | - Takes a search query and optional result limit 26 | - Returns a list of relevant articles with titles and snippets 27 | - Handles proper escaping and HTML cleanup 28 | 29 | ### get_wikipedia_article 30 | Fetches the full plaintext of a Wikipedia article: 31 | - Takes an exact article title as input 32 | - Returns the complete article text (up to specified character limit) 33 | - Handles redirects automatically 34 | - Returns plain text with HTML markup removed 35 | 36 | ## Reward Function 37 | The task scores 1.0 only if the ground-truth string, after XML-entity unescaping and whitespace normalization, exactly matches the text inside the first ... block (case-insensitive); otherwise it returns 0.0. If that block is missing or empty, the reward defaults to 0.0. This binary scheme forces the model to place a single, exact final answer inside the first answer tag while allowing any additional explanation outside it. 38 | 39 | ## Features 40 | 41 | - API key rotation support to handle rate limits 42 | - HTML cleanup and entity unescaping 43 | - Configurable result limits 44 | - Error handling for API failures 45 | - Support for article redirects 46 | - Plain text extraction 47 | 48 | Dataset: https://huggingface.co/datasets/chiayewken/bamboogle -------------------------------------------------------------------------------- /src/benchmax/envs/excel/workdir/reward_fn.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Awaitable, Callable, Dict, Optional, Union 3 | 4 | from fastmcp import Client 5 | 6 | try: 7 | from excel_utils import compare_excel_cells 8 | except: 9 | # Added except local import for unit testing purposes 10 | from .excel_utils import compare_excel_cells 11 | 12 | RewardFunction = Callable[..., Union[float, Awaitable[float]]] 13 | 14 | 15 | def spreadsheet_comparison_reward( 16 | completion: str, 17 | ground_truth: dict, 18 | mcp_client: Client, 19 | workspace: Path, 20 | **kwargs: Any, 21 | ) -> float: 22 | """ 23 | Compares the output spreadsheet to the ground truth using cell values in the specified range. 24 | Returns 1.0 if all values match, else 0.0. 25 | """ 26 | answer_position: Optional[str] = kwargs.get("answer_position") 27 | output_filename: Optional[str] = kwargs.get("output_filename") 28 | ground_truth_filename: Optional[str] = kwargs.get("ground_truth_filename") 29 | 30 | if not answer_position or not output_filename or not ground_truth_filename: 31 | raise ValueError( 32 | "kwargs must contain 'answer_position', 'output_filename', and 'ground_truth_filename' fields" 33 | ) 34 | 35 | output_path = workspace / output_filename 36 | ground_truth_path = workspace / ground_truth_filename 37 | 38 | # Return 1.0 score if the output completely matches the ground truth 39 | try: 40 | match, _ = compare_excel_cells( 41 | str(ground_truth_path), str(output_path), answer_position 42 | ) 43 | return 1.0 if match else 0.0 44 | except Exception as e: 45 | print( 46 | f"Error comparing spreadsheets {ground_truth_path} and {output_path}: {e}" 47 | ) 48 | return 0.0 49 | 50 | 51 | # ------------------------------- 52 | # Export reward functions 53 | # ------------------------------- 54 | reward_functions: Dict[str, RewardFunction] = { 55 | "spreadsheet": spreadsheet_comparison_reward, 56 | } 57 | -------------------------------------------------------------------------------- /examples/skyrl/run_benchmax_math.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | # Colocated GRPO training+generation for Qwen2.5-1.5B-Instruct on a sample benchmax math environment. 4 | # uv run benchmax/adapters/skyrl/benchmax_data_process.py --local_dir ~/data/math --dataset_name dawidmt/arithmetic50 --env_path benchmax.envs.math.math_env.MathEnvLocal 5 | # bash examples/skyrl/run_benchmax_math.sh 6 | 7 | DATA_DIR="$HOME/data/math" 8 | NUM_GPUS=2 9 | ENV_CLASS="MathEnv" 10 | 11 | uv run --isolated --group skyrl -m examples.skyrl.benchmax_math \ 12 | data.train_data="['$DATA_DIR/train.parquet']" \ 13 | data.val_data="['$DATA_DIR/test.parquet']" \ 14 | trainer.algorithm.advantage_estimator="grpo" \ 15 | trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \ 16 | trainer.placement.colocate_all=true \ 17 | trainer.strategy=fsdp2 \ 18 | trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ 19 | trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ 20 | generator.num_inference_engines=$NUM_GPUS \ 21 | generator.inference_engine_tensor_parallel_size=1 \ 22 | trainer.epochs=20 \ 23 | trainer.update_epochs_per_batch=1 \ 24 | trainer.train_batch_size=32 \ 25 | trainer.policy_mini_batch_size=32 \ 26 | trainer.critic_mini_batch_size=32 \ 27 | trainer.micro_forward_batch_size_per_gpu=16 \ 28 | trainer.micro_train_batch_size_per_gpu=16 \ 29 | trainer.eval_batch_size=32 \ 30 | trainer.eval_before_train=true \ 31 | trainer.eval_interval=5 \ 32 | trainer.ckpt_interval=20 \ 33 | trainer.max_prompt_length=512 \ 34 | generator.sampling_params.max_generate_length=1024 \ 35 | trainer.policy.optimizer_config.lr=1.0e-7 \ 36 | trainer.algorithm.use_kl_loss=true \ 37 | generator.backend=vllm \ 38 | generator.run_engines_locally=true \ 39 | generator.weight_sync_backend=nccl \ 40 | generator.async_engine=true \ 41 | generator.batched=false \ 42 | environment.env_class=$ENV_CLASS \ 43 | generator.n_samples_per_prompt=5 \ 44 | generator.gpu_memory_utilization=0.8 \ 45 | trainer.logger="wandb" \ 46 | trainer.project_name="benchmax_math" \ 47 | trainer.run_name="benchmax_math" \ 48 | $@ -------------------------------------------------------------------------------- /examples/skyrl/run_benchmax_excel.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | # Colocated GRPO training+generation for Qwen2.5-3B-Instruct on a sample benchmax excel environment. 4 | # uv run benchmax/adapters/skyrl/benchmax_data_process.py --local_dir ~/data/excel --dataset_name spreadsheetbench --env_path benchmax.envs.excel.excel_env.ExcelEnvLocal 5 | # bash examples/skyrl/run_benchmax_excel.sh 6 | 7 | DATA_DIR="$HOME/data/excel" 8 | NUM_GPUS=2 9 | ENV_CLASS="ExcelEnv" 10 | 11 | uv run --isolated --group skyrl --group excel -m examples.skyrl.benchmax_excel \ 12 | data.train_data="['$DATA_DIR/train.parquet']" \ 13 | data.val_data="['$DATA_DIR/test.parquet']" \ 14 | trainer.algorithm.advantage_estimator="grpo" \ 15 | trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \ 16 | trainer.placement.colocate_all=true \ 17 | trainer.strategy=fsdp2 \ 18 | trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ 19 | trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ 20 | generator.num_inference_engines=$NUM_GPUS \ 21 | generator.inference_engine_tensor_parallel_size=1 \ 22 | trainer.epochs=20 \ 23 | trainer.update_epochs_per_batch=1 \ 24 | trainer.train_batch_size=20 \ 25 | trainer.policy_mini_batch_size=20 \ 26 | trainer.critic_mini_batch_size=20 \ 27 | trainer.micro_forward_batch_size_per_gpu=10 \ 28 | trainer.micro_train_batch_size_per_gpu=10 \ 29 | trainer.eval_batch_size=20 \ 30 | trainer.eval_before_train=true \ 31 | trainer.eval_interval=5 \ 32 | trainer.ckpt_interval=10 \ 33 | trainer.max_prompt_length=3000 \ 34 | generator.sampling_params.max_generate_length=1024 \ 35 | trainer.policy.optimizer_config.lr=1.0e-7 \ 36 | trainer.algorithm.use_kl_loss=true \ 37 | generator.backend=vllm \ 38 | generator.run_engines_locally=true \ 39 | generator.weight_sync_backend=nccl \ 40 | generator.async_engine=true \ 41 | generator.batched=false \ 42 | environment.env_class=$ENV_CLASS \ 43 | generator.n_samples_per_prompt=5 \ 44 | generator.gpu_memory_utilization=0.8 \ 45 | trainer.logger="wandb" \ 46 | trainer.project_name="benchmax_excel" \ 47 | trainer.run_name="benchmax_excel" \ 48 | $@ -------------------------------------------------------------------------------- /tests/unit/envs/excel/test_excel_code_runner_mcp.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import os 3 | from pathlib import Path 4 | import pytest 5 | import shutil 6 | from benchmax.envs.excel.workdir.excel_code_runner_mcp import run_excel_code_impl 7 | from benchmax.envs.excel.workdir.excel_utils import excel_to_str_repr 8 | 9 | 10 | # Fixtures 11 | @pytest.fixture(scope="session") 12 | def test_xlsx_path() -> str: 13 | return str(Path(__file__).parent / "test_inputs" / "test.xlsx") 14 | 15 | 16 | @contextmanager 17 | def temporary_cwd(path: Path): 18 | old_cwd = os.getcwd() 19 | os.chdir(path) 20 | try: 21 | yield 22 | finally: 23 | os.chdir(old_cwd) 24 | 25 | 26 | def test_excel_to_str_repr_basic(test_xlsx_path: str): 27 | output = excel_to_str_repr(test_xlsx_path) 28 | assert "Sheet Name: Sheet1" in output 29 | assert "Sheet Name: Sheet2" in output 30 | # Check for some known cell values and styles 31 | assert "D3: 1" in output 32 | assert "H3: D" in output 33 | assert "D6: =SUM(D3:D5) -> 2" in output 34 | 35 | 36 | def test_run_excel_code_impl_success(tmp_path: Path, test_xlsx_path: str): 37 | # Copy input to a temp file for isolation 38 | input_path = tmp_path / "input.xlsx" 39 | output_path = tmp_path / "output.xlsx" 40 | shutil.copyfile(test_xlsx_path, input_path) 41 | 42 | # User code: set D3 to 10 and save to output_path 43 | user_code = f''' 44 | from openpyxl import load_workbook 45 | wb = load_workbook("{input_path}") 46 | ws = wb["Sheet1"] 47 | ws["D3"].value = 10 48 | wb.save("{output_path}") 49 | wb.close() 50 | ''' 51 | 52 | # Run code with temporary cwd 53 | with temporary_cwd(tmp_path): 54 | result = run_excel_code_impl(user_code, "output.xlsx") 55 | 56 | assert "D3: 10" in result 57 | 58 | 59 | def test_run_excel_code_impl_error(tmp_path: Path): 60 | output_path = tmp_path / "output.xlsx" 61 | user_code = "raise ValueError('test error')" 62 | 63 | with temporary_cwd(tmp_path): 64 | result = run_excel_code_impl(user_code, str(output_path)) 65 | 66 | assert result.startswith("ERROR:") 67 | assert "test error" in result 68 | -------------------------------------------------------------------------------- /src/benchmax/envs/excel/workdir/excel_code_runner_mcp.py: -------------------------------------------------------------------------------- 1 | from fastmcp import FastMCP 2 | import subprocess 3 | import sys 4 | 5 | try: 6 | from excel_utils import excel_to_str_repr 7 | except Exception: 8 | # Added except local import for unit testing purposes 9 | from .excel_utils import excel_to_str_repr 10 | 11 | mcp = FastMCP( 12 | name="ExcelCodeRunner", 13 | instructions="This server provides a tool for running Python code to manipulate Excel files.", 14 | ) 15 | 16 | WHITE_LIKE_COLORS = [ 17 | "00000000", 18 | "FFFFFFFF", 19 | "FFFFFF00", 20 | ] 21 | 22 | 23 | def run_excel_code_impl(python_code: str, output_excel_path: str) -> str: 24 | """ 25 | Run Python code which should use openpyxl to manipulate an Excel file. 26 | Call load_workbook with the input excel path as specified by the user. 27 | Remember to save the workbook to the output path that you specified and then call close() so you do not overwrite the input file. 28 | 29 | If code executes with no errors, return the string representation of the Excel file with styles. 30 | If there are errors, return an error message. 31 | """ 32 | code_path = "script.py" 33 | # Write the user code to a file 34 | with open(code_path, "w") as f: 35 | f.write(python_code) 36 | try: 37 | subprocess.run( 38 | [sys.executable, code_path], check=True, capture_output=True, timeout=60 39 | ) 40 | except subprocess.CalledProcessError as e: 41 | return f"ERROR: User code failed: {e.stderr.decode()}" 42 | except Exception as e: 43 | return f"ERROR: Error running user code: {str(e)}" 44 | # Convert the manipulated Excel file to JSON with styles 45 | excel_str = excel_to_str_repr(output_excel_path) 46 | return excel_str 47 | 48 | 49 | @mcp.tool 50 | def run_excel_code(python_code: str, output_excel_path: str) -> str: 51 | """ 52 | Run Python code which should use openpyxl to manipulate an Excel file. 53 | If code executes with no errors, returns the string representation of the Excel file with styles. 54 | If there are errors, return an error message. 55 | """ 56 | return run_excel_code_impl(python_code, output_excel_path) 57 | 58 | 59 | if __name__ == "__main__": 60 | mcp.run(show_banner=False) 61 | -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/provisioners/manual_provisioner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manual provisioner for using pre-existing servers. 3 | """ 4 | 5 | import logging 6 | from typing import List 7 | from .base_provisioner import BaseProvisioner 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class ManualProvisioner(BaseProvisioner): 13 | """ 14 | Provisioner for manually specified server addresses. 15 | 16 | Use this when you have already started servers and want to provide 17 | their addresses directly. Useful for: 18 | - Debugging with pre-started local servers 19 | - Testing against persistent test infrastructure 20 | - Using servers managed by external systems 21 | 22 | Example: 23 | provisioner = ManualProvisioner([ 24 | "localhost:8080", 25 | "localhost:8081", 26 | "192.168.1.10:8080" 27 | ]) 28 | """ 29 | 30 | def __init__(self, addresses: List[str]): 31 | """ 32 | Initialize with pre-existing server addresses. 33 | 34 | Args: 35 | addresses: List of server addresses in "host:port" format. 36 | """ 37 | if not addresses: 38 | raise ValueError("ManualProvisioner requires at least one address") 39 | 40 | self._addresses = addresses 41 | logger.info(f"ManualProvisioner configured with {len(addresses)} addresses") 42 | 43 | @property 44 | def num_servers(self) -> int: 45 | """ 46 | Total number of servers 47 | """ 48 | return len(self._addresses) 49 | 50 | async def provision_servers(self, api_secret: str) -> List[str]: 51 | """ 52 | Return the pre-configured server addresses. 53 | 54 | Args: 55 | api_secret: Unused in this function. Servers already have set ther api secret. 56 | 57 | Returns: 58 | The list of addresses provided during initialization. 59 | """ 60 | logger.info(f"Using {len(self._addresses)} manually configured servers") 61 | return self._addresses.copy() 62 | 63 | async def teardown(self) -> None: 64 | """ 65 | No-op teardown since servers are externally managed. 66 | 67 | ManualProvisioner does not start servers, so it does not stop them. 68 | The user is responsible for managing the server lifecycle. 69 | """ 70 | logger.info("ManualProvisioner teardown (servers externally managed)") 71 | -------------------------------------------------------------------------------- /src/benchmax/envs/math/math_env.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | import sky 4 | 5 | from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv 6 | from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner 7 | from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner 8 | from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner 9 | from benchmax.envs.types import StandardizedExample 10 | 11 | SYSTEM_PROMPT = """Please use the tools provided to do any computation. 12 | Write your complete answer on the final line only, within the xml tags .\n 13 | """ 14 | 15 | 16 | class MathEnv(ParallelMcpEnv): 17 | """Environment for math problems, using local MCP tools.""" 18 | 19 | system_prompt: str = SYSTEM_PROMPT 20 | 21 | def __init__(self, workdir_path: Path, provisioner: BaseProvisioner, **kwargs): 22 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs) 23 | 24 | @classmethod 25 | def dataset_preprocess(cls, example: Any, **kwargs) -> StandardizedExample: 26 | return StandardizedExample( 27 | prompt=example.get("task", ""), 28 | ground_truth=example.get("answer", ""), 29 | init_rollout_args=None, 30 | ) 31 | 32 | 33 | class MathEnvLocal(MathEnv): 34 | """Import this env to run environment locally""" 35 | 36 | def __init__(self, num_local_servers: int = 5, **kwargs): 37 | workdir_path = Path(__file__).parent / "workdir" 38 | provisioner = LocalProvisioner( 39 | workdir_path=workdir_path, num_servers=num_local_servers 40 | ) 41 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs) 42 | 43 | 44 | class MathEnvSkypilot(MathEnv): 45 | """Import this env to run environment on any cloud (i.e. AWS / GCP / Azure) with Skypilot""" 46 | 47 | def __init__( 48 | self, 49 | cloud: sky.clouds.Cloud = sky.Azure(), 50 | num_nodes: int = 2, 51 | servers_per_node: int = 5, 52 | **kwargs, 53 | ): 54 | workdir_path = Path(__file__).parent / "workdir" 55 | provisioner = SkypilotProvisioner( 56 | workdir_path=workdir_path, 57 | cloud=cloud, 58 | num_nodes=num_nodes, 59 | servers_per_node=servers_per_node, 60 | ) 61 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs) 62 | -------------------------------------------------------------------------------- /examples/skyrl/benchmax_math.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import ray 3 | from ray.actor import ActorProxy 4 | from omegaconf import DictConfig 5 | import skyrl_gym 6 | from skyrl_train.utils import initialize_ray 7 | from skyrl_train.entrypoints.main_base import BasePPOExp, validate_cfg 8 | from skyrl_train.config.utils import CONFIG_DIR 9 | from skyrl_gym.envs import register 10 | 11 | from benchmax.adapters.skyrl.skyrl_adapter import ( 12 | cleanup_actor, 13 | get_or_create_benchmax_env_actor, 14 | load_benchmax_env_skyrl, 15 | ) 16 | from benchmax.adapters.benchmax_wrapper import BenchmaxEnv 17 | 18 | BENCHMAX_ACTOR_NAME = "BenchmaxEnvService" 19 | 20 | 21 | @ray.remote(num_cpus=1) 22 | def skyrl_entrypoint(cfg: DictConfig): 23 | actor = None 24 | try: 25 | # UNCOMMENT the following to run MathEnv with skypilot (comment the other) 26 | # from benchmax.envs.math.math_env import MathEnvSkypilot 27 | # import sky 28 | 29 | # actor = get_or_create_benchmax_env_actor( 30 | # MathEnvSkypilot, 31 | # env_kwargs={ 32 | # "cloud": sky.Azure(), 33 | # "num_nodes": 5, 34 | # "servers_per_node": 32, 35 | # }, # samples / prompt * batch size = 160 = 32 * 5 36 | # actor_name=BENCHMAX_ACTOR_NAME, 37 | # ) 38 | # UNCOMMENT the following to run MathEnv locally (comment the other) 39 | from benchmax.envs.math.math_env import MathEnvLocal 40 | 41 | actor = get_or_create_benchmax_env_actor( 42 | MathEnvLocal, 43 | env_kwargs={ 44 | "num_local_servers": 160 45 | }, # samples / prompt * batch size = 160 46 | actor_name=BENCHMAX_ACTOR_NAME, 47 | ) 48 | register( 49 | id="MathEnv", 50 | entry_point=load_benchmax_env_skyrl, 51 | kwargs={"actor": actor}, 52 | ) 53 | skyrl_gym.pprint_registry() 54 | 55 | exp = BasePPOExp(cfg) 56 | exp.run() 57 | 58 | finally: 59 | cleanup_actor(actor) 60 | 61 | 62 | @hydra.main( 63 | config_path=str(CONFIG_DIR), 64 | config_name="ppo_base_config", 65 | version_base=None, 66 | ) 67 | def main(cfg: DictConfig) -> None: 68 | try: 69 | validate_cfg(cfg) 70 | initialize_ray(cfg) 71 | ray.get(skyrl_entrypoint.remote(cfg)) 72 | finally: 73 | try: 74 | benchmax_actor: ActorProxy[BenchmaxEnv] = ray.get_actor(BENCHMAX_ACTOR_NAME) 75 | cleanup_actor(benchmax_actor) 76 | except Exception: 77 | pass 78 | finally: 79 | ray.shutdown() 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /examples/skyrl/benchmax_excel.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import ray 3 | from ray.actor import ActorProxy 4 | from omegaconf import DictConfig 5 | import skyrl_gym 6 | from skyrl_train.utils import initialize_ray 7 | from skyrl_train.entrypoints.main_base import BasePPOExp, validate_cfg 8 | from skyrl_train.config.utils import CONFIG_DIR 9 | from skyrl_gym.envs import register 10 | 11 | from benchmax.adapters.skyrl.skyrl_adapter import ( 12 | cleanup_actor, 13 | get_or_create_benchmax_env_actor, 14 | load_benchmax_env_skyrl, 15 | ) 16 | from benchmax.adapters.benchmax_wrapper import BenchmaxEnv 17 | 18 | BENCHMAX_ACTOR_NAME = "BenchmaxEnvService" 19 | 20 | 21 | @ray.remote(num_cpus=1) 22 | def skyrl_entrypoint(cfg: DictConfig): 23 | actor = None 24 | try: 25 | # UNCOMMENT the following to run ExcelEnv with skypilot (comment the other) 26 | from benchmax.envs.excel.excel_env import ExcelEnvSkypilot 27 | import sky 28 | 29 | actor = get_or_create_benchmax_env_actor( 30 | ExcelEnvSkypilot, 31 | env_kwargs={ 32 | "cloud": sky.Azure(), 33 | "num_nodes": 5, 34 | "servers_per_node": 20, 35 | }, # samples / prompt * batch size = 100 = 20 * 5 36 | actor_name=BENCHMAX_ACTOR_NAME, 37 | ) 38 | 39 | # UNCOMMENT the following to run ExcelEnv locally (comment the other) 40 | # from benchmax.envs.excel.excel_env import ExcelEnvLocal 41 | 42 | # actor = get_or_create_benchmax_env_actor( 43 | # ExcelEnvLocal, 44 | # env_kwargs={ 45 | # "num_local_servers": 100 46 | # }, # samples / prompt * batch size = 100 47 | # actor_name=BENCHMAX_ACTOR_NAME, 48 | # ) 49 | 50 | register( 51 | id="ExcelEnv", 52 | entry_point=load_benchmax_env_skyrl, 53 | kwargs={"actor": actor}, 54 | ) 55 | skyrl_gym.pprint_registry() 56 | 57 | exp = BasePPOExp(cfg) 58 | exp.run() 59 | 60 | finally: 61 | cleanup_actor(actor) 62 | 63 | 64 | @hydra.main( 65 | config_path=str(CONFIG_DIR), 66 | config_name="ppo_base_config", 67 | version_base=None, 68 | ) 69 | def main(cfg: DictConfig) -> None: 70 | try: 71 | validate_cfg(cfg) 72 | initialize_ray(cfg) 73 | ray.get(skyrl_entrypoint.remote(cfg)) 74 | finally: 75 | try: 76 | benchmax_actor: ActorProxy[BenchmaxEnv] = ray.get_actor(BENCHMAX_ACTOR_NAME) 77 | cleanup_actor(benchmax_actor) 78 | except Exception: 79 | pass 80 | finally: 81 | ray.shutdown() 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /tests/unit/envs/mcp/provisioners/test_manual.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from benchmax.envs.mcp.provisioners import ManualProvisioner 3 | 4 | 5 | class TestManualProvisioner: 6 | """Tests for ManualProvisioner class.""" 7 | 8 | def test_init_with_valid_addresses(self): 9 | """Test initialization with valid addresses.""" 10 | addresses = ["localhost:8080", "192.168.1.10:8080"] 11 | provisioner = ManualProvisioner(addresses) 12 | 13 | assert provisioner._addresses == addresses 14 | 15 | def test_init_with_empty_list_raises_error(self): 16 | """Test that empty address list raises ValueError.""" 17 | with pytest.raises(ValueError, match="at least one address"): 18 | ManualProvisioner([]) 19 | 20 | @pytest.mark.asyncio 21 | async def test_provision_servers_returns_addresses(self): 22 | """Test that provision_servers returns configured addresses.""" 23 | addresses = ["localhost:8080", "localhost:8081"] 24 | provisioner = ManualProvisioner(addresses) 25 | 26 | result = await provisioner.provision_servers("dummy-api-secret") 27 | 28 | assert result == addresses 29 | 30 | @pytest.mark.asyncio 31 | async def test_provision_servers_returns_copy(self): 32 | """Test that provision_servers returns a copy, not the original list.""" 33 | addresses = ["localhost:8080"] 34 | provisioner = ManualProvisioner(addresses) 35 | 36 | result = await provisioner.provision_servers("dummy-api-secret") 37 | result.append("modified") 38 | 39 | # Original should be unchanged 40 | assert provisioner._addresses == ["localhost:8080"] 41 | 42 | @pytest.mark.asyncio 43 | async def test_teardown_is_noop(self): 44 | """Test that teardown completes without error (no-op).""" 45 | provisioner = ManualProvisioner(["localhost:8080"]) 46 | await provisioner.teardown() # Should not raise 47 | 48 | @pytest.mark.asyncio 49 | async def test_multiple_provision_calls(self): 50 | """Test that provision_servers can be called multiple times.""" 51 | provisioner = ManualProvisioner(["localhost:8080"]) 52 | 53 | result1 = await provisioner.provision_servers("dummy-api-secret") 54 | result2 = await provisioner.provision_servers("dummy-api-secret") 55 | 56 | assert result1 == result2 57 | 58 | @pytest.mark.asyncio 59 | async def test_provision_after_teardown(self): 60 | """Test that provision_servers works after teardown.""" 61 | provisioner = ManualProvisioner(["localhost:8080"]) 62 | 63 | await provisioner.provision_servers("dummy-api-secret") 64 | await provisioner.teardown() 65 | result = await provisioner.provision_servers("dummy-api-secret") 66 | 67 | assert result == ["localhost:8080"] 68 | -------------------------------------------------------------------------------- /src/benchmax/envs/excel/README.md: -------------------------------------------------------------------------------- 1 | # Excel Environment 2 | 3 | This environment provides capabilities for interacting with Excel files through either LibreOffice (Linux) or Microsoft Excel (Windows/macOS). 4 | 5 | This is based off the [SpreadsheetBench Benchmark](https://spreadsheetbench.github.io/) 6 | 7 | ## Prerequisites 8 | 9 | **Important**: Before using this environment, ensure you have the appropriate spreadsheet application installed: 10 | - **Linux**: LibreOffice must be installed 11 | ```bash 12 | sudo apt install libreoffice 13 | ``` 14 | - **Windows/macOS**: Microsoft Excel must be installed 15 | 16 | ## Installation 17 | 18 | ### Linux 19 | ```bash 20 | pip install "benchmax[excel,skypilot]" 21 | ``` 22 | Includes: 23 | - openpyxl: For Excel file manipulation 24 | - fastmcp: For MCP server functionality 25 | 26 | ### Windows/macOS 27 | ```bash 28 | pip install "benchmax[excel-mac-windows,skypilot]" 29 | ``` 30 | Includes: 31 | - openpyxl: For Excel file manipulation 32 | - xlwings: For direct Excel application interaction 33 | - fastmcp: For MCP server functionality 34 | 35 | ## Usage 36 | 37 | Use `ExcelEnvLocal` to run the servers locally on the same machine as benchmax or use `ExcelEnvSkypilot` to parallelize the servers across multiple nodes. 38 | 39 | ## Available Tool 40 | 41 | The environment provides a single MCP tool for Excel manipulation: 42 | 43 | ### run_excel_code 44 | Executes Python code that uses openpyxl to manipulate Excel files. The tool: 45 | - Takes Python code and an output Excel path as input 46 | - Runs the code in a controlled environment 47 | - Returns a string representation of the modified Excel file 48 | - Preserves spreadsheet formatting (colors, fonts, styles) 49 | - Handles both cell-level and sheet-level operations 50 | 51 | ## Reward Functions 52 | 53 | Reward functions measure how well a generated spreadsheet matches the expected output. 54 | 55 | ### Default Reward Function 56 | 57 | The built-in reward function compares the output spreadsheet against a ground truth using only the cells specified in the task. 58 | 59 | - If all the values in the relevant cells are correct, the reward is **1.0** 60 | - If there are any mismatches, the reward is **0.0** 61 | 62 | ### Comparison Strategy 63 | 64 | - Compares evaluated values, not formulas (e.g., it checks the result `10`, not the formula `=5+5`) 65 | - Only compares the cells within the defined answer range 66 | - Supports multiple cell types like numbers, strings, times, and dates 67 | - Can optionally check visual formatting, such as: 68 | - Fill color (background) 69 | - Font color 70 | 71 | ### Error Handling 72 | 73 | If the comparison fails due to a missing file, invalid cell reference, or formatting error, the reward function will return **0.0** and log a helpful message for debugging. 74 | 75 | ### Outcome 76 | 77 | - **1.0** score if the generated spreadsheet is fully correct 78 | - **0.0** if any discrepancy is found -------------------------------------------------------------------------------- /src/benchmax/envs/how-to-extend-base-env.md: -------------------------------------------------------------------------------- 1 | ## Create your env from BaseEnv 2 | 3 | ### 1. **Define the system prompt** 4 | 5 | This helps instruct the model on how to interact with the tool and format output. 6 | 7 | ```python 8 | SYSTEM_PROMPT = """Use the `evaluate` tool to perform any computation. 9 | Write your final answer on the last line inside .... 10 | """ 11 | ``` 12 | 13 | ### 2. **Create a reward function** 14 | 15 | We'll score the model 1.0 if it places the correct answer inside `...` tags: 16 | 17 | ```python 18 | import re 19 | from html import unescape 20 | from pathlib import Path 21 | 22 | def reward_func(prompt: str, completion: str, ground_truth: str, workspace: Path, **kwargs) -> float: 23 | m = re.search(r'(.*?)', completion, flags=re.IGNORECASE | re.DOTALL) 24 | if not m: 25 | return 0.0 26 | answer_text = unescape(m.group(1)).strip().lower() 27 | return float(answer_text == ground_truth.lower()) 28 | ``` 29 | 30 | ### 3. **Define your math tool** 31 | 32 | A simple safe `eval` for math expressions: 33 | 34 | ```python 35 | def evaluate_expression(expr: str) -> str: 36 | try: 37 | result = eval(expr, {"__builtins__": {}}) 38 | return str(result) 39 | except Exception as e: 40 | return f"Error: {str(e)}" 41 | ``` 42 | 43 | ### 4. **Create the environment class** 44 | 45 | Bring it all together in a subclass of `BaseEnv`: 46 | 47 | ```python 48 | class SimpleMathEnv(BaseEnv): 49 | system_prompt: str = SYSTEM_PROMPT 50 | _reward_funcs: List[RewardFunction] = [reward_func] 51 | 52 | def __init__(self): 53 | eval_tool = ToolDefinition( 54 | name="evaluate", 55 | description="Safely evaluate a math expression like '2 + 3 * 4'.", 56 | input_schema={ 57 | "type": "object", 58 | "properties": { 59 | "expr": { 60 | "type": "string", 61 | "description": "Math expression to evaluate.", 62 | }, 63 | }, 64 | "required": ["expr"], 65 | } 66 | ) 67 | self.tools: Dict[str, Tuple[ToolDefinition, Callable]] = { 68 | "evaluate": (eval_tool, evaluate_expression) 69 | } 70 | def dataset_preprocess(self, example: dict) -> StandardizedExample: 71 | return { 72 | "prompt": f"Question: {example['question']}\n\nWrite your answer below.", 73 | "ground_truth": example.get("answer", ""), 74 | "init_rollout_args": {} 75 | } 76 | 77 | def list_tools(self) -> List[ToolDefinition]: 78 | return [tool_def for tool_def, _ in self.tools.values()] 79 | 80 | def run_tool(self, rollout_id: str, tool_name: str, **tool_args) -> Any: 81 | _, tool_fn = self.tools[tool_name] 82 | return tool_fn(**tool_args) 83 | ``` 84 | 85 | -------------------------------------------------------------------------------- /src/benchmax/envs/wikipedia/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import html 3 | import re 4 | import threading 5 | from typing import Any, Dict, List, Optional 6 | 7 | import aiohttp 8 | 9 | 10 | def clean_html(raw: str) -> str: 11 | """Strip HTML tags and unescape entities.""" 12 | text = re.sub(r"<[^>]+>", "", raw) 13 | return html.unescape(text) 14 | 15 | 16 | class APIKeyRotator: 17 | """Thread-safe round-robin iterator over API keys.""" 18 | 19 | def __init__(self, keys: Optional[List[str]] = None): 20 | self._keys: List[str] = keys or [] 21 | self._lock = threading.Lock() 22 | self._idx = 0 23 | 24 | def next(self) -> Optional[str]: 25 | """Return the next key, or None if no keys configured.""" 26 | if not self._keys: 27 | return None 28 | with self._lock: 29 | key = self._keys[self._idx] 30 | self._idx = (self._idx + 1) % len(self._keys) 31 | return key 32 | 33 | 34 | class RateLimitExceeded(Exception): 35 | """Raised when the API repeatedly returns HTTP 429.""" 36 | 37 | 38 | async def safe_request( 39 | method: str, 40 | url: str, 41 | *, 42 | headers: Dict[str, str], 43 | params: Dict[str, Any] | None = None, 44 | timeout: float = 10.0, 45 | json: Any | None = None, 46 | max_retries: int = 3, 47 | retry_delay_seconds: float = 20, 48 | rate_limit_seconds: float = 2.5, 49 | ) -> Optional[aiohttp.ClientResponse]: 50 | """ 51 | Async HTTP request with exponential backoff on 429 rate limits. 52 | 53 | Args: 54 | method: HTTP method (GET, POST, etc.) 55 | url: Target URL 56 | headers: Request headers 57 | params: Query parameters 58 | timeout: Request timeout in seconds 59 | json: JSON body for request 60 | max_retries: Maximum retry attempts on 429 61 | retry_delay_seconds: Base delay between retries 62 | rate_limit_seconds: Initial delay before first request 63 | 64 | Returns: 65 | aiohttp.ClientResponse object 66 | 67 | Raises: 68 | RateLimitExceeded: When max retries exhausted on 429 errors 69 | """ 70 | await asyncio.sleep(rate_limit_seconds) 71 | 72 | async with aiohttp.ClientSession() as session: 73 | for attempt in range(max_retries + 1): 74 | async with session.request( 75 | method, 76 | url, 77 | headers=headers, 78 | params=params, 79 | json=json, 80 | timeout=aiohttp.ClientTimeout(total=timeout), 81 | ) as resp: 82 | if resp.status != 429: 83 | # Read response content before returning 84 | content = await resp.read() 85 | # Create a new response object with the content 86 | resp._body = content 87 | return resp 88 | 89 | if attempt == max_retries: 90 | raise RateLimitExceeded( 91 | f"Rate limit hit and {max_retries} retries exhausted." 92 | ) 93 | 94 | print(f"Rate limit hit, retrying in {retry_delay_seconds:.1f}s...") 95 | await asyncio.sleep(retry_delay_seconds) 96 | -------------------------------------------------------------------------------- /tests/unit/prompts/test_tools.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List 2 | from benchmax.envs.base_env import ToolDefinition 3 | from benchmax.prompts.tools import ( 4 | mcp2openai, 5 | parse_hermes_tool_call, 6 | render_tools_prompt, 7 | ) 8 | 9 | 10 | def test_mcp2openai(): 11 | # Test basic conversion 12 | tool_def = ToolDefinition( 13 | name="test_tool", 14 | description="A test tool", 15 | input_schema={"type": "object", "properties": {"arg1": {"type": "string"}}}, 16 | ) 17 | result = mcp2openai(tool_def) 18 | 19 | assert result["type"] == "function" 20 | assert result["function"]["name"] == "test_tool" 21 | assert result["function"]["description"] == "A test tool" 22 | assert result["function"]["parameters"] == { 23 | "type": "object", 24 | "properties": {"arg1": {"type": "string"}}, 25 | "required": [], 26 | } 27 | assert result["function"]["strict"] is False 28 | 29 | # Test with empty input schema 30 | tool_def_no_schema = ToolDefinition( 31 | name="empty_tool", description="Tool with no schema", input_schema=None 32 | ) 33 | result_no_schema = mcp2openai(tool_def_no_schema) 34 | assert result_no_schema["function"]["parameters"] == {"required": []} 35 | 36 | 37 | def test_parse_hermes_tool_call(): 38 | # Test single tool call 39 | single_call = """{"name": "get_weather", "arguments": {"location": "New York"}}""" 40 | result: List[Dict[str, Any]] = parse_hermes_tool_call(single_call) 41 | assert len(result) == 1 42 | assert result[0]["name"] == "get_weather" 43 | assert result[0]["arguments"]["location"] == "New York" 44 | 45 | # Test multiple tool calls 46 | multiple_calls = """ 47 | {"name": "tool1", "arguments": {"arg1": "value1"}} 48 | {"name": "tool2", "arguments": {"arg2": "value2"}} 49 | """ 50 | result: List[Dict[str, Any]] = parse_hermes_tool_call(multiple_calls) 51 | assert len(result) == 2 52 | assert result[0]["name"] == "tool1" 53 | assert result[1]["name"] == "tool2" 54 | 55 | # Test empty string 56 | assert parse_hermes_tool_call("") == [] 57 | 58 | 59 | def test_render_tools_prompt(): 60 | # Test with empty tool list 61 | assert render_tools_prompt([]) == "" 62 | 63 | # Test with single tool 64 | tool_def = ToolDefinition( 65 | name="test_tool", 66 | description="A test tool", 67 | input_schema={"type": "object", "properties": {"arg1": {"type": "string"}}}, 68 | ) 69 | result = render_tools_prompt([tool_def], system_message="Test System Message") 70 | 71 | assert "Test System Message" in result 72 | assert "# Tools" in result 73 | assert "" in result 74 | assert "" in result 75 | assert "test_tool" in result 76 | assert "" in result 77 | assert "" in result 78 | 79 | # Test with multiple tools 80 | tool_def2 = ToolDefinition( 81 | name="another_tool", 82 | description="Another test tool", 83 | input_schema={"type": "object", "properties": {"arg2": {"type": "number"}}}, 84 | ) 85 | result_multiple = render_tools_prompt([tool_def, tool_def2]) 86 | assert "test_tool" in result_multiple 87 | assert "another_tool" in result_multiple 88 | -------------------------------------------------------------------------------- /src/benchmax/prompts/tools.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, List 3 | 4 | from benchmax.envs.types import ToolDefinition 5 | 6 | def mcp2openai(mcp_tool: ToolDefinition) -> dict: 7 | """Convert a ToolDefinition to an OpenAI Function Call format.""" 8 | openai_format = { 9 | "type": "function", 10 | "function": { 11 | "name": mcp_tool.name, 12 | "description": mcp_tool.description, 13 | "parameters": mcp_tool.input_schema or {}, 14 | "strict": False, 15 | }, 16 | } 17 | if not openai_format["function"]["parameters"].get("required", None): 18 | openai_format["function"]["parameters"]["required"] = [] 19 | return openai_format 20 | 21 | def parse_hermes_tool_call(text: str) -> List[Dict[str, str]]: 22 | """ 23 | Parse a tool call from Hermes XML format. 24 | Example: 25 | 26 | {"name": "get_weather", "arguments": {"location": "New York"}} 27 | 28 | 29 | {"name": "get_weather", "arguments": {"location": "New York"}} 30 | 31 | """ 32 | import re 33 | import json 34 | # Match all tool call XML tags and extract the JSON content 35 | matches = re.finditer(r'(.*?)', text, re.DOTALL) 36 | tool_calls = [] 37 | 38 | for match in matches: 39 | tool_call_json = match.group(1).strip() 40 | try: 41 | tool_calls.append(json.loads(tool_call_json)) 42 | except json.JSONDecodeError as e: 43 | return [] 44 | 45 | return tool_calls if tool_calls else [] 46 | 47 | def render_tools_prompt( 48 | tool_definitions: List[ToolDefinition], 49 | system_message: str = "" 50 | ) -> str: 51 | """ 52 | Build the prompt block that advertises the available function tools to the model. 53 | 54 | Parameters 55 | ---------- 56 | tool_schema : list[dict] 57 | A list of tool descriptors in the OpenAI Tools / function-calling format. 58 | system_message : str, optional 59 | The system message that will be placed at the top of the prompt 60 | (defaults to the Qwen assistant greeting). 61 | 62 | Returns 63 | ------- 64 | str 65 | A fully-rendered prompt string with system message and tool information. 66 | """ 67 | tool_schema = [mcp2openai(tool_def) for tool_def in tool_definitions] 68 | if not tool_schema: 69 | return system_message 70 | 71 | # Header 72 | lines = [system_message, "", "# Tools", "", 73 | "You may call one or more functions to assist with the user query.", 74 | "", 75 | "You are provided with function signatures within XML tags:", 76 | ""] 77 | 78 | # One line-per-tool JSON dump (compact, no extra spaces) 79 | for tool in tool_schema: 80 | lines.append(json.dumps(tool, separators=(",", ":"))) 81 | 82 | lines.extend([ 83 | "", 84 | "", 85 | "For each function call, return a json object with function name and arguments within XML tags:", 86 | "", 87 | "{\"name\": , \"arguments\": }", 88 | "", 89 | ]) 90 | 91 | return "\n".join(lines) 92 | 93 | -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/provisioners/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for server provisioners 3 | """ 4 | 5 | import re 6 | import shutil 7 | import tempfile 8 | from pathlib import Path 9 | from typing import List, Optional 10 | 11 | 12 | def setup_sync_dir(workdir_path: Path) -> Path: 13 | """ 14 | Set up a temporary directory to sync the workdir contents. 15 | 16 | This creates a temp directory and copies: 17 | 1. proxy_server.py from the mcp/ directory 18 | 2. All contents of the provided workdir_path 19 | 20 | Args: 21 | workdir_path: Path to workdir containing mcp_config.yaml, setup.sh, etc. 22 | 23 | Returns: 24 | Path to the temporary sync directory. 25 | 26 | Raises: 27 | FileNotFoundError: If required files are missing. 28 | ValueError: If workdir_path is not a directory. 29 | """ 30 | sync_dir = Path(tempfile.mkdtemp(prefix="benchmax_skypilot_")) 31 | 32 | try: 33 | # Copy proxy_server.py (located in the mcp/ directory) 34 | src_server_path = Path(__file__).parent.parent / "proxy_server.py" 35 | if not src_server_path.exists(): 36 | raise FileNotFoundError( 37 | f"Expected proxy_server.py at {src_server_path}, but not found." 38 | ) 39 | shutil.copy(src_server_path, sync_dir / "proxy_server.py") 40 | 41 | # Validate workdir exists and is a directory 42 | if not workdir_path.exists(): 43 | raise FileNotFoundError( 44 | f"Expected workdir_path at {workdir_path}, but not found." 45 | ) 46 | if not workdir_path.is_dir(): 47 | raise ValueError( 48 | f"Expected workdir_path at {workdir_path} to be a directory." 49 | ) 50 | 51 | # Validate required files in the workdir using regex patterns 52 | required_patterns = { 53 | r"^reward_fn\.py$": "reward_fn.py", 54 | r"^setup\.sh$": "setup.sh", 55 | r"^mcp_config\.(yaml|yml)$": "mcp_config.yaml or mcp_config.yml", 56 | } 57 | 58 | workdir_files = {f.name for f in workdir_path.iterdir() if f.is_file()} 59 | 60 | for pattern, description in required_patterns.items(): 61 | pattern_re = re.compile(pattern) 62 | if not any(pattern_re.match(filename) for filename in workdir_files): 63 | raise FileNotFoundError( 64 | f"Required file matching '{description}' not found in workdir_path '{workdir_path}'." 65 | ) 66 | 67 | # Copy all contents of the workdir 68 | shutil.copytree(workdir_path, sync_dir, dirs_exist_ok=True) 69 | 70 | except Exception: 71 | shutil.rmtree(sync_dir, ignore_errors=True) 72 | raise 73 | 74 | return sync_dir 75 | 76 | 77 | def cleanup_dir(path: Optional[Path]) -> None: 78 | """ 79 | Recursively delete a directory if it exists. 80 | 81 | Args: 82 | path: Path to directory to delete. If None or doesn't exist, no-op. 83 | """ 84 | if path and path.exists() and path.is_dir(): 85 | shutil.rmtree(path, ignore_errors=True) 86 | 87 | 88 | def get_setup_command() -> str: 89 | """Generate setup command for installing dependencies.""" 90 | return """ 91 | # Install uv 92 | curl -LsSf https://astral.sh/uv/install.sh | sh 93 | UV_VENV_CLEAR=1 uv venv ~/venv && source ~/venv/bin/activate 94 | uv pip install fastmcp~=2.12.0 pyyaml psutil 95 | bash setup.sh 96 | """ 97 | 98 | 99 | def get_run_command(ports: List[str]) -> str: 100 | """Generate command to start multiple proxy servers on different ports.""" 101 | commands = [] 102 | for port in ports: 103 | cmd = f"source ~/venv/bin/activate && python proxy_server.py --port {port} --base-dir ../workspace &" 104 | commands.append(cmd) 105 | commands.append("wait") # Wait for all background processes 106 | return "\n".join(commands) 107 | -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/example_workdir/reward_fn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reward functions for demo calculator MCP server. 3 | 4 | Three reward types: 5 | 1. Stateless completion check 6 | 2. Tool call variable in memory check 7 | 3. Workspace log check 8 | 9 | All reward functions will receive the same ground truth and completion. 10 | When defining the dataset, you can determine the shape of your ground truth. 11 | 12 | In this example, the ground truth is a dictionary of the shape: 13 | { 14 | 'completion': str, 15 | 'variable': { name: str, expected_value: float }, 16 | 'log': { filename: str, expected_content: str } 17 | } 18 | 19 | Each reward fn extract what they need from the ground_truth dictionary. 20 | Each of these reward shows way of computing reward using completion, calling mcp server tool, 21 | and reaading from the workspace that the MCP is operating in. 22 | 23 | """ 24 | 25 | from pathlib import Path 26 | from typing import Any, Callable, Dict, Union, Awaitable 27 | from mcp.types import TextContent 28 | from fastmcp import Client 29 | from fastmcp.exceptions import ToolError 30 | 31 | RewardFunction = Callable[..., Union[float, Awaitable[float]]] 32 | 33 | 34 | # ------------------------------- 35 | # Reward 0: Stateless completion check 36 | # ------------------------------- 37 | async def completion_match_reward( 38 | completion: str, 39 | ground_truth: dict, 40 | mcp_client: Client, 41 | workspace: Path, 42 | **kwargs: Any, 43 | ) -> float: 44 | """ 45 | Return 1.0 if completion matches ground_truth['completion'], else 0.0 46 | 47 | Uses: ground_truth['completion'] (str) 48 | """ 49 | expected = ground_truth.get("completion", "") 50 | return 1.0 if completion.strip() == expected.strip() else 0.0 51 | 52 | 53 | # ------------------------------- 54 | # Reward 1: Tool call variable in memory check 55 | # ------------------------------- 56 | async def variable_in_memory_reward( 57 | completion: str, ground_truth: dict, mcp_client: Client, workspace: Path, **kwargs 58 | ) -> float: 59 | """ 60 | Reward uses tool call to match in-memory variable value. 61 | 62 | Uses: ground_truth['variable'] = {"name": str, "expected_value": float} 63 | """ 64 | variable_spec = ground_truth.get("variable", {}) 65 | var_name = variable_spec.get("name") 66 | expected = variable_spec.get("expected_value") 67 | 68 | if not var_name or expected is None: 69 | return 0.0 70 | 71 | if not mcp_client or not mcp_client.is_connected(): 72 | return 0.0 73 | 74 | try: 75 | # Call tool via MCP client 76 | response = await mcp_client.call_tool("get_variable", {"name": var_name}) 77 | 78 | # Extract text from TextContent objects 79 | text_contents = [] 80 | for content in response.content: 81 | if isinstance(content, TextContent): 82 | text_contents.append(content.text) 83 | 84 | combined_text = "\n".join(text_contents).strip() 85 | value = float(combined_text) 86 | 87 | return 1.0 if value == expected else 0.0 88 | 89 | except ToolError: 90 | return 0.0 91 | except Exception: 92 | return 0.0 93 | 94 | 95 | # ------------------------------- 96 | # Reward 2: Workspace log check 97 | # ------------------------------- 98 | async def log_in_workspace_reward( 99 | completion: str, 100 | ground_truth: dict, 101 | mcp_client: Client, 102 | workspace: Path, 103 | **kwargs: Any, 104 | ) -> float: 105 | """ 106 | Reward based on workspace file content. 107 | 108 | Uses: ground_truth['log'] = {"filename": str, "expected_content": str} 109 | """ 110 | log_spec = ground_truth.get("log", {}) 111 | filename = log_spec.get("filename") 112 | expected = log_spec.get("expected_content", "") 113 | 114 | if not filename: 115 | return 0.0 116 | 117 | file_path = Path(workspace) / filename 118 | if not file_path.exists(): 119 | return 0.0 120 | 121 | content = file_path.read_text().strip() 122 | return 1.0 if content == expected.strip() else 0.0 123 | 124 | 125 | # ------------------------------- 126 | # Export reward functions 127 | # ------------------------------- 128 | reward_functions: Dict[str, RewardFunction] = { 129 | "completion": completion_match_reward, 130 | "variable": variable_in_memory_reward, 131 | "log": log_in_workspace_reward, 132 | } 133 | -------------------------------------------------------------------------------- /tests/integration/envs/mcp/provisioners/test_local_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integration tests for LocalProvisioner. 3 | """ 4 | 5 | import pytest 6 | import asyncio 7 | from pathlib import Path 8 | from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner 9 | from tests.integration.envs.mcp.provisioners.utils import wait_for_server_health, check_health 10 | 11 | 12 | class TestEndToEnd: 13 | """End-to-end integration tests for provisioning and teardown.""" 14 | 15 | @pytest.mark.asyncio 16 | @pytest.mark.slow 17 | async def test_single_server_lifecycle(self, example_workdir: Path): 18 | """Test complete lifecycle of a single server: provision, verify, and teardown.""" 19 | provisioner = LocalProvisioner( 20 | workdir_path=example_workdir, 21 | num_servers=1, 22 | base_port=9000, 23 | ) 24 | 25 | # Provision 26 | api_secret = "single-server-test-secret-32chars!!" 27 | addresses = await provisioner.provision_servers(api_secret) 28 | assert len(addresses) == 1 29 | assert addresses[0] == "localhost:9000" 30 | 31 | # Verify server is up and healthy 32 | is_healthy = await wait_for_server_health(addresses[0]) 33 | assert is_healthy, "Server failed to become healthy" 34 | 35 | # Verify process is running 36 | assert len(provisioner._processes) == 1 37 | assert provisioner._processes[0].poll() is None 38 | 39 | # Check that double-provisioning would result in an error 40 | with pytest.raises(RuntimeError, match="already provisioned"): 41 | await provisioner.provision_servers(api_secret) 42 | 43 | # Teardown 44 | await provisioner.teardown() 45 | 46 | # Verify cleanup 47 | assert len(provisioner._processes) == 0 48 | await asyncio.sleep(0.5) 49 | assert not await check_health("localhost:9000"), "Server still responding after teardown" 50 | 51 | @pytest.mark.asyncio 52 | @pytest.mark.slow 53 | async def test_multiple_servers_lifecycle(self, example_workdir: Path): 54 | """Test complete lifecycle of multiple servers: provision, verify, and teardown.""" 55 | provisioner = LocalProvisioner( 56 | workdir_path=example_workdir, 57 | num_servers=5, 58 | base_port=9100, 59 | ) 60 | 61 | # Provision 62 | api_secret = "single-server-test-secret-32chars!!" 63 | addresses = await provisioner.provision_servers(api_secret) 64 | assert len(addresses) == 5 65 | expected_addresses = [f"localhost:{9100 + i}" for i in range(5)] 66 | assert addresses == expected_addresses 67 | 68 | # Verify all servers are up and healthy 69 | health_checks = await asyncio.gather( 70 | *[wait_for_server_health(addr) for addr in addresses] 71 | ) 72 | assert all(health_checks), "Not all servers became healthy" 73 | 74 | # Verify all processes are running 75 | assert len(provisioner._processes) == 5 76 | assert all(p.poll() is None for p in provisioner._processes) 77 | 78 | # Teardown 79 | await provisioner.teardown() 80 | 81 | # Verify cleanup 82 | assert len(provisioner._processes) == 0 83 | await asyncio.sleep(0.5) 84 | 85 | # Verify all servers are down 86 | for addr in addresses: 87 | assert not await check_health(addr), f"Server {addr} still responding after teardown" 88 | 89 | 90 | class TestValidation: 91 | """Test parameter validation.""" 92 | 93 | def test_invalid_num_servers(self): 94 | """Test validation of num_servers parameter.""" 95 | with pytest.raises(ValueError, match="at least 1"): 96 | LocalProvisioner(workdir_path=".", num_servers=0, base_port=8080) 97 | 98 | def test_invalid_base_port(self): 99 | """Test validation of base_port parameter.""" 100 | with pytest.raises(ValueError, match="between 1024 and 65535"): 101 | LocalProvisioner(workdir_path=".", num_servers=1, base_port=500) 102 | 103 | def test_port_range_exceeds_max(self): 104 | """Test validation of port range.""" 105 | with pytest.raises(ValueError, match="exceeds max port"): 106 | LocalProvisioner(workdir_path=".", num_servers=100, base_port=65500) 107 | -------------------------------------------------------------------------------- /examples/skyrl/README.md: -------------------------------------------------------------------------------- 1 | # Benchmax Environments with SkyRL 2 | 3 | This example directory contains example for both Math environment and Excel environment. The guide below is written for the Math environment but the high-level idea is broadly transferrable. 4 | 5 | ### **🔍 Quickstart: RL Math Agent with SkyRL + `benchmax`** 6 | 7 | We can fine-tune an RL agent with SkyRL using a `benchmax` environment — for example, the `math` environment — and a tool-enabled setup for calculator-like reasoning. 8 | 9 | Example script: 10 | `examples/benchmax/run_benchmax_math.sh` 11 | 12 | --- 13 | 14 | ## 1. Prepare the dataset 15 | 16 | Use `benchmax_data_process.py` to convert a HuggingFace dataset into the multiturn chat format expected by SkyRL rollouts. 17 | 18 | Example for `arithmetic50`: 19 | 20 | ```bash 21 | uv run src/benchmax/adapters/skyrl/benchmax_data_process.py \ 22 | --local_dir ~/data/math \ 23 | --dataset_name dawidmt/arithmetic50 \ 24 | --env_path benchmax.envs.math.math_env.MathEnvLocal 25 | ``` 26 | 27 | **Arguments**: 28 | 29 | | Flag | Required? | Description | 30 | | ---------------- | --------- | --------------------------------------------------------------------------------------------- | 31 | | `--local_dir` | yes | Output folder for `train.parquet` & `test.parquet`. | 32 | | `--dataset_name` | yes | HuggingFace dataset name or path. | 33 | | `--env_path` | yes | Dotted path to the `benchmax` environment class (e.g. `benchmax.envs.math.math_env.MathEnvLocal`). | 34 | 35 | --- 36 | 37 | ## 2. Launch training — **Focusing on Environment Arguments** 38 | 39 | In `run_benchmax_math.sh`, the environment is configured with: 40 | 41 | ```bash 42 | ENV_CLASS="MathEnv" 43 | ... 44 | environment.env_class=$ENV_CLASS \ 45 | ``` 46 | 47 | **How it works**: 48 | 49 | * **`environment.env_class`**: This must match the **ID** registered in `skyrl_gym.envs.register(...)`. 50 | 51 | * In `benchmax_math.py`, the registration happens here: 52 | 53 | ```python 54 | register( 55 | id="MathEnv", 56 | entry_point=load_benchmax_env_skyrl, 57 | kwargs={"actor": actor} 58 | ) 59 | ``` 60 | * The `id` value (`"MathEnv`) is what `ENV_CLASS` should be set to in the shell script and in the preprocessing step. 61 | 62 | * **Ray Actor Creation**: Before registration, the script calls: 63 | 64 | ```python 65 | get_or_create_benchmax_env_actor(MathEnvLocal) 66 | ``` 67 | 68 | This: 69 | 70 | 1. Imports your environment class (`from benchmax.envs.math.math_env import MathEnvLocal`). 71 | 2. Starts a persistent Ray actor (`BenchmaxEnvActor`) that wraps the environment. 72 | 3. Names the actor `BaseEnvService`, so `load_benchmax_env_skyrl` can attach to it when Gym launches the env. 73 | 74 | * **Putting it together**: 75 | `ENV_CLASS` → matches Gym registry `id` → which maps to `load_benchmax_env_skyrl` → which connects to the Ray actor created from your imported environment class. 76 | 77 | To run: 78 | 79 | ```bash 80 | bash examples/skyrl/run_benchmax_math.sh 81 | ``` 82 | 83 | --- 84 | 85 | ## 4. Using Your Own Benchmax Environment with SkyRL 86 | 87 | To integrate a new `benchmax` environment: 88 | 89 | 1. **Create or select your `benchmax` environment class** 90 | 91 | * Must subclass `benchmax.BaseEnv` (or `benchmax.mcp.ParallelMcpEnv` for multi-node support) and implement required methods. 92 | * Add tools, rewards, and `system_prompt` as needed. 93 | 94 | 2. **Update the SkyRL entrypoint** 95 | 96 | * In your copy of `benchmax_math.py`, change: 97 | 98 | ```python 99 | from benchmax.envs.math.math_env import MathEnvLocal 100 | get_or_create_benchmax_env_actor(MathEnvLocal) 101 | ``` 102 | 103 | to import and use your environment class. Change the env to MathEnvSkypilot to run MCP servers on multiple external nodes 104 | 105 | 3. **Update the registry ID & shell script** 106 | 107 | * Change the `id` in `register(...)` to match your environment name. 108 | * Update `ENV_CLASS` in your run script to match this ID. 109 | 110 | 4. **Update dataset preprocessing** 111 | 112 | * Use `--env_path` pointing to your environment class in the preprocessing command. 113 | 114 | --- 115 | 116 | Once set up, SkyRL will: 117 | 118 | * Launch your environment in a Ray actor 119 | * Register it in the SkyRL gym 120 | * Fine-tune your model via PPO with configurable multi-turn RL training -------------------------------------------------------------------------------- /src/benchmax/envs/crm/crm_env.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, List, Optional 3 | import sky 4 | 5 | from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv 6 | from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner 7 | from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner 8 | from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner 9 | from benchmax.envs.types import StandardizedExample 10 | 11 | 12 | SYSTEM_PROMPT = """\ 13 | You are an expert in Salesforce and you have access to a Salesforce instance. 14 | 15 | # Instructions 16 | - You will be provided a question, the system description, and relevant task context. 17 | - Interact with the Salesforce instance using the tools provided to help you answer the question. 18 | - You should ALWAYS make ONLY ONE tool call at a time. If you want to submit your final answer, just respond with the answer without tool call. If not, you should call some other tool. 19 | - Always end by respond with ONLY the answer, NO full sentence or any explanation. 20 | - If your answer is empty that is there are no records found matching the requirements mentioned, just return 'None' as the response. 21 | - You should be able to get to an answer within 2-3 tool calls, so don't overthink. 22 | 23 | Write your complete answer on the final line, within the xml tags . If there are multiple answers, use comma as a delimiter. 24 | e.g. 25 | For Case IDs, final answer should look like 0XA124XDF. If there are multiple, it could look like 0XA124XDF, 001XX000003GXXX 26 | For Months, it could look like May,July 27 | If nothing matches, output None 28 | """ 29 | 30 | 31 | class CRMExample(StandardizedExample): 32 | reward_metric: str 33 | 34 | 35 | class CRMEnv(ParallelMcpEnv): 36 | """Environment for CRM tasks using MCP with Salesforce""" 37 | 38 | system_prompt: str = SYSTEM_PROMPT 39 | 40 | def __init__(self, workdir_path: Path, provisioner: BaseProvisioner, **kwargs): 41 | """Initialize CRMEnv.""" 42 | super().__init__(workdir_path, provisioner, **kwargs) 43 | 44 | @classmethod 45 | def dataset_preprocess(cls, example: Any, **kwargs) -> CRMExample: 46 | # convert dataset example into CRMExample (inherit from StandardizedExample) 47 | task: Optional[str] = example.get("task") 48 | persona: Optional[str] = example.get("persona") 49 | metadata: Optional[Dict[str, Any]] = example.get("metadata") 50 | answer: Optional[List[str]] = example.get("answer") 51 | query: Optional[str] = example.get("query") 52 | reward_metric: Optional[str] = example.get("reward_metric") 53 | 54 | if not task or not persona or not query or answer is None or not reward_metric: 55 | raise ValueError( 56 | "Example must contain 'task', 'persona', 'query', 'answer', and 'reward_metric' fields" 57 | ) 58 | 59 | prompt = f"{persona}\n{task}\n{query}" 60 | if metadata and "required" in metadata: 61 | required_metadata = metadata["required"] 62 | prompt = f"{persona}\n{task}\n{required_metadata}\n{query}" 63 | 64 | return CRMExample( 65 | prompt=prompt, 66 | ground_truth=answer, 67 | init_rollout_args=None, 68 | reward_metric=reward_metric, 69 | ) 70 | 71 | 72 | class CRMEnvLocal(CRMEnv): 73 | """Import this env to run environment locally""" 74 | 75 | def __init__(self, num_local_servers: int = 5, **kwargs): 76 | workdir_path = Path(__file__).parent / "workdir" 77 | provisioner = LocalProvisioner( 78 | workdir_path=workdir_path, num_servers=num_local_servers 79 | ) 80 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs) 81 | 82 | 83 | class CRMEnvSkypilot(CRMEnv): 84 | """Import this env to run environment on any cloud (i.e. AWS / GCP / Azure) with Skypilot""" 85 | 86 | def __init__( 87 | self, 88 | cloud: sky.clouds.Cloud = sky.Azure(), 89 | num_nodes: int = 2, 90 | servers_per_node: int = 5, 91 | **kwargs, 92 | ): 93 | workdir_path = Path(__file__).parent / "workdir" 94 | provisioner = SkypilotProvisioner( 95 | workdir_path=workdir_path, 96 | cloud=cloud, 97 | num_nodes=num_nodes, 98 | servers_per_node=servers_per_node, 99 | ) 100 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs) 101 | -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/example_workdir/demo_mcp_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo MCP server for calculator-style workflow. 3 | 4 | Tools: 5 | - hello_world: stateless sanity check 6 | - define_variable / get_variable: in-memory calculator variables 7 | - evaluate: arithmetic using stored variables 8 | - append_log / read_log: workspace file I/O 9 | - allocate_memory: stress test 10 | """ 11 | 12 | from pathlib import Path 13 | from typing import Dict 14 | from fastmcp import FastMCP 15 | from fastmcp.exceptions import ToolError 16 | 17 | # ---------------------------------------------------------------------- 18 | # MCP server setup 19 | # ---------------------------------------------------------------------- 20 | 21 | mcp = FastMCP("demo-calculator-server") 22 | 23 | # In-memory state 24 | _variables: Dict[str, float] = {} 25 | 26 | # Memory stress 27 | _leaked_memory = [] 28 | 29 | # ---------------------------------------------------------------------- 30 | # Tools 31 | # ---------------------------------------------------------------------- 32 | 33 | 34 | @mcp.tool() 35 | async def hello_world(name: str) -> str: 36 | """Simple stateless greeting.""" 37 | return f"Hello, {name}!" 38 | 39 | 40 | def is_valid_var_name(name: str) -> bool: 41 | """Check if a variable name is valid: start with letter/_ and contain only letters, digits, or _""" 42 | if not name: 43 | return False 44 | if not (name[0].isalpha() or name[0] == "_"): 45 | return False 46 | return all(c.isalnum() or c == "_" for c in name) 47 | 48 | 49 | @mcp.tool() 50 | async def define_variable(name: str, value: float) -> str: 51 | """Store a named variable in memory.""" 52 | if not is_valid_var_name(name): 53 | raise ToolError(f"Invalid variable name: '{name}'") 54 | _variables[name] = value 55 | return f"Variable '{name}' set to {value}" 56 | 57 | 58 | @mcp.tool() 59 | async def get_variable(name: str) -> str: 60 | """Retrieve a variable from memory.""" 61 | if not is_valid_var_name(name): 62 | raise ToolError(f"Invalid variable name: '{name}'") 63 | if name not in _variables: 64 | raise ToolError(f"Variable '{name}' not defined") 65 | return str(_variables[name]) 66 | 67 | 68 | @mcp.tool() 69 | async def evaluate(expression: str) -> str: 70 | """ 71 | Evaluate arithmetic expression using stored variables. 72 | 73 | Only numbers and defined variable names are allowed. 74 | """ 75 | allowed_names = {k: v for k, v in _variables.items()} 76 | allowed_chars = "0123456789+-*/()., " 77 | 78 | # Split expression manually into potential identifiers and other tokens 79 | token = "" 80 | for c in expression + " ": # add space to flush last token 81 | if c.isalnum() or c == "_": 82 | token += c 83 | else: 84 | if token: 85 | if token[0].isalpha() or token[0] == "_": 86 | if token not in allowed_names: 87 | raise ToolError(f"Undefined variable: '{token}'") 88 | if not is_valid_var_name(token): 89 | raise ToolError(f"Invalid variable name: '{token}'") 90 | # numeric token is implicitly allowed 91 | token = "" 92 | if c not in allowed_chars and not c.isspace(): 93 | raise ToolError(f"Invalid character in expression: '{c}'") 94 | 95 | try: 96 | result = eval(expression, {"__builtins__": {}}, allowed_names) 97 | return str(result) 98 | except Exception as e: 99 | raise ToolError(f"Evaluation error: {str(e)}") 100 | 101 | 102 | @mcp.tool() 103 | async def append_log(filename: str, message: str) -> str: 104 | """Append a message to a workspace file.""" 105 | file_path = Path(filename) 106 | with open(file_path, "a") as f: 107 | f.write(message + "\n") 108 | return f"Appended message to {filename}" 109 | 110 | 111 | @mcp.tool() 112 | async def read_log(filename: str) -> str: 113 | """Read the content of a workspace file.""" 114 | file_path = Path(filename) 115 | if not file_path.exists(): 116 | raise ToolError(f"File '{filename}' not found") 117 | return file_path.read_text() 118 | 119 | 120 | @mcp.tool() 121 | async def allocate_memory(megabytes: int) -> str: 122 | """Allocate memory to simulate stress / OOM.""" 123 | global _leaked_memory 124 | size = megabytes * 1024 * 1024 125 | _leaked_memory.append(bytearray(size)) 126 | return f"Leaked {megabytes} MB (total allocations: {len(_leaked_memory)})" 127 | 128 | 129 | if __name__ == "__main__": 130 | # Run the server 131 | mcp.run() 132 | -------------------------------------------------------------------------------- /tests/unit/envs/mcp/provisioners/test_local.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pathlib import Path 3 | from unittest.mock import Mock, patch, MagicMock, AsyncMock 4 | from benchmax.envs.mcp.provisioners import LocalProvisioner 5 | 6 | 7 | # --------------------------------------------------------------------------- 8 | # Initialization 9 | # --------------------------------------------------------------------------- 10 | 11 | 12 | class TestLocalProvisionerInit: 13 | """Tests for LocalProvisioner initialization.""" 14 | 15 | def test_init_requires_workdir_path(self, example_workdir: Path) -> None: 16 | """Provisioner initializes with given workdir and defaults.""" 17 | LocalProvisioner(workdir_path=example_workdir) 18 | 19 | 20 | # --------------------------------------------------------------------------- 21 | # Provisioning 22 | # --------------------------------------------------------------------------- 23 | 24 | 25 | class TestLocalProvisionerProvision: 26 | """Tests for LocalProvisioner.provision_servers.""" 27 | 28 | @pytest.mark.asyncio 29 | @patch("benchmax.envs.mcp.provisioners.local_provisioner.setup_sync_dir") 30 | @patch.object(LocalProvisioner, "_spawn_process", new_callable=AsyncMock) 31 | async def test_provision_runs_expected_commands( 32 | self, 33 | mock_spawn: AsyncMock, 34 | mock_setup: Mock, 35 | example_workdir: Path, 36 | test_sync_dir: Path, 37 | ) -> None: 38 | """Provisioner should prepare sync dir and start subprocesses.""" 39 | mock_setup.return_value = test_sync_dir 40 | mock_proc = MagicMock() 41 | mock_proc.pid = 1234 42 | mock_spawn.return_value = mock_proc 43 | 44 | provisioner = LocalProvisioner(workdir_path=example_workdir, num_servers=2) 45 | addresses = await provisioner.provision_servers("dummy-api-secret") 46 | 47 | assert len(addresses) == 2 48 | mock_setup.assert_called_once_with(example_workdir) 49 | # First call is setup_cmd (wait=True), next calls are servers 50 | assert mock_spawn.call_count == 3 51 | assert addresses == ["localhost:8080", "localhost:8081"] 52 | 53 | @pytest.mark.asyncio 54 | @patch("benchmax.envs.mcp.provisioners.local_provisioner.setup_sync_dir") 55 | async def test_provision_handles_setup_failure( 56 | self, mock_setup: Mock, example_workdir: Path 57 | ) -> None: 58 | """setup_sync_dir errors are propagated.""" 59 | mock_setup.side_effect = OSError("setup failed") 60 | provisioner = LocalProvisioner(workdir_path=example_workdir) 61 | 62 | with pytest.raises(OSError): 63 | await provisioner.provision_servers("dummy-api-secret") 64 | 65 | 66 | # --------------------------------------------------------------------------- 67 | # Teardown 68 | # --------------------------------------------------------------------------- 69 | 70 | 71 | class TestLocalProvisionerTeardown: 72 | """Tests for LocalProvisioner.teardown.""" 73 | 74 | @pytest.mark.asyncio 75 | @patch("benchmax.envs.mcp.provisioners.local_provisioner.cleanup_dir") 76 | async def test_teardown_kills_processes_and_cleans_up( 77 | self, mock_cleanup: Mock, example_workdir: Path, test_sync_dir: Path 78 | ) -> None: 79 | """Active processes are killed and sync dir cleaned up.""" 80 | provisioner = LocalProvisioner(workdir_path=example_workdir) 81 | 82 | proc = MagicMock() 83 | proc.poll.return_value = None 84 | proc.kill = MagicMock() 85 | proc.wait = MagicMock() 86 | provisioner._processes = [proc] 87 | provisioner._sync_dir = test_sync_dir 88 | provisioner._is_provisioned = True 89 | 90 | await provisioner.teardown() 91 | proc.kill.assert_called_once() 92 | proc.wait.assert_called_once() 93 | mock_cleanup.assert_called_once_with(test_sync_dir) 94 | 95 | @pytest.mark.asyncio 96 | @patch("benchmax.envs.mcp.provisioners.local_provisioner.cleanup_dir") 97 | async def test_teardown_skips_already_terminated( 98 | self, mock_cleanup: Mock, example_workdir: Path, test_sync_dir: Path 99 | ) -> None: 100 | """Teardown skips processes that are already terminated.""" 101 | provisioner = LocalProvisioner(workdir_path=example_workdir) 102 | proc = MagicMock() 103 | proc.poll.return_value = 0 # Already exited 104 | proc.kill = MagicMock() 105 | provisioner._processes = [proc] 106 | provisioner._sync_dir = test_sync_dir 107 | provisioner._is_provisioned = True 108 | 109 | await provisioner.teardown() 110 | proc.kill.assert_not_called() 111 | mock_cleanup.assert_called_once_with(test_sync_dir) 112 | -------------------------------------------------------------------------------- /tests/unit/envs/mcp/provisioners/test_provisioners_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for MCP utility functions. 3 | """ 4 | 5 | import pytest 6 | import tempfile 7 | from pathlib import Path 8 | 9 | from benchmax.envs.mcp.provisioners.utils import ( 10 | setup_sync_dir, 11 | cleanup_dir, 12 | ) 13 | 14 | 15 | class TestSetupSyncDir: 16 | """Tests for setup_sync_dir function.""" 17 | 18 | def test_setup_sync_dir_creates_temp_directory(self, example_workdir: Path): 19 | """Test that sync directory is created in temp location.""" 20 | sync_dir = setup_sync_dir(example_workdir) 21 | 22 | try: 23 | assert sync_dir.exists() 24 | assert sync_dir.is_dir() 25 | assert "benchmax_skypilot_" in sync_dir.name 26 | finally: 27 | cleanup_dir(sync_dir) 28 | 29 | def test_setup_sync_dir_copies_proxy_server(self, example_workdir: Path): 30 | """Test that proxy_server.py is copied to sync directory.""" 31 | sync_dir = setup_sync_dir(example_workdir) 32 | 33 | try: 34 | proxy_server = sync_dir / "proxy_server.py" 35 | assert proxy_server.exists() 36 | assert proxy_server.is_file() 37 | finally: 38 | cleanup_dir(sync_dir) 39 | 40 | def test_setup_sync_dir_copies_workdir_contents(self, example_workdir: Path): 41 | """Test that all workdir contents are copied.""" 42 | sync_dir = setup_sync_dir(example_workdir) 43 | 44 | try: 45 | # Check required files 46 | assert (sync_dir / "mcp_config.yaml").exists() 47 | assert (sync_dir / "reward_fn.py").exists() 48 | assert (sync_dir / "setup.sh").exists() 49 | assert (sync_dir / "proxy_server.py").exists() 50 | finally: 51 | cleanup_dir(sync_dir) 52 | 53 | def test_setup_sync_dir_validates_required_files(self, tmp_path: Path): 54 | """Test that missing required files raise FileNotFoundError.""" 55 | # Create incomplete workdir 56 | incomplete_workdir = tmp_path / "incomplete" 57 | incomplete_workdir.mkdir() 58 | (incomplete_workdir / "mcp_config.yaml").touch() 59 | # Missing reward_fn.py and setup.sh 60 | 61 | with pytest.raises(FileNotFoundError, match="reward_fn.py"): 62 | setup_sync_dir(incomplete_workdir) 63 | 64 | def test_setup_sync_dir_handles_nonexistent_workdir(self, tmp_path: Path): 65 | """Test that nonexistent workdir raises FileNotFoundError.""" 66 | nonexistent = tmp_path / "does_not_exist" 67 | 68 | with pytest.raises(FileNotFoundError, match="workdir_path"): 69 | setup_sync_dir(nonexistent) 70 | 71 | def test_setup_sync_dir_handles_file_as_workdir(self, tmp_path: Path): 72 | """Test that file instead of directory raises ValueError.""" 73 | file_path = tmp_path / "not_a_dir.txt" 74 | file_path.touch() 75 | 76 | with pytest.raises(ValueError, match="directory"): 77 | setup_sync_dir(file_path) 78 | 79 | def test_setup_sync_dir_cleanup_on_error(self, tmp_path: Path): 80 | """Test that sync directory is cleaned up if setup fails.""" 81 | # Create workdir with required files but missing reward_fn.py 82 | incomplete_workdir = tmp_path / "workdir" 83 | incomplete_workdir.mkdir() 84 | (incomplete_workdir / "mcp_config.yaml").touch() 85 | (incomplete_workdir / "setup.sh").touch() 86 | 87 | # This should fail because reward_fn.py doesn't exist in mcp/ 88 | # But the temp directory should be cleaned up 89 | with pytest.raises(FileNotFoundError): 90 | setup_sync_dir(incomplete_workdir) 91 | 92 | 93 | class TestCleanupDir: 94 | """Tests for cleanup_dir function.""" 95 | 96 | def test_cleanup_dir_removes_directory(self): 97 | """Test that cleanup_dir removes an existing directory.""" 98 | temp_dir = Path(tempfile.mkdtemp(prefix="test_cleanup_")) 99 | assert temp_dir.exists() 100 | 101 | cleanup_dir(temp_dir) 102 | assert not temp_dir.exists() 103 | 104 | def test_cleanup_dir_handles_none(self): 105 | """Test that cleanup_dir handles None gracefully.""" 106 | cleanup_dir(None) # Should not raise 107 | 108 | def test_cleanup_dir_handles_nonexistent(self, tmp_path: Path): 109 | """Test that cleanup_dir handles nonexistent path gracefully.""" 110 | nonexistent = tmp_path / "does_not_exist" 111 | cleanup_dir(nonexistent) # Should not raise 112 | 113 | def test_cleanup_dir_handles_file(self, tmp_path: Path): 114 | """Test that cleanup_dir ignores files (only removes directories).""" 115 | file_path = tmp_path / "file.txt" 116 | file_path.touch() 117 | 118 | cleanup_dir(file_path) # Should not raise 119 | assert file_path.exists() # File should still exist 120 | -------------------------------------------------------------------------------- /src/benchmax/envs/crm/README.md: -------------------------------------------------------------------------------- 1 | # CRM Environment 2 | 3 | This environment provides capabilities for interacting with Salesforce instances through the Salesforce API. 4 | 5 | This is based off [CRMArena Pro](https://github.com/SalesforceAIResearch/CRMArena) & supports both B2B and B2C configurations. 6 | 7 | ## Prerequisites 8 | 9 | **Important**: Before using this environment, ensure you have: 10 | - Python 3.12 or higher 11 | 12 | ## Installation 13 | 14 | ```bash 15 | pip install "benchmax[crm,skypilot]" 16 | ``` 17 | 18 | Includes: 19 | - simple-salesforce: For Salesforce API interactions 20 | - python-dateutil: For date/time handling 21 | - fastmcp: For MCP server functionality 22 | 23 | ## Usage 24 | 25 | Use `CRMEnvLocal` to run the servers locally on the same machine as benchmax or use `CRMEnvSkypilot` to parallelize the servers across multiple nodes. 26 | 27 | ## Configuration 28 | 29 | The environment has built-in Salesforce configurations: 30 | 31 | ### B2B Configuration 32 | Default configuration using B2B Salesforce instance. Used automatically when initializing CRMEnv. 33 | 34 | ### B2C Configuration 35 | Alternative configuration for B2C use cases, can be enabled by passing "b2c" to get_mcp_config(). 36 | 37 | ## Available Tools 38 | 39 | The environment provides a comprehensive set of MCP tools for Salesforce interactions: 40 | 41 | ### Case Management 42 | - `get_cases`: Retrieve cases based on various filtering criteria (dates, agents, statuses) 43 | - `get_non_transferred_case_ids`: Get cases not transferred between agents in a period 44 | - `get_agent_handled_cases_by_period`: Get number of cases handled by each agent 45 | - `get_agent_transferred_cases_by_period`: Get number of cases transferred between agents 46 | - `get_livechat_transcript_by_case_id`: Retrieve live chat transcripts for a case 47 | 48 | ### Agent Analysis 49 | - `get_qualified_agent_ids_by_case_count`: Filter agents based on case handling count 50 | - `get_agents_with_max_cases`: Find agents with most cases in a subset 51 | - `get_agents_with_min_cases`: Find agents with fewest cases in a subset 52 | - `calculate_average_handle_time`: Calculate average case handling time per agent 53 | 54 | ### Regional Analysis 55 | - `get_shipping_state`: Add shipping state information to cases 56 | - `calculate_region_average_closure_times`: Calculate average case closure times by region 57 | 58 | ### Time Period Management 59 | - `get_start_date`: Calculate start date based on period and interval 60 | - `get_period`: Get date range for named periods (months, quarters, seasons) 61 | - `get_month_to_case_count`: Count cases created in each month 62 | 63 | ### Product and Issue Management 64 | - `search_products`: Search for products by name/description 65 | - `get_purchase_history`: Get purchase history for account/products 66 | - `get_issues`: Retrieve list of issue records 67 | - `get_issue_counts`: Get issue counts for products in a time period 68 | - `get_order_item_ids_by_product`: Get order items for a product 69 | 70 | ### Knowledge Base 71 | - `search_knowledge_articles`: Search knowledge articles by term 72 | 73 | ### Account Management 74 | - `get_account_id_by_contact_id`: Get Account ID for a Contact 75 | 76 | ### Utility Functions 77 | - `find_id_with_max_value`: Find IDs with maximum value in a dataset 78 | - `find_id_with_min_value`: Find IDs with minimum value in a dataset 79 | - `issue_soql_query`: Execute custom SOQL queries 80 | - `issue_sosl_query`: Execute custom SOSL queries 81 | 82 | ## Features 83 | 84 | - Seamless interaction with Salesforce instances 85 | - Support for both B2B and B2C configurations 86 | - Robust answer parsing and evaluation through fuzzy matching 87 | - Standardized example preprocessing for dataset handling 88 | 89 | ## Reward Functions 90 | 91 | The environment provides two reward metrics for evaluating model completions: 92 | 93 | ### Exact Match 94 | - Uses IoU (Intersection over Union) score 95 | - Compares completion tokens with ground truth tokens 96 | - Perfect score (1.0) for exact matches 97 | - Partial score based on token overlap 98 | - Returns 0.0 if one set is empty while other isn't 99 | 100 | ### Fuzzy Match 101 | - Uses F1 score for more lenient evaluation 102 | - Handles variations in text formatting and word order 103 | - Normalizes text by removing punctuation and articles 104 | - Suitable for cases where exact matching is too strict 105 | 106 | ## MCP Server 107 | 108 | The environment runs a Salesforce MCP server that handles API interactions. The server: 109 | - Manages authentication with Salesforce 110 | - Provides secure access to Salesforce operations 111 | - Handles API rate limiting and session management 112 | 113 | 114 | # CRMArena-Pro data 115 | crmarena_pro_queries = `load_dataset`("Salesforce/CRMArenaPro", "CRMArenaPro") 116 | b2b_schema = load_dataset("Salesforce/CRMArenaPro", "b2b_schema") 117 | b2c_schema = load_dataset("Salesforce/CRMArenaPro", "b2c_schema") 118 | 119 | TODO: Link to HuggingFace dataset -------------------------------------------------------------------------------- /src/benchmax/envs/base_env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List, Any, Optional, Tuple 3 | from pathlib import Path 4 | from datasets import ( 5 | DatasetDict, 6 | Dataset, 7 | IterableDatasetDict, 8 | IterableDataset, 9 | load_dataset, 10 | ) 11 | 12 | from benchmax.envs.types import ToolDefinition, StandardizedExample 13 | from benchmax.prompts.tools import render_tools_prompt 14 | 15 | 16 | class BaseEnv(ABC): 17 | """Base benchmax environment for tool execution and reward computation""" 18 | 19 | system_prompt: str = "" 20 | 21 | @abstractmethod 22 | async def shutdown(self): 23 | pass 24 | 25 | # Override this method if your example does not match the default structure 26 | @classmethod 27 | def dataset_preprocess(cls, example: Any, **kwargs) -> StandardizedExample: 28 | """ 29 | Preprocess a single dataset example into a dict with keys: 30 | - "prompt": str 31 | - "ground_truth": Any 32 | - "init_rollout_args": Dict[str, Any] 33 | """ 34 | prompt = example.pop("prompt", "") 35 | ground_truth = example.pop("ground_truth", "") 36 | init_rollout_args = example.pop("init_rollout_args") 37 | return StandardizedExample( 38 | prompt=prompt, 39 | ground_truth=ground_truth, 40 | init_rollout_args=init_rollout_args, 41 | **example, 42 | ) 43 | 44 | @classmethod 45 | def load_dataset( 46 | cls, dataset_name: str, **kwargs 47 | ) -> Tuple[ 48 | DatasetDict | Dataset | IterableDatasetDict | IterableDataset, str | None 49 | ]: 50 | """ 51 | Download and prepare a dataset for use with this environment. 52 | 53 | This method should handle retrieving the specified dataset (e.g., from HuggingFace, local files, 54 | or a custom source), preprocessing or converting it into a compatible structure, and storing it 55 | locally in a reusable format. The processed dataset should be suitable for downstream use with 56 | `dataset_preprocess`, which standardizes individual examples into the expected format. 57 | 58 | Args: 59 | dataset_name (str): Identifier of the dataset to be loaded. 60 | **kwargs: Additional dataset-specific arguments (e.g., split, filtering options, cache directory). 61 | 62 | Returns: 63 | Dataset: A dataset object (e.g., HuggingFace Dataset or similar) ready for processing. 64 | str: Optional string pointing to where the dataset is stored locally 65 | """ 66 | return load_dataset(dataset_name, **kwargs), None 67 | 68 | @abstractmethod 69 | async def list_tools(self) -> List[ToolDefinition]: 70 | """Return list of available tools""" 71 | pass 72 | 73 | @abstractmethod 74 | async def run_tool(self, rollout_id: str, tool_name: str, **tool_args) -> Any: 75 | """Execute named tool in rollout context with given arguments""" 76 | pass 77 | 78 | @abstractmethod 79 | async def init_rollout(self, rollout_id: str, **rollout_args) -> None: 80 | """Initialize resources for a new rollout""" 81 | pass 82 | 83 | @abstractmethod 84 | async def release_rollout(self, rollout_id: str) -> None: 85 | """Free up resources for a new rollout. Called by compute_reward internally but also available for cleanup.""" 86 | pass 87 | 88 | @abstractmethod 89 | async def copy_to_workspace( 90 | self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None 91 | ) -> None: 92 | """Copy a file to the workspace for a specific rollout. If dst_filename is None, use the original filename.""" 93 | pass 94 | 95 | @abstractmethod 96 | async def copy_content_to_workspace( 97 | self, rollout_id: str, src_content: str | bytes, dst_filename: str 98 | ) -> None: 99 | """Create a file with given content in the workspace for a specific rollout""" 100 | pass 101 | 102 | @abstractmethod 103 | async def copy_from_workspace( 104 | self, rollout_id: str, src_filename: str, dst_path: Path 105 | ) -> None: 106 | """Copy a file from the workspace for a specific rollout""" 107 | pass 108 | 109 | @abstractmethod 110 | async def compute_reward( 111 | self, rollout_id: str, completion: str, ground_truth: Any, **kwargs: Any 112 | ) -> Dict[str, float]: 113 | """Compute rewards using registered functions 114 | 115 | Returns dict mapping reward function names to their computed scores. 116 | """ 117 | pass 118 | 119 | async def get_system_prompt(self, add_tool_defs: bool = False) -> str: 120 | """Get system prompt. To add tool definitions, set add_tool_defs to True.""" 121 | if add_tool_defs: 122 | return render_tools_prompt( 123 | await self.list_tools(), self.system_prompt or "" 124 | ) 125 | else: 126 | return self.system_prompt 127 | -------------------------------------------------------------------------------- /src/benchmax/envs/crm/workdir/reward_fn.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | from html import unescape 4 | from pathlib import Path 5 | from collections import Counter 6 | from typing import Any, Callable, Dict, List, Optional, Union, Awaitable 7 | from fastmcp import Client 8 | 9 | RewardFunction = Callable[..., Union[float, Awaitable[float]]] 10 | 11 | 12 | def parse_answers(proposed_answer: str) -> str: 13 | """ 14 | Parse the proposed answer. 15 | """ 16 | m = re.search( 17 | r"(.*?)", proposed_answer, flags=re.IGNORECASE | re.DOTALL 18 | ) 19 | if not m: 20 | proposed_answer = "" 21 | else: 22 | # Unescape any XML entities (& → &, etc.) and normalise whitespace. 23 | proposed_answer = unescape(m.group(1)).strip().lower() 24 | return proposed_answer 25 | 26 | 27 | def parse_text_to_tokens(text: str) -> set: 28 | """ 29 | Parse text into normalized tokens using common separators. 30 | 31 | Args: 32 | text: Input text to parse 33 | 34 | Returns: 35 | set: Set of normalized tokens 36 | """ 37 | if not text: 38 | return set() 39 | 40 | # Clean up the text by removing quotes and extra whitespace 41 | cleaned_text = text.strip().strip('"').strip("'").lower() 42 | 43 | # Split by common separators: spaces, commas, semicolons, pipes, tabs, newlines 44 | # Using regex to split on multiple separators 45 | tokens = re.split(r"[,\s|]+", cleaned_text) 46 | 47 | # Filter out empty tokens and normalize 48 | normalized_tokens = {token.strip() for token in tokens if token.strip()} 49 | 50 | return normalized_tokens 51 | 52 | 53 | def get_all_metrics(proposed_answer: str, ground_truth: str) -> float: 54 | """ 55 | Compute fuzzy matching score between proposed answer and ground truth. 56 | Uses F1 score as the primary metric. 57 | """ 58 | 59 | def normalize_answer(s): 60 | """Lower text and remove punctuation, articles and extra whitespace.""" 61 | 62 | def remove_articles(text): 63 | return re.sub(r"\b(a|an|the)\b", " ", text) 64 | 65 | def white_space_fix(text): 66 | return " ".join(text.split()) 67 | 68 | def handle_punc(text): 69 | exclude = set(string.punctuation + "".join(["'", "'", "´", "`"])) 70 | return "".join(ch if ch not in exclude else " " for ch in text) 71 | 72 | def lower(text): 73 | return text.lower() 74 | 75 | def replace_underscore(text): 76 | return text.replace("_", " ") 77 | 78 | return white_space_fix( 79 | remove_articles(handle_punc(lower(replace_underscore(s)))) 80 | ).strip() 81 | 82 | def f1_score(prediction, ground_truth): 83 | prediction_tokens = normalize_answer(prediction).split() 84 | ground_truth_tokens = normalize_answer(ground_truth).split() 85 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 86 | num_same = sum(common.values()) 87 | if num_same == 0: 88 | return 0 89 | precision = 1.0 * num_same / len(prediction_tokens) 90 | recall = 1.0 * num_same / len(ground_truth_tokens) 91 | f1 = (2 * precision * recall) / (precision + recall) 92 | return f1 93 | 94 | return f1_score(proposed_answer, ground_truth) 95 | 96 | 97 | def crm_matching_reward_function( 98 | completion: str, 99 | ground_truth: List[str], 100 | mcp_client: Client, 101 | workspace: Path, 102 | **kwargs: Any, 103 | ) -> float: 104 | """ 105 | Reward function for CRM environment that evaluates model completions. 106 | 107 | Args: 108 | prompt: Input prompt given to the model 109 | completion: Model's generated completion/response 110 | ground_truth: Expected/correct output (should be a list) 111 | workspace: Path to rollout's workspace 112 | **kwargs: Additional context 113 | 114 | Returns: 115 | float: Reward score between 0 and 1 116 | """ 117 | reward_metric: Optional[str] = kwargs.get("reward_metric") 118 | 119 | if not reward_metric: 120 | raise ValueError("kwargs must contain reward metric") 121 | 122 | proposed_answer = completion.strip() if completion else "" 123 | proposed_answer = parse_answers(proposed_answer) 124 | 125 | if reward_metric == "exact_match": 126 | # Parse and normalize the completion text 127 | completion_tokens = parse_text_to_tokens(proposed_answer) 128 | 129 | # Parse and normalize all ground truth items 130 | all_ground_truth_tokens = set() 131 | for gt_item in ground_truth: 132 | gt_tokens = parse_text_to_tokens(str(gt_item)) 133 | all_ground_truth_tokens.update(gt_tokens) 134 | 135 | # Calculate IoU (Intersection over Union) 136 | if not all_ground_truth_tokens and not completion_tokens: 137 | return 1.0 # Both empty sets match perfectly 138 | elif not all_ground_truth_tokens or not completion_tokens: 139 | return 0.0 # One empty, one non-empty 140 | 141 | intersection = completion_tokens.intersection(all_ground_truth_tokens) 142 | union = completion_tokens.union(all_ground_truth_tokens) 143 | 144 | iou_score = len(intersection) / len(union) if union else 0.0 145 | 146 | # Return 1.0 if perfect match (IoU = 1.0), otherwise return IoU score 147 | return iou_score 148 | 149 | elif reward_metric == "fuzzy_match": 150 | # For fuzzy match, we only have 1 ground truth item 151 | if ground_truth[0] is not None: 152 | return get_all_metrics(proposed_answer, str(ground_truth[0])) 153 | else: 154 | return 0.0 155 | 156 | else: 157 | print(f"Unknown reward metric: {reward_metric}") 158 | return 0.0 159 | 160 | 161 | # ------------------------------- 162 | # Export reward functions 163 | # ------------------------------- 164 | reward_functions: Dict[str, RewardFunction] = { 165 | "match": crm_matching_reward_function, 166 | } 167 | -------------------------------------------------------------------------------- /tests/integration/envs/mcp/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integration fixtures for MCP environment tests. 3 | These may start subprocesses, use real ports, or access the filesystem. 4 | """ 5 | 6 | import uuid 7 | import pytest 8 | from pathlib import Path 9 | from typing import Tuple, List 10 | from collections.abc import AsyncGenerator 11 | import sky 12 | 13 | from benchmax.envs.mcp import ParallelMcpEnv 14 | from benchmax.envs.mcp.provisioners import ( 15 | LocalProvisioner, 16 | ManualProvisioner, 17 | SkypilotProvisioner, 18 | ) 19 | 20 | # ===== Session-scoped: Provision servers once for all tests ===== 21 | 22 | 23 | @pytest.fixture(scope="session") 24 | async def local_servers_with_secret( 25 | example_workdir: Path, 26 | ) -> AsyncGenerator[Tuple[List[str], str], None]: 27 | """ 28 | Provision local servers once for the entire test session. 29 | Returns (addresses, api_secret) tuple. 30 | 31 | These servers are reused across tests for speed. 32 | """ 33 | api_secret = uuid.uuid4().hex 34 | provisioner = LocalProvisioner( 35 | workdir_path=example_workdir, 36 | num_servers=4, 37 | base_port=8080, 38 | ) 39 | 40 | addresses = await provisioner.provision_servers(api_secret) 41 | print(f"\n[Session Setup] Provisioned {len(addresses)} local servers") 42 | 43 | yield addresses, api_secret 44 | 45 | print(f"\n[Session Teardown] Tearing down {len(addresses)} local servers") 46 | await provisioner.teardown() 47 | 48 | 49 | @pytest.fixture(scope="session") 50 | async def skypilot_servers_with_secret( 51 | example_workdir: Path, 52 | ) -> AsyncGenerator[Tuple[List[str], str], None]: 53 | """ 54 | Provision Skypilot servers once for the entire test session. 55 | Returns (addresses, api_secret) tuple. 56 | 57 | These servers are reused across tests for speed. 58 | Only provisioned if test actually uses this fixture. 59 | """ 60 | 61 | api_secret = uuid.uuid4().hex 62 | provisioner = SkypilotProvisioner( 63 | workdir_path=example_workdir, 64 | cloud=sky.Azure(), 65 | num_nodes=2, 66 | servers_per_node=4, 67 | base_cluster_name="test-cluster", 68 | cpus=2, 69 | memory=8 70 | ) 71 | 72 | addresses = await provisioner.provision_servers(api_secret) 73 | print(f"\n[Session Setup] Provisioned {len(addresses)} Skypilot servers") 74 | 75 | yield addresses, api_secret 76 | 77 | print(f"\n[Session Teardown] Tearing down {len(addresses)} Skypilot servers") 78 | await provisioner.teardown() 79 | 80 | 81 | # ===== Function-scoped: Fresh env for each test ===== 82 | 83 | 84 | @pytest.fixture 85 | async def local_env( 86 | local_servers_with_secret: Tuple[List[str], str], example_workdir: Path 87 | ) -> AsyncGenerator[ParallelMcpEnv, None]: 88 | """ 89 | Create a fresh ParallelMcpEnv using reused local servers. 90 | 91 | Each test gets a clean env instance, but servers are shared 92 | across tests for speed. This means server state may be 93 | contaminated, but that's acceptable for most tests. 94 | """ 95 | addresses, api_secret = local_servers_with_secret 96 | 97 | manual_provisioner = ManualProvisioner(addresses) 98 | env = ParallelMcpEnv( 99 | workdir_path=example_workdir, 100 | provisioner=manual_provisioner, 101 | api_secret=api_secret, 102 | provision_at_init=True, 103 | ) 104 | 105 | yield env 106 | 107 | await env.shutdown() 108 | 109 | 110 | @pytest.fixture 111 | async def skypilot_env( 112 | skypilot_servers_with_secret: Tuple[List[str], str], example_workdir: Path 113 | ) -> AsyncGenerator[ParallelMcpEnv, None]: 114 | """ 115 | Create a fresh ParallelMcpEnv using reused Skypilot servers. 116 | 117 | Each test gets a clean env instance, but servers are shared 118 | across tests for speed. Mark tests using this with @pytest.mark.remote. 119 | """ 120 | addresses, api_secret = skypilot_servers_with_secret 121 | 122 | manual_provisioner = ManualProvisioner(addresses) 123 | env = ParallelMcpEnv( 124 | workdir_path=example_workdir, 125 | provisioner=manual_provisioner, 126 | api_secret=api_secret, 127 | provision_at_init=True, 128 | ) 129 | 130 | yield env 131 | 132 | await env.shutdown() 133 | 134 | 135 | @pytest.fixture 136 | async def fresh_local_env(example_workdir: Path) -> AsyncGenerator[ParallelMcpEnv, None]: 137 | """ 138 | Create a fresh ParallelMcpEnv with its own dedicated servers. 139 | 140 | Use this for tests that need clean server state. 141 | Slower than local_env but provides isolation. 142 | Mark tests using this with @pytest.mark.slow. 143 | """ 144 | provisioner = LocalProvisioner( 145 | workdir_path=example_workdir, 146 | num_servers=4, 147 | base_port=9080, # Different port range to avoid conflicts 148 | ) 149 | 150 | env = ParallelMcpEnv( 151 | workdir_path=example_workdir, 152 | provisioner=provisioner, 153 | provision_at_init=True, 154 | ) 155 | 156 | yield env 157 | 158 | await env.shutdown() 159 | 160 | 161 | @pytest.fixture 162 | async def fresh_skypilot_env( 163 | example_workdir: Path, 164 | ) -> AsyncGenerator[ParallelMcpEnv, None]: 165 | """ 166 | Create a fresh ParallelMcpEnv with its own dedicated Skypilot servers. 167 | 168 | Use this for E2E tests that need clean server state on cloud infrastructure. 169 | Slower and more expensive than skypilot_env. 170 | Mark tests using this with @pytest.mark.remote and @pytest.mark.slow. 171 | """ 172 | provisioner = SkypilotProvisioner( 173 | workdir_path=example_workdir, 174 | cloud=sky.Azure(), 175 | num_nodes=2, 176 | servers_per_node=4, 177 | base_cluster_name="test-cluster-fresh", 178 | cpus=2, 179 | memory=8, 180 | ) 181 | 182 | env = ParallelMcpEnv( 183 | workdir_path=example_workdir, 184 | provisioner=provisioner, 185 | provision_at_init=True, 186 | ) 187 | 188 | yield env 189 | 190 | await env.shutdown() 191 | 192 | -------------------------------------------------------------------------------- /tests/integration/adapters/skyrl/skyrl_adapter_integration.py: -------------------------------------------------------------------------------- 1 | # import pytest 2 | # import ray 3 | 4 | # from benchmax.adapters.skyrl.skyrl_adapter import load_benchmax_env_skyrl, RemoteBaseEnvProxy, expose 5 | 6 | # # ---- Fixtures and dummies ---- 7 | # class DummyActor: 8 | # """Dummy Ray actor to simulate remote Benchmax service.""" 9 | # def __init__(self): 10 | # self.calls = [] 11 | 12 | # class Method: 13 | # def __init__(self, name, parent): 14 | # self.name = name 15 | # self.parent = parent 16 | # def remote(self, *args, **kwargs): 17 | # self.parent.calls.append((self.name, args, kwargs)) 18 | # # return a sentinel value 19 | # if self.name == 'compute_reward': 20 | # return {'a': 1.0, 'b': 2.0} 21 | # return f"{self.name}-result" 22 | 23 | # @property 24 | # def list_tools(self): 25 | # return DummyActor.Method('list_tools', self) 26 | 27 | # @property 28 | # def compute_reward(self): 29 | # return DummyActor.Method('compute_reward', self) 30 | 31 | # @property 32 | # def run_tool(self): 33 | # return DummyActor.Method('run_tool', self) 34 | 35 | # @property 36 | # def init_rollout(self): 37 | # return DummyActor.Method('init_rollout', self) 38 | 39 | # @property 40 | # def cleanup_rollout(self): 41 | # return DummyActor.Method('cleanup_rollout', self) 42 | 43 | # @pytest.fixture(autouse=True) 44 | # def patch_ray(monkeypatch): 45 | # dummy = DummyActor() 46 | # monkeypatch.setattr(ray, 'get_actor', lambda name: dummy) 47 | # monkeypatch.setattr(ray, 'get', lambda x: x) 48 | # return dummy 49 | 50 | # # ---- Tests for RemoteBaseEnvProxy ---- 51 | # def test_list_tools_proxy(patch_ray): 52 | # proxy = RemoteBaseEnvProxy(actor_name='BenchmaxEnvService', rollout_id='rid') 53 | # result = proxy.list_tools() 54 | # assert result == 'list_tools-result' 55 | # # Ensure no rollout_id injected for list_tools 56 | # assert patch_ray.calls == [('list_tools', (), {})] 57 | 58 | 59 | # def test_compute_reward_proxy(patch_ray): 60 | # proxy = RemoteBaseEnvProxy(actor_name='BenchmaxEnvService', rollout_id='RID123') 61 | # reward = proxy.compute_reward('task1', 'action1', {'gt': True}) 62 | # assert reward == {'a': 1.0, 'b': 2.0} 63 | # # rollout_id should be first arg 64 | # name, args, kwargs = patch_ray.calls[-1] 65 | # assert name == 'compute_reward' 66 | # assert args[0] == 'RID123' 67 | 68 | # # ---- Tests for _call_tool ---- 69 | # class DummyEnv: 70 | # def __init__(self): 71 | # self.benchmax_env = RemoteBaseEnvProxy(actor_name='BenchmaxEnvService') 72 | # self.extras = {} 73 | # _call_tool = load_benchmax_env_skyrl.__wrapped__.__defaults__[0]._call_tool if False else None 74 | 75 | # @pytest.fixture 76 | # def configured_env(patch_ray): 77 | # # Build a minimal SkyRL env 78 | # cfg = {} 79 | # extras = {'init_rollout_args': {}, 'task': 'T1', 'ground_truth': {}} 80 | # env = load_benchmax_env_skyrl(actor_name='BenchmaxEnvService', env_config=cfg, extras=extras) 81 | # return env 82 | 83 | 84 | # def test_call_tool_errors(configured_env): 85 | # # Not dict 86 | # assert configured_env._call_tool('not_a_dict') == "Error: Tool command must be a JSON object." 87 | # # Missing name 88 | # assert configured_env._call_tool({'arguments': {}}) == "Error: Missing 'name' field in tool command." 89 | # # Arguments not dict 90 | # assert "must be a JSON object" in configured_env._call_tool({'name': 'foo', 'arguments': 'bad'}) 91 | 92 | 93 | # def test_call_tool_success_and_truncate(configured_env): 94 | # # Valid call 95 | # out = configured_env._call_tool({'name': 'run_tool', 'arguments': {'x': 1}}) 96 | # assert out.startswith('run_tool-result') 97 | # # Truncate 98 | # long_res = configured_env._call_tool({'name': 'run_tool', 'arguments': {}}, max_chars=5) 99 | # assert long_res.endswith('...') 100 | 101 | # # ---- Tests for step() ---- 102 | 103 | # @ pytest.fixture(autouse=True) 104 | # def patch_parse(monkeypatch): 105 | # # default parse returns no tools 106 | # monkeypatch.setattr('benchmax.prompts.tools.parse_hermes_tool_call', lambda x: []) 107 | 108 | 109 | # def test_step_final_reward(configured_env, patch_parse): 110 | # # parse returns [], so done=True 111 | # out = configured_env.step('final answer') 112 | # assert out["done"] is True 113 | # assert out["reward"] == 3.0 114 | # assert out["observations"] == [] 115 | 116 | 117 | # def test_step_tool_flow(monkeypatch, configured_env): 118 | # # parse returns one tool call 119 | # monkeypatch.setattr('benchmax.prompts.tools.parse_hermes_tool_call', lambda x: [{'name': 'run_tool', 'arguments': {}}]) 120 | # # Stub _call_tool 121 | # monkeypatch.setattr(configured_env, '_call_tool', lambda call: 'obs-text') 122 | # out = configured_env.step('{"name": "calculate", "arguments": {"expression": "25 ÷ 5 + 4 × 3"}}') 123 | # assert out["done"] is False 124 | # assert out["reward"] == 0.0 125 | # assert out["observations"] == [{'role': 'user', 'content': 'obs-text'}] 126 | 127 | 128 | # def test_step_tool_error(monkeypatch, configured_env): 129 | # monkeypatch.setattr('benchmax.prompts.tools.parse_hermes_tool_call', lambda x: [{'name': 'run_tool', 'arguments': {}}]) 130 | # def bad_call(call): 131 | # raise RuntimeError('fail') 132 | # monkeypatch.setattr(configured_env, '_call_tool', bad_call) 133 | # out = configured_env.step('{"name": "calculate", "arguments": {"expression": "25 ÷ 5 + 4 × 3"}}') 134 | # print(out) 135 | # assert out["done"] is False 136 | # assert out["observations"][0]['content'] == 'fail' 137 | 138 | # # ---- Tests for expose decorator ---- 139 | # import asyncio 140 | # class FakeEnv: 141 | # async def foo(self, x): 142 | # return x * 2 143 | 144 | # @pytest.mark.asyncio 145 | # async def test_expose_wrapper(): 146 | # wrapper = expose('foo') 147 | # class Host: 148 | # def __init__(self): 149 | # self.env = FakeEnv() 150 | # host = Host() 151 | # res = await wrapper(host, 10) 152 | # assert res == 20 -------------------------------------------------------------------------------- /tests/unit/envs/mcp/test_mcp_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for MCP utility functions. 3 | """ 4 | 5 | import jwt 6 | from mcp import Tool 7 | from benchmax.envs.mcp.utils import ( 8 | convert_tool_definitions, 9 | generate_jwt_token, 10 | get_auth_headers, 11 | ) 12 | from benchmax.envs.types import ToolDefinition 13 | 14 | 15 | class TestConvertToolDefinitions: 16 | """Tests for convert_tool_definitions function.""" 17 | 18 | def test_convert_tool_definitions_basic(self): 19 | """Test basic conversion of MCP tools to ToolDefinitions.""" 20 | mcp_tools = [ 21 | Tool( 22 | name="read_file", 23 | description="Read a file", 24 | inputSchema={ 25 | "type": "object", 26 | "properties": {"path": {"type": "string"}}, 27 | }, 28 | ), 29 | Tool( 30 | name="write_file", 31 | description="Write to a file", 32 | inputSchema={ 33 | "type": "object", 34 | "properties": {"path": {"type": "string"}}, 35 | }, 36 | ), 37 | ] 38 | 39 | result = convert_tool_definitions(mcp_tools, allowed_tools=None) 40 | 41 | assert len(result) == 2 42 | assert all(isinstance(t, ToolDefinition) for t in result) 43 | assert result[0].name == "read_file" 44 | assert result[1].name == "write_file" 45 | 46 | def test_convert_tool_definitions_with_filter(self): 47 | """Test filtering tools with allowed_tools list.""" 48 | mcp_tools = [ 49 | Tool(name="read_file", description="Read", inputSchema={}), 50 | Tool(name="write_file", description="Write", inputSchema={}), 51 | Tool(name="execute", description="Execute", inputSchema={}), 52 | ] 53 | 54 | allowed = ["read_file", "write_file"] 55 | result = convert_tool_definitions(mcp_tools, allowed_tools=allowed) 56 | 57 | assert len(result) == 2 58 | assert all(t.name in allowed for t in result) 59 | assert not any(t.name == "execute" for t in result) 60 | 61 | def test_convert_tool_definitions_empty_description(self): 62 | """Test handling of tools with no description.""" 63 | mcp_tools = [Tool(name="tool1", description=None, inputSchema={})] 64 | 65 | result = convert_tool_definitions(mcp_tools, allowed_tools=None) 66 | 67 | assert len(result) == 1 68 | assert result[0].description == "" 69 | 70 | def test_convert_tool_definitions_empty_list(self): 71 | """Test conversion of empty tool list.""" 72 | result = convert_tool_definitions([], allowed_tools=None) 73 | assert result == [] 74 | 75 | 76 | class TestGenerateJwtToken: 77 | """Tests for generate_jwt_token function.""" 78 | 79 | def test_generate_jwt_token_contains_standard_claims(self): 80 | """Ensure generated JWT includes required standard claims.""" 81 | secret = "secret" 82 | token = generate_jwt_token(secret) 83 | decoded = jwt.decode( 84 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server" 85 | ) 86 | 87 | assert decoded["iss"] == "mcp-client" 88 | assert decoded["aud"] == "mcp-proxy-server" 89 | assert "iat" in decoded 90 | assert "exp" in decoded 91 | assert decoded["exp"] > decoded["iat"] 92 | 93 | def test_generate_jwt_token_includes_rollout_id(self): 94 | """Ensure rollout_id is included if provided.""" 95 | secret = "secret" 96 | token = generate_jwt_token(secret, rollout_id="rollout-123") 97 | decoded = jwt.decode( 98 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server" 99 | ) 100 | 101 | assert decoded["rollout_id"] == "rollout-123" 102 | 103 | def test_generate_jwt_token_includes_extra_claims(self): 104 | """Ensure extra custom claims are included.""" 105 | secret = "secret" 106 | token = generate_jwt_token(secret, user="test_user", env="staging") 107 | decoded = jwt.decode( 108 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server" 109 | ) 110 | 111 | assert decoded["user"] == "test_user" 112 | assert decoded["env"] == "staging" 113 | 114 | def test_generate_jwt_token_expiration_respects_custom_value(self): 115 | """Ensure custom expiration_seconds is respected.""" 116 | secret = "secret" 117 | token = generate_jwt_token(secret, expiration_seconds=60) 118 | decoded = jwt.decode( 119 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server" 120 | ) 121 | 122 | assert ( 123 | abs(decoded["exp"] - decoded["iat"] - 60) <= 1 124 | ) # small clock drift margin 125 | 126 | 127 | class TestGetAuthHeaders: 128 | """Tests for get_auth_headers function.""" 129 | 130 | def test_get_auth_headers_contains_bearer_prefix(self): 131 | """Ensure Authorization header has proper Bearer format.""" 132 | headers = get_auth_headers("secret", rollout_id="r1") 133 | assert "Authorization" in headers 134 | assert headers["Authorization"].startswith("Bearer ") 135 | 136 | def test_get_auth_headers_token_decodes_correctly(self): 137 | """Ensure token inside header is valid and decodable.""" 138 | secret = "secret" 139 | headers = get_auth_headers(secret, rollout_id="rollout-xyz") 140 | token = headers["Authorization"].split("Bearer ")[1] 141 | 142 | decoded = jwt.decode( 143 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server" 144 | ) 145 | assert decoded["rollout_id"] == "rollout-xyz" 146 | 147 | def test_get_auth_headers_includes_extra_claims(self): 148 | """Ensure extra claims propagate correctly through get_auth_headers.""" 149 | secret = "secret" 150 | headers = get_auth_headers(secret, env="prod", user="alice") 151 | token = headers["Authorization"].split("Bearer ")[1] 152 | 153 | decoded = jwt.decode( 154 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server" 155 | ) 156 | assert decoded["env"] == "prod" 157 | assert decoded["user"] == "alice" 158 | -------------------------------------------------------------------------------- /tests/integration/envs/excel/test_excel_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import pytest 3 | from pathlib import Path 4 | from openpyxl import load_workbook 5 | from openpyxl.styles import PatternFill, Font 6 | from benchmax.envs.excel.workdir.excel_utils import compare_excel_cells, evaluate_excel 7 | 8 | @pytest.fixture(scope="session") 9 | def setup_files(): 10 | base_excel_path = Path(__file__).parent / "test_inputs" / "test.xlsx" 11 | ground_truth_path = Path(__file__).parent / "test_inputs" / "test_gt.xlsx" 12 | output_path = Path(__file__).parent / "test_inputs" / "test_output.xlsx" 13 | 14 | # Create ground truth file 15 | wb_gt = load_workbook(base_excel_path) 16 | ws_gt = wb_gt.active 17 | assert ws_gt is not None, "Worksheet should not be None" 18 | ws_gt["H3"] = "Text Value" 19 | ws_gt["H4"] = 4 20 | ws_gt["H5"] = 2 # Numeric value 21 | ws_gt["H6"] = "Mismatch Value" 22 | ws_gt["I3"] = "Matching Value" 23 | ws_gt["I4"] = "Matching Value" 24 | ws_gt["I3"].fill = PatternFill(start_color="FFFF00", end_color="FFFF00", fill_type="solid") # Yellow 25 | ws_gt["I4"].font = Font(color="FF0000") # Red font 26 | wb_gt.save(ground_truth_path) 27 | wb_gt.close() 28 | 29 | # Create output file 30 | wb_output = load_workbook(base_excel_path) 31 | ws_output = wb_output.active 32 | assert ws_output is not None, "Worksheet should not be None" 33 | ws_output["H3"] = "Text Value" 34 | ws_output["H4"] = 4 35 | ws_output["H5"] = "=G5" # Formula 36 | ws_output["H6"] = "Different Mismatch Value" 37 | ws_output["I3"] = "Matching Value" 38 | ws_output["I4"] = "Matching Value" 39 | ws_output["I3"].fill = PatternFill(start_color="00FF00", end_color="00FF00", fill_type="solid") # Green 40 | ws_output["I4"].font = Font(color="0000FF") # Blue font 41 | wb_output.save(output_path) 42 | wb_output.close() 43 | 44 | evaluate_excel(str(ground_truth_path)) 45 | evaluate_excel(str(output_path)) 46 | 47 | yield ground_truth_path, output_path 48 | 49 | # Cleanup 50 | if ground_truth_path.exists(): 51 | ground_truth_path.unlink() 52 | if output_path.exists(): 53 | output_path.unlink() 54 | 55 | # Test for mismatched values in a single cell 56 | @pytest.mark.excel 57 | def test_single_cell_comparison_mismatch(setup_files: Tuple[Path, Path]): 58 | ground_truth_path, test_output_path = setup_files 59 | result, message = compare_excel_cells( 60 | str(ground_truth_path), 61 | str(test_output_path), 62 | "H6" 63 | ) 64 | assert not result 65 | assert "Value mismatch" in message 66 | 67 | # Test for matching values in a single cell 68 | @pytest.mark.excel 69 | def test_single_cell_comparison_match(setup_files: Tuple[Path, Path]): 70 | ground_truth_path, test_output_path = setup_files 71 | result, message = compare_excel_cells( 72 | str(ground_truth_path), 73 | str(test_output_path), 74 | "H3" 75 | ) 76 | assert result 77 | assert "passed" in message 78 | 79 | # Test for mismatched values in a range of cells 80 | @pytest.mark.excel 81 | def test_range_comparison_mismatch(setup_files: Tuple[Path, Path]): 82 | ground_truth_path, test_output_path = setup_files 83 | result, message = compare_excel_cells( 84 | str(ground_truth_path), 85 | str(test_output_path), 86 | "D4:H6" 87 | ) 88 | assert not result 89 | assert "Value mismatch" in message 90 | 91 | # Test for matching values in a range of cells 92 | @pytest.mark.excel 93 | def test_range_comparison_match(setup_files: Tuple[Path, Path]): 94 | ground_truth_path, test_output_path = setup_files 95 | result, message = compare_excel_cells( 96 | str(ground_truth_path), 97 | str(test_output_path), 98 | "H5:H5" 99 | ) 100 | assert result 101 | assert "passed" in message 102 | 103 | # Test for handling sheet names with mismatched values 104 | @pytest.mark.excel 105 | def test_sheet_name_handling_mismatch(setup_files: Tuple[Path, Path]): 106 | ground_truth_path, test_output_path = setup_files 107 | result, message = compare_excel_cells( 108 | str(ground_truth_path), 109 | str(test_output_path), 110 | "'Sheet1'!H6" 111 | ) 112 | assert not result 113 | assert "Value mismatch" in message 114 | 115 | # Test for handling sheet names with matching values 116 | @pytest.mark.excel 117 | def test_sheet_name_handling_match(setup_files: Tuple[Path, Path]): 118 | ground_truth_path, test_output_path = setup_files 119 | result, message = compare_excel_cells( 120 | str(ground_truth_path), 121 | str(test_output_path), 122 | "'Sheet1'!H3" 123 | ) 124 | assert result 125 | assert "passed" in message 126 | 127 | # Test for comparing cell formatting with mismatched styles 128 | @pytest.mark.excel 129 | def test_fill_color_comparison_mismatch(setup_files: Tuple[Path, Path]): 130 | ground_truth_path, test_output_path = setup_files 131 | result, message = compare_excel_cells( 132 | str(ground_truth_path), 133 | str(test_output_path), 134 | "I3", 135 | is_CF=True 136 | ) 137 | assert not result 138 | assert "Fill color mismatch" in message 139 | 140 | 141 | # Test for comparing cell formatting with mismatched styles 142 | @pytest.mark.excel 143 | def test_font_color_comparison_mismatch(setup_files: Tuple[Path, Path]): 144 | ground_truth_path, test_output_path = setup_files 145 | result, message = compare_excel_cells( 146 | str(ground_truth_path), 147 | str(test_output_path), 148 | "I4", 149 | is_CF=True 150 | ) 151 | assert not result 152 | assert "Font color mismatch" in message 153 | 154 | # Test for comparing cell formatting with matching styles 155 | @pytest.mark.excel 156 | def test_formatting_comparison_match(setup_files: Tuple[Path, Path]): 157 | ground_truth_path, test_output_path = setup_files 158 | result, message = compare_excel_cells( 159 | str(ground_truth_path), 160 | str(test_output_path), 161 | "H3", 162 | is_CF=True 163 | ) 164 | assert result 165 | assert "passed" in message 166 | 167 | # Test for handling missing sheets with mismatched values 168 | @pytest.mark.excel 169 | def test_missing_sheet_mismatch(setup_files: Tuple[Path, Path]): 170 | ground_truth_path, test_output_path = setup_files 171 | result, message = compare_excel_cells( 172 | str(ground_truth_path), 173 | str(test_output_path), 174 | "'NonExistentSheet'!A1" 175 | ) 176 | assert not result 177 | assert "Worksheet 'NonExistentSheet' not found" in message -------------------------------------------------------------------------------- /tests/integration/envs/mcp/provisioners/test_skypilot_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integration tests for SkypilotProvisioner. 3 | """ 4 | 5 | import pytest 6 | import asyncio 7 | import sky 8 | from pathlib import Path 9 | from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner 10 | from tests.integration.envs.mcp.provisioners.utils import wait_for_server_health 11 | 12 | 13 | class TestEndToEnd: 14 | """End-to-end integration tests for provisioning and teardown.""" 15 | 16 | @pytest.mark.asyncio 17 | @pytest.mark.slow 18 | @pytest.mark.remote 19 | async def test_single_node_lifecycle(self, example_workdir: Path): 20 | """Test complete lifecycle of a single-node cluster with all validations.""" 21 | base_name = "test-single-node-cluster" 22 | api_secret = "single-node-test-secret-32chars!!" 23 | 24 | provisioner = SkypilotProvisioner( 25 | workdir_path=example_workdir, 26 | cloud=sky.Azure(), 27 | num_nodes=1, 28 | servers_per_node=2, 29 | cpus="4+", 30 | memory="16+", 31 | base_cluster_name=base_name, 32 | ) 33 | 34 | try: 35 | # Verify configuration before provisioning 36 | assert provisioner.cluster_name.startswith(f"{base_name}-") 37 | assert provisioner._workdir_path.is_absolute() 38 | 39 | # Provision 40 | addresses = await provisioner.provision_servers(api_secret) 41 | assert len(addresses) == 2 # 1 node * 2 servers 42 | 43 | # Verify all addresses have correct format 44 | for addr in addresses: 45 | assert ":" in addr 46 | host, port = addr.split(":") 47 | assert port.isdigit() 48 | assert 8080 <= int(port) < 8090 49 | 50 | # Verify servers are up and healthy 51 | health_checks = await asyncio.gather( 52 | *[wait_for_server_health(addr, timeout=90.0) for addr in addresses] 53 | ) 54 | assert all(health_checks), "Not all servers became healthy" 55 | 56 | # Check that double-provisioning would result in an error 57 | with pytest.raises(RuntimeError, match="already provisioned"): 58 | await provisioner.provision_servers(api_secret) 59 | finally: 60 | # Teardown - always attempt cleanup 61 | await provisioner.teardown() 62 | 63 | @pytest.mark.asyncio 64 | @pytest.mark.slow 65 | @pytest.mark.remote 66 | async def test_multi_node_lifecycle(self, example_workdir: Path): 67 | """Test complete lifecycle of a multi-node cluster with all validations.""" 68 | provisioner1 = SkypilotProvisioner( 69 | workdir_path=example_workdir, 70 | cloud=sky.Azure(), 71 | num_nodes=2, 72 | servers_per_node=3, 73 | cpus=2, 74 | memory=8, 75 | ) 76 | 77 | provisioner2 = SkypilotProvisioner( 78 | workdir_path=example_workdir, 79 | cloud=sky.Azure(), 80 | num_nodes=1, 81 | servers_per_node=1, 82 | ) 83 | 84 | api_secret = "multi-node-test-secret-32chars!!" 85 | 86 | try: 87 | # Verify cluster names are unique 88 | assert provisioner1.cluster_name != provisioner2.cluster_name 89 | assert provisioner1.cluster_name.startswith("benchmax-env-cluster-") 90 | assert provisioner2.cluster_name.startswith("benchmax-env-cluster-") 91 | 92 | # Provision main test cluster 93 | addresses = await provisioner1.provision_servers(api_secret) 94 | assert len(addresses) == 6 # 2 nodes * 3 servers 95 | 96 | # Verify addresses are grouped by node 97 | hosts = [addr.split(":")[0] for addr in addresses] 98 | unique_hosts = set(hosts) 99 | assert len(unique_hosts) == 2, "Should have 2 unique node IPs" 100 | 101 | for host in unique_hosts: 102 | host_servers = [addr for addr in addresses if addr.startswith(host)] 103 | assert len(host_servers) == 3, f"Host {host} should have 3 servers" 104 | 105 | # Verify all servers are up and healthy 106 | health_checks = await asyncio.gather( 107 | *[wait_for_server_health(addr, timeout=90.0) for addr in addresses] 108 | ) 109 | assert all(health_checks), "Not all servers became healthy" 110 | finally: 111 | # Teardown - always attempt cleanup 112 | await provisioner1.teardown() 113 | 114 | class TestValidation: 115 | """Test parameter validation without provisioning.""" 116 | 117 | def test_invalid_num_nodes(self): 118 | """Test validation of num_nodes parameter.""" 119 | with pytest.raises(ValueError, match="at least 1"): 120 | SkypilotProvisioner( 121 | workdir_path=".", 122 | cloud=sky.Azure(), 123 | num_nodes=0, 124 | servers_per_node=5, 125 | ) 126 | 127 | def test_invalid_servers_per_node_low(self): 128 | """Test validation of servers_per_node parameter (too low).""" 129 | with pytest.raises(ValueError, match="between 1 and 100"): 130 | SkypilotProvisioner( 131 | workdir_path=".", 132 | cloud=sky.Azure(), 133 | num_nodes=1, 134 | servers_per_node=0, 135 | ) 136 | 137 | def test_invalid_servers_per_node_high(self): 138 | """Test validation of servers_per_node parameter (too high).""" 139 | with pytest.raises(ValueError, match="between 1 and 100"): 140 | SkypilotProvisioner( 141 | workdir_path=".", 142 | cloud=sky.Azure(), 143 | num_nodes=1, 144 | servers_per_node=101, 145 | ) 146 | 147 | def test_custom_base_cluster_name(self): 148 | """Test that custom base cluster name is used.""" 149 | provisioner = SkypilotProvisioner( 150 | workdir_path=".", 151 | cloud=sky.Azure(), 152 | num_nodes=1, 153 | servers_per_node=1, 154 | base_cluster_name="my-custom-cluster", 155 | ) 156 | 157 | assert provisioner.cluster_name.startswith("my-custom-cluster-") 158 | 159 | def test_workdir_path_conversion(self): 160 | """Test that workdir_path is properly converted to absolute Path.""" 161 | provisioner = SkypilotProvisioner( 162 | workdir_path="./relative/path", 163 | cloud=sky.Azure(), 164 | num_nodes=1, 165 | servers_per_node=1, 166 | ) 167 | 168 | # Should be converted to absolute path 169 | assert provisioner._workdir_path.is_absolute() 170 | -------------------------------------------------------------------------------- /src/benchmax/adapters/skyrl/benchmax_data_process.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess a huggingface/benchmax dataset to a multiturn format suitable for a benchmax environment. 3 | """ 4 | 5 | import argparse 6 | import logging 7 | from importlib import import_module 8 | from pathlib import Path 9 | from types import ModuleType 10 | from typing import Type 11 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict 12 | import datasets 13 | import asyncio 14 | import inspect 15 | 16 | from benchmax.envs.base_env import BaseEnv 17 | 18 | # Set logging level to WARNING and above 19 | logging.basicConfig(level=logging.WARNING) 20 | 21 | 22 | def load_class(dotted_path: str) -> Type[BaseEnv]: 23 | """ 24 | Load and return the class specified by `dotted_path`. 25 | Example: "benchmax.envs.wikipedia.wiki_env.WikipediaEnv" 26 | """ 27 | try: 28 | module_path, class_name = dotted_path.rsplit(".", 1) 29 | except ValueError as exc: 30 | raise ImportError( 31 | f'"{dotted_path}" doesn\'t look like "package.module.Class"' 32 | ) from exc 33 | 34 | module: ModuleType = import_module(module_path) 35 | try: 36 | cls: Type[BaseEnv] = getattr(module, class_name) 37 | except AttributeError as exc: 38 | raise ImportError( 39 | f'Module "{module_path}" has no attribute "{class_name}"' 40 | ) from exc 41 | 42 | return cls 43 | 44 | 45 | def get_canonical_class_name(cls: Type[BaseEnv]) -> str: 46 | """ 47 | Get the canonical class name, removing local/skypilot prefix/suffix if the parent class 48 | has the same name without that prefix/suffix. 49 | """ 50 | class_name = cls.__name__ 51 | 52 | # Check for prefixes/suffixes to strip 53 | prefixes = ["local", "skypilot"] 54 | suffixes = ["local", "skypilot"] 55 | 56 | # Try to find a matching parent class without the prefix/suffix 57 | for base_cls in cls.__bases__: 58 | base_name = base_cls.__name__ 59 | 60 | # Check if current class has prefix that base doesn't 61 | for prefix in prefixes: 62 | if class_name.lower().startswith( 63 | prefix 64 | ) and not base_name.lower().startswith(prefix): 65 | # Check if removing prefix gives us the base name 66 | stripped = class_name[len(prefix) :] 67 | if stripped == base_name: 68 | return base_name 69 | 70 | # Check if current class has suffix that base doesn't 71 | for suffix in suffixes: 72 | if class_name.lower().endswith(suffix) and not base_name.lower().endswith( 73 | suffix 74 | ): 75 | # Check if removing suffix gives us the base name 76 | stripped = class_name[: -len(suffix)] 77 | if stripped == base_name: 78 | return base_name 79 | 80 | # No matching parent found, return original name 81 | return class_name 82 | 83 | 84 | async def get_system_prompt(cls: Type[BaseEnv]) -> str: 85 | """Setup env and get system prompt in async context.""" 86 | # Initialize env with num_local_servers=1 if supported 87 | init_signature = inspect.signature(cls.__init__) 88 | if "num_local_servers" in init_signature.parameters: 89 | env = cls(num_local_servers=1) # type: ignore 90 | else: 91 | env = cls() 92 | 93 | # Get system prompt (async function) 94 | prompt = await env.get_system_prompt(add_tool_defs=True) 95 | 96 | await env.shutdown() 97 | return prompt 98 | 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument( 103 | "--local_dir", 104 | required=True, 105 | help="Local directory where processed train/test parquet files will be written.", 106 | ) 107 | parser.add_argument( 108 | "--dataset_name", 109 | required=True, 110 | help="Identifier of the HuggingFace dataset to load (e.g., 'squad', 'wikitext').", 111 | ) 112 | parser.add_argument( 113 | "--env_path", 114 | required=True, 115 | help=( 116 | "Dotted path to the BaseEnv subclass to use for preprocessing, " 117 | "e.g. 'benchmax.envs.wikipedia.wiki_env.WikipediaEnv'." 118 | ), 119 | ) 120 | 121 | args = parser.parse_args() 122 | 123 | print(f"Loading {args.dataset_name} dataset...", flush=True) 124 | benchmax_cls: Type[BaseEnv] = load_class(args.env_path) 125 | raw_dataset, dataset_path = benchmax_cls.load_dataset(args.dataset_name) 126 | 127 | if isinstance(raw_dataset, (IterableDataset, IterableDatasetDict)): 128 | raise TypeError( 129 | f"Iterable datasets are currently not supported. Got {type(raw_dataset).__name__}. " 130 | ) 131 | 132 | if not isinstance(raw_dataset, (DatasetDict, Dataset)): 133 | raise TypeError( 134 | f"Expected DatasetDict or Dataset, but got {type(raw_dataset).__name__}." 135 | ) 136 | 137 | print("Getting system prompt...", flush=True) 138 | system_prompt = asyncio.run(get_system_prompt(benchmax_cls)) 139 | 140 | # Get canonical class name (strips local/skypilot if parent matches) 141 | canonical_name = get_canonical_class_name(benchmax_cls) 142 | 143 | def process_example(example): 144 | """Single mapping function that does all processing.""" 145 | # First apply dataset-specific preprocessing 146 | standardized = benchmax_cls.dataset_preprocess( 147 | example, dataset_path=dataset_path 148 | ) 149 | 150 | # Then format as multiturn prompt 151 | prompt = [ 152 | { 153 | "role": "system", 154 | "content": system_prompt, 155 | }, 156 | {"role": "user", "content": standardized["prompt"]}, 157 | ] 158 | result = { 159 | **standardized, 160 | "prompt": prompt, 161 | "env_class": canonical_name, 162 | "data_source": canonical_name, 163 | } 164 | 165 | # Remove keys with None values 166 | result = {k: v for k, v in result.items() if v is not None} 167 | 168 | return result 169 | 170 | print("Processing examples...", flush=True) 171 | processed_dataset = raw_dataset.map(process_example) 172 | 173 | if isinstance(processed_dataset, DatasetDict) and set( 174 | processed_dataset.keys() 175 | ) == set(["train", "test"]): 176 | # If train and test dataset split already exist 177 | train_dataset = processed_dataset["train"] 178 | test_dataset = processed_dataset["test"] 179 | else: 180 | if isinstance(processed_dataset, DatasetDict): 181 | processed_dataset = datasets.concatenate_datasets( 182 | [ds for ds in processed_dataset.values()] 183 | ).shuffle(seed=42) 184 | 185 | split = processed_dataset.train_test_split( 186 | test_size=0.2, seed=42, shuffle=False 187 | ) 188 | train_dataset = split["train"] 189 | test_dataset = split["test"] 190 | 191 | print(f"Saving to {args.local_dir}...", flush=True) 192 | local_dir = Path(args.local_dir) 193 | train_dataset.to_parquet(local_dir / "train.parquet") 194 | test_dataset.to_parquet(local_dir / "test.parquet") 195 | -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for MCP environment infrastructure. 3 | """ 4 | 5 | import jwt 6 | import time 7 | import asyncio 8 | from contextlib import AsyncExitStack 9 | from pathlib import Path 10 | from typing import Optional, Dict, List, Any 11 | import aiohttp 12 | from fastmcp import Client 13 | from mcp import Tool 14 | 15 | from benchmax.envs.types import ToolDefinition 16 | 17 | 18 | def convert_tool_definitions( 19 | tools: List[Tool], allowed_tools: Optional[List[str]] 20 | ) -> List[ToolDefinition]: 21 | """ 22 | Convert MCP Tool objects to ToolDefinition dataclass. 23 | 24 | Args: 25 | tools: List of MCP Tool objects. 26 | allowed_tools: Optional whitelist of tool names. If provided, 27 | only tools in this list are included. 28 | 29 | Returns: 30 | List of ToolDefinition objects. 31 | """ 32 | tool_definitions = [ 33 | ToolDefinition( 34 | name=tool.name, 35 | description=tool.description or "", 36 | input_schema=tool.inputSchema, 37 | ) 38 | for tool in tools 39 | ] 40 | 41 | if not allowed_tools: 42 | return tool_definitions 43 | 44 | return [tool for tool in tool_definitions if tool.name in allowed_tools] 45 | 46 | 47 | def generate_jwt_token( 48 | api_secret: str, 49 | rollout_id: Optional[str] = None, 50 | expiration_seconds: int = 300, 51 | **extra_claims: Any, 52 | ) -> str: 53 | """ 54 | Generate a JWT token with standard and custom claims. 55 | 56 | Args: 57 | api_secret: Shared secret for signing (HS256). 58 | rollout_id: Optional rollout ID to include in claims. 59 | expiration_seconds: Token validity duration (default: 5 minutes). 60 | **extra_claims: Additional custom claims to include. 61 | 62 | Returns: 63 | JWT token string. 64 | """ 65 | current_time = int(time.time()) 66 | 67 | payload = { 68 | # Standard claims 69 | "iss": "mcp-client", 70 | "aud": "mcp-proxy-server", 71 | "iat": current_time, 72 | "exp": current_time + expiration_seconds, 73 | # Custom claims 74 | **extra_claims, 75 | } 76 | 77 | # Add rollout_id if provided 78 | if rollout_id: 79 | payload["rollout_id"] = rollout_id 80 | 81 | # Sign with HS256 82 | token = jwt.encode(payload, api_secret, algorithm="HS256") 83 | return token 84 | 85 | 86 | def get_auth_headers( 87 | api_secret: str, rollout_id: Optional[str] = None, **extra_claims: Any 88 | ) -> Dict[str, str]: 89 | """ 90 | Generate Authorization header with JWT token. 91 | 92 | Args: 93 | api_secret: Shared secret for signing. 94 | rollout_id: Optional rollout ID to include in claims. 95 | **extra_claims: Additional custom claims. 96 | 97 | Returns: 98 | Headers dict with Authorization Bearer token. 99 | """ 100 | token = generate_jwt_token(api_secret, rollout_id, **extra_claims) 101 | return {"Authorization": f"Bearer {token}"} 102 | 103 | 104 | async def upload_form( 105 | http_session: aiohttp.ClientSession, 106 | upload_url: str, 107 | api_secret: str, 108 | file_bytes: bytes, 109 | filename: str, 110 | rollout_id: Optional[str] = None, 111 | content_type: str = "application/octet-stream", 112 | ) -> None: 113 | """ 114 | Upload a file or content to a remote URL using multipart form with JWT auth. 115 | 116 | Args: 117 | http_session: aiohttp client session. 118 | upload_url: URL to upload to. 119 | api_secret: Shared secret for JWT signing. 120 | file_bytes: File content as bytes. 121 | filename: Name for the uploaded file. 122 | rollout_id: Optional rollout ID for JWT claims. 123 | content_type: MIME type of the content. 124 | 125 | Raises: 126 | RuntimeError: If upload fails. 127 | """ 128 | # Generate JWT token with rollout_id 129 | headers = get_auth_headers(api_secret, rollout_id) 130 | 131 | # Create multipart form data 132 | data = aiohttp.FormData() 133 | data.add_field("file", file_bytes, filename=filename, content_type=content_type) 134 | 135 | async with http_session.post(upload_url, headers=headers, data=data) as response: 136 | if response.status == 200: 137 | return 138 | error_text = await response.text() 139 | raise RuntimeError(f"Upload failed: {response.status} - {error_text}") 140 | 141 | 142 | async def download_file( 143 | http_session: aiohttp.ClientSession, 144 | download_url: str, 145 | api_secret: str, 146 | params: Dict[str, str], 147 | dst_path: Path, 148 | rollout_id: Optional[str] = None, 149 | ) -> None: 150 | """ 151 | Download a file from a remote URL and save it locally with JWT auth. 152 | 153 | Args: 154 | http_session: aiohttp client session. 155 | download_url: URL to download from. 156 | api_secret: Shared secret for JWT signing. 157 | params: Query parameters. 158 | dst_path: Local path to save the downloaded file. 159 | rollout_id: Optional rollout ID for JWT claims. 160 | 161 | Raises: 162 | RuntimeError: If download fails. 163 | """ 164 | # Generate JWT token with rollout_id 165 | headers = get_auth_headers(api_secret, rollout_id) 166 | 167 | async with http_session.get( 168 | download_url, headers=headers, params=params 169 | ) as response: 170 | if response.status != 200: 171 | error_text = await response.text() 172 | raise RuntimeError(f"Download failed: {response.status} - {error_text}") 173 | 174 | dst_path.parent.mkdir(parents=True, exist_ok=True) 175 | with open(dst_path, "wb") as f: 176 | async for chunk in response.content.iter_chunked(8192): 177 | f.write(chunk) 178 | 179 | 180 | async def _safe_session_runner(self): 181 | """ 182 | Patched version of FastMCPClient._session_runner that catches exceptions. 183 | 184 | This prevents crashes when servers disconnect or restart, logging errors 185 | instead of propagating them. 186 | """ 187 | try: 188 | async with AsyncExitStack() as stack: 189 | try: 190 | await stack.enter_async_context(self._context_manager()) 191 | self._session_state.ready_event.set() 192 | await self._session_state.stop_event.wait() 193 | except (aiohttp.ClientError, asyncio.CancelledError) as e: 194 | # Common expected errors when server disconnects or restarts 195 | print(f"[INFO] Client session ended: {type(e).__name__}: {e}") 196 | except Exception as e: 197 | # Log unexpected errors 198 | print(f"[WARN] Client session crashed: {e}") 199 | finally: 200 | self._session_state.ready_event.set() 201 | except Exception as e: 202 | # Catch outer-level async exit errors 203 | print(f"[WARN] Session runner outer error: {e}") 204 | 205 | 206 | def apply_fastmcp_patch(): 207 | """ 208 | Apply the monkey patch to FastMCP Client. 209 | 210 | Call this once at module import to enable graceful session handling. 211 | """ 212 | Client._session_runner = _safe_session_runner 213 | -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/README.md: -------------------------------------------------------------------------------- 1 | # Parallel MCP Env 2 | 3 | This directory contains: 4 | ```bash 5 | ├── example_workdir/ # An example of a workdir 6 | ├── parallel_mcp_env.py # Parallel MCP env 7 | ├── provisioners/ # Various methods of provisioning server (e.g. local, skypilot) 8 | ├── proxy_server.py # Proxy server that redirects MCP connection 9 | ├── server_pool.py # Helper to manage servers 10 | └── utils.py # Shared utils 11 | ``` 12 | 13 | With parallel MCP env, you can easily run multiple MCP servers in parallel either locally or in the cloud. The environment supports both local development and cloud deployment through various provisioners, making it flexible for different use cases. 14 | 15 | ## Workdir Structure 16 | 17 | The `workdir` directory contains all the files necessary to run your MCP environment. It should include: 18 | 19 | 1. `mcp_config.yaml`: Configuration for spinning up MCP servers 20 | 2. `reward_fn.py`: Implementation of the reward function for evaluating model outputs 21 | 3. `setup.sh`: Script that runs at startup to install dependencies 22 | 4. Additional files: Any other files required by your MCP server (e.g., mcp server code, assets, helper scripts) 23 | 24 | When the environment is initialized, the entire workdir is copied to the execution machine, ensuring all necessary components are available whether running locally or in the cloud. 25 | 26 | 27 | ## Creating a custom environment using MCP 28 | 29 | To create a custom environment using an MCP server (like a calculator, browser, or spreadsheet), you can extend `ParallelMcpEnv`. Here's a quick step-by-step guide using `benchmax.envs.math.math_env.MathEnv` as an example. 30 | 31 | ### 1. **Define a System Prompt** 32 | 33 | This prompt guides the LLM’s behavior. It can include any instruction, such as how to format the answer or when to use tools. 34 | 35 | ```python 36 | SYSTEM_PROMPT = """Please use the tools provided to do any computation. 37 | Write your complete answer on the final line only, within the xml tags . 38 | """ 39 | ``` 40 | 41 | ### 2. **Configure MCP Server(s)** 42 | 43 | Define the MCP servers to be launched. You can configure one or more in `src/benchmax/envs/math/workdir/mcp_config.yaml`: 44 | 45 | ```yaml 46 | mcpServers: 47 | calculator: 48 | command: uvx 49 | args: 50 | - mcp-server-calculator 51 | ``` 52 | 53 | This MCP config will be used to spin up multiple MCP servers, which can run on a single machine (if configured locally) or distributed across multiple nodes in the cloud (if using SkyPilot multi-node). 54 | 55 | ### 3. **Write a Reward Function** 56 | 57 | The reward function evaluates how "correct" the model's output is, based on structured output. Here’s a simple XML-based example: 58 | 59 | Note that `**kwargs` contains all the other fields in your dataset example, so feel free to use them in `reward_fn` calculations. 60 | 61 | ```python 62 | async def text_match_reward( 63 | completion: str, 64 | ground_truth: str, 65 | mcp_client: Client, 66 | workspace: Path, 67 | **kwargs: Any 68 | ) -> float: 69 | """ 70 | Reward = 1 if `ground_truth` (case-insensitive) appears anywhere *inside* 71 | the first block of `completion`; otherwise 0. 72 | 73 | Falls back to 0 if the tag is missing or empty. 74 | """ 75 | 76 | # Grab only the text inside the first pair (case-insensitive). 77 | m = re.search(r'(.*?)', completion, flags=re.IGNORECASE | re.DOTALL) 78 | if m is None: 79 | return 0.0 80 | 81 | # Unescape any XML entities (& → &, etc.) and normalise whitespace. 82 | answer_text = unescape(m.group(1)).strip().lower() 83 | 84 | try: 85 | # Try to interpret both as floats for numerical comparison. 86 | return float(float(ground_truth.lower()) == float(answer_text)) 87 | except ValueError: 88 | return 0.0 89 | ``` 90 | 91 | Check out [math env's reward_fn.py](/src/benchmax/envs/math/workdir/reward_fn.py) for the full implementation. You can reference [this example reward_fn.py](/src/benchmax/envs/mcp/example_workdir/reward_fn.py) for a comprehensive usage of various ways to compute reward. 92 | 93 | ### 4. Define **`dataset_preprocess`** 94 | 95 | If your dataset is not already standardized, implement this method to convert a raw example into a standardized one with: 96 | 97 | - `"prompt"`: A fully constructed string prompt. 98 | - `"ground_truth"`: A known correct output (optional depending on reward). 99 | - `"init_rollout_args"`: Arguments needed to initialize a rollout. 100 | 101 | Example for our math task: 102 | 103 | ```python 104 | def dataset_preprocess(self, example: dict) -> StandardizedExample: 105 | return StandardizedExample( 106 | prompt=example.get("task", ""), 107 | ground_truth=example.get("answer", ""), 108 | init_rollout_args={} 109 | ) 110 | ``` 111 | 112 |
113 | Notes on init_rollout_args 114 | The `init_rollout_args` dictionary is passed from `dataset_preprocess()` to your environment's `init_rollout()` method. It is used to initialize any **per-example files, resources, or execution context** needed before a rollout begins. 115 | 116 | Common use cases include: 117 | 118 | - **Input files**: For environments that manipulate files like spreadsheets, images, or databases, pass the necessary file paths. 119 | - **Version control**: For code-related tasks, you might pass a `commit_id` to check out the correct code state. 120 | - **Task-specific settings**: Pass metadata like cell ranges, task IDs, or execution flags. 121 | 122 | Example: 123 | 124 | ```python 125 | # Inside dataset_preprocess 126 | return { 127 | "prompt": "...", 128 | "ground_truth": "...", 129 | "init_rollout_args": { 130 | "spreadsheet_path": "/path/to/1_001_input.xlsx" 131 | } 132 | } 133 | ``` 134 | 135 | Then in your `init_rollout()` method: 136 | 137 | ```python 138 | def init_rollout(self, rollout_id: str, **rollout_args): 139 | spreadsheet_path = rollout_args["spreadsheet_path"] 140 | workspace = self.get_rollout_workspace(rollout_id) 141 | 142 | # Copy the input file into the rollout's workspace 143 | shutil.copy(spreadsheet_path, workspace / Path(spreadsheet_path).name) 144 | ``` 145 | 146 | This pattern ensures each rollout starts with the correct inputs and configuration. 147 |
148 | 149 | ### 5. **Extend `ParallelMcpEnv`** 150 | 151 | Now bring everything together into a custom environment class: 152 | 153 | ```python 154 | SYSTEM_PROMPT = """Please use the tools provided to do any computation. 155 | Write your complete answer on the final line only, within the xml tags .\n 156 | """ 157 | 158 | class MathEnv(ParallelMcpEnv): 159 | """Environment for math problems, using local MCP tools.""" 160 | 161 | system_prompt: str = SYSTEM_PROMPT 162 | 163 | def __init__(self, workdir_path: Path, provisioner: BaseProvisioner, **kwargs): 164 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs) 165 | 166 | @classmethod 167 | def dataset_preprocess(cls, example: Any, **kwargs) -> StandardizedExample: 168 | return StandardizedExample( 169 | prompt=example.get("task", ""), 170 | ground_truth=example.get("answer", ""), 171 | init_rollout_args=None, 172 | ) 173 | ``` 174 | 175 | You're done! This environment is now compatible with `benchmax` and can be plugged into any compatible RL trainer. 176 | 177 | In [math_env.py](/src/benchmax/envs/math/math_env.py), we have also provided `MathEnvLocal` class to run the env locally and `MathEnvSkypilot` to run it remotely across multiple machines. 178 | -------------------------------------------------------------------------------- /tests/unit/envs/excel/test_excel_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for ExcelEnv. 3 | 4 | All tests are fast with no external service calls. 5 | """ 6 | 7 | import pytest 8 | from pathlib import Path 9 | from typing import Dict, Any 10 | from unittest.mock import Mock, patch 11 | 12 | from benchmax.envs.excel.excel_env import ExcelEnv 13 | from benchmax.envs.excel.workdir.reward_fn import spreadsheet_comparison_reward 14 | from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner 15 | 16 | 17 | # Fixtures 18 | @pytest.fixture(scope="session") 19 | def test_xlsx_path() -> str: 20 | return str(Path(__file__).parent / "test_inputs" / "test.xlsx") 21 | 22 | 23 | @pytest.fixture 24 | def excel_env(tmp_path: Path) -> ExcelEnv: 25 | """Fixture to create an ExcelEnv instance without initializing the parent.""" 26 | env = ExcelEnv( 27 | workdir_path=tmp_path, 28 | provisioner=Mock(spec=BaseProvisioner), 29 | provision_at_init=False, 30 | ) 31 | env._servers_provisioned = True 32 | return env 33 | 34 | 35 | @pytest.fixture 36 | def mock_dataset(tmp_path: Path) -> Path: 37 | """Create a fake dataset folder with a sample input file. 38 | 39 | The ExcelEnv expects the dataset layout to contain a folder (spreadsheet_path) 40 | with files like `1__input.xlsx`. We touch a file to satisfy existence 41 | checks; reading is patched in tests that need it. 42 | """ 43 | base = tmp_path / "dataset_root" 44 | sheet_dir = base / "sheet_folder" 45 | sheet_dir.mkdir(parents=True) 46 | 47 | # Create a dummy input file path that ExcelEnv will look for 48 | (sheet_dir / "1_42_input.xlsx").write_text("dummy") 49 | 50 | # Also create an answer file (not strictly required for dataset_preprocess) 51 | (sheet_dir / "1_42_answer.xlsx").write_text("dummy-answer") 52 | 53 | return base 54 | 55 | 56 | @pytest.fixture 57 | def mock_mcp_client() -> Mock: 58 | return Mock() 59 | 60 | 61 | @pytest.fixture 62 | def mock_workspace(tmp_path: Path) -> Path: 63 | ws = tmp_path / "workspace" 64 | ws.mkdir() 65 | return ws 66 | 67 | 68 | class TestDatasetPreprocess: 69 | """Tests for ExcelEnv.dataset_preprocess.""" 70 | 71 | def test_valid_example( 72 | self, mock_dataset: Path, test_xlsx_path: str 73 | ) -> None: 74 | """Valid example returns an ExcelExample with expected fields.""" 75 | example: Dict[str, Any] = { 76 | "id": "42", 77 | "spreadsheet_path": test_xlsx_path, 78 | "instruction": "Fill cell A1 with 10", 79 | "instruction_type": "Cell-Level Manipulation", 80 | "answer_position": "A1", 81 | } 82 | 83 | with patch( 84 | "benchmax.envs.excel.excel_env.excel_to_str_repr", 85 | return_value="A1=10", 86 | ): 87 | result = ExcelEnv.dataset_preprocess(example, dataset_path=mock_dataset) 88 | 89 | # Result is a mapping-like StandardizedExample (TypedDict / dataclass) 90 | assert result is not None 91 | assert "Fill cell A1 with 10" in result["prompt"] 92 | assert "1_42_input.xlsx" in result["prompt"] 93 | assert result["init_rollout_args"] is not None 94 | assert result["init_rollout_args"]["input_src_path"].endswith("1_42_input.xlsx") 95 | assert result["answer_position"] == "A1" 96 | assert result["output_filename"] == "1_42_output.xlsx" 97 | 98 | def test_missing_fields_raise(self, excel_env: ExcelEnv) -> None: 99 | """Missing required fields should raise ValueError.""" 100 | example: Dict[str, Any] = {"id": "1", "instruction": "do"} 101 | 102 | with pytest.raises(ValueError): 103 | excel_env.dataset_preprocess(example) 104 | 105 | def test_non_string_spreadsheet_path_raises( 106 | self, mock_dataset: Path 107 | ) -> None: 108 | """Non-string spreadsheet_path should raise ValueError.""" 109 | example: Dict[str, Any] = { 110 | "id": "1", 111 | "spreadsheet_path": 123, # invalid type 112 | "instruction": "x", 113 | "instruction_type": "Cell-Level Manipulation", 114 | "answer_position": "A1", 115 | } 116 | 117 | with pytest.raises(TypeError): 118 | ExcelEnv.dataset_preprocess(example, dataset_path=mock_dataset) 119 | 120 | def test_missing_spreadsheet_folder_raises( 121 | self, tmp_path: Path 122 | ) -> None: 123 | """If the spreadsheet folder does not exist under the dataset path, raise FileNotFoundError.""" 124 | dataset_path = tmp_path / "some_other_root" 125 | 126 | example: Dict[str, Any] = { 127 | "id": "99", 128 | "spreadsheet_path": "no_such_folder", 129 | "instruction": "x", 130 | "instruction_type": "Cell-Level Manipulation", 131 | "answer_position": "A1", 132 | } 133 | 134 | with pytest.raises(FileNotFoundError): 135 | ExcelEnv.dataset_preprocess(example, dataset_path=dataset_path) 136 | 137 | 138 | class TestRewardComputation: 139 | """Test reward computation for spreadsheet comparison.""" 140 | 141 | def test_exact_match_returns_one( 142 | self, mock_mcp_client: Mock, mock_workspace: Path 143 | ) -> None: 144 | # Patch compare_excel_cells to return match=True 145 | with patch( 146 | "benchmax.envs.excel.workdir.reward_fn.compare_excel_cells", 147 | return_value=(True, None), 148 | ): 149 | score = spreadsheet_comparison_reward( 150 | completion="irrelevant", 151 | ground_truth={}, 152 | mcp_client=mock_mcp_client, 153 | workspace=mock_workspace, 154 | answer_position="A1", 155 | output_filename="out.xlsx", 156 | ground_truth_filename="gt.xlsx", 157 | ) 158 | 159 | assert score == 1.0 160 | 161 | def test_mismatch_returns_zero( 162 | self, mock_mcp_client: Mock, mock_workspace: Path 163 | ) -> None: 164 | with patch( 165 | "benchmax.envs.excel.workdir.reward_fn.compare_excel_cells", 166 | return_value=(False, None), 167 | ): 168 | score = spreadsheet_comparison_reward( 169 | completion="irrelevant", 170 | ground_truth={}, 171 | mcp_client=mock_mcp_client, 172 | workspace=mock_workspace, 173 | answer_position="A1", 174 | output_filename="out.xlsx", 175 | ground_truth_filename="gt.xlsx", 176 | ) 177 | 178 | assert score == 0.0 179 | 180 | def test_missing_kwargs_raises( 181 | self, mock_mcp_client: Mock, mock_workspace: Path 182 | ) -> None: 183 | with pytest.raises(ValueError): 184 | spreadsheet_comparison_reward( 185 | completion="x", 186 | ground_truth={}, 187 | mcp_client=mock_mcp_client, 188 | workspace=mock_workspace, 189 | # missing answer_position, output_filename, ground_truth_filename 190 | ) 191 | 192 | def test_compare_raises_returns_zero( 193 | self, mock_mcp_client: Mock, mock_workspace: Path 194 | ) -> None: 195 | with patch( 196 | "benchmax.envs.excel.workdir.reward_fn.compare_excel_cells", 197 | side_effect=Exception("boom"), 198 | ): 199 | score = spreadsheet_comparison_reward( 200 | completion="irrelevant", 201 | ground_truth={}, 202 | mcp_client=mock_mcp_client, 203 | workspace=mock_workspace, 204 | answer_position="A1", 205 | output_filename="out.xlsx", 206 | ground_truth_filename="gt.xlsx", 207 | ) 208 | 209 | assert score == 0.0 210 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | Benchmax 3 | 4 | 5 | ## benchmax: Framework-Agnostic RL Environments for LLM Fine-Tuning 6 | *A lightweight, training-framework agnostic library for defining, running, and parallelizing environments, to fine-tune OSS LLMs with reinforcement learning.* 7 |
8 |
9 |
10 | 11 | Website 12 | 13 | 14 | @cgftlabs 15 | 16 |
17 |
18 | License 19 |
20 | 21 | ## 📌 News 22 | 23 | - **[29 Oct 2025]** 🎉 Added support for easy multi-node parallelization across all major cloud providers using [SkyPilot](https://github.com/skypilot-org/skypilot) 24 | - **[29 Oct 2025]** 🎉 Integration with [SkyRL](https://github.com/NovaSky-AI/SkyRL) for distributed RL training across clusters 25 | - **[Upcoming]** 🛠️ Integration with Tinker API. 26 | 27 | ## 📘 Quickstart 28 | 29 | **Example: Multi-node parallelization of Excel Env with SkyRL and SkyPilot** 30 | 31 | RL environments can be computationally expensive to run (e.g. running tests). To handle these workloads efficiently, we distribute rollouts across multiple nodes using **SkyPilot**, horizontally scaling `benchmax` across cloud providers like GCP, AWS, Azure, etc. 32 | 33 | **SkyRL** is a training framework `benchmax` is currently integrated with. Use our ***SkyRL*** integration to RL finetune Qwen-2.5 to do spreadsheet manipulation using a excel MCP parallelized across multiple nodes. The environment is defined in [`benchmax.envs.excel.excel_env.ExcelEnvSkypilot`](/src/benchmax/envs/excel/excel_env.py) 34 | 35 | 1. **Prepare the dataset** 36 | 37 | ```bash 38 | uv run src/benchmax/adapters/skyrl/benchmax_data_process.py \ 39 | --local_dir ~/data/excel \ 40 | --dataset_name spreadsheetbench \ 41 | --env_path benchmax.envs.excel.excel_env.ExcelEnvLocal 42 | ``` 43 | 44 | Note: We are using `ExcelEnvLocal` instead of `ExcelEnvSkypilot` because the MCP is only used for listing tools to prepare the system prompt. 45 | 46 | 2. **Run training and parallelize Excel environment** 47 | 48 | ```bash 49 | bash examples/skyrl/run_benchmax_excel.sh 50 | ``` 51 | 52 | This excel env example will spin up 5 nodes with 20 servers per node (total 100 MCP server in parallel). For more details, check out [multi-node parallelization](/src/benchmax/envs/mcp/README.md) and [SkyRL integration](/examples/skyrl/README.md). 53 | 54 | ## ℹ️ Overview 55 | 56 | `benchmax` comes with: 57 | 58 | - A collection of ready-to-use reinforcement learning (RL) environments for LLM fine-tuning ranging from multi-hop search to spreadsheet manipulation to CRM agents 59 | - An easy to define, compose, and parallelize your own environments, including leveraging the existing ecosystem of MCP servers 60 | - Built-in integrations with popular RL training libraries (skyrl, etc.). `benchmax` is trainer-agnostic by design 61 | 62 | Define your environment as: 63 | 64 | 1. A **toolset** (LLM calls, external APIs, calculators, MCPs, etc.). 65 | 2. **Output parsing** logic to extract structured observations. 66 | 3. **Reward functions** to score model outputs. 67 | 68 | Rollout management, parallel execution, etc. comes out of the box. 69 | 70 | ⭐ Star our repository to show your support! 71 | 72 | ## 💡 Core Features 73 | 74 | **Built-in examples & templates** 75 | 76 | Get started with ready to use recipes, from Wikipedia search to spreadsheet manipulation. Easy to copy, customize, and extend. And yes, more are on the way. 77 | 78 | **Trainer integrations** 79 | 80 | Use your own trainer or training framework - no lock-in. `benchmax` is already integrated into SkyRL, with more integrations (Tinker, etc.) coming soon! 81 | 82 | **MCP support** 83 | 84 | Tap into the growing MCP ecosystem and integrate them as tools within your environments. 85 | 86 | **Multi-node parallel execution** 87 | 88 | Multi-node parallelization enabled out of the box with state isolation across roll-outs (e.g. editing files on filesystem, etc.). 89 | 90 | 91 | ## 🌐 Creating & Training with Environments 92 | 93 | ### What is an environment? 94 | 95 | An environment consists of: 96 | 97 | - A list of tools that an LLM can call 98 | - A list of reward functions that evaluate the quality & correctness of the model's final output. 99 | 100 | We also support MCP servers natively, allowing you to easily leverage the many servers built by the community. 101 | 102 | ### Pre-built environments 103 | 104 | Ready-to-use environments with pre-configured tools and reward functions. 105 | 106 | - [CRM](/src/benchmax/envs/crm/README.md) 107 | - [Excel](/src/benchmax/envs/excel/README.md) 108 | - [Math](/src/benchmax/envs/math/README.md) 109 | - [Wikipedia](/src/benchmax/envs/wikipedia/README.md) 110 | 111 | ### How do I create a custom environment? 112 | 113 | 1. [With existing MCP servers](/src/benchmax/envs/mcp/README.md) (Built-in support for multi-node parallelization) 114 | 115 | 2. [Extend BaseEnv](/src/benchmax/envs/README.md) 116 | 117 | ### How about more complex environments? 118 | 119 | - Check out our excel spreadsheet RL environment: `benchmax.envs.excel.excel_env.ExcelEnv` 120 | 121 | ### How do I use an environment with my preferred RL Trainer? 122 | 123 | We currently have integrations with SkyRL. More incoming! 124 | 125 | [`benchmax` environments with skyrl](/examples/skyrl/README.md) 126 | 127 | ### I want a specific environment 128 | 129 | Open an issue and tag us & we will look into building you one! 130 | 131 | --- 132 | 133 | ## 🎯 Motivation 134 | 135 | - **Modularity and Simplicity**: 136 | 137 | We set out to build a lightweight, modular system for defining RL environments—breaking them down into simple, composable parts: tools, tool output parsing, and reward functions. 138 | 139 | The goal’s to make it easy for software engineers to build and experiment with RL environments without needing deep RL expertise. 140 | 141 | - **Trainer Integrations**: 142 | 143 | There’s been lots of new RL training frameworks popping up (e.g., numerous forks of verl) & we expect this to continue. They are often tightly coupled with specific environments, leading to fragmentation and limited compatibility. 144 | 145 | We are building `benchmax` as a standalone library with integrations to these different training frameworks & as an easy way for new frameworks to tap into an existing pool of environments. We're already integrated with SkyRL (Tinker coming soon)! 146 | 147 | - **Task Recipes and Ideas**: 148 | 149 | We want `benchmax` to be a living library of reusable, RL-compatible task recipes, ready to inspire and extend beyond the usual suspects like math and coding. We aim to support more real-world workflows, including open-ended and long-horizon tasks. 150 | 151 | - **Parallelization and Cloud Compatibility**: 152 | - Enable efficient parallelization with maintained statefulness between rollouts. 153 | - Facilitate easy deployment and scalability in cloud environments. 154 | 155 | - **MCP as a first class citizen**: 156 | 157 | There has been an explosion of MCP servers/tools built out for use-cases ranging from browser use to excel to game creation.`benchmax` allows folks to leverage and compose these existing MCP servers to build environments integrated with real world systems e.g. excel 158 | 159 | 160 | ## 🤝 Contributing 161 | 162 | We welcome new environment recipes, bug reports, and trainer integrations! 163 | 164 | ⭐ Star our repository to show your support! 165 | 166 | ## 📜 License 167 | 168 | Apache 2.0 © 2025 CGFT Inc. 169 | -------------------------------------------------------------------------------- /tests/unit/envs/wikipedia/test_wiki_env.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import AsyncMock, MagicMock, patch 3 | from aiohttp import ClientResponse 4 | 5 | from benchmax.envs.wikipedia.wiki_env import WikipediaEnv 6 | 7 | 8 | @pytest.fixture 9 | def wiki_env() -> WikipediaEnv: 10 | """Create WikipediaEnv instance without API keys for unit tests.""" 11 | return WikipediaEnv(wikipedia_api_keys=None) 12 | 13 | 14 | class TestDatasetPreprocess: 15 | """Tests for dataset preprocessing.""" 16 | 17 | @pytest.mark.asyncio 18 | async def test_dataset_preprocess_valid_example( 19 | self, wiki_env: WikipediaEnv 20 | ) -> None: 21 | """Test preprocessing a valid dataset example.""" 22 | example = {"Question": "Who created Python?", "Answer": "Guido van Rossum"} 23 | 24 | result = wiki_env.dataset_preprocess(example) 25 | 26 | assert result["prompt"] == "Who created Python?" 27 | assert result["ground_truth"] == "Guido van Rossum" 28 | 29 | 30 | class TestComputeReward: 31 | """Tests for reward computation.""" 32 | 33 | @pytest.mark.asyncio 34 | async def test_compute_reward_exact_match(self, wiki_env: WikipediaEnv) -> None: 35 | """Test that exact match returns 1.0.""" 36 | completion = "The answer is Paris" 37 | ground_truth = "Paris" 38 | 39 | rewards = await wiki_env.compute_reward( 40 | rollout_id="test", completion=completion, ground_truth=ground_truth 41 | ) 42 | 43 | assert rewards["text_match"] == 1.0 44 | 45 | @pytest.mark.asyncio 46 | async def test_compute_reward_case_insensitive( 47 | self, wiki_env: WikipediaEnv 48 | ) -> None: 49 | """Test that matching is case-insensitive.""" 50 | completion = "The answer is PARIS" 51 | ground_truth = "paris" 52 | 53 | rewards = await wiki_env.compute_reward( 54 | rollout_id="test", completion=completion, ground_truth=ground_truth 55 | ) 56 | 57 | assert rewards["text_match"] == 1.0 58 | 59 | @pytest.mark.asyncio 60 | async def test_compute_reward_no_match(self, wiki_env: WikipediaEnv) -> None: 61 | """Test that wrong answer returns 0.0.""" 62 | completion = "The answer is London" 63 | ground_truth = "Paris" 64 | 65 | rewards = await wiki_env.compute_reward( 66 | rollout_id="test", completion=completion, ground_truth=ground_truth 67 | ) 68 | 69 | assert rewards["text_match"] == 0.0 70 | 71 | @pytest.mark.asyncio 72 | async def test_compute_reward_missing_answer_tags( 73 | self, wiki_env: WikipediaEnv 74 | ) -> None: 75 | """Test that missing answer tags returns 0.0.""" 76 | completion = "The answer is Paris" 77 | ground_truth = "Paris" 78 | 79 | rewards = await wiki_env.compute_reward( 80 | rollout_id="test", completion=completion, ground_truth=ground_truth 81 | ) 82 | 83 | assert rewards["text_match"] == 0.0 84 | 85 | 86 | class TestListTools: 87 | """Tests for listing available tools.""" 88 | 89 | @pytest.mark.asyncio 90 | async def test_list_tools_returns_two_tools(self, wiki_env: WikipediaEnv) -> None: 91 | """Test that exactly 2 tools are returned with correct names.""" 92 | tools = await wiki_env.list_tools() 93 | 94 | assert len(tools) == 2 95 | 96 | tool_names = {tool.name for tool in tools} 97 | assert "search_wikipedia" in tool_names 98 | assert "get_wikipedia_article" in tool_names 99 | 100 | # Verify each tool has required schema properties 101 | for tool in tools: 102 | assert tool.name is not None 103 | assert tool.description is not None 104 | assert tool.input_schema is not None 105 | assert "type" in tool.input_schema 106 | assert "properties" in tool.input_schema 107 | 108 | 109 | class TestRunTool: 110 | """Tests for tool execution with mocked HTTP calls.""" 111 | 112 | @pytest.mark.asyncio 113 | async def test_run_tool_search_wikipedia_success( 114 | self, mock_safe_request: AsyncMock, wiki_env: WikipediaEnv 115 | ) -> None: 116 | """Test successful Wikipedia search with mocked response.""" 117 | mock_response = MagicMock(spec=ClientResponse) 118 | mock_response.status = 200 119 | mock_response.json = AsyncMock( 120 | return_value={ 121 | "query": { 122 | "search": [ 123 | { 124 | "title": "Python (programming language)", 125 | "snippet": "Python is a high-level programming language", 126 | }, 127 | { 128 | "title": "Python (genus)", 129 | "snippet": "Python is a genus of constricting snakes", 130 | }, 131 | ] 132 | } 133 | } 134 | ) 135 | mock_safe_request.return_value = mock_response 136 | 137 | with patch("benchmax.envs.wikipedia.wiki_env.safe_request"): 138 | result = await wiki_env.run_tool( 139 | rollout_id="test", tool_name="search_wikipedia", q="Python", limit=5 140 | ) 141 | 142 | assert isinstance(result, str) 143 | assert "Python (programming language)" in result 144 | assert "high-level programming language" in result 145 | 146 | @pytest.mark.asyncio 147 | @patch("benchmax.envs.wikipedia.wiki_env.safe_request") 148 | async def test_run_tool_get_article_success( 149 | self, mock_safe_request: AsyncMock, wiki_env: WikipediaEnv 150 | ) -> None: 151 | """Test successful article fetch with mocked response.""" 152 | mock_response = MagicMock(spec=ClientResponse) 153 | mock_response.status = 200 154 | mock_response.json = AsyncMock( 155 | return_value={ 156 | "query": { 157 | "pages": { 158 | "12345": { 159 | "extract": "Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability with the use of significant indentation." 160 | } 161 | } 162 | } 163 | } 164 | ) 165 | mock_safe_request.return_value = mock_response 166 | 167 | result = await wiki_env.run_tool( 168 | rollout_id="test", 169 | tool_name="get_wikipedia_article", 170 | title="Python (programming language)", 171 | max_chars=1000, 172 | ) 173 | 174 | assert isinstance(result, str) 175 | assert "Python is a high-level" in result 176 | assert "code readability" in result 177 | 178 | @pytest.mark.asyncio 179 | async def test_run_tool_missing_required_param( 180 | self, wiki_env: WikipediaEnv 181 | ) -> None: 182 | """Test error handling when required parameter is missing.""" 183 | # Search without query 184 | result = await wiki_env.run_tool( 185 | rollout_id="test", 186 | tool_name="search_wikipedia", 187 | q="", # Empty query 188 | limit=5, 189 | ) 190 | 191 | assert isinstance(result, str) 192 | assert "Error" in result 193 | 194 | 195 | class TestInitRollout: 196 | """Tests for rollout initialization.""" 197 | 198 | @pytest.mark.asyncio 199 | async def test_init_rollout_completes(self, wiki_env: WikipediaEnv) -> None: 200 | """Test that init_rollout completes without errors (no-op).""" 201 | # Should complete without raising exceptions 202 | await wiki_env.init_rollout(rollout_id="test_rollout_1") 203 | await wiki_env.init_rollout(rollout_id="test_rollout_2", extra_arg="value") 204 | 205 | # No assertions needed - just verify no exceptions raised 206 | assert True 207 | -------------------------------------------------------------------------------- /src/benchmax/envs/mcp/provisioners/skypilot_provisioner.py: -------------------------------------------------------------------------------- 1 | """ 2 | SkyPilot provisioner for launching cloud-based server clusters. 3 | """ 4 | 5 | import logging 6 | import uuid 7 | from pathlib import Path 8 | from typing import List, Optional 9 | import sky 10 | 11 | from .base_provisioner import BaseProvisioner 12 | from .utils import get_run_command, setup_sync_dir, cleanup_dir, get_setup_command 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class SkypilotProvisioner(BaseProvisioner): 18 | """ 19 | Provisioner that launches a SkyPilot cluster in the cloud. 20 | 21 | Use this for: 22 | - Production-scale parallel execution 23 | - Distributed benchmarking across many nodes 24 | - Cloud-based compute resource management 25 | 26 | Example: 27 | import sky 28 | provisioner = SkypilotProvisioner( 29 | workdir_path=Path("my_workdir"), 30 | cloud=sky.Azure(), 31 | num_nodes=5, 32 | servers_per_node=4, 33 | cpus=4, 34 | memory=16, 35 | ) 36 | # Will provision 20 total servers (5 nodes * 4 servers/node) 37 | """ 38 | 39 | def __init__( 40 | self, 41 | workdir_path: Path | str, 42 | cloud: sky.clouds.Cloud, 43 | num_nodes: int = 1, 44 | servers_per_node: int = 5, 45 | cpus: Optional[str | int] = "2+", 46 | memory: Optional[str | int] = "8+", 47 | base_cluster_name: str = "benchmax-env-cluster", 48 | ): 49 | """ 50 | Initialize SkyPilot provisioner. 51 | 52 | Args: 53 | workdir_path: Path to workdir containing mcp_config.yaml, setup.sh, etc. 54 | cloud: SkyPilot cloud instance (e.g., sky.AWS(), sky.Azure(), sky.GCP()). 55 | num_nodes: Number of nodes in the cluster. 56 | servers_per_node: Number of proxy servers to run on each node. 57 | cpus: CPU requirement per node (e.g., "2+", 4, "8"). 58 | memory: Memory requirement per node in GB (e.g., "16+", 32). 59 | base_cluster_name: Base name for the cluster (timestamp will be appended). 60 | """ 61 | if num_nodes < 1: 62 | raise ValueError("num_nodes must be at least 1") 63 | if servers_per_node < 1 or servers_per_node > 100: 64 | raise ValueError("servers_per_node must be between 1 and 100") 65 | 66 | self._workdir_path = Path(workdir_path).absolute() 67 | self._cloud = cloud 68 | self._num_nodes = num_nodes 69 | self._servers_per_node = servers_per_node 70 | self._total_num_servers = num_nodes * servers_per_node 71 | self._cpus = cpus 72 | self._memory = memory 73 | 74 | # Generate unique cluster name 75 | unique_suffix = uuid.uuid4().hex[:4] 76 | self._cluster_name = f"{base_cluster_name}-{unique_suffix}" 77 | 78 | # Internal state 79 | self._sync_workdir: Optional[Path] = None 80 | self._cluster_provisioned: bool = False 81 | 82 | logger.info( 83 | f"SkypilotProvisioner configured: {num_nodes} nodes * {servers_per_node} servers/node = " 84 | f"{self._total_num_servers} total servers, cluster: '{self._cluster_name}'" 85 | ) 86 | 87 | @property 88 | def num_servers(self) -> int: 89 | """ 90 | Total number of servers 91 | """ 92 | return self._total_num_servers 93 | 94 | async def provision_servers(self, api_secret: str) -> List[str]: 95 | """ 96 | Launch SkyPilot cluster and return server addresses. 97 | 98 | Args: 99 | api_secret: API token for server authentication. 100 | 101 | Returns: 102 | List of server addresses in "host:port" format. 103 | """ 104 | if self._cluster_provisioned: 105 | raise RuntimeError("Cluster already provisioned. Call teardown() first.") 106 | 107 | self._cluster_provisioned = True 108 | 109 | logger.info(f"Launching SkyPilot cluster '{self._cluster_name}'...") 110 | 111 | # Setup sync directory with workdir contents + proxy_server.py 112 | try: 113 | self._sync_workdir = setup_sync_dir(self._workdir_path) 114 | logger.debug(f"Synced workdir to temporary directory: {self._sync_workdir}") 115 | except Exception as e: 116 | logger.error(f"Failed to setup sync directory: {e}") 117 | raise 118 | 119 | # Calculate ports 120 | base_port = 8080 121 | all_ports = [str(base_port + i) for i in range(self._servers_per_node)] 122 | 123 | env = None if api_secret is None else {"API_SECRET": api_secret} 124 | 125 | # Configure SkyPilot task 126 | sky_task = sky.Task( 127 | name="mcp-server", 128 | run=get_run_command(ports=all_ports), 129 | setup=get_setup_command(), 130 | workdir=str(self._sync_workdir), 131 | num_nodes=self._num_nodes, 132 | envs=env, 133 | ) 134 | 135 | sky_task.set_resources( 136 | sky.Resources( 137 | cloud=self._cloud, 138 | cpus=self._cpus, 139 | memory=self._memory, 140 | ports=all_ports, 141 | ) 142 | ) 143 | 144 | # Launch cluster 145 | logger.info( 146 | f"Submitting cluster launch: {self._num_nodes} nodes, " 147 | f"{self._cpus} CPUs, {self._memory}GB memory per node" 148 | ) 149 | cluster_handle = None 150 | try: 151 | _, handle = sky.launch( 152 | task=sky_task, 153 | cluster_name=self._cluster_name, 154 | detach_run=True, 155 | detach_setup=True, 156 | retry_until_up=True, 157 | ) 158 | cluster_handle = handle 159 | except Exception as e: 160 | logger.error(f"Failed to launch SkyPilot cluster: {e}") 161 | cleanup_dir(self._sync_workdir) 162 | self._sync_workdir = None 163 | raise RuntimeError(f"SkyPilot cluster launch failed: {e}") from e 164 | 165 | if cluster_handle is None: 166 | cleanup_dir(self._sync_workdir) 167 | self._sync_workdir = None 168 | raise RuntimeError("SkyPilot launch returned no handle") 169 | 170 | # Collect server addresses 171 | addresses = [] 172 | for node_idx, (_, node_ip) in enumerate( 173 | cluster_handle.stable_internal_external_ips 174 | ): 175 | for port in all_ports: 176 | addresses.append(f"{node_ip}:{port}") 177 | logger.debug(f"Node {node_idx}: {node_ip} with {len(all_ports)} servers") 178 | 179 | logger.info( 180 | f"Successfully launched cluster '{self._cluster_name}' " 181 | f"with {len(addresses)} servers across {self._num_nodes} node(s)" 182 | ) 183 | return addresses 184 | 185 | async def teardown(self) -> None: 186 | """ 187 | Tear down SkyPilot cluster and clean up resources. 188 | """ 189 | if self._cluster_provisioned is None: 190 | logger.warning("teardown() called but no cluster is active.") 191 | return 192 | 193 | logger.info(f"Tearing down SkyPilot cluster '{self._cluster_name}'...") 194 | try: 195 | sky.down(cluster_name=self._cluster_name) 196 | logger.info(f"Cluster '{self._cluster_name}' torn down successfully") 197 | except Exception as e: 198 | logger.error(f"Error tearing down cluster '{self._cluster_name}': {e}") 199 | finally: 200 | self._cluster_provisioned = False 201 | if self._sync_workdir: 202 | cleanup_dir(self._sync_workdir) 203 | self._sync_workdir = None 204 | logger.debug("Cleaned up sync directory") 205 | 206 | @property 207 | def cluster_name(self) -> str: 208 | """ 209 | Unique name of the SkyPilot cluster. 210 | """ 211 | return self._cluster_name 212 | -------------------------------------------------------------------------------- /src/benchmax/envs/excel/workdir/excel_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import shutil 4 | import subprocess 5 | import platform 6 | import tempfile 7 | import datetime 8 | from typing import Tuple 9 | 10 | WHITE_LIKE_COLORS = [ 11 | "00000000", 12 | "FFFFFFFF", 13 | "FFFFFF00", 14 | ] 15 | 16 | 17 | def evaluate_excel(excel_path: Path | str): 18 | """ 19 | Evaluate Python code that manipulates an Excel file using xlwings. 20 | """ 21 | if platform.system() == "Linux": 22 | # Use LibreOffice for Linux 23 | evaluate_excel_libre(excel_path) 24 | return 25 | else: 26 | # Use xlwings for Windows and MacOS (assuming Excel is installed) 27 | import xlwings # type: ignore 28 | 29 | excel_app = xlwings.App(visible=False) 30 | excel_book = excel_app.books.open(excel_path) 31 | excel_book.save() 32 | excel_book.close() 33 | excel_app.quit() 34 | 35 | 36 | def evaluate_excel_libre(excel_path: Path | str) -> None: 37 | """ 38 | Force‑recalculate in place under Linux using LibreOffice. 39 | Raises subprocess.CalledProcessError if soffice exits abnormally. 40 | """ 41 | tmp_outdir = tempfile.mkdtemp(prefix="lo_convert_") 42 | cmd = [ 43 | "soffice", 44 | "--headless", 45 | "--nologo", 46 | "--nofirststartwizard", 47 | "--norestore", 48 | "--calc", 49 | "--convert-to", 50 | "xlsx", 51 | "--outdir", 52 | tmp_outdir, 53 | os.path.abspath(excel_path), 54 | ] 55 | lo_home = Path(tempfile.mkdtemp(prefix="lo_profile_")) 56 | env = dict(os.environ, HOME=str(lo_home)) 57 | try: 58 | subprocess.run( 59 | cmd, 60 | check=True, 61 | stdout=subprocess.DEVNULL, 62 | stderr=subprocess.STDOUT, 63 | text=True, 64 | env=env, 65 | ) 66 | # Determine the converted file name (same base name, .xlsx extension) 67 | base_name = os.path.splitext(os.path.basename(excel_path))[0] + ".xlsx" 68 | converted_path = os.path.join(tmp_outdir, base_name) 69 | # Overwrite the original file with the converted one 70 | shutil.move(converted_path, excel_path) 71 | finally: 72 | # Clean up the temp folder 73 | shutil.rmtree(tmp_outdir, ignore_errors=True) 74 | pass 75 | 76 | 77 | def excel_to_str_repr(excel_path: Path | str, evaluate_formulas=False) -> str: 78 | from openpyxl import load_workbook 79 | 80 | # Load workbook twice: data_only=True to get the evaluated values, 81 | # and data_only=False to get the formulas and styles. 82 | if evaluate_formulas: 83 | evaluate_excel(excel_path) 84 | 85 | wb_evaluated = load_workbook(excel_path, data_only=True) 86 | wb_raw = load_workbook(excel_path, data_only=False) 87 | 88 | result = [] 89 | 90 | for sheet_name in wb_evaluated.sheetnames: 91 | sheet_evaluated = wb_evaluated[sheet_name] 92 | sheet_raw = wb_raw[sheet_name] 93 | 94 | sheet_result = f"Sheet Name: {sheet_name}" 95 | result.append(sheet_result) 96 | 97 | for row_evaluated, row_raw in zip( 98 | sheet_evaluated.iter_rows(), sheet_raw.iter_rows() 99 | ): 100 | is_row_empty = True 101 | 102 | for cell_evaluated, cell_raw in zip(row_evaluated, row_raw): 103 | is_default_background = True 104 | style = [] 105 | 106 | if ( 107 | cell_raw.fill.start_color.index != "00000000" 108 | and type(cell_raw.fill.start_color.rgb) is str 109 | and cell_raw.fill.start_color.rgb not in WHITE_LIKE_COLORS 110 | ): 111 | is_default_background = False 112 | style.append(f"bg:{cell_raw.fill.start_color.rgb}") 113 | if ( 114 | cell_raw.font.color 115 | and cell_raw.font.color.index != 1 116 | and type(cell_raw.font.color.rgb) is str 117 | ): 118 | style.append(f"color:{cell_raw.font.color.rgb}") 119 | if cell_raw.font.bold: 120 | style.append("bold") 121 | if cell_raw.font.italic: 122 | style.append("italic") 123 | if cell_raw.font.underline: 124 | style.append("underline") 125 | 126 | display_value = cell_evaluated.value 127 | if cell_raw.data_type == "f": 128 | cell_raw_val = cell_raw.value 129 | if type(cell_raw_val) is not str: 130 | cell_raw_val = cell_raw.value.text # type: ignore 131 | display_value = f"{cell_raw_val} -> {cell_evaluated.value}" 132 | 133 | coords = cell_evaluated.coordinate 134 | 135 | if display_value is None and not is_default_background: 136 | # If cell is empty but has background color, still include it 137 | result.append(f"{coords}: null [{', '.join(style)}]") 138 | is_row_empty = False 139 | elif display_value: 140 | style_str = f" [{', '.join(style)}]" if style else "" 141 | result.append(f"{coords}: {display_value}{style_str}") 142 | is_row_empty = False 143 | if not is_row_empty: 144 | result.append("") # Newline after each row 145 | 146 | return "\n".join(result) 147 | 148 | 149 | def transform_value(v): 150 | if isinstance(v, (int, float)): 151 | v = round(float(v), 2) 152 | elif isinstance(v, datetime.time): 153 | v = str(v)[:-3] 154 | elif isinstance(v, datetime.datetime): 155 | v = round( 156 | (v - datetime.datetime(1899, 12, 30)).days 157 | + (v - datetime.datetime(1899, 12, 30)).seconds / 86400.0, 158 | 0, 159 | ) 160 | elif isinstance(v, str): 161 | try: 162 | v = round(float(v), 2) 163 | except ValueError: 164 | pass 165 | return v 166 | 167 | 168 | def compare_fill_color(fill1, fill2): 169 | fgColor1 = fill1.fgColor.rgb if fill1.fgColor else None 170 | fgColor2 = fill2.fgColor.rgb if fill2.fgColor else None 171 | bgColor1 = fill1.bgColor.rgb if fill1.bgColor else None 172 | bgColor2 = fill2.bgColor.rgb if fill2.bgColor else None 173 | return fgColor1 == fgColor2 and bgColor1 == bgColor2 174 | 175 | 176 | def compare_font_color(font1, font2): 177 | # UNSURE if this is actually correct. 178 | if font1.color and font2.color: 179 | return font1.color.rgb == font2.color.rgb 180 | return font1.color is None and font2.color is None 181 | 182 | 183 | def col_name2num(name): 184 | """Convert an Excel column name to a column number""" 185 | num = 0 186 | for c in name: 187 | num = num * 26 + (ord(c.upper()) - ord("A") + 1) 188 | return num 189 | 190 | 191 | def parse_cell_range(range_str): 192 | start_cell, end_cell = range_str.split(":") 193 | start_col, start_row = "", "" 194 | for char in start_cell: 195 | if char.isdigit(): 196 | start_row += char 197 | else: 198 | start_col += char 199 | end_col, end_row = "", "" 200 | for char in end_cell: 201 | if char.isdigit(): 202 | end_row += char 203 | else: 204 | end_col += char 205 | return (col_name2num(start_col), int(start_row)), ( 206 | col_name2num(end_col), 207 | int(end_row), 208 | ) 209 | 210 | 211 | def generate_cell_names(range_str): 212 | from openpyxl.utils import get_column_letter 213 | 214 | if ":" not in range_str: 215 | return [range_str] 216 | (start_col, start_row), (end_col, end_row) = parse_cell_range(range_str) 217 | columns = [get_column_letter(i) for i in range(start_col, end_col + 1)] 218 | return [f"{col}{row}" for col in columns for row in range(start_row, end_row + 1)] 219 | 220 | 221 | def compare_excel_cells( 222 | ground_truth_path: str, output_path: str, answer_position: str, is_CF: bool = False 223 | ) -> Tuple[bool, str]: 224 | from openpyxl import load_workbook 225 | 226 | wb_gt = load_workbook(ground_truth_path, data_only=True) 227 | wb_out = load_workbook(output_path, data_only=True) 228 | sheet_ranges = answer_position.split(",") 229 | for sheet_range in sheet_ranges: 230 | if "!" in sheet_range: 231 | sheet_name, cell_range = sheet_range.split("!") 232 | sheet_name = sheet_name.strip("'") 233 | else: 234 | sheet_name = wb_gt.sheetnames[0] 235 | cell_range = sheet_range 236 | if sheet_name not in wb_out.sheetnames: 237 | return False, f"Worksheet '{sheet_name}' not found in output workbook." 238 | ws_gt = wb_gt[sheet_name] 239 | ws_out = wb_out[sheet_name] 240 | cell_names = generate_cell_names(cell_range) 241 | for cell_name in cell_names: 242 | cell_gt = ws_gt[cell_name] 243 | cell_out = ws_out[cell_name] 244 | if not transform_value(cell_gt.value) == transform_value(cell_out.value): 245 | return ( 246 | False, 247 | f"Value mismatch at {cell_name}: expected {cell_gt.value}, got {cell_out.value}", 248 | ) 249 | if is_CF: 250 | if not compare_fill_color(cell_gt.fill, cell_out.fill): 251 | return False, f"Fill color mismatch at {cell_name}" 252 | if not compare_font_color(cell_gt.font, cell_out.font): 253 | return False, f"Font color mismatch at {cell_name}" 254 | return True, "All comparisons passed." 255 | -------------------------------------------------------------------------------- /src/benchmax/envs/excel/excel_env.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | from typing import Any, Dict, Tuple, Optional 5 | 6 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict 7 | import sky 8 | 9 | from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv 10 | from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner 11 | from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner 12 | from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner 13 | from benchmax.envs.types import StandardizedExample 14 | from .data_utils import download_and_extract 15 | 16 | # Using library shared with mcp workdir 17 | from .workdir.excel_utils import excel_to_str_repr 18 | 19 | SYSTEM_PROMPT = """You are a spreadsheet expert who can manipulate spreadsheets through Python code. 20 | 21 | You need to solve the given spreadsheet manipulation question, which contains six types of information: 22 | - instruction: The question about spreadsheet manipulation. 23 | - spreadsheet_path: The path of the spreadsheet file you need to manipulate. 24 | - spreadsheet_content: The content of speadsheet file. 25 | - instruction_type: There are two values (Cell-Level Manipulation, Sheet-Level Manipulation) used to indicate whether the answer to this question applies only to specific cells or to the entire worksheet. 26 | - answer_position: The position need to be modified or filled. For Cell-Level Manipulation questions, this field is filled with the cell position; for Sheet-Level Manipulation, it is the maximum range of cells you need to modify. You only need to modify or fill in values within the cell range specified by answer_position. 27 | - output_path: You need to generate the modified spreadsheet file in this new path. 28 | """ 29 | 30 | DEFAULT_DATA_OUTPUT_PATH = os.path.expanduser("~/.cache/excel_data") 31 | SPREADSHEET_FULL = "all_data_912_v0.1" 32 | SPREADSHEET_SAMPLE = "sample_data_200" 33 | 34 | # Set train data to full for proper training 35 | SPREADSHEET_BENCH_TRAIN_DATA = SPREADSHEET_SAMPLE 36 | 37 | 38 | class ExcelExample(StandardizedExample): 39 | id: str 40 | answer_position: str 41 | output_filename: str 42 | ground_truth_filename: str 43 | spreadsheet_base_dir: str 44 | 45 | 46 | class ExcelEnv(ParallelMcpEnv): 47 | """Environment for spreadsheet manipulation tasks using MCP with Excel support""" 48 | 49 | system_prompt: str = SYSTEM_PROMPT 50 | 51 | def __init__( 52 | self, 53 | workdir_path: Path, 54 | provisioner: BaseProvisioner, 55 | **kwargs, 56 | ): 57 | """Initialize the ExcelEnv with an optional dataset path.""" 58 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs) 59 | 60 | @classmethod 61 | def load_dataset( 62 | cls, 63 | dataset_name: str = "spreadsheetbench", 64 | data_output_path: str = DEFAULT_DATA_OUTPUT_PATH, 65 | **kwargs, 66 | ) -> Tuple[ 67 | DatasetDict | Dataset | IterableDatasetDict | IterableDataset, str | None 68 | ]: 69 | # Currently only support spreadsheetbench dataset but can be extended to other datasets in the future 70 | if dataset_name == "spreadsheetbench": 71 | folder_path = Path(data_output_path) / SPREADSHEET_BENCH_TRAIN_DATA 72 | json_path = folder_path / "dataset.json" 73 | if not os.path.exists(json_path): 74 | download_and_extract( 75 | f"https://github.com/RUCKBReasoning/SpreadsheetBench/raw/refs/heads/main/data/{SPREADSHEET_BENCH_TRAIN_DATA}.tar.gz", 76 | data_output_path, 77 | ) 78 | with open(json_path, "r") as f: 79 | data = json.load(f) 80 | for example in data: 81 | example["id"] = str(example["id"]) # Ensure IDs are strings 82 | dataset = Dataset.from_list(data) 83 | return dataset, str(folder_path) 84 | return super().load_dataset(dataset_name, **kwargs) 85 | 86 | @classmethod 87 | def dataset_preprocess( 88 | cls, example: Any, dataset_path: Optional[str | Path] = None, **kwargs 89 | ) -> ExcelExample: 90 | # convert dataset json into ExcelExample (a subclass of StandardizedExample) 91 | example_id: Optional[str] = example.get("id") 92 | spreadsheet_path: Optional[str] = example.get("spreadsheet_path") 93 | instruction: Optional[str] = example.get("instruction") 94 | instruction_type: Optional[str] = example.get("instruction_type") 95 | answer_position: Optional[str] = example.get("answer_position") 96 | 97 | if ( 98 | not example_id 99 | or not spreadsheet_path 100 | or not instruction 101 | or not instruction_type 102 | or not answer_position 103 | ): 104 | raise ValueError( 105 | "Example must contain 'id', 'spreadsheet_path', 'instruction', 'instruction_type', and 'answer_position' fields" 106 | ) 107 | if not isinstance(spreadsheet_path, str): 108 | raise TypeError("spreadsheet_path must be a string") 109 | 110 | if dataset_path is None: 111 | dataset_path = Path(DEFAULT_DATA_OUTPUT_PATH) / SPREADSHEET_BENCH_TRAIN_DATA 112 | elif not isinstance(dataset_path, (str, Path)): 113 | raise TypeError("dataset_path must be a str or Path") 114 | 115 | spreadsheet_base_dir = Path(dataset_path) / spreadsheet_path 116 | 117 | if os.path.exists(spreadsheet_base_dir) is False: 118 | raise FileNotFoundError( 119 | f"Spreadsheet path {spreadsheet_base_dir} does not exist." 120 | ) 121 | 122 | # File path in the workspace (input spreadsheet will be copied into the workspace at init_rollout) 123 | input_filename = f"1_{example_id}_input.xlsx" 124 | output_filename = f"1_{example_id}_output.xlsx" 125 | ground_truth_filename = f"1_{example_id}_answer.xlsx" 126 | 127 | input_src_path = spreadsheet_base_dir / input_filename 128 | input_spreadsheet_content = excel_to_str_repr(input_src_path, True) 129 | 130 | prompt = f""" 131 | Instruction: {instruction} 132 | Spreadsheet Path: {input_filename} 133 | Spreadsheet Content: {input_spreadsheet_content} 134 | Instruction Type: {instruction_type} 135 | Answer Position: {answer_position} 136 | Output Path: {output_filename}""" 137 | 138 | return ExcelExample( 139 | prompt=prompt.strip(), 140 | # Ground truth unused in ExcelEnv 141 | ground_truth=None, 142 | init_rollout_args={ 143 | "input_src_path": str(input_src_path), 144 | }, 145 | id=example_id, 146 | answer_position=answer_position, 147 | output_filename=output_filename, 148 | ground_truth_filename=ground_truth_filename, 149 | spreadsheet_base_dir=str(spreadsheet_base_dir), 150 | ) 151 | 152 | async def init_rollout(self, rollout_id: str, **rollout_args): 153 | input_src_path: Optional[str] = rollout_args.get("input_src_path") 154 | 155 | if not input_src_path: 156 | raise ValueError("rollout_args must contain 'input_src_path' field") 157 | 158 | await super().init_rollout(rollout_id, **rollout_args) 159 | await self.copy_to_workspace(rollout_id, Path(input_src_path)) 160 | 161 | async def compute_reward( 162 | self, rollout_id: str, completion: str, ground_truth: Any, **kwargs: Any 163 | ) -> Dict[str, float]: 164 | answer_position: Optional[str] = kwargs.get("answer_position") 165 | output_filename: Optional[str] = kwargs.get("output_filename") 166 | ground_truth_filename: Optional[str] = kwargs.get("ground_truth_filename") 167 | spreadsheet_base_dir: Optional[str] = kwargs.get("spreadsheet_base_dir") 168 | 169 | if ( 170 | not answer_position 171 | or not output_filename 172 | or not ground_truth_filename 173 | or not spreadsheet_base_dir 174 | ): 175 | raise ValueError( 176 | "kwargs must contain 'answer_position', 'output_filename', 'ground_truth_filename', and 'spreadsheet_base_dir' fields" 177 | ) 178 | 179 | # Copy ground truth file to workspace for reward computation 180 | await self.copy_to_workspace( 181 | rollout_id, Path(spreadsheet_base_dir) / ground_truth_filename 182 | ) 183 | return await super().compute_reward( 184 | rollout_id, 185 | completion, 186 | ground_truth, 187 | answer_position=answer_position, 188 | output_filename=output_filename, 189 | ground_truth_filename=ground_truth_filename, 190 | ) 191 | 192 | 193 | class ExcelEnvLocal(ExcelEnv): 194 | """Import this env to run environment locally""" 195 | 196 | def __init__(self, num_local_servers: int = 5, **kwargs): 197 | workdir_path = Path(__file__).parent / "workdir" 198 | provisioner = LocalProvisioner( 199 | workdir_path=workdir_path, num_servers=num_local_servers 200 | ) 201 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs) 202 | 203 | 204 | class ExcelEnvSkypilot(ExcelEnv): 205 | """Import this env to run environment on any cloud (i.e. AWS / GCP / Azure) with Skypilot""" 206 | 207 | def __init__( 208 | self, 209 | cloud: sky.clouds.Cloud = sky.Azure(), 210 | num_nodes: int = 2, 211 | servers_per_node: int = 5, 212 | **kwargs, 213 | ): 214 | workdir_path = Path(__file__).parent / "workdir" 215 | provisioner = SkypilotProvisioner( 216 | workdir_path=workdir_path, 217 | cloud=cloud, 218 | num_nodes=num_nodes, 219 | servers_per_node=servers_per_node, 220 | ) 221 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs) 222 | --------------------------------------------------------------------------------