├── .gitignore ├── README.md ├── accelerate_configs ├── deepspeed_zero3.yaml └── deepspeed_zero3_cpu.yaml ├── app.py ├── bash_scrips ├── qwen2_72b_instruct_step_dpo.sh ├── qwen2_72b_step_dpo.sh └── qwen2_7b_step_dpo.sh ├── configs └── config_full.yaml ├── data └── test │ ├── GSM8K_test_data.jsonl │ └── MATH_test_data.jsonl ├── data_pipeline ├── generate_dataset.py ├── locate_error_by_gpt4.py ├── merge.sh ├── predictions │ └── sample.json ├── prepare_for_correction.py ├── step1.sh ├── step2.sh └── step3.sh ├── eval_math.py ├── eval_results ├── gsm8k │ └── sample.json └── math │ ├── qwen2-7b-dpo-v3-continue-from-incorrect-fix-part1-filtered+2-filtered+aqua-filtered-rej-original_acc0.6-topk1-beta0.5-8ep-fixbug-fixeos-bf16-keywords-fix.json │ └── sample.json ├── evaluation ├── data_processing │ ├── answer_extraction.py │ └── process_utils.py └── eval │ ├── eval_script.py │ ├── eval_utils.py │ ├── ocwcourses_eval_utils.py │ ├── python_executor.py │ └── utils.py ├── imgs ├── .DS_Store ├── coreidea.png ├── example1.png ├── example2.png ├── example3.png ├── example4.png ├── example5.jpg ├── summary.jpg └── triangle.png ├── licenses ├── DATA_LICENSE ├── LICENSE └── WEIGHT_LICENSE ├── paper └── paper.pdf ├── requirements.txt ├── stepdpo_trainer.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | wandb/ 3 | outputs/ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ![image](imgs/coreidea.png) 3 | # Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs 4 | [Xin Lai](https://scholar.google.com/citations?user=tqNDPA4AAAAJ&hl), 5 | [Zhuotao Tian](https://scholar.google.com/citations?user=mEjhz-IAAAAJ&hl), 6 | [Yukang Chen](https://scholar.google.com/citations?user=6p0ygKUAAAAJ&hl), 7 | [Senqiao Yang](https://scholar.google.com/citations?user=NcJc-RwAAAAJ&hl), 8 | [Xiangru Peng](xxxx), 9 | [Jiaya Jia](https://scholar.google.com/citations?user=XPAkzTEAAAAJ&hl=en) 10 | 11 | 12 | [![](https://img.shields.io/badge/Models-HuggingFace-pink)](https://huggingface.co/collections/xinlai/step-dpo-6682e12dfbbb2917c8161df7) 13 | [![](https://img.shields.io/badge/Dataset-Math--Step--DPO--10K-blue)](https://huggingface.co/datasets/xinlai/Math-Step-DPO-10K) 14 | [![](https://img.shields.io/badge/Paper-Arvix%20Link-green)](https://arxiv.org/pdf/2406.18629) 15 | [![](https://img.shields.io/badge/Demo-Huggingface-yellow)](http://103.170.5.190:7870/) 16 | 17 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-yellow.svg)](licenses/LICENSE) 18 | [![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-orange.svg)](licenses/DATA_LICENSE) 19 | [![Weight License](https://img.shields.io/badge/Weight%20License-CC%20By%20NC%204.0-red)](licenses/WEIGHT_LICENSE) 20 | 21 | This repo provides the implementation of **Step-DPO**, a simple, effective, and data-efficient method for boosting the long-chain reasoning ability of LLMs, with **a data construction pipeline** that yields a **high-quality dataset** containing 10K step-wise preference pairs. 22 | 23 | Notably, **Step-DPO** boosts the performance of **Qwen2-7B-Instruct** from **53.0%** to **58.6%** on MATH, and **85.5%** to **87.9%** on GSM8K, with as few as **10K data** and **hundreds of training steps**! 24 | 25 | Moreover, **Step-DPO**, when applied to **Qwen2-72B-Instruct**, achieves scores of **70.8%** and **94.0%** on the test sets of **MATH** and **GSM8K**, respectively, **surpassing a series of closed-source models** without bells and wistles, including GPT-4-1106, Claude-3-Opus, and Gemini-1.5-Pro. 26 | 27 | ![image](imgs/summary.jpg) 28 | 29 | ## TABLE OF CONTENTS 30 | 1. [News](#news) 31 | 2. [Datasets](#datasets) 32 | 3. [Models](#models) 33 | 4. [Installation](#installation) 34 | 5. [Training](#training) 35 | 6. [Evaluation](#evaluation) 36 | 7. [Data Construction Pipeline](#data-construction-pipeline) 37 | 8. [Deployment](#deployment) 38 | 9. [Examples](#examples) 39 | 10. [Acknowledgement](#acknowledgement) 40 | 11. [Citation](#citation) 41 | 42 | ## News 43 | - [x] [2024.7.7] We release the scripts for [Data Construction Pipeline](#data-construction-pipeline)! You can construct dataset on your own with these scripts! 44 | - [x] [2024.7.1] We release the demo of the model [Qwen2-7B-Instruct-Step-DPO](https://huggingface.co/xinlai/Qwen2-7B-Instruct-Step-DPO). Welcome to try it on [Demo](http://103.170.5.190:7870/)! 45 | - [x] [2024.6.28] We release the pre-print of [Step-DPO](https://arxiv.org/pdf/2406.18629) and this GitHub repo, including training/evaluation scripts, pre-trained models and data. 46 | 47 | ## Datasets 48 | 49 | We build a 10K math preference datasets for Step-DPO, which can be downloaded from the following link. 50 | 51 | | Dataset | Size | Link | 52 | | ------------------------ | ------ | ------------------------------------------------------------ | 53 | | xinlai/Math-Step-DPO-10K | 10,795 | 🤗 [Hugging Face](https://huggingface.co/datasets/xinlai/Math-Step-DPO-10K) | 54 | 55 | ## Models 56 | 57 | It is notable that the model **Qwen2-72B-Instruct + Step-DPO** could achieve **70.8%** and **94.0%** on MATH and GSM8K test sets. Step-DPO also brings considerable improvement over various models as follows. Welcome to download and use. 58 | 59 | | Models | Size | MATH | GSM8K | Odyssey-MATH | Link | 60 | | :------------------------------ | :--: | :----: | :---: | :---: | :----------------------------------------------------------: | 61 | | Qwen2-7B-Instruct | 7B | 53.0 | 85.5 | - | - | 62 | | **Qwen2-7B-Instruct + Step-DPO** | 7B | **58.6 (+5.6)** | **87.9 (+2.4)** | - | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-7B-Instruct-Step-DPO) | 63 | | DeepSeekMath-RL | 7B | 51.7 | 88.2 | - | - | 64 | | **DeepSeekMath-RL + Step-DPO** | 7B | **53.2 (+1.5)** | **88.7 (+0.5)** | - | 🤗 [HF](https://huggingface.co/xinlai/DeepSeekMath-RL-Step-DPO) | 65 | | Qwen2-7B-SFT | 7B | 54.8 | 88.2 | - | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-7B-SFT) | 66 | | **Qwen2-7B-SFT + Step-DPO** | 7B | **55.8 (+1.0)** | **88.5 (+0.3)** | - |🤗 [HF](https://huggingface.co/xinlai/Qwen2-7B-SFT-Step-DPO) | 67 | | Qwen1.5-32B-SFT | 32B | 54.9 | 90.0 | - | 🤗 [HF](https://huggingface.co/xinlai/Qwen1.5-32B-SFT) | 68 | | **Qwen1.5-32B-SFT + Step-DPO** | 32B | **56.9 (+2.0)** | **90.9 (+0.9)** | - |🤗 [HF](https://huggingface.co/xinlai/Qwen1.5-32B-SFT-Step-DPO) | 69 | | Qwen2-57B-A14B-SFT | 57B | 54.6 | 89.8 | - | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-57B-A14B-SFT) | 70 | | **Qwen2-57B-A14B-SFT + Step-DPO** | 57B | **56.5 (+1.9)** | **90.0 (+0.2)** | - |🤗 [HF](https://huggingface.co/xinlai/Qwen2-57B-A14B-SFT-Step-DPO) | 71 | | Llama-3-70B-SFT | 70B | 56.9 | 92.2 | - | 🤗 [HF](https://huggingface.co/xinlai/Llama-3-70B-SFT) | 72 | | **Llama-3-70B-SFT + Step-DPO** | 70B | **59.5 (+2.6)** | **93.3 (+1.1)** | - |🤗 [HF](https://huggingface.co/xinlai/Llama-3-70B-SFT-Step-DPO) | 73 | | Qwen2-72B-SFT | 72B | 61.7 | 92.9 | 44.2 | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-72B-SFT) | 74 | | **Qwen2-72B-SFT + Step-DPO** | 72B | **64.7 (+3.0)** | **93.9 (+1.0)** | **47.0 (+2.8)** | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-72B-SFT-Step-DPO) | 75 | | Qwen2-72B-Instruct | 72B | 69.4 | 92.4 | 47.0 | - | 76 | | **Qwen2-72B-Instruct + Step-DPO** | 72B | **70.8 (+1.4)** | **94.0 (+1.6)** | **50.1 (+3.1)** | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-72B-Instruct-Step-DPO) | 77 | 78 | Note: **Odyssey-MATH** contains competition-level math problems. 79 | 80 | ## Installation 81 | ``` 82 | conda create -n step_dpo python=3.10 83 | conda activate step_dpo 84 | 85 | pip install -r requirements.txt 86 | ``` 87 | 88 | ## Training 89 | 90 | ### Pre-trained weights 91 | We use Qwen2, Qwen1.5, Llama-3, and DeepSeekMath models as the pre-trained weights and fine-tune them with Step-DPO. Download based on your choices. 92 | 93 | | Pre-trained weights | 94 | |:---------------------------------------------------------------------------| 95 | | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) | 96 | | [deepseek-ai/deepseek-math-7b-rl](https://huggingface.co/deepseek-ai/deepseek-math-7b-rl) | 97 | | [xinlai/Qwen2-7B-SFT](https://huggingface.co/xinlai/Qwen2-7B-SFT) | 98 | | [xinlai/Qwen1.5-32B-SFT](https://huggingface.co/xinlai/Qwen1.5-32B-SFT) | 99 | | [xinlai/Qwen2-57B-A14B-SFT](https://huggingface.co/xinlai/Qwen2-57B-A14B-SFT) | 100 | | [xinlai/Llama-3-70B-SFT](https://huggingface.co/xinlai/Llama-3-70B-SFT) | 101 | | [xinlai/Qwen2-72B-SFT](https://huggingface.co/xinlai/Qwen2-72B-SFT) | 102 | | [Qwen/Qwen2-72B-Instruct](https://huggingface.co/Qwen/Qwen2-72B-Instruct) | 103 | 104 | **Note**: models with '-SFT' are supervised fine-tuned by our 299K SFT data based on open-source base models. You could perform Step-DPO on either our SFT models or existing open-source instruct models. 105 | 106 | Here is a script example to perform Step-DPO on `Qwen/Qwen2-72B-Instruct`: 107 | 108 | ```shell 109 | ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3_cpu.yaml --mixed_precision bf16 \ 110 | --num_processes 8 \ 111 | train.py configs/config_full.yaml \ 112 | --model_name_or_path="Qwen/Qwen2-72B-Instruct" \ 113 | --data_path="xinlai/Math-Step-DPO-10K" \ 114 | --per_device_train_batch_size=2 \ 115 | --gradient_accumulation_steps=8 \ 116 | --torch_dtype=bfloat16 \ 117 | --bf16=True \ 118 | --beta=0.4 \ 119 | --num_train_epochs=4 \ 120 | --save_strategy='steps' \ 121 | --save_steps=200 \ 122 | --save_total_limit=1 \ 123 | --output_dir=outputs/qwen2-72b-instruct-step-dpo \ 124 | --hub_model_id=qwen2-72b-instruct-step-dpo \ 125 | --prompt=qwen2-boxed 126 | ``` 127 | 128 | ## Evaluation 129 | 130 | Here are script examples to evaluate fine-tuned models on both GSM8K and MATH test sets: 131 | ``` 132 | python eval_math.py \ 133 | --model outputs/qwen2-72b-instruct-step-dpo \ 134 | --data_file ./data/test/GSM8K_test_data.jsonl \ 135 | --save_path 'eval_results/gsm8k/qwen2-72b-instruct-step-dpo.json' \ 136 | --prompt 'qwen2-boxed' \ 137 | --tensor_parallel_size 8 138 | ``` 139 | 140 | ``` 141 | python eval_math.py \ 142 | --model outputs/qwen2-72b-instruct-step-dpo \ 143 | --data_file ./data/test/MATH_test_data.jsonl \ 144 | --save_path 'eval_results/math/qwen2-72b-instruct-step-dpo.json' \ 145 | --prompt 'qwen2-boxed' \ 146 | --tensor_parallel_size 8 147 | ``` 148 | 149 | ## Data Construction Pipeline 150 | 151 | We release the scripts to construct the Step-DPO data, as shown in the `data_pipeline/` directory. Please follow the instructions below. 152 | 153 | ``` 154 | cd Step-DPO 155 | 156 | # Step 1: Error Collection 157 | # Before executing, please set the MODEL_PATH, PRED_PATH, EVAL_PROMPT 158 | bash data_pipeline/step1.sh 159 | 160 | # Step 2: Locate Erroneous Step by GPT-4o 161 | # Before executing, please set the OPENAI_BASE_URL, OPENAI_API_KEY 162 | bash data_pipeline/step2.sh 163 | 164 | # Step 3: Rectify by the model itself 165 | # Before executing, please set the MODEL_PATH, EVAL_PROMPT, JSON_FILE, PRED_PATH, SAVE_PATH 166 | bash data_pipeline/step3.sh 167 | 168 | # Finally, Get the resulting dataset 169 | # Before executing, please set the EVAL_PROMPT, JSON_FILE, PRED_PATH, SAVE_PATH 170 | bash data_pipeline/merge.sh 171 | ``` 172 | 173 | ## Deployment 174 | 175 | For deployment, please directly use the following command: 176 | ``` 177 | python3 app.py --model_path_or_name xinlai/Qwen2-7B-Instruct-Step-DPO 178 | ``` 179 | 180 | 181 | ## Examples 182 | 183 | ![image](imgs/example5.jpg) 184 | 185 | ![image](imgs/example1.png) 186 | 187 | ![image](imgs/example4.png) 188 | 189 | ![image](imgs/example2.png) 190 | 191 | ## Acknowledgement 192 | 193 | This repository is based on [alignment-handbook](https://github.com/huggingface/alignment-handbook), [DeepSeekMath](https://github.com/deepseek-ai/DeepSeek-Math), and [MetaMath](https://github.com/meta-math/MetaMath). 194 | 195 | Many thanks for their efforts! 196 | 197 | ## Citation 198 | If you find this project useful in your research, please consider citing us: 199 | 200 | ``` 201 | @article{lai2024stepdpo, 202 | title={Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs}, 203 | author={Xin Lai and Zhuotao Tian and Yukang Chen and Senqiao Yang and Xiangru Peng and Jiaya Jia}, 204 | journal={arXiv:2406.18629}, 205 | year={2024} 206 | } 207 | ``` 208 | -------------------------------------------------------------------------------- /accelerate_configs/deepspeed_zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: no #bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /accelerate_configs/deepspeed_zero3_cpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: cpu 6 | offload_param_device: cpu 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: no #bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | import sys 5 | 6 | import bleach 7 | import gradio as gr 8 | import torch 9 | import transformers 10 | 11 | 12 | def parse_args(args): 13 | parser = argparse.ArgumentParser(description='LISA chat') 14 | parser.add_argument('--model_path_or_name', default='') 15 | parser.add_argument('--save_path', default='/data/step_dpo_history') 16 | return parser.parse_args(args) 17 | 18 | args = parse_args(sys.argv[1:]) 19 | os.makedirs(args.save_path, exist_ok=True) 20 | 21 | # Create model 22 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_path_or_name) 23 | model = transformers.AutoModelForCausalLM.from_pretrained(args.model_path_or_name, torch_dtype=torch.bfloat16, device_map="auto") 24 | 25 | # Gradio 26 | examples = [ 27 | ['Suppose that $h(x)=f^{-1}(x)$. If $h(2)=10$, $h(10)=1$ and $h(1)=2$, what is $f(f(10))$?'], 28 | ] 29 | output_labels = ['Output'] 30 | 31 | title = 'Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs' 32 | 33 | description = """ 34 | 35 | 36 | This is the online demo of **Qwen2-7B-Instruct-Step-DPO**. \n 37 | 38 | It is obtained by performing **Step-DPO** on **Qwen2-7B-Instruct**, with as few as **10K data and hundreds of training steps**. \n 39 | 40 | **Step-DPO** improves the mathematical reasoning of **Qwen2-7B-Instruct** significantly, from **53.0\%** to **58.6\%** on MATH, and **85.5\%** to **87.9\%** on GSM8K. \n 41 | Besides, **Qwen2-72B-Instruct-Step-DPO** achieves **70.8\%** on MATH and **94.0\%** on GSM8K, **outperforming GPT-4-1106, Gemini-1.5-Pro, and Claude-3-Opus**. 42 | 43 | Code, models, data are available at [GitHub](https://github.com/dvlab-research/Step-DPO). 44 | 45 | Hope you can enjoy our work! 46 | 47 | """ 48 | 49 | article = """ 50 |

