├── .gitignore
├── README.md
├── accelerate_configs
├── deepspeed_zero3.yaml
└── deepspeed_zero3_cpu.yaml
├── app.py
├── bash_scrips
├── qwen2_72b_instruct_step_dpo.sh
├── qwen2_72b_step_dpo.sh
└── qwen2_7b_step_dpo.sh
├── configs
└── config_full.yaml
├── data
└── test
│ ├── GSM8K_test_data.jsonl
│ └── MATH_test_data.jsonl
├── data_pipeline
├── generate_dataset.py
├── locate_error_by_gpt4.py
├── merge.sh
├── predictions
│ └── sample.json
├── prepare_for_correction.py
├── step1.sh
├── step2.sh
└── step3.sh
├── eval_math.py
├── eval_results
├── gsm8k
│ └── sample.json
└── math
│ ├── qwen2-7b-dpo-v3-continue-from-incorrect-fix-part1-filtered+2-filtered+aqua-filtered-rej-original_acc0.6-topk1-beta0.5-8ep-fixbug-fixeos-bf16-keywords-fix.json
│ └── sample.json
├── evaluation
├── data_processing
│ ├── answer_extraction.py
│ └── process_utils.py
└── eval
│ ├── eval_script.py
│ ├── eval_utils.py
│ ├── ocwcourses_eval_utils.py
│ ├── python_executor.py
│ └── utils.py
├── imgs
├── .DS_Store
├── coreidea.png
├── example1.png
├── example2.png
├── example3.png
├── example4.png
├── example5.jpg
├── summary.jpg
└── triangle.png
├── licenses
├── DATA_LICENSE
├── LICENSE
└── WEIGHT_LICENSE
├── paper
└── paper.pdf
├── requirements.txt
├── stepdpo_trainer.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
2 | wandb/
3 | outputs/
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | 
3 | # Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs
4 | [Xin Lai](https://scholar.google.com/citations?user=tqNDPA4AAAAJ&hl),
5 | [Zhuotao Tian](https://scholar.google.com/citations?user=mEjhz-IAAAAJ&hl),
6 | [Yukang Chen](https://scholar.google.com/citations?user=6p0ygKUAAAAJ&hl),
7 | [Senqiao Yang](https://scholar.google.com/citations?user=NcJc-RwAAAAJ&hl),
8 | [Xiangru Peng](xxxx),
9 | [Jiaya Jia](https://scholar.google.com/citations?user=XPAkzTEAAAAJ&hl=en)
10 |
11 |
12 | [](https://huggingface.co/collections/xinlai/step-dpo-6682e12dfbbb2917c8161df7)
13 | [](https://huggingface.co/datasets/xinlai/Math-Step-DPO-10K)
14 | [](https://arxiv.org/pdf/2406.18629)
15 | [](http://103.170.5.190:7870/)
16 |
17 | [](licenses/LICENSE)
18 | [](licenses/DATA_LICENSE)
19 | [](licenses/WEIGHT_LICENSE)
20 |
21 | This repo provides the implementation of **Step-DPO**, a simple, effective, and data-efficient method for boosting the long-chain reasoning ability of LLMs, with **a data construction pipeline** that yields a **high-quality dataset** containing 10K step-wise preference pairs.
22 |
23 | Notably, **Step-DPO** boosts the performance of **Qwen2-7B-Instruct** from **53.0%** to **58.6%** on MATH, and **85.5%** to **87.9%** on GSM8K, with as few as **10K data** and **hundreds of training steps**!
24 |
25 | Moreover, **Step-DPO**, when applied to **Qwen2-72B-Instruct**, achieves scores of **70.8%** and **94.0%** on the test sets of **MATH** and **GSM8K**, respectively, **surpassing a series of closed-source models** without bells and wistles, including GPT-4-1106, Claude-3-Opus, and Gemini-1.5-Pro.
26 |
27 | 
28 |
29 | ## TABLE OF CONTENTS
30 | 1. [News](#news)
31 | 2. [Datasets](#datasets)
32 | 3. [Models](#models)
33 | 4. [Installation](#installation)
34 | 5. [Training](#training)
35 | 6. [Evaluation](#evaluation)
36 | 7. [Data Construction Pipeline](#data-construction-pipeline)
37 | 8. [Deployment](#deployment)
38 | 9. [Examples](#examples)
39 | 10. [Acknowledgement](#acknowledgement)
40 | 11. [Citation](#citation)
41 |
42 | ## News
43 | - [x] [2024.7.7] We release the scripts for [Data Construction Pipeline](#data-construction-pipeline)! You can construct dataset on your own with these scripts!
44 | - [x] [2024.7.1] We release the demo of the model [Qwen2-7B-Instruct-Step-DPO](https://huggingface.co/xinlai/Qwen2-7B-Instruct-Step-DPO). Welcome to try it on [Demo](http://103.170.5.190:7870/)!
45 | - [x] [2024.6.28] We release the pre-print of [Step-DPO](https://arxiv.org/pdf/2406.18629) and this GitHub repo, including training/evaluation scripts, pre-trained models and data.
46 |
47 | ## Datasets
48 |
49 | We build a 10K math preference datasets for Step-DPO, which can be downloaded from the following link.
50 |
51 | | Dataset | Size | Link |
52 | | ------------------------ | ------ | ------------------------------------------------------------ |
53 | | xinlai/Math-Step-DPO-10K | 10,795 | 🤗 [Hugging Face](https://huggingface.co/datasets/xinlai/Math-Step-DPO-10K) |
54 |
55 | ## Models
56 |
57 | It is notable that the model **Qwen2-72B-Instruct + Step-DPO** could achieve **70.8%** and **94.0%** on MATH and GSM8K test sets. Step-DPO also brings considerable improvement over various models as follows. Welcome to download and use.
58 |
59 | | Models | Size | MATH | GSM8K | Odyssey-MATH | Link |
60 | | :------------------------------ | :--: | :----: | :---: | :---: | :----------------------------------------------------------: |
61 | | Qwen2-7B-Instruct | 7B | 53.0 | 85.5 | - | - |
62 | | **Qwen2-7B-Instruct + Step-DPO** | 7B | **58.6 (+5.6)** | **87.9 (+2.4)** | - | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-7B-Instruct-Step-DPO) |
63 | | DeepSeekMath-RL | 7B | 51.7 | 88.2 | - | - |
64 | | **DeepSeekMath-RL + Step-DPO** | 7B | **53.2 (+1.5)** | **88.7 (+0.5)** | - | 🤗 [HF](https://huggingface.co/xinlai/DeepSeekMath-RL-Step-DPO) |
65 | | Qwen2-7B-SFT | 7B | 54.8 | 88.2 | - | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-7B-SFT) |
66 | | **Qwen2-7B-SFT + Step-DPO** | 7B | **55.8 (+1.0)** | **88.5 (+0.3)** | - |🤗 [HF](https://huggingface.co/xinlai/Qwen2-7B-SFT-Step-DPO) |
67 | | Qwen1.5-32B-SFT | 32B | 54.9 | 90.0 | - | 🤗 [HF](https://huggingface.co/xinlai/Qwen1.5-32B-SFT) |
68 | | **Qwen1.5-32B-SFT + Step-DPO** | 32B | **56.9 (+2.0)** | **90.9 (+0.9)** | - |🤗 [HF](https://huggingface.co/xinlai/Qwen1.5-32B-SFT-Step-DPO) |
69 | | Qwen2-57B-A14B-SFT | 57B | 54.6 | 89.8 | - | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-57B-A14B-SFT) |
70 | | **Qwen2-57B-A14B-SFT + Step-DPO** | 57B | **56.5 (+1.9)** | **90.0 (+0.2)** | - |🤗 [HF](https://huggingface.co/xinlai/Qwen2-57B-A14B-SFT-Step-DPO) |
71 | | Llama-3-70B-SFT | 70B | 56.9 | 92.2 | - | 🤗 [HF](https://huggingface.co/xinlai/Llama-3-70B-SFT) |
72 | | **Llama-3-70B-SFT + Step-DPO** | 70B | **59.5 (+2.6)** | **93.3 (+1.1)** | - |🤗 [HF](https://huggingface.co/xinlai/Llama-3-70B-SFT-Step-DPO) |
73 | | Qwen2-72B-SFT | 72B | 61.7 | 92.9 | 44.2 | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-72B-SFT) |
74 | | **Qwen2-72B-SFT + Step-DPO** | 72B | **64.7 (+3.0)** | **93.9 (+1.0)** | **47.0 (+2.8)** | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-72B-SFT-Step-DPO) |
75 | | Qwen2-72B-Instruct | 72B | 69.4 | 92.4 | 47.0 | - |
76 | | **Qwen2-72B-Instruct + Step-DPO** | 72B | **70.8 (+1.4)** | **94.0 (+1.6)** | **50.1 (+3.1)** | 🤗 [HF](https://huggingface.co/xinlai/Qwen2-72B-Instruct-Step-DPO) |
77 |
78 | Note: **Odyssey-MATH** contains competition-level math problems.
79 |
80 | ## Installation
81 | ```
82 | conda create -n step_dpo python=3.10
83 | conda activate step_dpo
84 |
85 | pip install -r requirements.txt
86 | ```
87 |
88 | ## Training
89 |
90 | ### Pre-trained weights
91 | We use Qwen2, Qwen1.5, Llama-3, and DeepSeekMath models as the pre-trained weights and fine-tune them with Step-DPO. Download based on your choices.
92 |
93 | | Pre-trained weights |
94 | |:---------------------------------------------------------------------------|
95 | | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) |
96 | | [deepseek-ai/deepseek-math-7b-rl](https://huggingface.co/deepseek-ai/deepseek-math-7b-rl) |
97 | | [xinlai/Qwen2-7B-SFT](https://huggingface.co/xinlai/Qwen2-7B-SFT) |
98 | | [xinlai/Qwen1.5-32B-SFT](https://huggingface.co/xinlai/Qwen1.5-32B-SFT) |
99 | | [xinlai/Qwen2-57B-A14B-SFT](https://huggingface.co/xinlai/Qwen2-57B-A14B-SFT) |
100 | | [xinlai/Llama-3-70B-SFT](https://huggingface.co/xinlai/Llama-3-70B-SFT) |
101 | | [xinlai/Qwen2-72B-SFT](https://huggingface.co/xinlai/Qwen2-72B-SFT) |
102 | | [Qwen/Qwen2-72B-Instruct](https://huggingface.co/Qwen/Qwen2-72B-Instruct) |
103 |
104 | **Note**: models with '-SFT' are supervised fine-tuned by our 299K SFT data based on open-source base models. You could perform Step-DPO on either our SFT models or existing open-source instruct models.
105 |
106 | Here is a script example to perform Step-DPO on `Qwen/Qwen2-72B-Instruct`:
107 |
108 | ```shell
109 | ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3_cpu.yaml --mixed_precision bf16 \
110 | --num_processes 8 \
111 | train.py configs/config_full.yaml \
112 | --model_name_or_path="Qwen/Qwen2-72B-Instruct" \
113 | --data_path="xinlai/Math-Step-DPO-10K" \
114 | --per_device_train_batch_size=2 \
115 | --gradient_accumulation_steps=8 \
116 | --torch_dtype=bfloat16 \
117 | --bf16=True \
118 | --beta=0.4 \
119 | --num_train_epochs=4 \
120 | --save_strategy='steps' \
121 | --save_steps=200 \
122 | --save_total_limit=1 \
123 | --output_dir=outputs/qwen2-72b-instruct-step-dpo \
124 | --hub_model_id=qwen2-72b-instruct-step-dpo \
125 | --prompt=qwen2-boxed
126 | ```
127 |
128 | ## Evaluation
129 |
130 | Here are script examples to evaluate fine-tuned models on both GSM8K and MATH test sets:
131 | ```
132 | python eval_math.py \
133 | --model outputs/qwen2-72b-instruct-step-dpo \
134 | --data_file ./data/test/GSM8K_test_data.jsonl \
135 | --save_path 'eval_results/gsm8k/qwen2-72b-instruct-step-dpo.json' \
136 | --prompt 'qwen2-boxed' \
137 | --tensor_parallel_size 8
138 | ```
139 |
140 | ```
141 | python eval_math.py \
142 | --model outputs/qwen2-72b-instruct-step-dpo \
143 | --data_file ./data/test/MATH_test_data.jsonl \
144 | --save_path 'eval_results/math/qwen2-72b-instruct-step-dpo.json' \
145 | --prompt 'qwen2-boxed' \
146 | --tensor_parallel_size 8
147 | ```
148 |
149 | ## Data Construction Pipeline
150 |
151 | We release the scripts to construct the Step-DPO data, as shown in the `data_pipeline/` directory. Please follow the instructions below.
152 |
153 | ```
154 | cd Step-DPO
155 |
156 | # Step 1: Error Collection
157 | # Before executing, please set the MODEL_PATH, PRED_PATH, EVAL_PROMPT
158 | bash data_pipeline/step1.sh
159 |
160 | # Step 2: Locate Erroneous Step by GPT-4o
161 | # Before executing, please set the OPENAI_BASE_URL, OPENAI_API_KEY
162 | bash data_pipeline/step2.sh
163 |
164 | # Step 3: Rectify by the model itself
165 | # Before executing, please set the MODEL_PATH, EVAL_PROMPT, JSON_FILE, PRED_PATH, SAVE_PATH
166 | bash data_pipeline/step3.sh
167 |
168 | # Finally, Get the resulting dataset
169 | # Before executing, please set the EVAL_PROMPT, JSON_FILE, PRED_PATH, SAVE_PATH
170 | bash data_pipeline/merge.sh
171 | ```
172 |
173 | ## Deployment
174 |
175 | For deployment, please directly use the following command:
176 | ```
177 | python3 app.py --model_path_or_name xinlai/Qwen2-7B-Instruct-Step-DPO
178 | ```
179 |
180 |
181 | ## Examples
182 |
183 | 
184 |
185 | 
186 |
187 | 
188 |
189 | 
190 |
191 | ## Acknowledgement
192 |
193 | This repository is based on [alignment-handbook](https://github.com/huggingface/alignment-handbook), [DeepSeekMath](https://github.com/deepseek-ai/DeepSeek-Math), and [MetaMath](https://github.com/meta-math/MetaMath).
194 |
195 | Many thanks for their efforts!
196 |
197 | ## Citation
198 | If you find this project useful in your research, please consider citing us:
199 |
200 | ```
201 | @article{lai2024stepdpo,
202 | title={Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs},
203 | author={Xin Lai and Zhuotao Tian and Yukang Chen and Senqiao Yang and Xiangru Peng and Jiaya Jia},
204 | journal={arXiv:2406.18629},
205 | year={2024}
206 | }
207 | ```
208 |
--------------------------------------------------------------------------------
/accelerate_configs/deepspeed_zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: no #bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/accelerate_configs/deepspeed_zero3_cpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: cpu
6 | offload_param_device: cpu
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: no #bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import os
4 | import sys
5 |
6 | import bleach
7 | import gradio as gr
8 | import torch
9 | import transformers
10 |
11 |
12 | def parse_args(args):
13 | parser = argparse.ArgumentParser(description='LISA chat')
14 | parser.add_argument('--model_path_or_name', default='')
15 | parser.add_argument('--save_path', default='/data/step_dpo_history')
16 | return parser.parse_args(args)
17 |
18 | args = parse_args(sys.argv[1:])
19 | os.makedirs(args.save_path, exist_ok=True)
20 |
21 | # Create model
22 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_path_or_name)
23 | model = transformers.AutoModelForCausalLM.from_pretrained(args.model_path_or_name, torch_dtype=torch.bfloat16, device_map="auto")
24 |
25 | # Gradio
26 | examples = [
27 | ['Suppose that $h(x)=f^{-1}(x)$. If $h(2)=10$, $h(10)=1$ and $h(1)=2$, what is $f(f(10))$?'],
28 | ]
29 | output_labels = ['Output']
30 |
31 | title = 'Step-DPO: Step-wise Preference Optimization for Long-chain Reasoning of LLMs'
32 |
33 | description = """
34 |
35 |
36 | This is the online demo of **Qwen2-7B-Instruct-Step-DPO**. \n
37 |
38 | It is obtained by performing **Step-DPO** on **Qwen2-7B-Instruct**, with as few as **10K data and hundreds of training steps**. \n
39 |
40 | **Step-DPO** improves the mathematical reasoning of **Qwen2-7B-Instruct** significantly, from **53.0\%** to **58.6\%** on MATH, and **85.5\%** to **87.9\%** on GSM8K. \n
41 | Besides, **Qwen2-72B-Instruct-Step-DPO** achieves **70.8\%** on MATH and **94.0\%** on GSM8K, **outperforming GPT-4-1106, Gemini-1.5-Pro, and Claude-3-Opus**.
42 |
43 | Code, models, data are available at [GitHub](https://github.com/dvlab-research/Step-DPO).
44 |
45 | Hope you can enjoy our work!
46 |
47 | """
48 |
49 | article = """
50 |
51 |
52 | Preprint Paper
53 |
54 | \n
55 |
56 | Github Repo
57 | """
58 |
59 |
60 | def inference(input_str):
61 |
62 | ## filter out special chars
63 | input_str = bleach.clean(input_str)
64 |
65 | print("input_str: ", input_str)
66 |
67 | prompt = input_str + "\nPlease reason step by step, and put your final answer within \\boxed{{}}." #input("Please input your prompt: ")
68 |
69 | messages = [
70 | {"role": "user", "content": prompt}
71 | ]
72 |
73 | text = tokenizer.apply_chat_template(
74 | messages,
75 | tokenize=False,
76 | add_generation_prompt=True
77 | )
78 |
79 | model_inputs = tokenizer([text], return_tensors="pt").to('cuda')
80 |
81 | generated_ids = model.generate(
82 | model_inputs.input_ids,
83 | max_new_tokens=1024
84 | )
85 | generated_ids = [
86 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
87 | ]
88 | text_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
89 |
90 | return text_output
91 |
92 |
93 | demo = gr.Interface(
94 | inference,
95 | inputs=[
96 | gr.Textbox(
97 | lines=1, placeholder=None, label='Math Problem'),
98 | ],
99 | outputs=[
100 | gr.Textbox(
101 | lines=1, placeholder=None, label='Text Output'),
102 | ],
103 | title=title,
104 | description=description,
105 | article=article,
106 | examples=examples,
107 | allow_flagging='auto',
108 | flagging_dir=args.save_path)
109 |
110 | demo.queue()
111 |
112 | demo.launch(server_name='0.0.0.0', show_error=True)
113 |
--------------------------------------------------------------------------------
/bash_scrips/qwen2_72b_instruct_step_dpo.sh:
--------------------------------------------------------------------------------
1 | export output_dir="qwen2-72b-instruct-step-dpo"
2 | export prompt="qwen2-boxed"
3 |
4 | ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3_cpu.yaml --mixed_precision bf16 \
5 | --num_processes 8 \
6 | train.py configs/config_full.yaml \
7 | --model_name_or_path="Qwen/Qwen2-72B-Instruct" \
8 | --data_path="xinlai/Math-Step-DPO-10K" \
9 | --per_device_train_batch_size=2 \
10 | --gradient_accumulation_steps=8 \
11 | --torch_dtype=bfloat16 \
12 | --bf16=True \
13 | --beta=0.4 \
14 | --num_train_epochs=4 \
15 | --save_strategy='steps' \
16 | --save_steps=200 \
17 | --save_total_limit=1 \
18 | --output_dir=outputs/$output_dir \
19 | --hub_model_id=$output_dir \
20 | --prompt=$prompt
21 |
22 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/GSM8K_test_data.jsonl --save_path 'eval_results/gsm8k/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4
23 |
24 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/MATH_test_data.jsonl --save_path 'eval_results/math/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4
25 |
--------------------------------------------------------------------------------
/bash_scrips/qwen2_72b_step_dpo.sh:
--------------------------------------------------------------------------------
1 | export output_dir="qwen2-72b-step-dpo"
2 | export prompt="alpaca"
3 |
4 | ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3_cpu.yaml --mixed_precision bf16 \
5 | --num_processes 8 \
6 | train.py configs/config_full.yaml \
7 | --model_name_or_path="xinlai/Qwen2-72B-SFT" \
8 | --data_path="xinlai/Math-Step-DPO-10K" \
9 | --per_device_train_batch_size=2 \
10 | --gradient_accumulation_steps=8 \
11 | --torch_dtype=bfloat16 \
12 | --bf16=True \
13 | --beta=0.4 \
14 | --num_train_epochs=4 \
15 | --save_strategy='steps' \
16 | --save_steps=200 \
17 | --save_total_limit=1 \
18 | --output_dir=outputs/$output_dir \
19 | --hub_model_id=$output_dir \
20 | --prompt=$prompt
21 |
22 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/GSM8K_test_data.jsonl --save_path 'eval_results/gsm8k/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4
23 |
24 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/MATH_test_data.jsonl --save_path 'eval_results/math/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4
25 |
--------------------------------------------------------------------------------
/bash_scrips/qwen2_7b_step_dpo.sh:
--------------------------------------------------------------------------------
1 | export output_dir="qwen2-7b-step-dpo"
2 | export prompt="alpaca"
3 |
4 | ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml --mixed_precision bf16 \
5 | --num_processes 8 \
6 | train.py configs/config_full.yaml \
7 | --model_name_or_path="xinlai/Qwen2-7B-SFT" \
8 | --data_path="xinlai/Math-Step-DPO-10K" \
9 | --per_device_train_batch_size=4 \
10 | --gradient_accumulation_steps=4 \
11 | --torch_dtype=bfloat16 \
12 | --bf16=True \
13 | --beta=0.5 \
14 | --num_train_epochs=8 \
15 | --save_strategy='steps' \
16 | --save_steps=400 \
17 | --save_total_limit=1 \
18 | --output_dir=outputs/$output_dir \
19 | --hub_model_id=$output_dir \
20 | --prompt=$prompt
21 |
22 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/GSM8K_test_data.jsonl --save_path 'eval_results/gsm8k/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4
23 |
24 | python eval_math.py --model outputs/$output_dir --data_file ./data/test/MATH_test_data.jsonl --save_path 'eval_results/math/'$output_dir'.json' --prompt $prompt --tensor_parallel_size 4
25 |
--------------------------------------------------------------------------------
/configs/config_full.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path:
3 | torch_dtype: bfloat16
4 |
5 | # Data training arguments
6 | # For definitions, see: src/h4/training/config.py
7 | data_path:
8 | dataset_splits:
9 | - train
10 | preprocessing_num_workers: 12
11 |
12 | # DPOTrainer arguments
13 | bf16: True
14 | beta: 0.05
15 | do_eval: False
16 | evaluation_strategy: 'no'
17 | eval_steps: 100
18 | gradient_accumulation_steps: 16
19 | gradient_checkpointing: true
20 | gradient_checkpointing_kwargs:
21 | use_reentrant: False
22 | hub_model_id: step-dpo
23 | learning_rate: 5.0e-7
24 | log_level: info
25 | logging_steps: 1
26 | lr_scheduler_type: cosine
27 | max_length: 1024
28 | max_prompt_length: 512
29 | num_train_epochs: 2
30 | optim: adamw_torch
31 | output_dir: data/step-dpo
32 | per_device_train_batch_size: 1
33 | per_device_eval_batch_size: 4
34 | push_to_hub: false
35 | report_to:
36 | - tensorboard
37 | - wandb
38 | save_strategy: "no"
39 | seed: 42
40 | warmup_ratio: 0.1
41 |
--------------------------------------------------------------------------------
/data_pipeline/generate_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import glob
4 | import jsonlines
5 |
6 | def main(args):
7 | save_path = args.save_path
8 | json_files = sorted(glob.glob(args.corrected_files))
9 | identifier2items = {}
10 | for json_file in json_files:
11 | with open(json_file) as f:
12 | for item in json.load(f):
13 | if item['result']:
14 | if 'alpaca' in args.prompt:
15 | prompt = item['prompt'].split("### Instruction:")[1].split("### Response:")[0].strip()
16 | prefix = item['prompt'].split("### Response:")[-1].lstrip()
17 | elif 'qwen2-boxed' in args.prompt:
18 | prompt = item['prompt'].split("<|im_start|>user\n")[1].split("\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>")[0].strip()
19 | prefix = item['prompt'].split("<|im_start|>assistant\n")[-1].lstrip()
20 | else:
21 | raise NotImplementedError("Prompt {} is not supported currently".format(args.prompt))
22 |
23 | prefix = prefix.replace("Let's think step by step.\n", "")
24 | identifier = prompt + "||" + prefix
25 | if identifier not in identifier2items:
26 | identifier2items[identifier] = []
27 | identifier2items[identifier].append(item)
28 |
29 | new_items = []
30 | invalid_cnt = 0
31 | cnt = 0
32 | with jsonlines.open(args.json_file, "r") as f:
33 | for line in f:
34 | prompt = line['instruction']
35 | prefix = line['prefix']
36 | identifier = prompt + "||" + prefix
37 |
38 | if identifier not in identifier2items:
39 | invalid_cnt += 1
40 | continue
41 | items = identifier2items[identifier]
42 | visited_chosen = set()
43 | for item in items:
44 | cnt += 1
45 | chosen = item['completion']
46 | rejected = line['output']
47 |
48 | chosen_first_step = chosen.split("\nStep ")[0]
49 | rejected_first_step = rejected.split("\nStep ")[0]
50 |
51 | if chosen_first_step in visited_chosen:
52 | continue
53 |
54 | visited_chosen.add(chosen_first_step)
55 |
56 | new_item = {
57 | 'dataset': line['type'],
58 | 'prompt': prompt,
59 | 'prefix': "Let's think step by step.\n" + prefix,
60 | 'chosen': chosen_first_step,
61 | 'rejected': rejected_first_step,
62 | 'original_chosen': chosen,
63 | 'answer': line['answer'],
64 | }
65 | new_items.append(new_item)
66 |
67 | print("len(new_items): {}, invalid_cnt: {}, cnt: {}".format(len(new_items), invalid_cnt, cnt))
68 | with open(save_path, "w+") as f:
69 | json.dump(new_items, f, indent=4)
70 |
71 | def parse_args():
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument("--prompt", type=str, default='qwen2-boxed-step')
74 | parser.add_argument("--save_path", type=str, default='./data_pipeline/data.json')
75 | parser.add_argument("--json_file", type=str, default="./data_pipeline/continue_from_incorrect_step.jsonl")
76 | parser.add_argument("--corrected_files", type=str, default="./data_pipeline/corrections/qwen2-7b-instruct-correction*.json")
77 | return parser.parse_args()
78 |
79 | if __name__ == "__main__":
80 | args = parse_args()
81 | main(args)
82 |
--------------------------------------------------------------------------------
/data_pipeline/locate_error_by_gpt4.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import json
4 | import os
5 | import time
6 |
7 | import openai
8 | import tqdm
9 |
10 | client = openai.OpenAI(
11 | base_url=os.getenv("OPENAI_BASE_URL"),
12 | api_key=os.getenv("OPENAI_API_KEY"),
13 | )
14 |
15 | prompt = '''### Problem:
16 | {problem}
17 |
18 | ### Correct solution:
19 | {solution}
20 |
21 | ### Incorrect answer:
22 | {answer}
23 |
24 | ---
25 |
26 | A math problem and its correct solution are listed above. We also give another incorrect answer, where step-by-step reasoning process is shown. Please output the correctness for each reasoning step in the given answer.
27 |
28 | Requirements:
29 | 1. You should first output a step-by-step analysis process (no more than 200 words), and finally output the decision ("correct", "neutral", "incorrect") for each step following the format of "Final Decision:\nStep 1: correct; Step 2: neutral; ...";
30 | 2. Stop when you find the first incorrect step.'''
31 |
32 | def main(args):
33 |
34 | if not os.path.exists(args.save_dir):
35 | os.mkdir(args.save_dir)
36 |
37 | save_dir = args.save_dir
38 | visited_dirs = save_dir if len(args.visited_dirs) == 0 else args.visited_dirs
39 | json_files = sorted(glob.glob(args.json_files))
40 |
41 | pred_data = []
42 | for json_file in json_files:
43 | with open(json_file) as f:
44 | for item in json.load(f):
45 | if not item['result']:
46 | pred_data.append(item)
47 |
48 | n_groups = args.n_groups
49 | remainder = args.remainder
50 |
51 | print("n_groups: {}, remainder: {}".format(n_groups, remainder))
52 | print("len(pred_data): ", len(pred_data))
53 |
54 | cnt = 0
55 | question2cnt = dict()
56 | for idx, pred_dict in tqdm.tqdm(enumerate(pred_data)):
57 |
58 | if 'alpaca' in args.prompt:
59 | question = pred_dict['prompt'].split("### Instruction:")[1].split("### Response:")[0].strip()
60 | elif 'qwen2-boxed' in args.prompt:
61 | question = pred_dict['prompt'].split("<|im_start|>user\n")[1].split("\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>")[0].strip()
62 | else:
63 | raise NotImplementedError("Prompt {} is not supported currently".format(args.prompt))
64 |
65 | if question in question2cnt and question2cnt[question] > args.max_count_per_question:
66 | continue
67 | if question not in question2cnt:
68 | question2cnt[question] = 0
69 | question2cnt[question] += 1
70 |
71 | # skip the invalid questions without diagram
72 | if "diagram" in question and 'asy' not in question:
73 | continue
74 |
75 | # skip other threads
76 | if idx % n_groups != remainder:
77 | continue
78 |
79 | # skip the visited questions
80 | if any([os.path.exists(os.path.join(visited_dir, "{}.json".format(idx))) for visited_dir in visited_dirs.split("||")]):
81 | continue
82 |
83 | completion = "Step 1: " + pred_dict['completion']
84 | instruction = prompt.format(problem=question, solution=pred_dict['gt_output'].replace("\n\n", "\n"), answer=completion.replace("\n\n", "\n"))
85 |
86 | # print("instruction: ", instruction)
87 | # import pdb; pdb.set_trace()
88 |
89 | while True:
90 | try:
91 | chat_completion = client.chat.completions.create(
92 | messages=[
93 | {
94 | "role": "user",
95 | "content": instruction,
96 | }
97 | ],
98 | model="gpt-4o",
99 | )
100 | except (openai.APIConnectionError, openai.InternalServerError) as e:
101 | print(str(e))
102 | time.sleep(3)
103 | continue
104 | break
105 |
106 | item = pred_dict.copy()
107 | item['gpt4-output'] = chat_completion.choices[0].message.content
108 | item['gpt4-prompt'] = instruction
109 | save_path = os.path.join(save_dir, "{}.json".format(idx))
110 | with open(save_path, "w+") as f:
111 | json.dump(item, f, indent=4)
112 | cnt += 1
113 | print("cnt: ", cnt, "idx: ", idx)
114 | if cnt >= args.max_count_total:
115 | break
116 |
117 | def parse_args():
118 | parser = argparse.ArgumentParser()
119 | parser.add_argument("--prompt", type=str, default='qwen2-boxed-step')
120 | parser.add_argument("--visited_dirs", type=str, default='') # will skip the files in $visited_dirs
121 | parser.add_argument("--save_dir", type=str, default='./data_pipeline/generated')
122 | parser.add_argument("--remainder", type=int, default=0) # remainder
123 | parser.add_argument("--n_groups", type=int, default=1) # n_groups
124 | parser.add_argument("--json_files", type=str, default="./data_pipeline/predictions/qwen2-7b-instruct-temp0.8-top_p0.95_rep2_seed0-alpaca-group*.json")
125 | parser.add_argument("--max_count_per_question", type=int, default=1)
126 | parser.add_argument("--max_count_total", type=int, default=10000)
127 | return parser.parse_args()
128 |
129 | if __name__ == "__main__":
130 | args = parse_args()
131 | main(args)
132 |
--------------------------------------------------------------------------------
/data_pipeline/merge.sh:
--------------------------------------------------------------------------------
1 | export EVAL_PROMPT='qwen2-boxed-prefix'
2 | export JSON_FILE='./data_pipeline/continue_from_incorrect_step.jsonl'
3 | export PRED_PATH='./data_pipeline/corrections/qwen2-7b-instruct-correction'
4 | export SAVE_PATH='./data_pipeline/data.json'
5 |
6 | python3 data_pipeline/generate_dataset.py --prompt $EVAL_PROMPT \
7 | --save_path $SAVE_PATH \
8 | --json_file $JSON_FILE \
9 | --corrected_files $PRED_PATH"*.json"
10 |
--------------------------------------------------------------------------------
/data_pipeline/predictions/sample.json:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/data_pipeline/predictions/sample.json
--------------------------------------------------------------------------------
/data_pipeline/prepare_for_correction.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import glob
4 | import jsonlines
5 | import re
6 |
7 | def main(args):
8 | save_file = args.save_file
9 | generated_files = sorted(glob.glob(args.generated_files))
10 |
11 | invalid_cnt0 = 0
12 | invalid_cnt1 = 0
13 | invalid_cnt2 = 0
14 | with jsonlines.open(save_file, "w") as f:
15 | for json_file in generated_files:
16 |
17 | with open(json_file) as ff:
18 | item = json.load(ff)
19 |
20 | correctness = item['gpt4-output'].lower()
21 | correctness = correctness.split("final decision")[-1].split("summary decision:")[-1].strip()
22 | if not any([x in correctness for x in ['neutral', 'incorrect']]):
23 | invalid_cnt0 += 1
24 | continue
25 |
26 | step_num = correctness.split("neutral")[0].split("incorrect")[0]
27 | step_num = step_num.split("\n")[-1].split(";")[-1]
28 | if step_num.count("step") > 1:
29 | invalid_cnt1 += 1
30 | continue
31 | step_num = step_num.split("step")[-1].split(":")[0]
32 | try:
33 | step_num = int(step_num.strip())
34 | except:
35 | # import pdb; pdb.set_trace()
36 | invalid_cnt2 += 1
37 | continue
38 |
39 | if 'alpaca' in args.prompt:
40 | prompt = item['prompt'].split("### Instruction:")[1].split("### Response:")[0].strip()
41 | prefix = item['prompt'].split("### Response:")[-1].lstrip()
42 | elif 'qwen2-boxed' in args.prompt:
43 | prompt = item['prompt'].split("<|im_start|>user\n")[1].split("\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>")[0].strip()
44 | prefix = item['prompt'].split("<|im_start|>assistant\n")[-1].lstrip()
45 | else:
46 | raise NotImplementedError("Prompt {} is not supported currently".format(args.prompt))
47 |
48 | completion = prefix + item['completion']
49 | # pred_answer = completion.split("The answer is:")[-1].strip()
50 | type = item['type']
51 |
52 | if completion.count("Step {}:".format(step_num)) == 0:
53 | continue
54 |
55 | prefix = completion.split("Step {}:".format(step_num))[0] + "Step {}:".format(step_num)
56 |
57 | new_item = {
58 | 'idx': "n/a",
59 | 'instruction': prompt,
60 | 'prefix': prefix.replace("Let's think step by step.\n", ""),
61 | 'output': completion.replace(prefix, ""),
62 | 'gt_output': item['gt_output'],
63 | 'answer': item['prompt_answer'],
64 | 'step_num': step_num,
65 | 'input': "",
66 | 'type': type,
67 | 'ori_filepath': item['path'] if 'path' in item else 'n/a',
68 | }
69 | f.write(new_item)
70 |
71 | print("invalid_cnt0: ", invalid_cnt0)
72 | print("invalid_cnt1: ", invalid_cnt1)
73 | print("invalid_cnt2: ", invalid_cnt2)
74 |
75 | def parse_args():
76 | parser = argparse.ArgumentParser()
77 | parser.add_argument("--prompt", type=str, default='qwen2-boxed-step')
78 | parser.add_argument("--save_file", type=str, default='./data_pipeline/continue_from_incorrect.jsonl')
79 | parser.add_argument("--generated_files", type=str, default="./data_pipeline/generated/*.json")
80 | return parser.parse_args()
81 |
82 | if __name__ == "__main__":
83 | args = parse_args()
84 | main(args)
85 |
--------------------------------------------------------------------------------
/data_pipeline/step1.sh:
--------------------------------------------------------------------------------
1 | export MODEL_PATH='/dataset/pretrained-models/Qwen2-7B-Instruct'
2 | export PRED_PATH='./data_pipeline/predictions/qwen2-7b-instruct-temp0.8-top_p0.95_rep2_seed0-alpaca-group'
3 | export EVAL_PROMPT='qwen2-boxed-step'
4 |
5 | CUDA_VISIBLE_DEVICES=0 python eval_math.py --model $MODEL_PATH --remainder 0 --n_groups 8 --save_path $PRED_PATH"0.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 &
6 | CUDA_VISIBLE_DEVICES=1 python eval_math.py --model $MODEL_PATH --remainder 1 --n_groups 8 --save_path $PRED_PATH"1.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 &
7 | CUDA_VISIBLE_DEVICES=2 python eval_math.py --model $MODEL_PATH --remainder 2 --n_groups 8 --save_path $PRED_PATH"2.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 &
8 | CUDA_VISIBLE_DEVICES=3 python eval_math.py --model $MODEL_PATH --remainder 3 --n_groups 8 --save_path $PRED_PATH"3.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 &
9 | CUDA_VISIBLE_DEVICES=4 python eval_math.py --model $MODEL_PATH --remainder 4 --n_groups 8 --save_path $PRED_PATH"4.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 &
10 | CUDA_VISIBLE_DEVICES=5 python eval_math.py --model $MODEL_PATH --remainder 5 --n_groups 8 --save_path $PRED_PATH"5.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 &
11 | CUDA_VISIBLE_DEVICES=6 python eval_math.py --model $MODEL_PATH --remainder 6 --n_groups 8 --save_path $PRED_PATH"6.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1 &
12 | CUDA_VISIBLE_DEVICES=7 python eval_math.py --model $MODEL_PATH --remainder 7 --n_groups 8 --save_path $PRED_PATH"7.json" --data_file /dataset/industry_gpt/llm_infer/AQuA/train_qa.jsonl --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 2 --seed 0 --tensor_parallel_size 1
13 |
--------------------------------------------------------------------------------
/data_pipeline/step2.sh:
--------------------------------------------------------------------------------
1 | export OPENAI_BASE_URL="" # input openai base_url here
2 | export OPENAI_API_KEY="" # input openai api_key here
3 |
4 | python3 data_pipeline/locate_error_by_gpt4.py \
5 | --prompt "qwen2-boxed-step" \
6 | --save_dir "./data_pipeline/generated" \
7 | --json_files "./data_pipeline/predictions/qwen2-7b-instruct-temp0.8-top_p0.95_rep2_seed0-alpaca-group*.json" \
8 | --max_count_total 100
9 |
--------------------------------------------------------------------------------
/data_pipeline/step3.sh:
--------------------------------------------------------------------------------
1 | export MODEL_PATH='/dataset/pretrained-models/Qwen2-7B-Instruct'
2 | export EVAL_PROMPT='qwen2-boxed-prefix'
3 | export JSON_FILE='./data_pipeline/continue_from_incorrect_step.jsonl'
4 | export PRED_PATH='./data_pipeline/corrections/qwen2-7b-instruct-correction'
5 | export SAVE_PATH='./data_pipeline/data.json'
6 |
7 | python3 data_pipeline/prepare_for_correction.py --prompt $EVAL_PROMPT \
8 | --save_file $JSON_FILE \
9 | --generated_files "./data_pipeline/generated/*.json"
10 |
11 | CUDA_VISIBLE_DEVICES=0 python eval_math.py --model $MODEL_PATH --remainder 0 --n_groups 8 --save_path $PRED_PATH"0.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 &
12 | CUDA_VISIBLE_DEVICES=1 python eval_math.py --model $MODEL_PATH --remainder 1 --n_groups 8 --save_path $PRED_PATH"1.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 &
13 | CUDA_VISIBLE_DEVICES=2 python eval_math.py --model $MODEL_PATH --remainder 2 --n_groups 8 --save_path $PRED_PATH"2.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 &
14 | CUDA_VISIBLE_DEVICES=3 python eval_math.py --model $MODEL_PATH --remainder 3 --n_groups 8 --save_path $PRED_PATH"3.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 &
15 | CUDA_VISIBLE_DEVICES=4 python eval_math.py --model $MODEL_PATH --remainder 4 --n_groups 8 --save_path $PRED_PATH"4.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 &
16 | CUDA_VISIBLE_DEVICES=5 python eval_math.py --model $MODEL_PATH --remainder 5 --n_groups 8 --save_path $PRED_PATH"5.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 &
17 | CUDA_VISIBLE_DEVICES=6 python eval_math.py --model $MODEL_PATH --remainder 6 --n_groups 8 --save_path $PRED_PATH"6.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1 &
18 | CUDA_VISIBLE_DEVICES=7 python eval_math.py --model $MODEL_PATH --remainder 7 --n_groups 8 --save_path $PRED_PATH"7.json" --data_file $JSON_FILE --prompt $EVAL_PROMPT --temp 0.8 --top_p 0.95 --rep 20 --seed 0 --tensor_parallel_size 1
19 |
--------------------------------------------------------------------------------
/eval_math.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import pdb
5 | import sys
6 |
7 | import jsonlines
8 | import torch
9 | from evaluation.data_processing.answer_extraction import extract_math_answer
10 | from evaluation.eval.eval_script import eval_math
11 | from vllm import LLM, SamplingParams
12 |
13 | MAX_INT = sys.maxsize
14 | INVALID_ANS = "[invalid]"
15 |
16 | invalid_outputs = []
17 |
18 | def batch_data(data_list, batch_size=1):
19 | n = len(data_list) // batch_size
20 | batch_data = []
21 | for i in range(n-1):
22 | start = i * batch_size
23 | end = (i+1)*batch_size
24 | batch_data.append(data_list[start:end])
25 |
26 | last_start = (n-1) * batch_size
27 | last_end = MAX_INT
28 | batch_data.append(data_list[last_start:last_end])
29 | return batch_data
30 |
31 | def test_hendrycks_math(model, data_path, remainder=0, n_groups=MAX_INT, batch_size=1, tensor_parallel_size=1, args=None):
32 |
33 | save_path = args.save_path
34 | hendrycks_math_ins = []
35 | hendrycks_math_answers = []
36 | attributes = []
37 | if args.prompt == 'alpaca':
38 | problem_prompt = (
39 | "Below is an instruction that describes a task. "
40 | "Write a response that appropriately completes the request.\n\n"
41 | "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
42 | )
43 | elif args.prompt == 'alpaca-cot-step':
44 | problem_prompt = (
45 | "Below is an instruction that describes a task. "
46 | "Write a response that appropriately completes the request.\n\n"
47 | "### Instruction:\n{instruction}\n\n### Response:\nLet's think step by step.\nStep 1: "
48 | )
49 | elif args.prompt == 'alpaca-cot-prefix':
50 | problem_prompt = (
51 | "Below is an instruction that describes a task. "
52 | "Write a response that appropriately completes the request.\n\n"
53 | "### Instruction:\n{instruction}\n\n### Response:\nLet's think step by step.\n{prefix}"
54 | )
55 | elif args.prompt == 'deepseek-math':
56 | problem_prompt = (
57 | "User: {instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant:"
58 | )
59 | elif args.prompt == 'deepseek-math-step':
60 | problem_prompt = (
61 | "User: {instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant: Let's think step by step.\nStep 1: "
62 | )
63 | elif args.prompt == 'qwen2-boxed':
64 | problem_prompt = (
65 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
66 | "<|im_start|>user\n{instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
67 | "<|im_start|>assistant\n"
68 | )
69 | elif args.prompt == 'qwen2-boxed-step':
70 | problem_prompt = (
71 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
72 | "<|im_start|>user\n{instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
73 | "<|im_start|>assistant\nLet's think step by step.\nStep 1: "
74 | )
75 | elif args.prompt == 'qwen2-boxed-prefix':
76 | problem_prompt = (
77 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
78 | "<|im_start|>user\n{instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
79 | "<|im_start|>assistant\nLet's think step by step.\n{prefix}"
80 | )
81 |
82 | print('prompt =====', problem_prompt)
83 | with open(data_path, "r+", encoding="utf8") as f:
84 | for idx, item in enumerate(jsonlines.Reader(f)):
85 | if "prefix" in item:
86 | temp_instr = problem_prompt.format(instruction=item["instruction"], prefix=item['prefix'])
87 | else:
88 | temp_instr = problem_prompt.format(instruction=item["instruction"])
89 | hendrycks_math_ins.append(temp_instr)
90 | temp_ans = item['answer']
91 | hendrycks_math_answers.append(temp_ans)
92 | attribute = {}
93 | if 'filepath' in item:
94 | attribute['filepath'] = item['filepath']
95 | if 'type' in item:
96 | attribute['type'] = item['type']
97 | if 'output' in item:
98 | attribute['gt_output'] = item['output']
99 | attributes.append(attribute)
100 |
101 | print("args.seed: ", args.seed)
102 | print('length ===', len(hendrycks_math_ins))
103 | hendrycks_math_ins = hendrycks_math_ins[remainder::n_groups]
104 | hendrycks_math_answers = hendrycks_math_answers[remainder::n_groups]
105 | attributes = attributes[remainder::n_groups]
106 |
107 | print("processed length ===", len(hendrycks_math_ins))
108 | hendrycks_math_ins = hendrycks_math_ins * args.rep
109 | hendrycks_math_answers = hendrycks_math_answers * args.rep
110 | attributes = attributes * args.rep
111 |
112 | print('total length ===', len(hendrycks_math_ins))
113 | batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size)
114 |
115 | sampling_params = SamplingParams(temperature=args.temp, top_p=args.top_p, max_tokens=2048)
116 | print('sampling =====', sampling_params)
117 | if not os.path.exists(save_path):
118 | llm = LLM(model=model, tensor_parallel_size=tensor_parallel_size, dtype=torch.bfloat16, seed=args.seed)
119 |
120 | res_completions = []
121 | for idx, (prompt, prompt_answer) in enumerate(zip(batch_hendrycks_math_ins, hendrycks_math_answers)):
122 | if isinstance(prompt, list):
123 | pass
124 | else:
125 | prompt = [prompt]
126 | completions = llm.generate(prompt, sampling_params)
127 | for output in completions:
128 | prompt_temp = output.prompt
129 | generated_text = output.outputs[0].text
130 | res_completions.append(generated_text)
131 | else:
132 | res_completions = []
133 | with open(save_path) as f:
134 | items = json.load(f)
135 | for idx, item in enumerate(items):
136 | res_completions.append(item['completion'])
137 |
138 | to_save_list = []
139 | results = []
140 | for idx, (prompt, completion, prompt_answer, attribute) in enumerate(zip(hendrycks_math_ins, res_completions, hendrycks_math_answers, attributes)):
141 |
142 | if isinstance(prompt_answer, str) and prompt_answer.startswith("\\text{"):
143 | prompt_answer = remove_text(prompt_answer)
144 |
145 | if "The answer is:" in completion and (isinstance(prompt_answer, list) and len(prompt_answer) == 1 and "\\begin{pmatrix}" in prompt_answer[0]):
146 | prompt_answer[0] = prompt_answer[0].replace("\\\\", "\\")
147 | completion = completion.replace("\\\\", "\\")
148 |
149 | item = {
150 | 'question': prompt,
151 | 'model_output': completion,
152 | 'prediction': extract_math_answer(prompt, completion, task='cot'),
153 | 'answer': prompt_answer if isinstance(prompt_answer, list) else [prompt_answer],
154 | }
155 |
156 | if len(item['prediction']) == 0:
157 | invalid_outputs.append({'question': prompt, 'output': completion, 'answer': item['prediction']})
158 | res = False
159 | extract_ans = None
160 | else:
161 | extract_ans = item['prediction']
162 | res = eval_math(item)
163 |
164 | results.append(res)
165 |
166 | to_save_dict = {
167 | 'prompt': prompt,
168 | 'completion': completion,
169 | 'extract_answer': extract_ans,
170 | 'prompt_answer': prompt_answer,
171 | 'result': res,
172 | }
173 | to_save_dict.update(attribute)
174 | to_save_list.append(to_save_dict)
175 |
176 | acc = sum(results) / len(results)
177 | # print('valid_outputs===', invalid_outputs)
178 | print('len invalid outputs ====', len(invalid_outputs))
179 | print('n_groups===', n_groups, ', remainder====', remainder)
180 | print('length====', len(results), ', acc====', acc)
181 |
182 | try:
183 | with open(save_path, "w+") as f:
184 | json.dump(to_save_list, f, indent=4)
185 | except Exception:
186 | pdb.set_trace()
187 |
188 | def parse_args():
189 | parser = argparse.ArgumentParser()
190 | parser.add_argument("--model", type=str, default='') # model path
191 | parser.add_argument("--data_file", type=str, default='') # data path
192 | parser.add_argument("--remainder", type=int, default=0) # index
193 | parser.add_argument("--n_groups", type=int, default=1) # group number
194 | parser.add_argument("--batch_size", type=int, default=400) # batch_size
195 | parser.add_argument("--tensor_parallel_size", type=int, default=8) # tensor_parallel_size
196 | parser.add_argument("--save_path", type=str)
197 | parser.add_argument("--prompt", type=str, default='alpaca')
198 | parser.add_argument("--temp", type=float, default=0.0)
199 | parser.add_argument("--top_p", type=float, default=1.0)
200 | parser.add_argument("--seed", type=int, default=None)
201 | parser.add_argument("--rep", type=int, default=1)
202 | return parser.parse_args()
203 |
204 | if __name__ == "__main__":
205 | args = parse_args()
206 | test_hendrycks_math(model=args.model, data_path=args.data_file, remainder=args.remainder, n_groups=args.n_groups, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size, args=args)
207 |
--------------------------------------------------------------------------------
/eval_results/gsm8k/sample.json:
--------------------------------------------------------------------------------
1 | []
--------------------------------------------------------------------------------
/eval_results/math/qwen2-7b-dpo-v3-continue-from-incorrect-fix-part1-filtered+2-filtered+aqua-filtered-rej-original_acc0.6-topk1-beta0.5-8ep-fixbug-fixeos-bf16-keywords-fix.json:
--------------------------------------------------------------------------------
1 | [
2 |
3 | ]
--------------------------------------------------------------------------------
/eval_results/math/sample.json:
--------------------------------------------------------------------------------
1 | []
--------------------------------------------------------------------------------
/evaluation/data_processing/answer_extraction.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import regex
4 |
5 |
6 | def _fix_fracs(string):
7 | substrs = string.split("\\frac")
8 | new_str = substrs[0]
9 | if len(substrs) > 1:
10 | substrs = substrs[1:]
11 | for substr in substrs:
12 | new_str += "\\frac"
13 | if len(substr) > 0 and substr[0] == "{":
14 | new_str += substr
15 | else:
16 | try:
17 | assert len(substr) >= 2
18 | except Exception:
19 | return string
20 | a = substr[0]
21 | b = substr[1]
22 | if b != "{":
23 | if len(substr) > 2:
24 | post_substr = substr[2:]
25 | new_str += "{" + a + "}{" + b + "}" + post_substr
26 | else:
27 | new_str += "{" + a + "}{" + b + "}"
28 | else:
29 | if len(substr) > 2:
30 | post_substr = substr[2:]
31 | new_str += "{" + a + "}" + b + post_substr
32 | else:
33 | new_str += "{" + a + "}" + b
34 | string = new_str
35 | return string
36 |
37 |
38 | def _fix_a_slash_b(string):
39 | if len(string.split("/")) != 2:
40 | return string
41 | a = string.split("/")[0]
42 | b = string.split("/")[1]
43 | try:
44 | if "sqrt" not in a:
45 | a = int(a)
46 | if "sqrt" not in b:
47 | b = int(b)
48 | assert string == "{}/{}".format(a, b)
49 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
50 | return new_string
51 | except Exception:
52 | return string
53 |
54 |
55 | def _fix_sqrt(string):
56 | _string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string)
57 | _string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string)
58 | return _string
59 |
60 |
61 | def _fix_tan(string):
62 | _string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string)
63 | _string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string)
64 | return _string
65 |
66 |
67 | def strip_string(string):
68 | string = str(string).strip()
69 | # linebreaks
70 | string = string.replace("\n", "")
71 |
72 | # right "."
73 | string = string.rstrip(".")
74 |
75 | # remove inverse spaces
76 | string = string.replace("\\!", "")
77 | # string = string.replace("\\ ", "")
78 |
79 | # replace \\ with \
80 | # string = string.replace("\\\\", "\\")
81 | # string = string.replace("\\\\", "\\")
82 |
83 | if string.startswith("\\text{") and string.endswith("}"):
84 | string = string.split("{", 1)[1][:-1]
85 |
86 | # replace tfrac and dfrac with frac
87 | string = string.replace("tfrac", "frac")
88 | string = string.replace("dfrac", "frac")
89 | string = string.replace("cfrac", "frac")
90 |
91 | # remove \left and \right
92 | string = string.replace("\\left", "")
93 | string = string.replace("\\right", "")
94 |
95 | # Remove unit: miles, dollars if after is not none
96 | _string = re.sub(r"\\text{.*?}$", "", string).strip()
97 | if _string != "" and _string != string:
98 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
99 | string = _string
100 |
101 | # Remove circ (degrees)
102 | string = string.replace("^{\\circ}", "").strip()
103 | string = string.replace("^\\circ", "").strip()
104 |
105 | string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip()
106 | string = regex.sub(r"p\.m\.$", "", string).strip()
107 | string = regex.sub(r"(\d)\s*t$", r"\1", string).strip()
108 |
109 | # remove dollar signs
110 | string = string.replace("\\$", "")
111 | string = string.replace("$", "")
112 |
113 | # string = string.replace("\\text", "")
114 | string = string.replace("x\\in", "")
115 |
116 | # remove percentage
117 | string = string.replace("\\%", "%")
118 | string = string.replace("\%", "%")
119 | # string = string.replace("%", "")
120 |
121 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
122 | string = string.replace(" .", " 0.")
123 | string = string.replace("{.", "{0.")
124 |
125 | # cdot
126 | string = string.replace("\\cdot", "")
127 |
128 | # inf
129 | string = string.replace("infinity", "\\infty")
130 | if "\\infty" not in string:
131 | string = string.replace("inf", "\\infty")
132 | string = string.replace("+\\inity", "\\infty")
133 |
134 | # and
135 | # string = string.replace("and", "")
136 | string = string.replace("\\mathbf", "")
137 | string = string.replace("\\mathrm", "")
138 |
139 | # use regex to remove \mbox{...}
140 | string = re.sub(r"\\mbox{.*?}", "", string)
141 |
142 | # quote
143 | string.replace("'", "")
144 | string.replace("\"", "")
145 |
146 | # i, j
147 | if "j" in string and "i" not in string:
148 | string = string.replace("j", "i")
149 |
150 | # replace a.000b where b is not number or b is end, with ab, use regex
151 | string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
152 | string = re.sub(r"(\d+)\.0+$", r"\1", string)
153 |
154 | # if empty, return empty string
155 | if len(string) == 0:
156 | return string
157 | if string[0] == ".":
158 | string = "0" + string
159 |
160 | # to consider: get rid of e.g. "k = " or "q = " at beginning
161 | # if len(string.split("=")) == 2:
162 | # if len(string.split("=")[0]) <= 2:
163 | # string = string.split("=")[1]
164 |
165 | string = _fix_sqrt(string)
166 | string = _fix_tan(string)
167 | string = string.replace(" ", "")
168 |
169 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
170 | string = _fix_fracs(string)
171 |
172 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
173 | string = _fix_a_slash_b(string)
174 |
175 | string = regex.sub(r"(\\|,|\.)+$", "", string)
176 |
177 | return string
178 |
179 | def extract_boxed_answers(text):
180 | answers = []
181 | for piece in text.split('boxed{')[1:]:
182 | n = 0
183 | for i in range(len(piece)):
184 | if piece[i] == '{':
185 | n += 1
186 | elif piece[i] == '}':
187 | n -= 1
188 | if n < 0:
189 | if i + 1 < len(piece) and piece[i + 1] == '%':
190 | answers.append(piece[: i + 1])
191 | else:
192 | answers.append(piece[:i])
193 | break
194 | return answers
195 |
196 | def extract_program_output(pred_str):
197 | """
198 | extract output between the last ```output\n...\n```
199 | """
200 | if "```output" not in pred_str:
201 | return ""
202 | if '```output' in pred_str:
203 | pred_str = pred_str.split('```output')[-1]
204 | if '```' in pred_str:
205 | pred_str = pred_str.split('```')[0]
206 | output = pred_str.strip()
207 | return output
208 |
209 | def extract_answer(pred_str, exhaust=False):
210 | pred = []
211 |
212 | # import pdb; pdb.set_trace()
213 |
214 | if 'final answer is $' in pred_str and '$. I hope' in pred_str:
215 | tmp = pred_str.split('final answer is $', 1)[1]
216 | pred = [tmp.split('$. I hope', 1)[0].strip()]
217 | elif 'boxed' in pred_str:
218 | pred = extract_boxed_answers(pred_str)
219 |
220 | # import pdb; pdb.set_trace()
221 |
222 | elif ('he answer is' in pred_str):
223 | pred = [pred_str.split('he answer is')[-1].strip()]
224 | else:
225 | program_output = extract_program_output(pred_str)
226 | if program_output != "":
227 | # fall back to program
228 | pred.append(program_output)
229 | else: # use the last number
230 | pattern = '-?\d*\.?\d+'
231 | ans = re.findall(pattern, pred_str.replace(",", ""))
232 | if(len(ans) >= 1):
233 | ans = ans[-1]
234 | else:
235 | ans = ''
236 | if ans:
237 | pred.append(ans)
238 |
239 | # multiple line
240 | _pred = []
241 | for ans in pred:
242 | # ans = ans.strip().split("\n")[0]
243 | ans = ans.strip()
244 |
245 | ans = ans.lstrip(":")
246 | ans = ans.rstrip(".")
247 | ans = ans.rstrip("/")
248 | ans = strip_string(ans)
249 | _pred.append(ans)
250 | if exhaust:
251 | return _pred
252 | else:
253 | return _pred[-1] if _pred else ""
254 |
255 | def extract_math_answer(question, reasoning, task):
256 | answer = []
257 | for ans in extract_answer(reasoning, exhaust=True):
258 | if 'separated by commas' in question and all(ch not in ans for ch in '()[]'):
259 | answer.extend([a.strip() for a in ans.split(",")])
260 | elif regex.search(r"\\text\{\s*and\s*\}", ans):
261 | answer.extend([a.strip() for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split("[SEP]")])
262 | else:
263 | answer.append(ans.strip())
264 | return answer
265 |
266 | def extract_math_few_shot_cot_answer(question, reasoning, task):
267 | if 'Problem:' in reasoning:
268 | reasoning = reasoning.split("Problem:", 1)[0]
269 | return extract_math_answer(question, reasoning, task)
270 |
271 | def extract_last_single_answer(question, reasoning, task):
272 | return extract_answer(reasoning, exhaust=False)
273 |
274 | def extract_gsm_few_shot_cot_answer(question, reasoning, task):
275 | if 'Q: ' in reasoning:
276 | reasoning = reasoning.split("Q: ", 1)[0]
277 | pred = [s for s in regex.findall(r'-?\d+\.?\d*', reasoning)]
278 | if pred:
279 | return pred[-1]
280 | else:
281 | return "[invalid]"
282 |
283 | def extract_agieval_gaokao_mathcloze_few_shot_cot_test(question, reasoning, task):
284 | if '问题 ' in reasoning:
285 | reasoning = reasoning.split("问题 ", 1)[0]
286 | if '答案是' in reasoning:
287 | ans = reasoning.split('答案是', 1)[1].strip()
288 | ans = ans.split("\n")[0].strip()
289 | ans = [ans.strip("$")]
290 | else:
291 | ans = ['placeholder']
292 | return ans
293 |
294 | def extract_agieval_gaokao_mathqa_few_shot_cot_test(question, reasoning, task):
295 | if '问题 ' in reasoning:
296 | reasoning = reasoning.split("问题 ", 1)[0]
297 | if '答案是' in reasoning:
298 | ans = reasoning.split('答案是', 1)[1].strip()
299 | ans = ans.split("\n")[0].strip()
300 | else:
301 | ans = 'placeholder'
302 | return ans
303 |
304 | def extract_sat_few_shot_answer(question, reasoning, task):
305 | if 'Problem:' in reasoning:
306 | reasoning = reasoning.split("Problem:", 1)[0]
307 | patt = regex.search(r"the final answer is \(?(?P[abcd])\)?", reasoning.lower())
308 | if patt is not None:
309 | return patt.group('ans').upper()
310 | return 'placeholder'
311 |
312 | def extract_ocwcourses_few_shot_answer(question, reasoning, task):
313 | if 'Problem:' in reasoning:
314 | reasoning = reasoning.split("Problem:", 1)[0]
315 | patt = regex.search(r"final answer is (?P.*)\. I hope it is correct.", reasoning)
316 | if patt is None:
317 | pred = "[invalid]"
318 | print(f"DEBUG >>>\n{reasoning}", flush=True)
319 | else:
320 | pred = patt.group('ans')
321 | return pred
322 |
323 | def extract_mmlu_stem(question, reasoning, task):
324 | if 'Problem:' in reasoning:
325 | reasoning = reasoning.split("Problem:", 1)[0]
326 | return extract_sat_few_shot_answer(question, reasoning, task)
327 |
328 | def extract_minif2f_isabelle(question, reasoning, task):
329 | if 'Informal:' in reasoning:
330 | reasoning = reasoning.split("Informal:", 1)[0]
331 | return reasoning.strip()
332 |
333 | def extract_cmath_few_shot_test(question, reasoning, task):
334 | if '问题:' in reasoning:
335 | reasoning = reasoning.split("问题:", 1)[0]
336 | if '答案是' in reasoning:
337 | ans = reasoning.split('答案是', 1)[1].strip()
338 | ans = ans.split("\n")[0]
339 | ans = ans.strip(":")
340 | ans = ans.strip("。")
341 | try:
342 | ans = [s for s in regex.findall(r'-?\d+\.?\d*', ans)][-1]
343 | except Exception:
344 | print(f"DEBUG CMATH: {reasoning}", flush=True)
345 | ans = "[invalid]"
346 | else:
347 | ans = extract_last_single_answer(question, reasoning, task)
348 | return ans
349 |
--------------------------------------------------------------------------------
/evaluation/data_processing/process_utils.py:
--------------------------------------------------------------------------------
1 | import regex
2 | from data_processing.answer_extraction import extract_math_answer, strip_string
3 |
4 |
5 | def process_gsm8k_test(item):
6 | sample = {
7 | 'dataset': 'gsm8k-cot',
8 | 'id': item['id'],
9 | 'messages': [
10 | {'role': 'user', 'content': item['question']},
11 | {'role': 'assistant', 'content': regex.sub(r"<<[^<>]*>>", "", item['cot']) + "\nSo the answer is $\\boxed{" + item['answer'].strip() + "}$."}
12 | ],
13 | 'answer': item['answer'].replace(',', '')
14 | }
15 | yield sample
16 |
17 | def process_math_test(item):
18 | question = item["problem"]
19 | try:
20 | answer = extract_math_answer(question, item['solution'], task="cot")
21 | except Exception:
22 | return
23 | sample = {
24 | "dataset": "math-cot",
25 | "id": item['id'],
26 | "level": item["level"],
27 | "type": item["type"],
28 | "category": item["category"],
29 | "messages": [
30 | {"role": "user", "content": question},
31 | {"role": "assistant", "content": "\n".join(regex.split(r"(?<=\.) (?=[A-Z])", item["solution"]))}
32 | ],
33 | "answer": answer
34 | }
35 | yield sample
36 |
37 | def process_math_sat(item):
38 | options = item['options'].strip()
39 | assert 'A' == options[0]
40 | options = '(' + options
41 | for ch in 'BCDEFG':
42 | if f' {ch}) ' in options:
43 | options = regex.sub(f' {ch}\) ', f" ({ch}) ", options)
44 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
45 | messages = [
46 | {'role': 'user', 'content': question},
47 | {'role': 'assistant', 'content': item['Answer']}
48 | ]
49 | item = {
50 | 'dataset': 'math_sat',
51 | 'id': item['id'],
52 | 'language': 'en',
53 | 'messages': messages,
54 | 'answer': item['Answer'],
55 | }
56 | yield item
57 |
58 | def process_ocwcourses(item):
59 | messages = [
60 | {'role': 'user', 'content': item['problem'].strip()},
61 | {'role': 'assistant', 'content': item['solution'].strip()}
62 | ]
63 | item = {
64 | "dataset": "OCWCourses",
65 | "id": item['id'],
66 | "language": "en",
67 | "messages": messages,
68 | "answer": item['answer']
69 | }
70 | yield item
71 |
72 | def process_mmlu_stem(item):
73 | options = item['options']
74 | for i, (label, option) in enumerate(zip('ABCD', options)):
75 | options[i] = f"({label}) {str(option).strip()}"
76 | options = ", ".join(options)
77 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
78 | messages = [
79 | {'role': 'user', 'content': question},
80 | {'role': 'assistant', 'content': item['answer']}
81 | ]
82 | item = {
83 | "dataset": "MMLU-STEM",
84 | "id": item['id'],
85 | "language": "en",
86 | "messages": messages,
87 | "answer": item['answer']
88 | }
89 | yield item
90 |
91 | def process_mgsm_zh(item):
92 | item['answer'] = item['answer'].replace(',', '')
93 | yield item
94 |
95 | def process_cmath(item):
96 | item = {
97 | 'dataset': 'cmath',
98 | 'id': item['id'],
99 | 'grade': item['grade'],
100 | 'reasoning_step': item['reasoning_step'],
101 | 'messages': [
102 | {'role': 'user', 'content': item['question'].strip()},
103 | {'role': 'assistant', 'content': ''}
104 | ],
105 | 'answer': item['golden'].strip().replace(",", "")
106 | }
107 | yield item
108 |
109 | def process_agieval_gaokao_math_cloze(item):
110 | item = {
111 | 'dataset': 'agieval-gaokao-math-cloze',
112 | 'id': item['id'],
113 | 'messages': [
114 | {'role': 'user', 'content': item['question'].strip()},
115 | {'role': 'assistant', 'content': ''}
116 | ],
117 | 'answer': [strip_string(ans) for ans in item['answer'].strip().split(";")]
118 | }
119 | yield item
120 |
121 | def process_agieval_gaokao_mathqa(item):
122 | question = item['question'].strip()
123 | options = []
124 | for option in item['options']:
125 | option = option.strip()
126 | assert option[0] == '('
127 | assert option[2] == ')'
128 | assert option[1] in 'ABCD'
129 | option = f"{option[1]}: {option[3:].strip()}"
130 | options.append(option.strip())
131 | question = f"{question}\n{options}"
132 | item = {
133 | 'dataset': 'agieval-gaokao-mathqa',
134 | 'id': item['id'],
135 | 'messages': [
136 | {'role': 'user', 'content': question},
137 | {'role': 'assistant', 'content': ''}
138 | ],
139 | "answer": item['label']
140 | }
141 | yield item
142 |
143 | def process_agieval_gaokao_mathqa_few_shot_cot_test(item):
144 | question = item['question'].strip().rstrip('\\')
145 | options = " ".join([opt.strip() for opt in item['options']])
146 | question = f"{question}\n从以下选项中选择: {options}"
147 | item = {
148 | 'dataset': 'agieval-gaokao-mathqa',
149 | 'id': item['id'],
150 | 'messages': [
151 | {'role': 'user', 'content': question},
152 | {'role': 'assistant', 'content': ''}
153 | ],
154 | "answer": item['label']
155 | }
156 | yield item
157 |
158 | def process_minif2f_isabelle(item):
159 | question = f"(*### Problem\n\n{item['informal_statement'].strip()}\n\n### Solution\n\n{item['informal_proof'].strip()} *)\n\nFormal:\n{item['formal_statement'].strip()}"
160 | item = {
161 | 'dataset': 'minif2f-isabelle',
162 | 'id': item['id'],
163 | 'messages': [
164 | {'role': 'user', 'content': question},
165 | {'role': 'assistant', 'content': ''}
166 | ],
167 | "answer": "placeholder"
168 | }
169 | yield item
170 |
--------------------------------------------------------------------------------
/evaluation/eval/eval_script.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import regex
4 | from evaluation.eval.eval_utils import math_equal
5 | from evaluation.eval.ocwcourses_eval_utils import (
6 | SymbolicMathMixin,
7 | normalize_numeric,
8 | normalize_symbolic_equation,
9 | numeric_equality,
10 | )
11 |
12 |
13 | def is_correct(item, pred_key='prediction', prec=1e-3):
14 | pred = item[pred_key]
15 | ans = item['answer']
16 | if isinstance(pred, list) and isinstance(ans, list):
17 | pred_matched = set()
18 | ans_matched = set()
19 | for i in range(len(pred)):
20 | for j in range(len(ans)):
21 | item_cpy = deepcopy(item)
22 | item_cpy.update({
23 | pred_key: pred[i],
24 | 'answer': ans[j]
25 | })
26 | if is_correct(item_cpy, pred_key=pred_key, prec=prec):
27 | pred_matched.add(i)
28 | ans_matched.add(j)
29 | if item_cpy[pred_key] == '2,3,4':
30 | print(item, flush=True)
31 | print("wtf", flush=True)
32 | return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
33 | elif isinstance(pred, str) and isinstance(ans, str):
34 | if '\\cup' in pred and '\\cup' in ans:
35 | item = deepcopy(item)
36 | item.update({
37 | pred_key: pred.split('\\cup'),
38 | 'answer': ans.split('\\cup'),
39 | })
40 | return is_correct(item, pred_key=pred_key, prec=prec)
41 | else:
42 | label = False
43 | try:
44 | label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec
45 | except Exception:
46 | pass
47 |
48 | # if ans == "0.5":
49 | # import pdb; pdb.set_trace()
50 |
51 | label = label or (ans and pred == ans) or math_equal(pred, ans)
52 | return label
53 | else:
54 | print(item, flush=True)
55 | raise NotImplementedError()
56 |
57 | def eval_math(item, pred_key='prediction', prec=1e-3):
58 | pred = item[pred_key]
59 | if pred_key == 'program_output' and isinstance(pred, str):
60 | pred = [pred]
61 | ans = item['answer']
62 | if isinstance(pred, list) and isinstance(ans, list):
63 | # for some questions in MATH, `reference` repeats answers
64 | _ans = []
65 | for a in ans:
66 | if a not in _ans:
67 | _ans.append(a)
68 | ans = _ans
69 | # some predictions for MATH questions also repeats answers
70 | _pred = []
71 | for a in pred:
72 | if a not in _pred:
73 | _pred.append(a)
74 | # some predictions mistakenly box non-answer strings
75 | pred = _pred[-len(ans):]
76 |
77 | item.update({
78 | pred_key: pred,
79 | 'answer': ans
80 | })
81 | return is_correct(item, pred_key=pred_key, prec=prec)
82 |
83 | def eval_last_single_answer(item, pred_key='prediction', prec=1e-3):
84 | for key in [pred_key, 'answer']:
85 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
86 | return is_correct(item, pred_key=pred_key, prec=prec)
87 |
88 | def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3):
89 | if pred_key == 'program_output' and isinstance(item[pred_key], str):
90 | item[pred_key] = [item[pred_key]]
91 | for key in [pred_key, 'answer']:
92 | assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
93 | pred = item[pred_key]
94 | ans = item['answer']
95 | _pred = []
96 | for p in pred:
97 | p = p + ";"
98 | while p:
99 | left_brackets = 0
100 | for i in range(len(p)):
101 | if p[i] == ';' or (p[i] == ',' and left_brackets == 0):
102 | _p, p = p[:i].strip(), p[i + 1:].strip()
103 | if _p not in _pred:
104 | _pred.append(_p)
105 | break
106 | elif p[i] in '([{':
107 | left_brackets += 1
108 | elif p[i] in ')]}':
109 | left_brackets -= 1
110 | pred = _pred[-len(ans):]
111 | if len(pred) == len(ans):
112 | for p, a in zip(pred, ans):
113 | item.update({
114 | pred_key: p,
115 | 'answer': a,
116 | })
117 | if not is_correct(item, pred_key=pred_key, prec=prec):
118 | return False
119 | return True
120 | else:
121 | return False
122 |
123 | def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3):
124 | if pred_key == 'program_output' and isinstance(item[pred_key], str):
125 | item[pred_key] = [item[pred_key]]
126 | pred_str = " ".join(item[pred_key])
127 | ans = item['answer']
128 | tag = None
129 | idx = -1
130 | for t in 'ABCD':
131 | if t in pred_str and pred_str.index(t) > idx:
132 | tag = t
133 | idx = pred_str.index(t)
134 | return tag == ans
135 |
136 | def eval_math_sat(item, pred_key='prediction', prec=1e-3):
137 | for key in [pred_key, 'answer']:
138 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
139 | return item[pred_key].lower() == item['answer'].lower()
140 |
141 | def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3):
142 | return eval_math_sat(item, pred_key=pred_key, prec=prec)
143 |
144 | def eval_ocwcourses(item, pred_key='prediction', prec=1e-3):
145 | INVALID_ANSWER = "[invalidanswer]"
146 | for key in [pred_key, 'answer']:
147 | assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
148 | pred = item[pred_key]
149 | ans = item['answer']
150 |
151 | try:
152 | # numeric
153 | float(ans)
154 | normalize_fn = normalize_numeric
155 | is_equiv = numeric_equality
156 | except ValueError:
157 | if "=" in ans:
158 | # equation
159 | normalize_fn = normalize_symbolic_equation
160 | is_equiv = lambda x, y: x==y
161 | else:
162 | # expression
163 | normalize_fn = SymbolicMathMixin().normalize_tex
164 | is_equiv = SymbolicMathMixin().is_tex_equiv
165 |
166 | correct_answer = normalize_fn(ans)
167 |
168 | unnormalized_answer = pred if pred else INVALID_ANSWER
169 | model_answer = normalize_fn(unnormalized_answer)
170 |
171 | if unnormalized_answer == INVALID_ANSWER:
172 | acc = 0
173 | elif model_answer == INVALID_ANSWER:
174 | acc = 0
175 | elif is_equiv(model_answer, correct_answer):
176 | acc = 1
177 | else:
178 | acc = 0
179 |
180 | return acc
181 |
182 | def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3):
183 | return True
184 |
--------------------------------------------------------------------------------
/evaluation/eval/eval_utils.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import re
3 | from math import isclose
4 | from typing import Any, Dict, Union
5 |
6 | import numpy as np
7 | import regex
8 | from evaluation.data_processing.answer_extraction import (
9 | extract_answer,
10 | extract_program_output,
11 | strip_string,
12 | )
13 | from sympy import N, simplify
14 | from sympy.parsing.latex import parse_latex
15 | from sympy.parsing.sympy_parser import parse_expr
16 |
17 |
18 | def extract_program(result: str, last_only=True):
19 | """
20 | extract the program after "```python", and before "```"
21 | """
22 | program = ""
23 | start = False
24 | for line in result.split("\n"):
25 | if line.startswith("```python"):
26 | if last_only:
27 | program = "" # only extract the last program
28 | else:
29 | program += "\n# ========\n"
30 | start = True
31 | elif line.startswith("```"):
32 | start = False
33 | elif start:
34 | program += line + "\n"
35 | return program
36 |
37 |
38 | def parse_ground_truth(example: Dict[str, Any], data_name):
39 | if 'gt_cot' in example:
40 | return example['gt_cot'], strip_string(example['gt'])
41 |
42 | # parse ground truth
43 | if data_name in ["math", 'ocw']:
44 | gt_cot = example['solution']
45 | gt_ans = extract_answer(gt_cot)
46 | elif data_name == "gsm8k":
47 | gt_cot, gt_ans = example['answer'].split("####")
48 | elif data_name == "gsm-hard":
49 | gt_cot, gt_ans = example['code'], example['target']
50 | elif data_name == "svamp":
51 | gt_cot, gt_ans = example['Equation'], example['Answer']
52 | elif data_name == "asdiv":
53 | gt_cot = example['formula']
54 | gt_ans = re.sub(r"\(.*?\)", "", example['answer'])
55 | elif data_name == "mawps":
56 | gt_cot, gt_ans = None, example['target']
57 | elif data_name == "tabmwp":
58 | gt_cot = example['solution']
59 | gt_ans = example['answer']
60 | if example['ans_type'] in ['integer_number', 'decimal_number']:
61 | if '/' in gt_ans:
62 | gt_ans = int(gt_ans.split('/')[0]) / int(gt_ans.split('/')[1])
63 | elif ',' in gt_ans:
64 | gt_ans = float(gt_ans.replace(',', ''))
65 | elif '%' in gt_ans:
66 | gt_ans = float(gt_ans.split('%')[0]) / 100
67 | else:
68 | gt_ans = float(gt_ans)
69 | elif data_name == "bbh":
70 | gt_cot, gt_ans = None, example['target']
71 | else:
72 | raise NotImplementedError(data_name)
73 | # post process
74 | gt_cot = str(gt_cot).strip()
75 | gt_ans = strip_string(gt_ans)
76 | return gt_cot, gt_ans
77 |
78 |
79 | def parse_question(example, data_name):
80 | question = ""
81 | if data_name == "asdiv":
82 | question = f"{example['body'].strip()} {example['question'].strip()}"
83 | elif data_name == "svamp":
84 | body = example["Body"].strip()
85 | if not body.endswith("."):
86 | body = body + "."
87 | question = f'{body} {example["Question"].strip()}'
88 | elif data_name == "tabmwp":
89 | title_str = f'regarding "{example["table_title"]}" ' if example['table_title'] else ""
90 | question = f'Read the following table {title_str}and answer a question:\n'
91 | question += f'{example["table"]}\n{example["question"]}'
92 | if example['choices']:
93 | question += f' Please select from the following options: {example["choices"]}'
94 | else:
95 | for key in ['question', 'problem', 'Question', 'input']:
96 | if key in example:
97 | question = example[key]
98 | break
99 | assert question != ""
100 | return question.strip()
101 |
102 |
103 | def run_execute(executor, result, prompt_type, execute=False):
104 | if not result or result == 'error':
105 | return None, None
106 | report = None
107 |
108 | if "program_only" in prompt_type:
109 | prediction = extract_program_output(result)
110 | elif prompt_type in ["pot", "pal"] and execute:
111 | code = extract_program(result)
112 | prediction, report = executor.apply(code)
113 | else:
114 | prediction = extract_answer(result)
115 |
116 | prediction = strip_string(prediction)
117 | return prediction, report
118 |
119 |
120 | def parse_digits(num):
121 | # format: 234.23 || 23%
122 | num = regex.sub(',', '', str(num))
123 | try:
124 | return float(num)
125 | except Exception:
126 | if num.endswith('%'):
127 | num = num[:-1]
128 | if num.endswith('\\'):
129 | num = num[:-1]
130 | try:
131 | return float(num) / 100
132 | except Exception:
133 | pass
134 | return None
135 |
136 | def is_digit(num):
137 | # paired with parse_digits
138 | return parse_digits(num) is not None
139 |
140 |
141 | def normalize_prediction(prediction):
142 | try: # 1. numerical equal
143 | if is_digit(prediction):
144 | prediction = np.round(float(str(prediction).replace(",", "")), 6)
145 | return str(prediction)
146 | except Exception:
147 | pass
148 |
149 | # 2. symbolic equal
150 | prediction = str(prediction).strip()
151 |
152 | ## deal with [], (), {}
153 | brackets = []
154 | while prediction.startswith("[") and prediction.endswith("]") or (prediction.startswith("(") and prediction.endswith(")")):
155 | prediction = prediction[1:-1]
156 | if brackets and ',' in prediction:
157 | pred_parts = [normalize_prediction(part) for part in prediction.split(",")]
158 | prediction = ",".join(pred_parts)
159 |
160 | if brackets:
161 | for b in reversed(brackets):
162 | if b == '[':
163 | prediction = '[' + prediction + ']'
164 | else:
165 | assert b == '('
166 | prediction = '(' + prediction + ')'
167 |
168 | def _parse(s):
169 | for f in [parse_latex, parse_expr]:
170 | try:
171 | return f(s)
172 | except Exception:
173 | pass
174 | return s
175 |
176 | prediction = _parse(prediction)
177 |
178 | for s in ['{', "}", "(", ")"]:
179 | prediction = prediction.replace(s, "")
180 |
181 | return prediction
182 |
183 |
184 | def math_equal(prediction: Union[bool, float, str],
185 | reference: Union[float, str],
186 | include_percentage: bool = True,
187 | is_close: bool = True,
188 | timeout: bool = False,
189 | ) -> bool:
190 | """
191 | Exact match of math if and only if:
192 | 1. numerical equal: both can convert to float and are equal
193 | 2. symbolic equal: both can convert to sympy expression and are equal
194 | """
195 | if str(prediction) == str(reference):
196 | return True
197 |
198 | if "^{216}" in prediction:
199 | return False
200 |
201 | print("prediction: {}, reference: {}".format(
202 | prediction, reference
203 | ))
204 |
205 | try: # 1. numerical equal
206 | if is_digit(prediction) and is_digit(reference):
207 | prediction = parse_digits(prediction)
208 | reference = parse_digits(reference)
209 | # number questions
210 | if include_percentage:
211 | gt_result = [reference / 100, reference, reference * 100]
212 | else:
213 | gt_result = [reference]
214 | for item in gt_result:
215 | try:
216 | if is_close:
217 | if isclose(item, prediction, abs_tol=1e-3):
218 | return True
219 | else:
220 | if item == prediction:
221 | return True
222 | except Exception:
223 | continue
224 | return False
225 | except Exception:
226 | pass
227 |
228 | if not prediction and prediction not in [0, False]:
229 | return False
230 |
231 | # 2. symbolic equal
232 | reference = str(reference).strip()
233 | prediction = str(prediction).strip()
234 |
235 | if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None:
236 | pred_parts = prediction[1:-1].split(",")
237 | ref_parts = reference[1:-1].split(",")
238 | if len(pred_parts) == len(ref_parts):
239 | if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
240 | return True
241 |
242 | if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \
243 | (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")):
244 | pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
245 | ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
246 | matched = True
247 | if len(pred_lines) == len(ref_lines):
248 | for pred_line, ref_line in zip(pred_lines, ref_lines):
249 | pred_parts = pred_line.split("&")
250 | ref_parts = ref_line.split("&")
251 | if len(pred_parts) == len(ref_parts):
252 | if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
253 | matched = False
254 | break
255 | else:
256 | matched = False
257 | if not matched:
258 | break
259 | else:
260 | matched = False
261 | if matched:
262 | return True
263 |
264 | if prediction.count('=') == 1 and reference.count('=') == 1:
265 | pred = prediction.split('=')
266 | pred = f"{pred[0].strip()} - ({pred[1].strip()})"
267 | ref = reference.split('=')
268 | ref = f"{ref[0].strip()} - ({ref[1].strip()})"
269 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
270 | return True
271 | elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference:
272 | if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
273 | return True
274 | elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction:
275 | if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
276 | return True
277 |
278 | # symbolic equal with sympy
279 | if timeout:
280 | if call_with_timeout(symbolic_equal_process, prediction, reference):
281 | return True
282 | else:
283 | if symbolic_equal(prediction, reference):
284 | return True
285 |
286 | return False
287 |
288 |
289 | def math_equal_process(param):
290 | return math_equal(param[-2], param[-1])
291 |
292 |
293 | def symbolic_equal(a, b):
294 | def _parse(s):
295 | for f in [parse_latex, parse_expr]:
296 | try:
297 | return f(s)
298 | except Exception:
299 | pass
300 | return s
301 | a = _parse(a)
302 | b = _parse(b)
303 |
304 | try:
305 | if simplify(a-b) == 0:
306 | return True
307 | except Exception:
308 | pass
309 |
310 | try:
311 | if isclose(N(a), N(b), abs_tol=1e-3):
312 | return True
313 | except Exception:
314 | pass
315 | return False
316 |
317 |
318 | def symbolic_equal_process(a, b, output_queue):
319 | result = symbolic_equal(a, b)
320 | output_queue.put(result)
321 |
322 |
323 | def call_with_timeout(func, *args, timeout=1, **kwargs):
324 | output_queue = multiprocessing.Queue()
325 | process_args = args + (output_queue,)
326 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
327 | process.start()
328 | process.join(timeout)
329 |
330 | if process.is_alive():
331 | process.terminate()
332 | process.join()
333 | return False
334 |
335 | return output_queue.get()
336 |
--------------------------------------------------------------------------------
/evaluation/eval/ocwcourses_eval_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import signal
3 |
4 | import numpy as np
5 | import sympy
6 | from sympy.core.sympify import SympifyError
7 | from sympy.parsing.latex import parse_latex
8 |
9 | INVALID_ANSWER = "[invalidanswer]"
10 |
11 | class timeout:
12 | def __init__(self, seconds=1, error_message="Timeout"):
13 | self.seconds = seconds
14 | self.error_message = error_message
15 |
16 | def handle_timeout(self, signum, frame):
17 | raise TimeoutError(self.error_message)
18 |
19 | def __enter__(self):
20 | signal.signal(signal.SIGALRM, self.handle_timeout)
21 | signal.alarm(self.seconds)
22 |
23 | def __exit__(self, type, value, traceback):
24 | signal.alarm(0)
25 |
26 | def normalize_numeric(s):
27 | if s is None:
28 | return None
29 | for unit in [
30 | "eV",
31 | " \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}",
32 | " kg m/s",
33 | "kg*m/s",
34 | "kg",
35 | "m/s",
36 | "m / s",
37 | "m s^{-1}",
38 | "\\text{ m/s}",
39 | " \\mathrm{m/s}",
40 | " \\text{ m/s}",
41 | "g/mole",
42 | "g/mol",
43 | "\\mathrm{~g}",
44 | "\\mathrm{~g} / \\mathrm{mol}",
45 | "W",
46 | "erg/s",
47 | "years",
48 | "year",
49 | "cm",
50 | ]:
51 | s = s.replace(unit, "")
52 | s = s.strip()
53 | for maybe_unit in ["m", "s", "cm"]:
54 | s = s.replace("\\mathrm{" + maybe_unit + "}", "")
55 | s = s.replace("\\mathrm{~" + maybe_unit + "}", "")
56 | s = s.strip()
57 | s = s.strip("$")
58 | try:
59 | return float(eval(s))
60 | except Exception:
61 | try:
62 | expr = parse_latex(s)
63 | if expr.is_number:
64 | return float(expr)
65 | return INVALID_ANSWER
66 | except Exception:
67 | return INVALID_ANSWER
68 |
69 | def numeric_equality(n1, n2, threshold=0.01):
70 | if n1 is None or n2 is None:
71 | return False
72 | if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0):
73 | return np.abs(n1 - n2) < threshold * (n1 + n2) / 2
74 | else:
75 | return np.isclose(n1, n2)
76 |
77 | def normalize_symbolic_equation(s):
78 | if not isinstance(s, str):
79 | return INVALID_ANSWER
80 | if s.startswith("\\["):
81 | s = s[2:]
82 | if s.endswith("\\]"):
83 | s = s[:-2]
84 | s = s.replace("\\left(", "(")
85 | s = s.replace("\\right)", ")")
86 | s = s.replace("\\\\", "\\")
87 | if s.startswith("$") or s.endswith("$"):
88 | s = s.strip("$")
89 | try:
90 | maybe_expression = parse_latex(s)
91 | if not isinstance(maybe_expression, sympy.core.relational.Equality):
92 | # we have equation, not expression
93 | return INVALID_ANSWER
94 | else:
95 | return maybe_expression
96 | except Exception:
97 | return INVALID_ANSWER
98 |
99 | class SymbolicMathMixin:
100 | """
101 | Methods useful for parsing mathematical expressions from text and determining equivalence of expressions.
102 | """
103 |
104 | SUBSTITUTIONS = [ # used for text normalize
105 | ("an ", ""),
106 | ("a ", ""),
107 | (".$", "$"),
108 | ("\\$", ""),
109 | (r"\ ", ""),
110 | (" ", ""),
111 | ("mbox", "text"),
112 | (",\\text{and}", ","),
113 | ("\\text{and}", ","),
114 | ("\\text{m}", "\\text{}"),
115 | ]
116 | REMOVED_EXPRESSIONS = [ # used for text normalizer
117 | "square",
118 | "ways",
119 | "integers",
120 | "dollars",
121 | "mph",
122 | "inches",
123 | "ft",
124 | "hours",
125 | "km",
126 | "units",
127 | "\\ldots",
128 | "sue",
129 | "points",
130 | "feet",
131 | "minutes",
132 | "digits",
133 | "cents",
134 | "degrees",
135 | "cm",
136 | "gm",
137 | "pounds",
138 | "meters",
139 | "meals",
140 | "edges",
141 | "students",
142 | "childrentickets",
143 | "multiples",
144 | "\\text{s}",
145 | "\\text{.}",
146 | "\\text{\ns}",
147 | "\\text{}^2",
148 | "\\text{}^3",
149 | "\\text{\n}",
150 | "\\text{}",
151 | r"\mathrm{th}",
152 | r"^\circ",
153 | r"^{\circ}",
154 | r"\;",
155 | r",\!",
156 | "{,}",
157 | '"',
158 | "\\dots",
159 | ]
160 |
161 | def normalize_tex(self, final_answer: str) -> str:
162 | """
163 | Normalizes a string representing a mathematical expression.
164 | Used as a preprocessing step before parsing methods.
165 |
166 | Copied character for character from appendix D of Lewkowycz et al. (2022)
167 | """
168 | final_answer = final_answer.split("=")[-1]
169 |
170 | for before, after in self.SUBSTITUTIONS:
171 | final_answer = final_answer.replace(before, after)
172 | for expr in self.REMOVED_EXPRESSIONS:
173 | final_answer = final_answer.replace(expr, "")
174 |
175 | # Extract answer that is in LaTeX math, is bold,
176 | # is surrounded by a box, etc.
177 | final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
178 | final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
179 | final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
180 | final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
181 | final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
182 |
183 | # Normalize shorthand TeX:
184 | # \fracab -> \frac{a}{b}
185 | # \frac{abc}{bef} -> \frac{abc}{bef}
186 | # \fracabc -> \frac{a}{b}c
187 | # \sqrta -> \sqrt{a}
188 | # \sqrtab -> sqrt{a}b
189 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
190 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
191 | final_answer = final_answer.replace("$", "")
192 |
193 | # Normalize 100,000 -> 100000
194 | if final_answer.replace(",", "").isdigit():
195 | final_answer = final_answer.replace(",", "")
196 |
197 | return final_answer
198 |
199 | def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic:
200 | """
201 | Wrapper around `sympy.parse_text` that outputs a SymPy expression.
202 | Typically, you want to apply `normalize_text` as a preprocessing step.
203 | """
204 | try:
205 | with timeout(seconds=time_limit):
206 | parsed = parse_latex(text)
207 | except (
208 | # general error handling: there is a long tail of possible sympy/other
209 | # errors we would like to catch
210 | Exception
211 | ) as e:
212 | print(f"failed to parse {text} with exception {e}")
213 | return None
214 |
215 | return parsed
216 |
217 | def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool:
218 | """
219 | Determines whether two sympy expressions are equal.
220 | """
221 | try:
222 | with timeout(seconds=time_limit):
223 | try:
224 | diff = x1 - x2
225 | except (SympifyError, ValueError, TypeError) as e:
226 | print(
227 | f"Couldn't subtract {x1} and {x2} with exception {e}"
228 | )
229 | return False
230 |
231 | try:
232 | if sympy.simplify(diff) == 0:
233 | return True
234 | else:
235 | return False
236 | except (SympifyError, ValueError, TypeError) as e:
237 | print(f"Failed to simplify {x1}-{x2} with {e}")
238 | return False
239 | except TimeoutError:
240 | print(f"Timed out comparing {x1} and {x2}")
241 | return False
242 | except Exception as e:
243 | print(f"Failed on unrecognized exception: {e}")
244 | return False
245 |
246 | def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool:
247 | """
248 | Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal.
249 |
250 | Does so by first checking for string exact-match, then falls back on sympy-equivalence,
251 | following the (Lewkowycz et al. 2022) methodology.
252 | """
253 | if x1 == x2:
254 | # don't resort to sympy if we have full string match, post-normalization
255 | return True
256 | else:
257 | return False
258 | parsed_x2 = self.parse_tex(x2)
259 | if not parsed_x2:
260 | # if our reference fails to parse into a Sympy object,
261 | # we forgo parsing + checking our generated answer.
262 | return False
263 | return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit)
264 |
--------------------------------------------------------------------------------
/evaluation/eval/python_executor.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import io
3 | import pickle
4 | import traceback
5 | from concurrent.futures import TimeoutError
6 | from contextlib import redirect_stdout
7 | from functools import partial
8 | from typing import Any, Dict, Optional
9 |
10 | import multiprocess
11 | import regex
12 | from pebble import ProcessPool
13 | from timeout_decorator import timeout
14 |
15 |
16 | class GenericRuntime:
17 | GLOBAL_DICT = {}
18 | LOCAL_DICT = None
19 | HEADERS = []
20 | def __init__(self):
21 | self._global_vars = copy.copy(self.GLOBAL_DICT)
22 | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
23 |
24 | for c in self.HEADERS:
25 | self.exec_code(c)
26 |
27 | def exec_code(self, code_piece: str) -> None:
28 | if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece):
29 | raise RuntimeError()
30 | exec(code_piece, self._global_vars)
31 |
32 | def eval_code(self, expr: str) -> Any:
33 | return eval(expr, self._global_vars)
34 |
35 | def inject(self, var_dict: Dict[str, Any]) -> None:
36 | for k, v in var_dict.items():
37 | self._global_vars[k] = v
38 |
39 | @property
40 | def answer(self):
41 | return self._global_vars['answer']
42 |
43 | class PythonExecutor:
44 | def __init__(
45 | self,
46 | runtime: Optional[Any] = None,
47 | get_answer_symbol: Optional[str] = None,
48 | get_answer_expr: Optional[str] = None,
49 | get_answer_from_stdout: bool = False,
50 | ) -> None:
51 | self.runtime = runtime if runtime else GenericRuntime()
52 | self.answer_symbol = get_answer_symbol
53 | self.answer_expr = get_answer_expr
54 | self.get_answer_from_stdout = get_answer_from_stdout
55 |
56 | def process_generation_to_code(self, gens: str):
57 | batch_code = []
58 | for g in gens:
59 | multiline_comments = False
60 | code = []
61 | for line in g.split('\n'):
62 | strip_line = line.strip()
63 | if strip_line.startswith("#"):
64 | line = line.split("#", 1)[0] + "# comments"
65 | elif not multiline_comments and strip_line.startswith('"""') and strip_line.endswith('"""') and len(strip_line) >= 6:
66 | line = line.split('"""', 1)[0] + '"""comments"""'
67 | elif not multiline_comments and strip_line.startswith('"""'):
68 | multiline_comments = True
69 | elif multiline_comments and strip_line.endswith('"""'):
70 | multiline_comments = False
71 | line = ""
72 | if not multiline_comments:
73 | code.append(line)
74 | batch_code.append(code)
75 | return batch_code
76 |
77 | @staticmethod
78 | def execute(
79 | code,
80 | get_answer_from_stdout = None,
81 | runtime = None,
82 | answer_symbol = None,
83 | answer_expr = None,
84 | timeout_length = 10,
85 | ):
86 | try:
87 | if get_answer_from_stdout:
88 | program_io = io.StringIO()
89 | with redirect_stdout(program_io):
90 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
91 | program_io.seek(0)
92 | result = "".join(program_io.readlines()) # [-1]
93 | elif answer_symbol:
94 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
95 | result = runtime._global_vars[answer_symbol]
96 | elif answer_expr:
97 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
98 | result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
99 | else:
100 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
101 | result = timeout(timeout_length)(runtime.eval_code)(code[-1])
102 | concise_exec_info = ""
103 | exec_info = ""
104 | str(result)
105 | pickle.dumps(result) # serialization check
106 | except Exception:
107 | # traceback.print_exc()
108 | result = ''
109 | concise_exec_info = traceback.format_exc().split('\n')[-2]
110 | exec_info = traceback.format_exc()
111 | if get_answer_from_stdout and 'exec(code_piece, self._global_vars)' in exec_info:
112 | exec_info = exec_info.split('exec(code_piece, self._global_vars)')[-1].strip()
113 | msg = []
114 | for line in exec_info.split("\n"):
115 | patt = regex.search(r'(?P.*)File "(?P.*)", line (?P\d+), (?P.*)', line)
116 | if patt is not None:
117 | if '' in patt.group('end'):
118 | continue
119 | fname = patt.group("file")
120 | if "site-packages" in fname:
121 | fname = f"site-packages{fname.split('site-packages', 1)[1]}"
122 | line = f'{patt.group("start")}File "{fname}", {patt.group("end")}'
123 | else:
124 | line = f'{patt.group("start")}{patt.group("end")}'
125 | else:
126 | patt = regex.search(r'(?P.*)(?P/.*site-packages/.*\.py)(?P.*)', line)
127 | if patt is not None:
128 | line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}'
129 | msg.append(line)
130 | exec_info = "\n".join(msg)
131 | return result, concise_exec_info, exec_info
132 |
133 | def apply(self, code):
134 | return self.batch_apply([code])[0]
135 |
136 | def batch_apply(self, batch_code):
137 | all_code_snippets = self.process_generation_to_code(batch_code)
138 | all_exec_results = []
139 | executor = partial(
140 | self.execute,
141 | get_answer_from_stdout=self.get_answer_from_stdout,
142 | runtime=self.runtime,
143 | answer_symbol=self.answer_symbol,
144 | answer_expr=self.answer_expr,
145 | timeout_length=10,
146 | )
147 | with ProcessPool(max_workers=multiprocess.cpu_count()) as pool:
148 | iterator = pool.map(executor, all_code_snippets, timeout=10).result()
149 |
150 | while True:
151 | try:
152 | result = next(iterator)
153 | all_exec_results.append(result)
154 | except StopIteration:
155 | break
156 | except TimeoutError:
157 | all_exec_results.append(("", "Timeout Error", "Timeout Error"))
158 | except Exception as error:
159 | print(error)
160 | exit()
161 |
162 | batch_results = []
163 | for code, (result, concise_exec_info, exec_info) in zip(all_code_snippets, all_exec_results):
164 | metadata = {'code': code, 'exec_result': result, 'concise_exec_info': concise_exec_info, 'exec_info': exec_info}
165 | batch_results.append((result, metadata))
166 | return batch_results
167 |
--------------------------------------------------------------------------------
/evaluation/eval/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 | from transformers import GenerationConfig, StoppingCriteria
4 |
5 |
6 | class KeyWordsCriteria(StoppingCriteria):
7 | def __init__(self, stop_id_sequences, tokenizer, prompt_length):
8 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
9 | self.tokenizer = tokenizer
10 | self.stop_id_sequences = stop_id_sequences
11 | self.stop_sequences = [tokenizer.decode(sequence) for sequence in stop_id_sequences]
12 | print(f"stop sequences: {self.stop_sequences}", flush=True)
13 | self.prompt_length = prompt_length
14 |
15 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
16 | sequences_should_be_stopped = []
17 | for i in range(input_ids.shape[0]):
18 | ids = input_ids[i][self.prompt_length:].tolist()
19 | should_be_stopped = False
20 | for stop_ids, stop_sequence in zip(self.stop_id_sequences, self.stop_sequences):
21 | _ids = ids
22 | for j in range(len(_ids), 0, -1):
23 | s = self.tokenizer.decode(_ids[max(j - len(stop_ids) - 3, 0) :j])
24 | if s.endswith(stop_sequence):
25 | should_be_stopped = True
26 | break
27 | if should_be_stopped:
28 | break
29 | sequences_should_be_stopped.append(should_be_stopped)
30 | return all(sequences_should_be_stopped)
31 |
32 | @torch.no_grad()
33 | def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, end_of_generation_id_sequence=None, disable_tqdm=False, **generation_kwargs):
34 | generations = []
35 | finish_completion = []
36 | if not disable_tqdm:
37 | progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")
38 |
39 | if stop_id_sequences is not None:
40 | stop_sequences = [tokenizer.decode(stop_id_sequence) for stop_id_sequence in stop_id_sequences]
41 |
42 | if end_of_generation_id_sequence is not None:
43 | end_of_generation_sequence = tokenizer.decode(end_of_generation_id_sequence)
44 |
45 | num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
46 | generation_kwargs['use_cache'] = True
47 | for i in range(0, len(prompts), batch_size):
48 | batch_prompts = prompts[i:i+batch_size]
49 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens='chatglm2' in str(model.__class__))
50 | batch_input_ids = tokenized_prompts.input_ids
51 | attention_mask = tokenized_prompts.attention_mask
52 |
53 | if model.device.type == "cuda":
54 | batch_input_ids = batch_input_ids.cuda()
55 | attention_mask = attention_mask.cuda()
56 |
57 | batch_finish_completion = [False] * len(batch_prompts) * num_return_sequences
58 | try:
59 | batch_outputs = model.generate(
60 | input_ids=batch_input_ids,
61 | attention_mask=attention_mask,
62 | stopping_criteria=[KeyWordsCriteria(stop_id_sequences, tokenizer, batch_input_ids.size(1))] if stop_id_sequences else None,
63 | **generation_kwargs
64 | )
65 |
66 | # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
67 | # so some outputs still have the stop sequence, which we need to remove.
68 | if stop_id_sequences:
69 | for output_idx in range(batch_outputs.shape[0]):
70 | for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
71 | if any(tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(stop_sequence) + 3]).startswith(stop_sequence) for stop_sequence in stop_sequences):
72 | if end_of_generation_id_sequence is not None and tokenizer.decode(batch_outputs[output_idx, token_idx: token_idx + len(end_of_generation_id_sequence) + 3]).startswith(end_of_generation_sequence):
73 | batch_finish_completion[output_idx] = True
74 | batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
75 | break
76 |
77 | # remove the prompt from the output
78 | # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
79 | # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
80 | # space is important for some tasks (e.g., code completion).
81 | batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
82 | batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
83 | # duplicate the prompts to match the number of return sequences
84 | batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
85 | batch_generations = [
86 | output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
87 | ]
88 | except Exception as e:
89 | print("Error when generating completions for batch:")
90 | print(batch_prompts)
91 | print("Error message:")
92 | print(e)
93 | print("Use empty string as the completion.")
94 | batch_generations = [""] * len(batch_prompts) * num_return_sequences
95 |
96 | generations += batch_generations
97 | finish_completion += batch_finish_completion
98 |
99 | if not disable_tqdm:
100 | progress.update(len(batch_prompts)//num_return_sequences)
101 |
102 | assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
103 | return generations, finish_completion
104 |
105 |
106 | @torch.no_grad()
107 | def get_next_word_predictions(model, tokenizer, prompts, candidate_token_ids=None, batch_size=1, return_token_predictions=False, disable_tqdm=False):
108 | predictions, probs = [], []
109 | if not disable_tqdm:
110 | progress = tqdm.tqdm(total=len(prompts), desc="Getting Predictions")
111 |
112 | for i in range(0, len(prompts), batch_size):
113 | batch_prompts = prompts[i: i+batch_size]
114 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=False)
115 | batch_input_ids = tokenized_prompts.input_ids
116 | attention_mask = tokenized_prompts.attention_mask
117 |
118 | if model.device.type == "cuda":
119 | batch_input_ids = batch_input_ids.cuda()
120 | attention_mask = attention_mask.cuda()
121 |
122 | batch_logits = model(input_ids=batch_input_ids, attention_mask=attention_mask).logits[:, -1, :]
123 | if candidate_token_ids is not None:
124 | batch_logits = batch_logits[:, candidate_token_ids]
125 | batch_probs = torch.softmax(batch_logits, dim=-1)
126 | batch_prediction_indices = torch.argmax(batch_probs, dim=-1)
127 | if return_token_predictions:
128 | if candidate_token_ids is not None:
129 | candidate_tokens = tokenizer.convert_ids_to_tokens(candidate_token_ids)
130 | batch_predictions = [candidate_tokens[idx] for idx in batch_prediction_indices]
131 | else:
132 | batch_predictions = tokenizer.convert_ids_to_tokens(batch_prediction_indices)
133 | predictions += batch_predictions
134 | else:
135 | predictions += batch_prediction_indices.tolist()
136 | probs += batch_probs.tolist()
137 |
138 | if not disable_tqdm:
139 | progress.update(len(batch_prompts))
140 |
141 | assert len(predictions) == len(prompts), "number of predictions should be equal to number of prompts"
142 | return predictions, probs
143 |
144 |
145 | @torch.no_grad()
146 | def score_completions(model, tokenizer, scoring_examples, disable_tqdm=False):
147 | '''
148 | Each scoring example is a dict, which contains the following keys:
149 | - prompt: the prompt to score
150 | - completions: a list of completions to score
151 | '''
152 |
153 | if not disable_tqdm:
154 | progress = tqdm.tqdm(total=len(scoring_examples), desc="Scoring Completions")
155 |
156 | # unroll the scoring examples
157 | unrolled_examples = []
158 | for scoring_example in scoring_examples:
159 | prompt = scoring_example["prompt"]
160 | for completion in scoring_example["completions"]:
161 | unrolled_examples.append({
162 | "prompt": prompt,
163 | "completion": completion
164 | })
165 |
166 | scores = []
167 | # currently we don't support batching, because we want to directly use the loss returned by the model to score each completion.
168 | for unrolled_example in unrolled_examples:
169 | encoded_example = encode_with_prompt_completion_format(unrolled_example, tokenizer, max_seq_length=None)
170 | # unsqueeze the batch dimension
171 | for key, value in encoded_example.items():
172 | encoded_example[key] = value.unsqueeze(0)
173 | if model.device.type == "cuda":
174 | encoded_example = {
175 | key: value.cuda() for key, value in encoded_example.items()
176 | }
177 | outputs = model(**encoded_example)
178 | loss = outputs.loss
179 | scores.append(-loss.item())
180 | if not disable_tqdm:
181 | progress.update(1)
182 |
183 | # roll up the scores
184 | rolled_up_scores = {}
185 | for unrolled_example, score in zip(unrolled_examples, scores):
186 | prompt = unrolled_example["prompt"]
187 | completion = unrolled_example["completion"]
188 | if prompt not in rolled_up_scores:
189 | rolled_up_scores[prompt] = {}
190 | rolled_up_scores[prompt][completion] = score
191 |
192 | return rolled_up_scores
193 |
194 |
195 |
196 | def load_hf_lm_and_tokenizer(
197 | model_name_or_path,
198 | tokenizer_name_or_path=None,
199 | device_map="auto",
200 | load_in_8bit=False,
201 | load_in_half=False,
202 | gptq_model=False,
203 | use_fast_tokenizer=True,
204 | padding_side="left",
205 | ):
206 |
207 | from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
208 |
209 | if not tokenizer_name_or_path:
210 | tokenizer_name_or_path = model_name_or_path
211 |
212 | is_chatglm2 = 'chatglm2' in tokenizer_name_or_path.lower() or 'chatglm2' in model_name_or_path
213 | is_qwen = 'qwen' in tokenizer_name_or_path.lower() or 'qwen' in model_name_or_path
214 |
215 | if is_chatglm2 or is_qwen:
216 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
217 | if is_qwen:
218 | tokenizer.eos_token = '<|endoftext|>'
219 | tokenizer.eos_token_id = 151643
220 | tokenizer.pad_token = tokenizer.eos_token
221 | tokenizer.pad_token_id = tokenizer.eos_token_id
222 | else:
223 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=use_fast_tokenizer)
224 | # set padding side to left for batch generation
225 | tokenizer.padding_side = padding_side
226 | # set pad token to eos token if pad token is not set (as is the case for llama models)
227 | if tokenizer.pad_token is None:
228 | tokenizer.pad_token = tokenizer.eos_token
229 | tokenizer.pad_token_id = tokenizer.eos_token_id
230 |
231 | if gptq_model:
232 | from auto_gptq import AutoGPTQForCausalLM
233 | model_wrapper = AutoGPTQForCausalLM.from_quantized(
234 | model_name_or_path, device="cuda:0", use_triton=True
235 | )
236 | model = model_wrapper.model
237 | elif load_in_8bit:
238 | model = AutoModelForCausalLM.from_pretrained(
239 | model_name_or_path,
240 | device_map=device_map,
241 | load_in_8bit=True
242 | )
243 | else:
244 | kwargs = {}
245 | model_class = AutoModelForCausalLM
246 | if is_chatglm2:
247 | kwargs = {'trust_remote_code': True}
248 | model_class = AutoModel
249 | elif is_qwen:
250 | kwargs = {'trust_remote_code': True}
251 | if device_map:
252 | model = model_class.from_pretrained(model_name_or_path, device_map=device_map, **kwargs)
253 | else:
254 | model = model_class.from_pretrained(model_name_or_path, **kwargs)
255 | if torch.cuda.is_available():
256 | model = model.cuda()
257 | if is_qwen:
258 | model.generation_config = GenerationConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
259 | model.generation_config.do_sample = False
260 | if not is_chatglm2 and not is_qwen and load_in_half:
261 | model = model.half()
262 | model.eval()
263 | return model, tokenizer
264 |
--------------------------------------------------------------------------------
/imgs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/.DS_Store
--------------------------------------------------------------------------------
/imgs/coreidea.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/coreidea.png
--------------------------------------------------------------------------------
/imgs/example1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/example1.png
--------------------------------------------------------------------------------
/imgs/example2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/example2.png
--------------------------------------------------------------------------------
/imgs/example3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/example3.png
--------------------------------------------------------------------------------
/imgs/example4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/example4.png
--------------------------------------------------------------------------------
/imgs/example5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/example5.jpg
--------------------------------------------------------------------------------
/imgs/summary.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/summary.jpg
--------------------------------------------------------------------------------
/imgs/triangle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/imgs/triangle.png
--------------------------------------------------------------------------------
/licenses/DATA_LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the “Licensor.” The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
408 |
--------------------------------------------------------------------------------
/licenses/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/licenses/WEIGHT_LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the “Licensor.” The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
408 |
--------------------------------------------------------------------------------
/paper/paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/Step-DPO/1f504ead5004f252025cb234017dfd9897cd542c/paper/paper.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jsonlines
2 | trl==0.12.2
3 | alignment-handbook @ git+https://github.com/huggingface/alignment-handbook
4 | wandb
5 | deepspeed
6 | accelerate
7 | flash_attn
8 | vllm
9 | antlr4-python3-runtime==4.11
10 |
--------------------------------------------------------------------------------
/stepdpo_trainer.py:
--------------------------------------------------------------------------------
1 | # Modified from trl/trl/trainer/dpo_trainer.py
2 | from typing import Dict, Optional, Union
3 |
4 | import torch
5 | from transformers import PreTrainedModel
6 | from trl import DPOTrainer
7 |
8 |
9 | class StepDPOTrainer(DPOTrainer):
10 |
11 | def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None) -> Dict:
12 | """Tokenize a single row from a DPO specific dataset.
13 |
14 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
15 | in case the prompt + chosen or prompt + rejected responses is/are too long. First
16 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
17 |
18 | We also create the labels for the chosen/rejected responses, which are of length equal to
19 | the sum of the length of the prompt and the chosen/rejected response, with
20 | label_pad_token_id for the prompt tokens.
21 | """
22 | batch = {}
23 | prompt = feature["prompt"]
24 | chosen = feature["chosen"]
25 | rejected = feature["rejected"]
26 |
27 | if not self.is_encoder_decoder:
28 | # Check issues below for more details
29 | # 1. https://github.com/huggingface/trl/issues/907
30 | # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
31 | # 3. https://github.com/LianjiaTech/BELLE/issues/337
32 |
33 | if not isinstance(prompt, str):
34 | raise ValueError(f"prompt should be an str but got {type(prompt)}")
35 | prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
36 | prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
37 |
38 | if not isinstance(chosen, str):
39 | raise ValueError(f"chosen should be an str but got {type(chosen)}")
40 | chosen_tokens = self.build_tokenized_answer(prompt, chosen)
41 |
42 | if not isinstance(rejected, str):
43 | raise ValueError(f"rejected should be an str but got {type(rejected)}")
44 | rejected_tokens = self.build_tokenized_answer(prompt, rejected)
45 |
46 | # Last prompt token might get merged by tokenizer and
47 | # it should not be included for generation if that happens
48 | prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
49 |
50 | chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
51 | rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
52 | prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
53 |
54 | for k, v in prompt_tokens.items():
55 | prompt_tokens[k] = v[:prompt_len_input_ids]
56 |
57 | # Make sure prompts only have one different token at most an
58 | # and length only differs by 1 at most
59 | num_diff_tokens = sum(
60 | [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
61 | )
62 | num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
63 | if num_diff_tokens > 1 or num_diff_len > 1:
64 | raise ValueError(
65 | "Chosen and rejected prompt_input_ids might only differ on the "
66 | "last token due to tokenizer merge ops."
67 | )
68 |
69 | # add BOS token to head of prompt
70 | if self.tokenizer.bos_token_id is not None:
71 | prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
72 | chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
73 | rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
74 |
75 | prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
76 | chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
77 | rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
78 |
79 | # # add EOS token to end of answer
80 | # chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
81 | # chosen_tokens["attention_mask"].append(1)
82 |
83 | # rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
84 | # rejected_tokens["attention_mask"].append(1)
85 |
86 | longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
87 |
88 | # if combined sequence is too long, truncate the prompt
89 | for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
90 | if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
91 | if self.truncation_mode == "keep_start":
92 | for k in ["prompt_input_ids", "prompt_attention_mask"]:
93 | answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
94 | elif self.truncation_mode == "keep_end":
95 | for k in ["prompt_input_ids", "prompt_attention_mask"]:
96 | answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
97 | else:
98 | raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
99 |
100 | # if that's still too long, truncate the response
101 | for answer_tokens in [chosen_tokens, rejected_tokens]:
102 | if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
103 | for k in ["input_ids", "attention_mask"]:
104 | answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
105 |
106 | # Create labels
107 | chosen_sequence_tokens = {
108 | k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
109 | }
110 | rejected_sequence_tokens = {
111 | k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
112 | }
113 | chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
114 | chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
115 | self.label_pad_token_id
116 | ] * len(chosen_tokens["prompt_input_ids"])
117 | rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
118 | rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
119 | self.label_pad_token_id
120 | ] * len(rejected_tokens["prompt_input_ids"])
121 |
122 | for k, toks in {
123 | "chosen_": chosen_sequence_tokens,
124 | "rejected_": rejected_sequence_tokens,
125 | "": prompt_tokens,
126 | }.items():
127 | for type_key, tokens in toks.items():
128 | if type_key == "token_type_ids":
129 | continue
130 | batch[f"{k}{type_key}"] = tokens
131 |
132 | # import pdb; pdb.set_trace()
133 |
134 | else:
135 | chosen_tokens = self.tokenizer(
136 | chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
137 | )
138 | rejected_tokens = self.tokenizer(
139 | rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
140 | )
141 | prompt_tokens = self.tokenizer(
142 | prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
143 | )
144 |
145 | batch["chosen_labels"] = chosen_tokens["input_ids"]
146 | batch["rejected_labels"] = rejected_tokens["input_ids"]
147 | batch["prompt_input_ids"] = prompt_tokens["input_ids"]
148 | batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
149 |
150 | if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
151 | batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
152 | labels=torch.tensor(batch["rejected_labels"])
153 | )
154 | batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
155 | labels=torch.tensor(batch["chosen_labels"])
156 | )
157 |
158 | return batch
159 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | import sys
4 | from dataclasses import dataclass, field
5 |
6 | import torch
7 | import transformers
8 | from alignment import (
9 | DataArguments,
10 | DPOConfig,
11 | H4ArgumentParser,
12 | ModelArguments,
13 | get_checkpoint,
14 | get_kbit_device_map,
15 | get_peft_config,
16 | get_quantization_config,
17 | get_tokenizer,
18 | )
19 | from datasets import load_dataset
20 | from stepdpo_trainer import StepDPOTrainer
21 | from transformers import set_seed
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | def apply_step_wise_chat_template(
26 | example,
27 | tokenizer,
28 | task,
29 | prompt,
30 | auto_insert_empty_system_msg: bool = True
31 | ):
32 | assert task in ["dpo"]
33 | if prompt == 'alpaca':
34 | prompt_input = (
35 | "Below is an instruction that describes a task, paired with an input that provides further context. "
36 | "Write a response that appropriately completes the request.\n\n"
37 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
38 | )
39 | prompt_no_input = (
40 | "Below is an instruction that describes a task. "
41 | "Write a response that appropriately completes the request.\n\n"
42 | "### Instruction:\n{instruction}\n\n### Response:"
43 | )
44 | elif prompt == 'deepseek-math':
45 | prompt_input = None
46 | prompt_no_input = "User: {instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant:"
47 | elif prompt == 'qwen2-boxed':
48 | prompt_input = None
49 | prompt_no_input = (
50 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
51 | "<|im_start|>user\n{instruction}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
52 | "<|im_start|>assistant\n"
53 | )
54 |
55 | text_chosen = example['chosen']
56 | text_rejected = example['rejected']
57 |
58 | if prompt == 'alpaca':
59 | if len(example['initial_reason_steps']) == 0:
60 | new_example = {
61 | 'prompt': prompt_no_input.format(instruction=example['prompt']),
62 | 'chosen': text_chosen,
63 | 'rejected': text_rejected,
64 | }
65 | else:
66 | new_example = {
67 | 'prompt': prompt_no_input.format(instruction=example['prompt']) + "\n" + example['initial_reason_steps'],
68 | 'chosen': text_chosen,
69 | 'rejected': text_rejected,
70 | }
71 | elif prompt == 'deepseek-math':
72 | if len(example['initial_reason_steps']) == 0:
73 | new_example = {
74 | 'prompt': prompt_no_input.format(instruction=example['prompt']),
75 | 'chosen': text_chosen,
76 | 'rejected': text_rejected,
77 | }
78 | else:
79 | new_example = {
80 | 'prompt': prompt_no_input.format(instruction=example['prompt']) + " " + example['initial_reason_steps'],
81 | 'chosen': text_chosen,
82 | 'rejected': text_rejected,
83 | }
84 | elif prompt == 'qwen2-boxed':
85 | if len(example['initial_reason_steps']) == 0:
86 | new_example = {
87 | 'prompt': prompt_no_input.format(instruction=example['prompt']),
88 | 'chosen': text_chosen,
89 | 'rejected': text_rejected,
90 | }
91 | else:
92 | new_example = {
93 | 'prompt': prompt_no_input.format(instruction=example['prompt']) + example['initial_reason_steps'],
94 | 'chosen': text_chosen,
95 | 'rejected': text_rejected,
96 | }
97 | return new_example
98 |
99 | @dataclass
100 | class StepDPOConfig(DPOConfig):
101 | data_path: str = field(default="xinlai/math-step-dpo-10K")
102 | prompt: str = field(default="alpaca")
103 |
104 | def main():
105 | parser = H4ArgumentParser((ModelArguments, DataArguments, StepDPOConfig))
106 | model_args, data_args, training_args = parser.parse()
107 |
108 | #######
109 | # Setup
110 | #######
111 | logging.basicConfig(
112 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
113 | datefmt="%Y-%m-%d %H:%M:%S",
114 | handlers=[logging.StreamHandler(sys.stdout)],
115 | )
116 | log_level = training_args.get_process_log_level()
117 | logger.setLevel(log_level)
118 | transformers.utils.logging.set_verbosity(log_level)
119 | transformers.utils.logging.enable_default_handler()
120 | transformers.utils.logging.enable_explicit_format()
121 |
122 | # Log on each process the small summary:
123 | logger.info(f"Model parameters {model_args}")
124 | logger.info(f"Data parameters {data_args}")
125 | logger.info(f"Training/evaluation parameters {training_args}")
126 |
127 | # Check for last checkpoint
128 | last_checkpoint = get_checkpoint(training_args)
129 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
130 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
131 |
132 | # Set seed for reproducibility
133 | set_seed(training_args.seed)
134 |
135 | ###############
136 | # Load datasets
137 | ###############
138 | if ".json" in training_args.data_path:
139 | raw_datasets = load_dataset(
140 | "json",
141 | data_files=training_args.data_path.split("||"),
142 | )
143 | else:
144 | raw_datasets = load_dataset(training_args.data_path)
145 |
146 | logger.info(
147 | f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
148 | )
149 | column_names = list(raw_datasets["train"].features)
150 |
151 | #####################################
152 | # Load tokenizer and process datasets
153 | #####################################
154 | data_args.truncation_side = "left" # Truncate from left to ensure we don't lose labels in final turn
155 | tokenizer = get_tokenizer(model_args, data_args)
156 |
157 | #####################
158 | # Apply chat template
159 | #####################
160 |
161 | raw_datasets = raw_datasets.map(
162 | apply_step_wise_chat_template,
163 | fn_kwargs={
164 | "tokenizer": tokenizer,
165 | "task": "dpo",
166 | "prompt": training_args.prompt,
167 | "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
168 | },
169 | num_proc=data_args.preprocessing_num_workers,
170 | remove_columns=column_names,
171 | desc="Formatting comparisons with prompt template",
172 | )
173 |
174 | # Log a few random samples from the training set:
175 | for index in random.sample(range(len(raw_datasets["train"])), 3):
176 | logger.info(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
177 | logger.info(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}")
178 | logger.info(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}")
179 |
180 | torch_dtype = (
181 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
182 | )
183 | quantization_config = get_quantization_config(model_args)
184 |
185 | model_kwargs = dict(
186 | revision=model_args.model_revision,
187 | trust_remote_code=model_args.trust_remote_code,
188 | use_flash_attention_2=model_args.use_flash_attention_2,
189 | torch_dtype=torch_dtype,
190 | use_cache=False if training_args.gradient_checkpointing else True,
191 | device_map=get_kbit_device_map() if quantization_config is not None else None,
192 | quantization_config=quantization_config,
193 | )
194 |
195 | model = model_args.model_name_or_path
196 | ref_model = model
197 | ref_model_kwargs = model_kwargs
198 |
199 | if model_args.use_peft is True:
200 | ref_model = None
201 | ref_model_kwargs = None
202 |
203 | #########################
204 | # Instantiate DPO trainer
205 | #########################
206 | trainer = StepDPOTrainer(
207 | model,
208 | ref_model,
209 | model_init_kwargs=model_kwargs,
210 | ref_model_init_kwargs=ref_model_kwargs,
211 | args=training_args,
212 | beta=training_args.beta,
213 | train_dataset=raw_datasets["train"],
214 | eval_dataset=raw_datasets["test"] if "test" in raw_datasets.keys() else None,
215 | tokenizer=tokenizer,
216 | max_length=training_args.max_length,
217 | max_prompt_length=training_args.max_prompt_length,
218 | peft_config=get_peft_config(model_args),
219 | loss_type=training_args.loss_type,
220 | )
221 |
222 | ###############
223 | # Training loop
224 | ###############
225 | checkpoint = None
226 | if training_args.resume_from_checkpoint is not None:
227 | checkpoint = training_args.resume_from_checkpoint
228 | elif last_checkpoint is not None:
229 | checkpoint = last_checkpoint
230 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
231 | metrics = train_result.metrics
232 | metrics["train_samples"] = len(raw_datasets["train"])
233 | trainer.log_metrics("train", metrics)
234 | trainer.save_metrics("train", metrics)
235 | trainer.save_state()
236 |
237 | logger.info("*** Training complete ***")
238 |
239 | ##################################
240 | # Save model and create model card
241 | ##################################
242 | logger.info("*** Save model ***")
243 | trainer.save_model(training_args.output_dir)
244 | logger.info(f"Model saved to {training_args.output_dir}")
245 |
246 | # Save everything else on main process
247 | kwargs = {
248 | "finetuned_from": model_args.model_name_or_path,
249 | "dataset": [training_args.data_path],
250 | "dataset_tags": [training_args.data_path],
251 | "tags": ["alignment-handbook"],
252 | }
253 | if trainer.accelerator.is_main_process:
254 | trainer.create_model_card(**kwargs)
255 | # Restore k,v cache for fast inference
256 | trainer.model.config.use_cache = True
257 | trainer.model.config.save_pretrained(training_args.output_dir)
258 |
259 | ##########
260 | # Evaluate
261 | ##########
262 | if training_args.do_eval:
263 | logger.info("*** Evaluate ***")
264 | metrics = trainer.evaluate()
265 | metrics["eval_samples"] = len(raw_datasets["test"])
266 | trainer.log_metrics("eval", metrics)
267 | trainer.save_metrics("eval", metrics)
268 |
269 | if training_args.push_to_hub is True:
270 | logger.info("Pushing to hub...")
271 | trainer.push_to_hub(**kwargs)
272 |
273 | logger.info("*** Training complete! ***")
274 |
275 |
276 | if __name__ == "__main__":
277 | main()
278 |
--------------------------------------------------------------------------------