├── assets ├── fig1.png ├── fig2.png ├── fig3.png ├── fig4.png ├── thm1.png └── thm2.png ├── src ├── scripts │ ├── compute_mi_trajectories.sh │ ├── run_TTTS.sh │ └── run_RR.sh ├── applications │ ├── data │ │ └── DeepSeek-R1-Distill-Llama-8B.jsonl │ ├── data_loader.py │ ├── evaluate.py │ ├── trajectory.py │ ├── python_executor.py │ ├── utils.py │ ├── RR_model.py │ ├── model_utils.py │ ├── math_utils.py │ ├── grader.py │ ├── TTTS_evaluate.py │ ├── RR_evaluate.py │ └── parser.py ├── mi_estimators.py ├── CKA.py ├── generate_gt_activation.py ├── cal_mi.py ├── generate_activation.py └── plot_mi_peaks.ipynb ├── README.md └── LICENSE /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/fig1.png -------------------------------------------------------------------------------- /assets/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/fig2.png -------------------------------------------------------------------------------- /assets/fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/fig3.png -------------------------------------------------------------------------------- /assets/fig4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/fig4.png -------------------------------------------------------------------------------- /assets/thm1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/thm1.png -------------------------------------------------------------------------------- /assets/thm2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/thm2.png -------------------------------------------------------------------------------- /src/scripts/compute_mi_trajectories.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | models=( 5 | deepseek-ai/DeepSeek-R1-Distill-Llama-8B 6 | ) 7 | 8 | for model in "${models[@]}"; do 9 | # 1. generate representations 10 | echo generate reps on $model ... 11 | python generate_activation.py --model $model --layers 31 --dataset math_train_12k --sample_num 100 12 | python generate_gt_activation.py --model $model --layers 31 --dataset math_train_12k --sample_num 100 13 | 14 | # 2. compute mutual information 15 | echo compute mi on $model ... 16 | python cal_mi.py --gt_model $model --test_model $model --layers 31 --dataset math_train_12k --sample_num 100 & 17 | 18 | done 19 | 20 | -------------------------------------------------------------------------------- /src/scripts/run_TTTS.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd applications/ 4 | 5 | model="your_dir/DeepSeek-R1-Distill-Llama-8B" 6 | 7 | model_name=$(basename "$model") 8 | dataset=aime24 9 | 10 | gpu_id=0 11 | 12 | token_budget=4096 13 | 14 | save_dir="results/${model_name}/${dataset}_budget${token_budget}/" 15 | mkdir -p "$save_dir" 16 | 17 | 18 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -u TTTS_evaluate.py \ 19 | --model_name_or_path $model \ 20 | --data_names $dataset \ 21 | --output_dir $save_dir \ 22 | --split "test" \ 23 | --prompt_type "deepseek-math" \ 24 | --num_test_sample -1 \ 25 | --seed 0 \ 26 | --temperature 0 \ 27 | --n_sampling 1 \ 28 | --top_p 1 \ 29 | --start 0 \ 30 | --end -1 \ 31 | --overwrite \ 32 | --use_vllm \ 33 | --thinking_tokens_file_path data/${model_name}.jsonl \ 34 | --max_tokens_per_call 4096 \ 35 | --token_budget $token_budget 36 | -------------------------------------------------------------------------------- /src/scripts/run_RR.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | cd applications/ 5 | 6 | PROMPT_TYPE="deepseek-math" 7 | SPLIT="test" 8 | NUM_TEST_SAMPLE=-1 9 | LOG_DIR=your_log/$(date +%m-%d_%H-%M) 10 | mkdir -p $LOG_DIR 11 | 12 | 13 | gpu_counter=0 14 | 15 | 16 | model_path='your_dir/DeepSeek-R1-Distill-Llama-8B' 17 | 18 | datasets=( 19 | "aime24" 20 | ) 21 | ei_layers=(23) 22 | 23 | 24 | model_name=$(basename "$model_path") 25 | 26 | for (( j=0; j<${#datasets[@]}; j++ )); do 27 | dataset=${datasets[$j]} 28 | 29 | for (( k=0; k<${#ei_layers[@]}; k++ )); do 30 | ei_layer=${ei_layers[$k]} 31 | save_dir="scores/${model_name}/${ei_layer}/" 32 | responses_dir="responses/${model_name}/${dataset}" 33 | mkdir -p "$save_dir" 34 | mkdir -p "$responses_dir" 35 | 36 | echo "Launched: model=$model_name, dataset=$dataset, layer=$ei_layer, GPU=$gpu_counter" 37 | 38 | log_file="${LOG_DIR}/${model_name}_${dataset}_${ei_layer}.log" 39 | mkdir -p "$(dirname "$log_file")" 40 | 41 | CUDA_VISIBLE_DEVICES=$gpu_counter \ 42 | python RR_evaluate.py \ 43 | --model_name_or_path "$model_path" \ 44 | --data_names "$dataset" \ 45 | --inject_layer_id $ei_layer \ 46 | --extract_layer_id $ei_layer \ 47 | --num_test_sample "$NUM_TEST_SAMPLE" \ 48 | --output_file "$responses_dir/${ei_layer}.jsonl"\ 49 | --use_recursive_thinking True \ 50 | --num_recursive_steps 1 \ 51 | --output_dir "$save_dir" \ 52 | --split "$SPLIT" \ 53 | --prompt_type "$PROMPT_TYPE" \ 54 | --seed 0 \ 55 | --temperature 0 \ 56 | --n_sampling 1 \ 57 | --top_p 1 \ 58 | --start 0 \ 59 | --end -1 \ 60 | --save_outputs \ 61 | --overwrite \ 62 | --interested_tokens_file_path data/${model_name}.jsonl \ 63 | --max_tokens_per_call 16000 >> "$log_file" 2>&1 & 64 | 65 | ((gpu_counter++)) 66 | done 67 | done 68 | 69 | # Wait for all background tasks to complete 70 | wait 71 | echo "All tasks completed!" -------------------------------------------------------------------------------- /src/applications/data/DeepSeek-R1-Distill-Llama-8B.jsonl: -------------------------------------------------------------------------------- 1 | {"word": "So", "freq": 394, "token_ids": [2100, 4516, 78012]} 2 | {"word": "Let", "freq": 218, "token_ids": [6914, 10267]} 3 | {"word": "Hmm", "freq": 127, "token_ids": [81122, 89290]} 4 | {"word": " ", "freq": 118, "token_ids": []} 5 | {"word": "\n\n", "freq": 114, "token_ids": [271, 4815, 15152, 16176, 19124, 24356, 35033, 35249, 40965, 62098, 66367, 66768, 72348, 75625, 81923, 126595]} 6 | {"word": "I", "freq": 112, "token_ids": [40, 358, 25494]} 7 | {"word": "The", "freq": 89, "token_ids": [578, 791, 33026]} 8 | {"word": "Okay", "freq": 88, "token_ids": [33413, 36539]} 9 | {"word": "That", "freq": 81, "token_ids": [3011, 4897]} 10 | {"word": "First", "freq": 79, "token_ids": [5451, 5629]} 11 | {"word": "Now", "freq": 74, "token_ids": [4800, 7184]} 12 | {"word": "\\", "freq": 70, "token_ids": [59, 1144]} 13 | {"word": "Wait", "freq": 68, "token_ids": [14144, 14524]} 14 | {"word": "But", "freq": 64, "token_ids": [2030, 4071]} 15 | {"word": "-", "freq": 47, "token_ids": [12, 482]} 16 | {"word": "Then", "freq": 45, "token_ids": [5112, 12487]} 17 | {"word": "Since", "freq": 45, "token_ids": [8876, 12834]} 18 | {"word": "\\(", "freq": 44, "token_ids": [18240, 45392]} 19 | {"word": "Therefore", "freq": 35, "token_ids": [15636, 55915]} 20 | {"word": "=", "freq": 32, "token_ids": [28, 284]} 21 | {"word": "1", "freq": 32, "token_ids": [16]} 22 | {"word": "Maybe", "freq": 29, "token_ids": [10926, 22105]} 23 | {"word": "x", "freq": 28, "token_ids": [87, 865, 10436]} 24 | {"word": "\n", "freq": 27, "token_ids": [198, 319, 720, 1084, 1602, 1734, 1827, 2355, 2451, 2591, 3456, 4574, 4660, 5996, 6053, 6336, 6494, 6557, 7071, 7786, 9175, 10636, 10912, 11187, 12064, 12586, 12858, 16052, 16244, 16462, 16554, 17707, 17934, 18737, 19548, 20254, 20959, 25332, 26510, 27381, 27644, 28465, 28768, 29347, 31745, 33645, 34741, 35583, 37677, 38304, 38792, 39185, 39420, 39912, 40748, 41437, 42736, 46675, 46907, 47526, 52050, 52224, 52580, 53820, 56547, 57696, 59659, 63317, 65883, 66417, 67934, 68764, 69862, 70977, 72764, 72879, 73453, 74296, 76328, 79093, 79455, 80100, 85952, 86522, 87870, 89966, 90001, 90260, 91406, 91458, 92305, 95791, 96047, 97117, 98414, 99307, 99351, 106050, 120582, 122019]} 25 | {"word": "If", "freq": 25, "token_ids": [1442, 2746, 52792]} 26 | {"word": "To", "freq": 24, "token_ids": [1271, 2057]} 27 | {"word": "**", "freq": 24, "token_ids": [334, 3146]} 28 | {"word": "2", "freq": 20, "token_ids": [17]} 29 | {"word": "(", "freq": 20, "token_ids": [7, 320]} 30 | {"word": "5", "freq": 17, "token_ids": [20]} 31 | -------------------------------------------------------------------------------- /src/mi_estimators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def distmat(X): 8 | """ distance matrix 9 | """ 10 | if len(X.shape) == 1: 11 | X = X.view(-1, 1) 12 | r = torch.sum(X * X, 1) 13 | r = r.view([-1, 1]) 14 | a = torch.mm(X, torch.transpose(X, 0, 1)) 15 | D = r.expand_as(a) - 2 * a + torch.transpose(r, 0, 1).expand_as(a) 16 | D = torch.abs(D) 17 | return D 18 | 19 | 20 | def sigma_estimation(X, Y): 21 | """ sigma from median distance 22 | """ 23 | D = distmat(torch.cat([X, Y])) 24 | D = D.detach().cpu().numpy() 25 | Itri = np.tril_indices(D.shape[0], -1) 26 | Tri = D[Itri] 27 | med = np.median(Tri) 28 | if med <= 0: 29 | med = np.mean(Tri) 30 | if med < 1E-2: 31 | med = 1E-2 32 | return med 33 | 34 | 35 | def kernelmat(X, sigma, ktype='gaussian'): 36 | """ kernel matrix baker 37 | """ 38 | if len(X.shape) == 1: 39 | X = X.view(-1, 1) 40 | 41 | m = int(X.size()[0]) 42 | H = torch.eye(m) - (1. / m) * torch.ones([m, m]) 43 | 44 | if ktype == "gaussian": 45 | Dxx = distmat(X) 46 | 47 | if sigma: 48 | variance = 2. * sigma * sigma * X.size()[1] 49 | Kx = torch.exp(-Dxx / variance).type(torch.FloatTensor) # kernel matrices 50 | # print(sigma, torch.mean(Kx), torch.max(Kx), torch.min(Kx)) 51 | else: 52 | try: 53 | sx = sigma_estimation(X, X) 54 | Kx = torch.exp(-Dxx / (2. * sx * sx)).type(torch.FloatTensor) 55 | except RuntimeError as e: 56 | raise RuntimeError("Unstable sigma {} with maximum/minimum input ({},{})".format( 57 | sx, torch.max(X), torch.min(X))) 58 | 59 | 60 | elif ktype == "linear": 61 | Kx = torch.mm(X, X.T).type(torch.FloatTensor) 62 | 63 | elif ktype == 'IMQ': 64 | Dxx = distmat(X) 65 | Kx = 1 * torch.rsqrt(Dxx + 1) 66 | 67 | Kxc = torch.mm(Kx, H) 68 | 69 | return Kxc 70 | 71 | 72 | def hsic_normalized_cca(x, y, sigma=50., ktype='gaussian'): 73 | if len(x.shape) == 1: 74 | x = x.reshape(-1, 1) 75 | if len(y.shape) == 1: 76 | y = y.reshape(-1, 1) 77 | # x = torch.from_numpy(x) 78 | # y = torch.from_numpy(y) 79 | 80 | m = int(x.size()[0]) 81 | Kxc = kernelmat(x, sigma=sigma, ktype=ktype) 82 | Kyc = kernelmat(y, sigma=sigma, ktype=ktype) 83 | 84 | epsilon = 1E-5 85 | K_I = torch.eye(m) 86 | Kxc_i = torch.inverse(Kxc + epsilon * m * K_I) 87 | Kyc_i = torch.inverse(Kyc + epsilon * m * K_I) 88 | Rx = (Kxc.mm(Kxc_i)) 89 | Ry = (Kyc.mm(Kyc_i)) 90 | Pxy = torch.sum(torch.mul(Rx, Ry.t())) 91 | 92 | return Pxy 93 | 94 | 95 | def estimate_mi_hsic(x, y, ktype='gaussian', sigma=50.): 96 | estimate_IXY = hsic_normalized_cca(x, y, ktype=ktype, sigma=sigma) 97 | return estimate_IXY -------------------------------------------------------------------------------- /src/CKA.py: -------------------------------------------------------------------------------- 1 | # python ablation_cka.py --dataset pku-rlhf-10k --base_model llama-2-7b-chat 2 | # python ablation_cka.py --dataset bigcodebench_complete_prompts --base_model codellama-7b 3 | # python ablation_cka.py --dataset mmlu_all_test_questions --base_model llama-2-7b-chat 4 | 5 | import os 6 | import argparse 7 | # from model import Projection, MLP 8 | # from generate_head_activations import load_acts 9 | import numpy as np 10 | import math 11 | import pickle 12 | import torch 13 | import torch.optim as optim 14 | 15 | 16 | class CKA(object): 17 | def __init__(self): 18 | pass 19 | 20 | def centering(self, K): 21 | n = K.shape[0] 22 | unit = np.ones([n, n]) 23 | I = np.eye(n) 24 | H = I - unit / n 25 | return np.dot(np.dot(H, K), H) 26 | 27 | def rbf(self, X, sigma=None): 28 | GX = np.dot(X, X.T) 29 | KX = np.diag(GX) - GX + (np.diag(GX) - GX).T 30 | if sigma is None: 31 | mdist = np.median(KX[KX != 0]) 32 | sigma = math.sqrt(mdist) 33 | KX *= - 0.5 / (sigma * sigma) 34 | KX = np.exp(KX) 35 | return KX 36 | 37 | def kernel_HSIC(self, X, Y, sigma): 38 | return np.sum(self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma))) 39 | 40 | def linear_HSIC(self, X, Y): 41 | L_X = X @ X.T 42 | L_Y = Y @ Y.T 43 | return np.sum(self.centering(L_X) * self.centering(L_Y)) 44 | 45 | def linear_CKA(self, X, Y): 46 | hsic = self.linear_HSIC(X, Y) 47 | var1 = np.sqrt(self.linear_HSIC(X, X)) 48 | var2 = np.sqrt(self.linear_HSIC(Y, Y)) 49 | 50 | return hsic / (var1 * var2) 51 | 52 | def kernel_CKA(self, X, Y, sigma=None): 53 | hsic = self.kernel_HSIC(X, Y, sigma) 54 | var1 = np.sqrt(self.kernel_HSIC(X, X, sigma)) 55 | var2 = np.sqrt(self.kernel_HSIC(Y, Y, sigma)) 56 | 57 | return hsic / (var1 * var2) 58 | 59 | 60 | class CudaCKA(object): 61 | def __init__(self, device): 62 | self.device = device 63 | 64 | def centering(self, K): 65 | n = K.shape[0] 66 | unit = torch.ones([n, n], device=self.device) 67 | I = torch.eye(n, device=self.device) 68 | H = I - unit / n 69 | return torch.matmul(torch.matmul(H, K), H) 70 | 71 | def rbf(self, X, sigma=None): 72 | GX = torch.matmul(X, X.T) 73 | KX = torch.diag(GX) - GX + (torch.diag(GX) - GX).T 74 | if sigma is None: 75 | mdist = torch.median(KX[KX != 0]) 76 | sigma = math.sqrt(mdist) 77 | KX *= - 0.5 / (sigma * sigma) 78 | KX = torch.exp(KX) 79 | return KX 80 | 81 | def kernel_HSIC(self, X, Y, sigma): 82 | return torch.sum(self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma))) 83 | 84 | def linear_HSIC(self, X, Y): 85 | L_X = torch.matmul(X, X.T) 86 | L_Y = torch.matmul(Y, Y.T) 87 | return torch.sum(self.centering(L_X) * self.centering(L_Y)) 88 | 89 | def linear_CKA(self, X, Y): 90 | hsic = self.linear_HSIC(X, Y) 91 | var1 = torch.sqrt(self.linear_HSIC(X, X)) 92 | var2 = torch.sqrt(self.linear_HSIC(Y, Y)) 93 | 94 | return hsic / (var1 * var2) 95 | 96 | def kernel_CKA(self, X, Y, sigma=None): 97 | hsic = self.kernel_HSIC(X, Y, sigma) 98 | var1 = torch.sqrt(self.kernel_HSIC(X, X, sigma)) 99 | var2 = torch.sqrt(self.kernel_HSIC(Y, Y, sigma)) 100 | return hsic / (var1 * var2) 101 | 102 | 103 | -------------------------------------------------------------------------------- /src/generate_gt_activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | import pandas as pd 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | from tqdm import tqdm 7 | 8 | 9 | class Hook: 10 | def __init__(self, token_position=-1): 11 | self.token_position = token_position 12 | self.tokens_embeddings = [] 13 | 14 | def __call__(self, module, module_inputs, module_outputs): 15 | # output: [batch, seq_len, hidden_size] 16 | hidden_states = module_outputs[0] if isinstance(module_outputs, tuple) else module_outputs 17 | emb = hidden_states[0, self.token_position].detach().cpu() 18 | 19 | self.tokens_embeddings.append(emb) 20 | 21 | 22 | def load_model(model_path): 23 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 24 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map="auto") 25 | 26 | return tokenizer, model 27 | 28 | 29 | def get_acts(query_list, tokenizer, model, layers, device, token_pos=-1): 30 | """ 31 | Get given layer activations for the statements. 32 | Return dictionary of stacked activations. 33 | 34 | token_pos: default to fetch the last token's activations 35 | """ 36 | # attach hooks 37 | hooks, handles = [], [] 38 | for layer in layers: 39 | hook = Hook(token_position=token_pos) 40 | handle = model.model.layers[layer].register_forward_hook(hook) 41 | hooks.append(hook), handles.append(handle) 42 | 43 | # get activations 44 | acts = {id: {layer: [] for layer in layers} for id in range(len(query_list))} 45 | 46 | for id, query in tqdm(enumerate(query_list), total=len(query_list), desc="Processing Queries"): 47 | 48 | input_ids = tokenizer.encode(query, return_tensors="pt") 49 | with torch.no_grad(): 50 | model(input_ids) 51 | for layer, hook in zip(layers, hooks): 52 | acts[id][layer] = hook.tokens_embeddings 53 | 54 | for hook in hooks: 55 | hook.tokens_embeddings = [] 56 | 57 | for id, layer_acts in acts.items(): 58 | for layer, emb in layer_acts.items(): 59 | layer_acts[layer] = torch.stack(emb).float() 60 | 61 | # remove hooks 62 | for handle in handles: 63 | handle.remove() 64 | 65 | return acts 66 | 67 | 68 | def main(): 69 | parser = argparse.ArgumentParser(description="Generate gt activations for statements in a dataset") 70 | parser.add_argument("--model_path", default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B") 71 | parser.add_argument("--layers", nargs='+', help="Layers to save embeddings from") 72 | parser.add_argument("--dataset", default='math_train_12k') 73 | parser.add_argument("--sample_num", type=int, default=100) 74 | parser.add_argument("--output_dir", default="acts", help="Directory to save activations to") 75 | args = parser.parse_args() 76 | 77 | dataset = args.dataset 78 | 79 | math_data = pd.read_csv(f'data/{dataset}.csv') 80 | 81 | solution_list = math_data['solution'].tolist()[:args.sample_num] 82 | 83 | tokenizer, model = load_model(args.model_path) 84 | 85 | layers = [int(layer) for layer in args.layers] 86 | if layers == [-1]: 87 | layers = list(range(len(model.model.layers))) 88 | 89 | acts = get_acts(solution_list, tokenizer, model, layers, device, token_pos=-1) 90 | 91 | # save representations 92 | os.makedirs(f'{args.output_dir}/gt', exist_ok=True) 93 | torch.save(acts, f"{args.output_dir}/gt/{dataset}_{args.model_path.split('/')[-1]}.pth") 94 | 95 | 96 | if __name__=='__main__': 97 | main() -------------------------------------------------------------------------------- /src/applications/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import datasets 5 | from datasets import load_dataset, Dataset, concatenate_datasets 6 | from utils import load_jsonl, lower_keys 7 | 8 | 9 | def load_data(data_name, split, data_dir="./data"): 10 | data_file = f"{data_dir}/{data_name}/{split}.jsonl" 11 | if os.path.exists(data_file): 12 | examples = list(load_jsonl(data_file)) 13 | else: 14 | if data_name == "math": 15 | dataset = load_dataset( 16 | "competition_math", 17 | split=split, 18 | name="main", 19 | cache_dir=f"{data_dir}/temp", 20 | ) 21 | elif data_name == "MATH-500": 22 | dataset = load_dataset(data_name, split=split) 23 | elif data_name == "gsm8k": 24 | dataset = load_dataset(data_name, split=split) 25 | elif data_name == "svamp": 26 | # evaluate on training set + test set 27 | dataset = load_dataset("ChilleD/SVAMP", split="train") 28 | dataset = concatenate_datasets( 29 | [dataset, load_dataset("ChilleD/SVAMP", split="test")] 30 | ) 31 | elif data_name == "asdiv": 32 | dataset = load_dataset("EleutherAI/asdiv", split="validation") 33 | dataset = dataset.filter( 34 | lambda x: ";" not in x["answer"] 35 | ) # remove multi-answer examples 36 | elif data_name == "mawps": 37 | examples = [] 38 | # four sub-tasks 39 | for data_name in ["singleeq", "singleop", "addsub", "multiarith"]: 40 | sub_examples = list(load_jsonl(f"{data_dir}/mawps/{data_name}.jsonl")) 41 | for example in sub_examples: 42 | example["type"] = data_name 43 | examples.extend(sub_examples) 44 | dataset = Dataset.from_list(examples) 45 | elif data_name == "mmlu_stem": 46 | dataset = load_dataset("hails/mmlu_no_train", "all", split="test") 47 | # only keep stem subjects 48 | stem_subjects = [ 49 | "abstract_algebra", 50 | "astronomy", 51 | "college_biology", 52 | "college_chemistry", 53 | "college_computer_science", 54 | "college_mathematics", 55 | "college_physics", 56 | "computer_security", 57 | "conceptual_physics", 58 | "electrical_engineering", 59 | "elementary_mathematics", 60 | "high_school_biology", 61 | "high_school_chemistry", 62 | "high_school_computer_science", 63 | "high_school_mathematics", 64 | "high_school_physics", 65 | "high_school_statistics", 66 | "machine_learning", 67 | ] 68 | dataset = dataset.rename_column("subject", "type") 69 | dataset = dataset.filter(lambda x: x["type"] in stem_subjects) 70 | elif data_name == "carp_en": 71 | dataset = load_jsonl(f"{data_dir}/carp_en/test.jsonl") 72 | else: 73 | raise NotImplementedError(data_name) 74 | 75 | examples = list(dataset) 76 | examples = [lower_keys(example) for example in examples] 77 | dataset = Dataset.from_list(examples) 78 | os.makedirs(f"{data_dir}/{data_name}", exist_ok=True) 79 | dataset.to_json(data_file) 80 | 81 | # add 'idx' in the first column 82 | if "idx" not in examples[0]: 83 | examples = [{"idx": i, **example} for i, example in enumerate(examples)] 84 | 85 | # dedepulicate & sort 86 | examples = sorted(examples, key=lambda x: x["idx"]) 87 | return examples 88 | -------------------------------------------------------------------------------- /src/cal_mi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | from mi_estimators import estimate_mi_hsic 7 | 8 | 9 | def calculate_mi(acts, gt_acts, layers=[], num_samples=-1, save_dir='results/mi/', args=None): 10 | 11 | num_samples = len(acts) if num_samples < 0 else num_samples 12 | num_layers = len(acts[0]['reps']) 13 | 14 | mi_list = [] 15 | 16 | all_mi_matrices_list = [] 17 | 18 | try: 19 | final_mi_dict = torch.load(f'{save_dir}/{args.dataset}_gtmodel={args.gt_model}_testmodel={args.test_model}.pth') 20 | except: 21 | final_mi_dict = { 22 | id: { 23 | 'reps': {layer: [] for layer in range(num_layers)}, 24 | 'total_tokens': -1 25 | } 26 | for id in range(num_samples) 27 | } 28 | 29 | for id in tqdm(range(0, num_samples)): # for each query 30 | 31 | if final_mi_dict[id]['total_tokens'] > 0: 32 | print(f'id {id} has been computed, skip.') 33 | continue 34 | 35 | final_mi_dict[id]['total_tokens'] = acts[id]['token_ids'].shape[0] 36 | 37 | 38 | if len(layers) == 0: 39 | layers = [layer for layer in acts[id]['reps'].keys()] 40 | 41 | layer_mi_matrix = [] 42 | for layer in layers: 43 | here_num_tokens = acts[id]['reps'][layer].shape[0] 44 | layer_mi_list = torch.zeros(here_num_tokens) 45 | 46 | for i in range(here_num_tokens): 47 | layer_mi_list[i] = estimate_mi_hsic(acts[id]['reps'][layer][i], gt_acts[id][layer][0]) 48 | print(f'loop {i} finished!') 49 | 50 | layer_mi_matrix.append(layer_mi_list) 51 | final_mi_dict[id]['reps'][layer] = layer_mi_list 52 | 53 | 54 | print(f'[id={id}, layer={layer}] layer_mi_list:', layer_mi_list) 55 | print('total_tokens:', acts[id]['token_ids'].shape[0]) 56 | 57 | all_mi_matrices_list.append(layer_mi_matrix) 58 | 59 | print(f'id {id} mi_matrix computation finished!') 60 | 61 | # save 62 | os.makedirs(save_dir, exist_ok=True) 63 | torch.save(final_mi_dict, f'{save_dir}/{args.dataset}_gtmodel={args.gt_model}_testmodel={args.test_model}.pth') 64 | 65 | 66 | return final_mi_dict 67 | 68 | 69 | 70 | def load_reps(dataset_name, model_tag, is_gt=False, step_level=False): 71 | print(f'Loading activations of model [{model_tag}], on dataset [{dataset_name}]...') 72 | 73 | if is_gt: 74 | return torch.load(f"acts/gt/{dataset_name}_{model_tag}.pth") 75 | else: 76 | return torch.load(f"acts/reasoning_evolve/{dataset_name}_{model_tag}.pth") 77 | 78 | 79 | def main(): 80 | parser = argparse.ArgumentParser(description="Generate activations for statements in a dataset") 81 | 82 | parser.add_argument("--gt_model", default='deepseek-ai/DeepSeek-R1-Distill-Llama-8B') 83 | parser.add_argument("--test_model", default='deepseek-ai/DeepSeek-R1-Distill-Llama-8B') 84 | parser.add_argument("--dataset", default='math_train_12k') 85 | parser.add_argument("--layers", nargs='*', type=int, default=[]) 86 | 87 | parser.add_argument("--sample_num", type=int, default=100) 88 | 89 | 90 | args = parser.parse_args() 91 | 92 | sample_num = args.sample_num 93 | 94 | 95 | dataset_name = args.dataset 96 | 97 | gt_model = args.gt_model.split('/')[-1] 98 | test_model = args.test_model.split('/')[-1] 99 | layers = args.layers 100 | 101 | acts = load_reps(dataset_name=dataset_name, model_tag=test_model) 102 | gt_acts = load_reps(dataset_name=dataset_name, model_tag=gt_model, is_gt=True) 103 | 104 | 105 | save_dir = f'results/mi' 106 | os.makedirs(save_dir, exist_ok=True) 107 | 108 | final_mi_dict = calculate_mi(acts=acts, gt_acts=gt_acts, layers=layers, num_samples=args.sample_num, save_dir=save_dir, args=args) 109 | 110 | 111 | torch.save(final_mi_dict, f'{save_dir}/{dataset_name}_gtmodel={gt_model}_testmodel={test_model}.pth') 112 | 113 | 114 | if __name__=='__main__': 115 | main() 116 | 117 | -------------------------------------------------------------------------------- /src/generate_activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import pandas as pd 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | from tqdm import tqdm 6 | 7 | 8 | class Hook: 9 | def __init__(self, token_position=-1): 10 | self.token_position = token_position 11 | self.tokens_embeddings = [] 12 | 13 | def __call__(self, module, module_inputs, module_outputs): 14 | # output: [batch, seq_len, hidden_size] 15 | hidden_states = module_outputs[0] if isinstance(module_outputs, tuple) else module_outputs 16 | emb = hidden_states[0, self.token_position].detach().cpu() 17 | 18 | self.tokens_embeddings.append(emb) 19 | 20 | 21 | def load_model(model_path): 22 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 23 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map="auto") 24 | 25 | return tokenizer, model 26 | 27 | 28 | def get_acts(query_list, tokenizer, model, layers, token_pos=-1): 29 | """ 30 | Get given layer activations for the statements. 31 | Return dictionary of stacked activations. 32 | 33 | token_pos: default to fetch the last token's activations 34 | """ 35 | # attach hooks 36 | hooks, handles = [], [] 37 | for layer in layers: 38 | hook = Hook(token_position=token_pos) 39 | handle = model.model.layers[layer].register_forward_hook(hook) 40 | hooks.append(hook), handles.append(handle) 41 | 42 | # get activations 43 | acts = { 44 | id: { 45 | 'reps': {layer: [] for layer in layers}, 46 | 'token_ids': [] 47 | } 48 | for id in range(len(query_list)) 49 | } 50 | 51 | 52 | for id, query in tqdm(enumerate(query_list), total=len(query_list), desc="Processing Queries"): 53 | 54 | input_ids = tokenizer.encode(query, return_tensors="pt") 55 | with torch.no_grad(): 56 | outputs = model.generate( 57 | input_ids, 58 | max_new_tokens=512, 59 | do_sample=False, # greedy 60 | return_dict_in_generate=True, 61 | output_hidden_states=True 62 | ) 63 | 64 | response = tokenizer.batch_decode(outputs[0][:, input_ids.shape[1]:-1])[0].strip() 65 | 66 | for layer, hook in zip(layers, hooks): 67 | acts[id]['reps'][layer] = hook.tokens_embeddings 68 | 69 | acts[id]['token_ids'] = outputs[0][:, input_ids.shape[1]:-1].squeeze().cpu() 70 | 71 | for hook in hooks: 72 | hook.tokens_embeddings = [] 73 | 74 | for id, layer_acts in acts.items(): 75 | for layer, emb in layer_acts['reps'].items(): 76 | layer_acts['reps'][layer] = torch.stack(emb).float() 77 | 78 | 79 | # remove hooks 80 | for handle in handles: 81 | handle.remove() 82 | 83 | return acts 84 | 85 | 86 | def main(): 87 | parser = argparse.ArgumentParser(description="Generate activations for statements in a dataset") 88 | parser.add_argument("--model_path", default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B") 89 | parser.add_argument("--layers", nargs='+', help="Layers to save embeddings from") 90 | parser.add_argument("--dataset", default='math_train_12k') 91 | parser.add_argument("--sample_num", type=int, default=100) 92 | parser.add_argument("--output_dir", default="acts", help="Directory to save activations to") 93 | args = parser.parse_args() 94 | 95 | dataset = args.dataset 96 | 97 | math_data = pd.read_csv(f'data/{dataset}.csv') 98 | 99 | query_list = math_data['problem'].tolist()[:args.sample_num] 100 | 101 | tokenizer, model = load_model(args.model_path) 102 | 103 | layers = [int(layer) for layer in args.layers] 104 | if layers == [-1]: 105 | layers = list(range(len(model.model.layers))) 106 | 107 | acts = get_acts(query_list, tokenizer, model, layers, token_pos=-1) 108 | 109 | # save representations 110 | os.makedirs(f'{args.output_dir}/reasoning_evolve/', exist_ok=True) 111 | torch.save(acts, f"{args.output_dir}/reasoning_evolve/{dataset}_{args.model_path.split('/')[-1]}.pth") 112 | 113 | 114 | if __name__=='__main__': 115 | main() 116 | 117 | -------------------------------------------------------------------------------- /src/applications/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from tqdm import tqdm 4 | from pebble import ProcessPool 5 | from concurrent.futures import TimeoutError 6 | 7 | from grader import * 8 | 9 | from parser import * 10 | from utils import load_jsonl 11 | from python_executor import PythonExecutor 12 | 13 | 14 | def evaluate(data_name, prompt_type, samples: list=None, file_path: str=None, max_num_samples=None, execute=False): 15 | assert samples or file_path, "samples or file_path must be provided" 16 | if not samples: 17 | samples = list(load_jsonl(file_path)) 18 | if 'idx' in samples[0]: 19 | samples = {sample['idx']: sample for sample in samples}.values() 20 | samples = sorted(samples, key=lambda x: x['idx']) 21 | else: 22 | samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)] 23 | 24 | if max_num_samples: 25 | print(f"max_num_samples: {max_num_samples} / {len(samples)}") 26 | samples = samples[:max_num_samples] 27 | 28 | # parse gt 29 | for sample in samples: 30 | sample['gt_cot'], sample['gt'] = parse_ground_truth(sample, data_name) 31 | params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']] 32 | 33 | scores = [] 34 | timeout_cnt = 0 35 | 36 | with ProcessPool(max_workers=1) as pool: 37 | future = pool.map(math_equal_process, params, timeout=3) 38 | iterator = future.result() 39 | with tqdm(total=len(samples), desc="Evaluate") as progress_bar: 40 | while True: 41 | try: 42 | result = next(iterator) 43 | scores.append(result) 44 | except StopIteration: 45 | break 46 | except TimeoutError as error: 47 | print(error) 48 | scores.append(False) 49 | timeout_cnt += 1 50 | except Exception as error: 51 | print(error.traceback) 52 | exit() 53 | progress_bar.update(1) 54 | 55 | idx = 0 56 | score_mat = [] 57 | for sample in samples: 58 | sample['score'] = scores[idx: idx+len(sample['pred'])] 59 | assert len(sample['score']) == len(sample['pred']) 60 | score_mat.append(sample['score']) 61 | idx += len(sample['pred']) 62 | 63 | max_len = max([len(s) for s in score_mat]) 64 | 65 | for i, s in enumerate(score_mat): 66 | if len(s) < max_len: 67 | score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad 68 | 69 | # output mean of each column of scores 70 | col_means= np.array(score_mat).mean(axis=0) 71 | mean_score = list(np.round(col_means * 100, decimals=1)) 72 | 73 | result_json = { 74 | "num_samples": len(samples), 75 | "num_scores": len(scores), 76 | "timeout_samples": timeout_cnt, 77 | "empty_samples": len([s for s in samples if not s['pred'][-1]]), 78 | "acc": mean_score[0] 79 | } 80 | 81 | # each type score 82 | if "type" in samples[0]: 83 | type_scores = {} 84 | for sample in samples: 85 | if sample['type'] not in type_scores: 86 | type_scores[sample['type']] = [] 87 | type_scores[sample['type']].append(sample['score'][-1]) 88 | type_scores = {k: np.round(np.array(v).mean() * 100, decimals=1) for k, v in type_scores.items()} 89 | type_scores = {k: v for k, v in sorted(type_scores.items(), key=lambda item: item[0])} 90 | result_json['type_acc'] = type_scores 91 | 92 | print(result_json) 93 | return samples, result_json 94 | 95 | 96 | def parse_args(): 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--data_name", type=str, default="math") 99 | parser.add_argument("--prompt_type", type=str, default="tool-integrated") 100 | parser.add_argument("--file_path", type=str, default=None, required=True) 101 | parser.add_argument("--max_num_samples", type=int, default=None) 102 | parser.add_argument("--execute", action="store_true") 103 | args = parser.parse_args() 104 | return args 105 | 106 | if __name__ == "__main__": 107 | args = parse_args() 108 | evaluate(data_name=args.data_name, prompt_type=args.prompt_type, file_path=args.file_path, 109 | max_num_samples=args.max_num_samples, execute=args.execute) 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 |

