├── LICENSE ├── README.md ├── data ├── README.md ├── arc │ └── README.md └── proofwriter │ └── README.md ├── figures ├── motivation.jpg └── rt_example_fig_final.jpg ├── requirements.txt ├── run_attnvisual_gpt2.sh ├── run_causal_gpt2.sh ├── run_corrupt_proofwriter.sh ├── run_finetune_gpt2.sh ├── run_probing_arc.sh ├── run_probing_gpt2.sh ├── run_probing_proofwriter.sh ├── src ├── gpt2.py ├── llama.py ├── main.py ├── proofparser.py ├── rtask │ ├── arc_probe.py │ └── proofwriter_probe.py ├── stask │ ├── attn.py │ ├── probe.py │ └── prune.py └── utils.py └── tmp ├── attn ├── attn_analysis-all_min01.pdf └── attn_analysis-all_org.pdf ├── attn_pos ├── attn_analysis-all_min01.pdf └── attn_analysis-all_org.pdf ├── gpt2_cs ├── cs_min01_8_12.pdf └── cs_min01_all.pdf └── prune ├── pruning_analysis-all_min01_pos.pdf └── pruning_analysis-all_min01_size.pdf /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 yifan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MechanisticProbe 2 | 3 | [![](https://img.shields.io/badge/License-MIT-blue.svg)]() 4 | [![arxiv](https://img.shields.io/badge/arXiv-2310.14491-b31b1b)](https://arxiv.org/abs/2310.14491) 5 | [![Dataset Download](https://img.shields.io/badge/HuggingFace-Datasets-green)](https://huggingface.co/datasets/yyyyifan/MechanisticProbe_ProofWriter_ARC) 6 | [![GitHub Project](https://img.shields.io/badge/GitHub-Project-orange)](https://github.com/yifan-h/MechanisticProbe) 7 | 8 | 9 | ##### Source code for **[Towards a Mechanistic Interpretation of Multi-Step Reasoning Capabilities of Language Models](https://arxiv.org/abs/2310.14491)** 10 | 11 | 12 | --- 13 | 14 | 15 | ![](./figures/motivation.jpg) 16 | > In this paper, we explored how language models (LMs) perform multi-step reasoning tasks: by **memorizing answers from massive pretraining corpus**, or by **step-by-step reasoning**. 17 | 18 | 19 | --- 20 | 21 | ![](./figures/rt_example_fig_final.jpg) 22 | > To answer the research question, we propose the **MechanisticProbe** to detect the reasoning trees inside LMs via the attention patterns. We run the analysis experiments on three reasoning tasks: 23 | > 1. Probe GPT-2 on the synthetic task: finding the k-th smallest number from a number list; 24 | > 2. Probe LLaMA on the synthetic task: ProofWriter; 25 | > 3. Probe LLaMA on the real-world reasoning task: AI2 Reasoning Challenge (ARC). 26 | 27 | > The **MechanisticProbe** is quite simple composed of two **kNN classifiers**: 28 | >> The first one is used to predict if the statement is useful or not for the reasoning task; 29 | >> The second one is used to detect the reasoning step of the useful statement. 30 | 31 | --- 32 | 33 | ## Try MechanisticProbe 34 | 35 | ### 1. Prepare the environment 36 | 37 | Install necessary libraries: 38 | 39 | $ pip install -r requirements.txt 40 | 41 | Download the processed dataset folder from the [HuggingFace repo](https://huggingface.co/datasets/yyyyifan/MechanisticProbe_ProofWriter_ARC). 42 | 43 | 44 | ### 2. Prepare the model 45 | 46 | Run the finetuning script to finetune GPT-2 on the synthetic reasoning task (finding the k-th smallest number) by: 47 | 48 | $ bash run_finetune_gpt2.sh 49 | 50 | ### 3. Run the MechanisticProbe 51 | 52 | Run our probe model to analyze the finetuned GPT-2 on the synthetic task: 53 | 54 | $ bash run_probing_gpt2.sh 55 | 56 | ## Other analysis experiments 57 | 58 | We have also prepared code for other analysis experiments. Users have to prepare the datasets as well as the LMS. The path of datasets and models should be specified in the scripts. 59 | 60 | > 1. Attention visualization for GPT-2 model on the synthetic reasoning task (*run_attnvisual_gpt2.sh*); 61 | 62 | > 2. Causal analysis (head entropy calculation) and full attention visualization (*run_causal_gpt2.sh*); 63 | 64 | > 3. Probing experiments for LLaMA on the ProofWriter and ARC (*run_probing_proofwriter.sh* and *run_probing_arc.sh*); 65 | 66 | > 4. Robustness experiment for LLaMA on the ProofWriter (*run_corrupt_proofwriter.sh*). 67 | 68 | --- 69 | 70 | ### Cite 71 | 72 | If this work/code is helpful for you, welcome to cite our paper: 73 | 74 | ``` 75 | @article{hou2023towards, 76 | title={Towards a Mechanistic Interpretation of Multi-Step Reasoning Capabilities of Language Models}, 77 | author={Hou, Yifan and Li, Jiaoda and Fei, Yu and Stolfo, Alessandro and Zhou, Wangchunshu and Zeng, Guangtao and Bosselut, Antoine and Sachan, Mrinmaya}, 78 | journal={arXiv preprint arXiv:2310.14491}, 79 | year={2023} 80 | } 81 | ``` 82 | --- 83 | 84 | ### Contact 85 | 86 | Feel free to open an issue or send me (yifan.hou@inf.ethz.ch) an email if you have any questions! 87 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Please download the processed datasets and put them here 2 | 3 | The download link is [https://huggingface.co/datasets/yyyyifan/MechanisticProbe_ProofWriter_ARC](https://huggingface.co/datasets/yyyyifan/MechanisticProbe_ProofWriter_ARC). 4 | 5 | Users can also download the original version and processed using our code. -------------------------------------------------------------------------------- /data/arc/README.md: -------------------------------------------------------------------------------- 1 | ## Please download the ARC dataset and put it here 2 | 3 | The download link is [https://github.com/amazon-science/street-reasoning/tree/main/data/arc](https://github.com/amazon-science/street-reasoning/tree/main/data/arc). 4 | 5 | The code would clean the data and put the cleaned version here. -------------------------------------------------------------------------------- /data/proofwriter/README.md: -------------------------------------------------------------------------------- 1 | ## Please download the ProofWriter dataset and put it here 2 | 3 | The download link is [https://aristo-data-public.s3.amazonaws.com/proofwriter/proofwriter-dataset-V2020.12.3.zip](https://aristo-data-public.s3.amazonaws.com/proofwriter/proofwriter-dataset-V2020.12.3.zip). 4 | 5 | Current folder should contain **CWA** and **OWA** two sub-folders. The code would clean the data and put the cleaned version in **CWA**. 6 | -------------------------------------------------------------------------------- /figures/motivation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifan-h/MechanisticProbe/aa0bfd9dc2e9e38b675db31bad035d735083cff9/figures/motivation.jpg -------------------------------------------------------------------------------- /figures/rt_example_fig_final.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifan-h/MechanisticProbe/aa0bfd9dc2e9e38b675db31bad035d735083cff9/figures/rt_example_fig_final.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | numpy==1.21.0 3 | transformers==4.28.0 4 | tqdm==4.66.1 5 | networkx==2.6.3 6 | matplotlib==3.4.1 7 | scikit-learn==0.24.1 8 | scipy==1.10.1 -------------------------------------------------------------------------------- /run_attnvisual_gpt2.sh: -------------------------------------------------------------------------------- 1 | python -u src/main.py --analysis_task finetune_attn_ksmallest --model_path gpt2 -------------------------------------------------------------------------------- /run_causal_gpt2.sh: -------------------------------------------------------------------------------- 1 | python -u src/main.py --analysis_task causal_ksmallest --model_path gpt2 -------------------------------------------------------------------------------- /run_corrupt_proofwriter.sh: -------------------------------------------------------------------------------- 1 | python -u src/main.py --analysis_task corrupt_proofwriter --model_path llama-7b-hf --data_dir ./data/proofwriter -------------------------------------------------------------------------------- /run_finetune_gpt2.sh: -------------------------------------------------------------------------------- 1 | python -u src/main.py --analysis_task finetune_ksmallest --model_path gpt2 -------------------------------------------------------------------------------- /run_probing_arc.sh: -------------------------------------------------------------------------------- 1 | python -u src/main.py --analysis_task probing_arc --model_path llama-7b-hf --data_dir ./data/arc -------------------------------------------------------------------------------- /run_probing_gpt2.sh: -------------------------------------------------------------------------------- 1 | python -u src/main.py --analysis_task probing_ksmallest --model_path gpt2 -------------------------------------------------------------------------------- /run_probing_proofwriter.sh: -------------------------------------------------------------------------------- 1 | python -u src/main.py --analysis_task probing_proofwriter --model_path llama-7b-hf --data_dir ./data/proofwriter -------------------------------------------------------------------------------- /src/gpt2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import random 5 | import numpy as np 6 | from torch.optim import AdamW 7 | from transformers import AutoTokenizer, AutoModelForCausalLM, get_constant_schedule_with_warmup, GPT2Config, GPT2LMHeadModel 8 | from tqdm import tqdm 9 | import networkx as nx 10 | 11 | from utils import data_generator, plot_loss, set_grad, avg 12 | 13 | 14 | def evaluator(args, tokenizer, model, eval_data, eval_labels, batch_num, test=False, head_mask=None): 15 | if not test: 16 | rand_idx = random.choices([i for i in range(len(eval_labels))], k=1000) 17 | eval_data = [eval_data[idx] for idx in rand_idx] 18 | eval_labels = [eval_labels[idx] for idx in rand_idx] 19 | acc_list = [] 20 | with torch.no_grad(): 21 | for s in range(0, len(eval_labels), batch_num): 22 | tmp_data = eval_data[s:s+batch_num] 23 | tmp_label = eval_labels[s:s+batch_num] 24 | inputs = tokenizer(tmp_data, return_tensors="pt").to(args.device) 25 | if head_mask is not None: 26 | head_mask = head_mask.to(args.device) 27 | outputs = model(**inputs, labels=inputs["input_ids"], head_mask=head_mask).logits.softmax(-1)[:, -1, :] 28 | else: 29 | outputs = model(**inputs, labels=inputs["input_ids"]).logits.softmax(-1)[:, -1, :] 30 | predicts = torch.argmax(outputs, dim=1) 31 | labels = tokenizer(tmp_label, return_tensors="pt")["input_ids"].to(args.device) 32 | # if predict correctly, accuracy + 1 33 | labels = torch.squeeze(labels) 34 | for idx in range(labels.shape[0]): 35 | if labels[idx] == predicts[idx]: 36 | acc_list.append(1.) 37 | else: 38 | acc_list.append(0.) 39 | return avg(acc_list) 40 | 41 | 42 | def finetune_gpt2(args): 43 | # prepare model 44 | # folder_name = args.model_path.split("/")[-1] 45 | folder_name = "training" 46 | epoch_num = args.epoch_num 47 | batch_num = 128 # 32 * 16 for acc, 128 * 2 for attn 48 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 49 | model = AutoModelForCausalLM.from_pretrained(args.model_path).to(args.device) 50 | 51 | # prepare data 52 | sdata_generator = data_generator(args, tokenizer) 53 | train_data, val_data, test_data = sdata_generator.return_data(data=True) 54 | train_label, val_label, test_label = sdata_generator.return_data(data=False) 55 | 56 | # training 57 | set_grad(model, args.tuning_param) 58 | optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=1e-3,) # 2e-5 for 16/512, 1e-6 as default (+min00) 59 | scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=500,) 60 | 61 | loss_fcn = torch.nn.CrossEntropyLoss() 62 | loss_list = [] 63 | acc_list = [0] 64 | best_path = None 65 | with tqdm(total=int(epoch_num*len(train_label)/batch_num)) as t: 66 | for e in range(epoch_num): 67 | for s in range(0, len(train_label), batch_num): 68 | # updating 69 | tmp_data = train_data[s:s+batch_num] 70 | tmp_label = train_label[s:s+batch_num] 71 | inputs = tokenizer(tmp_data, return_tensors="pt").to(args.device) 72 | outputs = model(**inputs, labels=inputs["input_ids"]).logits.softmax(-1)[:, -1, :] 73 | labels = tokenizer(tmp_label, return_tensors="pt")["input_ids"].to(args.device) 74 | labels = torch.squeeze(labels) 75 | loss = loss_fcn(outputs, labels) 76 | loss_list.append(float(loss.item())) 77 | loss.backward() 78 | scheduler.step() 79 | if s % 2 == 0: 80 | optimizer.step() 81 | optimizer.zero_grad() 82 | # evaluation 83 | if (e*len(train_label)+s+1) % int(int(epoch_num*len(train_label)/batch_num) / 100) == 0: 84 | acc_list.append(evaluator(args, tokenizer, model, val_data, val_label, batch_num)) 85 | # print(acc_list[-1]) 86 | if acc_list[-1] == max(acc_list): # save best model 87 | best_path = os.path.join(args.tmp_dir, folder_name, "gpt2_"+args.stask+"_"+args.tuning_param) 88 | model.save_pretrained(best_path) 89 | # print results 90 | t.set_description("Training num: {:d}".format(s+1)) 91 | t.set_postfix(loss=round(float(loss.item()), 6), acc=round(acc_list[-1], 6)) 92 | t.update(1) 93 | 94 | # test 95 | acc_list.append(evaluator(args, tokenizer, model, test_data, test_label, batch_num, test=True)) 96 | del model 97 | if best_path is not None: 98 | model = GPT2LMHeadModel.from_pretrained(best_path).to(args.device) 99 | acc_list.append(evaluator(args, tokenizer, model, test_data, test_label, batch_num, test=True)) 100 | print("Test accuracy last and best (", args.stask, ") : ", acc_list[-2], "; ", acc_list[-1]) 101 | else: 102 | print("Test accuracy last (", args.stask, ") : ", acc_list[-1]) 103 | 104 | # output results 105 | if not os.path.exists(os.path.join(args.tmp_dir, folder_name)): 106 | os.mkdir(os.path.join(args.tmp_dir, folder_name)) 107 | results = {"acc": acc_list, "loss": loss_list} 108 | with open(os.path.join(args.tmp_dir, folder_name, "gpt2_"+args.stask+"_"+args.tuning_param+"_results.json"), "w") as f: 109 | f.write(json.dumps(results)) 110 | plot_loss(loss_list, os.path.join(args.tmp_dir, folder_name, "gpt2_"+args.stask+"_"+args.tuning_param+"_loss.pdf")) 111 | plot_loss(acc_list, os.path.join(args.tmp_dir, folder_name, "gpt2_"+args.stask+"_"+args.tuning_param+"_acc.pdf")) 112 | return 113 | 114 | -------------------------------------------------------------------------------- /src/llama.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import random 5 | import numpy as np 6 | from transformers import LlamaForCausalLM, LlamaConfig, get_constant_schedule_with_warmup, AdamW 7 | from transformers import AutoTokenizer, AutoModelForCausalLM 8 | from tqdm import tqdm 9 | import networkx as nx 10 | 11 | from utils import data_loader, plot_loss, avg, set_grad 12 | 13 | 14 | def evaluator(args, tokenizer, model, eval_data, test=False, head_mask=None, train_data=[]): 15 | # evaluation or test 16 | if not test: 17 | rand_idx = random.choices([i for i in range(len(eval_data))], k=1000) 18 | eval_data = [eval_data[idx] for idx in rand_idx] 19 | acc_list = [] 20 | # prompting 21 | with torch.no_grad(): 22 | for s in range(len(eval_data)): 23 | # zs prompt or fs icl 24 | icl_examples = "" 25 | if len(train_data): 26 | if head_mask is not None: # train_data: icl text 27 | icl_examples = train_data[s] 28 | else: # train_data: training data (need random selection) 29 | icl = random.sample(train_data, k=args.icl_num) 30 | for e in icl: 31 | icl_examples += e["context"] + " " + e["question"] + ": True or False?" + str(e["answer"]) + "\n" 32 | # prompt construction: proofwriiter label: True: 5852; False: 7700 33 | tmp_data = icl_examples + eval_data[s]["context"] + " " + eval_data[s]["question"] + ": True or False?" 34 | if eval_data[s]["answer"]: 35 | tmp_label = 0 36 | else: 37 | tmp_label = 1 38 | # inference 39 | inputs = tokenizer(tmp_data, return_tensors="pt", add_special_tokens=False).to(args.device) 40 | outputs = model(**inputs, labels=inputs["input_ids"]).logits.softmax(-1)[:, -1, :][:,[5852, 7700]] 41 | predicts = torch.argmax(outputs) 42 | if predicts == tmp_label: 43 | acc_list.append(1.) 44 | else: 45 | acc_list.append(0.) 46 | return avg(acc_list) 47 | 48 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from utils import fix_seed 5 | from gpt2 import finetune_gpt2 6 | 7 | from stask.attn import attn_analysis 8 | from stask.probe import probe_analysis, probe_layer 9 | from stask.prune import pruning_analysis, acc_calculation 10 | 11 | from rtask.proofwriter_probe import proofwriter_probe_analysis, proofwriter_corrupt_analysis 12 | from rtask.arc_probe import arc_probe_analysis 13 | 14 | 15 | def main_func(args): 16 | fix_seed(args.random_seed) 17 | if args.analysis_task == "finetune_ksmallest": 18 | finetune_gpt2(args) 19 | elif args.analysis_task == "attn_ksmallest": 20 | attn_analysis(args) 21 | attn_analysis(args, pos=True) 22 | elif args.analysis_task == "probing_ksmallest": 23 | probe_analysis(args) 24 | probe_layer(args) 25 | elif args.analysis_task == "causal_ksmallest": 26 | pruning_analysis(args) 27 | elif args.analysis_task == "probing_proofwriter": 28 | proofwriter_probe_analysis(args) 29 | elif args.analysis_task == "corrupt_proofwriter": 30 | proofwriter_corrupt_analysis(args) 31 | elif args.analysis_task == "probing_arc": 32 | arc_probe_analysis(args) 33 | else: 34 | raise Exception("Error analysis task name") 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser(description="MechanisticProbe") 38 | parser.add_argument("--analysis_task", type=str, default="finetune_ksmallest", 39 | help="Analysis task: [finetune_ksmallest, attn_ksmallest, probing_ksmallest, causal_ksmallest,\ 40 | probing_proofwriter, corrupt_proofwriter, probing_arc]") 41 | # basics 42 | parser.add_argument("--tmp_dir", type=str, default="./tmp/", 43 | help="the cache directory") 44 | parser.add_argument("--task_format", type=str, default="stask", 45 | help="the task for evaluation: [stask, rtask]") 46 | parser.add_argument("--random_seed", type=int, default=42, 47 | help="random seed for reproducibility.") 48 | parser.add_argument("--device", type=int, default=0, 49 | help="which GPU to use: set -1 to use CPU/prompt") 50 | parser.add_argument("--model_path", type=str, default="gpt2", 51 | help="the base model path, can be ./ by default ([llama-7b-hf, gpt2])") 52 | parser.add_argument("--epoch_num", type=int, default=2, # 2 for GPT-2 53 | help="the number of training epochs for fine-tuning LLM") 54 | 55 | # synthetic experiment (GPT-2) 56 | parser.add_argument("--tuning_param", type=str, default="all", 57 | help="the task of synthetic data: [all, attn, mlp, ln].") 58 | parser.add_argument("--sdata_num", type=int, default=int(1e6), 59 | help="the number of synthetic data.") 60 | parser.add_argument("--stask_len", type=int, default=16, 61 | help="the len of input list.") 62 | parser.add_argument("--stask_set", type=int, default=512, 63 | help="the sampling set size of input list.") 64 | parser.add_argument("--stask", type=str, default="min01", 65 | help="the task of synthetic data: [min00,..., ].") 66 | 67 | # real experiment (LLAMA) 68 | parser.add_argument("--dataset", type=str, default="proofwriter", 69 | help="the corresponding dataset for evaluation") 70 | parser.add_argument("--data_dir", type=str, default="./data/proofwriter", 71 | help="the data directory, can be ./ by default") 72 | parser.add_argument("--rtask", type=str, default="5", 73 | help="the depth of reasoning tree: [0, 1, 2, 3, 4, 5]") 74 | parser.add_argument("--icl_num", type=int, default=4, 75 | help="the number of examples for in-context learning prompt.") 76 | 77 | args = parser.parse_args() 78 | print(args) 79 | 80 | fix_seed(args.random_seed) 81 | 82 | if not os.path.exists(args.tmp_dir): 83 | os.mkdir(args.tmp_dir) 84 | 85 | main_func(args) 86 | -------------------------------------------------------------------------------- /src/proofparser.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/swarnaHub/PRover/blob/master/proof_utils.py 2 | 3 | class Node: 4 | def __init__(self, head): 5 | self.head = head 6 | def __str__(self): 7 | return str(self.head) 8 | 9 | def get_proof_graph(proof_str): 10 | stack = [] 11 | last_open = 0 12 | last_open_index = 0 13 | pop_list = [] 14 | all_edges = [] 15 | all_nodes = [] 16 | proof_str = proof_str.replace("(", " ( ") 17 | proof_str = proof_str.replace(")", " ) ") 18 | proof_str = proof_str.split() 19 | should_join = False 20 | for i in range(len(proof_str)): 21 | _s = proof_str[i] 22 | x = _s.strip() 23 | if len(x) == 0: 24 | continue 25 | if x == "(": 26 | stack.append((x, i)) 27 | last_open = len(stack) - 1 28 | last_open_index = i 29 | elif x == ")": 30 | for j in range(last_open + 1, len(stack)): 31 | if isinstance(stack[j][0], Node): 32 | pop_list.append((stack[j][1], stack[j][0])) 33 | stack = stack[:last_open] 34 | for j in range((len(stack))): 35 | if stack[j][0] == "(": 36 | last_open = j 37 | last_open_index = stack[j][1] 38 | elif x == '[' or x == ']': 39 | pass 40 | elif x == "->": 41 | should_join = True 42 | else: 43 | # terminal 44 | if x not in all_nodes: 45 | all_nodes.append(x) 46 | if should_join: 47 | new_pop_list = [] 48 | # Choose which ones to add the node to 49 | for (index, p) in pop_list: 50 | if index < last_open_index: 51 | new_pop_list.append((index, p)) 52 | else: 53 | all_edges.append((p.head, x)) 54 | pop_list = new_pop_list 55 | stack.append((Node(x), i)) 56 | should_join = False 57 | return all_nodes, all_edges 58 | 59 | def get_proof_graph_with_fail(proof_str): 60 | proof_str = proof_str[:-2].split("=")[1].strip()[1:-1] 61 | nodes = proof_str.split(" <- ") 62 | all_nodes = [] 63 | all_edges = [] 64 | for i in range(len(nodes)-1): 65 | all_nodes.append(nodes[i]) 66 | if nodes[i+1] != "FAIL": 67 | all_edges.append((nodes[i+1], nodes[i])) 68 | return all_nodes, all_edges -------------------------------------------------------------------------------- /src/rtask/arc_probe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import json 5 | import random 6 | import numpy as np 7 | from torch.optim import Adam 8 | from sklearn.metrics import f1_score, mean_squared_error 9 | from sklearn.neighbors import KNeighborsClassifier 10 | from sklearn.model_selection import cross_val_score 11 | from scipy.stats import pearsonr 12 | from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | from tqdm import tqdm 15 | 16 | import sys 17 | sys.path.append("..") 18 | from utils import data_loader, avg, plot_sim, fix_seed, label_parser 19 | from llama import evaluator 20 | 21 | 22 | def knn_classifier(args, task_type, attn_list, test_data, pred_type, depth, layerwise=-1): 23 | acc_list = [] 24 | label_list = [] 25 | new_attn_list = [] 26 | label_len = 0 27 | if depth > 1: label_len = 1 28 | for i in range(len(test_data)): 29 | label, sample_idx = label_parser(test_data[i], label_len, pred_type) 30 | new_attn = attn_list[i] 31 | if layerwise < 0: 32 | label = label[sample_idx] 33 | new_attn = new_attn[sample_idx] 34 | label_list.append(label) 35 | new_attn_list.append(new_attn) 36 | # print("The number of test data: ", len(label_list), len(test_data)) 37 | attn_list = new_attn_list 38 | attn_all = torch.cat(attn_list, dim=0).numpy() 39 | label_all = torch.squeeze(torch.cat(label_list, dim=0)) 40 | if pred_type != "noise": 41 | if len(label_all.shape) == 1: label_all = label_all.unsqueeze(-1) 42 | label_all_new = torch.argmax(label_all, dim=1) 43 | if layerwise >= 0: 44 | tmp_idx_list = [] 45 | for row in range(label_all.shape[0]): 46 | if torch.sum(label_all[row]) == 0: 47 | label_all_new[row] = 0. 48 | tmp_idx_list.append(row) 49 | else: 50 | if label_all_new[row] == layerwise: 51 | label_all_new[row] = 1. 52 | tmp_idx_list.append(row) 53 | label_all_new = label_all_new[tmp_idx_list] 54 | attn_all = attn_all[tmp_idx_list] 55 | label_all = label_all_new 56 | label_all = label_all.numpy() 57 | 58 | score_list = [] 59 | neigh = KNeighborsClassifier(n_neighbors=8, weights="distance", p=1) 60 | score_list.append(avg(cross_val_score(neigh, attn_all, label_all, cv=5, scoring='f1_macro'))) 61 | return [max(score_list)] 62 | 63 | 64 | def icl_analysis(args, model, tokenizer, test_dict, task_type, train_data=[], prune_list=[]): 65 | score_noise, score_depth = {}, {} 66 | for depth, test_data in test_dict.items(): 67 | attn_list = [] 68 | # get attention list 69 | with torch.no_grad(): 70 | for i in range(len(test_data)): 71 | data = test_data[i] 72 | # construct icl examples 73 | icl_examples = "" 74 | if len(train_data): 75 | icl = random.choices(train_data, k=args.icl_num) 76 | for e in icl: 77 | icl_examples += e["context"] + " " + e["question"] + str(e["answer"]) + "\n" 78 | # construct context 79 | context_text = "" 80 | index_list = [] 81 | for k, v in data["triples"].items(): 82 | context_text += v 83 | index_list.append(len(tokenizer.encode(v, add_special_tokens=False))) # end token index 84 | for k, v in data["rules"].items(): 85 | context_text += v 86 | index_list.append(len(tokenizer.encode(v, add_special_tokens=False))) # end token index 87 | context_text = icl_examples + context_text + " " + data["question"] 88 | sent_num = len(index_list) # get number of sentences 89 | index_list = [len(tokenizer.encode(icl_examples, add_special_tokens=False))] + index_list 90 | index_list = [sum(index_list[:j+1]) for j in range(len(index_list))] # get index 91 | index_list[0] += 1 92 | # get attention matrix 93 | inputs = tokenizer(context_text, return_tensors="pt", add_special_tokens=False).to(args.device) 94 | outputs = model(**inputs, output_attentions=True).attentions # [batch_num x attn head x attn matrix] 95 | tmp_attn_list = [] 96 | for j in range(len(outputs)): 97 | attn = torch.squeeze(outputs[j])[:, max(index_list):, :max(index_list)].cpu() # shape: [attn head x question_num x context_num] 98 | attn = torch.max(attn, dim=1)[0] # shape: [attn head x context_num] 99 | tmp_attn_list.append(attn) 100 | tmp_attn_list = [torch.mean(a, dim=0, keepdim=True) for a in tmp_attn_list] 101 | attn = torch.cat(tmp_attn_list, dim=0) 102 | attn = torch.transpose(attn, 0, 1) # shape: [input_len x layer_num] 103 | # get attn flow across sentences 104 | new_attn = torch.zeros(sent_num, attn.shape[1], device=attn.device) 105 | for j in range(sent_num): 106 | new_attn[j, :] = torch.mean(attn[index_list[j]:index_list[j+1], :], 0) 107 | if len(prune_list): 108 | layer_list = [i for i in range(new_attn.shape[1]) if i not in prune_list] 109 | new_attn = new_attn[:, layer_list] 110 | attn_list.append(new_attn) 111 | del new_attn 112 | # train & test probe model 113 | f1_noise = avg(knn_classifier(args, task_type, attn_list, test_data, "noise", int(depth))) 114 | f1_depth = avg(knn_classifier(args, task_type, attn_list, test_data, "depth", int(depth))) 115 | print(task_type, "depth: ", depth, " KNN Classifier F1-Macro (noise, depth): ", round(f1_noise, 6) , round(f1_depth, 6)) 116 | return score_noise, score_depth 117 | 118 | 119 | def ft_analysis(args, model, tokenizer, test_dict, task_type, prune_list=[]): 120 | score_noise, score_depth = {}, {} 121 | for depth, test_data in test_dict.items(): 122 | attn_list = [] 123 | # get attention list 124 | with torch.no_grad(): 125 | for i in range(len(test_data)): 126 | data = test_data[i] 127 | # construct context 128 | context_text = "" 129 | index_list = [] 130 | for k, v in data["triples"].items(): 131 | context_text += v 132 | index_list.append(len(tokenizer.encode(v, add_special_tokens=False))) # end token index 133 | for k, v in data["rules"].items(): 134 | context_text += v 135 | index_list.append(len(tokenizer.encode(v, add_special_tokens=False))) # end token index 136 | context_text = context_text + " " + data["question"] 137 | sent_num = len(index_list) # get number of sentences 138 | index_list = [0] + index_list 139 | index_list = [sum(index_list[:j+1]) for j in range(len(index_list))] # get index 140 | index_list[0] += 1 141 | # get attention matrix 142 | inputs = tokenizer(context_text, return_tensors="pt", add_special_tokens=False).to(args.device) 143 | outputs = model(**inputs, output_attentions=True).attentions # [batch_num x attn head x attn matrix] 144 | tmp_attn_list = [] 145 | for j in range(len(outputs)): 146 | # flows to question community 147 | attn = torch.squeeze(outputs[j])[:, max(index_list):, :max(index_list)].cpu() # shape: [attn head x question_num x context_num] 148 | attn = torch.mean(attn, dim=1) # shape: [attn head x context_num] 149 | tmp_attn_list.append(attn) 150 | tmp_attn_list = [torch.mean(a, dim=0, keepdim=True) for a in tmp_attn_list] 151 | attn = torch.cat(tmp_attn_list, dim=0) 152 | attn = torch.transpose(attn, 0, 1) # shape: [token_num x layer_num] 153 | # get attn flow across sentences 154 | new_attn = torch.zeros(sent_num, attn.shape[1], device=attn.device) 155 | for j in range(sent_num): 156 | new_attn[j, :] = torch.mean(attn[index_list[j]:index_list[j+1], :], 0) 157 | if len(prune_list): 158 | layer_list = [i for i in range(new_attn.shape[1]) if i not in prune_list] 159 | new_attn = new_attn[:, layer_list] 160 | attn_list.append(new_attn) 161 | del new_attn 162 | # train & test probe model 163 | f1_noise = avg(knn_classifier(args, task_type, attn_list, test_data, "noise", int(depth))) 164 | f1_depth = avg(knn_classifier(args, task_type, attn_list, test_data, "depth", int(depth))) 165 | print(task_type, "depth: ", depth, " KNN Classifier F1-Macro (noise, depth): ", round(f1_noise, 6) , round(f1_depth, 6)) 166 | return score_noise, score_depth 167 | 168 | 169 | def arc_probe_analysis(args): 170 | fix_seed(args.random_seed) 171 | if not os.path.exists(os.path.join(args.tmp_dir, "arc_probe")): 172 | os.mkdir(os.path.join(args.tmp_dir, "arc_probe")) 173 | # prepare data 174 | rdata_loader = data_loader(args) 175 | train_dict, dev_dict, test_dict = rdata_loader.return_data() 176 | train_data = [] 177 | for k, v in train_dict.items(): 178 | train_data += v 179 | # remove multi-choice samples 180 | test_dict = {depth: train_dict[depth]+dev_dict[depth]+test_dict[depth] for depth, _ in test_dict.items() if int(depth) <= 2} 181 | print("Sampled test data number: ", {k:len(v) for k, v in test_dict.items()}) 182 | tokenizer = LlamaTokenizer.from_pretrained(args.model_path) 183 | 184 | # analysis random baselines 185 | config=LlamaConfig() 186 | model = LlamaForCausalLM(config).to(args.device) 187 | score_noise_scratch, score_depth_scratch = ft_analysis(args, model, tokenizer, test_dict, "scratch") 188 | del model 189 | 190 | # analysis ICL setting 191 | model = LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16).to(args.device) 192 | prune_list = [31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17] 193 | score_noise_icl, score_depth_icl = icl_analysis(args, model, tokenizer, test_dict, "in-context learning", train_data, prune_list=prune_list) 194 | return 195 | -------------------------------------------------------------------------------- /src/rtask/proofwriter_probe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import json 5 | import random 6 | import numpy as np 7 | from torch.optim import Adam 8 | from sklearn.metrics import f1_score, mean_squared_error 9 | from sklearn.neighbors import KNeighborsClassifier 10 | from sklearn.model_selection import cross_val_score 11 | from scipy.stats import pearsonr 12 | from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | from tqdm import tqdm 15 | 16 | import sys 17 | sys.path.append("..") 18 | from utils import data_loader, avg, plot_sim, fix_seed, label_parser 19 | from llama import evaluator 20 | 21 | 22 | def knn_classifier(args, task_type, attn_list, test_data, pred_type, depth, layerwise=-1): 23 | acc_list = [] 24 | label_list = [] 25 | new_attn_list = [] 26 | label_len = 0 27 | if depth > 1: label_len = 1 28 | for i in range(len(test_data)): 29 | label, sample_idx = label_parser(test_data[i], label_len, pred_type) 30 | new_attn = attn_list[i] 31 | if layerwise < 0: 32 | label = label[sample_idx] 33 | new_attn = new_attn[sample_idx] 34 | label_list.append(label) 35 | new_attn_list.append(new_attn) 36 | # print("The number of test data: ", len(label_list), len(test_data)) 37 | attn_list = new_attn_list 38 | attn_all = torch.cat(attn_list, dim=0).numpy() 39 | label_all = torch.squeeze(torch.cat(label_list, dim=0)) 40 | if pred_type != "noise": 41 | if len(label_all.shape) == 1: label_all = label_all.unsqueeze(-1) 42 | label_all_new = torch.argmax(label_all, dim=1) 43 | if layerwise >= 0: 44 | tmp_idx_list = [] 45 | for row in range(label_all.shape[0]): 46 | if torch.sum(label_all[row]) == 0: 47 | label_all_new[row] = 0. 48 | tmp_idx_list.append(row) 49 | else: 50 | if label_all_new[row] == layerwise: 51 | label_all_new[row] = 1. 52 | tmp_idx_list.append(row) 53 | label_all_new = label_all_new[tmp_idx_list] 54 | attn_all = attn_all[tmp_idx_list] 55 | label_all = label_all_new 56 | label_all = label_all.numpy() 57 | score_list = [] 58 | neigh = KNeighborsClassifier(n_neighbors=8, weights="distance", p=1) 59 | score_list.append(avg(cross_val_score(neigh, attn_all, label_all, cv=5, scoring='f1_macro'))) 60 | return [max(score_list)] 61 | 62 | 63 | def icl_analysis(args, model, tokenizer, test_dict, task_type, train_data=[], prune_list=[], corr=False): 64 | # get accuracy 65 | acc_dict = {} 66 | for depth, data in test_dict.items(): 67 | acc_dict[depth] = evaluator(args, tokenizer, model, data, test=True, train_data=train_data) 68 | acc = acc_dict[depth] 69 | if not corr: print(" ICL testing accuracy: ", acc_dict) 70 | attn_dict = {} 71 | score_noise, score_depth = {}, {} 72 | for depth, test_data in test_dict.items(): 73 | attn_list = [] 74 | # get attention list 75 | with torch.no_grad(): 76 | for i in range(len(test_data)): 77 | data = test_data[i] 78 | # construct icl examples 79 | icl_examples = "" 80 | if len(train_data): 81 | icl = random.choices(train_data, k=args.icl_num) 82 | for e in icl: 83 | icl_examples += e["context"] + " " + e["question"] + ": True or False?" + str(e["answer"]) + "\n" 84 | # construct context 85 | context_text = "" 86 | index_list = [] 87 | for k, v in data["triples"].items(): 88 | context_text += v 89 | index_list.append(len(tokenizer.encode(v, add_special_tokens=False))) # end token index 90 | for k, v in data["rules"].items(): 91 | context_text += v 92 | index_list.append(len(tokenizer.encode(v, add_special_tokens=False))) # end token index 93 | context_text = icl_examples + context_text + " " + data["question"] + " True or False?" 94 | sent_num = len(index_list) # get number of sentences 95 | index_list = [len(tokenizer.encode(icl_examples, add_special_tokens=False))] + index_list 96 | index_list = [sum(index_list[:j+1]) for j in range(len(index_list))] # get index 97 | index_list[0] += 1 98 | # get attention matrix 99 | inputs = tokenizer(context_text, return_tensors="pt", add_special_tokens=False).to(args.device) 100 | outputs = model(**inputs, output_attentions=True).attentions # [batch_num x attn head x attn matrix] 101 | tmp_attn_list = [] 102 | for j in range(len(outputs)): 103 | attn = torch.squeeze(outputs[j])[:, max(index_list):, :max(index_list)].cpu() # shape: [attn head x question_num x context_num] 104 | attn = torch.max(attn, dim=1)[0] # shape: [attn head x context_num] 105 | tmp_attn_list.append(attn) 106 | tmp_attn_list = [torch.mean(a, dim=0, keepdim=True) for a in tmp_attn_list] 107 | attn = torch.cat(tmp_attn_list, dim=0) 108 | attn = torch.transpose(attn, 0, 1) # shape: [input_len x layer_num] 109 | # get attn flow across sentences 110 | new_attn = torch.zeros(sent_num, attn.shape[1], device=attn.device) 111 | for j in range(sent_num): 112 | new_attn[j, :] = torch.mean(attn[index_list[j]:index_list[j+1], :], 0) 113 | if len(prune_list): 114 | layer_list = [i for i in range(new_attn.shape[1]) if i not in prune_list] 115 | new_attn = new_attn[:, layer_list] 116 | attn_list.append(new_attn) 117 | del new_attn 118 | # train & test probe model 119 | f1_noise = avg(knn_classifier(args, task_type, attn_list, test_data, "noise", int(depth))) 120 | f1_depth = avg(knn_classifier(args, task_type, attn_list, test_data, "depth", int(depth))) 121 | if corr: 122 | return acc, f1_noise, f1_depth 123 | print(task_type, "m = ", depth, " KNN Classifier F1-Macro (noise, depth): ", round(f1_noise, 6) , round(f1_depth, 6)) 124 | return score_noise, score_depth 125 | 126 | 127 | def ft_analysis(args, model, tokenizer, test_dict, task_type, prune_list=[], corr=False): 128 | attn_dict = {} 129 | score_noise, score_depth = {}, {} 130 | for depth, test_data in test_dict.items(): 131 | attn_list = [] 132 | # get attention list 133 | with torch.no_grad(): 134 | for i in range(len(test_data)): 135 | data = test_data[i] 136 | # construct context 137 | context_text = "" 138 | index_list = [] 139 | for k, v in data["triples"].items(): 140 | context_text += v 141 | index_list.append(len(tokenizer.encode(v, add_special_tokens=False))) # end token index 142 | for k, v in data["rules"].items(): 143 | context_text += v 144 | index_list.append(len(tokenizer.encode(v, add_special_tokens=False))) # end token index 145 | context_text = context_text + " " + data["question"] + " True or False?" 146 | sent_num = len(index_list) # get number of sentences 147 | index_list = [0] + index_list 148 | index_list = [sum(index_list[:j+1]) for j in range(len(index_list))] # get index 149 | index_list[0] += 1 150 | # get attention matrix 151 | inputs = tokenizer(context_text, return_tensors="pt", add_special_tokens=False).to(args.device) 152 | outputs = model(**inputs, output_attentions=True).attentions # [batch_num x attn head x attn matrix] 153 | tmp_attn_list = [] 154 | for j in range(len(outputs)): 155 | # flows to question community 156 | attn = torch.squeeze(outputs[j])[:, max(index_list):, :max(index_list)].cpu() # shape: [attn head x question_num x context_num] 157 | attn = torch.mean(attn, dim=1) # shape: [attn head x context_num] 158 | tmp_attn_list.append(attn) 159 | tmp_attn_list = [torch.mean(a, dim=0, keepdim=True) for a in tmp_attn_list] 160 | attn = torch.cat(tmp_attn_list, dim=0) 161 | attn = torch.transpose(attn, 0, 1) # shape: [token_num x layer_num] 162 | # get attn flow across sentences 163 | new_attn = torch.zeros(sent_num, attn.shape[1], device=attn.device) 164 | for j in range(sent_num): 165 | new_attn[j, :] = torch.mean(attn[index_list[j]:index_list[j+1], :], 0) 166 | # new_attn[j, :] = torch.max(attn[index_list[j]:index_list[j+1], :], 0)[0] 167 | if len(prune_list): 168 | layer_list = [i for i in range(new_attn.shape[1]) if i not in prune_list] 169 | new_attn = new_attn[:, layer_list] 170 | attn_list.append(new_attn) 171 | del new_attn 172 | # train & test probe model 173 | f1_noise = avg(knn_classifier(args, task_type, attn_list, test_data, "noise", int(depth))) 174 | f1_depth = avg(knn_classifier(args, task_type, attn_list, test_data, "depth", int(depth))) 175 | if corr: 176 | return f1_noise, f1_depth 177 | print(task_type, "m = ", depth, " KNN Classifier F1-Macro (noise, depth): ", round(f1_noise, 6) , round(f1_depth, 6)) 178 | return score_noise, score_depth 179 | 180 | 181 | def proofwriter_probe_analysis(args): 182 | fix_seed(args.random_seed) 183 | if not os.path.exists(os.path.join(args.tmp_dir, "proofwriter_probe")): 184 | os.mkdir(os.path.join(args.tmp_dir, "proofwriter_probe")) 185 | # prepare data 186 | rdata_loader = data_loader(args) 187 | train_dict, dev_dict, test_dict = rdata_loader.return_data() 188 | train_data = [] 189 | for k, v in train_dict.items(): 190 | train_data += v 191 | # remove multi-choice samples 192 | test_dict = {depth: [d for d in test_dict[depth]+dev_dict[depth]+train_dict[depth] if "OR" not in d["proof"]] for depth, _ in test_dict.items() if int(depth) <= 1} 193 | test_dict_sample = {1: [], 2:[], 4:[], 8:[], 12:[], 16:[], 20:[], 24:[]} 194 | data_num = 1024 195 | for depth, data in test_dict.items(): 196 | random.shuffle(data) 197 | for d in data: 198 | label, sample_idx = label_parser(d, depth, "depth") 199 | label = label[sample_idx] 200 | if torch.sum(label) >= label.shape[1]: 201 | if int(depth) == 1: 202 | tmp_count = 0 203 | tmp_count += len(d["triples"]) 204 | tmp_count += len(d["rules"]) 205 | if tmp_count in test_dict_sample and len(test_dict_sample[tmp_count]) < data_num: 206 | test_dict_sample[tmp_count].append(d) 207 | elif int(depth) == 0 and len(test_dict_sample[1]) < data_num: 208 | test_dict_sample[1].append(d) 209 | test_dict = test_dict_sample 210 | print("Sampled data for probing (all depth=1, key=m): ", {k:len(v) for k, v in test_dict_sample.items()}) 211 | tokenizer = LlamaTokenizer.from_pretrained(args.model_path) 212 | # analysis random baselines 213 | config=LlamaConfig() 214 | model = LlamaForCausalLM(config).to(args.device) 215 | score_noise_scratch, score_depth_scratch = ft_analysis(args, model, tokenizer, test_dict, "scratch") 216 | del model 217 | # analysis ICL setting 218 | model = LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16).to(args.device) 219 | prune_list = [31, 30, 29, 27, 26, 25, 24, 23, 22, 21, 20, 18, 17] 220 | score_noise_icl, score_depth_icl = icl_analysis(args, model, tokenizer, test_dict, "in-context learning", train_data, prune_list=prune_list) 221 | return 222 | 223 | 224 | def corrput_func(args, model, tokenizer, test_dict_sample, train_data, ctype=""): 225 | test_data = test_dict_sample[4] 226 | new_test_data = [] 227 | for data in test_data: 228 | new_data = {k: v for k, v in data.items()} 229 | tmp_context = data["context"].split(". ") 230 | count = 10 231 | tmp_idx = 0 232 | while count < 10 or tmp_context[tmp_idx] in data['proof']: 233 | tmp_idx = random.choice([i for i in range(len(tmp_context))]) 234 | tmp_context[tmp_idx] = "That " + tmp_context[tmp_idx] + " is false" 235 | tmp_context = ". ".join(tmp_context) 236 | new_data["context"] = tmp_context 237 | new_test_data.append(new_data) 238 | # get accuracy 239 | test_dict = {4: new_test_data} 240 | acc_dict = {} 241 | for depth, data in test_dict.items(): 242 | acc_dict[depth] = evaluator(args, tokenizer, model, data, test=True, train_data=train_data) 243 | acc = acc_dict[depth] 244 | return acc 245 | 246 | 247 | def proofwriter_corrupt_analysis(args): 248 | fix_seed(args.random_seed) 249 | if not os.path.exists(os.path.join(args.tmp_dir, "proofwriter_probe")): 250 | os.mkdir(os.path.join(args.tmp_dir, "proofwriter_probe")) 251 | 252 | # prepare data 253 | rdata_loader = data_loader(args) 254 | train_dict, dev_dict, test_dict = rdata_loader.return_data() 255 | train_data = [] 256 | for k, v in train_dict.items(): 257 | train_data += v 258 | # remove multi-choice samples 259 | test_dict = {depth: [d for d in test_dict[depth]+dev_dict[depth]+train_dict[depth] if "OR" not in d["proof"]] for depth, _ in test_dict.items() if int(depth) == 1} 260 | test_dict_sample = {4: random.sample(v, k=128) for depth, v in test_dict.items()} 261 | tokenizer = LlamaTokenizer.from_pretrained(args.model_path) 262 | 263 | # get scratch 264 | config=LlamaConfig() 265 | model = LlamaForCausalLM(config).to(args.device) 266 | f1_noise_scratch, f1_depth_scratch = ft_analysis(args, model, tokenizer, test_dict_sample, "scratch", corr=True) 267 | del model 268 | 269 | # analysis ICL setting 270 | prune_list = [31, 30, 29, 27, 26, 25, 24, 23, 22, 21, 20, 18, 17] 271 | model = LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16).to(args.device) 272 | acc_list, sp1_list, sp2_list = [], [], [] 273 | corrput_shuffle, corrput_noise, corrput_counter = [], [], [] 274 | for i in tqdm(range(1024)): 275 | test_dict_sample = {4: random.sample(v, k=random.randint(64, 128)) for depth, v in test_dict.items()} 276 | acc, f1_noise, f1_depth = icl_analysis(args, model, tokenizer, test_dict_sample, "in-context learning", train_data, prune_list=prune_list, corr=True) 277 | acc_list.append(acc) 278 | sp1_list.append((f1_noise-f1_noise_scratch)/(1-f1_noise_scratch)) 279 | sp2_list.append((f1_depth-f1_depth_scratch)/(1-f1_depth_scratch)) 280 | # corruption 281 | acc_shuffle = 0. 282 | acc_noise = 0. 283 | acc_counter = corrput_func(args, model, tokenizer, test_dict_sample, train_data, "counter") 284 | corrput_shuffle.append(acc_shuffle-acc) 285 | corrput_noise.append(acc_noise-acc) 286 | corrput_counter.append(acc_counter-acc) 287 | del model 288 | results = {"acc": acc_list, "sp1": sp1_list, "sp2": sp2_list, "corrupt_shuffle": corrput_shuffle, "corrput_noise": corrput_noise, "corrput_counter": corrput_counter} 289 | save_path = os.path.join(args.tmp_dir, "proofwriter_probe", "corrput_results.json") 290 | with open(save_path, "w") as f: 291 | f.write(json.dumps(results)) 292 | # print results 293 | bin_num = 8 294 | bin_len = (max(sp2_list) - min(sp2_list)) / bin_num 295 | bins_shuf, bins_noise, bins_count, bins_acc = [], [], [], [] 296 | for b in range(bin_num): 297 | bin_lower = min(sp2_list) + bin_len*b 298 | bin_upper = min(sp2_list) + bin_len*(b+1) 299 | tmp_shuf, tmp_noise, tmp_count, tmp_acc = [], [], [], [] 300 | for j in range(len(sp2_list)): 301 | if sp2_list[j] >= bin_lower and sp2_list[j] < bin_upper: 302 | tmp_shuf.append(corrput_shuffle[j]) 303 | tmp_noise.append(corrput_noise[j]) 304 | tmp_count.append(corrput_counter[j]) 305 | tmp_acc.append(acc_list[j]) 306 | bins_shuf.append(avg(tmp_shuf)) 307 | bins_noise.append(avg(tmp_noise)) 308 | bins_count.append(avg(tmp_count)) 309 | bins_acc.append(avg(tmp_acc)) 310 | print("final results: ", bins_shuf, bins_noise, bins_count, bins_acc, min(sp2_list), max(sp2_list)) 311 | return 312 | -------------------------------------------------------------------------------- /src/stask/attn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import random 5 | import numpy as np 6 | from scipy.special import softmax 7 | from torch.optim import AdamW 8 | from transformers import AutoTokenizer, AutoModelForCausalLM, get_constant_schedule_with_warmup 9 | from tqdm import tqdm 10 | import networkx as nx 11 | 12 | import sys 13 | sys.path.append("..") 14 | from utils import data_generator, plot_sim 15 | 16 | 17 | def attn_calculation(args, model, tokenizer, test_data, save_path, title, pos): 18 | # get attention list 19 | with torch.no_grad(): 20 | dist_mat = None 21 | for i in tqdm(range(len(test_data))): 22 | data = test_data[i] 23 | data_num = data.split(" ") 24 | data_sorted = sorted([int(d) for d in data_num]) 25 | if pos: 26 | label_idx = [i for i in range(len(data_sorted[:16]))] 27 | else: 28 | label_idx = [data_num.index(str(d)) for d in data_sorted[:16]] 29 | # get attention matrix 30 | inputs = tokenizer(data, return_tensors="pt").to(args.device) 31 | outputs = model(**inputs, output_attentions=True).attentions # [batch_num x attn head x attn matrix] 32 | tmp_dist_mat = [] 33 | for attn in outputs: 34 | attn = torch.squeeze(torch.mean(attn, dim=1))[-1,:].cpu().numpy() # shape: [16] 35 | tmp_dist = [attn[l_idx] for l_idx in label_idx] 36 | tmp_dist_mat.append(tmp_dist) 37 | if dist_mat is None: 38 | dist_mat = np.array(tmp_dist_mat) 39 | else: 40 | dist_mat += np.array(tmp_dist_mat) 41 | dist_mat = np.transpose(dist_mat) / len(test_data) 42 | if pos: 43 | ylabel = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", \ 44 | "9th", "10th", "11th", "12th", "13th", "14th", "15th", "16th"] 45 | else: 46 | ylabel = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", \ 47 | "9th", "10th", "11th", "12th", "13th", "14th", "15th", "16th"] 48 | plot_sim(dist_mat, save_path, title, xlabel=[str(i+1) for i in range(len(outputs))], ylabel1=ylabel, pos=pos) 49 | return 50 | 51 | 52 | def attn_analysis(args, pos=False): 53 | if pos: 54 | folder_name = "attn_pos" 55 | else: 56 | folder_name = "attn" 57 | if not os.path.exists(os.path.join(args.tmp_dir, folder_name)): 58 | os.mkdir(os.path.join(args.tmp_dir, folder_name)) 59 | # get all folders (GPT-2) 60 | for root, dirs, files in os.walk(os.path.join(args.tmp_dir, "training")): 61 | folders = dirs 62 | break 63 | folders = [args.model_path] + [os.path.join(args.tmp_dir, "training", f_path) for f_path in folders if f_path[:8]=="gpt2_min"] 64 | model_label = ["org"] + [f_name[20:20+5] for f_name in folders[1:]] 65 | # set tokenizer 66 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 67 | # prepare data 68 | sdata_generator = data_generator(args, tokenizer) 69 | _, _, test_data = sdata_generator.return_data(data=True) 70 | print(folders, model_label) 71 | for i in range(len(folders)): 72 | if "attn" in folders[i]: 73 | save_path = os.path.join(args.tmp_dir, folder_name, "attn_analysis-attn_"+model_label[i]+".pdf") 74 | elif "mlp" in folders[i]: 75 | save_path = os.path.join(args.tmp_dir, folder_name, "attn_analysis-mlp_"+model_label[i]+".pdf") 76 | elif "ln" in folders[i]: 77 | save_path = os.path.join(args.tmp_dir, folder_name, "attn_analysis-ln_"+model_label[i]+".pdf") 78 | else: 79 | save_path = os.path.join(args.tmp_dir, folder_name, "attn_analysis-all_"+model_label[i]+".pdf") 80 | title = model_label[i] 81 | model = AutoModelForCausalLM.from_pretrained(folders[i]).to(args.device) 82 | attn_calculation(args, model, tokenizer, test_data, save_path, title, pos) 83 | return 84 | -------------------------------------------------------------------------------- /src/stask/probe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import random 5 | import math 6 | import numpy as np 7 | from sklearn.metrics import f1_score 8 | from sklearn.neighbors import KNeighborsClassifier 9 | from sklearn.metrics import pairwise_distances 10 | from scipy.stats import pearsonr 11 | from torch.optim import Adam 12 | import torch.nn as nn 13 | from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Config, GPT2LMHeadModel 14 | from tqdm import tqdm 15 | import networkx as nx 16 | 17 | import sys 18 | sys.path.append("..") 19 | from utils import data_generator, avg, plot_sim 20 | 21 | 22 | def label_generator(data, task_name, pred_type): 23 | data_num = data.split(" ") 24 | data_sorted = sorted([int(d) for d in data_num]) 25 | label_idx = [data_sorted.index(int(d)) for d in data_num] 26 | if pred_type == "noise": 27 | label_tensor = torch.zeros(len(data_num), 1) 28 | for i in range(task_name+1): 29 | if i not in label_idx: # repeat numbers here 30 | k = i 31 | while k not in label_idx: 32 | k = k-1 33 | label_tensor[label_idx.index(k), 0] = 1 34 | else: 35 | label_tensor[label_idx.index(i), 0] = 1 36 | if task_name not in label_idx: # repeat numbers here 37 | k = task_name 38 | while k not in label_idx: 39 | k = k-1 40 | label_tensor[label_idx.index(k), 0] = 1 41 | else: 42 | label_tensor[label_idx.index(task_name), 0] = 1 43 | # return torch.flatten(label_tensor) 44 | return label_tensor, [i for i in range(label_tensor.shape[0])] 45 | else: 46 | label_tensor = torch.zeros(len(data_num), 2) 47 | for i in range(task_name+1): 48 | if i not in label_idx: # repeat numbers here 49 | k = i 50 | while k not in label_idx: 51 | k = k-1 52 | label_tensor[label_idx.index(k), 0] = 1 53 | else: 54 | label_tensor[label_idx.index(i), 0] = 1 55 | if task_name not in label_idx: # repeat numbers here 56 | k = task_name 57 | while k not in label_idx: 58 | k = k-1 59 | label_tensor[label_idx.index(k), 1] = 1 60 | else: 61 | label_tensor[label_idx.index(task_name), 1] = 1 62 | # return torch.flatten(label_tensor) 63 | sample_idx = [] 64 | for i in range(label_tensor.shape[0]): 65 | if torch.sum(label_tensor[i, :]) != 0: 66 | sample_idx.append(i) 67 | return label_tensor, sample_idx 68 | 69 | 70 | def knn_classifier(args, task_name, attn_list, test_data, pred_type="depth"): 71 | acc_list = [] 72 | label_list = [] 73 | new_attn_list = [] 74 | 75 | for i in range(len(test_data)): 76 | label, sample_idx = label_generator(test_data[i], task_name, pred_type) 77 | label_list.append(label[sample_idx]) 78 | new_attn_list.append(attn_list[i][sample_idx]) 79 | attn_list = new_attn_list 80 | attn_all = torch.cat(attn_list, dim=0).numpy() 81 | label_all = torch.squeeze(torch.cat(label_list, dim=0)).numpy() 82 | 83 | neigh = KNeighborsClassifier(n_neighbors=8, weights='distance') 84 | neigh.fit(attn_all[:int(len(label_list)/2)], label_all[:int(len(label_list)/2)]) 85 | pred = neigh.predict(attn_all[int(len(label_list)/2):]) 86 | targt = label_all[int(len(label_list)/2):] 87 | acc_list.append(f1_score(targt, pred, average="macro")) 88 | 89 | neigh = KNeighborsClassifier(n_neighbors=8, weights='distance') 90 | neigh.fit(attn_all[int(len(label_list)/2):], label_all[int(len(label_list)/2):]) 91 | pred = neigh.predict(attn_all[:int(len(label_list)/2)]) 92 | targt = label_all[:int(len(label_list)/2)] 93 | acc_list.append(f1_score(targt, pred, average="macro")) 94 | return acc_list 95 | 96 | 97 | def probe_calculation(args, model, tokenizer, test_data, model_name, layer_wise=False): 98 | attn_list = [] 99 | # get attention list 100 | with torch.no_grad(): 101 | for i in tqdm(range(len(test_data))): 102 | data = test_data[i] 103 | # get attention matrix 104 | inputs = tokenizer(data, return_tensors="pt").to(args.device) 105 | outputs = model(**inputs, output_attentions=True).attentions # [batch_num x attn head x attn matrix] 106 | tmp_attn_list = [] 107 | for j in range(len(outputs)): 108 | attn = torch.squeeze(outputs[j])[:,-1,:].cpu() # shape: [12x16] 109 | tmp_attn_list.append(attn) 110 | tmp_attn_list = [torch.mean(a, dim=0, keepdim=True) for a in tmp_attn_list] 111 | attn = torch.cat(tmp_attn_list, dim=0) 112 | attn = torch.transpose(attn, 0, 1) # shape: [16x12] 113 | attn = attn/torch.max(attn) 114 | attn_list.append(attn) 115 | if layer_wise: 116 | task_name = int(model_name[-2:]) 117 | for l in range(attn_list[0].shape[1]): 118 | tmp_attn_list = [a[:, :(l+1)] for a in attn_list] 119 | f1_noise_list = avg(knn_classifier(args, task_name, tmp_attn_list, test_data, "noise")) 120 | f1_depth_list = avg(knn_classifier(args, task_name, tmp_attn_list, test_data, "depth")) 121 | print(model_name, "Layer: " + str(l) + " KNN classifier (noise, depth) F1-Macro: ", round(f1_noise_list, 6), round(f1_depth_list, 6)) 122 | else: 123 | # train & test probe model 124 | if "pretrained" in model_name or "scratch" in model_name: 125 | f1_noise_list, f1_depth_list = [], [] 126 | for tn in range(int(args.stask_len/2)): 127 | f1_noise = avg(knn_classifier(args, tn, attn_list, test_data, "noise")) 128 | f1_depth = avg(knn_classifier(args, tn, attn_list, test_data, "depth")) 129 | print(model_name, "k=", tn, " KNN classifier (noise, depth) F1-Macro: ", round(f1_noise, 6), round(f1_depth, 6)) 130 | f1_noise_list.append(f1_noise) 131 | f1_depth_list.append(f1_depth) 132 | else: 133 | task_name = int(model_name[-2:]) 134 | f1_noise = avg(knn_classifier(args, task_name, attn_list, test_data, "noise")) 135 | f1_depth = avg(knn_classifier(args, task_name, attn_list, test_data, "depth")) 136 | print(model_name, " KNN classifier (noise, depth) F1-Macro: ", round(f1_noise, 6), round(f1_depth, 6)) 137 | f1_noise_list = f1_noise 138 | f1_depth_list = f1_depth 139 | return f1_noise_list, f1_depth_list 140 | 141 | 142 | def probe_analysis(args): 143 | folder_name = "training" 144 | # folder_name = args.model_path.split("/")[-1] 145 | if not os.path.exists(os.path.join(args.tmp_dir, "probe")): 146 | os.mkdir(os.path.join(args.tmp_dir, "probe")) 147 | # get all folders (GPT-2) 148 | for root, dirs, files in os.walk(os.path.join(args.tmp_dir, folder_name)): 149 | folders = dirs 150 | break 151 | folders = ["scratch", args.model_path] + [os.path.join(args.tmp_dir, folder_name, f_path) for f_path in folders if f_path[:8]=="gpt2_min"] 152 | model_label = ["scratch", "pretrained"] + [f_name[12+len(folder_name):12+len(folder_name)+5] for f_name in folders[2:]] 153 | # set tokenizer 154 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 155 | # prepare data 156 | sdata_generator = data_generator(args, tokenizer) 157 | _, _, test_data = sdata_generator.return_data(data=True) 158 | test_data = random.sample(test_data, k=1024) 159 | print(folders, model_label) 160 | for i in range(len(folders)): 161 | # if i < 2: continue 162 | if folders[i] == "scratch": 163 | config = GPT2Config() 164 | model = GPT2LMHeadModel(config=config).to(args.device) 165 | f1_noise_scratch, f1_depth_scratch = probe_calculation(args, model, tokenizer, test_data, model_label[i]) 166 | elif model_label[i] == "pretrained": 167 | model = AutoModelForCausalLM.from_pretrained(folders[i]).to(args.device) 168 | f1_noise_pretrained, f1_depth_pretrained = probe_calculation(args, model, tokenizer, test_data, model_label[i]) 169 | for j in range(len(f1_noise_pretrained)): 170 | sp1 = (f1_noise_pretrained[j]-f1_noise_scratch[j]) / (1-f1_noise_scratch[j]) 171 | sp2 = (f1_depth_pretrained[j]-f1_depth_scratch[j]) / (1-f1_depth_scratch[j]) 172 | print("====>>>> The pretrained GPT-2 probing scores are: (k=", str(j), ")", round(sp1, 6), round(sp2, 6)) 173 | else: 174 | model = AutoModelForCausalLM.from_pretrained(folders[i]).to(args.device) 175 | f1_noise, f1_depth = probe_calculation(args, model, tokenizer, test_data, model_label[i]) 176 | label_k = int(model_label[i][3:]) 177 | sp1 = (f1_noise-f1_noise_scratch[label_k]) / (1-f1_noise_scratch[label_k]) 178 | sp2 = (f1_depth-f1_depth_scratch[label_k]) / (1-f1_depth_scratch[label_k]) 179 | print("====>>>> The finetuned GPT-2 probing scores are: ", round(sp1, 6), round(sp2, 6)) 180 | return 181 | 182 | 183 | def probe_layer(args): 184 | folder_name = "training" 185 | for root, dirs, files in os.walk(os.path.join(args.tmp_dir, folder_name)): 186 | folders = dirs 187 | break 188 | folders = [os.path.join(args.tmp_dir, folder_name, f_path) for f_path in folders if f_path[:8]=="gpt2_min" and "all" in f_path] 189 | model_label = [f_name[20:20+5] for f_name in folders] 190 | # set tokenizer 191 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 192 | # prepare data 193 | sdata_generator = data_generator(args, tokenizer) 194 | _, _, test_data = sdata_generator.return_data(data=True) 195 | test_data = random.sample(test_data, k=1024) 196 | for i in range(len(folders)): 197 | model = AutoModelForCausalLM.from_pretrained(folders[i]).to(args.device) 198 | probe_calculation(args, model, tokenizer, test_data, model_label[i], layer_wise=True) 199 | return 200 | -------------------------------------------------------------------------------- /src/stask/prune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import json 5 | import random 6 | import numpy as np 7 | from transformers import AutoTokenizer, AutoModelForCausalLM, get_constant_schedule_with_warmup 8 | from tqdm import tqdm 9 | from scipy.stats import entropy 10 | 11 | import sys 12 | sys.path.append("..") 13 | from utils import data_generator, plot_entropy, plot_flow 14 | from gpt2 import evaluator 15 | 16 | 17 | def head_calculation(args, model, tokenizer, test_data, save_path="", title=""): 18 | ''' 19 | # get head entropy map 20 | entropy_pos = np.zeros((12,12)) 21 | entropy_size = np.zeros((12,12)) 22 | # get attention list 23 | with torch.no_grad(): 24 | for i in tqdm(range(len(test_data))): 25 | data = test_data[i] 26 | data_num = data.split(" ") 27 | data_sorted = sorted([int(d) for d in data_num]) 28 | label_idx = [data_num.index(str(d)) for d in data_sorted[:16]] 29 | # get attention matrix 30 | inputs = tokenizer(data, return_tensors="pt").to(args.device) 31 | outputs = model(**inputs, output_attentions=True).attentions # [1 x attn head x N x N] 32 | tmp_size, tmp_pos = [], [] 33 | for attn in outputs: 34 | attn = attn[:,:,-1,:].cpu().numpy() 35 | attn_pos = attn # [1 x attn head x N] 36 | tmp_pos.append(attn_pos) 37 | attn_size = [attn[:,:,l_idx:l_idx+1] for l_idx in label_idx] 38 | attn_size = np.concatenate(attn_size, axis=-1) 39 | tmp_size.append(attn_size) 40 | attn_pos = np.concatenate(tmp_pos, axis=0) # [layer_num x attn head x N] 41 | attn_size = np.concatenate(tmp_size, axis=0) # [layer_num x attn head x N] 42 | for l in range(attn_pos.shape[0]): 43 | for h in range(attn_pos.shape[1]): 44 | entropy_pos[l, h] += entropy(attn_pos[l,h,:]) / math.log(attn_pos.shape[2]) 45 | entropy_size[l, h] += entropy(attn_size[l,h,:]) / math.log(attn_size.shape[2]) 46 | entropy_pos = entropy_pos / len(test_data) # [layer_num x attn head x N] 47 | entropy_size = entropy_size / len(test_data) # [layer_num x attn head x N] 48 | xlabel = [str(i+1) for i in range(12)] 49 | ylabel1 = xlabel 50 | if len(save_path): 51 | plot_entropy(np.transpose(entropy_pos), save_path.replace(".pdf", "_pos_em.pdf"), xlabel=xlabel, ylabel1=ylabel1) 52 | plot_entropy(np.transpose(entropy_size), save_path.replace(".pdf", "_size_em.pdf"), xlabel=xlabel, ylabel1=ylabel1) 53 | ''' 54 | # get attention list 55 | attn_pos_all, attn_size_all = None, None 56 | with torch.no_grad(): 57 | for i in tqdm(range(len(test_data))): 58 | data = test_data[i] 59 | data_num = data.split(" ") 60 | data_sorted = sorted([int(d) for d in data_num]) 61 | label_idx = [data_num.index(str(d)) for d in data_sorted[:16]] 62 | # get attention matrix 63 | inputs = tokenizer(data, return_tensors="pt").to(args.device) 64 | outputs = model(**inputs, output_attentions=True).attentions # [1 x attn head x N x N] 65 | tmp_size, tmp_pos = [], [] 66 | for attn in outputs: 67 | attn = attn[:,:,-1,:].cpu().numpy() 68 | attn_pos = attn # [1 x attn head x N] 69 | tmp_pos.append(attn_pos) 70 | attn_size = [attn[:,:,l_idx:l_idx+1] for l_idx in label_idx] 71 | attn_size = np.concatenate(attn_size, axis=-1) 72 | tmp_size.append(attn_size) 73 | attn_pos = np.concatenate(tmp_pos, axis=0) # [layer_num x attn head x N] 74 | if attn_pos_all is None: 75 | attn_pos_all = attn_pos 76 | else: 77 | attn_pos_all += attn_pos 78 | attn_size = np.concatenate(tmp_size, axis=0) # [layer_num x attn head x N] 79 | if attn_size_all is None: 80 | attn_size_all = attn_size 81 | else: 82 | attn_size_all += attn_size 83 | attn_pos_all = attn_pos_all / len(test_data) # [layer_num x attn head x N] 84 | attn_size_all = attn_size_all / len(test_data) # [layer_num x attn head x N] 85 | # get head entropy map 86 | entropy_pos = np.zeros(attn_pos_all.shape[:2]) 87 | entropy_size = np.zeros(attn_size_all.shape[:2]) 88 | for l in range(attn_pos_all.shape[0]): 89 | for h in range(attn_pos_all.shape[1]): 90 | entropy_pos[l, h] = entropy(attn_pos_all[l,h,:]) / math.log(attn_pos_all.shape[2]) 91 | entropy_size[l, h] = entropy(attn_size_all[l,h,:]) / math.log(attn_size_all.shape[2]) 92 | xlabel = [str(i+1) for i in range(12)] 93 | ylabel1 = xlabel 94 | if len(save_path): 95 | plot_entropy(np.transpose(entropy_pos), save_path.replace(".pdf", "_pos.pdf"), xlabel=xlabel, ylabel1=ylabel1) 96 | plot_entropy(np.transpose(entropy_size), save_path.replace(".pdf", "_size.pdf"), xlabel=xlabel, ylabel1=ylabel1) 97 | return entropy_pos, entropy_size 98 | 99 | 100 | def get_head_mask(e, p, drop="min"): # [layer_num x attn head x N] 101 | entropy = e.copy() 102 | head_mask = torch.zeros(entropy.shape[0], entropy.shape[1]) + 1 103 | p_num = int(p / 100 * entropy.shape[0] * entropy.shape[1]) 104 | p_count = 0 105 | while p_count <= p_num: 106 | # find the minimum entropy head 107 | if drop == "max": 108 | idxl, idxh = np.where(entropy==entropy.max()) 109 | ridx = random.choice([i for i in range(len(idxl))]) 110 | idxl = int(idxl[ridx]) 111 | idxh = int(idxh[ridx]) 112 | entropy[idxl, idxh] = 0. 113 | elif drop == "min": 114 | idxl, idxh = np.where(entropy==entropy.min()) 115 | tmp_num = 1. 116 | ridx = random.choice([i for i in range(len(idxl))]) 117 | idxl = int(idxl[ridx]) 118 | idxh = int(idxh[ridx]) 119 | entropy[idxl, idxh] = 1. 120 | else: 121 | idxl = random.choice([i for i in range(entropy.shape[0])]) 122 | idxh = random.choice([i for i in range(entropy.shape[1])]) 123 | # if sum(head_mask[idxl, :]) > 1: 124 | if head_mask[idxl, idxh] == 1: 125 | p_count += 1 126 | head_mask[idxl, idxh] = 0 127 | return head_mask 128 | 129 | 130 | def acc_calculation(args, model, tokenizer, test_data, test_labels, save_path, title, entropy_pos, entropy_size): 131 | org_acc = evaluator(args, tokenizer, model, test_data, test_labels, 128, True) 132 | acc_list = [] 133 | drop_list = [10, 20, 30, 40, 50, 60, 70, 80, 90] 134 | for d in drop_list: 135 | acc_list.append(evaluator(args, tokenizer, model, test_data, test_labels, 128, True, get_head_mask(entropy_pos, d, "rand"))) 136 | print("The accuracy (drop randomly) is: ", [org_acc] + [round(acc, 6) for acc in acc_list]) 137 | acc_list = [] 138 | for d in drop_list: 139 | acc_list.append(evaluator(args, tokenizer, model, test_data, test_labels, 128, True, get_head_mask(entropy_pos, d, "min"))) 140 | print("The accuracy (drop min pos entropy first) is: ", [org_acc] + [round(acc, 6) for acc in acc_list]) 141 | acc_list = [] 142 | for d in drop_list: 143 | acc_list.append(evaluator(args, tokenizer, model, test_data, test_labels, 128, True, get_head_mask(entropy_size, d, "min"))) 144 | print("The accuracy (drop min size entropy first) is: ", [org_acc] + [round(acc, 6) for acc in acc_list]) 145 | return 146 | 147 | 148 | def pruning_analysis(args): 149 | folder_name = "prune" 150 | if not os.path.exists(os.path.join(args.tmp_dir, folder_name)): 151 | os.mkdir(os.path.join(args.tmp_dir, folder_name)) 152 | # get all folders (GPT-2) 153 | for root, dirs, files in os.walk(os.path.join(args.tmp_dir, "training")): 154 | folders = dirs 155 | break 156 | folders = [os.path.join(args.tmp_dir, "training", f_path) for f_path in folders if f_path[:8]=="gpt2_min"] 157 | new_folders = [] 158 | for i in range(len(folders)): 159 | if "attn" in folders[i]: 160 | continue 161 | elif "mlp" in folders[i]: 162 | continue 163 | elif "ln" in folders[i]: 164 | continue 165 | else: 166 | new_folders.append(folders[i]) 167 | folders = new_folders 168 | model_label = [f_name[20:20+5] for f_name in folders] 169 | # set tokenizer 170 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 171 | print(folders, model_label) 172 | for i in range(len(folders)): 173 | # prepare data 174 | args.stask = model_label[i][:5] 175 | sdata_generator = data_generator(args, tokenizer) 176 | _, _, test_data = sdata_generator.return_data(data=True) 177 | _, _, test_labels = sdata_generator.return_data(data=False) 178 | 179 | save_path = os.path.join(args.tmp_dir, folder_name, "pruning_analysis-all_"+model_label[i]+".pdf") 180 | title = model_label[i] 181 | model = AutoModelForCausalLM.from_pretrained(folders[i]).to(args.device) 182 | entropy_pos, entropy_size = head_calculation(args, model, tokenizer, test_data, save_path, title) 183 | acc_calculation(args, model, tokenizer, test_data, test_labels, save_path, title, entropy_pos, entropy_size) 184 | 185 | # visualize case study 186 | if not os.path.exists(os.path.join(args.tmp_dir, "gpt2_cs")): 187 | os.mkdir(os.path.join(args.tmp_dir, "gpt2_cs")) 188 | args.stask = "min01" 189 | sdata_generator = data_generator(args, tokenizer) 190 | _, _, test_data = sdata_generator.return_data(data=True) 191 | 192 | f_path = [f for f in folders if "min01" in f][0] 193 | model = AutoModelForCausalLM.from_pretrained(f_path).to(args.device) 194 | entropy_pos, entropy_size = head_calculation(args, model, tokenizer, test_data,) 195 | head_mask_pos = get_head_mask(entropy_pos, 40, "min").to(args.device) # remove 40% position heads ([layer_num x attn head]) 196 | attn_mean_all, pos_mean_all = None, None 197 | pos_00, pos_01 = 8, 12 198 | with torch.no_grad(): 199 | for data in test_data: 200 | inputs = tokenizer(data, return_tensors="pt").to(args.device) 201 | outputs = model(**inputs, output_attentions=True, head_mask=head_mask_pos).attentions # [1 x attn head x N x N] 202 | attn_pos = torch.cat(outputs, dim=0) # [layer_num x attn head x N x N] 203 | mask_sum = torch.sum(head_mask_pos, dim=1).unsqueeze(-1).unsqueeze(-1) 204 | attn_pos = torch.sum(attn_pos, dim=1) / mask_sum # [layer_num x N x N] 205 | attn_pos = attn_pos.cpu().numpy() 206 | if pos_mean_all is None: 207 | pos_mean_all = attn_pos 208 | else: 209 | pos_mean_all += attn_pos 210 | 211 | data_num = data.split(" ") 212 | if len(data_num) != len(set(data_num)): continue 213 | data_sorted = sorted([int(d) for d in data_num]) 214 | data_min00, data_min01 = str(data_sorted[0]), str(data_sorted[1]) 215 | data_drop_00, data_drop_01 = data_num[pos_00-1], data_num[pos_01-1] 216 | data_num[data_num.index(data_min00)] = data_drop_00 217 | data_num[data_num.index(data_min01)] = data_drop_01 218 | data_num[pos_00-1] = data_min00 219 | data_num[pos_01-1] = data_min01 220 | data = " ".join(data_num) 221 | 222 | inputs = tokenizer(data, return_tensors="pt").to(args.device) 223 | outputs = model(**inputs, output_attentions=True, head_mask=head_mask_pos).attentions # [1 x attn head x N x N] 224 | attn = torch.cat(outputs, dim=0) # [layer_num x attn head x N x N] 225 | mask_sum = torch.sum(head_mask_pos, dim=1).unsqueeze(-1).unsqueeze(-1) 226 | attn_mean = torch.sum(attn, dim=1) / mask_sum # [layer_num x N x N] 227 | attn_mean = attn_mean.cpu().numpy() 228 | if attn_mean_all is None: 229 | attn_mean_all = attn_mean 230 | else: 231 | attn_mean_all += attn_mean 232 | attn_mean_all = attn_mean_all/len(test_data) 233 | pos_mean_all = pos_mean_all/len(test_data) 234 | xlabel=[""] + [str(i+1) for i in range(len(outputs))] 235 | ylabel = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16"] 236 | ylabel.reverse() 237 | save_path = os.path.join(args.tmp_dir, "gpt2_cs", "cs_min01_all.pdf") 238 | plot_flow(pos_mean_all, save_path, "", xlabel=xlabel, ylabel=ylabel, cs=True) 239 | ylabel = ["1", "2", "3", "4", "5", "6", "7", "8 (leaf)", "9", "10", "11", "12 (root)", "13", "14", "15", "16"] 240 | ylabel.reverse() 241 | save_path = os.path.join(args.tmp_dir, "gpt2_cs", "cs_min01_"+str(pos_00)+"_"+str(pos_01)+".pdf") 242 | plot_flow(attn_mean_all, save_path, "", xlabel=xlabel, ylabel=ylabel, cs=True, pos1=pos_00, pos2=pos_01) 243 | 244 | return 245 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import random 5 | import numpy as np 6 | import networkx as nx 7 | from networkx.algorithms.dag import dag_longest_path 8 | import torch 9 | from transformers import set_seed 10 | import matplotlib.pyplot as plt 11 | import matplotlib.colors as colors 12 | from sklearn.preprocessing import normalize 13 | from sklearn.pipeline import make_pipeline 14 | from sklearn.preprocessing import StandardScaler 15 | from sklearn.model_selection import train_test_split 16 | from sklearn.decomposition import PCA 17 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 18 | from sklearn.neighbors import KNeighborsClassifier, NeighborhoodComponentsAnalysis 19 | from sklearn.manifold import SpectralEmbedding, Isomap, LocallyLinearEmbedding, MDS, TSNE 20 | 21 | from proofparser import get_proof_graph 22 | 23 | 24 | def fix_seed(seed): 25 | # random 26 | random.seed(seed) 27 | # Numpy 28 | np.random.seed(seed) 29 | # PyTorch 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) 32 | # torch.backends.cudnn.deterministic = True 33 | # Transformers 34 | # set_seed(seed) 35 | return 36 | 37 | 38 | def set_grad(model, tuning_params="all"): 39 | if tuning_params == "attn": 40 | for name, param in model.named_parameters(): 41 | if "attn" in name: 42 | param.requires_grad = True 43 | else: 44 | param.requires_grad = False 45 | elif tuning_params == "mlp": 46 | for name, param in model.named_parameters(): 47 | if "mlp" in name: 48 | param.requires_grad = True 49 | else: 50 | param.requires_grad = False 51 | elif tuning_params == "ln": 52 | for name, param in model.named_parameters(): 53 | if "ln" in name: 54 | param.requires_grad = True 55 | else: 56 | param.requires_grad = False 57 | elif tuning_params == "lora": 58 | from transformers.adapters import LoRAConfig 59 | config = LoRAConfig(r=8, alpha=16) 60 | model.add_adapter("lora_adapter", config=config) 61 | model.train_adapter("lora_adapter") 62 | else: 63 | for name, param in model.named_parameters(): 64 | param.requires_grad = True 65 | return 66 | 67 | 68 | def avg(l): 69 | if len(l): 70 | return sum(l)/len(l) 71 | else: 72 | return 0 73 | 74 | 75 | def label_parser(data, depth, pred_type="noise", corrput=""): 76 | depth = int(depth) 77 | # get context_key 78 | context_key = [] 79 | for k, v in data["triples"].items(): 80 | context_key.append(k) 81 | for k, v in data["rules"].items(): 82 | context_key.append(k) 83 | # parse reasoning tree 84 | proof_text = data["proof"] 85 | nodes, edges = get_proof_graph(proof_text) 86 | # remove loops and repeat samples 87 | edges_noloop = [] 88 | for src, dst in edges: 89 | if (dst, src) not in edges_noloop: 90 | edges_noloop.append((src, dst)) 91 | edges = edges_noloop[:int(depth)] 92 | nodes = list(set(nodes)) 93 | # random and shorcut 94 | if len(corrput) and pred_type != "noise": 95 | # print(len(edges), edges) 96 | if corrput == "random": 97 | random.shuffle(edges) 98 | if random.random() > 0.5: edges[0] = (edges[0][1], edges[0][0]) 99 | if random.random() > 0.5: edges[1] = (edges[1][1], edges[1][0]) 100 | elif corrput == "shortcut" or corrput == "231": 101 | edges = [edges[1], edges[0]] 102 | else: 103 | corrput_int = [int(corrput[s]) for s in range(len(corrput))] 104 | triples = [edges[0][0], edges[0][1], edges[1][1]] 105 | triples = [triples[corrput_int[i]-1] for i in range(len(triples))] 106 | edges = [(triples[0], triples[1]), (triples[1], triples[2])] 107 | # construct graph 108 | g = nx.DiGraph() 109 | for n in nodes: 110 | g.add_node(n) 111 | for src, dst in edges: 112 | g.add_edge(src, dst) 113 | 114 | if pred_type == "noise": 115 | label_tensor = torch.zeros(len(context_key), 1) 116 | for n in nodes: 117 | label_tensor[context_key.index(n), 0] = 1. 118 | return label_tensor, [i for i in range(len(context_key))] 119 | else: # pred_type == "depth" 120 | # find root node 121 | root_nodes = [n for n in nodes] 122 | for src, dst in edges: 123 | if src in root_nodes: 124 | root_nodes.remove(src) 125 | if len(root_nodes) < 1: # multiple root nodes / root node in loop 126 | # print("Warning! no root nodes (there is loop!): ", depth, nodes, edges, root_nodes) 127 | lp = [] 128 | for n1 in g.nodes(): 129 | for n2 in g.nodes(): 130 | all_paths = nx.all_simple_paths(g, n1, n2) 131 | for p in all_paths: 132 | if len(p) > len(lp): lp=p 133 | root_nodes = [lp[-1]] 134 | # get all node depth 135 | depth_dict = {} 136 | for n in nodes: 137 | max_len = 0 138 | for root in root_nodes: 139 | for l in nx.all_simple_paths(g, n, root): 140 | max_len = max(len(l)-1, max_len) 141 | depth_dict[n] = max_len # nx.dag_longest_path_length(g, root) 142 | # add tensor 143 | label_tensor = torch.zeros(len(context_key), depth+1) 144 | # for all nodes 145 | for n in nodes: 146 | x_idx = context_key.index(n) 147 | y_idx = depth - depth_dict[n] 148 | label_tensor[x_idx, y_idx] = 1. 149 | ''' 150 | # if there is undirected loop, actually there is none of loops 151 | for src, dst in edges: 152 | # src node 153 | if 1. not in label_tensor[context_key.index(src)]: 154 | label_tensor[context_key.index(src), depth-depth_dict[src]] = 1. 155 | if depth_dict[src] > 0: # when root node has a loop 156 | if 1. not in label_tensor[context_key.index(dst)]: 157 | label_tensor[context_key.index(dst), depth-depth_dict[src]+1] = 1. 158 | # dst node 159 | if 1. not in label_tensor[context_key.index(dst)]: 160 | label_tensor[context_key.index(dst), depth-depth_dict[dst]] = 1. 161 | if depth_dict[dst] < depth: # when leaf node has a loop 162 | if 1. not in label_tensor[context_key.index(src)]: 163 | label_tensor[context_key.index(src), depth-depth_dict[dst]-1] = 1. 164 | ''' 165 | # check label 166 | sample_idx = [context_key.index(n) for n in nodes] 167 | ''' 168 | if depth >= 1: 169 | test_label_tensor = label_tensor[sample_idx] 170 | if torch.sum(test_label_tensor) != test_label_tensor.shape[0]: 171 | print("Error label: mismatching number: ", test_label_tensor, sample_idx, context_key, nodes, edges) 172 | if torch.sum(test_label_tensor) < test_label_tensor.shape[1]: 173 | print("Error label: missing number: ", test_label_tensor, sample_idx, context_key, nodes, edges) 174 | ''' 175 | return label_tensor, sample_idx 176 | 177 | 178 | def plot_loss(loss_list, save_path): 179 | if len(loss_list) == 0: return 180 | while len(loss_list) >= 10000: 181 | new_loss_list = [] 182 | for i in range(0, len(loss_list), 10): 183 | new_loss_list.append(sum(loss_list[i:i+10])/len(loss_list[i:i+10])) 184 | loss_list = new_loss_list 185 | if len(loss_list) == 0: return 186 | x = range(len(loss_list)) 187 | fig = plt.gcf() 188 | plt.plot(x, loss_list) 189 | fig.savefig(save_path, format='pdf', bbox_inches="tight") 190 | plt.close() 191 | return 192 | 193 | 194 | def plot_entropy(mat, save_path, title="", xlabel=[], ylabel1=[], ylabel2=[],): 195 | # construct figure 196 | fig = plt.figure() 197 | ax = plt.gca() 198 | im = ax.matshow(mat, aspect='equal', interpolation='none' , cmap='hot', vmin=0., vmax=1.0) 199 | plt.ylabel('Attention head', fontsize=14) 200 | fig.colorbar(im) 201 | if len(xlabel) and len(ylabel1): 202 | ax.set_xticks(np.arange(len(xlabel))) 203 | ax.set_xticklabels(xlabel) 204 | # Rotate and align bottom ticklabels 205 | plt.setp([tick.label1 for tick in ax.xaxis.get_major_ticks()], rotation=45, 206 | ha="right", va="center", rotation_mode="anchor") 207 | ax.set_yticks(np.arange(len(ylabel1))) 208 | ax.set_yticklabels(ylabel1) 209 | # Set ticks on both sides of axes on 210 | ax.tick_params(axis="x", bottom=True, top=True, labelleft=True, labelright=True, labeltop=True) 211 | ax.set_title("Layers", fontsize=14) 212 | fig.tight_layout() 213 | fig.savefig(save_path, format='pdf', bbox_inches="tight") 214 | plt.close() 215 | return 216 | 217 | 218 | def plot_sim(mat, save_path, title="", xlabel=[], ylabel1=[], ylabel2=[], pos=False, depth=False, depth_pred=False, prune_list=[]): 219 | # construct figure 220 | fig = plt.figure() 221 | ax = plt.gca() 222 | if pos: 223 | ax.set_title("Layers", fontsize=14) 224 | im = ax.matshow(mat, aspect='equal', interpolation='none' , cmap='seismic', vmin=0., vmax=0.92) 225 | plt.ylabel('Position ranking', fontsize=14) 226 | else: 227 | if depth: 228 | ax.set_title("Layers", fontsize=14) 229 | xlabel = [i+1 for i in range(32)] 230 | if len(prune_list): 231 | for l in prune_list: 232 | if l+1 in xlabel: 233 | xlabel.remove(l+1) 234 | ylabel1 = [i for i in range(mat.shape[0])] 235 | ylabel1 = ["NA"] + ylabel1[:-1] 236 | im = ax.matshow(mat, aspect='equal', interpolation='none' , cmap='seismic') 237 | plt.ylabel('Depth', fontsize=14) 238 | fig.colorbar(im, location='bottom', shrink=0.5, pad=0.05) 239 | else: 240 | if depth_pred: 241 | dmin, omax= 1., 0. 242 | for i in range(mat.shape[0]): 243 | for j in range(mat.shape[1]): 244 | if i == j: 245 | dmin = min(mat[i,j], dmin) 246 | else: 247 | omax = max(mat[i,j], omax) 248 | ax.set_title("Prediction", fontsize=14) 249 | im = ax.matshow(mat, aspect='equal', interpolation='none' , cmap='seismic', 250 | norm=colors.TwoSlopeNorm(vcenter=min(dmin, omax)+0.5*(max(dmin, omax)-min(dmin, omax)))) 251 | plt.ylabel('GroundTruth', fontsize=14) 252 | fig.colorbar(im,)# shrink=0.5,) 253 | for (i, j), z in np.ndenumerate(mat): 254 | ax.text(j, i, '{:0.2f}'.format(z), ha='center', va='center', fontsize=11) 255 | else: 256 | ax.set_title("Layers", fontsize=14) 257 | im = ax.matshow(mat, aspect='equal', interpolation='none' , cmap='seismic', 258 | norm=colors.TwoSlopeNorm(vcenter=0.077))#norm=colors.PowerNorm(gamma=0.2),) 259 | plt.ylabel('Size ranking', fontsize=14) 260 | fig.colorbar(im) 261 | if len(xlabel) and len(ylabel1): 262 | ax.set_xticks(np.arange(len(xlabel))) 263 | ax.set_xticklabels(xlabel) 264 | ax.set_yticks(np.arange(len(ylabel1))) 265 | ax.set_yticklabels(ylabel1) 266 | # Set ticks on both sides of axes on 267 | ax.tick_params(axis="x", bottom=True, top=True, labelleft=True, labelright=True, labeltop=True) 268 | if depth: 269 | # Rotate and align bottom ticklabels 270 | plt.setp([tick.label1 for tick in ax.xaxis.get_major_ticks()], rotation=60, 271 | ha="right", va="center", rotation_mode="anchor") 272 | # Rotate and align top ticklabels 273 | plt.setp([tick.label2 for tick in ax.xaxis.get_major_ticks()], rotation=60, 274 | ha="left", va="center",rotation_mode="anchor") 275 | # ax.set_title(title, pad=55) 276 | ''' 277 | # plot number 278 | for (i, j), z in np.ndenumerate(mat): 279 | ax.text(j, i, '{:0.2f}'.format(z), ha='center', va='center') 280 | ''' 281 | fig.tight_layout() 282 | fig.savefig(save_path, format='pdf', bbox_inches="tight") 283 | plt.close() 284 | return 285 | 286 | 287 | def plot_flow(data, save_path, title, xlabel, ylabel, cs=False, pos1=None, pos2=None): 288 | if len(data) == 0: return 289 | ax = plt.gca() 290 | ax.set_xticks(np.arange(len(xlabel))) 291 | ax.set_xticklabels(xlabel) 292 | ax.set_yticks(np.arange(len(ylabel))) 293 | ax.set_yticklabels(ylabel) 294 | plt.ylabel('Token', fontsize=14) 295 | plt.xlabel('Layer', fontsize=14) 296 | # Set ticks on both sides of axes on 297 | ax.tick_params(axis="x", bottom=False, top=False, labelleft=True, labelright=True, labeltop=False) 298 | # Rotate and align bottom ticklabels 299 | # plt.setp([tick.label1 for tick in ax.xaxis.get_major_ticks()], rotation=45, ha="right", va="center", rotation_mode="anchor") 300 | if cs: 301 | layer_num = data.shape[0] 302 | token_num = data.shape[1] 303 | for l in range(layer_num): 304 | for t1 in range(token_num): 305 | for t2 in range(token_num): 306 | if data[l, t1, t2] == 0: continue 307 | # if t1 != token_num-1: continue 308 | x = [l+1, l] 309 | y = [token_num-1-t1, token_num-1-t2] 310 | fig = plt.gcf() 311 | if l == layer_num-1 and token_num-1-t1 != 0: continue 312 | if pos1 is not None and pos2 is not None: 313 | if t2 == pos1-1: 314 | plt.plot(x, y, color='blue', alpha=data[l, t1, t2]) 315 | elif t2 == pos2-1: 316 | plt.plot(x, y, color='red', alpha=data[l, t1, t2]) 317 | else: 318 | plt.plot(x, y, color='grey', alpha=data[l, t1, t2]) 319 | else: 320 | plt.plot(x, y, color='grey', alpha=data[l, t1, t2]) 321 | # plot neuron 322 | nx_list, ny_list = [], [] 323 | for i in range(layer_num+1): 324 | for j in range(token_num): 325 | if i == 12 and j != 0: continue 326 | nx_list.append(i) 327 | ny_list.append(j) 328 | plt.scatter(nx_list, ny_list, s=12) 329 | else: 330 | # plot flow 331 | for d in data: 332 | x = range(len(d)) 333 | fig = plt.gcf() 334 | plt.plot(x, d, alpha=0.01) 335 | # plt.title(title) 336 | # plot neuron 337 | nx_list, ny_list = [], [] 338 | for i in range(12): 339 | for j in range(16): 340 | nx_list.append(i) 341 | ny_list.append(j) 342 | plt.scatter(nx_list, ny_list, s=12) 343 | fig.savefig(save_path, format='pdf', bbox_inches="tight") 344 | plt.close() 345 | return 346 | 347 | 348 | def get_label_idx(data, label): 349 | label_list = [] 350 | data_list = [d.split(" ") for d in data] 351 | if len(label) != len(data_list): print("Error: data length not equal to label length!") 352 | for i in range(len(data_list)): 353 | label_idx = data_list[i].index(label[i]) 354 | label_list.append(label_idx) 355 | return label_list 356 | 357 | 358 | class data_generator(): 359 | def __init__(self, args, tokenizer): 360 | fix_seed(args.random_seed) 361 | self.max_num = args.stask_set # 512/1800 numbers with 1 token 362 | self.list_len = args.stask_len # 64 for accuracy analysis, 16 for attn analysis 363 | self.tokenizer = tokenizer 364 | tmp_data_path = os.path.join(args.tmp_dir, args.stask+"_tmp_data.json") 365 | if os.path.exists(tmp_data_path): 366 | self.load_tmp_data(tmp_data_path) 367 | else: 368 | self.generate_save_data(args, tmp_data_path) 369 | 370 | def generate_save_data(self, args, path): 371 | train_data, val_data, test_data = self.generate_data(args.sdata_num) 372 | train_label = self.generate_label(args, train_data) 373 | val_label = self.generate_label(args, val_data) 374 | test_label = self.generate_label(args, test_data) 375 | # int to str 376 | train_sdata = [self.int2str(intl) for intl in train_data] 377 | val_sdata = [self.int2str(intl) for intl in val_data] 378 | test_sdata = [self.int2str(intl) for intl in test_data] 379 | train_slabel = [str(intid) for intid in train_label] 380 | val_slabel = [str(intid) for intid in val_label] 381 | test_slabel = [str(intid) for intid in test_label] 382 | # added to self 383 | self.train_data = train_sdata 384 | self.val_data = val_sdata 385 | self.test_data = test_sdata 386 | self.train_label = train_slabel 387 | self.val_label = val_slabel 388 | self.test_label = test_slabel 389 | return 390 | 391 | def generate_data(self, sdata_num): 392 | all_data = [] 393 | rand_list = [] 394 | count = 0 395 | while len(rand_list) < self.max_num: 396 | t = self.tokenizer(str(count)+" " + str(count), add_special_tokens=False) 397 | if len(t["input_ids"]) == 2: 398 | rand_list.append(count) 399 | count += 1 400 | if count > 1e5: break 401 | for _ in range(sdata_num): 402 | all_data.append(random.choices(rand_list, k=self.list_len)) 403 | split_idx1 = int(len(all_data)*0.98) 404 | split_idx2 = int(len(all_data)*0.01) + split_idx1 405 | return all_data[:split_idx1], all_data[split_idx1:split_idx2], all_data[split_idx2:] 406 | 407 | def generate_label(self, args, data_list): 408 | label_list = [] 409 | for data in data_list: 410 | idx = int(args.stask[-2:]) 411 | label_list.append(sorted(data)[idx]) 412 | return label_list 413 | 414 | def int2str(self, intl): 415 | strl = [str(n) for n in intl] 416 | return " ".join(strl) 417 | 418 | def return_data(self, data=True): 419 | if data: 420 | return self.train_data, self.val_data, self.test_data 421 | else: 422 | return self.train_label, self.val_label, self.test_label 423 | 424 | 425 | class data_loader(): 426 | def __init__(self, args): 427 | fix_seed(args.random_seed) 428 | train_dict, dev_dict, test_dict = self.get_data(args.data_dir) 429 | self.train_dict = train_dict 430 | self.dev_dict = dev_dict 431 | self.test_dict = test_dict 432 | 433 | def get_data(self, path): 434 | if path.split("/")[-1] == "proofwriter": 435 | path = os.path.join(path, "CWA") 436 | if not os.path.exists(os.path.join(path, "test.json")): 437 | train_dict, dev_dict, test_dict = self.preprocess_proofwriter(path) 438 | else: 439 | train_dict, dev_dict, test_dict = self.load_data(path) 440 | elif path.split("/")[-1] == "arc": 441 | if not os.path.exists(os.path.join(path, "test.json")): 442 | train_dict, dev_dict, test_dict = self.preprocess_arc(path) 443 | else: 444 | train_dict, dev_dict, test_dict = self.load_data(path) 445 | else: 446 | print("Warning: wrong data directory!") 447 | 448 | return train_dict, dev_dict, test_dict 449 | 450 | def preprocess_proofwriter(self, path): 451 | dir_list = ["depth-0", "depth-1", "depth-2", "depth-3", "depth-5"] 452 | train_dict = {i:[] for i in range(6)} 453 | dev_dict = {i:[] for i in range(6)} 454 | test_dict = {i:[] for i in range(6)} 455 | for d in dir_list: 456 | tmp_path_train = os.path.join(path, d, "meta-train.jsonl") 457 | tmp_path_dev = os.path.join(path, d, "meta-dev.jsonl") 458 | tmp_path_test = os.path.join(path, d, "meta-test.jsonl") 459 | # process data 460 | with open(tmp_path_train, "r") as f: 461 | for line in f: 462 | tmp_data = json.loads(line) 463 | # clean data 464 | c = tmp_data["theory"] 465 | t, r = {}, {} 466 | for k, v in tmp_data["triples"].items(): t[k] = v["text"] 467 | for k, v in tmp_data["rules"].items(): r[k] = v["text"] 468 | q_count = 0 469 | for qid, qdict in tmp_data["questions"].items(): 470 | if qdict["QLen"] == "": continue # ignore q without proof 471 | q = qdict["question"] 472 | a = qdict["answer"] 473 | p = qdict["proofs"] 474 | if "NAF" in p: continue # ignore proof with negation-as-failure 475 | q_count += 1 476 | # add data 477 | new_data = {} 478 | new_data["context"] = c 479 | new_data["triples"] = t 480 | new_data["rules"] = r 481 | new_data["question"] = q 482 | new_data["answer"] = a 483 | new_data["proof"] = p 484 | q_depth = qdict["QDep"] 485 | train_dict[q_depth].append(new_data) 486 | with open(tmp_path_dev, "r") as f: 487 | for line in f: 488 | tmp_data = json.loads(line) 489 | # clean data 490 | c = tmp_data["theory"] 491 | t, r = {}, {} 492 | for k, v in tmp_data["triples"].items(): t[k] = v["text"] 493 | for k, v in tmp_data["rules"].items(): r[k] = v["text"] 494 | q_count = 0 495 | for qid, qdict in tmp_data["questions"].items(): 496 | if qdict["QLen"] == "": continue # ignore q without proof 497 | q = qdict["question"] 498 | a = qdict["answer"] 499 | p = qdict["proofs"] 500 | if "NAF" in p: continue # ignore proof with negation-as-failure 501 | q_count += 1 502 | # add data 503 | new_data = {} 504 | new_data["context"] = c 505 | new_data["triples"] = t 506 | new_data["rules"] = r 507 | new_data["question"] = q 508 | new_data["answer"] = a 509 | new_data["proof"] = p 510 | q_depth = qdict["QDep"] 511 | dev_dict[q_depth].append(new_data) 512 | with open(tmp_path_test, "r") as f: 513 | for line in f: 514 | tmp_data = json.loads(line) 515 | # clean data 516 | c = tmp_data["theory"] 517 | t, r = {}, {} 518 | for k, v in tmp_data["triples"].items(): t[k] = v["text"] 519 | for k, v in tmp_data["rules"].items(): r[k] = v["text"] 520 | q_count = 0 521 | for qid, qdict in tmp_data["questions"].items(): 522 | if qdict["QLen"] == "": continue # ignore q without proof 523 | q = qdict["question"] 524 | a = qdict["answer"] 525 | p = qdict["proofs"] 526 | if "NAF" in p: continue # ignore proof with negation-as-failure 527 | q_count += 1 528 | # add data 529 | new_data = {} 530 | new_data["context"] = c 531 | new_data["triples"] = t 532 | new_data["rules"] = r 533 | new_data["question"] = q 534 | new_data["answer"] = a 535 | new_data["proof"] = p 536 | q_depth = qdict["QDep"] 537 | test_dict[q_depth].append(new_data) 538 | # save data 539 | with open(os.path.join(path, "train.json"), "w") as f: 540 | f.write(json.dumps(train_dict)) 541 | with open(os.path.join(path, "dev.json"), "w") as f: 542 | f.write(json.dumps(dev_dict)) 543 | with open(os.path.join(path, "test.json"), "w") as f: 544 | f.write(json.dumps(test_dict)) 545 | print("Training data number: ", {k:len(v) for k, v in train_dict.items()}) 546 | print("Dev data number: ", {k:len(v) for k, v in dev_dict.items()}) 547 | print("Test data number: ", {k:len(v) for k, v in test_dict.items()}) 548 | return train_dict, dev_dict, test_dict 549 | 550 | def preprocess_arc(self, path): 551 | dir_list = ["depth-0", "depth-1", "depth-2", "depth-3", "depth-5"] 552 | train_dict = {} 553 | dev_dict = {} 554 | test_dict = {} 555 | tmp_path_train = os.path.join(path, "reasoning_annotated_train.jsonl") 556 | tmp_path_dev = os.path.join(path, "reasoning_annotated_dev.jsonl") 557 | tmp_path_test = os.path.join(path, "reasoning_annotated_test.jsonl") 558 | # process data 559 | with open(tmp_path_train, "r") as f: 560 | for line in f: 561 | tmp_data = json.loads(line) 562 | # add data 563 | new_data = {} 564 | new_data["context"] = tmp_data["context"] 565 | all_triples = {str(k):v for k, v in tmp_data["textual_logical_units"].items()} 566 | new_all_triples = {} 567 | anwser_key = "1" 568 | for k, v in all_triples.items(): 569 | if "The answer is" not in v: 570 | new_all_triples[k] = v 571 | else: 572 | anwser_key = k 573 | new_data["triples"] = new_all_triples 574 | new_data["rules"] = {} 575 | new_data["options"] = tmp_data["options"] 576 | new_data["question"] = tmp_data["question"] + " ".join(tmp_data["options"]) 577 | new_data["answer"] = tmp_data["answer"][0] 578 | all_edges = tmp_data["reasoning_graph_edges"] 579 | tmp_g = nx.DiGraph() 580 | input_nodes = set() 581 | for e in all_edges: 582 | for a in e["antecedents"]: 583 | if int(a) < int(anwser_key) and int(e["consequent"]) < int(anwser_key): 584 | tmp_g.add_edge(str(a), str(e["consequent"])) 585 | input_nodes.add(str(a)) 586 | input_nodes.add(str(e["consequent"])) 587 | proof = "" 588 | g_nodes = dag_longest_path(tmp_g) 589 | g_input = g_nodes[0] 590 | for n in input_nodes: 591 | if n not in g_nodes: 592 | g_input = n + " " + g_input 593 | g_input = "(" + g_input + ")" 594 | g_nodes[0] = g_input 595 | proof = " -> ".join(g_nodes) 596 | new_data["proof"] = proof 597 | q_depth = len(g_nodes) - 1 598 | if q_depth not in train_dict: train_dict[q_depth] = [] 599 | train_dict[q_depth].append(new_data) 600 | with open(tmp_path_dev, "r") as f: 601 | for line in f: 602 | tmp_data = json.loads(line) 603 | # add data 604 | new_data = {} 605 | new_data["context"] = tmp_data["context"] 606 | all_triples = {str(k):v for k, v in tmp_data["textual_logical_units"].items()} 607 | new_all_triples = {} 608 | anwser_key = "1" 609 | for k, v in all_triples.items(): 610 | if "The answer is" not in v: 611 | new_all_triples[k] = v 612 | else: 613 | anwser_key = k 614 | new_data["triples"] = new_all_triples 615 | new_data["rules"] = {} 616 | new_data["options"] = tmp_data["options"] 617 | new_data["question"] = tmp_data["question"] + " ".join(tmp_data["options"]) 618 | new_data["answer"] = tmp_data["answer"][0] 619 | all_edges = tmp_data["reasoning_graph_edges"] 620 | tmp_g = nx.DiGraph() 621 | input_nodes = set() 622 | for e in all_edges: 623 | for a in e["antecedents"]: 624 | if int(a) < int(anwser_key) and int(e["consequent"]) < int(anwser_key): 625 | tmp_g.add_edge(str(a), str(e["consequent"])) 626 | input_nodes.add(str(a)) 627 | input_nodes.add(str(e["consequent"])) 628 | proof = "" 629 | g_nodes = dag_longest_path(tmp_g) 630 | g_input = g_nodes[0] 631 | for n in input_nodes: 632 | if n not in g_nodes: 633 | g_input = n + " " + g_input 634 | g_input = "(" + g_input + ")" 635 | g_nodes[0] = g_input 636 | proof = " -> ".join(g_nodes) 637 | new_data["proof"] = proof 638 | q_depth = len(g_nodes) - 1 639 | if q_depth not in dev_dict: dev_dict[q_depth] = [] 640 | dev_dict[q_depth].append(new_data) 641 | with open(tmp_path_test, "r") as f: 642 | for line in f: 643 | tmp_data = json.loads(line) 644 | # add data 645 | new_data = {} 646 | new_data["context"] = tmp_data["context"] 647 | all_triples = {str(k):v for k, v in tmp_data["textual_logical_units"].items()} 648 | new_all_triples = {} 649 | anwser_key = "1" 650 | for k, v in all_triples.items(): 651 | if "The answer is" not in v: 652 | new_all_triples[k] = v 653 | else: 654 | anwser_key = k 655 | new_data["triples"] = new_all_triples 656 | new_data["rules"] = {} 657 | new_data["options"] = tmp_data["options"] 658 | new_data["question"] = tmp_data["question"] + " ".join(tmp_data["options"]) 659 | new_data["answer"] = tmp_data["answer"][0] 660 | all_edges = tmp_data["reasoning_graph_edges"] 661 | tmp_g = nx.DiGraph() 662 | input_nodes = set() 663 | for e in all_edges: 664 | for a in e["antecedents"]: 665 | if int(a) < int(anwser_key) and int(e["consequent"]) < int(anwser_key): 666 | tmp_g.add_edge(str(a), str(e["consequent"])) 667 | input_nodes.add(str(a)) 668 | input_nodes.add(str(e["consequent"])) 669 | proof = "" 670 | g_nodes = dag_longest_path(tmp_g) 671 | g_input = g_nodes[0] 672 | for n in input_nodes: 673 | if n not in g_nodes: 674 | g_input = n + " " + g_input 675 | g_input = "(" + g_input + ")" 676 | g_nodes[0] = g_input 677 | proof = " -> ".join(g_nodes) 678 | new_data["proof"] = proof 679 | q_depth = len(g_nodes) - 1 680 | if q_depth not in test_dict: test_dict[q_depth] = [] 681 | test_dict[q_depth].append(new_data) 682 | # save data 683 | with open(os.path.join(path, "train.json"), "w") as f: 684 | f.write(json.dumps(train_dict)) 685 | with open(os.path.join(path, "dev.json"), "w") as f: 686 | f.write(json.dumps(dev_dict)) 687 | with open(os.path.join(path, "test.json"), "w") as f: 688 | f.write(json.dumps(test_dict)) 689 | print("Training data number: ", {k:len(v) for k, v in train_dict.items()}) 690 | print("Dev data number: ", {k:len(v) for k, v in dev_dict.items()}) 691 | print("Test data number: ", {k:len(v) for k, v in test_dict.items()}) 692 | return train_dict, dev_dict, test_dict 693 | 694 | def load_data(self, path): 695 | with open(os.path.join(path, "train.json"), "r") as f: 696 | train_dict = json.loads(f.read()) 697 | with open(os.path.join(path, "dev.json"), "r") as f: 698 | dev_dict = json.loads(f.read()) 699 | with open(os.path.join(path, "test.json"), "r") as f: 700 | test_dict = json.loads(f.read()) 701 | print("Training data number: ", {k:len(v) for k, v in train_dict.items()}) 702 | print("Dev data number: ", {k:len(v) for k, v in dev_dict.items()}) 703 | print("Test data number: ", {k:len(v) for k, v in test_dict.items()}) 704 | return train_dict, dev_dict, test_dict 705 | 706 | def return_data(self): 707 | return self.train_dict, self.dev_dict, self.test_dict 708 | 709 | def return_shuffled_data(self): 710 | train_data, dev_data, test_data = [], [], [] 711 | for k, v in self.train_dict.items(): train_data = train_data + v 712 | for k, v in self.dev_dict.items(): dev_data = dev_data + v 713 | for k, v in self.test_dict.items(): test_data = test_data + v 714 | random.shuffle(train_data) 715 | random.shuffle(dev_data) 716 | random.shuffle(test_data) 717 | return train_data, dev_data, test_data -------------------------------------------------------------------------------- /tmp/attn/attn_analysis-all_min01.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifan-h/MechanisticProbe/aa0bfd9dc2e9e38b675db31bad035d735083cff9/tmp/attn/attn_analysis-all_min01.pdf -------------------------------------------------------------------------------- /tmp/attn/attn_analysis-all_org.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifan-h/MechanisticProbe/aa0bfd9dc2e9e38b675db31bad035d735083cff9/tmp/attn/attn_analysis-all_org.pdf -------------------------------------------------------------------------------- /tmp/attn_pos/attn_analysis-all_min01.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifan-h/MechanisticProbe/aa0bfd9dc2e9e38b675db31bad035d735083cff9/tmp/attn_pos/attn_analysis-all_min01.pdf -------------------------------------------------------------------------------- /tmp/attn_pos/attn_analysis-all_org.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifan-h/MechanisticProbe/aa0bfd9dc2e9e38b675db31bad035d735083cff9/tmp/attn_pos/attn_analysis-all_org.pdf -------------------------------------------------------------------------------- /tmp/gpt2_cs/cs_min01_8_12.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifan-h/MechanisticProbe/aa0bfd9dc2e9e38b675db31bad035d735083cff9/tmp/gpt2_cs/cs_min01_8_12.pdf -------------------------------------------------------------------------------- /tmp/gpt2_cs/cs_min01_all.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifan-h/MechanisticProbe/aa0bfd9dc2e9e38b675db31bad035d735083cff9/tmp/gpt2_cs/cs_min01_all.pdf -------------------------------------------------------------------------------- /tmp/prune/pruning_analysis-all_min01_pos.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifan-h/MechanisticProbe/aa0bfd9dc2e9e38b675db31bad035d735083cff9/tmp/prune/pruning_analysis-all_min01_pos.pdf -------------------------------------------------------------------------------- /tmp/prune/pruning_analysis-all_min01_size.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifan-h/MechanisticProbe/aa0bfd9dc2e9e38b675db31bad035d735083cff9/tmp/prune/pruning_analysis-all_min01_size.pdf --------------------------------------------------------------------------------