├── .gitignore
├── README.md
├── eval_layercast.py
├── eval_main.py
├── eval_passk.py
├── evals
├── README.md
├── __init__.py
├── base_instruct_evals.md
├── batch
│ ├── __init__.py
│ ├── engines
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── initializer.py
│ │ └── vllm_engine.py
│ ├── env_config.py
│ ├── logging
│ │ └── __init__.py
│ ├── pipeline.py
│ ├── tokenizer.py
│ ├── utils.py
│ └── workload.py
├── cli.py
├── common
│ ├── __init__.py
│ └── entities.py
├── inference_and_check.py
├── labeled_numina_difficulty
│ └── README.md
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── model_configs.yaml
│ └── system_prompts
│ │ └── prime.txt
├── ray_configs
│ └── ray_config.yaml
├── scoring
│ ├── __init__.py
│ ├── apps
│ │ ├── __init__.py
│ │ ├── apps_scorer.py
│ │ └── apps_util.py
│ ├── base.py
│ ├── gsm8k
│ │ ├── __init__.py
│ │ └── gsm8k_scorer.py
│ ├── ifeval
│ │ ├── __init__.py
│ │ ├── ifeval_scorer.py
│ │ ├── instructions.py
│ │ ├── instructions_main.py
│ │ ├── instructions_registry.py
│ │ └── instructions_util.py
│ ├── livecodebench
│ │ ├── __init__.py
│ │ ├── livecodebench_scorer.py
│ │ └── livecodebench_util.py
│ ├── math
│ │ ├── __init__.py
│ │ └── math_scorer.py
│ ├── taco
│ │ ├── __init__.py
│ │ ├── taco_scorer.py
│ │ └── taco_util.py
│ └── utils
│ │ ├── __init__.py
│ │ └── pyext2.py
├── tasks
│ ├── __init__.py
│ ├── aime
│ │ ├── aime24.yaml
│ │ ├── aime24_sky.yaml
│ │ ├── aime25_1.yaml
│ │ ├── aime25_2.yaml
│ │ └── aime_handler.py
│ ├── amc23
│ │ ├── amc23.yaml
│ │ └── amc23_handler.py
│ ├── apps
│ │ ├── apps.yaml
│ │ ├── apps_handler.py
│ │ └── apps_util.py
│ ├── arc
│ │ ├── arc_c.yaml
│ │ └── arc_handler.py
│ ├── base.py
│ ├── gpqa_diamond
│ │ ├── gpqa_diamond.yaml
│ │ └── gpqa_diamond_handler.py
│ ├── gsm8k
│ │ ├── gsm8k.yaml
│ │ └── gsm8k_handler.py
│ ├── liveaops
│ │ ├── liveaops.yaml
│ │ └── liveaops_handler.py
│ ├── livecodebench
│ │ ├── livecodebench.yaml
│ │ ├── livecodebench_easy.yaml
│ │ ├── livecodebench_handler.py
│ │ ├── livecodebench_hard.yaml
│ │ ├── livecodebench_medium.yaml
│ │ └── livecodebench_util.py
│ ├── math
│ │ ├── math500.yaml
│ │ └── math_handler.py
│ ├── minervamath
│ │ ├── minervamath.yaml
│ │ └── minervamath_handler.py
│ ├── mmlu
│ │ ├── mmlu.yaml
│ │ ├── mmlu_handler.py
│ │ └── mmlu_pro.yaml
│ ├── numina
│ │ ├── numina.yaml
│ │ ├── numina_amc_aime.yaml
│ │ ├── numina_handler.py
│ │ ├── numina_math.yaml
│ │ └── numina_olympiads.yaml
│ ├── olympiadbench
│ │ ├── olympiadbench_handler.py
│ │ └── olympiadbench_math_en.yaml
│ ├── omni_math
│ │ ├── omni_handler.py
│ │ └── omni_math.yaml
│ ├── taco
│ │ ├── pyext2.py
│ │ ├── taco.yaml
│ │ ├── taco_handler.py
│ │ └── taco_util.py
│ └── task_util.py
└── util
│ ├── __init__.py
│ ├── cli_util.py
│ ├── common.py
│ ├── math_parsing_util.py
│ ├── metrics.py
│ ├── response.py
│ └── results.py
├── figures
└── reproduciblellm_fig1.png
├── patch_vllm.py
└── prompt_util
└── prompt_template.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *pytest*
2 | *.egg-info
3 | *output*
4 | *.log
5 | **/__pycache__/
6 | outputs/
7 | scoring_results/
8 | sh/my*
9 | sh/test*
10 | vllm_version_test/test*
11 | acc_folder*/
12 | scoring*.py
13 | arxiv_exp/test*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Challenges and Solutions of LLM Reproducibility
2 |
3 | Codebase of [Give Me FP32 or Give Me Death? Challenges and Solutions for Reproducible Reasoning](https://arxiv.org/abs/2506.09501)
4 |
5 | ## News
6 | - [2025.06.18]: Our paper has been released on [arxiv](https://arxiv.org/abs/2506.09501). Feel free to ⭐UPVOTE in [huggingface](https://huggingface.co/papers/2506.09501)
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | Figure 1. Left: Under BF16 precision and greedy decoding, the model's output can vary significantly depending on factors such as GPU count, evaluation batch size, and GPU hardware version. Right: For example, changes in evaluation batch size alone can lead to noticeable differences in responses, which is often ignored and not standardized by evaluation benchmarks.
16 |
17 |
18 | ## Overview
19 | This repository contains the official implementation of **"Give Me FP32 or Give Me Death? Challenges and Solutions for Reproducible Reasoning"**. We present the first systematic study on the fragility of LLM reproducibility under different system configurations. Our work identifies reduced numerical precision as a key source of divergence, and introduces LayerCast, a hybrid-precision inference pipeline that balances memory efficiency with numerical stability.
20 |
21 | ## Environment Setup
22 |
23 | ```bash
24 | conda create -n reproducible_llm python=3.12 -y
25 | conda activate reproducible_llm
26 | pip install vllm==0.8.2
27 | pip install datasets latex2sympy2 word2number immutabledict nltk langdetect
28 | ```
29 | #### Impact of Serving System Version
30 | We consistently used vLLM 0.8.2 for our experiments. Please make sure to use the same vLLM version, since different versions of serving frameworks may employ different GPU kernels, which may have varying numerical stability.
31 |
32 |
33 | ## Getting Started
34 | ### To download this repository:
35 | ```bash
36 | git clone https://github.com/nanomaoli/llm_reproducibility.git
37 | cd llm_reproducibility
38 | ```
39 | ### To reproduce the main experiments:
40 | Set CUDA_VISIBLE_DEVICES to control the number of GPUs used, and specify a descriptive exp_name to help track different configurations.
41 | #### Run inference with greedy decoding:
42 | ```python
43 | [CUDA_VISIBLE_DEVICES] python eval_main.py --model [MODEL] \
44 | --task [TASK] \
45 | --dtype [dtype] \
46 | --seed [RANDOM_SEED] \
47 | --batch_size [BS] \
48 | --max_tokens [MAX_TOKENS] \
49 | --exp_name [NAME_OF_THE_RUN]
50 | ```
51 | Model responses and logprobs will be saved to `outputs/vllm_main/{exp_name}/{model}`. We save logprobs of the 5 most likely tokens for analysis in our paper.
52 | Scoring results will appear in `scoring_results/greedy`.
53 |
54 | *Example:*
55 | ```python
56 | CUDA_VISIBLE_DEVICES=0,1 python eval_main.py --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
57 | --task math500 \
58 | --dtype bfloat16 \
59 | --seed 42 \
60 | --batch_size 32 \
61 | --max_tokens 32768 \
62 | --exp_name 2a100_math500_bf16_bs32
63 | ```
64 |
65 | #### Run inference with greedy decoding using LayerCast:
66 | LayerCast uses `float32` for computation, so `--dtype` should be set accordingly.
67 | ```python
68 | [CUDA_VISIBLE_DEVICES] python eval_layercast.py --model [MODEL] \
69 | --task [TASK] \
70 | --dtype float32 \
71 | --seed [RANDOM_SEED] \
72 | --batch_size [BS] \
73 | --max_tokens [MAX_TOKENS] \
74 | --exp_name [NAME_OF_THE_RUN]
75 | ```
76 | Model responses and logprobs will be saved to `outputs/vllm_layercast/{exp_name}/{model}`.
77 | Scoring results will appear in `scoring_results/greedy_layercast`.
78 |
79 | *Example:*
80 | ```python
81 | CUDA_VISIBLE_DEVICES=0,1 python eval_layercast.py --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
82 | --task math500 \
83 | --dtype float32 \
84 | --seed 42 \
85 | --batch_size 32 \
86 | --max_tokens 32768 \
87 | --exp_name 2a100_math500_layercast_bs32
88 | ```
89 |
90 |
91 | #### Run inference with random sampling (`n` independent samples per problem):
92 |
93 | ```python
94 | [CUDA_VISIBLE_DEVICES] python eval_passk.py --model [MODEL] \
95 | --task [TASK] \
96 | --dtype [dtype] \
97 | --seed [RANDOM_SEED] \
98 | --batch_size [BS] \
99 | --max_tokens [MAX_TOKENS] \
100 | --passk [n] \
101 | --exp_name [NAME_OF_THE_RUN]
102 | ```
103 | Model responses will be saved to `outputs/vllm_passk/{exp_name}/{model}`.
104 | Scoring results will appear in `scoring_results/random_passk`.
105 |
106 | *Example:*
107 | ```python
108 | CUDA_VISIBLE_DEVICES=0,1 python eval_passk.py --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
109 | --task math500 \
110 | --dtype bfloat16 \
111 | --seed 42 \
112 | --batch_size 32 \
113 | --max_tokens 32768 \
114 | --passk 4 \
115 | --exp_name 2a100_pass4_math500_bf16_bs32
116 | ```
117 |
118 | ## Citation
119 |
120 | If you find our work interesting or helpful, please kindly cite our paper.
121 |
122 | ```bibtex
123 | @misc{yuan2025fp32deathchallengessolutions,
124 | title={Give Me FP32 or Give Me Death? Challenges and Solutions for Reproducible Reasoning},
125 | author={Jiayi Yuan and Hao Li and Xinheng Ding and Wenya Xie and Yu-Jhe Li and Wentian Zhao and Kun Wan and Jing Shi and Xia Hu and Zirui Liu},
126 | year={2025},
127 | eprint={2506.09501},
128 | archivePrefix={arXiv},
129 | primaryClass={cs.CL},
130 | url={https://arxiv.org/abs/2506.09501},
131 | }
132 | ```
133 |
134 | ## References
135 | Our evaluation implementation is adapted from [SkyThought](https://github.com/NovaSky-AI/SkyThought) repository.
136 |
--------------------------------------------------------------------------------
/eval_main.py:
--------------------------------------------------------------------------------
1 | import vllm
2 | import torch
3 | import logging
4 | import datasets
5 | from vllm import SamplingParams
6 | import os
7 | import argparse
8 | import json
9 | import glob
10 | from pathlib import Path
11 | from typing import Any, Dict, List, Tuple, Optional
12 | from tqdm import tqdm
13 | from evals.tasks import TASK_HANDLER_MAP, TASK_NAMES_TO_YAML, TaskConfig, TaskHandler
14 | from evals.util.results import SummaryResults, save_summary
15 |
16 | from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
17 | parse_chat_messages,)
18 | from vllm.utils import is_list_of
19 | from vllm.inputs import TextPrompt, TokensPrompt
20 | from prompt_util.prompt_template import make_conversation_from_contents
21 | from evals.tasks import TASK_HANDLER_MAP, TASK_NAMES_TO_YAML, TaskConfig
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | # Add argument parser
26 | def parse_args():
27 | parser = argparse.ArgumentParser(description='Run model inference with configurable parameters')
28 | parser.add_argument('--model', type=str, default='deepseek-ai/DeepSeek-R1-Distill-Llama-8B',
29 | help='Model name or path')
30 | parser.add_argument('--task', type=str, default='math500',
31 | help='Task name')
32 | parser.add_argument('--dtype', type=str, default='bfloat16',
33 | help='Data type for model (e.g., bfloat16, float16)')
34 | parser.add_argument('--seed', type=int, default=42,
35 | help='Random seed')
36 | parser.add_argument('--batch_size', type=int, default=1,
37 | help='Batch size for inference')
38 | parser.add_argument('--max_tokens', type=int, default=32768,
39 | help='Maximum number of tokens to generate')
40 | parser.add_argument('--exp_name', type=str, default='baseline',
41 | help='Experiment name')
42 | return parser.parse_args()
43 |
44 |
45 |
46 | def set_seed(seed):
47 | torch.manual_seed(seed)
48 | torch.cuda.manual_seed(seed)
49 | torch.cuda.manual_seed_all(seed)
50 | torch.backends.cudnn.deterministic = True
51 | torch.backends.cudnn.benchmark = False
52 |
53 |
54 | # Determine the starting point based on existing .pt files
55 | def get_resume_point(output_path, batch_size):
56 | # Find all .pt files in the output directory
57 | pt_files = glob.glob(f'{output_path}/problem_*_token_ids_*.pt')
58 | if not pt_files:
59 | return 0 # No files exist, start from the beginning
60 |
61 | # Extract global_idx from filenames
62 | global_indices = []
63 | for pt_file in pt_files:
64 | # Filename format: problem__token_ids_*.pt
65 | parts = os.path.basename(pt_file).split('_')
66 | try:
67 | global_idx = int(parts[1]) # Extract the global_idx
68 | global_indices.append(global_idx)
69 | except (IndexError, ValueError):
70 | continue
71 |
72 | if not global_indices:
73 | return 0 # No valid indices found, start from the beginning
74 |
75 | # Find the largest global_idx and calculate the starting batch
76 | max_global_idx = max(global_indices)
77 | resume_point = ((max_global_idx + 1) // batch_size) * batch_size
78 | print(f"Resuming from batch starting at index {resume_point} (max_global_idx={max_global_idx})")
79 | return resume_point
80 |
81 | def score_responses(
82 | handler: TaskHandler,
83 | list_of_results: List[Dict[str, Any]],
84 | eval_data: List[Dict[str, Any]],
85 | ) -> Tuple[float, Dict[str, List[int]], int]:
86 |
87 | if not list_of_results:
88 | return 0.0, {}, 0
89 |
90 | total_correct = 0
91 | total_finish = 0
92 | id_to_scores = {}
93 |
94 | for result in tqdm(list_of_results, desc="Scoring responses"):
95 | # Get content from the result
96 | model_response = result['model_answer']
97 | problem_id = result['problem_id']
98 | problem = eval_data[problem_id]
99 |
100 | new_response_entry = handler.update_results(
101 | problem=problem,
102 | response=model_response,
103 | )
104 |
105 | if problem_id not in id_to_scores:
106 | id_to_scores[problem_id] = [0]
107 | id_to_scores[problem_id][0] = new_response_entry["correctness"]
108 |
109 | total_correct += new_response_entry["correctness"]
110 | total_finish += 1
111 |
112 | accuracy = round(total_correct / total_finish, 4) if total_finish else 0
113 | return accuracy, id_to_scores, total_finish
114 |
115 | if __name__ == '__main__':
116 | args = parse_args()
117 | set_seed(args.seed)
118 | # Create outputs directory if it doesn't exist
119 | output_path = f'./outputs/vllm_main/{args.exp_name}/{args.model}'
120 | os.makedirs(output_path, exist_ok=True)
121 |
122 | task_config = TaskConfig.from_yaml(TASK_NAMES_TO_YAML[args.task])
123 | handler_name = task_config.handler
124 | handler_cls = TASK_HANDLER_MAP[handler_name]
125 | handler = handler_cls(task_config)
126 | eval_data = handler.load_and_filter_dataset(0, -1) # start from 0, load all
127 | remaining_data = handler.process_remaining_data(eval_data, {})
128 | conversations = handler.make_conversations(
129 | remaining_data,
130 | None, # str(model_config.system_prompt),
131 | None, # model_config.user_template,
132 | None, # model_config.assistant_prefill,
133 | )
134 | total_samples = len(conversations)
135 | print(f"Total samples in the dataset: {total_samples}")
136 |
137 | # Get number of available GPUs
138 | num_gpus = torch.cuda.device_count()
139 | print(f"Using {num_gpus} GPUs for tensor parallelism")
140 |
141 | model = vllm.LLM(model=args.model,
142 | tensor_parallel_size=num_gpus,
143 | # max_model_len=length_used,
144 | dtype=args.dtype,
145 | enforce_eager=True)
146 | # Configure sampling parameters to return logits
147 | sampling_params = SamplingParams(temperature=0.0, logprobs=5, max_tokens=args.max_tokens, seed=args.seed)
148 |
149 | # Process in batches
150 | qa_pairs = []
151 | jsonl_path = f'{output_path}/qa_pairs_{args.dtype}_bs_{args.batch_size}.jsonl'
152 |
153 | start_point = get_resume_point(output_path, args.batch_size)
154 |
155 | for batch_start in range(start_point, total_samples, args.batch_size):
156 | batch_end = min(batch_start + args.batch_size, total_samples)
157 | current_batch = conversations[batch_start:batch_end]
158 | print(f"Processing batch {batch_start//args.batch_size + 1}/{(total_samples + args.batch_size - 1)//args.batch_size}")
159 |
160 | tokenizer = model.get_tokenizer()
161 | model_config = model.llm_engine.get_model_config()
162 | prompts = []
163 |
164 | for msgs in current_batch:
165 | # NOTE: _parse_chat_message_content_parts() currently doesn't
166 | # handle mm_processor_kwargs, since there is no implementation in
167 | # the chat message parsing for it.
168 | conversation, mm_data = parse_chat_messages(
169 | msgs,
170 | model_config,
171 | tokenizer,
172 | content_format='string',
173 | )
174 |
175 | prompt_data = apply_hf_chat_template(
176 | tokenizer,
177 | conversation=conversation,
178 | chat_template=None,
179 | add_generation_prompt=True,
180 | continue_final_message=False,
181 | tools=None,
182 | )
183 |
184 | if is_list_of(prompt_data, int):
185 | prompt = TokensPrompt(prompt_token_ids=prompt_data)
186 | else:
187 | prompt = TextPrompt(prompt=prompt_data)
188 |
189 | if mm_data is not None:
190 | prompt["multi_modal_data"] = mm_data
191 |
192 | prompts.append(prompt)
193 |
194 | # Generate with logits for current batch
195 | response = model.generate(prompts, sampling_params=sampling_params)
196 | # Extract output text and logits for each sample in the batch
197 | qa_pairs = []
198 | for idx, output in enumerate(response):
199 | global_idx = batch_start + idx
200 | generated_text = output.outputs[0].text
201 | token_logprobs = output.outputs[0].logprobs
202 | # Create tensors from token_logprobs
203 | num_tokens = len(token_logprobs)
204 | token_ids = torch.zeros((num_tokens, 5), dtype=torch.long)
205 | logprobs = torch.zeros((num_tokens, 5), dtype=torch.float32)
206 | # Save QA pair to JSONL file
207 | qa_pair = {
208 | "problem_id": global_idx,
209 | "question": current_batch[idx],
210 | "model_answer": generated_text,
211 | }
212 |
213 | for i, logprobs_dict in enumerate(token_logprobs):
214 | # Extract token IDs in order of rank
215 | sorted_items = sorted(logprobs_dict.items(), key=lambda x: x[1].rank)
216 | for j, (token_id, L) in enumerate(sorted_items):
217 | token_ids[i, j] = token_id
218 | logprobs[i, j] = L.logprob
219 |
220 | torch.save(token_ids, f'{output_path}/problem_{global_idx}_{args.task}_token_ids_bs_{args.batch_size}_{args.dtype}_max_tokens_{args.max_tokens}.pt')
221 | torch.save(logprobs, f'{output_path}/problem_{global_idx}_{args.task}_logprobs_bs_{args.batch_size}_{args.dtype}_max_tokens_{args.max_tokens}.pt')
222 | print(f"Saved tensors for problem {global_idx}")
223 |
224 | qa_pairs.append(qa_pair)
225 |
226 | with open(jsonl_path, 'a') as f:
227 | for qa_pair in qa_pairs:
228 | f.write(json.dumps(qa_pair) + '\n')
229 | print(f"Saved QA pairs to for batch {batch_start//args.batch_size + 1}")
230 |
231 | responses_path = Path(jsonl_path)
232 |
233 | if responses_path.stat().st_size == 0:
234 | raise ValueError(f"Response file is empty: {responses_path}")
235 |
236 | print(f"Valid response file: {responses_path}")
237 |
238 | # Read the .jsonl file line by line and parse each line as a JSON object
239 | with open(responses_path, "r") as f:
240 | list_of_results = [json.loads(line) for line in f]
241 |
242 | # Check if the response file is a list of dictionaries
243 | if not all(isinstance(result, dict) for result in list_of_results):
244 | raise ValueError(f"Response file does not contain valid dictionaries on each line: {responses_path}")
245 |
246 | # Check if the response file is a list of dictionaries
247 | if not isinstance(list_of_results, list):
248 | raise ValueError(f"Response file is not a list of dictionaries: {responses_path}")
249 |
250 | # Obtain the correct task handler
251 | task = args.task
252 | if task not in TASK_NAMES_TO_YAML:
253 | raise ValueError(
254 | f"Task {task} not found. Should be one of {TASK_NAMES_TO_YAML.keys()}"
255 | )
256 | task_config = TaskConfig.from_yaml(TASK_NAMES_TO_YAML[task])
257 | handler_name = task_config.handler
258 | handler_cls = TASK_HANDLER_MAP[handler_name]
259 | handler = handler_cls(task_config)
260 |
261 | raw_dataset = handler.load_and_filter_dataset(0, -1) # start from 0, load all
262 | eval_data = [
263 | row.to_dict()
264 | for _, row in raw_dataset.iterrows()
265 | ]
266 |
267 | accuracy, id_to_scores, total_finish = score_responses(handler, list_of_results, eval_data)
268 | logger.info(f"Accuracy: {accuracy}")
269 |
270 | num_responses_total = len(id_to_scores)
271 |
272 | summary_data = SummaryResults(
273 | accuracy=accuracy,
274 | )
275 |
276 | # Create outputs directory if it doesn't exist
277 | acc_path = f'./scoring_results/greedy'
278 | os.makedirs(acc_path, exist_ok=True)
279 |
280 | sanitized_model_name = args.model.replace("/", "_")
281 | summary_file = Path(acc_path) / f"{sanitized_model_name}_{args.exp_name}_summary.jsonl"
282 | save_summary(summary_file, summary_data)
283 | logger.info(f"Summary saved to {summary_file}")
284 |
--------------------------------------------------------------------------------
/evals/README.md:
--------------------------------------------------------------------------------
1 | # Skythought-evals: Data Generation and Evaluation Tools
2 |
3 |
4 | ## Requirements
5 |
6 | Make sure you have installed the `skythought` package as outlined in the [README.md](/README.md#usage).
7 |
8 | For running OpenAI model, export the OpenAI key.
9 | ```shell
10 | export OPENAI_API_KEY={openai_api_key}
11 | ```
12 |
13 | ## Usage
14 |
15 | We provide three commands in the CLI:
16 |
17 | - `skythought evaluate` : Evaluate a model on a given task.
18 | - `skythought generate`: Generate model outputs for a pre-configured task.
19 | - `skythought score`: Score saved generations for a given task.
20 |
21 | For a walkthrough on the basics, please refer to the [example](../../examples/evaluate.ipynb).
22 |
23 | ## Generation and Evaluation
24 |
25 | ### Benchmark Evaluation
26 |
27 | Given below are two examples for evaluation.
28 |
29 | ```shell
30 | skythought evaluate --model NovaSky-AI/Sky-T1-32B-Preview --task aime --backend vllm --backend-args tensor_parallel_size=8 --sampling-params temperature=0.6,top_p=0.95 --n 8 --result-dir ./
31 |
32 | skythought evaluate --model NovaSky-AI/Sky-T1-32B-Preview --task gpqa_diamond --backend vllm --backend-args tensor_parallel_size=8 --sampling-params temperature=0.6,top_p=0.95 --n 8
33 | ```
34 |
35 | **Note**: The `GPQADiamond` dataset is gated and requires first receiving access at this Huggingface [link](https://huggingface.co/datasets/Idavidrein/gpqa) (which is granted immediately), then logging into your Huggingface account in your terminal session with `huggingface-cli login`.
36 |
37 |
38 | The results will be saved in a folder in `result-dir`:
39 |
40 | ```bash
41 | result-dir/
42 | ├── Qwen_QwQ-32B-Preview_aime_myHash
43 | │ ├── results.json # contains the full results for the benchmark
44 | │ └── summary.json # contains summary of the run with configuration and metrics
45 | ```
46 |
47 | ### Scaling evaluation with Ray
48 |
49 | You can scale evaluations across multiple model replicas (and across multiple nodes) using [ray](https://docs.ray.io) backend:
50 |
51 | ```shell
52 | skythought evaluate --model Qwen/QwQ-32B-Preview --task aime --backend ray --backend-args tensor_parallel_size=4,num_replicas=4 --result-dir ./
53 | ```
54 |
55 | By default, we make use of the configuration in [ray_configs/ray_config.yaml](./ray_configs/ray_config.yaml). You can also customize the following parameters for ray:
56 |
57 | - `tensor_parallel_size`: Tensor parallel size per replica. Defaults to 4.
58 | - `accelerator_type`: GPU accelerator type. See [the list of available types](https://docs.ray.io/en/latest/ray-core/accelerator-types.html) for more information. Defaults to None, which means any available GPUs in the Ray cluster will be used.
59 | - `num_replicas`: Number of model replicas to use for inference. Defaults to 2.
60 | - `batch_size`: Batch size per model replica for inference.
61 | - `gpu_memory_utilization`: Fraction of GPU memory allocated to the model executor in vLLM. Defaults to 0.9.
62 | - `dtype`: Data type used for inference. Defaults to "auto".
63 |
64 |
65 | ### Optimized settings for 32B and 7B models
66 |
67 | The following are optimized settings on a 8xH100 or a 8xA100 node. We recommend using `ray` backend for best performance.
68 |
69 | For 32B models, we recommend using the default backend configuration for best performance.
70 |
71 | ```shell
72 | skythought evaluate --model Qwen/QwQ-32B-Preview --task aime24 --backend ray --result-dir ./
73 | ```
74 |
75 | For 7B models, we recommend using `tensor_parallel_size=1` and `num_replicas=8` for best performance. For example, the previous command will change to:
76 |
77 | ```shell
78 | skythought evaluate --model Qwen/Qwen2-7B-Instruct --task math500 --backend ray --backend-args tensor_parallel_size=1,num_replicas=8 --result-dir ./
79 | ```
80 |
81 | #### Multi-node inference
82 |
83 | Note that if you have a ray cluster setup, you can scale the number of replicas as needed with `num_replicas` argument in `backend-args` to make full use of your cluster. Make sure to execute the script on the head node and ensure that `--result-dir` is a valid directory that the head node can write to.
84 |
85 | ### Best-of-N Evaluation
86 |
87 | You can use the `--n` parameter to specify the number of generations per problem. For `n>1` , we calculate pass
88 |
89 | ```bash
90 | skythought evaluate --model Qwen/Qwen2-7B-Instruct --task math500 --backend ray --backend-args tensor_parallel_size=1,num_replicas=8 --sampling-params temperature=0.7,max_tokens=4096 --n 64 --result-dir ./
91 | ```
92 |
93 | ### Distill and Reject Sampling
94 | Currently we support distill and reject sampling for NUMINA, APPS, and TACO datasets. For NUMINA, the source can be one from `[amc_aime, math, olympiads]`.
95 |
96 | #### Example Usage
97 |
98 | ```shell
99 | skythought generate --model Qwen/QwQ-32B-Preview --task numina_amc_aime --backend ray --backend-args tensor_parallel_size=8 --sampling-params max_tokens=16384 --result-dir $SKYT_HOME/data
100 | ```
101 |
102 | Once the generations are saved, you can then apply any postprocessing on the results (saved in a `results.json` file in separate run folder) and then run:
103 |
104 | ```shell
105 | skythought score --task numina_amc_aime --run-dir
106 | ```
107 |
108 | ### Reproducibility Issues
109 |
110 |
111 | We've noticed that it can be hard to reproduce results in reasoning benchmarks. Beyond the lack of agreed sampling parameters and metrics in the field at the moment, there can be significant differences in results across different evaluation codebases, and even for the same codebase with a different set of dependencies. In half-precision (bfloat16 or float16), numerical error accumulation will change outputs ever so slightly, which can dramatically alter final performance. There are three factors we've noticed that affect results:
112 |
113 | - Long context generations: Errors can accumulate so that the output changes at 1k+ tokens, which compound as you keep generating. Since we typically set max tokens to be 16k or 32k tokens, the final solution will change significantly
114 | - vLLM settings: With vLLM, we’ve also noticed that at half-precision, different batch sizes can affect downstream evaluation results by a few percentage points. Further, different tensor parallelism settings can also change results in half-precision.
115 | - vLLM version: Different versions of vLLM will use different CUDA-Toolkit or Flash attention versions. Even for the same settings, these differences in the underlying kernels used can change results.
116 |
117 | We recommend to run evaluation benchmarks at full precision, i.e float32 to avoid this. In full-precision, evaluation results should be robust to changes in batch size, tensor parallel size, version differences, etc.
118 |
119 |
120 | ## Key Concepts
121 |
122 | ### Tasks
123 |
124 | A Task consists of task-specific configuration and implements
125 | - Dataset loading and preprocessing
126 | - Creating of input conversation to the model
127 | - Scoring of model responses
128 |
129 | The configuration (`TaskConfig`) contains dataset loading related details such as Hugging Face dataset ID, the particular subset for this benchmark (e.g., ”Challenge” subset for ARC), and a task template, which contains task-specific instructions to be used (Eg: `Return your answer in \boxed{}`). Each configuration is stored in a YAML. For example, you can see the YAML in this [aime24.yaml file](./tasks/aime/aime24.yaml)
130 |
131 | Internally, a Task implementation is termed a "TaskHandler", you can see one such implementation [here](./tasks/aime/aime_handler.py).
132 |
133 |
134 | To add a new task `mytask`:
135 | - First, see if the task can be simply specified as a configuration (One example is [`aime25`](./tasks/aime/aime25.yaml)). If so, you can add a YAML file in the appropriate folder and re-use an existing handler. (All available handlers are specified [here](./tasks/__init__.py)).
136 | - If not, you should create a new `TaskHandler` subclass for this task along with a task configuration YAML (`mytask.yaml`).
137 |
138 | ### Models
139 |
140 | A Model consists of the model ID and templating configuration. This configuration optionally contains the system prompt and an assistant prefill message. Different reasoning models use their own system prompt, and some perform best when the response is prefilled with special tokens.
141 |
142 | We store our pre-configured models as well as a list of system prompt templates [here](./models/model_configs.yaml).
143 |
144 | ### Backend
145 |
146 | The Backend is concerned with how the LLM instance is created and queried. For flexibility, we support
147 | - Local inference with vLLM (basic single node) or Ray+vLLM (more scalable single and multi-node inference)
148 | - Remote inference behind an OpenAI-compatible endpoint.
149 |
150 | The Backend also consists of configuration at instantiation (ex; the data type for the model), along with sampling parameters during generation (temperature, max tokens, etc).
151 |
152 | During evaluation, the above tie in together and the flow is as follows:
153 | 1. Load dataset and create conversations based on the Task and Model specified by the user
154 | 2. Generate model responses from the Backend based on the provided sampling parameters
155 | 3. Score model responses based on the Task
156 | 4. Output final results
157 |
--------------------------------------------------------------------------------
/evals/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanomaoli/llm_reproducibility/8a373c5a159a27e59783394827cecadd6255484e/evals/__init__.py
--------------------------------------------------------------------------------
/evals/base_instruct_evals.md:
--------------------------------------------------------------------------------
1 | # Reproducing results on non-reasoning benchmarks
2 |
3 | For the full set of results, see [here](./README.md#results-on-qa-and-instruction-following-benchmarks).
4 |
5 | ## Installation instructions
6 |
7 | 1. For `lm_eval`, install the package by executing the following :
8 |
9 | ```bash
10 | git clone https://github.com/EleutherAI/lm-evaluation-harness
11 | cd lm-evaluation-harness
12 | git checkout 703fbff
13 | pip install -e ".[ifeval]"
14 | ```
15 |
16 | For more details, you can refer to the official instructions [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/703fbffd6fe5e136bbb9d884cb40844e5503ae5d?tab=readme-ov-file#install). We report results with commit https://github.com/EleutherAI/lm-evaluation-harness/commit/703fbffd6fe5e136bbb9d884cb40844e5503ae5d
17 |
18 | 2. For `fastchat`, follow the instructions [here](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge#install). The current implementation of Fastchat is based on OpenAI version <= 0.28.0. For making use of the latest vllm backend, it is recommended to migrate the `llm_judge` folder to use openai>=1.0.0. You can run `openai migrate` for the fastchat codebase or follow the PR [here](https://github.com/lm-sys/FastChat/pull/2915/files)
19 | 3. For `BFCL`, you can follow the official instructions [here](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#basic-installation). We further evaulate on all test categories, which requires [setting up environment variables](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#setting-up-environment-variables), and [obtaining API keys for executable test categories](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#api-keys-for-executable-test-categories). Make sure to use changes from [this PR](https://github.com/ShishirPatil/gorilla/pull/888) for QwQ and Sky-T1 model support.
20 | 4. For `Arena-Hard` results, you can follow the instructions [here](https://github.com/lmarena/arena-hard-auto). We use `gpt-4-1106-preview` as the judge.
21 |
22 | ## Commands for reproducing results
23 |
24 | All the benchmarks were run on a 8xH100 machine with the `vllm` backend. If you're running on a different device, make sure to tweak `tensor_parallel_size` and if needed the `batch_size` arguments. Expect some variance in scores (+/- 1%) for different evaluation settings (ex: `tensor_parallel_size`)
25 |
26 | All the commands below are given for `NovaSky-AI/Sky-T1-32B-Preview`. Simply substitute the model name for `Qwen/Qwen-2.5-32B-Instruct`. For `Qwen/QwQ-32B-Preview`, we further make use of two arguments `revision=refs/pr/58,tokenizer_revision=refs/pr/58` to use a corrected revision of QwQ. For more details on this, see https://github.com/NovaSky-AI/SkyThought/pull/26#issuecomment-2606435601.
27 |
28 | ### MMLU (0 shot; no CoT)
29 |
30 | ```bash
31 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mmlu --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn
32 | ```
33 |
34 | For QwQ, you would do
35 |
36 | ```bash
37 | lm_eval --model vllm --model_args pretrained=Qwen/QwQ-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048revision=refs/pr/58,tokenizer_revision=refs/pr/58 --tasks mmlu --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn
38 | ```
39 |
40 | ### MMLU (5 shot; no CoT)
41 |
42 | ```bash
43 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mmlu --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn --num_fewshot 5
44 | ```
45 |
46 | ### ARC-C (0 shot; no CoT)
47 |
48 | ```bash
49 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks arc_challenge --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn
50 | ```
51 |
52 | ### IFEval
53 |
54 | ```bash
55 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.9,data_parallel_size=1 --tasks leaderboard_ifeval --trust_remote_code --batch_size auto --apply_chat_template --fewshot_as_multiturn
56 | ```
57 |
58 | We use the `prompt_level_strict_acc` metric following Qwen-2.5.
59 |
60 | ### MGSM (native CoT)
61 |
62 | ```bash
63 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mgsm_direct --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn
64 | ```
65 |
66 | We report the average value of `flexible-extract` filter.
67 |
68 | ### MGSM (8-shot; native CoT)
69 |
70 | ```bash
71 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mgsm_direct --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn --num_fewshot 8
72 | ```
73 |
74 | ### LLM-as-a-Judge
75 |
76 | We use the default settings - with `max_tokens` 1024 and the `gpt-4` judge. We observe that some reasoning models like `Qwen/QwQ-32B-Preview` are unable to provide brief responses sometimes and thus get truncated responses at the used `max_tokens`. While this will effect the final rating, given the context length limitations of the commonly used `gpt-4` judge (8K tokens), we stick to the 1024 `max_tokens` budget for consistency.
77 |
78 | 1. First, serve the model with vLLM
79 |
80 |
81 | ```bash
82 | vllm serve NovaSky-AI/Sky-T1-32B-Preview --dtype auto --tensor-parallel-size 8 --gpu-memory-utilization 0.9
83 | ```
84 |
85 | For `Qwen/QwQ-32B-Preview`, use
86 |
87 | ```bash
88 | vllm serve Qwen/QwQ-32B-Preview --dtype auto --tensor-parallel-size 8 --gpu-memory-utilization 0.9 --revision refs/pr/58 --tokenizer-revision refs/pr/58
89 | ```
90 |
91 | 2. Next, generate model response
92 |
93 | ```bash
94 | python gen_api_answer.py --model NovaSky-AI/Sky-T1-32B-Preview --openai-api-base http://localhost:8000/v1 --parallel 50
95 | ```
96 |
97 | Note: The generated results will be in `data/model_answer//.jsonl` . Move them to the root folder `data/model_answer/`
98 |
99 | 3. After generating responses for all the models, evaluate with the default settings
100 |
101 | ```bash
102 | export OPENAI_API_KEY=XXXXXX # set the OpenAI API key
103 | python gen_judgment.py --model-list Sky-T1-32B-Preview QwQ-32B-Preview Qwen2.5-32B-Instruct --parallel 2
104 | ```
105 | 4. Get MTBench scores (we use the average score of both turns)
106 |
107 | ```bash
108 | python show_result.py
109 | ```
110 |
111 | ### BFCL-v3
112 |
113 | Our results are reported on `test-category` `all` . Make sure to get the API keys for the executable test categories by following the instructions [here](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#api-keys-for-executable-test-categories)
114 |
115 | Run
116 |
117 | ```bash
118 | bfcl generate --model NovaSky-AI/Sky-T1-32B-Preview --test-category all --backend vllm --num-gpus 8 --gpu-memory-utilization 0.9
119 | ```
120 |
121 | For evaluation, you can simply run
122 |
123 | ```bash
124 | bfcl evaluate --model Qwen/QwQ-32B-Preview,NovaSky-AI/Sky-T1-32B-Preview,Qwen/Qwen2.5-32B-Instruct --test-category all --api-sanity-check
125 | ```
126 | ### Arena Hard
127 | For `Arena-Hard`, we use the following script to start a `TGI` service for generating answers
128 | ```bash
129 | hf_pat=
130 | model=NovaSky-AI/Sky-T1-32B-Preview
131 | volume=/mnt/local_storage/data/cache
132 | port=1996
133 |
134 | huggingface-cli download $model
135 | sudo docker run --gpus 8 -e HUGGING_FACE_HUB_TOKEN=$hf_pat --shm-size 2000g -p $port:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0 --model-id $model --max-input-length 8192 --max-batch-total-tokens 8193 --max-batch-prefill-tokens 8193 --max-total-tokens 8193 --sharded true
136 | ```
137 | For running the `gen_answer.py` script, we use the following `config_api` yaml setting. For `qwq-32b-preview`, we explicitly specify the system prompt as `You are a helpful and harmless assistant. You are Qwen developed by Alibaba.` to avoid the CoT prompt.
138 | ```yaml
139 | ...
140 | sky-T1-32B-Preview:
141 | model_name: sky-T1-32B-Preview
142 | endpoints:
143 | - api_base: http://localhost:1996/v1
144 | api_key: empty
145 | api_type: openai
146 | parallel: 8
147 | ...
148 | ```
149 | and finally for `gen_judgment.py`, we use `gpt-4-1106-preview` as the judge.
150 |
151 | #### Supplementary results for Arena-Hard
152 |
153 | Here are some supplementary results for Arena-Hard, compared with o1-mini which is the best performing model on this benchmark (as of Jan 2025).
154 |
155 | | model | score | rating_q025 | rating_q975 | CI | avg_tokens | date |
156 | |-------|--------|------------|-------------|-------|------------|-------|
157 | | o1-mini-2024-09-12 | 91.98 | 90.88 | 93.12 | (-1.10, +1.14) | 1399.0 | 2025-01-18 |
158 | | sky-T1-32B-Preview | 74.79 | 72.28 | 76.8 | (-2.51, +2.01) | 847.0 | 2025-01-18 |
159 | | qwen2.5-32b-instruct | 66.51 | 64.55 | 68.4 | (-1.96, +1.89) | 611.0 | 2025-01-18 |
160 | | qwq-32b-preview | 52.6 | 50.86 | 54.91 | (-1.74, +2.31) | 1005.0 | 2025-01-23 |
161 |
162 | For more details, see: https://github.com/NovaSky-AI/SkyThought/pull/26#issuecomment-2599525551
163 |
--------------------------------------------------------------------------------
/evals/batch/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = []
2 |
3 | from .engines import init_engine_from_config
4 | from .pipeline import Pipeline
5 | from .workload import (
6 | EvalWorkload,
7 | )
8 |
9 | __all__ = [
10 | "Pipeline",
11 | "init_engine_from_config",
12 | "EvalWorkload",
13 | ]
14 |
--------------------------------------------------------------------------------
/evals/batch/engines/__init__.py:
--------------------------------------------------------------------------------
1 | """LLM Engines."""
2 |
3 | __all__ = []
4 |
5 | from .initializer import EngineInitializerBase, init_engine_from_config
6 |
7 | __all__ = [
8 | "EngineInitializerBase",
9 | "init_engine_from_config",
10 | ]
11 |
--------------------------------------------------------------------------------
/evals/batch/engines/base.py:
--------------------------------------------------------------------------------
1 | """Engine base."""
2 |
3 | from typing import Any, AsyncGenerator, Dict
4 |
5 | import numpy as np
6 |
7 |
8 | class EngineBase:
9 | """Base class for engines."""
10 |
11 | async def __call__(
12 | self, batch: Dict[str, np.ndarray]
13 | ) -> AsyncGenerator[Dict[str, Any], None]:
14 | """Call the LLM engine asynchronously to process a Ray Data batch.
15 |
16 | Args:
17 | batch: The batch.
18 |
19 | Yields:
20 | The output.
21 | """
22 | raise NotImplementedError
23 |
--------------------------------------------------------------------------------
/evals/batch/engines/initializer.py:
--------------------------------------------------------------------------------
1 | """Engine initializers.
2 | Note that this file should not import any engine dependent modeules, such as
3 | vLLM, because the engine initializer is used in the driver node which may
4 | not have GPUs.
5 | """
6 |
7 | import os
8 | from pathlib import Path
9 | from typing import Any, Dict, Optional, Union
10 |
11 | import yaml
12 |
13 | from ..utils import (
14 | download_model_from_hf,
15 | update_dict_recursive,
16 | )
17 | from ..workload import EvalWorkload
18 | from .base import EngineBase
19 |
20 |
21 | class EngineInitializerBase:
22 | """Base class for engine initializer.
23 |
24 | Args:
25 | model_id: The model id.
26 | accelerator_type: The accelerator type.
27 | engine_kwargs: The engine specific configurations.
28 | ray_env_vars: The Ray runtime environment
29 | """
30 |
31 | use_ray_placement_group: bool = False
32 |
33 | def __init__(
34 | self,
35 | model_id: str,
36 | accelerator_type: str,
37 | engine_kwargs: Dict[str, Any],
38 | lora_adapter: Optional[str] = None,
39 | ray_env_vars: Dict[str, Any] = None,
40 | ):
41 | self._model = model_id
42 | self._accelerator_type = accelerator_type
43 | self._ray_env_vars = ray_env_vars or {}
44 | self.lora_adapter = lora_adapter
45 | self.engine_kwargs = engine_kwargs
46 |
47 | @property
48 | def model(self) -> str:
49 | return self._model
50 |
51 | @property
52 | def accelerator_type(self) -> str:
53 | return self._accelerator_type
54 |
55 | @property
56 | def ray_env_vars(self) -> Dict[str, str]:
57 | return self._ray_env_vars
58 |
59 | @property
60 | def num_gpus(self) -> int:
61 | """The number of GPUs used per engine."""
62 | raise NotImplementedError
63 |
64 | @property
65 | def max_model_len(self) -> Optional[int]:
66 | """The maximum model length set by the engine."""
67 | return None
68 |
69 | def get_engine_cls(self) -> EngineBase:
70 | """Get the engine class.
71 |
72 | Returns:
73 | The engine class.
74 | """
75 | raise NotImplementedError
76 |
77 | def get_engine_constructor_args(self, workload: EvalWorkload) -> Dict[str, Any]:
78 | """Get the engine constructor arguments.
79 |
80 | Args:
81 | workload: The workload that the engine will process.
82 |
83 | Returns:
84 | The engine constructor keyword arguments.
85 | """
86 | raise NotImplementedError
87 |
88 |
89 | class vLLMEngineInitializer(EngineInitializerBase):
90 | use_ray_placement_group: bool = False
91 |
92 | def __init__(
93 | self,
94 | model_id: str,
95 | accelerator_type: str,
96 | engine_kwargs: Dict[str, Any],
97 | lora_adapter: Optional[str] = None,
98 | ray_env_vars: Dict[str, Any] = None,
99 | ):
100 | super().__init__(
101 | model_id, accelerator_type, engine_kwargs, lora_adapter, ray_env_vars
102 | )
103 |
104 | # Override vLLM default configs. Note that this is only effective
105 | # when the config is not set by users.
106 | self.engine_kwargs.setdefault("gpu_memory_utilization", 0.95)
107 | self.engine_kwargs.setdefault("use_v2_block_manager", True)
108 | self.engine_kwargs.setdefault("enable_prefix_caching", False)
109 | self.engine_kwargs.setdefault("enforce_eager", False)
110 | self.engine_kwargs.setdefault("pipeline_parallel_size", 1)
111 | self.engine_kwargs.setdefault("max_num_seqs", 256)
112 | self.engine_kwargs.setdefault("tensor_parallel_size", 1)
113 | self.engine_kwargs.setdefault("max_logprobs", 0)
114 | self.engine_kwargs.setdefault("distributed_executor_backend", "mp")
115 |
116 | # Set engine environment variables.
117 | self._ray_env_vars.setdefault("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")
118 | self._ray_env_vars.setdefault("ENABLE_ANYSCALE_PREFIX_OPTIMIZATIONS", "0")
119 | # FIXME: This should already be deprecated and can be removed.
120 | self._ray_env_vars.setdefault("VLLM_DISABLE_LOGPROBS", "1")
121 | for key, value in self._ray_env_vars.items():
122 | os.environ[key] = str(value)
123 |
124 | def get_engine_cls(self):
125 | from .vllm_engine import AsyncLLMPredictor
126 |
127 | return AsyncLLMPredictor
128 |
129 | @property
130 | def num_gpus(self) -> int:
131 | assert "tensor_parallel_size" in self.engine_kwargs
132 | assert "pipeline_parallel_size" in self.engine_kwargs
133 | tp_size = self.engine_kwargs["tensor_parallel_size"]
134 | pp_size = self.engine_kwargs["pipeline_parallel_size"]
135 | return tp_size * pp_size
136 |
137 | @property
138 | def max_model_len(self) -> Optional[int]:
139 | """The maximum model length set by the engine."""
140 | return self.engine_kwargs.get("max_model_len", None)
141 |
142 | def get_engine_constructor_args(self, workload: EvalWorkload):
143 | from vllm import PoolingParams, SamplingParams
144 | from vllm.config import PoolerConfig
145 |
146 | constructor_kwargs = {
147 | "model": self.model,
148 | "lora_adapter": self.lora_adapter,
149 | }
150 |
151 | if sampling_params := workload.sampling_params:
152 | # Sampling params is given: Auto-regressive generation.
153 | # In this case, we need to set max_tokens and max_model_len.
154 |
155 | max_tokens = sampling_params.get("max_tokens", None)
156 | if max_tokens is None:
157 | raise ValueError("max_tokens is required for vLLM engine.")
158 |
159 | vllm_sampling_params = SamplingParams(**workload.sampling_params)
160 | vllm_sampling_params.max_tokens = max_tokens
161 | vllm_sampling_params.detokenize = False
162 | constructor_kwargs["params"] = vllm_sampling_params
163 |
164 | if (
165 | "max_model_len" not in self.engine_kwargs
166 | and workload.max_tokens_in_prompt < 0
167 | ):
168 | raise ValueError(
169 | "Neither max_tokens_in_prompt nor max_model_len is set. If you "
170 | "intend to let the pipeline infer max_tokens_in_prompt but got this error, "
171 | "it is either because the workload has not been tokenized, or the "
172 | "workload bypass the tokenizer but does not set max_tokens_in_prompt by itself."
173 | )
174 |
175 | # Use max_tokens_in_prompt + max_tokens as the max_model_len. max_tokens_in_prompt
176 | # is either inferred by materializing tokenized dataset, set by the workload, or
177 | # set by the engine.
178 | self.engine_kwargs["max_model_len"] = (
179 | workload.max_tokens_in_prompt + max_tokens
180 | )
181 | else:
182 | # Sampling params is not given: Embedding workload.
183 | # In this case, we need to set pooling_params and task.
184 |
185 | if workload.pooling_params is None:
186 | raise ValueError(
187 | "pooling_params is required for vLLM engine for embedding workload."
188 | )
189 | constructor_kwargs["params"] = PoolingParams(**workload.pooling_params)
190 | constructor_kwargs["task"] = "embed"
191 |
192 | # Construct PoolerConfig if override_pooler_config is specified.
193 | if pooler_config := self.engine_kwargs.get("override_pooler_config", None):
194 | self.engine_kwargs["override_pooler_config"] = PoolerConfig(
195 | **pooler_config
196 | )
197 |
198 | constructor_kwargs.update(self.engine_kwargs)
199 | return constructor_kwargs
200 |
201 |
202 | def init_engine_from_config(
203 | config: Union[Dict[str, Any], str], override: Optional[Dict[str, Any]] = None
204 | ) -> EngineInitializerBase:
205 | """Initialize an engine initializer from a config file or a config dict.
206 |
207 | Args:
208 | config: A config file (in YAML) or a config dict. It should include
209 | the following keys: "engine", backend engine to use; "model",
210 | model to use; "accelerator_type", the GPU type; "configs",
211 | the engine specific configurations.
212 | override: Override values in config["configs"].
213 |
214 | Returns:
215 | An engine initializer.
216 | """
217 | if isinstance(config, str):
218 | config_path = Path(config)
219 | if not config_path.exists():
220 | raise FileNotFoundError(f"Engine config file {config} not found.")
221 | with open(config_path, "r") as filep:
222 | config = yaml.safe_load(filep)
223 |
224 | assert isinstance(config, dict)
225 |
226 | # Override configs
227 | if override is not None:
228 | update_dict_recursive(config, override)
229 |
230 | # Ray runtime environments.
231 | runtime_env: Dict[str, Any] = config.get("runtime_env", {})
232 | ray_env_vars: Dict[str, Any] = runtime_env.get("env_vars", {})
233 |
234 | # Download model and save to local path in advance, in case
235 | # too many worker downloads the model in parallel and hit huggingface rate limit.
236 | assert "model_id" in config and isinstance(config["model_id"], str)
237 | if ray_env_vars.pop("PREDOWNLOAD_MODEL_FROM_HF", "0") == "1":
238 | config["model_id"] = download_model_from_hf(
239 | config["model_id"], "/mnt/cluster_storage"
240 | )
241 |
242 | # Do not download LoRA adapter here because it is not used in the driver node.
243 | lora_adapter = None
244 | if "lora_config" in config:
245 | lora_adapter = config["lora_config"].get("dynamic_lora_loading_path", None)
246 |
247 | # Sanity check for engine kwargs.
248 | for key in ("llm_engine", "model_id", "accelerator_type"):
249 | if key not in config:
250 | raise KeyError(f"Required {key} not found in config.")
251 | if "engine_kwargs" not in config:
252 | config["engine_kwargs"] = {}
253 |
254 | name = config["llm_engine"]
255 | if name == "vllm":
256 | return vLLMEngineInitializer(
257 | model_id=config["model_id"],
258 | accelerator_type=config["accelerator_type"],
259 | engine_kwargs=config["engine_kwargs"],
260 | lora_adapter=lora_adapter,
261 | ray_env_vars=ray_env_vars,
262 | )
263 |
264 | raise ValueError(f"Unknown engine: {name}")
265 |
--------------------------------------------------------------------------------
/evals/batch/env_config.py:
--------------------------------------------------------------------------------
1 | """Environment configurations for Ray."""
2 |
3 | from dataclasses import dataclass
4 | from typing import Dict, Optional
5 |
6 | from .logging import get_logger
7 |
8 | logger = get_logger(__name__)
9 |
10 |
11 | @dataclass
12 | class EnvConfig:
13 | """Environment configurations for Ray."""
14 |
15 | # General configurations.
16 | hf_token: Optional[str] = None
17 | ray_override_job_runtime_env: str = "1"
18 |
19 | # Ray Data configurations.
20 | ray_data_default_wait_for_min_actors_s: int = 600
21 |
22 | # The number of LLM engine replicas to use.
23 | num_replicas: int = 1
24 | # The batch size. This represents the unit of fault tolerance.
25 | # Smaller batch size implies more fault tolerance but may
26 | # introduce more overhead. Batch size should at least be 16 to
27 | # avoid hanging.
28 | batch_size: int = 256
29 |
30 | def gen_ray_runtime_envs(self, engine_envs: Dict[str, str]) -> Dict[str, str]:
31 | """Generate Ray runtime environment variables."""
32 | envs = {k.upper(): str(v) for k, v in engine_envs.items()}
33 |
34 | for key in (
35 | "hf_token",
36 | "ray_override_job_runtime_env",
37 | "ray_data_default_wait_for_min_actors_s",
38 | ):
39 | if getattr(self, key) is not None:
40 | envs[key.upper()] = str(getattr(self, key))
41 | return envs
42 |
--------------------------------------------------------------------------------
/evals/batch/logging/__init__.py:
--------------------------------------------------------------------------------
1 | """Logging."""
2 |
3 | import logging
4 | from typing import Optional
5 |
6 | from ray._private.ray_logging.filters import CoreContextFilter
7 | from ray._private.ray_logging.formatters import JSONFormatter
8 |
9 |
10 | def _add_ray_logging(handler: logging.Handler):
11 | """Add Ray logging to the handler.
12 |
13 | This is not used for now and will be enabled after the Ray Job is supported.
14 |
15 | Args:
16 | handler: The handler to add Ray logging to.
17 | """
18 | handler.addFilter(CoreContextFilter())
19 | handler.setFormatter(JSONFormatter())
20 |
21 |
22 | def _setup_logger(logger_name: str):
23 | """Setup logger given the logger name.
24 |
25 | This function is idempotent and won't set up the same logger multiple times.
26 |
27 | Args:
28 | logger_name: The name of the logger.
29 | """
30 | logger = logging.getLogger(logger_name)
31 |
32 | # Skip setup if the logger already has handlers setup.
33 | if logger.handlers:
34 | return
35 |
36 | handler = logging.StreamHandler()
37 | logger.addHandler(handler)
38 | logger.setLevel(logging.INFO)
39 | logger.propagate = False
40 |
41 |
42 | def get_logger(name: Optional[str] = None) -> logging.Logger:
43 | """Get a structured logger.
44 |
45 | Loggers by default are logging to stdout, and are expected to be scraped by an
46 | external process.
47 |
48 | Args:
49 | name: The name of the logger.
50 |
51 | Returns:
52 | A logger instance.
53 | """
54 | _setup_logger(name)
55 | return logging.getLogger(name)
56 |
--------------------------------------------------------------------------------
/evals/batch/pipeline.py:
--------------------------------------------------------------------------------
1 | """Pipeline for batch processing large-scale LLM workloads."""
2 |
3 | import os
4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
5 |
6 | import ray
7 | from ray.data._internal.stats import DatasetStats
8 | from ray.data.dataset import Dataset
9 | from ray.util import remove_placement_group
10 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
11 |
12 | from .engines import EngineInitializerBase, init_engine_from_config
13 | from .env_config import EnvConfig
14 | from .logging import get_logger
15 | from .tokenizer import Detokenizer
16 | from .workload import EvalWorkload
17 |
18 | if TYPE_CHECKING:
19 | from ray.util.placement_group import PlacementGroup
20 |
21 | logger = get_logger(__name__)
22 |
23 |
24 | class Pipeline:
25 | """Pipeline for batch processing large-scale LLM workloads.
26 |
27 | Args:
28 | engine_initializer: An engine initializer to create and initialize an engine.
29 | workload: Workload instance.
30 | env_config: EnvConfig to provide environment configurations of Ray.
31 | """
32 |
33 | def __init__(
34 | self,
35 | engine_initializer: EngineInitializerBase,
36 | env_config: EnvConfig,
37 | ):
38 | self.engine_initializer = engine_initializer
39 | self.env_config = env_config
40 | self.num_replicas: int = self.env_config.num_replicas
41 | self.ds: Optional[Dataset] = None
42 | self.stats: Optional[DatasetStats] = None
43 |
44 | self.pgs: List["PlacementGroup"] = []
45 |
46 | if not ray.is_initialized():
47 | ray.init(runtime_env={"env_vars": self.env_vars})
48 |
49 | @classmethod
50 | def from_config(
51 | cls, engine_cfg: Union[Dict[str, Any], str], workload: EvalWorkload, **kwargs
52 | ):
53 | """Initialize the pipeline from a configuration file or dictionary.
54 |
55 | Args:
56 | engine_cfg: A config file (in YAML) or a config dict. It should include
57 | the following keys: "engine", backend engine to use; "model",
58 | model to use; "accelerator_type", the GPU type; "configs",
59 | the engine specific configurations.
60 | workload: Workload instance.
61 | **kwargs: environment configuration parameters. See `EnvConfig` for more details.
62 | """
63 | engine_initializer = init_engine_from_config(engine_cfg)
64 | env_config = EnvConfig(**kwargs)
65 | return cls(engine_initializer, workload, env_config)
66 |
67 | @property
68 | def env_vars(self) -> Dict[str, Any]:
69 | return self.env_config.gen_ray_runtime_envs(
70 | self.engine_initializer.ray_env_vars
71 | )
72 |
73 | def load(
74 | self,
75 | repartition_by_batch_size: bool = False,
76 | ) -> Dataset:
77 | """Use the given workload to load and process the dataset,
78 | and then tokenize the prompts if needed. The processed dataset
79 | will be repartitioned based on the number of replicas and batch size.
80 |
81 | Args:
82 | repartition_by_batch_size: Whether to repartition the dataset by the
83 | batch size for fault tolerance granularity. You should enable
84 | this when the dataset is not from parquet and checkpointing is
85 | disabled.
86 |
87 | Returns:
88 | The processed dataset.
89 | """
90 | ds, num_blocks = self.workload.get_preprocessed_dataset(
91 | self.env_config.batch_size,
92 | repartition_by_batch_size,
93 | )
94 | if num_blocks is not None and num_blocks < self.num_replicas:
95 | logger.warning(
96 | "The number of blocks (%d) is less than the number of replicas (%d). "
97 | "This may result in suboptimal performance.",
98 | num_blocks,
99 | self.num_replicas,
100 | )
101 |
102 | if self.workload.need_tokenize:
103 | # TODO: Figure out a better concurrency.
104 | # Now we simply assume each LLM replica could have 4 tokenizers.
105 | # This is a heuristic and may not be optimal.
106 | tokenizer_concurrency = self.num_replicas * 4
107 | ds = ds.map_batches(
108 | self.workload.tokenizer_cls,
109 | fn_constructor_kwargs=self.workload.tokenizer_constructor_kwargs(
110 | self.engine_initializer.model
111 | ),
112 | zero_copy_batch=True,
113 | concurrency=(1, tokenizer_concurrency),
114 | batch_size=self.env_config.batch_size,
115 | )
116 |
117 | # If max tokens in prompt is not set in the workload and max_model_len is not set
118 | # in the engine, we need to materialize the dataset to get the maximum tokens in prompt.
119 | # This may hurt the overall throughput but may be memory efficient.
120 | if self.workload.max_tokens_in_prompt == -1:
121 | if self.engine_initializer.max_model_len is not None:
122 | max_tokens = self.workload.sampling_params.get("max_tokens", 0)
123 | max_tokens_in_prompt = (
124 | self.engine_initializer.max_model_len - max_tokens
125 | )
126 | msg = f"Max Prompt Tokens (max_model_len - max_tokens): {max_tokens_in_prompt}"
127 | else:
128 | logger.info(
129 | "Materializing dataset after tokenization to get max prompt tokens"
130 | )
131 | ds = ds.materialize()
132 |
133 | max_tokens_in_prompt = int(ds.max("num_text_tokens"))
134 | msg = f"Max Prompt Tokens (inferred): {max_tokens_in_prompt}"
135 | self.workload.max_tokens_in_prompt = max_tokens_in_prompt
136 | else:
137 | msg = f"Max Prompt Tokens (specified in wokrload): {self.workload.max_tokens_in_prompt}"
138 |
139 | logger.info(msg)
140 | self.ds = ds
141 | return ds
142 |
143 | def __call__(self, workload: EvalWorkload):
144 | self.workload: EvalWorkload = workload
145 | # Set the task to "embed" if sampling params are not given.
146 | self.task_type_str: str = (
147 | "auto" if self.workload.sampling_params is not None else "embed"
148 | )
149 | return self.run(eager=False)
150 |
151 | def run(
152 | self,
153 | dataset: Optional[Dataset] = None,
154 | output_path: Optional[str] = None,
155 | detokenize: bool = True,
156 | eager: bool = True,
157 | repartition_by_batch_size: bool = False,
158 | ) -> Optional[Dataset]:
159 | """Perform batch processing on the dataset with LLM engines.
160 |
161 | Args:
162 | dataset: The dataset to process. If None, we directly use the given workload
163 | to load and process the dataset.
164 | output_path: The output path to write the processed dataset to parquet. It can be
165 | a path to a S3 bucket, or a path to local disk (with local:// as the prefix). If None,
166 | the processed dataset will be materialized but not be written.
167 | detokenize: Whether to detokenize the generated text. Default is True.
168 | eager: Whether to run the pipeline eagerly. If True, the dataset will be materialized.
169 | If False, we skip the materialization step and return the dataset. If output_path is specified,
170 | the dataset will be written to files and therefore will be materialized
171 | regardless of the eager flag.
172 | repartition_by_batch_size: Whether to repartition the dataset by the
173 | batch size for fault tolerance granularity. You should enable
174 | this when the dataset is not from parquet and checkpointing is
175 | disabled.
176 |
177 | Returns:
178 | The processed dataset. If output_path is not None, the dataset will be None after writing.
179 | """
180 | if not eager and output_path is not None:
181 | logger.warning("Eager mode is enforced because output path is specified")
182 | eager = True
183 |
184 | # Expend output_path in case environment variable is used.
185 | if output_path is not None:
186 | output_path = os.path.expanduser(output_path)
187 |
188 | # Force skipping detokenizer if task is "embed".
189 | if self.task_type_str == "embed" and detokenize:
190 | logger.info("Detokenization is skipped because of embedding workload")
191 | detokenize = False
192 |
193 | ray_remote_args = {}
194 | if self.engine_initializer.accelerator_type:
195 | ray_remote_args["accelerator_type"] = (
196 | self.engine_initializer.accelerator_type
197 | )
198 | ray_remote_args.update({"runtime_env": {"env_vars": self.env_vars}})
199 |
200 | if dataset is not None:
201 | self.ds = dataset
202 | elif self.ds is None:
203 | self.load(repartition_by_batch_size)
204 | assert self.ds is not None
205 |
206 | num_gpus = self.engine_initializer.num_gpus
207 | if self.engine_initializer.use_ray_placement_group:
208 | # Specify the number of GPUs required per LLM instance.
209 | # Note: for TP>1, num_gpus has to be 0 - instead, we specify a placement group
210 | if self.engine_initializer.num_gpus > 1:
211 |
212 | def _scheduling_strategy_fn(
213 | num_gpus_per_instance: int, accelerator_type: str
214 | ):
215 | def _get_bundle() -> Dict[str, float]:
216 | bundle: Dict[str, float] = {"GPU": 1, "CPU": 1}
217 | if accelerator_type:
218 | bundle[f"accelerator_type:{accelerator_type}"] = 0.001
219 | return bundle
220 |
221 | pg = ray.util.placement_group(
222 | [_get_bundle()] * num_gpus_per_instance,
223 | strategy="STRICT_PACK",
224 | )
225 | self.pgs.append(pg)
226 | return dict(
227 | scheduling_strategy=PlacementGroupSchedulingStrategy(
228 | pg, placement_group_capture_child_tasks=True
229 | )
230 | )
231 |
232 | ray_remote_args.update(
233 | _scheduling_strategy_fn(
234 | self.engine_initializer.num_gpus,
235 | self.engine_initializer.accelerator_type,
236 | )
237 | )
238 |
239 | self.ds = self.ds.map_batches(
240 | self.engine_initializer.get_engine_cls(),
241 | fn_constructor_kwargs=self.engine_initializer.get_engine_constructor_args(
242 | self.workload
243 | ),
244 | zero_copy_batch=True,
245 | # The number of running actors.
246 | concurrency=self.env_config.num_replicas,
247 | # The number of running batches for an actor in Ray Core level.
248 | # The value may not be optimal when the batch size is too small,
249 | # but it should be good enough for batch size >= 64.
250 | max_concurrency=4,
251 | batch_size=self.env_config.batch_size,
252 | num_gpus=num_gpus,
253 | **ray_remote_args,
254 | )
255 |
256 | # Skip detokenization. Usually used for tuning, profiling, and embedding.
257 | if detokenize:
258 | self.ds = self.ds.map_batches(
259 | Detokenizer,
260 | fn_constructor_kwargs={"model": self.engine_initializer.model},
261 | zero_copy_batch=True,
262 | concurrency=(1, self.num_replicas),
263 | batch_size=self.env_config.batch_size,
264 | )
265 |
266 | if output_path is not None:
267 | # Dataset will become None after writing to parquet.
268 | self.ds = self.ds.write_parquet(output_path)
269 | elif eager:
270 | self.ds = self.ds.materialize()
271 |
272 | # If the dataset pipeline is executed due to eager mode, we can cleanup.
273 | if eager:
274 | self.cleanup()
275 |
276 | return self.ds
277 |
278 | def cleanup(self):
279 | for pg in self.pgs:
280 | remove_placement_group(pg)
281 | self.pgs.clear()
282 |
--------------------------------------------------------------------------------
/evals/batch/tokenizer.py:
--------------------------------------------------------------------------------
1 | """Tokenizer and detokenizer for LLMs."""
2 |
3 | import time
4 | from typing import Any, AsyncGenerator, Dict, Union
5 |
6 | import numpy as np
7 | from transformers import (
8 | AutoProcessor,
9 | AutoTokenizer,
10 | PreTrainedTokenizer, # type: ignore
11 | PreTrainedTokenizerFast,
12 | )
13 |
14 | from .logging import get_logger
15 | from .utils import async_caller_empty_batch_handler, maybe_download_model_from_s3
16 |
17 | AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, Any]
18 |
19 | logger = get_logger(__name__)
20 |
21 |
22 | def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
23 | """Get tokenizer with cached properties.
24 |
25 | This will patch the tokenizer object in place.
26 | By default, transformers will recompute multiple tokenizer properties
27 | each time they are called, leading to a significant slowdown. This
28 | function caches these properties for faster access.
29 |
30 | Args:
31 | tokenizer: The tokenizer object.
32 |
33 | Returns:
34 | The patched tokenizer object.
35 | """
36 | chat_template = getattr(tokenizer, "chat_template", None)
37 | # For VLM, the text tokenizer is wrapped by a processor.
38 | if hasattr(tokenizer, "tokenizer"):
39 | tokenizer = tokenizer.tokenizer
40 | # Some VLM's tokenizer has chat_template attribute (e.g. Qwen/Qwen2-VL-7B-Instruct),
41 | # however some other VLM's tokenizer does not have chat_template attribute (e.g.
42 | # mistral-community/pixtral-12b). Therefore, we cache the processor's chat_template.
43 | if chat_template is None:
44 | chat_template = getattr(tokenizer, "chat_template", None)
45 |
46 | tokenizer_all_special_ids = set(tokenizer.all_special_ids)
47 | tokenizer_all_special_tokens_extended = tokenizer.all_special_tokens_extended
48 | tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
49 | tokenizer_len = len(tokenizer)
50 |
51 | class CachedTokenizer(tokenizer.__class__): # type: ignore
52 | @property
53 | def all_special_ids(self):
54 | return tokenizer_all_special_ids
55 |
56 | @property
57 | def all_special_tokens(self):
58 | return tokenizer_all_special_tokens
59 |
60 | @property
61 | def all_special_tokens_extended(self):
62 | return tokenizer_all_special_tokens_extended
63 |
64 | @property
65 | def chat_template(self):
66 | return chat_template
67 |
68 | def __len__(self):
69 | return tokenizer_len
70 |
71 | CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
72 |
73 | tokenizer.__class__ = CachedTokenizer
74 | return tokenizer
75 |
76 |
77 | class ChatTemplateTokenizer:
78 | """Tokenizer with chat template applied.
79 |
80 | Args:
81 | model: The model name.
82 | """
83 |
84 | def __init__(self, model: str) -> None:
85 | self.model = maybe_download_model_from_s3(model)
86 | self.tokenizer = get_cached_tokenizer(AutoProcessor.from_pretrained(self.model))
87 |
88 | @async_caller_empty_batch_handler
89 | async def __call__(
90 | self, batch: Dict[str, np.ndarray]
91 | ) -> AsyncGenerator[Dict[str, Any], None]:
92 | """Call the tokenizer to process a batch.
93 | This function first process inputs in the batch asynchronously to apply
94 | chat template because this step cannot be batched. Then it tokenizes all inputs at once.
95 |
96 | Args:
97 | batch: The batch.
98 |
99 | Yields:
100 | The output.
101 | """
102 | if "messages" not in batch:
103 | raise KeyError(f'"messages" not found in {batch.keys()=}')
104 |
105 | start_t = time.perf_counter()
106 | messages = batch["messages"].tolist()
107 |
108 | # Tokenize text prompts.
109 | full_prompts = []
110 | for conversation in messages:
111 | # add generation prompt only if the last message is from the user
112 | add_generation_prompt = conversation[-1]["role"] == "user"
113 | full_prompts.append(
114 | self.tokenizer.apply_chat_template(
115 | conversation,
116 | tokenize=False,
117 | add_generation_prompt=add_generation_prompt,
118 | continue_final_message=not add_generation_prompt,
119 | )
120 | )
121 | tokens = self.tokenizer(full_prompts)["input_ids"]
122 | time_taken_tokenizer = time.perf_counter() - start_t
123 |
124 | ret = {
125 | **batch,
126 | "prompt": full_prompts,
127 | "tokenized_prompt": tokens,
128 | "num_text_tokens": [len(t) for t in tokens],
129 | "time_taken_tokenizer": [time_taken_tokenizer] * len(tokens),
130 | }
131 |
132 | yield ret
133 |
134 |
135 | class Detokenizer:
136 | """Detokenizer for LLMs.
137 |
138 | Args:
139 | model: The model name.
140 | """
141 |
142 | def __init__(self, model: str) -> None:
143 | self.model = maybe_download_model_from_s3(model)
144 | self.tokenizer = get_cached_tokenizer(AutoTokenizer.from_pretrained(self.model))
145 |
146 | async def __call__(
147 | self, batch: Dict[str, np.ndarray]
148 | ) -> AsyncGenerator[Dict[str, Any], None]:
149 | """Detokenize the batch.
150 |
151 | Args:
152 | batch: The batch data.
153 |
154 | Returns:
155 | The detokenized batch.
156 | """
157 | start_t = time.perf_counter()
158 | generated_tokens = batch["generated_tokens"]
159 | flattened = False
160 | # if the generated tokens are nested lists, flatten them
161 | if isinstance(generated_tokens[0][0], np.ndarray):
162 | # flatten the lists of lists for detokenization
163 | flattened = True
164 | generated_tokens = [
165 | token for tokens in generated_tokens for token in tokens
166 | ] # flattens list
167 | generated_text = self.tokenizer.batch_decode(
168 | generated_tokens, skip_special_tokens=True
169 | )
170 | if flattened:
171 | # unflatten the list back to original structure
172 | curr_idx = 0
173 | generated_text_unflattened = []
174 | for sublist in batch["generated_tokens"]:
175 | sublist_len = len(sublist)
176 | generated_text_unflattened.append(
177 | generated_text[curr_idx : curr_idx + sublist_len]
178 | )
179 | curr_idx += sublist_len
180 | generated_text = generated_text_unflattened
181 | time_taken_detokenizer = time.perf_counter() - start_t
182 | yield {
183 | **batch,
184 | "generated_text": generated_text,
185 | "time_taken_detokenizer": [time_taken_detokenizer] * len(generated_text),
186 | }
187 |
--------------------------------------------------------------------------------
/evals/batch/utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions"""
2 |
3 | import os
4 | import subprocess
5 | import time
6 | from functools import wraps
7 | from pathlib import Path
8 | from typing import Any, Callable, Dict, List, Optional
9 |
10 | import pyarrow
11 | import ray
12 | from filelock import FileLock
13 | from huggingface_hub import snapshot_download
14 | from pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit # type: ignore
15 | from ray.data import Dataset
16 |
17 | from .logging import get_logger
18 |
19 | logger = get_logger(__name__)
20 |
21 |
22 | # The default local root directory to store models downloaded from S3.
23 | # This path should always available on Anyscale platform. If not, then
24 | # we will fallback to FALLBACK_LOCAL_MODEL_ROOT.
25 | DEFAULT_LOCAL_MODEL_ROOT = "/mnt/local_storage/cache"
26 | FALLBACK_LOCAL_MODEL_ROOT = "/tmp/cache"
27 |
28 |
29 | def update_dict_recursive(
30 | orig: Dict[str, Any], update_dict: Dict[str, Any]
31 | ) -> Dict[str, Any]:
32 | """Update a dictionary (in-place) recursively.
33 |
34 | Args:
35 | orig: The original dictionary.
36 | update_dict: The dictionary to update.
37 |
38 | Returns:
39 | The updated dictionary.
40 | """
41 | for key, value in update_dict.items():
42 | if isinstance(value, dict):
43 | orig[key] = update_dict_recursive(orig.get(key, {}), value)
44 | else:
45 | orig[key] = value
46 | return orig
47 |
48 |
49 | def wait_for_gpu_memory_to_clear(threshold_bytes: int, timeout_s: float = 120) -> None:
50 | """Wait for GPU memory to be below a threshold.
51 | Use nvml instead of pytorch to reduce measurement error from torch cuda context.
52 |
53 | Args:
54 | threshold_bytes: The threshold in bytes.
55 | timeout_s: The timeout in seconds.
56 |
57 | Raises:
58 | ValueError: If the memory is not free after the timeout.
59 | """
60 | devices = [int(x) for x in ray.get_gpu_ids()]
61 | nvmlInit()
62 | start_time = time.monotonic()
63 | while True:
64 | output = {}
65 | output_raw = {}
66 | for device in devices:
67 | dev_handle = nvmlDeviceGetHandleByIndex(device)
68 | mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
69 | gb_used = mem_info.used / 2**30
70 | output_raw[device] = gb_used
71 | output[device] = f"{gb_used:.02f}"
72 |
73 | logger.info(
74 | "GPU memory used (GB): " + "; ".join(f"{k}={v}" for k, v in output.items())
75 | )
76 |
77 | dur_s = time.monotonic() - start_time
78 | if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
79 | logger.info(
80 | "Done waiting for free GPU memory on devices %s (%.2f GB) %.02f s",
81 | devices,
82 | threshold_bytes / 2**30,
83 | dur_s,
84 | )
85 | break
86 |
87 | if dur_s >= timeout_s:
88 | raise ValueError(
89 | f"Memory of devices {devices=} not free after "
90 | f"{dur_s=:.02f} ({threshold_bytes/2**30=})"
91 | )
92 |
93 | time.sleep(5)
94 |
95 |
96 | def run_s3_command(command: List[str], error_msg: Optional[str] = None) -> Any:
97 | """Run a S3 command and raise an exception if it fails.
98 |
99 | Args:
100 | command: The command to run.
101 | error_msg: The error message to raise if the command fails.
102 |
103 | Returns:
104 | The result of the command.
105 | """
106 | try:
107 | return subprocess.run(command, check=True, capture_output=True)
108 | except Exception as err:
109 | # Not using logger.exception since we raise anyway.
110 | if isinstance(err, (subprocess.TimeoutExpired, subprocess.CalledProcessError)):
111 | stdout_txt = f"\nSTDOUT: {err.stdout.decode()}" if err.stdout else ""
112 | stderr_txt = f"\nSTDERR: {err.stderr.decode()}" if err.stderr else ""
113 | else:
114 | stdout_txt = ""
115 | stderr_txt = ""
116 |
117 | if error_msg is not None:
118 | logger.error(
119 | "(%s) %s. Command %s.%s%s",
120 | str(err),
121 | error_msg,
122 | command,
123 | stdout_txt,
124 | stderr_txt,
125 | )
126 | raise
127 |
128 |
129 | def download_hf_model_from_s3(s3_path: str, local_path_root: str) -> str:
130 | """Download model files from s3 to the local path. The model path prefix
131 | will be added to the local path.
132 |
133 | Args:
134 | s3_path: The s3 path to download from.
135 | local_path_root: The local path root to download to.
136 |
137 | Returns:
138 | The local path where the files are downloaded.
139 | """
140 | if not s3_path.startswith("s3://"):
141 | raise ValueError(f"Invalid s3 path: {s3_path}")
142 |
143 | prefix = "/".join(s3_path.split("/")[3:])
144 | local_path = Path(local_path_root) / prefix
145 |
146 | # Use aws s3 sync to make sure we don't download the same files again.
147 | command = ["aws", "s3", "sync", s3_path, local_path]
148 |
149 | logger.info(
150 | "Downloading %s to %s using %s",
151 | s3_path,
152 | local_path,
153 | command,
154 | )
155 | with FileLock(local_path / ".lock", timeout=-1):
156 | run_s3_command(command, f"Failed to sync model from {s3_path} to {local_path}")
157 | return str(local_path)
158 |
159 |
160 | def maybe_download_model_from_s3(
161 | model_path: str, local_path_root: Optional[str] = None
162 | ) -> str:
163 | """Download model from s3 to the local path, and return the local model path.
164 |
165 | Args:
166 | model_path: The maybe s3 path to download from.
167 | lora_path_root: The local path root to download to. If not provided,
168 | will use the default path (/mnt/local_storage/cache or /tmp/cache).
169 |
170 | Returns:
171 | The local path where the model is downloaded.
172 | """
173 | s3_path = os.path.expandvars(model_path)
174 | if not s3_path.startswith("s3://"):
175 | return model_path
176 |
177 | local_root = Path(local_path_root or DEFAULT_LOCAL_MODEL_ROOT)
178 | try:
179 | local_root.mkdir(parents=True, exist_ok=True)
180 | # Check if the directory is writable.
181 | with open(local_root / ".test", "w") as fp:
182 | fp.write("test")
183 | except PermissionError:
184 | logger.warning(
185 | "Failed to create local root directory at %s (Permission denied). "
186 | "Reset local root to %s",
187 | local_root,
188 | FALLBACK_LOCAL_MODEL_ROOT,
189 | )
190 | local_root = Path(FALLBACK_LOCAL_MODEL_ROOT)
191 | local_root.mkdir(parents=True, exist_ok=True)
192 |
193 | return download_hf_model_from_s3(s3_path, local_root)
194 |
195 |
196 | def download_model_from_hf(
197 | model_name: str, local_path_root: Optional[str] = None
198 | ) -> str:
199 | """Download model files from Hugging Face to the local path.
200 | If the local path has permission issues, return the original model name, but warn the user.
201 |
202 | Args:
203 | model_name: The model name to download.
204 | local_path_root: The local path root to download to. If not provided,
205 | will use the default path (/mnt/local_storage/cache or /tmp/cache
206 |
207 | Returns:
208 | The local path where the files are downloaded.
209 | """
210 | # If the model_name is already a local path, skip downloading
211 | if model_name.startswith("/"):
212 | return model_name
213 |
214 | local_model_path = Path(local_path_root or DEFAULT_LOCAL_MODEL_ROOT) / model_name
215 | try:
216 | local_model_path.mkdir(parents=True, exist_ok=True)
217 |
218 | # Check directory is writable by trying to list files (avoiding .test file creation)
219 | if not os.access(local_model_path, os.W_OK):
220 | raise PermissionError
221 | except PermissionError:
222 | logger.warning(
223 | "Failed to create or write to the model directory at %s (Permission denied). "
224 | "Please grant permission, or each worker may download the model, hitting rate limits.",
225 | local_model_path,
226 | )
227 | return model_name # Return the original model name
228 |
229 | snapshot_download(repo_id=model_name, local_dir=str(local_model_path))
230 |
231 | return str(local_model_path)
232 |
233 |
234 | def async_caller_empty_batch_handler(func) -> Callable:
235 | """A decorator to handle the case where all rows are checkpointed.
236 | When all rows are checkpointed, we will still get a batch
237 | in pyarrow.Table format with empty rows. This is a bug and
238 | is being tracked here:
239 | https://github.com/anyscale/rayturbo/issues/1292
240 |
241 | Args:
242 | func: The function to wrap.
243 |
244 | Returns:
245 | The wrapped function.
246 | """
247 |
248 | @wraps(func)
249 | async def wrapper(self, batch):
250 | if not isinstance(batch, pyarrow.lib.Table) or batch.num_rows > 0:
251 | async for x in func(self, batch):
252 | yield x
253 | else:
254 | yield {}
255 |
256 | return wrapper
257 |
258 |
259 | def has_materialized(ds: Dataset) -> bool:
260 | """Check if the dataset has been materialized.
261 | TODO: This API should be moved to Ray Data.
262 |
263 | Args:
264 | ds: The dataset to check.
265 |
266 | Returns:
267 | True if the dataset is materialized, False otherwise.
268 | """
269 | return bool(ds.stats())
270 |
--------------------------------------------------------------------------------
/evals/batch/workload.py:
--------------------------------------------------------------------------------
1 | """The workload."""
2 |
3 | import math
4 | from dataclasses import dataclass, field
5 | from pathlib import Path
6 | from typing import Any, Dict, Optional, Tuple
7 |
8 | import yaml
9 | from ray.data.dataset import Dataset
10 |
11 | from .logging import get_logger
12 | from .tokenizer import ChatTemplateTokenizer
13 |
14 | logger = get_logger(__name__)
15 |
16 |
17 | def load_config_from_path(config_path: str) -> Dict[str, Any]:
18 | if isinstance(config_path, str):
19 | config_path = Path(config_path)
20 | if not config_path.exists():
21 | raise FileNotFoundError(f"Engine config file {config_path} not found.")
22 | with open(config_path, "r") as filep:
23 | config = yaml.safe_load(filep)
24 |
25 | assert isinstance(config, dict)
26 | return config
27 |
28 |
29 | @dataclass
30 | class EvalWorkload:
31 | # The ray.data.Dataset. If None, the Worklod must initialize the dataset
32 | # in __post_init__().
33 | dataset: Optional[Dataset]
34 | # Sampling a fraction of dataset for benchmarking and testing. If the value
35 | # is greater than one, it means to take the first N rows from the dataset.
36 | dataset_fraction: float = 1.0
37 | # Tokenizer class for the workload.
38 | tokenizer_cls: Any = ChatTemplateTokenizer
39 |
40 | # Sampling parameters for the workload, such as max_tokens, temperature, etc.
41 | # It can only be None when the workload is used for embedding.
42 | sampling_params: Dict[str, Any] = field(
43 | default_factory=lambda: {"max_tokens": 4096}
44 | )
45 | # Pooling parameters for the workload, such as pooling_type, etc.
46 | # It can only be None when the workload is used for auto-regressive generation.
47 | pooling_params: Optional[Dict[str, Any]] = None
48 |
49 | need_tokenize: bool = True
50 | # When specified, the tokenization will be async because we don't need to
51 | # materialize an entire tokenized dataset to get the maximum tokens in prompt.
52 | # With the default value of -1, the actual value will be set after tokenization.
53 | max_tokens_in_prompt: int = -1
54 |
55 | # Do we want to carry over input keys that are not in the output?
56 | carryover_inputs: bool = True
57 |
58 | def validate(self):
59 | if not ((self.sampling_params is None) ^ (self.pooling_params is None)):
60 | raise ValueError(
61 | "Either sampling_params or pooling_params must be specified."
62 | )
63 |
64 | def get_preprocessed_dataset(
65 | self,
66 | max_batch_size: int = 256,
67 | repartition_by_batch_size: bool = False,
68 | ) -> Tuple[Dataset, Optional[int]]:
69 | """Load the dataset and process it.
70 |
71 | Args:
72 | max_batch_size: The batch size. This determines the number of rows per
73 | block. Note that if some rows have already processed (checkpointed),
74 | the actual batch size may be smaller than this value.
75 | repartition_by_batch_size: Whether to repartition the dataset by the
76 | batch size for fault tolerance granularity. You should enable
77 | this when the dataset is not from parquet and checkpointing is
78 | disabled.
79 |
80 | Returns:
81 | The processed dataset and the number of blocks. If checkpointing is
82 | enabled, then the number of blocks is unknown.
83 | """
84 | self.validate()
85 | if self.dataset is None:
86 | raise ValueError(
87 | "dataset must be specified or initialized before calling "
88 | "get_preprocessed_dataset()."
89 | )
90 |
91 | self.max_batch_size = max_batch_size
92 |
93 | ds = self.dataset
94 | if self.dataset_fraction < 1.0:
95 | logger.info("Sampling %f dataset", self.dataset_fraction)
96 | ds = ds.random_sample(self.dataset_fraction, seed=0)
97 | elif self.dataset_fraction > 1.0:
98 | n_rows = int(self.dataset_fraction)
99 | logger.info("Taking the first %d rows from dataset", n_rows)
100 | ds = ds.limit(n_rows)
101 |
102 | if repartition_by_batch_size:
103 | num_requests = ds.count()
104 | num_blocks = math.ceil(num_requests / max_batch_size)
105 | ds = ds.repartition(num_blocks)
106 |
107 | logger.info("#Requests: %d (%d blocks)", num_requests, num_blocks)
108 | else:
109 | # When checkpointing is enabled, the number of blocks is unknown
110 | # at this point.
111 | num_blocks = None
112 |
113 | mapper_fn = (
114 | self.parse_row_with_carryover_input
115 | if self.carryover_inputs
116 | else self.parse_row
117 | )
118 | return ds.map(mapper_fn), num_blocks
119 |
120 | def tokenizer_constructor_kwargs(self, model: str):
121 | """Return the keyword arguments for tokenizer constructor.
122 |
123 | Args:
124 | model: The model name.
125 |
126 | Returns:
127 | The keyword arguments for tokenizer constructor.
128 | """
129 | return {"model": model}
130 |
131 | def parse_row_with_carryover_input(self, row: dict[str, Any]) -> dict[str, Any]:
132 | """Same as parse_row but carries over the input keys that are not in the output row.
133 |
134 | This is useful when we want to keep the input keys in the output.
135 | This method assumes if user returns the same output keys as
136 | input keys they have already copied input over and there is
137 | no need to do it again for those keys. We will just copy the input_keys that
138 | are not in the output row.
139 |
140 | Args:
141 | row: The row to be parsed.
142 |
143 | Returns:
144 | The parsed row.
145 | """
146 | input_row_keys = set(row.keys())
147 | output_row = self.parse_row(row)
148 | output_row_keys = set(output_row.keys())
149 | return {
150 | **{k: row[k] for k in input_row_keys if k not in output_row_keys},
151 | **output_row,
152 | }
153 |
154 | def parse_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
155 | """Parse each row in the dataset to make them compatible with
156 | OpenAI chat API messages. Specifically, the output row should only
157 | include a single key "messages" with type Dict[str, Union[str, List[Dict]]].
158 | """
159 | return {"messages": row["item"][1], "index": row["item"][0]}
160 |
--------------------------------------------------------------------------------
/evals/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanomaoli/llm_reproducibility/8a373c5a159a27e59783394827cecadd6255484e/evals/common/__init__.py
--------------------------------------------------------------------------------
/evals/common/entities.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from dataclasses import dataclass
3 | from enum import Enum
4 | from importlib import resources
5 | from pathlib import Path
6 | from typing import Literal, Optional, Union
7 |
8 | import yaml
9 | from openai import NOT_GIVEN, NotGiven
10 | from openai.types.chat import ChatCompletionReasoningEffort
11 | from pydantic import BaseModel, ConfigDict, Field
12 | from vllm import SamplingParams as VLLMSamplingParams
13 |
14 | TEMPERATURE_DEFAULT = 0
15 | TOP_P_DEFAULT = 1
16 | MAX_TOKENS_DEFAULT = 32768
17 |
18 |
19 | class Backend(str, Enum):
20 | VLLM = "vllm"
21 | OPENAI = "openai"
22 | RAY = "ray"
23 |
24 |
25 | class OpenAISamplingParams(BaseModel):
26 | model_config = ConfigDict(arbitrary_types_allowed=True)
27 |
28 | temperature: float = TEMPERATURE_DEFAULT
29 | top_p: float = TOP_P_DEFAULT
30 | n: int = 1
31 | max_tokens: int = MAX_TOKENS_DEFAULT
32 | reasoning_effort: Union[ChatCompletionReasoningEffort, NotGiven] = NOT_GIVEN
33 | frequency_penalty: Optional[float] = None
34 |
35 |
36 | class SamplingParameters(BaseModel):
37 | model_config = ConfigDict(arbitrary_types_allowed=True)
38 |
39 | params: Union[OpenAISamplingParams, VLLMSamplingParams]
40 |
41 | @classmethod
42 | def from_dict(cls, backend: Backend, params: dict):
43 | params = copy.deepcopy(params)
44 | if backend == Backend.OPENAI:
45 | return cls(params=OpenAISamplingParams(**params))
46 | # Currently, ray-data based processor only supports vllm as the inference engine
47 | elif backend in [Backend.VLLM, Backend.RAY]:
48 | return cls(params=VLLMSamplingParams(**params))
49 | else:
50 | raise ValueError(f"Invalid backend type: {backend}")
51 |
52 | def __repr__(self):
53 | return f"SamplingParameters(params={self.params})"
54 |
55 | def to_dict(self):
56 | if isinstance(self.params, OpenAISamplingParams):
57 | return self.params.model_dump()
58 | elif isinstance(self.params, VLLMSamplingParams):
59 | return {k: getattr(self.params, k) for k in self.params.__annotations__}
60 | else:
61 | raise ValueError(f"Invalid sampling parameters type: {type(self.params)}")
62 |
63 |
64 | class OpenAIClientArgs(BaseModel):
65 | api_key: Optional[str] = Field(None, description="OpenAI API key")
66 | base_url: Optional[str] = Field(None, description="OpenAI base URL")
67 | project: Optional[str] = Field(None, description="OpenAI project")
68 | organization: Optional[str] = Field(None, description="OpenAI organization")
69 |
70 |
71 | class RayLLMEngineArgs(BaseModel):
72 |
73 | tensor_parallel_size: Optional[int] = Field(
74 | default=None, description="Tensor parallelism size"
75 | )
76 | num_replicas: Optional[int] = Field(
77 | default=None, description="Number of replicas to use for Ray"
78 | )
79 | batch_size: Optional[int] = Field(default=None, description="Batch size for Ray")
80 | accelerator_type: Optional[str] = Field(
81 | default=None, description="Accelerator type for the inference engine"
82 | )
83 | gpu_memory_utilization: Optional[float] = Field(
84 | default=None, description="GPU memory utilization for the inference engine"
85 | )
86 | dtype: Optional[Literal["float32", "float16", "bfloat16", "float8", "auto"]] = (
87 | Field(default=None, description="Data type for inference engine.")
88 | )
89 |
90 | def get_ray_llm_config(self):
91 | config_path = Path(
92 | resources.files("skythought.evals").joinpath("ray_configs/ray_config.yaml")
93 | )
94 | with open(config_path) as f:
95 | default_config = yaml.safe_load(f)
96 |
97 | if self.tensor_parallel_size is not None:
98 | default_config["engine_kwargs"][
99 | "tensor_parallel_size"
100 | ] = self.tensor_parallel_size
101 |
102 | if self.num_replicas is not None:
103 | default_config["env_config"]["num_replicas"] = self.num_replicas
104 |
105 | if self.batch_size is not None:
106 | default_config["env_config"]["batch_size"] = self.batch_size
107 |
108 | if self.accelerator_type is not None:
109 | default_config["accelerator_type"] = self.accelerator_type
110 |
111 | if self.gpu_memory_utilization is not None:
112 | default_config["engine_kwargs"][
113 | "gpu_memory_utilization"
114 | ] = self.gpu_memory_utilization
115 |
116 | # FIXME (sumanthrh): there can be a corner case when we support providing a config yaml directly, and this will override the dtype
117 | if self.dtype is not None:
118 | default_config["engine_kwargs"]["dtype"] = self.dtype
119 |
120 | return default_config
121 |
122 |
123 | @dataclass
124 | class BackendParameters:
125 | model_config = ConfigDict(arbitrary_types_allowed=True)
126 |
127 | params: Union[dict, OpenAIClientArgs, RayLLMEngineArgs]
128 |
129 | @classmethod
130 | def from_dict(cls, backend_type: Backend, params: dict):
131 | if backend_type == Backend.RAY:
132 | return cls(params=RayLLMEngineArgs(**params))
133 | elif backend_type == Backend.VLLM:
134 | # passed directly to LLM(..) instantiation
135 | return cls(params=params)
136 | elif backend_type == Backend.OPENAI:
137 | return cls(params=OpenAIClientArgs(**params))
138 | else:
139 | raise ValueError(f"Invalid backend type: {backend_type}")
140 |
141 | def to_dict(self):
142 | if isinstance(self.params, RayLLMEngineArgs):
143 | return self.params.model_dump()
144 | elif isinstance(self.params, dict):
145 | return self.params
146 | elif isinstance(self.params, OpenAIClientArgs):
147 | return self.params.model_dump()
148 | else:
149 | raise ValueError(f"Invalid backend parameters type: {type(self.params)}")
150 |
--------------------------------------------------------------------------------
/evals/labeled_numina_difficulty/README.md:
--------------------------------------------------------------------------------
1 | # Labeled NUMINA Difficulty Data
2 |
3 | We also include data of labeled difficulty from NUMINA, in the following files: `labeled_amc_aime_0_-1.json`, `labeled_math_0_-1.json`, `labeled_olympiads_0_-1.json`. These files can be found and downloaded from [HuggingFace](https://huggingface.co/datasets/NovaSky-AI/labeled_numina_difficulty).
--------------------------------------------------------------------------------
/evals/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import ModelConfig, get_system_prompt_keys
2 |
3 | __all__ = ["ModelConfig", "get_system_prompt_keys"]
4 |
--------------------------------------------------------------------------------
/evals/models/base.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from pathlib import Path
3 | from typing import Optional, Union
4 |
5 | import yaml
6 | from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
7 |
8 | MODEL_CONFIG_FILE_PATH = Path(__file__).parent / "model_configs.yaml"
9 | # cache the configs in a global var
10 | ALL_MODEL_CONFIGS = None
11 |
12 |
13 | class StringInFile(BaseModel):
14 | path: str
15 | _string: str = PrivateAttr(default=None)
16 |
17 | @model_validator(mode="after")
18 | def validate_and_extract_string(self):
19 | full_path = Path(MODEL_CONFIG_FILE_PATH).parent / self.path
20 | if full_path.exists():
21 | with open(full_path, "r") as f:
22 | self._string = f.read()
23 | else:
24 | raise ValueError("Invalid path")
25 | return self
26 |
27 | @property
28 | def string(self):
29 | return self._string
30 |
31 | def __str__(self) -> str:
32 | return self._string
33 |
34 |
35 | def read_yaml(path: str):
36 | with open(path, "r") as f:
37 | return yaml.safe_load(f)
38 |
39 |
40 | class ModelConfig(BaseModel):
41 | model_config = ConfigDict(protected_namespaces=())
42 |
43 | model_id: str
44 | name: Optional[str] = Field(default=None)
45 | # can be a string or a path to a file with the string
46 | system_prompt: Optional[Union[str, StringInFile]] = None
47 | user_template: Optional[Union[str, StringInFile]] = None
48 | assistant_prefill: Optional[str] = None
49 |
50 | @model_validator(mode="after")
51 | def validate_name(self):
52 | if self.name is None:
53 | self.name = self.model_id.split("/")[-1]
54 | return self
55 |
56 | @classmethod
57 | def from_model_id(
58 | cls,
59 | model_id: str,
60 | system_prompt_name: Optional[str] = None,
61 | system_prompt: Optional[str] = None,
62 | assistant_prefill: Optional[str] = None,
63 | ):
64 | global ALL_MODEL_CONFIGS
65 | # only one of the two can be provided
66 | assert (
67 | system_prompt_name is None or system_prompt is None
68 | ), "Only one of `system_prompt_name` or `system_prompt` can be provided"
69 | init_kwargs = {}
70 | if ALL_MODEL_CONFIGS is None:
71 | ALL_MODEL_CONFIGS = read_yaml(MODEL_CONFIG_FILE_PATH)
72 | if model_id in ALL_MODEL_CONFIGS["models"]:
73 | init_kwargs = ALL_MODEL_CONFIGS["models"][model_id]
74 |
75 | if system_prompt_name:
76 | if system_prompt_name not in ALL_MODEL_CONFIGS["system_prompts"]:
77 | raise ValueError(
78 | f"Invalid system prompt template {system_prompt_name} provided."
79 | )
80 | init_kwargs["system_prompt"] = ALL_MODEL_CONFIGS["system_prompts"][
81 | system_prompt_name
82 | ]
83 | elif system_prompt:
84 | init_kwargs["system_prompt"] = system_prompt
85 | # if none was provided, and the model is not in the config file
86 | elif model_id not in ALL_MODEL_CONFIGS["models"]:
87 | init_kwargs = {}
88 | warnings.warn(
89 | f"Model {model_id} not found in {MODEL_CONFIG_FILE_PATH}. Initializing without any system prompt.",
90 | stacklevel=2,
91 | )
92 |
93 | if assistant_prefill:
94 | init_kwargs["assistant_prefill"] = assistant_prefill
95 |
96 | init_kwargs["model_id"] = model_id
97 | return cls(**init_kwargs)
98 |
99 |
100 | def get_system_prompt_keys():
101 | global ALL_MODEL_CONFIGS
102 | if ALL_MODEL_CONFIGS is None:
103 | ALL_MODEL_CONFIGS = read_yaml(MODEL_CONFIG_FILE_PATH)
104 | return list(ALL_MODEL_CONFIGS["system_prompts"].keys())
105 |
--------------------------------------------------------------------------------
/evals/models/model_configs.yaml:
--------------------------------------------------------------------------------
1 | system_prompts:
2 | qwen_cot: &qwen_cot_system_prompt "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."
3 | prime_rl: &prime_rl_system_prompt
4 | # system prompt can also point to a text file. the path to the file should be relative to the parent dir of model_configs.yaml
5 | path: system_prompts/prime.txt
6 | skythought: &sky_t1_system_prompt "Your role as an assistant involves thoroughly exploring questions through a systematic long \
7 | thinking process before providing the final precise and accurate solutions. This requires \
8 | engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, \
9 | backtracing, and iteration to develop well-considered thinking process. \
10 | Please structure your response into two main sections: Thought and Solution. \
11 | In the Thought section, detail your reasoning process using the specified format: \
12 | <|begin_of_thought|> {thought with steps separated with '\n\n'} \
13 | <|end_of_thought|> \
14 | Each step should include detailed considerations such as analisying questions, summarizing \
15 | relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining \
16 | any errors, and revisiting previous steps. \
17 | In the Solution section, based on various attempts, explorations, and reflections from the Thought \
18 | section, systematically present the final solution that you deem correct. The solution should \
19 | remain a logical, accurate, concise expression style and detail necessary step needed to reach the \
20 | conclusion, formatted as follows: \
21 | <|begin_of_solution|> \
22 | {final formatted, precise, and clear solution} \
23 | <|end_of_solution|> \
24 | Now, try to solve the following question through the above guidelines:"
25 |
26 | user_templates: null
27 | # Example:
28 | # o1_mini: &o1_mini "Question: {input}\nAnswer: "
29 |
30 | models:
31 | o1-mini:
32 | # 'name' is by default in the huggingface format /, but can be customized here
33 | name: o1-mini
34 | system_prompt: null
35 | # user template's use positional argument for formatting
36 | user_template: "Question: {}\nAnswer: "
37 |
38 | o1-preview:
39 | system_prompt: null
40 | user_template: "Question: {}\nAnswer: "
41 |
42 | gpt-4o-mini:
43 | system_prompt: null
44 | user_template: "User: {}\nAssistant: "
45 |
46 | Qwen/Qwen2-7B-Instruct:
47 | system_prompt: *qwen_cot_system_prompt
48 |
49 | Qwen/QwQ-32B-Preview:
50 | system_prompt: *qwen_cot_system_prompt
51 |
52 | Qwen/Qwen2.5-72B-Instruct:
53 | system_prompt: *qwen_cot_system_prompt
54 |
55 | Qwen/Qwen2.5-32B-Instruct:
56 | system_prompt: *qwen_cot_system_prompt
57 |
58 | Qwen/Qwen2.5-7B-Instruct:
59 | system_prompt: *qwen_cot_system_prompt
60 |
61 | Qwen/Qwen2.5-1.5B-Instruct:
62 | system_prompt: *qwen_cot_system_prompt
63 |
64 | Qwen/Qwen2.5-Math-7B-Instruct:
65 | system_prompt: *qwen_cot_system_prompt
66 |
67 | Qwen/Qwen2.5-Math-72B-Instruct:
68 | system_prompt: *qwen_cot_system_prompt
69 |
70 | PRIME-RL/Eurus-2-7B-PRIME:
71 | system_prompt: *prime_rl_system_prompt
72 |
73 | NovaSky-AI/Sky-T1-32B-Preview:
74 | system_prompt: *sky_t1_system_prompt
75 |
76 | NovaSky-AI/Sky-T1-32B-Flash:
77 | system_prompt: *sky_t1_system_prompt
--------------------------------------------------------------------------------
/evals/models/system_prompts/prime.txt:
--------------------------------------------------------------------------------
1 | When tackling complex reasoning tasks, you have access to the following actions. Use them as needed to progress through your thought process. After each action, determine and state the next most appropriate action to take.
2 |
3 | Actions:
4 |
5 | {actions}
6 |
7 | Your action should contain multiple steps, and each step starts with #. After each action (except OUTPUT), state which action you will take next with ''Next action: [Your action]'' and finish this turn. Continue this process until you reach a satisfactory conclusion or solution to the problem at hand, at which point you should use the [OUTPUT] action. The thought process is completely invisible to user, so [OUTPUT] should be a complete response. You should strictly follow the format below:
8 |
9 | [ACTION NAME]
10 |
11 | # Your action step 1
12 |
13 | # Your action step 2
14 |
15 | # Your action step 3
16 |
17 | ...
18 |
19 | Next action: [NEXT ACTION NAME]
20 |
21 |
22 | Now, begin with the [ASSESS] action for the following task:
--------------------------------------------------------------------------------
/evals/ray_configs/ray_config.yaml:
--------------------------------------------------------------------------------
1 | llm_engine: vllm # currently only vllm supported
2 | accelerator_type: null # accelerator name as specified here: https://docs.ray.io/en/master/ray-core/accelerator-types.html#accelerator-types
3 | engine_kwargs: # vllm engine kwargs
4 | tensor_parallel_size: 4
5 | gpu_memory_utilization: 0.9
6 | dtype: auto
7 | # other optional vllm engine kwargs to tune performance!
8 | # pipeline_parallel_size: 1
9 | # max_num_seqs: 448
10 | # use_v2_block_manager: True
11 | # enable_prefix_caching: False
12 | # preemption_mode: "recompute"
13 | # block_size: 16
14 | # kv_cache_dtype: "auto"
15 | # enforce_eager: False
16 | # enable_chunked_prefill: True
17 | # max_num_batched_tokens: 8192
18 | # max_seq_len_to_capture: 32768
19 | runtime_env:
20 | env_vars:
21 | VLLM_ATTENTION_BACKEND: "FLASH_ATTN"
22 | env_config:
23 | num_replicas: 2 # number of vllm replicas
24 | batch_size: 128 # ray pipeline internal batch size (used for map_batches call internally). Should usually be set to a value in [64, 128, 256] for best performance.
25 |
--------------------------------------------------------------------------------
/evals/scoring/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Scorer
2 | from .gsm8k import GSM8KScorer
3 | from .math import MathEqualScorer, MathVerifyScorer
4 |
5 | __all__ = ["Scorer", "MathEqualScorer", "MathVerifyScorer", "GSM8KScorer"]
6 |
--------------------------------------------------------------------------------
/evals/scoring/apps/__init__.py:
--------------------------------------------------------------------------------
1 | from .apps_scorer import APPSScorer
2 |
3 | __all__ = ["APPSScorer"]
4 |
--------------------------------------------------------------------------------
/evals/scoring/apps/apps_scorer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import multiprocessing
4 | from multiprocessing import Manager
5 | from typing import Any, Dict, List, Literal
6 |
7 | import numpy as np
8 | import ray
9 | from ray.exceptions import GetTimeoutError
10 |
11 | from ..base import Scorer
12 | from ...util.common import has_code
13 |
14 | from .apps_util import run_test as apps_run_test
15 |
16 |
17 | class APPSScorer(Scorer):
18 | """Scorer for the APPS dataset
19 |
20 | For the APPS dataset format, see https://huggingface.co/datasets/codeparrot/apps
21 |
22 | Args:
23 | response_column: The column name for the response (str).
24 | solutions_column: The column name with solutions (str).
25 | input_output_column: The column name with the test inputs and outputs (str).
26 | keyword_args_column: The column name for the keyword arguments to the instruction builder (str).
27 | key_column: The column name for the unique identifier (str).
28 | backend: The backend to use for scoring. Supports "ray" or "mp" (str).
29 | """
30 |
31 | SCORE_COLUMN = "apps_score"
32 | # timeout per sample
33 | TIMEOUT = 10
34 |
35 | def __init__(
36 | self,
37 | response_column="response",
38 | solutions_column="solutions",
39 | input_output_column="input_output",
40 | backend: Literal["mp", "ray"] = "ray",
41 | ) -> None:
42 | super().__init__()
43 | self.response_column = response_column
44 | self.solutions_column = solutions_column
45 | self.input_output_column = input_output_column
46 | self.backend = backend
47 | if self.backend not in ["mp", "ray"]:
48 | raise ValueError(f"Invalid backend for `APPSScorer`: {self.backend}")
49 |
50 | def score(self, row: Dict[str, Any]):
51 |
52 | code_filter_result = has_code(row[self.response_column])
53 | if len(code_filter_result) == 0:
54 | return {self.SCORE_COLUMN: False}
55 | else:
56 | last_code = code_filter_result[-1]
57 | problem_to_check = copy.deepcopy(row)
58 | problem_to_check[self.input_output_column] = json.loads(
59 | row[self.input_output_column]
60 | )
61 | try:
62 | problem_to_check[self.solutions_column] = json.loads(
63 | row[self.solutions_column]
64 | )
65 | except Exception:
66 | problem_to_check[self.solutions_column] = ""
67 |
68 | if self.backend == "ray":
69 | score = _run_test_ray(
70 | problem_to_check[self.input_output_column],
71 | last_code,
72 | self.TIMEOUT,
73 | False,
74 | )
75 | else:
76 | score = _run_test_mp(
77 | problem_to_check[self.input_output_column],
78 | last_code,
79 | self.TIMEOUT,
80 | False,
81 | )
82 | return {self.SCORE_COLUMN: score}
83 |
84 |
85 | # NOTE (sumanthrh): We make sure that scoring for code generation is run on a separate process for isolation
86 | # We need to run scoring for each data sample in a separate process. Since ray doesn't play well with
87 | # multiprocessing, we launch scoring as a standalone ray task. Further, to make sure that resource requests
88 | # don't blow up for batched processing- for example, in a ray data pipeline, we reduce `num_cpus` to 0.01 from the default
89 | # value of 1. That way, scoring for different samples can timeshare on the same set of cpus.
90 | @ray.remote(num_cpus=0.001)
91 | def _temp_run_ray(input_outputs, generation, debug) -> List[bool]:
92 | try:
93 | result: List[bool] = apps_run_test(input_outputs, test=generation, debug=debug)
94 | return result
95 | except Exception:
96 | pass
97 | return []
98 |
99 |
100 | def _run_test_ray(input_outputs, generation, timeout, debug):
101 | try:
102 | result = ray.get(
103 | _temp_run_ray.remote(input_outputs, generation, debug),
104 | timeout=timeout + 1,
105 | )
106 | except GetTimeoutError:
107 | result = []
108 | return bool(result and np.all(result))
109 |
110 |
111 | def _run_test_mp(input_outputs, generation, timeout, debug):
112 | def _temp_run(input_outputs, generation, debug, result) -> List[List[bool]]:
113 | try:
114 | result.append(
115 | apps_run_test(input_outputs=input_outputs, test=generation, debug=debug)
116 | )
117 | except Exception:
118 | pass
119 |
120 | manager = Manager()
121 | result = manager.list()
122 | p = multiprocessing.Process(
123 | target=_temp_run, args=(input_outputs, generation, False, result)
124 | )
125 | p.start()
126 | p.join(timeout=timeout + 1)
127 | if p.is_alive():
128 | p.kill()
129 | return bool(result and np.all(result[0]))
130 |
--------------------------------------------------------------------------------
/evals/scoring/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, AsyncIterator, Dict, List
3 |
4 |
5 | class Scorer(ABC):
6 | """Abstract base class for scorers."""
7 |
8 | SCORE_COLUMN = "score"
9 |
10 | @abstractmethod
11 | def score(self, row: dict) -> Dict[str, Any]:
12 | """Scores a single row of data
13 |
14 | Args:
15 | row: A dictionary containing the data to score. (dict)
16 |
17 | Returns:
18 | A dictionary containing the score and any other relevant information.
19 | """
20 | pass
21 |
22 | def __call__(self, row: dict):
23 | return {**row, **self.score(row)}
24 |
25 |
26 | class BatchScorer(ABC):
27 | """
28 | Abstract base class for batch scorers.
29 | """
30 |
31 | SCORE_COLUMN = "score"
32 |
33 | INTERNAL_IDX_KEY = "__internal_idx__"
34 |
35 | @abstractmethod
36 | async def score(self, rows: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]:
37 | """Scores a batch of data
38 |
39 | Args:
40 | rows: list of input dictionaries. (list)
41 |
42 | Returns:
43 | An async iterator of dictionaries containing the score and any other relevant information.
44 | """
45 | pass
46 |
47 | async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]:
48 | """Scores a batch of data
49 |
50 | Yields results for each row in the batch as they finish.
51 |
52 | Args:
53 | batch: A dictionary containing the data to score. (dict)
54 |
55 | Returns:
56 | An async iterator of dictionaries containing the score and any other relevant information.
57 | """
58 | key = next(iter(batch.keys()))
59 | value = batch[key]
60 | num_rows = len(value)
61 | if hasattr(value, "tolist"):
62 | batch = {k: v.tolist() for k, v in batch.items()}
63 | else:
64 | batch = {k: list(v) for k, v in batch.items()}
65 | batch[self.INTERNAL_IDX_KEY] = list(range(num_rows))
66 | rows = [{k: batch[k][i] for k in batch.keys()} for i in range(num_rows)]
67 | async for result in self.score(rows):
68 | if self.INTERNAL_IDX_KEY not in result:
69 | raise ValueError(
70 | f"`score` function must yield dictionaries with the key {self.INTERNAL_IDX_KEY}"
71 | )
72 | idx = result[self.INTERNAL_IDX_KEY]
73 | row = rows[idx]
74 | yield {**row, **result}
75 |
--------------------------------------------------------------------------------
/evals/scoring/gsm8k/__init__.py:
--------------------------------------------------------------------------------
1 | from .gsm8k_scorer import GSM8KScorer
2 |
3 | __all__ = ["GSM8KScorer"]
4 |
--------------------------------------------------------------------------------
/evals/scoring/gsm8k/gsm8k_scorer.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Any, Dict, List
3 |
4 | from ...util.math_parsing_util import extract_answer, math_equal
5 |
6 | from ..base import Scorer
7 |
8 |
9 | class GSM8KScorer(Scorer):
10 | """Scorer for GSM8K based on the `math_equal` function from Qwen Math
11 |
12 | Args:
13 | response_column: The column name for the model generated response.
14 | answer_column: The column name for the ground truth answer.
15 | """
16 |
17 | SCORE_COLUMN = "gsm8k_score"
18 | INVALID_ANS = "[invalid]"
19 | GT_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
20 | ANS_RE = re.compile(r"((-?[$0-9.,]{2,})|(-?[0-9]+))")
21 |
22 | def __init__(self, response_column: str, answer_column: str):
23 |
24 | self.response_column = response_column
25 | self.answer_column = answer_column
26 |
27 | def score(self, row: dict) -> Dict[str, Any]:
28 | try:
29 | pred = self.extract_pred_from_response(row[self.response_key])
30 | ref = self.extract_gt_answer(row[self.answer_key])
31 | except Exception:
32 | return False
33 | return {
34 | self.SCORE_COLUMN: math_equal(pred, ref),
35 | }
36 |
37 | def extract_gt_answer(self, completion):
38 | match = self.GT_RE.search(completion)
39 | if match:
40 | match_str = match.group(1).strip()
41 | match_str = match_str.replace(",", "")
42 | return match_str
43 | else:
44 | return self.INVALID_ANS
45 |
46 | def extract_pred_from_response(self, response):
47 | answer = extract_answer(response)
48 | answer = self.sanitize_answer(response)
49 | return answer
50 |
51 | def sanitize_answer(self, answer):
52 | patterns_to_remove = [
53 | ",", # Remove commas
54 | r"\$", # Remove dollar signs
55 | r"\.$" r"\*", # Remove trailing period # Remove asterisks
56 | ]
57 | for pattern in patterns_to_remove:
58 | answer = re.sub(pattern, "", answer)
59 |
60 | matches = self.ANS_RE.findall(answer)
61 | if matches:
62 | # get the last match (i.e final response) and the first / outer capturing group
63 | match_str = matches[-1][0].strip()
64 | return match_str
65 | else:
66 | return self.INVALID_ANS
67 |
68 | @property
69 | def expected_keys(self) -> List[str]:
70 | return [self.response_column, self.answer_column]
71 |
--------------------------------------------------------------------------------
/evals/scoring/ifeval/__init__.py:
--------------------------------------------------------------------------------
1 | from .ifeval_scorer import IfEvalScorer
2 |
3 | __all__ = ["IfEvalScorer"]
4 |
--------------------------------------------------------------------------------
/evals/scoring/ifeval/ifeval_scorer.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 | from instructions_main import (
4 | InputExample,
5 | test_instruction_following_loose,
6 | test_instruction_following_strict,
7 | )
8 |
9 | from ..base import Scorer
10 |
11 |
12 | def process_results(doc, response):
13 | inp = InputExample(
14 | key=doc["key"],
15 | instruction_id_list=doc["instruction_id_list"],
16 | prompt=doc["prompt"],
17 | kwargs=doc["kwargs"],
18 | )
19 |
20 | out_strict = test_instruction_following_strict(inp, response)
21 | out_loose = test_instruction_following_loose(inp, response)
22 |
23 | return {
24 | "prompt_level_strict_acc": out_strict.follow_all_instructions,
25 | "inst_level_strict_acc": out_strict.follow_instruction_list,
26 | "prompt_level_loose_acc": out_loose.follow_all_instructions,
27 | "inst_level_loose_acc": out_loose.follow_instruction_list,
28 | }
29 |
30 |
31 | class IfEvalScorer(Scorer):
32 | """Scorer for the IF-Eval task
33 |
34 | For the IFEval dataset format, see https://huggingface.co/datasets/google/IFEval
35 |
36 | Args:
37 | instruction_ids_column: The column name for the list of instruction ids (str).
38 | prompt_column: The column name for the prompt (str).
39 | keyword_args_column: The column name for the keyword arguments to the instruction builder (str).
40 | key_column: The column name for the unique identifier (str).
41 | response_column: The column name for the response (str).
42 | """
43 |
44 | SCORE_COLUMN = "ifeval_score"
45 |
46 | def __init__(
47 | self,
48 | instruction_ids_column: str = "instruction_id_list",
49 | prompt_column: str = "prompt",
50 | keyword_args_column: str = "kwargs",
51 | key_column: str = "key",
52 | response_column: str = "response",
53 | ):
54 | self.instruction_ids_column = instruction_ids_column
55 | self.response_column = response_column
56 |
57 | def score(self, row: dict) -> Dict[str, Any]:
58 | return {self.SCORE_COLUMN: process_results(row, row[self.response_column])}
59 |
60 | @property
61 | def expected_keys(self) -> List[str]:
62 | return [
63 | self.instruction_ids_column,
64 | self.prompt_column,
65 | self.keyword_args_column,
66 | self.key_column,
67 | self.response_column,
68 | ]
69 |
--------------------------------------------------------------------------------
/evals/scoring/ifeval/instructions_main.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import Dict, Optional, Union
3 |
4 | from . import instructions_registry
5 |
6 |
7 | @dataclasses.dataclass
8 | class InputExample:
9 | key: int
10 | instruction_id_list: list[str]
11 | prompt: str
12 | kwargs: list[Dict[str, Optional[Union[str, int]]]]
13 |
14 |
15 | @dataclasses.dataclass
16 | class OutputExample:
17 | instruction_id_list: list[str]
18 | prompt: str
19 | response: str
20 | follow_all_instructions: bool
21 | follow_instruction_list: list[bool]
22 |
23 |
24 | def test_instruction_following_strict(
25 | inp,
26 | response,
27 | ):
28 | """Tests response to see if instructions are followed."""
29 | instruction_list = inp.instruction_id_list
30 | is_following_list = []
31 |
32 | for index, instruction_id in enumerate(instruction_list):
33 | instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
34 | instruction = instruction_cls(instruction_id)
35 |
36 | # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
37 | kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
38 | instruction.build_description(**kwargs)
39 | args = instruction.get_instruction_args()
40 | if args and "prompt" in args:
41 | instruction.build_description(prompt=inp.prompt)
42 |
43 | if response.strip() and instruction.check_following(response):
44 | is_following_list.append(True)
45 | else:
46 | is_following_list.append(False)
47 |
48 | return OutputExample(
49 | instruction_id_list=inp.instruction_id_list,
50 | prompt=inp.prompt,
51 | response=response,
52 | follow_all_instructions=all(is_following_list),
53 | follow_instruction_list=is_following_list,
54 | )
55 |
56 |
57 | def test_instruction_following_loose(
58 | inp,
59 | response,
60 | ):
61 | """Tests response for an upper bound for following instructions."""
62 | r = response.split("\n")
63 | response_remove_first = "\n".join(r[1:]).strip()
64 | response_remove_last = "\n".join(r[:-1]).strip()
65 | response_remove_both = "\n".join(r[1:-1]).strip()
66 | revised_response = response.replace("*", "")
67 | revised_response_remove_first = response_remove_first.replace("*", "")
68 | revised_response_remove_last = response_remove_last.replace("*", "")
69 | revised_response_remove_both = response_remove_both.replace("*", "")
70 | all_responses = [
71 | response,
72 | revised_response,
73 | response_remove_first,
74 | response_remove_last,
75 | response_remove_both,
76 | revised_response_remove_first,
77 | revised_response_remove_last,
78 | revised_response_remove_both,
79 | ]
80 | instruction_list = inp.instruction_id_list
81 | is_following_list = []
82 |
83 | for index, instruction_id in enumerate(instruction_list):
84 | instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
85 | instruction = instruction_cls(instruction_id)
86 |
87 | # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
88 | kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
89 | instruction.build_description(**kwargs)
90 | args = instruction.get_instruction_args()
91 | if args and "prompt" in args:
92 | instruction.build_description(prompt=inp.prompt)
93 |
94 | is_following = False
95 | for r in all_responses:
96 | if r.strip() and instruction.check_following(r):
97 | is_following = True
98 | break
99 |
100 | is_following_list.append(is_following)
101 |
102 | return OutputExample(
103 | instruction_id_list=inp.instruction_id_list,
104 | prompt=inp.prompt,
105 | response=response,
106 | follow_all_instructions=all(is_following_list),
107 | follow_instruction_list=is_following_list,
108 | )
109 |
110 |
111 | def agg_inst_level_acc(items):
112 | flat_items = [item for sublist in items for item in sublist]
113 | inst_level_acc = sum(flat_items) / len(flat_items)
114 | return inst_level_acc
115 |
--------------------------------------------------------------------------------
/evals/scoring/ifeval/instructions_registry.py:
--------------------------------------------------------------------------------
1 | """
2 | IFEval scoring functions from Google's source code: https://github.com/google-research/google-research/blob/master/instruction_following_eval/instruction_following_eval.py
3 | """
4 |
5 | # coding=utf-8
6 | # Copyright 2024 The Google Research Authors.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | from . import instructions
21 |
22 | _KEYWORD = "keywords:"
23 |
24 | _LANGUAGE = "language:"
25 |
26 | _LENGTH = "length_constraints:"
27 |
28 | _CONTENT = "detectable_content:"
29 |
30 | _FORMAT = "detectable_format:"
31 |
32 | _MULTITURN = "multi-turn:"
33 |
34 | _COMBINATION = "combination:"
35 |
36 | _STARTEND = "startend:"
37 |
38 | _CHANGE_CASES = "change_case:"
39 |
40 | _PUNCTUATION = "punctuation:"
41 |
42 | INSTRUCTION_DICT = {
43 | _KEYWORD + "existence": instructions.KeywordChecker,
44 | _KEYWORD + "frequency": instructions.KeywordFrequencyChecker,
45 | # TODO(jeffreyzhou): make a proper set of sentences to choose from
46 | # _KEYWORD + "key_sentences": instructions.KeySentenceChecker,
47 | _KEYWORD + "forbidden_words": instructions.ForbiddenWords,
48 | _KEYWORD + "letter_frequency": instructions.LetterFrequencyChecker,
49 | _LANGUAGE + "response_language": instructions.ResponseLanguageChecker,
50 | _LENGTH + "number_sentences": instructions.NumberOfSentences,
51 | _LENGTH + "number_paragraphs": instructions.ParagraphChecker,
52 | _LENGTH + "number_words": instructions.NumberOfWords,
53 | _LENGTH + "nth_paragraph_first_word": instructions.ParagraphFirstWordCheck,
54 | _CONTENT + "number_placeholders": instructions.PlaceholderChecker,
55 | _CONTENT + "postscript": instructions.PostscriptChecker,
56 | _FORMAT + "number_bullet_lists": instructions.BulletListChecker,
57 | # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace
58 | # _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph,
59 | _FORMAT + "constrained_response": instructions.ConstrainedResponseChecker,
60 | _FORMAT + "number_highlighted_sections": (instructions.HighlightSectionChecker),
61 | _FORMAT + "multiple_sections": instructions.SectionChecker,
62 | # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message.
63 | # _FORMAT + "rephrase": instructions.RephraseChecker,
64 | _FORMAT + "json_format": instructions.JsonFormat,
65 | _FORMAT + "title": instructions.TitleChecker,
66 | # TODO(tianjianlu): Re-enable with specific prompts.
67 | # _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker,
68 | _COMBINATION + "two_responses": instructions.TwoResponsesChecker,
69 | _COMBINATION + "repeat_prompt": instructions.RepeatPromptThenAnswer,
70 | _STARTEND + "end_checker": instructions.EndChecker,
71 | _CHANGE_CASES + "capital_word_frequency": instructions.CapitalWordFrequencyChecker,
72 | _CHANGE_CASES + "english_capital": instructions.CapitalLettersEnglishChecker,
73 | _CHANGE_CASES + "english_lowercase": instructions.LowercaseLettersEnglishChecker,
74 | _PUNCTUATION + "no_comma": instructions.CommaChecker,
75 | _STARTEND + "quotation": instructions.QuotationChecker,
76 | }
77 |
--------------------------------------------------------------------------------
/evals/scoring/livecodebench/__init__.py:
--------------------------------------------------------------------------------
1 | from .livecodebench_scorer import LiveCodeBenchScorer
2 |
3 | __all__ = ["LiveCodeBenchScorer"]
4 |
--------------------------------------------------------------------------------
/evals/scoring/livecodebench/livecodebench_scorer.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import copy
3 | from typing import Any, AsyncIterator, Dict, List, Literal, Tuple
4 |
5 | from ...util.common import has_code
6 |
7 | from ..base import BatchScorer, Scorer
8 | from .livecodebench_util import (
9 | _ray_wrapper,
10 | has_test_type,
11 | post_process_code,
12 | unsafe_lcb_runTests_mp,
13 | unsafe_lcb_runTests_ray,
14 | )
15 |
16 |
17 | class LiveCodeBenchScorer(Scorer):
18 | """Scorer for LiveCodeBench
19 |
20 | For the LiveCodeBench dataset format, see https://huggingface.co/datasets/livecodebench/code_generation_lite
21 |
22 | Args:
23 | question_content_column: The column name for the question (str).
24 | private_test_cases_column: The column name for the private test cases (str).
25 | public_test_cases_column: The column name for the public test cases (str).
26 | starter_code_column: The column name for the starter code (str).
27 | difficulty_column: The column name for the difficulty level (str).
28 | question_id_column: The column name for the question id (str).
29 | response_column: The column name for the response (str).
30 | backend: The backend to use for scoring. Supports "ray" or "mp" (str).
31 | """
32 |
33 | TIMEOUT = 6
34 | SCORE_COLUMN = "livecodebench_score"
35 |
36 | def __init__(
37 | self,
38 | question_content_column: str = "question_content",
39 | private_test_cases_column: str = "private_test_cases",
40 | public_test_cases_column: str = "public_test_cases",
41 | starter_code_column: str = "starter_code",
42 | difficulty_column: str = "difficulty",
43 | question_id_column: str = "question_id",
44 | response_column: str = "response",
45 | backend: Literal["ray", "mp"] = "ray",
46 | ):
47 |
48 | self.question_content_column = question_content_column
49 | self.private_test_cases_column = private_test_cases_column
50 | self.public_test_cases_column = public_test_cases_column
51 | self.starter_code_column = starter_code_column
52 | self.difficulty_column = difficulty_column
53 | self.question_id_column = question_id_column
54 | self.response_column = response_column
55 | self.backend = backend
56 |
57 | def score(self, row: dict) -> Dict[str, Any]:
58 | row = self.map_to_example(row)
59 |
60 | code_filter_result = has_code(row[self.response_column])
61 | last_code = None
62 | if len(code_filter_result) == 0:
63 | return {self.SCORE_COLUMN: False}
64 | else:
65 | last_code = code_filter_result[-1]
66 | problem_to_check = copy.deepcopy(row)
67 |
68 | if self.backend == "ray":
69 | result_list = unsafe_lcb_runTests_ray(
70 | problem_to_check,
71 | post_process_code(last_code),
72 | self.TIMEOUT,
73 | runtime_debug=False,
74 | is_extracted=not row["is_stdin"],
75 | )
76 | else:
77 | result_list = unsafe_lcb_runTests_mp(
78 | problem_to_check,
79 | post_process_code(last_code),
80 | self.TIMEOUT,
81 | runtime_debug=False,
82 | is_extracted=not row["is_stdin"],
83 | )
84 | details = [r[0] for r in result_list]
85 | all_passed = all(details)
86 |
87 | result = ""
88 | if result_list and all_passed:
89 | result = "passed"
90 |
91 | return {self.SCORE_COLUMN: result == "passed"}
92 |
93 | @property
94 | def expected_keys(self) -> List[str]:
95 | return [
96 | self.question_content_column,
97 | self.private_test_cases_column,
98 | self.public_test_cases_column,
99 | self.difficulty_column,
100 | self.question_id_column,
101 | self.starter_code_column,
102 | self.response_column,
103 | ]
104 |
105 | def map_to_example(self, row):
106 | return {
107 | "prompt": row[self.question_content_column],
108 | "test": row[self.private_test_cases_column],
109 | "entry_point": row[self.starter_code_column],
110 | "canonical_solution": "", # seems like live code bench lite does not have this field
111 | "task_id": row[self.question_id_column],
112 | "is_stdin": has_test_type(row[self.public_test_cases_column], "stdin"),
113 | "public_test_cases": row[self.public_test_cases_column],
114 | "difficulty": row[self.difficulty_column],
115 | self.response_column: row[self.response_column],
116 | }
117 |
118 |
119 | class LiveCodeBenchBatchScorer(BatchScorer):
120 | """Batch scorer for LiveCodeBench
121 |
122 | For the LiveCodeBench dataset format, see https://huggingface.co/datasets/livecodebench/code_generation_lite
123 |
124 | Args:
125 | question_content_column: The column name for the question (str).
126 | private_test_cases_column: The column name for the private test cases (str).
127 | public_test_cases_column: The column name for the public test cases (str).
128 | starter_code_column: The column name for the starter code (str).
129 | difficulty_column: The column name for the difficulty level (str).
130 | question_id_column: The column name for the question id (str).
131 | response_column: The column name for the response (str).
132 | """
133 |
134 | TIMEOUT = 6
135 | SCORE_COLUMN = "livecodebench_score"
136 |
137 | def __init__(
138 | self,
139 | question_content_column: str = "question_content",
140 | private_test_cases_column: str = "private_test_cases",
141 | public_test_cases_column: str = "public_test_cases",
142 | starter_code_column: str = "starter_code",
143 | difficulty_column: str = "difficulty",
144 | question_id_column: str = "question_id",
145 | response_column: str = "response",
146 | ):
147 | self.question_content_column = question_content_column
148 | self.private_test_cases_column = private_test_cases_column
149 | self.public_test_cases_column = public_test_cases_column
150 | self.starter_code_column = starter_code_column
151 | self.difficulty_column = difficulty_column
152 | self.question_id_column = question_id_column
153 | self.response_column = response_column
154 |
155 | async def score(self, rows: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]]:
156 |
157 | inputs = []
158 | ids = []
159 | for row in rows:
160 | row = self.map_to_example(row)
161 | code_filter_result = has_code(row[self.response_column])
162 | last_code = None
163 | if len(code_filter_result) == 0:
164 | yield {
165 | self.INTERNAL_IDX_KEY: row[self.INTERNAL_IDX_KEY],
166 | self.SCORE_COLUMN: False,
167 | }
168 | else:
169 | last_code = code_filter_result[-1]
170 | problem_to_check = copy.deepcopy(row)
171 |
172 | inputs.append(
173 | {
174 | "problem": problem_to_check,
175 | "completion": post_process_code(last_code),
176 | "timeout": self.TIMEOUT,
177 | "runtime_debug": False,
178 | "is_extracted": row["is_stdin"],
179 | }
180 | )
181 | ids.append(row[self.INTERNAL_IDX_KEY])
182 |
183 | async for output in _unsafe_lcb_runTests_ray_batch(ids, inputs):
184 | idx, result_list = output
185 | details = [r[0] for r in result_list]
186 | all_passed = all(details)
187 |
188 | result = ""
189 | if result_list and all_passed:
190 | result = "passed"
191 |
192 | yield {
193 | self.INTERNAL_IDX_KEY: idx,
194 | self.SCORE_COLUMN: result == "passed",
195 | }
196 |
197 | def map_to_example(self, row):
198 | return {
199 | "prompt": row[self.question_content_column],
200 | "test": row[self.private_test_cases_column],
201 | "entry_point": row[self.starter_code_column],
202 | "canonical_solution": "", # seems like live code bench lite does not have this field
203 | "task_id": row[self.question_id_column],
204 | "is_stdin": has_test_type(row[self.public_test_cases_column], "stdin"),
205 | "public_test_cases": row[self.public_test_cases_column],
206 | "difficulty": row[self.difficulty_column],
207 | self.response_column: row[self.response_column],
208 | self.INTERNAL_IDX_KEY: row[self.INTERNAL_IDX_KEY],
209 | }
210 |
211 |
212 | async def _unsafe_lcb_runTests_ray_batch(
213 | ids, inputs
214 | ) -> AsyncIterator[Tuple[int, List[Tuple[bool, str, str, float]]]]:
215 | refs = []
216 | for idx, _input in zip(ids, inputs):
217 | problem = _input["problem"]
218 | completion = _input["completion"]
219 | timeout = _input["timeout"]
220 | runtime_debug = _input["runtime_debug"]
221 | is_extracted = _input["is_extracted"]
222 | test_cases = problem["test"]
223 |
224 | result_ref = _ray_wrapper.remote(
225 | test_cases, completion, timeout, runtime_debug, is_extracted, idx
226 | )
227 | refs.append(result_ref)
228 |
229 | futs = [asyncio.wrap_future(ref.future()) for ref in refs]
230 | for fut in asyncio.as_completed(futs):
231 | idx, result = await fut
232 | _input = inputs[ids.index(idx)]
233 | ## This is supposed to be the case where not all test passed in the given timeout
234 | for _i in range(len(_input["problem"]["test"]) - len(result)):
235 | result.append((False, "Time out!.", "Error: Time out!", float("inf")))
236 | yield idx, result
237 |
--------------------------------------------------------------------------------
/evals/scoring/math/__init__.py:
--------------------------------------------------------------------------------
1 | from .math_scorer import MathEqualScorer, MathVerifyScorer
2 |
3 | __all__ = ["MathVerifyScorer", "MathEqualScorer"]
4 |
--------------------------------------------------------------------------------
/evals/scoring/math/math_scorer.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 | from ...util.math_parsing_util import extract_answer, math_equal
4 |
5 | from ..base import Scorer
6 |
7 | try:
8 | from math_verify import parse as mv_parse
9 | from math_verify import verify as mv_verify
10 | except ImportError:
11 | mv_parse = None
12 | mv_verify = None
13 |
14 |
15 | class MathEqualScorer(Scorer):
16 | """Scorer for math based on the `math_equal` function from Qwen Math
17 |
18 | Args:
19 | response_column: The column name for the model generated response. (str)
20 | answer_column: The column name for the ground truth answer. (str)
21 | """
22 |
23 | SCORE_COLUMN = "math_equal_score"
24 |
25 | def __init__(self, response_column: str, answer_column: str):
26 | self.response_column = response_column
27 | self.answer_column = answer_column
28 |
29 | def score(self, row: dict) -> Dict[str, Any]:
30 | try:
31 | pred = extract_answer(row[self.response_column])
32 | ref = extract_answer(row[self.answer_column])
33 | except Exception:
34 | return False
35 | return {self.SCORE_COLUMN: math_equal(pred, ref)}
36 |
37 | @property
38 | def expected_keys(self) -> List[str]:
39 | return [self.response_column, self.answer_column]
40 |
41 |
42 | class MathVerifyScorer(Scorer):
43 | """Scorer for math based on the `math_verify` function from HuggingFace
44 |
45 | Args:
46 | response_column: The column name for the model generated response. (str)
47 | answer_column: The column name for the ground truth answer. (str)
48 | """
49 |
50 | SCORE_COLUMN = "math_verify_score"
51 |
52 | def __init__(self, response_column: str, answer_column: str):
53 | self.response_column = response_column
54 | self.answer_column = answer_column
55 | if mv_parse is None or mv_verify is None:
56 | raise ImportError(
57 | "`math_verify` is not installed. Please install it with `pip install math_verify`."
58 | )
59 |
60 | def score(self, row: dict) -> Dict[str, Any]:
61 | try:
62 | pred = mv_parse(row[self.response_key])
63 | ref = mv_parse(row[self.answer_key])
64 | except Exception:
65 | return False
66 | return {self.SCORE_COLUMN: mv_verify(pred, ref)}
67 |
68 | @property
69 | def expected_keys(self) -> List[str]:
70 | return [self.response_column, self.answer_column]
71 |
--------------------------------------------------------------------------------
/evals/scoring/taco/__init__.py:
--------------------------------------------------------------------------------
1 | from .taco_scorer import TACOScorer
2 |
3 | __all__ = ["TACOScorer"]
4 |
--------------------------------------------------------------------------------
/evals/scoring/taco/taco_scorer.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | from multiprocessing import Manager
3 | from typing import Any, Dict, Literal
4 |
5 | import ray
6 |
7 | from ...util.common import has_code
8 |
9 | from ..base import Scorer
10 | from .taco_util import run_test as taco_run_test
11 |
12 |
13 | class TACOScorer(Scorer):
14 | SCORE_COLUMN = "taco_score"
15 |
16 | def __init__(
17 | self,
18 | response_column="response",
19 | input_output_column="input_output",
20 | backend: Literal["ray", "mp"] = "ray",
21 | ) -> None:
22 | super().__init__()
23 | self.response_column = response_column
24 | self.input_output_column = input_output_column
25 | self.backend = backend
26 | if backend not in ["ray", "mp"]:
27 | raise ValueError(f"Unsupported backend for launching tests: {backend}")
28 |
29 | def score(self, row: Dict[str, Any]):
30 | # Initialize the response structure
31 | response = row[self.response_column]
32 | input_outputs = row[self.input_output_column]
33 |
34 | code_filter_result = has_code(response)
35 | if len(code_filter_result) == 0:
36 | return {self.SCORE_COLUMN: False}
37 | else:
38 | last_code = code_filter_result[-1]
39 | if self.backend == "mp":
40 | curr_res, _ = _taco_run_tests_mp(input_outputs, generation=last_code)
41 | else:
42 | curr_res, _ = _taco_run_tests_ray(input_outputs, generation=last_code)
43 |
44 | if curr_res:
45 | return {self.SCORE_COLUMN: True}
46 | else:
47 | return {self.SCORE_COLUMN: False}
48 |
49 |
50 | def _taco_run_tests_mp(input_outputs, generation):
51 |
52 | def _temp_run(input_outputs, generation, debug, result):
53 | try:
54 | result.append(taco_run_test(input_outputs, test=generation, debug=debug))
55 | except Exception as e:
56 | print(f"Error in _temp_run: {e}")
57 |
58 | # run the test in a separate process for safety
59 | manager = Manager()
60 | result = manager.list()
61 | p = multiprocessing.Process(
62 | target=_temp_run, args=(input_outputs, generation, False, result)
63 | )
64 | p.start()
65 | p.join()
66 | if p.is_alive():
67 | p.kill()
68 | # get the first element in ListProxy - this is the result
69 | result = result[0]
70 | return bool(result and all(result)), result
71 |
72 |
73 | # NOTE (sumanthrh): We make sure that scoring for code generation is run on a separate process for isolation
74 | # We need to run scoring for each data sample in a separate process. Since ray doesn't play well with
75 | # multiprocessing, we launch scoring as a standalone ray task. Further, to make sure that resource requests
76 | # don't blow up for batched processing- for example, in a ray data pipeline, we reduce `num_cpus` to 0.001 from the default
77 | # value of 1. That way, scoring for different samples can timeshare on the same set of cpus.
78 | @ray.remote(num_cpus=0.001)
79 | def _temp_run_ray(input_outputs, generation, debug):
80 | result = []
81 | try:
82 | result = taco_run_test(input_outputs, test=generation, debug=debug)
83 | except Exception as e:
84 | print(f"Error in _temp_run: {e}")
85 | return result
86 |
87 |
88 | def _taco_run_tests_ray(input_outputs, generation):
89 | # run the test in a separate process for safety
90 | obj_ref = _temp_run_ray.remote(input_outputs, generation, False)
91 | result = ray.get(obj_ref)
92 | return bool(result and all(result)), result
93 |
--------------------------------------------------------------------------------
/evals/scoring/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanomaoli/llm_reproducibility/8a373c5a159a27e59783394827cecadd6255484e/evals/scoring/utils/__init__.py
--------------------------------------------------------------------------------
/evals/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .aime.aime_handler import AIMETaskHandler
4 | from .amc23.amc23_handler import AMC23TaskHandler
5 | from .apps.apps_handler import APPSTaskHandler
6 | from .arc.arc_handler import ARCChallengeTaskHandler
7 | from .base import ConversationType, TaskConfig, TaskHandler
8 | from .gpqa_diamond.gpqa_diamond_handler import GPQADiamondTaskHandler
9 | from .gsm8k.gsm8k_handler import GSM8KTaskHandler
10 | from .liveaops.liveaops_handler import LiveAOPSTaskHandler
11 | from .livecodebench.livecodebench_handler import LiveCodeBenchTaskHandler
12 | from .math.math_handler import MathTaskHandler
13 | from .minervamath.minervamath_handler import MinervaMathTaskHandler
14 | from .mmlu.mmlu_handler import MMLUProTaskHandler, MMLUTaskHandler
15 | from .numina.numina_handler import NUMINATaskHandler
16 | from .olympiadbench.olympiadbench_handler import OlympiadBenchMathTaskHandler
17 | from .omni_math.omni_handler import OMNIMathTaskHandler
18 | from .taco.taco_handler import TACOTaskHandler
19 | from .task_util import get_tasks
20 |
21 | TASK_HANDLER_MAP = {
22 | "numina": NUMINATaskHandler,
23 | "apps": APPSTaskHandler,
24 | "taco": TACOTaskHandler,
25 | "math": MathTaskHandler,
26 | "aime": AIMETaskHandler,
27 | "gpqa_diamond": GPQADiamondTaskHandler,
28 | "mmlu": MMLUTaskHandler,
29 | "mmlu_pro": MMLUProTaskHandler,
30 | "livecodebench": LiveCodeBenchTaskHandler,
31 | "gsm8k": GSM8KTaskHandler,
32 | "arc_c": ARCChallengeTaskHandler,
33 | "amc23": AMC23TaskHandler,
34 | "minervamath": MinervaMathTaskHandler,
35 | "olympiadbench_math": OlympiadBenchMathTaskHandler,
36 | "omni_math": OMNIMathTaskHandler,
37 | "liveaops": LiveAOPSTaskHandler,
38 | }
39 | TASK_NAMES_TO_YAML = get_tasks(os.path.dirname(__file__))
40 |
41 | __all__ = [
42 | "AIMETaskHandler",
43 | "APPSTaskHandler",
44 | "TACOTaskHandler",
45 | "MathTaskHandler",
46 | "AMC23TaskHandler",
47 | "NUMINATaskHandler",
48 | "GPQADiamondTaskHandler",
49 | "MMLUTaskHandler",
50 | "MMLUProTaskHandler",
51 | "LiveCodeBenchTaskHandler",
52 | "GSM8KTaskHandler",
53 | "ARCChallengeTaskHandler",
54 | "TaskHandler",
55 | "MathTaskHandler",
56 | "OlympiadBenchMathTaskHandler",
57 | "MinervaMathTaskHandler",
58 | "TaskConfig",
59 | "TASK_HANDLER_MAP",
60 | "TASK_NAMES_TO_YAML",
61 | "ConversationType",
62 | ]
63 |
--------------------------------------------------------------------------------
/evals/tasks/aime/aime24.yaml:
--------------------------------------------------------------------------------
1 | handler: aime
2 | dataset_path: AI-MO/aimo-validation-aime
3 | dataset_split: train
4 | question_key: problem
5 | answer_key: answer
6 | templating_parameters:
7 | template: "Return your final response within \\boxed{{}}. {prompt}"
8 | preprocess_config:
9 | url: "2024"
10 |
--------------------------------------------------------------------------------
/evals/tasks/aime/aime24_sky.yaml:
--------------------------------------------------------------------------------
1 | handler: aime
2 | dataset_path: AI-MO/aimo-validation-aime
3 | dataset_split: train
4 | question_key: problem
5 | answer_key: answer
6 | templating_parameters:
7 | template: "{prompt}\nReturn your final response within \\boxed{{}}"
8 | preprocess_config:
9 | url: "2024"
--------------------------------------------------------------------------------
/evals/tasks/aime/aime25_1.yaml:
--------------------------------------------------------------------------------
1 | handler: aime
2 | dataset_path: opencompass/AIME2025
3 | dataset_subset: AIME2025-I
4 | dataset_split: test
5 | question_key: question
6 | answer_key: answer
7 | templating_parameters:
8 | template: "{prompt}\nReturn your final response within \\boxed{{}}"
9 |
10 |
--------------------------------------------------------------------------------
/evals/tasks/aime/aime25_2.yaml:
--------------------------------------------------------------------------------
1 | handler: aime
2 | dataset_path: opencompass/AIME2025
3 | dataset_subset: AIME2025-II
4 | dataset_split: test
5 | question_key: question
6 | answer_key: answer
7 | templating_parameters:
8 | template: "{prompt}\nReturn your final response within \\boxed{{}}"
9 |
10 |
--------------------------------------------------------------------------------
/evals/tasks/aime/aime_handler.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from ..math.math_handler import MathTaskHandler
4 |
5 |
6 | class AIMETaskHandler(MathTaskHandler):
7 | def generate_prompt(self, problem: Dict):
8 | return self.task_config.templating_parameters["template"].format(
9 | prompt=problem[self.question_key]
10 | )
11 |
12 | def load_and_filter_dataset(
13 | self, start, end, split=None, subset=None, difficulty=None
14 | ):
15 | train_data = self.load_dataset(subset=subset, split=split).to_pandas()
16 | if self.task_config.preprocess_config:
17 | if "url" in self.task_config.preprocess_config:
18 | train_data = train_data[
19 | train_data["url"].str.contains(
20 | self.task_config.preprocess_config["url"], na=False
21 | )
22 | ]
23 | return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:]
24 |
--------------------------------------------------------------------------------
/evals/tasks/amc23/amc23.yaml:
--------------------------------------------------------------------------------
1 | handler: amc23
2 | dataset_path: AI-MO/aimo-validation-amc
3 | dataset_kwargs:
4 | trust_remote_code: true
5 | dataset_split: train
6 | question_key: problem
7 | answer_key: answer
8 | # Optionally, you can filter the dataset by difficulty
9 | # preprocess_config:
10 | # difficulty: easy
11 | templating_parameters:
12 | template: "Return your final response within \\boxed{{}}. {problem}"
13 |
--------------------------------------------------------------------------------
/evals/tasks/amc23/amc23_handler.py:
--------------------------------------------------------------------------------
1 | from ..math.math_handler import MathTaskHandler
2 |
3 |
4 | class AMC23TaskHandler(MathTaskHandler):
5 | def load_and_filter_dataset(
6 | self, start, end, split=None, subset=None, difficulty=None
7 | ):
8 | train_data = self.load_dataset(subset=subset, split=split).to_pandas()
9 | filtered_data = train_data[train_data["url"].str.contains("2023", na=False)]
10 | return filtered_data.iloc[start:end] if end > 0 else filtered_data.iloc[start:]
11 |
--------------------------------------------------------------------------------
/evals/tasks/apps/apps.yaml:
--------------------------------------------------------------------------------
1 | handler: apps
2 | dataset_path: codeparrot/apps
3 | dataset_subset: all
4 | dataset_kwargs:
5 | trust_remote_code: true
6 | dataset_split: test
7 | question_key: question
8 | answer_key: null
9 | # preprocess_config:
10 | # difficulty: null
11 | templating_parameters:
12 | with_fn_name_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}"
13 | without_fn_name_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}"
14 | # Add starter code on top of the initial template
15 | with_starter_code_template: "{input}\n{starter_code}"
16 | # Optionally, you can filter the dataset by difficulty
17 | # preprocess_config:
18 | # difficulty: easy
19 |
--------------------------------------------------------------------------------
/evals/tasks/apps/apps_handler.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import multiprocessing
4 | from multiprocessing import Manager
5 |
6 | import numpy as np
7 |
8 | from ...util.common import has_code
9 |
10 | from .apps_util import run_test as apps_run_test
11 | from ..base import TaskHandler
12 |
13 |
14 | class APPSTaskHandler(TaskHandler):
15 |
16 | def generate_prompt(self, problem):
17 | # test_case, prompt, starter_code=None
18 | test_case = json.loads(problem["input_output"])
19 | starter_code = problem["starter_code"]
20 | prompt = problem["question"]
21 | if not test_case.get("fn_name"):
22 | _input = self.task_config.templating_parameters[
23 | "with_fn_name_template"
24 | ].format(prompt=prompt)
25 | else:
26 | _input = self.task_config.templating_parameters[
27 | "without_fn_name_template"
28 | ].format(prompt=prompt)
29 |
30 | if starter_code is not None:
31 | _input = self.task_config.templating_parameters[
32 | "with_starter_code_template"
33 | ].format(input=_input, starter_code=starter_code)
34 | return _input
35 |
36 | def check_correctness(self, problem, generation):
37 | TIMEOUT = 10
38 |
39 | def _temp_run(problem, generation, debug, result):
40 | try:
41 | result.append(
42 | apps_run_test(problem=problem, test=generation, debug=debug)
43 | )
44 | except Exception:
45 | pass
46 |
47 | manager = Manager()
48 | result = manager.list()
49 | p = multiprocessing.Process(
50 | target=_temp_run, args=(problem, generation, False, result)
51 | )
52 | p.start()
53 | p.join(timeout=TIMEOUT + 1)
54 | if p.is_alive():
55 | p.kill()
56 | return bool(result and np.all(result[0]))
57 |
58 | def update_results(self, problem, response):
59 | # Initialize the response structure
60 | response_entry = {
61 | "content": response,
62 | "correctness": None,
63 | "reason": None,
64 | }
65 | code_filter_result = has_code(response)
66 | if len(code_filter_result) == 0:
67 | response_entry["correctness"] = False
68 | response_entry["reason"] = "Does not contain code component."
69 | else:
70 | last_code = code_filter_result[-1]
71 | problem_to_check = copy.deepcopy(problem)
72 | problem_to_check["input_output"] = json.loads(problem["input_output"])
73 | try:
74 | problem_to_check["solutions"] = json.loads(problem["solutions"])
75 | except Exception:
76 | problem_to_check["solutions"] = ""
77 | print("Empty solution from the dataset")
78 | curr_res = self.check_correctness(problem_to_check, generation=last_code)
79 | if curr_res:
80 | response_entry["correctness"] = True
81 | response_entry["reason"] = ""
82 | else:
83 | response_entry["correctness"] = False
84 | response_entry["reason"] = "Code is incorrect."
85 |
86 | return response_entry
87 |
88 | def load_and_filter_dataset(
89 | self, start, end, split=None, subset=None, difficulty=None
90 | ):
91 | train_data = self.load_dataset(subset=subset, split=split)
92 | if difficulty or "difficulty" in self.task_config.preprocess_config:
93 | difficulty = (
94 | self.task_config.preprocess_config["difficulty"]
95 | if not difficulty
96 | else difficulty
97 | )
98 | train_data = train_data.filter(lambda x: x["difficulty"] == difficulty)
99 |
100 | train_data = train_data.to_pandas()
101 |
102 | return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:]
103 |
--------------------------------------------------------------------------------
/evals/tasks/arc/arc_c.yaml:
--------------------------------------------------------------------------------
1 | handler: arc_c
2 | dataset_path: allenai/ai2_arc
3 | dataset_subset: ARC-Challenge
4 | dataset_split: train
5 | question_key: question
6 | answer_key: answerKey
7 | templating_parameters:
8 | # We combine choices for a question into choices_text entry in the dataset
9 | template: "Given the following question and four candidate answers (A, B, C and D), choose the best answer. Your response should end with \"The best answer is [the_answer_letter]\" where [the_answer_letter] is one of the four letter choice (A, B, C, or D).\n{question}\n{choices_text}"
--------------------------------------------------------------------------------
/evals/tasks/arc/arc_handler.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Any, Dict
3 |
4 | from ...util.math_parsing_util import extract_answer
5 |
6 | from ..base import TaskConfig, TaskHandler
7 |
8 |
9 | class ARCChallengeTaskHandler(TaskHandler):
10 | def __init__(self, task_config: TaskConfig) -> None:
11 | super().__init__(task_config)
12 | self.ans_re = re.compile(r"[Tt]he best answer is ([A-D])[\.\,]*", re.IGNORECASE)
13 | self.letter_re = re.compile(r"([A-D])[\.\,]*")
14 | self.canonical_options = ["A", "B", "C", "D"]
15 | self.invalid_ans = "[invalid]"
16 |
17 | def generate_prompt(self, problem):
18 | choices = problem["choices"]
19 | choices_text = "\n".join(
20 | [
21 | f"{label}.{choice}"
22 | for label, choice in zip(self.canonical_options, choices["text"])
23 | ]
24 | )
25 | problem["choices_text"] = choices_text
26 | full_prompt = self.task_config.templating_parameters["template"].format(
27 | **problem
28 | )
29 | return full_prompt
30 |
31 | def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool:
32 | gt_answer = problem[self.task_config.answer_key]
33 | if gt_answer not in self.canonical_options:
34 | gt_answer = self.canonical_options[
35 | int(problem[self.task_config.answer_key]) - 1
36 | ]
37 | model_answer = self.get_answer(generation)
38 | return model_answer == gt_answer
39 |
40 | def update_results(self, problem, response):
41 | # Initialize the response structure
42 | response_entry = {
43 | "content": response,
44 | "correctness": None,
45 | "reason": None,
46 | }
47 | curr_res = self.check_correctness(problem, generation=response)
48 | if curr_res:
49 | response_entry["correctness"] = True
50 | response_entry["reason"] = ""
51 | else:
52 | response_entry["correctness"] = False
53 | response_entry["reason"] = "Solution is incorrect."
54 |
55 | return response_entry
56 |
57 | def load_and_filter_dataset(
58 | self, start, end, split=None, subset=None, difficulty=None
59 | ):
60 | train_data = self.load_dataset(subset=subset, split=split).to_pandas()
61 | return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:]
62 |
63 | def get_answer(self, completion):
64 | # First, we try to extract similar to MATH answers
65 | answer = extract_answer(completion)
66 | match = None
67 | if answer:
68 | # match for the letter answer needed.
69 | match = self.letter_re.search(answer)
70 | if match:
71 | return match.group(1).strip()
72 |
73 | if not answer or not match:
74 | # try basic-regex based search
75 | patterns_to_remove = [
76 | ",", # Remove commas
77 | r"\$", # Remove dollar signs
78 | r"\.$" r"\\", # Remove trailing period # Remove stray backslashes
79 | r"\*", # Remove asterisks
80 | ]
81 | answer = completion
82 | for pattern in patterns_to_remove:
83 | answer = re.sub(pattern, "", answer)
84 | matches = self.ans_re.findall(answer)
85 | if not matches:
86 | return self.invalid_ans
87 | return matches[-1].strip()
88 |
--------------------------------------------------------------------------------
/evals/tasks/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Dict, List, Optional
3 | from urllib.parse import urlparse
4 |
5 | import pandas as pd
6 | import yaml
7 | from datasets import Dataset as HFDataset
8 | from datasets import load_dataset
9 | from pydantic import BaseModel, Field
10 |
11 | ConversationType = List[Dict[str, Any]]
12 |
13 |
14 | class TaskConfig(BaseModel):
15 | handler: str
16 | dataset_path: str
17 | dataset_subset: Optional[str] = None
18 | dataset_split: Optional[str] = None
19 | dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
20 | question_key: str
21 | # Optional answer key for datasets with a single correct answer
22 | answer_key: Optional[str] = None
23 | templating_parameters: Dict[str, str] = Field(default_factory=dict)
24 | # Example fields
25 | # fewshot_config: List[Dict[str, Any]] = Field(default_factory=list)
26 | # num_fewshot: int = 0
27 |
28 | preprocess_config: Dict[str, Any] = Field(default_factory=dict)
29 |
30 | @classmethod
31 | def from_yaml(cls, yaml_file_path) -> "TaskConfig":
32 | with open(yaml_file_path, "r", encoding="utf-8") as f:
33 | config_dict = yaml.safe_load(f)
34 | return cls(**config_dict)
35 |
36 | def update(self, **kwargs):
37 | for key, value in kwargs.items():
38 | setattr(self, key, value)
39 |
40 |
41 | class TaskHandler(ABC):
42 |
43 | def __init__(self, task_config: TaskConfig):
44 | self.task_config = task_config
45 |
46 | @classmethod
47 | def from_config_path(cls, config_path: str) -> "TaskHandler":
48 | task_config = TaskConfig.from_yaml(config_path)
49 | return cls(task_config)
50 |
51 | @property
52 | def question_key(self):
53 | return self.task_config.question_key
54 |
55 | @abstractmethod
56 | def check_correctness(
57 | self, problem: Dict[str, Any], generation: Dict[str, Any]
58 | ) -> bool:
59 | pass
60 |
61 | @abstractmethod
62 | def update_results(self, problem: Dict[str, Any], response: str) -> Dict[str, Any]:
63 | pass
64 |
65 | def make_conversations(
66 | self,
67 | data: List[Dict[str, Any]],
68 | system_prompt: Optional[str] = None,
69 | user_template: Optional[str] = None,
70 | assistant_prefill: Optional[str] = None,
71 | ) -> List[ConversationType]:
72 | conversations = []
73 | for _, problem in enumerate(data):
74 | prompt_text = self.generate_prompt(problem)
75 | conversations.append(
76 | make_conversation_from_contents(
77 | [prompt_text],
78 | system_prompt=system_prompt,
79 | user_template=user_template,
80 | assistant_prefill=assistant_prefill,
81 | )
82 | )
83 | return conversations
84 |
85 | def load_dataset(self, subset=None, split=None, **kwargs) -> HFDataset:
86 | # check if the path provided is a valid URL
87 | parsed = urlparse(self.task_config.dataset_path)
88 | if not parsed.scheme:
89 | # HF dataset
90 | dataset = load_dataset(
91 | path=self.task_config.dataset_path,
92 | name=subset if subset else self.task_config.dataset_subset,
93 | split=split if split else self.task_config.dataset_split,
94 | **self.task_config.dataset_kwargs,
95 | )
96 | else:
97 | # Try to load URL
98 | # Only JSON supported for now
99 | if split is not None or subset is not None:
100 | raise ValueError(
101 | "URL-based dataset does not support loading arguments like `split`, `subset`"
102 | )
103 | # By default, Huggingface will create a DatasetDict object with "train" split
104 | dataset = load_dataset("json", data_files=[self.task_config.dataset_path])[
105 | "train"
106 | ]
107 |
108 | # add an index column efficiently with map
109 | dataset = dataset.map(add_idx_map, with_indices=True)
110 | return dataset
111 |
112 | @abstractmethod
113 | def load_and_filter_dataset(
114 | self, start, end, split=None, subset=None, difficulty=None
115 | ) -> pd.DataFrame:
116 | pass
117 |
118 | def process_remaining_data(self, train_data, id_to_results: dict):
119 | return [
120 | row.to_dict()
121 | for _, row in train_data.iterrows()
122 | if str(row["_index"]) not in id_to_results
123 | ]
124 |
125 |
126 | def add_idx_map(x: dict, idx: int) -> dict:
127 | # We convert to string for consistency
128 | x["_index"] = str(idx)
129 | return x
130 |
131 |
132 | def make_conversation_from_contents(
133 | contents: List[str],
134 | system_prompt: Optional[str] = None,
135 | user_template: Optional[str] = None,
136 | assistant_prefill: Optional[str] = None,
137 | ) -> ConversationType:
138 | """Makes a conversation given a list of user/assistant message strings.
139 |
140 | If system_prompt is provided, it will be added as the first message.
141 | If user_template is provided, it will be used to format the user messages. This is useful for model-specific formatting.
142 |
143 | Args:
144 | content: A list of user/assistant message strings.
145 | system_prompt: An optional string for the system prompt.
146 | user_template: An optional string for the user template.
147 |
148 | Returns:
149 | A list of dictionaries representing the conversation.
150 | """
151 |
152 | conversation = []
153 | if system_prompt:
154 | conversation.append({"role": "system", "content": system_prompt})
155 |
156 | for i, content in enumerate(contents):
157 | if i % 2 == 0:
158 | content = user_template.format(content) if user_template else content
159 | conversation.append({"role": "user", "content": content})
160 | else:
161 | conversation.append({"role": "assistant", "content": content})
162 |
163 | if assistant_prefill and conversation[-1]["role"] == "user":
164 | conversation.append({"role": "assistant", "content": assistant_prefill})
165 |
166 | return conversation
167 |
--------------------------------------------------------------------------------
/evals/tasks/gpqa_diamond/gpqa_diamond.yaml:
--------------------------------------------------------------------------------
1 | handler: gpqa_diamond
2 | dataset_path: Idavidrein/gpqa
3 | dataset_subset: gpqa_diamond
4 | dataset_split: train
5 | question_key: Question
6 | answer_key: Answer
7 | templating_parameters:
8 | # For GPQA, we combine the Question key and the multiple choice answers into a single `prompt` entry
9 | template: "Return your final response within \\boxed{{}} and only include the letter choice (A, B, C, or D) as your final response. {prompt}"
--------------------------------------------------------------------------------
/evals/tasks/gpqa_diamond/gpqa_diamond_handler.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from ...util.math_parsing_util import get_multiple_choice_answer
4 |
5 | from ..base import TaskHandler
6 |
7 |
8 | class GPQADiamondTaskHandler(TaskHandler):
9 |
10 | def generate_prompt(self, problem):
11 | multiple_choice_string, correct_answer_letter = (
12 | self.get_multiple_choice_answers(problem)
13 | )
14 | problem["Answer"] = correct_answer_letter
15 | problem["prompt"] = problem["Question"] + "\n" + multiple_choice_string
16 | return self.task_config.templating_parameters["template"].format(
17 | prompt=problem["prompt"]
18 | )
19 |
20 | def update_results(self, problem, response):
21 | # Initialize the response structure
22 | response_entry = {
23 | "content": response,
24 | "correctness": None,
25 | "reason": None,
26 | }
27 | curr_res = self.check_correctness(problem, generation=response)
28 | if curr_res:
29 | response_entry["correctness"] = True
30 | response_entry["reason"] = ""
31 | else:
32 | response_entry["correctness"] = False
33 | response_entry["reason"] = "Solution is incorrect."
34 |
35 | return response_entry
36 |
37 | def check_correctness(self, problem, generation):
38 | pred = get_multiple_choice_answer(generation)
39 | answer = problem[self.task_config.answer_key]
40 | return answer == pred
41 |
42 | def get_multiple_choice_answers(self, data):
43 | answers = [
44 | data["Correct Answer"],
45 | data["Incorrect Answer 1"],
46 | data["Incorrect Answer 2"],
47 | data["Incorrect Answer 3"],
48 | ]
49 | random.shuffle(answers)
50 |
51 | # Map options to letters
52 | options = ["A", "B", "C", "D"]
53 | options_to_answers = {
54 | letter: answer for letter, answer in zip(options, answers)
55 | }
56 |
57 | # Format the options into the string
58 | multiple_choice_string = ", ".join(
59 | f"{letter}) {options_to_answers[letter]}" for letter in options
60 | )
61 |
62 | # Save the letter corresponding to the correct answer
63 | correct_answer_letter = next(
64 | letter
65 | for letter, answer in options_to_answers.items()
66 | if answer == data["Correct Answer"]
67 | )
68 |
69 | return multiple_choice_string, correct_answer_letter
70 |
71 | def load_and_filter_dataset(
72 | self, start, end, split=None, subset=None, difficulty=None
73 | ):
74 | train_data = self.load_dataset(subset=subset, split=split).to_pandas()
75 | return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:]
76 |
--------------------------------------------------------------------------------
/evals/tasks/gsm8k/gsm8k.yaml:
--------------------------------------------------------------------------------
1 | handler: gsm8k
2 | dataset_path: "openai/gsm8k"
3 | dataset_subset: main
4 | dataset_split: test
5 | question_key: question
6 | answer_key: answer
7 | templating_parameters:
8 | template: "Given the following problem, reason and give a final answer to the problem.\nProblem: {question}\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem."
9 |
10 |
--------------------------------------------------------------------------------
/evals/tasks/gsm8k/gsm8k_handler.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Any, Dict
3 |
4 | from ...util.math_parsing_util import extract_answer
5 |
6 | from ..base import TaskConfig, TaskHandler
7 |
8 |
9 | class GSM8KTaskHandler(TaskHandler):
10 | def __init__(self, task_config: TaskConfig) -> None:
11 | super().__init__(task_config)
12 | self.ans_re = re.compile(r"((-?[$0-9.,]{2,})|(-?[0-9]+))")
13 | self.gt_re = re.compile(r"#### (\-?[0-9\.\,]+)")
14 | self.invalid_ans = "[invalid]"
15 |
16 | def generate_prompt(self, problem):
17 | return self.task_config.templating_parameters["template"].format(**problem)
18 |
19 | def check_correctness(self, problem: Dict[str, Any], generation: str) -> bool:
20 | gt_answer = self.extract_gt_answer(problem[self.task_config.answer_key])
21 | model_answer = extract_answer(generation)
22 | model_answer = self.sanitize_answer(model_answer)
23 | return model_answer == gt_answer
24 |
25 | def update_results(self, problem, response):
26 | # Initialize the response structure
27 | response_entry = {
28 | "content": response,
29 | "correctness": None,
30 | "reason": None,
31 | }
32 | curr_res = self.check_correctness(problem, generation=response)
33 | if curr_res:
34 | response_entry["correctness"] = True
35 | response_entry["reason"] = ""
36 | else:
37 | response_entry["correctness"] = False
38 | response_entry["reason"] = "Solution is incorrect."
39 |
40 | return response_entry
41 |
42 | def load_and_filter_dataset(
43 | self, start, end, split=None, subset=None, difficulty=None
44 | ):
45 | train_data = self.load_dataset(subset=subset, split=split).to_pandas()
46 | return train_data.iloc[start:end] if end > 0 else train_data.iloc[start:]
47 |
48 | def extract_gt_answer(self, completion):
49 | match = self.gt_re.search(completion)
50 | if match:
51 | match_str = match.group(1).strip()
52 | match_str = match_str.replace(",", "")
53 | return match_str
54 | else:
55 | return self.invalid_ans
56 |
57 | def sanitize_answer(self, answer):
58 | patterns_to_remove = [
59 | ",", # Remove commas
60 | r"\$", # Remove dollar signs
61 | r"\.$" r"\*", # Remove trailing period # Remove asterisks
62 | ]
63 | for pattern in patterns_to_remove:
64 | answer = re.sub(pattern, "", answer)
65 |
66 | matches = self.ans_re.findall(answer)
67 | if matches:
68 | # get the last match (i.e final response) and the first / outer capturing group
69 | match_str = matches[-1][0].strip()
70 | return match_str
71 | else:
72 | return self.invalid_ans
73 |
--------------------------------------------------------------------------------
/evals/tasks/liveaops/liveaops.yaml:
--------------------------------------------------------------------------------
1 | handler: liveaops
2 | dataset_path: https://livemathbench.github.io/data/LiveAoPSBench-2024.jsonl
3 | dataset_subset: null # which subset on huggingface. Not applicable for a URL dataset
4 | dataset_split: null # Rule based evaluation
5 | question_key: question
6 | answer_key: answer
7 | templating_parameters:
8 | template: "Return your final response within \\boxed{{}}. {question}"
9 |
--------------------------------------------------------------------------------
/evals/tasks/liveaops/liveaops_handler.py:
--------------------------------------------------------------------------------
1 | from ...util.math_parsing_util import (
2 | extract_answer,
3 | math_equal,
4 | strip_answer_string,
5 | )
6 |
7 | from ..math.math_handler import MathTaskHandler
8 |
9 |
10 | class LiveAOPSTaskHandler(MathTaskHandler):
11 | def generate_prompt(self, problem):
12 | return self.task_config.templating_parameters["template"].format(**problem)
13 |
14 | def check_correctness(self, problem, generation):
15 | # no preprocessing needed
16 | answer = problem[self.task_config.answer_key]
17 | pred = extract_answer(generation)
18 | pred = strip_answer_string(pred)
19 | return math_equal(pred, answer)
20 |
21 | def load_and_filter_dataset(
22 | self, start, end, split=None, subset=None, difficulty=None
23 | ):
24 | assert difficulty is None, "LiveAOPS does not support `difficulty` argument"
25 | dataset = self.load_dataset(subset=subset, split=split).to_pandas()
26 | return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:]
27 |
--------------------------------------------------------------------------------
/evals/tasks/livecodebench/livecodebench.yaml:
--------------------------------------------------------------------------------
1 | handler: livecodebench
2 | dataset_path: "livecodebench/code_generation_lite" # repo ID in huggingface
3 | dataset_subset: null
4 | dataset_split: test
5 | dataset_kwargs:
6 | version_tag: release_v2
7 | trust_remote_code: true
8 | question_key: task_id
9 | answer_key: null
10 | templating_parameters:
11 | stdin_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}"
12 | non_stdin_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}"
13 | # Optionally, you can filter the dataset by difficulty
14 | # preprocess_config:
15 | # difficulty: easy
16 |
--------------------------------------------------------------------------------
/evals/tasks/livecodebench/livecodebench_easy.yaml:
--------------------------------------------------------------------------------
1 | handler: livecodebench
2 | dataset_path: "livecodebench/code_generation_lite" # repo ID in huggingface
3 | dataset_subset: null
4 | dataset_split: test
5 | dataset_kwargs:
6 | version_tag: release_v2
7 | trust_remote_code: true
8 | question_key: task_id
9 | answer_key: null
10 | templating_parameters:
11 | stdin_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}"
12 | non_stdin_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}"
13 | preprocess_config:
14 | difficulty: easy
15 |
--------------------------------------------------------------------------------
/evals/tasks/livecodebench/livecodebench_handler.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Dict
3 |
4 | from datasets import Dataset as HFDataset
5 |
6 | from ...util.common import has_code
7 |
8 | from ..base import TaskHandler
9 | from .livecodebench_util import (
10 | map_to_example,
11 | post_process_code,
12 | translate_private_test_cases,
13 | unsafe_lcb_runTests,
14 | )
15 |
16 |
17 | class LiveCodeBenchTaskHandler(TaskHandler):
18 |
19 | def generate_prompt(self, problem):
20 | if problem["is_stdin"]:
21 | return self.task_config.templating_parameters["stdin_template"].format(
22 | **problem
23 | )
24 | else:
25 | return self.task_config.templating_parameters["non_stdin_template"].format(
26 | **problem
27 | )
28 |
29 | def check_correctness(
30 | self,
31 | problem: Dict,
32 | completion: str,
33 | timeout: float,
34 | runtime_debug=False,
35 | is_extracted=False,
36 | ) -> Dict:
37 | """
38 | Evaluates the functional correctness of a completion by running the test
39 | suite provided in the problem.
40 |
41 | :param completion_id: an optional completion ID so we can match
42 | the results later even if execution finishes asynchronously.
43 | """
44 | result_list = unsafe_lcb_runTests(
45 | problem, completion, timeout, runtime_debug, is_extracted
46 | )
47 | details = [r[0] for r in result_list]
48 | all_passed = all(details)
49 |
50 | result = ""
51 | if result_list and all_passed:
52 | result = "passed"
53 |
54 | return result == "passed"
55 |
56 | def update_results(self, problem, response):
57 | # Initialize the response structure
58 | response_entry = {
59 | "content": response,
60 | "correctness": None,
61 | "reason": None,
62 | }
63 | code_filter_result = has_code(response)
64 | # print(response)
65 | if len(code_filter_result) == 0:
66 | response_entry["correctness"] = False
67 | response_entry["reason"] = "Does not contain code component."
68 | else:
69 | last_code = code_filter_result[-1]
70 | problem_to_check = copy.deepcopy(problem)
71 |
72 | curr_res = self.check_correctness(
73 | problem=problem_to_check,
74 | completion=post_process_code(last_code),
75 | timeout=6,
76 | is_extracted=not problem_to_check["is_stdin"],
77 | )
78 | if curr_res:
79 | response_entry["correctness"] = True
80 | response_entry["reason"] = ""
81 | else:
82 | response_entry["correctness"] = False
83 | response_entry["reason"] = "Code is incorrect."
84 |
85 | return response_entry
86 |
87 | def load_and_filter_dataset(
88 | self, start, end, split=None, subset=None, difficulty=None
89 | ):
90 | dataset: HFDataset = self.load_dataset(subset=subset, split=split)
91 | # Filter by CLI or config
92 | if difficulty or "difficulty" in self.task_config.preprocess_config:
93 | difficulty = (
94 | difficulty
95 | if difficulty
96 | else self.task_config.preprocess_config["difficulty"]
97 | )
98 | dataset = dataset.filter(
99 | lambda example: example["difficulty"] == difficulty
100 | )
101 | # We use a lower writer_batch_size to avoid pyarrow issues. JSON entries with LiveCodeBench are large.
102 | # See: https://github.com/NovaSky-AI/SkyThought/pull/45 for details.
103 | dataset = dataset.map(
104 | lambda example: {
105 | "private_test_cases": translate_private_test_cases(
106 | example["private_test_cases"]
107 | )
108 | },
109 | writer_batch_size=100,
110 | )
111 | # Apply the mapping function
112 | # TODO (sumanthrh): See if the appropriate livecodebench columns can be renamed instead and let other columns pass-through
113 | dataset = dataset.map(
114 | map_to_example,
115 | remove_columns=dataset.column_names.remove("_index"),
116 | writer_batch_size=100,
117 | ).to_pandas()
118 | return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:]
119 |
--------------------------------------------------------------------------------
/evals/tasks/livecodebench/livecodebench_hard.yaml:
--------------------------------------------------------------------------------
1 | handler: livecodebench
2 | dataset_path: "livecodebench/code_generation_lite" # repo ID in huggingface
3 | dataset_subset: null
4 | dataset_split: test
5 | dataset_kwargs:
6 | version_tag: release_v2
7 | trust_remote_code: true
8 | question_key: task_id
9 | answer_key: null
10 | templating_parameters:
11 | stdin_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}"
12 | non_stdin_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}"
13 | preprocess_config:
14 | difficulty: hard
15 |
--------------------------------------------------------------------------------
/evals/tasks/livecodebench/livecodebench_medium.yaml:
--------------------------------------------------------------------------------
1 | handler: livecodebench
2 | dataset_path: "livecodebench/code_generation_lite" # repo ID in huggingface
3 | dataset_subset: null
4 | dataset_split: test
5 | dataset_kwargs:
6 | version_tag: release_v2
7 | trust_remote_code: true
8 | question_key: task_id
9 | answer_key: null
10 | templating_parameters:
11 | stdin_template: "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition. {prompt}"
12 | non_stdin_template: "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution. {prompt}"
13 | preprocess_config:
14 | difficulty: medium
15 |
--------------------------------------------------------------------------------
/evals/tasks/math/math500.yaml:
--------------------------------------------------------------------------------
1 | handler: math
2 | dataset_path: "HuggingFaceH4/MATH-500" # repo ID in huggingface
3 | dataset_subset: null # which subset on huggingface
4 | question_key: problem
5 | answer_key: answer
6 | dataset_split: test
7 | templating_parameters:
8 | template: "Return your final response within \\boxed{{}}. {problem}"
9 | # optional. Not supported yet.
10 | # fewshot_config:
11 | # - question: ...
12 | # - target: ...
13 | # num_fewshot: 0
14 |
--------------------------------------------------------------------------------
/evals/tasks/math/math_handler.py:
--------------------------------------------------------------------------------
1 | from ...util.math_parsing_util import (
2 | extract_answer,
3 | math_equal,
4 | strip_answer_string,
5 | )
6 |
7 | from ..base import TaskHandler
8 |
9 |
10 | class MathTaskHandler(TaskHandler):
11 | def generate_prompt(self, problem):
12 | return self.task_config.templating_parameters["template"].format(**problem)
13 |
14 | def check_correctness(self, problem, generation):
15 | answer = strip_answer_string(problem[self.task_config.answer_key])
16 | pred = extract_answer(generation)
17 | pred = strip_answer_string(pred)
18 | return math_equal(pred, answer)
19 |
20 | def update_results(self, problem, response):
21 | # Initialize the response structure
22 | response_entry = {
23 | "content": response,
24 | "correctness": None,
25 | "reason": None,
26 | }
27 | curr_res = self.check_correctness(problem, generation=response)
28 | if curr_res:
29 | response_entry["correctness"] = True
30 | response_entry["reason"] = ""
31 | else:
32 | response_entry["correctness"] = False
33 | response_entry["reason"] = "Solution is incorrect."
34 |
35 | return response_entry
36 |
37 | def load_and_filter_dataset(
38 | self, start, end, split=None, subset=None, difficulty=None
39 | ):
40 | dataset = self.load_dataset(subset=subset, split=split).to_pandas()
41 | return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:]
42 |
--------------------------------------------------------------------------------
/evals/tasks/minervamath/minervamath.yaml:
--------------------------------------------------------------------------------
1 | handler: math
2 | dataset_path: "svc-huggingface/minerva-math" # repo ID in huggingface
3 | dataset_subset: null # which subset on huggingface
4 | question_key: problem
5 | answer_key: solution
6 | dataset_split: test
7 | templating_parameters:
8 | template: "Return your final response within \\boxed{{}}. {problem}"
--------------------------------------------------------------------------------
/evals/tasks/minervamath/minervamath_handler.py:
--------------------------------------------------------------------------------
1 | from ...util.math_parsing_util import (
2 | extract_answer,
3 | math_equal,
4 | strip_answer_string,
5 | )
6 |
7 | from ..math.math_handler import MathTaskHandler
8 |
9 |
10 | class MinervaMathTaskHandler(MathTaskHandler):
11 |
12 | def check_correctness(self, problem, generation):
13 | answer = extract_answer(problem[self.task_config.answer_key])
14 | answer = strip_answer_string(answer)
15 |
16 | pred = extract_answer(generation)
17 | pred = strip_answer_string(pred)
18 | return math_equal(pred, answer)
19 |
--------------------------------------------------------------------------------
/evals/tasks/mmlu/mmlu.yaml:
--------------------------------------------------------------------------------
1 | handler: mmlu
2 | dataset_path: cais/mmlu
3 | dataset_subset: all
4 | dataset_split: test
5 | question_key: question
6 | answer_key: answer
7 | templating_parameters:
8 | template: "Return your final response within \\boxed{{}}. {prompt}"
9 |
--------------------------------------------------------------------------------
/evals/tasks/mmlu/mmlu_handler.py:
--------------------------------------------------------------------------------
1 | from ...util.math_parsing_util import (
2 | get_multiple_choice_answer,
3 | mmlu_pro_extract_answer,
4 | )
5 |
6 | from ..base import TaskConfig, TaskHandler
7 |
8 |
9 | class MMLUTaskHandler(TaskHandler):
10 | def generate_prompt(self, problem):
11 | multiple_choice_string = self.get_multiple_choice_answers(problem)
12 | prompt = problem["question"] + "\n" + multiple_choice_string
13 | return self.task_config.templating_parameters["template"].format(prompt=prompt)
14 |
15 | def check_correctness(self, problem, generation):
16 | pred = get_multiple_choice_answer(generation)
17 | abcd = "ABCD"
18 | answer = abcd[problem[self.task_config.answer_key]]
19 | return answer == pred
20 |
21 | def update_results(self, problem, response):
22 | # Initialize the response structure
23 | response_entry = {
24 | "content": response,
25 | "correctness": None,
26 | "reason": None,
27 | }
28 | curr_res = self.check_correctness(problem, generation=response)
29 | if curr_res:
30 | response_entry["correctness"] = True
31 | response_entry["reason"] = ""
32 | else:
33 | response_entry["correctness"] = False
34 | response_entry["reason"] = "Solution is incorrect."
35 | return response_entry
36 |
37 | def get_multiple_choice_answers(self, problem):
38 | options = problem["choices"]
39 | options_str = ""
40 | for _, (label, option) in enumerate(zip("ABCD", options)):
41 | options_str += f"({label}) {str(option).strip()} "
42 | options_str = options_str[:-1] # remove the last space
43 | return f"Answer Choices: {options_str}"
44 |
45 | def load_and_filter_dataset(
46 | self, start, end, split=None, subset=None, difficulty=None
47 | ):
48 | dataset = self.load_dataset(subset=subset, split=split).to_pandas()
49 | return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:]
50 |
51 |
52 | class MMLUProTaskHandler(MMLUTaskHandler):
53 | def __init__(self, task_config: TaskConfig):
54 | super().__init__(task_config)
55 | self.choices = [
56 | "A",
57 | "B",
58 | "C",
59 | "D",
60 | "E",
61 | "F",
62 | "G",
63 | "H",
64 | "I",
65 | "J",
66 | "K",
67 | "L",
68 | "M",
69 | "N",
70 | "O",
71 | "P",
72 | ]
73 |
74 | def generate_prompt(self, prompt):
75 | return self.task_config.templating_parameters["template"].format(prompt=prompt)
76 |
77 | def check_correctness(self, problem, generation):
78 | pred = mmlu_pro_extract_answer(generation)
79 | answer = self.choices[problem["answer_index"]]
80 | return answer == pred
81 |
82 | def get_multiple_choice_answers(self, problem):
83 | options = problem["options"]
84 | for i, (label, option) in enumerate(zip(self.choices[: len(options)], options)):
85 | options[i] = f"({label}) {str(option).strip()}"
86 | options = " ".join(options)
87 | return f"Answer Choices: {options}"
88 |
89 | def load_and_filter_dataset(
90 | self, start, end, split=None, subset=None, difficulty=None
91 | ):
92 | dataset = self.load_dataset(subset=subset, split=split).to_pandas()
93 | return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:]
94 |
--------------------------------------------------------------------------------
/evals/tasks/mmlu/mmlu_pro.yaml:
--------------------------------------------------------------------------------
1 | handler: mmlu_pro
2 | dataset_path: TIGER-Lab/MMLU-Pro
3 | dataset_subset: default
4 | dataset_split: test
5 | question_key: question
6 | answer_key: answer
7 | templating_parameters:
8 | template: "Return your final response within \\boxed{{}}. {prompt}"
9 |
--------------------------------------------------------------------------------
/evals/tasks/numina/numina.yaml:
--------------------------------------------------------------------------------
1 | handler: numina
2 | dataset_path: "AI-MO/NuminaMath-CoT"
3 | dataset_subset: null
4 | dataset_split: train
5 | question_key: problem
6 | answer_key: solution
7 | templating_parameters:
8 | template: "Return your final response within \\boxed{{}}. {prompt}"
9 | # Optionally, you can filter the dataset by difficulty
10 | # preprocess_config:
11 | # filter_difficulty: true
12 | # math_difficulty_lower_bound: 4
13 | # math_difficulty_upper_bound: 9
14 | # source: math
15 |
--------------------------------------------------------------------------------
/evals/tasks/numina/numina_amc_aime.yaml:
--------------------------------------------------------------------------------
1 | handler: numina
2 | dataset_path: "AI-MO/NuminaMath-CoT"
3 | dataset_subset: null
4 | dataset_split: train
5 | question_key: problem
6 | answer_key: solution
7 | templating_parameters:
8 | template: "Return your final response within \\boxed{{}}. {prompt}"
9 | preprocess_config:
10 | filter_difficulty: true
11 | math_difficulty_lower_bound: 1
12 | math_difficulty_upper_bound: 9
13 | source: amc_aime
14 |
--------------------------------------------------------------------------------
/evals/tasks/numina/numina_handler.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | from datasets import load_dataset
4 |
5 | from ...util.common import TimeoutException, timeout
6 | from ...util.math_parsing_util import (
7 | extract_answer,
8 | math_equal,
9 | )
10 |
11 | from ..base import TaskHandler
12 |
13 |
14 | class NUMINATaskHandler(TaskHandler):
15 |
16 | def generate_prompt(self, problem: Dict[str, Any]):
17 | prompt = problem["problem"]
18 | return self.task_config.templating_parameters["template"].format(prompt=prompt)
19 |
20 | @timeout(5) # Add timeout of 5 seconds
21 | def check_correctness(self, problem, generation):
22 | solution = extract_answer(problem[self.task_config.answer_key])
23 | pred = extract_answer(generation)
24 | return math_equal(pred, solution)
25 |
26 | def update_results(self, problem, response):
27 | # Initialize the response structure
28 | response_entry = {
29 | "content": response,
30 | "correctness": None,
31 | "reason": None,
32 | }
33 |
34 | try:
35 | curr_res = self.check_correctness(problem, generation=response)
36 | if curr_res:
37 | response_entry["correctness"] = True
38 | response_entry["reason"] = ""
39 | else:
40 | response_entry["correctness"] = False
41 | response_entry["reason"] = "Solution is incorrect."
42 | except TimeoutException as e:
43 | response_entry["correctness"] = False
44 | response_entry["reason"] = str(e)
45 |
46 | return response_entry
47 |
48 | @staticmethod
49 | def get_difficulty_dict(subset, start, end):
50 | diff_dict = {}
51 | dataset = load_dataset(
52 | "NovaSky-AI/labeled_numina_difficulty_859K",
53 | trust_remote_code=True,
54 | split="train",
55 | )
56 | for example in dataset:
57 | # print(example)
58 | diff_dict[example["problem"]] = example["gpt_difficulty_parsed"]
59 | return diff_dict
60 |
61 | def load_and_filter_dataset(
62 | self, start, end, split=None, subset=None, difficulty=None
63 | ):
64 | dataset = self.load_dataset(subset=subset, split=split)
65 |
66 | if "source" in self.task_config.preprocess_config:
67 | source = self.task_config.preprocess_config["source"]
68 | dataset = dataset.filter(lambda x: x["source"] == source)
69 |
70 | dataset = dataset.to_pandas()
71 | # TODO (sumanthrh): this is hacky for numina. the start and end filter should be applied at the very end
72 | # it is kept here for consistency with the original code.
73 | dataset = dataset.iloc[start:end] if end > 0 else dataset.iloc[start:]
74 | dataset = dataset[dataset["solution"].str.contains("boxed", na=False)]
75 |
76 | if "filter_difficulty" in self.task_config.preprocess_config:
77 | lower_bound = self.task_config.preprocess_config[
78 | "math_difficulty_lower_bound"
79 | ]
80 | upper_bound = self.task_config.preprocess_config[
81 | "math_difficulty_upper_bound"
82 | ]
83 | diff_dict = self.get_difficulty_dict(
84 | self.task_config.dataset_subset, start, end
85 | )
86 | dataset = dataset[
87 | dataset["problem"]
88 | .map(diff_dict)
89 | .apply(lambda x: x >= lower_bound and x <= upper_bound)
90 | ]
91 |
92 | return dataset
93 |
--------------------------------------------------------------------------------
/evals/tasks/numina/numina_math.yaml:
--------------------------------------------------------------------------------
1 | handler: numina
2 | dataset_path: "AI-MO/NuminaMath-CoT"
3 | dataset_subset: null
4 | dataset_split: train
5 | question_key: problem
6 | answer_key: solution
7 | templating_parameters:
8 | template: "Return your final response within \\boxed{{}}. {prompt}"
9 | preprocess_config:
10 | filter_difficulty: true
11 | math_difficulty_lower_bound: 4
12 | math_difficulty_upper_bound: 9
13 | source: math
14 |
--------------------------------------------------------------------------------
/evals/tasks/numina/numina_olympiads.yaml:
--------------------------------------------------------------------------------
1 | handler: numina
2 | dataset_path: "AI-MO/NuminaMath-CoT"
3 | dataset_subset: null
4 | dataset_split: train
5 | question_key: problem
6 | answer_key: solution
7 | templating_parameters:
8 | template: "Return your final response within \\boxed{{}}. {prompt}"
9 | preprocess_config:
10 | filter_difficulty: true
11 | math_difficulty_lower_bound: 9
12 | math_difficulty_upper_bound: 9
13 | source: olympiads
14 |
--------------------------------------------------------------------------------
/evals/tasks/olympiadbench/olympiadbench_handler.py:
--------------------------------------------------------------------------------
1 | from ...util.math_parsing_util import (
2 | extract_answer,
3 | math_equal,
4 | strip_answer_string,
5 | )
6 |
7 | from ..math.math_handler import MathTaskHandler
8 |
9 |
10 | class OlympiadBenchMathTaskHandler(MathTaskHandler):
11 | def check_correctness(self, problem, generation):
12 | # all problems have final answer in a list
13 | answer = strip_answer_string(problem[self.task_config.answer_key][0])
14 | pred = extract_answer(generation)
15 | pred = strip_answer_string(pred)
16 | return math_equal(pred, answer)
17 |
--------------------------------------------------------------------------------
/evals/tasks/olympiadbench/olympiadbench_math_en.yaml:
--------------------------------------------------------------------------------
1 | handler: olympiadbench_math
2 | dataset_path: Hothan/OlympiadBench
3 | dataset_subset: OE_TO_maths_en_COMP
4 | dataset_split: train
5 | question_key: question
6 | answer_key: final_answer
7 | templating_parameters:
8 | template: "Return your final response within \\boxed{{}}. {question}"
9 |
--------------------------------------------------------------------------------
/evals/tasks/omni_math/omni_handler.py:
--------------------------------------------------------------------------------
1 | from ...util.math_parsing_util import (
2 | extract_answer,
3 | math_equal,
4 | strip_answer_string,
5 | )
6 |
7 | from ..math.math_handler import MathTaskHandler
8 |
9 |
10 | class OMNIMathTaskHandler(MathTaskHandler):
11 | def generate_prompt(self, problem):
12 | return self.task_config.templating_parameters["template"].format(**problem)
13 |
14 | def check_correctness(self, problem, generation):
15 | # no preprocessing needed
16 | answer = problem[self.task_config.answer_key]
17 | pred = extract_answer(generation)
18 | pred = strip_answer_string(pred)
19 | return math_equal(pred, answer)
20 |
--------------------------------------------------------------------------------
/evals/tasks/omni_math/omni_math.yaml:
--------------------------------------------------------------------------------
1 | handler: omni_math
2 | dataset_path: "KbsdJames/Omni-MATH" # repo ID in huggingface
3 | dataset_subset: null # which subset on huggingface
4 | dataset_split: test_rule_based # Rule based evaluation
5 | dataset_kwargs:
6 | # NOTE: This is using the subset for rule-based evaluation in the below PR
7 | revision: refs/pr/2
8 | question_key: problem
9 | answer_key: answer
10 | templating_parameters:
11 | template: "Return your final response within \\boxed{{}}. {problem}"
--------------------------------------------------------------------------------
/evals/tasks/taco/taco.yaml:
--------------------------------------------------------------------------------
1 | handler: taco
2 | dataset_path: "BAAI/TACO"
3 | dataset_subset: MEDIUM
4 | dataset_split: train
5 | dataset_kwargs:
6 | trust_remote_code: true
7 | question_key: question
8 | answer_key: null
9 | templating_parameters:
10 | initial_template: "\nQUESTION:\n{prompt}"
11 | # Add starter code to initial template
12 | starter_code_template: "{input}\n{starter_code}"
13 | # stdin template is used when there is no starter code or fn_name
14 | stdin_template: "{input}\nUse Standard Input format\nANSWER:\n"
15 | # call template is used when there is starter code or fn_name
16 | call_template: "{input}\nUse Call-Based format\nANSWER:\n"
17 | # Optionally, you can filter the dataset by difficulty
18 | # preprocess_config:
19 | # difficulty: easy
20 |
21 |
--------------------------------------------------------------------------------
/evals/tasks/taco/taco_handler.py:
--------------------------------------------------------------------------------
1 | import json
2 | import multiprocessing
3 | from multiprocessing import Manager
4 |
5 | import numpy as np
6 |
7 | from ...util.common import has_code
8 |
9 | from ..base import TaskHandler
10 | from .taco_util import run_test as taco_run_test
11 |
12 |
13 | class TACOTaskHandler(TaskHandler):
14 |
15 | def generate_prompt(self, problem):
16 | prompt = problem["question"]
17 | starter_code = (
18 | None if len(problem["starter_code"]) == 0 else problem["starter_code"]
19 | )
20 | try:
21 | input_outpout = json.loads(problem["input_output"])
22 | fn_name = (
23 | None if not input_outpout.get("fn_name") else input_outpout["fn_name"]
24 | )
25 | except ValueError:
26 | fn_name = None
27 |
28 | _input = self.task_config.templating_parameters["initial_template"].format(
29 | prompt=prompt
30 | )
31 |
32 | if starter_code:
33 | _input = self.task_config.templating_parameters[
34 | "starter_code_template"
35 | ].format(input=_input, starter_code=starter_code)
36 | else:
37 | _input = self.task_config.templating_parameters["initial_template"].format(
38 | prompt=prompt
39 | )
40 | if (not fn_name) and (not starter_code):
41 | _input = self.task_config.templating_parameters["stdin_template"].format(
42 | input=_input
43 | )
44 | else:
45 | _input = self.task_config.templating_parameters["call_template"].format(
46 | input=_input
47 | )
48 |
49 | return _input
50 |
51 | def check_correctness(self, problem, generation):
52 | TIME_OUT = 300
53 |
54 | manager = Manager()
55 | result = manager.list()
56 | p = multiprocessing.Process(
57 | target=_temp_run, args=(problem, generation, False, result)
58 | )
59 | p.start()
60 | p.join(timeout=TIME_OUT + 1)
61 | if p.is_alive():
62 | p.kill()
63 | return bool(result and np.all(result[0]))
64 |
65 | def update_results(self, problem, response):
66 | # Initialize the response structure
67 | response_entry = {
68 | "content": response,
69 | "correctness": None,
70 | "reason": None,
71 | }
72 | code_filter_result = has_code(response)
73 | if len(code_filter_result) == 0:
74 | response_entry["correctness"] = False
75 | response_entry["reason"] = "Does not contain code component."
76 | else:
77 | last_code = code_filter_result[-1]
78 | curr_res = self.check_correctness(problem, generation=last_code)
79 | if curr_res:
80 | response_entry["correctness"] = True
81 | response_entry["reason"] = ""
82 | else:
83 | response_entry["correctness"] = False
84 | response_entry["reason"] = "Code is incorrect."
85 |
86 | return response_entry
87 |
88 | def load_and_filter_dataset(
89 | self, start, end, split=None, subset=None, difficulty=None
90 | ):
91 | dataset = self.load_dataset(subset=subset, split=split).to_pandas()
92 | if difficulty or "difficulty" in self.task_config.preprocess_config:
93 | difficulty = (
94 | difficulty
95 | if difficulty
96 | else self.task_config.preprocess_config["difficulty"]
97 | )
98 | dataset = dataset.filter(
99 | lambda example: example["difficulty"] == difficulty
100 | )
101 |
102 | return dataset.iloc[start:end] if end > 0 else dataset.iloc[start:]
103 |
104 |
105 | def _temp_run(problem, generation, debug, result):
106 | try:
107 | result.append(
108 | taco_run_test(problem["input_output"], test=generation, debug=debug)
109 | )
110 | except Exception as e:
111 | print(f"Error in _temp_run: {e}")
112 |
--------------------------------------------------------------------------------
/evals/tasks/task_util.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | from typing import Dict
4 |
5 |
6 | def get_tasks(task_root_dir: str) -> Dict[str, str]:
7 | """Returns a dictionary of task names and their corresponding yaml file paths"""
8 | # list all yamls in subdirectories
9 | name_to_yaml = {}
10 | for yaml_file in glob.glob(
11 | os.path.join(task_root_dir, "**", "*.yaml"), recursive=True
12 | ):
13 | # arc.yaml -> arc
14 | name = os.path.basename(yaml_file).split(".")[0]
15 |
16 | name_to_yaml[name] = yaml_file
17 |
18 | return name_to_yaml
19 |
--------------------------------------------------------------------------------
/evals/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanomaoli/llm_reproducibility/8a373c5a159a27e59783394827cecadd6255484e/evals/util/__init__.py
--------------------------------------------------------------------------------
/evals/util/cli_util.py:
--------------------------------------------------------------------------------
1 | from ast import literal_eval
2 | from typing import Any, List
3 |
4 | import msgpack
5 | import xxhash
6 |
7 |
8 | def _parse_multi_args(vals: str) -> dict:
9 | """Parse a multi-value argument into a dictionary.
10 |
11 | The argument can either be a comma separated list of key=value pairs, or a dictionary.
12 | """
13 | try:
14 | # try to parse as a dictionary first
15 | my_dict = literal_eval(vals)
16 | assert isinstance(my_dict, dict)
17 | return my_dict
18 | except Exception:
19 | # try to parse as a comma separated list of key=value pairs
20 | vals = vals.replace(" ", "")
21 | if not len(vals):
22 | return {}
23 | ret = {}
24 | for val in vals.split(","):
25 | k, v = val.split("=")
26 | try:
27 | ret[k] = literal_eval(v)
28 | except (ValueError, SyntaxError):
29 | # if literal eval fails, propagate as a string
30 | ret[k] = v
31 | return ret
32 |
33 |
34 | def parse_multi_args(vals: str) -> dict:
35 | try:
36 | return _parse_multi_args(vals)
37 | except Exception as err:
38 | raise ValueError(
39 | f"Expected comma separated list of parameters arg1=val1,args2=val2 or a dictionary, got invalid argument {vals}. "
40 | ) from err
41 |
42 |
43 | def comma_separated_to_list(vals: str) -> List[str]:
44 | vals = vals.replace(" ", "")
45 | return vals.split(",")
46 |
47 |
48 | def to_tuple(d) -> tuple:
49 | if isinstance(d, dict):
50 | return tuple(map(to_tuple, d.items()))
51 | elif isinstance(d, (set, list, tuple)):
52 | return tuple(map(to_tuple, d))
53 | else:
54 | return d
55 |
56 |
57 | def get_deterministic_hash(d: Any, num_digits: int = 6) -> str:
58 | """Get deterministic hash"""
59 | tuple_form = to_tuple(d)
60 | serialized = msgpack.packb(tuple_form, use_bin_type=True)
61 | return xxhash.xxh32(serialized).hexdigest()[:num_digits]
62 |
--------------------------------------------------------------------------------
/evals/util/common.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os
3 | import random
4 | import re
5 |
6 | import numpy as np
7 | import torch
8 |
9 |
10 | def set_seed(seed: int):
11 | os.environ["PYTHONHASHSEED"] = str(seed)
12 | random.seed(seed)
13 | np.random.seed(seed)
14 | torch.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 |
17 |
18 | class TimeoutException(Exception):
19 | """Custom exception for function timeout."""
20 |
21 | pass
22 |
23 |
24 | def timeout(seconds):
25 | """Decorator to enforce a timeout on a function using multiprocessing."""
26 |
27 | def decorator(func):
28 | def wrapper(*args, **kwargs):
29 | # A queue to store the result or exception
30 | queue = multiprocessing.Queue()
31 |
32 | def target(queue, *args, **kwargs):
33 | try:
34 | result = func(*args, **kwargs)
35 | queue.put((True, result))
36 | except Exception as e:
37 | queue.put((False, e))
38 |
39 | process = multiprocessing.Process(
40 | target=target, args=(queue, *args), kwargs=kwargs
41 | )
42 | process.start()
43 | process.join(seconds)
44 |
45 | if process.is_alive():
46 | process.terminate()
47 | process.join()
48 | raise TimeoutException(
49 | f"Function '{func.__name__}' timed out after {seconds} seconds!"
50 | )
51 |
52 | success, value = queue.get()
53 | if success:
54 | return value
55 | else:
56 | raise value
57 |
58 | return wrapper
59 |
60 | return decorator
61 |
62 |
63 | def has_code(response):
64 | pattern = r"```(?:[a-zA-Z]*)\n(.*?)```"
65 | # Use re.DOTALL to match multiline content inside backticks
66 | matches = re.findall(pattern, response, re.DOTALL)
67 | # print(matches)
68 | return matches
69 |
--------------------------------------------------------------------------------
/evals/util/metrics.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | from collections import defaultdict
4 | from typing import Dict, List
5 |
6 | import numpy as np
7 |
8 |
9 | def _pass_at_k(n, c, k):
10 | """
11 | :param n: total number of samples
12 | :param c: number of correct samples
13 | :param k: k in pass@$k$
14 | """
15 | if n - c < k:
16 | return 1.0
17 | return float(1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)))
18 |
19 |
20 | def pass_at_k(N: int, id_to_scores: Dict[str, List[int]]):
21 | final_passk_scores = {}
22 | k_to_passk_scores = defaultdict(list) # k -> list of scores
23 | for _, sample_scores in id_to_scores.items():
24 | # Start at N
25 | k = N
26 | is_power_of_2 = N == 2 ** (int(math.log2(N)))
27 | while k > 0:
28 | # calculate pass @ k
29 | num_correct = np.sum(sample_scores)
30 | pass_k = _pass_at_k(N, num_correct, k)
31 | k_to_passk_scores[k].append(pass_k)
32 | # corner case: when N is not a power of 2
33 | if not is_power_of_2 and k == N:
34 | k = 2 ** (int(math.log2(N)))
35 | else:
36 | # otherwise, just divide by 2
37 | k = k // 2
38 |
39 | for k in k_to_passk_scores:
40 | final_passk_scores[f"{k=}"] = round(np.mean(k_to_passk_scores[k]) * 100, 3)
41 |
42 | # print("Final pass @ k:")
43 | for k, s in final_passk_scores.items():
44 | logging.info(f"k: {k}, pass @ k: {s}")
45 | return final_passk_scores
46 |
--------------------------------------------------------------------------------
/evals/util/response.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional
3 |
4 |
5 | @dataclass
6 | class Response:
7 | response: List[str]
8 | num_completion_tokens: List[int]
9 | num_input_tokens: int
10 | index: Optional[int] = None
11 |
12 | @classmethod
13 | def from_ray_response(cls, response) -> "Response":
14 | """
15 | Factory method to create a Response instance from a rayllm response.
16 |
17 | Args:
18 | response: Ray response object containing generated text and token information
19 |
20 | Returns:
21 | Responses: New instance initialized with Ray response data
22 | """
23 |
24 | if isinstance(response["generated_text"], list):
25 | # n > 1 samples
26 | response_texts = response["generated_text"]
27 | num_completion_tokens = [
28 | int(response["num_generated_tokens"][i])
29 | for i in range(len(response["num_generated_tokens"]))
30 | ]
31 | else:
32 | response_texts = [response["generated_text"]]
33 | num_completion_tokens = [int(response["num_generated_tokens"])]
34 | return cls(
35 | response=response_texts,
36 | num_completion_tokens=num_completion_tokens,
37 | num_input_tokens=int(response["num_input_tokens"]),
38 | index=response["index"],
39 | )
40 |
41 | @classmethod
42 | def from_openai_response(cls, response) -> "Response":
43 | """
44 | Factory method to create a Response instance from an OpenAI response.
45 |
46 | Args:
47 | response: OpenAI response object containing message content and token information
48 |
49 | Returns:
50 | Responses: New instance initialized with OpenAI response data
51 | """
52 | return cls(
53 | response=[
54 | response.choices[i].message.content
55 | for i in range(len(response.choices))
56 | ],
57 | num_completion_tokens=[
58 | response.usage.completion_tokens if i == 0 else 0
59 | for i in range(len(response.choices))
60 | ],
61 | num_input_tokens=response.usage.prompt_tokens,
62 | )
63 |
64 | @classmethod
65 | def from_vllm_response(cls, response) -> "Response":
66 | """
67 | Factory method to create a Response instance from a vLLM response.
68 |
69 | Args:
70 | response: vLLM response object containing output text and token information
71 |
72 | Returns:
73 | Responses: New instance initialized with vLLM response data
74 | """
75 | response_texts = [
76 | response.outputs[i].text for i in range(len(response.outputs))
77 | ]
78 | num_completion_tokens = [
79 | len(response.outputs[i].token_ids) for i in range(len(response.outputs))
80 | ]
81 | return cls(
82 | response=response_texts,
83 | num_completion_tokens=num_completion_tokens,
84 | num_input_tokens=len(response.prompt_token_ids),
85 | )
86 |
87 |
88 | @dataclass
89 | class SingleParsedResponse:
90 | content: str
91 | correctness: Optional[bool] = None
92 | reason: Optional[str] = None
93 |
94 | def to_dict(self):
95 | return {
96 | "content": self.content,
97 | "correctness": self.correctness,
98 | "reason": self.reason,
99 | }
100 |
--------------------------------------------------------------------------------
/evals/util/results.py:
--------------------------------------------------------------------------------
1 | import json
2 | from dataclasses import asdict, dataclass
3 | from pathlib import Path
4 | from typing import Any, Dict, Optional
5 |
6 |
7 | @dataclass
8 | class SummaryResults:
9 | # configuration: Dict[str, Any]
10 | # total_completion_tokens: int = 0
11 | # avg_completion_tokens: float = 0
12 | # total_prompt_tokens: int = 0
13 | # avg_prompt_tokens: float = 0
14 | accuracy: float = 0.0
15 | pass_at_k: Optional[Dict[str, float]] = None
16 | # mean_of_stdevs: float = None
17 | # run_level_stdev: float = None
18 |
19 | def to_json_dict(self) -> Dict[str, Any]:
20 | """Convert to a JSON-compatible dictionary."""
21 | return asdict(self)
22 |
23 |
24 | def save_summary(summary_path: Path, summary: SummaryResults) -> None:
25 | with open(summary_path, "w", encoding="utf-8") as f:
26 | json.dump(summary.to_json_dict(), f, indent=4)
27 |
--------------------------------------------------------------------------------
/figures/reproduciblellm_fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanomaoli/llm_reproducibility/8a373c5a159a27e59783394827cecadd6255484e/figures/reproduciblellm_fig1.png
--------------------------------------------------------------------------------
/patch_vllm.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple
2 | import torch
3 | from typing import Optional, Tuple
4 | from vllm.model_executor.layers.rotary_embedding import get_rope, RotaryEmbedding
5 | from vllm.model_executor.layers.linear import UnquantizedLinearMethod
6 | from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod
7 | from vllm.attention import get_attn_backend
8 | import torch.nn.functional as F
9 | from vllm.utils import get_dtype_size
10 | from vllm.model_executor.layers.linear import UnquantizedLinearMethod
11 | from vllm.worker.cache_engine import CacheEngine
12 | from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
13 | get_dtype_size, is_pin_memory_available)
14 | from vllm.distributed import get_pp_group
15 | from vllm.model_executor.models.qwen2 import Qwen2Model, LogitsProcessor, get_sampler, ParallelLMHead, PPMissingLayer, maybe_prefix
16 | from vllm.attention import Attention
17 | # from vllm.v1.worker.gpu_model_runner
18 | import sys
19 | import pdb
20 | # from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
21 | # STR_DTYPE_TO_TORCH_DTYPE["float32"] = torch.float32
22 | # from vllm.config import CacheConfig
23 |
24 |
25 | class ForkedPdb(pdb.Pdb):
26 | """
27 | PDB Subclass for debugging multi-processed code
28 | Suggested in: https://stackoverflow.com/questions/4716533/how-to-attach-debugger-to-a-python-subproccess
29 | """
30 | def interaction(self, *args, **kwargs):
31 | _stdin = sys.stdin
32 | try:
33 | sys.stdin = open('/dev/stdin')
34 | pdb.Pdb.interaction(self, *args, **kwargs)
35 | finally:
36 | sys.stdin = _stdin
37 |
38 | def convert_linear_weights_to_fp16(model: torch.nn.Module):
39 | """Convert weights of linear layers to fp16 for storage."""
40 | for name, module in model.named_modules():
41 | if 'proj' in name:
42 | module.weight.data = module.weight.data.to(torch.float16)
43 | if module.bias is not None:
44 | module.bias.data = module.bias.data.to(torch.float16)
45 |
46 | def convert_linear_weights_to_bfloat16(model: torch.nn.Module):
47 | """Convert weights of linear layers to bfloat16 for storage."""
48 | for name, module in model.named_modules():
49 | if 'proj' in name:
50 | module.weight.data = module.weight.data.to(torch.bfloat16)
51 | if module.bias is not None:
52 | module.bias.data = module.bias.data.to(torch.bfloat16)
53 |
54 | def our_attn_forward(
55 | self,
56 | positions: torch.Tensor,
57 | hidden_states: torch.Tensor,
58 | ) -> torch.Tensor:
59 | # Input is already in fp32 from previous layer
60 | qkv, _ = self.qkv_proj(hidden_states)
61 | q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
62 | q, k = self.rotary_emb(positions, q, k)
63 | attn_output = self.attn(q, k, v)
64 | output, _ = self.o_proj(attn_output)
65 | return output # Keep in fp32
66 |
67 | def our_fp32_rope_forward_cuda(
68 | self,
69 | positions: torch.Tensor,
70 | query: torch.Tensor,
71 | key: torch.Tensor,
72 | offsets: Optional[torch.Tensor] = None,
73 | ) -> Tuple[torch.Tensor, torch.Tensor]:
74 | from vllm import _custom_ops as ops
75 | # Everything is already in fp32, no need for conversion
76 | if self.cos_sin_cache.device != query.device:
77 | self.cos_sin_cache = self.cos_sin_cache.to(query.device)
78 |
79 | if offsets is not None:
80 | ops.batched_rotary_embedding(positions, query, key, self.head_size,
81 | self.cos_sin_cache,
82 | self.is_neox_style, self.rotary_dim,
83 | offsets)
84 | else:
85 | ops.rotary_embedding(positions, query, key, self.head_size,
86 | self.cos_sin_cache, self.is_neox_style)
87 | return query, key
88 |
89 |
90 | def our_linear_apply(self,
91 | layer: torch.nn.Module,
92 | x: torch.Tensor,
93 | bias: Optional[torch.Tensor] = None) -> torch.Tensor:
94 | # x is already in fp32
95 | assert x.dtype == torch.float32
96 | # Upcast weights to fp32 for computation
97 | weight = layer.weight.to(torch.float32)
98 | if bias is not None:
99 | bias = bias.to(torch.float32)
100 | return F.linear(x, weight, bias) # Result stays in fp32
101 |
102 | def patch_cache_engine():
103 | original_init = CacheEngine.__init__
104 | def custom_cache_engine_init(
105 | self,
106 | cache_config,
107 | model_config,
108 | parallel_config,
109 | device_config,
110 | ) -> None:
111 | self.cache_config = cache_config
112 | self.model_config = model_config
113 | self.parallel_config = parallel_config
114 | self.device_config = device_config
115 |
116 | self.head_size = model_config.get_head_size()
117 | # Models like Jamba, have mixed typed layers, E.g Mamba
118 | self.num_attention_layers = model_config.get_num_layers_by_block_type(
119 | parallel_config, LayerBlockType.attention)
120 | self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
121 |
122 | self.block_size = cache_config.block_size
123 | self.num_gpu_blocks = cache_config.num_gpu_blocks
124 | if self.num_gpu_blocks:
125 | self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
126 | self.num_cpu_blocks = cache_config.num_cpu_blocks
127 | if self.num_cpu_blocks:
128 | self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
129 |
130 | self.dtype = torch.float32 # Force fp32 for cache
131 |
132 | # Get attention backend.
133 | self.attn_backend = get_attn_backend(self.head_size,
134 | model_config.dtype,
135 | cache_config.cache_dtype,
136 | self.block_size,
137 | model_config.is_attention_free,
138 | use_mla=model_config.use_mla)
139 |
140 | # Initialize the cache.
141 | self.gpu_cache = self._allocate_kv_cache(
142 | self.num_gpu_blocks, self.device_config.device_type)
143 | self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
144 |
145 | @staticmethod
146 | def our_get_cache_block_size(
147 | cache_config,
148 | model_config,
149 | parallel_config,
150 | ) -> int:
151 | head_size = model_config.get_head_size()
152 | num_heads = model_config.get_num_kv_heads(parallel_config)
153 | num_attention_layers = model_config.get_num_layers_by_block_type(
154 | parallel_config, LayerBlockType.attention)
155 |
156 | dtype = torch.float32 # Force fp32 for cache
157 | key_cache_entry = num_heads * head_size
158 |
159 | # For MLA there is no value cache, since the latent vector
160 | # is joint keys and values.
161 | value_cache_entry = key_cache_entry if not model_config.use_mla else 0
162 | total = num_attention_layers * cache_config.block_size * \
163 | (key_cache_entry + value_cache_entry)
164 |
165 | dtype_size = get_dtype_size(dtype)
166 | return dtype_size * total
167 |
168 | CacheEngine.__init__ = custom_cache_engine_init
169 | CacheEngine.get_cache_block_size = our_get_cache_block_size
170 |
171 |
172 | def patch_qwen2_vllm():
173 | # from vllm.platforms import _Backend
174 | # from vllm.attention.selector import global_force_attn_backend
175 | # global_force_attn_backend(_Backend.XFORMERS)
176 | import os
177 | os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS" # FLASHINFER
178 | from vllm.model_executor.models.qwen2 import Qwen2Attention, Qwen2ForCausalLM
179 |
180 | patch_cache_engine()
181 |
182 | def new_qwen2_lm_init(self, *, vllm_config, prefix: str = ""):
183 | torch.nn.Module.__init__(self)
184 | config = vllm_config.model_config.hf_config
185 | quant_config = vllm_config.quant_config
186 | lora_config = vllm_config.lora_config
187 |
188 | self.config = config
189 | self.lora_config = lora_config
190 | self.quant_config = quant_config
191 |
192 | self.model = Qwen2Model(vllm_config=vllm_config,
193 | prefix=maybe_prefix(prefix, "model"))
194 |
195 | if get_pp_group().is_last_rank:
196 | if config.tie_word_embeddings:
197 | self.lm_head = self.model.embed_tokens
198 | else:
199 | self.lm_head = ParallelLMHead(config.vocab_size,
200 | config.hidden_size,
201 | quant_config=quant_config,
202 | prefix=maybe_prefix(
203 | prefix, "lm_head"))
204 | else:
205 | self.lm_head = PPMissingLayer()
206 |
207 | # Convert linear weights to fp16 for storage
208 | convert_linear_weights_to_bfloat16(self.model)
209 | if not isinstance(self.lm_head, PPMissingLayer):
210 | convert_linear_weights_to_bfloat16(self.lm_head)
211 |
212 | self.logits_processor = LogitsProcessor(config.vocab_size, scale=1.2)
213 | self.sampler = get_sampler()
214 | self.make_empty_intermediate_tensors = (
215 | self.model.make_empty_intermediate_tensors)
216 |
217 | Qwen2ForCausalLM.__init__ = new_qwen2_lm_init
218 |
219 | # Store the original __init__
220 | original_init = Qwen2Attention.__init__
221 | def new_qwen2_init(self, *args, **kwargs):
222 | # Call the original init first
223 | original_init(self, *args, **kwargs)
224 | self.rotary_emb = get_rope(
225 | self.head_dim,
226 | rotary_dim=self.head_dim,
227 | max_position=kwargs['max_position'],
228 | base=self.rope_theta,
229 | rope_scaling=kwargs['rope_scaling'],
230 | dtype=torch.float32 # RoPE computation in fp32
231 | )
232 |
233 | Qwen2Attention.__init__ = new_qwen2_init
234 | # Replace the apply method
235 | UnquantizedLinearMethod.apply = our_linear_apply
236 | UnquantizedEmbeddingMethod.apply = our_linear_apply
237 | RotaryEmbedding.forward_cuda = our_fp32_rope_forward_cuda
238 | Qwen2Attention.forward = our_attn_forward
239 | print("Patched vLLM: Model loaded in fp32, linear weights stored in fp16, all computations in fp32")
--------------------------------------------------------------------------------
/prompt_util/prompt_template.py:
--------------------------------------------------------------------------------
1 |
2 | def make_conversation_from_contents(
3 | contents,
4 | system_prompt=None,
5 | user_template=None,
6 | assistant_prefill=None,
7 | ):
8 | """Makes a conversation given a list of user/assistant message strings.
9 |
10 | If system_prompt is provided, it will be added as the first message.
11 | If user_template is provided, it will be used to format the user messages. This is useful for model-specific formatting.
12 |
13 | Args:
14 | content: A list of user/assistant message strings.
15 | system_prompt: An optional string for the system prompt.
16 | user_template: An optional string for the user template.
17 |
18 | Returns:
19 | A list of dictionaries representing the conversation.
20 | """
21 |
22 | conversation = []
23 | if system_prompt:
24 | conversation.append({"role": "system", "content": system_prompt})
25 |
26 | for i, content in enumerate(contents):
27 | if i % 2 == 0:
28 | content = user_template.format(content) if user_template else content
29 | conversation.append({"role": "user", "content": content})
30 | else:
31 | conversation.append({"role": "assistant", "content": content})
32 |
33 | if assistant_prefill and conversation[-1]["role"] == "user":
34 | conversation.append({"role": "assistant", "content": assistant_prefill})
35 |
36 | return conversation
--------------------------------------------------------------------------------