51 | 52 | Preprint Paper 53 | 54 | \n 55 |

56 | Github Repo

57 | """ 58 | 59 | 60 | def inference(input_str): 61 | 62 | ## filter out special chars 63 | input_str = bleach.clean(input_str) 64 | 65 | print("input_str: ", input_str) 66 | 67 | prompt = input_str + "\nPlease reason step by step, and put your final answer within \\boxed{{}}." #input("Please input your prompt: ") 68 | 69 | messages = [ 70 | {"role": "user", "content": prompt} 71 | ] 72 | 73 | text = tokenizer.apply_chat_template( 74 | messages, 75 | tokenize=False, 76 | add_generation_prompt=True 77 | ) 78 | 79 | model_inputs = tokenizer([text], return_tensors="pt").to('cuda') 80 | 81 | generated_ids = model.generate( 82 | model_inputs.input_ids, 83 | max_new_tokens=1024 84 | ) 85 | generated_ids = [ 86 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 87 | ] 88 | text_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 89 | 90 | return text_output 91 | 92 | 93 | demo = gr.Interface( 94 | inference, 95 | inputs=[ 96 | gr.Textbox( 97 | lines=1, placeholder=None, label='Math Problem'), 98 | ], 99 | outputs=[ 100 | gr.Textbox( 101 | lines=1, placeholder=None, label='Text Output'), 102 | ], 103 | title=title, 104 | description=description, 105 | article=article, 106 | examples=examples, 107 | allow_flagging='auto', 108 | flagging_dir=args.save_path) 109 | 110 | demo.queue() 111 | 112 | demo.launch(server_name='0.0.0.0', show_error=True) 113 | -------------------------------------------------------------------------------- /bash_scrips/qwen2_72b_instruct_step_dpo.sh: -------------------------------------------------------------------------------- 1 | export output_dir="qwen2-72b-instruct-step-dpo" 2 | export prompt="qwen2-boxed" 3 | 4 | ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3_cpu.yaml --mixed_precision bf16 \ 5 | --num_processes 8 \ 6 | train.py configs/config_full.yaml \ 7 | --model_name_or_path="Qwen/Qwen2-72B-Instruct" \ 8 | --data_path="xinlai/Math-Step-DPO-10K" \ 9 | --per_device_train_batch_size=2 \ 10 | --gradient_accumulation_steps=8 \ 11 | --torch_dtype=bfloat16 \ 12 | --bf16=True \ 13 | --beta=0.4 \ 14 | --num_train_epochs=4 \ 15 | --save_strategy='steps' \ 16 | --save_steps=200 \ 17 | --save_total_limit=1 \ 18 | --output_dir=outputs/$output_dir \ 19 | --hub_model_id=$output_dir \ 20 | --prompt=$prompt 21 | 22 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/GSM8K_test_data.jsonl --save_path 'eval_results/gsm8k/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4 23 | 24 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/MATH_test_data.jsonl --save_path 'eval_results/math/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4 25 | -------------------------------------------------------------------------------- /bash_scrips/qwen2_72b_step_dpo.sh: -------------------------------------------------------------------------------- 1 | export output_dir="qwen2-72b-step-dpo" 2 | export prompt="alpaca" 3 | 4 | ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3_cpu.yaml --mixed_precision bf16 \ 5 | --num_processes 8 \ 6 | train.py configs/config_full.yaml \ 7 | --model_name_or_path="xinlai/Qwen2-72B-SFT" \ 8 | --data_path="xinlai/Math-Step-DPO-10K" \ 9 | --per_device_train_batch_size=2 \ 10 | --gradient_accumulation_steps=8 \ 11 | --torch_dtype=bfloat16 \ 12 | --bf16=True \ 13 | --beta=0.4 \ 14 | --num_train_epochs=4 \ 15 | --save_strategy='steps' \ 16 | --save_steps=200 \ 17 | --save_total_limit=1 \ 18 | --output_dir=outputs/$output_dir \ 19 | --hub_model_id=$output_dir \ 20 | --prompt=$prompt 21 | 22 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/GSM8K_test_data.jsonl --save_path 'eval_results/gsm8k/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4 23 | 24 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/MATH_test_data.jsonl --save_path 'eval_results/math/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4 25 | -------------------------------------------------------------------------------- /bash_scrips/qwen2_7b_step_dpo.sh: -------------------------------------------------------------------------------- 1 | export output_dir="qwen2-7b-step-dpo" 2 | export prompt="alpaca" 3 | 4 | ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml --mixed_precision bf16 \ 5 | --num_processes 8 \ 6 | train.py configs/config_full.yaml \ 7 | --model_name_or_path="xinlai/Qwen2-7B-SFT" \ 8 | --data_path="xinlai/Math-Step-DPO-10K" \ 9 | --per_device_train_batch_size=4 \ 10 | --gradient_accumulation_steps=4 \ 11 | --torch_dtype=bfloat16 \ 12 | --bf16=True \ 13 | --beta=0.5 \ 14 | --num_train_epochs=8 \ 15 | --save_strategy='steps' \ 16 | --save_steps=400 \ 17 | --save_total_limit=1 \ 18 | --output_dir=outputs/$output_dir \ 19 | --hub_model_id=$output_dir \ 20 | --prompt=$prompt 21 | 22 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/GSM8K_test_data.jsonl --save_path 'eval_results/gsm8k/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4 23 | 24 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/MATH_test_data.jsonl --save_path 'eval_results/math/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4 25 | -------------------------------------------------------------------------------- /configs/config_full.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: 3 | torch_dtype: bfloat16 4 | 5 | # Data training arguments 6 | # For definitions, see: src/h4/training/config.py 7 | data_path: 8 | dataset_splits: 9 | - train 10 | preprocessing_num_workers: 12 11 | 12 | # DPOTrainer arguments 13 | bf16: True 14 | beta: 0.05 15 | do_eval: False 16 | evaluation_strategy: 'no' 17 | eval_steps: 100 18 | gradient_accumulation_steps: 16 19 | gradient_checkpointing: true 20 | gradient_checkpointing_kwargs: 21 | use_reentrant: False 22 | hub_model_id: step-dpo 23 | learning_rate: 5.0e-7 24 | log_level: info 25 | logging_steps: 1 26 | lr_scheduler_type: cosine 27 | max_length: 1024 28 | max_prompt_length: 512 29 | num_train_epochs: 2 30 | optim: adamw_torch 31 | output_dir: data/step-dpo 32 | per_device_train_batch_size: 1 33 | per_device_eval_batch_size: 4 34 | push_to_hub: false 35 | report_to: 36 | - tensorboard 37 | - wandb 38 | save_strategy: "no" 39 | seed: 42 40 | warmup_ratio: 0.1 41 | -------------------------------------------------------------------------------- /data_pipeline/generate_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import glob 4 | import jsonlines 5 | 6 | def main(args): 7 | save_path = args.save_path 8 | json_files = sorted(glob.glob(args.corrected_files)) 9 | identifier2items = {} 10 | for json_file in json_files: 11 | with open(json_file) as f: 12 | for item in json.load(f): 13 | if item['result']: 14 | if 'alpaca' in args.prompt: 15 | prompt = item['prompt'].split("### Instruction:")[1].split("### Response:")[0].strip() 16 | prefix = item['prompt'].split("### Response:")[-1].lstrip() 17 | elif 'qwen2-boxed' in args.prompt: 18 | prompt = item['prompt'].split("<|im_start|>user\n")[1].split("\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>")[0].strip() 19 | prefix = item['prompt'].split("<|im_start|>assistant\n")[-1].lstrip() 20 | else: 21 | raise NotImplementedError("Prompt {} is not supported currently".format(args.prompt)) 22 | 23 | prefix = prefix.replace("Let's think step by step.\n", "") 24 | identifier = prompt + "||" + prefix 25 | if identifier not in identifier2items: 26 | identifier2items[identifier] = [] 27 | identifier2items[identifier].append(item) 28 | 29 | new_items = [] 30 | invalid_cnt = 0 31 | cnt = 0 32 | with jsonlines.open(args.json_file, "r") as f: 33 | for line in f: 34 | prompt = line['instruction'] 35 | prefix = line['prefix'] 36 | identifier = prompt + "||" + prefix 37 | 38 | if identifier not in identifier2items: 39 | invalid_cnt += 1 40 | continue 41 | items = identifier2items[identifier] 42 | visited_chosen = set() 43 | for item in items: 44 | cnt += 1 45 | chosen = item['completion'] 46 | rejected = line['output'] 47 | 48 | chosen_first_step = chosen.split("\nStep ")[0] 49 | rejected_first_step = rejected.split("\nStep ")[0] 50 | 51 | if chosen_first_step in visited_chosen: 52 | continue 53 | 54 | visited_chosen.add(chosen_first_step) 55 | 56 | new_item = { 57 | 'dataset': line['type'], 58 | 'prompt': prompt, 59 | 'prefix': "Let's think step by step.\n" + prefix, 60 | 'chosen': chosen_first_step, 61 | 'rejected': rejected_first_step, 62 | 'original_chosen': chosen, 63 | 'answer': line['answer'], 64 | } 65 | new_items.append(new_item) 66 | 67 | print("len(new_items): {}, invalid_cnt: {}, cnt: {}".format(len(new_items), invalid_cnt, cnt)) 68 | with open(save_path, "w+") as f: 69 | json.dump(new_items, f, indent=4) 70 | 71 | def parse_args(): 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("--prompt", type=str, default='qwen2-boxed-step') 74 | parser.add_argument("--save_path", type=str, default='./data_pipeline/data.json') 75 | parser.add_argument("--json_file", type=str, default="./data_pipeline/continue_from_incorrect_step.jsonl") 76 | parser.add_argument("--corrected_files", type=str, default="./data_pipeline/corrections/qwen2-7b-instruct-correction*.json") 77 | return parser.parse_args() 78 | 79 | if __name__ == "__main__": 80 | args = parse_args() 81 | main(args) 82 | -------------------------------------------------------------------------------- /data_pipeline/locate_error_by_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import os 5 | import time 6 | 7 | import openai 8 | import tqdm 9 | 10 | client = openai.OpenAI( 11 | base_url=os.getenv("OPENAI_BASE_URL"), 12 | api_key=os.getenv("OPENAI_API_KEY"), 13 | ) 14 | 15 | prompt = '''### Problem: 16 | {problem} 17 | 18 | ### Correct solution: 19 | {solution} 20 | 21 | ### Incorrect answer: 22 | {answer} 23 | 24 | --- 25 | 26 | A math problem and its correct solution are listed above. We also give another incorrect answer, where step-by-step reasoning process is shown. Please output the correctness for each reasoning step in the given answer. 27 | 28 | Requirements: 29 | 1. You should first output a step-by-step analysis process (no more than 200 words), and finally output the decision ("correct", "neutral", "incorrect") for each step following the format of "Final Decision:\nStep 1: correct; Step 2: neutral; ..."; 30 | 2. Stop when you find the first incorrect step.''' 31 | 32 | def main(args): 33 | 34 | if not os.path.exists(args.save_dir): 35 | os.mkdir(args.save_dir) 36 | 37 | save_dir = args.save_dir 38 | visited_dirs = save_dir if len(args.visited_dirs) == 0 else args.visited_dirs 39 | json_files = sorted(glob.glob(args.json_files)) 40 | 41 | pred_data = [] 42 | for json_file in json_files: 43 | with open(json_file) as f: 44 | for item in json.load(f): 45 | if not item['result']: 46 | pred_data.append(item) 47 | 48 | n_groups = args.n_groups 49 | remainder = args.remainder 50 | 51 | print("n_groups: {}, remainder: {}".format(n_groups, remainder)) 52 | print("len(pred_data): ", len(pred_data)) 53 | 54 | cnt = 0 55 | question2cnt = dict() 56 | for idx, pred_dict in tqdm.tqdm(enumerate(pred_data)): 57 | 58 | if 'alpaca' in args.prompt: 59 | question = pred_dict['prompt'].split("### Instruction:")[1].split("### Response:")[0].strip() 60 | elif 'qwen2-boxed' in args.prompt: 61 | question = pred_dict['prompt'].split("<|im_start|>user\n")[1].split("\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>")[0].strip() 62 | else: 63 | raise NotImplementedError("Prompt {} is not supported currently".format(args.prompt)) 64 | 65 | if question in question2cnt and question2cnt[question] > args.max_count_per_question: 66 | continue 67 | if question not in question2cnt: 68 | question2cnt[question] = 0 69 | question2cnt[question] += 1 70 | 71 | # skip the invalid questions without diagram 72 | if "diagram" in question and 'asy' not in question: 73 | continue 74 | 75 | # skip other threads 76 | if idx % n_groups != remainder: 77 | continue 78 | 79 | # skip the visited questions 80 | if any([os.path.exists(os.path.join(visited_dir, "{}.json".format(idx))) for visited_dir in visited_dirs.split("||")]): 81 | continue 82 | 83 | completion = "Step 1: " + pred_dict['completion'] 84 | instruction = prompt.format(problem=question, solution=pred_dict['gt_output'].replace("\n\n", "\n"), answer=completion.replace("\n\n", "\n")) 85 | 86 | # print("instruction: ", instruction) 87 | # import pdb; pdb.set_trace() 88 | 89 | while True: 90 | try: 91 | chat_completion = client.chat.completions.create( 92 | messages=[ 93 | { 94 | "role": "user", 95 | "content": instruction, 96 | } 97 | ], 98 | model="gpt-4o", 99 | ) 100 | except (openai.APIConnectionError, openai.InternalServerError) as e: 101 | print(str(e)) 102 | time.sleep(3) 103 | continue 104 | break 105 | 106 | item = pred_dict.copy() 107 | item['gpt4-output'] = chat_completion.choices[0].message.content 108 | item['gpt4-prompt'] = instruction 109 | save_path = os.path.join(save_dir, "{}.json".format(idx)) 110 | with open(save_path, "w+") as f: 111 | json.dump(item, f, indent=4) 112 | cnt += 1 113 | print("cnt: ", cnt, "idx: ", idx) 114 | if cnt >= args.max_count_total: 115 | break 116 | 117 | def parse_args(): 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument("--prompt", type=str, default='qwen2-boxed-step') 120 | parser.add_argument("--visited_dirs", type=str, default='') # will skip the files in $visited_dirs 121 | parser.add_argument("--save_dir", type=str, default='./data_pipeline/generated') 122 | parser.add_argument("--remainder", type=int, default=0) # remainder 123 | parser.add_argument("--n_groups", type=int, default=1) # n_groups 124 | parser.add_argument("--json_files", type=str, default="./data_pipeline/predictions/qwen2-7b-instruct-temp0.8-top_p0.95_rep2_seed0-alpaca-group*.json") 125 | parser.add_argument("--max_count_per_question", type=int, default=1) 126 | parser.add_argument("--max_count_total", type=int, default=10000) 127 | return parser.parse_args() 128 | 129 | if __name__ == "__main__": 130 | args = parse_args() 131 | main(args) 132 | -------------------------------------------------------------------------------- /data_pipeline/merge.sh: -------------------------------------------------------------------------------- 1 | export EVAL_PROMPT='qwen2-boxed-prefix' 2 | export JSON_FILE='./data_pipeline/continue_from_incorrect_step.jsonl' 3 | export PRED_PATH='./data_pipeline/corrections/qwen2-7b-instruct-correction' 4 | export SAVE_PATH='./data_pipeline/data.json' 5 | 6 | python3 data_pipeline/generate_dataset.py --prompt $EVAL_PROMPT \ 7 | --save_path $SAVE_PATH \ 8 | --json_file $JSON_FILE \ 9 | --corrected_files $PRED_PATH"*.json" 10 | -------------------------------------------------------------------------------- /data_pipeline/predictions/sample.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/data_pipeline/predictions/sample.json -------------------------------------------------------------------------------- /data_pipeline/prepare_for_correction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import glob 4 | import jsonlines 5 | import re 6 | 7 | def main(args): 8 | save_file = args.save_file 9 | generated_files = sorted(glob.glob(args.generated_files)) 10 | 11 | invalid_cnt0 = 0 12 | invalid_cnt1 = 0 13 | invalid_cnt2 = 0 14 | with jsonlines.open(save_file, "w") as f: 15 | for json_file in generated_files: 16 | 17 | with open(json_file) as ff: 18 | item = json.load(ff) 19 | 20 | correctness = item['gpt4-output'].lower() 21 | correctness = correctness.split("final decision")[-1].split("summary decision:")[-1].strip() 22 | if not any([x in correctness for x in ['neutral', 'incorrect']]): 23 | invalid_cnt0 += 1 24 | continue 25 | 26 | step_num = correctness.split("neutral")[0].split("incorrect")[0] 27 | step_num = step_num.split("\n")[-1].split(";")[-1] 28 | if step_num.count("step") > 1: 29 | invalid_cnt1 += 1 30 | continue 31 | step_num = step_num.split("step")[-1].split(":")[0] 32 | try: 33 | step_num = int(step_num.strip()) 34 | except: 35 | # import pdb; pdb.set_trace() 36 | invalid_cnt2 += 1 37 | continue 38 | 39 | if 'alpaca' in args.prompt: 40 | prompt = item['prompt'].split("### Instruction:")[1].split("### Response:")[0].strip() 41 | prefix = item['prompt'].split("### Response:")[-1].lstrip() 42 | elif 'qwen2-boxed' in args.prompt: 43 | prompt = item['prompt'].split("<|im_start|>user\n")[1].split("\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>")[0].strip() 44 | prefix = item['prompt'].split("<|im_start|>assistant\n")[-1].lstrip() 45 | else: 46 | raise NotImplementedError("Prompt {} is not supported currently".format(args.prompt)) 47 | 48 | completion = prefix + item['completion'] 49 | # pred_answer = completion.split("The answer is:")[-1].strip() 50 | type = item['type'] 51 | 52 | if completion.count("Step {}:".format(step_num)) == 0: 53 | continue 54 | 55 | prefix = completion.split("Step {}:".format(step_num))[0] + "Step {}:".format(step_num) 56 | 57 | new_item = { 58 | 'idx': "n/a", 59 | 'instruction': prompt, 60 | 'prefix': prefix.replace("Let's think step by step.\n", ""), 61 | 'output': completion.replace(prefix, ""), 62 | 'gt_output': item['gt_output'], 63 | 'answer': item['prompt_answer'], 64 | 'step_num': step_num, 65 | 'input': "", 66 | 'type': type, 67 | 'ori_filepath': item['path'] if 'path' in item else 'n/a', 68 | } 69 | f.write(new_item) 70 | 71 | print("invalid_cnt0: ", invalid_cnt0) 72 | print("invalid_cnt1: ", invalid_cnt1) 73 | print("invalid_cnt2: ", invalid_cnt2) 74 | 75 | def parse_args(): 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument("--prompt", type=str, default='qwen2-boxed-step') 78 | parser.add_argument("--save_file", type=str, default='./data_pipeline/continue_from_incorrect.jsonl') 79 | parser.add_argument("--generated_files", type=str, default="./data_pipeline/generated/*.json") 80 | return parser.parse_args() 81 | 82 | if __name__ == "__main__": 83 | args = parse_args() 84 | main(args) 85 | -------------------------------------------------------------------------------- /data_pipeline/step1.sh: -------------------------------------------------------------------------------- 1 | export MODEL_PATH='/dataset/pretrained-models/Qwen2-7B-Instruct' 2 | export PRED_PATH='./data_pipeline/predictions/qwen2-7b-instruct-temp0.8-top_p0.95_rep2_seed0-alpaca-group' 3 | export EVAL_PROMPT='qwen2-boxed-step' 4 | 5 | CUDA_VISIBLE_DEVICES=0 python eval_math.py --model $MODEL_PATH --remainder 0 --n_groups 8 --save_path $PRED_PATH"0.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 & 6 | CUDA_VISIBLE_DEVICES=1 python eval_math.py --model $MODEL_PATH --remainder 1 --n_groups 8 --save_path $PRED_PATH"1.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 & 7 | CUDA_VISIBLE_DEVICES=2 python eval_math.py --model $MODEL_PATH --remainder 2 --n_groups 8 --save_path $PRED_PATH"2.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 & 8 | CUDA_VISIBLE_DEVICES=3 python eval_math.py --model $MODEL_PATH --remainder 3 --n_groups 8 --save_path $PRED_PATH"3.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 & 9 | CUDA_VISIBLE_DEVICES=4 python eval_math.py --model $MODEL_PATH --remainder 4 --n_groups 8 --save_path $PRED_PATH"4.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 & 10 | CUDA_VISIBLE_DEVICES=5 python eval_math.py --model $MODEL_PATH --remainder 5 --n_groups 8 --save_path $PRED_PATH"5.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 & 11 | CUDA_VISIBLE_DEVICES=6 python eval_math.py --model $MODEL_PATH --remainder 6 --n_groups 8 --save_path $PRED_PATH"6.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 & 12 | CUDA_VISIBLE_DEVICES=7 python eval_math.py --model $MODEL_PATH --remainder 7 --n_groups 8 --save_path $PRED_PATH"7.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 13 | -------------------------------------------------------------------------------- /data_pipeline/step2.sh: -------------------------------------------------------------------------------- 1 | export OPENAI_BASE_URL="" # input openai base_url here 2 | export OPENAI_API_KEY="" # input openai api_key here 3 | 4 | python3 data_pipeline/locate_error_by_gpt4.py \ 5 | --prompt "qwen2-boxed-step" \ 6 | --save_dir "./data_pipeline/generated" \ 7 | --json_files "./data_pipeline/predictions/qwen2-7b-instruct-temp0.8-top_p0.95_rep2_seed0-alpaca-group*.json" \ 8 | --max_count_total 100 9 | -------------------------------------------------------------------------------- /data_pipeline/step3.sh: -------------------------------------------------------------------------------- 1 | export MODEL_PATH='/dataset/pretrained-models/Qwen2-7B-Instruct' 2 | export EVAL_PROMPT='qwen2-boxed-prefix' 3 | export JSON_FILE='./data_pipeline/continue_from_incorrect_step.jsonl' 4 | export PRED_PATH='./data_pipeline/corrections/qwen2-7b-instruct-correction' 5 | export SAVE_PATH='./data_pipeline/data.json' 6 | 7 | python3 data_pipeline/prepare_for_correction.py --prompt $EVAL_PROMPT \ 8 | --save_file $JSON_FILE \ 9 | --generated_files "./data_pipeline/generated/*.json" 10 | 11 | CUDA_VISIBLE_DEVICES=0 python eval_math.py --model $MODEL_PATH --remainder 0 --n_groups 8 --save_path $PRED_PATH"0.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 & 12 | CUDA_VISIBLE_DEVICES=1 python eval_math.py --model $MODEL_PATH --remainder 1 --n_groups 8 --save_path $PRED_PATH"1.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 & 13 | CUDA_VISIBLE_DEVICES=2 python eval_math.py --model $MODEL_PATH --remainder 2 --n_groups 8 --save_path $PRED_PATH"2.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 & 14 | CUDA_VISIBLE_DEVICES=3 python eval_math.py --model $MODEL_PATH --remainder 3 --n_groups 8 --save_path $PRED_PATH"3.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 & 15 | CUDA_VISIBLE_DEVICES=4 python eval_math.py --model $MODEL_PATH --remainder 4 --n_groups 8 --save_path $PRED_PATH"4.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 & 16 | CUDA_VISIBLE_DEVICES=5 python eval_math.py --model $MODEL_PATH --remainder 5 --n_groups 8 --save_path $PRED_PATH"5.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 & 17 | CUDA_VISIBLE_DEVICES=6 python eval_math.py --model $MODEL_PATH --remainder 6 --n_groups 8 --save_path $PRED_PATH"6.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 & 18 | CUDA_VISIBLE_DEVICES=7 python eval_math.py --model $MODEL_PATH --remainder 7 --n_groups 8 --save_path $PRED_PATH"7.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 19 | -------------------------------------------------------------------------------- /eval_math.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import pdb 5 | import sys 6 | 7 | import jsonlines 8 | import torch 9 | from evaluation.data_processing.answer_extraction import extract_math_answer 10 | from evaluation.eval.eval_script import eval_math 11 | from vllm import LLM, SamplingParams 12 | 13 | MAX_INT = sys.maxsize 14 | INVALID_ANS = "[invalid]" 15 | 16 | invalid_outputs = [] 17 | 18 | def batch_data(data_list, batch_size=1): 19 | n = len(data_list) // batch_size 20 | batch_data = [] 21 | for i in range(n-1): 22 | start = i * batch_size 23 | end = (i+1)*batch_size 24 | batch_data.append(data_list[start:end]) 25 | 26 | last_start = (n-1) * batch_size 27 | last_end = MAX_INT 28 | batch_data.append(data_list[last_start:last_end]) 29 | return batch_data 30 | 31 | def test_hendrycks_math(model, data_path, remainder=0, n_groups=MAX_INT, batch_size=1, tensor_parallel_size=1, args=None): 32 | 33 | save_path = args.save_path 34 | hendrycks_math_ins = [] 35 | hendrycks_math_answers = [] 36 | attributes = [] 37 | if args.prompt == 'alpaca': 38 | problem_prompt = ( 39 | "Below is an instruction that describes a task. " 40 | "Write a response that appropriately completes the request.\n\n" 41 | "### Instruction:\n{instruction}\n\n### Response: Let's think step by step." 42 | ) 43 | elif args.prompt == 'alpaca-cot-step': 44 | problem_prompt = ( 45 | "Below is an instruction that describes a task. " 46 | "Write a response that appropriately completes the request.\n\n" 47 | "### Instruction:\n{instruction}\n\n### Response:\nLet's think step by step.\nStep 1: " 48 | ) 49 | elif args.prompt == 'alpaca-cot-prefix': 50 | problem_prompt = ( 51 | "Below is an instruction that describes a task. " 52 | "Write a response that appropriately completes the request.\n\n" 53 | "### Instruction:\n{instruction}\n\n### Response:\nLet's think step by step.\n{prefix}" 54 | ) 55 | elif args.prompt == 'deepseek-math': 56 | problem_prompt = ( 57 | "User: {instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant:" 58 | ) 59 | elif args.prompt == 'deepseek-math-step': 60 | problem_prompt = ( 61 | "User: {instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant: Let's think step by step.\nStep 1: " 62 | ) 63 | elif args.prompt == 'qwen2-boxed': 64 | problem_prompt = ( 65 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 66 | "<|im_start|>user\n{instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" 67 | "<|im_start|>assistant\n" 68 | ) 69 | elif args.prompt == 'qwen2-boxed-step': 70 | problem_prompt = ( 71 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 72 | "<|im_start|>user\n{instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" 73 | "<|im_start|>assistant\nLet's think step by step.\nStep 1: " 74 | ) 75 | elif args.prompt == 'qwen2-boxed-prefix': 76 | problem_prompt = ( 77 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 78 | "<|im_start|>user\n{instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" 79 | "<|im_start|>assistant\nLet's think step by step.\n{prefix}" 80 | ) 81 | 82 | print('prompt =====', problem_prompt) 83 | with open(data_path, "r+", encoding="utf8") as f: 84 | for idx, item in enumerate(jsonlines.Reader(f)): 85 | if "prefix" in item: 86 | temp_instr = problem_prompt.format(instruction=item["instruction"], prefix=item['prefix']) 87 | else: 88 | temp_instr = problem_prompt.format(instruction=item["instruction"]) 89 | hendrycks_math_ins.append(temp_instr) 90 | temp_ans = item['answer'] 91 | hendrycks_math_answers.append(temp_ans) 92 | attribute = {} 93 | if 'filepath' in item: 94 | attribute['filepath'] = item['filepath'] 95 | if 'type' in item: 96 | attribute['type'] = item['type'] 97 | if 'output' in item: 98 | attribute['gt_output'] = item['output'] 99 | attributes.append(attribute) 100 | 101 | print("args.seed: ", args.seed) 102 | print('length ===', len(hendrycks_math_ins)) 103 | hendrycks_math_ins = hendrycks_math_ins[remainder::n_groups] 104 | hendrycks_math_answers = hendrycks_math_answers[remainder::n_groups] 105 | attributes = attributes[remainder::n_groups] 106 | 107 | print("processed length ===", len(hendrycks_math_ins)) 108 | hendrycks_math_ins = hendrycks_math_ins * args.rep 109 | hendrycks_math_answers = hendrycks_math_answers * args.rep 110 | attributes = attributes * args.rep 111 | 112 | print('total length ===', len(hendrycks_math_ins)) 113 | batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size) 114 | 115 | sampling_params = SamplingParams(temperature=args.temp, top_p=args.top_p, max_tokens=2048) 116 | print('sampling =====', sampling_params) 117 | if not os.path.exists(save_path): 118 | llm = LLM(model=model, tensor_parallel_size=tensor_parallel_size, dtype=torch.bfloat16, seed=args.seed) 119 | 120 | res_completions = [] 121 | for idx, (prompt, prompt_answer) in enumerate(zip(batch_hendrycks_math_ins, hendrycks_math_answers)): 122 | if isinstance(prompt, list): 123 | pass 124 | else: 125 | prompt = [prompt] 126 | completions = llm.generate(prompt, sampling_params) 127 | for output in completions: 128 | prompt_temp = output.prompt 129 | generated_text = output.outputs[0].text 130 | res_completions.append(generated_text) 131 | else: 132 | res_completions = [] 133 | with open(save_path) as f: 134 | items = json.load(f) 135 | for idx, item in enumerate(items): 136 | res_completions.append(item['completion']) 137 | 138 | to_save_list = [] 139 | results = [] 140 | for idx, (prompt, completion, prompt_answer, attribute) in enumerate(zip(hendrycks_math_ins, res_completions, hendrycks_math_answers, attributes)): 141 | 142 | if isinstance(prompt_answer, str) and prompt_answer.startswith("\\text{"): 143 | prompt_answer = remove_text(prompt_answer) 144 | 145 | if "The answer is:" in completion and (isinstance(prompt_answer, list) and len(prompt_answer) == 1 and "\\begin{pmatrix}" in prompt_answer[0]): 146 | prompt_answer[0] = prompt_answer[0].replace("\\\\", "\\") 147 | completion = completion.replace("\\\\", "\\") 148 | 149 | item = { 150 | 'question': prompt, 151 | 'model_output': completion, 152 | 'prediction': extract_math_answer(prompt, completion, task='cot'), 153 | 'answer': prompt_answer if isinstance(prompt_answer, list) else [prompt_answer], 154 | } 155 | 156 | if len(item['prediction']) == 0: 157 | invalid_outputs.append({'question': prompt, 'output': completion, 'answer': item['prediction']}) 158 | res = False 159 | extract_ans = None 160 | else: 161 | extract_ans = item['prediction'] 162 | res = eval_math(item) 163 | 164 | results.append(res) 165 | 166 | to_save_dict = { 167 | 'prompt': prompt, 168 | 'completion': completion, 169 | 'extract_answer': extract_ans, 170 | 'prompt_answer': prompt_answer, 171 | 'result': res, 172 | } 173 | to_save_dict.update(attribute) 174 | to_save_list.append(to_save_dict) 175 | 176 | acc = sum(results) / len(results) 177 | # print('valid_outputs===', invalid_outputs) 178 | print('len invalid outputs ====', len(invalid_outputs)) 179 | print('n_groups===', n_groups, ', remainder====', remainder) 180 | print('length====', len(results), ', acc====', acc) 181 | 182 | try: 183 | with open(save_path, "w+") as f: 184 | json.dump(to_save_list, f, indent=4) 185 | except Exception: 186 | pdb.set_trace() 187 | 188 | def parse_args(): 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument("--model", type=str, default='') # model path 191 | parser.add_argument("--data_file", type=str, default='') # data path 192 | parser.add_argument("--remainder", type=int, default=0) # index 193 | parser.add_argument("--n_groups", type=int, default=1) # group number 194 | parser.add_argument("--batch_size", type=int, default=400) # batch_size 195 | parser.add_argument("--tensor_parallel_size", type=int, default=8) # tensor_parallel_size 196 | parser.add_argument("--save_path", type=str) 197 | parser.add_argument("--prompt", type=str, default='alpaca') 198 | parser.add_argument("--temp", type=float, default=0.0) 199 | parser.add_argument("--top_p", type=float, default=1.0) 200 | parser.add_argument("--seed", type=int, default=None) 201 | parser.add_argument("--rep", type=int, default=1) 202 | return parser.parse_args() 203 | 204 | if __name__ == "__main__": 205 | args = parse_args() 206 | test_hendrycks_math(model=args.model, data_path=args.data_file, remainder=args.remainder, n_groups=args.n_groups, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size, args=args) 207 | -------------------------------------------------------------------------------- /eval_results/gsm8k/sample.json: -------------------------------------------------------------------------------- 1 | [] -------------------------------------------------------------------------------- /eval_results/math/qwen2-7b-dpo-v3-continue-from-incorrect-fix-part1-filtered+2-filtered+aqua-filtered-rej-original_acc0.6-topk1-beta0.5-8ep-fixbug-fixeos-bf16-keywords-fix.json: -------------------------------------------------------------------------------- 1 | [ 2 | 3 | ] -------------------------------------------------------------------------------- /eval_results/math/sample.json: -------------------------------------------------------------------------------- 1 | [] -------------------------------------------------------------------------------- /evaluation/data_processing/answer_extraction.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import regex 4 | 5 | 6 | def _fix_fracs(string): 7 | substrs = string.split("\\frac") 8 | new_str = substrs[0] 9 | if len(substrs) > 1: 10 | substrs = substrs[1:] 11 | for substr in substrs: 12 | new_str += "\\frac" 13 | if len(substr) > 0 and substr[0] == "{": 14 | new_str += substr 15 | else: 16 | try: 17 | assert len(substr) >= 2 18 | except Exception: 19 | return string 20 | a = substr[0] 21 | b = substr[1] 22 | if b != "{": 23 | if len(substr) > 2: 24 | post_substr = substr[2:] 25 | new_str += "{" + a + "}{" + b + "}" + post_substr 26 | else: 27 | new_str += "{" + a + "}{" + b + "}" 28 | else: 29 | if len(substr) > 2: 30 | post_substr = substr[2:] 31 | new_str += "{" + a + "}" + b + post_substr 32 | else: 33 | new_str += "{" + a + "}" + b 34 | string = new_str 35 | return string 36 | 37 | 38 | def _fix_a_slash_b(string): 39 | if len(string.split("/")) != 2: 40 | return string 41 | a = string.split("/")[0] 42 | b = string.split("/")[1] 43 | try: 44 | if "sqrt" not in a: 45 | a = int(a) 46 | if "sqrt" not in b: 47 | b = int(b) 48 | assert string == "{}/{}".format(a, b) 49 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 50 | return new_string 51 | except Exception: 52 | return string 53 | 54 | 55 | def _fix_sqrt(string): 56 | _string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string) 57 | _string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string) 58 | return _string 59 | 60 | 61 | def _fix_tan(string): 62 | _string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string) 63 | _string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string) 64 | return _string 65 | 66 | 67 | def strip_string(string): 68 | string = str(string).strip() 69 | # linebreaks 70 | string = string.replace("\n", "") 71 | 72 | # right "." 73 | string = string.rstrip(".") 74 | 75 | # remove inverse spaces 76 | string = string.replace("\\!", "") 77 | # string = string.replace("\\ ", "") 78 | 79 | # replace \\ with \ 80 | # string = string.replace("\\\\", "\\") 81 | # string = string.replace("\\\\", "\\") 82 | 83 | if string.startswith("\\text{") and string.endswith("}"): 84 | string = string.split("{", 1)[1][:-1] 85 | 86 | # replace tfrac and dfrac with frac 87 | string = string.replace("tfrac", "frac") 88 | string = string.replace("dfrac", "frac") 89 | string = string.replace("cfrac", "frac") 90 | 91 | # remove \left and \right 92 | string = string.replace("\\left", "") 93 | string = string.replace("\\right", "") 94 | 95 | # Remove unit: miles, dollars if after is not none 96 | _string = re.sub(r"\\text{.*?}$", "", string).strip() 97 | if _string != "" and _string != string: 98 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) 99 | string = _string 100 | 101 | # Remove circ (degrees) 102 | string = string.replace("^{\\circ}", "").strip() 103 | string = string.replace("^\\circ", "").strip() 104 | 105 | string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip() 106 | string = regex.sub(r"p\.m\.$", "", string).strip() 107 | string = regex.sub(r"(\d)\s*t$", r"\1", string).strip() 108 | 109 | # remove dollar signs 110 | string = string.replace("\\$", "") 111 | string = string.replace("$", "") 112 | 113 | # string = string.replace("\\text", "") 114 | string = string.replace("x\\in", "") 115 | 116 | # remove percentage 117 | string = string.replace("\\%", "%") 118 | string = string.replace("\%", "%") 119 | # string = string.replace("%", "") 120 | 121 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 122 | string = string.replace(" .", " 0.") 123 | string = string.replace("{.", "{0.") 124 | 125 | # cdot 126 | string = string.replace("\\cdot", "") 127 | 128 | # inf 129 | string = string.replace("infinity", "\\infty") 130 | if "\\infty" not in string: 131 | string = string.replace("inf", "\\infty") 132 | string = string.replace("+\\inity", "\\infty") 133 | 134 | # and 135 | # string = string.replace("and", "") 136 | string = string.replace("\\mathbf", "") 137 | string = string.replace("\\mathrm", "") 138 | 139 | # use regex to remove \mbox{...} 140 | string = re.sub(r"\\mbox{.*?}", "", string) 141 | 142 | # quote 143 | string.replace("'", "") 144 | string.replace("\"", "") 145 | 146 | # i, j 147 | if "j" in string and "i" not in string: 148 | string = string.replace("j", "i") 149 | 150 | # replace a.000b where b is not number or b is end, with ab, use regex 151 | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string) 152 | string = re.sub(r"(\d+)\.0+$", r"\1", string) 153 | 154 | # if empty, return empty string 155 | if len(string) == 0: 156 | return string 157 | if string[0] == ".": 158 | string = "0" + string 159 | 160 | # to consider: get rid of e.g. "k = " or "q = " at beginning 161 | # if len(string.split("=")) == 2: 162 | # if len(string.split("=")[0]) <= 2: 163 | # string = string.split("=")[1] 164 | 165 | string = _fix_sqrt(string) 166 | string = _fix_tan(string) 167 | string = string.replace(" ", "") 168 | 169 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 170 | string = _fix_fracs(string) 171 | 172 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 173 | string = _fix_a_slash_b(string) 174 | 175 | string = regex.sub(r"(\\|,|\.)+$", "", string) 176 | 177 | return string 178 | 179 | def extract_boxed_answers(text): 180 | answers = [] 181 | for piece in text.split('boxed{')[1:]: 182 | n = 0 183 | for i in range(len(piece)): 184 | if piece[i] == '{': 185 | n += 1 186 | elif piece[i] == '}': 187 | n -= 1 188 | if n < 0: 189 | if i + 1 < len(piece) and piece[i + 1] == '%': 190 | answers.append(piece[: i + 1]) 191 | else: 192 | answers.append(piece[:i]) 193 | break 194 | return answers 195 | 196 | def extract_program_output(pred_str): 197 | """ 198 | extract output between the last ```output\n...\n``` 199 | """ 200 | if "```output" not in pred_str: 201 | return "" 202 | if '```output' in pred_str: 203 | pred_str = pred_str.split('```output')[-1] 204 | if '```' in pred_str: 205 | pred_str = pred_str.split('```')[0] 206 | output = pred_str.strip() 207 | return output 208 | 209 | def extract_answer(pred_str, exhaust=False): 210 | pred = [] 211 | 212 | # import pdb; pdb.set_trace() 213 | 214 | if 'final answer is $' in pred_str and '$. I hope' in pred_str: 215 | tmp = pred_str.split('final answer is $', 1)[1] 216 | pred = [tmp.split('$. I hope', 1)[0].strip()] 217 | elif 'boxed' in pred_str: 218 | pred = extract_boxed_answers(pred_str) 219 | 220 | # import pdb; pdb.set_trace() 221 | 222 | elif ('he answer is' in pred_str): 223 | pred = [pred_str.split('he answer is')[-1].strip()] 224 | else: 225 | program_output = extract_program_output(pred_str) 226 | if program_output != "": 227 | # fall back to program 228 | pred.append(program_output) 229 | else: # use the last number 230 | pattern = '-?\d*\.?\d+' 231 | ans = re.findall(pattern, pred_str.replace(",", "")) 232 | if(len(ans) >= 1): 233 | ans = ans[-1] 234 | else: 235 | ans = '' 236 | if ans: 237 | pred.append(ans) 238 | 239 | # multiple line 240 | _pred = [] 241 | for ans in pred: 242 | # ans = ans.strip().split("\n")[0] 243 | ans = ans.strip() 244 | 245 | ans = ans.lstrip(":") 246 | ans = ans.rstrip(".") 247 | ans = ans.rstrip("/") 248 | ans = strip_string(ans) 249 | _pred.append(ans) 250 | if exhaust: 251 | return _pred 252 | else: 253 | return _pred[-1] if _pred else "" 254 | 255 | def extract_math_answer(question, reasoning, task): 256 | answer = [] 257 | for ans in extract_answer(reasoning, exhaust=True): 258 | if 'separated by commas' in question and all(ch not in ans for ch in '()[]'): 259 | answer.extend([a.strip() for a in ans.split(",")]) 260 | elif regex.search(r"\\text\{\s*and\s*\}", ans): 261 | answer.extend([a.strip() for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split("[SEP]")]) 262 | else: 263 | answer.append(ans.strip()) 264 | return answer 265 | 266 | def extract_math_few_shot_cot_answer(question, reasoning, task): 267 | if 'Problem:' in reasoning: 268 | reasoning = reasoning.split("Problem:", 1)[0] 269 | return extract_math_answer(question, reasoning, task) 270 | 271 | def extract_last_single_answer(question, reasoning, task): 272 | return extract_answer(reasoning, exhaust=False) 273 | 274 | def extract_gsm_few_shot_cot_answer(question, reasoning, task): 275 | if 'Q: ' in reasoning: 276 | reasoning = reasoning.split("Q: ", 1)[0] 277 | pred = [s for s in regex.findall(r'-?\d+\.?\d*', reasoning)] 278 | if pred: 279 | return pred[-1] 280 | else: 281 | return "[invalid]" 282 | 283 | def extract_agieval_gaokao_mathcloze_few_shot_cot_test(question, reasoning, task): 284 | if '问题 ' in reasoning: 285 | reasoning = reasoning.split("问题 ", 1)[0] 286 | if '答案是' in reasoning: 287 | ans = reasoning.split('答案是', 1)[1].strip() 288 | ans = ans.split("\n")[0].strip() 289 | ans = [ans.strip("$")] 290 | else: 291 | ans = ['placeholder'] 292 | return ans 293 | 294 | def extract_agieval_gaokao_mathqa_few_shot_cot_test(question, reasoning, task): 295 | if '问题 ' in reasoning: 296 | reasoning = reasoning.split("问题 ", 1)[0] 297 | if '答案是' in reasoning: 298 | ans = reasoning.split('答案是', 1)[1].strip() 299 | ans = ans.split("\n")[0].strip() 300 | else: 301 | ans = 'placeholder' 302 | return ans 303 | 304 | def extract_sat_few_shot_answer(question, reasoning, task): 305 | if 'Problem:' in reasoning: 306 | reasoning = reasoning.split("Problem:", 1)[0] 307 | patt = regex.search(r"the final answer is \(?(?P[abcd])\)?", reasoning.lower()) 308 | if patt is not None: 309 | return patt.group('ans').upper() 310 | return 'placeholder' 311 | 312 | def extract_ocwcourses_few_shot_answer(question, reasoning, task): 313 | if 'Problem:' in reasoning: 314 | reasoning = reasoning.split("Problem:", 1)[0] 315 | patt = regex.search(r"final answer is (?P.*)\. I hope it is correct.", reasoning) 316 | if patt is None: 317 | pred = "[invalid]" 318 | print(f"DEBUG >>>\n{reasoning}", flush=True) 319 | else: 320 | pred = patt.group('ans') 321 | return pred 322 | 323 | def extract_mmlu_stem(question, reasoning, task): 324 | if 'Problem:' in reasoning: 325 | reasoning = reasoning.split("Problem:", 1)[0] 326 | return extract_sat_few_shot_answer(question, reasoning, task) 327 | 328 | def extract_minif2f_isabelle(question, reasoning, task): 329 | if 'Informal:' in reasoning: 330 | reasoning = reasoning.split("Informal:", 1)[0] 331 | return reasoning.strip() 332 | 333 | def extract_cmath_few_shot_test(question, reasoning, task): 334 | if '问题:' in reasoning: 335 | reasoning = reasoning.split("问题:", 1)[0] 336 | if '答案是' in reasoning: 337 | ans = reasoning.split('答案是', 1)[1].strip() 338 | ans = ans.split("\n")[0] 339 | ans = ans.strip(":") 340 | ans = ans.strip("。") 341 | try: 342 | ans = [s for s in regex.findall(r'-?\d+\.?\d*', ans)][-1] 343 | except Exception: 344 | print(f"DEBUG CMATH: {reasoning}", flush=True) 345 | ans = "[invalid]" 346 | else: 347 | ans = extract_last_single_answer(question, reasoning, task) 348 | return ans 349 | -------------------------------------------------------------------------------- /evaluation/data_processing/process_utils.py: -------------------------------------------------------------------------------- 1 | import regex 2 | from data_processing.answer_extraction import extract_math_answer, strip_string 3 | 4 | 5 | def process_gsm8k_test(item): 6 | sample = { 7 | 'dataset': 'gsm8k-cot', 8 | 'id': item['id'], 9 | 'messages': [ 10 | {'role': 'user', 'content': item['question']}, 11 | {'role': 'assistant', 'content': regex.sub(r"<<[^<>]*>>", "", item['cot']) + "\nSo the answer is $\\boxed{" + item['answer'].strip() + "}$."} 12 | ], 13 | 'answer': item['answer'].replace(',', '') 14 | } 15 | yield sample 16 | 17 | def process_math_test(item): 18 | question = item["problem"] 19 | try: 20 | answer = extract_math_answer(question, item['solution'], task="cot") 21 | except Exception: 22 | return 23 | sample = { 24 | "dataset": "math-cot", 25 | "id": item['id'], 26 | "level": item["level"], 27 | "type": item["type"], 28 | "category": item["category"], 29 | "messages": [ 30 | {"role": "user", "content": question}, 31 | {"role": "assistant", "content": "\n".join(regex.split(r"(?<=\.) (?=[A-Z])", item["solution"]))} 32 | ], 33 | "answer": answer 34 | } 35 | yield sample 36 | 37 | def process_math_sat(item): 38 | options = item['options'].strip() 39 | assert 'A' == options[0] 40 | options = '(' + options 41 | for ch in 'BCDEFG': 42 | if f' {ch}) ' in options: 43 | options = regex.sub(f' {ch}\) ', f" ({ch}) ", options) 44 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}" 45 | messages = [ 46 | {'role': 'user', 'content': question}, 47 | {'role': 'assistant', 'content': item['Answer']} 48 | ] 49 | item = { 50 | 'dataset': 'math_sat', 51 | 'id': item['id'], 52 | 'language': 'en', 53 | 'messages': messages, 54 | 'answer': item['Answer'], 55 | } 56 | yield item 57 | 58 | def process_ocwcourses(item): 59 | messages = [ 60 | {'role': 'user', 'content': item['problem'].strip()}, 61 | {'role': 'assistant', 'content': item['solution'].strip()} 62 | ] 63 | item = { 64 | "dataset": "OCWCourses", 65 | "id": item['id'], 66 | "language": "en", 67 | "messages": messages, 68 | "answer": item['answer'] 69 | } 70 | yield item 71 | 72 | def process_mmlu_stem(item): 73 | options = item['options'] 74 | for i, (label, option) in enumerate(zip('ABCD', options)): 75 | options[i] = f"({label}) {str(option).strip()}" 76 | options = ", ".join(options) 77 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}" 78 | messages = [ 79 | {'role': 'user', 'content': question}, 80 | {'role': 'assistant', 'content': item['answer']} 81 | ] 82 | item = { 83 | "dataset": "MMLU-STEM", 84 | "id": item['id'], 85 | "language": "en", 86 | "messages": messages, 87 | "answer": item['answer'] 88 | } 89 | yield item 90 | 91 | def process_mgsm_zh(item): 92 | item['answer'] = item['answer'].replace(',', '') 93 | yield item 94 | 95 | def process_cmath(item): 96 | item = { 97 | 'dataset': 'cmath', 98 | 'id': item['id'], 99 | 'grade': item['grade'], 100 | 'reasoning_step': item['reasoning_step'], 101 | 'messages': [ 102 | {'role': 'user', 'content': item['question'].strip()}, 103 | {'role': 'assistant', 'content': ''} 104 | ], 105 | 'answer': item['golden'].strip().replace(",", "") 106 | } 107 | yield item 108 | 109 | def process_agieval_gaokao_math_cloze(item): 110 | item = { 111 | 'dataset': 'agieval-gaokao-math-cloze', 112 | 'id': item['id'], 113 | 'messages': [ 114 | {'role': 'user', 'content': item['question'].strip()}, 115 | {'role': 'assistant', 'content': ''} 116 | ], 117 | 'answer': [strip_string(ans) for ans in item['answer'].strip().split(";")] 118 | } 119 | yield item 120 | 121 | def process_agieval_gaokao_mathqa(item): 122 | question = item['question'].strip() 123 | options = [] 124 | for option in item['options']: 125 | option = option.strip() 126 | assert option[0] == '(' 127 | assert option[2] == ')' 128 | assert option[1] in 'ABCD' 129 | option = f"{option[1]}: {option[3:].strip()}" 130 | options.append(option.strip()) 131 | question = f"{question}\n{options}" 132 | item = { 133 | 'dataset': 'agieval-gaokao-mathqa', 134 | 'id': item['id'], 135 | 'messages': [ 136 | {'role': 'user', 'content': question}, 137 | {'role': 'assistant', 'content': ''} 138 | ], 139 | "answer": item['label'] 140 | } 141 | yield item 142 | 143 | def process_agieval_gaokao_mathqa_few_shot_cot_test(item): 144 | question = item['question'].strip().rstrip('\\') 145 | options = " ".join([opt.strip() for opt in item['options']]) 146 | question = f"{question}\n从以下选项中选择: {options}" 147 | item = { 148 | 'dataset': 'agieval-gaokao-mathqa', 149 | 'id': item['id'], 150 | 'messages': [ 151 | {'role': 'user', 'content': question}, 152 | {'role': 'assistant', 'content': ''} 153 | ], 154 | "answer": item['label'] 155 | } 156 | yield item 157 | 158 | def process_minif2f_isabelle(item): 159 | question = f"(*### Problem\n\n{item['informal_statement'].strip()}\n\n### Solution\n\n{item['informal_proof'].strip()} *)\n\nFormal:\n{item['formal_statement'].strip()}" 160 | item = { 161 | 'dataset': 'minif2f-isabelle', 162 | 'id': item['id'], 163 | 'messages': [ 164 | {'role': 'user', 'content': question}, 165 | {'role': 'assistant', 'content': ''} 166 | ], 167 | "answer": "placeholder" 168 | } 169 | yield item 170 | -------------------------------------------------------------------------------- /evaluation/eval/eval_script.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import regex 4 | from evaluation.eval.eval_utils import math_equal 5 | from evaluation.eval.ocwcourses_eval_utils import ( 6 | SymbolicMathMixin, 7 | normalize_numeric, 8 | normalize_symbolic_equation, 9 | numeric_equality, 10 | ) 11 | 12 | 13 | def is_correct(item, pred_key='prediction', prec=1e-3): 14 | pred = item[pred_key] 15 | ans = item['answer'] 16 | if isinstance(pred, list) and isinstance(ans, list): 17 | pred_matched = set() 18 | ans_matched = set() 19 | for i in range(len(pred)): 20 | for j in range(len(ans)): 21 | item_cpy = deepcopy(item) 22 | item_cpy.update({ 23 | pred_key: pred[i], 24 | 'answer': ans[j] 25 | }) 26 | if is_correct(item_cpy, pred_key=pred_key, prec=prec): 27 | pred_matched.add(i) 28 | ans_matched.add(j) 29 | if item_cpy[pred_key] == '2,3,4': 30 | print(item, flush=True) 31 | print("wtf", flush=True) 32 | return len(pred_matched) == len(pred) and len(ans_matched) == len(ans) 33 | elif isinstance(pred, str) and isinstance(ans, str): 34 | if '\\cup' in pred and '\\cup' in ans: 35 | item = deepcopy(item) 36 | item.update({ 37 | pred_key: pred.split('\\cup'), 38 | 'answer': ans.split('\\cup'), 39 | }) 40 | return is_correct(item, pred_key=pred_key, prec=prec) 41 | else: 42 | label = False 43 | try: 44 | label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec 45 | except Exception: 46 | pass 47 | 48 | # if ans == "0.5": 49 | # import pdb; pdb.set_trace() 50 | 51 | label = label or (ans and pred == ans) or math_equal(pred, ans) 52 | return label 53 | else: 54 | print(item, flush=True) 55 | raise NotImplementedError() 56 | 57 | def eval_math(item, pred_key='prediction', prec=1e-3): 58 | pred = item[pred_key] 59 | if pred_key == 'program_output' and isinstance(pred, str): 60 | pred = [pred] 61 | ans = item['answer'] 62 | if isinstance(pred, list) and isinstance(ans, list): 63 | # for some questions in MATH, `reference` repeats answers 64 | _ans = [] 65 | for a in ans: 66 | if a not in _ans: 67 | _ans.append(a) 68 | ans = _ans 69 | # some predictions for MATH questions also repeats answers 70 | _pred = [] 71 | for a in pred: 72 | if a not in _pred: 73 | _pred.append(a) 74 | # some predictions mistakenly box non-answer strings 75 | pred = _pred[-len(ans):] 76 | 77 | item.update({ 78 | pred_key: pred, 79 | 'answer': ans 80 | }) 81 | return is_correct(item, pred_key=pred_key, prec=prec) 82 | 83 | def eval_last_single_answer(item, pred_key='prediction', prec=1e-3): 84 | for key in [pred_key, 'answer']: 85 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 86 | return is_correct(item, pred_key=pred_key, prec=prec) 87 | 88 | def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3): 89 | if pred_key == 'program_output' and isinstance(item[pred_key], str): 90 | item[pred_key] = [item[pred_key]] 91 | for key in [pred_key, 'answer']: 92 | assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list" 93 | pred = item[pred_key] 94 | ans = item['answer'] 95 | _pred = [] 96 | for p in pred: 97 | p = p + ";" 98 | while p: 99 | left_brackets = 0 100 | for i in range(len(p)): 101 | if p[i] == ';' or (p[i] == ',' and left_brackets == 0): 102 | _p, p = p[:i].strip(), p[i + 1:].strip() 103 | if _p not in _pred: 104 | _pred.append(_p) 105 | break 106 | elif p[i] in '([{': 107 | left_brackets += 1 108 | elif p[i] in ')]}': 109 | left_brackets -= 1 110 | pred = _pred[-len(ans):] 111 | if len(pred) == len(ans): 112 | for p, a in zip(pred, ans): 113 | item.update({ 114 | pred_key: p, 115 | 'answer': a, 116 | }) 117 | if not is_correct(item, pred_key=pred_key, prec=prec): 118 | return False 119 | return True 120 | else: 121 | return False 122 | 123 | def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3): 124 | if pred_key == 'program_output' and isinstance(item[pred_key], str): 125 | item[pred_key] = [item[pred_key]] 126 | pred_str = " ".join(item[pred_key]) 127 | ans = item['answer'] 128 | tag = None 129 | idx = -1 130 | for t in 'ABCD': 131 | if t in pred_str and pred_str.index(t) > idx: 132 | tag = t 133 | idx = pred_str.index(t) 134 | return tag == ans 135 | 136 | def eval_math_sat(item, pred_key='prediction', prec=1e-3): 137 | for key in [pred_key, 'answer']: 138 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 139 | return item[pred_key].lower() == item['answer'].lower() 140 | 141 | def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3): 142 | return eval_math_sat(item, pred_key=pred_key, prec=prec) 143 | 144 | def eval_ocwcourses(item, pred_key='prediction', prec=1e-3): 145 | INVALID_ANSWER = "[invalidanswer]" 146 | for key in [pred_key, 'answer']: 147 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" 148 | pred = item[pred_key] 149 | ans = item['answer'] 150 | 151 | try: 152 | # numeric 153 | float(ans) 154 | normalize_fn = normalize_numeric 155 | is_equiv = numeric_equality 156 | except ValueError: 157 | if "=" in ans: 158 | # equation 159 | normalize_fn = normalize_symbolic_equation 160 | is_equiv = lambda x, y: x==y 161 | else: 162 | # expression 163 | normalize_fn = SymbolicMathMixin().normalize_tex 164 | is_equiv = SymbolicMathMixin().is_tex_equiv 165 | 166 | correct_answer = normalize_fn(ans) 167 | 168 | unnormalized_answer = pred if pred else INVALID_ANSWER 169 | model_answer = normalize_fn(unnormalized_answer) 170 | 171 | if unnormalized_answer == INVALID_ANSWER: 172 | acc = 0 173 | elif model_answer == INVALID_ANSWER: 174 | acc = 0 175 | elif is_equiv(model_answer, correct_answer): 176 | acc = 1 177 | else: 178 | acc = 0 179 | 180 | return acc 181 | 182 | def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3): 183 | return True 184 | -------------------------------------------------------------------------------- /evaluation/eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import re 3 | from math import isclose 4 | from typing import Any, Dict, Union 5 | 6 | import numpy as np 7 | import regex 8 | from evaluation.data_processing.answer_extraction import ( 9 | extract_answer, 10 | extract_program_output, 11 | strip_string, 12 | ) 13 | from sympy import N, simplify 14 | from sympy.parsing.latex import parse_latex 15 | from sympy.parsing.sympy_parser import parse_expr 16 | 17 | 18 | def extract_program(result: str, last_only=True): 19 | """ 20 | extract the program after "```python", and before "```" 21 | """ 22 | program = "" 23 | start = False 24 | for line in result.split("\n"): 25 | if line.startswith("```python"): 26 | if last_only: 27 | program = "" # only extract the last program 28 | else: 29 | program += "\n# ========\n" 30 | start = True 31 | elif line.startswith("```"): 32 | start = False 33 | elif start: 34 | program += line + "\n" 35 | return program 36 | 37 | 38 | def parse_ground_truth(example: Dict[str, Any], data_name): 39 | if 'gt_cot' in example: 40 | return example['gt_cot'], strip_string(example['gt']) 41 | 42 | # parse ground truth 43 | if data_name in ["math", 'ocw']: 44 | gt_cot = example['solution'] 45 | gt_ans = extract_answer(gt_cot) 46 | elif data_name == "gsm8k": 47 | gt_cot, gt_ans = example['answer'].split("####") 48 | elif data_name == "gsm-hard": 49 | gt_cot, gt_ans = example['code'], example['target'] 50 | elif data_name == "svamp": 51 | gt_cot, gt_ans = example['Equation'], example['Answer'] 52 | elif data_name == "asdiv": 53 | gt_cot = example['formula'] 54 | gt_ans = re.sub(r"\(.*?\)", "", example['answer']) 55 | elif data_name == "mawps": 56 | gt_cot, gt_ans = None, example['target'] 57 | elif data_name == "tabmwp": 58 | gt_cot = example['solution'] 59 | gt_ans = example['answer'] 60 | if example['ans_type'] in ['integer_number', 'decimal_number']: 61 | if '/' in gt_ans: 62 | gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1]) 63 | elif ',' in gt_ans: 64 | gt_ans = float(gt_ans.replace(',', '')) 65 | elif '%' in gt_ans: 66 | gt_ans = float(gt_ans.split('%')[0]) / 100 67 | else: 68 | gt_ans = float(gt_ans) 69 | elif data_name == "bbh": 70 | gt_cot, gt_ans = None, example['target'] 71 | else: 72 | raise NotImplementedError(data_name) 73 | # post process 74 | gt_cot = str(gt_cot).strip() 75 | gt_ans = strip_string(gt_ans) 76 | return gt_cot, gt_ans 77 | 78 | 79 | def parse_question(example, data_name): 80 | question = "" 81 | if data_name == "asdiv": 82 | question = f"{example['body'].strip()} {example['question'].strip()}" 83 | elif data_name == "svamp": 84 | body = example["Body"].strip() 85 | if not body.endswith("."): 86 | body = body + "." 87 | question = f'{body} {example["Question"].strip()}' 88 | elif data_name == "tabmwp": 89 | title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else "" 90 | question = f'Read the following table {title_str}and answer a question:\n' 91 | question += f'{example["table"]}\n{example["question"]}' 92 | if example['choices']: 93 | question += f' Please select from the following options: {example["choices"]}' 94 | else: 95 | for key in ['question', 'problem', 'Question', 'input']: 96 | if key in example: 97 | question = example[key] 98 | break 99 | assert question != "" 100 | return question.strip() 101 | 102 | 103 | def run_execute(executor, result, prompt_type, execute=False): 104 | if not result or result == 'error': 105 | return None, None 106 | report = None 107 | 108 | if "program_only" in prompt_type: 109 | prediction = extract_program_output(result) 110 | elif prompt_type in ["pot", "pal"] and execute: 111 | code = extract_program(result) 112 | prediction, report = executor.apply(code) 113 | else: 114 | prediction = extract_answer(result) 115 | 116 | prediction = strip_string(prediction) 117 | return prediction, report 118 | 119 | 120 | def parse_digits(num): 121 | # format: 234.23 || 23% 122 | num = regex.sub(',', '', str(num)) 123 | try: 124 | return float(num) 125 | except Exception: 126 | if num.endswith('%'): 127 | num = num[:-1] 128 | if num.endswith('\\'): 129 | num = num[:-1] 130 | try: 131 | return float(num) / 100 132 | except Exception: 133 | pass 134 | return None 135 | 136 | def is_digit(num): 137 | # paired with parse_digits 138 | return parse_digits(num) is not None 139 | 140 | 141 | def normalize_prediction(prediction): 142 | try: # 1. numerical equal 143 | if is_digit(prediction): 144 | prediction = np.round(float(str(prediction).replace(",", "")), 6) 145 | return str(prediction) 146 | except Exception: 147 | pass 148 | 149 | # 2. symbolic equal 150 | prediction = str(prediction).strip() 151 | 152 | ## deal with [], (), {} 153 | brackets = [] 154 | while prediction.startswith("[") and prediction.endswith("]") or (prediction.startswith("(") and prediction.endswith(")")): 155 | prediction = prediction[1:-1] 156 | if brackets and ',' in prediction: 157 | pred_parts = [normalize_prediction(part) for part in prediction.split(",")] 158 | prediction = ",".join(pred_parts) 159 | 160 | if brackets: 161 | for b in reversed(brackets): 162 | if b == '[': 163 | prediction = '[' + prediction + ']' 164 | else: 165 | assert b == '(' 166 | prediction = '(' + prediction + ')' 167 | 168 | def _parse(s): 169 | for f in [parse_latex, parse_expr]: 170 | try: 171 | return f(s) 172 | except Exception: 173 | pass 174 | return s 175 | 176 | prediction = _parse(prediction) 177 | 178 | for s in ['{', "}", "(", ")"]: 179 | prediction = prediction.replace(s, "") 180 | 181 | return prediction 182 | 183 | 184 | def math_equal(prediction: Union[bool, float, str], 185 | reference: Union[float, str], 186 | include_percentage: bool = True, 187 | is_close: bool = True, 188 | timeout: bool = False, 189 | ) -> bool: 190 | """ 191 | Exact match of math if and only if: 192 | 1. numerical equal: both can convert to float and are equal 193 | 2. symbolic equal: both can convert to sympy expression and are equal 194 | """ 195 | if str(prediction) == str(reference): 196 | return True 197 | 198 | if "^{216}" in prediction: 199 | return False 200 | 201 | print("prediction: {}, reference: {}".format( 202 | prediction, reference 203 | )) 204 | 205 | try: # 1. numerical equal 206 | if is_digit(prediction) and is_digit(reference): 207 | prediction = parse_digits(prediction) 208 | reference = parse_digits(reference) 209 | # number questions 210 | if include_percentage: 211 | gt_result = [reference / 100, reference, reference * 100] 212 | else: 213 | gt_result = [reference] 214 | for item in gt_result: 215 | try: 216 | if is_close: 217 | if isclose(item, prediction, abs_tol=1e-3): 218 | return True 219 | else: 220 | if item == prediction: 221 | return True 222 | except Exception: 223 | continue 224 | return False 225 | except Exception: 226 | pass 227 | 228 | if not prediction and prediction not in [0, False]: 229 | return False 230 | 231 | # 2. symbolic equal 232 | reference = str(reference).strip() 233 | prediction = str(prediction).strip() 234 | 235 | if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None: 236 | pred_parts = prediction[1:-1].split(",") 237 | ref_parts = reference[1:-1].split(",") 238 | if len(pred_parts) == len(ref_parts): 239 | if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): 240 | return True 241 | 242 | if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \ 243 | (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")): 244 | pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] 245 | ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()] 246 | matched = True 247 | if len(pred_lines) == len(ref_lines): 248 | for pred_line, ref_line in zip(pred_lines, ref_lines): 249 | pred_parts = pred_line.split("&") 250 | ref_parts = ref_line.split("&") 251 | if len(pred_parts) == len(ref_parts): 252 | if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]): 253 | matched = False 254 | break 255 | else: 256 | matched = False 257 | if not matched: 258 | break 259 | else: 260 | matched = False 261 | if matched: 262 | return True 263 | 264 | if prediction.count('=') == 1 and reference.count('=') == 1: 265 | pred = prediction.split('=') 266 | pred = f"{pred[0].strip()} - ({pred[1].strip()})" 267 | ref = reference.split('=') 268 | ref = f"{ref[0].strip()} - ({ref[1].strip()})" 269 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): 270 | return True 271 | elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference: 272 | if math_equal(prediction.split('=')[1], reference, include_percentage, is_close): 273 | return True 274 | elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction: 275 | if math_equal(prediction, reference.split('=')[1], include_percentage, is_close): 276 | return True 277 | 278 | # symbolic equal with sympy 279 | if timeout: 280 | if call_with_timeout(symbolic_equal_process, prediction, reference): 281 | return True 282 | else: 283 | if symbolic_equal(prediction, reference): 284 | return True 285 | 286 | return False 287 | 288 | 289 | def math_equal_process(param): 290 | return math_equal(param[-2], param[-1]) 291 | 292 | 293 | def symbolic_equal(a, b): 294 | def _parse(s): 295 | for f in [parse_latex, parse_expr]: 296 | try: 297 | return f(s) 298 | except Exception: 299 | pass 300 | return s 301 | a = _parse(a) 302 | b = _parse(b) 303 | 304 | try: 305 | if simplify(a-b) == 0: 306 | return True 307 | except Exception: 308 | pass 309 | 310 | try: 311 | if isclose(N(a), N(b), abs_tol=1e-3): 312 | return True 313 | except Exception: 314 | pass 315 | return False 316 | 317 | 318 | def symbolic_equal_process(a, b, output_queue): 319 | result = symbolic_equal(a, b) 320 | output_queue.put(result) 321 | 322 | 323 | def call_with_timeout(func, *args, timeout=1, **kwargs): 324 | output_queue = multiprocessing.Queue() 325 | process_args = args + (output_queue,) 326 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) 327 | process.start() 328 | process.join(timeout) 329 | 330 | if process.is_alive(): 331 | process.terminate() 332 | process.join() 333 | return False 334 | 335 | return output_queue.get() 336 | -------------------------------------------------------------------------------- /evaluation/eval/ocwcourses_eval_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import signal 3 | 4 | import numpy as np 5 | import sympy 6 | from sympy.core.sympify import SympifyError 7 | from sympy.parsing.latex import parse_latex 8 | 9 | INVALID_ANSWER = "[invalidanswer]" 10 | 11 | class timeout: 12 | def __init__(self, seconds=1, error_message="Timeout"): 13 | self.seconds = seconds 14 | self.error_message = error_message 15 | 16 | def handle_timeout(self, signum, frame): 17 | raise TimeoutError(self.error_message) 18 | 19 | def __enter__(self): 20 | signal.signal(signal.SIGALRM, self.handle_timeout) 21 | signal.alarm(self.seconds) 22 | 23 | def __exit__(self, type, value, traceback): 24 | signal.alarm(0) 25 | 26 | def normalize_numeric(s): 27 | if s is None: 28 | return None 29 | for unit in [ 30 | "eV", 31 | " \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}", 32 | " kg m/s", 33 | "kg*m/s", 34 | "kg", 35 | "m/s", 36 | "m / s", 37 | "m s^{-1}", 38 | "\\text{ m/s}", 39 | " \\mathrm{m/s}", 40 | " \\text{ m/s}", 41 | "g/mole", 42 | "g/mol", 43 | "\\mathrm{~g}", 44 | "\\mathrm{~g} / \\mathrm{mol}", 45 | "W", 46 | "erg/s", 47 | "years", 48 | "year", 49 | "cm", 50 | ]: 51 | s = s.replace(unit, "") 52 | s = s.strip() 53 | for maybe_unit in ["m", "s", "cm"]: 54 | s = s.replace("\\mathrm{" + maybe_unit + "}", "") 55 | s = s.replace("\\mathrm{~" + maybe_unit + "}", "") 56 | s = s.strip() 57 | s = s.strip("$") 58 | try: 59 | return float(eval(s)) 60 | except Exception: 61 | try: 62 | expr = parse_latex(s) 63 | if expr.is_number: 64 | return float(expr) 65 | return INVALID_ANSWER 66 | except Exception: 67 | return INVALID_ANSWER 68 | 69 | def numeric_equality(n1, n2, threshold=0.01): 70 | if n1 is None or n2 is None: 71 | return False 72 | if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0): 73 | return np.abs(n1 - n2) < threshold * (n1 + n2) / 2 74 | else: 75 | return np.isclose(n1, n2) 76 | 77 | def normalize_symbolic_equation(s): 78 | if not isinstance(s, str): 79 | return INVALID_ANSWER 80 | if s.startswith("\\["): 81 | s = s[2:] 82 | if s.endswith("\\]"): 83 | s = s[:-2] 84 | s = s.replace("\\left(", "(") 85 | s = s.replace("\\right)", ")") 86 | s = s.replace("\\\\", "\\") 87 | if s.startswith("$") or s.endswith("$"): 88 | s = s.strip("$") 89 | try: 90 | maybe_expression = parse_latex(s) 91 | if not isinstance(maybe_expression, sympy.core.relational.Equality): 92 | # we have equation, not expression 93 | return INVALID_ANSWER 94 | else: 95 | return maybe_expression 96 | except Exception: 97 | return INVALID_ANSWER 98 | 99 | class SymbolicMathMixin: 100 | """ 101 | Methods useful for parsing mathematical expressions from text and determining equivalence of expressions. 102 | """ 103 | 104 | SUBSTITUTIONS = [ # used for text normalize 105 | ("an ", ""), 106 | ("a ", ""), 107 | (".$", "$"), 108 | ("\\$", ""), 109 | (r"\ ", ""), 110 | (" ", ""), 111 | ("mbox", "text"), 112 | (",\\text{and}", ","), 113 | ("\\text{and}", ","), 114 | ("\\text{m}", "\\text{}"), 115 | ] 116 | REMOVED_EXPRESSIONS = [ # used for text normalizer 117 | "square", 118 | "ways", 119 | "integers", 120 | "dollars", 121 | "mph", 122 | "inches", 123 | "ft", 124 | "hours", 125 | "km", 126 | "units", 127 | "\\ldots", 128 | "sue", 129 | "points", 130 | "feet", 131 | "minutes", 132 | "digits", 133 | "cents", 134 | "degrees", 135 | "cm", 136 | "gm", 137 | "pounds", 138 | "meters", 139 | "meals", 140 | "edges", 141 | "students", 142 | "childrentickets", 143 | "multiples", 144 | "\\text{s}", 145 | "\\text{.}", 146 | "\\text{\ns}", 147 | "\\text{}^2", 148 | "\\text{}^3", 149 | "\\text{\n}", 150 | "\\text{}", 151 | r"\mathrm{th}", 152 | r"^\circ", 153 | r"^{\circ}", 154 | r"\;", 155 | r",\!", 156 | "{,}", 157 | '"', 158 | "\\dots", 159 | ] 160 | 161 | def normalize_tex(self, final_answer: str) -> str: 162 | """ 163 | Normalizes a string representing a mathematical expression. 164 | Used as a preprocessing step before parsing methods. 165 | 166 | Copied character for character from appendix D of Lewkowycz et al. (2022) 167 | """ 168 | final_answer = final_answer.split("=")[-1] 169 | 170 | for before, after in self.SUBSTITUTIONS: 171 | final_answer = final_answer.replace(before, after) 172 | for expr in self.REMOVED_EXPRESSIONS: 173 | final_answer = final_answer.replace(expr, "") 174 | 175 | # Extract answer that is in LaTeX math, is bold, 176 | # is surrounded by a box, etc. 177 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) 178 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) 179 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) 180 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) 181 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) 182 | 183 | # Normalize shorthand TeX: 184 | # \fracab -> \frac{a}{b} 185 | # \frac{abc}{bef} -> \frac{abc}{bef} 186 | # \fracabc -> \frac{a}{b}c 187 | # \sqrta -> \sqrt{a} 188 | # \sqrtab -> sqrt{a}b 189 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) 190 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) 191 | final_answer = final_answer.replace("$", "") 192 | 193 | # Normalize 100,000 -> 100000 194 | if final_answer.replace(",", "").isdigit(): 195 | final_answer = final_answer.replace(",", "") 196 | 197 | return final_answer 198 | 199 | def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic: 200 | """ 201 | Wrapper around `sympy.parse_text` that outputs a SymPy expression. 202 | Typically, you want to apply `normalize_text` as a preprocessing step. 203 | """ 204 | try: 205 | with timeout(seconds=time_limit): 206 | parsed = parse_latex(text) 207 | except ( 208 | # general error handling: there is a long tail of possible sympy/other 209 | # errors we would like to catch 210 | Exception 211 | ) as e: 212 | print(f"failed to parse {text} with exception {e}") 213 | return None 214 | 215 | return parsed 216 | 217 | def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool: 218 | """ 219 | Determines whether two sympy expressions are equal. 220 | """ 221 | try: 222 | with timeout(seconds=time_limit): 223 | try: 224 | diff = x1 - x2 225 | except (SympifyError, ValueError, TypeError) as e: 226 | print( 227 | f"Couldn't subtract {x1} and {x2} with exception {e}" 228 | ) 229 | return False 230 | 231 | try: 232 | if sympy.simplify(diff) == 0: 233 | return True 234 | else: 235 | return False 236 | except (SympifyError, ValueError, TypeError) as e: 237 | print(f"Failed to simplify {x1}-{x2} with {e}") 238 | return False 239 | except TimeoutError: 240 | print(f"Timed out comparing {x1} and {x2}") 241 | return False 242 | except Exception as e: 243 | print(f"Failed on unrecognized exception: {e}") 244 | return False 245 | 246 | def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool: 247 | """ 248 | Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal. 249 | 250 | Does so by first checking for string exact-match, then falls back on sympy-equivalence, 251 | following the (Lewkowycz et al. 2022) methodology. 252 | """ 253 | if x1 == x2: 254 | # don't resort to sympy if we have full string match, post-normalization 255 | return True 256 | else: 257 | return False 258 | parsed_x2 = self.parse_tex(x2) 259 | if not parsed_x2: 260 | # if our reference fails to parse into a Sympy object, 261 | # we forgo parsing + checking our generated answer. 262 | return False 263 | return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit) 264 | -------------------------------------------------------------------------------- /evaluation/eval/python_executor.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import io 3 | import pickle 4 | import traceback 5 | from concurrent.futures import TimeoutError 6 | from contextlib import redirect_stdout 7 | from functools import partial 8 | from typing import Any, Dict, Optional 9 | 10 | import multiprocess 11 | import regex 12 | from pebble import ProcessPool 13 | from timeout_decorator import timeout 14 | 15 | 16 | class GenericRuntime: 17 | GLOBAL_DICT = {} 18 | LOCAL_DICT = None 19 | HEADERS = [] 20 | def __init__(self): 21 | self._global_vars = copy.copy(self.GLOBAL_DICT) 22 | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None 23 | 24 | for c in self.HEADERS: 25 | self.exec_code(c) 26 | 27 | def exec_code(self, code_piece: str) -> None: 28 | if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece): 29 | raise RuntimeError() 30 | exec(code_piece, self._global_vars) 31 | 32 | def eval_code(self, expr: str) -> Any: 33 | return eval(expr, self._global_vars) 34 | 35 | def inject(self, var_dict: Dict[str, Any]) -> None: 36 | for k, v in var_dict.items(): 37 | self._global_vars[k] = v 38 | 39 | @property 40 | def answer(self): 41 | return self._global_vars['answer'] 42 | 43 | class PythonExecutor: 44 | def __init__( 45 | self, 46 | runtime: Optional[Any] = None, 47 | get_answer_symbol: Optional[str] = None, 48 | get_answer_expr: Optional[str] = None, 49 | get_answer_from_stdout: bool = False, 50 | ) -> None: 51 | self.runtime = runtime if runtime else GenericRuntime() 52 | self.answer_symbol = get_answer_symbol 53 | self.answer_expr = get_answer_expr 54 | self.get_answer_from_stdout = get_answer_from_stdout 55 | 56 | def process_generation_to_code(self, gens: str): 57 | batch_code = [] 58 | for g in gens: 59 | multiline_comments = False 60 | code = [] 61 | for line in g.split('\n'): 62 | strip_line = line.strip() 63 | if strip_line.startswith("#"): 64 | line = line.split("#", 1)[0] + "# comments" 65 | elif not multiline_comments and strip_line.startswith('"""') and strip_line.endswith('"""') and len(strip_line) >= 6: 66 | line = line.split('"""', 1)[0] + '"""comments"""' 67 | elif not multiline_comments and strip_line.startswith('"""'): 68 | multiline_comments = True 69 | elif multiline_comments and strip_line.endswith('"""'): 70 | multiline_comments = False 71 | line = "" 72 | if not multiline_comments: 73 | code.append(line) 74 | batch_code.append(code) 75 | return batch_code 76 | 77 | @staticmethod 78 | def execute( 79 | code, 80 | get_answer_from_stdout = None, 81 | runtime = None, 82 | answer_symbol = None, 83 | answer_expr = None, 84 | timeout_length = 10, 85 | ): 86 | try: 87 | if get_answer_from_stdout: 88 | program_io = io.StringIO() 89 | with redirect_stdout(program_io): 90 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 91 | program_io.seek(0) 92 | result = "".join(program_io.readlines()) # [-1] 93 | elif answer_symbol: 94 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 95 | result = runtime._global_vars[answer_symbol] 96 | elif answer_expr: 97 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 98 | result = timeout(timeout_length)(runtime.eval_code)(answer_expr) 99 | else: 100 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1])) 101 | result = timeout(timeout_length)(runtime.eval_code)(code[-1]) 102 | concise_exec_info = "" 103 | exec_info = "" 104 | str(result) 105 | pickle.dumps(result) # serialization check 106 | except Exception: 107 | # traceback.print_exc() 108 | result = '' 109 | concise_exec_info = traceback.format_exc().split('\n')[-2] 110 | exec_info = traceback.format_exc() 111 | if get_answer_from_stdout and 'exec(code_piece, self._global_vars)' in exec_info: 112 | exec_info = exec_info.split('exec(code_piece, self._global_vars)')[-1].strip() 113 | msg = [] 114 | for line in exec_info.split("\n"): 115 | patt = regex.search(r'(?P.*)File "(?P.*)", line (?P\d+), (?P.*)', line) 116 | if patt is not None: 117 | if '' in patt.group('end'): 118 | continue 119 | fname = patt.group("file") 120 | if "site-packages" in fname: 121 | fname = f"site-packages{fname.split('site-packages', 1)[1]}" 122 | line = f'{patt.group("start")}File "{fname}", {patt.group("end")}' 123 | else: 124 | line = f'{patt.group("start")}{patt.group("end")}' 125 | else: 126 | patt = regex.search(r'(?P.*)(?P/.*site-packages/.*\.py)(?P.*)', line) 127 | if patt is not None: 128 | line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}' 129 | msg.append(line) 130 | exec_info = "\n".join(msg) 131 | return result, concise_exec_info, exec_info 132 | 133 | def apply(self, code): 134 | return self.batch_apply([code])[0] 135 | 136 | def batch_apply(self, batch_code): 137 | all_code_snippets = self.process_generation_to_code(batch_code) 138 | all_exec_results = [] 139 | executor = partial( 140 | self.execute, 141 | get_answer_from_stdout=self.get_answer_from_stdout, 142 | runtime=self.runtime, 143 | answer_symbol=self.answer_symbol, 144 | answer_expr=self.answer_expr, 145 | timeout_length=10, 146 | ) 147 | with ProcessPool(max_workers=multiprocess.cpu_count()) as pool: 148 | iterator = pool.map(executor, all_code_snippets, timeout=10).result() 149 | 150 | while True: 151 | try: 152 | result = next(iterator) 153 | all_exec_results.append(result) 154 | except StopIteration: 155 | break 156 | except TimeoutError: 157 | all_exec_results.append(("", "Timeout Error", "Timeout Error")) 158 | except Exception as error: 159 | print(error) 160 | exit() 161 | 162 | batch_results = [] 163 | for code, (result, concise_exec_info, exec_info) in zip(all_code_snippets, all_exec_results): 164 | metadata = {'code': code, 'exec_result': result, 'concise_exec_info': concise_exec_info, 'exec_info': exec_info} 165 | batch_results.append((result, metadata)) 166 | return batch_results 167 | -------------------------------------------------------------------------------- /evaluation/eval/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from transformers import GenerationConfig, StoppingCriteria 4 | 5 | 6 | class KeyWordsCriteria(StoppingCriteria): 7 | def __init__(self, stop_id_sequences, tokenizer, prompt_length): 8 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids" 9 | self.tokenizer = tokenizer 10 | self.stop_id_sequences = stop_id_sequences 11 | self.stop_sequences = [tokenizer.decode(sequence) for sequence in stop_id_sequences] 12 | print(f"stop sequences: {self.stop_sequences}", flush=True) 13 | self.prompt_length = prompt_length 14 | 15 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 16 | sequences_should_be_stopped = [] 17 | for i in range(input_ids.shape[0]): 18 | ids = input_ids[i][self.prompt_length:].tolist() 19 | should_be_stopped = False 20 | for stop_ids, stop_sequence in zip(self.stop_id_sequences, self.stop_sequences): 21 | _ids = ids 22 | for j in range(len(_ids), 0, -1): 23 | s = self.tokenizer.decode(_ids[max(j - len(stop_ids) - 3, 0) :j]) 24 | if s.endswith(stop_sequence): 25 | should_be_stopped = True 26 | break 27 | if should_be_stopped: 28 | break 29 | sequences_should_be_stopped.append(should_be_stopped) 30 | return all(sequences_should_be_stopped) 31 | 32 | @torch.no_grad() 33 | def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, end_of_generation_id_sequence=None, disable_tqdm=False, **generation_kwargs): 34 | generations = [] 35 | finish_completion = [] 36 | if not disable_tqdm: 37 | progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions") 38 | 39 | if stop_id_sequences is not None: 40 | stop_sequences = [tokenizer.decode(stop_id_sequence) for stop_id_sequence in stop_id_sequences] 41 | 42 | if end_of_generation_id_sequence is not None: 43 | end_of_generation_sequence = tokenizer.decode(end_of_generation_id_sequence) 44 | 45 | num_return_sequences = generation_kwargs.get("num_return_sequences", 1) 46 | generation_kwargs['use_cache'] = True 47 | for i in range(0, len(prompts), batch_size): 48 | batch_prompts = prompts[i:i+batch_size] 49 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens='chatglm2' in str(model.__class__)) 50 | batch_input_ids = tokenized_prompts.input_ids 51 | attention_mask = tokenized_prompts.attention_mask 52 | 53 | if model.device.type == "cuda": 54 | batch_input_ids = batch_input_ids.cuda() 55 | attention_mask = attention_mask.cuda() 56 | 57 | batch_finish_completion = [False] * len(batch_prompts) * num_return_sequences 58 | try: 59 | batch_outputs = model.generate( 60 | input_ids=batch_input_ids, 61 | attention_mask=attention_mask, 62 | stopping_criteria=[KeyWordsCriteria(stop_id_sequences, tokenizer, batch_input_ids.size(1))] if stop_id_sequences else None, 63 | **generation_kwargs 64 | ) 65 | 66 | # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate. 67 | # so some outputs still have the stop sequence, which we need to remove. 68 | if stop_id_sequences: 69 | for output_idx in range(batch_outputs.shape[0]): 70 | for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]): 71 | if any(tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(stop_sequence) + 3]).startswith(stop_sequence) for stop_sequence in stop_sequences): 72 | if end_of_generation_id_sequence is not None and tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(end_of_generation_id_sequence) + 3]).startswith(end_of_generation_sequence): 73 | batch_finish_completion[output_idx] = True 74 | batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id 75 | break 76 | 77 | # remove the prompt from the output 78 | # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs. 79 | # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token. 80 | # space is important for some tasks (e.g., code completion). 81 | batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True) 82 | batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True) 83 | # duplicate the prompts to match the number of return sequences 84 | batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)] 85 | batch_generations = [ 86 | output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs) 87 | ] 88 | except Exception as e: 89 | print("Error when generating completions for batch:") 90 | print(batch_prompts) 91 | print("Error message:") 92 | print(e) 93 | print("Use empty string as the completion.") 94 | batch_generations = [""] * len(batch_prompts) * num_return_sequences 95 | 96 | generations += batch_generations 97 | finish_completion += batch_finish_completion 98 | 99 | if not disable_tqdm: 100 | progress.update(len(batch_prompts)//num_return_sequences) 101 | 102 | assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences" 103 | return generations, finish_completion 104 | 105 | 106 | @torch.no_grad() 107 | def get_next_word_predictions(model, tokenizer, prompts, candidate_token_ids=None, batch_size=1, return_token_predictions=False, disable_tqdm=False): 108 | predictions, probs = [], [] 109 | if not disable_tqdm: 110 | progress = tqdm.tqdm(total=len(prompts), desc="Getting Predictions") 111 | 112 | for i in range(0, len(prompts), batch_size): 113 | batch_prompts = prompts[i: i+batch_size] 114 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False) 115 | batch_input_ids = tokenized_prompts.input_ids 116 | attention_mask = tokenized_prompts.attention_mask 117 | 118 | if model.device.type == "cuda": 119 | batch_input_ids = batch_input_ids.cuda() 120 | attention_mask = attention_mask.cuda() 121 | 122 | batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :] 123 | if candidate_token_ids is not None: 124 | batch_logits = batch_logits[:, candidate_token_ids] 125 | batch_probs = torch.softmax(batch_logits, dim=-1) 126 | batch_prediction_indices = torch.argmax(batch_probs, dim=-1) 127 | if return_token_predictions: 128 | if candidate_token_ids is not None: 129 | candidate_tokens = tokenizer.convert_ids_to_tokens(candidate_token_ids) 130 | batch_predictions = [candidate_tokens[idx] for idx in batch_prediction_indices] 131 | else: 132 | batch_predictions = tokenizer.convert_ids_to_tokens(batch_prediction_indices) 133 | predictions += batch_predictions 134 | else: 135 | predictions += batch_prediction_indices.tolist() 136 | probs += batch_probs.tolist() 137 | 138 | if not disable_tqdm: 139 | progress.update(len(batch_prompts)) 140 | 141 | assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts" 142 | return predictions, probs 143 | 144 | 145 | @torch.no_grad() 146 | def score_completions(model, tokenizer, scoring_examples, disable_tqdm=False): 147 | ''' 148 | Each scoring example is a dict, which contains the following keys: 149 | - prompt: the prompt to score 150 | - completions: a list of completions to score 151 | ''' 152 | 153 | if not disable_tqdm: 154 | progress = tqdm.tqdm(total=len(scoring_examples), desc="Scoring Completions") 155 | 156 | # unroll the scoring examples 157 | unrolled_examples = [] 158 | for scoring_example in scoring_examples: 159 | prompt = scoring_example["prompt"] 160 | for completion in scoring_example["completions"]: 161 | unrolled_examples.append({ 162 | "prompt": prompt, 163 | "completion": completion 164 | }) 165 | 166 | scores = [] 167 | # currently we don't support batching, because we want to directly use the loss returned by the model to score each completion. 168 | for unrolled_example in unrolled_examples: 169 | encoded_example = encode_with_prompt_completion_format(unrolled_example, tokenizer, max_seq_length=None) 170 | # unsqueeze the batch dimension 171 | for key, value in encoded_example.items(): 172 | encoded_example[key] = value.unsqueeze(0) 173 | if model.device.type == "cuda": 174 | encoded_example = { 175 | key: value.cuda() for key, value in encoded_example.items() 176 | } 177 | outputs = model(**encoded_example) 178 | loss = outputs.loss 179 | scores.append(-loss.item()) 180 | if not disable_tqdm: 181 | progress.update(1) 182 | 183 | # roll up the scores 184 | rolled_up_scores = {} 185 | for unrolled_example, score in zip(unrolled_examples, scores): 186 | prompt = unrolled_example["prompt"] 187 | completion = unrolled_example["completion"] 188 | if prompt not in rolled_up_scores: 189 | rolled_up_scores[prompt] = {} 190 | rolled_up_scores[prompt][completion] = score 191 | 192 | return rolled_up_scores 193 | 194 | 195 | 196 | def load_hf_lm_and_tokenizer( 197 | model_name_or_path, 198 | tokenizer_name_or_path=None, 199 | device_map="auto", 200 | load_in_8bit=False, 201 | load_in_half=False, 202 | gptq_model=False, 203 | use_fast_tokenizer=True, 204 | padding_side="left", 205 | ): 206 | 207 | from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer 208 | 209 | if not tokenizer_name_or_path: 210 | tokenizer_name_or_path = model_name_or_path 211 | 212 | is_chatglm2 = 'chatglm2' in tokenizer_name_or_path.lower() or 'chatglm2' in model_name_or_path 213 | is_qwen = 'qwen' in tokenizer_name_or_path.lower() or 'qwen' in model_name_or_path 214 | 215 | if is_chatglm2 or is_qwen: 216 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True) 217 | if is_qwen: 218 | tokenizer.eos_token = '<|endoftext|>' 219 | tokenizer.eos_token_id = 151643 220 | tokenizer.pad_token = tokenizer.eos_token 221 | tokenizer.pad_token_id = tokenizer.eos_token_id 222 | else: 223 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=use_fast_tokenizer) 224 | # set padding side to left for batch generation 225 | tokenizer.padding_side = padding_side 226 | # set pad token to eos token if pad token is not set (as is the case for llama models) 227 | if tokenizer.pad_token is None: 228 | tokenizer.pad_token = tokenizer.eos_token 229 | tokenizer.pad_token_id = tokenizer.eos_token_id 230 | 231 | if gptq_model: 232 | from auto_gptq import AutoGPTQForCausalLM 233 | model_wrapper = AutoGPTQForCausalLM.from_quantized( 234 | model_name_or_path, device="cuda:0", use_triton=True 235 | ) 236 | model = model_wrapper.model 237 | elif load_in_8bit: 238 | model = AutoModelForCausalLM.from_pretrained( 239 | model_name_or_path, 240 | device_map=device_map, 241 | load_in_8bit=True 242 | ) 243 | else: 244 | kwargs = {} 245 | model_class = AutoModelForCausalLM 246 | if is_chatglm2: 247 | kwargs = {'trust_remote_code': True} 248 | model_class = AutoModel 249 | elif is_qwen: 250 | kwargs = {'trust_remote_code': True} 251 | if device_map: 252 | model = model_class.from_pretrained(model_name_or_path, device_map=device_map, **kwargs) 253 | else: 254 | model = model_class.from_pretrained(model_name_or_path, **kwargs) 255 | if torch.cuda.is_available(): 256 | model = model.cuda() 257 | if is_qwen: 258 | model.generation_config = GenerationConfig.from_pretrained(model_name_or_path, trust_remote_code=True) 259 | model.generation_config.do_sample = False 260 | if not is_chatglm2 and not is_qwen and load_in_half: 261 | model = model.half() 262 | model.eval() 263 | return model, tokenizer 264 | -------------------------------------------------------------------------------- /imgs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/.DS_Store -------------------------------------------------------------------------------- /imgs/coreidea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/coreidea.png -------------------------------------------------------------------------------- /imgs/example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/example1.png -------------------------------------------------------------------------------- /imgs/example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/example2.png -------------------------------------------------------------------------------- /imgs/example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/example3.png -------------------------------------------------------------------------------- /imgs/example4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/example4.png -------------------------------------------------------------------------------- /imgs/example5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/example5.jpg -------------------------------------------------------------------------------- /imgs/summary.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/summary.jpg -------------------------------------------------------------------------------- /imgs/triangle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/triangle.png -------------------------------------------------------------------------------- /licenses/DATA_LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /licenses/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /licenses/WEIGHT_LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /paper/paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/paper/paper.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonlines 2 | trl==0.12.2 3 | alignment-handbook @ git+https://github.com/huggingface/alignment-handbook 4 | wandb 5 | deepspeed 6 | accelerate 7 | flash_attn 8 | vllm 9 | antlr4-python3-runtime==4.11 10 | -------------------------------------------------------------------------------- /stepdpo_trainer.py: -------------------------------------------------------------------------------- 1 | # Modified from trl/trl/trainer/dpo_trainer.py 2 | from typing import Dict, Optional, Union 3 | 4 | import torch 5 | from transformers import PreTrainedModel 6 | from trl import DPOTrainer 7 | 8 | 9 | class StepDPOTrainer(DPOTrainer): 10 | 11 | def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None) -> Dict: 12 | """Tokenize a single row from a DPO specific dataset. 13 | 14 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation 15 | in case the prompt + chosen or prompt + rejected responses is/are too long. First 16 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected. 17 | 18 | We also create the labels for the chosen/rejected responses, which are of length equal to 19 | the sum of the length of the prompt and the chosen/rejected response, with 20 | label_pad_token_id for the prompt tokens. 21 | """ 22 | batch = {} 23 | prompt = feature["prompt"] 24 | chosen = feature["chosen"] 25 | rejected = feature["rejected"] 26 | 27 | if not self.is_encoder_decoder: 28 | # Check issues below for more details 29 | # 1. https://github.com/huggingface/trl/issues/907 30 | # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 31 | # 3. https://github.com/LianjiaTech/BELLE/issues/337 32 | 33 | if not isinstance(prompt, str): 34 | raise ValueError(f"prompt should be an str but got {type(prompt)}") 35 | prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) 36 | prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} 37 | 38 | if not isinstance(chosen, str): 39 | raise ValueError(f"chosen should be an str but got {type(chosen)}") 40 | chosen_tokens = self.build_tokenized_answer(prompt, chosen) 41 | 42 | if not isinstance(rejected, str): 43 | raise ValueError(f"rejected should be an str but got {type(rejected)}") 44 | rejected_tokens = self.build_tokenized_answer(prompt, rejected) 45 | 46 | # Last prompt token might get merged by tokenizer and 47 | # it should not be included for generation if that happens 48 | prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) 49 | 50 | chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) 51 | rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) 52 | prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) 53 | 54 | for k, v in prompt_tokens.items(): 55 | prompt_tokens[k] = v[:prompt_len_input_ids] 56 | 57 | # Make sure prompts only have one different token at most an 58 | # and length only differs by 1 at most 59 | num_diff_tokens = sum( 60 | [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])] 61 | ) 62 | num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) 63 | if num_diff_tokens > 1 or num_diff_len > 1: 64 | raise ValueError( 65 | "Chosen and rejected prompt_input_ids might only differ on the " 66 | "last token due to tokenizer merge ops." 67 | ) 68 | 69 | # add BOS token to head of prompt 70 | if self.tokenizer.bos_token_id is not None: 71 | prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"] 72 | chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"] 73 | rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"] 74 | 75 | prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] 76 | chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] 77 | rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] 78 | 79 | # # add EOS token to end of answer 80 | # chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) 81 | # chosen_tokens["attention_mask"].append(1) 82 | 83 | # rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) 84 | # rejected_tokens["attention_mask"].append(1) 85 | 86 | longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) 87 | 88 | # if combined sequence is too long, truncate the prompt 89 | for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: 90 | if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: 91 | if self.truncation_mode == "keep_start": 92 | for k in ["prompt_input_ids", "prompt_attention_mask"]: 93 | answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] 94 | elif self.truncation_mode == "keep_end": 95 | for k in ["prompt_input_ids", "prompt_attention_mask"]: 96 | answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] 97 | else: 98 | raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") 99 | 100 | # if that's still too long, truncate the response 101 | for answer_tokens in [chosen_tokens, rejected_tokens]: 102 | if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: 103 | for k in ["input_ids", "attention_mask"]: 104 | answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] 105 | 106 | # Create labels 107 | chosen_sequence_tokens = { 108 | k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] 109 | } 110 | rejected_sequence_tokens = { 111 | k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] 112 | } 113 | chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] 114 | chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ 115 | self.label_pad_token_id 116 | ] * len(chosen_tokens["prompt_input_ids"]) 117 | rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] 118 | rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ 119 | self.label_pad_token_id 120 | ] * len(rejected_tokens["prompt_input_ids"]) 121 | 122 | for k, toks in { 123 | "chosen_": chosen_sequence_tokens, 124 | "rejected_": rejected_sequence_tokens, 125 | "": prompt_tokens, 126 | }.items(): 127 | for type_key, tokens in toks.items(): 128 | if type_key == "token_type_ids": 129 | continue 130 | batch[f"{k}{type_key}"] = tokens 131 | 132 | # import pdb; pdb.set_trace() 133 | 134 | else: 135 | chosen_tokens = self.tokenizer( 136 | chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True 137 | ) 138 | rejected_tokens = self.tokenizer( 139 | rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True 140 | ) 141 | prompt_tokens = self.tokenizer( 142 | prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True 143 | ) 144 | 145 | batch["chosen_labels"] = chosen_tokens["input_ids"] 146 | batch["rejected_labels"] = rejected_tokens["input_ids"] 147 | batch["prompt_input_ids"] = prompt_tokens["input_ids"] 148 | batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] 149 | 150 | if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): 151 | batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( 152 | labels=torch.tensor(batch["rejected_labels"]) 153 | ) 154 | batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( 155 | labels=torch.tensor(batch["chosen_labels"]) 156 | ) 157 | 158 | return batch 159 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import sys 4 | from dataclasses import dataclass, field 5 | 6 | import torch 7 | import transformers 8 | from alignment import ( 9 | DataArguments, 10 | DPOConfig, 11 | H4ArgumentParser, 12 | ModelArguments, 13 | get_checkpoint, 14 | get_kbit_device_map, 15 | get_peft_config, 16 | get_quantization_config, 17 | get_tokenizer, 18 | ) 19 | from datasets import load_dataset 20 | from stepdpo_trainer import StepDPOTrainer 21 | from transformers import set_seed 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | def apply_step_wise_chat_template( 26 | example, 27 | tokenizer, 28 | task, 29 | prompt, 30 | auto_insert_empty_system_msg: bool = True 31 | ): 32 | assert task in ["dpo"] 33 | if prompt == 'alpaca': 34 | prompt_input = ( 35 | "Below is an instruction that describes a task, paired with an input that provides further context. " 36 | "Write a response that appropriately completes the request.\n\n" 37 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 38 | ) 39 | prompt_no_input = ( 40 | "Below is an instruction that describes a task. " 41 | "Write a response that appropriately completes the request.\n\n" 42 | "### Instruction:\n{instruction}\n\n### Response:" 43 | ) 44 | elif prompt == 'deepseek-math': 45 | prompt_input = None 46 | prompt_no_input = "User: {instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant:" 47 | elif prompt == 'qwen2-boxed': 48 | prompt_input = None 49 | prompt_no_input = ( 50 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 51 | "<|im_start|>user\n{instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" 52 | "<|im_start|>assistant\n" 53 | ) 54 | 55 | text_chosen = example['chosen'] 56 | text_rejected = example['rejected'] 57 | 58 | if prompt == 'alpaca': 59 | if len(example['initial_reason_steps']) == 0: 60 | new_example = { 61 | 'prompt': prompt_no_input.format(instruction=example['prompt']), 62 | 'chosen': text_chosen, 63 | 'rejected': text_rejected, 64 | } 65 | else: 66 | new_example = { 67 | 'prompt': prompt_no_input.format(instruction=example['prompt']) + "\n" + example['initial_reason_steps'], 68 | 'chosen': text_chosen, 69 | 'rejected': text_rejected, 70 | } 71 | elif prompt == 'deepseek-math': 72 | if len(example['initial_reason_steps']) == 0: 73 | new_example = { 74 | 'prompt': prompt_no_input.format(instruction=example['prompt']), 75 | 'chosen': text_chosen, 76 | 'rejected': text_rejected, 77 | } 78 | else: 79 | new_example = { 80 | 'prompt': prompt_no_input.format(instruction=example['prompt']) + " " + example['initial_reason_steps'], 81 | 'chosen': text_chosen, 82 | 'rejected': text_rejected, 83 | } 84 | elif prompt == 'qwen2-boxed': 85 | if len(example['initial_reason_steps']) == 0: 86 | new_example = { 87 | 'prompt': prompt_no_input.format(instruction=example['prompt']), 88 | 'chosen': text_chosen, 89 | 'rejected': text_rejected, 90 | } 91 | else: 92 | new_example = { 93 | 'prompt': prompt_no_input.format(instruction=example['prompt']) + example['initial_reason_steps'], 94 | 'chosen': text_chosen, 95 | 'rejected': text_rejected, 96 | } 97 | return new_example 98 | 99 | @dataclass 100 | class StepDPOConfig(DPOConfig): 101 | data_path: str = field(default="xinlai/math-step-dpo-10K") 102 | prompt: str = field(default="alpaca") 103 | 104 | def main(): 105 | parser = H4ArgumentParser((ModelArguments, DataArguments, StepDPOConfig)) 106 | model_args, data_args, training_args = parser.parse() 107 | 108 | ####### 109 | # Setup 110 | ####### 111 | logging.basicConfig( 112 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 113 | datefmt="%Y-%m-%d %H:%M:%S", 114 | handlers=[logging.StreamHandler(sys.stdout)], 115 | ) 116 | log_level = training_args.get_process_log_level() 117 | logger.setLevel(log_level) 118 | transformers.utils.logging.set_verbosity(log_level) 119 | transformers.utils.logging.enable_default_handler() 120 | transformers.utils.logging.enable_explicit_format() 121 | 122 | # Log on each process the small summary: 123 | logger.info(f"Model parameters {model_args}") 124 | logger.info(f"Data parameters {data_args}") 125 | logger.info(f"Training/evaluation parameters {training_args}") 126 | 127 | # Check for last checkpoint 128 | last_checkpoint = get_checkpoint(training_args) 129 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 130 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") 131 | 132 | # Set seed for reproducibility 133 | set_seed(training_args.seed) 134 | 135 | ############### 136 | # Load datasets 137 | ############### 138 | if ".json" in training_args.data_path: 139 | raw_datasets = load_dataset( 140 | "json", 141 | data_files=training_args.data_path.split("||"), 142 | ) 143 | else: 144 | raw_datasets = load_dataset(training_args.data_path) 145 | 146 | logger.info( 147 | f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" 148 | ) 149 | column_names = list(raw_datasets["train"].features) 150 | 151 | ##################################### 152 | # Load tokenizer and process datasets 153 | ##################################### 154 | data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn 155 | tokenizer = get_tokenizer(model_args, data_args) 156 | 157 | ##################### 158 | # Apply chat template 159 | ##################### 160 | 161 | raw_datasets = raw_datasets.map( 162 | apply_step_wise_chat_template, 163 | fn_kwargs={ 164 | "tokenizer": tokenizer, 165 | "task": "dpo", 166 | "prompt": training_args.prompt, 167 | "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg, 168 | }, 169 | num_proc=data_args.preprocessing_num_workers, 170 | remove_columns=column_names, 171 | desc="Formatting comparisons with prompt template", 172 | ) 173 | 174 | # Log a few random samples from the training set: 175 | for index in random.sample(range(len(raw_datasets["train"])), 3): 176 | logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}") 177 | logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}") 178 | logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}") 179 | 180 | torch_dtype = ( 181 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) 182 | ) 183 | quantization_config = get_quantization_config(model_args) 184 | 185 | model_kwargs = dict( 186 | revision=model_args.model_revision, 187 | trust_remote_code=model_args.trust_remote_code, 188 | use_flash_attention_2=model_args.use_flash_attention_2, 189 | torch_dtype=torch_dtype, 190 | use_cache=False if training_args.gradient_checkpointing else True, 191 | device_map=get_kbit_device_map() if quantization_config is not None else None, 192 | quantization_config=quantization_config, 193 | ) 194 | 195 | model = model_args.model_name_or_path 196 | ref_model = model 197 | ref_model_kwargs = model_kwargs 198 | 199 | if model_args.use_peft is True: 200 | ref_model = None 201 | ref_model_kwargs = None 202 | 203 | ######################### 204 | # Instantiate DPO trainer 205 | ######################### 206 | trainer = StepDPOTrainer( 207 | model, 208 | ref_model, 209 | model_init_kwargs=model_kwargs, 210 | ref_model_init_kwargs=ref_model_kwargs, 211 | args=training_args, 212 | beta=training_args.beta, 213 | train_dataset=raw_datasets["train"], 214 | eval_dataset=raw_datasets["test"] if "test" in raw_datasets.keys() else None, 215 | tokenizer=tokenizer, 216 | max_length=training_args.max_length, 217 | max_prompt_length=training_args.max_prompt_length, 218 | peft_config=get_peft_config(model_args), 219 | loss_type=training_args.loss_type, 220 | ) 221 | 222 | ############### 223 | # Training loop 224 | ############### 225 | checkpoint = None 226 | if training_args.resume_from_checkpoint is not None: 227 | checkpoint = training_args.resume_from_checkpoint 228 | elif last_checkpoint is not None: 229 | checkpoint = last_checkpoint 230 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 231 | metrics = train_result.metrics 232 | metrics["train_samples"] = len(raw_datasets["train"]) 233 | trainer.log_metrics("train", metrics) 234 | trainer.save_metrics("train", metrics) 235 | trainer.save_state() 236 | 237 | logger.info("*** Training complete ***") 238 | 239 | ################################## 240 | # Save model and create model card 241 | ################################## 242 | logger.info("*** Save model ***") 243 | trainer.save_model(training_args.output_dir) 244 | logger.info(f"Model saved to {training_args.output_dir}") 245 | 246 | # Save everything else on main process 247 | kwargs = { 248 | "finetuned_from": model_args.model_name_or_path, 249 | "dataset": [training_args.data_path], 250 | "dataset_tags": [training_args.data_path], 251 | "tags": ["alignment-handbook"], 252 | } 253 | if trainer.accelerator.is_main_process: 254 | trainer.create_model_card(**kwargs) 255 | # Restore k,v cache for fast inference 256 | trainer.model.config.use_cache = True 257 | trainer.model.config.save_pretrained(training_args.output_dir) 258 | 259 | ########## 260 | # Evaluate 261 | ########## 262 | if training_args.do_eval: 263 | logger.info("*** Evaluate ***") 264 | metrics = trainer.evaluate() 265 | metrics["eval_samples"] = len(raw_datasets["test"]) 266 | trainer.log_metrics("eval", metrics) 267 | trainer.save_metrics("eval", metrics) 268 | 269 | if training_args.push_to_hub is True: 270 | logger.info("Pushing to hub...") 271 | trainer.push_to_hub(**kwargs) 272 | 273 | logger.info("*** Training complete! ***") 274 | 275 | 276 | if __name__ == "__main__": 277 | main() 278 | --------------------------------------------------------------------------------