├── data_fns ├── __init__.py ├── data_fns_fewrel.py ├── data_fns_yelp.py ├── data_fns_featurebased_checklists.py └── data_fns_override_checklists.py ├── config └── task-finetuning.yaml ├── PATCH_DIR └── override_patch_data │ └── process_into_applies.py ├── README.md ├── train_models.py ├── feature-checklist-baselines.ipynb ├── helpers.py ├── finetuning_experiments.py ├── environment.yml ├── eval_utils.py ├── orig_model_results.ipynb ├── training_utils.py ├── patch_dataset.py ├── model.py ├── convert_yaml_to_data.py ├── override-checklist-experiments.ipynb └── override_patches_sentiment.ipynb /data_fns/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_fns_yelp import * 2 | from .data_fns_fewrel import * 3 | from .data_fns_override_checklists import * 4 | from .data_fns_yelp import * 5 | from .data_fns_featurebased_checklists import * 6 | -------------------------------------------------------------------------------- /config/task-finetuning.yaml: -------------------------------------------------------------------------------- 1 | mode: 'train' 2 | data: 'toy' 3 | use_as: 'sentiment-classifier' 4 | model_type: t5-large 5 | prompt_style: 'p1' 6 | train: 7 | num_warmup_steps: 500 8 | save_path: 'gpt2-finetuned' 9 | num_epochs: 3 10 | lr: 1e-4 11 | train_batch_size: 2 12 | accum_steps: 16 13 | eval_batch_size: 64 14 | -------------------------------------------------------------------------------- /PATCH_DIR/override_patch_data/process_into_applies.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | DEFAULT_EXP='default noop explanation' 4 | def get_cond(explanation): 5 | return ' '.join(explanation.split(',')[0].split(' ')[1:]) 6 | 7 | with open('synthetic_data_old.json', 'r') as reader: 8 | data = json.load(reader) 9 | 10 | 11 | 12 | new_data = {key: [] for key in data} 13 | all_explanations = set() 14 | for exp, instance, label in zip(data['explanations'], data['instances'], data['labels']): 15 | if exp == DEFAULT_EXP: 16 | new_data['labels'].append(0) # this is because we give all negative explanations the label of the noop... 17 | new_data['instances'].append(instance) 18 | new_data['explanations'].append(exp) 19 | else: 20 | new_data['labels'].append(1) 21 | cond = get_cond(exp) 22 | new_data['instances'].append(instance) 23 | new_data['explanations'].append(cond) 24 | 25 | all_explanations.add(cond) 26 | 27 | for exp in all_explanations: 28 | print(exp) 29 | 30 | with open('synthetic_data.json', 'w') as writer: 31 | json.dump(new_data,writer) 32 | print(len(new_data['explanations'])) 33 | -------------------------------------------------------------------------------- /data_fns/data_fns_fewrel.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | # setting path 4 | from checklist.editor import Editor 5 | from collections import defaultdict as ddict 6 | import random 7 | import pandas as pd 8 | 9 | editor = Editor() 10 | 11 | import re 12 | 13 | ### Data slice from FewRel 14 | def get_fewrel_data(): 15 | neg_relations = [ 16 | "P3373", 17 | "P22", 18 | "P6", 19 | "P57", 20 | "P674", 21 | "P1344", 22 | "P22", 23 | "P991", 24 | "P106", 25 | "P463", 26 | "P40", 27 | "P25", 28 | "P108", 29 | ] 30 | pos_relations = ["P26"] 31 | dataset = load_dataset("few_rel", "default") 32 | keys = ["train_wiki", "val_wiki", "val_semeval"] 33 | 34 | def get_all(data, relation_types): 35 | filtered = [] 36 | for ex in data: 37 | if ex["relation"] in relation_types: 38 | filtered.append(ex) 39 | return filtered 40 | 41 | def process(ex): 42 | if ex[0][-1] != ".": 43 | ex[0] = "{}.".format(ex[0]) 44 | return "{} Entity1: {}. Entity2: {}".format(ex[0], ex[1], ex[2]) 45 | 46 | positives = [] 47 | negatives = [] 48 | for key in keys: 49 | negatives += get_all(dataset[key], neg_relations) 50 | for key in keys: 51 | positives += get_all(dataset[key], pos_relations) 52 | 53 | positive_data = [ 54 | process([" ".join(l["tokens"]), l["head"]["text"], l["tail"]["text"]]) 55 | for l in positives 56 | ] 57 | negative_data = [ 58 | process([" ".join(l["tokens"]), l["head"]["text"], l["tail"]["text"]]) 59 | for l in negatives 60 | ] 61 | return positive_data, negative_data 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Setup 2 | 3 | Install all dependencies using `conda`: 4 | 5 | ``` 6 | conda env create -f environment.yml 7 | conda activate lang-patching 8 | pip install -e . 9 | ``` 10 | 11 | ## Training Pipeline 12 | Note that this repository uses `hydra` for managing hyperparameters and experiments. The configs we use for training can be found under `config`. `hydra` creates unique outputs for every experiment under the directory `output`. 13 | 14 | ### Creating Patch Finetuning Data 15 | To start, create synthetic data for patch finetuning using the yaml format (some examples are in the `PATCH_DIR` folder), and then use `convert_yaml_to_data.py` to create json files. The JSON files used in our experiments can be found in the `PATCH_DIR` folder. 16 | 17 | ### Training Patchable Models 18 | 19 | The entry script for training patchable models is `train_models.py`. Run it as: 20 | 21 | ``` python train_models.py train.save_path={SAVE_PATH} +protocol={protocol} +patch_type={SUB_FOLDER} +multitask_sst=True +train.load_path={TASK_FINETUNED_MODEL} +learnt_interpreter={True/False}``` 22 | 23 | - {SAVE_PATH}: path where the patchable model will be saved 24 | - {protocol}: can be one of 25 | - `simple`: If you want to train a model on just the task ("Task Finetuning") 26 | - `patch_finetuning_conds`: train a patchable model for Sentiment Classification 27 | - `patch_re`: to train a patchable model for Relation Extraction 28 | - {SUB_FOLDER}: one of the folders in the PATCH_DIR directory. To train models with override patches, use `override_patch_data` and to train a model with feature based patches, use `feature_based_patch_data`. 29 | - learnt_interpreter: set this to `True` to train feature based patches. 30 | 31 | 32 | 33 | ### Model Checkpoints 34 | Checkpoints for models used in this work can be found at this [link](https://drive.google.com/drive/folders/1TWdPW7QS6um21cDlBzH26-gs3fkDnoat?usp=share_link). We also provide notebooks to reproduce various Tables in the paper. To reproduce results: 35 | - For Table-2, see the notebooks with `checklist` in the name 36 | - For Table-3,4,5 please follow the instructions in the notebook `override_patches_sentiment.ipynb` and `orig_model_results.ipynb` 37 | - For Figure-4, use `finetuning_experiments.py` 38 | 39 | 40 | To cite this paper, use: 41 | ``` 42 | @inproceedings{murty2022patches, 43 | title = "Fixing Model Bugs with Natural Language Patches", 44 | author = "Murty, Shikhar and 45 | Manning, Christopher and 46 | Lundberg, Scott and 47 | Ribeiro, Marco Tulio", 48 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", 49 | year = "2022", 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /train_models.py: -------------------------------------------------------------------------------- 1 | # Entry script: Train language models so that they can be patched with natural language at test time 2 | 3 | import random 4 | import os 5 | import torch 6 | import logging 7 | import numpy as np 8 | from transformers import AutoTokenizer 9 | 10 | import json 11 | from training_utils import train_loop 12 | 13 | # for managing experiments 14 | import hydra 15 | from hydra.utils import get_original_cwd 16 | 17 | 18 | from model import T5Interpeter, T5ForConditionalGenerationMultipleHeads 19 | from build_datasets import ( 20 | get_data_task_finetuning, 21 | get_data_patch_finetuning_sentiment, 22 | get_data_patch_finetuning_re, 23 | ) 24 | 25 | import wandb 26 | 27 | ### SET THIS AS YOUR OWN PROJECT! 28 | wandb.init(project="patches", entity="shikharmurty") 29 | 30 | 31 | @hydra.main(config_path="config", config_name="task-finetuning") 32 | def main(cfg): 33 | log = logging.getLogger(__name__) 34 | # save the config to wandb run 35 | wandb.config = cfg 36 | # make the model name descriptive enough so it doubles as a run name 37 | wandb.run.name = cfg.train.save_path 38 | wandb.run.save() 39 | # set seed 40 | random.seed(cfg.get("seed", 42)) 41 | orig_working_dir = get_original_cwd() 42 | 43 | model_type = cfg.model_type 44 | model = T5ForConditionalGenerationMultipleHeads.from_pretrained(model_type) 45 | tokenizer = AutoTokenizer.from_pretrained(model_type) 46 | 47 | data_protocol = cfg.get("protocol", "simple") 48 | if "patch" in data_protocol and not cfg.get("learnt_interpreter", False): 49 | primary_mode = "patch_applies_predictor" 50 | else: 51 | primary_mode = "task_predictor" 52 | if cfg.train.get("load_path", None): 53 | load_path = cfg.train.load_path 54 | log.info("loading a checkpoint from {}".format(load_path)) 55 | try: 56 | model.load_state_dict(torch.load(os.path.join(orig_working_dir, load_path))) 57 | model_obj = T5Interpeter( 58 | model, tokenizer, primary_mode=primary_mode, train_multihead=True 59 | ) 60 | except: 61 | model_obj = T5Interpeter( 62 | model, tokenizer, primary_mode=primary_mode, train_multihead=True 63 | ) 64 | model_obj.load_state_dict( 65 | torch.load(os.path.join(orig_working_dir, load_path)), strict=False 66 | ) 67 | 68 | else: 69 | model_obj = T5Interpeter( 70 | model, tokenizer, primary_mode=primary_mode, train_multihead=True 71 | ) 72 | 73 | if torch.cuda.is_available(): 74 | device = torch.device("cuda") 75 | model_obj.to(device) 76 | 77 | ### Get data for Task Finetuning stage 78 | if data_protocol == "simple": 79 | train_data, val_data = get_data_task_finetuning(cfg, tokenizer) 80 | ### Get data for Patch Finetuning stage 81 | elif data_protocol == "patch_finetune_conds": 82 | train_data, val_data = get_data_patch_finetuning_sentiment(cfg, tokenizer) 83 | elif data_protocol == "patch_re": 84 | train_data, val_data = get_data_patch_finetuning_re(cfg, tokenizer) 85 | 86 | wandb.watch(model_obj) 87 | if cfg.data == "spouse_re" or data_protocol == "patch_re": 88 | metric = "f1" 89 | elif cfg.get("learnt_interpreter", False): 90 | metric = "task_data_patch_acc" 91 | else: 92 | metric = "acc" 93 | train_loop(model_obj, log, cfg.train, train_data, val_data, metric) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /feature-checklist-baselines.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "0c03b5a6", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "8de87b68", 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stderr", 22 | "output_type": "stream", 23 | "text": [ 24 | "Some weights of T5ForConditionalGenerationMultipleHeads were not initialized from the model checkpoint at t5-large and are newly initialized: ['encoder.embed_tokens.weight']\n", 25 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 26 | ] 27 | }, 28 | { 29 | "name": "stdout", 30 | "output_type": "stream", 31 | "text": [ 32 | "primary mode: exp_applies_predictor\n", 33 | "splicing parts from pretrained model\n" 34 | ] 35 | }, 36 | { 37 | "name": "stderr", 38 | "output_type": "stream", 39 | "text": [ 40 | "Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at t5-large and are newly initialized: ['encoder.embed_tokens.weight']\n", 41 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "from eval_utils import load_model\n", 47 | "path_name = '/u/scr/smurty/LanguageExplanations/trained_models/t5-large-sst-no-exp'\n", 48 | "model_obj = load_model(path_name, primary_mode='exp_applies_predictor')" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "ff4c7555", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "from collections import Counter\n", 59 | "from eval_utils import predict_stuff\n", 60 | "import itertools\n", 61 | "import numpy as np\n", 62 | "\n", 63 | "\n", 64 | "def single_prompt_pred(model, data, prompt, examine=False):\n", 65 | " prompted_inps = [(prompt, ex) for ex in data[0]]\n", 66 | " probs = predict_stuff(prompted_inps, itertools.repeat(0), \n", 67 | " model, 'p1', verbose=False, mode='task_predictor')\n", 68 | " #print(np.mean(probs.argmax(axis=1)==data[1]))\n", 69 | " return probs.argmax(axis=1)\n", 70 | "\n", 71 | "def get_prompting_scores(model_obj, data, prompt_set, examine=False):\n", 72 | " orig_preds = single_prompt_pred(model_obj, data, '')\n", 73 | " all_preds = []\n", 74 | " for prompt in patches:\n", 75 | " all_preds.append(single_prompt_pred(model_obj, (inputs, labels), prompt))\n", 76 | " p = np.stack(all_preds, axis=1)\n", 77 | " model_preds = []\n", 78 | " for x in p:\n", 79 | " counts = Counter(x)\n", 80 | " model_preds.append(counts[1] > counts[0])\n", 81 | " return np.array(model_preds), orig_preds" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "id": "c00d54d1", 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "from data_fns import knowledge_absn\n", 92 | "inputs, labels, patches = knowledge_absn(abstraction=True)\n", 93 | "\n", 94 | "model_preds, orig_preds = get_prompting_scores(model_obj, (inputs, labels), patches)\n", 95 | "print(np.mean(model_preds == orig_preds))" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "39797ed6", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "from data_fns import knowledge_checklists\n", 106 | "out, patches = knowledge_checklists(abstraction=True)\n", 107 | "inputs = []\n", 108 | "labels = []\n", 109 | "for key in out:\n", 110 | " inputs += out[key]['instances']\n", 111 | " labels += out[key]['labels']" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "e46acfec", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "model_preds, orig_preds = get_prompting_scores(model_obj, (inputs, labels), patches)\n", 122 | "print(np.mean(model_preds == labels))\n", 123 | "print(np.mean(orig_preds == labels))" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "37d0a75b", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [] 133 | } 134 | ], 135 | "metadata": { 136 | "kernelspec": { 137 | "display_name": "shikhar-basic", 138 | "language": "python", 139 | "name": "shikhar-basic" 140 | }, 141 | "language_info": { 142 | "codemirror_mode": { 143 | "name": "ipython", 144 | "version": 3 145 | }, 146 | "file_extension": ".py", 147 | "mimetype": "text/x-python", 148 | "name": "python", 149 | "nbconvert_exporter": "python", 150 | "pygments_lexer": "ipython3", 151 | "version": "3.8.10" 152 | } 153 | }, 154 | "nbformat": 4, 155 | "nbformat_minor": 5 156 | } 157 | -------------------------------------------------------------------------------- /data_fns/data_fns_yelp.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | from collections import defaultdict as ddict 4 | import random 5 | import pandas as pd 6 | 7 | from checklist.editor import Editor 8 | import re 9 | 10 | editor = Editor() 11 | ### Data slices from Yelp 12 | def get_yelp_data(conflicting=False): 13 | df = pd.read_csv("compare_model_steering_labeled.csv") 14 | fix_quotes = lambda i: i if i[-1] != "'" else i[:-1] 15 | df_inputs = [fix_quotes(cinput) for cinput in df["Input"]] 16 | df_labels = [label for label in df["Label"]] 17 | df_service_labels = [label for label in df["service-label"]] 18 | df_food_labels = [label for label in df["food-label"]] 19 | 20 | label_word_dict = {1: "positive", 0: "negative"} 21 | 22 | def get_val(key): 23 | if key in label_word_dict: 24 | return label_word_dict[key] 25 | else: 26 | return "NAN" 27 | 28 | def subset(clist, idxs): 29 | return [clist[idx] for idx in idxs] 30 | 31 | def conflict(ldict): 32 | return ldict[0]["polarity"] != ldict[1]["polarity"] 33 | 34 | df_label_dict_list = [ 35 | [ 36 | {"category": "service", "polarity": get_val(df_service_labels[idx])}, 37 | {"category": "food", "polarity": get_val(df_food_labels[idx])}, 38 | ] 39 | for idx, _ in enumerate(df_labels) 40 | ] 41 | if conflicting: 42 | idxs = [idx for idx, ldict in enumerate(df_label_dict_list) if conflict(ldict)] 43 | return ( 44 | subset(df_inputs, idxs), 45 | subset(df_label_dict_list, idxs), 46 | subset(df_labels, idxs), 47 | ) 48 | else: 49 | return df_inputs, df_label_dict_list, df_labels 50 | 51 | 52 | def get_all_yelp_data(): 53 | yelp_dataset = load_dataset("yelp_polarity", split="train") 54 | yelp_dataset = [ 55 | {"sentence": ex["text"], "label": ex["label"]} 56 | for ex in yelp_dataset 57 | if len(ex["text"]) < 240 58 | ] 59 | return ( 60 | [ex["sentence"] for ex in yelp_dataset], 61 | [ex["label"] for ex in yelp_dataset], 62 | None, 63 | ) 64 | 65 | 66 | def get_yelp_colloquial_control(): 67 | all_templates = [] 68 | all_templates.append( 69 | editor.template( 70 | "The bomb was {verb} by the {auth} at the {place}", 71 | verb=["found", "diffused"], 72 | auth=["waiter", "waitress", "manager", "police"], 73 | place=["restaurant", "cafe", "bar"], 74 | ) 75 | ) 76 | all_templates.append( 77 | editor.template( 78 | "The {auth} {verb} the bomb at the {place}", 79 | verb=["found", "diffused"], 80 | auth=["waiter", "waitress", "manager", "police"], 81 | place=["restaurant", "cafe", "bar"], 82 | ) 83 | ) 84 | all_templates.append( 85 | editor.template( 86 | "The {placeorperson} bombed the order", 87 | placeorperson=["restaurant", "waiter", "server", "waitress"], 88 | ) 89 | ) 90 | all_templates.append( 91 | editor.template( 92 | "The {place} looked like a bomb had exploded", 93 | place=["restaurant", "cafe", "pub", "bar"], 94 | ) 95 | ) 96 | all_templates.append( 97 | editor.template( 98 | "The {person} was a dope and kept forgetting people's orders", 99 | person=["waiter", "server", "manager", "chef"], 100 | ) 101 | ) 102 | all_templates.append( 103 | editor.template( 104 | "The {person} seemed like he was on dope", 105 | person=["waiter", "server", "manager", "chef"], 106 | ) 107 | ) 108 | all_templates.append( 109 | editor.template( 110 | "The {aspect} was quite shitty", aspect=["food", "service", "ambience"] 111 | ) 112 | ) 113 | 114 | examples = [] 115 | labels = [] 116 | for template in all_templates: 117 | examples += template.data 118 | labels += [0] * len(template.data) 119 | return examples, labels 120 | 121 | 122 | def get_yelp_colloquial(): 123 | def balance(inps, labels): 124 | label2idx = ddict(list) 125 | for idx, label in enumerate(labels): 126 | label2idx[label].append(idx) 127 | min_size = min(len(val) for _, val in label2idx.items()) 128 | pos_idxs = random.sample(label2idx[1], k=min_size) 129 | neg_idxs = random.sample(label2idx[0], k=min_size) 130 | return [inps[idx] for idx in pos_idxs + neg_idxs], [ 131 | labels[idx] for idx in pos_idxs + neg_idxs 132 | ] 133 | 134 | examples, labels, _ = get_all_yelp_data() 135 | terms = ["wtf", "omg", "the shit", "bomb", "dope", "suck"] 136 | filtered_idxs = [] 137 | for idx, ex in enumerate(examples): 138 | if any(word in ex.lower() for word in terms): 139 | filtered_idxs.append(idx) 140 | 141 | examples = [examples[idx] for idx in filtered_idxs] 142 | labels = [labels[idx] for idx in filtered_idxs] 143 | examples_b, labels_b = balance(examples, labels) 144 | return examples_b, labels_b 145 | 146 | 147 | def get_yelp_stars(): 148 | examples, labels, _ = get_all_yelp_data() 149 | idxs = [idx for idx, ex in enumerate(examples) if re.search(r"\bstars?\b", ex)] 150 | 151 | patches = [ 152 | "If review gives more than 3 stars, then sentiment is positive and if review gives less than 3 stars then sentiment is negative", 153 | "", 154 | ] 155 | return [examples[idx] for idx in idxs], [labels[idx] for idx in idxs], patches 156 | -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | from checklist.editor import Editor 2 | import os 3 | import re 4 | import pickle 5 | from datasets import Dataset as HFDataset 6 | from hydra.utils import get_original_cwd 7 | 8 | editor = Editor() 9 | pos_adjectives = ["good", "great", "amazing", "wonderful", "awesome"] 10 | neg_adjectives = ["bad", "pathetic", "awful", "terrible", "horrid"] 11 | # add in some gibberish words 12 | gibberish_adjectives = ["zonker", "wonker", "zubin", "wugly", "shug"] 13 | all_pos_adjectives = pos_adjectives + gibberish_adjectives 14 | all_neg_adjectives = neg_adjectives + gibberish_adjectives 15 | 16 | restaurant_patches = [ 17 | ("food is more important for determining sentiment than service", "food"), 18 | ( 19 | "service is more important for determining sentiment than food", 20 | "service", 21 | ), 22 | ("quality of food determines sentiment", "food"), 23 | ("quality of service determines sentiment", "service"), 24 | ("food matters more than service for determining sentiment", "food"), 25 | ("service matters more than food for determining sentiment", "service"), 26 | ] 27 | 28 | 29 | # this is the null prompt 30 | def prompt_style_0(patch, sentence): 31 | if len(patch) > 0: 32 | return (patch, sentence) 33 | else: 34 | return sentence 35 | 36 | 37 | def make_re_transforms(patch, sentence): 38 | text, ent_info = sentence.split(" Entity1:") 39 | e1, e2 = ent_info.split(". Entity2:") 40 | 41 | e1 = e1.strip() 42 | e2 = e2.strip() 43 | exp_new = patch.replace("Entity1", e1).replace("Entity2", e2) 44 | return exp_new, text 45 | 46 | 47 | def prompt_style_1(patch, sentence): 48 | if len(patch) > 0: 49 | # if 'Entity1:' in sentence: 50 | # patch, sentence = make_re_transforms(patch, sentence) 51 | out = "patch: {}. Input: {}".format(patch, sentence) 52 | else: 53 | out = "Input: {}".format(sentence) 54 | out = out.rstrip() 55 | return out[:-1].rstrip() if out[-1] == "." else out 56 | 57 | 58 | def prompt_style_reverse(patch, sentence): 59 | if len(patch) > 0: 60 | out = "Input: {}. patch: {}".format(sentence, patch) 61 | else: 62 | out = "Input: {}".format(sentence) 63 | out = out.rstrip() 64 | return out[:-1].rstrip() if out[-1] == "." else out 65 | 66 | 67 | def prompt_style_2(patch, sentence): 68 | if len(patch) > 0: 69 | return "Steering hints: %s. '%s'" % (patch, sentence) 70 | else: 71 | return "'%s'" % sentence 72 | 73 | 74 | def prompt_style_1_exp_applies(patch, sentence): 75 | out = "patch: {}. Input: {}".format(patch, sentence) 76 | out = out.rstrip() 77 | return out[:-1] if out[-1] == "." else out 78 | 79 | 80 | prompt_styles = { 81 | "p0": prompt_style_0, 82 | "p1": prompt_style_1, 83 | "p1_reverse": prompt_style_reverse, 84 | "p1_exp_applies": prompt_style_1_exp_applies, 85 | "p2": prompt_style_2, 86 | } 87 | 88 | 89 | def convert_to_tensors(data_list, tokenizer): 90 | dataset = { 91 | "sentence": [ex for ex, _ in data_list], 92 | "label": tokenizer([label for _, label in data_list])["input_ids"], 93 | } 94 | dataset = HFDataset.from_dict(dataset) 95 | 96 | tokenize_func = lambda examples: tokenizer(examples["sentence"], truncation=True) 97 | tensored_dataset = dataset.map( 98 | tokenize_func, batched=True, remove_columns=["sentence"] 99 | ) 100 | return tensored_dataset 101 | 102 | 103 | def on_azure(): 104 | return any("AZURE" in key for key in os.environ) 105 | 106 | 107 | def get_spouse_data(split, prompt_style, use_percent=1.0): 108 | return get_re_data(split, "SPOUSE_DATA", prompt_style, use_percent) 109 | 110 | 111 | def verbalize_all(data, prompt_style="p1"): 112 | label_verbalizer = {1: "positive", 0: "negative"} 113 | pf = prompt_styles[prompt_style] 114 | verbalized_data = [] 115 | for ex, label in data: 116 | verbalized_ex = pf("", ex) 117 | verbalized_label = label_verbalizer[int(label)] 118 | verbalized_data.append((verbalized_ex, verbalized_label)) 119 | return verbalized_data 120 | 121 | 122 | from itertools import chain, combinations, permutations 123 | 124 | 125 | def powerset(iterable): 126 | "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" 127 | s = list(iterable) 128 | return chain.from_iterable(permutations(s, r) for r in range(len(s) + 1)) 129 | 130 | 131 | def verbalize_examples( 132 | examples, prompt_style="p0", labels_given=False, label_name_dict=None 133 | ): 134 | verbalized_examples = [] 135 | if labels_given: 136 | if label_name_dict is None: 137 | label_name_dict = {0: "negative", 1: "positive"} 138 | 139 | for example in examples: 140 | if type(example) == tuple: 141 | text, label = example 142 | assert labels_given 143 | verbalized_label = label_name_dict[label] 144 | else: 145 | text = example 146 | verbalized_label = None 147 | 148 | verbalized_text = text 149 | 150 | if verbalized_label: 151 | verbalized_examples.append((verbalized_text, verbalized_label)) 152 | else: 153 | verbalized_examples.append(verbalized_text) 154 | return verbalized_examples 155 | 156 | 157 | def symbol_at_eos(examples): 158 | all_data = [] 159 | for ex in examples: 160 | cinput = ex["sentence"].rstrip() + " :z" 161 | patch = ":z at the end of input changes sentiment" 162 | label = 1 - int(ex["label"] > 0.5) 163 | all_data.append((cinput, label, patch)) 164 | return all_data 165 | -------------------------------------------------------------------------------- /finetuning_experiments.py: -------------------------------------------------------------------------------- 1 | # ### How many labeled examples is patching worth? #### 2 | 3 | 4 | from eval_utils import predict_stuff 5 | from data_utils import ( 6 | get_yelp_colloquial_control, 7 | get_yelp_stars, 8 | get_yelp_colloquial, 9 | get_fewrel_data, 10 | ) 11 | from eval_utils import load_model, predict_stuff 12 | import numpy as np 13 | import argparse 14 | import random 15 | import re 16 | import torch 17 | from eval_utils import fewshot_finetune 18 | 19 | 20 | def compare(inputs, labels, exp, path_name): 21 | model = load_model(path_name) 22 | inputs_without_exp = [("", inp) for inp in inputs] 23 | baseline_1_out = predict_stuff( 24 | inputs_without_exp, labels, model, verbose=True, prompt_style="p1" 25 | ) 26 | baseline_1_labels = baseline_1_out.argmax(axis=1) 27 | 28 | inputs_with_exp = [(exp, inp) for inp in inputs] 29 | baseline_1_out_withexp = predict_stuff( 30 | inputs_with_exp, labels, model, verbose=True, prompt_style="p1" 31 | ) 32 | baseline_1_labels_withexp = baseline_1_out_withexp.argmax(axis=1) 33 | 34 | print(np.mean(baseline_1_labels == np.array(labels))) 35 | print(np.mean(baseline_1_labels_withexp == np.array(labels))) 36 | 37 | 38 | # keep 80% of yelp stars for testing, and let's say we get 20% for finetuning. what happens? 39 | def get_sample(inputs, labels, sample_size): 40 | all_idxs = [idx for idx, _ in enumerate(inputs)] 41 | sampled_idxs = set(random.sample(all_idxs, k=sample_size)) 42 | 43 | train_inputs = [inputs[idx] for idx, _ in enumerate(inputs) if idx in sampled_idxs] 44 | test_inputs = [ 45 | inputs[idx] for idx, _ in enumerate(inputs) if idx not in sampled_idxs 46 | ] 47 | 48 | train_labels = [labels[idx] for idx, _ in enumerate(inputs) if idx in sampled_idxs] 49 | test_labels = [ 50 | labels[idx] for idx, _ in enumerate(inputs) if idx not in sampled_idxs 51 | ] 52 | return (train_inputs, train_labels), (test_inputs, test_labels) 53 | 54 | 55 | def get_fewshot_curve(train_data, test_data, path_name, metric="acc"): 56 | # for all tasks this is how we get subsets to stay balanced 57 | def get_subset_balanced(num_examples): 58 | shots = num_examples // 2 59 | pos_ex = [ 60 | train_data[0][idx] for idx, label in enumerate(train_data[1]) if label == 1 61 | ] 62 | neg_ex = [ 63 | train_data[0][idx] for idx, label in enumerate(train_data[1]) if label == 0 64 | ] 65 | chosen_pos = random.sample(pos_ex, k=shots) 66 | chosen_neg = random.sample(neg_ex, k=shots) 67 | labels = [1] * len(chosen_pos) + [0] * len(chosen_neg) 68 | return chosen_pos + chosen_neg, labels 69 | 70 | # for relation extraction though, we get subsets like this, since we don't want to balance the dataset! 71 | def get_subset_re(num_examples): 72 | all_idxs = [idx for idx, _ in enumerate(train_data[0])] 73 | sampled_idxs = random.sample(all_idxs, k=num_examples) 74 | chosen_inps = [train_data[0][idx] for idx in sampled_idxs] 75 | chosen_labels = [train_data[1][idx] for idx in sampled_idxs] 76 | return chosen_inps, chosen_labels 77 | 78 | all_acc_dict = {} 79 | for num_examples in [2, 4, 8, 16, 32, 64, 128]: 80 | try: 81 | if metric == "acc": 82 | train_subset = get_subset_balanced(num_examples) 83 | else: 84 | train_subset = get_subset_re(num_examples) 85 | except: 86 | break 87 | curr_acc = fewshot_finetune( 88 | path_name, 89 | update_steps=64, 90 | train_tuple_list=train_subset, 91 | val_tuple_list=test_data, 92 | metric=metric, 93 | ) 94 | all_acc_dict[num_examples] = curr_acc 95 | print(curr_acc) 96 | return all_acc_dict 97 | 98 | 99 | def set_seed(rand_seed): 100 | random.seed(rand_seed) 101 | np.random.seed(rand_seed) 102 | torch.manual_seed(rand_seed) 103 | torch.cuda.manual_seed_all(rand_seed) 104 | 105 | 106 | def main(): 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument("--type", type=str, default="stars") 109 | parser.add_argument("--seed", type=int, default=42) 110 | parser.add_argument("--path_name", type=str) 111 | parser.add_argument("--sample_size", type=int, default=256) 112 | 113 | args = parser.parse_args() 114 | 115 | set_seed(args.seed) 116 | if args.type == "stars": 117 | star_inputs, star_labels, _ = get_yelp_stars() 118 | star_train_data, star_test_data = get_sample( 119 | star_inputs, star_labels, args.sample_size 120 | ) 121 | all_data_dict = get_fewshot_curve( 122 | star_train_data, star_test_data, args.path_name 123 | ) 124 | with open("finetuning_logs/{}_{}.txt".format(args.type, args.seed), "w") as f: 125 | f.write(str(all_data_dict)) 126 | elif args.type == "spouse_nyt": 127 | pos, neg = get_fewrel_data() 128 | labels = [1] * len(pos) + [0] * len(neg) 129 | inputs = pos + neg 130 | train_data, test_data = get_sample(inputs, labels, args.sample_size) 131 | all_data_dict = get_fewshot_curve( 132 | train_data, test_data, args.path_name, metric="f1" 133 | ) 134 | with open("finetuning_logs/{}_{}.txt".format(args.type, args.seed), "w") as f: 135 | f.write(str(all_data_dict)) 136 | elif "yelp_colloquial" in args.type: 137 | inputs, labels = get_yelp_colloquial() 138 | print(len(inputs)) 139 | train_data, test_data = get_sample(inputs, labels, args.sample_size) 140 | if "control" in args.type: 141 | test_data = get_yelp_colloquial_control() 142 | all_data_dict = get_fewshot_curve(train_data, test_data, args.path_name) 143 | with open("finetuning_logs/{}_{}.txt".format(args.type, args.seed), "w") as f: 144 | f.write(str(all_data_dict)) 145 | 146 | 147 | if __name__ == "__main__": 148 | main() 149 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: lang-patching 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - backcall=0.2.0=py_0 10 | - blas=1.0=mkl 11 | - bzip2=1.0.8=h7b6447c_0 12 | - ca-certificates=2021.10.26=h06a4308_2 13 | - certifi=2021.10.8=py38h06a4308_0 14 | - cudatoolkit=11.3.1=h2bc3f7f_2 15 | - ffmpeg=4.3=hf484d3e_0 16 | - freetype=2.11.0=h70c0345_0 17 | - giflib=5.2.1=h7b6447c_0 18 | - gmp=6.2.1=h2531618_2 19 | - gnutls=3.6.15=he1e5248_0 20 | - intel-openmp=2021.4.0=h06a4308_3561 21 | - ipython_genutils=0.2.0=py38_0 22 | - jedi=0.18.0=py38h06a4308_1 23 | - jpeg=9d=h7f8727e_0 24 | - jupyter_client=6.1.7=py_0 25 | - jupyter_core=4.6.3=py38_0 26 | - lame=3.100=h7b6447c_0 27 | - lcms2=2.12=h3be6417_0 28 | - ld_impl_linux-64=2.35.1=h7274673_9 29 | - libffi=3.3=he6710b0_2 30 | - libgcc-ng=9.3.0=h5101ec6_17 31 | - libgomp=9.3.0=h5101ec6_17 32 | - libiconv=1.15=h63c8f33_5 33 | - libidn2=2.3.2=h7f8727e_0 34 | - libpng=1.6.37=hbc83047_0 35 | - libsodium=1.0.18=h7b6447c_0 36 | - libstdcxx-ng=9.3.0=hd4cf53a_17 37 | - libtasn1=4.16.0=h27cfd23_0 38 | - libtiff=4.2.0=h85742a9_0 39 | - libunistring=0.9.10=h27cfd23_0 40 | - libuv=1.40.0=h7b6447c_0 41 | - libwebp=1.2.0=h89dd481_0 42 | - libwebp-base=1.2.0=h27cfd23_0 43 | - lz4-c=1.9.3=h295c915_1 44 | - mkl=2021.4.0=h06a4308_640 45 | - mkl-service=2.4.0=py38h7f8727e_0 46 | - mkl_fft=1.3.1=py38hd3c417c_0 47 | - mkl_random=1.2.2=py38h51133e4_0 48 | - ncurses=6.2=he6710b0_1 49 | - nettle=3.7.3=hbbd107a_1 50 | - olefile=0.46=pyhd3eb1b0_0 51 | - openh264=2.1.0=hd408876_0 52 | - openssl=1.1.1l=h7f8727e_0 53 | - pexpect=4.8.0=py38_0 54 | - pickleshare=0.7.5=py38_1000 55 | - pip=21.1.3=py38h06a4308_0 56 | - python=3.8.10=h12debd9_8 57 | - python-dateutil=2.8.1=py_0 58 | - pytorch=1.10.0=py3.8_cuda11.3_cudnn8.2.0_0 59 | - pytorch-mutex=1.0=cuda 60 | - readline=8.1=h27cfd23_0 61 | - setuptools=52.0.0=py38h06a4308_0 62 | - sqlite=3.36.0=hc218d9a_0 63 | - tk=8.6.10=hbc83047_0 64 | - traitlets=5.0.5=py_0 65 | - typing_extensions=3.10.0.2=pyh06a4308_0 66 | - wcwidth=0.2.5=py_0 67 | - wheel=0.36.2=pyhd3eb1b0_0 68 | - xz=5.2.5=h7b6447c_0 69 | - zeromq=4.3.3=he6710b0_3 70 | - zlib=1.2.11=h7b6447c_3 71 | - zstd=1.4.9=haebb681_0 72 | - pip: 73 | - absl-py==1.0.0 74 | - aiohttp==3.7.4.post0 75 | - allennlp==2.9.3 76 | - antlr4-python3-runtime==4.8 77 | - argon2-cffi==20.1.0 78 | - async-generator==1.10 79 | - async-timeout==3.0.1 80 | - attrs==21.2.0 81 | - backports-csv==1.0.7 82 | - base58==2.1.1 83 | - beautifulsoup4==4.10.0 84 | - black==22.10.0 85 | - bleach==3.3.0 86 | - blis==0.7.4 87 | - boto3==1.23.1 88 | - botocore==1.26.1 89 | - cached-path==1.1.2 90 | - cachetools==5.0.0 91 | - captum==0.4.0 92 | - catalogue==2.0.6 93 | - cffi==1.14.5 94 | - chardet==4.0.0 95 | - checklist==0.0.11 96 | - cheroot==8.5.2 97 | - cherrypy==18.6.1 98 | - click==8.1.3 99 | - cloudpickle==2.0.0 100 | - configparser==5.0.2 101 | - conllu==4.5.2 102 | - cryptography==3.4.8 103 | - cycler==0.10.0 104 | - cymem==2.0.5 105 | - datasets==2.6.1 106 | - debugpy==1.3.0 107 | - decorator==5.0.9 108 | - defusedxml==0.7.1 109 | - dill==0.3.4 110 | - docker-pycreds==0.4.0 111 | - emoji==1.6.1 112 | - en-core-web-sm==3.1.0 113 | - entrypoints==0.3 114 | - evaluate==0.3.0 115 | - fairscale==0.4.6 116 | - feedparser==6.0.8 117 | - filelock==3.6.0 118 | - fsspec==2022.10.0 119 | - future==0.18.2 120 | - gitdb==4.0.7 121 | - gitpython==3.1.24 122 | - google-api-core==2.7.3 123 | - google-auth==2.6.2 124 | - google-auth-oauthlib==0.4.6 125 | - google-cloud-core==2.3.0 126 | - google-cloud-storage==2.3.0 127 | - google-crc32c==1.3.0 128 | - google-resumable-media==2.3.2 129 | - googleapis-common-protos==1.56.1 130 | - grpcio==1.44.0 131 | - h5py==3.6.0 132 | - huggingface-hub==0.10.1 133 | - hydra-core==1.1.1 134 | - idna==2.10 135 | - imageio==2.21.0 136 | - importlib-metadata==4.11.3 137 | - importlib-resources==5.2.2 138 | - iniconfig==1.1.1 139 | - ipykernel==6.0.0 140 | - ipython==7.25.0 141 | - ipython-genutils==0.2.0 142 | - ipywidgets==7.6.3 143 | - iso-639==0.4.5 144 | - jaraco-classes==3.2.1 145 | - jaraco-collections==3.4.0 146 | - jaraco-functools==3.3.0 147 | - jaraco-text==3.5.1 148 | - jinja2==3.0.1 149 | - jmespath==1.0.0 150 | - joblib==1.0.1 151 | - jsonnet==0.18.0 152 | - jsonschema==3.2.0 153 | - jupyter==1.0.0 154 | - jupyter-client==6.1.12 155 | - jupyter-console==6.4.0 156 | - jupyter-core==4.7.1 157 | - jupyterlab-pygments==0.1.2 158 | - jupyterlab-widgets==1.0.0 159 | - kiwisolver==1.3.2 160 | - langcodes==3.3.0 161 | - llvmlite==0.37.0 162 | - lmdb==1.3.0 163 | - lxml==4.6.3 164 | - markdown==3.3.6 165 | - markupsafe==2.0.1 166 | - matplotlib==3.4.3 167 | - matplotlib-inline==0.1.2 168 | - mistune==0.8.4 169 | - more-itertools==8.10.0 170 | - multidict==5.2.0 171 | - multiprocess==0.70.12.2 172 | - munch==2.5.0 173 | - murmurhash==1.0.5 174 | - mypy-extensions==0.4.3 175 | - nbclient==0.5.3 176 | - nbconvert==6.1.0 177 | - nbformat==5.1.3 178 | - nest-asyncio==1.5.1 179 | - nltk==3.6.2 180 | - notebook==6.4.0 181 | - numba==0.54.1 182 | - numpy==1.20.3 183 | - oauthlib==3.2.0 184 | - omegaconf==2.1.1 185 | - packaging==21.3 186 | - pandas==1.2.5 187 | - pandocfilters==1.4.3 188 | - parso==0.8.2 189 | - pathspec==0.10.1 190 | - pathtools==0.1.2 191 | - pathy==0.6.0 192 | - patsy==0.5.2 193 | - patternfork-nosql==3.6 194 | - pdfminer-six==20201018 195 | - pillow==8.3.0 196 | - platformdirs==2.5.2 197 | - pluggy==1.0.0 198 | - portend==2.7.1 199 | - preshed==3.0.5 200 | - prometheus-client==0.11.0 201 | - promise==2.3 202 | - prompt-toolkit==3.0.19 203 | - protobuf==3.18.0 204 | - psutil==5.8.0 205 | - ptyprocess==0.7.0 206 | - py==1.11.0 207 | - pyarrow==6.0.1 208 | - pyasn1==0.4.8 209 | - pyasn1-modules==0.2.8 210 | - pycparser==2.20 211 | - pydantic==1.8.2 212 | - pygments==2.9.0 213 | - pyparsing==2.4.7 214 | - pyrsistent==0.18.0 215 | - pytest==7.1.2 216 | - python-docx==0.8.11 217 | - pytz==2021.1 218 | - pyyaml==5.4.1 219 | - pyzmq==22.1.0 220 | - qtconsole==5.1.1 221 | - qtpy==1.9.0 222 | - regex==2021.4.4 223 | - requests==2.25.1 224 | - requests-oauthlib==1.3.1 225 | - responses==0.18.0 226 | - rsa==4.8 227 | - s3transfer==0.5.2 228 | - sacremoses==0.0.45 229 | - scikit-learn==1.0.1 230 | - scipy==1.7.1 231 | - seaborn==0.11.2 232 | - send2trash==1.7.1 233 | - sentencepiece==0.1.96 234 | - sentry-sdk==1.4.3 235 | - sgmllib3k==1.0.0 236 | - shap==0.40.0 237 | - shortuuid==1.0.1 238 | - six==1.16.0 239 | - slicer==0.0.7 240 | - smart-open==5.2.1 241 | - smmap==4.0.0 242 | - sortedcontainers==2.4.0 243 | - soupsieve==2.2.1 244 | - spacy==3.2.4 245 | - spacy-legacy==3.0.8 246 | - spacy-loggers==1.0.2 247 | - srsly==2.4.1 248 | - stanza==1.4.0 249 | - statsmodels==0.13.2 250 | - subprocess32==3.5.4 251 | - tabulate==0.8.9 252 | - tempora==4.1.1 253 | - tensorboard==2.8.0 254 | - tensorboard-data-server==0.6.1 255 | - tensorboard-plugin-wit==1.8.1 256 | - tensorboardx==2.5 257 | - termcolor==1.1.0 258 | - terminado==0.10.1 259 | - testpath==0.5.0 260 | - thinc==8.0.15 261 | - threadpoolctl==3.0.0 262 | - tokenizers==0.12.1 263 | - tomli==2.0.1 264 | - torch==1.9.0 265 | - torchaudio==0.9.0 266 | - torchvision==0.10.0 267 | - tornado==6.1 268 | - tqdm==4.62.3 269 | - transformers==4.18.0 270 | - typer==0.4.1 271 | - typing-extensions==3.10.0.0 272 | - urllib3==1.26.6 273 | - wandb==0.12.3 274 | - wasabi==0.8.2 275 | - webencodings==0.5.1 276 | - werkzeug==2.1.1 277 | - widgetsnbextension==3.5.1 278 | - xxhash==2.0.2 279 | - yarl==1.7.0 280 | - yaspin==2.1.0 281 | - zc-lockfile==2.0 282 | - zipp==3.5.0 283 | prefix: /u/nlp/anaconda/main/anaconda3/envs/lang-patching 284 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | # load in the tokenizer 2 | import os, sys 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from torch.utils.data import ( 6 | DataLoader, 7 | RandomSampler, 8 | SequentialSampler, 9 | ) 10 | from munch import Munch 11 | 12 | from transformers import T5ForConditionalGeneration, AutoTokenizer 13 | from model import T5ForConditionalGenerationMultipleHeads, T5Interpeter 14 | import torch 15 | import itertools 16 | from helpers import convert_to_tensors, prompt_styles, verbalize_examples 17 | from transformers.data.data_collator import DataCollatorWithPadding 18 | from patch_dataset import SimpleDataset 19 | 20 | from training_utils import train_loop_fixed_steps 21 | 22 | 23 | def apply_patch_soft(patch_applies_probs, baseline_probs, conditioned_probs): 24 | applies_prob = patch_applies_probs[:, 1].reshape(-1, 1) 25 | return (applies_prob * conditioned_probs) + (1 - applies_prob) * baseline_probs 26 | 27 | 28 | def dissect(patch): 29 | cond, consequence = patch.split(",") 30 | cond = " ".join(cond.split(" ")[1:]) 31 | consequence = " ".join(consequence.split(" ")[2:]) 32 | print(cond, consequence) 33 | return cond, consequence 34 | 35 | 36 | def get_scores_multiple_patches_hard(model_obj, data, patch_list, silent=False): 37 | no_exps = [("", ex) for ex in data[0]] 38 | no_exp_probs = predict_stuff( 39 | no_exps, 40 | [0] * len(no_exps), 41 | model_obj, 42 | "p1", 43 | verbose=False, 44 | mode="task_predictor", 45 | ) 46 | if not silent: 47 | print(np.mean(no_exp_probs.argmax(axis=1) == data[1])) 48 | cond_probs = [] 49 | all_patched_probs = [] 50 | for idx, patch in enumerate(patch_list): 51 | if patch == "": 52 | continue 53 | 54 | cond, consequence = dissect(patch) 55 | contextualized = [(cond, ex) for ex in data[0]] 56 | gating_probs = predict_stuff( 57 | contextualized, itertools.repeat(0), model_obj, "p1", verbose=False 58 | ) 59 | cond_probs.append(np.log(gating_probs[:, 1])) # log(p(c | x)) 60 | 61 | conditioning_examples = [(consequence, ex) for ex in data[0]] 62 | conditioned_probs = predict_stuff( 63 | conditioning_examples, 64 | itertools.repeat(0), 65 | model_obj, 66 | "p1", 67 | verbose=True, 68 | mode="task_predictor", 69 | ) 70 | 71 | patched_probs = apply_patch_soft(gating_probs, no_exp_probs, conditioned_probs) 72 | 73 | if not silent: 74 | print("Applying patch {}".format(cond)) 75 | all_patched_probs.append(patched_probs[:, 1]) 76 | # how much should each be weighted by? 77 | # pick best patch and apply it! 78 | all_patched_probs = np.stack(all_patched_probs, axis=1) # D x P 79 | cond_probs = np.stack(cond_probs, axis=1) # D x P 80 | best_patches = np.argmax(cond_probs, axis=1) # D x l 81 | 82 | ptrue = np.array([p[idx] for p, idx in zip(all_patched_probs, best_patches)]) 83 | pfalse = 1.0 - ptrue 84 | return no_exp_probs, np.stack([pfalse, ptrue]).T 85 | 86 | 87 | def get_data(tuple_list, tokenizer, prompt_style="p1"): 88 | inputs, labels = tuple_list 89 | prompt_func = prompt_styles[prompt_style] 90 | verbalizer_label = {0: "negative", 1: "positive"} 91 | all_data = [] 92 | for inp, label in zip(inputs, labels): 93 | ex = prompt_func("", inp) 94 | all_data.append((ex, verbalizer_label[label])) 95 | 96 | return SimpleDataset(all_data, tokenizer, as_lm=True) 97 | 98 | 99 | def fewshot_finetune(path_name, update_steps, train_tuple_list, val_tuple_list, metric): 100 | # load the model in 101 | model_obj = load_model(path_name) 102 | train_data = get_data(train_tuple_list, model_obj.tokenizer) 103 | 104 | if type(val_tuple_list) == dict: 105 | val_data = { 106 | key: get_data(_val, model_obj.tokenizer) 107 | for key, _val in val_tuple_list.items() 108 | } 109 | else: 110 | val_data = get_data(val_tuple_list, model_obj.tokenizer) 111 | 112 | # TODO: figure out a way to get the config. 113 | cfg = Munch( 114 | num_warmup_steps=0, 115 | lr=1e-4, 116 | train_batch_size=4, 117 | accum_steps=4, 118 | eval_batch_size=256, 119 | ) 120 | return train_loop_fixed_steps( 121 | model_obj, cfg, {"task_data": train_data}, val_data, update_steps, metric 122 | ) 123 | 124 | 125 | def load_model(path_name, primary_mode="task_predictor", device_idx=0): 126 | if "t5" in path_name: 127 | tokenizer = AutoTokenizer.from_pretrained("t5-large") 128 | try: 129 | base_model = T5ForConditionalGenerationMultipleHeads.from_pretrained( 130 | "t5-large" 131 | ) 132 | model_obj = T5Interpeter( 133 | base_model, tokenizer, primary_mode=primary_mode, train_multihead=True 134 | ) 135 | # don't set strict to true here because we want all keys to match! 136 | model_obj.load_state_dict(torch.load(path_name, map_location="cpu")) 137 | except RuntimeError: 138 | print("only loading base model!") 139 | base_model = T5ForConditionalGeneration.from_pretrained("t5-large") 140 | base_model.load_state_dict( 141 | torch.load(path_name, map_location="cpu"), strict="False" 142 | ) 143 | model_obj = T5Interpeter( 144 | base_model, tokenizer, primary_mode=primary_mode, train_multihead=False 145 | ) 146 | # except: 147 | # print("loading base model with multiple heads") 148 | # base_model = T5ForConditionalGenerationMultipleHeads.from_pretrained('t5-large') 149 | # model_obj = T5Interpeter(base_model, tokenizer, primary_mode=primary_mode, train_multihead=False) 150 | # model_obj.load_state_dict(torch.load(path_name, map_location='cpu')) 151 | else: 152 | print("model not supported") 153 | sys.exit(1) 154 | 155 | if torch.cuda.is_available(): 156 | if device_idx: 157 | device = torch.device("cuda:{}".format(device_idx)) 158 | else: 159 | device = torch.device("cuda") 160 | model_obj.to(device) 161 | else: 162 | print("No cuda!!") 163 | model_obj.eval() 164 | return model_obj 165 | 166 | 167 | def predict_stuff_helper( 168 | model, 169 | dataset, 170 | verbose, 171 | interchange=True, 172 | data_collator_to_use=None, 173 | batch_size=64, 174 | mode=None, 175 | ret_result=False, 176 | ): 177 | if data_collator_to_use is None: 178 | tokenizer = model.tokenizer 179 | data_collator_to_use = DataCollatorWithPadding(tokenizer=tokenizer) 180 | dataloader = DataLoader( 181 | dataset, 182 | sampler=SequentialSampler(dataset), 183 | batch_size=batch_size, 184 | collate_fn=data_collator_to_use, 185 | ) 186 | result = model.evaluator(dataloader, verbose=verbose, mode=mode) 187 | if ret_result: 188 | return result 189 | else: 190 | pp = F.softmax(result["logits"], dim=1).numpy() 191 | if interchange: 192 | pp = np.hstack((pp[:, 1:], pp[:, 0:1])) 193 | return pp 194 | 195 | 196 | def predict_stuff( 197 | examples, 198 | labels, 199 | model, 200 | prompt_style, 201 | verbose=False, 202 | interchange=True, 203 | verbalize=True, 204 | batch_size=64, 205 | mode=None, 206 | ): 207 | prompt_func = prompt_styles[prompt_style] 208 | tokenizer = model.tokenizer 209 | examples = [x if type(x) == str else prompt_func(x[0], x[1]) for x in examples] 210 | if verbalize: 211 | verbalized_examples = verbalize_examples( 212 | [(x, label) for (x, label) in zip(examples, labels)], 213 | prompt_style, 214 | labels_given=True, 215 | ) 216 | else: 217 | verbalized_examples = [(x, label) for (x, label) in zip(examples, labels)] 218 | if verbose: 219 | print(verbalized_examples[0]) 220 | test_dataset = convert_to_tensors(verbalized_examples, tokenizer) 221 | return predict_stuff_helper( 222 | model, 223 | test_dataset, 224 | verbose, 225 | interchange=interchange, 226 | batch_size=batch_size, 227 | mode=mode, 228 | ) 229 | 230 | 231 | def get_predictions(patches, inputs, model_dict, prompt_style=None): 232 | model2preds = {} 233 | for model_name in model_dict: 234 | preds = {} 235 | try: 236 | model_obj = load_model(model_dict[model_name], None) 237 | except: 238 | continue 239 | if not prompt_style: 240 | if "p2" in model_name: 241 | prompt_style = "p2" 242 | else: 243 | prompt_style = "p1" 244 | 245 | for patch in patches: 246 | if len(patch) == 0: 247 | input_examples = inputs 248 | else: 249 | input_examples = [(patch, cinput) for cinput in inputs] 250 | preds[patch] = predict_stuff( 251 | input_examples, [0] * len(inputs), model_obj, prompt_style 252 | ) 253 | model2preds[model_name] = preds 254 | return model2preds 255 | -------------------------------------------------------------------------------- /data_fns/data_fns_featurebased_checklists.py: -------------------------------------------------------------------------------- 1 | # ==== Functions for generating data for evaluating knowledge patches ===== 2 | from collections import defaultdict as ddict 3 | import json 4 | from helpers import prompt_styles 5 | from checklist.editor import Editor 6 | 7 | editor = Editor() 8 | pf = prompt_styles["p1"] 9 | 10 | 11 | def knowledge_absn( 12 | abstraction=False, 13 | ): 14 | negated, patches = knowledge_checklists_negating_contexts(abstraction) 15 | # irrelevant 16 | irrelevant, _ = knowledge_checklists( 17 | abstraction, 18 | use_irrelevant=True, 19 | ) 20 | 21 | # non predictive 22 | non_predictive, _ = knowledge_checklists_flips(abstraction) 23 | 24 | all_inputs = [] 25 | all_labels = [] 26 | 27 | all_types = [negated, irrelevant, non_predictive] 28 | for ddtype in all_types: 29 | for k in ddtype: 30 | all_inputs += ddtype[k]["instances"] 31 | all_labels += ddtype[k]["labels"] 32 | 33 | return all_inputs, all_labels, patches 34 | 35 | 36 | def knowledge_checklists_negating_contexts( 37 | abstraction=False, 38 | ): 39 | # contexts that negate the meaning such as e.g. 40 | # I did not think that the food was {} 41 | # I thought the food was not {} 42 | # The food was not {}, in my opinion 43 | # my friends thought that the food was {}, but i did not think so. 44 | 45 | # padj = ['seeet', 'bgesx', 'weref'] 46 | # nadj = ['wuex', 'sercx', 'wety'] 47 | 48 | padj = ["numf", "weref", "wety"] 49 | nadj = ["wuex", "muxy", "wegry"] 50 | 51 | patches = [ 52 | "If food is described as {}, then food is good".format(adj) for adj in padj 53 | ] + ["If food is described as {}, then food is bad".format(adj) for adj in nadj] 54 | if abstraction: 55 | food = ["food", "steak", "tacos", "pizza", "pasta", "oysters", "filet mignon"] 56 | else: 57 | food = ["food"] 58 | 59 | e1 = "" 60 | e2 = "" 61 | 62 | templates_1 = [ 63 | editor.template(pf(e1, "The {food} wasn't {padj}"), food=food, padj=padj), 64 | editor.template(pf(e2, "The {food} wasn't {nadj}"), food=food, nadj=nadj), 65 | ] 66 | labels_1 = [0, 1] 67 | templates_2 = [ 68 | editor.template( 69 | pf(e1, "I did not think that the {food} was {padj}"), food=food, padj=padj 70 | ), 71 | editor.template( 72 | pf(e2, "I did not think that the {food} was {nadj}"), food=food, nadj=nadj 73 | ), 74 | ] 75 | labels_2 = [0, 1] 76 | templates_3 = [ 77 | editor.template( 78 | pf(e1, "The {food} was not {padj}, in my opinion"), food=food, padj=padj 79 | ), 80 | editor.template( 81 | pf(e2, "The {food} was not {nadj}, in my opinion"), food=food, nadj=nadj 82 | ), 83 | ] 84 | labels_3 = [0, 1] 85 | 86 | return { 87 | "d1": get_metadata(templates_1, labels_1), 88 | "d2": get_metadata(templates_2, labels_2), 89 | "d3": get_metadata(templates_3, labels_3), 90 | }, patches 91 | 92 | 93 | def knowledge_checklists_flips( 94 | abstraction=False, 95 | ): 96 | # performing well here indicates that the model isn't just copying 97 | padj = ["numf", "weref", "wety"] 98 | nadj = ["wuex", "muxy", "wegry"] 99 | patches = [ 100 | "If food is described as {}, then food is good".format(adj) for adj in padj 101 | ] + ["If food is described as {}, then food is bad".format(adj) for adj in nadj] 102 | if abstraction: 103 | food = ["steak", "tacos", "pizza", "pasta", "oysters", "filet mignon"] 104 | else: 105 | food = ["food"] 106 | 107 | e1 = "" 108 | e2 = "" 109 | # e1 = '{padj} is a good word' 110 | # e2 = '{nadj} is a bad word' 111 | 112 | templates_1 = [ 113 | editor.template( 114 | pf(e1, "The {food} was {padj}, but everything else was really {o_nadj}"), 115 | padj=padj, 116 | food=food, 117 | o_nadj=["bad", "poor", "pathetic"], 118 | ), 119 | editor.template( 120 | pf(e2, "The {food} was {nadj}, but everything else was really {o_padj}"), 121 | nadj=nadj, 122 | food=food, 123 | o_padj=["amazing", "wonderful"], 124 | ), 125 | ] 126 | labels_1 = [ 127 | 0, 128 | 1, 129 | ] # not 0 and 1 but the model preds for food was good, but everything else was really 130 | 131 | templates_2 = [ 132 | editor.template( 133 | pf( 134 | e1, 135 | "Unfortunately everything else was really {o_nadj} even though the {food} was {padj}", 136 | ), 137 | food=food, 138 | padj=padj, 139 | o_nadj=["bad", "poor", "pathetic"], 140 | ), 141 | editor.template( 142 | pf( 143 | e2, 144 | "Fortunately, everything else was really {o_padj} even though the {food} was {nadj}", 145 | ), 146 | food=food, 147 | nadj=nadj, 148 | o_padj=["amazing", "wonderful"], 149 | ), 150 | ] 151 | labels_2 = [0, 1] 152 | 153 | return { 154 | "d1": get_metadata(templates_1, labels_1), 155 | "d2": get_metadata(templates_2, labels_2), 156 | }, patches 157 | 158 | 159 | def knowledge_checklists( 160 | abstraction=False, 161 | use_irrelevant=False, 162 | ): 163 | easy_labels = [1, 0, 1, 0] 164 | padj = ["the bomb", "the shizz"] 165 | nadj = ["unusual", "strange"] 166 | s_padj = padj 167 | s_nadj = nadj 168 | o_padj = ["good", "decent"] 169 | o_nadj = ["bad"] 170 | 171 | if abstraction: 172 | food = ["steak", "tacos", "pizza", "pasta", "oysters", "filet mignon"] 173 | service = ["bartender", "server", "barista", "host"] 174 | else: 175 | service = ["service"] 176 | food = ["food"] 177 | 178 | patches = [f"if food is described as {adj}, then food is good" for adj in padj] 179 | patches += [f"if food is described as {adj}, then food is bad" for adj in nadj] 180 | patches += [ 181 | f"if service is described as {adj}, then service is good" for adj in padj 182 | ] 183 | patches += [ 184 | f"if service is described as {adj}, then service is bad" for adj in nadj 185 | ] 186 | 187 | e1 = e2 = e3 = e4 = "" 188 | 189 | simple_templates = [ 190 | editor.template( 191 | pf(e1, "The restaurant has {padj} {food}"), food=food, padj=padj 192 | ), 193 | editor.template( 194 | pf(e2, "The restaurant has {nadj} {food}"), food=food, nadj=nadj 195 | ), 196 | ] 197 | 198 | irrelevant_simple = [ 199 | editor.template( 200 | pf(e3, "The restaurant has a {padj} {service}"), 201 | service=service, 202 | padj=s_padj, 203 | ), 204 | editor.template( 205 | pf(e4, "The restaurant has a {nadj} {service}"), 206 | service=service, 207 | nadj=s_nadj, 208 | ), 209 | ] 210 | 211 | compound_templates = [ 212 | editor.template( 213 | pf( 214 | e1, "The restaurant has a {nadj} {service} but {food} was really {padj}" 215 | ), 216 | nadj=o_nadj, 217 | padj=padj, 218 | food=food, 219 | service=service, 220 | ), 221 | editor.template( 222 | pf( 223 | e2, "The restaurant has a {padj} {service} but {food} was really {nadj}" 224 | ), 225 | padj=o_padj, 226 | nadj=nadj, 227 | food=food, 228 | service=service, 229 | ), 230 | ] 231 | 232 | irrelevant_compound = [ 233 | editor.template( 234 | pf( 235 | e3, 236 | "The restaurant has {nadj} {food} but the {service} was really {padj}", 237 | ), 238 | nadj=o_nadj, 239 | padj=s_padj, 240 | food=food, 241 | service=service, 242 | ), 243 | editor.template( 244 | pf( 245 | e4, 246 | "The restaurant has {padj} {food} but the {service} was really {nadj}", 247 | ), 248 | padj=o_padj, 249 | nadj=s_nadj, 250 | food=food, 251 | service=service, 252 | ), 253 | ] 254 | 255 | if use_irrelevant: 256 | return { 257 | "simple": get_metadata(irrelevant_simple, easy_labels), 258 | "compound": get_metadata(irrelevant_compound, easy_labels), 259 | }, patches 260 | else: 261 | return { 262 | "simple": get_metadata(simple_templates, easy_labels), 263 | "compound": get_metadata(compound_templates, easy_labels), 264 | }, patches 265 | 266 | 267 | def subset(metadata, idxs): 268 | metadata_subset = {} 269 | for key in metadata: 270 | metadata_subset[key] = [metadata[key][idx] for idx in idxs] 271 | return metadata_subset 272 | 273 | 274 | def deconstruct(patch_and_instance_list): 275 | patches = [] 276 | instances = [] 277 | for eandi in patch_and_instance_list: 278 | if len(eandi.split(".")) > 1: 279 | patch = eandi.split(".")[0].split(":")[-1].strip() 280 | instance = eandi.split(".")[1].split(":")[-1].strip() 281 | else: 282 | instance = eandi.split(":")[-1].strip() 283 | patch = "" 284 | patches.append(patch) 285 | instances.append(instance) 286 | 287 | return instances, patches 288 | 289 | 290 | def get_metadata(all_data, all_labels): 291 | labels = [ 292 | label for label, template in zip(all_labels, all_data) for _ in template["data"] 293 | ] 294 | try: 295 | all_instances, all_patches = deconstruct( 296 | [ex for template in all_data for ex in template["data"]] 297 | ) 298 | except: 299 | all_instances = [ex for template in all_data for ex in template["data"]] 300 | all_patches = ["" for _ in all_instances] 301 | return { 302 | "instances": all_instances, 303 | "patches": all_patches, 304 | "labels": labels, 305 | } 306 | -------------------------------------------------------------------------------- /orig_model_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "870d3673", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 3, 17 | "id": "9c7c4cbe", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from data_fns import get_yelp_stars\n", 22 | "from eval_utils import predict_stuff\n", 23 | "import numpy as np" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 4, 29 | "id": "8072177b", 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stderr", 34 | "output_type": "stream", 35 | "text": [ 36 | "Some weights of T5ForConditionalGenerationMultipleHeads were not initialized from the model checkpoint at t5-large and are newly initialized: ['encoder.embed_tokens.weight']\n", 37 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 38 | ] 39 | }, 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "primary mode: exp_applies_predictor\n", 45 | "splicing parts from pretrained model\n" 46 | ] 47 | }, 48 | { 49 | "name": "stderr", 50 | "output_type": "stream", 51 | "text": [ 52 | "Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at t5-large and are newly initialized: ['encoder.embed_tokens.weight']\n", 53 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 54 | ] 55 | }, 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "only loading base model!\n" 61 | ] 62 | }, 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at t5-large and are newly initialized: ['encoder.embed_tokens.weight']\n", 68 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 69 | ] 70 | }, 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "primary mode: exp_applies_predictor\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "from eval_utils import load_model\n", 81 | "\n", 82 | "path_name = '/u/scr/smurty/LanguageExplanations/trained_models/t5-large-sst-no-exp'\n", 83 | "model_obj = load_model(path_name, primary_mode='exp_applies_predictor')\n" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "id": "d1dc6e3d", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stderr", 94 | "output_type": "stream", 95 | "text": [ 96 | "Reusing dataset yelp_polarity (/u/scr/smurty/yelp_polarity/plain_text/1.0.0/a770787b2526bdcbfc29ac2d9beb8e820fbc15a03afd3ebc4fb9d8529de57544)\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "tests_yelp = get_yelp_stars()\n", 102 | "inps, labels = tests_yelp[0], tests_yelp[1]" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 6, 108 | "id": "a7b8d795", 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "application/vnd.jupyter.widget-view+json": { 114 | "model_id": "dbc7aa43983c4efc9ba237ed4eced712", 115 | "version_major": 2, 116 | "version_minor": 0 117 | }, 118 | "text/plain": [ 119 | " 0%| | 0/4 [00:00 best_metric: 163 | best_metric = to_log[metric] 164 | log.info( 165 | "Saving model at {}".format(os.path.join(orig_working_dir, cfg.save_path)) 166 | ) 167 | torch.save( 168 | model.state_dict(), 169 | "{}/{}".format(orig_working_dir, cfg.save_path), 170 | ) 171 | return best_metric 172 | 173 | 174 | def train_loop_fixed_steps(model, cfg, train_data_dict, val_data, t_total, metric): 175 | accum_steps = cfg.get("accum_steps", 1) 176 | opt = get_opt(cfg, model) 177 | scheduler = get_scheduler(cfg, opt, t_total) 178 | num_steps = 0 179 | 180 | tokenizer = model.tokenizer 181 | train_data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 182 | val_data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 183 | 184 | # number of total 185 | pbar = tqdm(total=t_total) 186 | while num_steps < t_total: 187 | train_dataloaders = {} 188 | total_train_sz = [] 189 | for key, train_data in train_data_dict.items(): 190 | train_data_curr = train_data.get_data() 191 | total_train_sz.append(len(train_data_curr)) 192 | train = DataLoader( 193 | train_data_curr, 194 | sampler=RandomSampler(train_data_curr), 195 | batch_size=cfg.train_batch_size, 196 | collate_fn=train_data_collator, 197 | ) 198 | train_dataloaders[key] = train 199 | with torch.enable_grad(): 200 | losses = [] 201 | all_keys = list(train_dataloaders.keys()) 202 | for all_batches in zip(*train_dataloaders.values()): 203 | curr_batch_dict = dict(zip(all_keys, all_batches)) 204 | model.train() 205 | loss_curr = model.get_loss(curr_batch_dict) 206 | loss_curr /= accum_steps 207 | loss_curr.backward() 208 | losses.append(loss_curr.item()) 209 | if len(losses) == accum_steps: 210 | num_steps += 1 211 | pbar.update(1) 212 | opt.step() 213 | scheduler.step() 214 | model.zero_grad() 215 | losses = [] 216 | 217 | if num_steps == t_total: 218 | break 219 | if losses: 220 | num_steps += 1 221 | pbar.update(1) 222 | opt.step() 223 | scheduler.step() 224 | model.zero_grad() 225 | losses = [] 226 | pbar.close() 227 | 228 | print("Evaluating on Test Data.") 229 | return eval_func(cfg, model, val_data, val_data_collator, None, metric=metric) 230 | 231 | 232 | def train_loop(model, log, cfg, train_data_dict, val_data, metric="acc"): 233 | num_epochs = cfg.num_epochs 234 | accum_steps = cfg.get("accum_steps", 1) 235 | eval_every = cfg.get("eval_every", None) 236 | max_grad_norm = cfg.get("max_grad_norm", 5) 237 | opt = get_opt(cfg, model) 238 | t_total = num_epochs * ( 239 | min(len(train_data_dict[key]) for key in train_data_dict) 240 | // accum_steps 241 | * cfg.train_batch_size 242 | ) 243 | scheduler = get_scheduler(cfg, opt, t_total) 244 | num_steps = 0 245 | best_acc = 0 246 | orig_working_dir = working_dir() 247 | 248 | tokenizer = model.tokenizer 249 | train_data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 250 | val_data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 251 | 252 | # evaluate once at the beginning to see if evaluation pipeline is A-ok 253 | for epoch in range(num_epochs): 254 | # Evaluate on this epoch 255 | train_dataloaders = {} 256 | total_train_sz = [] 257 | for key, train_data in train_data_dict.items(): 258 | train_data_curr = train_data.get_data() 259 | total_train_sz.append(len(train_data_curr)) 260 | train = DataLoader( 261 | train_data_curr, 262 | sampler=RandomSampler(train_data_curr), 263 | batch_size=cfg.train_batch_size, 264 | collate_fn=train_data_collator, 265 | ) 266 | train_dataloaders[key] = train 267 | 268 | log.info("Epoch: {}".format(epoch)) 269 | with torch.enable_grad(), tqdm(total=min(total_train_sz)) as progress_bar: 270 | # Train on this epoch 271 | losses = [] 272 | all_keys = list(train_dataloaders.keys()) 273 | canon_key = all_keys[0] 274 | for all_batches in zip(*train_dataloaders.values()): 275 | curr_batch_dict = dict(zip(all_keys, all_batches)) 276 | model.train() 277 | loss_curr = model.get_loss(curr_batch_dict) 278 | progress_bar.update(len(curr_batch_dict[canon_key]["input_ids"])) 279 | loss_curr /= accum_steps 280 | loss_curr.backward() 281 | losses.append(loss_curr.item()) 282 | if len(losses) == accum_steps: 283 | num_steps += 1 284 | progress_bar.set_postfix( 285 | {"loss": sum(losses) / len(losses), "num_steps": num_steps} 286 | ) 287 | opt.step() 288 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 289 | scheduler.step() 290 | model.zero_grad() 291 | losses = [] 292 | if eval_every and num_steps % eval_every == 0: 293 | log.info("Evaluating at step {}".format(num_steps)) 294 | best_acc = eval_func( 295 | cfg, 296 | model, 297 | val_data, 298 | val_data_collator, 299 | log, 300 | best_acc, 301 | metric, 302 | ) 303 | # evaluate at the end of the epoch. 304 | if not eval_every: 305 | log.info("Evaluating at step {}".format(num_steps)) 306 | best_acc = eval_func( 307 | cfg, model, val_data, val_data_collator, log, best_acc, metric 308 | ) 309 | return 310 | -------------------------------------------------------------------------------- /patch_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from helpers import verbalize_examples, prompt_styles, powerset 4 | from tqdm import tqdm 5 | from datasets import Dataset as HFDataset 6 | 7 | from collections import defaultdict as ddict, Counter 8 | 9 | # from sentence_transformers import SentenceTransformer, util 10 | 11 | DEFAULT_EXP = "default noop explanation" 12 | 13 | 14 | def get_examples2patch_dict(patch2examples_dict, all_texts, with_idxs=False): 15 | examples2patch_pos = ddict(list) 16 | examples2patch_neg = ddict(list) 17 | total_len = 0 18 | for patch, idxs in patch2examples_dict.items(): 19 | # handled at the end 20 | if patch == "": 21 | continue 22 | for idx in idxs: 23 | if with_idxs: 24 | examples2patch_pos[all_texts[idx]].append((patch, idx)) 25 | else: 26 | examples2patch_pos[all_texts[idx]].append(patch) 27 | total_len += 1 28 | 29 | # handle the empty patch at the end. 30 | # We do this because we add the empty patch ONLY if the example has no non-zero patches 31 | for idx in patch2examples_dict[""]: 32 | assert with_idxs 33 | examples2patch_pos[all_texts[idx]].append(("", idx)) 34 | total_len += 1 35 | 36 | all_patches = set([patch for patch in patch2examples_dict]) - set( 37 | [""] 38 | ) # remove '' from all_patchs handle at then end 39 | # post process such that for inputs without an patch, we add ''. 40 | for text in examples2patch_pos: 41 | if with_idxs: 42 | all_negs = list( 43 | all_patches - set([patch for patch, _ in examples2patch_pos[text]]) 44 | ) 45 | else: 46 | all_negs = list(all_patches - set(examples2patch_pos[text])) 47 | examples2patch_neg[text] = all_negs 48 | 49 | return examples2patch_pos, examples2patch_neg, total_len 50 | 51 | 52 | class SimpleDataset: 53 | def __init__( 54 | self, 55 | all_data, 56 | tokenizer, 57 | as_lm=True, 58 | deverb_dict={"positive": 1, "negative": 0}, 59 | ): 60 | self.all_data = all_data 61 | self.tensored_dataset = self.get_tensored_dataset(tokenizer, as_lm) 62 | self.tokenizer = tokenizer 63 | self.deverb_dict = deverb_dict 64 | 65 | def get_tensored_dataset(self, tokenizer, as_lm): 66 | def pad(label, max_len, val): 67 | to_pad = max_len - len(label) 68 | return label + [val] * to_pad 69 | 70 | data_list = self.all_data 71 | if as_lm: 72 | all_labels = tokenizer([label for _, label in data_list])["input_ids"] 73 | max_len = max(len(label) for label in all_labels) 74 | all_labels = [pad(label, max_len, -100) for label in all_labels] 75 | # TODO: if labels not same length pad with -100 76 | print(Counter([tuple(l) for l in all_labels])) 77 | dataset = {"sentence": [ex for ex, _ in data_list], "label": all_labels} 78 | else: 79 | # deverbalize 80 | deverbalize_dict = self.deverb_dict 81 | dataset = { 82 | "sentence": [ex for ex, _ in data_list], 83 | "label": [deverbalize_dict[label] for _, label in data_list], 84 | } 85 | dataset = HFDataset.from_dict(dataset) 86 | tokenize_func = lambda examples: tokenizer( 87 | examples["sentence"], truncation=True, max_length=128 88 | ) 89 | tensored_dataset = dataset.map( 90 | tokenize_func, batched=True, remove_columns=["sentence"] 91 | ) 92 | return tensored_dataset 93 | 94 | def __len__(self): 95 | return len(self.all_data) 96 | 97 | def get_data(self, max_size=-1): 98 | return self.tensored_dataset 99 | 100 | 101 | class PatchApplies: 102 | def __init__(self, patch2examples_dict, texts, tokenizer): 103 | self.texts = texts 104 | examples2patch_pos, examples2patch_neg, total_len = get_examples2patch_dict( 105 | patch2examples_dict, texts 106 | ) 107 | self.examples2patch_pos = examples2patch_pos 108 | self.examples2patch_neg = examples2patch_neg 109 | self.total_len = total_len 110 | self.patch2examples_dict = patch2examples_dict 111 | self.tokenizer = tokenizer 112 | 113 | def __len__(self): 114 | return self.total_len 115 | 116 | def get_samples(self, text): 117 | all_negatives = self.examples2patch_neg[text] 118 | all_positives = self.examples2patch_pos[text] 119 | chosen_positives = random.choice(all_positives) 120 | chosen_negatives = random.choice(all_negatives) 121 | return chosen_positives, chosen_negatives 122 | 123 | def get_data(self): 124 | tokenizer = self.tokenizer 125 | all_data = {"labels": [], "sentence": []} 126 | verbalizer_label = {0: "no", 1: "yes"} 127 | prompt_func = prompt_styles["p1_patch_applies"] 128 | for example in self.examples2patch_pos: 129 | positive_ex, negative_ex = self.get_samples(example) 130 | with_correct_patch = prompt_func(positive_ex, example) 131 | with_incorrect_patch = prompt_func(negative_ex, example) 132 | all_data["labels"].append(1) 133 | all_data["sentence"].append(with_correct_patch) 134 | all_data["labels"].append(0) 135 | all_data["sentence"].append(with_incorrect_patch) 136 | all_data["labels"] = [ 137 | verbalizer_label[sentiment] for sentiment in all_data["labels"] 138 | ] 139 | all_data["sentence"] = verbalize_examples( 140 | all_data["sentence"], prompt_style="p1_exp_applies" 141 | ) 142 | return self.process_into_hf_dataset(all_data, tokenizer) 143 | 144 | def process_into_hf_dataset(self, all_data, tokenizer): 145 | all_data["labels"] = tokenizer(all_data["labels"])["input_ids"] 146 | dataset = HFDataset.from_dict(all_data) 147 | tokenize_func = lambda examples: tokenizer( 148 | examples["sentence"], truncation=True 149 | ) 150 | return dataset.map(tokenize_func, batched=True, remove_columns=["sentence"]) 151 | 152 | 153 | class PatchDataset: 154 | def __init__( 155 | self, 156 | patch2examples_dict, 157 | texts, 158 | labels, 159 | tokenizer, 160 | prompt_style="p1", 161 | get_hard_negs=False, 162 | use_negatives=True, 163 | ): 164 | self.texts = texts 165 | self.patch2examples_dict = patch2examples_dict 166 | self.use_negatives = use_negatives 167 | examples2patch_pos, examples2patch_neg, total_len = get_examples2patch_dict( 168 | patch2examples_dict, texts, with_idxs=True 169 | ) 170 | 171 | # default label is 0. 172 | example2noop_label = ddict(int) 173 | if DEFAULT_EXP in patch2examples_dict: 174 | for idx in patch2examples_dict[DEFAULT_EXP]: 175 | example2noop_label[texts[idx]] = labels[idx] 176 | 177 | self.example2noop_label = example2noop_label 178 | self.examples2patch_pos = examples2patch_pos 179 | self.examples2patch_neg = examples2patch_neg 180 | self.total_len = total_len 181 | self.prompt_style = prompt_style 182 | self.labels = labels 183 | self.tokenizer = tokenizer 184 | 185 | # how many examples are constructed per epoch 186 | def __len__(self): 187 | return self.total_len 188 | 189 | def get_neg_data(self, num_samples=5): 190 | tokenizer = self.tokenizer 191 | all_data = {"no_patch": [], "with_incorrect_patch": []} 192 | verbalizer_label = {0: "negative", 1: "positive"} 193 | prompt_func = prompt_styles[self.prompt_style] 194 | for example_text in self.examples2patch_pos: 195 | all_negatives = self.examples2patch_neg[example_text] 196 | sampled_patches = random.sample( 197 | all_negatives, k=min(num_samples, len(all_negatives)) 198 | ) 199 | for patch in sampled_patches: 200 | all_data["with_incorrect_patch"].append( 201 | prompt_func(patch, example_text) 202 | ) 203 | all_data["no_patch"].append(prompt_func("", example_text)) 204 | all_data["with_incorrect_patch"] = verbalize_examples( 205 | all_data["with_incorrect_patch"], prompt_style="p1" 206 | ) 207 | all_data["no_patch"] = verbalize_examples( 208 | all_data["no_patch"], prompt_style="p1" 209 | ) 210 | return self.process_into_hf_dataset(all_data, tokenizer) 211 | 212 | def combine(self, first_patch, second_patch): 213 | # invariant: both cannot be zero 214 | if len(first_patch) == 0: 215 | return second_patch 216 | elif len(second_patch) == 0: 217 | return first_patch 218 | else: 219 | return "{}. {}".format(first_patch, second_patch) 220 | 221 | def subset(self, data_list, indices): 222 | return [data_list[idx] for idx in indices] 223 | 224 | def get_data_helper(self, verbose, postprocess=True, max_size=-1): 225 | tokenizer = self.tokenizer 226 | verbalizer_label = {0: "negative", 1: "positive"} 227 | all_data = { 228 | "sentence": [], 229 | "label": [], 230 | "instances": [], 231 | "patches": [], 232 | "is_pos": [], 233 | } 234 | prompt_func = prompt_styles[self.prompt_style] 235 | 236 | for example_text in self.examples2patch_pos: 237 | for (positive_ex, idx) in self.examples2patch_pos[example_text]: 238 | if positive_ex == DEFAULT_EXP: 239 | continue 240 | instance = prompt_func(positive_ex, example_text) 241 | label = verbalizer_label[self.labels[idx]] 242 | all_data["sentence"].append(instance) 243 | all_data["instances"].append(example_text) 244 | all_data["patches"].append(positive_ex) 245 | all_data["is_pos"].append(1) 246 | all_data["label"].append(label) 247 | if self.use_negatives: 248 | all_negatives = self.examples2patch_neg[example_text] 249 | if len(all_negatives) > 10: 250 | all_negatives = random.sample(all_negatives, k=10) 251 | for neg in all_negatives: 252 | if neg == DEFAULT_EXP: 253 | continue 254 | instance = prompt_func(neg, example_text) 255 | # get the noop label, and put that here. 256 | all_data["sentence"].append(instance) 257 | all_data["instances"].append(example_text) 258 | all_data["patches"].append(neg) 259 | all_data["is_pos"].append(0) 260 | all_data["label"].append( 261 | verbalizer_label[self.example2noop_label[example_text]] 262 | ) 263 | if not postprocess: 264 | all_data["label"] = [int(l == "positive") for l in all_data["label"]] 265 | return all_data 266 | 267 | dataset = { 268 | "sentence": all_data["sentence"], 269 | "label": tokenizer(all_data["label"])["input_ids"], 270 | } 271 | if max_size != -1 and len(all_data["sentence"]) > max_size: 272 | all_indices = list(range(len(all_data["sentence"]))) 273 | indices = random.sample(all_indices, k=max_size) 274 | dataset = {key: self.subset(dataset[key], indices) for key in dataset} 275 | 276 | dataset = HFDataset.from_dict(dataset) 277 | if verbose: 278 | for ex in dataset: 279 | print(ex["sentence"]) 280 | print(ex["label"]) 281 | tokenize_func = lambda examples: tokenizer( 282 | examples["sentence"], truncation=True 283 | ) 284 | tensored_dataset = dataset.map( 285 | tokenize_func, batched=True, remove_columns=["sentence"] 286 | ) 287 | return tensored_dataset 288 | 289 | def get_data(self, verbose=False, postprocess=True, max_size=-1): 290 | return self.get_data_helper(verbose, postprocess=postprocess, max_size=max_size) 291 | 292 | def process_into_hf_dataset(self, all_data, tokenizer): 293 | dataset = HFDataset.from_dict(all_data) 294 | tokenize_func = lambda key: lambda ex: { 295 | "{}_{}".format(k, key): val 296 | for k, val in tokenizer(ex[key], truncation=True).items() 297 | } 298 | if "with_correct_patch" in all_data: 299 | dataset = dataset.map( 300 | tokenize_func("with_correct_patch"), 301 | batched=True, 302 | remove_columns=["with_correct_patch"], 303 | ) 304 | tensored_dataset = dataset.map( 305 | tokenize_func("with_incorrect_patch"), 306 | batched=True, 307 | remove_columns=["with_incorrect_patch"], 308 | ) 309 | return tensored_dataset.map( 310 | tokenize_func("no_patch"), batched=True, remove_columns=["no_patch"] 311 | ) 312 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import wandb 6 | from tqdm import tqdm 7 | from munch import Munch 8 | 9 | EPS = 1e-9 10 | 11 | from transformers import T5ForConditionalGeneration 12 | from transformers.models.t5.modeling_t5 import T5Stack 13 | 14 | from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput 15 | 16 | import warnings 17 | import copy 18 | 19 | from torch.nn import CrossEntropyLoss 20 | 21 | __HEAD_MASK_WARNING_MSG = """ 22 | The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, 23 | `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. 24 | If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, 25 | num_heads)`. 26 | """ 27 | 28 | 29 | class T5ForConditionalGenerationMultipleHeads(T5ForConditionalGeneration): 30 | def __init__(self, config): 31 | super().__init__(config) 32 | 33 | def forward( 34 | self, 35 | input_ids=None, 36 | attention_mask=None, 37 | decoder_input_ids=None, 38 | decoder_attention_mask=None, 39 | aux_decoder=None, 40 | aux_lm_head=None, 41 | head_mask=None, 42 | decoder_head_mask=None, 43 | cross_attn_head_mask=None, 44 | encoder_outputs=None, 45 | past_key_values=None, 46 | inputs_embeds=None, 47 | decoder_inputs_embeds=None, 48 | labels=None, 49 | use_cache=None, 50 | output_attentions=None, 51 | output_hidden_states=None, 52 | return_dict=None, 53 | ): 54 | use_cache = use_cache if use_cache is not None else self.config.use_cache 55 | return_dict = ( 56 | return_dict if return_dict is not None else self.config.use_return_dict 57 | ) 58 | 59 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask 60 | if head_mask is not None and decoder_head_mask is None: 61 | if self.config.num_layers == self.config.num_decoder_layers: 62 | warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) 63 | decoder_head_mask = head_mask 64 | 65 | # Encode if needed (training, first prediction pass) 66 | if encoder_outputs is None: 67 | # Convert encoder inputs in embeddings if needed 68 | encoder_outputs = self.encoder( 69 | input_ids=input_ids, 70 | attention_mask=attention_mask, 71 | inputs_embeds=inputs_embeds, 72 | head_mask=head_mask, 73 | output_attentions=output_attentions, 74 | output_hidden_states=output_hidden_states, 75 | return_dict=return_dict, 76 | ) 77 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 78 | encoder_outputs = BaseModelOutput( 79 | last_hidden_state=encoder_outputs[0], 80 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 81 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 82 | ) 83 | 84 | hidden_states = encoder_outputs[0] 85 | 86 | # Decode 87 | if aux_decoder: 88 | decoder = aux_decoder 89 | else: 90 | decoder = self.decoder 91 | 92 | if self.model_parallel: 93 | torch.cuda.set_device(decoder.first_device) 94 | 95 | if ( 96 | labels is not None 97 | and decoder_input_ids is None 98 | and decoder_inputs_embeds is None 99 | ): 100 | # get decoder inputs from shifting lm labels to the right 101 | decoder_input_ids = self._shift_right(labels) 102 | 103 | # Set device for model parallelism 104 | if self.model_parallel: 105 | torch.cuda.set_device(decoder.first_device) 106 | hidden_states = hidden_states.to(decoder.first_device) 107 | if decoder_input_ids is not None: 108 | decoder_input_ids = decoder_input_ids.to(decoder.first_device) 109 | if attention_mask is not None: 110 | attention_mask = attention_mask.to(decoder.first_device) 111 | if decoder_attention_mask is not None: 112 | decoder_attention_mask = decoder_attention_mask.to(decoder.first_device) 113 | 114 | # Decode 115 | decoder_outputs = decoder( 116 | input_ids=decoder_input_ids, 117 | attention_mask=decoder_attention_mask, 118 | inputs_embeds=decoder_inputs_embeds, 119 | past_key_values=past_key_values, 120 | encoder_hidden_states=hidden_states, 121 | encoder_attention_mask=attention_mask, 122 | head_mask=decoder_head_mask, 123 | cross_attn_head_mask=cross_attn_head_mask, 124 | use_cache=use_cache, 125 | output_attentions=output_attentions, 126 | output_hidden_states=output_hidden_states, 127 | return_dict=return_dict, 128 | ) 129 | 130 | sequence_output = decoder_outputs[0] 131 | 132 | # Set device for model parallelism 133 | if self.model_parallel: 134 | torch.cuda.set_device(self.encoder.first_device) 135 | self.lm_head = self.lm_head.to(self.encoder.first_device) 136 | sequence_output = sequence_output.to(self.lm_head.weight.device) 137 | 138 | if self.config.tie_word_embeddings: 139 | # Rescale output before projecting on vocab 140 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 141 | sequence_output = sequence_output * (self.model_dim**-0.5) 142 | 143 | if aux_lm_head: 144 | lm_head = aux_lm_head 145 | else: 146 | lm_head = self.lm_head 147 | 148 | lm_logits = lm_head(sequence_output) 149 | 150 | loss = None 151 | if labels is not None: 152 | loss_fct = CrossEntropyLoss(ignore_index=-100) 153 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) 154 | # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 155 | 156 | if not return_dict: 157 | output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs 158 | return ((loss,) + output) if loss is not None else output 159 | 160 | return Seq2SeqLMOutput( 161 | loss=loss, 162 | logits=lm_logits, 163 | past_key_values=decoder_outputs.past_key_values, 164 | decoder_hidden_states=decoder_outputs.hidden_states, 165 | decoder_attentions=decoder_outputs.attentions, 166 | cross_attentions=decoder_outputs.cross_attentions, 167 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 168 | encoder_hidden_states=encoder_outputs.hidden_states, 169 | encoder_attentions=encoder_outputs.attentions, 170 | ) 171 | 172 | 173 | def get_from_pretrained_t5(): 174 | print("splicing parts from pretrained model") 175 | model = T5ForConditionalGeneration.from_pretrained("t5-large") 176 | aux_decoder = model.decoder 177 | aux_lm_head = model.lm_head 178 | return aux_decoder, aux_lm_head 179 | 180 | 181 | class T5Interpeter(nn.Module): 182 | def __init__( 183 | self, 184 | model, 185 | tokenizer, 186 | label_list=["positive", "negative"], 187 | primary_mode="task_predictor", 188 | train_multihead=False, 189 | ): 190 | super().__init__() 191 | self.model = model 192 | self.primary_mode = primary_mode 193 | self.train_multihead = train_multihead 194 | print("primary mode: {}".format(primary_mode)) 195 | if self.train_multihead: 196 | decoder_config = copy.deepcopy(self.model.config) 197 | decoder_config.is_decoder = True 198 | decoder_config.is_encoder_decoder = False 199 | decoder_config.num_layers = self.model.config.num_decoder_layers 200 | 201 | aux_decoder, aux_lm_head = get_from_pretrained_t5() 202 | self.aux_decoder = aux_decoder 203 | self.aux_lm_head = aux_lm_head 204 | """ 205 | self.aux_decoder = T5Stack(decoder_config, 206 | nn.Embedding(self.model.config.vocab_size, self.model.config.d_model)) 207 | self.aux_lm_head = nn.Linear(self.model.config.d_model, self.model.config.vocab_size, bias=False) 208 | """ 209 | else: 210 | self.aux_decoder = None 211 | self.aux_lm_head = None 212 | 213 | self.tokenizer = tokenizer 214 | self.loss_fn = nn.CrossEntropyLoss() 215 | pos_idx = tokenizer(label_list[0])["input_ids"] 216 | neg_idx = tokenizer(label_list[1])["input_ids"] 217 | self.pos_idx = pos_idx[0] 218 | self.neg_idx = neg_idx[0] 219 | self.label_list = [self.pos_idx, self.neg_idx] 220 | self.label_list_words = label_list 221 | 222 | def forward_helper(self, batch, mode): 223 | for key in batch: 224 | batch[key] = batch[key].to(self.model.device) 225 | # labels are -100 unless the input_id refers to either positive or negative 226 | if mode == "patch_applies_predictor": 227 | assert self.aux_decoder is not None 228 | out = self.model( 229 | input_ids=batch["input_ids"], 230 | attention_mask=batch["attention_mask"], 231 | labels=batch["labels"], 232 | aux_decoder=self.aux_decoder, 233 | aux_lm_head=self.aux_lm_head, 234 | ) 235 | else: 236 | out = self.model( 237 | input_ids=batch["input_ids"], 238 | attention_mask=batch["attention_mask"], 239 | labels=batch["labels"], 240 | ) 241 | 242 | return out 243 | 244 | def get_task_tensors(self, logits, batch): 245 | cls_logits = logits[:, 0] 246 | if "labels" in batch: 247 | return cls_logits, batch["labels"][:, 0] 248 | else: 249 | return cls_logits, None 250 | 251 | def compute_confusion_matrix(self, preds, labels_curr): 252 | tp = 0.0 253 | tn = 0.0 254 | fp = 0.0 255 | fn = 0.0 256 | for pred, label in zip(preds, labels_curr): 257 | # label might be padded 258 | if type(label) == list and label[-1] == -100: 259 | idx = label.index(-100) 260 | label = label[:idx] 261 | if label == self.pos_idx: 262 | tp += int(pred == self.pos_idx) 263 | fn += int(pred == self.neg_idx) 264 | else: 265 | tn += int(pred == self.neg_idx) 266 | fp += int(pred == self.pos_idx) 267 | return tp, tn, fp, fn 268 | 269 | def get_acc(self, batch, mode): 270 | with torch.no_grad(): 271 | out = self.forward_helper(batch, mode=mode) 272 | logits, labels = self.get_task_tensors(out.logits, batch) 273 | labels = labels.cpu().tolist() 274 | task_logits = logits[ 275 | :, self.label_list 276 | ] # first logit is for positive, second logit is for negative. 277 | preds = task_logits.argmax(dim=-1) 278 | 279 | # just compare positive and negative 280 | preds_task = [self.label_list[pred] for pred in preds] 281 | return task_logits, labels, preds_task 282 | 283 | def get_loss(self, batch): 284 | if type(batch) == dict: 285 | out_list = [] 286 | for key in batch: 287 | if key == "patch_grounding_data": 288 | out_list.append( 289 | self.forward_helper(batch[key], mode="patch_applies_predictor") 290 | ) 291 | else: 292 | out_list.append( 293 | self.forward_helper(batch[key], mode="task_predictor") 294 | ) 295 | loss_curr = sum(out.loss for out in out_list) 296 | else: 297 | out = self.forward_helper(batch, mode=self.primary_mode) 298 | loss_curr = out.loss 299 | try: 300 | wandb.log({"loss": loss_curr.item()}) 301 | except: 302 | pass 303 | return loss_curr 304 | 305 | def evaluator(self, examples, mode=None, verbose=True): 306 | task_logits_all = [] 307 | labels = [] 308 | 309 | correct = 0.0 310 | tp = 0.0 311 | fp = 0.0 312 | tn = 0.0 313 | fn = 0.0 314 | 315 | if not mode: 316 | mode = self.primary_mode 317 | if verbose: 318 | iterate_over = tqdm(examples) 319 | else: 320 | iterate_over = examples 321 | 322 | for batch in iterate_over: 323 | task_logits, labels_curr, preds = self.get_acc(batch, mode) 324 | # sum(p == l for p, l in zip(preds, labels_curr)) 325 | tp_curr, tn_curr, fp_curr, fn_curr = self.compute_confusion_matrix( 326 | preds, labels_curr 327 | ) 328 | tp += tp_curr 329 | fp += fp_curr 330 | tn += tn_curr 331 | fn += fn_curr 332 | 333 | correct += tp_curr + tn_curr 334 | task_logits_all.append(task_logits) 335 | labels += labels_curr 336 | 337 | task_logits = torch.cat(task_logits_all) 338 | probs = F.softmax(task_logits, dim=1).cpu().numpy() 339 | precision = tp / (tp + fp + EPS) # prevent div by 0 340 | recall = tp / (tp + fn + EPS) # prevent div by 0 341 | f1 = 2 * precision * recall / (precision + recall + EPS) 342 | 343 | return { 344 | "labels": labels, 345 | "probs": probs, 346 | "f1": f1, 347 | "precision": precision, 348 | "recall": recall, 349 | "logits": task_logits.cpu(), 350 | "acc": (correct) / (1.0 * len(labels)), 351 | } 352 | -------------------------------------------------------------------------------- /data_fns/data_fns_override_checklists.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | # setting path 4 | from checklist.editor import Editor 5 | from collections import defaultdict as ddict 6 | import random 7 | import pandas as pd 8 | 9 | editor = Editor() 10 | 11 | import re 12 | 13 | ##################### Data generation for synthetic datasets: Override patches ######################## 14 | def get_checklist_data_but(): 15 | food_list = ["food", "taco", "steak"] 16 | service_list = ["service", "waiter", "staff", "manager", "waitress"] 17 | padj = ["good", "nice", "great"] 18 | nadj = ["bad", "poor", "horrible"] 19 | editor = Editor() 20 | x1 = editor.template( 21 | "The restaurant has {nadj} {service} but {food} was {padj}", 22 | food=food_list, 23 | service=service_list, 24 | padj=padj, 25 | nadj=nadj, 26 | ) 27 | x1_label_set = [ 28 | [ 29 | {"category": "food", "polarity": "positive"}, 30 | {"category": "service", "polarity": "negative"}, 31 | ] 32 | for _ in x1["data"] 33 | ] 34 | x2 = editor.template( 35 | "The restaurant has {padj} {service} but {food} was {nadj}", 36 | food=food_list, 37 | service=service_list, 38 | padj=padj, 39 | nadj=nadj, 40 | ) 41 | x2_label_set = [ 42 | [ 43 | {"category": "food", "polarity": "negative"}, 44 | {"category": "service", "polarity": "positive"}, 45 | ] 46 | for _ in x2["data"] 47 | ] 48 | x3 = editor.template( 49 | "The restaurant has {nadj} {food} but {service} was {padj}", 50 | food=food_list, 51 | service=service_list, 52 | padj=padj, 53 | nadj=nadj, 54 | ) 55 | x3_label_set = [ 56 | [ 57 | {"category": "food", "polarity": "negative"}, 58 | {"category": "service", "polarity": "positive"}, 59 | ] 60 | for _ in x3["data"] 61 | ] 62 | x4 = editor.template( 63 | "The restaurant has {padj} {food} but {service} was {nadj}", 64 | food=food_list, 65 | service=service_list, 66 | padj=padj, 67 | nadj=nadj, 68 | ) 69 | x4_label_set = [ 70 | [ 71 | {"category": "food", "polarity": "positive"}, 72 | {"category": "service", "polarity": "negative"}, 73 | ] 74 | for _ in x4["data"] 75 | ] 76 | 77 | all_templates = [x1, x2, x3, x4] 78 | all_labels_checklist = x1_label_set + x2_label_set + x3_label_set + x4_label_set 79 | all_examples_checklist = [ 80 | data for template in all_templates for data in template["data"] 81 | ] 82 | overall_labels = ( 83 | [1] * len(x1["data"]) 84 | + [0] * len(x2["data"]) 85 | + [1] * len(x3["data"]) 86 | + [0] * len(x4["data"]) 87 | ) 88 | return all_examples_checklist, all_labels_checklist, overall_labels 89 | 90 | 91 | def get_checklist_data(): 92 | food_list = ["food", "taco", "steak"] 93 | service_list = ["service", "waiter", "staff", "manager", "waitress"] 94 | padj = ["good", "nice", "great"] 95 | nadj = ["bad", "horrible", "below average"] 96 | editor = Editor() 97 | x1 = editor.template( 98 | "The restaurant has {padj} {food} and {nadj} service", 99 | food=food_list, 100 | service=service_list, 101 | padj=padj, 102 | nadj=nadj, 103 | ) 104 | x1_label_set = [ 105 | [ 106 | {"category": "food", "polarity": "positive"}, 107 | {"category": "service", "polarity": "negative"}, 108 | ] 109 | for _ in x1["data"] 110 | ] 111 | x2 = editor.template( 112 | "The restaurant has {nadj} {food} and {padj} service", 113 | food=food_list, 114 | service=service_list, 115 | padj=padj, 116 | nadj=nadj, 117 | ) 118 | x2_label_set = [ 119 | [ 120 | {"category": "food", "polarity": "negative"}, 121 | {"category": "service", "polarity": "positive"}, 122 | ] 123 | for _ in x2["data"] 124 | ] 125 | x3 = editor.template( 126 | "The restaurant has {padj} {service} and {nadj} food", 127 | food=food_list, 128 | service=service_list, 129 | padj=padj, 130 | nadj=nadj, 131 | ) 132 | x3_label_set = [ 133 | [ 134 | {"category": "food", "polarity": "negative"}, 135 | {"category": "service", "polarity": "positive"}, 136 | ] 137 | for _ in x3["data"] 138 | ] 139 | x4 = editor.template( 140 | "The restaurant has {nadj} {service} and {padj} {food}", 141 | food=food_list, 142 | service=service_list, 143 | padj=padj, 144 | nadj=nadj, 145 | ) 146 | x4_label_set = [ 147 | [ 148 | {"category": "food", "polarity": "positive"}, 149 | {"category": "service", "polarity": "negative"}, 150 | ] 151 | for _ in x4["data"] 152 | ] 153 | 154 | all_templates = [x1, x2, x3, x4] 155 | all_labels_checklist = x1_label_set + x2_label_set + x3_label_set + x4_label_set 156 | all_examples_checklist = [ 157 | data for template in all_templates for data in template["data"] 158 | ] 159 | overall_labels = ( 160 | [0] * len(x1["data"]) 161 | + [1] * len(x2["data"]) 162 | + [0] * len(x3["data"]) 163 | + [1] * len(x4["data"]) 164 | ) 165 | return all_examples_checklist, all_labels_checklist, overall_labels 166 | 167 | 168 | def get_checklist_data_negated(): 169 | food_list = ["food", "taco", "steak"] 170 | service_list = ["service", "waiter", "staff", "manager", "waitress"] 171 | padj = ["good", "nice", "great"] 172 | nadj = ["bad", "horrible", "below average"] 173 | editor = Editor() 174 | x1 = editor.template( 175 | "I do not think that the restaurant has {padj} {food}, {service} was {padj}", 176 | food=food_list, 177 | service=service_list, 178 | padj=padj, 179 | ) 180 | x1_label_set = [ 181 | [ 182 | {"category": "food", "polarity": "negative"}, 183 | {"category": "service", "polarity": "positive"}, 184 | ] 185 | for _ in x1["data"] 186 | ] 187 | 188 | x2 = editor.template( 189 | "I do not think that the restaurant has {nadj} {food}, {service} was {padj}", 190 | food=food_list, 191 | service=service_list, 192 | padj=padj, 193 | nadj=nadj, 194 | ) 195 | x2_label_set = [ 196 | [ 197 | {"category": "food", "polarity": "positive"}, 198 | {"category": "service", "polarity": "positive"}, 199 | ] 200 | for _ in x2["data"] 201 | ] 202 | 203 | x3 = editor.template( 204 | "I do not think that the restaurant has {padj} {food}, {service} was {nadj}", 205 | food=food_list, 206 | service=service_list, 207 | padj=padj, 208 | nadj=nadj, 209 | ) 210 | x3_label_set = [ 211 | [ 212 | {"category": "food", "polarity": "negative"}, 213 | {"category": "service", "polarity": "negative"}, 214 | ] 215 | for _ in x3["data"] 216 | ] 217 | 218 | x4 = editor.template( 219 | "I do not think that the restaurant has {nadj} {food}, {service} was {nadj}", 220 | food=food_list, 221 | service=service_list, 222 | nadj=nadj, 223 | ) 224 | x4_label_set = [ 225 | [ 226 | {"category": "food", "polarity": "positive"}, 227 | {"category": "service", "polarity": "negative"}, 228 | ] 229 | for _ in x4["data"] 230 | ] 231 | 232 | x5 = editor.template( 233 | "{service} was {padj} and I do not think that the restaurant has {padj} {food}", 234 | food=food_list, 235 | service=service_list, 236 | padj=padj, 237 | ) 238 | x5_label_set = [ 239 | [ 240 | {"category": "food", "polarity": "negative"}, 241 | {"category": "service", "polarity": "positive"}, 242 | ] 243 | for _ in x1["data"] 244 | ] 245 | 246 | x6 = editor.template( 247 | " {service} was {padj} and I do not think that the restaurant has {nadj} {food}", 248 | food=food_list, 249 | service=service_list, 250 | padj=padj, 251 | nadj=nadj, 252 | ) 253 | x6_label_set = [ 254 | [ 255 | {"category": "food", "polarity": "positive"}, 256 | {"category": "service", "polarity": "positive"}, 257 | ] 258 | for _ in x2["data"] 259 | ] 260 | 261 | x7 = editor.template( 262 | "{service} was {nadj} and I do not think that the restaurant has {padj} {food}", 263 | food=food_list, 264 | service=service_list, 265 | padj=padj, 266 | nadj=nadj, 267 | ) 268 | x7_label_set = [ 269 | [ 270 | {"category": "food", "polarity": "negative"}, 271 | {"category": "service", "polarity": "negative"}, 272 | ] 273 | for _ in x3["data"] 274 | ] 275 | 276 | x8 = editor.template( 277 | "{service} was {nadj} and I do not think that the restaurant has {nadj} {food}", 278 | food=food_list, 279 | service=service_list, 280 | nadj=nadj, 281 | ) 282 | x8_label_set = [ 283 | [ 284 | {"category": "food", "polarity": "positive"}, 285 | {"category": "service", "polarity": "negative"}, 286 | ] 287 | for _ in x4["data"] 288 | ] 289 | 290 | all_templates = [x1, x2, x3, x4, x5, x6, x7, x8] 291 | all_labels_checklist = ( 292 | x1_label_set 293 | + x2_label_set 294 | + x3_label_set 295 | + x4_label_set 296 | + x5_label_set 297 | + x6_label_set 298 | + x7_label_set 299 | + x8_label_set 300 | ) 301 | all_examples_checklist = [ 302 | data for template in all_templates for data in template["data"] 303 | ] 304 | 305 | all_templates = [x1, x2, x3, x4, x5, x6, x7, x8] 306 | return {"food is good": [x1 + x3 + x5 + x7], "food is bad": [x2 + x4 + x6 + x8]} 307 | 308 | 309 | def aspect_abstraction_test_fn_negated(): 310 | words = ["good", "nice"] 311 | words_2 = ["weird", "surprising", "unexpected", "unusual"] 312 | 313 | all_patches = [ 314 | "If food is described as {}, then sentiment is negative".format(word) 315 | for word in words_2 316 | ] 317 | aspects = ["steak", "tacos", "pizza", "pasta", "oysters", "filet mignon"] 318 | service_aspects = ["bartender", "waiter", "waitress", "manager", "barista"] 319 | 320 | inputs_1 = editor.template( 321 | "The {aspect} at the restaurant was not {words}", aspect=aspects, words=words_2 322 | )["data"] 323 | inputs_2 = editor.template( 324 | "I did not think that the {aspect} at the restaurant was {words}", 325 | aspect=aspects, 326 | words=words_2, 327 | )["data"] 328 | 329 | inputs_3 = editor.template( 330 | "The {aspect} at the restaurant was {words}", aspect=aspects, words=words 331 | )["data"] 332 | inputs_4 = editor.template( 333 | "The {aspect} at the restaurant was {words}", 334 | aspect=service_aspects, 335 | words=words_2, 336 | )["data"] 337 | 338 | inputs = inputs_1 + inputs_2 + inputs_3 + inputs_4 339 | labels = ( 340 | [1] * len(inputs_1) 341 | + [1] * len(inputs_2) 342 | + [1] * len(inputs_3) 343 | + [0] * len(inputs_4) 344 | ) 345 | return inputs, labels, all_patches 346 | 347 | 348 | def aspect_abstraction_test_fn(): 349 | words_1 = ["weird", "surprising", "unexpected", "unusual"] 350 | words_2 = ["wowowow", "goooood", "da bomb", "ultimate"] 351 | 352 | ## remember, that the baseline works for these 353 | patches = [ 354 | "If food is described as {}, then sentiment is negative".format(word) 355 | for word in words_1 356 | ] 357 | patches += [""] 358 | 359 | # but not for these. 360 | patches_2 = [ 361 | "If food is described as {}, then sentiment is positive".format(word) 362 | for word in words_2 363 | ] 364 | all_patches = patches + patches_2 365 | 366 | aspects = ["steak", "tacos", "pizza", "pasta", "oysters", "filet mignon"] 367 | inputs_neg = editor.template( 368 | "The {aspect} at the restaurant was {words}", aspect=aspects, words=words_1 369 | )["data"] 370 | inputs_pos = editor.template( 371 | "The {aspect} at the restaurant was {words}", aspect=aspects, words=words_2 372 | )["data"] 373 | inputs = inputs_pos + inputs_neg 374 | labels = [1] * len(inputs_pos) + [0] * len(inputs_neg) 375 | return inputs, labels, patches 376 | 377 | 378 | def keyword_matching_test(words_1, words_2): 379 | patches = [ 380 | "If review contains phrases or words like {}, then sentiment is negative".format( 381 | word 382 | ) 383 | for word in words_1 384 | ] 385 | patches += [ 386 | "If review contains phrases or words like {}, then sentiment is positive".format( 387 | word 388 | ) 389 | for word in words_2 390 | ] 391 | patches += [""] 392 | restaurant_aspects = [ 393 | "service", 394 | "ambience", 395 | "food", 396 | "lighting", 397 | "steak", 398 | "waiter", 399 | "pasta", 400 | "pizza", 401 | ] 402 | movie_aspects = ["plot", "casting", "storyline", "ending", "writing"] 403 | subjs = ["book", "movie", "restaurant", "bar"] 404 | # Here we look at restaurants but also generic subjects. 405 | neg_templates = [ 406 | editor.template( 407 | "The {aspect} at the restaurant was {words}", 408 | aspect=restaurant_aspects, 409 | words=words_1, 410 | ), 411 | editor.template( 412 | "The {aspect} of the movie was {words}", aspect=movie_aspects, words=words_1 413 | ), 414 | editor.template( 415 | "We found the {subj} to be quite {words}", words=words_1, subj=subjs 416 | ), 417 | ] 418 | 419 | pos_templates = [ 420 | editor.template( 421 | "The {aspect} at the restaurant was {words}", 422 | aspect=restaurant_aspects, 423 | words=words_2, 424 | ), 425 | editor.template( 426 | "The {aspect} of the movie was {words}", aspect=movie_aspects, words=words_2 427 | ), 428 | editor.template( 429 | "We found the {subj} to be quite {words}", words=words_2, subj=subjs 430 | ), 431 | ] 432 | inputs = [] 433 | labels = [] 434 | for t in neg_templates: 435 | inputs += t["data"] 436 | labels += [0] * len(t["data"]) 437 | for t in pos_templates: 438 | inputs += t["data"] 439 | labels += [1] * len(t["data"]) 440 | return inputs, labels, patches 441 | 442 | 443 | def keyword_matching_real_words(): 444 | words_1 = ["weird", "unexpected", "unusual", "strange"] 445 | words_2 = ["interesting", "amazeballs", "gooood", "wooooow"] 446 | return keyword_matching_test(words_1, words_2) 447 | 448 | 449 | def keyword_matching_gibberish_words(): 450 | words_1_gibberish = ["wug", "zubin", "shug"] 451 | words_2_gibberish = ["stup", "zink", "zoop"] 452 | return keyword_matching_test(words_1_gibberish, words_2_gibberish) 453 | -------------------------------------------------------------------------------- /convert_yaml_to_data.py: -------------------------------------------------------------------------------- 1 | from curses import meta 2 | from distutils.spawn import find_executable 3 | import yaml 4 | import argparse 5 | import json 6 | import random 7 | from checklist.editor import Editor 8 | from collections import Counter 9 | from copy import deepcopy 10 | from eval_utils import load_model, predict_stuff 11 | from itertools import product 12 | from collections import defaultdict as ddict 13 | import numpy as np 14 | 15 | editor = Editor() 16 | 17 | 18 | def prompt_style_0(explanation, sentence): 19 | return "Explanation: %s.\nInput: %s." % (explanation, sentence) 20 | 21 | def prompt_style_1(explanation, sentence): 22 | if len(explanation) > 0: 23 | out = "Explanation: {}. Input: {}".format(explanation, sentence) 24 | else: 25 | out = "Input: {}".format(sentence) 26 | out = out.rstrip() 27 | return out[:-1].rstrip() if out[-1] == '.' else out 28 | 29 | def prompt_style_2(explanation, sentence): 30 | return "Steering hints: %s. '%s'" %(explanation, sentence) 31 | 32 | def deconstruct(explanation_and_instance_list): 33 | explanations = [] 34 | instances = [] 35 | for eandi in explanation_and_instance_list: 36 | try: 37 | split_idx = eandi.find('Input') 38 | explanation = ' '.join(eandi[:split_idx].split(':')[1:]).strip() 39 | instance = ' '.join(eandi[split_idx:].split(':')[1:]).strip() 40 | except: 41 | instance = ' '.join(eandi.split(':')[1:]).strip() 42 | explanation = '' 43 | if explanation[-1] == '.': 44 | explanation = explanation[:-1] 45 | explanations.append(explanation) 46 | instances.append(instance) 47 | return explanations, instances 48 | 49 | def get_default_labels(instances, mode='sentiment'): 50 | instances_with_exp = [('', instance) for instance in instances] 51 | if mode == 'sentiment': 52 | model_obj = load_model('models/t5-large-sst-no-exp') 53 | else: 54 | model_obj = load_model('models/t5-large-spouse_re_0.1') 55 | 56 | import pdb; pdb.set_trace(); 57 | model_out = predict_stuff(instances_with_exp, [0]*len(instances), model_obj, verbose=True, prompt_style='p1') 58 | return model_out.argmax(axis=1) 59 | 60 | 61 | prompt_styles = {"p0": prompt_style_0, "p1": prompt_style_1, "p2": prompt_style_2} 62 | chosen_prompt_style = "p1" 63 | 64 | 65 | def create_data(gpt_input, vals, key): 66 | # replace all occurences 67 | num_occurences = gpt_input.count(key) 68 | # consider the cartesian product num_occurences times 69 | replace_tuple_list = list(product(vals, repeat=num_occurences)) 70 | all_data = [] 71 | offset = len(key) 72 | for replace_tuple in replace_tuple_list: 73 | i = 0 74 | f = 0 75 | gpt_input_curr = deepcopy(gpt_input) 76 | while f < num_occurences: 77 | if gpt_input_curr[i : i + offset] == key: 78 | gpt_input_curr = "%s%s%s" % ( 79 | gpt_input_curr[:i], 80 | replace_tuple[f], 81 | gpt_input_curr[i + offset :], 82 | ) 83 | f += 1 84 | i += 1 85 | all_data.append(gpt_input_curr) 86 | return {"data": all_data} 87 | 88 | 89 | def gender(person): 90 | male_names = ['Bob', 'Stephen', 'Lee', 'Tao'] 91 | female_names = ['Mary', 'Alice', 'Stacy'] 92 | if person in male_names: 93 | return 'male' 94 | elif person in female_names: 95 | return 'female' 96 | else: 97 | return 'unknown' 98 | 99 | def subsample(data_dict, sample_size): 100 | new_data_dict = {key: random.sample(data_dict[key], k=sample_size) for key in data_dict} 101 | return new_data_dict 102 | 103 | def read_data_zsb(file_name, add_noop=True): 104 | data = {'explanations': [], 'instances': [], 'labels': []} 105 | with open(file_name, 'r') as stream: 106 | yaml_dict = yaml.load(stream) 107 | all_explanations = [val for _, val in yaml_dict['Explanations'].items()] 108 | all_template_sets = [val for _, val in yaml_dict['Templates'].items()] 109 | for _, (exps, template_set) in enumerate(zip(all_explanations, all_template_sets)): 110 | if type(exps) != list: 111 | exps = [exps] 112 | for exp in exps: 113 | for template in template_set: 114 | label = template[-1] 115 | sentence = template[0] 116 | all_args = {arg: yaml_dict["Fillers"][arg] for arg in template[1:-1]} 117 | print(sentence) 118 | all_examples = editor.template(sentence, **all_args, remove_duplicates=True, meta=True) 119 | data['explanations'] += [exp] * len(all_examples.data) 120 | data['instances'] += all_examples.data 121 | data['labels'] += [label] * len(all_examples.data) 122 | 123 | print(Counter(data['labels'])) 124 | if add_noop: 125 | instance2data = ddict(list) 126 | for idx, instance in enumerate(data['instances']): 127 | label = data['labels'][idx] 128 | instance = data['instances'][idx] 129 | exp = data['explanations'][idx] 130 | #instance2data[instance].append((label, instance, exp)) 131 | # TODO: TRAIN AN APPLIES CLASSIFIER! 132 | instance2data[instance].append((1, instance, exp)) 133 | 134 | all_instances = list(set(data['instances'])) 135 | 136 | 137 | 138 | print("getting default labels for {} instances".format(len(all_instances))) 139 | #labels = get_default_labels(all_instances, mode='re') 140 | labels = [0]*len(all_instances) 141 | # baseline out... 142 | idx2labels = ddict(list) 143 | for idx, label in enumerate(labels): 144 | idx2labels[label].append(idx) 145 | 146 | 147 | #if len(idx2labels[1]) <= len(idx2labels[0]): 148 | # chosen_negs = random.sample(idx2labels[0], k = len(idx2labels[1])) 149 | # chosen_idxs = [(1, idx) for idx in idx2labels[1]] + [(0, idx) for idx in chosen_negs] 150 | #else: 151 | chosen_idxs = [(1, idx) for idx in idx2labels[1]] + [(0, idx) for idx in idx2labels[0]] 152 | #print("baseline acc: {}".format(baseline_acc)) 153 | print(len(chosen_idxs), len(idx2labels[1]), len(idx2labels[0])) 154 | #for label, instance in zip(labels, all_instances): 155 | 156 | # we need to keep examples that have a different label compared to the noop label! 157 | # if that is satisfied, we are good to go! 158 | 159 | new_data = {'explanations': [], 'labels': [], 'instances': []} 160 | for label, idx in chosen_idxs: 161 | instance = all_instances[idx] 162 | old_data_list = instance2data[instance] 163 | to_use = False 164 | for l, i, exp in old_data_list: 165 | #if l != label: 166 | if True: 167 | to_use = True 168 | new_data['explanations'].append(exp) 169 | new_data['instances'].append(i) 170 | new_data['labels'].append(l) 171 | if to_use: 172 | new_data['explanations'].append('default noop explanation') 173 | #new_data['labels'].append(int(label)) 174 | #TODO: TRAIN AN APPLIES CLASSIFIER! 175 | new_data['labels'].append(0) 176 | new_data['instances'].append(instance) 177 | else: 178 | pass 179 | # TODO: throw this into a rejected pile, and later, maybe accept some! 180 | # now add the old inputs here! 181 | # subsample 10k examples 182 | return new_data 183 | 184 | 185 | def read_data_re3(file_name): 186 | def get_only_unique(data): 187 | new_data = {'explanations': [], 'instances': [], 'labels': []} 188 | seen = set() 189 | for exp, instance, label in zip(data['explanations'], data['instances'], data['labels']): 190 | if (exp, instance) not in seen: 191 | new_data['explanations'].append(exp) 192 | new_data['instances'].append(instance) 193 | new_data['labels'].append(label) 194 | seen.add((exp, instance)) 195 | 196 | print(len(data['instances']), len(new_data['instances'])) 197 | return new_data 198 | def get_instances(template, patch, entity_key, fillers): 199 | editor_temp = '%s\t%s' %(patch, template) 200 | all_args = {arg: fillers[arg] for arg in fillers if '{%s'%arg in editor_temp} 201 | all_examples = editor.template(editor_temp, **all_args, remove_duplicates=True, meta=True) 202 | instances = [] 203 | patches = [] 204 | 205 | 206 | person_1_key, person_2_key = entity_key.split('_') 207 | for inp, metadata in zip(all_examples.data, all_examples.meta): 208 | patch_curr, inp_curr = inp.split('\t') 209 | try: 210 | p1 = metadata[person_1_key] 211 | p2 = metadata[person_2_key] 212 | except: 213 | import pdb; pdb.set_trace(); 214 | instances.append('{}. Entity1: {}. Entity2: {}'.format(inp_curr, p1, p2)) 215 | patches.append(patch_curr) 216 | return instances, patches 217 | # for feature based patches on spouse 218 | data = {'explanations': [], 'instances': [], 'labels': []} 219 | with open(file_name, 'r') as stream: 220 | yaml_dict = yaml.load(stream) 221 | fillers = yaml_dict['FILLERS'] 222 | for key in yaml_dict['Templates']: 223 | templates = yaml_dict['Templates'][key] 224 | 225 | patches = yaml_dict['Patches'][key] # get corresponding patches 226 | for template in templates: 227 | for patch in patches: 228 | all_labels = yaml_dict['Labels'][key][0] 229 | for entity_key in all_labels: 230 | instances_curr, patches_curr = get_instances(template, patch, entity_key, fillers) 231 | # also add entity2_entity? 232 | label_curr = all_labels[entity_key] 233 | data['instances'] += instances_curr 234 | data['explanations'] += patches_curr 235 | data['labels'] +=[label_curr]*len(instances_curr) 236 | 237 | p1, p2 = entity_key.split('_') 238 | instances_curr, patches_curr = get_instances(template, patch, '{}_{}'.format(p2, p1), fillers) 239 | # also add entity2_entity1 240 | label_curr = all_labels[entity_key] 241 | data['instances'] += instances_curr 242 | data['explanations'] += patches_curr 243 | data['labels'] +=[label_curr]*len(instances_curr) 244 | 245 | 246 | data = get_only_unique(data) 247 | return data 248 | 249 | 250 | 251 | def read_data_re2(file_name): 252 | inverses = {'e1': 'e2', 253 | 'e2': 'e1', 254 | 'e3': 'e3', 255 | 'e4': 'e4', 256 | 'e5': 'e5', 257 | 'e6': 'e6'} 258 | def get_instances(template, patch, entity_key, fillers): 259 | editor_temp = '%s\t%s' %(patch, template) 260 | all_args = {arg: fillers[arg] for arg in fillers if '{%s'%arg in editor_temp} 261 | all_examples = editor.template(editor_temp, **all_args, remove_duplicates=True, meta=True) 262 | #if '{location}' in template: 263 | #all_examples = editor.template(template, p=fillers['p'], location=fillers['location'], remove_duplicates=True, meta=True) 264 | #else: 265 | #all_examples = editor.template(template, p=fillers['p'], remove_duplicates=True, meta=True) 266 | instances = [] 267 | patches = [] 268 | 269 | 270 | person_1_key, person_2_key = entity_key.split('_') 271 | for inp, metadata in zip(all_examples.data, all_examples.meta): 272 | patch_curr, inp_curr = inp.split('\t') 273 | try: 274 | p1 = metadata[person_1_key] 275 | p2 = metadata[person_2_key] 276 | except: 277 | import pdb; pdb.set_trace(); 278 | # create biased synthetic data 279 | # if label == 1 and gender(p1) == gender(p2): 280 | # continue 281 | instances.append('{}. Entity1: {}. Entity2: {}'.format(inp_curr, p1, p2)) 282 | patches.append(patch_curr) 283 | return instances, patches 284 | 285 | def get_only_unique(data): 286 | new_data = {'explanations': [], 'instances': [], 'labels': []} 287 | seen = set() 288 | for exp, instance, label in zip(data['explanations'], data['instances'], data['labels']): 289 | if (exp, instance) not in seen: 290 | new_data['explanations'].append(exp) 291 | new_data['instances'].append(instance) 292 | new_data['labels'].append(label) 293 | seen.add((exp, instance)) 294 | 295 | print(len(data['instances']), len(new_data['instances'])) 296 | return new_data 297 | 298 | data = {'explanations': [], 'instances': [], 'labels': []} 299 | with open(file_name, 'r') as stream: 300 | yaml_dict = yaml.load(stream) 301 | all_explanations = {key: val[0] for key, val in yaml_dict['Explanations'].items()} 302 | for key in yaml_dict['Templates']: 303 | templates = yaml_dict['Templates'][key] 304 | labels = yaml_dict['LABELS'][key] 305 | for entity_key_dict in labels: 306 | entity_key = list(entity_key_dict.keys())[0] 307 | labels = entity_key_dict[entity_key].split(', ') 308 | print(len(labels)) 309 | print(len(templates)) 310 | for idx, template in enumerate(templates): 311 | if labels[idx] == '_': 312 | continue 313 | positive_patch = all_explanations[labels[idx]] 314 | instances, pos_patches = get_instances(template, positive_patch, entity_key, yaml_dict['FILLERS']) 315 | data['instances'] += instances 316 | data['explanations'] += pos_patches 317 | data['labels'] += [1]*len(instances) 318 | 319 | if labels[idx] in inverses: 320 | inverse_patch = all_explanations[inverses[labels[idx]]] 321 | e1, e2 = entity_key.split('_') 322 | instances_2, inverse_patches = get_instances(template, inverse_patch, '{}_{}'.format(e2, e1), yaml_dict['FILLERS']) 323 | data['instances'] += instances_2 324 | data['explanations'] += inverse_patches 325 | data['labels'] += [1]*len(instances_2) 326 | 327 | 328 | data = get_only_unique(data) 329 | return data 330 | 331 | 332 | def read_data_re(file_name, add_noop=True): 333 | data = {'explanations': [], 'instances': [], 'labels': []} 334 | with open(file_name, 'r') as stream: 335 | yaml_dict = yaml.load(stream) 336 | all_explanations = [val for _, val in yaml_dict['Explanations'].items()] 337 | all_template_sets = [val for _, val in yaml_dict['Templates'].items()] 338 | for _, (exps, template_set) in enumerate(zip(all_explanations, all_template_sets)): 339 | if type(exps) != list: 340 | exps = [exps] 341 | for exp in exps: 342 | for template in template_set: 343 | label = template[-1] 344 | person_2_key = template[-2] 345 | person_1_key = template[-3] 346 | sentence = template[0] 347 | all_args = {arg: yaml_dict["Fillers"][arg] for arg in template[1:-3]} 348 | all_examples = editor.template(sentence, **all_args, remove_duplicates=True, meta=True) 349 | instances_curr = [] 350 | for inp, metadata in zip(all_examples.data, all_examples.meta): 351 | p1 = metadata[person_1_key] 352 | p2 = metadata[person_2_key] 353 | # create biased synthetic data 354 | if label == 1 and gender(p1) == gender(p2): 355 | continue 356 | instances_curr.append('{}. Entity1: {}. Entity2: {}'.format(inp, p1, p2)) 357 | data['explanations'] += [exp] * len(instances_curr) 358 | data['instances'] += instances_curr 359 | data['labels'] += [label] * len(instances_curr) 360 | 361 | print(Counter(data['labels'])) 362 | #if len(data['labels']) > 10000: 363 | # data = subsample(data, 10000) 364 | if add_noop: 365 | instance2data = ddict(list) 366 | for idx, instance in enumerate(data['instances']): 367 | label = data['labels'][idx] 368 | instance = data['instances'][idx] 369 | exp = data['explanations'][idx] 370 | #instance2data[instance].append((label, instance, exp)) 371 | # TODO: TRAIN AN APPLIES CLASSIFIER! 372 | instance2data[instance].append((1, instance, exp)) 373 | 374 | all_instances = list(set(data['instances'])) 375 | 376 | 377 | 378 | print("getting default labels for {} instances".format(len(all_instances))) 379 | #labels = get_default_labels(all_instances, mode='re') 380 | labels = [0]*len(all_instances) 381 | # baseline out... 382 | idx2labels = ddict(list) 383 | for idx, label in enumerate(labels): 384 | idx2labels[label].append(idx) 385 | 386 | 387 | #if len(idx2labels[1]) <= len(idx2labels[0]): 388 | # chosen_negs = random.sample(idx2labels[0], k = len(idx2labels[1])) 389 | # chosen_idxs = [(1, idx) for idx in idx2labels[1]] + [(0, idx) for idx in chosen_negs] 390 | #else: 391 | chosen_idxs = [(1, idx) for idx in idx2labels[1]] + [(0, idx) for idx in idx2labels[0]] 392 | #print("baseline acc: {}".format(baseline_acc)) 393 | print(len(chosen_idxs), len(idx2labels[1]), len(idx2labels[0])) 394 | #for label, instance in zip(labels, all_instances): 395 | 396 | # we need to keep examples that have a different label compared to the noop label! 397 | # if that is satisfied, we are good to go! 398 | 399 | new_data = {'explanations': [], 'labels': [], 'instances': []} 400 | for label, idx in chosen_idxs: 401 | instance = all_instances[idx] 402 | old_data_list = instance2data[instance] 403 | to_use = False 404 | for l, i, exp in old_data_list: 405 | #if l != label: 406 | if True: 407 | to_use = True 408 | new_data['explanations'].append(exp) 409 | new_data['instances'].append(i) 410 | new_data['labels'].append(l) 411 | if to_use: 412 | new_data['explanations'].append('default noop explanation') 413 | #new_data['labels'].append(int(label)) 414 | #TODO: TRAIN AN APPLIES CLASSIFIER! 415 | new_data['labels'].append(0) 416 | new_data['instances'].append(instance) 417 | else: 418 | pass 419 | # TODO: throw this into a rejected pile, and later, maybe accept some! 420 | # now add the old inputs here! 421 | # subsample 10k examples 422 | return new_data 423 | else: 424 | return data 425 | 426 | 427 | def read_data(args, file_name): 428 | prompt_func = prompt_styles[chosen_prompt_style] 429 | data = {"examples": [], "labels": [], 'is_gold': []} 430 | with open(file_name, "r") as stream: 431 | yaml_dict = yaml.load(stream) 432 | all_explanations = [val for key, val in yaml_dict["Explanations"].items()] 433 | all_template_sets = [val for key, val in yaml_dict["Templates"].items()] 434 | for idx, (explanation_obj, template_set) in enumerate(zip(all_explanations, all_template_sets)): 435 | if type(explanation_obj) != list: 436 | explanation_obj = [explanation_obj] 437 | 438 | for explanation in explanation_obj: 439 | for template in template_set: 440 | label = template[-1] 441 | sentence = template[0] 442 | gpt_input = prompt_func(explanation, sentence) 443 | if len(template) > 2: 444 | all_args = { 445 | arg: yaml_dict["Fillers"][arg] for arg in template[1:-1] 446 | } 447 | all_padj = all(['padj' in arg for arg in all_args if 'adj' in arg]) 448 | all_nadj = all(['nadj' in arg for arg in all_args if 'adj' in arg]) 449 | try: 450 | all_examples = editor.template(gpt_input, **all_args, remove_duplicates=True) 451 | except: 452 | import pdb; pdb.set_trace(); 453 | data["examples"] += all_examples["data"] 454 | if type(label) == int: 455 | data["labels"] += [label] * len(all_examples["data"]) 456 | if all_padj or all_nadj: 457 | print(gpt_input) 458 | data['is_gold'] += [1] * len(all_examples['data']) 459 | else: 460 | data['is_gold'] += [0] * len(all_examples['data']) 461 | else: 462 | idxs = list(range(len(all_examples["data"]))) 463 | chosen_idxs = random.sample(idxs, k=int(len(idxs) * label)) 464 | data["labels"] += [ 465 | (1 if idx in chosen_idxs else 0) for idx in idxs 466 | ] 467 | else: 468 | data["examples"] += [gpt_input] 469 | data["labels"] += [label] 470 | explanations, instances = deconstruct(data['examples']) 471 | data['explanations'] = [exp.lower() for exp in explanations] 472 | data['instances'] = instances 473 | ## for each explanation, we have positives... 474 | ## all the negatives can be read from NOOP? 475 | # now just need to change the labels!! 476 | default_idxs = [idx for idx, ex in enumerate(data['explanations']) if ex == 'default noop explanation'] 477 | if len(default_idxs) == 0: 478 | print("No default explanation. Make sure this is the correct behavior.") 479 | all_instances = list(set(data['instances'])) 480 | print("getting default labels for {} instances".format(len(all_instances))) 481 | labels = get_default_labels(all_instances) 482 | for instance, label in zip(all_instances, labels): 483 | data['explanations'].append('default noop explanation') 484 | data['labels'].append(int(label)) 485 | data['instances'].append(instance) 486 | else: 487 | default_instances = [data['instances'][idx] for idx in default_idxs] 488 | labels = get_default_labels(default_instances) 489 | for oidx, idx in enumerate(default_idxs): 490 | data['labels'][idx] = int(labels[oidx]) 491 | return data 492 | 493 | 494 | if __name__ == "__main__": 495 | parser = argparse.ArgumentParser("create data from yaml file") 496 | parser.add_argument("--exp_dir", type=str) 497 | parser.add_argument("--mode", type=str, default='sentiment') 498 | 499 | args = parser.parse_args() 500 | file_name = "{}/explanations.yaml".format(args.exp_dir) 501 | if args.mode == 'sentiment': 502 | data = read_data(args, file_name) 503 | elif args.mode == 're': 504 | data = read_data_re3(file_name) 505 | else: 506 | data = read_data_zsb(file_name, add_noop=True) 507 | print(Counter(data['labels'])) 508 | with open("{}/synthetic_data.json".format(args.exp_dir), "w") as writer: 509 | json.dump(data, writer) 510 | -------------------------------------------------------------------------------- /override-checklist-experiments.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ad94f801", 6 | "metadata": {}, 7 | "source": [ 8 | "#### For overrride patches, we use checklists to evaluate how well our model can:\n", 9 | "- use patches when they refer to abstract conditions\n", 10 | "- avoid spurious behaviors like matching based on keywords, ignoring negated contexts etc." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 18, 16 | "id": "24e423ed", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "The autoreload extension is already loaded. To reload it, use:\n", 24 | " %reload_ext autoreload\n" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "%load_ext autoreload\n", 30 | "%autoreload 2" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "id": "6baf50f7", 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stderr", 41 | "output_type": "stream", 42 | "text": [ 43 | "Some weights of T5ForConditionalGenerationMultipleHeads were not initialized from the model checkpoint at t5-large and are newly initialized: ['encoder.embed_tokens.weight']\n", 44 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 45 | ] 46 | }, 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "primary mode: exp_applies_predictor\n", 52 | "splicing parts from pretrained model\n" 53 | ] 54 | }, 55 | { 56 | "name": "stderr", 57 | "output_type": "stream", 58 | "text": [ 59 | "Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at t5-large and are newly initialized: ['encoder.embed_tokens.weight']\n", 60 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "from eval_utils import load_model\n", 66 | "\n", 67 | "# TODO: make this notebook better. try to use mtl vs non mtl and see difference. \n", 68 | "# TODO: see what the baseline without any explanations or anything gets. \n", 69 | "\n", 70 | "#path_name = '/u/scr/smurty/LanguageExplanations/gpt2-finetune-generalization/models/t5-large-sst-no-exp'\n", 71 | "#path_name = '/u/scr/smurty/LanguageExplanations/all_models/t5-large-sst-fix-mtl-4'\n", 72 | "\n", 73 | "\n", 74 | "path_name='t5-sst-overrides-mtl-newest'\n", 75 | "model_obj = load_model(path_name, primary_mode='exp_applies_predictor')" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 1, 81 | "id": "f65adba1", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "from data_fns import keyword_matching_real_words, keyword_matching_gibberish_words, aspect_abstraction_test_fn\n", 86 | "\n", 87 | "all_datasets = {'keyword_real': keyword_matching_real_words(),\n", 88 | " 'aspect': aspect_abstraction_test_fn(),\n", 89 | " 'keyword_gibberish': keyword_matching_gibberish_words()}\n", 90 | "\n", 91 | "\n", 92 | "\n", 93 | "### construct a label according to the conditions, since we have an override.\n", 94 | "def cond2label_dict(cond, orig_label):\n", 95 | " is_food = 'food' in cond\n", 96 | " is_good = 'good' in cond\n", 97 | " label_name2label = {'positive': 1, 'negative': 0, 'NAN': -1}\n", 98 | " if is_food:\n", 99 | " dict_to_use = [label for label in orig_label if label['category'] == 'food'][0]\n", 100 | " else:\n", 101 | " dict_to_use = [label for label in orig_label if label['category'] == 'service'][0]\n", 102 | " \n", 103 | " label = label_name2label[dict_to_use['polarity']]\n", 104 | " cond_label = int(label_name2label[dict_to_use['polarity']] == is_good)\n", 105 | " return label, cond_label\n", 106 | "\n", 107 | "\n", 108 | "def conds_and_labels(data_tuple):\n", 109 | " def helper(explanation, inputs):\n", 110 | " word = explanation.split(\",\")[0].split(\" \")[-1]\n", 111 | " ### If word appears in label, then the explanation applies\n", 112 | " labels = [int(word in cinput) for cinput in inputs]\n", 113 | " cond = \" \".join(explanation.split(\",\")[0].split(\" \")[1:])\n", 114 | " return cond, labels \n", 115 | "\n", 116 | " data_conds = []\n", 117 | " data_label_sets = {}\n", 118 | " for explanation in data_tuple[-1]:\n", 119 | " cond, labels = helper(explanation, data_tuple[0])\n", 120 | " if 'positive' in explanation:\n", 121 | " l = [0, 1]\n", 122 | " else:\n", 123 | " l = [1, 0]\n", 124 | " data_conds.append((cond, l))\n", 125 | " data_label_sets[cond] = labels \n", 126 | " return data_conds, data_label_sets\n", 127 | "\n", 128 | "\n", 129 | "all_conds = {}\n", 130 | "all_label_sets = {}\n", 131 | "for key in all_datasets:\n", 132 | " val = all_datasets[key]\n", 133 | " if 'checklist' in key :\n", 134 | " conds = [('food is good', [0,1]), ('service is good', [0,1]), ('food is bad',[1,0]), ('service is bad',[1,0])]\n", 135 | " label_sets = {cond: [cond2label_dict(cond, l) for l in val[1]] for cond, _ in conds}\n", 136 | " else:\n", 137 | " conds, label_sets = conds_and_labels(val)\n", 138 | " \n", 139 | " all_conds[key] = conds\n", 140 | " all_label_sets[key] = label_sets" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 11, 146 | "id": "f2e86a82", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "from eval_utils import predict_stuff\n", 151 | "import numpy as np\n", 152 | "import itertools\n", 153 | "\n", 154 | "\n", 155 | "def apply_patch_soft(exp_applies_probs, baseline_probs, label_clause): \n", 156 | " x = np.array([label_clause]).repeat(len(baseline_probs), 0)\n", 157 | " #print(x.shape) \n", 158 | " applies_prob = exp_applies_probs[:, 1].reshape(-1, 1)\n", 159 | " #print(applies_prob)\n", 160 | " return applies_prob * x + (1 - applies_prob) * baseline_probs\n", 161 | "\n", 162 | "\n", 163 | "def get_scores_multiple_patches_hard(data, cond_list, examine=False):\n", 164 | " no_exps = [('', ex) for ex in data[0]]\n", 165 | " no_exp_probs = predict_stuff(no_exps, [0]*len(no_exps), model_obj, 'p1', verbose=False, mode='task_predictor')\n", 166 | " cond_probs = []\n", 167 | " interpret_probs = []\n", 168 | " all_patched_probs = []\n", 169 | " for idx, (cond, label_clause) in enumerate(cond_list):\n", 170 | " if cond == '':\n", 171 | " continue\n", 172 | " contextualized = [(cond, ex) for ex in data[0]]\n", 173 | " output_probs = predict_stuff(contextualized, itertools.repeat(0), model_obj, 'p1', verbose=False)\n", 174 | " cond_probs.append(np.log(output_probs[:, 1])) # log(p(c | x))\n", 175 | " \n", 176 | " patched_probs = apply_patch_soft(output_probs, no_exp_probs, label_clause) #Pr(y | x, lp) \n", 177 | " all_patched_probs.append(patched_probs[:, 1])\n", 178 | " # how much should each be weighted by? \n", 179 | " # pick best patch and apply it! \n", 180 | " all_patched_probs = np.stack(all_patched_probs, axis=1) # D x P\n", 181 | " cond_probs = np.stack(cond_probs, axis=1) # D x P\n", 182 | " best_patches = np.argmax(cond_probs, axis=1) # D x l\n", 183 | " \n", 184 | " ptrue = np.array([p[idx] for p, idx in zip(all_patched_probs, best_patches)])\n", 185 | " pfalse = 1.0 - ptrue\n", 186 | " return no_exp_probs, np.stack([pfalse, ptrue]).T" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 5, 192 | "id": "365a8396", 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "application/vnd.jupyter.widget-view+json": { 198 | "model_id": "48c169ef981749fb8250bce7cf8ff4f5", 199 | "version_major": 2, 200 | "version_minor": 0 201 | }, 202 | "text/plain": [ 203 | " 0%| | 0/1 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mdata_fns\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_yelp_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# set conflicting to True for Table-4 and False for Table-3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0md1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_yelp_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconflicting\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 387 | "\u001b[0;32m/juice2/scr2/smurty/LanguagePatching/data_fns.py\u001b[0m in \u001b[0;36mget_yelp_data\u001b[0;34m(conflicting)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_yelp_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconflicting\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mdf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'compare_model_steering_labeled.csv'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0mfix_quotes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m\"'\"\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mdf_inputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mfix_quotes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcinput\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcinput\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'Input'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 388 | "\u001b[0;32m/u/nlp/anaconda/main/anaconda3/envs/shikhar-basic/lib/python3.8/site-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36mread_csv\u001b[0;34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, squeeze, prefix, mangle_dupe_cols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, dialect, error_bad_lines, warn_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options)\u001b[0m\n\u001b[1;32m 608\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwds_defaults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 610\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_read\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 612\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 389 | "\u001b[0;32m/u/nlp/anaconda/main/anaconda3/envs/shikhar-basic/lib/python3.8/site-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m_read\u001b[0;34m(filepath_or_buffer, kwds)\u001b[0m\n\u001b[1;32m 460\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[0;31m# Create the parser.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 462\u001b[0;31m \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTextFileReader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 463\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 464\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mchunksize\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 390 | "\u001b[0;32m/u/nlp/anaconda/main/anaconda3/envs/shikhar-basic/lib/python3.8/site-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, f, engine, **kwds)\u001b[0m\n\u001b[1;32m 817\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"has_index_names\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"has_index_names\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 818\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 819\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_engine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 820\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 821\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 391 | "\u001b[0;32m/u/nlp/anaconda/main/anaconda3/envs/shikhar-basic/lib/python3.8/site-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m_make_engine\u001b[0;34m(self, engine)\u001b[0m\n\u001b[1;32m 1048\u001b[0m )\n\u001b[1;32m 1049\u001b[0m \u001b[0;31m# error: Too many arguments for \"ParserBase\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1050\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmapping\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[call-arg]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1051\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1052\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_failover_to_python\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 392 | "\u001b[0;32m/u/nlp/anaconda/main/anaconda3/envs/shikhar-basic/lib/python3.8/site-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, src, **kwds)\u001b[0m\n\u001b[1;32m 1865\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1866\u001b[0m \u001b[0;31m# open handles\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1867\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_open_handles\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1868\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandles\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1869\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m\"storage_options\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"encoding\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"memory_map\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"compression\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 393 | "\u001b[0;32m/u/nlp/anaconda/main/anaconda3/envs/shikhar-basic/lib/python3.8/site-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m_open_handles\u001b[0;34m(self, src, kwds)\u001b[0m\n\u001b[1;32m 1360\u001b[0m \u001b[0mLet\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mreaders\u001b[0m \u001b[0mopen\u001b[0m \u001b[0mIOHanldes\u001b[0m \u001b[0mafter\u001b[0m \u001b[0mthey\u001b[0m \u001b[0mare\u001b[0m \u001b[0mdone\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtheir\u001b[0m \u001b[0mpotential\u001b[0m \u001b[0mraises\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1361\u001b[0m \"\"\"\n\u001b[0;32m-> 1362\u001b[0;31m self.handles = get_handle(\n\u001b[0m\u001b[1;32m 1363\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1364\u001b[0m \u001b[0;34m\"r\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 394 | "\u001b[0;32m/u/nlp/anaconda/main/anaconda3/envs/shikhar-basic/lib/python3.8/site-packages/pandas/io/common.py\u001b[0m in \u001b[0;36mget_handle\u001b[0;34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[0m\n\u001b[1;32m 645\u001b[0m \u001b[0merrors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"replace\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 646\u001b[0m \u001b[0;31m# Encoding\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 647\u001b[0;31m handle = open(\n\u001b[0m\u001b[1;32m 648\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 649\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 395 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'compare_model_steering_labeled.csv'" 396 | ] 397 | } 398 | ], 399 | "source": [ 400 | "from data_fns import get_yelp_data\n", 401 | "# set conflicting to True for Table-4 and False for Table-3\n", 402 | "d1 = get_yelp_data(conflicting=False)\n", 403 | "print(len(d1))" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": null, 409 | "id": "ac8a014c", 410 | "metadata": {}, 411 | "outputs": [], 412 | "source": [ 413 | "def cond2label_dict(cond, orig_label):\n", 414 | " is_food = 'food' in cond\n", 415 | " is_good = 'good' in cond\n", 416 | " \n", 417 | " label_name2label = {'positive': 1, 'negative': 0, 'NAN': -1}\n", 418 | " if is_food:\n", 419 | " dict_to_use = [label for label in orig_label if label['category'] == 'food'][0]\n", 420 | " else:\n", 421 | " dict_to_use = [label for label in orig_label if label['category'] == 'service'][0]\n", 422 | " # aspect sentiment. does patch apply\n", 423 | " return label_name2label[dict_to_use['polarity']], int(label_name2label[dict_to_use['polarity']] == is_good)\n", 424 | "\n", 425 | "\n", 426 | "conds = [('food is good', [0,1]), ('service is good', [0,1]), ('food is bad',[1,0]), ('service is bad',[1,0])]\n", 427 | "label_sets = {cond: [cond2label_dict(cond, l) for l in d1[1]] for cond, _ in conds}" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": null, 433 | "id": "ad65db33", 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "def get_steering_acc(data, labels, cond_labels, cond, cons, use_exps=True):\n", 438 | " no_exps = [('', ex) for ex in data]\n", 439 | " no_exp_probs = predict_stuff(no_exps, [0]*len(no_exps), model_obj, 'p1', verbose=False, mode='task_predictor')\n", 440 | " no_exp_preds = no_exp_probs.argmax(axis=1) \n", 441 | " \n", 442 | " if not use_exps:\n", 443 | " acc_1 = np.sum((no_exp_preds == labels) & cond_labels)\n", 444 | " return acc_1, np.sum(1-cond_labels), np.sum(cond_labels), np.sum(1-cond_labels)\n", 445 | " else:\n", 446 | " contextualized = [(cond, ex) for ex in data]\n", 447 | " output_probs = predict_stuff(contextualized, cond_labels, model_obj, 'p1', verbose=False)\n", 448 | " patched_probs = apply_patch_soft(output_probs, no_exp_probs, cons) #Pr(y | x, lp)\n", 449 | " patched_preds = patched_probs.argmax(axis=1)\n", 450 | " \n", 451 | " # if patch applies, how often is model correct\n", 452 | " acc_1 = np.sum((patched_preds == labels) & cond_labels)\n", 453 | " \n", 454 | " # if the patch doesn't apply, how often does the prediction say the same\n", 455 | " acc_2 = np.sum((patched_preds == no_exp_preds) & (1-cond_labels))\n", 456 | " return acc_1, acc_2, np.sum(cond_labels), np.sum(1-cond_labels)\n", 457 | "\n", 458 | "\n", 459 | "def get_scores(conds, use_exps=True):\n", 460 | " t1 = 0.0\n", 461 | " t2 = 0.0\n", 462 | "\n", 463 | " total1 = 0.0\n", 464 | " total2 = 0.0\n", 465 | "\n", 466 | " for cond, cons in conds:\n", 467 | " curr = label_sets[cond]\n", 468 | " aspect_labels = np.array([a for a, _ in curr])\n", 469 | " cond_applies = np.array([ca for _, ca in curr])\n", 470 | "\n", 471 | " print(cond)\n", 472 | " t1_c, t2_c, total1_c, total2_c = get_steering_acc(d1[0], aspect_labels, cond_applies, cond, cons, use_exps=use_exps)\n", 473 | " t1 += t1_c\n", 474 | " t2 += t2_c\n", 475 | " total1 += total1_c\n", 476 | " total2 += total2_c \n", 477 | " return t1 / total1, t2 / total2" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": null, 483 | "id": "aa055d5e", 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "s1, s2 = get_scores(conds)\n", 488 | "print(s1,s2)" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "id": "d47a0f6f", 495 | "metadata": {}, 496 | "outputs": [], 497 | "source": [ 498 | "s1, s2 = get_scores(conds, use_exps=False)\n", 499 | "print(s1, s2)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "id": "6e8b9eb5", 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [] 509 | } 510 | ], 511 | "metadata": { 512 | "kernelspec": { 513 | "display_name": "shikhar-basic", 514 | "language": "python", 515 | "name": "shikhar-basic" 516 | }, 517 | "language_info": { 518 | "codemirror_mode": { 519 | "name": "ipython", 520 | "version": 3 521 | }, 522 | "file_extension": ".py", 523 | "mimetype": "text/x-python", 524 | "name": "python", 525 | "nbconvert_exporter": "python", 526 | "pygments_lexer": "ipython3", 527 | "version": "3.8.10" 528 | } 529 | }, 530 | "nbformat": 4, 531 | "nbformat_minor": 5 532 | } 533 | --------------------------------------------------------------------------------