├── assets
├── fig1.png
├── fig2.png
├── fig3.png
├── fig4.png
├── thm1.png
└── thm2.png
├── src
├── scripts
│ ├── compute_mi_trajectories.sh
│ ├── run_TTTS.sh
│ └── run_RR.sh
├── applications
│ ├── data
│ │ └── DeepSeek-R1-Distill-Llama-8B.jsonl
│ ├── data_loader.py
│ ├── evaluate.py
│ ├── trajectory.py
│ ├── python_executor.py
│ ├── utils.py
│ ├── RR_model.py
│ ├── model_utils.py
│ ├── math_utils.py
│ ├── grader.py
│ ├── TTTS_evaluate.py
│ ├── RR_evaluate.py
│ └── parser.py
├── mi_estimators.py
├── CKA.py
├── generate_gt_activation.py
├── cal_mi.py
├── generate_activation.py
└── plot_mi_peaks.ipynb
├── README.md
└── LICENSE
/assets/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/fig1.png
--------------------------------------------------------------------------------
/assets/fig2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/fig2.png
--------------------------------------------------------------------------------
/assets/fig3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/fig3.png
--------------------------------------------------------------------------------
/assets/fig4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/fig4.png
--------------------------------------------------------------------------------
/assets/thm1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/thm1.png
--------------------------------------------------------------------------------
/assets/thm2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChnQ/MI-Peaks/HEAD/assets/thm2.png
--------------------------------------------------------------------------------
/src/scripts/compute_mi_trajectories.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | models=(
5 | deepseek-ai/DeepSeek-R1-Distill-Llama-8B
6 | )
7 |
8 | for model in "${models[@]}"; do
9 | # 1. generate representations
10 | echo generate reps on $model ...
11 | python generate_activation.py --model $model --layers 31 --dataset math_train_12k --sample_num 100
12 | python generate_gt_activation.py --model $model --layers 31 --dataset math_train_12k --sample_num 100
13 |
14 | # 2. compute mutual information
15 | echo compute mi on $model ...
16 | python cal_mi.py --gt_model $model --test_model $model --layers 31 --dataset math_train_12k --sample_num 100 &
17 |
18 | done
19 |
20 |
--------------------------------------------------------------------------------
/src/scripts/run_TTTS.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd applications/
4 |
5 | model="your_dir/DeepSeek-R1-Distill-Llama-8B"
6 |
7 | model_name=$(basename "$model")
8 | dataset=aime24
9 |
10 | gpu_id=0
11 |
12 | token_budget=4096
13 |
14 | save_dir="results/${model_name}/${dataset}_budget${token_budget}/"
15 | mkdir -p "$save_dir"
16 |
17 |
18 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -u TTTS_evaluate.py \
19 | --model_name_or_path $model \
20 | --data_names $dataset \
21 | --output_dir $save_dir \
22 | --split "test" \
23 | --prompt_type "deepseek-math" \
24 | --num_test_sample -1 \
25 | --seed 0 \
26 | --temperature 0 \
27 | --n_sampling 1 \
28 | --top_p 1 \
29 | --start 0 \
30 | --end -1 \
31 | --overwrite \
32 | --use_vllm \
33 | --thinking_tokens_file_path data/${model_name}.jsonl \
34 | --max_tokens_per_call 4096 \
35 | --token_budget $token_budget
36 |
--------------------------------------------------------------------------------
/src/scripts/run_RR.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 |
4 | cd applications/
5 |
6 | PROMPT_TYPE="deepseek-math"
7 | SPLIT="test"
8 | NUM_TEST_SAMPLE=-1
9 | LOG_DIR=your_log/$(date +%m-%d_%H-%M)
10 | mkdir -p $LOG_DIR
11 |
12 |
13 | gpu_counter=0
14 |
15 |
16 | model_path='your_dir/DeepSeek-R1-Distill-Llama-8B'
17 |
18 | datasets=(
19 | "aime24"
20 | )
21 | ei_layers=(23)
22 |
23 |
24 | model_name=$(basename "$model_path")
25 |
26 | for (( j=0; j<${#datasets[@]}; j++ )); do
27 | dataset=${datasets[$j]}
28 |
29 | for (( k=0; k<${#ei_layers[@]}; k++ )); do
30 | ei_layer=${ei_layers[$k]}
31 | save_dir="scores/${model_name}/${ei_layer}/"
32 | responses_dir="responses/${model_name}/${dataset}"
33 | mkdir -p "$save_dir"
34 | mkdir -p "$responses_dir"
35 |
36 | echo "Launched: model=$model_name, dataset=$dataset, layer=$ei_layer, GPU=$gpu_counter"
37 |
38 | log_file="${LOG_DIR}/${model_name}_${dataset}_${ei_layer}.log"
39 | mkdir -p "$(dirname "$log_file")"
40 |
41 | CUDA_VISIBLE_DEVICES=$gpu_counter \
42 | python RR_evaluate.py \
43 | --model_name_or_path "$model_path" \
44 | --data_names "$dataset" \
45 | --inject_layer_id $ei_layer \
46 | --extract_layer_id $ei_layer \
47 | --num_test_sample "$NUM_TEST_SAMPLE" \
48 | --output_file "$responses_dir/${ei_layer}.jsonl"\
49 | --use_recursive_thinking True \
50 | --num_recursive_steps 1 \
51 | --output_dir "$save_dir" \
52 | --split "$SPLIT" \
53 | --prompt_type "$PROMPT_TYPE" \
54 | --seed 0 \
55 | --temperature 0 \
56 | --n_sampling 1 \
57 | --top_p 1 \
58 | --start 0 \
59 | --end -1 \
60 | --save_outputs \
61 | --overwrite \
62 | --interested_tokens_file_path data/${model_name}.jsonl \
63 | --max_tokens_per_call 16000 >> "$log_file" 2>&1 &
64 |
65 | ((gpu_counter++))
66 | done
67 | done
68 |
69 | # Wait for all background tasks to complete
70 | wait
71 | echo "All tasks completed!"
--------------------------------------------------------------------------------
/src/applications/data/DeepSeek-R1-Distill-Llama-8B.jsonl:
--------------------------------------------------------------------------------
1 | {"word": "So", "freq": 394, "token_ids": [2100, 4516, 78012]}
2 | {"word": "Let", "freq": 218, "token_ids": [6914, 10267]}
3 | {"word": "Hmm", "freq": 127, "token_ids": [81122, 89290]}
4 | {"word": " ", "freq": 118, "token_ids": []}
5 | {"word": "\n\n", "freq": 114, "token_ids": [271, 4815, 15152, 16176, 19124, 24356, 35033, 35249, 40965, 62098, 66367, 66768, 72348, 75625, 81923, 126595]}
6 | {"word": "I", "freq": 112, "token_ids": [40, 358, 25494]}
7 | {"word": "The", "freq": 89, "token_ids": [578, 791, 33026]}
8 | {"word": "Okay", "freq": 88, "token_ids": [33413, 36539]}
9 | {"word": "That", "freq": 81, "token_ids": [3011, 4897]}
10 | {"word": "First", "freq": 79, "token_ids": [5451, 5629]}
11 | {"word": "Now", "freq": 74, "token_ids": [4800, 7184]}
12 | {"word": "\\", "freq": 70, "token_ids": [59, 1144]}
13 | {"word": "Wait", "freq": 68, "token_ids": [14144, 14524]}
14 | {"word": "But", "freq": 64, "token_ids": [2030, 4071]}
15 | {"word": "-", "freq": 47, "token_ids": [12, 482]}
16 | {"word": "Then", "freq": 45, "token_ids": [5112, 12487]}
17 | {"word": "Since", "freq": 45, "token_ids": [8876, 12834]}
18 | {"word": "\\(", "freq": 44, "token_ids": [18240, 45392]}
19 | {"word": "Therefore", "freq": 35, "token_ids": [15636, 55915]}
20 | {"word": "=", "freq": 32, "token_ids": [28, 284]}
21 | {"word": "1", "freq": 32, "token_ids": [16]}
22 | {"word": "Maybe", "freq": 29, "token_ids": [10926, 22105]}
23 | {"word": "x", "freq": 28, "token_ids": [87, 865, 10436]}
24 | {"word": "\n", "freq": 27, "token_ids": [198, 319, 720, 1084, 1602, 1734, 1827, 2355, 2451, 2591, 3456, 4574, 4660, 5996, 6053, 6336, 6494, 6557, 7071, 7786, 9175, 10636, 10912, 11187, 12064, 12586, 12858, 16052, 16244, 16462, 16554, 17707, 17934, 18737, 19548, 20254, 20959, 25332, 26510, 27381, 27644, 28465, 28768, 29347, 31745, 33645, 34741, 35583, 37677, 38304, 38792, 39185, 39420, 39912, 40748, 41437, 42736, 46675, 46907, 47526, 52050, 52224, 52580, 53820, 56547, 57696, 59659, 63317, 65883, 66417, 67934, 68764, 69862, 70977, 72764, 72879, 73453, 74296, 76328, 79093, 79455, 80100, 85952, 86522, 87870, 89966, 90001, 90260, 91406, 91458, 92305, 95791, 96047, 97117, 98414, 99307, 99351, 106050, 120582, 122019]}
25 | {"word": "If", "freq": 25, "token_ids": [1442, 2746, 52792]}
26 | {"word": "To", "freq": 24, "token_ids": [1271, 2057]}
27 | {"word": "**", "freq": 24, "token_ids": [334, 3146]}
28 | {"word": "2", "freq": 20, "token_ids": [17]}
29 | {"word": "(", "freq": 20, "token_ids": [7, 320]}
30 | {"word": "5", "freq": 17, "token_ids": [20]}
31 |
--------------------------------------------------------------------------------
/src/mi_estimators.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def distmat(X):
8 | """ distance matrix
9 | """
10 | if len(X.shape) == 1:
11 | X = X.view(-1, 1)
12 | r = torch.sum(X * X, 1)
13 | r = r.view([-1, 1])
14 | a = torch.mm(X, torch.transpose(X, 0, 1))
15 | D = r.expand_as(a) - 2 * a + torch.transpose(r, 0, 1).expand_as(a)
16 | D = torch.abs(D)
17 | return D
18 |
19 |
20 | def sigma_estimation(X, Y):
21 | """ sigma from median distance
22 | """
23 | D = distmat(torch.cat([X, Y]))
24 | D = D.detach().cpu().numpy()
25 | Itri = np.tril_indices(D.shape[0], -1)
26 | Tri = D[Itri]
27 | med = np.median(Tri)
28 | if med <= 0:
29 | med = np.mean(Tri)
30 | if med < 1E-2:
31 | med = 1E-2
32 | return med
33 |
34 |
35 | def kernelmat(X, sigma, ktype='gaussian'):
36 | """ kernel matrix baker
37 | """
38 | if len(X.shape) == 1:
39 | X = X.view(-1, 1)
40 |
41 | m = int(X.size()[0])
42 | H = torch.eye(m) - (1. / m) * torch.ones([m, m])
43 |
44 | if ktype == "gaussian":
45 | Dxx = distmat(X)
46 |
47 | if sigma:
48 | variance = 2. * sigma * sigma * X.size()[1]
49 | Kx = torch.exp(-Dxx / variance).type(torch.FloatTensor) # kernel matrices
50 | # print(sigma, torch.mean(Kx), torch.max(Kx), torch.min(Kx))
51 | else:
52 | try:
53 | sx = sigma_estimation(X, X)
54 | Kx = torch.exp(-Dxx / (2. * sx * sx)).type(torch.FloatTensor)
55 | except RuntimeError as e:
56 | raise RuntimeError("Unstable sigma {} with maximum/minimum input ({},{})".format(
57 | sx, torch.max(X), torch.min(X)))
58 |
59 |
60 | elif ktype == "linear":
61 | Kx = torch.mm(X, X.T).type(torch.FloatTensor)
62 |
63 | elif ktype == 'IMQ':
64 | Dxx = distmat(X)
65 | Kx = 1 * torch.rsqrt(Dxx + 1)
66 |
67 | Kxc = torch.mm(Kx, H)
68 |
69 | return Kxc
70 |
71 |
72 | def hsic_normalized_cca(x, y, sigma=50., ktype='gaussian'):
73 | if len(x.shape) == 1:
74 | x = x.reshape(-1, 1)
75 | if len(y.shape) == 1:
76 | y = y.reshape(-1, 1)
77 | # x = torch.from_numpy(x)
78 | # y = torch.from_numpy(y)
79 |
80 | m = int(x.size()[0])
81 | Kxc = kernelmat(x, sigma=sigma, ktype=ktype)
82 | Kyc = kernelmat(y, sigma=sigma, ktype=ktype)
83 |
84 | epsilon = 1E-5
85 | K_I = torch.eye(m)
86 | Kxc_i = torch.inverse(Kxc + epsilon * m * K_I)
87 | Kyc_i = torch.inverse(Kyc + epsilon * m * K_I)
88 | Rx = (Kxc.mm(Kxc_i))
89 | Ry = (Kyc.mm(Kyc_i))
90 | Pxy = torch.sum(torch.mul(Rx, Ry.t()))
91 |
92 | return Pxy
93 |
94 |
95 | def estimate_mi_hsic(x, y, ktype='gaussian', sigma=50.):
96 | estimate_IXY = hsic_normalized_cca(x, y, ktype=ktype, sigma=sigma)
97 | return estimate_IXY
--------------------------------------------------------------------------------
/src/CKA.py:
--------------------------------------------------------------------------------
1 | # python ablation_cka.py --dataset pku-rlhf-10k --base_model llama-2-7b-chat
2 | # python ablation_cka.py --dataset bigcodebench_complete_prompts --base_model codellama-7b
3 | # python ablation_cka.py --dataset mmlu_all_test_questions --base_model llama-2-7b-chat
4 |
5 | import os
6 | import argparse
7 | # from model import Projection, MLP
8 | # from generate_head_activations import load_acts
9 | import numpy as np
10 | import math
11 | import pickle
12 | import torch
13 | import torch.optim as optim
14 |
15 |
16 | class CKA(object):
17 | def __init__(self):
18 | pass
19 |
20 | def centering(self, K):
21 | n = K.shape[0]
22 | unit = np.ones([n, n])
23 | I = np.eye(n)
24 | H = I - unit / n
25 | return np.dot(np.dot(H, K), H)
26 |
27 | def rbf(self, X, sigma=None):
28 | GX = np.dot(X, X.T)
29 | KX = np.diag(GX) - GX + (np.diag(GX) - GX).T
30 | if sigma is None:
31 | mdist = np.median(KX[KX != 0])
32 | sigma = math.sqrt(mdist)
33 | KX *= - 0.5 / (sigma * sigma)
34 | KX = np.exp(KX)
35 | return KX
36 |
37 | def kernel_HSIC(self, X, Y, sigma):
38 | return np.sum(self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma)))
39 |
40 | def linear_HSIC(self, X, Y):
41 | L_X = X @ X.T
42 | L_Y = Y @ Y.T
43 | return np.sum(self.centering(L_X) * self.centering(L_Y))
44 |
45 | def linear_CKA(self, X, Y):
46 | hsic = self.linear_HSIC(X, Y)
47 | var1 = np.sqrt(self.linear_HSIC(X, X))
48 | var2 = np.sqrt(self.linear_HSIC(Y, Y))
49 |
50 | return hsic / (var1 * var2)
51 |
52 | def kernel_CKA(self, X, Y, sigma=None):
53 | hsic = self.kernel_HSIC(X, Y, sigma)
54 | var1 = np.sqrt(self.kernel_HSIC(X, X, sigma))
55 | var2 = np.sqrt(self.kernel_HSIC(Y, Y, sigma))
56 |
57 | return hsic / (var1 * var2)
58 |
59 |
60 | class CudaCKA(object):
61 | def __init__(self, device):
62 | self.device = device
63 |
64 | def centering(self, K):
65 | n = K.shape[0]
66 | unit = torch.ones([n, n], device=self.device)
67 | I = torch.eye(n, device=self.device)
68 | H = I - unit / n
69 | return torch.matmul(torch.matmul(H, K), H)
70 |
71 | def rbf(self, X, sigma=None):
72 | GX = torch.matmul(X, X.T)
73 | KX = torch.diag(GX) - GX + (torch.diag(GX) - GX).T
74 | if sigma is None:
75 | mdist = torch.median(KX[KX != 0])
76 | sigma = math.sqrt(mdist)
77 | KX *= - 0.5 / (sigma * sigma)
78 | KX = torch.exp(KX)
79 | return KX
80 |
81 | def kernel_HSIC(self, X, Y, sigma):
82 | return torch.sum(self.centering(self.rbf(X, sigma)) * self.centering(self.rbf(Y, sigma)))
83 |
84 | def linear_HSIC(self, X, Y):
85 | L_X = torch.matmul(X, X.T)
86 | L_Y = torch.matmul(Y, Y.T)
87 | return torch.sum(self.centering(L_X) * self.centering(L_Y))
88 |
89 | def linear_CKA(self, X, Y):
90 | hsic = self.linear_HSIC(X, Y)
91 | var1 = torch.sqrt(self.linear_HSIC(X, X))
92 | var2 = torch.sqrt(self.linear_HSIC(Y, Y))
93 |
94 | return hsic / (var1 * var2)
95 |
96 | def kernel_CKA(self, X, Y, sigma=None):
97 | hsic = self.kernel_HSIC(X, Y, sigma)
98 | var1 = torch.sqrt(self.kernel_HSIC(X, X, sigma))
99 | var2 = torch.sqrt(self.kernel_HSIC(Y, Y, sigma))
100 | return hsic / (var1 * var2)
101 |
102 |
103 |
--------------------------------------------------------------------------------
/src/generate_gt_activation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import argparse
4 | import pandas as pd
5 | from transformers import AutoTokenizer, AutoModelForCausalLM
6 | from tqdm import tqdm
7 |
8 |
9 | class Hook:
10 | def __init__(self, token_position=-1):
11 | self.token_position = token_position
12 | self.tokens_embeddings = []
13 |
14 | def __call__(self, module, module_inputs, module_outputs):
15 | # output: [batch, seq_len, hidden_size]
16 | hidden_states = module_outputs[0] if isinstance(module_outputs, tuple) else module_outputs
17 | emb = hidden_states[0, self.token_position].detach().cpu()
18 |
19 | self.tokens_embeddings.append(emb)
20 |
21 |
22 | def load_model(model_path):
23 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
24 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
25 |
26 | return tokenizer, model
27 |
28 |
29 | def get_acts(query_list, tokenizer, model, layers, device, token_pos=-1):
30 | """
31 | Get given layer activations for the statements.
32 | Return dictionary of stacked activations.
33 |
34 | token_pos: default to fetch the last token's activations
35 | """
36 | # attach hooks
37 | hooks, handles = [], []
38 | for layer in layers:
39 | hook = Hook(token_position=token_pos)
40 | handle = model.model.layers[layer].register_forward_hook(hook)
41 | hooks.append(hook), handles.append(handle)
42 |
43 | # get activations
44 | acts = {id: {layer: [] for layer in layers} for id in range(len(query_list))}
45 |
46 | for id, query in tqdm(enumerate(query_list), total=len(query_list), desc="Processing Queries"):
47 |
48 | input_ids = tokenizer.encode(query, return_tensors="pt")
49 | with torch.no_grad():
50 | model(input_ids)
51 | for layer, hook in zip(layers, hooks):
52 | acts[id][layer] = hook.tokens_embeddings
53 |
54 | for hook in hooks:
55 | hook.tokens_embeddings = []
56 |
57 | for id, layer_acts in acts.items():
58 | for layer, emb in layer_acts.items():
59 | layer_acts[layer] = torch.stack(emb).float()
60 |
61 | # remove hooks
62 | for handle in handles:
63 | handle.remove()
64 |
65 | return acts
66 |
67 |
68 | def main():
69 | parser = argparse.ArgumentParser(description="Generate gt activations for statements in a dataset")
70 | parser.add_argument("--model_path", default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
71 | parser.add_argument("--layers", nargs='+', help="Layers to save embeddings from")
72 | parser.add_argument("--dataset", default='math_train_12k')
73 | parser.add_argument("--sample_num", type=int, default=100)
74 | parser.add_argument("--output_dir", default="acts", help="Directory to save activations to")
75 | args = parser.parse_args()
76 |
77 | dataset = args.dataset
78 |
79 | math_data = pd.read_csv(f'data/{dataset}.csv')
80 |
81 | solution_list = math_data['solution'].tolist()[:args.sample_num]
82 |
83 | tokenizer, model = load_model(args.model_path)
84 |
85 | layers = [int(layer) for layer in args.layers]
86 | if layers == [-1]:
87 | layers = list(range(len(model.model.layers)))
88 |
89 | acts = get_acts(solution_list, tokenizer, model, layers, device, token_pos=-1)
90 |
91 | # save representations
92 | os.makedirs(f'{args.output_dir}/gt', exist_ok=True)
93 | torch.save(acts, f"{args.output_dir}/gt/{dataset}_{args.model_path.split('/')[-1]}.pth")
94 |
95 |
96 | if __name__=='__main__':
97 | main()
--------------------------------------------------------------------------------
/src/applications/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import datasets
5 | from datasets import load_dataset, Dataset, concatenate_datasets
6 | from utils import load_jsonl, lower_keys
7 |
8 |
9 | def load_data(data_name, split, data_dir="./data"):
10 | data_file = f"{data_dir}/{data_name}/{split}.jsonl"
11 | if os.path.exists(data_file):
12 | examples = list(load_jsonl(data_file))
13 | else:
14 | if data_name == "math":
15 | dataset = load_dataset(
16 | "competition_math",
17 | split=split,
18 | name="main",
19 | cache_dir=f"{data_dir}/temp",
20 | )
21 | elif data_name == "MATH-500":
22 | dataset = load_dataset(data_name, split=split)
23 | elif data_name == "gsm8k":
24 | dataset = load_dataset(data_name, split=split)
25 | elif data_name == "svamp":
26 | # evaluate on training set + test set
27 | dataset = load_dataset("ChilleD/SVAMP", split="train")
28 | dataset = concatenate_datasets(
29 | [dataset, load_dataset("ChilleD/SVAMP", split="test")]
30 | )
31 | elif data_name == "asdiv":
32 | dataset = load_dataset("EleutherAI/asdiv", split="validation")
33 | dataset = dataset.filter(
34 | lambda x: ";" not in x["answer"]
35 | ) # remove multi-answer examples
36 | elif data_name == "mawps":
37 | examples = []
38 | # four sub-tasks
39 | for data_name in ["singleeq", "singleop", "addsub", "multiarith"]:
40 | sub_examples = list(load_jsonl(f"{data_dir}/mawps/{data_name}.jsonl"))
41 | for example in sub_examples:
42 | example["type"] = data_name
43 | examples.extend(sub_examples)
44 | dataset = Dataset.from_list(examples)
45 | elif data_name == "mmlu_stem":
46 | dataset = load_dataset("hails/mmlu_no_train", "all", split="test")
47 | # only keep stem subjects
48 | stem_subjects = [
49 | "abstract_algebra",
50 | "astronomy",
51 | "college_biology",
52 | "college_chemistry",
53 | "college_computer_science",
54 | "college_mathematics",
55 | "college_physics",
56 | "computer_security",
57 | "conceptual_physics",
58 | "electrical_engineering",
59 | "elementary_mathematics",
60 | "high_school_biology",
61 | "high_school_chemistry",
62 | "high_school_computer_science",
63 | "high_school_mathematics",
64 | "high_school_physics",
65 | "high_school_statistics",
66 | "machine_learning",
67 | ]
68 | dataset = dataset.rename_column("subject", "type")
69 | dataset = dataset.filter(lambda x: x["type"] in stem_subjects)
70 | elif data_name == "carp_en":
71 | dataset = load_jsonl(f"{data_dir}/carp_en/test.jsonl")
72 | else:
73 | raise NotImplementedError(data_name)
74 |
75 | examples = list(dataset)
76 | examples = [lower_keys(example) for example in examples]
77 | dataset = Dataset.from_list(examples)
78 | os.makedirs(f"{data_dir}/{data_name}", exist_ok=True)
79 | dataset.to_json(data_file)
80 |
81 | # add 'idx' in the first column
82 | if "idx" not in examples[0]:
83 | examples = [{"idx": i, **example} for i, example in enumerate(examples)]
84 |
85 | # dedepulicate & sort
86 | examples = sorted(examples, key=lambda x: x["idx"])
87 | return examples
88 |
--------------------------------------------------------------------------------
/src/cal_mi.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import torch.nn.functional as F
5 | from tqdm import tqdm
6 | from mi_estimators import estimate_mi_hsic
7 |
8 |
9 | def calculate_mi(acts, gt_acts, layers=[], num_samples=-1, save_dir='results/mi/', args=None):
10 |
11 | num_samples = len(acts) if num_samples < 0 else num_samples
12 | num_layers = len(acts[0]['reps'])
13 |
14 | mi_list = []
15 |
16 | all_mi_matrices_list = []
17 |
18 | try:
19 | final_mi_dict = torch.load(f'{save_dir}/{args.dataset}_gtmodel={args.gt_model}_testmodel={args.test_model}.pth')
20 | except:
21 | final_mi_dict = {
22 | id: {
23 | 'reps': {layer: [] for layer in range(num_layers)},
24 | 'total_tokens': -1
25 | }
26 | for id in range(num_samples)
27 | }
28 |
29 | for id in tqdm(range(0, num_samples)): # for each query
30 |
31 | if final_mi_dict[id]['total_tokens'] > 0:
32 | print(f'id {id} has been computed, skip.')
33 | continue
34 |
35 | final_mi_dict[id]['total_tokens'] = acts[id]['token_ids'].shape[0]
36 |
37 |
38 | if len(layers) == 0:
39 | layers = [layer for layer in acts[id]['reps'].keys()]
40 |
41 | layer_mi_matrix = []
42 | for layer in layers:
43 | here_num_tokens = acts[id]['reps'][layer].shape[0]
44 | layer_mi_list = torch.zeros(here_num_tokens)
45 |
46 | for i in range(here_num_tokens):
47 | layer_mi_list[i] = estimate_mi_hsic(acts[id]['reps'][layer][i], gt_acts[id][layer][0])
48 | print(f'loop {i} finished!')
49 |
50 | layer_mi_matrix.append(layer_mi_list)
51 | final_mi_dict[id]['reps'][layer] = layer_mi_list
52 |
53 |
54 | print(f'[id={id}, layer={layer}] layer_mi_list:', layer_mi_list)
55 | print('total_tokens:', acts[id]['token_ids'].shape[0])
56 |
57 | all_mi_matrices_list.append(layer_mi_matrix)
58 |
59 | print(f'id {id} mi_matrix computation finished!')
60 |
61 | # save
62 | os.makedirs(save_dir, exist_ok=True)
63 | torch.save(final_mi_dict, f'{save_dir}/{args.dataset}_gtmodel={args.gt_model}_testmodel={args.test_model}.pth')
64 |
65 |
66 | return final_mi_dict
67 |
68 |
69 |
70 | def load_reps(dataset_name, model_tag, is_gt=False, step_level=False):
71 | print(f'Loading activations of model [{model_tag}], on dataset [{dataset_name}]...')
72 |
73 | if is_gt:
74 | return torch.load(f"acts/gt/{dataset_name}_{model_tag}.pth")
75 | else:
76 | return torch.load(f"acts/reasoning_evolve/{dataset_name}_{model_tag}.pth")
77 |
78 |
79 | def main():
80 | parser = argparse.ArgumentParser(description="Generate activations for statements in a dataset")
81 |
82 | parser.add_argument("--gt_model", default='deepseek-ai/DeepSeek-R1-Distill-Llama-8B')
83 | parser.add_argument("--test_model", default='deepseek-ai/DeepSeek-R1-Distill-Llama-8B')
84 | parser.add_argument("--dataset", default='math_train_12k')
85 | parser.add_argument("--layers", nargs='*', type=int, default=[])
86 |
87 | parser.add_argument("--sample_num", type=int, default=100)
88 |
89 |
90 | args = parser.parse_args()
91 |
92 | sample_num = args.sample_num
93 |
94 |
95 | dataset_name = args.dataset
96 |
97 | gt_model = args.gt_model.split('/')[-1]
98 | test_model = args.test_model.split('/')[-1]
99 | layers = args.layers
100 |
101 | acts = load_reps(dataset_name=dataset_name, model_tag=test_model)
102 | gt_acts = load_reps(dataset_name=dataset_name, model_tag=gt_model, is_gt=True)
103 |
104 |
105 | save_dir = f'results/mi'
106 | os.makedirs(save_dir, exist_ok=True)
107 |
108 | final_mi_dict = calculate_mi(acts=acts, gt_acts=gt_acts, layers=layers, num_samples=args.sample_num, save_dir=save_dir, args=args)
109 |
110 |
111 | torch.save(final_mi_dict, f'{save_dir}/{dataset_name}_gtmodel={gt_model}_testmodel={test_model}.pth')
112 |
113 |
114 | if __name__=='__main__':
115 | main()
116 |
117 |
--------------------------------------------------------------------------------
/src/generate_activation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import pandas as pd
4 | from transformers import AutoTokenizer, AutoModelForCausalLM
5 | from tqdm import tqdm
6 |
7 |
8 | class Hook:
9 | def __init__(self, token_position=-1):
10 | self.token_position = token_position
11 | self.tokens_embeddings = []
12 |
13 | def __call__(self, module, module_inputs, module_outputs):
14 | # output: [batch, seq_len, hidden_size]
15 | hidden_states = module_outputs[0] if isinstance(module_outputs, tuple) else module_outputs
16 | emb = hidden_states[0, self.token_position].detach().cpu()
17 |
18 | self.tokens_embeddings.append(emb)
19 |
20 |
21 | def load_model(model_path):
22 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
23 | model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
24 |
25 | return tokenizer, model
26 |
27 |
28 | def get_acts(query_list, tokenizer, model, layers, token_pos=-1):
29 | """
30 | Get given layer activations for the statements.
31 | Return dictionary of stacked activations.
32 |
33 | token_pos: default to fetch the last token's activations
34 | """
35 | # attach hooks
36 | hooks, handles = [], []
37 | for layer in layers:
38 | hook = Hook(token_position=token_pos)
39 | handle = model.model.layers[layer].register_forward_hook(hook)
40 | hooks.append(hook), handles.append(handle)
41 |
42 | # get activations
43 | acts = {
44 | id: {
45 | 'reps': {layer: [] for layer in layers},
46 | 'token_ids': []
47 | }
48 | for id in range(len(query_list))
49 | }
50 |
51 |
52 | for id, query in tqdm(enumerate(query_list), total=len(query_list), desc="Processing Queries"):
53 |
54 | input_ids = tokenizer.encode(query, return_tensors="pt")
55 | with torch.no_grad():
56 | outputs = model.generate(
57 | input_ids,
58 | max_new_tokens=512,
59 | do_sample=False, # greedy
60 | return_dict_in_generate=True,
61 | output_hidden_states=True
62 | )
63 |
64 | response = tokenizer.batch_decode(outputs[0][:, input_ids.shape[1]:-1])[0].strip()
65 |
66 | for layer, hook in zip(layers, hooks):
67 | acts[id]['reps'][layer] = hook.tokens_embeddings
68 |
69 | acts[id]['token_ids'] = outputs[0][:, input_ids.shape[1]:-1].squeeze().cpu()
70 |
71 | for hook in hooks:
72 | hook.tokens_embeddings = []
73 |
74 | for id, layer_acts in acts.items():
75 | for layer, emb in layer_acts['reps'].items():
76 | layer_acts['reps'][layer] = torch.stack(emb).float()
77 |
78 |
79 | # remove hooks
80 | for handle in handles:
81 | handle.remove()
82 |
83 | return acts
84 |
85 |
86 | def main():
87 | parser = argparse.ArgumentParser(description="Generate activations for statements in a dataset")
88 | parser.add_argument("--model_path", default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
89 | parser.add_argument("--layers", nargs='+', help="Layers to save embeddings from")
90 | parser.add_argument("--dataset", default='math_train_12k')
91 | parser.add_argument("--sample_num", type=int, default=100)
92 | parser.add_argument("--output_dir", default="acts", help="Directory to save activations to")
93 | args = parser.parse_args()
94 |
95 | dataset = args.dataset
96 |
97 | math_data = pd.read_csv(f'data/{dataset}.csv')
98 |
99 | query_list = math_data['problem'].tolist()[:args.sample_num]
100 |
101 | tokenizer, model = load_model(args.model_path)
102 |
103 | layers = [int(layer) for layer in args.layers]
104 | if layers == [-1]:
105 | layers = list(range(len(model.model.layers)))
106 |
107 | acts = get_acts(query_list, tokenizer, model, layers, token_pos=-1)
108 |
109 | # save representations
110 | os.makedirs(f'{args.output_dir}/reasoning_evolve/', exist_ok=True)
111 | torch.save(acts, f"{args.output_dir}/reasoning_evolve/{dataset}_{args.model_path.split('/')[-1]}.pth")
112 |
113 |
114 | if __name__=='__main__':
115 | main()
116 |
117 |
--------------------------------------------------------------------------------
/src/applications/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | from tqdm import tqdm
4 | from pebble import ProcessPool
5 | from concurrent.futures import TimeoutError
6 |
7 | from grader import *
8 |
9 | from parser import *
10 | from utils import load_jsonl
11 | from python_executor import PythonExecutor
12 |
13 |
14 | def evaluate(data_name, prompt_type, samples: list=None, file_path: str=None, max_num_samples=None, execute=False):
15 | assert samples or file_path, "samples or file_path must be provided"
16 | if not samples:
17 | samples = list(load_jsonl(file_path))
18 | if 'idx' in samples[0]:
19 | samples = {sample['idx']: sample for sample in samples}.values()
20 | samples = sorted(samples, key=lambda x: x['idx'])
21 | else:
22 | samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)]
23 |
24 | if max_num_samples:
25 | print(f"max_num_samples: {max_num_samples} / {len(samples)}")
26 | samples = samples[:max_num_samples]
27 |
28 | # parse gt
29 | for sample in samples:
30 | sample['gt_cot'], sample['gt'] = parse_ground_truth(sample, data_name)
31 | params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']]
32 |
33 | scores = []
34 | timeout_cnt = 0
35 |
36 | with ProcessPool(max_workers=1) as pool:
37 | future = pool.map(math_equal_process, params, timeout=3)
38 | iterator = future.result()
39 | with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
40 | while True:
41 | try:
42 | result = next(iterator)
43 | scores.append(result)
44 | except StopIteration:
45 | break
46 | except TimeoutError as error:
47 | print(error)
48 | scores.append(False)
49 | timeout_cnt += 1
50 | except Exception as error:
51 | print(error.traceback)
52 | exit()
53 | progress_bar.update(1)
54 |
55 | idx = 0
56 | score_mat = []
57 | for sample in samples:
58 | sample['score'] = scores[idx: idx+len(sample['pred'])]
59 | assert len(sample['score']) == len(sample['pred'])
60 | score_mat.append(sample['score'])
61 | idx += len(sample['pred'])
62 |
63 | max_len = max([len(s) for s in score_mat])
64 |
65 | for i, s in enumerate(score_mat):
66 | if len(s) < max_len:
67 | score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad
68 |
69 | # output mean of each column of scores
70 | col_means= np.array(score_mat).mean(axis=0)
71 | mean_score = list(np.round(col_means * 100, decimals=1))
72 |
73 | result_json = {
74 | "num_samples": len(samples),
75 | "num_scores": len(scores),
76 | "timeout_samples": timeout_cnt,
77 | "empty_samples": len([s for s in samples if not s['pred'][-1]]),
78 | "acc": mean_score[0]
79 | }
80 |
81 | # each type score
82 | if "type" in samples[0]:
83 | type_scores = {}
84 | for sample in samples:
85 | if sample['type'] not in type_scores:
86 | type_scores[sample['type']] = []
87 | type_scores[sample['type']].append(sample['score'][-1])
88 | type_scores = {k: np.round(np.array(v).mean() * 100, decimals=1) for k, v in type_scores.items()}
89 | type_scores = {k: v for k, v in sorted(type_scores.items(), key=lambda item: item[0])}
90 | result_json['type_acc'] = type_scores
91 |
92 | print(result_json)
93 | return samples, result_json
94 |
95 |
96 | def parse_args():
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument("--data_name", type=str, default="math")
99 | parser.add_argument("--prompt_type", type=str, default="tool-integrated")
100 | parser.add_argument("--file_path", type=str, default=None, required=True)
101 | parser.add_argument("--max_num_samples", type=int, default=None)
102 | parser.add_argument("--execute", action="store_true")
103 | args = parser.parse_args()
104 | return args
105 |
106 | if __name__ == "__main__":
107 | args = parse_args()
108 | evaluate(data_name=args.data_name, prompt_type=args.prompt_type, file_path=args.file_path,
109 | max_num_samples=args.max_num_samples, execute=args.execute)
110 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
Demystifying Reasoning Dynamics with Mutual Information: Thinking Tokens are Information Peaks in LLM Reasoning
5 |
6 |
📢 If you are interested in our work, please star ⭐ our project.
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## 🌈 Introduction
17 |
18 |
19 | 
20 |
21 | Large reasoning models (LRMs) have demonstrated impressive capabilities in
22 | complex problem-solving, yet their internal reasoning mechanisms remain poorly
23 | understood. In this paper, we investigate the reasoning trajectories of LRMs from
24 | an information-theoretic perspective. By tracking how mutual information (MI)
25 | between intermediate representations and the correct answer evolves during LRM
26 | reasoning, we observe an interesting *MI peaks* phenomenon: **the MI at specific
27 | generative steps exhibits a sudden and significant increase during LRM’s
28 | reasoning process**. We theoretically analyze such phenomenon and show that as
29 | MI increases, the probability of model’s prediction error decreases. Furthermore,
30 | **these MI peaks often correspond to tokens expressing reflection or transition, such as “Hmm”, “Wait” and “Therefore,”** which we term as the thinking
31 | tokens. We then demonstrate that these thinking tokens are crucial for LRM’s
32 | reasoning performance, while other tokens has minimal impacts. Building on
33 | these analyses, we propose two simple yet effective methods to improve LRM’s
34 | reasoning performance, by delicately leveraging these thinking tokens. Overall, our work provides novel insights into the reasoning mechanisms of LRMs and offers
35 | practical ways to improve their reasoning capabilities.
36 |
37 |
38 | ## 🚩Main Analyses
39 |
40 |
43 |
44 | **Certain steps exhibit sudden and significantly increases in MI during the reasoning process of LRMs, and these MI peaks are sparse and distribute non-uniformly.**
45 |
46 |
47 |
48 |
49 |
50 | ---
51 |
52 |
55 |
56 | **Theoretical Insights: Higher MI Leads to Tighter Bounds on Prediction Error.**
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 | **Non-reasoning LLMs exhibit weaker and less pronounced MI peaks compared to LRMs. And the overall MI in non-reasoning LLMs during the reasoning process is lower than their
65 | corresponding LRMs.**
66 |
67 |
68 |
69 |
70 |
71 |
72 | **The tokens that appear at MI peaks are mostly connective words that express self-reflection or
73 | transitions in LRM’s reasoning process.**
74 |
75 |
76 |
77 |
78 | ## 🚀Quick Start
79 |
80 |
81 | ### 🔧Requirements
82 |
83 | The following pakages are required to run the code:
84 |
85 | - python==3.11.5
86 |
87 | - pytorch==2.1.2
88 |
89 | - transformers==4.46.1
90 |
91 | - numpy==1.26.4
92 |
93 | ### 🌟Usage
94 |
95 | ```bash
96 | cd src/
97 | ```
98 |
99 | **1. Collect the representations and compute the MI**
100 |
101 | ```bash
102 | sh scripts/compute_mi_trajectories.sh
103 | ```
104 |
105 | **2. Plot figures to observe the MI Peaks phenomenon**
106 |
107 | ```bash
108 | run the plot_mi_peaks.ipynb
109 | ```
110 |
111 |
112 | **3. Run the Representation Recycling (RR)**
113 |
114 | ```bash
115 | sh scripts/run_RR.sh
116 | ```
117 |
118 | ## 📝License
119 | Distributed under the Apache-2.0 License. See LICENSE for more information.
120 |
121 |
122 | ## Acknowledgements
123 |
124 | Some code in this project is adapted from resources provided by the following repositories:
125 |
126 | - https://github.com/QwenLM/Qwen2.5-Math
127 | - https://github.com/ChnQ/TracingLLM
128 |
129 | We greatly appreciate the contributions of the original authors.
130 |
131 | ## 📖BibTeX
132 | ```
133 | @article{qian2025demystifying,
134 | title={Demystifying Reasoning Dynamics with Mutual Information: Thinking Tokens are Information Peaks in LLM Reasoning},
135 | author={Qian, Chen and Liu, Dongrui and Wen, Haochen and Bai, Zhen and Liu, Yong and Shao, Jing},
136 | journal={arXiv preprint arXiv:2506.02867},
137 | year={2025}
138 | }
--------------------------------------------------------------------------------
/src/applications/trajectory.py:
--------------------------------------------------------------------------------
1 | import re
2 | """
3 | trajcectory:
4 | [
5 | {"role": "rationale", "content": "..."},
6 | {"role": "program", "content": "..."},
7 | {"role": "output", "content": "..."},
8 | {"role": "rationale", "content": "..."},
9 | ...
10 | ]
11 | """
12 |
13 | def text_to_trajectory(traj_str: str) -> None:
14 | """
15 | """
16 | # parse the above interleaved string of raionale, program, output, raionale, program, output, ...
17 | # output a list of dict
18 | trajectory = []
19 | cur_role = "rationale"
20 | cur_content = ""
21 |
22 | # print(traj_str)
23 | for i, line in enumerate(traj_str.split("\n")):
24 | if line == "```python": # program begin
25 | assert cur_role == "rationale"
26 | if cur_content:
27 | trajectory.append({"role": cur_role, "content": cur_content})
28 | cur_content = ""
29 | cur_role = "program"
30 | elif cur_role == "program" and line == "```": # program end
31 | assert cur_content
32 | trajectory.append({"role": cur_role, "content": cur_content})
33 | cur_content = ""
34 | cur_role = "output"
35 | elif cur_role == "output" and line.startswith("```output"): # output begin
36 | assert cur_content == ""
37 | elif cur_role == "output" and line == "```": # output end
38 | trajectory.append({"role": cur_role, "content": cur_content})
39 | cur_content = ""
40 | cur_role = "rationale"
41 | else: # content
42 | cur_content += line
43 | if i < len(traj_str.split("\n")) - 1:
44 | cur_content += "\n"
45 | # the last content
46 | if cur_content:
47 | trajectory.append({"role": cur_role, "content": cur_content})
48 | return trajectory
49 |
50 |
51 | def trajectory_to_text(trajectory: list) -> str:
52 | text = ""
53 | for item in trajectory:
54 | content = item["content"]
55 | if item["role"] == "program":
56 | content = f"```python\n{content}```\n"
57 | elif item["role"] == "output":
58 | content = f"```output\n{content}```\n"
59 | text += content
60 | return text
61 |
62 |
63 | def is_execution_success(output):
64 | error_key_words = ["error", "exception", "no algorithms", "no algorithms", "cannot", "nan", "..."]
65 | success = all([k not in output.lower() for k in error_key_words])
66 | return success
67 |
68 |
69 | def extract_program(text:str=None, trajectory:list=None, last_only=False) -> str:
70 | assert text is not None or trajectory is not None, "Either text or trajectory should be provided."
71 | if trajectory is None:
72 | try:
73 | trajectory = text_to_trajectory(text)
74 | except:
75 | return "raise ValueError('Invalid trajectory')"
76 |
77 | program_list = []
78 | import_lines = []
79 | for i, item in enumerate(trajectory):
80 | if item["role"] == "program":
81 | cur_program = item["content"]
82 | if i < len(trajectory) - 1:
83 | assert trajectory[i+1]["role"] == "output"
84 | output = trajectory[i+1]["content"].strip()
85 | if is_execution_success(output):
86 | program_list.append(cur_program)
87 | else:
88 | # extract import lines only
89 | for line in cur_program.split("\n"):
90 | if line.startswith("import") or line.startswith("from"):
91 | import_lines.append(line)
92 | else:
93 | program_list.append(cur_program)
94 | # add import lines to the first program
95 | if len(program_list) == 0:
96 | program_list.append("")
97 | if len(import_lines) > 0:
98 | program_list[0] = "\n".join(import_lines) + "\n" + program_list[0]
99 | for i, program in enumerate(program_list[:-1]):
100 | program_list[i] = "\n".join([line for line in program.split("\n") if not line.strip().startswith("print(")])
101 |
102 | if last_only:
103 | program = program_list[-1]
104 | else:
105 | program = "\n".join(program_list)
106 | return program
107 |
108 |
109 | def extract_program_output(pred_str, last_only=True):
110 | """
111 | extract output between ```output\n...\n```, use regex, there might be multiple outputs, each output may have multiple lines
112 | """
113 | outputs = re.findall(r"```output\n(.*?)\n```", pred_str, re.DOTALL)
114 | if last_only:
115 | return outputs[-1] if len(outputs) > 0 else ""
116 | else:
117 | return outputs
118 |
119 |
120 | def _test_str_to_trajectory():
121 | traj_text = """We can use the Chinese Remainder Theorem to solve this system of congruences. The Chinese Remainder Theorem states that if $n_1, n_2, \ldots, n_k$ are pairwise coprime (i.e., the greatest common divisor of any two of them is 1), and if $a_1, a_2, \ldots, a_k$ are any integers, then there is a number $x$ that satisfies the system of congruences:
122 |
123 | $x \equiv a_1 \pmod{n_1}$
124 | $x \equiv a_2 \pmod{n_2}$
125 | $\vdots$
126 | $x \equiv a_k \pmod{n_k}$
127 |
128 | Moreover, any two solutions of this system are congruent modulo the product $n_1n_2 \cdots n_k$.
129 |
130 | Let's write a function to solve this:
131 |
132 | ```python
133 | from sympy import symbols, solve
134 |
135 | def smallest_x():
136 | x = symbols('x')
137 | # Solve the system of congruences
138 | solution = solve([x - 7, x - 17], x, domain='Z')
139 | # Return the smallest positive solution
140 | return min([sol for sol in solution if sol >= 0])
141 |
142 | result = smallest_x()
143 | print(result)
144 | ```
145 | ```output
146 | ValueError: min() arg is an empty sequence
147 | ```
148 | The error in the last example is due to the incorrect usage of the `solve` function. The `solve` function is used to solve equations, not congruences. To solve a system of congruences, we can use the `solve_congruence` function from the `sympy.ntheory.modular` module.
149 |
150 | Here is the corrected code:
151 |
152 | ```python
153 | from sympy.ntheory.modular import solve_congruence
154 |
155 | def smallest_x():
156 | # Solve the system of congruences
157 | solution = solve_congruence(('x', 7, 12), ('x', 17, 11))
158 | # Return the smallest positive solution
159 | return solution[0][0]
160 |
161 | result = smallest_x()
162 | print(result)
163 | ```
164 | ```output
165 | good
166 | ```
167 | ValueError: too many values to unpack (expected 2)
168 | The error in the last example is due to the incorrect usage of the `solve_congruence` function. The `solve_congruence` function returns a tuple of two elements: the solution and the modulus. We need to unpack these two values correctly.
169 |
170 | Here is the corrected code:
171 |
172 | ```python
173 | from sympy.ntheory.modular import solve_congruence
174 |
175 | def smallest_x():
176 | # Solve the system of congruences
177 | solution, modulus = solve_congruence(('x', 7, 12), ('x', 17, 11))
178 | # Return the smallest positive solution
179 | return solution
180 |
181 | result = smallest_x()
182 | print(result)
183 | ```"""
184 |
185 | import pprint
186 | trajectory = text_to_trajectory(traj_text)
187 | pprint.pprint(trajectory)
188 |
189 | text = trajectory_to_text(trajectory)
190 | assert text == traj_text
191 |
192 | # print(extract_program(traj_text))
193 |
194 |
195 | if __name__ == "__main__":
196 | _test_str_to_trajectory()
--------------------------------------------------------------------------------
/src/applications/python_executor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | import regex
4 | import pickle
5 | import traceback
6 | import copy
7 | import datetime
8 | import dateutil.relativedelta
9 | import multiprocess
10 | from multiprocess import Pool
11 | from typing import Any, Dict, Optional
12 | from pebble import ProcessPool
13 | from tqdm import tqdm
14 | from concurrent.futures import TimeoutError
15 | from functools import partial
16 | from timeout_decorator import timeout
17 | from contextlib import redirect_stdout
18 |
19 |
20 | class GenericRuntime:
21 | GLOBAL_DICT = {}
22 | LOCAL_DICT = None
23 | HEADERS = []
24 | def __init__(self):
25 | self._global_vars = copy.copy(self.GLOBAL_DICT)
26 | self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
27 |
28 | for c in self.HEADERS:
29 | self.exec_code(c)
30 |
31 | def exec_code(self, code_piece: str) -> None:
32 | if regex.search(r'(\s|^)?input\(', code_piece):
33 | # regex.search(r'(\s|^)?os.', code_piece):
34 | raise RuntimeError()
35 | exec(code_piece, self._global_vars)
36 |
37 | # TODO: use: https://github.com/shroominic/codebox-api
38 | # @high safe exec in sandbox
39 | # byte_code = compile_restricted(
40 | # code_piece,
41 | # filename='',
42 | # mode='exec'
43 | # )
44 | # print("global vars:", self._global_vars)
45 | # _print_ = PrintCollector
46 | # exec(byte_code, {'__builtins__': utility_builtins}, None)
47 |
48 | def eval_code(self, expr: str) -> Any:
49 | return eval(expr, self._global_vars)
50 |
51 | def inject(self, var_dict: Dict[str, Any]) -> None:
52 | for k, v in var_dict.items():
53 | self._global_vars[k] = v
54 |
55 | @property
56 | def answer(self):
57 | return self._global_vars['answer']
58 |
59 | class DateRuntime(GenericRuntime):
60 | GLOBAL_DICT = {
61 | 'datetime': datetime.datetime,
62 | 'timedelta': dateutil.relativedelta.relativedelta,
63 | 'relativedelta': dateutil.relativedelta.relativedelta
64 | }
65 |
66 |
67 | class CustomDict(dict):
68 | def __iter__(self):
69 | return list(super().__iter__()).__iter__()
70 |
71 | class ColorObjectRuntime(GenericRuntime):
72 | GLOBAL_DICT = {'dict': CustomDict}
73 |
74 |
75 | class PythonExecutor:
76 | def __init__(
77 | self,
78 | runtime: Optional[Any] = None,
79 | get_answer_symbol: Optional[str] = None,
80 | get_answer_expr: Optional[str] = None,
81 | get_answer_from_stdout: bool = False,
82 | timeout_length: int = 5,
83 | ) -> None:
84 | self.runtime = runtime if runtime else GenericRuntime()
85 | self.answer_symbol = get_answer_symbol
86 | self.answer_expr = get_answer_expr
87 | self.get_answer_from_stdout = get_answer_from_stdout
88 | self.pool = Pool(multiprocess.cpu_count())
89 | self.timeout_length = timeout_length
90 |
91 | def process_generation_to_code(self, gens: str):
92 | return [g.strip().split('\n') for g in gens]
93 |
94 | @staticmethod
95 | def execute(
96 | code,
97 | get_answer_from_stdout = None,
98 | runtime = None,
99 | answer_symbol = None,
100 | answer_expr = None,
101 | timeout_length = 10,
102 | auto_mode=False
103 | ):
104 | try:
105 | if auto_mode:
106 | if "print(" in code[-1]:
107 | program_io = io.StringIO()
108 | with redirect_stdout(program_io):
109 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
110 | program_io.seek(0)
111 | result = program_io.read()
112 | else:
113 | print(code)
114 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
115 | result = timeout(timeout_length)(runtime.eval_code)(code[-1])
116 | else:
117 | if get_answer_from_stdout:
118 | program_io = io.StringIO()
119 | with redirect_stdout(program_io):
120 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
121 | program_io.seek(0)
122 | result = program_io.read()
123 | elif answer_symbol:
124 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
125 | result = runtime._global_vars[answer_symbol]
126 | elif answer_expr:
127 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
128 | result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
129 | else:
130 | timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
131 | result = timeout(timeout_length)(runtime.eval_code)(code[-1])
132 | report = "Done"
133 | str(result)
134 | pickle.dumps(result) # serialization check
135 | except:
136 | result = ''
137 | report = traceback.format_exc().split('\n')[-2]
138 | return result, report
139 |
140 | def apply(self, code):
141 | return self.batch_apply([code])[0]
142 |
143 | @staticmethod
144 | def truncate(s, max_length=400):
145 | half = max_length // 2
146 | if len(s) > max_length:
147 | s = s[:half] + "..." + s[-half:]
148 | return s
149 |
150 | def batch_apply(self, batch_code):
151 | all_code_snippets = self.process_generation_to_code(batch_code)
152 |
153 | timeout_cnt = 0
154 | all_exec_results = []
155 | # with ProcessPool(max_workers=min(len(all_code_snippets), os.cpu_count())) as pool:
156 | with ProcessPool(max_workers=min(len(all_code_snippets), 1)) as pool:
157 | executor = partial(
158 | self.execute,
159 | get_answer_from_stdout=self.get_answer_from_stdout,
160 | runtime=self.runtime,
161 | answer_symbol=self.answer_symbol,
162 | answer_expr=self.answer_expr,
163 | timeout_length=self.timeout_length, # this timeout not work
164 | auto_mode=True
165 | )
166 | future = pool.map(executor, all_code_snippets, timeout=self.timeout_length)
167 | iterator = future.result()
168 |
169 | if len(all_code_snippets) > 100:
170 | progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
171 | else:
172 | progress_bar = None
173 |
174 | while True:
175 | try:
176 | result = next(iterator)
177 | all_exec_results.append(result)
178 | except StopIteration:
179 | break
180 | except TimeoutError as error:
181 | print(error)
182 | all_exec_results.append(("", "Timeout Error"))
183 | timeout_cnt += 1
184 | except Exception as error:
185 | print(error)
186 | exit()
187 | if progress_bar is not None:
188 | progress_bar.update(1)
189 |
190 | if progress_bar is not None:
191 | progress_bar.close()
192 |
193 | batch_results = []
194 | for code, (res, report) in zip(all_code_snippets, all_exec_results):
195 | # post processing
196 | res, report = str(res).strip(), str(report).strip()
197 | res, report = self.truncate(res), self.truncate(report)
198 | batch_results.append((res, report))
199 | return batch_results
200 |
201 |
202 | def _test():
203 | batch_code = [
204 | """
205 | from sympy import Matrix
206 |
207 | def null_space_basis():
208 | # Define the matrix
209 | A = Matrix([[3, 3, -1, -6], [9, -1, -8, -1], [7, 4, -2, -9]])
210 |
211 | # Compute the basis for the null space
212 | basis = A.nullspace()
213 |
214 | # Round the elements of the basis vectors to three decimal places
215 | basis_rounded = [v.evalf(3) for v in basis]
216 |
217 | return basis_rounded
218 |
219 | result = null_space_basis()
220 | print(result)
221 | """
222 | ]
223 |
224 | executor = PythonExecutor(get_answer_from_stdout=True)
225 | predictions = executor.apply(batch_code[0])
226 | print(predictions)
227 |
228 |
229 | if __name__ == '__main__':
230 | _test()
--------------------------------------------------------------------------------
/src/plot_mi_peaks.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### Plot the MI trajectories"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "\n",
17 | "import seaborn as sns\n",
18 | "import torch\n",
19 | "import matplotlib.pyplot as plt\n",
20 | "import seaborn as sns\n",
21 | "import json\n",
22 | "import numpy as np\n",
23 | "from transformers import AutoTokenizer\n",
24 | "\n",
25 | "sns.set_theme(style=\"whitegrid\", context=\"talk\", palette=\"muted\", font_scale=1.0)\n",
26 | "\n",
27 | "\n",
28 | "def load_model_data(model_path, dataset, target_layer=31):\n",
29 | " tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
30 | " model_name = model_path.split('/')[-1]\n",
31 | "\n",
32 | " data_path = f'results/mi/{dataset}_gtmodel={model_name}_testmodel={model_name}.pth'\n",
33 | " data = torch.load(data_path)\n",
34 | "\n",
35 | " all_sample_mi_list = []\n",
36 | " all_mi_peak_list = [] \n",
37 | "\n",
38 | " for id in data.keys():\n",
39 | " try:\n",
40 | " this_id_mi_list = data[id]['reps'][target_layer]\n",
41 | " \n",
42 | " all_sample_mi_list.append(this_id_mi_list[:])\n",
43 | "\n",
44 | " top_indices = sorted(range(len(this_id_mi_list)), \n",
45 | " key=lambda i: this_id_mi_list[i], reverse=True)[:20] # approximately take top-20\n",
46 | " all_mi_peak_list.append(top_indices)\n",
47 | "\n",
48 | " except Exception as e:\n",
49 | " print(f'[id:{id}] Error:', e)\n",
50 | "\n",
51 | " return all_sample_mi_list, all_mi_peak_list\n",
52 | "\n",
53 | "\n",
54 | "dataset = 'math_train_12k'\n",
55 | "model_path = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'\n",
56 | "model_name = model_path.split('/')[-1]\n",
57 | "\n",
58 | "\n",
59 | "model_data_dict = {}\n",
60 | "\n",
61 | "mi_list, mi_peak_list = load_model_data(model_path, dataset, target_layer=31)\n",
62 | "model_data_dict[model_name] = {\n",
63 | " 'mi': mi_list,\n",
64 | " 'peaks': mi_peak_list,\n",
65 | "}\n",
66 | "\n",
67 | "\n",
68 | "fig, axes = plt.subplots(2, 5, figsize=(30, 10))\n",
69 | "axes = axes.flatten()\n",
70 | "\n",
71 | "for sample_idx in range(10):\n",
72 | " ax = axes[sample_idx]\n",
73 | " mi_values = model_data_dict[model_name]['mi'][sample_idx]\n",
74 | " steps = model_data_dict[model_name]['peaks'][sample_idx]\n",
75 | "\n",
76 | " ax.plot(mi_values, linewidth=2, alpha=0.8)\n",
77 | " ax.scatter(steps, [mi_values[i] for i in steps],\n",
78 | " s=50, edgecolor=\"white\", linewidth=0.8, zorder=3)\n",
79 | "\n",
80 | " ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
81 | " ax.set_facecolor(\"#fafafa\")\n",
82 | " sns.despine(ax=ax, top=True, right=True)\n",
83 | "\n",
84 | " ax.set_title(model_name, fontsize=22, pad=8) \n",
85 | " ax.set_xlabel(\"Reasoning Step\", fontsize=18) \n",
86 | " ax.set_ylabel(\"MI Value\", fontsize=20)\n",
87 | " ax.tick_params(axis=\"x\", labelsize=18) \n",
88 | " ax.tick_params(axis=\"y\", labelsize=18)\n",
89 | "\n",
90 | "\n",
91 | "plt.subplots_adjust(hspace=0.4, wspace=0.13)\n",
92 | "\n",
93 | "plt.show()"
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "metadata": {},
99 | "source": [
100 | "### Projecting the MI-peak representations to token space "
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "import re\n",
110 | "import torch\n",
111 | "import matplotlib.pyplot as plt\n",
112 | "import seaborn as sns\n",
113 | "import json\n",
114 | "import numpy as np\n",
115 | "from transformers import AutoTokenizer\n",
116 | "from collections import Counter\n",
117 | "from matplotlib.colors import LinearSegmentedColormap\n",
118 | "\n",
119 | "sns.set(style=\"whitegrid\", context=\"talk\", font_scale=1.3)\n",
120 | "\n",
121 | "\n",
122 | "def plot_token_freq(model_path, dataset, target_layer=31):\n",
123 | "\n",
124 | " model_name = model_path.split('/')[-1]\n",
125 | " tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
126 | "\n",
127 | " data_path = f'results/mi/{dataset}_gtmodel={model_name}_testmodel={model_name}.pth'\n",
128 | " data = torch.load(data_path)\n",
129 | "\n",
130 | " acts = torch.load(f'acts/reasoning_evolve/{dataset}_{model_name}.pth')\n",
131 | "\n",
132 | " all_sample_mi_list = [] \n",
133 | " all_mi_peak_list = []\n",
134 | " all_tokens = []\n",
135 | "\n",
136 | " fail_id_list = []\n",
137 | " for id in data.keys():\n",
138 | " try:\n",
139 | " this_id_mi_list = data[id]['reps'][target_layer] \n",
140 | " all_sample_mi_list.append(this_id_mi_list[:]) \n",
141 | " top_indices = sorted(range(len(this_id_mi_list)), key=lambda i: this_id_mi_list[i], reverse=True)[:20] # approximately take top-20\n",
142 | "\n",
143 | " all_mi_peak_list.append(top_indices)\n",
144 | "\n",
145 | " token_list = acts[id]['token_ids'].tolist()\n",
146 | " token_list.append(2) # [eos] token\n",
147 | " top_prob_token_ids = [token_list[i] for i in top_indices]\n",
148 | "\n",
149 | " batch_top_n_tokens = tokenizer.batch_decode(top_prob_token_ids, skip_special_tokens=False)\n",
150 | " all_tokens.extend(batch_top_n_tokens)\n",
151 | "\n",
152 | " except Exception as e:\n",
153 | " fail_id_list.append(id)\n",
154 | "\n",
155 | " print('fail_id_list:', fail_id_list)\n",
156 | "\n",
157 | "\n",
158 | " english_pattern = re.compile(r'^[a-zA-Z]+$')\n",
159 | "\n",
160 | " processed_all_tokens = []\n",
161 | " for token in all_tokens:\n",
162 | " if english_pattern.match(token.strip()):\n",
163 | " processed_all_tokens.append(token)\n",
164 | "\n",
165 | " \n",
166 | " token_freq = Counter(processed_all_tokens)\n",
167 | "\n",
168 | " common_tokens = token_freq.most_common(15)\n",
169 | " print('$'*50)\n",
170 | " print(f'model: {model_name}')\n",
171 | " print(\"Most common tokens:\", common_tokens)\n",
172 | "\n",
173 | " colors = [\n",
174 | " (0.0, \"#3d61aa\"), \n",
175 | " (0.5, \"#b1bee9\"), \n",
176 | " (1.0, \"#9673c4\") \n",
177 | " ]\n",
178 | " cmap = LinearSegmentedColormap.from_list(\"blue_purple\", colors)\n",
179 | "\n",
180 | "\n",
181 | " # ------------------------------ plot --------------------------------------\n",
182 | "\n",
183 | " token_names, token_counts = zip(*common_tokens)\n",
184 | " token_names_processed = [token.replace('$', '\\$').replace('_', '\\_').replace('^', '\\^') for token in token_names]\n",
185 | " token_names_repr = [repr(token) for token in token_names_processed]\n",
186 | "\n",
187 | " n_bars = len(token_names_repr)\n",
188 | " palette = [cmap(i / (n_bars - 1)) for i in range(n_bars)]\n",
189 | "\n",
190 | " plt.figure(figsize=(8, 5))\n",
191 | " sns.barplot(x=list(token_names_repr), y=list(token_counts), palette=palette)\n",
192 | " plt.xticks(rotation=45, ha='right', size=17)\n",
193 | " plt.xlabel('Tokens at MI Peaks')\n",
194 | " plt.ylabel('Frequency')\n",
195 | " plt.title(f'{model_name}')\n",
196 | "\n",
197 | "\n",
198 | " plt.show()\n",
199 | " \n",
200 | "\n",
201 | "dataset = 'math_train_12k'\n",
202 | "model_path = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'\n",
203 | "\n",
204 | "\n",
205 | "plot_token_freq(\n",
206 | " model_path=model_path,\n",
207 | " dataset=dataset,\n",
208 | " target_layer=31,\n",
209 | ")\n",
210 | "\n"
211 | ]
212 | }
213 | ],
214 | "metadata": {
215 | "kernelspec": {
216 | "display_name": "vllm053",
217 | "language": "python",
218 | "name": "python3"
219 | },
220 | "language_info": {
221 | "codemirror_mode": {
222 | "name": "ipython",
223 | "version": 3
224 | },
225 | "file_extension": ".py",
226 | "mimetype": "text/x-python",
227 | "name": "python",
228 | "nbconvert_exporter": "python",
229 | "pygments_lexer": "ipython3",
230 | "version": "3.11.5"
231 | }
232 | },
233 | "nbformat": 4,
234 | "nbformat_minor": 2
235 | }
236 |
--------------------------------------------------------------------------------
/src/applications/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import json
5 | import os
6 | import numpy as np
7 | from pathlib import Path
8 | from typing import Iterable, Union, Any
9 |
10 | from examples import get_examples
11 |
12 |
13 | def set_seed(seed: int = 42) -> None:
14 | np.random.seed(seed)
15 | random.seed(seed)
16 | os.environ["PYTHONHASHSEED"] = str(seed)
17 | print(f"Random seed set as {seed}")
18 |
19 |
20 | def load_jsonl(file: Union[str, Path]) -> Iterable[Any]:
21 | with open(file, "r", encoding="utf-8") as f:
22 | for line in f:
23 | try:
24 | yield json.loads(line)
25 | except:
26 | print("Error in loading:", line)
27 | exit()
28 |
29 |
30 | def save_jsonl(samples, save_path):
31 | # ensure path
32 | folder = os.path.dirname(save_path)
33 | os.makedirs(folder, exist_ok=True)
34 |
35 | with open(save_path, "w", encoding="utf-8") as f:
36 | for sample in samples:
37 | f.write(json.dumps(sample, ensure_ascii=False) + "\n")
38 | print("Saved to", save_path)
39 |
40 |
41 | def lower_keys(example):
42 | new_example = {}
43 | for key, value in example.items():
44 | if key != key.lower():
45 | new_key = key.lower()
46 | new_example[new_key] = value
47 | else:
48 | new_example[key] = value
49 | return new_example
50 |
51 |
52 | EXAMPLES = get_examples()
53 |
54 |
55 | def load_prompt(data_name, prompt_type, num_shots):
56 | if not num_shots:
57 | return []
58 |
59 | if data_name in ["gsm_hard", "svamp", "tabmwp", "asdiv", "mawps"]:
60 | data_name = "gsm8k"
61 | if data_name in ["math_oai", "hungarian_exam", "math-oai", "aime24", "amc23"]:
62 | data_name = "math"
63 | if data_name in ["sat_math"]:
64 | data_name = "mmlu_stem"
65 | if data_name in [
66 | "gaokao2024_I",
67 | "gaokao2024_II",
68 | "gaokao_math_qa",
69 | "gaokao2024_mix",
70 | "cn_middle_school",
71 | ]:
72 | data_name = "gaokao"
73 |
74 | if prompt_type in ["tool-integrated"]:
75 | prompt_type = "tora"
76 |
77 | return EXAMPLES[data_name][:num_shots]
78 |
79 |
80 | PROMPT_TEMPLATES = {
81 | "direct": ("Question: {input}\nAnswer: ", "{output}", "\n\n"),
82 | "cot": ("Question: {input}\nAnswer: ", "{output}", "\n\n\n"),
83 | "pal": ("Question: {input}\n\n", "{output}", "\n---\n"),
84 | "tool-integrated": ("Question: {input}\n\nSolution:\n", "{output}", "\n---\n"),
85 | "self-instruct": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"),
86 | "tora": ("<|user|>\n{input}\n<|assistant|>\n", "{output}", "\n"),
87 | "wizard_zs": (
88 | "### Instruction:\n{input}\n\n### Response: Let's think step by step.",
89 | "{output}",
90 | "\n\n\n",
91 | ),
92 | "platypus_fs": (
93 | "### Instruction:\n{input}\n\n### Response:\n",
94 | "{output}",
95 | "\n\n\n",
96 | ),
97 | "deepseek-math": (
98 | "User: {input}\nPlease reason step by step, "
99 | "and put your final answer within \\boxed{{}}.\n\nAssistant:",
100 | "{output}",
101 | "\n\n\n",
102 | ),
103 | "kpmath": (
104 | "User: Please reason step by step and put your final answer at the end "
105 | 'with "The answer is: ".\n\n{input}\n\nAssistant:',
106 | "{output}",
107 | ),
108 | "jiuzhang": (
109 | "## Question\n{input}\n\n## Solution\n",
110 | "{output}",
111 | "\n\n\n",
112 | ),
113 | "jiuzhang_tora": (
114 | "## Question\n{input}\n\n## Code Solution\n",
115 | "{output}",
116 | "\n\n\n",
117 | ),
118 | "jiuzhang_nl": (
119 | "## Question\n{input}\n\n## Natural Language Solution\n",
120 | "{output}",
121 | "\n\n\n",
122 | ),
123 | "mmiqc": (
124 | 'Please solve the following problem and put your answer at the end with "The answer is: ".\n\n{input}\n\n',
125 | "{output}",
126 | "\n\n\n",
127 | ),
128 | "abel": (
129 | "Question:\n{input}\nAnswer:\nLet's think step by step.\n",
130 | "{output}",
131 | "\n\n",
132 | ),
133 | "shepherd": ("{input}\n", "{output}", "\n\n\n"),
134 | "qwen-boxed": (
135 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
136 | "<|im_start|>user\n{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
137 | "<|im_start|>assistant\n",
138 | "{output}",
139 | "\n\n",
140 | ),
141 | "qwen25-math-cot": (
142 | "<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
143 | "<|im_start|>user\n{input}<|im_end|>\n"
144 | "<|im_start|>assistant\n",
145 | "{output}",
146 | "\n\n",
147 | ),
148 | "mathstral": (
149 | "{input}\nPlease reason step by step, and put your final answer within \\boxed{{}}.",
150 | "{output}",
151 | "\n\n",
152 | ),
153 | "internlm-math-fs": ("Question:{input}\nAnswer:", "{output}", "\n"),
154 | "internlm-math-chat": (
155 | "<|im_start|>user\n{input}<|im_end|>\n" "<|im_start|>assistant\n",
156 | "{output}",
157 | "\n\n",
158 | ),
159 | "mistral": (
160 | "[INST] {input}[/INST]",
161 | "{output}",
162 | "\n\n",
163 | ),
164 | "numina": ("### Problem: {input}\n### Solution:", " {output}", "\n\n"),
165 | }
166 |
167 |
168 | def construct_prompt(example, data_name, args):
169 | if args.adapt_few_shot and data_name in [
170 | "gaokao2024_I",
171 | "gaokao2024_II",
172 | "gaokao_math_qa",
173 | "gaokao2024_mix",
174 | "cn_middle_school",
175 | ]:
176 | demos = load_prompt(data_name, args.prompt_type, 5)
177 | else:
178 | demos = load_prompt(data_name, args.prompt_type, args.num_shots)
179 | prompt_type = args.prompt_type
180 | if prompt_type == "platypus_fs":
181 | prompt_type = "cot"
182 | if prompt_type == "tool-integrated":
183 | prompt_type = "tora"
184 |
185 | prompt_temp = PROMPT_TEMPLATES[args.prompt_type]
186 |
187 | splitter = prompt_temp[2]
188 | input_template, output_template, splitter = (
189 | prompt_temp[0],
190 | prompt_temp[1],
191 | prompt_temp[2],
192 | )
193 | if args.prompt_type == "qwen25-math-cot":
194 | # Hotfix to support putting all demos into a single turn
195 | demo_prompt = splitter.join([q + "\n" + a for q, a in demos])
196 | else:
197 | demo_prompt = splitter.join(
198 | [
199 | input_template.format(input=q) + output_template.format(output=a)
200 | for q, a in demos
201 | ]
202 | )
203 | context = input_template.format(input=example["question"])
204 | if len(demo_prompt) == 0 or (
205 | args.adapt_few_shot and example["gt_ans"] not in ["A", "B", "C", "D", "E"]
206 | ):
207 | full_prompt = context
208 | else:
209 | if args.prompt_type == "qwen25-math-cot":
210 | # Hotfix to supportting put all demos into a single turn
211 | full_prompt = demo_prompt + splitter + example["question"]
212 | full_prompt = input_template.format(input=full_prompt)
213 | else:
214 | full_prompt = demo_prompt + splitter + context
215 |
216 | if args.prompt_type == "platypus_fs":
217 | full_prompt_temp = (
218 | "Below is an instruction that describes a task. "
219 | "Write a response that appropriately completes the request.\n\n"
220 | "### Instruction:\n{instruction}\n\n### Response:\n"
221 | )
222 | full_prompt = full_prompt_temp.format(instruction=full_prompt)
223 |
224 | if prompt_type == "tora":
225 | full_prompt = (
226 | """Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines:
227 |
228 | - Analyze the question and write functions to solve the problem; the function should not take any arguments.
229 | - Present the final result in LaTeX using a `\boxed{}` without any units.
230 | - Utilize the `pi` symbol and `Rational`` from Sympy for $\pi$ and fractions, and simplify all fractions and square roots without converting them to decimal values.
231 |
232 | Here are some examples you may refer to:
233 |
234 | ---
235 |
236 | """
237 | + full_prompt
238 | )
239 |
240 | return full_prompt.strip(" ") # important!
241 |
242 |
243 | key_map = {
244 | "gt": "Ground Truth",
245 | "pred": "Prediction",
246 | "gt_cot": "Reference CoT",
247 | "score": "Score",
248 | }
249 |
250 |
251 | def show_sample(sample, print_all_preds=False):
252 | print("==" * 20)
253 | for key in ["idx", "type", "level", "dataset"]:
254 | if key in sample:
255 | # capitalize
256 | print("{}: {}".format(key[0].upper() + key[1:], sample[key]))
257 | print("Question:", repr(sample["question"]))
258 | if "code" in sample:
259 | if print_all_preds:
260 | for code in sample["code"]:
261 | print("-" * 20)
262 | print("code:", code)
263 | print("Execution:", sample["report"])
264 | else:
265 | print("Solution:\n", sample["code"][0])
266 | print("Execution:", sample["report"][0])
267 | if "pred" in sample:
268 | print("Prediction:", repr(sample["pred"][0]))
269 | for key in ["gt", "score", "unit", "gt_cot"]:
270 | if key in sample:
271 | _key = key_map.get(key, key)
272 | print("{}: {}".format(_key, repr(sample[key])))
273 | print()
274 |
--------------------------------------------------------------------------------
/src/applications/RR_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | from typing import Optional, List
5 | import gc
6 | import json
7 | import time
8 | import random
9 |
10 |
11 | class RecursiveThinkingModel(torch.nn.Module):
12 | def __init__(
13 | self,
14 | base_model_name: str = None,
15 | extract_layer_id: int = None,
16 | inject_layer_id: int = None,
17 | num_recursive_steps: int = 1, # Number of recursive optimization steps
18 | use_recursive_thinking: bool = True,
19 | output_file: str = None
20 | ):
21 | super().__init__()
22 |
23 |
24 | self.base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
25 | self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
26 | self.tokenizer.pad_token = self.tokenizer.eos_token
27 |
28 | self.extract_layer_id = extract_layer_id or -1
29 | self.inject_layer_id = inject_layer_id or -1
30 | self.extracted_hidden = None
31 |
32 | self.num_recursive_steps = num_recursive_steps
33 | self.use_recursive_thinking = use_recursive_thinking
34 |
35 |
36 | self._find_layers()
37 |
38 |
39 | self.extract_hook = None
40 | self.inject_hook = None
41 | self.enable_inject = False
42 | self.enable_extract = True
43 |
44 |
45 | self.output_file = output_file
46 | if output_file is None:
47 | raise ValueError("output_file must be provided as a valid path string")
48 |
49 |
50 | def _find_layers(self):
51 | """Identify the layer structure of the model"""
52 | if hasattr(self.base_model, 'model') and hasattr(self.base_model.model, 'layers'):
53 | self.layers = self.base_model.model.layers
54 | elif hasattr(self.base_model, 'transformer') and hasattr(self.base_model.transformer, 'h'):
55 | self.layers = self.base_model.transformer.h
56 | elif hasattr(self.base_model, 'encoder') and hasattr(self.base_model.encoder, 'layer'):
57 | self.layers = self.base_model.encoder.layer
58 | else:
59 | raise ValueError(f"Unsupported model architecture: {type(self.base_model)}")
60 |
61 | def _register_hooks(self):
62 | """Register feature passing hooks"""
63 |
64 | # Remove existing hooks (if any)
65 | self._remove_hooks()
66 |
67 | # Feature extraction hook (captures representation at the specified layer)
68 | def extract_hook(module, inputs, outputs):
69 | if self.enable_extract:
70 | self.extracted_hidden = outputs[0].clone()
71 |
72 | # Feature injection hook (injects captured representation into the target layer)
73 | def inject_hook(module, inputs):
74 | if self.enable_inject and self.extracted_hidden is not None:
75 | modified_hidden_states = inputs[0].clone()
76 | modified_hidden_states[:, -1:, :] = self.extracted_hidden[:, -1:, :]
77 |
78 | return (modified_hidden_states,) + inputs[1:]
79 |
80 | return inputs
81 |
82 |
83 | # Register hooks
84 | self.inject_hook = self.layers[self.inject_layer_id].register_forward_pre_hook(inject_hook)
85 | self.extract_hook = self.layers[self.extract_layer_id].register_forward_hook(extract_hook)
86 |
87 | def _remove_hooks(self):
88 | """Remove hooks to prevent memory leaks"""
89 | if hasattr(self, 'extract_hook') and self.extract_hook is not None:
90 | self.extract_hook.remove()
91 | self.extract_hook = None
92 |
93 | if hasattr(self, 'inject_hook') and self.inject_hook is not None:
94 | self.inject_hook.remove()
95 | self.inject_hook = None
96 |
97 | def _clear_memory(self):
98 | """Clean up memory, release unnecessary tensors and cache"""
99 | gc.collect()
100 | if torch.cuda.is_available():
101 | torch.cuda.empty_cache()
102 |
103 | @torch.inference_mode()
104 | def generate(
105 | self,
106 | max_tokens: int,
107 | prompt: str,
108 | interested_tokens: set = None,
109 | use_recursive_thinking: bool = True,
110 | **kwargs
111 | ):
112 |
113 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114 | self.to(device).eval()
115 |
116 | # Register hooks only when needed
117 | if use_recursive_thinking:
118 | self._register_hooks()
119 |
120 | try:
121 | # Initialize input
122 | inputs = self.tokenizer(prompt, return_tensors="pt").to(device)
123 | input_ids = inputs.input_ids
124 | attention_mask = inputs.attention_mask
125 |
126 | past_key_values = None
127 | generated = input_ids.clone()
128 | full_input_ids = input_ids.clone() # Save the complete input sequence
129 | print(f"Prompt: {prompt}")
130 |
131 | recursive_tokens_count = 0
132 | regular_tokens_count = 0
133 | lose_tokens_count = 0
134 |
135 | for _ in tqdm(range(max_tokens), desc="Generating"):
136 | # Regular generation to get candidate token
137 |
138 | outputs = self.base_model(
139 | input_ids=input_ids,
140 | attention_mask=attention_mask,
141 | past_key_values=past_key_values,
142 | use_cache=True,
143 | output_hidden_states=False
144 | )
145 | next_token_logits = outputs.logits[:, -1, :]
146 | candidate_token = next_token_logits.argmax(dim=-1, keepdim=True)
147 | del next_token_logits
148 |
149 | if candidate_token.item() == self.tokenizer.eos_token_id:
150 | break
151 |
152 | # recursive decoding
153 | if use_recursive_thinking and candidate_token.item() in interested_tokens and random.random() < 0.5:
154 | recursive_tokens_count += 1
155 | print(f"Recursive token")
156 |
157 | final_token = None
158 |
159 | # ------------------------------ Forward propagation after replacing representation ------------------------------
160 | self.enable_inject = True
161 | self.enable_extract = False
162 | self.base_model.config.use_cache=True
163 |
164 | recursive_outputs = self.base_model(
165 | input_ids=input_ids,
166 | attention_mask=attention_mask,
167 | past_key_values=past_key_values,
168 | use_cache=True
169 | )
170 | self.enable_extract = True
171 | # Disable hooks, reset hidden state
172 | self.enable_inject = False
173 | final_logits = recursive_outputs.logits[:, -1, :]
174 | final_token = final_logits.argmax(dim=-1, keepdim=True)
175 |
176 |
177 | # ------------------------------ Update KV cache ------------------------------
178 | past_key_values = recursive_outputs.past_key_values
179 | # Update sequence
180 | generated = torch.cat([generated,final_token], dim=-1)
181 | full_input_ids = torch.cat([full_input_ids,final_token], dim=-1)
182 | input_ids = final_token
183 |
184 | # Update attention_mask
185 | attention_mask = torch.cat([
186 | attention_mask,
187 | torch.ones((1, 1), device=device)
188 | ], dim=1)
189 | # regular decoding
190 | else:
191 | regular_tokens_count += 1
192 | past_key_values = outputs.past_key_values
193 | # Update sequence
194 | generated = torch.cat([generated, candidate_token], dim=-1)
195 | full_input_ids = torch.cat([full_input_ids, candidate_token], dim=-1)
196 | input_ids = candidate_token
197 |
198 | # Update attention_mask
199 | attention_mask = torch.cat([
200 | attention_mask,
201 | torch.ones((1, 1), device=device)
202 | ], dim=1)
203 |
204 |
205 | outputs = None
206 |
207 | if _ % 5000 == 0:
208 | gc.collect()
209 | if torch.cuda.is_available():
210 | torch.cuda.empty_cache()
211 |
212 |
213 | # Save performance statistics
214 | response_str = self.tokenizer.decode(generated[0], skip_special_tokens=True)
215 | entry = {
216 | "query": prompt,
217 | "response": response_str,
218 | "performance": {
219 | "recursive_tokens": recursive_tokens_count,
220 | "regular_tokens": regular_tokens_count,
221 | "lose_tokens": lose_tokens_count,
222 |
223 | }
224 | }
225 | with open(self.output_file, "a") as f:
226 | f.write(json.dumps(entry, indent=2) + "\n")
227 |
228 | print(f"\nGenerated {recursive_tokens_count} recursive tokens and {regular_tokens_count} regular tokens")
229 | print('-'*50)
230 |
231 | generated = None
232 | full_input_ids = None
233 | past_key_values = None
234 | input_ids = None
235 | attention_mask = None
236 |
237 |
238 | return response_str
239 |
240 | finally:
241 | self._remove_hooks()
242 | self._clear_memory()
--------------------------------------------------------------------------------
/src/applications/model_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/allenai/open-instruct
3 | """
4 | import torch
5 | import tqdm
6 | from transformers import StoppingCriteria, StoppingCriteriaList
7 |
8 |
9 | class KeywordsStoppingCriteria(StoppingCriteria):
10 | def __init__(self, keywords_str, tokenizer):
11 | StoppingCriteria.__init__(self)
12 | self.current_context = []
13 | self.tokenizer = tokenizer
14 | self.keywords_str = keywords_str
15 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
16 | if len(self.current_context) == 0:
17 | self.current_context = [[] for _ in range(input_ids.shape[0])]
18 |
19 | # self.current_context.append(input_ids[0][-1].item())
20 | sequences_should_be_stopped = []
21 | for i in range(input_ids.shape[0]):
22 | _id = input_ids[i][-1].item()
23 | self.current_context[i].append(_id)
24 | current_context = self.tokenizer.decode(self.current_context[i])
25 | should_be_stopped = False
26 | for word in self.keywords_str:
27 | if word in current_context:
28 | should_be_stopped = True
29 | break
30 | sequences_should_be_stopped.append(should_be_stopped)
31 | return all(sequences_should_be_stopped)
32 |
33 |
34 | class KeyWordsCriteriaTrunc(StoppingCriteria):
35 | def __init__(self, stop_id_sequences, prompt_length):
36 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
37 | self.stop_sequences = stop_id_sequences
38 | self.prompt_length = prompt_length
39 |
40 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
41 | sequences_should_be_stopped = []
42 | for i in range(input_ids.shape[0]):
43 | ids = input_ids[i][self.prompt_length:].tolist()
44 | should_be_stopped = False
45 | for stop_sequence in self.stop_sequences:
46 | if input_ids.shape[0] == 1:
47 | _ids = ids[-len(stop_sequence):]
48 | else:
49 | _ids = ids
50 | for j in range(len(_ids), 0, -len(stop_sequence)):
51 | if _ids[max(j - len(stop_sequence), 0): j] == stop_sequence:
52 | should_be_stopped = True
53 | break
54 | if should_be_stopped:
55 | break
56 | sequences_should_be_stopped.append(should_be_stopped)
57 | return all(sequences_should_be_stopped)
58 |
59 |
60 | class KeyWordsCriteria(StoppingCriteria):
61 | def __init__(self, stop_id_sequences):
62 | assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
63 | self.stop_sequences = stop_id_sequences
64 |
65 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
66 | sequences_should_be_stopped = []
67 | for i in range(input_ids.shape[0]):
68 | sequence_should_be_stopped = False
69 | for stop_sequence in self.stop_sequences:
70 | if input_ids[i][-len(stop_sequence):].tolist() == stop_sequence:
71 | sequence_should_be_stopped = True
72 | break
73 | sequences_should_be_stopped.append(sequence_should_be_stopped)
74 | return all(sequences_should_be_stopped)
75 |
76 |
77 | @torch.no_grad()
78 | def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, add_special_tokens=True, disable_tqdm=False, **generation_kwargs):
79 | generations = []
80 | if not disable_tqdm:
81 | progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")
82 |
83 | num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
84 | for i in range(0, len(prompts), batch_size):
85 | batch_prompts = prompts[i:i+batch_size]
86 | tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens)
87 | batch_input_ids = tokenized_prompts.input_ids
88 | attention_mask = tokenized_prompts.attention_mask
89 |
90 | if model.device.type == "cuda":
91 | batch_input_ids = batch_input_ids.cuda()
92 | attention_mask = attention_mask.cuda()
93 |
94 | # try:
95 | stop_criteria = KeywordsStoppingCriteria(stop_id_sequences, tokenizer)
96 | batch_outputs = model.generate(
97 | input_ids=batch_input_ids,
98 | attention_mask=attention_mask,
99 | stopping_criteria=StoppingCriteriaList([stop_criteria]),
100 | # stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None,
101 | # stopping_criteria=[KeyWordsCriteriaTrunc(stop_id_sequences, batch_input_ids.size(1))] if stop_id_sequences else None,
102 | **generation_kwargs
103 | )
104 |
105 | # the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
106 | # so some outputs still have the stop sequence, which we need to remove.
107 | # if stop_id_sequences:
108 | # for output_idx in range(batch_outputs.shape[0]):
109 | # for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
110 | # if any(batch_outputs[output_idx, token_idx: token_idx+len(stop_sequence)].tolist() == stop_sequence for stop_sequence in stop_id_sequences):
111 | # batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
112 | # break
113 |
114 | # remove the prompt from the output
115 | # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
116 | # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
117 | # space is important for some tasks (e.g., code completion).
118 | batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
119 | batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
120 | # duplicate the prompts to match the number of return sequences
121 | batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
122 | batch_generations = [
123 | output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
124 | ]
125 |
126 | # remove the remain stop sequence from the output.
127 | for idx, prediction in enumerate(batch_generations):
128 | for stop_sequence in stop_id_sequences:
129 | batch_generations[idx] = prediction.split(stop_sequence)[0]
130 |
131 | generations += batch_generations
132 |
133 | if not disable_tqdm:
134 | progress.update(len(batch_prompts)//num_return_sequences)
135 |
136 | assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
137 | return generations
138 |
139 |
140 | def load_hf_lm_and_tokenizer(
141 | model_name_or_path,
142 | tokenizer_name_or_path=None,
143 | device_map="auto",
144 | load_in_8bit=False,
145 | load_in_half=True,
146 | gptq_model=False,
147 | use_fast_tokenizer=False,
148 | padding_side="left",
149 | use_safetensors=False,
150 | ):
151 | import torch
152 | from transformers import AutoModelForCausalLM, AutoTokenizer
153 |
154 | if not tokenizer_name_or_path:
155 | tokenizer_name_or_path = model_name_or_path
156 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=use_fast_tokenizer, padding_side=padding_side, trust_remote_code=True)
157 | # tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, legacy=False, use_fast=use_fast_tokenizer, padding_side=padding_side, trust_remote_code=True)
158 |
159 | # set pad token to eos token if pad token is not set
160 | if tokenizer.pad_token is None:
161 | if tokenizer.unk_token:
162 | tokenizer.pad_token = tokenizer.unk_token
163 | tokenizer.pad_token_id = tokenizer.unk_token_id
164 | elif tokenizer.eos_token:
165 | tokenizer.pad_token = tokenizer.eos_token
166 | tokenizer.pad_token_id = tokenizer.eos_token_id
167 | else:
168 | raise ValueError("You are using a new tokenizer without a pad token."
169 | "This is not supported by this script.")
170 |
171 | # if tokenizer.pad_token is None:
172 | # tokenizer.pad_token = tokenizer.unk_token
173 | # tokenizer.pad_token_id = tokenizer.unk_token_id
174 |
175 | if gptq_model:
176 | from auto_gptq import AutoGPTQForCausalLM
177 | model_wrapper = AutoGPTQForCausalLM.from_quantized(
178 | model_name_or_path, device="cuda:0", use_triton=True
179 | )
180 | model = model_wrapper.model
181 | elif load_in_8bit:
182 | model = AutoModelForCausalLM.from_pretrained(
183 | model_name_or_path,
184 | device_map=device_map,
185 | load_in_8bit=True
186 | )
187 | else:
188 | # return "", tokenizer
189 | # defaul load in float16
190 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
191 | torch_dtype=torch.float16,
192 | device_map=device_map,
193 | trust_remote_code=True,
194 | use_safetensors=use_safetensors)
195 | if torch.cuda.is_available():
196 | model = model.cuda()
197 | if load_in_half:
198 | model = model.half()
199 | model.eval()
200 | return model, tokenizer
201 |
202 |
203 | def _test_generate_completions():
204 | model_name_or_path = "../models/codellama_7b/v1-16k"
205 | llm, tokenizer = load_hf_lm_and_tokenizer(
206 | model_name_or_path=model_name_or_path,
207 | load_in_half=True,
208 | use_fast_tokenizer=True,
209 | use_safetensors=True,
210 | )
211 | # some math word problems
212 | prompts = [
213 | "---\n1+1=2\n---2+2=4\n---3+3=6\n---4+4=8\n---5+5=10\n---6+6=",
214 | "---\n1+1=2\n---12+12=24\n---3+3=6\n---12345+12345=",
215 | # "A train leaves Chicago at 7am and travels at 60mph. Another train leaves Chicago at 9am and travels at 80mph. When will the second train overtake the first?",
216 | # "The sum of two numbers is 10. The difference of the same two numbers is 4. What are the two numbers?",
217 | ]
218 |
219 | stop_sequences = ["\n\n\n", "---"]
220 | # Because many tokenizers will treat the word after space differently from the original word alone,
221 | # to be consistent, we add a space before tokenization and remove it after tokenization.
222 | # stop_id_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences]
223 | outputs = generate_completions(
224 | model=llm,
225 | tokenizer=tokenizer,
226 | prompts=prompts,
227 | max_new_tokens=128,
228 | batch_size=16,
229 | # stop_id_sequences=stop_id_sequences,
230 | stop_id_sequences=stop_sequences,
231 | )
232 | print(outputs)
233 |
234 | if __name__ == "__main__":
235 | _test_generate_completions()
--------------------------------------------------------------------------------
/src/applications/math_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import time
3 | import os
4 | import json
5 | import random
6 | import string
7 | from enum import Enum, auto
8 | from tqdm import tqdm
9 | from collections import OrderedDict
10 | import dataclasses
11 | import pandas as pd
12 | import timeout_decorator
13 | import mpmath
14 | import sympy as sp
15 | from sympy.parsing.latex import parse_latex
16 | import sympy as sp
17 | from sympy import simplify
18 | from sympy.printing import latex
19 | from sympy.core.relational import Relational
20 | from sympy.solvers.solveset import solvify
21 | from sympy.solvers.inequalities import reduce_inequalities
22 | from sympy.parsing.sympy_parser import (
23 | parse_expr,
24 | standard_transformations,
25 | implicit_multiplication,
26 | )
27 |
28 |
29 | def compare_numerical_ans(ans_p, ans_l):
30 | if ans_p is None:
31 | return False
32 | ans_p = ans_p.replace(",", "").replace("$", "")
33 | ans_l = ans_l.replace(",", "").replace("$", "")
34 | try:
35 | if ans_p.endswith("%"):
36 | ans_p = float(ans_p.rstrip("%")) / 100
37 | if isinstance(ans_p, str):
38 | ans_p = float(ans_p)
39 | if isinstance(ans_l, str):
40 | ans_l = float(ans_l)
41 | except Exception as e:
42 | return False
43 | return abs(ans_p - float(ans_l)) < 1e-3
44 |
45 |
46 | def my_parse_latex(expr_str):
47 | expr_str = expr_str.replace("dfrac", "frac")
48 | expr = parse_latex(expr_str)
49 | if "\\pi" in expr_str:
50 | expr = expr.subs({sp.Symbol("pi"): sp.pi})
51 | expr = expr.subs({sp.Symbol("i"): sp.I})
52 | return expr
53 |
54 |
55 | def is_number(element: str) -> bool:
56 | try:
57 | float(element.replace(" ", ""))
58 | return True
59 | except ValueError:
60 | return False
61 |
62 |
63 | def percentage_to_fraction(text):
64 | pattern = r"(\d+(\.\d+)?%)"
65 | matches = re.findall(pattern, text)
66 | for match in matches:
67 | percentage_str = match[0]
68 | percentage = float(percentage_str.strip("%")) / 100
69 | fraction = str(percentage)
70 | text = text.replace(percentage_str, fraction)
71 | return text
72 |
73 |
74 | def clean_expr_str(expr_str):
75 | expr_str = (
76 | expr_str.replace(" . ", ".")
77 | .replace(". ", ".")
78 | .replace("**", "^")
79 | .replace("\\pm", "")
80 | .replace("*", "\\times ")
81 | .replace("\\\\", "\\")
82 | .replace("\\ne ", "\\neq ")
83 | .replace("!=", "\\neq")
84 | .replace(">=", "\\ge")
85 | .replace("<=", "\\le")
86 | .replace("≠", "\\neq")
87 | .replace("dfrac", "frac")
88 | .replace("tfrac", "frac")
89 | .replace("\\$", "")
90 | .replace("$", "")
91 | .replace("\\%", "")
92 | .replace("%", "")
93 | .replace("\\!", "")
94 | .replace("^\circ", "\\times \\pi / 180")
95 | .replace("//", "/")
96 | .replace('"', "")
97 | # .replace(",", "") # TODO
98 | )
99 | # expr_str = re.sub(r"\^\s(.*)", r"\^\s{\1}", expr_str)
100 | expr_str = re.sub(r"\\+", r"\\", expr_str)
101 | expr_str = re.sub(r"\^\s?\((.*?)\)", r"^{\1}", expr_str)
102 | expr_str = re.sub(r"\\frac\s?(\d)\s?(\d+)", r"\\frac{\1}{\2}", expr_str)
103 | expr_str = re.sub(r"\\log_\s?(\d)\s?(\d+)", r"\\log_{\1}{\2}", expr_str)
104 | expr_str = re.sub(r"\\frac\s?{(.*?)}\s?(\d)", r"\\frac{\1}{\2}", expr_str)
105 | expr_str = re.sub(r"\\frac\s?(\d)\s?{(.*?)}", r"\\frac{\1}{\2}", expr_str)
106 | expr_str = re.sub(r"\\sqrt\s?(\d)", r"\\sqrt{\1}", expr_str)
107 | expr_str = re.sub(r"sqrt\s?\((\d+)\)", r"\\sqrt{\1}", expr_str)
108 | expr_str = re.sub(r"sqrt\s?\((.*?)\)", r"\\sqrt{\1}", expr_str)
109 | expr_str = expr_str.replace(" sqrt", "\\sqrt")
110 | expr_str = (
111 | expr_str.replace("\\left", "").replace("\\right.", "").replace("\\right", "")
112 | )
113 | return expr_str
114 |
115 |
116 | def parse_latex_answer(sample):
117 | if isinstance(sample, int) or isinstance(sample, float):
118 | sample = str(sample)
119 | # return sample
120 | sample = clean_expr_str(sample)
121 | try:
122 | expr = my_parse_latex(sample)
123 | except:
124 | print("[parse failed]", sample)
125 | return None
126 | return expr
127 |
128 |
129 | def my_equals(ans_p, ans_l):
130 | return ans_p.equals(ans_l)
131 |
132 |
133 | def is_expr_equal(ans_p, ans_l, is_strict=False):
134 | def is_equ_num_equal(equation, number):
135 | if (
136 | isinstance(equation, sp.Eq)
137 | # and isinstance(equation.lhs, sp.Symbol)
138 | and equation.rhs.is_number
139 | and number.is_number
140 | ):
141 | try:
142 | ret = my_equals(equation.rhs, number)
143 | return bool(ret)
144 | except:
145 | return equation.rhs == number
146 |
147 | if ans_p is None or ans_l is None:
148 | return False
149 | if isinstance(ans_l, str):
150 | return ans_p == ans_l
151 |
152 | if (
153 | not is_strict
154 | and is_equ_num_equal(ans_l, ans_p)
155 | or is_equ_num_equal(ans_p, ans_l)
156 | ):
157 | return True
158 |
159 | if ans_p.free_symbols != ans_l.free_symbols:
160 | return False
161 |
162 | if ans_p == ans_l:
163 | return True
164 |
165 | if isinstance(ans_l, sp.core.relational.Relational):
166 | try:
167 | if (
168 | type(ans_l) == type(ans_p)
169 | and my_equals(ans_p.lhs, ans_l.lhs)
170 | and my_equals(ans_p.rhs, ans_l.rhs)
171 | ):
172 | return True
173 | except Exception as e:
174 | print(ans_p, ans_l, e)
175 | try:
176 | ret = my_equals(ans_p, ans_l)
177 | return bool(ret)
178 | except:
179 | return False
180 |
181 |
182 | # @timeout_decorator.timeout(5)
183 | # def compare_ans(ans_p_str, ans_l_str, is_strict=False):
184 | # ans_p_str = clean_expr_str(ans_p_str)
185 | # ans_p_str = ans_p_str.replace(",", "").replace("$", "")
186 | # ans_l_str = clean_expr_str(ans_l_str)
187 | # ans_l_str = ans_l_str.replace(",", "").replace("$", "")
188 | # if ans_p_str is None:
189 | # return False
190 | # if ans_p_str.replace(" ", "") == ans_l_str.replace(" ", ""):
191 | # return True
192 | # ans_p = parse_latex_answer(ans_p_str)
193 | # if ans_p is None:
194 | # return False
195 | # ans_l = parse_latex_answer(ans_l_str)
196 | # if ans_l is None:
197 | # return False
198 | # return is_expr_equal(ans_p, ans_l, is_strict=is_strict)
199 |
200 |
201 | def extract_answer_number(sentence: str) -> float:
202 | sentence = sentence.replace(",", "")
203 | pred = [s for s in re.findall(r"-?\d+\.?\d*", sentence)]
204 | if not pred:
205 | return ""
206 | return pred[-1]
207 |
208 |
209 | @timeout_decorator.timeout(5)
210 | def compare_ans(ans_p_str, ans_l_str, is_strict=False):
211 | ans_p_str = clean_expr_str(ans_p_str)
212 | ans_p_str = ans_p_str.replace(",", "").replace("$", "")
213 | ans_l_str = clean_expr_str(ans_l_str)
214 | ans_l_str = ans_l_str.replace(",", "").replace("$", "")
215 | if ans_p_str is None:
216 | return False
217 | if ans_p_str.replace(" ", "") == ans_l_str.replace(" ", ""):
218 | return True
219 | ans_p = parse_latex_answer(ans_p_str)
220 | if ans_p is None:
221 | return False
222 | ans_l = parse_latex_answer(ans_l_str)
223 | if ans_l is None:
224 | return False
225 | if is_expr_equal(ans_p, ans_l, is_strict=is_strict):
226 | return True
227 | # TODO not suitable
228 | ans_p_str = extract_answer_number(ans_p_str)
229 | if is_number(ans_p_str):
230 | ans_p = parse_latex_answer(ans_p_str)
231 | if is_expr_equal(ans_p, ans_l, is_strict=is_strict):
232 | return True
233 | return False
234 |
235 |
236 | def vote(answers):
237 | counter = Counter(answers)
238 | return counter.most_common(1)[0][0]
239 |
240 |
241 | def contains_number(s):
242 | return any(i.isdigit() for i in s)
243 |
244 |
245 | def rough_compare_ans(generation, answer):
246 | for line in generation.split("\n")[::-1]:
247 | if contains_number(line):
248 | break
249 | words = line.split()
250 | for i, w in enumerate(words):
251 | if i > 0 and words[i - 1] in ["+", "-", "*", "/", "^"]:
252 | continue
253 | if i < len(words) - 1 and words[i + 1] in ["+", "-", "*", "/", "^"]:
254 | continue
255 | if not contains_number(w):
256 | continue
257 | if compare_numerical_ans(w.replace("$", ""), answer) and "=" not in " ".join(
258 | w[i:]
259 | ):
260 | return 1
261 | return 0
262 |
263 |
264 | if __name__ == "__main__":
265 | pred = "To find the minimum value of $f(x)$, we need to find the critical points where the derivative of $f(x)$ is either zero or undefined. Let's write a function to solve this problem:\n\n```python\nfrom sympy import symbols, diff, solve\n\ndef find_minimum_value():\n x, p = symbols('x p')\n f = abs(x - p) + abs(x - 15) + abs(x - p - 15)\n f_prime = diff(f, x)\n critical_points = solve(f_prime, x)\n \n # Evaluate f(x) at the critical points and endpoints\n values = [f.subs(x, p), f.subs(x, 15)]\n for point in critical_points:\n values.append(f.subs(x, point))\n \n # Find the minimum value\n minimum_value = min(values)\n \n return minimum_value\n\nresult = find_minimum_value()\nprint(result)\n```\n```output\nValueError(\"\\nCan't calculate derivative wrt re(x) + I*im(x).\")\n```\nIt seems that there is an error in the code. Let's modify the code to handle the case when the derivative is undefined.\n\n```python\nfrom sympy import symbols, diff, solve, oo\n\ndef find_minimum_value():\n x, p = symbols('x p')\n f = abs(x - p) + abs(x - 15) + abs(x - p - 15)\n f_prime = diff(f, x)\n critical_points = solve(f_prime, x)\n \n # Evaluate f(x) at the critical points and endpoints\n values = [f.subs(x, p), f.subs(x, 15)]\n for point in critical_points:\n if point.is_real:\n values.append(f.subs(x, point))\n \n # Find the minimum value\n minimum_value = min(values)\n \n return minimum_value\n\nresult = find_minimum_value()\nprint(result)\n```\n```output\nValueError(\"\\nCan't calculate derivative wrt re(x) + I*im(x).\")\n```\nApologies for the error in the code. It seems that the derivative calculation is causing an issue. Let's modify the code to handle the absolute value function and find the minimum value using a different approach.\n\n```python\nfrom sympy import symbols, Piecewise, Min, oo\n\ndef find_minimum_value():\n x, p = symbols('x p')\n f = abs(x - p) + abs(x - 15) + abs(x - p - 15)\n \n # Define the piecewise function for f(x)\n f_piecewise = Piecewise((f, x <= p), (f.subs(x, p), p < x <= 15), (f.subs(x, 15), x > 15))\n \n # Find the minimum value\n minimum_value = Min(f_piecewise, (x, p, 15))\n \n return minimum_value\n\nresult = find_minimum_value()\nprint(result)\n```\n```output\nTypeError('cannot determine truth value of Relational')\n```\nApologies for the error in the code. It seems that the piecewise function is causing an issue. Let's modify the code to handle the absolute value function and find the minimum value using a different approach.\n\n```python\nfrom sympy import symbols, Min, oo\n\ndef find_minimum_value():\n x, p = symbols('x p')\n f = abs(x - p) + abs(x - 15) + abs(x - p - 15)\n \n # Find the minimum value\n minimum_value = Min(f, (x, p, 15))\n \n return minimum_value\n\nresult = find_minimum_value()\nprint(result)\n```\n"
266 | label = "15"
267 | print(compare_ans(pred, label))
268 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/src/applications/grader.py:
--------------------------------------------------------------------------------
1 | """
2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
3 | - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
4 | - https://github.com/openai/prm800k
5 | - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
6 | - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
7 | """
8 |
9 | import re
10 | import regex
11 | import multiprocessing
12 | from math import isclose
13 | from typing import Union
14 | from collections import defaultdict
15 |
16 | from sympy import simplify, N
17 | from sympy.parsing.sympy_parser import parse_expr
18 | from sympy.parsing.latex import parse_latex
19 | from latex2sympy2 import latex2sympy
20 |
21 | # from .parser import choice_answer_clean, strip_string
22 | # from parser import choice_answer_clean
23 |
24 |
25 | def choice_answer_clean(pred: str):
26 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
27 | # Clean the answer based on the dataset
28 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
29 | if tmp:
30 | pred = tmp
31 | else:
32 | pred = [pred.strip().strip(".")]
33 | pred = pred[-1]
34 | # Remove the period at the end, again!
35 | pred = pred.rstrip(".").rstrip("/")
36 | return pred
37 |
38 |
39 | def parse_digits(num):
40 | num = regex.sub(",", "", str(num))
41 | try:
42 | return float(num)
43 | except:
44 | if num.endswith("%"):
45 | num = num[:-1]
46 | if num.endswith("\\"):
47 | num = num[:-1]
48 | try:
49 | return float(num) / 100
50 | except:
51 | pass
52 | return None
53 |
54 |
55 | def is_digit(num):
56 | # paired with parse_digits
57 | return parse_digits(num) is not None
58 |
59 |
60 | def str_to_pmatrix(input_str):
61 | input_str = input_str.strip()
62 | matrix_str = re.findall(r"\{.*,.*\}", input_str)
63 | pmatrix_list = []
64 |
65 | for m in matrix_str:
66 | m = m.strip("{}")
67 | pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
68 | pmatrix_list.append(pmatrix)
69 |
70 | return ", ".join(pmatrix_list)
71 |
72 |
73 | def math_equal(
74 | prediction: Union[bool, float, str],
75 | reference: Union[float, str],
76 | include_percentage: bool = True,
77 | is_close: bool = True,
78 | timeout: bool = False,
79 | ) -> bool:
80 | """
81 | Exact match of math if and only if:
82 | 1. numerical equal: both can convert to float and are equal
83 | 2. symbolic equal: both can convert to sympy expression and are equal
84 | """
85 | # print("Judge:", prediction, reference)
86 | if prediction is None or reference is None:
87 | return False
88 | if str(prediction.strip().lower()) == str(reference.strip().lower()):
89 | return True
90 | if (
91 | reference in ["A", "B", "C", "D", "E"]
92 | and choice_answer_clean(prediction) == reference
93 | ):
94 | return True
95 |
96 | try: # 1. numerical equal
97 | if is_digit(prediction) and is_digit(reference):
98 | prediction = parse_digits(prediction)
99 | reference = parse_digits(reference)
100 | # number questions
101 | if include_percentage:
102 | gt_result = [reference / 100, reference, reference * 100]
103 | else:
104 | gt_result = [reference]
105 | for item in gt_result:
106 | try:
107 | if is_close:
108 | if numeric_equal(prediction, item):
109 | return True
110 | else:
111 | if item == prediction:
112 | return True
113 | except Exception:
114 | continue
115 | return False
116 | except:
117 | pass
118 |
119 | if not prediction and prediction not in [0, False]:
120 | return False
121 |
122 | # 2. symbolic equal
123 | reference = str(reference).strip()
124 | prediction = str(prediction).strip()
125 |
126 | ## pmatrix (amps)
127 | if "pmatrix" in prediction and not "pmatrix" in reference:
128 | reference = str_to_pmatrix(reference)
129 |
130 | ## deal with [], (), {}
131 | pred_str, ref_str = prediction, reference
132 | if (
133 | prediction.startswith("[")
134 | and prediction.endswith("]")
135 | and not reference.startswith("(")
136 | ) or (
137 | prediction.startswith("(")
138 | and prediction.endswith(")")
139 | and not reference.startswith("[")
140 | ):
141 | pred_str = pred_str.strip("[]()")
142 | ref_str = ref_str.strip("[]()")
143 | for s in ["{", "}", "(", ")"]:
144 | ref_str = ref_str.replace(s, "")
145 | pred_str = pred_str.replace(s, "")
146 | if pred_str.lower() == ref_str.lower():
147 | return True
148 |
149 | ## [a, b] vs. [c, d], return a==c and b==d
150 | if (
151 | regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
152 | and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
153 | ):
154 | pred_parts = prediction[1:-1].split(",")
155 | ref_parts = reference[1:-1].split(",")
156 | if len(pred_parts) == len(ref_parts):
157 | if all(
158 | [
159 | math_equal(
160 | pred_parts[i], ref_parts[i], include_percentage, is_close
161 | )
162 | for i in range(len(pred_parts))
163 | ]
164 | ):
165 | return True
166 | if (
167 | (
168 | prediction.startswith("\\begin{pmatrix}")
169 | or prediction.startswith("\\begin{bmatrix}")
170 | )
171 | and (
172 | prediction.endswith("\\end{pmatrix}")
173 | or prediction.endswith("\\end{bmatrix}")
174 | )
175 | and (
176 | reference.startswith("\\begin{pmatrix}")
177 | or reference.startswith("\\begin{bmatrix}")
178 | )
179 | and (
180 | reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
181 | )
182 | ):
183 | pred_lines = [
184 | line.strip()
185 | for line in prediction[
186 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
187 | ].split("\\\\")
188 | if line.strip()
189 | ]
190 | ref_lines = [
191 | line.strip()
192 | for line in reference[
193 | len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
194 | ].split("\\\\")
195 | if line.strip()
196 | ]
197 | matched = True
198 | if len(pred_lines) == len(ref_lines):
199 | for pred_line, ref_line in zip(pred_lines, ref_lines):
200 | pred_parts = pred_line.split("&")
201 | ref_parts = ref_line.split("&")
202 | if len(pred_parts) == len(ref_parts):
203 | if not all(
204 | [
205 | math_equal(
206 | pred_parts[i],
207 | ref_parts[i],
208 | include_percentage,
209 | is_close,
210 | )
211 | for i in range(len(pred_parts))
212 | ]
213 | ):
214 | matched = False
215 | break
216 | else:
217 | matched = False
218 | if not matched:
219 | break
220 | else:
221 | matched = False
222 | if matched:
223 | return True
224 |
225 | if prediction.count("=") == 1 and reference.count("=") == 1:
226 | pred = prediction.split("=")
227 | pred = f"{pred[0].strip()} - ({pred[1].strip()})"
228 | ref = reference.split("=")
229 | ref = f"{ref[0].strip()} - ({ref[1].strip()})"
230 | if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
231 | return True
232 | elif (
233 | prediction.count("=") == 1
234 | and len(prediction.split("=")[0].strip()) <= 2
235 | and "=" not in reference
236 | ):
237 | if math_equal(
238 | prediction.split("=")[1], reference, include_percentage, is_close
239 | ):
240 | return True
241 | elif (
242 | reference.count("=") == 1
243 | and len(reference.split("=")[0].strip()) <= 2
244 | and "=" not in prediction
245 | ):
246 | if math_equal(
247 | prediction, reference.split("=")[1], include_percentage, is_close
248 | ):
249 | return True
250 |
251 | # symbolic equal with sympy
252 | if timeout:
253 | if call_with_timeout(symbolic_equal_process, prediction, reference):
254 | return True
255 | else:
256 | if symbolic_equal(prediction, reference):
257 | return True
258 |
259 | return False
260 |
261 |
262 | def math_equal_process(param):
263 | return math_equal(param[-2], param[-1])
264 |
265 |
266 | def numeric_equal(prediction: float, reference: float):
267 | # Note that relative tolerance has significant impact
268 | # on the result of the synthesized GSM-Hard dataset
269 | # if reference.is_integer():
270 | # return isclose(reference, round(prediction), abs_tol=1e-4)
271 | # else:
272 | # prediction = round(prediction, len(str(reference).split(".")[-1]))
273 | return isclose(reference, prediction, rel_tol=1e-4)
274 |
275 |
276 | def symbolic_equal(a, b):
277 | def _parse(s):
278 | for f in [parse_latex, parse_expr, latex2sympy]:
279 | try:
280 | return f(s.replace("\\\\", "\\"))
281 | except:
282 | try:
283 | return f(s)
284 | except:
285 | pass
286 | return s
287 |
288 | a = _parse(a)
289 | b = _parse(b)
290 |
291 | # direct equal
292 | try:
293 | if str(a) == str(b) or a == b:
294 | return True
295 | except:
296 | pass
297 |
298 | # simplify equal
299 | try:
300 | if a.equals(b) or simplify(a - b) == 0:
301 | return True
302 | except:
303 | pass
304 |
305 | # equation equal
306 | try:
307 | if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
308 | return True
309 | except:
310 | pass
311 |
312 | try:
313 | if numeric_equal(float(N(a)), float(N(b))):
314 | return True
315 | except:
316 | pass
317 |
318 | # matrix
319 | try:
320 | # if a and b are matrix
321 | if a.shape == b.shape:
322 | _a = a.applyfunc(lambda x: round(x, 3))
323 | _b = b.applyfunc(lambda x: round(x, 3))
324 | if _a.equals(_b):
325 | return True
326 | except:
327 | pass
328 |
329 | return False
330 |
331 |
332 | def symbolic_equal_process(a, b, output_queue):
333 | result = symbolic_equal(a, b)
334 | output_queue.put(result)
335 |
336 |
337 | def call_with_timeout(func, *args, timeout=1, **kwargs):
338 | output_queue = multiprocessing.Queue()
339 | process_args = args + (output_queue,)
340 | process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
341 | process.start()
342 | process.join(timeout)
343 |
344 | if process.is_alive():
345 | process.terminate()
346 | process.join()
347 | return False
348 |
349 | return output_queue.get()
350 |
351 | def _test_math_equal():
352 | # print(math_equal("0.0833333333333333", "\\frac{1}{12}"))
353 | # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})"))
354 | # print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
355 | # print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True))
356 | # print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True))
357 |
358 | # pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}'
359 | # gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})'
360 |
361 | # pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}'
362 | # gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}'
363 |
364 | # pred = '-34x-45y+20z-100=0'
365 | # gt = '34x+45y-20z+100=0'
366 |
367 | # pred = '\\frac{100}{3}'
368 | # gt = '33.3'
369 |
370 | # pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}'
371 | # gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})'
372 |
373 | # pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}'
374 | # gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}'
375 |
376 | # pred = '(+5)(b+2)'
377 | # gt = '(a+5)(b+2)'
378 |
379 | # pred = '\\frac{1+\\sqrt{5}}{2}'
380 | # gt = '2'
381 |
382 | # pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4'
383 | # pred = '1', gt = '1\\\\sqrt{19}'
384 |
385 | # pred = "(0.6,2.6667]"
386 | # gt = "(\\frac{3}{5},\\frac{8}{3}]"
387 |
388 | gt = "x+2n+1"
389 | pred = "x+1"
390 |
391 | print(math_equal(pred, gt, timeout=True))
392 |
393 |
394 | if __name__ == "__main__":
395 | _test_math_equal()
396 |
--------------------------------------------------------------------------------
/src/applications/TTTS_evaluate.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | import argparse
4 | import time
5 | import json
6 | import re
7 | import torch
8 | from datetime import datetime
9 | from tqdm import tqdm
10 | from transformers import LogitsProcessor, AutoTokenizer, AutoModelForCausalLM
11 | from vllm import LLM, SamplingParams
12 | from evaluate import evaluate
13 | from utils import set_seed, load_jsonl, save_jsonl, construct_prompt
14 | from parser import *
15 | from trajectory import *
16 | from data_loader import load_data
17 | from python_executor import PythonExecutor
18 | from model_utils import load_hf_lm_and_tokenizer, generate_completions
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--data_names", default="gsm8k,math", type=str)
23 | parser.add_argument("--data_dir", default="./data", type=str)
24 | parser.add_argument("--save_dir", default="./data", type=str)
25 | parser.add_argument("--model_name_or_path", default="gpt-4", type=str)
26 | parser.add_argument("--output_dir", default="./output", type=str)
27 | parser.add_argument("--prompt_type", default="tool-integrated", type=str)
28 | parser.add_argument("--split", default="test", type=str)
29 | parser.add_argument("--num_test_sample", default=-1, type=int) # -1: all data
30 | parser.add_argument("--seed", default=0, type=int)
31 | parser.add_argument("--start", default=0, type=int)
32 | parser.add_argument("--end", default=-1, type=int)
33 | parser.add_argument("--temperature", default=0, type=float)
34 | parser.add_argument("--n_sampling", default=1, type=int)
35 | parser.add_argument("--top_p", default=1, type=float)
36 | parser.add_argument("--shuffle", action="store_true")
37 | parser.add_argument("--use_vllm", action="store_true")
38 | parser.add_argument("--overwrite", action="store_true")
39 | parser.add_argument("--use_safetensors", action="store_true")
40 | parser.add_argument("--apply_chat_template", action="store_true",)
41 | parser.add_argument("--pipeline_parallel_size", type=int, default=1)
42 | parser.add_argument("--adapt_few_shot", action="store_true")
43 | parser.add_argument("--num_shots", type=int, default=0)
44 |
45 | # TTTS related
46 | parser.add_argument("--thinking_tokens_file_path", type=str, default=None)
47 | parser.add_argument("--max_tokens_per_call", default=2048, type=int)
48 | parser.add_argument("--thinking_token", default="", type=str)
49 | parser.add_argument("--token_budget", default=2048, type=int)
50 | args = parser.parse_args()
51 |
52 | args.top_p = 1 if args.temperature == 0 else args.top_p
53 | return args
54 |
55 |
56 | def load_model_and_tokenizer(args):
57 | available_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
58 | if args.use_vllm:
59 | llm = LLM(
60 | model=args.model_name_or_path,
61 | tensor_parallel_size=len(available_gpus) // args.pipeline_parallel_size,
62 | pipeline_parallel_size=args.pipeline_parallel_size,
63 | trust_remote_code=True,
64 | )
65 | tokenizer = None
66 | if args.apply_chat_template:
67 | tokenizer = AutoTokenizer.from_pretrained(
68 | args.model_name_or_path, trust_remote_code=True
69 | )
70 | else:
71 | llm, tokenizer = load_hf_lm_and_tokenizer(
72 | model_name_or_path=args.model_name_or_path,
73 | load_in_half=True,
74 | use_fast_tokenizer=True,
75 | use_safetensors=args.use_safetensors,
76 | )
77 | return llm, tokenizer
78 |
79 |
80 | def prepare_data(data_name, args):
81 | examples = load_data(data_name, args.split, args.data_dir)
82 |
83 | if args.num_test_sample > 0:
84 | examples = examples[: args.num_test_sample]
85 |
86 | if args.shuffle:
87 | random.seed(datetime.now().timestamp())
88 | random.shuffle(examples)
89 |
90 | examples = examples[args.start: (len(examples) if args.end == -1 else args.end)]
91 |
92 | dt_string = datetime.now().strftime("%m-%d_%H-%M")
93 | out_file_prefix = f"{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}"
94 | output_dir = args.output_dir
95 | output_dir = os.path.join("outputs", output_dir)
96 | os.makedirs(os.path.join(output_dir, data_name), exist_ok=True)
97 | out_file = os.path.join(output_dir, data_name, f"{out_file_prefix}_s{args.start}_e{args.end}.jsonl")
98 |
99 | processed_samples = []
100 | if not args.overwrite:
101 | processed_files = [
102 | f for f in os.listdir(os.path.join(output_dir, data_name))
103 | if f.endswith(".jsonl") and f.startswith(out_file_prefix)
104 | ]
105 | for f in processed_files:
106 | processed_samples.extend(list(load_jsonl(os.path.join(output_dir, data_name, f))))
107 |
108 | processed_samples_dict = {sample["idx"]: sample for sample in processed_samples}
109 | processed_idxs = set(processed_samples_dict.keys())
110 | examples = [example for example in examples if example["idx"] not in processed_idxs]
111 |
112 | return examples, list(processed_samples_dict.values()), out_file
113 |
114 |
115 | def setup(args):
116 | llm, tokenizer = load_model_and_tokenizer(args)
117 | data_list = args.data_names.split(",")
118 | results = []
119 | for data_name in data_list:
120 | results.append(main(llm, tokenizer, data_name, args))
121 |
122 | avg_acc = sum(result["acc"] for result in results) / len(results)
123 | data_list.append("avg")
124 | results.append({"acc": avg_acc})
125 |
126 | pad = max(len(name) for name in data_list)
127 | print("\t".join(name.ljust(pad, " ") for name in data_list))
128 | print("\t".join(f"{result['acc']:.1f}".ljust(pad, " ") for result in results))
129 |
130 |
131 | def is_multi_choice(answer):
132 | return all(c in ["A", "B", "C", "D", "E"] for c in answer)
133 |
134 |
135 | # ------------------------------ Thinking Token based Test-time Scaling ----------------------------------
136 | def batch_TTTS_generation(prompts, llm, tokenizer, args, thinking_token, stop_words):
137 | base_budget = args.max_tokens_per_call
138 | token_budget = args.token_budget
139 |
140 |
141 | ### 1. Base generation
142 | sampling_params = SamplingParams(
143 | max_tokens=base_budget,
144 | top_p=args.top_p,
145 | stop=stop_words,
146 | stop_token_ids=(
147 | [151645, 151643]
148 | if "qwen2" in args.model_name_or_path.lower()
149 | else None
150 | ),
151 | skip_special_tokens=False,
152 | temperature=0.0,
153 | )
154 | vllm_outputs = llm.generate(prompts, sampling_params)
155 |
156 | outputs = []
157 | for output in vllm_outputs:
158 | generated_text = output.outputs[0].text.strip()
159 | outputs.append(generated_text)
160 |
161 |
162 | ### 2. Generation with thinking tokens
163 | budget_prompts = []
164 |
165 | for q, a in zip(prompts, outputs):
166 | concat_prompt = f'{q} {a} {thinking_token}'
167 | budget_prompts.append(concat_prompt)
168 |
169 | sampling_params = SamplingParams(
170 | max_tokens=token_budget,
171 | top_p=args.top_p,
172 | stop=stop_words,
173 | stop_token_ids=(
174 | [151645, 151643]
175 | if "qwen2" in args.model_name_or_path.lower()
176 | else None
177 | ),
178 | skip_special_tokens=False,
179 | temperature=0.0,
180 | )
181 | vllm_outputs_budget = llm.generate(budget_prompts, sampling_params)
182 |
183 | outputs_budget = []
184 | for output in vllm_outputs_budget:
185 | generated_text = output.outputs[0].text.strip()
186 | outputs_budget.append(generated_text)
187 |
188 |
189 | ### 3. Final output
190 | final_prompts = []
191 | for q, a in zip(budget_prompts, outputs_budget):
192 | concat_prompt = q + a + r" Final Answer within \boxed{}:"
193 | final_prompts.append(concat_prompt)
194 |
195 |
196 | sampling_params = SamplingParams(
197 | max_tokens=16,
198 | top_p=args.top_p,
199 | stop=stop_words,
200 | stop_token_ids=(
201 | [151645, 151643]
202 | if "qwen2" in args.model_name_or_path.lower()
203 | else None
204 | ),
205 | skip_special_tokens=False,
206 | temperature=0.0,
207 | )
208 | final_vllm_outputs = llm.generate(final_prompts, sampling_params)
209 |
210 | final_outputs = []
211 | for output in final_vllm_outputs:
212 | generated_text = output.outputs[0].text.strip()
213 | final_outputs.append(generated_text)
214 |
215 | return final_vllm_outputs
216 |
217 |
218 | def main(llm, tokenizer, data_name, args):
219 | examples, processed_samples, out_file = prepare_data(data_name, args)
220 | print("Dataset:", data_name, "Total samples:", len(examples))
221 | if examples:
222 | print(examples[0])
223 |
224 | if "pal" in args.prompt_type:
225 | executor = PythonExecutor(get_answer_expr="solution()")
226 | else:
227 | executor = PythonExecutor(get_answer_from_stdout=True)
228 |
229 | samples = []
230 | for example in tqdm(examples, total=len(examples)):
231 | idx = example["idx"]
232 | example["question"] = parse_question(example, data_name)
233 | if example["question"] == "":
234 | continue
235 | gt_cot, gt_ans = parse_ground_truth(example, data_name)
236 | example["gt_ans"] = gt_ans
237 | full_prompt = construct_prompt(example, data_name, args)
238 |
239 | if idx == args.start:
240 | print(full_prompt)
241 |
242 | sample = {
243 | "idx": idx,
244 | "question": example["question"],
245 | "gt_cot": gt_cot,
246 | "gt": gt_ans,
247 | "prompt": full_prompt,
248 | }
249 | for key in [
250 | "level", "type", "unit", "solution_type", "choices", "solution",
251 | "ques_type", "ans_type", "answer_type", "dataset", "subfield", "filed",
252 | "theorem", "answer",
253 | ]:
254 | if key in example:
255 | sample[key] = example[key]
256 | samples.append(sample)
257 |
258 | input_prompts = [sample["prompt"] for sample in samples for _ in range(args.n_sampling)]
259 | if args.apply_chat_template and tokenizer is not None:
260 | input_prompts = [
261 | tokenizer.apply_chat_template(
262 | [{"role": "user", "content": prompt.strip()}],
263 | tokenize=False,
264 | add_generation_prompt=True,
265 | )
266 | for prompt in input_prompts
267 | ]
268 | remain_prompts = list(enumerate(input_prompts))
269 | end_prompts = []
270 |
271 | max_func_call = 1 if args.prompt_type in ["cot", "pal"] else 4
272 | stop_words = ["", "<|im_end|>", "<|endoftext|>"]
273 | if args.prompt_type in ["cot"]:
274 | stop_words.append("\n\nQuestion:")
275 | if args.prompt_type in ["pal", "tool-integrated", "jiuzhang_tora"]:
276 | stop_words.extend(["\n\n---", "```output"])
277 | elif args.prompt_type in ["wizard_zs", "platypus_fs"]:
278 | stop_words.extend(["Instruction", "Response"])
279 | elif "jiuzhang" in args.prompt_type:
280 | stop_words.append("\n\n## Question")
281 | elif "numina" in args.prompt_type:
282 | stop_words.append("\n### Problem")
283 | elif "pure" in args.prompt_type:
284 | stop_words.append("\n\n\n")
285 |
286 | start_time = time.time()
287 | for epoch in range(max_func_call):
288 | print("-" * 20, "Epoch", epoch)
289 | current_prompts = remain_prompts
290 | if not current_prompts:
291 | break
292 |
293 | prompts = [item[1] for item in current_prompts]
294 |
295 | outputs = batch_TTTS_generation(
296 | prompts=prompts,
297 | llm=llm,
298 | tokenizer=tokenizer,
299 | args=args,
300 | thinking_token=args.thinking_token,
301 | stop_words=stop_words,
302 | )
303 | outputs = sorted(outputs, key=lambda x: int(x.request_id))
304 | outputs = [output.outputs[0].text for output in outputs]
305 |
306 | assert len(outputs) == len(current_prompts)
307 |
308 | remain_prompts = []
309 | remain_codes = []
310 | for (i, query), output in zip(current_prompts, outputs):
311 | output = output.rstrip()
312 | query += output
313 | if args.prompt_type == "pal":
314 | remain_prompts.append((i, query))
315 | if "```python" in output:
316 | output = extract_program(query)
317 | remain_codes.append(output)
318 | elif args.prompt_type == "cot":
319 | end_prompts.append((i, query))
320 | elif "boxed" not in output and output.endswith("```"):
321 | program = extract_program(query)
322 | remain_prompts.append((i, query))
323 | remain_codes.append(program)
324 | else:
325 | end_prompts.append((i, query))
326 |
327 | remain_results = executor.batch_apply(remain_codes)
328 | for k in range(len(remain_prompts)):
329 | i, query = remain_prompts[k]
330 | res, report = remain_results[k]
331 | exec_result = res if res else report
332 | if "pal" in args.prompt_type:
333 | exec_result = "\\boxed{" + exec_result + "}"
334 | exec_result = f"\n```output\n{exec_result}\n```\n"
335 | query += exec_result
336 | if epoch == max_func_call - 1:
337 | query += "\nReach max function call limit."
338 | remain_prompts[k] = (i, query)
339 |
340 | end_prompts.extend(remain_prompts)
341 | end_prompts = sorted(end_prompts, key=lambda x: x[0])
342 |
343 | codes = []
344 | assert len(input_prompts) == len(end_prompts)
345 | for i in range(len(input_prompts)):
346 | _, end_prompt = end_prompts[i]
347 | code = end_prompt.split(input_prompts[i])[-1].strip()
348 | for stop_word in stop_words:
349 | if stop_word in code:
350 | code = code.split(stop_word)[0].strip()
351 | codes.append(code)
352 |
353 | results = [run_execute(executor, code, args.prompt_type, data_name) for code in codes]
354 | time_use = time.time() - start_time
355 |
356 | all_samples = []
357 | for i, sample in enumerate(samples):
358 | code_list = codes[i * args.n_sampling: (i + 1) * args.n_sampling]
359 | result_list = results[i * args.n_sampling: (i + 1) * args.n_sampling]
360 | preds = [item[0] for item in result_list]
361 | reports = [item[1] for item in result_list]
362 | for j in range(len(preds)):
363 | if sample["gt"] in ["A", "B", "C", "D", "E"] and preds[j] not in ["A", "B", "C", "D", "E"]:
364 | preds[j] = choice_answer_clean(code_list[j])
365 | elif is_multi_choice(sample["gt"]) and not is_multi_choice(preds[j]):
366 | preds[j] = "".join([c for c in preds[j] if c in ["A", "B", "C", "D", "E"]])
367 | sample.pop("prompt", None)
368 | sample.update({"code": code_list, "pred": preds, "report": reports})
369 | all_samples.append(sample)
370 |
371 | all_samples.extend(processed_samples)
372 | all_samples, result_json = evaluate(
373 | samples=all_samples,
374 | data_name=data_name,
375 | prompt_type=args.prompt_type,
376 | execute=True,
377 | )
378 | out_file_metrics = out_file.replace(".jsonl", f"_{args.prompt_type}_{args.thinking_token}_metrics.json")
379 | with open(out_file_metrics, "w") as f:
380 | json.dump(result_json, f, indent=4)
381 | result_json["time_use_in_second"] = time_use
382 | result_json["time_use_in_minite"] = f"{int(time_use // 60)}:{int(time_use % 60):02d}"
383 | result_json["word"] = args.thinking_token
384 | return result_json
385 |
386 |
387 | if __name__ == "__main__":
388 | args = parse_args()
389 | set_seed(args.seed)
390 |
391 | llm, tokenizer = load_model_and_tokenizer(args)
392 | prompt_words = load_jsonl(args.thinking_tokens_file_path)
393 |
394 | prompt_words = [
395 | entry for entry in prompt_words
396 | if re.fullmatch(r'[A-Za-z]+', entry.get('word', ''))
397 | ]
398 |
399 | output_file = f"{args.output_dir}/TTTS_results.json"
400 |
401 | aggregated_results = {}
402 | data_list = args.data_names.split(",")
403 | for data_name in data_list:
404 | aggregated_results[data_name] = {}
405 | for entry in prompt_words:
406 | word = entry["word"]
407 | args.thinking_token = word
408 | print(f"Evaluating dataset {data_name}, thinking token: {word}")
409 | result_json = main(llm, tokenizer, data_name, args)
410 | aggregated_results[data_name][word] = result_json
411 |
412 | print('aggregated_results:', aggregated_results)
413 | with open(output_file, "w") as f:
414 | json.dump(aggregated_results, f, indent=2)
415 |
--------------------------------------------------------------------------------
/src/applications/RR_evaluate.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | import argparse
4 | import time
5 | import re
6 | import json
7 | import torch
8 |
9 | from datetime import datetime
10 | from tqdm import tqdm
11 | from transformers import AutoTokenizer, AutoModelForCausalLM
12 | from evaluate import evaluate
13 | from utils import set_seed, load_jsonl, save_jsonl, construct_prompt
14 | from parser import *
15 | from trajectory import *
16 | from data_loader import load_data
17 | from python_executor import PythonExecutor
18 | from model_utils import load_hf_lm_and_tokenizer, generate_completions
19 | from RR_model import RecursiveThinkingModel
20 |
21 |
22 |
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--data_names", default="aime24", type=str)
27 | parser.add_argument("--data_dir", default="./data", type=str)
28 | parser.add_argument("--model_name_or_path", default="gpt-4", type=str)
29 | parser.add_argument("--output_dir", default="./output", type=str)
30 | parser.add_argument("--prompt_type", default="tool-integrated", type=str)
31 | parser.add_argument("--split", default="test", type=str)
32 | parser.add_argument("--num_test_sample", default=-1, type=int) # -1 means use all data
33 | parser.add_argument("--seed", default=0, type=int)
34 | parser.add_argument("--start", default=0, type=int)
35 | parser.add_argument("--end", default=-1, type=int)
36 | parser.add_argument("--temperature", default=0, type=float)
37 | parser.add_argument("--n_sampling", default=1, type=int)
38 | parser.add_argument("--top_p", default=1, type=float)
39 | parser.add_argument("--max_tokens_per_call", default=16000, type=int)
40 | parser.add_argument("--shuffle", action="store_true")
41 | parser.add_argument("--use_vllm", action="store_true")
42 | parser.add_argument("--save_outputs", action="store_true")
43 | parser.add_argument("--overwrite", action="store_true")
44 | parser.add_argument("--use_safetensors", action="store_true")
45 | parser.add_argument("--num_shots", type=int, default=0)
46 | parser.add_argument("--apply_chat_template", action="store_true", help="Apply chat template to prompts")
47 | parser.add_argument("--pipeline_parallel_size", type=int, default=1)
48 | parser.add_argument("--adapt_few_shot", action="store_true", help="Use few-shot examples for multiple-choice questions, zero-shot for others")
49 |
50 | # Add RecursiveThinkingModel specific parameters
51 | parser.add_argument("--use_recursive_thinking", type=bool, default=True, help="Enable recursive thinking feature")
52 | parser.add_argument("--extract_layer_id", type=int, default=-1, help="Layer ID for extracting hidden states")
53 | parser.add_argument("--inject_layer_id", type=int, default=-1, help="Layer ID for injecting hidden states")
54 | parser.add_argument("--num_recursive_steps", type=int, default=1, help="Number of recursive thinking steps")
55 | parser.add_argument("--interested_tokens", help="List of token IDs to apply recursive thinking, can be a string or a preprocessed list")
56 | parser.add_argument("--interested_tokens_file_path", help="interested_tokens_file_path")
57 | parser.add_argument("--output_file",default="outputs.jsonl", help="Output address for responses")
58 | args = parser.parse_args()
59 | args.top_p = (1 if args.temperature == 0 else args.top_p) # top_p must be 1 when using greedy sampling
60 | return args
61 |
62 |
63 |
64 | def setup(args):
65 | """Set up the model and evaluation environment"""
66 | # Output debug information about interested_tokens
67 | print(f"Type of interested_tokens: {type(args.interested_tokens)}")
68 | print(f"Length of interested_tokens: {len(args.interested_tokens) if hasattr(args.interested_tokens, '__len__') else 'N/A'}")
69 | if args.interested_tokens is not None and hasattr(args.interested_tokens, '__len__') and len(args.interested_tokens) > 0:
70 | print(f"Elements of interested_tokens: {args.interested_tokens}")
71 |
72 | # Load the model
73 | available_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
74 |
75 |
76 | # Load RecursiveThinkingModel
77 | if args.use_recursive_thinking:
78 | print("Loading RecursiveThinkingModel...")
79 | base_model = args.model_name_or_path
80 |
81 | # Process the list of interested tokens
82 | interested_tokens = args.interested_tokens
83 | print(f"Loaded interested_tokens_ids: Type={type(interested_tokens)}, Length={len(interested_tokens) if interested_tokens else 0}")
84 | if interested_tokens and len(interested_tokens) > 0:
85 | print(f"Loaded token IDs: {interested_tokens}")
86 |
87 | # Load tokenizer separately for preprocessing
88 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
89 | if tokenizer.pad_token is None:
90 | tokenizer.pad_token = tokenizer.eos_token
91 | # Create RecursiveThinkingModel
92 | llm = RecursiveThinkingModel(
93 | base_model_name=base_model,
94 | extract_layer_id=args.extract_layer_id,
95 | inject_layer_id=args.inject_layer_id,
96 | num_recursive_steps=args.num_recursive_steps,
97 | use_recursive_thinking=args.use_recursive_thinking,
98 | output_file=args.output_file,
99 | )
100 | print(f"RecursiveThinkingModel loaded, extraction layer={args.extract_layer_id}, injection layer={args.inject_layer_id}, recursive steps={args.num_recursive_steps}")
101 | else:
102 | # Standard HF model loading
103 | llm, tokenizer = load_hf_lm_and_tokenizer(
104 | model_name_or_path=args.model_name_or_path,
105 | load_in_half=True,
106 | use_fast_tokenizer=True,
107 | use_safetensors=args.use_safetensors,
108 | )
109 |
110 | # Inference and evaluation
111 | data_list = args.data_names.split(",")
112 | results = []
113 | for data_name in data_list:
114 | results.append(main(llm, tokenizer, data_name, args))
115 |
116 | # Add "avg" result to data_list and results
117 | data_list.append("avg")
118 | results.append(
119 | {
120 | "acc": sum([result["acc"] for result in results]) / len(results),
121 | }
122 | )
123 |
124 | # Print all results
125 | pad = max([len(data_name) for data_name in data_list])
126 | print("\t".join(data_name.ljust(pad, " ") for data_name in data_list))
127 | print("\t".join([f"{result['acc']:.1f}".ljust(pad, " ") for result in results]))
128 |
129 |
130 | def prepare_data(data_name, args):
131 | """Prepare evaluation data"""
132 | examples = load_data(data_name, args.split, args.data_dir)
133 |
134 | # Sample num_test_sample samples from the dataset
135 | if args.num_test_sample > 0:
136 | examples = examples[: args.num_test_sample]
137 |
138 | # Shuffle data
139 | if args.shuffle:
140 | random.seed(datetime.now().timestamp())
141 | random.shuffle(examples)
142 |
143 | # Select start and end indices
144 | examples = examples[args.start : len(examples) if args.end == -1 else args.end]
145 |
146 | # Get output filename
147 | dt_string = datetime.now().strftime("%m-%d_%H-%M")
148 | model_name = "/".join(args.model_name_or_path.split("/")[-2:])
149 | out_file_prefix = f"{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}"
150 | if args.use_recursive_thinking:
151 | out_file_prefix += f"_recursive{args.num_recursive_steps}_ext{args.extract_layer_id}_inj{args.inject_layer_id}"
152 | # Add an identifier indicating that interested tokens were used
153 | if args.interested_tokens is not None:
154 | out_file_prefix += "_with_interest_tokens"
155 | output_dir = args.output_dir
156 | if not os.path.exists(output_dir):
157 | output_dir = f"outputs/{output_dir}"
158 | out_file = f"{output_dir}/{data_name}/{out_file_prefix}_s{args.start}_e{args.end}.jsonl"
159 | os.makedirs(f"{output_dir}/{data_name}", exist_ok=True)
160 |
161 | # Load all processed samples
162 | processed_samples = []
163 | if not args.overwrite:
164 | processed_files = [
165 | f
166 | for f in os.listdir(f"{output_dir}/{data_name}/")
167 | if f.endswith(".jsonl") and f.startswith(out_file_prefix)
168 | ]
169 | for f in processed_files:
170 | processed_samples.extend(
171 | list(load_jsonl(f"{output_dir}/{data_name}/{f}"))
172 | )
173 |
174 | # Deduplicate
175 | processed_samples = {sample["idx"]: sample for sample in processed_samples}
176 | processed_idxs = list(processed_samples.keys())
177 | processed_samples = list(processed_samples.values())
178 | examples = [example for example in examples if example["idx"] not in processed_idxs]
179 | return examples, processed_samples, out_file
180 |
181 |
182 | def generate_with_recursive_model(model, prompt, max_tokens, interested_tokens=None, use_recursive_thinking=True):
183 | """Generate text using RecursiveThinkingModel"""
184 | return model.generate(
185 | max_tokens=max_tokens,
186 | prompt=prompt,
187 | interested_tokens=interested_tokens,
188 | use_recursive_thinking=use_recursive_thinking
189 | )
190 |
191 |
192 | def is_multi_choice(answer):
193 | """Check if the answer is in multiple-choice format"""
194 | for c in answer:
195 | if c not in ["A", "B", "C", "D", "E"]:
196 | return False
197 | return True
198 |
199 |
200 | def main(llm, tokenizer, data_name, args):
201 | """Main evaluation function"""
202 | examples, processed_samples, out_file = prepare_data(data_name, args)
203 | print("=" * 50)
204 | print("Data:", data_name, " , Remaining samples:", len(examples))
205 | if len(examples) > 0:
206 | print(examples[0])
207 |
208 | # Initialize Python executor
209 | if "pal" in args.prompt_type:
210 | executor = PythonExecutor(get_answer_expr="solution()")
211 | else:
212 | executor = PythonExecutor(get_answer_from_stdout=True)
213 |
214 | samples = []
215 | for example in tqdm(examples, total=len(examples)):
216 | idx = example["idx"]
217 |
218 | # Parse question and answer
219 | example["question"] = parse_question(example, data_name)
220 | if example["question"] == "":
221 | continue
222 | gt_cot, gt_ans = parse_ground_truth(example, data_name)
223 | example["gt_ans"] = gt_ans
224 | full_prompt = construct_prompt(example, data_name, args)
225 |
226 | if idx == args.start:
227 | print(full_prompt)
228 |
229 | sample = {
230 | "idx": idx,
231 | "question": example["question"],
232 | "gt_cot": gt_cot,
233 | "gt": gt_ans,
234 | "prompt": full_prompt,
235 | }
236 |
237 | # Add remaining fields
238 | for key in [
239 | "level",
240 | "type",
241 | "unit",
242 | "solution_type",
243 | "choices",
244 | "solution",
245 | "ques_type",
246 | "ans_type",
247 | "answer_type",
248 | "dataset",
249 | "subfield",
250 | "filed",
251 | "theorem",
252 | "answer",
253 | ]:
254 | if key in example:
255 | sample[key] = example[key]
256 | samples.append(sample)
257 |
258 | # Prepare prompts
259 | input_prompts = [
260 | sample["prompt"] for sample in samples for _ in range(args.n_sampling)
261 | ]
262 | if args.apply_chat_template:
263 | input_prompts = [
264 | tokenizer.apply_chat_template(
265 | [{"role": "user", "content": prompt.strip()}],
266 | tokenize=False,
267 | add_generation_prompt=True,
268 | )
269 | for prompt in input_prompts
270 | ]
271 | remain_prompts = input_prompts
272 | remain_prompts = [(i, prompt) for i, prompt in enumerate(remain_prompts)]
273 | end_prompts = []
274 |
275 | max_func_call = 1 if args.prompt_type in ["cot", "pal"] else 4
276 |
277 | stop_words = ["", "<|im_end|>", "<|endoftext|>"]
278 |
279 | if args.prompt_type in ["cot"]:
280 | stop_words.append("\n\nQuestion:")
281 | if args.prompt_type in ["pal", "tool-integrated", "jiuzhang_tora"]:
282 | stop_words.extend(["\n\n---", "```output"])
283 | elif args.prompt_type in ["wizard_zs", "platypus_fs"]:
284 | stop_words.extend(["Instruction", "Response"])
285 | elif "jiuzhang" in args.prompt_type:
286 | stop_words.append("\n\n## Question")
287 | elif "numina" in args.prompt_type:
288 | stop_words.append("\n### Problem")
289 | elif "pure" in args.prompt_type:
290 | stop_words.append("\n\n\n")
291 |
292 | # Start inference
293 | # Measure time usage
294 | start_time = time.time()
295 | for epoch in range(max_func_call):
296 | print("-" * 20, "Epoch", epoch)
297 | current_prompts = remain_prompts
298 | if len(current_prompts) == 0:
299 | break
300 |
301 | # Get all outputs
302 | prompts = [item[1] for item in current_prompts]
303 |
304 | # Use model to generate outputs
305 | if hasattr(llm, 'generate') and isinstance(llm, RecursiveThinkingModel):
306 | # Use RecursiveThinkingModel's generate method
307 | outputs = []
308 | # Process the list of interested tokens
309 | interested_tokens = args.interested_tokens
310 |
311 | for prompt in prompts:
312 | output = generate_with_recursive_model(
313 | model=llm,
314 | prompt=prompt,
315 | max_tokens=args.max_tokens_per_call,
316 | interested_tokens=interested_tokens,
317 | use_recursive_thinking=args.use_recursive_thinking
318 | )
319 | # Remove prompt from output, keep only the generated part
320 | if prompt in output:
321 | output = output[len(prompt):]
322 | outputs.append(output)
323 |
324 | # Check for stop words and truncate
325 | for stop_word in stop_words:
326 | if stop_word in output:
327 | output = output.split(stop_word)[0]
328 | break
329 | else:
330 | # Standard HF generation
331 | outputs = generate_completions(
332 | model=llm,
333 | tokenizer=tokenizer,
334 | prompts=prompts,
335 | max_new_tokens=args.max_tokens_per_call,
336 | batch_size=16,
337 | stop_id_sequences=stop_words,
338 | )
339 |
340 | assert len(outputs) == len(current_prompts)
341 |
342 | # Process all outputs
343 | remain_prompts = []
344 | remain_codes = []
345 | for (i, query), output in zip(current_prompts, outputs):
346 | output = output.rstrip()
347 | query += output
348 | if args.prompt_type == "pal":
349 | remain_prompts.append((i, query))
350 | if "```python" in output:
351 | output = extract_program(query)
352 | remain_codes.append(output)
353 | elif args.prompt_type == "cot":
354 | end_prompts.append((i, query))
355 | elif "boxed" not in output and output.endswith("```"):
356 | program = extract_program(query)
357 | remain_prompts.append((i, query))
358 | remain_codes.append(program)
359 | else:
360 | end_prompts.append((i, query))
361 |
362 | # Execute remaining prompts
363 | remain_results = executor.batch_apply(remain_codes)
364 | for k in range(len(remain_prompts)):
365 | i, query = remain_prompts[k]
366 | res, report = remain_results[k]
367 | exec_result = res if res else report
368 | if "pal" in args.prompt_type:
369 | exec_result = "\\boxed{" + exec_result + "}"
370 | exec_result = f"\n```output\n{exec_result}\n```\n"
371 | query += exec_result
372 | # Not finished
373 | if epoch == max_func_call - 1:
374 | query += "\nReached maximum function call limit."
375 | remain_prompts[k] = (i, query)
376 |
377 | # Unresolved samples
378 | print("Unresolved samples:", len(remain_prompts))
379 | end_prompts.extend(remain_prompts)
380 | # Sort by index
381 | end_prompts = sorted(end_prompts, key=lambda x: x[0])
382 |
383 | # Remove input_prompt from end_prompt
384 | codes = []
385 | assert len(input_prompts) == len(end_prompts)
386 | for i in range(len(input_prompts)):
387 | _, end_prompt = end_prompts[i]
388 | code = end_prompt.split(input_prompts[i])[-1].strip()
389 | for stop_word in stop_words:
390 | if stop_word in code:
391 | code = code.split(stop_word)[0].strip()
392 | codes.append(code)
393 |
394 | # Extract predictions
395 | results = [
396 | run_execute(executor, code, args.prompt_type, data_name) for code in codes
397 | ]
398 | time_use = time.time() - start_time
399 |
400 | # Put results back into samples
401 | all_samples = []
402 | for i, sample in enumerate(samples):
403 | code = codes[i * args.n_sampling : (i + 1) * args.n_sampling]
404 | result = results[i * args.n_sampling : (i + 1) * args.n_sampling]
405 | preds = [item[0] for item in result]
406 | reports = [item[1] for item in result]
407 | for j in range(len(preds)):
408 | if sample["gt"] in ["A", "B", "C", "D", "E"] and preds[j] not in [
409 | "A",
410 | "B",
411 | "C",
412 | "D",
413 | "E",
414 | ]:
415 | preds[j] = choice_answer_clean(code[j])
416 | elif is_multi_choice(sample["gt"]) and not is_multi_choice(preds[j]):
417 | # Remove any non-choice characters
418 | preds[j] = "".join(
419 | [c for c in preds[j] if c in ["A", "B", "C", "D", "E"]]
420 | )
421 |
422 | sample.pop("prompt")
423 | sample.update({"code": code, "pred": preds, "report": reports})
424 | all_samples.append(sample)
425 |
426 | # Add processed samples
427 | all_samples.extend(processed_samples)
428 | all_samples, result_json = evaluate(
429 | samples=all_samples,
430 | data_name=data_name,
431 | prompt_type=args.prompt_type,
432 | execute=True,
433 | )
434 |
435 | # Save outputs
436 | if len(processed_samples) < len(all_samples) and args.save_outputs:
437 | save_jsonl(all_samples, out_file)
438 |
439 | result_json["time_use_in_second"] = time_use
440 | result_json["time_use_in_minute"] = (
441 | f"{int(time_use // 60)}:{int(time_use % 60):02d}"
442 | )
443 |
444 | # Add recursive thinking parameters to metrics
445 | if args.use_recursive_thinking:
446 | result_json["recursive_thinking"] = {
447 | "enabled": args.use_recursive_thinking,
448 | "extract_layer_id": args.extract_layer_id,
449 | "inject_layer_id": args.inject_layer_id,
450 | "num_recursive_steps": args.num_recursive_steps,
451 | "interested_tokens_count": len(args.interested_tokens) if args.interested_tokens else 0
452 | }
453 |
454 | with open(
455 | out_file.replace(".jsonl", f"_{args.prompt_type}_metrics.json"), "w"
456 | ) as f:
457 | json.dump(result_json, f, indent=4)
458 | return result_json
459 |
460 |
461 | if __name__ == "__main__":
462 |
463 | args = parse_args()
464 |
465 | english_word_token_ids = set()
466 | english_pattern = re.compile(r'^[a-zA-Z]+$')
467 |
468 | english_word_count = 0
469 | target_word_count = 10
470 |
471 | # Read JSONL file
472 | with open(args.interested_tokens_file_path, 'r') as f:
473 | for line in f:
474 | try:
475 | if english_word_count >= target_word_count:
476 | break
477 |
478 | # Parse JSON
479 | record = json.loads(line)
480 | word = record.get("word", "")
481 | token_ids = record.get("token_ids", [])
482 |
483 | if english_pattern.match(word):
484 | english_word_count += 1
485 | for token_id in token_ids:
486 | english_word_token_ids.add(token_id)
487 |
488 | except json.JSONDecodeError:
489 | print(f"Error: Unable to parse JSON line: {line}")
490 | except Exception as e:
491 | print(f"Error during processing: {e}")
492 |
493 | args.interested_tokens = english_word_token_ids
494 |
495 | set_seed(args.seed)
496 | setup(args)
--------------------------------------------------------------------------------
/src/applications/parser.py:
--------------------------------------------------------------------------------
1 | import random
2 | import regex
3 | import re
4 | import sympy
5 | from latex2sympy2 import latex2sympy
6 | from typing import TypeVar, Iterable, List, Union, Any, Dict
7 | from word2number import w2n
8 | from utils import *
9 |
10 |
11 | def _fix_fracs(string):
12 | substrs = string.split("\\frac")
13 | new_str = substrs[0]
14 | if len(substrs) > 1:
15 | substrs = substrs[1:]
16 | for substr in substrs:
17 | new_str += "\\frac"
18 | if len(substr) > 0 and substr[0] == "{":
19 | new_str += substr
20 | else:
21 | try:
22 | assert len(substr) >= 2
23 | except:
24 | return string
25 | a = substr[0]
26 | b = substr[1]
27 | if b != "{":
28 | if len(substr) > 2:
29 | post_substr = substr[2:]
30 | new_str += "{" + a + "}{" + b + "}" + post_substr
31 | else:
32 | new_str += "{" + a + "}{" + b + "}"
33 | else:
34 | if len(substr) > 2:
35 | post_substr = substr[2:]
36 | new_str += "{" + a + "}" + b + post_substr
37 | else:
38 | new_str += "{" + a + "}" + b
39 | string = new_str
40 | return string
41 |
42 |
43 | def _fix_a_slash_b(string):
44 | if len(string.split("/")) != 2:
45 | return string
46 | a = string.split("/")[0]
47 | b = string.split("/")[1]
48 | try:
49 | if "sqrt" not in a:
50 | a = int(a)
51 | if "sqrt" not in b:
52 | b = int(b)
53 | assert string == "{}/{}".format(a, b)
54 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
55 | return new_string
56 | except:
57 | return string
58 |
59 |
60 | def _fix_sqrt(string):
61 | _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
62 | return _string
63 |
64 |
65 | def convert_word_number(text: str) -> str:
66 | try:
67 | text = str(w2n.word_to_num(text))
68 | except:
69 | pass
70 | return text
71 |
72 |
73 | # units mainly from MathQA
74 | unit_texts = [
75 | "east",
76 | "degree",
77 | "mph",
78 | "kmph",
79 | "ft",
80 | "m sqaure",
81 | " m east",
82 | "sq m",
83 | "deg",
84 | "mile",
85 | "q .",
86 | "monkey",
87 | "prime",
88 | "ratio",
89 | "profit of rs",
90 | "rd",
91 | "o",
92 | "gm",
93 | "p . m",
94 | "lb",
95 | "tile",
96 | "per",
97 | "dm",
98 | "lt",
99 | "gain",
100 | "ab",
101 | "way",
102 | "west",
103 | "a .",
104 | "b .",
105 | "c .",
106 | "d .",
107 | "e .",
108 | "f .",
109 | "g .",
110 | "h .",
111 | "t",
112 | "a",
113 | "h",
114 | "no change",
115 | "men",
116 | "soldier",
117 | "pie",
118 | "bc",
119 | "excess",
120 | "st",
121 | "inches",
122 | "noon",
123 | "percent",
124 | "by",
125 | "gal",
126 | "kmh",
127 | "c",
128 | "acre",
129 | "rise",
130 | "a . m",
131 | "th",
132 | "π r 2",
133 | "sq",
134 | "mark",
135 | "l",
136 | "toy",
137 | "coin",
138 | "sq . m",
139 | "gallon",
140 | "° f",
141 | "profit",
142 | "minw",
143 | "yr",
144 | "women",
145 | "feet",
146 | "am",
147 | "pm",
148 | "hr",
149 | "cu cm",
150 | "square",
151 | "v â € ™",
152 | "are",
153 | "rupee",
154 | "rounds",
155 | "cubic",
156 | "cc",
157 | "mtr",
158 | "s",
159 | "ohm",
160 | "number",
161 | "kmph",
162 | "day",
163 | "hour",
164 | "minute",
165 | "min",
166 | "second",
167 | "man",
168 | "woman",
169 | "sec",
170 | "cube",
171 | "mt",
172 | "sq inch",
173 | "mp",
174 | "∏ cm ³",
175 | "hectare",
176 | "more",
177 | "sec",
178 | "unit",
179 | "cu . m",
180 | "cm 2",
181 | "rs .",
182 | "rs",
183 | "kg",
184 | "g",
185 | "month",
186 | "km",
187 | "m",
188 | "cm",
189 | "mm",
190 | "apple",
191 | "liter",
192 | "loss",
193 | "yard",
194 | "pure",
195 | "year",
196 | "increase",
197 | "decrease",
198 | "d",
199 | "less",
200 | "Surface",
201 | "litre",
202 | "pi sq m",
203 | "s .",
204 | "metre",
205 | "meter",
206 | "inch",
207 | ]
208 |
209 | unit_texts.extend([t + "s" for t in unit_texts])
210 |
211 |
212 | def strip_string(string, skip_unit=False):
213 | string = str(string).strip()
214 | # linebreaks
215 | string = string.replace("\n", "")
216 |
217 | # right "."
218 | string = string.rstrip(".")
219 |
220 | # remove inverse spaces
221 | # replace \\ with \
222 | string = string.replace("\\!", "")
223 | # string = string.replace("\\ ", "")
224 | # string = string.replace("\\\\", "\\")
225 |
226 | # matrix
227 | string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
228 | string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
229 | string = string.replace("bmatrix", "pmatrix")
230 |
231 | # replace tfrac and dfrac with frac
232 | string = string.replace("tfrac", "frac")
233 | string = string.replace("dfrac", "frac")
234 | string = (
235 | string.replace("\\neq", "\\ne")
236 | .replace("\\leq", "\\le")
237 | .replace("\\geq", "\\ge")
238 | )
239 |
240 | # remove \left and \right
241 | string = string.replace("\\left", "")
242 | string = string.replace("\\right", "")
243 | string = string.replace("\\{", "{")
244 | string = string.replace("\\}", "}")
245 |
246 | # Remove unit: miles, dollars if after is not none
247 | _string = re.sub(r"\\text{.*?}$", "", string).strip()
248 | if _string != "" and _string != string:
249 | # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
250 | string = _string
251 |
252 | if not skip_unit:
253 | # Remove unit: texts
254 | for _ in range(2):
255 | for unit_text in unit_texts:
256 | # use regex, the prefix should be either the start of the string or a non-alphanumeric character
257 | # the suffix should be either the end of the string or a non-alphanumeric character
258 | _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
259 | if _string != "":
260 | string = _string
261 |
262 | # Remove circ (degrees)
263 | string = string.replace("^{\\circ}", "")
264 | string = string.replace("^\\circ", "")
265 |
266 | # remove dollar signs
267 | string = string.replace("\\$", "")
268 | string = string.replace("$", "")
269 | string = string.replace("\\(", "").replace("\\)", "")
270 |
271 | # convert word number to digit
272 | string = convert_word_number(string)
273 |
274 | # replace "\\text{...}" to "..."
275 | string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
276 | for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
277 | string = string.replace(key, "")
278 | string = string.replace("\\emptyset", r"{}")
279 | string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
280 |
281 | # remove percentage
282 | string = string.replace("\\%", "")
283 | string = string.replace("\%", "")
284 | string = string.replace("%", "")
285 |
286 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
287 | string = string.replace(" .", " 0.")
288 | string = string.replace("{.", "{0.")
289 |
290 | # cdot
291 | # string = string.replace("\\cdot", "")
292 | if (
293 | string.startswith("{")
294 | and string.endswith("}")
295 | and string.isalnum()
296 | or string.startswith("(")
297 | and string.endswith(")")
298 | and string.isalnum()
299 | or string.startswith("[")
300 | and string.endswith("]")
301 | and string.isalnum()
302 | ):
303 | string = string[1:-1]
304 |
305 | # inf
306 | string = string.replace("infinity", "\\infty")
307 | if "\\infty" not in string:
308 | string = string.replace("inf", "\\infty")
309 | string = string.replace("+\\inity", "\\infty")
310 |
311 | # and
312 | string = string.replace("and", "")
313 | string = string.replace("\\mathbf", "")
314 |
315 | # use regex to remove \mbox{...}
316 | string = re.sub(r"\\mbox{.*?}", "", string)
317 |
318 | # quote
319 | string.replace("'", "")
320 | string.replace('"', "")
321 |
322 | # i, j
323 | if "j" in string and "i" not in string:
324 | string = string.replace("j", "i")
325 |
326 | # replace a.000b where b is not number or b is end, with ab, use regex
327 | string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
328 | string = re.sub(r"(\d+)\.0*$", r"\1", string)
329 |
330 | # if empty, return empty string
331 | if len(string) == 0:
332 | return string
333 | if string[0] == ".":
334 | string = "0" + string
335 |
336 | # to consider: get rid of e.g. "k = " or "q = " at beginning
337 | if len(string.split("=")) == 2:
338 | if len(string.split("=")[0]) <= 2:
339 | string = string.split("=")[1]
340 |
341 | string = _fix_sqrt(string)
342 | string = string.replace(" ", "")
343 |
344 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
345 | string = _fix_fracs(string)
346 |
347 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
348 | string = _fix_a_slash_b(string)
349 |
350 | return string
351 |
352 |
353 | def extract_multi_choice_answer(pred_str):
354 | # TODO: SFT models
355 | if "Problem:" in pred_str:
356 | pred_str = pred_str.split("Problem:", 1)[0]
357 | pred_str = pred_str.replace("choice is", "answer is")
358 | patt = regex.search(r"answer is \(?(?P[abcde])\)?", pred_str.lower())
359 | if patt is not None:
360 | return patt.group("ans").upper()
361 | return "placeholder"
362 |
363 |
364 | direct_answer_trigger_for_fewshot = ("choice is", "answer is")
365 |
366 |
367 | def choice_answer_clean(pred: str):
368 | pred = pred.strip("\n")
369 |
370 | # Determine if this is ICL, if so, use \n\n to split the first chunk.
371 | ICL = False
372 | for trigger in direct_answer_trigger_for_fewshot:
373 | if pred.count(trigger) > 1:
374 | ICL = True
375 | if ICL:
376 | pred = pred.split("\n\n")[0]
377 |
378 | # Split the trigger to find the answer.
379 | preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred)
380 | if len(preds) > 1:
381 | answer_flag = True
382 | pred = preds[-1]
383 | else:
384 | answer_flag = False
385 |
386 | pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
387 |
388 | # Clean the answer based on the dataset
389 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
390 | if tmp:
391 | pred = tmp
392 | else:
393 | pred = [pred.strip().strip(".")]
394 |
395 | if len(pred) == 0:
396 | pred = ""
397 | else:
398 | if answer_flag:
399 | # choose the first element in list ...
400 | pred = pred[0]
401 | else:
402 | # choose the last e
403 | pred = pred[-1]
404 |
405 | # Remove the period at the end, again!
406 | pred = pred.rstrip(".").rstrip("/")
407 |
408 | return pred
409 |
410 |
411 | def find_box(pred_str: str):
412 | ans = pred_str.split("boxed")[-1]
413 | if not ans:
414 | return ""
415 | if ans[0] == "{":
416 | stack = 1
417 | a = ""
418 | for c in ans[1:]:
419 | if c == "{":
420 | stack += 1
421 | a += c
422 | elif c == "}":
423 | stack -= 1
424 | if stack == 0:
425 | break
426 | a += c
427 | else:
428 | a += c
429 | else:
430 | a = ans.split("$")[0].strip()
431 | return a
432 |
433 |
434 | def clean_units(pred_str: str):
435 | """Clean the units in the number."""
436 |
437 | def convert_pi_to_number(code_string):
438 | code_string = code_string.replace("\\pi", "π")
439 | # Replace \pi or π not preceded by a digit or } with 3.14
440 | code_string = re.sub(r"(? "3*3.14"
442 | code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string)
443 | # Handle cases where π is within braces or followed by a multiplication symbol
444 | # This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14"
445 | code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string)
446 | code_string = re.sub(r"\*(\\?π)", "*3.14", code_string)
447 | return code_string
448 |
449 | pred_str = convert_pi_to_number(pred_str)
450 | pred_str = pred_str.replace("%", "/100")
451 | pred_str = pred_str.replace("$", "")
452 | pred_str = pred_str.replace("¥", "")
453 | pred_str = pred_str.replace("°C", "")
454 | pred_str = pred_str.replace(" C", "")
455 | pred_str = pred_str.replace("°", "")
456 | return pred_str
457 |
458 |
459 | def extract_theoremqa_answer(pred: str, answer_flag: bool = True):
460 | if any([option in pred.lower() for option in ["yes", "true"]]):
461 | pred = "True"
462 | elif any([option in pred.lower() for option in ["no", "false"]]):
463 | pred = "False"
464 | elif any(
465 | [
466 | option in pred.lower()
467 | for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]
468 | ]
469 | ):
470 | pass
471 | else:
472 | # Some of the models somehow get used to boxed output from pre-training
473 | if "boxed" in pred:
474 | pred = find_box(pred)
475 |
476 | if answer_flag:
477 | # Extract the numbers out of the string
478 | pred = pred.split("=")[-1].strip()
479 | pred = clean_units(pred)
480 | try:
481 | tmp = str(latex2sympy(pred))
482 | pred = str(eval(tmp))
483 | except Exception:
484 | if re.match(r"-?[\d\.]+\s\D+$", pred):
485 | pred = pred.split(" ")[0]
486 | elif re.match(r"-?[\d\.]+\s[^\s]+$", pred):
487 | pred = pred.split(" ")[0]
488 | else:
489 | # desparate search over the last number
490 | preds = re.findall(r"-?\d*\.?\d+", pred)
491 | if len(preds) >= 1:
492 | pred = preds[-1]
493 | else:
494 | pred = ""
495 |
496 | return pred
497 |
498 |
499 | def extract_answer(pred_str, data_name, use_last_number=True):
500 | pred_str = pred_str.replace("\u043a\u0438", "")
501 | if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]:
502 | # TODO check multiple choice
503 | return choice_answer_clean(pred_str)
504 |
505 | if "final answer is $" in pred_str and "$. I hope" in pred_str:
506 | # minerva_math
507 | tmp = pred_str.split("final answer is $", 1)[1]
508 | pred = tmp.split("$. I hope", 1)[0].strip()
509 | elif "boxed" in pred_str:
510 | ans = pred_str.split("boxed")[-1]
511 | if len(ans) == 0:
512 | return ""
513 | elif ans[0] == "{":
514 | stack = 1
515 | a = ""
516 | for c in ans[1:]:
517 | if c == "{":
518 | stack += 1
519 | a += c
520 | elif c == "}":
521 | stack -= 1
522 | if stack == 0:
523 | break
524 | a += c
525 | else:
526 | a += c
527 | else:
528 | a = ans.split("$")[0].strip()
529 | pred = a
530 | elif "he answer is" in pred_str:
531 | pred = pred_str.split("he answer is")[-1].strip()
532 | elif "final answer is" in pred_str:
533 | pred = pred_str.split("final answer is")[-1].strip()
534 | elif "答案是" in pred_str:
535 | # Handle Chinese few-shot multiple choice problem answer extraction
536 | pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
537 | else: # use the last number
538 | if use_last_number:
539 | pattern = "-?\d*\.?\d+"
540 | pred = re.findall(pattern, pred_str.replace(",", ""))
541 | if len(pred) >= 1:
542 | pred = pred[-1]
543 | else:
544 | pred = ""
545 | else:
546 | pred = ""
547 |
548 | # choice answer
549 | if (
550 | data_name in ["sat_math", "aqua"]
551 | or "mmlu" in data_name
552 | ):
553 | tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
554 | if tmp:
555 | pred = tmp[-1]
556 | else:
557 | pred = pred.strip().strip(".")
558 |
559 | # multiple line
560 | # pred = pred.split("\n")[0]
561 | pred = re.sub(r"\n\s*", "", pred)
562 | if pred != "" and pred[0] == ":":
563 | pred = pred[1:]
564 | if pred != "" and pred[-1] == ".":
565 | pred = pred[:-1]
566 | if pred != "" and pred[-1] == "/":
567 | pred = pred[:-1]
568 | pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"])
569 | return pred
570 |
571 |
572 | STRIP_EXCEPTIONS = ["carp_en", "minerva_math"]
573 |
574 |
575 | def parse_ground_truth(example: Dict[str, Any], data_name):
576 | # 如果样本中已经有 gt_cot 和 gt 字段,直接使用(这里不做区分大小写)
577 | if "gt_cot" in example and "gt" in example:
578 | if data_name.lower() in ["math", "minerva_math"]:
579 | gt_ans = extract_answer(example["gt_cot"], data_name)
580 | elif data_name in STRIP_EXCEPTIONS:
581 | gt_ans = example["gt"]
582 | else:
583 | gt_ans = strip_string(example["gt"])
584 | return example["gt_cot"], gt_ans
585 |
586 | # 针对不同数据集进行解析
587 | if data_name.lower() in ["math", "minerva_math"]:
588 | gt_cot = example["solution"]
589 | gt_ans = extract_answer(gt_cot, data_name)
590 | elif data_name.lower() == "math-500":
591 | # 对于 MATH-500,使用 solution 作为链式推理,直接使用 answer 作为最终答案
592 | gt_cot = example["solution"]
593 | gt_ans = strip_string(example["answer"])
594 | elif data_name == "gsm8k":
595 | gt_cot, gt_ans = example["answer"].split("####")
596 | elif data_name == "svamp":
597 | gt_cot, gt_ans = example["Equation"], example["Answer"]
598 | elif data_name == "asdiv":
599 | gt_cot = example["formula"]
600 | gt_ans = re.sub(r"\(.*?\)", "", example["answer"])
601 | elif data_name == "mawps":
602 | gt_cot, gt_ans = None, example["target"]
603 | elif data_name == "tabmwp":
604 | gt_cot = example["solution"]
605 | gt_ans = example["answer"]
606 | if example["ans_type"] in ["integer_number", "decimal_number"]:
607 | if "/" in gt_ans:
608 | gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1])
609 | elif "," in gt_ans:
610 | gt_ans = float(gt_ans.replace(",", ""))
611 | elif "%" in gt_ans:
612 | gt_ans = float(gt_ans.split("%")[0]) / 100
613 | else:
614 | gt_ans = float(gt_ans)
615 | elif data_name == "carp_en":
616 | gt_cot, gt_ans = example["steps"], example["answer"]
617 | elif data_name == "mmlu_stem":
618 | abcd = "ABCD"
619 | gt_cot, gt_ans = None, abcd[example["answer"]]
620 | elif data_name == "sat_math":
621 | gt_cot, gt_ans = None, example["Answer"]
622 | elif data_name == "aqua":
623 | gt_cot, gt_ans = None, example["correct"]
624 | elif data_name in ["gaokao2023en", "college_math", "gaokao_math_cloze"]:
625 | gt_cot, gt_ans = None, example["answer"].replace("$", "").strip()
626 | elif data_name == "gaokao_math_qa":
627 | gt_cot, gt_ans = None, example["label"]
628 | elif data_name in ["gaokao2024_mix", "cn_middle_school"]:
629 | if len(example["choice_answer"]) > 0:
630 | gt_cot, gt_ans = None, example["choice_answer"]
631 | else:
632 | gt_cot, gt_ans = None, example["answer"]
633 | elif data_name == "olympiadbench":
634 | gt_cot, gt_ans = None, example["final_answer"][0].strip("$")
635 | elif data_name in [
636 | "aime24",
637 | "amc23",
638 | "cmath",
639 | "gaokao2024_I",
640 | "gaokao2024_II",
641 | "imo2024",
642 | ]:
643 | gt_cot, gt_ans = None, example["answer"]
644 | else:
645 | raise NotImplementedError(f"`{data_name}`")
646 |
647 | # 后处理
648 | gt_cot = str(gt_cot).strip()
649 | if data_name not in STRIP_EXCEPTIONS:
650 | gt_ans = strip_string(gt_ans, skip_unit=data_name == "carp_en")
651 | else:
652 | gt_ans = (
653 | gt_ans.replace("\\neq", "\\ne")
654 | .replace("\\leq", "\\le")
655 | .replace("\\geq", "\\ge")
656 | )
657 | return gt_cot, gt_ans
658 |
659 |
660 | def parse_question(example, data_name):
661 | question = ""
662 | if data_name == "asdiv":
663 | question = f"{example['body'].strip()} {example['question'].strip()}"
664 | elif data_name == "svamp":
665 | body = example["Body"].strip()
666 | if not body.endswith("."):
667 | body = body + "."
668 | question = f'{body} {example["Question"].strip()}'
669 | elif data_name == "tabmwp":
670 | title_str = (
671 | f'regarding "{example["table_title"]}" ' if example["table_title"] else ""
672 | )
673 | question = f"Read the following table {title_str}and answer a question:\n"
674 | question += f'{example["table"]}\n{example["question"]}'
675 | if example["choices"]:
676 | question += (
677 | f' Please select from the following options: {example["choices"]}'
678 | )
679 | elif data_name == "carp_en":
680 | question = example["content"]
681 | elif data_name == "mmlu_stem":
682 | options = example["choices"]
683 | assert len(options) == 4
684 | for i, (label, option) in enumerate(zip("ABCD", options)):
685 | options[i] = f"({label}) {str(option).strip()}"
686 | options = " ".join(options)
687 | question = f"{example['question'].strip()}\nAnswer Choices: {options}"
688 | elif data_name == "sat_math":
689 | options = example["options"].strip()
690 | assert "A" == options[0]
691 | options = "(" + options
692 | for ch in "BCD":
693 | if f" {ch}) " in options:
694 | options = re.sub(f" {ch}\) ", f" ({ch}) ", options)
695 | question = f"{example['question'].strip()}\nAnswer Choices: {options}"
696 | elif "aqua" in data_name:
697 | options = example["options"]
698 | choice = "(" + "(".join(options)
699 | choice = choice.replace("(", " (").replace(")", ") ").strip()
700 | choice = "\nAnswer Choices: " + choice
701 | question = example["question"].strip() + choice
702 | elif data_name == "gaokao_math_qa":
703 | options_dict = example["options"]
704 | options = []
705 | for key in options_dict:
706 | options.append(f"({key}) {options_dict[key]}")
707 | options = " ".join(options)
708 | question = f"{example['question'].strip()}\n选项: {options}"
709 | elif data_name.lower() == "math-500":
710 | # 对于 MATH-500,使用 "problem" 字段作为问题
711 | question = example["problem"]
712 | else:
713 | for key in ["question", "problem", "Question", "input"]:
714 | if key in example:
715 | question = example[key]
716 | break
717 |
718 | # 针对 Yes/No 类型的问题添加提示
719 | _, gt_ans = parse_ground_truth(example, data_name)
720 | if isinstance(gt_ans, str):
721 | gt_lower = gt_ans.lower()
722 | if gt_lower in ["true", "false"]:
723 | question += " (True or False)"
724 | if gt_lower in ["yes", "no"]:
725 | question += " (Yes or No)"
726 | return question.strip()
727 |
728 |
729 | def run_execute(executor, result, prompt_type, data_name, execute=False):
730 | if not result or result == "error":
731 | return None, None
732 | report = None
733 |
734 | if "program_only" in prompt_type:
735 | prediction = extract_program_output(result)
736 | elif prompt_type in ["pot", "pal"] and execute:
737 | code = extract_program(result)
738 | prediction, report = executor.apply(code)
739 | else:
740 | prediction = extract_answer(result, data_name)
741 |
742 | # prediction = strip_string(prediction, skip_unit=data_name == "carp_en")
743 | prediction = strip_string(prediction, skip_unit=data_name in STRIP_EXCEPTIONS)
744 | return prediction, report
745 |
746 |
747 | def _test_extract_answer():
748 | text = """
749 | This is still not equal to $0$, so we must have made another mistake.
750 |
751 | When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that:
752 |
753 | \[\frac{386}{64} - 7 = \frac{386}{64} - \frac{7 \cdot 64}{1 \cdot 64} = \frac{386 - 448}{64} = \frac{-62}{64}.\]
754 |
755 | This is still not equal to $0$, so we must have made another mistake.
756 |
757 | When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that:
758 |
759 | \[\frac{386}{64}
760 | """
761 | print(extract_answer(text, "math-oai", use_last_number=False))
762 | print(choice_answer_clean("\mathrm{(D)\}1,008,016"))
763 | # should output a dict
764 |
765 |
766 | if __name__ == "__main__":
767 | _test_extract_answer()
768 |
--------------------------------------------------------------------------------