├── tests
├── __init__.py
├── unit
│ ├── envs
│ │ ├── __init__.py
│ │ ├── crm
│ │ │ └── __init__.py
│ │ ├── math
│ │ │ └── __init__.py
│ │ ├── mcp
│ │ │ ├── __init__.py
│ │ │ ├── provisioners
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_manual.py
│ │ │ │ ├── test_local.py
│ │ │ │ └── test_provisioners_utils.py
│ │ │ └── test_mcp_utils.py
│ │ ├── wikipedia
│ │ │ ├── __init__.py
│ │ │ └── test_wiki_env.py
│ │ └── excel
│ │ │ ├── test_inputs
│ │ │ └── test.xlsx
│ │ │ ├── test_excel_code_runner_mcp.py
│ │ │ └── test_excel_env.py
│ └── prompts
│ │ └── test_tools.py
├── integration
│ ├── __init__.py
│ ├── adapters
│ │ ├── __init__.py
│ │ └── skyrl
│ │ │ ├── __init__.py
│ │ │ └── skyrl_adapter_integration.py
│ └── envs
│ │ ├── __init__.py
│ │ ├── mcp
│ │ ├── __init__.py
│ │ ├── provisioners
│ │ │ ├── __init__.py
│ │ │ ├── utils.py
│ │ │ ├── test_local_integration.py
│ │ │ └── test_skypilot_integration.py
│ │ └── conftest.py
│ │ ├── math
│ │ └── __init__.py
│ │ ├── wikipedia
│ │ └── __init__.py
│ │ └── excel
│ │ ├── test_inputs
│ │ └── test.xlsx
│ │ └── test_excel_utils.py
└── conftest.py
├── src
└── benchmax
│ ├── adapters
│ ├── __init__.py
│ └── skyrl
│ │ └── benchmax_data_process.py
│ ├── envs
│ ├── __init__.py
│ ├── excel
│ │ ├── workdir
│ │ │ ├── __init__.py
│ │ │ ├── mcp_config.yaml
│ │ │ ├── setup.sh
│ │ │ ├── reward_fn.py
│ │ │ ├── excel_code_runner_mcp.py
│ │ │ └── excel_utils.py
│ │ ├── data_utils.py
│ │ ├── README.md
│ │ └── excel_env.py
│ ├── math
│ │ ├── workdir
│ │ │ ├── setup.sh
│ │ │ ├── mcp_config.yaml
│ │ │ └── reward_fn.py
│ │ ├── README.md
│ │ └── math_env.py
│ ├── crm
│ │ ├── workdir
│ │ │ ├── setup.sh
│ │ │ ├── mcp_config.yaml
│ │ │ └── reward_fn.py
│ │ ├── crm_env.py
│ │ └── README.md
│ ├── mcp
│ │ ├── example_workdir
│ │ │ ├── setup.sh
│ │ │ ├── mcp_config.yaml
│ │ │ ├── reward_fn.py
│ │ │ └── demo_mcp_server.py
│ │ ├── __init__.py
│ │ ├── provisioners
│ │ │ ├── __init__.py
│ │ │ ├── base_provisioner.py
│ │ │ ├── manual_provisioner.py
│ │ │ ├── utils.py
│ │ │ └── skypilot_provisioner.py
│ │ ├── utils.py
│ │ └── README.md
│ ├── types.py
│ ├── README.md
│ ├── wikipedia
│ │ ├── README.md
│ │ └── utils.py
│ ├── how-to-extend-base-env.md
│ └── base_env.py
│ └── prompts
│ ├── __init__.py
│ └── tools.py
├── .env.example
├── static
└── benchmax.png
├── pytest.ini
├── .gitignore
├── pyproject.toml
├── examples
└── skyrl
│ ├── run_benchmax_math.sh
│ ├── run_benchmax_excel.sh
│ ├── benchmax_math.py
│ ├── benchmax_excel.py
│ └── README.md
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/envs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/benchmax/adapters/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/benchmax/envs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/benchmax/prompts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integration/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/envs/crm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/envs/math/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/envs/mcp/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integration/adapters/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integration/envs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integration/envs/mcp/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/envs/wikipedia/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/benchmax/envs/excel/workdir/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integration/envs/math/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integration/adapters/skyrl/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integration/envs/wikipedia/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/envs/mcp/provisioners/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integration/envs/mcp/provisioners/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | WIKIPEDIA_API_KEYS=DUMMY_API_KEY_1,DUMMY_API_KEY_2,DUMMY_API_KEY_3
--------------------------------------------------------------------------------
/static/benchmax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cgftinc/benchmax/HEAD/static/benchmax.png
--------------------------------------------------------------------------------
/src/benchmax/envs/math/workdir/setup.sh:
--------------------------------------------------------------------------------
1 | # Install uv
2 | curl -LsSf https://astral.sh/uv/install.sh | sh
--------------------------------------------------------------------------------
/src/benchmax/envs/crm/workdir/setup.sh:
--------------------------------------------------------------------------------
1 | uv pip install simple-salesforce>=1.12.3 python-dateutil==2.9.0.post0
--------------------------------------------------------------------------------
/tests/unit/envs/excel/test_inputs/test.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cgftinc/benchmax/HEAD/tests/unit/envs/excel/test_inputs/test.xlsx
--------------------------------------------------------------------------------
/src/benchmax/envs/math/workdir/mcp_config.yaml:
--------------------------------------------------------------------------------
1 | mcpServers:
2 | calculator:
3 | command: uvx
4 | args:
5 | - mcp-server-calculator
--------------------------------------------------------------------------------
/tests/integration/envs/excel/test_inputs/test.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cgftinc/benchmax/HEAD/tests/integration/envs/excel/test_inputs/test.xlsx
--------------------------------------------------------------------------------
/src/benchmax/envs/excel/workdir/mcp_config.yaml:
--------------------------------------------------------------------------------
1 | mcpServers:
2 | test_server:
3 | command: python
4 | args:
5 | - ${{ sync_workdir }}/excel_code_runner_mcp.py
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/example_workdir/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Setup script for test MCP server
3 |
4 | echo "Installing test MCP server dependencies..."
5 |
6 | echo "Test MCP server setup complete"
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | MCP-based environment infrastructure for parallel rollout execution.
3 | """
4 |
5 | from .parallel_mcp_env import ParallelMcpEnv
6 | from .server_pool import ServerPool, ServerInfo
7 |
8 | __all__ = [
9 | "ParallelMcpEnv",
10 | "ServerPool",
11 | "ServerInfo",
12 | ]
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/example_workdir/mcp_config.yaml:
--------------------------------------------------------------------------------
1 | # Test MCP configuration
2 | # This configuration is used for testing the MCP infrastructure
3 | # Use ${{ sync_workdir }} to refer to the location of the synchronized workdir in the remote machine
4 |
5 | mcpServers:
6 | test_server:
7 | command: python
8 | args:
9 | - ${{ sync_workdir }}/demo_mcp_server.py
--------------------------------------------------------------------------------
/src/benchmax/envs/types.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict, Optional, TypedDict
3 |
4 |
5 | class StandardizedExample(TypedDict):
6 | prompt: str
7 | ground_truth: Any
8 | init_rollout_args: Optional[Dict[str, Any]]
9 |
10 |
11 | @dataclass
12 | class ToolDefinition:
13 | """Definition of a tool's interface"""
14 |
15 | name: str
16 | description: str
17 | input_schema: Optional[Dict[str, Any]] = None
18 |
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/provisioners/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Server provisioning strategies for ParallelMcpEnv.
3 | """
4 |
5 | from .base_provisioner import BaseProvisioner
6 | from .manual_provisioner import ManualProvisioner
7 | from .local_provisioner import LocalProvisioner
8 | from .skypilot_provisioner import SkypilotProvisioner
9 |
10 | __all__ = [
11 | "BaseProvisioner",
12 | "ManualProvisioner",
13 | "LocalProvisioner",
14 | "SkypilotProvisioner",
15 | ]
16 |
--------------------------------------------------------------------------------
/src/benchmax/envs/crm/workdir/mcp_config.yaml:
--------------------------------------------------------------------------------
1 | mcpServers:
2 | test_server:
3 | command: python
4 | args:
5 | - ${{ sync_workdir }}/salesforce_mcp.py
6 | env:
7 | SALESFORCE_USERNAME: crmarena_b2b@gmaill.com
8 | SALESFORCE_PASSWORD: crmarenatest
9 | SALESFORCE_SECURITY_TOKEN: zdaqqSYBEQTjjLuq0zLUHkC3
10 | # Uncomment the following lines to use B2C credentials
11 | # SALESFORCE_USERNAME: crmarena_b2c@gmaill.com
12 | # SALESFORCE_PASSWORD: crmarenatest
13 | # SALESFORCE_SECURITY_TOKEN: 2AQCtK8MnnV4lJdRNF0DGCs1
--------------------------------------------------------------------------------
/src/benchmax/envs/excel/workdir/setup.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 |
4 | OS="$(uname -s)"
5 |
6 | if [[ "$OS" == "Linux"* ]]; then
7 | echo "Detected Linux system. Installing LibreOffice (if not installed) and openpyxl..."
8 | if ! command -v libreoffice >/dev/null 2>&1; then
9 | sudo apt update -qq
10 | sudo apt install -y libreoffice >/dev/null
11 | fi
12 | uv pip install openpyxl
13 | elif [[ "$OS" == "Darwin"* || "$OS" == MINGW* || "$OS" == MSYS* || "$OS" == CYGWIN* ]]; then
14 | echo "Detected macOS/Windows system. Installing xlwings and openpyxl..."
15 | uv pip install openpyxl xlwings
16 | else
17 | echo "Unsupported OS: $OS" >&2
18 | exit 1
19 | fi
20 |
--------------------------------------------------------------------------------
/src/benchmax/envs/README.md:
--------------------------------------------------------------------------------
1 | # Envs
2 |
3 | This directory contains:
4 | ```bash
5 | ├── crm/ # Salesforce env (extends MCP)
6 | ├── excel/ # Excel env (extends MCP)
7 | ├── math/ # Math env (extends MCP)
8 | ├── mcp/ # MCP env class (extends BaseEnv)
9 | ├── wikipedia/ # Wikipedia env (extends BaseEnv)
10 | ├── types.py # Shared types
11 | └── base_env.py # Base env class
12 | ```
13 |
14 | Pre-built envs like CRM, Excel, Math uses MCP which has built-in multi-node parallelization are ready to use out of the box. To learn how to create your own parallelized MCP env, check out [this guide here](mcp/README.md)
15 |
16 | If you want to manually extend `BaseEnv` (no multi-node support), you can check out the Wikipedia env or [follow this guide](how-to-extend-base-env.md).
17 |
18 |
--------------------------------------------------------------------------------
/src/benchmax/envs/excel/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | import tarfile
4 |
5 |
6 | def download_and_extract(url, output_path):
7 | """
8 | Downloads a tar.gz file from the given URL and extracts it into output_path.
9 | """
10 | # Ensure the output directory exists
11 | os.makedirs(output_path, exist_ok=True)
12 |
13 | # Determine local file name
14 | local_filename = os.path.join(output_path, os.path.basename(url))
15 |
16 | # Download the file
17 | with requests.get(url, stream=True) as r:
18 | r.raise_for_status()
19 | with open(local_filename, "wb") as f:
20 | for chunk in r.iter_content(chunk_size=8192):
21 | f.write(chunk)
22 |
23 | # Extract the tar.gz
24 | with tarfile.open(local_filename, "r:gz") as tar:
25 | tar.extractall(path=output_path)
26 |
27 | print(f"Downloaded and extracted to {output_path}")
28 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | # Pytest configuration for benchmax
3 | pythonpath = .
4 |
5 | testpaths = tests
6 |
7 | # Test discovery
8 | python_files = test_*.py
9 | python_classes = Test*
10 | python_functions = test_*
11 |
12 | # Asyncio mode
13 | asyncio_mode = auto
14 |
15 | # Markers
16 | markers =
17 | slow: Slow-running tests (deselect with '-m "not slow"')
18 | remote: marks tests that require remote resources (deselect with '-m "not remote"')
19 | excel: mark tests that require opening of excel app
20 | unit: Fast, isolated tests
21 |
22 | # Default deselection to speed up CI runs
23 | addopts = -m "not slow and not remote and not excel"
24 |
25 | # Logging
26 | log_cli = false
27 | log_cli_level = INFO
28 | log_cli_format = %(asctime)s [%(levelname)8s] %(message)s
29 | log_cli_date_format = %Y-%m-%d %H:%M:%S
30 |
31 | # Warnings
32 | filterwarnings =
33 | ignore::DeprecationWarning
34 | ignore::PendingDeprecationWarning
35 |
36 | minversion = 7.0
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 |
23 | # Virtual Environments
24 | .env
25 | .venv
26 | .python-version
27 | env/
28 | venv/
29 | ENV/
30 |
31 | # IDE
32 | .idea/
33 | .vscode/
34 | *.swp
35 | *.swo
36 |
37 | # Distribution / packaging
38 | .Python
39 | build/
40 | develop-eggs/
41 | dist/
42 | downloads/
43 | eggs/
44 | .eggs/
45 | lib/
46 | lib64/
47 | parts/
48 | sdist/
49 | var/
50 | wheels/
51 | *.egg-info/
52 | .installed.cfg
53 | *.egg
54 |
55 | # Unit test / coverage reports
56 | htmlcov/
57 | .tox/
58 | .coverage
59 | .coverage.*
60 | .cache
61 | nosetests.xml
62 | coverage.xml
63 | *.cover
64 | .hypothesis/
65 | .pytest_cache/
66 |
67 | # mypy
68 | .mypy_cache/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # macOS
74 | .DS_Store
75 |
76 | # Poetry
77 | poetry.lock
78 |
79 | # Generated workspaces
80 | workspaces/
81 | outputs/
--------------------------------------------------------------------------------
/src/benchmax/envs/math/workdir/reward_fn.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import re
3 | from typing import Any, Callable, Dict, Union, Awaitable
4 | from fastmcp import Client
5 | from html import unescape
6 |
7 | RewardFunction = Callable[..., Union[float, Awaitable[float]]]
8 |
9 |
10 | async def text_match_reward(
11 | completion: str,
12 | ground_truth: str,
13 | mcp_client: Client,
14 | workspace: Path,
15 | **kwargs: Any,
16 | ) -> float:
17 | """
18 | Reward = 1 if `ground_truth` (case-insensitive) appears anywhere *inside*
19 | the first … block of `completion`; otherwise 0.
20 |
21 | Falls back to 0 if the tag is missing or empty.
22 | """
23 |
24 | # Grab only the text inside the first … pair (case-insensitive).
25 | m = re.search(
26 | r"(.*?)", completion, flags=re.IGNORECASE | re.DOTALL
27 | )
28 | if m is None:
29 | return 0.0
30 |
31 | # Unescape any XML entities (& → &, etc.) and normalise whitespace.
32 | answer_text = unescape(m.group(1)).strip().lower()
33 |
34 | try:
35 | # Try to interpret both as floats for numerical comparison.
36 | return float(float(ground_truth.lower()) == float(answer_text))
37 | except ValueError:
38 | return 0.0
39 |
40 |
41 | # -------------------------------
42 | # Export reward functions
43 | # -------------------------------
44 | reward_functions: Dict[str, RewardFunction] = {"match": text_match_reward}
45 |
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/provisioners/base_provisioner.py:
--------------------------------------------------------------------------------
1 | """
2 | Base provisioner interface for server provisioning strategies.
3 | """
4 |
5 | from abc import ABC, abstractmethod
6 | from typing import List
7 |
8 |
9 | class BaseProvisioner(ABC):
10 | """
11 | Abstract base class for server provisioning strategies.
12 |
13 | A provisioner is responsible for:
14 | 1. Starting/launching servers (returning their addresses)
15 | 2. Cleaning up resources when done
16 | """
17 |
18 | @property
19 | @abstractmethod
20 | def num_servers(self) -> int:
21 | """
22 | Total number of servers
23 |
24 | This reports the number of servers that are / will be provisioned.
25 | """
26 | pass
27 |
28 | @abstractmethod
29 | async def provision_servers(self, api_secret: str) -> List[str]:
30 | """
31 | Provision servers and return their addresses.
32 |
33 | Args:
34 | api_secret: Secret for server authentication.
35 |
36 | Returns:
37 | List of server addresses in "host:port" format.
38 | Example: ["localhost:8080", "192.168.1.10:8080"]
39 | """
40 | pass
41 |
42 | @abstractmethod
43 | async def teardown(self) -> None:
44 | """
45 | Tear down provisioned resources.
46 |
47 | This should clean up any resources created during provisioning,
48 | such as stopping processes, terminating cloud instances, etc.
49 | """
50 | pass
51 |
--------------------------------------------------------------------------------
/tests/integration/envs/mcp/provisioners/utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Optional
3 | import aiohttp
4 |
5 |
6 |
7 | async def check_health(address: str) -> bool:
8 | """
9 | Check if a server's /health endpoint is responding.
10 |
11 | Args:
12 | address: Server address in format "host:port"
13 |
14 | Returns:
15 | True if server is healthy, False otherwise
16 | """
17 | host, port = address.split(":")
18 | url = f"http://{host}:{port}/health"
19 |
20 | timeout_obj = aiohttp.ClientTimeout(total=2.0)
21 |
22 | try:
23 | async with aiohttp.ClientSession(timeout=timeout_obj) as session:
24 | async with session.get(url) as response:
25 | return response.status == 200
26 | except (aiohttp.ClientError, asyncio.TimeoutError):
27 | return False
28 |
29 |
30 | async def wait_for_server_health(address: str, timeout: float = 60.0) -> bool:
31 | """
32 | Wait for a server to be healthy by polling its /health endpoint.
33 |
34 | Args:
35 | address: Server address in format "host:port"
36 | timeout: Maximum time to wait in seconds
37 |
38 | Returns:
39 | True if server becomes healthy, False if timeout
40 | """
41 | start_time = asyncio.get_event_loop().time()
42 |
43 | while (asyncio.get_event_loop().time() - start_time) < timeout:
44 | if await check_health(address):
45 | return True
46 | await asyncio.sleep(1.0)
47 |
48 | return False
49 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """
2 | Shared fixtures
3 | """
4 |
5 | import os
6 | import tempfile
7 | import uuid
8 | import pytest
9 | from pathlib import Path
10 | import importlib.util
11 |
12 |
13 | @pytest.fixture
14 | def unique_rollout_id() -> str:
15 | """Generate a unique rollout ID for testing."""
16 | return f"test-rollout-{uuid.uuid4().hex[:8]}"
17 |
18 |
19 | @pytest.fixture
20 | def test_sync_dir(tmp_path: Path) -> Path:
21 | """Temporary directory for mocking syncdir (unit tests only)."""
22 | sync_dir = tmp_path / "sync"
23 | os.mkdir(sync_dir)
24 | return sync_dir
25 |
26 |
27 | @pytest.fixture(scope="session")
28 | def session_tmp_path() -> Path:
29 | """Temporary directory for test session."""
30 | return Path(tempfile.mkdtemp(prefix="benchmax_test_session_"))
31 |
32 |
33 | @pytest.fixture(scope="session")
34 | def example_workdir() -> Path:
35 | """Path to example MCP workdir inside benchmax.envs.mcp."""
36 | # Locate the mcp package dynamically
37 | spec = importlib.util.find_spec("benchmax.envs.mcp")
38 | if not spec or not spec.submodule_search_locations:
39 | raise RuntimeError("Could not locate benchmax.envs.mcp package")
40 |
41 | # The directory containing __init__.py
42 | mcp_pkg_dir = Path(spec.submodule_search_locations[0])
43 |
44 | # Workdir is relative to that
45 | workdir = mcp_pkg_dir / "example_workdir"
46 |
47 | if not workdir.exists():
48 | raise FileNotFoundError(f"Expected example_workdir not found at: {workdir}")
49 |
50 | return workdir
51 |
--------------------------------------------------------------------------------
/src/benchmax/envs/math/README.md:
--------------------------------------------------------------------------------
1 | # Math Environment
2 |
3 | This environment provides capabilities for solving mathematical problems through a local calculator MCP server.
4 |
5 | ## Prerequisites
6 |
7 | Before using this environment, ensure you have:
8 | - Python 3.12 or later installed
9 | - The `mcp-server-calculator` package installed
10 |
11 | ## Installation
12 |
13 | ```bash
14 | pip install "benchmax[skypilot]"
15 | ```
16 |
17 | Includes:
18 | - fastmcp: For MCP server functionality that enables calculator operations
19 |
20 | ## Usage
21 |
22 | Use `MathEnvLocal` to run the servers locally on the same machine as benchmax or use `MathEnvSkypilot` to parallelize the servers across multiple nodes.
23 |
24 | ## Available Tools
25 |
26 | The environment provides a calculator MCP tool through the server configuration:
27 |
28 | ### Calculator Tool
29 | The calculator tool is provided through a local MCP server that:
30 | - Handles mathematical computations
31 | - Takes mathematical expressions as input
32 | - Returns computed results
33 | - Supports standard mathematical operations
34 |
35 | ## Reward Function
36 | Written in workdir/reward_fn.py so that the reward function can easily be calculated with the remote node.
37 |
38 | The evaluator awards 1.0 when the ground-truth string, after case-insensitive comparison, whitespace normalization, and XML-entity unescaping—appears anywhere inside the first ... block of the completion; otherwise the reward is 0.0. If that tag pair is missing or empty, the reward defaults to 0.0. This binary scheme incentivizes placing an exact, normalized final answer within the required XML tags.
39 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "benchmax"
3 | version = "0.1.2.dev5"
4 | description = "Framework-Agnostic RL Environments for LLM Fine-Tuning"
5 | readme = "README.md"
6 | authors = [{ name = "cgft.io" }]
7 | requires-python = "==3.12.*"
8 | dependencies = [
9 | "aiohttp>=3.13.1",
10 | "asyncio>=4.0.0",
11 | "datasets>=4.0.0",
12 | "fastmcp~=2.12.0",
13 | "pyjwt>=2.10.1",
14 | "skypilot~=0.8.1",
15 | ]
16 | classifiers = [
17 | "Programming Language :: Python :: 3",
18 | "Operating System :: OS Independent",
19 | ]
20 |
21 | [build-system]
22 | requires = ["setuptools>=61.0", "wheel"]
23 | build-backend = "setuptools.build_meta"
24 |
25 | [tool.setuptools.packages.find]
26 | where = ["src"]
27 |
28 | [dependency-groups]
29 | dev = [
30 | "pytest>=8.4.2",
31 | "pytest-asyncio>=1.2.0",
32 | "python-dotenv>=1.2.1",
33 | "ruff>=0.14.2",
34 | ]
35 | skypilot = [
36 | "skypilot[aws,gcp,azure]~=0.8.1", # Change this to your cloud provider
37 | "pip>=25.3", # Added as needed for skypilot launch
38 | "msrestazure>=0.6.4.post1",
39 | ]
40 | skyrl = [
41 | "grpcio>=1.60.0",
42 | "hydra-core>=1.3.2",
43 | "omegaconf>=2.3.0",
44 | "ray>=2.48.0",
45 | "skyrl-gym>=0.1.1",
46 | "skyrl-train[vllm]>=0.2.0",
47 | ]
48 | excel = ["openpyxl>=3.1.5"]
49 | excel-mac-windows = ["openpyxl>=3.1.5", "xlwings>=0.33.16"]
50 | crm = ["python-dateutil>=2.9.0.post0", "simple-salesforce>=1.12.9"]
51 |
52 | [tool.uv]
53 | conflicts = [[{ group = "skypilot" }, { group = "skyrl" }]]
54 |
55 | [tool.uv.pip]
56 | extra = ["dev", "skypilot", "skyrl", "excel", "excel-mac-windows", "crm"]
57 |
58 | [tool.uv.extra-build-dependencies]
59 | flash-attn = [{ requirement = "torch", match-runtime = true }]
60 |
61 | [tool.uv.extra-build-variables]
62 | flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE" }
63 |
--------------------------------------------------------------------------------
/src/benchmax/envs/wikipedia/README.md:
--------------------------------------------------------------------------------
1 | # Wikipedia Environment
2 |
3 | This environment provides capabilities for interacting with Wikipedia through its API, enabling search and article retrieval functionality.
4 |
5 | ## Prerequisites
6 |
7 | No additional software installation is required. However, we suggest instantiating the `WikipediaEnv` class with API keys for higher rate limits.
8 |
9 | ## Installation
10 |
11 | ```bash
12 | pip install "benchmax"
13 | ```
14 |
15 | ## Usage
16 |
17 | Use `WikipediaEnv` to run the servers locally.
18 |
19 | ## Available Tools
20 |
21 | The environment provides two MCP tools for Wikipedia interaction:
22 |
23 | ### search_wikipedia
24 | Searches Wikipedia articles by keyword:
25 | - Takes a search query and optional result limit
26 | - Returns a list of relevant articles with titles and snippets
27 | - Handles proper escaping and HTML cleanup
28 |
29 | ### get_wikipedia_article
30 | Fetches the full plaintext of a Wikipedia article:
31 | - Takes an exact article title as input
32 | - Returns the complete article text (up to specified character limit)
33 | - Handles redirects automatically
34 | - Returns plain text with HTML markup removed
35 |
36 | ## Reward Function
37 | The task scores 1.0 only if the ground-truth string, after XML-entity unescaping and whitespace normalization, exactly matches the text inside the first ... block (case-insensitive); otherwise it returns 0.0. If that block is missing or empty, the reward defaults to 0.0. This binary scheme forces the model to place a single, exact final answer inside the first answer tag while allowing any additional explanation outside it.
38 |
39 | ## Features
40 |
41 | - API key rotation support to handle rate limits
42 | - HTML cleanup and entity unescaping
43 | - Configurable result limits
44 | - Error handling for API failures
45 | - Support for article redirects
46 | - Plain text extraction
47 |
48 | Dataset: https://huggingface.co/datasets/chiayewken/bamboogle
--------------------------------------------------------------------------------
/src/benchmax/envs/excel/workdir/reward_fn.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any, Awaitable, Callable, Dict, Optional, Union
3 |
4 | from fastmcp import Client
5 |
6 | try:
7 | from excel_utils import compare_excel_cells
8 | except:
9 | # Added except local import for unit testing purposes
10 | from .excel_utils import compare_excel_cells
11 |
12 | RewardFunction = Callable[..., Union[float, Awaitable[float]]]
13 |
14 |
15 | def spreadsheet_comparison_reward(
16 | completion: str,
17 | ground_truth: dict,
18 | mcp_client: Client,
19 | workspace: Path,
20 | **kwargs: Any,
21 | ) -> float:
22 | """
23 | Compares the output spreadsheet to the ground truth using cell values in the specified range.
24 | Returns 1.0 if all values match, else 0.0.
25 | """
26 | answer_position: Optional[str] = kwargs.get("answer_position")
27 | output_filename: Optional[str] = kwargs.get("output_filename")
28 | ground_truth_filename: Optional[str] = kwargs.get("ground_truth_filename")
29 |
30 | if not answer_position or not output_filename or not ground_truth_filename:
31 | raise ValueError(
32 | "kwargs must contain 'answer_position', 'output_filename', and 'ground_truth_filename' fields"
33 | )
34 |
35 | output_path = workspace / output_filename
36 | ground_truth_path = workspace / ground_truth_filename
37 |
38 | # Return 1.0 score if the output completely matches the ground truth
39 | try:
40 | match, _ = compare_excel_cells(
41 | str(ground_truth_path), str(output_path), answer_position
42 | )
43 | return 1.0 if match else 0.0
44 | except Exception as e:
45 | print(
46 | f"Error comparing spreadsheets {ground_truth_path} and {output_path}: {e}"
47 | )
48 | return 0.0
49 |
50 |
51 | # -------------------------------
52 | # Export reward functions
53 | # -------------------------------
54 | reward_functions: Dict[str, RewardFunction] = {
55 | "spreadsheet": spreadsheet_comparison_reward,
56 | }
57 |
--------------------------------------------------------------------------------
/examples/skyrl/run_benchmax_math.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | # Colocated GRPO training+generation for Qwen2.5-1.5B-Instruct on a sample benchmax math environment.
4 | # uv run benchmax/adapters/skyrl/benchmax_data_process.py --local_dir ~/data/math --dataset_name dawidmt/arithmetic50 --env_path benchmax.envs.math.math_env.MathEnvLocal
5 | # bash examples/skyrl/run_benchmax_math.sh
6 |
7 | DATA_DIR="$HOME/data/math"
8 | NUM_GPUS=2
9 | ENV_CLASS="MathEnv"
10 |
11 | uv run --isolated --group skyrl -m examples.skyrl.benchmax_math \
12 | data.train_data="['$DATA_DIR/train.parquet']" \
13 | data.val_data="['$DATA_DIR/test.parquet']" \
14 | trainer.algorithm.advantage_estimator="grpo" \
15 | trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \
16 | trainer.placement.colocate_all=true \
17 | trainer.strategy=fsdp2 \
18 | trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
19 | trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
20 | generator.num_inference_engines=$NUM_GPUS \
21 | generator.inference_engine_tensor_parallel_size=1 \
22 | trainer.epochs=20 \
23 | trainer.update_epochs_per_batch=1 \
24 | trainer.train_batch_size=32 \
25 | trainer.policy_mini_batch_size=32 \
26 | trainer.critic_mini_batch_size=32 \
27 | trainer.micro_forward_batch_size_per_gpu=16 \
28 | trainer.micro_train_batch_size_per_gpu=16 \
29 | trainer.eval_batch_size=32 \
30 | trainer.eval_before_train=true \
31 | trainer.eval_interval=5 \
32 | trainer.ckpt_interval=20 \
33 | trainer.max_prompt_length=512 \
34 | generator.sampling_params.max_generate_length=1024 \
35 | trainer.policy.optimizer_config.lr=1.0e-7 \
36 | trainer.algorithm.use_kl_loss=true \
37 | generator.backend=vllm \
38 | generator.run_engines_locally=true \
39 | generator.weight_sync_backend=nccl \
40 | generator.async_engine=true \
41 | generator.batched=false \
42 | environment.env_class=$ENV_CLASS \
43 | generator.n_samples_per_prompt=5 \
44 | generator.gpu_memory_utilization=0.8 \
45 | trainer.logger="wandb" \
46 | trainer.project_name="benchmax_math" \
47 | trainer.run_name="benchmax_math" \
48 | $@
--------------------------------------------------------------------------------
/examples/skyrl/run_benchmax_excel.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | # Colocated GRPO training+generation for Qwen2.5-3B-Instruct on a sample benchmax excel environment.
4 | # uv run benchmax/adapters/skyrl/benchmax_data_process.py --local_dir ~/data/excel --dataset_name spreadsheetbench --env_path benchmax.envs.excel.excel_env.ExcelEnvLocal
5 | # bash examples/skyrl/run_benchmax_excel.sh
6 |
7 | DATA_DIR="$HOME/data/excel"
8 | NUM_GPUS=2
9 | ENV_CLASS="ExcelEnv"
10 |
11 | uv run --isolated --group skyrl --group excel -m examples.skyrl.benchmax_excel \
12 | data.train_data="['$DATA_DIR/train.parquet']" \
13 | data.val_data="['$DATA_DIR/test.parquet']" \
14 | trainer.algorithm.advantage_estimator="grpo" \
15 | trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \
16 | trainer.placement.colocate_all=true \
17 | trainer.strategy=fsdp2 \
18 | trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
19 | trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
20 | generator.num_inference_engines=$NUM_GPUS \
21 | generator.inference_engine_tensor_parallel_size=1 \
22 | trainer.epochs=20 \
23 | trainer.update_epochs_per_batch=1 \
24 | trainer.train_batch_size=20 \
25 | trainer.policy_mini_batch_size=20 \
26 | trainer.critic_mini_batch_size=20 \
27 | trainer.micro_forward_batch_size_per_gpu=10 \
28 | trainer.micro_train_batch_size_per_gpu=10 \
29 | trainer.eval_batch_size=20 \
30 | trainer.eval_before_train=true \
31 | trainer.eval_interval=5 \
32 | trainer.ckpt_interval=10 \
33 | trainer.max_prompt_length=3000 \
34 | generator.sampling_params.max_generate_length=1024 \
35 | trainer.policy.optimizer_config.lr=1.0e-7 \
36 | trainer.algorithm.use_kl_loss=true \
37 | generator.backend=vllm \
38 | generator.run_engines_locally=true \
39 | generator.weight_sync_backend=nccl \
40 | generator.async_engine=true \
41 | generator.batched=false \
42 | environment.env_class=$ENV_CLASS \
43 | generator.n_samples_per_prompt=5 \
44 | generator.gpu_memory_utilization=0.8 \
45 | trainer.logger="wandb" \
46 | trainer.project_name="benchmax_excel" \
47 | trainer.run_name="benchmax_excel" \
48 | $@
--------------------------------------------------------------------------------
/tests/unit/envs/excel/test_excel_code_runner_mcp.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import os
3 | from pathlib import Path
4 | import pytest
5 | import shutil
6 | from benchmax.envs.excel.workdir.excel_code_runner_mcp import run_excel_code_impl
7 | from benchmax.envs.excel.workdir.excel_utils import excel_to_str_repr
8 |
9 |
10 | # Fixtures
11 | @pytest.fixture(scope="session")
12 | def test_xlsx_path() -> str:
13 | return str(Path(__file__).parent / "test_inputs" / "test.xlsx")
14 |
15 |
16 | @contextmanager
17 | def temporary_cwd(path: Path):
18 | old_cwd = os.getcwd()
19 | os.chdir(path)
20 | try:
21 | yield
22 | finally:
23 | os.chdir(old_cwd)
24 |
25 |
26 | def test_excel_to_str_repr_basic(test_xlsx_path: str):
27 | output = excel_to_str_repr(test_xlsx_path)
28 | assert "Sheet Name: Sheet1" in output
29 | assert "Sheet Name: Sheet2" in output
30 | # Check for some known cell values and styles
31 | assert "D3: 1" in output
32 | assert "H3: D" in output
33 | assert "D6: =SUM(D3:D5) -> 2" in output
34 |
35 |
36 | def test_run_excel_code_impl_success(tmp_path: Path, test_xlsx_path: str):
37 | # Copy input to a temp file for isolation
38 | input_path = tmp_path / "input.xlsx"
39 | output_path = tmp_path / "output.xlsx"
40 | shutil.copyfile(test_xlsx_path, input_path)
41 |
42 | # User code: set D3 to 10 and save to output_path
43 | user_code = f'''
44 | from openpyxl import load_workbook
45 | wb = load_workbook("{input_path}")
46 | ws = wb["Sheet1"]
47 | ws["D3"].value = 10
48 | wb.save("{output_path}")
49 | wb.close()
50 | '''
51 |
52 | # Run code with temporary cwd
53 | with temporary_cwd(tmp_path):
54 | result = run_excel_code_impl(user_code, "output.xlsx")
55 |
56 | assert "D3: 10" in result
57 |
58 |
59 | def test_run_excel_code_impl_error(tmp_path: Path):
60 | output_path = tmp_path / "output.xlsx"
61 | user_code = "raise ValueError('test error')"
62 |
63 | with temporary_cwd(tmp_path):
64 | result = run_excel_code_impl(user_code, str(output_path))
65 |
66 | assert result.startswith("ERROR:")
67 | assert "test error" in result
68 |
--------------------------------------------------------------------------------
/src/benchmax/envs/excel/workdir/excel_code_runner_mcp.py:
--------------------------------------------------------------------------------
1 | from fastmcp import FastMCP
2 | import subprocess
3 | import sys
4 |
5 | try:
6 | from excel_utils import excel_to_str_repr
7 | except Exception:
8 | # Added except local import for unit testing purposes
9 | from .excel_utils import excel_to_str_repr
10 |
11 | mcp = FastMCP(
12 | name="ExcelCodeRunner",
13 | instructions="This server provides a tool for running Python code to manipulate Excel files.",
14 | )
15 |
16 | WHITE_LIKE_COLORS = [
17 | "00000000",
18 | "FFFFFFFF",
19 | "FFFFFF00",
20 | ]
21 |
22 |
23 | def run_excel_code_impl(python_code: str, output_excel_path: str) -> str:
24 | """
25 | Run Python code which should use openpyxl to manipulate an Excel file.
26 | Call load_workbook with the input excel path as specified by the user.
27 | Remember to save the workbook to the output path that you specified and then call close() so you do not overwrite the input file.
28 |
29 | If code executes with no errors, return the string representation of the Excel file with styles.
30 | If there are errors, return an error message.
31 | """
32 | code_path = "script.py"
33 | # Write the user code to a file
34 | with open(code_path, "w") as f:
35 | f.write(python_code)
36 | try:
37 | subprocess.run(
38 | [sys.executable, code_path], check=True, capture_output=True, timeout=60
39 | )
40 | except subprocess.CalledProcessError as e:
41 | return f"ERROR: User code failed: {e.stderr.decode()}"
42 | except Exception as e:
43 | return f"ERROR: Error running user code: {str(e)}"
44 | # Convert the manipulated Excel file to JSON with styles
45 | excel_str = excel_to_str_repr(output_excel_path)
46 | return excel_str
47 |
48 |
49 | @mcp.tool
50 | def run_excel_code(python_code: str, output_excel_path: str) -> str:
51 | """
52 | Run Python code which should use openpyxl to manipulate an Excel file.
53 | If code executes with no errors, returns the string representation of the Excel file with styles.
54 | If there are errors, return an error message.
55 | """
56 | return run_excel_code_impl(python_code, output_excel_path)
57 |
58 |
59 | if __name__ == "__main__":
60 | mcp.run(show_banner=False)
61 |
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/provisioners/manual_provisioner.py:
--------------------------------------------------------------------------------
1 | """
2 | Manual provisioner for using pre-existing servers.
3 | """
4 |
5 | import logging
6 | from typing import List
7 | from .base_provisioner import BaseProvisioner
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class ManualProvisioner(BaseProvisioner):
13 | """
14 | Provisioner for manually specified server addresses.
15 |
16 | Use this when you have already started servers and want to provide
17 | their addresses directly. Useful for:
18 | - Debugging with pre-started local servers
19 | - Testing against persistent test infrastructure
20 | - Using servers managed by external systems
21 |
22 | Example:
23 | provisioner = ManualProvisioner([
24 | "localhost:8080",
25 | "localhost:8081",
26 | "192.168.1.10:8080"
27 | ])
28 | """
29 |
30 | def __init__(self, addresses: List[str]):
31 | """
32 | Initialize with pre-existing server addresses.
33 |
34 | Args:
35 | addresses: List of server addresses in "host:port" format.
36 | """
37 | if not addresses:
38 | raise ValueError("ManualProvisioner requires at least one address")
39 |
40 | self._addresses = addresses
41 | logger.info(f"ManualProvisioner configured with {len(addresses)} addresses")
42 |
43 | @property
44 | def num_servers(self) -> int:
45 | """
46 | Total number of servers
47 | """
48 | return len(self._addresses)
49 |
50 | async def provision_servers(self, api_secret: str) -> List[str]:
51 | """
52 | Return the pre-configured server addresses.
53 |
54 | Args:
55 | api_secret: Unused in this function. Servers already have set ther api secret.
56 |
57 | Returns:
58 | The list of addresses provided during initialization.
59 | """
60 | logger.info(f"Using {len(self._addresses)} manually configured servers")
61 | return self._addresses.copy()
62 |
63 | async def teardown(self) -> None:
64 | """
65 | No-op teardown since servers are externally managed.
66 |
67 | ManualProvisioner does not start servers, so it does not stop them.
68 | The user is responsible for managing the server lifecycle.
69 | """
70 | logger.info("ManualProvisioner teardown (servers externally managed)")
71 |
--------------------------------------------------------------------------------
/src/benchmax/envs/math/math_env.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any
3 | import sky
4 |
5 | from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv
6 | from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
7 | from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner
8 | from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner
9 | from benchmax.envs.types import StandardizedExample
10 |
11 | SYSTEM_PROMPT = """Please use the tools provided to do any computation.
12 | Write your complete answer on the final line only, within the xml tags .\n
13 | """
14 |
15 |
16 | class MathEnv(ParallelMcpEnv):
17 | """Environment for math problems, using local MCP tools."""
18 |
19 | system_prompt: str = SYSTEM_PROMPT
20 |
21 | def __init__(self, workdir_path: Path, provisioner: BaseProvisioner, **kwargs):
22 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs)
23 |
24 | @classmethod
25 | def dataset_preprocess(cls, example: Any, **kwargs) -> StandardizedExample:
26 | return StandardizedExample(
27 | prompt=example.get("task", ""),
28 | ground_truth=example.get("answer", ""),
29 | init_rollout_args=None,
30 | )
31 |
32 |
33 | class MathEnvLocal(MathEnv):
34 | """Import this env to run environment locally"""
35 |
36 | def __init__(self, num_local_servers: int = 5, **kwargs):
37 | workdir_path = Path(__file__).parent / "workdir"
38 | provisioner = LocalProvisioner(
39 | workdir_path=workdir_path, num_servers=num_local_servers
40 | )
41 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs)
42 |
43 |
44 | class MathEnvSkypilot(MathEnv):
45 | """Import this env to run environment on any cloud (i.e. AWS / GCP / Azure) with Skypilot"""
46 |
47 | def __init__(
48 | self,
49 | cloud: sky.clouds.Cloud = sky.Azure(),
50 | num_nodes: int = 2,
51 | servers_per_node: int = 5,
52 | **kwargs,
53 | ):
54 | workdir_path = Path(__file__).parent / "workdir"
55 | provisioner = SkypilotProvisioner(
56 | workdir_path=workdir_path,
57 | cloud=cloud,
58 | num_nodes=num_nodes,
59 | servers_per_node=servers_per_node,
60 | )
61 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs)
62 |
--------------------------------------------------------------------------------
/examples/skyrl/benchmax_math.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import ray
3 | from ray.actor import ActorProxy
4 | from omegaconf import DictConfig
5 | import skyrl_gym
6 | from skyrl_train.utils import initialize_ray
7 | from skyrl_train.entrypoints.main_base import BasePPOExp, validate_cfg
8 | from skyrl_train.config.utils import CONFIG_DIR
9 | from skyrl_gym.envs import register
10 |
11 | from benchmax.adapters.skyrl.skyrl_adapter import (
12 | cleanup_actor,
13 | get_or_create_benchmax_env_actor,
14 | load_benchmax_env_skyrl,
15 | )
16 | from benchmax.adapters.benchmax_wrapper import BenchmaxEnv
17 |
18 | BENCHMAX_ACTOR_NAME = "BenchmaxEnvService"
19 |
20 |
21 | @ray.remote(num_cpus=1)
22 | def skyrl_entrypoint(cfg: DictConfig):
23 | actor = None
24 | try:
25 | # UNCOMMENT the following to run MathEnv with skypilot (comment the other)
26 | # from benchmax.envs.math.math_env import MathEnvSkypilot
27 | # import sky
28 |
29 | # actor = get_or_create_benchmax_env_actor(
30 | # MathEnvSkypilot,
31 | # env_kwargs={
32 | # "cloud": sky.Azure(),
33 | # "num_nodes": 5,
34 | # "servers_per_node": 32,
35 | # }, # samples / prompt * batch size = 160 = 32 * 5
36 | # actor_name=BENCHMAX_ACTOR_NAME,
37 | # )
38 | # UNCOMMENT the following to run MathEnv locally (comment the other)
39 | from benchmax.envs.math.math_env import MathEnvLocal
40 |
41 | actor = get_or_create_benchmax_env_actor(
42 | MathEnvLocal,
43 | env_kwargs={
44 | "num_local_servers": 160
45 | }, # samples / prompt * batch size = 160
46 | actor_name=BENCHMAX_ACTOR_NAME,
47 | )
48 | register(
49 | id="MathEnv",
50 | entry_point=load_benchmax_env_skyrl,
51 | kwargs={"actor": actor},
52 | )
53 | skyrl_gym.pprint_registry()
54 |
55 | exp = BasePPOExp(cfg)
56 | exp.run()
57 |
58 | finally:
59 | cleanup_actor(actor)
60 |
61 |
62 | @hydra.main(
63 | config_path=str(CONFIG_DIR),
64 | config_name="ppo_base_config",
65 | version_base=None,
66 | )
67 | def main(cfg: DictConfig) -> None:
68 | try:
69 | validate_cfg(cfg)
70 | initialize_ray(cfg)
71 | ray.get(skyrl_entrypoint.remote(cfg))
72 | finally:
73 | try:
74 | benchmax_actor: ActorProxy[BenchmaxEnv] = ray.get_actor(BENCHMAX_ACTOR_NAME)
75 | cleanup_actor(benchmax_actor)
76 | except Exception:
77 | pass
78 | finally:
79 | ray.shutdown()
80 |
81 |
82 | if __name__ == "__main__":
83 | main()
84 |
--------------------------------------------------------------------------------
/examples/skyrl/benchmax_excel.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import ray
3 | from ray.actor import ActorProxy
4 | from omegaconf import DictConfig
5 | import skyrl_gym
6 | from skyrl_train.utils import initialize_ray
7 | from skyrl_train.entrypoints.main_base import BasePPOExp, validate_cfg
8 | from skyrl_train.config.utils import CONFIG_DIR
9 | from skyrl_gym.envs import register
10 |
11 | from benchmax.adapters.skyrl.skyrl_adapter import (
12 | cleanup_actor,
13 | get_or_create_benchmax_env_actor,
14 | load_benchmax_env_skyrl,
15 | )
16 | from benchmax.adapters.benchmax_wrapper import BenchmaxEnv
17 |
18 | BENCHMAX_ACTOR_NAME = "BenchmaxEnvService"
19 |
20 |
21 | @ray.remote(num_cpus=1)
22 | def skyrl_entrypoint(cfg: DictConfig):
23 | actor = None
24 | try:
25 | # UNCOMMENT the following to run ExcelEnv with skypilot (comment the other)
26 | from benchmax.envs.excel.excel_env import ExcelEnvSkypilot
27 | import sky
28 |
29 | actor = get_or_create_benchmax_env_actor(
30 | ExcelEnvSkypilot,
31 | env_kwargs={
32 | "cloud": sky.Azure(),
33 | "num_nodes": 5,
34 | "servers_per_node": 20,
35 | }, # samples / prompt * batch size = 100 = 20 * 5
36 | actor_name=BENCHMAX_ACTOR_NAME,
37 | )
38 |
39 | # UNCOMMENT the following to run ExcelEnv locally (comment the other)
40 | # from benchmax.envs.excel.excel_env import ExcelEnvLocal
41 |
42 | # actor = get_or_create_benchmax_env_actor(
43 | # ExcelEnvLocal,
44 | # env_kwargs={
45 | # "num_local_servers": 100
46 | # }, # samples / prompt * batch size = 100
47 | # actor_name=BENCHMAX_ACTOR_NAME,
48 | # )
49 |
50 | register(
51 | id="ExcelEnv",
52 | entry_point=load_benchmax_env_skyrl,
53 | kwargs={"actor": actor},
54 | )
55 | skyrl_gym.pprint_registry()
56 |
57 | exp = BasePPOExp(cfg)
58 | exp.run()
59 |
60 | finally:
61 | cleanup_actor(actor)
62 |
63 |
64 | @hydra.main(
65 | config_path=str(CONFIG_DIR),
66 | config_name="ppo_base_config",
67 | version_base=None,
68 | )
69 | def main(cfg: DictConfig) -> None:
70 | try:
71 | validate_cfg(cfg)
72 | initialize_ray(cfg)
73 | ray.get(skyrl_entrypoint.remote(cfg))
74 | finally:
75 | try:
76 | benchmax_actor: ActorProxy[BenchmaxEnv] = ray.get_actor(BENCHMAX_ACTOR_NAME)
77 | cleanup_actor(benchmax_actor)
78 | except Exception:
79 | pass
80 | finally:
81 | ray.shutdown()
82 |
83 |
84 | if __name__ == "__main__":
85 | main()
86 |
--------------------------------------------------------------------------------
/tests/unit/envs/mcp/provisioners/test_manual.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from benchmax.envs.mcp.provisioners import ManualProvisioner
3 |
4 |
5 | class TestManualProvisioner:
6 | """Tests for ManualProvisioner class."""
7 |
8 | def test_init_with_valid_addresses(self):
9 | """Test initialization with valid addresses."""
10 | addresses = ["localhost:8080", "192.168.1.10:8080"]
11 | provisioner = ManualProvisioner(addresses)
12 |
13 | assert provisioner._addresses == addresses
14 |
15 | def test_init_with_empty_list_raises_error(self):
16 | """Test that empty address list raises ValueError."""
17 | with pytest.raises(ValueError, match="at least one address"):
18 | ManualProvisioner([])
19 |
20 | @pytest.mark.asyncio
21 | async def test_provision_servers_returns_addresses(self):
22 | """Test that provision_servers returns configured addresses."""
23 | addresses = ["localhost:8080", "localhost:8081"]
24 | provisioner = ManualProvisioner(addresses)
25 |
26 | result = await provisioner.provision_servers("dummy-api-secret")
27 |
28 | assert result == addresses
29 |
30 | @pytest.mark.asyncio
31 | async def test_provision_servers_returns_copy(self):
32 | """Test that provision_servers returns a copy, not the original list."""
33 | addresses = ["localhost:8080"]
34 | provisioner = ManualProvisioner(addresses)
35 |
36 | result = await provisioner.provision_servers("dummy-api-secret")
37 | result.append("modified")
38 |
39 | # Original should be unchanged
40 | assert provisioner._addresses == ["localhost:8080"]
41 |
42 | @pytest.mark.asyncio
43 | async def test_teardown_is_noop(self):
44 | """Test that teardown completes without error (no-op)."""
45 | provisioner = ManualProvisioner(["localhost:8080"])
46 | await provisioner.teardown() # Should not raise
47 |
48 | @pytest.mark.asyncio
49 | async def test_multiple_provision_calls(self):
50 | """Test that provision_servers can be called multiple times."""
51 | provisioner = ManualProvisioner(["localhost:8080"])
52 |
53 | result1 = await provisioner.provision_servers("dummy-api-secret")
54 | result2 = await provisioner.provision_servers("dummy-api-secret")
55 |
56 | assert result1 == result2
57 |
58 | @pytest.mark.asyncio
59 | async def test_provision_after_teardown(self):
60 | """Test that provision_servers works after teardown."""
61 | provisioner = ManualProvisioner(["localhost:8080"])
62 |
63 | await provisioner.provision_servers("dummy-api-secret")
64 | await provisioner.teardown()
65 | result = await provisioner.provision_servers("dummy-api-secret")
66 |
67 | assert result == ["localhost:8080"]
68 |
--------------------------------------------------------------------------------
/src/benchmax/envs/excel/README.md:
--------------------------------------------------------------------------------
1 | # Excel Environment
2 |
3 | This environment provides capabilities for interacting with Excel files through either LibreOffice (Linux) or Microsoft Excel (Windows/macOS).
4 |
5 | This is based off the [SpreadsheetBench Benchmark](https://spreadsheetbench.github.io/)
6 |
7 | ## Prerequisites
8 |
9 | **Important**: Before using this environment, ensure you have the appropriate spreadsheet application installed:
10 | - **Linux**: LibreOffice must be installed
11 | ```bash
12 | sudo apt install libreoffice
13 | ```
14 | - **Windows/macOS**: Microsoft Excel must be installed
15 |
16 | ## Installation
17 |
18 | ### Linux
19 | ```bash
20 | pip install "benchmax[excel,skypilot]"
21 | ```
22 | Includes:
23 | - openpyxl: For Excel file manipulation
24 | - fastmcp: For MCP server functionality
25 |
26 | ### Windows/macOS
27 | ```bash
28 | pip install "benchmax[excel-mac-windows,skypilot]"
29 | ```
30 | Includes:
31 | - openpyxl: For Excel file manipulation
32 | - xlwings: For direct Excel application interaction
33 | - fastmcp: For MCP server functionality
34 |
35 | ## Usage
36 |
37 | Use `ExcelEnvLocal` to run the servers locally on the same machine as benchmax or use `ExcelEnvSkypilot` to parallelize the servers across multiple nodes.
38 |
39 | ## Available Tool
40 |
41 | The environment provides a single MCP tool for Excel manipulation:
42 |
43 | ### run_excel_code
44 | Executes Python code that uses openpyxl to manipulate Excel files. The tool:
45 | - Takes Python code and an output Excel path as input
46 | - Runs the code in a controlled environment
47 | - Returns a string representation of the modified Excel file
48 | - Preserves spreadsheet formatting (colors, fonts, styles)
49 | - Handles both cell-level and sheet-level operations
50 |
51 | ## Reward Functions
52 |
53 | Reward functions measure how well a generated spreadsheet matches the expected output.
54 |
55 | ### Default Reward Function
56 |
57 | The built-in reward function compares the output spreadsheet against a ground truth using only the cells specified in the task.
58 |
59 | - If all the values in the relevant cells are correct, the reward is **1.0**
60 | - If there are any mismatches, the reward is **0.0**
61 |
62 | ### Comparison Strategy
63 |
64 | - Compares evaluated values, not formulas (e.g., it checks the result `10`, not the formula `=5+5`)
65 | - Only compares the cells within the defined answer range
66 | - Supports multiple cell types like numbers, strings, times, and dates
67 | - Can optionally check visual formatting, such as:
68 | - Fill color (background)
69 | - Font color
70 |
71 | ### Error Handling
72 |
73 | If the comparison fails due to a missing file, invalid cell reference, or formatting error, the reward function will return **0.0** and log a helpful message for debugging.
74 |
75 | ### Outcome
76 |
77 | - **1.0** score if the generated spreadsheet is fully correct
78 | - **0.0** if any discrepancy is found
--------------------------------------------------------------------------------
/src/benchmax/envs/how-to-extend-base-env.md:
--------------------------------------------------------------------------------
1 | ## Create your env from BaseEnv
2 |
3 | ### 1. **Define the system prompt**
4 |
5 | This helps instruct the model on how to interact with the tool and format output.
6 |
7 | ```python
8 | SYSTEM_PROMPT = """Use the `evaluate` tool to perform any computation.
9 | Write your final answer on the last line inside ....
10 | """
11 | ```
12 |
13 | ### 2. **Create a reward function**
14 |
15 | We'll score the model 1.0 if it places the correct answer inside `...` tags:
16 |
17 | ```python
18 | import re
19 | from html import unescape
20 | from pathlib import Path
21 |
22 | def reward_func(prompt: str, completion: str, ground_truth: str, workspace: Path, **kwargs) -> float:
23 | m = re.search(r'(.*?)', completion, flags=re.IGNORECASE | re.DOTALL)
24 | if not m:
25 | return 0.0
26 | answer_text = unescape(m.group(1)).strip().lower()
27 | return float(answer_text == ground_truth.lower())
28 | ```
29 |
30 | ### 3. **Define your math tool**
31 |
32 | A simple safe `eval` for math expressions:
33 |
34 | ```python
35 | def evaluate_expression(expr: str) -> str:
36 | try:
37 | result = eval(expr, {"__builtins__": {}})
38 | return str(result)
39 | except Exception as e:
40 | return f"Error: {str(e)}"
41 | ```
42 |
43 | ### 4. **Create the environment class**
44 |
45 | Bring it all together in a subclass of `BaseEnv`:
46 |
47 | ```python
48 | class SimpleMathEnv(BaseEnv):
49 | system_prompt: str = SYSTEM_PROMPT
50 | _reward_funcs: List[RewardFunction] = [reward_func]
51 |
52 | def __init__(self):
53 | eval_tool = ToolDefinition(
54 | name="evaluate",
55 | description="Safely evaluate a math expression like '2 + 3 * 4'.",
56 | input_schema={
57 | "type": "object",
58 | "properties": {
59 | "expr": {
60 | "type": "string",
61 | "description": "Math expression to evaluate.",
62 | },
63 | },
64 | "required": ["expr"],
65 | }
66 | )
67 | self.tools: Dict[str, Tuple[ToolDefinition, Callable]] = {
68 | "evaluate": (eval_tool, evaluate_expression)
69 | }
70 | def dataset_preprocess(self, example: dict) -> StandardizedExample:
71 | return {
72 | "prompt": f"Question: {example['question']}\n\nWrite your answer below.",
73 | "ground_truth": example.get("answer", ""),
74 | "init_rollout_args": {}
75 | }
76 |
77 | def list_tools(self) -> List[ToolDefinition]:
78 | return [tool_def for tool_def, _ in self.tools.values()]
79 |
80 | def run_tool(self, rollout_id: str, tool_name: str, **tool_args) -> Any:
81 | _, tool_fn = self.tools[tool_name]
82 | return tool_fn(**tool_args)
83 | ```
84 |
85 |
--------------------------------------------------------------------------------
/src/benchmax/envs/wikipedia/utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import html
3 | import re
4 | import threading
5 | from typing import Any, Dict, List, Optional
6 |
7 | import aiohttp
8 |
9 |
10 | def clean_html(raw: str) -> str:
11 | """Strip HTML tags and unescape entities."""
12 | text = re.sub(r"<[^>]+>", "", raw)
13 | return html.unescape(text)
14 |
15 |
16 | class APIKeyRotator:
17 | """Thread-safe round-robin iterator over API keys."""
18 |
19 | def __init__(self, keys: Optional[List[str]] = None):
20 | self._keys: List[str] = keys or []
21 | self._lock = threading.Lock()
22 | self._idx = 0
23 |
24 | def next(self) -> Optional[str]:
25 | """Return the next key, or None if no keys configured."""
26 | if not self._keys:
27 | return None
28 | with self._lock:
29 | key = self._keys[self._idx]
30 | self._idx = (self._idx + 1) % len(self._keys)
31 | return key
32 |
33 |
34 | class RateLimitExceeded(Exception):
35 | """Raised when the API repeatedly returns HTTP 429."""
36 |
37 |
38 | async def safe_request(
39 | method: str,
40 | url: str,
41 | *,
42 | headers: Dict[str, str],
43 | params: Dict[str, Any] | None = None,
44 | timeout: float = 10.0,
45 | json: Any | None = None,
46 | max_retries: int = 3,
47 | retry_delay_seconds: float = 20,
48 | rate_limit_seconds: float = 2.5,
49 | ) -> Optional[aiohttp.ClientResponse]:
50 | """
51 | Async HTTP request with exponential backoff on 429 rate limits.
52 |
53 | Args:
54 | method: HTTP method (GET, POST, etc.)
55 | url: Target URL
56 | headers: Request headers
57 | params: Query parameters
58 | timeout: Request timeout in seconds
59 | json: JSON body for request
60 | max_retries: Maximum retry attempts on 429
61 | retry_delay_seconds: Base delay between retries
62 | rate_limit_seconds: Initial delay before first request
63 |
64 | Returns:
65 | aiohttp.ClientResponse object
66 |
67 | Raises:
68 | RateLimitExceeded: When max retries exhausted on 429 errors
69 | """
70 | await asyncio.sleep(rate_limit_seconds)
71 |
72 | async with aiohttp.ClientSession() as session:
73 | for attempt in range(max_retries + 1):
74 | async with session.request(
75 | method,
76 | url,
77 | headers=headers,
78 | params=params,
79 | json=json,
80 | timeout=aiohttp.ClientTimeout(total=timeout),
81 | ) as resp:
82 | if resp.status != 429:
83 | # Read response content before returning
84 | content = await resp.read()
85 | # Create a new response object with the content
86 | resp._body = content
87 | return resp
88 |
89 | if attempt == max_retries:
90 | raise RateLimitExceeded(
91 | f"Rate limit hit and {max_retries} retries exhausted."
92 | )
93 |
94 | print(f"Rate limit hit, retrying in {retry_delay_seconds:.1f}s...")
95 | await asyncio.sleep(retry_delay_seconds)
96 |
--------------------------------------------------------------------------------
/tests/unit/prompts/test_tools.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any, List
2 | from benchmax.envs.base_env import ToolDefinition
3 | from benchmax.prompts.tools import (
4 | mcp2openai,
5 | parse_hermes_tool_call,
6 | render_tools_prompt,
7 | )
8 |
9 |
10 | def test_mcp2openai():
11 | # Test basic conversion
12 | tool_def = ToolDefinition(
13 | name="test_tool",
14 | description="A test tool",
15 | input_schema={"type": "object", "properties": {"arg1": {"type": "string"}}},
16 | )
17 | result = mcp2openai(tool_def)
18 |
19 | assert result["type"] == "function"
20 | assert result["function"]["name"] == "test_tool"
21 | assert result["function"]["description"] == "A test tool"
22 | assert result["function"]["parameters"] == {
23 | "type": "object",
24 | "properties": {"arg1": {"type": "string"}},
25 | "required": [],
26 | }
27 | assert result["function"]["strict"] is False
28 |
29 | # Test with empty input schema
30 | tool_def_no_schema = ToolDefinition(
31 | name="empty_tool", description="Tool with no schema", input_schema=None
32 | )
33 | result_no_schema = mcp2openai(tool_def_no_schema)
34 | assert result_no_schema["function"]["parameters"] == {"required": []}
35 |
36 |
37 | def test_parse_hermes_tool_call():
38 | # Test single tool call
39 | single_call = """{"name": "get_weather", "arguments": {"location": "New York"}}"""
40 | result: List[Dict[str, Any]] = parse_hermes_tool_call(single_call)
41 | assert len(result) == 1
42 | assert result[0]["name"] == "get_weather"
43 | assert result[0]["arguments"]["location"] == "New York"
44 |
45 | # Test multiple tool calls
46 | multiple_calls = """
47 | {"name": "tool1", "arguments": {"arg1": "value1"}}
48 | {"name": "tool2", "arguments": {"arg2": "value2"}}
49 | """
50 | result: List[Dict[str, Any]] = parse_hermes_tool_call(multiple_calls)
51 | assert len(result) == 2
52 | assert result[0]["name"] == "tool1"
53 | assert result[1]["name"] == "tool2"
54 |
55 | # Test empty string
56 | assert parse_hermes_tool_call("") == []
57 |
58 |
59 | def test_render_tools_prompt():
60 | # Test with empty tool list
61 | assert render_tools_prompt([]) == ""
62 |
63 | # Test with single tool
64 | tool_def = ToolDefinition(
65 | name="test_tool",
66 | description="A test tool",
67 | input_schema={"type": "object", "properties": {"arg1": {"type": "string"}}},
68 | )
69 | result = render_tools_prompt([tool_def], system_message="Test System Message")
70 |
71 | assert "Test System Message" in result
72 | assert "# Tools" in result
73 | assert "" in result
74 | assert "" in result
75 | assert "test_tool" in result
76 | assert "" in result
77 | assert "" in result
78 |
79 | # Test with multiple tools
80 | tool_def2 = ToolDefinition(
81 | name="another_tool",
82 | description="Another test tool",
83 | input_schema={"type": "object", "properties": {"arg2": {"type": "number"}}},
84 | )
85 | result_multiple = render_tools_prompt([tool_def, tool_def2])
86 | assert "test_tool" in result_multiple
87 | assert "another_tool" in result_multiple
88 |
--------------------------------------------------------------------------------
/src/benchmax/prompts/tools.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Dict, List
3 |
4 | from benchmax.envs.types import ToolDefinition
5 |
6 | def mcp2openai(mcp_tool: ToolDefinition) -> dict:
7 | """Convert a ToolDefinition to an OpenAI Function Call format."""
8 | openai_format = {
9 | "type": "function",
10 | "function": {
11 | "name": mcp_tool.name,
12 | "description": mcp_tool.description,
13 | "parameters": mcp_tool.input_schema or {},
14 | "strict": False,
15 | },
16 | }
17 | if not openai_format["function"]["parameters"].get("required", None):
18 | openai_format["function"]["parameters"]["required"] = []
19 | return openai_format
20 |
21 | def parse_hermes_tool_call(text: str) -> List[Dict[str, str]]:
22 | """
23 | Parse a tool call from Hermes XML format.
24 | Example:
25 |
26 | {"name": "get_weather", "arguments": {"location": "New York"}}
27 |
28 |
29 | {"name": "get_weather", "arguments": {"location": "New York"}}
30 |
31 | """
32 | import re
33 | import json
34 | # Match all tool call XML tags and extract the JSON content
35 | matches = re.finditer(r'(.*?)', text, re.DOTALL)
36 | tool_calls = []
37 |
38 | for match in matches:
39 | tool_call_json = match.group(1).strip()
40 | try:
41 | tool_calls.append(json.loads(tool_call_json))
42 | except json.JSONDecodeError as e:
43 | return []
44 |
45 | return tool_calls if tool_calls else []
46 |
47 | def render_tools_prompt(
48 | tool_definitions: List[ToolDefinition],
49 | system_message: str = ""
50 | ) -> str:
51 | """
52 | Build the prompt block that advertises the available function tools to the model.
53 |
54 | Parameters
55 | ----------
56 | tool_schema : list[dict]
57 | A list of tool descriptors in the OpenAI Tools / function-calling format.
58 | system_message : str, optional
59 | The system message that will be placed at the top of the prompt
60 | (defaults to the Qwen assistant greeting).
61 |
62 | Returns
63 | -------
64 | str
65 | A fully-rendered prompt string with system message and tool information.
66 | """
67 | tool_schema = [mcp2openai(tool_def) for tool_def in tool_definitions]
68 | if not tool_schema:
69 | return system_message
70 |
71 | # Header
72 | lines = [system_message, "", "# Tools", "",
73 | "You may call one or more functions to assist with the user query.",
74 | "",
75 | "You are provided with function signatures within XML tags:",
76 | ""]
77 |
78 | # One line-per-tool JSON dump (compact, no extra spaces)
79 | for tool in tool_schema:
80 | lines.append(json.dumps(tool, separators=(",", ":")))
81 |
82 | lines.extend([
83 | "",
84 | "",
85 | "For each function call, return a json object with function name and arguments within XML tags:",
86 | "",
87 | "{\"name\": , \"arguments\": }",
88 | "",
89 | ])
90 |
91 | return "\n".join(lines)
92 |
93 |
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/provisioners/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for server provisioners
3 | """
4 |
5 | import re
6 | import shutil
7 | import tempfile
8 | from pathlib import Path
9 | from typing import List, Optional
10 |
11 |
12 | def setup_sync_dir(workdir_path: Path) -> Path:
13 | """
14 | Set up a temporary directory to sync the workdir contents.
15 |
16 | This creates a temp directory and copies:
17 | 1. proxy_server.py from the mcp/ directory
18 | 2. All contents of the provided workdir_path
19 |
20 | Args:
21 | workdir_path: Path to workdir containing mcp_config.yaml, setup.sh, etc.
22 |
23 | Returns:
24 | Path to the temporary sync directory.
25 |
26 | Raises:
27 | FileNotFoundError: If required files are missing.
28 | ValueError: If workdir_path is not a directory.
29 | """
30 | sync_dir = Path(tempfile.mkdtemp(prefix="benchmax_skypilot_"))
31 |
32 | try:
33 | # Copy proxy_server.py (located in the mcp/ directory)
34 | src_server_path = Path(__file__).parent.parent / "proxy_server.py"
35 | if not src_server_path.exists():
36 | raise FileNotFoundError(
37 | f"Expected proxy_server.py at {src_server_path}, but not found."
38 | )
39 | shutil.copy(src_server_path, sync_dir / "proxy_server.py")
40 |
41 | # Validate workdir exists and is a directory
42 | if not workdir_path.exists():
43 | raise FileNotFoundError(
44 | f"Expected workdir_path at {workdir_path}, but not found."
45 | )
46 | if not workdir_path.is_dir():
47 | raise ValueError(
48 | f"Expected workdir_path at {workdir_path} to be a directory."
49 | )
50 |
51 | # Validate required files in the workdir using regex patterns
52 | required_patterns = {
53 | r"^reward_fn\.py$": "reward_fn.py",
54 | r"^setup\.sh$": "setup.sh",
55 | r"^mcp_config\.(yaml|yml)$": "mcp_config.yaml or mcp_config.yml",
56 | }
57 |
58 | workdir_files = {f.name for f in workdir_path.iterdir() if f.is_file()}
59 |
60 | for pattern, description in required_patterns.items():
61 | pattern_re = re.compile(pattern)
62 | if not any(pattern_re.match(filename) for filename in workdir_files):
63 | raise FileNotFoundError(
64 | f"Required file matching '{description}' not found in workdir_path '{workdir_path}'."
65 | )
66 |
67 | # Copy all contents of the workdir
68 | shutil.copytree(workdir_path, sync_dir, dirs_exist_ok=True)
69 |
70 | except Exception:
71 | shutil.rmtree(sync_dir, ignore_errors=True)
72 | raise
73 |
74 | return sync_dir
75 |
76 |
77 | def cleanup_dir(path: Optional[Path]) -> None:
78 | """
79 | Recursively delete a directory if it exists.
80 |
81 | Args:
82 | path: Path to directory to delete. If None or doesn't exist, no-op.
83 | """
84 | if path and path.exists() and path.is_dir():
85 | shutil.rmtree(path, ignore_errors=True)
86 |
87 |
88 | def get_setup_command() -> str:
89 | """Generate setup command for installing dependencies."""
90 | return """
91 | # Install uv
92 | curl -LsSf https://astral.sh/uv/install.sh | sh
93 | UV_VENV_CLEAR=1 uv venv ~/venv && source ~/venv/bin/activate
94 | uv pip install fastmcp~=2.12.0 pyyaml psutil
95 | bash setup.sh
96 | """
97 |
98 |
99 | def get_run_command(ports: List[str]) -> str:
100 | """Generate command to start multiple proxy servers on different ports."""
101 | commands = []
102 | for port in ports:
103 | cmd = f"source ~/venv/bin/activate && python proxy_server.py --port {port} --base-dir ../workspace &"
104 | commands.append(cmd)
105 | commands.append("wait") # Wait for all background processes
106 | return "\n".join(commands)
107 |
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/example_workdir/reward_fn.py:
--------------------------------------------------------------------------------
1 | """
2 | Reward functions for demo calculator MCP server.
3 |
4 | Three reward types:
5 | 1. Stateless completion check
6 | 2. Tool call variable in memory check
7 | 3. Workspace log check
8 |
9 | All reward functions will receive the same ground truth and completion.
10 | When defining the dataset, you can determine the shape of your ground truth.
11 |
12 | In this example, the ground truth is a dictionary of the shape:
13 | {
14 | 'completion': str,
15 | 'variable': { name: str, expected_value: float },
16 | 'log': { filename: str, expected_content: str }
17 | }
18 |
19 | Each reward fn extract what they need from the ground_truth dictionary.
20 | Each of these reward shows way of computing reward using completion, calling mcp server tool,
21 | and reaading from the workspace that the MCP is operating in.
22 |
23 | """
24 |
25 | from pathlib import Path
26 | from typing import Any, Callable, Dict, Union, Awaitable
27 | from mcp.types import TextContent
28 | from fastmcp import Client
29 | from fastmcp.exceptions import ToolError
30 |
31 | RewardFunction = Callable[..., Union[float, Awaitable[float]]]
32 |
33 |
34 | # -------------------------------
35 | # Reward 0: Stateless completion check
36 | # -------------------------------
37 | async def completion_match_reward(
38 | completion: str,
39 | ground_truth: dict,
40 | mcp_client: Client,
41 | workspace: Path,
42 | **kwargs: Any,
43 | ) -> float:
44 | """
45 | Return 1.0 if completion matches ground_truth['completion'], else 0.0
46 |
47 | Uses: ground_truth['completion'] (str)
48 | """
49 | expected = ground_truth.get("completion", "")
50 | return 1.0 if completion.strip() == expected.strip() else 0.0
51 |
52 |
53 | # -------------------------------
54 | # Reward 1: Tool call variable in memory check
55 | # -------------------------------
56 | async def variable_in_memory_reward(
57 | completion: str, ground_truth: dict, mcp_client: Client, workspace: Path, **kwargs
58 | ) -> float:
59 | """
60 | Reward uses tool call to match in-memory variable value.
61 |
62 | Uses: ground_truth['variable'] = {"name": str, "expected_value": float}
63 | """
64 | variable_spec = ground_truth.get("variable", {})
65 | var_name = variable_spec.get("name")
66 | expected = variable_spec.get("expected_value")
67 |
68 | if not var_name or expected is None:
69 | return 0.0
70 |
71 | if not mcp_client or not mcp_client.is_connected():
72 | return 0.0
73 |
74 | try:
75 | # Call tool via MCP client
76 | response = await mcp_client.call_tool("get_variable", {"name": var_name})
77 |
78 | # Extract text from TextContent objects
79 | text_contents = []
80 | for content in response.content:
81 | if isinstance(content, TextContent):
82 | text_contents.append(content.text)
83 |
84 | combined_text = "\n".join(text_contents).strip()
85 | value = float(combined_text)
86 |
87 | return 1.0 if value == expected else 0.0
88 |
89 | except ToolError:
90 | return 0.0
91 | except Exception:
92 | return 0.0
93 |
94 |
95 | # -------------------------------
96 | # Reward 2: Workspace log check
97 | # -------------------------------
98 | async def log_in_workspace_reward(
99 | completion: str,
100 | ground_truth: dict,
101 | mcp_client: Client,
102 | workspace: Path,
103 | **kwargs: Any,
104 | ) -> float:
105 | """
106 | Reward based on workspace file content.
107 |
108 | Uses: ground_truth['log'] = {"filename": str, "expected_content": str}
109 | """
110 | log_spec = ground_truth.get("log", {})
111 | filename = log_spec.get("filename")
112 | expected = log_spec.get("expected_content", "")
113 |
114 | if not filename:
115 | return 0.0
116 |
117 | file_path = Path(workspace) / filename
118 | if not file_path.exists():
119 | return 0.0
120 |
121 | content = file_path.read_text().strip()
122 | return 1.0 if content == expected.strip() else 0.0
123 |
124 |
125 | # -------------------------------
126 | # Export reward functions
127 | # -------------------------------
128 | reward_functions: Dict[str, RewardFunction] = {
129 | "completion": completion_match_reward,
130 | "variable": variable_in_memory_reward,
131 | "log": log_in_workspace_reward,
132 | }
133 |
--------------------------------------------------------------------------------
/tests/integration/envs/mcp/provisioners/test_local_integration.py:
--------------------------------------------------------------------------------
1 | """
2 | Integration tests for LocalProvisioner.
3 | """
4 |
5 | import pytest
6 | import asyncio
7 | from pathlib import Path
8 | from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner
9 | from tests.integration.envs.mcp.provisioners.utils import wait_for_server_health, check_health
10 |
11 |
12 | class TestEndToEnd:
13 | """End-to-end integration tests for provisioning and teardown."""
14 |
15 | @pytest.mark.asyncio
16 | @pytest.mark.slow
17 | async def test_single_server_lifecycle(self, example_workdir: Path):
18 | """Test complete lifecycle of a single server: provision, verify, and teardown."""
19 | provisioner = LocalProvisioner(
20 | workdir_path=example_workdir,
21 | num_servers=1,
22 | base_port=9000,
23 | )
24 |
25 | # Provision
26 | api_secret = "single-server-test-secret-32chars!!"
27 | addresses = await provisioner.provision_servers(api_secret)
28 | assert len(addresses) == 1
29 | assert addresses[0] == "localhost:9000"
30 |
31 | # Verify server is up and healthy
32 | is_healthy = await wait_for_server_health(addresses[0])
33 | assert is_healthy, "Server failed to become healthy"
34 |
35 | # Verify process is running
36 | assert len(provisioner._processes) == 1
37 | assert provisioner._processes[0].poll() is None
38 |
39 | # Check that double-provisioning would result in an error
40 | with pytest.raises(RuntimeError, match="already provisioned"):
41 | await provisioner.provision_servers(api_secret)
42 |
43 | # Teardown
44 | await provisioner.teardown()
45 |
46 | # Verify cleanup
47 | assert len(provisioner._processes) == 0
48 | await asyncio.sleep(0.5)
49 | assert not await check_health("localhost:9000"), "Server still responding after teardown"
50 |
51 | @pytest.mark.asyncio
52 | @pytest.mark.slow
53 | async def test_multiple_servers_lifecycle(self, example_workdir: Path):
54 | """Test complete lifecycle of multiple servers: provision, verify, and teardown."""
55 | provisioner = LocalProvisioner(
56 | workdir_path=example_workdir,
57 | num_servers=5,
58 | base_port=9100,
59 | )
60 |
61 | # Provision
62 | api_secret = "single-server-test-secret-32chars!!"
63 | addresses = await provisioner.provision_servers(api_secret)
64 | assert len(addresses) == 5
65 | expected_addresses = [f"localhost:{9100 + i}" for i in range(5)]
66 | assert addresses == expected_addresses
67 |
68 | # Verify all servers are up and healthy
69 | health_checks = await asyncio.gather(
70 | *[wait_for_server_health(addr) for addr in addresses]
71 | )
72 | assert all(health_checks), "Not all servers became healthy"
73 |
74 | # Verify all processes are running
75 | assert len(provisioner._processes) == 5
76 | assert all(p.poll() is None for p in provisioner._processes)
77 |
78 | # Teardown
79 | await provisioner.teardown()
80 |
81 | # Verify cleanup
82 | assert len(provisioner._processes) == 0
83 | await asyncio.sleep(0.5)
84 |
85 | # Verify all servers are down
86 | for addr in addresses:
87 | assert not await check_health(addr), f"Server {addr} still responding after teardown"
88 |
89 |
90 | class TestValidation:
91 | """Test parameter validation."""
92 |
93 | def test_invalid_num_servers(self):
94 | """Test validation of num_servers parameter."""
95 | with pytest.raises(ValueError, match="at least 1"):
96 | LocalProvisioner(workdir_path=".", num_servers=0, base_port=8080)
97 |
98 | def test_invalid_base_port(self):
99 | """Test validation of base_port parameter."""
100 | with pytest.raises(ValueError, match="between 1024 and 65535"):
101 | LocalProvisioner(workdir_path=".", num_servers=1, base_port=500)
102 |
103 | def test_port_range_exceeds_max(self):
104 | """Test validation of port range."""
105 | with pytest.raises(ValueError, match="exceeds max port"):
106 | LocalProvisioner(workdir_path=".", num_servers=100, base_port=65500)
107 |
--------------------------------------------------------------------------------
/examples/skyrl/README.md:
--------------------------------------------------------------------------------
1 | # Benchmax Environments with SkyRL
2 |
3 | This example directory contains example for both Math environment and Excel environment. The guide below is written for the Math environment but the high-level idea is broadly transferrable.
4 |
5 | ### **🔍 Quickstart: RL Math Agent with SkyRL + `benchmax`**
6 |
7 | We can fine-tune an RL agent with SkyRL using a `benchmax` environment — for example, the `math` environment — and a tool-enabled setup for calculator-like reasoning.
8 |
9 | Example script:
10 | `examples/benchmax/run_benchmax_math.sh`
11 |
12 | ---
13 |
14 | ## 1. Prepare the dataset
15 |
16 | Use `benchmax_data_process.py` to convert a HuggingFace dataset into the multiturn chat format expected by SkyRL rollouts.
17 |
18 | Example for `arithmetic50`:
19 |
20 | ```bash
21 | uv run src/benchmax/adapters/skyrl/benchmax_data_process.py \
22 | --local_dir ~/data/math \
23 | --dataset_name dawidmt/arithmetic50 \
24 | --env_path benchmax.envs.math.math_env.MathEnvLocal
25 | ```
26 |
27 | **Arguments**:
28 |
29 | | Flag | Required? | Description |
30 | | ---------------- | --------- | --------------------------------------------------------------------------------------------- |
31 | | `--local_dir` | yes | Output folder for `train.parquet` & `test.parquet`. |
32 | | `--dataset_name` | yes | HuggingFace dataset name or path. |
33 | | `--env_path` | yes | Dotted path to the `benchmax` environment class (e.g. `benchmax.envs.math.math_env.MathEnvLocal`). |
34 |
35 | ---
36 |
37 | ## 2. Launch training — **Focusing on Environment Arguments**
38 |
39 | In `run_benchmax_math.sh`, the environment is configured with:
40 |
41 | ```bash
42 | ENV_CLASS="MathEnv"
43 | ...
44 | environment.env_class=$ENV_CLASS \
45 | ```
46 |
47 | **How it works**:
48 |
49 | * **`environment.env_class`**: This must match the **ID** registered in `skyrl_gym.envs.register(...)`.
50 |
51 | * In `benchmax_math.py`, the registration happens here:
52 |
53 | ```python
54 | register(
55 | id="MathEnv",
56 | entry_point=load_benchmax_env_skyrl,
57 | kwargs={"actor": actor}
58 | )
59 | ```
60 | * The `id` value (`"MathEnv`) is what `ENV_CLASS` should be set to in the shell script and in the preprocessing step.
61 |
62 | * **Ray Actor Creation**: Before registration, the script calls:
63 |
64 | ```python
65 | get_or_create_benchmax_env_actor(MathEnvLocal)
66 | ```
67 |
68 | This:
69 |
70 | 1. Imports your environment class (`from benchmax.envs.math.math_env import MathEnvLocal`).
71 | 2. Starts a persistent Ray actor (`BenchmaxEnvActor`) that wraps the environment.
72 | 3. Names the actor `BaseEnvService`, so `load_benchmax_env_skyrl` can attach to it when Gym launches the env.
73 |
74 | * **Putting it together**:
75 | `ENV_CLASS` → matches Gym registry `id` → which maps to `load_benchmax_env_skyrl` → which connects to the Ray actor created from your imported environment class.
76 |
77 | To run:
78 |
79 | ```bash
80 | bash examples/skyrl/run_benchmax_math.sh
81 | ```
82 |
83 | ---
84 |
85 | ## 4. Using Your Own Benchmax Environment with SkyRL
86 |
87 | To integrate a new `benchmax` environment:
88 |
89 | 1. **Create or select your `benchmax` environment class**
90 |
91 | * Must subclass `benchmax.BaseEnv` (or `benchmax.mcp.ParallelMcpEnv` for multi-node support) and implement required methods.
92 | * Add tools, rewards, and `system_prompt` as needed.
93 |
94 | 2. **Update the SkyRL entrypoint**
95 |
96 | * In your copy of `benchmax_math.py`, change:
97 |
98 | ```python
99 | from benchmax.envs.math.math_env import MathEnvLocal
100 | get_or_create_benchmax_env_actor(MathEnvLocal)
101 | ```
102 |
103 | to import and use your environment class. Change the env to MathEnvSkypilot to run MCP servers on multiple external nodes
104 |
105 | 3. **Update the registry ID & shell script**
106 |
107 | * Change the `id` in `register(...)` to match your environment name.
108 | * Update `ENV_CLASS` in your run script to match this ID.
109 |
110 | 4. **Update dataset preprocessing**
111 |
112 | * Use `--env_path` pointing to your environment class in the preprocessing command.
113 |
114 | ---
115 |
116 | Once set up, SkyRL will:
117 |
118 | * Launch your environment in a Ray actor
119 | * Register it in the SkyRL gym
120 | * Fine-tune your model via PPO with configurable multi-turn RL training
--------------------------------------------------------------------------------
/src/benchmax/envs/crm/crm_env.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any, Dict, List, Optional
3 | import sky
4 |
5 | from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv
6 | from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
7 | from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner
8 | from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner
9 | from benchmax.envs.types import StandardizedExample
10 |
11 |
12 | SYSTEM_PROMPT = """\
13 | You are an expert in Salesforce and you have access to a Salesforce instance.
14 |
15 | # Instructions
16 | - You will be provided a question, the system description, and relevant task context.
17 | - Interact with the Salesforce instance using the tools provided to help you answer the question.
18 | - You should ALWAYS make ONLY ONE tool call at a time. If you want to submit your final answer, just respond with the answer without tool call. If not, you should call some other tool.
19 | - Always end by respond with ONLY the answer, NO full sentence or any explanation.
20 | - If your answer is empty that is there are no records found matching the requirements mentioned, just return 'None' as the response.
21 | - You should be able to get to an answer within 2-3 tool calls, so don't overthink.
22 |
23 | Write your complete answer on the final line, within the xml tags . If there are multiple answers, use comma as a delimiter.
24 | e.g.
25 | For Case IDs, final answer should look like 0XA124XDF. If there are multiple, it could look like 0XA124XDF, 001XX000003GXXX
26 | For Months, it could look like May,July
27 | If nothing matches, output None
28 | """
29 |
30 |
31 | class CRMExample(StandardizedExample):
32 | reward_metric: str
33 |
34 |
35 | class CRMEnv(ParallelMcpEnv):
36 | """Environment for CRM tasks using MCP with Salesforce"""
37 |
38 | system_prompt: str = SYSTEM_PROMPT
39 |
40 | def __init__(self, workdir_path: Path, provisioner: BaseProvisioner, **kwargs):
41 | """Initialize CRMEnv."""
42 | super().__init__(workdir_path, provisioner, **kwargs)
43 |
44 | @classmethod
45 | def dataset_preprocess(cls, example: Any, **kwargs) -> CRMExample:
46 | # convert dataset example into CRMExample (inherit from StandardizedExample)
47 | task: Optional[str] = example.get("task")
48 | persona: Optional[str] = example.get("persona")
49 | metadata: Optional[Dict[str, Any]] = example.get("metadata")
50 | answer: Optional[List[str]] = example.get("answer")
51 | query: Optional[str] = example.get("query")
52 | reward_metric: Optional[str] = example.get("reward_metric")
53 |
54 | if not task or not persona or not query or answer is None or not reward_metric:
55 | raise ValueError(
56 | "Example must contain 'task', 'persona', 'query', 'answer', and 'reward_metric' fields"
57 | )
58 |
59 | prompt = f"{persona}\n{task}\n{query}"
60 | if metadata and "required" in metadata:
61 | required_metadata = metadata["required"]
62 | prompt = f"{persona}\n{task}\n{required_metadata}\n{query}"
63 |
64 | return CRMExample(
65 | prompt=prompt,
66 | ground_truth=answer,
67 | init_rollout_args=None,
68 | reward_metric=reward_metric,
69 | )
70 |
71 |
72 | class CRMEnvLocal(CRMEnv):
73 | """Import this env to run environment locally"""
74 |
75 | def __init__(self, num_local_servers: int = 5, **kwargs):
76 | workdir_path = Path(__file__).parent / "workdir"
77 | provisioner = LocalProvisioner(
78 | workdir_path=workdir_path, num_servers=num_local_servers
79 | )
80 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs)
81 |
82 |
83 | class CRMEnvSkypilot(CRMEnv):
84 | """Import this env to run environment on any cloud (i.e. AWS / GCP / Azure) with Skypilot"""
85 |
86 | def __init__(
87 | self,
88 | cloud: sky.clouds.Cloud = sky.Azure(),
89 | num_nodes: int = 2,
90 | servers_per_node: int = 5,
91 | **kwargs,
92 | ):
93 | workdir_path = Path(__file__).parent / "workdir"
94 | provisioner = SkypilotProvisioner(
95 | workdir_path=workdir_path,
96 | cloud=cloud,
97 | num_nodes=num_nodes,
98 | servers_per_node=servers_per_node,
99 | )
100 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs)
101 |
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/example_workdir/demo_mcp_server.py:
--------------------------------------------------------------------------------
1 | """
2 | Demo MCP server for calculator-style workflow.
3 |
4 | Tools:
5 | - hello_world: stateless sanity check
6 | - define_variable / get_variable: in-memory calculator variables
7 | - evaluate: arithmetic using stored variables
8 | - append_log / read_log: workspace file I/O
9 | - allocate_memory: stress test
10 | """
11 |
12 | from pathlib import Path
13 | from typing import Dict
14 | from fastmcp import FastMCP
15 | from fastmcp.exceptions import ToolError
16 |
17 | # ----------------------------------------------------------------------
18 | # MCP server setup
19 | # ----------------------------------------------------------------------
20 |
21 | mcp = FastMCP("demo-calculator-server")
22 |
23 | # In-memory state
24 | _variables: Dict[str, float] = {}
25 |
26 | # Memory stress
27 | _leaked_memory = []
28 |
29 | # ----------------------------------------------------------------------
30 | # Tools
31 | # ----------------------------------------------------------------------
32 |
33 |
34 | @mcp.tool()
35 | async def hello_world(name: str) -> str:
36 | """Simple stateless greeting."""
37 | return f"Hello, {name}!"
38 |
39 |
40 | def is_valid_var_name(name: str) -> bool:
41 | """Check if a variable name is valid: start with letter/_ and contain only letters, digits, or _"""
42 | if not name:
43 | return False
44 | if not (name[0].isalpha() or name[0] == "_"):
45 | return False
46 | return all(c.isalnum() or c == "_" for c in name)
47 |
48 |
49 | @mcp.tool()
50 | async def define_variable(name: str, value: float) -> str:
51 | """Store a named variable in memory."""
52 | if not is_valid_var_name(name):
53 | raise ToolError(f"Invalid variable name: '{name}'")
54 | _variables[name] = value
55 | return f"Variable '{name}' set to {value}"
56 |
57 |
58 | @mcp.tool()
59 | async def get_variable(name: str) -> str:
60 | """Retrieve a variable from memory."""
61 | if not is_valid_var_name(name):
62 | raise ToolError(f"Invalid variable name: '{name}'")
63 | if name not in _variables:
64 | raise ToolError(f"Variable '{name}' not defined")
65 | return str(_variables[name])
66 |
67 |
68 | @mcp.tool()
69 | async def evaluate(expression: str) -> str:
70 | """
71 | Evaluate arithmetic expression using stored variables.
72 |
73 | Only numbers and defined variable names are allowed.
74 | """
75 | allowed_names = {k: v for k, v in _variables.items()}
76 | allowed_chars = "0123456789+-*/()., "
77 |
78 | # Split expression manually into potential identifiers and other tokens
79 | token = ""
80 | for c in expression + " ": # add space to flush last token
81 | if c.isalnum() or c == "_":
82 | token += c
83 | else:
84 | if token:
85 | if token[0].isalpha() or token[0] == "_":
86 | if token not in allowed_names:
87 | raise ToolError(f"Undefined variable: '{token}'")
88 | if not is_valid_var_name(token):
89 | raise ToolError(f"Invalid variable name: '{token}'")
90 | # numeric token is implicitly allowed
91 | token = ""
92 | if c not in allowed_chars and not c.isspace():
93 | raise ToolError(f"Invalid character in expression: '{c}'")
94 |
95 | try:
96 | result = eval(expression, {"__builtins__": {}}, allowed_names)
97 | return str(result)
98 | except Exception as e:
99 | raise ToolError(f"Evaluation error: {str(e)}")
100 |
101 |
102 | @mcp.tool()
103 | async def append_log(filename: str, message: str) -> str:
104 | """Append a message to a workspace file."""
105 | file_path = Path(filename)
106 | with open(file_path, "a") as f:
107 | f.write(message + "\n")
108 | return f"Appended message to {filename}"
109 |
110 |
111 | @mcp.tool()
112 | async def read_log(filename: str) -> str:
113 | """Read the content of a workspace file."""
114 | file_path = Path(filename)
115 | if not file_path.exists():
116 | raise ToolError(f"File '{filename}' not found")
117 | return file_path.read_text()
118 |
119 |
120 | @mcp.tool()
121 | async def allocate_memory(megabytes: int) -> str:
122 | """Allocate memory to simulate stress / OOM."""
123 | global _leaked_memory
124 | size = megabytes * 1024 * 1024
125 | _leaked_memory.append(bytearray(size))
126 | return f"Leaked {megabytes} MB (total allocations: {len(_leaked_memory)})"
127 |
128 |
129 | if __name__ == "__main__":
130 | # Run the server
131 | mcp.run()
132 |
--------------------------------------------------------------------------------
/tests/unit/envs/mcp/provisioners/test_local.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pathlib import Path
3 | from unittest.mock import Mock, patch, MagicMock, AsyncMock
4 | from benchmax.envs.mcp.provisioners import LocalProvisioner
5 |
6 |
7 | # ---------------------------------------------------------------------------
8 | # Initialization
9 | # ---------------------------------------------------------------------------
10 |
11 |
12 | class TestLocalProvisionerInit:
13 | """Tests for LocalProvisioner initialization."""
14 |
15 | def test_init_requires_workdir_path(self, example_workdir: Path) -> None:
16 | """Provisioner initializes with given workdir and defaults."""
17 | LocalProvisioner(workdir_path=example_workdir)
18 |
19 |
20 | # ---------------------------------------------------------------------------
21 | # Provisioning
22 | # ---------------------------------------------------------------------------
23 |
24 |
25 | class TestLocalProvisionerProvision:
26 | """Tests for LocalProvisioner.provision_servers."""
27 |
28 | @pytest.mark.asyncio
29 | @patch("benchmax.envs.mcp.provisioners.local_provisioner.setup_sync_dir")
30 | @patch.object(LocalProvisioner, "_spawn_process", new_callable=AsyncMock)
31 | async def test_provision_runs_expected_commands(
32 | self,
33 | mock_spawn: AsyncMock,
34 | mock_setup: Mock,
35 | example_workdir: Path,
36 | test_sync_dir: Path,
37 | ) -> None:
38 | """Provisioner should prepare sync dir and start subprocesses."""
39 | mock_setup.return_value = test_sync_dir
40 | mock_proc = MagicMock()
41 | mock_proc.pid = 1234
42 | mock_spawn.return_value = mock_proc
43 |
44 | provisioner = LocalProvisioner(workdir_path=example_workdir, num_servers=2)
45 | addresses = await provisioner.provision_servers("dummy-api-secret")
46 |
47 | assert len(addresses) == 2
48 | mock_setup.assert_called_once_with(example_workdir)
49 | # First call is setup_cmd (wait=True), next calls are servers
50 | assert mock_spawn.call_count == 3
51 | assert addresses == ["localhost:8080", "localhost:8081"]
52 |
53 | @pytest.mark.asyncio
54 | @patch("benchmax.envs.mcp.provisioners.local_provisioner.setup_sync_dir")
55 | async def test_provision_handles_setup_failure(
56 | self, mock_setup: Mock, example_workdir: Path
57 | ) -> None:
58 | """setup_sync_dir errors are propagated."""
59 | mock_setup.side_effect = OSError("setup failed")
60 | provisioner = LocalProvisioner(workdir_path=example_workdir)
61 |
62 | with pytest.raises(OSError):
63 | await provisioner.provision_servers("dummy-api-secret")
64 |
65 |
66 | # ---------------------------------------------------------------------------
67 | # Teardown
68 | # ---------------------------------------------------------------------------
69 |
70 |
71 | class TestLocalProvisionerTeardown:
72 | """Tests for LocalProvisioner.teardown."""
73 |
74 | @pytest.mark.asyncio
75 | @patch("benchmax.envs.mcp.provisioners.local_provisioner.cleanup_dir")
76 | async def test_teardown_kills_processes_and_cleans_up(
77 | self, mock_cleanup: Mock, example_workdir: Path, test_sync_dir: Path
78 | ) -> None:
79 | """Active processes are killed and sync dir cleaned up."""
80 | provisioner = LocalProvisioner(workdir_path=example_workdir)
81 |
82 | proc = MagicMock()
83 | proc.poll.return_value = None
84 | proc.kill = MagicMock()
85 | proc.wait = MagicMock()
86 | provisioner._processes = [proc]
87 | provisioner._sync_dir = test_sync_dir
88 | provisioner._is_provisioned = True
89 |
90 | await provisioner.teardown()
91 | proc.kill.assert_called_once()
92 | proc.wait.assert_called_once()
93 | mock_cleanup.assert_called_once_with(test_sync_dir)
94 |
95 | @pytest.mark.asyncio
96 | @patch("benchmax.envs.mcp.provisioners.local_provisioner.cleanup_dir")
97 | async def test_teardown_skips_already_terminated(
98 | self, mock_cleanup: Mock, example_workdir: Path, test_sync_dir: Path
99 | ) -> None:
100 | """Teardown skips processes that are already terminated."""
101 | provisioner = LocalProvisioner(workdir_path=example_workdir)
102 | proc = MagicMock()
103 | proc.poll.return_value = 0 # Already exited
104 | proc.kill = MagicMock()
105 | provisioner._processes = [proc]
106 | provisioner._sync_dir = test_sync_dir
107 | provisioner._is_provisioned = True
108 |
109 | await provisioner.teardown()
110 | proc.kill.assert_not_called()
111 | mock_cleanup.assert_called_once_with(test_sync_dir)
112 |
--------------------------------------------------------------------------------
/tests/unit/envs/mcp/provisioners/test_provisioners_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for MCP utility functions.
3 | """
4 |
5 | import pytest
6 | import tempfile
7 | from pathlib import Path
8 |
9 | from benchmax.envs.mcp.provisioners.utils import (
10 | setup_sync_dir,
11 | cleanup_dir,
12 | )
13 |
14 |
15 | class TestSetupSyncDir:
16 | """Tests for setup_sync_dir function."""
17 |
18 | def test_setup_sync_dir_creates_temp_directory(self, example_workdir: Path):
19 | """Test that sync directory is created in temp location."""
20 | sync_dir = setup_sync_dir(example_workdir)
21 |
22 | try:
23 | assert sync_dir.exists()
24 | assert sync_dir.is_dir()
25 | assert "benchmax_skypilot_" in sync_dir.name
26 | finally:
27 | cleanup_dir(sync_dir)
28 |
29 | def test_setup_sync_dir_copies_proxy_server(self, example_workdir: Path):
30 | """Test that proxy_server.py is copied to sync directory."""
31 | sync_dir = setup_sync_dir(example_workdir)
32 |
33 | try:
34 | proxy_server = sync_dir / "proxy_server.py"
35 | assert proxy_server.exists()
36 | assert proxy_server.is_file()
37 | finally:
38 | cleanup_dir(sync_dir)
39 |
40 | def test_setup_sync_dir_copies_workdir_contents(self, example_workdir: Path):
41 | """Test that all workdir contents are copied."""
42 | sync_dir = setup_sync_dir(example_workdir)
43 |
44 | try:
45 | # Check required files
46 | assert (sync_dir / "mcp_config.yaml").exists()
47 | assert (sync_dir / "reward_fn.py").exists()
48 | assert (sync_dir / "setup.sh").exists()
49 | assert (sync_dir / "proxy_server.py").exists()
50 | finally:
51 | cleanup_dir(sync_dir)
52 |
53 | def test_setup_sync_dir_validates_required_files(self, tmp_path: Path):
54 | """Test that missing required files raise FileNotFoundError."""
55 | # Create incomplete workdir
56 | incomplete_workdir = tmp_path / "incomplete"
57 | incomplete_workdir.mkdir()
58 | (incomplete_workdir / "mcp_config.yaml").touch()
59 | # Missing reward_fn.py and setup.sh
60 |
61 | with pytest.raises(FileNotFoundError, match="reward_fn.py"):
62 | setup_sync_dir(incomplete_workdir)
63 |
64 | def test_setup_sync_dir_handles_nonexistent_workdir(self, tmp_path: Path):
65 | """Test that nonexistent workdir raises FileNotFoundError."""
66 | nonexistent = tmp_path / "does_not_exist"
67 |
68 | with pytest.raises(FileNotFoundError, match="workdir_path"):
69 | setup_sync_dir(nonexistent)
70 |
71 | def test_setup_sync_dir_handles_file_as_workdir(self, tmp_path: Path):
72 | """Test that file instead of directory raises ValueError."""
73 | file_path = tmp_path / "not_a_dir.txt"
74 | file_path.touch()
75 |
76 | with pytest.raises(ValueError, match="directory"):
77 | setup_sync_dir(file_path)
78 |
79 | def test_setup_sync_dir_cleanup_on_error(self, tmp_path: Path):
80 | """Test that sync directory is cleaned up if setup fails."""
81 | # Create workdir with required files but missing reward_fn.py
82 | incomplete_workdir = tmp_path / "workdir"
83 | incomplete_workdir.mkdir()
84 | (incomplete_workdir / "mcp_config.yaml").touch()
85 | (incomplete_workdir / "setup.sh").touch()
86 |
87 | # This should fail because reward_fn.py doesn't exist in mcp/
88 | # But the temp directory should be cleaned up
89 | with pytest.raises(FileNotFoundError):
90 | setup_sync_dir(incomplete_workdir)
91 |
92 |
93 | class TestCleanupDir:
94 | """Tests for cleanup_dir function."""
95 |
96 | def test_cleanup_dir_removes_directory(self):
97 | """Test that cleanup_dir removes an existing directory."""
98 | temp_dir = Path(tempfile.mkdtemp(prefix="test_cleanup_"))
99 | assert temp_dir.exists()
100 |
101 | cleanup_dir(temp_dir)
102 | assert not temp_dir.exists()
103 |
104 | def test_cleanup_dir_handles_none(self):
105 | """Test that cleanup_dir handles None gracefully."""
106 | cleanup_dir(None) # Should not raise
107 |
108 | def test_cleanup_dir_handles_nonexistent(self, tmp_path: Path):
109 | """Test that cleanup_dir handles nonexistent path gracefully."""
110 | nonexistent = tmp_path / "does_not_exist"
111 | cleanup_dir(nonexistent) # Should not raise
112 |
113 | def test_cleanup_dir_handles_file(self, tmp_path: Path):
114 | """Test that cleanup_dir ignores files (only removes directories)."""
115 | file_path = tmp_path / "file.txt"
116 | file_path.touch()
117 |
118 | cleanup_dir(file_path) # Should not raise
119 | assert file_path.exists() # File should still exist
120 |
--------------------------------------------------------------------------------
/src/benchmax/envs/crm/README.md:
--------------------------------------------------------------------------------
1 | # CRM Environment
2 |
3 | This environment provides capabilities for interacting with Salesforce instances through the Salesforce API.
4 |
5 | This is based off [CRMArena Pro](https://github.com/SalesforceAIResearch/CRMArena) & supports both B2B and B2C configurations.
6 |
7 | ## Prerequisites
8 |
9 | **Important**: Before using this environment, ensure you have:
10 | - Python 3.12 or higher
11 |
12 | ## Installation
13 |
14 | ```bash
15 | pip install "benchmax[crm,skypilot]"
16 | ```
17 |
18 | Includes:
19 | - simple-salesforce: For Salesforce API interactions
20 | - python-dateutil: For date/time handling
21 | - fastmcp: For MCP server functionality
22 |
23 | ## Usage
24 |
25 | Use `CRMEnvLocal` to run the servers locally on the same machine as benchmax or use `CRMEnvSkypilot` to parallelize the servers across multiple nodes.
26 |
27 | ## Configuration
28 |
29 | The environment has built-in Salesforce configurations:
30 |
31 | ### B2B Configuration
32 | Default configuration using B2B Salesforce instance. Used automatically when initializing CRMEnv.
33 |
34 | ### B2C Configuration
35 | Alternative configuration for B2C use cases, can be enabled by passing "b2c" to get_mcp_config().
36 |
37 | ## Available Tools
38 |
39 | The environment provides a comprehensive set of MCP tools for Salesforce interactions:
40 |
41 | ### Case Management
42 | - `get_cases`: Retrieve cases based on various filtering criteria (dates, agents, statuses)
43 | - `get_non_transferred_case_ids`: Get cases not transferred between agents in a period
44 | - `get_agent_handled_cases_by_period`: Get number of cases handled by each agent
45 | - `get_agent_transferred_cases_by_period`: Get number of cases transferred between agents
46 | - `get_livechat_transcript_by_case_id`: Retrieve live chat transcripts for a case
47 |
48 | ### Agent Analysis
49 | - `get_qualified_agent_ids_by_case_count`: Filter agents based on case handling count
50 | - `get_agents_with_max_cases`: Find agents with most cases in a subset
51 | - `get_agents_with_min_cases`: Find agents with fewest cases in a subset
52 | - `calculate_average_handle_time`: Calculate average case handling time per agent
53 |
54 | ### Regional Analysis
55 | - `get_shipping_state`: Add shipping state information to cases
56 | - `calculate_region_average_closure_times`: Calculate average case closure times by region
57 |
58 | ### Time Period Management
59 | - `get_start_date`: Calculate start date based on period and interval
60 | - `get_period`: Get date range for named periods (months, quarters, seasons)
61 | - `get_month_to_case_count`: Count cases created in each month
62 |
63 | ### Product and Issue Management
64 | - `search_products`: Search for products by name/description
65 | - `get_purchase_history`: Get purchase history for account/products
66 | - `get_issues`: Retrieve list of issue records
67 | - `get_issue_counts`: Get issue counts for products in a time period
68 | - `get_order_item_ids_by_product`: Get order items for a product
69 |
70 | ### Knowledge Base
71 | - `search_knowledge_articles`: Search knowledge articles by term
72 |
73 | ### Account Management
74 | - `get_account_id_by_contact_id`: Get Account ID for a Contact
75 |
76 | ### Utility Functions
77 | - `find_id_with_max_value`: Find IDs with maximum value in a dataset
78 | - `find_id_with_min_value`: Find IDs with minimum value in a dataset
79 | - `issue_soql_query`: Execute custom SOQL queries
80 | - `issue_sosl_query`: Execute custom SOSL queries
81 |
82 | ## Features
83 |
84 | - Seamless interaction with Salesforce instances
85 | - Support for both B2B and B2C configurations
86 | - Robust answer parsing and evaluation through fuzzy matching
87 | - Standardized example preprocessing for dataset handling
88 |
89 | ## Reward Functions
90 |
91 | The environment provides two reward metrics for evaluating model completions:
92 |
93 | ### Exact Match
94 | - Uses IoU (Intersection over Union) score
95 | - Compares completion tokens with ground truth tokens
96 | - Perfect score (1.0) for exact matches
97 | - Partial score based on token overlap
98 | - Returns 0.0 if one set is empty while other isn't
99 |
100 | ### Fuzzy Match
101 | - Uses F1 score for more lenient evaluation
102 | - Handles variations in text formatting and word order
103 | - Normalizes text by removing punctuation and articles
104 | - Suitable for cases where exact matching is too strict
105 |
106 | ## MCP Server
107 |
108 | The environment runs a Salesforce MCP server that handles API interactions. The server:
109 | - Manages authentication with Salesforce
110 | - Provides secure access to Salesforce operations
111 | - Handles API rate limiting and session management
112 |
113 |
114 | # CRMArena-Pro data
115 | crmarena_pro_queries = `load_dataset`("Salesforce/CRMArenaPro", "CRMArenaPro")
116 | b2b_schema = load_dataset("Salesforce/CRMArenaPro", "b2b_schema")
117 | b2c_schema = load_dataset("Salesforce/CRMArenaPro", "b2c_schema")
118 |
119 | TODO: Link to HuggingFace dataset
--------------------------------------------------------------------------------
/src/benchmax/envs/base_env.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, List, Any, Optional, Tuple
3 | from pathlib import Path
4 | from datasets import (
5 | DatasetDict,
6 | Dataset,
7 | IterableDatasetDict,
8 | IterableDataset,
9 | load_dataset,
10 | )
11 |
12 | from benchmax.envs.types import ToolDefinition, StandardizedExample
13 | from benchmax.prompts.tools import render_tools_prompt
14 |
15 |
16 | class BaseEnv(ABC):
17 | """Base benchmax environment for tool execution and reward computation"""
18 |
19 | system_prompt: str = ""
20 |
21 | @abstractmethod
22 | async def shutdown(self):
23 | pass
24 |
25 | # Override this method if your example does not match the default structure
26 | @classmethod
27 | def dataset_preprocess(cls, example: Any, **kwargs) -> StandardizedExample:
28 | """
29 | Preprocess a single dataset example into a dict with keys:
30 | - "prompt": str
31 | - "ground_truth": Any
32 | - "init_rollout_args": Dict[str, Any]
33 | """
34 | prompt = example.pop("prompt", "")
35 | ground_truth = example.pop("ground_truth", "")
36 | init_rollout_args = example.pop("init_rollout_args")
37 | return StandardizedExample(
38 | prompt=prompt,
39 | ground_truth=ground_truth,
40 | init_rollout_args=init_rollout_args,
41 | **example,
42 | )
43 |
44 | @classmethod
45 | def load_dataset(
46 | cls, dataset_name: str, **kwargs
47 | ) -> Tuple[
48 | DatasetDict | Dataset | IterableDatasetDict | IterableDataset, str | None
49 | ]:
50 | """
51 | Download and prepare a dataset for use with this environment.
52 |
53 | This method should handle retrieving the specified dataset (e.g., from HuggingFace, local files,
54 | or a custom source), preprocessing or converting it into a compatible structure, and storing it
55 | locally in a reusable format. The processed dataset should be suitable for downstream use with
56 | `dataset_preprocess`, which standardizes individual examples into the expected format.
57 |
58 | Args:
59 | dataset_name (str): Identifier of the dataset to be loaded.
60 | **kwargs: Additional dataset-specific arguments (e.g., split, filtering options, cache directory).
61 |
62 | Returns:
63 | Dataset: A dataset object (e.g., HuggingFace Dataset or similar) ready for processing.
64 | str: Optional string pointing to where the dataset is stored locally
65 | """
66 | return load_dataset(dataset_name, **kwargs), None
67 |
68 | @abstractmethod
69 | async def list_tools(self) -> List[ToolDefinition]:
70 | """Return list of available tools"""
71 | pass
72 |
73 | @abstractmethod
74 | async def run_tool(self, rollout_id: str, tool_name: str, **tool_args) -> Any:
75 | """Execute named tool in rollout context with given arguments"""
76 | pass
77 |
78 | @abstractmethod
79 | async def init_rollout(self, rollout_id: str, **rollout_args) -> None:
80 | """Initialize resources for a new rollout"""
81 | pass
82 |
83 | @abstractmethod
84 | async def release_rollout(self, rollout_id: str) -> None:
85 | """Free up resources for a new rollout. Called by compute_reward internally but also available for cleanup."""
86 | pass
87 |
88 | @abstractmethod
89 | async def copy_to_workspace(
90 | self, rollout_id: str, src_path: Path, dst_filename: Optional[str] = None
91 | ) -> None:
92 | """Copy a file to the workspace for a specific rollout. If dst_filename is None, use the original filename."""
93 | pass
94 |
95 | @abstractmethod
96 | async def copy_content_to_workspace(
97 | self, rollout_id: str, src_content: str | bytes, dst_filename: str
98 | ) -> None:
99 | """Create a file with given content in the workspace for a specific rollout"""
100 | pass
101 |
102 | @abstractmethod
103 | async def copy_from_workspace(
104 | self, rollout_id: str, src_filename: str, dst_path: Path
105 | ) -> None:
106 | """Copy a file from the workspace for a specific rollout"""
107 | pass
108 |
109 | @abstractmethod
110 | async def compute_reward(
111 | self, rollout_id: str, completion: str, ground_truth: Any, **kwargs: Any
112 | ) -> Dict[str, float]:
113 | """Compute rewards using registered functions
114 |
115 | Returns dict mapping reward function names to their computed scores.
116 | """
117 | pass
118 |
119 | async def get_system_prompt(self, add_tool_defs: bool = False) -> str:
120 | """Get system prompt. To add tool definitions, set add_tool_defs to True."""
121 | if add_tool_defs:
122 | return render_tools_prompt(
123 | await self.list_tools(), self.system_prompt or ""
124 | )
125 | else:
126 | return self.system_prompt
127 |
--------------------------------------------------------------------------------
/src/benchmax/envs/crm/workdir/reward_fn.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 | from html import unescape
4 | from pathlib import Path
5 | from collections import Counter
6 | from typing import Any, Callable, Dict, List, Optional, Union, Awaitable
7 | from fastmcp import Client
8 |
9 | RewardFunction = Callable[..., Union[float, Awaitable[float]]]
10 |
11 |
12 | def parse_answers(proposed_answer: str) -> str:
13 | """
14 | Parse the proposed answer.
15 | """
16 | m = re.search(
17 | r"(.*?)", proposed_answer, flags=re.IGNORECASE | re.DOTALL
18 | )
19 | if not m:
20 | proposed_answer = ""
21 | else:
22 | # Unescape any XML entities (& → &, etc.) and normalise whitespace.
23 | proposed_answer = unescape(m.group(1)).strip().lower()
24 | return proposed_answer
25 |
26 |
27 | def parse_text_to_tokens(text: str) -> set:
28 | """
29 | Parse text into normalized tokens using common separators.
30 |
31 | Args:
32 | text: Input text to parse
33 |
34 | Returns:
35 | set: Set of normalized tokens
36 | """
37 | if not text:
38 | return set()
39 |
40 | # Clean up the text by removing quotes and extra whitespace
41 | cleaned_text = text.strip().strip('"').strip("'").lower()
42 |
43 | # Split by common separators: spaces, commas, semicolons, pipes, tabs, newlines
44 | # Using regex to split on multiple separators
45 | tokens = re.split(r"[,\s|]+", cleaned_text)
46 |
47 | # Filter out empty tokens and normalize
48 | normalized_tokens = {token.strip() for token in tokens if token.strip()}
49 |
50 | return normalized_tokens
51 |
52 |
53 | def get_all_metrics(proposed_answer: str, ground_truth: str) -> float:
54 | """
55 | Compute fuzzy matching score between proposed answer and ground truth.
56 | Uses F1 score as the primary metric.
57 | """
58 |
59 | def normalize_answer(s):
60 | """Lower text and remove punctuation, articles and extra whitespace."""
61 |
62 | def remove_articles(text):
63 | return re.sub(r"\b(a|an|the)\b", " ", text)
64 |
65 | def white_space_fix(text):
66 | return " ".join(text.split())
67 |
68 | def handle_punc(text):
69 | exclude = set(string.punctuation + "".join(["'", "'", "´", "`"]))
70 | return "".join(ch if ch not in exclude else " " for ch in text)
71 |
72 | def lower(text):
73 | return text.lower()
74 |
75 | def replace_underscore(text):
76 | return text.replace("_", " ")
77 |
78 | return white_space_fix(
79 | remove_articles(handle_punc(lower(replace_underscore(s))))
80 | ).strip()
81 |
82 | def f1_score(prediction, ground_truth):
83 | prediction_tokens = normalize_answer(prediction).split()
84 | ground_truth_tokens = normalize_answer(ground_truth).split()
85 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
86 | num_same = sum(common.values())
87 | if num_same == 0:
88 | return 0
89 | precision = 1.0 * num_same / len(prediction_tokens)
90 | recall = 1.0 * num_same / len(ground_truth_tokens)
91 | f1 = (2 * precision * recall) / (precision + recall)
92 | return f1
93 |
94 | return f1_score(proposed_answer, ground_truth)
95 |
96 |
97 | def crm_matching_reward_function(
98 | completion: str,
99 | ground_truth: List[str],
100 | mcp_client: Client,
101 | workspace: Path,
102 | **kwargs: Any,
103 | ) -> float:
104 | """
105 | Reward function for CRM environment that evaluates model completions.
106 |
107 | Args:
108 | prompt: Input prompt given to the model
109 | completion: Model's generated completion/response
110 | ground_truth: Expected/correct output (should be a list)
111 | workspace: Path to rollout's workspace
112 | **kwargs: Additional context
113 |
114 | Returns:
115 | float: Reward score between 0 and 1
116 | """
117 | reward_metric: Optional[str] = kwargs.get("reward_metric")
118 |
119 | if not reward_metric:
120 | raise ValueError("kwargs must contain reward metric")
121 |
122 | proposed_answer = completion.strip() if completion else ""
123 | proposed_answer = parse_answers(proposed_answer)
124 |
125 | if reward_metric == "exact_match":
126 | # Parse and normalize the completion text
127 | completion_tokens = parse_text_to_tokens(proposed_answer)
128 |
129 | # Parse and normalize all ground truth items
130 | all_ground_truth_tokens = set()
131 | for gt_item in ground_truth:
132 | gt_tokens = parse_text_to_tokens(str(gt_item))
133 | all_ground_truth_tokens.update(gt_tokens)
134 |
135 | # Calculate IoU (Intersection over Union)
136 | if not all_ground_truth_tokens and not completion_tokens:
137 | return 1.0 # Both empty sets match perfectly
138 | elif not all_ground_truth_tokens or not completion_tokens:
139 | return 0.0 # One empty, one non-empty
140 |
141 | intersection = completion_tokens.intersection(all_ground_truth_tokens)
142 | union = completion_tokens.union(all_ground_truth_tokens)
143 |
144 | iou_score = len(intersection) / len(union) if union else 0.0
145 |
146 | # Return 1.0 if perfect match (IoU = 1.0), otherwise return IoU score
147 | return iou_score
148 |
149 | elif reward_metric == "fuzzy_match":
150 | # For fuzzy match, we only have 1 ground truth item
151 | if ground_truth[0] is not None:
152 | return get_all_metrics(proposed_answer, str(ground_truth[0]))
153 | else:
154 | return 0.0
155 |
156 | else:
157 | print(f"Unknown reward metric: {reward_metric}")
158 | return 0.0
159 |
160 |
161 | # -------------------------------
162 | # Export reward functions
163 | # -------------------------------
164 | reward_functions: Dict[str, RewardFunction] = {
165 | "match": crm_matching_reward_function,
166 | }
167 |
--------------------------------------------------------------------------------
/tests/integration/envs/mcp/conftest.py:
--------------------------------------------------------------------------------
1 | """
2 | Integration fixtures for MCP environment tests.
3 | These may start subprocesses, use real ports, or access the filesystem.
4 | """
5 |
6 | import uuid
7 | import pytest
8 | from pathlib import Path
9 | from typing import Tuple, List
10 | from collections.abc import AsyncGenerator
11 | import sky
12 |
13 | from benchmax.envs.mcp import ParallelMcpEnv
14 | from benchmax.envs.mcp.provisioners import (
15 | LocalProvisioner,
16 | ManualProvisioner,
17 | SkypilotProvisioner,
18 | )
19 |
20 | # ===== Session-scoped: Provision servers once for all tests =====
21 |
22 |
23 | @pytest.fixture(scope="session")
24 | async def local_servers_with_secret(
25 | example_workdir: Path,
26 | ) -> AsyncGenerator[Tuple[List[str], str], None]:
27 | """
28 | Provision local servers once for the entire test session.
29 | Returns (addresses, api_secret) tuple.
30 |
31 | These servers are reused across tests for speed.
32 | """
33 | api_secret = uuid.uuid4().hex
34 | provisioner = LocalProvisioner(
35 | workdir_path=example_workdir,
36 | num_servers=4,
37 | base_port=8080,
38 | )
39 |
40 | addresses = await provisioner.provision_servers(api_secret)
41 | print(f"\n[Session Setup] Provisioned {len(addresses)} local servers")
42 |
43 | yield addresses, api_secret
44 |
45 | print(f"\n[Session Teardown] Tearing down {len(addresses)} local servers")
46 | await provisioner.teardown()
47 |
48 |
49 | @pytest.fixture(scope="session")
50 | async def skypilot_servers_with_secret(
51 | example_workdir: Path,
52 | ) -> AsyncGenerator[Tuple[List[str], str], None]:
53 | """
54 | Provision Skypilot servers once for the entire test session.
55 | Returns (addresses, api_secret) tuple.
56 |
57 | These servers are reused across tests for speed.
58 | Only provisioned if test actually uses this fixture.
59 | """
60 |
61 | api_secret = uuid.uuid4().hex
62 | provisioner = SkypilotProvisioner(
63 | workdir_path=example_workdir,
64 | cloud=sky.Azure(),
65 | num_nodes=2,
66 | servers_per_node=4,
67 | base_cluster_name="test-cluster",
68 | cpus=2,
69 | memory=8
70 | )
71 |
72 | addresses = await provisioner.provision_servers(api_secret)
73 | print(f"\n[Session Setup] Provisioned {len(addresses)} Skypilot servers")
74 |
75 | yield addresses, api_secret
76 |
77 | print(f"\n[Session Teardown] Tearing down {len(addresses)} Skypilot servers")
78 | await provisioner.teardown()
79 |
80 |
81 | # ===== Function-scoped: Fresh env for each test =====
82 |
83 |
84 | @pytest.fixture
85 | async def local_env(
86 | local_servers_with_secret: Tuple[List[str], str], example_workdir: Path
87 | ) -> AsyncGenerator[ParallelMcpEnv, None]:
88 | """
89 | Create a fresh ParallelMcpEnv using reused local servers.
90 |
91 | Each test gets a clean env instance, but servers are shared
92 | across tests for speed. This means server state may be
93 | contaminated, but that's acceptable for most tests.
94 | """
95 | addresses, api_secret = local_servers_with_secret
96 |
97 | manual_provisioner = ManualProvisioner(addresses)
98 | env = ParallelMcpEnv(
99 | workdir_path=example_workdir,
100 | provisioner=manual_provisioner,
101 | api_secret=api_secret,
102 | provision_at_init=True,
103 | )
104 |
105 | yield env
106 |
107 | await env.shutdown()
108 |
109 |
110 | @pytest.fixture
111 | async def skypilot_env(
112 | skypilot_servers_with_secret: Tuple[List[str], str], example_workdir: Path
113 | ) -> AsyncGenerator[ParallelMcpEnv, None]:
114 | """
115 | Create a fresh ParallelMcpEnv using reused Skypilot servers.
116 |
117 | Each test gets a clean env instance, but servers are shared
118 | across tests for speed. Mark tests using this with @pytest.mark.remote.
119 | """
120 | addresses, api_secret = skypilot_servers_with_secret
121 |
122 | manual_provisioner = ManualProvisioner(addresses)
123 | env = ParallelMcpEnv(
124 | workdir_path=example_workdir,
125 | provisioner=manual_provisioner,
126 | api_secret=api_secret,
127 | provision_at_init=True,
128 | )
129 |
130 | yield env
131 |
132 | await env.shutdown()
133 |
134 |
135 | @pytest.fixture
136 | async def fresh_local_env(example_workdir: Path) -> AsyncGenerator[ParallelMcpEnv, None]:
137 | """
138 | Create a fresh ParallelMcpEnv with its own dedicated servers.
139 |
140 | Use this for tests that need clean server state.
141 | Slower than local_env but provides isolation.
142 | Mark tests using this with @pytest.mark.slow.
143 | """
144 | provisioner = LocalProvisioner(
145 | workdir_path=example_workdir,
146 | num_servers=4,
147 | base_port=9080, # Different port range to avoid conflicts
148 | )
149 |
150 | env = ParallelMcpEnv(
151 | workdir_path=example_workdir,
152 | provisioner=provisioner,
153 | provision_at_init=True,
154 | )
155 |
156 | yield env
157 |
158 | await env.shutdown()
159 |
160 |
161 | @pytest.fixture
162 | async def fresh_skypilot_env(
163 | example_workdir: Path,
164 | ) -> AsyncGenerator[ParallelMcpEnv, None]:
165 | """
166 | Create a fresh ParallelMcpEnv with its own dedicated Skypilot servers.
167 |
168 | Use this for E2E tests that need clean server state on cloud infrastructure.
169 | Slower and more expensive than skypilot_env.
170 | Mark tests using this with @pytest.mark.remote and @pytest.mark.slow.
171 | """
172 | provisioner = SkypilotProvisioner(
173 | workdir_path=example_workdir,
174 | cloud=sky.Azure(),
175 | num_nodes=2,
176 | servers_per_node=4,
177 | base_cluster_name="test-cluster-fresh",
178 | cpus=2,
179 | memory=8,
180 | )
181 |
182 | env = ParallelMcpEnv(
183 | workdir_path=example_workdir,
184 | provisioner=provisioner,
185 | provision_at_init=True,
186 | )
187 |
188 | yield env
189 |
190 | await env.shutdown()
191 |
192 |
--------------------------------------------------------------------------------
/tests/integration/adapters/skyrl/skyrl_adapter_integration.py:
--------------------------------------------------------------------------------
1 | # import pytest
2 | # import ray
3 |
4 | # from benchmax.adapters.skyrl.skyrl_adapter import load_benchmax_env_skyrl, RemoteBaseEnvProxy, expose
5 |
6 | # # ---- Fixtures and dummies ----
7 | # class DummyActor:
8 | # """Dummy Ray actor to simulate remote Benchmax service."""
9 | # def __init__(self):
10 | # self.calls = []
11 |
12 | # class Method:
13 | # def __init__(self, name, parent):
14 | # self.name = name
15 | # self.parent = parent
16 | # def remote(self, *args, **kwargs):
17 | # self.parent.calls.append((self.name, args, kwargs))
18 | # # return a sentinel value
19 | # if self.name == 'compute_reward':
20 | # return {'a': 1.0, 'b': 2.0}
21 | # return f"{self.name}-result"
22 |
23 | # @property
24 | # def list_tools(self):
25 | # return DummyActor.Method('list_tools', self)
26 |
27 | # @property
28 | # def compute_reward(self):
29 | # return DummyActor.Method('compute_reward', self)
30 |
31 | # @property
32 | # def run_tool(self):
33 | # return DummyActor.Method('run_tool', self)
34 |
35 | # @property
36 | # def init_rollout(self):
37 | # return DummyActor.Method('init_rollout', self)
38 |
39 | # @property
40 | # def cleanup_rollout(self):
41 | # return DummyActor.Method('cleanup_rollout', self)
42 |
43 | # @pytest.fixture(autouse=True)
44 | # def patch_ray(monkeypatch):
45 | # dummy = DummyActor()
46 | # monkeypatch.setattr(ray, 'get_actor', lambda name: dummy)
47 | # monkeypatch.setattr(ray, 'get', lambda x: x)
48 | # return dummy
49 |
50 | # # ---- Tests for RemoteBaseEnvProxy ----
51 | # def test_list_tools_proxy(patch_ray):
52 | # proxy = RemoteBaseEnvProxy(actor_name='BenchmaxEnvService', rollout_id='rid')
53 | # result = proxy.list_tools()
54 | # assert result == 'list_tools-result'
55 | # # Ensure no rollout_id injected for list_tools
56 | # assert patch_ray.calls == [('list_tools', (), {})]
57 |
58 |
59 | # def test_compute_reward_proxy(patch_ray):
60 | # proxy = RemoteBaseEnvProxy(actor_name='BenchmaxEnvService', rollout_id='RID123')
61 | # reward = proxy.compute_reward('task1', 'action1', {'gt': True})
62 | # assert reward == {'a': 1.0, 'b': 2.0}
63 | # # rollout_id should be first arg
64 | # name, args, kwargs = patch_ray.calls[-1]
65 | # assert name == 'compute_reward'
66 | # assert args[0] == 'RID123'
67 |
68 | # # ---- Tests for _call_tool ----
69 | # class DummyEnv:
70 | # def __init__(self):
71 | # self.benchmax_env = RemoteBaseEnvProxy(actor_name='BenchmaxEnvService')
72 | # self.extras = {}
73 | # _call_tool = load_benchmax_env_skyrl.__wrapped__.__defaults__[0]._call_tool if False else None
74 |
75 | # @pytest.fixture
76 | # def configured_env(patch_ray):
77 | # # Build a minimal SkyRL env
78 | # cfg = {}
79 | # extras = {'init_rollout_args': {}, 'task': 'T1', 'ground_truth': {}}
80 | # env = load_benchmax_env_skyrl(actor_name='BenchmaxEnvService', env_config=cfg, extras=extras)
81 | # return env
82 |
83 |
84 | # def test_call_tool_errors(configured_env):
85 | # # Not dict
86 | # assert configured_env._call_tool('not_a_dict') == "Error: Tool command must be a JSON object."
87 | # # Missing name
88 | # assert configured_env._call_tool({'arguments': {}}) == "Error: Missing 'name' field in tool command."
89 | # # Arguments not dict
90 | # assert "must be a JSON object" in configured_env._call_tool({'name': 'foo', 'arguments': 'bad'})
91 |
92 |
93 | # def test_call_tool_success_and_truncate(configured_env):
94 | # # Valid call
95 | # out = configured_env._call_tool({'name': 'run_tool', 'arguments': {'x': 1}})
96 | # assert out.startswith('run_tool-result')
97 | # # Truncate
98 | # long_res = configured_env._call_tool({'name': 'run_tool', 'arguments': {}}, max_chars=5)
99 | # assert long_res.endswith('...')
100 |
101 | # # ---- Tests for step() ----
102 |
103 | # @ pytest.fixture(autouse=True)
104 | # def patch_parse(monkeypatch):
105 | # # default parse returns no tools
106 | # monkeypatch.setattr('benchmax.prompts.tools.parse_hermes_tool_call', lambda x: [])
107 |
108 |
109 | # def test_step_final_reward(configured_env, patch_parse):
110 | # # parse returns [], so done=True
111 | # out = configured_env.step('final answer')
112 | # assert out["done"] is True
113 | # assert out["reward"] == 3.0
114 | # assert out["observations"] == []
115 |
116 |
117 | # def test_step_tool_flow(monkeypatch, configured_env):
118 | # # parse returns one tool call
119 | # monkeypatch.setattr('benchmax.prompts.tools.parse_hermes_tool_call', lambda x: [{'name': 'run_tool', 'arguments': {}}])
120 | # # Stub _call_tool
121 | # monkeypatch.setattr(configured_env, '_call_tool', lambda call: 'obs-text')
122 | # out = configured_env.step('{"name": "calculate", "arguments": {"expression": "25 ÷ 5 + 4 × 3"}}')
123 | # assert out["done"] is False
124 | # assert out["reward"] == 0.0
125 | # assert out["observations"] == [{'role': 'user', 'content': 'obs-text'}]
126 |
127 |
128 | # def test_step_tool_error(monkeypatch, configured_env):
129 | # monkeypatch.setattr('benchmax.prompts.tools.parse_hermes_tool_call', lambda x: [{'name': 'run_tool', 'arguments': {}}])
130 | # def bad_call(call):
131 | # raise RuntimeError('fail')
132 | # monkeypatch.setattr(configured_env, '_call_tool', bad_call)
133 | # out = configured_env.step('{"name": "calculate", "arguments": {"expression": "25 ÷ 5 + 4 × 3"}}')
134 | # print(out)
135 | # assert out["done"] is False
136 | # assert out["observations"][0]['content'] == 'fail'
137 |
138 | # # ---- Tests for expose decorator ----
139 | # import asyncio
140 | # class FakeEnv:
141 | # async def foo(self, x):
142 | # return x * 2
143 |
144 | # @pytest.mark.asyncio
145 | # async def test_expose_wrapper():
146 | # wrapper = expose('foo')
147 | # class Host:
148 | # def __init__(self):
149 | # self.env = FakeEnv()
150 | # host = Host()
151 | # res = await wrapper(host, 10)
152 | # assert res == 20
--------------------------------------------------------------------------------
/tests/unit/envs/mcp/test_mcp_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for MCP utility functions.
3 | """
4 |
5 | import jwt
6 | from mcp import Tool
7 | from benchmax.envs.mcp.utils import (
8 | convert_tool_definitions,
9 | generate_jwt_token,
10 | get_auth_headers,
11 | )
12 | from benchmax.envs.types import ToolDefinition
13 |
14 |
15 | class TestConvertToolDefinitions:
16 | """Tests for convert_tool_definitions function."""
17 |
18 | def test_convert_tool_definitions_basic(self):
19 | """Test basic conversion of MCP tools to ToolDefinitions."""
20 | mcp_tools = [
21 | Tool(
22 | name="read_file",
23 | description="Read a file",
24 | inputSchema={
25 | "type": "object",
26 | "properties": {"path": {"type": "string"}},
27 | },
28 | ),
29 | Tool(
30 | name="write_file",
31 | description="Write to a file",
32 | inputSchema={
33 | "type": "object",
34 | "properties": {"path": {"type": "string"}},
35 | },
36 | ),
37 | ]
38 |
39 | result = convert_tool_definitions(mcp_tools, allowed_tools=None)
40 |
41 | assert len(result) == 2
42 | assert all(isinstance(t, ToolDefinition) for t in result)
43 | assert result[0].name == "read_file"
44 | assert result[1].name == "write_file"
45 |
46 | def test_convert_tool_definitions_with_filter(self):
47 | """Test filtering tools with allowed_tools list."""
48 | mcp_tools = [
49 | Tool(name="read_file", description="Read", inputSchema={}),
50 | Tool(name="write_file", description="Write", inputSchema={}),
51 | Tool(name="execute", description="Execute", inputSchema={}),
52 | ]
53 |
54 | allowed = ["read_file", "write_file"]
55 | result = convert_tool_definitions(mcp_tools, allowed_tools=allowed)
56 |
57 | assert len(result) == 2
58 | assert all(t.name in allowed for t in result)
59 | assert not any(t.name == "execute" for t in result)
60 |
61 | def test_convert_tool_definitions_empty_description(self):
62 | """Test handling of tools with no description."""
63 | mcp_tools = [Tool(name="tool1", description=None, inputSchema={})]
64 |
65 | result = convert_tool_definitions(mcp_tools, allowed_tools=None)
66 |
67 | assert len(result) == 1
68 | assert result[0].description == ""
69 |
70 | def test_convert_tool_definitions_empty_list(self):
71 | """Test conversion of empty tool list."""
72 | result = convert_tool_definitions([], allowed_tools=None)
73 | assert result == []
74 |
75 |
76 | class TestGenerateJwtToken:
77 | """Tests for generate_jwt_token function."""
78 |
79 | def test_generate_jwt_token_contains_standard_claims(self):
80 | """Ensure generated JWT includes required standard claims."""
81 | secret = "secret"
82 | token = generate_jwt_token(secret)
83 | decoded = jwt.decode(
84 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server"
85 | )
86 |
87 | assert decoded["iss"] == "mcp-client"
88 | assert decoded["aud"] == "mcp-proxy-server"
89 | assert "iat" in decoded
90 | assert "exp" in decoded
91 | assert decoded["exp"] > decoded["iat"]
92 |
93 | def test_generate_jwt_token_includes_rollout_id(self):
94 | """Ensure rollout_id is included if provided."""
95 | secret = "secret"
96 | token = generate_jwt_token(secret, rollout_id="rollout-123")
97 | decoded = jwt.decode(
98 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server"
99 | )
100 |
101 | assert decoded["rollout_id"] == "rollout-123"
102 |
103 | def test_generate_jwt_token_includes_extra_claims(self):
104 | """Ensure extra custom claims are included."""
105 | secret = "secret"
106 | token = generate_jwt_token(secret, user="test_user", env="staging")
107 | decoded = jwt.decode(
108 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server"
109 | )
110 |
111 | assert decoded["user"] == "test_user"
112 | assert decoded["env"] == "staging"
113 |
114 | def test_generate_jwt_token_expiration_respects_custom_value(self):
115 | """Ensure custom expiration_seconds is respected."""
116 | secret = "secret"
117 | token = generate_jwt_token(secret, expiration_seconds=60)
118 | decoded = jwt.decode(
119 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server"
120 | )
121 |
122 | assert (
123 | abs(decoded["exp"] - decoded["iat"] - 60) <= 1
124 | ) # small clock drift margin
125 |
126 |
127 | class TestGetAuthHeaders:
128 | """Tests for get_auth_headers function."""
129 |
130 | def test_get_auth_headers_contains_bearer_prefix(self):
131 | """Ensure Authorization header has proper Bearer format."""
132 | headers = get_auth_headers("secret", rollout_id="r1")
133 | assert "Authorization" in headers
134 | assert headers["Authorization"].startswith("Bearer ")
135 |
136 | def test_get_auth_headers_token_decodes_correctly(self):
137 | """Ensure token inside header is valid and decodable."""
138 | secret = "secret"
139 | headers = get_auth_headers(secret, rollout_id="rollout-xyz")
140 | token = headers["Authorization"].split("Bearer ")[1]
141 |
142 | decoded = jwt.decode(
143 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server"
144 | )
145 | assert decoded["rollout_id"] == "rollout-xyz"
146 |
147 | def test_get_auth_headers_includes_extra_claims(self):
148 | """Ensure extra claims propagate correctly through get_auth_headers."""
149 | secret = "secret"
150 | headers = get_auth_headers(secret, env="prod", user="alice")
151 | token = headers["Authorization"].split("Bearer ")[1]
152 |
153 | decoded = jwt.decode(
154 | token, secret, algorithms=["HS256"], audience="mcp-proxy-server"
155 | )
156 | assert decoded["env"] == "prod"
157 | assert decoded["user"] == "alice"
158 |
--------------------------------------------------------------------------------
/tests/integration/envs/excel/test_excel_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import pytest
3 | from pathlib import Path
4 | from openpyxl import load_workbook
5 | from openpyxl.styles import PatternFill, Font
6 | from benchmax.envs.excel.workdir.excel_utils import compare_excel_cells, evaluate_excel
7 |
8 | @pytest.fixture(scope="session")
9 | def setup_files():
10 | base_excel_path = Path(__file__).parent / "test_inputs" / "test.xlsx"
11 | ground_truth_path = Path(__file__).parent / "test_inputs" / "test_gt.xlsx"
12 | output_path = Path(__file__).parent / "test_inputs" / "test_output.xlsx"
13 |
14 | # Create ground truth file
15 | wb_gt = load_workbook(base_excel_path)
16 | ws_gt = wb_gt.active
17 | assert ws_gt is not None, "Worksheet should not be None"
18 | ws_gt["H3"] = "Text Value"
19 | ws_gt["H4"] = 4
20 | ws_gt["H5"] = 2 # Numeric value
21 | ws_gt["H6"] = "Mismatch Value"
22 | ws_gt["I3"] = "Matching Value"
23 | ws_gt["I4"] = "Matching Value"
24 | ws_gt["I3"].fill = PatternFill(start_color="FFFF00", end_color="FFFF00", fill_type="solid") # Yellow
25 | ws_gt["I4"].font = Font(color="FF0000") # Red font
26 | wb_gt.save(ground_truth_path)
27 | wb_gt.close()
28 |
29 | # Create output file
30 | wb_output = load_workbook(base_excel_path)
31 | ws_output = wb_output.active
32 | assert ws_output is not None, "Worksheet should not be None"
33 | ws_output["H3"] = "Text Value"
34 | ws_output["H4"] = 4
35 | ws_output["H5"] = "=G5" # Formula
36 | ws_output["H6"] = "Different Mismatch Value"
37 | ws_output["I3"] = "Matching Value"
38 | ws_output["I4"] = "Matching Value"
39 | ws_output["I3"].fill = PatternFill(start_color="00FF00", end_color="00FF00", fill_type="solid") # Green
40 | ws_output["I4"].font = Font(color="0000FF") # Blue font
41 | wb_output.save(output_path)
42 | wb_output.close()
43 |
44 | evaluate_excel(str(ground_truth_path))
45 | evaluate_excel(str(output_path))
46 |
47 | yield ground_truth_path, output_path
48 |
49 | # Cleanup
50 | if ground_truth_path.exists():
51 | ground_truth_path.unlink()
52 | if output_path.exists():
53 | output_path.unlink()
54 |
55 | # Test for mismatched values in a single cell
56 | @pytest.mark.excel
57 | def test_single_cell_comparison_mismatch(setup_files: Tuple[Path, Path]):
58 | ground_truth_path, test_output_path = setup_files
59 | result, message = compare_excel_cells(
60 | str(ground_truth_path),
61 | str(test_output_path),
62 | "H6"
63 | )
64 | assert not result
65 | assert "Value mismatch" in message
66 |
67 | # Test for matching values in a single cell
68 | @pytest.mark.excel
69 | def test_single_cell_comparison_match(setup_files: Tuple[Path, Path]):
70 | ground_truth_path, test_output_path = setup_files
71 | result, message = compare_excel_cells(
72 | str(ground_truth_path),
73 | str(test_output_path),
74 | "H3"
75 | )
76 | assert result
77 | assert "passed" in message
78 |
79 | # Test for mismatched values in a range of cells
80 | @pytest.mark.excel
81 | def test_range_comparison_mismatch(setup_files: Tuple[Path, Path]):
82 | ground_truth_path, test_output_path = setup_files
83 | result, message = compare_excel_cells(
84 | str(ground_truth_path),
85 | str(test_output_path),
86 | "D4:H6"
87 | )
88 | assert not result
89 | assert "Value mismatch" in message
90 |
91 | # Test for matching values in a range of cells
92 | @pytest.mark.excel
93 | def test_range_comparison_match(setup_files: Tuple[Path, Path]):
94 | ground_truth_path, test_output_path = setup_files
95 | result, message = compare_excel_cells(
96 | str(ground_truth_path),
97 | str(test_output_path),
98 | "H5:H5"
99 | )
100 | assert result
101 | assert "passed" in message
102 |
103 | # Test for handling sheet names with mismatched values
104 | @pytest.mark.excel
105 | def test_sheet_name_handling_mismatch(setup_files: Tuple[Path, Path]):
106 | ground_truth_path, test_output_path = setup_files
107 | result, message = compare_excel_cells(
108 | str(ground_truth_path),
109 | str(test_output_path),
110 | "'Sheet1'!H6"
111 | )
112 | assert not result
113 | assert "Value mismatch" in message
114 |
115 | # Test for handling sheet names with matching values
116 | @pytest.mark.excel
117 | def test_sheet_name_handling_match(setup_files: Tuple[Path, Path]):
118 | ground_truth_path, test_output_path = setup_files
119 | result, message = compare_excel_cells(
120 | str(ground_truth_path),
121 | str(test_output_path),
122 | "'Sheet1'!H3"
123 | )
124 | assert result
125 | assert "passed" in message
126 |
127 | # Test for comparing cell formatting with mismatched styles
128 | @pytest.mark.excel
129 | def test_fill_color_comparison_mismatch(setup_files: Tuple[Path, Path]):
130 | ground_truth_path, test_output_path = setup_files
131 | result, message = compare_excel_cells(
132 | str(ground_truth_path),
133 | str(test_output_path),
134 | "I3",
135 | is_CF=True
136 | )
137 | assert not result
138 | assert "Fill color mismatch" in message
139 |
140 |
141 | # Test for comparing cell formatting with mismatched styles
142 | @pytest.mark.excel
143 | def test_font_color_comparison_mismatch(setup_files: Tuple[Path, Path]):
144 | ground_truth_path, test_output_path = setup_files
145 | result, message = compare_excel_cells(
146 | str(ground_truth_path),
147 | str(test_output_path),
148 | "I4",
149 | is_CF=True
150 | )
151 | assert not result
152 | assert "Font color mismatch" in message
153 |
154 | # Test for comparing cell formatting with matching styles
155 | @pytest.mark.excel
156 | def test_formatting_comparison_match(setup_files: Tuple[Path, Path]):
157 | ground_truth_path, test_output_path = setup_files
158 | result, message = compare_excel_cells(
159 | str(ground_truth_path),
160 | str(test_output_path),
161 | "H3",
162 | is_CF=True
163 | )
164 | assert result
165 | assert "passed" in message
166 |
167 | # Test for handling missing sheets with mismatched values
168 | @pytest.mark.excel
169 | def test_missing_sheet_mismatch(setup_files: Tuple[Path, Path]):
170 | ground_truth_path, test_output_path = setup_files
171 | result, message = compare_excel_cells(
172 | str(ground_truth_path),
173 | str(test_output_path),
174 | "'NonExistentSheet'!A1"
175 | )
176 | assert not result
177 | assert "Worksheet 'NonExistentSheet' not found" in message
--------------------------------------------------------------------------------
/tests/integration/envs/mcp/provisioners/test_skypilot_integration.py:
--------------------------------------------------------------------------------
1 | """
2 | Integration tests for SkypilotProvisioner.
3 | """
4 |
5 | import pytest
6 | import asyncio
7 | import sky
8 | from pathlib import Path
9 | from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner
10 | from tests.integration.envs.mcp.provisioners.utils import wait_for_server_health
11 |
12 |
13 | class TestEndToEnd:
14 | """End-to-end integration tests for provisioning and teardown."""
15 |
16 | @pytest.mark.asyncio
17 | @pytest.mark.slow
18 | @pytest.mark.remote
19 | async def test_single_node_lifecycle(self, example_workdir: Path):
20 | """Test complete lifecycle of a single-node cluster with all validations."""
21 | base_name = "test-single-node-cluster"
22 | api_secret = "single-node-test-secret-32chars!!"
23 |
24 | provisioner = SkypilotProvisioner(
25 | workdir_path=example_workdir,
26 | cloud=sky.Azure(),
27 | num_nodes=1,
28 | servers_per_node=2,
29 | cpus="4+",
30 | memory="16+",
31 | base_cluster_name=base_name,
32 | )
33 |
34 | try:
35 | # Verify configuration before provisioning
36 | assert provisioner.cluster_name.startswith(f"{base_name}-")
37 | assert provisioner._workdir_path.is_absolute()
38 |
39 | # Provision
40 | addresses = await provisioner.provision_servers(api_secret)
41 | assert len(addresses) == 2 # 1 node * 2 servers
42 |
43 | # Verify all addresses have correct format
44 | for addr in addresses:
45 | assert ":" in addr
46 | host, port = addr.split(":")
47 | assert port.isdigit()
48 | assert 8080 <= int(port) < 8090
49 |
50 | # Verify servers are up and healthy
51 | health_checks = await asyncio.gather(
52 | *[wait_for_server_health(addr, timeout=90.0) for addr in addresses]
53 | )
54 | assert all(health_checks), "Not all servers became healthy"
55 |
56 | # Check that double-provisioning would result in an error
57 | with pytest.raises(RuntimeError, match="already provisioned"):
58 | await provisioner.provision_servers(api_secret)
59 | finally:
60 | # Teardown - always attempt cleanup
61 | await provisioner.teardown()
62 |
63 | @pytest.mark.asyncio
64 | @pytest.mark.slow
65 | @pytest.mark.remote
66 | async def test_multi_node_lifecycle(self, example_workdir: Path):
67 | """Test complete lifecycle of a multi-node cluster with all validations."""
68 | provisioner1 = SkypilotProvisioner(
69 | workdir_path=example_workdir,
70 | cloud=sky.Azure(),
71 | num_nodes=2,
72 | servers_per_node=3,
73 | cpus=2,
74 | memory=8,
75 | )
76 |
77 | provisioner2 = SkypilotProvisioner(
78 | workdir_path=example_workdir,
79 | cloud=sky.Azure(),
80 | num_nodes=1,
81 | servers_per_node=1,
82 | )
83 |
84 | api_secret = "multi-node-test-secret-32chars!!"
85 |
86 | try:
87 | # Verify cluster names are unique
88 | assert provisioner1.cluster_name != provisioner2.cluster_name
89 | assert provisioner1.cluster_name.startswith("benchmax-env-cluster-")
90 | assert provisioner2.cluster_name.startswith("benchmax-env-cluster-")
91 |
92 | # Provision main test cluster
93 | addresses = await provisioner1.provision_servers(api_secret)
94 | assert len(addresses) == 6 # 2 nodes * 3 servers
95 |
96 | # Verify addresses are grouped by node
97 | hosts = [addr.split(":")[0] for addr in addresses]
98 | unique_hosts = set(hosts)
99 | assert len(unique_hosts) == 2, "Should have 2 unique node IPs"
100 |
101 | for host in unique_hosts:
102 | host_servers = [addr for addr in addresses if addr.startswith(host)]
103 | assert len(host_servers) == 3, f"Host {host} should have 3 servers"
104 |
105 | # Verify all servers are up and healthy
106 | health_checks = await asyncio.gather(
107 | *[wait_for_server_health(addr, timeout=90.0) for addr in addresses]
108 | )
109 | assert all(health_checks), "Not all servers became healthy"
110 | finally:
111 | # Teardown - always attempt cleanup
112 | await provisioner1.teardown()
113 |
114 | class TestValidation:
115 | """Test parameter validation without provisioning."""
116 |
117 | def test_invalid_num_nodes(self):
118 | """Test validation of num_nodes parameter."""
119 | with pytest.raises(ValueError, match="at least 1"):
120 | SkypilotProvisioner(
121 | workdir_path=".",
122 | cloud=sky.Azure(),
123 | num_nodes=0,
124 | servers_per_node=5,
125 | )
126 |
127 | def test_invalid_servers_per_node_low(self):
128 | """Test validation of servers_per_node parameter (too low)."""
129 | with pytest.raises(ValueError, match="between 1 and 100"):
130 | SkypilotProvisioner(
131 | workdir_path=".",
132 | cloud=sky.Azure(),
133 | num_nodes=1,
134 | servers_per_node=0,
135 | )
136 |
137 | def test_invalid_servers_per_node_high(self):
138 | """Test validation of servers_per_node parameter (too high)."""
139 | with pytest.raises(ValueError, match="between 1 and 100"):
140 | SkypilotProvisioner(
141 | workdir_path=".",
142 | cloud=sky.Azure(),
143 | num_nodes=1,
144 | servers_per_node=101,
145 | )
146 |
147 | def test_custom_base_cluster_name(self):
148 | """Test that custom base cluster name is used."""
149 | provisioner = SkypilotProvisioner(
150 | workdir_path=".",
151 | cloud=sky.Azure(),
152 | num_nodes=1,
153 | servers_per_node=1,
154 | base_cluster_name="my-custom-cluster",
155 | )
156 |
157 | assert provisioner.cluster_name.startswith("my-custom-cluster-")
158 |
159 | def test_workdir_path_conversion(self):
160 | """Test that workdir_path is properly converted to absolute Path."""
161 | provisioner = SkypilotProvisioner(
162 | workdir_path="./relative/path",
163 | cloud=sky.Azure(),
164 | num_nodes=1,
165 | servers_per_node=1,
166 | )
167 |
168 | # Should be converted to absolute path
169 | assert provisioner._workdir_path.is_absolute()
170 |
--------------------------------------------------------------------------------
/src/benchmax/adapters/skyrl/benchmax_data_process.py:
--------------------------------------------------------------------------------
1 | """
2 | Preprocess a huggingface/benchmax dataset to a multiturn format suitable for a benchmax environment.
3 | """
4 |
5 | import argparse
6 | import logging
7 | from importlib import import_module
8 | from pathlib import Path
9 | from types import ModuleType
10 | from typing import Type
11 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
12 | import datasets
13 | import asyncio
14 | import inspect
15 |
16 | from benchmax.envs.base_env import BaseEnv
17 |
18 | # Set logging level to WARNING and above
19 | logging.basicConfig(level=logging.WARNING)
20 |
21 |
22 | def load_class(dotted_path: str) -> Type[BaseEnv]:
23 | """
24 | Load and return the class specified by `dotted_path`.
25 | Example: "benchmax.envs.wikipedia.wiki_env.WikipediaEnv"
26 | """
27 | try:
28 | module_path, class_name = dotted_path.rsplit(".", 1)
29 | except ValueError as exc:
30 | raise ImportError(
31 | f'"{dotted_path}" doesn\'t look like "package.module.Class"'
32 | ) from exc
33 |
34 | module: ModuleType = import_module(module_path)
35 | try:
36 | cls: Type[BaseEnv] = getattr(module, class_name)
37 | except AttributeError as exc:
38 | raise ImportError(
39 | f'Module "{module_path}" has no attribute "{class_name}"'
40 | ) from exc
41 |
42 | return cls
43 |
44 |
45 | def get_canonical_class_name(cls: Type[BaseEnv]) -> str:
46 | """
47 | Get the canonical class name, removing local/skypilot prefix/suffix if the parent class
48 | has the same name without that prefix/suffix.
49 | """
50 | class_name = cls.__name__
51 |
52 | # Check for prefixes/suffixes to strip
53 | prefixes = ["local", "skypilot"]
54 | suffixes = ["local", "skypilot"]
55 |
56 | # Try to find a matching parent class without the prefix/suffix
57 | for base_cls in cls.__bases__:
58 | base_name = base_cls.__name__
59 |
60 | # Check if current class has prefix that base doesn't
61 | for prefix in prefixes:
62 | if class_name.lower().startswith(
63 | prefix
64 | ) and not base_name.lower().startswith(prefix):
65 | # Check if removing prefix gives us the base name
66 | stripped = class_name[len(prefix) :]
67 | if stripped == base_name:
68 | return base_name
69 |
70 | # Check if current class has suffix that base doesn't
71 | for suffix in suffixes:
72 | if class_name.lower().endswith(suffix) and not base_name.lower().endswith(
73 | suffix
74 | ):
75 | # Check if removing suffix gives us the base name
76 | stripped = class_name[: -len(suffix)]
77 | if stripped == base_name:
78 | return base_name
79 |
80 | # No matching parent found, return original name
81 | return class_name
82 |
83 |
84 | async def get_system_prompt(cls: Type[BaseEnv]) -> str:
85 | """Setup env and get system prompt in async context."""
86 | # Initialize env with num_local_servers=1 if supported
87 | init_signature = inspect.signature(cls.__init__)
88 | if "num_local_servers" in init_signature.parameters:
89 | env = cls(num_local_servers=1) # type: ignore
90 | else:
91 | env = cls()
92 |
93 | # Get system prompt (async function)
94 | prompt = await env.get_system_prompt(add_tool_defs=True)
95 |
96 | await env.shutdown()
97 | return prompt
98 |
99 |
100 | if __name__ == "__main__":
101 | parser = argparse.ArgumentParser()
102 | parser.add_argument(
103 | "--local_dir",
104 | required=True,
105 | help="Local directory where processed train/test parquet files will be written.",
106 | )
107 | parser.add_argument(
108 | "--dataset_name",
109 | required=True,
110 | help="Identifier of the HuggingFace dataset to load (e.g., 'squad', 'wikitext').",
111 | )
112 | parser.add_argument(
113 | "--env_path",
114 | required=True,
115 | help=(
116 | "Dotted path to the BaseEnv subclass to use for preprocessing, "
117 | "e.g. 'benchmax.envs.wikipedia.wiki_env.WikipediaEnv'."
118 | ),
119 | )
120 |
121 | args = parser.parse_args()
122 |
123 | print(f"Loading {args.dataset_name} dataset...", flush=True)
124 | benchmax_cls: Type[BaseEnv] = load_class(args.env_path)
125 | raw_dataset, dataset_path = benchmax_cls.load_dataset(args.dataset_name)
126 |
127 | if isinstance(raw_dataset, (IterableDataset, IterableDatasetDict)):
128 | raise TypeError(
129 | f"Iterable datasets are currently not supported. Got {type(raw_dataset).__name__}. "
130 | )
131 |
132 | if not isinstance(raw_dataset, (DatasetDict, Dataset)):
133 | raise TypeError(
134 | f"Expected DatasetDict or Dataset, but got {type(raw_dataset).__name__}."
135 | )
136 |
137 | print("Getting system prompt...", flush=True)
138 | system_prompt = asyncio.run(get_system_prompt(benchmax_cls))
139 |
140 | # Get canonical class name (strips local/skypilot if parent matches)
141 | canonical_name = get_canonical_class_name(benchmax_cls)
142 |
143 | def process_example(example):
144 | """Single mapping function that does all processing."""
145 | # First apply dataset-specific preprocessing
146 | standardized = benchmax_cls.dataset_preprocess(
147 | example, dataset_path=dataset_path
148 | )
149 |
150 | # Then format as multiturn prompt
151 | prompt = [
152 | {
153 | "role": "system",
154 | "content": system_prompt,
155 | },
156 | {"role": "user", "content": standardized["prompt"]},
157 | ]
158 | result = {
159 | **standardized,
160 | "prompt": prompt,
161 | "env_class": canonical_name,
162 | "data_source": canonical_name,
163 | }
164 |
165 | # Remove keys with None values
166 | result = {k: v for k, v in result.items() if v is not None}
167 |
168 | return result
169 |
170 | print("Processing examples...", flush=True)
171 | processed_dataset = raw_dataset.map(process_example)
172 |
173 | if isinstance(processed_dataset, DatasetDict) and set(
174 | processed_dataset.keys()
175 | ) == set(["train", "test"]):
176 | # If train and test dataset split already exist
177 | train_dataset = processed_dataset["train"]
178 | test_dataset = processed_dataset["test"]
179 | else:
180 | if isinstance(processed_dataset, DatasetDict):
181 | processed_dataset = datasets.concatenate_datasets(
182 | [ds for ds in processed_dataset.values()]
183 | ).shuffle(seed=42)
184 |
185 | split = processed_dataset.train_test_split(
186 | test_size=0.2, seed=42, shuffle=False
187 | )
188 | train_dataset = split["train"]
189 | test_dataset = split["test"]
190 |
191 | print(f"Saving to {args.local_dir}...", flush=True)
192 | local_dir = Path(args.local_dir)
193 | train_dataset.to_parquet(local_dir / "train.parquet")
194 | test_dataset.to_parquet(local_dir / "test.parquet")
195 |
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for MCP environment infrastructure.
3 | """
4 |
5 | import jwt
6 | import time
7 | import asyncio
8 | from contextlib import AsyncExitStack
9 | from pathlib import Path
10 | from typing import Optional, Dict, List, Any
11 | import aiohttp
12 | from fastmcp import Client
13 | from mcp import Tool
14 |
15 | from benchmax.envs.types import ToolDefinition
16 |
17 |
18 | def convert_tool_definitions(
19 | tools: List[Tool], allowed_tools: Optional[List[str]]
20 | ) -> List[ToolDefinition]:
21 | """
22 | Convert MCP Tool objects to ToolDefinition dataclass.
23 |
24 | Args:
25 | tools: List of MCP Tool objects.
26 | allowed_tools: Optional whitelist of tool names. If provided,
27 | only tools in this list are included.
28 |
29 | Returns:
30 | List of ToolDefinition objects.
31 | """
32 | tool_definitions = [
33 | ToolDefinition(
34 | name=tool.name,
35 | description=tool.description or "",
36 | input_schema=tool.inputSchema,
37 | )
38 | for tool in tools
39 | ]
40 |
41 | if not allowed_tools:
42 | return tool_definitions
43 |
44 | return [tool for tool in tool_definitions if tool.name in allowed_tools]
45 |
46 |
47 | def generate_jwt_token(
48 | api_secret: str,
49 | rollout_id: Optional[str] = None,
50 | expiration_seconds: int = 300,
51 | **extra_claims: Any,
52 | ) -> str:
53 | """
54 | Generate a JWT token with standard and custom claims.
55 |
56 | Args:
57 | api_secret: Shared secret for signing (HS256).
58 | rollout_id: Optional rollout ID to include in claims.
59 | expiration_seconds: Token validity duration (default: 5 minutes).
60 | **extra_claims: Additional custom claims to include.
61 |
62 | Returns:
63 | JWT token string.
64 | """
65 | current_time = int(time.time())
66 |
67 | payload = {
68 | # Standard claims
69 | "iss": "mcp-client",
70 | "aud": "mcp-proxy-server",
71 | "iat": current_time,
72 | "exp": current_time + expiration_seconds,
73 | # Custom claims
74 | **extra_claims,
75 | }
76 |
77 | # Add rollout_id if provided
78 | if rollout_id:
79 | payload["rollout_id"] = rollout_id
80 |
81 | # Sign with HS256
82 | token = jwt.encode(payload, api_secret, algorithm="HS256")
83 | return token
84 |
85 |
86 | def get_auth_headers(
87 | api_secret: str, rollout_id: Optional[str] = None, **extra_claims: Any
88 | ) -> Dict[str, str]:
89 | """
90 | Generate Authorization header with JWT token.
91 |
92 | Args:
93 | api_secret: Shared secret for signing.
94 | rollout_id: Optional rollout ID to include in claims.
95 | **extra_claims: Additional custom claims.
96 |
97 | Returns:
98 | Headers dict with Authorization Bearer token.
99 | """
100 | token = generate_jwt_token(api_secret, rollout_id, **extra_claims)
101 | return {"Authorization": f"Bearer {token}"}
102 |
103 |
104 | async def upload_form(
105 | http_session: aiohttp.ClientSession,
106 | upload_url: str,
107 | api_secret: str,
108 | file_bytes: bytes,
109 | filename: str,
110 | rollout_id: Optional[str] = None,
111 | content_type: str = "application/octet-stream",
112 | ) -> None:
113 | """
114 | Upload a file or content to a remote URL using multipart form with JWT auth.
115 |
116 | Args:
117 | http_session: aiohttp client session.
118 | upload_url: URL to upload to.
119 | api_secret: Shared secret for JWT signing.
120 | file_bytes: File content as bytes.
121 | filename: Name for the uploaded file.
122 | rollout_id: Optional rollout ID for JWT claims.
123 | content_type: MIME type of the content.
124 |
125 | Raises:
126 | RuntimeError: If upload fails.
127 | """
128 | # Generate JWT token with rollout_id
129 | headers = get_auth_headers(api_secret, rollout_id)
130 |
131 | # Create multipart form data
132 | data = aiohttp.FormData()
133 | data.add_field("file", file_bytes, filename=filename, content_type=content_type)
134 |
135 | async with http_session.post(upload_url, headers=headers, data=data) as response:
136 | if response.status == 200:
137 | return
138 | error_text = await response.text()
139 | raise RuntimeError(f"Upload failed: {response.status} - {error_text}")
140 |
141 |
142 | async def download_file(
143 | http_session: aiohttp.ClientSession,
144 | download_url: str,
145 | api_secret: str,
146 | params: Dict[str, str],
147 | dst_path: Path,
148 | rollout_id: Optional[str] = None,
149 | ) -> None:
150 | """
151 | Download a file from a remote URL and save it locally with JWT auth.
152 |
153 | Args:
154 | http_session: aiohttp client session.
155 | download_url: URL to download from.
156 | api_secret: Shared secret for JWT signing.
157 | params: Query parameters.
158 | dst_path: Local path to save the downloaded file.
159 | rollout_id: Optional rollout ID for JWT claims.
160 |
161 | Raises:
162 | RuntimeError: If download fails.
163 | """
164 | # Generate JWT token with rollout_id
165 | headers = get_auth_headers(api_secret, rollout_id)
166 |
167 | async with http_session.get(
168 | download_url, headers=headers, params=params
169 | ) as response:
170 | if response.status != 200:
171 | error_text = await response.text()
172 | raise RuntimeError(f"Download failed: {response.status} - {error_text}")
173 |
174 | dst_path.parent.mkdir(parents=True, exist_ok=True)
175 | with open(dst_path, "wb") as f:
176 | async for chunk in response.content.iter_chunked(8192):
177 | f.write(chunk)
178 |
179 |
180 | async def _safe_session_runner(self):
181 | """
182 | Patched version of FastMCPClient._session_runner that catches exceptions.
183 |
184 | This prevents crashes when servers disconnect or restart, logging errors
185 | instead of propagating them.
186 | """
187 | try:
188 | async with AsyncExitStack() as stack:
189 | try:
190 | await stack.enter_async_context(self._context_manager())
191 | self._session_state.ready_event.set()
192 | await self._session_state.stop_event.wait()
193 | except (aiohttp.ClientError, asyncio.CancelledError) as e:
194 | # Common expected errors when server disconnects or restarts
195 | print(f"[INFO] Client session ended: {type(e).__name__}: {e}")
196 | except Exception as e:
197 | # Log unexpected errors
198 | print(f"[WARN] Client session crashed: {e}")
199 | finally:
200 | self._session_state.ready_event.set()
201 | except Exception as e:
202 | # Catch outer-level async exit errors
203 | print(f"[WARN] Session runner outer error: {e}")
204 |
205 |
206 | def apply_fastmcp_patch():
207 | """
208 | Apply the monkey patch to FastMCP Client.
209 |
210 | Call this once at module import to enable graceful session handling.
211 | """
212 | Client._session_runner = _safe_session_runner
213 |
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/README.md:
--------------------------------------------------------------------------------
1 | # Parallel MCP Env
2 |
3 | This directory contains:
4 | ```bash
5 | ├── example_workdir/ # An example of a workdir
6 | ├── parallel_mcp_env.py # Parallel MCP env
7 | ├── provisioners/ # Various methods of provisioning server (e.g. local, skypilot)
8 | ├── proxy_server.py # Proxy server that redirects MCP connection
9 | ├── server_pool.py # Helper to manage servers
10 | └── utils.py # Shared utils
11 | ```
12 |
13 | With parallel MCP env, you can easily run multiple MCP servers in parallel either locally or in the cloud. The environment supports both local development and cloud deployment through various provisioners, making it flexible for different use cases.
14 |
15 | ## Workdir Structure
16 |
17 | The `workdir` directory contains all the files necessary to run your MCP environment. It should include:
18 |
19 | 1. `mcp_config.yaml`: Configuration for spinning up MCP servers
20 | 2. `reward_fn.py`: Implementation of the reward function for evaluating model outputs
21 | 3. `setup.sh`: Script that runs at startup to install dependencies
22 | 4. Additional files: Any other files required by your MCP server (e.g., mcp server code, assets, helper scripts)
23 |
24 | When the environment is initialized, the entire workdir is copied to the execution machine, ensuring all necessary components are available whether running locally or in the cloud.
25 |
26 |
27 | ## Creating a custom environment using MCP
28 |
29 | To create a custom environment using an MCP server (like a calculator, browser, or spreadsheet), you can extend `ParallelMcpEnv`. Here's a quick step-by-step guide using `benchmax.envs.math.math_env.MathEnv` as an example.
30 |
31 | ### 1. **Define a System Prompt**
32 |
33 | This prompt guides the LLM’s behavior. It can include any instruction, such as how to format the answer or when to use tools.
34 |
35 | ```python
36 | SYSTEM_PROMPT = """Please use the tools provided to do any computation.
37 | Write your complete answer on the final line only, within the xml tags .
38 | """
39 | ```
40 |
41 | ### 2. **Configure MCP Server(s)**
42 |
43 | Define the MCP servers to be launched. You can configure one or more in `src/benchmax/envs/math/workdir/mcp_config.yaml`:
44 |
45 | ```yaml
46 | mcpServers:
47 | calculator:
48 | command: uvx
49 | args:
50 | - mcp-server-calculator
51 | ```
52 |
53 | This MCP config will be used to spin up multiple MCP servers, which can run on a single machine (if configured locally) or distributed across multiple nodes in the cloud (if using SkyPilot multi-node).
54 |
55 | ### 3. **Write a Reward Function**
56 |
57 | The reward function evaluates how "correct" the model's output is, based on structured output. Here’s a simple XML-based example:
58 |
59 | Note that `**kwargs` contains all the other fields in your dataset example, so feel free to use them in `reward_fn` calculations.
60 |
61 | ```python
62 | async def text_match_reward(
63 | completion: str,
64 | ground_truth: str,
65 | mcp_client: Client,
66 | workspace: Path,
67 | **kwargs: Any
68 | ) -> float:
69 | """
70 | Reward = 1 if `ground_truth` (case-insensitive) appears anywhere *inside*
71 | the first … block of `completion`; otherwise 0.
72 |
73 | Falls back to 0 if the tag is missing or empty.
74 | """
75 |
76 | # Grab only the text inside the first … pair (case-insensitive).
77 | m = re.search(r'(.*?)', completion, flags=re.IGNORECASE | re.DOTALL)
78 | if m is None:
79 | return 0.0
80 |
81 | # Unescape any XML entities (& → &, etc.) and normalise whitespace.
82 | answer_text = unescape(m.group(1)).strip().lower()
83 |
84 | try:
85 | # Try to interpret both as floats for numerical comparison.
86 | return float(float(ground_truth.lower()) == float(answer_text))
87 | except ValueError:
88 | return 0.0
89 | ```
90 |
91 | Check out [math env's reward_fn.py](/src/benchmax/envs/math/workdir/reward_fn.py) for the full implementation. You can reference [this example reward_fn.py](/src/benchmax/envs/mcp/example_workdir/reward_fn.py) for a comprehensive usage of various ways to compute reward.
92 |
93 | ### 4. Define **`dataset_preprocess`**
94 |
95 | If your dataset is not already standardized, implement this method to convert a raw example into a standardized one with:
96 |
97 | - `"prompt"`: A fully constructed string prompt.
98 | - `"ground_truth"`: A known correct output (optional depending on reward).
99 | - `"init_rollout_args"`: Arguments needed to initialize a rollout.
100 |
101 | Example for our math task:
102 |
103 | ```python
104 | def dataset_preprocess(self, example: dict) -> StandardizedExample:
105 | return StandardizedExample(
106 | prompt=example.get("task", ""),
107 | ground_truth=example.get("answer", ""),
108 | init_rollout_args={}
109 | )
110 | ```
111 |
112 |
113 | Notes on init_rollout_args
114 | The `init_rollout_args` dictionary is passed from `dataset_preprocess()` to your environment's `init_rollout()` method. It is used to initialize any **per-example files, resources, or execution context** needed before a rollout begins.
115 |
116 | Common use cases include:
117 |
118 | - **Input files**: For environments that manipulate files like spreadsheets, images, or databases, pass the necessary file paths.
119 | - **Version control**: For code-related tasks, you might pass a `commit_id` to check out the correct code state.
120 | - **Task-specific settings**: Pass metadata like cell ranges, task IDs, or execution flags.
121 |
122 | Example:
123 |
124 | ```python
125 | # Inside dataset_preprocess
126 | return {
127 | "prompt": "...",
128 | "ground_truth": "...",
129 | "init_rollout_args": {
130 | "spreadsheet_path": "/path/to/1_001_input.xlsx"
131 | }
132 | }
133 | ```
134 |
135 | Then in your `init_rollout()` method:
136 |
137 | ```python
138 | def init_rollout(self, rollout_id: str, **rollout_args):
139 | spreadsheet_path = rollout_args["spreadsheet_path"]
140 | workspace = self.get_rollout_workspace(rollout_id)
141 |
142 | # Copy the input file into the rollout's workspace
143 | shutil.copy(spreadsheet_path, workspace / Path(spreadsheet_path).name)
144 | ```
145 |
146 | This pattern ensures each rollout starts with the correct inputs and configuration.
147 |
148 |
149 | ### 5. **Extend `ParallelMcpEnv`**
150 |
151 | Now bring everything together into a custom environment class:
152 |
153 | ```python
154 | SYSTEM_PROMPT = """Please use the tools provided to do any computation.
155 | Write your complete answer on the final line only, within the xml tags .\n
156 | """
157 |
158 | class MathEnv(ParallelMcpEnv):
159 | """Environment for math problems, using local MCP tools."""
160 |
161 | system_prompt: str = SYSTEM_PROMPT
162 |
163 | def __init__(self, workdir_path: Path, provisioner: BaseProvisioner, **kwargs):
164 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs)
165 |
166 | @classmethod
167 | def dataset_preprocess(cls, example: Any, **kwargs) -> StandardizedExample:
168 | return StandardizedExample(
169 | prompt=example.get("task", ""),
170 | ground_truth=example.get("answer", ""),
171 | init_rollout_args=None,
172 | )
173 | ```
174 |
175 | You're done! This environment is now compatible with `benchmax` and can be plugged into any compatible RL trainer.
176 |
177 | In [math_env.py](/src/benchmax/envs/math/math_env.py), we have also provided `MathEnvLocal` class to run the env locally and `MathEnvSkypilot` to run it remotely across multiple machines.
178 |
--------------------------------------------------------------------------------
/tests/unit/envs/excel/test_excel_env.py:
--------------------------------------------------------------------------------
1 | """
2 | Unit tests for ExcelEnv.
3 |
4 | All tests are fast with no external service calls.
5 | """
6 |
7 | import pytest
8 | from pathlib import Path
9 | from typing import Dict, Any
10 | from unittest.mock import Mock, patch
11 |
12 | from benchmax.envs.excel.excel_env import ExcelEnv
13 | from benchmax.envs.excel.workdir.reward_fn import spreadsheet_comparison_reward
14 | from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
15 |
16 |
17 | # Fixtures
18 | @pytest.fixture(scope="session")
19 | def test_xlsx_path() -> str:
20 | return str(Path(__file__).parent / "test_inputs" / "test.xlsx")
21 |
22 |
23 | @pytest.fixture
24 | def excel_env(tmp_path: Path) -> ExcelEnv:
25 | """Fixture to create an ExcelEnv instance without initializing the parent."""
26 | env = ExcelEnv(
27 | workdir_path=tmp_path,
28 | provisioner=Mock(spec=BaseProvisioner),
29 | provision_at_init=False,
30 | )
31 | env._servers_provisioned = True
32 | return env
33 |
34 |
35 | @pytest.fixture
36 | def mock_dataset(tmp_path: Path) -> Path:
37 | """Create a fake dataset folder with a sample input file.
38 |
39 | The ExcelEnv expects the dataset layout to contain a folder (spreadsheet_path)
40 | with files like `1__input.xlsx`. We touch a file to satisfy existence
41 | checks; reading is patched in tests that need it.
42 | """
43 | base = tmp_path / "dataset_root"
44 | sheet_dir = base / "sheet_folder"
45 | sheet_dir.mkdir(parents=True)
46 |
47 | # Create a dummy input file path that ExcelEnv will look for
48 | (sheet_dir / "1_42_input.xlsx").write_text("dummy")
49 |
50 | # Also create an answer file (not strictly required for dataset_preprocess)
51 | (sheet_dir / "1_42_answer.xlsx").write_text("dummy-answer")
52 |
53 | return base
54 |
55 |
56 | @pytest.fixture
57 | def mock_mcp_client() -> Mock:
58 | return Mock()
59 |
60 |
61 | @pytest.fixture
62 | def mock_workspace(tmp_path: Path) -> Path:
63 | ws = tmp_path / "workspace"
64 | ws.mkdir()
65 | return ws
66 |
67 |
68 | class TestDatasetPreprocess:
69 | """Tests for ExcelEnv.dataset_preprocess."""
70 |
71 | def test_valid_example(
72 | self, mock_dataset: Path, test_xlsx_path: str
73 | ) -> None:
74 | """Valid example returns an ExcelExample with expected fields."""
75 | example: Dict[str, Any] = {
76 | "id": "42",
77 | "spreadsheet_path": test_xlsx_path,
78 | "instruction": "Fill cell A1 with 10",
79 | "instruction_type": "Cell-Level Manipulation",
80 | "answer_position": "A1",
81 | }
82 |
83 | with patch(
84 | "benchmax.envs.excel.excel_env.excel_to_str_repr",
85 | return_value="A1=10",
86 | ):
87 | result = ExcelEnv.dataset_preprocess(example, dataset_path=mock_dataset)
88 |
89 | # Result is a mapping-like StandardizedExample (TypedDict / dataclass)
90 | assert result is not None
91 | assert "Fill cell A1 with 10" in result["prompt"]
92 | assert "1_42_input.xlsx" in result["prompt"]
93 | assert result["init_rollout_args"] is not None
94 | assert result["init_rollout_args"]["input_src_path"].endswith("1_42_input.xlsx")
95 | assert result["answer_position"] == "A1"
96 | assert result["output_filename"] == "1_42_output.xlsx"
97 |
98 | def test_missing_fields_raise(self, excel_env: ExcelEnv) -> None:
99 | """Missing required fields should raise ValueError."""
100 | example: Dict[str, Any] = {"id": "1", "instruction": "do"}
101 |
102 | with pytest.raises(ValueError):
103 | excel_env.dataset_preprocess(example)
104 |
105 | def test_non_string_spreadsheet_path_raises(
106 | self, mock_dataset: Path
107 | ) -> None:
108 | """Non-string spreadsheet_path should raise ValueError."""
109 | example: Dict[str, Any] = {
110 | "id": "1",
111 | "spreadsheet_path": 123, # invalid type
112 | "instruction": "x",
113 | "instruction_type": "Cell-Level Manipulation",
114 | "answer_position": "A1",
115 | }
116 |
117 | with pytest.raises(TypeError):
118 | ExcelEnv.dataset_preprocess(example, dataset_path=mock_dataset)
119 |
120 | def test_missing_spreadsheet_folder_raises(
121 | self, tmp_path: Path
122 | ) -> None:
123 | """If the spreadsheet folder does not exist under the dataset path, raise FileNotFoundError."""
124 | dataset_path = tmp_path / "some_other_root"
125 |
126 | example: Dict[str, Any] = {
127 | "id": "99",
128 | "spreadsheet_path": "no_such_folder",
129 | "instruction": "x",
130 | "instruction_type": "Cell-Level Manipulation",
131 | "answer_position": "A1",
132 | }
133 |
134 | with pytest.raises(FileNotFoundError):
135 | ExcelEnv.dataset_preprocess(example, dataset_path=dataset_path)
136 |
137 |
138 | class TestRewardComputation:
139 | """Test reward computation for spreadsheet comparison."""
140 |
141 | def test_exact_match_returns_one(
142 | self, mock_mcp_client: Mock, mock_workspace: Path
143 | ) -> None:
144 | # Patch compare_excel_cells to return match=True
145 | with patch(
146 | "benchmax.envs.excel.workdir.reward_fn.compare_excel_cells",
147 | return_value=(True, None),
148 | ):
149 | score = spreadsheet_comparison_reward(
150 | completion="irrelevant",
151 | ground_truth={},
152 | mcp_client=mock_mcp_client,
153 | workspace=mock_workspace,
154 | answer_position="A1",
155 | output_filename="out.xlsx",
156 | ground_truth_filename="gt.xlsx",
157 | )
158 |
159 | assert score == 1.0
160 |
161 | def test_mismatch_returns_zero(
162 | self, mock_mcp_client: Mock, mock_workspace: Path
163 | ) -> None:
164 | with patch(
165 | "benchmax.envs.excel.workdir.reward_fn.compare_excel_cells",
166 | return_value=(False, None),
167 | ):
168 | score = spreadsheet_comparison_reward(
169 | completion="irrelevant",
170 | ground_truth={},
171 | mcp_client=mock_mcp_client,
172 | workspace=mock_workspace,
173 | answer_position="A1",
174 | output_filename="out.xlsx",
175 | ground_truth_filename="gt.xlsx",
176 | )
177 |
178 | assert score == 0.0
179 |
180 | def test_missing_kwargs_raises(
181 | self, mock_mcp_client: Mock, mock_workspace: Path
182 | ) -> None:
183 | with pytest.raises(ValueError):
184 | spreadsheet_comparison_reward(
185 | completion="x",
186 | ground_truth={},
187 | mcp_client=mock_mcp_client,
188 | workspace=mock_workspace,
189 | # missing answer_position, output_filename, ground_truth_filename
190 | )
191 |
192 | def test_compare_raises_returns_zero(
193 | self, mock_mcp_client: Mock, mock_workspace: Path
194 | ) -> None:
195 | with patch(
196 | "benchmax.envs.excel.workdir.reward_fn.compare_excel_cells",
197 | side_effect=Exception("boom"),
198 | ):
199 | score = spreadsheet_comparison_reward(
200 | completion="irrelevant",
201 | ground_truth={},
202 | mcp_client=mock_mcp_client,
203 | workspace=mock_workspace,
204 | answer_position="A1",
205 | output_filename="out.xlsx",
206 | ground_truth_filename="gt.xlsx",
207 | )
208 |
209 | assert score == 0.0
210 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## benchmax: Framework-Agnostic RL Environments for LLM Fine-Tuning
6 | *A lightweight, training-framework agnostic library for defining, running, and parallelizing environments, to fine-tune OSS LLMs with reinforcement learning.*
7 |
8 |
9 |
17 |
18 |

19 |
20 |
21 | ## 📌 News
22 |
23 | - **[29 Oct 2025]** 🎉 Added support for easy multi-node parallelization across all major cloud providers using [SkyPilot](https://github.com/skypilot-org/skypilot)
24 | - **[29 Oct 2025]** 🎉 Integration with [SkyRL](https://github.com/NovaSky-AI/SkyRL) for distributed RL training across clusters
25 | - **[Upcoming]** 🛠️ Integration with Tinker API.
26 |
27 | ## 📘 Quickstart
28 |
29 | **Example: Multi-node parallelization of Excel Env with SkyRL and SkyPilot**
30 |
31 | RL environments can be computationally expensive to run (e.g. running tests). To handle these workloads efficiently, we distribute rollouts across multiple nodes using **SkyPilot**, horizontally scaling `benchmax` across cloud providers like GCP, AWS, Azure, etc.
32 |
33 | **SkyRL** is a training framework `benchmax` is currently integrated with. Use our ***SkyRL*** integration to RL finetune Qwen-2.5 to do spreadsheet manipulation using a excel MCP parallelized across multiple nodes. The environment is defined in [`benchmax.envs.excel.excel_env.ExcelEnvSkypilot`](/src/benchmax/envs/excel/excel_env.py)
34 |
35 | 1. **Prepare the dataset**
36 |
37 | ```bash
38 | uv run src/benchmax/adapters/skyrl/benchmax_data_process.py \
39 | --local_dir ~/data/excel \
40 | --dataset_name spreadsheetbench \
41 | --env_path benchmax.envs.excel.excel_env.ExcelEnvLocal
42 | ```
43 |
44 | Note: We are using `ExcelEnvLocal` instead of `ExcelEnvSkypilot` because the MCP is only used for listing tools to prepare the system prompt.
45 |
46 | 2. **Run training and parallelize Excel environment**
47 |
48 | ```bash
49 | bash examples/skyrl/run_benchmax_excel.sh
50 | ```
51 |
52 | This excel env example will spin up 5 nodes with 20 servers per node (total 100 MCP server in parallel). For more details, check out [multi-node parallelization](/src/benchmax/envs/mcp/README.md) and [SkyRL integration](/examples/skyrl/README.md).
53 |
54 | ## ℹ️ Overview
55 |
56 | `benchmax` comes with:
57 |
58 | - A collection of ready-to-use reinforcement learning (RL) environments for LLM fine-tuning ranging from multi-hop search to spreadsheet manipulation to CRM agents
59 | - An easy to define, compose, and parallelize your own environments, including leveraging the existing ecosystem of MCP servers
60 | - Built-in integrations with popular RL training libraries (skyrl, etc.). `benchmax` is trainer-agnostic by design
61 |
62 | Define your environment as:
63 |
64 | 1. A **toolset** (LLM calls, external APIs, calculators, MCPs, etc.).
65 | 2. **Output parsing** logic to extract structured observations.
66 | 3. **Reward functions** to score model outputs.
67 |
68 | Rollout management, parallel execution, etc. comes out of the box.
69 |
70 | ⭐ Star our repository to show your support!
71 |
72 | ## 💡 Core Features
73 |
74 | **Built-in examples & templates**
75 |
76 | Get started with ready to use recipes, from Wikipedia search to spreadsheet manipulation. Easy to copy, customize, and extend. And yes, more are on the way.
77 |
78 | **Trainer integrations**
79 |
80 | Use your own trainer or training framework - no lock-in. `benchmax` is already integrated into SkyRL, with more integrations (Tinker, etc.) coming soon!
81 |
82 | **MCP support**
83 |
84 | Tap into the growing MCP ecosystem and integrate them as tools within your environments.
85 |
86 | **Multi-node parallel execution**
87 |
88 | Multi-node parallelization enabled out of the box with state isolation across roll-outs (e.g. editing files on filesystem, etc.).
89 |
90 |
91 | ## 🌐 Creating & Training with Environments
92 |
93 | ### What is an environment?
94 |
95 | An environment consists of:
96 |
97 | - A list of tools that an LLM can call
98 | - A list of reward functions that evaluate the quality & correctness of the model's final output.
99 |
100 | We also support MCP servers natively, allowing you to easily leverage the many servers built by the community.
101 |
102 | ### Pre-built environments
103 |
104 | Ready-to-use environments with pre-configured tools and reward functions.
105 |
106 | - [CRM](/src/benchmax/envs/crm/README.md)
107 | - [Excel](/src/benchmax/envs/excel/README.md)
108 | - [Math](/src/benchmax/envs/math/README.md)
109 | - [Wikipedia](/src/benchmax/envs/wikipedia/README.md)
110 |
111 | ### How do I create a custom environment?
112 |
113 | 1. [With existing MCP servers](/src/benchmax/envs/mcp/README.md) (Built-in support for multi-node parallelization)
114 |
115 | 2. [Extend BaseEnv](/src/benchmax/envs/README.md)
116 |
117 | ### How about more complex environments?
118 |
119 | - Check out our excel spreadsheet RL environment: `benchmax.envs.excel.excel_env.ExcelEnv`
120 |
121 | ### How do I use an environment with my preferred RL Trainer?
122 |
123 | We currently have integrations with SkyRL. More incoming!
124 |
125 | [`benchmax` environments with skyrl](/examples/skyrl/README.md)
126 |
127 | ### I want a specific environment
128 |
129 | Open an issue and tag us & we will look into building you one!
130 |
131 | ---
132 |
133 | ## 🎯 Motivation
134 |
135 | - **Modularity and Simplicity**:
136 |
137 | We set out to build a lightweight, modular system for defining RL environments—breaking them down into simple, composable parts: tools, tool output parsing, and reward functions.
138 |
139 | The goal’s to make it easy for software engineers to build and experiment with RL environments without needing deep RL expertise.
140 |
141 | - **Trainer Integrations**:
142 |
143 | There’s been lots of new RL training frameworks popping up (e.g., numerous forks of verl) & we expect this to continue. They are often tightly coupled with specific environments, leading to fragmentation and limited compatibility.
144 |
145 | We are building `benchmax` as a standalone library with integrations to these different training frameworks & as an easy way for new frameworks to tap into an existing pool of environments. We're already integrated with SkyRL (Tinker coming soon)!
146 |
147 | - **Task Recipes and Ideas**:
148 |
149 | We want `benchmax` to be a living library of reusable, RL-compatible task recipes, ready to inspire and extend beyond the usual suspects like math and coding. We aim to support more real-world workflows, including open-ended and long-horizon tasks.
150 |
151 | - **Parallelization and Cloud Compatibility**:
152 | - Enable efficient parallelization with maintained statefulness between rollouts.
153 | - Facilitate easy deployment and scalability in cloud environments.
154 |
155 | - **MCP as a first class citizen**:
156 |
157 | There has been an explosion of MCP servers/tools built out for use-cases ranging from browser use to excel to game creation.`benchmax` allows folks to leverage and compose these existing MCP servers to build environments integrated with real world systems e.g. excel
158 |
159 |
160 | ## 🤝 Contributing
161 |
162 | We welcome new environment recipes, bug reports, and trainer integrations!
163 |
164 | ⭐ Star our repository to show your support!
165 |
166 | ## 📜 License
167 |
168 | Apache 2.0 © 2025 CGFT Inc.
169 |
--------------------------------------------------------------------------------
/tests/unit/envs/wikipedia/test_wiki_env.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unittest.mock import AsyncMock, MagicMock, patch
3 | from aiohttp import ClientResponse
4 |
5 | from benchmax.envs.wikipedia.wiki_env import WikipediaEnv
6 |
7 |
8 | @pytest.fixture
9 | def wiki_env() -> WikipediaEnv:
10 | """Create WikipediaEnv instance without API keys for unit tests."""
11 | return WikipediaEnv(wikipedia_api_keys=None)
12 |
13 |
14 | class TestDatasetPreprocess:
15 | """Tests for dataset preprocessing."""
16 |
17 | @pytest.mark.asyncio
18 | async def test_dataset_preprocess_valid_example(
19 | self, wiki_env: WikipediaEnv
20 | ) -> None:
21 | """Test preprocessing a valid dataset example."""
22 | example = {"Question": "Who created Python?", "Answer": "Guido van Rossum"}
23 |
24 | result = wiki_env.dataset_preprocess(example)
25 |
26 | assert result["prompt"] == "Who created Python?"
27 | assert result["ground_truth"] == "Guido van Rossum"
28 |
29 |
30 | class TestComputeReward:
31 | """Tests for reward computation."""
32 |
33 | @pytest.mark.asyncio
34 | async def test_compute_reward_exact_match(self, wiki_env: WikipediaEnv) -> None:
35 | """Test that exact match returns 1.0."""
36 | completion = "The answer is Paris"
37 | ground_truth = "Paris"
38 |
39 | rewards = await wiki_env.compute_reward(
40 | rollout_id="test", completion=completion, ground_truth=ground_truth
41 | )
42 |
43 | assert rewards["text_match"] == 1.0
44 |
45 | @pytest.mark.asyncio
46 | async def test_compute_reward_case_insensitive(
47 | self, wiki_env: WikipediaEnv
48 | ) -> None:
49 | """Test that matching is case-insensitive."""
50 | completion = "The answer is PARIS"
51 | ground_truth = "paris"
52 |
53 | rewards = await wiki_env.compute_reward(
54 | rollout_id="test", completion=completion, ground_truth=ground_truth
55 | )
56 |
57 | assert rewards["text_match"] == 1.0
58 |
59 | @pytest.mark.asyncio
60 | async def test_compute_reward_no_match(self, wiki_env: WikipediaEnv) -> None:
61 | """Test that wrong answer returns 0.0."""
62 | completion = "The answer is London"
63 | ground_truth = "Paris"
64 |
65 | rewards = await wiki_env.compute_reward(
66 | rollout_id="test", completion=completion, ground_truth=ground_truth
67 | )
68 |
69 | assert rewards["text_match"] == 0.0
70 |
71 | @pytest.mark.asyncio
72 | async def test_compute_reward_missing_answer_tags(
73 | self, wiki_env: WikipediaEnv
74 | ) -> None:
75 | """Test that missing answer tags returns 0.0."""
76 | completion = "The answer is Paris"
77 | ground_truth = "Paris"
78 |
79 | rewards = await wiki_env.compute_reward(
80 | rollout_id="test", completion=completion, ground_truth=ground_truth
81 | )
82 |
83 | assert rewards["text_match"] == 0.0
84 |
85 |
86 | class TestListTools:
87 | """Tests for listing available tools."""
88 |
89 | @pytest.mark.asyncio
90 | async def test_list_tools_returns_two_tools(self, wiki_env: WikipediaEnv) -> None:
91 | """Test that exactly 2 tools are returned with correct names."""
92 | tools = await wiki_env.list_tools()
93 |
94 | assert len(tools) == 2
95 |
96 | tool_names = {tool.name for tool in tools}
97 | assert "search_wikipedia" in tool_names
98 | assert "get_wikipedia_article" in tool_names
99 |
100 | # Verify each tool has required schema properties
101 | for tool in tools:
102 | assert tool.name is not None
103 | assert tool.description is not None
104 | assert tool.input_schema is not None
105 | assert "type" in tool.input_schema
106 | assert "properties" in tool.input_schema
107 |
108 |
109 | class TestRunTool:
110 | """Tests for tool execution with mocked HTTP calls."""
111 |
112 | @pytest.mark.asyncio
113 | async def test_run_tool_search_wikipedia_success(
114 | self, mock_safe_request: AsyncMock, wiki_env: WikipediaEnv
115 | ) -> None:
116 | """Test successful Wikipedia search with mocked response."""
117 | mock_response = MagicMock(spec=ClientResponse)
118 | mock_response.status = 200
119 | mock_response.json = AsyncMock(
120 | return_value={
121 | "query": {
122 | "search": [
123 | {
124 | "title": "Python (programming language)",
125 | "snippet": "Python is a high-level programming language",
126 | },
127 | {
128 | "title": "Python (genus)",
129 | "snippet": "Python is a genus of constricting snakes",
130 | },
131 | ]
132 | }
133 | }
134 | )
135 | mock_safe_request.return_value = mock_response
136 |
137 | with patch("benchmax.envs.wikipedia.wiki_env.safe_request"):
138 | result = await wiki_env.run_tool(
139 | rollout_id="test", tool_name="search_wikipedia", q="Python", limit=5
140 | )
141 |
142 | assert isinstance(result, str)
143 | assert "Python (programming language)" in result
144 | assert "high-level programming language" in result
145 |
146 | @pytest.mark.asyncio
147 | @patch("benchmax.envs.wikipedia.wiki_env.safe_request")
148 | async def test_run_tool_get_article_success(
149 | self, mock_safe_request: AsyncMock, wiki_env: WikipediaEnv
150 | ) -> None:
151 | """Test successful article fetch with mocked response."""
152 | mock_response = MagicMock(spec=ClientResponse)
153 | mock_response.status = 200
154 | mock_response.json = AsyncMock(
155 | return_value={
156 | "query": {
157 | "pages": {
158 | "12345": {
159 | "extract": "Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability with the use of significant indentation."
160 | }
161 | }
162 | }
163 | }
164 | )
165 | mock_safe_request.return_value = mock_response
166 |
167 | result = await wiki_env.run_tool(
168 | rollout_id="test",
169 | tool_name="get_wikipedia_article",
170 | title="Python (programming language)",
171 | max_chars=1000,
172 | )
173 |
174 | assert isinstance(result, str)
175 | assert "Python is a high-level" in result
176 | assert "code readability" in result
177 |
178 | @pytest.mark.asyncio
179 | async def test_run_tool_missing_required_param(
180 | self, wiki_env: WikipediaEnv
181 | ) -> None:
182 | """Test error handling when required parameter is missing."""
183 | # Search without query
184 | result = await wiki_env.run_tool(
185 | rollout_id="test",
186 | tool_name="search_wikipedia",
187 | q="", # Empty query
188 | limit=5,
189 | )
190 |
191 | assert isinstance(result, str)
192 | assert "Error" in result
193 |
194 |
195 | class TestInitRollout:
196 | """Tests for rollout initialization."""
197 |
198 | @pytest.mark.asyncio
199 | async def test_init_rollout_completes(self, wiki_env: WikipediaEnv) -> None:
200 | """Test that init_rollout completes without errors (no-op)."""
201 | # Should complete without raising exceptions
202 | await wiki_env.init_rollout(rollout_id="test_rollout_1")
203 | await wiki_env.init_rollout(rollout_id="test_rollout_2", extra_arg="value")
204 |
205 | # No assertions needed - just verify no exceptions raised
206 | assert True
207 |
--------------------------------------------------------------------------------
/src/benchmax/envs/mcp/provisioners/skypilot_provisioner.py:
--------------------------------------------------------------------------------
1 | """
2 | SkyPilot provisioner for launching cloud-based server clusters.
3 | """
4 |
5 | import logging
6 | import uuid
7 | from pathlib import Path
8 | from typing import List, Optional
9 | import sky
10 |
11 | from .base_provisioner import BaseProvisioner
12 | from .utils import get_run_command, setup_sync_dir, cleanup_dir, get_setup_command
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class SkypilotProvisioner(BaseProvisioner):
18 | """
19 | Provisioner that launches a SkyPilot cluster in the cloud.
20 |
21 | Use this for:
22 | - Production-scale parallel execution
23 | - Distributed benchmarking across many nodes
24 | - Cloud-based compute resource management
25 |
26 | Example:
27 | import sky
28 | provisioner = SkypilotProvisioner(
29 | workdir_path=Path("my_workdir"),
30 | cloud=sky.Azure(),
31 | num_nodes=5,
32 | servers_per_node=4,
33 | cpus=4,
34 | memory=16,
35 | )
36 | # Will provision 20 total servers (5 nodes * 4 servers/node)
37 | """
38 |
39 | def __init__(
40 | self,
41 | workdir_path: Path | str,
42 | cloud: sky.clouds.Cloud,
43 | num_nodes: int = 1,
44 | servers_per_node: int = 5,
45 | cpus: Optional[str | int] = "2+",
46 | memory: Optional[str | int] = "8+",
47 | base_cluster_name: str = "benchmax-env-cluster",
48 | ):
49 | """
50 | Initialize SkyPilot provisioner.
51 |
52 | Args:
53 | workdir_path: Path to workdir containing mcp_config.yaml, setup.sh, etc.
54 | cloud: SkyPilot cloud instance (e.g., sky.AWS(), sky.Azure(), sky.GCP()).
55 | num_nodes: Number of nodes in the cluster.
56 | servers_per_node: Number of proxy servers to run on each node.
57 | cpus: CPU requirement per node (e.g., "2+", 4, "8").
58 | memory: Memory requirement per node in GB (e.g., "16+", 32).
59 | base_cluster_name: Base name for the cluster (timestamp will be appended).
60 | """
61 | if num_nodes < 1:
62 | raise ValueError("num_nodes must be at least 1")
63 | if servers_per_node < 1 or servers_per_node > 100:
64 | raise ValueError("servers_per_node must be between 1 and 100")
65 |
66 | self._workdir_path = Path(workdir_path).absolute()
67 | self._cloud = cloud
68 | self._num_nodes = num_nodes
69 | self._servers_per_node = servers_per_node
70 | self._total_num_servers = num_nodes * servers_per_node
71 | self._cpus = cpus
72 | self._memory = memory
73 |
74 | # Generate unique cluster name
75 | unique_suffix = uuid.uuid4().hex[:4]
76 | self._cluster_name = f"{base_cluster_name}-{unique_suffix}"
77 |
78 | # Internal state
79 | self._sync_workdir: Optional[Path] = None
80 | self._cluster_provisioned: bool = False
81 |
82 | logger.info(
83 | f"SkypilotProvisioner configured: {num_nodes} nodes * {servers_per_node} servers/node = "
84 | f"{self._total_num_servers} total servers, cluster: '{self._cluster_name}'"
85 | )
86 |
87 | @property
88 | def num_servers(self) -> int:
89 | """
90 | Total number of servers
91 | """
92 | return self._total_num_servers
93 |
94 | async def provision_servers(self, api_secret: str) -> List[str]:
95 | """
96 | Launch SkyPilot cluster and return server addresses.
97 |
98 | Args:
99 | api_secret: API token for server authentication.
100 |
101 | Returns:
102 | List of server addresses in "host:port" format.
103 | """
104 | if self._cluster_provisioned:
105 | raise RuntimeError("Cluster already provisioned. Call teardown() first.")
106 |
107 | self._cluster_provisioned = True
108 |
109 | logger.info(f"Launching SkyPilot cluster '{self._cluster_name}'...")
110 |
111 | # Setup sync directory with workdir contents + proxy_server.py
112 | try:
113 | self._sync_workdir = setup_sync_dir(self._workdir_path)
114 | logger.debug(f"Synced workdir to temporary directory: {self._sync_workdir}")
115 | except Exception as e:
116 | logger.error(f"Failed to setup sync directory: {e}")
117 | raise
118 |
119 | # Calculate ports
120 | base_port = 8080
121 | all_ports = [str(base_port + i) for i in range(self._servers_per_node)]
122 |
123 | env = None if api_secret is None else {"API_SECRET": api_secret}
124 |
125 | # Configure SkyPilot task
126 | sky_task = sky.Task(
127 | name="mcp-server",
128 | run=get_run_command(ports=all_ports),
129 | setup=get_setup_command(),
130 | workdir=str(self._sync_workdir),
131 | num_nodes=self._num_nodes,
132 | envs=env,
133 | )
134 |
135 | sky_task.set_resources(
136 | sky.Resources(
137 | cloud=self._cloud,
138 | cpus=self._cpus,
139 | memory=self._memory,
140 | ports=all_ports,
141 | )
142 | )
143 |
144 | # Launch cluster
145 | logger.info(
146 | f"Submitting cluster launch: {self._num_nodes} nodes, "
147 | f"{self._cpus} CPUs, {self._memory}GB memory per node"
148 | )
149 | cluster_handle = None
150 | try:
151 | _, handle = sky.launch(
152 | task=sky_task,
153 | cluster_name=self._cluster_name,
154 | detach_run=True,
155 | detach_setup=True,
156 | retry_until_up=True,
157 | )
158 | cluster_handle = handle
159 | except Exception as e:
160 | logger.error(f"Failed to launch SkyPilot cluster: {e}")
161 | cleanup_dir(self._sync_workdir)
162 | self._sync_workdir = None
163 | raise RuntimeError(f"SkyPilot cluster launch failed: {e}") from e
164 |
165 | if cluster_handle is None:
166 | cleanup_dir(self._sync_workdir)
167 | self._sync_workdir = None
168 | raise RuntimeError("SkyPilot launch returned no handle")
169 |
170 | # Collect server addresses
171 | addresses = []
172 | for node_idx, (_, node_ip) in enumerate(
173 | cluster_handle.stable_internal_external_ips
174 | ):
175 | for port in all_ports:
176 | addresses.append(f"{node_ip}:{port}")
177 | logger.debug(f"Node {node_idx}: {node_ip} with {len(all_ports)} servers")
178 |
179 | logger.info(
180 | f"Successfully launched cluster '{self._cluster_name}' "
181 | f"with {len(addresses)} servers across {self._num_nodes} node(s)"
182 | )
183 | return addresses
184 |
185 | async def teardown(self) -> None:
186 | """
187 | Tear down SkyPilot cluster and clean up resources.
188 | """
189 | if self._cluster_provisioned is None:
190 | logger.warning("teardown() called but no cluster is active.")
191 | return
192 |
193 | logger.info(f"Tearing down SkyPilot cluster '{self._cluster_name}'...")
194 | try:
195 | sky.down(cluster_name=self._cluster_name)
196 | logger.info(f"Cluster '{self._cluster_name}' torn down successfully")
197 | except Exception as e:
198 | logger.error(f"Error tearing down cluster '{self._cluster_name}': {e}")
199 | finally:
200 | self._cluster_provisioned = False
201 | if self._sync_workdir:
202 | cleanup_dir(self._sync_workdir)
203 | self._sync_workdir = None
204 | logger.debug("Cleaned up sync directory")
205 |
206 | @property
207 | def cluster_name(self) -> str:
208 | """
209 | Unique name of the SkyPilot cluster.
210 | """
211 | return self._cluster_name
212 |
--------------------------------------------------------------------------------
/src/benchmax/envs/excel/workdir/excel_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import shutil
4 | import subprocess
5 | import platform
6 | import tempfile
7 | import datetime
8 | from typing import Tuple
9 |
10 | WHITE_LIKE_COLORS = [
11 | "00000000",
12 | "FFFFFFFF",
13 | "FFFFFF00",
14 | ]
15 |
16 |
17 | def evaluate_excel(excel_path: Path | str):
18 | """
19 | Evaluate Python code that manipulates an Excel file using xlwings.
20 | """
21 | if platform.system() == "Linux":
22 | # Use LibreOffice for Linux
23 | evaluate_excel_libre(excel_path)
24 | return
25 | else:
26 | # Use xlwings for Windows and MacOS (assuming Excel is installed)
27 | import xlwings # type: ignore
28 |
29 | excel_app = xlwings.App(visible=False)
30 | excel_book = excel_app.books.open(excel_path)
31 | excel_book.save()
32 | excel_book.close()
33 | excel_app.quit()
34 |
35 |
36 | def evaluate_excel_libre(excel_path: Path | str) -> None:
37 | """
38 | Force‑recalculate in place under Linux using LibreOffice.
39 | Raises subprocess.CalledProcessError if soffice exits abnormally.
40 | """
41 | tmp_outdir = tempfile.mkdtemp(prefix="lo_convert_")
42 | cmd = [
43 | "soffice",
44 | "--headless",
45 | "--nologo",
46 | "--nofirststartwizard",
47 | "--norestore",
48 | "--calc",
49 | "--convert-to",
50 | "xlsx",
51 | "--outdir",
52 | tmp_outdir,
53 | os.path.abspath(excel_path),
54 | ]
55 | lo_home = Path(tempfile.mkdtemp(prefix="lo_profile_"))
56 | env = dict(os.environ, HOME=str(lo_home))
57 | try:
58 | subprocess.run(
59 | cmd,
60 | check=True,
61 | stdout=subprocess.DEVNULL,
62 | stderr=subprocess.STDOUT,
63 | text=True,
64 | env=env,
65 | )
66 | # Determine the converted file name (same base name, .xlsx extension)
67 | base_name = os.path.splitext(os.path.basename(excel_path))[0] + ".xlsx"
68 | converted_path = os.path.join(tmp_outdir, base_name)
69 | # Overwrite the original file with the converted one
70 | shutil.move(converted_path, excel_path)
71 | finally:
72 | # Clean up the temp folder
73 | shutil.rmtree(tmp_outdir, ignore_errors=True)
74 | pass
75 |
76 |
77 | def excel_to_str_repr(excel_path: Path | str, evaluate_formulas=False) -> str:
78 | from openpyxl import load_workbook
79 |
80 | # Load workbook twice: data_only=True to get the evaluated values,
81 | # and data_only=False to get the formulas and styles.
82 | if evaluate_formulas:
83 | evaluate_excel(excel_path)
84 |
85 | wb_evaluated = load_workbook(excel_path, data_only=True)
86 | wb_raw = load_workbook(excel_path, data_only=False)
87 |
88 | result = []
89 |
90 | for sheet_name in wb_evaluated.sheetnames:
91 | sheet_evaluated = wb_evaluated[sheet_name]
92 | sheet_raw = wb_raw[sheet_name]
93 |
94 | sheet_result = f"Sheet Name: {sheet_name}"
95 | result.append(sheet_result)
96 |
97 | for row_evaluated, row_raw in zip(
98 | sheet_evaluated.iter_rows(), sheet_raw.iter_rows()
99 | ):
100 | is_row_empty = True
101 |
102 | for cell_evaluated, cell_raw in zip(row_evaluated, row_raw):
103 | is_default_background = True
104 | style = []
105 |
106 | if (
107 | cell_raw.fill.start_color.index != "00000000"
108 | and type(cell_raw.fill.start_color.rgb) is str
109 | and cell_raw.fill.start_color.rgb not in WHITE_LIKE_COLORS
110 | ):
111 | is_default_background = False
112 | style.append(f"bg:{cell_raw.fill.start_color.rgb}")
113 | if (
114 | cell_raw.font.color
115 | and cell_raw.font.color.index != 1
116 | and type(cell_raw.font.color.rgb) is str
117 | ):
118 | style.append(f"color:{cell_raw.font.color.rgb}")
119 | if cell_raw.font.bold:
120 | style.append("bold")
121 | if cell_raw.font.italic:
122 | style.append("italic")
123 | if cell_raw.font.underline:
124 | style.append("underline")
125 |
126 | display_value = cell_evaluated.value
127 | if cell_raw.data_type == "f":
128 | cell_raw_val = cell_raw.value
129 | if type(cell_raw_val) is not str:
130 | cell_raw_val = cell_raw.value.text # type: ignore
131 | display_value = f"{cell_raw_val} -> {cell_evaluated.value}"
132 |
133 | coords = cell_evaluated.coordinate
134 |
135 | if display_value is None and not is_default_background:
136 | # If cell is empty but has background color, still include it
137 | result.append(f"{coords}: null [{', '.join(style)}]")
138 | is_row_empty = False
139 | elif display_value:
140 | style_str = f" [{', '.join(style)}]" if style else ""
141 | result.append(f"{coords}: {display_value}{style_str}")
142 | is_row_empty = False
143 | if not is_row_empty:
144 | result.append("") # Newline after each row
145 |
146 | return "\n".join(result)
147 |
148 |
149 | def transform_value(v):
150 | if isinstance(v, (int, float)):
151 | v = round(float(v), 2)
152 | elif isinstance(v, datetime.time):
153 | v = str(v)[:-3]
154 | elif isinstance(v, datetime.datetime):
155 | v = round(
156 | (v - datetime.datetime(1899, 12, 30)).days
157 | + (v - datetime.datetime(1899, 12, 30)).seconds / 86400.0,
158 | 0,
159 | )
160 | elif isinstance(v, str):
161 | try:
162 | v = round(float(v), 2)
163 | except ValueError:
164 | pass
165 | return v
166 |
167 |
168 | def compare_fill_color(fill1, fill2):
169 | fgColor1 = fill1.fgColor.rgb if fill1.fgColor else None
170 | fgColor2 = fill2.fgColor.rgb if fill2.fgColor else None
171 | bgColor1 = fill1.bgColor.rgb if fill1.bgColor else None
172 | bgColor2 = fill2.bgColor.rgb if fill2.bgColor else None
173 | return fgColor1 == fgColor2 and bgColor1 == bgColor2
174 |
175 |
176 | def compare_font_color(font1, font2):
177 | # UNSURE if this is actually correct.
178 | if font1.color and font2.color:
179 | return font1.color.rgb == font2.color.rgb
180 | return font1.color is None and font2.color is None
181 |
182 |
183 | def col_name2num(name):
184 | """Convert an Excel column name to a column number"""
185 | num = 0
186 | for c in name:
187 | num = num * 26 + (ord(c.upper()) - ord("A") + 1)
188 | return num
189 |
190 |
191 | def parse_cell_range(range_str):
192 | start_cell, end_cell = range_str.split(":")
193 | start_col, start_row = "", ""
194 | for char in start_cell:
195 | if char.isdigit():
196 | start_row += char
197 | else:
198 | start_col += char
199 | end_col, end_row = "", ""
200 | for char in end_cell:
201 | if char.isdigit():
202 | end_row += char
203 | else:
204 | end_col += char
205 | return (col_name2num(start_col), int(start_row)), (
206 | col_name2num(end_col),
207 | int(end_row),
208 | )
209 |
210 |
211 | def generate_cell_names(range_str):
212 | from openpyxl.utils import get_column_letter
213 |
214 | if ":" not in range_str:
215 | return [range_str]
216 | (start_col, start_row), (end_col, end_row) = parse_cell_range(range_str)
217 | columns = [get_column_letter(i) for i in range(start_col, end_col + 1)]
218 | return [f"{col}{row}" for col in columns for row in range(start_row, end_row + 1)]
219 |
220 |
221 | def compare_excel_cells(
222 | ground_truth_path: str, output_path: str, answer_position: str, is_CF: bool = False
223 | ) -> Tuple[bool, str]:
224 | from openpyxl import load_workbook
225 |
226 | wb_gt = load_workbook(ground_truth_path, data_only=True)
227 | wb_out = load_workbook(output_path, data_only=True)
228 | sheet_ranges = answer_position.split(",")
229 | for sheet_range in sheet_ranges:
230 | if "!" in sheet_range:
231 | sheet_name, cell_range = sheet_range.split("!")
232 | sheet_name = sheet_name.strip("'")
233 | else:
234 | sheet_name = wb_gt.sheetnames[0]
235 | cell_range = sheet_range
236 | if sheet_name not in wb_out.sheetnames:
237 | return False, f"Worksheet '{sheet_name}' not found in output workbook."
238 | ws_gt = wb_gt[sheet_name]
239 | ws_out = wb_out[sheet_name]
240 | cell_names = generate_cell_names(cell_range)
241 | for cell_name in cell_names:
242 | cell_gt = ws_gt[cell_name]
243 | cell_out = ws_out[cell_name]
244 | if not transform_value(cell_gt.value) == transform_value(cell_out.value):
245 | return (
246 | False,
247 | f"Value mismatch at {cell_name}: expected {cell_gt.value}, got {cell_out.value}",
248 | )
249 | if is_CF:
250 | if not compare_fill_color(cell_gt.fill, cell_out.fill):
251 | return False, f"Fill color mismatch at {cell_name}"
252 | if not compare_font_color(cell_gt.font, cell_out.font):
253 | return False, f"Font color mismatch at {cell_name}"
254 | return True, "All comparisons passed."
255 |
--------------------------------------------------------------------------------
/src/benchmax/envs/excel/excel_env.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 | from typing import Any, Dict, Tuple, Optional
5 |
6 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
7 | import sky
8 |
9 | from benchmax.envs.mcp.parallel_mcp_env import ParallelMcpEnv
10 | from benchmax.envs.mcp.provisioners.base_provisioner import BaseProvisioner
11 | from benchmax.envs.mcp.provisioners.local_provisioner import LocalProvisioner
12 | from benchmax.envs.mcp.provisioners.skypilot_provisioner import SkypilotProvisioner
13 | from benchmax.envs.types import StandardizedExample
14 | from .data_utils import download_and_extract
15 |
16 | # Using library shared with mcp workdir
17 | from .workdir.excel_utils import excel_to_str_repr
18 |
19 | SYSTEM_PROMPT = """You are a spreadsheet expert who can manipulate spreadsheets through Python code.
20 |
21 | You need to solve the given spreadsheet manipulation question, which contains six types of information:
22 | - instruction: The question about spreadsheet manipulation.
23 | - spreadsheet_path: The path of the spreadsheet file you need to manipulate.
24 | - spreadsheet_content: The content of speadsheet file.
25 | - instruction_type: There are two values (Cell-Level Manipulation, Sheet-Level Manipulation) used to indicate whether the answer to this question applies only to specific cells or to the entire worksheet.
26 | - answer_position: The position need to be modified or filled. For Cell-Level Manipulation questions, this field is filled with the cell position; for Sheet-Level Manipulation, it is the maximum range of cells you need to modify. You only need to modify or fill in values within the cell range specified by answer_position.
27 | - output_path: You need to generate the modified spreadsheet file in this new path.
28 | """
29 |
30 | DEFAULT_DATA_OUTPUT_PATH = os.path.expanduser("~/.cache/excel_data")
31 | SPREADSHEET_FULL = "all_data_912_v0.1"
32 | SPREADSHEET_SAMPLE = "sample_data_200"
33 |
34 | # Set train data to full for proper training
35 | SPREADSHEET_BENCH_TRAIN_DATA = SPREADSHEET_SAMPLE
36 |
37 |
38 | class ExcelExample(StandardizedExample):
39 | id: str
40 | answer_position: str
41 | output_filename: str
42 | ground_truth_filename: str
43 | spreadsheet_base_dir: str
44 |
45 |
46 | class ExcelEnv(ParallelMcpEnv):
47 | """Environment for spreadsheet manipulation tasks using MCP with Excel support"""
48 |
49 | system_prompt: str = SYSTEM_PROMPT
50 |
51 | def __init__(
52 | self,
53 | workdir_path: Path,
54 | provisioner: BaseProvisioner,
55 | **kwargs,
56 | ):
57 | """Initialize the ExcelEnv with an optional dataset path."""
58 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs)
59 |
60 | @classmethod
61 | def load_dataset(
62 | cls,
63 | dataset_name: str = "spreadsheetbench",
64 | data_output_path: str = DEFAULT_DATA_OUTPUT_PATH,
65 | **kwargs,
66 | ) -> Tuple[
67 | DatasetDict | Dataset | IterableDatasetDict | IterableDataset, str | None
68 | ]:
69 | # Currently only support spreadsheetbench dataset but can be extended to other datasets in the future
70 | if dataset_name == "spreadsheetbench":
71 | folder_path = Path(data_output_path) / SPREADSHEET_BENCH_TRAIN_DATA
72 | json_path = folder_path / "dataset.json"
73 | if not os.path.exists(json_path):
74 | download_and_extract(
75 | f"https://github.com/RUCKBReasoning/SpreadsheetBench/raw/refs/heads/main/data/{SPREADSHEET_BENCH_TRAIN_DATA}.tar.gz",
76 | data_output_path,
77 | )
78 | with open(json_path, "r") as f:
79 | data = json.load(f)
80 | for example in data:
81 | example["id"] = str(example["id"]) # Ensure IDs are strings
82 | dataset = Dataset.from_list(data)
83 | return dataset, str(folder_path)
84 | return super().load_dataset(dataset_name, **kwargs)
85 |
86 | @classmethod
87 | def dataset_preprocess(
88 | cls, example: Any, dataset_path: Optional[str | Path] = None, **kwargs
89 | ) -> ExcelExample:
90 | # convert dataset json into ExcelExample (a subclass of StandardizedExample)
91 | example_id: Optional[str] = example.get("id")
92 | spreadsheet_path: Optional[str] = example.get("spreadsheet_path")
93 | instruction: Optional[str] = example.get("instruction")
94 | instruction_type: Optional[str] = example.get("instruction_type")
95 | answer_position: Optional[str] = example.get("answer_position")
96 |
97 | if (
98 | not example_id
99 | or not spreadsheet_path
100 | or not instruction
101 | or not instruction_type
102 | or not answer_position
103 | ):
104 | raise ValueError(
105 | "Example must contain 'id', 'spreadsheet_path', 'instruction', 'instruction_type', and 'answer_position' fields"
106 | )
107 | if not isinstance(spreadsheet_path, str):
108 | raise TypeError("spreadsheet_path must be a string")
109 |
110 | if dataset_path is None:
111 | dataset_path = Path(DEFAULT_DATA_OUTPUT_PATH) / SPREADSHEET_BENCH_TRAIN_DATA
112 | elif not isinstance(dataset_path, (str, Path)):
113 | raise TypeError("dataset_path must be a str or Path")
114 |
115 | spreadsheet_base_dir = Path(dataset_path) / spreadsheet_path
116 |
117 | if os.path.exists(spreadsheet_base_dir) is False:
118 | raise FileNotFoundError(
119 | f"Spreadsheet path {spreadsheet_base_dir} does not exist."
120 | )
121 |
122 | # File path in the workspace (input spreadsheet will be copied into the workspace at init_rollout)
123 | input_filename = f"1_{example_id}_input.xlsx"
124 | output_filename = f"1_{example_id}_output.xlsx"
125 | ground_truth_filename = f"1_{example_id}_answer.xlsx"
126 |
127 | input_src_path = spreadsheet_base_dir / input_filename
128 | input_spreadsheet_content = excel_to_str_repr(input_src_path, True)
129 |
130 | prompt = f"""
131 | Instruction: {instruction}
132 | Spreadsheet Path: {input_filename}
133 | Spreadsheet Content: {input_spreadsheet_content}
134 | Instruction Type: {instruction_type}
135 | Answer Position: {answer_position}
136 | Output Path: {output_filename}"""
137 |
138 | return ExcelExample(
139 | prompt=prompt.strip(),
140 | # Ground truth unused in ExcelEnv
141 | ground_truth=None,
142 | init_rollout_args={
143 | "input_src_path": str(input_src_path),
144 | },
145 | id=example_id,
146 | answer_position=answer_position,
147 | output_filename=output_filename,
148 | ground_truth_filename=ground_truth_filename,
149 | spreadsheet_base_dir=str(spreadsheet_base_dir),
150 | )
151 |
152 | async def init_rollout(self, rollout_id: str, **rollout_args):
153 | input_src_path: Optional[str] = rollout_args.get("input_src_path")
154 |
155 | if not input_src_path:
156 | raise ValueError("rollout_args must contain 'input_src_path' field")
157 |
158 | await super().init_rollout(rollout_id, **rollout_args)
159 | await self.copy_to_workspace(rollout_id, Path(input_src_path))
160 |
161 | async def compute_reward(
162 | self, rollout_id: str, completion: str, ground_truth: Any, **kwargs: Any
163 | ) -> Dict[str, float]:
164 | answer_position: Optional[str] = kwargs.get("answer_position")
165 | output_filename: Optional[str] = kwargs.get("output_filename")
166 | ground_truth_filename: Optional[str] = kwargs.get("ground_truth_filename")
167 | spreadsheet_base_dir: Optional[str] = kwargs.get("spreadsheet_base_dir")
168 |
169 | if (
170 | not answer_position
171 | or not output_filename
172 | or not ground_truth_filename
173 | or not spreadsheet_base_dir
174 | ):
175 | raise ValueError(
176 | "kwargs must contain 'answer_position', 'output_filename', 'ground_truth_filename', and 'spreadsheet_base_dir' fields"
177 | )
178 |
179 | # Copy ground truth file to workspace for reward computation
180 | await self.copy_to_workspace(
181 | rollout_id, Path(spreadsheet_base_dir) / ground_truth_filename
182 | )
183 | return await super().compute_reward(
184 | rollout_id,
185 | completion,
186 | ground_truth,
187 | answer_position=answer_position,
188 | output_filename=output_filename,
189 | ground_truth_filename=ground_truth_filename,
190 | )
191 |
192 |
193 | class ExcelEnvLocal(ExcelEnv):
194 | """Import this env to run environment locally"""
195 |
196 | def __init__(self, num_local_servers: int = 5, **kwargs):
197 | workdir_path = Path(__file__).parent / "workdir"
198 | provisioner = LocalProvisioner(
199 | workdir_path=workdir_path, num_servers=num_local_servers
200 | )
201 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs)
202 |
203 |
204 | class ExcelEnvSkypilot(ExcelEnv):
205 | """Import this env to run environment on any cloud (i.e. AWS / GCP / Azure) with Skypilot"""
206 |
207 | def __init__(
208 | self,
209 | cloud: sky.clouds.Cloud = sky.Azure(),
210 | num_nodes: int = 2,
211 | servers_per_node: int = 5,
212 | **kwargs,
213 | ):
214 | workdir_path = Path(__file__).parent / "workdir"
215 | provisioner = SkypilotProvisioner(
216 | workdir_path=workdir_path,
217 | cloud=cloud,
218 | num_nodes=num_nodes,
219 | servers_per_node=servers_per_node,
220 | )
221 | super().__init__(workdir_path=workdir_path, provisioner=provisioner, **kwargs)
222 |
--------------------------------------------------------------------------------