├── .gitignore
├── README.md
├── data_curation
├── __init__.py
├── bespoke_data.py
├── length_comparsion.py
└── record.py
├── imgs
└── Training_pipeline.png
├── model
├── LMConfig.py
├── dataset.py
├── minimind_tokenizer
│ ├── merges.txt
│ ├── tokenizer.json
│ ├── tokenizer_config.json
│ └── vocab.json
└── model.py
├── tools
├── .gitattributes
├── README.md
├── __init__.py
├── base_instruct_evals.md
├── combine_data.py
├── convert_format.py
├── convert_to_data.py
├── eval.py
├── inference_and_check.py
├── label_math_difficulty.py
├── labeled_numina_difficulty
│ └── README.md
├── requirements.txt
├── response_rewrite.py
├── upload_hub.py
└── util
│ ├── apps
│ └── testing_util.py
│ ├── common.py
│ ├── livecodebench
│ └── testing_util.py
│ ├── math
│ └── testing_util.py
│ ├── model_utils.py
│ ├── prompts.py
│ ├── taco
│ ├── pyext2.py
│ └── testing_util.py
│ └── task_handlers.py
├── train
├── __init__.py
├── deepseed
│ ├── ds_z3_offload_config.json
│ ├── zero2_config.json
│ ├── zero3_config.json
│ └── zero3_config2.json
├── dpo_train.py
├── names.py
└── sft_train.py
└── utils
├── __init__.py
├── data_utils.py
├── eval
├── eval_utils.py
└── qwen_math_parser.py
├── load_model.py
├── model_utils.py
├── settings.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/*
2 | utils/upload.py
3 | results/
4 | scripts/
5 | tools/eval.sh
6 | tools/results
7 | train/sft_7b.sh
8 | train/wandb
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ThinkPO: Thinking Preference Optimization
2 | > a simple yet effective postSFT method that enhances long CoT reasoning without requiring new long CoT responses
3 |
4 |
5 |
6 |
7 |
8 | ---
9 | # News
10 | - 2025-02-22: We released our ThinkPO dataset: [Bespoke_dpo_filter](https://huggingface.co/datasets/VanWang/Bespoke_dpo_filter)
11 | - 2025-02-21: We released our four models:
12 | - [DeepSeek-R1-Distill-Qwen-7B-ThinkPO](https://huggingface.co/VanWang/DeepSeek-R1-Distill-Qwen-7B-ThinkPO)
13 | - [Bespoke-Stratos-7B-ThinkPO](https://huggingface.co/VanWang/Bespoke-Stratos-7B-ThinkPO)
14 | - [Bespoke-Stratos-7B-repro-SFT](https://huggingface.co/VanWang/Bespoke-Stratos-7B-repro-SFT)
15 | - [Bespoke-Stratos-7B-repro-ThinkPO](https://huggingface.co/VanWang/Bespoke-Stratos-7B-repro-ThinkPO)
16 |
17 | - 2025-02-19: We released our [paper](https://arxiv.org/abs/2502.13173).
18 |
19 | ---
20 | # Introduction
21 | - Here, we show the results of open-source reasoning LLMs before and after ThinkPO.
22 | ## Accuracy
23 |
24 | | Models | Dataset | SFT | Ours (+ThinkPO) | Improv. (%) |
25 | |:--------:|:--------:|:--------:|:--------:|:--------:|
26 | |DeepSeek-R1-Distill-Qwen-7B (Deepseek) |MATH500 | 87.4 | 91.2 | 4.3% |
27 | || AIME | 56.7 | 43.3 | -23.6% |
28 | || GPQA | 47.0 | 49.5 | 5.3% |
29 | || GSM8K | 87.2 | 87.6 | 0.5% |
30 | || Olympiad | 58.6 | 58.6 | 0.0% |
31 | |Bespoke-Stratos-7B (Bespoke)| MATH500 | 84.0 | 82.8 | -1.4% |
32 | || AIME | 20.0 | 23.3 | 16.5% |
33 | || GPQA | 37.9 | 43.4 | 14.5% |
34 | || GSM8K | 92.9 | 93.3 | 0.4% |
35 | || Olympiad | 44.1 | 48.5 | 10.0% |
36 |
37 | ## Average Response Length
38 |
39 | | Model | Dataset | SFT | Ours (+ThinkPO) | Improv. (%) |
40 | |:--------:|:--------:|:--------:|:--------:|:--------:|
41 | |DeepSeek-R1-Distill-Qwen-7B (Deepseek) | MATH500 | 2577 | 3021 | 17.2% |
42 | || AIME | 11419 | 12875 | 12.8% |
43 | || GPQA | 4895 | 5604 | 14.5% |
44 | || GSM8K | 619 | 668 | 7.9% |
45 | || Olympiad | 7196 | 7383 | 2.6% |
46 | |Bespoke-Stratos-7B (Bespoke)| MATH500 | 5696 | 6404 | 12.4% |
47 | || AIME | 19858 | 20079 | 1.1% |
48 | || GPQA | 5968 | 7301 | 22.3% |
49 | || GSM8K | 1404 | 1755 | 25.0% |
50 | || Olympiad | 11140 | 12204 | 9.6% |
51 |
52 | ---
53 | # Quick Use
54 | ## Settting
55 | - in ./utils/settings.py, you could set your project path, huggingface cache path and token
56 | ```python
57 | project_dir = 'path to your project'
58 | cache_dir = 'path to huggingface cache'
59 | hug_token = 'your huggingface token'
60 | ```
61 |
62 | ## SFT Train
63 | - if you wanna use multi-gpus to train Qwen2.5-7B-Instruct with SFT, you could use the following command:
64 | ```shell
65 | cd train
66 | deepspeed sft_train.py --model_name Instruct-7b --gradient_accumulation_steps 16 --dataset_name Bespoke --epoch 3 --lr 1e-5 --deepspeed ./deepseed/zero3_config2.json
67 | ```
68 |
69 | ## ThinkPO Train
70 | - if you wanna use multi-gpus to train Qwen2.5-7B-Instruct with ThinkPO, you could use the following command:
71 | ```shell
72 | cd train
73 | deepspeed dpo_train.py --lr 3e-7 --beta 0.01 --model Bespoke-7b --dataset Bespoke_dpo --gradient_accumulation_steps 12 --deepspeed ./deepseed/zero3_config2.json
74 | ```
75 |
76 | ## eval the model
77 | - The LLM Reasoning Evaluation refers to [Sky-Thought](https://github.com/NovaSky-AI/SkyThought/tree/main)
78 | - you could use the following command to evaluate the model, like datasets MATH500,AIME,GPQADiamond,GSM8K,OlympiadBenchMath
79 | ```shell
80 | cd ./tools
81 | python ./eval.py \
82 | --model deepseek-ai/DeepSeek-R1-Distill-Qwen-7B \
83 | --evals MATH500,AIME,GPQADiamond,GSM8K,OlympiadBenchMath \
84 | --tp 2 --output_file ./results/eval/DeepSeek-R1-Distill-Qwen-7B.txt \
85 | --result_dir ./results/generated
86 | ```
87 |
88 | ## citation
89 | ```bibtex
90 | @misc{yang2025thinkingpreferenceoptimization,
91 | title={Thinking Preference Optimization},
92 | author={Wang Yang and Hongye Jin and Jingfeng Yang and Vipin Chaudhary and Xiaotian Han},
93 | year={2025},
94 | eprint={2502.13173},
95 | archivePrefix={arXiv},
96 | primaryClass={cs.LG},
97 | url={https://arxiv.org/abs/2502.13173},
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/data_curation/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.abspath(os.path.join(__file__, "../../")))
--------------------------------------------------------------------------------
/data_curation/bespoke_data.py:
--------------------------------------------------------------------------------
1 | import __init__
2 | from utils.utils import *
3 | from utils.data_utils import read_saved_results, save_results
4 | from utils.load_model import *
5 | from utils.data_utils import *
6 | from utils.eval.eval_utils import check_math_correctness
7 | from tqdm import tqdm
8 | import copy
9 |
10 |
11 | def generate_is_ok(predict, refrence):
12 | if not check_math_correctness(refrence, predict): return False
13 | return True
14 |
15 | MAX_LENGTH = 1024 * 1
16 |
17 | save_file = set_global(f'./data/final/Bespoke_dpo.jsonl')
18 | save_data = read_saved_results(save_file)
19 | ref_data = load_data('bespokelabs/Bespoke-Stratos-17k', 'huggingface')['train']
20 |
21 |
22 | tokenizer = load_tokenizer("NovaSky-AI/Sky-T1-32B-Preview")
23 | llm = init_vllm_model("Qwen/Qwen2.5-Math-7B-Instruct", 1)
24 | qwwn_tokenizer = load_tokenizer("Qwen/Qwen2.5-Math-7B-Instruct")
25 |
26 | for data in tqdm(ref_data):
27 | chosen = data['conversations'][1]['value']
28 | messages = [
29 | {"role": "system", "content": "You are a helpful and harmless assistant. Please reason step by step."},
30 | {"role": "user", "content": data['conversations'][0]['value']}
31 | ]
32 | prompt = qwwn_tokenizer.apply_chat_template(
33 | messages,
34 | tokenize=False,
35 | add_generation_prompt=True
36 | )
37 | rejected = vllm_generate(llm, prompt, MAX_LENGTH)[0]
38 | if len(rejected) == 0 or not generate_is_ok(rejected, chosen):
39 | continue
40 | rejected_ans = copy.deepcopy(data['conversations'])
41 | rejected_ans[1]['value'] = rejected
42 | save_data.append({
43 | 'system': data['system'],
44 | 'chosen': data['conversations'],
45 | 'rejected': rejected_ans
46 | })
47 | save_results(save_file, save_data[-1])
48 |
49 |
50 |
51 | save_file = set_global(f'./data/final/Bespoke_dpo.jsonl')
52 | save_data = []
53 | all_data = read_saved_results(save_file)
54 |
--------------------------------------------------------------------------------
/data_curation/length_comparsion.py:
--------------------------------------------------------------------------------
1 | import __init__
2 | from utils.utils import *
3 | from utils.data_utils import save_all_results, read_saved_results
4 | from utils.load_model import *
5 | from utils.data_utils import *
6 | from utils.eval.eval_utils import check_math_correctness
7 | from tqdm import tqdm
8 |
9 |
10 | def get_middle_1000(lst):
11 | n = len(lst)
12 | if n <= 1000:
13 | return lst
14 | start = (n - 1000) // 2
15 | return lst[start:start + 1000]
16 |
17 | def generate_is_ok(predict, refrence):
18 | if not check_math_correctness(refrence, predict): return False
19 | return True
20 |
21 | save_file = set_global(f'./data/final/Bespoke_dpo.jsonl')
22 | all_data = read_saved_results(save_file)
23 | tokenizer = load_tokenizer("NovaSky-AI/Sky-T1-32B-Preview")
24 |
25 | lengths, save_data = [], []
26 | for d in tqdm(all_data):
27 | chosen, rejected = d['chosen'][2]['content'], d['rejected'][2]['content']
28 | chosen_len, rejected_len = len(tokenizer.encode(chosen, add_special_tokens=True)), len(tokenizer.encode(rejected, add_special_tokens=True))
29 | if rejected_len < 8*1024:
30 | d['dealta'] = rejected_len - chosen_len
31 | chosen, rejected = d['chosen'], d['rejected']
32 | d['chosen'], d['rejected'] = rejected, chosen
33 | save_data.append(d)
34 | lengths.append(d['dealta'])
35 |
36 | all_data = sorted(save_data, key=lambda x: x["dealta"])
37 |
38 | middle_1000 = get_middle_1000(all_data)
39 | save_all_results(set_global(f'./data/final/Bespoke_dpo_filter_len_middle.jsonl'), middle_1000)
40 | save_all_results(set_global(f'./data/final/Bespoke_dpo_filter_len_short.jsonl'), all_data[:1000])
41 | save_all_results(set_global(f'./data/final/Bespoke_dpo_filter_len_long.jsonl'), all_data[-1000:])
42 |
--------------------------------------------------------------------------------
/data_curation/record.py:
--------------------------------------------------------------------------------
1 | import __init__
2 | import os
3 | from utils.data_utils import *
4 | from utils.eval.eval_utils import find_box
5 |
6 |
7 | def count_total_occurrences(text, words):
8 | return sum(text.lower().count(word) for word in words)
9 | word_list = ["wait", "hmm"]
10 |
11 | folder_path = "/home/wxy320/ondemand/program/SkyThought/skythought/tools/results/steps"
12 | subdirs = [d for d in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, d))]
13 |
14 | results = dict()
15 | for step in subdirs:
16 | step = os.path.join(folder_path, step)
17 | json_lens = dict()
18 | json_files = [f for f in os.listdir(step) if f.endswith(".json") and os.path.isfile(os.path.join(step, f))]
19 | for file_path in json_files:
20 | with open(os.path.join(step,file_path), "r", encoding="utf-8") as f:
21 | data = json.load(f)
22 | all_prompts_len, num, counts = 0, 0 ,0
23 | for prompt, d in data.items():
24 | if '<|end_of_solution|>' in d['responses']['0.7']['content']:
25 | all_prompts_len = all_prompts_len + d['token_usages']['0.7']['completion_tokens']
26 | num = num +1
27 | if d['responses']['0.7']['correctness']:
28 | counts =counts+count_total_occurrences(d['responses']['0.7']['content'], word_list)
29 | json_lens[file_path] = {'lengths':all_prompts_len / num, "wait":counts}
30 | results[step.split('/')[-1]] = json_lens
31 |
32 |
33 | with open('results.json', "w", encoding="utf-8") as f:
34 | json.dump(results, f, ensure_ascii=False, indent=4)
--------------------------------------------------------------------------------
/imgs/Training_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uservan/ThinkPO/23aa90e636e0beafabfa24b681f94216bbbaba9a/imgs/Training_pipeline.png
--------------------------------------------------------------------------------
/model/LMConfig.py:
--------------------------------------------------------------------------------
1 | from transformers import PretrainedConfig
2 | from typing import List
3 |
4 |
5 | class LMConfig(PretrainedConfig):
6 | model_type = "minimind"
7 |
8 | def __init__(
9 | self,
10 | dim: int = 512,
11 | n_layers: int = 8,
12 | n_heads: int = 16,
13 | n_kv_heads: int = 8,
14 | vocab_size: int = 6400,
15 | hidden_dim: int = None,
16 | multiple_of: int = 64,
17 | norm_eps: float = 1e-5,
18 | max_seq_len: int = 512,
19 | dropout: float = 0.0,
20 | flash_attn: bool = True,
21 | ####################################################
22 | # Here are the specific configurations of MOE
23 | # When use_moe is false, the following is invalid
24 | ####################################################
25 | use_moe: bool = False,
26 | num_experts_per_tok=2,
27 | n_routed_experts=4,
28 | n_shared_experts: bool = True,
29 | scoring_func='softmax',
30 | aux_loss_alpha=0.01,
31 | seq_aux=True,
32 | norm_topk_prob=True,
33 | **kwargs,
34 | ):
35 | self.dim = dim
36 | self.n_layers = n_layers
37 | self.n_heads = n_heads
38 | self.n_kv_heads = n_kv_heads
39 | self.vocab_size = vocab_size
40 | self.hidden_dim = hidden_dim
41 | self.multiple_of = multiple_of
42 | self.norm_eps = norm_eps
43 | self.max_seq_len = max_seq_len
44 | self.dropout = dropout
45 | self.flash_attn = flash_attn
46 | ####################################################
47 | # Here are the specific configurations of MOE
48 | # When use_moe is false, the following is invalid
49 | ####################################################
50 | self.use_moe = use_moe
51 | self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
52 | self.n_routed_experts = n_routed_experts # 总的专家数量
53 | self.n_shared_experts = n_shared_experts # 共享专家
54 | self.scoring_func = scoring_func # 评分函数,默认为'softmax'
55 | self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
56 | self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
57 | self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
58 | super().__init__(**kwargs)
59 |
--------------------------------------------------------------------------------
/model/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import re
4 |
5 | import pandas as pd
6 | import numpy as np
7 | from torch.utils.data import Dataset, DataLoader
8 | import torch
9 | from sklearn.model_selection import train_test_split
10 | import os
11 |
12 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
13 |
14 |
15 | class PretrainDataset(Dataset):
16 | def __init__(self, df, tokenizer, max_length=4096):
17 | super().__init__()
18 | self.df = df
19 | self.tokenizer = tokenizer
20 | self.max_length = max_length
21 | self.padding = 0
22 |
23 | def __len__(self):
24 | return self.df.shape[0]
25 |
26 | def __getitem__(self, index: int):
27 | #
28 | sample = self.df.iloc[index]
29 | text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
30 | input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
31 | text_len = len(input_id)
32 | # 没满最大长度的剩余部分
33 | padding_len = self.max_length - text_len
34 | input_id = input_id + [self.padding] * padding_len
35 | # 0表示不计算损失
36 | loss_mask = [1] * text_len + [0] * padding_len
37 |
38 | input_id = np.array(input_id)
39 | X = np.array(input_id[:-1]).astype(np.int64)
40 | Y = np.array(input_id[1:]).astype(np.int64)
41 | loss_mask = np.array(loss_mask[1:]).astype(np.int64)
42 | return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
43 |
44 | class SFTDataset_List(Dataset):
45 | def __init__(self, list, tokenizer, max_seq_len):
46 | super().__init__()
47 | self.list = list
48 | self.tokenizer = tokenizer
49 | self.MAX_LENGTH = max_seq_len
50 |
51 | def __len__(self):
52 | return len(self.list)
53 |
54 | def __getitem__(self, index: int):
55 | example = self.list[index]
56 | input_ids, attention_mask, labels = [], [], []
57 | inputs, targets = example["prompt"], example["response"]
58 | instruction = self.tokenizer(inputs+'\n', add_special_tokens=False)
59 | response = self.tokenizer(targets, add_special_tokens=False)
60 | input_ids = instruction["input_ids"] + response["input_ids"]
61 | attention_mask = instruction["attention_mask"] + response["attention_mask"]
62 | labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
63 | if len(input_ids) > self.MAX_LENGTH:
64 | input_ids = input_ids[:self.MAX_LENGTH]
65 | attention_mask = attention_mask[:self.MAX_LENGTH]
66 | labels = labels[:self.MAX_LENGTH]
67 | # else:
68 | # padding_len = self.MAX_LENGTH - len(input_ids)
69 | # input_ids = input_ids + [self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)] * padding_len
70 | # labels = labels + [-100]* padding_len
71 | # attention_mask = attention_mask + [0] * padding_len
72 | return torch.from_numpy(np.array(input_ids)), torch.from_numpy(np.array(labels)), torch.from_numpy(np.array(attention_mask))
73 |
74 | class SFTDataset(Dataset):
75 | def __init__(self, df, tokenizer, max_length=1024, prompt_max_len=512, answer_max_len=256):
76 | super().__init__()
77 | self.df = df
78 | self.max_length = max_length
79 | self.prompt_max_len = prompt_max_len
80 | self.answer_max_len = answer_max_len
81 | #
82 | self.tokenizer = tokenizer
83 | self.padding = 0
84 | self.bos_id = self.tokenizer('assistant').data['input_ids']
85 |
86 | def __len__(self):
87 | return self.df.shape[0]
88 |
89 | def find_sublist_index(self, main_list, sub_list) -> int:
90 | last_index = -1
91 | for i in range(len(main_list) - len(sub_list) + 1):
92 | if main_list[i:i + len(sub_list)] == sub_list:
93 | last_index = i
94 | return last_index
95 |
96 | def safe_eval(self, s):
97 | try:
98 | res = eval(s)
99 | except Exception as e:
100 | return []
101 | return res
102 |
103 | def __getitem__(self, index: int):
104 | #
105 | sample = self.df.iloc[index]
106 | history = self.safe_eval(sample['history'])
107 | q = str(sample['q'])
108 | a = str(sample['a'])
109 |
110 | messages = []
111 | for history_message in history:
112 | if len(history_message) <= 1:
113 | continue
114 | messages.append(
115 | {"role": 'user', "content": str(history_message[0])[:self.max_length // 2]}
116 | )
117 | messages.append(
118 | {"role": 'assistant', "content": str(history_message[1])[:self.max_length // 2]}
119 | )
120 |
121 | messages += [
122 | {"role": "user", "content": q},
123 | {"role": "assistant", "content": a},
124 | ]
125 | new_prompt = self.tokenizer.apply_chat_template(
126 | messages,
127 | tokenize=False,
128 | add_generation_prompt=True
129 | )
130 | input_id = self.tokenizer(new_prompt).data['input_ids'][:self.max_length]
131 |
132 | # 实际长度
133 | question_length = self.find_sublist_index(input_id, self.bos_id) + len(self.bos_id)
134 | # 没满最大长度的剩余部分
135 | padding_len = self.max_length - len(input_id)
136 | input_id = input_id + [self.padding] * padding_len
137 | mask_len = len(input_id) - question_length - padding_len
138 | # 0表示不计算损失
139 | loss_mask = [0] * question_length + [1] * (mask_len) + [0] * padding_len
140 |
141 | input_id = np.array(input_id)
142 | X = np.array(input_id[:-1]).astype(np.int64)
143 | Y = np.array(input_id[1:]).astype(np.int64)
144 | loss_mask = np.array(loss_mask[1:]).astype(np.int64)
145 |
146 | X_tensor = torch.from_numpy(X)
147 | Y_tensor = torch.from_numpy(Y)
148 | loss_mask_tensor = torch.from_numpy(loss_mask)
149 |
150 | return X_tensor, Y_tensor, loss_mask_tensor
151 |
152 |
153 | if __name__ == "__main__":
154 | pass
155 |
--------------------------------------------------------------------------------
/model/minimind_tokenizer/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_bos_token": false,
3 | "add_eos_token": false,
4 | "add_prefix_space": true,
5 | "added_tokens_decoder": {
6 | "0": {
7 | "content": "",
8 | "lstrip": false,
9 | "normalized": false,
10 | "rstrip": false,
11 | "single_word": false,
12 | "special": true
13 | },
14 | "1": {
15 | "content": "",
16 | "lstrip": false,
17 | "normalized": false,
18 | "rstrip": false,
19 | "single_word": false,
20 | "special": true
21 | },
22 | "2": {
23 | "content": "",
24 | "lstrip": false,
25 | "normalized": false,
26 | "rstrip": false,
27 | "single_word": false,
28 | "special": true
29 | }
30 | },
31 | "additional_special_tokens": [],
32 | "bos_token": "",
33 | "clean_up_tokenization_spaces": false,
34 | "eos_token": "",
35 | "legacy": true,
36 | "model_max_length": 1000000000000000019884624838656,
37 | "pad_token": null,
38 | "sp_model_kwargs": {},
39 | "spaces_between_special_tokens": false,
40 | "tokenizer_class": "PreTrainedTokenizerFast",
41 | "unk_token": "",
42 | "use_default_system_prompt": false,
43 | "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ 'user\\n' + content + '\\nassistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '' + '\\n' }}{% endif %}{% endfor %}"
44 | }
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import struct
3 | import inspect
4 | import time
5 |
6 | from .LMConfig import LMConfig
7 | from typing import Any, Optional, Tuple
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn
12 | from transformers import PreTrainedModel
13 | from transformers.modeling_outputs import CausalLMOutputWithPast
14 |
15 |
16 | class RMSNorm(torch.nn.Module):
17 | def __init__(self, dim: int, eps: float):
18 | super().__init__()
19 | self.eps = eps
20 | self.weight = nn.Parameter(torch.ones(dim))
21 |
22 | def _norm(self, x):
23 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
24 |
25 | def forward(self, x):
26 | output = self._norm(x.float()).type_as(x)
27 | return output * self.weight
28 |
29 |
30 | def precompute_pos_cis(dim: int, end: int, theta: float = 10000.0):
31 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
32 | t = torch.arange(end, device=freqs.device) # type: ignore
33 | freqs = torch.outer(t, freqs).float() # type: ignore
34 | pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
35 | return pos_cis
36 |
37 |
38 | def apply_rotary_emb(xq, xk, pos_cis):
39 | def unite_shape(pos_cis, x):
40 | ndim = x.ndim
41 | assert 0 <= 1 < ndim
42 | assert pos_cis.shape == (x.shape[1], x.shape[-1])
43 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
44 | return pos_cis.view(*shape)
45 |
46 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
47 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
48 | pos_cis = unite_shape(pos_cis, xq_)
49 | xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
50 | xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
51 | return xq_out.type_as(xq), xk_out.type_as(xk)
52 |
53 |
54 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
55 | """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
56 | bs, slen, n_kv_heads, head_dim = x.shape
57 | if n_rep == 1:
58 | return x
59 | return (
60 | x[:, :, :, None, :]
61 | .expand(bs, slen, n_kv_heads, n_rep, head_dim)
62 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
63 | )
64 |
65 |
66 | class Attention(nn.Module):
67 | def __init__(self, args: LMConfig):
68 | super().__init__()
69 | self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
70 | assert args.n_heads % self.n_kv_heads == 0
71 | self.n_local_heads = args.n_heads
72 | self.n_local_kv_heads = self.n_kv_heads
73 | self.n_rep = self.n_local_heads // self.n_local_kv_heads
74 | self.head_dim = args.dim // args.n_heads
75 | self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
76 | self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
77 | self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
78 | self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
79 | self.k_cache, self.v_cache = None, None
80 | self.attn_dropout = nn.Dropout(args.dropout)
81 | self.resid_dropout = nn.Dropout(args.dropout)
82 | self.dropout = args.dropout
83 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
84 |
85 | # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
86 | mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
87 | mask = torch.triu(mask, diagonal=1)
88 | self.register_buffer("mask", mask, persistent=False)
89 |
90 | def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
91 | bsz, seqlen, _ = x.shape
92 |
93 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
94 |
95 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
96 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
97 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
98 |
99 | xq, xk = apply_rotary_emb(xq, xk, pos_cis)
100 |
101 | # 更高效的kv_cache实现
102 | if kv_cache and self.eval():
103 | if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)):
104 | xk = torch.cat((self.k_cache, xk), dim=1)
105 | xv = torch.cat((self.v_cache, xv), dim=1)
106 | self.k_cache, self.v_cache = xk, xv
107 |
108 | xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
109 | xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
110 |
111 | xq = xq.transpose(1, 2)
112 | xk = xk.transpose(1, 2)
113 | xv = xv.transpose(1, 2)
114 |
115 | if self.flash and seqlen != 1:
116 | output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
117 | dropout_p=self.dropout if self.training else 0.0,
118 | is_causal=True)
119 | else:
120 | scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
121 | scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
122 | scores = F.softmax(scores.float(), dim=-1).type_as(xq)
123 | scores = self.attn_dropout(scores)
124 | output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
125 |
126 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
127 |
128 | output = self.wo(output)
129 | output = self.resid_dropout(output)
130 | return output
131 |
132 |
133 | class FeedForward(nn.Module):
134 | def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
135 | super().__init__()
136 | if hidden_dim is None:
137 | hidden_dim = 4 * dim
138 | hidden_dim = int(2 * hidden_dim / 3)
139 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
140 | self.w1 = nn.Linear(dim, hidden_dim, bias=False)
141 | self.w2 = nn.Linear(hidden_dim, dim, bias=False)
142 | self.w3 = nn.Linear(dim, hidden_dim, bias=False)
143 | self.dropout = nn.Dropout(dropout)
144 |
145 | def forward(self, x):
146 | return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
147 |
148 |
149 | class MoEGate(nn.Module):
150 | def __init__(self, config: LMConfig):
151 | super().__init__()
152 | self.config = config
153 | self.top_k = config.num_experts_per_tok
154 | self.n_routed_experts = config.n_routed_experts
155 |
156 | self.scoring_func = config.scoring_func
157 | self.alpha = config.aux_loss_alpha
158 | self.seq_aux = config.seq_aux
159 |
160 | self.norm_topk_prob = config.norm_topk_prob
161 | self.gating_dim = config.dim
162 | self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
163 | self.reset_parameters()
164 |
165 | def reset_parameters(self) -> None:
166 | import torch.nn.init as init
167 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
168 |
169 | def forward(self, hidden_states):
170 | bsz, seq_len, h = hidden_states.shape
171 |
172 | hidden_states = hidden_states.view(-1, h)
173 | logits = F.linear(hidden_states, self.weight, None)
174 | if self.scoring_func == 'softmax':
175 | scores = logits.softmax(dim=-1)
176 | else:
177 | raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
178 |
179 | topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
180 |
181 | if self.top_k > 1 and self.norm_topk_prob:
182 | denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
183 | topk_weight = topk_weight / denominator
184 |
185 | if self.training and self.alpha > 0.0:
186 | scores_for_aux = scores
187 | aux_topk = self.top_k
188 | topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
189 | if self.seq_aux:
190 | scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
191 | ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
192 | ce.scatter_add_(1, topk_idx_for_aux_loss,
193 | torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
194 | seq_len * aux_topk / self.n_routed_experts)
195 | aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
196 | else:
197 | mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
198 | ce = mask_ce.float().mean(0)
199 | Pi = scores_for_aux.mean(0)
200 | fi = ce * self.n_routed_experts
201 | aux_loss = (Pi * fi).sum() * self.alpha
202 | else:
203 | aux_loss = None
204 | return topk_idx, topk_weight, aux_loss
205 |
206 |
207 | class MOEFeedForward(nn.Module):
208 | def __init__(self, config: LMConfig):
209 | super().__init__()
210 | self.config = config
211 | self.experts = nn.ModuleList([
212 | FeedForward(
213 | dim=config.dim,
214 | hidden_dim=config.hidden_dim,
215 | multiple_of=config.multiple_of,
216 | dropout=config.dropout,
217 | )
218 | for _ in range(config.n_routed_experts)
219 | ])
220 |
221 | self.gate = MoEGate(config)
222 | if config.n_shared_experts is not None:
223 | self.shared_experts = FeedForward(
224 | dim=config.dim,
225 | hidden_dim=config.hidden_dim,
226 | multiple_of=config.multiple_of,
227 | dropout=config.dropout,
228 | )
229 |
230 | def forward(self, x):
231 | identity = x
232 | orig_shape = x.shape
233 | bsz, seq_len, _ = x.shape
234 |
235 | # 使用门控机制选择专家
236 | topk_idx, topk_weight, aux_loss = self.gate(x)
237 |
238 | x = x.view(-1, x.shape[-1])
239 | flat_topk_idx = topk_idx.view(-1)
240 |
241 | if self.training:
242 | # 训练模式下,重复输入数据
243 | x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
244 | y = torch.empty_like(x, dtype=torch.float16)
245 | for i, expert in enumerate(self.experts):
246 | y[flat_topk_idx == i] = expert(x[flat_topk_idx == i])
247 | y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
248 | y = y.view(*orig_shape)
249 | else:
250 | # 推理模式下,只选择最优专家
251 | y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
252 |
253 | if self.config.n_shared_experts is not None:
254 | y = y + self.shared_experts(identity)
255 |
256 | return y
257 |
258 | @torch.no_grad()
259 | def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
260 | expert_cache = torch.zeros_like(x)
261 | idxs = flat_expert_indices.argsort()
262 | tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
263 | token_idxs = idxs // self.config.num_experts_per_tok
264 | # 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52]
265 | # 当token_idxs=[3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...]
266 | # 意味着当token_idxs[:6] -> [3, 7, 19, 21, 24, 25, 4]位置的token都由专家0处理,token_idxs[6:15]位置的token都由专家1处理......
267 | for i, end_idx in enumerate(tokens_per_expert):
268 | start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
269 | if start_idx == end_idx:
270 | continue
271 | expert = self.experts[i]
272 | exp_token_idx = token_idxs[start_idx:end_idx]
273 | expert_tokens = x[exp_token_idx]
274 | expert_out = expert(expert_tokens)
275 | expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
276 | # 使用 scatter_add_ 进行 sum 操作
277 | expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
278 |
279 | return expert_cache
280 |
281 |
282 | class TransformerBlock(nn.Module):
283 | def __init__(self, layer_id: int, args: LMConfig):
284 | super().__init__()
285 | self.n_heads = args.n_heads
286 | self.dim = args.dim
287 | self.head_dim = args.dim // args.n_heads
288 | self.attention = Attention(args)
289 |
290 | self.layer_id = layer_id
291 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
292 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
293 |
294 | if args.use_moe:
295 | self.feed_forward = MOEFeedForward(args)
296 | else:
297 | self.feed_forward = FeedForward(
298 | dim=args.dim,
299 | hidden_dim=args.hidden_dim,
300 | multiple_of=args.multiple_of,
301 | dropout=args.dropout,
302 | )
303 |
304 | def forward(self, x, pos_cis, kv_cache=False):
305 | h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache)
306 | out = h + self.feed_forward(self.ffn_norm(h))
307 | return out
308 |
309 |
310 | class Transformer(PreTrainedModel):
311 | config_class = LMConfig
312 | last_loss: Optional[torch.Tensor]
313 |
314 | def __init__(self, params: LMConfig = None):
315 | super().__init__(params)
316 | if not params:
317 | params = LMConfig()
318 | self.params = params
319 | self.vocab_size = params.vocab_size
320 | self.n_layers = params.n_layers
321 |
322 | self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
323 | self.dropout = nn.Dropout(params.dropout)
324 | self.layers = torch.nn.ModuleList()
325 | for layer_id in range(self.n_layers):
326 | self.layers.append(TransformerBlock(layer_id, params))
327 | self.norm = RMSNorm(params.dim, eps=params.norm_eps)
328 | self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
329 | self.tok_embeddings.weight = self.output.weight
330 | pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
331 | self.register_buffer("pos_cis", pos_cis, persistent=False)
332 |
333 | self.apply(self._init_weights)
334 |
335 | for pn, p in self.named_parameters():
336 | if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
337 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers))
338 |
339 | self.last_loss = None
340 | self.OUT = CausalLMOutputWithPast()
341 | self._no_split_modules = [name for name, _ in self.named_modules()]
342 |
343 | def _init_weights(self, module):
344 | if isinstance(module, nn.Linear):
345 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
346 | if module.bias is not None:
347 | torch.nn.init.zeros_(module.bias)
348 | elif isinstance(module, nn.Embedding):
349 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
350 |
351 | def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
352 | kv_cache=False, **keyargs):
353 | current_idx = 0
354 | if 'input_ids' in keyargs:
355 | tokens = keyargs['input_ids']
356 | if 'attention_mask' in keyargs:
357 | targets = keyargs['attention_mask']
358 | if 'current_idx' in keyargs:
359 | current_idx = int(keyargs['current_idx'])
360 |
361 | _bsz, seqlen = tokens.shape
362 | h = self.tok_embeddings(tokens)
363 | h = self.dropout(h)
364 | pos_cis = self.pos_cis[current_idx:current_idx + seqlen]
365 | for idx, layer in enumerate(self.layers):
366 | h = layer(h, pos_cis, kv_cache)
367 |
368 | h = self.norm(h)
369 |
370 | if targets is not None:
371 | logits = self.output(h)
372 | self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
373 | ignore_index=0, reduction='none')
374 | else:
375 | logits = self.output(h[:, [-1], :])
376 | self.last_loss = None
377 |
378 | self.OUT.__setitem__('logits', logits)
379 | self.OUT.__setitem__('last_loss', self.last_loss)
380 | return self.OUT
381 |
382 | @torch.inference_mode()
383 | def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True):
384 | # rp: repetition_penalty
385 | index = idx.shape[1]
386 | init_inference = True
387 | while idx.shape[1] < max_new_tokens - 1:
388 | if init_inference or not kv_cache:
389 | inference_res, init_inference = self(idx, kv_cache=kv_cache), False
390 | else:
391 | inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1)
392 |
393 | logits = inference_res.logits
394 | logits = logits[:, -1, :]
395 |
396 | for token in set(idx.tolist()[0]):
397 | logits[:, token] /= rp
398 |
399 | if temperature == 0.0:
400 | _, idx_next = torch.topk(logits, k=1, dim=-1)
401 | else:
402 | logits = logits / temperature
403 | if top_k is not None:
404 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
405 | logits[logits < v[:, [-1]]] = -float('Inf')
406 |
407 | probs = F.softmax(logits, dim=-1)
408 | idx_next = torch.multinomial(probs, num_samples=1, generator=None)
409 |
410 | if idx_next == eos:
411 | break
412 |
413 | idx = torch.cat((idx, idx_next), dim=1)
414 | if stream:
415 | yield idx[:, index:]
416 |
417 | if not stream:
418 | yield idx[:, index:]
419 |
420 | @torch.inference_mode()
421 | def eval_answer(self, idx):
422 | idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
423 | inference_res = self(idx_cond)
424 | logits = inference_res.logits
425 | logits = logits[:, -1, :]
426 | return logits
427 |
--------------------------------------------------------------------------------
/tools/.gitattributes:
--------------------------------------------------------------------------------
1 | labeled_numina_difficulty/labeled_amc_aime_0_-1.json filter=lfs diff=lfs merge=lfs -text
2 | labeled_numina_difficulty/labeled_math_0_-1.json filter=lfs diff=lfs merge=lfs -text
3 | labeled_numina_difficulty/labeled_olympiads_0_-1.json filter=lfs diff=lfs merge=lfs -text
4 |
--------------------------------------------------------------------------------
/tools/README.md:
--------------------------------------------------------------------------------
1 | # Data Generation, Processing, and Evaluation Tools
2 | This document describes the steps to training data curation and evaluation scripts for Sky-T1.
3 |
4 | ## Requirements
5 | First create the environment as follows.
6 | ```shell
7 | conda create -n eval python==3.10
8 | conda activate eval
9 | pip install -r requirements.txt
10 | ```
11 |
12 | For running OpenAI model, export the OpenAI key.
13 | ```shell
14 | export OPENAI_API_KEY={openai_api_key}
15 | ```
16 |
17 | ## Training Data Curation
18 | ### Step 0 (Optional, only for NUMINA math dataset): Label Math Difficulty from NUMINA
19 | Put one or multiple OPENAI_API_KEY in a file, e.g. keys.txt (one per line). If there is more than one key, the script will use them in a round-robin way to speed up generation. Label Math difficulty using GPT-4o-mini:
20 | #### Example usage:
21 | ```
22 | python label_math_difficulty.py --source [amc_aime, math, olympiads] --keys keys.txt
23 | ```
24 | The expected output is labeled_source_0_-1.json. We also provide instructions to download these files under the labeled_numina_difficulty folder (Download from HuggingFace).
25 |
26 | ### Step 1: Data Inference
27 | Inference the results from QwQ on several datasets. In preview version, we use data from the following dataset.
28 |
29 | ```shell
30 | python inference_and_check.py --dataset APPS --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --inference
31 |
32 | python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source MEDIUM --filter-difficulty --result-dir $SKYT_HOME/data --inference
33 |
34 | python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --inference
35 |
36 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty --result-dir $SKYT_HOME/data --inference
37 |
38 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty --result-dir $SKYT_HOME/data --inference
39 |
40 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source olympiads --end 20000 --filter-difficulty --result-dir $SKYT_HOME/data --inference
41 | ```
42 |
43 | ### Step 2: Format the response
44 | After obtaining a list file for training data, convert them to a unified format (Note: This uses GPT-4o-mini to rewrite. The output is long and takes ~100 dollars for our preview data).
45 | ```shell
46 | python convert_format.py --input_dir $SKYT_HOME/data --keys keys.txt
47 | ```
48 |
49 | ### Step 3: Reject Sampling on the formatted data (Example Usage with previous script)
50 | ```shell
51 | python inference_and_check.py --dataset APPS --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --check
52 | ```
53 | Similar for other datasets.
54 |
55 | ### Convert to ShareGPT format for training
56 | After obtaining multiple converted files, merge them together and convert to the ShareGPT format to perform training. In our preview model, we also add the science and riddle portion from the [STILL-2 model](https://arxiv.org/pdf/2412.09413), where interested readers can download their part of data and simply concatenating to the data obtained above.
57 | ```shell
58 | python convert_to_data.py --input_dir $SKYT_HOME/data --output $SKYT_HOME/data/train_data.json
59 | ```
60 |
61 |
62 | ## Generation and Evaluation
63 | The file `inference_and_check.py` provides convenient methods for generating sequences (e.g., for distillation or benchmark evaluation) and checking whether the generated solutions are correct (e.g., for reject sampling or benchmark evaluation).
64 |
65 | ### Distill and Reject Sampling
66 | Currently we support distill and reject sampling from various self-hosted models for NUMINA, APPS, and TACO datasets. For NUMINA, the source can be one from `[amc_aime, math, olympiads]`.
67 | #### Example Usage
68 |
69 | ```shell
70 | python inference_and_check.py --dataset APPS --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data
71 |
72 | python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source MEDIUM --filter-difficulty --result-dir $SKYT_HOME/data
73 |
74 | python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data
75 |
76 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 4 --math_difficulty_upper_bound 9
77 |
78 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 1 --math_difficulty_upper_bound 9
79 |
80 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source olympiads --end 20000 --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 9 --math_difficulty_upper_bound 9
81 | ```
82 |
83 |
84 | #### Best-of-N Inference and Check
85 | ```bash
86 | python inference_and_check.py --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --inference --temperatures 0.7 --n 64
87 | python inference_and_check.py --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --check --temperatures 0.7 --n 8
88 | ```
89 |
90 | ### Benchmark Evaluations
91 | We provide a wrapper script `eval.py` to conveniently run reasoning benchmarks. We currently support `AIME`, `MATH500`, `GPQADiamond`, and `MMLU`. This script can be used to launch evaluations for multiple benchmarks, then aggregate and log the accuracy for all benchmarks.
92 |
93 | **Note**: The `GPQADiamond` dataset is gated and requires first receiving access at this Huggingface [link](https://huggingface.co/datasets/Idavidrein/gpqa) (which is granted immediately), then logging into your Huggingface account in your terminal session with `huggingface-cli login`.
94 |
95 | **NOTE**: For reproducing `Sky-T1-32B-Preview` results on `AIME` and `GPQADiamond` dataset, pass in temperatures as `0.7`.
96 |
97 | ```shell
98 | python eval.py --model NovaSky-AI/Sky-T1-32B-Preview --evals=AIME,GPQADiamond --tp=8 --output_file=results.txt --temperatures 0.7
99 | ```
100 |
101 | #### Example Usage
102 | ```shell
103 | python eval.py --model Qwen/QwQ-32B-Preview --evals=AIME,MATH500,GPQADiamond --tp=8 --output_file=results.txt
104 | ```
105 |
106 | Example result: `{"AIME": , "MATH500": , "GPQADiamond": }`
107 |
108 | ## Response Rewriting
109 | The file `response_rewrite.py` provides a pipeline for filtering and rewriting responses generated with `inference_and_check.py`. We use `response_rewrite.py` to create preference pairs for preference optimization (e.g., DPO, SimPO), however, the logic can be edited for alternative filtering and rewriting steps. Details of the implemented logic can be found in `response_rewrite.py` or on [this blog post](https://novasky-ai.github.io/posts/reduce-overthinking).
110 |
111 | To use our preference optimization pipeline, first generate and score multiple responses using `inference_and_check.py`. For example:
112 |
113 | ```shell
114 | python inference_and_check.py --inference --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --temperatures 0.7 --n 8
115 | python inference_and_check.py --check --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --temperatures 0.7 --n 8
116 | ```
117 |
118 | Then, use `response_rewrite.py` to process the responses into preference pairs. By default, the shortest correct responses will be used as positive examples and the longest correct responses will be used as negative samples. The argument `--SILC` can be used to also include short incorrect responses as negative examples and long correct repsonses as positive samples.
119 |
120 | ```shell
121 | python response_rewrite.py --SILC --rewrite-model meta-llama/Meta-Llama-3-8B-Instruct --target-model NovaSky-AI/Sky-T1-32B-Preview --dataset [PATH_TO_GENERATED_RESPONSES] --result-dir ./ --checkpoint --tp 8
122 | ```
123 |
124 | The `--checkpoint` argument can optionally be used to save intermediate files of the processed data between steps, in case of failure.
125 |
126 | The resulting `.json` files can be used to train a model with preference optimization algorithms. See the `/train/` directory for more details.
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.abspath(os.path.join(__file__, "../../")))
4 | from utils.settings import *
5 |
6 |
7 |
--------------------------------------------------------------------------------
/tools/base_instruct_evals.md:
--------------------------------------------------------------------------------
1 | # Reproducing results on non-reasoning benchmarks
2 |
3 | For the full set of results, see [here](./README.md#results-on-qa-and-instruction-following-benchmarks).
4 |
5 | ## Installation instructions
6 |
7 | 1. For `lm_eval`, install the package by executing the following :
8 |
9 | ```bash
10 | git clone https://github.com/EleutherAI/lm-evaluation-harness
11 | cd lm-evaluation-harness
12 | git checkout 703fbff
13 | pip install -e ".[ifeval]"
14 | ```
15 |
16 | For more details, you can refer to the official instructions [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/703fbffd6fe5`e136bbb9d884cb40844e5503ae5d?tab=readme-ov-file#install). We report results with commit https://github.com/EleutherAI/lm-evaluation-harness/commit/703fbffd6fe5e136bbb9d884cb40844e5503ae5d
17 |
18 | 2. For `fastchat`, follow the instructions [here](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge#install). The current implementation of Fastchat is based on OpenAI version <= 0.28.0. For making use of the latest vllm backend, it is recommended to migrate the `llm_judge` folder to use openai>=1.0.0. You can run `openai migrate` for the fastchat codebase or follow the PR [here](https://github.com/lm-sys/FastChat/pull/2915/files)
19 | 3. For `BFCL`, you can follow the official instructions [here](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#basic-installation). We further evaulate on all test categories, which requires [setting up environment variables](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#setting-up-environment-variables), and [obtaining API keys for executable test categories](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#api-keys-for-executable-test-categories). Make sure to use changes from [this PR](https://github.com/ShishirPatil/gorilla/pull/888) for QwQ and Sky-T1 model support.
20 | 4. For `Arena-Hard` results, you can follow the instructions [here](https://github.com/lmarena/arena-hard-auto). We use `gpt-4-1106-preview` as the judge.
21 |
22 | ## Commands for reproducing results
23 |
24 | All the benchmarks were run on a 8xH100 machine with the `vllm` backend. If you're running on a different device, make sure to tweak `tensor_parallel_size` and if needed the `batch_size` arguments. Expect some variance in scores (+/- 1%) for different evaluation settings (ex: `tensor_parallel_size`)
25 |
26 | All the commands below are given for `NovaSky-AI/Sky-T1-32B-Preview`. Simply substitute the model name for `Qwen/Qwen-2.5-32B-Instruct`. For `Qwen/QwQ-32B-Preview`, we further make use of two arguments `revision=refs/pr/58,tokenizer_revision=refs/pr/58` to use a corrected revision of QwQ. For more details on this, see https://github.com/NovaSky-AI/SkyThought/pull/26#issuecomment-2606435601.
27 |
28 | ### MMLU (0 shot; no CoT)
29 |
30 | ```bash
31 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mmlu --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn
32 | ```
33 |
34 | For QwQ, you would do
35 |
36 | ```bash
37 | lm_eval --model vllm --model_args pretrained=Qwen/QwQ-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048revision=refs/pr/58,tokenizer_revision=refs/pr/58 --tasks mmlu --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn
38 | ```
39 |
40 | ### MMLU (5 shot; no CoT)
41 |
42 | ```bash
43 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mmlu --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn --num_fewshot 5
44 | ```
45 |
46 | ### ARC-C (0 shot; no CoT)
47 |
48 | ```bash
49 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks arc_challenge --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn
50 | ```
51 |
52 | ### IFEval
53 |
54 | ```bash
55 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.9,data_parallel_size=1 --tasks leaderboard_ifeval --trust_remote_code --batch_size auto --apply_chat_template --fewshot_as_multiturn
56 | ```
57 |
58 | We use the `prompt_level_strict_acc` metric following Qwen-2.5.
59 |
60 | ### MGSM (native CoT)
61 |
62 | ```bash
63 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mgsm_direct --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn
64 | ```
65 |
66 | We report the average value of `flexible-extract` filter.
67 |
68 | ### MGSM (8-shot; native CoT)
69 |
70 | ```bash
71 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mgsm_direct --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn --num_fewshot 8
72 | ```
73 |
74 | ### LLM-as-a-Judge
75 |
76 | We use the default settings - with `max_tokens` 1024 and the `gpt-4` judge. We observe that some reasoning models like `Qwen/QwQ-32B-Preview` are unable to provide brief responses sometimes and thus get truncated responses at the used `max_tokens`. While this will effect the final rating, given the context length limitations of the commonly used `gpt-4` judge (8K tokens), we stick to the 1024 `max_tokens` budget for consistency.
77 |
78 | 1. First, serve the model with vLLM
79 |
80 |
81 | ```bash
82 | vllm serve NovaSky-AI/Sky-T1-32B-Preview --dtype auto --tensor-parallel-size 8 --gpu-memory-utilization 0.9
83 | ```
84 |
85 | For `Qwen/QwQ-32B-Preview`, use
86 |
87 | ```bash
88 | vllm serve Qwen/QwQ-32B-Preview --dtype auto --tensor-parallel-size 8 --gpu-memory-utilization 0.9 --revision refs/pr/58 --tokenizer-revision refs/pr/58
89 | ```
90 |
91 | 2. Next, generate model response
92 |
93 | ```bash
94 | python gen_api_answer.py --model NovaSky-AI/Sky-T1-32B-Preview --openai-api-base http://localhost:8000/v1 --parallel 50
95 | ```
96 |
97 | Note: The generated results will be in `data/model_answer//.jsonl` . Move them to the root folder `data/model_answer/`
98 |
99 | 3. After generating responses for all the models, evaluate with the default settings
100 |
101 | ```bash
102 | export OPENAI_API_KEY=XXXXXX # set the OpenAI API key
103 | python gen_judgment.py --model-list Sky-T1-32B-Preview QwQ-32B-Preview Qwen2.5-32B-Instruct --parallel 2
104 | ```
105 | 4. Get MTBench scores (we use the average score of both turns)
106 |
107 | ```bash
108 | python show_result.py
109 | ```
110 |
111 | ### BFCL-v3
112 |
113 | Our results are reported on `test-category` `all` . Make sure to get the API keys for the executable test categories by following the instructions [here](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#api-keys-for-executable-test-categories)
114 |
115 | Run
116 |
117 | ```bash
118 | bfcl generate --model NovaSky-AI/Sky-T1-32B-Preview --test-category all --backend vllm --num-gpus 8 --gpu-memory-utilization 0.9
119 | ```
120 |
121 | For evaluation, you can simply run
122 |
123 | ```bash
124 | bfcl evaluate --model Qwen/QwQ-32B-Preview,NovaSky-AI/Sky-T1-32B-Preview,Qwen/Qwen2.5-32B-Instruct --test-category all --api-sanity-check
125 | ```
126 | ### Arena Hard
127 | For `Arena-Hard`, we use the following script to start a `TGI` service for generating answers
128 | ```bash
129 | hf_pat=
130 | model=NovaSky-AI/Sky-T1-32B-Preview
131 | volume=/mnt/local_storage/data/cache
132 | port=1996
133 |
134 | huggingface-cli download $model
135 | sudo docker run --gpus 8 -e HUGGING_FACE_HUB_TOKEN=$hf_pat --shm-size 2000g -p $port:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0 --model-id $model --max-input-length 8192 --max-batch-total-tokens 8193 --max-batch-prefill-tokens 8193 --max-total-tokens 8193 --sharded true
136 | ```
137 | For running the `gen_answer.py` script, we use the following `config_api` yaml setting. For `qwq-32b-preview`, we explicitly specify the system prompt as `You are a helpful and harmless assistant. You are Qwen developed by Alibaba.` to avoid the CoT prompt.
138 | ```yaml
139 | ...
140 | sky-T1-32B-Preview:
141 | model_name: sky-T1-32B-Preview
142 | endpoints:
143 | - api_base: http://localhost:1996/v1
144 | api_key: empty
145 | api_type: openai
146 | parallel: 8
147 | ...
148 | ```
149 | and finally for `gen_judgment.py`, we use `gpt-4-1106-preview` as the judge.
150 |
151 | #### Supplementary results for Arena-Hard
152 |
153 | Here are some supplementary results for Arena-Hard, compared with o1-mini which is the best performing model on this benchmark (as of Jan 2025).
154 |
155 | | model | score | rating_q025 | rating_q975 | CI | avg_tokens | date |
156 | |-------|--------|------------|-------------|-------|------------|-------|
157 | | o1-mini-2024-09-12 | 91.98 | 90.88 | 93.12 | (-1.10, +1.14) | 1399.0 | 2025-01-18 |
158 | | sky-T1-32B-Preview | 74.79 | 72.28 | 76.8 | (-2.51, +2.01) | 847.0 | 2025-01-18 |
159 | | qwen2.5-32b-instruct | 66.51 | 64.55 | 68.4 | (-1.96, +1.89) | 611.0 | 2025-01-18 |
160 | | qwq-32b-preview | 52.6 | 50.86 | 54.91 | (-1.74, +2.31) | 1005.0 | 2025-01-23 |
161 |
162 | For more details, see: https://github.com/NovaSky-AI/SkyThought/pull/26#issuecomment-2599525551
163 |
--------------------------------------------------------------------------------
/tools/combine_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from util.prompts import system_prompt
4 |
5 | still2_jsonl_file = "../../data/public_long_form_thought_data_5k.jsonl"
6 | code_json_file = "../../data/converted_apps_long_form_thought_data_5k.json"
7 | output_file = "../../data/converted_v2_long_form_thought_data_9k.json"
8 |
9 | # Load the JSONL file
10 | still2_data = []
11 | with open(still2_jsonl_file, "r") as f:
12 | for line in f:
13 | still2_data.append(json.loads(line.strip()))
14 | # print(math_data)
15 |
16 | # Process the data into the desired format
17 | all_data = []
18 | code_num = 0
19 |
20 | for entry in still2_data:
21 | question = entry["question"]
22 | combined_text = entry["combined_text"]
23 | domain = entry["domain"]
24 | if domain != "code":
25 | # Create the conversation format
26 | conversations = [
27 | {"from": "user", "value": question},
28 | {"from": "assistant", "value": combined_text}
29 | ]
30 |
31 | # Prepare the final structure
32 | cur_data = {
33 | "system": system_prompt,
34 | "conversations": conversations
35 | }
36 | all_data.append(cur_data)
37 | else:
38 | code_num += 1
39 |
40 | print(code_num)
41 | with open(code_json_file, "r") as f:
42 | code_data = json.load(f)
43 | # print(code_data[0])
44 |
45 | all_data.extend(code_data)
46 | print(f"First item slice before shuffle: {all_data[0]['conversations'][-1]['value'][-50:-1]}")
47 | random.shuffle(all_data)
48 | print(f"First item slice after shuffle: {all_data[0]['conversations'][-1]['value'][-50:-1]}")
49 | print(len(all_data))
50 |
51 | # Save the converted data to the output file
52 | with open(output_file, "w") as f:
53 | json.dump(all_data, f, indent=4)
54 |
55 | print(f"Conversion completed. The data has been saved to {output_file} with {len(all_data)} data.")
56 |
57 |
--------------------------------------------------------------------------------
/tools/convert_format.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from tqdm import tqdm
4 | import multiprocessing as mp
5 | import openai
6 | from itertools import cycle
7 | import time
8 | import os
9 | from util.prompts import convert_prompt, convert_prompt_example
10 |
11 | global args
12 | # Function to set the OpenAI API key
13 | def set_openai_key(api_key):
14 | openai.api_key = api_key
15 |
16 | # GPT API processing function with retry logic
17 | def process_content(content, api_key):
18 | # Set the OpenAI key for this request
19 | set_openai_key(api_key)
20 |
21 | # GPT prompt
22 | prompt = convert_prompt.format(example=convert_prompt_example, content=content)
23 |
24 | retries = 3
25 | while retries > 0:
26 | try:
27 | # OpenAI API call
28 | response = openai.chat.completions.create(
29 | model="gpt-4o-mini",
30 | messages=[
31 | {"role": "system", "content": "You are a solution format convertor."},
32 | {"role": "user", "content": prompt}
33 | ],
34 | max_tokens=16384,
35 | temperature=0.7
36 | )
37 | return response.choices[0].message.content
38 | except openai.RateLimitError:
39 | retries -= 1
40 | if retries == 0:
41 | return "Error: Rate limit reached and retries exhausted."
42 | print(f"Sleep for 5 seconds for API limit.")
43 | time.sleep(5)
44 | except Exception as e:
45 | return f"Error processing content: {e}"
46 |
47 | # Function for multiprocessing
48 | def process_entry(entry, api_key_cycle):
49 | key, values = entry
50 | content = values["responses"]["0.7"]["content"]
51 |
52 | # Get the next API key from the cycle
53 | api_key = next(api_key_cycle)
54 |
55 | processed = process_content(content, api_key)
56 | values["responses"]["0.7"]["processed_content"] = processed
57 |
58 | return key, values
59 |
60 | # Wrapper function for multiprocessing
61 | def process_entry_wrapper(args):
62 | return process_entry(*args)
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser(description="Process content and save results.")
66 | parser.add_argument("--input_dir", type=str, help="Input directory containing JSON files.")
67 | parser.add_argument("--keys", type=str, help="File containing OpenAI API keys (one per line).")
68 |
69 | global args
70 | args = parser.parse_args()
71 |
72 | # Load API keys and prepare a round-robin cycle
73 | with open(args.keys, "r") as f:
74 | api_keys = [line.strip() for line in f if line.strip()]
75 | api_key_cycle = cycle(api_keys)
76 |
77 | # Process each file in the input directory
78 | for filename in os.listdir(args.input_dir):
79 | if filename.endswith(".json"):
80 | input_path = os.path.join(args.input_dir, filename)
81 |
82 | # Load the data
83 | with open(input_path, "r") as f:
84 | data = json.load(f)
85 |
86 | # Prepare output file
87 | output_file = os.path.join(args.input_dir, f"converted_{filename}")
88 |
89 | # Use multiprocessing to process the content
90 | results = []
91 | with mp.Pool(os.cpu_count()) as pool:
92 | tasks = [(entry, api_key_cycle) for entry in data.items()]
93 | for result in tqdm(pool.imap(process_entry_wrapper, tasks), total=len(data)):
94 | results.append(result)
95 |
96 | # Aggregate and write results in the main process
97 | aggregated_data = {key: values for key, values in results}
98 | with open(output_file, "w") as f:
99 | json.dump(aggregated_data, f, indent=4)
100 |
101 | print(f"Processed data saved to {output_file}")
102 |
--------------------------------------------------------------------------------
/tools/convert_to_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from util.prompts import system_prompt
5 |
6 | def main():
7 | parser = argparse.ArgumentParser(description="Convert JSON data for processing.")
8 | parser.add_argument("--input_dir", type=str, help="Directory containing input JSON files.")
9 | parser.add_argument("--output", type=str, help="Output JSON file.")
10 | args = parser.parse_args()
11 |
12 | all_data = []
13 |
14 | # Iterate through all files in the input directory
15 | for filename in os.listdir(args.input_dir):
16 | if filename.endswith(".json") and filename.startswith("converted"):
17 | filepath = os.path.join(args.input_dir, filename)
18 | with open(filepath, "r") as f:
19 | cur_data = json.load(f)
20 |
21 | for _, v in cur_data.items():
22 | prompt = v["prompt"]
23 | response_data = v["responses"]
24 |
25 | for cur_temp, cur_temp_response in response_data.items():
26 | # Only support 0.7 for this version
27 | assert cur_temp == "0.7", "Only support a single temperature=0.7 now."
28 | # Accept this data
29 | if cur_temp_response["correctness"]:
30 | # Create the conversation format
31 | conversations = [
32 | {"from": "user", "value": prompt},
33 | {"from": "assistant", "value": cur_temp_response["processed_content"]}
34 | ]
35 |
36 | # Prepare the final structure
37 | cur_data = {
38 | "system": system_prompt,
39 | "conversations": conversations
40 | }
41 | all_data.append(cur_data)
42 |
43 | # Save the converted data to the output file
44 | with open(args.output, "w") as f:
45 | json.dump(all_data, f, indent=4)
46 |
47 | print(f"Conversion completed. The data has been saved to {args.output} with {len(all_data)} data.")
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/tools/eval.py:
--------------------------------------------------------------------------------
1 | import __init__
2 | import argparse
3 | import subprocess
4 | import os
5 | import json
6 |
7 |
8 | # Define eval to split mapping
9 | eval_to_split = {
10 | "MATH500": "test",
11 | "MinvervaMath": "test",
12 | "AIME": "train",
13 | "GPQADiamond": "train",
14 | "MMLU": "test",
15 | "MMLUPro": "test",
16 | "LiveCodeBench": "test",
17 | "GSM8K": "test",
18 | "ARC-C": "test",
19 | "AMC23": "train",
20 | "OlympiadBenchMath": "train",
21 | }
22 |
23 | def parse_arguments():
24 | parser = argparse.ArgumentParser(description="Process model path, prompt format, and evals to run.")
25 | parser.add_argument("--model", required=True, type=str, default='bespokelabs/Bespoke-Stratos-7B', help="Path to the model.")
26 | parser.add_argument("--evals", required=True, type=str, default='MinvervaMath,AIME,MATH500,GPQADiamond,GSM8K,OlympiadBenchMath', help="Comma-separated list of evals to run (no spaces).")
27 | parser.add_argument("--tp", type=int, default=1, help="Tensor Parallelism Degree")
28 | parser.add_argument("--filter-difficulty", action="store_true", help="Filter difficulty.")
29 | parser.add_argument("--source", type=str, help="Source for the dataset.")
30 | parser.add_argument("--output_file", required=True, type=str, default='./eval/Bespoke-7B.txt', help="Output file to write results to.")
31 | parser.add_argument("--temperatures", type=float, nargs="+", default=[0.7], help="Temperature for sampling.")
32 | parser.add_argument("--result_dir", type=str, default='./results/generated', help="Source for the dataset.")
33 | return parser.parse_args()
34 |
35 | def extract_accuracy_from_output(output):
36 | # Iterate through all lines from the end to the beginning
37 | lines = output.splitlines()[::-1]
38 | for line in lines:
39 | try:
40 | # Attempt to parse a JSON object from the line
41 | data = json.loads(line.replace("'", '"'))
42 | if "acc" in data:
43 | return data["acc"]
44 | except json.JSONDecodeError:
45 | continue
46 | return None
47 |
48 | def write_logs_to_file(logs, output_file):
49 | try:
50 | with open(output_file, "w") as file:
51 | file.write(logs)
52 | print(f"Logs successfully written to {output_file}")
53 | except IOError as e:
54 | print(f"Failed to write logs to file {output_file}: {e}")
55 |
56 | def main():
57 | args = parse_arguments()
58 |
59 | # Extract the arguments
60 | model_path = args.model
61 | evals = args.evals.split(",")
62 | output_file = args.output_file
63 | tp = args.tp
64 | temperatures = [str(t) for t in args.temperatures]
65 |
66 | script_path = "inference_and_check.py"
67 |
68 | # Hold all logs
69 | all_logs = ""
70 | results = {}
71 | result_dir = args.model.split("/")[-2] if 'check' in args.model.split("/")[-1] else args.model.split("/")[-1]
72 | result_dir = os.path.join(args.result_dir,result_dir)
73 | # Run the Python command for each eval and collect logs
74 | for eval_name in evals:
75 | command = [
76 | "python", script_path,
77 | "--model", model_path,
78 | "--dataset", eval_name,
79 | "--split", eval_to_split[eval_name],
80 | "--tp", str(tp),
81 | "--result-dir", result_dir,
82 | "--temperatures"
83 | ]
84 | command.extend(temperatures) # Add temperatures as separate arguments
85 |
86 | if args.filter_difficulty:
87 | assert args.source != "", "No source passed for filtering difficulty."
88 | command.append("--filter-difficulty")
89 | command.append("--source")
90 | command.append(args.source)
91 | print(f"Running eval {eval_name} with command {command}")
92 | all_logs += f"\nRunning eval: {eval_name} with command {command}\n"
93 | try:
94 | with subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) as proc:
95 | output_lines = []
96 | for line in proc.stdout:
97 | print(line, end="") # Stream output to the console
98 | output_lines.append(line)
99 | all_logs += line
100 | proc.wait()
101 | if proc.returncode != 0:
102 | raise subprocess.CalledProcessError(proc.returncode, command)
103 |
104 | # Capture output for post-processing
105 | output = "".join(output_lines)
106 | accuracy = extract_accuracy_from_output(output)
107 | results[eval_name] = accuracy
108 |
109 | except subprocess.CalledProcessError as e:
110 | error_message = f"Error occurred while running eval {eval_name}: {e}\n"
111 | print(error_message)
112 | all_logs += error_message
113 |
114 | # Write logs of all stdout / stderr to a file
115 | write_logs_to_file(all_logs, output_file)
116 |
117 | print("Results:")
118 | print(results)
119 |
120 | if __name__ == "__main__":
121 | main()
122 |
--------------------------------------------------------------------------------
/tools/inference_and_check.py:
--------------------------------------------------------------------------------
1 | import __init__
2 | import json
3 | import argparse
4 | import re
5 | from concurrent.futures import ProcessPoolExecutor, as_completed
6 | from vllm import LLM, SamplingParams
7 | from tqdm import tqdm
8 | from util.task_handlers import *
9 | from util.model_utils import *
10 | from openai import OpenAI
11 | import concurrent.futures
12 | from functools import partial
13 |
14 | class NumpyEncoder(json.JSONEncoder):
15 | def default(self, obj):
16 | if isinstance(obj, np.ndarray):
17 | return obj.tolist()
18 | return super().default(obj)
19 |
20 | def fetch_response_openai(llm, model_name, max_tokens, temp, prompt):
21 | model_name = model_name.replace("openai/", "")
22 | if "o1" in model_name:
23 | # O1 doesn't support system prompt
24 | # NOTE: might want to implement this inside handler instead
25 | for p in prompt:
26 | p["role"] = "user"
27 |
28 | response = llm.chat.completions.create(
29 | model=model_name,
30 | messages=prompt,
31 | n=1,
32 | temperature=1, # has to be 1
33 | max_completion_tokens=max_tokens
34 | )
35 | else:
36 | response = llm.chat.completions.create(
37 | model=model_name,
38 | messages=prompt,
39 | n=1,
40 | temperature=temp,
41 | max_tokens=max_tokens
42 | )
43 | return response
44 |
45 | def perform_inference_and_check(handler: TaskHandler, temperatures, max_tokens, result_file, llm, system_prompt, args):
46 | results = handler.load_existing_results(result_file)
47 | print(f"Loaded {len(results)} existing results.")
48 | train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, \
49 | filter_difficulty=args.filter_difficulty, args=args)
50 | remaining_data = handler.process_remaining_data(train_data, results)
51 | conversations = handler.make_conversations(remaining_data, system_prompt, args.model)
52 |
53 | for temp in temperatures:
54 |
55 | if args.model.startswith("openai"):
56 | fetch_partial = partial(fetch_response_openai, llm, args.model, max_tokens, temp)
57 |
58 | with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e:
59 | responses = list(e.map(fetch_partial, conversations))
60 |
61 | else:
62 | sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temp)
63 | responses = llm.chat(messages=conversations, sampling_params=sampling_params, use_tqdm=True)
64 |
65 | total_correct = 0
66 | total_finish = 0
67 | with ProcessPoolExecutor(max_workers=32) as executor:
68 | # future_to_task = {
69 | # executor.submit(handler.update_results, remaining_data[idx], response): idx
70 | # for idx, response in enumerate(responses)
71 | # }
72 | future_to_task = {}
73 | token_usages = {}
74 | for idx, response in enumerate(responses):
75 | if args.model.startswith("openai"):
76 | response_str = response.choices[0].message.content.strip()
77 | else:
78 | response_str = response.outputs[0].text.strip()
79 | future_to_task[executor.submit(handler.update_results, remaining_data[idx], response_str)] = idx
80 | # print(f"Request output: {response}")
81 |
82 | if args.model.startswith("openai"):
83 | token_usages[idx] = response.usage
84 | else:
85 | token_usages[idx] = {
86 | "completion_tokens": len(response.outputs[0].token_ids),
87 | "prompt_tokens": len(response.prompt_token_ids)
88 | }
89 |
90 | for future in tqdm(as_completed(future_to_task), total=len(future_to_task), desc="Processing Generations"):
91 | idx = future_to_task[future]
92 | response_entry = future.result()
93 | total_correct += response_entry["correctness"]
94 | total_finish += 1
95 |
96 | problem_key = remaining_data[idx][handler.get_question_key()]
97 | if problem_key not in results:
98 | results[problem_key] = remaining_data[idx]
99 | if isinstance(handler, NUMINATaskHandler):
100 | results[problem_key]["messages"] = ""
101 | results[problem_key]["responses"] = {}
102 | results[problem_key]["token_usages"] = {}
103 | prompt = conversations[idx][1]["content"]
104 | results[problem_key]["prompt"] = prompt
105 | results[problem_key]["input_conversation"] = conversations[idx]
106 |
107 | results[problem_key]["responses"][str(temp)] = response_entry
108 |
109 | if args.model.startswith("openai"):
110 | results[problem_key]["token_usages"][str(temp)] = {
111 | "completion_tokens": token_usages[idx].completion_tokens,
112 | "prompt_tokens": token_usages[idx].prompt_tokens,
113 | }
114 | else:
115 | # TODO: vLLM model, can it do the same thing
116 | results[problem_key]["token_usages"][str(temp)] = token_usages[idx]
117 |
118 | print(f"Final acc: {total_correct}/{total_finish}")
119 | acc = round(total_correct / total_finish, 4) if total_finish > 0 else 0
120 | print(json.dumps({"acc": acc}))
121 |
122 | completion_tokens = [
123 | results[key].get("token_usages", {}).get(str(temp), {}).get("completion_tokens", 0)
124 | for key in results for temp in temperatures
125 | ]
126 | prompt_tokens = [
127 | results[key].get("token_usages", {}).get(str(temp), {}).get("prompt_tokens", 0)
128 | for key in results for temp in temperatures
129 | ]
130 |
131 | # Token usage summary
132 | result_dir, result_name = os.path.split(result_file)
133 | token_usage_dir = os.path.join(result_dir, "token_usage")
134 | os.makedirs(token_usage_dir, exist_ok=True)
135 |
136 | # Construct the token usage result file path
137 | token_usage_result_file = os.path.join(token_usage_dir, result_name)
138 |
139 | # Prepare the token usage dictionary
140 | token_dict = {
141 | "completion_tokens": sum(completion_tokens),
142 | "prompt_tokens": sum(prompt_tokens),
143 | "avg_completion_tokens": round(sum(completion_tokens) / len(completion_tokens), 3) if completion_tokens else 0,
144 | "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0,
145 | }
146 |
147 | # Save the token usage dictionary to the result file
148 | with open(token_usage_result_file, "w") as f:
149 | json.dump(token_dict, f, indent=4)
150 |
151 | print(f"Token usage saved to {token_usage_result_file}")
152 |
153 | with open(result_file, 'w', encoding='utf-8') as file:
154 | json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder)
155 |
156 | def perform_check(handler: TaskHandler, temperatures, result_file, args):
157 | results = handler.load_existing_results(result_file)
158 | print(f"Loaded {len(results)} existing results.")
159 |
160 | train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, \
161 | filter_difficulty=args.filter_difficulty, args=args)
162 | remaining_data = handler.process_remaining_data(train_data, {})
163 |
164 | tasks = []
165 | for item in remaining_data:
166 | problem_key = item[handler.get_question_key()]
167 | # If this item exists in the results file, check each temperature
168 | if problem_key in results and "responses" in results[problem_key]:
169 | for temp in temperatures:
170 | if str(temp) in results[problem_key]["responses"]:
171 | response_entries = results[problem_key]["responses"][str(temp)]
172 | for sample_id, response_entry in enumerate(response_entries):
173 | if sample_id > (args.n - 1): continue
174 | if True or response_entry["correctness"] is None:
175 | processed = "processed_content" in response_entry
176 | tasks.append((item, temp, response_entry["processed_content"] if processed else response_entry["content"], sample_id))
177 |
178 | print(f"Found {len(tasks)} responses requiring reject sampling...")
179 |
180 | total_correct = 0
181 | total_finish = 0
182 | correct = { temp: {} for temp in temperatures }
183 | with ProcessPoolExecutor(max_workers=32) as executor:
184 | future_to_task = {
185 | executor.submit(handler.update_results, item, content): (item, temp, sample_id)
186 | for (item, temp, content, sample_id) in tasks
187 | }
188 |
189 | # 4. Collect the results as they finish.
190 | for future in tqdm(as_completed(future_to_task), total=len(future_to_task), desc="Processing Reject Sampling"):
191 | item, temp, sample_id = future_to_task[future]
192 | new_response_entry = future.result()
193 | total_correct += new_response_entry["correctness"]
194 | total_finish += 1
195 |
196 | # Update the corresponding record in results
197 | problem_key = item[handler.get_question_key()]
198 | if problem_key not in correct[temp]:
199 | correct[temp][problem_key] = False
200 | if new_response_entry["correctness"]:
201 | correct[temp][problem_key] = True
202 | assert problem_key in results and "responses" in results[problem_key] and str(temp) in results[problem_key]["responses"]
203 | response_entry = results[problem_key]["responses"][str(temp)][sample_id]
204 | response_entry["correctness"] = new_response_entry["correctness"]
205 | response_entry["reason"] = new_response_entry["reason"]
206 | results[problem_key]["responses"][str(temp)][sample_id] = response_entry
207 |
208 | print(f"Final reject-sampling accuracy: {total_correct}/{total_finish}")
209 | # per temperature acc
210 | for temp in temperatures:
211 | temp_correct = sum(correct[temp].values())
212 | temp_total = len(correct[temp])
213 | temp_acc = round(temp_correct / temp_total, 4) if temp_total > 0 else 0
214 | print(f"Temperature {temp} acc: {temp_correct}/{temp_total} ({temp_acc})")
215 |
216 | with open(result_file, 'w', encoding='utf-8') as file:
217 | json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder)
218 |
219 | def perform_inference_and_save(handler: TaskHandler, temperatures, max_tokens, result_file, llm, system_prompt, args):
220 | print(system_prompt)
221 | results = handler.load_existing_results(result_file)
222 | print(f"Loaded {len(results)} existing results.")
223 | train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, \
224 | filter_difficulty=args.filter_difficulty, args=args)
225 | remaining_data = handler.process_remaining_data(train_data, results)
226 | conversations = handler.make_conversations(remaining_data, system_prompt, args.model)
227 |
228 | for temp in temperatures:
229 | if args.model.startswith("openai"):
230 | fetch_partial = partial(fetch_response_openai, llm, args.model, max_tokens, temp)
231 |
232 | with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e:
233 | responses = list(e.map(fetch_partial, conversations))
234 |
235 | else:
236 | sampling_params = SamplingParams(n=args.n, max_tokens=max_tokens, temperature=temp)
237 | responses = llm.chat(messages=conversations, sampling_params=sampling_params, use_tqdm=True)
238 |
239 | completion_tokens = []
240 | prompt_tokens = []
241 | for idx, response in enumerate(responses):
242 | response_entries = []
243 | token_usages = []
244 | completion_token = 0
245 | for sample_idx in range(args.n):
246 | response_entry = {
247 | "content": response.choices[0].message.content.strip() if args.model.startswith("openai") else response.outputs[sample_idx].text.strip(),
248 | "correctness": None,
249 | "reason": None,
250 | }
251 | response_entries.append(response_entry)
252 | if not args.model.startswith("openai"):
253 | token_usages.append({
254 | "completion_tokens": len(response.outputs[sample_idx].token_ids),
255 | "prompt_tokens": len(response.prompt_token_ids)
256 | })
257 | completion_token += len(response.outputs[sample_idx].token_ids)
258 | completion_token /= args.n
259 | prompt_token = len(response.prompt_token_ids)
260 | prompt_tokens.append(prompt_token)
261 | completion_tokens.append(completion_token)
262 |
263 | problem_key = remaining_data[idx][handler.get_question_key()] # can you use this idx
264 | if problem_key not in results:
265 | results[problem_key] = remaining_data[idx]
266 | if isinstance(handler, NUMINATaskHandler):
267 | results[problem_key]["messages"] = ""
268 | results[problem_key]["responses"] = {}
269 | results[problem_key]["token_usages"] = {}
270 | prompt = conversations[idx][1]["content"]
271 | results[problem_key]["prompt"] = prompt
272 |
273 | results[problem_key]["responses"][str(temp)] = response_entries
274 |
275 | if args.model.startswith("openai"):
276 | results[problem_key]["token_usages"][str(temp)] = {
277 | "completion_tokens": response.usage.completion_tokens,
278 | "prompt_tokens": response.usage.prompt_tokens,
279 | }
280 | else:
281 | results[problem_key]["token_usages"][str(temp)] = token_usages
282 |
283 | # Token usage summary put into another subdirectory
284 | result_dir, result_name = os.path.split(result_file)
285 | token_usage_dir = os.path.join(result_dir, "token_usage")
286 | os.makedirs(token_usage_dir, exist_ok=True)
287 |
288 | # Construct the token usage result file path
289 | token_usage_result_file = os.path.join(token_usage_dir, result_name)
290 |
291 | # Prepare the token usage dictionary
292 | token_dict = {
293 | "completion_tokens": sum(completion_tokens),
294 | "prompt_tokens": sum(prompt_tokens),
295 | "avg_completion_tokens": round(sum(completion_tokens) / len(completion_tokens), 3) if completion_tokens else 0,
296 | "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0,
297 | }
298 |
299 | # Save the token usage dictionary to the result file
300 | with open(token_usage_result_file, "w") as f:
301 | json.dump(token_dict, f, indent=4)
302 |
303 | print(f"Token usage saved to {token_usage_result_file}")
304 |
305 | with open(result_file, 'w', encoding='utf-8') as file:
306 | json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder)
307 |
308 | def main():
309 | parser = argparse.ArgumentParser(description="Unified inference and checking for different datasets/tasks.")
310 | parser.add_argument("--dataset", type=str, required=True, choices=list(TASK_HANDLERS.keys()), help="Dataset to process.")
311 | parser.add_argument("--model", type=str, required=True, default="Qwen/QwQ-32B-Preview", help="The model to run.")
312 | parser.add_argument("--tp", type=int, default=8, help="Tensor Parallelism Degree")
313 | parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for the model.")
314 | parser.add_argument("--split", type=str, default="train", help="Split to use for apps (e.g., train, test).")
315 | parser.add_argument("--source", type=str, help="Source for the dataset.")
316 | parser.add_argument("--start", type=int, default=0, help="Start index.")
317 | parser.add_argument("--end", type=int, default=-1, help="End index.")
318 | parser.add_argument("--filter-difficulty", action="store_true", help="Filter difficulty.")
319 | parser.add_argument("--result-dir", type=str, default="./", help="Result dir to save files.")
320 | parser.add_argument("--check", action="store_true", help="Perform evaluation checks on generated samples.")
321 | parser.add_argument("--inference", action="store_true", help="Perform inference.")
322 | parser.add_argument("--temperatures", type=float, nargs="+", default=[0], help="Temperature for sampling.")
323 | parser.add_argument("--math-difficulty-lower-bound", type=int, default=None, help="Lowest difficulty level for math.")
324 | parser.add_argument("--math-difficulty-upper-bound", type=int, default=None, help="Highest difficulty level for math.")
325 | parser.add_argument("--n", type=int, default=1, help="Number of samples generated per problem.")
326 | args = parser.parse_args()
327 |
328 | handler: TaskHandler = TASK_HANDLERS[args.dataset]()
329 | temperatures = [1] if args.model.startswith("openai/o1") else args.temperatures
330 |
331 | print(f"Temperature: {temperatures}")
332 | max_tokens = args.max_tokens
333 | if temperatures == [0] and args.n > 1:
334 | args.n = 1
335 | print("Warning: Temperature 0 does not support multiple samples. Setting n=1.")
336 |
337 | # create result dir if not exists
338 | if args.result_dir and not os.path.exists(args.result_dir):
339 | os.makedirs(args.result_dir)
340 | if args.math_difficulty_lower_bound is not None or args.math_difficulty_upper_bound is not None:
341 | result_file = os.path.join(args.result_dir, f"{get_MODEL_TO_NAME(args.model)}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json")
342 | else:
343 | result_file = os.path.join(args.result_dir, f"{get_MODEL_TO_NAME(args.model)}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}.json")
344 |
345 | if args.check:
346 | # check if converted file exists
347 | if args.math_difficulty_lower_bound is not None or args.math_difficulty_upper_bound is not None:
348 | converted_file = f"{args.result_dir}/converted_{get_MODEL_TO_NAME(args.model)}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json"
349 | else:
350 | converted_file = f"{args.result_dir}/converted_{get_MODEL_TO_NAME(args.model)}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}.json"
351 | if os.path.exists(converted_file):
352 | result_file = converted_file
353 | perform_check(handler, temperatures, result_file, args)
354 | return
355 | elif args.inference:
356 | llm = OpenAI() if args.model.startswith("openai") else LLM(model=args.model, tensor_parallel_size=args.tp)
357 | system_prompt = SYSTEM_PROMPT.get(args.model, SYSTEM_PROMPT['NovaSky-AI/Sky-T1-32B-Preview'])
358 | perform_inference_and_save(handler, temperatures, max_tokens, result_file, llm, system_prompt, args)
359 | return
360 |
361 | llm = OpenAI() if args.model.startswith("openai") else LLM(model=args.model, tensor_parallel_size=args.tp)
362 | system_prompt = SYSTEM_PROMPT.get(args.model, SYSTEM_PROMPT['NovaSky-AI/Sky-T1-32B-Preview'])
363 | perform_inference_and_check(handler, temperatures, max_tokens, result_file, llm, system_prompt, args)
364 |
365 | if __name__ == "__main__":
366 | main()
367 |
--------------------------------------------------------------------------------
/tools/label_math_difficulty.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from tqdm import tqdm
4 | import multiprocessing as mp
5 | import openai
6 | from itertools import cycle
7 | import time
8 | import os
9 | from datasets import load_dataset
10 | import re
11 | import ast
12 | from util.prompts import grading_prompt, aops_criteria
13 |
14 | # Function to set the OpenAI API key
15 | def set_openai_key(api_key):
16 | openai.api_key = api_key
17 |
18 | # From FastChat
19 | def find_difficulty(judgment):
20 | one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]")
21 | one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]")
22 | match = re.search(one_score_pattern, judgment)
23 | if not match:
24 | match = re.search(one_score_pattern_backup, judgment)
25 |
26 | if match:
27 | rating = ast.literal_eval(match.groups()[0])
28 | else:
29 | rating = -1
30 |
31 | return rating
32 |
33 | # GPT API processing function with retry logic
34 | def process_content(problem, api_key):
35 | # Set the OpenAI key for this request
36 | set_openai_key(api_key)
37 |
38 | # GPT prompt
39 | prompt = grading_prompt.format(problem=problem, aops_criteria=aops_criteria)
40 | retries = 3
41 | while retries > 0:
42 | try:
43 | # OpenAI API call
44 | response = openai.chat.completions.create(
45 | model="gpt-4o-mini",
46 | messages=[
47 | {"role": "system", "content": "You are a math problem difficulty labeler."},
48 | {"role": "user", "content": prompt}
49 | ],
50 | max_tokens=2048,
51 | temperature=0.7
52 | )
53 | return response.choices[0].message.content
54 | except openai.RateLimitError:
55 | retries -= 1
56 | if retries == 0:
57 | return "Error: Rate limit reached and retries exhausted."
58 | print(f"Sleep for 5 seconds for API limit.")
59 | time.sleep(5)
60 | except Exception as e:
61 | return f"Error processing content: {e}"
62 |
63 | def process_entry(entry, api_key_cycle):
64 | # Get the next API key from the cycle
65 | api_key = next(api_key_cycle)
66 |
67 | # Pass only entry["problem"] to the process_content function
68 | processed = process_content(entry["problem"], api_key)
69 |
70 | # Store the processed content in the responses
71 | entry["messages"] = ""
72 | entry["gpt_difficulty"] = processed
73 | entry["gpt_difficulty_parsed"] = find_difficulty(processed)
74 | return entry
75 |
76 | # Wrapper function for multiprocessing
77 | def process_entry_wrapper(args):
78 | return process_entry(*args)
79 |
80 | if __name__ == "__main__":
81 | parser = argparse.ArgumentParser(description="Label difficulty")
82 | parser.add_argument("--source", type=str, help="")
83 | parser.add_argument("--start", type=int, default=0, help="")
84 | parser.add_argument("--end", type=int, default=-1, help="")
85 | parser.add_argument("--keys", type=str, help="File containing OpenAI API keys (one per line).")
86 | args = parser.parse_args()
87 |
88 | dataset = load_dataset("AI-MO/NuminaMath-CoT")
89 | data = (
90 | dataset["train"]
91 | .to_pandas()
92 | .query('source == @args.source')
93 | .iloc[args.start:args.end]
94 | )
95 |
96 | data = data.to_dict(orient="records")
97 |
98 | # Load API keys and prepare a round-robin cycle
99 | with open(args.keys, "r") as f:
100 | api_keys = [line.strip() for line in f if line.strip()]
101 | api_key_cycle = cycle(api_keys)
102 |
103 | # Prepare output file
104 | output_file = f"labeled_{args.source}_{args.start}_{args.end}.json"
105 |
106 | # Use multiprocessing to process the content
107 | results = []
108 | with mp.Pool(os.cpu_count()) as pool:
109 | tasks = [(entry, api_key_cycle) for entry in data]
110 | for result in tqdm(pool.imap(process_entry_wrapper, tasks), total=len(data)):
111 | results.append(result)
112 |
113 | # Aggregate and write results in the main process
114 | # aggregated_data = {key: values for key, values in results}
115 | # print(results)
116 | with open(output_file, "w") as f:
117 | json.dump(results, f, indent=4)
118 |
119 | print(f"Processed data saved to {output_file}")
120 |
--------------------------------------------------------------------------------
/tools/labeled_numina_difficulty/README.md:
--------------------------------------------------------------------------------
1 | # Labeled NUMINA Difficulty Data
2 |
3 | We also include data of labeled difficulty from NUMINA, in the following files: `labeled_amc_aime_0_-1.json`, `labeled_math_0_-1.json`, `labeled_olympiads_0_-1.json`. These files can be found and downloaded from [HuggingFace](https://huggingface.co/datasets/NovaSky-AI/labeled_numina_difficulty).
--------------------------------------------------------------------------------
/tools/requirements.txt:
--------------------------------------------------------------------------------
1 | vllm==0.6.2
2 | pyext
3 | word2number
4 | scipy
5 | datasets
6 | latex2sympy2
--------------------------------------------------------------------------------
/tools/response_rewrite.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 | from tqdm import tqdm
6 | from util.math.testing_util import strip_answer_string
7 | from util.model_utils import *
8 | from vllm import LLM, SamplingParams
9 |
10 | def load_dataset(dataset_path : str):
11 | data = {}
12 | with open(dataset_path, 'r', encoding='utf-8') as file:
13 | data = json.load(file)
14 | return data
15 |
16 |
17 | def make_scoring_conversations(dataset, system_prompt):
18 | conversations = []
19 | for _, key in enumerate(dataset):
20 | problem = dataset[key]
21 | gt_answer = strip_answer_string(problem["answer"])
22 | for response_key in problem["responses"]:
23 | response = problem["responses"][response_key]["content"]
24 | prompt_text = response + "\n#####\nThe ground truth answer is " + gt_answer
25 | conversations.append([
26 | {"role": "system", "content": system_prompt},
27 | {"role": "user", "content": prompt_text}
28 | ])
29 |
30 | return conversations
31 |
32 |
33 | def score_solutions(dataset, responses, outfile):
34 | idx = 0
35 | for _, key in tqdm(enumerate(dataset), total=len(dataset), desc="Scoring original solutions"):
36 | problem = dataset[key]
37 | for response_key in problem["responses"]:
38 | score = responses[idx].outputs[0].text.strip()
39 | problem["responses"][response_key]["correctness"] = (score == "True")
40 | idx += 1
41 |
42 | with open(outfile, 'w', encoding='utf-8') as new_file:
43 | json.dump(dataset, new_file, ensure_ascii=False, indent=2)
44 | return dataset
45 |
46 |
47 | def filter_solutions(dataset):
48 | # First filter out incorrect responses.
49 | for key in dataset:
50 | problem = dataset[key]
51 | keys_to_filter = []
52 | for response_key in problem["responses"]:
53 | if not problem["responses"][response_key]["correctness"]:
54 | keys_to_filter.append(response_key)
55 | for k in keys_to_filter:
56 | del problem["responses"][k]
57 | del problem["token_usages"][k]
58 |
59 | # Next, filter out examples with <2 correct responses.
60 | keys_to_filter = []
61 | for key in dataset:
62 | problem = dataset[key]
63 | if len(problem["responses"]) < 2:
64 | keys_to_filter.append(key)
65 | for k in keys_to_filter:
66 | del dataset[k]
67 |
68 | # Finally, filter for the shortest and longest solutions for each sample.
69 | for key in dataset:
70 | problem = dataset[key]
71 | token_usages = problem["token_usages"]
72 | shortest_key, shortest_entry = min(token_usages.items(), key=lambda x: x[1]["completion_tokens"])
73 | longest_key, longest_entry = max(token_usages.items(), key=lambda x: x[1]["completion_tokens"])
74 | problem["token_usages"] = {
75 | "shortest": shortest_entry,
76 | "longest": longest_entry,
77 | }
78 | new_responses = {
79 | "shortest": problem["responses"][shortest_key],
80 | "longest":problem["responses"][longest_key],
81 | }
82 | problem["responses"] = new_responses
83 |
84 | return dataset
85 |
86 |
87 | def make_splitting_conversations(data, system_prompt):
88 | conversations = []
89 | for problem in data:
90 | response = data[problem]["responses"]["shortest"]
91 | prompt_text = response["content"]
92 | conversations.append([
93 | {"role": "system", "content": system_prompt},
94 | {"role": "user", "content": prompt_text}
95 | ])
96 | return conversations
97 |
98 |
99 | def split_solutions(dataset, responses, delimiter):
100 | outputs = []
101 | for _, response in tqdm(enumerate(responses), total=len(responses), desc="Splitting responses"):
102 | content = response.outputs[0].text.strip()
103 | # Split response by configured delimiter.
104 | split_content = content.split(delimiter)
105 | split_content = [x.strip() for x in split_content if x != ""]
106 | outputs.append(split_content)
107 | for idx, key in enumerate(dataset):
108 | solutions = outputs[idx]
109 | problem = dataset[key]
110 | problem["responses"]["shortest"]["subsolutions"] = solutions
111 | return dataset
112 |
113 |
114 | def make_subscoring_conversations(dataset, system_prompt):
115 | conversations = []
116 | for _, key in enumerate(dataset):
117 | problem = dataset[key]
118 | gt_answer = strip_answer_string(problem["answer"])
119 | subsolutions = problem["responses"]["shortest"]["subsolutions"]
120 | for sub in subsolutions:
121 | prompt_text = sub + "\n#####\nThe ground truth answer is " + gt_answer
122 | conversations.append([
123 | {"role": "system", "content": system_prompt},
124 | {"role": "user", "content": prompt_text}
125 | ])
126 | return conversations
127 |
128 |
129 | def score_subsolutions(dataset, responses):
130 | idx = 0
131 | for _, key in tqdm(enumerate(dataset), total=len(dataset), desc="Scoring sub-solutions"):
132 | problem = dataset[key]
133 | subsolutions = problem["responses"]["shortest"]["subsolutions"]
134 | scores = []
135 | for _, sub in enumerate(subsolutions):
136 | score = responses[idx].outputs[0].text.strip()
137 | scores.append(score == "True")
138 | idx += 1
139 | problem["responses"]["shortest"]["scores"] = scores
140 | return dataset
141 |
142 |
143 | def build_response_variants(dataset):
144 | def clean_response_string(response):
145 | if '<|end_of_thought|>' not in response:
146 | response += '<|end_of_thought|>'
147 | return response
148 |
149 | keys_to_remove = []
150 |
151 | for key, problem in dataset.items():
152 | scores = problem["responses"]["shortest"]["scores"]
153 | subsolutions = problem["responses"]["shortest"]["subsolutions"]
154 |
155 | # Check if there are valid scores
156 | if True not in scores:
157 | keys_to_remove.append(key)
158 | continue
159 |
160 | # Build FCS (First Correct Solution)
161 | fcs_idx = scores.index(True)
162 | fcs_response = "\n".join(subsolutions[:fcs_idx + 1]) if fcs_idx < len(scores) - 1 else "\n".join(subsolutions[:-1])
163 | fcs_response = clean_response_string(fcs_response) + "\n" + subsolutions[-1]
164 | problem["responses"]["fcs"] = fcs_response
165 |
166 | # Build FCS + 1
167 | fcs_plus1_idx = fcs_idx + 1 if fcs_idx + 1 < len(subsolutions) - 1 else fcs_idx
168 | fcs_plus1_response = "\n".join(subsolutions[:fcs_plus1_idx + 1])
169 | fcs_plus1_response = clean_response_string(fcs_plus1_response) + "\n" + subsolutions[-1]
170 | problem["responses"]["fcs_plus1"] = fcs_plus1_response
171 |
172 | # Check if there are valid scores
173 | if True not in scores[fcs_idx + 1:]:
174 | keys_to_remove.append(key)
175 | continue
176 |
177 | # Build FCS + Reflection
178 | fcs_reflection_idx = scores.index(True, fcs_idx + 1)
179 | fcs_reflection_response = "\n".join(subsolutions[:fcs_reflection_idx + 1]) if fcs_reflection_idx < len(scores) - 1 else "\n".join(subsolutions[:-1])
180 | fcs_reflection_response = clean_response_string(fcs_reflection_response) + "\n" + subsolutions[-1]
181 | problem["responses"]["fcs_reflection"] = fcs_reflection_response
182 |
183 | # Remove problems without valid sub-solutions
184 | for key in keys_to_remove:
185 | del dataset[key]
186 |
187 | return dataset
188 |
189 |
190 | def compute_token_usages(dataset, variants, llm):
191 | tokenizer = llm.get_tokenizer()
192 | for key in tqdm(dataset, desc="Computing token usages", total=len(dataset)):
193 | problem = dataset[key]
194 | prompt_tokens = problem["token_usages"]["shortest"]["prompt_tokens"]
195 | for variant in variants:
196 | problem["token_usages"][variant] = {
197 | "prompt_tokens": prompt_tokens,
198 | "completion_tokens": len(tokenizer(problem["responses"][variant]).input_ids)
199 | }
200 | return dataset
201 |
202 |
203 | def build_question_prompt(prompt):
204 | return "Return your final response within \\boxed{{}}" + prompt
205 |
206 | def make_preference_conversations(final_dataset, format, system_prompt):
207 | conversations = []
208 | for prompt in final_dataset:
209 | problem = final_dataset[prompt]
210 | convo = {}
211 | convo["conversations"] = [
212 | {
213 | "from": "system",
214 | "value": system_prompt,
215 | },
216 | {
217 | "from": "human",
218 | "value": build_question_prompt(prompt),
219 | }
220 | ]
221 | convo["chosen"] = {
222 | "from": "gpt",
223 | "value": problem["responses"][format],
224 | }
225 | convo["rejected"] = {
226 | "from": "gpt",
227 | "value": problem["responses"]["longest"]["content"]
228 | }
229 | conversations.append(convo)
230 |
231 | return conversations
232 |
233 |
234 | def make_SILC_conversations(dataset, system_prompt):
235 | keys_to_filter = []
236 | for prompt in dataset:
237 | problem = dataset[prompt]
238 | contition = False
239 | for response_key in problem["responses"]:
240 | if not problem["responses"][response_key]['correctness']:
241 | wrong_length = problem["token_usages"][response_key]['completion_tokens']
242 | for k in problem["responses"]:
243 | if k != response_key and problem["token_usages"][k]['completion_tokens'] > wrong_length and problem["responses"][k]['correctness']:
244 | contition = True
245 | break
246 | break
247 | if not contition:
248 | keys_to_filter.append(prompt)
249 |
250 | for key in keys_to_filter:
251 | del dataset[key]
252 |
253 | # Build contrastive pairs out of {short incorrect, long correct}
254 | conversations = []
255 | for prompt in dataset:
256 | problem = dataset[prompt]
257 |
258 | shortest_incorrect_key = None
259 | shortest_incorrect_length = float('inf')
260 |
261 | # Get shortest incorrect.
262 | for response_key in problem["responses"]:
263 | if not problem["responses"][response_key]['correctness']:
264 | length = problem["token_usages"][response_key]['completion_tokens']
265 | if length < shortest_incorrect_length:
266 | shortest_incorrect_length = length
267 | shortest_incorrect_key = response_key
268 |
269 | # Get next longest correct.
270 | shortest_correct_longer_key = None
271 | shortest_correct_longer_length = float('inf')
272 | for response_key in problem["responses"]:
273 | if problem["responses"][response_key]['correctness']:
274 | length = problem["token_usages"][response_key]['completion_tokens']
275 | if length > shortest_incorrect_length and length < shortest_correct_longer_length:
276 | shortest_correct_longer_length = length
277 | shortest_correct_longer_key = response_key
278 |
279 | convo = {}
280 | convo["conversations"] = [
281 | {
282 | "from": "system",
283 | "value": system_prompt,
284 | },
285 | {
286 | "from": "human",
287 | "value": build_question_prompt(prompt),
288 | }
289 | ]
290 | convo["chosen"] = {
291 | "from": "gpt",
292 | "value": problem["responses"][shortest_correct_longer_key]['content'],
293 | }
294 | convo["rejected"] = {
295 | "from": "gpt",
296 | "value": problem["responses"][shortest_incorrect_key]["content"]
297 | }
298 | conversations.append(convo)
299 |
300 | return conversations
301 |
302 |
303 | def main():
304 | parser = argparse.ArgumentParser(description="Filter, rewrite, and format generated responses for high-quality data curation.")
305 | parser.add_argument("--rewrite-model", type=str, required=True, default="meta-llama/Llama-3.3-70B-Instruct", help="The model used for response processing.")
306 | parser.add_argument("--target-model", type=str, required=True, default="NovaSky-AI/Sky-T1-32B-Preview", help="The target model the rewritten responses will be used to train.")
307 | parser.add_argument("--dataset", type=str, required=True, help="Path to the starting dataset of generated responses to filter from.")
308 | parser.add_argument("--result-dir", type=str, default="./", help="Result directory to save processed data.")
309 | parser.add_argument("--checkpoint", action="store_true", help="Whether to checkpoint the dataset at each step.")
310 | parser.add_argument("--SILC", action="store_true", help="Whether to include short-incorrect/long-correct (SILC) preference pairs.")
311 | parser.add_argument("--tp", type=int, default=8, help="Tensor Parallelism Degree")
312 | parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for the model.")
313 | args = parser.parse_args()
314 |
315 | if args.result_dir and not os.path.exists(args.result_dir):
316 | os.makedirs(args.result_dir)
317 |
318 | # Initialize model for data processing.
319 | llm = LLM(model=args.rewrite_model, tensor_parallel_size=args.tp)
320 | sampling_params = SamplingParams(max_tokens=args.max_tokens)
321 |
322 | original_dataset = load_dataset(args.dataset)
323 |
324 | # Filter for the shortest and longest correct solutions.
325 | filtered_dataset = filter_solutions(original_dataset)
326 | if args.checkpoint:
327 | outfile = os.path.join(args.result_dir, f"filtered-responses.json")
328 | with open(outfile, 'w', encoding='utf-8') as new_file:
329 | json.dump(filtered_dataset, new_file, ensure_ascii=False, indent=2)
330 |
331 | # Split the shortest solution into subsolutions using the configured model.
332 | conversations = make_splitting_conversations(filtered_dataset, SUBPROBLEM_SPLIT_PROMPT)
333 | responses = llm.chat(messages=conversations, sampling_params=sampling_params, use_tqdm=True)
334 | split_dataset = split_solutions(filtered_dataset, responses, '#####')
335 | if args.checkpoint:
336 | outfile = os.path.join(args.result_dir, f"split-solutions.json")
337 | with open(outfile, 'w', encoding='utf-8') as new_file:
338 | json.dump(split_dataset, new_file, ensure_ascii=False, indent=2)
339 |
340 | # Score the subsolutions using the configured model.
341 | subscoring_conversations = make_subscoring_conversations(split_dataset, SUBSOLUTION_EXTRACTION_PROMPT)
342 | responses = llm.chat(messages=subscoring_conversations, sampling_params=sampling_params, use_tqdm=True)
343 | scored_dataset = score_subsolutions(split_dataset, responses)
344 | if args.checkpoint:
345 | outfile = os.path.join(args.result_dir, f"scored-subsolutions.json")
346 | with open(outfile, 'w', encoding='utf-8') as new_file:
347 | json.dump(scored_dataset, new_file, ensure_ascii=False, indent=2)
348 |
349 | # Rewrite response based on variants of combining sub-solutions. Here are examples for
350 | # FCS, FCS+1, and FCS+Reflection.
351 | variants_dataset = build_response_variants(scored_dataset)
352 | if args.checkpoint:
353 | outfile = os.path.join(args.result_dir, f"response-variants.json")
354 | with open(outfile, 'w', encoding='utf-8') as new_file:
355 | json.dump(variants_dataset, new_file, ensure_ascii=False, indent=2)
356 |
357 | # Add per-variant token counts to dataset for convenience.
358 | final_dataset = compute_token_usages(variants_dataset, ["fcs", "fcs_plus1", "fcs_reflection"], llm)
359 |
360 | system_prompt = SYSTEM_PROMPT[args.target_model]
361 |
362 | # Generate conversation format for each variant, which can be used in SimPO/DPO/etc.
363 | fcs_convo = make_preference_conversations(final_dataset, "fcs", system_prompt)
364 | fcs_plus1_convo = make_preference_conversations(final_dataset, "fcs_plus1", system_prompt)
365 | fcs_reflection_convo = make_preference_conversations(final_dataset, "fcs_reflection", system_prompt)
366 |
367 | # Optionall add short incorrect, long correct (SILC) conversations
368 | if args.SILC:
369 | short_incorrect_long_correct_conversations = make_SILC_conversations(load_dataset(args.dataset), system_prompt)
370 | for convo in [fcs_convo, fcs_plus1_convo, fcs_reflection_convo]:
371 | convo += short_incorrect_long_correct_conversations
372 | random.shuffle(convo)
373 |
374 | # Save final conversation variants.
375 | fcs_outfile = os.path.join(args.result_dir, "fcs-conversations.json")
376 | with open(fcs_outfile, 'w', encoding='utf-8') as new_file:
377 | json.dump(fcs_convo, new_file, ensure_ascii=False, indent=2)
378 |
379 | fcs_plus1_outfile = os.path.join(args.result_dir, "fcs_plus1-conversations.json")
380 | with open(fcs_plus1_outfile, 'w', encoding='utf-8') as new_file:
381 | json.dump(fcs_plus1_convo, new_file, ensure_ascii=False, indent=2)
382 |
383 | fcs_reflection_outfile = os.path.join(args.result_dir, "fcs_reflection-conversations.json")
384 | with open(fcs_reflection_outfile, 'w', encoding='utf-8') as new_file:
385 | json.dump(fcs_reflection_convo, new_file, ensure_ascii=False, indent=2)
386 |
387 |
388 | if __name__ == "__main__":
389 | main()
390 |
--------------------------------------------------------------------------------
/tools/upload_hub.py:
--------------------------------------------------------------------------------
1 | """
2 | From https://github.com/lm-sys/FastChat/
3 | Upload weights to huggingface.
4 |
5 | Usage:
6 | python upload_hub.py --model-path ~/model_weights/Sky-T1 --hub-repo-id NovaSky-AI/Sky-T1 --private
7 | """
8 | import argparse
9 | import tempfile
10 |
11 | import torch
12 | from transformers import AutoTokenizer, AutoModelForCausalLM
13 |
14 | def upload_hub(model_path, hub_repo_id, component, private):
15 | if component == "all":
16 | components = ["model", "tokenizer"]
17 | else:
18 | components = [component]
19 |
20 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private}
21 |
22 | if "model" in components:
23 | model = AutoModelForCausalLM.from_pretrained(
24 | model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
25 | )
26 | with tempfile.TemporaryDirectory() as tmp_path:
27 | model.save_pretrained(tmp_path, **kwargs)
28 |
29 | if "tokenizer" in components:
30 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
31 | with tempfile.TemporaryDirectory() as tmp_path:
32 | tokenizer.save_pretrained(tmp_path, **kwargs)
33 |
34 |
35 | if __name__ == "__main__":
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("--model-path", type=str, required=True)
38 | parser.add_argument("--hub-repo-id", type=str, required=True)
39 | parser.add_argument(
40 | "--component", type=str, choices=["all", "model", "tokenizer"], default="all"
41 | )
42 | parser.add_argument("--private", action="store_true")
43 | args = parser.parse_args()
44 |
45 | upload_hub(args.model_path, args.hub_repo_id, args.component, args.private)
--------------------------------------------------------------------------------
/tools/util/common.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 |
3 | class TimeoutException(Exception):
4 | """Custom exception for function timeout."""
5 | pass
6 |
7 | def timeout(seconds):
8 | """Decorator to enforce a timeout on a function using multiprocessing."""
9 | def decorator(func):
10 | def wrapper(*args, **kwargs):
11 | # A queue to store the result or exception
12 | queue = multiprocessing.Queue()
13 |
14 | def target(queue, *args, **kwargs):
15 | try:
16 | result = func(*args, **kwargs)
17 | queue.put((True, result))
18 | except Exception as e:
19 | queue.put((False, e))
20 |
21 | process = multiprocessing.Process(target=target, args=(queue, *args), kwargs=kwargs)
22 | process.start()
23 | process.join(seconds)
24 |
25 | if process.is_alive():
26 | process.terminate()
27 | process.join()
28 | raise TimeoutException(f"Function '{func.__name__}' timed out after {seconds} seconds!")
29 |
30 | success, value = queue.get()
31 | if success:
32 | return value
33 | else:
34 | raise value
35 | return wrapper
36 | return decorator
--------------------------------------------------------------------------------
/tools/util/math/testing_util.py:
--------------------------------------------------------------------------------
1 | """
2 | The logic in this file largely borrows from Qwen2.5-Math codebase at https://github.com/QwenLM/Qwen2.5-Math:
3 | """
4 |
5 | import re
6 | import regex
7 | from word2number import w2n
8 | from math import isclose
9 | from collections import defaultdict
10 |
11 | from sympy import simplify, N
12 | from sympy.parsing.sympy_parser import parse_expr
13 | from sympy.parsing.latex import parse_latex
14 | from latex2sympy2 import latex2sympy
15 |
16 | def convert_word_number(text: str) -> str:
17 | try:
18 | text = str(w2n.word_to_num(text))
19 | except:
20 | pass
21 | return text
22 |
23 | def _fix_fracs(string):
24 | substrs = string.split("\\frac")
25 | new_str = substrs[0]
26 | if len(substrs) > 1:
27 | substrs = substrs[1:]
28 | for substr in substrs:
29 | new_str += "\\frac"
30 | if len(substr) > 0 and substr[0] == "{":
31 | new_str += substr
32 | else:
33 | try:
34 | assert len(substr) >= 2
35 | except:
36 | return string
37 | a = substr[0]
38 | b = substr[1]
39 | if b != "{":
40 | if len(substr) > 2:
41 | post_substr = substr[2:]
42 | new_str += "{" + a + "}{" + b + "}" + post_substr
43 | else:
44 | new_str += "{" + a + "}{" + b + "}"
45 | else:
46 | if len(substr) > 2:
47 | post_substr = substr[2:]
48 | new_str += "{" + a + "}" + b + post_substr
49 | else:
50 | new_str += "{" + a + "}" + b
51 | string = new_str
52 | return string
53 |
54 |
55 | def _fix_a_slash_b(string):
56 | if len(string.split("/")) != 2:
57 | return string
58 | a = string.split("/")[0]
59 | b = string.split("/")[1]
60 | try:
61 | if "sqrt" not in a:
62 | a = int(a)
63 | if "sqrt" not in b:
64 | b = int(b)
65 | assert string == "{}/{}".format(a, b)
66 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
67 | return new_string
68 | except:
69 | return string
70 |
71 |
72 | def _fix_sqrt(string):
73 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
74 | return _string
75 |
76 | def strip_answer_string(string):
77 | string = str(string).strip()
78 | # linebreaks
79 | string = string.replace("\n", "")
80 |
81 | # right "."
82 | string = string.rstrip(".")
83 |
84 | # remove inverse spaces
85 | # replace \\ with \
86 | string = string.replace("\\!", "")
87 | # string = string.replace("\\ ", "")
88 | # string = string.replace("\\\\", "\\")
89 |
90 | # matrix
91 | string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
92 | string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
93 | string = string.replace("bmatrix", "pmatrix")
94 |
95 | # replace tfrac and dfrac with frac
96 | string = string.replace("tfrac", "frac")
97 | string = string.replace("dfrac", "frac")
98 | string = (
99 | string.replace("\\neq", "\\ne")
100 | .replace("\\leq", "\\le")
101 | .replace("\\geq", "\\ge")
102 | )
103 |
104 | # remove \left and \right
105 | string = string.replace("\\left", "")
106 | string = string.replace("\\right", "")
107 | string = string.replace("\\{", "{")
108 | string = string.replace("\\}", "}")
109 |
110 | # Function to replace number words with corresponding digits
111 | def replace_match(match):
112 | word = match.group(1).lower()
113 | if convert_word_number(word) == word:
114 | return match.group(0)
115 | else:
116 | return convert_word_number(word)
117 | string = re.sub(r"\\text\{([a-zA-Z]+)\}", replace_match, string)
118 |
119 | # Before removing unit, check if the unit is squared (for surface area)
120 | string = re.sub(r"(cm|inches)\}\^2", r"\1}", string)
121 |
122 | # Remove unit: miles, dollars if after is not none
123 | _string = re.sub(r"\\text{.*?}$", "", string).strip()
124 | if _string != "" and _string != string:
125 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
126 | string = _string
127 |
128 | # Remove circ (degrees)
129 | string = string.replace("^{\\circ}", "")
130 | string = string.replace("^\\circ", "")
131 |
132 | # remove dollar signs
133 | string = string.replace("\\$", "")
134 | string = string.replace("$", "")
135 | string = string.replace("\\(", "").replace("\\)", "")
136 |
137 | # convert word number to digit
138 | string = convert_word_number(string)
139 |
140 | # replace "\\text{...}" to "..."
141 | string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
142 | for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
143 | string = string.replace(key, "")
144 | string = string.replace("\\emptyset", r"{}")
145 | string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
146 |
147 | # remove percentage
148 | string = string.replace("\\%", "")
149 | string = string.replace("\%", "")
150 | string = string.replace("%", "")
151 |
152 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
153 | string = string.replace(" .", " 0.")
154 | string = string.replace("{.", "{0.")
155 |
156 | # cdot
157 | # string = string.replace("\\cdot", "")
158 | if (
159 | string.startswith("{")
160 | and string.endswith("}")
161 | and string.isalnum()
162 | or string.startswith("(")
163 | and string.endswith(")")
164 | and string.isalnum()
165 | or string.startswith("[")
166 | and string.endswith("]")
167 | and string.isalnum()
168 | ):
169 | string = string[1:-1]
170 |
171 | # inf
172 | string = string.replace("infinity", "\\infty")
173 | if "\\infty" not in string:
174 | string = string.replace("inf", "\\infty")
175 | string = string.replace("+\\inity", "\\infty")
176 |
177 | # and
178 | string = string.replace("and", "")
179 | string = string.replace("\\mathbf", "")
180 |
181 | # use regex to remove \mbox{...}
182 | string = re.sub(r"\\mbox{.*?}", "", string)
183 |
184 | # quote
185 | string.replace("'", "")
186 | string.replace('"', "")
187 |
188 | # i, j
189 | if "j" in string and "i" not in string:
190 | string = string.replace("j", "i")
191 |
192 | # replace a.000b where b is not number or b is end, with ab, use regex
193 | string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
194 | string = re.sub(r"(\d+)\.0*$", r"\1", string)
195 |
196 | # if empty, return empty string
197 | if len(string) == 0:
198 | return string
199 | if string[0] == ".":
200 | string = "0" + string
201 |
202 | # to consider: get rid of e.g. "k = " or "q = " at beginning
203 | if len(string.split("=")) == 2:
204 | if len(string.split("=")[0]) <= 2:
205 | string = string.split("=")[1]
206 |
207 | string = _fix_sqrt(string)
208 | string = string.replace(" ", "")
209 |
210 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
211 | string = _fix_fracs(string)
212 |
213 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
214 | string = _fix_a_slash_b(string)
215 |
216 | # Remove unnecessary '\' before integers
217 | string = re.sub(r"\\(?=\-?\d+(\\|\)|,|\]|$))", "", string)
218 |
219 | # Remove grade level (e.g., 12th grade) and just maintain the integer
220 | string = re.sub(r"thgrade$", "", string)
221 |
222 | # If the answer is a list of integers (without parenthesis), sort them
223 | if re.fullmatch(r"(\s*-?\d+\s*,)*\s*-?\d+\s*", string):
224 | # Split the string into a list of integers
225 | try:
226 | integer_list = list(map(int, string.split(',')))
227 | except:
228 | integer_list = list(map(int, "-1,-1".split(',')))
229 |
230 | # Sort the list in ascending order
231 | sorted_list = sorted(integer_list)
232 |
233 | # Join the sorted list back into a comma-separated string
234 | string = ','.join(map(str, sorted_list))
235 |
236 | return string
237 |
238 | def extract_answer(pred_str, use_last_number=True):
239 | pred_str = pred_str.replace("\u043a\u0438", "")
240 | if "final answer is $" in pred_str and "$. I hope" in pred_str:
241 | # minerva_math
242 | tmp = pred_str.split("final answer is $", 1)[1]
243 | pred = tmp.split("$. I hope", 1)[0].strip()
244 | elif "boxed" in pred_str:
245 | ans = pred_str.split("boxed")[-1]
246 | if len(ans) == 0:
247 | return ""
248 | elif ans[0] == "{":
249 | stack = 1
250 | a = ""
251 | for c in ans[1:]:
252 | if c == "{":
253 | stack += 1
254 | a += c
255 | elif c == "}":
256 | stack -= 1
257 | if stack == 0:
258 | break
259 | a += c
260 | else:
261 | a += c
262 | else:
263 | a = ans.split("$")[0].strip()
264 | pred = a
265 | elif "he answer is" in pred_str:
266 | pred = pred_str.split("he answer is")[-1].strip()
267 | elif "final answer is" in pred_str:
268 | pred = pred_str.split("final answer is")[-1].strip()
269 | elif "答案是" in pred_str:
270 | # Handle Chinese few-shot multiple choice problem answer extraction
271 | pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
272 | else: # use the last number
273 | if use_last_number:
274 | pattern = "-?\d*\.?\d+"
275 | pred = re.findall(pattern, pred_str.replace(",", ""))
276 | if len(pred) >= 1:
277 | pred = pred[-1]
278 | else:
279 | pred = ""
280 | else:
281 | pred = ""
282 |
283 | # multiple line
284 | # pred = pred.split("\n")[0]
285 | pred = re.sub(r"\n\s*", "", pred)
286 | if pred != "" and pred[0] == ":":
287 | pred = pred[1:]
288 | if pred != "" and pred[-1] == ".":
289 | pred = pred[:-1]
290 | if pred != "" and pred[-1] == "/":
291 | pred = pred[:-1]
292 | pred = strip_answer_string(pred)
293 | return pred
294 |
295 | def get_multiple_choice_answer(pred: str):
296 | tmp = re.findall(r"\b(A|B|C|D)\b", pred.upper())
297 | if tmp:
298 | pred = tmp
299 | else:
300 | pred = [pred.strip().strip(".")]
301 |
302 | if len(pred) == 0:
303 | pred = ""
304 | else:
305 | pred = pred[-1]
306 |
307 | # Remove the period at the end, again!
308 | pred = pred.rstrip(".").rstrip("/")
309 |
310 | return pred
311 |
312 | def mmlu_pro_extract_answer(text):
313 | pattern = r"answer is \(?([A-J])\)?"
314 | match = re.search(pattern, text)
315 | if match:
316 | return match.group(1)
317 | else:
318 | # print("1st answer extract failed\n" + text)
319 | match = re.search(r'.*[aA]nswer:\s*([A-J])', text)
320 | if match:
321 | return match.group(1)
322 | else:
323 | # print("2nd answer extract failed\n" + text)
324 | pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)"
325 | match = re.search(pattern, text, re.DOTALL)
326 | if match:
327 | return match.group(0)
328 |
329 | def choice_answer_clean(pred: str):
330 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
331 | # Clean the answer based on the dataset
332 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
333 | if tmp:
334 | pred = tmp
335 | else:
336 | pred = [pred.strip().strip(".")]
337 | pred = pred[-1]
338 | # Remove the period at the end, again!
339 | pred = pred.rstrip(".").rstrip("/")
340 | return pred
341 |
342 |
343 | def parse_digits(num):
344 | num = regex.sub(",", "", str(num))
345 | try:
346 | return float(num)
347 | except:
348 | if num.endswith("%"):
349 | num = num[:-1]
350 | if num.endswith("\\"):
351 | num = num[:-1]
352 | try:
353 | return float(num) / 100
354 | except:
355 | pass
356 | return None
357 |
358 |
359 | def is_digit(num):
360 | # paired with parse_digits
361 | return parse_digits(num) is not None
362 |
363 |
364 | def str_to_pmatrix(input_str):
365 | input_str = input_str.strip()
366 | matrix_str = re.findall(r"\{.*,.*\}", input_str)
367 | pmatrix_list = []
368 |
369 | for m in matrix_str:
370 | m = m.strip("{}")
371 | pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
372 | pmatrix_list.append(pmatrix)
373 |
374 | return ", ".join(pmatrix_list)
375 |
376 |
377 | def math_equal(
378 | prediction,
379 | reference,
380 | include_percentage: bool = True,
381 | is_close: bool = True,
382 | timeout: bool = False,
383 | ) -> bool:
384 | """
385 | Exact match of math if and only if:
386 | 1. numerical equal: both can convert to float and are equal
387 | 2. symbolic equal: both can convert to sympy expression and are equal
388 | """
389 | if prediction is None or reference is None:
390 | return False
391 | if str(prediction.strip().lower()) == str(reference.strip().lower()):
392 | return True
393 | if (
394 | reference in ["A", "B", "C", "D", "E"]
395 | and choice_answer_clean(prediction) == reference
396 | ):
397 | return True
398 |
399 | try: # 1. numerical equal
400 | if is_digit(prediction) and is_digit(reference):
401 | prediction = parse_digits(prediction)
402 | reference = parse_digits(reference)
403 | # number questions
404 | if include_percentage:
405 | gt_result = [reference / 100, reference, reference * 100]
406 | else:
407 | gt_result = [reference]
408 | for item in gt_result:
409 | try:
410 | if is_close:
411 | if numeric_equal(prediction, item):
412 | return True
413 | else:
414 | if item == prediction:
415 | return True
416 | except Exception:
417 | continue
418 | return False
419 | except:
420 | pass
421 |
422 | if not prediction and prediction not in [0, False]:
423 | return False
424 |
425 | # 2. symbolic equal
426 | reference = str(reference).strip()
427 | prediction = str(prediction).strip()
428 |
429 | ## pmatrix (amps)
430 | if "pmatrix" in prediction and not "pmatrix" in reference:
431 | reference = str_to_pmatrix(reference)
432 |
433 | ## deal with [], (), {}
434 | pred_str, ref_str = prediction, reference
435 | if (
436 | prediction.startswith("[")
437 | and prediction.endswith("]")
438 | and not reference.startswith("(")
439 | ) or (
440 | prediction.startswith("(")
441 | and prediction.endswith(")")
442 | and not reference.startswith("[")
443 | ):
444 | pred_str = pred_str.strip("[]()")
445 | ref_str = ref_str.strip("[]()")
446 | for s in ["{", "}", "(", ")"]:
447 | ref_str = ref_str.replace(s, "")
448 | pred_str = pred_str.replace(s, "")
449 | if pred_str.lower() == ref_str.lower():
450 | return True
451 |
452 | ## [a, b] vs. [c, d], return a==c and b==d
453 | if (
454 | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
455 | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
456 | ):
457 | pred_parts = prediction[1:-1].split(",")
458 | ref_parts = reference[1:-1].split(",")
459 | if len(pred_parts) == len(ref_parts):
460 | if all(
461 | [
462 | math_equal(
463 | pred_parts[i], ref_parts[i], include_percentage, is_close
464 | )
465 | for i in range(len(pred_parts))
466 | ]
467 | ):
468 | return True
469 | if (
470 | (
471 | prediction.startswith("\\begin{pmatrix}")
472 | or prediction.startswith("\\begin{bmatrix}")
473 | )
474 | and (
475 | prediction.endswith("\\end{pmatrix}")
476 | or prediction.endswith("\\end{bmatrix}")
477 | )
478 | and (
479 | reference.startswith("\\begin{pmatrix}")
480 | or reference.startswith("\\begin{bmatrix}")
481 | )
482 | and (
483 | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
484 | )
485 | ):
486 | pred_lines = [
487 | line.strip()
488 | for line in prediction[
489 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
490 | ].split("\\\\")
491 | if line.strip()
492 | ]
493 | ref_lines = [
494 | line.strip()
495 | for line in reference[
496 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
497 | ].split("\\\\")
498 | if line.strip()
499 | ]
500 | matched = True
501 | if len(pred_lines) == len(ref_lines):
502 | for pred_line, ref_line in zip(pred_lines, ref_lines):
503 | pred_parts = pred_line.split("&")
504 | ref_parts = ref_line.split("&")
505 | if len(pred_parts) == len(ref_parts):
506 | if not all(
507 | [
508 | math_equal(
509 | pred_parts[i],
510 | ref_parts[i],
511 | include_percentage,
512 | is_close,
513 | )
514 | for i in range(len(pred_parts))
515 | ]
516 | ):
517 | matched = False
518 | break
519 | else:
520 | matched = False
521 | if not matched:
522 | break
523 | else:
524 | matched = False
525 | if matched:
526 | return True
527 |
528 | if prediction.count("=") == 1 and reference.count("=") == 1:
529 | pred = prediction.split("=")
530 | pred = f"{pred[0].strip()} - ({pred[1].strip()})"
531 | ref = reference.split("=")
532 | ref = f"{ref[0].strip()} - ({ref[1].strip()})"
533 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
534 | return True
535 | elif (
536 | prediction.count("=") == 1
537 | and len(prediction.split("=")[0].strip()) <= 2
538 | and "=" not in reference
539 | ):
540 | if math_equal(
541 | prediction.split("=")[1], reference, include_percentage, is_close
542 | ):
543 | return True
544 | elif (
545 | reference.count("=") == 1
546 | and len(reference.split("=")[0].strip()) <= 2
547 | and "=" not in prediction
548 | ):
549 | if math_equal(
550 | prediction, reference.split("=")[1], include_percentage, is_close
551 | ):
552 | return True
553 |
554 | if symbolic_equal(prediction, reference):
555 | return True
556 |
557 | return False
558 |
559 |
560 | def numeric_equal(prediction: float, reference: float):
561 | return isclose(reference, prediction, rel_tol=1e-4)
562 |
563 |
564 | def symbolic_equal(a, b):
565 | def _parse(s):
566 | for f in [parse_latex, parse_expr, latex2sympy]:
567 | try:
568 | return f(s.replace("\\\\", "\\"))
569 | except:
570 | try:
571 | return f(s)
572 | except:
573 | pass
574 | return s
575 |
576 | a = _parse(a)
577 | b = _parse(b)
578 |
579 | # direct equal
580 | try:
581 | if str(a) == str(b) or a == b:
582 | return True
583 | except:
584 | pass
585 |
586 | # simplify equal
587 | try:
588 | if a.equals(b) or simplify(a - b) == 0:
589 | return True
590 | except:
591 | pass
592 |
593 | # equation equal
594 | try:
595 | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
596 | return True
597 | except:
598 | pass
599 |
600 | try:
601 | if numeric_equal(float(N(a)), float(N(b))):
602 | return True
603 | except:
604 | pass
605 |
606 | # matrix
607 | try:
608 | # if a and b are matrix
609 | if a.shape == b.shape:
610 | _a = a.applyfunc(lambda x: round(x, 3))
611 | _b = b.applyfunc(lambda x: round(x, 3))
612 | if _a.equals(_b):
613 | return True
614 | except:
615 | pass
616 |
617 | return False
--------------------------------------------------------------------------------
/tools/util/model_utils.py:
--------------------------------------------------------------------------------
1 | SYSTEM_PROMPT = {
2 | "Qwen/Qwen2-7B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
3 | "Qwen/QwQ-32B-Preview": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
4 | "Qwen/Qwen2.5-72B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
5 | "Qwen/Qwen2.5-32B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
6 | "Qwen/Qwen2.5-7B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
7 | "Qwen/Qwen2.5-1.5B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
8 | "Qwen/Qwen2.5-Math-7B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
9 | "Qwen/Qwen2.5-Math-72B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
10 | "PRIME-RL/Eurus-2-7B-PRIME": """When tackling complex reasoning tasks, you have access to the following actions. Use them as needed to progress through your thought process. After each action, determine and state the next most appropriate action to take.
11 |
12 | Actions:
13 |
14 | {actions}
15 |
16 | Your action should contain multiple steps, and each step starts with #. After each action (except OUTPUT), state which action you will take next with ''Next action: [Your action]'' and finish this turn. Continue this process until you reach a satisfactory conclusion or solution to the problem at hand, at which point you should use the [OUTPUT] action. The thought process is completely invisible to user, so [OUTPUT] should be a complete response. You should strictly follow the format below:
17 |
18 | [ACTION NAME]
19 |
20 | # Your action step 1
21 |
22 | # Your action step 2
23 |
24 | # Your action step 3
25 |
26 | ...
27 |
28 | Next action: [NEXT ACTION NAME]
29 |
30 |
31 | Now, begin with the [ASSESS] action for the following task:
32 | """,
33 | "NovaSky-AI/Sky-T1-32B-Preview": "Your role as an assistant involves thoroughly exploring questions through a systematic long \
34 | thinking process before providing the final precise and accurate solutions. This requires \
35 | engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, \
36 | backtracing, and iteration to develop well-considered thinking process. \
37 | Please structure your response into two main sections: Thought and Solution. \
38 | In the Thought section, detail your reasoning process using the specified format: \
39 | <|begin_of_thought|> {thought with steps separated with '\n\n'} \
40 | <|end_of_thought|> \
41 | Each step should include detailed considerations such as analisying questions, summarizing \
42 | relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining \
43 | any errors, and revisiting previous steps. \
44 | In the Solution section, based on various attempts, explorations, and reflections from the Thought \
45 | section, systematically present the final solution that you deem correct. The solution should \
46 | remain a logical, accurate, concise expression style and detail necessary step needed to reach the \
47 | conclusion, formatted as follows: \
48 | <|begin_of_solution|> \
49 | {final formatted, precise, and clear solution} \
50 | <|end_of_solution|> \
51 | Now, try to solve the following question through the above guidelines:",
52 | "openai/o1-mini": "Question: {input}\nAnswer: ",
53 | "openai/o1-preview": "Question: {input}\nAnswer: ",
54 | "openai/gpt-4o-mini": "User: {input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant:",
55 | }
56 |
57 | def get_MODEL_TO_NAME(key):
58 | if key in MODEL_TO_NAME:
59 | return MODEL_TO_NAME[key]
60 | else:
61 | keys = key.split("/")
62 | if 'checkpoint' in keys[-1]: return keys[-2]
63 | return keys[-1]
64 |
65 | MODEL_TO_NAME = {
66 | "Qwen/Qwen2-7B-Instruct": "Qwen2-7B-Instruct",
67 | "Qwen/QwQ-32B-Preview": "QwQ-32B-Preview",
68 | "Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct",
69 | "Qwen/Qwen2.5-32B-Instruct": "Qwen2.5-32B-Instruct",
70 | "Qwen/Qwen2.5-7B-Instruct": "Qwen2.5-7B-Instruct",
71 | "Qwen/Qwen2.5-1.5B-Instruct": "Qwen2.5-1.5B-Instruct",
72 | "Qwen/Qwen2.5-Math-7B-Instruct": "Qwen2.5-Math-7B-Instruct",
73 | "Qwen/Qwen2.5-Math-72B-Instruct": "Qwen2.5-Math-72B-Instruct",
74 | "PRIME-RL/Eurus-2-7B-PRIME": "Eurus-2-7B-PRIME",
75 | "NovaSky-AI/Sky-T1-32B-Preview": "Sky-T1-32B-Preview",
76 | "openai/o1-mini": "o1-mini",
77 | "openai/o1-preview": "o1-preview",
78 | "openai/gpt-4o-mini": "gpt-4o-mini",
79 | }
80 |
81 | SUBPROBLEM_SPLIT_PROMPT = """
82 | You are given a reasoning sequence that attempts to solve a math problem.
83 | This sequence contains multiple proposed solutions, then provides a the final solution.
84 | Each proposed solution within the sequence follows a different line of thought, usually to double check the answer.
85 | Your objective is to identify these separate lines of thought and add the separator string '#####' between the separate lines of thought.
86 | This is important: Your response should be the original unchanged reasoning sequence, except for '#####' injected into the sequence between distinct lines of thought.
87 | Do NOT summarize portions of the reasoning sequence with '...'.
88 |
89 | Please keep the sequence that starts with '<|begin_of_solution|>' and ends with '<|end_of_solution|>' as
90 | one single sequence with no '#####' inside of the sequence. Add the separator '#####' immediately before '<|begin_of_solution|>'.
91 |
92 | Importantly, only use '#####' if a line of thought presents an answer.
93 | If the line of thought does not include an answer, it cannot be considered a separate line of thought, and should not be separated.
94 |
95 | For example, if the input is:
96 | <|begin_of_thought|>The answer to 2+3 is 5. But wait, let me double check this.
97 | If I have two apples and I am given three more apples, I now have 5 apples, so 5 seems like the right answer.
98 | Alternatively, 2+3 is the same as 3+2, which is also 5.<|end_of_thought|>
99 | <|begin_of_solution|>The answer is 5<|end_of_solution|>.
100 |
101 | Your output should be:
102 | <|begin_of_thought|>The answer to 2+3 is 5.
103 | #####
104 | But wait, let me double check this.
105 | If I have two apples and I am given three more apples, I now have 5 apples, so 5 seems like the right answer.
106 | #####
107 | Alternatively, 2+3 is the same as 3+2, which is also 5.<|end_of_thought|>
108 | #####
109 | <|begin_of_solution|>The answer is 5<|end_of_solution|>.
110 | """
111 |
112 | SUBSOLUTION_EXTRACTION_PROMPT = """
113 | You are given text of an attemp to solve a math problem. The text contains a final proposed answer to the math problem.
114 |
115 | The text also contains a string '#####' and after this string the ground truth answer is presented.
116 |
117 | Your objective is to determine whether the final proposed answer is equivalent to the ground truth answer.
118 | The proposed answer and ground truth answer may be in slightly different formats. For example, the proposed answer may be '1/2' but the ground truth is '0.5'.
119 | Equivalent answers in different formats should be treated as equivalent.
120 | If the text contains multiple proposed answers, use the final proposed answer.
121 |
122 | You should return only "True" if the proposed answer is equivalent to the ground truth answer and "False" if there is no proposed answer or if the proposed answer is not equivalent to the ground truth.
123 | Do NOT respond with anything at all except "True" or "False".
124 |
125 | For example, if you are given:
126 | I believe 2+3 equals 5.
127 | #####
128 | The ground truth answer is five.
129 |
130 | Your response should be:
131 | True
132 |
133 | Another example, if you are given:
134 | I believe 2+2 equals 4. But wait, it is actually 5.
135 | #####
136 | The ground truth answer is five.
137 |
138 | Your response should be:
139 | True
140 | """
--------------------------------------------------------------------------------
/tools/util/prompts.py:
--------------------------------------------------------------------------------
1 | grading_prompt = " \
2 | You will be given a math problem. Your job is to grade the difficulty level from 1-10 according to the AoPS standard. \
3 | Here is the standard: \
4 | {aops_criteria} \
5 | Problem to be labeled: {problem}. Please put your estimation of the difficulty inside [[level]]. \
6 | Important: You should place the difficulty from 1-10 into the [[]], not the solution of the problem."
7 |
8 | aops_criteria = " \
9 | All levels are estimated and refer to averages. The following is a rough standard based on the USA tier system AMC 8 - AMC 10 - AMC 12 - AIME - USAMO/USAJMO - IMO, \
10 | representing Middle School - Junior High - High School - Challenging High School - Olympiad levels. Other contests can be interpolated against this. \
11 | Notes: \
12 | Multiple choice tests like AMC are rated as though they are free-response. Test-takers can use the answer choices as hints, and so correctly answer more AMC questions than Mathcounts or AIME problems of similar difficulty. \
13 | Some Olympiads are taken in 2 sessions, with 2 similarly difficult sets of questions, numbered as one set. For these the first half of the test (questions 1-3) is similar difficulty to the second half (questions 4-6). \
14 | Scale \
15 | 1: Problems strictly for beginner, on the easiest elementary school or middle school levels (MOEMS, MATHCOUNTS Chapter, AMC 8 1-20, AMC 10 1-10, AMC 12 1-5, and others that involve standard techniques introduced up to the middle school level), most traditional middle/high school word problems. \
16 | 2: For motivated beginners, harder questions from the previous categories (AMC 8 21-25, harder MATHCOUNTS States questions, AMC 10 11-20, AMC 12 5-15, AIME 1-3), traditional middle/high school word problems with extremely complex problem solving. \
17 | 3: Advanced Beginner problems that require more creative thinking (harder MATHCOUNTS National questions, AMC 10 21-25, AMC 12 15-20, AIME 4-6). \
18 | 4: Intermediate-level problems (AMC 12 21-25, AIME 7-9). \
19 | 5: More difficult AIME problems (10-12), simple proof-based Olympiad-style problems (early JBMO questions, easiest USAJMO 1/4). \
20 | 6: High-leveled AIME-styled questions (13-15). Introductory-leveled Olympiad-level questions (harder USAJMO 1/4 and easier USAJMO 2/5, easier USAMO and IMO 1/4). \
21 | 7: Tougher Olympiad-level questions, may require more technical knowledge (harder USAJMO 2/5 and most USAJMO 3/6, extremely hard USAMO and IMO 1/4, easy-medium USAMO and IMO 2/5). \
22 | 8: High-level Olympiad-level questions (medium-hard USAMO and IMO 2/5, easiest USAMO and IMO 3/6). \
23 | 9: Expert Olympiad-level questions (average USAMO and IMO 3/6). \
24 | 10: Historically hard problems, generally unsuitable for very hard competitions (such as the IMO) due to being exceedingly tedious, long, and difficult (e.g. very few students are capable of solving on a worldwide basis). \
25 | Examples \
26 | For reference, here are problems from each of the difficulty levels 1-10: \
27 | <1: Jamie counted the number of edges of a cube, Jimmy counted the numbers of corners, and Judy counted the number of faces. They then added the three numbers. What was the resulting sum? (2003 AMC 8, Problem 1) \
28 | 1: How many integer values of $x$ satisfy $|x| < 3\pi$? (2021 Spring AMC 10B, Problem 1) \
29 | 2: A fair $6$-sided die is repeatedly rolled until an odd number appears. What is the probability that every even number appears at least once before the first occurrence of an odd number? (2021 Spring AMC 10B, Problem 18) \
30 | 3: Triangle $ABC$ with $AB=50$ and $AC=10$ has area $120$. Let $D$ be the midpoint of $\overline{AB}$, and let $E$ be the midpoint of $\overline{AC}$. The angle bisector of $\angle BAC$ intersects $\overline{DE}$ and $\overline{BC}$ at $F$ and $G$, respectively. What is the area of quadrilateral $FDBG$? (2018 AMC 10A, Problem 24) \
31 | 4: Define a sequence recursively by $x_0=5$ and\[x_{n+1}=\frac{x_n^2+5x_n+4}{x_n+6}\]for all nonnegative integers $n.$ Let $m$ be the least positive integer such that\[x_m\leq 4+\frac{1}{2^{20}}.\]In which of the following intervals does $m$ lie? \
32 | $\textbf{(A) } [9,26] \qquad\textbf{(B) } [27,80] \qquad\textbf{(C) } [81,242]\qquad\textbf{(D) } [243,728] \qquad\textbf{(E) } [729,\infty)$ \
33 | (2019 AMC 10B, Problem 24 and 2019 AMC 12B, Problem 22) \
34 | 5: Find all triples $(a, b, c)$ of real numbers such that the following system holds:\[a+b+c=\frac{1}{a}+\frac{1}{b}+\frac{1}{c},\]\[a^2+b^2+c^2=\frac{1}{a^2}+\frac{1}{b^2}+\frac{1}{c^2}.\](JBMO 2020/1) \
35 | 6: Let $\triangle ABC$ be an acute triangle with circumcircle $\omega,$ and let $H$ be the intersection of the altitudes of $\triangle ABC.$ Suppose the tangent to the circumcircle of $\triangle HBC$ at $H$ intersects $\omega$ at points $X$ and $Y$ with $HA=3,HX=2,$ and $HY=6.$ The area of $\triangle ABC$ can be written in the form $m\sqrt{n},$ where $m$ and $n$ are positive integers, and $n$ is not divisible by the square of any prime. Find $m+n.$ (2020 AIME I, Problem 15) \
36 | 7: We say that a finite set $\mathcal{S}$ in the plane is balanced if, for any two different points $A$, $B$ in $\mathcal{S}$, there is a point $C$ in $\mathcal{S}$ such that $AC=BC$. We say that $\mathcal{S}$ is centre-free if for any three points $A$, $B$, $C$ in $\mathcal{S}$, there is no point $P$ in $\mathcal{S}$ such that $PA=PB=PC$. \
37 | Show that for all integers $n\geq 3$, there exists a balanced set consisting of $n$ points. \
38 | Determine all integers $n\geq 3$ for which there exists a balanced centre-free set consisting of $n$ points. \
39 | (IMO 2015/1) \
40 | 8: For each positive integer $n$, the Bank of Cape Town issues coins of denomination $\frac1n$. Given a finite collection of such coins (of not necessarily different denominations) with total value at most most $99+\frac{1}{2}$, prove that it is possible to split this collection into $100$ or fewer groups, such that each group has total value at most $1$. (IMO 2014/5) \
41 | 9: Let $k$ be a positive integer and let $S$ be a finite set of odd prime numbers. Prove that there is at most one way (up to rotation and reflection) to place the elements of $S$ around the circle such that the product of any two neighbors is of the form $x^2+x+k$ for some positive integer $x$. (IMO 2022/3) \
42 | 10: Prove that there exists a positive constant $c$ such that the following statement is true: Consider an integer $n > 1$, and a set $\mathcal S$ of $n$ points in the plane such that the distance between any two different points in $\mathcal S$ is at least 1. It follows that there is a line $\ell$ separating $\mathcal S$ such that the distance from any point of $\mathcal S$ to $\ell$ is at least $cn^{-1/3}$. \
43 | (A line $\ell$ separates a set of points S if some segment joining two points in $\mathcal S$ crosses $\ell$.) (IMO 2020/6)"
44 |
45 |
46 | convert_prompt = "Another solution is written in an unstructured way. Your job is to convert them into two sections: \
47 | <|begin_of_thought|> \
48 | (Thought process, you should copy exactly the thinking process of the original solution.) \
49 | <|end_of_thought|> \
50 | <|begin_of_solution|> \
51 | (Final formatted, precise, and clear solution; make sure there is only one solution in this section; If it is a coding problem, make sure there is only one code block) \
52 | <|end_of_solution|> \
53 | Here is an example demonstration of a different question, you can refer to its format: \
54 | {example} \
55 | Important: You should almost copy all the contents word-by-word of the original solution. Just convert them into two sections. \
56 | Make sure you include: <|begin_of_slow_thought|>, <|end_of_slow_thought|>, <|begin_of_solution|>,<|end_of_solution|> These four headers explicitly. \
57 | Content to be converted: {content}"
58 |
59 | convert_prompt_example = ("<|begin_of_thought|>\n\n"
60 | "Okay, so I've got this problem here. Mr. Wang leaves home at 6 AM, riding his bike at 12 km/h, "
61 | "and he stops to rest for 6 minutes after every 30 minutes of riding. Then, when he arrives at a park "
62 | "that's 16.8 km away, I need to find out the angle between the hour and minute hands on his watch.\n\n"
63 | "Alright, first things first, I need to figure out how long it takes Mr. Wang to ride 16.8 km, including "
64 | "his rest periods.\n\n"
65 | "So, his speed is 12 km/h. To find out how long it takes to go 16.8 km without any stops, I can use the formula "
66 | "time = distance/speed. That would be 16.8 divided by 12, which is 1.4 hours. To make it easier, that's 1 hour and 24 minutes.\n\n"
67 | "But wait, he doesn't ride straight through. He stops for 6 minutes after every 30 minutes of riding. So, I need to see how many "
68 | "of those 30-minute riding periods are there in his total riding time.\n\n"
69 | "In 1 hour and 24 minutes of riding, how many 30-minute segments are there? Well, 1 hour is 60 minutes, plus 24 minutes makes 84 minutes "
70 | "total riding time. So, 84 divided by 30 is 2.8. That means he has two full 30-minute riding periods and a partial one.\n\n"
71 | "After each full 30-minute riding period, he rests for 6 minutes. So, for two full periods, he rests twice, which is 12 minutes of rest.\n\n"
72 | "Now, for the partial riding period. Since 2 times 30 minutes is 60 minutes, and he has 84 minutes of riding, the remaining riding time is 84 minus 60, "
73 | "which is 24 minutes. So, he rides for 24 minutes without another rest because he doesn't complete another 30-minute segment.\n\n"
74 | "So, total time taken is riding time plus rest time. That's 84 minutes riding plus 12 minutes resting, totaling 96 minutes.\n\n"
75 | "Wait a minute, but he stops after every 30 minutes of riding, but in the last partial period of 24 minutes, does he rest again? I think he only rests after "
76 | "completing 30 minutes of riding, so in this case, since the last riding period is only 24 minutes, he doesn't take an additional rest after that.\n\n"
77 | "So, total time should be 84 minutes riding plus 12 minutes resting, which is indeed 96 minutes, or 1 hour and 36 minutes.\n\n"
78 | "So, he leaves at 6 AM and takes 1 hour and 36 minutes to reach the park, arriving at 7:36 AM.\n\n"
79 | "Now, I need to find the angle between the hour and minute hands at 7:36.\n\n"
80 | "To find the angle between the hour and minute hands, I can use the formula:\n\n"
81 | "|30H - 5.5M|\n\n"
82 | "where H is the hour and M is the minutes.\n\n"
83 | "At 7:36, H is 7 and M is 36.\n\n"
84 | "So, plugging in:\n\n"
85 | "30*7 = 210\n\n"
86 | "5.5*36 = 198\n\n"
87 | "210 - 198 = 12\n\n"
88 | "So, the angle is 12 degrees.\n\n"
89 | "Wait, but I should make sure that's the smaller angle. Sometimes, the larger angle is considered, but usually, the smaller one is what is asked for.\n\n"
90 | "So, the angle between the hour and minute hands at 7:36 AM is 12 degrees.\n\n"
91 | "I think that's the answer.<|end_of_thought|>\n\n"
92 | "<|begin_of_solution|>\n\n"
93 | "Mr. Wang leaves home at 6 AM and rides at a speed of 12 km/h, stopping to rest for 6 minutes after every 30 minutes of riding. "
94 | "He arrives at a park 16.8 km away. To determine the angle between the hour and minute hands on his watch when he arrives, we first calculate the total time taken.\n\n"
95 | "1. **Riding time without stops**:\n\n"
96 | "$$\\text{Time} = \\frac{\\text{Distance}}{\\text{Speed}} = \\frac{16.8 \\text{ km}}{12 \\text{ km/h}} = 1.4 \\text{ hours} = 84 \\text{ minutes}$$\n\n"
97 | "2. **Rest periods**:\n\n"
98 | " - He rests for 6 minutes after every 30 minutes of riding.\n\n"
99 | " - In 84 minutes of riding, he completes 2 full 30-minute segments and a partial 24-minute segment.\n\n"
100 | " - He rests twice, totaling 12 minutes of rest.\n\n"
101 | "3. **Total time**:\n\n"
102 | "$$\\text{Total time} = 84 \\text{ minutes (riding)} + 12 \\text{ minutes (rest)} = 96 \\text{ minutes} = 1 \\text{ hour and } 36 \\text{ minutes}$$\n\n"
103 | " - He arrives at 7:36 AM.\n\n"
104 | "4. **Angle between hour and minute hands at 7:36**:\n\n"
105 | " - Use the formula:\n\n"
106 | "$$\\text{Angle} = |30H - 5.5M|$$\n\n"
107 | " - At 7:36, $H = 7$ and $M = 36$:\n\n"
108 | "$$\\text{Angle} = |30 \\times 7 - 5.5 \\times 36| = |210 - 198| = 12 \\text{ degrees}$$\n\n"
109 | "Thus, the angle between the hour and minute hands on his watch is $\\boxed{12}$.<|end_of_solution|>\n")
110 |
111 | # From https://arxiv.org/pdf/2412.09413
112 | system_prompt = "Your role as an assistant involves thoroughly exploring questions through a systematic long \
113 | thinking process before providing the final precise and accurate solutions. This requires \
114 | engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, \
115 | backtracing, and iteration to develop well-considered thinking process. \
116 | Please structure your response into two main sections: Thought and Solution. \
117 | In the Thought section, detail your reasoning process using the specified format: \
118 | <|begin_of_thought|> {thought with steps separated with '\n\n'} \
119 | <|end_of_thought|> \
120 | Each step should include detailed considerations such as analisying questions, summarizing \
121 | relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining \
122 | any errors, and revisiting previous steps. \
123 | In the Solution section, based on various attempts, explorations, and reflections from the Thought \
124 | section, systematically present the final solution that you deem correct. The solution should \
125 | remain a logical, accurate, concise expression style and detail necessary step needed to reach the \
126 | conclusion, formatted as follows: \
127 | <|begin_of_solution|> \
128 | {final formatted, precise, and clear solution} \
129 | <|end_of_solution|> \
130 | Now, try to solve the following question through the above guidelines:"
--------------------------------------------------------------------------------
/tools/util/taco/pyext2.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (C) 2014 Ryan Gonzalez
3 |
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to use,
8 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
9 | Software, and to permit persons to whom the Software is furnished to do so,
10 | subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 | '''
22 |
23 | g_backup = globals().copy()
24 |
25 | __version__ = '0.7'
26 |
27 | __all__ = ['overload', 'RuntimeModule', 'switch', 'tail_recurse', 'copyfunc', 'set_docstring', 'annotate', 'safe_unpack', 'modify_function', 'assign', 'fannotate', 'compare_and_swap', 'is_main', 'call_if_main', 'run_main']
28 |
29 | import sys, inspect, types
30 |
31 | def __targspec(func, specs, attr='__orig_arg__'):
32 | if hasattr(func, '__is_overload__') and func.__is_overload__:
33 | return getattr(func, attr)
34 | return specs(func)
35 |
36 | def set_docstring(doc):
37 | '''A simple decorator to set docstrings.
38 |
39 | :param doc: The docstring to tie to the function.
40 |
41 | Example::
42 |
43 | @set_docstring('This is a docstring')
44 | def myfunc(x):
45 | pass'''
46 | def _wrap(f):
47 | f.__doc__ = doc
48 | return f
49 | return _wrap
50 |
51 | __modify_function_doc = '''
52 | Creates a copy of a function, changing its attributes.
53 |
54 | :param globals: Will be added to the function's globals.
55 |
56 | :param name: The new function name. Set to ``None`` to use the function's original name.
57 |
58 | :param code: The new function code object. Set to ``None`` to use the function's original code object.
59 |
60 | :param defaults: The new function defaults. Set to ``None`` to use the function's original defaults.
61 |
62 | :param closure: The new function closure. Set to ``None`` to use the function's original closure.
63 |
64 | .. warning:: This function can be potentially dangerous.
65 | '''
66 |
67 | def copyfunc(f):
68 | '''Copies a funcion.
69 |
70 | :param f: The function to copy.
71 |
72 | :return: The copied function.
73 |
74 | .. deprecated:: 0.4
75 | Use :func:`modify_function` instead.
76 | '''
77 | return modify_function(f)
78 |
79 | if sys.version_info.major == 3:
80 | @set_docstring(__modify_function_doc)
81 | def modify_function(f, globals={}, name=None, code=None, defaults=None,
82 | closure=None):
83 | if code is None: code = f.__code__
84 | if name is None: name = f.__name__
85 | if defaults is None: defaults = f.__defaults__
86 | if closure is None: closure = f.__closure__
87 | newf = types.FunctionType(code, dict(f.__globals__, **globals), name=name,
88 | argdefs=defaults, closure=closure)
89 | newf.__dict__.update(f.__dict__)
90 | return newf
91 | def argspec(f):
92 | return inspect.getfullargspec(f)
93 | ofullargspec = inspect.getfullargspec
94 | def _fullargspec(func):
95 | return __targspec(func, ofullargspec)
96 | inspect.getfullargspec = _fullargspec
97 | def _exec(m,g): exec(m,g)
98 | else:
99 | @set_docstring(__modify_function_doc)
100 | def modify_function(f, globals={}, name=None, code=None, defaults=None,
101 | closure=None):
102 | if code is None: code = f.func_code
103 | if name is None: name = f.__name__
104 | if defaults is None: defaults = f.func_defaults
105 | if closure is None: closure = f.func_closure
106 | newf = types.FunctionType(code, dict(f.func_globals, **globals), name=name,
107 | argdefs=defaults, closure=closure)
108 | newf.__dict__.update(f.__dict__)
109 | return newf
110 | def argspec(f):
111 | return inspect.getargspec(f)
112 | eval(compile('def _exec(m,g): exec m in g', '', 'exec'))
113 |
114 | def _gettypes(args):
115 | return tuple(map(type, args))
116 |
117 | oargspec = inspect.getargspec
118 |
119 | def _argspec(func):
120 | return __targspec(func, oargspec)
121 |
122 | inspect.getargspec = _argspec
123 |
124 | try:
125 | import IPython
126 | except ImportError:
127 | IPython = None
128 | else:
129 | # Replace IPython's argspec
130 | oipyargspec = IPython.core.oinspect.getargspec
131 | def _ipyargspec(func):
132 | return __targspec(func, oipyargspec, '__orig_arg_ipy__')
133 | IPython.core.oinspect.getargspec = _ipyargspec
134 |
135 | class overload(object):
136 | '''Simple function overloading in Python.'''
137 | _items = {}
138 | _types = {}
139 | @classmethod
140 | def argc(self, argc=None):
141 | '''Overloads a function based on the specified argument count.
142 |
143 | :param argc: The argument count. Defaults to ``None``. If ``None`` is given, automatically compute the argument count from the given function.
144 |
145 | .. note::
146 |
147 | Keyword argument counts are NOT checked! In addition, when the argument count is automatically calculated, the keyword argument count is also ignored!
148 |
149 | Example::
150 |
151 | @overload.argc()
152 | def func(a):
153 | print 'Function 1 called'
154 |
155 | @overload.argc()
156 | def func(a, b):
157 | print 'Function 2 called'
158 |
159 | func(1) # Calls first function
160 | func(1, 2) # Calls second function
161 | func() # Raises error
162 | '''
163 | # Python 2 UnboundLocalError fix
164 | argc = {'argc': argc}
165 | def _wrap(f):
166 | def _newf(*args, **kwargs):
167 | if len(args) not in self._items[f.__name__]:
168 | raise TypeError("No overload of function '%s' that takes %d args" % (f.__name__, len(args)))
169 | return self._items[f.__name__][len(args)](*args, **kwargs)
170 | if f.__name__ not in self._items:
171 | self._items[f.__name__] = {}
172 | if argc['argc'] is None:
173 | argc['argc'] = len(argspec(f).args)
174 | self._items[f.__name__][argc['argc']] = f
175 | _newf.__name__ = f.__name__
176 | _newf.__doc__ = f.__doc__
177 | _newf.__is_overload__ = True
178 | _newf.__orig_arg__ = argspec(f)
179 | if IPython:
180 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f)
181 | return _newf
182 | return _wrap
183 | @classmethod
184 | def args(self, *argtypes, **kw):
185 | '''Overload a function based on the specified argument types.
186 |
187 | :param argtypes: The argument types. If None is given, get the argument types from the function annotations(Python 3 only)
188 | :param kw: Can only contain 1 argument, `is_cls`. If True, the function is assumed to be part of a class.
189 |
190 | Example::
191 |
192 | @overload.args(str)
193 | def func(s):
194 | print 'Got string'
195 |
196 | @overload.args(int, str)
197 | def func(i, s):
198 | print 'Got int and string'
199 |
200 | @overload.args()
201 | def func(i:int): # A function annotation example
202 | print 'Got int'
203 |
204 | func('s')
205 | func(1)
206 | func(1, 's')
207 | func(True) # Raises error
208 | '''
209 |
210 | # Python 2 UnboundLocalError fix...again!
211 | argtypes = {'args': tuple(argtypes)}
212 | def _wrap(f):
213 | def _newf(*args):
214 | if len(kw) == 0:
215 | cargs = args
216 | elif len(kw) == 1 and 'is_cls' in kw and kw['is_cls']:
217 | cargs = args[1:]
218 | else:
219 | raise ValueError('Invalid keyword args specified')
220 | if _gettypes(cargs) not in self._types[f.__name__]:
221 | raise TypeError("No overload of function '%s' that takes '%s' types and %d arg(s)" % (f.__name__, _gettypes(cargs), len(cargs)))
222 | return self._types[f.__name__][_gettypes(cargs)](*args)
223 | if f.__name__ not in self._types:
224 | self._types[f.__name__] = {}
225 | if len(argtypes['args']) == 1 and argtypes['args'][0] is None:
226 | aspec = argspec(f)
227 | argtypes['args'] = tuple(map(lambda x: x[1], sorted(
228 | aspec.annotations.items(), key=lambda x: aspec.args.index(x[0]))))
229 | self._types[f.__name__][argtypes['args']] = f
230 | _newf.__name__ = f.__name__
231 | _newf.__doc__ = f.__doc__
232 | _newf.__is_overload__ = True
233 | _newf.__orig_arg__ = argspec(f)
234 | if IPython:
235 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f)
236 | return _newf
237 | return _wrap
238 |
239 | class _RuntimeModule(object):
240 | 'Create a module object at runtime and insert it into sys.path. If called, same as :py:func:`from_objects`.'
241 | def __call__(self, *args, **kwargs):
242 | return self.from_objects(*args, **kwargs)
243 | @staticmethod
244 | @overload.argc(1)
245 | def from_objects(module_name_for_code_eval, **d):
246 | return _RuntimeModule.from_objects(module_name_for_code_eval, '', **d)
247 | @staticmethod
248 | @overload.argc(2)
249 | def from_objects(module_name_for_code_eval, docstring, **d):
250 | '''Create a module at runtime from `d`.
251 |
252 | :param name: The module name.
253 |
254 | :param docstring: Optional. The module's docstring.
255 |
256 | :param \*\*d: All the keyword args, mapped from name->value.
257 |
258 | Example: ``RuntimeModule.from_objects('name', 'doc', a=1, b=2)``'''
259 | module = types.ModuleType(module_name_for_code_eval, docstring)
260 | module.__dict__.update(d)
261 | module.__file__ = ''
262 | sys.modules[module_name_for_code_eval] = module
263 | return module
264 | @staticmethod
265 | @overload.argc(2)
266 | def from_string(module_name_for_code_eval, s):
267 | return _RuntimeModule.from_string(module_name_for_code_eval, '', s)
268 | @staticmethod
269 | @overload.argc(3)
270 | def from_string(module_name_for_code_eval, docstring, s):
271 | '''Create a module at runtime from `s``.
272 |
273 | :param name: The module name.
274 |
275 | :param docstring: Optional. The module docstring.
276 |
277 | :param s: A string containing the module definition.'''
278 | g = {}
279 | _exec(s, g)
280 | return _RuntimeModule.from_objects(module_name_for_code_eval, docstring, **dict(filter(lambda x: x[0] not in g_backup, g.items())))
281 |
282 | RuntimeModule = _RuntimeModule()
283 |
284 | class CaseObject(object):
285 | 'The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments.'
286 | def __init__(self, value):
287 | self.value = value
288 | self.did_match = False
289 | self.did_pass = False
290 | def __call__(self, *args):
291 | if assign('res', not self.did_pass and any([self.value == rhs for rhs in args])):
292 | self.did_match = True
293 | return res
294 | def quit(self):
295 | 'Forces all other calls to return False. Equilavent of a ``break`` statement.'
296 | self.did_pass = True
297 | def default(self):
298 | "Executed if quit wasn't called."
299 | return not self.did_match and not self.did_pass
300 | def __iter__(self):
301 | yield self
302 | def __enter__(self):
303 | return self
304 | def __exit__(self, *args):
305 | pass
306 |
307 | def switch(value):
308 | '''A Python switch statement implementation that is used with a ``with`` statement.
309 |
310 | :param value: The value to "switch".
311 |
312 | ``with`` statement example::
313 |
314 | with switch('x'):
315 | if case(1): print 'Huh?'
316 | if case('x'): print 'It works!!!'
317 |
318 | .. warning:: If you modify a variable named "case" in the same scope that you use the ``with`` statement version, you will get an UnboundLocalError. The soluction is to use ``with switch('x') as case:`` instead of ``with switch('x'):``.'''
319 | res = CaseObject(value)
320 | inspect.stack()[1][0].f_globals['case'] = res
321 | return res
322 |
323 | def tail_recurse(spec=None):
324 | '''Remove tail recursion from a function.
325 |
326 | :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value.
327 |
328 | .. note::
329 |
330 | This function has a slight overhead that is noticable when using timeit. Only use it if the function has a possibility of going over the recursion limit.
331 |
332 | .. warning::
333 |
334 | This function will BREAK any code that either uses any recursion other than tail recursion or calls itself multiple times. For example, ``def x(): return x()+1`` will fail.
335 |
336 | Example::
337 |
338 | @tail_recurse()
339 | def add(a, b):
340 | if a == 0: return b
341 | return add(a-1, b+1)
342 |
343 | add(10000000, 1) # Doesn't max the recursion limit.
344 | '''
345 | def _wrap(f):
346 | class TailRecursion(Exception):
347 | def __init__(self, args, kwargs):
348 | self.args = args
349 | self.kwargs = kwargs
350 | def _newf(*args, **kwargs):
351 | if inspect.stack()[1][3] == f.__name__:
352 | if (spec and spec(args)) or not spec:
353 | raise TailRecursion(args, kwargs)
354 | while True:
355 | try:
356 | res = f(*args, **kwargs)
357 | except TailRecursion as ex:
358 | args = ex.args
359 | kwargs = ex.kwargs
360 | continue
361 | else:
362 | return res
363 | _newf.__doc__ = f.__doc__
364 | return _newf
365 | return _wrap
366 |
367 | def annotate(*args, **kwargs):
368 | '''Set function annotations using decorators.
369 |
370 | :param args: This is a list of annotations for the function, in the order of the function's parameters. For example, ``annotate('Annotation 1', 'Annotation 2')`` will set the annotations of parameter 1 of the function to ``Annotation 1``.
371 |
372 | :param kwargs: This is a mapping of argument names to annotations. Note that these are applied *after* the argument list, so any args set that way will be overriden by this mapping. If there is a key named `ret`, that will be the annotation for the function's return value.
373 |
374 | .. deprecated:: 0.5
375 | Use :func:`fannotate` instead.
376 | '''
377 | def _wrap(f):
378 | if not hasattr(f, '__annotations__'):
379 | f.__annotations__ = {}
380 | if 'ret' in kwargs:
381 | f.__annotations__['return'] = kwargs.pop('ret')
382 | f.__annotations__.update(dict(zip(argspec(f).args, args)))
383 | f.__annotations__.update(kwargs)
384 | return f
385 | return _wrap
386 |
387 | def fannotate(*args, **kwargs):
388 | '''Set function annotations using decorators.
389 |
390 | :param \*args: The first positional argument is used for the function's return value; all others are discarded.
391 |
392 | :param \**kwargs: This is a mapping of argument names to annotations.
393 |
394 | Example::
395 |
396 | @fannotate('This for the return value', a='Parameter a', b='Parameter b')
397 | def x(a, b):
398 | pass
399 |
400 | '''
401 | def _wrap(f):
402 | if not hasattr(f, '__annotations__'):
403 | f.__annotations__ = {}
404 | if len(args) >= 1:
405 | f.__annotations__['return'] = args[0]
406 | f.__annotations__.update(kwargs)
407 | return f
408 | return _wrap
409 |
410 | def safe_unpack(seq, ln, fill=None):
411 | '''Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off.
412 |
413 | :param seq: The sequence to unpack.
414 |
415 | :param ln: The expected length of the sequence.
416 |
417 | :param fill: The value to substitute if the sequence is too small. Defaults to ``None``.
418 |
419 | Example::
420 |
421 | s = 'a:b'
422 | a, b = safe_unpack(s.split(':'), 2)
423 | # a = 'a'
424 | # b = 'b'
425 | s = 'a'
426 | a, b = safe_unpack(s.split(':'), 2)
427 | # a = 'a'
428 | # b = None'''
429 | if len(seq) > ln:
430 | return seq[:ln]
431 | elif len(seq) < ln:
432 | return seq + type(seq)([fill]*(ln-len(seq)))
433 | else:
434 | return seq
435 |
436 | def assign(varname, value):
437 | '''Assign `value` to `varname` and return it. If `varname` is an attribute and the instance name it belongs to is not defined, a NameError is raised.
438 | This can be used to emulate assignment as an expression. For example, this::
439 |
440 | if assign('x', 7): ...
441 |
442 | is equilavent to this C code::
443 |
444 | if (x = 7) ...
445 |
446 | .. warning::
447 |
448 | When assigning an attribute, the instance it belongs to MUST be declared as global prior to the assignment. Otherwise, the assignment will not work.
449 | '''
450 | fd = inspect.stack()[1][0].f_globals
451 | if '.' not in varname:
452 | fd[varname] = value
453 | else:
454 | vsplit = list(map(str.strip, varname.split('.')))
455 | if vsplit[0] not in fd:
456 | raise NameError('Unknown object: %s'%vsplit[0])
457 | base = fd[vsplit[0]]
458 | for x in vsplit[1:-1]:
459 | base = getattr(base, x)
460 | setattr(base, vsplit[-1], value)
461 | return value
462 |
463 | def is_main(frame=1):
464 | "Return if the caller is main. Equilavent to ``__name__ == '__main__'``."
465 | return inspect.stack()[frame][0].f_globals['__name__'] == '__main__'
466 |
467 | def _call_if_main(frame, f, args):
468 | if is_main(frame): return f(*args)
469 |
470 | def call_if_main(f,*args):
471 | "Call the `f` with `args` if the caller's module is main."
472 | return _call_if_main(3,f,args)
473 |
474 | def run_main(f,*args):
475 | "Call `f` with the `args` and terminate the program with its return code if the caller's module is main."
476 | sys.exit(_call_if_main(3,f,args))
477 |
478 | def compare_and_swap(var, compare, new):
479 | "If `var` is equal to `compare`, set it to `new`."
480 | if assign('v', inspect.stack()[1][0].f_globals)[var] == compare:
481 | v[var] = new
482 |
--------------------------------------------------------------------------------
/train/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.abspath(os.path.join(__file__, "../../")))
4 |
--------------------------------------------------------------------------------
/train/deepseed/ds_z3_offload_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 3,
20 | "offload_optimizer": {
21 | "device": "cpu",
22 | "pin_memory": true
23 | },
24 | "offload_param": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "overlap_comm": true,
29 | "contiguous_gradients": true,
30 | "sub_group_size": 1e9,
31 | "reduce_bucket_size": "auto",
32 | "stage3_prefetch_bucket_size": "auto",
33 | "stage3_param_persistence_threshold": "auto",
34 | "stage3_max_live_parameters": 1e9,
35 | "stage3_max_reuse_distance": 1e9,
36 | "stage3_gather_16bit_weights_on_model_save": true
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/train/deepseed/zero2_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 2,
20 | "allgather_partitions": true,
21 | "allgather_bucket_size": 5e8,
22 | "overlap_comm": true,
23 | "reduce_scatter": true,
24 | "reduce_bucket_size": 5e8,
25 | "contiguous_gradients": true,
26 | "round_robin_gradients": true
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/train/deepseed/zero3_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": "auto",
3 | "train_micro_batch_size_per_gpu": "auto",
4 | "gradient_accumulation_steps": "auto",
5 | "gradient_clipping": "auto",
6 | "zero_allow_untested_optimizer": true,
7 | "fp16": {
8 | "enabled": "auto",
9 | "loss_scale": 0,
10 | "loss_scale_window": 1000,
11 | "initial_scale_power": 16,
12 | "hysteresis": 2,
13 | "min_loss_scale": 1
14 | },
15 | "bf16": {
16 | "enabled": "auto"
17 | },
18 | "zero_optimization": {
19 | "stage": 3,
20 | "overlap_comm": true,
21 | "contiguous_gradients": true,
22 | "sub_group_size": 1e9,
23 | "reduce_bucket_size": "auto",
24 | "stage3_prefetch_bucket_size": "auto",
25 | "stage3_param_persistence_threshold": "auto",
26 | "stage3_max_live_parameters": 1e9,
27 | "stage3_max_reuse_distance": 1e9,
28 | "stage3_gather_16bit_weights_on_model_save": true
29 | },
30 | "wandb": {
31 | "enabled": true,
32 | "rank_zero_only": true
33 | }
34 | }
--------------------------------------------------------------------------------
/train/deepseed/zero3_config2.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": true
4 | },
5 | "zero_optimization": {
6 | "stage": 3,
7 | "allgather_partitions": true,
8 | "allgather_bucket_size": 2e8,
9 | "overlap_comm": true,
10 | "reduce_scatter": true,
11 | "reduce_bucket_size": "auto",
12 | "contiguous_gradients": true,
13 | "stage3_gather_16bit_weights_on_model_save": true
14 | },
15 | "gradient_accumulation_steps": "auto",
16 | "gradient_clipping": "auto",
17 | "train_batch_size": "auto",
18 | "train_micro_batch_size_per_gpu": "auto",
19 | "wall_clock_breakdown": false
20 | }
--------------------------------------------------------------------------------
/train/dpo_train.py:
--------------------------------------------------------------------------------
1 | import __init__
2 | from utils.utils import *
3 | import os
4 | import warnings
5 | from trl import DPOConfig, DPOTrainer
6 | import wandb
7 | from utils.load_model import *
8 | import argparse
9 | from train.names import *
10 |
11 | warnings.filterwarnings('ignore')
12 | os.environ["WANDB_MODE"] = "offline"
13 |
14 | def parse_args(args=None):
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--lr', type=float, default=5e-7)
17 | parser.add_argument('--beta', type=float, default=0.05)
18 | parser.add_argument('--model', type=str, default='Bespoke-7b')
19 | parser.add_argument('--dataset', type=str, default='Bespoke_dpo')
20 | parser.add_argument('--epoch', type=int, default=1)
21 | parser.add_argument('--MAX_LENGTH', type=int, default=1024 * 8)
22 | parser.add_argument('--seed', type=int, default=42)
23 | parser.add_argument('--gradient_accumulation_steps', type=int, default=12)
24 | parser.add_argument('--deepspeed', type=str, default=None) #
25 | parser.add_argument('--local_rank', type=int, default=0)
26 | return parser.parse_args(args)
27 |
28 |
29 | if __name__ == '__main__':
30 | args = parse_args()
31 | lr, beta = args.lr, args.beta
32 | model_name = model_names[args.model]
33 | MAX_LENGTH = args.MAX_LENGTH
34 | dataset_name = set_global(dataset_names[args.dataset])
35 | wandb_name = f"{args.model}_{args.dataset}_dpo_lr{lr}_beta{beta}"
36 |
37 | training_config = DPOConfig(
38 | save_only_model=True,
39 | output_dir=set_global(f"./train/models/dpo/{wandb_name}"),
40 | per_device_train_batch_size=1,
41 | gradient_accumulation_steps=args.gradient_accumulation_steps,
42 | gradient_checkpointing=True,
43 | save_total_limit=1,
44 | num_train_epochs=args.epoch,
45 | report_to="wandb",
46 | save_strategy='epoch',
47 | logging_steps=1,
48 | learning_rate=lr,
49 | beta=beta,
50 | bf16=True,
51 | lr_scheduler_type='cosine',
52 | warmup_ratio = 0.1,
53 | max_length=MAX_LENGTH,
54 | deepspeed= set_global(args.deepspeed) if args.deepspeed is not None else None,
55 | )
56 |
57 | train_dataset = load_train_data(dataset_name)
58 | model, tokenizer = init_train_model(model_name)
59 | ref_model = init_train_model(model_name)[0]
60 | ref_model.eval()
61 | # ref_model = None
62 |
63 | if args.local_rank == 0:
64 | wandb.login()
65 | wandb.init(project="ThinkPO", name=wandb_name)
66 |
67 | dpo_trainer = DPOTrainer(
68 | model,
69 | ref_model=ref_model,
70 | args=training_config,
71 | train_dataset=train_dataset,
72 | tokenizer=tokenizer,
73 | )
74 | dpo_trainer.train()
--------------------------------------------------------------------------------
/train/names.py:
--------------------------------------------------------------------------------
1 | import __init__
2 | from utils.utils import *
3 | from utils.data_utils import save_all_results, read_saved_results, load_data
4 | from utils.load_model import *
5 |
6 | def init_train_model(model_name=''):
7 | model, tokenizer = init_model(model_name, eval=False, low_cpu_mem_usage=False)
8 | model.enable_input_require_grads()
9 | if tokenizer.pad_token is None:
10 | token = tokenizer.convert_ids_to_tokens(2)
11 | print(f"Token for ID 2: {token}")
12 | tokenizer.pad_token_id = 2
13 | return model, tokenizer
14 |
15 | def load_train_data(choose_data_name):
16 | if choose_data_name in ['Bespoke_dpo', 'Bespoke', 'Deepseek']:
17 | choose_data_name = dataset_names[choose_data_name]
18 | choose_data = load_data(choose_data_name, 'huggingface')
19 | else:
20 | choose_data_name = set_global(dataset_names[choose_data_name])
21 | choose_data = load_data(choose_data_name, 'json')
22 | choose_data = choose_data['train']
23 | return choose_data
24 |
25 | dataset_names = {
26 | 'NuminaMath': 'AI-MO/NuminaMath-CoT',
27 | 'Bespoke':'bespokelabs/Bespoke-Stratos-17k',
28 |
29 | 'Bespoke_dpo':f'VanWang/Bespoke_dpo_filter',
30 | 'Bespoke_dpo_long':f'data/final/Bespoke_dpo_filter_len_long.jsonl',
31 | 'Bespoke_dpo_short':f'data/final/Bespoke_dpo_filter_len_short.jsonl',
32 | 'Bespoke_dpo_middle':'data/final/Bespoke_dpo_filter_len_middle.jsonl'
33 | }
34 | model_names = {
35 | 'Deepseek-7b':'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
36 |
37 | 'Instruct-1.5b':'Qwen/Qwen2.5-1.5B-Instruct',
38 | 'Instruct-3b':'Qwen/Qwen2.5-3B-Instruct',
39 | 'Instruct-7b':'Qwen/Qwen2.5-7B-Instruct',
40 | 'Instruct-3b':'Qwen/Qwen2.5-3B-Instruct',
41 | 'Instruct-14b':'Qwen/Qwen2.5-14B-Instruct',
42 | 'Instruct-32b':'Qwen/Qwen2.5-32B-Instruct',
43 |
44 | 'Bespoke-32b': 'bespokelabs/Bespoke-Stratos-32B',
45 | 'Bespoke-7b': 'bespokelabs/Bespoke-Stratos-7B',
46 |
47 | 'OpenThinker-7B':'open-thoughts/OpenThinker-7B',
48 | }
--------------------------------------------------------------------------------
/train/sft_train.py:
--------------------------------------------------------------------------------
1 | from __init__ import *
2 | from utils.utils import *
3 | from transformers import Trainer, TrainingArguments,DataCollatorForSeq2Seq
4 | import wandb
5 | from utils.data_utils import load_data
6 | from utils.load_model import *
7 | import argparse
8 | import os
9 | from train.names import model_names, dataset_names
10 |
11 | os.environ['WANDB_MODE'] = 'offline'
12 |
13 | def init_train_model(model_name=''):
14 | model, tokenizer = init_model(model_name=model_names[model_name], eval=False)
15 | model.enable_input_require_grads()
16 | tokenizer.padding_side = "right"
17 | if tokenizer.pad_token is None:
18 | token = tokenizer.convert_ids_to_tokens(2)
19 | print(f"Token for ID 2: {token}")
20 | tokenizer.pad_token_id = 2
21 | return model, tokenizer
22 |
23 | def preprocess_function(example):
24 | global MAX_LENGTH
25 | input_ids, attention_mask, labels = [], [], []
26 | system, conversations = example["system"], example["conversations"]
27 | targets = conversations[1]['value'].strip()
28 | messages = [
29 | {"role": "system", "content": system},
30 | {"role": "user", "content": conversations[0]['value']}
31 | ]
32 | inputs = tokenizer.apply_chat_template(
33 | messages,
34 | tokenize=False,
35 | add_generation_prompt=True
36 | )
37 | instruction, response = tokenizer(inputs+'\n'), tokenizer(targets)
38 | input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.eos_token_id]
39 | attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]
40 | labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.eos_token_id]
41 | if len(input_ids) > MAX_LENGTH:
42 | input_ids = input_ids[:MAX_LENGTH]
43 | attention_mask = attention_mask[:MAX_LENGTH]
44 | labels = labels[:MAX_LENGTH]
45 | return {
46 | "input_ids": input_ids,
47 | "attention_mask": attention_mask,
48 | "labels": labels
49 | }
50 |
51 | def load_train_data(dataset_name):
52 | if dataset_name == 'NuminaMath':
53 | dataset = load_data(set_global(dataset_names[dataset_name]), 'json')
54 | train_dataset = dataset['train']
55 | if dataset_name == 'Bespoke':
56 | dataset = load_data(dataset_names[dataset_name], 'huggingface')
57 | train_dataset = dataset['train']
58 | return train_dataset.map(preprocess_function, remove_columns=train_dataset.column_names)
59 |
60 |
61 | def parse_args(args=None):
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--lr', type=float, default=3e-5)
64 | parser.add_argument('--MAX_LENGTH', type=int, default=1024*10)
65 | parser.add_argument('--model_name', type=str, default='Qwen2.5-14B-Instruct', help='Qwen2.5-3B, Qwen2.5-7B, Qwen2.5-32B')
66 | parser.add_argument('--seed', type=int, default=42)
67 | parser.add_argument('--gradient_accumulation_steps', type=int, default=96)
68 | parser.add_argument('--dataset_name', type=str, default='Bespoke', help='Bespoke, NuminaMath')
69 | parser.add_argument('--deepspeed', type=str, default=None) #
70 | parser.add_argument('--local_rank', type=int, default=0)
71 | parser.add_argument('--epoch', type=int, default=1)
72 | return parser.parse_args(args)
73 |
74 | args = parse_args()
75 |
76 | lr = args.lr
77 | MAX_LENGTH = args.MAX_LENGTH
78 | model_name = args.model_name
79 | dataset_name = args.dataset_name
80 | wandb_name = f"{dataset_name}_{model_name}_sft_lr{lr}"
81 |
82 | model, tokenizer = init_train_model(model_name=model_name)
83 | train_dataset = load_train_data(dataset_name)
84 |
85 | if args.local_rank == 0:
86 | wandb.login()
87 | wandb.init(project="ThinkPO-SFT", name=wandb_name)
88 |
89 | training_args = TrainingArguments(
90 | save_only_model=True,
91 | output_dir=set_global(f"./train/models/sft/{wandb_name}"),
92 | per_device_train_batch_size=1,
93 | gradient_accumulation_steps=args.gradient_accumulation_steps,
94 | save_total_limit=1,
95 | num_train_epochs=args.epoch,
96 | learning_rate=lr,
97 | gradient_checkpointing=True,
98 | lr_scheduler_type='cosine',
99 | warmup_ratio=0.1,
100 | logging_steps=10,
101 | save_strategy='epoch',
102 | report_to="wandb",
103 | bf16=True,
104 | # dataloader_num_workers=4,
105 | deepspeed= set_global(args.deepspeed) if args.deepspeed is not None else None,
106 | )
107 |
108 | trainer = Trainer(
109 | model=model,
110 | args=training_args,
111 | tokenizer=tokenizer,
112 | train_dataset=train_dataset,
113 | data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
114 | )
115 | trainer.train()
116 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.abspath(os.path.join(__file__, "../../")))
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from .utils import cache_dir
3 | import os
4 | import json
5 |
6 | def save_results(out_path, data):
7 | if os.path.exists(out_path):
8 | with open(out_path, "a", encoding="utf-8") as f:
9 | json.dump(data, f)
10 | f.write('\n')
11 | else:
12 | save_all_results(out_path, [data])
13 |
14 |
15 | def save_all_results(out_path, data):
16 | os.makedirs(os.path.dirname(out_path), exist_ok=True)
17 | with open(out_path, "w", encoding="utf-8") as f:
18 | for pred in data:
19 | json.dump(pred, f, ensure_ascii=False)
20 | f.write('\n')
21 |
22 | def read_saved_results(out_path):
23 | preds=[]
24 | if os.path.exists(out_path):
25 | with open(out_path, "r", encoding="utf-8") as f:
26 | i=0
27 | for line in f:
28 | l = json.loads(line)
29 | preds.append(l)
30 | i+=1
31 | print(i)
32 | return preds
33 |
34 |
35 | def load_data(dataset_name, mode='json'):
36 | if mode == 'huggingface':
37 | dataset = load_dataset(dataset_name)
38 | if mode == 'json':
39 | dataset = load_dataset('json', data_files=dataset_name)
40 | return dataset
41 |
--------------------------------------------------------------------------------
/utils/eval/eval_utils.py:
--------------------------------------------------------------------------------
1 | from .qwen_math_parser import *
2 |
3 |
4 | def check_math_correctness(reference, generation):
5 | if not find_box(generation) or not find_box(reference): return False
6 | ref = extract_answer(reference)
7 | answer = strip_answer_string(ref)
8 | pred = extract_answer(generation)
9 | pred = strip_answer_string(pred)
10 | return math_equal(pred, answer)
11 |
12 | def check_math_correctness_with_model(reference, generation):
13 | pass
--------------------------------------------------------------------------------
/utils/eval/qwen_math_parser.py:
--------------------------------------------------------------------------------
1 | """
2 | The logic in this file largely borrows from Qwen2.5-Math codebase at https://github.com/QwenLM/Qwen2.5-Math:
3 | """
4 |
5 | import re
6 | import regex
7 | from word2number import w2n
8 | from math import isclose
9 | from collections import defaultdict
10 |
11 | from sympy import simplify, N
12 | from sympy.parsing.sympy_parser import parse_expr
13 | from sympy.parsing.latex import parse_latex
14 | from latex2sympy2 import latex2sympy
15 |
16 |
17 | def convert_word_number(text: str) -> str:
18 | try:
19 | text = str(w2n.word_to_num(text))
20 | except:
21 | pass
22 | return text
23 |
24 | def _fix_fracs(string):
25 | substrs = string.split("\\frac")
26 | new_str = substrs[0]
27 | if len(substrs) > 1:
28 | substrs = substrs[1:]
29 | for substr in substrs:
30 | new_str += "\\frac"
31 | if len(substr) > 0 and substr[0] == "{":
32 | new_str += substr
33 | else:
34 | try:
35 | assert len(substr) >= 2
36 | except:
37 | return string
38 | a = substr[0]
39 | b = substr[1]
40 | if b != "{":
41 | if len(substr) > 2:
42 | post_substr = substr[2:]
43 | new_str += "{" + a + "}{" + b + "}" + post_substr
44 | else:
45 | new_str += "{" + a + "}{" + b + "}"
46 | else:
47 | if len(substr) > 2:
48 | post_substr = substr[2:]
49 | new_str += "{" + a + "}" + b + post_substr
50 | else:
51 | new_str += "{" + a + "}" + b
52 | string = new_str
53 | return string
54 |
55 |
56 | def _fix_a_slash_b(string):
57 | if len(string.split("/")) != 2:
58 | return string
59 | a = string.split("/")[0]
60 | b = string.split("/")[1]
61 | try:
62 | if "sqrt" not in a:
63 | a = int(a)
64 | if "sqrt" not in b:
65 | b = int(b)
66 | assert string == "{}/{}".format(a, b)
67 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
68 | return new_string
69 | except:
70 | return string
71 |
72 |
73 | def _fix_sqrt(string):
74 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
75 | return _string
76 |
77 | def strip_answer_string(string):
78 | string = str(string).strip()
79 | # linebreaks
80 | string = string.replace("\n", "")
81 |
82 | # right "."
83 | string = string.rstrip(".")
84 |
85 | # remove inverse spaces
86 | # replace \\ with \
87 | string = string.replace("\\!", "")
88 | # string = string.replace("\\ ", "")
89 | # string = string.replace("\\\\", "\\")
90 |
91 | # matrix
92 | string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
93 | string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
94 | string = string.replace("bmatrix", "pmatrix")
95 |
96 | # replace tfrac and dfrac with frac
97 | string = string.replace("tfrac", "frac")
98 | string = string.replace("dfrac", "frac")
99 | string = (
100 | string.replace("\\neq", "\\ne")
101 | .replace("\\leq", "\\le")
102 | .replace("\\geq", "\\ge")
103 | )
104 |
105 | # remove \left and \right
106 | string = string.replace("\\left", "")
107 | string = string.replace("\\right", "")
108 | string = string.replace("\\{", "{")
109 | string = string.replace("\\}", "}")
110 |
111 | # Function to replace number words with corresponding digits
112 | def replace_match(match):
113 | word = match.group(1).lower()
114 | if convert_word_number(word) == word:
115 | return match.group(0)
116 | else:
117 | return convert_word_number(word)
118 | string = re.sub(r"\\text\{([a-zA-Z]+)\}", replace_match, string)
119 |
120 | # Before removing unit, check if the unit is squared (for surface area)
121 | string = re.sub(r"(cm|inches)\}\^2", r"\1}", string)
122 |
123 | # Remove unit: miles, dollars if after is not none
124 | _string = re.sub(r"\\text{.*?}$", "", string).strip()
125 | if _string != "" and _string != string:
126 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
127 | string = _string
128 |
129 | # Remove circ (degrees)
130 | string = string.replace("^{\\circ}", "")
131 | string = string.replace("^\\circ", "")
132 |
133 | # remove dollar signs
134 | string = string.replace("\\$", "")
135 | string = string.replace("$", "")
136 | string = string.replace("\\(", "").replace("\\)", "")
137 |
138 | # convert word number to digit
139 | string = convert_word_number(string)
140 |
141 | # replace "\\text{...}" to "..."
142 | string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
143 | for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
144 | string = string.replace(key, "")
145 | string = string.replace("\\emptyset", r"{}")
146 | string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
147 |
148 | # remove percentage
149 | string = string.replace("\\%", "")
150 | string = string.replace("\%", "")
151 | string = string.replace("%", "")
152 |
153 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
154 | string = string.replace(" .", " 0.")
155 | string = string.replace("{.", "{0.")
156 |
157 | # cdot
158 | # string = string.replace("\\cdot", "")
159 | if (
160 | string.startswith("{")
161 | and string.endswith("}")
162 | and string.isalnum()
163 | or string.startswith("(")
164 | and string.endswith(")")
165 | and string.isalnum()
166 | or string.startswith("[")
167 | and string.endswith("]")
168 | and string.isalnum()
169 | ):
170 | string = string[1:-1]
171 |
172 | # inf
173 | string = string.replace("infinity", "\\infty")
174 | if "\\infty" not in string:
175 | string = string.replace("inf", "\\infty")
176 | string = string.replace("+\\inity", "\\infty")
177 |
178 | # and
179 | string = string.replace("and", "")
180 | string = string.replace("\\mathbf", "")
181 |
182 | # use regex to remove \mbox{...}
183 | string = re.sub(r"\\mbox{.*?}", "", string)
184 |
185 | # quote
186 | string.replace("'", "")
187 | string.replace('"', "")
188 |
189 | # i, j
190 | if "j" in string and "i" not in string:
191 | string = string.replace("j", "i")
192 |
193 | # replace a.000b where b is not number or b is end, with ab, use regex
194 | string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
195 | string = re.sub(r"(\d+)\.0*$", r"\1", string)
196 |
197 | # if empty, return empty string
198 | if len(string) == 0:
199 | return string
200 | if string[0] == ".":
201 | string = "0" + string
202 |
203 | # to consider: get rid of e.g. "k = " or "q = " at beginning
204 | if len(string.split("=")) == 2:
205 | if len(string.split("=")[0]) <= 2:
206 | string = string.split("=")[1]
207 |
208 | string = _fix_sqrt(string)
209 | string = string.replace(" ", "")
210 |
211 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
212 | string = _fix_fracs(string)
213 |
214 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
215 | string = _fix_a_slash_b(string)
216 |
217 | # Remove unnecessary '\' before integers
218 | string = re.sub(r"\\(?=\-?\d+(\\|\)|,|\]|$))", "", string)
219 |
220 | # Remove grade level (e.g., 12th grade) and just maintain the integer
221 | string = re.sub(r"thgrade$", "", string)
222 |
223 | # If the answer is a list of integers (without parenthesis), sort them
224 | if re.fullmatch(r"(\s*-?\d+\s*,)*\s*-?\d+\s*", string):
225 | if ',' in string:
226 | # Split the string into a list of integers
227 | integer_list = list(map(int, string.split(',')))
228 |
229 | # Sort the list in ascending order
230 | sorted_list = sorted(integer_list)
231 |
232 | # Join the sorted list back into a comma-separated string
233 | string = ','.join(map(str, sorted_list))
234 |
235 | return string
236 |
237 | def extract_answer(pred_str, use_last_number=True):
238 | pred_str = pred_str.replace("\u043a\u0438", "")
239 | if "final answer is $" in pred_str and "$. I hope" in pred_str:
240 | # minerva_math
241 | tmp = pred_str.split("final answer is $", 1)[1]
242 | pred = tmp.split("$. I hope", 1)[0].strip()
243 | elif "boxed" in pred_str:
244 | ans = pred_str.split("boxed")[-1]
245 | if len(ans) == 0:
246 | return ""
247 | elif ans[0] == "{":
248 | stack = 1
249 | a = ""
250 | for c in ans[1:]:
251 | if c == "{":
252 | stack += 1
253 | a += c
254 | elif c == "}":
255 | stack -= 1
256 | if stack == 0:
257 | break
258 | a += c
259 | else:
260 | a += c
261 | else:
262 | a = ans.split("$")[0].strip()
263 | pred = a
264 | elif "he answer is" in pred_str:
265 | pred = pred_str.split("he answer is")[-1].strip()
266 | elif "final answer is" in pred_str:
267 | pred = pred_str.split("final answer is")[-1].strip()
268 | elif "答案是" in pred_str:
269 | # Handle Chinese few-shot multiple choice problem answer extraction
270 | pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
271 | else: # use the last number
272 | if use_last_number:
273 | pattern = "-?\d*\.?\d+"
274 | pred = re.findall(pattern, pred_str.replace(",", ""))
275 | if len(pred) >= 1:
276 | pred = pred[-1]
277 | else:
278 | pred = ""
279 | else:
280 | pred = ""
281 |
282 | # multiple line
283 | # pred = pred.split("\n")[0]
284 | pred = re.sub(r"\n\s*", "", pred)
285 | if pred != "" and pred[0] == ":":
286 | pred = pred[1:]
287 | if pred != "" and pred[-1] == ".":
288 | pred = pred[:-1]
289 | if pred != "" and pred[-1] == "/":
290 | pred = pred[:-1]
291 | pred = strip_answer_string(pred)
292 | return pred
293 |
294 | def get_multiple_choice_answer(pred: str):
295 | tmp = re.findall(r"\b(A|B|C|D)\b", pred.upper())
296 | if tmp:
297 | pred = tmp
298 | else:
299 | pred = [pred.strip().strip(".")]
300 |
301 | if len(pred) == 0:
302 | pred = ""
303 | else:
304 | pred = pred[-1]
305 |
306 | # Remove the period at the end, again!
307 | pred = pred.rstrip(".").rstrip("/")
308 |
309 | return pred
310 |
311 | def choice_answer_clean(pred: str):
312 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
313 | # Clean the answer based on the dataset
314 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
315 | if tmp:
316 | pred = tmp
317 | else:
318 | pred = [pred.strip().strip(".")]
319 | pred = pred[-1]
320 | # Remove the period at the end, again!
321 | pred = pred.rstrip(".").rstrip("/")
322 | return pred
323 |
324 |
325 | def parse_digits(num):
326 | num = regex.sub(",", "", str(num))
327 | try:
328 | return float(num)
329 | except:
330 | if num.endswith("%"):
331 | num = num[:-1]
332 | if num.endswith("\\"):
333 | num = num[:-1]
334 | try:
335 | return float(num) / 100
336 | except:
337 | pass
338 | return None
339 |
340 |
341 | def is_digit(num):
342 | # paired with parse_digits
343 | return parse_digits(num) is not None
344 |
345 |
346 | def str_to_pmatrix(input_str):
347 | input_str = input_str.strip()
348 | matrix_str = re.findall(r"\{.*,.*\}", input_str)
349 | pmatrix_list = []
350 |
351 | for m in matrix_str:
352 | m = m.strip("{}")
353 | pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
354 | pmatrix_list.append(pmatrix)
355 |
356 | return ", ".join(pmatrix_list)
357 |
358 |
359 | def math_equal(
360 | prediction,
361 | reference,
362 | include_percentage: bool = True,
363 | is_close: bool = True,
364 | timeout: bool = False,
365 | ) -> bool:
366 | """
367 | Exact match of math if and only if:
368 | 1. numerical equal: both can convert to float and are equal
369 | 2. symbolic equal: both can convert to sympy expression and are equal
370 | """
371 | if prediction is None or reference is None:
372 | return False
373 | if str(prediction.strip().lower()) == str(reference.strip().lower()):
374 | return True
375 | if (
376 | reference in ["A", "B", "C", "D", "E"]
377 | and choice_answer_clean(prediction) == reference
378 | ):
379 | return True
380 |
381 | try: # 1. numerical equal
382 | if is_digit(prediction) and is_digit(reference):
383 | prediction = parse_digits(prediction)
384 | reference = parse_digits(reference)
385 | # number questions
386 | if include_percentage:
387 | gt_result = [reference / 100, reference, reference * 100]
388 | else:
389 | gt_result = [reference]
390 | for item in gt_result:
391 | try:
392 | if is_close:
393 | if numeric_equal(prediction, item):
394 | return True
395 | else:
396 | if item == prediction:
397 | return True
398 | except Exception:
399 | continue
400 | return False
401 | except:
402 | pass
403 |
404 | if not prediction and prediction not in [0, False]:
405 | return False
406 |
407 | # 2. symbolic equal
408 | reference = str(reference).strip()
409 | prediction = str(prediction).strip()
410 |
411 | ## pmatrix (amps)
412 | if "pmatrix" in prediction and not "pmatrix" in reference:
413 | reference = str_to_pmatrix(reference)
414 |
415 | ## deal with [], (), {}
416 | pred_str, ref_str = prediction, reference
417 | if (
418 | prediction.startswith("[")
419 | and prediction.endswith("]")
420 | and not reference.startswith("(")
421 | ) or (
422 | prediction.startswith("(")
423 | and prediction.endswith(")")
424 | and not reference.startswith("[")
425 | ):
426 | pred_str = pred_str.strip("[]()")
427 | ref_str = ref_str.strip("[]()")
428 | for s in ["{", "}", "(", ")"]:
429 | ref_str = ref_str.replace(s, "")
430 | pred_str = pred_str.replace(s, "")
431 | if pred_str.lower() == ref_str.lower():
432 | return True
433 |
434 | ## [a, b] vs. [c, d], return a==c and b==d
435 | if (
436 | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
437 | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
438 | ):
439 | pred_parts = prediction[1:-1].split(",")
440 | ref_parts = reference[1:-1].split(",")
441 | if len(pred_parts) == len(ref_parts):
442 | if all(
443 | [
444 | math_equal(
445 | pred_parts[i], ref_parts[i], include_percentage, is_close
446 | )
447 | for i in range(len(pred_parts))
448 | ]
449 | ):
450 | return True
451 | if (
452 | (
453 | prediction.startswith("\\begin{pmatrix}")
454 | or prediction.startswith("\\begin{bmatrix}")
455 | )
456 | and (
457 | prediction.endswith("\\end{pmatrix}")
458 | or prediction.endswith("\\end{bmatrix}")
459 | )
460 | and (
461 | reference.startswith("\\begin{pmatrix}")
462 | or reference.startswith("\\begin{bmatrix}")
463 | )
464 | and (
465 | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
466 | )
467 | ):
468 | pred_lines = [
469 | line.strip()
470 | for line in prediction[
471 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
472 | ].split("\\\\")
473 | if line.strip()
474 | ]
475 | ref_lines = [
476 | line.strip()
477 | for line in reference[
478 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
479 | ].split("\\\\")
480 | if line.strip()
481 | ]
482 | matched = True
483 | if len(pred_lines) == len(ref_lines):
484 | for pred_line, ref_line in zip(pred_lines, ref_lines):
485 | pred_parts = pred_line.split("&")
486 | ref_parts = ref_line.split("&")
487 | if len(pred_parts) == len(ref_parts):
488 | if not all(
489 | [
490 | math_equal(
491 | pred_parts[i],
492 | ref_parts[i],
493 | include_percentage,
494 | is_close,
495 | )
496 | for i in range(len(pred_parts))
497 | ]
498 | ):
499 | matched = False
500 | break
501 | else:
502 | matched = False
503 | if not matched:
504 | break
505 | else:
506 | matched = False
507 | if matched:
508 | return True
509 |
510 | if prediction.count("=") == 1 and reference.count("=") == 1:
511 | pred = prediction.split("=")
512 | pred = f"{pred[0].strip()} - ({pred[1].strip()})"
513 | ref = reference.split("=")
514 | ref = f"{ref[0].strip()} - ({ref[1].strip()})"
515 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
516 | return True
517 | elif (
518 | prediction.count("=") == 1
519 | and len(prediction.split("=")[0].strip()) <= 2
520 | and "=" not in reference
521 | ):
522 | if math_equal(
523 | prediction.split("=")[1], reference, include_percentage, is_close
524 | ):
525 | return True
526 | elif (
527 | reference.count("=") == 1
528 | and len(reference.split("=")[0].strip()) <= 2
529 | and "=" not in prediction
530 | ):
531 | if math_equal(
532 | prediction, reference.split("=")[1], include_percentage, is_close
533 | ):
534 | return True
535 |
536 | if symbolic_equal(prediction, reference):
537 | return True
538 |
539 | return False
540 |
541 |
542 | def numeric_equal(prediction: float, reference: float):
543 | return isclose(reference, prediction, rel_tol=1e-4)
544 |
545 |
546 | def symbolic_equal(a, b):
547 | def _parse(s):
548 | for f in [parse_latex, parse_expr, latex2sympy]:
549 | try:
550 | return f(s.replace("\\\\", "\\"))
551 | except:
552 | try:
553 | return f(s)
554 | except:
555 | pass
556 | return s
557 |
558 | a = _parse(a)
559 | b = _parse(b)
560 |
561 | # direct equal
562 | try:
563 | if str(a) == str(b) or a == b:
564 | return True
565 | except:
566 | pass
567 |
568 | # simplify equal
569 | try:
570 | if a.equals(b) or simplify(a - b) == 0:
571 | return True
572 | except:
573 | pass
574 |
575 | # equation equal
576 | try:
577 | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
578 | return True
579 | except:
580 | pass
581 |
582 | try:
583 | if numeric_equal(float(N(a)), float(N(b))):
584 | return True
585 | except:
586 | pass
587 |
588 | # matrix
589 | try:
590 | # if a and b are matrix
591 | if a.shape == b.shape:
592 | _a = a.applyfunc(lambda x: round(x, 3))
593 | _b = b.applyfunc(lambda x: round(x, 3))
594 | if _a.equals(_b):
595 | return True
596 | except:
597 | pass
598 |
599 | return False
600 |
601 | def find_box(pred_str: str):
602 | ans = pred_str.split("boxed")[-1]
603 | if not ans:
604 | return ""
605 | if ans[0] == "{":
606 | stack = 1
607 | a = ""
608 | for c in ans[1:]:
609 | if c == "{":
610 | stack += 1
611 | a += c
612 | elif c == "}":
613 | stack -= 1
614 | if stack == 0:
615 | break
616 | a += c
617 | else:
618 | a += c
619 | else:
620 | a = ans.split("$")[0].strip()
621 | return a
--------------------------------------------------------------------------------
/utils/load_model.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from transformers import AutoTokenizer, AutoModelForCausalLM
3 | from vllm import LLM, SamplingParams
4 |
5 | def count_parameters(model):
6 | return sum(p.numel() for name, p in model.named_parameters() if p.requires_grad)
7 | # return sum(p.numel() for p in model.parameters() if p.requires_grad)
8 | def check_grad_update_status(model):
9 | return None
10 |
11 | def init_model(model_name='', check=True, eval=True, low_cpu_mem_usage=True,):
12 | if eval:
13 | model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto",
14 | low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch.bfloat16,
15 | use_flash_attention_2=True)
16 | else:
17 | model = AutoModelForCausalLM.from_pretrained(model_name, use_cache=False, low_cpu_mem_usage=low_cpu_mem_usage,
18 | torch_dtype=torch.bfloat16)
19 | tokenizer = AutoTokenizer.from_pretrained(model_name,)
20 | if check:
21 | check_grad_update_status(model)
22 | return model, tokenizer
23 |
24 | def load_tokenizer(model_name):
25 | tokenizer = AutoTokenizer.from_pretrained(model_name,)
26 | return tokenizer
27 |
28 | def init_vllm_model(model_name, gpus=1):
29 | llm = LLM(model=model_name, tensor_parallel_size=gpus)
30 | return llm
31 |
32 | def llm_generate(model, tokenizer, prompts:list, max_length=256, temperature=0.7, top_p=0.9):
33 | inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
34 | inputs = {k: v.to(model.device) for k, v in inputs.items()}
35 | with torch.no_grad():
36 | outputs = model.generate(
37 | **inputs,
38 | max_length=max_length,
39 | temperature=temperature,
40 | top_p=top_p,
41 | # num_beams=4,
42 | # do_sample=False,
43 | repetition_penalty=1.1,
44 | pad_token_id=tokenizer.pad_token_id,
45 | eos_token_id=tokenizer.eos_token_id
46 | )
47 | decoded = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
48 | for i, prompt in enumerate(prompts):
49 | decoded[i] = decoded[i].replace(prompt, '').strip()
50 | return decoded
51 |
52 | def vllm_generate(llm, prompts:list, max_length=256, temperature=0.7, top_p=0.9):
53 | sampling_params = SamplingParams(temperature=temperature,top_p=top_p, max_tokens=max_length)
54 | outputs = llm.generate(prompts, sampling_params)
55 | return [output.outputs[0].text for output in outputs]
56 |
57 | def add_template(model_name, prompt, tokenizer=None):
58 | if tokenizer is None: tokenizer = load_tokenizer(model_name)
59 | if 'QwQ' in model_name:
60 | messages = [
61 | {"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step. Return your final response within \\boxed{{}}"},
62 | {"role": "user", "content": prompt}
63 | ]
64 | prompt = tokenizer.apply_chat_template(
65 | messages,
66 | tokenize=False,
67 | add_generation_prompt=True
68 | )
69 | elif 'Sky-T1-32B-Preview' in model_name:
70 | d = {'prompt': "Your role as an assistant involves thoroughly exploring questions through a systematic long \
71 | thinking process before providing the final precise and accurate solutions. This requires \
72 | engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, \
73 | backtracing, and iteration to develop well-considered thinking process. \
74 | Please structure your response into two main sections: Thought and Solution. \
75 | In the Thought section, detail your reasoning process using the specified format: \
76 | <|begin_of_thought|> {thought with steps separated with '\n\n'} \
77 | <|end_of_thought|> \
78 | Each step should include detailed considerations such as analisying questions, summarizing \
79 | relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining \
80 | any errors, and revisiting previous steps. \
81 | In the Solution section, based on various attempts, explorations, and reflections from the Thought \
82 | section, systematically present the final solution that you deem correct. The solution should \
83 | remain a logical, accurate, concise expression style and detail necessary step needed to reach the \
84 | conclusion, formatted as follows: \
85 | <|begin_of_solution|> \
86 | {final formatted, precise, and clear solution} \
87 | <|end_of_solution|> \
88 | Now, try to solve the following question through the above guidelines: "}
89 | prompt = d["prompt"] + prompt
90 | elif 'sky' in model_name:
91 | prompt = 'You are a helpful and harmless assistant. You should solve this math problem using step-by-step reasoning. Require that the output of each step ends with the "\n\n" token. Return your final response with \\[ \\boxed{} \\], if answer is "1/2", the output is \\[ \\boxed{1/2} \\]. '+prompt
92 | else: # elif 'math' in model_name:
93 | # 'follow this structure to step-by-step solve math problem, and final answer must formated with \\[ \\boxed{} \\], if answer is "1/2", the output is \\[ \\boxed{1/2} \\]'
94 | # 'solve this math problem using step-by-step reasoning. Require that the output of each step ends with the "\n\n" token.'
95 | prompt = "You are a helpful and harmless assistant. You should think step-by-step. Return your final response within \\boxed{{}}. "+prompt
96 | return prompt
--------------------------------------------------------------------------------
/utils/settings.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 |
4 | project_dir = './'
5 | cache_dir = None
6 | hug_token = ''
7 | if cache_dir:
8 | os.environ["HF_HOME"] = cache_dir
9 |
10 |
11 |
12 | SYSTEM_PROMPT = {
13 | "Qwen/Qwen2-7B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
14 | "Qwen/QwQ-32B-Preview": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
15 | "Qwen/Qwen2.5-72B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
16 | "Qwen/Qwen2.5-32B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
17 | "Qwen/Qwen2.5-7B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
18 | "NovaSky-AI/Sky-T1-32B-Preview": "Your role as an assistant involves thoroughly exploring questions through a systematic long \
19 | thinking process before providing the final precise and accurate solutions. This requires \
20 | engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, \
21 | backtracing, and iteration to develop well-considered thinking process. \
22 | Please structure your response into two main sections: Thought and Solution. \
23 | In the Thought section, detail your reasoning process using the specified format: \
24 | <|begin_of_thought|> {thought with steps separated with '\n\n'} \
25 | <|end_of_thought|> \
26 | Each step should include detailed considerations such as analisying questions, summarizing \
27 | relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining \
28 | any errors, and revisiting previous steps. \
29 | In the Solution section, based on various attempts, explorations, and reflections from the Thought \
30 | section, systematically present the final solution that you deem correct. The solution should \
31 | remain a logical, accurate, concise expression style and detail necessary step needed to reach the \
32 | conclusion, formatted as follows: \
33 | <|begin_of_solution|> \
34 | {final formatted, precise, and clear solution} \
35 | <|end_of_solution|> \
36 | Now, try to solve the following question through the above guidelines:",
37 | "openai/o1-mini": "Question: {input}\nAnswer: ",
38 | "openai/o1-preview": "Question: {input}\nAnswer: ",
39 | "openai/gpt-4o-mini": "User: {input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant:",
40 | }
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import random
4 | import torch
5 | from utils.settings import *
6 |
7 |
8 | def set_global(path):
9 | return os.path.join(project_dir, path)
10 |
11 | from huggingface_hub import login
12 | login(token=hug_token)
13 |
14 | api_keys = {
15 | 'openai': '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'
16 | }
17 |
18 |
19 | def seed_everything(seed):
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed(seed)
22 | np.random.seed(seed)
23 | random.seed(seed)
24 | torch.backends.cudnn.benchmark = False
25 | torch.backends.cudnn.deterministic = True
26 | torch.cuda.manual_seed_all(seed)
27 |
28 | def Logger(content):
29 | print(content)
--------------------------------------------------------------------------------