├── tool_eval
├── __init__.py
├── chat_templates
│ ├── chatml.j2
│ ├── zephyr.j2
│ └── vicuna.j2
├── schema.py
├── prompt_assets
│ ├── sys_prompt.yml
│ ├── sys_prompt_scratchpad.yml
│ └── few_shot.json
├── prompter.py
├── evaluator_json_mode.py
├── utils.py
├── validator.py
└── evaluator.py
├── requirements.txt
└── README.md
/tool_eval/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.1.1
2 | transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632
3 | peft @ git+https://github.com/huggingface/peft.git
4 | bitsandbytes>=0.41.1
5 | accelerate==0.27.2
6 | datasets==2.17.1
7 | ninja==1.11.1.1
8 | pydantic==2.7.4
9 | jsonschema==4.22.0
--------------------------------------------------------------------------------
/tool_eval/chat_templates/chatml.j2:
--------------------------------------------------------------------------------
1 | {% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
--------------------------------------------------------------------------------
/tool_eval/chat_templates/zephyr.j2:
--------------------------------------------------------------------------------
1 | {% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}
--------------------------------------------------------------------------------
/tool_eval/chat_templates/vicuna.j2:
--------------------------------------------------------------------------------
1 | {% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'].strip() + '\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ 'USER: ' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ 'ASSISTANT: ' + message['content'].strip() + eos_token + '\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}
--------------------------------------------------------------------------------
/tool_eval/schema.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import List, Dict, Literal, Optional
3 |
4 | class FunctionCall(BaseModel):
5 | arguments: dict
6 | """
7 | The arguments to call the function with, as generated by the model in JSON
8 | format. Note that the model does not always generate valid JSON, and may
9 | hallucinate parameters not defined by your function schema. Validate the
10 | arguments in your code before calling your function.
11 | """
12 |
13 | name: str
14 | """The name of the function to call."""
15 |
16 | class FunctionDefinition(BaseModel):
17 | name: str
18 | description: Optional[str] = None
19 | parameters: Optional[Dict[str, object]] = None
20 |
21 | class FunctionSignature(BaseModel):
22 | function: FunctionDefinition
23 | type: Literal["function"]
24 |
--------------------------------------------------------------------------------
/tool_eval/prompt_assets/sys_prompt.yml:
--------------------------------------------------------------------------------
1 | Role: >
2 | You are a function calling AI model.
3 | You are provided with function signatures within XML tags.
4 | Objective: >
5 | You may call one or more functions to assist with the user query.
6 | Don't make assumptions about what values to plug into functions.
7 | Tools: |
8 | Here are the available tools:
9 |
10 | {tools}
11 |
12 | Examples: |
13 | Here are some example usage of functions:
14 | {examples}
15 | Schema: |
16 | Use the following pydantic model json schema for each tool call you will make:
17 |
18 | {schema}
19 |
20 | Instructions: |
21 | For each function call return a json object with function name and arguments within XML tags as follows:
22 |
23 | {{"name": , "arguments": }}
24 |
--------------------------------------------------------------------------------
/tool_eval/prompt_assets/sys_prompt_scratchpad.yml:
--------------------------------------------------------------------------------
1 | Role: >
2 | You are a function calling AI model.
3 | You are provided with function signatures within XML tags.
4 | Objective: >
5 | You may call one or more functions to assist with the user query.
6 | If available tools are not relevant in assisting with user query, just respond in natural conversational language.
7 | Don't make assumptions about what values to plug into functions.
8 | After calling & executing the functions, you will be provided with function results within XML tags.
9 | Tools: |
10 | Here are the available tools:
11 |
12 | {tools}
13 |
14 | Examples: |
15 | Here are some example usage of functions:
16 | {examples}
17 | Schema: |
18 | For each function call return a JSON object, with the following pydantic model json schema for each:
19 |
20 | {schema}
21 |
22 | Instructions: |
23 | Each function call should be enclosed within XML tags. Please use XML tags to record your reasoning and planning before you call the functions.
24 | Example:
25 |
26 | Goal:
27 | Actions:
28 |
29 | - {{result_var_name1}} = functions.{{function_name1}}({{param1}}={{value1}},...)
30 | - {{result_var_name2, result_var_name3}} = ...
31 | None
32 | Observation:
33 | Reflection:
34 |
35 |
36 | {{"name": , "arguments": }}
37 |
--------------------------------------------------------------------------------
/tool_eval/prompt_assets/few_shot.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "example": "```\nSYSTEM: You are a helpful assistant who has access to functions. Use them if required\n[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n\nUSER: Hi, I need to know the distance from New York to Los Angeles by car.\nASSISTANT:\n\n{\"arguments\": {\"origin\": \"New York\",\n \"destination\": \"Los Angeles\", \"mode\": \"car\"}, \"name\": \"calculate_distance\"}\n\n```\n"
4 | },
5 | {
6 | "example": "```\nSYSTEM: You are a helpful assistant with access to functions. Use them if required\n[\n {\n \"name\": \"calculate_distance\",\n \"description\": \"Calculate the distance between two locations\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"origin\": {\n \"type\": \"string\",\n \"description\": \"The starting location\"\n },\n \"destination\": {\n \"type\": \"string\",\n \"description\": \"The destination location\"\n },\n \"mode\": {\n \"type\": \"string\",\n \"description\": \"The mode of transportation\"\n }\n },\n \"required\": [\n \"origin\",\n \"destination\",\n \"mode\"\n ]\n }\n },\n {\n \"name\": \"generate_password\",\n \"description\": \"Generate a random password\",\n \"parameters\": {\n \"type\": \"object\",\n \"properties\": {\n \"length\": {\n \"type\": \"integer\",\n \"description\": \"The length of the password\"\n }\n },\n \"required\": [\n \"length\"\n ]\n }\n }\n]\n\n\nUSER: Can you help me generate a random password with a length of 8 characters?\nASSISTANT:\n\n{\"arguments\": {\"length\": 8}, \"name\": \"generate_password\"}\n\n```"
7 | }
8 | ]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Function-calling & JSON-mode Evaluation
2 | A framework for evaluating function calls and json output by LLMs using Hermes Tool Call and JSON-mode format.
3 |
4 | This script evaluates the performance of a language model on a function calling and JSON output tasks. It preprocesses prompts, runs model completions, parses the function calls/json objects in the completions, validates the function calls/json objects, and calculates the pass rate.
5 |
6 | ## Usage
7 |
8 | 1. Clone the repository or copy the script to your local machine.
9 | ```bash
10 | git clone https://github.com/your-repo/function-calling-eval.git
11 | cd function-calling-eval/tool_eval
12 | ```
13 |
14 | 2. Install the required dependencies:
15 | ```bash
16 | pip -r requirements.txt
17 | MAX_JOBS=4 pip install flash-attn --no-build-isolation
18 | ```
19 |
20 | ### Arguments
21 |
22 | - `--model_path`: Path to the model folder (required).
23 | - `--chat_template`: Chat template for prompt formatting (default: `"chatml"`).
24 | - `--num_fewshot`: Option to subset the evaluation dataset (default: `None`).
25 | - `--dataset_path`: Path to the Hugging Face dataset (default: function-calling: `"NousResearch/func-calling-eval"` & json-mode: `"NousResearch/json-mode-eval"`).
26 | - `--load_in_4bit`: Option to load the model in 4-bit mode with `bitsandbytes` (default: `"False"`).
27 | - `--dpo`: Option to save the dataset for DPO (default: `"False"`).
28 |
29 | ## Example
30 |
31 | ### Function-calling
32 | ```bash
33 | python evaluator.py --model_path /path/to/model --chat_template chatml --dataset_path dataset/path --load_in_4bit True --dpo False
34 | ```
35 | #### Output
36 |
37 | The script generates the following outputs:
38 | - `function_calling_eval_results.json`: A JSON file containing the function-calling evaluation results, including prompts, completions, model outputs, and pass/fail status.
39 | - `function_calling_dpo_pairs.json` (if `--dpo` is set to `"True"`): A JSON file containing the DPO dataset for function-calling consisting of system messages, questions, chosen completions, and rejected completions.
40 |
41 | ### JSON-mode
42 | ```bash
43 | python evaluator_json_mode.py --model_path /path/to/model --load_in_4bit True --dpo False
44 | ```
45 | #### Output
46 |
47 | The script generates the following outputs:
48 | - `json_mode_eval_results.json`: A JSON file containing the json-mode evaluation results, including prompts, completions, model outputs, and pass/fail status.
49 | - `json_mode_dpo_pairs.json` (if `--dpo` is set to `"True"`): A JSON file containing the DPO dataset for json-mode consisting of system messages, questions, chosen completions, and rejected completions.
50 |
--------------------------------------------------------------------------------
/tool_eval/prompter.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import Dict
3 | from schema import FunctionCall
4 | from utils import (
5 | get_fewshot_examples
6 | )
7 | import yaml
8 | import json
9 | import os
10 |
11 | class PromptSchema(BaseModel):
12 | Role: str
13 | Objective: str
14 | Tools: str
15 | Examples: str
16 | Schema: str
17 | Instructions: str
18 |
19 | class PromptManager:
20 | def __init__(self):
21 | self.script_dir = os.path.dirname(os.path.abspath(__file__))
22 |
23 | def format_yaml_prompt(self, prompt_schema: PromptSchema, variables: Dict) -> str:
24 | formatted_prompt = ""
25 | for field, value in prompt_schema.dict().items():
26 | if field == "Examples" and variables.get("examples") is None:
27 | continue
28 |
29 | formatted_value = value.format(**variables)
30 |
31 | # Add Markdown header
32 | #formatted_prompt += f"# {field}\n"
33 | formatted_prompt += formatted_value
34 | #formatted_prompt += "\n"
35 |
36 | return formatted_prompt.strip()
37 |
38 | def read_yaml_file(self, file_path: str) -> PromptSchema:
39 | with open(file_path, 'r') as file:
40 | yaml_content = yaml.safe_load(file)
41 |
42 | prompt_schema = PromptSchema(
43 | Role=yaml_content.get('Role', ''),
44 | Objective=yaml_content.get('Objective', ''),
45 | Tools=yaml_content.get('Tools', ''),
46 | Examples=yaml_content.get('Examples', ''),
47 | Schema=yaml_content.get('Schema', ''),
48 | Instructions=yaml_content.get('Instructions', ''),
49 | )
50 | return prompt_schema
51 |
52 | def generate_prompt(self, sample, scratch_pad=False, num_fewshot=None):
53 | if scratch_pad:
54 | prompt_path = os.path.join(self.script_dir, 'prompt_assets', 'sys_prompt_scratchpad.yml')
55 | else:
56 | prompt_path = os.path.join(self.script_dir, 'prompt_assets', 'sys_prompt.yml')
57 | prompt_schema = self.read_yaml_file(prompt_path)
58 |
59 | if num_fewshot is not None:
60 | examples = get_fewshot_examples(num_fewshot)
61 | else:
62 | examples = None
63 |
64 | schema_json = json.loads(FunctionCall.schema_json())
65 | #schema = schema_json.get("properties", {})
66 |
67 | variables = {
68 | "tools": sample['tools'],
69 | "examples": examples,
70 | "schema": schema_json
71 | }
72 | sys_prompt = self.format_yaml_prompt(prompt_schema, variables)
73 | sample['system'] = sys_prompt
74 |
75 | prompt = [
76 | {'content': sys_prompt, 'role': 'system'}
77 | ]
78 | prompt.extend(sample['prompt'])
79 | return prompt, sys_prompt
80 |
81 |
82 |
--------------------------------------------------------------------------------
/tool_eval/evaluator_json_mode.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import json
4 | from tqdm import tqdm
5 | from datasets import load_dataset
6 |
7 | from transformers import (
8 | AutoModelForCausalLM,
9 | AutoTokenizer,
10 | BitsAndBytesConfig
11 | )
12 |
13 | from validator import validate_json_completion, validate_json_data
14 |
15 | from utils import (
16 | eval_logger,
17 | calculate_pass_rate,
18 | get_assistant_message,
19 | get_chat_template,
20 | )
21 |
22 | class ModelEvaluator:
23 | def __init__(self, model_path, chat_template, load_in_4bit, flash_attn, dpo):
24 | self.bnb_config = None
25 |
26 | if load_in_4bit:
27 | self.bnb_config = BitsAndBytesConfig(
28 | load_in_4bit=True,
29 | bnb_4bit_quant_type="nf4",
30 | bnb_4bit_use_double_quant=True,
31 | )
32 |
33 | # Prepare model loading arguments
34 | model_args = {
35 | "trust_remote_code": True,
36 | "return_dict": True,
37 | "quantization_config": self.bnb_config,
38 | "torch_dtype": torch.bfloat16,
39 | "device_map": "auto",
40 | }
41 |
42 | # Conditionally add attn_implementation if flash_attn is True
43 | if flash_attn:
44 | model_args["attn_implementation"] = "flash_attention_2"
45 |
46 | self.model = AutoModelForCausalLM.from_pretrained(
47 | model_path,
48 | **model_args
49 | )
50 |
51 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
52 | if self.tokenizer.pad_token is None:
53 | self.tokenizer.pad_token = self.tokenizer.eos_token
54 | self.tokenizer.padding_side = "left"
55 |
56 | if self.tokenizer.chat_template is None:
57 | self.tokenizer.chat_template = get_chat_template(chat_template)
58 |
59 | self.eval_results = []
60 | if dpo:
61 | self.dpo_results = []
62 |
63 | eval_logger.info(self.model.config)
64 | eval_logger.info(self.model.generation_config)
65 | eval_logger.info(self.model.parameters)
66 | eval_logger.info(self.tokenizer.chat_template)
67 | eval_logger.info(self.tokenizer.special_tokens_map)
68 |
69 | def evaluate_model(self, eval_dataset, chat_template, num_fewshot):
70 |
71 | for sample in tqdm(eval_dataset, desc="processing samples", unit="sample"):
72 | prompt = sample['prompt']
73 |
74 | inputs = self.tokenizer.apply_chat_template(
75 | prompt,
76 | add_generation_prompt=True,
77 | return_tensors='pt'
78 | )
79 |
80 | tokens = self.model.generate(
81 | inputs.to(self.model.device),
82 | max_new_tokens=512,
83 | temperature=0.1,
84 | do_sample=True
85 | )
86 |
87 | completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False)
88 | eval_logger.info(f"model completion with eval prompt:\n{completion}")
89 |
90 | assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token)
91 |
92 | sample['model_completion'] = ""
93 | sample['result'] = "failed"
94 |
95 | if assistant_message is not None:
96 | sample['model_completion'] = assistant_message
97 | validation, json_object = validate_json_data(assistant_message, json.loads(sample['schema']))
98 | if validation:
99 | result = validate_json_completion(json_object, json.loads(sample['completion']))
100 | if result == "failed":
101 | eval_logger.info("Json completion validation failed")
102 | else:
103 | sample['result'] = "passed"
104 | sample['model_completion'] = json_object
105 | eval_logger.info(f"all validations passed")
106 | eval_logger.info(f"parsed json object:\n{json.dumps(json_object, indent=2)}")
107 |
108 | if sample['result'] == "failed":
109 | eval_logger.info("all validations failed")
110 | eval_logger.info(f"failed assistant message:\n{sample['model_completion']}")
111 |
112 | if hasattr(self, 'dpo_results'):
113 | self.dpo_results.append({
114 | "system": prompt[0]["content"],
115 | "question": prompt[1]['content'],
116 | "chosen": sample['completion'],
117 | "rejected": assistant_message
118 | })
119 |
120 | self.eval_results.append(sample)
121 |
122 | if __name__ == "__main__":
123 | parser = argparse.ArgumentParser(description="Evaluate model performance on fireworks-ai dataset")
124 | parser.add_argument("--model_path", type=str, help="Path to the model folder")
125 | parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting")
126 | parser.add_argument("--num_fewshot", type=int, default=None, help="Option to subset eval dataset")
127 | parser.add_argument("--dataset_path", type=str, default=None, help="Huggingface dataset path")
128 | parser.add_argument("--flash_attn", type=bool, default=False, help="weather to use flash attention; requires installing flash-attn")
129 | parser.add_argument("--load_in_4bit", type=str, default=False, help="Option to load in 4bit with bitsandbytes")
130 | parser.add_argument("--dpo", type=str, default=False, help="Option to save dpo dataset")
131 | args = parser.parse_args()
132 |
133 | # load eval dataset
134 | if args.dataset_path:
135 | eval_dataset = load_dataset(args.dataset_path)["train"]
136 | else:
137 | eval_dataset = load_dataset("NousResearch/json-mode-eval")['train']
138 |
139 | # Create model evaluator instance
140 | model_evaluator = ModelEvaluator(args.model_path, args.chat_template, args.load_in_4bit, args.flash_attn, args.dpo)
141 |
142 | # Run the model evaluator
143 | model_evaluator.evaluate_model(eval_dataset, args.chat_template, args.num_fewshot)
144 |
145 | # Calculate and print pass rate
146 | pass_rate = calculate_pass_rate(model_evaluator.eval_results)
147 | eval_logger.info(f"json-mode eval (pass@1): {pass_rate}")
148 |
149 | results_path = './json_mode_eval_results.json'
150 | with open(results_path, 'w') as file:
151 | json.dump(model_evaluator.eval_results, file)
152 |
153 | if args.dpo:
154 | dpo_path = './json_mode_dpo_pairs.json'
155 | with open(dpo_path, 'w') as file:
156 | json.dump(model_evaluator.dpo_results, file)
157 |
--------------------------------------------------------------------------------
/tool_eval/utils.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | import re
4 | import json
5 | import logging
6 | import datetime
7 | import xml.etree.ElementTree as ET
8 | from logging.handlers import RotatingFileHandler
9 |
10 | logging.basicConfig(
11 | format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
12 | datefmt="%Y-%m-%d:%H:%M:%S",
13 | level=logging.INFO,
14 | )
15 | script_dir = os.path.dirname(os.path.abspath(__file__))
16 | now = datetime.datetime.now()
17 | log_folder = os.path.join(script_dir, "eval_logs")
18 | os.makedirs(log_folder, exist_ok=True)
19 | log_file_path = os.path.join(
20 | log_folder, f"function-calling-eval_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log"
21 | )
22 | # Use RotatingFileHandler from the logging.handlers module
23 | file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0)
24 | file_handler.setLevel(logging.INFO)
25 |
26 | formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S")
27 | file_handler.setFormatter(formatter)
28 |
29 | eval_logger = logging.getLogger("function-calling-eval")
30 | eval_logger.addHandler(file_handler)
31 |
32 | def get_fewshot_examples(num_fewshot):
33 | """return a list of few shot examples"""
34 | example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json')
35 | with open(example_path, 'r') as file:
36 | examples = json.load(file) # Use json.load with the file object, not the file path
37 | if num_fewshot > len(examples):
38 | raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).")
39 | return examples[:num_fewshot]
40 |
41 | def get_chat_template(chat_template):
42 | """read chat template from jinja file"""
43 | template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2")
44 |
45 | if not os.path.exists(template_path):
46 | eval_logger.error(f"Template file not found: {chat_template}")
47 | return None
48 | try:
49 | with open(template_path, 'r') as file:
50 | template = file.read()
51 | return template
52 | except Exception as e:
53 | print(f"Error loading template: {e}")
54 | return None
55 |
56 | def get_assistant_message(completion, chat_template, eos_token):
57 | """define and match pattern to find the assistant message"""
58 | completion = completion.strip()
59 |
60 | if chat_template == "zephyr":
61 | assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL)
62 | elif chat_template == "chatml":
63 | assistant_pattern = re.compile(r'<\\|im_start\\|>\s*assistant((?:(?!<\\|im_start\\|>\s*assistant).)*)$', re.DOTALL)
64 | elif chat_template == "vicuna":
65 | assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL)
66 | else:
67 | raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.")
68 |
69 | assistant_match = assistant_pattern.search(completion)
70 | if assistant_match:
71 | assistant_content = assistant_match.group(1).strip()
72 | #if chat_template == "vicuna":
73 | # eos_token = f"{eos_token}"
74 | return assistant_content.replace(eos_token, "")
75 | else:
76 | assistant_content = None
77 | eval_logger.info("No match found for the assistant pattern")
78 | return assistant_content
79 |
80 | def validate_and_extract_tool_calls_regex(assistant_content):
81 | validation_result = False
82 | tool_calls = []
83 |
84 | # Regular expression to find content within tags
85 | tool_call_pattern = re.compile(r'(.*?)', re.DOTALL)
86 |
87 | # Find all matches
88 | matches = tool_call_pattern.findall(assistant_content)
89 |
90 | for match in matches:
91 | try:
92 | # Try to parse the content as JSON
93 | json_data = json.loads(match.strip())
94 | tool_calls.append(json_data)
95 | validation_result = True
96 | except json.JSONDecodeError as json_err:
97 | eval_logger.error("JSON parsing failed:")
98 | eval_logger.error("- JSON Decode Error: %s", json_err)
99 | eval_logger.error("- Problematic JSON text: %s", match.strip())
100 |
101 | return validation_result, tool_calls
102 |
103 | def validate_and_extract_tool_calls(assistant_content):
104 | validation_result = False
105 | tool_calls = []
106 | try:
107 | # wrap content in root element
108 | xml_root_element = f"{assistant_content}"
109 | root = ET.fromstring(xml_root_element)
110 |
111 | # extract JSON data
112 | for element in root.findall(".//tool_call"):
113 | if element.text is not None:
114 | json_text = element.text.strip()
115 |
116 | try:
117 | # Prioritize json.loads for better error handling
118 | json_data = json.loads(json_text)
119 | except json.JSONDecodeError as json_err:
120 | try:
121 | # Fallback to ast.literal_eval if json.loads fails
122 | json_data = ast.literal_eval(json_text)
123 | except (SyntaxError, ValueError) as eval_err:
124 | eval_logger.error("JSON parsing failed with both json.loads and ast.literal_eval:")
125 | eval_logger.error("- JSON Decode Error: %s", json_err)
126 | eval_logger.error("- Fallback Syntax/Value Error: %s", eval_err)
127 | eval_logger.error("- Problematic JSON text: %s", json_text)
128 | continue
129 |
130 | tool_calls.append(json_data)
131 | validation_result = True
132 |
133 | except ET.ParseError as err:
134 | eval_logger.error("XML Parse Error: %s", err)
135 |
136 |
137 | # Return default values if no valid data is extracted
138 | return validation_result, tool_calls
139 |
140 | def validate_tool_calls(generated_arguments, expected_arguments):
141 | for key, expected_value in expected_arguments.items():
142 | if generated_arguments.get(key) != expected_value:
143 | eval_logger.info("Expected: %s", expected_value)
144 | eval_logger.info("Got: %s", generated_arguments.get(key))
145 | return "failed"
146 | return "passed"
147 |
148 | def calculate_pass_rate(eval_results):
149 | passed_count =sum(1 for sample in eval_results if sample["result"] == "passed")
150 | eval_logger.info("Number of eval tests passed: %s", passed_count)
151 | eval_logger.info("Number of eval tests failed: %s", len(eval_results) - passed_count)
152 |
153 | pass_rate = passed_count / len(eval_results)
154 | return pass_rate
--------------------------------------------------------------------------------
/tool_eval/validator.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import json
3 | from jsonschema import validate, ValidationError
4 | from pydantic import ValidationError
5 | from utils import eval_logger
6 | from schema import FunctionCall, FunctionSignature
7 |
8 | def validate_function_call_schema(call, signatures):
9 | try:
10 | call_data = FunctionCall(**call)
11 | except ValidationError as e:
12 | eval_logger.info(f"Invalid function call: {e}")
13 | return False
14 |
15 | for signature in signatures:
16 | # Inside the main validation function
17 | try:
18 | signature_data = FunctionSignature(**signature)
19 |
20 | if signature_data.function.name == call_data.name:
21 |
22 | # Validate types in function arguments
23 | for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items():
24 | if arg_name in call_data.arguments:
25 | call_arg_value = call_data.arguments[arg_name]
26 | if call_arg_value:
27 | try:
28 | validate_argument_type(arg_name, call_arg_value, arg_schema)
29 | except Exception as arg_validation_error:
30 | eval_logger.info(f"Invalid argument '{arg_name}': {arg_validation_error}")
31 | return False
32 |
33 | # Check if all required arguments are present
34 | required_arguments = signature_data.function.parameters.get('required', [])
35 | result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments)
36 |
37 | if not result:
38 | eval_logger.info(f"Missing required arguments: {missing_arguments}")
39 | return False
40 |
41 | return True
42 | except Exception as e:
43 | # Handle validation errors for the function signature
44 | eval_logger.info(f"Error validating function call: {e}")
45 | return False
46 |
47 | # Moved the "No matching function signature found" message here
48 | eval_logger.info(f"No matching function signature found for function: {call_data.name}")
49 | return False
50 |
51 | def check_required_arguments(call_arguments, required_arguments):
52 | missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
53 | return not bool(missing_arguments), missing_arguments
54 |
55 | def validate_enum_value(arg_name, arg_value, enum_values):
56 | if arg_value not in enum_values:
57 | raise Exception(
58 | f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
59 | )
60 |
61 | def validate_argument_type(arg_name, arg_value, arg_schema):
62 | arg_type = arg_schema.get('type', None)
63 | if arg_type:
64 | if arg_type == 'string' and 'enum' in arg_schema:
65 | enum_values = arg_schema['enum']
66 | if None not in enum_values and enum_values != []:
67 | try:
68 | validate_enum_value(arg_name, arg_value, enum_values)
69 | except Exception as e:
70 | # Propagate the validation error message
71 | raise Exception(f"Error validating function call: {e}")
72 |
73 | python_type = get_python_type(arg_type)
74 | if not isinstance(arg_value, python_type):
75 | raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}")
76 |
77 | def get_python_type(json_type):
78 | type_mapping = {
79 | 'string': str,
80 | 'number': (int, float),
81 | 'integer': int,
82 | 'boolean': bool,
83 | 'array': list,
84 | 'object': dict,
85 | 'null': type(None),
86 | }
87 | return type_mapping[json_type]
88 |
89 |
90 | def validate_json_data(json_object, json_schema):
91 | valid = False
92 | error_message = None
93 | result_json = None
94 |
95 | try:
96 | # Attempt to load JSON using json.loads
97 | try:
98 | result_json = json.loads(json_object)
99 | except json.decoder.JSONDecodeError:
100 | # If json.loads fails, try ast.literal_eval
101 | try:
102 | result_json = ast.literal_eval(json_object)
103 | except (SyntaxError, ValueError) as e:
104 | error_message = f"JSON decoding error: {e}"
105 | # Return early if both json.loads and ast.literal_eval fail
106 | eval_logger.info(f"Validation failed for JSON data: {error_message}")
107 | return valid, result_json
108 |
109 | # Validate each item in the list against schema if it's a list
110 | if isinstance(result_json, list):
111 | for index, item in enumerate(result_json):
112 | try:
113 | validate(instance=item, schema=json_schema)
114 | eval_logger.info(f"Item {index+1} is valid against the schema.")
115 | except ValidationError as e:
116 | error_message = f"Validation failed for item {index+1}: {e}"
117 | break
118 | else: # Default to validation without list
119 | try:
120 | validate(instance=result_json, schema=json_schema)
121 | #eval_logger.info("JSON object is valid against the schema.")
122 | except ValidationError as e:
123 | error_message = f"Validation failed: {e}"
124 | except Exception as e:
125 | error_message = f"Error occurred: {e}"
126 |
127 | if error_message is None:
128 | valid = True
129 | eval_logger.info("JSON data is valid against the schema.")
130 | else:
131 | eval_logger.info(f"Validation failed for JSON data: {error_message}")
132 |
133 | return valid, result_json
134 |
135 |
136 | def validate_json_completion(json_obj1, json_obj2):
137 | # Check if keys match
138 | try:
139 | if set(json_obj1.keys()) != set(json_obj2.keys()):
140 | eval_logger.info("Keys don't match:")
141 | eval_logger.info(f"Expected: {set(json_obj1.keys())}")
142 | eval_logger.info(f"Got: {set(json_obj2.keys())}")
143 | return "failed"
144 |
145 | # Check if values match
146 | for key in json_obj1.keys():
147 | if json_obj1[key] != json_obj2[key]:
148 | eval_logger.info(f"Values don't match for key '{key}'")
149 | eval_logger.info(f"Expected: {json_obj1[key]}")
150 | eval_logger.info(f"Got: {json_obj2[key]}")
151 | return "failed"
152 | except Exception as e:
153 | eval_logger.info(f"Exception occured: {e}")
154 | return "failed"
155 |
156 | # If keys and values match, result remains "passed"
157 | return "passed"
158 |
159 |
--------------------------------------------------------------------------------
/tool_eval/evaluator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import json
4 | from tqdm import tqdm
5 | from datasets import load_dataset
6 |
7 | from transformers import (
8 | AutoModelForCausalLM,
9 | AutoTokenizer,
10 | BitsAndBytesConfig
11 | )
12 |
13 | from prompter import PromptManager
14 | from validator import validate_function_call_schema
15 |
16 | from utils import (
17 | eval_logger,
18 | calculate_pass_rate,
19 | get_assistant_message,
20 | get_chat_template,
21 | validate_tool_calls,
22 | validate_and_extract_tool_calls_regex
23 | )
24 |
25 | class ModelEvaluator:
26 | def __init__(self, model_path, chat_template, load_in_4bit, flash_attn, dpo):
27 | self.prompter = PromptManager()
28 | self.bnb_config = None
29 |
30 | if load_in_4bit:
31 | self.bnb_config = BitsAndBytesConfig(
32 | load_in_4bit=True,
33 | bnb_4bit_quant_type="nf4",
34 | bnb_4bit_use_double_quant=True,
35 | )
36 |
37 | # Prepare model loading arguments
38 | model_args = {
39 | "trust_remote_code": True,
40 | "return_dict": True,
41 | "quantization_config": self.bnb_config,
42 | "torch_dtype": torch.bfloat16,
43 | "device_map": "auto",
44 | }
45 |
46 | # Conditionally add attn_implementation if flash_attn is True
47 | if flash_attn:
48 | model_args["attn_implementation"] = "flash_attention_2"
49 |
50 | self.model = AutoModelForCausalLM.from_pretrained(
51 | model_path,
52 | **model_args
53 | )
54 |
55 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
56 |
57 | if self.tokenizer.pad_token is None:
58 | self.tokenizer.pad_token = self.tokenizer.eos_token
59 | self.tokenizer.padding_side = "left"
60 |
61 | if self.tokenizer.chat_template is None:
62 | self.tokenizer.chat_template = get_chat_template(chat_template)
63 |
64 | self.eval_results = []
65 | if dpo:
66 | self.dpo_results = []
67 |
68 | eval_logger.info(self.model.config)
69 | eval_logger.info(self.model.generation_config)
70 | eval_logger.info(self.model.parameters)
71 | eval_logger.info(self.tokenizer.chat_template)
72 | eval_logger.info(self.tokenizer.special_tokens_map)
73 |
74 | def evaluate_model(self, eval_dataset, chat_template, scratch_pad, num_fewshot):
75 |
76 | for sample in tqdm(eval_dataset, desc="processing samples", unit="sample"):
77 | prompt, sys_prompt = self.prompter.generate_prompt(sample, scratch_pad, num_fewshot)
78 |
79 | inputs = self.tokenizer.apply_chat_template(
80 | prompt,
81 | add_generation_prompt=True,
82 | return_tensors='pt'
83 | )
84 |
85 | tokens = self.model.generate(
86 | inputs.to(self.model.device),
87 | max_new_tokens=2048,
88 | temperature=0.1,
89 | do_sample=True
90 | )
91 |
92 | completion = self.tokenizer.decode(tokens[0], skip_special_tokens=False)
93 | eval_logger.info(f"model completion with eval prompt:\n{completion}")
94 |
95 | assistant_message = get_assistant_message(completion, chat_template, self.tokenizer.eos_token)
96 | validation, tool_calls = validate_and_extract_tool_calls_regex(assistant_message)
97 |
98 | sample['model_completion'] = []
99 | sample['result'] = "failed"
100 |
101 | eval_completion = json.loads(sample['completion'])
102 | if validation:
103 | if isinstance(eval_completion, list):
104 | eval_tool_calls = eval_completion
105 | else:
106 | eval_tool_calls = [eval_completion]
107 |
108 | all_valid = True
109 | if len(tool_calls) != len(eval_tool_calls):
110 | all_valid = False
111 | eval_logger.info("Number of tool calls doesn't match")
112 | eval_logger.info(f"Expected: {len(eval_tool_calls)} tool calls; Got: {len(tool_calls)}")
113 |
114 | for eval_tool_call in eval_tool_calls:
115 | function_found = False
116 |
117 | for tool_call in tool_calls:
118 | schema_validation = validate_function_call_schema(tool_call, json.loads(sample['tools']))
119 | if not schema_validation:
120 | all_valid = False
121 | break
122 |
123 | if tool_call['name'] == eval_tool_call['name']:
124 | function_found = True
125 | result = validate_tool_calls(tool_call['arguments'], eval_tool_call['arguments'])
126 | sample['model_completion'].append(tool_call)
127 | eval_logger.info(f"{tool_call['name']} validation: {result}")
128 | if result == "failed":
129 | all_valid = False
130 | break
131 | if not function_found:
132 | eval_logger.info(f"Function '{eval_tool_call['name']}' not found")
133 | all_valid = False
134 | else:
135 | eval_logger.info("Function call validation failed")
136 | all_valid = False
137 |
138 | if all_valid:
139 | sample['result'] = "passed"
140 | eval_logger.info(f"all validations: {sample['result']}")
141 | eval_logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
142 | else:
143 | sample['model_completion'] = assistant_message
144 | eval_logger.info(f"all validations: {sample['result']}")
145 | eval_logger.info(f"failed tool calls:\n{assistant_message}")
146 |
147 | if hasattr(self, 'dpo_results'):
148 | chosen_completion = ""
149 | for tool_call in eval_completion:
150 | chosen_completion += f"\n{tool_call}\n\n"
151 | self.dpo_results.append({
152 | "system": sys_prompt,
153 | "question": sample['prompt'][0]['content'],
154 | "chosen": chosen_completion,
155 | "rejected": assistant_message
156 | })
157 |
158 | self.eval_results.append(sample)
159 |
160 | if __name__ == "__main__":
161 | parser = argparse.ArgumentParser(description="Evaluate model performance on fireworks-ai dataset")
162 | parser.add_argument("--model_path", type=str, help="Path to the model folder")
163 | parser.add_argument("--chat_template", type=str, default="chatml", help="Chat template for prompt formatting")
164 | parser.add_argument("--num_fewshot", type=int, default=None, help="Option to subset eval dataset")
165 | parser.add_argument("--dataset_path", type=str, default=None, help="Huggingface dataset path")
166 | parser.add_argument("--load_in_4bit", type=bool, default=False, help="Option to load in 4bit with bitsandbytes")
167 | parser.add_argument("--flash_attn", type=bool, default=False, help="weather to use flash attention; requires installing flash-attn")
168 | parser.add_argument("--scratch_pad", type=bool, default=False, help="whether to use scratchpad reasoning")
169 | parser.add_argument("--dpo", type=bool, default=False, help="Option to save dpo dataset")
170 | args = parser.parse_args()
171 |
172 | # load eval dataset
173 | if args.dataset_path:
174 | eval_dataset = load_dataset(args.dataset_path)["train"]
175 | else:
176 | eval_dataset = load_dataset("NousResearch/func-calling-eval-glaive")['train']
177 |
178 | # Create model evaluator instance
179 | model_evaluator = ModelEvaluator(args.model_path, args.chat_template, args.load_in_4bit, args.flash_attn, args.dpo)
180 |
181 | # Run the model evaluator
182 | model_evaluator.evaluate_model(eval_dataset, args.chat_template, args.scratch_pad, args.num_fewshot)
183 |
184 | # Calculate and print pass rate
185 | pass_rate = calculate_pass_rate(model_evaluator.eval_results)
186 | eval_logger.info(f"function-calling eval (pass@1): {pass_rate}")
187 |
188 | results_path = './function_calling_eval_results.json'
189 | with open(results_path, 'w') as file:
190 | json.dump(model_evaluator.eval_results, file)
191 |
192 | if args.dpo == "True":
193 | dpo_path = './function_calling_dpo_pairs.json'
194 | with open(dpo_path, 'w') as file:
195 | json.dump(model_evaluator.dpo_results, file)
196 |
197 |
198 |
--------------------------------------------------------------------------------