├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── auto_error_identification.py ├── few_shot_data ├── MockAirlineDomainEnv-few_shot.jsonl └── MockRetailDomainEnv-few_shot.jsonl ├── historical_trajectories ├── gpt-4o-airline.json ├── gpt-4o-retail.json ├── sonnet-35-new-airline.json └── sonnet-35-new-retail.json ├── run.py ├── setup.py └── tau_bench ├── __init__.py ├── agents ├── __init__.py ├── base.py ├── chat_react_agent.py ├── few_shot_agent.py └── tool_calling_agent.py ├── envs ├── __init__.py ├── airline │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── flights.json │ │ ├── reservations.json │ │ └── users.json │ ├── env.py │ ├── rules.py │ ├── tasks.py │ ├── tasks_test.py │ ├── tools │ │ ├── __init__.py │ │ ├── book_reservation.py │ │ ├── calculate.py │ │ ├── cancel_reservation.py │ │ ├── get_reservation_details.py │ │ ├── get_user_details.py │ │ ├── list_all_airports.py │ │ ├── search_direct_flight.py │ │ ├── search_onestop_flight.py │ │ ├── send_certificate.py │ │ ├── think.py │ │ ├── transfer_to_human_agents.py │ │ ├── update_reservation_baggages.py │ │ ├── update_reservation_flights.py │ │ └── update_reservation_passengers.py │ ├── wiki.md │ └── wiki.py ├── base.py ├── retail │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── orders.json │ │ ├── products.json │ │ ├── readme.md │ │ └── users.json │ ├── env.py │ ├── rules.py │ ├── tasks.py │ ├── tasks_dev.py │ ├── tasks_test.py │ ├── tasks_train.py │ ├── tools │ │ ├── __init__.py │ │ ├── calculate.py │ │ ├── cancel_pending_order.py │ │ ├── exchange_delivered_order_items.py │ │ ├── find_user_id_by_email.py │ │ ├── find_user_id_by_name_zip.py │ │ ├── get_order_details.py │ │ ├── get_product_details.py │ │ ├── get_user_details.py │ │ ├── list_all_product_types.py │ │ ├── modify_pending_order_address.py │ │ ├── modify_pending_order_items.py │ │ ├── modify_pending_order_payment.py │ │ ├── modify_user_address.py │ │ ├── return_delivered_order_items.py │ │ ├── think.py │ │ └── transfer_to_human_agents.py │ ├── wiki.md │ └── wiki.py ├── tool.py └── user.py ├── model_utils ├── __init__.py ├── api │ ├── __init__.py │ ├── _model_methods.py │ ├── api.py │ ├── cache.py │ ├── datapoint.py │ ├── exception.py │ ├── logging.py │ ├── router.py │ ├── sample.py │ ├── tokens.py │ └── types.py ├── args.py ├── func_tools │ ├── __init__.py │ ├── filter.py │ └── map.py └── model │ ├── __init__.py │ ├── anyscale.py │ ├── chat.py │ ├── claude.py │ ├── completion.py │ ├── exception.py │ ├── general_model.py │ ├── mistral.py │ ├── model.py │ ├── openai.py │ ├── outlines_completion.py │ ├── utils.py │ ├── vllm_chat.py │ ├── vllm_completion.py │ └── vllm_utils.py ├── run.py └── types.py /.gitignore: -------------------------------------------------------------------------------- 1 | #Run result artifacts 2 | results/** 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *.egg-info/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sierra 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include tau_bench *.json 2 | recursive-include tau_bench *.md 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # τ-bench: A Benchmark for Tool-Agent-User Interaction in Real-World Domains 2 | 3 | **Paper**: [https://arxiv.org/abs/2406.12045](https://arxiv.org/abs/2406.12045) 4 | 5 | ## Leaderboard 6 | 7 | ### Airline 8 | 9 | | Strategy | Pass^1 | Pass^2 | Pass^3 | Pass^4 | 10 | | -------------- | ------ | ------ | ------ | ------ | 11 | | [TC (claude-3-5-sonnet-20241022)](https://www.anthropic.com/news/3-5-models-and-computer-use) | **0.460** | **0.326** | **0.263** | **0.225** | 12 | | [TC (gpt-4o)](https://platform.openai.com/docs/guides/function-calling) | 0.420 | 0.273 | 0.220 | 0.200 | 13 | | [TC (claude-3-5-sonnet-20240620)](https://docs.anthropic.com/en/docs/build-with-claude/tool-use) | 0.360 | 0.224 | 0.169 | 0.139 | 14 | | [TC (mistral-large-2407)](https://docs.mistral.ai/capabilities/function_calling/) | ?? | ?? | ?? | ?? | 15 | | [TC (gpt-4o-mini)](https://platform.openai.com/docs/guides/function-calling) | 0.225 | 0.140 | 0.110 | 0.100 | 16 | | [Act](https://arxiv.org/abs/2210.03629) (gpt-4o) | 0.365 | 0.217 | 0.160 | 0.140 | 17 | | [ReAct](https://arxiv.org/abs/2210.03629) (gpt-4o) | 0.325 | 0.233 | 0.185 | 0.160 | 18 | 19 | ### Retail 20 | 21 | | Strategy | Pass^1 | Pass^2 | Pass^3 | Pass^4 | 22 | | -------------- | ------ | ------ | ------ | ------ | 23 | | [TC (claude-3-5-sonnet-20241022)](https://www.anthropic.com/news/3-5-models-and-computer-use) | **0.692** | **0.576** | **0.509** | **0.462** | 24 | | [TC (gpt-4o)](https://platform.openai.com/docs/guides/function-calling) | 0.604 | 0.491 | 0.430 | 0.383 | 25 | | [TC (claude-3-5-sonnet-20240620)](https://docs.anthropic.com/en/docs/build-with-claude/tool-use) | 0.626 | 0.506 | 0.435 | 0.387 | 26 | | [TC (mistral-large-2407)](https://docs.mistral.ai/capabilities/function_calling/) | ?? | ?? | ?? | ?? | 27 | | [TC (gpt-4o-mini)](https://platform.openai.com/docs/guides/function-calling) | ?? | ?? | ?? | ?? | 28 | | [Act](https://arxiv.org/abs/2210.03629) (gpt-4o) | ?? | ?? | ?? | ?? | 29 | | [ReAct](https://arxiv.org/abs/2210.03629) (gpt-4o) | ?? | ?? | ?? | ?? | 30 | 31 | *TC = `tool-calling` strategy (the function-calling strategy reported in the paper) 32 | 33 | ## Setup 34 | 35 | 1. Clone this repository: 36 | 37 | ```bash 38 | git clone https://github.com/sierra-research/tau-bench && cd ./tau-bench 39 | ``` 40 | 41 | 2. Install from source (which also installs required packages): 42 | 43 | ```bash 44 | pip install -e . 45 | ``` 46 | 47 | 3. Set up your OpenAI / Anthropic / Google / Mistral / AnyScale API keys as environment variables. 48 | 49 | ```bash 50 | OPENAI_API_KEY=... 51 | ANTHROPIC_API_KEY=... 52 | GOOGLE_API_KEY=... 53 | MISTRAL_API_KEY=... 54 | ``` 55 | 56 | ## Run 57 | 58 | Run a tool-calling agent on the τ-retail environment: 59 | 60 | ```bash 61 | python run.py --agent-strategy tool-calling --env retail --model gpt-4o --model-provider openai --user-model gpt-4o --user-model-provider openai --user-strategy llm --max-concurrency 10 62 | ``` 63 | 64 | Set max concurrency according to your API limit(s). 65 | 66 | To run specific tasks, use the `--task-ids` flag. For example: 67 | 68 | ```bash 69 | python run.py --agent-strategy tool-calling --env retail --model gpt-4o --model-provider openai --user-model gpt-4o --user-model-provider openai --user-strategy llm --max-concurrency 10 --task-ids 2 4 6 70 | ``` 71 | 72 | This command will run only the tasks with IDs 2, 4, and 6. 73 | 74 | ## User simulators 75 | 76 | By default, we use `gpt-4o` as the user simulator with strategy `llm`. You can use other models by setting the `--user-model` flag, or other strategies by setting the `--user-strategy` flag. For example, run a tool-calling agent with a claude user simulator: 77 | 78 | ```bash 79 | python run.py --agent-strategy tool-calling --env retail --model gpt-4o --model-provider openai --max-concurrency 10 --user-model claude-3-5-sonnet-20240620 --user-model-provider anthropic --user-strategy llm 80 | ``` 81 | 82 | Other strategies: 83 | 84 | To run `react` user simulator: 85 | 86 | ```bash 87 | python run.py --agent-strategy tool-calling --env retail --model gpt-4o --model-provider openai --max-concurrency 10 --user-model gpt-4o --user-model-provider openai --user-strategy react 88 | ``` 89 | 90 | Example of a `react` user response: 91 | 92 | ```md 93 | Thought: 94 | I should provide my name and zip code as I wasn't given an email address to use. 95 | 96 | User Response: 97 | Sure, my name is Yusuf Rossi, and my zip code is 19122. 98 | ``` 99 | 100 | To run `verify` user simulator: 101 | 102 | ```bash 103 | python run.py --agent-strategy tool-calling --env retail --model gpt-4o --model-provider openai --max-concurrency 10 --user-model gpt-4o --user-model-provider openai --user-strategy verify 104 | ``` 105 | 106 | This strategy uses a subsequent LLM verification step to check if the user simulator's response is satisfactory. If not, the user simulator will be prompted to generate a new response. 107 | 108 | To run `reflection` user simulator: 109 | 110 | ```bash 111 | python run.py --agent-strategy tool-calling --env retail --model gpt-4o --model-provider openai --max-concurrency 10 --user-model gpt-4o --user-model-provider openai --user-strategy reflection 112 | ``` 113 | 114 | This strategy uses a subsequent LLM verification step to check if the user simulator's response is satisfactory. If not, the user simulator will be prompted to reflect on its response and generate a new response. 115 | 116 | ## Auto error identification 117 | 118 | Often times, it is difficult and time consuming to manually identify specific error locations in trajectories as they can be long and the constraints can be complex. We have provided an auto error identification tool that can do the following: 119 | 120 | 1. Fault assignment: determine the entity that is responsible for the fault (user, agent, environment) 121 | 2. Fault type classification: classify the type of fault (goal_partially_completed, used_wrong_tool, used_wrong_tool_argument, took_unintended_action) 122 | 123 | Both of the labels are accompanied with a description. 124 | 125 | To run the auto error identification, run: 126 | 127 | ```bash 128 | python auto_error_identification.py --env --platform openai --results-path --max-concurrency 16 --output-path test-auto-error-identification --max-num-failed-results 10 129 | ``` 130 | 131 | Please note that this feature utilizes an LLM, which may lead to inaccurate error identifications. 132 | 133 | *Notice: If an error is raised due to the structure of your results file, you may have to rerun the benchmark to produce a new results file. We have recently [rewritten](https://github.com/sierra-research/tau-bench/commit/043b544371757ebb3762b3d02a6675dfe0c41798) the benchmark to be more type-safe and extensible. 134 | 135 | ## Historical trajectories 136 | 137 | τ-bench might be expensive to run. We have provided a set of historical trajectories for the airline and retail environments in `./historical_trajectories`. 138 | 139 | If you would like to contribute your historical trajectories to this benchmark, please submit a PR! 140 | 141 | ## License 142 | 143 | See `./LICENSE`. 144 | 145 | ## Contact 146 | 147 | Please submit issues or pull requests if you find problems with the benchmark. 148 | 149 | ## Citation 150 | 151 | ```bibtex 152 | @misc{yao2024tau, 153 | title={$\tau$-bench: A Benchmark for Tool-Agent-User Interaction in Real-World Domains}, 154 | author={Shunyu Yao and Noah Shinn and Pedram Razavi and Karthik Narasimhan}, 155 | year={2024}, 156 | eprint={2406.12045}, 157 | archivePrefix={arXiv}, 158 | primaryClass={cs.AI}, 159 | url={https://arxiv.org/abs/2406.12045}, 160 | } 161 | ``` 162 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import argparse 4 | from tau_bench.types import RunConfig 5 | from tau_bench.run import run 6 | from litellm import provider_list 7 | from tau_bench.envs.user import UserStrategy 8 | 9 | 10 | def parse_args() -> RunConfig: 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--num-trials", type=int, default=1) 13 | parser.add_argument( 14 | "--env", type=str, choices=["retail", "airline"], default="retail" 15 | ) 16 | parser.add_argument( 17 | "--model", 18 | type=str, 19 | help="The model to use for the agent", 20 | ) 21 | parser.add_argument( 22 | "--model-provider", 23 | type=str, 24 | choices=provider_list, 25 | help="The model provider for the agent", 26 | ) 27 | parser.add_argument( 28 | "--user-model", 29 | type=str, 30 | default="gpt-4o", 31 | help="The model to use for the user simulator", 32 | ) 33 | parser.add_argument( 34 | "--user-model-provider", 35 | type=str, 36 | choices=provider_list, 37 | help="The model provider for the user simulator", 38 | ) 39 | parser.add_argument( 40 | "--agent-strategy", 41 | type=str, 42 | default="tool-calling", 43 | choices=["tool-calling", "act", "react", "few-shot"], 44 | ) 45 | parser.add_argument( 46 | "--temperature", 47 | type=float, 48 | default=0.0, 49 | help="The sampling temperature for the action model", 50 | ) 51 | parser.add_argument( 52 | "--task-split", 53 | type=str, 54 | default="test", 55 | choices=["train", "test", "dev"], 56 | help="The split of tasks to run (only applies to the retail domain for now", 57 | ) 58 | parser.add_argument("--start-index", type=int, default=0) 59 | parser.add_argument("--end-index", type=int, default=-1, help="Run all tasks if -1") 60 | parser.add_argument("--task-ids", type=int, nargs="+", help="(Optional) run only the tasks with the given IDs") 61 | parser.add_argument("--log-dir", type=str, default="results") 62 | parser.add_argument( 63 | "--max-concurrency", 64 | type=int, 65 | default=1, 66 | help="Number of tasks to run in parallel", 67 | ) 68 | parser.add_argument("--seed", type=int, default=10) 69 | parser.add_argument("--shuffle", type=int, default=0) 70 | parser.add_argument("--user-strategy", type=str, default="llm", choices=[item.value for item in UserStrategy]) 71 | parser.add_argument("--few-shot-displays-path", type=str, help="Path to a jsonlines file containing few shot displays") 72 | args = parser.parse_args() 73 | print(args) 74 | return RunConfig( 75 | model_provider=args.model_provider, 76 | user_model_provider=args.user_model_provider, 77 | model=args.model, 78 | user_model=args.user_model, 79 | num_trials=args.num_trials, 80 | env=args.env, 81 | agent_strategy=args.agent_strategy, 82 | temperature=args.temperature, 83 | task_split=args.task_split, 84 | start_index=args.start_index, 85 | end_index=args.end_index, 86 | task_ids=args.task_ids, 87 | log_dir=args.log_dir, 88 | max_concurrency=args.max_concurrency, 89 | seed=args.seed, 90 | shuffle=args.shuffle, 91 | user_strategy=args.user_strategy, 92 | few_shot_displays_path=args.few_shot_displays_path, 93 | ) 94 | 95 | 96 | def main(): 97 | config = parse_args() 98 | run(config) 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="tau_bench", 7 | version="0.1.0", 8 | description="The Tau-Bench package", 9 | long_description=open("README.md").read(), 10 | packages=find_packages(), 11 | include_package_data=True, 12 | install_requires=[ 13 | "openai>=1.13.3", 14 | "mistralai>=0.4.0", 15 | "anthropic>=0.26.1", 16 | "google-generativeai>=0.5.4", 17 | "tenacity>=8.3.0", 18 | "termcolor>=2.4.0", 19 | "numpy>=1.26.4", 20 | "litellm>=1.41.0", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /tau_bench/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from tau_bench.envs.base import Env as Env 4 | from tau_bench.agents.base import Agent as Agent 5 | -------------------------------------------------------------------------------- /tau_bench/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | -------------------------------------------------------------------------------- /tau_bench/agents/base.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import abc 4 | from typing import Optional 5 | from tau_bench.envs.base import Env 6 | from tau_bench.types import SolveResult 7 | 8 | 9 | class Agent(abc.ABC): 10 | @abc.abstractmethod 11 | def solve( 12 | self, env: Env, task_index: Optional[int] = None, max_num_steps: int = 30 13 | ) -> SolveResult: 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /tau_bench/agents/chat_react_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from litellm import completion 5 | 6 | from tau_bench.agents.base import Agent 7 | from tau_bench.envs.base import Env 8 | from tau_bench.types import ( 9 | Action, 10 | SolveResult, 11 | RESPOND_ACTION_NAME, 12 | RESPOND_ACTION_FIELD_NAME, 13 | ) 14 | from typing import Optional, List, Dict, Any, Tuple 15 | 16 | 17 | class ChatReActAgent(Agent): 18 | def __init__( 19 | self, 20 | tools_info: List[Dict[str, Any]], 21 | wiki: str, 22 | model: str, 23 | provider: str, 24 | use_reasoning: bool = True, 25 | temperature: float = 0.0, 26 | ) -> None: 27 | instruction = REACT_INSTRUCTION if use_reasoning else ACT_INSTRUCTION 28 | self.prompt = ( 29 | wiki + "\n#Available tools\n" + json.dumps(tools_info) + instruction 30 | ) 31 | self.model = model 32 | self.provider = provider 33 | self.temperature = temperature 34 | self.use_reasoning = use_reasoning 35 | self.tools_info = tools_info 36 | 37 | def generate_next_step( 38 | self, messages: List[Dict[str, Any]] 39 | ) -> Tuple[Dict[str, Any], Action, float]: 40 | res = completion( 41 | model=self.model, 42 | custom_llm_provider=self.provider, 43 | messages=messages, 44 | temperature=self.temperature, 45 | ) 46 | message = res.choices[0].message 47 | action_str = message.content.split("Action:")[-1].strip() 48 | try: 49 | action_parsed = json.loads(action_str) 50 | except json.JSONDecodeError: 51 | # this is a hack 52 | action_parsed = { 53 | "name": RESPOND_ACTION_NAME, 54 | "arguments": {RESPOND_ACTION_FIELD_NAME: action_str}, 55 | } 56 | assert "name" in action_parsed 57 | assert "arguments" in action_parsed 58 | action = Action(name=action_parsed["name"], kwargs=action_parsed["arguments"]) 59 | return message.model_dump(), action, res._hidden_params["response_cost"] 60 | 61 | def solve( 62 | self, env: Env, task_index: Optional[int] = None, max_num_steps: int = 30 63 | ) -> SolveResult: 64 | response = env.reset(task_index=task_index) 65 | reward = 0.0 66 | messages: List[Dict[str, Any]] = [ 67 | {"role": "system", "content": self.prompt}, 68 | {"role": "user", "content": response.observation}, 69 | ] 70 | total_cost = 0.0 71 | info = {} 72 | for _ in range(max_num_steps): 73 | message, action, cost = self.generate_next_step(messages) 74 | response = env.step(action) 75 | obs = response.observation 76 | reward = response.reward 77 | info = {**info, **response.info.model_dump()} 78 | if action.name != RESPOND_ACTION_NAME: 79 | obs = "API output: " + obs 80 | messages.extend( 81 | [ 82 | message, 83 | {"role": "user", "content": obs}, 84 | ] 85 | ) 86 | total_cost += cost 87 | if response.done: 88 | break 89 | return SolveResult( 90 | messages=messages, 91 | reward=reward, 92 | info=info, 93 | ) 94 | 95 | 96 | REACT_INSTRUCTION = f""" 97 | # Instruction 98 | You need to act as an agent that use the above tools to help the user according to the above policy. 99 | 100 | At each step, your generation should have exactly the following format: 101 | Thought: 102 | 103 | Action: 104 | {{"name": , "arguments": }} 105 | 106 | The Action will be parsed, so it must be valid JSON. 107 | 108 | You should not use made-up or placeholder arguments. 109 | 110 | For example, if the user says "I want to know the current weather of San Francisco", and there is such a tool available 111 | {{ 112 | "type": "function", 113 | "function": {{ 114 | "name": "get_current_weather", 115 | "description": "Get the current weather", 116 | "parameters": {{ 117 | "type": "object", 118 | "properties": {{ 119 | "location": {{ 120 | "type": "string", 121 | "description": "The city and state, e.g. San Francisco, CA", 122 | }}, 123 | "format": {{ 124 | "type": "string", 125 | "enum": ["celsius", "fahrenheit"], 126 | "description": "The temperature unit to use. Infer this from the users location.", 127 | }}, 128 | }}, 129 | "required": ["location", "format"], 130 | }}, 131 | }} 132 | }} 133 | 134 | Your response can be like this: 135 | Thought: 136 | Since the user asks for the weather of San Francisco in USA, the unit should be in fahrenheit. I can query get_current_weather to get the weather. 137 | Action: 138 | {{"name": "get_current_weather", "arguments": {{"location": "San Francisco, CA", "format": "fahrenheit"}}}} 139 | 140 | And if the tool returns "70F", your response can be: 141 | Thought: 142 | I can answer the user now. 143 | Action: 144 | {{"name": {RESPOND_ACTION_NAME}, "arguments": {{"{RESPOND_ACTION_FIELD_NAME}": "The current weather of San Francisco is 70F."}}}} 145 | 146 | Try to be helpful and always follow the policy. 147 | """ 148 | 149 | 150 | ACT_INSTRUCTION = f""" 151 | # Instruction 152 | You need to act as an agent that use the above tools to help the user according to the above policy. 153 | 154 | At each step, your generation should have exactly the following format: 155 | 156 | Action: 157 | {{"name": , "arguments": }} 158 | 159 | You should not use made-up or placeholder arguments. 160 | 161 | The Action will be parsed, so it must be valid JSON. 162 | 163 | For example, if the user says "I want to know the current weather of San Francisco", and there is such a tool available 164 | ```json 165 | {{ 166 | "type": "function", 167 | "function": {{ 168 | "name": "get_current_weather", 169 | "description": "Get the current weather", 170 | "parameters": {{ 171 | "type": "object", 172 | "properties": {{ 173 | "location": {{ 174 | "type": "string", 175 | "description": "The city and state, e.g. San Francisco, CA", 176 | }}, 177 | "format": {{ 178 | "type": "string", 179 | "enum": ["celsius", "fahrenheit"], 180 | "description": "The temperature unit to use. Infer this from the users location.", 181 | }}, 182 | }}, 183 | "required": ["location", "format"], 184 | }}, 185 | }} 186 | }} 187 | ``` 188 | 189 | Your response can be like this: 190 | Action: 191 | {{"name": "get_current_weather", "arguments": {{"location": "San Francisco, CA", "format": "fahrenheit"}}}} 192 | 193 | And if the tool returns "70F", your response can be: 194 | Action: 195 | {{"name": {RESPOND_ACTION_NAME}, "arguments": {{"{RESPOND_ACTION_FIELD_NAME}": "The current weather of San Francisco is 70F."}}}} 196 | 197 | Try to be helpful and always follow the policy. Always make sure you generate valid JSON only. 198 | """ 199 | -------------------------------------------------------------------------------- /tau_bench/agents/few_shot_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | import random 5 | from litellm import completion 6 | from typing import List, Optional, Dict, Any 7 | 8 | from tau_bench.agents.base import Agent 9 | from tau_bench.envs.base import Env 10 | from tau_bench.types import SolveResult, Action, RESPOND_ACTION_NAME 11 | 12 | 13 | class FewShotToolCallingAgent(Agent): 14 | def __init__( 15 | self, 16 | tools_info: List[Dict[str, Any]], 17 | wiki: str, 18 | model: str, 19 | provider: str, 20 | few_shot_displays: List[str], 21 | temperature: float = 0.0, 22 | num_few_shots: int = 5, 23 | ): 24 | self.tools_info = tools_info 25 | self.wiki = wiki 26 | self.model = model 27 | self.provider = provider 28 | if len(few_shot_displays) == 0: 29 | raise ValueError("Few shot displays are empty") 30 | elif len(few_shot_displays) < num_few_shots: 31 | raise ValueError(f"Few shot displays are less than num_few_shots requested: {len(few_shot_displays)} < {num_few_shots}") 32 | self.few_shot_displays = few_shot_displays 33 | self.temperature = temperature 34 | self.num_few_shots = num_few_shots 35 | def solve( 36 | self, env: Env, task_index: Optional[int] = None, max_num_steps: int = 30 37 | ) -> SolveResult: 38 | sampled_few_shot_displays = random.sample(self.few_shot_displays, self.num_few_shots) 39 | few_shots = "\n\n".join([f"Example {i+1}:\n{display}" for i, display in enumerate(sampled_few_shot_displays)]) 40 | total_cost = 0.0 41 | env_reset_res = env.reset(task_index=task_index) 42 | obs = env_reset_res.observation 43 | info = env_reset_res.info.model_dump() 44 | reward = 0.0 45 | messages: List[Dict[str, Any]] = [ 46 | {"role": "system", "content": f"{self.wiki}\n\n{few_shots}"}, 47 | {"role": "user", "content": obs}, 48 | ] 49 | for _ in range(max_num_steps): 50 | res = completion( 51 | messages=messages, 52 | model=self.model, 53 | custom_llm_provider=self.provider, 54 | tools=self.tools_info, 55 | temperature=self.temperature, 56 | ) 57 | next_message = res.choices[0].message.model_dump() 58 | total_cost += res._hidden_params["response_cost"] 59 | action = message_to_action(next_message) 60 | env_response = env.step(action) 61 | reward = env_response.reward 62 | info = {**info, **env_response.info.model_dump()} 63 | if action.name != RESPOND_ACTION_NAME: 64 | next_message["tool_calls"] = next_message["tool_calls"][:1] 65 | messages.extend( 66 | [ 67 | next_message, 68 | { 69 | "role": "tool", 70 | "tool_call_id": next_message["tool_calls"][0]["id"], 71 | "name": next_message["tool_calls"][0]["function"]["name"], 72 | "content": env_response.observation, 73 | }, 74 | ] 75 | ) 76 | else: 77 | messages.extend( 78 | [ 79 | next_message, 80 | {"role": "user", "content": env_response.observation}, 81 | ] 82 | ) 83 | if env_response.done: 84 | break 85 | return SolveResult( 86 | reward=reward, 87 | info=info, 88 | messages=messages, 89 | total_cost=total_cost, 90 | ) 91 | 92 | 93 | def message_to_action( 94 | message: Dict[str, Any], 95 | ) -> Action: 96 | if "tool_calls" in message and message["tool_calls"] is not None and len(message["tool_calls"]) > 0 and message["tool_calls"][0]["function"] is not None: 97 | tool_call = message["tool_calls"][0] 98 | return Action( 99 | name=tool_call["function"]["name"], 100 | kwargs=json.loads(tool_call["function"]["arguments"]), 101 | ) 102 | else: 103 | return Action(name=RESPOND_ACTION_NAME, kwargs={"content": message["content"]}) 104 | -------------------------------------------------------------------------------- /tau_bench/agents/tool_calling_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from litellm import completion 5 | from typing import List, Optional, Dict, Any 6 | 7 | from tau_bench.agents.base import Agent 8 | from tau_bench.envs.base import Env 9 | from tau_bench.types import SolveResult, Action, RESPOND_ACTION_NAME 10 | 11 | 12 | class ToolCallingAgent(Agent): 13 | def __init__( 14 | self, 15 | tools_info: List[Dict[str, Any]], 16 | wiki: str, 17 | model: str, 18 | provider: str, 19 | temperature: float = 0.0, 20 | ): 21 | self.tools_info = tools_info 22 | self.wiki = wiki 23 | self.model = model 24 | self.provider = provider 25 | self.temperature = temperature 26 | 27 | def solve( 28 | self, env: Env, task_index: Optional[int] = None, max_num_steps: int = 30 29 | ) -> SolveResult: 30 | total_cost = 0.0 31 | env_reset_res = env.reset(task_index=task_index) 32 | obs = env_reset_res.observation 33 | info = env_reset_res.info.model_dump() 34 | reward = 0.0 35 | messages: List[Dict[str, Any]] = [ 36 | {"role": "system", "content": self.wiki}, 37 | {"role": "user", "content": obs}, 38 | ] 39 | for _ in range(max_num_steps): 40 | res = completion( 41 | messages=messages, 42 | model=self.model, 43 | custom_llm_provider=self.provider, 44 | tools=self.tools_info, 45 | temperature=self.temperature, 46 | ) 47 | next_message = res.choices[0].message.model_dump() 48 | total_cost += res._hidden_params["response_cost"] 49 | action = message_to_action(next_message) 50 | env_response = env.step(action) 51 | reward = env_response.reward 52 | info = {**info, **env_response.info.model_dump()} 53 | if action.name != RESPOND_ACTION_NAME: 54 | next_message["tool_calls"] = next_message["tool_calls"][:1] 55 | messages.extend( 56 | [ 57 | next_message, 58 | { 59 | "role": "tool", 60 | "tool_call_id": next_message["tool_calls"][0]["id"], 61 | "name": next_message["tool_calls"][0]["function"]["name"], 62 | "content": env_response.observation, 63 | }, 64 | ] 65 | ) 66 | else: 67 | messages.extend( 68 | [ 69 | next_message, 70 | {"role": "user", "content": env_response.observation}, 71 | ] 72 | ) 73 | if env_response.done: 74 | break 75 | return SolveResult( 76 | reward=reward, 77 | info=info, 78 | messages=messages, 79 | total_cost=total_cost, 80 | ) 81 | 82 | 83 | def message_to_action( 84 | message: Dict[str, Any], 85 | ) -> Action: 86 | if "tool_calls" in message and message["tool_calls"] is not None and len(message["tool_calls"]) > 0 and message["tool_calls"][0]["function"] is not None: 87 | tool_call = message["tool_calls"][0] 88 | return Action( 89 | name=tool_call["function"]["name"], 90 | kwargs=json.loads(tool_call["function"]["arguments"]), 91 | ) 92 | else: 93 | return Action(name=RESPOND_ACTION_NAME, kwargs={"content": message["content"]}) 94 | -------------------------------------------------------------------------------- /tau_bench/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Optional, Union 4 | from tau_bench.envs.base import Env 5 | from tau_bench.envs.user import UserStrategy 6 | 7 | 8 | def get_env( 9 | env_name: str, 10 | user_strategy: Union[str, UserStrategy], 11 | user_model: str, 12 | task_split: str, 13 | user_provider: Optional[str] = None, 14 | task_index: Optional[int] = None, 15 | ) -> Env: 16 | if env_name == "retail": 17 | from tau_bench.envs.retail import MockRetailDomainEnv 18 | 19 | return MockRetailDomainEnv( 20 | user_strategy=user_strategy, 21 | user_model=user_model, 22 | task_split=task_split, 23 | user_provider=user_provider, 24 | task_index=task_index, 25 | ) 26 | elif env_name == "airline": 27 | from tau_bench.envs.airline import MockAirlineDomainEnv 28 | 29 | return MockAirlineDomainEnv( 30 | user_strategy=user_strategy, 31 | user_model=user_model, 32 | task_split=task_split, 33 | user_provider=user_provider, 34 | task_index=task_index, 35 | ) 36 | else: 37 | raise ValueError(f"Unknown environment: {env_name}") 38 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from tau_bench.envs.airline.env import MockAirlineDomainEnv as MockAirlineDomainEnv 4 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | import os 5 | from typing import Any 6 | 7 | FOLDER_PATH = os.path.dirname(__file__) 8 | 9 | 10 | def load_data() -> dict[str, Any]: 11 | with open(os.path.join(FOLDER_PATH, "flights.json")) as f: 12 | flight_data = json.load(f) 13 | with open(os.path.join(FOLDER_PATH, "reservations.json")) as f: 14 | reservation_data = json.load(f) 15 | with open(os.path.join(FOLDER_PATH, "users.json")) as f: 16 | user_data = json.load(f) 17 | return { 18 | "flights": flight_data, 19 | "reservations": reservation_data, 20 | "users": user_data, 21 | } 22 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/env.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from tau_bench.envs.airline.data import load_data 4 | from tau_bench.envs.airline.rules import RULES 5 | from tau_bench.envs.airline.tools import ALL_TOOLS 6 | from tau_bench.envs.airline.wiki import WIKI 7 | from tau_bench.envs.base import Env 8 | from typing import Optional, Union 9 | from tau_bench.envs.user import UserStrategy 10 | 11 | 12 | class MockAirlineDomainEnv(Env): 13 | def __init__( 14 | self, 15 | user_strategy: Union[str, UserStrategy] = UserStrategy.LLM, 16 | user_model: str = "gpt-4o", 17 | user_provider: Optional[str] = None, 18 | task_split: str = "test", 19 | task_index: Optional[int] = None, 20 | ): 21 | match task_split: 22 | case "test": 23 | from tau_bench.envs.airline.tasks_test import TASKS as tasks 24 | case _: 25 | raise ValueError(f"Unknown task split: {task_split}") 26 | super().__init__( 27 | data_load_func=load_data, 28 | tools=ALL_TOOLS, 29 | tasks=tasks, 30 | wiki=WIKI, 31 | rules=RULES, 32 | user_strategy=user_strategy, 33 | user_model=user_model, 34 | user_provider=user_provider, 35 | task_index=task_index, 36 | ) 37 | self.terminate_tools = ["transfer_to_human_agents"] 38 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/rules.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | RULES = [] 4 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from .book_reservation import BookReservation 4 | from .calculate import Calculate 5 | from .cancel_reservation import CancelReservation 6 | from .get_reservation_details import GetReservationDetails 7 | from .get_user_details import GetUserDetails 8 | from .list_all_airports import ListAllAirports 9 | from .search_direct_flight import SearchDirectFlight 10 | from .search_onestop_flight import SearchOnestopFlight 11 | from .send_certificate import SendCertificate 12 | from .think import Think 13 | from .transfer_to_human_agents import TransferToHumanAgents 14 | from .update_reservation_baggages import UpdateReservationBaggages 15 | from .update_reservation_flights import UpdateReservationFlights 16 | from .update_reservation_passengers import UpdateReservationPassengers 17 | 18 | ALL_TOOLS = [ 19 | BookReservation, 20 | Calculate, 21 | CancelReservation, 22 | GetReservationDetails, 23 | GetUserDetails, 24 | ListAllAirports, 25 | SearchDirectFlight, 26 | SearchOnestopFlight, 27 | SendCertificate, 28 | Think, 29 | TransferToHumanAgents, 30 | UpdateReservationBaggages, 31 | UpdateReservationFlights, 32 | UpdateReservationPassengers, 33 | ] 34 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/calculate.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class Calculate(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], expression: str) -> str: 10 | if not all(char in "0123456789+-*/(). " for char in expression): 11 | return "Error: invalid characters in expression" 12 | try: 13 | return str(round(float(eval(expression, {"__builtins__": None}, {})), 2)) 14 | except Exception as e: 15 | return f"Error: {e}" 16 | 17 | @staticmethod 18 | def get_info() -> Dict[str, Any]: 19 | return { 20 | "type": "function", 21 | "function": { 22 | "name": "calculate", 23 | "description": "Calculate the result of a mathematical expression.", 24 | "parameters": { 25 | "type": "object", 26 | "properties": { 27 | "expression": { 28 | "type": "string", 29 | "description": "The mathematical expression to calculate, such as '2 + 2'. The expression can contain numbers, operators (+, -, *, /), parentheses, and spaces.", 30 | }, 31 | }, 32 | "required": ["expression"], 33 | }, 34 | }, 35 | } 36 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/cancel_reservation.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class CancelReservation(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | reservation_id: str, 13 | ) -> str: 14 | reservations = data["reservations"] 15 | if reservation_id not in reservations: 16 | return "Error: reservation not found" 17 | reservation = reservations[reservation_id] 18 | 19 | # reverse the payment 20 | refunds = [] 21 | for payment in reservation["payment_history"]: 22 | refunds.append( 23 | { 24 | "payment_id": payment["payment_id"], 25 | "amount": -payment["amount"], 26 | } 27 | ) 28 | reservation["payment_history"].extend(refunds) 29 | reservation["status"] = "cancelled" 30 | return json.dumps(reservation) 31 | 32 | @staticmethod 33 | def get_info() -> Dict[str, Any]: 34 | return { 35 | "type": "function", 36 | "function": { 37 | "name": "cancel_reservation", 38 | "description": "Cancel the whole reservation.", 39 | "parameters": { 40 | "type": "object", 41 | "properties": { 42 | "reservation_id": { 43 | "type": "string", 44 | "description": "The reservation ID, such as 'ZFA04Y'.", 45 | }, 46 | }, 47 | "required": ["reservation_id"], 48 | }, 49 | }, 50 | } 51 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/get_reservation_details.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class GetReservationDetails(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], reservation_id: str) -> str: 11 | reservations = data["reservations"] 12 | if reservation_id in reservations: 13 | return json.dumps(reservations[reservation_id]) 14 | return "Error: user not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "get_reservation_details", 22 | "description": "Get the details of a reservation.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "reservation_id": { 27 | "type": "string", 28 | "description": "The reservation id, such as '8JX2WO'.", 29 | }, 30 | }, 31 | "required": ["reservation_id"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/get_user_details.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class GetUserDetails(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], user_id: str) -> str: 11 | users = data["users"] 12 | if user_id in users: 13 | return json.dumps(users[user_id]) 14 | return "Error: user not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "get_user_details", 22 | "description": "Get the details of an user, including their reservations.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "user_id": { 27 | "type": "string", 28 | "description": "The user id, such as 'sara_doe_496'.", 29 | }, 30 | }, 31 | "required": ["user_id"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/list_all_airports.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ListAllAirports(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any]) -> str: 11 | airports = [ 12 | "SFO", 13 | "JFK", 14 | "LAX", 15 | "ORD", 16 | "DFW", 17 | "DEN", 18 | "SEA", 19 | "ATL", 20 | "MIA", 21 | "BOS", 22 | "PHX", 23 | "IAH", 24 | "LAS", 25 | "MCO", 26 | "EWR", 27 | "CLT", 28 | "MSP", 29 | "DTW", 30 | "PHL", 31 | "LGA", 32 | ] 33 | cities = [ 34 | "San Francisco", 35 | "New York", 36 | "Los Angeles", 37 | "Chicago", 38 | "Dallas", 39 | "Denver", 40 | "Seattle", 41 | "Atlanta", 42 | "Miami", 43 | "Boston", 44 | "Phoenix", 45 | "Houston", 46 | "Las Vegas", 47 | "Orlando", 48 | "Newark", 49 | "Charlotte", 50 | "Minneapolis", 51 | "Detroit", 52 | "Philadelphia", 53 | "LaGuardia", 54 | ] 55 | return json.dumps({airport: city for airport, city in zip(airports, cities)}) 56 | 57 | @staticmethod 58 | def get_info() -> Dict[str, Any]: 59 | return { 60 | "type": "function", 61 | "function": { 62 | "name": "list_all_airports", 63 | "description": "List all airports and their cities.", 64 | "parameters": { 65 | "type": "object", 66 | "properties": {}, 67 | "required": [], 68 | }, 69 | }, 70 | } 71 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/search_direct_flight.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class SearchDirectFlight(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], origin: str, destination: str, date: str) -> str: 11 | flights = data["flights"] 12 | results = [] 13 | for flight in flights.values(): 14 | if flight["origin"] == origin and flight["destination"] == destination: 15 | if ( 16 | date in flight["dates"] 17 | and flight["dates"][date]["status"] == "available" 18 | ): 19 | # results add flight except dates, but add flight["datas"][date] 20 | results.append({k: v for k, v in flight.items() if k != "dates"}) 21 | results[-1].update(flight["dates"][date]) 22 | return json.dumps(results) 23 | 24 | @staticmethod 25 | def get_info() -> Dict[str, Any]: 26 | return { 27 | "type": "function", 28 | "function": { 29 | "name": "search_direct_flight", 30 | "description": "Search direct flights between two cities on a specific date.", 31 | "parameters": { 32 | "type": "object", 33 | "properties": { 34 | "origin": { 35 | "type": "string", 36 | "description": "The origin city airport in three letters, such as 'JFK'.", 37 | }, 38 | "destination": { 39 | "type": "string", 40 | "description": "The destination city airport in three letters, such as 'LAX'.", 41 | }, 42 | "date": { 43 | "type": "string", 44 | "description": "The date of the flight in the format 'YYYY-MM-DD', such as '2024-01-01'.", 45 | }, 46 | }, 47 | "required": ["origin", "destination", "date"], 48 | }, 49 | }, 50 | } 51 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/search_onestop_flight.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class SearchOnestopFlight(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], origin: str, destination: str, date: str) -> str: 11 | flights = data["flights"] 12 | results = [] 13 | for flight1 in flights.values(): 14 | if flight1["origin"] == origin: 15 | for flight2 in flights.values(): 16 | if ( 17 | flight2["destination"] == destination 18 | and flight1["destination"] == flight2["origin"] 19 | ): 20 | date2 = ( 21 | f"2024-05-{int(date[-2:])+1}" 22 | if "+1" in flight1["scheduled_arrival_time_est"] 23 | else date 24 | ) 25 | if ( 26 | flight1["scheduled_arrival_time_est"] 27 | > flight2["scheduled_departure_time_est"] 28 | ): 29 | continue 30 | if date in flight1["dates"] and date2 in flight2["dates"]: 31 | if ( 32 | flight1["dates"][date]["status"] == "available" 33 | and flight2["dates"][date2]["status"] == "available" 34 | ): 35 | result1 = { 36 | k: v for k, v in flight1.items() if k != "dates" 37 | } 38 | result1.update(flight1["dates"][date]) 39 | result1["date"] = date 40 | result2 = { 41 | k: v for k, v in flight2.items() if k != "dates" 42 | } 43 | result2.update(flight2["dates"][date]) 44 | result2["date"] = date2 45 | results.append([result1, result2]) 46 | return json.dumps(results) 47 | 48 | @staticmethod 49 | def get_info() -> Dict[str, Any]: 50 | return { 51 | "type": "function", 52 | "function": { 53 | "name": "search_onestop_flight", 54 | "description": "Search direct flights between two cities on a specific date.", 55 | "parameters": { 56 | "type": "object", 57 | "properties": { 58 | "origin": { 59 | "type": "string", 60 | "description": "The origin city airport in three letters, such as 'JFK'.", 61 | }, 62 | "destination": { 63 | "type": "string", 64 | "description": "The destination city airport in three letters, such as 'LAX'.", 65 | }, 66 | "date": { 67 | "type": "string", 68 | "description": "The date of the flight in the format 'YYYY-MM-DD', such as '2024-05-01'.", 69 | }, 70 | }, 71 | "required": ["origin", "destination", "date"], 72 | }, 73 | }, 74 | } 75 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/send_certificate.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class SendCertificate(Tool): 8 | @staticmethod 9 | def invoke( 10 | data: Dict[str, Any], 11 | user_id: str, 12 | amount: int, 13 | ) -> str: 14 | users = data["users"] 15 | if user_id not in users: 16 | return "Error: user not found" 17 | user = users[user_id] 18 | 19 | # add a certificate, assume at most 3 cases per task 20 | for id in [3221322, 3221323, 3221324]: 21 | payment_id = f"certificate_{id}" 22 | if payment_id not in user["payment_methods"]: 23 | user["payment_methods"][payment_id] = { 24 | "source": "certificate", 25 | "amount": amount, 26 | "id": payment_id, 27 | } 28 | return f"Certificate {payment_id} added to user {user_id} with amount {amount}." 29 | 30 | @staticmethod 31 | def get_info() -> Dict[str, Any]: 32 | return { 33 | "type": "function", 34 | "function": { 35 | "name": "send_certificate", 36 | "description": "Send a certificate to a user. Be careful!", 37 | "parameters": { 38 | "type": "object", 39 | "properties": { 40 | "user_id": { 41 | "type": "string", 42 | "description": "The ID of the user to book the reservation, such as 'sara_doe_496'.", 43 | }, 44 | "amount": { 45 | "type": "number", 46 | "description": "Certificate amount to send.", 47 | }, 48 | }, 49 | "required": ["user_id", "amount"], 50 | }, 51 | }, 52 | } 53 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/think.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class Think(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], thought: str) -> str: 10 | return "" 11 | 12 | @staticmethod 13 | def get_info() -> Dict[str, Any]: 14 | return { 15 | "type": "function", 16 | "function": { 17 | "name": "think", 18 | "description": "Use the tool to think about something. It will not obtain new information or change the database, but just append the thought to the log. Use it when complex reasoning is needed.", 19 | "parameters": { 20 | "type": "object", 21 | "properties": { 22 | "thought": { 23 | "type": "string", 24 | "description": "A thought to think about.", 25 | }, 26 | }, 27 | "required": ["thought"], 28 | }, 29 | }, 30 | } 31 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/transfer_to_human_agents.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class TransferToHumanAgents(Tool): 8 | @staticmethod 9 | def invoke( 10 | data: Dict[str, Any], 11 | summary: str, 12 | ) -> str: 13 | return "Transfer successful" 14 | 15 | @staticmethod 16 | def get_info() -> Dict[str, Any]: 17 | return { 18 | "type": "function", 19 | "function": { 20 | "name": "transfer_to_human_agents", 21 | "description": "Transfer the user to a human agent, with a summary of the user's issue. Only transfer if the user explicitly asks for a human agent, or if the user's issue cannot be resolved by the agent with the available tools.", 22 | "parameters": { 23 | "type": "object", 24 | "properties": { 25 | "summary": { 26 | "type": "string", 27 | "description": "A summary of the user's issue.", 28 | }, 29 | }, 30 | "required": [ 31 | "summary", 32 | ], 33 | }, 34 | }, 35 | } 36 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/update_reservation_baggages.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class UpdateReservationBaggages(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | reservation_id: str, 13 | total_baggages: int, 14 | nonfree_baggages: int, 15 | payment_id: str, 16 | ) -> str: 17 | users, reservations = data["users"], data["reservations"] 18 | if reservation_id not in reservations: 19 | return "Error: reservation not found" 20 | reservation = reservations[reservation_id] 21 | 22 | total_price = 50 * max(0, nonfree_baggages - reservation["nonfree_baggages"]) 23 | if payment_id not in users[reservation["user_id"]]["payment_methods"]: 24 | return "Error: payment method not found" 25 | payment_method = users[reservation["user_id"]]["payment_methods"][payment_id] 26 | if payment_method["source"] == "certificate": 27 | return "Error: certificate cannot be used to update reservation" 28 | elif ( 29 | payment_method["source"] == "gift_card" 30 | and payment_method["amount"] < total_price 31 | ): 32 | return "Error: gift card balance is not enough" 33 | 34 | reservation["total_baggages"] = total_baggages 35 | reservation["nonfree_baggages"] = nonfree_baggages 36 | if payment_method["source"] == "gift_card": 37 | payment_method["amount"] -= total_price 38 | 39 | if total_price != 0: 40 | reservation["payment_history"].append( 41 | { 42 | "payment_id": payment_id, 43 | "amount": total_price, 44 | } 45 | ) 46 | 47 | return json.dumps(reservation) 48 | 49 | @staticmethod 50 | def get_info() -> Dict[str, Any]: 51 | return { 52 | "type": "function", 53 | "function": { 54 | "name": "update_reservation_baggages", 55 | "description": "Update the baggage information of a reservation.", 56 | "parameters": { 57 | "type": "object", 58 | "properties": { 59 | "reservation_id": { 60 | "type": "string", 61 | "description": "The reservation ID, such as 'ZFA04Y'.", 62 | }, 63 | "total_baggages": { 64 | "type": "integer", 65 | "description": "The updated total number of baggage items included in the reservation.", 66 | }, 67 | "nonfree_baggages": { 68 | "type": "integer", 69 | "description": "The updated number of non-free baggage items included in the reservation.", 70 | }, 71 | "payment_id": { 72 | "type": "string", 73 | "description": "The payment id stored in user profile, such as 'credit_card_7815826', 'gift_card_7815826', 'certificate_7815826'.", 74 | }, 75 | }, 76 | "required": [ 77 | "reservation_id", 78 | "total_baggages", 79 | "nonfree_baggages", 80 | "payment_id", 81 | ], 82 | }, 83 | }, 84 | } 85 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/update_reservation_flights.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from copy import deepcopy 5 | from typing import Any, Dict, List 6 | from tau_bench.envs.tool import Tool 7 | 8 | 9 | class UpdateReservationFlights(Tool): 10 | @staticmethod 11 | def invoke( 12 | data: Dict[str, Any], 13 | reservation_id: str, 14 | cabin: str, 15 | flights: List[Dict[str, Any]], 16 | payment_id: str, 17 | ) -> str: 18 | users, reservations = data["users"], data["reservations"] 19 | if reservation_id not in reservations: 20 | return "Error: reservation not found" 21 | reservation = reservations[reservation_id] 22 | 23 | # update flights and calculate price 24 | total_price = 0 25 | flights = deepcopy(flights) 26 | for flight in flights: 27 | # if existing flight, ignore 28 | if _ := [ 29 | f 30 | for f in reservation["flights"] 31 | if f["flight_number"] == flight["flight_number"] 32 | and f["date"] == flight["date"] 33 | and cabin == reservation["cabin"] 34 | ]: 35 | total_price += _[0]["price"] * len(reservation["passengers"]) 36 | flight["price"] = _[0]["price"] 37 | flight["origin"] = _[0]["origin"] 38 | flight["destination"] = _[0]["destination"] 39 | continue 40 | flight_number = flight["flight_number"] 41 | if flight_number not in data["flights"]: 42 | return f"Error: flight {flight_number} not found" 43 | flight_data = data["flights"][flight_number] 44 | if flight["date"] not in flight_data["dates"]: 45 | return ( 46 | f"Error: flight {flight_number} not found on date {flight['date']}" 47 | ) 48 | flight_date_data = flight_data["dates"][flight["date"]] 49 | if flight_date_data["status"] != "available": 50 | return f"Error: flight {flight_number} not available on date {flight['date']}" 51 | if flight_date_data["available_seats"][cabin] < len( 52 | reservation["passengers"] 53 | ): 54 | return f"Error: not enough seats on flight {flight_number}" 55 | flight["price"] = flight_date_data["prices"][cabin] 56 | flight["origin"] = flight_data["origin"] 57 | flight["destination"] = flight_data["destination"] 58 | total_price += flight["price"] * len(reservation["passengers"]) 59 | 60 | total_price -= sum(flight["price"] for flight in reservation["flights"]) * len( 61 | reservation["passengers"] 62 | ) 63 | 64 | # check payment 65 | if payment_id not in users[reservation["user_id"]]["payment_methods"]: 66 | return "Error: payment method not found" 67 | payment_method = users[reservation["user_id"]]["payment_methods"][payment_id] 68 | if payment_method["source"] == "certificate": 69 | return "Error: certificate cannot be used to update reservation" 70 | elif ( 71 | payment_method["source"] == "gift_card" 72 | and payment_method["amount"] < total_price 73 | ): 74 | return "Error: gift card balance is not enough" 75 | 76 | # if checks pass, deduct payment and update seats 77 | if payment_method["source"] == "gift_card": 78 | payment_method["amount"] -= total_price 79 | reservation["flights"] = flights 80 | if total_price != 0: 81 | reservation["payment_history"].append( 82 | { 83 | "payment_id": payment_id, 84 | "amount": total_price, 85 | } 86 | ) 87 | # do not make flight database update here, assume it takes time to be updated 88 | return json.dumps(reservation) 89 | 90 | @staticmethod 91 | def get_info() -> Dict[str, Any]: 92 | return { 93 | "type": "function", 94 | "function": { 95 | "name": "update_reservation_flights", 96 | "description": "Update the flight information of a reservation.", 97 | "parameters": { 98 | "type": "object", 99 | "properties": { 100 | "reservation_id": { 101 | "type": "string", 102 | "description": "The reservation ID, such as 'ZFA04Y'.", 103 | }, 104 | "cabin": { 105 | "type": "string", 106 | "enum": [ 107 | "basic_economy", 108 | "economy", 109 | "business", 110 | ], 111 | }, 112 | "flights": { 113 | "type": "array", 114 | "description": "An array of objects containing details about each piece of flight in the ENTIRE new reservation. Even if the a flight segment is not changed, it should still be included in the array.", 115 | "items": { 116 | "type": "object", 117 | "properties": { 118 | "flight_number": { 119 | "type": "string", 120 | "description": "Flight number, such as 'HAT001'.", 121 | }, 122 | "date": { 123 | "type": "string", 124 | "description": "The date for the flight in the format 'YYYY-MM-DD', such as '2024-05-01'.", 125 | }, 126 | }, 127 | "required": ["flight_number", "date"], 128 | }, 129 | }, 130 | "payment_id": { 131 | "type": "string", 132 | "description": "The payment id stored in user profile, such as 'credit_card_7815826', 'gift_card_7815826', 'certificate_7815826'.", 133 | }, 134 | }, 135 | "required": ["reservation_id", "cabin", "flights", "payment_id"], 136 | }, 137 | }, 138 | } 139 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/tools/update_reservation_passengers.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict, List 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class UpdateReservationPassengers(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | reservation_id: str, 13 | passengers: List[Dict[str, Any]], 14 | ) -> str: 15 | reservations = data["reservations"] 16 | if reservation_id not in reservations: 17 | return "Error: reservation not found" 18 | reservation = reservations[reservation_id] 19 | if len(passengers) != len(reservation["passengers"]): 20 | return "Error: number of passengers does not match" 21 | reservation["passengers"] = passengers 22 | return json.dumps(reservation) 23 | 24 | @staticmethod 25 | def get_info() -> Dict[str, Any]: 26 | return { 27 | "type": "function", 28 | "function": { 29 | "name": "update_reservation_passengers", 30 | "description": "Update the passenger information of a reservation.", 31 | "parameters": { 32 | "type": "object", 33 | "properties": { 34 | "reservation_id": { 35 | "type": "string", 36 | "description": "The reservation ID, such as 'ZFA04Y'.", 37 | }, 38 | "passengers": { 39 | "type": "array", 40 | "description": "An array of objects containing details about each passenger.", 41 | "items": { 42 | "type": "object", 43 | "properties": { 44 | "first_name": { 45 | "type": "string", 46 | "description": "The first name of the passenger, such as 'Noah'.", 47 | }, 48 | "last_name": { 49 | "type": "string", 50 | "description": "The last name of the passenger, such as 'Brown'.", 51 | }, 52 | "dob": { 53 | "type": "string", 54 | "description": "The date of birth of the passenger in the format 'YYYY-MM-DD', such as '1990-01-01'.", 55 | }, 56 | }, 57 | "required": ["first_name", "last_name", "dob"], 58 | }, 59 | }, 60 | }, 61 | "required": ["reservation_id", "passengers"], 62 | }, 63 | }, 64 | } 65 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/wiki.md: -------------------------------------------------------------------------------- 1 | # Airline Agent Policy 2 | 3 | The current time is 2024-05-15 15:00:00 EST. 4 | 5 | As an airline agent, you can help users book, modify, or cancel flight reservations. 6 | 7 | - Before taking any actions that update the booking database (booking, modifying flights, editing baggage, upgrading cabin class, or updating passenger information), you must list the action details and obtain explicit user confirmation (yes) to proceed. 8 | 9 | - You should not provide any information, knowledge, or procedures not provided by the user or available tools, or give subjective recommendations or comments. 10 | 11 | - You should only make one tool call at a time, and if you make a tool call, you should not respond to the user simultaneously. If you respond to the user, you should not make a tool call at the same time. 12 | 13 | - You should deny user requests that are against this policy. 14 | 15 | - You should transfer the user to a human agent if and only if the request cannot be handled within the scope of your actions. 16 | 17 | ## Domain Basic 18 | 19 | - Each user has a profile containing user id, email, addresses, date of birth, payment methods, reservation numbers, and membership tier. 20 | 21 | - Each reservation has an reservation id, user id, trip type (one way, round trip), flights, passengers, payment methods, created time, baggages, and travel insurance information. 22 | 23 | - Each flight has a flight number, an origin, destination, scheduled departure and arrival time (local time), and for each date: 24 | - If the status is "available", the flight has not taken off, available seats and prices are listed. 25 | - If the status is "delayed" or "on time", the flight has not taken off, cannot be booked. 26 | - If the status is "flying", the flight has taken off but not landed, cannot be booked. 27 | 28 | ## Book flight 29 | 30 | - The agent must first obtain the user id, then ask for the trip type, origin, destination. 31 | 32 | - Passengers: Each reservation can have at most five passengers. The agent needs to collect the first name, last name, and date of birth for each passenger. All passengers must fly the same flights in the same cabin. 33 | 34 | - Payment: each reservation can use at most one travel certificate, at most one credit card, and at most three gift cards. The remaining amount of a travel certificate is not refundable. All payment methods must already be in user profile for safety reasons. 35 | 36 | - Checked bag allowance: If the booking user is a regular member, 0 free checked bag for each basic economy passenger, 1 free checked bag for each economy passenger, and 2 free checked bags for each business passenger. If the booking user is a silver member, 1 free checked bag for each basic economy passenger, 2 free checked bag for each economy passenger, and 3 free checked bags for each business passenger. If the booking user is a gold member, 2 free checked bag for each basic economy passenger, 3 free checked bag for each economy passenger, and 3 free checked bags for each business passenger. Each extra baggage is 50 dollars. 37 | 38 | - Travel insurance: the agent should ask if the user wants to buy the travel insurance, which is 30 dollars per passenger and enables full refund if the user needs to cancel the flight given health or weather reasons. 39 | 40 | ## Modify flight 41 | 42 | - The agent must first obtain the user id and the reservation id. 43 | 44 | - Change flights: Basic economy flights cannot be modified. Other reservations can be modified without changing the origin, destination, and trip type. Some flight segments can be kept, but their prices will not be updated based on the current price. The API does not check these for the agent, so the agent must make sure the rules apply before calling the API! 45 | 46 | - Change cabin: all reservations, including basic economy, can change cabin without changing the flights. Cabin changes require the user to pay for the difference between their current cabin and the new cabin class. Cabin class must be the same across all the flights in the same reservation; changing cabin for just one flight segment is not possible. 47 | 48 | - Change baggage and insurance: The user can add but not remove checked bags. The user cannot add insurance after initial booking. 49 | 50 | - Change passengers: The user can modify passengers but cannot modify the number of passengers. This is something that even a human agent cannot assist with. 51 | 52 | - Payment: If the flights are changed, the user needs to provide one gift card or credit card for payment or refund method. The agent should ask for the payment or refund method instead. 53 | 54 | ## Cancel flight 55 | 56 | - The agent must first obtain the user id, the reservation id, and the reason for cancellation (change of plan, airline cancelled flight, or other reasons) 57 | 58 | - All reservations can be cancelled within 24 hours of booking, or if the airline cancelled the flight. Otherwise, basic economy or economy flights can be cancelled only if travel insurance is bought and the condition is met, and business flights can always be cancelled. The rules are strict regardless of the membership status. The API does not check these for the agent, so the agent must make sure the rules apply before calling the API! 59 | 60 | - The agent can only cancel the whole trip that is not flown. If any of the segments are already used, the agent cannot help and transfer is needed. 61 | 62 | - The refund will go to original payment methods in 5 to 7 business days. 63 | 64 | ## Refund 65 | 66 | - If the user is silver/gold member or has travel insurance or flies business, and complains about cancelled flights in a reservation, the agent can offer a certificate as a gesture after confirming the facts, with the amount being $100 times the number of passengers. 67 | 68 | - If the user is silver/gold member or has travel insurance or flies business, and complains about delayed flights in a reservation and wants to change or cancel the reservation, the agent can offer a certificate as a gesture after confirming the facts and changing or cancelling the reservation, with the amount being $50 times the number of passengers. 69 | 70 | - Do not proactively offer these unless the user complains about the situation and explicitly asks for some compensation. Do not compensate if the user is regular member and has no travel insurance and flies (basic) economy. 71 | -------------------------------------------------------------------------------- /tau_bench/envs/airline/wiki.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import os 4 | 5 | FOLDER_PATH = os.path.dirname(__file__) 6 | 7 | with open(os.path.join(FOLDER_PATH, "wiki.md"), "r") as f: 8 | WIKI = f.read() 9 | -------------------------------------------------------------------------------- /tau_bench/envs/base.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import random 4 | from hashlib import sha256 5 | from tau_bench.envs.tool import Tool 6 | from typing import Any, Callable, Dict, List, Type, Optional, Set, Union, Tuple 7 | 8 | from tau_bench.envs.user import load_user, UserStrategy 9 | from tau_bench.types import ( 10 | Action, 11 | Task, 12 | EnvInfo, 13 | EnvResetResponse, 14 | EnvResponse, 15 | RewardResult, 16 | RewardOutputInfo, 17 | RewardActionInfo, 18 | RESPOND_ACTION_NAME, 19 | ) 20 | 21 | ToHashable = Union[ 22 | str, int, float, Dict[str, "ToHashable"], List["ToHashable"], Set["ToHashable"] 23 | ] 24 | Hashable = Union[str, int, float, Tuple["Hashable"], Tuple[Tuple[str, "Hashable"]]] 25 | 26 | 27 | def to_hashable(item: ToHashable) -> Hashable: 28 | if isinstance(item, dict): 29 | return tuple((key, to_hashable(value)) for key, value in sorted(item.items())) 30 | elif isinstance(item, list): 31 | return tuple(to_hashable(element) for element in item) 32 | elif isinstance(item, set): 33 | return tuple(sorted(to_hashable(element) for element in item)) 34 | else: 35 | return item 36 | 37 | 38 | def consistent_hash( 39 | value: Hashable, 40 | ) -> str: 41 | return sha256(str(value).encode("utf-8")).hexdigest() 42 | 43 | 44 | class Env(object): 45 | def __init__( 46 | self, 47 | data_load_func: Callable[[], Dict[str, Any]], 48 | tools: List[Type[Tool]], 49 | tasks: List[Task], 50 | wiki: str, 51 | rules: List[str], 52 | user_strategy: Union[str, UserStrategy], 53 | user_model: str, 54 | user_provider: Optional[str] = None, 55 | task_index: Optional[int] = None, 56 | ) -> None: 57 | super().__init__() 58 | self.data_load_func = data_load_func 59 | self.data = data_load_func() 60 | self.tools_map: Dict[str, Type[Tool]] = { 61 | tool.get_info()["function"]["name"]: tool for tool in tools 62 | } 63 | self.tools_info = [tool.get_info() for tool in tools] 64 | self.terminate_tools = [] 65 | self.tasks = tasks 66 | if task_index is not None: 67 | self.task_index = task_index 68 | else: 69 | self.task_index = random.randint(0, len(tasks)) 70 | self.task = tasks[self.task_index] 71 | self.wiki = wiki 72 | self.rules = rules 73 | self.user = load_user( 74 | user_strategy=user_strategy, model=user_model, provider=user_provider 75 | ) 76 | self.actions: List[Action] = [] 77 | 78 | def reset(self, task_index: Optional[int] = None) -> EnvResetResponse: 79 | if task_index is None: 80 | task_index = random.randint(0, len(self.tasks)) 81 | self.task_index = task_index 82 | self.data = self.data_load_func() 83 | self.task = self.tasks[task_index] 84 | self.actions = [] 85 | initial_observation = self.user.reset(instruction=self.task.instruction) 86 | return EnvResetResponse( 87 | observation=initial_observation, info=EnvInfo(task=self.task, source="user") 88 | ) 89 | 90 | def step(self, action: Action) -> EnvResponse: 91 | self.actions.append(action) 92 | 93 | info = EnvInfo(task=self.task) 94 | reward = 0 95 | done = False 96 | if action.name == RESPOND_ACTION_NAME: 97 | observation = self.user.step(action.kwargs["content"]) 98 | info.source = "user" 99 | done = "###STOP###" in observation 100 | elif action.name in self.tools_map: 101 | try: 102 | observation = self.tools_map[action.name].invoke( 103 | data=self.data, **action.kwargs 104 | ) 105 | except Exception as e: 106 | observation = f"Error: {e}" 107 | info.source = action.name 108 | if action.name in self.terminate_tools: 109 | done = True 110 | else: 111 | observation = f"Unknown action {action.name}" 112 | info.source = action.name 113 | 114 | if done: 115 | reward_res = self.calculate_reward() 116 | reward = reward_res.reward 117 | info.reward_info = reward_res 118 | info.user_cost = self.user.get_total_cost() 119 | return EnvResponse(observation=observation, reward=reward, done=done, info=info) 120 | 121 | def get_data_hash(self) -> str: 122 | return consistent_hash(to_hashable(self.data)) 123 | 124 | def calculate_reward(self) -> RewardResult: 125 | data_hash = self.get_data_hash() 126 | reward = 1.0 127 | actions = [ 128 | action for action in self.task.actions if action.name != RESPOND_ACTION_NAME 129 | ] 130 | 131 | # Check if the database changes are correct. If they are not correct, then we set the reward to 0. 132 | # TODO: cache gt_data_hash in tasks.py (low priority) 133 | self.data = self.data_load_func() 134 | for action in self.task.actions: 135 | if action.name not in self.terminate_tools: 136 | self.step(action) 137 | gt_data_hash = self.get_data_hash() 138 | info = RewardActionInfo( 139 | r_actions=data_hash == gt_data_hash, gt_data_hash=gt_data_hash 140 | ) 141 | if not info.r_actions: 142 | reward = 0.0 143 | 144 | if len(self.task.outputs) > 0: 145 | # check outputs 146 | r_outputs = 1.0 147 | outputs = {} 148 | for output in self.task.outputs: 149 | found = False 150 | for action in self.actions: 151 | if ( 152 | action.name == RESPOND_ACTION_NAME 153 | and output.lower() 154 | in action.kwargs["content"].lower().replace(",", "") 155 | ): 156 | found = True 157 | break 158 | outputs[output] = found 159 | if not found: 160 | r_outputs = 0.0 161 | reward = 0.0 162 | info = RewardOutputInfo(r_outputs=r_outputs, outputs=outputs) 163 | 164 | return RewardResult(reward=reward, info=info, actions=actions) 165 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from tau_bench.envs.retail.env import MockRetailDomainEnv as MockRetailDomainEnv 4 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | import os 5 | from typing import Any 6 | 7 | FOLDER_PATH = os.path.dirname(__file__) 8 | 9 | 10 | def load_data() -> dict[str, Any]: 11 | with open(os.path.join(FOLDER_PATH, "orders.json")) as f: 12 | order_data = json.load(f) 13 | with open(os.path.join(FOLDER_PATH, "products.json")) as f: 14 | product_data = json.load(f) 15 | with open(os.path.join(FOLDER_PATH, "users.json")) as f: 16 | user_data = json.load(f) 17 | return { 18 | "orders": order_data, 19 | "products": product_data, 20 | "users": user_data, 21 | } 22 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/data/readme.md: -------------------------------------------------------------------------------- 1 | # Mock Data Generation 2 | 3 | ## Current Mock Data for the Benchmark 4 | Feel free to use some of the data for other purposes. 5 | - `users.json`: a database of users with their emails, addresses, and orders 6 | - `products.json`: a database of products, where each product has variants (e.g., size, color). 7 | - `orders.json`: a database of orders that can be operated upon. 8 | 9 | 10 | Check `../tools` for mock APIs on top of current mock data. 11 | 12 | 13 | ### Experience of Mock Data Generation 14 | 15 | Read our paper to learn more about the generation process for each database. In general, it involves the following stages: 16 | 17 | 1. Design the type and schema of each database. Can use GPT for co-brainstorming but has to be human decided as it is the foundation of everything else. 18 | 2. For each schema, figure out which parts can be programmaticly generated and which parts need GPT. For example, 19 | - Product types (shirt, lamp, pen) and user names (Sara, John, Noah) need GPT generation 20 | - Product price and shipping date can be generated via code 21 | 3. Use GPT to generate seed data (first names, last names, addresses, cities, etc.), then use a program to compose them with other code generated data. Can use GPT to help write the code for this part, but I think code-based database construction is more reliable than GPT-based database construction (e.g., give some example user profiles and ask GPT to generate more --- issues with diversity and reliability). 22 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/env.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from tau_bench.envs.base import Env 4 | from tau_bench.envs.retail.data import load_data 5 | from tau_bench.envs.retail.rules import RULES 6 | from tau_bench.envs.retail.tools import ALL_TOOLS 7 | from tau_bench.envs.retail.wiki import WIKI 8 | from typing import Optional, Union 9 | from tau_bench.envs.user import UserStrategy 10 | 11 | 12 | class MockRetailDomainEnv(Env): 13 | def __init__( 14 | self, 15 | user_strategy: Union[str, UserStrategy] = UserStrategy.LLM, 16 | user_model: str = "gpt-4o", 17 | user_provider: Optional[str] = None, 18 | task_split: str = "test", 19 | task_index: Optional[int] = None, 20 | ): 21 | match task_split: 22 | case "test": 23 | from tau_bench.envs.retail.tasks_test import TASKS_TEST as tasks 24 | case "train": 25 | from tau_bench.envs.retail.tasks_train import TASKS_TRAIN as tasks 26 | case "dev": 27 | from tau_bench.envs.retail.tasks_dev import TASKS_DEV as tasks 28 | case _: 29 | raise ValueError(f"Unknown task split: {task_split}") 30 | super().__init__( 31 | data_load_func=load_data, 32 | tools=ALL_TOOLS, 33 | tasks=tasks, 34 | wiki=WIKI, 35 | rules=RULES, 36 | user_strategy=user_strategy, 37 | user_model=user_model, 38 | user_provider=user_provider, 39 | task_index=task_index, 40 | ) 41 | self.terminate_tools = ["transfer_to_human_agents"] 42 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/rules.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | RULES = [ 4 | "You are a customer service representative for an online retail company. You are chatting with a customer, and you can call tools or respond to the user.", 5 | "The agent should always first confirm the user id by email or name+zip before proceeding with any task.", 6 | "The agent should not proceed with any task if the user id is not found.", 7 | "For any change to the backend database, e.g., address update, refund, or order cancellation, the agent must confirm the transaction details with the user and ask for permission, and get explicit authorization (yes) to proceed.", 8 | "The agent should solve the user task given the tools, without transferring to a human agent.", 9 | "The agent should not make up any information or knowledge not provided from the user or the tools.", 10 | "The agent should at most make one tool call at a time, and if the agent makes a tool call, it does not respond to the user at the same time.", 11 | ] 12 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from .calculate import Calculate 4 | from .cancel_pending_order import CancelPendingOrder 5 | from .exchange_delivered_order_items import ExchangeDeliveredOrderItems 6 | from .find_user_id_by_email import FindUserIdByEmail 7 | from .find_user_id_by_name_zip import FindUserIdByNameZip 8 | from .get_order_details import GetOrderDetails 9 | from .get_product_details import GetProductDetails 10 | from .get_user_details import GetUserDetails 11 | from .list_all_product_types import ListAllProductTypes 12 | from .modify_pending_order_address import ModifyPendingOrderAddress 13 | from .modify_pending_order_items import ModifyPendingOrderItems 14 | from .modify_pending_order_payment import ModifyPendingOrderPayment 15 | from .modify_user_address import ModifyUserAddress 16 | from .return_delivered_order_items import ReturnDeliveredOrderItems 17 | from .think import Think 18 | from .transfer_to_human_agents import TransferToHumanAgents 19 | 20 | 21 | ALL_TOOLS = [ 22 | Calculate, 23 | CancelPendingOrder, 24 | ExchangeDeliveredOrderItems, 25 | FindUserIdByEmail, 26 | FindUserIdByNameZip, 27 | GetOrderDetails, 28 | GetProductDetails, 29 | GetUserDetails, 30 | ListAllProductTypes, 31 | ModifyPendingOrderAddress, 32 | ModifyPendingOrderItems, 33 | ModifyPendingOrderPayment, 34 | ModifyUserAddress, 35 | ReturnDeliveredOrderItems, 36 | Think, 37 | TransferToHumanAgents, 38 | ] 39 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/calculate.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class Calculate(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], expression: str) -> str: 10 | if not all(char in "0123456789+-*/(). " for char in expression): 11 | return "Error: invalid characters in expression" 12 | try: 13 | # Evaluate the mathematical expression safely 14 | return str(round(float(eval(expression, {"__builtins__": None}, {})), 2)) 15 | except Exception as e: 16 | return f"Error: {e}" 17 | 18 | @staticmethod 19 | def get_info() -> Dict[str, Any]: 20 | return { 21 | "type": "function", 22 | "function": { 23 | "name": "calculate", 24 | "description": "Calculate the result of a mathematical expression.", 25 | "parameters": { 26 | "type": "object", 27 | "properties": { 28 | "expression": { 29 | "type": "string", 30 | "description": "The mathematical expression to calculate, such as '2 + 2'. The expression can contain numbers, operators (+, -, *, /), parentheses, and spaces.", 31 | }, 32 | }, 33 | "required": ["expression"], 34 | }, 35 | }, 36 | } 37 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/cancel_pending_order.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class CancelPendingOrder(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], order_id: str, reason: str) -> str: 11 | # check order exists and is pending 12 | orders = data["orders"] 13 | if order_id not in orders: 14 | return "Error: order not found" 15 | order = orders[order_id] 16 | if order["status"] != "pending": 17 | return "Error: non-pending order cannot be cancelled" 18 | 19 | # check reason 20 | if reason not in ["no longer needed", "ordered by mistake"]: 21 | return "Error: invalid reason" 22 | 23 | # handle refund 24 | refunds = [] 25 | for payment in order["payment_history"]: 26 | payment_id = payment["payment_method_id"] 27 | refund = { 28 | "transaction_type": "refund", 29 | "amount": payment["amount"], 30 | "payment_method_id": payment_id, 31 | } 32 | refunds.append(refund) 33 | if "gift_card" in payment_id: # refund to gift card immediately 34 | payment_method = data["users"][order["user_id"]]["payment_methods"][ 35 | payment_id 36 | ] 37 | payment_method["balance"] += payment["amount"] 38 | payment_method["balance"] = round(payment_method["balance"], 2) 39 | 40 | # update order status 41 | order["status"] = "cancelled" 42 | order["cancel_reason"] = reason 43 | order["payment_history"].extend(refunds) 44 | 45 | return json.dumps(order) 46 | 47 | @staticmethod 48 | def get_info() -> Dict[str, Any]: 49 | return { 50 | "type": "function", 51 | "function": { 52 | "name": "cancel_pending_order", 53 | "description": ( 54 | "Cancel a pending order. If the order is already processed or delivered, " 55 | "it cannot be cancelled. The agent needs to explain the cancellation detail " 56 | "and ask for explicit user confirmation (yes/no) to proceed. If the user confirms, " 57 | "the order status will be changed to 'cancelled' and the payment will be refunded. " 58 | "The refund will be added to the user's gift card balance immediately if the payment " 59 | "was made using a gift card, otherwise the refund would take 5-7 business days to process. " 60 | "The function returns the order details after the cancellation." 61 | ), 62 | "parameters": { 63 | "type": "object", 64 | "properties": { 65 | "order_id": { 66 | "type": "string", 67 | "description": "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id.", 68 | }, 69 | "reason": { 70 | "type": "string", 71 | "enum": ["no longer needed", "ordered by mistake"], 72 | "description": "The reason for cancellation, which should be either 'no longer needed' or 'ordered by mistake'.", 73 | }, 74 | }, 75 | "required": ["order_id", "reason"], 76 | }, 77 | }, 78 | } 79 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/exchange_delivered_order_items.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict, List 5 | 6 | from tau_bench.envs.tool import Tool 7 | 8 | 9 | class ExchangeDeliveredOrderItems(Tool): 10 | @staticmethod 11 | def invoke( 12 | data: Dict[str, Any], 13 | order_id: str, 14 | item_ids: List[str], 15 | new_item_ids: List[str], 16 | payment_method_id: str, 17 | ) -> str: 18 | products, orders, users = data["products"], data["orders"], data["users"] 19 | 20 | # check order exists and is delivered 21 | if order_id not in orders: 22 | return "Error: order not found" 23 | order = orders[order_id] 24 | if order["status"] != "delivered": 25 | return "Error: non-delivered order cannot be exchanged" 26 | 27 | # check the items to be exchanged exist 28 | all_item_ids = [item["item_id"] for item in order["items"]] 29 | for item_id in item_ids: 30 | if item_ids.count(item_id) > all_item_ids.count(item_id): 31 | return f"Error: {item_id} not found" 32 | 33 | # check new items exist and match old items and are available 34 | if len(item_ids) != len(new_item_ids): 35 | return "Error: the number of items to be exchanged should match" 36 | 37 | diff_price = 0 38 | for item_id, new_item_id in zip(item_ids, new_item_ids): 39 | item = [item for item in order["items"] if item["item_id"] == item_id][0] 40 | product_id = item["product_id"] 41 | if not ( 42 | new_item_id in products[product_id]["variants"] 43 | and products[product_id]["variants"][new_item_id]["available"] 44 | ): 45 | return f"Error: new item {new_item_id} not found or available" 46 | 47 | old_price = item["price"] 48 | new_price = products[product_id]["variants"][new_item_id]["price"] 49 | diff_price += new_price - old_price 50 | 51 | diff_price = round(diff_price, 2) 52 | 53 | # check payment method exists and can cover the price difference if gift card 54 | if payment_method_id not in users[order["user_id"]]["payment_methods"]: 55 | return "Error: payment method not found" 56 | 57 | payment_method = users[order["user_id"]]["payment_methods"][payment_method_id] 58 | if ( 59 | payment_method["source"] == "gift_card" 60 | and payment_method["balance"] < diff_price 61 | ): 62 | return ( 63 | "Error: insufficient gift card balance to pay for the price difference" 64 | ) 65 | 66 | # modify the order 67 | order["status"] = "exchange requested" 68 | order["exchange_items"] = sorted(item_ids) 69 | order["exchange_new_items"] = sorted(new_item_ids) 70 | order["exchange_payment_method_id"] = payment_method_id 71 | order["exchange_price_difference"] = diff_price 72 | 73 | return json.dumps(order) 74 | 75 | @staticmethod 76 | def get_info() -> Dict[str, Any]: 77 | return { 78 | "type": "function", 79 | "function": { 80 | "name": "exchange_delivered_order_items", 81 | "description": ( 82 | "Exchange items in a delivered order to new items of the same product type. " 83 | "For a delivered order, return or exchange can be only done once by the agent. " 84 | "The agent needs to explain the exchange detail and ask for explicit user confirmation (yes/no) to proceed." 85 | ), 86 | "parameters": { 87 | "type": "object", 88 | "properties": { 89 | "order_id": { 90 | "type": "string", 91 | "description": "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id.", 92 | }, 93 | "item_ids": { 94 | "type": "array", 95 | "items": { 96 | "type": "string", 97 | }, 98 | "description": "The item ids to be exchanged, each such as '1008292230'. There could be duplicate items in the list.", 99 | }, 100 | "new_item_ids": { 101 | "type": "array", 102 | "items": { 103 | "type": "string", 104 | }, 105 | "description": ( 106 | "The item ids to be exchanged for, each such as '1008292230'. " 107 | "There could be duplicate items in the list. Each new item id should match the item id in the same position and be of the same product." 108 | ), 109 | }, 110 | "payment_method_id": { 111 | "type": "string", 112 | "description": ( 113 | "The payment method id to pay or receive refund for the item price difference, " 114 | "such as 'gift_card_0000000' or 'credit_card_0000000'. These can be looked up from the user or order details." 115 | ), 116 | }, 117 | }, 118 | "required": [ 119 | "order_id", 120 | "item_ids", 121 | "new_item_ids", 122 | "payment_method_id", 123 | ], 124 | }, 125 | }, 126 | } 127 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/find_user_id_by_email.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class FindUserIdByEmail(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], email: str) -> str: 10 | users = data["users"] 11 | for user_id, profile in users.items(): 12 | if profile["email"].lower() == email.lower(): 13 | return user_id 14 | return "Error: user not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "find_user_id_by_email", 22 | "description": "Find user id by email. If the user is not found, the function will return an error message.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "email": { 27 | "type": "string", 28 | "description": "The email of the user, such as 'something@example.com'.", 29 | }, 30 | }, 31 | "required": ["email"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/find_user_id_by_name_zip.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class FindUserIdByNameZip(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], first_name: str, last_name: str, zip: str) -> str: 10 | users = data["users"] 11 | for user_id, profile in users.items(): 12 | if ( 13 | profile["name"]["first_name"].lower() == first_name.lower() 14 | and profile["name"]["last_name"].lower() == last_name.lower() 15 | and profile["address"]["zip"] == zip 16 | ): 17 | return user_id 18 | return "Error: user not found" 19 | 20 | @staticmethod 21 | def get_info() -> Dict[str, Any]: 22 | return { 23 | "type": "function", 24 | "function": { 25 | "name": "find_user_id_by_name_zip", 26 | "description": ( 27 | "Find user id by first name, last name, and zip code. If the user is not found, the function " 28 | "will return an error message. By default, find user id by email, and only call this function " 29 | "if the user is not found by email or cannot remember email." 30 | ), 31 | "parameters": { 32 | "type": "object", 33 | "properties": { 34 | "first_name": { 35 | "type": "string", 36 | "description": "The first name of the customer, such as 'John'.", 37 | }, 38 | "last_name": { 39 | "type": "string", 40 | "description": "The last name of the customer, such as 'Doe'.", 41 | }, 42 | "zip": { 43 | "type": "string", 44 | "description": "The zip code of the customer, such as '12345'.", 45 | }, 46 | }, 47 | "required": ["first_name", "last_name", "zip"], 48 | }, 49 | }, 50 | } 51 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/get_order_details.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class GetOrderDetails(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], order_id: str) -> str: 11 | orders = data["orders"] 12 | if order_id in orders: 13 | return json.dumps(orders[order_id]) 14 | return "Error: order not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "get_order_details", 22 | "description": "Get the status and details of an order.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "order_id": { 27 | "type": "string", 28 | "description": "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id.", 29 | }, 30 | }, 31 | "required": ["order_id"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/get_product_details.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class GetProductDetails(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], product_id: str) -> str: 11 | products = data["products"] 12 | if product_id in products: 13 | return json.dumps(products[product_id]) 14 | return "Error: product not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "get_product_details", 22 | "description": "Get the inventory details of a product.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "product_id": { 27 | "type": "string", 28 | "description": "The product id, such as '6086499569'. Be careful the product id is different from the item id.", 29 | }, 30 | }, 31 | "required": ["product_id"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/get_user_details.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class GetUserDetails(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], user_id: str) -> str: 11 | users = data["users"] 12 | if user_id in users: 13 | return json.dumps(users[user_id]) 14 | return "Error: user not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "get_user_details", 22 | "description": "Get the details of a user, including their orders.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "user_id": { 27 | "type": "string", 28 | "description": "The user id, such as 'sara_doe_496'.", 29 | }, 30 | }, 31 | "required": ["user_id"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/list_all_product_types.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ListAllProductTypes(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any]) -> str: 11 | products = data["products"] 12 | product_dict = { 13 | product["name"]: product["product_id"] for product in products.values() 14 | } 15 | product_dict = dict(sorted(product_dict.items())) 16 | return json.dumps(product_dict) 17 | 18 | @staticmethod 19 | def get_info() -> Dict[str, Any]: 20 | return { 21 | "type": "function", 22 | "function": { 23 | "name": "list_all_product_types", 24 | "description": "List the name and product id of all product types. Each product type has a variety of different items with unique item ids and options. There are only 50 product types in the store.", 25 | "parameters": { 26 | "type": "object", 27 | "properties": {}, 28 | "required": [], 29 | }, 30 | }, 31 | } 32 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/modify_pending_order_address.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ModifyPendingOrderAddress(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | order_id: str, 13 | address1: str, 14 | address2: str, 15 | city: str, 16 | state: str, 17 | country: str, 18 | zip: str, 19 | ) -> str: 20 | # Check if the order exists and is pending 21 | orders = data["orders"] 22 | if order_id not in orders: 23 | return "Error: order not found" 24 | order = orders[order_id] 25 | if order["status"] != "pending": 26 | return "Error: non-pending order cannot be modified" 27 | 28 | # Modify the address 29 | order["address"] = { 30 | "address1": address1, 31 | "address2": address2, 32 | "city": city, 33 | "state": state, 34 | "country": country, 35 | "zip": zip, 36 | } 37 | return json.dumps(order) 38 | 39 | @staticmethod 40 | def get_info() -> Dict[str, Any]: 41 | return { 42 | "type": "function", 43 | "function": { 44 | "name": "modify_pending_order_address", 45 | "description": "Modify the shipping address of a pending order. The agent needs to explain the modification detail and ask for explicit user confirmation (yes/no) to proceed.", 46 | "parameters": { 47 | "type": "object", 48 | "properties": { 49 | "order_id": { 50 | "type": "string", 51 | "description": "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id.", 52 | }, 53 | "address1": { 54 | "type": "string", 55 | "description": "The first line of the address, such as '123 Main St'.", 56 | }, 57 | "address2": { 58 | "type": "string", 59 | "description": "The second line of the address, such as 'Apt 1' or ''.", 60 | }, 61 | "city": { 62 | "type": "string", 63 | "description": "The city, such as 'San Francisco'.", 64 | }, 65 | "state": { 66 | "type": "string", 67 | "description": "The state, such as 'CA'.", 68 | }, 69 | "country": { 70 | "type": "string", 71 | "description": "The country, such as 'USA'.", 72 | }, 73 | "zip": { 74 | "type": "string", 75 | "description": "The zip code, such as '12345'.", 76 | }, 77 | }, 78 | "required": [ 79 | "order_id", 80 | "address1", 81 | "address2", 82 | "city", 83 | "state", 84 | "country", 85 | "zip", 86 | ], 87 | }, 88 | }, 89 | } 90 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/modify_pending_order_items.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict, List 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ModifyPendingOrderItems(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | order_id: str, 13 | item_ids: List[str], 14 | new_item_ids: List[str], 15 | payment_method_id: str, 16 | ) -> str: 17 | products, orders, users = data["products"], data["orders"], data["users"] 18 | 19 | # Check if the order exists and is pending 20 | if order_id not in orders: 21 | return "Error: order not found" 22 | order = orders[order_id] 23 | if order["status"] != "pending": 24 | return "Error: non-pending order cannot be modified" 25 | 26 | # Check if the items to be modified exist 27 | all_item_ids = [item["item_id"] for item in order["items"]] 28 | for item_id in item_ids: 29 | if item_ids.count(item_id) > all_item_ids.count(item_id): 30 | return f"Error: {item_id} not found" 31 | 32 | # Check new items exist, match old items, and are available 33 | if len(item_ids) != len(new_item_ids): 34 | return "Error: the number of items to be exchanged should match" 35 | 36 | diff_price = 0 37 | for item_id, new_item_id in zip(item_ids, new_item_ids): 38 | item = [item for item in order["items"] if item["item_id"] == item_id][0] 39 | product_id = item["product_id"] 40 | if not ( 41 | new_item_id in products[product_id]["variants"] 42 | and products[product_id]["variants"][new_item_id]["available"] 43 | ): 44 | return f"Error: new item {new_item_id} not found or available" 45 | 46 | old_price = item["price"] 47 | new_price = products[product_id]["variants"][new_item_id]["price"] 48 | diff_price += new_price - old_price 49 | 50 | # Check if the payment method exists 51 | if payment_method_id not in users[order["user_id"]]["payment_methods"]: 52 | return "Error: payment method not found" 53 | 54 | # If the new item is more expensive, check if the gift card has enough balance 55 | payment_method = users[order["user_id"]]["payment_methods"][payment_method_id] 56 | if ( 57 | payment_method["source"] == "gift_card" 58 | and payment_method["balance"] < diff_price 59 | ): 60 | return "Error: insufficient gift card balance to pay for the new item" 61 | 62 | # Handle the payment or refund 63 | order["payment_history"].append( 64 | { 65 | "transaction_type": "payment" if diff_price > 0 else "refund", 66 | "amount": abs(diff_price), 67 | "payment_method_id": payment_method_id, 68 | } 69 | ) 70 | if payment_method["source"] == "gift_card": 71 | payment_method["balance"] -= diff_price 72 | payment_method["balance"] = round(payment_method["balance"], 2) 73 | 74 | # Modify the order 75 | for item_id, new_item_id in zip(item_ids, new_item_ids): 76 | item = [item for item in order["items"] if item["item_id"] == item_id][0] 77 | item["item_id"] = new_item_id 78 | item["price"] = products[item["product_id"]]["variants"][new_item_id][ 79 | "price" 80 | ] 81 | item["options"] = products[item["product_id"]]["variants"][new_item_id][ 82 | "options" 83 | ] 84 | order["status"] = "pending (item modified)" 85 | 86 | return json.dumps(order) 87 | 88 | @staticmethod 89 | def get_info() -> Dict[str, Any]: 90 | return { 91 | "type": "function", 92 | "function": { 93 | "name": "modify_pending_order_items", 94 | "description": "Modify items in a pending order to new items of the same product type. For a pending order, this function can only be called once. The agent needs to explain the exchange detail and ask for explicit user confirmation (yes/no) to proceed.", 95 | "parameters": { 96 | "type": "object", 97 | "properties": { 98 | "order_id": { 99 | "type": "string", 100 | "description": "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id.", 101 | }, 102 | "item_ids": { 103 | "type": "array", 104 | "items": { 105 | "type": "string", 106 | }, 107 | "description": "The item ids to be modified, each such as '1008292230'. There could be duplicate items in the list.", 108 | }, 109 | "new_item_ids": { 110 | "type": "array", 111 | "items": { 112 | "type": "string", 113 | }, 114 | "description": "The item ids to be modified for, each such as '1008292230'. There could be duplicate items in the list. Each new item id should match the item id in the same position and be of the same product.", 115 | }, 116 | "payment_method_id": { 117 | "type": "string", 118 | "description": "The payment method id to pay or receive refund for the item price difference, such as 'gift_card_0000000' or 'credit_card_0000000'. These can be looked up from the user or order details.", 119 | }, 120 | }, 121 | "required": [ 122 | "order_id", 123 | "item_ids", 124 | "new_item_ids", 125 | "payment_method_id", 126 | ], 127 | }, 128 | }, 129 | } 130 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/modify_pending_order_payment.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ModifyPendingOrderPayment(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | order_id: str, 13 | payment_method_id: str, 14 | ) -> str: 15 | orders = data["orders"] 16 | 17 | # Check if the order exists and is pending 18 | if order_id not in orders: 19 | return "Error: order not found" 20 | order = orders[order_id] 21 | if order["status"] != "pending": 22 | return "Error: non-pending order cannot be modified" 23 | 24 | # Check if the payment method exists 25 | if payment_method_id not in data["users"][order["user_id"]]["payment_methods"]: 26 | return "Error: payment method not found" 27 | 28 | # Check that the payment history should only have one payment 29 | if ( 30 | len(order["payment_history"]) > 1 31 | or order["payment_history"][0]["transaction_type"] != "payment" 32 | ): 33 | return "Error: there should be exactly one payment for a pending order" 34 | 35 | # Check that the payment method is different 36 | if order["payment_history"][0]["payment_method_id"] == payment_method_id: 37 | return ( 38 | "Error: the new payment method should be different from the current one" 39 | ) 40 | 41 | amount = order["payment_history"][0]["amount"] 42 | payment_method = data["users"][order["user_id"]]["payment_methods"][ 43 | payment_method_id 44 | ] 45 | 46 | # Check if the new payment method has enough balance if it is a gift card 47 | if ( 48 | payment_method["source"] == "gift_card" 49 | and payment_method["balance"] < amount 50 | ): 51 | return "Error: insufficient gift card balance to pay for the order" 52 | 53 | # Modify the payment method 54 | order["payment_history"].extend( 55 | [ 56 | { 57 | "transaction_type": "payment", 58 | "amount": amount, 59 | "payment_method_id": payment_method_id, 60 | }, 61 | { 62 | "transaction_type": "refund", 63 | "amount": amount, 64 | "payment_method_id": order["payment_history"][0][ 65 | "payment_method_id" 66 | ], 67 | }, 68 | ] 69 | ) 70 | 71 | # If payment is made by gift card, update the balance 72 | if payment_method["source"] == "gift_card": 73 | payment_method["balance"] -= amount 74 | payment_method["balance"] = round(payment_method["balance"], 2) 75 | 76 | # If refund is made to a gift card, update the balance 77 | if "gift_card" in order["payment_history"][0]["payment_method_id"]: 78 | old_payment_method = data["users"][order["user_id"]]["payment_methods"][ 79 | order["payment_history"][0]["payment_method_id"] 80 | ] 81 | old_payment_method["balance"] += amount 82 | old_payment_method["balance"] = round(old_payment_method["balance"], 2) 83 | 84 | return json.dumps(order) 85 | 86 | @staticmethod 87 | def get_info() -> Dict[str, Any]: 88 | return { 89 | "type": "function", 90 | "function": { 91 | "name": "modify_pending_order_payment", 92 | "description": "Modify the payment method of a pending order. The agent needs to explain the modification detail and ask for explicit user confirmation (yes/no) to proceed.", 93 | "parameters": { 94 | "type": "object", 95 | "properties": { 96 | "order_id": { 97 | "type": "string", 98 | "description": "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id.", 99 | }, 100 | "payment_method_id": { 101 | "type": "string", 102 | "description": "The payment method id to pay or receive refund for the item price difference, such as 'gift_card_0000000' or 'credit_card_0000000'. These can be looked up from the user or order details.", 103 | }, 104 | }, 105 | "required": [ 106 | "order_id", 107 | "payment_method_id", 108 | ], 109 | }, 110 | }, 111 | } 112 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/modify_user_address.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ModifyUserAddress(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | user_id: str, 13 | address1: str, 14 | address2: str, 15 | city: str, 16 | state: str, 17 | country: str, 18 | zip: str, 19 | ) -> str: 20 | users = data["users"] 21 | if user_id not in users: 22 | return "Error: user not found" 23 | user = users[user_id] 24 | user["address"] = { 25 | "address1": address1, 26 | "address2": address2, 27 | "city": city, 28 | "state": state, 29 | "country": country, 30 | "zip": zip, 31 | } 32 | return json.dumps(user) 33 | 34 | @staticmethod 35 | def get_info() -> Dict[str, Any]: 36 | return { 37 | "type": "function", 38 | "function": { 39 | "name": "modify_user_address", 40 | "description": "Modify the default address of a user. The agent needs to explain the modification detail and ask for explicit user confirmation (yes/no) to proceed.", 41 | "parameters": { 42 | "type": "object", 43 | "properties": { 44 | "user_id": { 45 | "type": "string", 46 | "description": "The user id, such as 'sara_doe_496'.", 47 | }, 48 | "address1": { 49 | "type": "string", 50 | "description": "The first line of the address, such as '123 Main St'.", 51 | }, 52 | "address2": { 53 | "type": "string", 54 | "description": "The second line of the address, such as 'Apt 1' or ''.", 55 | }, 56 | "city": { 57 | "type": "string", 58 | "description": "The city, such as 'San Francisco'.", 59 | }, 60 | "state": { 61 | "type": "string", 62 | "description": "The state, such as 'CA'.", 63 | }, 64 | "country": { 65 | "type": "string", 66 | "description": "The country, such as 'USA'.", 67 | }, 68 | "zip": { 69 | "type": "string", 70 | "description": "The zip code, such as '12345'.", 71 | }, 72 | }, 73 | "required": [ 74 | "user_id", 75 | "address1", 76 | "address2", 77 | "city", 78 | "state", 79 | "country", 80 | "zip", 81 | ], 82 | }, 83 | }, 84 | } 85 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/return_delivered_order_items.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict, List 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ReturnDeliveredOrderItems(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], order_id: str, item_ids: List[str], payment_method_id: str 12 | ) -> str: 13 | orders = data["orders"] 14 | 15 | # Check if the order exists and is delivered 16 | if order_id not in orders: 17 | return "Error: order not found" 18 | order = orders[order_id] 19 | if order["status"] != "delivered": 20 | return "Error: non-delivered order cannot be returned" 21 | 22 | # Check if the payment method exists and is either the original payment method or a gift card 23 | if payment_method_id not in data["users"][order["user_id"]]["payment_methods"]: 24 | return "Error: payment method not found" 25 | if ( 26 | "gift_card" not in payment_method_id 27 | and payment_method_id != order["payment_history"][0]["payment_method_id"] 28 | ): 29 | return "Error: payment method should be either the original payment method or a gift card" 30 | 31 | # Check if the items to be returned exist (there could be duplicate items in either list) 32 | all_item_ids = [item["item_id"] for item in order["items"]] 33 | for item_id in item_ids: 34 | if item_ids.count(item_id) > all_item_ids.count(item_id): 35 | return "Error: some item not found" 36 | 37 | # Update the order status 38 | order["status"] = "return requested" 39 | order["return_items"] = sorted(item_ids) 40 | order["return_payment_method_id"] = payment_method_id 41 | 42 | return json.dumps(order) 43 | 44 | @staticmethod 45 | def get_info() -> Dict[str, Any]: 46 | return { 47 | "type": "function", 48 | "function": { 49 | "name": "return_delivered_order_items", 50 | "description": ( 51 | "Return some items of a delivered order. The order status will be changed to 'return requested'. " 52 | "The agent needs to explain the return detail and ask for explicit user confirmation (yes/no) to proceed. " 53 | "The user will receive follow-up email for how and where to return the item." 54 | ), 55 | "parameters": { 56 | "type": "object", 57 | "properties": { 58 | "order_id": { 59 | "type": "string", 60 | "description": ( 61 | "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id." 62 | ), 63 | }, 64 | "item_ids": { 65 | "type": "array", 66 | "items": {"type": "string"}, 67 | "description": ( 68 | "The item ids to be returned, each such as '1008292230'. There could be duplicate items in the list." 69 | ), 70 | }, 71 | "payment_method_id": { 72 | "type": "string", 73 | "description": ( 74 | "The payment method id to pay or receive refund for the item price difference, such as 'gift_card_0000000' or 'credit_card_0000000'. " 75 | "These can be looked up from the user or order details." 76 | ), 77 | }, 78 | }, 79 | "required": ["order_id", "item_ids", "payment_method_id"], 80 | }, 81 | }, 82 | } 83 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/think.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class Think(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], thought: str) -> str: 10 | # This method does not change the state of the data; it simply returns an empty string. 11 | return "" 12 | 13 | @staticmethod 14 | def get_info() -> Dict[str, Any]: 15 | return { 16 | "type": "function", 17 | "function": { 18 | "name": "think", 19 | "description": ( 20 | "Use the tool to think about something. It will not obtain new information or change the database, " 21 | "but just append the thought to the log. Use it when complex reasoning or some cache memory is needed." 22 | ), 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "thought": { 27 | "type": "string", 28 | "description": "A thought to think about.", 29 | }, 30 | }, 31 | "required": ["thought"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/tools/transfer_to_human_agents.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class TransferToHumanAgents(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], summary: str) -> str: 10 | # This method simulates the transfer to a human agent. 11 | return "Transfer successful" 12 | 13 | @staticmethod 14 | def get_info() -> Dict[str, Any]: 15 | return { 16 | "type": "function", 17 | "function": { 18 | "name": "transfer_to_human_agents", 19 | "description": ( 20 | "Transfer the user to a human agent, with a summary of the user's issue. " 21 | "Only transfer if the user explicitly asks for a human agent, or if the user's issue cannot be resolved by the agent with the available tools." 22 | ), 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "summary": { 27 | "type": "string", 28 | "description": "A summary of the user's issue.", 29 | }, 30 | }, 31 | "required": ["summary"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/wiki.md: -------------------------------------------------------------------------------- 1 | # Retail agent policy 2 | 3 | As a retail agent, you can help users cancel or modify pending orders, return or exchange delivered orders, modify their default user address, or provide information about their own profile, orders, and related products. 4 | 5 | - At the beginning of the conversation, you have to authenticate the user identity by locating their user id via email, or via name + zip code. This has to be done even when the user already provides the user id. 6 | 7 | - Once the user has been authenticated, you can provide the user with information about order, product, profile information, e.g. help the user look up order id. 8 | 9 | - You can only help one user per conversation (but you can handle multiple requests from the same user), and must deny any requests for tasks related to any other user. 10 | 11 | - Before taking consequential actions that update the database (cancel, modify, return, exchange), you have to list the action detail and obtain explicit user confirmation (yes) to proceed. 12 | 13 | - You should not make up any information or knowledge or procedures not provided from the user or the tools, or give subjective recommendations or comments. 14 | 15 | - You should at most make one tool call at a time, and if you take a tool call, you should not respond to the user at the same time. If you respond to the user, you should not make a tool call. 16 | 17 | - You should transfer the user to a human agent if and only if the request cannot be handled within the scope of your actions. 18 | 19 | ## Domain basic 20 | 21 | - All times in the database are EST and 24 hour based. For example "02:30:00" means 2:30 AM EST. 22 | 23 | - Each user has a profile of its email, default address, user id, and payment methods. Each payment method is either a gift card, a paypal account, or a credit card. 24 | 25 | - Our retail store has 50 types of products. For each type of product, there are variant items of different options. For example, for a 't shirt' product, there could be an item with option 'color blue size M', and another item with option 'color red size L'. 26 | 27 | - Each product has an unique product id, and each item has an unique item id. They have no relations and should not be confused. 28 | 29 | - Each order can be in status 'pending', 'processed', 'delivered', or 'cancelled'. Generally, you can only take action on pending or delivered orders. 30 | 31 | - Exchange or modify order tools can only be called once. Be sure that all items to be changed are collected into a list before making the tool call!!! 32 | 33 | ## Cancel pending order 34 | 35 | - An order can only be cancelled if its status is 'pending', and you should check its status before taking the action. 36 | 37 | - The user needs to confirm the order id and the reason (either 'no longer needed' or 'ordered by mistake') for cancellation. 38 | 39 | - After user confirmation, the order status will be changed to 'cancelled', and the total will be refunded via the original payment method immediately if it is gift card, otherwise in 5 to 7 business days. 40 | 41 | ## Modify pending order 42 | 43 | - An order can only be modified if its status is 'pending', and you should check its status before taking the action. 44 | 45 | - For a pending order, you can take actions to modify its shipping address, payment method, or product item options, but nothing else. 46 | 47 | ### Modify payment 48 | 49 | - The user can only choose a single payment method different from the original payment method. 50 | 51 | - If the user wants the modify the payment method to gift card, it must have enough balance to cover the total amount. 52 | 53 | - After user confirmation, the order status will be kept 'pending'. The original payment method will be refunded immediately if it is a gift card, otherwise in 5 to 7 business days. 54 | 55 | ### Modify items 56 | 57 | - This action can only be called once, and will change the order status to 'pending (items modifed)', and the agent will not be able to modify or cancel the order anymore. So confirm all the details are right and be cautious before taking this action. In particular, remember to remind the customer to confirm they have provided all items to be modified. 58 | 59 | - For a pending order, each item can be modified to an available new item of the same product but of different product option. There cannot be any change of product types, e.g. modify shirt to shoe. 60 | 61 | - The user must provide a payment method to pay or receive refund of the price difference. If the user provides a gift card, it must have enough balance to cover the price difference. 62 | 63 | ## Return delivered order 64 | 65 | - An order can only be returned if its status is 'delivered', and you should check its status before taking the action. 66 | 67 | - The user needs to confirm the order id, the list of items to be returned, and a payment method to receive the refund. 68 | 69 | - The refund must either go to the original payment method, or an existing gift card. 70 | 71 | - After user confirmation, the order status will be changed to 'return requested', and the user will receive an email regarding how to return items. 72 | 73 | ## Exchange delivered order 74 | 75 | - An order can only be exchanged if its status is 'delivered', and you should check its status before taking the action. In particular, remember to remind the customer to confirm they have provided all items to be exchanged. 76 | 77 | - For a delivered order, each item can be exchanged to an available new item of the same product but of different product option. There cannot be any change of product types, e.g. modify shirt to shoe. 78 | 79 | - The user must provide a payment method to pay or receive refund of the price difference. If the user provides a gift card, it must have enough balance to cover the price difference. 80 | 81 | - After user confirmation, the order status will be changed to 'exchange requested', and the user will receive an email regarding how to return items. There is no need to place a new order. 82 | -------------------------------------------------------------------------------- /tau_bench/envs/retail/wiki.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import os 4 | 5 | FOLDER_PATH = os.path.dirname(__file__) 6 | 7 | with open(os.path.join(FOLDER_PATH, "wiki.md"), "r") as f: 8 | WIKI = f.read() 9 | -------------------------------------------------------------------------------- /tau_bench/envs/tool.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any 3 | 4 | 5 | class Tool(abc.ABC): 6 | @staticmethod 7 | def invoke(*args, **kwargs): 8 | raise NotImplementedError 9 | 10 | @staticmethod 11 | def get_info() -> dict[str, Any]: 12 | raise NotImplementedError 13 | -------------------------------------------------------------------------------- /tau_bench/model_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from tau_bench.model_utils.api.api import API as API 2 | from tau_bench.model_utils.api.api import default_api_from_args as default_api_from_args 3 | from tau_bench.model_utils.api.api import BinaryClassifyDatapoint as BinaryClassifyDatapoint 4 | from tau_bench.model_utils.api.api import ClassifyDatapoint as ClassifyDatapoint 5 | from tau_bench.model_utils.api.api import GenerateDatapoint as GenerateDatapoint 6 | from tau_bench.model_utils.api.api import ParseDatapoint as ParseDatapoint 7 | from tau_bench.model_utils.api.api import ParseForceDatapoint as ParseForceDatapoint 8 | from tau_bench.model_utils.api.api import ScoreDatapoint as ScoreDatapoint 9 | from tau_bench.model_utils.api.api import default_api as default_api 10 | from tau_bench.model_utils.api.api import default_quick_api as default_quick_api 11 | from tau_bench.model_utils.api.datapoint import Datapoint as Datapoint 12 | from tau_bench.model_utils.api.datapoint import EvaluationResult as EvaluationResult 13 | from tau_bench.model_utils.api.datapoint import datapoint_factory as datapoint_factory 14 | from tau_bench.model_utils.api.datapoint import load_from_disk as load_from_disk 15 | from tau_bench.model_utils.api.exception import APIError as APIError 16 | from tau_bench.model_utils.api.sample import ( 17 | EnsembleSamplingStrategy as EnsembleSamplingStrategy, 18 | ) 19 | from tau_bench.model_utils.api.sample import ( 20 | MajoritySamplingStrategy as MajoritySamplingStrategy, 21 | ) 22 | from tau_bench.model_utils.api.sample import ( 23 | RedundantSamplingStrategy as RedundantSamplingStrategy, 24 | ) 25 | from tau_bench.model_utils.api.sample import RetrySamplingStrategy as RetrySamplingStrategy 26 | from tau_bench.model_utils.api.sample import SamplingStrategy as SamplingStrategy 27 | from tau_bench.model_utils.api.sample import SingleSamplingStrategy as SingleSamplingStrategy 28 | from tau_bench.model_utils.api.sample import ( 29 | UnanimousSamplingStrategy as UnanimousSamplingStrategy, 30 | ) 31 | from tau_bench.model_utils.api.sample import ( 32 | get_default_sampling_strategy as get_default_sampling_strategy, 33 | ) 34 | from tau_bench.model_utils.api.sample import ( 35 | set_default_sampling_strategy as set_default_sampling_strategy, 36 | ) 37 | from tau_bench.model_utils.model.chat import PromptSuffixStrategy as PromptSuffixStrategy 38 | from tau_bench.model_utils.model.exception import ModelError as ModelError 39 | from tau_bench.model_utils.model.general_model import GeneralModel as GeneralModel 40 | from tau_bench.model_utils.model.general_model import default_model as default_model 41 | from tau_bench.model_utils.model.general_model import model_factory as model_factory 42 | from tau_bench.model_utils.model.model import BinaryClassifyModel as BinaryClassifyModel 43 | from tau_bench.model_utils.model.model import ClassifyModel as ClassifyModel 44 | from tau_bench.model_utils.model.model import GenerateModel as GenerateModel 45 | from tau_bench.model_utils.model.model import ParseForceModel as ParseForceModel 46 | from tau_bench.model_utils.model.model import ParseModel as ParseModel 47 | from tau_bench.model_utils.model.model import Platform as Platform 48 | from tau_bench.model_utils.model.model import ScoreModel as ScoreModel 49 | from tau_bench.model_utils.model.openai import OpenAIModel as OpenAIModel 50 | from tau_bench.model_utils.model.utils import InputType as InputType 51 | -------------------------------------------------------------------------------- /tau_bench/model_utils/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sierra-research/tau-bench/14bf0ef52e595922d597a38f32d3e8c0dce3a8f8/tau_bench/model_utils/api/__init__.py -------------------------------------------------------------------------------- /tau_bench/model_utils/api/_model_methods.py: -------------------------------------------------------------------------------- 1 | MODEL_METHODS = [ 2 | "classify", 3 | "binary_classify", 4 | "parse", 5 | "generate", 6 | "parse_force", 7 | "score", 8 | ] 9 | -------------------------------------------------------------------------------- /tau_bench/model_utils/api/cache.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import hashlib 3 | import inspect 4 | import threading 5 | from collections import defaultdict 6 | from multiprocessing import Lock 7 | from typing import Any, Callable, TypeVar 8 | 9 | from pydantic import BaseModel 10 | 11 | T = TypeVar("T") 12 | 13 | USE_CACHE = True 14 | _USE_CACHE_LOCK = Lock() 15 | cache: dict[str, tuple[T, threading.Event]] = {} 16 | lock = threading.Lock() 17 | conditions = defaultdict(threading.Condition) 18 | 19 | 20 | def disable_cache(): 21 | global USE_CACHE 22 | with _USE_CACHE_LOCK: 23 | USE_CACHE = False 24 | 25 | 26 | def enable_cache(): 27 | global USE_CACHE 28 | with _USE_CACHE_LOCK: 29 | USE_CACHE = True 30 | 31 | 32 | def hash_item(item: Any) -> int: 33 | if isinstance(item, dict): 34 | return hash(tuple({k: hash_item(v) for k, v in sorted(item.items())})) 35 | elif isinstance(item, list): 36 | return hash(tuple([hash_item(x) for x in item])) 37 | elif isinstance(item, set): 38 | return hash(frozenset([hash_item(x) for x in item])) 39 | elif isinstance(item, tuple): 40 | return hash(tuple([hash_item(x) for x in item])) 41 | elif isinstance(item, BaseModel): 42 | return hash_item(item.model_json_schema()) 43 | return hash(item) 44 | 45 | 46 | def hash_func_call(func: Callable[..., Any], args: tuple[Any], kwargs: dict[str, Any]) -> str: 47 | bound_args = inspect.signature(func).bind(*args, **kwargs) 48 | bound_args.apply_defaults() 49 | standardized_args = sorted(bound_args.arguments.items()) 50 | arg_hash = hash_item(standardized_args) 51 | hashed_func = id(func) 52 | call = (hashed_func, arg_hash) 53 | return hashlib.md5(str(call).encode()).hexdigest() 54 | 55 | 56 | def cache_call_w_dedup(func: Callable[..., T]) -> Callable[..., T]: 57 | @functools.wraps(func) 58 | def wrapper(*args: Any, **kwargs: Any) -> T: 59 | if not USE_CACHE: 60 | return func(*args, **kwargs) 61 | key = hash_func_call(func=func, args=args, kwargs=kwargs) 62 | if key in cache: 63 | result, event = cache[key] 64 | if event.is_set(): 65 | return result 66 | else: 67 | with lock: 68 | cache[key] = (None, threading.Event()) 69 | 70 | condition = conditions[key] 71 | with condition: 72 | if cache[key][1].is_set(): 73 | return cache[key][0] 74 | if not cache[key][0]: 75 | try: 76 | result = func(*args, **kwargs) 77 | with lock: 78 | cache[key] = (result, threading.Event()) 79 | cache[key][1].set() 80 | except Exception as e: 81 | with lock: 82 | cache[key] = (e, threading.Event()) 83 | cache[key][1].set() 84 | raise e 85 | return cache[key][0] 86 | 87 | return wrapper 88 | -------------------------------------------------------------------------------- /tau_bench/model_utils/api/exception.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from concurrent.futures import ThreadPoolExecutor 5 | from typing import Any, Callable, TypeVar 6 | 7 | from tau_bench.model_utils.model.exception import ModelError, Result 8 | 9 | T = TypeVar("T") 10 | 11 | _REPORT_DIR = os.path.expanduser("~/.llm-primitives/log") 12 | 13 | 14 | def set_report_dir(path: str) -> None: 15 | global _REPORT_DIR 16 | _REPORT_DIR = path 17 | 18 | 19 | def get_report_dir() -> str: 20 | return _REPORT_DIR 21 | 22 | 23 | def log_report_to_disk(report: dict[str, Any], path: str) -> None: 24 | with open(path, "w") as f: 25 | json.dump(report, f, indent=4) 26 | 27 | 28 | def generate_report_location() -> str: 29 | if not os.path.exists(_REPORT_DIR): 30 | os.makedirs(_REPORT_DIR) 31 | return os.path.join(_REPORT_DIR, f"report-{time.time_ns()}.json") 32 | 33 | 34 | class APIError(Exception): 35 | def __init__(self, short_message: str, report: dict[str, Any] | None = None) -> None: 36 | self.report_path = generate_report_location() 37 | self.short_message = short_message 38 | self.report = report 39 | if self.report is not None: 40 | log_report_to_disk( 41 | report={"error_type": "APIError", "report": report}, path=self.report_path 42 | ) 43 | super().__init__(f"{short_message}\n\nSee the full report at {self.report_path}") 44 | 45 | 46 | def execute_and_filter_model_errors( 47 | funcs: list[Callable[[], T]], 48 | max_concurrency: int | None = None, 49 | ) -> list[T] | list[ModelError]: 50 | def _invoke_w_o_llm_error(invocable: Callable[[], T]) -> Result: 51 | try: 52 | return Result(value=invocable(), error=None) 53 | except ModelError as e: 54 | return Result(value=None, error=e) 55 | 56 | with ThreadPoolExecutor(max_workers=max_concurrency) as executor: 57 | results = list(executor.map(_invoke_w_o_llm_error, funcs)) 58 | 59 | errors: list[ModelError] = [] 60 | values = [] 61 | for res in results: 62 | if res.error is not None: 63 | errors.append(res.error) 64 | else: 65 | values.append(res.value) 66 | if len(values) == 0: 67 | assert len(errors) > 0 68 | raise errors[0] 69 | return values 70 | -------------------------------------------------------------------------------- /tau_bench/model_utils/api/logging.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | import json 4 | from multiprocessing import Lock 5 | from typing import Any 6 | 7 | from pydantic import BaseModel 8 | 9 | from tau_bench.model_utils.api.sample import SamplingStrategy 10 | from tau_bench.model_utils.model.utils import optionalize_type 11 | 12 | log_files = {} 13 | 14 | 15 | def prep_for_json_serialization(obj: Any, from_parse_method: bool = False): 16 | # TODO: refine type annotations 17 | if isinstance(obj, (str, int, float, bool, type(None))): 18 | return obj 19 | elif isinstance(obj, dict): 20 | return {k: prep_for_json_serialization(v) for k, v in obj.items()} 21 | elif isinstance(obj, list): 22 | return [prep_for_json_serialization(v) for v in obj] 23 | elif isinstance(obj, tuple): 24 | return tuple(prep_for_json_serialization(v) for v in obj) 25 | elif isinstance(obj, set): 26 | return {prep_for_json_serialization(v) for v in obj} 27 | elif isinstance(obj, frozenset): 28 | return frozenset(prep_for_json_serialization(v) for v in obj) 29 | elif isinstance(obj, BaseModel): 30 | return obj.model_dump(mode="json") 31 | elif isinstance(obj, type) and issubclass(obj, BaseModel): 32 | if from_parse_method: 33 | optionalized_type = optionalize_type(obj) 34 | return optionalized_type.model_json_schema() 35 | else: 36 | return obj.model_json_schema() 37 | elif isinstance(obj, SamplingStrategy): 38 | return obj.__class__.__name__ 39 | else: 40 | raise TypeError(f"Object of type {type(obj)} is not JSON serializable") 41 | 42 | 43 | def log_call(func): 44 | @functools.wraps(func) 45 | def wrapper(self, *args, **kwargs): 46 | response = func(self, *args, **kwargs) 47 | log_file = getattr(self, "_log_file", None) 48 | if log_file is not None: 49 | if log_file not in log_files: 50 | log_files[log_file] = Lock() 51 | sig = inspect.signature(func) 52 | bound_args = sig.bind(self, *args, **kwargs) 53 | bound_args.apply_defaults() 54 | all_args = bound_args.arguments 55 | all_args.pop("self", None) 56 | 57 | cls_name = self.__class__.__name__ 58 | log_entry = { 59 | "cls_name": cls_name, 60 | "method_name": func.__name__, 61 | "kwargs": { 62 | k: prep_for_json_serialization( 63 | v, from_parse_method=func.__name__ in ["parse", "async_parse"] 64 | ) 65 | for k, v in all_args.items() 66 | }, 67 | "response": prep_for_json_serialization(response), 68 | } 69 | with log_files[log_file]: 70 | with open(log_file, "a") as f: 71 | f.write(f"{json.dumps(log_entry)}\n") 72 | return response 73 | 74 | return wrapper 75 | -------------------------------------------------------------------------------- /tau_bench/model_utils/api/router.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from pydantic import BaseModel 4 | 5 | from tau_bench.model_utils.api.datapoint import Datapoint, ScoreDatapoint 6 | from tau_bench.model_utils.model.model import Model 7 | 8 | 9 | class RequestRouter(abc.ABC): 10 | @abc.abstractmethod 11 | def route(self, dp: Datapoint, available_models: list[Model]) -> Model: 12 | raise NotImplementedError 13 | 14 | 15 | class FirstModelRequestRouter(RequestRouter): 16 | def route(self, dp: Datapoint, available_models: list[Model]) -> Model: 17 | supporting_models = [model for model in available_models if model.supports_dp(dp)] 18 | if len(supporting_models) == 0: 19 | raise ValueError(f"No supporting models found from {available_models}") 20 | return supporting_models[0] 21 | 22 | 23 | class CapabilityScoreModel(abc.ABC): 24 | @abc.abstractmethod 25 | def score_dp(self, dp: Datapoint) -> float: 26 | raise NotImplementedError 27 | 28 | 29 | class PromptedLLMCapabilityScoreModel: 30 | def __init__(self, model: Model | None = None) -> None: 31 | if model is None: 32 | from tau_bench.model_utils.model.claude import ClaudeModel 33 | 34 | # claude is used as the default model as it is better at meta-level tasks 35 | model = ClaudeModel() 36 | self.model = model 37 | 38 | def score_dp(self, dp: Datapoint, examples: list[ScoreDatapoint] | None = None) -> float: 39 | return ( 40 | self.model.score( 41 | instruction="Score the task in the datapoint on a scale of 1 (least complex) to 10 (most complex).", 42 | text=f"----- start task -----\n{dp.model_dump_json()}\n----- end task -----", 43 | min=1, 44 | max=10, 45 | examples=examples, 46 | ) 47 | / 10.0 48 | ) 49 | 50 | 51 | class MinimumCapabilityRequestRouter(RequestRouter): 52 | def __init__(self, capability_score_model: CapabilityScoreModel) -> None: 53 | self.capability_score_model = capability_score_model 54 | 55 | def route(self, dp: Datapoint, available_models: list[Model]) -> Model: 56 | supporting_models = [model for model in available_models if model.supports_dp(dp)] 57 | if len(supporting_models) == 0: 58 | raise ValueError(f"No supporting models found from {available_models}") 59 | required_capability = self.capability_score_model.score_dp(dp) 60 | minimum_model: Model | None = None 61 | minimum_model_capability: float | None = None 62 | for model in supporting_models: 63 | capability = model.get_capability() 64 | if capability >= required_capability and ( 65 | minimum_model_capability is None or capability < minimum_model_capability 66 | ): 67 | minimum_model = model 68 | minimum_model_capability = capability 69 | if minimum_model is None: 70 | raise ValueError(f"No model found with capability >= {required_capability}") 71 | return minimum_model 72 | 73 | 74 | def request_router_factory( 75 | router_id: str, capability_score_model: CapabilityScoreModel | None = None 76 | ) -> RequestRouter: 77 | if router_id == "first-model": 78 | return FirstModelRequestRouter() 79 | elif router_id == "minimum-capability": 80 | if capability_score_model is None: 81 | raise ValueError("CapabilityScoreModel is required for minimum-capability router") 82 | return MinimumCapabilityRequestRouter(capability_score_model=capability_score_model) 83 | raise ValueError(f"Unknown router_id: {router_id}") 84 | 85 | 86 | def default_request_router() -> RequestRouter: 87 | return FirstModelRequestRouter() 88 | 89 | 90 | class RequestRouteDatapoint(BaseModel): 91 | dp: Datapoint 92 | capability_score: float 93 | -------------------------------------------------------------------------------- /tau_bench/model_utils/api/sample.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import functools 3 | from multiprocessing import Lock 4 | from typing import Any, Callable, TypeVar 5 | 6 | from pydantic import BaseModel 7 | 8 | from tau_bench.model_utils.api.exception import APIError, execute_and_filter_model_errors 9 | from tau_bench.model_utils.model.exception import ModelError 10 | from tau_bench.model_utils import func_tools 11 | 12 | T = TypeVar("T") 13 | 14 | 15 | class SamplingStrategy(abc.ABC): 16 | @abc.abstractmethod 17 | def execute(self, invocable_or_invokables: Callable[..., T] | list[Callable[..., T]]) -> T: 18 | raise NotImplementedError 19 | 20 | 21 | def catch_model_errors(func: Callable[..., T]) -> Callable[..., T]: 22 | @functools.wraps(func) 23 | def wrapper(*args, **kwargs) -> T: 24 | try: 25 | return func(*args, **kwargs) 26 | except ModelError as e: 27 | raise APIError( 28 | short_message=str(e), 29 | report={ 30 | "prompt": e.prompt, 31 | "response": e.response, 32 | "error_message": str(e), 33 | }, 34 | ) 35 | 36 | return wrapper 37 | 38 | 39 | class SingleSamplingStrategy(SamplingStrategy): 40 | @catch_model_errors 41 | def execute(self, invocable_or_invokables: Callable[..., T]) -> T: 42 | assert isinstance(invocable_or_invokables, Callable) 43 | return invocable_or_invokables() 44 | 45 | 46 | class RedundantSamplingStrategy(SamplingStrategy): 47 | def __init__(self, n: int = 2) -> None: 48 | assert n > 0 49 | self.n = n 50 | 51 | @catch_model_errors 52 | def execute(self, invocable_or_invokables: Callable[..., T] | list[Callable[..., T]]) -> T: 53 | results = execute_and_filter_model_errors( 54 | [lambda: invocable_or_invokables() for _ in range(self.n)] 55 | if isinstance(invocable_or_invokables, Callable) 56 | else invocable_or_invokables 57 | ) 58 | assert len(results) > 0 59 | return results[0] 60 | 61 | 62 | class RetrySamplingStrategy(SamplingStrategy): 63 | def __init__(self, max_retries: int = 5) -> None: 64 | assert max_retries > 0 65 | self.max_retries = max_retries 66 | 67 | @catch_model_errors 68 | def execute(self, invocable_or_invokables: Callable[..., T]) -> T: 69 | assert isinstance(invocable_or_invokables, Callable) 70 | first_error = None 71 | for _ in range(self.max_retries): 72 | try: 73 | return invocable_or_invokables() 74 | except ModelError as e: 75 | if first_error is None: 76 | first_error = e 77 | assert first_error is not None 78 | raise first_error 79 | 80 | 81 | class MajoritySamplingStrategy(SamplingStrategy): 82 | def __init__( 83 | self, 84 | n: int = 5, 85 | max_concurrency: int | None = None, 86 | panic_on_first_model_error: bool = False, 87 | ) -> None: 88 | self.n = n 89 | self.max_concurrency = max_concurrency if max_concurrency is not None else n 90 | self.panic_on_first_model_error = panic_on_first_model_error 91 | 92 | @catch_model_errors 93 | def execute(self, invocable_or_invokables: Callable[..., T] | list[Callable[..., T]]) -> T: 94 | if self.panic_on_first_model_error: 95 | if isinstance(invocable_or_invokables, Callable): 96 | results = list( 97 | func_tools.map( 98 | lambda _: invocable_or_invokables(), 99 | range(self.n), 100 | max_concurrency=self.max_concurrency, 101 | ) 102 | ) 103 | else: 104 | results = list( 105 | func_tools.map( 106 | lambda invocable: invocable(), 107 | invocable_or_invokables, 108 | max_concurrency=self.max_concurrency, 109 | ) 110 | ) 111 | else: 112 | results = execute_and_filter_model_errors( 113 | ( 114 | [lambda: invocable_or_invokables() for _ in range(self.n)] 115 | if isinstance(invocable_or_invokables, Callable) 116 | else invocable_or_invokables 117 | ), 118 | max_concurrency=self.max_concurrency, 119 | ) 120 | if not self.panic_on_first_model_error and len(results) == 0: 121 | raise SamplingError( 122 | "No results from majority sampling (all calls resulted in LLM errors)" 123 | ) 124 | return get_majority(results) 125 | 126 | 127 | def get_majority(results: list[T]) -> T: 128 | grouped: dict[str, Any] = {} 129 | for result in results: 130 | if isinstance(result, BaseModel): 131 | key = result.model_dump_json() 132 | else: 133 | key = str(result) 134 | if key not in grouped: 135 | # for now, just store duplicate results for the count 136 | grouped[key] = [result] 137 | else: 138 | grouped[key].append(result) 139 | majority = max(grouped, key=lambda key: len(grouped[key])) 140 | return grouped[majority][0] 141 | 142 | 143 | class EnsembleSamplingStrategy(SamplingStrategy): 144 | def __init__( 145 | self, max_concurrency: int | None = None, panic_on_first_model_error: bool = False 146 | ) -> None: 147 | self.max_concurrency = max_concurrency 148 | self.panic_on_first_model_error = panic_on_first_model_error 149 | 150 | @catch_model_errors 151 | def execute(self, invocable_or_invokables: Callable[..., T] | list[Callable[..., T]]) -> T: 152 | if not isinstance(invocable_or_invokables, list) or len(invocable_or_invokables) < 2: 153 | raise ValueError("Ensemble sampling requires at least 2 invocables") 154 | if self.panic_on_first_model_error: 155 | results = list( 156 | func_tools.map( 157 | lambda invocable: invocable(), 158 | invocable_or_invokables, 159 | max_concurrency=self.max_concurrency, 160 | ) 161 | ) 162 | else: 163 | results = execute_and_filter_model_errors( 164 | invocable_or_invokables, max_concurrency=self.max_concurrency 165 | ) 166 | if not self.panic_on_first_model_error and len(results) == 0: 167 | raise SamplingError( 168 | "No results from ensemble sampling (all calls resulted in LLM errors)" 169 | ) 170 | return get_majority(results) 171 | 172 | 173 | class UnanimousSamplingStrategy(SamplingStrategy): 174 | def __init__( 175 | self, 176 | n: int = 5, 177 | max_concurrency: int | None = None, 178 | panic_on_first_model_error: bool = False, 179 | ) -> None: 180 | self.n = n 181 | self.max_concurrency = max_concurrency if max_concurrency is not None else n 182 | self.panic_on_first_model_error = panic_on_first_model_error 183 | 184 | @catch_model_errors 185 | def execute(self, invocable_or_invokables: Callable[..., T] | list[Callable[..., T]]) -> T: 186 | if self.panic_on_first_model_error: 187 | if isinstance(invocable_or_invokables, Callable): 188 | results = list( 189 | func_tools.map( 190 | lambda _: invocable_or_invokables(), 191 | range(self.n), 192 | max_concurrency=self.max_concurrency, 193 | ) 194 | ) 195 | else: 196 | results = list( 197 | func_tools.map( 198 | lambda invocable: invocable(), 199 | invocable_or_invokables, 200 | max_concurrency=self.max_concurrency, 201 | ) 202 | ) 203 | else: 204 | results = execute_and_filter_model_errors( 205 | ( 206 | [lambda: invocable_or_invokables() for _ in range(self.n)] 207 | if isinstance(invocable_or_invokables, Callable) 208 | else invocable_or_invokables 209 | ), 210 | max_concurrency=self.max_concurrency, 211 | ) 212 | if len(set(results)) > 1: 213 | raise SamplingError("Results are not unanimous") 214 | return results[0] 215 | 216 | 217 | class SamplingError(Exception): 218 | pass 219 | 220 | 221 | DEFAULT_SAMPLING_STRATEGY = SingleSamplingStrategy() 222 | _DEFAULT_SAMPLING_STRATEGY_LOCK = Lock() 223 | 224 | 225 | def set_default_sampling_strategy(strategy: SamplingStrategy) -> None: 226 | with _DEFAULT_SAMPLING_STRATEGY_LOCK: 227 | global DEFAULT_SAMPLING_STRATEGY 228 | DEFAULT_SAMPLING_STRATEGY = strategy 229 | 230 | 231 | def get_default_sampling_strategy() -> SamplingStrategy: 232 | with _DEFAULT_SAMPLING_STRATEGY_LOCK: 233 | return DEFAULT_SAMPLING_STRATEGY 234 | -------------------------------------------------------------------------------- /tau_bench/model_utils/api/tokens.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from pydantic import BaseModel 4 | 5 | from tau_bench.model_utils.api.datapoint import ( 6 | BinaryClassifyDatapoint, 7 | ClassifyDatapoint, 8 | Datapoint, 9 | GenerateDatapoint, 10 | ParseDatapoint, 11 | ParseForceDatapoint, 12 | ScoreDatapoint, 13 | ) 14 | 15 | 16 | class TokenUsage(BaseModel): 17 | input_tokens: int 18 | output_tokens: int 19 | by_primitive: dict[str, "TokenUsage"] 20 | 21 | 22 | def batch_token_analysis(dps: list[Datapoint], encoding_for_model: str = "gpt-4o") -> TokenUsage: 23 | import tiktoken 24 | 25 | enc = tiktoken.encoding_for_model(encoding_for_model) 26 | # very rough estimates 27 | inputs_by_primitive: dict[str, list[str]] = {} 28 | outputs_by_primitive: dict[str, list[str]] = {} 29 | for dp in dps: 30 | input = json.dumps({k: v for k, v in dp.model_dump().items() if k != "response"}) 31 | inputs_by_primitive.setdefault(type(dp).__name__, []).append(input) 32 | if isinstance(dp, ClassifyDatapoint): 33 | output = f'{{"classification": {dp.response}}}' 34 | elif isinstance(dp, BinaryClassifyDatapoint): 35 | output = f'{{"classification": {0 if dp.response else 1}}}' 36 | elif isinstance(dp, ParseForceDatapoint): 37 | output = ( 38 | json.dumps(dp.response) 39 | if isinstance(dp.response, dict) 40 | else dp.response.model_dump_json() 41 | ) 42 | elif isinstance(dp, GenerateDatapoint): 43 | output = json.dumps(dp.response) 44 | elif isinstance(dp, ParseDatapoint): 45 | output = ( 46 | json.dumps(dp.response) 47 | if isinstance(dp.response, dict) 48 | else dp.response.model_dump_json() 49 | ) 50 | elif isinstance(dp, ScoreDatapoint): 51 | output = f"{{'score': {dp.response}}}" 52 | else: 53 | raise ValueError(f"Unknown datapoint type: {type(dp)}") 54 | outputs_by_primitive.setdefault(type(dp).__name__, []).append(output) 55 | input_tokens_by_primitive = {} 56 | output_tokens_by_primitive = {} 57 | for primitive, inputs in inputs_by_primitive.items(): 58 | input_tokens = sum([len(item) for item in enc.encode_batch(inputs)]) 59 | input_tokens_by_primitive[primitive] = input_tokens 60 | for primitive, outputs in outputs_by_primitive.items(): 61 | output_tokens = sum([len(item) for item in enc.encode_batch(outputs)]) 62 | output_tokens_by_primitive[primitive] = output_tokens 63 | return TokenUsage( 64 | input_tokens=sum(input_tokens_by_primitive.values()), 65 | output_tokens=sum(output_tokens_by_primitive.values()), 66 | by_primitive={ 67 | primitive: TokenUsage( 68 | input_tokens=input_tokens_by_primitive.get(primitive, 0), 69 | output_tokens=output_tokens_by_primitive.get(primitive, 0), 70 | by_primitive={}, 71 | ) 72 | for primitive in set(input_tokens_by_primitive.keys()) 73 | | set(output_tokens_by_primitive.keys()) 74 | }, 75 | ) 76 | 77 | 78 | def token_analysis(dp: Datapoint, encoding_for_model: str = "gpt-4o") -> TokenUsage: 79 | return batch_token_analysis([dp], encoding_for_model) 80 | -------------------------------------------------------------------------------- /tau_bench/model_utils/api/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | PartialObj = dict[str, Any] 4 | -------------------------------------------------------------------------------- /tau_bench/model_utils/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from tau_bench.model_utils.model.model import Platform 4 | 5 | 6 | def api_parser() -> argparse.ArgumentParser: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--model", type=str) 9 | parser.add_argument("--base-url", type=str) 10 | parser.add_argument("--platform", type=str, required=True, choices=[e.value for e in Platform]) 11 | return parser 12 | -------------------------------------------------------------------------------- /tau_bench/model_utils/func_tools/__init__.py: -------------------------------------------------------------------------------- 1 | from tau_bench.model_utils.func_tools.filter import filter as filter 2 | from tau_bench.model_utils.func_tools.map import map as map 3 | -------------------------------------------------------------------------------- /tau_bench/model_utils/func_tools/filter.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable, TypeVar 2 | 3 | from tau_bench.model_utils.func_tools.map import map 4 | 5 | T = TypeVar("T") 6 | 7 | builtin_filter = filter 8 | 9 | 10 | def filter( 11 | func: Callable[[T], bool], 12 | iterable: Iterable[T], 13 | max_concurrency: int | None = None, 14 | ) -> Iterable[T]: 15 | assert max_concurrency is None or max_concurrency > 0 16 | bits = map(func, iterable=iterable, max_concurrency=max_concurrency) 17 | return [x for x, y in zip(iterable, bits) if y] 18 | -------------------------------------------------------------------------------- /tau_bench/model_utils/func_tools/map.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | from typing import Callable, Iterable, TypeVar 3 | 4 | T = TypeVar("T") 5 | U = TypeVar("U") 6 | 7 | 8 | def map( 9 | func: Callable[[T], U], 10 | iterable: Iterable[T], 11 | max_concurrency: int | None = None, 12 | use_tqdm: bool = False, 13 | ) -> Iterable[U]: 14 | assert max_concurrency is None or max_concurrency > 0 15 | with ThreadPoolExecutor(max_workers=max_concurrency) as executor: 16 | if use_tqdm: 17 | from tqdm import tqdm 18 | 19 | return list(tqdm(executor.map(func, iterable), total=len(iterable))) 20 | return executor.map(func, iterable) 21 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sierra-research/tau-bench/14bf0ef52e595922d597a38f32d3e8c0dce3a8f8/tau_bench/model_utils/model/__init__.py -------------------------------------------------------------------------------- /tau_bench/model_utils/model/anyscale.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tau_bench.model_utils.api.datapoint import Datapoint 4 | from tau_bench.model_utils.model.chat import ChatModel, Message 5 | from tau_bench.model_utils.model.completion import approx_cost_for_datapoint, approx_prompt_str 6 | from tau_bench.model_utils.model.general_model import wrap_temperature 7 | from tau_bench.model_utils.model.utils import approx_num_tokens 8 | 9 | API_KEY_ENV_VAR = "ANYSCALE_API_KEY" 10 | BASE_URL = "https://api.endpoints.anyscale.com/v1" 11 | 12 | PRICE_PER_INPUT_TOKEN_MAP = {"meta-llama/Meta-Llama-3-8B-Instruct": ...} 13 | INPUT_PRICE_PER_TOKEN_FALLBACK = 10 / 1000000 14 | 15 | CAPABILITY_SCORE_MAP = { 16 | "meta-llama/Meta-Llama-3-8B-Instruct": 0.2, 17 | "meta-llama/Meta-Llama-3-70B-Instruct": 0.6, 18 | } 19 | CAPABILITY_SCORE_FALLBACK = 0.2 20 | 21 | # TODO: implement 22 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 23 | # TODO: implement 24 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 25 | 26 | MAX_CONTEXT_LENGTH_MAP = { 27 | "meta-llama/Meta-Llama-3-8B-Instruct": 8192, 28 | "meta-llama/Meta-Llama-3-70B-Instruct": 8192, 29 | } 30 | MAX_CONTEXT_LENGTH_FALLBACK = 8192 31 | 32 | 33 | class AnyscaleModel(ChatModel): 34 | def __init__( 35 | self, 36 | model: str, 37 | api_key: str | None = None, 38 | temperature: float = 0.0, 39 | ) -> None: 40 | from openai import AsyncOpenAI, OpenAI 41 | 42 | self.model = model 43 | 44 | api_key = None 45 | if api_key is None: 46 | api_key = os.getenv(API_KEY_ENV_VAR) 47 | if api_key is None: 48 | raise ValueError(f"{API_KEY_ENV_VAR} environment variable is not set") 49 | self.client = OpenAI(api_key=api_key, base_url=BASE_URL) 50 | self.async_client = AsyncOpenAI(api_key=api_key, base_url=BASE_URL) 51 | self.temperature = temperature 52 | 53 | def generate_message( 54 | self, 55 | messages: list[Message], 56 | force_json: bool, 57 | temperature: float | None = None, 58 | ) -> Message: 59 | if temperature is None: 60 | temperature = self.temperature 61 | msgs = self.build_generate_message_state(messages) 62 | res = self.client.chat.completions.create( 63 | model=self.model, 64 | messages=msgs, 65 | temperature=wrap_temperature(temperature), 66 | response_format={"type": "json_object" if force_json else "text"}, 67 | ) 68 | return self.handle_generate_message_response( 69 | prompt=msgs, content=res.choices[0].message.content, force_json=force_json 70 | ) 71 | 72 | def get_approx_cost(self, dp: Datapoint) -> float: 73 | cost_per_token = PRICE_PER_INPUT_TOKEN_MAP.get(self.model, INPUT_PRICE_PER_TOKEN_FALLBACK) 74 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 75 | 76 | def get_latency(self, dp: Datapoint) -> float: 77 | latency_per_output_token = LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get( 78 | self.model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK 79 | ) 80 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 81 | 82 | def get_capability(self) -> float: 83 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 84 | 85 | def supports_dp(self, dp: Datapoint) -> bool: 86 | prompt = approx_prompt_str(dp) 87 | return approx_num_tokens(prompt) <= MAX_CONTEXT_LENGTH_MAP.get( 88 | self.model, MAX_CONTEXT_LENGTH_FALLBACK 89 | ) 90 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/claude.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from tau_bench.model_utils.api.datapoint import Datapoint 5 | from tau_bench.model_utils.model.chat import ChatModel, Message 6 | from tau_bench.model_utils.model.completion import approx_cost_for_datapoint, approx_prompt_str 7 | from tau_bench.model_utils.model.general_model import wrap_temperature 8 | from tau_bench.model_utils.model.utils import approx_num_tokens 9 | 10 | DEFAULT_CLAUDE_MODEL = "claude-3-5-sonnet-20240620" 11 | DEFAULT_MAX_TOKENS = 8192 12 | ENV_VAR_API_KEY = "ANTHROPIC_API_KEY" 13 | 14 | PRICE_PER_INPUT_TOKEN_MAP = { 15 | "claude-3-5-sonnet-20240620": 3 / 1000000, 16 | } 17 | INPUT_PRICE_PER_TOKEN_FALLBACK = 15 / 1000000 18 | 19 | CAPABILITY_SCORE_MAP = { 20 | "claude-3-5-sonnet-20240620": 1.0, 21 | } 22 | CAPABILITY_SCORE_FALLBACK = 0.5 23 | 24 | # TODO: implement 25 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 26 | # TODO: implement 27 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 28 | 29 | MAX_CONTEXT_LENGTH_MAP = { 30 | "claude-3-5-sonnet-20240620": 8192, 31 | } 32 | MAX_CONTEXT_LENGTH_FALLBACK = 8192 33 | 34 | 35 | class ClaudeModel(ChatModel): 36 | def __init__( 37 | self, 38 | model: str | None = None, 39 | api_key: str | None = None, 40 | temperature: float = 0.0, 41 | ) -> None: 42 | from anthropic import Anthropic, AsyncAnthropic 43 | 44 | if model is None: 45 | self.model = DEFAULT_CLAUDE_MODEL 46 | else: 47 | self.model = model 48 | 49 | api_key = None 50 | if api_key is None: 51 | api_key = os.getenv(ENV_VAR_API_KEY) 52 | if api_key is None: 53 | raise ValueError(f"{ENV_VAR_API_KEY} environment variable is not set") 54 | # `anthropic-beta` header is needed for the 8192 context length (https://docs.anthropic.com/en/docs/about-claude/models) 55 | self.client = Anthropic( 56 | api_key=api_key, default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} 57 | ) 58 | self.async_client = AsyncAnthropic(api_key=api_key) 59 | self.temperature = temperature 60 | 61 | def get_approx_cost(self, dp: Datapoint) -> float: 62 | cost_per_token = PRICE_PER_INPUT_TOKEN_MAP.get(self.model, INPUT_PRICE_PER_TOKEN_FALLBACK) 63 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 64 | 65 | def get_latency(self, dp: Datapoint) -> float: 66 | latency_per_output_token = LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get( 67 | self.model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK 68 | ) 69 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 70 | 71 | def get_capability(self) -> float: 72 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 73 | 74 | def supports_dp(self, dp: Datapoint) -> bool: 75 | prompt = approx_prompt_str(dp) 76 | return approx_num_tokens(prompt) <= MAX_CONTEXT_LENGTH_MAP.get( 77 | self.model, MAX_CONTEXT_LENGTH_FALLBACK 78 | ) 79 | 80 | def _remap_messages(self, messages: list[dict[str, str]]) -> list[dict[str, str]]: 81 | remapped: list[dict[str, str]] = [] 82 | is_user = True 83 | for i, message in enumerate(messages): 84 | role = message["role"] 85 | if role == "assistant": 86 | if i == 0: 87 | raise ValueError( 88 | f"First message must be a system or user message, got {[m['role'] for m in messages]}" 89 | ) 90 | elif is_user: 91 | raise ValueError( 92 | f"Must alternate between user and assistant, got {[m['role'] for m in messages]}" 93 | ) 94 | remapped.append(message) 95 | is_user = True 96 | else: 97 | if is_user: 98 | remapped.append({"role": "user", "content": message["content"]}) 99 | is_user = False 100 | else: 101 | if remapped[-1]["role"] != "user": 102 | raise ValueError( 103 | f"Invalid sequence, expected user message but got {[m['role'] for m in messages]}" 104 | ) 105 | remapped[-1]["content"] += "\n\n" + message["content"] 106 | return remapped 107 | 108 | def build_generate_message_state( 109 | self, 110 | messages: list[Message], 111 | ) -> list[dict[str, str]]: 112 | msgs: list[dict[str, str]] = [] 113 | for msg in messages: 114 | if msg.obj is not None: 115 | content = json.dumps(msg.obj) 116 | else: 117 | content = msg.content 118 | msgs.append({"role": msg.role.value, "content": content}) 119 | return self._remap_messages(msgs) 120 | 121 | def generate_message( 122 | self, 123 | messages: list[Message], 124 | force_json: bool, 125 | temperature: float | None = None, 126 | ) -> Message: 127 | if temperature is None: 128 | temperature = self.temperature 129 | msgs = self.build_generate_message_state(messages) 130 | res = self.client.messages.create( 131 | model=self.model, 132 | messages=msgs, 133 | temperature=wrap_temperature(temperature), 134 | max_tokens=DEFAULT_MAX_TOKENS, 135 | ) 136 | return self.handle_generate_message_response( 137 | prompt=msgs, content=res.content[0].text, force_json=force_json 138 | ) 139 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/exception.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Generic, TypeVar 3 | 4 | T = TypeVar("T") 5 | 6 | 7 | class ModelError(Exception): 8 | def __init__( 9 | self, 10 | short_message: str, 11 | prompt: str | list[dict[str, str]] | None = None, 12 | response: str | None = None, 13 | ) -> None: 14 | super().__init__(short_message) 15 | self.short_message = short_message 16 | self.prompt = prompt 17 | self.response = response 18 | 19 | 20 | @dataclass 21 | class Result(Generic[T]): 22 | value: T | None 23 | error: ModelError | None 24 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/general_model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, TypeVar 3 | 4 | from pydantic import BaseModel 5 | 6 | from tau_bench.model_utils.api.datapoint import ( 7 | BinaryClassifyDatapoint, 8 | ClassifyDatapoint, 9 | GenerateDatapoint, 10 | ParseDatapoint, 11 | ParseForceDatapoint, 12 | ScoreDatapoint, 13 | ) 14 | from tau_bench.model_utils.api.types import PartialObj 15 | from tau_bench.model_utils.model.model import ( 16 | BinaryClassifyModel, 17 | ClassifyModel, 18 | GenerateModel, 19 | ParseForceModel, 20 | ParseModel, 21 | Platform, 22 | ScoreModel, 23 | ) 24 | 25 | T = TypeVar("T", bound=BaseModel) 26 | 27 | LLM_SAMPLING_TEMPERATURE_EPS = 1e-5 28 | 29 | 30 | def wrap_temperature(temperature: float) -> float: 31 | return max(temperature, LLM_SAMPLING_TEMPERATURE_EPS) 32 | 33 | 34 | class GeneralModel( 35 | ClassifyModel, 36 | BinaryClassifyModel, 37 | ParseModel, 38 | GenerateModel, 39 | ParseForceModel, 40 | ScoreModel, 41 | ): 42 | @abc.abstractmethod 43 | def classify( 44 | self, 45 | instruction: str, 46 | text: str, 47 | options: list[str], 48 | examples: list[ClassifyDatapoint] | None = None, 49 | temperature: float | None = None, 50 | ) -> int: 51 | raise NotImplementedError 52 | 53 | def binary_classify( 54 | self, 55 | instruction: str, 56 | text: str, 57 | examples: list[BinaryClassifyDatapoint] | None = None, 58 | temperature: float | None = None, 59 | ) -> bool: 60 | return ( 61 | self.classify( 62 | instruction, 63 | text, 64 | ["true", "false"], 65 | examples=( 66 | None 67 | if examples is None 68 | else [ 69 | ClassifyDatapoint( 70 | instruction=example.instruction, 71 | text=example.text, 72 | options=["true", "false"], 73 | response=0 if example.response else 1, 74 | ) 75 | for example in examples 76 | ] 77 | ), 78 | temperature=temperature, 79 | ) 80 | == 0 81 | ) 82 | 83 | @abc.abstractmethod 84 | def parse( 85 | self, 86 | text: str, 87 | typ: type[T] | dict[str, Any], 88 | examples: list[ParseDatapoint] | None = None, 89 | temperature: float | None = None, 90 | ) -> T | PartialObj | dict[str, Any]: 91 | raise NotImplementedError 92 | 93 | @abc.abstractmethod 94 | def generate( 95 | self, 96 | instruction: str, 97 | text: str, 98 | examples: list[GenerateDatapoint] | None = None, 99 | temperature: float | None = None, 100 | ) -> str: 101 | raise NotImplementedError 102 | 103 | @abc.abstractmethod 104 | def parse_force( 105 | self, 106 | instruction: str, 107 | typ: type[T] | dict[str, Any], 108 | text: str | None = None, 109 | examples: list[ParseForceDatapoint] | None = None, 110 | temperature: float | None = None, 111 | ) -> T | dict[str, Any]: 112 | raise NotImplementedError 113 | 114 | @abc.abstractmethod 115 | def score( 116 | self, 117 | instruction: str, 118 | text: str, 119 | min: int, 120 | max: int, 121 | examples: list[ScoreDatapoint] | None = None, 122 | temperature: float | None = None, 123 | ) -> int: 124 | raise NotImplementedError 125 | 126 | 127 | def default_model() -> GeneralModel: 128 | from tau_bench.model_utils.model.openai import OpenAIModel 129 | 130 | return OpenAIModel() 131 | 132 | 133 | def default_quick_model() -> GeneralModel: 134 | from tau_bench.model_utils.model.openai import OpenAIModel 135 | 136 | return OpenAIModel(model="gpt-4o-mini") 137 | 138 | 139 | def model_factory( 140 | model_id: str, 141 | platform: str | Platform, 142 | base_url: str | None = None, 143 | api_key: str | None = None, 144 | temperature: float = 0.0, 145 | ) -> GeneralModel: 146 | if isinstance(platform, str): 147 | platform = Platform(platform) 148 | if platform == Platform.OPENAI: 149 | from tau_bench.model_utils.model.openai import OpenAIModel 150 | 151 | return OpenAIModel(model=model_id, api_key=api_key, temperature=temperature) 152 | elif platform == Platform.MISTRAL: 153 | from tau_bench.model_utils.model.mistral import MistralModel 154 | 155 | return MistralModel(model=model_id, api_key=api_key, temperature=temperature) 156 | elif platform == Platform.ANTHROPIC: 157 | from tau_bench.model_utils.model.claude import ClaudeModel 158 | 159 | return ClaudeModel(model=model_id, api_key=api_key, temperature=temperature) 160 | 161 | elif platform == Platform.ANYSCALE: 162 | from tau_bench.model_utils.model.anyscale import AnyscaleModel 163 | 164 | return AnyscaleModel(model=model_id, api_key=api_key, temperature=temperature) 165 | elif platform == Platform.OUTLINES: 166 | if base_url is None: 167 | raise ValueError("base_url must be provided for custom models") 168 | from tau_bench.model_utils.model.outlines_completion import OutlinesCompletionModel 169 | 170 | return OutlinesCompletionModel(model=model_id, base_url=base_url, temperature=temperature) 171 | elif platform == Platform.VLLM_CHAT: 172 | if base_url is None: 173 | raise ValueError("base_url must be provided for custom models") 174 | from tau_bench.model_utils.model.vllm_chat import VLLMChatModel 175 | 176 | return VLLMChatModel( 177 | model=model_id, 178 | base_url=base_url, 179 | api_key="sk-no-api-key-required" if api_key is None else api_key, 180 | temperature=temperature, 181 | ) 182 | else: 183 | if base_url is None: 184 | raise ValueError("base_url must be provided for custom models") 185 | from tau_bench.model_utils.model.vllm_completion import VLLMCompletionModel 186 | 187 | return VLLMCompletionModel(model=model_id, base_url=base_url, temperature=temperature) 188 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/mistral.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tau_bench.model_utils.api.datapoint import Datapoint 4 | from tau_bench.model_utils.model.chat import ChatModel, Message 5 | from tau_bench.model_utils.model.completion import approx_cost_for_datapoint, approx_prompt_str 6 | from tau_bench.model_utils.model.general_model import wrap_temperature 7 | from tau_bench.model_utils.model.utils import approx_num_tokens 8 | 9 | DEFAULT_MISTRAL_MODEL = "mistral-large-latest" 10 | 11 | PRICE_PER_INPUT_TOKEN_MAP = { 12 | "mistral-largest-latest": 3 / 1000000, 13 | } 14 | INPUT_PRICE_PER_TOKEN_FALLBACK = 10 / 1000000 15 | 16 | CAPABILITY_SCORE_MAP = { 17 | "mistral-largest-latest": 0.9, 18 | } 19 | CAPABILITY_SCORE_FALLBACK = 0.3 20 | 21 | # TODO: implement 22 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 23 | # TODO: implement 24 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 25 | 26 | MAX_CONTEXT_LENGTH_MAP = { 27 | "mistral-largest-latest": 128000, 28 | } 29 | MAX_CONTEXT_LENGTH_FALLBACK = 128000 30 | 31 | 32 | class MistralModel(ChatModel): 33 | def __init__( 34 | self, model: str | None = None, api_key: str | None = None, temperature: float = 0.0 35 | ) -> None: 36 | from mistralai.async_client import MistralAsyncClient 37 | from mistralai.client import MistralClient 38 | 39 | if model is None: 40 | self.model = DEFAULT_MISTRAL_MODEL 41 | else: 42 | self.model = model 43 | 44 | api_key = None 45 | if api_key is None: 46 | api_key = os.getenv("MISTRAL_API_KEY") 47 | if api_key is None: 48 | raise ValueError("MISTRAL_API_KEY environment variable is not set") 49 | self.client = MistralClient(api_key=api_key) 50 | self.async_client = MistralAsyncClient(api_key=api_key) 51 | self.temperature = temperature 52 | 53 | def generate_message( 54 | self, 55 | messages: list[Message], 56 | force_json: bool, 57 | temperature: float | None = None, 58 | ) -> Message: 59 | if temperature is None: 60 | temperature = self.temperature 61 | msgs = self.build_generate_message_state(messages) 62 | res = self.client.chat( 63 | model=self.model, 64 | messages=msgs, 65 | temperature=wrap_temperature(temperature), 66 | response_format={"type": "json_object" if force_json else "text"}, 67 | ) 68 | return self.handle_generate_message_response( 69 | prompt=msgs, content=res.choices[0].message.content, force_json=force_json 70 | ) 71 | 72 | def get_approx_cost(self, dp: Datapoint) -> float: 73 | cost_per_token = PRICE_PER_INPUT_TOKEN_MAP.get(self.model, INPUT_PRICE_PER_TOKEN_FALLBACK) 74 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 75 | 76 | def get_latency(self, dp: Datapoint) -> float: 77 | latency_per_output_token = LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get( 78 | self.model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK 79 | ) 80 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 81 | 82 | def get_capability(self) -> float: 83 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 84 | 85 | def supports_dp(self, dp: Datapoint) -> bool: 86 | prompt = approx_prompt_str(dp) 87 | return approx_num_tokens(prompt) <= MAX_CONTEXT_LENGTH_MAP.get( 88 | self.model, MAX_CONTEXT_LENGTH_FALLBACK 89 | ) 90 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import enum 3 | from typing import Any, TypeVar 4 | 5 | from pydantic import BaseModel 6 | 7 | from tau_bench.model_utils.api.datapoint import ( 8 | BinaryClassifyDatapoint, 9 | ClassifyDatapoint, 10 | Datapoint, 11 | GenerateDatapoint, 12 | ParseDatapoint, 13 | ParseForceDatapoint, 14 | ScoreDatapoint, 15 | ) 16 | from tau_bench.model_utils.api.types import PartialObj 17 | 18 | T = TypeVar("T", bound=BaseModel) 19 | 20 | 21 | class Platform(enum.Enum): 22 | OPENAI = "openai" 23 | MISTRAL = "mistral" 24 | ANTHROPIC = "anthropic" 25 | ANYSCALE = "anyscale" 26 | OUTLINES = "outlines" 27 | VLLM_CHAT = "vllm-chat" 28 | VLLM_COMPLETION = "vllm-completion" 29 | 30 | 31 | # @runtime_checkable 32 | # class Model(Protocol): 33 | class Model(abc.ABC): 34 | @abc.abstractmethod 35 | def get_capability(self) -> float: 36 | """Return the capability of the model, a float between 0.0 and 1.0.""" 37 | raise NotImplementedError 38 | 39 | @abc.abstractmethod 40 | def get_approx_cost(self, dp: Datapoint) -> float: 41 | raise NotImplementedError 42 | 43 | @abc.abstractmethod 44 | def get_latency(self, dp: Datapoint) -> float: 45 | raise NotImplementedError 46 | 47 | @abc.abstractmethod 48 | def supports_dp(self, dp: Datapoint) -> bool: 49 | raise NotImplementedError 50 | 51 | 52 | class ClassifyModel(Model): 53 | @abc.abstractmethod 54 | def classify( 55 | self, 56 | instruction: str, 57 | text: str, 58 | options: list[str], 59 | examples: list[ClassifyDatapoint] | None = None, 60 | temperature: float | None = None, 61 | ) -> int: 62 | raise NotImplementedError 63 | 64 | 65 | class BinaryClassifyModel(Model): 66 | @abc.abstractmethod 67 | def binary_classify( 68 | self, 69 | instruction: str, 70 | text: str, 71 | examples: list[BinaryClassifyDatapoint] | None = None, 72 | temperature: float | None = None, 73 | ) -> bool: 74 | raise NotImplementedError 75 | 76 | 77 | class ParseModel(Model): 78 | @abc.abstractmethod 79 | def parse( 80 | self, 81 | text: str, 82 | typ: type[T] | dict[str, Any], 83 | examples: list[ParseDatapoint] | None = None, 84 | temperature: float | None = None, 85 | ) -> T | PartialObj | dict[str, Any]: 86 | raise NotImplementedError 87 | 88 | 89 | class GenerateModel(Model): 90 | @abc.abstractmethod 91 | def generate( 92 | self, 93 | instruction: str, 94 | text: str, 95 | examples: list[GenerateDatapoint] | None = None, 96 | temperature: float | None = None, 97 | ) -> str: 98 | raise NotImplementedError 99 | 100 | 101 | class ParseForceModel(Model): 102 | @abc.abstractmethod 103 | def parse_force( 104 | self, 105 | instruction: str, 106 | typ: type[T] | dict[str, Any], 107 | text: str | None = None, 108 | examples: list[ParseForceDatapoint] | None = None, 109 | temperature: float | None = None, 110 | ) -> T | dict[str, Any]: 111 | raise NotImplementedError 112 | 113 | 114 | class ScoreModel(Model): 115 | @abc.abstractmethod 116 | def score( 117 | self, 118 | instruction: str, 119 | text: str, 120 | min: int, 121 | max: int, 122 | examples: list[ScoreDatapoint] | None = None, 123 | temperature: float | None = None, 124 | ) -> int: 125 | raise NotImplementedError 126 | 127 | 128 | AnyModel = ( 129 | BinaryClassifyModel | ClassifyModel | ParseForceModel | GenerateModel | ParseModel | ScoreModel 130 | ) 131 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tau_bench.model_utils.api.datapoint import Datapoint 4 | from tau_bench.model_utils.model.chat import ChatModel, Message 5 | from tau_bench.model_utils.model.completion import approx_cost_for_datapoint, approx_prompt_str 6 | from tau_bench.model_utils.model.general_model import wrap_temperature 7 | from tau_bench.model_utils.model.utils import approx_num_tokens 8 | 9 | DEFAULT_OPENAI_MODEL = "gpt-4o-2024-08-06" 10 | API_KEY_ENV_VAR = "OPENAI_API_KEY" 11 | 12 | PRICE_PER_INPUT_TOKEN_MAP = { 13 | "gpt-4o-2024-08-06": 2.5 / 1000000, 14 | "gpt-4o": 5 / 1000000, 15 | "gpt-4o-2024-08-06": 2.5 / 1000000, 16 | "gpt-4o-2024-05-13": 5 / 1000000, 17 | "gpt-4-turbo": 10 / 1000000, 18 | "gpt-4-turbo-2024-04-09": 10 / 1000000, 19 | "gpt-4": 30 / 1000000, 20 | "gpt-4o-mini": 0.15 / 1000000, 21 | "gpt-4o-mini-2024-07-18": 0.15 / 1000000, 22 | "gpt-3.5-turbo": 0.5 / 1000000, 23 | "gpt-3.5-turbo-0125": 0.5 / 1000000, 24 | "gpt-3.5-turbo-instruct": 1.5 / 1000000, 25 | } 26 | INPUT_PRICE_PER_TOKEN_FALLBACK = 10 / 1000000 27 | 28 | CAPABILITY_SCORE_MAP = { 29 | "gpt-4o-2024-08-06": 0.8, 30 | "gpt-4o": 0.8, 31 | "gpt-4o-2024-08-06": 0.8, 32 | "gpt-4o-2024-05-13": 0.8, 33 | "gpt-4-turbo": 0.9, 34 | "gpt-4-turbo-2024-04-09": 0.9, 35 | "gpt-4": 0.8, 36 | "gpt-4o-mini": 0.5, 37 | "gpt-4o-mini-2024-07-18": 0.5, 38 | "gpt-3.5-turbo": 0.3, 39 | "gpt-3.5-turbo-0125": 0.3, 40 | } 41 | CAPABILITY_SCORE_FALLBACK = 0.3 42 | 43 | # TODO: implement 44 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 45 | # TODO: implement 46 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 47 | 48 | MAX_CONTEXT_LENGTH_MAP = { 49 | "gpt-4o-2024-08-06": 128000, 50 | "gpt-4o": 128000, 51 | "gpt-4o-2024-08-06": 128000, 52 | "gpt-4o-2024-05-13": 128000, 53 | "gpt-4-turbo": 128000, 54 | "gpt-4-turbo-2024-04-09": 128000, 55 | "gpt-4": 8192, 56 | "gpt-4o-mini": 128000, 57 | "gpt-4o-mini-2024-07-18": 128000, 58 | "gpt-3.5-turbo": 16385, 59 | "gpt-3.5-turbo-0125": 16385, 60 | } 61 | MAX_CONTEXT_LENGTH_FALLBACK = 128000 62 | 63 | 64 | class OpenAIModel(ChatModel): 65 | def __init__( 66 | self, 67 | model: str | None = None, 68 | api_key: str | None = None, 69 | temperature: float = 0.0, 70 | ) -> None: 71 | from openai import AsyncOpenAI, OpenAI 72 | 73 | if model is None: 74 | self.model = DEFAULT_OPENAI_MODEL 75 | else: 76 | self.model = model 77 | 78 | api_key = None 79 | if api_key is None: 80 | api_key = os.getenv(API_KEY_ENV_VAR) 81 | if api_key is None: 82 | raise ValueError(f"{API_KEY_ENV_VAR} environment variable is not set") 83 | self.client = OpenAI(api_key=api_key) 84 | self.async_client = AsyncOpenAI(api_key=api_key) 85 | self.temperature = temperature 86 | 87 | def generate_message( 88 | self, 89 | messages: list[Message], 90 | force_json: bool, 91 | temperature: float | None = None, 92 | ) -> Message: 93 | if temperature is None: 94 | temperature = self.temperature 95 | msgs = self.build_generate_message_state(messages) 96 | res = self.client.chat.completions.create( 97 | model=self.model, 98 | messages=msgs, 99 | temperature=wrap_temperature(temperature), 100 | response_format={"type": "json_object" if force_json else "text"}, 101 | ) 102 | return self.handle_generate_message_response( 103 | prompt=msgs, content=res.choices[0].message.content, force_json=force_json 104 | ) 105 | 106 | def get_approx_cost(self, dp: Datapoint) -> float: 107 | cost_per_token = PRICE_PER_INPUT_TOKEN_MAP.get(self.model, INPUT_PRICE_PER_TOKEN_FALLBACK) 108 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 109 | 110 | def get_latency(self, dp: Datapoint) -> float: 111 | latency_per_output_token = LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get( 112 | self.model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK 113 | ) 114 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 115 | 116 | def get_capability(self) -> float: 117 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 118 | 119 | def supports_dp(self, dp: Datapoint) -> bool: 120 | prompt = approx_prompt_str(dp) 121 | return approx_num_tokens(prompt) <= MAX_CONTEXT_LENGTH_MAP.get( 122 | self.model, MAX_CONTEXT_LENGTH_FALLBACK 123 | ) 124 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/outlines_completion.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pydantic import BaseModel 4 | 5 | from tau_bench.model_utils.api.datapoint import Datapoint 6 | from tau_bench.model_utils.model.vllm_completion import VLLMCompletionModel 7 | from tau_bench.model_utils.model.vllm_utils import generate_request 8 | 9 | 10 | class OutlinesCompletionModel(VLLMCompletionModel): 11 | def parse_force_from_prompt( 12 | self, prompt: str, typ: BaseModel, temperature: float | None = None 13 | ) -> dict[str, Any]: 14 | if temperature is None: 15 | temperature = self.temperature 16 | schema = typ.model_json_schema() 17 | res = generate_request( 18 | url=self.url, 19 | prompt=prompt, 20 | force_json=True, 21 | schema=schema, 22 | temperature=temperature, 23 | ) 24 | return self.handle_parse_force_response(prompt=prompt, content=res) 25 | 26 | def get_approx_cost(self, dp: Datapoint) -> float: 27 | return super().get_approx_cost(dp) 28 | 29 | def get_latency(self, dp: Datapoint) -> float: 30 | return super().get_latency(dp) 31 | 32 | def get_capability(self) -> float: 33 | return super().get_capability() 34 | 35 | def supports_dp(self, dp: Datapoint) -> bool: 36 | return super().supports_dp(dp) 37 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/utils.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import json 3 | import re 4 | from typing import Any, Optional, TypeVar 5 | 6 | from pydantic import BaseModel, Field 7 | 8 | from tau_bench.model_utils.api.types import PartialObj 9 | 10 | T = TypeVar("T", bound=BaseModel) 11 | 12 | 13 | class InputType(enum.Enum): 14 | CHAT = "chat" 15 | COMPLETION = "completion" 16 | 17 | 18 | def display_choices(choices: list[str]) -> tuple[str, dict[str, int]]: 19 | choice_displays = [] 20 | decode_map = {} 21 | for i, choice in enumerate(choices): 22 | label = index_to_alpha(i) 23 | choice_display = f"{label}. {choice}" 24 | choice_displays.append(choice_display) 25 | decode_map[label] = i 26 | return "\n".join(choice_displays), decode_map 27 | 28 | 29 | def index_to_alpha(index: int) -> str: 30 | alpha = "" 31 | while index >= 0: 32 | alpha = chr(index % 26 + ord("A")) + alpha 33 | index = index // 26 - 1 34 | return alpha 35 | 36 | 37 | def type_to_json_schema_string(typ: type[T]) -> str: 38 | json_schema = typ.model_json_schema() 39 | return json.dumps(json_schema, indent=4) 40 | 41 | 42 | def optionalize_type(typ: type[T]) -> type[T]: 43 | class OptionalModel(typ): 44 | ... 45 | 46 | new_fields = {} 47 | for name, field in OptionalModel.model_fields.items(): 48 | new_fields[name] = Field(default=None, annotation=Optional[field.annotation]) 49 | OptionalModel.model_fields = new_fields 50 | OptionalModel.__name__ = typ.__name__ 51 | return OptionalModel 52 | 53 | 54 | def json_response_to_obj_or_partial_obj( 55 | response: dict[str, Any], typ: type[T] | dict[str, Any] 56 | ) -> T | PartialObj | dict[str, Any]: 57 | if isinstance(typ, dict): 58 | return response 59 | else: 60 | required_field_names = [ 61 | name for name, field in typ.model_fields.items() if field.is_required() 62 | ] 63 | for name in required_field_names: 64 | if name not in response.keys() or response[name] is None: 65 | return response 66 | return typ.model_validate(response) 67 | 68 | 69 | def clean_top_level_keys(d: dict[str, Any]) -> dict[str, Any]: 70 | new_d = {} 71 | for k, v in d.items(): 72 | new_d[k.strip()] = v 73 | return new_d 74 | 75 | 76 | def parse_json_or_json_markdown(text: str) -> dict[str, Any]: 77 | def parse(s: str) -> dict[str, Any] | None: 78 | try: 79 | return json.loads(s) 80 | except json.decoder.JSONDecodeError: 81 | return None 82 | 83 | # pass #1: try to parse as json 84 | parsed = parse(text) 85 | if parsed is not None: 86 | return parsed 87 | 88 | # pass #2: try to parse as json markdown 89 | stripped = text.strip() 90 | if stripped.startswith("```json"): 91 | stripped = stripped[len("```json") :].strip() 92 | if stripped.endswith("```"): 93 | stripped = stripped[: -len("```")].strip() 94 | parsed = parse(stripped) 95 | if parsed is not None: 96 | return parsed 97 | 98 | # pass #3: try to parse an arbitrary md block 99 | pattern = r"```(?:\w+\n)?(.*?)```" 100 | match = re.search(pattern, text, re.DOTALL) 101 | if match: 102 | content = match.group(1).strip() 103 | parsed = parse(content) 104 | if parsed is not None: 105 | return parsed 106 | 107 | # pass #4: try to parse arbitrary sections as json 108 | lines = text.split("\n") 109 | seen = set() 110 | for i in range(len(lines)): 111 | for j in range(i + 1, len(lines) + 1): 112 | if i < j and (i, j) not in seen: 113 | seen.add((i, j)) 114 | content = "\n".join(lines[i:j]) 115 | parsed = parse(content) 116 | if parsed is not None: 117 | return parsed 118 | raise ValueError("Could not parse JSON or JSON markdown") 119 | 120 | 121 | def longest_valid_string(s: str, options: list[str]) -> str | None: 122 | longest = 0 123 | longest_str = None 124 | options_set = set(options) 125 | for i in range(len(s)): 126 | if s[: i + 1] in options_set and i + 1 > longest: 127 | longest = i + 1 128 | longest_str = s[: i + 1] 129 | return longest_str 130 | 131 | 132 | def try_classify_recover(s: str, decode_map: dict[str, int]) -> str | None: 133 | lvs = longest_valid_string(s, list(decode_map.keys())) 134 | if lvs is not None and lvs in decode_map: 135 | return lvs 136 | for k, v in decode_map.items(): 137 | if s == v: 138 | return k 139 | 140 | 141 | def approx_num_tokens(text: str) -> int: 142 | return len(text) // 4 143 | 144 | 145 | def add_md_close_tag(prompt: str) -> str: 146 | return f"{prompt}\n```" 147 | 148 | 149 | def add_md_tag(prompt: str) -> str: 150 | return f"```json\n{prompt}\n```" 151 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/vllm_chat.py: -------------------------------------------------------------------------------- 1 | from tau_bench.model_utils.api.datapoint import Datapoint 2 | from tau_bench.model_utils.model.chat import ChatModel, Message 3 | from tau_bench.model_utils.model.completion import approx_cost_for_datapoint, approx_prompt_str 4 | from tau_bench.model_utils.model.general_model import wrap_temperature 5 | from tau_bench.model_utils.model.utils import approx_num_tokens 6 | 7 | PRICE_PER_INPUT_TOKEN_MAP = { 8 | "Qwen/Qwen2-0.5B-Instruct": 0.0, 9 | "Qwen/Qwen2-1.5B-Instruct": 0.0, 10 | "Qwen/Qwen2-7B-Instruct": 0.0, 11 | "Qwen/Qwen2-72B-Instruct": 0.0, 12 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 0.0, 13 | "sierra-research/Meta-Llama-3.1-8B-Instruct": 0.0, 14 | "meta-llama/Meta-Llama-3.1-70B-Instruct": 0.0, 15 | "mistralai/Mistral-Nemo-Instruct-2407": 0.0, 16 | } 17 | INPUT_PRICE_PER_TOKEN_FALLBACK = 0.0 18 | 19 | # TODO: refine this 20 | CAPABILITY_SCORE_MAP = { 21 | "Qwen/Qwen2-0.5B-Instruct": 0.05, 22 | "Qwen/Qwen2-1.5B-Instruct": 0.07, 23 | "Qwen/Qwen2-7B-Instruct": 0.2, 24 | "Qwen/Qwen2-72B-Instruct": 0.4, 25 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 0.3, 26 | "sierra-research/Meta-Llama-3.1-8B-Instruct": 0.3, 27 | "meta-llama/Meta-Llama-3.1-70B-Instruct": 0.4, 28 | "mistralai/Mistral-Nemo-Instruct-2407": 0.3, 29 | } 30 | CAPABILITY_SCORE_FALLBACK = 0.3 31 | 32 | # TODO: implement 33 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 34 | # TODO: implement 35 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 36 | 37 | MAX_CONTEXT_LENGTH_MAP = { 38 | "Qwen/Qwen2-0.5B-Instruct": 32768, 39 | "Qwen/Qwen2-1.5B-Instruct": 32768, 40 | "Qwen/Qwen2-7B-Instruct": 131072, 41 | "Qwen/Qwen2-72B-Instruct": 131072, 42 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 128000, 43 | "sierra-research/Meta-Llama-3.1-8B-Instruct": 128000, 44 | "meta-llama/Meta-Llama-3.1-70B-Instruct": 128000, 45 | "mistralai/Mistral-Nemo-Instruct-2407": 128000, 46 | } 47 | MAX_CONTEXT_LENGTH_FALLBACK = 128000 48 | 49 | 50 | class VLLMChatModel(ChatModel): 51 | def __init__( 52 | self, 53 | model: str, 54 | base_url: str, 55 | api_key: str, 56 | temperature: float = 0.0, 57 | price_per_input_token: float | None = None, 58 | capability: float | None = None, 59 | latency_ms_per_output_token: float | None = None, 60 | max_context_length: int | None = None, 61 | ) -> None: 62 | from openai import AsyncOpenAI, OpenAI 63 | 64 | self.model = model 65 | self.client = OpenAI( 66 | base_url=base_url, 67 | api_key=api_key, 68 | ) 69 | self.async_client = AsyncOpenAI( 70 | base_url=base_url, 71 | api_key=api_key, 72 | ) 73 | self.temperature = temperature 74 | self.price_per_input_token = ( 75 | price_per_input_token 76 | if price_per_input_token is not None 77 | else PRICE_PER_INPUT_TOKEN_MAP.get(model, INPUT_PRICE_PER_TOKEN_FALLBACK) 78 | ) 79 | self.capability = ( 80 | capability 81 | if capability is not None 82 | else CAPABILITY_SCORE_MAP.get(model, CAPABILITY_SCORE_FALLBACK) 83 | ) 84 | self.latency_ms_per_output_token = ( 85 | latency_ms_per_output_token 86 | if latency_ms_per_output_token is not None 87 | else LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get(model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK) 88 | ) 89 | self.max_context_length = ( 90 | max_context_length 91 | if max_context_length is not None 92 | else MAX_CONTEXT_LENGTH_MAP.get(model, MAX_CONTEXT_LENGTH_FALLBACK) 93 | ) 94 | 95 | def get_approx_cost(self, dp: Datapoint) -> float: 96 | cost_per_token = self.price_per_input_token 97 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 98 | 99 | def get_latency(self, dp: Datapoint) -> float: 100 | latency_per_output_token = self.latency_ms_per_output_token 101 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 102 | 103 | def get_capability(self) -> float: 104 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 105 | 106 | def supports_dp(self, dp: Datapoint) -> bool: 107 | prompt = approx_prompt_str(dp) 108 | return approx_num_tokens(prompt) <= self.max_context_length 109 | 110 | def generate_message( 111 | self, 112 | messages: list[Message], 113 | force_json: bool, 114 | temperature: float | None = None, 115 | ) -> Message: 116 | if temperature is None: 117 | temperature = self.temperature 118 | msgs = self.build_generate_message_state(messages) 119 | res = self.client.chat.completions.create( 120 | model=self.model, 121 | messages=msgs, 122 | temperature=wrap_temperature(temperature=temperature), 123 | ) 124 | return self.handle_generate_message_response( 125 | prompt=msgs, content=res.choices[0].message.content, force_json=force_json 126 | ) 127 | 128 | def force_json_prompt(self, text: str, _: bool = False) -> str: 129 | return super().force_json_prompt(text, with_prefix=True) 130 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/vllm_completion.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | from pydantic import BaseModel 5 | 6 | from tau_bench.model_utils.api.datapoint import Datapoint 7 | from tau_bench.model_utils.model.completion import ( 8 | CompletionModel, 9 | approx_cost_for_datapoint, 10 | approx_prompt_str, 11 | ) 12 | from tau_bench.model_utils.model.utils import approx_num_tokens 13 | from tau_bench.model_utils.model.vllm_utils import generate_request 14 | 15 | PRICE_PER_INPUT_TOKEN_MAP = { 16 | "Qwen/Qwen2-0.5B-Instruct": 0.0, 17 | "Qwen/Qwen2-1.5B-Instruct": 0.0, 18 | "Qwen/Qwen2-7B-Instruct": 0.0, 19 | "Qwen/Qwen2-72B-Instruct": 0.0, 20 | "meta-llama/Meta-Llama-3-8B-Instruct": 0.0, 21 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 0.0, 22 | "meta-llama/Meta-Llama-3-70B-Instruct": 0.0, 23 | "mistralai/Mistral-Nemo-Instruct-2407": 0.0, 24 | } 25 | INPUT_PRICE_PER_TOKEN_FALLBACK = 0.0 26 | 27 | # TODO: refine this 28 | CAPABILITY_SCORE_MAP = { 29 | "Qwen/Qwen2-0.5B-Instruct": 0.05, 30 | "Qwen/Qwen2-1.5B-Instruct": 0.07, 31 | "Qwen/Qwen2-7B-Instruct": 0.2, 32 | "Qwen/Qwen2-72B-Instruct": 0.4, 33 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 0.3, 34 | "sierra-research/Meta-Llama-3.1-8B-Instruct": 0.3, 35 | "meta-llama/Meta-Llama-3.1-70B-Instruct": 0.5, 36 | "mistralai/Mistral-Nemo-Instruct-2407": 0.3, 37 | } 38 | CAPABILITY_SCORE_FALLBACK = 0.1 39 | 40 | # TODO: implement 41 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 42 | # TODO: implement 43 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 44 | 45 | MAX_CONTEXT_LENGTH_MAP = { 46 | "Qwen/Qwen2-0.5B-Instruct": 32768, 47 | "Qwen/Qwen2-1.5B-Instruct": 32768, 48 | "Qwen/Qwen2-7B-Instruct": 131072, 49 | "Qwen/Qwen2-72B-Instruct": 131072, 50 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 128000, 51 | "sierra-research/Meta-Llama-3.1-8B-Instruct": 128000, 52 | "meta-llama/Meta-Llama-3.1-70B-Instruct": 128000, 53 | "mistralai/Mistral-Nemo-Instruct-2407": 128000, 54 | } 55 | MAX_CONTEXT_LENGTH_FALLBACK = 128000 56 | 57 | 58 | class VLLMCompletionModel(CompletionModel): 59 | def __init__( 60 | self, 61 | model: str, 62 | base_url: str, 63 | endpoint: str = "generate", 64 | temperature: float = 0.0, 65 | price_per_input_token: float | None = None, 66 | capability: float | None = None, 67 | latency_ms_per_output_token: float | None = None, 68 | max_context_length: int | None = None, 69 | ) -> None: 70 | self.model = model 71 | self.base_url = base_url 72 | self.url = os.path.join(base_url, endpoint) 73 | self.temperature = temperature 74 | self.price_per_input_token = ( 75 | price_per_input_token 76 | if price_per_input_token is not None 77 | else PRICE_PER_INPUT_TOKEN_MAP.get(model, INPUT_PRICE_PER_TOKEN_FALLBACK) 78 | ) 79 | self.capability = ( 80 | capability 81 | if capability is not None 82 | else CAPABILITY_SCORE_MAP.get(model, CAPABILITY_SCORE_FALLBACK) 83 | ) 84 | self.latency_ms_per_output_token = ( 85 | latency_ms_per_output_token 86 | if latency_ms_per_output_token is not None 87 | else LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get(model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK) 88 | ) 89 | self.max_context_length = ( 90 | max_context_length 91 | if max_context_length is not None 92 | else MAX_CONTEXT_LENGTH_MAP.get(model, MAX_CONTEXT_LENGTH_FALLBACK) 93 | ) 94 | 95 | def generate_from_prompt(self, prompt: str, temperature: float = 0.0) -> str: 96 | return generate_request(url=self.url, prompt=prompt, temperature=temperature) 97 | 98 | def parse_force_from_prompt( 99 | self, prompt: str, typ: BaseModel | dict[str, Any], temperature: float | None = None 100 | ) -> dict[str, Any]: 101 | if temperature is None: 102 | temperature = self.temperature 103 | res = generate_request( 104 | url=self.url, prompt=prompt, force_json=True, temperature=temperature 105 | ) 106 | return self.handle_parse_force_response(prompt=prompt, content=res) 107 | 108 | def get_approx_cost(self, dp: Datapoint) -> float: 109 | cost_per_token = self.price_per_input_token 110 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 111 | 112 | def get_latency(self, dp: Datapoint) -> float: 113 | latency_per_output_token = self.latency_ms_per_output_token 114 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 115 | 116 | def get_capability(self) -> float: 117 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 118 | 119 | def supports_dp(self, dp: Datapoint) -> bool: 120 | prompt = approx_prompt_str(dp) 121 | return approx_num_tokens(prompt) <= self.max_context_length 122 | -------------------------------------------------------------------------------- /tau_bench/model_utils/model/vllm_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import requests 4 | 5 | from tau_bench.model_utils.model.general_model import wrap_temperature 6 | 7 | 8 | def generate_request( 9 | url: str, 10 | prompt: str, 11 | temperature: float = 0.0, 12 | force_json: bool = False, 13 | **req_body_kwargs: Any, 14 | ) -> str: 15 | args = { 16 | "prompt": prompt, 17 | "temperature": wrap_temperature(temperature), 18 | "max_tokens": 4096, 19 | **req_body_kwargs, 20 | } 21 | if force_json: 22 | # the prompt will have a suffix of '```json\n' to indicate that the response should be a JSON object 23 | args["stop"] = ["```"] 24 | res = requests.post( 25 | url, 26 | json=args, 27 | ) 28 | res.raise_for_status() 29 | json_res = res.json() 30 | if "text" not in json_res: 31 | raise ValueError(f"Unexpected response: {json_res}") 32 | elif len(json_res["text"]) == 0: 33 | raise ValueError(f"Empty response: {json_res}") 34 | text = json_res["text"][0] 35 | assert isinstance(text, str) 36 | return text.removeprefix(prompt) 37 | -------------------------------------------------------------------------------- /tau_bench/run.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import os 4 | import json 5 | import random 6 | import traceback 7 | from math import comb 8 | import multiprocessing 9 | from typing import List, Dict, Any 10 | from datetime import datetime 11 | from concurrent.futures import ThreadPoolExecutor 12 | 13 | from tau_bench.envs import get_env 14 | from tau_bench.agents.base import Agent 15 | from tau_bench.types import EnvRunResult, RunConfig 16 | from litellm import provider_list 17 | from tau_bench.envs.user import UserStrategy 18 | 19 | 20 | def run(config: RunConfig) -> List[EnvRunResult]: 21 | assert config.env in ["retail", "airline"], "Only retail and airline envs are supported" 22 | assert config.model_provider in provider_list, "Invalid model provider" 23 | assert config.user_model_provider in provider_list, "Invalid user model provider" 24 | assert config.agent_strategy in ["tool-calling", "act", "react", "few-shot"], "Invalid agent strategy" 25 | assert config.task_split in ["train", "test", "dev"], "Invalid task split" 26 | assert config.user_strategy in [item.value for item in UserStrategy], "Invalid user strategy" 27 | 28 | random.seed(config.seed) 29 | time_str = datetime.now().strftime("%m%d%H%M%S") 30 | ckpt_path = f"{config.log_dir}/{config.agent_strategy}-{config.model.split('/')[-1]}-{config.temperature}_range_{config.start_index}-{config.end_index}_user-{config.user_model}-{config.user_strategy}_{time_str}.json" 31 | if not os.path.exists(config.log_dir): 32 | os.makedirs(config.log_dir) 33 | 34 | print(f"Loading user with strategy: {config.user_strategy}") 35 | env = get_env( 36 | config.env, 37 | user_strategy=config.user_strategy, 38 | user_model=config.user_model, 39 | user_provider=config.user_model_provider, 40 | task_split=config.task_split, 41 | ) 42 | agent = agent_factory( 43 | tools_info=env.tools_info, 44 | wiki=env.wiki, 45 | config=config, 46 | ) 47 | end_index = ( 48 | len(env.tasks) if config.end_index == -1 else min(config.end_index, len(env.tasks)) 49 | ) 50 | results: List[EnvRunResult] = [] 51 | lock = multiprocessing.Lock() 52 | if config.task_ids and len(config.task_ids) > 0: 53 | print(f"Running tasks {config.task_ids} (checkpoint path: {ckpt_path})") 54 | else: 55 | print( 56 | f"Running tasks {config.start_index} to {end_index} (checkpoint path: {ckpt_path})" 57 | ) 58 | for i in range(config.num_trials): 59 | if config.task_ids and len(config.task_ids) > 0: 60 | idxs = config.task_ids 61 | else: 62 | idxs = list(range(config.start_index, end_index)) 63 | if config.shuffle: 64 | random.shuffle(idxs) 65 | 66 | def _run(idx: int) -> EnvRunResult: 67 | isolated_env = get_env( 68 | config.env, 69 | user_strategy=config.user_strategy, 70 | user_model=config.user_model, 71 | task_split=config.task_split, 72 | user_provider=config.user_model_provider, 73 | task_index=idx, 74 | ) 75 | 76 | print(f"Running task {idx}") 77 | try: 78 | res = agent.solve( 79 | env=isolated_env, 80 | task_index=idx, 81 | ) 82 | result = EnvRunResult( 83 | task_id=idx, 84 | reward=res.reward, 85 | info=res.info, 86 | traj=res.messages, 87 | trial=i, 88 | ) 89 | except Exception as e: 90 | result = EnvRunResult( 91 | task_id=idx, 92 | reward=0.0, 93 | info={"error": str(e), "traceback": traceback.format_exc()}, 94 | traj=[], 95 | trial=i, 96 | ) 97 | print( 98 | "✅" if result.reward == 1 else "❌", 99 | f"task_id={idx}", 100 | result.info, 101 | ) 102 | print("-----") 103 | with lock: 104 | data = [] 105 | if os.path.exists(ckpt_path): 106 | with open(ckpt_path, "r") as f: 107 | data = json.load(f) 108 | with open(ckpt_path, "w") as f: 109 | json.dump(data + [result.model_dump()], f, indent=2) 110 | return result 111 | 112 | with ThreadPoolExecutor(max_workers=config.max_concurrency) as executor: 113 | res = list(executor.map(_run, idxs)) 114 | results.extend(res) 115 | 116 | display_metrics(results) 117 | 118 | with open(ckpt_path, "w") as f: 119 | json.dump([result.model_dump() for result in results], f, indent=2) 120 | print(f"\n📄 Results saved to {ckpt_path}\n") 121 | return results 122 | 123 | 124 | def agent_factory( 125 | tools_info: List[Dict[str, Any]], wiki, config: RunConfig 126 | ) -> Agent: 127 | if config.agent_strategy == "tool-calling": 128 | # native tool calling 129 | from tau_bench.agents.tool_calling_agent import ToolCallingAgent 130 | 131 | return ToolCallingAgent( 132 | tools_info=tools_info, 133 | wiki=wiki, 134 | model=config.model, 135 | provider=config.model_provider, 136 | temperature=config.temperature, 137 | ) 138 | elif config.agent_strategy == "act": 139 | # `act` from https://arxiv.org/abs/2210.03629 140 | from tau_bench.agents.chat_react_agent import ChatReActAgent 141 | 142 | return ChatReActAgent( 143 | tools_info=tools_info, 144 | wiki=wiki, 145 | model=config.model, 146 | provider=config.model_provider, 147 | use_reasoning=False, 148 | temperature=config.temperature, 149 | ) 150 | elif config.agent_strategy == "react": 151 | # `react` from https://arxiv.org/abs/2210.03629 152 | from tau_bench.agents.chat_react_agent import ChatReActAgent 153 | 154 | return ChatReActAgent( 155 | tools_info=tools_info, 156 | wiki=wiki, 157 | model=config.model, 158 | provider=config.model_provider, 159 | use_reasoning=True, 160 | temperature=config.temperature, 161 | ) 162 | elif config.agent_strategy == "few-shot": 163 | from tau_bench.agents.few_shot_agent import FewShotToolCallingAgent 164 | assert config.few_shot_displays_path is not None, "Few shot displays path is required for few-shot agent strategy" 165 | with open(config.few_shot_displays_path, "r") as f: 166 | few_shot_displays = [json.loads(line)["messages_display"] for line in f] 167 | 168 | return FewShotToolCallingAgent( 169 | tools_info=tools_info, 170 | wiki=wiki, 171 | model=config.model, 172 | provider=config.model_provider, 173 | few_shot_displays=few_shot_displays, 174 | temperature=config.temperature, 175 | ) 176 | else: 177 | raise ValueError(f"Unknown agent strategy: {config.agent_strategy}") 178 | 179 | 180 | def display_metrics(results: List[EnvRunResult]) -> None: 181 | def is_successful(reward: float) -> bool: 182 | return (1 - 1e-6) <= reward <= (1 + 1e-6) 183 | 184 | num_trials = len(set([r.trial for r in results])) 185 | rewards = [r.reward for r in results] 186 | avg_reward = sum(rewards) / len(rewards) 187 | # c from https://arxiv.org/pdf/2406.12045 188 | c_per_task_id: dict[int, int] = {} 189 | for result in results: 190 | if result.task_id not in c_per_task_id: 191 | c_per_task_id[result.task_id] = 1 if is_successful(result.reward) else 0 192 | else: 193 | c_per_task_id[result.task_id] += 1 if is_successful(result.reward) else 0 194 | pass_hat_ks: dict[int, float] = {} 195 | for k in range(1, num_trials + 1): 196 | sum_task_pass_hat_k = 0 197 | for c in c_per_task_id.values(): 198 | sum_task_pass_hat_k += comb(c, k) / comb(num_trials, k) 199 | pass_hat_ks[k] = sum_task_pass_hat_k / len(c_per_task_id) 200 | print(f"🏆 Average reward: {avg_reward}") 201 | print("📈 Pass^k") 202 | for k, pass_hat_k in pass_hat_ks.items(): 203 | print(f" k={k}: {pass_hat_k}") 204 | -------------------------------------------------------------------------------- /tau_bench/types.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from pydantic import BaseModel 4 | from typing import List, Dict, Any, Optional, Union 5 | 6 | RESPOND_ACTION_NAME = "respond" 7 | RESPOND_ACTION_FIELD_NAME = "content" 8 | 9 | 10 | class Action(BaseModel): 11 | name: str 12 | kwargs: Dict[str, Any] 13 | 14 | 15 | class Task(BaseModel): 16 | user_id: str 17 | actions: List[Action] 18 | instruction: str 19 | outputs: List[str] 20 | 21 | 22 | class RewardOutputInfo(BaseModel): 23 | r_outputs: float 24 | outputs: Dict[str, bool] 25 | 26 | 27 | class RewardActionInfo(BaseModel): 28 | r_actions: float 29 | gt_data_hash: str 30 | 31 | 32 | class RewardResult(BaseModel): 33 | reward: float 34 | info: Union[RewardOutputInfo, RewardActionInfo] 35 | actions: List[Action] 36 | 37 | 38 | class SolveResult(BaseModel): 39 | reward: float 40 | messages: List[Dict[str, Any]] 41 | info: Dict[str, Any] 42 | total_cost: Optional[float] = None 43 | 44 | 45 | class EnvInfo(BaseModel): 46 | task: Task 47 | source: Optional[str] = None 48 | user_cost: Optional[float] = None 49 | reward_info: Optional[RewardResult] = None 50 | 51 | 52 | class EnvResponse(BaseModel): 53 | observation: str 54 | reward: float 55 | done: bool 56 | info: EnvInfo 57 | 58 | 59 | class EnvResetResponse(BaseModel): 60 | observation: str 61 | info: EnvInfo 62 | 63 | 64 | class EnvRunResult(BaseModel): 65 | task_id: int 66 | reward: float 67 | info: Dict[str, Any] 68 | traj: List[Dict[str, Any]] 69 | trial: int 70 | 71 | 72 | class RunConfig(BaseModel): 73 | model_provider: str 74 | user_model_provider: str 75 | model: str 76 | user_model: str = "gpt-4o" 77 | num_trials: int = 1 78 | env: str = "retail" 79 | agent_strategy: str = "tool-calling" 80 | temperature: float = 0.0 81 | task_split: str = "test" 82 | start_index: int = 0 83 | end_index: int = -1 84 | task_ids: Optional[List[int]] = None 85 | log_dir: str = "results" 86 | max_concurrency: int = 1 87 | seed: int = 10 88 | shuffle: int = 0 89 | user_strategy: str = "llm" 90 | few_shot_displays_path: Optional[str] = None 91 | --------------------------------------------------------------------------------