├── .gitignore ├── resources ├── complex-example.png ├── data-collection.png └── evaluation-pipeline.png ├── .env ├── requirements.txt ├── prompts ├── prompts.py ├── compare.py └── response.py ├── models ├── mistral.py ├── qwen.py ├── claude.py ├── gpt.py ├── llama.py └── glm.py ├── utils ├── logger.py ├── utils.py ├── rapidapi.py ├── tool_info.json ├── exact_match_values.json └── compare_method.py ├── print_results.py ├── runner ├── response_runner.py ├── base_runner.py ├── qwen_runner.py ├── llama_runner.py ├── mistral_runner.py ├── gpt_runner.py ├── claude_runner.py └── glm_runner.py ├── evaluation.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | result 3 | data -------------------------------------------------------------------------------- /resources/complex-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zai-org/ComplexFuncBench/HEAD/resources/complex-example.png -------------------------------------------------------------------------------- /resources/data-collection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zai-org/ComplexFuncBench/HEAD/resources/data-collection.png -------------------------------------------------------------------------------- /resources/evaluation-pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zai-org/ComplexFuncBench/HEAD/resources/evaluation-pipeline.png -------------------------------------------------------------------------------- /.env: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY= 2 | ZHIPUAI_API_KEY= 3 | ANTHROPIC_API_KEY= 4 | MISTRAL_API_KEY= 5 | Qwen_aliyuncs_KEY= 6 | 7 | RAPID_API_KEY= -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anthropic 2 | zhipuai 3 | mistralai 4 | numpy 5 | openai 6 | Requests 7 | scipy 8 | FlagEmbedding 9 | torch==2.3.0 10 | -------------------------------------------------------------------------------- /prompts/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Text 2 | from dataclasses import dataclass 3 | 4 | @dataclass 5 | class SimpleTemplatePrompt: 6 | template: str 7 | args_order: list 8 | 9 | def __call__(self, **kwargs: Any) -> Text: 10 | self.cur_template = self.template 11 | args = [kwargs[arg] for arg in self.args_order] 12 | for i, arg in enumerate(args): 13 | if isinstance(arg, int): arg = str(arg) 14 | self.cur_template = self.cur_template.replace(f"[args{str(i+1)}]", arg) 15 | return self.cur_template 16 | -------------------------------------------------------------------------------- /models/mistral.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import os 3 | from openai import OpenAI 4 | import json 5 | import sys 6 | import copy 7 | import os 8 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 10 | from prompts.prompts import SimpleTemplatePrompt 11 | from utils.utils import * 12 | from mistralai import Mistral 13 | 14 | 15 | class MistralModel: 16 | def __init__(self, model_name): 17 | super().__init__() 18 | self.model_name = model_name 19 | self.client = Mistral(api_key=os.getenv("MISTRAL_API_KEY")) 20 | 21 | self.messages = [] 22 | 23 | @retry(max_attempts=10, delay=60) 24 | def __call__(self, messages, tools=None, **kwargs: Any): 25 | if "function_call" not in json.dumps(messages, ensure_ascii=False): 26 | self.messages = copy.deepcopy(messages) 27 | try: 28 | completion = self.client.chat.complete( 29 | model=self.model_name, 30 | messages=self.messages, 31 | temperature=0.0, 32 | tools=tools, 33 | tool_choice="auto", 34 | max_tokens=2048 35 | ) 36 | return completion.choices[0].message 37 | except Exception as e: 38 | print(f"Exception: {e}") 39 | return None 40 | -------------------------------------------------------------------------------- /models/qwen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from typing import Any, Dict 4 | from utils.utils import * 5 | from openai import OpenAI 6 | 7 | """ 8 | You can also deploy Qwen2.5 via vLLM, please enable the auto-tool-choice. 9 | ```bash 10 | vllm serve Qwen/Qwen2.5-7B-Instruct --enable-auto-tool-choice --tool-call-parser hermes 11 | ``` 12 | Note: Tool support has been available in vllm since v0.6.0. Be sure to install a version that supports tool use. 13 | Reference: https://qwen.readthedocs.io/en/latest/framework/function_call.html#vllm 14 | """ 15 | 16 | class QwenModel: 17 | def __init__(self, model_name): 18 | self.model_name = model_name 19 | self.messages = [] 20 | self.client = OpenAI( 21 | api_key=os.getenv("Qwen_aliyuncs_KEY"), 22 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") 23 | 24 | @retry(max_attempts=5, delay=20) 25 | def __call__(self, messages, tools=None, **kwargs: Any): 26 | if "function_call" not in json.dumps(messages, ensure_ascii=False): 27 | self.messages = copy.deepcopy(messages) 28 | try: 29 | completion = self.client.chat.completions.create( 30 | model=self.model_name, 31 | messages=self.messages, 32 | temperature=0.0, 33 | tools=tools, 34 | tool_choice="auto", 35 | max_tokens=2048 36 | ) 37 | return completion.model_dump()['choices'][0]['message'] 38 | except Exception as e: 39 | print(f"Exception: {e}") 40 | return None -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class Logger: 5 | def __init__(self, name='my_logger', log_file='test.log', level=logging.INFO): 6 | self.logger = logging.getLogger(name) 7 | self.logger.setLevel(level) 8 | 9 | # Create handlers 10 | self.file_handler = logging.FileHandler(log_file) 11 | self.console_handler = logging.StreamHandler() 12 | 13 | # Configure file handler 14 | self.file_handler.setLevel(level) 15 | 16 | # Configure console handler 17 | self.console_handler.setLevel(level) 18 | 19 | # Create a logging format 20 | self.formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 21 | 22 | # Add the formatter to handlers 23 | self.file_handler.setFormatter(self.formatter) 24 | self.console_handler.setFormatter(self.formatter) 25 | 26 | # Add handlers to the logger 27 | if not self.logger.handlers: 28 | self.logger.addHandler(self.file_handler) 29 | self.logger.addHandler(self.console_handler) 30 | 31 | def debug(self, msg): 32 | self.logger.debug(msg) 33 | 34 | def info(self, msg): 35 | self.logger.info(msg) 36 | 37 | def warning(self, msg): 38 | self.logger.warning(msg) 39 | 40 | def error(self, msg): 41 | self.logger.error(msg) 42 | 43 | def critical(self, msg): 44 | self.logger.critical(msg) 45 | 46 | # Example usage 47 | if __name__ == "__main__": 48 | log = Logger(name="test_logger", log_file="logs/test.log", level=logging.DEBUG) 49 | log.debug("This is a debug message") 50 | log.info("This is an info message") 51 | log.warning("This is a warning message") 52 | log.error("This is an error message") -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import time 4 | import traceback 5 | import logging 6 | 7 | 8 | def load_json(dir_path): 9 | if dir_path.endswith('.json'): 10 | return json.load(open(dir_path, 'r')) 11 | elif dir_path.endswith('.jsonl'): 12 | return [json.loads(line) for line in open(dir_path, 'r')] 13 | 14 | 15 | def save_json(data, dir_path): 16 | if dir_path.endswith('.json'): 17 | with open(dir_path, 'w') as f: 18 | json.dump(data, f, ensure_ascii=False, indent=4) 19 | elif dir_path.endswith(".jsonl"): 20 | with open(dir_path, 'w') as f: 21 | for line in data: 22 | f.write(json.dumps(line, ensure_ascii=False) + "\n") 23 | 24 | def decode_json(json_str): 25 | json_str = json_str.strip('```JSON\n').strip('```json\n').strip('\n```') 26 | json_str = json_str.replace('\n', '').replace('False', 'false').replace('True', 'true') 27 | try: 28 | return json.loads(json_str) 29 | except: 30 | return None 31 | 32 | 33 | def exception_handler(func): 34 | def wrapper(*args, **kwargs): 35 | try: 36 | return func(*args, **kwargs) 37 | except Exception as e: 38 | print(f"An error occurred in {func.__name__}: {e}") 39 | tb = traceback.format_exc() 40 | print(f"Traceback:\n{tb}") 41 | return None 42 | return wrapper 43 | 44 | 45 | def apply_decorator_to_all_methods(decorator): 46 | def class_decorator(cls): 47 | for attr in dir(cls): 48 | if callable(getattr(cls, attr)) and not attr.startswith("__"): 49 | setattr(cls, attr, decorator(getattr(cls, attr))) 50 | return cls 51 | return class_decorator 52 | 53 | 54 | from functools import wraps 55 | def retry(max_attempts=5, delay=2): 56 | def decorator(func): 57 | @wraps(func) 58 | def wrapper(*args, **kwargs): 59 | attempt = 0 60 | while attempt < max_attempts: 61 | response = func(*args, **kwargs) 62 | if response is not None: 63 | return response 64 | attempt += 1 65 | print(f"Attempt {attempt}/{max_attempts} failed.") 66 | time.sleep(delay) 67 | return response 68 | return wrapper 69 | return decorator -------------------------------------------------------------------------------- /models/claude.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import os 3 | from anthropic import Anthropic 4 | import copy 5 | import json 6 | import sys 7 | import os 8 | from urllib.parse import unquote 9 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | from prompts.prompts import SimpleTemplatePrompt 12 | from utils.utils import * 13 | 14 | class ClaudeModel: 15 | def __init__(self, model_name): 16 | super().__init__() 17 | self.model_name = model_name 18 | self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) 19 | 20 | def __call__(self, prefix, prompt: SimpleTemplatePrompt, **kwargs: Any): 21 | filled_prompt = prompt(**kwargs) 22 | prediction = self._predict(prefix, filled_prompt) 23 | return prediction 24 | 25 | @retry(max_attempts=10, delay=60) 26 | def _predict(self, prefix, query): 27 | try: 28 | completion = self.client.messages.create( 29 | model=self.model_name, 30 | messages=[ 31 | {"role": "system", "content": prefix}, 32 | {"role": "user", "content": query} 33 | ], 34 | temperature=0.0, 35 | do_sample=False 36 | ) 37 | return completion.choices[0].message.content 38 | except Exception as e: 39 | print(f"Exception: {e}") 40 | return None 41 | 42 | class FunctionCallClaude(ClaudeModel): 43 | def __init__(self, model_name): 44 | super().__init__(model_name) 45 | 46 | @retry(max_attempts=10, delay=60) 47 | def __call__(self, messages, tools=None, **kwargs: Any): 48 | if "function_call" not in json.dumps(messages, ensure_ascii=False): 49 | self.messages = copy.deepcopy(messages) 50 | try: 51 | response = self.client.messages.create( 52 | model=self.model_name, 53 | messages=self.messages, 54 | temperature=0.0, 55 | tools=tools, 56 | max_tokens=2048, 57 | tool_choice={"type": "auto"} 58 | ) 59 | return response 60 | except Exception as e: 61 | print(f"Exception: {e}") 62 | return None 63 | 64 | if __name__ == "__main__": 65 | model = ClaudeModel("claude-3-5-sonnet-20240620") 66 | response_message = model._predict("You are a helpful assistant.", query="What is the capital of France?") 67 | print(response_message) -------------------------------------------------------------------------------- /prompts/compare.py: -------------------------------------------------------------------------------- 1 | from prompts.prompts import SimpleTemplatePrompt 2 | 3 | 4 | system_prompt = """"You are an assistant for function call comparison. Your task is to determine whether two function calls are equivalent based on the conversation history and the function descriptions, and provide specific reasons. 5 | # Instructions: 6 | You need to determine whether two function calls are equivalent based on the following criteria: 7 | 1. The same parameter can be expressed in different languages. For example: `America`, `美国` and `アメリカ` are equivalent. 8 | 2. The same parameter can be expressed in different forms as long as the meaning is the same. For example, `New York` and `NY` are equivalent, `Shanghai` and `Shanghai City` are equivalent. `Narita International Airport` and `Tokyo Narita International Airport` are equivalent. 9 | 3. A location with or without a country suffix is considered equivalent. For example, Van Gogh Museum, Amsterdam and Van Gogh Museum are equivalent. 10 | 4. The order of parameters in a function call can differ, for example: `add(1, 2)` and `add(2, 1)` are equivalent. 11 | 5. If a parameter in the function description has a default value, and the current parameter value is equal to that default value, then the parameter can be omitted in the function call. For example, if the adults parameter in the Search_Hotels function has a default value of 1, then the following two function calls are equivalent: `Search_Hotels(New_York, adults=1)` and `Search_Hotels(New_York)`. 12 | 13 | # Output: 14 | You need to output the result in JSON format, containing the following fields: 15 | - **is_equal**: A boolean indicating whether the two function calls are equivalent. 16 | - **reason**: Please provide the reason for your judgment. 17 | 18 | ## Example 1 19 | output: 20 | ```JSON 21 | {"is_equal": true, "reason": "xx and yy are equivalent."} 22 | ``` 23 | 24 | ## Example 2 25 | output: 26 | ```JSON 27 | {"is_equal": false, "reason": "The parameter in call xx is not present in call yy."} 28 | ``` 29 | 30 | Note: The output must be a valid JSON format that can be properly loaded. 31 | """ 32 | 33 | # User prompt 34 | user_prompt = SimpleTemplatePrompt( 35 | template=("""Function list: 36 | ```JSON 37 | [args1] 38 | ``` 39 | Conversation history: 40 | ```JSON 41 | [args2] 42 | ``` 43 | Function call 1: 44 | ```JSON 45 | [args3] 46 | ``` 47 | Function call 2: 48 | ```JSON 49 | [args4] 50 | ``` 51 | Please determine whether Function call 1 and Function call 2 are equivalent and provide your reason. 52 | output:\n 53 | """ 54 | ), args_order=["functions", "history", "function_call_1", "function_call_2"]) 55 | -------------------------------------------------------------------------------- /models/gpt.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import os 3 | from openai import OpenAI 4 | import json 5 | import sys 6 | import copy 7 | import os 8 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 10 | from prompts.prompts import SimpleTemplatePrompt 11 | from utils.utils import * 12 | 13 | 14 | class GPTModel: 15 | def __init__(self, model_name): 16 | super().__init__() 17 | self.model_name = model_name 18 | self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) 19 | 20 | 21 | def __call__(self, prefix, prompt: SimpleTemplatePrompt, **kwargs: Any): 22 | filled_prompt = prompt(**kwargs) 23 | prediction = self._predict(prefix, filled_prompt, **kwargs) 24 | return prediction 25 | 26 | @retry(max_attempts=10) 27 | def _predict(self, prefix, text, **kwargs): 28 | try: 29 | completion = self.client.chat.completions.create( 30 | model=self.model_name, 31 | messages=[ 32 | {"role": "system", "content": prefix}, 33 | {"role": "user", "content": text} 34 | ], 35 | temperature=0.0, 36 | ) 37 | return completion.choices[0].message.content 38 | except Exception as e: 39 | print(f"Exception: {e}") 40 | return None 41 | 42 | 43 | class FunctionCallGPT(GPTModel): 44 | def __init__(self, model_name): 45 | super().__init__(None) 46 | self.model_name = model_name 47 | self.messages = [] 48 | 49 | @retry(max_attempts=5, delay=10) 50 | def __call__(self, messages, tools=None, **kwargs: Any): 51 | if "function_call" not in json.dumps(messages, ensure_ascii=False): 52 | self.messages = copy.deepcopy(messages) 53 | try: 54 | completion = self.client.chat.completions.create( 55 | model=self.model_name, 56 | messages=self.messages, 57 | temperature=0.0, 58 | tools=tools, 59 | tool_choice="auto", 60 | max_tokens=2048 61 | ) 62 | return completion.choices[0].message 63 | except Exception as e: 64 | print(f"Exception: {e}") 65 | return None 66 | 67 | 68 | if __name__ == "__main__": 69 | model = GPTModel("gpt-4") 70 | response = model("You are a helpful assistant.", SimpleTemplatePrompt(template=("What is the capital of France?"), args_order=[])) 71 | print(response) -------------------------------------------------------------------------------- /print_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from utils.utils import * 4 | from collections import Counter 5 | import argparse 6 | 7 | 8 | def basic_metric(result_dir): 9 | results = load_json(result_dir) 10 | domain_success = defaultdict(int) 11 | domain_turn_count = defaultdict(lambda: [0, 0]) 12 | domain_call_count = defaultdict(lambda: [0, 0]) 13 | complete_score_count = defaultdict(lambda: [0, 0]) 14 | correct_score_count = defaultdict(lambda: [0, 0]) 15 | for result in results: 16 | domain = result['id'].rsplit("-", 1)[0] 17 | if result['message'] == "Success.": 18 | domain_success[domain] += 1 19 | domain_turn_count[domain][0] += result['count_dict']['success_turn_num'] 20 | domain_turn_count[domain][1] += result['count_dict']['total_turn_num'] 21 | 22 | domain_call_count[domain][0] += result['count_dict']['correct_call_num'] 23 | domain_call_count[domain][1] += result['count_dict']['total_call_num'] 24 | 25 | if result["resp_eval"] is None: 26 | continue 27 | 28 | if result["resp_eval"]['complete']['score'] in {0, 1, 2}: 29 | complete_score_count[domain][0] += result["resp_eval"]['complete']['score'] 30 | complete_score_count[domain][1] += 1 31 | 32 | if result["resp_eval"]['correct']['score'] in {0, 1, 2}: 33 | correct_score_count[domain][0] += result["resp_eval"]['correct']['score'] 34 | correct_score_count[domain][1] += 1 35 | 36 | domain_success_rate = {k: v / 150 * 100 if k != "Cross" else v / 400 * 100 for k, v in domain_success.items()} 37 | domain_turn_acc = {k: v[0] / v[1] * 100 if v[1] != 0 else 0 for k, v in domain_turn_count.items()} 38 | domain_call_acc = {k: v[0] / v[1] * 100 if v[1] != 0 else 0 for k, v in domain_call_count.items()} 39 | 40 | overall_success = sum(domain_success.values()) / 1000 * 100 41 | overall_call_acc = sum([v[0] for v in domain_call_count.values()]) / sum([v[1] for v in domain_call_count.values()]) * 100 42 | 43 | complete_score, complete_total = 0, 0 44 | for k, v in complete_score_count.items(): 45 | complete_score += v[0] 46 | complete_total += v[1] 47 | complete_score_avg = complete_score / complete_total if complete_total != 0 else 0 48 | 49 | correct_score, correct_total = 0, 0 50 | for k, v in correct_score_count.items(): 51 | correct_score += v[0] 52 | correct_total += v[1] 53 | correct_score_avg = correct_score / correct_total if correct_total != 0 else 0 54 | 55 | 56 | print(f"Domain Success Rate: {domain_success_rate}") 57 | # print(f"Domain Turn Accuracy: {domain_turn_acc}") 58 | print(f"Domain Call Accuracy: {domain_call_acc}") 59 | print(f"Overall Success Rate: {overall_success}") 60 | print(f"Overall Call Accuracy: {overall_call_acc}") 61 | print(f"Complete Score: {complete_score_avg}") 62 | print(f"Correct Score: {correct_score_avg}") 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--log_dir", type=str, default="logs/test.log") 68 | parser.add_argument("--result_dir", type=str, default="result/../All.jsonl") 69 | args = parser.parse_args() 70 | basic_metric(args.result_dir) 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | main() -------------------------------------------------------------------------------- /runner/response_runner.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | from utils.utils import * 4 | from models.gpt import GPTModel 5 | 6 | from prompts.response import ( 7 | complete_system_prompt, 8 | complete_user_prompt, 9 | correct_system_prompt, 10 | correct_user_prompt 11 | ) 12 | 13 | class RespEvalRunner: 14 | def __init__(self, args, logger): 15 | self.logger = logger 16 | self.model = GPTModel("gpt-4o-2024-08-06") 17 | 18 | @retry(max_attempts=10) 19 | def completeness_eval(self, **kwargs): 20 | complete_result = self.model(complete_system_prompt, complete_user_prompt, **kwargs) 21 | decoded_complete_result = decode_json(complete_result) 22 | 23 | self.logger.info(f"Complete Result: {decoded_complete_result}") 24 | 25 | if not isinstance(decoded_complete_result, dict) or "score" not in decoded_complete_result: 26 | return None 27 | if decoded_complete_result['score'] not in [0, 1, 2]: 28 | return None 29 | return decoded_complete_result 30 | 31 | @retry(max_attempts=10) 32 | def correctness_eval(self, **kwargs): 33 | correct_result = self.model(correct_system_prompt, correct_user_prompt, **kwargs) 34 | decoded_correct_result = decode_json(correct_result) 35 | self.logger.info(f"Correct Result: {decoded_correct_result}") 36 | if not isinstance(decoded_correct_result, dict) or "score" not in decoded_correct_result: 37 | return None 38 | if decoded_correct_result['score'] not in [0, 1, 2]: 39 | return None 40 | return decoded_correct_result 41 | 42 | def run(self, data, gen_response): 43 | if gen_response == "": 44 | return { 45 | "complete": {"score": -2, "reason": "Do not generate response successfully."}, 46 | "correct": {"score": -2, "reason": "Do not generate response successfully."} 47 | } 48 | 49 | convs = data['conversations'] 50 | 51 | kwargs = { 52 | "query": convs[0]['content'], 53 | "gen_response": gen_response, 54 | } 55 | 56 | complete_result = self.completeness_eval(**kwargs) 57 | 58 | kwargs = { 59 | "history": json.dumps(convs[:-1], ensure_ascii=False), 60 | "gen_response": gen_response 61 | } 62 | 63 | correct_result = self.correctness_eval(**kwargs) 64 | 65 | if complete_result and correct_result: 66 | return { 67 | "complete": {"score": complete_result['score'], "reason": complete_result.get("reason", None)}, 68 | "correct": {"score": correct_result['score'], "reason": correct_result.get("reason", None)} 69 | } 70 | elif complete_result: 71 | return { 72 | "complete": {"score": complete_result['score'], "reason": complete_result.get("reason", None)}, 73 | "correct": {"score": -1, "reason": "Correctness Eval failed."} 74 | } 75 | elif correct_result: 76 | return { 77 | "complete": {"score": -1, "reason": "Completeness Eval failed."}, 78 | "correct": {"score": correct_result['score'], "reason": correct_result.get("reason", None)} 79 | } 80 | else: 81 | return { 82 | "complete": {"score": -1, "reason": "Completeness Eval failed."}, 83 | "correct": {"score": -1, "reason": "Correctness Eval failed."} 84 | } -------------------------------------------------------------------------------- /utils/rapidapi.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | import copy 4 | import os 5 | import random 6 | from utils.utils import * 7 | 8 | 9 | class RapidAPICall(): 10 | def __init__(self, tool, tool_info): 11 | self.remote = True 12 | self.name_to_url = tool_info['name_to_url'] 13 | self.headers = { 14 | "X-RapidAPI-Key": os.getenv("RAPID_API_KEY"), 15 | "X-RapidAPI-Host": tool_info['host'] 16 | } 17 | self.path_params = tool_info['path_params'] 18 | self.tool = tool 19 | 20 | @retry(max_attempts=3) 21 | def _call(self, func_call): 22 | self.url = self.name_to_url[func_call["name"]] 23 | params_copy = copy.deepcopy(func_call['arguments']) 24 | 25 | param_dict = {} 26 | for path_param in self.path_params: 27 | if f"{{{path_param}}}" in self.url and path_param in params_copy: 28 | param_dict[path_param] = params_copy.pop(path_param) 29 | self.url = self.url.format(**param_dict) 30 | 31 | for k, value in params_copy.items(): 32 | if k == "legs": 33 | params_copy[k] = json.dumps(value, ensure_ascii=False) 34 | try: 35 | response = requests.get(self.url, headers=self.headers, params=params_copy) 36 | except: 37 | return None 38 | 39 | if response.status_code == 200: 40 | # print("Request success.") 41 | response = response.json() 42 | if response['status'] == True: 43 | if "timestamp" in response: 44 | response.pop("timestamp") 45 | if "data" in response: 46 | response = response['data'] 47 | return response 48 | else: 49 | return None 50 | 51 | 52 | def observation_shorten(self, response): 53 | if isinstance(response, dict): 54 | keys_to_delete = [key for key, value in response.items() if (value in ["", None, {}, []])] 55 | for key in keys_to_delete: 56 | response.pop(key) 57 | 58 | for key, value in response.items(): 59 | response[key] = self.observation_shorten(value) 60 | 61 | elif isinstance(response, list): 62 | if len(response) > 10 and isinstance(response[0], dict): 63 | response = response[:10] 64 | response = [self.observation_shorten(item) for item in response] 65 | 66 | return response 67 | 68 | 69 | if __name__ == "__main__": 70 | with open("utils/tool_info.json", 'r') as f: 71 | tool_info = json.load(f) 72 | tool_info = tool_info['booking-com15'] 73 | api_call = RapidAPICall(tool="booking-com15", tool_info=tool_info) 74 | func_call = { 75 | "name": "Search_Flights_Multi_Stops", 76 | "arguments": { 77 | "legs": [ 78 | { 79 | "fromId": "ORD.AIRPORT", 80 | "toId": "HND.AIRPORT", 81 | "date": "2024-09-01" 82 | }, 83 | { 84 | "fromId": "HND.AIRPORT", 85 | "toId": "PVG.AIRPORT", 86 | "date": "2024-09-05" 87 | }, 88 | { 89 | "fromId": "PVG.AIRPORT", 90 | "toId": "ORD.AIRPORT", 91 | "date": "2024-09-10" 92 | } 93 | ] 94 | } 95 | } 96 | response = api_call._call(func_call) 97 | print(response) 98 | -------------------------------------------------------------------------------- /models/llama.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import os 3 | from openai import OpenAI 4 | import json 5 | import sys 6 | import copy 7 | import os 8 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 10 | from prompts.prompts import SimpleTemplatePrompt 11 | from utils.utils import * 12 | 13 | 14 | class LlamaModel: 15 | def __init__(self, url, model_name): 16 | super().__init__() 17 | self.temperature = 0.95 18 | self.model_name = model_name 19 | self.url = url 20 | self.client = OpenAI( 21 | api_key="EMPTY", 22 | base_url=self.url) 23 | 24 | self.messages = [] 25 | 26 | def _format_prompt(self, messages, function): 27 | formatted_prompt = "<|begin_of_text|>" 28 | 29 | system_message = "" 30 | remaining_messages = messages 31 | if messages[0]["role"] == "system": 32 | system_message = messages[0]["content"].strip() 33 | remaining_messages = messages[1:] 34 | 35 | formatted_prompt += "<|start_header_id|>system<|end_header_id|>\n\n" 36 | formatted_prompt += "Cutting Knowledge Date: December 2023\n" 37 | formatted_prompt += "Today Date: 23 Jul 2024\n\n" 38 | formatted_prompt += "When you receive a tool call response, use the output to format an answer to the orginal user question.\n\n" 39 | formatted_prompt += "You are a helpful assistant with tool calling capabilities." 40 | formatted_prompt += system_message + "<|eot_id|>\n" 41 | 42 | # Llama pass in custom tools in first user message 43 | is_first_user_message = True 44 | for message in remaining_messages: 45 | if message["role"] == "user" and is_first_user_message: 46 | is_first_user_message = False 47 | formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n\n" 48 | formatted_prompt += "Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\n" 49 | formatted_prompt += 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.\n\n' 50 | for func in function: 51 | formatted_prompt += json.dumps(func, indent=4) + "\n\n" 52 | formatted_prompt += f"Question: {message['content'].strip()}<|eot_id|>" 53 | 54 | elif message["role"] == "tool": 55 | formatted_prompt += "<|start_header_id|>ipython<|end_header_id|>\n\n" 56 | if isinstance(message["content"], (dict, list)): 57 | formatted_prompt += json.dumps(message["content"]) 58 | else: 59 | formatted_prompt += message["content"] 60 | formatted_prompt += "<|eot_id|>" 61 | 62 | else: 63 | formatted_prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content'].strip()}<|eot_id|>" 64 | 65 | formatted_prompt += "\n<|start_header_id|>assistant<|end_header_id|>\n\n" 66 | 67 | return formatted_prompt 68 | 69 | 70 | @retry(max_attempts=5) 71 | def __call__(self, messages, tools=None, **kwargs: Any): 72 | if "function_call" not in json.dumps(messages, ensure_ascii=False): 73 | self.messages = copy.deepcopy(messages) 74 | prompt = self._format_prompt(self.messages, tools) 75 | try: 76 | completion = self.client.completions.create( 77 | model=self.model_name, 78 | prompt=prompt, 79 | temperature=0.0, 80 | max_tokens=4096 81 | ) 82 | return completion.choices[0].text 83 | except Exception as e: 84 | print(f"Exception: {e}") 85 | return None 86 | -------------------------------------------------------------------------------- /runner/base_runner.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | from utils.compare_method import CompareFC 4 | 5 | class ModelRunner: 6 | def __init__(self, args, logger): 7 | self.logger = logger 8 | 9 | self.error_message = None 10 | self.unexpect_call_resp = {"api_status": True, "content": "There is a problem with your api call, please double-check for possible problems."} 11 | 12 | self.CompareClass = CompareFC(args, logger) 13 | self.free_function_list = self.CompareClass.free_function_list 14 | 15 | def only_free_function(self, temp_fcs): 16 | for call in temp_fcs: 17 | if call['name'] == "Search_Hotels" and call["arguments"]["search_type"] == "hotel": 18 | return True 19 | 20 | temp_fcs = set([fc["name"] for fc in temp_fcs]) 21 | 22 | return temp_fcs.issubset(set(self.free_function_list)) 23 | 24 | def get_success_turn(self, remain_fcs, total_fcs): 25 | remain_ids = [] 26 | for idx, fc_list in enumerate(total_fcs): 27 | for remain_fc in remain_fcs: 28 | if remain_fc in fc_list: 29 | remain_ids.append(idx) 30 | if remain_ids == []: 31 | return len(total_fcs) 32 | 33 | return max(min(remain_ids), 0) 34 | 35 | def init_golden(self, convs): 36 | self.fc_chain = [] 37 | self.obs_chain = [] 38 | for turn in convs: 39 | if "function_call" in turn: 40 | self.fc_chain.append(turn['function_call']) 41 | elif turn['role'] == "observation": 42 | self.obs_chain.append(turn['content']) 43 | 44 | assert len(self.fc_chain) == len(self.obs_chain), "function call and observation length mismatch." 45 | 46 | self.turn_id, self.correct_count = 0, 0 47 | self.golden_fcs, self.golden_obs = copy.deepcopy(self.fc_chain[self.turn_id]), copy.deepcopy(self.obs_chain[self.turn_id]) 48 | 49 | if self.only_free_function(self.golden_fcs): 50 | self.update_current_golden() 51 | 52 | def update_current_golden(self): 53 | self.turn_id += 1 54 | if self.turn_id < len(self.fc_chain): 55 | self.golden_fcs.extend(copy.deepcopy(self.fc_chain[self.turn_id])) 56 | self.golden_obs.extend(copy.deepcopy(self.obs_chain[self.turn_id])) 57 | 58 | def return_result(self, messages, error_info=None): 59 | if error_info: 60 | success_turn = self.get_success_turn(self.golden_fcs, self.fc_chain) 61 | return messages, error_info, success_turn, self.correct_count 62 | 63 | # free function post process 64 | if len(self.golden_fcs) != 0: 65 | for call in self.golden_fcs: 66 | if call['name'] == "Search_Hotels" and call["arguments"]["search_type"] == "hotel": 67 | self.golden_fcs.remove(call) 68 | if call['name'] in ["Search_Hotel_Destination", "Search_Attraction_Location", "Search_Car_Location", "Search_Flight_Location", "Taxi_Search_Location"]: 69 | self.golden_fcs.remove(call) 70 | 71 | if self.turn_id < len(self.fc_chain) or len(self.golden_fcs) > 0: 72 | self.logger.info(f"turn id = {self.turn_id}; len(golden_answer) = {len(self.fc_chain)}") 73 | self.logger.info(f"golden_function_calls = {self.golden_fcs}") 74 | return self.return_result(messages, {"error_type": "stop_early", "content": "Stop early."}) 75 | elif len(self.golden_fcs) == 0: 76 | return messages, "Success.", len(self.fc_chain), self.correct_count 77 | else: 78 | raise NotImplementedError("Unexpected error.") 79 | 80 | def process_matches(self, success_matched): 81 | for matched in success_matched: 82 | if matched in self.golden_fcs: 83 | self.golden_obs.pop(self.golden_fcs.index(matched)) 84 | self.golden_fcs.remove(matched) 85 | 86 | if len(success_matched) > 0: 87 | self.update_current_golden() 88 | 89 | for k, v in self.CompareClass.free_functions.items(): 90 | if v['called'] == True and json.loads(k) in self.golden_fcs: 91 | self.golden_obs.pop(self.golden_fcs.index(json.loads(k))) 92 | self.golden_fcs.remove(json.loads(k)) 93 | 94 | if self.only_free_function(self.golden_fcs): 95 | self.update_current_golden() -------------------------------------------------------------------------------- /runner/qwen_runner.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import json 4 | from models.qwen import QwenModel 5 | from runner.base_runner import ModelRunner 6 | 7 | 8 | class QwenRunner(ModelRunner): 9 | def __init__(self, args, logger): 10 | super().__init__(args, logger) 11 | self.model_name = args.model_name 12 | self.model = QwenModel(self.model_name) 13 | 14 | def get_standard_functions(self, functions): 15 | return [{"type": "function", "function": copy.deepcopy(func)} for func in functions] 16 | 17 | def get_standard_fc(self, tool_call): 18 | try: 19 | return {"name": tool_call['function']['name'], "arguments": json.loads(tool_call['function']['arguments'])} 20 | except: 21 | return None 22 | 23 | def run(self, data): 24 | convs, functions = data['conversations'], data['functions'] 25 | self.CompareClass.add_free_function(convs) 26 | standard_functions = self.get_standard_functions(functions) 27 | 28 | messages = [] 29 | query = convs[0]['content'] 30 | messages.append({"role": "user", "content": query}) 31 | 32 | self.init_golden(convs) 33 | 34 | while True: 35 | llm_response = self.model(messages, tools=standard_functions) 36 | 37 | if llm_response is None: 38 | raise NotImplementedError("LLM response is None") 39 | 40 | if llm_response['tool_calls'] != None: 41 | if self.golden_fcs == []: 42 | self.logger.error(f"Output FC:\n{llm_response}") 43 | return self.return_result(messages, {"error_type": "func_hallucination", "content": "`self.golden_fcs == []`. Expected to stop. But Model continue to output function call."}) 44 | if llm_response['content'] is None: 45 | llm_response['content'] = "" 46 | self.model.messages.append(llm_response) 47 | tool_calls = llm_response['tool_calls'] 48 | 49 | function_calls = [] 50 | for tool_call in tool_calls: 51 | function_call = self.get_standard_fc(tool_call) 52 | if function_call is None: 53 | return self.return_result(messages, {"error_type": "decode_error", "content": f"{tool_call} is not Valid."}) 54 | function_calls.append(function_call) 55 | self.logger.info(f"Function Calls: \n{json.dumps(function_calls, ensure_ascii=False, indent=4)}\n") 56 | self.logger.info(f"Golden Function Call: \n{json.dumps(self.golden_fcs, ensure_ascii=False, indent=4)}\n") 57 | messages.append({"role": "assistant", "function_call": function_calls}) 58 | 59 | self.error_message, success_map, success_matched, format_error = self.CompareClass.compare_turn_prediction( 60 | functions, messages[:-1], 61 | copy.deepcopy(function_calls), self.golden_fcs, 62 | self.golden_obs 63 | ) 64 | if len(success_map) == 0 and format_error == {}: 65 | return self.return_result(messages, self.error_message) 66 | self.correct_count += len(success_map) 67 | 68 | real_time_obs = [] 69 | for t, function_call in enumerate(function_calls): 70 | if t in success_map: 71 | temp_obs = success_map[t] 72 | elif t in format_error: 73 | temp_obs = format_error[t] 74 | else: 75 | temp_obs = self.unexpect_call_resp 76 | 77 | real_time_obs.append(temp_obs) 78 | self.model.messages.append( 79 | { 80 | "name": function_call['name'], 81 | "role": "tool", 82 | "content": json.dumps(temp_obs, ensure_ascii=False) 83 | } 84 | ) 85 | self.process_matches(success_matched) 86 | 87 | self.logger.info(f"Observations:\n{json.dumps(real_time_obs, ensure_ascii=False, indent=4)}\n") 88 | messages.append({"role": "observation", "content": real_time_obs}) 89 | 90 | elif llm_response['tool_calls'] == None: 91 | final_response = llm_response['content'] 92 | self.logger.info(f"Final Response: {final_response}\n") 93 | messages.append({"role": "assistant", "content": final_response}) 94 | 95 | return self.return_result(messages) 96 | 97 | else: 98 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) -------------------------------------------------------------------------------- /utils/tool_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "booking-com15": { 3 | "host": "booking-com15.p.rapidapi.com", 4 | "path_params": [], 5 | "name_to_url": { 6 | "Search_Hotels_By_Coordinates": "https://booking-com15.p.rapidapi.com/api/v1/hotels/searchHotelsByCoordinates", 7 | "Get_Room_List_With_Availability": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getRoomListWithAvailability", 8 | "Get_Room_Availability": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getAvailability", 9 | "Get_Question_And_Answer": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getQuestionAndAnswer", 10 | "Get_Nearby_Cities": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getNearbyCities", 11 | "Get_Popular_Attraction_Near_By": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getPopularAttractionNearBy", 12 | "Get_Hotel_Reviews(Tips)": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getHotelReviews", 13 | "Get_Hotel_Reviews(Tips)_Sort_Option": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getHotelReviewsSortOption", 14 | "Get_Hotel_Review_Scores": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getHotelReviewScores", 15 | "Get_Hotel_Reviews_Filter_Metadata": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getHotelReviewsFilterMetadata", 16 | "Property_Children_Policies": "https://booking-com15.p.rapidapi.com/api/v1/hotels/propertyChildrenPolicies", 17 | "Get_Hotel_Policies": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getHotelPolicies", 18 | "Get_Room_List": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getRoomList", 19 | "Get_Description_And_Info": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getDescriptionAndInfo", 20 | "Get_Hotel_Details": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getHotelDetails", 21 | "Get_Filter": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getFilter", 22 | "Search_Hotel_Destination": "https://booking-com15.p.rapidapi.com/api/v1/hotels/searchDestination", 23 | "Get_Sort_By": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getSortBy", 24 | "Payment_features_of_the_Hotel": "https://booking-com15.p.rapidapi.com/api/v1/hotels/getPaymentFeatures", 25 | "Search_Hotels": "https://booking-com15.p.rapidapi.com/api/v1/hotels/searchHotels", 26 | "Search_Car_Rentals": "https://booking-com15.p.rapidapi.com/api/v1/cars/searchCarRentals", 27 | "Search_Car_Location": "https://booking-com15.p.rapidapi.com/api/v1/cars/searchDestination", 28 | "Booking_Summary": "https://booking-com15.p.rapidapi.com/api/v1/cars/bookingSummary", 29 | "Vehicle_Supplier_Ratings": "https://booking-com15.p.rapidapi.com/api/v1/cars/vehicleSupplierRatings", 30 | "Vehicle_Details": "https://booking-com15.p.rapidapi.com/api/v1/cars/vehicleDetails", 31 | "Vehicle_Supplier_Details": "https://booking-com15.p.rapidapi.com/api/v1/cars/vehicleSupplierDetails", 32 | "Get_Packages": "https://booking-com15.p.rapidapi.com/api/v1/cars/getPackages", 33 | "Vehicle_Supplier_Review": "https://booking-com15.p.rapidapi.com/api/v1/cars/vehicleSupplierReview", 34 | "Location_to_Lat_Long": "https://booking-com15.p.rapidapi.com/api/v1/meta/locationToLatLong", 35 | "Test_API": "https://booking-com15.p.rapidapi.com/api/v1/test", 36 | "Get_Exchange_Rates": "https://booking-com15.p.rapidapi.com/api/v1/meta/getExchangeRates", 37 | "Get_Currency": "https://booking-com15.p.rapidapi.com/api/v1/meta/getCurrency", 38 | "Get_Languages": "https://booking-com15.p.rapidapi.com/api/v1/meta/getLanguages", 39 | "Get_Attraction_Details": "https://booking-com15.p.rapidapi.com/api/v1/attraction/getAttractionDetails", 40 | "Search_Attractions": "https://booking-com15.p.rapidapi.com/api/v1/attraction/searchAttractions", 41 | "Get_Availability_Calendar": "https://booking-com15.p.rapidapi.com/api/v1/attraction/getAvailabilityCalendar", 42 | "Get_Availability": "https://booking-com15.p.rapidapi.com/api/v1/attraction/getAvailability", 43 | "Search_Attraction_Location": "https://booking-com15.p.rapidapi.com/api/v1/attraction/searchLocation", 44 | "Get_Min_Price_Multi_Stops": "https://booking-com15.p.rapidapi.com/api/v1/flights/getMinPriceMultiStops", 45 | "Get_Flight_Details": "https://booking-com15.p.rapidapi.com/api/v1/flights/getFlightDetails", 46 | "Get_Min_Price": "https://booking-com15.p.rapidapi.com/api/v1/flights/getMinPrice", 47 | "Get_Seat_Map": "https://booking-com15.p.rapidapi.com/api/v1/flights/getSeatMap", 48 | "Search_Flights_Multi_Stops": "https://booking-com15.p.rapidapi.com/api/v1/flights/searchFlightsMultiStops", 49 | "Search_Flights": "https://booking-com15.p.rapidapi.com/api/v1/flights/searchFlights", 50 | "Search_Flight_Location": "https://booking-com15.p.rapidapi.com/api/v1/flights/searchDestination", 51 | "Search_Taxi": "https://booking-com15.p.rapidapi.com/api/v1/taxi/searchTaxi", 52 | "Taxi_Search_Location": "https://booking-com15.p.rapidapi.com/api/v1/taxi/searchLocation", 53 | "v2/auto-complete": "https://seeking-alpha.p.rapidapi.com/v2/auto-complete" 54 | } 55 | } 56 | } -------------------------------------------------------------------------------- /prompts/response.py: -------------------------------------------------------------------------------- 1 | from prompts.prompts import SimpleTemplatePrompt 2 | 3 | """ 4 | Completeness Evaluation. 5 | """ 6 | # System prompt 7 | complete_system_prompt = """"You are a helpful response completeness detect assistant. Your task is to evaluate the response based on whether it fully addresses all aspects of the user's query. 8 | # Your Task 9 | For each user query and corresponding response, you should determine the completeness of the response using the following criteria: 10 | - If the response covers all requested information and addresses all parts of the user's query, it should be considered complete and receive a score of 2. 11 | - If the response addresses some but not all parts of the user's query, it should be considered partial and receive a score of 1. 12 | - If the response does not address any of the requested information in the user's query, it should be considered incomplete and receive a score of 0. 13 | 14 | # Output Format 15 | You should output the score for each user query and corresponding response in JSON format with following keys: 16 | - score: the completeness score for the response (0, 1, or 2) 17 | - reason: a string describing the reason for the score (e.g. "all requested information is addressed", "some requested information is not addressed", "no requested information is addressed") 18 | 19 | # Example 20 | ## Example 1 21 | input: 22 | query: I'm thinking about heading to Paris. Can you tell me which museums are currently trending? 23 | response: The Louvre is the most treding museums in Paris. 24 | 25 | output: 26 | ```JSON 27 | {"score": 2, "reason": "all requested information is addressed"} 28 | ``` 29 | 30 | ## Example 2 31 | input: 32 | query: I'm planning to stay in Milan for a week starting on November 2nd, 2024. Help me find detailed information about the most trending attractions in the center of Milan and book a hotel nearby this attraction. 33 | response: The Milan Cathedral (Duomo di Milano), located in the central square of Milan, Italy, is the largest Gothic cathedral in the world. It is a Roman Catholic cathedral and ranks among the top five churches in the world. \nConstruction of the Milan Cathedral began in 1386, the vault was completed in 1500, and the highest Gothic spire, topped with the gilded statue of the Virgin Mary (La Madonnina), was built by Giuseppe Perego in 1774. It is a symbol of Milan city. The entire cathedral was completed in 1965, taking five centuries to finish. Napoleon held his coronation ceremony at the Milan Cathedral in 1805. 34 | 35 | output: 36 | ```JSON 37 | {"score": 1, "reason": "Only give the detailed information about the most trending attractions in the center of Milan. Miss the hotel booking part."} 38 | ``` 39 | 40 | ## Example 3 41 | input: 42 | query: My partner and I are gonna fly from Chicago to Tokyo on September 5th, 2024. Then we'll fly from Tokyo to Shanghai on September 9th, 2024. Finally, we'll go back to Chicago on September 19th, 2024. Can you give us the info about economy class? 43 | response: Sorry, I can't find any info about economy class. 44 | 45 | output: 46 | ```JSON 47 | {"score": 0, "reason": "The response did not give any information about economy class."} 48 | ``` 49 | """ 50 | 51 | # User prompt 52 | complete_user_prompt = SimpleTemplatePrompt( 53 | template=("""input: 54 | query: [args1] 55 | response: [args2] 56 | 57 | output:\n 58 | """ 59 | ), args_order=["query", "gen_response"]) 60 | 61 | 62 | 63 | """ 64 | Correctness Evaluation. 65 | """ 66 | # System prompt 67 | correct_system_prompt = """"You are a helpful response correctness detect assistant. Your task is to evaluate the response based on its accuracy in matching the details provided by API response. 68 | # Your task 69 | Give a dialogue history containing user query, function calls and api responses, you should determine the correctness of the corresponding respons using the following criteria: 70 | - If the response is consistent with the information provided in the API response, it should be considered entirely correct and receive a score of 2. 71 | - If the response partially matches the information provided in the API response (with some correct and some incorrect details), it should be considered partially correct and receive a score of 1. 72 | - If the response does not match any of the information provided in the API response, it should be considered incorrect and receive a score of 0. 73 | 74 | # Output Format 75 | You should output the score for each dialogue history and corresponding response in JSON format with following keys: 76 | - score: the completeness score for the response (0, 1, or 2) 77 | - reason: a string describing the reason for the score (e.g. "all mentioned information is correct.", "Some information is correct, and some are incorrect.", "All information in the response is incorrect.") 78 | 79 | ## example output 1 80 | output: 81 | ```JSON 82 | {"score": 2, "reason": "All mentioned information is correct."} 83 | ``` 84 | 85 | ## example output 2 86 | output: 87 | ```JSON 88 | {"score": 1, "reason": "xx is correct, and yy is incorrect."} 89 | ``` 90 | 91 | ## example output 3 92 | output: 93 | ```JSON 94 | {"score": 0, "reason": "All information in the response is incorrect."} 95 | ``` 96 | """ 97 | 98 | 99 | # User prompt 100 | correct_user_prompt = SimpleTemplatePrompt( 101 | template=("""dialogue history: [args1] 102 | response: [args2] 103 | output:\n 104 | """ 105 | ), args_order=["history", "gen_response"]) 106 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import random 4 | import argparse 5 | import os 6 | import logging 7 | import datetime 8 | from collections import defaultdict 9 | import multiprocessing 10 | from multiprocessing import Pool, Manager 11 | from functools import partial 12 | 13 | from utils.logger import Logger 14 | from utils.utils import * 15 | 16 | from runner.gpt_runner import GPTRunner 17 | from runner.glm_runner import GLMRunner, GLMAPIRunner 18 | from runner.claude_runner import ClaudeRunner 19 | from runner.qwen_runner import QwenRunner 20 | from runner.llama_runner import LlamaRunner 21 | from runner.mistral_runner import MistralRunner 22 | from runner.response_runner import RespEvalRunner 23 | 24 | MODEL_MAPPING = { 25 | "gpt-4o-2024-08-06": GPTRunner, 26 | "gpt-4-turbo-2024-04-09": GPTRunner, 27 | "claude-3-5-sonnet-20240620": ClaudeRunner, 28 | "claude-3-5-sonnet-20241022": ClaudeRunner, 29 | "claude-3-5-haiku-20241022": ClaudeRunner, 30 | "glm-4-9b-chat": GLMRunner, 31 | "glm-4-long": GLMAPIRunner, 32 | "Llama-3.1-70B": LlamaRunner, 33 | "Llama-3.1-8B": LlamaRunner, 34 | "Meta-Llama-3.1-405B-Instruct-FP8": LlamaRunner, 35 | "qwen2.5-7b-instruct": QwenRunner, 36 | "qwen2.5-72b-instruct": QwenRunner, 37 | "qwen2.5-7b-instruct": QwenRunner, 38 | "mistral-large-2407": MistralRunner, 39 | } 40 | 41 | 42 | def get_args(): 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--log_dir", type=str, default="logs/test.log") 45 | parser.add_argument("--input_file", type=str, default="data/ComplexFuncBench.jsonl") 46 | parser.add_argument("--model_name", type=str, required=True, choices=list(MODEL_MAPPING.keys()), help="The name of the model to be evaluated.") 47 | parser.add_argument('--exp_name', type=str, default='full-1000') 48 | parser.add_argument("--vllm_url", type=str) 49 | parser.add_argument("--proc_num", type=int, default=1) 50 | parser.add_argument("--debug", action="store_true") 51 | 52 | args = parser.parse_args() 53 | 54 | os.makedirs(f"logs/{datetime.date.today().strftime('%Y-%m-%d')}/{args.model_name}", exist_ok=True) 55 | os.makedirs(f"result/{args.model_name}/{args.exp_name}/logs", exist_ok=True) 56 | 57 | args.log_dir = f"logs/{datetime.date.today().strftime('%Y-%m-%d')}/{args.model_name}/{args.exp_name}.log" 58 | args.output_dir = f"result/{args.model_name}/{args.exp_name}.jsonl" 59 | args.log_dir = f"result/{args.model_name}/{args.exp_name}/logs" 60 | return args 61 | 62 | 63 | def process_example(data, args): 64 | log_dir = f"{args.log_dir}/{data['id']}.log" 65 | logger = Logger(f"evaluation_logger_{data['id']}", log_dir, logging.DEBUG) 66 | 67 | model = MODEL_MAPPING[args.model_name](args=args, logger=logger) 68 | resp_eval_model = RespEvalRunner(args=args, logger=logger) 69 | 70 | logger.info(f"Test Example {data['id']}") 71 | logger.info(f"Query: {data['conversations'][0]['content']}") 72 | 73 | turn_count, call_count = 0, 0 74 | for turn in data['conversations']: 75 | if turn['role'] == "assistant" and "function_call" in turn: 76 | turn_count += 1 77 | call_count += len(turn["function_call"]) 78 | 79 | convs, message, turn_id, correct_count = model.run(data) 80 | 81 | # API Error 82 | if isinstance(message, dict) and message["error_type"] == "unknown_error": 83 | return None 84 | 85 | real_turn_count = 0 86 | for turn in convs: 87 | if turn['role'] == "assistant" and "function_call" in turn: 88 | real_turn_count += 1 89 | 90 | if convs[-1]['role'] == "assistant" and "content" in convs[-1]: 91 | gen_response = convs[-1]['content'] 92 | resp_eval_result = resp_eval_model.run(data, gen_response) 93 | else: 94 | resp_eval_result = None 95 | 96 | logger.info(f"Message: {message}") 97 | logger.info(f"Success turn num = {turn_id}") 98 | logger.info("-" * 100) 99 | 100 | result = { 101 | "id": data['id'], 102 | "gen_convs": convs, 103 | "message": message, 104 | "count_dict": { 105 | "success_turn_num": turn_id, 106 | "total_turn_num": turn_count, 107 | "correct_call_num": correct_count, 108 | "total_call_num": call_count, 109 | "real_turn_num": real_turn_count 110 | }, 111 | "resp_eval": resp_eval_result 112 | } 113 | 114 | with open(args.output_dir, 'a+') as f: 115 | f.write(json.dumps(result, ensure_ascii=False) + "\n") 116 | f.flush() 117 | 118 | return result 119 | 120 | 121 | def main(): 122 | args = get_args() 123 | test_data = load_json(args.input_file) 124 | if args.debug: 125 | test_data = random.sample(test_data, 10) 126 | 127 | if os.path.exists(args.output_dir): 128 | finished_data = load_json(args.output_dir) 129 | finised_ids = [d["id"] for d in finished_data] 130 | else: 131 | finised_ids = [] 132 | test_data = [d for d in test_data if d['id'] not in finised_ids] 133 | 134 | with Manager() as manager: 135 | pool = Pool(processes=args.proc_num) 136 | process_example_partial = partial(process_example) 137 | results = pool.starmap(process_example_partial, [(data, args) for data in test_data]) 138 | 139 | pool.close() 140 | pool.join() 141 | 142 | 143 | if __name__ == '__main__': 144 | multiprocessing.set_start_method('spawn') 145 | main() 146 | -------------------------------------------------------------------------------- /runner/llama_runner.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import json 4 | from models.llama import LlamaModel 5 | from runner.base_runner import ModelRunner 6 | 7 | 8 | class LlamaRunner(ModelRunner): 9 | def __init__(self, args, logger): 10 | super().__init__(args, logger) 11 | self.model_name = args.model_name 12 | self.model = LlamaModel(args.vllm_url, self.model_name) 13 | 14 | def get_standard_functions(self, functions): 15 | return [{"type": "function", "function": copy.deepcopy(func)} for func in functions] 16 | 17 | def get_standard_fc(self, call): 18 | try: 19 | return {"name": call['name'], "arguments": call['parameters']} 20 | except: 21 | return None 22 | 23 | def decode_response(self, result): 24 | result = result.replace("<|python_tag|>", "") 25 | try: 26 | if ";" in result: 27 | function_calls = result.split(";") 28 | function_calls = [json.loads(func_call) for func_call in function_calls] 29 | else: 30 | function_calls = eval(result) 31 | if type(function_calls) == dict: 32 | function_calls = [function_calls] 33 | 34 | decoded_output = {"return_type": "tool_calls", "tool_calls": function_calls} 35 | except: 36 | decoded_output = {"return_type": "response", "content": result} 37 | 38 | return decoded_output 39 | 40 | def run(self, data): 41 | convs, functions = data['conversations'], data['functions'] 42 | self.CompareClass.add_free_function(convs) 43 | standard_functions = self.get_standard_functions(functions) 44 | 45 | messages = [] 46 | query = convs[0]['content'] 47 | messages.append({"role": "user", "content": query}) 48 | 49 | self.init_golden(convs) 50 | 51 | while True: 52 | llm_response = self.model(messages, tools=standard_functions) 53 | if llm_response is None: 54 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) 55 | decoded_response = self.decode_response(llm_response) 56 | 57 | if decoded_response['return_type'] == "tool_calls": 58 | if self.golden_fcs == []: 59 | self.logger.error(f"Output FC:\n{decoded_response['tool_calls']}") 60 | return self.return_result(messages, {"error_type": "func_hallucination", "content": "`self.golden_fcs == []`. Expected to stop. But Model continue to output function call."}) 61 | self.model.messages.append({"role": "assistant", "content": llm_response.replace("<|python_tag|>", "")}) 62 | tool_calls = decoded_response['tool_calls'] 63 | 64 | function_calls = [] 65 | for tool_call in tool_calls: 66 | function_call = self.get_standard_fc(tool_call) 67 | if function_call is None: 68 | return self.return_result(messages, {"error_type": "name_error", "content": f"{tool_call} is not Valid."}) 69 | function_calls.append(function_call) 70 | self.logger.info(f"Function Calls: \n{json.dumps(function_calls, ensure_ascii=False, indent=4)}\n") 71 | self.logger.info(f"Golden Function Call: \n{json.dumps(self.golden_fcs, ensure_ascii=False, indent=4)}\n") 72 | messages.append({"role": "assistant", "function_call": function_calls}) 73 | 74 | self.error_message, success_map, success_matched, format_error = self.CompareClass.compare_turn_prediction( 75 | functions, messages[:-1], 76 | copy.deepcopy(function_calls), self.golden_fcs, 77 | self.golden_obs 78 | ) 79 | if len(success_map) == 0 and format_error == {}: 80 | return self.return_result(messages, self.error_message) 81 | self.correct_count += len(success_map) 82 | 83 | real_time_obs = [] 84 | for t, function_call in enumerate(function_calls): 85 | if t in success_map: 86 | temp_obs = success_map[t] 87 | elif t in format_error: 88 | temp_obs = format_error[t] 89 | else: 90 | temp_obs = self.unexpect_call_resp 91 | 92 | real_time_obs.append(temp_obs) 93 | self.model.messages.append( 94 | { 95 | "role": "tool", 96 | "content": temp_obs 97 | } 98 | ) 99 | 100 | self.process_matches(success_matched) 101 | 102 | self.logger.info(f"Observations:\n{json.dumps(real_time_obs, ensure_ascii=False, indent=4)}\n") 103 | messages.append({"role": "observation", "content": real_time_obs}) 104 | 105 | elif decoded_response['return_type'] == "response": 106 | final_response = decoded_response['content'] 107 | self.logger.info(f"Final Response: {final_response}\n") 108 | messages.append({"role": "assistant", "content": final_response}) 109 | 110 | return self.return_result(messages) 111 | 112 | else: 113 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) -------------------------------------------------------------------------------- /runner/mistral_runner.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import json 4 | from models.mistral import MistralModel 5 | from runner.base_runner import ModelRunner 6 | 7 | 8 | class MistralRunner(ModelRunner): 9 | def __init__(self, args, logger): 10 | super().__init__(args, logger) 11 | self.model_name = args.model_name 12 | self.model = MistralModel(self.model_name) 13 | 14 | def replace_invalid_chars(self, s): 15 | valid_pattern = re.compile(r'[a-zA-Z0-9_-]') 16 | result = ''.join([char if valid_pattern.match(char) else '-' for char in s]) 17 | return result[:64] 18 | 19 | def get_standard_functions(self, functions): 20 | self.name_dict = {api['name']: self.replace_invalid_chars(api['name']) for api in functions} 21 | gpt_functions = [{"type": "function", "function": copy.deepcopy(func)} for func in functions] 22 | for func in gpt_functions: 23 | func['function']['name'] = self.name_dict[func['function']['name']] 24 | return gpt_functions 25 | 26 | def get_standard_fc(self, call): 27 | tool_call = copy.deepcopy(call) 28 | 29 | function_call = {} 30 | function_call['name'] = next((k for k, v in self.name_dict.items() if v == tool_call.function.name), None) 31 | if function_call['name'] is None: 32 | return None 33 | function_call['arguments'] = json.loads(tool_call.function.arguments) 34 | if function_call['arguments'] is None: 35 | function_call['arguments'] = {} 36 | 37 | return function_call 38 | 39 | def run(self, data): 40 | convs, functions = data['conversations'], data['functions'] 41 | self.CompareClass.add_free_function(convs) 42 | standard_functions = self.get_standard_functions(functions) 43 | 44 | messages = [] 45 | query = convs[0]['content'] 46 | messages.append({"role": "user", "content": query}) 47 | 48 | self.init_golden(convs) 49 | 50 | while True: 51 | llm_response = self.model(messages, tools=standard_functions) 52 | 53 | if llm_response is None: 54 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) 55 | 56 | if llm_response.tool_calls: 57 | if self.golden_fcs == []: 58 | self.logger.error(f"Output FC:\n{llm_response.tool_calls}") 59 | return self.return_result(messages, {"error_type": "func_hallucination", "content": "`self.golden_fcs == []`. Expected to stop. But Model continue to output function call."}) 60 | self.model.messages.append(llm_response) 61 | tool_calls = llm_response.tool_calls 62 | 63 | function_calls = [] 64 | for tool_call in tool_calls: 65 | function_call = self.get_standard_fc(tool_call) 66 | if function_call is None: 67 | return self.return_result(messages, {"error_type": "name_error", "content": f"{tool_call.function} is not Valid."}) 68 | function_calls.append(function_call) 69 | self.logger.info(f"Function Calls: \n{json.dumps(function_calls, ensure_ascii=False, indent=4)}\n") 70 | self.logger.info(f"Golden Function Call: \n{json.dumps(self.golden_fcs, ensure_ascii=False, indent=4)}\n") 71 | messages.append({"role": "assistant", "function_call": function_calls}) 72 | 73 | self.error_message, success_map, success_matched, format_error = self.CompareClass.compare_turn_prediction( 74 | functions, messages[:-1], 75 | copy.deepcopy(function_calls), self.golden_fcs, 76 | self.golden_obs 77 | ) 78 | if len(success_map) == 0 and format_error == {}: 79 | return self.return_result(messages, self.error_message) 80 | self.correct_count += len(success_map) 81 | 82 | real_time_obs = [] 83 | for t, function_call in enumerate(function_calls): 84 | if t in success_map: 85 | temp_obs = success_map[t] 86 | elif t in format_error: 87 | temp_obs = format_error[t] 88 | else: 89 | temp_obs = self.unexpect_call_resp 90 | 91 | real_time_obs.append(temp_obs) 92 | self.model.messages.append( 93 | { 94 | "role": "tool", 95 | "name": self.name_dict[function_call['name']], 96 | "content": json.dumps(temp_obs, ensure_ascii=False), 97 | "tool_call_id": tool_calls[t].id, 98 | } 99 | ) 100 | 101 | self.process_matches(success_matched) 102 | 103 | self.logger.info(f"Observations:\n{json.dumps(real_time_obs, ensure_ascii=False, indent=4)}\n") 104 | messages.append({"role": "observation", "content": real_time_obs}) 105 | 106 | elif llm_response.content is not None: 107 | final_response = llm_response.content 108 | self.logger.info(f"Final Response: {final_response}\n") 109 | messages.append({"role": "assistant", "content": final_response}) 110 | 111 | return self.return_result(messages, self.error_message) 112 | 113 | else: 114 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) -------------------------------------------------------------------------------- /runner/gpt_runner.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import json 4 | from models.gpt import FunctionCallGPT 5 | from runner.base_runner import ModelRunner 6 | 7 | 8 | class GPTRunner(ModelRunner): 9 | def __init__(self, args, logger): 10 | super().__init__(args, logger) 11 | self.model_name = args.model_name 12 | self.model = FunctionCallGPT(self.model_name) 13 | 14 | def replace_invalid_chars(self, s): 15 | valid_pattern = re.compile(r'[a-zA-Z0-9_-]') 16 | result = ''.join([char if valid_pattern.match(char) else '-' for char in s]) 17 | return result[:64] 18 | 19 | def get_standard_functions(self, functions): 20 | self.name_dict = {api['name']: self.replace_invalid_chars(api['name']) for api in functions} 21 | gpt_functions = [{"type": "function", "function": copy.deepcopy(func)} for func in functions] 22 | for func in gpt_functions: 23 | func['function']['name'] = self.name_dict[func['function']['name']] 24 | return gpt_functions 25 | 26 | def get_standard_fc(self, gpt4_tool_call): 27 | tool_call = copy.deepcopy(gpt4_tool_call) 28 | 29 | function_call = {} 30 | function_call['name'] = next((k for k, v in self.name_dict.items() if v == tool_call.function.name), None) 31 | if function_call['name'] is None: 32 | return None 33 | try: 34 | function_call['arguments'] = json.loads(tool_call.function.arguments) 35 | except: 36 | return None 37 | if function_call['arguments'] is None: 38 | function_call['arguments'] = {} 39 | 40 | return function_call 41 | 42 | def run(self, data): 43 | convs, functions = data['conversations'], data['functions'] 44 | self.CompareClass.add_free_function(convs) 45 | gpt_functions = self.get_standard_functions(functions) 46 | 47 | messages = [] 48 | query = convs[0]['content'] 49 | messages.append({"role": "user", "content": query}) 50 | 51 | self.init_golden(convs) 52 | 53 | while True: 54 | llm_response = self.model(messages, tools=gpt_functions) 55 | if llm_response is None: 56 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) 57 | 58 | if llm_response.tool_calls: 59 | if self.golden_fcs == []: 60 | self.logger.error(f"Output FC:\n{llm_response.tool_calls}") 61 | return self.return_result(messages, {"error_type": "func_hallucination", "content": "`self.golden_fcs == []`. Expected to stop. But Model continue to output function call."}) 62 | self.model.messages.append({"role": "assistant", "content": None, "tool_calls": llm_response.tool_calls}) 63 | tool_calls = llm_response.tool_calls 64 | 65 | function_calls = [] 66 | for tool_call in tool_calls: 67 | function_call = self.get_standard_fc(tool_call) 68 | if function_call is None: 69 | return self.return_result(messages, {"error_type": "decode_error", "content": f"{tool_call.function} is not Valid."}) 70 | function_calls.append(function_call) 71 | self.logger.info(f"Function Calls: \n{json.dumps(function_calls, ensure_ascii=False, indent=4)}\n") 72 | self.logger.info(f"Golden Function Call: \n{json.dumps(self.golden_fcs, ensure_ascii=False, indent=4)}\n") 73 | messages.append({"role": "assistant", "function_call": function_calls}) 74 | 75 | self.error_message, success_map, success_matched, format_error = self.CompareClass.compare_turn_prediction( 76 | functions, messages[:-1], 77 | copy.deepcopy(function_calls), self.golden_fcs, 78 | self.golden_obs 79 | ) 80 | if len(success_map) == 0 and format_error == {}: 81 | return self.return_result(messages, self.error_message) 82 | self.correct_count += len(success_map) 83 | 84 | real_time_obs = [] 85 | for t, function_call in enumerate(function_calls): 86 | if t in success_map: 87 | temp_obs = success_map[t] 88 | elif t in format_error: 89 | temp_obs = format_error[t] 90 | else: 91 | temp_obs = self.unexpect_call_resp 92 | 93 | real_time_obs.append(temp_obs) 94 | self.model.messages.append( 95 | { 96 | "tool_call_id": tool_calls[t].id, 97 | "role": "tool", 98 | "name": self.name_dict[function_call['name']], 99 | "content": json.dumps(temp_obs, ensure_ascii=False) 100 | } 101 | ) 102 | 103 | self.process_matches(success_matched) 104 | 105 | self.logger.info(f"Observations:\n{json.dumps(real_time_obs, ensure_ascii=False, indent=4)}\n") 106 | messages.append({"role": "observation", "content": real_time_obs}) 107 | 108 | elif llm_response.content is not None: 109 | final_response = llm_response.content 110 | self.logger.info(f"Final Response: {final_response}\n") 111 | messages.append({"role": "assistant", "content": final_response}) 112 | 113 | return self.return_result(messages, self.error_message) 114 | 115 | else: 116 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /utils/exact_match_values.json: -------------------------------------------------------------------------------- 1 | { 2 | "Search_Hotels_By_Coordinates": [ 3 | "arrival_date", 4 | "departure_date", 5 | "latitude", 6 | "longitude", 7 | "currency_code", 8 | "languagecode" 9 | ], 10 | "Get_Room_List_With_Availability": [ 11 | "hotel_id", 12 | "departure_date", 13 | "arrival_date", 14 | "currency_code", 15 | "languagecode" 16 | ], 17 | "Get_Room_Availability": [ 18 | "currency_code", 19 | "hotel_id" 20 | ], 21 | "Get_Question_And_Answer": [ 22 | "languagecode", 23 | "hotel_id" 24 | ], 25 | "Get_Nearby_Cities": [ 26 | "languagecode", 27 | "longitude", 28 | "latitude" 29 | ], 30 | "Get_Popular_Attraction_Near_By": [ 31 | "languagecode", 32 | "hotel_id" 33 | ], 34 | "Get_Hotel_Reviews(Tips)": [ 35 | "languagecode", 36 | "hotel_id", 37 | "sort_option_id" 38 | ], 39 | "Get_Hotel_Reviews(Tips)_Sort_Option": [ 40 | "languagecode", 41 | "hotel_id" 42 | ], 43 | "Get_Hotel_Review_Scores": [ 44 | "languagecode", 45 | "hotel_id" 46 | ], 47 | "Get_Hotel_Reviews_Filter_Metadata": [ 48 | "languagecode", 49 | "hotel_id" 50 | ], 51 | "Property_Children_Policies": [ 52 | "languagecode", 53 | "hotel_id" 54 | ], 55 | "Get_Hotel_Policies": [ 56 | "languagecode", 57 | "hotel_id" 58 | ], 59 | "Get_Room_List": [ 60 | "currency_code", 61 | "languagecode", 62 | "departure_date", 63 | "arrival_date", 64 | "hotel_id" 65 | ], 66 | "Get_Description_And_Info": [ 67 | "languagecode", 68 | "hotel_id" 69 | ], 70 | "Get_Hotel_Details": [ 71 | "currency_code", 72 | "languagecode", 73 | "departure_date", 74 | "hotel_id", 75 | "arrival_date" 76 | ], 77 | "Get_Filter": [ 78 | "search_type", 79 | "languagecode", 80 | "dest_id", 81 | "departure_date", 82 | "arrival_date", 83 | "categories_filter" 84 | ], 85 | "Search_Hotel_Destination": [ 86 | ], 87 | "Get_Sort_By": [ 88 | "search_type", 89 | "departure_date", 90 | "languagecode", 91 | "dest_id", 92 | "arrival_date" 93 | ], 94 | "Payment_features_of_the_Hotel": [ 95 | "languagecode", 96 | "hotel_id" 97 | ], 98 | "Search_Hotels": [ 99 | "currency_code", 100 | "languagecode", 101 | "search_type", 102 | "arrival_date", 103 | "dest_id", 104 | "departure_date", 105 | "categories_filter", 106 | "sort_by" 107 | ], 108 | "Search_Car_Rentals": [ 109 | "currency_code", 110 | "languagecode", 111 | "pick_up_longitude", 112 | "drop_off_time", 113 | "drop_off_latitude", 114 | "drop_off_date", 115 | "pick_up_time", 116 | "pick_up_date", 117 | "drop_off_longitude" 118 | ], 119 | "Search_Car_Location": [ 120 | "languagecode" 121 | ], 122 | "Booking_Summary": [ 123 | "currency_code", 124 | "languagecode", 125 | "search_key", 126 | "vehicle_id" 127 | ], 128 | "Vehicle_Supplier_Ratings": [ 129 | "currency_code", 130 | "languagecode", 131 | "vehicle_id", 132 | "search_key" 133 | ], 134 | "Vehicle_Details": [ 135 | "currency_code", 136 | "languagecode", 137 | "vehicle_id", 138 | "search_key" 139 | ], 140 | "Vehicle_Supplier_Details": [ 141 | "currency_code", 142 | "languagecode", 143 | "vehicle_id", 144 | "search_key" 145 | ], 146 | "Get_Packages": [ 147 | "currency_code", 148 | "languagecode", 149 | "vehicle_id", 150 | "search_key" 151 | ], 152 | "Vehicle_Supplier_Review": [ 153 | "languagecode", 154 | "vehicle_id", 155 | "search_key" 156 | ], 157 | "Location_to_Lat_Long": [ 158 | ], 159 | "Get_Exchange_Rates": [ 160 | ], 161 | "Get_Currency": [], 162 | "Get_Languages": [], 163 | "Get_Attraction_Details": [ 164 | "currency_code", 165 | "languagecode", 166 | "slug" 167 | ], 168 | "Search_Attractions": [ 169 | "currency_code", 170 | "languagecode", 171 | "endDate", 172 | "startDate", 173 | "id" 174 | ], 175 | "Get_Availability_Calendar": [ 176 | "languagecode", 177 | "id" 178 | ], 179 | "Get_Availability": [ 180 | "currency_code", 181 | "languagecode", 182 | "slug", 183 | "date" 184 | ], 185 | "Search_Attraction_Location": [ 186 | "languagecode" 187 | ], 188 | "Get_Min_Price_Multi_Stops": [ 189 | "currency_code" 190 | ], 191 | "Get_Flight_Details": [ 192 | "currency_code", 193 | "token" 194 | ], 195 | "Get_Min_Price": [ 196 | "currency_code", 197 | "fromId", 198 | "returnDate", 199 | "toId", 200 | "departDate" 201 | ], 202 | "Get_Seat_Map": [ 203 | "currency_code", 204 | "offerToken" 205 | ], 206 | "Search_Flights_Multi_Stops": [ 207 | "currency_code", 208 | "sort" 209 | ], 210 | "Search_Flights": [ 211 | "currency_code", 212 | "toId", 213 | "fromId", 214 | "returnDate", 215 | "departDate", 216 | "sort" 217 | ], 218 | "Search_Flight_Location": [ 219 | "languagecode" 220 | ], 221 | "Search_Taxi": [ 222 | "currency_code", 223 | "languagecode", 224 | "drop_off_place_id", 225 | "pick_up_place_id", 226 | "pick_up_time", 227 | "pick_up_date" 228 | ], 229 | "Taxi_Search_Location": [ 230 | "languagecode" 231 | ] 232 | } -------------------------------------------------------------------------------- /models/glm.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from typing import Any, Dict, List, Optional, Set, Tuple, Union 3 | import json 4 | import random 5 | import re 6 | import requests 7 | import ast 8 | import os 9 | import datetime 10 | import copy 11 | from zhipuai import ZhipuAI 12 | 13 | from utils.utils import * 14 | 15 | from openai import OpenAI 16 | Message = dict[str, str] # keys role, content 17 | MessageList = list[Message] 18 | 19 | 20 | class GLMAPIModel(): 21 | def __init__(self, model_name): 22 | super().__init__() 23 | self.model_name = model_name 24 | self.message = [] 25 | self.client = ZhipuAI(api_key=os.getenv("ZHIPU_API_KEY")) 26 | 27 | @retry(max_attempts=10) 28 | def __call__(self, messages, tools=None, **kwargs: Any): 29 | if "function_call" not in json.dumps(messages, ensure_ascii=False): 30 | self.messages = copy.deepcopy(messages) 31 | try: 32 | completion = self.client.chat.completions.create( 33 | model=self.model_name, 34 | messages=self.messages, 35 | temperature=0.0, 36 | stream=False, 37 | do_sample=False, 38 | tools=tools, 39 | tool_choice="auto", 40 | max_tokens=2048 41 | ) 42 | return completion.choices[0] 43 | 44 | except Exception as e: 45 | print(f"Exception: {e}") 46 | return None 47 | 48 | 49 | class GLMVllmModel(): 50 | def __init__(self, url, model_name): 51 | super().__init__() 52 | self.model_name = model_name 53 | self.message = [] 54 | self.url = url 55 | 56 | self.client = OpenAI( 57 | api_key="EMPTY", 58 | base_url=self.url 59 | ) 60 | 61 | def build_system_prompt(self, functions=None, current_time: Optional[float] = None): 62 | if functions is None: 63 | functions = [] 64 | value = "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱 AI 公司训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。" 65 | _date_prompt = "当前日期: %Y-%m-%d" 66 | if current_time is not None: 67 | value += "\n\n" + datetime.datetime.fromtimestamp(current_time).strftime(_date_prompt) 68 | if len(functions) > 0: 69 | value += "\n\n# 可用工具" 70 | contents = [] 71 | for function in functions: 72 | content = f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}" 73 | content += "\n在调用上述函数时,请使用 Json 格式表示调用的参数。" 74 | contents.append(content) 75 | random.shuffle(contents) 76 | value += "".join(contents) 77 | return value 78 | 79 | def build_single_message(self, role, metadata, message, add_dummy_prefix=False): 80 | assert role in ["system", "user", "assistant", "observation"], role 81 | role_tokens = f"<|{role}|>" + f"{metadata if metadata is not None else ''}\n" 82 | message_tokens = message 83 | tokens = role_tokens + message_tokens 84 | return tokens 85 | 86 | def get_full_prompt(self, messages): 87 | prompt = "" 88 | for item in messages: 89 | content = item["content"] 90 | prompt += self.build_single_message(item["role"], item.get("metadata", ""), content, 91 | add_dummy_prefix=False) 92 | prompt += " <|assistant|>" 93 | return prompt 94 | 95 | def process_single_call(self, text): 96 | name, arguments = text.split("\n") 97 | arguments = json.loads(arguments) 98 | return {"name": name, "arguments": arguments} 99 | 100 | def get_standard_messages(self, messages, tools): 101 | system_prompt = self.build_system_prompt(functions=tools) 102 | messages.insert(0, {"role": "system", "content": system_prompt}) 103 | new_messages = [] 104 | for message in messages: 105 | if message['role'] == "assistant" and "function_call" in message: 106 | for call in message['function_call']: 107 | value = json.dumps(call["arguments"], ensure_ascii=False) 108 | new_messages.append({"role": "assistant", "metadata": call['name'], "content": value}) 109 | else: 110 | new_messages.append(message) 111 | 112 | return new_messages 113 | 114 | @retry(max_attempts=5) 115 | def __call__(self, messages, tools=None, **kwargs: Any): 116 | generated_result = [] 117 | function_calls = [] 118 | messages = self.get_standard_messages(messages, tools) 119 | 120 | while True: 121 | response = self.client.completions.create( 122 | model=self.model_name, 123 | prompt=self.get_full_prompt(messages), 124 | temperature=0.0, 125 | max_tokens=2048 126 | ) 127 | 128 | try: 129 | function_call = self.process_single_call(response.choices[0].text) 130 | function_calls.append(function_call) 131 | messages.append({ 132 | "role": "assistant", "metadata": function_call['name'], 133 | "content": json.dumps(function_call["arguments"], ensure_ascii=False)}) 134 | except: 135 | single_message = {"role": "assistant", "content": response.choices[0].text.strip()} 136 | generated_result.append(single_message) 137 | messages.append(single_message) 138 | 139 | try: 140 | if response.choices[0].stop_reason == 151338: 141 | generated_result.append({"role": "assistant", "function_call": function_calls}) 142 | return generated_result 143 | 144 | elif response.choices[0].stop_reason == 151336: 145 | return generated_result 146 | except Exception as e: 147 | print(e) 148 | return None 149 | 150 | 151 | if __name__ =="__main__": 152 | model = GLMAPIModel("glm-4-alltools") 153 | 154 | -------------------------------------------------------------------------------- /runner/claude_runner.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import json 4 | from models.claude import FunctionCallClaude 5 | from anthropic.types import TextBlock, ToolUseBlock 6 | from runner.base_runner import ModelRunner 7 | 8 | 9 | class ClaudeRunner(ModelRunner): 10 | def __init__(self, args, logger): 11 | super().__init__(args, logger) 12 | self.model_name = args.model_name 13 | self.model = FunctionCallClaude(self.model_name) 14 | 15 | def replace_invalid_chars(self, s): 16 | valid_pattern = re.compile(r'[a-zA-Z0-9_-]') 17 | result = ''.join([char if valid_pattern.match(char) else '-' for char in s]) 18 | return result[:64] 19 | 20 | def get_standard_functions(self, functions): 21 | self.name_dict = {api['name']: self.replace_invalid_chars(api['name']) for api in functions} 22 | 23 | claude_functions = [] 24 | for func in functions: 25 | new_func = copy.deepcopy(func) 26 | new_func['input_schema'] = new_func.pop("parameters") 27 | claude_functions.append(new_func) 28 | 29 | for func in claude_functions: 30 | func['name'] = self.name_dict[func['name']] 31 | return claude_functions 32 | 33 | def get_standard_fc(self, claude_tool_call): 34 | tool_call = copy.deepcopy(claude_tool_call) 35 | 36 | function_call = {} 37 | function_call['name'] = next((k for k, v in self.name_dict.items() if v == tool_call.name), None) 38 | if function_call['name'] is None: 39 | return None 40 | function_call['arguments'] = tool_call.input 41 | if function_call['arguments'] is None: 42 | function_call['arguments'] = {} 43 | 44 | return function_call 45 | 46 | def run(self, data): 47 | convs, functions = data['conversations'], data['functions'] 48 | self.CompareClass.add_free_function(convs) 49 | gpt_functions = self.get_standard_functions(functions) 50 | 51 | messages = [] 52 | query = convs[0]['content'] 53 | messages.append({"role": "user", "content": query}) 54 | 55 | self.init_golden(convs) 56 | 57 | while True: 58 | llm_response = self.model(messages, tools=gpt_functions) 59 | 60 | if llm_response is None: 61 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) 62 | 63 | if llm_response.stop_reason == "tool_use": 64 | if self.golden_fcs == []: 65 | self.logger.error(f"Output FC:\n{llm_response}") 66 | return self.return_result(messages, 67 | {"error_type": "func_hallucination", "content": "`self.golden_fcs == []`. Expected to stop. But Model continue to output function call."}) 68 | 69 | function_calls, call_ids = [], [] 70 | for content in llm_response.content: 71 | if isinstance(content, TextBlock): 72 | messages.append({"role": "assistant", "content": content.text}) 73 | elif isinstance(content, ToolUseBlock): 74 | call_ids.append(content.id) 75 | function_call = self.get_standard_fc(content) 76 | if function_call is None: 77 | return self.return_result(messages, {"error_type": "name_error", "content": f"{content} is not Valid."}) 78 | function_calls.append(function_call) 79 | 80 | self.model.messages.append({"role": "assistant", "content": llm_response.content}) 81 | self.logger.info(f"Function Calls: \n{json.dumps(function_calls, ensure_ascii=False, indent=4)}\n") 82 | self.logger.info(f"Golden Function Call: \n{json.dumps(self.golden_fcs, ensure_ascii=False, indent=4)}\n") 83 | messages.append({"role": "assistant", "function_call": function_calls}) 84 | 85 | self.error_message, success_map, success_matched, format_error = self.CompareClass.compare_turn_prediction( 86 | functions, messages[:-1], 87 | copy.deepcopy(function_calls), self.golden_fcs, 88 | self.golden_obs 89 | ) 90 | if len(success_map) == 0 and format_error == {}: 91 | return self.return_result(messages, self.error_message) 92 | self.correct_count += len(success_map) 93 | 94 | real_time_obs = [] 95 | self.model.messages.append({"role": "user", "content":[]}) 96 | for t, function_call in enumerate(function_calls): 97 | if t in success_map: 98 | temp_obs = success_map[t] 99 | else: 100 | temp_obs = self.unexpect_call_resp 101 | real_time_obs.append(temp_obs) 102 | self.model.messages[-1]['content'].append( 103 | { 104 | "type": "tool_result", 105 | "tool_use_id": call_ids[t], 106 | "content": json.dumps(temp_obs, ensure_ascii=False) 107 | } 108 | ) 109 | 110 | self.process_matches(success_matched) 111 | 112 | self.logger.info(f"Observations:\n{json.dumps(real_time_obs, ensure_ascii=False, indent=4)}\n") 113 | messages.append({"role": "observation", "content": real_time_obs}) 114 | 115 | elif llm_response.stop_reason in ["stop_sequence", "end_turn"]: 116 | final_response = llm_response.content[0].text 117 | self.logger.info(f"Final Response: {final_response}\n") 118 | messages.append({"role": "assistant", "content": final_response}) 119 | 120 | return self.return_result(messages) 121 | 122 | else: 123 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) 124 | 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Complex Function Calling Benchmark (ComplexFuncBench) 2 | 3 |
4 | 📄 Arxiv Paper • 🤗 HF Paper • 📊 Dataset 5 |
6 | 7 | 8 | ## Table of Contents 9 | 10 | - [Introduction](#introduction) 11 | - [Leaderboard](#Leaderboard) 12 | - [Method](#Method) 13 | - [How to evaluate on ComplexFuncBench](#how-to-evaluate-on-complexfuncbench) 14 | - [Citation](#citation) 15 | 16 | ## Introduction 17 | 18 | Complex Function Calling Benchmark (`ComplexFuncBench`) is specillly designed for complex function calling evaluation. The ComplexFuncBench dataset encompass 1,000 complex function calling samples from five aspects: (1) Function calling with **multi-step** in single turn; (2) Function calling with user-provided **constraints**; (3) Function calling that requires **parameter value reasoning** from implicit information; (4) Function calling with **long parameter values** that exceed 500 tokens; and (5) Function calling with **128k long-context** length. 19 | 20 |  21 | 22 | The difference between `ComplexFuncBench` and other function calling benchmarks is shown in the following table. 23 | 24 | | | Real API Response | Multi-Step | Constraints | Parameter Value Reasoning | Long Parameter Reasoning | Long-Context | 25 | | :----------------: | :---------------: | :--------: | :---------: | :-----------------------: | :----------------------: | :----------: | 26 | | API-Bench | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | 27 | | ToolBench | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | 28 | | T-Eval | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | 29 | | BFCL | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | 30 | | Tool Sandbox | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | 31 | | `ComplexFuncBench` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 32 | 33 | ## Leaderboard 34 | 35 | | Model | Overall Success Rate | Overall Call Acc. | Completeness | Correctness | 36 | | :--------------------------- | :------------------: | :---------------: | :----------: | :---------: | 37 | | Claude-3.5-Sonnet (20241022) | **61.00** | 79.27 | **1.84** | **1.85** | 38 | | GPT-4o (2024-08-06) | 60.50 | **80.55** | 1.66 | 1.75 | 39 | | GLM-4-Long | 57.10 | 76.35 | 1.72 | 1.74 | 40 | | GPT-4-Turbo (2024-04-09) | 49.50 | 71.38 | 1.72 | 1.81 | 41 | | Claude-3.5-Haiku (20241022) | 45.80 | 69.50 | 1.79 | 1.71 | 42 | | Qwen2.5-72B | 40.10 | 58.32 | 1.80 | 1.75 | 43 | | Mistral Large 2 | 20.10 | 48.78 | 0.94 | 1.0 | 44 | | GLM-4-9B | 9.40 | 27.97 | 1.15 | 1.03 | 45 | | Qwen2.5-7B | 5.0 | 18.19 | 1.5 | 1.47 | 46 | | Llama-3.1-405B | 4.00 | 11.87 | 0.43 | 0.30 | 47 | | Llama-3.1-70B | 2.70 | 8.17 | 0.67 | 0.36 | 48 | | Llama-3.1-8B | 0.10 | 1.34 | 0.18 | 0.09 | 49 | 50 | ## Method 51 | 52 | ### Data Collection 53 | 54 | The collection of the ComplexFuncBench dataset consists of three stages: coarse generation, fine-grained annotation, and generalization. The dataset contains 1,000 complex function-calling samples, which comprise 600 single-domain samples and 400 cross-domain samples. 55 | 56 |  57 | 58 | ### Automated Evaluation 59 | 60 | The automated evaluation framework \texttt{ComplexEval} evaluates models' complex function calling ability and response generation ability simultaneously. 61 | 62 |  63 | 64 | ## How to evaluate on ComplexFuncBench 65 | 66 | ### Preparation 67 | First, download the repository and dataset. You can download the benchmarkd dataset through the [HuggingFace datasets](https://huggingface.co/datasets/THUDM/ComplexFuncBench). 68 | 69 | ```shell 70 | git clone https://github.com/THUDM/ComplexFuncBench.git 71 | cd ComplexFuncBench 72 | ``` 73 | 74 | Then, install the dependencies. 75 | 76 | ```shell 77 | pip install -r requirements.txt 78 | ``` 79 | 80 | 81 | ### Serve Model 82 | - For close source models, make sure the corresponding model API keys are included in your evironments `.env` . To enable response-based evaluation, you need to subscribe the [Booking API](https://rapidapi.com/DataCrawler/api/booking-com15) from RapidAPI. 83 | 84 | ```shell 85 | OPENAI_API_KEY=sk-XXXXXX 86 | 87 | RAPID_API_KEY= 88 | ``` 89 | 90 | - For open source models, you need to deploy your model via [vLLM](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html). Run the following command to serve the model. Take `THUDM/glm-4-9b-chat` for example: 91 | 92 | ```shell 93 | vllm serve THUDM/glm-4-9b-chat --api-key token-abc123 --tensor-parallel-size 4 --gpu-memory-utilization 0.95 --max_model_len 131072 --trust-remote-code 94 | ``` 95 | 96 | ### Run Model Inference 97 | 98 | ```shell 99 | python evaluation.py --model_name {model_name} --proc_num {proc_num} 100 | ``` 101 | 102 | Take `gpt-4o-2024-08-06` and `THUDM/glm-4-9b-chat` for example, 103 | 104 | ```shell 105 | python evaluation.py --model_name gpt-4o-2024-08-06 --proc_num 50 106 | ``` 107 | 108 | ```shell 109 | python evaluation.py --model_name THUDM/glm-4-9b-chat --proc_num 50 --vllm_url http://xx.xx.xx.xx:8000/v1 110 | ``` 111 | 112 | The evaluation results is saved in `result/{model_name}` 113 | 114 | ### Export Results 115 | 116 | ```shell 117 | python print_results.py --result_dir {result_dir} 118 | ``` 119 | 120 | 121 | 122 | ## Citation 123 | If you find our work helpful for your research, please consider citing our work. 124 | ``` 125 | @misc{zhong2025complexfuncbench, 126 | title={ComplexFuncBench: Exploring Multi-Step and Constrained Function Calling under Long-Context Scenario}, 127 | author={Lucen Zhong and Zhengxiao Du and Xiaohan Zhang and Haiyi Hu and Jie Tang}, 128 | year={2025}, 129 | eprint={2501.10132}, 130 | archivePrefix={arXiv}, 131 | primaryClass={cs.CL}, 132 | url={https://arxiv.org/abs/2501.10132}, 133 | } 134 | ``` -------------------------------------------------------------------------------- /runner/glm_runner.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import json 4 | from models.glm import GLMAPIModel, GLMVllmModel 5 | from runner.base_runner import ModelRunner 6 | 7 | 8 | class GLMRunner(ModelRunner): 9 | def __init__(self, args, logger): 10 | super().__init__(args, logger) 11 | self.model_name = args.model_name 12 | self.model = GLMVllmModel(args.vllm_url, self.model_name) 13 | 14 | def run(self, data): 15 | convs, functions = data['conversations'], data['functions'] 16 | self.CompareClass.add_free_function(convs) 17 | 18 | messages = [] 19 | query = convs[0]['content'] 20 | messages.append({"role": "user", "content": query}) 21 | 22 | self.init_golden(convs) 23 | 24 | while True: 25 | llm_response = self.model(messages, tools=functions) 26 | 27 | if "function_call" in json.dumps(llm_response, ensure_ascii=False): 28 | if self.golden_fcs == []: 29 | self.logger.error(f"Output FC:\n{llm_response}") 30 | return self.return_result(messages, {"error_type": "func_hallucination", "content": "`self.golden_fcs == []`. Expected to stop. But Model continue to output function call."}) 31 | if len(llm_response) == 2: 32 | messages.append(llm_response[0]) 33 | self.logger.info(f"Thought: {llm_response[0]['content']}") 34 | 35 | function_calls = llm_response[-1]['function_call'] 36 | self.logger.info(f"Function Calls: \n{json.dumps(function_calls, ensure_ascii=False, indent=4)}\n") 37 | self.logger.info(f"Golden Function Call: \n{json.dumps(self.golden_fcs, ensure_ascii=False, indent=4)}\n") 38 | messages.append({"role": "assistant", "function_call": function_calls}) 39 | 40 | self.error_message, success_map, success_matched, format_error = self.CompareClass.compare_turn_prediction( 41 | functions, messages[:-1], 42 | copy.deepcopy(function_calls), self.golden_fcs, 43 | self.golden_obs 44 | ) 45 | if len(success_map) == 0 and format_error == {}: 46 | return self.return_result(messages, self.error_message) 47 | self.correct_count += len(success_map) 48 | 49 | real_time_obs = [] 50 | for t, function_call in enumerate(function_calls): 51 | if t in success_map: 52 | temp_obs = success_map[t] 53 | elif t in format_error: 54 | temp_obs = format_error[t] 55 | else: 56 | temp_obs = self.unexpect_call_resp 57 | real_time_obs.append(temp_obs) 58 | if not isinstance(temp_obs, str): 59 | temp_obs = json.dumps(temp_obs, ensure_ascii=False) 60 | messages.append({"role": "observation", "content": temp_obs}) 61 | 62 | self.process_matches(success_matched) 63 | 64 | self.logger.info(f"Observations:\n{json.dumps(real_time_obs, ensure_ascii=False, indent=4)}\n") 65 | # messages.append({"role": "observation", "content": real_time_obs}) 66 | 67 | elif llm_response is not None: 68 | final_response = llm_response[0]['content'] 69 | self.logger.info(f"Final Response: {final_response}\n") 70 | messages.append({"role": "assistant", "content": final_response}) 71 | 72 | return self.return_result(messages) 73 | 74 | elif llm_response.finish_reason == "length": 75 | self.logger.info(f"{llm_response}") 76 | return self.return_result(messages, {"error_type": "exceed_max_length", "content": "The response is too long."}) 77 | 78 | else: 79 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) 80 | 81 | 82 | class GLMAPIRunner(GLMRunner): 83 | def __init__(self, args, logger): 84 | super().__init__(args, logger) 85 | self.model_name = args.model_name 86 | self.model = GLMAPIModel(self.model_name) 87 | 88 | def replace_invalid_chars(self, s): 89 | # 使用正则表达式匹配有效字符 90 | valid_pattern = re.compile(r'[a-zA-Z0-9_-]') 91 | 92 | # 使用list comprehension替换不符合要求的字符 93 | result = ''.join([char if valid_pattern.match(char) else '-' for char in s]) 94 | 95 | # 如果字符串长度超过64个字符,截断为前64个字符 96 | return result[:64] 97 | 98 | def get_standard_functions(self, functions): 99 | # self.name_dict = {api['name']: self.replace_invalid_chars(api['name']) for api in functions} 100 | gpt_functions = [{"type": "function", "function": copy.deepcopy(func)} for func in functions] 101 | # for func in gpt_functions: 102 | # func['function']['name'] = self.name_dict[func['function']['name']] 103 | return gpt_functions 104 | 105 | def get_standard_fc(self, tool_call): 106 | try: 107 | return {"name": tool_call['function']['name'], "arguments": json.loads(tool_call['function']['arguments'])} 108 | except: 109 | return None 110 | 111 | def run(self, data): 112 | convs, functions = data['conversations'], data['functions'] 113 | self.CompareClass.add_free_function(convs) 114 | gpt_functions = self.get_standard_functions(functions) 115 | 116 | messages = [] 117 | query = convs[0]['content'] 118 | messages.append({"role": "user", "content": query}) 119 | 120 | self.init_golden(convs) 121 | 122 | while True: 123 | llm_response = self.model(messages, tools=gpt_functions) 124 | if llm_response is None: 125 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) 126 | 127 | if llm_response.finish_reason == "tool_calls": 128 | if self.golden_fcs == []: 129 | self.logger.error(f"Output FC:\n{llm_response.tool_calls}") 130 | return self.return_result(messages, {"error_type": "func_hallucination", "content": "`self.golden_fcs == []`. Expected to stop. But Model continue to output function call."}) 131 | if llm_response.message is not None: 132 | self.model.messages.append({"role": "assistant", "content": llm_response.message.content}) 133 | self.logger.info(f"Thought: {llm_response.message.content}") 134 | self.model.messages.append({"role": "assistant", "tool_calls": llm_response.tool_calls}) 135 | # self.model.message.append(llm_response.message.model_dump()) 136 | tool_calls = llm_response.tool_calls 137 | 138 | function_calls = [] 139 | for tool_call in tool_calls: 140 | function_call = self.get_standard_fc(tool_call) 141 | if function_call is None: 142 | return self.return_result(messages, {"error_type": "decode_error", "content": f"{tool_call} is not Valid."}) 143 | function_calls.append(function_call) 144 | self.logger.info(f"Function Calls: \n{json.dumps(function_calls, ensure_ascii=False, indent=4)}\n") 145 | self.logger.info(f"Golden Function Call: \n{json.dumps(self.golden_fcs, ensure_ascii=False, indent=4)}\n") 146 | messages.append({"role": "assistant", "function_call": function_calls}) 147 | 148 | self.error_message, success_map, success_matched, format_error = self.CompareClass.compare_turn_prediction( 149 | functions, messages[:-1], 150 | copy.deepcopy(function_calls), self.golden_fcs, 151 | self.golden_obs 152 | ) 153 | if len(success_map) == 0 and format_error == {}: 154 | return self.return_result(messages, self.error_message) 155 | self.correct_count += len(success_map) 156 | 157 | real_time_obs = [] 158 | for t, function_call in enumerate(function_calls): 159 | if t in success_map: 160 | temp_obs = success_map[t] 161 | elif t in format_error: 162 | temp_obs = format_error[t] 163 | else: 164 | temp_obs = self.unexpect_call_resp 165 | 166 | real_time_obs.append(temp_obs) 167 | 168 | self.model.messages.append( 169 | { 170 | "role": "tool", 171 | "content": json.dumps(temp_obs, ensure_ascii=False), 172 | "tool_call_id": tool_calls[t].id, 173 | } 174 | ) 175 | 176 | self.process_matches(success_matched) 177 | 178 | self.logger.info(f"Observations:\n{json.dumps(real_time_obs, ensure_ascii=False, indent=4)}\n") 179 | messages.append({"role": "observation", "content": real_time_obs}) 180 | 181 | elif llm_response.finish_reason == "stop": 182 | final_response = llm_response.message.content 183 | self.logger.info(f"Final Response: {final_response}\n") 184 | messages.append({"role": "assistant", "content": final_response}) 185 | 186 | return self.return_result(messages) 187 | 188 | else: 189 | self.logger.info(f"{llm_response}") 190 | return self.return_result(messages, {"error_type": "unknown_error", "content": "llm_response is None"}) -------------------------------------------------------------------------------- /utils/compare_method.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import gc 4 | import random 5 | import torch 6 | import numpy as np 7 | import numpy as np 8 | from FlagEmbedding import FlagModel 9 | from scipy.optimize import linear_sum_assignment 10 | 11 | from utils.utils import * 12 | from utils.rapidapi import RapidAPICall 13 | from models.gpt import GPTModel 14 | from prompts.compare import system_prompt, user_prompt 15 | from utils.logger import Logger 16 | 17 | class CompareFCBase: 18 | def __init__(self, args, logger) -> None: 19 | self.embedding = FlagModel('BAAI/bge-large-en-v1.5', 20 | query_instruction_for_retrieval="Represent this sentence for searching relevant passages:", 21 | use_fp16=True) 22 | 23 | with open("utils/tool_info.json", 'r') as f: 24 | tool_info = json.load(f) 25 | tool_info = tool_info['booking-com15'] 26 | self.api_call = RapidAPICall(tool="booking-com15", tool_info=tool_info) 27 | self.model = GPTModel("gpt-4o-2024-05-13") 28 | self.logger = logger 29 | self.error_message = [] 30 | self.exact_match_dict = load_json("utils/exact_match_values.json") 31 | self.free_function_list = ["Location_to_Lat_Long", "Search_Hotel_Destination", "Search_Attraction_Location", 32 | "Search_Car_Location", "Search_Flight_Location", "Taxi_Search_Location"] 33 | 34 | def format_check(self, func_call, functions): 35 | name_to_func = {func['name']: func for func in functions} 36 | 37 | if func_call['name'] not in name_to_func: 38 | return {"error": f"Function {func_call['name']} is not defined in the function list."} 39 | 40 | used_func = name_to_func[func_call['name']] 41 | required_params = used_func['parameters']['required'] 42 | if not set(required_params).issubset(set(func_call['arguments'].keys())): 43 | # 找到在required_params中,但不在func_call['arguments'].keys()中的参数 44 | missing_param = set(required_params) - set(func_call['arguments'].keys()) 45 | return {"error": f"Function {used_func['name']} requires parameters {required_params}, but {list(func_call['arguments'].keys())} do not provide {missing_param}"} 46 | 47 | if not set(func_call['arguments'].keys()).issubset(set(used_func['parameters']['properties'].keys())): 48 | missing_param = set(func_call['arguments'].keys()) - set(used_func['parameters']['properties'].keys()) 49 | return {"error":f"Function {used_func['name']} does not have parameters {missing_param}" } 50 | 51 | # 参数类型验证 52 | for param_name, param_value in func_call['arguments'].items(): 53 | if used_func['parameters']['properties'][param_name]['type'] == "string": 54 | if not isinstance(param_value, str): 55 | return {"error": f"Parameter {param_name} of function {used_func['name']} should be a string, but {type(param_value)} is provided."} 56 | 57 | elif used_func['parameters']['properties'][param_name]['type'] == "number": 58 | if not isinstance(param_value, int) and not isinstance(param_value, float): 59 | return {"error": f"Parameter {param_name} of function {used_func['name']} should be a number, but {type(param_value)} is provided."} 60 | 61 | elif used_func['parameters']['properties'][param_name]['type'] == "boolean": 62 | if not isinstance(param_value, bool): 63 | return {"error": f"Parameter {param_name} of function {used_func['name']} should be a boolean, but {type(param_value)} is provided."} 64 | 65 | elif used_func['parameters']['properties'][param_name]['type'] == "array": 66 | if not isinstance(param_value, list): 67 | return {"error": f"Parameter {param_name} of function {used_func['name']} should be an array, but {type(param_value)} is provided."} 68 | 69 | return True 70 | 71 | def add_free_function(self, convs): 72 | # free function is optional 73 | self.free_functions = {} 74 | for i, turn in enumerate(convs): 75 | if "function_call" not in turn: 76 | continue 77 | for j, func_call in enumerate(turn['function_call']): 78 | if func_call['name'] in self.free_function_list: 79 | if json.dumps(func_call) not in self.free_functions: 80 | self.free_functions[json.dumps(func_call)] = { 81 | "called": False, 82 | "obs": convs[i+1]['content'][j] 83 | } 84 | 85 | def rule_based(self, predict, golden): 86 | """ 87 | Rule-based Match. 88 | """ 89 | if predict['name'] != golden['name']: 90 | return False 91 | if sorted(predict['arguments'].keys()) != sorted(golden['arguments'].keys()): 92 | return False 93 | for k, v in predict['arguments'].items(): 94 | if k == "categories_filter": 95 | pred_filter = [s.strip() for s in v.split(",")] 96 | golden_filter = [s.strip() for s in golden['arguments'][k].split(",")] 97 | if set(golden_filter) == set(pred_filter): 98 | continue 99 | if v != golden['arguments'][k]: 100 | return False 101 | 102 | return True 103 | 104 | def response_based(self, predict, golden): 105 | try: 106 | resp_1 = self.api_call._call(predict) 107 | if resp_1 == {}: 108 | self.error_message = f"API call failed for {predict}." 109 | return False 110 | if isinstance(resp_1, dict): 111 | if "status" in resp_1 and resp_1["status"] == False: 112 | self.error_response = resp_1 113 | else: 114 | self.error_response = resp_1 115 | resp_2 = self.api_call._call(golden) 116 | except: 117 | return False 118 | if resp_1 is None or resp_2 is None: 119 | return False 120 | 121 | return resp_1 == resp_2 122 | 123 | def similarity_based(self, predict, golden): 124 | embedding_1 = self.embedding.encode([json.dumps(predict, ensure_ascii=False)]) 125 | embedding_2 = self.embedding.encode([json.dumps(golden, ensure_ascii=False)]) 126 | similarity = embedding_1 @ embedding_2.T 127 | del embedding_1, embedding_2 128 | torch.cuda.empty_cache() 129 | gc.collect() 130 | self.logger.debug(f"Similarity-based comparison output: {similarity[0][0]}") 131 | return similarity[0][0] > 0.98 132 | 133 | def llm_based(self, functions, history, predict, golden): 134 | kwargs = { 135 | "functions": json.dumps(functions, ensure_ascii=False), 136 | "history": json.dumps(history, ensure_ascii=False), 137 | "function_call_1": json.dumps(predict, ensure_ascii=False), 138 | "function_call_2": json.dumps(golden, ensure_ascii=False), 139 | } 140 | 141 | output = self.model(system_prompt, user_prompt, **kwargs) 142 | 143 | decode_output = decode_json(output) 144 | 145 | if decode_output: 146 | self.logger.debug(f"LLM-based comparison output: {decode_output}") 147 | return decode_output['is_equal'] 148 | else: 149 | return None 150 | 151 | 152 | class CompareFC(CompareFCBase): 153 | def __init__(self, args, logger) -> None: 154 | super().__init__(args, logger) 155 | 156 | def value_checker(self, pred_call, golden_call): 157 | if pred_call['name'] != golden_call['name']: 158 | return False, {"error_type": "func_error", "content": "Do not call the correct function."} 159 | param_list = self.exact_match_dict[pred_call['name']] 160 | for k, v in golden_call['arguments'].items(): 161 | if k in param_list: 162 | if k not in pred_call['arguments']: 163 | return False, {"error_type": "param_missing", "content": f"Missing parameter {k} in prediction."} 164 | elif k != "categories_filter" and pred_call['arguments'][k] != v: 165 | return False, {"error_type": "value_error", "content": f"Parameter {k} value is not correct in prediction."} 166 | elif k == "categories_filter": 167 | golden_filter = [s.strip() for s in v.split(",")] 168 | pred_filter = [s.strip() for s in pred_call['arguments'][k].split(",")] 169 | if not set(golden_filter) == set(pred_filter): 170 | return False, {"error_type": "value_error", "content": f"Parameter {k} value is not correct in prediction."} 171 | return True, "" 172 | 173 | def remove_called_fc(self, golden, golden_obs): 174 | pop_index = [] 175 | for singel_golden in golden: 176 | if json.dumps(singel_golden) in self.free_functions and self.free_functions[json.dumps(singel_golden)]['called'] == True: 177 | pop_index.append(golden.index(singel_golden)) 178 | 179 | for index in sorted(pop_index, reverse=True): 180 | golden.pop(index) 181 | golden_obs.pop(index) 182 | return golden, golden_obs 183 | 184 | def get_error_message(self, pred_call, golden_call): 185 | # value error 186 | for k, v in golden_call['arguments'].items(): 187 | if k not in pred_call['arguments']: 188 | return {"error_type": "param_missing", "content": f"Missing parameter {k} in prediction."} 189 | if v != pred_call['arguments'][k]: 190 | return {"error_type": "value_error", "content": f"Parameter {k} value do not equal to golden."} 191 | 192 | # hallucination 193 | for k, v in pred_call['arguments'].items(): 194 | if k not in golden_call['arguments']: 195 | return {'error_type': "param_hallucination", "content": f"Parameter {k} is hallucinated."} 196 | 197 | def mapping_call(self, predict, golden, golden_obs): 198 | def sort_arguments(call_list): 199 | for value in call_list: 200 | sorted_arguments = {k: value['arguments'][k] for k in sorted(value['arguments'])} 201 | value['arguments'] = sorted_arguments 202 | sort_arguments(predict) 203 | sort_arguments(golden) 204 | 205 | # exact match 206 | exact_matches = [] 207 | remaining_predict = [] 208 | remaining_predict_index = {} 209 | remaining_golden = [] 210 | remaining_golden_index = {} 211 | matched_indices = set() 212 | 213 | for p_index, p_value in enumerate(predict): 214 | match_found = False 215 | for g_index, g_value in enumerate(golden): 216 | if g_index in matched_indices: 217 | continue 218 | if p_value == g_value: 219 | exact_matches.append({ 220 | "idx": p_index, 221 | "pred_call": p_value, 222 | "golden_call": g_value, 223 | "golden_obs": golden_obs[g_index] 224 | }) 225 | matched_indices.add(g_index) 226 | match_found = True 227 | break 228 | elif json.dumps(p_value) in self.free_functions: 229 | exact_matches.append({ 230 | "idx": p_index, 231 | "pred_call": p_value, 232 | "golden_call": p_value, 233 | "golden_obs": self.free_functions[json.dumps(p_value)]['obs'] 234 | }) 235 | match_found = True 236 | self.free_functions[json.dumps(p_value)]['called'] = True 237 | break 238 | 239 | if not match_found: 240 | remaining_predict.append(p_value) 241 | remaining_predict_index[len(remaining_predict) - 1] = p_index 242 | 243 | for g_index, g_value in enumerate(golden): 244 | if g_index not in matched_indices: 245 | remaining_golden.append(g_value) 246 | remaining_golden_index[len(remaining_golden) - 1] = g_index 247 | 248 | if remaining_predict == [] or remaining_golden == []: 249 | return exact_matches 250 | 251 | # embedding match 252 | pred_embed = self.embedding.encode([json.dumps(value, ensure_ascii=False) for value in remaining_predict]) 253 | gold_embed = self.embedding.encode([json.dumps(value, ensure_ascii=False) for value in remaining_golden]) 254 | matrix = pred_embed @ gold_embed.T 255 | 256 | del pred_embed, gold_embed 257 | torch.cuda.empty_cache() 258 | gc.collect() 259 | 260 | row_ind, col_ind = linear_sum_assignment(-matrix) 261 | 262 | embedding_matches = [] 263 | for i, j in zip(row_ind, col_ind): 264 | embedding_matches.append({ 265 | "idx": remaining_predict_index[i], 266 | "pred_call": remaining_predict[i], 267 | "golden_call": remaining_golden[j], 268 | "golden_obs": golden_obs[remaining_golden_index[j]] 269 | }) 270 | matching = exact_matches + embedding_matches 271 | 272 | return matching 273 | 274 | def compare_single_call(self, functions, history, pred_call, golden_call): 275 | self.logger.info(f"Start compare_single_call: \n{pred_call}\n{golden_call}") 276 | # rule-based 277 | if self.rule_based(pred_call, golden_call): 278 | self.logger.info(f"Rule-based compare success.") 279 | return True, None 280 | 281 | is_valid, error_message = self.value_checker(pred_call, golden_call) 282 | if not is_valid: 283 | self.logger.info(f"{error_message}") 284 | return False, error_message 285 | 286 | # Response-based 287 | if self.response_based(pred_call, golden_call): 288 | self.logger.info(f"Response-based compare success.") 289 | return True, None 290 | 291 | # LLM-based 292 | if self.llm_based(functions, history, pred_call, golden_call): 293 | self.logger.info(f"LLM-based compare success.") 294 | return True, None 295 | 296 | self.logger.info(f"All compare method failed.") 297 | return False, None 298 | 299 | def compare_turn_prediction(self, functions, history, predict, golden, golden_obs): 300 | self.error_message = [] 301 | golden, golden_obs = self.remove_called_fc(golden, golden_obs) 302 | 303 | if len(golden) == 0: 304 | raise NotImplementedError() 305 | 306 | match_list = self.mapping_call(predict, golden, golden_obs) 307 | 308 | format_error = {} 309 | success_map, success_matched = {}, [] 310 | for match_item in match_list: 311 | # format error check 312 | message = self.format_check(match_item['pred_call'], functions) 313 | if message == True: 314 | is_match, single_message = self.compare_single_call(functions, history, match_item['pred_call'], match_item['golden_call']) 315 | if is_match: 316 | success_map[match_item['idx']] = match_item['golden_obs'] 317 | success_matched.append(match_item['golden_call']) 318 | else: 319 | if single_message: 320 | self.error_message.append(single_message) 321 | else: 322 | self.error_message.append(self.get_error_message(match_item['pred_call'], match_item['golden_call'])) 323 | 324 | else: 325 | format_error[match_item['idx']] = message 326 | 327 | self.logger.info(f"Success matched: {success_matched}") 328 | 329 | return self.error_message, success_map, success_matched, format_error --------------------------------------------------------------------------------