├── .gitignore ├── README.md ├── data_curation ├── __init__.py ├── bespoke_data.py ├── length_comparsion.py └── record.py ├── imgs └── Training_pipeline.png ├── model ├── LMConfig.py ├── dataset.py ├── minimind_tokenizer │ ├── merges.txt │ ├── tokenizer.json │ ├── tokenizer_config.json │ └── vocab.json └── model.py ├── tools ├── .gitattributes ├── README.md ├── __init__.py ├── base_instruct_evals.md ├── combine_data.py ├── convert_format.py ├── convert_to_data.py ├── eval.py ├── inference_and_check.py ├── label_math_difficulty.py ├── labeled_numina_difficulty │ └── README.md ├── requirements.txt ├── response_rewrite.py ├── upload_hub.py └── util │ ├── apps │ └── testing_util.py │ ├── common.py │ ├── livecodebench │ └── testing_util.py │ ├── math │ └── testing_util.py │ ├── model_utils.py │ ├── prompts.py │ ├── taco │ ├── pyext2.py │ └── testing_util.py │ └── task_handlers.py ├── train ├── __init__.py ├── deepseed │ ├── ds_z3_offload_config.json │ ├── zero2_config.json │ ├── zero3_config.json │ └── zero3_config2.json ├── dpo_train.py ├── names.py └── sft_train.py └── utils ├── __init__.py ├── data_utils.py ├── eval ├── eval_utils.py └── qwen_math_parser.py ├── load_model.py ├── model_utils.py ├── settings.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/* 2 | utils/upload.py 3 | results/ 4 | scripts/ 5 | tools/eval.sh 6 | tools/results 7 | train/sft_7b.sh 8 | train/wandb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ThinkPO: Thinking Preference Optimization 2 | > a simple yet effective postSFT method that enhances long CoT reasoning without requiring new long CoT responses 3 | 4 |

5 | Training Pipeline 6 |

7 | 8 | --- 9 | # News 10 | - 2025-02-22: We released our ThinkPO dataset: [Bespoke_dpo_filter](https://huggingface.co/datasets/VanWang/Bespoke_dpo_filter) 11 | - 2025-02-21: We released our four models: 12 | - [DeepSeek-R1-Distill-Qwen-7B-ThinkPO](https://huggingface.co/VanWang/DeepSeek-R1-Distill-Qwen-7B-ThinkPO) 13 | - [Bespoke-Stratos-7B-ThinkPO](https://huggingface.co/VanWang/Bespoke-Stratos-7B-ThinkPO) 14 | - [Bespoke-Stratos-7B-repro-SFT](https://huggingface.co/VanWang/Bespoke-Stratos-7B-repro-SFT) 15 | - [Bespoke-Stratos-7B-repro-ThinkPO](https://huggingface.co/VanWang/Bespoke-Stratos-7B-repro-ThinkPO) 16 | 17 | - 2025-02-19: We released our [paper](https://arxiv.org/abs/2502.13173). 18 | 19 | --- 20 | # Introduction 21 | - Here, we show the results of open-source reasoning LLMs before and after ThinkPO. 22 | ## Accuracy 23 | 24 | | Models | Dataset | SFT | Ours (+ThinkPO) | Improv. (%) | 25 | |:--------:|:--------:|:--------:|:--------:|:--------:| 26 | |DeepSeek-R1-Distill-Qwen-7B (Deepseek) |MATH500 | 87.4 | 91.2 | 4.3% | 27 | || AIME | 56.7 | 43.3 | -23.6% | 28 | || GPQA | 47.0 | 49.5 | 5.3% | 29 | || GSM8K | 87.2 | 87.6 | 0.5% | 30 | || Olympiad | 58.6 | 58.6 | 0.0% | 31 | |Bespoke-Stratos-7B (Bespoke)| MATH500 | 84.0 | 82.8 | -1.4% | 32 | || AIME | 20.0 | 23.3 | 16.5% | 33 | || GPQA | 37.9 | 43.4 | 14.5% | 34 | || GSM8K | 92.9 | 93.3 | 0.4% | 35 | || Olympiad | 44.1 | 48.5 | 10.0% | 36 | 37 | ## Average Response Length 38 | 39 | | Model | Dataset | SFT | Ours (+ThinkPO) | Improv. (%) | 40 | |:--------:|:--------:|:--------:|:--------:|:--------:| 41 | |DeepSeek-R1-Distill-Qwen-7B (Deepseek) | MATH500 | 2577 | 3021 | 17.2% | 42 | || AIME | 11419 | 12875 | 12.8% | 43 | || GPQA | 4895 | 5604 | 14.5% | 44 | || GSM8K | 619 | 668 | 7.9% | 45 | || Olympiad | 7196 | 7383 | 2.6% | 46 | |Bespoke-Stratos-7B (Bespoke)| MATH500 | 5696 | 6404 | 12.4% | 47 | || AIME | 19858 | 20079 | 1.1% | 48 | || GPQA | 5968 | 7301 | 22.3% | 49 | || GSM8K | 1404 | 1755 | 25.0% | 50 | || Olympiad | 11140 | 12204 | 9.6% | 51 | 52 | --- 53 | # Quick Use 54 | ## Settting 55 | - in ./utils/settings.py, you could set your project path, huggingface cache path and token 56 | ```python 57 | project_dir = 'path to your project' 58 | cache_dir = 'path to huggingface cache' 59 | hug_token = 'your huggingface token' 60 | ``` 61 | 62 | ## SFT Train 63 | - if you wanna use multi-gpus to train Qwen2.5-7B-Instruct with SFT, you could use the following command: 64 | ```shell 65 | cd train 66 | deepspeed sft_train.py --model_name Instruct-7b --gradient_accumulation_steps 16 --dataset_name Bespoke --epoch 3 --lr 1e-5 --deepspeed ./deepseed/zero3_config2.json 67 | ``` 68 | 69 | ## ThinkPO Train 70 | - if you wanna use multi-gpus to train Qwen2.5-7B-Instruct with ThinkPO, you could use the following command: 71 | ```shell 72 | cd train 73 | deepspeed dpo_train.py --lr 3e-7 --beta 0.01 --model Bespoke-7b --dataset Bespoke_dpo --gradient_accumulation_steps 12 --deepspeed ./deepseed/zero3_config2.json 74 | ``` 75 | 76 | ## eval the model 77 | - The LLM Reasoning Evaluation refers to [Sky-Thought](https://github.com/NovaSky-AI/SkyThought/tree/main) 78 | - you could use the following command to evaluate the model, like datasets MATH500,AIME,GPQADiamond,GSM8K,OlympiadBenchMath 79 | ```shell 80 | cd ./tools 81 | python ./eval.py \ 82 | --model deepseek-ai/DeepSeek-R1-Distill-Qwen-7B \ 83 | --evals MATH500,AIME,GPQADiamond,GSM8K,OlympiadBenchMath \ 84 | --tp 2 --output_file ./results/eval/DeepSeek-R1-Distill-Qwen-7B.txt \ 85 | --result_dir ./results/generated 86 | ``` 87 | 88 | ## citation 89 | ```bibtex 90 | @misc{yang2025thinkingpreferenceoptimization, 91 | title={Thinking Preference Optimization}, 92 | author={Wang Yang and Hongye Jin and Jingfeng Yang and Vipin Chaudhary and Xiaotian Han}, 93 | year={2025}, 94 | eprint={2502.13173}, 95 | archivePrefix={arXiv}, 96 | primaryClass={cs.LG}, 97 | url={https://arxiv.org/abs/2502.13173}, 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /data_curation/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.abspath(os.path.join(__file__, "../../"))) -------------------------------------------------------------------------------- /data_curation/bespoke_data.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | from utils.utils import * 3 | from utils.data_utils import read_saved_results, save_results 4 | from utils.load_model import * 5 | from utils.data_utils import * 6 | from utils.eval.eval_utils import check_math_correctness 7 | from tqdm import tqdm 8 | import copy 9 | 10 | 11 | def generate_is_ok(predict, refrence): 12 | if not check_math_correctness(refrence, predict): return False 13 | return True 14 | 15 | MAX_LENGTH = 1024 * 1 16 | 17 | save_file = set_global(f'./data/final/Bespoke_dpo.jsonl') 18 | save_data = read_saved_results(save_file) 19 | ref_data = load_data('bespokelabs/Bespoke-Stratos-17k', 'huggingface')['train'] 20 | 21 | 22 | tokenizer = load_tokenizer("NovaSky-AI/Sky-T1-32B-Preview") 23 | llm = init_vllm_model("Qwen/Qwen2.5-Math-7B-Instruct", 1) 24 | qwwn_tokenizer = load_tokenizer("Qwen/Qwen2.5-Math-7B-Instruct") 25 | 26 | for data in tqdm(ref_data): 27 | chosen = data['conversations'][1]['value'] 28 | messages = [ 29 | {"role": "system", "content": "You are a helpful and harmless assistant. Please reason step by step."}, 30 | {"role": "user", "content": data['conversations'][0]['value']} 31 | ] 32 | prompt = qwwn_tokenizer.apply_chat_template( 33 | messages, 34 | tokenize=False, 35 | add_generation_prompt=True 36 | ) 37 | rejected = vllm_generate(llm, prompt, MAX_LENGTH)[0] 38 | if len(rejected) == 0 or not generate_is_ok(rejected, chosen): 39 | continue 40 | rejected_ans = copy.deepcopy(data['conversations']) 41 | rejected_ans[1]['value'] = rejected 42 | save_data.append({ 43 | 'system': data['system'], 44 | 'chosen': data['conversations'], 45 | 'rejected': rejected_ans 46 | }) 47 | save_results(save_file, save_data[-1]) 48 | 49 | 50 | 51 | save_file = set_global(f'./data/final/Bespoke_dpo.jsonl') 52 | save_data = [] 53 | all_data = read_saved_results(save_file) 54 | -------------------------------------------------------------------------------- /data_curation/length_comparsion.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | from utils.utils import * 3 | from utils.data_utils import save_all_results, read_saved_results 4 | from utils.load_model import * 5 | from utils.data_utils import * 6 | from utils.eval.eval_utils import check_math_correctness 7 | from tqdm import tqdm 8 | 9 | 10 | def get_middle_1000(lst): 11 | n = len(lst) 12 | if n <= 1000: 13 | return lst 14 | start = (n - 1000) // 2 15 | return lst[start:start + 1000] 16 | 17 | def generate_is_ok(predict, refrence): 18 | if not check_math_correctness(refrence, predict): return False 19 | return True 20 | 21 | save_file = set_global(f'./data/final/Bespoke_dpo.jsonl') 22 | all_data = read_saved_results(save_file) 23 | tokenizer = load_tokenizer("NovaSky-AI/Sky-T1-32B-Preview") 24 | 25 | lengths, save_data = [], [] 26 | for d in tqdm(all_data): 27 | chosen, rejected = d['chosen'][2]['content'], d['rejected'][2]['content'] 28 | chosen_len, rejected_len = len(tokenizer.encode(chosen, add_special_tokens=True)), len(tokenizer.encode(rejected, add_special_tokens=True)) 29 | if rejected_len < 8*1024: 30 | d['dealta'] = rejected_len - chosen_len 31 | chosen, rejected = d['chosen'], d['rejected'] 32 | d['chosen'], d['rejected'] = rejected, chosen 33 | save_data.append(d) 34 | lengths.append(d['dealta']) 35 | 36 | all_data = sorted(save_data, key=lambda x: x["dealta"]) 37 | 38 | middle_1000 = get_middle_1000(all_data) 39 | save_all_results(set_global(f'./data/final/Bespoke_dpo_filter_len_middle.jsonl'), middle_1000) 40 | save_all_results(set_global(f'./data/final/Bespoke_dpo_filter_len_short.jsonl'), all_data[:1000]) 41 | save_all_results(set_global(f'./data/final/Bespoke_dpo_filter_len_long.jsonl'), all_data[-1000:]) 42 | -------------------------------------------------------------------------------- /data_curation/record.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import os 3 | from utils.data_utils import * 4 | from utils.eval.eval_utils import find_box 5 | 6 | 7 | def count_total_occurrences(text, words): 8 | return sum(text.lower().count(word) for word in words) 9 | word_list = ["wait", "hmm"] 10 | 11 | folder_path = "/home/wxy320/ondemand/program/SkyThought/skythought/tools/results/steps" 12 | subdirs = [d for d in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, d))] 13 | 14 | results = dict() 15 | for step in subdirs: 16 | step = os.path.join(folder_path, step) 17 | json_lens = dict() 18 | json_files = [f for f in os.listdir(step) if f.endswith(".json") and os.path.isfile(os.path.join(step, f))] 19 | for file_path in json_files: 20 | with open(os.path.join(step,file_path), "r", encoding="utf-8") as f: 21 | data = json.load(f) 22 | all_prompts_len, num, counts = 0, 0 ,0 23 | for prompt, d in data.items(): 24 | if '<|end_of_solution|>' in d['responses']['0.7']['content']: 25 | all_prompts_len = all_prompts_len + d['token_usages']['0.7']['completion_tokens'] 26 | num = num +1 27 | if d['responses']['0.7']['correctness']: 28 | counts =counts+count_total_occurrences(d['responses']['0.7']['content'], word_list) 29 | json_lens[file_path] = {'lengths':all_prompts_len / num, "wait":counts} 30 | results[step.split('/')[-1]] = json_lens 31 | 32 | 33 | with open('results.json', "w", encoding="utf-8") as f: 34 | json.dump(results, f, ensure_ascii=False, indent=4) -------------------------------------------------------------------------------- /imgs/Training_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uservan/ThinkPO/23aa90e636e0beafabfa24b681f94216bbbaba9a/imgs/Training_pipeline.png -------------------------------------------------------------------------------- /model/LMConfig.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | from typing import List 3 | 4 | 5 | class LMConfig(PretrainedConfig): 6 | model_type = "minimind" 7 | 8 | def __init__( 9 | self, 10 | dim: int = 512, 11 | n_layers: int = 8, 12 | n_heads: int = 16, 13 | n_kv_heads: int = 8, 14 | vocab_size: int = 6400, 15 | hidden_dim: int = None, 16 | multiple_of: int = 64, 17 | norm_eps: float = 1e-5, 18 | max_seq_len: int = 512, 19 | dropout: float = 0.0, 20 | flash_attn: bool = True, 21 | #################################################### 22 | # Here are the specific configurations of MOE 23 | # When use_moe is false, the following is invalid 24 | #################################################### 25 | use_moe: bool = False, 26 | num_experts_per_tok=2, 27 | n_routed_experts=4, 28 | n_shared_experts: bool = True, 29 | scoring_func='softmax', 30 | aux_loss_alpha=0.01, 31 | seq_aux=True, 32 | norm_topk_prob=True, 33 | **kwargs, 34 | ): 35 | self.dim = dim 36 | self.n_layers = n_layers 37 | self.n_heads = n_heads 38 | self.n_kv_heads = n_kv_heads 39 | self.vocab_size = vocab_size 40 | self.hidden_dim = hidden_dim 41 | self.multiple_of = multiple_of 42 | self.norm_eps = norm_eps 43 | self.max_seq_len = max_seq_len 44 | self.dropout = dropout 45 | self.flash_attn = flash_attn 46 | #################################################### 47 | # Here are the specific configurations of MOE 48 | # When use_moe is false, the following is invalid 49 | #################################################### 50 | self.use_moe = use_moe 51 | self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量 52 | self.n_routed_experts = n_routed_experts # 总的专家数量 53 | self.n_shared_experts = n_shared_experts # 共享专家 54 | self.scoring_func = scoring_func # 评分函数,默认为'softmax' 55 | self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数 56 | self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失 57 | self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率 58 | super().__init__(**kwargs) 59 | -------------------------------------------------------------------------------- /model/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import re 4 | 5 | import pandas as pd 6 | import numpy as np 7 | from torch.utils.data import Dataset, DataLoader 8 | import torch 9 | from sklearn.model_selection import train_test_split 10 | import os 11 | 12 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 13 | 14 | 15 | class PretrainDataset(Dataset): 16 | def __init__(self, df, tokenizer, max_length=4096): 17 | super().__init__() 18 | self.df = df 19 | self.tokenizer = tokenizer 20 | self.max_length = max_length 21 | self.padding = 0 22 | 23 | def __len__(self): 24 | return self.df.shape[0] 25 | 26 | def __getitem__(self, index: int): 27 | # 28 | sample = self.df.iloc[index] 29 | text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}" 30 | input_id = self.tokenizer(text).data['input_ids'][:self.max_length] 31 | text_len = len(input_id) 32 | # 没满最大长度的剩余部分 33 | padding_len = self.max_length - text_len 34 | input_id = input_id + [self.padding] * padding_len 35 | # 0表示不计算损失 36 | loss_mask = [1] * text_len + [0] * padding_len 37 | 38 | input_id = np.array(input_id) 39 | X = np.array(input_id[:-1]).astype(np.int64) 40 | Y = np.array(input_id[1:]).astype(np.int64) 41 | loss_mask = np.array(loss_mask[1:]).astype(np.int64) 42 | return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask) 43 | 44 | class SFTDataset_List(Dataset): 45 | def __init__(self, list, tokenizer, max_seq_len): 46 | super().__init__() 47 | self.list = list 48 | self.tokenizer = tokenizer 49 | self.MAX_LENGTH = max_seq_len 50 | 51 | def __len__(self): 52 | return len(self.list) 53 | 54 | def __getitem__(self, index: int): 55 | example = self.list[index] 56 | input_ids, attention_mask, labels = [], [], [] 57 | inputs, targets = example["prompt"], example["response"] 58 | instruction = self.tokenizer(inputs+'\n', add_special_tokens=False) 59 | response = self.tokenizer(targets, add_special_tokens=False) 60 | input_ids = instruction["input_ids"] + response["input_ids"] 61 | attention_mask = instruction["attention_mask"] + response["attention_mask"] 62 | labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] 63 | if len(input_ids) > self.MAX_LENGTH: 64 | input_ids = input_ids[:self.MAX_LENGTH] 65 | attention_mask = attention_mask[:self.MAX_LENGTH] 66 | labels = labels[:self.MAX_LENGTH] 67 | # else: 68 | # padding_len = self.MAX_LENGTH - len(input_ids) 69 | # input_ids = input_ids + [self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)] * padding_len 70 | # labels = labels + [-100]* padding_len 71 | # attention_mask = attention_mask + [0] * padding_len 72 | return torch.from_numpy(np.array(input_ids)), torch.from_numpy(np.array(labels)), torch.from_numpy(np.array(attention_mask)) 73 | 74 | class SFTDataset(Dataset): 75 | def __init__(self, df, tokenizer, max_length=1024, prompt_max_len=512, answer_max_len=256): 76 | super().__init__() 77 | self.df = df 78 | self.max_length = max_length 79 | self.prompt_max_len = prompt_max_len 80 | self.answer_max_len = answer_max_len 81 | # 82 | self.tokenizer = tokenizer 83 | self.padding = 0 84 | self.bos_id = self.tokenizer('assistant').data['input_ids'] 85 | 86 | def __len__(self): 87 | return self.df.shape[0] 88 | 89 | def find_sublist_index(self, main_list, sub_list) -> int: 90 | last_index = -1 91 | for i in range(len(main_list) - len(sub_list) + 1): 92 | if main_list[i:i + len(sub_list)] == sub_list: 93 | last_index = i 94 | return last_index 95 | 96 | def safe_eval(self, s): 97 | try: 98 | res = eval(s) 99 | except Exception as e: 100 | return [] 101 | return res 102 | 103 | def __getitem__(self, index: int): 104 | # 105 | sample = self.df.iloc[index] 106 | history = self.safe_eval(sample['history']) 107 | q = str(sample['q']) 108 | a = str(sample['a']) 109 | 110 | messages = [] 111 | for history_message in history: 112 | if len(history_message) <= 1: 113 | continue 114 | messages.append( 115 | {"role": 'user', "content": str(history_message[0])[:self.max_length // 2]} 116 | ) 117 | messages.append( 118 | {"role": 'assistant', "content": str(history_message[1])[:self.max_length // 2]} 119 | ) 120 | 121 | messages += [ 122 | {"role": "user", "content": q}, 123 | {"role": "assistant", "content": a}, 124 | ] 125 | new_prompt = self.tokenizer.apply_chat_template( 126 | messages, 127 | tokenize=False, 128 | add_generation_prompt=True 129 | ) 130 | input_id = self.tokenizer(new_prompt).data['input_ids'][:self.max_length] 131 | 132 | # 实际长度 133 | question_length = self.find_sublist_index(input_id, self.bos_id) + len(self.bos_id) 134 | # 没满最大长度的剩余部分 135 | padding_len = self.max_length - len(input_id) 136 | input_id = input_id + [self.padding] * padding_len 137 | mask_len = len(input_id) - question_length - padding_len 138 | # 0表示不计算损失 139 | loss_mask = [0] * question_length + [1] * (mask_len) + [0] * padding_len 140 | 141 | input_id = np.array(input_id) 142 | X = np.array(input_id[:-1]).astype(np.int64) 143 | Y = np.array(input_id[1:]).astype(np.int64) 144 | loss_mask = np.array(loss_mask[1:]).astype(np.int64) 145 | 146 | X_tensor = torch.from_numpy(X) 147 | Y_tensor = torch.from_numpy(Y) 148 | loss_mask_tensor = torch.from_numpy(loss_mask) 149 | 150 | return X_tensor, Y_tensor, loss_mask_tensor 151 | 152 | 153 | if __name__ == "__main__": 154 | pass 155 | -------------------------------------------------------------------------------- /model/minimind_tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": false, 3 | "add_eos_token": false, 4 | "add_prefix_space": true, 5 | "added_tokens_decoder": { 6 | "0": { 7 | "content": "", 8 | "lstrip": false, 9 | "normalized": false, 10 | "rstrip": false, 11 | "single_word": false, 12 | "special": true 13 | }, 14 | "1": { 15 | "content": "", 16 | "lstrip": false, 17 | "normalized": false, 18 | "rstrip": false, 19 | "single_word": false, 20 | "special": true 21 | }, 22 | "2": { 23 | "content": "", 24 | "lstrip": false, 25 | "normalized": false, 26 | "rstrip": false, 27 | "single_word": false, 28 | "special": true 29 | } 30 | }, 31 | "additional_special_tokens": [], 32 | "bos_token": "", 33 | "clean_up_tokenization_spaces": false, 34 | "eos_token": "", 35 | "legacy": true, 36 | "model_max_length": 1000000000000000019884624838656, 37 | "pad_token": null, 38 | "sp_model_kwargs": {}, 39 | "spaces_between_special_tokens": false, 40 | "tokenizer_class": "PreTrainedTokenizerFast", 41 | "unk_token": "", 42 | "use_default_system_prompt": false, 43 | "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ 'user\\n' + content + '\\nassistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '' + '\\n' }}{% endif %}{% endfor %}" 44 | } -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import struct 3 | import inspect 4 | import time 5 | 6 | from .LMConfig import LMConfig 7 | from typing import Any, Optional, Tuple 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | from transformers import PreTrainedModel 13 | from transformers.modeling_outputs import CausalLMOutputWithPast 14 | 15 | 16 | class RMSNorm(torch.nn.Module): 17 | def __init__(self, dim: int, eps: float): 18 | super().__init__() 19 | self.eps = eps 20 | self.weight = nn.Parameter(torch.ones(dim)) 21 | 22 | def _norm(self, x): 23 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 24 | 25 | def forward(self, x): 26 | output = self._norm(x.float()).type_as(x) 27 | return output * self.weight 28 | 29 | 30 | def precompute_pos_cis(dim: int, end: int, theta: float = 10000.0): 31 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 32 | t = torch.arange(end, device=freqs.device) # type: ignore 33 | freqs = torch.outer(t, freqs).float() # type: ignore 34 | pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 35 | return pos_cis 36 | 37 | 38 | def apply_rotary_emb(xq, xk, pos_cis): 39 | def unite_shape(pos_cis, x): 40 | ndim = x.ndim 41 | assert 0 <= 1 < ndim 42 | assert pos_cis.shape == (x.shape[1], x.shape[-1]) 43 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 44 | return pos_cis.view(*shape) 45 | 46 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 47 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 48 | pos_cis = unite_shape(pos_cis, xq_) 49 | xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) 50 | xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) 51 | return xq_out.type_as(xq), xk_out.type_as(xk) 52 | 53 | 54 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 55 | """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" 56 | bs, slen, n_kv_heads, head_dim = x.shape 57 | if n_rep == 1: 58 | return x 59 | return ( 60 | x[:, :, :, None, :] 61 | .expand(bs, slen, n_kv_heads, n_rep, head_dim) 62 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim) 63 | ) 64 | 65 | 66 | class Attention(nn.Module): 67 | def __init__(self, args: LMConfig): 68 | super().__init__() 69 | self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads 70 | assert args.n_heads % self.n_kv_heads == 0 71 | self.n_local_heads = args.n_heads 72 | self.n_local_kv_heads = self.n_kv_heads 73 | self.n_rep = self.n_local_heads // self.n_local_kv_heads 74 | self.head_dim = args.dim // args.n_heads 75 | self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) 76 | self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) 77 | self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) 78 | self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) 79 | self.k_cache, self.v_cache = None, None 80 | self.attn_dropout = nn.Dropout(args.dropout) 81 | self.resid_dropout = nn.Dropout(args.dropout) 82 | self.dropout = args.dropout 83 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn 84 | 85 | # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 86 | mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) 87 | mask = torch.triu(mask, diagonal=1) 88 | self.register_buffer("mask", mask, persistent=False) 89 | 90 | def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False): 91 | bsz, seqlen, _ = x.shape 92 | 93 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 94 | 95 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 96 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 97 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 98 | 99 | xq, xk = apply_rotary_emb(xq, xk, pos_cis) 100 | 101 | # 更高效的kv_cache实现 102 | if kv_cache and self.eval(): 103 | if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)): 104 | xk = torch.cat((self.k_cache, xk), dim=1) 105 | xv = torch.cat((self.v_cache, xv), dim=1) 106 | self.k_cache, self.v_cache = xk, xv 107 | 108 | xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) 109 | xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) 110 | 111 | xq = xq.transpose(1, 2) 112 | xk = xk.transpose(1, 2) 113 | xv = xv.transpose(1, 2) 114 | 115 | if self.flash and seqlen != 1: 116 | output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, 117 | dropout_p=self.dropout if self.training else 0.0, 118 | is_causal=True) 119 | else: 120 | scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) 121 | scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) 122 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 123 | scores = self.attn_dropout(scores) 124 | output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) 125 | 126 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 127 | 128 | output = self.wo(output) 129 | output = self.resid_dropout(output) 130 | return output 131 | 132 | 133 | class FeedForward(nn.Module): 134 | def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float): 135 | super().__init__() 136 | if hidden_dim is None: 137 | hidden_dim = 4 * dim 138 | hidden_dim = int(2 * hidden_dim / 3) 139 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 140 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 141 | self.w2 = nn.Linear(hidden_dim, dim, bias=False) 142 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 143 | self.dropout = nn.Dropout(dropout) 144 | 145 | def forward(self, x): 146 | return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) 147 | 148 | 149 | class MoEGate(nn.Module): 150 | def __init__(self, config: LMConfig): 151 | super().__init__() 152 | self.config = config 153 | self.top_k = config.num_experts_per_tok 154 | self.n_routed_experts = config.n_routed_experts 155 | 156 | self.scoring_func = config.scoring_func 157 | self.alpha = config.aux_loss_alpha 158 | self.seq_aux = config.seq_aux 159 | 160 | self.norm_topk_prob = config.norm_topk_prob 161 | self.gating_dim = config.dim 162 | self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) 163 | self.reset_parameters() 164 | 165 | def reset_parameters(self) -> None: 166 | import torch.nn.init as init 167 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 168 | 169 | def forward(self, hidden_states): 170 | bsz, seq_len, h = hidden_states.shape 171 | 172 | hidden_states = hidden_states.view(-1, h) 173 | logits = F.linear(hidden_states, self.weight, None) 174 | if self.scoring_func == 'softmax': 175 | scores = logits.softmax(dim=-1) 176 | else: 177 | raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') 178 | 179 | topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) 180 | 181 | if self.top_k > 1 and self.norm_topk_prob: 182 | denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 183 | topk_weight = topk_weight / denominator 184 | 185 | if self.training and self.alpha > 0.0: 186 | scores_for_aux = scores 187 | aux_topk = self.top_k 188 | topk_idx_for_aux_loss = topk_idx.view(bsz, -1) 189 | if self.seq_aux: 190 | scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) 191 | ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) 192 | ce.scatter_add_(1, topk_idx_for_aux_loss, 193 | torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_( 194 | seq_len * aux_topk / self.n_routed_experts) 195 | aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha 196 | else: 197 | mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) 198 | ce = mask_ce.float().mean(0) 199 | Pi = scores_for_aux.mean(0) 200 | fi = ce * self.n_routed_experts 201 | aux_loss = (Pi * fi).sum() * self.alpha 202 | else: 203 | aux_loss = None 204 | return topk_idx, topk_weight, aux_loss 205 | 206 | 207 | class MOEFeedForward(nn.Module): 208 | def __init__(self, config: LMConfig): 209 | super().__init__() 210 | self.config = config 211 | self.experts = nn.ModuleList([ 212 | FeedForward( 213 | dim=config.dim, 214 | hidden_dim=config.hidden_dim, 215 | multiple_of=config.multiple_of, 216 | dropout=config.dropout, 217 | ) 218 | for _ in range(config.n_routed_experts) 219 | ]) 220 | 221 | self.gate = MoEGate(config) 222 | if config.n_shared_experts is not None: 223 | self.shared_experts = FeedForward( 224 | dim=config.dim, 225 | hidden_dim=config.hidden_dim, 226 | multiple_of=config.multiple_of, 227 | dropout=config.dropout, 228 | ) 229 | 230 | def forward(self, x): 231 | identity = x 232 | orig_shape = x.shape 233 | bsz, seq_len, _ = x.shape 234 | 235 | # 使用门控机制选择专家 236 | topk_idx, topk_weight, aux_loss = self.gate(x) 237 | 238 | x = x.view(-1, x.shape[-1]) 239 | flat_topk_idx = topk_idx.view(-1) 240 | 241 | if self.training: 242 | # 训练模式下,重复输入数据 243 | x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) 244 | y = torch.empty_like(x, dtype=torch.float16) 245 | for i, expert in enumerate(self.experts): 246 | y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]) 247 | y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) 248 | y = y.view(*orig_shape) 249 | else: 250 | # 推理模式下,只选择最优专家 251 | y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) 252 | 253 | if self.config.n_shared_experts is not None: 254 | y = y + self.shared_experts(identity) 255 | 256 | return y 257 | 258 | @torch.no_grad() 259 | def moe_infer(self, x, flat_expert_indices, flat_expert_weights): 260 | expert_cache = torch.zeros_like(x) 261 | idxs = flat_expert_indices.argsort() 262 | tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) 263 | token_idxs = idxs // self.config.num_experts_per_tok 264 | # 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52] 265 | # 当token_idxs=[3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 266 | # 意味着当token_idxs[:6] -> [3, 7, 19, 21, 24, 25, 4]位置的token都由专家0处理,token_idxs[6:15]位置的token都由专家1处理...... 267 | for i, end_idx in enumerate(tokens_per_expert): 268 | start_idx = 0 if i == 0 else tokens_per_expert[i - 1] 269 | if start_idx == end_idx: 270 | continue 271 | expert = self.experts[i] 272 | exp_token_idx = token_idxs[start_idx:end_idx] 273 | expert_tokens = x[exp_token_idx] 274 | expert_out = expert(expert_tokens) 275 | expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) 276 | # 使用 scatter_add_ 进行 sum 操作 277 | expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out) 278 | 279 | return expert_cache 280 | 281 | 282 | class TransformerBlock(nn.Module): 283 | def __init__(self, layer_id: int, args: LMConfig): 284 | super().__init__() 285 | self.n_heads = args.n_heads 286 | self.dim = args.dim 287 | self.head_dim = args.dim // args.n_heads 288 | self.attention = Attention(args) 289 | 290 | self.layer_id = layer_id 291 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 292 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 293 | 294 | if args.use_moe: 295 | self.feed_forward = MOEFeedForward(args) 296 | else: 297 | self.feed_forward = FeedForward( 298 | dim=args.dim, 299 | hidden_dim=args.hidden_dim, 300 | multiple_of=args.multiple_of, 301 | dropout=args.dropout, 302 | ) 303 | 304 | def forward(self, x, pos_cis, kv_cache=False): 305 | h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache) 306 | out = h + self.feed_forward(self.ffn_norm(h)) 307 | return out 308 | 309 | 310 | class Transformer(PreTrainedModel): 311 | config_class = LMConfig 312 | last_loss: Optional[torch.Tensor] 313 | 314 | def __init__(self, params: LMConfig = None): 315 | super().__init__(params) 316 | if not params: 317 | params = LMConfig() 318 | self.params = params 319 | self.vocab_size = params.vocab_size 320 | self.n_layers = params.n_layers 321 | 322 | self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) 323 | self.dropout = nn.Dropout(params.dropout) 324 | self.layers = torch.nn.ModuleList() 325 | for layer_id in range(self.n_layers): 326 | self.layers.append(TransformerBlock(layer_id, params)) 327 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 328 | self.output = nn.Linear(params.dim, params.vocab_size, bias=False) 329 | self.tok_embeddings.weight = self.output.weight 330 | pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len) 331 | self.register_buffer("pos_cis", pos_cis, persistent=False) 332 | 333 | self.apply(self._init_weights) 334 | 335 | for pn, p in self.named_parameters(): 336 | if pn.endswith('w3.weight') or pn.endswith('wo.weight'): 337 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers)) 338 | 339 | self.last_loss = None 340 | self.OUT = CausalLMOutputWithPast() 341 | self._no_split_modules = [name for name, _ in self.named_modules()] 342 | 343 | def _init_weights(self, module): 344 | if isinstance(module, nn.Linear): 345 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 346 | if module.bias is not None: 347 | torch.nn.init.zeros_(module.bias) 348 | elif isinstance(module, nn.Embedding): 349 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 350 | 351 | def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None, 352 | kv_cache=False, **keyargs): 353 | current_idx = 0 354 | if 'input_ids' in keyargs: 355 | tokens = keyargs['input_ids'] 356 | if 'attention_mask' in keyargs: 357 | targets = keyargs['attention_mask'] 358 | if 'current_idx' in keyargs: 359 | current_idx = int(keyargs['current_idx']) 360 | 361 | _bsz, seqlen = tokens.shape 362 | h = self.tok_embeddings(tokens) 363 | h = self.dropout(h) 364 | pos_cis = self.pos_cis[current_idx:current_idx + seqlen] 365 | for idx, layer in enumerate(self.layers): 366 | h = layer(h, pos_cis, kv_cache) 367 | 368 | h = self.norm(h) 369 | 370 | if targets is not None: 371 | logits = self.output(h) 372 | self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), 373 | ignore_index=0, reduction='none') 374 | else: 375 | logits = self.output(h[:, [-1], :]) 376 | self.last_loss = None 377 | 378 | self.OUT.__setitem__('logits', logits) 379 | self.OUT.__setitem__('last_loss', self.last_loss) 380 | return self.OUT 381 | 382 | @torch.inference_mode() 383 | def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True): 384 | # rp: repetition_penalty 385 | index = idx.shape[1] 386 | init_inference = True 387 | while idx.shape[1] < max_new_tokens - 1: 388 | if init_inference or not kv_cache: 389 | inference_res, init_inference = self(idx, kv_cache=kv_cache), False 390 | else: 391 | inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1) 392 | 393 | logits = inference_res.logits 394 | logits = logits[:, -1, :] 395 | 396 | for token in set(idx.tolist()[0]): 397 | logits[:, token] /= rp 398 | 399 | if temperature == 0.0: 400 | _, idx_next = torch.topk(logits, k=1, dim=-1) 401 | else: 402 | logits = logits / temperature 403 | if top_k is not None: 404 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 405 | logits[logits < v[:, [-1]]] = -float('Inf') 406 | 407 | probs = F.softmax(logits, dim=-1) 408 | idx_next = torch.multinomial(probs, num_samples=1, generator=None) 409 | 410 | if idx_next == eos: 411 | break 412 | 413 | idx = torch.cat((idx, idx_next), dim=1) 414 | if stream: 415 | yield idx[:, index:] 416 | 417 | if not stream: 418 | yield idx[:, index:] 419 | 420 | @torch.inference_mode() 421 | def eval_answer(self, idx): 422 | idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:] 423 | inference_res = self(idx_cond) 424 | logits = inference_res.logits 425 | logits = logits[:, -1, :] 426 | return logits 427 | -------------------------------------------------------------------------------- /tools/.gitattributes: -------------------------------------------------------------------------------- 1 | labeled_numina_difficulty/labeled_amc_aime_0_-1.json filter=lfs diff=lfs merge=lfs -text 2 | labeled_numina_difficulty/labeled_math_0_-1.json filter=lfs diff=lfs merge=lfs -text 3 | labeled_numina_difficulty/labeled_olympiads_0_-1.json filter=lfs diff=lfs merge=lfs -text 4 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | # Data Generation, Processing, and Evaluation Tools 2 | This document describes the steps to training data curation and evaluation scripts for Sky-T1. 3 | 4 | ## Requirements 5 | First create the environment as follows. 6 | ```shell 7 | conda create -n eval python==3.10 8 | conda activate eval 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | For running OpenAI model, export the OpenAI key. 13 | ```shell 14 | export OPENAI_API_KEY={openai_api_key} 15 | ``` 16 | 17 | ## Training Data Curation 18 | ### Step 0 (Optional, only for NUMINA math dataset): Label Math Difficulty from NUMINA 19 | Put one or multiple OPENAI_API_KEY in a file, e.g. keys.txt (one per line). If there is more than one key, the script will use them in a round-robin way to speed up generation. Label Math difficulty using GPT-4o-mini: 20 | #### Example usage: 21 | ``` 22 | python label_math_difficulty.py --source [amc_aime, math, olympiads] --keys keys.txt 23 | ``` 24 | The expected output is labeled_source_0_-1.json. We also provide instructions to download these files under the labeled_numina_difficulty folder (Download from HuggingFace). 25 | 26 | ### Step 1: Data Inference 27 | Inference the results from QwQ on several datasets. In preview version, we use data from the following dataset. 28 | 29 | ```shell 30 | python inference_and_check.py --dataset APPS --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --inference 31 | 32 | python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source MEDIUM --filter-difficulty --result-dir $SKYT_HOME/data --inference 33 | 34 | python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --inference 35 | 36 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty --result-dir $SKYT_HOME/data --inference 37 | 38 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty --result-dir $SKYT_HOME/data --inference 39 | 40 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source olympiads --end 20000 --filter-difficulty --result-dir $SKYT_HOME/data --inference 41 | ``` 42 | 43 | ### Step 2: Format the response 44 | After obtaining a list file for training data, convert them to a unified format (Note: This uses GPT-4o-mini to rewrite. The output is long and takes ~100 dollars for our preview data). 45 | ```shell 46 | python convert_format.py --input_dir $SKYT_HOME/data --keys keys.txt 47 | ``` 48 | 49 | ### Step 3: Reject Sampling on the formatted data (Example Usage with previous script) 50 | ```shell 51 | python inference_and_check.py --dataset APPS --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data --check 52 | ``` 53 | Similar for other datasets. 54 | 55 | ### Convert to ShareGPT format for training 56 | After obtaining multiple converted files, merge them together and convert to the ShareGPT format to perform training. In our preview model, we also add the science and riddle portion from the [STILL-2 model](https://arxiv.org/pdf/2412.09413), where interested readers can download their part of data and simply concatenating to the data obtained above. 57 | ```shell 58 | python convert_to_data.py --input_dir $SKYT_HOME/data --output $SKYT_HOME/data/train_data.json 59 | ``` 60 | 61 | 62 | ## Generation and Evaluation 63 | The file `inference_and_check.py` provides convenient methods for generating sequences (e.g., for distillation or benchmark evaluation) and checking whether the generated solutions are correct (e.g., for reject sampling or benchmark evaluation). 64 | 65 | ### Distill and Reject Sampling 66 | Currently we support distill and reject sampling from various self-hosted models for NUMINA, APPS, and TACO datasets. For NUMINA, the source can be one from `[amc_aime, math, olympiads]`. 67 | #### Example Usage 68 | 69 | ```shell 70 | python inference_and_check.py --dataset APPS --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data 71 | 72 | python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source MEDIUM --filter-difficulty --result-dir $SKYT_HOME/data 73 | 74 | python inference_and_check.py --dataset TACO --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split test --source all --result-dir $SKYT_HOME/data 75 | 76 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source math --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 4 --math_difficulty_upper_bound 9 77 | 78 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source amc_aime --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 1 --math_difficulty_upper_bound 9 79 | 80 | python inference_and_check.py --dataset NUMINA --model Qwen/QwQ-32B-Preview --tp 8 --max_tokens 16384 --split train --source olympiads --end 20000 --filter-difficulty --result-dir $SKYT_HOME/data --math_difficulty_lower_bound 9 --math_difficulty_upper_bound 9 81 | ``` 82 | 83 | 84 | #### Best-of-N Inference and Check 85 | ```bash 86 | python inference_and_check.py --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --inference --temperatures 0.7 --n 64 87 | python inference_and_check.py --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --check --temperatures 0.7 --n 8 88 | ``` 89 | 90 | ### Benchmark Evaluations 91 | We provide a wrapper script `eval.py` to conveniently run reasoning benchmarks. We currently support `AIME`, `MATH500`, `GPQADiamond`, and `MMLU`. This script can be used to launch evaluations for multiple benchmarks, then aggregate and log the accuracy for all benchmarks. 92 | 93 | **Note**: The `GPQADiamond` dataset is gated and requires first receiving access at this Huggingface [link](https://huggingface.co/datasets/Idavidrein/gpqa) (which is granted immediately), then logging into your Huggingface account in your terminal session with `huggingface-cli login`. 94 | 95 | **NOTE**: For reproducing `Sky-T1-32B-Preview` results on `AIME` and `GPQADiamond` dataset, pass in temperatures as `0.7`. 96 | 97 | ```shell 98 | python eval.py --model NovaSky-AI/Sky-T1-32B-Preview --evals=AIME,GPQADiamond --tp=8 --output_file=results.txt --temperatures 0.7 99 | ``` 100 | 101 | #### Example Usage 102 | ```shell 103 | python eval.py --model Qwen/QwQ-32B-Preview --evals=AIME,MATH500,GPQADiamond --tp=8 --output_file=results.txt 104 | ``` 105 | 106 | Example result: `{"AIME": , "MATH500": , "GPQADiamond": }` 107 | 108 | ## Response Rewriting 109 | The file `response_rewrite.py` provides a pipeline for filtering and rewriting responses generated with `inference_and_check.py`. We use `response_rewrite.py` to create preference pairs for preference optimization (e.g., DPO, SimPO), however, the logic can be edited for alternative filtering and rewriting steps. Details of the implemented logic can be found in `response_rewrite.py` or on [this blog post](https://novasky-ai.github.io/posts/reduce-overthinking). 110 | 111 | To use our preference optimization pipeline, first generate and score multiple responses using `inference_and_check.py`. For example: 112 | 113 | ```shell 114 | python inference_and_check.py --inference --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --temperatures 0.7 --n 8 115 | python inference_and_check.py --check --dataset MATH500 --model Qwen/Qwen2-7B-Instruct --tp 4 --max_tokens 4096 --split test --result-dir ./ --temperatures 0.7 --n 8 116 | ``` 117 | 118 | Then, use `response_rewrite.py` to process the responses into preference pairs. By default, the shortest correct responses will be used as positive examples and the longest correct responses will be used as negative samples. The argument `--SILC` can be used to also include short incorrect responses as negative examples and long correct repsonses as positive samples. 119 | 120 | ```shell 121 | python response_rewrite.py --SILC --rewrite-model meta-llama/Meta-Llama-3-8B-Instruct --target-model NovaSky-AI/Sky-T1-32B-Preview --dataset [PATH_TO_GENERATED_RESPONSES] --result-dir ./ --checkpoint --tp 8 122 | ``` 123 | 124 | The `--checkpoint` argument can optionally be used to save intermediate files of the processed data between steps, in case of failure. 125 | 126 | The resulting `.json` files can be used to train a model with preference optimization algorithms. See the `/train/` directory for more details. -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.abspath(os.path.join(__file__, "../../"))) 4 | from utils.settings import * 5 | 6 | 7 | -------------------------------------------------------------------------------- /tools/base_instruct_evals.md: -------------------------------------------------------------------------------- 1 | # Reproducing results on non-reasoning benchmarks 2 | 3 | For the full set of results, see [here](./README.md#results-on-qa-and-instruction-following-benchmarks). 4 | 5 | ## Installation instructions 6 | 7 | 1. For `lm_eval`, install the package by executing the following : 8 | 9 | ```bash 10 | git clone https://github.com/EleutherAI/lm-evaluation-harness 11 | cd lm-evaluation-harness 12 | git checkout 703fbff 13 | pip install -e ".[ifeval]" 14 | ``` 15 | 16 | For more details, you can refer to the official instructions [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/703fbffd6fe5`e136bbb9d884cb40844e5503ae5d?tab=readme-ov-file#install). We report results with commit https://github.com/EleutherAI/lm-evaluation-harness/commit/703fbffd6fe5e136bbb9d884cb40844e5503ae5d 17 | 18 | 2. For `fastchat`, follow the instructions [here](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge#install). The current implementation of Fastchat is based on OpenAI version <= 0.28.0. For making use of the latest vllm backend, it is recommended to migrate the `llm_judge` folder to use openai>=1.0.0. You can run `openai migrate` for the fastchat codebase or follow the PR [here](https://github.com/lm-sys/FastChat/pull/2915/files) 19 | 3. For `BFCL`, you can follow the official instructions [here](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#basic-installation). We further evaulate on all test categories, which requires [setting up environment variables](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#setting-up-environment-variables), and [obtaining API keys for executable test categories](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#api-keys-for-executable-test-categories). Make sure to use changes from [this PR](https://github.com/ShishirPatil/gorilla/pull/888) for QwQ and Sky-T1 model support. 20 | 4. For `Arena-Hard` results, you can follow the instructions [here](https://github.com/lmarena/arena-hard-auto). We use `gpt-4-1106-preview` as the judge. 21 | 22 | ## Commands for reproducing results 23 | 24 | All the benchmarks were run on a 8xH100 machine with the `vllm` backend. If you're running on a different device, make sure to tweak `tensor_parallel_size` and if needed the `batch_size` arguments. Expect some variance in scores (+/- 1%) for different evaluation settings (ex: `tensor_parallel_size`) 25 | 26 | All the commands below are given for `NovaSky-AI/Sky-T1-32B-Preview`. Simply substitute the model name for `Qwen/Qwen-2.5-32B-Instruct`. For `Qwen/QwQ-32B-Preview`, we further make use of two arguments `revision=refs/pr/58,tokenizer_revision=refs/pr/58` to use a corrected revision of QwQ. For more details on this, see https://github.com/NovaSky-AI/SkyThought/pull/26#issuecomment-2606435601. 27 | 28 | ### MMLU (0 shot; no CoT) 29 | 30 | ```bash 31 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mmlu --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn 32 | ``` 33 | 34 | For QwQ, you would do 35 | 36 | ```bash 37 | lm_eval --model vllm --model_args pretrained=Qwen/QwQ-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048revision=refs/pr/58,tokenizer_revision=refs/pr/58 --tasks mmlu --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn 38 | ``` 39 | 40 | ### MMLU (5 shot; no CoT) 41 | 42 | ```bash 43 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mmlu --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn --num_fewshot 5 44 | ``` 45 | 46 | ### ARC-C (0 shot; no CoT) 47 | 48 | ```bash 49 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks arc_challenge --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn 50 | ``` 51 | 52 | ### IFEval 53 | 54 | ```bash 55 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.9,data_parallel_size=1 --tasks leaderboard_ifeval --trust_remote_code --batch_size auto --apply_chat_template --fewshot_as_multiturn 56 | ``` 57 | 58 | We use the `prompt_level_strict_acc` metric following Qwen-2.5. 59 | 60 | ### MGSM (native CoT) 61 | 62 | ```bash 63 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mgsm_direct --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn 64 | ``` 65 | 66 | We report the average value of `flexible-extract` filter. 67 | 68 | ### MGSM (8-shot; native CoT) 69 | 70 | ```bash 71 | lm_eval --model vllm --model_args pretrained=NovaSky-AI/Sky-T1-32B-Preview,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,max_model_len=2048 --tasks mgsm_direct --trust_remote_code --batch_size 8 --apply_chat_template --fewshot_as_multiturn --num_fewshot 8 72 | ``` 73 | 74 | ### LLM-as-a-Judge 75 | 76 | We use the default settings - with `max_tokens` 1024 and the `gpt-4` judge. We observe that some reasoning models like `Qwen/QwQ-32B-Preview` are unable to provide brief responses sometimes and thus get truncated responses at the used `max_tokens`. While this will effect the final rating, given the context length limitations of the commonly used `gpt-4` judge (8K tokens), we stick to the 1024 `max_tokens` budget for consistency. 77 | 78 | 1. First, serve the model with vLLM 79 | 80 | 81 | ```bash 82 | vllm serve NovaSky-AI/Sky-T1-32B-Preview --dtype auto --tensor-parallel-size 8 --gpu-memory-utilization 0.9 83 | ``` 84 | 85 | For `Qwen/QwQ-32B-Preview`, use 86 | 87 | ```bash 88 | vllm serve Qwen/QwQ-32B-Preview --dtype auto --tensor-parallel-size 8 --gpu-memory-utilization 0.9 --revision refs/pr/58 --tokenizer-revision refs/pr/58 89 | ``` 90 | 91 | 2. Next, generate model response 92 | 93 | ```bash 94 | python gen_api_answer.py --model NovaSky-AI/Sky-T1-32B-Preview --openai-api-base http://localhost:8000/v1 --parallel 50 95 | ``` 96 | 97 | Note: The generated results will be in `data/model_answer//.jsonl` . Move them to the root folder `data/model_answer/` 98 | 99 | 3. After generating responses for all the models, evaluate with the default settings 100 | 101 | ```bash 102 | export OPENAI_API_KEY=XXXXXX # set the OpenAI API key 103 | python gen_judgment.py --model-list Sky-T1-32B-Preview QwQ-32B-Preview Qwen2.5-32B-Instruct --parallel 2 104 | ``` 105 | 4. Get MTBench scores (we use the average score of both turns) 106 | 107 | ```bash 108 | python show_result.py 109 | ``` 110 | 111 | ### BFCL-v3 112 | 113 | Our results are reported on `test-category` `all` . Make sure to get the API keys for the executable test categories by following the instructions [here](https://github.com/ShishirPatil/gorilla/tree/main/berkeley-function-call-leaderboard#api-keys-for-executable-test-categories) 114 | 115 | Run 116 | 117 | ```bash 118 | bfcl generate --model NovaSky-AI/Sky-T1-32B-Preview --test-category all --backend vllm --num-gpus 8 --gpu-memory-utilization 0.9 119 | ``` 120 | 121 | For evaluation, you can simply run 122 | 123 | ```bash 124 | bfcl evaluate --model Qwen/QwQ-32B-Preview,NovaSky-AI/Sky-T1-32B-Preview,Qwen/Qwen2.5-32B-Instruct --test-category all --api-sanity-check 125 | ``` 126 | ### Arena Hard 127 | For `Arena-Hard`, we use the following script to start a `TGI` service for generating answers 128 | ```bash 129 | hf_pat= 130 | model=NovaSky-AI/Sky-T1-32B-Preview 131 | volume=/mnt/local_storage/data/cache 132 | port=1996 133 | 134 | huggingface-cli download $model 135 | sudo docker run --gpus 8 -e HUGGING_FACE_HUB_TOKEN=$hf_pat --shm-size 2000g -p $port:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0 --model-id $model --max-input-length 8192 --max-batch-total-tokens 8193 --max-batch-prefill-tokens 8193 --max-total-tokens 8193 --sharded true 136 | ``` 137 | For running the `gen_answer.py` script, we use the following `config_api` yaml setting. For `qwq-32b-preview`, we explicitly specify the system prompt as `You are a helpful and harmless assistant. You are Qwen developed by Alibaba.` to avoid the CoT prompt. 138 | ```yaml 139 | ... 140 | sky-T1-32B-Preview: 141 | model_name: sky-T1-32B-Preview 142 | endpoints: 143 | - api_base: http://localhost:1996/v1 144 | api_key: empty 145 | api_type: openai 146 | parallel: 8 147 | ... 148 | ``` 149 | and finally for `gen_judgment.py`, we use `gpt-4-1106-preview` as the judge. 150 | 151 | #### Supplementary results for Arena-Hard 152 | 153 | Here are some supplementary results for Arena-Hard, compared with o1-mini which is the best performing model on this benchmark (as of Jan 2025). 154 | 155 | | model | score | rating_q025 | rating_q975 | CI | avg_tokens | date | 156 | |-------|--------|------------|-------------|-------|------------|-------| 157 | | o1-mini-2024-09-12 | 91.98 | 90.88 | 93.12 | (-1.10, +1.14) | 1399.0 | 2025-01-18 | 158 | | sky-T1-32B-Preview | 74.79 | 72.28 | 76.8 | (-2.51, +2.01) | 847.0 | 2025-01-18 | 159 | | qwen2.5-32b-instruct | 66.51 | 64.55 | 68.4 | (-1.96, +1.89) | 611.0 | 2025-01-18 | 160 | | qwq-32b-preview | 52.6 | 50.86 | 54.91 | (-1.74, +2.31) | 1005.0 | 2025-01-23 | 161 | 162 | For more details, see: https://github.com/NovaSky-AI/SkyThought/pull/26#issuecomment-2599525551 163 | -------------------------------------------------------------------------------- /tools/combine_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from util.prompts import system_prompt 4 | 5 | still2_jsonl_file = "../../data/public_long_form_thought_data_5k.jsonl" 6 | code_json_file = "../../data/converted_apps_long_form_thought_data_5k.json" 7 | output_file = "../../data/converted_v2_long_form_thought_data_9k.json" 8 | 9 | # Load the JSONL file 10 | still2_data = [] 11 | with open(still2_jsonl_file, "r") as f: 12 | for line in f: 13 | still2_data.append(json.loads(line.strip())) 14 | # print(math_data) 15 | 16 | # Process the data into the desired format 17 | all_data = [] 18 | code_num = 0 19 | 20 | for entry in still2_data: 21 | question = entry["question"] 22 | combined_text = entry["combined_text"] 23 | domain = entry["domain"] 24 | if domain != "code": 25 | # Create the conversation format 26 | conversations = [ 27 | {"from": "user", "value": question}, 28 | {"from": "assistant", "value": combined_text} 29 | ] 30 | 31 | # Prepare the final structure 32 | cur_data = { 33 | "system": system_prompt, 34 | "conversations": conversations 35 | } 36 | all_data.append(cur_data) 37 | else: 38 | code_num += 1 39 | 40 | print(code_num) 41 | with open(code_json_file, "r") as f: 42 | code_data = json.load(f) 43 | # print(code_data[0]) 44 | 45 | all_data.extend(code_data) 46 | print(f"First item slice before shuffle: {all_data[0]['conversations'][-1]['value'][-50:-1]}") 47 | random.shuffle(all_data) 48 | print(f"First item slice after shuffle: {all_data[0]['conversations'][-1]['value'][-50:-1]}") 49 | print(len(all_data)) 50 | 51 | # Save the converted data to the output file 52 | with open(output_file, "w") as f: 53 | json.dump(all_data, f, indent=4) 54 | 55 | print(f"Conversion completed. The data has been saved to {output_file} with {len(all_data)} data.") 56 | 57 | -------------------------------------------------------------------------------- /tools/convert_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from tqdm import tqdm 4 | import multiprocessing as mp 5 | import openai 6 | from itertools import cycle 7 | import time 8 | import os 9 | from util.prompts import convert_prompt, convert_prompt_example 10 | 11 | global args 12 | # Function to set the OpenAI API key 13 | def set_openai_key(api_key): 14 | openai.api_key = api_key 15 | 16 | # GPT API processing function with retry logic 17 | def process_content(content, api_key): 18 | # Set the OpenAI key for this request 19 | set_openai_key(api_key) 20 | 21 | # GPT prompt 22 | prompt = convert_prompt.format(example=convert_prompt_example, content=content) 23 | 24 | retries = 3 25 | while retries > 0: 26 | try: 27 | # OpenAI API call 28 | response = openai.chat.completions.create( 29 | model="gpt-4o-mini", 30 | messages=[ 31 | {"role": "system", "content": "You are a solution format convertor."}, 32 | {"role": "user", "content": prompt} 33 | ], 34 | max_tokens=16384, 35 | temperature=0.7 36 | ) 37 | return response.choices[0].message.content 38 | except openai.RateLimitError: 39 | retries -= 1 40 | if retries == 0: 41 | return "Error: Rate limit reached and retries exhausted." 42 | print(f"Sleep for 5 seconds for API limit.") 43 | time.sleep(5) 44 | except Exception as e: 45 | return f"Error processing content: {e}" 46 | 47 | # Function for multiprocessing 48 | def process_entry(entry, api_key_cycle): 49 | key, values = entry 50 | content = values["responses"]["0.7"]["content"] 51 | 52 | # Get the next API key from the cycle 53 | api_key = next(api_key_cycle) 54 | 55 | processed = process_content(content, api_key) 56 | values["responses"]["0.7"]["processed_content"] = processed 57 | 58 | return key, values 59 | 60 | # Wrapper function for multiprocessing 61 | def process_entry_wrapper(args): 62 | return process_entry(*args) 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser(description="Process content and save results.") 66 | parser.add_argument("--input_dir", type=str, help="Input directory containing JSON files.") 67 | parser.add_argument("--keys", type=str, help="File containing OpenAI API keys (one per line).") 68 | 69 | global args 70 | args = parser.parse_args() 71 | 72 | # Load API keys and prepare a round-robin cycle 73 | with open(args.keys, "r") as f: 74 | api_keys = [line.strip() for line in f if line.strip()] 75 | api_key_cycle = cycle(api_keys) 76 | 77 | # Process each file in the input directory 78 | for filename in os.listdir(args.input_dir): 79 | if filename.endswith(".json"): 80 | input_path = os.path.join(args.input_dir, filename) 81 | 82 | # Load the data 83 | with open(input_path, "r") as f: 84 | data = json.load(f) 85 | 86 | # Prepare output file 87 | output_file = os.path.join(args.input_dir, f"converted_{filename}") 88 | 89 | # Use multiprocessing to process the content 90 | results = [] 91 | with mp.Pool(os.cpu_count()) as pool: 92 | tasks = [(entry, api_key_cycle) for entry in data.items()] 93 | for result in tqdm(pool.imap(process_entry_wrapper, tasks), total=len(data)): 94 | results.append(result) 95 | 96 | # Aggregate and write results in the main process 97 | aggregated_data = {key: values for key, values in results} 98 | with open(output_file, "w") as f: 99 | json.dump(aggregated_data, f, indent=4) 100 | 101 | print(f"Processed data saved to {output_file}") 102 | -------------------------------------------------------------------------------- /tools/convert_to_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from util.prompts import system_prompt 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser(description="Convert JSON data for processing.") 8 | parser.add_argument("--input_dir", type=str, help="Directory containing input JSON files.") 9 | parser.add_argument("--output", type=str, help="Output JSON file.") 10 | args = parser.parse_args() 11 | 12 | all_data = [] 13 | 14 | # Iterate through all files in the input directory 15 | for filename in os.listdir(args.input_dir): 16 | if filename.endswith(".json") and filename.startswith("converted"): 17 | filepath = os.path.join(args.input_dir, filename) 18 | with open(filepath, "r") as f: 19 | cur_data = json.load(f) 20 | 21 | for _, v in cur_data.items(): 22 | prompt = v["prompt"] 23 | response_data = v["responses"] 24 | 25 | for cur_temp, cur_temp_response in response_data.items(): 26 | # Only support 0.7 for this version 27 | assert cur_temp == "0.7", "Only support a single temperature=0.7 now." 28 | # Accept this data 29 | if cur_temp_response["correctness"]: 30 | # Create the conversation format 31 | conversations = [ 32 | {"from": "user", "value": prompt}, 33 | {"from": "assistant", "value": cur_temp_response["processed_content"]} 34 | ] 35 | 36 | # Prepare the final structure 37 | cur_data = { 38 | "system": system_prompt, 39 | "conversations": conversations 40 | } 41 | all_data.append(cur_data) 42 | 43 | # Save the converted data to the output file 44 | with open(args.output, "w") as f: 45 | json.dump(all_data, f, indent=4) 46 | 47 | print(f"Conversion completed. The data has been saved to {args.output} with {len(all_data)} data.") 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import argparse 3 | import subprocess 4 | import os 5 | import json 6 | 7 | 8 | # Define eval to split mapping 9 | eval_to_split = { 10 | "MATH500": "test", 11 | "MinvervaMath": "test", 12 | "AIME": "train", 13 | "GPQADiamond": "train", 14 | "MMLU": "test", 15 | "MMLUPro": "test", 16 | "LiveCodeBench": "test", 17 | "GSM8K": "test", 18 | "ARC-C": "test", 19 | "AMC23": "train", 20 | "OlympiadBenchMath": "train", 21 | } 22 | 23 | def parse_arguments(): 24 | parser = argparse.ArgumentParser(description="Process model path, prompt format, and evals to run.") 25 | parser.add_argument("--model", required=True, type=str, default='bespokelabs/Bespoke-Stratos-7B', help="Path to the model.") 26 | parser.add_argument("--evals", required=True, type=str, default='MinvervaMath,AIME,MATH500,GPQADiamond,GSM8K,OlympiadBenchMath', help="Comma-separated list of evals to run (no spaces).") 27 | parser.add_argument("--tp", type=int, default=1, help="Tensor Parallelism Degree") 28 | parser.add_argument("--filter-difficulty", action="store_true", help="Filter difficulty.") 29 | parser.add_argument("--source", type=str, help="Source for the dataset.") 30 | parser.add_argument("--output_file", required=True, type=str, default='./eval/Bespoke-7B.txt', help="Output file to write results to.") 31 | parser.add_argument("--temperatures", type=float, nargs="+", default=[0.7], help="Temperature for sampling.") 32 | parser.add_argument("--result_dir", type=str, default='./results/generated', help="Source for the dataset.") 33 | return parser.parse_args() 34 | 35 | def extract_accuracy_from_output(output): 36 | # Iterate through all lines from the end to the beginning 37 | lines = output.splitlines()[::-1] 38 | for line in lines: 39 | try: 40 | # Attempt to parse a JSON object from the line 41 | data = json.loads(line.replace("'", '"')) 42 | if "acc" in data: 43 | return data["acc"] 44 | except json.JSONDecodeError: 45 | continue 46 | return None 47 | 48 | def write_logs_to_file(logs, output_file): 49 | try: 50 | with open(output_file, "w") as file: 51 | file.write(logs) 52 | print(f"Logs successfully written to {output_file}") 53 | except IOError as e: 54 | print(f"Failed to write logs to file {output_file}: {e}") 55 | 56 | def main(): 57 | args = parse_arguments() 58 | 59 | # Extract the arguments 60 | model_path = args.model 61 | evals = args.evals.split(",") 62 | output_file = args.output_file 63 | tp = args.tp 64 | temperatures = [str(t) for t in args.temperatures] 65 | 66 | script_path = "inference_and_check.py" 67 | 68 | # Hold all logs 69 | all_logs = "" 70 | results = {} 71 | result_dir = args.model.split("/")[-2] if 'check' in args.model.split("/")[-1] else args.model.split("/")[-1] 72 | result_dir = os.path.join(args.result_dir,result_dir) 73 | # Run the Python command for each eval and collect logs 74 | for eval_name in evals: 75 | command = [ 76 | "python", script_path, 77 | "--model", model_path, 78 | "--dataset", eval_name, 79 | "--split", eval_to_split[eval_name], 80 | "--tp", str(tp), 81 | "--result-dir", result_dir, 82 | "--temperatures" 83 | ] 84 | command.extend(temperatures) # Add temperatures as separate arguments 85 | 86 | if args.filter_difficulty: 87 | assert args.source != "", "No source passed for filtering difficulty." 88 | command.append("--filter-difficulty") 89 | command.append("--source") 90 | command.append(args.source) 91 | print(f"Running eval {eval_name} with command {command}") 92 | all_logs += f"\nRunning eval: {eval_name} with command {command}\n" 93 | try: 94 | with subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) as proc: 95 | output_lines = [] 96 | for line in proc.stdout: 97 | print(line, end="") # Stream output to the console 98 | output_lines.append(line) 99 | all_logs += line 100 | proc.wait() 101 | if proc.returncode != 0: 102 | raise subprocess.CalledProcessError(proc.returncode, command) 103 | 104 | # Capture output for post-processing 105 | output = "".join(output_lines) 106 | accuracy = extract_accuracy_from_output(output) 107 | results[eval_name] = accuracy 108 | 109 | except subprocess.CalledProcessError as e: 110 | error_message = f"Error occurred while running eval {eval_name}: {e}\n" 111 | print(error_message) 112 | all_logs += error_message 113 | 114 | # Write logs of all stdout / stderr to a file 115 | write_logs_to_file(all_logs, output_file) 116 | 117 | print("Results:") 118 | print(results) 119 | 120 | if __name__ == "__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /tools/inference_and_check.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | import json 3 | import argparse 4 | import re 5 | from concurrent.futures import ProcessPoolExecutor, as_completed 6 | from vllm import LLM, SamplingParams 7 | from tqdm import tqdm 8 | from util.task_handlers import * 9 | from util.model_utils import * 10 | from openai import OpenAI 11 | import concurrent.futures 12 | from functools import partial 13 | 14 | class NumpyEncoder(json.JSONEncoder): 15 | def default(self, obj): 16 | if isinstance(obj, np.ndarray): 17 | return obj.tolist() 18 | return super().default(obj) 19 | 20 | def fetch_response_openai(llm, model_name, max_tokens, temp, prompt): 21 | model_name = model_name.replace("openai/", "") 22 | if "o1" in model_name: 23 | # O1 doesn't support system prompt 24 | # NOTE: might want to implement this inside handler instead 25 | for p in prompt: 26 | p["role"] = "user" 27 | 28 | response = llm.chat.completions.create( 29 | model=model_name, 30 | messages=prompt, 31 | n=1, 32 | temperature=1, # has to be 1 33 | max_completion_tokens=max_tokens 34 | ) 35 | else: 36 | response = llm.chat.completions.create( 37 | model=model_name, 38 | messages=prompt, 39 | n=1, 40 | temperature=temp, 41 | max_tokens=max_tokens 42 | ) 43 | return response 44 | 45 | def perform_inference_and_check(handler: TaskHandler, temperatures, max_tokens, result_file, llm, system_prompt, args): 46 | results = handler.load_existing_results(result_file) 47 | print(f"Loaded {len(results)} existing results.") 48 | train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, \ 49 | filter_difficulty=args.filter_difficulty, args=args) 50 | remaining_data = handler.process_remaining_data(train_data, results) 51 | conversations = handler.make_conversations(remaining_data, system_prompt, args.model) 52 | 53 | for temp in temperatures: 54 | 55 | if args.model.startswith("openai"): 56 | fetch_partial = partial(fetch_response_openai, llm, args.model, max_tokens, temp) 57 | 58 | with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e: 59 | responses = list(e.map(fetch_partial, conversations)) 60 | 61 | else: 62 | sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temp) 63 | responses = llm.chat(messages=conversations, sampling_params=sampling_params, use_tqdm=True) 64 | 65 | total_correct = 0 66 | total_finish = 0 67 | with ProcessPoolExecutor(max_workers=32) as executor: 68 | # future_to_task = { 69 | # executor.submit(handler.update_results, remaining_data[idx], response): idx 70 | # for idx, response in enumerate(responses) 71 | # } 72 | future_to_task = {} 73 | token_usages = {} 74 | for idx, response in enumerate(responses): 75 | if args.model.startswith("openai"): 76 | response_str = response.choices[0].message.content.strip() 77 | else: 78 | response_str = response.outputs[0].text.strip() 79 | future_to_task[executor.submit(handler.update_results, remaining_data[idx], response_str)] = idx 80 | # print(f"Request output: {response}") 81 | 82 | if args.model.startswith("openai"): 83 | token_usages[idx] = response.usage 84 | else: 85 | token_usages[idx] = { 86 | "completion_tokens": len(response.outputs[0].token_ids), 87 | "prompt_tokens": len(response.prompt_token_ids) 88 | } 89 | 90 | for future in tqdm(as_completed(future_to_task), total=len(future_to_task), desc="Processing Generations"): 91 | idx = future_to_task[future] 92 | response_entry = future.result() 93 | total_correct += response_entry["correctness"] 94 | total_finish += 1 95 | 96 | problem_key = remaining_data[idx][handler.get_question_key()] 97 | if problem_key not in results: 98 | results[problem_key] = remaining_data[idx] 99 | if isinstance(handler, NUMINATaskHandler): 100 | results[problem_key]["messages"] = "" 101 | results[problem_key]["responses"] = {} 102 | results[problem_key]["token_usages"] = {} 103 | prompt = conversations[idx][1]["content"] 104 | results[problem_key]["prompt"] = prompt 105 | results[problem_key]["input_conversation"] = conversations[idx] 106 | 107 | results[problem_key]["responses"][str(temp)] = response_entry 108 | 109 | if args.model.startswith("openai"): 110 | results[problem_key]["token_usages"][str(temp)] = { 111 | "completion_tokens": token_usages[idx].completion_tokens, 112 | "prompt_tokens": token_usages[idx].prompt_tokens, 113 | } 114 | else: 115 | # TODO: vLLM model, can it do the same thing 116 | results[problem_key]["token_usages"][str(temp)] = token_usages[idx] 117 | 118 | print(f"Final acc: {total_correct}/{total_finish}") 119 | acc = round(total_correct / total_finish, 4) if total_finish > 0 else 0 120 | print(json.dumps({"acc": acc})) 121 | 122 | completion_tokens = [ 123 | results[key].get("token_usages", {}).get(str(temp), {}).get("completion_tokens", 0) 124 | for key in results for temp in temperatures 125 | ] 126 | prompt_tokens = [ 127 | results[key].get("token_usages", {}).get(str(temp), {}).get("prompt_tokens", 0) 128 | for key in results for temp in temperatures 129 | ] 130 | 131 | # Token usage summary 132 | result_dir, result_name = os.path.split(result_file) 133 | token_usage_dir = os.path.join(result_dir, "token_usage") 134 | os.makedirs(token_usage_dir, exist_ok=True) 135 | 136 | # Construct the token usage result file path 137 | token_usage_result_file = os.path.join(token_usage_dir, result_name) 138 | 139 | # Prepare the token usage dictionary 140 | token_dict = { 141 | "completion_tokens": sum(completion_tokens), 142 | "prompt_tokens": sum(prompt_tokens), 143 | "avg_completion_tokens": round(sum(completion_tokens) / len(completion_tokens), 3) if completion_tokens else 0, 144 | "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0, 145 | } 146 | 147 | # Save the token usage dictionary to the result file 148 | with open(token_usage_result_file, "w") as f: 149 | json.dump(token_dict, f, indent=4) 150 | 151 | print(f"Token usage saved to {token_usage_result_file}") 152 | 153 | with open(result_file, 'w', encoding='utf-8') as file: 154 | json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder) 155 | 156 | def perform_check(handler: TaskHandler, temperatures, result_file, args): 157 | results = handler.load_existing_results(result_file) 158 | print(f"Loaded {len(results)} existing results.") 159 | 160 | train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, \ 161 | filter_difficulty=args.filter_difficulty, args=args) 162 | remaining_data = handler.process_remaining_data(train_data, {}) 163 | 164 | tasks = [] 165 | for item in remaining_data: 166 | problem_key = item[handler.get_question_key()] 167 | # If this item exists in the results file, check each temperature 168 | if problem_key in results and "responses" in results[problem_key]: 169 | for temp in temperatures: 170 | if str(temp) in results[problem_key]["responses"]: 171 | response_entries = results[problem_key]["responses"][str(temp)] 172 | for sample_id, response_entry in enumerate(response_entries): 173 | if sample_id > (args.n - 1): continue 174 | if True or response_entry["correctness"] is None: 175 | processed = "processed_content" in response_entry 176 | tasks.append((item, temp, response_entry["processed_content"] if processed else response_entry["content"], sample_id)) 177 | 178 | print(f"Found {len(tasks)} responses requiring reject sampling...") 179 | 180 | total_correct = 0 181 | total_finish = 0 182 | correct = { temp: {} for temp in temperatures } 183 | with ProcessPoolExecutor(max_workers=32) as executor: 184 | future_to_task = { 185 | executor.submit(handler.update_results, item, content): (item, temp, sample_id) 186 | for (item, temp, content, sample_id) in tasks 187 | } 188 | 189 | # 4. Collect the results as they finish. 190 | for future in tqdm(as_completed(future_to_task), total=len(future_to_task), desc="Processing Reject Sampling"): 191 | item, temp, sample_id = future_to_task[future] 192 | new_response_entry = future.result() 193 | total_correct += new_response_entry["correctness"] 194 | total_finish += 1 195 | 196 | # Update the corresponding record in results 197 | problem_key = item[handler.get_question_key()] 198 | if problem_key not in correct[temp]: 199 | correct[temp][problem_key] = False 200 | if new_response_entry["correctness"]: 201 | correct[temp][problem_key] = True 202 | assert problem_key in results and "responses" in results[problem_key] and str(temp) in results[problem_key]["responses"] 203 | response_entry = results[problem_key]["responses"][str(temp)][sample_id] 204 | response_entry["correctness"] = new_response_entry["correctness"] 205 | response_entry["reason"] = new_response_entry["reason"] 206 | results[problem_key]["responses"][str(temp)][sample_id] = response_entry 207 | 208 | print(f"Final reject-sampling accuracy: {total_correct}/{total_finish}") 209 | # per temperature acc 210 | for temp in temperatures: 211 | temp_correct = sum(correct[temp].values()) 212 | temp_total = len(correct[temp]) 213 | temp_acc = round(temp_correct / temp_total, 4) if temp_total > 0 else 0 214 | print(f"Temperature {temp} acc: {temp_correct}/{temp_total} ({temp_acc})") 215 | 216 | with open(result_file, 'w', encoding='utf-8') as file: 217 | json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder) 218 | 219 | def perform_inference_and_save(handler: TaskHandler, temperatures, max_tokens, result_file, llm, system_prompt, args): 220 | print(system_prompt) 221 | results = handler.load_existing_results(result_file) 222 | print(f"Loaded {len(results)} existing results.") 223 | train_data = handler.load_and_filter_dataset(args.start, args.end, split=args.split, source=args.source, \ 224 | filter_difficulty=args.filter_difficulty, args=args) 225 | remaining_data = handler.process_remaining_data(train_data, results) 226 | conversations = handler.make_conversations(remaining_data, system_prompt, args.model) 227 | 228 | for temp in temperatures: 229 | if args.model.startswith("openai"): 230 | fetch_partial = partial(fetch_response_openai, llm, args.model, max_tokens, temp) 231 | 232 | with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e: 233 | responses = list(e.map(fetch_partial, conversations)) 234 | 235 | else: 236 | sampling_params = SamplingParams(n=args.n, max_tokens=max_tokens, temperature=temp) 237 | responses = llm.chat(messages=conversations, sampling_params=sampling_params, use_tqdm=True) 238 | 239 | completion_tokens = [] 240 | prompt_tokens = [] 241 | for idx, response in enumerate(responses): 242 | response_entries = [] 243 | token_usages = [] 244 | completion_token = 0 245 | for sample_idx in range(args.n): 246 | response_entry = { 247 | "content": response.choices[0].message.content.strip() if args.model.startswith("openai") else response.outputs[sample_idx].text.strip(), 248 | "correctness": None, 249 | "reason": None, 250 | } 251 | response_entries.append(response_entry) 252 | if not args.model.startswith("openai"): 253 | token_usages.append({ 254 | "completion_tokens": len(response.outputs[sample_idx].token_ids), 255 | "prompt_tokens": len(response.prompt_token_ids) 256 | }) 257 | completion_token += len(response.outputs[sample_idx].token_ids) 258 | completion_token /= args.n 259 | prompt_token = len(response.prompt_token_ids) 260 | prompt_tokens.append(prompt_token) 261 | completion_tokens.append(completion_token) 262 | 263 | problem_key = remaining_data[idx][handler.get_question_key()] # can you use this idx 264 | if problem_key not in results: 265 | results[problem_key] = remaining_data[idx] 266 | if isinstance(handler, NUMINATaskHandler): 267 | results[problem_key]["messages"] = "" 268 | results[problem_key]["responses"] = {} 269 | results[problem_key]["token_usages"] = {} 270 | prompt = conversations[idx][1]["content"] 271 | results[problem_key]["prompt"] = prompt 272 | 273 | results[problem_key]["responses"][str(temp)] = response_entries 274 | 275 | if args.model.startswith("openai"): 276 | results[problem_key]["token_usages"][str(temp)] = { 277 | "completion_tokens": response.usage.completion_tokens, 278 | "prompt_tokens": response.usage.prompt_tokens, 279 | } 280 | else: 281 | results[problem_key]["token_usages"][str(temp)] = token_usages 282 | 283 | # Token usage summary put into another subdirectory 284 | result_dir, result_name = os.path.split(result_file) 285 | token_usage_dir = os.path.join(result_dir, "token_usage") 286 | os.makedirs(token_usage_dir, exist_ok=True) 287 | 288 | # Construct the token usage result file path 289 | token_usage_result_file = os.path.join(token_usage_dir, result_name) 290 | 291 | # Prepare the token usage dictionary 292 | token_dict = { 293 | "completion_tokens": sum(completion_tokens), 294 | "prompt_tokens": sum(prompt_tokens), 295 | "avg_completion_tokens": round(sum(completion_tokens) / len(completion_tokens), 3) if completion_tokens else 0, 296 | "avg_prompt_tokens": round(sum(prompt_tokens) / len(prompt_tokens), 3) if prompt_tokens else 0, 297 | } 298 | 299 | # Save the token usage dictionary to the result file 300 | with open(token_usage_result_file, "w") as f: 301 | json.dump(token_dict, f, indent=4) 302 | 303 | print(f"Token usage saved to {token_usage_result_file}") 304 | 305 | with open(result_file, 'w', encoding='utf-8') as file: 306 | json.dump(results, file, ensure_ascii=False, indent=4, cls=NumpyEncoder) 307 | 308 | def main(): 309 | parser = argparse.ArgumentParser(description="Unified inference and checking for different datasets/tasks.") 310 | parser.add_argument("--dataset", type=str, required=True, choices=list(TASK_HANDLERS.keys()), help="Dataset to process.") 311 | parser.add_argument("--model", type=str, required=True, default="Qwen/QwQ-32B-Preview", help="The model to run.") 312 | parser.add_argument("--tp", type=int, default=8, help="Tensor Parallelism Degree") 313 | parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for the model.") 314 | parser.add_argument("--split", type=str, default="train", help="Split to use for apps (e.g., train, test).") 315 | parser.add_argument("--source", type=str, help="Source for the dataset.") 316 | parser.add_argument("--start", type=int, default=0, help="Start index.") 317 | parser.add_argument("--end", type=int, default=-1, help="End index.") 318 | parser.add_argument("--filter-difficulty", action="store_true", help="Filter difficulty.") 319 | parser.add_argument("--result-dir", type=str, default="./", help="Result dir to save files.") 320 | parser.add_argument("--check", action="store_true", help="Perform evaluation checks on generated samples.") 321 | parser.add_argument("--inference", action="store_true", help="Perform inference.") 322 | parser.add_argument("--temperatures", type=float, nargs="+", default=[0], help="Temperature for sampling.") 323 | parser.add_argument("--math-difficulty-lower-bound", type=int, default=None, help="Lowest difficulty level for math.") 324 | parser.add_argument("--math-difficulty-upper-bound", type=int, default=None, help="Highest difficulty level for math.") 325 | parser.add_argument("--n", type=int, default=1, help="Number of samples generated per problem.") 326 | args = parser.parse_args() 327 | 328 | handler: TaskHandler = TASK_HANDLERS[args.dataset]() 329 | temperatures = [1] if args.model.startswith("openai/o1") else args.temperatures 330 | 331 | print(f"Temperature: {temperatures}") 332 | max_tokens = args.max_tokens 333 | if temperatures == [0] and args.n > 1: 334 | args.n = 1 335 | print("Warning: Temperature 0 does not support multiple samples. Setting n=1.") 336 | 337 | # create result dir if not exists 338 | if args.result_dir and not os.path.exists(args.result_dir): 339 | os.makedirs(args.result_dir) 340 | if args.math_difficulty_lower_bound is not None or args.math_difficulty_upper_bound is not None: 341 | result_file = os.path.join(args.result_dir, f"{get_MODEL_TO_NAME(args.model)}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json") 342 | else: 343 | result_file = os.path.join(args.result_dir, f"{get_MODEL_TO_NAME(args.model)}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}.json") 344 | 345 | if args.check: 346 | # check if converted file exists 347 | if args.math_difficulty_lower_bound is not None or args.math_difficulty_upper_bound is not None: 348 | converted_file = f"{args.result_dir}/converted_{get_MODEL_TO_NAME(args.model)}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}_{args.math_difficulty_lower_bound}_{args.math_difficulty_upper_bound}.json" 349 | else: 350 | converted_file = f"{args.result_dir}/converted_{get_MODEL_TO_NAME(args.model)}_{args.dataset}_{args.split}_{args.source}_{args.start}_{args.end}.json" 351 | if os.path.exists(converted_file): 352 | result_file = converted_file 353 | perform_check(handler, temperatures, result_file, args) 354 | return 355 | elif args.inference: 356 | llm = OpenAI() if args.model.startswith("openai") else LLM(model=args.model, tensor_parallel_size=args.tp) 357 | system_prompt = SYSTEM_PROMPT.get(args.model, SYSTEM_PROMPT['NovaSky-AI/Sky-T1-32B-Preview']) 358 | perform_inference_and_save(handler, temperatures, max_tokens, result_file, llm, system_prompt, args) 359 | return 360 | 361 | llm = OpenAI() if args.model.startswith("openai") else LLM(model=args.model, tensor_parallel_size=args.tp) 362 | system_prompt = SYSTEM_PROMPT.get(args.model, SYSTEM_PROMPT['NovaSky-AI/Sky-T1-32B-Preview']) 363 | perform_inference_and_check(handler, temperatures, max_tokens, result_file, llm, system_prompt, args) 364 | 365 | if __name__ == "__main__": 366 | main() 367 | -------------------------------------------------------------------------------- /tools/label_math_difficulty.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from tqdm import tqdm 4 | import multiprocessing as mp 5 | import openai 6 | from itertools import cycle 7 | import time 8 | import os 9 | from datasets import load_dataset 10 | import re 11 | import ast 12 | from util.prompts import grading_prompt, aops_criteria 13 | 14 | # Function to set the OpenAI API key 15 | def set_openai_key(api_key): 16 | openai.api_key = api_key 17 | 18 | # From FastChat 19 | def find_difficulty(judgment): 20 | one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]") 21 | one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]") 22 | match = re.search(one_score_pattern, judgment) 23 | if not match: 24 | match = re.search(one_score_pattern_backup, judgment) 25 | 26 | if match: 27 | rating = ast.literal_eval(match.groups()[0]) 28 | else: 29 | rating = -1 30 | 31 | return rating 32 | 33 | # GPT API processing function with retry logic 34 | def process_content(problem, api_key): 35 | # Set the OpenAI key for this request 36 | set_openai_key(api_key) 37 | 38 | # GPT prompt 39 | prompt = grading_prompt.format(problem=problem, aops_criteria=aops_criteria) 40 | retries = 3 41 | while retries > 0: 42 | try: 43 | # OpenAI API call 44 | response = openai.chat.completions.create( 45 | model="gpt-4o-mini", 46 | messages=[ 47 | {"role": "system", "content": "You are a math problem difficulty labeler."}, 48 | {"role": "user", "content": prompt} 49 | ], 50 | max_tokens=2048, 51 | temperature=0.7 52 | ) 53 | return response.choices[0].message.content 54 | except openai.RateLimitError: 55 | retries -= 1 56 | if retries == 0: 57 | return "Error: Rate limit reached and retries exhausted." 58 | print(f"Sleep for 5 seconds for API limit.") 59 | time.sleep(5) 60 | except Exception as e: 61 | return f"Error processing content: {e}" 62 | 63 | def process_entry(entry, api_key_cycle): 64 | # Get the next API key from the cycle 65 | api_key = next(api_key_cycle) 66 | 67 | # Pass only entry["problem"] to the process_content function 68 | processed = process_content(entry["problem"], api_key) 69 | 70 | # Store the processed content in the responses 71 | entry["messages"] = "" 72 | entry["gpt_difficulty"] = processed 73 | entry["gpt_difficulty_parsed"] = find_difficulty(processed) 74 | return entry 75 | 76 | # Wrapper function for multiprocessing 77 | def process_entry_wrapper(args): 78 | return process_entry(*args) 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser(description="Label difficulty") 82 | parser.add_argument("--source", type=str, help="") 83 | parser.add_argument("--start", type=int, default=0, help="") 84 | parser.add_argument("--end", type=int, default=-1, help="") 85 | parser.add_argument("--keys", type=str, help="File containing OpenAI API keys (one per line).") 86 | args = parser.parse_args() 87 | 88 | dataset = load_dataset("AI-MO/NuminaMath-CoT") 89 | data = ( 90 | dataset["train"] 91 | .to_pandas() 92 | .query('source == @args.source') 93 | .iloc[args.start:args.end] 94 | ) 95 | 96 | data = data.to_dict(orient="records") 97 | 98 | # Load API keys and prepare a round-robin cycle 99 | with open(args.keys, "r") as f: 100 | api_keys = [line.strip() for line in f if line.strip()] 101 | api_key_cycle = cycle(api_keys) 102 | 103 | # Prepare output file 104 | output_file = f"labeled_{args.source}_{args.start}_{args.end}.json" 105 | 106 | # Use multiprocessing to process the content 107 | results = [] 108 | with mp.Pool(os.cpu_count()) as pool: 109 | tasks = [(entry, api_key_cycle) for entry in data] 110 | for result in tqdm(pool.imap(process_entry_wrapper, tasks), total=len(data)): 111 | results.append(result) 112 | 113 | # Aggregate and write results in the main process 114 | # aggregated_data = {key: values for key, values in results} 115 | # print(results) 116 | with open(output_file, "w") as f: 117 | json.dump(results, f, indent=4) 118 | 119 | print(f"Processed data saved to {output_file}") 120 | -------------------------------------------------------------------------------- /tools/labeled_numina_difficulty/README.md: -------------------------------------------------------------------------------- 1 | # Labeled NUMINA Difficulty Data 2 | 3 | We also include data of labeled difficulty from NUMINA, in the following files: `labeled_amc_aime_0_-1.json`, `labeled_math_0_-1.json`, `labeled_olympiads_0_-1.json`. These files can be found and downloaded from [HuggingFace](https://huggingface.co/datasets/NovaSky-AI/labeled_numina_difficulty). -------------------------------------------------------------------------------- /tools/requirements.txt: -------------------------------------------------------------------------------- 1 | vllm==0.6.2 2 | pyext 3 | word2number 4 | scipy 5 | datasets 6 | latex2sympy2 -------------------------------------------------------------------------------- /tools/response_rewrite.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | from tqdm import tqdm 6 | from util.math.testing_util import strip_answer_string 7 | from util.model_utils import * 8 | from vllm import LLM, SamplingParams 9 | 10 | def load_dataset(dataset_path : str): 11 | data = {} 12 | with open(dataset_path, 'r', encoding='utf-8') as file: 13 | data = json.load(file) 14 | return data 15 | 16 | 17 | def make_scoring_conversations(dataset, system_prompt): 18 | conversations = [] 19 | for _, key in enumerate(dataset): 20 | problem = dataset[key] 21 | gt_answer = strip_answer_string(problem["answer"]) 22 | for response_key in problem["responses"]: 23 | response = problem["responses"][response_key]["content"] 24 | prompt_text = response + "\n#####\nThe ground truth answer is " + gt_answer 25 | conversations.append([ 26 | {"role": "system", "content": system_prompt}, 27 | {"role": "user", "content": prompt_text} 28 | ]) 29 | 30 | return conversations 31 | 32 | 33 | def score_solutions(dataset, responses, outfile): 34 | idx = 0 35 | for _, key in tqdm(enumerate(dataset), total=len(dataset), desc="Scoring original solutions"): 36 | problem = dataset[key] 37 | for response_key in problem["responses"]: 38 | score = responses[idx].outputs[0].text.strip() 39 | problem["responses"][response_key]["correctness"] = (score == "True") 40 | idx += 1 41 | 42 | with open(outfile, 'w', encoding='utf-8') as new_file: 43 | json.dump(dataset, new_file, ensure_ascii=False, indent=2) 44 | return dataset 45 | 46 | 47 | def filter_solutions(dataset): 48 | # First filter out incorrect responses. 49 | for key in dataset: 50 | problem = dataset[key] 51 | keys_to_filter = [] 52 | for response_key in problem["responses"]: 53 | if not problem["responses"][response_key]["correctness"]: 54 | keys_to_filter.append(response_key) 55 | for k in keys_to_filter: 56 | del problem["responses"][k] 57 | del problem["token_usages"][k] 58 | 59 | # Next, filter out examples with <2 correct responses. 60 | keys_to_filter = [] 61 | for key in dataset: 62 | problem = dataset[key] 63 | if len(problem["responses"]) < 2: 64 | keys_to_filter.append(key) 65 | for k in keys_to_filter: 66 | del dataset[k] 67 | 68 | # Finally, filter for the shortest and longest solutions for each sample. 69 | for key in dataset: 70 | problem = dataset[key] 71 | token_usages = problem["token_usages"] 72 | shortest_key, shortest_entry = min(token_usages.items(), key=lambda x: x[1]["completion_tokens"]) 73 | longest_key, longest_entry = max(token_usages.items(), key=lambda x: x[1]["completion_tokens"]) 74 | problem["token_usages"] = { 75 | "shortest": shortest_entry, 76 | "longest": longest_entry, 77 | } 78 | new_responses = { 79 | "shortest": problem["responses"][shortest_key], 80 | "longest":problem["responses"][longest_key], 81 | } 82 | problem["responses"] = new_responses 83 | 84 | return dataset 85 | 86 | 87 | def make_splitting_conversations(data, system_prompt): 88 | conversations = [] 89 | for problem in data: 90 | response = data[problem]["responses"]["shortest"] 91 | prompt_text = response["content"] 92 | conversations.append([ 93 | {"role": "system", "content": system_prompt}, 94 | {"role": "user", "content": prompt_text} 95 | ]) 96 | return conversations 97 | 98 | 99 | def split_solutions(dataset, responses, delimiter): 100 | outputs = [] 101 | for _, response in tqdm(enumerate(responses), total=len(responses), desc="Splitting responses"): 102 | content = response.outputs[0].text.strip() 103 | # Split response by configured delimiter. 104 | split_content = content.split(delimiter) 105 | split_content = [x.strip() for x in split_content if x != ""] 106 | outputs.append(split_content) 107 | for idx, key in enumerate(dataset): 108 | solutions = outputs[idx] 109 | problem = dataset[key] 110 | problem["responses"]["shortest"]["subsolutions"] = solutions 111 | return dataset 112 | 113 | 114 | def make_subscoring_conversations(dataset, system_prompt): 115 | conversations = [] 116 | for _, key in enumerate(dataset): 117 | problem = dataset[key] 118 | gt_answer = strip_answer_string(problem["answer"]) 119 | subsolutions = problem["responses"]["shortest"]["subsolutions"] 120 | for sub in subsolutions: 121 | prompt_text = sub + "\n#####\nThe ground truth answer is " + gt_answer 122 | conversations.append([ 123 | {"role": "system", "content": system_prompt}, 124 | {"role": "user", "content": prompt_text} 125 | ]) 126 | return conversations 127 | 128 | 129 | def score_subsolutions(dataset, responses): 130 | idx = 0 131 | for _, key in tqdm(enumerate(dataset), total=len(dataset), desc="Scoring sub-solutions"): 132 | problem = dataset[key] 133 | subsolutions = problem["responses"]["shortest"]["subsolutions"] 134 | scores = [] 135 | for _, sub in enumerate(subsolutions): 136 | score = responses[idx].outputs[0].text.strip() 137 | scores.append(score == "True") 138 | idx += 1 139 | problem["responses"]["shortest"]["scores"] = scores 140 | return dataset 141 | 142 | 143 | def build_response_variants(dataset): 144 | def clean_response_string(response): 145 | if '<|end_of_thought|>' not in response: 146 | response += '<|end_of_thought|>' 147 | return response 148 | 149 | keys_to_remove = [] 150 | 151 | for key, problem in dataset.items(): 152 | scores = problem["responses"]["shortest"]["scores"] 153 | subsolutions = problem["responses"]["shortest"]["subsolutions"] 154 | 155 | # Check if there are valid scores 156 | if True not in scores: 157 | keys_to_remove.append(key) 158 | continue 159 | 160 | # Build FCS (First Correct Solution) 161 | fcs_idx = scores.index(True) 162 | fcs_response = "\n".join(subsolutions[:fcs_idx + 1]) if fcs_idx < len(scores) - 1 else "\n".join(subsolutions[:-1]) 163 | fcs_response = clean_response_string(fcs_response) + "\n" + subsolutions[-1] 164 | problem["responses"]["fcs"] = fcs_response 165 | 166 | # Build FCS + 1 167 | fcs_plus1_idx = fcs_idx + 1 if fcs_idx + 1 < len(subsolutions) - 1 else fcs_idx 168 | fcs_plus1_response = "\n".join(subsolutions[:fcs_plus1_idx + 1]) 169 | fcs_plus1_response = clean_response_string(fcs_plus1_response) + "\n" + subsolutions[-1] 170 | problem["responses"]["fcs_plus1"] = fcs_plus1_response 171 | 172 | # Check if there are valid scores 173 | if True not in scores[fcs_idx + 1:]: 174 | keys_to_remove.append(key) 175 | continue 176 | 177 | # Build FCS + Reflection 178 | fcs_reflection_idx = scores.index(True, fcs_idx + 1) 179 | fcs_reflection_response = "\n".join(subsolutions[:fcs_reflection_idx + 1]) if fcs_reflection_idx < len(scores) - 1 else "\n".join(subsolutions[:-1]) 180 | fcs_reflection_response = clean_response_string(fcs_reflection_response) + "\n" + subsolutions[-1] 181 | problem["responses"]["fcs_reflection"] = fcs_reflection_response 182 | 183 | # Remove problems without valid sub-solutions 184 | for key in keys_to_remove: 185 | del dataset[key] 186 | 187 | return dataset 188 | 189 | 190 | def compute_token_usages(dataset, variants, llm): 191 | tokenizer = llm.get_tokenizer() 192 | for key in tqdm(dataset, desc="Computing token usages", total=len(dataset)): 193 | problem = dataset[key] 194 | prompt_tokens = problem["token_usages"]["shortest"]["prompt_tokens"] 195 | for variant in variants: 196 | problem["token_usages"][variant] = { 197 | "prompt_tokens": prompt_tokens, 198 | "completion_tokens": len(tokenizer(problem["responses"][variant]).input_ids) 199 | } 200 | return dataset 201 | 202 | 203 | def build_question_prompt(prompt): 204 | return "Return your final response within \\boxed{{}}" + prompt 205 | 206 | def make_preference_conversations(final_dataset, format, system_prompt): 207 | conversations = [] 208 | for prompt in final_dataset: 209 | problem = final_dataset[prompt] 210 | convo = {} 211 | convo["conversations"] = [ 212 | { 213 | "from": "system", 214 | "value": system_prompt, 215 | }, 216 | { 217 | "from": "human", 218 | "value": build_question_prompt(prompt), 219 | } 220 | ] 221 | convo["chosen"] = { 222 | "from": "gpt", 223 | "value": problem["responses"][format], 224 | } 225 | convo["rejected"] = { 226 | "from": "gpt", 227 | "value": problem["responses"]["longest"]["content"] 228 | } 229 | conversations.append(convo) 230 | 231 | return conversations 232 | 233 | 234 | def make_SILC_conversations(dataset, system_prompt): 235 | keys_to_filter = [] 236 | for prompt in dataset: 237 | problem = dataset[prompt] 238 | contition = False 239 | for response_key in problem["responses"]: 240 | if not problem["responses"][response_key]['correctness']: 241 | wrong_length = problem["token_usages"][response_key]['completion_tokens'] 242 | for k in problem["responses"]: 243 | if k != response_key and problem["token_usages"][k]['completion_tokens'] > wrong_length and problem["responses"][k]['correctness']: 244 | contition = True 245 | break 246 | break 247 | if not contition: 248 | keys_to_filter.append(prompt) 249 | 250 | for key in keys_to_filter: 251 | del dataset[key] 252 | 253 | # Build contrastive pairs out of {short incorrect, long correct} 254 | conversations = [] 255 | for prompt in dataset: 256 | problem = dataset[prompt] 257 | 258 | shortest_incorrect_key = None 259 | shortest_incorrect_length = float('inf') 260 | 261 | # Get shortest incorrect. 262 | for response_key in problem["responses"]: 263 | if not problem["responses"][response_key]['correctness']: 264 | length = problem["token_usages"][response_key]['completion_tokens'] 265 | if length < shortest_incorrect_length: 266 | shortest_incorrect_length = length 267 | shortest_incorrect_key = response_key 268 | 269 | # Get next longest correct. 270 | shortest_correct_longer_key = None 271 | shortest_correct_longer_length = float('inf') 272 | for response_key in problem["responses"]: 273 | if problem["responses"][response_key]['correctness']: 274 | length = problem["token_usages"][response_key]['completion_tokens'] 275 | if length > shortest_incorrect_length and length < shortest_correct_longer_length: 276 | shortest_correct_longer_length = length 277 | shortest_correct_longer_key = response_key 278 | 279 | convo = {} 280 | convo["conversations"] = [ 281 | { 282 | "from": "system", 283 | "value": system_prompt, 284 | }, 285 | { 286 | "from": "human", 287 | "value": build_question_prompt(prompt), 288 | } 289 | ] 290 | convo["chosen"] = { 291 | "from": "gpt", 292 | "value": problem["responses"][shortest_correct_longer_key]['content'], 293 | } 294 | convo["rejected"] = { 295 | "from": "gpt", 296 | "value": problem["responses"][shortest_incorrect_key]["content"] 297 | } 298 | conversations.append(convo) 299 | 300 | return conversations 301 | 302 | 303 | def main(): 304 | parser = argparse.ArgumentParser(description="Filter, rewrite, and format generated responses for high-quality data curation.") 305 | parser.add_argument("--rewrite-model", type=str, required=True, default="meta-llama/Llama-3.3-70B-Instruct", help="The model used for response processing.") 306 | parser.add_argument("--target-model", type=str, required=True, default="NovaSky-AI/Sky-T1-32B-Preview", help="The target model the rewritten responses will be used to train.") 307 | parser.add_argument("--dataset", type=str, required=True, help="Path to the starting dataset of generated responses to filter from.") 308 | parser.add_argument("--result-dir", type=str, default="./", help="Result directory to save processed data.") 309 | parser.add_argument("--checkpoint", action="store_true", help="Whether to checkpoint the dataset at each step.") 310 | parser.add_argument("--SILC", action="store_true", help="Whether to include short-incorrect/long-correct (SILC) preference pairs.") 311 | parser.add_argument("--tp", type=int, default=8, help="Tensor Parallelism Degree") 312 | parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for the model.") 313 | args = parser.parse_args() 314 | 315 | if args.result_dir and not os.path.exists(args.result_dir): 316 | os.makedirs(args.result_dir) 317 | 318 | # Initialize model for data processing. 319 | llm = LLM(model=args.rewrite_model, tensor_parallel_size=args.tp) 320 | sampling_params = SamplingParams(max_tokens=args.max_tokens) 321 | 322 | original_dataset = load_dataset(args.dataset) 323 | 324 | # Filter for the shortest and longest correct solutions. 325 | filtered_dataset = filter_solutions(original_dataset) 326 | if args.checkpoint: 327 | outfile = os.path.join(args.result_dir, f"filtered-responses.json") 328 | with open(outfile, 'w', encoding='utf-8') as new_file: 329 | json.dump(filtered_dataset, new_file, ensure_ascii=False, indent=2) 330 | 331 | # Split the shortest solution into subsolutions using the configured model. 332 | conversations = make_splitting_conversations(filtered_dataset, SUBPROBLEM_SPLIT_PROMPT) 333 | responses = llm.chat(messages=conversations, sampling_params=sampling_params, use_tqdm=True) 334 | split_dataset = split_solutions(filtered_dataset, responses, '#####') 335 | if args.checkpoint: 336 | outfile = os.path.join(args.result_dir, f"split-solutions.json") 337 | with open(outfile, 'w', encoding='utf-8') as new_file: 338 | json.dump(split_dataset, new_file, ensure_ascii=False, indent=2) 339 | 340 | # Score the subsolutions using the configured model. 341 | subscoring_conversations = make_subscoring_conversations(split_dataset, SUBSOLUTION_EXTRACTION_PROMPT) 342 | responses = llm.chat(messages=subscoring_conversations, sampling_params=sampling_params, use_tqdm=True) 343 | scored_dataset = score_subsolutions(split_dataset, responses) 344 | if args.checkpoint: 345 | outfile = os.path.join(args.result_dir, f"scored-subsolutions.json") 346 | with open(outfile, 'w', encoding='utf-8') as new_file: 347 | json.dump(scored_dataset, new_file, ensure_ascii=False, indent=2) 348 | 349 | # Rewrite response based on variants of combining sub-solutions. Here are examples for 350 | # FCS, FCS+1, and FCS+Reflection. 351 | variants_dataset = build_response_variants(scored_dataset) 352 | if args.checkpoint: 353 | outfile = os.path.join(args.result_dir, f"response-variants.json") 354 | with open(outfile, 'w', encoding='utf-8') as new_file: 355 | json.dump(variants_dataset, new_file, ensure_ascii=False, indent=2) 356 | 357 | # Add per-variant token counts to dataset for convenience. 358 | final_dataset = compute_token_usages(variants_dataset, ["fcs", "fcs_plus1", "fcs_reflection"], llm) 359 | 360 | system_prompt = SYSTEM_PROMPT[args.target_model] 361 | 362 | # Generate conversation format for each variant, which can be used in SimPO/DPO/etc. 363 | fcs_convo = make_preference_conversations(final_dataset, "fcs", system_prompt) 364 | fcs_plus1_convo = make_preference_conversations(final_dataset, "fcs_plus1", system_prompt) 365 | fcs_reflection_convo = make_preference_conversations(final_dataset, "fcs_reflection", system_prompt) 366 | 367 | # Optionall add short incorrect, long correct (SILC) conversations 368 | if args.SILC: 369 | short_incorrect_long_correct_conversations = make_SILC_conversations(load_dataset(args.dataset), system_prompt) 370 | for convo in [fcs_convo, fcs_plus1_convo, fcs_reflection_convo]: 371 | convo += short_incorrect_long_correct_conversations 372 | random.shuffle(convo) 373 | 374 | # Save final conversation variants. 375 | fcs_outfile = os.path.join(args.result_dir, "fcs-conversations.json") 376 | with open(fcs_outfile, 'w', encoding='utf-8') as new_file: 377 | json.dump(fcs_convo, new_file, ensure_ascii=False, indent=2) 378 | 379 | fcs_plus1_outfile = os.path.join(args.result_dir, "fcs_plus1-conversations.json") 380 | with open(fcs_plus1_outfile, 'w', encoding='utf-8') as new_file: 381 | json.dump(fcs_plus1_convo, new_file, ensure_ascii=False, indent=2) 382 | 383 | fcs_reflection_outfile = os.path.join(args.result_dir, "fcs_reflection-conversations.json") 384 | with open(fcs_reflection_outfile, 'w', encoding='utf-8') as new_file: 385 | json.dump(fcs_reflection_convo, new_file, ensure_ascii=False, indent=2) 386 | 387 | 388 | if __name__ == "__main__": 389 | main() 390 | -------------------------------------------------------------------------------- /tools/upload_hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/lm-sys/FastChat/ 3 | Upload weights to huggingface. 4 | 5 | Usage: 6 | python upload_hub.py --model-path ~/model_weights/Sky-T1 --hub-repo-id NovaSky-AI/Sky-T1 --private 7 | """ 8 | import argparse 9 | import tempfile 10 | 11 | import torch 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | def upload_hub(model_path, hub_repo_id, component, private): 15 | if component == "all": 16 | components = ["model", "tokenizer"] 17 | else: 18 | components = [component] 19 | 20 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private} 21 | 22 | if "model" in components: 23 | model = AutoModelForCausalLM.from_pretrained( 24 | model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 25 | ) 26 | with tempfile.TemporaryDirectory() as tmp_path: 27 | model.save_pretrained(tmp_path, **kwargs) 28 | 29 | if "tokenizer" in components: 30 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 31 | with tempfile.TemporaryDirectory() as tmp_path: 32 | tokenizer.save_pretrained(tmp_path, **kwargs) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--model-path", type=str, required=True) 38 | parser.add_argument("--hub-repo-id", type=str, required=True) 39 | parser.add_argument( 40 | "--component", type=str, choices=["all", "model", "tokenizer"], default="all" 41 | ) 42 | parser.add_argument("--private", action="store_true") 43 | args = parser.parse_args() 44 | 45 | upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) -------------------------------------------------------------------------------- /tools/util/common.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | class TimeoutException(Exception): 4 | """Custom exception for function timeout.""" 5 | pass 6 | 7 | def timeout(seconds): 8 | """Decorator to enforce a timeout on a function using multiprocessing.""" 9 | def decorator(func): 10 | def wrapper(*args, **kwargs): 11 | # A queue to store the result or exception 12 | queue = multiprocessing.Queue() 13 | 14 | def target(queue, *args, **kwargs): 15 | try: 16 | result = func(*args, **kwargs) 17 | queue.put((True, result)) 18 | except Exception as e: 19 | queue.put((False, e)) 20 | 21 | process = multiprocessing.Process(target=target, args=(queue, *args), kwargs=kwargs) 22 | process.start() 23 | process.join(seconds) 24 | 25 | if process.is_alive(): 26 | process.terminate() 27 | process.join() 28 | raise TimeoutException(f"Function '{func.__name__}' timed out after {seconds} seconds!") 29 | 30 | success, value = queue.get() 31 | if success: 32 | return value 33 | else: 34 | raise value 35 | return wrapper 36 | return decorator -------------------------------------------------------------------------------- /tools/util/math/testing_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | The logic in this file largely borrows from Qwen2.5-Math codebase at https://github.com/QwenLM/Qwen2.5-Math: 3 | """ 4 | 5 | import re 6 | import regex 7 | from word2number import w2n 8 | from math import isclose 9 | from collections import defaultdict 10 | 11 | from sympy import simplify, N 12 | from sympy.parsing.sympy_parser import parse_expr 13 | from sympy.parsing.latex import parse_latex 14 | from latex2sympy2 import latex2sympy 15 | 16 | def convert_word_number(text: str) -> str: 17 | try: 18 | text = str(w2n.word_to_num(text)) 19 | except: 20 | pass 21 | return text 22 | 23 | def _fix_fracs(string): 24 | substrs = string.split("\\frac") 25 | new_str = substrs[0] 26 | if len(substrs) > 1: 27 | substrs = substrs[1:] 28 | for substr in substrs: 29 | new_str += "\\frac" 30 | if len(substr) > 0 and substr[0] == "{": 31 | new_str += substr 32 | else: 33 | try: 34 | assert len(substr) >= 2 35 | except: 36 | return string 37 | a = substr[0] 38 | b = substr[1] 39 | if b != "{": 40 | if len(substr) > 2: 41 | post_substr = substr[2:] 42 | new_str += "{" + a + "}{" + b + "}" + post_substr 43 | else: 44 | new_str += "{" + a + "}{" + b + "}" 45 | else: 46 | if len(substr) > 2: 47 | post_substr = substr[2:] 48 | new_str += "{" + a + "}" + b + post_substr 49 | else: 50 | new_str += "{" + a + "}" + b 51 | string = new_str 52 | return string 53 | 54 | 55 | def _fix_a_slash_b(string): 56 | if len(string.split("/")) != 2: 57 | return string 58 | a = string.split("/")[0] 59 | b = string.split("/")[1] 60 | try: 61 | if "sqrt" not in a: 62 | a = int(a) 63 | if "sqrt" not in b: 64 | b = int(b) 65 | assert string == "{}/{}".format(a, b) 66 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 67 | return new_string 68 | except: 69 | return string 70 | 71 | 72 | def _fix_sqrt(string): 73 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) 74 | return _string 75 | 76 | def strip_answer_string(string): 77 | string = str(string).strip() 78 | # linebreaks 79 | string = string.replace("\n", "") 80 | 81 | # right "." 82 | string = string.rstrip(".") 83 | 84 | # remove inverse spaces 85 | # replace \\ with \ 86 | string = string.replace("\\!", "") 87 | # string = string.replace("\\ ", "") 88 | # string = string.replace("\\\\", "\\") 89 | 90 | # matrix 91 | string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) 92 | string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) 93 | string = string.replace("bmatrix", "pmatrix") 94 | 95 | # replace tfrac and dfrac with frac 96 | string = string.replace("tfrac", "frac") 97 | string = string.replace("dfrac", "frac") 98 | string = ( 99 | string.replace("\\neq", "\\ne") 100 | .replace("\\leq", "\\le") 101 | .replace("\\geq", "\\ge") 102 | ) 103 | 104 | # remove \left and \right 105 | string = string.replace("\\left", "") 106 | string = string.replace("\\right", "") 107 | string = string.replace("\\{", "{") 108 | string = string.replace("\\}", "}") 109 | 110 | # Function to replace number words with corresponding digits 111 | def replace_match(match): 112 | word = match.group(1).lower() 113 | if convert_word_number(word) == word: 114 | return match.group(0) 115 | else: 116 | return convert_word_number(word) 117 | string = re.sub(r"\\text\{([a-zA-Z]+)\}", replace_match, string) 118 | 119 | # Before removing unit, check if the unit is squared (for surface area) 120 | string = re.sub(r"(cm|inches)\}\^2", r"\1}", string) 121 | 122 | # Remove unit: miles, dollars if after is not none 123 | _string = re.sub(r"\\text{.*?}$", "", string).strip() 124 | if _string != "" and _string != string: 125 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) 126 | string = _string 127 | 128 | # Remove circ (degrees) 129 | string = string.replace("^{\\circ}", "") 130 | string = string.replace("^\\circ", "") 131 | 132 | # remove dollar signs 133 | string = string.replace("\\$", "") 134 | string = string.replace("$", "") 135 | string = string.replace("\\(", "").replace("\\)", "") 136 | 137 | # convert word number to digit 138 | string = convert_word_number(string) 139 | 140 | # replace "\\text{...}" to "..." 141 | string = re.sub(r"\\text\{(.*?)\}", r"\1", string) 142 | for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]: 143 | string = string.replace(key, "") 144 | string = string.replace("\\emptyset", r"{}") 145 | string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}") 146 | 147 | # remove percentage 148 | string = string.replace("\\%", "") 149 | string = string.replace("\%", "") 150 | string = string.replace("%", "") 151 | 152 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 153 | string = string.replace(" .", " 0.") 154 | string = string.replace("{.", "{0.") 155 | 156 | # cdot 157 | # string = string.replace("\\cdot", "") 158 | if ( 159 | string.startswith("{") 160 | and string.endswith("}") 161 | and string.isalnum() 162 | or string.startswith("(") 163 | and string.endswith(")") 164 | and string.isalnum() 165 | or string.startswith("[") 166 | and string.endswith("]") 167 | and string.isalnum() 168 | ): 169 | string = string[1:-1] 170 | 171 | # inf 172 | string = string.replace("infinity", "\\infty") 173 | if "\\infty" not in string: 174 | string = string.replace("inf", "\\infty") 175 | string = string.replace("+\\inity", "\\infty") 176 | 177 | # and 178 | string = string.replace("and", "") 179 | string = string.replace("\\mathbf", "") 180 | 181 | # use regex to remove \mbox{...} 182 | string = re.sub(r"\\mbox{.*?}", "", string) 183 | 184 | # quote 185 | string.replace("'", "") 186 | string.replace('"', "") 187 | 188 | # i, j 189 | if "j" in string and "i" not in string: 190 | string = string.replace("j", "i") 191 | 192 | # replace a.000b where b is not number or b is end, with ab, use regex 193 | string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string) 194 | string = re.sub(r"(\d+)\.0*$", r"\1", string) 195 | 196 | # if empty, return empty string 197 | if len(string) == 0: 198 | return string 199 | if string[0] == ".": 200 | string = "0" + string 201 | 202 | # to consider: get rid of e.g. "k = " or "q = " at beginning 203 | if len(string.split("=")) == 2: 204 | if len(string.split("=")[0]) <= 2: 205 | string = string.split("=")[1] 206 | 207 | string = _fix_sqrt(string) 208 | string = string.replace(" ", "") 209 | 210 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 211 | string = _fix_fracs(string) 212 | 213 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 214 | string = _fix_a_slash_b(string) 215 | 216 | # Remove unnecessary '\' before integers 217 | string = re.sub(r"\\(?=\-?\d+(\\|\)|,|\]|$))", "", string) 218 | 219 | # Remove grade level (e.g., 12th grade) and just maintain the integer 220 | string = re.sub(r"thgrade$", "", string) 221 | 222 | # If the answer is a list of integers (without parenthesis), sort them 223 | if re.fullmatch(r"(\s*-?\d+\s*,)*\s*-?\d+\s*", string): 224 | # Split the string into a list of integers 225 | try: 226 | integer_list = list(map(int, string.split(','))) 227 | except: 228 | integer_list = list(map(int, "-1,-1".split(','))) 229 | 230 | # Sort the list in ascending order 231 | sorted_list = sorted(integer_list) 232 | 233 | # Join the sorted list back into a comma-separated string 234 | string = ','.join(map(str, sorted_list)) 235 | 236 | return string 237 | 238 | def extract_answer(pred_str, use_last_number=True): 239 | pred_str = pred_str.replace("\u043a\u0438", "") 240 | if "final answer is $" in pred_str and "$. I hope" in pred_str: 241 | # minerva_math 242 | tmp = pred_str.split("final answer is $", 1)[1] 243 | pred = tmp.split("$. I hope", 1)[0].strip() 244 | elif "boxed" in pred_str: 245 | ans = pred_str.split("boxed")[-1] 246 | if len(ans) == 0: 247 | return "" 248 | elif ans[0] == "{": 249 | stack = 1 250 | a = "" 251 | for c in ans[1:]: 252 | if c == "{": 253 | stack += 1 254 | a += c 255 | elif c == "}": 256 | stack -= 1 257 | if stack == 0: 258 | break 259 | a += c 260 | else: 261 | a += c 262 | else: 263 | a = ans.split("$")[0].strip() 264 | pred = a 265 | elif "he answer is" in pred_str: 266 | pred = pred_str.split("he answer is")[-1].strip() 267 | elif "final answer is" in pred_str: 268 | pred = pred_str.split("final answer is")[-1].strip() 269 | elif "答案是" in pred_str: 270 | # Handle Chinese few-shot multiple choice problem answer extraction 271 | pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip() 272 | else: # use the last number 273 | if use_last_number: 274 | pattern = "-?\d*\.?\d+" 275 | pred = re.findall(pattern, pred_str.replace(",", "")) 276 | if len(pred) >= 1: 277 | pred = pred[-1] 278 | else: 279 | pred = "" 280 | else: 281 | pred = "" 282 | 283 | # multiple line 284 | # pred = pred.split("\n")[0] 285 | pred = re.sub(r"\n\s*", "", pred) 286 | if pred != "" and pred[0] == ":": 287 | pred = pred[1:] 288 | if pred != "" and pred[-1] == ".": 289 | pred = pred[:-1] 290 | if pred != "" and pred[-1] == "/": 291 | pred = pred[:-1] 292 | pred = strip_answer_string(pred) 293 | return pred 294 | 295 | def get_multiple_choice_answer(pred: str): 296 | tmp = re.findall(r"\b(A|B|C|D)\b", pred.upper()) 297 | if tmp: 298 | pred = tmp 299 | else: 300 | pred = [pred.strip().strip(".")] 301 | 302 | if len(pred) == 0: 303 | pred = "" 304 | else: 305 | pred = pred[-1] 306 | 307 | # Remove the period at the end, again! 308 | pred = pred.rstrip(".").rstrip("/") 309 | 310 | return pred 311 | 312 | def mmlu_pro_extract_answer(text): 313 | pattern = r"answer is \(?([A-J])\)?" 314 | match = re.search(pattern, text) 315 | if match: 316 | return match.group(1) 317 | else: 318 | # print("1st answer extract failed\n" + text) 319 | match = re.search(r'.*[aA]nswer:\s*([A-J])', text) 320 | if match: 321 | return match.group(1) 322 | else: 323 | # print("2nd answer extract failed\n" + text) 324 | pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)" 325 | match = re.search(pattern, text, re.DOTALL) 326 | if match: 327 | return match.group(0) 328 | 329 | def choice_answer_clean(pred: str): 330 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") 331 | # Clean the answer based on the dataset 332 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) 333 | if tmp: 334 | pred = tmp 335 | else: 336 | pred = [pred.strip().strip(".")] 337 | pred = pred[-1] 338 | # Remove the period at the end, again! 339 | pred = pred.rstrip(".").rstrip("/") 340 | return pred 341 | 342 | 343 | def parse_digits(num): 344 | num = regex.sub(",", "", str(num)) 345 | try: 346 | return float(num) 347 | except: 348 | if num.endswith("%"): 349 | num = num[:-1] 350 | if num.endswith("\\"): 351 | num = num[:-1] 352 | try: 353 | return float(num) / 100 354 | except: 355 | pass 356 | return None 357 | 358 | 359 | def is_digit(num): 360 | # paired with parse_digits 361 | return parse_digits(num) is not None 362 | 363 | 364 | def str_to_pmatrix(input_str): 365 | input_str = input_str.strip() 366 | matrix_str = re.findall(r"\{.*,.*\}", input_str) 367 | pmatrix_list = [] 368 | 369 | for m in matrix_str: 370 | m = m.strip("{}") 371 | pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}" 372 | pmatrix_list.append(pmatrix) 373 | 374 | return ", ".join(pmatrix_list) 375 | 376 | 377 | def math_equal( 378 | prediction, 379 | reference, 380 | include_percentage: bool = True, 381 | is_close: bool = True, 382 | timeout: bool = False, 383 | ) -> bool: 384 | """ 385 | Exact match of math if and only if: 386 | 1. numerical equal: both can convert to float and are equal 387 | 2. symbolic equal: both can convert to sympy expression and are equal 388 | """ 389 | if prediction is None or reference is None: 390 | return False 391 | if str(prediction.strip().lower()) == str(reference.strip().lower()): 392 | return True 393 | if ( 394 | reference in ["A", "B", "C", "D", "E"] 395 | and choice_answer_clean(prediction) == reference 396 | ): 397 | return True 398 | 399 | try: # 1. numerical equal 400 | if is_digit(prediction) and is_digit(reference): 401 | prediction = parse_digits(prediction) 402 | reference = parse_digits(reference) 403 | # number questions 404 | if include_percentage: 405 | gt_result = [reference / 100, reference, reference * 100] 406 | else: 407 | gt_result = [reference] 408 | for item in gt_result: 409 | try: 410 | if is_close: 411 | if numeric_equal(prediction, item): 412 | return True 413 | else: 414 | if item == prediction: 415 | return True 416 | except Exception: 417 | continue 418 | return False 419 | except: 420 | pass 421 | 422 | if not prediction and prediction not in [0, False]: 423 | return False 424 | 425 | # 2. symbolic equal 426 | reference = str(reference).strip() 427 | prediction = str(prediction).strip() 428 | 429 | ## pmatrix (amps) 430 | if "pmatrix" in prediction and not "pmatrix" in reference: 431 | reference = str_to_pmatrix(reference) 432 | 433 | ## deal with [], (), {} 434 | pred_str, ref_str = prediction, reference 435 | if ( 436 | prediction.startswith("[") 437 | and prediction.endswith("]") 438 | and not reference.startswith("(") 439 | ) or ( 440 | prediction.startswith("(") 441 | and prediction.endswith(")") 442 | and not reference.startswith("[") 443 | ): 444 | pred_str = pred_str.strip("[]()") 445 | ref_str = ref_str.strip("[]()") 446 | for s in ["{", "}", "(", ")"]: 447 | ref_str = ref_str.replace(s, "") 448 | pred_str = pred_str.replace(s, "") 449 | if pred_str.lower() == ref_str.lower(): 450 | return True 451 | 452 | ## [a, b] vs. [c, d], return a==c and b==d 453 | if ( 454 | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None 455 | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None 456 | ): 457 | pred_parts = prediction[1:-1].split(",") 458 | ref_parts = reference[1:-1].split(",") 459 | if len(pred_parts) == len(ref_parts): 460 | if all( 461 | [ 462 | math_equal( 463 | pred_parts[i], ref_parts[i], include_percentage, is_close 464 | ) 465 | for i in range(len(pred_parts)) 466 | ] 467 | ): 468 | return True 469 | if ( 470 | ( 471 | prediction.startswith("\\begin{pmatrix}") 472 | or prediction.startswith("\\begin{bmatrix}") 473 | ) 474 | and ( 475 | prediction.endswith("\\end{pmatrix}") 476 | or prediction.endswith("\\end{bmatrix}") 477 | ) 478 | and ( 479 | reference.startswith("\\begin{pmatrix}") 480 | or reference.startswith("\\begin{bmatrix}") 481 | ) 482 | and ( 483 | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") 484 | ) 485 | ): 486 | pred_lines = [ 487 | line.strip() 488 | for line in prediction[ 489 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") 490 | ].split("\\\\") 491 | if line.strip() 492 | ] 493 | ref_lines = [ 494 | line.strip() 495 | for line in reference[ 496 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") 497 | ].split("\\\\") 498 | if line.strip() 499 | ] 500 | matched = True 501 | if len(pred_lines) == len(ref_lines): 502 | for pred_line, ref_line in zip(pred_lines, ref_lines): 503 | pred_parts = pred_line.split("&") 504 | ref_parts = ref_line.split("&") 505 | if len(pred_parts) == len(ref_parts): 506 | if not all( 507 | [ 508 | math_equal( 509 | pred_parts[i], 510 | ref_parts[i], 511 | include_percentage, 512 | is_close, 513 | ) 514 | for i in range(len(pred_parts)) 515 | ] 516 | ): 517 | matched = False 518 | break 519 | else: 520 | matched = False 521 | if not matched: 522 | break 523 | else: 524 | matched = False 525 | if matched: 526 | return True 527 | 528 | if prediction.count("=") == 1 and reference.count("=") == 1: 529 | pred = prediction.split("=") 530 | pred = f"{pred[0].strip()} - ({pred[1].strip()})" 531 | ref = reference.split("=") 532 | ref = f"{ref[0].strip()} - ({ref[1].strip()})" 533 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): 534 | return True 535 | elif ( 536 | prediction.count("=") == 1 537 | and len(prediction.split("=")[0].strip()) <= 2 538 | and "=" not in reference 539 | ): 540 | if math_equal( 541 | prediction.split("=")[1], reference, include_percentage, is_close 542 | ): 543 | return True 544 | elif ( 545 | reference.count("=") == 1 546 | and len(reference.split("=")[0].strip()) <= 2 547 | and "=" not in prediction 548 | ): 549 | if math_equal( 550 | prediction, reference.split("=")[1], include_percentage, is_close 551 | ): 552 | return True 553 | 554 | if symbolic_equal(prediction, reference): 555 | return True 556 | 557 | return False 558 | 559 | 560 | def numeric_equal(prediction: float, reference: float): 561 | return isclose(reference, prediction, rel_tol=1e-4) 562 | 563 | 564 | def symbolic_equal(a, b): 565 | def _parse(s): 566 | for f in [parse_latex, parse_expr, latex2sympy]: 567 | try: 568 | return f(s.replace("\\\\", "\\")) 569 | except: 570 | try: 571 | return f(s) 572 | except: 573 | pass 574 | return s 575 | 576 | a = _parse(a) 577 | b = _parse(b) 578 | 579 | # direct equal 580 | try: 581 | if str(a) == str(b) or a == b: 582 | return True 583 | except: 584 | pass 585 | 586 | # simplify equal 587 | try: 588 | if a.equals(b) or simplify(a - b) == 0: 589 | return True 590 | except: 591 | pass 592 | 593 | # equation equal 594 | try: 595 | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): 596 | return True 597 | except: 598 | pass 599 | 600 | try: 601 | if numeric_equal(float(N(a)), float(N(b))): 602 | return True 603 | except: 604 | pass 605 | 606 | # matrix 607 | try: 608 | # if a and b are matrix 609 | if a.shape == b.shape: 610 | _a = a.applyfunc(lambda x: round(x, 3)) 611 | _b = b.applyfunc(lambda x: round(x, 3)) 612 | if _a.equals(_b): 613 | return True 614 | except: 615 | pass 616 | 617 | return False -------------------------------------------------------------------------------- /tools/util/model_utils.py: -------------------------------------------------------------------------------- 1 | SYSTEM_PROMPT = { 2 | "Qwen/Qwen2-7B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 3 | "Qwen/QwQ-32B-Preview": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 4 | "Qwen/Qwen2.5-72B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 5 | "Qwen/Qwen2.5-32B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 6 | "Qwen/Qwen2.5-7B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 7 | "Qwen/Qwen2.5-1.5B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 8 | "Qwen/Qwen2.5-Math-7B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 9 | "Qwen/Qwen2.5-Math-72B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 10 | "PRIME-RL/Eurus-2-7B-PRIME": """When tackling complex reasoning tasks, you have access to the following actions. Use them as needed to progress through your thought process. After each action, determine and state the next most appropriate action to take. 11 | 12 | Actions: 13 | 14 | {actions} 15 | 16 | Your action should contain multiple steps, and each step starts with #. After each action (except OUTPUT), state which action you will take next with ''Next action: [Your action]'' and finish this turn. Continue this process until you reach a satisfactory conclusion or solution to the problem at hand, at which point you should use the [OUTPUT] action. The thought process is completely invisible to user, so [OUTPUT] should be a complete response. You should strictly follow the format below: 17 | 18 | [ACTION NAME] 19 | 20 | # Your action step 1 21 | 22 | # Your action step 2 23 | 24 | # Your action step 3 25 | 26 | ... 27 | 28 | Next action: [NEXT ACTION NAME] 29 | 30 | 31 | Now, begin with the [ASSESS] action for the following task: 32 | """, 33 | "NovaSky-AI/Sky-T1-32B-Preview": "Your role as an assistant involves thoroughly exploring questions through a systematic long \ 34 | thinking process before providing the final precise and accurate solutions. This requires \ 35 | engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, \ 36 | backtracing, and iteration to develop well-considered thinking process. \ 37 | Please structure your response into two main sections: Thought and Solution. \ 38 | In the Thought section, detail your reasoning process using the specified format: \ 39 | <|begin_of_thought|> {thought with steps separated with '\n\n'} \ 40 | <|end_of_thought|> \ 41 | Each step should include detailed considerations such as analisying questions, summarizing \ 42 | relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining \ 43 | any errors, and revisiting previous steps. \ 44 | In the Solution section, based on various attempts, explorations, and reflections from the Thought \ 45 | section, systematically present the final solution that you deem correct. The solution should \ 46 | remain a logical, accurate, concise expression style and detail necessary step needed to reach the \ 47 | conclusion, formatted as follows: \ 48 | <|begin_of_solution|> \ 49 | {final formatted, precise, and clear solution} \ 50 | <|end_of_solution|> \ 51 | Now, try to solve the following question through the above guidelines:", 52 | "openai/o1-mini": "Question: {input}\nAnswer: ", 53 | "openai/o1-preview": "Question: {input}\nAnswer: ", 54 | "openai/gpt-4o-mini": "User: {input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant:", 55 | } 56 | 57 | def get_MODEL_TO_NAME(key): 58 | if key in MODEL_TO_NAME: 59 | return MODEL_TO_NAME[key] 60 | else: 61 | keys = key.split("/") 62 | if 'checkpoint' in keys[-1]: return keys[-2] 63 | return keys[-1] 64 | 65 | MODEL_TO_NAME = { 66 | "Qwen/Qwen2-7B-Instruct": "Qwen2-7B-Instruct", 67 | "Qwen/QwQ-32B-Preview": "QwQ-32B-Preview", 68 | "Qwen/Qwen2.5-72B-Instruct": "Qwen2.5-72B-Instruct", 69 | "Qwen/Qwen2.5-32B-Instruct": "Qwen2.5-32B-Instruct", 70 | "Qwen/Qwen2.5-7B-Instruct": "Qwen2.5-7B-Instruct", 71 | "Qwen/Qwen2.5-1.5B-Instruct": "Qwen2.5-1.5B-Instruct", 72 | "Qwen/Qwen2.5-Math-7B-Instruct": "Qwen2.5-Math-7B-Instruct", 73 | "Qwen/Qwen2.5-Math-72B-Instruct": "Qwen2.5-Math-72B-Instruct", 74 | "PRIME-RL/Eurus-2-7B-PRIME": "Eurus-2-7B-PRIME", 75 | "NovaSky-AI/Sky-T1-32B-Preview": "Sky-T1-32B-Preview", 76 | "openai/o1-mini": "o1-mini", 77 | "openai/o1-preview": "o1-preview", 78 | "openai/gpt-4o-mini": "gpt-4o-mini", 79 | } 80 | 81 | SUBPROBLEM_SPLIT_PROMPT = """ 82 | You are given a reasoning sequence that attempts to solve a math problem. 83 | This sequence contains multiple proposed solutions, then provides a the final solution. 84 | Each proposed solution within the sequence follows a different line of thought, usually to double check the answer. 85 | Your objective is to identify these separate lines of thought and add the separator string '#####' between the separate lines of thought. 86 | This is important: Your response should be the original unchanged reasoning sequence, except for '#####' injected into the sequence between distinct lines of thought. 87 | Do NOT summarize portions of the reasoning sequence with '...'. 88 | 89 | Please keep the sequence that starts with '<|begin_of_solution|>' and ends with '<|end_of_solution|>' as 90 | one single sequence with no '#####' inside of the sequence. Add the separator '#####' immediately before '<|begin_of_solution|>'. 91 | 92 | Importantly, only use '#####' if a line of thought presents an answer. 93 | If the line of thought does not include an answer, it cannot be considered a separate line of thought, and should not be separated. 94 | 95 | For example, if the input is: 96 | <|begin_of_thought|>The answer to 2+3 is 5. But wait, let me double check this. 97 | If I have two apples and I am given three more apples, I now have 5 apples, so 5 seems like the right answer. 98 | Alternatively, 2+3 is the same as 3+2, which is also 5.<|end_of_thought|> 99 | <|begin_of_solution|>The answer is 5<|end_of_solution|>. 100 | 101 | Your output should be: 102 | <|begin_of_thought|>The answer to 2+3 is 5. 103 | ##### 104 | But wait, let me double check this. 105 | If I have two apples and I am given three more apples, I now have 5 apples, so 5 seems like the right answer. 106 | ##### 107 | Alternatively, 2+3 is the same as 3+2, which is also 5.<|end_of_thought|> 108 | ##### 109 | <|begin_of_solution|>The answer is 5<|end_of_solution|>. 110 | """ 111 | 112 | SUBSOLUTION_EXTRACTION_PROMPT = """ 113 | You are given text of an attemp to solve a math problem. The text contains a final proposed answer to the math problem. 114 | 115 | The text also contains a string '#####' and after this string the ground truth answer is presented. 116 | 117 | Your objective is to determine whether the final proposed answer is equivalent to the ground truth answer. 118 | The proposed answer and ground truth answer may be in slightly different formats. For example, the proposed answer may be '1/2' but the ground truth is '0.5'. 119 | Equivalent answers in different formats should be treated as equivalent. 120 | If the text contains multiple proposed answers, use the final proposed answer. 121 | 122 | You should return only "True" if the proposed answer is equivalent to the ground truth answer and "False" if there is no proposed answer or if the proposed answer is not equivalent to the ground truth. 123 | Do NOT respond with anything at all except "True" or "False". 124 | 125 | For example, if you are given: 126 | I believe 2+3 equals 5. 127 | ##### 128 | The ground truth answer is five. 129 | 130 | Your response should be: 131 | True 132 | 133 | Another example, if you are given: 134 | I believe 2+2 equals 4. But wait, it is actually 5. 135 | ##### 136 | The ground truth answer is five. 137 | 138 | Your response should be: 139 | True 140 | """ -------------------------------------------------------------------------------- /tools/util/prompts.py: -------------------------------------------------------------------------------- 1 | grading_prompt = " \ 2 | You will be given a math problem. Your job is to grade the difficulty level from 1-10 according to the AoPS standard. \ 3 | Here is the standard: \ 4 | {aops_criteria} \ 5 | Problem to be labeled: {problem}. Please put your estimation of the difficulty inside [[level]]. \ 6 | Important: You should place the difficulty from 1-10 into the [[]], not the solution of the problem." 7 | 8 | aops_criteria = " \ 9 | All levels are estimated and refer to averages. The following is a rough standard based on the USA tier system AMC 8 - AMC 10 - AMC 12 - AIME - USAMO/USAJMO - IMO, \ 10 | representing Middle School - Junior High - High School - Challenging High School - Olympiad levels. Other contests can be interpolated against this. \ 11 | Notes: \ 12 | Multiple choice tests like AMC are rated as though they are free-response. Test-takers can use the answer choices as hints, and so correctly answer more AMC questions than Mathcounts or AIME problems of similar difficulty. \ 13 | Some Olympiads are taken in 2 sessions, with 2 similarly difficult sets of questions, numbered as one set. For these the first half of the test (questions 1-3) is similar difficulty to the second half (questions 4-6). \ 14 | Scale \ 15 | 1: Problems strictly for beginner, on the easiest elementary school or middle school levels (MOEMS, MATHCOUNTS Chapter, AMC 8 1-20, AMC 10 1-10, AMC 12 1-5, and others that involve standard techniques introduced up to the middle school level), most traditional middle/high school word problems. \ 16 | 2: For motivated beginners, harder questions from the previous categories (AMC 8 21-25, harder MATHCOUNTS States questions, AMC 10 11-20, AMC 12 5-15, AIME 1-3), traditional middle/high school word problems with extremely complex problem solving. \ 17 | 3: Advanced Beginner problems that require more creative thinking (harder MATHCOUNTS National questions, AMC 10 21-25, AMC 12 15-20, AIME 4-6). \ 18 | 4: Intermediate-level problems (AMC 12 21-25, AIME 7-9). \ 19 | 5: More difficult AIME problems (10-12), simple proof-based Olympiad-style problems (early JBMO questions, easiest USAJMO 1/4). \ 20 | 6: High-leveled AIME-styled questions (13-15). Introductory-leveled Olympiad-level questions (harder USAJMO 1/4 and easier USAJMO 2/5, easier USAMO and IMO 1/4). \ 21 | 7: Tougher Olympiad-level questions, may require more technical knowledge (harder USAJMO 2/5 and most USAJMO 3/6, extremely hard USAMO and IMO 1/4, easy-medium USAMO and IMO 2/5). \ 22 | 8: High-level Olympiad-level questions (medium-hard USAMO and IMO 2/5, easiest USAMO and IMO 3/6). \ 23 | 9: Expert Olympiad-level questions (average USAMO and IMO 3/6). \ 24 | 10: Historically hard problems, generally unsuitable for very hard competitions (such as the IMO) due to being exceedingly tedious, long, and difficult (e.g. very few students are capable of solving on a worldwide basis). \ 25 | Examples \ 26 | For reference, here are problems from each of the difficulty levels 1-10: \ 27 | <1: Jamie counted the number of edges of a cube, Jimmy counted the numbers of corners, and Judy counted the number of faces. They then added the three numbers. What was the resulting sum? (2003 AMC 8, Problem 1) \ 28 | 1: How many integer values of $x$ satisfy $|x| < 3\pi$? (2021 Spring AMC 10B, Problem 1) \ 29 | 2: A fair $6$-sided die is repeatedly rolled until an odd number appears. What is the probability that every even number appears at least once before the first occurrence of an odd number? (2021 Spring AMC 10B, Problem 18) \ 30 | 3: Triangle $ABC$ with $AB=50$ and $AC=10$ has area $120$. Let $D$ be the midpoint of $\overline{AB}$, and let $E$ be the midpoint of $\overline{AC}$. The angle bisector of $\angle BAC$ intersects $\overline{DE}$ and $\overline{BC}$ at $F$ and $G$, respectively. What is the area of quadrilateral $FDBG$? (2018 AMC 10A, Problem 24) \ 31 | 4: Define a sequence recursively by $x_0=5$ and\[x_{n+1}=\frac{x_n^2+5x_n+4}{x_n+6}\]for all nonnegative integers $n.$ Let $m$ be the least positive integer such that\[x_m\leq 4+\frac{1}{2^{20}}.\]In which of the following intervals does $m$ lie? \ 32 | $\textbf{(A) } [9,26] \qquad\textbf{(B) } [27,80] \qquad\textbf{(C) } [81,242]\qquad\textbf{(D) } [243,728] \qquad\textbf{(E) } [729,\infty)$ \ 33 | (2019 AMC 10B, Problem 24 and 2019 AMC 12B, Problem 22) \ 34 | 5: Find all triples $(a, b, c)$ of real numbers such that the following system holds:\[a+b+c=\frac{1}{a}+\frac{1}{b}+\frac{1}{c},\]\[a^2+b^2+c^2=\frac{1}{a^2}+\frac{1}{b^2}+\frac{1}{c^2}.\](JBMO 2020/1) \ 35 | 6: Let $\triangle ABC$ be an acute triangle with circumcircle $\omega,$ and let $H$ be the intersection of the altitudes of $\triangle ABC.$ Suppose the tangent to the circumcircle of $\triangle HBC$ at $H$ intersects $\omega$ at points $X$ and $Y$ with $HA=3,HX=2,$ and $HY=6.$ The area of $\triangle ABC$ can be written in the form $m\sqrt{n},$ where $m$ and $n$ are positive integers, and $n$ is not divisible by the square of any prime. Find $m+n.$ (2020 AIME I, Problem 15) \ 36 | 7: We say that a finite set $\mathcal{S}$ in the plane is balanced if, for any two different points $A$, $B$ in $\mathcal{S}$, there is a point $C$ in $\mathcal{S}$ such that $AC=BC$. We say that $\mathcal{S}$ is centre-free if for any three points $A$, $B$, $C$ in $\mathcal{S}$, there is no point $P$ in $\mathcal{S}$ such that $PA=PB=PC$. \ 37 | Show that for all integers $n\geq 3$, there exists a balanced set consisting of $n$ points. \ 38 | Determine all integers $n\geq 3$ for which there exists a balanced centre-free set consisting of $n$ points. \ 39 | (IMO 2015/1) \ 40 | 8: For each positive integer $n$, the Bank of Cape Town issues coins of denomination $\frac1n$. Given a finite collection of such coins (of not necessarily different denominations) with total value at most most $99+\frac{1}{2}$, prove that it is possible to split this collection into $100$ or fewer groups, such that each group has total value at most $1$. (IMO 2014/5) \ 41 | 9: Let $k$ be a positive integer and let $S$ be a finite set of odd prime numbers. Prove that there is at most one way (up to rotation and reflection) to place the elements of $S$ around the circle such that the product of any two neighbors is of the form $x^2+x+k$ for some positive integer $x$. (IMO 2022/3) \ 42 | 10: Prove that there exists a positive constant $c$ such that the following statement is true: Consider an integer $n > 1$, and a set $\mathcal S$ of $n$ points in the plane such that the distance between any two different points in $\mathcal S$ is at least 1. It follows that there is a line $\ell$ separating $\mathcal S$ such that the distance from any point of $\mathcal S$ to $\ell$ is at least $cn^{-1/3}$. \ 43 | (A line $\ell$ separates a set of points S if some segment joining two points in $\mathcal S$ crosses $\ell$.) (IMO 2020/6)" 44 | 45 | 46 | convert_prompt = "Another solution is written in an unstructured way. Your job is to convert them into two sections: \ 47 | <|begin_of_thought|> \ 48 | (Thought process, you should copy exactly the thinking process of the original solution.) \ 49 | <|end_of_thought|> \ 50 | <|begin_of_solution|> \ 51 | (Final formatted, precise, and clear solution; make sure there is only one solution in this section; If it is a coding problem, make sure there is only one code block) \ 52 | <|end_of_solution|> \ 53 | Here is an example demonstration of a different question, you can refer to its format: \ 54 | {example} \ 55 | Important: You should almost copy all the contents word-by-word of the original solution. Just convert them into two sections. \ 56 | Make sure you include: <|begin_of_slow_thought|>, <|end_of_slow_thought|>, <|begin_of_solution|>,<|end_of_solution|> These four headers explicitly. \ 57 | Content to be converted: {content}" 58 | 59 | convert_prompt_example = ("<|begin_of_thought|>\n\n" 60 | "Okay, so I've got this problem here. Mr. Wang leaves home at 6 AM, riding his bike at 12 km/h, " 61 | "and he stops to rest for 6 minutes after every 30 minutes of riding. Then, when he arrives at a park " 62 | "that's 16.8 km away, I need to find out the angle between the hour and minute hands on his watch.\n\n" 63 | "Alright, first things first, I need to figure out how long it takes Mr. Wang to ride 16.8 km, including " 64 | "his rest periods.\n\n" 65 | "So, his speed is 12 km/h. To find out how long it takes to go 16.8 km without any stops, I can use the formula " 66 | "time = distance/speed. That would be 16.8 divided by 12, which is 1.4 hours. To make it easier, that's 1 hour and 24 minutes.\n\n" 67 | "But wait, he doesn't ride straight through. He stops for 6 minutes after every 30 minutes of riding. So, I need to see how many " 68 | "of those 30-minute riding periods are there in his total riding time.\n\n" 69 | "In 1 hour and 24 minutes of riding, how many 30-minute segments are there? Well, 1 hour is 60 minutes, plus 24 minutes makes 84 minutes " 70 | "total riding time. So, 84 divided by 30 is 2.8. That means he has two full 30-minute riding periods and a partial one.\n\n" 71 | "After each full 30-minute riding period, he rests for 6 minutes. So, for two full periods, he rests twice, which is 12 minutes of rest.\n\n" 72 | "Now, for the partial riding period. Since 2 times 30 minutes is 60 minutes, and he has 84 minutes of riding, the remaining riding time is 84 minus 60, " 73 | "which is 24 minutes. So, he rides for 24 minutes without another rest because he doesn't complete another 30-minute segment.\n\n" 74 | "So, total time taken is riding time plus rest time. That's 84 minutes riding plus 12 minutes resting, totaling 96 minutes.\n\n" 75 | "Wait a minute, but he stops after every 30 minutes of riding, but in the last partial period of 24 minutes, does he rest again? I think he only rests after " 76 | "completing 30 minutes of riding, so in this case, since the last riding period is only 24 minutes, he doesn't take an additional rest after that.\n\n" 77 | "So, total time should be 84 minutes riding plus 12 minutes resting, which is indeed 96 minutes, or 1 hour and 36 minutes.\n\n" 78 | "So, he leaves at 6 AM and takes 1 hour and 36 minutes to reach the park, arriving at 7:36 AM.\n\n" 79 | "Now, I need to find the angle between the hour and minute hands at 7:36.\n\n" 80 | "To find the angle between the hour and minute hands, I can use the formula:\n\n" 81 | "|30H - 5.5M|\n\n" 82 | "where H is the hour and M is the minutes.\n\n" 83 | "At 7:36, H is 7 and M is 36.\n\n" 84 | "So, plugging in:\n\n" 85 | "30*7 = 210\n\n" 86 | "5.5*36 = 198\n\n" 87 | "210 - 198 = 12\n\n" 88 | "So, the angle is 12 degrees.\n\n" 89 | "Wait, but I should make sure that's the smaller angle. Sometimes, the larger angle is considered, but usually, the smaller one is what is asked for.\n\n" 90 | "So, the angle between the hour and minute hands at 7:36 AM is 12 degrees.\n\n" 91 | "I think that's the answer.<|end_of_thought|>\n\n" 92 | "<|begin_of_solution|>\n\n" 93 | "Mr. Wang leaves home at 6 AM and rides at a speed of 12 km/h, stopping to rest for 6 minutes after every 30 minutes of riding. " 94 | "He arrives at a park 16.8 km away. To determine the angle between the hour and minute hands on his watch when he arrives, we first calculate the total time taken.\n\n" 95 | "1. **Riding time without stops**:\n\n" 96 | "$$\\text{Time} = \\frac{\\text{Distance}}{\\text{Speed}} = \\frac{16.8 \\text{ km}}{12 \\text{ km/h}} = 1.4 \\text{ hours} = 84 \\text{ minutes}$$\n\n" 97 | "2. **Rest periods**:\n\n" 98 | " - He rests for 6 minutes after every 30 minutes of riding.\n\n" 99 | " - In 84 minutes of riding, he completes 2 full 30-minute segments and a partial 24-minute segment.\n\n" 100 | " - He rests twice, totaling 12 minutes of rest.\n\n" 101 | "3. **Total time**:\n\n" 102 | "$$\\text{Total time} = 84 \\text{ minutes (riding)} + 12 \\text{ minutes (rest)} = 96 \\text{ minutes} = 1 \\text{ hour and } 36 \\text{ minutes}$$\n\n" 103 | " - He arrives at 7:36 AM.\n\n" 104 | "4. **Angle between hour and minute hands at 7:36**:\n\n" 105 | " - Use the formula:\n\n" 106 | "$$\\text{Angle} = |30H - 5.5M|$$\n\n" 107 | " - At 7:36, $H = 7$ and $M = 36$:\n\n" 108 | "$$\\text{Angle} = |30 \\times 7 - 5.5 \\times 36| = |210 - 198| = 12 \\text{ degrees}$$\n\n" 109 | "Thus, the angle between the hour and minute hands on his watch is $\\boxed{12}$.<|end_of_solution|>\n") 110 | 111 | # From https://arxiv.org/pdf/2412.09413 112 | system_prompt = "Your role as an assistant involves thoroughly exploring questions through a systematic long \ 113 | thinking process before providing the final precise and accurate solutions. This requires \ 114 | engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, \ 115 | backtracing, and iteration to develop well-considered thinking process. \ 116 | Please structure your response into two main sections: Thought and Solution. \ 117 | In the Thought section, detail your reasoning process using the specified format: \ 118 | <|begin_of_thought|> {thought with steps separated with '\n\n'} \ 119 | <|end_of_thought|> \ 120 | Each step should include detailed considerations such as analisying questions, summarizing \ 121 | relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining \ 122 | any errors, and revisiting previous steps. \ 123 | In the Solution section, based on various attempts, explorations, and reflections from the Thought \ 124 | section, systematically present the final solution that you deem correct. The solution should \ 125 | remain a logical, accurate, concise expression style and detail necessary step needed to reach the \ 126 | conclusion, formatted as follows: \ 127 | <|begin_of_solution|> \ 128 | {final formatted, precise, and clear solution} \ 129 | <|end_of_solution|> \ 130 | Now, try to solve the following question through the above guidelines:" -------------------------------------------------------------------------------- /tools/util/taco/pyext2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2014 Ryan Gonzalez 3 | 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to use, 8 | copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 9 | Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | ''' 22 | 23 | g_backup = globals().copy() 24 | 25 | __version__ = '0.7' 26 | 27 | __all__ = ['overload', 'RuntimeModule', 'switch', 'tail_recurse', 'copyfunc', 'set_docstring', 'annotate', 'safe_unpack', 'modify_function', 'assign', 'fannotate', 'compare_and_swap', 'is_main', 'call_if_main', 'run_main'] 28 | 29 | import sys, inspect, types 30 | 31 | def __targspec(func, specs, attr='__orig_arg__'): 32 | if hasattr(func, '__is_overload__') and func.__is_overload__: 33 | return getattr(func, attr) 34 | return specs(func) 35 | 36 | def set_docstring(doc): 37 | '''A simple decorator to set docstrings. 38 | 39 | :param doc: The docstring to tie to the function. 40 | 41 | Example:: 42 | 43 | @set_docstring('This is a docstring') 44 | def myfunc(x): 45 | pass''' 46 | def _wrap(f): 47 | f.__doc__ = doc 48 | return f 49 | return _wrap 50 | 51 | __modify_function_doc = ''' 52 | Creates a copy of a function, changing its attributes. 53 | 54 | :param globals: Will be added to the function's globals. 55 | 56 | :param name: The new function name. Set to ``None`` to use the function's original name. 57 | 58 | :param code: The new function code object. Set to ``None`` to use the function's original code object. 59 | 60 | :param defaults: The new function defaults. Set to ``None`` to use the function's original defaults. 61 | 62 | :param closure: The new function closure. Set to ``None`` to use the function's original closure. 63 | 64 | .. warning:: This function can be potentially dangerous. 65 | ''' 66 | 67 | def copyfunc(f): 68 | '''Copies a funcion. 69 | 70 | :param f: The function to copy. 71 | 72 | :return: The copied function. 73 | 74 | .. deprecated:: 0.4 75 | Use :func:`modify_function` instead. 76 | ''' 77 | return modify_function(f) 78 | 79 | if sys.version_info.major == 3: 80 | @set_docstring(__modify_function_doc) 81 | def modify_function(f, globals={}, name=None, code=None, defaults=None, 82 | closure=None): 83 | if code is None: code = f.__code__ 84 | if name is None: name = f.__name__ 85 | if defaults is None: defaults = f.__defaults__ 86 | if closure is None: closure = f.__closure__ 87 | newf = types.FunctionType(code, dict(f.__globals__, **globals), name=name, 88 | argdefs=defaults, closure=closure) 89 | newf.__dict__.update(f.__dict__) 90 | return newf 91 | def argspec(f): 92 | return inspect.getfullargspec(f) 93 | ofullargspec = inspect.getfullargspec 94 | def _fullargspec(func): 95 | return __targspec(func, ofullargspec) 96 | inspect.getfullargspec = _fullargspec 97 | def _exec(m,g): exec(m,g) 98 | else: 99 | @set_docstring(__modify_function_doc) 100 | def modify_function(f, globals={}, name=None, code=None, defaults=None, 101 | closure=None): 102 | if code is None: code = f.func_code 103 | if name is None: name = f.__name__ 104 | if defaults is None: defaults = f.func_defaults 105 | if closure is None: closure = f.func_closure 106 | newf = types.FunctionType(code, dict(f.func_globals, **globals), name=name, 107 | argdefs=defaults, closure=closure) 108 | newf.__dict__.update(f.__dict__) 109 | return newf 110 | def argspec(f): 111 | return inspect.getargspec(f) 112 | eval(compile('def _exec(m,g): exec m in g', '', 'exec')) 113 | 114 | def _gettypes(args): 115 | return tuple(map(type, args)) 116 | 117 | oargspec = inspect.getargspec 118 | 119 | def _argspec(func): 120 | return __targspec(func, oargspec) 121 | 122 | inspect.getargspec = _argspec 123 | 124 | try: 125 | import IPython 126 | except ImportError: 127 | IPython = None 128 | else: 129 | # Replace IPython's argspec 130 | oipyargspec = IPython.core.oinspect.getargspec 131 | def _ipyargspec(func): 132 | return __targspec(func, oipyargspec, '__orig_arg_ipy__') 133 | IPython.core.oinspect.getargspec = _ipyargspec 134 | 135 | class overload(object): 136 | '''Simple function overloading in Python.''' 137 | _items = {} 138 | _types = {} 139 | @classmethod 140 | def argc(self, argc=None): 141 | '''Overloads a function based on the specified argument count. 142 | 143 | :param argc: The argument count. Defaults to ``None``. If ``None`` is given, automatically compute the argument count from the given function. 144 | 145 | .. note:: 146 | 147 | Keyword argument counts are NOT checked! In addition, when the argument count is automatically calculated, the keyword argument count is also ignored! 148 | 149 | Example:: 150 | 151 | @overload.argc() 152 | def func(a): 153 | print 'Function 1 called' 154 | 155 | @overload.argc() 156 | def func(a, b): 157 | print 'Function 2 called' 158 | 159 | func(1) # Calls first function 160 | func(1, 2) # Calls second function 161 | func() # Raises error 162 | ''' 163 | # Python 2 UnboundLocalError fix 164 | argc = {'argc': argc} 165 | def _wrap(f): 166 | def _newf(*args, **kwargs): 167 | if len(args) not in self._items[f.__name__]: 168 | raise TypeError("No overload of function '%s' that takes %d args" % (f.__name__, len(args))) 169 | return self._items[f.__name__][len(args)](*args, **kwargs) 170 | if f.__name__ not in self._items: 171 | self._items[f.__name__] = {} 172 | if argc['argc'] is None: 173 | argc['argc'] = len(argspec(f).args) 174 | self._items[f.__name__][argc['argc']] = f 175 | _newf.__name__ = f.__name__ 176 | _newf.__doc__ = f.__doc__ 177 | _newf.__is_overload__ = True 178 | _newf.__orig_arg__ = argspec(f) 179 | if IPython: 180 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f) 181 | return _newf 182 | return _wrap 183 | @classmethod 184 | def args(self, *argtypes, **kw): 185 | '''Overload a function based on the specified argument types. 186 | 187 | :param argtypes: The argument types. If None is given, get the argument types from the function annotations(Python 3 only) 188 | :param kw: Can only contain 1 argument, `is_cls`. If True, the function is assumed to be part of a class. 189 | 190 | Example:: 191 | 192 | @overload.args(str) 193 | def func(s): 194 | print 'Got string' 195 | 196 | @overload.args(int, str) 197 | def func(i, s): 198 | print 'Got int and string' 199 | 200 | @overload.args() 201 | def func(i:int): # A function annotation example 202 | print 'Got int' 203 | 204 | func('s') 205 | func(1) 206 | func(1, 's') 207 | func(True) # Raises error 208 | ''' 209 | 210 | # Python 2 UnboundLocalError fix...again! 211 | argtypes = {'args': tuple(argtypes)} 212 | def _wrap(f): 213 | def _newf(*args): 214 | if len(kw) == 0: 215 | cargs = args 216 | elif len(kw) == 1 and 'is_cls' in kw and kw['is_cls']: 217 | cargs = args[1:] 218 | else: 219 | raise ValueError('Invalid keyword args specified') 220 | if _gettypes(cargs) not in self._types[f.__name__]: 221 | raise TypeError("No overload of function '%s' that takes '%s' types and %d arg(s)" % (f.__name__, _gettypes(cargs), len(cargs))) 222 | return self._types[f.__name__][_gettypes(cargs)](*args) 223 | if f.__name__ not in self._types: 224 | self._types[f.__name__] = {} 225 | if len(argtypes['args']) == 1 and argtypes['args'][0] is None: 226 | aspec = argspec(f) 227 | argtypes['args'] = tuple(map(lambda x: x[1], sorted( 228 | aspec.annotations.items(), key=lambda x: aspec.args.index(x[0])))) 229 | self._types[f.__name__][argtypes['args']] = f 230 | _newf.__name__ = f.__name__ 231 | _newf.__doc__ = f.__doc__ 232 | _newf.__is_overload__ = True 233 | _newf.__orig_arg__ = argspec(f) 234 | if IPython: 235 | _newf.__orig_arg_ipy__ = IPython.core.oinspect.getargspec(f) 236 | return _newf 237 | return _wrap 238 | 239 | class _RuntimeModule(object): 240 | 'Create a module object at runtime and insert it into sys.path. If called, same as :py:func:`from_objects`.' 241 | def __call__(self, *args, **kwargs): 242 | return self.from_objects(*args, **kwargs) 243 | @staticmethod 244 | @overload.argc(1) 245 | def from_objects(module_name_for_code_eval, **d): 246 | return _RuntimeModule.from_objects(module_name_for_code_eval, '', **d) 247 | @staticmethod 248 | @overload.argc(2) 249 | def from_objects(module_name_for_code_eval, docstring, **d): 250 | '''Create a module at runtime from `d`. 251 | 252 | :param name: The module name. 253 | 254 | :param docstring: Optional. The module's docstring. 255 | 256 | :param \*\*d: All the keyword args, mapped from name->value. 257 | 258 | Example: ``RuntimeModule.from_objects('name', 'doc', a=1, b=2)``''' 259 | module = types.ModuleType(module_name_for_code_eval, docstring) 260 | module.__dict__.update(d) 261 | module.__file__ = '' 262 | sys.modules[module_name_for_code_eval] = module 263 | return module 264 | @staticmethod 265 | @overload.argc(2) 266 | def from_string(module_name_for_code_eval, s): 267 | return _RuntimeModule.from_string(module_name_for_code_eval, '', s) 268 | @staticmethod 269 | @overload.argc(3) 270 | def from_string(module_name_for_code_eval, docstring, s): 271 | '''Create a module at runtime from `s``. 272 | 273 | :param name: The module name. 274 | 275 | :param docstring: Optional. The module docstring. 276 | 277 | :param s: A string containing the module definition.''' 278 | g = {} 279 | _exec(s, g) 280 | return _RuntimeModule.from_objects(module_name_for_code_eval, docstring, **dict(filter(lambda x: x[0] not in g_backup, g.items()))) 281 | 282 | RuntimeModule = _RuntimeModule() 283 | 284 | class CaseObject(object): 285 | 'The object returned by a switch statement. When called, it will return True if the given argument equals its value, else False. It can be called with multiple parameters, in which case it checks if its value equals any of the arguments.' 286 | def __init__(self, value): 287 | self.value = value 288 | self.did_match = False 289 | self.did_pass = False 290 | def __call__(self, *args): 291 | if assign('res', not self.did_pass and any([self.value == rhs for rhs in args])): 292 | self.did_match = True 293 | return res 294 | def quit(self): 295 | 'Forces all other calls to return False. Equilavent of a ``break`` statement.' 296 | self.did_pass = True 297 | def default(self): 298 | "Executed if quit wasn't called." 299 | return not self.did_match and not self.did_pass 300 | def __iter__(self): 301 | yield self 302 | def __enter__(self): 303 | return self 304 | def __exit__(self, *args): 305 | pass 306 | 307 | def switch(value): 308 | '''A Python switch statement implementation that is used with a ``with`` statement. 309 | 310 | :param value: The value to "switch". 311 | 312 | ``with`` statement example:: 313 | 314 | with switch('x'): 315 | if case(1): print 'Huh?' 316 | if case('x'): print 'It works!!!' 317 | 318 | .. warning:: If you modify a variable named "case" in the same scope that you use the ``with`` statement version, you will get an UnboundLocalError. The soluction is to use ``with switch('x') as case:`` instead of ``with switch('x'):``.''' 319 | res = CaseObject(value) 320 | inspect.stack()[1][0].f_globals['case'] = res 321 | return res 322 | 323 | def tail_recurse(spec=None): 324 | '''Remove tail recursion from a function. 325 | 326 | :param spec: A function that, when given the arguments, returns a bool indicating whether or not to exit. If ``None,`` tail recursion is always called unless the function returns a value. 327 | 328 | .. note:: 329 | 330 | This function has a slight overhead that is noticable when using timeit. Only use it if the function has a possibility of going over the recursion limit. 331 | 332 | .. warning:: 333 | 334 | This function will BREAK any code that either uses any recursion other than tail recursion or calls itself multiple times. For example, ``def x(): return x()+1`` will fail. 335 | 336 | Example:: 337 | 338 | @tail_recurse() 339 | def add(a, b): 340 | if a == 0: return b 341 | return add(a-1, b+1) 342 | 343 | add(10000000, 1) # Doesn't max the recursion limit. 344 | ''' 345 | def _wrap(f): 346 | class TailRecursion(Exception): 347 | def __init__(self, args, kwargs): 348 | self.args = args 349 | self.kwargs = kwargs 350 | def _newf(*args, **kwargs): 351 | if inspect.stack()[1][3] == f.__name__: 352 | if (spec and spec(args)) or not spec: 353 | raise TailRecursion(args, kwargs) 354 | while True: 355 | try: 356 | res = f(*args, **kwargs) 357 | except TailRecursion as ex: 358 | args = ex.args 359 | kwargs = ex.kwargs 360 | continue 361 | else: 362 | return res 363 | _newf.__doc__ = f.__doc__ 364 | return _newf 365 | return _wrap 366 | 367 | def annotate(*args, **kwargs): 368 | '''Set function annotations using decorators. 369 | 370 | :param args: This is a list of annotations for the function, in the order of the function's parameters. For example, ``annotate('Annotation 1', 'Annotation 2')`` will set the annotations of parameter 1 of the function to ``Annotation 1``. 371 | 372 | :param kwargs: This is a mapping of argument names to annotations. Note that these are applied *after* the argument list, so any args set that way will be overriden by this mapping. If there is a key named `ret`, that will be the annotation for the function's return value. 373 | 374 | .. deprecated:: 0.5 375 | Use :func:`fannotate` instead. 376 | ''' 377 | def _wrap(f): 378 | if not hasattr(f, '__annotations__'): 379 | f.__annotations__ = {} 380 | if 'ret' in kwargs: 381 | f.__annotations__['return'] = kwargs.pop('ret') 382 | f.__annotations__.update(dict(zip(argspec(f).args, args))) 383 | f.__annotations__.update(kwargs) 384 | return f 385 | return _wrap 386 | 387 | def fannotate(*args, **kwargs): 388 | '''Set function annotations using decorators. 389 | 390 | :param \*args: The first positional argument is used for the function's return value; all others are discarded. 391 | 392 | :param \**kwargs: This is a mapping of argument names to annotations. 393 | 394 | Example:: 395 | 396 | @fannotate('This for the return value', a='Parameter a', b='Parameter b') 397 | def x(a, b): 398 | pass 399 | 400 | ''' 401 | def _wrap(f): 402 | if not hasattr(f, '__annotations__'): 403 | f.__annotations__ = {} 404 | if len(args) >= 1: 405 | f.__annotations__['return'] = args[0] 406 | f.__annotations__.update(kwargs) 407 | return f 408 | return _wrap 409 | 410 | def safe_unpack(seq, ln, fill=None): 411 | '''Safely unpack a sequence to length `ln`, without raising ValueError. Based on Lua's method of unpacking. Empty values will be filled in with `fill`, while any extra values will be cut off. 412 | 413 | :param seq: The sequence to unpack. 414 | 415 | :param ln: The expected length of the sequence. 416 | 417 | :param fill: The value to substitute if the sequence is too small. Defaults to ``None``. 418 | 419 | Example:: 420 | 421 | s = 'a:b' 422 | a, b = safe_unpack(s.split(':'), 2) 423 | # a = 'a' 424 | # b = 'b' 425 | s = 'a' 426 | a, b = safe_unpack(s.split(':'), 2) 427 | # a = 'a' 428 | # b = None''' 429 | if len(seq) > ln: 430 | return seq[:ln] 431 | elif len(seq) < ln: 432 | return seq + type(seq)([fill]*(ln-len(seq))) 433 | else: 434 | return seq 435 | 436 | def assign(varname, value): 437 | '''Assign `value` to `varname` and return it. If `varname` is an attribute and the instance name it belongs to is not defined, a NameError is raised. 438 | This can be used to emulate assignment as an expression. For example, this:: 439 | 440 | if assign('x', 7): ... 441 | 442 | is equilavent to this C code:: 443 | 444 | if (x = 7) ... 445 | 446 | .. warning:: 447 | 448 | When assigning an attribute, the instance it belongs to MUST be declared as global prior to the assignment. Otherwise, the assignment will not work. 449 | ''' 450 | fd = inspect.stack()[1][0].f_globals 451 | if '.' not in varname: 452 | fd[varname] = value 453 | else: 454 | vsplit = list(map(str.strip, varname.split('.'))) 455 | if vsplit[0] not in fd: 456 | raise NameError('Unknown object: %s'%vsplit[0]) 457 | base = fd[vsplit[0]] 458 | for x in vsplit[1:-1]: 459 | base = getattr(base, x) 460 | setattr(base, vsplit[-1], value) 461 | return value 462 | 463 | def is_main(frame=1): 464 | "Return if the caller is main. Equilavent to ``__name__ == '__main__'``." 465 | return inspect.stack()[frame][0].f_globals['__name__'] == '__main__' 466 | 467 | def _call_if_main(frame, f, args): 468 | if is_main(frame): return f(*args) 469 | 470 | def call_if_main(f,*args): 471 | "Call the `f` with `args` if the caller's module is main." 472 | return _call_if_main(3,f,args) 473 | 474 | def run_main(f,*args): 475 | "Call `f` with the `args` and terminate the program with its return code if the caller's module is main." 476 | sys.exit(_call_if_main(3,f,args)) 477 | 478 | def compare_and_swap(var, compare, new): 479 | "If `var` is equal to `compare`, set it to `new`." 480 | if assign('v', inspect.stack()[1][0].f_globals)[var] == compare: 481 | v[var] = new 482 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.abspath(os.path.join(__file__, "../../"))) 4 | -------------------------------------------------------------------------------- /train/deepseed/ds_z3_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 3, 20 | "offload_optimizer": { 21 | "device": "cpu", 22 | "pin_memory": true 23 | }, 24 | "offload_param": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "overlap_comm": true, 29 | "contiguous_gradients": true, 30 | "sub_group_size": 1e9, 31 | "reduce_bucket_size": "auto", 32 | "stage3_prefetch_bucket_size": "auto", 33 | "stage3_param_persistence_threshold": "auto", 34 | "stage3_max_live_parameters": 1e9, 35 | "stage3_max_reuse_distance": 1e9, 36 | "stage3_gather_16bit_weights_on_model_save": true 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /train/deepseed/zero2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 2, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 5e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 5e8, 25 | "contiguous_gradients": true, 26 | "round_robin_gradients": true 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /train/deepseed/zero3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "zero_optimization": { 19 | "stage": 3, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true, 22 | "sub_group_size": 1e9, 23 | "reduce_bucket_size": "auto", 24 | "stage3_prefetch_bucket_size": "auto", 25 | "stage3_param_persistence_threshold": "auto", 26 | "stage3_max_live_parameters": 1e9, 27 | "stage3_max_reuse_distance": 1e9, 28 | "stage3_gather_16bit_weights_on_model_save": true 29 | }, 30 | "wandb": { 31 | "enabled": true, 32 | "rank_zero_only": true 33 | } 34 | } -------------------------------------------------------------------------------- /train/deepseed/zero3_config2.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": true 4 | }, 5 | "zero_optimization": { 6 | "stage": 3, 7 | "allgather_partitions": true, 8 | "allgather_bucket_size": 2e8, 9 | "overlap_comm": true, 10 | "reduce_scatter": true, 11 | "reduce_bucket_size": "auto", 12 | "contiguous_gradients": true, 13 | "stage3_gather_16bit_weights_on_model_save": true 14 | }, 15 | "gradient_accumulation_steps": "auto", 16 | "gradient_clipping": "auto", 17 | "train_batch_size": "auto", 18 | "train_micro_batch_size_per_gpu": "auto", 19 | "wall_clock_breakdown": false 20 | } -------------------------------------------------------------------------------- /train/dpo_train.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | from utils.utils import * 3 | import os 4 | import warnings 5 | from trl import DPOConfig, DPOTrainer 6 | import wandb 7 | from utils.load_model import * 8 | import argparse 9 | from train.names import * 10 | 11 | warnings.filterwarnings('ignore') 12 | os.environ["WANDB_MODE"] = "offline" 13 | 14 | def parse_args(args=None): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--lr', type=float, default=5e-7) 17 | parser.add_argument('--beta', type=float, default=0.05) 18 | parser.add_argument('--model', type=str, default='Bespoke-7b') 19 | parser.add_argument('--dataset', type=str, default='Bespoke_dpo') 20 | parser.add_argument('--epoch', type=int, default=1) 21 | parser.add_argument('--MAX_LENGTH', type=int, default=1024 * 8) 22 | parser.add_argument('--seed', type=int, default=42) 23 | parser.add_argument('--gradient_accumulation_steps', type=int, default=12) 24 | parser.add_argument('--deepspeed', type=str, default=None) # 25 | parser.add_argument('--local_rank', type=int, default=0) 26 | return parser.parse_args(args) 27 | 28 | 29 | if __name__ == '__main__': 30 | args = parse_args() 31 | lr, beta = args.lr, args.beta 32 | model_name = model_names[args.model] 33 | MAX_LENGTH = args.MAX_LENGTH 34 | dataset_name = set_global(dataset_names[args.dataset]) 35 | wandb_name = f"{args.model}_{args.dataset}_dpo_lr{lr}_beta{beta}" 36 | 37 | training_config = DPOConfig( 38 | save_only_model=True, 39 | output_dir=set_global(f"./train/models/dpo/{wandb_name}"), 40 | per_device_train_batch_size=1, 41 | gradient_accumulation_steps=args.gradient_accumulation_steps, 42 | gradient_checkpointing=True, 43 | save_total_limit=1, 44 | num_train_epochs=args.epoch, 45 | report_to="wandb", 46 | save_strategy='epoch', 47 | logging_steps=1, 48 | learning_rate=lr, 49 | beta=beta, 50 | bf16=True, 51 | lr_scheduler_type='cosine', 52 | warmup_ratio = 0.1, 53 | max_length=MAX_LENGTH, 54 | deepspeed= set_global(args.deepspeed) if args.deepspeed is not None else None, 55 | ) 56 | 57 | train_dataset = load_train_data(dataset_name) 58 | model, tokenizer = init_train_model(model_name) 59 | ref_model = init_train_model(model_name)[0] 60 | ref_model.eval() 61 | # ref_model = None 62 | 63 | if args.local_rank == 0: 64 | wandb.login() 65 | wandb.init(project="ThinkPO", name=wandb_name) 66 | 67 | dpo_trainer = DPOTrainer( 68 | model, 69 | ref_model=ref_model, 70 | args=training_config, 71 | train_dataset=train_dataset, 72 | tokenizer=tokenizer, 73 | ) 74 | dpo_trainer.train() -------------------------------------------------------------------------------- /train/names.py: -------------------------------------------------------------------------------- 1 | import __init__ 2 | from utils.utils import * 3 | from utils.data_utils import save_all_results, read_saved_results, load_data 4 | from utils.load_model import * 5 | 6 | def init_train_model(model_name=''): 7 | model, tokenizer = init_model(model_name, eval=False, low_cpu_mem_usage=False) 8 | model.enable_input_require_grads() 9 | if tokenizer.pad_token is None: 10 | token = tokenizer.convert_ids_to_tokens(2) 11 | print(f"Token for ID 2: {token}") 12 | tokenizer.pad_token_id = 2 13 | return model, tokenizer 14 | 15 | def load_train_data(choose_data_name): 16 | if choose_data_name in ['Bespoke_dpo', 'Bespoke', 'Deepseek']: 17 | choose_data_name = dataset_names[choose_data_name] 18 | choose_data = load_data(choose_data_name, 'huggingface') 19 | else: 20 | choose_data_name = set_global(dataset_names[choose_data_name]) 21 | choose_data = load_data(choose_data_name, 'json') 22 | choose_data = choose_data['train'] 23 | return choose_data 24 | 25 | dataset_names = { 26 | 'NuminaMath': 'AI-MO/NuminaMath-CoT', 27 | 'Bespoke':'bespokelabs/Bespoke-Stratos-17k', 28 | 29 | 'Bespoke_dpo':f'VanWang/Bespoke_dpo_filter', 30 | 'Bespoke_dpo_long':f'data/final/Bespoke_dpo_filter_len_long.jsonl', 31 | 'Bespoke_dpo_short':f'data/final/Bespoke_dpo_filter_len_short.jsonl', 32 | 'Bespoke_dpo_middle':'data/final/Bespoke_dpo_filter_len_middle.jsonl' 33 | } 34 | model_names = { 35 | 'Deepseek-7b':'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', 36 | 37 | 'Instruct-1.5b':'Qwen/Qwen2.5-1.5B-Instruct', 38 | 'Instruct-3b':'Qwen/Qwen2.5-3B-Instruct', 39 | 'Instruct-7b':'Qwen/Qwen2.5-7B-Instruct', 40 | 'Instruct-3b':'Qwen/Qwen2.5-3B-Instruct', 41 | 'Instruct-14b':'Qwen/Qwen2.5-14B-Instruct', 42 | 'Instruct-32b':'Qwen/Qwen2.5-32B-Instruct', 43 | 44 | 'Bespoke-32b': 'bespokelabs/Bespoke-Stratos-32B', 45 | 'Bespoke-7b': 'bespokelabs/Bespoke-Stratos-7B', 46 | 47 | 'OpenThinker-7B':'open-thoughts/OpenThinker-7B', 48 | } -------------------------------------------------------------------------------- /train/sft_train.py: -------------------------------------------------------------------------------- 1 | from __init__ import * 2 | from utils.utils import * 3 | from transformers import Trainer, TrainingArguments,DataCollatorForSeq2Seq 4 | import wandb 5 | from utils.data_utils import load_data 6 | from utils.load_model import * 7 | import argparse 8 | import os 9 | from train.names import model_names, dataset_names 10 | 11 | os.environ['WANDB_MODE'] = 'offline' 12 | 13 | def init_train_model(model_name=''): 14 | model, tokenizer = init_model(model_name=model_names[model_name], eval=False) 15 | model.enable_input_require_grads() 16 | tokenizer.padding_side = "right" 17 | if tokenizer.pad_token is None: 18 | token = tokenizer.convert_ids_to_tokens(2) 19 | print(f"Token for ID 2: {token}") 20 | tokenizer.pad_token_id = 2 21 | return model, tokenizer 22 | 23 | def preprocess_function(example): 24 | global MAX_LENGTH 25 | input_ids, attention_mask, labels = [], [], [] 26 | system, conversations = example["system"], example["conversations"] 27 | targets = conversations[1]['value'].strip() 28 | messages = [ 29 | {"role": "system", "content": system}, 30 | {"role": "user", "content": conversations[0]['value']} 31 | ] 32 | inputs = tokenizer.apply_chat_template( 33 | messages, 34 | tokenize=False, 35 | add_generation_prompt=True 36 | ) 37 | instruction, response = tokenizer(inputs+'\n'), tokenizer(targets) 38 | input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.eos_token_id] 39 | attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1] 40 | labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.eos_token_id] 41 | if len(input_ids) > MAX_LENGTH: 42 | input_ids = input_ids[:MAX_LENGTH] 43 | attention_mask = attention_mask[:MAX_LENGTH] 44 | labels = labels[:MAX_LENGTH] 45 | return { 46 | "input_ids": input_ids, 47 | "attention_mask": attention_mask, 48 | "labels": labels 49 | } 50 | 51 | def load_train_data(dataset_name): 52 | if dataset_name == 'NuminaMath': 53 | dataset = load_data(set_global(dataset_names[dataset_name]), 'json') 54 | train_dataset = dataset['train'] 55 | if dataset_name == 'Bespoke': 56 | dataset = load_data(dataset_names[dataset_name], 'huggingface') 57 | train_dataset = dataset['train'] 58 | return train_dataset.map(preprocess_function, remove_columns=train_dataset.column_names) 59 | 60 | 61 | def parse_args(args=None): 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--lr', type=float, default=3e-5) 64 | parser.add_argument('--MAX_LENGTH', type=int, default=1024*10) 65 | parser.add_argument('--model_name', type=str, default='Qwen2.5-14B-Instruct', help='Qwen2.5-3B, Qwen2.5-7B, Qwen2.5-32B') 66 | parser.add_argument('--seed', type=int, default=42) 67 | parser.add_argument('--gradient_accumulation_steps', type=int, default=96) 68 | parser.add_argument('--dataset_name', type=str, default='Bespoke', help='Bespoke, NuminaMath') 69 | parser.add_argument('--deepspeed', type=str, default=None) # 70 | parser.add_argument('--local_rank', type=int, default=0) 71 | parser.add_argument('--epoch', type=int, default=1) 72 | return parser.parse_args(args) 73 | 74 | args = parse_args() 75 | 76 | lr = args.lr 77 | MAX_LENGTH = args.MAX_LENGTH 78 | model_name = args.model_name 79 | dataset_name = args.dataset_name 80 | wandb_name = f"{dataset_name}_{model_name}_sft_lr{lr}" 81 | 82 | model, tokenizer = init_train_model(model_name=model_name) 83 | train_dataset = load_train_data(dataset_name) 84 | 85 | if args.local_rank == 0: 86 | wandb.login() 87 | wandb.init(project="ThinkPO-SFT", name=wandb_name) 88 | 89 | training_args = TrainingArguments( 90 | save_only_model=True, 91 | output_dir=set_global(f"./train/models/sft/{wandb_name}"), 92 | per_device_train_batch_size=1, 93 | gradient_accumulation_steps=args.gradient_accumulation_steps, 94 | save_total_limit=1, 95 | num_train_epochs=args.epoch, 96 | learning_rate=lr, 97 | gradient_checkpointing=True, 98 | lr_scheduler_type='cosine', 99 | warmup_ratio=0.1, 100 | logging_steps=10, 101 | save_strategy='epoch', 102 | report_to="wandb", 103 | bf16=True, 104 | # dataloader_num_workers=4, 105 | deepspeed= set_global(args.deepspeed) if args.deepspeed is not None else None, 106 | ) 107 | 108 | trainer = Trainer( 109 | model=model, 110 | args=training_args, 111 | tokenizer=tokenizer, 112 | train_dataset=train_dataset, 113 | data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True) 114 | ) 115 | trainer.train() 116 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.abspath(os.path.join(__file__, "../../"))) -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from .utils import cache_dir 3 | import os 4 | import json 5 | 6 | def save_results(out_path, data): 7 | if os.path.exists(out_path): 8 | with open(out_path, "a", encoding="utf-8") as f: 9 | json.dump(data, f) 10 | f.write('\n') 11 | else: 12 | save_all_results(out_path, [data]) 13 | 14 | 15 | def save_all_results(out_path, data): 16 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 17 | with open(out_path, "w", encoding="utf-8") as f: 18 | for pred in data: 19 | json.dump(pred, f, ensure_ascii=False) 20 | f.write('\n') 21 | 22 | def read_saved_results(out_path): 23 | preds=[] 24 | if os.path.exists(out_path): 25 | with open(out_path, "r", encoding="utf-8") as f: 26 | i=0 27 | for line in f: 28 | l = json.loads(line) 29 | preds.append(l) 30 | i+=1 31 | print(i) 32 | return preds 33 | 34 | 35 | def load_data(dataset_name, mode='json'): 36 | if mode == 'huggingface': 37 | dataset = load_dataset(dataset_name) 38 | if mode == 'json': 39 | dataset = load_dataset('json', data_files=dataset_name) 40 | return dataset 41 | -------------------------------------------------------------------------------- /utils/eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | from .qwen_math_parser import * 2 | 3 | 4 | def check_math_correctness(reference, generation): 5 | if not find_box(generation) or not find_box(reference): return False 6 | ref = extract_answer(reference) 7 | answer = strip_answer_string(ref) 8 | pred = extract_answer(generation) 9 | pred = strip_answer_string(pred) 10 | return math_equal(pred, answer) 11 | 12 | def check_math_correctness_with_model(reference, generation): 13 | pass -------------------------------------------------------------------------------- /utils/eval/qwen_math_parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | The logic in this file largely borrows from Qwen2.5-Math codebase at https://github.com/QwenLM/Qwen2.5-Math: 3 | """ 4 | 5 | import re 6 | import regex 7 | from word2number import w2n 8 | from math import isclose 9 | from collections import defaultdict 10 | 11 | from sympy import simplify, N 12 | from sympy.parsing.sympy_parser import parse_expr 13 | from sympy.parsing.latex import parse_latex 14 | from latex2sympy2 import latex2sympy 15 | 16 | 17 | def convert_word_number(text: str) -> str: 18 | try: 19 | text = str(w2n.word_to_num(text)) 20 | except: 21 | pass 22 | return text 23 | 24 | def _fix_fracs(string): 25 | substrs = string.split("\\frac") 26 | new_str = substrs[0] 27 | if len(substrs) > 1: 28 | substrs = substrs[1:] 29 | for substr in substrs: 30 | new_str += "\\frac" 31 | if len(substr) > 0 and substr[0] == "{": 32 | new_str += substr 33 | else: 34 | try: 35 | assert len(substr) >= 2 36 | except: 37 | return string 38 | a = substr[0] 39 | b = substr[1] 40 | if b != "{": 41 | if len(substr) > 2: 42 | post_substr = substr[2:] 43 | new_str += "{" + a + "}{" + b + "}" + post_substr 44 | else: 45 | new_str += "{" + a + "}{" + b + "}" 46 | else: 47 | if len(substr) > 2: 48 | post_substr = substr[2:] 49 | new_str += "{" + a + "}" + b + post_substr 50 | else: 51 | new_str += "{" + a + "}" + b 52 | string = new_str 53 | return string 54 | 55 | 56 | def _fix_a_slash_b(string): 57 | if len(string.split("/")) != 2: 58 | return string 59 | a = string.split("/")[0] 60 | b = string.split("/")[1] 61 | try: 62 | if "sqrt" not in a: 63 | a = int(a) 64 | if "sqrt" not in b: 65 | b = int(b) 66 | assert string == "{}/{}".format(a, b) 67 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 68 | return new_string 69 | except: 70 | return string 71 | 72 | 73 | def _fix_sqrt(string): 74 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) 75 | return _string 76 | 77 | def strip_answer_string(string): 78 | string = str(string).strip() 79 | # linebreaks 80 | string = string.replace("\n", "") 81 | 82 | # right "." 83 | string = string.rstrip(".") 84 | 85 | # remove inverse spaces 86 | # replace \\ with \ 87 | string = string.replace("\\!", "") 88 | # string = string.replace("\\ ", "") 89 | # string = string.replace("\\\\", "\\") 90 | 91 | # matrix 92 | string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) 93 | string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) 94 | string = string.replace("bmatrix", "pmatrix") 95 | 96 | # replace tfrac and dfrac with frac 97 | string = string.replace("tfrac", "frac") 98 | string = string.replace("dfrac", "frac") 99 | string = ( 100 | string.replace("\\neq", "\\ne") 101 | .replace("\\leq", "\\le") 102 | .replace("\\geq", "\\ge") 103 | ) 104 | 105 | # remove \left and \right 106 | string = string.replace("\\left", "") 107 | string = string.replace("\\right", "") 108 | string = string.replace("\\{", "{") 109 | string = string.replace("\\}", "}") 110 | 111 | # Function to replace number words with corresponding digits 112 | def replace_match(match): 113 | word = match.group(1).lower() 114 | if convert_word_number(word) == word: 115 | return match.group(0) 116 | else: 117 | return convert_word_number(word) 118 | string = re.sub(r"\\text\{([a-zA-Z]+)\}", replace_match, string) 119 | 120 | # Before removing unit, check if the unit is squared (for surface area) 121 | string = re.sub(r"(cm|inches)\}\^2", r"\1}", string) 122 | 123 | # Remove unit: miles, dollars if after is not none 124 | _string = re.sub(r"\\text{.*?}$", "", string).strip() 125 | if _string != "" and _string != string: 126 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) 127 | string = _string 128 | 129 | # Remove circ (degrees) 130 | string = string.replace("^{\\circ}", "") 131 | string = string.replace("^\\circ", "") 132 | 133 | # remove dollar signs 134 | string = string.replace("\\$", "") 135 | string = string.replace("$", "") 136 | string = string.replace("\\(", "").replace("\\)", "") 137 | 138 | # convert word number to digit 139 | string = convert_word_number(string) 140 | 141 | # replace "\\text{...}" to "..." 142 | string = re.sub(r"\\text\{(.*?)\}", r"\1", string) 143 | for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]: 144 | string = string.replace(key, "") 145 | string = string.replace("\\emptyset", r"{}") 146 | string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}") 147 | 148 | # remove percentage 149 | string = string.replace("\\%", "") 150 | string = string.replace("\%", "") 151 | string = string.replace("%", "") 152 | 153 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 154 | string = string.replace(" .", " 0.") 155 | string = string.replace("{.", "{0.") 156 | 157 | # cdot 158 | # string = string.replace("\\cdot", "") 159 | if ( 160 | string.startswith("{") 161 | and string.endswith("}") 162 | and string.isalnum() 163 | or string.startswith("(") 164 | and string.endswith(")") 165 | and string.isalnum() 166 | or string.startswith("[") 167 | and string.endswith("]") 168 | and string.isalnum() 169 | ): 170 | string = string[1:-1] 171 | 172 | # inf 173 | string = string.replace("infinity", "\\infty") 174 | if "\\infty" not in string: 175 | string = string.replace("inf", "\\infty") 176 | string = string.replace("+\\inity", "\\infty") 177 | 178 | # and 179 | string = string.replace("and", "") 180 | string = string.replace("\\mathbf", "") 181 | 182 | # use regex to remove \mbox{...} 183 | string = re.sub(r"\\mbox{.*?}", "", string) 184 | 185 | # quote 186 | string.replace("'", "") 187 | string.replace('"', "") 188 | 189 | # i, j 190 | if "j" in string and "i" not in string: 191 | string = string.replace("j", "i") 192 | 193 | # replace a.000b where b is not number or b is end, with ab, use regex 194 | string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string) 195 | string = re.sub(r"(\d+)\.0*$", r"\1", string) 196 | 197 | # if empty, return empty string 198 | if len(string) == 0: 199 | return string 200 | if string[0] == ".": 201 | string = "0" + string 202 | 203 | # to consider: get rid of e.g. "k = " or "q = " at beginning 204 | if len(string.split("=")) == 2: 205 | if len(string.split("=")[0]) <= 2: 206 | string = string.split("=")[1] 207 | 208 | string = _fix_sqrt(string) 209 | string = string.replace(" ", "") 210 | 211 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 212 | string = _fix_fracs(string) 213 | 214 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 215 | string = _fix_a_slash_b(string) 216 | 217 | # Remove unnecessary '\' before integers 218 | string = re.sub(r"\\(?=\-?\d+(\\|\)|,|\]|$))", "", string) 219 | 220 | # Remove grade level (e.g., 12th grade) and just maintain the integer 221 | string = re.sub(r"thgrade$", "", string) 222 | 223 | # If the answer is a list of integers (without parenthesis), sort them 224 | if re.fullmatch(r"(\s*-?\d+\s*,)*\s*-?\d+\s*", string): 225 | if ',' in string: 226 | # Split the string into a list of integers 227 | integer_list = list(map(int, string.split(','))) 228 | 229 | # Sort the list in ascending order 230 | sorted_list = sorted(integer_list) 231 | 232 | # Join the sorted list back into a comma-separated string 233 | string = ','.join(map(str, sorted_list)) 234 | 235 | return string 236 | 237 | def extract_answer(pred_str, use_last_number=True): 238 | pred_str = pred_str.replace("\u043a\u0438", "") 239 | if "final answer is $" in pred_str and "$. I hope" in pred_str: 240 | # minerva_math 241 | tmp = pred_str.split("final answer is $", 1)[1] 242 | pred = tmp.split("$. I hope", 1)[0].strip() 243 | elif "boxed" in pred_str: 244 | ans = pred_str.split("boxed")[-1] 245 | if len(ans) == 0: 246 | return "" 247 | elif ans[0] == "{": 248 | stack = 1 249 | a = "" 250 | for c in ans[1:]: 251 | if c == "{": 252 | stack += 1 253 | a += c 254 | elif c == "}": 255 | stack -= 1 256 | if stack == 0: 257 | break 258 | a += c 259 | else: 260 | a += c 261 | else: 262 | a = ans.split("$")[0].strip() 263 | pred = a 264 | elif "he answer is" in pred_str: 265 | pred = pred_str.split("he answer is")[-1].strip() 266 | elif "final answer is" in pred_str: 267 | pred = pred_str.split("final answer is")[-1].strip() 268 | elif "答案是" in pred_str: 269 | # Handle Chinese few-shot multiple choice problem answer extraction 270 | pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip() 271 | else: # use the last number 272 | if use_last_number: 273 | pattern = "-?\d*\.?\d+" 274 | pred = re.findall(pattern, pred_str.replace(",", "")) 275 | if len(pred) >= 1: 276 | pred = pred[-1] 277 | else: 278 | pred = "" 279 | else: 280 | pred = "" 281 | 282 | # multiple line 283 | # pred = pred.split("\n")[0] 284 | pred = re.sub(r"\n\s*", "", pred) 285 | if pred != "" and pred[0] == ":": 286 | pred = pred[1:] 287 | if pred != "" and pred[-1] == ".": 288 | pred = pred[:-1] 289 | if pred != "" and pred[-1] == "/": 290 | pred = pred[:-1] 291 | pred = strip_answer_string(pred) 292 | return pred 293 | 294 | def get_multiple_choice_answer(pred: str): 295 | tmp = re.findall(r"\b(A|B|C|D)\b", pred.upper()) 296 | if tmp: 297 | pred = tmp 298 | else: 299 | pred = [pred.strip().strip(".")] 300 | 301 | if len(pred) == 0: 302 | pred = "" 303 | else: 304 | pred = pred[-1] 305 | 306 | # Remove the period at the end, again! 307 | pred = pred.rstrip(".").rstrip("/") 308 | 309 | return pred 310 | 311 | def choice_answer_clean(pred: str): 312 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") 313 | # Clean the answer based on the dataset 314 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) 315 | if tmp: 316 | pred = tmp 317 | else: 318 | pred = [pred.strip().strip(".")] 319 | pred = pred[-1] 320 | # Remove the period at the end, again! 321 | pred = pred.rstrip(".").rstrip("/") 322 | return pred 323 | 324 | 325 | def parse_digits(num): 326 | num = regex.sub(",", "", str(num)) 327 | try: 328 | return float(num) 329 | except: 330 | if num.endswith("%"): 331 | num = num[:-1] 332 | if num.endswith("\\"): 333 | num = num[:-1] 334 | try: 335 | return float(num) / 100 336 | except: 337 | pass 338 | return None 339 | 340 | 341 | def is_digit(num): 342 | # paired with parse_digits 343 | return parse_digits(num) is not None 344 | 345 | 346 | def str_to_pmatrix(input_str): 347 | input_str = input_str.strip() 348 | matrix_str = re.findall(r"\{.*,.*\}", input_str) 349 | pmatrix_list = [] 350 | 351 | for m in matrix_str: 352 | m = m.strip("{}") 353 | pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}" 354 | pmatrix_list.append(pmatrix) 355 | 356 | return ", ".join(pmatrix_list) 357 | 358 | 359 | def math_equal( 360 | prediction, 361 | reference, 362 | include_percentage: bool = True, 363 | is_close: bool = True, 364 | timeout: bool = False, 365 | ) -> bool: 366 | """ 367 | Exact match of math if and only if: 368 | 1. numerical equal: both can convert to float and are equal 369 | 2. symbolic equal: both can convert to sympy expression and are equal 370 | """ 371 | if prediction is None or reference is None: 372 | return False 373 | if str(prediction.strip().lower()) == str(reference.strip().lower()): 374 | return True 375 | if ( 376 | reference in ["A", "B", "C", "D", "E"] 377 | and choice_answer_clean(prediction) == reference 378 | ): 379 | return True 380 | 381 | try: # 1. numerical equal 382 | if is_digit(prediction) and is_digit(reference): 383 | prediction = parse_digits(prediction) 384 | reference = parse_digits(reference) 385 | # number questions 386 | if include_percentage: 387 | gt_result = [reference / 100, reference, reference * 100] 388 | else: 389 | gt_result = [reference] 390 | for item in gt_result: 391 | try: 392 | if is_close: 393 | if numeric_equal(prediction, item): 394 | return True 395 | else: 396 | if item == prediction: 397 | return True 398 | except Exception: 399 | continue 400 | return False 401 | except: 402 | pass 403 | 404 | if not prediction and prediction not in [0, False]: 405 | return False 406 | 407 | # 2. symbolic equal 408 | reference = str(reference).strip() 409 | prediction = str(prediction).strip() 410 | 411 | ## pmatrix (amps) 412 | if "pmatrix" in prediction and not "pmatrix" in reference: 413 | reference = str_to_pmatrix(reference) 414 | 415 | ## deal with [], (), {} 416 | pred_str, ref_str = prediction, reference 417 | if ( 418 | prediction.startswith("[") 419 | and prediction.endswith("]") 420 | and not reference.startswith("(") 421 | ) or ( 422 | prediction.startswith("(") 423 | and prediction.endswith(")") 424 | and not reference.startswith("[") 425 | ): 426 | pred_str = pred_str.strip("[]()") 427 | ref_str = ref_str.strip("[]()") 428 | for s in ["{", "}", "(", ")"]: 429 | ref_str = ref_str.replace(s, "") 430 | pred_str = pred_str.replace(s, "") 431 | if pred_str.lower() == ref_str.lower(): 432 | return True 433 | 434 | ## [a, b] vs. [c, d], return a==c and b==d 435 | if ( 436 | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None 437 | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None 438 | ): 439 | pred_parts = prediction[1:-1].split(",") 440 | ref_parts = reference[1:-1].split(",") 441 | if len(pred_parts) == len(ref_parts): 442 | if all( 443 | [ 444 | math_equal( 445 | pred_parts[i], ref_parts[i], include_percentage, is_close 446 | ) 447 | for i in range(len(pred_parts)) 448 | ] 449 | ): 450 | return True 451 | if ( 452 | ( 453 | prediction.startswith("\\begin{pmatrix}") 454 | or prediction.startswith("\\begin{bmatrix}") 455 | ) 456 | and ( 457 | prediction.endswith("\\end{pmatrix}") 458 | or prediction.endswith("\\end{bmatrix}") 459 | ) 460 | and ( 461 | reference.startswith("\\begin{pmatrix}") 462 | or reference.startswith("\\begin{bmatrix}") 463 | ) 464 | and ( 465 | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") 466 | ) 467 | ): 468 | pred_lines = [ 469 | line.strip() 470 | for line in prediction[ 471 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") 472 | ].split("\\\\") 473 | if line.strip() 474 | ] 475 | ref_lines = [ 476 | line.strip() 477 | for line in reference[ 478 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") 479 | ].split("\\\\") 480 | if line.strip() 481 | ] 482 | matched = True 483 | if len(pred_lines) == len(ref_lines): 484 | for pred_line, ref_line in zip(pred_lines, ref_lines): 485 | pred_parts = pred_line.split("&") 486 | ref_parts = ref_line.split("&") 487 | if len(pred_parts) == len(ref_parts): 488 | if not all( 489 | [ 490 | math_equal( 491 | pred_parts[i], 492 | ref_parts[i], 493 | include_percentage, 494 | is_close, 495 | ) 496 | for i in range(len(pred_parts)) 497 | ] 498 | ): 499 | matched = False 500 | break 501 | else: 502 | matched = False 503 | if not matched: 504 | break 505 | else: 506 | matched = False 507 | if matched: 508 | return True 509 | 510 | if prediction.count("=") == 1 and reference.count("=") == 1: 511 | pred = prediction.split("=") 512 | pred = f"{pred[0].strip()} - ({pred[1].strip()})" 513 | ref = reference.split("=") 514 | ref = f"{ref[0].strip()} - ({ref[1].strip()})" 515 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): 516 | return True 517 | elif ( 518 | prediction.count("=") == 1 519 | and len(prediction.split("=")[0].strip()) <= 2 520 | and "=" not in reference 521 | ): 522 | if math_equal( 523 | prediction.split("=")[1], reference, include_percentage, is_close 524 | ): 525 | return True 526 | elif ( 527 | reference.count("=") == 1 528 | and len(reference.split("=")[0].strip()) <= 2 529 | and "=" not in prediction 530 | ): 531 | if math_equal( 532 | prediction, reference.split("=")[1], include_percentage, is_close 533 | ): 534 | return True 535 | 536 | if symbolic_equal(prediction, reference): 537 | return True 538 | 539 | return False 540 | 541 | 542 | def numeric_equal(prediction: float, reference: float): 543 | return isclose(reference, prediction, rel_tol=1e-4) 544 | 545 | 546 | def symbolic_equal(a, b): 547 | def _parse(s): 548 | for f in [parse_latex, parse_expr, latex2sympy]: 549 | try: 550 | return f(s.replace("\\\\", "\\")) 551 | except: 552 | try: 553 | return f(s) 554 | except: 555 | pass 556 | return s 557 | 558 | a = _parse(a) 559 | b = _parse(b) 560 | 561 | # direct equal 562 | try: 563 | if str(a) == str(b) or a == b: 564 | return True 565 | except: 566 | pass 567 | 568 | # simplify equal 569 | try: 570 | if a.equals(b) or simplify(a - b) == 0: 571 | return True 572 | except: 573 | pass 574 | 575 | # equation equal 576 | try: 577 | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): 578 | return True 579 | except: 580 | pass 581 | 582 | try: 583 | if numeric_equal(float(N(a)), float(N(b))): 584 | return True 585 | except: 586 | pass 587 | 588 | # matrix 589 | try: 590 | # if a and b are matrix 591 | if a.shape == b.shape: 592 | _a = a.applyfunc(lambda x: round(x, 3)) 593 | _b = b.applyfunc(lambda x: round(x, 3)) 594 | if _a.equals(_b): 595 | return True 596 | except: 597 | pass 598 | 599 | return False 600 | 601 | def find_box(pred_str: str): 602 | ans = pred_str.split("boxed")[-1] 603 | if not ans: 604 | return "" 605 | if ans[0] == "{": 606 | stack = 1 607 | a = "" 608 | for c in ans[1:]: 609 | if c == "{": 610 | stack += 1 611 | a += c 612 | elif c == "}": 613 | stack -= 1 614 | if stack == 0: 615 | break 616 | a += c 617 | else: 618 | a += c 619 | else: 620 | a = ans.split("$")[0].strip() 621 | return a -------------------------------------------------------------------------------- /utils/load_model.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from transformers import AutoTokenizer, AutoModelForCausalLM 3 | from vllm import LLM, SamplingParams 4 | 5 | def count_parameters(model): 6 | return sum(p.numel() for name, p in model.named_parameters() if p.requires_grad) 7 | # return sum(p.numel() for p in model.parameters() if p.requires_grad) 8 | def check_grad_update_status(model): 9 | return None 10 | 11 | def init_model(model_name='', check=True, eval=True, low_cpu_mem_usage=True,): 12 | if eval: 13 | model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", 14 | low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch.bfloat16, 15 | use_flash_attention_2=True) 16 | else: 17 | model = AutoModelForCausalLM.from_pretrained(model_name, use_cache=False, low_cpu_mem_usage=low_cpu_mem_usage, 18 | torch_dtype=torch.bfloat16) 19 | tokenizer = AutoTokenizer.from_pretrained(model_name,) 20 | if check: 21 | check_grad_update_status(model) 22 | return model, tokenizer 23 | 24 | def load_tokenizer(model_name): 25 | tokenizer = AutoTokenizer.from_pretrained(model_name,) 26 | return tokenizer 27 | 28 | def init_vllm_model(model_name, gpus=1): 29 | llm = LLM(model=model_name, tensor_parallel_size=gpus) 30 | return llm 31 | 32 | def llm_generate(model, tokenizer, prompts:list, max_length=256, temperature=0.7, top_p=0.9): 33 | inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_length) 34 | inputs = {k: v.to(model.device) for k, v in inputs.items()} 35 | with torch.no_grad(): 36 | outputs = model.generate( 37 | **inputs, 38 | max_length=max_length, 39 | temperature=temperature, 40 | top_p=top_p, 41 | # num_beams=4, 42 | # do_sample=False, 43 | repetition_penalty=1.1, 44 | pad_token_id=tokenizer.pad_token_id, 45 | eos_token_id=tokenizer.eos_token_id 46 | ) 47 | decoded = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] 48 | for i, prompt in enumerate(prompts): 49 | decoded[i] = decoded[i].replace(prompt, '').strip() 50 | return decoded 51 | 52 | def vllm_generate(llm, prompts:list, max_length=256, temperature=0.7, top_p=0.9): 53 | sampling_params = SamplingParams(temperature=temperature,top_p=top_p, max_tokens=max_length) 54 | outputs = llm.generate(prompts, sampling_params) 55 | return [output.outputs[0].text for output in outputs] 56 | 57 | def add_template(model_name, prompt, tokenizer=None): 58 | if tokenizer is None: tokenizer = load_tokenizer(model_name) 59 | if 'QwQ' in model_name: 60 | messages = [ 61 | {"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step. Return your final response within \\boxed{{}}"}, 62 | {"role": "user", "content": prompt} 63 | ] 64 | prompt = tokenizer.apply_chat_template( 65 | messages, 66 | tokenize=False, 67 | add_generation_prompt=True 68 | ) 69 | elif 'Sky-T1-32B-Preview' in model_name: 70 | d = {'prompt': "Your role as an assistant involves thoroughly exploring questions through a systematic long \ 71 | thinking process before providing the final precise and accurate solutions. This requires \ 72 | engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, \ 73 | backtracing, and iteration to develop well-considered thinking process. \ 74 | Please structure your response into two main sections: Thought and Solution. \ 75 | In the Thought section, detail your reasoning process using the specified format: \ 76 | <|begin_of_thought|> {thought with steps separated with '\n\n'} \ 77 | <|end_of_thought|> \ 78 | Each step should include detailed considerations such as analisying questions, summarizing \ 79 | relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining \ 80 | any errors, and revisiting previous steps. \ 81 | In the Solution section, based on various attempts, explorations, and reflections from the Thought \ 82 | section, systematically present the final solution that you deem correct. The solution should \ 83 | remain a logical, accurate, concise expression style and detail necessary step needed to reach the \ 84 | conclusion, formatted as follows: \ 85 | <|begin_of_solution|> \ 86 | {final formatted, precise, and clear solution} \ 87 | <|end_of_solution|> \ 88 | Now, try to solve the following question through the above guidelines: "} 89 | prompt = d["prompt"] + prompt 90 | elif 'sky' in model_name: 91 | prompt = 'You are a helpful and harmless assistant. You should solve this math problem using step-by-step reasoning. Require that the output of each step ends with the "\n\n" token. Return your final response with \\[ \\boxed{} \\], if answer is "1/2", the output is \\[ \\boxed{1/2} \\]. '+prompt 92 | else: # elif 'math' in model_name: 93 | # 'follow this structure to step-by-step solve math problem, and final answer must formated with \\[ \\boxed{} \\], if answer is "1/2", the output is \\[ \\boxed{1/2} \\]' 94 | # 'solve this math problem using step-by-step reasoning. Require that the output of each step ends with the "\n\n" token.' 95 | prompt = "You are a helpful and harmless assistant. You should think step-by-step. Return your final response within \\boxed{{}}. "+prompt 96 | return prompt -------------------------------------------------------------------------------- /utils/settings.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | project_dir = './' 5 | cache_dir = None 6 | hug_token = '' 7 | if cache_dir: 8 | os.environ["HF_HOME"] = cache_dir 9 | 10 | 11 | 12 | SYSTEM_PROMPT = { 13 | "Qwen/Qwen2-7B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 14 | "Qwen/QwQ-32B-Preview": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 15 | "Qwen/Qwen2.5-72B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 16 | "Qwen/Qwen2.5-32B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 17 | "Qwen/Qwen2.5-7B-Instruct": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", 18 | "NovaSky-AI/Sky-T1-32B-Preview": "Your role as an assistant involves thoroughly exploring questions through a systematic long \ 19 | thinking process before providing the final precise and accurate solutions. This requires \ 20 | engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, \ 21 | backtracing, and iteration to develop well-considered thinking process. \ 22 | Please structure your response into two main sections: Thought and Solution. \ 23 | In the Thought section, detail your reasoning process using the specified format: \ 24 | <|begin_of_thought|> {thought with steps separated with '\n\n'} \ 25 | <|end_of_thought|> \ 26 | Each step should include detailed considerations such as analisying questions, summarizing \ 27 | relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining \ 28 | any errors, and revisiting previous steps. \ 29 | In the Solution section, based on various attempts, explorations, and reflections from the Thought \ 30 | section, systematically present the final solution that you deem correct. The solution should \ 31 | remain a logical, accurate, concise expression style and detail necessary step needed to reach the \ 32 | conclusion, formatted as follows: \ 33 | <|begin_of_solution|> \ 34 | {final formatted, precise, and clear solution} \ 35 | <|end_of_solution|> \ 36 | Now, try to solve the following question through the above guidelines:", 37 | "openai/o1-mini": "Question: {input}\nAnswer: ", 38 | "openai/o1-preview": "Question: {input}\nAnswer: ", 39 | "openai/gpt-4o-mini": "User: {input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.\n\nAssistant:", 40 | } -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import torch 5 | from utils.settings import * 6 | 7 | 8 | def set_global(path): 9 | return os.path.join(project_dir, path) 10 | 11 | from huggingface_hub import login 12 | login(token=hug_token) 13 | 14 | api_keys = { 15 | 'openai': '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' 16 | } 17 | 18 | 19 | def seed_everything(seed): 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | torch.backends.cudnn.benchmark = False 25 | torch.backends.cudnn.deterministic = True 26 | torch.cuda.manual_seed_all(seed) 27 | 28 | def Logger(content): 29 | print(content) --------------------------------------------------------------------------------