├── prompt ├── prompt_augmenter_new.py ├── prompt_decoder.py ├── text_attack_pytorch_warper.py └── prompt_augmenter.py ├── dataset ├── dataset_warper.py ├── dataset_format_convet.py ├── split_dataset.py ├── extract_arguemented_dataset.py ├── multi_thread_data_augmenter.py └── align_dataset.py ├── LICENSE ├── test └── test_datawarper.py ├── README.md ├── run ├── inference_dataset_oldversion2.py ├── inference_dataset_old3.py ├── inference_dataset_testset.py ├── inference_dataset.py ├── inference_testset.py └── inference_dataset_old_version.py ├── eval ├── eval.py └── select_evaluation.py └── model └── warper.py /prompt/prompt_augmenter_new.py: -------------------------------------------------------------------------------- 1 | from textattack.attack_recipes.textfooler_jin_2019 import TextFoolerJin2019 2 | 3 | if __name__ == '__main__': 4 | TextFoolerJin2019.build(model_warpper) -------------------------------------------------------------------------------- /dataset/dataset_warper.py: -------------------------------------------------------------------------------- 1 | from datasets import load_from_disk 2 | from icecream import ic 3 | 4 | dataset_qwq = load_from_disk('./yelp_review_full_split_train_dev') 5 | dataset = [] 6 | for i in dataset_qwq['test']: 7 | dataset.append((i['text'], i['label'])) 8 | ic(dataset[0:3]) 9 | import textattack 10 | dataset = textattack.datasets.Dataset(dataset) 11 | ic(dataset) -------------------------------------------------------------------------------- /dataset/dataset_format_convet.py: -------------------------------------------------------------------------------- 1 | from datasets import load_from_disk 2 | from icecream import ic 3 | dataset = load_from_disk('./yelp_review_full_split_train_dev') 4 | df_dict = { 5 | 'text':[], 6 | 'label':[] 7 | } 8 | ic(len(dataset['test'])) 9 | for i in dataset['test']: 10 | df_dict['text'].append(i['text']) 11 | df_dict['label'].append(i['label']) 12 | if (len(df_dict['text'])==10000): 13 | pass 14 | import pandas as pd 15 | ic(len(df_dict['text'])) 16 | df = pd.DataFrame.from_dict(df_dict) 17 | df.to_csv("yelp_testset.csv") -------------------------------------------------------------------------------- /dataset/split_dataset.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm, trange 2 | from icecream import ic 3 | 4 | if __name__=='__main__': 5 | from datasets import load_from_disk 6 | dataset = load_from_disk('yelp_review_full') 7 | ic(dataset) 8 | dataset['train'] = dataset['train'].shuffle(seed=42) 9 | sum = [0,0,0,0,0] 10 | selected_train = [] 11 | selected_dev = [] 12 | ans = 0 13 | for i in trange(len(dataset['train'])): 14 | if sum[dataset['train'][i]['label']]<2000: 15 | sum[dataset['train'][i]['label']]+=1 16 | selected_train.append(i) 17 | ans+=1 18 | elif sum[dataset['train'][i]['label']]<4000: 19 | sum[dataset['train'][i]['label']]+=1 20 | selected_dev.append(i) 21 | ans+=1 22 | else: 23 | if (ans==20000): 24 | break 25 | dataset['selected_train'] = dataset['train'].select(selected_train) 26 | dataset['selected_dev'] = dataset['train'].select(selected_dev) 27 | ic(dataset) 28 | dataset.save_to_disk('./yelp_review_full_split_train_dev') -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ZHIHENG LYU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /prompt/prompt_decoder.py: -------------------------------------------------------------------------------- 1 | from icecream import ic 2 | with open("seed_prompt.txt") as f: 3 | content = ''.join(f.readlines()) 4 | seed_prompt_candidates = content.split('\n\n\n') 5 | ic(len(seed_prompt_candidates)) 6 | df_dict = { 7 | "text":[], 8 | "label":[], 9 | } 10 | for i in range(len(seed_prompt_candidates)): 11 | setup = i//10+1 12 | df_dict['label'].append(setup) 13 | if (setup==1): 14 | text = f"{seed_prompt_candidates[i]}The person’s rating was" 15 | elif (setup==2): 16 | text = f"{seed_prompt_candidates[i]}" 17 | else: 18 | text = "A person saw this Yelp review: "+"f{review_text}"+f"{seed_prompt_candidates[i]} In this case, the person guessed the rating was " 19 | 20 | df_dict['text'].append(text) 21 | 22 | import pandas as pd 23 | ic(len(df_dict['text'])) 24 | df = pd.DataFrame.from_dict(df_dict) 25 | df.to_csv("seed_prompt.csv") 26 | """ 27 | [[SETUP 1]] 28 | [For all seed prompts, use them as] 29 | 30 | “f{seed_prompt} 31 | The person’s rating was” [now let GPT generate] 32 | 33 | 34 | [[SETUP 2]] 35 | [For all seed prompts, use them as] 36 | 37 | “f{seed_prompt}” 38 | 39 | [[SETUP 3]] 40 | [For all seed prompts, use them as] 41 | 42 | “A person saw this Yelp review: f{review_text}. f{seed_prompt}. In this case, the person guessed the rating was ” 43 | 44 | """ -------------------------------------------------------------------------------- /test/test_datawarper.py: -------------------------------------------------------------------------------- 1 | from warper import TokenizerWarper, generate_prompts 2 | 3 | # We want to do is to warp the dataset input text like the gpt2_large 4 | # So we first tokenize by GPT2 tokenlizer, then convert it to the text 5 | class GPT3TokenizerWarper(TokenizerWarper): 6 | def __init__(self, prompt, tokenizer_path = "./gpt2-large_saved/token"): 7 | super().__init__(prompt, tokenizer_path) 8 | 9 | def calc_single_str(self, str): 10 | token_ids = self.encode_single_sentence(str)['input_ids'] 11 | return self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(token_ids)) 12 | 13 | def __call__(self, str): 14 | output_dict = { 15 | "input_prompts":[], 16 | } 17 | if (type(str)==list): 18 | for i in str: 19 | output_dict["input_prompts"].append(self.calc_single_str(i)) 20 | else: 21 | output_dict["input_prompts"].append(self.calc_single_str(str)) 22 | return output_dict 23 | from datasets import load_from_disk 24 | from icecream import ic 25 | if __name__=='__main__': 26 | import pandas as pd 27 | df = pd.read_csv('prompt_pool.csv') 28 | ic(len(df)) 29 | perfix_list, postfix_list = generate_prompts(df, 1) 30 | ic(len(perfix_list)) 31 | 32 | dataset = load_from_disk('./yelp_review_full_split_train_dev') 33 | subdataset = dataset['train'] 34 | tokenizer_list = [] 35 | for i in range(len(perfix_list)): 36 | tokenizer_list.append(GPT3TokenizerWarper([perfix_list[i], postfix_list[i]])) 37 | print(tokenizer_list[0](subdataset[100]['text'])['input_prompts'][0]) 38 | 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llm-bivariate-causal-discovery 2 | [[paper](https://openreview.net/forum?id=ucHh-ytUkOH)] 3 | ## Introduction 4 | Identifying the causal direction between two variables has long been an important but challenging task for causal inference. Existing work proposes to distinguish whether $X\rightarrow Y$ or $Y \rightarrow X$ by setting up an input-output learning task using the two variables, since causal and anticausal learning have different performances under semi-supervised learning and domain shift. This approach works for many task-specific models trained on the input-output pairs. However, with the rise of general-purpose large language models (LLMs), there are various challenges posed to this previous task-specific learning approach, since continued training of LLMs is less likely to be affordable for university labs, and LLMs are no longer trained on specific input-output pairs. In this work, we propose a new paradigm to distinguish cause from effect using LLMs. Specifically, we conduct post-hoc analysis using natural language prompts that describe different possible causal stories behind the $X$, $Y$ pairs, and test their zero-shot performance. Through the experiments, we show that the natural language prompts that describe the same causal story as the ground-truth data generating direction achieve the highest zero-shot performance, with 2\% margin over anticausal prompts. We highlight that it will be an interesting direction to identify more causal relations using LLMs. 5 | 6 | ## Code Structure 7 | * `./dataset/` 8 | * the code about split and augment the dataset 9 | * `./eval/` 10 | * evaluating the result 11 | * `./model/` 12 | * warping the model 13 | * `./prompt/` 14 | * the code about generate the prompt and warping prompt to the dataset 15 | * `./run/` 16 | * do the inference of gpt2 series model in the given dataset 17 | * `./test/` 18 | * test codes to the implementation of some modules 19 | 20 | ## Some special remark 21 | In this version, we use GPT3 API to manually generate the prompt and using `text attack` package to paraphrases the prompt and select the best one. We use the `gpt2` series models with the default in the huggingface. 22 | 23 | The origin code is run in the same folder, and to have a more clear structure, the code is split by their functions into 5 folders. 24 | -------------------------------------------------------------------------------- /run/inference_dataset_oldversion2.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch 3 | from icecream import ic 4 | import pandas as pd 5 | from warper import TokenizerWarper, ModelWarper 6 | max_length = 1024 7 | def generate_prompts(df): 8 | perfix = [] 9 | postfix = [] 10 | for i in range(len(df)): 11 | current_str = df['text'].iloc[i].replace("\\n","\n") 12 | perfix.append(current_str.split("\n\n\n\n")[0]+"\n\n") 13 | postfix.append("\n\n"+current_str.split("\n\n\n\n")[-1]) 14 | return perfix, postfix 15 | 16 | if __name__=='__main__': 17 | from datasets import load_from_disk 18 | dataset = load_from_disk('./yelp_review_full_split_train_dev') 19 | 20 | model = ModelWarper() 21 | 22 | print(f"Model.device is: {model.device}") 23 | 24 | final_dict = { 25 | "label":[], 26 | "text":[], 27 | "answer":[], 28 | } 29 | import numpy as np 30 | 31 | df = pd.read_csv('yelp_prompt_arguement_full_version.csv') 32 | perfix_list, postfix_list = generate_prompts(df) 33 | tokenizer_list = [] 34 | for i in range(len(perfix_list)): 35 | tokenizer_list.append(TokenizerWarper([perfix_list[i], postfix_list[i]])) 36 | label_idx = np.array([352, 362, 513, 604, 642]) 37 | sum = 0 38 | with torch.no_grad(): 39 | for i in trange(0,len(dataset['selected_train']),8): 40 | batch_data = dataset['selected_train'][i:i+8] 41 | sum+=1 42 | if (sum==10): 43 | #break 44 | pass 45 | final_dict['label'].append(batch_data['label']) 46 | final_dict['text'].append(batch_data['text']) 47 | inputs = [tokenizer(batch_data['text']) for tokenizer in tokenizer_list] 48 | outputs = [] 49 | for k in inputs: 50 | outputs.append(model(**k)) 51 | 52 | final_dict["answer"].append(outputs) 53 | 54 | for i in trange(0,len(dataset['selected_dev']),8): 55 | batch_data = dataset['selected_dev'][i:i+8] 56 | sum+=1 57 | if (sum==10): 58 | #break 59 | pass 60 | final_dict['label'].append(batch_data['label']) 61 | final_dict['text'].append(batch_data['text']) 62 | inputs = [tokenizer(i['text']) for tokenizer in tokenizer_list] 63 | outputs = [] 64 | for k in inputs: 65 | outputs.append(model(k)) 66 | 67 | final_dict["answer"].append(outputs) 68 | a=np.array(final_dict) 69 | np.save("/cluster/project/sachan/zhiheng/causal_prompting/intermediate_data/predict_train_14.npy",a) 70 | 71 | 72 | -------------------------------------------------------------------------------- /run/inference_dataset_old3.py: -------------------------------------------------------------------------------- 1 | def generate_prompts(df): 2 | perfix = [] 3 | postfix = [] 4 | for i in range(len(df)): 5 | current_str = df['text'].iloc[i].replace("\\n","\n") 6 | perfix.append(current_str.split("\n\n\n\n")[0]) 7 | postfix.append(current_str.split("\n\n\n\n")[-1]) 8 | return perfix, postfix 9 | if __name__=='__main__': 10 | import pandas as pd 11 | df = pd.read_csv('yelp_prompt_example.csv') 12 | """ 13 | perfix_list = ["I just finished eating at a restaurant. Then I opened my Yelp app.\n\nI first gave a rating, and then wrote the following review:\n\n", 14 | "I just finished eating at a restaurant. Then I opened my Yelp app, and wrote the following review: \n\n", 15 | "I opened my Yelp app, and started to read some reviews of a restaurant that I want to try. I saw a user wrote this review:\n\n"] 16 | postfix_list = ["\n\nThe review is an explanation of why I gave a rating (out of 1 to 5 stars) of", 17 | "\n\nThen I gave the rating. In terms of 1 to 5 stars, I think this restaurant is worth a", 18 | "\n\nIn terms of 1 to 5 stars, I think this user rated it a"] 19 | """ 20 | label_idx = np.array([352, 362, 513, 604, 642]) 21 | with torch.no_grad(): 22 | for i in tqdm(dataset['test']): 23 | final_dict['label'].append(i['label']) 24 | final_dict['text'].append(i['text']) 25 | for j in range(3): 26 | current_text = perfix_list[j]+i['text']+postfix_list[j] 27 | input = tokenizer(current_text, return_tensors='pt', padding=True) 28 | if (len(input['input_ids'][0])>1024): 29 | perfix_input = tokenizer(perfix_list[j], return_tensors='pt') 30 | text_input = tokenizer(i['text'], return_tensors='pt') 31 | postfix_input = tokenizer("..."+postfix_list[j], return_tensors='pt') 32 | for k in input: 33 | ic([perfix_input[k],text_input[k][:,1024-len(perfix_input[k][0])-len(postfix_input[k][0])],postfix_input[k]]) 34 | input[k] = torch.cat([perfix_input[k],text_input[k][:,:1024-len(perfix_input[k][0])-len(postfix_input[k][0])],postfix_input[k]],dim=1) 35 | for k in input: 36 | input[k]=input[k].cuda() 37 | output = model(**input)[0].cpu().detach().numpy()[0][-1][label_idx] 38 | final_dict[f"setup{j+1}_answer"].append(output) 39 | #ic(i['label'], output) 40 | a=np.array(final_dict) 41 | np.save("/cluster/project/sachan/zhiheng/causal_prompting/intermediate_data/predict_result.npy",a) 42 | 43 | 44 | print(generate_prompts(df)) 45 | -------------------------------------------------------------------------------- /eval/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.metrics 3 | from sklearn.metrics import accuracy_score, f1_score 4 | from collections import Counter 5 | from icecream import ic 6 | if __name__=='__main__': 7 | result_path = "/cluster/project/sachan/zhiheng/causal_prompting/intermediate_data/predict_result.npy" 8 | a = np.load(result_path, allow_pickle = True) 9 | result_dict=a.tolist() 10 | ic(result_dict.keys()) 11 | gt_labels = result_dict['label'][0:2000] 12 | result_dict['setup1_answer'] = np.array(result_dict['setup1_answer'][0:2000]) 13 | result_dict['setup2_answer'] = np.array(result_dict['setup2_answer'][0:2000]) 14 | result_dict['setup3_answer'] = np.array(result_dict['setup3_answer'][0:2000]) 15 | lr = 0.01 16 | result1 = np.array(result_dict['setup1_answer']).argmax(axis=1) 17 | result2 = np.array(result_dict['setup2_answer']).argmax(axis=1) 18 | result3 = np.array(result_dict['setup3_answer']).argmax(axis=1) 19 | offset1 = [0,0,0,0,0] 20 | offset2 = [0,0,0,0,0] 21 | offset3 = [0,0,0,0,0] 22 | ic(Counter(gt_labels), Counter(result1),Counter(result2),Counter(result3)) 23 | ic(accuracy_score(gt_labels, result1), f1_score(gt_labels, result1, average='weighted')) 24 | ic(accuracy_score(gt_labels, result2), f1_score(gt_labels, result2, average='weighted')) 25 | ic(accuracy_score(gt_labels, result3), f1_score(gt_labels, result3, average='weighted')) 26 | cgt = dict(Counter(gt_labels)) 27 | for i in range(1000): 28 | result1 = np.array(result_dict['setup1_answer']).argmax(axis = 1) 29 | result2 = np.array(result_dict['setup2_answer']).argmax(axis = 1) 30 | result3 = np.array(result_dict['setup3_answer']).argmax(axis = 1) 31 | for label in range(5): 32 | offset1[label]+=lr*(1-list(result1).count(label)/cgt[label]) 33 | offset2[label]+=lr*(1-list(result2).count(label)/cgt[label]) 34 | offset3[label]+=lr*(1-list(result3).count(label)/cgt[label]) 35 | result_dict['setup1_answer'][:,label]+=lr*(1-list(result1).count(label)/cgt[label]) 36 | result_dict['setup2_answer'][:,label]+=lr*(1-list(result2).count(label)/cgt[label]) 37 | result_dict['setup3_answer'][:,label]+=lr*(1-list(result3).count(label)/cgt[label]) 38 | result1 = np.array(result_dict['setup1_answer']).argmax(axis = 1) 39 | result2 = np.array(result_dict['setup2_answer']).argmax(axis = 1) 40 | result3 = np.array(result_dict['setup3_answer']).argmax(axis = 1) 41 | ic(Counter(gt_labels), Counter(result1),Counter(result2),Counter(result3)) 42 | ic(accuracy_score(gt_labels, result1), f1_score(gt_labels, result1, average='weighted')) 43 | ic(accuracy_score(gt_labels, result2), f1_score(gt_labels, result2, average='weighted')) 44 | ic(accuracy_score(gt_labels, result3), f1_score(gt_labels, result3, average='weighted')) 45 | ic(offset1) 46 | ic(offset2) 47 | ic(offset3) 48 | #ic(gt_labels[0:100], result3[0:100]) 49 | #ic(result1[0:100], result2[0:100]) -------------------------------------------------------------------------------- /run/inference_dataset_testset.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch 3 | from icecream import ic 4 | import pandas as pd 5 | from warper import TokenizerWarper, ModelWarper 6 | import argparse 7 | 8 | max_length = 1024 9 | def generate_prompts(df, setup): 10 | perfix = [] 11 | postfix = [] 12 | for i in range(len(df)): 13 | if (int(df['label'].iloc[i])!=setup or setup == -1): 14 | continue 15 | current_str = df['text'].iloc[i].replace("\\n","\n") 16 | perfix.append(current_str.split("\n\n\n\n")[0]+"\n\n") 17 | postfix.append("\n\n"+current_str.split("\n\n\n\n")[-1]) 18 | ic(len(perfix)) 19 | return perfix, postfix 20 | 21 | if __name__=='__main__': 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("setup", help="the selected setup number") 25 | parser.add_argument("--output", help="the path of output file") 26 | parser.add_argument("--first", help="use the first 50% of the dataset, otherwise the remain 50%", 27 | action="store_true") 28 | args = parser.parse_args() 29 | 30 | from datasets import load_from_disk 31 | dataset = load_from_disk('./yelp_review_full_split_train_dev') 32 | dataset_idx = 0 33 | if (args.train): 34 | subdataset = dataset['selected_train'] 35 | elif (args.dev): 36 | subdataset = dataset['selected_dev'] 37 | dataset_idx += 2 38 | else: 39 | raise Exception('You should select a subdataset') 40 | begin, end = 0, len(subdataset) 41 | if (args.first): 42 | end = end//2 43 | else: 44 | begin = end//2 45 | dataset_idx += 1 46 | ic(begin, end) 47 | 48 | model = ModelWarper() 49 | 50 | print(f"Model.device is: {model.device}") 51 | setup = int(args.setup) 52 | ic(setup) 53 | if (setup<-1 or setup>3): 54 | raise Exception('incorrect setup number') 55 | 56 | final_dict = { 57 | "dataset_idx":dataset_idx, 58 | "setup":setup, 59 | "label":[], 60 | "text":[], 61 | "answer":[], 62 | } 63 | import numpy as np 64 | 65 | df = pd.read_csv('yelp_prompt_arguement_full_version.csv') 66 | perfix_list, postfix_list = generate_prompts(df, setup) 67 | tokenizer_list = [] 68 | for i in range(len(perfix_list)): 69 | tokenizer_list.append(TokenizerWarper([perfix_list[i], postfix_list[i]])) 70 | label_idx = np.array([352, 362, 513, 604, 642]) 71 | sum = 0 72 | with torch.no_grad(): 73 | for i in trange(begin,end,4): 74 | batch_data =subdataset[i:min(i+4, end)] 75 | sum+=1 76 | if (sum==10): 77 | #break 78 | pass 79 | final_dict['label'].append(batch_data['label']) 80 | final_dict['text'].append(batch_data['text']) 81 | inputs = [tokenizer(batch_data['text']) for tokenizer in tokenizer_list] 82 | outputs = [] 83 | for k in inputs: 84 | #if (len(k['input_ids'][0])>1024): 85 | outputs.append(model(**k)) 86 | 87 | final_dict["answer"].append(outputs) 88 | 89 | a=np.array(final_dict) 90 | np.save(args.output,a) 91 | 92 | 93 | -------------------------------------------------------------------------------- /dataset/extract_arguemented_dataset.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | from icecream import ic 3 | import pandas as pd 4 | def load_jsonl_list(path): 5 | jsonl_list = [] 6 | with jsonlines.open(path, 'r') as reader: 7 | for i in reader: 8 | jsonl_list.append(dict(i)) 9 | return jsonl_list 10 | 11 | def extract_and_save_dataset(results): 12 | from datasets import Dataset 13 | dataset_dict = { 14 | "text":[], 15 | "label":[], 16 | "origin_data_idx":[], 17 | "origin_text":[], 18 | } 19 | idx_based_dict = {} 20 | for i in range(50000): 21 | idx_based_dict[i] = [] 22 | for result in results: 23 | for result_text in result['result']: 24 | if (len(idx_based_dict[result['idx']])<10 and not (result_text in idx_based_dict[result['idx']])): 25 | idx_based_dict[result['idx']].append(result_text) 26 | dataset_dict['text'].append(result_text) 27 | dataset_dict['label'].append(result['label']) 28 | dataset_dict['origin_data_idx'].append(result['idx']) 29 | dataset_dict['origin_text'].append(result['text']) 30 | 31 | ic(len(dataset_dict['text'])) 32 | df = pd.DataFrame.from_dict(dataset_dict) 33 | df.to_csv("augmented_dataset_test_full.csv") 34 | from datasets import Dataset 35 | dataset = Dataset.from_dict(dataset_dict) 36 | dataset.save_to_disk("augmented_dataset_test_full") 37 | ic(dataset) 38 | 39 | sum = 0 40 | li = [] 41 | for i in idx_based_dict: 42 | if len(idx_based_dict[i])<10: 43 | sum+=(10-len(idx_based_dict[i])+3)//4 44 | for j in range((10-len(idx_based_dict[i])+3)//4): 45 | li.append(i) 46 | ic(li[0:100]) 47 | ic(sum) 48 | import numpy as np 49 | a = np.array(li) 50 | np.save("regenerate_dataset.npy", a) 51 | 52 | 53 | if __name__ == "__main__": 54 | base_path = "/cluster/project/sachan/zhiheng/causal_prompting/intermediate_data/" 55 | count_dict = {} 56 | for i in range(50000): 57 | count_dict[i]=0 58 | result_list = [] 59 | 60 | for postfix in ["", "_1", "_2"]: 61 | for perfix in range(1,6,1): 62 | file_name = f"augment_test_{perfix}{postfix}.jsonl" 63 | x = load_jsonl_list(base_path+file_name) 64 | for augment_result in x: 65 | count_dict[augment_result['idx']]+=4 66 | result_list.append(augment_result) 67 | import json 68 | 69 | file_name = f"regenerating_dataset.jsonl" 70 | f = open(base_path + file_name) 71 | lines = f.readlines() 72 | import json 73 | for augment_json in lines: 74 | try: 75 | augment_result = json.loads(augment_json) 76 | except: 77 | continue 78 | count_dict[augment_result['idx']] += 4 79 | result_list.append(augment_result) 80 | """ 81 | sum = 0 82 | li = [] 83 | for i in count_dict: 84 | if count_dict[i]<10: 85 | sum+=(10-count_dict[i]+3)//4 86 | for j in range((10-count_dict[i]+3)//4): 87 | li.append(i) 88 | #ic(li[0:100]) 89 | ic(sum) 90 | import numpy as np 91 | a = np.array(li) 92 | np.save("regenerate_dataset.npy", a) 93 | """ 94 | extract_and_save_dataset(result_list) 95 | -------------------------------------------------------------------------------- /run/inference_dataset.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch 3 | from icecream import ic 4 | import pandas as pd 5 | from warper import TokenizerWarper, ModelWarper 6 | from warper import generate_prompts 7 | import argparse 8 | 9 | max_length = 1024 10 | 11 | """ 12 | def generate_prompts(df, setup): 13 | perfix = [] 14 | postfix = [] 15 | for i in range(len(df)): 16 | if (int(df['label'].iloc[i])!=setup): 17 | continue 18 | current_str = df['text'].iloc[i].replace("\\n","\n") 19 | perfix.append(current_str.split("f{review_text}")[0]) 20 | postfix.append(current_str.split("f{review_text}")[-1]) 21 | ic(len(perfix)) 22 | print(perfix) 23 | print(postfix) 24 | return perfix, postfix 25 | """ 26 | 27 | if __name__=='__main__': 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("setup", help="the selected setup number") 31 | parser.add_argument("model", help="the name of selected model") 32 | #parser.add_argument("--output", help="the path of output file") 33 | #parser.add_argument("--first", help="use the first 50% of the dataset, otherwise the remain 50%",action="store_true") 34 | parser.add_argument("--train", help="use the train dataset", action="store_true") 35 | parser.add_argument("--dev", help="use the develop dataset", action="store_true") 36 | args = parser.parse_args() 37 | 38 | from datasets import load_from_disk 39 | dataset = load_from_disk('./yelp_review_full_split_train_dev') 40 | dataset_idx = 0 41 | if (args.train): 42 | subdataset = dataset['selected_train'] 43 | output_path = f"./intermediate_data/s{args.setup}_train_{args.model}_pools3.npy" 44 | elif (args.dev): 45 | subdataset = dataset['selected_dev'] 46 | output_path = f"./intermediate_data/s{args.setup}_dev_{args.model}_pools3.npy" 47 | dataset_idx += 2 48 | else: 49 | raise Exception('You should select a subdataset') 50 | ic(output_path) 51 | begin, end = 0, len(subdataset) 52 | ic(begin, end) 53 | 54 | model = ModelWarper(model_path = f"./{args.model}_saved/model") 55 | 56 | print(f"Model.device is: {model.device}") 57 | setup = int(args.setup) 58 | ic(setup) 59 | if (setup<0 or setup>4): 60 | raise Exception('incorrect setup number') 61 | 62 | final_dict = { 63 | "dataset_idx":dataset_idx, 64 | "setup":setup, 65 | "label":[], 66 | "text":[], 67 | "answer":[], 68 | } 69 | import numpy as np 70 | 71 | #df = pd.read_csv('yelp_prompt_example.csv') 72 | #df = pd.read_csv('finalized_prompt.csv') 73 | df = pd.read_csv('prompt_pool.csv') 74 | perfix_list, postfix_list = generate_prompts(df, setup) 75 | tokenizer_list = [] 76 | for i in range(len(perfix_list)): 77 | tokenizer_list.append(TokenizerWarper([perfix_list[i], postfix_list[i]], tokenizer_path = f"./{args.model}_saved/token")) 78 | label_idx = np.array([352, 362, 513, 604, 642]) 79 | sum = 0 80 | with torch.no_grad(): 81 | for i in trange(begin,end,4): 82 | batch_data =subdataset[i:min(i+4, end)] 83 | sum+=1 84 | if (sum==10): 85 | #break 86 | pass 87 | final_dict['label'].append(batch_data['label']) 88 | final_dict['text'].append(batch_data['text']) 89 | inputs = [tokenizer(batch_data['text']) for tokenizer in tokenizer_list] 90 | outputs = [] 91 | for k in inputs: 92 | #if (len(k['input_ids'][0])>1024): 93 | outputs.append(model(**k)) 94 | 95 | final_dict["answer"].append(outputs) 96 | 97 | a=np.array(final_dict) 98 | np.save(output_path,a) 99 | 100 | 101 | -------------------------------------------------------------------------------- /dataset/multi_thread_data_augmenter.py: -------------------------------------------------------------------------------- 1 | from textattack.augmentation.recipes import EasyDataAugmenter 2 | from tqdm import trange 3 | import multiprocessing 4 | import argparse 5 | from icecream import ic 6 | from multiprocessing import Process, Queue 7 | import jsonlines 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | class MyProcess(Process): #继承Process类 13 | def __init__(self, name, sub_dataset, path): 14 | super(MyProcess,self).__init__() 15 | self.name = name 16 | self.augmenter = EasyDataAugmenter() 17 | self.augmenter.fast_augment = True 18 | self.augmenter.high_yield = True 19 | self.augmenter.transformations_per_example = 20 20 | self.sub_dataset = sub_dataset 21 | self.Q = q 22 | self.path = path 23 | 24 | def run(self): 25 | for datapoint in self.sub_dataset: 26 | while True: 27 | try: 28 | result = self.augmenter.augment(datapoint['text']) 29 | break 30 | except: 31 | print(f"A error occured at {self.name}") 32 | time.sleep(random.random()) 33 | datapoint['result'] = result 34 | datapoint['name'] = self.name 35 | ic(self.name) 36 | while True: 37 | try: 38 | with jsonlines.open(self.path, 'a') as writer: 39 | writer.write(datapoint) 40 | break 41 | except: 42 | time.sleep(random.random()) 43 | ic(self.name+" finish write file") 44 | 45 | 46 | 47 | if __name__ == '__main__': 48 | 49 | #parser = argparse.ArgumentParser() 50 | #parser.add_argument("task_id", help="the selected setup number") 51 | #args = parser.parse_args() 52 | #task_id = int(args.task_id) 53 | #output_path = f"./intermediate_data/augment_test_{task_id}_2.jsonl" 54 | output_path = f"./intermediate_data/regenerating_dataset.jsonl" 55 | #ic(task_id, output_path) 56 | 57 | from datasets import load_from_disk 58 | 59 | results = [] 60 | dataset = load_from_disk('./yelp_review_full_split_train_dev') 61 | 62 | dataset_idx = np.load("regenerate_dataset.npy") 63 | current_point = 0 64 | process_list = [] 65 | q = Queue() 66 | for i in range(50): # 开启5个子进程执行fun1函数 67 | #begin = (task_id-1)*10000+i*20 68 | #subdataset = dataset['test'][begin:begin+200] 69 | subdataset_jsonl = [] 70 | for j in range(1): 71 | if (current_point>len(dataset_idx)): 72 | break 73 | current_idx = int(dataset_idx[current_point]) 74 | datapoint = { 75 | "idx": current_idx, 76 | "text": dataset['test'][current_idx]['text'], 77 | 'label': dataset['test'][current_idx]['label'] 78 | } 79 | current_point += 1 80 | subdataset_jsonl.append(datapoint) 81 | p = MyProcess(f"augment_Task_Process{i}", subdataset_jsonl, output_path) # 实例化进程对象 82 | p.start() 83 | time.sleep(1) 84 | process_list.append(p) 85 | 86 | for i in process_list: 87 | p.join() 88 | time.sleep(1) 89 | 90 | """ 91 | with jsonlines.open(output_path, 'a') as writer: 92 | for i in trange(10000): 93 | while (True): 94 | try: 95 | result = q.get() 96 | writer.write(result) 97 | ic(result, sum) 98 | break 99 | except: 100 | time.sleep(5) 101 | """ 102 | 103 | """ 104 | process_list = [] 105 | for i in range(5): #开启5个子进程执行fun1函数 106 | p = MyProcess('Python') #实例化进程对象 107 | p.start() 108 | process_list.append(p) 109 | 110 | for i in process_list: 111 | p.join() 112 | 113 | print('结束测试') 114 | """ -------------------------------------------------------------------------------- /prompt/text_attack_pytorch_warper.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | from torch.nn import CrossEntropyLoss 5 | 6 | import textattack 7 | 8 | from textattack.models.wrappers.model_wrapper import ModelWrapper 9 | 10 | torch.cuda.empty_cache() 11 | from icecream import ic 12 | 13 | 14 | class PyTorchModelWrapper(ModelWrapper): 15 | """Loads a PyTorch model (`nn.Module`) and tokenizer. 16 | Args: 17 | model (torch.nn.Module): PyTorch model 18 | tokenizer: tokenizer whose output can be packed as a tensor and passed to the model. 19 | No type requirement, but most have `tokenizer` method that accepts list of strings. 20 | """ 21 | 22 | def __init__(self, model, tokenizer): 23 | if not isinstance(model, torch.nn.Module): 24 | raise TypeError( 25 | f"PyTorch model must be torch.nn.Module, got type {type(model)}" 26 | ) 27 | 28 | self.model = model 29 | self.tokenizer = tokenizer 30 | 31 | def to(self, device): 32 | self.model.to(device) 33 | 34 | def __call__(self, text_input_list, batch_size=8): 35 | ic(batch_size) 36 | model_device = next(self.model.parameters()).device 37 | ids = self.tokenizer(text_input_list) 38 | 39 | for k in ids: 40 | ids[k] = torch.tensor(ids[k]).to(model_device) 41 | outputs = [] 42 | with torch.no_grad(): 43 | for i in range(0, len(ids['input_ids']),batch_size): 44 | outputs.append(self.model(**{'input_ids':ids['input_ids'][i:i+batch_size], 'attention_mask':ids['attention_mask'][i:i+batch_size]})) 45 | outputs = torch.cat(outputs).reshape([-1,5]) 46 | ic(outputs.shape) 47 | return outputs 48 | 49 | def get_grad(self, text_input, loss_fn=CrossEntropyLoss()): 50 | """Get gradient of loss with respect to input tokens. 51 | Args: 52 | text_input (str): input string 53 | loss_fn (torch.nn.Module): loss function. Default is `torch.nn.CrossEntropyLoss` 54 | Returns: 55 | Dict of ids, tokens, and gradient as numpy array. 56 | """ 57 | 58 | if not hasattr(self.model, "get_input_embeddings"): 59 | raise AttributeError( 60 | f"{type(self.model)} must have method `get_input_embeddings` that returns `torch.nn.Embedding` object that represents input embedding layer" 61 | ) 62 | if not isinstance(loss_fn, torch.nn.Module): 63 | raise ValueError("Loss function must be of type `torch.nn.Module`.") 64 | 65 | self.model.train() 66 | 67 | embedding_layer = self.model.get_input_embeddings() 68 | original_state = embedding_layer.weight.requires_grad 69 | embedding_layer.weight.requires_grad = True 70 | 71 | emb_grads = [] 72 | 73 | def grad_hook(module, grad_in, grad_out): 74 | emb_grads.append(grad_out[0]) 75 | 76 | emb_hook = embedding_layer.register_backward_hook(grad_hook) 77 | 78 | self.model.zero_grad() 79 | model_device = next(self.model.parameters()).device 80 | ids = self.tokenizer([text_input]) 81 | for k in ids: 82 | ids[k] = torch.tensor(ids[k]).to(model_device) 83 | 84 | predictions = self.model(**ids) 85 | 86 | output = predictions.argmax(dim=1) 87 | loss = loss_fn(predictions, output) 88 | loss.backward() 89 | 90 | # grad w.r.t to word embeddings 91 | grad = torch.transpose(emb_grads[0], 0, 1)[0].cpu().numpy() 92 | 93 | embedding_layer.weight.requires_grad = original_state 94 | emb_hook.remove() 95 | self.model.eval() 96 | 97 | output = {"ids": ids[0].tolist(), "gradient": grad} 98 | 99 | return output 100 | 101 | def _tokenize(self, inputs): 102 | """Helper method that for `tokenize` 103 | Args: 104 | inputs (list[str]): list of input strings 105 | Returns: 106 | tokens (list[list[str]]): List of list of tokens as strings 107 | """ 108 | return [self.tokenizer.convert_ids_to_tokens(self.tokenizer(x)) for x in inputs] -------------------------------------------------------------------------------- /run/inference_testset.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch 3 | from icecream import ic 4 | import pandas as pd 5 | from warper import TokenizerWarper, ModelWarper 6 | import argparse 7 | 8 | max_length = 1024 9 | def generate_prompts(df, setup): 10 | perfix = [] 11 | postfix = [] 12 | for i in range(len(df)): 13 | if (int(df['label'].iloc[i])!=setup): 14 | continue 15 | current_str = df['text'].iloc[i].replace("\\n","\n") 16 | perfix.append(current_str.split("f{review_text}")[0]) 17 | postfix.append(current_str.split("f{review_text}")[-1]) 18 | ic(len(perfix)) 19 | print(perfix) 20 | print(postfix) 21 | return perfix, postfix 22 | 23 | if __name__=='__main__': 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("setup", help="the selected setup number") 27 | parser.add_argument("model", help="the name of selected model") 28 | parser.add_argument("part", help="use the first, second, third,forth 25%, of the dataset") 29 | #parser.add_argument("--output", help="the path of output file") 30 | parser.add_argument("--test", help="use the train dataset", action="store_true") 31 | parser.add_argument("--augment", help="use the develop dataset", action="store_true") 32 | args = parser.parse_args() 33 | 34 | from datasets import load_from_disk 35 | #dataset = load_from_disk('./yelp_review_full_split_train_dev') 36 | dataset_idx = 0 37 | if (args.test): 38 | subdataset = load_from_disk('./origin_dataset_test') 39 | dataset_name = 'test_finalize' 40 | elif (args.augment): 41 | subdataset = load_from_disk('./augmented_dataset_test') 42 | dataset_name = 'augment_finalize' 43 | dataset_idx += 2 44 | else: 45 | raise Exception('You should select a subdataset') 46 | begin, end = 0, len(subdataset) 47 | bounder = [0, end//4, end//2, end*3//4 ,end] 48 | begin = bounder[int(args.part)-1] 49 | end = bounder[int(args.part)] 50 | ic(begin, end) 51 | 52 | model = ModelWarper(model_path = f"./{args.model}_saved/model") 53 | output_path = f"./intermediate_data/s{args.setup}_{dataset_name}{int(args.part)}_{args.model}.npy" 54 | ic(output_path) 55 | import os 56 | if (os.path.exists(output_path)): 57 | ic("file already existed") 58 | exit(0) 59 | print(f"Model.device is: {model.device}") 60 | setup = int(args.setup) 61 | ic(setup) 62 | if (setup<0 or setup>3): 63 | raise Exception('incorrect setup number') 64 | 65 | final_dict = { 66 | "dataset_idx":dataset_idx, 67 | "setup":setup, 68 | "label":[], 69 | "text":[], 70 | "answer":[], 71 | } 72 | import numpy as np 73 | 74 | #df = pd.read_csv('yelp_prompt_example.csv') 75 | #df = pd.read_csv('finalized_prompt.csv') 76 | df = pd.read_csv('probing_prompt.csv') 77 | perfix_list, postfix_list = generate_prompts(df, setup) 78 | tokenizer_list = [] 79 | for i in range(len(perfix_list)): 80 | tokenizer_list.append(TokenizerWarper([perfix_list[i], postfix_list[i]], tokenizer_path = f"./{args.model}_saved/token")) 81 | label_idx = np.array([352, 362, 513, 604, 642]) 82 | sum = 0 83 | with torch.no_grad(): 84 | for i in trange(begin,end,4): 85 | batch_data =subdataset[i:min(i+4, end)] 86 | sum+=1 87 | if (sum==10): 88 | #break 89 | pass 90 | final_dict['label'].append(batch_data['label']) 91 | final_dict['text'].append(batch_data['text']) 92 | inputs = [tokenizer(batch_data['text']) for tokenizer in tokenizer_list] 93 | outputs = [] 94 | for k in inputs: 95 | #if (len(k['input_ids'][0])>1024): 96 | outputs.append(model(**k)) 97 | 98 | final_dict["answer"].append(outputs) 99 | 100 | a=np.array(final_dict) 101 | np.save(output_path,a) 102 | 103 | 104 | 105 | """ 106 | from transformers import AutoTokenizer 107 | from transformers import GPT2LMHeadModel 108 | 109 | for model_name in ['gpt2-xl']: 110 | tokenizer = AutoTokenizer.from_pretrained(model_name) 111 | model = GPT2LMHeadModel.from_pretrained(model_name) 112 | 113 | tokenizer.save_pretrained(f"./{model_name}_saved/token") 114 | model.save_pretrained(f"./{model_name}_saved/model") 115 | """ -------------------------------------------------------------------------------- /prompt/prompt_augmenter.py: -------------------------------------------------------------------------------- 1 | from textattack.augmentation.recipes import EasyDataAugmenter 2 | from tqdm import trange 3 | import multiprocessing 4 | import argparse 5 | from icecream import ic 6 | from multiprocessing import Process, Queue 7 | import jsonlines 8 | import time 9 | import random 10 | import numpy as np 11 | import nltk 12 | import nltk.translate.gleu_score as gleu 13 | import numpy 14 | import os 15 | 16 | try: 17 | nltk.data.find('tokenizers/punkt') 18 | except LookupError: 19 | nltk.download('punkt') 20 | 21 | class MyProcess(Process): #继承Process类 22 | def __init__(self, name, sub_dataset, path): 23 | super(MyProcess,self).__init__() 24 | self.name = name 25 | self.augmenter = EasyDataAugmenter() 26 | self.augmenter.fast_augment = True 27 | self.augmenter.high_yield = True 28 | self.augmenter.transformations_per_example = 20 29 | self.sub_dataset = sub_dataset 30 | self.Q = q 31 | self.path = path 32 | 33 | def run(self): 34 | for datapoint in self.sub_dataset: 35 | while True: 36 | try: 37 | result = self.augmenter.augment(datapoint['text']) 38 | break 39 | except: 40 | print(f"A error occured at {self.name}") 41 | time.sleep(random.random()) 42 | datapoint['result'] = result 43 | datapoint['name'] = self.name 44 | ic(self.name) 45 | while True: 46 | try: 47 | with jsonlines.open(self.path, 'a') as writer: 48 | writer.write(datapoint) 49 | break 50 | except: 51 | time.sleep(random.random()) 52 | ic(self.name+" finish write file") 53 | 54 | 55 | 56 | def augment_single_sentence(str, max_transformation_words, number_exapmles): 57 | #ic(str) 58 | augmenter = EasyDataAugmenter(pct_words_to_swap=max_transformation_words, transformations_per_example = number_exapmles) 59 | augmenter.fast_augment = True 60 | augmenter.high_yield = True 61 | #augmenter.transformations_per_example = max_transformation_words 62 | result = augmenter.augment(str) 63 | return result 64 | 65 | # use MT metric to evaluate the distance bewteen the augmented sentence and origin sentence 66 | def evaluate_distance(hyp, ref_b): 67 | hyp = hyp.split() 68 | ref_b = ref_b.split() 69 | score_1to4grams = gleu.sentence_gleu([ref_b], hyp, min_len=1, max_len=4) 70 | return score_1to4grams 71 | 72 | 73 | def generate_augmented_sentence(perfix, postfix): 74 | # input is a prompt, and output is [4, 10, 2] list, (score, GLUE score) 75 | augmented_list = [] 76 | prompt = perfix+' f{review_text} '+postfix 77 | for max_transformation_words in [0.05, 0.1, 0.2, 0.4]: 78 | perfix_augmented_example = augment_single_sentence(perfix, max_transformation_words, 16) 79 | postfix_augmented_example = augment_single_sentence(postfix, max_transformation_words, 16) 80 | for i in range(min(len(perfix_augmented_example),len(postfix_augmented_example) )): 81 | prompt_aug = perfix_augmented_example[i]+' f{review_text} '+postfix_augmented_example[i] 82 | augmented_list.append((evaluate_distance(prompt, prompt_aug), prompt_aug)) 83 | augmented_list = sorted(augmented_list, reverse=True) 84 | ic(len(augmented_list)) 85 | #ic(augmented_list[0]) 86 | #ic(augmented_list[1]) 87 | #ic(augmented_list[-2]) 88 | #ic(augmented_list[-1]) 89 | result_list = [] 90 | for i in range(0,len(augmented_list),16): 91 | result_list.append(augmented_list[i:i+10]) 92 | return result_list 93 | # 1. split the prompt by perfix and postfix 94 | # 2. augment them seperately 95 | # 3. concate them 96 | 97 | if __name__ == '__main__': 98 | import pandas as pd 99 | df = pd.read_csv('smaller_prompt_pool.csv') 100 | ic(len(df)) 101 | from warper import generate_prompts 102 | perfix_list, postfix_list = generate_prompts(df, 0) 103 | df_dict = { 104 | "text":[], 105 | "label":[], 106 | "score":[], 107 | "group":[], 108 | } 109 | for i in range(6): 110 | augmented_dataset = generate_augmented_sentence(perfix_list[0], postfix_list[0]) 111 | for group_number in range(4): 112 | for k in range(10): 113 | df_dict['group'].append(group_number+1) 114 | df_dict['text'].append(augmented_dataset[group_number][k][1]) 115 | df_dict['score'].append(augmented_dataset[group_number][k][0]) 116 | df_dict['label'].append(df['label'].iloc[i]) 117 | 118 | new_df = pd.DataFrame.from_dict(df_dict) 119 | new_df.to_csv('prompt_pool_augmented.csv') 120 | 121 | 122 | -------------------------------------------------------------------------------- /run/inference_dataset_old_version.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | from icecream import ic 4 | import pandas as pd 5 | max_length = 1024 6 | def generate_prompts(df): 7 | perfix = [] 8 | postfix = [] 9 | for i in range(len(df)): 10 | current_str = df['text'].iloc[i].replace("\\n","\n") 11 | perfix.append(current_str.split("\n\n\n\n")[0]+"\n\n") 12 | postfix.append("\n\n"+current_str.split("\n\n\n\n")[-1]) 13 | return perfix, postfix 14 | 15 | if __name__=='__main__': 16 | from datasets import load_from_disk 17 | dataset = load_from_disk('./yelp_review_full_split_train_dev') 18 | #dataset.save_to_disk('./yelp_review_full') 19 | 20 | from transformers import AutoTokenizer 21 | from transformers import GPT2LMHeadModel 22 | tokenizer = AutoTokenizer.from_pretrained("./gpt2-large_saved/token") 23 | tokenizer.pad_token = tokenizer.eos_token 24 | model = GPT2LMHeadModel.from_pretrained('./gpt2-large_saved/model') 25 | 26 | #tokenizer.save_pretrained('./gpt2-large_saved/token') 27 | #model.save_pretrained('./gpt2-large_saved/model') 28 | 29 | device = 'cuda:0' 30 | model = model.to(device) 31 | 32 | print(f"Model.device is: {model.device}") 33 | 34 | final_dict = { 35 | "label":[], 36 | "text":[], 37 | "answer":[], 38 | } 39 | import numpy as np 40 | 41 | df = pd.read_csv('yelp_prompt_arguement_full_version.csv') 42 | perfix_list, postfix_list = generate_prompts(df) 43 | """ 44 | perfix_list = ["I just finished eating at a restaurant. Then I opened my Yelp app.\n\nI first gave a rating, and then wrote the following review:\n\n", 45 | "I just finished eating at a restaurant. Then I opened my Yelp app, and wrote the following review: \n\n", 46 | "I opened my Yelp app, and started to read some reviews of a restaurant that I want to try. I saw a user wrote this review:\n\n"] 47 | postfix_list = ["\n\nThe review is an explanation of why I gave a rating (out of 1 to 5 stars) of", 48 | "\n\nThen I gave the rating. In terms of 1 to 5 stars, I think this restaurant is worth a", 49 | "\n\nIn terms of 1 to 5 stars, I think this user rated it a"] 50 | """ 51 | label_idx = np.array([352, 362, 513, 604, 642]) 52 | perfix_input = [] 53 | for i in perfix_list: 54 | perfix_input.append(tokenizer(i, return_tensors='pt')) 55 | postfix_input = [] 56 | for i in postfix_list: 57 | postfix_input.append(tokenizer("..."+i, return_tensors='pt')) 58 | sum=0 59 | ic(dataset['selected_train'][9]) 60 | with torch.no_grad(): 61 | for i in tqdm(dataset['selected_train']): 62 | sum+=1 63 | if (sum==10): 64 | #break 65 | pass 66 | final_dict['label'].append(i['label']) 67 | final_dict['text'].append(i['text']) 68 | current_text = [] 69 | answer_list = [] 70 | for j in range(len(perfix_list)): 71 | current_text.append(perfix_list[j]+i['text']+postfix_list[j]) 72 | input = tokenizer(current_text, return_tensors='pt', padding=True) 73 | #ic(input['input_ids'].shape) 74 | if (len(input['input_ids'][0])>max_length): 75 | text_input = tokenizer(i['text'], return_tensors='pt') 76 | for k in input: 77 | input[k] = [] 78 | for j in range(len(perfix_input)): 79 | #perfix_input = tokenizer(perfix_list[j], return_tensors='pt') 80 | #postfix_input = tokenizer("..."+postfix_list[j], return_tensors='pt') 81 | ic(perfix_input[j][k][0],text_input[k][0,:max_length-len(perfix_input[j][k][0])-len(postfix_input[j][k][0])],postfix_input[j][k][0]) 82 | input[k].append(torch.cat([perfix_input[j][k][0],text_input[k][0,:max_length-len(perfix_input[j][k][0])-len(postfix_input[j][k][0])],postfix_input[j][k][0]],dim=0)) 83 | input[k]=torch.cat(input[k]).reshape([-1, max_length]).contiguous() 84 | ic(input[k].shape) 85 | for k in input: 86 | input[k] = input[k].to(device) 87 | output = model(**input)[0].cpu().detach().numpy()[:,:, label_idx] 88 | answer = [] 89 | for i in range(len(perfix_list)): 90 | selected_output = output[i][input['attention_mask'][i].cpu().detach().numpy()==1][-1] 91 | answer.append(selected_output) 92 | #TODO 93 | #ic(len(answer)) 94 | #ic(answer[0].shape) 95 | #ic(output.shape) 96 | final_dict[f"answer"].append(output) 97 | #ic(i['label'], output) 98 | a=np.array(final_dict) 99 | np.save("/cluster/project/sachan/zhiheng/causal_prompting/intermediate_data/predict_train_14.npy",a) 100 | 101 | 102 | -------------------------------------------------------------------------------- /eval/select_evaluation.py: -------------------------------------------------------------------------------- 1 | # This code is writed for train all the offset and testing the accuracy of all the prompts 2 | # and select the best one 3 | # 1. Concate the structure from the 4 | #TODO 5 | import numpy as np 6 | import sklearn.metrics 7 | from sklearn.metrics import accuracy_score, f1_score 8 | from collections import Counter 9 | from icecream import ic 10 | import pandas as pd 11 | from tqdm import trange 12 | final_answer = [] 13 | 14 | def generate_prompts(df): 15 | prompt_list = [] 16 | for i in range(len(df)): 17 | current_str = df['text'].iloc[i].replace("\\n","\n") 18 | perfix = current_str.split("f{review_text}")[0] 19 | postfix = current_str.split("f{review_text}")[-1] 20 | label = df['group'].iloc[i] 21 | prompt_list.append({ 22 | 'perfix':perfix, 23 | 'postfix':postfix, 24 | 'label':label, 25 | }) 26 | return prompt_list 27 | 28 | def eval_given_offset(result_list, offset): 29 | pass 30 | 31 | def eval_and_offset(result_list): 32 | # len(result_list) = 20,000, the first 10,000 is train set, used to generate the offset 33 | # the last 10,000 is the dev set, used to evaluate the performance 34 | #ic(result_list[0:10]) 35 | train_list, dev_list = result_list[:len(result_list)//2],result_list[len(result_list)//2:] 36 | gt_labels = [i[1] for i in train_list] 37 | predict_prob = np.array([i[0] for i in train_list]) 38 | predict_labels = predict_prob.argmax(axis=1) 39 | 40 | 41 | gt_labels_dev = [i[1] for i in dev_list] 42 | predict_prob_dev = np.array([i[0] for i in dev_list]) 43 | predict_labels_dev = predict_prob_dev.argmax(axis=1) 44 | 45 | print("###########before normalize#############") 46 | ic(Counter(gt_labels), Counter(predict_labels)) 47 | ic(accuracy_score(gt_labels, predict_labels), f1_score(gt_labels, predict_labels, average='weighted')) 48 | 49 | ic(Counter(gt_labels_dev), Counter(predict_labels_dev)) 50 | ic(accuracy_score(gt_labels_dev, predict_labels_dev), f1_score(gt_labels_dev, predict_labels_dev, average='weighted')) 51 | 52 | offset = np.array([0.0,0.0,0.0,0.0,0.0]) 53 | lr = 0.01 54 | cgt = dict(Counter(gt_labels)) 55 | for i in range(1000): 56 | predict_labels = predict_prob.argmax(axis=1) 57 | predict_labels_dev = predict_prob_dev.argmax(axis=1) 58 | flag = False 59 | for label in range(5): 60 | delta = lr*(1-list(predict_labels).count(label)/cgt[label]) 61 | if (delta!=0): 62 | flag = True 63 | offset[label]+=delta 64 | predict_prob[:,label]+=delta 65 | predict_prob_dev[:,label]+=delta 66 | #ic(offset[label], delta, flag) 67 | if (not flag): 68 | break 69 | 70 | predict_labels = predict_prob.argmax(axis=1) 71 | predict_labels_dev = predict_prob_dev.argmax(axis=1) 72 | print("###########after normalize#############") 73 | ic(Counter(gt_labels), Counter(predict_labels)) 74 | ic(accuracy_score(gt_labels, predict_labels), f1_score(gt_labels, predict_labels, average='weighted')) 75 | ic(Counter(gt_labels_dev), Counter(predict_labels_dev)) 76 | ic(accuracy_score(gt_labels_dev, predict_labels_dev), 77 | f1_score(gt_labels_dev, predict_labels_dev, average='weighted')) 78 | return offset, accuracy_score(gt_labels, predict_labels), accuracy_score(gt_labels_dev, predict_labels_dev) 79 | def renormalize_data(prompt_list, origin_data_paths = "/cluster/project/sachan/zhiheng/causal_prompting/intermediate_data/"): 80 | #dim1, setup number, 3; dim2, number of prompt, 10; dim3, 20000, prediction of each prompt 81 | whole_dataset = [[],[],[],[]] 82 | for setup_number in [1,2,3]: 83 | for prompt in prompt_list: 84 | if (prompt['label'] == setup_number): 85 | whole_dataset[setup_number].append([]) 86 | for subdataset_name in ['train', 'dev']: 87 | for sub_dataset in [1,2]: 88 | file_name = f"s{setup_number}_{subdataset_name}{sub_dataset}_selected.npy" 89 | result_path = origin_data_paths + file_name 90 | ic(result_path) 91 | a = np.load(result_path, allow_pickle=True) 92 | result_dict = a.tolist() 93 | ic(result_dict.keys()) 94 | for batch_id in trange(len(result_dict['answer'])): 95 | for prompt_id in range(len(result_dict['answer'][batch_id])): 96 | batch_label = result_dict['label'][batch_id] 97 | batch_predict = result_dict['answer'][batch_id][prompt_id] 98 | #ic(batch_label, batch_predict) 99 | for i in range(len(batch_label)): 100 | whole_dataset[setup_number][prompt_id].append((batch_predict[i], batch_label[i])) 101 | 102 | for prompt_id in range(len(whole_dataset[setup_number])): 103 | #ic(whole_dataset[setup_number][prompt_id][0:10]) 104 | offset, train_accuracy, accuracy = eval_and_offset(whole_dataset[setup_number][prompt_id]) 105 | print(f"For setup {setup_number}, prompt {prompt_id}, the final accuracy is {accuracy}, offset is {offset}") 106 | final_answer.append(f"For setup {setup_number}, prompt {prompt_id},the train accuracy is {train_accuracy}, the dev accuracy is {accuracy}") 107 | 108 | if __name__=='__main__': 109 | df = pd.read_csv('selected_prompt.csv') 110 | prompt_list = generate_prompts(df) 111 | renormalized_data = renormalize_data(prompt_list) 112 | sum = 0 113 | for i in final_answer: 114 | sum+=1 115 | print(sum, i) -------------------------------------------------------------------------------- /model/warper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import AutoTokenizer 4 | from transformers import GPT2LMHeadModel 5 | import sys 6 | from icecream import ic 7 | sys.path.append('/cluster/project/sachan/zhiheng/causal_prompting') 8 | 9 | def generate_prompts(df, setup): 10 | perfix = [] 11 | postfix = [] 12 | for i in range(len(df)): 13 | if (int(df['group'].iloc[i])!=setup and setup!=0): 14 | continue 15 | current_str = df['text'].iloc[i].replace("\\n","\n") 16 | perfix.append(current_str.split(" f{review_text} ")[0]) 17 | postfix.append(current_str.split(" f{review_text} ")[-1]) 18 | ic(len(perfix)) 19 | print(perfix) 20 | print(postfix) 21 | return perfix, postfix 22 | 23 | class ModelWarper(nn.Module): 24 | def __init__(self, offset = [0,0,0,0,0], model_path = './gpt2-large_saved/model'): 25 | super().__init__() 26 | self.model = GPT2LMHeadModel.from_pretrained(model_path) 27 | self.device = 'cpu' 28 | if (torch.cuda.is_available()): 29 | self.device = 'cuda:0' 30 | self.model = self.model.to(self.device) 31 | self.label_idx = torch.tensor([352, 362, 513, 604, 642]) 32 | #self.offset = torch.tensor(offset).to(self.device) 33 | 34 | def forward(self, **kwargs): 35 | for i in kwargs: 36 | kwargs[i] = kwargs[i].to(self.device) 37 | #ic(kwargs[i].shape) 38 | output = self.model(**kwargs)[0][:, :, self.label_idx] 39 | # ic(output) 40 | answer = [] 41 | for i in range(len(kwargs['input_ids'])): 42 | selected_output = output[i][kwargs['attention_mask'][i]== 1][-1].cpu().detach().numpy() 43 | answer.append(selected_output) 44 | #ic(answer) 45 | return answer#+self.offset 46 | 47 | # Use the maxlength method to warp the sentence 48 | class TokenizerWarper(object): 49 | def __init__(self, prompt, tokenizer_path = "./gpt2-large_saved/token"): 50 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, truncation = True) 51 | self.tokenizer.pad_token = self.tokenizer.eos_token 52 | self.prompt = prompt 53 | """ 54 | self.perfix_token = self.tokenizer(prompt[0]+"\"") 55 | self.postfix_token = self.tokenizer("\" "+prompt[1]) 56 | self.postfix_token_2 = self.tokenizer("...\" "+prompt[1]) 57 | """ 58 | 59 | self.perfix_token = self.tokenizer(prompt[0]) 60 | self.postfix_token = self.tokenizer(" "+prompt[1]) 61 | self.postfix_token_2 = self.tokenizer("... "+prompt[1]) 62 | self.max_length = 1024 63 | 64 | def padding(self, padded_dict, max_length): 65 | if (len(padded_dict['input_ids'])>=max_length): 66 | #ic(len(padded_dict['input_ids'])) 67 | pass 68 | for i in padded_dict: 69 | padded_dict[i] = padded_dict[i] + [0]*(max_length-len(padded_dict[i])) 70 | return padded_dict 71 | 72 | def encode_single_sentence(self, str): 73 | #str = '\n'.join(str.split('\\n')) 74 | mid = self.tokenizer(str) 75 | perfix_token = self.perfix_token 76 | postfix_token = self.postfix_token 77 | postfix_token_2 = self.postfix_token_2 78 | if (len(perfix_token['input_ids'])+len(mid['input_ids'])+len(postfix_token['input_ids'])>self.max_length): 79 | #ic(len(perfix_token['input_ids'])+len(mid['input_ids'])+len(postfix_token['input_ids'])) 80 | pass 81 | if (len(perfix_token['input_ids'])+len(mid['input_ids'])+len(postfix_token['input_ids'])>self.max_length): 82 | return { 83 | 'input_ids':perfix_token['input_ids']+mid['input_ids'][:self.max_length-len(perfix_token['input_ids'])-len(postfix_token_2['input_ids'])]+postfix_token_2['input_ids'], 84 | 'attention_mask':perfix_token['attention_mask']+mid['attention_mask'][:self.max_length-len(perfix_token['attention_mask'])-len(postfix_token_2['attention_mask'])]+postfix_token_2['attention_mask'], 85 | } 86 | else: 87 | return { 88 | 'input_ids':perfix_token['input_ids']+mid['input_ids']+postfix_token['input_ids'], 89 | 'attention_mask':perfix_token['attention_mask']+mid['attention_mask']+postfix_token['attention_mask'], 90 | } 91 | 92 | def encode(self, str): 93 | answer = { 94 | "input_ids":[], 95 | "attention_mask":[], 96 | } 97 | if (type(str)==list): 98 | max_length = 0 99 | answers = [] 100 | for i in str: 101 | answers.append(self.encode_single_sentence(i)) 102 | max_length = max(max_length, len(answers[-1]['input_ids'])) 103 | max_length = min(max_length, self.max_length) 104 | for subanswer in answers: 105 | paded_subanswer = self.padding(subanswer, max_length) 106 | answer['input_ids'].append(subanswer['input_ids']) 107 | answer['attention_mask'].append(subanswer['attention_mask']) 108 | ic(max_length) 109 | else: 110 | subanswer = self.encode_single_sentence(str) 111 | answer['input_ids'].append(subanswer['input_ids']) 112 | answer['attention_mask'].append(subanswer['attention_mask']) 113 | return answer 114 | return {'input_ids': torch.tensor(answer['input_ids']), 'attention_mask': torch.tensor(answer['attention_mask'])} 115 | 116 | def __call__(self, str): 117 | return self.encode(str) 118 | 119 | tokenizer = TokenizerWarper(["I just finished eating at a restaurant. Then I opened my Yelp app.\n\nI first gave a rating, and then wrote the following review:\n\n", 120 | "\n\nThe review is an explanation of why I gave a rating (out of 1 to 5 stars) of"]) 121 | 122 | """ 123 | perfix_list = ["I just finished eating at a restaurant. Then I opened my Yelp app.\n\nI first gave a rating, and then wrote the following review:\n\n", 124 | "I just finished eating at a restaurant. Then I opened my Yelp app, and wrote the following review: \n\n", 125 | "I opened my Yelp app, and started to read some reviews of a restaurant that I want to try. I saw a user wrote this review:\n\n"] 126 | postfix_list = ["\n\nThe review is an explanation of why I gave a rating (out of 1 to 5 stars) of", 127 | "\n\nThen I gave the rating. In terms of 1 to 5 stars, I think this restaurant is worth a", 128 | "\n\nIn terms of 1 to 5 stars, I think this user rated it a"] 129 | tokenizer = 130 | """ 131 | """ 132 | model_gpt = ModelWarper(offset=[-0.35403400000000423, 133 | 0.8431049999999987, 134 | 0.10486800000000042, 135 | -0.3882869999999973, 136 | -0.20565200000000122]) 137 | import textattack 138 | import textattack.models 139 | import textattack.models.wrappers 140 | from text_attack_pytorch_warper import PyTorchModelWrapper as tmw 141 | model = tmw(model_gpt, tokenizer) 142 | from icecream import ic 143 | ic(model) 144 | """ 145 | 146 | def generate_final_dict(model, tokenizer, dataset): 147 | final_dict = {} 148 | pass 149 | 150 | if __name__=='__main__': 151 | 152 | from datasets import load_from_disk 153 | dataset = load_from_disk('./yelp_review_full_split_train_dev') 154 | 155 | ic(dataset) 156 | ic(dataset['train'][0:1]['text']) 157 | tokenizer = TokenizerWarper(["I just finished eating at a restaurant. Then I opened my Yelp app.\n\nI first gave a rating, and then wrote the following review:\n\n", 158 | "\n\nThe review is an explanation of why I gave a rating (out of 1 to 5 stars) of"]) 159 | tokens = tokenizer.encode(dataset['train'][0:1]['text']) 160 | ic(tokens) 161 | model = ModelWarper(offset=[-0.35403400000000423, 162 | 0.8431049999999987, 163 | 0.10486800000000042, 164 | -0.3882869999999973, 165 | -0.20565200000000122]) 166 | ic(model(**tokens)) 167 | 168 | 169 | -------------------------------------------------------------------------------- /dataset/align_dataset.py: -------------------------------------------------------------------------------- 1 | from icecream import ic 2 | import pandas as pd 3 | from tqdm import trange 4 | import numpy as np 5 | from collections import Counter 6 | from icecream import ic 7 | import pandas as pd 8 | import sklearn.metrics 9 | from sklearn.metrics import accuracy_score, f1_score 10 | 11 | def align_dataset(model_name, subdataset_name, base_dataset, prompt_list, out_dataset_path): 12 | # dim1, setup number, 3; dim2, number of prompt, n; dim3, 10000, prediction of each prompt 13 | dataset_dict = base_dataset[:] 14 | 15 | origin_data_paths = "./intermediate_data/" 16 | 17 | whole_dataset = [[],[],[],[]] 18 | for setup_number in [1,2,3]: 19 | for prompt in prompt_list: 20 | if (prompt['label'] == setup_number): 21 | whole_dataset[setup_number].append([]) 22 | for sub_dataset in ["", "1", "2", "3", "4"]: 23 | file_name = f"s{setup_number}_{subdataset_name}{sub_dataset}_{model_name}.npy" 24 | result_path = origin_data_paths + file_name 25 | #ic(result_path) 26 | import numpy as np 27 | try: 28 | a = np.load(result_path, allow_pickle=True) 29 | except: 30 | ic("there is no " + result_path) 31 | continue 32 | 33 | result_dict = a.tolist() 34 | #ic(result_dict.keys()) 35 | for batch_id in trange(len(result_dict['answer'])): 36 | for prompt_id in range(len(result_dict['answer'][batch_id])): 37 | batch_label = result_dict['label'][batch_id] 38 | batch_predict = result_dict['answer'][batch_id][prompt_id] 39 | # ic(batch_label, batch_predict) 40 | for i in range(len(batch_label)): 41 | whole_dataset[setup_number][prompt_id].append(list(batch_predict[i])) 42 | dataset_dict['result'] = [] 43 | for i in range(len(dataset_dict['label'])): 44 | dataset_dict['result'].append([]) 45 | for setup_number in range(1, 4, 1): 46 | for prompt_id in range(2): 47 | #ic(setup_number, prompt_id) 48 | #ic(whole_dataset[setup_number][prompt_id]) 49 | if (len(whole_dataset[setup_number][prompt_id])>i): 50 | dataset_dict['result'][-1].append(whole_dataset[setup_number][prompt_id][i]) 51 | else: 52 | dataset_dict['result'][-1].append(None) 53 | from datasets import Dataset 54 | dataset = Dataset.from_dict(dataset_dict) 55 | ic(dataset[-3:]) 56 | ic(dataset) 57 | dataset.save_to_disk(out_dataset_path) 58 | return dataset 59 | 60 | def load_whole_result(model_name, subdataset_name): 61 | # dim1, setup number, 3; dim2, number of prompt, n; dim3, 10000, prediction of each prompt 62 | origin_data_paths = "./intermediate_data/" 63 | whole_dataset = [[],[],[],[],[]] 64 | for setup_number in [1,2,3,4]: 65 | for prompt in prompt_list: 66 | if (prompt['label'] == setup_number):# Origin version is group 67 | whole_dataset[setup_number].append([]) 68 | for sub_dataset in ["", "1", "2", "3", "4"]: 69 | file_name = f"s{setup_number}_{subdataset_name}{sub_dataset}_{model_name}_pools_aug.npy" 70 | result_path = origin_data_paths + file_name 71 | #ic(result_path) 72 | import numpy as np 73 | try: 74 | a = np.load(result_path, allow_pickle=True) 75 | except: 76 | ic("there is no " + result_path) 77 | continue 78 | 79 | result_dict = a.tolist() 80 | #ic(result_dict.keys()) 81 | 82 | ic(len(result_dict['answer'][0])) 83 | for batch_id in trange(len(result_dict['answer'])): 84 | for prompt_id in range(len(result_dict['answer'][batch_id])): 85 | batch_label = result_dict['label'][batch_id] 86 | batch_predict = result_dict['answer'][batch_id][prompt_id] 87 | # ic(batch_label, batch_predict) 88 | for i in range(len(batch_label)): 89 | whole_dataset[setup_number][prompt_id].append((list(batch_predict[i]), batch_label[i])) 90 | """ 91 | ic(len(whole_dataset[1])) 92 | ic(len(whole_dataset[2])) 93 | ic(len(whole_dataset[3])) 94 | ic(len(whole_dataset[1][0])) 95 | ic(len(whole_dataset[3][0])) 96 | ic(whole_dataset[3][0][0:5]) 97 | """ 98 | ic(len(whole_dataset[1][0])) 99 | return whole_dataset 100 | 101 | def generate_offset(result_list): 102 | # Return a offset, and evaluate the performance on trainset 103 | train_list= result_list 104 | gt_labels = [i[1] for i in train_list] 105 | predict_prob = np.array([i[0] for i in train_list]) 106 | predict_labels = predict_prob.argmax(axis=1) 107 | 108 | ic(Counter(gt_labels), Counter(predict_labels)) 109 | ic(accuracy_score(gt_labels, predict_labels), f1_score(gt_labels, predict_labels, average='weighted')) 110 | 111 | offset = np.array([0.0,0.0,0.0,0.0,0.0]) 112 | lr = 0.01 113 | cgt = dict(Counter(gt_labels)) 114 | # Optimization the offsets 115 | for i in range(10000): 116 | predict_labels = predict_prob.argmax(axis=1) 117 | flag = False 118 | for label in range(5): 119 | delta = lr*(1-list(predict_labels).count(label)/cgt[label]) 120 | if (delta!=0): 121 | flag = True 122 | offset[label]+=delta 123 | predict_prob[:,label]+=delta 124 | if (not flag): 125 | break 126 | predict_labels = predict_prob.argmax(axis=1) 127 | print("###########after optimize#############") 128 | ic(Counter(gt_labels), Counter(predict_labels)) 129 | ic(accuracy_score(gt_labels, predict_labels), f1_score(gt_labels, predict_labels, average='weighted')) 130 | ic(offset) 131 | return offset, [accuracy_score(gt_labels, predict_labels), f1_score(gt_labels, predict_labels, average='weighted')] 132 | 133 | def evaluate_performance(result_list, offset): 134 | gt_labels = [i[1] for i in result_list] 135 | predict_prob = np.array([i[0]+offset for i in result_list]) 136 | #ic(predict_prob[0:10]) 137 | #ic(result_list[0:10]) 138 | predict_labels = predict_prob.argmax(axis=1) 139 | ic(Counter(gt_labels), Counter(predict_labels)) 140 | ic(accuracy_score(gt_labels, predict_labels), f1_score(gt_labels, predict_labels, average='weighted')) 141 | return [accuracy_score(gt_labels, predict_labels), f1_score(gt_labels, predict_labels, average='weighted')] 142 | 143 | def generate_robustness_KL(result_list, result_list_augment, augmented_dataset, offset): 144 | from scipy.special import softmax 145 | def KL(px, py): 146 | return np.sum(px*np.log(px/py)) 147 | gt_labels = [i[1] for i in result_list] 148 | predict_prob = np.array([softmax(i[0]+offset) for i in result_list]) 149 | predict_prob_augment = np.array([softmax(i[0]+offset) for i in result_list_augment]) 150 | KL_list = [] 151 | for i in trange(len(result_list_augment)): 152 | KL_list.append(KL(predict_prob[augmented_dataset[i]['origin_data_idx']], predict_prob_augment[i])) 153 | #ic(KL_list[0:10]) 154 | ic(np.average(KL_list)) 155 | return np.average(KL_list) 156 | 157 | def evaluate_prompt(train_list, dev_list, test_list, augment_list, augmented_dataset): 158 | ic("Train set performance") 159 | evals = [] 160 | offset, eval = generate_offset(train_list) 161 | evals.append(eval) 162 | ic("Dev set performance") 163 | eval = evaluate_performance(dev_list, offset) 164 | evals.append(eval) 165 | if (test_list == None): 166 | return evals 167 | ic("Test set performance") 168 | eval = evaluate_performance(test_list, offset) 169 | evals.append(eval) 170 | if (augment_list == None): 171 | return evals 172 | ic("Aug set performance") 173 | eval = evaluate_performance(augment_list, offset) 174 | evals.append(eval) 175 | #ic("Aug set KL_divergence") 176 | #eval = generate_robustness_KL(test_list, augment_list, augmented_dataset, offset) 177 | evals.append(eval) 178 | return evals 179 | 180 | 181 | 182 | if __name__=='__main__': 183 | from select_evaluation import generate_prompts 184 | #df = pd.read_csv('prompt_pool.csv') 185 | #prompt_list = generate_prompts(df) 186 | #ic(len(prompt_list)) 187 | model_name = 'gpt2-large' 188 | #subdataset_name = 'test' 189 | from datasets import load_from_disk 190 | import datasets 191 | #base_dataset = load_from_disk("/cluster/project/sachan/zhiheng/causal_prompting/origin_dataset_test") 192 | #augmented_dataset = load_from_disk("/cluster/project/sachan/zhiheng/causal_prompting/augmented_dataset_test") 193 | df = pd.read_csv('prompt_pool_augmented.csv') 194 | prompt_list = generate_prompts(df) 195 | #out_dataset_path = "/cluster/project/sachan/zhiheng/causal_prompting/gpt2-medium_aligned_dataset_test" 196 | #align_dataset(model_name, subdataset_name, base_dataset, prompt_list, out_dataset_path) 197 | train_list = load_whole_result("gpt2-large", 'train') 198 | dev_list = load_whole_result("gpt2-large", 'dev') 199 | #ic(len(train_list[2])) 200 | #test_list = load_whole_result("gpt2-large", 'test') 201 | #augment_list = load_whole_result("gpt2-large", 'augment_finalize') 202 | #ic(len(augment_list[1][1])) 203 | #ic(len(augment_list[2][1])) 204 | #ic(len(augment_list[3][1])) 205 | final_result = [] 206 | #generate_offset(augment_list[1][0]) 207 | #final_result.append(evaluate_prompt(train_list[1][0], dev_list[1][0], test_list[1][0], augment_list[1][0], augmented_dataset)) 208 | ic(len(train_list[0]), len(train_list[1]),len(train_list[2]),len(train_list[3])) 209 | """ 210 | final_result.append(evaluate_prompt(train_list[1][3], dev_list[1][3], None, None, None)) 211 | final_result.append(evaluate_prompt(train_list[1][0], dev_list[1][0], None, None, None)) 212 | final_result.append(evaluate_prompt(train_list[2][1], dev_list[2][1], None, None, None)) 213 | final_result.append(evaluate_prompt(train_list[1][1], dev_list[1][1], None, None, None)) 214 | final_result.append(evaluate_prompt(train_list[2][0], dev_list[2][0], None, None, None)) 215 | final_result.append(evaluate_prompt(train_list[1][2], dev_list[1][2], None, None, None)) 216 | final_result.append(evaluate_prompt(train_list[2][2], dev_list[2][2], None, None, None)) 217 | """ 218 | 219 | """ 220 | final_result.append(evaluate_prompt(train_list[1][0], dev_list[1][0], None, None, None)) 221 | final_result.append(evaluate_prompt(train_list[1][1], dev_list[1][1], None, None, None)) 222 | final_result.append(evaluate_prompt(train_list[1][2], dev_list[1][2], None, None, None)) 223 | final_result.append(evaluate_prompt(train_list[1][3], dev_list[1][3], None, None, None)) 224 | final_result.append(evaluate_prompt(train_list[1][4], dev_list[1][4], None, None, None)) 225 | final_result.append(evaluate_prompt(train_list[2][0], dev_list[2][0], None, None, None)) 226 | final_result.append(evaluate_prompt(train_list[2][1], dev_list[2][1], None, None, None)) 227 | final_result.append(evaluate_prompt(train_list[2][2], dev_list[2][2], None, None, None)) 228 | """ 229 | for i in range(5): 230 | print(len(train_list[i])) 231 | if (i!=4): 232 | continue 233 | for j in range(len(train_list[i])): 234 | final_result.append(evaluate_prompt(train_list[i][j], dev_list[i][j], None, None, None)) 235 | 236 | print(final_result) --------------------------------------------------------------------------------