├── ASSETS ├── chain.png ├── feature.png ├── coe_insight.png ├── definition.png └── chain-of-embedding.png ├── requirements.txt ├── config_pool.py ├── Data └── load_data.py ├── Scripts └── llm_infer.sh ├── Model └── load_model.py ├── main.py ├── score.py ├── Evaluation ├── eval.py └── match.py ├── prompt_pool.py ├── README.md ├── inference.py └── LICENSE /ASSETS/chain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/Chain-of-Embedding/HEAD/ASSETS/chain.png -------------------------------------------------------------------------------- /ASSETS/feature.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/Chain-of-Embedding/HEAD/ASSETS/feature.png -------------------------------------------------------------------------------- /ASSETS/coe_insight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/Chain-of-Embedding/HEAD/ASSETS/coe_insight.png -------------------------------------------------------------------------------- /ASSETS/definition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/Chain-of-Embedding/HEAD/ASSETS/definition.png -------------------------------------------------------------------------------- /ASSETS/chain-of-embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/Chain-of-Embedding/HEAD/ASSETS/chain-of-embedding.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonlines 2 | scikit-learn 3 | matplotlib 4 | seaborn 5 | tqdm==4.66.4 6 | scipy==1.12.0 7 | numpy==1.26 8 | transformers==4.45 9 | accelerate>=0.26.0 10 | -------------------------------------------------------------------------------- /config_pool.py: -------------------------------------------------------------------------------- 1 | MODEL_POOL = [ 2 | "Llama-3-8B-Instruct", 3 | "Llama-3-70B-Instruct", 4 | "qwen2-7B-Instruct", 5 | "qwen2-72B-Instruct", 6 | "Mistral-7B-Instruct", 7 | ] 8 | 9 | 10 | DATASET_POOL = [ 11 | "mgsm", 12 | "math", 13 | "commonsenseqa", 14 | "theoremqa", 15 | ] 16 | 17 | LANGUAGE_MAPPING = { 18 | "mgsm": ["en", "bn", "de", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"], 19 | } -------------------------------------------------------------------------------- /Data/load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jsonlines 3 | 4 | dataset_path = "./Data/" 5 | 6 | class DatasetInfo: 7 | def __init__(self, dataset_name): 8 | self.dataset_name = dataset_name 9 | self.data = [] 10 | with open(dataset_path + self.dataset_name + ".jsonl", "r+", encoding="utf8") as f: 11 | for item in jsonlines.Reader(f): 12 | self.data.append(item) 13 | self.data_size = len(self.data) 14 | 15 | 16 | def load_one_sample(self, idx): 17 | return self.data[idx] 18 | 19 | -------------------------------------------------------------------------------- /Scripts/llm_infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PROJECT_PATH="/your/path/to/Chain-of-Embedding" 4 | export CUDA_VISIBLE_DEVICES="0,1" 5 | 6 | model_name="qwen2-7B-Instruct" 7 | dataset_list=(mgsm) 8 | 9 | for i in ${dataset_list[*]}; do 10 | python main.py --model_name $model_name \ 11 | --dataset "$i" \ 12 | --print_model_parameter \ 13 | --save_output \ 14 | --save_hidden_states \ 15 | --save_coe_score \ 16 | --save_coe_figure 17 | done 18 | -------------------------------------------------------------------------------- /Model/load_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | 6 | import argparse 7 | import scipy.spatial 8 | import math 9 | import json 10 | import torch 11 | import torch.nn as nn 12 | 13 | import numpy as np 14 | import pickle 15 | from tqdm import tqdm 16 | from transformers import ( 17 | AutoTokenizer, 18 | AutoModelForCausalLM, 19 | AutoConfig, 20 | GenerationConfig, 21 | ) 22 | 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | MODEL_REPO = "./Model/" 26 | 27 | 28 | def load_base_model(args): 29 | config_kwargs = { 30 | "trust_remote_code": True, 31 | "cache_dir": None, 32 | "revision": 'main', 33 | "use_auth_token": None, 34 | "output_hidden_states": True 35 | } 36 | config = AutoConfig.from_pretrained(MODEL_REPO + args.model_name, **config_kwargs) 37 | model = AutoModelForCausalLM.from_pretrained( 38 | MODEL_REPO + args.model_name, 39 | config=config, 40 | torch_dtype=torch.float32, 41 | device_map='auto', 42 | trust_remote_code=True 43 | ) 44 | 45 | tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO + args.model_name, trust_remote_code=True) 46 | tokenizer.pad_token_id = 0 47 | 48 | return model, tokenizer, config -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | import json 6 | import random 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | import numpy as np 11 | import pickle 12 | import scipy.spatial 13 | import torch 14 | import torch.nn as nn 15 | 16 | from transformers import ( 17 | AutoTokenizer, 18 | AutoModelForCausalLM, 19 | AutoConfig, 20 | GenerationConfig, 21 | ) 22 | 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | print(f'device: {device}') 25 | 26 | project_root_path = os.environ["PROJECT_PATH"] 27 | sys.path.append(project_root_path) 28 | from Data.load_data import DatasetInfo 29 | from Model.load_model import load_base_model 30 | from config_pool import MODEL_POOL, DATASET_POOL, LANGUAGE_MAPPING 31 | from inference import Inference 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser(description="chain-of-embedding") 36 | 37 | parser.add_argument("--model_name", type=str, default="Llama-3-8B-Instruct", choices=MODEL_POOL) 38 | parser.add_argument("--dataset", type=str, default="mgsm", choices=DATASET_POOL) 39 | parser.add_argument("--max_output_token", type=int, default=2048) 40 | 41 | parser.add_argument("--print_model_parameter", action="store_true") 42 | parser.add_argument("--save_output", action="store_true") 43 | parser.add_argument("--save_hidden_states", action="store_true") 44 | parser.add_argument("--save_coe_score", action="store_true") 45 | parser.add_argument("--save_coe_figure", action="store_true") 46 | 47 | args = parser.parse_args() 48 | args.max_output_token = 2048 if "Instruct" in args.model_name else 128 49 | 50 | model, tokenizer, config = load_base_model(args) 51 | if args.print_model_parameter: 52 | print("********** Module Name and Size **********\n") 53 | for param_tensor in model.state_dict(): 54 | print(param_tensor,'\t',model.state_dict()[param_tensor].size()) 55 | 56 | model_info = { 57 | "model_name": args.model_name, 58 | "model_ckpt": model, 59 | "tokenizer": tokenizer, 60 | "model_config": config, 61 | "generation_config": GenerationConfig(), 62 | "max_output_token": args.max_output_token 63 | } 64 | dataset_info = { 65 | "dataset_name": args.dataset, 66 | } 67 | verbose = { 68 | "save_output": args.save_output, 69 | "save_hidden_states": args.save_hidden_states, 70 | "save_coe_score": args.save_coe_score, 71 | "save_coe_figure": args.save_coe_figure 72 | } 73 | 74 | print(f"***** Model Name: *****\n{args.model_name}") 75 | print(f"***** Dataset Name: *****\n{args.dataset}") 76 | print(f"***** Dataset Size: *****\n{DatasetInfo(args.dataset).data_size}") 77 | 78 | language_list = LANGUAGE_MAPPING[args.dataset] if args.dataset in LANGUAGE_MAPPING else ["en"] 79 | for lang in language_list: 80 | dataset_info["language"] = lang 81 | Infer = Inference(model_info, dataset_info, verbose) 82 | Infer.dataset_inference() 83 | -------------------------------------------------------------------------------- /score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import scipy.spatial 6 | from scipy.stats import entropy 7 | import math 8 | import json 9 | 10 | from torch.utils.dlpack import to_dlpack 11 | from torch.utils.dlpack import from_dlpack 12 | import torch 13 | import torch.nn.functional as F 14 | import numpy as np 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | 19 | 20 | class OutputScoreInfo: 21 | def __init__(self, output_scores): 22 | self.output_scores = output_scores 23 | self.all_token_re = [] 24 | self.all_token_max_re = [] 25 | for token in range(len(self.output_scores)): 26 | re = self.output_scores[token][0].tolist() 27 | re = F.softmax(torch.tensor(re).to(device), 0).cpu().tolist() 28 | self.all_token_re.append(re) 29 | self.all_token_max_re.append(max(re)) 30 | 31 | def compute_maxprob(self): 32 | seq_prob_list = self.all_token_max_re 33 | max_prob = np.mean(seq_prob_list) 34 | return max_prob 35 | 36 | def compute_ppl(self): 37 | seq_ppl_list = [math.log(max_re) for max_re in self.all_token_max_re] 38 | ppl = -np.mean(seq_ppl_list) 39 | return ppl 40 | 41 | def compute_entropy(self): 42 | seq_entropy_list = [entropy(re, base=2) for re in self.all_token_re] 43 | seq_entropy = np.mean(seq_entropy_list) 44 | return seq_entropy 45 | 46 | 47 | 48 | class CoEScoreInfo: 49 | def __init__(self, hidden_states): 50 | self.hidden_states = hidden_states 51 | 52 | def compute_CoE_Mag(self): 53 | hs_all_layer = self.hidden_states 54 | layer_num = len(hs_all_layer) 55 | 56 | norm_denominator = np.linalg.norm(hs_all_layer[-1] - hs_all_layer[0], ord=2) 57 | al_repdiff = np.array([hs_all_layer[i+1] - hs_all_layer[i] for i in range(layer_num - 1)]) 58 | al_repdiff_norm = [np.linalg.norm(item, ord=2) / norm_denominator for item in al_repdiff] 59 | al_repdiff_ave = np.mean(np.array(al_repdiff_norm)) 60 | al_repdiff_var = np.var(np.array(al_repdiff_norm)) 61 | return al_repdiff_norm, al_repdiff_ave, al_repdiff_var 62 | 63 | 64 | def compute_CoE_Ang(self): 65 | hs_all_layer = self.hidden_states 66 | layer_num = len(hs_all_layer) 67 | 68 | al_semdiff = [] 69 | norm_denominator = np.dot(hs_all_layer[-1], hs_all_layer[0]) / (np.linalg.norm(hs_all_layer[-1], ord=2) * np.linalg.norm(hs_all_layer[0], ord=2)) 70 | norm_denominator = math.acos(norm_denominator) 71 | for i in range(layer_num - 1): 72 | a = hs_all_layer[i + 1] 73 | b = hs_all_layer[i] 74 | dot_product = np.dot(a, b) 75 | norm_a, norm_b = np.linalg.norm(a, ord=2), np.linalg.norm(b, ord=2) 76 | similarity = dot_product / (norm_a * norm_b) 77 | similarity = similarity if similarity <= 1 else 1 78 | 79 | arccos_sim = math.acos(similarity) 80 | al_semdiff.append(arccos_sim / norm_denominator) 81 | 82 | al_semdiff_norm = np.array(al_semdiff) 83 | al_semdiff_ave = np.mean(np.array(al_semdiff_norm)) 84 | al_semdiff_var = np.var(np.array(al_semdiff_norm)) 85 | 86 | return al_semdiff_norm, al_semdiff_ave, al_semdiff_var 87 | 88 | def compute_CoE_R(self): 89 | _, al_repdiff_ave, _ = self.compute_CoE_Mag() 90 | _, al_semdiff_ave, _ = self.compute_CoE_Ang() 91 | 92 | return al_repdiff_ave - al_semdiff_ave 93 | 94 | def compute_CoE_C(self): 95 | al_repdiff_norm, _, _ = self.compute_CoE_Mag() 96 | al_semdiff_norm, _, _ = self.compute_CoE_Ang() 97 | x_list = np.array([al_repdiff_norm[i] * math.cos(al_semdiff_norm[i]) for i in range(len(al_semdiff_norm))]) 98 | y_list = np.array([al_repdiff_norm[i] * math.sin(al_semdiff_norm[i]) for i in range(len(al_semdiff_norm))]) 99 | al_combdiff_x_ave = np.mean(x_list) 100 | al_combdiff_y_ave = np.mean(y_list) 101 | al_combdiff_x_var = np.mean(x_list) 102 | al_combdiff_y_var = np.mean(y_list) 103 | 104 | return math.sqrt(al_combdiff_x_ave ** 2 + al_combdiff_y_ave ** 2) 105 | -------------------------------------------------------------------------------- /Evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import pickle 5 | import argparse 6 | 7 | from sklearn.metrics import roc_curve, auc, precision_recall_curve 8 | from scipy import interpolate 9 | 10 | project_root_path = os.environ["PROJECT_PATH"] 11 | sys.path.append(project_root_path) 12 | from Data.load_data import DatasetInfo 13 | from config_pool import MODEL_POOL, DATASET_POOL, LANGUAGE_MAPPING 14 | from match import AnswerParsing 15 | 16 | 17 | class StandardEvaluation: 18 | def __init__(self, dataset_list): 19 | self.data_all = [] 20 | self.data_size = 0 21 | for i, dataset in enumerate(dataset_list): 22 | data_loader = DatasetInfo(args.dataset) 23 | self.data_all.extend(data_loader.data) 24 | self.data_size += data_loader.data_size 25 | 26 | def std_eval(self, args): 27 | answerparsing = AnswerParsing(args.dataset) 28 | output_dir = os.path.join(project_root_path, f"OutputInfo/{args.language}/Output", args.model_name, args.dataset) 29 | coe_dir = os.path.join(project_root_path, f"OutputInfo/{args.language}/CoE", args.model_name, args.dataset) 30 | 31 | output_list, coe_list, binary_list = [], [], [] 32 | acc = 0 33 | for i in range(self.data_size): 34 | sample = self.data_all[i] 35 | true_output = sample["answer"] 36 | 37 | with open(os.path.join(output_dir, f"{args.dataset}_{str(i)}.pkl"), 'rb') as file: 38 | output = pickle.load(file) 39 | pred_output = output["output_seq"] 40 | 41 | with open(os.path.join(coe_dir, f"{args.dataset}_{str(i)}.pkl"), 'rb') as file: 42 | coe = pickle.load(file) 43 | 44 | extracted_answer, binary = answerparsing.dataset_parse(pred_output, true_output, sample) 45 | if binary: acc += 1 46 | 47 | output_list.append(output) 48 | coe_list.append(coe) 49 | binary_list.append(binary) 50 | 51 | return round(acc / self.data_size, 3), output_list, coe_list, binary_list 52 | 53 | 54 | class SelfEvaluation: 55 | def __init__(self, dataset_list): 56 | self.data_all = [] 57 | self.data_size = 0 58 | for i, dataset in enumerate(dataset_list): 59 | data_loader = DatasetInfo(args.dataset) 60 | self.data_all.extend(data_loader.data) 61 | self.data_size += data_loader.data_size 62 | 63 | def self_eval(self, score_list, binary_list): 64 | fpr, tpr, thresholds = roc_curve(binary_list, score_list) 65 | auroc = auc(fpr, tpr) 66 | fpr95 = float(interpolate.interp1d(tpr, fpr)(0.95)) 67 | precision, recall, _ = precision_recall_curve(binary_list, score_list) 68 | aupr = auc(recall, precision) 69 | 70 | return round(auroc * 100, 2), round(fpr95 * 100, 2), round(aupr * 100, 2) 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser(description="eval") 75 | parser.add_argument("--model_name", type=str, default="Llama-3-8B-Instruct", choices=MODEL_POOL) 76 | parser.add_argument("--dataset", type=str, default="mgsm", choices=DATASET_POOL) 77 | parser.add_argument("--language", type=str, default="en") 78 | 79 | args = parser.parse_args() 80 | 81 | stdeval = StandardEvaluation([args.dataset]) 82 | acc, output_list, coe_list, binary_list = stdeval.std_eval(args) 83 | print(f"# Accuracy: {acc}") 84 | 85 | input_list = [output_list[i]["input_seq"] for i in range(len(output_list))] 86 | maxprob_list = [output_list[i]["maxprob"] for i in range(len(output_list))] 87 | ppl_list = [1 / output_list[i]["ppl"] for i in range(len(output_list))] 88 | entropy_list = [1 / output_list[i]["entropy"] for i in range(len(output_list))] 89 | coer_list = [coe_list[i]["R"] for i in range(len(coe_list))] 90 | coec_list = [coe_list[i]["C"] for i in range(len(coe_list))] 91 | 92 | selfeval = SelfEvaluation([args.dataset]) 93 | maxprob_auroc, maxprob_fpr95, maxprob_aupr = selfeval.self_eval(maxprob_list, binary_list) 94 | ppl_auroc, ppl_fpr95, ppl_aupr = selfeval.self_eval(ppl_list, binary_list) 95 | entropy_auroc, entropy_fpr95, entropy_aupr = selfeval.self_eval(entropy_list, binary_list) 96 | coer_auroc, coer_fpr95, coer_aupr = selfeval.self_eval(coer_list, binary_list) 97 | coec_auroc, coec_fpr95, coec_aupr = selfeval.self_eval(coec_list, binary_list) 98 | 99 | print(f"{'maxprob_auroc'.rjust(13)}: {maxprob_auroc:.2f} {'maxprob_fpr95'.rjust(13)}: {maxprob_fpr95:.2f} {'maxprob_aupr'.rjust(13)}: {maxprob_aupr:.2f}") 100 | print(f"{'ppl_auroc'.rjust(13)}: {ppl_auroc:.2f} {'ppl_fpr95'.rjust(13)}: {ppl_fpr95:.2f} {'ppl_aupr'.rjust(13)}: {ppl_aupr:.2f}") 101 | print(f"{'entropy_auroc'.rjust(13)}: {entropy_auroc:.2f} {'entropy_fpr95'.rjust(13)}: {entropy_fpr95:.2f} {'entropy_aupr'.rjust(13)}: {entropy_aupr:.2f}") 102 | print(f"{'coer_auroc'.rjust(13)}: {coer_auroc:.2f} {'coer_fpr95'.rjust(13)}: {coer_fpr95:.2f} {'coer_aupr'.rjust(13)}: {coer_aupr:.2f}") 103 | print(f"{'coec_auroc'.rjust(13)}: {coec_auroc:.2f} {'coec_fpr95'.rjust(13)}: {coec_fpr95:.2f} {'coec_aupr'.rjust(13)}: {coec_aupr:.2f}") 104 | -------------------------------------------------------------------------------- /prompt_pool.py: -------------------------------------------------------------------------------- 1 | DATASET_PROMPTS = { 2 | "GSM8K": "Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of \"Answer:\". Do not add anything other than the integer answer after \"Answer:\".\n\nQuestion:\n{input_data}\n", 3 | "mgsm": "Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of \"Answer:\". Do not add anything other than the integer answer after \"Answer:\".\n\nQuestion:\n{input_data}\n", 4 | "math": "Question: {input_data}\nPlease reason step by step, and put your final answer within \\boxed{{}}\n", 5 | "mmmlu": "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\nQuestion:\n{input_data}\n", 6 | "belebele": "Answer the following multiple choice reading-comprehension question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Please fully understand the passage and give explanations step by step before answering.\n\n{input_data}\n", 7 | "commonsenseqa": "Answer the following multiple choice common-sense reasoning question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCDE. Think step by step and output the reasoning process before answering.\n\n{input_data}", 8 | "theoremqa": ''' 9 | Below is an instruction that describes a task, paired with an input that provides further context. 10 | Write a response that appropriately completes the request. 11 | 12 | ### Instruction: 13 | Please read a math problem, and then think step by step to derive the answer. The answer is decided by Answer Type. 14 | If the Answer type in [bool], the answer needs to be True or False. 15 | Else if the Answer type in [integer, float] , The answer needs to be in numerical form. 16 | Else if the Answer type in [list of integer, list of float] , the answer needs to be a list of number like [2, 3, 4]. 17 | Else if the Answer type in [option], the answer needs to be an option like (a), (b), (c), (d). 18 | You need to output the answer in your final sentence like 'Therefore, the answer is ...'. 19 | 20 | ### Question: 21 | ''' + "{input_data}\n\n### Answer_type: {answer_type}\n\n### Response:" 22 | } 23 | 24 | 25 | 26 | QUESTIONS = { 27 | "en": "Question:", 28 | "es": "Pregunta:", 29 | "fr": "Question:", 30 | "de": "Frage:", 31 | "ru": "Задача:", 32 | "zh": "问题:", 33 | "ja": "問題:", 34 | "th": "โจทย์:", 35 | "sw": "Swali:", 36 | "bn": "প্রশ্ন:", 37 | "te": "ప్రశ్న:" 38 | } 39 | 40 | ANSWERS = { 41 | "en": "Step-by-Step Answer:", 42 | "es": "Respuesta paso a paso:", 43 | "fr": "Réponse étape par étape :", 44 | "de": "Schritt-für-Schritt-Antwort:", 45 | "ru": "Пошаговоерешение:", 46 | "zh": "逐步解答:", 47 | "ja": "ステップごとの答え:", 48 | "th": "คำตอบทีละขั้นตอน:", 49 | "sw": "Jibu la Hatua kwa Hatua:", 50 | "bn": "ধাপে ধাপে উত্তর:", 51 | "te": "దశలవారీగా సమాధానం:" 52 | } 53 | 54 | 55 | INSTRUCTIONS = { 56 | "en": """Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:". Do not add anything other than the integer answer after "Answer:".""", 57 | "bn": """এই গণিতের সমস্যাটি সমাধান করুন। চূড়ান্ত উত্তর দেওয়ার আগে যুক্তিসম্পন্ন পদক্ষেপ প্রদান করুন। চূড়ান্ত উত্তরটি একক সংখ্যা হিসাবে "উত্তর:" এর পরে শেষ লাইনে দিন। "উত্তর:" এর পরে অন্য কিছু যুক্ত করবেন না।.""", 58 | "de": """Löse dieses Mathematikproblem. Gib die Schritte zur Begründung an, bevor du die endgültige Antwort in der letzten Zeile alleine im Format "Antwort:" gibst. Füge nichts anderes als die ganzzahlige Antwort nach "Antwort:" hinzu.""", 59 | "es": """Resuelve este problema matemático. Proporciona los pasos de razonamiento antes de dar la respuesta final en la última línea por sí misma en el formato de "Respuesta:". No añadas nada más que la respuesta entera después de "Respuesta:".""", 60 | "fr": """Résolvez ce problème de mathématiques. Donnez les étapes de raisonnement avant de fournir la réponse finale sur la dernière ligne elle-même dans le format de "Réponse:". N'ajoutez rien d'autre que la réponse entière après "Réponse:".""", 61 | "ja": """の数学の問題を解いてください。最終的な答えを出す前に、解答の推論過程を記述してください。そして最後の行には "答え:" の形式で答えを記述し、その後には整数の答え以外何も追加しないでください。""", 62 | "ru": """Решите эту математическую задачу. Объясните шаги рассуждения перед тем, как дать окончательный ответ в последней строке сам по себе в формате "Ответ:". Не добавляйте ничего, кроме целочисленного ответа после "Ответ:".""", 63 | "sw": """Suluhisha tatizo hili la hesabu. Toa hatua za mantiki kabla ya kutoa jibu la mwisho kwenye mstari wa mwisho peke yake katika muundo wa "Jibu:". Usiongeze chochote kingine isipokuwa jibu la integer baada ya "Jibu:".""", 64 | "te": """ఈ గణిత సమస్యను పరిష్కరించండి. చివరి సమాధానాన్ని ఇవ్వదానికి ముందు తర్కాత్మక అదుగులను ఇవ్వండి. చివరి పంక్తిలో మాత్రమే 'సమాధానం:' అనే ఆకారంలో చివరి సమాధానాద్ని ఇవ్వండి సమాధానం: తర్వాత పూర్ణాంక సమాధానానికి తప్పించి ఎదేనా చేర్చవద్దు.""", 65 | "th": """แก้ปัญหาคณิตศาสตร์นี้ ให้ให้ขั้นตอนการใช้เหตุผลก่อนที่จะให้คำตอบสุดท้ายในบรรทัดสุดท้ายโดยอยู่ในรูปแบบ "คำตอบ:" ไม่ควรเพิ่มอะไรนอกจากคำตอบที่เป็นจำนวนเต็มหลังจาก "คำตอบ:""""", 66 | "zh": """解决这个数学问题。在最后一行给出答案前,请提供推理步骤。最后一行应该以 "答案: " 的形式独立给出答案。在 "答案:" 后不要添加除整数答案之外的任何内容。""", 67 | } 68 | 69 | ANSWER_PREFIX = { 70 | "en": "Answer", 71 | "bn": "উত্তর", 72 | "de": "Antwort", 73 | "es": "Respuesta", 74 | "fr": "Réponse", 75 | "ja": "答え", 76 | "ru": "Ответ", 77 | "sw": "Jibu", 78 | "te": "సమాధానం", 79 | "th": "คำตอบ", 80 | "zh": "答案", 81 | } 82 | 83 | 84 | -------------------------------------------------------------------------------- /Evaluation/match.py: -------------------------------------------------------------------------------- 1 | import re 2 | from prompt_pool import ANSWER_PREFIX 3 | 4 | 5 | class AnswerParsing: 6 | def __init__(self, dataset): 7 | self.dataset = dataset 8 | 9 | def dataset_parse(self, pred, true, sample): 10 | if self.dataset == "mgsm" or self.dataset == "gsm8k": 11 | extracted_answer, binary = self.mgsm_parse(ANSWER_PREFIX["en"], pred, true) 12 | elif self.dataset == "math": 13 | extracted_answer, binary = self.math_parse(ANSWER_PREFIX["en"], pred, true) 14 | elif self.dataset == "commonsenseqa": 15 | extracted_answer, binary = self.commonsenseqa_parse(pred, true) 16 | elif self.dataset == "theoremqa": 17 | extracted_answer, binary = self.theoremqa_parse(sample["answer_type"], pred, true) 18 | elif self.dataset == "mmmlu": 19 | extracted_answer, binary = self.mmmlu_parse(pred, true) 20 | elif self.dataset == "belebele": 21 | extracted_answer, binary = self.belebele_parse(pred, true) 22 | 23 | return extracted_answer, binary 24 | 25 | 26 | def extract_boxed_content(self, text): 27 | pattern = re.compile(r'\\boxed{') 28 | matches = pattern.finditer(text) 29 | results = [] 30 | for match in matches: 31 | start_pos = match.end() 32 | brace_count = 1 33 | i = start_pos 34 | while i < len(text) and brace_count > 0: 35 | if text[i] == '{': 36 | brace_count += 1 37 | elif text[i] == '}': 38 | brace_count -= 1 39 | i += 1 40 | if brace_count == 0: 41 | results.append(text[start_pos:i-1]) 42 | return results 43 | 44 | 45 | def mgsm_parse(self, answer_prefix, pred, true): 46 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 47 | return "Incomplete", False 48 | 49 | if answer_prefix not in pred: 50 | return None, False 51 | 52 | answer_text = pred.split(answer_prefix)[-1].strip() 53 | numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", "")) 54 | extracted_answer = numbers[-1].rstrip(".") if numbers else "" 55 | 56 | if "." in extracted_answer: 57 | extracted_answer = extracted_answer.rstrip("0").rstrip(".") 58 | true = true.replace(",", "") 59 | extracted_answer = extracted_answer.replace(",", "") 60 | 61 | return extracted_answer, true == extracted_answer 62 | 63 | 64 | def math_parse(self, answer_prefix, pred, true): 65 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 66 | return "Incomplete", False 67 | 68 | extracted_answer = self.extract_boxed_content(pred) 69 | extracted_answer = extracted_answer[0] if extracted_answer else None 70 | if extracted_answer: 71 | extracted_answer = extracted_answer.replace(" ", "") 72 | true = true.replace(" ", "") 73 | 74 | return extracted_answer, true == extracted_answer 75 | 76 | 77 | def mmmlu_parse(self, pred, true): 78 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 79 | return "Incomplete", False 80 | 81 | ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])" 82 | pred = pred.replace("$", "") 83 | pred = pred.replace("(", "") 84 | pred = pred.replace(")", "") 85 | match = re.search(ANSWER_PATTERN_MULTICHOICE, pred) 86 | extracted_answer = match.group(1) if match else None 87 | 88 | return extracted_answer, true == extracted_answer 89 | 90 | 91 | def commonsenseqa_parse(self, pred, true): 92 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 93 | return "Incomplete", False 94 | 95 | ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-E])" 96 | pred = pred.replace("$", "") 97 | pred = pred.replace("(", "") 98 | pred = pred.replace(")", "") 99 | match = re.search(ANSWER_PATTERN_MULTICHOICE, pred) 100 | extracted_answer = match.group(1) if match else None 101 | 102 | return extracted_answer, true == extracted_answer 103 | 104 | 105 | def belebele_parse(self, pred, true): 106 | alpha_map = ["", "A", "B", "C", "D"] 107 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 108 | return "Incomplete", False 109 | 110 | ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])" 111 | pred = pred.replace("$", "") 112 | pred = pred.replace("(", "") 113 | pred = pred.replace(")", "") 114 | match = re.search(ANSWER_PATTERN_MULTICHOICE, pred) 115 | extracted_answer = match.group(1) if match else None 116 | 117 | return extracted_answer, alpha_map[int(true)] == extracted_answer 118 | 119 | def theoremqa_parse(self, answer_type, pred, true): 120 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 121 | return "Incomplete", False 122 | 123 | answer_text = pred.split("answer is")[-1].strip() 124 | pred = re.sub(r'\<.*\>', '', answer_text) 125 | 126 | if answer_type == "bool": 127 | if "True" in pred or "true" in pred: 128 | extracted_answer = "True" 129 | elif "False" in pred or "false" in pred: 130 | extracted_answer = "False" 131 | else: 132 | extracted_answer = None 133 | elif answer_type == "integer" or answer_type == "float": 134 | numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", "")) 135 | extracted_answer = numbers[-1].rstrip(".") if numbers else "" 136 | elif answer_type == "list of integer" or answer_type == "list of float": 137 | match = re.search(r"\[(.*?)\]", pred) 138 | extracted_answer = match.group(1) if match else None 139 | true = true[1:-1] 140 | elif answer_type == "option": 141 | match = re.search(r"\(([a-d])\)", pred) 142 | extracted_answer = match.group(1) if match else None 143 | true = true[1:-1] 144 | 145 | return extracted_answer, true == extracted_answer -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ##
Latent Space Chain-of-Embedding Enables Output-free LLM Self-Evaluation
14 |
19 |
24 |
29 |