├── ASSETS ├── chain.png ├── feature.png ├── coe_insight.png ├── definition.png └── chain-of-embedding.png ├── requirements.txt ├── config_pool.py ├── Data └── load_data.py ├── Scripts └── llm_infer.sh ├── Model └── load_model.py ├── main.py ├── score.py ├── Evaluation ├── eval.py └── match.py ├── prompt_pool.py ├── README.md ├── inference.py └── LICENSE /ASSETS/chain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/Chain-of-Embedding/HEAD/ASSETS/chain.png -------------------------------------------------------------------------------- /ASSETS/feature.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/Chain-of-Embedding/HEAD/ASSETS/feature.png -------------------------------------------------------------------------------- /ASSETS/coe_insight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/Chain-of-Embedding/HEAD/ASSETS/coe_insight.png -------------------------------------------------------------------------------- /ASSETS/definition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/Chain-of-Embedding/HEAD/ASSETS/definition.png -------------------------------------------------------------------------------- /ASSETS/chain-of-embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/Chain-of-Embedding/HEAD/ASSETS/chain-of-embedding.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonlines 2 | scikit-learn 3 | matplotlib 4 | seaborn 5 | tqdm==4.66.4 6 | scipy==1.12.0 7 | numpy==1.26 8 | transformers==4.45 9 | accelerate>=0.26.0 10 | -------------------------------------------------------------------------------- /config_pool.py: -------------------------------------------------------------------------------- 1 | MODEL_POOL = [ 2 | "Llama-3-8B-Instruct", 3 | "Llama-3-70B-Instruct", 4 | "qwen2-7B-Instruct", 5 | "qwen2-72B-Instruct", 6 | "Mistral-7B-Instruct", 7 | ] 8 | 9 | 10 | DATASET_POOL = [ 11 | "mgsm", 12 | "math", 13 | "commonsenseqa", 14 | "theoremqa", 15 | ] 16 | 17 | LANGUAGE_MAPPING = { 18 | "mgsm": ["en", "bn", "de", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"], 19 | } -------------------------------------------------------------------------------- /Data/load_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jsonlines 3 | 4 | dataset_path = "./Data/" 5 | 6 | class DatasetInfo: 7 | def __init__(self, dataset_name): 8 | self.dataset_name = dataset_name 9 | self.data = [] 10 | with open(dataset_path + self.dataset_name + ".jsonl", "r+", encoding="utf8") as f: 11 | for item in jsonlines.Reader(f): 12 | self.data.append(item) 13 | self.data_size = len(self.data) 14 | 15 | 16 | def load_one_sample(self, idx): 17 | return self.data[idx] 18 | 19 | -------------------------------------------------------------------------------- /Scripts/llm_infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PROJECT_PATH="/your/path/to/Chain-of-Embedding" 4 | export CUDA_VISIBLE_DEVICES="0,1" 5 | 6 | model_name="qwen2-7B-Instruct" 7 | dataset_list=(mgsm) 8 | 9 | for i in ${dataset_list[*]}; do 10 | python main.py --model_name $model_name \ 11 | --dataset "$i" \ 12 | --print_model_parameter \ 13 | --save_output \ 14 | --save_hidden_states \ 15 | --save_coe_score \ 16 | --save_coe_figure 17 | done 18 | -------------------------------------------------------------------------------- /Model/load_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | 6 | import argparse 7 | import scipy.spatial 8 | import math 9 | import json 10 | import torch 11 | import torch.nn as nn 12 | 13 | import numpy as np 14 | import pickle 15 | from tqdm import tqdm 16 | from transformers import ( 17 | AutoTokenizer, 18 | AutoModelForCausalLM, 19 | AutoConfig, 20 | GenerationConfig, 21 | ) 22 | 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | MODEL_REPO = "./Model/" 26 | 27 | 28 | def load_base_model(args): 29 | config_kwargs = { 30 | "trust_remote_code": True, 31 | "cache_dir": None, 32 | "revision": 'main', 33 | "use_auth_token": None, 34 | "output_hidden_states": True 35 | } 36 | config = AutoConfig.from_pretrained(MODEL_REPO + args.model_name, **config_kwargs) 37 | model = AutoModelForCausalLM.from_pretrained( 38 | MODEL_REPO + args.model_name, 39 | config=config, 40 | torch_dtype=torch.float32, 41 | device_map='auto', 42 | trust_remote_code=True 43 | ) 44 | 45 | tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO + args.model_name, trust_remote_code=True) 46 | tokenizer.pad_token_id = 0 47 | 48 | return model, tokenizer, config -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | import json 6 | import random 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | import numpy as np 11 | import pickle 12 | import scipy.spatial 13 | import torch 14 | import torch.nn as nn 15 | 16 | from transformers import ( 17 | AutoTokenizer, 18 | AutoModelForCausalLM, 19 | AutoConfig, 20 | GenerationConfig, 21 | ) 22 | 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | print(f'device: {device}') 25 | 26 | project_root_path = os.environ["PROJECT_PATH"] 27 | sys.path.append(project_root_path) 28 | from Data.load_data import DatasetInfo 29 | from Model.load_model import load_base_model 30 | from config_pool import MODEL_POOL, DATASET_POOL, LANGUAGE_MAPPING 31 | from inference import Inference 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser(description="chain-of-embedding") 36 | 37 | parser.add_argument("--model_name", type=str, default="Llama-3-8B-Instruct", choices=MODEL_POOL) 38 | parser.add_argument("--dataset", type=str, default="mgsm", choices=DATASET_POOL) 39 | parser.add_argument("--max_output_token", type=int, default=2048) 40 | 41 | parser.add_argument("--print_model_parameter", action="store_true") 42 | parser.add_argument("--save_output", action="store_true") 43 | parser.add_argument("--save_hidden_states", action="store_true") 44 | parser.add_argument("--save_coe_score", action="store_true") 45 | parser.add_argument("--save_coe_figure", action="store_true") 46 | 47 | args = parser.parse_args() 48 | args.max_output_token = 2048 if "Instruct" in args.model_name else 128 49 | 50 | model, tokenizer, config = load_base_model(args) 51 | if args.print_model_parameter: 52 | print("********** Module Name and Size **********\n") 53 | for param_tensor in model.state_dict(): 54 | print(param_tensor,'\t',model.state_dict()[param_tensor].size()) 55 | 56 | model_info = { 57 | "model_name": args.model_name, 58 | "model_ckpt": model, 59 | "tokenizer": tokenizer, 60 | "model_config": config, 61 | "generation_config": GenerationConfig(), 62 | "max_output_token": args.max_output_token 63 | } 64 | dataset_info = { 65 | "dataset_name": args.dataset, 66 | } 67 | verbose = { 68 | "save_output": args.save_output, 69 | "save_hidden_states": args.save_hidden_states, 70 | "save_coe_score": args.save_coe_score, 71 | "save_coe_figure": args.save_coe_figure 72 | } 73 | 74 | print(f"***** Model Name: *****\n{args.model_name}") 75 | print(f"***** Dataset Name: *****\n{args.dataset}") 76 | print(f"***** Dataset Size: *****\n{DatasetInfo(args.dataset).data_size}") 77 | 78 | language_list = LANGUAGE_MAPPING[args.dataset] if args.dataset in LANGUAGE_MAPPING else ["en"] 79 | for lang in language_list: 80 | dataset_info["language"] = lang 81 | Infer = Inference(model_info, dataset_info, verbose) 82 | Infer.dataset_inference() 83 | -------------------------------------------------------------------------------- /score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import scipy.spatial 6 | from scipy.stats import entropy 7 | import math 8 | import json 9 | 10 | from torch.utils.dlpack import to_dlpack 11 | from torch.utils.dlpack import from_dlpack 12 | import torch 13 | import torch.nn.functional as F 14 | import numpy as np 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | 19 | 20 | class OutputScoreInfo: 21 | def __init__(self, output_scores): 22 | self.output_scores = output_scores 23 | self.all_token_re = [] 24 | self.all_token_max_re = [] 25 | for token in range(len(self.output_scores)): 26 | re = self.output_scores[token][0].tolist() 27 | re = F.softmax(torch.tensor(re).to(device), 0).cpu().tolist() 28 | self.all_token_re.append(re) 29 | self.all_token_max_re.append(max(re)) 30 | 31 | def compute_maxprob(self): 32 | seq_prob_list = self.all_token_max_re 33 | max_prob = np.mean(seq_prob_list) 34 | return max_prob 35 | 36 | def compute_ppl(self): 37 | seq_ppl_list = [math.log(max_re) for max_re in self.all_token_max_re] 38 | ppl = -np.mean(seq_ppl_list) 39 | return ppl 40 | 41 | def compute_entropy(self): 42 | seq_entropy_list = [entropy(re, base=2) for re in self.all_token_re] 43 | seq_entropy = np.mean(seq_entropy_list) 44 | return seq_entropy 45 | 46 | 47 | 48 | class CoEScoreInfo: 49 | def __init__(self, hidden_states): 50 | self.hidden_states = hidden_states 51 | 52 | def compute_CoE_Mag(self): 53 | hs_all_layer = self.hidden_states 54 | layer_num = len(hs_all_layer) 55 | 56 | norm_denominator = np.linalg.norm(hs_all_layer[-1] - hs_all_layer[0], ord=2) 57 | al_repdiff = np.array([hs_all_layer[i+1] - hs_all_layer[i] for i in range(layer_num - 1)]) 58 | al_repdiff_norm = [np.linalg.norm(item, ord=2) / norm_denominator for item in al_repdiff] 59 | al_repdiff_ave = np.mean(np.array(al_repdiff_norm)) 60 | al_repdiff_var = np.var(np.array(al_repdiff_norm)) 61 | return al_repdiff_norm, al_repdiff_ave, al_repdiff_var 62 | 63 | 64 | def compute_CoE_Ang(self): 65 | hs_all_layer = self.hidden_states 66 | layer_num = len(hs_all_layer) 67 | 68 | al_semdiff = [] 69 | norm_denominator = np.dot(hs_all_layer[-1], hs_all_layer[0]) / (np.linalg.norm(hs_all_layer[-1], ord=2) * np.linalg.norm(hs_all_layer[0], ord=2)) 70 | norm_denominator = math.acos(norm_denominator) 71 | for i in range(layer_num - 1): 72 | a = hs_all_layer[i + 1] 73 | b = hs_all_layer[i] 74 | dot_product = np.dot(a, b) 75 | norm_a, norm_b = np.linalg.norm(a, ord=2), np.linalg.norm(b, ord=2) 76 | similarity = dot_product / (norm_a * norm_b) 77 | similarity = similarity if similarity <= 1 else 1 78 | 79 | arccos_sim = math.acos(similarity) 80 | al_semdiff.append(arccos_sim / norm_denominator) 81 | 82 | al_semdiff_norm = np.array(al_semdiff) 83 | al_semdiff_ave = np.mean(np.array(al_semdiff_norm)) 84 | al_semdiff_var = np.var(np.array(al_semdiff_norm)) 85 | 86 | return al_semdiff_norm, al_semdiff_ave, al_semdiff_var 87 | 88 | def compute_CoE_R(self): 89 | _, al_repdiff_ave, _ = self.compute_CoE_Mag() 90 | _, al_semdiff_ave, _ = self.compute_CoE_Ang() 91 | 92 | return al_repdiff_ave - al_semdiff_ave 93 | 94 | def compute_CoE_C(self): 95 | al_repdiff_norm, _, _ = self.compute_CoE_Mag() 96 | al_semdiff_norm, _, _ = self.compute_CoE_Ang() 97 | x_list = np.array([al_repdiff_norm[i] * math.cos(al_semdiff_norm[i]) for i in range(len(al_semdiff_norm))]) 98 | y_list = np.array([al_repdiff_norm[i] * math.sin(al_semdiff_norm[i]) for i in range(len(al_semdiff_norm))]) 99 | al_combdiff_x_ave = np.mean(x_list) 100 | al_combdiff_y_ave = np.mean(y_list) 101 | al_combdiff_x_var = np.mean(x_list) 102 | al_combdiff_y_var = np.mean(y_list) 103 | 104 | return math.sqrt(al_combdiff_x_ave ** 2 + al_combdiff_y_ave ** 2) 105 | -------------------------------------------------------------------------------- /Evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import pickle 5 | import argparse 6 | 7 | from sklearn.metrics import roc_curve, auc, precision_recall_curve 8 | from scipy import interpolate 9 | 10 | project_root_path = os.environ["PROJECT_PATH"] 11 | sys.path.append(project_root_path) 12 | from Data.load_data import DatasetInfo 13 | from config_pool import MODEL_POOL, DATASET_POOL, LANGUAGE_MAPPING 14 | from match import AnswerParsing 15 | 16 | 17 | class StandardEvaluation: 18 | def __init__(self, dataset_list): 19 | self.data_all = [] 20 | self.data_size = 0 21 | for i, dataset in enumerate(dataset_list): 22 | data_loader = DatasetInfo(args.dataset) 23 | self.data_all.extend(data_loader.data) 24 | self.data_size += data_loader.data_size 25 | 26 | def std_eval(self, args): 27 | answerparsing = AnswerParsing(args.dataset) 28 | output_dir = os.path.join(project_root_path, f"OutputInfo/{args.language}/Output", args.model_name, args.dataset) 29 | coe_dir = os.path.join(project_root_path, f"OutputInfo/{args.language}/CoE", args.model_name, args.dataset) 30 | 31 | output_list, coe_list, binary_list = [], [], [] 32 | acc = 0 33 | for i in range(self.data_size): 34 | sample = self.data_all[i] 35 | true_output = sample["answer"] 36 | 37 | with open(os.path.join(output_dir, f"{args.dataset}_{str(i)}.pkl"), 'rb') as file: 38 | output = pickle.load(file) 39 | pred_output = output["output_seq"] 40 | 41 | with open(os.path.join(coe_dir, f"{args.dataset}_{str(i)}.pkl"), 'rb') as file: 42 | coe = pickle.load(file) 43 | 44 | extracted_answer, binary = answerparsing.dataset_parse(pred_output, true_output, sample) 45 | if binary: acc += 1 46 | 47 | output_list.append(output) 48 | coe_list.append(coe) 49 | binary_list.append(binary) 50 | 51 | return round(acc / self.data_size, 3), output_list, coe_list, binary_list 52 | 53 | 54 | class SelfEvaluation: 55 | def __init__(self, dataset_list): 56 | self.data_all = [] 57 | self.data_size = 0 58 | for i, dataset in enumerate(dataset_list): 59 | data_loader = DatasetInfo(args.dataset) 60 | self.data_all.extend(data_loader.data) 61 | self.data_size += data_loader.data_size 62 | 63 | def self_eval(self, score_list, binary_list): 64 | fpr, tpr, thresholds = roc_curve(binary_list, score_list) 65 | auroc = auc(fpr, tpr) 66 | fpr95 = float(interpolate.interp1d(tpr, fpr)(0.95)) 67 | precision, recall, _ = precision_recall_curve(binary_list, score_list) 68 | aupr = auc(recall, precision) 69 | 70 | return round(auroc * 100, 2), round(fpr95 * 100, 2), round(aupr * 100, 2) 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser(description="eval") 75 | parser.add_argument("--model_name", type=str, default="Llama-3-8B-Instruct", choices=MODEL_POOL) 76 | parser.add_argument("--dataset", type=str, default="mgsm", choices=DATASET_POOL) 77 | parser.add_argument("--language", type=str, default="en") 78 | 79 | args = parser.parse_args() 80 | 81 | stdeval = StandardEvaluation([args.dataset]) 82 | acc, output_list, coe_list, binary_list = stdeval.std_eval(args) 83 | print(f"# Accuracy: {acc}") 84 | 85 | input_list = [output_list[i]["input_seq"] for i in range(len(output_list))] 86 | maxprob_list = [output_list[i]["maxprob"] for i in range(len(output_list))] 87 | ppl_list = [1 / output_list[i]["ppl"] for i in range(len(output_list))] 88 | entropy_list = [1 / output_list[i]["entropy"] for i in range(len(output_list))] 89 | coer_list = [coe_list[i]["R"] for i in range(len(coe_list))] 90 | coec_list = [coe_list[i]["C"] for i in range(len(coe_list))] 91 | 92 | selfeval = SelfEvaluation([args.dataset]) 93 | maxprob_auroc, maxprob_fpr95, maxprob_aupr = selfeval.self_eval(maxprob_list, binary_list) 94 | ppl_auroc, ppl_fpr95, ppl_aupr = selfeval.self_eval(ppl_list, binary_list) 95 | entropy_auroc, entropy_fpr95, entropy_aupr = selfeval.self_eval(entropy_list, binary_list) 96 | coer_auroc, coer_fpr95, coer_aupr = selfeval.self_eval(coer_list, binary_list) 97 | coec_auroc, coec_fpr95, coec_aupr = selfeval.self_eval(coec_list, binary_list) 98 | 99 | print(f"{'maxprob_auroc'.rjust(13)}: {maxprob_auroc:.2f} {'maxprob_fpr95'.rjust(13)}: {maxprob_fpr95:.2f} {'maxprob_aupr'.rjust(13)}: {maxprob_aupr:.2f}") 100 | print(f"{'ppl_auroc'.rjust(13)}: {ppl_auroc:.2f} {'ppl_fpr95'.rjust(13)}: {ppl_fpr95:.2f} {'ppl_aupr'.rjust(13)}: {ppl_aupr:.2f}") 101 | print(f"{'entropy_auroc'.rjust(13)}: {entropy_auroc:.2f} {'entropy_fpr95'.rjust(13)}: {entropy_fpr95:.2f} {'entropy_aupr'.rjust(13)}: {entropy_aupr:.2f}") 102 | print(f"{'coer_auroc'.rjust(13)}: {coer_auroc:.2f} {'coer_fpr95'.rjust(13)}: {coer_fpr95:.2f} {'coer_aupr'.rjust(13)}: {coer_aupr:.2f}") 103 | print(f"{'coec_auroc'.rjust(13)}: {coec_auroc:.2f} {'coec_fpr95'.rjust(13)}: {coec_fpr95:.2f} {'coec_aupr'.rjust(13)}: {coec_aupr:.2f}") 104 | -------------------------------------------------------------------------------- /prompt_pool.py: -------------------------------------------------------------------------------- 1 | DATASET_PROMPTS = { 2 | "GSM8K": "Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of \"Answer:\". Do not add anything other than the integer answer after \"Answer:\".\n\nQuestion:\n{input_data}\n", 3 | "mgsm": "Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of \"Answer:\". Do not add anything other than the integer answer after \"Answer:\".\n\nQuestion:\n{input_data}\n", 4 | "math": "Question: {input_data}\nPlease reason step by step, and put your final answer within \\boxed{{}}\n", 5 | "mmmlu": "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\nQuestion:\n{input_data}\n", 6 | "belebele": "Answer the following multiple choice reading-comprehension question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Please fully understand the passage and give explanations step by step before answering.\n\n{input_data}\n", 7 | "commonsenseqa": "Answer the following multiple choice common-sense reasoning question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCDE. Think step by step and output the reasoning process before answering.\n\n{input_data}", 8 | "theoremqa": ''' 9 | Below is an instruction that describes a task, paired with an input that provides further context. 10 | Write a response that appropriately completes the request. 11 | 12 | ### Instruction: 13 | Please read a math problem, and then think step by step to derive the answer. The answer is decided by Answer Type. 14 | If the Answer type in [bool], the answer needs to be True or False. 15 | Else if the Answer type in [integer, float] , The answer needs to be in numerical form. 16 | Else if the Answer type in [list of integer, list of float] , the answer needs to be a list of number like [2, 3, 4]. 17 | Else if the Answer type in [option], the answer needs to be an option like (a), (b), (c), (d). 18 | You need to output the answer in your final sentence like 'Therefore, the answer is ...'. 19 | 20 | ### Question: 21 | ''' + "{input_data}\n\n### Answer_type: {answer_type}\n\n### Response:" 22 | } 23 | 24 | 25 | 26 | QUESTIONS = { 27 | "en": "Question:", 28 | "es": "Pregunta:", 29 | "fr": "Question:", 30 | "de": "Frage:", 31 | "ru": "Задача:", 32 | "zh": "问题:", 33 | "ja": "問題:", 34 | "th": "โจทย์:", 35 | "sw": "Swali:", 36 | "bn": "প্রশ্ন:", 37 | "te": "ప్రశ్న:" 38 | } 39 | 40 | ANSWERS = { 41 | "en": "Step-by-Step Answer:", 42 | "es": "Respuesta paso a paso:", 43 | "fr": "Réponse étape par étape :", 44 | "de": "Schritt-für-Schritt-Antwort:", 45 | "ru": "Пошаговоерешение:", 46 | "zh": "逐步解答:", 47 | "ja": "ステップごとの答え:", 48 | "th": "คำตอบทีละขั้นตอน:", 49 | "sw": "Jibu la Hatua kwa Hatua:", 50 | "bn": "ধাপে ধাপে উত্তর:", 51 | "te": "దశలవారీగా సమాధానం:" 52 | } 53 | 54 | 55 | INSTRUCTIONS = { 56 | "en": """Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:". Do not add anything other than the integer answer after "Answer:".""", 57 | "bn": """এই গণিতের সমস্যাটি সমাধান করুন। চূড়ান্ত উত্তর দেওয়ার আগে যুক্তিসম্পন্ন পদক্ষেপ প্রদান করুন। চূড়ান্ত উত্তরটি একক সংখ্যা হিসাবে "উত্তর:" এর পরে শেষ লাইনে দিন। "উত্তর:" এর পরে অন্য কিছু যুক্ত করবেন না।.""", 58 | "de": """Löse dieses Mathematikproblem. Gib die Schritte zur Begründung an, bevor du die endgültige Antwort in der letzten Zeile alleine im Format "Antwort:" gibst. Füge nichts anderes als die ganzzahlige Antwort nach "Antwort:" hinzu.""", 59 | "es": """Resuelve este problema matemático. Proporciona los pasos de razonamiento antes de dar la respuesta final en la última línea por sí misma en el formato de "Respuesta:". No añadas nada más que la respuesta entera después de "Respuesta:".""", 60 | "fr": """Résolvez ce problème de mathématiques. Donnez les étapes de raisonnement avant de fournir la réponse finale sur la dernière ligne elle-même dans le format de "Réponse:". N'ajoutez rien d'autre que la réponse entière après "Réponse:".""", 61 | "ja": """の数学の問題を解いてください。最終的な答えを出す前に、解答の推論過程を記述してください。そして最後の行には "答え:" の形式で答えを記述し、その後には整数の答え以外何も追加しないでください。""", 62 | "ru": """Решите эту математическую задачу. Объясните шаги рассуждения перед тем, как дать окончательный ответ в последней строке сам по себе в формате "Ответ:". Не добавляйте ничего, кроме целочисленного ответа после "Ответ:".""", 63 | "sw": """Suluhisha tatizo hili la hesabu. Toa hatua za mantiki kabla ya kutoa jibu la mwisho kwenye mstari wa mwisho peke yake katika muundo wa "Jibu:". Usiongeze chochote kingine isipokuwa jibu la integer baada ya "Jibu:".""", 64 | "te": """ఈ గణిత సమస్యను పరిష్కరించండి. చివరి సమాధానాన్ని ఇవ్వదానికి ముందు తర్కాత్మక అదుగులను ఇవ్వండి. చివరి పంక్తిలో మాత్రమే 'సమాధానం:' అనే ఆకారంలో చివరి సమాధానాద్ని ఇవ్వండి సమాధానం: తర్వాత పూర్ణాంక సమాధానానికి తప్పించి ఎదేనా చేర్చవద్దు.""", 65 | "th": """แก้ปัญหาคณิตศาสตร์นี้ ให้ให้ขั้นตอนการใช้เหตุผลก่อนที่จะให้คำตอบสุดท้ายในบรรทัดสุดท้ายโดยอยู่ในรูปแบบ "คำตอบ:" ไม่ควรเพิ่มอะไรนอกจากคำตอบที่เป็นจำนวนเต็มหลังจาก "คำตอบ:""""", 66 | "zh": """解决这个数学问题。在最后一行给出答案前,请提供推理步骤。最后一行应该以 "答案: " 的形式独立给出答案。在 "答案:" 后不要添加除整数答案之外的任何内容。""", 67 | } 68 | 69 | ANSWER_PREFIX = { 70 | "en": "Answer", 71 | "bn": "উত্তর", 72 | "de": "Antwort", 73 | "es": "Respuesta", 74 | "fr": "Réponse", 75 | "ja": "答え", 76 | "ru": "Ответ", 77 | "sw": "Jibu", 78 | "te": "సమాధానం", 79 | "th": "คำตอบ", 80 | "zh": "答案", 81 | } 82 | 83 | 84 | -------------------------------------------------------------------------------- /Evaluation/match.py: -------------------------------------------------------------------------------- 1 | import re 2 | from prompt_pool import ANSWER_PREFIX 3 | 4 | 5 | class AnswerParsing: 6 | def __init__(self, dataset): 7 | self.dataset = dataset 8 | 9 | def dataset_parse(self, pred, true, sample): 10 | if self.dataset == "mgsm" or self.dataset == "gsm8k": 11 | extracted_answer, binary = self.mgsm_parse(ANSWER_PREFIX["en"], pred, true) 12 | elif self.dataset == "math": 13 | extracted_answer, binary = self.math_parse(ANSWER_PREFIX["en"], pred, true) 14 | elif self.dataset == "commonsenseqa": 15 | extracted_answer, binary = self.commonsenseqa_parse(pred, true) 16 | elif self.dataset == "theoremqa": 17 | extracted_answer, binary = self.theoremqa_parse(sample["answer_type"], pred, true) 18 | elif self.dataset == "mmmlu": 19 | extracted_answer, binary = self.mmmlu_parse(pred, true) 20 | elif self.dataset == "belebele": 21 | extracted_answer, binary = self.belebele_parse(pred, true) 22 | 23 | return extracted_answer, binary 24 | 25 | 26 | def extract_boxed_content(self, text): 27 | pattern = re.compile(r'\\boxed{') 28 | matches = pattern.finditer(text) 29 | results = [] 30 | for match in matches: 31 | start_pos = match.end() 32 | brace_count = 1 33 | i = start_pos 34 | while i < len(text) and brace_count > 0: 35 | if text[i] == '{': 36 | brace_count += 1 37 | elif text[i] == '}': 38 | brace_count -= 1 39 | i += 1 40 | if brace_count == 0: 41 | results.append(text[start_pos:i-1]) 42 | return results 43 | 44 | 45 | def mgsm_parse(self, answer_prefix, pred, true): 46 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 47 | return "Incomplete", False 48 | 49 | if answer_prefix not in pred: 50 | return None, False 51 | 52 | answer_text = pred.split(answer_prefix)[-1].strip() 53 | numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", "")) 54 | extracted_answer = numbers[-1].rstrip(".") if numbers else "" 55 | 56 | if "." in extracted_answer: 57 | extracted_answer = extracted_answer.rstrip("0").rstrip(".") 58 | true = true.replace(",", "") 59 | extracted_answer = extracted_answer.replace(",", "") 60 | 61 | return extracted_answer, true == extracted_answer 62 | 63 | 64 | def math_parse(self, answer_prefix, pred, true): 65 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 66 | return "Incomplete", False 67 | 68 | extracted_answer = self.extract_boxed_content(pred) 69 | extracted_answer = extracted_answer[0] if extracted_answer else None 70 | if extracted_answer: 71 | extracted_answer = extracted_answer.replace(" ", "") 72 | true = true.replace(" ", "") 73 | 74 | return extracted_answer, true == extracted_answer 75 | 76 | 77 | def mmmlu_parse(self, pred, true): 78 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 79 | return "Incomplete", False 80 | 81 | ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])" 82 | pred = pred.replace("$", "") 83 | pred = pred.replace("(", "") 84 | pred = pred.replace(")", "") 85 | match = re.search(ANSWER_PATTERN_MULTICHOICE, pred) 86 | extracted_answer = match.group(1) if match else None 87 | 88 | return extracted_answer, true == extracted_answer 89 | 90 | 91 | def commonsenseqa_parse(self, pred, true): 92 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 93 | return "Incomplete", False 94 | 95 | ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-E])" 96 | pred = pred.replace("$", "") 97 | pred = pred.replace("(", "") 98 | pred = pred.replace(")", "") 99 | match = re.search(ANSWER_PATTERN_MULTICHOICE, pred) 100 | extracted_answer = match.group(1) if match else None 101 | 102 | return extracted_answer, true == extracted_answer 103 | 104 | 105 | def belebele_parse(self, pred, true): 106 | alpha_map = ["", "A", "B", "C", "D"] 107 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 108 | return "Incomplete", False 109 | 110 | ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])" 111 | pred = pred.replace("$", "") 112 | pred = pred.replace("(", "") 113 | pred = pred.replace(")", "") 114 | match = re.search(ANSWER_PATTERN_MULTICHOICE, pred) 115 | extracted_answer = match.group(1) if match else None 116 | 117 | return extracted_answer, alpha_map[int(true)] == extracted_answer 118 | 119 | def theoremqa_parse(self, answer_type, pred, true): 120 | if "<|im_end|>" not in pred and "<|eot_id|>" not in pred and "" not in pred and "<|END_OF_TURN_TOKEN|>" not in pred: 121 | return "Incomplete", False 122 | 123 | answer_text = pred.split("answer is")[-1].strip() 124 | pred = re.sub(r'\<.*\>', '', answer_text) 125 | 126 | if answer_type == "bool": 127 | if "True" in pred or "true" in pred: 128 | extracted_answer = "True" 129 | elif "False" in pred or "false" in pred: 130 | extracted_answer = "False" 131 | else: 132 | extracted_answer = None 133 | elif answer_type == "integer" or answer_type == "float": 134 | numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", "")) 135 | extracted_answer = numbers[-1].rstrip(".") if numbers else "" 136 | elif answer_type == "list of integer" or answer_type == "list of float": 137 | match = re.search(r"\[(.*?)\]", pred) 138 | extracted_answer = match.group(1) if match else None 139 | true = true[1:-1] 140 | elif answer_type == "option": 141 | match = re.search(r"\(([a-d])\)", pred) 142 | extracted_answer = match.group(1) if match else None 143 | true = true[1:-1] 144 | 145 | return extracted_answer, true == extracted_answer -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ##
Latent Space Chain-of-Embedding Enables Output-free LLM Self-Evaluation
2 | 3 | 4 | ![Arxiv](https://img.shields.io/badge/Arxiv-2410.13640-red.svg?style=plastic) 5 | ![python 3.10](https://img.shields.io/badge/python-3.10-royalblue.svg?style=plastic) 6 | ![transformer 4.45](https://img.shields.io/badge/transformer-4.45-green.svg?style=plastic) 7 | ![license Apache-2.0](https://img.shields.io/badge/license-Apache%202.0-inactive.svg?style=plastic) 8 | 9 | 10 | $\boxed{\text{Chain-of-Embedding (CoE)}}$ is a brand-new interpretability tool, which **captures a progressive embedding chain from input to output space** by tracking the hidden states of language models during inference. 11 | 12 |
13 | 14 |
15 | 16 | ***Definition***: 17 |
18 | 19 |
20 | 21 | ***Two CoE Features***: 22 |
23 | 24 |
25 | 26 | ***Insight***: When the language model responds correctly and incorrectly, the CoE in the latent space will produce differentiated representations. 27 |
28 | 29 |
30 | 31 | 32 | --- 33 | 34 | ## *Usage Instruction* 35 | 36 | We provide automated scripts in this repository to help you obtain inference-time **CoE scores** and **Visualization** of each sample. 37 | 38 | ## Environment Installation 39 | 40 | ```sh 41 | conda create -n coeeval python=3.10 42 | conda activate coeeval 43 | pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 44 | pip install -r requirements.txt 45 | ``` 46 | 47 | ## Preparation 48 | 49 | 50 | ### 1. Model 51 | 52 | To prevent instability in remote access, our code uses local model loading. You need to download the model you need to deploy (e.g., Llama-3-8B-Instruct) into the ``Model`` folder and add the name of the model folder to ``MODEL_POOL`` in ``config_pool.py``. 53 | 54 | ``Model/`` 55 | 56 | ```sh 57 | Model 58 | └── Llama-3-8B-Instruct ### add 59 | ├── config.json 60 | ├── model-00001-... 61 | ├── tokenizer_config.json 62 | └── ... 63 | ``` 64 | 65 | ``config_pool.py`` 66 | 67 | ```py 68 | MODEL_POOL = [ 69 | "Llama-3-8B-Instruct", ### add 70 | ] 71 | ``` 72 | 73 | *Note*: Our work is primarily focused on the **zero-shot paradigm for instruct-based models**. If using a base model, it may need to be modified to a few-shot paradigm. 74 | 75 | 76 | ### 2. Dataset 77 | 78 | You need to add the dataset (``.jsonl`` format) to be tested into the ``Data`` folder (e.g. math). 79 | 80 | ``Data/`` 81 | 82 | ```sh 83 | Model 84 | └── math.jsonl ### add 85 | ``` 86 | 87 | Each sample must contain **at least the following keys and values**: 88 | 89 | ```py 90 | { 91 | "id": 1, ### Unique identifier 92 | "en": "Find the units digit of $29 \\cdot 79 + 31 \\cdot 81$.", ### Question described in English 93 | "answer": "2" ### Standard answer without solution process 94 | } 95 | ``` 96 | 97 | Similarly, please add the dataset name to ``DATASET_POOL`` in ``config_pool.py``. 98 | 99 | ``config_pool.py`` 100 | 101 | ```py 102 | DATASET_POOL = [ 103 | "math", ### add 104 | ] 105 | ``` 106 | 107 | We also provide a multilingual interface. If you need to infer a multilingual version of a dataset, please first add the desired language list to ``LANGUAGE_MAPPING`` in ``config_pool.py``, and add question descriptions in the dataset file using the language names as keywords (refer to the mgsm dataset). 108 | 109 | ### 3. Instruction 110 | 111 | The instructions corresponding to different datasets are stored under ``DATASET_PROMPTS`` in ``prompt_pool.py``. We provide instructions for all the datasets used in the paper, with some of them referencing and . 112 | 113 | ``prompt_pool.py`` 114 | 115 | ```py 116 | DATASET_PROMPTS = { 117 | "mgsm": "Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of \"Answer:\". Do not add anything other than the integer answer after \"Answer:\".\n\nQuestion:\n{input_data}\n", ### add 118 | } 119 | ``` 120 | 121 | *Note*: Please use **string parsing format** to facilitate automated parsing of input questions. 122 | 123 | ## Inference 124 | 125 | ```sh 126 | bash Scripts/llm_infer.sh 127 | ``` 128 | 129 | Your can modify the following parameters in this script: 130 | 131 | ```sh 132 | #!/bin/bash 133 | 134 | export PROJECT_PATH="your/path/to/Chain-of-Embedding" ### your local path 135 | export CUDA_VISIBLE_DEVICES="0,1" ### cuda id 136 | 137 | model_name="qwen2-7B-Instruct" ### model name 138 | dataset_list=(mgsm) ### dataset name (We provide a mode for cyclic testing across multiple datasets.) 139 | 140 | for i in ${dataset_list[*]}; do 141 | python main.py --model_name $model_name \ 142 | --dataset "$i" \ 143 | ### The following are optional: 144 | --print_model_parameter \ 145 | --save_output \ 146 | --save_hidden_states \ 147 | --save_coe_score \ 148 | --save_coe_figure 149 | done 150 | ``` 151 | 152 | ### Optional Parameters: 153 | 154 | ```print_model_parameter```: As described 155 | 156 | ```--save_output```: Save output (text sequence, perplexity, etc.) of each sample to ```./OutputInfo/{language}/Output/{model}/{dataset}/{dataset}_{sample_id}.pkl``` 157 | 158 | ```--save_hidden_states```: Save hidden states corresponding to each output token at each layer of each sample to ```./OutputInfo/{language}/HiddenStates/{model}/{dataset}/{dataset}_{sample_id}.pkl``` 159 | 160 | ```--save_coe_score```: Save CoE Score of each sample to ```./OutputInfo/{language}/CoE/{model}/{dataset}/{dataset}_{sample_id}.pkl``` 161 | 162 | ```--save_coe_figure```: Draw the CoE trajectory of each sample to ```./Figure/{language}/{model}/{dataset}/{dataset}_{sample_id}.pkl``` 163 | 164 | ### Core CoE Score Computation 165 | 166 | In ```score.py```. 167 | 168 | ## Citation 169 | 170 | If you use our technique or are inspired by our work, welcome to cite our paper and provide valuable suggestions. 171 | 172 | ``` 173 | @article{wang2024latent, 174 | title={Latent Space Chain-of-Embedding Enables Output-free LLM Self-Evaluation}, 175 | author={Wang, Yiming and Zhang, Pei and Yang, Baosong and Wong, Derek F and Wang, Rui}, 176 | journal={arXiv preprint arXiv:2410.13640}, 177 | year={2024} 178 | } 179 | ``` 180 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | 6 | import pickle 7 | import argparse 8 | import scipy.spatial 9 | import math 10 | import json 11 | import torch 12 | import torch.nn as nn 13 | 14 | from sklearn.decomposition import PCA 15 | import matplotlib.pyplot as plt 16 | import matplotlib.font_manager as fm 17 | from matplotlib.colors import Normalize 18 | import seaborn as sns 19 | from collections import Counter 20 | 21 | import numpy as np 22 | import pickle 23 | from tqdm import tqdm 24 | from transformers import ( 25 | AutoTokenizer, 26 | AutoModelForCausalLM, 27 | AutoConfig, 28 | GenerationConfig, 29 | ) 30 | 31 | project_root_path = os.environ["PROJECT_PATH"] 32 | sys.path.append(project_root_path) 33 | from Data.load_data import DatasetInfo 34 | from prompt_pool import * 35 | from score import OutputScoreInfo, CoEScoreInfo 36 | 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | 39 | 40 | class Inference: 41 | def __init__(self, model_info: dict, dataset_info: dict, verbose: dict): 42 | self.model_info = model_info 43 | self.dataset_info = dataset_info 44 | self.verbose = verbose 45 | 46 | self.model = self.model_info["model_ckpt"] 47 | self.model_name = self.model_info["model_name"] 48 | self.config = self.model_info["model_config"] 49 | self.generation_config = self.model_info["generation_config"] 50 | self.tokenizer = self.model_info["tokenizer"] 51 | self.max_output_token = self.model_info["max_output_token"] 52 | 53 | self.dataset_name = self.dataset_info["dataset_name"] 54 | self.data_loader = DatasetInfo(self.dataset_name) 55 | self.data_all = self.data_loader.data 56 | self.data_size = self.data_loader.data_size 57 | self.language = self.dataset_info["language"] 58 | 59 | self.sample_info = {} 60 | 61 | 62 | def dataset_inference(self): 63 | self.greedy_inference() 64 | 65 | 66 | def greedy_inference(self): 67 | for i in tqdm(range(self.data_size)): 68 | print("*"*30 + f" index {str(i)} " + "*"*30) 69 | sample = self.data_all[i] 70 | input_data, output_data, model_input, input_ids = self.parse_input(sample) 71 | self.sample_info = { 72 | "input": { 73 | "raw_input_data": input_data, 74 | "model_input": model_input, 75 | "model_input_ids": input_ids, 76 | }, 77 | "output": { 78 | "raw_output_data": output_data, 79 | } 80 | } 81 | 82 | with torch.no_grad(): 83 | generation_output = self.model_inference() 84 | self.sample_info["output"]["output_scores"] = generation_output.scores 85 | self.sample_info["output"]["output_seq"] = generation_output.sequences 86 | self.sample_info["output"]["attentions"] = generation_output.attentions 87 | self.sample_info["output"]["all_token_hidden_states"] = generation_output.hidden_states # output_len x layer_num x sampling_num x beam_search x hidden_dim 88 | self.sample_info["output"]["output_len"] = min(self.max_output_token, len(generation_output.scores)) 89 | 90 | output_seq, maxprob, ppl, entropy = self.print_output() 91 | output = {'id': i, 92 | 'answer_type': sample["answer_type"] if self.dataset_name == "theoremqa" else "", 93 | 'input_seq': self.sample_info["input"]["model_input"], 94 | 'output_seq': output_seq, 95 | 'maxprob': maxprob, 96 | 'ppl': ppl, 97 | 'entropy': entropy} 98 | if self.verbose["save_output"]: self.save_output(output, i) 99 | 100 | hidden_states = self.print_hidden_states() 101 | if self.verbose["save_hidden_states"]: self.save_hidden_states(hidden_states, i) 102 | 103 | CoE_score = self.print_CoE_score() 104 | if self.verbose["save_coe_score"]: self.save_CoE_score(CoE_score, i) 105 | if self.verbose["save_coe_figure"]: self.save_CoE_figure(hidden_states, i) 106 | 107 | 108 | def model_inference(self): 109 | input_ids = self.sample_info["input"]["model_input_ids"] 110 | self.model.eval() 111 | terminators = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")] \ 112 | if "Llama" in self.model_name else self.tokenizer.eos_token_id 113 | 114 | time_start = time.time() 115 | generation_output = self.model.generate( 116 | input_ids=input_ids.to(device), 117 | pad_token_id=self.tokenizer.eos_token_id, 118 | eos_token_id=terminators, 119 | generation_config=self.generation_config, 120 | return_dict_in_generate=True, 121 | max_new_tokens=self.max_output_token, 122 | output_attentions=True, 123 | output_hidden_states=True, 124 | output_scores=True, 125 | do_sample=False, 126 | ) 127 | time_end = time.time() 128 | print(f'inference time: {round(time_end - time_start, 4)}') 129 | 130 | return generation_output 131 | 132 | 133 | def parse_input(self, sample): 134 | input_data = sample[self.language] 135 | output_data = sample["answer"] 136 | 137 | model_input = DATASET_PROMPTS[self.dataset_name].replace("{input_data}", input_data) 138 | if self.dataset_name == "theoremqa": 139 | model_input = model_input.replace("{answer_type}", sample["answer_type"]) 140 | input_ids = self.tokenizer.apply_chat_template([{"role": "user", "content": model_input}], 141 | tokenize=True, add_generation_prompt=True, return_tensors="pt") 142 | input_len = len(input_ids[0]) 143 | 144 | print(f"********** Input Text (length: {input_len}) **********\n{input_data}\n") 145 | print(f"********** Input ID **********\n{input_ids}\n") 146 | 147 | return input_data, output_data, model_input, input_ids 148 | 149 | 150 | def print_output(self): 151 | output_scores = self.sample_info["output"]["output_scores"] 152 | output_seq = self.sample_info["output"]["output_seq"] 153 | true_output = self.sample_info["output"]["raw_output_data"] 154 | output_len = self.sample_info["output"]["output_len"] 155 | 156 | output_seq = self.tokenizer.decode(output_seq[0][-output_len:]) 157 | print(f"********** Model-generated Text (length: {output_len}) **********\n{output_seq}\n") 158 | print(f"********** True Output Text **********\n{true_output}\n") 159 | 160 | outputinfo = OutputScoreInfo(output_scores) 161 | maxprob = outputinfo.compute_maxprob() 162 | ppl = outputinfo.compute_ppl() 163 | entropy = outputinfo.compute_entropy() 164 | print(f"********** Output Info: **********\nmaxprob {maxprob}; perplexity {ppl}; entropy {entropy}\n") 165 | 166 | return output_seq, maxprob, ppl, entropy 167 | 168 | 169 | def save_output(self, output, i): 170 | filedir = os.path.join(project_root_path, f'OutputInfo/{self.language}/Output', self.model_name, self.dataset_name) 171 | if not os.path.exists(filedir): 172 | os.makedirs(filedir) 173 | with open(os.path.join(filedir, self.dataset_name + '_' + str(i) + '.pkl'), 'wb') as file: 174 | pickle.dump(output, file) 175 | 176 | 177 | def print_hidden_states(self): 178 | hidden_states = self.sample_info["output"]["all_token_hidden_states"] 179 | output_len = self.sample_info["output"]["output_len"] 180 | 181 | layer_num = len(hidden_states[1]) 182 | hs_all_layer = [] 183 | for j in range(layer_num): 184 | all_pos_hs = np.array([np.array(hidden_states[pos][j][0][0].cpu()) for pos in range(0, output_len)]) 185 | hs_all_layer.append(np.mean(all_pos_hs, axis=0)) 186 | hidden_states = hs_all_layer 187 | print(f"********** Hidden State Size: **********\n{np.array(hidden_states).shape}\n") 188 | 189 | return hidden_states 190 | 191 | 192 | def save_hidden_states(self, hidden_states, i): 193 | hs = {'hidden_states': hidden_states} 194 | filedir = os.path.join(project_root_path, f'OutputInfo/{self.language}/HiddenStates', self.model_name, self.dataset_name) 195 | if not os.path.exists(filedir): 196 | os.makedirs(filedir) 197 | with open(os.path.join(filedir, self.dataset_name + '_' + str(i) + '.pkl'), 'wb') as file: 198 | pickle.dump(hs, file) 199 | 200 | 201 | def print_CoE_score(self): 202 | hidden_states = self.sample_info["output"]["all_token_hidden_states"] 203 | output_len = self.sample_info["output"]["output_len"] 204 | layer_num = len(hidden_states[1]) 205 | 206 | hs_all_layer = [] 207 | for j in range(layer_num): 208 | all_pos_hs = np.array([np.array(hidden_states[pos][j][0][0].cpu()) for pos in range(0, output_len)]) 209 | hs_all_layer.append(np.mean(all_pos_hs, axis=0)) 210 | 211 | coescoreinfo = CoEScoreInfo(hs_all_layer) 212 | _, coe_mag, _ = coescoreinfo.compute_CoE_Mag() 213 | _, coe_ang, _ = coescoreinfo.compute_CoE_Ang() 214 | coe_r = coescoreinfo.compute_CoE_R() 215 | coe_c = coescoreinfo.compute_CoE_C() 216 | 217 | print(f"********** CoE Score Info: **********\nMag {coe_mag}; Ang {coe_ang}; R {coe_r}; C {coe_c}\n") 218 | return { 219 | "Mag": coe_mag, 220 | "Ang": coe_ang, 221 | "R": coe_r, 222 | "C": coe_c 223 | } 224 | 225 | 226 | def save_CoE_score(self, CoE_score, i): 227 | filedir = os.path.join(project_root_path, f'OutputInfo/{self.language}/CoE', self.model_name, self.dataset_name) 228 | if not os.path.exists(filedir): 229 | os.makedirs(filedir) 230 | with open(os.path.join(filedir, self.dataset_name + '_' + str(i) + '.pkl'), 'wb') as file: 231 | pickle.dump(CoE_score, file) 232 | 233 | 234 | def save_CoE_figure(self, hidden_states, i): 235 | embeddings = PCA(n_components=2, random_state=2024).fit_transform(np.array(hidden_states)) 236 | 237 | fig = plt.figure(figsize=(14, 8)) 238 | #fig.suptitle('Embedding Trajectory under Correct/Incorrect Samples', fontsize=40, fontweight='bold') 239 | ax1 = fig.add_subplot(1, 1, 1, facecolor='w') 240 | 241 | traj_x = np.array(embeddings[:, 0]) 242 | traj_y = np.array(embeddings[:, 1]) 243 | 244 | ax1.scatter(traj_x, traj_y, color='blue', alpha=1.0, edgecolor='white', s=200) 245 | ax1.plot(traj_x, traj_y, color='gray', linestyle='-', linewidth=2, alpha=0.5) 246 | ax1.text(0, 0, "Origin (0,0)", color='black', fontsize=10) 247 | 248 | ax1.set_xlabel('X-axis', fontsize=24, fontweight='bold') 249 | ax1.set_ylabel('Y-axis', fontsize=24, fontweight='bold') 250 | 251 | '''save''' 252 | filedir = os.path.join(project_root_path, f'Figure/{self.language}', self.model_name, self.dataset_name) 253 | if not os.path.exists(filedir): 254 | os.makedirs(filedir) 255 | plt.savefig(os.path.join(filedir, self.dataset_name + '_' + str(i) + '.png'), bbox_inches='tight', pad_inches=0) 256 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------