├── tool_eval ├── __init__.py ├── chat_templates │ ├── chatml.j2 │ ├── zephyr.j2 │ └── vicuna.j2 ├── schema.py ├── prompt_assets │ ├── sys_prompt.yml │ ├── sys_prompt_scratchpad.yml │ └── few_shot.json ├── prompter.py ├── evaluator_json_mode.py ├── utils.py ├── validator.py └── evaluator.py ├── requirements.txt └── README.md /tool_eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.1 2 | transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632 3 | peft @ git+https://github.com/huggingface/peft.git 4 | bitsandbytes>=0.41.1 5 | accelerate==0.27.2 6 | datasets==2.17.1 7 | ninja==1.11.1.1 8 | pydantic==2.7.4 9 | jsonschema==4.22.0 -------------------------------------------------------------------------------- /tool_eval/chat_templates/chatml.j2: -------------------------------------------------------------------------------- 1 | {% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %} -------------------------------------------------------------------------------- /tool_eval/chat_templates/zephyr.j2: -------------------------------------------------------------------------------- 1 | {% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %} -------------------------------------------------------------------------------- /tool_eval/chat_templates/vicuna.j2: -------------------------------------------------------------------------------- 1 | {% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'].strip() + '\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ 'USER: ' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ 'ASSISTANT: ' + message['content'].strip() + eos_token + '\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %} -------------------------------------------------------------------------------- /tool_eval/schema.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import List, Dict, Literal, Optional 3 | 4 | class FunctionCall(BaseModel): 5 | arguments: dict 6 | """ 7 | The arguments to call the function with, as generated by the model in JSON 8 | format. Note that the model does not always generate valid JSON, and may 9 | hallucinate parameters not defined by your function schema. Validate the 10 | arguments in your code before calling your function. 11 | """ 12 | 13 | name: str 14 | """The name of the function to call.""" 15 | 16 | class FunctionDefinition(BaseModel): 17 | name: str 18 | description: Optional[str] = None 19 | parameters: Optional[Dict[str, object]] = None 20 | 21 | class FunctionSignature(BaseModel): 22 | function: FunctionDefinition 23 | type: Literal["function"] 24 | -------------------------------------------------------------------------------- /tool_eval/prompt_assets/sys_prompt.yml: -------------------------------------------------------------------------------- 1 | Role: > 2 | You are a function calling AI model. 3 | You are provided with function signatures within XML tags. 4 | Objective: > 5 | You may call one or more functions to assist with the user query. 6 | Don't make assumptions about what values to plug into functions. 7 | Tools: | 8 | Here are the available tools: 9 | 10 | {tools} 11 | 12 | Examples: | 13 | Here are some example usage of functions: 14 | {examples} 15 | Schema: | 16 | Use the following pydantic model json schema for each tool call you will make: 17 | 18 | {schema} 19 | 20 | Instructions: | 21 | For each function call return a json object with function name and arguments within XML tags as follows: 22 | 23 | {{"name": , "arguments": }} 24 | -------------------------------------------------------------------------------- /tool_eval/prompt_assets/sys_prompt_scratchpad.yml: -------------------------------------------------------------------------------- 1 | Role: > 2 | You are a function calling AI model. 3 | You are provided with function signatures within XML tags. 4 | Objective: > 5 | You may call one or more functions to assist with the user query. 6 | If available tools are not relevant in assisting with user query, just respond in natural conversational language. 7 | Don't make assumptions about what values to plug into functions. 8 | After calling & executing the functions, you will be provided with function results within XML tags. 9 | Tools: | 10 | Here are the available tools: 11 | 12 | {tools} 13 | 14 | Examples: | 15 | Here are some example usage of functions: 16 | {examples} 17 | Schema: | 18 | For each function call return a JSON object, with the following pydantic model json schema for each: 19 | 20 | {schema} 21 | 22 | Instructions: | 23 | Each function call should be enclosed within XML tags. Please use XML tags to record your reasoning and planning before you call the functions. 24 | Example: 25 | 26 | Goal: 27 | Actions: 28 | 29 | - {{result_var_name1}} = functions.{{function_name1}}({{param1}}={{value1}},...) 30 | - {{result_var_name2, result_var_name3}} = ... 31 | None 32 | Observation: 33 | Reflection: 34 | 35 | 36 | {{"name": , "arguments": }} 37 | -------------------------------------------------------------------------------- /tool_eval/prompt_assets/few_shot.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "example": "```\nSYSTEM: You are a helpful assistant who has access to functions. Use them if required\n[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n\nUSER: Hi, I need to know the distance from New York to Los Angeles by car.\nASSISTANT:\n\n{\"arguments\": {\"origin\": \"New York\",\n \"destination\": \"Los Angeles\", \"mode\": \"car\"}, \"name\": \"calculate_distance\"}\n\n```\n" 4 | }, 5 | { 6 | "example": "```\nSYSTEM: You are a helpful assistant with access to functions. Use them if required\n[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n\nUSER: Can you help me generate a random password with a length of 8 characters?\nASSISTANT:\n\n{\"arguments\": {\"length\": 8}, \"name\": \"generate_password\"}\n\n```" 7 | } 8 | ] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Function-calling & JSON-mode Evaluation 2 | A framework for evaluating function calls and json output by LLMs using Hermes Tool Call and JSON-mode format. 3 | 4 | This script evaluates the performance of a language model on a function calling and JSON output tasks. It preprocesses prompts, runs model completions, parses the function calls/json objects in the completions, validates the function calls/json objects, and calculates the pass rate. 5 | 6 | ## Usage 7 | 8 | 1. Clone the repository or copy the script to your local machine. 9 | ```bash 10 | git clone https://github.com/your-repo/function-calling-eval.git 11 | cd function-calling-eval/tool_eval 12 | ``` 13 | 14 | 2. Install the required dependencies: 15 | ```bash 16 | pip -r requirements.txt 17 | MAX_JOBS=4 pip install flash-attn --no-build-isolation 18 | ``` 19 | 20 | ### Arguments 21 | 22 | - `--model_path`: Path to the model folder (required). 23 | - `--chat_template`: Chat template for prompt formatting (default: `"chatml"`). 24 | - `--num_fewshot`: Option to subset the evaluation dataset (default: `None`). 25 | - `--dataset_path`: Path to the Hugging Face dataset (default: function-calling: `"NousResearch/func-calling-eval"` & json-mode: `"NousResearch/json-mode-eval"`). 26 | - `--load_in_4bit`: Option to load the model in 4-bit mode with `bitsandbytes` (default: `"False"`). 27 | - `--dpo`: Option to save the dataset for DPO (default: `"False"`). 28 | 29 | ## Example 30 | 31 | ### Function-calling 32 | ```bash 33 | python evaluator.py --model_path /path/to/model --chat_template chatml --dataset_path dataset/path --load_in_4bit True --dpo False 34 | ``` 35 | #### Output 36 | 37 | The script generates the following outputs: 38 | - `function_calling_eval_results.json`: A JSON file containing the function-calling evaluation results, including prompts, completions, model outputs, and pass/fail status. 39 | - `function_calling_dpo_pairs.json` (if `--dpo` is set to `"True"`): A JSON file containing the DPO dataset for function-calling consisting of system messages, questions, chosen completions, and rejected completions. 40 | 41 | ### JSON-mode 42 | ```bash 43 | python evaluator_json_mode.py --model_path /path/to/model --load_in_4bit True --dpo False 44 | ``` 45 | #### Output 46 | 47 | The script generates the following outputs: 48 | - `json_mode_eval_results.json`: A JSON file containing the json-mode evaluation results, including prompts, completions, model outputs, and pass/fail status. 49 | - `json_mode_dpo_pairs.json` (if `--dpo` is set to `"True"`): A JSON file containing the DPO dataset for json-mode consisting of system messages, questions, chosen completions, and rejected completions. 50 | -------------------------------------------------------------------------------- /tool_eval/prompter.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Dict 3 | from schema import FunctionCall 4 | from utils import ( 5 | get_fewshot_examples 6 | ) 7 | import yaml 8 | import json 9 | import os 10 | 11 | class PromptSchema(BaseModel): 12 | Role: str 13 | Objective: str 14 | Tools: str 15 | Examples: str 16 | Schema: str 17 | Instructions: str 18 | 19 | class PromptManager: 20 | def __init__(self): 21 | self.script_dir = os.path.dirname(os.path.abspath(__file__)) 22 | 23 | def format_yaml_prompt(self, prompt_schema: PromptSchema, variables: Dict) -> str: 24 | formatted_prompt = "" 25 | for field, value in prompt_schema.dict().items(): 26 | if field == "Examples" and variables.get("examples") is None: 27 | continue 28 | 29 | formatted_value = value.format(**variables) 30 | 31 | # Add Markdown header 32 | #formatted_prompt += f"# {field}\n" 33 | formatted_prompt += formatted_value 34 | #formatted_prompt += "\n" 35 | 36 | return formatted_prompt.strip() 37 | 38 | def read_yaml_file(self, file_path: str) -> PromptSchema: 39 | with open(file_path, 'r') as file: 40 | yaml_content = yaml.safe_load(file) 41 | 42 | prompt_schema = PromptSchema( 43 | Role=yaml_content.get('Role', ''), 44 | Objective=yaml_content.get('Objective', ''), 45 | Tools=yaml_content.get('Tools', ''), 46 | Examples=yaml_content.get('Examples', ''), 47 | Schema=yaml_content.get('Schema', ''), 48 | Instructions=yaml_content.get('Instructions', ''), 49 | ) 50 | return prompt_schema 51 | 52 | def generate_prompt(self, sample, scratch_pad=False, num_fewshot=None): 53 | if scratch_pad: 54 | prompt_path = os.path.join(self.script_dir, 'prompt_assets', 'sys_prompt_scratchpad.yml') 55 | else: 56 | prompt_path = os.path.join(self.script_dir, 'prompt_assets', 'sys_prompt.yml') 57 | prompt_schema = self.read_yaml_file(prompt_path) 58 | 59 | if num_fewshot is not None: 60 | examples = get_fewshot_examples(num_fewshot) 61 | else: 62 | examples = None 63 | 64 | schema_json = json.loads(FunctionCall.schema_json()) 65 | #schema = schema_json.get("properties", {}) 66 | 67 | variables = { 68 | "tools": sample['tools'], 69 | "examples": examples, 70 | "schema": schema_json 71 | } 72 | sys_prompt = self.format_yaml_prompt(prompt_schema, variables) 73 | sample['system'] = sys_prompt 74 | 75 | prompt = [ 76 | {'content': sys_prompt, 'role': 'system'} 77 | ] 78 | prompt.extend(sample['prompt']) 79 | return prompt, sys_prompt 80 | 81 | 82 | -------------------------------------------------------------------------------- /tool_eval/evaluator_json_mode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | from tqdm import tqdm 5 | from datasets import load_dataset 6 | 7 | from transformers import ( 8 | AutoModelForCausalLM, 9 | AutoTokenizer, 10 | BitsAndBytesConfig 11 | ) 12 | 13 | from validator import validate_json_completion, validate_json_data 14 | 15 | from utils import ( 16 | eval_logger, 17 | calculate_pass_rate, 18 | get_assistant_message, 19 | get_chat_template, 20 | ) 21 | 22 | class ModelEvaluator: 23 | def __init__(self, model_path, chat_template, load_in_4bit, flash_attn, dpo): 24 | self.bnb_config = None 25 | 26 | if load_in_4bit: 27 | self.bnb_config = BitsAndBytesConfig( 28 | load_in_4bit=True, 29 | bnb_4bit_quant_type="nf4", 30 | bnb_4bit_use_double_quant=True, 31 | ) 32 | 33 | # Prepare model loading arguments 34 | model_args = { 35 | "trust_remote_code": True, 36 | "return_dict": True, 37 | "quantization_config": self.bnb_config, 38 | "torch_dtype": torch.bfloat16, 39 | "device_map": "auto", 40 | } 41 | 42 | # Conditionally add attn_implementation if flash_attn is True 43 | if flash_attn: 44 | model_args["attn_implementation"] = "flash_attention_2" 45 | 46 | self.model = AutoModelForCausalLM.from_pretrained( 47 | model_path, 48 | **model_args 49 | ) 50 | 51 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 52 | if self.tokenizer.pad_token is None: 53 | self.tokenizer.pad_token = self.tokenizer.eos_token 54 | self.tokenizer.padding_side = "left" 55 | 56 | if self.tokenizer.chat_template is None: 57 | self.tokenizer.chat_template = get_chat_template(chat_template) 58 | 59 | self.eval_results = [] 60 | if dpo: 61 | self.dpo_results = [] 62 | 63 | eval_logger.info(self.model.config) 64 | eval_logger.info(self.model.generation_config) 65 | eval_logger.info(self.model.parameters) 66 | eval_logger.info(self.tokenizer.chat_template) 67 | eval_logger.info(self.tokenizer.special_tokens_map) 68 | 69 | def evaluate_model(self, eval_dataset, chat_template, num_fewshot): 70 | 71 | for sample in tqdm(eval_dataset, desc="processing samples", unit="sample"): 72 | prompt = sample['prompt'] 73 | 74 | inputs = self.tokenizer.apply_chat_template( 75 | prompt, 76 | add_generation_prompt=True, 77 | return_tensors='pt' 78 | ) 79 | 80 | tokens = self.model.generate( 81 | inputs.to(self.model.device), 82 | max_new_tokens=512, 83 | temperature=0.1, 84 | do_sample=True 85 | ) 86 | 87 | completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False) 88 | eval_logger.info(f"model completion with eval prompt:\n{completion}") 89 | 90 | assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token) 91 | 92 | sample['model_completion'] = "" 93 | sample['result'] = "failed" 94 | 95 | if assistant_message is not None: 96 | sample['model_completion'] = assistant_message 97 | validation, json_object = validate_json_data(assistant_message, json.loads(sample['schema'])) 98 | if validation: 99 | result = validate_json_completion(json_object, json.loads(sample['completion'])) 100 | if result == "failed": 101 | eval_logger.info("Json completion validation failed") 102 | else: 103 | sample['result'] = "passed" 104 | sample['model_completion'] = json_object 105 | eval_logger.info(f"all validations passed") 106 | eval_logger.info(f"parsed json object:\n{json.dumps(json_object, indent=2)}") 107 | 108 | if sample['result'] == "failed": 109 | eval_logger.info("all validations failed") 110 | eval_logger.info(f"failed assistant message:\n{sample['model_completion']}") 111 | 112 | if hasattr(self, 'dpo_results'): 113 | self.dpo_results.append({ 114 | "system": prompt[0]["content"], 115 | "question": prompt[1]['content'], 116 | "chosen": sample['completion'], 117 | "rejected": assistant_message 118 | }) 119 | 120 | self.eval_results.append(sample) 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser(description="Evaluate model performance on fireworks-ai dataset") 124 | parser.add_argument("--model_path", type=str, help="Path to the model folder") 125 | parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting") 126 | parser.add_argument("--num_fewshot", type=int, default=None, help="Option to subset eval dataset") 127 | parser.add_argument("--dataset_path", type=str, default=None, help="Huggingface dataset path") 128 | parser.add_argument("--flash_attn", type=bool, default=False, help="weather to use flash attention; requires installing flash-attn") 129 | parser.add_argument("--load_in_4bit", type=str, default=False, help="Option to load in 4bit with bitsandbytes") 130 | parser.add_argument("--dpo", type=str, default=False, help="Option to save dpo dataset") 131 | args = parser.parse_args() 132 | 133 | # load eval dataset 134 | if args.dataset_path: 135 | eval_dataset = load_dataset(args.dataset_path)["train"] 136 | else: 137 | eval_dataset = load_dataset("NousResearch/json-mode-eval")['train'] 138 | 139 | # Create model evaluator instance 140 | model_evaluator = ModelEvaluator(args.model_path, args.chat_template, args.load_in_4bit, args.flash_attn, args.dpo) 141 | 142 | # Run the model evaluator 143 | model_evaluator.evaluate_model(eval_dataset, args.chat_template, args.num_fewshot) 144 | 145 | # Calculate and print pass rate 146 | pass_rate = calculate_pass_rate(model_evaluator.eval_results) 147 | eval_logger.info(f"json-mode eval (pass@1): {pass_rate}") 148 | 149 | results_path = './json_mode_eval_results.json' 150 | with open(results_path, 'w') as file: 151 | json.dump(model_evaluator.eval_results, file) 152 | 153 | if args.dpo: 154 | dpo_path = './json_mode_dpo_pairs.json' 155 | with open(dpo_path, 'w') as file: 156 | json.dump(model_evaluator.dpo_results, file) 157 | -------------------------------------------------------------------------------- /tool_eval/utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | import re 4 | import json 5 | import logging 6 | import datetime 7 | import xml.etree.ElementTree as ET 8 | from logging.handlers import RotatingFileHandler 9 | 10 | logging.basicConfig( 11 | format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", 12 | datefmt="%Y-%m-%d:%H:%M:%S", 13 | level=logging.INFO, 14 | ) 15 | script_dir = os.path.dirname(os.path.abspath(__file__)) 16 | now = datetime.datetime.now() 17 | log_folder = os.path.join(script_dir, "eval_logs") 18 | os.makedirs(log_folder, exist_ok=True) 19 | log_file_path = os.path.join( 20 | log_folder, f"function-calling-eval_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log" 21 | ) 22 | # Use RotatingFileHandler from the logging.handlers module 23 | file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0) 24 | file_handler.setLevel(logging.INFO) 25 | 26 | formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S") 27 | file_handler.setFormatter(formatter) 28 | 29 | eval_logger = logging.getLogger("function-calling-eval") 30 | eval_logger.addHandler(file_handler) 31 | 32 | def get_fewshot_examples(num_fewshot): 33 | """return a list of few shot examples""" 34 | example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json') 35 | with open(example_path, 'r') as file: 36 | examples = json.load(file) # Use json.load with the file object, not the file path 37 | if num_fewshot > len(examples): 38 | raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).") 39 | return examples[:num_fewshot] 40 | 41 | def get_chat_template(chat_template): 42 | """read chat template from jinja file""" 43 | template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2") 44 | 45 | if not os.path.exists(template_path): 46 | eval_logger.error(f"Template file not found: {chat_template}") 47 | return None 48 | try: 49 | with open(template_path, 'r') as file: 50 | template = file.read() 51 | return template 52 | except Exception as e: 53 | print(f"Error loading template: {e}") 54 | return None 55 | 56 | def get_assistant_message(completion, chat_template, eos_token): 57 | """define and match pattern to find the assistant message""" 58 | completion = completion.strip() 59 | 60 | if chat_template == "zephyr": 61 | assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL) 62 | elif chat_template == "chatml": 63 | assistant_pattern = re.compile(r'<\\|im_start\\|>\s*assistant((?:(?!<\\|im_start\\|>\s*assistant).)*)$', re.DOTALL) 64 | elif chat_template == "vicuna": 65 | assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL) 66 | else: 67 | raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.") 68 | 69 | assistant_match = assistant_pattern.search(completion) 70 | if assistant_match: 71 | assistant_content = assistant_match.group(1).strip() 72 | #if chat_template == "vicuna": 73 | # eos_token = f"{eos_token}" 74 | return assistant_content.replace(eos_token, "") 75 | else: 76 | assistant_content = None 77 | eval_logger.info("No match found for the assistant pattern") 78 | return assistant_content 79 | 80 | def validate_and_extract_tool_calls_regex(assistant_content): 81 | validation_result = False 82 | tool_calls = [] 83 | 84 | # Regular expression to find content within tags 85 | tool_call_pattern = re.compile(r'(.*?)', re.DOTALL) 86 | 87 | # Find all matches 88 | matches = tool_call_pattern.findall(assistant_content) 89 | 90 | for match in matches: 91 | try: 92 | # Try to parse the content as JSON 93 | json_data = json.loads(match.strip()) 94 | tool_calls.append(json_data) 95 | validation_result = True 96 | except json.JSONDecodeError as json_err: 97 | eval_logger.error("JSON parsing failed:") 98 | eval_logger.error("- JSON Decode Error: %s", json_err) 99 | eval_logger.error("- Problematic JSON text: %s", match.strip()) 100 | 101 | return validation_result, tool_calls 102 | 103 | def validate_and_extract_tool_calls(assistant_content): 104 | validation_result = False 105 | tool_calls = [] 106 | try: 107 | # wrap content in root element 108 | xml_root_element = f"{assistant_content}" 109 | root = ET.fromstring(xml_root_element) 110 | 111 | # extract JSON data 112 | for element in root.findall(".//tool_call"): 113 | if element.text is not None: 114 | json_text = element.text.strip() 115 | 116 | try: 117 | # Prioritize json.loads for better error handling 118 | json_data = json.loads(json_text) 119 | except json.JSONDecodeError as json_err: 120 | try: 121 | # Fallback to ast.literal_eval if json.loads fails 122 | json_data = ast.literal_eval(json_text) 123 | except (SyntaxError, ValueError) as eval_err: 124 | eval_logger.error("JSON parsing failed with both json.loads and ast.literal_eval:") 125 | eval_logger.error("- JSON Decode Error: %s", json_err) 126 | eval_logger.error("- Fallback Syntax/Value Error: %s", eval_err) 127 | eval_logger.error("- Problematic JSON text: %s", json_text) 128 | continue 129 | 130 | tool_calls.append(json_data) 131 | validation_result = True 132 | 133 | except ET.ParseError as err: 134 | eval_logger.error("XML Parse Error: %s", err) 135 | 136 | 137 | # Return default values if no valid data is extracted 138 | return validation_result, tool_calls 139 | 140 | def validate_tool_calls(generated_arguments, expected_arguments): 141 | for key, expected_value in expected_arguments.items(): 142 | if generated_arguments.get(key) != expected_value: 143 | eval_logger.info("Expected: %s", expected_value) 144 | eval_logger.info("Got: %s", generated_arguments.get(key)) 145 | return "failed" 146 | return "passed" 147 | 148 | def calculate_pass_rate(eval_results): 149 | passed_count =sum(1 for sample in eval_results if sample["result"] == "passed") 150 | eval_logger.info("Number of eval tests passed: %s", passed_count) 151 | eval_logger.info("Number of eval tests failed: %s", len(eval_results) - passed_count) 152 | 153 | pass_rate = passed_count / len(eval_results) 154 | return pass_rate -------------------------------------------------------------------------------- /tool_eval/validator.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | from jsonschema import validate, ValidationError 4 | from pydantic import ValidationError 5 | from utils import eval_logger 6 | from schema import FunctionCall, FunctionSignature 7 | 8 | def validate_function_call_schema(call, signatures): 9 | try: 10 | call_data = FunctionCall(**call) 11 | except ValidationError as e: 12 | eval_logger.info(f"Invalid function call: {e}") 13 | return False 14 | 15 | for signature in signatures: 16 | # Inside the main validation function 17 | try: 18 | signature_data = FunctionSignature(**signature) 19 | 20 | if signature_data.function.name == call_data.name: 21 | 22 | # Validate types in function arguments 23 | for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items(): 24 | if arg_name in call_data.arguments: 25 | call_arg_value = call_data.arguments[arg_name] 26 | if call_arg_value: 27 | try: 28 | validate_argument_type(arg_name, call_arg_value, arg_schema) 29 | except Exception as arg_validation_error: 30 | eval_logger.info(f"Invalid argument '{arg_name}': {arg_validation_error}") 31 | return False 32 | 33 | # Check if all required arguments are present 34 | required_arguments = signature_data.function.parameters.get('required', []) 35 | result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments) 36 | 37 | if not result: 38 | eval_logger.info(f"Missing required arguments: {missing_arguments}") 39 | return False 40 | 41 | return True 42 | except Exception as e: 43 | # Handle validation errors for the function signature 44 | eval_logger.info(f"Error validating function call: {e}") 45 | return False 46 | 47 | # Moved the "No matching function signature found" message here 48 | eval_logger.info(f"No matching function signature found for function: {call_data.name}") 49 | return False 50 | 51 | def check_required_arguments(call_arguments, required_arguments): 52 | missing_arguments = [arg for arg in required_arguments if arg not in call_arguments] 53 | return not bool(missing_arguments), missing_arguments 54 | 55 | def validate_enum_value(arg_name, arg_value, enum_values): 56 | if arg_value not in enum_values: 57 | raise Exception( 58 | f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}" 59 | ) 60 | 61 | def validate_argument_type(arg_name, arg_value, arg_schema): 62 | arg_type = arg_schema.get('type', None) 63 | if arg_type: 64 | if arg_type == 'string' and 'enum' in arg_schema: 65 | enum_values = arg_schema['enum'] 66 | if None not in enum_values and enum_values != []: 67 | try: 68 | validate_enum_value(arg_name, arg_value, enum_values) 69 | except Exception as e: 70 | # Propagate the validation error message 71 | raise Exception(f"Error validating function call: {e}") 72 | 73 | python_type = get_python_type(arg_type) 74 | if not isinstance(arg_value, python_type): 75 | raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}") 76 | 77 | def get_python_type(json_type): 78 | type_mapping = { 79 | 'string': str, 80 | 'number': (int, float), 81 | 'integer': int, 82 | 'boolean': bool, 83 | 'array': list, 84 | 'object': dict, 85 | 'null': type(None), 86 | } 87 | return type_mapping[json_type] 88 | 89 | 90 | def validate_json_data(json_object, json_schema): 91 | valid = False 92 | error_message = None 93 | result_json = None 94 | 95 | try: 96 | # Attempt to load JSON using json.loads 97 | try: 98 | result_json = json.loads(json_object) 99 | except json.decoder.JSONDecodeError: 100 | # If json.loads fails, try ast.literal_eval 101 | try: 102 | result_json = ast.literal_eval(json_object) 103 | except (SyntaxError, ValueError) as e: 104 | error_message = f"JSON decoding error: {e}" 105 | # Return early if both json.loads and ast.literal_eval fail 106 | eval_logger.info(f"Validation failed for JSON data: {error_message}") 107 | return valid, result_json 108 | 109 | # Validate each item in the list against schema if it's a list 110 | if isinstance(result_json, list): 111 | for index, item in enumerate(result_json): 112 | try: 113 | validate(instance=item, schema=json_schema) 114 | eval_logger.info(f"Item {index+1} is valid against the schema.") 115 | except ValidationError as e: 116 | error_message = f"Validation failed for item {index+1}: {e}" 117 | break 118 | else: # Default to validation without list 119 | try: 120 | validate(instance=result_json, schema=json_schema) 121 | #eval_logger.info("JSON object is valid against the schema.") 122 | except ValidationError as e: 123 | error_message = f"Validation failed: {e}" 124 | except Exception as e: 125 | error_message = f"Error occurred: {e}" 126 | 127 | if error_message is None: 128 | valid = True 129 | eval_logger.info("JSON data is valid against the schema.") 130 | else: 131 | eval_logger.info(f"Validation failed for JSON data: {error_message}") 132 | 133 | return valid, result_json 134 | 135 | 136 | def validate_json_completion(json_obj1, json_obj2): 137 | # Check if keys match 138 | try: 139 | if set(json_obj1.keys()) != set(json_obj2.keys()): 140 | eval_logger.info("Keys don't match:") 141 | eval_logger.info(f"Expected: {set(json_obj1.keys())}") 142 | eval_logger.info(f"Got: {set(json_obj2.keys())}") 143 | return "failed" 144 | 145 | # Check if values match 146 | for key in json_obj1.keys(): 147 | if json_obj1[key] != json_obj2[key]: 148 | eval_logger.info(f"Values don't match for key '{key}'") 149 | eval_logger.info(f"Expected: {json_obj1[key]}") 150 | eval_logger.info(f"Got: {json_obj2[key]}") 151 | return "failed" 152 | except Exception as e: 153 | eval_logger.info(f"Exception occured: {e}") 154 | return "failed" 155 | 156 | # If keys and values match, result remains "passed" 157 | return "passed" 158 | 159 | -------------------------------------------------------------------------------- /tool_eval/evaluator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | from tqdm import tqdm 5 | from datasets import load_dataset 6 | 7 | from transformers import ( 8 | AutoModelForCausalLM, 9 | AutoTokenizer, 10 | BitsAndBytesConfig 11 | ) 12 | 13 | from prompter import PromptManager 14 | from validator import validate_function_call_schema 15 | 16 | from utils import ( 17 | eval_logger, 18 | calculate_pass_rate, 19 | get_assistant_message, 20 | get_chat_template, 21 | validate_tool_calls, 22 | validate_and_extract_tool_calls_regex 23 | ) 24 | 25 | class ModelEvaluator: 26 | def __init__(self, model_path, chat_template, load_in_4bit, flash_attn, dpo): 27 | self.prompter = PromptManager() 28 | self.bnb_config = None 29 | 30 | if load_in_4bit: 31 | self.bnb_config = BitsAndBytesConfig( 32 | load_in_4bit=True, 33 | bnb_4bit_quant_type="nf4", 34 | bnb_4bit_use_double_quant=True, 35 | ) 36 | 37 | # Prepare model loading arguments 38 | model_args = { 39 | "trust_remote_code": True, 40 | "return_dict": True, 41 | "quantization_config": self.bnb_config, 42 | "torch_dtype": torch.bfloat16, 43 | "device_map": "auto", 44 | } 45 | 46 | # Conditionally add attn_implementation if flash_attn is True 47 | if flash_attn: 48 | model_args["attn_implementation"] = "flash_attention_2" 49 | 50 | self.model = AutoModelForCausalLM.from_pretrained( 51 | model_path, 52 | **model_args 53 | ) 54 | 55 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 56 | 57 | if self.tokenizer.pad_token is None: 58 | self.tokenizer.pad_token = self.tokenizer.eos_token 59 | self.tokenizer.padding_side = "left" 60 | 61 | if self.tokenizer.chat_template is None: 62 | self.tokenizer.chat_template = get_chat_template(chat_template) 63 | 64 | self.eval_results = [] 65 | if dpo: 66 | self.dpo_results = [] 67 | 68 | eval_logger.info(self.model.config) 69 | eval_logger.info(self.model.generation_config) 70 | eval_logger.info(self.model.parameters) 71 | eval_logger.info(self.tokenizer.chat_template) 72 | eval_logger.info(self.tokenizer.special_tokens_map) 73 | 74 | def evaluate_model(self, eval_dataset, chat_template, scratch_pad, num_fewshot): 75 | 76 | for sample in tqdm(eval_dataset, desc="processing samples", unit="sample"): 77 | prompt, sys_prompt = self.prompter.generate_prompt(sample, scratch_pad, num_fewshot) 78 | 79 | inputs = self.tokenizer.apply_chat_template( 80 | prompt, 81 | add_generation_prompt=True, 82 | return_tensors='pt' 83 | ) 84 | 85 | tokens = self.model.generate( 86 | inputs.to(self.model.device), 87 | max_new_tokens=2048, 88 | temperature=0.1, 89 | do_sample=True 90 | ) 91 | 92 | completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False) 93 | eval_logger.info(f"model completion with eval prompt:\n{completion}") 94 | 95 | assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token) 96 | validation, tool_calls = validate_and_extract_tool_calls_regex(assistant_message) 97 | 98 | sample['model_completion'] = [] 99 | sample['result'] = "failed" 100 | 101 | eval_completion = json.loads(sample['completion']) 102 | if validation: 103 | if isinstance(eval_completion, list): 104 | eval_tool_calls = eval_completion 105 | else: 106 | eval_tool_calls = [eval_completion] 107 | 108 | all_valid = True 109 | if len(tool_calls) != len(eval_tool_calls): 110 | all_valid = False 111 | eval_logger.info("Number of tool calls doesn't match") 112 | eval_logger.info(f"Expected: {len(eval_tool_calls)} tool calls; Got: {len(tool_calls)}") 113 | 114 | for eval_tool_call in eval_tool_calls: 115 | function_found = False 116 | 117 | for tool_call in tool_calls: 118 | schema_validation = validate_function_call_schema(tool_call, json.loads(sample['tools'])) 119 | if not schema_validation: 120 | all_valid = False 121 | break 122 | 123 | if tool_call['name'] == eval_tool_call['name']: 124 | function_found = True 125 | result = validate_tool_calls(tool_call['arguments'], eval_tool_call['arguments']) 126 | sample['model_completion'].append(tool_call) 127 | eval_logger.info(f"{tool_call['name']} validation: {result}") 128 | if result == "failed": 129 | all_valid = False 130 | break 131 | if not function_found: 132 | eval_logger.info(f"Function '{eval_tool_call['name']}' not found") 133 | all_valid = False 134 | else: 135 | eval_logger.info("Function call validation failed") 136 | all_valid = False 137 | 138 | if all_valid: 139 | sample['result'] = "passed" 140 | eval_logger.info(f"all validations: {sample['result']}") 141 | eval_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}") 142 | else: 143 | sample['model_completion'] = assistant_message 144 | eval_logger.info(f"all validations: {sample['result']}") 145 | eval_logger.info(f"failed tool calls:\n{assistant_message}") 146 | 147 | if hasattr(self, 'dpo_results'): 148 | chosen_completion = "" 149 | for tool_call in eval_completion: 150 | chosen_completion += f"\n{tool_call}\n\n" 151 | self.dpo_results.append({ 152 | "system": sys_prompt, 153 | "question": sample['prompt'][0]['content'], 154 | "chosen": chosen_completion, 155 | "rejected": assistant_message 156 | }) 157 | 158 | self.eval_results.append(sample) 159 | 160 | if __name__ == "__main__": 161 | parser = argparse.ArgumentParser(description="Evaluate model performance on fireworks-ai dataset") 162 | parser.add_argument("--model_path", type=str, help="Path to the model folder") 163 | parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting") 164 | parser.add_argument("--num_fewshot", type=int, default=None, help="Option to subset eval dataset") 165 | parser.add_argument("--dataset_path", type=str, default=None, help="Huggingface dataset path") 166 | parser.add_argument("--load_in_4bit", type=bool, default=False, help="Option to load in 4bit with bitsandbytes") 167 | parser.add_argument("--flash_attn", type=bool, default=False, help="weather to use flash attention; requires installing flash-attn") 168 | parser.add_argument("--scratch_pad", type=bool, default=False, help="whether to use scratchpad reasoning") 169 | parser.add_argument("--dpo", type=bool, default=False, help="Option to save dpo dataset") 170 | args = parser.parse_args() 171 | 172 | # load eval dataset 173 | if args.dataset_path: 174 | eval_dataset = load_dataset(args.dataset_path)["train"] 175 | else: 176 | eval_dataset = load_dataset("NousResearch/func-calling-eval-glaive")['train'] 177 | 178 | # Create model evaluator instance 179 | model_evaluator = ModelEvaluator(args.model_path, args.chat_template, args.load_in_4bit, args.flash_attn, args.dpo) 180 | 181 | # Run the model evaluator 182 | model_evaluator.evaluate_model(eval_dataset, args.chat_template, args.scratch_pad, args.num_fewshot) 183 | 184 | # Calculate and print pass rate 185 | pass_rate = calculate_pass_rate(model_evaluator.eval_results) 186 | eval_logger.info(f"function-calling eval (pass@1): {pass_rate}") 187 | 188 | results_path = './function_calling_eval_results.json' 189 | with open(results_path, 'w') as file: 190 | json.dump(model_evaluator.eval_results, file) 191 | 192 | if args.dpo == "True": 193 | dpo_path = './function_calling_dpo_pairs.json' 194 | with open(dpo_path, 'w') as file: 195 | json.dump(model_evaluator.dpo_results, file) 196 | 197 | 198 | --------------------------------------------------------------------------------