├── exp_result ├── gpt2-large_result_oursall.pt └── final_gpt2-large_result_oursall.pt ├── README.md ├── interpret ├── README.md ├── mt_saliency.py └── lm_saliency.py ├── TDD_step2.py └── TDD_step1.py /exp_result/gpt2-large_result_oursall.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijian678/TDD/HEAD/exp_result/gpt2-large_result_oursall.pt -------------------------------------------------------------------------------- /exp_result/final_gpt2-large_result_oursall.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijian678/TDD/HEAD/exp_result/final_gpt2-large_result_oursall.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TDD 2 | **Unveiling and Manipulating Prompt Influence in Large Language Models (ICLR 2024)** [Link](https://openreview.net/forum?id=ap1ByuwQrX) 3 | 4 | TDD explores using **token distributions** to explain **autoregressive LLMs**. 5 | Our another work, PromptExplainer, explains **masked language models such as BERT and RoBERTa using token distributions**. Welcome to check [PromptExplainer](https://github.com/zijian678/PromptExplainer)! 6 | 7 | ## Reproduce our results 8 | The are two steps to reproduce our results. 9 | * Step 1: Generate saliency scores using TDD_step1.py. You may choose different datasets and LLMs to generate saliency scores. 10 | * Step 2: Evaluate using AOPC and Sufficiency by TDD_step2.py. It calculates AOPC and Suff scores using the saliency scores from step 1. 11 | 12 | Please use your own LLaMA access token while experimenting with it. 13 | 14 | ## Acknowledgement 15 | The code for contrastive explanation baselines is from [interpret-lm](https://github.com/kayoyin/interpret-lm). The dataset is from [BLiMP](https://github.com/alexwarstadt/blimp). We thank the authors for their excellent contributions! 16 | 17 | ## Citation 18 | If you find our work useful, please consider citing TDD: 19 | ``` 20 | @inproceedings{feng2024tdd, 21 | title={Unveiling and Manipulating Prompt Influence in Large Language Models}, 22 | author={Feng, Zijian and Zhou, Hanzhang and Zhu, Zixiao and Qian, Junlang and Mao, Kezhi}, 23 | booktitle={The Twelfth International Conference on Learning Representations}, 24 | year={2024} 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /interpret/README.md: -------------------------------------------------------------------------------- 1 | # Interpreting Language Models with Contrastive Explanations 2 | 3 | Code supporting the paper [Interpreting Language Models with Contrastive Explanations](https://arxiv.org/abs/2202.10419) 4 | 5 | Currently supports: 6 | * Contrastive explanations for language models (GPT-2, GPT-Neo) [(Colab)](https://colab.research.google.com/drive/1L6VjQ9_XAlbkPENmJxMpCntR3X_grpih?usp=sharing) 7 | * Contrastive explanations for NMT models (MarianMT) [(Colab)](https://colab.research.google.com/drive/1rkSOGGxinVH_pzHxmswtt0ZDKmgf-sPL?usp=sharing) 8 | 9 | ## Requirements 10 | * [PyTorch](https://pytorch.org/get-started/locally/) >= 1.11.0 11 | * [SentencePiece](https://github.com/google/sentencepiece) >= 0.1.90 12 | * [Transformers](https://github.com/huggingface/transformers) 13 | * Python >= 3.6 14 | 15 | ## Examples 16 | ### 1. Load models 17 | 18 | LM: 19 | ``` 20 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 21 | 22 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 23 | model = GPT2LMHeadModel.from_pretrained("gpt2") 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | model.to(device) 26 | ``` 27 | 28 | NMT: 29 | ``` 30 | from transformers import MarianTokenizer, MarianMTModel 31 | 32 | model_name = f"Helsinki-NLP/opus-mt-en-fr" 33 | tokenizer = MarianTokenizer.from_pretrained(model_name) 34 | model = MarianMTModel.from_pretrained(model_name) 35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | model.to(device) 37 | ``` 38 | 39 | ### 2. Define inputs 40 | 41 | LM: 42 | ``` 43 | input = "Can you stop the dog from " 44 | input_tokens = tokenizer(input)['input_ids'] 45 | attention_ids = tokenizer(input)['attention_mask'] 46 | ``` 47 | 48 | NMT: 49 | ``` 50 | encoder_input = "I can't find the seat, do you know where it is?" 51 | decoder_input = "Je ne trouve pas la place, tu sais où" 52 | decoder_input = f" {decoder_input.strip()} " 53 | 54 | input_ids = tokenizer(encoder_input, return_tensors="pt").input_ids.to(device) 55 | decoder_input_ids = tokenizer(decoder_input, return_tensors="pt", add_special_tokens=False,).input_ids.to(device) 56 | ``` 57 | 58 | ### 3. Visualize explanations 59 | 60 | LM: 61 | ``` 62 | from lm_saliency import * 63 | 64 | target = "barking" 65 | foil = "crying" 66 | CORRECT_ID = tokenizer(" "+ target)['input_ids'][0] 67 | FOIL_ID = tokenizer(" "+ foil)['input_ids'][0] 68 | 69 | base_saliency_matrix, base_embd_matrix = saliency(model, input_tokens, attention_ids) 70 | saliency_matrix, embd_matrix = saliency(model, input_tokens, attention_ids, foil=FOIL_ID) 71 | 72 | # Input x gradient 73 | base_explanation = input_x_gradient(base_saliency_matrix, base_embd_matrix, normalize=True) 74 | contra_explanation = input_x_gradient(saliency_matrix, embd_matrix, normalize=True) 75 | 76 | # Gradient norm 77 | base_explanation = l1_grad_norm(base_saliency_matrix, normalize=True) 78 | contra_explanation = l1_grad_norm(saliency_matrix, normalize=True) 79 | 80 | # Erasure 81 | base_explanation = erasure_scores(model, input_tokens, attention_ids, normalize=True) 82 | contra_explanation = erasure_scores(model, input_tokens, attention_ids, correct=CORRECT_ID, foil=FOIL_ID, normalize=True) 83 | 84 | visualize(np.array(base_explanation), tokenizer, [input_tokens], print_text=True, title=f"Why did the model predict {target}?") 85 | visualize(np.array(contra_explanation), tokenizer, [input_tokens], print_text=True, title=f"Why did the model predict {target} instead of {foil}?") 86 | ``` 87 | 88 | NMT: 89 | ``` 90 | from lm_saliency import visualize 91 | from mt_saliency import * 92 | 93 | target = "elle" 94 | foil = "il" 95 | CORRECT_ID = tokenizer(" "+ target)['input_ids'][0] 96 | FOIL_ID = tokenizer(" "+ foil)['input_ids'][0] 97 | 98 | base_enc_saliency, base_enc_embed, base_dec_saliency, base_dec_embed = saliency(model, input_ids, decoder_input_ids) 99 | enc_saliency, enc_embed, dec_saliency, dec_embed = saliency(model, input_ids, decoder_input_ids, foil=FOIL_ID) 100 | 101 | # Input x gradient 102 | base_enc_explanation = input_x_gradient(base_enc_saliency, base_enc_embed, normalize=False) 103 | base_dec_explanation = input_x_gradient(base_dec_saliency, base_dec_embed, normalize=False) 104 | enc_explanation = input_x_gradient(enc_saliency, enc_embed, normalize=False) 105 | dec_explanation = input_x_gradient(dec_saliency, dec_embed, normalize=False) 106 | 107 | # Gradient norm 108 | base_enc_explanation = l1_grad_norm(base_enc_saliency, normalize=False) 109 | base_dec_explanation = l1_grad_norm(base_dec_saliency, normalize=False) 110 | enc_explanation = l1_grad_norm(enc_saliency, normalize=False) 111 | dec_explanation = l1_grad_norm(dec_saliency, normalize=False) 112 | 113 | # Erasure 114 | base_enc_explanation, base_dec_explanation = erasure_scores(model, input_ids, decoder_input_ids, correct=CORRECT_ID, normalize=False) 115 | enc_explanation, dec_explanation = erasure_scores(model, input_ids, decoder_input_ids, correct=CORRECT_ID, foil=FOIL_ID, normalize=False) 116 | 117 | # Normalize 118 | base_norm = np.linalg.norm(np.concatenate((base_enc_explanation, base_dec_explanation)), ord=1) 119 | base_enc_explanation /= base_norm 120 | base_dec_explanation /= base_norm 121 | norm = np.linalg.norm(np.concatenate((enc_explanation, dec_explanation)), ord=1) 122 | enc_explanation /= norm 123 | dec_explanation /= norm 124 | 125 | # Visualize 126 | visualize(base_enc_explanation, tokenizer, input_ids, print_text=True, title=f"Why did the model predict {target}? (encoder input)") 127 | visualize(base_dec_explanation, tokenizer, decoder_input_ids, print_text=True, title=f"Why did the model predict {target}? (decoder input)") 128 | visualize(enc_explanation, tokenizer, input_ids, print_text=True, title=f"Why did the model predict {target} instead of {foil}? (encoder input)") 129 | visualize(dec_explanation, tokenizer, decoder_input_ids, print_text=True, title=f"Why did the model predict {target} instead of {foil}? (decoder input)") 130 | ``` 131 | -------------------------------------------------------------------------------- /interpret/mt_saliency.py: -------------------------------------------------------------------------------- 1 | import argparse, json 2 | import random 3 | import torch 4 | import numpy as np 5 | from transformers import ( 6 | WEIGHTS_NAME, 7 | GPT2Config, 8 | GPT2Tokenizer, 9 | GPT2LMHeadModel, 10 | GPTNeoForCausalLM, 11 | ) 12 | 13 | import matplotlib as mpl 14 | import matplotlib.pyplot as plt 15 | 16 | plt.rcParams['figure.figsize'] = [10, 10] 17 | 18 | config = GPT2Config.from_pretrained("gpt2") 19 | VOCAB_SIZE = config.vocab_size 20 | 21 | def model_preds(model, input_ids, decoder_input_ids, tokenizer, k=10, verbose=False): 22 | softmax = torch.nn.Softmax(dim=0) 23 | A = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) 24 | probs = softmax(A.logits[0][-1]) 25 | top_preds = probs.topk(k) 26 | if verbose: 27 | print("Top model predictions:") 28 | for p,i in zip(top_preds.values, top_preds.indices): 29 | print(f"{np.round(p.item(), 3)}: {tokenizer.decode(i)}") 30 | return top_preds.indices 31 | 32 | # Adapted from AllenNLP Interpret and Han et al. 2020 33 | 34 | def register_embedding_list_hook(model, embeddings_list): 35 | def forward_hook(module, inputs, output): 36 | embeddings_list.append(output.squeeze(0).clone().cpu().detach().numpy()) 37 | embedding_layer = model.model.encoder.embed_tokens 38 | handle = embedding_layer.register_forward_hook(forward_hook) 39 | return handle 40 | 41 | def register_embedding_gradient_hooks(model, embeddings_gradients): 42 | def hook_layers(module, grad_in, grad_out): 43 | embeddings_gradients.append(grad_out[0].detach().cpu().numpy()) 44 | embedding_layer = model.model.encoder.embed_tokens 45 | hook = embedding_layer.register_backward_hook(hook_layers) 46 | return hook 47 | 48 | def saliency(model, input_ids, decoder_input_ids, batch=0, correct=None, foil=None): 49 | torch.enable_grad() 50 | model.eval() 51 | embeddings_list = [] 52 | handle = register_embedding_list_hook(model, embeddings_list) 53 | embeddings_gradients = [] 54 | hook = register_embedding_gradient_hooks(model, embeddings_gradients) 55 | 56 | if correct is None: 57 | correct = input_ids[0][-1] 58 | input_ids = torch.tensor(input_ids, dtype=torch.long).to(model.device) 59 | decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long).to(model.device) 60 | 61 | model.zero_grad() 62 | A = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) 63 | 64 | if foil is not None: 65 | if correct == foil: 66 | (A.logits[0][-1][correct]).backward() 67 | else: 68 | (A.logits[0][-1][correct]-A.logits[0][-1][foil]).backward() 69 | else: 70 | (A.logits[0][-1][correct]).backward() 71 | handle.remove() 72 | hook.remove() 73 | 74 | dec_saliency, enc_saliency = embeddings_gradients 75 | enc_embed, dec_embed = embeddings_list 76 | return enc_saliency.squeeze(), enc_embed, dec_saliency.squeeze(), dec_embed 77 | 78 | def input_x_gradient(grads, embds, normalize=False): 79 | # same as LM saliency 80 | input_grad = np.sum(grads * embds, axis=-1).squeeze() 81 | 82 | if normalize: 83 | norm = np.linalg.norm(input_grad, ord=1) 84 | input_grad /= norm 85 | 86 | return input_grad 87 | 88 | def l1_grad_norm(grads, normalize=False): 89 | # same as LM saliency 90 | l1_grad = np.linalg.norm(grads, ord=1, axis=-1).squeeze() 91 | 92 | if normalize: 93 | norm = np.linalg.norm(l1_grad, ord=1) 94 | l1_grad /= norm 95 | return l1_grad 96 | 97 | 98 | def erasure_scores(model, input_ids, decoder_input_ids, correct=None, foil=None, normalize=False): 99 | model.eval() 100 | if correct is None: 101 | correct = input_ids[0][-1] 102 | input_ids = torch.tensor(input_ids, dtype=torch.long).to(model.device) 103 | decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long).to(model.device) 104 | 105 | A = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) 106 | softmax = torch.nn.Softmax(dim=0) 107 | logits = A.logits[0][-1] 108 | probs = softmax(logits) 109 | if foil is not None and correct != foil: 110 | base_score = (probs[correct]-probs[foil]).detach().cpu().numpy() 111 | else: 112 | base_score = (probs[correct]).detach().cpu().numpy() 113 | 114 | enc_scores = np.zeros(len(input_ids[0])) 115 | for i in range(len(input_ids[0])): 116 | input_ids_i = torch.cat((input_ids[0][:i], input_ids[0][i+1:])).unsqueeze(0) 117 | A = model(input_ids=input_ids_i, decoder_input_ids=decoder_input_ids) 118 | logits = A.logits[0][-1] 119 | probs = softmax(logits) 120 | if foil is not None and correct != foil: 121 | erased_score = (probs[correct]-probs[foil]).detach().cpu().numpy() 122 | else: 123 | erased_score = (probs[correct]).detach().cpu().numpy() 124 | 125 | enc_scores[i] = base_score - erased_score # higher score = lower confidence in correct = more influential input 126 | 127 | dec_scores = np.zeros(len(decoder_input_ids[0])) 128 | for i in range(len(decoder_input_ids[0])): 129 | decoder_input_ids_i = torch.cat((decoder_input_ids[0][:i], decoder_input_ids[0][i+1:])).unsqueeze(0) 130 | A = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids_i) 131 | logits = A.logits[0][-1] 132 | probs = softmax(logits) 133 | if foil is not None and correct != foil: 134 | erased_score = (probs[correct]-probs[foil]).detach().cpu().numpy() 135 | else: 136 | erased_score = (probs[correct]).detach().cpu().numpy() 137 | 138 | dec_scores[i] = base_score - erased_score # higher score = lower confidence in correct = more influential input 139 | 140 | 141 | if normalize: 142 | norm = np.linalg.norm(enc_scores, ord=1) 143 | enc_scores /= norm 144 | norm = np.linalg.norm(dec_scores, ord=1) 145 | dec_scores /= norm 146 | 147 | return enc_scores, dec_scores 148 | 149 | 150 | def main(): 151 | pass 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /interpret/lm_saliency.py: -------------------------------------------------------------------------------- 1 | import argparse, json 2 | import random 3 | import torch 4 | import numpy as np 5 | from transformers import ( 6 | WEIGHTS_NAME, 7 | GPT2Config, 8 | GPT2Tokenizer, 9 | GPT2LMHeadModel, 10 | GPTNeoForCausalLM, 11 | 12 | ) 13 | 14 | import matplotlib as mpl 15 | import matplotlib.pyplot as plt 16 | 17 | plt.rcParams['figure.figsize'] = [10, 10] 18 | 19 | config = GPT2Config.from_pretrained("gpt2") 20 | VOCAB_SIZE = config.vocab_size 21 | 22 | 23 | def model_preds(model, input_ids, input_mask, pos, tokenizer, foils=None, k=10, verbose=False): 24 | # Obtain model's top predictions for given input 25 | input_ids = torch.tensor(input_ids, dtype=torch.long).to(model.device) 26 | input_mask = torch.tensor(input_mask, dtype=torch.long).to(model.device) 27 | softmax = torch.nn.Softmax(dim=0) 28 | A = model(input_ids[:, :pos], attention_mask=input_mask[:, :pos]) 29 | probs = softmax(A.logits[0][pos-1]) 30 | top_preds = probs.topk(k) 31 | if verbose: 32 | if foils: 33 | for foil in foils: 34 | print("Contrastive loss: ", A.logits[0][pos-1][input_ids[0, pos]] - A.logits[0][pos-1][foil]) 35 | print(f"{np.round(probs[foil].item(), 3)}: {tokenizer.decode(foil)}") 36 | print("Top model predictions:") 37 | for p,i in zip(top_preds.values, top_preds.indices): 38 | print(f"{np.round(p.item(), 3)}: {tokenizer.decode(i)}") 39 | return top_preds.indices 40 | 41 | # Adapted from AllenNLP Interpret and Han et al. 2020 42 | def register_embedding_list_hook(model, embeddings_list,model_name = "gpt2-large"): 43 | def forward_hook(module, inputs, output): 44 | embeddings_list.append(output.squeeze(0).clone().cpu().detach().numpy()) 45 | 46 | if model_name == "gpt2-large" or model_name == "gpt-j-6B" or model_name == 'gpt2-medium': 47 | embedding_layer = model.transformer.wte 48 | elif model_name == "llama-7B": 49 | embedding_layer = model.model.embed_tokens 50 | elif model_name == "bloom-7b": 51 | embedding_layer = model.transformer.word_embeddings 52 | elif model_name == "Pythia-6.9b": 53 | embedding_layer = model.gpt_neox.embed_in 54 | elif model_name == "opt-6.7b": 55 | embedding_layer = model.model.decoder.embed_tokens 56 | 57 | handle = embedding_layer.register_forward_hook(forward_hook) 58 | return handle 59 | 60 | def register_embedding_gradient_hooks(model, embeddings_gradients,model_name = "gpt2-large"): 61 | def hook_layers(module, grad_in, grad_out): 62 | embeddings_gradients.append(grad_out[0].detach().cpu().numpy()) 63 | # embedding_layer = model.transformer.wte 64 | if model_name == "gpt2-large" or model_name == "gpt-j-6B" or model_name == 'gpt2-medium': 65 | embedding_layer = model.transformer.wte 66 | elif model_name == "llama-7B": 67 | embedding_layer = model.model.embed_tokens 68 | elif model_name == "bloom-7b": 69 | embedding_layer = model.transformer.word_embeddings 70 | elif model_name == "Pythia-6.9b": 71 | embedding_layer = model.gpt_neox.embed_in 72 | elif model_name == "opt-6.7b": 73 | embedding_layer = model.model.decoder.embed_tokens 74 | 75 | 76 | hook = embedding_layer.register_backward_hook(hook_layers) 77 | return hook 78 | 79 | def saliency(model, model_name,input_ids, input_mask, batch=0, correct=None, foil=None): 80 | # Get model gradients and input embeddings 81 | torch.enable_grad() 82 | model.eval() 83 | embeddings_list = [] 84 | handle = register_embedding_list_hook(model, embeddings_list,model_name) 85 | embeddings_gradients = [] 86 | hook = register_embedding_gradient_hooks(model, embeddings_gradients,model_name) 87 | 88 | if correct is None: 89 | correct = input_ids[-1] 90 | input_ids = input_ids[:-1] 91 | input_mask = input_mask[:-1] 92 | # if correct is None: 93 | # correct = input_ids 94 | # input_ids = input_ids 95 | # input_mask = input_mask 96 | 97 | input_ids = torch.tensor(input_ids, dtype=torch.long).to(model.device) 98 | input_mask = torch.tensor(input_mask, dtype=torch.long).to(model.device) 99 | 100 | model.zero_grad() 101 | # print('input_ids:',input_ids.shape,input_ids) 102 | 103 | # revise ### 104 | if model_name == "gpt2-large" or model_name == "gpt-j-6B": 105 | A = model(input_ids, attention_mask=input_mask) 106 | # print('A.logits:', A.logits.shape) 107 | if foil is not None and correct != foil: 108 | (A.logits[-1][correct] - A.logits[-1][foil]).backward() 109 | else: 110 | (A.logits[-1][correct]).backward() 111 | handle.remove() 112 | hook.remove() 113 | 114 | out1 = np.array(embeddings_gradients).squeeze() 115 | out2 = np.array(embeddings_list).squeeze() 116 | # print('out1:',out1.shape) 117 | # print('out2:',out2.shape) 118 | else: 119 | input_ids = input_ids.unsqueeze(0) 120 | input_mask = input_mask.unsqueeze(0) 121 | A = model(input_ids, attention_mask=input_mask) 122 | # print('A.logits:',A.logits.shape) 123 | if foil is not None and correct != foil: 124 | (A.logits[0][-1][correct] - A.logits[0][-1][foil]).backward() 125 | else: 126 | (A.logits[0][-1][correct]).backward() 127 | handle.remove() 128 | hook.remove() 129 | 130 | out1 = np.array(embeddings_gradients).squeeze() 131 | out2 = np.array(embeddings_list).squeeze() 132 | # print('out1:', out1.shape) 133 | # print('out2:',out2.shape) 134 | 135 | 136 | 137 | 138 | 139 | return out1,out2 140 | 141 | def input_x_gradient(grads, embds, normalize=False): 142 | # print('grads:',grads.shape,'embds:',embds.shape) 143 | if grads.shape[0] == 1: 144 | input_grad = np.sum(grads * embds, axis=-1) 145 | else: 146 | input_grad = np.sum(grads * embds, axis=-1).squeeze() 147 | 148 | 149 | if normalize: 150 | 151 | 152 | if input_grad.shape == (): 153 | input_grad = np.array([1.0]) 154 | else: 155 | norm = np.linalg.norm(input_grad, ord=1) 156 | input_grad /= norm 157 | # print('input_grad:', input_grad.shape, input_grad) 158 | 159 | return input_grad 160 | 161 | def l1_grad_norm(grads, normalize=False): 162 | if grads.shape[0] == 1: 163 | l1_grad = np.linalg.norm(grads, ord=1, axis=-1) 164 | else: 165 | l1_grad = np.linalg.norm(grads, ord=1, axis=-1).squeeze() 166 | 167 | 168 | if normalize: 169 | if l1_grad.shape == (): 170 | l1_grad = np.array([1.0]) 171 | else: 172 | norm = np.linalg.norm(l1_grad, ord=1) 173 | l1_grad /= norm 174 | 175 | return l1_grad 176 | def erasure_scores(model, input_ids, input_mask, correct=None, foil=None, remove=False, normalize=False): 177 | model.eval() 178 | if correct is None: 179 | correct = input_ids[-1] 180 | input_ids = input_ids[:-1] 181 | input_mask = input_mask[:-1] 182 | input_ids = torch.unsqueeze(torch.tensor(input_ids, dtype=torch.long).to(model.device), 0) 183 | input_mask = torch.unsqueeze(torch.tensor(input_mask, dtype=torch.long).to(model.device), 0) 184 | 185 | A = model(input_ids, attention_mask=input_mask) 186 | softmax = torch.nn.Softmax(dim=0) 187 | logits = A.logits[0][-1] 188 | probs = softmax(logits) 189 | if foil is not None and correct != foil: 190 | base_score = (probs[correct]-probs[foil]).detach().cpu().numpy() 191 | else: 192 | base_score = (probs[correct]).detach().cpu().numpy() 193 | 194 | scores = np.zeros(len(input_ids[0])) 195 | for i in range(len(input_ids[0])): 196 | if remove: 197 | input_ids_i = torch.cat((input_ids[0][:i], input_ids[0][i+1:])).unsqueeze(0) 198 | input_mask_i = torch.cat((input_mask[0][:i], input_mask[0][i+1:])).unsqueeze(0) 199 | else: 200 | input_ids_i = torch.clone(input_ids) 201 | input_mask_i = torch.clone(input_mask) 202 | input_mask_i[0][i] = 0 203 | 204 | A = model(input_ids_i, attention_mask=input_mask_i) 205 | logits = A.logits[0][-1] 206 | probs = softmax(logits) 207 | if foil is not None and correct != foil: 208 | erased_score = (probs[correct]-probs[foil]).detach().cpu().numpy() 209 | else: 210 | erased_score = (probs[correct]).detach().cpu().numpy() 211 | 212 | scores[i] = base_score - erased_score # higher score = lower confidence in correct = more influential input 213 | # print('scores:',scores.shape,scores) 214 | if scores.shape[0] == 1: 215 | scores[0] = 1 216 | if normalize: 217 | norm = np.linalg.norm(scores, ord=1) 218 | scores /= norm 219 | return scores 220 | 221 | def visualize(attention, tokenizer, input_ids, gold=None, normalize=False, print_text=True, save_file=None, title=None, figsize=60, fontsize=36): 222 | tokens = [tokenizer.decode(i) for i in input_ids[0][:len(attention) + 1]] 223 | if gold is not None: 224 | for i, g in enumerate(gold): 225 | if g == 1: 226 | tokens[i] = "**" + tokens[i] + "**" 227 | 228 | # Normalize to [-1, 1] 229 | if normalize: 230 | a,b = min(attention), max(attention) 231 | x = 2/(b-a) 232 | y = 1-b*x 233 | attention = [g*x + y for g in attention] 234 | attention = np.array([list(map(float, attention))]) 235 | 236 | fig, ax = plt.subplots(figsize=(figsize,figsize)) 237 | norm = mpl.colors.Normalize(vmin=-1, vmax=1) 238 | im = ax.imshow(attention, cmap='seismic', norm=norm) 239 | 240 | if print_text: 241 | ax.set_xticks(np.arange(len(tokens))) 242 | ax.set_xticklabels(tokens, fontsize=fontsize) 243 | else: 244 | ax.get_xaxis().set_visible(False) 245 | ax.get_yaxis().set_visible(False) 246 | 247 | 248 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 249 | rotation_mode="anchor") 250 | for (i, j), z in np.ndenumerate(attention): 251 | ax.text(j, i, '{:0.2f}'.format(z), ha='center', va='center', fontsize=fontsize) 252 | 253 | 254 | ax.set_title("") 255 | fig.tight_layout() 256 | if title is not None: 257 | plt.title(title, fontsize=36) 258 | 259 | if save_file is not None: 260 | plt.savefig(save_file, bbox_inches = 'tight', 261 | pad_inches = 0) 262 | plt.close() 263 | else: 264 | plt.show() 265 | 266 | def main(): 267 | pass 268 | 269 | if __name__ == "__main__": 270 | main() 271 | -------------------------------------------------------------------------------- /TDD_step2.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import os 3 | 4 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM,BloomTokenizerFast, BloomForCausalLM, GPTNeoXForCausalLM 5 | import torch 6 | import json 7 | import numpy as np 8 | from tqdm import tqdm 9 | from interpret.lm_saliency import * 10 | # from experiment import flip 11 | import random 12 | import torch.nn.functional as F 13 | 14 | # input parameters 15 | model_name = 'gpt2-large' # gpt2-large (36 layers), gpt-j-6B (28 layers), llama-7B (32 layers), bloom-7b (30 layerss), Pythia-6.9b 32 layers 16 | flip_case = 'generate' # pruning generate ---- generate is the activation task 17 | method = "ours" # ours # sota # erasure # rollout 18 | if method == 'sota': 19 | result = {'IG_base': {}, 'IG_con': {}, 'GN_base': {}, 'GN_con': {}} 20 | if method == "ours": 21 | result = {'ours': {},'ours_back':{},'ours_add':{}} 22 | if method == "erasure": 23 | result = {'erasure': {}} 24 | if method == "rollout": 25 | result = {'rollout': {}} 26 | 27 | # model_name = "Pythia-6.9b" 28 | # data_name = ["anaphor_gender_agreement",'anaphor_number_agreement','animate_subject_passive', 29 | # 'determiner_noun_agreement_1','determiner_noun_agreement_irregular_1', 30 | # 'determiner_noun_agreement_with_adjective_1','determiner_noun_agreement_with_adj_irregular_1', 31 | # 'npi_present_1','distractor_agreement_relational_noun','irregular_plural_subject_verb_agreement_1', 32 | # 'regular_plural_subject_verb_agreement_1'] 33 | data_name = ['anaphor_gender_agreement','anaphor_number_agreement'] 34 | save_name_post = "all" 35 | 36 | ############################################### 37 | 38 | res_path = './exp_result/' + model_name + '_result_'+method + save_name_post + '.pt' 39 | save_path = './exp_result/' + 'final_' + model_name + '_result_'+method + save_name_post + '.pt' 40 | print('save_path:',save_path) 41 | exp_all = torch.load(res_path) 42 | 43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 44 | layer_norm = None 45 | if model_name == "gpt2-large": 46 | tokenizer = GPT2Tokenizer.from_pretrained(model_name) 47 | model = GPT2LMHeadModel.from_pretrained(model_name) 48 | lm_head = model.lm_head 49 | layer_norm = model.transformer.ln_f 50 | save_name_prefix = 'exp2_gpt2large_' 51 | 52 | if model_name == "gpt-j-6B": 53 | 54 | model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16,resume_download=True) 55 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 56 | lm_head = model.lm_head 57 | # layer_norm = model.transformer.ln_f 58 | layer_norm = model.transformer.ln_f 59 | save_name_prefix = 'exp2_gptj' 60 | 61 | if model_name == 'llama-7B': 62 | model_path = "meta-llama/Llama-2-7b-hf" 63 | 64 | tokenizer = AutoTokenizer.from_pretrained(model_path, token=" ") 65 | model = AutoModelForCausalLM.from_pretrained(model_path, token=" ", 66 | torch_dtype=torch.float16) # load_in_4bit=True 67 | lm_head = model.lm_head 68 | layer_norm = model.model.norm 69 | save_name_prefix = 'exp2_llama7B' 70 | 71 | if model_name == "bloom-7b": 72 | model_path = 'bigscience/bloom-7b1' 73 | model = BloomForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16,resume_download=True) 74 | tokenizer = BloomTokenizerFast.from_pretrained(model_path) 75 | lm_head = model.lm_head 76 | layer_norm = model.transformer.ln_f 77 | save_name_prefix = 'exp2_bloom7b' 78 | 79 | if model_name == "opt-6.7b": 80 | # no 81 | model_path = "facebook/opt-6.7b" 82 | model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16,resume_download=True) 83 | tokenizer = AutoTokenizer.from_pretrained(model_path) 84 | lm_head = model.lm_head 85 | # layer_norm = model.decoder. 86 | save_name_prefix = 'exp2_opt6b' 87 | 88 | if model_name == 'Pythia-6.9b': 89 | model_path = "EleutherAI/pythia-6.9b-deduped-v0" 90 | model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, resume_download=True) 91 | tokenizer = AutoTokenizer.from_pretrained(model_path) 92 | lm_head = model.embed_out 93 | layer_norm = model.gpt_neox.final_layer_norm 94 | save_name_prefix = 'exp2_pythia6b' 95 | 96 | # print(list(model.modules())) 97 | if model_name == "gpt2-large": 98 | model.to(device) 99 | 100 | 101 | def flip(model, x, token_ids, tokens, target_ids, fracs, flip_case,random_order = False, tokenizer=None, device='cuda',loss_ids = None): 102 | 103 | x = np.array(x) 104 | # y_true = y_true.squeeze(0) 105 | # print('x:',x.shape,x) 106 | # print('token_ids:',token_ids.shape,token_ids) 107 | # print('tokens input:',tokenizer.convert_ids_to_tokens(token_ids[0])) 108 | # # print('y_true:',y_true) 109 | # print('fracs:',fracs) 110 | # # print('flip_case:',flip_case) 111 | 112 | if model_name == "llama-7B": 113 | UNK_IDX = tokenizer.encode(' ')[1] 114 | else: 115 | UNK_IDX = tokenizer.encode(' ')[0] 116 | 117 | 118 | # print('UNK_IDX:',UNK_IDX,'***',tokenizer.encode(' '),'convert:',tokenizer.convert_tokens_to_ids(' ')) 119 | inputs0 = torch.tensor(token_ids).to(device) 120 | model = model.to(device) 121 | 122 | model_input = {"input_ids":inputs0.long(),"loss_ids":loss_ids,"meta": None} 123 | 124 | # y0 = model.forward(model_input)[0].squeeze().detach().cpu().numpy() 125 | # with torch.no_grad(): 126 | # y0 = model(model_input, output_hidden_states=True) 127 | # # print('y0:',y0) 128 | # orig_token_ids = np.copy(token_ids.detach().cpu().numpy()) 129 | 130 | if random_order==False: 131 | inds_sorted = np.argsort(x)[::-1] 132 | # print('inds_sorted:',inds_sorted) 133 | # if flip_case=='generate': 134 | # inds_sorted = np.argsort(x)[::-1] 135 | # elif flip_case=='pruning': 136 | # inds_sorted = np.argsort(np.abs(x)) 137 | # else: 138 | # raise 139 | else: 140 | 141 | inds_ = np.array(list(range(x.shape[-1]))) 142 | remain_inds = np.array(inds_) 143 | np.random.shuffle(remain_inds) 144 | 145 | inds_sorted = remain_inds 146 | 147 | inds_sorted = inds_sorted.copy() 148 | # print('inds_sorted:',inds_sorted) 149 | # vals = x[inds_sorted] 150 | 151 | mse = [] 152 | evidence = [] 153 | # model_outs = {'sentence': tokens, 'y_true':y_true.detach().cpu().numpy(), 'y0':y0} 154 | 155 | # print('x shpape:',x.shape) 156 | 157 | N=len(x) 158 | 159 | evolution = {} 160 | for frac in fracs: 161 | inds_generator = iter(inds_sorted) 162 | n_flip=int(np.ceil(frac*N)) 163 | inds_flip = [next(inds_generator) for i in range(n_flip)] 164 | 165 | if flip_case == 'pruning': 166 | 167 | inputs = inputs0 168 | for i in inds_flip: 169 | inputs[:,i] = UNK_IDX 170 | 171 | elif flip_case == 'generate': 172 | inputs = UNK_IDX*torch.ones_like(inputs0) 173 | # Set pad tokens 174 | inputs[inputs0==0] = 0 175 | 176 | for i in inds_flip: 177 | inputs[:,i] = inputs0[:,i] 178 | # print('original inputs:', inputs) 179 | # inputs[:,1] = 83 180 | # inputs[:, 2] = 50264 181 | # inputs[:, 3] = 340 182 | # inputs[:, 4] = 4832 183 | 184 | # print('inputs:',inputs) 185 | # print('tokens:',tokenizer.convert_ids_to_tokens(inputs.squeeze())) 186 | # print('tokens:',tokenizer.decode(inputs.squeeze())) 187 | 188 | # model_input = {"input_ids":inputs.long(),"loss_ids":loss_ids,"meta": None} 189 | model_input = inputs.long() 190 | 191 | # y = model(inputs, labels = torch.tensor([y_true]*len(token_ids)).long().to(device))['logits'].detach().cpu().numpy() 192 | # y = model.forward(model_input)[0].squeeze().detach().cpu().numpy() 193 | y = model(model_input, output_hidden_states=True).logits[0] 194 | y = y[-1] 195 | 196 | if model_name == "llama-7B": 197 | target = target_ids[0] 198 | foil = target_ids[1] 199 | # CORRECT_ID = tokenizer(target)['input_ids'] 200 | # print('correct id',CORRECT_ID,tokenizer.convert_ids_to_tokens(CORRECT_ID)) 201 | CORRECT_ID = tokenizer(target)['input_ids'][1] 202 | FOIL_ID = tokenizer(foil)['input_ids'][1] 203 | else: 204 | CORRECT_ID = tokenizer(" " + target_ids[0])['input_ids'][0] 205 | FOIL_ID = tokenizer(" " + target_ids[1])['input_ids'][0] 206 | 207 | 208 | 209 | 210 | # print('CORRECT_ID:',CORRECT_ID,tokenizer.decode(CORRECT_ID),'FOIL_ID:',FOIL_ID,tokenizer.decode(FOIL_ID)) 211 | 212 | probs = [float(y[CORRECT_ID]),float(y[FOIL_ID])] 213 | # print('original probs:',probs) 214 | probs = torch.tensor(probs) 215 | probs = torch.nn.functional.softmax(probs,dim = -1) 216 | # print('final probs:',probs) 217 | 218 | y = probs[0] 219 | 220 | # err = np.sum((y0-y)**2) 221 | # mse.append(err) 222 | # evidence.append(softmax(y)[int(y_true)]) 223 | 224 | # print('{:0.2f}'.format(frac), ' '.join(tokenizer.convert_ids_to_tokens(inputs.detach().cpu().numpy().squeeze()))) 225 | evolution[frac] = (inputs.detach().cpu().numpy(), y) 226 | 227 | # if flip_case == 'generate' and frac == 1.: 228 | # assert (inputs0 == inputs).all() 229 | # 230 | # 231 | # model_outs['flip_evolution'] = evolution 232 | return evolution 233 | 234 | # method = "sota" # ours # sota 235 | if method == 'sota': 236 | res = {'IG_base': {}, 'IG_con': {}, 'GN_base': {}, 'GN_con': {}} 237 | if method == "ours": 238 | res = {'ours': {},'ours_back':{},'ours_add':{}} 239 | 240 | if method == "erasure": 241 | res = {'erasure': {}} 242 | if method == "rollout": 243 | res = {'rollout': {}} 244 | 245 | 246 | for nnn in data_name: 247 | for kkk in result.keys(): 248 | res[kkk][nnn] = [] 249 | print('initial result:',result) 250 | 251 | for name in data_name: 252 | dataset = './data/' + name + '.jsonl' 253 | print('current dataset:',dataset) 254 | all_sample = [] 255 | with open(dataset, 'r', encoding='utf-8') as f: 256 | for line in f: 257 | json_object = json.loads(line) 258 | all_sample.append(json_object) 259 | 260 | for kkk in exp_all.keys(): 261 | exps = exp_all[kkk][name] 262 | 263 | for idx in tqdm(range(len(exps))): 264 | sample = all_sample[idx] 265 | 266 | 267 | input = sample["one_prefix_prefix"].strip() 268 | # print('input:', input) 269 | sample_length = len(input.split(' ')) 270 | if sample_length < 2: 271 | continue 272 | 273 | target = sample["one_prefix_word_good"] 274 | foil = sample['one_prefix_word_bad'] 275 | explanation = exps[idx] 276 | explanation = explanation.squeeze() 277 | token_ids = tokenizer(input, return_tensors="pt").input_ids 278 | target_ids = [target, foil] 279 | fracs = np.linspace(0, 1, 6) 280 | # print('fracs:', fracs) 281 | 282 | evolution = flip(model, 283 | x=explanation, 284 | token_ids=token_ids, 285 | tokens=input, 286 | target_ids=target_ids, 287 | fracs=fracs, 288 | flip_case=flip_case, 289 | random_order=False, 290 | tokenizer=tokenizer, ) 291 | res[kkk][name].append(evolution) 292 | 293 | torch.save(res,save_path) 294 | for k2 in res.keys(): 295 | print('current method:',k2) 296 | 297 | current_data = res[k2] 298 | for k5 in current_data.keys(): 299 | print('current dataset:',k5) 300 | current_ = res[k2][k5] 301 | all_acc = [] 302 | for ww in current_: 303 | sample_acc = [] 304 | for lll in ww.keys(): 305 | # print('ww[lll]:',ww[lll]) 306 | sample_acc.append(ww[lll][1]) 307 | all_acc.append(sample_acc) 308 | all_acc = np.array(all_acc) 309 | print('accuracy:', np.mean(all_acc)) 310 | 311 | 312 | # out_name = 'acc_' + k2 + '.pt' 313 | # torch.save(res,out_name) -------------------------------------------------------------------------------- /TDD_step1.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import os 3 | 4 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTJForCausalLM, AutoTokenizer, AutoModelForCausalLM,BloomTokenizerFast, BloomForCausalLM, GPTNeoXForCausalLM 5 | import torch 6 | import json 7 | import numpy as np 8 | from tqdm import tqdm 9 | from interpret.lm_saliency import * 10 | # from experiment import flip 11 | import random 12 | import torch.nn.functional as F 13 | 14 | # input parameters 15 | model_name = 'gpt2-large' # gpt2-large (36 layers), gpt-j-6B (28 layers), llama-7B (32 layers), bloom-7b (30 layerss), Pythia-6.9b 32 layers 16 | 17 | method = "ours" # ours # sota #erasure # rollout 18 | if method == 'sota': 19 | result = {'IG_base': {}, 'IG_con': {}, 'GN_base': {}, 'GN_con': {}} 20 | if method == "ours": 21 | result = {'ours': {},'ours_back':{},'ours_add':{}} 22 | if method == "rollout": 23 | result = {'rollout': {}} 24 | 25 | 26 | # model_name = "Pythia-6.9b" 27 | data_name = ["anaphor_gender_agreement",'anaphor_number_agreement','animate_subject_passive', 28 | 'determiner_noun_agreement_1','determiner_noun_agreement_irregular_1', 29 | 'determiner_noun_agreement_with_adjective_1','determiner_noun_agreement_with_adj_irregular_1', 30 | 'npi_present_1','distractor_agreement_relational_noun','irregular_plural_subject_verb_agreement_1', 31 | 'regular_plural_subject_verb_agreement_1'] 32 | data_name = ['anaphor_gender_agreement','anaphor_number_agreement'] 33 | save_name_post = "all" 34 | 35 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | layer_norm = None 37 | 38 | for nnn in data_name: 39 | for kkk in result.keys(): 40 | result[kkk][nnn] = [] 41 | print('initial result:',result) 42 | 43 | # load models 44 | if model_name == "gpt2-large": 45 | tokenizer = GPT2Tokenizer.from_pretrained(model_name) 46 | model = GPT2LMHeadModel.from_pretrained(model_name) 47 | lm_head = model.lm_head 48 | layer_norm = model.transformer.ln_f 49 | save_name_prefix = 'exp2_gpt2large_' 50 | 51 | if model_name == "gpt-j-6B": 52 | 53 | model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", load_in_8bit=True,resume_download=True) 54 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 55 | lm_head = model.lm_head 56 | # layer_norm = model.transformer.ln_f 57 | layer_norm = model.transformer.ln_f 58 | save_name_prefix = 'exp2_gptj' 59 | 60 | if model_name == 'llama-7B': 61 | model_path = "meta-llama/Llama-2-7b-hf" 62 | 63 | tokenizer = AutoTokenizer.from_pretrained(model_path, token=" ") 64 | model = AutoModelForCausalLM.from_pretrained(model_path, token=" ", 65 | load_in_8bit=True) # load_in_4bit=True 66 | lm_head = model.lm_head 67 | layer_norm = model.model.norm 68 | save_name_prefix = 'exp2_llama7B' 69 | 70 | if model_name == "bloom-7b": 71 | model_path = 'bigscience/bloom-7b1' 72 | model = BloomForCausalLM.from_pretrained(model_path, load_in_8bit=True,resume_download=True) 73 | tokenizer = BloomTokenizerFast.from_pretrained(model_path) 74 | lm_head = model.lm_head 75 | layer_norm = model.transformer.ln_f 76 | save_name_prefix = 'exp2_bloom7b' 77 | 78 | if model_name == "opt-6.7b": 79 | # no 80 | model_path = "facebook/opt-6.7b" 81 | model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16,resume_download=True) 82 | tokenizer = AutoTokenizer.from_pretrained(model_path) 83 | lm_head = model.lm_head 84 | # layer_norm = model.decoder. 85 | save_name_prefix = 'exp2_opt6b' 86 | 87 | if model_name == 'Pythia-6.9b': 88 | model_path = "EleutherAI/pythia-6.9b-deduped-v0" 89 | model = AutoModelForCausalLM.from_pretrained(model_path, load_in_8bit=True, resume_download=True) 90 | tokenizer = AutoTokenizer.from_pretrained(model_path) 91 | lm_head = model.embed_out 92 | layer_norm = model.gpt_neox.final_layer_norm 93 | save_name_prefix = 'exp2_pythia6b' 94 | 95 | # print(list(model.modules())) 96 | if model_name == "gpt2-large": 97 | model.to(device) 98 | 99 | 100 | 101 | 102 | # load dataset 103 | 104 | def attention_rollout(attentions): 105 | # Initialize with identity matrix 106 | seq_len = attentions[0].size(2) # Assuming attentions shape is (batch_size, num_heads, seq_len, seq_len) 107 | rollout_attention = torch.eye(seq_len).unsqueeze(0).unsqueeze(1) # shape: (batch_size, 1, seq_len, seq_len) 108 | rollout_attention = rollout_attention.to('cuda') 109 | rollout_attention = rollout_attention.float() 110 | 111 | # Iterate over attention weights from top to bottom 112 | for layer_attention in reversed(attentions): 113 | # Average attention weights across heads 114 | avg_attention = layer_attention.mean(dim=1, keepdim=True) # shape: (batch_size, 1, seq_len, seq_len) 115 | avg_attention = avg_attention.float() 116 | 117 | # Multiply the rolled-out attention so far with the current layer's attention 118 | # print('avg_attention:',avg_attention) 119 | # print('rollout_attention:',rollout_attention) 120 | rollout_attention = torch.bmm(avg_attention.squeeze(1), rollout_attention.squeeze(1)).unsqueeze(1) 121 | 122 | return rollout_attention 123 | 124 | for name in data_name: 125 | dataset = './data/' + name + '.jsonl' 126 | print('current dataset:',dataset) 127 | all_sample = [] 128 | with open(dataset, 'r', encoding='utf-8') as f: 129 | for line in f: 130 | json_object = json.loads(line) 131 | all_sample.append(json_object) 132 | 133 | for ss in tqdm(all_sample): 134 | sen_input = ss["one_prefix_prefix"].strip() 135 | 136 | # calculate length 137 | sen_len = len('sen_input split') 138 | 139 | if model_name == "llama-7B": 140 | if method == "sota" or method == "erasure": 141 | sen_input = sen_input + ' ' 142 | target = ss["one_prefix_word_good"] 143 | foil = ss['one_prefix_word_bad'] 144 | # CORRECT_ID = tokenizer(target)['input_ids'] 145 | # print('correct id',CORRECT_ID,tokenizer.convert_ids_to_tokens(CORRECT_ID)) 146 | CORRECT_ID = tokenizer(target)['input_ids'][1] 147 | FOIL_ID = tokenizer(foil)['input_ids'][1] 148 | # print('target', target, 'foil', foil, 'CORRECT_ID:', CORRECT_ID, 'decoded:', tokenizer.decode(CORRECT_ID), 149 | # 'FOIL_ID:', FOIL_ID, tokenizer.decode(FOIL_ID)) 150 | else: 151 | if method == "sota" or method == "erasure": 152 | sen_input = sen_input + ' ' 153 | target = ss["one_prefix_word_good"] 154 | foil = ss['one_prefix_word_bad'] 155 | CORRECT_ID = tokenizer(" " + target)['input_ids'][0] 156 | FOIL_ID = tokenizer(" " + foil)['input_ids'][0] 157 | # print('target', target, 'foil', foil, 'CORRECT_ID:', CORRECT_ID, 'decoded:', tokenizer.decode(CORRECT_ID), 158 | # 'FOIL_ID:', FOIL_ID, tokenizer.decode(FOIL_ID)) 159 | 160 | 161 | 162 | if method == "rollout": 163 | input_ids = tokenizer(sen_input, return_tensors="pt").input_ids 164 | input_ids = input_ids.to('cuda') 165 | with torch.no_grad(): 166 | outputs = model(input_ids, output_hidden_states=True, output_attentions=True) 167 | hidden_states = outputs.hidden_states 168 | # target_hidden = hidden_states[k+1] 169 | 170 | attentions = outputs.attentions 171 | rollout = attention_rollout(attentions) 172 | rollout = rollout.cpu() 173 | rollout = np.array(rollout) 174 | final_rollout = rollout[0][0][-1] 175 | # print('final_rollout:',final_rollout.shape,final_rollout) 176 | result['rollout'][name].append(final_rollout) 177 | 178 | elif method == "sota": 179 | input_tokens = tokenizer(sen_input)['input_ids'] 180 | # print('input_tokens:', input_tokens) 181 | # print(tokenizer.convert_ids_to_tokens(input_tokens)) 182 | attention_ids = tokenizer(sen_input)['attention_mask'] 183 | base_saliency_matrix, base_embd_matrix = saliency(model, model_name,input_tokens, attention_ids) 184 | # print('base_saliency_matrix:',len(base_saliency_matrix)) 185 | # print('base_embd_matrix:',len(base_embd_matrix)) 186 | saliency_matrix, embd_matrix = saliency(model, model_name,input_tokens, attention_ids, foil=FOIL_ID) 187 | base_explanation = input_x_gradient(base_saliency_matrix, base_embd_matrix, normalize=True) 188 | contra_explanation = input_x_gradient(saliency_matrix, embd_matrix, normalize=True) 189 | # print('base_explanation:', base_explanation) 190 | (result['IG_base'])[name].append(base_explanation) 191 | result['IG_con'][name].append(contra_explanation) 192 | 193 | base_explanation = l1_grad_norm(base_saliency_matrix, normalize=True) 194 | contra_explanation = l1_grad_norm(saliency_matrix, normalize=True) 195 | result['GN_base'][name].append(base_explanation) 196 | result['GN_con'][name].append(contra_explanation) 197 | 198 | elif method == 'ours': 199 | 200 | input_ids = tokenizer(sen_input, return_tensors="pt").input_ids 201 | input_ids = input_ids.to('cuda') 202 | # print('input_ids:',input_ids.shape,input_ids) 203 | # print(tokenizer.decode(input_ids[0])) 204 | 205 | original_input = input_ids 206 | 207 | all_exp = [] 208 | repeat_num = 2 209 | 210 | # repeat_num = len(input_ids.shape[-1]) 211 | 212 | for rn in range(repeat_num): 213 | if rn == 0: 214 | input_ids = original_input 215 | permuted_indices = None 216 | 217 | input_ids = input_ids.to('cuda') 218 | # print('********** rn:',rn) 219 | # print('input_ids:', input_ids.shape, input_ids) 220 | # print(tokenizer.decode(input_ids[0])) 221 | with torch.no_grad(): 222 | outputs = model(input_ids, output_hidden_states=True) 223 | hidden_states = outputs.hidden_states 224 | # print('hidden_states:',len(hidden_states)) 225 | logits = outputs.logits[0] 226 | base_hidden = outputs.hidden_states[0] 227 | base_logits = lm_head(base_hidden)[0] 228 | base_score = [] 229 | for pp, qq in zip(input_ids[0], base_logits): 230 | 231 | 232 | current_score = [float(qq[CORRECT_ID]), float(qq[FOIL_ID])] 233 | current_score = torch.tensor(current_score) 234 | current_score = torch.nn.functional.softmax(current_score, dim=-1) 235 | # print('current score:',current_score) 236 | diff = current_score[0] - current_score[ 237 | 1] # should be base (taken from layer 1 or embedding layer) + difference (current context, pervious influence) 238 | base_score.append(float(diff)) 239 | 240 | # print('token:', tokenizer.decode(pp), 'target:', target, 'prob:', qq[CORRECT_ID], 'foil:', foil, 'prob:', 241 | # qq[FOIL_ID], 'diff:', diff) 242 | 243 | learn_score = [] 244 | tokens = [] 245 | for xx, yy in zip(input_ids[0], logits): 246 | # method 1: yy[CORRECT_ID] * yy[CORRECT_ID]/yy[FOIL_ID] after softmax 247 | # method 2: should be softmax 248 | 249 | current_score = [float(yy[CORRECT_ID]), float(yy[FOIL_ID])] 250 | current_score = torch.tensor(current_score) 251 | current_score = torch.nn.functional.softmax(current_score, dim=-1) 252 | # print('current score:',current_score) 253 | diff = current_score[0] - current_score[ 254 | 1] # should be base (taken from layer 1 or embedding layer) + difference (current context, pervious influence) 255 | learn_score.append(float(diff)) 256 | tokens.append(tokenizer.decode(xx)) 257 | 258 | # print('token:', tokenizer.decode(xx), 'target:', target, 'prob:', yy[CORRECT_ID], 'foil:', foil, 'prob:', 259 | # yy[FOIL_ID], 'diff:', diff) 260 | 261 | # print('target:',target,'prob:',logits[CORRECT_ID]) 262 | # print('foil:',foil,'prob:',logits[FOIL_ID]) 263 | 264 | base_score = torch.tensor(base_score) 265 | learn_score = torch.tensor(learn_score) 266 | final_exp = [] 267 | for ind in range(len(learn_score)): 268 | if ind == 0: 269 | final_exp.append(float(learn_score[ind])) 270 | else: 271 | diff = learn_score[ind] - learn_score[ind - 1] 272 | final_exp.append(float(diff)) 273 | 274 | # check 275 | 276 | final_exp = torch.tensor(final_exp).unsqueeze(0) 277 | # print('origianl final exp:', final_exp) 278 | 279 | if permuted_indices is not None: 280 | final_exp = final_exp[0, permuted_indices].unsqueeze(0) 281 | 282 | final_exp = np.array(final_exp[0]) 283 | 284 | else: 285 | tensor = input_ids 286 | learn_score= [] 287 | 288 | for i in range(1, tensor.shape[1] + 1): 289 | new_tensor = tensor[:, -i:] # This gets the first element 290 | # new_tensor = torch.cat((new_tensor, tensor[:, i:]), 291 | # 1) # This concatenates the remaining elements 292 | # print(new_tensor) 293 | # print('new_tensor:',tokenizer.decode(new_tensor[0])) 294 | with torch.no_grad(): 295 | outputs = model(new_tensor, output_hidden_states=True) 296 | hidden_states = outputs.hidden_states 297 | # print('hidden_states:',len(hidden_states)) 298 | logits = outputs.logits[0] 299 | final_logit = logits[-1] 300 | # print('final_logit:',final_logit.shape,final_logit) 301 | 302 | current_score = [float(final_logit[CORRECT_ID]), float(final_logit[FOIL_ID])] 303 | current_score = torch.tensor(current_score) 304 | current_score = torch.nn.functional.softmax(current_score, dim=-1) 305 | # print('current score:',current_score) 306 | diff = current_score[0] - current_score[ 307 | 1] # should be base (taken from layer 1 or embedding layer) + difference (current context, pervious influence) 308 | learn_score.append(float(diff)) 309 | 310 | learn_score = torch.tensor(learn_score) 311 | final_exp = [] 312 | for ind in range(len(learn_score)): 313 | if ind == 0: 314 | final_exp.append(float(learn_score[ind])) 315 | else: 316 | diff = learn_score[ind] - learn_score[ind - 1] 317 | final_exp.append(float(diff)) 318 | final_exp = np.array(final_exp) 319 | # print('final_exp:',final_exp) 320 | 321 | # print('final_exp = np.array(final_exp[0]):',final_exp.shape,final_exp) 322 | final_exp = final_exp[::-1] 323 | # print('final_exp FF:', final_exp) 324 | # print('final final final_exp',final_exp.shape,final_exp) 325 | 326 | 327 | 328 | # print('recovered final_exp:',final_exp) 329 | all_exp.append(final_exp) 330 | 331 | ## start save different samples 332 | result["ours"][name].append(all_exp[0]) 333 | 334 | result["ours_back"][name].append(all_exp[1]) 335 | # print('original exp:',all_exp[1]) 336 | # print('original exp 2:', all_exp) 337 | all_exp = np.array(all_exp) 338 | 339 | # print('all_exp:',all_exp.shape,all_exp) 340 | ave_exp = np.mean(all_exp,axis = 0 ) 341 | # print('ave_exp:',ave_exp.shape,ave_exp) 342 | result["ours_add"][name].append(ave_exp) 343 | 344 | 345 | out_name = './exp_result/' + model_name + '_result_'+method + save_name_post + '.pt' 346 | torch.save(result,out_name) --------------------------------------------------------------------------------