Demystifying Reasoning Dynamics with Mutual Information: Thinking Tokens are Information Peaks in LLM Reasoning

5 |
6 | 📢 If you are interested in our work, please star ⭐ our project. 7 | 8 |

9 | 10 | License 11 | 12 |

13 |
14 | 15 | 16 | ## 🌈 Introduction 17 | 18 | 19 | ![Overview Diagram](assets/fig1.png) 20 | 21 | Large reasoning models (LRMs) have demonstrated impressive capabilities in 22 | complex problem-solving, yet their internal reasoning mechanisms remain poorly 23 | understood. In this paper, we investigate the reasoning trajectories of LRMs from 24 | an information-theoretic perspective. By tracking how mutual information (MI) 25 | between intermediate representations and the correct answer evolves during LRM 26 | reasoning, we observe an interesting *MI peaks* phenomenon: **the MI at specific 27 | generative steps exhibits a sudden and significant increase during LRM’s 28 | reasoning process**. We theoretically analyze such phenomenon and show that as 29 | MI increases, the probability of model’s prediction error decreases. Furthermore, 30 | **these MI peaks often correspond to tokens expressing reflection or transition, such as “Hmm”, “Wait” and “Therefore,”** which we term as the thinking 31 | tokens. We then demonstrate that these thinking tokens are crucial for LRM’s 32 | reasoning performance, while other tokens has minimal impacts. Building on 33 | these analyses, we propose two simple yet effective methods to improve LRM’s 34 | reasoning performance, by delicately leveraging these thinking tokens. Overall, our work provides novel insights into the reasoning mechanisms of LRMs and offers 35 | practical ways to improve their reasoning capabilities. 36 | 37 | 38 | ## 🚩Main Analyses 39 | 40 | 43 | 44 | **Certain steps exhibit sudden and significantly increases in MI during the reasoning process of LRMs, and these MI peaks are sparse and distribute non-uniformly.** 45 | 46 | Table 1 47 | 48 | 49 | 50 | --- 51 | 52 | 55 | 56 | **Theoretical Insights: Higher MI Leads to Tighter Bounds on Prediction Error.** 57 | 58 | Theorem 1 59 | 60 | Theorem 2 61 | 62 | 63 | 64 | **Non-reasoning LLMs exhibit weaker and less pronounced MI peaks compared to LRMs. And the overall MI in non-reasoning LLMs during the reasoning process is lower than their 65 | corresponding LRMs.** 66 | 67 | Figure 3 68 | 69 | 70 | 71 | 72 | **The tokens that appear at MI peaks are mostly connective words that express self-reflection or 73 | transitions in LRM’s reasoning process.** 74 | 75 | Figure 3 76 | 77 | 78 | ## 🚀Quick Start 79 | 80 | 81 | ### 🔧Requirements 82 | 83 | The following pakages are required to run the code: 84 | 85 | - python==3.11.5 86 | 87 | - pytorch==2.1.2 88 | 89 | - transformers==4.46.1 90 | 91 | - numpy==1.26.4 92 | 93 | ### 🌟Usage 94 | 95 | ```bash 96 | cd src/ 97 | ``` 98 | 99 | **1. Collect the representations and compute the MI** 100 | 101 | ```bash 102 | sh scripts/compute_mi_trajectories.sh 103 | ``` 104 | 105 | **2. Plot figures to observe the MI Peaks phenomenon** 106 | 107 | ```bash 108 | run the plot_mi_peaks.ipynb 109 | ``` 110 | 111 | 112 | **3. Run the Representation Recycling (RR)** 113 | 114 | ```bash 115 | sh scripts/run_RR.sh 116 | ``` 117 | 118 | ## 📝License 119 | Distributed under the Apache-2.0 License. See LICENSE for more information. 120 | 121 | 122 | ## Acknowledgements 123 | 124 | Some code in this project is adapted from resources provided by the following repositories: 125 | 126 | - https://github.com/QwenLM/Qwen2.5-Math 127 | - https://github.com/ChnQ/TracingLLM 128 | 129 | We greatly appreciate the contributions of the original authors. 130 | 131 | ## 📖BibTeX 132 | ``` 133 | @article{qian2025demystifying, 134 | title={Demystifying Reasoning Dynamics with Mutual Information: Thinking Tokens are Information Peaks in LLM Reasoning}, 135 | author={Qian, Chen and Liu, Dongrui and Wen, Haochen and Bai, Zhen and Liu, Yong and Shao, Jing}, 136 | journal={arXiv preprint arXiv:2506.02867}, 137 | year={2025} 138 | } -------------------------------------------------------------------------------- /src/applications/trajectory.py: -------------------------------------------------------------------------------- 1 | import re 2 | """ 3 | trajcectory: 4 | [ 5 | {"role": "rationale", "content": "..."}, 6 | {"role": "program", "content": "..."}, 7 | {"role": "output", "content": "..."}, 8 | {"role": "rationale", "content": "..."}, 9 | ... 10 | ] 11 | """ 12 | 13 | def text_to_trajectory(traj_str: str) -> None: 14 | """ 15 | """ 16 | # parse the above interleaved string of raionale, program, output, raionale, program, output, ... 17 | # output a list of dict 18 | trajectory = [] 19 | cur_role = "rationale" 20 | cur_content = "" 21 | 22 | # print(traj_str) 23 | for i, line in enumerate(traj_str.split("\n")): 24 | if line == "```python": # program begin 25 | assert cur_role == "rationale" 26 | if cur_content: 27 | trajectory.append({"role": cur_role, "content": cur_content}) 28 | cur_content = "" 29 | cur_role = "program" 30 | elif cur_role == "program" and line == "```": # program end 31 | assert cur_content 32 | trajectory.append({"role": cur_role, "content": cur_content}) 33 | cur_content = "" 34 | cur_role = "output" 35 | elif cur_role == "output" and line.startswith("```output"): # output begin 36 | assert cur_content == "" 37 | elif cur_role == "output" and line == "```": # output end 38 | trajectory.append({"role": cur_role, "content": cur_content}) 39 | cur_content = "" 40 | cur_role = "rationale" 41 | else: # content 42 | cur_content += line 43 | if i < len(traj_str.split("\n")) - 1: 44 | cur_content += "\n" 45 | # the last content 46 | if cur_content: 47 | trajectory.append({"role": cur_role, "content": cur_content}) 48 | return trajectory 49 | 50 | 51 | def trajectory_to_text(trajectory: list) -> str: 52 | text = "" 53 | for item in trajectory: 54 | content = item["content"] 55 | if item["role"] == "program": 56 | content = f"```python\n{content}```\n" 57 | elif item["role"] == "output": 58 | content = f"```output\n{content}```\n" 59 | text += content 60 | return text 61 | 62 | 63 | def is_execution_success(output): 64 | error_key_words = ["error", "exception", "no algorithms", "no algorithms", "cannot", "nan", "..."] 65 | success = all([k not in output.lower() for k in error_key_words]) 66 | return success 67 | 68 | 69 | def extract_program(text:str=None, trajectory:list=None, last_only=False) -> str: 70 | assert text is not None or trajectory is not None, "Either text or trajectory should be provided." 71 | if trajectory is None: 72 | try: 73 | trajectory = text_to_trajectory(text) 74 | except: 75 | return "raise ValueError('Invalid trajectory')" 76 | 77 | program_list = [] 78 | import_lines = [] 79 | for i, item in enumerate(trajectory): 80 | if item["role"] == "program": 81 | cur_program = item["content"] 82 | if i < len(trajectory) - 1: 83 | assert trajectory[i+1]["role"] == "output" 84 | output = trajectory[i+1]["content"].strip() 85 | if is_execution_success(output): 86 | program_list.append(cur_program) 87 | else: 88 | # extract import lines only 89 | for line in cur_program.split("\n"): 90 | if line.startswith("import") or line.startswith("from"): 91 | import_lines.append(line) 92 | else: 93 | program_list.append(cur_program) 94 | # add import lines to the first program 95 | if len(program_list) == 0: 96 | program_list.append("") 97 | if len(import_lines) > 0: 98 | program_list[0] = "\n".join(import_lines) + "\n" + program_list[0] 99 | for i, program in enumerate(program_list[:-1]): 100 | program_list[i] = "\n".join([line for line in program.split("\n") if not line.strip().startswith("print(")]) 101 | 102 | if last_only: 103 | program = program_list[-1] 104 | else: 105 | program = "\n".join(program_list) 106 | return program 107 | 108 | 109 | def extract_program_output(pred_str, last_only=True): 110 | """ 111 | extract output between ```output\n...\n```, use regex, there might be multiple outputs, each output may have multiple lines 112 | """ 113 | outputs = re.findall(r"```output\n(.*?)\n```", pred_str, re.DOTALL) 114 | if last_only: 115 | return outputs[-1] if len(outputs) > 0 else "" 116 | else: 117 | return outputs 118 | 119 | 120 | def _test_str_to_trajectory(): 121 | traj_text = """We can use the Chinese Remainder Theorem to solve this system of congruences. The Chinese Remainder Theorem states that if $n_1, n_2, \ldots, n_k$ are pairwise coprime (i.e., the greatest common divisor of any two of them is 1), and if $a_1, a_2, \ldots, a_k$ are any integers, then there is a number $x$ that satisfies the system of congruences: 122 | 123 | $x \equiv a_1 \pmod{n_1}$ 124 | $x \equiv a_2 \pmod{n_2}$ 125 | $\vdots$ 126 | $x \equiv a_k \pmod{n_k}$ 127 | 128 | Moreover, any two solutions of this system are congruent modulo the product $n_1n_2 \cdots n_k$. 129 | 130 | Let's write a function to solve this: 131 | 132 | ```python 133 | from sympy import symbols, solve 134 | 135 | def smallest_x(): 136 | x = symbols('x') 137 | # Solve the system of congruences 138 | solution = solve([x - 7, x - 17], x, domain='Z') 139 | # Return the smallest positive solution 140 | return min([sol for sol in solution if sol >= 0]) 141 | 142 | result = smallest_x() 143 | print(result) 144 | ``` 145 | ```output 146 | ValueError: min() arg is an empty sequence 147 | ``` 148 | The error in the last example is due to the incorrect usage of the `solve` function. The `solve` function is used to solve equations, not congruences. To solve a system of congruences, we can use the `solve_congruence` function from the `sympy.ntheory.modular` module. 149 | 150 | Here is the corrected code: 151 | 152 | ```python 153 | from sympy.ntheory.modular import solve_congruence 154 | 155 | def smallest_x(): 156 | # Solve the system of congruences 157 | solution = solve_congruence(('x', 7, 12), ('x', 17, 11)) 158 | # Return the smallest positive solution 159 | return solution[0][0] 160 | 161 | result = smallest_x() 162 | print(result) 163 | ``` 164 | ```output 165 | good 166 | ``` 167 | ValueError: too many values to unpack (expected 2) 168 | The error in the last example is due to the incorrect usage of the `solve_congruence` function. The `solve_congruence` function returns a tuple of two elements: the solution and the modulus. We need to unpack these two values correctly. 169 | 170 | Here is the corrected code: 171 | 172 | ```python 173 | from sympy.ntheory.modular import solve_congruence 174 | 175 | def smallest_x(): 176 | # Solve the system of congruences 177 | solution, modulus = solve_congruence(('x', 7, 12), ('x', 17, 11)) 178 | # Return the smallest positive solution 179 | return solution 180 | 181 | result = smallest_x() 182 | print(result) 183 | ```""" 184 | 185 | import pprint 186 | trajectory = text_to_trajectory(traj_text) 187 | pprint.pprint(trajectory) 188 | 189 | text = trajectory_to_text(trajectory) 190 | assert text == traj_text 191 | 192 | # print(extract_program(traj_text)) 193 | 194 | 195 | if __name__ == "__main__": 196 | _test_str_to_trajectory() -------------------------------------------------------------------------------- /src/applications/python_executor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import regex 4 | import pickle 5 | import traceback 6 | import copy 7 | import datetime 8 | import dateutil.relativedelta 9 | import multiprocess 10 | from multiprocess import Pool 11 | from typing import Any, Dict, Optional 12 | from pebble import ProcessPool 13 | from tqdm import tqdm 14 | from concurrent.futures import TimeoutError 15 | from functools import partial 16 | from timeout_decorator import timeout 17 | from contextlib import redirect_stdout 18 | 19 | 20 | class GenericRuntime: 21 | GLOBAL_DICT = {} 22 | LOCAL_DICT = None 23 | HEADERS = [] 24 | def __init__(self): 25 | self._global_vars = copy.copy(self.GLOBAL_DICT) 26 | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None 27 | 28 | for c in self.HEADERS: 29 | self.exec_code(c) 30 | 31 | def exec_code(self, code_piece: str) -> None: 32 | if regex.search(r'(\s|^)?input\(', code_piece): 33 | # regex.search(r'(\s|^)?os.', code_piece): 34 | raise RuntimeError() 35 | exec(code_piece, self._global_vars) 36 | 37 | # TODO: use: https://github.com/shroominic/codebox-api 38 | # @high safe exec in sandbox 39 | # byte_code = compile_restricted( 40 | # code_piece, 41 | # filename='', 42 | # mode='exec' 43 | # ) 44 | # print("global vars:", self._global_vars) 45 | # _print_ = PrintCollector 46 | # exec(byte_code, {'__builtins__': utility_builtins}, None) 47 | 48 | def eval_code(self, expr: str) -> Any: 49 | return eval(expr, self._global_vars) 50 | 51 | def inject(self, var_dict: Dict[str, Any]) -> None: 52 | for k, v in var_dict.items(): 53 | self._global_vars[k] = v 54 | 55 | @property 56 | def answer(self): 57 | return self._global_vars['answer'] 58 | 59 | class DateRuntime(GenericRuntime): 60 | GLOBAL_DICT = { 61 | 'datetime': datetime.datetime, 62 | 'timedelta': dateutil.relativedelta.relativedelta, 63 | 'relativedelta': dateutil.relativedelta.relativedelta 64 | } 65 | 66 | 67 | class CustomDict(dict): 68 | def __iter__(self): 69 | return list(super().__iter__()).__iter__() 70 | 71 | class ColorObjectRuntime(GenericRuntime): 72 | GLOBAL_DICT = {'dict': CustomDict} 73 | 74 | 75 | class PythonExecutor: 76 | def __init__( 77 | self, 78 | runtime: Optional[Any] = None, 79 | get_answer_symbol: Optional[str] = None, 80 | get_answer_expr: Optional[str] = None, 81 | get_answer_from_stdout: bool = False, 82 | timeout_length: int = 5, 83 | ) -> None: 84 | self.runtime = runtime if runtime else GenericRuntime() 85 | self.answer_symbol = get_answer_symbol 86 | self.answer_expr = get_answer_expr 87 | self.get_answer_from_stdout = get_answer_from_stdout 88 | self.pool = Pool(multiprocess.cpu_count()) 89 | self.timeout_length = timeout_length 90 | 91 | def process_generation_to_code(self, gens: str): 92 | return [g.strip().split('\n') for g in gens] 93 | 94 | @staticmethod 95 | def execute( 96 | code, 97 | get_answer_from_stdout = None, 98 | runtime = None, 99 | answer_symbol = None, 100 | answer_expr = None, 101 | timeout_length = 10, 102 | auto_mode=False 103 | ): 104 | try: 105 | if auto_mode: 106 | if "print(" in code[-1]: 107 | program_io = io.StringIO() 108 | with redirect_stdout(program_io): 109 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 110 | program_io.seek(0) 111 | result = program_io.read() 112 | else: 113 | print(code) 114 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1])) 115 | result = timeout(timeout_length)(runtime.eval_code)(code[-1]) 116 | else: 117 | if get_answer_from_stdout: 118 | program_io = io.StringIO() 119 | with redirect_stdout(program_io): 120 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 121 | program_io.seek(0) 122 | result = program_io.read() 123 | elif answer_symbol: 124 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 125 | result = runtime._global_vars[answer_symbol] 126 | elif answer_expr: 127 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) 128 | result = timeout(timeout_length)(runtime.eval_code)(answer_expr) 129 | else: 130 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1])) 131 | result = timeout(timeout_length)(runtime.eval_code)(code[-1]) 132 | report = "Done" 133 | str(result) 134 | pickle.dumps(result) # serialization check 135 | except: 136 | result = '' 137 | report = traceback.format_exc().split('\n')[-2] 138 | return result, report 139 | 140 | def apply(self, code): 141 | return self.batch_apply([code])[0] 142 | 143 | @staticmethod 144 | def truncate(s, max_length=400): 145 | half = max_length // 2 146 | if len(s) > max_length: 147 | s = s[:half] + "..." + s[-half:] 148 | return s 149 | 150 | def batch_apply(self, batch_code): 151 | all_code_snippets = self.process_generation_to_code(batch_code) 152 | 153 | timeout_cnt = 0 154 | all_exec_results = [] 155 | # with ProcessPool(max_workers=min(len(all_code_snippets), os.cpu_count())) as pool: 156 | with ProcessPool(max_workers=min(len(all_code_snippets), 1)) as pool: 157 | executor = partial( 158 | self.execute, 159 | get_answer_from_stdout=self.get_answer_from_stdout, 160 | runtime=self.runtime, 161 | answer_symbol=self.answer_symbol, 162 | answer_expr=self.answer_expr, 163 | timeout_length=self.timeout_length, # this timeout not work 164 | auto_mode=True 165 | ) 166 | future = pool.map(executor, all_code_snippets, timeout=self.timeout_length) 167 | iterator = future.result() 168 | 169 | if len(all_code_snippets) > 100: 170 | progress_bar = tqdm(total=len(all_code_snippets), desc="Execute") 171 | else: 172 | progress_bar = None 173 | 174 | while True: 175 | try: 176 | result = next(iterator) 177 | all_exec_results.append(result) 178 | except StopIteration: 179 | break 180 | except TimeoutError as error: 181 | print(error) 182 | all_exec_results.append(("", "Timeout Error")) 183 | timeout_cnt += 1 184 | except Exception as error: 185 | print(error) 186 | exit() 187 | if progress_bar is not None: 188 | progress_bar.update(1) 189 | 190 | if progress_bar is not None: 191 | progress_bar.close() 192 | 193 | batch_results = [] 194 | for code, (res, report) in zip(all_code_snippets, all_exec_results): 195 | # post processing 196 | res, report = str(res).strip(), str(report).strip() 197 | res, report = self.truncate(res), self.truncate(report) 198 | batch_results.append((res, report)) 199 | return batch_results 200 | 201 | 202 | def _test(): 203 | batch_code = [ 204 | """ 205 | from sympy import Matrix 206 | 207 | def null_space_basis(): 208 | # Define the matrix 209 | A = Matrix([[3, 3, -1, -6], [9, -1, -8, -1], [7, 4, -2, -9]]) 210 | 211 | # Compute the basis for the null space 212 | basis = A.nullspace() 213 | 214 | # Round the elements of the basis vectors to three decimal places 215 | basis_rounded = [v.evalf(3) for v in basis] 216 | 217 | return basis_rounded 218 | 219 | result = null_space_basis() 220 | print(result) 221 | """ 222 | ] 223 | 224 | executor = PythonExecutor(get_answer_from_stdout=True) 225 | predictions = executor.apply(batch_code[0]) 226 | print(predictions) 227 | 228 | 229 | if __name__ == '__main__': 230 | _test() -------------------------------------------------------------------------------- /src/plot_mi_peaks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Plot the MI trajectories" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "\n", 17 | "import seaborn as sns\n", 18 | "import torch\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "import seaborn as sns\n", 21 | "import json\n", 22 | "import numpy as np\n", 23 | "from transformers import AutoTokenizer\n", 24 | "\n", 25 | "sns.set_theme(style=\"whitegrid\", context=\"talk\", palette=\"muted\", font_scale=1.0)\n", 26 | "\n", 27 | "\n", 28 | "def load_model_data(model_path, dataset, target_layer=31):\n", 29 | " tokenizer = AutoTokenizer.from_pretrained(model_path)\n", 30 | " model_name = model_path.split('/')[-1]\n", 31 | "\n", 32 | " data_path = f'results/mi/{dataset}_gtmodel={model_name}_testmodel={model_name}.pth'\n", 33 | " data = torch.load(data_path)\n", 34 | "\n", 35 | " all_sample_mi_list = []\n", 36 | " all_mi_peak_list = [] \n", 37 | "\n", 38 | " for id in data.keys():\n", 39 | " try:\n", 40 | " this_id_mi_list = data[id]['reps'][target_layer]\n", 41 | " \n", 42 | " all_sample_mi_list.append(this_id_mi_list[:])\n", 43 | "\n", 44 | " top_indices = sorted(range(len(this_id_mi_list)), \n", 45 | " key=lambda i: this_id_mi_list[i], reverse=True)[:20] # approximately take top-20\n", 46 | " all_mi_peak_list.append(top_indices)\n", 47 | "\n", 48 | " except Exception as e:\n", 49 | " print(f'[id:{id}] Error:', e)\n", 50 | "\n", 51 | " return all_sample_mi_list, all_mi_peak_list\n", 52 | "\n", 53 | "\n", 54 | "dataset = 'math_train_12k'\n", 55 | "model_path = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'\n", 56 | "model_name = model_path.split('/')[-1]\n", 57 | "\n", 58 | "\n", 59 | "model_data_dict = {}\n", 60 | "\n", 61 | "mi_list, mi_peak_list = load_model_data(model_path, dataset, target_layer=31)\n", 62 | "model_data_dict[model_name] = {\n", 63 | " 'mi': mi_list,\n", 64 | " 'peaks': mi_peak_list,\n", 65 | "}\n", 66 | "\n", 67 | "\n", 68 | "fig, axes = plt.subplots(2, 5, figsize=(30, 10))\n", 69 | "axes = axes.flatten()\n", 70 | "\n", 71 | "for sample_idx in range(10):\n", 72 | " ax = axes[sample_idx]\n", 73 | " mi_values = model_data_dict[model_name]['mi'][sample_idx]\n", 74 | " steps = model_data_dict[model_name]['peaks'][sample_idx]\n", 75 | "\n", 76 | " ax.plot(mi_values, linewidth=2, alpha=0.8)\n", 77 | " ax.scatter(steps, [mi_values[i] for i in steps],\n", 78 | " s=50, edgecolor=\"white\", linewidth=0.8, zorder=3)\n", 79 | "\n", 80 | " ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n", 81 | " ax.set_facecolor(\"#fafafa\")\n", 82 | " sns.despine(ax=ax, top=True, right=True)\n", 83 | "\n", 84 | " ax.set_title(model_name, fontsize=22, pad=8) \n", 85 | " ax.set_xlabel(\"Reasoning Step\", fontsize=18) \n", 86 | " ax.set_ylabel(\"MI Value\", fontsize=20)\n", 87 | " ax.tick_params(axis=\"x\", labelsize=18) \n", 88 | " ax.tick_params(axis=\"y\", labelsize=18)\n", 89 | "\n", 90 | "\n", 91 | "plt.subplots_adjust(hspace=0.4, wspace=0.13)\n", 92 | "\n", 93 | "plt.show()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "### Projecting the MI-peak representations to token space " 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "import re\n", 110 | "import torch\n", 111 | "import matplotlib.pyplot as plt\n", 112 | "import seaborn as sns\n", 113 | "import json\n", 114 | "import numpy as np\n", 115 | "from transformers import AutoTokenizer\n", 116 | "from collections import Counter\n", 117 | "from matplotlib.colors import LinearSegmentedColormap\n", 118 | "\n", 119 | "sns.set(style=\"whitegrid\", context=\"talk\", font_scale=1.3)\n", 120 | "\n", 121 | "\n", 122 | "def plot_token_freq(model_path, dataset, target_layer=31):\n", 123 | "\n", 124 | " model_name = model_path.split('/')[-1]\n", 125 | " tokenizer = AutoTokenizer.from_pretrained(model_path)\n", 126 | "\n", 127 | " data_path = f'results/mi/{dataset}_gtmodel={model_name}_testmodel={model_name}.pth'\n", 128 | " data = torch.load(data_path)\n", 129 | "\n", 130 | " acts = torch.load(f'acts/reasoning_evolve/{dataset}_{model_name}.pth')\n", 131 | "\n", 132 | " all_sample_mi_list = [] \n", 133 | " all_mi_peak_list = []\n", 134 | " all_tokens = []\n", 135 | "\n", 136 | " fail_id_list = []\n", 137 | " for id in data.keys():\n", 138 | " try:\n", 139 | " this_id_mi_list = data[id]['reps'][target_layer] \n", 140 | " all_sample_mi_list.append(this_id_mi_list[:]) \n", 141 | " top_indices = sorted(range(len(this_id_mi_list)), key=lambda i: this_id_mi_list[i], reverse=True)[:20] # approximately take top-20\n", 142 | "\n", 143 | " all_mi_peak_list.append(top_indices)\n", 144 | "\n", 145 | " token_list = acts[id]['token_ids'].tolist()\n", 146 | " token_list.append(2) # [eos] token\n", 147 | " top_prob_token_ids = [token_list[i] for i in top_indices]\n", 148 | "\n", 149 | " batch_top_n_tokens = tokenizer.batch_decode(top_prob_token_ids, skip_special_tokens=False)\n", 150 | " all_tokens.extend(batch_top_n_tokens)\n", 151 | "\n", 152 | " except Exception as e:\n", 153 | " fail_id_list.append(id)\n", 154 | "\n", 155 | " print('fail_id_list:', fail_id_list)\n", 156 | "\n", 157 | "\n", 158 | " english_pattern = re.compile(r'^[a-zA-Z]+$')\n", 159 | "\n", 160 | " processed_all_tokens = []\n", 161 | " for token in all_tokens:\n", 162 | " if english_pattern.match(token.strip()):\n", 163 | " processed_all_tokens.append(token)\n", 164 | "\n", 165 | " \n", 166 | " token_freq = Counter(processed_all_tokens)\n", 167 | "\n", 168 | " common_tokens = token_freq.most_common(15)\n", 169 | " print('$'*50)\n", 170 | " print(f'model: {model_name}')\n", 171 | " print(\"Most common tokens:\", common_tokens)\n", 172 | "\n", 173 | " colors = [\n", 174 | " (0.0, \"#3d61aa\"), \n", 175 | " (0.5, \"#b1bee9\"), \n", 176 | " (1.0, \"#9673c4\") \n", 177 | " ]\n", 178 | " cmap = LinearSegmentedColormap.from_list(\"blue_purple\", colors)\n", 179 | "\n", 180 | "\n", 181 | " # ------------------------------ plot --------------------------------------\n", 182 | "\n", 183 | " token_names, token_counts = zip(*common_tokens)\n", 184 | " token_names_processed = [token.replace('$', '\\$').replace('_', '\\_').replace('^', '\\^') for token in token_names]\n", 185 | " token_names_repr = [repr(token) for token in token_names_processed]\n", 186 | "\n", 187 | " n_bars = len(token_names_repr)\n", 188 | " palette = [cmap(i / (n_bars - 1)) for i in range(n_bars)]\n", 189 | "\n", 190 | " plt.figure(figsize=(8, 5))\n", 191 | " sns.barplot(x=list(token_names_repr), y=list(token_counts), palette=palette)\n", 192 | " plt.xticks(rotation=45, ha='right', size=17)\n", 193 | " plt.xlabel('Tokens at MI Peaks')\n", 194 | " plt.ylabel('Frequency')\n", 195 | " plt.title(f'{model_name}')\n", 196 | "\n", 197 | "\n", 198 | " plt.show()\n", 199 | " \n", 200 | "\n", 201 | "dataset = 'math_train_12k'\n", 202 | "model_path = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'\n", 203 | "\n", 204 | "\n", 205 | "plot_token_freq(\n", 206 | " model_path=model_path,\n", 207 | " dataset=dataset,\n", 208 | " target_layer=31,\n", 209 | ")\n", 210 | "\n" 211 | ] 212 | } 213 | ], 214 | "metadata": { 215 | "kernelspec": { 216 | "display_name": "vllm053", 217 | "language": "python", 218 | "name": "python3" 219 | }, 220 | "language_info": { 221 | "codemirror_mode": { 222 | "name": "ipython", 223 | "version": 3 224 | }, 225 | "file_extension": ".py", 226 | "mimetype": "text/x-python", 227 | "name": "python", 228 | "nbconvert_exporter": "python", 229 | "pygments_lexer": "ipython3", 230 | "version": "3.11.5" 231 | } 232 | }, 233 | "nbformat": 4, 234 | "nbformat_minor": 2 235 | } 236 | -------------------------------------------------------------------------------- /src/applications/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import json 5 | import os 6 | import numpy as np 7 | from pathlib import Path 8 | from typing import Iterable, Union, Any 9 | 10 | from examples import get_examples 11 | 12 | 13 | def set_seed(seed: int = 42) -> None: 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | os.environ["PYTHONHASHSEED"] = str(seed) 17 | print(f"Random seed set as {seed}") 18 | 19 | 20 | def load_jsonl(file: Union[str, Path]) -> Iterable[Any]: 21 | with open(file, "r", encoding="utf-8") as f: 22 | for line in f: 23 | try: 24 | yield json.loads(line) 25 | except: 26 | print("Error in loading:", line) 27 | exit() 28 | 29 | 30 | def save_jsonl(samples, save_path): 31 | # ensure path 32 | folder = os.path.dirname(save_path) 33 | os.makedirs(folder, exist_ok=True) 34 | 35 | with open(save_path, "w", encoding="utf-8") as f: 36 | for sample in samples: 37 | f.write(json.dumps(sample, ensure_ascii=False) + "\n") 38 | print("Saved to", save_path) 39 | 40 | 41 | def lower_keys(example): 42 | new_example = {} 43 | for key, value in example.items(): 44 | if key != key.lower(): 45 | new_key = key.lower() 46 | new_example[new_key] = value 47 | else: 48 | new_example[key] = value 49 | return new_example 50 | 51 | 52 | EXAMPLES = get_examples() 53 | 54 | 55 | def load_prompt(data_name, prompt_type, num_shots): 56 | if not num_shots: 57 | return [] 58 | 59 | if data_name in ["gsm_hard", "svamp", "tabmwp", "asdiv", "mawps"]: 60 | data_name = "gsm8k" 61 | if data_name in ["math_oai", "hungarian_exam", "math-oai", "aime24", "amc23"]: 62 | data_name = "math" 63 | if data_name in ["sat_math"]: 64 | data_name = "mmlu_stem" 65 | if data_name in [ 66 | "gaokao2024_I", 67 | "gaokao2024_II", 68 | "gaokao_math_qa", 69 | "gaokao2024_mix", 70 | "cn_middle_school", 71 | ]: 72 | data_name = "gaokao" 73 | 74 | if prompt_type in ["tool-integrated"]: 75 | prompt_type = "tora" 76 | 77 | return EXAMPLES[data_name][:num_shots] 78 | 79 | 80 | PROMPT_TEMPLATES = { 81 | "direct": ("Question: {input}\nAnswer: ", "{output}", "\n\n"), 82 | "cot": ("Question: {input}\nAnswer: ", "{output}", "\n\n\n"), 83 | "pal": ("Question: {input}\n\n", "{output}", "\n---\n"), 84 | "tool-integrated": ("Question: {input}\n\nSolution:\n", "{output}", "\n---\n"), 85 | "self-instruct": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"), 86 | "tora": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"), 87 | "wizard_zs": ( 88 | "### Instruction:\n{input}\n\n### Response: Let's think step by step.", 89 | "{output}", 90 | "\n\n\n", 91 | ), 92 | "platypus_fs": ( 93 | "### Instruction:\n{input}\n\n### Response:\n", 94 | "{output}", 95 | "\n\n\n", 96 | ), 97 | "deepseek-math": ( 98 | "User: {input}\nPlease reason step by step, " 99 | "and put your final answer within \\boxed{{}}.\n\nAssistant:", 100 | "{output}", 101 | "\n\n\n", 102 | ), 103 | "kpmath": ( 104 | "User: Please reason step by step and put your final answer at the end " 105 | 'with "The answer is: ".\n\n{input}\n\nAssistant:', 106 | "{output}", 107 | ), 108 | "jiuzhang": ( 109 | "## Question\n{input}\n\n## Solution\n", 110 | "{output}", 111 | "\n\n\n", 112 | ), 113 | "jiuzhang_tora": ( 114 | "## Question\n{input}\n\n## Code Solution\n", 115 | "{output}", 116 | "\n\n\n", 117 | ), 118 | "jiuzhang_nl": ( 119 | "## Question\n{input}\n\n## Natural Language Solution\n", 120 | "{output}", 121 | "\n\n\n", 122 | ), 123 | "mmiqc": ( 124 | 'Please solve the following problem and put your answer at the end with "The answer is: ".\n\n{input}\n\n', 125 | "{output}", 126 | "\n\n\n", 127 | ), 128 | "abel": ( 129 | "Question:\n{input}\nAnswer:\nLet's think step by step.\n", 130 | "{output}", 131 | "\n\n", 132 | ), 133 | "shepherd": ("{input}\n", "{output}", "\n\n\n"), 134 | "qwen-boxed": ( 135 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 136 | "<|im_start|>user\n{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" 137 | "<|im_start|>assistant\n", 138 | "{output}", 139 | "\n\n", 140 | ), 141 | "qwen25-math-cot": ( 142 | "<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n" 143 | "<|im_start|>user\n{input}<|im_end|>\n" 144 | "<|im_start|>assistant\n", 145 | "{output}", 146 | "\n\n", 147 | ), 148 | "mathstral": ( 149 | "{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.", 150 | "{output}", 151 | "\n\n", 152 | ), 153 | "internlm-math-fs": ("Question:{input}\nAnswer:", "{output}", "\n"), 154 | "internlm-math-chat": ( 155 | "<|im_start|>user\n{input}<|im_end|>\n" "<|im_start|>assistant\n", 156 | "{output}", 157 | "\n\n", 158 | ), 159 | "mistral": ( 160 | "[INST] {input}[/INST]", 161 | "{output}", 162 | "\n\n", 163 | ), 164 | "numina": ("### Problem: {input}\n### Solution:", " {output}", "\n\n"), 165 | } 166 | 167 | 168 | def construct_prompt(example, data_name, args): 169 | if args.adapt_few_shot and data_name in [ 170 | "gaokao2024_I", 171 | "gaokao2024_II", 172 | "gaokao_math_qa", 173 | "gaokao2024_mix", 174 | "cn_middle_school", 175 | ]: 176 | demos = load_prompt(data_name, args.prompt_type, 5) 177 | else: 178 | demos = load_prompt(data_name, args.prompt_type, args.num_shots) 179 | prompt_type = args.prompt_type 180 | if prompt_type == "platypus_fs": 181 | prompt_type = "cot" 182 | if prompt_type == "tool-integrated": 183 | prompt_type = "tora" 184 | 185 | prompt_temp = PROMPT_TEMPLATES[args.prompt_type] 186 | 187 | splitter = prompt_temp[2] 188 | input_template, output_template, splitter = ( 189 | prompt_temp[0], 190 | prompt_temp[1], 191 | prompt_temp[2], 192 | ) 193 | if args.prompt_type == "qwen25-math-cot": 194 | # Hotfix to support putting all demos into a single turn 195 | demo_prompt = splitter.join([q + "\n" + a for q, a in demos]) 196 | else: 197 | demo_prompt = splitter.join( 198 | [ 199 | input_template.format(input=q) + output_template.format(output=a) 200 | for q, a in demos 201 | ] 202 | ) 203 | context = input_template.format(input=example["question"]) 204 | if len(demo_prompt) == 0 or ( 205 | args.adapt_few_shot and example["gt_ans"] not in ["A", "B", "C", "D", "E"] 206 | ): 207 | full_prompt = context 208 | else: 209 | if args.prompt_type == "qwen25-math-cot": 210 | # Hotfix to supportting put all demos into a single turn 211 | full_prompt = demo_prompt + splitter + example["question"] 212 | full_prompt = input_template.format(input=full_prompt) 213 | else: 214 | full_prompt = demo_prompt + splitter + context 215 | 216 | if args.prompt_type == "platypus_fs": 217 | full_prompt_temp = ( 218 | "Below is an instruction that describes a task. " 219 | "Write a response that appropriately completes the request.\n\n" 220 | "### Instruction:\n{instruction}\n\n### Response:\n" 221 | ) 222 | full_prompt = full_prompt_temp.format(instruction=full_prompt) 223 | 224 | if prompt_type == "tora": 225 | full_prompt = ( 226 | """Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines: 227 | 228 | - Analyze the question and write functions to solve the problem; the function should not take any arguments. 229 | - Present the final result in LaTeX using a `\boxed{}` without any units. 230 | - Utilize the `pi` symbol and `Rational`` from Sympy for $\pi$ and fractions, and simplify all fractions and square roots without converting them to decimal values. 231 | 232 | Here are some examples you may refer to: 233 | 234 | --- 235 | 236 | """ 237 | + full_prompt 238 | ) 239 | 240 | return full_prompt.strip(" ") # important! 241 | 242 | 243 | key_map = { 244 | "gt": "Ground Truth", 245 | "pred": "Prediction", 246 | "gt_cot": "Reference CoT", 247 | "score": "Score", 248 | } 249 | 250 | 251 | def show_sample(sample, print_all_preds=False): 252 | print("==" * 20) 253 | for key in ["idx", "type", "level", "dataset"]: 254 | if key in sample: 255 | # capitalize 256 | print("{}: {}".format(key[0].upper() + key[1:], sample[key])) 257 | print("Question:", repr(sample["question"])) 258 | if "code" in sample: 259 | if print_all_preds: 260 | for code in sample["code"]: 261 | print("-" * 20) 262 | print("code:", code) 263 | print("Execution:", sample["report"]) 264 | else: 265 | print("Solution:\n", sample["code"][0]) 266 | print("Execution:", sample["report"][0]) 267 | if "pred" in sample: 268 | print("Prediction:", repr(sample["pred"][0])) 269 | for key in ["gt", "score", "unit", "gt_cot"]: 270 | if key in sample: 271 | _key = key_map.get(key, key) 272 | print("{}: {}".format(_key, repr(sample[key]))) 273 | print() 274 | -------------------------------------------------------------------------------- /src/applications/RR_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from typing import Optional, List 5 | import gc 6 | import json 7 | import time 8 | import random 9 | 10 | 11 | class RecursiveThinkingModel(torch.nn.Module): 12 | def __init__( 13 | self, 14 | base_model_name: str = None, 15 | extract_layer_id: int = None, 16 | inject_layer_id: int = None, 17 | num_recursive_steps: int = 1, # Number of recursive optimization steps 18 | use_recursive_thinking: bool = True, 19 | output_file: str = None 20 | ): 21 | super().__init__() 22 | 23 | 24 | self.base_model = AutoModelForCausalLM.from_pretrained(base_model_name) 25 | self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) 26 | self.tokenizer.pad_token = self.tokenizer.eos_token 27 | 28 | self.extract_layer_id = extract_layer_id or -1 29 | self.inject_layer_id = inject_layer_id or -1 30 | self.extracted_hidden = None 31 | 32 | self.num_recursive_steps = num_recursive_steps 33 | self.use_recursive_thinking = use_recursive_thinking 34 | 35 | 36 | self._find_layers() 37 | 38 | 39 | self.extract_hook = None 40 | self.inject_hook = None 41 | self.enable_inject = False 42 | self.enable_extract = True 43 | 44 | 45 | self.output_file = output_file 46 | if output_file is None: 47 | raise ValueError("output_file must be provided as a valid path string") 48 | 49 | 50 | def _find_layers(self): 51 | """Identify the layer structure of the model""" 52 | if hasattr(self.base_model, 'model') and hasattr(self.base_model.model, 'layers'): 53 | self.layers = self.base_model.model.layers 54 | elif hasattr(self.base_model, 'transformer') and hasattr(self.base_model.transformer, 'h'): 55 | self.layers = self.base_model.transformer.h 56 | elif hasattr(self.base_model, 'encoder') and hasattr(self.base_model.encoder, 'layer'): 57 | self.layers = self.base_model.encoder.layer 58 | else: 59 | raise ValueError(f"Unsupported model architecture: {type(self.base_model)}") 60 | 61 | def _register_hooks(self): 62 | """Register feature passing hooks""" 63 | 64 | # Remove existing hooks (if any) 65 | self._remove_hooks() 66 | 67 | # Feature extraction hook (captures representation at the specified layer) 68 | def extract_hook(module, inputs, outputs): 69 | if self.enable_extract: 70 | self.extracted_hidden = outputs[0].clone() 71 | 72 | # Feature injection hook (injects captured representation into the target layer) 73 | def inject_hook(module, inputs): 74 | if self.enable_inject and self.extracted_hidden is not None: 75 | modified_hidden_states = inputs[0].clone() 76 | modified_hidden_states[:, -1:, :] = self.extracted_hidden[:, -1:, :] 77 | 78 | return (modified_hidden_states,) + inputs[1:] 79 | 80 | return inputs 81 | 82 | 83 | # Register hooks 84 | self.inject_hook = self.layers[self.inject_layer_id].register_forward_pre_hook(inject_hook) 85 | self.extract_hook = self.layers[self.extract_layer_id].register_forward_hook(extract_hook) 86 | 87 | def _remove_hooks(self): 88 | """Remove hooks to prevent memory leaks""" 89 | if hasattr(self, 'extract_hook') and self.extract_hook is not None: 90 | self.extract_hook.remove() 91 | self.extract_hook = None 92 | 93 | if hasattr(self, 'inject_hook') and self.inject_hook is not None: 94 | self.inject_hook.remove() 95 | self.inject_hook = None 96 | 97 | def _clear_memory(self): 98 | """Clean up memory, release unnecessary tensors and cache""" 99 | gc.collect() 100 | if torch.cuda.is_available(): 101 | torch.cuda.empty_cache() 102 | 103 | @torch.inference_mode() 104 | def generate( 105 | self, 106 | max_tokens: int, 107 | prompt: str, 108 | interested_tokens: set = None, 109 | use_recursive_thinking: bool = True, 110 | **kwargs 111 | ): 112 | 113 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 114 | self.to(device).eval() 115 | 116 | # Register hooks only when needed 117 | if use_recursive_thinking: 118 | self._register_hooks() 119 | 120 | try: 121 | # Initialize input 122 | inputs = self.tokenizer(prompt, return_tensors="pt").to(device) 123 | input_ids = inputs.input_ids 124 | attention_mask = inputs.attention_mask 125 | 126 | past_key_values = None 127 | generated = input_ids.clone() 128 | full_input_ids = input_ids.clone() # Save the complete input sequence 129 | print(f"Prompt: {prompt}") 130 | 131 | recursive_tokens_count = 0 132 | regular_tokens_count = 0 133 | lose_tokens_count = 0 134 | 135 | for _ in tqdm(range(max_tokens), desc="Generating"): 136 | # Regular generation to get candidate token 137 | 138 | outputs = self.base_model( 139 | input_ids=input_ids, 140 | attention_mask=attention_mask, 141 | past_key_values=past_key_values, 142 | use_cache=True, 143 | output_hidden_states=False 144 | ) 145 | next_token_logits = outputs.logits[:, -1, :] 146 | candidate_token = next_token_logits.argmax(dim=-1, keepdim=True) 147 | del next_token_logits 148 | 149 | if candidate_token.item() == self.tokenizer.eos_token_id: 150 | break 151 | 152 | # recursive decoding 153 | if use_recursive_thinking and candidate_token.item() in interested_tokens and random.random() < 0.5: 154 | recursive_tokens_count += 1 155 | print(f"Recursive token") 156 | 157 | final_token = None 158 | 159 | # ------------------------------ Forward propagation after replacing representation ------------------------------ 160 | self.enable_inject = True 161 | self.enable_extract = False 162 | self.base_model.config.use_cache=True 163 | 164 | recursive_outputs = self.base_model( 165 | input_ids=input_ids, 166 | attention_mask=attention_mask, 167 | past_key_values=past_key_values, 168 | use_cache=True 169 | ) 170 | self.enable_extract = True 171 | # Disable hooks, reset hidden state 172 | self.enable_inject = False 173 | final_logits = recursive_outputs.logits[:, -1, :] 174 | final_token = final_logits.argmax(dim=-1, keepdim=True) 175 | 176 | 177 | # ------------------------------ Update KV cache ------------------------------ 178 | past_key_values = recursive_outputs.past_key_values 179 | # Update sequence 180 | generated = torch.cat([generated,final_token], dim=-1) 181 | full_input_ids = torch.cat([full_input_ids,final_token], dim=-1) 182 | input_ids = final_token 183 | 184 | # Update attention_mask 185 | attention_mask = torch.cat([ 186 | attention_mask, 187 | torch.ones((1, 1), device=device) 188 | ], dim=1) 189 | # regular decoding 190 | else: 191 | regular_tokens_count += 1 192 | past_key_values = outputs.past_key_values 193 | # Update sequence 194 | generated = torch.cat([generated, candidate_token], dim=-1) 195 | full_input_ids = torch.cat([full_input_ids, candidate_token], dim=-1) 196 | input_ids = candidate_token 197 | 198 | # Update attention_mask 199 | attention_mask = torch.cat([ 200 | attention_mask, 201 | torch.ones((1, 1), device=device) 202 | ], dim=1) 203 | 204 | 205 | outputs = None 206 | 207 | if _ % 5000 == 0: 208 | gc.collect() 209 | if torch.cuda.is_available(): 210 | torch.cuda.empty_cache() 211 | 212 | 213 | # Save performance statistics 214 | response_str = self.tokenizer.decode(generated[0], skip_special_tokens=True) 215 | entry = { 216 | "query": prompt, 217 | "response": response_str, 218 | "performance": { 219 | "recursive_tokens": recursive_tokens_count, 220 | "regular_tokens": regular_tokens_count, 221 | "lose_tokens": lose_tokens_count, 222 | 223 | } 224 | } 225 | with open(self.output_file, "a") as f: 226 | f.write(json.dumps(entry, indent=2) + "\n") 227 | 228 | print(f"\nGenerated {recursive_tokens_count} recursive tokens and {regular_tokens_count} regular tokens") 229 | print('-'*50) 230 | 231 | generated = None 232 | full_input_ids = None 233 | past_key_values = None 234 | input_ids = None 235 | attention_mask = None 236 | 237 | 238 | return response_str 239 | 240 | finally: 241 | self._remove_hooks() 242 | self._clear_memory() -------------------------------------------------------------------------------- /src/applications/model_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/allenai/open-instruct 3 | """ 4 | import torch 5 | import tqdm 6 | from transformers import StoppingCriteria, StoppingCriteriaList 7 | 8 | 9 | class KeywordsStoppingCriteria(StoppingCriteria): 10 | def __init__(self, keywords_str, tokenizer): 11 | StoppingCriteria.__init__(self) 12 | self.current_context = [] 13 | self.tokenizer = tokenizer 14 | self.keywords_str = keywords_str 15 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 16 | if len(self.current_context) == 0: 17 | self.current_context = [[] for _ in range(input_ids.shape[0])] 18 | 19 | # self.current_context.append(input_ids[0][-1].item()) 20 | sequences_should_be_stopped = [] 21 | for i in range(input_ids.shape[0]): 22 | _id = input_ids[i][-1].item() 23 | self.current_context[i].append(_id) 24 | current_context = self.tokenizer.decode(self.current_context[i]) 25 | should_be_stopped = False 26 | for word in self.keywords_str: 27 | if word in current_context: 28 | should_be_stopped = True 29 | break 30 | sequences_should_be_stopped.append(should_be_stopped) 31 | return all(sequences_should_be_stopped) 32 | 33 | 34 | class KeyWordsCriteriaTrunc(StoppingCriteria): 35 | def __init__(self, stop_id_sequences, prompt_length): 36 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids" 37 | self.stop_sequences = stop_id_sequences 38 | self.prompt_length = prompt_length 39 | 40 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 41 | sequences_should_be_stopped = [] 42 | for i in range(input_ids.shape[0]): 43 | ids = input_ids[i][self.prompt_length:].tolist() 44 | should_be_stopped = False 45 | for stop_sequence in self.stop_sequences: 46 | if input_ids.shape[0] == 1: 47 | _ids = ids[-len(stop_sequence):] 48 | else: 49 | _ids = ids 50 | for j in range(len(_ids), 0, -len(stop_sequence)): 51 | if _ids[max(j - len(stop_sequence), 0): j] == stop_sequence: 52 | should_be_stopped = True 53 | break 54 | if should_be_stopped: 55 | break 56 | sequences_should_be_stopped.append(should_be_stopped) 57 | return all(sequences_should_be_stopped) 58 | 59 | 60 | class KeyWordsCriteria(StoppingCriteria): 61 | def __init__(self, stop_id_sequences): 62 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids" 63 | self.stop_sequences = stop_id_sequences 64 | 65 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 66 | sequences_should_be_stopped = [] 67 | for i in range(input_ids.shape[0]): 68 | sequence_should_be_stopped = False 69 | for stop_sequence in self.stop_sequences: 70 | if input_ids[i][-len(stop_sequence):].tolist() == stop_sequence: 71 | sequence_should_be_stopped = True 72 | break 73 | sequences_should_be_stopped.append(sequence_should_be_stopped) 74 | return all(sequences_should_be_stopped) 75 | 76 | 77 | @torch.no_grad() 78 | def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, add_special_tokens=True, disable_tqdm=False, **generation_kwargs): 79 | generations = [] 80 | if not disable_tqdm: 81 | progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions") 82 | 83 | num_return_sequences = generation_kwargs.get("num_return_sequences", 1) 84 | for i in range(0, len(prompts), batch_size): 85 | batch_prompts = prompts[i:i+batch_size] 86 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens) 87 | batch_input_ids = tokenized_prompts.input_ids 88 | attention_mask = tokenized_prompts.attention_mask 89 | 90 | if model.device.type == "cuda": 91 | batch_input_ids = batch_input_ids.cuda() 92 | attention_mask = attention_mask.cuda() 93 | 94 | # try: 95 | stop_criteria = KeywordsStoppingCriteria(stop_id_sequences, tokenizer) 96 | batch_outputs = model.generate( 97 | input_ids=batch_input_ids, 98 | attention_mask=attention_mask, 99 | stopping_criteria=StoppingCriteriaList([stop_criteria]), 100 | # stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None, 101 | # stopping_criteria=[KeyWordsCriteriaTrunc(stop_id_sequences, batch_input_ids.size(1))] if stop_id_sequences else None, 102 | **generation_kwargs 103 | ) 104 | 105 | # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate. 106 | # so some outputs still have the stop sequence, which we need to remove. 107 | # if stop_id_sequences: 108 | # for output_idx in range(batch_outputs.shape[0]): 109 | # for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]): 110 | # if any(batch_outputs[output_idx, token_idx: token_idx+len(stop_sequence)].tolist() == stop_sequence for stop_sequence in stop_id_sequences): 111 | # batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id 112 | # break 113 | 114 | # remove the prompt from the output 115 | # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs. 116 | # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token. 117 | # space is important for some tasks (e.g., code completion). 118 | batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True) 119 | batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True) 120 | # duplicate the prompts to match the number of return sequences 121 | batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)] 122 | batch_generations = [ 123 | output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs) 124 | ] 125 | 126 | # remove the remain stop sequence from the output. 127 | for idx, prediction in enumerate(batch_generations): 128 | for stop_sequence in stop_id_sequences: 129 | batch_generations[idx] = prediction.split(stop_sequence)[0] 130 | 131 | generations += batch_generations 132 | 133 | if not disable_tqdm: 134 | progress.update(len(batch_prompts)//num_return_sequences) 135 | 136 | assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences" 137 | return generations 138 | 139 | 140 | def load_hf_lm_and_tokenizer( 141 | model_name_or_path, 142 | tokenizer_name_or_path=None, 143 | device_map="auto", 144 | load_in_8bit=False, 145 | load_in_half=True, 146 | gptq_model=False, 147 | use_fast_tokenizer=False, 148 | padding_side="left", 149 | use_safetensors=False, 150 | ): 151 | import torch 152 | from transformers import AutoModelForCausalLM, AutoTokenizer 153 | 154 | if not tokenizer_name_or_path: 155 | tokenizer_name_or_path = model_name_or_path 156 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=use_fast_tokenizer, padding_side=padding_side, trust_remote_code=True) 157 | # tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, legacy=False, use_fast=use_fast_tokenizer, padding_side=padding_side, trust_remote_code=True) 158 | 159 | # set pad token to eos token if pad token is not set 160 | if tokenizer.pad_token is None: 161 | if tokenizer.unk_token: 162 | tokenizer.pad_token = tokenizer.unk_token 163 | tokenizer.pad_token_id = tokenizer.unk_token_id 164 | elif tokenizer.eos_token: 165 | tokenizer.pad_token = tokenizer.eos_token 166 | tokenizer.pad_token_id = tokenizer.eos_token_id 167 | else: 168 | raise ValueError("You are using a new tokenizer without a pad token." 169 | "This is not supported by this script.") 170 | 171 | # if tokenizer.pad_token is None: 172 | # tokenizer.pad_token = tokenizer.unk_token 173 | # tokenizer.pad_token_id = tokenizer.unk_token_id 174 | 175 | if gptq_model: 176 | from auto_gptq import AutoGPTQForCausalLM 177 | model_wrapper = AutoGPTQForCausalLM.from_quantized( 178 | model_name_or_path, device="cuda:0", use_triton=True 179 | ) 180 | model = model_wrapper.model 181 | elif load_in_8bit: 182 | model = AutoModelForCausalLM.from_pretrained( 183 | model_name_or_path, 184 | device_map=device_map, 185 | load_in_8bit=True 186 | ) 187 | else: 188 | # return "", tokenizer 189 | # defaul load in float16 190 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path, 191 | torch_dtype=torch.float16, 192 | device_map=device_map, 193 | trust_remote_code=True, 194 | use_safetensors=use_safetensors) 195 | if torch.cuda.is_available(): 196 | model = model.cuda() 197 | if load_in_half: 198 | model = model.half() 199 | model.eval() 200 | return model, tokenizer 201 | 202 | 203 | def _test_generate_completions(): 204 | model_name_or_path = "../models/codellama_7b/v1-16k" 205 | llm, tokenizer = load_hf_lm_and_tokenizer( 206 | model_name_or_path=model_name_or_path, 207 | load_in_half=True, 208 | use_fast_tokenizer=True, 209 | use_safetensors=True, 210 | ) 211 | # some math word problems 212 | prompts = [ 213 | "---\n1+1=2\n---2+2=4\n---3+3=6\n---4+4=8\n---5+5=10\n---6+6=", 214 | "---\n1+1=2\n---12+12=24\n---3+3=6\n---12345+12345=", 215 | # "A train leaves Chicago at 7am and travels at 60mph. Another train leaves Chicago at 9am and travels at 80mph. When will the second train overtake the first?", 216 | # "The sum of two numbers is 10. The difference of the same two numbers is 4. What are the two numbers?", 217 | ] 218 | 219 | stop_sequences = ["\n\n\n", "---"] 220 | # Because many tokenizers will treat the word after space differently from the original word alone, 221 | # to be consistent, we add a space before tokenization and remove it after tokenization. 222 | # stop_id_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences] 223 | outputs = generate_completions( 224 | model=llm, 225 | tokenizer=tokenizer, 226 | prompts=prompts, 227 | max_new_tokens=128, 228 | batch_size=16, 229 | # stop_id_sequences=stop_id_sequences, 230 | stop_id_sequences=stop_sequences, 231 | ) 232 | print(outputs) 233 | 234 | if __name__ == "__main__": 235 | _test_generate_completions() -------------------------------------------------------------------------------- /src/applications/math_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | import os 4 | import json 5 | import random 6 | import string 7 | from enum import Enum, auto 8 | from tqdm import tqdm 9 | from collections import OrderedDict 10 | import dataclasses 11 | import pandas as pd 12 | import timeout_decorator 13 | import mpmath 14 | import sympy as sp 15 | from sympy.parsing.latex import parse_latex 16 | import sympy as sp 17 | from sympy import simplify 18 | from sympy.printing import latex 19 | from sympy.core.relational import Relational 20 | from sympy.solvers.solveset import solvify 21 | from sympy.solvers.inequalities import reduce_inequalities 22 | from sympy.parsing.sympy_parser import ( 23 | parse_expr, 24 | standard_transformations, 25 | implicit_multiplication, 26 | ) 27 | 28 | 29 | def compare_numerical_ans(ans_p, ans_l): 30 | if ans_p is None: 31 | return False 32 | ans_p = ans_p.replace(",", "").replace("$", "") 33 | ans_l = ans_l.replace(",", "").replace("$", "") 34 | try: 35 | if ans_p.endswith("%"): 36 | ans_p = float(ans_p.rstrip("%")) / 100 37 | if isinstance(ans_p, str): 38 | ans_p = float(ans_p) 39 | if isinstance(ans_l, str): 40 | ans_l = float(ans_l) 41 | except Exception as e: 42 | return False 43 | return abs(ans_p - float(ans_l)) < 1e-3 44 | 45 | 46 | def my_parse_latex(expr_str): 47 | expr_str = expr_str.replace("dfrac", "frac") 48 | expr = parse_latex(expr_str) 49 | if "\\pi" in expr_str: 50 | expr = expr.subs({sp.Symbol("pi"): sp.pi}) 51 | expr = expr.subs({sp.Symbol("i"): sp.I}) 52 | return expr 53 | 54 | 55 | def is_number(element: str) -> bool: 56 | try: 57 | float(element.replace(" ", "")) 58 | return True 59 | except ValueError: 60 | return False 61 | 62 | 63 | def percentage_to_fraction(text): 64 | pattern = r"(\d+(\.\d+)?%)" 65 | matches = re.findall(pattern, text) 66 | for match in matches: 67 | percentage_str = match[0] 68 | percentage = float(percentage_str.strip("%")) / 100 69 | fraction = str(percentage) 70 | text = text.replace(percentage_str, fraction) 71 | return text 72 | 73 | 74 | def clean_expr_str(expr_str): 75 | expr_str = ( 76 | expr_str.replace(" . ", ".") 77 | .replace(". ", ".") 78 | .replace("**", "^") 79 | .replace("\\pm", "") 80 | .replace("*", "\\times ") 81 | .replace("\\\\", "\\") 82 | .replace("\\ne ", "\\neq ") 83 | .replace("!=", "\\neq") 84 | .replace(">=", "\\ge") 85 | .replace("<=", "\\le") 86 | .replace("≠", "\\neq") 87 | .replace("dfrac", "frac") 88 | .replace("tfrac", "frac") 89 | .replace("\\$", "") 90 | .replace("$", "") 91 | .replace("\\%", "") 92 | .replace("%", "") 93 | .replace("\\!", "") 94 | .replace("^\circ", "\\times \\pi / 180") 95 | .replace("//", "/") 96 | .replace('"', "") 97 | # .replace(",", "") # TODO 98 | ) 99 | # expr_str = re.sub(r"\^\s(.*)", r"\^\s{\1}", expr_str) 100 | expr_str = re.sub(r"\\+", r"\\", expr_str) 101 | expr_str = re.sub(r"\^\s?\((.*?)\)", r"^{\1}", expr_str) 102 | expr_str = re.sub(r"\\frac\s?(\d)\s?(\d+)", r"\\frac{\1}{\2}", expr_str) 103 | expr_str = re.sub(r"\\log_\s?(\d)\s?(\d+)", r"\\log_{\1}{\2}", expr_str) 104 | expr_str = re.sub(r"\\frac\s?{(.*?)}\s?(\d)", r"\\frac{\1}{\2}", expr_str) 105 | expr_str = re.sub(r"\\frac\s?(\d)\s?{(.*?)}", r"\\frac{\1}{\2}", expr_str) 106 | expr_str = re.sub(r"\\sqrt\s?(\d)", r"\\sqrt{\1}", expr_str) 107 | expr_str = re.sub(r"sqrt\s?\((\d+)\)", r"\\sqrt{\1}", expr_str) 108 | expr_str = re.sub(r"sqrt\s?\((.*?)\)", r"\\sqrt{\1}", expr_str) 109 | expr_str = expr_str.replace(" sqrt", "\\sqrt") 110 | expr_str = ( 111 | expr_str.replace("\\left", "").replace("\\right.", "").replace("\\right", "") 112 | ) 113 | return expr_str 114 | 115 | 116 | def parse_latex_answer(sample): 117 | if isinstance(sample, int) or isinstance(sample, float): 118 | sample = str(sample) 119 | # return sample 120 | sample = clean_expr_str(sample) 121 | try: 122 | expr = my_parse_latex(sample) 123 | except: 124 | print("[parse failed]", sample) 125 | return None 126 | return expr 127 | 128 | 129 | def my_equals(ans_p, ans_l): 130 | return ans_p.equals(ans_l) 131 | 132 | 133 | def is_expr_equal(ans_p, ans_l, is_strict=False): 134 | def is_equ_num_equal(equation, number): 135 | if ( 136 | isinstance(equation, sp.Eq) 137 | # and isinstance(equation.lhs, sp.Symbol) 138 | and equation.rhs.is_number 139 | and number.is_number 140 | ): 141 | try: 142 | ret = my_equals(equation.rhs, number) 143 | return bool(ret) 144 | except: 145 | return equation.rhs == number 146 | 147 | if ans_p is None or ans_l is None: 148 | return False 149 | if isinstance(ans_l, str): 150 | return ans_p == ans_l 151 | 152 | if ( 153 | not is_strict 154 | and is_equ_num_equal(ans_l, ans_p) 155 | or is_equ_num_equal(ans_p, ans_l) 156 | ): 157 | return True 158 | 159 | if ans_p.free_symbols != ans_l.free_symbols: 160 | return False 161 | 162 | if ans_p == ans_l: 163 | return True 164 | 165 | if isinstance(ans_l, sp.core.relational.Relational): 166 | try: 167 | if ( 168 | type(ans_l) == type(ans_p) 169 | and my_equals(ans_p.lhs, ans_l.lhs) 170 | and my_equals(ans_p.rhs, ans_l.rhs) 171 | ): 172 | return True 173 | except Exception as e: 174 | print(ans_p, ans_l, e) 175 | try: 176 | ret = my_equals(ans_p, ans_l) 177 | return bool(ret) 178 | except: 179 | return False 180 | 181 | 182 | # @timeout_decorator.timeout(5) 183 | # def compare_ans(ans_p_str, ans_l_str, is_strict=False): 184 | # ans_p_str = clean_expr_str(ans_p_str) 185 | # ans_p_str = ans_p_str.replace(",", "").replace("$", "") 186 | # ans_l_str = clean_expr_str(ans_l_str) 187 | # ans_l_str = ans_l_str.replace(",", "").replace("$", "") 188 | # if ans_p_str is None: 189 | # return False 190 | # if ans_p_str.replace(" ", "") == ans_l_str.replace(" ", ""): 191 | # return True 192 | # ans_p = parse_latex_answer(ans_p_str) 193 | # if ans_p is None: 194 | # return False 195 | # ans_l = parse_latex_answer(ans_l_str) 196 | # if ans_l is None: 197 | # return False 198 | # return is_expr_equal(ans_p, ans_l, is_strict=is_strict) 199 | 200 | 201 | def extract_answer_number(sentence: str) -> float: 202 | sentence = sentence.replace(",", "") 203 | pred = [s for s in re.findall(r"-?\d+\.?\d*", sentence)] 204 | if not pred: 205 | return "" 206 | return pred[-1] 207 | 208 | 209 | @timeout_decorator.timeout(5) 210 | def compare_ans(ans_p_str, ans_l_str, is_strict=False): 211 | ans_p_str = clean_expr_str(ans_p_str) 212 | ans_p_str = ans_p_str.replace(",", "").replace("$", "") 213 | ans_l_str = clean_expr_str(ans_l_str) 214 | ans_l_str = ans_l_str.replace(",", "").replace("$", "") 215 | if ans_p_str is None: 216 | return False 217 | if ans_p_str.replace(" ", "") == ans_l_str.replace(" ", ""): 218 | return True 219 | ans_p = parse_latex_answer(ans_p_str) 220 | if ans_p is None: 221 | return False 222 | ans_l = parse_latex_answer(ans_l_str) 223 | if ans_l is None: 224 | return False 225 | if is_expr_equal(ans_p, ans_l, is_strict=is_strict): 226 | return True 227 | # TODO not suitable 228 | ans_p_str = extract_answer_number(ans_p_str) 229 | if is_number(ans_p_str): 230 | ans_p = parse_latex_answer(ans_p_str) 231 | if is_expr_equal(ans_p, ans_l, is_strict=is_strict): 232 | return True 233 | return False 234 | 235 | 236 | def vote(answers): 237 | counter = Counter(answers) 238 | return counter.most_common(1)[0][0] 239 | 240 | 241 | def contains_number(s): 242 | return any(i.isdigit() for i in s) 243 | 244 | 245 | def rough_compare_ans(generation, answer): 246 | for line in generation.split("\n")[::-1]: 247 | if contains_number(line): 248 | break 249 | words = line.split() 250 | for i, w in enumerate(words): 251 | if i > 0 and words[i - 1] in ["+", "-", "*", "/", "^"]: 252 | continue 253 | if i < len(words) - 1 and words[i + 1] in ["+", "-", "*", "/", "^"]: 254 | continue 255 | if not contains_number(w): 256 | continue 257 | if compare_numerical_ans(w.replace("$", ""), answer) and "=" not in " ".join( 258 | w[i:] 259 | ): 260 | return 1 261 | return 0 262 | 263 | 264 | if __name__ == "__main__": 265 | pred = "To find the minimum value of $f(x)$, we need to find the critical points where the derivative of $f(x)$ is either zero or undefined. Let's write a function to solve this problem:\n\n```python\nfrom sympy import symbols, diff, solve\n\ndef find_minimum_value():\n x, p = symbols('x p')\n f = abs(x - p) + abs(x - 15) + abs(x - p - 15)\n f_prime = diff(f, x)\n critical_points = solve(f_prime, x)\n \n # Evaluate f(x) at the critical points and endpoints\n values = [f.subs(x, p), f.subs(x, 15)]\n for point in critical_points:\n values.append(f.subs(x, point))\n \n # Find the minimum value\n minimum_value = min(values)\n \n return minimum_value\n\nresult = find_minimum_value()\nprint(result)\n```\n```output\nValueError(\"\\nCan't calculate derivative wrt re(x) + I*im(x).\")\n```\nIt seems that there is an error in the code. Let's modify the code to handle the case when the derivative is undefined.\n\n```python\nfrom sympy import symbols, diff, solve, oo\n\ndef find_minimum_value():\n x, p = symbols('x p')\n f = abs(x - p) + abs(x - 15) + abs(x - p - 15)\n f_prime = diff(f, x)\n critical_points = solve(f_prime, x)\n \n # Evaluate f(x) at the critical points and endpoints\n values = [f.subs(x, p), f.subs(x, 15)]\n for point in critical_points:\n if point.is_real:\n values.append(f.subs(x, point))\n \n # Find the minimum value\n minimum_value = min(values)\n \n return minimum_value\n\nresult = find_minimum_value()\nprint(result)\n```\n```output\nValueError(\"\\nCan't calculate derivative wrt re(x) + I*im(x).\")\n```\nApologies for the error in the code. It seems that the derivative calculation is causing an issue. Let's modify the code to handle the absolute value function and find the minimum value using a different approach.\n\n```python\nfrom sympy import symbols, Piecewise, Min, oo\n\ndef find_minimum_value():\n x, p = symbols('x p')\n f = abs(x - p) + abs(x - 15) + abs(x - p - 15)\n \n # Define the piecewise function for f(x)\n f_piecewise = Piecewise((f, x <= p), (f.subs(x, p), p < x <= 15), (f.subs(x, 15), x > 15))\n \n # Find the minimum value\n minimum_value = Min(f_piecewise, (x, p, 15))\n \n return minimum_value\n\nresult = find_minimum_value()\nprint(result)\n```\n```output\nTypeError('cannot determine truth value of Relational')\n```\nApologies for the error in the code. It seems that the piecewise function is causing an issue. Let's modify the code to handle the absolute value function and find the minimum value using a different approach.\n\n```python\nfrom sympy import symbols, Min, oo\n\ndef find_minimum_value():\n x, p = symbols('x p')\n f = abs(x - p) + abs(x - 15) + abs(x - p - 15)\n \n # Find the minimum value\n minimum_value = Min(f, (x, p, 15))\n \n return minimum_value\n\nresult = find_minimum_value()\nprint(result)\n```\n" 266 | label = "15" 267 | print(compare_ans(pred, label)) 268 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /src/applications/grader.py: -------------------------------------------------------------------------------- 1 | """ 2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: 3 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC 4 | - https://github.com/openai/prm800k 5 | - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py 6 | - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py 7 | """ 8 | 9 | import re 10 | import regex 11 | import multiprocessing 12 | from math import isclose 13 | from typing import Union 14 | from collections import defaultdict 15 | 16 | from sympy import simplify, N 17 | from sympy.parsing.sympy_parser import parse_expr 18 | from sympy.parsing.latex import parse_latex 19 | from latex2sympy2 import latex2sympy 20 | 21 | # from .parser import choice_answer_clean, strip_string 22 | # from parser import choice_answer_clean 23 | 24 | 25 | def choice_answer_clean(pred: str): 26 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") 27 | # Clean the answer based on the dataset 28 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) 29 | if tmp: 30 | pred = tmp 31 | else: 32 | pred = [pred.strip().strip(".")] 33 | pred = pred[-1] 34 | # Remove the period at the end, again! 35 | pred = pred.rstrip(".").rstrip("/") 36 | return pred 37 | 38 | 39 | def parse_digits(num): 40 | num = regex.sub(",", "", str(num)) 41 | try: 42 | return float(num) 43 | except: 44 | if num.endswith("%"): 45 | num = num[:-1] 46 | if num.endswith("\\"): 47 | num = num[:-1] 48 | try: 49 | return float(num) / 100 50 | except: 51 | pass 52 | return None 53 | 54 | 55 | def is_digit(num): 56 | # paired with parse_digits 57 | return parse_digits(num) is not None 58 | 59 | 60 | def str_to_pmatrix(input_str): 61 | input_str = input_str.strip() 62 | matrix_str = re.findall(r"\{.*,.*\}", input_str) 63 | pmatrix_list = [] 64 | 65 | for m in matrix_str: 66 | m = m.strip("{}") 67 | pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}" 68 | pmatrix_list.append(pmatrix) 69 | 70 | return ", ".join(pmatrix_list) 71 | 72 | 73 | def math_equal( 74 | prediction: Union[bool, float, str], 75 | reference: Union[float, str], 76 | include_percentage: bool = True, 77 | is_close: bool = True, 78 | timeout: bool = False, 79 | ) -> bool: 80 | """ 81 | Exact match of math if and only if: 82 | 1. numerical equal: both can convert to float and are equal 83 | 2. symbolic equal: both can convert to sympy expression and are equal 84 | """ 85 | # print("Judge:", prediction, reference) 86 | if prediction is None or reference is None: 87 | return False 88 | if str(prediction.strip().lower()) == str(reference.strip().lower()): 89 | return True 90 | if ( 91 | reference in ["A", "B", "C", "D", "E"] 92 | and choice_answer_clean(prediction) == reference 93 | ): 94 | return True 95 | 96 | try: # 1. numerical equal 97 | if is_digit(prediction) and is_digit(reference): 98 | prediction = parse_digits(prediction) 99 | reference = parse_digits(reference) 100 | # number questions 101 | if include_percentage: 102 | gt_result = [reference / 100, reference, reference * 100] 103 | else: 104 | gt_result = [reference] 105 | for item in gt_result: 106 | try: 107 | if is_close: 108 | if numeric_equal(prediction, item): 109 | return True 110 | else: 111 | if item == prediction: 112 | return True 113 | except Exception: 114 | continue 115 | return False 116 | except: 117 | pass 118 | 119 | if not prediction and prediction not in [0, False]: 120 | return False 121 | 122 | # 2. symbolic equal 123 | reference = str(reference).strip() 124 | prediction = str(prediction).strip() 125 | 126 | ## pmatrix (amps) 127 | if "pmatrix" in prediction and not "pmatrix" in reference: 128 | reference = str_to_pmatrix(reference) 129 | 130 | ## deal with [], (), {} 131 | pred_str, ref_str = prediction, reference 132 | if ( 133 | prediction.startswith("[") 134 | and prediction.endswith("]") 135 | and not reference.startswith("(") 136 | ) or ( 137 | prediction.startswith("(") 138 | and prediction.endswith(")") 139 | and not reference.startswith("[") 140 | ): 141 | pred_str = pred_str.strip("[]()") 142 | ref_str = ref_str.strip("[]()") 143 | for s in ["{", "}", "(", ")"]: 144 | ref_str = ref_str.replace(s, "") 145 | pred_str = pred_str.replace(s, "") 146 | if pred_str.lower() == ref_str.lower(): 147 | return True 148 | 149 | ## [a, b] vs. [c, d], return a==c and b==d 150 | if ( 151 | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None 152 | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None 153 | ): 154 | pred_parts = prediction[1:-1].split(",") 155 | ref_parts = reference[1:-1].split(",") 156 | if len(pred_parts) == len(ref_parts): 157 | if all( 158 | [ 159 | math_equal( 160 | pred_parts[i], ref_parts[i], include_percentage, is_close 161 | ) 162 | for i in range(len(pred_parts)) 163 | ] 164 | ): 165 | return True 166 | if ( 167 | ( 168 | prediction.startswith("\\begin{pmatrix}") 169 | or prediction.startswith("\\begin{bmatrix}") 170 | ) 171 | and ( 172 | prediction.endswith("\\end{pmatrix}") 173 | or prediction.endswith("\\end{bmatrix}") 174 | ) 175 | and ( 176 | reference.startswith("\\begin{pmatrix}") 177 | or reference.startswith("\\begin{bmatrix}") 178 | ) 179 | and ( 180 | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") 181 | ) 182 | ): 183 | pred_lines = [ 184 | line.strip() 185 | for line in prediction[ 186 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") 187 | ].split("\\\\") 188 | if line.strip() 189 | ] 190 | ref_lines = [ 191 | line.strip() 192 | for line in reference[ 193 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}") 194 | ].split("\\\\") 195 | if line.strip() 196 | ] 197 | matched = True 198 | if len(pred_lines) == len(ref_lines): 199 | for pred_line, ref_line in zip(pred_lines, ref_lines): 200 | pred_parts = pred_line.split("&") 201 | ref_parts = ref_line.split("&") 202 | if len(pred_parts) == len(ref_parts): 203 | if not all( 204 | [ 205 | math_equal( 206 | pred_parts[i], 207 | ref_parts[i], 208 | include_percentage, 209 | is_close, 210 | ) 211 | for i in range(len(pred_parts)) 212 | ] 213 | ): 214 | matched = False 215 | break 216 | else: 217 | matched = False 218 | if not matched: 219 | break 220 | else: 221 | matched = False 222 | if matched: 223 | return True 224 | 225 | if prediction.count("=") == 1 and reference.count("=") == 1: 226 | pred = prediction.split("=") 227 | pred = f"{pred[0].strip()} - ({pred[1].strip()})" 228 | ref = reference.split("=") 229 | ref = f"{ref[0].strip()} - ({ref[1].strip()})" 230 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): 231 | return True 232 | elif ( 233 | prediction.count("=") == 1 234 | and len(prediction.split("=")[0].strip()) <= 2 235 | and "=" not in reference 236 | ): 237 | if math_equal( 238 | prediction.split("=")[1], reference, include_percentage, is_close 239 | ): 240 | return True 241 | elif ( 242 | reference.count("=") == 1 243 | and len(reference.split("=")[0].strip()) <= 2 244 | and "=" not in prediction 245 | ): 246 | if math_equal( 247 | prediction, reference.split("=")[1], include_percentage, is_close 248 | ): 249 | return True 250 | 251 | # symbolic equal with sympy 252 | if timeout: 253 | if call_with_timeout(symbolic_equal_process, prediction, reference): 254 | return True 255 | else: 256 | if symbolic_equal(prediction, reference): 257 | return True 258 | 259 | return False 260 | 261 | 262 | def math_equal_process(param): 263 | return math_equal(param[-2], param[-1]) 264 | 265 | 266 | def numeric_equal(prediction: float, reference: float): 267 | # Note that relative tolerance has significant impact 268 | # on the result of the synthesized GSM-Hard dataset 269 | # if reference.is_integer(): 270 | # return isclose(reference, round(prediction), abs_tol=1e-4) 271 | # else: 272 | # prediction = round(prediction, len(str(reference).split(".")[-1])) 273 | return isclose(reference, prediction, rel_tol=1e-4) 274 | 275 | 276 | def symbolic_equal(a, b): 277 | def _parse(s): 278 | for f in [parse_latex, parse_expr, latex2sympy]: 279 | try: 280 | return f(s.replace("\\\\", "\\")) 281 | except: 282 | try: 283 | return f(s) 284 | except: 285 | pass 286 | return s 287 | 288 | a = _parse(a) 289 | b = _parse(b) 290 | 291 | # direct equal 292 | try: 293 | if str(a) == str(b) or a == b: 294 | return True 295 | except: 296 | pass 297 | 298 | # simplify equal 299 | try: 300 | if a.equals(b) or simplify(a - b) == 0: 301 | return True 302 | except: 303 | pass 304 | 305 | # equation equal 306 | try: 307 | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): 308 | return True 309 | except: 310 | pass 311 | 312 | try: 313 | if numeric_equal(float(N(a)), float(N(b))): 314 | return True 315 | except: 316 | pass 317 | 318 | # matrix 319 | try: 320 | # if a and b are matrix 321 | if a.shape == b.shape: 322 | _a = a.applyfunc(lambda x: round(x, 3)) 323 | _b = b.applyfunc(lambda x: round(x, 3)) 324 | if _a.equals(_b): 325 | return True 326 | except: 327 | pass 328 | 329 | return False 330 | 331 | 332 | def symbolic_equal_process(a, b, output_queue): 333 | result = symbolic_equal(a, b) 334 | output_queue.put(result) 335 | 336 | 337 | def call_with_timeout(func, *args, timeout=1, **kwargs): 338 | output_queue = multiprocessing.Queue() 339 | process_args = args + (output_queue,) 340 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) 341 | process.start() 342 | process.join(timeout) 343 | 344 | if process.is_alive(): 345 | process.terminate() 346 | process.join() 347 | return False 348 | 349 | return output_queue.get() 350 | 351 | def _test_math_equal(): 352 | # print(math_equal("0.0833333333333333", "\\frac{1}{12}")) 353 | # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})")) 354 | # print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True)) 355 | # print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True)) 356 | # print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True)) 357 | 358 | # pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}' 359 | # gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})' 360 | 361 | # pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}' 362 | # gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}' 363 | 364 | # pred = '-34x-45y+20z-100=0' 365 | # gt = '34x+45y-20z+100=0' 366 | 367 | # pred = '\\frac{100}{3}' 368 | # gt = '33.3' 369 | 370 | # pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}' 371 | # gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})' 372 | 373 | # pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}' 374 | # gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}' 375 | 376 | # pred = '(+5)(b+2)' 377 | # gt = '(a+5)(b+2)' 378 | 379 | # pred = '\\frac{1+\\sqrt{5}}{2}' 380 | # gt = '2' 381 | 382 | # pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4' 383 | # pred = '1', gt = '1\\\\sqrt{19}' 384 | 385 | # pred = "(0.6,2.6667]" 386 | # gt = "(\\frac{3}{5},\\frac{8}{3}]" 387 | 388 | gt = "x+2n+1" 389 | pred = "x+1" 390 | 391 | print(math_equal(pred, gt, timeout=True)) 392 | 393 | 394 | if __name__ == "__main__": 395 | _test_math_equal() 396 | -------------------------------------------------------------------------------- /src/applications/TTTS_evaluate.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import argparse 4 | import time 5 | import json 6 | import re 7 | import torch 8 | from datetime import datetime 9 | from tqdm import tqdm 10 | from transformers import LogitsProcessor, AutoTokenizer, AutoModelForCausalLM 11 | from vllm import LLM, SamplingParams 12 | from evaluate import evaluate 13 | from utils import set_seed, load_jsonl, save_jsonl, construct_prompt 14 | from parser import * 15 | from trajectory import * 16 | from data_loader import load_data 17 | from python_executor import PythonExecutor 18 | from model_utils import load_hf_lm_and_tokenizer, generate_completions 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--data_names", default="gsm8k,math", type=str) 23 | parser.add_argument("--data_dir", default="./data", type=str) 24 | parser.add_argument("--save_dir", default="./data", type=str) 25 | parser.add_argument("--model_name_or_path", default="gpt-4", type=str) 26 | parser.add_argument("--output_dir", default="./output", type=str) 27 | parser.add_argument("--prompt_type", default="tool-integrated", type=str) 28 | parser.add_argument("--split", default="test", type=str) 29 | parser.add_argument("--num_test_sample", default=-1, type=int) # -1: all data 30 | parser.add_argument("--seed", default=0, type=int) 31 | parser.add_argument("--start", default=0, type=int) 32 | parser.add_argument("--end", default=-1, type=int) 33 | parser.add_argument("--temperature", default=0, type=float) 34 | parser.add_argument("--n_sampling", default=1, type=int) 35 | parser.add_argument("--top_p", default=1, type=float) 36 | parser.add_argument("--shuffle", action="store_true") 37 | parser.add_argument("--use_vllm", action="store_true") 38 | parser.add_argument("--overwrite", action="store_true") 39 | parser.add_argument("--use_safetensors", action="store_true") 40 | parser.add_argument("--apply_chat_template", action="store_true",) 41 | parser.add_argument("--pipeline_parallel_size", type=int, default=1) 42 | parser.add_argument("--adapt_few_shot", action="store_true") 43 | parser.add_argument("--num_shots", type=int, default=0) 44 | 45 | # TTTS related 46 | parser.add_argument("--thinking_tokens_file_path", type=str, default=None) 47 | parser.add_argument("--max_tokens_per_call", default=2048, type=int) 48 | parser.add_argument("--thinking_token", default="", type=str) 49 | parser.add_argument("--token_budget", default=2048, type=int) 50 | args = parser.parse_args() 51 | 52 | args.top_p = 1 if args.temperature == 0 else args.top_p 53 | return args 54 | 55 | 56 | def load_model_and_tokenizer(args): 57 | available_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") 58 | if args.use_vllm: 59 | llm = LLM( 60 | model=args.model_name_or_path, 61 | tensor_parallel_size=len(available_gpus) // args.pipeline_parallel_size, 62 | pipeline_parallel_size=args.pipeline_parallel_size, 63 | trust_remote_code=True, 64 | ) 65 | tokenizer = None 66 | if args.apply_chat_template: 67 | tokenizer = AutoTokenizer.from_pretrained( 68 | args.model_name_or_path, trust_remote_code=True 69 | ) 70 | else: 71 | llm, tokenizer = load_hf_lm_and_tokenizer( 72 | model_name_or_path=args.model_name_or_path, 73 | load_in_half=True, 74 | use_fast_tokenizer=True, 75 | use_safetensors=args.use_safetensors, 76 | ) 77 | return llm, tokenizer 78 | 79 | 80 | def prepare_data(data_name, args): 81 | examples = load_data(data_name, args.split, args.data_dir) 82 | 83 | if args.num_test_sample > 0: 84 | examples = examples[: args.num_test_sample] 85 | 86 | if args.shuffle: 87 | random.seed(datetime.now().timestamp()) 88 | random.shuffle(examples) 89 | 90 | examples = examples[args.start: (len(examples) if args.end == -1 else args.end)] 91 | 92 | dt_string = datetime.now().strftime("%m-%d_%H-%M") 93 | out_file_prefix = f"{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}" 94 | output_dir = args.output_dir 95 | output_dir = os.path.join("outputs", output_dir) 96 | os.makedirs(os.path.join(output_dir, data_name), exist_ok=True) 97 | out_file = os.path.join(output_dir, data_name, f"{out_file_prefix}_s{args.start}_e{args.end}.jsonl") 98 | 99 | processed_samples = [] 100 | if not args.overwrite: 101 | processed_files = [ 102 | f for f in os.listdir(os.path.join(output_dir, data_name)) 103 | if f.endswith(".jsonl") and f.startswith(out_file_prefix) 104 | ] 105 | for f in processed_files: 106 | processed_samples.extend(list(load_jsonl(os.path.join(output_dir, data_name, f)))) 107 | 108 | processed_samples_dict = {sample["idx"]: sample for sample in processed_samples} 109 | processed_idxs = set(processed_samples_dict.keys()) 110 | examples = [example for example in examples if example["idx"] not in processed_idxs] 111 | 112 | return examples, list(processed_samples_dict.values()), out_file 113 | 114 | 115 | def setup(args): 116 | llm, tokenizer = load_model_and_tokenizer(args) 117 | data_list = args.data_names.split(",") 118 | results = [] 119 | for data_name in data_list: 120 | results.append(main(llm, tokenizer, data_name, args)) 121 | 122 | avg_acc = sum(result["acc"] for result in results) / len(results) 123 | data_list.append("avg") 124 | results.append({"acc": avg_acc}) 125 | 126 | pad = max(len(name) for name in data_list) 127 | print("\t".join(name.ljust(pad, " ") for name in data_list)) 128 | print("\t".join(f"{result['acc']:.1f}".ljust(pad, " ") for result in results)) 129 | 130 | 131 | def is_multi_choice(answer): 132 | return all(c in ["A", "B", "C", "D", "E"] for c in answer) 133 | 134 | 135 | # ------------------------------ Thinking Token based Test-time Scaling ---------------------------------- 136 | def batch_TTTS_generation(prompts, llm, tokenizer, args, thinking_token, stop_words): 137 | base_budget = args.max_tokens_per_call 138 | token_budget = args.token_budget 139 | 140 | 141 | ### 1. Base generation 142 | sampling_params = SamplingParams( 143 | max_tokens=base_budget, 144 | top_p=args.top_p, 145 | stop=stop_words, 146 | stop_token_ids=( 147 | [151645, 151643] 148 | if "qwen2" in args.model_name_or_path.lower() 149 | else None 150 | ), 151 | skip_special_tokens=False, 152 | temperature=0.0, 153 | ) 154 | vllm_outputs = llm.generate(prompts, sampling_params) 155 | 156 | outputs = [] 157 | for output in vllm_outputs: 158 | generated_text = output.outputs[0].text.strip() 159 | outputs.append(generated_text) 160 | 161 | 162 | ### 2. Generation with thinking tokens 163 | budget_prompts = [] 164 | 165 | for q, a in zip(prompts, outputs): 166 | concat_prompt = f'{q} {a} {thinking_token}' 167 | budget_prompts.append(concat_prompt) 168 | 169 | sampling_params = SamplingParams( 170 | max_tokens=token_budget, 171 | top_p=args.top_p, 172 | stop=stop_words, 173 | stop_token_ids=( 174 | [151645, 151643] 175 | if "qwen2" in args.model_name_or_path.lower() 176 | else None 177 | ), 178 | skip_special_tokens=False, 179 | temperature=0.0, 180 | ) 181 | vllm_outputs_budget = llm.generate(budget_prompts, sampling_params) 182 | 183 | outputs_budget = [] 184 | for output in vllm_outputs_budget: 185 | generated_text = output.outputs[0].text.strip() 186 | outputs_budget.append(generated_text) 187 | 188 | 189 | ### 3. Final output 190 | final_prompts = [] 191 | for q, a in zip(budget_prompts, outputs_budget): 192 | concat_prompt = q + a + r" Final Answer within \boxed{}:" 193 | final_prompts.append(concat_prompt) 194 | 195 | 196 | sampling_params = SamplingParams( 197 | max_tokens=16, 198 | top_p=args.top_p, 199 | stop=stop_words, 200 | stop_token_ids=( 201 | [151645, 151643] 202 | if "qwen2" in args.model_name_or_path.lower() 203 | else None 204 | ), 205 | skip_special_tokens=False, 206 | temperature=0.0, 207 | ) 208 | final_vllm_outputs = llm.generate(final_prompts, sampling_params) 209 | 210 | final_outputs = [] 211 | for output in final_vllm_outputs: 212 | generated_text = output.outputs[0].text.strip() 213 | final_outputs.append(generated_text) 214 | 215 | return final_vllm_outputs 216 | 217 | 218 | def main(llm, tokenizer, data_name, args): 219 | examples, processed_samples, out_file = prepare_data(data_name, args) 220 | print("Dataset:", data_name, "Total samples:", len(examples)) 221 | if examples: 222 | print(examples[0]) 223 | 224 | if "pal" in args.prompt_type: 225 | executor = PythonExecutor(get_answer_expr="solution()") 226 | else: 227 | executor = PythonExecutor(get_answer_from_stdout=True) 228 | 229 | samples = [] 230 | for example in tqdm(examples, total=len(examples)): 231 | idx = example["idx"] 232 | example["question"] = parse_question(example, data_name) 233 | if example["question"] == "": 234 | continue 235 | gt_cot, gt_ans = parse_ground_truth(example, data_name) 236 | example["gt_ans"] = gt_ans 237 | full_prompt = construct_prompt(example, data_name, args) 238 | 239 | if idx == args.start: 240 | print(full_prompt) 241 | 242 | sample = { 243 | "idx": idx, 244 | "question": example["question"], 245 | "gt_cot": gt_cot, 246 | "gt": gt_ans, 247 | "prompt": full_prompt, 248 | } 249 | for key in [ 250 | "level", "type", "unit", "solution_type", "choices", "solution", 251 | "ques_type", "ans_type", "answer_type", "dataset", "subfield", "filed", 252 | "theorem", "answer", 253 | ]: 254 | if key in example: 255 | sample[key] = example[key] 256 | samples.append(sample) 257 | 258 | input_prompts = [sample["prompt"] for sample in samples for _ in range(args.n_sampling)] 259 | if args.apply_chat_template and tokenizer is not None: 260 | input_prompts = [ 261 | tokenizer.apply_chat_template( 262 | [{"role": "user", "content": prompt.strip()}], 263 | tokenize=False, 264 | add_generation_prompt=True, 265 | ) 266 | for prompt in input_prompts 267 | ] 268 | remain_prompts = list(enumerate(input_prompts)) 269 | end_prompts = [] 270 | 271 | max_func_call = 1 if args.prompt_type in ["cot", "pal"] else 4 272 | stop_words = ["", "<|im_end|>", "<|endoftext|>"] 273 | if args.prompt_type in ["cot"]: 274 | stop_words.append("\n\nQuestion:") 275 | if args.prompt_type in ["pal", "tool-integrated", "jiuzhang_tora"]: 276 | stop_words.extend(["\n\n---", "```output"]) 277 | elif args.prompt_type in ["wizard_zs", "platypus_fs"]: 278 | stop_words.extend(["Instruction", "Response"]) 279 | elif "jiuzhang" in args.prompt_type: 280 | stop_words.append("\n\n## Question") 281 | elif "numina" in args.prompt_type: 282 | stop_words.append("\n### Problem") 283 | elif "pure" in args.prompt_type: 284 | stop_words.append("\n\n\n") 285 | 286 | start_time = time.time() 287 | for epoch in range(max_func_call): 288 | print("-" * 20, "Epoch", epoch) 289 | current_prompts = remain_prompts 290 | if not current_prompts: 291 | break 292 | 293 | prompts = [item[1] for item in current_prompts] 294 | 295 | outputs = batch_TTTS_generation( 296 | prompts=prompts, 297 | llm=llm, 298 | tokenizer=tokenizer, 299 | args=args, 300 | thinking_token=args.thinking_token, 301 | stop_words=stop_words, 302 | ) 303 | outputs = sorted(outputs, key=lambda x: int(x.request_id)) 304 | outputs = [output.outputs[0].text for output in outputs] 305 | 306 | assert len(outputs) == len(current_prompts) 307 | 308 | remain_prompts = [] 309 | remain_codes = [] 310 | for (i, query), output in zip(current_prompts, outputs): 311 | output = output.rstrip() 312 | query += output 313 | if args.prompt_type == "pal": 314 | remain_prompts.append((i, query)) 315 | if "```python" in output: 316 | output = extract_program(query) 317 | remain_codes.append(output) 318 | elif args.prompt_type == "cot": 319 | end_prompts.append((i, query)) 320 | elif "boxed" not in output and output.endswith("```"): 321 | program = extract_program(query) 322 | remain_prompts.append((i, query)) 323 | remain_codes.append(program) 324 | else: 325 | end_prompts.append((i, query)) 326 | 327 | remain_results = executor.batch_apply(remain_codes) 328 | for k in range(len(remain_prompts)): 329 | i, query = remain_prompts[k] 330 | res, report = remain_results[k] 331 | exec_result = res if res else report 332 | if "pal" in args.prompt_type: 333 | exec_result = "\\boxed{" + exec_result + "}" 334 | exec_result = f"\n```output\n{exec_result}\n```\n" 335 | query += exec_result 336 | if epoch == max_func_call - 1: 337 | query += "\nReach max function call limit." 338 | remain_prompts[k] = (i, query) 339 | 340 | end_prompts.extend(remain_prompts) 341 | end_prompts = sorted(end_prompts, key=lambda x: x[0]) 342 | 343 | codes = [] 344 | assert len(input_prompts) == len(end_prompts) 345 | for i in range(len(input_prompts)): 346 | _, end_prompt = end_prompts[i] 347 | code = end_prompt.split(input_prompts[i])[-1].strip() 348 | for stop_word in stop_words: 349 | if stop_word in code: 350 | code = code.split(stop_word)[0].strip() 351 | codes.append(code) 352 | 353 | results = [run_execute(executor, code, args.prompt_type, data_name) for code in codes] 354 | time_use = time.time() - start_time 355 | 356 | all_samples = [] 357 | for i, sample in enumerate(samples): 358 | code_list = codes[i * args.n_sampling: (i + 1) * args.n_sampling] 359 | result_list = results[i * args.n_sampling: (i + 1) * args.n_sampling] 360 | preds = [item[0] for item in result_list] 361 | reports = [item[1] for item in result_list] 362 | for j in range(len(preds)): 363 | if sample["gt"] in ["A", "B", "C", "D", "E"] and preds[j] not in ["A", "B", "C", "D", "E"]: 364 | preds[j] = choice_answer_clean(code_list[j]) 365 | elif is_multi_choice(sample["gt"]) and not is_multi_choice(preds[j]): 366 | preds[j] = "".join([c for c in preds[j] if c in ["A", "B", "C", "D", "E"]]) 367 | sample.pop("prompt", None) 368 | sample.update({"code": code_list, "pred": preds, "report": reports}) 369 | all_samples.append(sample) 370 | 371 | all_samples.extend(processed_samples) 372 | all_samples, result_json = evaluate( 373 | samples=all_samples, 374 | data_name=data_name, 375 | prompt_type=args.prompt_type, 376 | execute=True, 377 | ) 378 | out_file_metrics = out_file.replace(".jsonl", f"_{args.prompt_type}_{args.thinking_token}_metrics.json") 379 | with open(out_file_metrics, "w") as f: 380 | json.dump(result_json, f, indent=4) 381 | result_json["time_use_in_second"] = time_use 382 | result_json["time_use_in_minite"] = f"{int(time_use // 60)}:{int(time_use % 60):02d}" 383 | result_json["word"] = args.thinking_token 384 | return result_json 385 | 386 | 387 | if __name__ == "__main__": 388 | args = parse_args() 389 | set_seed(args.seed) 390 | 391 | llm, tokenizer = load_model_and_tokenizer(args) 392 | prompt_words = load_jsonl(args.thinking_tokens_file_path) 393 | 394 | prompt_words = [ 395 | entry for entry in prompt_words 396 | if re.fullmatch(r'[A-Za-z]+', entry.get('word', '')) 397 | ] 398 | 399 | output_file = f"{args.output_dir}/TTTS_results.json" 400 | 401 | aggregated_results = {} 402 | data_list = args.data_names.split(",") 403 | for data_name in data_list: 404 | aggregated_results[data_name] = {} 405 | for entry in prompt_words: 406 | word = entry["word"] 407 | args.thinking_token = word 408 | print(f"Evaluating dataset {data_name}, thinking token: {word}") 409 | result_json = main(llm, tokenizer, data_name, args) 410 | aggregated_results[data_name][word] = result_json 411 | 412 | print('aggregated_results:', aggregated_results) 413 | with open(output_file, "w") as f: 414 | json.dump(aggregated_results, f, indent=2) 415 | -------------------------------------------------------------------------------- /src/applications/RR_evaluate.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import argparse 4 | import time 5 | import re 6 | import json 7 | import torch 8 | 9 | from datetime import datetime 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | from evaluate import evaluate 13 | from utils import set_seed, load_jsonl, save_jsonl, construct_prompt 14 | from parser import * 15 | from trajectory import * 16 | from data_loader import load_data 17 | from python_executor import PythonExecutor 18 | from model_utils import load_hf_lm_and_tokenizer, generate_completions 19 | from RR_model import RecursiveThinkingModel 20 | 21 | 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--data_names", default="aime24", type=str) 27 | parser.add_argument("--data_dir", default="./data", type=str) 28 | parser.add_argument("--model_name_or_path", default="gpt-4", type=str) 29 | parser.add_argument("--output_dir", default="./output", type=str) 30 | parser.add_argument("--prompt_type", default="tool-integrated", type=str) 31 | parser.add_argument("--split", default="test", type=str) 32 | parser.add_argument("--num_test_sample", default=-1, type=int) # -1 means use all data 33 | parser.add_argument("--seed", default=0, type=int) 34 | parser.add_argument("--start", default=0, type=int) 35 | parser.add_argument("--end", default=-1, type=int) 36 | parser.add_argument("--temperature", default=0, type=float) 37 | parser.add_argument("--n_sampling", default=1, type=int) 38 | parser.add_argument("--top_p", default=1, type=float) 39 | parser.add_argument("--max_tokens_per_call", default=16000, type=int) 40 | parser.add_argument("--shuffle", action="store_true") 41 | parser.add_argument("--use_vllm", action="store_true") 42 | parser.add_argument("--save_outputs", action="store_true") 43 | parser.add_argument("--overwrite", action="store_true") 44 | parser.add_argument("--use_safetensors", action="store_true") 45 | parser.add_argument("--num_shots", type=int, default=0) 46 | parser.add_argument("--apply_chat_template", action="store_true", help="Apply chat template to prompts") 47 | parser.add_argument("--pipeline_parallel_size", type=int, default=1) 48 | parser.add_argument("--adapt_few_shot", action="store_true", help="Use few-shot examples for multiple-choice questions, zero-shot for others") 49 | 50 | # Add RecursiveThinkingModel specific parameters 51 | parser.add_argument("--use_recursive_thinking", type=bool, default=True, help="Enable recursive thinking feature") 52 | parser.add_argument("--extract_layer_id", type=int, default=-1, help="Layer ID for extracting hidden states") 53 | parser.add_argument("--inject_layer_id", type=int, default=-1, help="Layer ID for injecting hidden states") 54 | parser.add_argument("--num_recursive_steps", type=int, default=1, help="Number of recursive thinking steps") 55 | parser.add_argument("--interested_tokens", help="List of token IDs to apply recursive thinking, can be a string or a preprocessed list") 56 | parser.add_argument("--interested_tokens_file_path", help="interested_tokens_file_path") 57 | parser.add_argument("--output_file",default="outputs.jsonl", help="Output address for responses") 58 | args = parser.parse_args() 59 | args.top_p = (1 if args.temperature == 0 else args.top_p) # top_p must be 1 when using greedy sampling 60 | return args 61 | 62 | 63 | 64 | def setup(args): 65 | """Set up the model and evaluation environment""" 66 | # Output debug information about interested_tokens 67 | print(f"Type of interested_tokens: {type(args.interested_tokens)}") 68 | print(f"Length of interested_tokens: {len(args.interested_tokens) if hasattr(args.interested_tokens, '__len__') else 'N/A'}") 69 | if args.interested_tokens is not None and hasattr(args.interested_tokens, '__len__') and len(args.interested_tokens) > 0: 70 | print(f"Elements of interested_tokens: {args.interested_tokens}") 71 | 72 | # Load the model 73 | available_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") 74 | 75 | 76 | # Load RecursiveThinkingModel 77 | if args.use_recursive_thinking: 78 | print("Loading RecursiveThinkingModel...") 79 | base_model = args.model_name_or_path 80 | 81 | # Process the list of interested tokens 82 | interested_tokens = args.interested_tokens 83 | print(f"Loaded interested_tokens_ids: Type={type(interested_tokens)}, Length={len(interested_tokens) if interested_tokens else 0}") 84 | if interested_tokens and len(interested_tokens) > 0: 85 | print(f"Loaded token IDs: {interested_tokens}") 86 | 87 | # Load tokenizer separately for preprocessing 88 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) 89 | if tokenizer.pad_token is None: 90 | tokenizer.pad_token = tokenizer.eos_token 91 | # Create RecursiveThinkingModel 92 | llm = RecursiveThinkingModel( 93 | base_model_name=base_model, 94 | extract_layer_id=args.extract_layer_id, 95 | inject_layer_id=args.inject_layer_id, 96 | num_recursive_steps=args.num_recursive_steps, 97 | use_recursive_thinking=args.use_recursive_thinking, 98 | output_file=args.output_file, 99 | ) 100 | print(f"RecursiveThinkingModel loaded, extraction layer={args.extract_layer_id}, injection layer={args.inject_layer_id}, recursive steps={args.num_recursive_steps}") 101 | else: 102 | # Standard HF model loading 103 | llm, tokenizer = load_hf_lm_and_tokenizer( 104 | model_name_or_path=args.model_name_or_path, 105 | load_in_half=True, 106 | use_fast_tokenizer=True, 107 | use_safetensors=args.use_safetensors, 108 | ) 109 | 110 | # Inference and evaluation 111 | data_list = args.data_names.split(",") 112 | results = [] 113 | for data_name in data_list: 114 | results.append(main(llm, tokenizer, data_name, args)) 115 | 116 | # Add "avg" result to data_list and results 117 | data_list.append("avg") 118 | results.append( 119 | { 120 | "acc": sum([result["acc"] for result in results]) / len(results), 121 | } 122 | ) 123 | 124 | # Print all results 125 | pad = max([len(data_name) for data_name in data_list]) 126 | print("\t".join(data_name.ljust(pad, " ") for data_name in data_list)) 127 | print("\t".join([f"{result['acc']:.1f}".ljust(pad, " ") for result in results])) 128 | 129 | 130 | def prepare_data(data_name, args): 131 | """Prepare evaluation data""" 132 | examples = load_data(data_name, args.split, args.data_dir) 133 | 134 | # Sample num_test_sample samples from the dataset 135 | if args.num_test_sample > 0: 136 | examples = examples[: args.num_test_sample] 137 | 138 | # Shuffle data 139 | if args.shuffle: 140 | random.seed(datetime.now().timestamp()) 141 | random.shuffle(examples) 142 | 143 | # Select start and end indices 144 | examples = examples[args.start : len(examples) if args.end == -1 else args.end] 145 | 146 | # Get output filename 147 | dt_string = datetime.now().strftime("%m-%d_%H-%M") 148 | model_name = "/".join(args.model_name_or_path.split("/")[-2:]) 149 | out_file_prefix = f"{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}" 150 | if args.use_recursive_thinking: 151 | out_file_prefix += f"_recursive{args.num_recursive_steps}_ext{args.extract_layer_id}_inj{args.inject_layer_id}" 152 | # Add an identifier indicating that interested tokens were used 153 | if args.interested_tokens is not None: 154 | out_file_prefix += "_with_interest_tokens" 155 | output_dir = args.output_dir 156 | if not os.path.exists(output_dir): 157 | output_dir = f"outputs/{output_dir}" 158 | out_file = f"{output_dir}/{data_name}/{out_file_prefix}_s{args.start}_e{args.end}.jsonl" 159 | os.makedirs(f"{output_dir}/{data_name}", exist_ok=True) 160 | 161 | # Load all processed samples 162 | processed_samples = [] 163 | if not args.overwrite: 164 | processed_files = [ 165 | f 166 | for f in os.listdir(f"{output_dir}/{data_name}/") 167 | if f.endswith(".jsonl") and f.startswith(out_file_prefix) 168 | ] 169 | for f in processed_files: 170 | processed_samples.extend( 171 | list(load_jsonl(f"{output_dir}/{data_name}/{f}")) 172 | ) 173 | 174 | # Deduplicate 175 | processed_samples = {sample["idx"]: sample for sample in processed_samples} 176 | processed_idxs = list(processed_samples.keys()) 177 | processed_samples = list(processed_samples.values()) 178 | examples = [example for example in examples if example["idx"] not in processed_idxs] 179 | return examples, processed_samples, out_file 180 | 181 | 182 | def generate_with_recursive_model(model, prompt, max_tokens, interested_tokens=None, use_recursive_thinking=True): 183 | """Generate text using RecursiveThinkingModel""" 184 | return model.generate( 185 | max_tokens=max_tokens, 186 | prompt=prompt, 187 | interested_tokens=interested_tokens, 188 | use_recursive_thinking=use_recursive_thinking 189 | ) 190 | 191 | 192 | def is_multi_choice(answer): 193 | """Check if the answer is in multiple-choice format""" 194 | for c in answer: 195 | if c not in ["A", "B", "C", "D", "E"]: 196 | return False 197 | return True 198 | 199 | 200 | def main(llm, tokenizer, data_name, args): 201 | """Main evaluation function""" 202 | examples, processed_samples, out_file = prepare_data(data_name, args) 203 | print("=" * 50) 204 | print("Data:", data_name, " , Remaining samples:", len(examples)) 205 | if len(examples) > 0: 206 | print(examples[0]) 207 | 208 | # Initialize Python executor 209 | if "pal" in args.prompt_type: 210 | executor = PythonExecutor(get_answer_expr="solution()") 211 | else: 212 | executor = PythonExecutor(get_answer_from_stdout=True) 213 | 214 | samples = [] 215 | for example in tqdm(examples, total=len(examples)): 216 | idx = example["idx"] 217 | 218 | # Parse question and answer 219 | example["question"] = parse_question(example, data_name) 220 | if example["question"] == "": 221 | continue 222 | gt_cot, gt_ans = parse_ground_truth(example, data_name) 223 | example["gt_ans"] = gt_ans 224 | full_prompt = construct_prompt(example, data_name, args) 225 | 226 | if idx == args.start: 227 | print(full_prompt) 228 | 229 | sample = { 230 | "idx": idx, 231 | "question": example["question"], 232 | "gt_cot": gt_cot, 233 | "gt": gt_ans, 234 | "prompt": full_prompt, 235 | } 236 | 237 | # Add remaining fields 238 | for key in [ 239 | "level", 240 | "type", 241 | "unit", 242 | "solution_type", 243 | "choices", 244 | "solution", 245 | "ques_type", 246 | "ans_type", 247 | "answer_type", 248 | "dataset", 249 | "subfield", 250 | "filed", 251 | "theorem", 252 | "answer", 253 | ]: 254 | if key in example: 255 | sample[key] = example[key] 256 | samples.append(sample) 257 | 258 | # Prepare prompts 259 | input_prompts = [ 260 | sample["prompt"] for sample in samples for _ in range(args.n_sampling) 261 | ] 262 | if args.apply_chat_template: 263 | input_prompts = [ 264 | tokenizer.apply_chat_template( 265 | [{"role": "user", "content": prompt.strip()}], 266 | tokenize=False, 267 | add_generation_prompt=True, 268 | ) 269 | for prompt in input_prompts 270 | ] 271 | remain_prompts = input_prompts 272 | remain_prompts = [(i, prompt) for i, prompt in enumerate(remain_prompts)] 273 | end_prompts = [] 274 | 275 | max_func_call = 1 if args.prompt_type in ["cot", "pal"] else 4 276 | 277 | stop_words = ["", "<|im_end|>", "<|endoftext|>"] 278 | 279 | if args.prompt_type in ["cot"]: 280 | stop_words.append("\n\nQuestion:") 281 | if args.prompt_type in ["pal", "tool-integrated", "jiuzhang_tora"]: 282 | stop_words.extend(["\n\n---", "```output"]) 283 | elif args.prompt_type in ["wizard_zs", "platypus_fs"]: 284 | stop_words.extend(["Instruction", "Response"]) 285 | elif "jiuzhang" in args.prompt_type: 286 | stop_words.append("\n\n## Question") 287 | elif "numina" in args.prompt_type: 288 | stop_words.append("\n### Problem") 289 | elif "pure" in args.prompt_type: 290 | stop_words.append("\n\n\n") 291 | 292 | # Start inference 293 | # Measure time usage 294 | start_time = time.time() 295 | for epoch in range(max_func_call): 296 | print("-" * 20, "Epoch", epoch) 297 | current_prompts = remain_prompts 298 | if len(current_prompts) == 0: 299 | break 300 | 301 | # Get all outputs 302 | prompts = [item[1] for item in current_prompts] 303 | 304 | # Use model to generate outputs 305 | if hasattr(llm, 'generate') and isinstance(llm, RecursiveThinkingModel): 306 | # Use RecursiveThinkingModel's generate method 307 | outputs = [] 308 | # Process the list of interested tokens 309 | interested_tokens = args.interested_tokens 310 | 311 | for prompt in prompts: 312 | output = generate_with_recursive_model( 313 | model=llm, 314 | prompt=prompt, 315 | max_tokens=args.max_tokens_per_call, 316 | interested_tokens=interested_tokens, 317 | use_recursive_thinking=args.use_recursive_thinking 318 | ) 319 | # Remove prompt from output, keep only the generated part 320 | if prompt in output: 321 | output = output[len(prompt):] 322 | outputs.append(output) 323 | 324 | # Check for stop words and truncate 325 | for stop_word in stop_words: 326 | if stop_word in output: 327 | output = output.split(stop_word)[0] 328 | break 329 | else: 330 | # Standard HF generation 331 | outputs = generate_completions( 332 | model=llm, 333 | tokenizer=tokenizer, 334 | prompts=prompts, 335 | max_new_tokens=args.max_tokens_per_call, 336 | batch_size=16, 337 | stop_id_sequences=stop_words, 338 | ) 339 | 340 | assert len(outputs) == len(current_prompts) 341 | 342 | # Process all outputs 343 | remain_prompts = [] 344 | remain_codes = [] 345 | for (i, query), output in zip(current_prompts, outputs): 346 | output = output.rstrip() 347 | query += output 348 | if args.prompt_type == "pal": 349 | remain_prompts.append((i, query)) 350 | if "```python" in output: 351 | output = extract_program(query) 352 | remain_codes.append(output) 353 | elif args.prompt_type == "cot": 354 | end_prompts.append((i, query)) 355 | elif "boxed" not in output and output.endswith("```"): 356 | program = extract_program(query) 357 | remain_prompts.append((i, query)) 358 | remain_codes.append(program) 359 | else: 360 | end_prompts.append((i, query)) 361 | 362 | # Execute remaining prompts 363 | remain_results = executor.batch_apply(remain_codes) 364 | for k in range(len(remain_prompts)): 365 | i, query = remain_prompts[k] 366 | res, report = remain_results[k] 367 | exec_result = res if res else report 368 | if "pal" in args.prompt_type: 369 | exec_result = "\\boxed{" + exec_result + "}" 370 | exec_result = f"\n```output\n{exec_result}\n```\n" 371 | query += exec_result 372 | # Not finished 373 | if epoch == max_func_call - 1: 374 | query += "\nReached maximum function call limit." 375 | remain_prompts[k] = (i, query) 376 | 377 | # Unresolved samples 378 | print("Unresolved samples:", len(remain_prompts)) 379 | end_prompts.extend(remain_prompts) 380 | # Sort by index 381 | end_prompts = sorted(end_prompts, key=lambda x: x[0]) 382 | 383 | # Remove input_prompt from end_prompt 384 | codes = [] 385 | assert len(input_prompts) == len(end_prompts) 386 | for i in range(len(input_prompts)): 387 | _, end_prompt = end_prompts[i] 388 | code = end_prompt.split(input_prompts[i])[-1].strip() 389 | for stop_word in stop_words: 390 | if stop_word in code: 391 | code = code.split(stop_word)[0].strip() 392 | codes.append(code) 393 | 394 | # Extract predictions 395 | results = [ 396 | run_execute(executor, code, args.prompt_type, data_name) for code in codes 397 | ] 398 | time_use = time.time() - start_time 399 | 400 | # Put results back into samples 401 | all_samples = [] 402 | for i, sample in enumerate(samples): 403 | code = codes[i * args.n_sampling : (i + 1) * args.n_sampling] 404 | result = results[i * args.n_sampling : (i + 1) * args.n_sampling] 405 | preds = [item[0] for item in result] 406 | reports = [item[1] for item in result] 407 | for j in range(len(preds)): 408 | if sample["gt"] in ["A", "B", "C", "D", "E"] and preds[j] not in [ 409 | "A", 410 | "B", 411 | "C", 412 | "D", 413 | "E", 414 | ]: 415 | preds[j] = choice_answer_clean(code[j]) 416 | elif is_multi_choice(sample["gt"]) and not is_multi_choice(preds[j]): 417 | # Remove any non-choice characters 418 | preds[j] = "".join( 419 | [c for c in preds[j] if c in ["A", "B", "C", "D", "E"]] 420 | ) 421 | 422 | sample.pop("prompt") 423 | sample.update({"code": code, "pred": preds, "report": reports}) 424 | all_samples.append(sample) 425 | 426 | # Add processed samples 427 | all_samples.extend(processed_samples) 428 | all_samples, result_json = evaluate( 429 | samples=all_samples, 430 | data_name=data_name, 431 | prompt_type=args.prompt_type, 432 | execute=True, 433 | ) 434 | 435 | # Save outputs 436 | if len(processed_samples) < len(all_samples) and args.save_outputs: 437 | save_jsonl(all_samples, out_file) 438 | 439 | result_json["time_use_in_second"] = time_use 440 | result_json["time_use_in_minute"] = ( 441 | f"{int(time_use // 60)}:{int(time_use % 60):02d}" 442 | ) 443 | 444 | # Add recursive thinking parameters to metrics 445 | if args.use_recursive_thinking: 446 | result_json["recursive_thinking"] = { 447 | "enabled": args.use_recursive_thinking, 448 | "extract_layer_id": args.extract_layer_id, 449 | "inject_layer_id": args.inject_layer_id, 450 | "num_recursive_steps": args.num_recursive_steps, 451 | "interested_tokens_count": len(args.interested_tokens) if args.interested_tokens else 0 452 | } 453 | 454 | with open( 455 | out_file.replace(".jsonl", f"_{args.prompt_type}_metrics.json"), "w" 456 | ) as f: 457 | json.dump(result_json, f, indent=4) 458 | return result_json 459 | 460 | 461 | if __name__ == "__main__": 462 | 463 | args = parse_args() 464 | 465 | english_word_token_ids = set() 466 | english_pattern = re.compile(r'^[a-zA-Z]+$') 467 | 468 | english_word_count = 0 469 | target_word_count = 10 470 | 471 | # Read JSONL file 472 | with open(args.interested_tokens_file_path, 'r') as f: 473 | for line in f: 474 | try: 475 | if english_word_count >= target_word_count: 476 | break 477 | 478 | # Parse JSON 479 | record = json.loads(line) 480 | word = record.get("word", "") 481 | token_ids = record.get("token_ids", []) 482 | 483 | if english_pattern.match(word): 484 | english_word_count += 1 485 | for token_id in token_ids: 486 | english_word_token_ids.add(token_id) 487 | 488 | except json.JSONDecodeError: 489 | print(f"Error: Unable to parse JSON line: {line}") 490 | except Exception as e: 491 | print(f"Error during processing: {e}") 492 | 493 | args.interested_tokens = english_word_token_ids 494 | 495 | set_seed(args.seed) 496 | setup(args) -------------------------------------------------------------------------------- /src/applications/parser.py: -------------------------------------------------------------------------------- 1 | import random 2 | import regex 3 | import re 4 | import sympy 5 | from latex2sympy2 import latex2sympy 6 | from typing import TypeVar, Iterable, List, Union, Any, Dict 7 | from word2number import w2n 8 | from utils import * 9 | 10 | 11 | def _fix_fracs(string): 12 | substrs = string.split("\\frac") 13 | new_str = substrs[0] 14 | if len(substrs) > 1: 15 | substrs = substrs[1:] 16 | for substr in substrs: 17 | new_str += "\\frac" 18 | if len(substr) > 0 and substr[0] == "{": 19 | new_str += substr 20 | else: 21 | try: 22 | assert len(substr) >= 2 23 | except: 24 | return string 25 | a = substr[0] 26 | b = substr[1] 27 | if b != "{": 28 | if len(substr) > 2: 29 | post_substr = substr[2:] 30 | new_str += "{" + a + "}{" + b + "}" + post_substr 31 | else: 32 | new_str += "{" + a + "}{" + b + "}" 33 | else: 34 | if len(substr) > 2: 35 | post_substr = substr[2:] 36 | new_str += "{" + a + "}" + b + post_substr 37 | else: 38 | new_str += "{" + a + "}" + b 39 | string = new_str 40 | return string 41 | 42 | 43 | def _fix_a_slash_b(string): 44 | if len(string.split("/")) != 2: 45 | return string 46 | a = string.split("/")[0] 47 | b = string.split("/")[1] 48 | try: 49 | if "sqrt" not in a: 50 | a = int(a) 51 | if "sqrt" not in b: 52 | b = int(b) 53 | assert string == "{}/{}".format(a, b) 54 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 55 | return new_string 56 | except: 57 | return string 58 | 59 | 60 | def _fix_sqrt(string): 61 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) 62 | return _string 63 | 64 | 65 | def convert_word_number(text: str) -> str: 66 | try: 67 | text = str(w2n.word_to_num(text)) 68 | except: 69 | pass 70 | return text 71 | 72 | 73 | # units mainly from MathQA 74 | unit_texts = [ 75 | "east", 76 | "degree", 77 | "mph", 78 | "kmph", 79 | "ft", 80 | "m sqaure", 81 | " m east", 82 | "sq m", 83 | "deg", 84 | "mile", 85 | "q .", 86 | "monkey", 87 | "prime", 88 | "ratio", 89 | "profit of rs", 90 | "rd", 91 | "o", 92 | "gm", 93 | "p . m", 94 | "lb", 95 | "tile", 96 | "per", 97 | "dm", 98 | "lt", 99 | "gain", 100 | "ab", 101 | "way", 102 | "west", 103 | "a .", 104 | "b .", 105 | "c .", 106 | "d .", 107 | "e .", 108 | "f .", 109 | "g .", 110 | "h .", 111 | "t", 112 | "a", 113 | "h", 114 | "no change", 115 | "men", 116 | "soldier", 117 | "pie", 118 | "bc", 119 | "excess", 120 | "st", 121 | "inches", 122 | "noon", 123 | "percent", 124 | "by", 125 | "gal", 126 | "kmh", 127 | "c", 128 | "acre", 129 | "rise", 130 | "a . m", 131 | "th", 132 | "π r 2", 133 | "sq", 134 | "mark", 135 | "l", 136 | "toy", 137 | "coin", 138 | "sq . m", 139 | "gallon", 140 | "° f", 141 | "profit", 142 | "minw", 143 | "yr", 144 | "women", 145 | "feet", 146 | "am", 147 | "pm", 148 | "hr", 149 | "cu cm", 150 | "square", 151 | "v â € ™", 152 | "are", 153 | "rupee", 154 | "rounds", 155 | "cubic", 156 | "cc", 157 | "mtr", 158 | "s", 159 | "ohm", 160 | "number", 161 | "kmph", 162 | "day", 163 | "hour", 164 | "minute", 165 | "min", 166 | "second", 167 | "man", 168 | "woman", 169 | "sec", 170 | "cube", 171 | "mt", 172 | "sq inch", 173 | "mp", 174 | "∏ cm ³", 175 | "hectare", 176 | "more", 177 | "sec", 178 | "unit", 179 | "cu . m", 180 | "cm 2", 181 | "rs .", 182 | "rs", 183 | "kg", 184 | "g", 185 | "month", 186 | "km", 187 | "m", 188 | "cm", 189 | "mm", 190 | "apple", 191 | "liter", 192 | "loss", 193 | "yard", 194 | "pure", 195 | "year", 196 | "increase", 197 | "decrease", 198 | "d", 199 | "less", 200 | "Surface", 201 | "litre", 202 | "pi sq m", 203 | "s .", 204 | "metre", 205 | "meter", 206 | "inch", 207 | ] 208 | 209 | unit_texts.extend([t + "s" for t in unit_texts]) 210 | 211 | 212 | def strip_string(string, skip_unit=False): 213 | string = str(string).strip() 214 | # linebreaks 215 | string = string.replace("\n", "") 216 | 217 | # right "." 218 | string = string.rstrip(".") 219 | 220 | # remove inverse spaces 221 | # replace \\ with \ 222 | string = string.replace("\\!", "") 223 | # string = string.replace("\\ ", "") 224 | # string = string.replace("\\\\", "\\") 225 | 226 | # matrix 227 | string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) 228 | string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) 229 | string = string.replace("bmatrix", "pmatrix") 230 | 231 | # replace tfrac and dfrac with frac 232 | string = string.replace("tfrac", "frac") 233 | string = string.replace("dfrac", "frac") 234 | string = ( 235 | string.replace("\\neq", "\\ne") 236 | .replace("\\leq", "\\le") 237 | .replace("\\geq", "\\ge") 238 | ) 239 | 240 | # remove \left and \right 241 | string = string.replace("\\left", "") 242 | string = string.replace("\\right", "") 243 | string = string.replace("\\{", "{") 244 | string = string.replace("\\}", "}") 245 | 246 | # Remove unit: miles, dollars if after is not none 247 | _string = re.sub(r"\\text{.*?}$", "", string).strip() 248 | if _string != "" and _string != string: 249 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) 250 | string = _string 251 | 252 | if not skip_unit: 253 | # Remove unit: texts 254 | for _ in range(2): 255 | for unit_text in unit_texts: 256 | # use regex, the prefix should be either the start of the string or a non-alphanumeric character 257 | # the suffix should be either the end of the string or a non-alphanumeric character 258 | _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) 259 | if _string != "": 260 | string = _string 261 | 262 | # Remove circ (degrees) 263 | string = string.replace("^{\\circ}", "") 264 | string = string.replace("^\\circ", "") 265 | 266 | # remove dollar signs 267 | string = string.replace("\\$", "") 268 | string = string.replace("$", "") 269 | string = string.replace("\\(", "").replace("\\)", "") 270 | 271 | # convert word number to digit 272 | string = convert_word_number(string) 273 | 274 | # replace "\\text{...}" to "..." 275 | string = re.sub(r"\\text\{(.*?)\}", r"\1", string) 276 | for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]: 277 | string = string.replace(key, "") 278 | string = string.replace("\\emptyset", r"{}") 279 | string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}") 280 | 281 | # remove percentage 282 | string = string.replace("\\%", "") 283 | string = string.replace("\%", "") 284 | string = string.replace("%", "") 285 | 286 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 287 | string = string.replace(" .", " 0.") 288 | string = string.replace("{.", "{0.") 289 | 290 | # cdot 291 | # string = string.replace("\\cdot", "") 292 | if ( 293 | string.startswith("{") 294 | and string.endswith("}") 295 | and string.isalnum() 296 | or string.startswith("(") 297 | and string.endswith(")") 298 | and string.isalnum() 299 | or string.startswith("[") 300 | and string.endswith("]") 301 | and string.isalnum() 302 | ): 303 | string = string[1:-1] 304 | 305 | # inf 306 | string = string.replace("infinity", "\\infty") 307 | if "\\infty" not in string: 308 | string = string.replace("inf", "\\infty") 309 | string = string.replace("+\\inity", "\\infty") 310 | 311 | # and 312 | string = string.replace("and", "") 313 | string = string.replace("\\mathbf", "") 314 | 315 | # use regex to remove \mbox{...} 316 | string = re.sub(r"\\mbox{.*?}", "", string) 317 | 318 | # quote 319 | string.replace("'", "") 320 | string.replace('"', "") 321 | 322 | # i, j 323 | if "j" in string and "i" not in string: 324 | string = string.replace("j", "i") 325 | 326 | # replace a.000b where b is not number or b is end, with ab, use regex 327 | string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string) 328 | string = re.sub(r"(\d+)\.0*$", r"\1", string) 329 | 330 | # if empty, return empty string 331 | if len(string) == 0: 332 | return string 333 | if string[0] == ".": 334 | string = "0" + string 335 | 336 | # to consider: get rid of e.g. "k = " or "q = " at beginning 337 | if len(string.split("=")) == 2: 338 | if len(string.split("=")[0]) <= 2: 339 | string = string.split("=")[1] 340 | 341 | string = _fix_sqrt(string) 342 | string = string.replace(" ", "") 343 | 344 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 345 | string = _fix_fracs(string) 346 | 347 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 348 | string = _fix_a_slash_b(string) 349 | 350 | return string 351 | 352 | 353 | def extract_multi_choice_answer(pred_str): 354 | # TODO: SFT models 355 | if "Problem:" in pred_str: 356 | pred_str = pred_str.split("Problem:", 1)[0] 357 | pred_str = pred_str.replace("choice is", "answer is") 358 | patt = regex.search(r"answer is \(?(?P[abcde])\)?", pred_str.lower()) 359 | if patt is not None: 360 | return patt.group("ans").upper() 361 | return "placeholder" 362 | 363 | 364 | direct_answer_trigger_for_fewshot = ("choice is", "answer is") 365 | 366 | 367 | def choice_answer_clean(pred: str): 368 | pred = pred.strip("\n") 369 | 370 | # Determine if this is ICL, if so, use \n\n to split the first chunk. 371 | ICL = False 372 | for trigger in direct_answer_trigger_for_fewshot: 373 | if pred.count(trigger) > 1: 374 | ICL = True 375 | if ICL: 376 | pred = pred.split("\n\n")[0] 377 | 378 | # Split the trigger to find the answer. 379 | preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred) 380 | if len(preds) > 1: 381 | answer_flag = True 382 | pred = preds[-1] 383 | else: 384 | answer_flag = False 385 | 386 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") 387 | 388 | # Clean the answer based on the dataset 389 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) 390 | if tmp: 391 | pred = tmp 392 | else: 393 | pred = [pred.strip().strip(".")] 394 | 395 | if len(pred) == 0: 396 | pred = "" 397 | else: 398 | if answer_flag: 399 | # choose the first element in list ... 400 | pred = pred[0] 401 | else: 402 | # choose the last e 403 | pred = pred[-1] 404 | 405 | # Remove the period at the end, again! 406 | pred = pred.rstrip(".").rstrip("/") 407 | 408 | return pred 409 | 410 | 411 | def find_box(pred_str: str): 412 | ans = pred_str.split("boxed")[-1] 413 | if not ans: 414 | return "" 415 | if ans[0] == "{": 416 | stack = 1 417 | a = "" 418 | for c in ans[1:]: 419 | if c == "{": 420 | stack += 1 421 | a += c 422 | elif c == "}": 423 | stack -= 1 424 | if stack == 0: 425 | break 426 | a += c 427 | else: 428 | a += c 429 | else: 430 | a = ans.split("$")[0].strip() 431 | return a 432 | 433 | 434 | def clean_units(pred_str: str): 435 | """Clean the units in the number.""" 436 | 437 | def convert_pi_to_number(code_string): 438 | code_string = code_string.replace("\\pi", "π") 439 | # Replace \pi or π not preceded by a digit or } with 3.14 440 | code_string = re.sub(r"(? "3*3.14" 442 | code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string) 443 | # Handle cases where π is within braces or followed by a multiplication symbol 444 | # This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14" 445 | code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string) 446 | code_string = re.sub(r"\*(\\?π)", "*3.14", code_string) 447 | return code_string 448 | 449 | pred_str = convert_pi_to_number(pred_str) 450 | pred_str = pred_str.replace("%", "/100") 451 | pred_str = pred_str.replace("$", "") 452 | pred_str = pred_str.replace("¥", "") 453 | pred_str = pred_str.replace("°C", "") 454 | pred_str = pred_str.replace(" C", "") 455 | pred_str = pred_str.replace("°", "") 456 | return pred_str 457 | 458 | 459 | def extract_theoremqa_answer(pred: str, answer_flag: bool = True): 460 | if any([option in pred.lower() for option in ["yes", "true"]]): 461 | pred = "True" 462 | elif any([option in pred.lower() for option in ["no", "false"]]): 463 | pred = "False" 464 | elif any( 465 | [ 466 | option in pred.lower() 467 | for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"] 468 | ] 469 | ): 470 | pass 471 | else: 472 | # Some of the models somehow get used to boxed output from pre-training 473 | if "boxed" in pred: 474 | pred = find_box(pred) 475 | 476 | if answer_flag: 477 | # Extract the numbers out of the string 478 | pred = pred.split("=")[-1].strip() 479 | pred = clean_units(pred) 480 | try: 481 | tmp = str(latex2sympy(pred)) 482 | pred = str(eval(tmp)) 483 | except Exception: 484 | if re.match(r"-?[\d\.]+\s\D+$", pred): 485 | pred = pred.split(" ")[0] 486 | elif re.match(r"-?[\d\.]+\s[^\s]+$", pred): 487 | pred = pred.split(" ")[0] 488 | else: 489 | # desparate search over the last number 490 | preds = re.findall(r"-?\d*\.?\d+", pred) 491 | if len(preds) >= 1: 492 | pred = preds[-1] 493 | else: 494 | pred = "" 495 | 496 | return pred 497 | 498 | 499 | def extract_answer(pred_str, data_name, use_last_number=True): 500 | pred_str = pred_str.replace("\u043a\u0438", "") 501 | if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]: 502 | # TODO check multiple choice 503 | return choice_answer_clean(pred_str) 504 | 505 | if "final answer is $" in pred_str and "$. I hope" in pred_str: 506 | # minerva_math 507 | tmp = pred_str.split("final answer is $", 1)[1] 508 | pred = tmp.split("$. I hope", 1)[0].strip() 509 | elif "boxed" in pred_str: 510 | ans = pred_str.split("boxed")[-1] 511 | if len(ans) == 0: 512 | return "" 513 | elif ans[0] == "{": 514 | stack = 1 515 | a = "" 516 | for c in ans[1:]: 517 | if c == "{": 518 | stack += 1 519 | a += c 520 | elif c == "}": 521 | stack -= 1 522 | if stack == 0: 523 | break 524 | a += c 525 | else: 526 | a += c 527 | else: 528 | a = ans.split("$")[0].strip() 529 | pred = a 530 | elif "he answer is" in pred_str: 531 | pred = pred_str.split("he answer is")[-1].strip() 532 | elif "final answer is" in pred_str: 533 | pred = pred_str.split("final answer is")[-1].strip() 534 | elif "答案是" in pred_str: 535 | # Handle Chinese few-shot multiple choice problem answer extraction 536 | pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip() 537 | else: # use the last number 538 | if use_last_number: 539 | pattern = "-?\d*\.?\d+" 540 | pred = re.findall(pattern, pred_str.replace(",", "")) 541 | if len(pred) >= 1: 542 | pred = pred[-1] 543 | else: 544 | pred = "" 545 | else: 546 | pred = "" 547 | 548 | # choice answer 549 | if ( 550 | data_name in ["sat_math", "aqua"] 551 | or "mmlu" in data_name 552 | ): 553 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) 554 | if tmp: 555 | pred = tmp[-1] 556 | else: 557 | pred = pred.strip().strip(".") 558 | 559 | # multiple line 560 | # pred = pred.split("\n")[0] 561 | pred = re.sub(r"\n\s*", "", pred) 562 | if pred != "" and pred[0] == ":": 563 | pred = pred[1:] 564 | if pred != "" and pred[-1] == ".": 565 | pred = pred[:-1] 566 | if pred != "" and pred[-1] == "/": 567 | pred = pred[:-1] 568 | pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"]) 569 | return pred 570 | 571 | 572 | STRIP_EXCEPTIONS = ["carp_en", "minerva_math"] 573 | 574 | 575 | def parse_ground_truth(example: Dict[str, Any], data_name): 576 | # 如果样本中已经有 gt_cot 和 gt 字段,直接使用(这里不做区分大小写) 577 | if "gt_cot" in example and "gt" in example: 578 | if data_name.lower() in ["math", "minerva_math"]: 579 | gt_ans = extract_answer(example["gt_cot"], data_name) 580 | elif data_name in STRIP_EXCEPTIONS: 581 | gt_ans = example["gt"] 582 | else: 583 | gt_ans = strip_string(example["gt"]) 584 | return example["gt_cot"], gt_ans 585 | 586 | # 针对不同数据集进行解析 587 | if data_name.lower() in ["math", "minerva_math"]: 588 | gt_cot = example["solution"] 589 | gt_ans = extract_answer(gt_cot, data_name) 590 | elif data_name.lower() == "math-500": 591 | # 对于 MATH-500,使用 solution 作为链式推理,直接使用 answer 作为最终答案 592 | gt_cot = example["solution"] 593 | gt_ans = strip_string(example["answer"]) 594 | elif data_name == "gsm8k": 595 | gt_cot, gt_ans = example["answer"].split("####") 596 | elif data_name == "svamp": 597 | gt_cot, gt_ans = example["Equation"], example["Answer"] 598 | elif data_name == "asdiv": 599 | gt_cot = example["formula"] 600 | gt_ans = re.sub(r"\(.*?\)", "", example["answer"]) 601 | elif data_name == "mawps": 602 | gt_cot, gt_ans = None, example["target"] 603 | elif data_name == "tabmwp": 604 | gt_cot = example["solution"] 605 | gt_ans = example["answer"] 606 | if example["ans_type"] in ["integer_number", "decimal_number"]: 607 | if "/" in gt_ans: 608 | gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1]) 609 | elif "," in gt_ans: 610 | gt_ans = float(gt_ans.replace(",", "")) 611 | elif "%" in gt_ans: 612 | gt_ans = float(gt_ans.split("%")[0]) / 100 613 | else: 614 | gt_ans = float(gt_ans) 615 | elif data_name == "carp_en": 616 | gt_cot, gt_ans = example["steps"], example["answer"] 617 | elif data_name == "mmlu_stem": 618 | abcd = "ABCD" 619 | gt_cot, gt_ans = None, abcd[example["answer"]] 620 | elif data_name == "sat_math": 621 | gt_cot, gt_ans = None, example["Answer"] 622 | elif data_name == "aqua": 623 | gt_cot, gt_ans = None, example["correct"] 624 | elif data_name in ["gaokao2023en", "college_math", "gaokao_math_cloze"]: 625 | gt_cot, gt_ans = None, example["answer"].replace("$", "").strip() 626 | elif data_name == "gaokao_math_qa": 627 | gt_cot, gt_ans = None, example["label"] 628 | elif data_name in ["gaokao2024_mix", "cn_middle_school"]: 629 | if len(example["choice_answer"]) > 0: 630 | gt_cot, gt_ans = None, example["choice_answer"] 631 | else: 632 | gt_cot, gt_ans = None, example["answer"] 633 | elif data_name == "olympiadbench": 634 | gt_cot, gt_ans = None, example["final_answer"][0].strip("$") 635 | elif data_name in [ 636 | "aime24", 637 | "amc23", 638 | "cmath", 639 | "gaokao2024_I", 640 | "gaokao2024_II", 641 | "imo2024", 642 | ]: 643 | gt_cot, gt_ans = None, example["answer"] 644 | else: 645 | raise NotImplementedError(f"`{data_name}`") 646 | 647 | # 后处理 648 | gt_cot = str(gt_cot).strip() 649 | if data_name not in STRIP_EXCEPTIONS: 650 | gt_ans = strip_string(gt_ans, skip_unit=data_name == "carp_en") 651 | else: 652 | gt_ans = ( 653 | gt_ans.replace("\\neq", "\\ne") 654 | .replace("\\leq", "\\le") 655 | .replace("\\geq", "\\ge") 656 | ) 657 | return gt_cot, gt_ans 658 | 659 | 660 | def parse_question(example, data_name): 661 | question = "" 662 | if data_name == "asdiv": 663 | question = f"{example['body'].strip()} {example['question'].strip()}" 664 | elif data_name == "svamp": 665 | body = example["Body"].strip() 666 | if not body.endswith("."): 667 | body = body + "." 668 | question = f'{body} {example["Question"].strip()}' 669 | elif data_name == "tabmwp": 670 | title_str = ( 671 | f'regarding "{example["table_title"]}" ' if example["table_title"] else "" 672 | ) 673 | question = f"Read the following table {title_str}and answer a question:\n" 674 | question += f'{example["table"]}\n{example["question"]}' 675 | if example["choices"]: 676 | question += ( 677 | f' Please select from the following options: {example["choices"]}' 678 | ) 679 | elif data_name == "carp_en": 680 | question = example["content"] 681 | elif data_name == "mmlu_stem": 682 | options = example["choices"] 683 | assert len(options) == 4 684 | for i, (label, option) in enumerate(zip("ABCD", options)): 685 | options[i] = f"({label}) {str(option).strip()}" 686 | options = " ".join(options) 687 | question = f"{example['question'].strip()}\nAnswer Choices: {options}" 688 | elif data_name == "sat_math": 689 | options = example["options"].strip() 690 | assert "A" == options[0] 691 | options = "(" + options 692 | for ch in "BCD": 693 | if f" {ch}) " in options: 694 | options = re.sub(f" {ch}\) ", f" ({ch}) ", options) 695 | question = f"{example['question'].strip()}\nAnswer Choices: {options}" 696 | elif "aqua" in data_name: 697 | options = example["options"] 698 | choice = "(" + "(".join(options) 699 | choice = choice.replace("(", " (").replace(")", ") ").strip() 700 | choice = "\nAnswer Choices: " + choice 701 | question = example["question"].strip() + choice 702 | elif data_name == "gaokao_math_qa": 703 | options_dict = example["options"] 704 | options = [] 705 | for key in options_dict: 706 | options.append(f"({key}) {options_dict[key]}") 707 | options = " ".join(options) 708 | question = f"{example['question'].strip()}\n选项: {options}" 709 | elif data_name.lower() == "math-500": 710 | # 对于 MATH-500,使用 "problem" 字段作为问题 711 | question = example["problem"] 712 | else: 713 | for key in ["question", "problem", "Question", "input"]: 714 | if key in example: 715 | question = example[key] 716 | break 717 | 718 | # 针对 Yes/No 类型的问题添加提示 719 | _, gt_ans = parse_ground_truth(example, data_name) 720 | if isinstance(gt_ans, str): 721 | gt_lower = gt_ans.lower() 722 | if gt_lower in ["true", "false"]: 723 | question += " (True or False)" 724 | if gt_lower in ["yes", "no"]: 725 | question += " (Yes or No)" 726 | return question.strip() 727 | 728 | 729 | def run_execute(executor, result, prompt_type, data_name, execute=False): 730 | if not result or result == "error": 731 | return None, None 732 | report = None 733 | 734 | if "program_only" in prompt_type: 735 | prediction = extract_program_output(result) 736 | elif prompt_type in ["pot", "pal"] and execute: 737 | code = extract_program(result) 738 | prediction, report = executor.apply(code) 739 | else: 740 | prediction = extract_answer(result, data_name) 741 | 742 | # prediction = strip_string(prediction, skip_unit=data_name == "carp_en") 743 | prediction = strip_string(prediction, skip_unit=data_name in STRIP_EXCEPTIONS) 744 | return prediction, report 745 | 746 | 747 | def _test_extract_answer(): 748 | text = """ 749 | This is still not equal to $0$, so we must have made another mistake. 750 | 751 | When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that: 752 | 753 | \[\frac{386}{64} - 7 = \frac{386}{64} - \frac{7 \cdot 64}{1 \cdot 64} = \frac{386 - 448}{64} = \frac{-62}{64}.\] 754 | 755 | This is still not equal to $0$, so we must have made another mistake. 756 | 757 | When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that: 758 | 759 | \[\frac{386}{64} 760 | """ 761 | print(extract_answer(text, "math-oai", use_last_number=False)) 762 | print(choice_answer_clean("\mathrm{(D)\}1,008,016")) 763 | # should output a dict 764 | 765 | 766 | if __name__ == "__main__": 767 | _test_extract_answer() 768 | --------------------------------------------------------------------------------