├── .gitignore ├── LICENSE ├── README.md ├── data ├── api_examples.json ├── dataset.zip ├── llama3-josh-toolwoz-for-KTO.zip ├── testListFile.json ├── tools.json ├── valListFile.json └── valid_api_defs.json ├── db ├── attraction-dbase.db ├── attraction_db.json ├── bus-dbase.db ├── bus_db.json ├── hospital-dbase.db ├── hospital_db.json ├── hotel-dbase.db ├── hotel_db.json ├── police_db.json ├── restaurant-dbase.db ├── restaurant_db.json ├── taxi-dbase.db ├── taxi_db.json ├── train-dbase.db └── train_db.json ├── final_file_push.py ├── final_file_push_kto.py ├── final_file_push_taubench.py ├── final_file_push_taubench_kto.py ├── instructionsToRunMTBench.md ├── josh_train ├── __init__.py ├── _version.py ├── agents │ ├── fc_agent.py │ └── react_agent.py ├── async_josh.py ├── build_apis.py ├── config.py ├── conversation_types │ ├── conversation_state.py │ └── conversation_state_pref_tree.py ├── finetune │ ├── finetune_agent.py │ ├── kto_ft.py │ ├── sft_agent.py │ ├── sft_agent_final_mod.py │ └── sft_agent_successful.py ├── josh.py ├── main.py ├── training_fnames.json ├── users │ ├── base_user_simulator.py │ ├── goal_user_simulator.py │ ├── guide_user_simulator.py │ └── script_user_simulator.py └── utils.py ├── mtbencheval.zip ├── prompts └── prompts.yaml ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py ├── takeAggregateWithOptBaseline.py ├── tau-bench-eval ├── LICENSE ├── README.md ├── auto_error_identification.py ├── run.py ├── setup.py └── tau_bench │ ├── __init__.py │ ├── agents │ ├── __init__.py │ ├── base.py │ ├── chat_react_agent.py │ ├── josh_chat_react_agent.py │ ├── josh_tool_calling_agent.py │ └── tool_calling_agent.py │ ├── envs │ ├── __init__.py │ ├── airline │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── flights.json │ │ │ ├── reservations.json │ │ │ └── users.json │ │ ├── env.py │ │ ├── rules.py │ │ ├── tasks.py │ │ ├── tasks_test.py │ │ ├── tools │ │ │ ├── __init__.py │ │ │ ├── book_reservation.py │ │ │ ├── calculate.py │ │ │ ├── cancel_reservation.py │ │ │ ├── get_reservation_details.py │ │ │ ├── get_user_details.py │ │ │ ├── list_all_airports.py │ │ │ ├── search_direct_flight.py │ │ │ ├── search_onestop_flight.py │ │ │ ├── send_certificate.py │ │ │ ├── think.py │ │ │ ├── transfer_to_human_agents.py │ │ │ ├── update_reservation_baggages.py │ │ │ ├── update_reservation_flights.py │ │ │ └── update_reservation_passengers.py │ │ ├── wiki.md │ │ └── wiki.py │ ├── base.py │ ├── retail │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── orders.json │ │ │ ├── products.json │ │ │ ├── readme.md │ │ │ └── users.json │ │ ├── env.py │ │ ├── rules.py │ │ ├── tasks.py │ │ ├── tasks_dev.py │ │ ├── tasks_test.py │ │ ├── tasks_train.py │ │ ├── tools │ │ │ ├── __init__.py │ │ │ ├── calculate.py │ │ │ ├── cancel_pending_order.py │ │ │ ├── exchange_delivered_order_items.py │ │ │ ├── find_user_id_by_email.py │ │ │ ├── find_user_id_by_name_zip.py │ │ │ ├── get_order_details.py │ │ │ ├── get_product_details.py │ │ │ ├── get_user_details.py │ │ │ ├── list_all_product_types.py │ │ │ ├── modify_pending_order_address.py │ │ │ ├── modify_pending_order_items.py │ │ │ ├── modify_pending_order_payment.py │ │ │ ├── modify_user_address.py │ │ │ ├── return_delivered_order_items.py │ │ │ ├── think.py │ │ │ └── transfer_to_human_agents.py │ │ ├── wiki.md │ │ └── wiki.py │ ├── tool.py │ └── user.py │ ├── model_utils │ ├── __init__.py │ ├── api │ │ ├── __init__.py │ │ ├── _model_methods.py │ │ ├── api.py │ │ ├── cache.py │ │ ├── datapoint.py │ │ ├── exception.py │ │ ├── logging.py │ │ ├── router.py │ │ ├── sample.py │ │ ├── tokens.py │ │ └── types.py │ ├── func_tools │ │ ├── __init__.py │ │ ├── filter.py │ │ └── map.py │ └── model │ │ ├── __init__.py │ │ ├── anyscale.py │ │ ├── chat.py │ │ ├── claude.py │ │ ├── completion.py │ │ ├── exception.py │ │ ├── general_model.py │ │ ├── mistral.py │ │ ├── model.py │ │ ├── openai.py │ │ ├── outlines_completion.py │ │ ├── utils.py │ │ ├── vllm_chat.py │ │ ├── vllm_completion.py │ │ └── vllm_utils.py │ └── types.py └── training_fnames.json /.gitignore: -------------------------------------------------------------------------------- 1 | #Run result artifacts 2 | results/** 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *.egg-info/ 8 | .vscode/ 9 | *.jsonl 10 | records/ 11 | sig_folder/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Barrett Martin Lattimer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Rewards Can Self-Train Dialogue Agents 2 | Barrett Martin Lattimer, Varun Gangal, Ryan McDonald, Yi Yang 3 | 4 | contact: blattimer@asapp.com 5 | 6 | paper: https://arxiv.org/abs/2409.04617 7 | 8 | This repo runs JOSH, the ToolWOZ, and τ-bench dataset. This repo also contains ways of logging training and preference-annotated episodes from user-simulator interactions and LORA-driven preference tuning of small LLMs from such preference annotated experience. 9 | 10 | 11 | ## Setup 12 | 1. Run the following in a new env 13 | ``` 14 | pip install josh-train 15 | ``` 16 | or 17 | ``` 18 | pip install -e . 19 | ``` 20 | 1. Unzip the ```dataset.zip``` file in the ```data``` folder 21 | 22 | 2. Set up your openai credentials 23 | ``` 24 | export OPENAI_API_KEY= # api_key 25 | export OPENAI_ORGANIZATION= # api_org 26 | ``` 27 | If you're running Llama or another local model, you will also need to set HF_TOKEN much in the same way. Wherever you see HF_KEY please replace it by your huggingface token. 28 | 29 | ## Running ToolWOZ 30 | 31 | You can run ToolWOZ normally by doing the following 32 | ``` 33 | python josh_train/main.py 34 | ``` 35 | Increase the ```--max_concurrency``` depending on your api rate limits 36 | ### JOSH on ToolWOZ 37 | Enable JOSH on ToolWOZ by adding the ```--josh``` flag, and make the running of JOSH print updates by also adding ```--josh_debug``` 38 | 39 | One example of a more involved JOSH prompt would be the following 40 | ``` 41 | python josh_train/main.py --josh --josh_debug --max_concurrency 20 --seed 20 --task_split train --temperature 1.0 --agent_strategy react --user_mode goal --model gpt-4o-mini --end_index 10 --beam_size 8 42 | ``` 43 | 44 | ## Running τ-bench 45 | 46 | We have added a clone of [τ-bench](https://github.com/sierra-research/tau-bench) to this repo with two run files, one for normal τ-bench testing and another for JOSH rollouts on τ-bench 47 | 48 | To run τ-bench normally you can do 49 | ``` 50 | python tau-bench-eval/run.py 51 | ``` 52 | 53 | ### JOSH on τ-bench 54 | To run JOSH on τ-bench you can do 55 | ``` 56 | python tau-bench-eval/run.py --josh --debug 57 | ``` 58 | 59 | ## Using JOSH 60 | A class of JOSH is provided in this repo to be very flexible and work for a wide variety of user/agent interactions. To use JOSH yourself, you can start with the following code snippet 61 | ``` 62 | from josh_train.josh import JOSH, BaseJOSHAgent, BaseRewards, BaseJOSHUser 63 | def add_error_message(agent): 64 | agent.messages.append({'role':'assistant', 'content':'Error: Agent ran out of retries.'}) 65 | return agent 66 | 67 | def step_agent(agent:BaseJOSHAgent, **kwargs): 68 | pass_to_customer = agent.step(**kwargs) 69 | return agent, pass_to_customer 70 | 71 | def step_user(user:BaseJOSHUser, agent:BaseJOSHAgent): 72 | agent, end_conversation = user.step(agent) 73 | return agent, end_conversation 74 | 75 | josh = JOSH( 76 | rewards=BaseRewards(['say hello', 'say hello', 'say hello']), 77 | agent_step=step_agent, 78 | user_step=step_user, 79 | add_error_message=add_error_message, 80 | root_agent = BaseJOSHAgent(), 81 | user = BaseJOSHUser(), 82 | debug=True 83 | ) 84 | 85 | for _ in range(10): 86 | max_reward, all_done = josh.step() 87 | if all_done: 88 | break 89 | 90 | print(max_reward) 91 | print(josh.training_examples) 92 | ``` 93 | 94 | All classes can be built on top of, and expanded for further use. 95 | 96 | 97 | ## MT-Bench 98 | 99 | (If you want to later evaluate MTBench) 100 | ``` 101 | unzip mtbencheval.zip 102 | ``` 103 | 104 | ## Citation 105 | Please cite if you enjoyed this work! 106 | ``` 107 | @article{lattimer2024sparse, 108 | title={Sparse Rewards Can Self-Train Dialogue Agents}, 109 | author={Lattimer, Barrett Martin and Gangal, Varun and McDonald, Ryan and Yang, Yi}, 110 | journal={arXiv preprint arXiv:2409.04617}, 111 | year={2024} 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /data/api_examples.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "book_restaurant", 4 | "parameters": { 5 | "time": "13:00", 6 | "day": "thursday", 7 | "people": "3", 8 | "name": "the river bar steakhouse and grill" 9 | } 10 | }, 11 | { 12 | "name": "search_restaurant", 13 | "parameters": { 14 | "food": "modern european", 15 | "pricerange": "cheap", 16 | "name": "jinling noodle bar", 17 | "area": "centre" 18 | } 19 | }, 20 | { 21 | "name": "book_hotel", 22 | "parameters": { 23 | "stay": "4", 24 | "day": "friday", 25 | "people": "2", 26 | "name": "acorn guest house" 27 | } 28 | }, 29 | { 30 | "name": "search_hotel", 31 | "parameters": { 32 | "name": "hamilton lodge", 33 | "area": "north", 34 | "parking": "yes", 35 | "pricerange": "moderate", 36 | "stars": "4", 37 | "internet": "yes", 38 | "type": "guesthouse" 39 | } 40 | }, 41 | { 42 | "name": "book_attraction", 43 | "parameters": {} 44 | }, 45 | { 46 | "name": "search_attraction", 47 | "parameters": { 48 | "type": "boat", 49 | "name": "sheep's green and lammas land park fen causeway", 50 | "area": "centre" 51 | } 52 | }, 53 | { 54 | "name": "book_train", 55 | "parameters": { 56 | "people": "3", 57 | "trainID": "TR2048" 58 | } 59 | }, 60 | { 61 | "name": "search_train", 62 | "parameters": { 63 | "leaveAt": "08:45", 64 | "destination": "cambridge", 65 | "day": "tuesday", 66 | "arriveBy": "12:30", 67 | "departure": "london liverpool street" 68 | } 69 | } 70 | ] -------------------------------------------------------------------------------- /data/dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/data/dataset.zip -------------------------------------------------------------------------------- /data/llama3-josh-toolwoz-for-KTO.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/data/llama3-josh-toolwoz-for-KTO.zip -------------------------------------------------------------------------------- /data/valid_api_defs.json: -------------------------------------------------------------------------------- 1 | { 2 | "restaurant": { 3 | "book_restaurant": { 4 | "parameters": [ 5 | "time", 6 | "day", 7 | "people", 8 | "name" 9 | ], 10 | "returns": [ 11 | "name", 12 | "reference" 13 | ] 14 | }, 15 | "search_restaurant": { 16 | "parameters": [ 17 | "food", 18 | "pricerange", 19 | "name", 20 | "area" 21 | ], 22 | "returns": [ 23 | "id", 24 | "address", 25 | "area", 26 | "food", 27 | "introduction", 28 | "name", 29 | "phone", 30 | "postcode", 31 | "pricerange", 32 | "signature", 33 | "type" 34 | ] 35 | } 36 | }, 37 | "hotel": { 38 | "book_hotel": { 39 | "parameters": [ 40 | "stay", 41 | "day", 42 | "people", 43 | "name" 44 | ], 45 | "returns": [ 46 | "name", 47 | "reference" 48 | ] 49 | }, 50 | "search_hotel": { 51 | "parameters": [ 52 | "name", 53 | "area", 54 | "parking", 55 | "pricerange", 56 | "stars", 57 | "internet", 58 | "type" 59 | ], 60 | "returns": [ 61 | "id", 62 | "address", 63 | "area", 64 | "internet", 65 | "parking", 66 | "single", 67 | "double", 68 | "family", 69 | "name", 70 | "phone", 71 | "postcode", 72 | "pricerange", 73 | "takesbookings", 74 | "stars", 75 | "type" 76 | ] 77 | } 78 | }, 79 | "attraction": { 80 | "book_attraction": { 81 | "parameters": [], 82 | "returns": [] 83 | }, 84 | "search_attraction": { 85 | "parameters": [ 86 | "type", 87 | "name", 88 | "area" 89 | ], 90 | "returns": [ 91 | "id", 92 | "address", 93 | "area", 94 | "entrance", 95 | "name", 96 | "phone", 97 | "postcode", 98 | "pricerange", 99 | "openhours", 100 | "type" 101 | ] 102 | } 103 | }, 104 | "train": { 105 | "book_train": { 106 | "parameters": [ 107 | "people", 108 | "trainID" 109 | ], 110 | "returns": [ 111 | "trainID", 112 | "reference" 113 | ] 114 | }, 115 | "search_train": { 116 | "parameters": [ 117 | "leaveAt", 118 | "destination", 119 | "day", 120 | "arriveBy", 121 | "departure" 122 | ], 123 | "returns": [ 124 | "trainID", 125 | "arriveBy", 126 | "day", 127 | "departure", 128 | "destination", 129 | "duration", 130 | "leaveAt", 131 | "price" 132 | ] 133 | } 134 | }, 135 | "taxi": { 136 | "book_taxi": { 137 | "parameters": [], 138 | "returns": [ 139 | "phone", 140 | "type" 141 | ] 142 | }, 143 | "search_taxi": { 144 | "parameters": [ 145 | "leaveAt", 146 | "destination", 147 | "departure", 148 | "arriveBy" 149 | ], 150 | "returns": [] 151 | } 152 | }, 153 | "hospital": { 154 | "book_hospital": { 155 | "parameters": [], 156 | "returns": [ 157 | "reference", 158 | "time", 159 | "department" 160 | ] 161 | }, 162 | "search_hospital": { 163 | "parameters": [ 164 | "department" 165 | ], 166 | "returns": [] 167 | } 168 | } 169 | } -------------------------------------------------------------------------------- /db/attraction-dbase.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/db/attraction-dbase.db -------------------------------------------------------------------------------- /db/bus-dbase.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/db/bus-dbase.db -------------------------------------------------------------------------------- /db/hospital-dbase.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/db/hospital-dbase.db -------------------------------------------------------------------------------- /db/hotel-dbase.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/db/hotel-dbase.db -------------------------------------------------------------------------------- /db/police_db.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "Parkside Police Station", 4 | "address": "Parkside, Cambridge", 5 | "id": 0, 6 | "phone": "01223358966" 7 | } 8 | ] 9 | -------------------------------------------------------------------------------- /db/restaurant-dbase.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/db/restaurant-dbase.db -------------------------------------------------------------------------------- /db/taxi-dbase.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/db/taxi-dbase.db -------------------------------------------------------------------------------- /db/taxi_db.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "taxi_colors" : ["black","white","red","yellow","blue","grey"], 4 | "taxi_types": ["toyota","skoda","bmw","honda","ford","audi","lexus","volvo","volkswagen","tesla"], 5 | "taxi_phone": ["^[0-9]{10}$"] 6 | } 7 | ] -------------------------------------------------------------------------------- /db/train-dbase.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/db/train-dbase.db -------------------------------------------------------------------------------- /final_file_push.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import yaml 4 | import os 5 | 6 | path = os.getcwd() 7 | 8 | with open(f'{path}/prompts/prompts.yaml', 'r') as file: 9 | prompts = yaml.safe_load(file) 10 | with open(f'{path}/data/tools.json', 'r') as file: 11 | tools_list = json.load(file) 12 | MONO_PROMPT = prompts['mistral_mono_prompt_v2'].replace('{example_filled}', json.dumps(tools_list, indent=2)) 13 | 14 | folder = f'{path}/records/mini_training_data_full/' 15 | file_list = [folder+x for x in os.listdir(folder) if '.json' in x] 16 | file_list = sorted(file_list, key=lambda x: x.split('.json')[0].split('_')[-1]) 17 | files = [] 18 | for fname in file_list: 19 | with open(fname, 'r') as file: 20 | tmp_data = json.load(file) 21 | if tmp_data['100_percent_reached']: 22 | longest = [] 23 | for path, _, is_gold in tmp_data['training_paths']: 24 | if is_gold and len(path) > len(longest): 25 | longest = path 26 | files.append(longest) 27 | print(len(files)) 28 | def testing(utterance): 29 | if utterance['role']=='assistant': 30 | if 'APIRETURN' in utterance['content']: 31 | return False 32 | if 'ERROR:' in utterance['content']: 33 | return False 34 | if 'Error:' in utterance['content']: 35 | return False 36 | if 'FAILURE' in utterance['content']: 37 | return False 38 | return True 39 | tots = [] 40 | for f in files: 41 | if all([testing(q) for q in f]): 42 | tots.append(f) 43 | 44 | files = tots 45 | 46 | kto_data_dic_full_convo_only = {'data':[]} 47 | for f in files: 48 | if not f or len(f) ==0: 49 | continue 50 | kto_data_dic_full_convo_only['data'].append({'messages':[{'role':'system', 'content':MONO_PROMPT}] + f}) 51 | print(len(kto_data_dic_full_convo_only['data'])) 52 | with open('training_data_mini_0807.json', 'w') as file: 53 | json.dump(kto_data_dic_full_convo_only, file, indent=2) 54 | 55 | with open('train_data_mini_0807.jsonl', 'w') as file: 56 | for idx, item in enumerate(kto_data_dic_full_convo_only['data']): 57 | json_line = json.dumps(item) 58 | additional = '' 59 | if idx < len(kto_data_dic_full_convo_only['data'])-1: 60 | additional = '\n' 61 | file.write(json_line + additional) 62 | 63 | -------------------------------------------------------------------------------- /final_file_push_kto.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pandas as pd 4 | import yaml 5 | import numpy as np 6 | import os 7 | import copy 8 | 9 | path = os.getcwd() 10 | 11 | with open(f'{path}/prompts/prompts.yaml', 'r') as file: 12 | prompts = yaml.safe_load(file) 13 | with open(f'{path}/data/tools.json', 'r') as file: 14 | tools_list = json.load(file) 15 | MONO_PROMPT = prompts['mistral_mono_prompt_v2'].replace('{example_filled}', json.dumps(tools_list, indent=2)) 16 | 17 | 18 | 19 | folder = f'{path}/records/training_data_v2/' 20 | file_list = [folder+x for x in os.listdir(folder) if '.json' in x] 21 | file_list = sorted(file_list, key=lambda x: x.split('.json')[0].split('_')[-1]) 22 | files = [] 23 | bad_ex = [] 24 | for fname in file_list: 25 | with open(fname, 'r') as file: 26 | tmp_data = json.load(file) 27 | if tmp_data['100_percent_reached']: 28 | new_bads = [] 29 | longest = [] 30 | for path, successful, is_gold in tmp_data['training_paths']: 31 | if is_gold and len(path) > len(longest): 32 | longest = path 33 | if not successful and not is_gold: 34 | new_bads.append(path) 35 | files.append(longest) 36 | bad_ex.append(new_bads) 37 | 38 | def testing(utterance): 39 | if utterance['role']=='assistant': 40 | if 'APIRETURN' in utterance['content']: 41 | return False 42 | if 'ERROR:' in utterance['content']: 43 | return False 44 | if 'Error:' in utterance['content']: 45 | return False 46 | if 'FAILURE' in utterance['content']: 47 | return False 48 | return True 49 | tots = [] 50 | bads = [] 51 | for j, f in enumerate(files): 52 | if all([testing(q) for q in f]): 53 | tots.append(f) 54 | for ex in bad_ex[j]: 55 | bads.append(ex) 56 | 57 | files = tots 58 | 59 | 60 | 61 | kto_data_dic_full_convo_only = {'data':[]} 62 | unique_prompts = [] 63 | # add good 64 | for jj, training_example in enumerate(files): 65 | if len(training_example)==0 or training_example in unique_prompts: 66 | continue 67 | unique_prompts.append(training_example) 68 | cleaned_training_example = [{'role':'system', 'content':MONO_PROMPT}] 69 | for idx, utterance in enumerate(training_example): 70 | if utterance['role']=='assistant': 71 | kto_data_dic_full_convo_only['data'].append({'prompt':copy.deepcopy(cleaned_training_example), 'completion':[copy.deepcopy(utterance)], 'label':True}) 72 | cleaned_training_example.append(utterance) 73 | else: 74 | cleaned_training_example.append(utterance) 75 | 76 | # add bad 77 | bad_data = [] 78 | for jj, training_example in enumerate(bads): 79 | if len(training_example)==0 or training_example in unique_prompts: 80 | continue 81 | unique_prompts.append(training_example) 82 | cleaned_training_example = [{'role':'system', 'content':MONO_PROMPT}] 83 | for idx, dic in enumerate(reversed(training_example)): 84 | if dic.get('role')=='system': 85 | continue 86 | if dic.get('role')=='user' and 'APIRETURN' not in dic.get('content'): 87 | break 88 | if not dic.get('content').startswith('APIRETURN'): 89 | bad_data.append({'prompt':copy.deepcopy([{'role':'system', 'content':MONO_PROMPT}]+training_example[:len(training_example)-idx-1]), 'completion':[copy.deepcopy(dic)], 'label':False}) 90 | 91 | print(len(kto_data_dic_full_convo_only['data'])) 92 | shared_bads = [] 93 | 94 | for good_example in kto_data_dic_full_convo_only['data']: 95 | c=[] 96 | good_prompt = good_example['prompt'] 97 | for bad_example in bad_data: 98 | if bad_example['prompt'] == good_prompt and bad_example not in shared_bads: 99 | shared_bads.append(bad_example) 100 | 101 | print(len(shared_bads)) 102 | kto_data_dic_full_convo_only['data'] = kto_data_dic_full_convo_only['data']+shared_bads 103 | # kto_data_dic_full_convo_only = {'data':[]} 104 | # for f in files: 105 | # if not f or len(f) ==0: 106 | # continue 107 | # kto_data_dic_full_convo_only['data'].append({'messages':[{'role':'system', 'content':MONO_PROMPT}] + f}) 108 | 109 | with open('train_data_kto_just_div.jsonl', 'w') as file: 110 | for idx, item in enumerate(kto_data_dic_full_convo_only['data']): 111 | json_line = json.dumps(item) 112 | additional = '' 113 | if idx < len(kto_data_dic_full_convo_only['data'])-1: 114 | additional = '\n' 115 | file.write(json_line + additional) 116 | 117 | -------------------------------------------------------------------------------- /final_file_push_taubench.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pandas as pd 4 | import yaml 5 | import numpy as np 6 | import os 7 | import copy 8 | 9 | path = os.getcwd() 10 | 11 | toolwoz=True 12 | 13 | folder = '/Users/blattimer/code/josh-llm-simulation-training/records/beam-size-16-gpt-4o-mini-react-toolowz.json'#f'records/josh_function_calling_toolwoz_gpt-4o.json' 14 | with open(folder, 'r') as file: 15 | tmp_data = json.load(file) 16 | 17 | def filter_errors(convo): 18 | updated_convo = [] 19 | for idx, msg in enumerate(convo['messages']): 20 | error_next = False 21 | if idx+1 < len(convo): 22 | error_next = ('content' in convo['messages'][idx+1] and 'Error' in convo['messages'][idx+1]['content']) or ('content' in msg and 'Error' in msg['content']) 23 | if not error_next: 24 | updated_convo.append(msg) 25 | convo['messages'] = updated_convo 26 | return convo 27 | 28 | def no_error(convo): 29 | for msg in convo['messages']: 30 | if 'content' in msg and 'Error' in msg['content']: 31 | return False 32 | if 'content' in msg and msg['content'] == 'API output: ': 33 | return False 34 | return True 35 | 36 | def no_error_toolwoz(convo): 37 | for utterance in convo['messages']: 38 | if 'content' not in utterance: 39 | continue 40 | if utterance['role']=='assistant': 41 | if 'APIRETURN' in utterance['content']: 42 | return False 43 | if 'ERROR:' in utterance['content']: 44 | return False 45 | if 'Error:' in utterance['content']: 46 | return False 47 | if 'FAILURE' in utterance['content']: 48 | return False 49 | return True 50 | 51 | files = [] 52 | bad_ex = [] 53 | for convo in tmp_data: 54 | if convo['reward']==1: 55 | new_bads = [] 56 | longest = [] 57 | convo_to_it = convo['info']['training_examples'] if not toolwoz else convo['training_examples'] 58 | for path, successful, is_gold in convo_to_it: 59 | if is_gold and len(path['messages']) > len(longest): 60 | longest = path 61 | if toolwoz: 62 | if no_error_toolwoz(longest): 63 | files.append(longest) 64 | else: 65 | if no_error(longest): 66 | files.append(longest) 67 | print(len(files)) 68 | 69 | import random 70 | random.seed(42) 71 | random.shuffle(files) 72 | 73 | with open('train_data_ft-gpt-4o-mini-react-beam-16-toolwoz.jsonl', 'w') as file: 74 | for idx, item in enumerate(files): 75 | json_line = json.dumps(item) 76 | additional = '' 77 | if idx < len(files)-1: 78 | additional = '\n' 79 | file.write(json_line + additional) 80 | 81 | -------------------------------------------------------------------------------- /final_file_push_taubench_kto.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pandas as pd 4 | import yaml 5 | import numpy as np 6 | import os 7 | import copy 8 | 9 | path = os.getcwd() 10 | 11 | 12 | 13 | folder = f'/root/code/multiwoz-api/results/react1-llama-1.0_range_0--1_usergpt-4o_0819233641.json' 14 | with open(folder, 'r') as file: 15 | tmp_data = json.load(file) 16 | 17 | files = [] 18 | bad_ex = [] 19 | for convo in tmp_data: 20 | if convo['reward']==1: 21 | new_bads = [] 22 | longest = [] 23 | for path, successful, is_gold in convo['info']['training_examples']: 24 | if is_gold and len(path) > len(longest): 25 | longest = path 26 | if not successful and not is_gold: 27 | new_bads.append(path) 28 | files.append(longest) 29 | bad_ex.append(new_bads) 30 | print(len(files)) 31 | # def testing(utterance): 32 | # if utterance['role']=='assistant': 33 | # if 'API output:' in utterance['content']: 34 | # return False 35 | # if 'ERROR:' in utterance['content']: 36 | # return False 37 | # if 'Error:' in utterance['content']: 38 | # return False 39 | # if 'FAILURE' in utterance['content']: 40 | # return False 41 | # return True 42 | # tots = [] 43 | bads = [] 44 | for j, f in enumerate(files): 45 | for ex in bad_ex[j]: 46 | bads.append(ex) 47 | # print(files[0]) 48 | # files = tots 49 | # print(len(files)) 50 | 51 | kto_data_dic_full_convo_only = {'data':[]} 52 | unique_prompts = [] 53 | # add good 54 | for jj, training_example in enumerate(files): 55 | if len(training_example)==0 or training_example in unique_prompts: 56 | continue 57 | unique_prompts.append(training_example) 58 | cleaned_training_example = [] 59 | for idx, utterance in enumerate(training_example): 60 | if utterance['role']=='assistant': 61 | kto_data_dic_full_convo_only['data'].append({'prompt':copy.deepcopy(cleaned_training_example), 'completion':[copy.deepcopy(utterance)], 'label':True}) 62 | cleaned_training_example.append(utterance) 63 | else: 64 | cleaned_training_example.append(utterance) 65 | 66 | # add bad 67 | bad_data = [] 68 | for jj, training_example in enumerate(bads): 69 | if len(training_example)==0 or training_example in unique_prompts: 70 | continue 71 | if training_example[-1].get('role')=='user': 72 | training_example = training_example[:-1] 73 | unique_prompts.append(training_example) 74 | for idx, dic in enumerate(reversed(training_example)): 75 | if dic.get('role')=='system': 76 | continue 77 | if dic.get('role')=='user' and 'API output:' not in dic.get('content'): 78 | break 79 | if not dic.get('content').startswith('API output:'): 80 | bad_data.append({'prompt':copy.deepcopy(training_example[:len(training_example)-idx-1]), 'completion':[copy.deepcopy(dic)], 'label':False}) 81 | 82 | print(len(kto_data_dic_full_convo_only['data'])) 83 | shared_bads = [] 84 | 85 | for good_example in kto_data_dic_full_convo_only['data']: 86 | c=[] 87 | good_prompt = good_example['prompt'] 88 | for bad_example in bad_data: 89 | if bad_example['prompt'] == good_prompt and bad_example not in shared_bads: 90 | shared_bads.append(bad_example) 91 | 92 | print(len(shared_bads)) 93 | kto_data_dic_full_convo_only['data'] = kto_data_dic_full_convo_only['data']+shared_bads 94 | # kto_data_dic_full_convo_only = {'data':[]} 95 | # for f in files: 96 | # if not f or len(f) ==0: 97 | # continue 98 | # kto_data_dic_full_convo_only['data'].append({'messages':[{'role':'system', 'content':MONO_PROMPT}] + f}) 99 | 100 | with open('train_data_kto_retail.jsonl', 'w') as file: 101 | for idx, item in enumerate(kto_data_dic_full_convo_only['data']): 102 | json_line = json.dumps(item) 103 | additional = '' 104 | if idx < len(kto_data_dic_full_convo_only['data'])-1: 105 | additional = '\n' 106 | file.write(json_line + additional) 107 | 108 | -------------------------------------------------------------------------------- /instructionsToRunMTBench.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ``` 4 | cd mtbencheval/FastChat 5 | pip install -e ".[model_worker,llm_judge]" 6 | ``` 7 | (To minimize env distruptions, you can first openup the pyroject.toml and see what all's gonna be installed. Install the stuff yourself if you can one by one so that 8 | you dont have suddenly a long chain of dependencies with something getting installed. This will ensure that the final pip install -e . only installs 1-2 new packages apart 9 | from the repo building itself, which will make your job easy. Nevertheless keep a close eye out for the last command too) 10 | 11 | ``` 12 | python download_mt_bench_pregenerated.py 13 | ``` 14 | 15 | 16 | ``` 17 | python gen_model_answer.py --model-path lmsys/vicuna-7b-v1.5 --model-id vicuna-7b-v1.5 18 | ``` 19 | 20 | Ensure OpenAI_API_KEY is set. 21 | 22 | Now run judgement on the thing you generated above [it may ask you to cluck yes after you run it. 23 | 24 | ``` 25 | python gen_judgment.py --model-list vicuna-7b-v1.5 --parallel 2 26 | ``` 27 | 28 | You can see the results via 29 | ``` 30 | python show_result.py --model-list vicuna-7b-v1.5 31 | ``` 32 | You get 2 turn-level scores and 1 overall aggregate. For vicuna its like 6.11 33 | 34 | FAQs and declutter section: 35 | 36 | 37 | Why do we need most recent repo vrsion? 38 | LLama3 related fixes were added only recently. 39 | e.g. see https://github.com/lm-sys/FastChat/pull/3326 40 | 41 | 42 | What runs so far and checked? 43 | Huggingface lmsys vicuna 44 | Llama3 [will check next] 45 | llama3 with peft adapter [will check after] 46 | -------------------------------------------------------------------------------- /josh_train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/josh_train/__init__.py -------------------------------------------------------------------------------- /josh_train/_version.py: -------------------------------------------------------------------------------- 1 | # file generated by setuptools-scm 2 | # don't change, don't track in version control 3 | 4 | __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"] 5 | 6 | TYPE_CHECKING = False 7 | if TYPE_CHECKING: 8 | from typing import Tuple 9 | from typing import Union 10 | 11 | VERSION_TUPLE = Tuple[Union[int, str], ...] 12 | else: 13 | VERSION_TUPLE = object 14 | 15 | version: str 16 | __version__: str 17 | __version_tuple__: VERSION_TUPLE 18 | version_tuple: VERSION_TUPLE 19 | 20 | __version__ = version = '0.1.dev34+gf3bfe13' 21 | __version_tuple__ = version_tuple = (0, 1, 'dev34', 'gf3bfe13') 22 | -------------------------------------------------------------------------------- /josh_train/agents/fc_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import yaml 3 | import copy 4 | import json 5 | from josh_train.utils import request_openai, handle_api_calls 6 | import os 7 | import josh_train.config as config 8 | from josh_train.josh import BaseJOSHAgent 9 | 10 | class FCAgentSimulator(BaseJOSHAgent): 11 | def __init__(self, api_examples, api_defs, model_name:Optional[str]=None, temperature = 0.0, debug=False): 12 | super().__init__() 13 | cwd = os.getcwd() 14 | with open(f'{cwd}/prompts/prompts.yaml', 'r') as file: 15 | prompts = yaml.safe_load(file) 16 | self.api_defs = api_defs 17 | self.api_examples = api_examples 18 | self.apis_to_examples = {x['name']: x for x in api_examples} 19 | 20 | if os.path.isfile('data/tools.json'): 21 | with open('data/tools.json', 'r') as file: 22 | tools_list = json.load(file) 23 | else: 24 | tools_list = self._create_oai_function_list(api_examples) 25 | with open('data/tools_updated.json', 'w') as file: 26 | json.dump(tools_list, file, indent=2) 27 | self.tool_list = tools_list 28 | self.MONO_PROMPT = "You are a travel agent. Help the customer with the provided apis. Do not say you can do something you cannot. You can only do things with the provided apis." 29 | self.modelname = model_name 30 | self.debug = debug 31 | self.temperature = temperature 32 | 33 | def _create_oai_function_list(self, api_examples): 34 | official_list = [] 35 | for item in api_examples: 36 | action, domain = item['name'].split('_') 37 | api_def = {'type':'function', 'function':{'name':item['name'], 'description': f"Allows you to {action} a {domain}", 'parameters':{'type':'object', 'required':[], 'properties':{}}}} 38 | for param in item['parameters'].keys(): 39 | api_def['function']['parameters']['properties'][param] = {'type': 'string', 'description':f'One example would be {item["parameters"][param]}'} 40 | official_list.append(api_def) 41 | return official_list 42 | 43 | def request(self, messages): 44 | output = request_openai(messages, self.modelname, config.client, tools=self.tool_list, temperature=self.temperature) 45 | return output 46 | 47 | 48 | def _step(self, turn, conversation_state): 49 | msg_return = [] 50 | if turn.tool_calls is not None: 51 | # API Call 52 | tool_call = turn.tool_calls[0] 53 | msg_return.append({ 54 | 'role':'assistant', 55 | 'tool_calls' : [ { 56 | 'id' : tool_call.id, 57 | 'type' : 'function', 58 | 'function' : { 59 | 'name' : tool_call.function.name, 60 | 'arguments' : tool_call.function.arguments 61 | } 62 | } ] 63 | } 64 | ) 65 | 66 | api_call = { 67 | "name": tool_call.function.name, 68 | "arguments": json.loads(tool_call.function.arguments), 69 | } 70 | if self.debug: 71 | print(json.dumps(api_call)) 72 | api_returns = handle_api_calls(api_call['name'], api_call['arguments'], conversation_state=conversation_state) 73 | if type(api_returns)==list: 74 | self.recent_actions.append({'name':api_call['name'], 'parameters': api_call['arguments'], 'returned': api_returns[0] if len(api_returns)>0 else api_returns}) 75 | else: 76 | self.recent_actions.append({'name':api_call['name'], 'parameters': api_call['arguments'], 'returned': api_returns}) 77 | if self.debug: 78 | print(json.dumps(api_returns)) 79 | # Add the return 80 | msg_return.append( 81 | { 82 | "role": "tool", 83 | "tool_call_id": tool_call.id, 84 | "content": json.dumps(api_returns), 85 | } 86 | ) 87 | 88 | return msg_return 89 | else: 90 | # Speak 91 | if self.debug: 92 | print(turn.content) 93 | msg_return.append({'role':'assistant', 'content':turn.content}) 94 | return msg_return 95 | 96 | def step(self, **kwargs): 97 | conversation_state = kwargs['env'] 98 | 99 | self.recent_actions = [] 100 | count=0 101 | while count < 3: 102 | agent_messages = [{'role':'system', 'content':self.MONO_PROMPT}]+self.messages_internal 103 | turn = self.request(agent_messages) 104 | message_return = self._step(turn, conversation_state) 105 | self.messages_internal.extend(message_return) 106 | if all([x['role']!='tool' for x in message_return]): 107 | self.messages.extend(message_return) 108 | return 109 | count+=1 110 | self.messages_internal.append({'role':'assistant', 'content':'Error: Agent ran out of retries.'}) 111 | self.messages.append({'role':'assistant', 'content':'Error: Agent ran out of retries.'}) 112 | return 113 | -------------------------------------------------------------------------------- /josh_train/agents/react_agent.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | import yaml 4 | import re 5 | import copy 6 | import json 7 | from transformers import pipeline 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | import torch 10 | from josh_train.utils import make_transcript, request_openai, parse_api_call, handle_api_calls 11 | import os 12 | import josh_train.config as config 13 | from josh_train.josh import BaseJOSHAgent 14 | 15 | 16 | class ReACTAgentSimulator(BaseJOSHAgent): 17 | def __init__(self, api_examples, api_defs, model_name:Optional[str]=None, temperature=0.0, debug = False): 18 | super().__init__() 19 | cwd = os.getcwd() 20 | with open(f'{cwd}/prompts/prompts.yaml', 'r') as file: 21 | prompts = yaml.safe_load(file) 22 | self.api_defs = api_defs 23 | self.api_examples = api_examples 24 | self.apis_to_examples = {x['name']: x for x in api_examples} 25 | with open(f'{cwd}/data/tools.json', 'r') as file: 26 | tools_list = json.load(file) 27 | self.MONO_PROMPT = prompts['react_prompt'].replace('{example_filled}', json.dumps(tools_list, indent=2)) 28 | self.pattern = "(PLAN|APICALL|SPEAK)(.*?)(?=PLAN|APICALL|SPEAK|$)" 29 | self.model_name=model_name 30 | self.debug = debug 31 | self.temperature = temperature 32 | 33 | def parse_agent_message(self, output): 34 | commands = re.findall(self.pattern , output , re.DOTALL) 35 | return commands 36 | 37 | 38 | def request(self, messages, model=None, tokenizer=None) -> str: 39 | if model and tokenizer: 40 | encoding = tokenizer.apply_chat_template(messages, return_tensors="pt").to('cuda') 41 | prompt_len=len(encoding[0]) 42 | with torch.no_grad(): 43 | if math.isclose(self.temperature, 0.0, rel_tol=1e-6): 44 | generated_ids = model.generate(encoding, max_new_tokens=256, do_sample=False) 45 | else: 46 | generated_ids = model.generate(encoding, max_new_tokens=256, temperature=self.temperature, top_k=50, top_p=0.95) 47 | return tokenizer.batch_decode(generated_ids[0][prompt_len:].unsqueeze(0), skip_special_tokens=True)[0] 48 | else: 49 | output = request_openai(messages, self.model_name, config.client, temperature=self.temperature) 50 | return output 51 | 52 | def handle_api(self, command, conversation_state): 53 | try: 54 | api_values = parse_api_call(command) 55 | except: 56 | return 'FAILURE INCORRECTLY FORMATTED APICALL', None 57 | if api_values['api_name'] not in self.apis_to_examples: 58 | return 'FAILURE INCORRECTLY FORMATTED APICALL', None 59 | returns = handle_api_calls(api_values['api_name'], api_values['api_args'], conversation_state=conversation_state) 60 | if type(returns)==list: 61 | called_api = {'name':api_values['api_name'], 'parameters': api_values['api_args'], 'returned': returns[0] if len(returns)>0 else returns} 62 | else: 63 | called_api = {'name':api_values['api_name'], 'parameters': api_values['api_args'], 'returned': returns} 64 | return returns, called_api 65 | 66 | def step(self, **kwargs): 67 | conversation_state = kwargs['env'] 68 | model = kwargs['model'] 69 | tokenizer = kwargs['tokenizer'] 70 | 71 | self.recent_actions = [] 72 | count=0 73 | while count < 3: 74 | agent_messages = [{'role':'system', 'content':self.MONO_PROMPT}]+self.messages_internal 75 | turn = self.request(agent_messages, model, tokenizer) 76 | if self.debug: 77 | print(turn) 78 | parsed = self.parse_agent_message(turn.replace('', '').strip().replace('\n','').replace('\\','')) 79 | if len(parsed)==0: 80 | self.messages_internal.append({'role':'assistant', 'content':'ERROR: NO COMMAND FOUND'}) 81 | thought_string = '' 82 | for command_type, command in parsed: 83 | command_type = command_type.strip() 84 | command=command.strip() 85 | if command_type=='PLAN': 86 | thought_string = 'PLAN '+command+' ' 87 | elif command_type == 'SPEAK': 88 | self.messages_internal.append({'role':'assistant', 'content':thought_string+'SPEAK '+command+' '}) 89 | self.messages.append({'role':'assistant', 'content':command}) 90 | return 91 | elif command_type == 'APICALL': 92 | command = command.strip().replace('\n','') 93 | output, called_api = self.handle_api(command, conversation_state) 94 | self.recent_actions.append(called_api) 95 | if self.debug: 96 | print(output) 97 | # Add the api call 98 | self.messages_internal.append({'role':'assistant', 'content':thought_string+'APICALL '+command+' '}) 99 | # Add the return 100 | self.messages_internal.append({'role':'user', 'content':'APIRETURN ' + json.dumps(output)}) 101 | else: 102 | self.messages_internal.append({'role':'assistant', 'content':'ERROR: INVALID COMMAND TYPE'}) 103 | count+=1 104 | self.messages.append({'role':'assistant', 'content':'Error: Agent ran out of retries.'}) 105 | return 106 | -------------------------------------------------------------------------------- /josh_train/build_apis.py: -------------------------------------------------------------------------------- 1 | import json 2 | # Load in MultiWOZ 2.2 data 3 | with open('data/multi-woz/data.json') as fin1: 4 | data = json.load(fin1) 5 | 6 | # Build the api arguments from real MultiWOZ examples 7 | apis = {} 8 | for k in data.keys(): 9 | for idx, tstt in enumerate(data[k]['log']): 10 | if idx % 2 == 0: 11 | continue 12 | tst = tstt['metadata'] 13 | # assert base.keys() == tst.keys() 14 | for domain in ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital']: 15 | if domain not in apis: 16 | apis[domain] = {} 17 | 18 | for key in tst[domain].keys(): 19 | api_action = 'book' if key=='book' else 'search' 20 | for values in tst[domain][key].keys(): 21 | if f'{api_action}_{domain}' not in apis[domain]: 22 | apis[domain][f'{api_action}_{domain}'] = {'parameters': [], 'returns': []} 23 | 24 | if api_action == 'book' and values == 'booked': 25 | booked_list = tst[domain][key][values] 26 | for book in booked_list: 27 | [apis[domain][f'{api_action}_{domain}']['returns'].append(k) for k in book.keys() if k not in apis[domain][f'{api_action}_{domain}']['returns']] 28 | else: 29 | if values not in apis[domain][f'{api_action}_{domain}']['parameters']: 30 | apis[domain][f'{api_action}_{domain}']['parameters'].append(values) 31 | 32 | # Build the return types for search apis by pulling return information from the databases 33 | import sqlite3 34 | # These are the only domains that have databases, so they are the only ones we will use search on 35 | domains = ['restaurant', 'hotel', 'attraction', 'train'] 36 | dbs = {} 37 | for domain in domains: 38 | db = 'db/{}-dbase.db'.format(domain) 39 | conn = sqlite3.connect(db) 40 | c = conn.cursor() 41 | dbs[domain] = c 42 | 43 | # Query the domains to find the column names 44 | res = {} 45 | for domain in domains: 46 | c = dbs[domain] 47 | # Get all table names in the DB. 48 | c.execute("SELECT name FROM sqlite_master WHERE type='table';") 49 | tables = c.fetchall() 50 | 51 | # Iterate over each table. 52 | for table_name in tables: 53 | table_name = table_name[0] 54 | c.execute('PRAGMA TABLE_INFO({})'.format(table_name)) 55 | # Collect the column names. 56 | columns = [tup[1] for tup in c.fetchall()] 57 | # Save the result. 58 | if domain not in res: 59 | res[domain] = {} 60 | res[domain][table_name] = columns 61 | 62 | # Add returns to search apis 63 | for domain in res.keys(): 64 | apis[domain][f'search_{domain}']['returns'] = res[domain][domain] 65 | 66 | 67 | with open('apis.json', 'w') as file: 68 | json.dump(apis, file, indent=2) -------------------------------------------------------------------------------- /josh_train/config.py: -------------------------------------------------------------------------------- 1 | client=None -------------------------------------------------------------------------------- /josh_train/conversation_types/conversation_state.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from josh_train.utils import * 3 | 4 | class Conversation: 5 | def __init__(self, key, apis, delex): 6 | self.dbs = create_dbs() 7 | self.apis = apis[key] 8 | self.convo_key = key 9 | 10 | self.goals = delex[key]['goal'] 11 | self.ground_truth_conversaiton = [{'role': 'assistant' if utterance['metadata'] else 'user', 'content':utterance['text']} for utterance in delex[key]['log']] 12 | 13 | self.called_apis = [] 14 | 15 | # Filter the DBs because sometimes failure cases are present 16 | for domain in self.apis.keys(): 17 | if f'search_{domain}' in self.apis[domain]['failure']: 18 | filter_dbs(domain, self.apis[domain]['failure'][f'search_{domain}']['parameters'], self.dbs) 19 | 20 | def close_convos(self): 21 | keys = [k for k in self.dbs.keys()] 22 | for k in keys: 23 | del self.dbs[k] 24 | 25 | def add_api_call(self, api_name, api_args, returned): 26 | if 'search' == api_name.split('_')[0] and type(returned)==list: 27 | returned_value = returned[0] if len(returned)!=0 else None 28 | else: 29 | returned_value = returned 30 | self.called_apis.append({'name':api_name, 'parameters': api_args, 'returned': returned_value}) 31 | 32 | def _test_book_apis(self, goal_api, domain): 33 | # Loop through all the called apis 34 | for called_api in self.called_apis: 35 | if called_api['name'] == goal_api: 36 | successful_parameters = {k:v for k,v in self.apis[domain]['success'][goal_api]['parameters'].items() if not test_if_val_is_empty(v)} 37 | subset_test = is_subset(successful_parameters, called_api['parameters']) 38 | if subset_test: 39 | return 1 40 | return 0 41 | 42 | def _test_search_apis(self, goal_api, domain): 43 | correctly_searched = 0 44 | unique_id_type = 'trainID' if domain == 'train' else 'name' 45 | # Pull the correct item 46 | unique_id_of_correct = self.apis[domain]['success'][f'book_{domain}']['unique_id'] if f'book_{domain}' in self.apis[domain]['success'] else '' 47 | for called_api in self.called_apis: 48 | if not called_api['returned'] or type(called_api['returned'])==str: 49 | continue 50 | if called_api['name'] == goal_api: 51 | goal_parameters = {k:v for k,v in self.apis[domain]['success'][goal_api]['parameters'].items() if not test_if_val_is_empty(v)} 52 | # if the Goal (correct) parameters are a subset of what was called 53 | if is_subset(goal_parameters, called_api['parameters']): 54 | # if theres no "technically" correct answer, its correct 55 | if len(unique_id_of_correct)==0: 56 | correctly_searched = 1 57 | # if theres a "correct" answer, make sure that was returned 58 | else: 59 | # if the unique id is correct its correct 60 | if unique_id_of_correct == called_api['returned'][unique_id_type]: 61 | correctly_searched = 1 62 | # else it's wrong. 63 | # if the called api is a subset of the Goal (correct) parameters 64 | elif is_subset(called_api['parameters'], goal_parameters): 65 | # if theres no "technically" correct answer 66 | if len(unique_id_of_correct)==0: 67 | # if the goals are a subset of the returned object, its "right" 68 | if is_subset(goal_parameters, called_api['returned']): 69 | correctly_searched = 1 70 | # else it is not right 71 | # if there is a "correct" answer 72 | else: 73 | # if the unique id is correct its correct 74 | if unique_id_of_correct == called_api['returned'][unique_id_type]: 75 | correctly_searched = 1 76 | # else it's wrong. 77 | # neither, so it is incorrect 78 | return correctly_searched 79 | 80 | 81 | def evaluate_apis(self): 82 | number_of_successful_apis = 0 83 | correct_calls = 0 84 | failed_api_calls = [] 85 | for domain in self.apis.keys(): 86 | for goal_api in self.apis[domain]['success'].keys(): 87 | #Looping through each goal api 88 | action, _ = goal_api.split('_') 89 | if action == 'book': 90 | successful_call = self._test_book_apis(goal_api, domain) 91 | elif action == 'search': 92 | successful_call = self._test_search_apis(goal_api, domain) 93 | else: 94 | successful_call = 0 95 | 96 | if successful_call==0: 97 | failed_api_calls.append({goal_api:self.apis[domain]['success'][goal_api]['parameters']}) 98 | 99 | correct_calls += successful_call 100 | number_of_successful_apis+=1 101 | 102 | return correct_calls/number_of_successful_apis, failed_api_calls -------------------------------------------------------------------------------- /josh_train/conversation_types/conversation_state_pref_tree.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import copy 3 | from josh_train.utils import * 4 | 5 | class Conversation: 6 | def __init__(self, key, apis, delex): 7 | self.dbs = create_dbs() 8 | self.apis = apis[key] 9 | self.apis_for_eval = copy.deepcopy(apis[key]) 10 | self.convo_key = key 11 | 12 | self.goals = delex[key]['goal'] 13 | self.ground_truth_conversaiton = [{'role': 'assistant' if utterance['metadata'] else 'user', 'content':utterance['text']} for utterance in delex[key]['log']] 14 | 15 | # Filter the DBs because sometimes failure cases are present 16 | for domain in self.apis.keys(): 17 | if f'search_{domain}' in self.apis[domain]['failure']: 18 | filter_dbs(domain, self.apis[domain]['failure'][f'search_{domain}']['parameters'], self.dbs) 19 | 20 | self.called_apis = [] 21 | 22 | def close_convos(self): 23 | keys = [k for k in self.dbs.keys()] 24 | for k in keys: 25 | del self.dbs[k] 26 | 27 | def add_api_call(self, api_name, api_args, returned): 28 | if 'search' == api_name.split('_')[0] and type(returned)==list: 29 | returned_value = returned[0] if len(returned)!=0 else None 30 | else: 31 | returned_value = returned 32 | self.called_apis.append({'name':api_name, 'parameters': api_args, 'returned': returned_value}) 33 | 34 | def _test_book_apis(self, goal_api, domain, called_apis): 35 | # Loop through all the called apis 36 | for called_api in called_apis: 37 | if called_api['name'] == goal_api: 38 | successful_parameters = {k:v for k,v in self.apis_for_eval[domain]['success'][goal_api]['parameters'].items() if not test_if_val_is_empty(v)} 39 | subset_test = is_subset(successful_parameters, called_api['parameters']) 40 | if subset_test: 41 | return 1 42 | return 0 43 | 44 | def _test_search_apis(self, goal_api, domain, called_apis): 45 | correctly_searched = 0 46 | unique_id_type = 'trainID' if domain == 'train' else 'name' 47 | # Pull the correct item 48 | unique_id_of_correct = self.apis_for_eval[domain]['success'][f'book_{domain}']['unique_id'] if f'book_{domain}' in self.apis_for_eval[domain]['success'] else '' 49 | for called_api in called_apis: 50 | if not called_api['returned'] or type(called_api['returned'])==str: 51 | continue 52 | if called_api['name'] == goal_api: 53 | goal_parameters = {k:v for k,v in self.apis_for_eval[domain]['success'][goal_api]['parameters'].items() if not test_if_val_is_empty(v)} 54 | # if the Goal (correct) parameters are a subset of what was called 55 | if is_subset(goal_parameters, called_api['parameters']): 56 | # if theres no "technically" correct answer, its correct 57 | if len(unique_id_of_correct)==0: 58 | correctly_searched = 1 59 | # if theres a "correct" answer, make sure that was returned 60 | else: 61 | # if the unique id is correct its correct 62 | if unique_id_of_correct == called_api['returned'][unique_id_type]: 63 | correctly_searched = 1 64 | # else it's wrong. 65 | # if the called api is a subset of the Goal (correct) parameters 66 | elif is_subset(called_api['parameters'], goal_parameters): 67 | # if theres no "technically" correct answer 68 | if len(unique_id_of_correct)==0: 69 | # if the goals are a subset of the returned object, its "right" 70 | if is_subset(goal_parameters, called_api['returned']): 71 | correctly_searched = 1 72 | # else it is not right 73 | # if there is a "correct" answer 74 | else: 75 | # if the unique id is correct its correct 76 | if unique_id_of_correct == called_api['returned'][unique_id_type]: 77 | correctly_searched = 1 78 | # else it's wrong. 79 | # neither, so it is incorrect 80 | return correctly_searched 81 | 82 | 83 | def evaluate_apis(self, called_apis): 84 | number_of_successful_apis = 0 85 | correct_calls = 0 86 | failed_api_calls = [] 87 | apis_to_delete = [] 88 | for domain in self.apis_for_eval.keys(): 89 | for goal_api in self.apis_for_eval[domain]['success'].keys(): 90 | #Looping through each goal api 91 | action, _ = goal_api.split('_') 92 | if action == 'book': 93 | successful_call = self._test_book_apis(goal_api, domain, called_apis) 94 | elif action == 'search': 95 | successful_call = self._test_search_apis(goal_api, domain, called_apis) 96 | else: 97 | successful_call = 0 98 | 99 | if successful_call==0: 100 | failed_api_calls.append({goal_api:self.apis_for_eval[domain]['success'][goal_api]['parameters']}) 101 | 102 | if successful_call==1: 103 | apis_to_delete.append((domain, goal_api)) 104 | correct_calls += successful_call 105 | number_of_successful_apis+=1 106 | 107 | return correct_calls, number_of_successful_apis, failed_api_calls, apis_to_delete -------------------------------------------------------------------------------- /josh_train/finetune/kto_ft.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datasets import Dataset, DatasetDict 4 | import pandas as pd 5 | from trl import KTOConfig, KTOTrainer 6 | import torch 7 | from transformers import AutoTokenizer, AutoModelForCausalLM 8 | import yaml 9 | import os 10 | from datasets import load_dataset 11 | from peft import LoraConfig 12 | config = LoraConfig( 13 | r=64, 14 | lora_alpha=64, 15 | target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", 16 | "gate_proj", "up_proj", "down_proj",], 17 | lora_dropout=0.05, 18 | bias="none", # add bias to the nn.Linear layers? 19 | task_type="CAUSAL_LM", 20 | ) 21 | model_name = "meta-llama/Meta-Llama-3-8B-Instruct" 22 | model = AutoModelForCausalLM.from_pretrained( 23 | model_name, 24 | device_map="auto", 25 | torch_dtype=torch.bfloat16, 26 | load_in_4bit=True, 27 | bnb_4bit_compute_dtype=torch.bfloat16, 28 | bnb_4bit_use_double_quant=True, 29 | bnb_4bit_quant_type="nf4", 30 | attn_implementation="flash_attention_2", 31 | ) 32 | 33 | model_ref = AutoModelForCausalLM.from_pretrained( 34 | model_name, 35 | device_map="auto", 36 | torch_dtype=torch.bfloat16, 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.bfloat16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type="nf4", 41 | attn_implementation="flash_attention_2", 42 | ) 43 | 44 | 45 | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") 46 | tokenizer.pad_token = tokenizer.eos_token 47 | 48 | 49 | dataset = load_dataset("json", data_files="/root/code/multiwoz-api/train_data_kto_just_div.jsonl", split="train") 50 | 51 | def format_dataset(example): 52 | example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False) 53 | example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False) 54 | return example 55 | 56 | formatted_dataset = dataset.map(format_dataset) 57 | 58 | training_args = KTOConfig( 59 | report_to="tensorboard", 60 | per_device_train_batch_size = 1, 61 | gradient_accumulation_steps = 16, 62 | warmup_ratio = 0.1, 63 | num_train_epochs = 1, 64 | learning_rate = 5e-7, 65 | fp16 = not torch.cuda.is_bf16_supported(), 66 | bf16 = torch.cuda.is_bf16_supported(), 67 | logging_steps = 1, 68 | logging_dir="/root/code/multiwoz-api/logs/", 69 | optim = "adamw_8bit", 70 | lr_scheduler_type = "cosine", 71 | seed = 42, 72 | beta=0.1, 73 | desirable_weight=1.0, 74 | undesirable_weight=2.71, 75 | gradient_checkpointing = True, 76 | save_strategy= "steps", 77 | save_steps= 100, 78 | max_prompt_length=5000, 79 | max_length=5256, 80 | output_dir = 'data/llama_8b_kto_0612', 81 | ) 82 | 83 | kto_trainer = KTOTrainer( 84 | model = model, 85 | ref_model = model_ref, 86 | peft_config = config, 87 | args = training_args, 88 | train_dataset = formatted_dataset, 89 | tokenizer = tokenizer, 90 | ) 91 | 92 | 93 | kto_trainer.train() 94 | -------------------------------------------------------------------------------- /josh_train/finetune/sft_agent_final_mod.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datasets import Dataset, DatasetDict 4 | import pandas as pd 5 | from trl import KTOConfig, KTOTrainer 6 | import torch 7 | from transformers import AutoTokenizer, AutoModelForCausalLM 8 | import yaml 9 | import numpy as np 10 | import os 11 | from josh_train.utils import parse_api_call 12 | from huggingface_hub import login 13 | login(token="HF_TOKEN") 14 | from peft import LoraConfig, get_peft_model 15 | from trl import SFTTrainer 16 | from transformers import TrainingArguments 17 | from datasets import load_dataset 18 | config = LoraConfig( 19 | r=64, 20 | lora_alpha=64, 21 | target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", 22 | "gate_proj", "up_proj", "down_proj",], 23 | lora_dropout=0.05, 24 | bias="none", # add bias to the nn.Linear layers? 25 | task_type="CAUSAL_LM", 26 | ) 27 | model_name = "meta-llama/Meta-Llama-3-8B-Instruct" 28 | model = AutoModelForCausalLM.from_pretrained( 29 | model_name, 30 | device_map="auto", 31 | torch_dtype=torch.bfloat16, 32 | load_in_4bit=True, 33 | bnb_4bit_compute_dtype=torch.bfloat16, 34 | bnb_4bit_use_double_quant=True, 35 | bnb_4bit_quant_type="nf4", 36 | attn_implementation="flash_attention_2", 37 | ) 38 | 39 | 40 | dataset = load_dataset("json", data_files="/root/code/multiwoz-api/train_data_full_convo.jsonl", split="train") 41 | trainer = SFTTrainer( 42 | model = model, 43 | train_dataset = dataset, 44 | peft_config = config, 45 | args = TrainingArguments( 46 | num_train_epochs = 3, 47 | per_device_train_batch_size = 1, 48 | gradient_accumulation_steps = 16, 49 | warmup_steps = 5, 50 | learning_rate = 2e-5, 51 | fp16 = not torch.cuda.is_bf16_supported(), 52 | bf16 = torch.cuda.is_bf16_supported(), 53 | logging_steps = 10, 54 | logging_dir = "/root/code/multiwoz-api/multiwoz_api/logs", 55 | optim = "adamw_8bit", 56 | weight_decay = 0.01, 57 | lr_scheduler_type = "linear", 58 | seed = 42, 59 | output_dir = "data/sft_0611_output", 60 | gradient_checkpointing = True, 61 | save_strategy= "steps", 62 | save_steps= 13, 63 | report_to='tensorboard', 64 | ), 65 | ) 66 | trainer_stats = trainer.train() 67 | -------------------------------------------------------------------------------- /josh_train/users/base_user_simulator.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Dict, List 3 | from josh_train.josh import BaseJOSHAgent 4 | 5 | class BaseUserSimulator: 6 | def __init__(self): 7 | pass 8 | def step(self, agent:BaseJOSHAgent): 9 | return agent, True -------------------------------------------------------------------------------- /josh_train/users/goal_user_simulator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import yaml 3 | from josh_train.users.base_user_simulator import BaseUserSimulator 4 | from josh_train.utils import compute_cost, request_openai 5 | import josh_train.config as config 6 | from josh_train.josh import BaseJOSHAgent 7 | 8 | class GoalUserSimulator(BaseUserSimulator): 9 | def __init__(self, goals, modelname, debug =False): 10 | self.modelname = modelname 11 | self.goals = goals['message'] 12 | with open('prompts/prompts.yaml', 'r') as file: 13 | prompts = yaml.safe_load(file) 14 | self.prompt = prompts['user_simulator_prompt'] 15 | self.opening_prompt = prompts['user_simulator_opening_message_prompt'].replace('{goals}', '\n'.join(self.goals)) 16 | self.system_prompt = prompts['user_simulator_system_prompt'].replace('{goals}', '\n'.join(self.goals)) 17 | self.cost = 0.0 18 | self.debug = debug 19 | 20 | def make_messages(self, messages): 21 | if len(messages)==0: 22 | return [{'role': 'system', 'content': self.opening_prompt}] 23 | mod_messages = [{'role':'user' if message['role']=='assistant' else 'assistant', 'content':message['content']} for message in messages] 24 | return [{'role': 'system', 'content': self.system_prompt}] + mod_messages + [{'role': 'system', 'content':self.prompt}] 25 | 26 | def step(self, agent): 27 | messages = copy.deepcopy(agent.messages) 28 | user_sim_messages = self.make_messages(messages=messages) 29 | output = request_openai(user_sim_messages, self.modelname, config.client) 30 | self.cost += compute_cost(user_sim_messages+ [{'role': 'user', 'content': output}]) 31 | 32 | if self.debug: 33 | print('#'*30) 34 | print(f'USER: {output}') 35 | print('#'*30) 36 | 37 | agent.messages.append({'role': 'user', 'content': output}) 38 | agent.messages_internal.append({'role': 'user', 'content': output}) 39 | 40 | end_convo = 'END_CONVERSATION' in output 41 | return agent, end_convo 42 | -------------------------------------------------------------------------------- /josh_train/users/guide_user_simulator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import yaml 3 | from josh_train.utils import make_transcript, request_openai 4 | from josh_train.users.base_user_simulator import BaseUserSimulator 5 | import josh_train.config as config 6 | from josh_train.josh import BaseJOSHAgent 7 | guide_prompt = """You are a coach giving tips to a user simulator trying to replicate a conversation as consistently as possible. The user simulator is in the middle of a conversation, give it advice on what to do in the next turn. 8 | Consistency means that over multiple runs, the user simulator should behave in the exact same way, it is your job to try and help it stay on the same trajectory every run. 9 | 10 | ###### Grounding Goals and Conversation ######### 11 | Customer goals: 12 | {goals} 13 | 14 | The following is the source conversation the user simulator is trying to replicate: 15 | {goal_convo} 16 | ################################################### 17 | 18 | ######## CURRENT (real) Conversation ####################### 19 | This is the CURRENT conversaiton the user simulator is having: 20 | {current_convo} 21 | 22 | Use your best judgement if the conversation is not going well, it's possible the agent is not good enough and you need to end the conversation. End the conversation by putting END_CONVERSATION after your quote. 23 | Keep in mind the Customer goals all must be communicated in order to give the agent enough information to properly search and book. 24 | It is critical you give consistent advice over multiple iterations of the same conversation. The best way to do that is to ground your response in the source conversation and providing quotes whenever possible. 25 | Please write breif advice on what the user simulator should say in order to keep it consistent and aligned with the source conversation. Write this advice to the user simulatior, referring to it as "you". No yapping.: 26 | 27 | Example: 28 | Advice: 29 | The user should ... 30 | Suggested quote: 31 | "Hello, how can I help you?" 32 | 33 | Advice: 34 | The conversation should be ended 35 | Suggested quote: 36 | "Thanks, goodbye" END_CONVERSATION 37 | 38 | Output: 39 | """ 40 | 41 | class GuideUserSimulator(BaseUserSimulator): 42 | def __init__(self, goals, convo, modelname='gpt-4o-2024-05-13', debug=False): 43 | self.goals = goals['message'] 44 | with open('prompts/prompts.yaml', 'r') as file: 45 | prompts = yaml.safe_load(file) 46 | self.prompt = prompts['user_simulator_prompt'] 47 | self.modelname = modelname 48 | self.opening_prompt = prompts['user_simulator_opening_message_prompt'].replace('{goals}', '\n'.join(self.goals)) 49 | self.system_prompt = prompts['user_simulator_system_prompt'].replace('{goals}', '\n'.join(self.goals)) 50 | self.convo = [] 51 | self.debug = debug 52 | for log in convo['log']: 53 | tag = 'agent: ' if log['metadata'] else 'customer: ' 54 | self.convo.append(tag+log['text'].strip()) 55 | def make_messages(self, messages, system=None, prompt=None): 56 | mod_messages = [{'role':'user' if message['role']=='assistant' else 'assistant', 'content':message['content']} for message in messages] 57 | return [{'role': 'system', 'content': self.system_prompt if not system else system}] + mod_messages + [{'role': 'system', 'content':self.prompt if not prompt else prompt}] 58 | 59 | def step(self, agent): 60 | user_sim_messages = copy.deepcopy(agent.messages) 61 | transcript = make_transcript(user_sim_messages, {'assistant':'user', 'user':'agent'}) 62 | filled_guide_prompt = guide_prompt.replace('{goals}', '\n'.join(self.goals)).replace('{goal_convo}', '\n'.join(self.convo)).replace('{current_convo}', transcript if len(transcript)!=0 else 'NOT YET STARTED') 63 | output_guide = request_openai([{'role': 'user', 'content': filled_guide_prompt}], self.modelname, config.client) 64 | output = output_guide.split("Suggested quote:")[-1].replace('\n', '').replace('"', '') 65 | 66 | if self.debug: 67 | print('#'*30) 68 | print(f'USER: {output}') 69 | print('#'*30) 70 | 71 | agent.messages.append({'role': 'user', 'content': output}) 72 | agent.messages_internal.append({'role': 'user', 'content': output}) 73 | 74 | end_convo = 'END_CONVERSATION' in output 75 | 76 | return agent, end_convo -------------------------------------------------------------------------------- /josh_train/users/script_user_simulator.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from josh_train.utils import make_transcript, request_openai 3 | from josh_train.users.base_user_simulator import BaseUserSimulator 4 | 5 | class ScriptUserSimulator(BaseUserSimulator): 6 | def __init__(self, goals, convo): 7 | self.convo = [x['text'] for x in convo['log'] if not x['metadata']] 8 | self.idx = -1 9 | def step(self, messages): 10 | self.idx += 1 11 | if self.idx >= len(self.convo): 12 | return [{'role':'user', 'content':'END_CONVERSATION'}] 13 | return [{'role': 'user', 'content': self.convo[self.idx]}] -------------------------------------------------------------------------------- /mtbencheval.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/mtbencheval.zip -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # pyproject.toml 2 | [build-system] 3 | requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] 4 | 5 | [tool.setuptools_scm] 6 | write_to = "josh_train/_version.py" 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # put requriemetns here 2 | python-dotenv 3 | colorama 4 | openai 5 | pyyaml 6 | datasets 7 | accelerate 8 | transformers 9 | numpy 10 | pandas 11 | trl 12 | torch 13 | deepsig 14 | tiktoken -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | junit_family = xunit1 3 | addopts = -ra -vl 4 | --junitxml=test-results/$TOX_ENVNAME/junit.xml 5 | --cov-config setup.cfg 6 | --cov=josh_train 7 | --cov-report html:test-results/coverage/cov_html 8 | --cov-report xml:test-results/coverage/cov.xml 9 | --cov-report term 10 | --html=test-results/$TOX_ENVNAME/report.html 11 | --self-contained-html 12 | markers = 13 | integration: mark a test as integration 14 | unit: mark a test as a unit test 15 | slow: mark a test as being slow 16 | notdev: mark a test as needing additional credentials or resources than most dev machines 17 | pinned: mark a test as potentially breaking if a particular dataset changes 18 | 19 | testpaths = test 20 | norecursedirs = .git test/helpers test/test-data test/test-cases .toxenv .toxenv36 .toxenv37 *site-packages* 21 | 22 | [flake8] 23 | max-line-length = 120 24 | ignore = W503 W504 25 | per-file-ignores = 26 | __init__.py:F401,F403 27 | 28 | [coverage:report] 29 | # Regexes for lines to exclude from consideration 30 | exclude_lines = 31 | # Have to re-enable the standard pragma 32 | pragma: no cover 33 | 34 | # Don't complain about missing debug-only code: 35 | def __repr__ 36 | if self\.debug 37 | 38 | # Don't complain if tests don't hit defensive assertion code: 39 | raise AssertionError 40 | raise NotImplementedError 41 | 42 | # Don't complain if non-runnable code isn't run: 43 | if 0: 44 | if __name__ == .__main__.: 45 | 46 | [coverage:html] 47 | directory = test-results/$TOX_ENVNAME/coverage/cov 48 | 49 | [coverage:paths] 50 | source = 51 | josh_train/ 52 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from setuptools import PEP420PackageFinder, setup 4 | 5 | 6 | def get_requires(requires_filename: str) -> List[str]: 7 | requirements = [] 8 | with open(requires_filename, "r") as infile: 9 | for line in infile.readlines(): 10 | line = line.strip() 11 | requirements.append(line) 12 | return requirements 13 | 14 | 15 | setup( 16 | name="josh_train", 17 | description="A simulation training framework", 18 | long_description=open("README.md", "r").read(), 19 | long_description_content_type="text/markdown", 20 | author="blattimer", 21 | author_email="blattimer@asapp.com", 22 | url="https://github.com/josh-llm-simulation-training", 23 | packages=PEP420PackageFinder.find(exclude=("test*",)), 24 | python_requires=">=3.8", 25 | install_requires=get_requires("requirements.txt"), 26 | include_package_data=True, 27 | setup_requires=["setuptools_scm"], 28 | version="0.1.4", 29 | classifiers=[ 30 | "Programming Language :: Python :: 3", 31 | "License :: OSI Approved :: MIT License", 32 | "Operating System :: OS Independent", 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /takeAggregateWithOptBaseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import sys 5 | from deepsig import bootstrap_test, aso 6 | 7 | def get_aggregates(input_dir): 8 | data = [] 9 | 10 | 11 | # for file in os.listdir('./'+input_dir+"/"): 12 | with open('./'+input_dir, 'r') as f: 13 | data = json.load(f) 14 | 15 | num_of_100 = 0 16 | scores = [] 17 | for d in data: 18 | scores.append(d['reward']) 19 | if d['reward'] == 1.0: 20 | num_of_100+=1 21 | 22 | return data, scores, num_of_100 23 | 24 | 25 | if __name__ == "__main__": 26 | 27 | input_dir = sys.argv[1] 28 | data, scores, num_of_100 = get_aggregates(input_dir) 29 | variance = np.var(scores) 30 | print_model_name="Model" 31 | print("\t".join(["Approach","Full Success","Mean Reward","Test Samples","Reward Variance","Is Model Stat Sig Better? (Lower p below 0.01 --> Model is indeed betterer)"])) 32 | print(print_model_name,"\t",num_of_100/len(data),"\t",np.mean(scores),"\t",len(data),"\t",variance,"\t","N/A") 33 | 34 | if len(sys.argv)>=3: 35 | for baseline_id in range(len(sys.argv)-2): 36 | baseline_dir = sys.argv[2+baseline_id] 37 | data_baseline, scores_baseline, num_of_100_baseline = get_aggregates(baseline_dir) 38 | variance_baseline = np.var(scores_baseline) 39 | scores_model = scores 40 | 41 | #print("Model scores contain NaN:", np.isnan(scores).any()) 42 | #print("Model scores contain inf:", np.isinf(scores).any()) 43 | #print("Baseline scores contain NaN:", np.isnan(scores_baseline).any()) 44 | #print("Baseline scores contain inf:", np.isinf(scores_baseline).any()) 45 | 46 | p_value = bootstrap_test(np.array(scores_model),np.array(scores_baseline),seed=42,num_jobs=100,num_samples=100000) 47 | print(p_value) 48 | print_model_name="Baseline"+" "+str(baseline_id) 49 | print(print_model_name,"\t",num_of_100_baseline/len(data_baseline),"\t",np.mean(scores_baseline),"\t",len(data_baseline),"\t",variance_baseline,"\t",p_value) 50 | 51 | -------------------------------------------------------------------------------- /tau-bench-eval/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sierra 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tau-bench-eval/README.md: -------------------------------------------------------------------------------- 1 | # τ-bench: A Benchmark for Tool-Agent-User Interaction in Real-World Domains 2 | 3 | **Paper**: https://arxiv.org/abs/2406.12045 4 | 5 | Citation: 6 | 7 | ```bibtex 8 | @misc{yao2024tau, 9 | title={$\tau$-bench: A Benchmark for Tool-Agent-User Interaction in Real-World Domains}, 10 | author={Shunyu Yao and Noah Shinn and Pedram Razavi and Karthik Narasimhan}, 11 | year={2024}, 12 | eprint={2406.12045}, 13 | archivePrefix={arXiv}, 14 | primaryClass={cs.AI}, 15 | url={https://arxiv.org/abs/2406.12045}, 16 | } 17 | ``` 18 | 19 | ## Leaderboard 20 | 21 | | Strategy | Pass^1 | Pass^2 | Pass^3 | Pass^4 | 22 | | -------------- | ------ | ------ | ------ | ------ | 23 | | [TC (gpt-4o)](https://platform.openai.com/docs/guides/function-calling) | 0.380 | 0.237 | 0.185 | 0.160 | 24 | | [TC (claude-3-5-sonnet-20240620)](https://docs.anthropic.com/en/docs/build-with-claude/tool-use) | ?? | ?? | ?? | ?? | 25 | | [TC (mistral-large-2407)](https://docs.mistral.ai/capabilities/function_calling/) | ?? | ?? | ?? | ?? | 26 | | [TC (gpt-4o-mini)](https://platform.openai.com/docs/guides/function-calling) | 0.225 | 0.140 | 0.110 | 0.100 | 27 | | [Act](https://arxiv.org/abs/2210.03629) (gpt-4o) | 0.365 | 0.217 | 0.160 | 0.140 | 28 | | [ReAct](https://arxiv.org/abs/2210.03629) (gpt-4o) | 0.325 | 0.233 | 0.185 | 0.160 | 29 | 30 | *TC = `tool-calling` strategy (the function-calling strategy reported in the paper) 31 | 32 | ## Setup 33 | 34 | 1. Clone this repository: 35 | 36 | ```bash 37 | git clone https://github.com/sierra-research/tau-bench && cd ./tau-bench 38 | ``` 39 | 40 | 2. Install from source (which also installs required packages): 41 | 42 | ```bash 43 | pip install -e . 44 | ``` 45 | 46 | 3. Set up your OpenAI / Anthropic / Google / Mistral / AnyScale API keys as environment variables. 47 | 48 | ```bash 49 | OPENAI_API_KEY=... 50 | ANTHROPIC_API_KEY=... 51 | GOOGLE_API_KEY=... 52 | MISTRAL_API_KEY=... 53 | ``` 54 | 55 | ## Run 56 | 57 | Run a tool-calling agent on the τ-retail environment: 58 | 59 | ```bash 60 | python run.py --agent-strategy tool-calling --env retail --model gpt-4o --model-provider openai --user-model gpt-4o --user-model-provider openai --max-concurrency 10 61 | ``` 62 | 63 | Set max concurrency according to your API limit(s). 64 | 65 | ## User simulators 66 | 67 | By default, we use `gpt-4o` as the user simulator. You can use other models by setting the `--user-model` flag. For example, run a function calling agent with a claude user simulator: 68 | 69 | ```bash 70 | python run.py --agent-strategy tool-calling --env retail --model gpt-4o --model-provider openai --max-concurrency 10 --user-model claude-3-5-sonnet-20240620 --user-model-provider anthropic 71 | ``` 72 | 73 | ## Auto error identification 74 | 75 | Often times, it is difficult and time consuming to manually identify specific error locations in trajectories as they can be long and the constraints can be complex. We have provided an auto error identification tool that can do the following: 76 | 77 | 1. Fault assignment: determine the entity that is responsible for the fault (user, agent, environment) 78 | 2. Fault type classification: classify the type of fault (goal_partially_completed, used_wrong_tool, used_wrong_tool_argument, took_unintended_action) 79 | 80 | Both of the labels are accompanied with a description. 81 | 82 | To run the auto error identification, run: 83 | 84 | ```bash 85 | python auto_error_identification.py --env --results-path --max-concurrency 16 --output-path test-auto-error-identification -n 10 86 | ``` 87 | 88 | Please note that this feature utilizes an LLM, which may lead to inaccurate error identifications. 89 | 90 | *Notice: If an error is raised due to the structure of your results file, you may have to rerun the benchmark to produce a new results file. We have recently [rewritten](https://github.com/sierra-research/tau-bench/commit/043b544371757ebb3762b3d02a6675dfe0c41798) the benchmark to be more type-safe and extensible. 91 | 92 | ## License 93 | 94 | See `./LICENSE`. 95 | 96 | ## Contact 97 | 98 | Please submit issues or pull requests if you find problems with the benchmark. 99 | -------------------------------------------------------------------------------- /tau-bench-eval/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="tau_bench", 7 | version="0.1.0", 8 | description="The Tau-Bench package", 9 | long_description=open("README.md").read(), 10 | packages=find_packages(), 11 | include_package_data=True, 12 | install_requires=[ 13 | "openai>=1.13.3", 14 | "mistralai>=0.4.0", 15 | "anthropic>=0.26.1", 16 | "google-generativeai>=0.5.4", 17 | "tenacity>=8.3.0", 18 | "termcolor>=2.4.0", 19 | "numpy>=1.26.4", 20 | "litellm>=1.41.0", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from tau_bench.envs.base import Env as Env 4 | from tau_bench.agents.base import Agent as Agent 5 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/agents/base.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import abc 4 | from typing import Optional 5 | from tau_bench.envs.base import Env 6 | from tau_bench.types import SolveResult 7 | 8 | 9 | class Agent(abc.ABC): 10 | @abc.abstractmethod 11 | def solve( 12 | self, env: Env, task_index: Optional[int] = None, max_num_steps: int = 30 13 | ) -> SolveResult: 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/agents/tool_calling_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from litellm import completion 5 | from typing import List, Optional, Dict, Any 6 | 7 | from tau_bench.agents.base import Agent 8 | from tau_bench.envs.base import Env 9 | from tau_bench.types import SolveResult, Action, RESPOND_ACTION_NAME 10 | 11 | 12 | class ToolCallingAgent(Agent): 13 | def __init__( 14 | self, 15 | tools_info: List[Dict[str, Any]], 16 | wiki: str, 17 | model: str, 18 | provider: str, 19 | temperature: float = 0.0, 20 | ): 21 | self.tools_info = tools_info 22 | self.wiki = wiki 23 | self.model = model 24 | self.provider = provider 25 | self.temperature = temperature 26 | 27 | def solve( 28 | self, env: Env, task_index: Optional[int] = None, max_num_steps: int = 30 29 | ) -> SolveResult: 30 | total_cost = 0.0 31 | env_reset_res = env.reset(task_index=task_index) 32 | obs = env_reset_res.observation 33 | info = env_reset_res.info.model_dump() 34 | reward = 0.0 35 | messages: List[Dict[str, Any]] = [ 36 | {"role": "system", "content": self.wiki}, 37 | {"role": "user", "content": obs}, 38 | ] 39 | for _ in range(max_num_steps): 40 | res = completion( 41 | messages=messages, 42 | model=self.model, 43 | custom_llm_provider=self.provider, 44 | tools=self.tools_info, 45 | temperature=self.temperature, 46 | ) 47 | next_message = res.choices[0].message.model_dump() 48 | total_cost += res._hidden_params["response_cost"] if res._hidden_params["response_cost"] else 0.0 49 | action = message_to_action(next_message) 50 | env_response = env.step(action) 51 | reward = env_response.reward 52 | info = {**info, **env_response.info.model_dump()} 53 | if action.name != RESPOND_ACTION_NAME: 54 | next_message["tool_calls"] = next_message["tool_calls"][:1] 55 | messages.extend( 56 | [ 57 | next_message, 58 | { 59 | "role": "tool", 60 | "tool_call_id": next_message["tool_calls"][0]["id"], 61 | "name": next_message["tool_calls"][0]["function"]["name"], 62 | "content": env_response.observation, 63 | }, 64 | ] 65 | ) 66 | else: 67 | messages.extend( 68 | [ 69 | next_message, 70 | {"role": "user", "content": env_response.observation}, 71 | ] 72 | ) 73 | if env_response.done: 74 | break 75 | return SolveResult( 76 | reward=reward, 77 | info=info, 78 | messages=messages, 79 | total_cost=total_cost, 80 | ) 81 | 82 | 83 | def message_to_action( 84 | message: Dict[str, Any], 85 | ) -> Action: 86 | if "tool_calls" in message and message["tool_calls"] is not None and len(message["tool_calls"]) > 0 and message["tool_calls"][0]["function"] is not None: 87 | tool_call = message["tool_calls"][0] 88 | return Action( 89 | name=tool_call["function"]["name"], 90 | kwargs=json.loads(tool_call["function"]["arguments"]), 91 | ) 92 | else: 93 | return Action(name=RESPOND_ACTION_NAME, kwargs={"content": message["content"]}) 94 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Optional 4 | from tau_bench.envs.base import Env 5 | 6 | 7 | def get_env( 8 | env_name: str, 9 | user_strategy: str, 10 | user_model: str, 11 | task_split: str, 12 | user_provider: Optional[str] = None, 13 | task_index: Optional[int] = None, 14 | ) -> Env: 15 | if env_name == "retail": 16 | from tau_bench.envs.retail import MockRetailDomainEnv 17 | 18 | return MockRetailDomainEnv( 19 | user_strategy=user_strategy, 20 | user_model=user_model, 21 | task_split=task_split, 22 | user_provider=user_provider, 23 | task_index=task_index, 24 | ) 25 | elif env_name == "airline": 26 | from tau_bench.envs.airline import MockAirlineDomainEnv 27 | 28 | return MockAirlineDomainEnv( 29 | user_strategy=user_strategy, 30 | user_model=user_model, 31 | task_split=task_split, 32 | user_provider=user_provider, 33 | task_index=task_index, 34 | ) 35 | else: 36 | raise ValueError(f"Unknown environment: {env_name}") 37 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from tau_bench.envs.airline.env import MockAirlineDomainEnv as MockAirlineDomainEnv 4 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | import os 5 | from typing import Any 6 | 7 | FOLDER_PATH = os.path.dirname(__file__) 8 | 9 | 10 | def load_data() -> dict[str, Any]: 11 | with open(os.path.join(FOLDER_PATH, "flights.json")) as f: 12 | flight_data = json.load(f) 13 | with open(os.path.join(FOLDER_PATH, "reservations.json")) as f: 14 | reservation_data = json.load(f) 15 | with open(os.path.join(FOLDER_PATH, "users.json")) as f: 16 | user_data = json.load(f) 17 | return { 18 | "flights": flight_data, 19 | "reservations": reservation_data, 20 | "users": user_data, 21 | } 22 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/env.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from tau_bench.envs.airline.data import load_data 4 | from tau_bench.envs.airline.rules import RULES 5 | from tau_bench.envs.airline.tools import ALL_TOOLS 6 | from tau_bench.envs.airline.wiki import WIKI 7 | from tau_bench.envs.base import Env 8 | from typing import Optional 9 | 10 | 11 | class MockAirlineDomainEnv(Env): 12 | def __init__( 13 | self, 14 | user_strategy: str = "llm", 15 | user_model: str = "gpt-4o", 16 | user_provider: Optional[str] = None, 17 | task_split: str = "test", 18 | task_index: Optional[int] = None, 19 | ): 20 | if task_split == "test": 21 | from tau_bench.envs.airline.tasks_test import TASKS as tasks 22 | else: 23 | raise ValueError(f"Unknown task split: {task_split}") 24 | super().__init__( 25 | data_load_func=load_data, 26 | tools=ALL_TOOLS, 27 | tasks=tasks, 28 | wiki=WIKI, 29 | rules=RULES, 30 | user_strategy=user_strategy, 31 | user_model=user_model, 32 | user_provider=user_provider, 33 | task_index=task_index, 34 | ) 35 | self.terminate_tools = ["transfer_to_human_agents"] 36 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/rules.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | RULES = [] 4 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from .book_reservation import BookReservation 4 | from .calculate import Calculate 5 | from .cancel_reservation import CancelReservation 6 | from .get_reservation_details import GetReservationDetails 7 | from .get_user_details import GetUserDetails 8 | from .list_all_airports import ListAllAirports 9 | from .search_direct_flight import SearchDirectFlight 10 | from .search_onestop_flight import SearchOnestopFlight 11 | from .send_certificate import SendCertificate 12 | from .think import Think 13 | from .transfer_to_human_agents import TransferToHumanAgents 14 | from .update_reservation_baggages import UpdateReservationBaggages 15 | from .update_reservation_flights import UpdateReservationFlights 16 | from .update_reservation_passengers import UpdateReservationPassengers 17 | 18 | ALL_TOOLS = [ 19 | BookReservation, 20 | Calculate, 21 | CancelReservation, 22 | GetReservationDetails, 23 | GetUserDetails, 24 | ListAllAirports, 25 | SearchDirectFlight, 26 | SearchOnestopFlight, 27 | SendCertificate, 28 | Think, 29 | TransferToHumanAgents, 30 | UpdateReservationBaggages, 31 | UpdateReservationFlights, 32 | UpdateReservationPassengers, 33 | ] 34 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/calculate.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class Calculate(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], expression: str) -> str: 10 | if not all(char in "0123456789+-*/(). " for char in expression): 11 | return "Error: invalid characters in expression" 12 | try: 13 | return str(round(float(eval(expression, {"__builtins__": None}, {})), 2)) 14 | except Exception as e: 15 | return f"Error: {e}" 16 | 17 | @staticmethod 18 | def get_info() -> Dict[str, Any]: 19 | return { 20 | "type": "function", 21 | "function": { 22 | "name": "calculate", 23 | "description": "Calculate the result of a mathematical expression.", 24 | "parameters": { 25 | "type": "object", 26 | "properties": { 27 | "expression": { 28 | "type": "string", 29 | "description": "The mathematical expression to calculate, such as '2 + 2'. The expression can contain numbers, operators (+, -, *, /), parentheses, and spaces.", 30 | }, 31 | }, 32 | "required": ["expression"], 33 | }, 34 | }, 35 | } 36 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/cancel_reservation.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class CancelReservation(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | reservation_id: str, 13 | ) -> str: 14 | reservations = data["reservations"] 15 | if reservation_id not in reservations: 16 | return "Error: reservation not found" 17 | reservation = reservations[reservation_id] 18 | 19 | # reverse the payment 20 | refunds = [] 21 | for payment in reservation["payment_history"]: 22 | refunds.append( 23 | { 24 | "payment_id": payment["payment_id"], 25 | "amount": -payment["amount"], 26 | } 27 | ) 28 | reservation["payment_history"].extend(refunds) 29 | reservation["status"] = "cancelled" 30 | return json.dumps(reservation) 31 | 32 | @staticmethod 33 | def get_info() -> Dict[str, Any]: 34 | return { 35 | "type": "function", 36 | "function": { 37 | "name": "cancel_reservation", 38 | "description": "Cancel the whole reservation.", 39 | "parameters": { 40 | "type": "object", 41 | "properties": { 42 | "reservation_id": { 43 | "type": "string", 44 | "description": "The reservation ID, such as 'ZFA04Y'.", 45 | }, 46 | }, 47 | "required": ["reservation_id"], 48 | }, 49 | }, 50 | } 51 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/get_reservation_details.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class GetReservationDetails(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], reservation_id: str) -> str: 11 | reservations = data["reservations"] 12 | if reservation_id in reservations: 13 | return json.dumps(reservations[reservation_id]) 14 | return "Error: user not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "get_reservation_details", 22 | "description": "Get the details of a reservation.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "reservation_id": { 27 | "type": "string", 28 | "description": "The reservation id, such as '8JX2WO'.", 29 | }, 30 | }, 31 | "required": ["reservation_id"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/get_user_details.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class GetUserDetails(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], user_id: str) -> str: 11 | users = data["users"] 12 | if user_id in users: 13 | return json.dumps(users[user_id]) 14 | return "Error: user not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "get_user_details", 22 | "description": "Get the details of an user.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "user_id": { 27 | "type": "string", 28 | "description": "The user id, such as 'sara_doe_496'.", 29 | }, 30 | }, 31 | "required": ["user_id"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/list_all_airports.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ListAllAirports(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any]) -> str: 11 | airports = [ 12 | "SFO", 13 | "JFK", 14 | "LAX", 15 | "ORD", 16 | "DFW", 17 | "DEN", 18 | "SEA", 19 | "ATL", 20 | "MIA", 21 | "BOS", 22 | "PHX", 23 | "IAH", 24 | "LAS", 25 | "MCO", 26 | "EWR", 27 | "CLT", 28 | "MSP", 29 | "DTW", 30 | "PHL", 31 | "LGA", 32 | ] 33 | cities = [ 34 | "San Francisco", 35 | "New York", 36 | "Los Angeles", 37 | "Chicago", 38 | "Dallas", 39 | "Denver", 40 | "Seattle", 41 | "Atlanta", 42 | "Miami", 43 | "Boston", 44 | "Phoenix", 45 | "Houston", 46 | "Las Vegas", 47 | "Orlando", 48 | "Newark", 49 | "Charlotte", 50 | "Minneapolis", 51 | "Detroit", 52 | "Philadelphia", 53 | "LaGuardia", 54 | ] 55 | return json.dumps({airport: city for airport, city in zip(airports, cities)}) 56 | 57 | @staticmethod 58 | def get_info() -> Dict[str, Any]: 59 | return { 60 | "type": "function", 61 | "function": { 62 | "name": "list_all_airports", 63 | "description": "List all airports and their cities.", 64 | "parameters": { 65 | "type": "object", 66 | "properties": {}, 67 | "required": [], 68 | }, 69 | }, 70 | } 71 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/search_direct_flight.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class SearchDirectFlight(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], origin: str, destination: str, date: str) -> str: 11 | flights = data["flights"] 12 | results = [] 13 | for flight in flights.values(): 14 | if flight["origin"] == origin and flight["destination"] == destination: 15 | if ( 16 | date in flight["dates"] 17 | and flight["dates"][date]["status"] == "available" 18 | ): 19 | # results add flight except dates, but add flight["datas"][date] 20 | results.append({k: v for k, v in flight.items() if k != "dates"}) 21 | results[-1].update(flight["dates"][date]) 22 | return json.dumps(results) 23 | 24 | @staticmethod 25 | def get_info() -> Dict[str, Any]: 26 | return { 27 | "type": "function", 28 | "function": { 29 | "name": "search_direct_flight", 30 | "description": "Search direct flights between two cities on a specific date.", 31 | "parameters": { 32 | "type": "object", 33 | "properties": { 34 | "origin": { 35 | "type": "string", 36 | "description": "The origin city airport in three letters, such as 'JFK'.", 37 | }, 38 | "destination": { 39 | "type": "string", 40 | "description": "The destination city airport in three letters, such as 'LAX'.", 41 | }, 42 | "date": { 43 | "type": "string", 44 | "description": "The date of the flight in the format 'YYYY-MM-DD', such as '2024-01-01'.", 45 | }, 46 | }, 47 | "required": ["origin", "destination", "date"], 48 | }, 49 | }, 50 | } 51 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/search_onestop_flight.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class SearchOnestopFlight(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], origin: str, destination: str, date: str) -> str: 11 | flights = data["flights"] 12 | results = [] 13 | for flight1 in flights.values(): 14 | if flight1["origin"] == origin: 15 | for flight2 in flights.values(): 16 | if ( 17 | flight2["destination"] == destination 18 | and flight1["destination"] == flight2["origin"] 19 | ): 20 | date2 = ( 21 | f"2024-05-{int(date[-2:])+1}" 22 | if "+1" in flight1["scheduled_arrival_time_est"] 23 | else date 24 | ) 25 | if ( 26 | flight1["scheduled_arrival_time_est"] 27 | > flight2["scheduled_departure_time_est"] 28 | ): 29 | continue 30 | if date in flight1["dates"] and date2 in flight2["dates"]: 31 | if ( 32 | flight1["dates"][date]["status"] == "available" 33 | and flight2["dates"][date2]["status"] == "available" 34 | ): 35 | result1 = { 36 | k: v for k, v in flight1.items() if k != "dates" 37 | } 38 | result1.update(flight1["dates"][date]) 39 | result1["date"] = date 40 | result2 = { 41 | k: v for k, v in flight2.items() if k != "dates" 42 | } 43 | result2.update(flight2["dates"][date]) 44 | result2["date"] = date2 45 | results.append([result1, result2]) 46 | return json.dumps(results) 47 | 48 | @staticmethod 49 | def get_info() -> Dict[str, Any]: 50 | return { 51 | "type": "function", 52 | "function": { 53 | "name": "search_onestop_flight", 54 | "description": "Search direct flights between two cities on a specific date.", 55 | "parameters": { 56 | "type": "object", 57 | "properties": { 58 | "origin": { 59 | "type": "string", 60 | "description": "The origin city airport in three letters, such as 'JFK'.", 61 | }, 62 | "destination": { 63 | "type": "string", 64 | "description": "The destination city airport in three letters, such as 'LAX'.", 65 | }, 66 | "date": { 67 | "type": "string", 68 | "description": "The date of the flight in the format 'YYYY-MM-DD', such as '2024-05-01'.", 69 | }, 70 | }, 71 | "required": ["origin", "destination", "date"], 72 | }, 73 | }, 74 | } 75 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/send_certificate.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class SendCertificate(Tool): 8 | @staticmethod 9 | def invoke( 10 | data: Dict[str, Any], 11 | user_id: str, 12 | amount: int, 13 | ) -> str: 14 | users = data["users"] 15 | if user_id not in users: 16 | return "Error: user not found" 17 | user = users[user_id] 18 | 19 | # add a certificate, assume at most 3 cases per task 20 | for id in [3221322, 3221323, 3221324]: 21 | payment_id = f"certificate_{id}" 22 | if payment_id not in user["payment_methods"]: 23 | user["payment_methods"][payment_id] = { 24 | "source": "certificate", 25 | "amount": amount, 26 | "id": payment_id, 27 | } 28 | return f"Certificate {payment_id} added to user {user_id} with amount {amount}." 29 | 30 | @staticmethod 31 | def get_info() -> Dict[str, Any]: 32 | return { 33 | "type": "function", 34 | "function": { 35 | "name": "send_certificate", 36 | "description": "Send a certificate to a user. Be careful!", 37 | "parameters": { 38 | "type": "object", 39 | "properties": { 40 | "user_id": { 41 | "type": "string", 42 | "description": "The ID of the user to book the reservation, such as 'sara_doe_496'.", 43 | }, 44 | "amount": { 45 | "type": "number", 46 | "description": "Certificate amount to send.", 47 | }, 48 | }, 49 | "required": ["user_id", "amount"], 50 | }, 51 | }, 52 | } 53 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/think.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class Think(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], thought: str) -> str: 10 | return "" 11 | 12 | @staticmethod 13 | def get_info() -> Dict[str, Any]: 14 | return { 15 | "type": "function", 16 | "function": { 17 | "name": "think", 18 | "description": "Use the tool to think about something. It will not obtain new information or change the database, but just append the thought to the log. Use it when complex reasoning is needed.", 19 | "parameters": { 20 | "type": "object", 21 | "properties": { 22 | "thought": { 23 | "type": "string", 24 | "description": "A thought to think about.", 25 | }, 26 | }, 27 | "required": ["thought"], 28 | }, 29 | }, 30 | } 31 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/transfer_to_human_agents.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class TransferToHumanAgents(Tool): 8 | @staticmethod 9 | def invoke( 10 | data: Dict[str, Any], 11 | summary: str, 12 | ) -> str: 13 | return "Transfer successful" 14 | 15 | @staticmethod 16 | def get_info() -> Dict[str, Any]: 17 | return { 18 | "type": "function", 19 | "function": { 20 | "name": "transfer_to_human_agents", 21 | "description": "Transfer the user to a human agent, with a summary of the user's issue. Only transfer if the user explicitly asks for a human agent, or if the user's issue cannot be resolved by the agent with the available tools.", 22 | "parameters": { 23 | "type": "object", 24 | "properties": { 25 | "summary": { 26 | "type": "string", 27 | "description": "A summary of the user's issue.", 28 | }, 29 | }, 30 | "required": [ 31 | "summary", 32 | ], 33 | }, 34 | }, 35 | } 36 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/update_reservation_baggages.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class UpdateReservationBaggages(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | reservation_id: str, 13 | total_baggages: int, 14 | nonfree_baggages: int, 15 | payment_id: str, 16 | ) -> str: 17 | users, reservations = data["users"], data["reservations"] 18 | if reservation_id not in reservations: 19 | return "Error: reservation not found" 20 | reservation = reservations[reservation_id] 21 | 22 | total_price = 50 * max(0, nonfree_baggages - reservation["nonfree_baggages"]) 23 | if payment_id not in users[reservation["user_id"]]["payment_methods"]: 24 | return "Error: payment method not found" 25 | payment_method = users[reservation["user_id"]]["payment_methods"][payment_id] 26 | if payment_method["source"] == "certificate": 27 | return "Error: certificate cannot be used to update reservation" 28 | elif ( 29 | payment_method["source"] == "gift_card" 30 | and payment_method["amount"] < total_price 31 | ): 32 | return "Error: gift card balance is not enough" 33 | 34 | reservation["total_baggages"] = total_baggages 35 | reservation["nonfree_baggages"] = nonfree_baggages 36 | if payment_method["source"] == "gift_card": 37 | payment_method["amount"] -= total_price 38 | 39 | if total_price != 0: 40 | reservation["payment_history"].append( 41 | { 42 | "payment_id": payment_id, 43 | "amount": total_price, 44 | } 45 | ) 46 | 47 | return json.dumps(reservation) 48 | 49 | @staticmethod 50 | def get_info() -> Dict[str, Any]: 51 | return { 52 | "type": "function", 53 | "function": { 54 | "name": "update_reservation_baggages", 55 | "description": "Update the baggage information of a reservation.", 56 | "parameters": { 57 | "type": "object", 58 | "properties": { 59 | "reservation_id": { 60 | "type": "string", 61 | "description": "The reservation ID, such as 'ZFA04Y'.", 62 | }, 63 | "total_baggages": { 64 | "type": "integer", 65 | "description": "The updated total number of baggage items included in the reservation.", 66 | }, 67 | "nonfree_baggages": { 68 | "type": "integer", 69 | "description": "The updated number of non-free baggage items included in the reservation.", 70 | }, 71 | "payment_id": { 72 | "type": "string", 73 | "description": "The payment id stored in user profile, such as 'credit_card_7815826', 'gift_card_7815826', 'certificate_7815826'.", 74 | }, 75 | }, 76 | "required": [ 77 | "reservation_id", 78 | "total_baggages", 79 | "nonfree_baggages", 80 | "payment_id", 81 | ], 82 | }, 83 | }, 84 | } 85 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/tools/update_reservation_passengers.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict, List 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class UpdateReservationPassengers(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | reservation_id: str, 13 | passengers: List[Dict[str, Any]], 14 | ) -> str: 15 | reservations = data["reservations"] 16 | if reservation_id not in reservations: 17 | return "Error: reservation not found" 18 | reservation = reservations[reservation_id] 19 | if len(passengers) != len(reservation["passengers"]): 20 | return "Error: number of passengers does not match" 21 | reservation["passengers"] = passengers 22 | return json.dumps(reservation) 23 | 24 | @staticmethod 25 | def get_info() -> Dict[str, Any]: 26 | return { 27 | "type": "function", 28 | "function": { 29 | "name": "update_reservation_passengers", 30 | "description": "Update the passenger information of a reservation.", 31 | "parameters": { 32 | "type": "object", 33 | "properties": { 34 | "reservation_id": { 35 | "type": "string", 36 | "description": "The reservation ID, such as 'ZFA04Y'.", 37 | }, 38 | "passengers": { 39 | "type": "array", 40 | "description": "An array of objects containing details about each passenger.", 41 | "items": { 42 | "type": "object", 43 | "properties": { 44 | "first_name": { 45 | "type": "string", 46 | "description": "The first name of the passenger, such as 'Noah'.", 47 | }, 48 | "last_name": { 49 | "type": "string", 50 | "description": "The last name of the passenger, such as 'Brown'.", 51 | }, 52 | "dob": { 53 | "type": "string", 54 | "description": "The date of birth of the passenger in the format 'YYYY-MM-DD', such as '1990-01-01'.", 55 | }, 56 | }, 57 | "required": ["first_name", "last_name", "dob"], 58 | }, 59 | }, 60 | }, 61 | "required": ["reservation_id", "passengers"], 62 | }, 63 | }, 64 | } 65 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/airline/wiki.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import os 4 | 5 | FOLDER_PATH = os.path.dirname(__file__) 6 | 7 | with open(os.path.join(FOLDER_PATH, "wiki.md"), "r") as f: 8 | WIKI = f.read() 9 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from tau_bench.envs.retail.env import MockRetailDomainEnv as MockRetailDomainEnv 4 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | import os 5 | from typing import Any 6 | 7 | FOLDER_PATH = os.path.dirname(__file__) 8 | 9 | 10 | def load_data() -> dict[str, Any]: 11 | with open(os.path.join(FOLDER_PATH, "orders.json")) as f: 12 | order_data = json.load(f) 13 | with open(os.path.join(FOLDER_PATH, "products.json")) as f: 14 | product_data = json.load(f) 15 | with open(os.path.join(FOLDER_PATH, "users.json")) as f: 16 | user_data = json.load(f) 17 | return { 18 | "orders": order_data, 19 | "products": product_data, 20 | "users": user_data, 21 | } 22 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/data/readme.md: -------------------------------------------------------------------------------- 1 | # Mock Data Generation 2 | 3 | ## Current Mock Data for the Benchmark 4 | Feel free to use some of the data for other purposes. 5 | - `users.json`: a database of users with their emails, addresses, and orders 6 | - `products.json`: a database of products, where each product has variants (e.g., size, color). 7 | - `orders.json`: a database of orders that can be operated upon. 8 | 9 | 10 | Check `../tools` for mock APIs on top of current mock data. 11 | 12 | 13 | ### Experience of Mock Data Generation 14 | 15 | Read our paper to learn more about the generation process for each database. In general, it involves the following stages: 16 | 17 | 1. Design the type and schema of each database. Can use GPT for co-brainstorming but has to be human decided as it is the foundation of everything else. 18 | 2. For each schema, figure out which parts can be programmaticly generated and which parts need GPT. For example, 19 | - Product types (shirt, lamp, pen) and user names (Sara, John, Noah) need GPT generation 20 | - Product price and shipping date can be generated via code 21 | 3. Use GPT to generate seed data (first names, last names, addresses, cities, etc.), then use a program to compose them with other code generated data. Can use GPT to help write the code for this part, but I think code-based database construction is more reliable than GPT-based database construction (e.g., give some example user profiles and ask GPT to generate more --- issues with diversity and reliability). 22 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/env.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from tau_bench.envs.base import Env 4 | from tau_bench.envs.retail.data import load_data 5 | from tau_bench.envs.retail.rules import RULES 6 | from tau_bench.envs.retail.tools import ALL_TOOLS 7 | from tau_bench.envs.retail.wiki import WIKI 8 | from typing import Optional 9 | 10 | 11 | class MockRetailDomainEnv(Env): 12 | def __init__( 13 | self, 14 | user_strategy: str = "llm", 15 | user_model: str = "gpt-4o", 16 | user_provider: Optional[str] = None, 17 | task_split: str = "test", 18 | task_index: Optional[int] = None, 19 | ): 20 | if task_split == "test": 21 | from tau_bench.envs.retail.tasks_test import TASKS_TEST as tasks 22 | elif task_split == "train": 23 | from tau_bench.envs.retail.tasks_train import TASKS_TRAIN as tasks 24 | elif task_split == "dev": 25 | from tau_bench.envs.retail.tasks_dev import TASKS_DEV as tasks 26 | else: 27 | raise ValueError(f"Unknown task split: {task_split}") 28 | super().__init__( 29 | data_load_func=load_data, 30 | tools=ALL_TOOLS, 31 | tasks=tasks, 32 | wiki=WIKI, 33 | rules=RULES, 34 | user_strategy=user_strategy, 35 | user_model=user_model, 36 | user_provider=user_provider, 37 | task_index=task_index, 38 | ) 39 | self.terminate_tools = ["transfer_to_human_agents"] 40 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/rules.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | RULES = [ 4 | "You are a customer service representative for an online retail company. You are chatting with a customer, and you can call tools or respond to the user.", 5 | "The agent should always first confirm the user id by email or name+zip before proceeding with any task.", 6 | "The agent should not proceed with any task if the user id is not found.", 7 | "For any change to the backend database, e.g., address update, refund, or order cancellation, the agent must confirm the transaction details with the user and ask for permission, and get explicit authorization (yes) to proceed.", 8 | "The agent should solve the user task given the tools, without transferring to a human agent.", 9 | "The agent should not make up any information or knowledge not provided from the user or the tools.", 10 | "The agent should at most make one tool call at a time, and if the agent makes a tool call, it does not respond to the user at the same time.", 11 | ] 12 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from .calculate import Calculate 4 | from .cancel_pending_order import CancelPendingOrder 5 | from .exchange_delivered_order_items import ExchangeDeliveredOrderItems 6 | from .find_user_id_by_email import FindUserIdByEmail 7 | from .find_user_id_by_name_zip import FindUserIdByNameZip 8 | from .get_order_details import GetOrderDetails 9 | from .get_product_details import GetProductDetails 10 | from .get_user_details import GetUserDetails 11 | from .list_all_product_types import ListAllProductTypes 12 | from .modify_pending_order_address import ModifyPendingOrderAddress 13 | from .modify_pending_order_items import ModifyPendingOrderItems 14 | from .modify_pending_order_payment import ModifyPendingOrderPayment 15 | from .modify_user_address import ModifyUserAddress 16 | from .return_delivered_order_items import ReturnDeliveredOrderItems 17 | from .think import Think 18 | from .transfer_to_human_agents import TransferToHumanAgents 19 | 20 | 21 | ALL_TOOLS = [ 22 | Calculate, 23 | CancelPendingOrder, 24 | ExchangeDeliveredOrderItems, 25 | FindUserIdByEmail, 26 | FindUserIdByNameZip, 27 | GetOrderDetails, 28 | GetProductDetails, 29 | GetUserDetails, 30 | ListAllProductTypes, 31 | ModifyPendingOrderAddress, 32 | ModifyPendingOrderItems, 33 | ModifyPendingOrderPayment, 34 | ModifyUserAddress, 35 | ReturnDeliveredOrderItems, 36 | Think, 37 | TransferToHumanAgents, 38 | ] 39 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/calculate.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class Calculate(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], expression: str) -> str: 10 | if not all(char in "0123456789+-*/(). " for char in expression): 11 | return "Error: invalid characters in expression" 12 | try: 13 | # Evaluate the mathematical expression safely 14 | return str(round(float(eval(expression, {"__builtins__": None}, {})), 2)) 15 | except Exception as e: 16 | return f"Error: {e}" 17 | 18 | @staticmethod 19 | def get_info() -> Dict[str, Any]: 20 | return { 21 | "type": "function", 22 | "function": { 23 | "name": "calculate", 24 | "description": "Calculate the result of a mathematical expression.", 25 | "parameters": { 26 | "type": "object", 27 | "properties": { 28 | "expression": { 29 | "type": "string", 30 | "description": "The mathematical expression to calculate, such as '2 + 2'. The expression can contain numbers, operators (+, -, *, /), parentheses, and spaces.", 31 | }, 32 | }, 33 | "required": ["expression"], 34 | }, 35 | }, 36 | } 37 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/cancel_pending_order.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class CancelPendingOrder(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], order_id: str, reason: str) -> str: 11 | # check order exists and is pending 12 | orders = data["orders"] 13 | if order_id not in orders: 14 | return "Error: order not found" 15 | order = orders[order_id] 16 | if order["status"] != "pending": 17 | return "Error: non-pending order cannot be cancelled" 18 | 19 | # check reason 20 | if reason not in ["no longer needed", "ordered by mistake"]: 21 | return "Error: invalid reason" 22 | 23 | # handle refund 24 | refunds = [] 25 | for payment in order["payment_history"]: 26 | payment_id = payment["payment_method_id"] 27 | refund = { 28 | "transaction_type": "refund", 29 | "amount": payment["amount"], 30 | "payment_method_id": payment_id, 31 | } 32 | refunds.append(refund) 33 | if "gift_card" in payment_id: # refund to gift card immediately 34 | payment_method = data["users"][order["user_id"]]["payment_methods"][ 35 | payment_id 36 | ] 37 | payment_method["balance"] += payment["amount"] 38 | payment_method["balance"] = round(payment_method["balance"], 2) 39 | 40 | # update order status 41 | order["status"] = "cancelled" 42 | order["cancel_reason"] = reason 43 | order["payment_history"].extend(refunds) 44 | 45 | return json.dumps(order) 46 | 47 | @staticmethod 48 | def get_info() -> Dict[str, Any]: 49 | return { 50 | "type": "function", 51 | "function": { 52 | "name": "cancel_pending_order", 53 | "description": ( 54 | "Cancel a pending order. If the order is already processed or delivered, " 55 | "it cannot be cancelled. The agent needs to explain the cancellation detail " 56 | "and ask for explicit user confirmation (yes/no) to proceed. If the user confirms, " 57 | "the order status will be changed to 'cancelled' and the payment will be refunded. " 58 | "The refund will be added to the user's gift card balance immediately if the payment " 59 | "was made using a gift card, otherwise the refund would take 5-7 business days to process. " 60 | "The function returns the order details after the cancellation." 61 | ), 62 | "parameters": { 63 | "type": "object", 64 | "properties": { 65 | "order_id": { 66 | "type": "string", 67 | "description": "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id.", 68 | }, 69 | "reason": { 70 | "type": "string", 71 | "enum": ["no longer needed", "ordered by mistake"], 72 | "description": "The reason for cancellation, which should be either 'no longer needed' or 'ordered by mistake'.", 73 | }, 74 | }, 75 | "required": ["order_id", "reason"], 76 | }, 77 | }, 78 | } 79 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/find_user_id_by_email.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class FindUserIdByEmail(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], email: str) -> str: 10 | users = data["users"] 11 | for user_id, profile in users.items(): 12 | if profile["email"].lower() == email.lower(): 13 | return user_id 14 | return "Error: user not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "find_user_id_by_email", 22 | "description": "Find user id by email. If the user is not found, the function will return an error message.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "email": { 27 | "type": "string", 28 | "description": "The email of the user, such as 'something@example.com'.", 29 | }, 30 | }, 31 | "required": ["email"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/find_user_id_by_name_zip.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class FindUserIdByNameZip(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], first_name: str, last_name: str, zip: str) -> str: 10 | users = data["users"] 11 | for user_id, profile in users.items(): 12 | if ( 13 | profile["name"]["first_name"].lower() == first_name.lower() 14 | and profile["name"]["last_name"].lower() == last_name.lower() 15 | and profile["address"]["zip"] == zip 16 | ): 17 | return user_id 18 | return "Error: user not found" 19 | 20 | @staticmethod 21 | def get_info() -> Dict[str, Any]: 22 | return { 23 | "type": "function", 24 | "function": { 25 | "name": "find_user_id_by_name_zip", 26 | "description": ( 27 | "Find user id by first name, last name, and zip code. If the user is not found, the function " 28 | "will return an error message. By default, find user id by email, and only call this function " 29 | "if the user is not found by email or cannot remember email." 30 | ), 31 | "parameters": { 32 | "type": "object", 33 | "properties": { 34 | "first_name": { 35 | "type": "string", 36 | "description": "The first name of the customer, such as 'John'.", 37 | }, 38 | "last_name": { 39 | "type": "string", 40 | "description": "The last name of the customer, such as 'Doe'.", 41 | }, 42 | "zip": { 43 | "type": "string", 44 | "description": "The zip code of the customer, such as '12345'.", 45 | }, 46 | }, 47 | "required": ["first_name", "last_name", "zip"], 48 | }, 49 | }, 50 | } 51 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/get_order_details.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class GetOrderDetails(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], order_id: str) -> str: 11 | orders = data["orders"] 12 | if order_id in orders: 13 | return json.dumps(orders[order_id]) 14 | return "Error: order not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "get_order_details", 22 | "description": "Get the status and details of an order.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "order_id": { 27 | "type": "string", 28 | "description": "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id.", 29 | }, 30 | }, 31 | "required": ["order_id"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/get_product_details.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class GetProductDetails(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], product_id: str) -> str: 11 | products = data["products"] 12 | if product_id in products: 13 | return json.dumps(products[product_id]) 14 | return "Error: product not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "get_product_details", 22 | "description": "Get the inventory details of a product.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "product_id": { 27 | "type": "string", 28 | "description": "The product id, such as '6086499569'. Be careful the product id is different from the item id.", 29 | }, 30 | }, 31 | "required": ["product_id"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/get_user_details.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class GetUserDetails(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any], user_id: str) -> str: 11 | users = data["users"] 12 | if user_id in users: 13 | return json.dumps(users[user_id]) 14 | return "Error: user not found" 15 | 16 | @staticmethod 17 | def get_info() -> Dict[str, Any]: 18 | return { 19 | "type": "function", 20 | "function": { 21 | "name": "get_user_details", 22 | "description": "Get the details of a user.", 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "user_id": { 27 | "type": "string", 28 | "description": "The user id, such as 'sara_doe_496'.", 29 | }, 30 | }, 31 | "required": ["user_id"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/list_all_product_types.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ListAllProductTypes(Tool): 9 | @staticmethod 10 | def invoke(data: Dict[str, Any]) -> str: 11 | products = data["products"] 12 | product_dict = { 13 | product["name"]: product["product_id"] for product in products.values() 14 | } 15 | product_dict = dict(sorted(product_dict.items())) 16 | return json.dumps(product_dict) 17 | 18 | @staticmethod 19 | def get_info() -> Dict[str, Any]: 20 | return { 21 | "type": "function", 22 | "function": { 23 | "name": "list_all_product_types", 24 | "description": "List the name and product id of all product types. Each product type has a variety of different items with unique item ids and options. There are only 50 product types in the store.", 25 | "parameters": { 26 | "type": "object", 27 | "properties": {}, 28 | "required": [], 29 | }, 30 | }, 31 | } 32 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/modify_pending_order_address.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ModifyPendingOrderAddress(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | order_id: str, 13 | address1: str, 14 | address2: str, 15 | city: str, 16 | state: str, 17 | country: str, 18 | zip: str, 19 | ) -> str: 20 | # Check if the order exists and is pending 21 | orders = data["orders"] 22 | if order_id not in orders: 23 | return "Error: order not found" 24 | order = orders[order_id] 25 | if order["status"] != "pending": 26 | return "Error: non-pending order cannot be modified" 27 | 28 | # Modify the address 29 | order["address"] = { 30 | "address1": address1, 31 | "address2": address2, 32 | "city": city, 33 | "state": state, 34 | "country": country, 35 | "zip": zip, 36 | } 37 | return json.dumps(order) 38 | 39 | @staticmethod 40 | def get_info() -> Dict[str, Any]: 41 | return { 42 | "type": "function", 43 | "function": { 44 | "name": "modify_pending_order_address", 45 | "description": "Modify the shipping address of a pending order. The agent needs to explain the modification detail and ask for explicit user confirmation (yes/no) to proceed.", 46 | "parameters": { 47 | "type": "object", 48 | "properties": { 49 | "order_id": { 50 | "type": "string", 51 | "description": "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id.", 52 | }, 53 | "address1": { 54 | "type": "string", 55 | "description": "The first line of the address, such as '123 Main St'.", 56 | }, 57 | "address2": { 58 | "type": "string", 59 | "description": "The second line of the address, such as 'Apt 1' or ''.", 60 | }, 61 | "city": { 62 | "type": "string", 63 | "description": "The city, such as 'San Francisco'.", 64 | }, 65 | "state": { 66 | "type": "string", 67 | "description": "The province, such as 'CA'.", 68 | }, 69 | "country": { 70 | "type": "string", 71 | "description": "The country, such as 'USA'.", 72 | }, 73 | "zip": { 74 | "type": "string", 75 | "description": "The zip code, such as '12345'.", 76 | }, 77 | }, 78 | "required": [ 79 | "order_id", 80 | "address1", 81 | "address2", 82 | "city", 83 | "state", 84 | "country", 85 | "zip", 86 | ], 87 | }, 88 | }, 89 | } 90 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/modify_pending_order_payment.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ModifyPendingOrderPayment(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | order_id: str, 13 | payment_method_id: str, 14 | ) -> str: 15 | orders = data["orders"] 16 | 17 | # Check if the order exists and is pending 18 | if order_id not in orders: 19 | return "Error: order not found" 20 | order = orders[order_id] 21 | if order["status"] != "pending": 22 | return "Error: non-pending order cannot be modified" 23 | 24 | # Check if the payment method exists 25 | if payment_method_id not in data["users"][order["user_id"]]["payment_methods"]: 26 | return "Error: payment method not found" 27 | 28 | # Check that the payment history should only have one payment 29 | if ( 30 | len(order["payment_history"]) > 1 31 | or order["payment_history"][0]["transaction_type"] != "payment" 32 | ): 33 | return "Error: there should be exactly one payment for a pending order" 34 | 35 | # Check that the payment method is different 36 | if order["payment_history"][0]["payment_method_id"] == payment_method_id: 37 | return ( 38 | "Error: the new payment method should be different from the current one" 39 | ) 40 | 41 | amount = order["payment_history"][0]["amount"] 42 | payment_method = data["users"][order["user_id"]]["payment_methods"][ 43 | payment_method_id 44 | ] 45 | 46 | # Check if the new payment method has enough balance if it is a gift card 47 | if ( 48 | payment_method["source"] == "gift_card" 49 | and payment_method["balance"] < amount 50 | ): 51 | return "Error: insufficient gift card balance to pay for the order" 52 | 53 | # Modify the payment method 54 | order["payment_history"].extend( 55 | [ 56 | { 57 | "transaction_type": "payment", 58 | "amount": amount, 59 | "payment_method_id": payment_method_id, 60 | }, 61 | { 62 | "transaction_type": "refund", 63 | "amount": amount, 64 | "payment_method_id": order["payment_history"][0][ 65 | "payment_method_id" 66 | ], 67 | }, 68 | ] 69 | ) 70 | 71 | # If payment is made by gift card, update the balance 72 | if payment_method["source"] == "gift_card": 73 | payment_method["balance"] -= amount 74 | payment_method["balance"] = round(payment_method["balance"], 2) 75 | 76 | # If refund is made to a gift card, update the balance 77 | if "gift_card" in order["payment_history"][0]["payment_method_id"]: 78 | old_payment_method = data["users"][order["user_id"]]["payment_methods"][ 79 | order["payment_history"][0]["payment_method_id"] 80 | ] 81 | old_payment_method["balance"] += amount 82 | old_payment_method["balance"] = round(old_payment_method["balance"], 2) 83 | 84 | return json.dumps(order) 85 | 86 | @staticmethod 87 | def get_info() -> Dict[str, Any]: 88 | return { 89 | "type": "function", 90 | "function": { 91 | "name": "modify_pending_order_payment", 92 | "description": "Modify the payment method of a pending order. The agent needs to explain the modification detail and ask for explicit user confirmation (yes/no) to proceed.", 93 | "parameters": { 94 | "type": "object", 95 | "properties": { 96 | "order_id": { 97 | "type": "string", 98 | "description": "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id.", 99 | }, 100 | "payment_method_id": { 101 | "type": "string", 102 | "description": "The payment method id to pay or receive refund for the item price difference, such as 'gift_card_0000000' or 'credit_card_0000000'. These can be looked up from the user or order details.", 103 | }, 104 | }, 105 | "required": [ 106 | "order_id", 107 | "payment_method_id", 108 | ], 109 | }, 110 | }, 111 | } 112 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/modify_user_address.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ModifyUserAddress(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], 12 | user_id: str, 13 | address1: str, 14 | address2: str, 15 | city: str, 16 | state: str, 17 | country: str, 18 | zip: str, 19 | ) -> str: 20 | users = data["users"] 21 | if user_id not in users: 22 | return "Error: user not found" 23 | user = users[user_id] 24 | user["address"] = { 25 | "address1": address1, 26 | "address2": address2, 27 | "city": city, 28 | "state": state, 29 | "country": country, 30 | "zip": zip, 31 | } 32 | return json.dumps(user) 33 | 34 | @staticmethod 35 | def get_info() -> Dict[str, Any]: 36 | return { 37 | "type": "function", 38 | "function": { 39 | "name": "modify_user_address", 40 | "description": "Modify the default address of a user. The agent needs to explain the modification detail and ask for explicit user confirmation (yes/no) to proceed.", 41 | "parameters": { 42 | "type": "object", 43 | "properties": { 44 | "user_id": { 45 | "type": "string", 46 | "description": "The user id, such as 'sara_doe_496'.", 47 | }, 48 | "address1": { 49 | "type": "string", 50 | "description": "The first line of the address, such as '123 Main St'.", 51 | }, 52 | "address2": { 53 | "type": "string", 54 | "description": "The second line of the address, such as 'Apt 1' or ''.", 55 | }, 56 | "city": { 57 | "type": "string", 58 | "description": "The city, such as 'San Francisco'.", 59 | }, 60 | "state": { 61 | "type": "string", 62 | "description": "The province, such as 'CA'.", 63 | }, 64 | "country": { 65 | "type": "string", 66 | "description": "The country, such as 'USA'.", 67 | }, 68 | "zip": { 69 | "type": "string", 70 | "description": "The zip code, such as '12345'.", 71 | }, 72 | }, 73 | "required": [ 74 | "user_id", 75 | "address1", 76 | "address2", 77 | "city", 78 | "state", 79 | "country", 80 | "zip", 81 | ], 82 | }, 83 | }, 84 | } 85 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/return_delivered_order_items.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import json 4 | from typing import Any, Dict, List 5 | from tau_bench.envs.tool import Tool 6 | 7 | 8 | class ReturnDeliveredOrderItems(Tool): 9 | @staticmethod 10 | def invoke( 11 | data: Dict[str, Any], order_id: str, item_ids: List[str], payment_method_id: str 12 | ) -> str: 13 | orders = data["orders"] 14 | 15 | # Check if the order exists and is delivered 16 | if order_id not in orders: 17 | return "Error: order not found" 18 | order = orders[order_id] 19 | if order["status"] != "delivered": 20 | return "Error: non-delivered order cannot be returned" 21 | 22 | # Check if the payment method exists and is either the original payment method or a gift card 23 | if payment_method_id not in data["users"][order["user_id"]]["payment_methods"]: 24 | return "Error: payment method not found" 25 | if ( 26 | "gift_card" not in payment_method_id 27 | and payment_method_id != order["payment_history"][0]["payment_method_id"] 28 | ): 29 | return "Error: payment method should be either the original payment method or a gift card" 30 | 31 | # Check if the items to be returned exist (there could be duplicate items in either list) 32 | all_item_ids = [item["item_id"] for item in order["items"]] 33 | for item_id in item_ids: 34 | if item_ids.count(item_id) > all_item_ids.count(item_id): 35 | return "Error: some item not found" 36 | 37 | # Update the order status 38 | order["status"] = "return requested" 39 | order["return_items"] = sorted(item_ids) 40 | order["return_payment_method_id"] = payment_method_id 41 | 42 | return json.dumps(order) 43 | 44 | @staticmethod 45 | def get_info() -> Dict[str, Any]: 46 | return { 47 | "type": "function", 48 | "function": { 49 | "name": "return_delivered_order_items", 50 | "description": ( 51 | "Return some items of a delivered order. The order status will be changed to 'return requested'. " 52 | "The agent needs to explain the return detail and ask for explicit user confirmation (yes/no) to proceed. " 53 | "The user will receive follow-up email for how and where to return the item." 54 | ), 55 | "parameters": { 56 | "type": "object", 57 | "properties": { 58 | "order_id": { 59 | "type": "string", 60 | "description": ( 61 | "The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning of the order id." 62 | ), 63 | }, 64 | "item_ids": { 65 | "type": "array", 66 | "items": {"type": "string"}, 67 | "description": ( 68 | "The item ids to be returned, each such as '1008292230'. There could be duplicate items in the list." 69 | ), 70 | }, 71 | "payment_method_id": { 72 | "type": "string", 73 | "description": ( 74 | "The payment method id to pay or receive refund for the item price difference, such as 'gift_card_0000000' or 'credit_card_0000000'. " 75 | "These can be looked up from the user or order details." 76 | ), 77 | }, 78 | }, 79 | "required": ["order_id", "item_ids", "payment_method_id"], 80 | }, 81 | }, 82 | } 83 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/think.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class Think(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], thought: str) -> str: 10 | # This method does not change the state of the data; it simply returns an empty string. 11 | return "" 12 | 13 | @staticmethod 14 | def get_info() -> Dict[str, Any]: 15 | return { 16 | "type": "function", 17 | "function": { 18 | "name": "think", 19 | "description": ( 20 | "Use the tool to think about something. It will not obtain new information or change the database, " 21 | "but just append the thought to the log. Use it when complex reasoning or some cache memory is needed." 22 | ), 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "thought": { 27 | "type": "string", 28 | "description": "A thought to think about.", 29 | }, 30 | }, 31 | "required": ["thought"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/tools/transfer_to_human_agents.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from typing import Any, Dict 4 | from tau_bench.envs.tool import Tool 5 | 6 | 7 | class TransferToHumanAgents(Tool): 8 | @staticmethod 9 | def invoke(data: Dict[str, Any], summary: str) -> str: 10 | # This method simulates the transfer to a human agent. 11 | return "Transfer successful" 12 | 13 | @staticmethod 14 | def get_info() -> Dict[str, Any]: 15 | return { 16 | "type": "function", 17 | "function": { 18 | "name": "transfer_to_human_agents", 19 | "description": ( 20 | "Transfer the user to a human agent, with a summary of the user's issue. " 21 | "Only transfer if the user explicitly asks for a human agent, or if the user's issue cannot be resolved by the agent with the available tools." 22 | ), 23 | "parameters": { 24 | "type": "object", 25 | "properties": { 26 | "summary": { 27 | "type": "string", 28 | "description": "A summary of the user's issue.", 29 | }, 30 | }, 31 | "required": ["summary"], 32 | }, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/retail/wiki.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import os 4 | 5 | FOLDER_PATH = os.path.dirname(__file__) 6 | 7 | with open(os.path.join(FOLDER_PATH, "wiki.md"), "r") as f: 8 | WIKI = f.read() 9 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/tool.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any 3 | 4 | 5 | class Tool(abc.ABC): 6 | @staticmethod 7 | def invoke(*args, **kwargs): 8 | raise NotImplementedError 9 | 10 | @staticmethod 11 | def get_info() -> dict[str, Any]: 12 | raise NotImplementedError 13 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/envs/user.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | import abc 4 | from litellm import completion 5 | 6 | from typing import Optional, List, Dict, Any 7 | 8 | 9 | class BaseUserSimulationEnv(abc.ABC): 10 | metadata = {} 11 | 12 | @abc.abstractmethod 13 | def reset(self, instruction: Optional[str] = None) -> str: 14 | raise NotImplementedError 15 | 16 | @abc.abstractmethod 17 | def step(self, content: str) -> str: 18 | raise NotImplementedError 19 | 20 | @abc.abstractmethod 21 | def get_total_cost(self) -> float: 22 | raise NotImplementedError 23 | 24 | 25 | class HumanUserSimulationEnv(BaseUserSimulationEnv): 26 | def reset(self, instruction: str) -> str: 27 | return input(f"{instruction}\n") 28 | 29 | def step(self, content: str) -> str: 30 | return input(f"{content}\n") 31 | 32 | def get_total_cost(self) -> float: 33 | return 0 34 | 35 | 36 | def build_system_prompt(instruction: Optional[str]) -> str: 37 | inst = ("\n\nInstruction: " + instruction + "\n") if instruction is not None else "" 38 | return f"""You are an user interacting with an agent.{inst} 39 | Rules: 40 | - Just generate one line at a time to simulate the user's message. 41 | - Do not give away all the instruction at once. Only provide the information that is necessary for the current step. 42 | - Do not hallucinate information that is not provided in the instruction. For example, if the agent asks for the order id but it is not mentioned in the instruction, do not make up an order id, just say you do not remember or have it. 43 | - If the instruction goal is satisified, generate '###STOP###' as a standalone message without anything else to end the conversation. 44 | - Do not repeat the exact instruction in the conversation. Instead, use your own words to convey the same information. 45 | - Try to make the conversation as natural as possible, and stick to the personalities in the instruction. 46 | """ 47 | 48 | 49 | class LLMUserSimulationEnv(BaseUserSimulationEnv): 50 | def __init__(self, model: str, provider: str) -> None: 51 | super().__init__() 52 | self.messages: List[Dict[str, Any]] = [] 53 | self.model = model 54 | self.provider = provider 55 | self.total_cost = 0.0 56 | self.reset() 57 | 58 | def reset(self, instruction: Optional[str] = None) -> str: 59 | self.messages = [ 60 | { 61 | "role": "system", 62 | "content": build_system_prompt(instruction=instruction), 63 | }, 64 | {"role": "user", "content": "Hi! How can I help you today?"}, 65 | ] 66 | res = completion( 67 | model=self.model, custom_llm_provider=self.provider, messages=self.messages, temperature=0.0 68 | ) 69 | message = res.choices[0].message 70 | self.messages.append(message.model_dump()) 71 | self.total_cost = res._hidden_params["response_cost"] 72 | return message.content 73 | 74 | def step(self, content: str) -> str: 75 | self.messages.append({"role": "user", "content": content}) 76 | res = completion( 77 | model=self.model, custom_llm_provider=self.provider, messages=self.messages, temperature=0.0 78 | ) 79 | message = res.choices[0].message 80 | self.messages.append(message.model_dump()) 81 | self.total_cost += res._hidden_params["response_cost"] 82 | return message.content 83 | 84 | def get_total_cost(self) -> float: 85 | return self.total_cost 86 | 87 | 88 | def load_user( 89 | user_strategy: str, model: Optional[str] = "gpt-4o", provider: Optional[str] = None 90 | ) -> BaseUserSimulationEnv: 91 | if user_strategy == "human": 92 | return HumanUserSimulationEnv() 93 | elif user_strategy == "llm": 94 | if model is None: 95 | raise ValueError("LLM user strategy requires a model") 96 | if provider is None: 97 | raise ValueError("LLM user strategy requires a model provider") 98 | return LLMUserSimulationEnv(model=model, provider=provider) 99 | else: 100 | raise ValueError(f"Unknown user strategy {user_strategy}") 101 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from tau_bench.model_utils.api.api import API as API 2 | from tau_bench.model_utils.api.api import BinaryClassifyDatapoint as BinaryClassifyDatapoint 3 | from tau_bench.model_utils.api.api import ClassifyDatapoint as ClassifyDatapoint 4 | from tau_bench.model_utils.api.api import GenerateDatapoint as GenerateDatapoint 5 | from tau_bench.model_utils.api.api import ParseDatapoint as ParseDatapoint 6 | from tau_bench.model_utils.api.api import ParseForceDatapoint as ParseForceDatapoint 7 | from tau_bench.model_utils.api.api import ScoreDatapoint as ScoreDatapoint 8 | from tau_bench.model_utils.api.api import default_api as default_api 9 | from tau_bench.model_utils.api.api import default_quick_api as default_quick_api 10 | from tau_bench.model_utils.api.datapoint import Datapoint as Datapoint 11 | from tau_bench.model_utils.api.datapoint import EvaluationResult as EvaluationResult 12 | from tau_bench.model_utils.api.datapoint import datapoint_factory as datapoint_factory 13 | from tau_bench.model_utils.api.datapoint import load_from_disk as load_from_disk 14 | from tau_bench.model_utils.api.exception import APIError as APIError 15 | from tau_bench.model_utils.api.sample import ( 16 | EnsembleSamplingStrategy as EnsembleSamplingStrategy, 17 | ) 18 | from tau_bench.model_utils.api.sample import ( 19 | MajoritySamplingStrategy as MajoritySamplingStrategy, 20 | ) 21 | from tau_bench.model_utils.api.sample import ( 22 | RedundantSamplingStrategy as RedundantSamplingStrategy, 23 | ) 24 | from tau_bench.model_utils.api.sample import RetrySamplingStrategy as RetrySamplingStrategy 25 | from tau_bench.model_utils.api.sample import SamplingStrategy as SamplingStrategy 26 | from tau_bench.model_utils.api.sample import SingleSamplingStrategy as SingleSamplingStrategy 27 | from tau_bench.model_utils.api.sample import ( 28 | UnanimousSamplingStrategy as UnanimousSamplingStrategy, 29 | ) 30 | from tau_bench.model_utils.api.sample import ( 31 | get_default_sampling_strategy as get_default_sampling_strategy, 32 | ) 33 | from tau_bench.model_utils.api.sample import ( 34 | set_default_sampling_strategy as set_default_sampling_strategy, 35 | ) 36 | from tau_bench.model_utils.model.chat import PromptSuffixStrategy as PromptSuffixStrategy 37 | from tau_bench.model_utils.model.exception import ModelError as ModelError 38 | from tau_bench.model_utils.model.general_model import GeneralModel as GeneralModel 39 | from tau_bench.model_utils.model.general_model import default_model as default_model 40 | from tau_bench.model_utils.model.general_model import model_factory as model_factory 41 | from tau_bench.model_utils.model.model import BinaryClassifyModel as BinaryClassifyModel 42 | from tau_bench.model_utils.model.model import ClassifyModel as ClassifyModel 43 | from tau_bench.model_utils.model.model import GenerateModel as GenerateModel 44 | from tau_bench.model_utils.model.model import ParseForceModel as ParseForceModel 45 | from tau_bench.model_utils.model.model import ParseModel as ParseModel 46 | from tau_bench.model_utils.model.model import Platform as Platform 47 | from tau_bench.model_utils.model.model import ScoreModel as ScoreModel 48 | from tau_bench.model_utils.model.openai import OpenAIModel as OpenAIModel 49 | from tau_bench.model_utils.model.utils import InputType as InputType 50 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/tau-bench-eval/tau_bench/model_utils/api/__init__.py -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/api/_model_methods.py: -------------------------------------------------------------------------------- 1 | MODEL_METHODS = [ 2 | "classify", 3 | "binary_classify", 4 | "parse", 5 | "generate", 6 | "parse_force", 7 | "score", 8 | ] 9 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/api/cache.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import hashlib 3 | import inspect 4 | import threading 5 | from collections import defaultdict 6 | from multiprocessing import Lock 7 | from typing import Any, Callable, TypeVar 8 | 9 | from pydantic import BaseModel 10 | 11 | T = TypeVar("T") 12 | 13 | USE_CACHE = True 14 | _USE_CACHE_LOCK = Lock() 15 | cache: dict[str, tuple[T, threading.Event]] = {} 16 | lock = threading.Lock() 17 | conditions = defaultdict(threading.Condition) 18 | 19 | 20 | def disable_cache(): 21 | global USE_CACHE 22 | with _USE_CACHE_LOCK: 23 | USE_CACHE = False 24 | 25 | 26 | def enable_cache(): 27 | global USE_CACHE 28 | with _USE_CACHE_LOCK: 29 | USE_CACHE = True 30 | 31 | 32 | def hash_item(item: Any) -> int: 33 | if isinstance(item, dict): 34 | return hash(tuple({k: hash_item(v) for k, v in sorted(item.items())})) 35 | elif isinstance(item, list): 36 | return hash(tuple([hash_item(x) for x in item])) 37 | elif isinstance(item, set): 38 | return hash(frozenset([hash_item(x) for x in item])) 39 | elif isinstance(item, tuple): 40 | return hash(tuple([hash_item(x) for x in item])) 41 | elif isinstance(item, BaseModel): 42 | return hash_item(item.model_json_schema()) 43 | return hash(item) 44 | 45 | 46 | def hash_func_call(func: Callable[..., Any], args: tuple[Any], kwargs: dict[str, Any]) -> str: 47 | bound_args = inspect.signature(func).bind(*args, **kwargs) 48 | bound_args.apply_defaults() 49 | standardized_args = sorted(bound_args.arguments.items()) 50 | arg_hash = hash_item(standardized_args) 51 | hashed_func = id(func) 52 | call = (hashed_func, arg_hash) 53 | return hashlib.md5(str(call).encode()).hexdigest() 54 | 55 | 56 | def cache_call_w_dedup(func: Callable[..., T]) -> Callable[..., T]: 57 | @functools.wraps(func) 58 | def wrapper(*args: Any, **kwargs: Any) -> T: 59 | if not USE_CACHE: 60 | return func(*args, **kwargs) 61 | key = hash_func_call(func=func, args=args, kwargs=kwargs) 62 | if key in cache: 63 | result, event = cache[key] 64 | if event.is_set(): 65 | return result 66 | else: 67 | with lock: 68 | cache[key] = (None, threading.Event()) 69 | 70 | condition = conditions[key] 71 | with condition: 72 | if cache[key][1].is_set(): 73 | return cache[key][0] 74 | if not cache[key][0]: 75 | try: 76 | result = func(*args, **kwargs) 77 | with lock: 78 | cache[key] = (result, threading.Event()) 79 | cache[key][1].set() 80 | except Exception as e: 81 | with lock: 82 | cache[key] = (e, threading.Event()) 83 | cache[key][1].set() 84 | raise e 85 | return cache[key][0] 86 | 87 | return wrapper 88 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/api/exception.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from concurrent.futures import ThreadPoolExecutor 5 | from typing import Any, Callable, TypeVar 6 | 7 | from tau_bench.model_utils.model.exception import ModelError, Result 8 | 9 | T = TypeVar("T") 10 | 11 | _REPORT_DIR = os.path.expanduser("~/.llm-primitives/log") 12 | 13 | 14 | def set_report_dir(path: str) -> None: 15 | global _REPORT_DIR 16 | _REPORT_DIR = path 17 | 18 | 19 | def get_report_dir() -> str: 20 | return _REPORT_DIR 21 | 22 | 23 | def log_report_to_disk(report: dict[str, Any], path: str) -> None: 24 | with open(path, "w") as f: 25 | json.dump(report, f, indent=4) 26 | 27 | 28 | def generate_report_location() -> str: 29 | if not os.path.exists(_REPORT_DIR): 30 | os.makedirs(_REPORT_DIR) 31 | return os.path.join(_REPORT_DIR, f"report-{time.time_ns()}.json") 32 | 33 | 34 | class APIError(Exception): 35 | def __init__(self, short_message: str, report: dict[str, Any] | None = None) -> None: 36 | self.report_path = generate_report_location() 37 | self.short_message = short_message 38 | self.report = report 39 | if self.report is not None: 40 | log_report_to_disk( 41 | report={"error_type": "APIError", "report": report}, path=self.report_path 42 | ) 43 | super().__init__(f"{short_message}\n\nSee the full report at {self.report_path}") 44 | 45 | 46 | def execute_and_filter_model_errors( 47 | funcs: list[Callable[[], T]], 48 | max_concurrency: int | None = None, 49 | ) -> list[T] | list[ModelError]: 50 | def _invoke_w_o_llm_error(invocable: Callable[[], T]) -> Result: 51 | try: 52 | return Result(value=invocable(), error=None) 53 | except ModelError as e: 54 | return Result(value=None, error=e) 55 | 56 | with ThreadPoolExecutor(max_workers=max_concurrency) as executor: 57 | results = list(executor.map(_invoke_w_o_llm_error, funcs)) 58 | 59 | errors: list[ModelError] = [] 60 | values = [] 61 | for res in results: 62 | if res.error is not None: 63 | errors.append(res.error) 64 | else: 65 | values.append(res.value) 66 | if len(values) == 0: 67 | assert len(errors) > 0 68 | raise errors[0] 69 | return values 70 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/api/logging.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | import json 4 | from multiprocessing import Lock 5 | from typing import Any 6 | 7 | from pydantic import BaseModel 8 | 9 | from tau_bench.model_utils.api.sample import SamplingStrategy 10 | from tau_bench.model_utils.model.utils import optionalize_type 11 | 12 | log_files = {} 13 | 14 | 15 | def prep_for_json_serialization(obj: Any, from_parse_method: bool = False): 16 | # TODO: refine type annotations 17 | if isinstance(obj, (str, int, float, bool, type(None))): 18 | return obj 19 | elif isinstance(obj, dict): 20 | return {k: prep_for_json_serialization(v) for k, v in obj.items()} 21 | elif isinstance(obj, list): 22 | return [prep_for_json_serialization(v) for v in obj] 23 | elif isinstance(obj, tuple): 24 | return tuple(prep_for_json_serialization(v) for v in obj) 25 | elif isinstance(obj, set): 26 | return {prep_for_json_serialization(v) for v in obj} 27 | elif isinstance(obj, frozenset): 28 | return frozenset(prep_for_json_serialization(v) for v in obj) 29 | elif isinstance(obj, BaseModel): 30 | return obj.model_dump(mode="json") 31 | elif isinstance(obj, type) and issubclass(obj, BaseModel): 32 | if from_parse_method: 33 | optionalized_type = optionalize_type(obj) 34 | return optionalized_type.model_json_schema() 35 | else: 36 | return obj.model_json_schema() 37 | elif isinstance(obj, SamplingStrategy): 38 | return obj.__class__.__name__ 39 | else: 40 | raise TypeError(f"Object of type {type(obj)} is not JSON serializable") 41 | 42 | 43 | def log_call(func): 44 | @functools.wraps(func) 45 | def wrapper(self, *args, **kwargs): 46 | response = func(self, *args, **kwargs) 47 | log_file = getattr(self, "_log_file", None) 48 | if log_file is not None: 49 | if log_file not in log_files: 50 | log_files[log_file] = Lock() 51 | sig = inspect.signature(func) 52 | bound_args = sig.bind(self, *args, **kwargs) 53 | bound_args.apply_defaults() 54 | all_args = bound_args.arguments 55 | all_args.pop("self", None) 56 | 57 | cls_name = self.__class__.__name__ 58 | log_entry = { 59 | "cls_name": cls_name, 60 | "method_name": func.__name__, 61 | "kwargs": { 62 | k: prep_for_json_serialization( 63 | v, from_parse_method=func.__name__ in ["parse", "async_parse"] 64 | ) 65 | for k, v in all_args.items() 66 | }, 67 | "response": prep_for_json_serialization(response), 68 | } 69 | with log_files[log_file]: 70 | with open(log_file, "a") as f: 71 | f.write(f"{json.dumps(log_entry)}\n") 72 | return response 73 | 74 | return wrapper 75 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/api/router.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from pydantic import BaseModel 4 | 5 | from tau_bench.model_utils.api.datapoint import Datapoint, ScoreDatapoint 6 | from tau_bench.model_utils.model.model import Model 7 | 8 | 9 | class RequestRouter(abc.ABC): 10 | @abc.abstractmethod 11 | def route(self, dp: Datapoint, available_models: list[Model]) -> Model: 12 | raise NotImplementedError 13 | 14 | 15 | class FirstModelRequestRouter(RequestRouter): 16 | def route(self, dp: Datapoint, available_models: list[Model]) -> Model: 17 | supporting_models = [model for model in available_models if model.supports_dp(dp)] 18 | if len(supporting_models) == 0: 19 | raise ValueError(f"No supporting models found from {available_models}") 20 | return supporting_models[0] 21 | 22 | 23 | class CapabilityScoreModel(abc.ABC): 24 | @abc.abstractmethod 25 | def score_dp(self, dp: Datapoint) -> float: 26 | raise NotImplementedError 27 | 28 | 29 | class PromptedLLMCapabilityScoreModel: 30 | def __init__(self, model: Model | None = None) -> None: 31 | if model is None: 32 | from tau_bench.model_utils.model.claude import ClaudeModel 33 | 34 | # claude is used as the default model as it is better at meta-level tasks 35 | model = ClaudeModel() 36 | self.model = model 37 | 38 | def score_dp(self, dp: Datapoint, examples: list[ScoreDatapoint] | None = None) -> float: 39 | return ( 40 | self.model.score( 41 | instruction="Score the task in the datapoint on a scale of 1 (least complex) to 10 (most complex).", 42 | text=f"----- start task -----\n{dp.model_dump_json()}\n----- end task -----", 43 | min=1, 44 | max=10, 45 | examples=examples, 46 | ) 47 | / 10.0 48 | ) 49 | 50 | 51 | class MinimumCapabilityRequestRouter(RequestRouter): 52 | def __init__(self, capability_score_model: CapabilityScoreModel) -> None: 53 | self.capability_score_model = capability_score_model 54 | 55 | def route(self, dp: Datapoint, available_models: list[Model]) -> Model: 56 | supporting_models = [model for model in available_models if model.supports_dp(dp)] 57 | if len(supporting_models) == 0: 58 | raise ValueError(f"No supporting models found from {available_models}") 59 | required_capability = self.capability_score_model.score_dp(dp) 60 | minimum_model: Model | None = None 61 | minimum_model_capability: float | None = None 62 | for model in supporting_models: 63 | capability = model.get_capability() 64 | if capability >= required_capability and ( 65 | minimum_model_capability is None or capability < minimum_model_capability 66 | ): 67 | minimum_model = model 68 | minimum_model_capability = capability 69 | if minimum_model is None: 70 | raise ValueError(f"No model found with capability >= {required_capability}") 71 | return minimum_model 72 | 73 | 74 | def request_router_factory( 75 | router_id: str, capability_score_model: CapabilityScoreModel | None = None 76 | ) -> RequestRouter: 77 | if router_id == "first-model": 78 | return FirstModelRequestRouter() 79 | elif router_id == "minimum-capability": 80 | if capability_score_model is None: 81 | raise ValueError("CapabilityScoreModel is required for minimum-capability router") 82 | return MinimumCapabilityRequestRouter(capability_score_model=capability_score_model) 83 | raise ValueError(f"Unknown router_id: {router_id}") 84 | 85 | 86 | def default_request_router() -> RequestRouter: 87 | return FirstModelRequestRouter() 88 | 89 | 90 | class RequestRouteDatapoint(BaseModel): 91 | dp: Datapoint 92 | capability_score: float 93 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/api/tokens.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from pydantic import BaseModel 4 | 5 | from tau_bench.model_utils.api.datapoint import ( 6 | BinaryClassifyDatapoint, 7 | ClassifyDatapoint, 8 | Datapoint, 9 | GenerateDatapoint, 10 | ParseDatapoint, 11 | ParseForceDatapoint, 12 | ScoreDatapoint, 13 | ) 14 | 15 | 16 | class TokenUsage(BaseModel): 17 | input_tokens: int 18 | output_tokens: int 19 | by_primitive: dict[str, "TokenUsage"] 20 | 21 | 22 | def batch_token_analysis(dps: list[Datapoint], encoding_for_model: str = "gpt-4o") -> TokenUsage: 23 | import tiktoken 24 | 25 | enc = tiktoken.encoding_for_model(encoding_for_model) 26 | # very rough estimates 27 | inputs_by_primitive: dict[str, list[str]] = {} 28 | outputs_by_primitive: dict[str, list[str]] = {} 29 | for dp in dps: 30 | input = json.dumps({k: v for k, v in dp.model_dump().items() if k != "response"}) 31 | inputs_by_primitive.setdefault(type(dp).__name__, []).append(input) 32 | if isinstance(dp, ClassifyDatapoint): 33 | output = f'{{"classification": {dp.response}}}' 34 | elif isinstance(dp, BinaryClassifyDatapoint): 35 | output = f'{{"classification": {0 if dp.response else 1}}}' 36 | elif isinstance(dp, ParseForceDatapoint): 37 | output = ( 38 | json.dumps(dp.response) 39 | if isinstance(dp.response, dict) 40 | else dp.response.model_dump_json() 41 | ) 42 | elif isinstance(dp, GenerateDatapoint): 43 | output = json.dumps(dp.response) 44 | elif isinstance(dp, ParseDatapoint): 45 | output = ( 46 | json.dumps(dp.response) 47 | if isinstance(dp.response, dict) 48 | else dp.response.model_dump_json() 49 | ) 50 | elif isinstance(dp, ScoreDatapoint): 51 | output = f"{{'score': {dp.response}}}" 52 | else: 53 | raise ValueError(f"Unknown datapoint type: {type(dp)}") 54 | outputs_by_primitive.setdefault(type(dp).__name__, []).append(output) 55 | input_tokens_by_primitive = {} 56 | output_tokens_by_primitive = {} 57 | for primitive, inputs in inputs_by_primitive.items(): 58 | input_tokens = sum([len(item) for item in enc.encode_batch(inputs)]) 59 | input_tokens_by_primitive[primitive] = input_tokens 60 | for primitive, outputs in outputs_by_primitive.items(): 61 | output_tokens = sum([len(item) for item in enc.encode_batch(outputs)]) 62 | output_tokens_by_primitive[primitive] = output_tokens 63 | return TokenUsage( 64 | input_tokens=sum(input_tokens_by_primitive.values()), 65 | output_tokens=sum(output_tokens_by_primitive.values()), 66 | by_primitive={ 67 | primitive: TokenUsage( 68 | input_tokens=input_tokens_by_primitive.get(primitive, 0), 69 | output_tokens=output_tokens_by_primitive.get(primitive, 0), 70 | by_primitive={}, 71 | ) 72 | for primitive in set(input_tokens_by_primitive.keys()) 73 | | set(output_tokens_by_primitive.keys()) 74 | }, 75 | ) 76 | 77 | 78 | def token_analysis(dp: Datapoint, encoding_for_model: str = "gpt-4o") -> TokenUsage: 79 | return batch_token_analysis([dp], encoding_for_model) 80 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/api/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | PartialObj = dict[str, Any] 4 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/func_tools/__init__.py: -------------------------------------------------------------------------------- 1 | from tau_bench.model_utils.func_tools.filter import filter as filter 2 | from tau_bench.model_utils.func_tools.map import map as map 3 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/func_tools/filter.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable, TypeVar 2 | 3 | from tau_bench.model_utils.func_tools.map import map 4 | 5 | T = TypeVar("T") 6 | 7 | builtin_filter = filter 8 | 9 | 10 | def filter( 11 | func: Callable[[T], bool], 12 | iterable: Iterable[T], 13 | max_concurrency: int | None = None, 14 | ) -> Iterable[T]: 15 | assert max_concurrency is None or max_concurrency > 0 16 | bits = map(func, iterable=iterable, max_concurrency=max_concurrency) 17 | return [x for x, y in zip(iterable, bits) if y] 18 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/func_tools/map.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | from typing import Callable, Iterable, TypeVar 3 | 4 | T = TypeVar("T") 5 | U = TypeVar("U") 6 | 7 | 8 | def map( 9 | func: Callable[[T], U], 10 | iterable: Iterable[T], 11 | max_concurrency: int | None = None, 12 | use_tqdm: bool = False, 13 | ) -> Iterable[U]: 14 | assert max_concurrency is None or max_concurrency > 0 15 | with ThreadPoolExecutor(max_workers=max_concurrency) as executor: 16 | if use_tqdm: 17 | from tqdm import tqdm 18 | 19 | return list(tqdm(executor.map(func, iterable), total=len(iterable))) 20 | return executor.map(func, iterable) 21 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/josh-llm-simulation-training/9c3a6076751ca3ac07e20a01fe5da780e42d23b0/tau-bench-eval/tau_bench/model_utils/model/__init__.py -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/anyscale.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tau_bench.model_utils.api.datapoint import Datapoint 4 | from tau_bench.model_utils.model.chat import ChatModel, Message 5 | from tau_bench.model_utils.model.completion import approx_cost_for_datapoint, approx_prompt_str 6 | from tau_bench.model_utils.model.general_model import wrap_temperature 7 | from tau_bench.model_utils.model.utils import approx_num_tokens 8 | 9 | API_KEY_ENV_VAR = "ANYSCALE_API_KEY" 10 | BASE_URL = "https://api.endpoints.anyscale.com/v1" 11 | 12 | PRICE_PER_INPUT_TOKEN_MAP = {"meta-llama/Meta-Llama-3-8B-Instruct": ...} 13 | INPUT_PRICE_PER_TOKEN_FALLBACK = 10 / 1000000 14 | 15 | CAPABILITY_SCORE_MAP = { 16 | "meta-llama/Meta-Llama-3-8B-Instruct": 0.2, 17 | "meta-llama/Meta-Llama-3-70B-Instruct": 0.6, 18 | } 19 | CAPABILITY_SCORE_FALLBACK = 0.2 20 | 21 | # TODO: implement 22 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 23 | # TODO: implement 24 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 25 | 26 | MAX_CONTEXT_LENGTH_MAP = { 27 | "meta-llama/Meta-Llama-3-8B-Instruct": 8192, 28 | "meta-llama/Meta-Llama-3-70B-Instruct": 8192, 29 | } 30 | MAX_CONTEXT_LENGTH_FALLBACK = 8192 31 | 32 | 33 | class AnyscaleModel(ChatModel): 34 | def __init__( 35 | self, 36 | model: str, 37 | api_key: str | None = None, 38 | temperature: float = 0.0, 39 | ) -> None: 40 | from openai import AsyncOpenAI, OpenAI 41 | 42 | self.model = model 43 | 44 | api_key = None 45 | if api_key is None: 46 | api_key = os.getenv(API_KEY_ENV_VAR) 47 | if api_key is None: 48 | raise ValueError(f"{API_KEY_ENV_VAR} environment variable is not set") 49 | self.client = OpenAI(api_key=api_key, base_url=BASE_URL) 50 | self.async_client = AsyncOpenAI(api_key=api_key, base_url=BASE_URL) 51 | self.temperature = temperature 52 | 53 | def generate_message( 54 | self, 55 | messages: list[Message], 56 | force_json: bool, 57 | temperature: float | None = None, 58 | ) -> Message: 59 | if temperature is None: 60 | temperature = self.temperature 61 | msgs = self.build_generate_message_state(messages) 62 | res = self.client.chat.completions.create( 63 | model=self.model, 64 | messages=msgs, 65 | temperature=wrap_temperature(temperature), 66 | response_format={"type": "json_object" if force_json else "text"}, 67 | ) 68 | return self.handle_generate_message_response( 69 | prompt=msgs, content=res.choices[0].message.content, force_json=force_json 70 | ) 71 | 72 | def get_approx_cost(self, dp: Datapoint) -> float: 73 | cost_per_token = PRICE_PER_INPUT_TOKEN_MAP.get(self.model, INPUT_PRICE_PER_TOKEN_FALLBACK) 74 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 75 | 76 | def get_latency(self, dp: Datapoint) -> float: 77 | latency_per_output_token = LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get( 78 | self.model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK 79 | ) 80 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 81 | 82 | def get_capability(self) -> float: 83 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 84 | 85 | def supports_dp(self, dp: Datapoint) -> bool: 86 | prompt = approx_prompt_str(dp) 87 | return approx_num_tokens(prompt) <= MAX_CONTEXT_LENGTH_MAP.get( 88 | self.model, MAX_CONTEXT_LENGTH_FALLBACK 89 | ) 90 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/claude.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from tau_bench.model_utils.api.datapoint import Datapoint 5 | from tau_bench.model_utils.model.chat import ChatModel, Message 6 | from tau_bench.model_utils.model.completion import approx_cost_for_datapoint, approx_prompt_str 7 | from tau_bench.model_utils.model.general_model import wrap_temperature 8 | from tau_bench.model_utils.model.utils import approx_num_tokens 9 | 10 | DEFAULT_CLAUDE_MODEL = "claude-3-5-sonnet-20240620" 11 | DEFAULT_MAX_TOKENS = 8192 12 | ENV_VAR_API_KEY = "ANTHROPIC_API_KEY" 13 | 14 | PRICE_PER_INPUT_TOKEN_MAP = { 15 | "claude-3-5-sonnet-20240620": 3 / 1000000, 16 | } 17 | INPUT_PRICE_PER_TOKEN_FALLBACK = 15 / 1000000 18 | 19 | CAPABILITY_SCORE_MAP = { 20 | "claude-3-5-sonnet-20240620": 1.0, 21 | } 22 | CAPABILITY_SCORE_FALLBACK = 0.5 23 | 24 | # TODO: implement 25 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 26 | # TODO: implement 27 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 28 | 29 | MAX_CONTEXT_LENGTH_MAP = { 30 | "claude-3-5-sonnet-20240620": 8192, 31 | } 32 | MAX_CONTEXT_LENGTH_FALLBACK = 8192 33 | 34 | 35 | class ClaudeModel(ChatModel): 36 | def __init__( 37 | self, 38 | model: str | None = None, 39 | api_key: str | None = None, 40 | temperature: float = 0.0, 41 | ) -> None: 42 | from anthropic import Anthropic, AsyncAnthropic 43 | 44 | if model is None: 45 | self.model = DEFAULT_CLAUDE_MODEL 46 | else: 47 | self.model = model 48 | 49 | api_key = None 50 | if api_key is None: 51 | api_key = os.getenv(ENV_VAR_API_KEY) 52 | if api_key is None: 53 | raise ValueError(f"{ENV_VAR_API_KEY} environment variable is not set") 54 | # `anthropic-beta` header is needed for the 8192 context length (https://docs.anthropic.com/en/docs/about-claude/models) 55 | self.client = Anthropic( 56 | api_key=api_key, default_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} 57 | ) 58 | self.async_client = AsyncAnthropic(api_key=api_key) 59 | self.temperature = temperature 60 | 61 | def get_approx_cost(self, dp: Datapoint) -> float: 62 | cost_per_token = PRICE_PER_INPUT_TOKEN_MAP.get(self.model, INPUT_PRICE_PER_TOKEN_FALLBACK) 63 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 64 | 65 | def get_latency(self, dp: Datapoint) -> float: 66 | latency_per_output_token = LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get( 67 | self.model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK 68 | ) 69 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 70 | 71 | def get_capability(self) -> float: 72 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 73 | 74 | def supports_dp(self, dp: Datapoint) -> bool: 75 | prompt = approx_prompt_str(dp) 76 | return approx_num_tokens(prompt) <= MAX_CONTEXT_LENGTH_MAP.get( 77 | self.model, MAX_CONTEXT_LENGTH_FALLBACK 78 | ) 79 | 80 | def _remap_messages(self, messages: list[dict[str, str]]) -> list[dict[str, str]]: 81 | remapped: list[dict[str, str]] = [] 82 | is_user = True 83 | for i, message in enumerate(messages): 84 | role = message["role"] 85 | if role == "assistant": 86 | if i == 0: 87 | raise ValueError( 88 | f"First message must be a system or user message, got {[m['role'] for m in messages]}" 89 | ) 90 | elif is_user: 91 | raise ValueError( 92 | f"Must alternate between user and assistant, got {[m['role'] for m in messages]}" 93 | ) 94 | remapped.append(message) 95 | is_user = True 96 | else: 97 | if is_user: 98 | remapped.append({"role": "user", "content": message["content"]}) 99 | is_user = False 100 | else: 101 | if remapped[-1]["role"] != "user": 102 | raise ValueError( 103 | f"Invalid sequence, expected user message but got {[m['role'] for m in messages]}" 104 | ) 105 | remapped[-1]["content"] += "\n\n" + message["content"] 106 | return remapped 107 | 108 | def build_generate_message_state( 109 | self, 110 | messages: list[Message], 111 | ) -> list[dict[str, str]]: 112 | msgs: list[dict[str, str]] = [] 113 | for msg in messages: 114 | if msg.obj is not None: 115 | content = json.dumps(msg.obj) 116 | else: 117 | content = msg.content 118 | msgs.append({"role": msg.role.value, "content": content}) 119 | return self._remap_messages(msgs) 120 | 121 | def generate_message( 122 | self, 123 | messages: list[Message], 124 | force_json: bool, 125 | temperature: float | None = None, 126 | ) -> Message: 127 | if temperature is None: 128 | temperature = self.temperature 129 | msgs = self.build_generate_message_state(messages) 130 | res = self.client.messages.create( 131 | model=self.model, 132 | messages=msgs, 133 | temperature=wrap_temperature(temperature), 134 | max_tokens=DEFAULT_MAX_TOKENS, 135 | ) 136 | return self.handle_generate_message_response( 137 | prompt=msgs, content=res.content[0].text, force_json=force_json 138 | ) 139 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/exception.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Generic, TypeVar 3 | 4 | T = TypeVar("T") 5 | 6 | 7 | class ModelError(Exception): 8 | def __init__( 9 | self, 10 | short_message: str, 11 | prompt: str | list[dict[str, str]] | None = None, 12 | response: str | None = None, 13 | ) -> None: 14 | super().__init__(short_message) 15 | self.short_message = short_message 16 | self.prompt = prompt 17 | self.response = response 18 | 19 | 20 | @dataclass 21 | class Result(Generic[T]): 22 | value: T | None 23 | error: ModelError | None 24 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/mistral.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tau_bench.model_utils.api.datapoint import Datapoint 4 | from tau_bench.model_utils.model.chat import ChatModel, Message 5 | from tau_bench.model_utils.model.completion import approx_cost_for_datapoint, approx_prompt_str 6 | from tau_bench.model_utils.model.general_model import wrap_temperature 7 | from tau_bench.model_utils.model.utils import approx_num_tokens 8 | 9 | DEFAULT_MISTRAL_MODEL = "mistral-large-latest" 10 | 11 | PRICE_PER_INPUT_TOKEN_MAP = { 12 | "mistral-largest-latest": 3 / 1000000, 13 | } 14 | INPUT_PRICE_PER_TOKEN_FALLBACK = 10 / 1000000 15 | 16 | CAPABILITY_SCORE_MAP = { 17 | "mistral-largest-latest": 0.9, 18 | } 19 | CAPABILITY_SCORE_FALLBACK = 0.3 20 | 21 | # TODO: implement 22 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 23 | # TODO: implement 24 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 25 | 26 | MAX_CONTEXT_LENGTH_MAP = { 27 | "mistral-largest-latest": 128000, 28 | } 29 | MAX_CONTEXT_LENGTH_FALLBACK = 128000 30 | 31 | 32 | class MistralModel(ChatModel): 33 | def __init__( 34 | self, model: str | None = None, api_key: str | None = None, temperature: float = 0.0 35 | ) -> None: 36 | from mistralai.async_client import MistralAsyncClient 37 | from mistralai.client import MistralClient 38 | 39 | if model is None: 40 | self.model = DEFAULT_MISTRAL_MODEL 41 | else: 42 | self.model = model 43 | 44 | api_key = None 45 | if api_key is None: 46 | api_key = os.getenv("MISTRAL_API_KEY") 47 | if api_key is None: 48 | raise ValueError("MISTRAL_API_KEY environment variable is not set") 49 | self.client = MistralClient(api_key=api_key) 50 | self.async_client = MistralAsyncClient(api_key=api_key) 51 | self.temperature = temperature 52 | 53 | def generate_message( 54 | self, 55 | messages: list[Message], 56 | force_json: bool, 57 | temperature: float | None = None, 58 | ) -> Message: 59 | if temperature is None: 60 | temperature = self.temperature 61 | msgs = self.build_generate_message_state(messages) 62 | res = self.client.chat( 63 | model=self.model, 64 | messages=msgs, 65 | temperature=wrap_temperature(temperature), 66 | response_format={"type": "json_object" if force_json else "text"}, 67 | ) 68 | return self.handle_generate_message_response( 69 | prompt=msgs, content=res.choices[0].message.content, force_json=force_json 70 | ) 71 | 72 | def get_approx_cost(self, dp: Datapoint) -> float: 73 | cost_per_token = PRICE_PER_INPUT_TOKEN_MAP.get(self.model, INPUT_PRICE_PER_TOKEN_FALLBACK) 74 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 75 | 76 | def get_latency(self, dp: Datapoint) -> float: 77 | latency_per_output_token = LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get( 78 | self.model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK 79 | ) 80 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 81 | 82 | def get_capability(self) -> float: 83 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 84 | 85 | def supports_dp(self, dp: Datapoint) -> bool: 86 | prompt = approx_prompt_str(dp) 87 | return approx_num_tokens(prompt) <= MAX_CONTEXT_LENGTH_MAP.get( 88 | self.model, MAX_CONTEXT_LENGTH_FALLBACK 89 | ) 90 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import enum 3 | from typing import Any, TypeVar 4 | 5 | from pydantic import BaseModel 6 | 7 | from tau_bench.model_utils.api.datapoint import ( 8 | BinaryClassifyDatapoint, 9 | ClassifyDatapoint, 10 | Datapoint, 11 | GenerateDatapoint, 12 | ParseDatapoint, 13 | ParseForceDatapoint, 14 | ScoreDatapoint, 15 | ) 16 | from tau_bench.model_utils.api.types import PartialObj 17 | 18 | T = TypeVar("T", bound=BaseModel) 19 | 20 | 21 | class Platform(enum.Enum): 22 | OPENAI = "openai" 23 | MISTRAL = "mistral" 24 | ANTHROPIC = "anthropic" 25 | ANYSCALE = "anyscale" 26 | OUTLINES = "outlines" 27 | VLLM_CHAT = "vllm-chat" 28 | VLLM_COMPLETION = "vllm-completion" 29 | 30 | 31 | # @runtime_checkable 32 | # class Model(Protocol): 33 | class Model(abc.ABC): 34 | @abc.abstractmethod 35 | def get_capability(self) -> float: 36 | """Return the capability of the model, a float between 0.0 and 1.0.""" 37 | raise NotImplementedError 38 | 39 | @abc.abstractmethod 40 | def get_approx_cost(self, dp: Datapoint) -> float: 41 | raise NotImplementedError 42 | 43 | @abc.abstractmethod 44 | def get_latency(self, dp: Datapoint) -> float: 45 | raise NotImplementedError 46 | 47 | @abc.abstractmethod 48 | def supports_dp(self, dp: Datapoint) -> bool: 49 | raise NotImplementedError 50 | 51 | 52 | class ClassifyModel(Model): 53 | @abc.abstractmethod 54 | def classify( 55 | self, 56 | instruction: str, 57 | text: str, 58 | options: list[str], 59 | examples: list[ClassifyDatapoint] | None = None, 60 | temperature: float | None = None, 61 | ) -> int: 62 | raise NotImplementedError 63 | 64 | 65 | class BinaryClassifyModel(Model): 66 | @abc.abstractmethod 67 | def binary_classify( 68 | self, 69 | instruction: str, 70 | text: str, 71 | examples: list[BinaryClassifyDatapoint] | None = None, 72 | temperature: float | None = None, 73 | ) -> bool: 74 | raise NotImplementedError 75 | 76 | 77 | class ParseModel(Model): 78 | @abc.abstractmethod 79 | def parse( 80 | self, 81 | text: str, 82 | typ: type[T] | dict[str, Any], 83 | examples: list[ParseDatapoint] | None = None, 84 | temperature: float | None = None, 85 | ) -> T | PartialObj | dict[str, Any]: 86 | raise NotImplementedError 87 | 88 | 89 | class GenerateModel(Model): 90 | @abc.abstractmethod 91 | def generate( 92 | self, 93 | instruction: str, 94 | text: str, 95 | examples: list[GenerateDatapoint] | None = None, 96 | temperature: float | None = None, 97 | ) -> str: 98 | raise NotImplementedError 99 | 100 | 101 | class ParseForceModel(Model): 102 | @abc.abstractmethod 103 | def parse_force( 104 | self, 105 | instruction: str, 106 | typ: type[T] | dict[str, Any], 107 | text: str | None = None, 108 | examples: list[ParseForceDatapoint] | None = None, 109 | temperature: float | None = None, 110 | ) -> T | dict[str, Any]: 111 | raise NotImplementedError 112 | 113 | 114 | class ScoreModel(Model): 115 | @abc.abstractmethod 116 | def score( 117 | self, 118 | instruction: str, 119 | text: str, 120 | min: int, 121 | max: int, 122 | examples: list[ScoreDatapoint] | None = None, 123 | temperature: float | None = None, 124 | ) -> int: 125 | raise NotImplementedError 126 | 127 | 128 | AnyModel = ( 129 | BinaryClassifyModel | ClassifyModel | ParseForceModel | GenerateModel | ParseModel | ScoreModel 130 | ) 131 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tau_bench.model_utils.api.datapoint import Datapoint 4 | from tau_bench.model_utils.model.chat import ChatModel, Message 5 | from tau_bench.model_utils.model.completion import approx_cost_for_datapoint, approx_prompt_str 6 | from tau_bench.model_utils.model.general_model import wrap_temperature 7 | from tau_bench.model_utils.model.utils import approx_num_tokens 8 | 9 | DEFAULT_OPENAI_MODEL = "gpt-4o-2024-08-06" 10 | API_KEY_ENV_VAR = "OPENAI_API_KEY" 11 | 12 | PRICE_PER_INPUT_TOKEN_MAP = { 13 | "gpt-4o-2024-08-06": 2.5 / 1000000, 14 | "gpt-4o": 5 / 1000000, 15 | "gpt-4o-2024-08-06": 2.5 / 1000000, 16 | "gpt-4o-2024-05-13": 5 / 1000000, 17 | "gpt-4-turbo": 10 / 1000000, 18 | "gpt-4-turbo-2024-04-09": 10 / 1000000, 19 | "gpt-4": 30 / 1000000, 20 | "gpt-4o-mini": 0.15 / 1000000, 21 | "gpt-4o-mini-2024-07-18": 0.15 / 1000000, 22 | "gpt-3.5-turbo": 0.5 / 1000000, 23 | "gpt-3.5-turbo-0125": 0.5 / 1000000, 24 | "gpt-3.5-turbo-instruct": 1.5 / 1000000, 25 | } 26 | INPUT_PRICE_PER_TOKEN_FALLBACK = 10 / 1000000 27 | 28 | CAPABILITY_SCORE_MAP = { 29 | "gpt-4o-2024-08-06": 0.8, 30 | "gpt-4o": 0.8, 31 | "gpt-4o-2024-08-06": 0.8, 32 | "gpt-4o-2024-05-13": 0.8, 33 | "gpt-4-turbo": 0.9, 34 | "gpt-4-turbo-2024-04-09": 0.9, 35 | "gpt-4": 0.8, 36 | "gpt-4o-mini": 0.5, 37 | "gpt-4o-mini-2024-07-18": 0.5, 38 | "gpt-3.5-turbo": 0.3, 39 | "gpt-3.5-turbo-0125": 0.3, 40 | } 41 | CAPABILITY_SCORE_FALLBACK = 0.3 42 | 43 | # TODO: implement 44 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 45 | # TODO: implement 46 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 47 | 48 | MAX_CONTEXT_LENGTH_MAP = { 49 | "gpt-4o-2024-08-06": 128000, 50 | "gpt-4o": 128000, 51 | "gpt-4o-2024-08-06": 128000, 52 | "gpt-4o-2024-05-13": 128000, 53 | "gpt-4-turbo": 128000, 54 | "gpt-4-turbo-2024-04-09": 128000, 55 | "gpt-4": 8192, 56 | "gpt-4o-mini": 128000, 57 | "gpt-4o-mini-2024-07-18": 128000, 58 | "gpt-3.5-turbo": 16385, 59 | "gpt-3.5-turbo-0125": 16385, 60 | } 61 | MAX_CONTEXT_LENGTH_FALLBACK = 128000 62 | 63 | 64 | class OpenAIModel(ChatModel): 65 | def __init__( 66 | self, 67 | model: str | None = None, 68 | api_key: str | None = None, 69 | temperature: float = 0.0, 70 | ) -> None: 71 | from openai import AsyncOpenAI, OpenAI 72 | 73 | if model is None: 74 | self.model = DEFAULT_OPENAI_MODEL 75 | else: 76 | self.model = model 77 | 78 | api_key = None 79 | if api_key is None: 80 | api_key = os.getenv(API_KEY_ENV_VAR) 81 | if api_key is None: 82 | raise ValueError(f"{API_KEY_ENV_VAR} environment variable is not set") 83 | self.client = OpenAI(api_key=api_key) 84 | self.async_client = AsyncOpenAI(api_key=api_key) 85 | self.temperature = temperature 86 | 87 | def generate_message( 88 | self, 89 | messages: list[Message], 90 | force_json: bool, 91 | temperature: float | None = None, 92 | ) -> Message: 93 | if temperature is None: 94 | temperature = self.temperature 95 | msgs = self.build_generate_message_state(messages) 96 | res = self.client.chat.completions.create( 97 | model=self.model, 98 | messages=msgs, 99 | temperature=wrap_temperature(temperature), 100 | response_format={"type": "json_object" if force_json else "text"}, 101 | ) 102 | return self.handle_generate_message_response( 103 | prompt=msgs, content=res.choices[0].message.content, force_json=force_json 104 | ) 105 | 106 | def get_approx_cost(self, dp: Datapoint) -> float: 107 | cost_per_token = PRICE_PER_INPUT_TOKEN_MAP.get(self.model, INPUT_PRICE_PER_TOKEN_FALLBACK) 108 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 109 | 110 | def get_latency(self, dp: Datapoint) -> float: 111 | latency_per_output_token = LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get( 112 | self.model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK 113 | ) 114 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 115 | 116 | def get_capability(self) -> float: 117 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 118 | 119 | def supports_dp(self, dp: Datapoint) -> bool: 120 | prompt = approx_prompt_str(dp) 121 | return approx_num_tokens(prompt) <= MAX_CONTEXT_LENGTH_MAP.get( 122 | self.model, MAX_CONTEXT_LENGTH_FALLBACK 123 | ) 124 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/outlines_completion.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pydantic import BaseModel 4 | 5 | from tau_bench.model_utils.api.datapoint import Datapoint 6 | from tau_bench.model_utils.model.vllm_completion import VLLMCompletionModel 7 | from tau_bench.model_utils.model.vllm_utils import generate_request 8 | 9 | 10 | class OutlinesCompletionModel(VLLMCompletionModel): 11 | def parse_force_from_prompt( 12 | self, prompt: str, typ: BaseModel, temperature: float | None = None 13 | ) -> dict[str, Any]: 14 | if temperature is None: 15 | temperature = self.temperature 16 | schema = typ.model_json_schema() 17 | res = generate_request( 18 | url=self.url, 19 | prompt=prompt, 20 | force_json=True, 21 | schema=schema, 22 | temperature=temperature, 23 | ) 24 | return self.handle_parse_force_response(prompt=prompt, content=res) 25 | 26 | def get_approx_cost(self, dp: Datapoint) -> float: 27 | return super().get_approx_cost(dp) 28 | 29 | def get_latency(self, dp: Datapoint) -> float: 30 | return super().get_latency(dp) 31 | 32 | def get_capability(self) -> float: 33 | return super().get_capability() 34 | 35 | def supports_dp(self, dp: Datapoint) -> bool: 36 | return super().supports_dp(dp) 37 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/utils.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import json 3 | import re 4 | from typing import Any, Optional, TypeVar 5 | 6 | from pydantic import BaseModel, Field 7 | 8 | from tau_bench.model_utils.api.types import PartialObj 9 | 10 | T = TypeVar("T", bound=BaseModel) 11 | 12 | 13 | class InputType(enum.Enum): 14 | CHAT = "chat" 15 | COMPLETION = "completion" 16 | 17 | 18 | def display_choices(choices: list[str]) -> tuple[str, dict[str, int]]: 19 | choice_displays = [] 20 | decode_map = {} 21 | for i, choice in enumerate(choices): 22 | label = index_to_alpha(i) 23 | choice_display = f"{label}. {choice}" 24 | choice_displays.append(choice_display) 25 | decode_map[label] = i 26 | return "\n".join(choice_displays), decode_map 27 | 28 | 29 | def index_to_alpha(index: int) -> str: 30 | alpha = "" 31 | while index >= 0: 32 | alpha = chr(index % 26 + ord("A")) + alpha 33 | index = index // 26 - 1 34 | return alpha 35 | 36 | 37 | def type_to_json_schema_string(typ: type[T]) -> str: 38 | json_schema = typ.model_json_schema() 39 | return json.dumps(json_schema, indent=4) 40 | 41 | 42 | def optionalize_type(typ: type[T]) -> type[T]: 43 | class OptionalModel(typ): 44 | ... 45 | 46 | new_fields = {} 47 | for name, field in OptionalModel.model_fields.items(): 48 | new_fields[name] = Field(default=None, annotation=Optional[field.annotation]) 49 | OptionalModel.model_fields = new_fields 50 | OptionalModel.__name__ = typ.__name__ 51 | return OptionalModel 52 | 53 | 54 | def json_response_to_obj_or_partial_obj( 55 | response: dict[str, Any], typ: type[T] | dict[str, Any] 56 | ) -> T | PartialObj | dict[str, Any]: 57 | if isinstance(typ, dict): 58 | return response 59 | else: 60 | required_field_names = [ 61 | name for name, field in typ.model_fields.items() if field.is_required() 62 | ] 63 | for name in required_field_names: 64 | if name not in response.keys() or response[name] is None: 65 | return response 66 | return typ.model_validate(response) 67 | 68 | 69 | def clean_top_level_keys(d: dict[str, Any]) -> dict[str, Any]: 70 | new_d = {} 71 | for k, v in d.items(): 72 | new_d[k.strip()] = v 73 | return new_d 74 | 75 | 76 | def parse_json_or_json_markdown(text: str) -> dict[str, Any]: 77 | def parse(s: str) -> dict[str, Any] | None: 78 | try: 79 | return json.loads(s) 80 | except json.decoder.JSONDecodeError: 81 | return None 82 | 83 | # pass #1: try to parse as json 84 | parsed = parse(text) 85 | if parsed is not None: 86 | return parsed 87 | 88 | # pass #2: try to parse as json markdown 89 | stripped = text.strip() 90 | if stripped.startswith("```json"): 91 | stripped = stripped[len("```json") :].strip() 92 | if stripped.endswith("```"): 93 | stripped = stripped[: -len("```")].strip() 94 | parsed = parse(stripped) 95 | if parsed is not None: 96 | return parsed 97 | 98 | # pass #3: try to parse an arbitrary md block 99 | pattern = r"```(?:\w+\n)?(.*?)```" 100 | match = re.search(pattern, text, re.DOTALL) 101 | if match: 102 | content = match.group(1).strip() 103 | parsed = parse(content) 104 | if parsed is not None: 105 | return parsed 106 | 107 | # pass #4: try to parse arbitrary sections as json 108 | lines = text.split("\n") 109 | seen = set() 110 | for i in range(len(lines)): 111 | for j in range(i + 1, len(lines) + 1): 112 | if i < j and (i, j) not in seen: 113 | seen.add((i, j)) 114 | content = "\n".join(lines[i:j]) 115 | parsed = parse(content) 116 | if parsed is not None: 117 | return parsed 118 | raise ValueError("Could not parse JSON or JSON markdown") 119 | 120 | 121 | def longest_valid_string(s: str, options: list[str]) -> str | None: 122 | longest = 0 123 | longest_str = None 124 | options_set = set(options) 125 | for i in range(len(s)): 126 | if s[: i + 1] in options_set and i + 1 > longest: 127 | longest = i + 1 128 | longest_str = s[: i + 1] 129 | return longest_str 130 | 131 | 132 | def try_classify_recover(s: str, decode_map: dict[str, int]) -> str | None: 133 | lvs = longest_valid_string(s, list(decode_map.keys())) 134 | if lvs is not None and lvs in decode_map: 135 | return lvs 136 | for k, v in decode_map.items(): 137 | if s == v: 138 | return k 139 | 140 | 141 | def approx_num_tokens(text: str) -> int: 142 | return len(text) // 4 143 | 144 | 145 | def add_md_close_tag(prompt: str) -> str: 146 | return f"{prompt}\n```" 147 | 148 | 149 | def add_md_tag(prompt: str) -> str: 150 | return f"```json\n{prompt}\n```" 151 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/vllm_chat.py: -------------------------------------------------------------------------------- 1 | from tau_bench.model_utils.api.datapoint import Datapoint 2 | from tau_bench.model_utils.model.chat import ChatModel, Message 3 | from tau_bench.model_utils.model.completion import approx_cost_for_datapoint, approx_prompt_str 4 | from tau_bench.model_utils.model.general_model import wrap_temperature 5 | from tau_bench.model_utils.model.utils import approx_num_tokens 6 | 7 | PRICE_PER_INPUT_TOKEN_MAP = { 8 | "Qwen/Qwen2-0.5B-Instruct": 0.0, 9 | "Qwen/Qwen2-1.5B-Instruct": 0.0, 10 | "Qwen/Qwen2-7B-Instruct": 0.0, 11 | "Qwen/Qwen2-72B-Instruct": 0.0, 12 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 0.0, 13 | "sierra-research/Meta-Llama-3.1-8B-Instruct": 0.0, 14 | "meta-llama/Meta-Llama-3.1-70B-Instruct": 0.0, 15 | "mistralai/Mistral-Nemo-Instruct-2407": 0.0, 16 | } 17 | INPUT_PRICE_PER_TOKEN_FALLBACK = 0.0 18 | 19 | # TODO: refine this 20 | CAPABILITY_SCORE_MAP = { 21 | "Qwen/Qwen2-0.5B-Instruct": 0.05, 22 | "Qwen/Qwen2-1.5B-Instruct": 0.07, 23 | "Qwen/Qwen2-7B-Instruct": 0.2, 24 | "Qwen/Qwen2-72B-Instruct": 0.4, 25 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 0.3, 26 | "sierra-research/Meta-Llama-3.1-8B-Instruct": 0.3, 27 | "meta-llama/Meta-Llama-3.1-70B-Instruct": 0.4, 28 | "mistralai/Mistral-Nemo-Instruct-2407": 0.3, 29 | } 30 | CAPABILITY_SCORE_FALLBACK = 0.3 31 | 32 | # TODO: implement 33 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 34 | # TODO: implement 35 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 36 | 37 | MAX_CONTEXT_LENGTH_MAP = { 38 | "Qwen/Qwen2-0.5B-Instruct": 32768, 39 | "Qwen/Qwen2-1.5B-Instruct": 32768, 40 | "Qwen/Qwen2-7B-Instruct": 131072, 41 | "Qwen/Qwen2-72B-Instruct": 131072, 42 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 128000, 43 | "sierra-research/Meta-Llama-3.1-8B-Instruct": 128000, 44 | "meta-llama/Meta-Llama-3.1-70B-Instruct": 128000, 45 | "mistralai/Mistral-Nemo-Instruct-2407": 128000, 46 | } 47 | MAX_CONTEXT_LENGTH_FALLBACK = 128000 48 | 49 | 50 | class VLLMChatModel(ChatModel): 51 | def __init__( 52 | self, 53 | model: str, 54 | base_url: str, 55 | api_key: str, 56 | temperature: float = 0.0, 57 | price_per_input_token: float | None = None, 58 | capability: float | None = None, 59 | latency_ms_per_output_token: float | None = None, 60 | max_context_length: int | None = None, 61 | ) -> None: 62 | from openai import AsyncOpenAI, OpenAI 63 | 64 | self.model = model 65 | self.client = OpenAI( 66 | base_url=base_url, 67 | api_key=api_key, 68 | ) 69 | self.async_client = AsyncOpenAI( 70 | base_url=base_url, 71 | api_key=api_key, 72 | ) 73 | self.temperature = temperature 74 | self.price_per_input_token = ( 75 | price_per_input_token 76 | if price_per_input_token is not None 77 | else PRICE_PER_INPUT_TOKEN_MAP.get(model, INPUT_PRICE_PER_TOKEN_FALLBACK) 78 | ) 79 | self.capability = ( 80 | capability 81 | if capability is not None 82 | else CAPABILITY_SCORE_MAP.get(model, CAPABILITY_SCORE_FALLBACK) 83 | ) 84 | self.latency_ms_per_output_token = ( 85 | latency_ms_per_output_token 86 | if latency_ms_per_output_token is not None 87 | else LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get(model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK) 88 | ) 89 | self.max_context_length = ( 90 | max_context_length 91 | if max_context_length is not None 92 | else MAX_CONTEXT_LENGTH_MAP.get(model, MAX_CONTEXT_LENGTH_FALLBACK) 93 | ) 94 | 95 | def get_approx_cost(self, dp: Datapoint) -> float: 96 | cost_per_token = self.price_per_input_token 97 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 98 | 99 | def get_latency(self, dp: Datapoint) -> float: 100 | latency_per_output_token = self.latency_ms_per_output_token 101 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 102 | 103 | def get_capability(self) -> float: 104 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 105 | 106 | def supports_dp(self, dp: Datapoint) -> bool: 107 | prompt = approx_prompt_str(dp) 108 | return approx_num_tokens(prompt) <= self.max_context_length 109 | 110 | def generate_message( 111 | self, 112 | messages: list[Message], 113 | force_json: bool, 114 | temperature: float | None = None, 115 | ) -> Message: 116 | if temperature is None: 117 | temperature = self.temperature 118 | msgs = self.build_generate_message_state(messages) 119 | res = self.client.chat.completions.create( 120 | model=self.model, 121 | messages=msgs, 122 | temperature=wrap_temperature(temperature=temperature), 123 | ) 124 | return self.handle_generate_message_response( 125 | prompt=msgs, content=res.choices[0].message.content, force_json=force_json 126 | ) 127 | 128 | def force_json_prompt(self, text: str, _: bool = False) -> str: 129 | return super().force_json_prompt(text, with_prefix=True) 130 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/vllm_completion.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | from pydantic import BaseModel 5 | 6 | from tau_bench.model_utils.api.datapoint import Datapoint 7 | from tau_bench.model_utils.model.completion import ( 8 | CompletionModel, 9 | approx_cost_for_datapoint, 10 | approx_prompt_str, 11 | ) 12 | from tau_bench.model_utils.model.utils import approx_num_tokens 13 | from tau_bench.model_utils.model.vllm_utils import generate_request 14 | 15 | PRICE_PER_INPUT_TOKEN_MAP = { 16 | "Qwen/Qwen2-0.5B-Instruct": 0.0, 17 | "Qwen/Qwen2-1.5B-Instruct": 0.0, 18 | "Qwen/Qwen2-7B-Instruct": 0.0, 19 | "Qwen/Qwen2-72B-Instruct": 0.0, 20 | "meta-llama/Meta-Llama-3-8B-Instruct": 0.0, 21 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 0.0, 22 | "meta-llama/Meta-Llama-3-70B-Instruct": 0.0, 23 | "mistralai/Mistral-Nemo-Instruct-2407": 0.0, 24 | } 25 | INPUT_PRICE_PER_TOKEN_FALLBACK = 0.0 26 | 27 | # TODO: refine this 28 | CAPABILITY_SCORE_MAP = { 29 | "Qwen/Qwen2-0.5B-Instruct": 0.05, 30 | "Qwen/Qwen2-1.5B-Instruct": 0.07, 31 | "Qwen/Qwen2-7B-Instruct": 0.2, 32 | "Qwen/Qwen2-72B-Instruct": 0.4, 33 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 0.3, 34 | "sierra-research/Meta-Llama-3.1-8B-Instruct": 0.3, 35 | "meta-llama/Meta-Llama-3.1-70B-Instruct": 0.5, 36 | "mistralai/Mistral-Nemo-Instruct-2407": 0.3, 37 | } 38 | CAPABILITY_SCORE_FALLBACK = 0.1 39 | 40 | # TODO: implement 41 | LATENCY_MS_PER_OUTPUT_TOKEN_MAP = {} 42 | # TODO: implement 43 | LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK = 0.0 44 | 45 | MAX_CONTEXT_LENGTH_MAP = { 46 | "Qwen/Qwen2-0.5B-Instruct": 32768, 47 | "Qwen/Qwen2-1.5B-Instruct": 32768, 48 | "Qwen/Qwen2-7B-Instruct": 131072, 49 | "Qwen/Qwen2-72B-Instruct": 131072, 50 | "meta-llama/Meta-Llama-3.1-8B-Instruct": 128000, 51 | "sierra-research/Meta-Llama-3.1-8B-Instruct": 128000, 52 | "meta-llama/Meta-Llama-3.1-70B-Instruct": 128000, 53 | "mistralai/Mistral-Nemo-Instruct-2407": 128000, 54 | } 55 | MAX_CONTEXT_LENGTH_FALLBACK = 128000 56 | 57 | 58 | class VLLMCompletionModel(CompletionModel): 59 | def __init__( 60 | self, 61 | model: str, 62 | base_url: str, 63 | endpoint: str = "generate", 64 | temperature: float = 0.0, 65 | price_per_input_token: float | None = None, 66 | capability: float | None = None, 67 | latency_ms_per_output_token: float | None = None, 68 | max_context_length: int | None = None, 69 | ) -> None: 70 | self.model = model 71 | self.base_url = base_url 72 | self.url = os.path.join(base_url, endpoint) 73 | self.temperature = temperature 74 | self.price_per_input_token = ( 75 | price_per_input_token 76 | if price_per_input_token is not None 77 | else PRICE_PER_INPUT_TOKEN_MAP.get(model, INPUT_PRICE_PER_TOKEN_FALLBACK) 78 | ) 79 | self.capability = ( 80 | capability 81 | if capability is not None 82 | else CAPABILITY_SCORE_MAP.get(model, CAPABILITY_SCORE_FALLBACK) 83 | ) 84 | self.latency_ms_per_output_token = ( 85 | latency_ms_per_output_token 86 | if latency_ms_per_output_token is not None 87 | else LATENCY_MS_PER_OUTPUT_TOKEN_MAP.get(model, LATENCY_MS_PER_OUTPUT_TOKEN_FALLBACK) 88 | ) 89 | self.max_context_length = ( 90 | max_context_length 91 | if max_context_length is not None 92 | else MAX_CONTEXT_LENGTH_MAP.get(model, MAX_CONTEXT_LENGTH_FALLBACK) 93 | ) 94 | 95 | def generate_from_prompt(self, prompt: str, temperature: float = 0.0) -> str: 96 | return generate_request(url=self.url, prompt=prompt, temperature=temperature) 97 | 98 | def parse_force_from_prompt( 99 | self, prompt: str, typ: BaseModel | dict[str, Any], temperature: float | None = None 100 | ) -> dict[str, Any]: 101 | if temperature is None: 102 | temperature = self.temperature 103 | res = generate_request( 104 | url=self.url, prompt=prompt, force_json=True, temperature=temperature 105 | ) 106 | return self.handle_parse_force_response(prompt=prompt, content=res) 107 | 108 | def get_approx_cost(self, dp: Datapoint) -> float: 109 | cost_per_token = self.price_per_input_token 110 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=cost_per_token) 111 | 112 | def get_latency(self, dp: Datapoint) -> float: 113 | latency_per_output_token = self.latency_ms_per_output_token 114 | return approx_cost_for_datapoint(dp=dp, price_per_input_token=latency_per_output_token) 115 | 116 | def get_capability(self) -> float: 117 | return CAPABILITY_SCORE_MAP.get(self.model, CAPABILITY_SCORE_FALLBACK) 118 | 119 | def supports_dp(self, dp: Datapoint) -> bool: 120 | prompt = approx_prompt_str(dp) 121 | return approx_num_tokens(prompt) <= self.max_context_length 122 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/model_utils/model/vllm_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import requests 4 | 5 | from tau_bench.model_utils.model.general_model import wrap_temperature 6 | 7 | 8 | def generate_request( 9 | url: str, 10 | prompt: str, 11 | temperature: float = 0.0, 12 | force_json: bool = False, 13 | **req_body_kwargs: Any, 14 | ) -> str: 15 | args = { 16 | "prompt": prompt, 17 | "temperature": wrap_temperature(temperature), 18 | "max_tokens": 4096, 19 | **req_body_kwargs, 20 | } 21 | if force_json: 22 | # the prompt will have a suffix of '```json\n' to indicate that the response should be a JSON object 23 | args["stop"] = ["```"] 24 | res = requests.post( 25 | url, 26 | json=args, 27 | ) 28 | res.raise_for_status() 29 | json_res = res.json() 30 | if "text" not in json_res: 31 | raise ValueError(f"Unexpected response: {json_res}") 32 | elif len(json_res["text"]) == 0: 33 | raise ValueError(f"Empty response: {json_res}") 34 | text = json_res["text"][0] 35 | assert isinstance(text, str) 36 | return text.removeprefix(prompt) 37 | -------------------------------------------------------------------------------- /tau-bench-eval/tau_bench/types.py: -------------------------------------------------------------------------------- 1 | # Copyright Sierra 2 | 3 | from pydantic import BaseModel 4 | from typing import List, Dict, Any, Optional, Literal, Union, TypedDict 5 | 6 | RESPOND_ACTION_NAME = "respond" 7 | RESPOND_ACTION_FIELD_NAME = "content" 8 | 9 | 10 | class Action(BaseModel): 11 | name: str 12 | kwargs: Dict[str, Any] 13 | 14 | 15 | class Task(BaseModel): 16 | user_id: str 17 | actions: List[Action] 18 | instruction: str 19 | outputs: List[str] 20 | 21 | 22 | class RewardOutputInfo(BaseModel): 23 | r_outputs: float 24 | outputs: Dict[str, bool] 25 | 26 | 27 | class RewardActionInfo(BaseModel): 28 | r_actions: float 29 | gt_data_hash: str 30 | 31 | 32 | class RewardResult(BaseModel): 33 | reward: float 34 | info: Union[RewardOutputInfo, RewardActionInfo] 35 | actions: List[Action] 36 | 37 | 38 | class SolveResult(BaseModel): 39 | reward: float 40 | messages: List[Dict[str, Any]] 41 | info: Dict[str, Any] 42 | total_cost: Optional[float] = None 43 | 44 | 45 | class EnvInfo(BaseModel): 46 | task: Task 47 | source: Optional[str] = None 48 | user_cost: Optional[float] = None 49 | reward_info: Optional[RewardResult] = None 50 | 51 | 52 | class EnvResponse(BaseModel): 53 | observation: str 54 | reward: float 55 | done: bool 56 | info: EnvInfo 57 | 58 | 59 | class EnvResetResponse(BaseModel): 60 | observation: str 61 | info: EnvInfo 62 | 63 | 64 | class EnvRunResult(BaseModel): 65 | task_id: int 66 | reward: float 67 | info: Dict[str, Any] 68 | traj: List[Dict[str, Any]] 69 | trial: int 70 | --------------------------------------------------------------------------------