├── README.md
├── download_models.sh
├── notebooks
└── evaluation.ipynb
├── requirements.txt
├── run_all.sh
├── run_stage_one.py
├── run_stage_two.py
└── src
├── __init__.py
├── dataset.py
├── edit_finder.py
├── editor.py
├── masker.py
├── predictors
├── imdb
│ ├── imdb_dataset_reader.py
│ └── imdb_roberta.json
├── newsgroups
│ ├── newsgroups_dataset_reader.py
│ └── newsgroups_roberta.json
├── predictor_utils.py
└── race
│ ├── race_dataset_reader.py
│ └── race_roberta.json
├── stage_one.py
├── stage_two.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Minimal Contrastive Editing (MiCE) 🐭
2 |
3 | This repository contains code for our paper, [Explaining NLP Models via Minimal Contrastive Editing (MiCE)](https://arxiv.org/pdf/2012.13985.pdf).
4 |
5 | ## Citation
6 | ```bibtex
7 | @inproceedings{Ross2020ExplainingNM,
8 | title = "Explaining NLP Models via Minimal Contrastive Editing (MiCE)",
9 | author = "Ross, Alexis and Marasovi{\'c}, Ana and Peters, Matthew E.",
10 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2021",
11 | publisher = "Association for Computational Linguistics",
12 | url= "https://arxiv.org/abs/2012.13985",
13 | }
14 | ```
15 | ## Installation
16 |
17 | 1. Clone the repository.
18 | ```bash
19 | git clone https://github.com/allenai/mice.git
20 | cd mice
21 | ```
22 |
23 | 2. [Download and install Conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html).
24 |
25 | 3. Create a Conda environment.
26 |
27 | ```bash
28 | conda create -n mice python=3.7
29 | ```
30 |
31 | 4. Activate the environment.
32 |
33 | ```bash
34 | conda activate mice
35 | ```
36 |
37 | 5. Download the requirements.
38 |
39 | ```bash
40 | pip3 install -r requirements.txt
41 | ```
42 |
43 | ## Quick Start
44 |
45 | 1. **Download Task Data**: If you want to work with the RACE dataset, download it here: [Link](https://www.cs.cmu.edu/~glai1/data/race/).
46 | The commands below assume that this data, after downloaded, is stored in `data/RACE/`.
47 | All other task-specific datasets are automatically downloaded by the commands below.
48 | 2. **Download Pretrained Models**: You can download pretrained models by running:
49 |
50 | ```bash
51 | bash download_models.sh
52 | ```
53 |
54 | For each task (IMDB/Newsgroups/RACE), this script saves the:
55 |
56 | - Predictor model to: `trained_predictors/{TASK}/model/model.tar.gz`.
57 | - Editor checkpoint to: `results/{TASK}/editors/mice/{TASK}_editor.pth`.
58 |
59 | 4. **Generate Edits**: Run the following command to generate edits for a particular task with our pretrained editor. It will write edits to `results/{TASK}/edits/{STAGE2EXP}/edits.csv`.
60 |
61 | python run_stage_two.py -task {TASK} -stage2_exp {STAGE2EXP} -editor_path results/{TASK}/editors/mice/{TASK}_editor.pth
62 |
63 |
64 | For instance, to generate edits for the IMDB task, the following command will save edits to `results/imdb/edits/mice_binary/edits.csv`:
65 |
66 | ```bash
67 | python run_stage_two.py -task imdb -stage2_exp mice_binary -editor_path results/imdb/editors/mice/imdb_editor.pth
68 | ```
69 |
70 |
71 | 4. **Inspect Edits**: Inspect these edits with the demo notebook `notebooks/evaluation.ipynb`.
72 |
73 | ## More Information
74 |
75 | `run_all.sh` contains commands for recreating the main experiments in our paper.
76 |
77 | ### Training Predictors
78 |
79 | We use AllenNLP to train our Predictor models. Code for training Predictors can be found in `src/predictors/`.
80 | See `run_all.sh` for commands used to train Predictors, which will save models to subfolders in `trained_predictors`.
81 |
82 | Alternatively, you can work with our pretrained models, which you can download with `download_models.sh`.
83 |
84 |
85 | ### Training Editors
86 | The following command will train an editor (i.e. run Stage 1 of MiCE) for a particular task. It saves checkpoints to `results/{TASK}/editors/{STAGE1EXP}/checkpoints/`.
87 |
88 | python run_stage_one.py -task {TASK} -stage1_exp {STAGE1EXP}
89 |
90 |
91 | ### Generating Edits
92 | The following command will find MiCE edits (i.e. run Stage 2 of MiCE) for a particular task. It saves edits to `results/{TASK}/edits/{STAGE2EXP}/edits.csv`. `-editor_path` determines the Editor model to use. Defaults to our pretrained Editor.
93 |
94 | python run_stage_two.py -task {TASK} -stage2_exp {STAGE2EXP} -editor_path results/{TASK}/editors/mice/{TASK}_editor.pth
95 |
96 |
97 | ### Inspecting Edits
98 | The notebook `notebooks/evaluation.ipynb` contains some code to inspect edits.
99 | To compute fluency of edits, see the `EditEvaluator` class in `src/edit_finder.py`.
100 |
101 | ## Adding a Task
102 | Follow the steps below to extend this repo for your own task.
103 |
104 | 1. Create a subfolder within `src/predictors/{TASK}`
105 |
106 | 2. **Dataset reader**: Create a task specific dataset reader in a file `{TASK}_dataset_reader.py` within that subfolder. It should have methods: `text_to_instance()`, `_read()`, and `get_inputs()`.
107 |
108 | 3. **Train Predictor**: Create a training config (see `src/predictors/imdb/imdb_roberta.json` for an example). Then train the Predictor using AllenNLP (see above commands or commands in `run_all.sh` for examples).
109 |
110 | 4. **Train Editor Model**: Depending on the task, you may have to create a new `StageOneDataset` subclass (see `RaceStageOneDataset` in `src/dataset.py` for an example of how to inherit from `StageOneDataset`).
111 | - For classification tasks, the existing base `StageOneDataset` class should work.
112 | - For new multiple-choice QA tasks with dataset readers patterned after the `RaceDatasetReader` (`src/predictors/race/race_dataset_reader.py`), the existing `RaceStageOneDataset` class should work.
113 |
114 | 5. **Generate Edits**: Depending on the task, you may have to create a new `Editor` subclass (see `RaceEditor` in `src/editor.py` for an example of how to inherit from `Editor`).
115 | - For classification tasks, the existing base `Editor` class should work.
116 | - For multiple-choice QA with dataset readers patterned after `RaceDatasetReader`, the existing `RaceEditor` class should work.
117 |
118 |
--------------------------------------------------------------------------------
/download_models.sh:
--------------------------------------------------------------------------------
1 | for TASK in imdb newsgroups race
2 | do
3 | mkdir -p trained_predictors/${TASK}/model
4 | mkdir -p results/${TASK}/editors/mice/
5 | wget https://storage.googleapis.com/allennlp-public-models/mice-${TASK}-predictor.tar.gz -O trained_predictors/${TASK}/model/model.tar.gz
6 | wget https://storage.googleapis.com/allennlp-public-models/mice-${TASK}-editor.pth -O results/${TASK}/editors/mice/${TASK}_editor.pth
7 | done
8 |
9 |
--------------------------------------------------------------------------------
/notebooks/evaluation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 3,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import sys\n",
11 | "sys.path.append(\"..\")\n",
12 | "from src.utils import html_highlight_diffs\n",
13 | "from IPython.core.display import display, HTML\n",
14 | "import numpy as np"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 4,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from src.utils import load_predictor, get_ints_to_labels"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 5,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "TASK = \"imdb\"\n",
33 | "STAGE2EXP = \"mice_binary\"\n",
34 | "EDIT_PATH = f\"../results/{TASK}/edits/{STAGE2EXP}/edits.csv\""
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 6,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "def read_edits(path):\n",
44 | " edits = pd.read_csv(EDIT_PATH, sep=\"\\t\", lineterminator=\"\\n\", error_bad_lines=False, warn_bad_lines=True)\n",
45 | "\n",
46 | " if edits['new_pred'].dtype == pd.np.dtype('float64'):\n",
47 | " edits['new_pred'] = edits.apply(lambda row: str(int(row['new_pred']) if not np.isnan(row['new_pred']) else \"\"), axis=1)\n",
48 | " edits['orig_pred'] = edits.apply(lambda row: str(int(row['orig_pred']) if not np.isnan(row['orig_pred']) else \"\"), axis=1)\n",
49 | " edits['contrast_pred'] = edits.apply(lambda row: str(int(row['contrast_pred']) if not np.isnan(row['contrast_pred']) else \"\"), axis=1)\n",
50 | " else:\n",
51 | " edits['new_pred'].fillna(value=\"\", inplace=True)\n",
52 | " edits['orig_pred'].fillna(value=\"\", inplace=True)\n",
53 | " edits['contrast_pred'].fillna(value=\"\", inplace=True)\n",
54 | " return edits"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 7,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "def get_best_edits(edits):\n",
64 | " \"\"\" MiCE writes all edits that are found in Stage 2, \n",
65 | " but we only want to evaluate the smallest per input. \n",
66 | " Calling get_sorted_e() \"\"\"\n",
67 | " return edits[edits['sorted_idx'] == 0]\n",
68 | " \n",
69 | "def evaluate_edits(edits):\n",
70 | " temp = edits[edits['sorted_idx'] == 0]\n",
71 | " minim = temp['minimality'].mean()\n",
72 | " flipped = temp[temp['new_pred'].astype(str)==temp['contrast_pred'].astype(str)]\n",
73 | " nunique = temp['data_idx'].nunique()\n",
74 | " flip_rate = len(flipped)/nunique\n",
75 | " duration=temp['duration'].mean()\n",
76 | " metrics = {\n",
77 | " \"num_total\": nunique,\n",
78 | " \"num_flipped\": len(flipped),\n",
79 | " \"flip_rate\": flip_rate,\n",
80 | " \"minimality\": minim,\n",
81 | " \"duration\": duration,\n",
82 | " }\n",
83 | " for k, v in metrics.items():\n",
84 | " print(f\"{k}: \\t{round(v, 3)}\")\n",
85 | " return metrics"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 8,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "def display_edits(row):\n",
95 | " html_original, html_edited = html_highlight_diffs(row['orig_editable_seg'], row['edited_editable_seg'])\n",
96 | " minim = round(row['minimality'], 3)\n",
97 | " print(f\"MINIMALITY: \\t{minim}\")\n",
98 | " print(\"\")\n",
99 | " display(HTML(html_original))\n",
100 | " display(HTML(html_edited))\n",
101 | "\n",
102 | "def display_classif_results(rows):\n",
103 | " for _, row in rows.iterrows():\n",
104 | " orig_contrast_prob_pred = round(row['orig_contrast_prob_pred'], 3)\n",
105 | " new_contrast_prob_pred = round(row['new_contrast_prob_pred'], 3)\n",
106 | " print(\"-----------------------\")\n",
107 | " print(f\"ORIG LABEL: \\t{row['orig_pred']}\")\n",
108 | " print(f\"CONTR LABEL: \\t{row['contrast_pred']} (Orig Pred Prob: {orig_contrast_prob_pred})\")\n",
109 | " print(f\"NEW LABEL: \\t{row['new_pred']} (New Pred Prob: {new_contrast_prob_pred})\")\n",
110 | " print(\"\")\n",
111 | " display_edits(row)\n",
112 | "\n",
113 | "def display_race_results(rows):\n",
114 | " for _, row in rows.iterrows():\n",
115 | " orig_contrast_prob_pred = round(row['orig_contrast_prob_pred'], 3)\n",
116 | " new_contrast_prob_pred = round(row['new_contrast_prob_pred'], 3)\n",
117 | " orig_input = eval(row['orig_input'])\n",
118 | " options = orig_input['options']\n",
119 | " print(\"-----------------------\")\n",
120 | " print(f\"QUESTION: {orig_input['question']}\")\n",
121 | " print(\"\\nOPTIONS:\")\n",
122 | " for opt_idx, opt in enumerate(options):\n",
123 | " print(f\" ({opt_idx}) {opt}\")\n",
124 | " print(f\"\\nORIG LABEL: \\t{row['orig_pred']}\")\n",
125 | " print(f\"CONTR LABEL: \\t{row['contrast_pred']} (Orig Pred Prob: {orig_contrast_prob_pred})\")\n",
126 | " print(f\"NEW LABEL: \\t{row['new_pred']} (New Pred Prob: {new_contrast_prob_pred})\")\n",
127 | " print(\"\")\n",
128 | " display_edits(row)"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 9,
134 | "metadata": {},
135 | "outputs": [
136 | {
137 | "name": "stdout",
138 | "output_type": "stream",
139 | "text": [
140 | "num_total: \t47\n",
141 | "num_flipped: \t47\n",
142 | "flip_rate: \t1.0\n",
143 | "minimality: \t0.183\n",
144 | "duration: \t8.814\n"
145 | ]
146 | },
147 | {
148 | "name": "stderr",
149 | "output_type": "stream",
150 | "text": [
151 | "/home/alexisr/miniconda3/envs/label_contrast_env/lib/python3.7/site-packages/ipykernel_launcher.py:4: FutureWarning: The pandas.np module is deprecated and will be removed from pandas in a future version. Import numpy directly instead\n",
152 | " after removing the cwd from sys.path.\n"
153 | ]
154 | }
155 | ],
156 | "source": [
157 | "edits = read_edits(EDIT_PATH)\n",
158 | "edits = get_best_edits(edits)\n",
159 | "metrics = evaluate_edits(edits)"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 15,
165 | "metadata": {},
166 | "outputs": [
167 | {
168 | "name": "stdout",
169 | "output_type": "stream",
170 | "text": [
171 | "-----------------------\n",
172 | "ORIG LABEL: \t0\n",
173 | "CONTR LABEL: \t1 (Orig Pred Prob: 0.002)\n",
174 | "NEW LABEL: \t1 (New Pred Prob: 0.986)\n",
175 | "\n",
176 | "MINIMALITY: \t0.071\n",
177 | "\n"
178 | ]
179 | },
180 | {
181 | "data": {
182 | "text/html": [
183 | "Eddie Murphy put a lot into this movie by directed wrote starred and produced this story about two nighclub owners in the 30s who try to fight mobsters and corrupt cops from taking over their club..a great cast in Murphy Redd Foxx, Richard Pryor, Danny Aiello, Della Reese, and a gorgeous Jasmine Guy that would make it worth seeing on its own..but the story just doesnt hold up interest or give the great cast enough to work with.. on a scale of one to ten..a 4 "
184 | ],
185 | "text/plain": [
186 | ""
187 | ]
188 | },
189 | "metadata": {},
190 | "output_type": "display_data"
191 | },
192 | {
193 | "data": {
194 | "text/html": [
195 | "Eddie Murphy put a lot into this movie by directed wrote starred and produced this story about two nighclub owners in the 30s who try to fight mobsters and corrupt cops from taking over their club..a great cast in Murphy Redd Foxx, Richard Pryor, Danny Aiello, Della Reese, and a gorgeous Jasmine Guy that would make it worth seeing on its own..but the story just might not hold up interest or give the great cast enough to work with.. on a budget of one to spare..a must see.."
196 | ],
197 | "text/plain": [
198 | ""
199 | ]
200 | },
201 | "metadata": {},
202 | "output_type": "display_data"
203 | }
204 | ],
205 | "source": [
206 | "random_rows = edits.sample(1)\n",
207 | "display_classif_results(random_rows)\n",
208 | "# display_race_results(random_rows)"
209 | ]
210 | }
211 | ],
212 | "metadata": {
213 | "kernelspec": {
214 | "display_name": "Python 3",
215 | "language": "python",
216 | "name": "python3"
217 | },
218 | "language_info": {
219 | "codemirror_mode": {
220 | "name": "ipython",
221 | "version": 3
222 | },
223 | "file_extension": ".py",
224 | "mimetype": "text/x-python",
225 | "name": "python",
226 | "nbconvert_exporter": "python",
227 | "pygments_lexer": "ipython3",
228 | "version": "3.7.9"
229 | }
230 | },
231 | "nbformat": 4,
232 | "nbformat_minor": 2
233 | }
234 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas==1.1.3
2 | numpy==1.19.0
3 | nltk==3.5
4 | allennlp_models==1.2.2
5 | torch==1.7.1
6 | allennlp==1.2.2
7 | munch==2.5.0
8 | tqdm==4.47.0
9 | overrides==3.1.0
10 | transformers==3.5.0
11 | more_itertools==8.4.0
12 | scikit_learn==0.24.2
13 |
--------------------------------------------------------------------------------
/run_all.sh:
--------------------------------------------------------------------------------
1 | ################################################################
2 | ####################### TRAIN PREDICTORS #######################
3 | ################################################################
4 |
5 | # Train RACE Predictor
6 | allennlp train src/predictors/race/race_roberta.json \
7 | --include-package src.predictors.race.race_dataset_reader \
8 | -s trained_predictors/models/race/
9 |
10 | # Train IMDB Predictor
11 | allennlp train src/predictors/imdb/imdb_roberta.json \
12 | --include-package src.predictors.imdb.imdb_dataset_reader \
13 | -s trained_predictors/models/imdb/
14 |
15 | # Train Newsgroups Predictor
16 | allennlp train src/predictors/newsgroups/newsgroups_roberta.json \
17 | --include-package src.predictors.newsgroups.newsgroups_dataset_reader \
18 | -s trained_predictors/models/newsgroups/
19 |
20 | ################################################################
21 | ########################## STAGE ONE ###########################
22 | ################################################################
23 |
24 | STAGE1EXP=mice_gold
25 |
26 | python run_stage_one.py -task imdb -stage1_exp ${STAGE1EXP}
27 | python run_stage_one.py -task newsgroups -stage1_exp ${STAGE1EXP}
28 | python run_stage_one.py -task race -stage1_exp ${STAGE1EXP}
29 |
30 | ################################################################
31 | ########################## STAGE TWO ###########################
32 | ################################################################
33 |
34 | STAGE2EXP=mice_binary
35 |
36 | python run_stage_two.py -task imdb \
37 | -editor_path results/imdb/editors/${STAGE1EXP}/checkpoints/ \
38 | -stage2_exp ${STAGE2EXP}
39 |
40 | python run_stage_two.py -task newsgroups \
41 | -editor_path results/newsgroups/editors/${STAGE1EXP}/checkpoints/ \
42 | -stage2_exp ${STAGE2EXP}
43 |
44 | python run_stage_two.py -task race \
45 | -editor_path results/race/editors/${STAGE1EXP}/checkpoints/ \
46 | -stage2_exp ${STAGE2EXP}
47 |
--------------------------------------------------------------------------------
/run_stage_one.py:
--------------------------------------------------------------------------------
1 | # Local imports
2 | from src.stage_one import run_train_editor
3 | from src.utils import get_args, load_predictor, get_dataset_reader
4 |
5 | if __name__ == '__main__':
6 |
7 | args = get_args("stage1")
8 | predictor = load_predictor(args.meta.task)
9 | dr = get_dataset_reader(args.meta.task, predictor)
10 | run_train_editor(predictor, dr, args)
11 |
--------------------------------------------------------------------------------
/run_stage_two.py:
--------------------------------------------------------------------------------
1 | # Local imports
2 | from src.stage_two import *
3 | from src.utils import get_args
4 |
5 | if __name__ == '__main__':
6 |
7 | args = get_args("stage2")
8 | run_edit_test(args)
9 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from tqdm import tqdm
4 | import numpy as np
5 | import random
6 | import logging
7 |
8 | # Local imports
9 | from src.masker import MaskError
10 | from src.utils import *
11 |
12 | logger = logging.getLogger(__name__)
13 | logger.setLevel(logging.INFO)
14 |
15 | class StageOneDataset(Dataset):
16 | """ Dataset for training Editor models in Stage One. Creates masked inputs
17 | from task training inputs. Inherits from torch.utils.data.Dataset. """
18 |
19 | def __init__(
20 | self,
21 | tokenizer,
22 | max_length=700,
23 | masked_strings=None,
24 | targets=None
25 | ):
26 | self.tokenizer = tokenizer
27 | self.masked_strings = masked_strings
28 | self.targets = targets
29 | self.max_length = max_length
30 |
31 | def __len__(self):
32 | return len(self.masked_strings)
33 |
34 | def __getitem__(self, index):
35 | input_text = self.masked_strings[index]
36 | label_text = self.targets[index]
37 |
38 | source = self.tokenizer.batch_encode_plus([input_text],
39 | truncation=True, max_length=self.max_length,
40 | pad_to_max_length=True, return_tensors='pt')
41 | target = self.tokenizer.batch_encode_plus([label_text],
42 | truncation=True, max_length=self.max_length,
43 | pad_to_max_length=True, return_tensors='pt')
44 |
45 | source_ids = source['input_ids'].squeeze()
46 | source_mask = source['attention_mask'].squeeze()
47 | target_ids = target['input_ids'].squeeze()
48 | target_mask = target['attention_mask'].squeeze()
49 |
50 | eos_id = torch.LongTensor([self.tokenizer.encode(label_text)[-1]])
51 |
52 | return {
53 | 'eos_id': eos_id,
54 | 'source_ids': source_ids.to(dtype=torch.long),
55 | 'source_mask': source_mask.to(dtype=torch.long),
56 | 'target_ids': target_ids.to(dtype=torch.long),
57 | 'target_ids_y': target_ids.to(dtype=torch.long)
58 | }
59 |
60 | def create_inputs(
61 | self, orig_inputs,
62 | orig_labels, predictor,
63 | masker, target_label = "pred",
64 | mask_fracs=np.arange(0.2, 0.6, 0.05),
65 | mask_frac_probs=[0.125] * 8
66 | ):
67 | target_label_options = ["pred", "gold"]
68 | if target_label not in target_label_options:
69 | error_msg = f"target_label must be in {target_label_options} "
70 | error_msg += f"but got '{target_label}'"
71 | raise ValueError(error_msg)
72 |
73 | masked_strings, targets = [], []
74 | labels_to_ints = get_labels_to_ints(predictor)
75 |
76 | num_errors = 0
77 | iterator = enumerate(zip(orig_inputs, orig_labels))
78 | for i, (orig_inp, orig_label) in tqdm(iterator, total=len(orig_inputs)):
79 | masker.mask_frac = np.random.choice(mask_fracs, 1,
80 | p=mask_frac_probs)[0]
81 |
82 | pred = predictor.predict(orig_inp)
83 | pred_label = pred['label']
84 |
85 | label_to_use = pred_label if target_label == "pred" else orig_label
86 | label_idx = labels_to_ints[label_to_use]
87 |
88 | predictor_tokenized = get_predictor_tokenized(predictor, orig_inp)
89 |
90 | try:
91 | _, word_indices_to_mask, masked_input, target = \
92 | masker.get_masked_string(orig_inp, label_idx,
93 | predictor_tok_end_idx=len(predictor_tokenized))
94 | masked_string = format_classif_input(masked_input, label_to_use)
95 | masked_strings.append(masked_string)
96 | targets.append(target)
97 |
98 | except MaskError:
99 | num_errors += 1
100 |
101 | verbose = True if i % 500 == 0 else False
102 |
103 | if verbose:
104 | rounded_mask_frac = round(masker.mask_frac, 3)
105 | logger.info(wrap_text(f"Original input ({i}): " + orig_inp))
106 | logger.info(wrap_text(f"Mask frac: {rounded_mask_frac}"))
107 | logger.info(wrap_text(f"Editor input: {masked_string}"))
108 | logger.info(wrap_text("Editor target: " + target))
109 |
110 | self.masked_strings = masked_strings
111 | self.targets = targets
112 |
113 | class RaceStageOneDataset(StageOneDataset):
114 | def __init__(self, *args, **kwargs):
115 | super().__init__(*args, **kwargs)
116 |
117 | def create_inputs(
118 | self, dr, orig_inputs,
119 | orig_labels, predictor, masker,
120 | mask_fracs=np.arange(0.2, 0.6, 0.05),
121 | mask_frac_probs=[0.125] * 8,
122 | editable_key = "article",
123 | target_label = "pred"
124 | ):
125 |
126 | editable_keys = ["article", "question"]
127 | if editable_key not in editable_keys:
128 | raise ValueError(f"Editable key must be in {editable_keys} \
129 | but got value {editable_key}")
130 |
131 | labels_to_ints = get_labels_to_ints(predictor)
132 |
133 | num_errors = 0
134 | masked_strings, targets = [], []
135 |
136 | iterator = enumerate(zip(orig_inputs, orig_labels))
137 | for i, (orig_inp, gold_label) in tqdm(iterator, total=len(orig_inputs)):
138 | masker.mask_frac = np.random.choice(mask_fracs, 1,
139 | p=mask_frac_probs)[0]
140 |
141 | instance, length_lst, _ = dr.text_to_instance(
142 | orig_inp["id"], orig_inp["article"],
143 | orig_inp["question"], orig_inp["options"]
144 | )
145 | options = orig_inp["options"]
146 | pred = predictor.predict_instance(instance)
147 | pred_label = int(pred['best_alternative'])
148 |
149 | # For RACE, label is already int, not string
150 | label_idx = pred_label if target_label == "pred" else gold_label
151 |
152 | try:
153 | # Mask the article
154 | if editable_key == "article":
155 | article_tok = get_predictor_tokenized(predictor,
156 | orig_inp["article"])
157 | predictor_tok_end_idx = min(len(article_tok),
158 | length_lst[label_idx])
159 | _, word_indices_to_mask, masked_article, target = \
160 | masker.get_masked_string(
161 | orig_inp["article"], label_idx,
162 | labeled_instance=instance,
163 | predictor_tok_end_idx=predictor_tok_end_idx
164 | )
165 | question = orig_inp["question"]
166 | article = masked_article
167 |
168 | # Mask the question
169 | # TODO: Does this work? Have only tested article
170 | elif editable_key == "question":
171 | question_tok = get_predictor_tokenized(predictor,
172 | orig_inp["question"])
173 | predictor_tok_end_idx = length_lst[label_idx] + \
174 | len(question_tok)
175 | _, word_indices_to_mask, masked_question, target = \
176 | masker.get_masked_string(
177 | orig_inp["question"], label_idx,
178 | labeled_instance=instance,
179 | predictor_tok_start_idx=length_lst[label_idx],
180 | predictor_tok_end_idx=predictor_tok_end_idx
181 | )
182 | question = masked_question
183 | article = orig_inp["article"]
184 |
185 | masked_string = format_multiple_choice_input(
186 | article, question, options, label_idx)
187 | masked_strings.append(masked_string)
188 | targets.append(target)
189 |
190 | except MaskError:
191 | num_errors += 1
192 |
193 | self.masked_strings = masked_strings
194 | self.targets = targets
195 |
--------------------------------------------------------------------------------
/src/edit_finder.py:
--------------------------------------------------------------------------------
1 | from allennlp.predictors import Predictor, TextClassifierPredictor
2 |
3 | import sys
4 | import allennlp
5 | from allennlp.data.tokenizers.spacy_tokenizer import SpacyTokenizer
6 | from allennlp.data.batch import Batch
7 |
8 | import torch
9 | import nltk
10 | import numpy as np
11 |
12 | import re
13 | import more_itertools as mit
14 | import math
15 | import textwrap
16 | import time
17 | import logging
18 | import os
19 | import heapq
20 | import difflib
21 |
22 | from transformers import T5Tokenizer, T5Model, T5Config
23 | from transformers import T5ForConditionalGeneration
24 |
25 | from src.masker import Masker, RandomMasker, GradientMasker
26 | from src.utils import *
27 |
28 | logger = logging.getLogger("my-logger")
29 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
30 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"), format=FORMAT)
31 | logger.setLevel(logging.INFO)
32 |
33 | ####################################################################
34 | ############################## Utils ###############################
35 | ####################################################################
36 |
37 | def get_max_instance(instance_candidates, contrast_pred_idx):
38 | """ Returns candidate with highest predicted prob of contrast_pred_idx. """
39 |
40 | batched_preds = predictor.predict_batch_instance(instance_candidates)
41 | batched_preds = [add_prob(pred) for pred in batched_preds]
42 | max_idx = max(range(len(batched_preds)), key=lambda index: \
43 | batched_preds[index]['probs'][contrast_pred_idx])
44 | max_candidate = instance_candidates[max_idx]
45 | max_prob = batched_preds[max_idx]['probs'][contrast_pred_idx]
46 |
47 | return max_candidate, max_prob
48 |
49 | class EditEvaluator():
50 | def __init__(
51 | self,
52 | fluency_model_name = "t5-base",
53 | fluency_masker = RandomMasker(None, SpacyTokenizer(), 512)
54 | ):
55 | self.device = get_device()
56 | self.fluency_model = T5ForConditionalGeneration.from_pretrained(
57 | fluency_model_name).to(self.device)
58 | self.fluency_tokenizer = T5Tokenizer.from_pretrained(
59 | fluency_model_name)
60 | self.fluency_masker = fluency_masker
61 |
62 | def score_fluency(self, sent):
63 | temp_losses = []
64 | masked_strings, span_labels = \
65 | self.fluency_masker.get_all_masked_strings(sent)
66 | for masked, label in zip(masked_strings, span_labels):
67 | input_ids = self.fluency_tokenizer.encode(masked,
68 | truncation="longest_first", max_length=600,
69 | return_tensors="pt")
70 | input_ids = input_ids.to(self.device)
71 | labels = self.fluency_tokenizer.encode(label,
72 | truncation="longest_first", max_length=600,
73 | return_tensors="pt")
74 | labels = labels.to(self.device)
75 | outputs = self.fluency_model(input_ids=input_ids, labels=labels)
76 | loss = outputs[0]
77 | temp_losses.append(loss.item())
78 | del input_ids
79 | del labels
80 | del loss
81 | torch.cuda.empty_cache()
82 | avg_loss = sum(temp_losses)/len(temp_losses)
83 | return avg_loss
84 |
85 | def score_minimality(self, orig_sent, edited_sent, normalized=True):
86 | spacy = SpacyTokenizer()
87 | tokenized_original = [t.text for t in spacy.tokenize(orig_sent)]
88 | tokenized_edited = [t.text for t in spacy.tokenize(edited_sent)]
89 | lev = nltk.edit_distance(tokenized_original, tokenized_edited)
90 | if normalized:
91 | return lev/len(tokenized_original)
92 | else:
93 | return lev
94 |
95 | def sort_instances_by_score(scores, *args):
96 | """ Sorts *args in order of decreasing scores """
97 |
98 | zipped = list(zip(scores, *args))
99 | zipped.sort(reverse = True)
100 | return list(zipped)
101 |
102 | def get_scores(predictor, instance_candidates, contrast_pred_idx, k = None):
103 | """ Gets (top k) predicted probs of contrast_pred_idx on candidates. """
104 |
105 | # Get predictions
106 | with torch.no_grad():
107 | cuda_device = predictor._model._get_prediction_device()
108 | dataset = Batch(instance_candidates)
109 | dataset.index_instances(predictor._model.vocab)
110 | model_input = allennlp.nn.util.move_to_device(
111 | dataset.as_tensor_dict(), cuda_device)
112 | outputs = predictor._model.make_output_human_readable(
113 | predictor._model(**model_input))
114 | outputs = add_probs(outputs)
115 | probs = outputs['probs']
116 |
117 | if k != None:
118 | pred_indices = torch.argmax(probs, dim=1)
119 |
120 | # Compute this only for remaining
121 | contrast_pred_tensor = torch.tensor([contrast_pred_idx]).cuda()
122 | bool_equal = (pred_indices == contrast_pred_tensor)
123 | pred_is_contrast_indices = bool_equal.reshape(-1).nonzero().reshape(-1)
124 |
125 | num_to_return = max(k, len(pred_is_contrast_indices))
126 |
127 | contrast_probs = probs[:, contrast_pred_idx]
128 | sorted_contrast_probs = torch.argsort(contrast_probs, descending=True)
129 | highest_indices = sorted_contrast_probs[:num_to_return]
130 | cpu_contrast_probs = torch.index_select(
131 | contrast_probs, 0, highest_indices).cpu().numpy()
132 | selected_pred_indices = torch.index_select(
133 | pred_indices, 0, highest_indices)
134 | cpu_pred_indices = selected_pred_indices.cpu().numpy()
135 | highest_indices = highest_indices.cpu().numpy()
136 | else:
137 | cpu_pred_indices = torch.argmax(probs, dim=1).cpu().numpy()
138 | cpu_contrast_probs = probs[:, contrast_pred_idx].cpu().numpy()
139 | highest_indices = range(len(cpu_contrast_probs))
140 |
141 | assert cpu_pred_indices.shape == cpu_contrast_probs.shape
142 | del outputs
143 | del probs
144 | del contrast_probs
145 | del selected_pred_indices
146 | del contrast_pred_tensor
147 | del pred_is_contrast_indices
148 | return cpu_contrast_probs, cpu_pred_indices, highest_indices
149 |
150 | ####################################################################
151 | ########################## Main Classes ############################
152 | ####################################################################
153 |
154 | class EditFinder():
155 | """ Runs search algorithms to find edits. """
156 |
157 | def __init__(
158 | self,
159 | predictor,
160 | editor,
161 | beam_width = 3,
162 | search_method = "binary",
163 | max_mask_frac = 0.5,
164 | max_search_levels = 10,
165 | verbose = True
166 | ):
167 | self.predictor = predictor
168 | self.editor = editor
169 | self.beam_width = beam_width
170 | self.ints_to_labels = self.editor.ints_to_labels
171 | self.search_method = search_method
172 | self.max_search_levels = max_search_levels
173 | self.device = get_device()
174 | self.verbose = verbose
175 | self.max_mask_frac = max_mask_frac
176 |
177 | def run_edit_round(
178 | self,
179 | edit_list,
180 | input_cand,
181 | contrast_pred_idx,
182 | num_rounds,
183 | mask_frac,
184 | edit_evaluator=None,
185 | sorted_token_indices=None
186 | ):
187 | logger.info(wrap_text(f"Running candidate generation for mask frac: \
188 | {mask_frac}; max mask frac: {self.max_mask_frac}"))
189 | self.editor.masker.mask_frac = mask_frac
190 |
191 | candidates, masked_sentence = self.editor.get_candidates(
192 | edit_list.contrast_label, input_cand, contrast_pred_idx,
193 | edit_list.orig_pred_idx,
194 | sorted_token_indices=sorted_token_indices)
195 |
196 | input_cands = [e['edited_input'] for e in candidates]
197 | editable_seg_cands = [e['edited_editable_seg'] for e in candidates]
198 |
199 | instance_cands = [self.editor.input_to_instance(inp, editable_seg=es) \
200 | for inp, es in zip(input_cands, editable_seg_cands)]
201 |
202 | # TODO: does this happen? might get [] if all generations are bad, but
203 | # Should not happen (if no good generations, use original infillings)
204 | if len(input_cands) == 0:
205 | logger.info("no candidates returned...")
206 | return False
207 |
208 | probs, pred_indices, highest_indices = get_scores(self.predictor,
209 | instance_cands, contrast_pred_idx, k = self.beam_width)
210 |
211 | input_cands = [input_cands[idx] for idx in highest_indices]
212 | editable_seg_cands = [editable_seg_cands[idx] for idx in highest_indices]
213 |
214 | # Sort these by highest score for iteration
215 | sorted_cands = sort_instances_by_score(
216 | probs, pred_indices, input_cands, editable_seg_cands)
217 | found_cand = False
218 | beam_inputs = [s for _, _, s in edit_list.beam]
219 | iterator = enumerate(sorted_cands)
220 | for sort_idx, (prob, pred_idx, input_cand, editable_seg_cand) in iterator:
221 | if self.verbose and sort_idx == 0:
222 | logger.info(wrap_text(f"Post edit round, top contr prob: {prob}"))
223 | logger.info(wrap_text(f"Post edit round, top cand: {input_cand}"))
224 | pred_idx = int(pred_idx)
225 | edit_list.counter += 1
226 | if input_cand not in beam_inputs:
227 | heapq.heappush(edit_list.beam,
228 | (prob, edit_list.counter, input_cand))
229 |
230 | label = self.ints_to_labels[pred_idx]
231 |
232 | if pred_idx == contrast_pred_idx:
233 | found_cand = True
234 |
235 | # Score minimality because we order edits by minimality scores
236 | if edit_evaluator is not None:
237 | minimality = edit_evaluator.score_minimality(
238 | edit_list.orig_editable_seg,
239 | editable_seg_cand, normalized=True)
240 | edit = {"edited_editable_seg": editable_seg_cand,
241 | "edited_input": input_cand,
242 | "minimality": minimality,
243 | "masked_sentence": masked_sentence,
244 | "edited_contrast_prob": prob,
245 | "edited_label": label,
246 | "mask_frac": mask_frac,
247 | "num_edit_rounds": num_rounds}
248 | edit_list.add_edit(edit)
249 |
250 | if len(edit_list.beam) > self.beam_width:
251 | _ = heapq.heappop(edit_list.beam)
252 |
253 | del probs
254 | del pred_indices
255 | return found_cand
256 |
257 | def binary_search_edit(
258 | self, edit_list, input_cand, contrast_pred_idx, num_rounds,
259 | min_mask_frac=0.0, max_mask_frac=0.5, num_levels=1,
260 | max_levels=None, edit_evaluator=None, sorted_token_indices=None):
261 |
262 | """ Runs binary search over masking percentages, starting at
263 | midpoint between min_mask_frac and max_mask_frac.
264 | Calls run_edit_round at each mask percentage. """
265 |
266 | if max_levels == None:
267 | max_levels = self.max_search_levels
268 |
269 | mid_mask_frac = (max_mask_frac + min_mask_frac) / 2
270 |
271 | if self.verbose:
272 | logger.info(wrap_text("binary search mid: " + str(mid_mask_frac)))
273 | found_cand = self.run_edit_round(
274 | edit_list, input_cand, contrast_pred_idx, num_rounds,
275 | mid_mask_frac, edit_evaluator=edit_evaluator,
276 | sorted_token_indices=sorted_token_indices)
277 | if self.verbose:
278 | logger.info(wrap_text("Binary search # levels: " + str(num_levels)))
279 | logger.info(wrap_text("Found cand: " + str(found_cand)))
280 |
281 | mid_mask_frac = (max_mask_frac + min_mask_frac) / 2
282 | if num_levels == max_levels:
283 | return found_cand
284 |
285 | elif num_levels < max_levels:
286 | if found_cand:
287 | return self.binary_search_edit(
288 | edit_list, input_cand, contrast_pred_idx, num_rounds,
289 | min_mask_frac=min_mask_frac,
290 | max_mask_frac=mid_mask_frac,
291 | num_levels=num_levels+1,
292 | sorted_token_indices=sorted_token_indices,
293 | edit_evaluator=edit_evaluator)
294 | else:
295 | return self.binary_search_edit(edit_list, input_cand,
296 | contrast_pred_idx, num_rounds,
297 | min_mask_frac=mid_mask_frac,
298 | max_mask_frac=max_mask_frac,
299 | num_levels=num_levels+1,
300 | sorted_token_indices=sorted_token_indices,
301 | edit_evaluator=edit_evaluator)
302 | else:
303 | error_msg = "Reached > max binary search levels." + \
304 | f"({num_levels} > {max_levels})"
305 | raise RuntimeError(error_msg)
306 |
307 | def linear_search_edit(
308 | self, edit_list, input_cand, contrast_pred_idx, num_rounds,
309 | min_mask_frac=0.0, max_mask_frac=0.5, max_levels=None,
310 | edit_evaluator=None, sorted_token_indices=None):
311 |
312 | """ Runs linear search over masking percentages from min_mask_frac
313 | to max_mask_frac. Calls run_edit_round at each mask percentage. """
314 |
315 | predictor, editor = self.predictor, self.editor
316 | if max_levels == None: max_levels = self.max_search_levels
317 | mask_frac_step = (max_mask_frac - min_mask_frac) / max_levels
318 | mask_frac_iterator = np.arange(min_mask_frac+mask_frac_step,
319 | max_mask_frac + mask_frac_step,
320 | mask_frac_step)
321 | for mask_frac in mask_frac_iterator:
322 | found_cand = self.run_edit_round(
323 | edit_list, input_cand, contrast_pred_idx,
324 | num_rounds, mask_frac, edit_evaluator=edit_evaluator,
325 | sorted_token_indices=sorted_token_indices)
326 | logger.info(wrap_text("Linear search mask_frac: " + str(mask_frac)))
327 | logger.info(wrap_text("Found cand: " + str(found_cand)))
328 | if found_cand:
329 | return found_cand
330 | return found_cand
331 |
332 | def minimally_edit(
333 | self, orig_input, contrast_pred_idx = -2,
334 | max_edit_rounds = 10, edit_evaluator=None):
335 |
336 | """ Gets minimal edits for given input.
337 | Calls search algorithm (linear/binary) based on self.search_method.
338 | contrast_pred_idx specifies which label to use as the contrast.
339 | Defaults to -2, i.e. use label with 2nd highest pred prob.
340 |
341 | Returns EditList() object. """
342 |
343 | editor = self.editor
344 | beam_width = self.beam_width
345 |
346 | # Get truncated editable part of input
347 | editable_seg = self.editor.get_editable_seg_from_input(orig_input)
348 | editable_seg = self.editor.truncate_editable_segs(
349 | [editable_seg], inp=orig_input)[0]
350 |
351 | orig_input = self.editor.get_input_from_editable_seg(
352 | orig_input, editable_seg)
353 | num_toks = len(get_predictor_tokenized(self.predictor, editable_seg))
354 | assert num_toks <= self.predictor._dataset_reader._tokenizer._max_length
355 |
356 | editable_seg = self.editor.tokenizer.decode(
357 | self.editor.tokenizer.encode(editable_seg),
358 | clean_up_tokenization_spaces=True).replace("", " ")
359 | start_time = time.time()
360 |
361 | instance = self.editor.input_to_instance(
362 | orig_input, editable_seg=editable_seg)
363 |
364 | orig_pred = self.predictor.predict_instance(instance)
365 | orig_pred = add_probs(orig_pred)
366 | orig_probs = orig_pred['probs']
367 | orig_pred_idx = np.array(orig_probs).argsort()[-1]
368 | orig_pred_label = self.editor.ints_to_labels[orig_pred_idx]
369 |
370 | assert orig_pred_label == str(orig_pred_label)
371 |
372 | contrast_pred_idx = np.array(orig_probs).argsort()[contrast_pred_idx]
373 | contrast_label = self.ints_to_labels[contrast_pred_idx]
374 |
375 | orig_contrast_prob = get_prob_pred(orig_pred, contrast_pred_idx)
376 | orig_contrast_prob = orig_pred['probs'][contrast_pred_idx]
377 |
378 | assert orig_contrast_prob < 1.0
379 |
380 | num_rounds = 0
381 | new_pred_label = orig_pred_label
382 |
383 | logger.info(f"Contrast label: {contrast_label}")
384 | logger.info(f"Orig contrast prob: {round(orig_contrast_prob, 3)}")
385 |
386 | edit_list = EditList(orig_input, editable_seg, orig_contrast_prob,
387 | orig_pred_label, contrast_label, orig_pred_idx)
388 |
389 | while new_pred_label != contrast_label:
390 | num_rounds += 1
391 | prev_beam = edit_list.beam.copy()
392 |
393 | # Iterate through in reversed order (highest probabilities first)
394 | iterator = enumerate(reversed(sorted(prev_beam)))
395 | for beam_elem_idx, (score, _, input_cand) in iterator:
396 |
397 | sys.stdout.flush()
398 | logger.info(wrap_text("Updating beam for: {input_cand}"))
399 | logger.info(wrap_text(f"Edit round: {num_rounds} (1-indexed)"))
400 | logger.info(wrap_text(f"Element {beam_elem_idx} of beam"))
401 | logger.info(wrap_text(f"Contrast label: {contrast_label}"))
402 | logger.info(wrap_text(f"Contrast prob: {round(score, 3)}"))
403 | logger.info(wrap_text("Generating candidates..."))
404 |
405 | if self.editor.grad_pred == "original":
406 | pred_idx = orig_pred_idx
407 | elif self.editor.grad_pred == "contrast":
408 | pred_idx = contrast_pred_idx
409 |
410 | sorted_token_indices = self.editor.get_sorted_token_indices(
411 | input_cand, pred_idx)
412 |
413 | if self.search_method == "binary":
414 | self.binary_search_edit(edit_list, input_cand,
415 | contrast_pred_idx, num_rounds,
416 | max_mask_frac=self.max_mask_frac, num_levels=1,
417 | edit_evaluator=edit_evaluator,
418 | sorted_token_indices=sorted_token_indices)
419 |
420 | elif self.search_method == "linear":
421 | self.linear_search_edit(edit_list, input_cand,
422 | contrast_pred_idx, num_rounds,
423 | max_mask_frac=self.max_mask_frac,
424 | sorted_token_indices=sorted_token_indices,
425 | edit_evaluator=edit_evaluator)
426 |
427 | if len(edit_list.successful_edits) != 0:
428 | logger.info("Found edit at edit round: {num_rounds}")
429 | return edit_list
430 |
431 | logger.info("CURRENT BEAM after considering candidates: ")
432 | for prob, _, input_cand in reversed(sorted(edit_list.beam)):
433 | logger.info(wrap_text(f"({round(prob, 4)}) {input_cand}"))
434 |
435 | highest_beam_element = sorted(list(edit_list.beam))[-1]
436 | _, _, input_cand = highest_beam_element
437 | num_minutes = round((time.time() - start_time)/60, 3)
438 |
439 | # If we've reached max # edit rounds, return highest cand in beam
440 | if num_rounds >= max_edit_rounds:
441 | logger.info(wrap_text("Reached max substitutions!"))
442 | return edit_list
443 |
444 | if edit_list.beam == prev_beam:
445 | logger.info(wrap_text("Beam unchanged after updating beam."))
446 | return edit_list
447 |
448 | return edit_list
449 |
450 | class EditList():
451 | """ Class for storing edits/beam for a particular input. """
452 |
453 | def __init__(
454 | self, orig_input, orig_editable_seg, orig_contrast_prob,
455 | orig_label, contrast_label, orig_pred_idx):
456 |
457 | self.orig_input = orig_input
458 | self.orig_editable_seg = orig_editable_seg
459 | self.successful_edits = []
460 | self.orig_contrast_prob = orig_contrast_prob
461 | self.orig_label = orig_label
462 | self.orig_pred_idx = orig_pred_idx
463 | self.contrast_label = contrast_label
464 | self.counter = 0
465 | self.beam = [(orig_contrast_prob, self.counter, orig_input)]
466 | heapq.heapify(self.beam)
467 |
468 | def add_edit(self, edit): # edit should be a dict
469 | orig_len = len(self.successful_edits)
470 | self.successful_edits.append(edit)
471 | assert len(self.successful_edits) == orig_len + 1
472 |
473 | def get_sorted_edits(self):
474 | return sorted(self.successful_edits, key=lambda k: k['minimality'])
475 |
476 |
--------------------------------------------------------------------------------
/src/editor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import re
4 | import os
5 | import sys
6 | import more_itertools as mit
7 | import math
8 | import textwrap
9 | import logging
10 | import warnings
11 |
12 | # Local imports
13 | from src.utils import *
14 |
15 | logger = logging.getLogger("my-logger")
16 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
17 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"), format=FORMAT)
18 | logger.setLevel(logging.INFO)
19 |
20 | class Editor():
21 | def __init__(
22 | self,
23 | tokenizer_wrapper,
24 | tokenizer,
25 | editor_model,
26 | masker,
27 | num_gens = 15,
28 | num_beams = 30,
29 | grad_pred = "contrast",
30 | generate_type = "sample",
31 | no_repeat_ngram_size = 2,
32 | top_k = 30,
33 | top_p = 0.92,
34 | length_penalty = 0.5,
35 | verbose = True,
36 | prepend_label = True,
37 | ints_to_labels = None
38 | ):
39 | self.tokenizer = tokenizer
40 | self.device = get_device()
41 | self.num_gens = num_gens
42 | self.editor_model = editor_model.to(self.device)
43 | self.tokenizer_wrapper = tokenizer_wrapper
44 | self.masker = masker
45 | if ints_to_labels is None:
46 | ints_to_labels = get_ints_to_labels(self.masker.predictor)
47 | self.ints_to_labels = ints_to_labels
48 | self.max_length = self.editor_model.config.n_positions
49 | self.predictor = self.masker.predictor
50 | self.dataset_reader = self.predictor._dataset_reader
51 | self.grad_pred = grad_pred
52 | self.verbose = verbose
53 | self.generate_type = generate_type
54 | self.no_repeat_ngram_size = no_repeat_ngram_size
55 | self.top_k = top_k
56 | self.top_p = top_p
57 | self.length_penalty = length_penalty
58 | self.num_beams = num_beams
59 | self.prepend_label = prepend_label
60 |
61 | def get_editor_input(self, targ_pred_label, masked_editable_seg, *args):
62 | """ Format input for editor """
63 |
64 | prefix = "" if not self.prepend_label else "label: " + \
65 | targ_pred_label + ". input: "
66 | return prefix + masked_editable_seg
67 |
68 | def get_editable_seg_from_input(self, inp):
69 | """ Map whole input -> editable seg.
70 | These are the same for single-input classification. """
71 |
72 | return inp
73 |
74 | def get_input_from_editable_seg(self, inp, editable_seg):
75 | """ Map whole input -> editable seg.
76 | These are the same for IMDB/Newsgroups. """
77 |
78 | return editable_seg
79 |
80 | def truncate_editable_segs(self, editable_segs, **kwargs):
81 | """ Truncate editable segments to max length of Predictor. """
82 |
83 | trunc_es = [None] * len(editable_segs)
84 | for s_idx, s in enumerate(editable_segs):
85 | assert(len(s) > 0)
86 | predic_tokenized = get_predictor_tokenized(self.predictor, s)
87 |
88 | max_predic_tokens = self.dataset_reader._tokenizer._max_length
89 | if len(predic_tokenized) >= max_predic_tokens:
90 | for idx, token in enumerate(reversed(predic_tokenized)):
91 | if token.idx_end is not None:
92 | last_idx = token.idx_end
93 | break
94 | trunc_es[s_idx] = s[0:last_idx]
95 | else:
96 | trunc_es[s_idx] = s
97 | return trunc_es
98 |
99 | def input_to_instance(self, inp, editable_seg = None, return_tuple = False):
100 | """ Convert input to AllenNLP instance object """
101 |
102 | if editable_seg is None:
103 | instance = self.dataset_reader.text_to_instance(inp)
104 | else:
105 | instance = self.dataset_reader.text_to_instance(editable_seg)
106 | if return_tuple:
107 | # TODO: hacky bc for race dataset reader, we return length list
108 | return instance, [None]
109 | return instance
110 |
111 | def get_sorted_token_indices(self, inp, grad_pred_idx):
112 | """ Get token indices to mask, sorted by gradient value """
113 |
114 | editable_seg = self.get_editable_seg_from_input(inp)
115 | editable_toks = self.tokenizer_wrapper.tokenize(editable_seg)[:-1]
116 | sorted_token_indices = self.masker.get_important_editor_tokens(
117 | editable_seg, grad_pred_idx, editable_toks,
118 | num_return_toks = len(editable_toks)
119 | )
120 | return sorted_token_indices
121 |
122 | def get_candidates(
123 | self, targ_pred_label, inp, targ_pred_idx, orig_pred_idx,
124 | sorted_token_indices = None):
125 | """ Gets edit candidates after infilling with Editor.
126 | Returns dicts with edited inputs (i.e. whole inputs, dicts in the case
127 | of RACE) and edited editable segs (i.e. just the parts of inputs
128 | that are editable, articles in the case of RACE). """
129 |
130 | assert targ_pred_idx != orig_pred_idx
131 |
132 | if self.grad_pred == "contrast":
133 | grad_pred_idx = targ_pred_idx
134 | elif self.grad_pred == "original":
135 | grad_pred_idx = orig_pred_idx
136 | else:
137 | raise ValueError
138 |
139 | num_spans, token_ind_to_mask, masked_inp, orig_spans, max_length = \
140 | self._prepare_input_for_editor(
141 | inp, targ_pred_idx, grad_pred_idx,
142 | sorted_token_indices=sorted_token_indices)
143 |
144 | edited_editable_segs = self._sample_edits(targ_pred_label, inp,
145 | masked_inp, targ_pred_idx, num_spans=num_spans,
146 | orig_spans=orig_spans, max_length=max_length)
147 | edited_cands = [None] * len(edited_editable_segs)
148 | for idx, es in enumerate(edited_editable_segs):
149 | cand = {}
150 | es = self.truncate_editable_segs([es], inp=inp)[0]
151 | cand['edited_input'] = self.get_input_from_editable_seg(inp, es)
152 | cand['edited_editable_seg'] = es
153 | edited_cands[idx] = cand
154 |
155 | return edited_cands, masked_inp
156 |
157 | def _prepare_input_for_editor(
158 | self, inp, targ_pred_idx, grad_pred_idx,
159 | sorted_token_indices = None):
160 | """ Helper function that prepares masked input for Editor. """
161 |
162 | tokens = self.tokenizer_wrapper.tokenize(inp)[:-1]
163 | tokens = [t.text for t in tokens]
164 |
165 | if sorted_token_indices is not None:
166 | num_return_toks = math.ceil(self.masker.mask_frac * len(tokens))
167 | token_ind_to_mask = sorted_token_indices[:num_return_toks]
168 | grouped_ind_to_mask, token_ind_to_mask, masked_inp, orig_spans = \
169 | self.masker.get_masked_string(
170 | inp, grad_pred_idx,
171 | editor_mask_indices=token_ind_to_mask
172 | )
173 |
174 | else:
175 | grouped_ind_to_mask, token_ind_to_mask, masked_inp, orig_spans = \
176 | self.masker.get_masked_string(inp, grad_pred_idx)
177 |
178 | max_length = math.ceil((self.masker.mask_frac + 0.2) * \
179 | len(sorted_token_indices))
180 | num_spans = len(grouped_ind_to_mask)
181 |
182 | return num_spans, token_ind_to_mask, masked_inp, orig_spans, max_length
183 |
184 | def _process_gen(self, masked_inp, gen, sentinel_toks):
185 | """ Helper function that processes decoded gen """
186 |
187 | bad_gen = False
188 | first_bad_tok = None
189 |
190 | # Hacky: If sentinel tokens are consecutive, then re split won't work
191 | gen = gen.replace("> prefix, etc.
194 | gen = gen[gen.find(""):]
195 |
196 | # Sanity check
197 | assert not gen.startswith(self.tokenizer.pad_token)
198 |
199 | # This is for baseline T5 which does not handle masked last tokens well.
200 | # Baseline often predicts as last token instead of sentinel tok.
201 | # Heuristically treating the first tok as the final sentinel tok.
202 | # TODO: will this mess things up for non-baseline editor_models?
203 | if sentinel_toks[-1] not in gen:
204 | first_eos_token_idx = gen.find(self.tokenizer.eos_token)
205 | gen = gen[:first_eos_token_idx] + sentinel_toks[-1] + \
206 | gen[first_eos_token_idx + len(self.tokenizer.eos_token):]
207 |
208 | last_sentin_idx = gen.find(sentinel_toks[-1])
209 | if last_sentin_idx != -1:
210 | # If the last token we are looking for is in generation, truncate
211 | gen = gen[:last_sentin_idx + len(sentinel_toks[-1])]
212 |
213 | # If is in generation, truncate
214 | eos_token_idx = gen.find(self.tokenizer.eos_token)
215 | if eos_token_idx != -1:
216 | gen = gen[:eos_token_idx]
217 |
218 | # Check if every sentinel token is in the gen
219 | for x in sentinel_toks:
220 | if x not in gen:
221 | bad_gen = True
222 | first_bad_tok = self.tokenizer.encode(x)[0]
223 | break
224 |
225 | tokens = list(filter(None, re.split(
226 | '||', gen)))
227 | gen_sentinel_toks = re.findall(
228 | '||', gen)
229 |
230 | gen_sentinel_toks = gen_sentinel_toks[:len(tokens)]
231 |
232 | temp = masked_inp
233 | ctr = 0
234 | prev_temp = temp
235 | tok_sentinel_iterator = zip(tokens, gen_sentinel_toks)
236 | for idx, (token, sentinel_tok) in enumerate(tok_sentinel_iterator):
237 | sentinel_idx = sentinel_tok[-2:-1] if len(sentinel_tok) == 12 \
238 | else sentinel_tok[-3:-1]
239 | sentinel_idx = int(sentinel_idx)
240 |
241 | # Check order of generated sentinel tokens
242 | if sentinel_idx != ctr:
243 | first_bad_tok = self.tokenizer.encode(f"")[0]
244 | bad_gen = True
245 | break
246 |
247 | if idx != 0:
248 | temp = temp.replace(prev_sentinel_tok, prev_token)
249 | prev_sentinel_tok = sentinel_tok
250 | prev_token = token
251 |
252 | # If last replacement, make sure final sentinel token was generated
253 | is_last = (idx == len(tokens)-1)
254 | if is_last and gen_sentinel_toks[-1] in sentinel_toks and not bad_gen:
255 | if " " + sentinel_tok in temp:
256 | temp = temp.replace(" " + sentinel_tok, token)
257 | elif "-" + sentinel_tok in temp:
258 | # If span follows "-" character, remove first white space
259 | if token[0] == " ":
260 | token = token[1:]
261 | temp = temp.replace(sentinel_tok, token)
262 | else:
263 | temp = temp.replace(sentinel_tok, token)
264 | else:
265 | first_bad_tok = self.tokenizer.encode("")[0]
266 | ctr += 1
267 |
268 | return bad_gen, first_bad_tok, temp, gen
269 |
270 | def _get_pred_with_replacement(self, temp_gen, orig_spans, *args):
271 | """ Replaces sentinel tokens in gen with orig text and returns pred.
272 | Used for intermediate bad generations. """
273 |
274 | orig_tokens = list(filter(None, re.split(
275 | '|', orig_spans)))
276 | orig_sentinel_toks = re.findall('|', orig_spans)
277 |
278 | for token, sentinel_tok in zip(orig_tokens, orig_sentinel_toks[:-1]):
279 | if sentinel_tok in temp_gen:
280 | temp_gen = temp_gen.replace(sentinel_tok, token)
281 | temp_instance = self.dataset_reader.text_to_instance(temp_gen)
282 | return temp_gen, self.predictor.predict_instance(temp_instance)
283 |
284 |
285 | def _sample_edits(
286 | self, targ_pred_label, inp, masked_editable_seg, targ_pred_idx,
287 | num_spans = None, orig_spans = None, max_length = None):
288 | """ Returns self.num_gens copies of masked_editable_seg with infills.
289 | Called by get_candidates(). """
290 |
291 | self.editor_model.eval()
292 |
293 | editor_input = self.get_editor_input(
294 | targ_pred_label, masked_editable_seg, inp)
295 |
296 | editor_inputs = [editor_input]
297 | editable_segs = [masked_editable_seg]
298 | span_end_offsets = [num_spans]
299 | orig_token_ids_lst = [self.tokenizer.encode(orig_spans)[:-1]]
300 | orig_spans_lst = [orig_spans]
301 | masked_token_ids_lst = [self.tokenizer.encode(editor_input)[:-1]]
302 |
303 | k_intermediate = 3
304 |
305 | sentinel_start = self.tokenizer.encode("")[0]
306 | sentinel_end = self.tokenizer.encode("")[0]
307 |
308 | num_sub_rounds = 0
309 | edited_editable_segs = [] # list of tuples with meta information
310 |
311 | max_sub_rounds = 3
312 | while editable_segs != []:
313 |
314 | # Break if past max sub rounds
315 | if num_sub_rounds > max_sub_rounds:
316 | break
317 |
318 | new_editor_inputs = []
319 | new_editable_segs = []
320 | new_span_end_offsets = []
321 | new_orig_token_ids_lst = []
322 | new_orig_spans_lst = []
323 | new_masked_token_ids_lst = []
324 | num_inputs = len(editor_inputs)
325 |
326 | iterator = enumerate(zip(
327 | editor_inputs, editable_segs, masked_token_ids_lst,
328 | span_end_offsets, orig_token_ids_lst, orig_spans_lst))
329 | for inp_idx, (editor_input, editable_seg, masked_token_ids, \
330 | span_end, orig_token_ids, orig_spans) in iterator:
331 |
332 | num_inputs = len(editor_inputs)
333 | num_return_seqs = int(math.ceil(self.num_gens/num_inputs)) \
334 | if num_sub_rounds != 0 else self.num_gens
335 | num_beams = self.num_beams if num_sub_rounds == 0 \
336 | else num_return_seqs
337 | last_sentin = f""
338 | end_token_id = self.tokenizer.convert_tokens_to_ids(last_sentin)
339 | masked_token_ids_tensor = torch.LongTensor(
340 | masked_token_ids + [self.tokenizer.eos_token_id]
341 | ).unsqueeze(0).to(self.device)
342 | eos_id = self.tokenizer.eos_token_id
343 | bad_tokens_ids = [[x] for x in range(
344 | sentinel_start, end_token_id)] + [[eos_id]]
345 | max_length = max(int(4/3 * max_length), 200)
346 | logger.info(wrap_text("Sub round: " + str(num_sub_rounds)))
347 | logger.info(wrap_text(f"Input: {inp_idx} of {num_inputs-1}"))
348 | logger.info(wrap_text(f"Last sentinel: {last_sentin}"))
349 | logger.info(wrap_text("INPUT TO EDITOR: " + \
350 | f"{self.tokenizer.decode(masked_token_ids)}"))
351 |
352 | with torch.no_grad():
353 | if self.generate_type == "beam":
354 | output = self.editor_model.generate(
355 | input_ids=masked_token_ids_tensor,
356 | num_beams=num_beams,
357 | num_return_sequences=num_return_seqs,
358 | no_repeat_ngram_size=self.no_repeat_ngram_size,
359 | eos_token_id=end_token_id,
360 | early_stopping=True,
361 | length_penalty=self.length_penalty,
362 | bad_words_ids=bad_tokens_ids,
363 | max_length=max_length)
364 |
365 | elif self.generate_type == "sample":
366 | output = self.editor_model.generate(
367 | input_ids=masked_token_ids_tensor,
368 | do_sample=True,
369 | top_p=self.top_p,
370 | top_k=self.top_k,
371 | num_return_sequences=num_return_seqs,
372 | no_repeat_ngram_size=self.no_repeat_ngram_size,
373 | eos_token_id=end_token_id,
374 | early_stopping=True,
375 | length_penalty=self.length_penalty,
376 | bad_words_ids=bad_tokens_ids,
377 | max_length=max_length)
378 | output = output.cpu()
379 | del masked_token_ids_tensor
380 | torch.cuda.empty_cache()
381 |
382 | batch_decoded = self.tokenizer.batch_decode(output)
383 | num_gens_with_pad = 0
384 | num_bad_gens = 0
385 | temp_edited_editable_segs = []
386 | logger.info(wrap_text("first batch: " + batch_decoded[0]))
387 | for batch_idx, batch in enumerate(batch_decoded):
388 | sentinel_toks = [f"" for idx in \
389 | range(0, span_end + 1)]
390 | bad_gen, first_bad_tok, temp, stripped_batch = \
391 | self._process_gen(editable_seg, batch, sentinel_toks)
392 |
393 | if len(sentinel_toks) > 3:
394 | assert sentinel_toks[-2] in editor_input
395 |
396 | if "" in batch[4:]:
397 | num_gens_with_pad += 1
398 | if bad_gen:
399 |
400 | num_bad_gens += 1
401 | temp_span_end_offset = first_bad_tok - end_token_id + 1
402 |
403 | new_editable_token_ids = np.array(
404 | self.tokenizer.encode(temp)[:-1])
405 |
406 | sentinel_indices = np.where(
407 | (new_editable_token_ids >= sentinel_start) & \
408 | (new_editable_token_ids <= sentinel_end))[0]
409 | new_first_token = max(
410 | new_editable_token_ids[sentinel_indices])
411 | diff = sentinel_end - new_first_token
412 | new_editable_token_ids[sentinel_indices] += diff
413 |
414 | new_span_end_offsets.append(len(sentinel_indices))
415 |
416 | new_editable_seg = self.tokenizer.decode(
417 | new_editable_token_ids)
418 | new_editable_segs.append(new_editable_seg)
419 |
420 | new_input = self.get_editor_input(targ_pred_label,
421 | new_editable_seg, inp)
422 |
423 | new_masked_token_ids = self.tokenizer.encode(new_input)[:-1]
424 | new_masked_token_ids_lst.append(new_masked_token_ids)
425 |
426 | # Hacky but re-decode to remove spaces b/w sentinels
427 | new_editor_input = self.tokenizer.decode(
428 | new_masked_token_ids)
429 | new_editor_inputs.append(new_editor_input)
430 |
431 | # Get orig token ids from new first token and on
432 | new_orig_token_ids = np.array(orig_token_ids[np.where(
433 | orig_token_ids == new_first_token)[0][0]:])
434 | sentinel_indices = np.where((
435 | new_orig_token_ids >= sentinel_start) & \
436 | (new_orig_token_ids <= sentinel_end))[0]
437 | new_orig_token_ids[sentinel_indices] += diff
438 | new_orig_token_ids_lst.append(new_orig_token_ids)
439 | new_orig_spans = self.tokenizer.decode(new_orig_token_ids)
440 | new_orig_spans_lst.append(new_orig_spans)
441 |
442 | else:
443 | temp_edited_editable_segs.append(temp)
444 | assert "" not in temp
447 |
448 | edited_editable_segs.extend(temp_edited_editable_segs)
449 |
450 | if new_editor_inputs == []:
451 | break
452 |
453 | _, unique_batch_indices = np.unique(new_editor_inputs,
454 | return_index=True)
455 |
456 | targ_probs = [-1] * len(new_editable_segs)
457 | for idx in unique_batch_indices:
458 | ot = new_orig_spans_lst[idx].replace("", "")
459 | temp, pred = self._get_pred_with_replacement(
460 | new_editable_segs[idx], ot, inp)
461 | pred = add_probs(pred)
462 | targ_probs[idx] = pred['probs'][targ_pred_idx]
463 | predicted_label = self.ints_to_labels[np.argmax(pred['probs'])]
464 | contrast_label = self.ints_to_labels[targ_pred_idx]
465 | if predicted_label == contrast_label:
466 | edited_editable_segs.append(temp)
467 |
468 | highest_indices = np.argsort(targ_probs)[-k_intermediate:]
469 | filt_indices = [idx for idx in highest_indices \
470 | if targ_probs[idx] != -1]
471 | editor_inputs = [new_editor_inputs[idx] for idx in filt_indices]
472 | editable_segs = [new_editable_segs[idx] for idx in filt_indices]
473 | span_end_offsets = [new_span_end_offsets[idx] for idx in filt_indices]
474 | orig_token_ids_lst = [new_orig_token_ids_lst[idx] for idx in filt_indices]
475 | orig_spans_lst = [new_orig_spans_lst[idx] for idx in filt_indices]
476 | masked_token_ids_lst = [new_masked_token_ids_lst[idx] for idx in filt_indices]
477 |
478 | sys.stdout.flush()
479 | num_sub_rounds += 1
480 |
481 | for idx, es in enumerate(edited_editable_segs):
482 | assert es.find("") in [len(es)-4, -1]
483 | edited_editable_segs[idx] = es.replace("", " ")
484 | assert "" not in es, \
487 | f" should not be in edited inp: {edited_editable_segs[idx][0]}"
488 |
489 |
490 | return set(edited_editable_segs)
491 |
492 | class RaceEditor(Editor):
493 | def __init__(
494 | self,
495 | tokenizer_wrapper,
496 | tokenizer,
497 | editor_model,
498 | masker,
499 | num_gens = 30,
500 | num_beams = 30,
501 | grad_pred = "contrast",
502 | generate_type = "sample",
503 | length_penalty = 1.0,
504 | no_repeat_ngram_size = 2,
505 | top_k = 50,
506 | top_p = 0.92,
507 | verbose = False,
508 | editable_key = "article"
509 | ):
510 | super().__init__(
511 | tokenizer_wrapper, tokenizer, editor_model, masker,
512 | num_gens=num_gens, num_beams=num_beams,
513 | ints_to_labels=[str(idx) for idx in range(4)],
514 | grad_pred=grad_pred,
515 | generate_type=generate_type,
516 | no_repeat_ngram_size=no_repeat_ngram_size,
517 | top_k=top_k, top_p=top_p,
518 | length_penalty=length_penalty,
519 | verbose=verbose)
520 |
521 | self.editable_key = editable_key
522 | if self.editable_key not in ["question", "article"]:
523 | raise ValueError("Invalid value for editable_key")
524 |
525 | def _get_pred_with_replacement(self, temp_gen, orig_spans, inp):
526 | """ Replaces sentinel tokens in gen with orig text and returns pred.
527 | Used for intermediate bad generations. """
528 |
529 | orig_tokens = list(filter(None, re.split(
530 | '||', orig_spans)))
531 | orig_sentinel_toks = re.findall(
532 | '||', orig_spans)
533 |
534 | for token, sentinel_tok in zip(orig_tokens, orig_sentinel_toks[:-1]):
535 | if sentinel_tok in temp_gen:
536 | temp_gen = temp_gen.replace(sentinel_tok, token)
537 | # temp_gen is article for RACE
538 | temp_instance = self.dataset_reader.text_to_instance(
539 | inp["id"], temp_gen, inp["question"], inp["options"])[0]
540 | return temp_gen, self.predictor.predict_instance(temp_instance)
541 |
542 | def get_editable_seg_from_input(self, inp):
543 | """ Map whole input -> editable seg. """
544 |
545 | return inp[self.editable_key]
546 |
547 | def get_input_from_editable_seg(self, inp, editable_seg):
548 | """ Map editable seg -> whole input. """
549 |
550 | new_inp = inp.copy()
551 | new_inp[self.editable_key] = editable_seg
552 | return new_inp
553 |
554 | def truncate_editable_segs(self, editable_segs, inp = None):
555 | """ Truncate editable segments to max length of Predictor. """
556 |
557 | trunc_inputs = [None] * len(editable_segs)
558 | instance, length_lst, max_length_lst = self.input_to_instance(
559 | inp, return_tuple = True)
560 | for s_idx, es in enumerate(editable_segs):
561 | editable_toks = get_predictor_tokenized(self.predictor, es)
562 | predic_tok_end_idx = len(editable_toks)
563 | predic_tok_end_idx = min(predic_tok_end_idx, max(max_length_lst))
564 | last_index = editable_toks[predic_tok_end_idx - 1].idx_end
565 | editable_seg = es[:last_index]
566 | trunc_inputs[s_idx] = editable_seg
567 | return trunc_inputs
568 |
569 | def get_editor_input(self, targ_pred_label, masked_editable_seg, inp):
570 | """ Format input for editor """
571 |
572 | options = inp["options"]
573 | if masked_editable_seg is None:
574 | article = inp["article"]
575 | question = inp["question"]
576 | else: # masked editable input given
577 | if self.editable_key == "article":
578 | article = masked_editable_seg
579 | question = inp["question"]
580 | elif self.editable_key == "question":
581 | article = inp["article"]
582 | question = masked_editable_seg
583 |
584 | editor_input = format_multiple_choice_input(
585 | article, question, options, int(targ_pred_label))
586 | return editor_input
587 |
588 | def input_to_instance(
589 | self, inp, editable_seg = None, return_tuple = False):
590 | """ Convert input to AllenNLP instance object """
591 |
592 | if editable_seg is None:
593 | article = inp["article"]
594 | question = inp["question"]
595 | else: # editable input given
596 | if self.editable_key == "article":
597 | article = editable_seg
598 | question = inp["question"]
599 | elif self.editable_key == "question":
600 | article = inp["article"]
601 | question = editable_seg
602 | output = self.dataset_reader.text_to_instance(
603 | inp["id"], article, question, inp["options"])
604 | if return_tuple:
605 | return output
606 | return output[0]
607 |
608 | def get_sorted_token_indices(self, inp, grad_pred_idx):
609 | """ Get token indices to mask, sorted by gradient value """
610 |
611 | editable_seg = self.get_editable_seg_from_input(inp)
612 |
613 | inst, length_lst, _ = self.input_to_instance(inp, return_tuple=True)
614 | editable_toks = get_predictor_tokenized(self.predictor, editable_seg)
615 | num_editab_toks = len(editable_toks)
616 |
617 | predic_tok_end_idx = len(editable_toks)
618 | predic_tok_end_idx = min(
619 | predic_tok_end_idx, length_lst[grad_pred_idx])
620 |
621 | if self.editable_key == "article":
622 | predic_tok_start_idx = 0
623 | elif self.editable_key == "question":
624 | predic_tok_start_idx = length_lst[grad_pred_idx]
625 | predic_tok_end_idx = length_lst[grad_pred_idx] + num_editab_toks
626 |
627 | editable_toks = self.tokenizer_wrapper.tokenize(editable_seg)[:-1]
628 | sorted_token_indices = self.masker.get_important_editor_tokens(
629 | editable_seg, grad_pred_idx, editable_toks,
630 | num_return_toks=len(editable_toks),
631 | labeled_instance=inst,
632 | predic_tok_end_idx=predic_tok_end_idx,
633 | predic_tok_start_idx=predic_tok_start_idx)
634 | return sorted_token_indices
635 |
636 | def _prepare_input_for_editor(self, inp, targ_pred_idx, grad_pred_idx,
637 | sorted_token_indices = None):
638 | """ Helper function that prepares masked input for Editor. """
639 |
640 | editable_seg = self.get_editable_seg_from_input(inp)
641 |
642 | tokens = [t.text for t in \
643 | self.tokenizer_wrapper.tokenize(editable_seg)[:-1]]
644 |
645 | instance, length_lst, _ = self.input_to_instance(
646 | inp, return_tuple=True)
647 | editable_toks = get_predictor_tokenized(self.predictor, editable_seg)
648 | num_editab_toks = len(editable_toks)
649 | predic_tok_end_idx = len(editable_toks)
650 | predic_tok_end_idx = min(
651 | predic_tok_end_idx, length_lst[grad_pred_idx])
652 |
653 | if self.editable_key == "article":
654 | predic_tok_start_idx = 0
655 | elif self.editable_key == "question":
656 | predic_tok_start_idx = length_lst[grad_pred_idx]
657 | predic_tok_end_idx = length_lst[grad_pred_idx] + num_editab_toks
658 |
659 | if sorted_token_indices is not None:
660 | num_return_toks = math.ceil(
661 | self.masker.mask_frac * len(tokens))
662 | token_ind_to_mask = sorted_token_indices[:num_return_toks]
663 |
664 | grouped_ind_to_mask, token_ind_to_mask, masked_inp, orig_spans = \
665 | self.masker.get_masked_string(editable_seg, grad_pred_idx,
666 | editor_mask_indices=token_ind_to_mask,
667 | predic_tok_start_idx=predic_tok_start_idx,
668 | predic_tok_end_idx=predic_tok_end_idx)
669 |
670 | else:
671 | grouped_ind_to_mask, token_ind_to_mask, masked_inp, orig_spans = \
672 | self.masker.get_masked_string(
673 | editable_seg, grad_pred_idx,
674 | labeled_instance=instance,
675 | predic_tok_end_idx=predic_tok_end_idx,
676 | predic_tok_start_idx=predic_tok_start_idx)
677 |
678 | num_spans = len(grouped_ind_to_mask)
679 | max_length = math.ceil(
680 | (self.masker.mask_frac+0.2) * len(sorted_token_indices))
681 |
682 | masked_inp = masked_inp.replace(self.tokenizer.eos_token, " ")
683 | return num_spans, token_ind_to_mask, masked_inp, orig_spans, max_length
684 |
--------------------------------------------------------------------------------
/src/masker.py:
--------------------------------------------------------------------------------
1 | import more_itertools as mit
2 | import logging
3 | import random
4 | import math
5 | import numpy as np
6 |
7 | from allennlp.data.batch import Batch
8 | from allennlp.nn import util
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | from torch import backends
13 |
14 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
15 | logging.basicConfig(format=FORMAT)
16 |
17 | class MaskError(Exception):
18 | pass
19 |
20 | class Masker():
21 | """
22 | Class used to mask inputs for Editors.
23 | Two subclasses: RandomMasker and GradientMasker
24 |
25 | mask_frac: float
26 | Fraction of input tokens to mask.
27 | editor_to_wrapper: allennlp.data.tokenizers.tokenizer
28 | Wraps around Editor tokenizer.
29 | Has capabilities for mapping Predictor tokens to Editor tokens.
30 | max_tokens: int
31 | Maximum number of tokens a masked input should have.
32 | """
33 |
34 | def __init__(
35 | self,
36 | mask_frac,
37 | editor_tok_wrapper,
38 | max_tokens
39 | ):
40 | self.mask_frac = mask_frac
41 | self.editor_tok_wrapper = editor_tok_wrapper
42 | self.max_tokens = max_tokens
43 |
44 | def _get_mask_indices(self, editor_toks):
45 | """ Helper function to get indices of Editor tokens to mask. """
46 | raise NotImplementedError("Need to implement this in subclass")
47 |
48 | def get_all_masked_strings(self, editable_seg):
49 | """ Returns a list of masked inps/targets where each inp has
50 | one word replaced by a sentinel token.
51 | Used for calculating fluency. """
52 |
53 | editor_toks = self.editor_tok_wrapper.tokenize(editable_seg)
54 | masked_segs = [None] * len(editor_toks)
55 | labels = [None] * len(editor_toks)
56 |
57 | for i, token in enumerate(editor_toks):
58 | token_start, token_end = token.idx, token.idx_end
59 | masked_segs[i] = editable_seg[:token_start] + \
60 | Masker._get_sentinel_token(0) + editable_seg[token_end:]
61 | labels[i] = Masker._get_sentinel_token(0) + \
62 | editable_seg[token_start:token_end] + \
63 | Masker._get_sentinel_token(1)
64 |
65 | return masked_segs, labels
66 |
67 | def _get_sentinel_token(idx):
68 | """ Helper function to get sentinel token based on given idx """
69 |
70 | return ""
71 |
72 | def _get_grouped_mask_indices(
73 | self, editable_seg, pred_idx, editor_mask_indices,
74 | editor_toks, **kwargs):
75 | """ Groups consecutive mask indices.
76 | Applies heuristics to enable better generation:
77 | - If > 27 spans, mask tokens b/w neighboring spans as well.
78 | (See Appendix: observed degeneration after 27th sentinel token)
79 | - Mask max of 100 spans (since there are 100 sentinel tokens in T5)
80 | """
81 |
82 | if editor_mask_indices is None:
83 | editor_mask_indices = self._get_mask_indices(
84 | editable_seg, editor_toks, pred_idx, **kwargs)
85 |
86 | new_editor_mask_indices = set(editor_mask_indices)
87 | grouped_editor_mask_indices = [list(group) for group in \
88 | mit.consecutive_groups(sorted(new_editor_mask_indices))]
89 |
90 | if len(grouped_editor_mask_indices) > 27:
91 | for t_idx in editor_mask_indices:
92 | if t_idx + 2 in editor_mask_indices:
93 | new_editor_mask_indices.add(t_idx + 1)
94 |
95 | grouped_editor_mask_indices = [list(group) for group in \
96 | mit.consecutive_groups(sorted(new_editor_mask_indices))]
97 |
98 | if len(grouped_editor_mask_indices) > 27:
99 | for t_idx in editor_mask_indices:
100 | if t_idx + 3 in editor_mask_indices:
101 | new_editor_mask_indices.add(t_idx + 1)
102 | new_editor_mask_indices.add(t_idx + 2)
103 |
104 | new_editor_mask_indices = list(new_editor_mask_indices)
105 | grouped_editor_mask_indices = [list(group) for group in \
106 | mit.consecutive_groups(sorted(new_editor_mask_indices))]
107 |
108 | grouped_editor_mask_indices = grouped_editor_mask_indices[:99]
109 | return grouped_editor_mask_indices
110 |
111 | def get_masked_string(
112 | self, editable_seg, pred_idx,
113 | editor_mask_indices = None, **kwargs):
114 | """ Gets masked string masking tokens w highest predictor gradients.
115 | Requires mapping predictor tokens to Editor tokens because edits are
116 | made on Editor tokens. """
117 |
118 | editor_toks = self.editor_tok_wrapper.tokenize(editable_seg)
119 | grpd_editor_mask_indices = self._get_grouped_mask_indices(
120 | editable_seg, pred_idx, editor_mask_indices,
121 | editor_toks, **kwargs)
122 |
123 | span_idx = len(grpd_editor_mask_indices) - 1
124 | label = Masker._get_sentinel_token(len(grpd_editor_mask_indices))
125 | masked_seg = editable_seg
126 |
127 | # Iterate over spans in reverse order and mask tokens
128 | for span in grpd_editor_mask_indices[::-1]:
129 |
130 | span_char_start = editor_toks[span[0]].idx
131 | span_char_end = editor_toks[span[-1]].idx_end
132 | end_token_idx = span[-1]
133 |
134 | # If last span tok is last t5 tok, heuristically set char end idx
135 | if span_char_end is None and end_token_idx == len(editor_toks)-1:
136 | span_char_end = span_char_start + 1
137 |
138 | if not span_char_end > span_char_start:
139 | raise MaskError
140 |
141 | label = Masker._get_sentinel_token(span_idx) + \
142 | masked_seg[span_char_start:span_char_end] + label
143 | masked_seg = masked_seg[:span_char_start] + \
144 | Masker._get_sentinel_token(span_idx) + \
145 | masked_seg[span_char_end:]
146 | span_idx -= 1
147 |
148 | return grpd_editor_mask_indices, editor_mask_indices, masked_seg, label
149 |
150 | class RandomMasker(Masker):
151 | """ Masks randomly chosen spans. """
152 |
153 | def __init__(
154 | self,
155 | mask_frac,
156 | editor_tok_wrapper,
157 | max_tokens
158 | ):
159 | super().__init__(mask_frac, editor_tok_wrapper, max_tokens)
160 |
161 | def _get_mask_indices(self, editable_seg, editor_toks, pred_idx, **kwargs):
162 | """ Helper function to get indices of Editor tokens to mask. """
163 |
164 | num_tokens = min(self.max_tokens, len(editor_toks))
165 | return random.sample(
166 | range(num_tokens), math.ceil(self.mask_frac * num_tokens))
167 |
168 | class GradientMasker(Masker):
169 | """ Masks spans based on gradients of Predictor wrt. given predicted label.
170 |
171 | mask_frac: float
172 | Fraction of input tokens to mask.
173 | editor_to_wrapper: allennlp.data.tokenizers.tokenizer
174 | Wraps around Editor tokenizer.
175 | Has capabilities for mapping Predictor tokens to Editor tokens.
176 | max_tokens: int
177 | Maximum number of tokens a masked input should have.
178 | grad_type: str, one of ["integrated_l1", "integrated_signed",
179 | "normal_l1", "normal_signed", "normal_l2", "integrated_l2"]
180 | Specifies how gradient value should be calculated
181 | Integrated vs. normal:
182 | Integrated: https://arxiv.org/pdf/1703.01365.pdf
183 | Normal: 'Vanilla' gradient
184 | Signed vs. l1 vs. l2:
185 | Signed: Sum gradients over embedding dimension.
186 | l1: Take l1 norm over embedding dimension.
187 | l2: Take l2 norm over embedding dimension.
188 | sign_direction: One of [-1, 1, None]
189 | When grad_type is signed, determines whether we want to get most
190 | negative or positive gradient values.
191 | This should depend on what label is being used
192 | (pred_idx argument to get_masked_string).
193 | For example, Stage One, we want to mask tokens that push *towards*
194 | gold label, whereas during Stage Two, we want to mask tokens that
195 | push *away* from the target label.
196 | Sign direction plays no role if only gradient *magnitudes* are used
197 | (i.e. if grad_type is not signed, but involves taking the l1/l2 norm.)
198 | num_integrated_grad_steps: int
199 | Hyperparameter for integrated gradients.
200 | Only used when grad_type is one of integrated types.
201 | """
202 |
203 | def __init__(
204 | self,
205 | mask_frac,
206 | editor_tok_wrapper,
207 | predictor,
208 | max_tokens,
209 | grad_type = "normal_l2",
210 | sign_direction = None,
211 | num_integrated_grad_steps = 10
212 | ):
213 | super().__init__(mask_frac, editor_tok_wrapper, max_tokens)
214 |
215 | self.predictor = predictor
216 | self.grad_type = grad_type
217 | self.num_integrated_grad_steps = num_integrated_grad_steps
218 | self.sign_direction = sign_direction
219 |
220 | if ("signed" in self.grad_type and sign_direction is None):
221 | error_msg = "To calculate a signed gradient value, need to " + \
222 | "specify sign direction but got None for sign_direction"
223 | raise ValueError(error_msg)
224 |
225 | if sign_direction not in [1, -1, None]:
226 | error_msg = f"Invalid value for sign_direction: {sign_direction}"
227 | raise ValueError(error_msg)
228 |
229 | temp_tokenizer = self.predictor._dataset_reader._tokenizer
230 |
231 | # Used later to avoid skipping special tokens like
232 | self.predictor_special_toks = \
233 | temp_tokenizer.sequence_pair_start_tokens + \
234 | temp_tokenizer.sequence_pair_mid_tokens + \
235 | temp_tokenizer.sequence_pair_end_tokens + \
236 | temp_tokenizer.single_sequence_start_tokens + \
237 | temp_tokenizer.single_sequence_end_tokens
238 |
239 | def _get_gradients_by_prob(self, instance, pred_idx):
240 | """ Helper function to get gradient values of predicted logit
241 | Largely copied from Predictor class of AllenNLP """
242 |
243 | instances = [instance]
244 | original_param_name_to_requires_grad_dict = {}
245 | for param_name, param in self.predictor._model.named_parameters():
246 | original_param_name_to_requires_grad_dict[param_name] = \
247 | param.requires_grad
248 | param.requires_grad = True
249 |
250 | embedding_gradients: List[Tensor] = []
251 | hooks: List[RemovableHandle] = \
252 | self.predictor._register_embedding_gradient_hooks(
253 | embedding_gradients)
254 |
255 | dataset = Batch(instances)
256 | dataset.index_instances(self.predictor._model.vocab)
257 | dataset_tensor_dict = util.move_to_device(
258 | dataset.as_tensor_dict(), self.predictor.cuda_device)
259 | with backends.cudnn.flags(enabled=False):
260 | outputs = self.predictor._model.make_output_human_readable(
261 | self.predictor._model.forward(**dataset_tensor_dict)
262 | )
263 |
264 | # Differs here
265 | prob = outputs["logits"][0][pred_idx]
266 |
267 | self.predictor._model.zero_grad()
268 | prob.backward()
269 |
270 | for hook in hooks:
271 | hook.remove()
272 |
273 | grad_dict = dict()
274 | for idx, grad in enumerate(embedding_gradients):
275 | key = "grad_input_" + str(idx + 1)
276 | grad_dict[key] = grad.detach().cpu().numpy()
277 |
278 | # Restore original requires_grad values of the parameters
279 | for param_name, param in self.predictor._model.named_parameters():
280 | param.requires_grad = \
281 | original_param_name_to_requires_grad_dict[param_name]
282 |
283 | del dataset_tensor_dict
284 | torch.cuda.empty_cache()
285 | return grad_dict, outputs
286 |
287 | def _get_word_positions(self, predic_tok, editor_toks):
288 | """ Helper function to map from (sub)tokens of Predictor to
289 | token indices of Editor tokenizer. Assumes the tokens are in order.
290 | Raises MaskError if tokens cannot be mapped
291 | This sometimes happens due to inconsistencies in way text is
292 | tokenized by different tokenizers. """
293 |
294 | return_word_idx = None
295 | predic_tok_start = predic_tok.idx
296 | predic_tok_end = predic_tok.idx_end
297 |
298 | if predic_tok_start is None or predic_tok_end is None:
299 | return [], [], []
300 |
301 | class Found(Exception): pass
302 | try:
303 | for word_idx, word_token in reversed(list(enumerate(editor_toks))):
304 | if editor_toks[word_idx].idx is None:
305 | continue
306 |
307 | # Ensure predic_tok start >= start of last Editor tok
308 | if word_idx == len(editor_toks) - 1:
309 | if predic_tok_start >= word_token.idx:
310 | return_word_idx = word_idx
311 | raise Found
312 |
313 | # For all other Editor toks, ensure predic_tok start
314 | # >= Editor tok start and < next Editor tok start
315 | elif predic_tok_start >= word_token.idx:
316 | for cand_idx in range(word_idx + 1, len(editor_toks)):
317 | if editor_toks[cand_idx].idx is None:
318 | continue
319 | elif predic_tok_start < editor_toks[cand_idx].idx:
320 | return_word_idx = word_idx
321 | raise Found
322 | except Found:
323 | pass
324 |
325 | if return_word_idx is None:
326 | return [], [], []
327 |
328 | last_idx = return_word_idx
329 | if predic_tok_end > editor_toks[return_word_idx].idx_end:
330 | for next_idx in range(return_word_idx, len(editor_toks)):
331 | if editor_toks[next_idx].idx_end is None:
332 | continue
333 | if predic_tok_end <= editor_toks[next_idx].idx_end:
334 | last_idx = next_idx
335 | break
336 |
337 | return_indices = []
338 | return_starts = []
339 | return_ends = []
340 |
341 | for cand_idx in range(return_word_idx, last_idx + 1):
342 | return_indices.append(cand_idx)
343 | return_starts.append(editor_toks[cand_idx].idx)
344 | return_ends.append(editor_toks[cand_idx].idx_end)
345 | if not predic_tok_start >= editor_toks[return_word_idx].idx:
346 | raise MaskError
347 |
348 | # Sometimes BERT tokenizers add extra tokens if spaces at end
349 | if last_idx != len(editor_toks)-1 and \
350 | predic_tok_end > editor_toks[last_idx].idx_end:
351 | raise MaskError
352 |
353 | return return_indices, return_starts, return_ends
354 |
355 | return_tuple = ([return_word_idx],
356 | [editor_toks[return_word_idx].idx],
357 | [editor_toks[return_word_idx].idx_end])
358 | return return_tuple
359 |
360 | # Copied from AllenNLP integrated gradient
361 | def _integrated_register_forward_hook(self, alpha, embeddings_list):
362 | """ Helper function for integrated gradients """
363 |
364 | def forward_hook(module, inputs, output):
365 | if alpha == 0:
366 | embeddings_list.append(
367 | output.squeeze(0).clone().detach().cpu().numpy())
368 |
369 | output.mul_(alpha)
370 |
371 | embedding_layer = util.find_embedding_layer(self.predictor._model)
372 | handle = embedding_layer.register_forward_hook(forward_hook)
373 | return handle
374 |
375 | # Copied from AllenNLP integrated gradient
376 | def _get_integrated_gradients(self, instance, pred_idx, steps):
377 | """ Helper function for integrated gradients """
378 |
379 | ig_grads: Dict[str, Any] = {}
380 |
381 | # List of Embedding inputs
382 | embeddings_list: List[np.ndarray] = []
383 |
384 | # Exclude the endpoint because we do a left point integral approx
385 | for alpha in np.linspace(0, 1.0, num=steps, endpoint=False):
386 | # Hook for modifying embedding value
387 | handle = self._integrated_register_forward_hook(
388 | alpha, embeddings_list)
389 |
390 | grads = self._get_gradients_by_prob(instance, pred_idx)[0]
391 | handle.remove()
392 |
393 | # Running sum of gradients
394 | if ig_grads == {}:
395 | ig_grads = grads
396 | else:
397 | for key in grads.keys():
398 | ig_grads[key] += grads[key]
399 |
400 | # Average of each gradient term
401 | for key in ig_grads.keys():
402 | ig_grads[key] /= steps
403 |
404 | # Gradients come back in reverse order of order sent into the network
405 | embeddings_list.reverse()
406 |
407 | # Element-wise multiply average gradient by the input
408 | for idx, input_embedding in enumerate(embeddings_list):
409 | key = "grad_input_" + str(idx + 1)
410 | ig_grads[key] *= input_embedding
411 |
412 | return ig_grads
413 |
414 | def get_important_editor_tokens(
415 | self, editable_seg, pred_idx, editor_toks,
416 | labeled_instance=None,
417 | predic_tok_start_idx=None,
418 | predic_tok_end_idx=None,
419 | num_return_toks=None):
420 | """ Gets Editor tokens that correspond to Predictor toks
421 | with highest gradient values (with respect to pred_idx).
422 |
423 | editable_seg:
424 | Original inp to mask.
425 | pred_idx:
426 | Index of label (in Predictor label space) to take gradient of.
427 | editor_toks:
428 | Tokenized words using Editor tokenizer
429 | labeled_instance:
430 | Instance object for Predictor
431 | predic_tok_start_idx:
432 | Start index of Predictor tokens to consider masking.
433 | Helpful for when we only want to mask part of the input,
434 | as in RACE (only mask article). In this case, editable_seg
435 | will contain a subinp of the original input, but the
436 | labeled_instance used to get gradient values will correspond
437 | to the whole original input, and so predic_tok_start_idx
438 | is used to line up gradient values with tokens of editable_seg.
439 | predic_tok_end_idx:
440 | End index of Predictor tokens to consider masking.
441 | Similar to predic_tok_start_idx.
442 | num_return_toks: int
443 | If set to value k, return k Editor tokens that correspond to
444 | Predictor tokens with highest gradients.
445 | If not supplied, use self.mask_frac to calculate # tokens to return
446 | """
447 |
448 | integrated_grad_steps = self.num_integrated_grad_steps
449 |
450 | max_length = self.predictor._dataset_reader._tokenizer._max_length
451 |
452 | temp_tokenizer = self.predictor._dataset_reader._tokenizer
453 | all_predic_toks = temp_tokenizer.tokenize(editable_seg)
454 |
455 | # TODO: Does NOT work for RACE
456 | # If labeled_instance is not supplied, create one
457 | if labeled_instance is None:
458 | labeled_instance = self.predictor.json_to_labeled_instances(
459 | {"sentence": editable_seg})[0]
460 |
461 | grad_type_options = ["integrated_l1", "integrated_signed", "normal_l1",
462 | "normal_signed", "normal_l2", "integrated_l2"]
463 | if self.grad_type not in grad_type_options:
464 | raise ValueError("Invalid value for grad_type")
465 |
466 | # Grad_magnitudes is used for sorting; highest values ordered first.
467 | # -> For signed, to only mask most neg values, multiply by -1
468 |
469 | if self.grad_type == "integrated_l1":
470 | grads = self._get_integrated_gradients(
471 | labeled_instance, pred_idx, steps = integrated_grad_steps)
472 | grad = grads["grad_input_1"][0]
473 | grad_signed = np.sum(abs(grad), axis = 1)
474 | grad_magnitudes = grad_signed.copy()
475 |
476 | elif self.grad_type == "integrated_signed":
477 | grads = self._get_integrated_gradients(
478 | labeled_instance, pred_idx, steps = integrated_grad_steps)
479 | grad = grads["grad_input_1"][0]
480 | grad_signed = np.sum(grad, axis = 1)
481 | grad_magnitudes = self.sign_direction * grad_signed
482 |
483 | elif self.grad_type == "integrated_l2":
484 | grads = self._get_integrated_gradients(
485 | labeled_instance, pred_idx, steps = integrated_grad_steps)
486 | grad = grads["grad_input_1"][0]
487 | grad_signed = [g.dot(g) for g in grad]
488 | grad_magnitudes = grad_signed.copy()
489 |
490 | elif self.grad_type == "normal_l1":
491 | grads = self._get_gradients_by_prob(labeled_instance, pred_idx)[0]
492 | grad = grads["grad_input_1"][0]
493 | grad_signed = np.sum(abs(grad), axis = 1)
494 | grad_magnitudes = grad_signed.copy()
495 |
496 | elif self.grad_type == "normal_signed":
497 | grads = self._get_gradients_by_prob(labeled_instance, pred_idx)[0]
498 | grad = grads["grad_input_1"][0]
499 | grad_signed = np.sum(grad, axis = 1)
500 | grad_magnitudes = self.sign_direction * grad_signed
501 |
502 | elif self.grad_type == "normal_l2":
503 | grads = self._get_gradients_by_prob(labeled_instance, pred_idx)[0]
504 | grad = grads["grad_input_1"][0]
505 | grad_signed = [g.dot(g) for g in grad]
506 | grad_magnitudes = grad_signed.copy()
507 |
508 | # Include only gradient values for editable parts of the inp
509 | if predic_tok_end_idx is not None:
510 | if predic_tok_start_idx is not None:
511 | grad_magnitudes = grad_magnitudes[
512 | predic_tok_start_idx:predic_tok_end_idx]
513 | grad_signed = grad_signed[
514 | predic_tok_start_idx:predic_tok_end_idx]
515 | else:
516 | grad_magnitudes = grad_magnitudes[:predic_tok_end_idx]
517 | grad_signed = grad_signed[:predic_tok_end_idx]
518 |
519 | # Order Predictor tokens from largest to smallest gradient values
520 | ordered_predic_tok_indices = np.argsort(grad_magnitudes)[::-1]
521 |
522 | # List of tuples of (start, end) positions in the original inp to mask
523 | ordered_word_indices_by_grad = [self._get_word_positions(
524 | all_predic_toks[idx], editor_toks)[0] \
525 | for idx in ordered_predic_tok_indices \
526 | if all_predic_toks[idx] not in self.predictor_special_toks]
527 | ordered_word_indices_by_grad = [item for sublist in \
528 | ordered_word_indices_by_grad for item in sublist]
529 |
530 | # Sanity checks
531 | if predic_tok_end_idx is not None:
532 | if predic_tok_start_idx is not None:
533 | assert(len(grad_magnitudes) == \
534 | predic_tok_end_idx - predic_tok_start_idx)
535 | else:
536 | assert(len(grad_magnitudes) == predic_tok_end_idx)
537 | elif max_length is not None and (len(grad_magnitudes)) >= max_length:
538 | assert(max_length == (len(grad_magnitudes)))
539 | else:
540 | assert(len(all_predic_toks) == (len(grad_magnitudes)))
541 |
542 | # Get num words to return
543 | if num_return_toks is None:
544 | num_return_toks = math.ceil(
545 | self.mask_frac * len(ordered_word_indices_by_grad))
546 | highest_editor_tok_indices = []
547 | for idx in ordered_word_indices_by_grad:
548 | if idx not in highest_editor_tok_indices:
549 | highest_editor_tok_indices.append(idx)
550 | if len(highest_editor_tok_indices) == num_return_toks:
551 | break
552 |
553 | highest_predic_tok_indices = ordered_predic_tok_indices[:num_return_toks]
554 | return highest_editor_tok_indices
555 |
556 | def _get_mask_indices(self, editable_seg, editor_toks, pred_idx, **kwargs):
557 | """ Helper function to get indices of Editor tokens to mask. """
558 |
559 | editor_mask_indices = self.get_important_editor_tokens(
560 | editable_seg, pred_idx, editor_toks, **kwargs)
561 | return editor_mask_indices
562 |
--------------------------------------------------------------------------------
/src/predictors/imdb/imdb_dataset_reader.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 | import logging
3 |
4 | from allennlp.data import Tokenizer
5 | from overrides import overrides
6 | from nltk.tree import Tree
7 |
8 |
9 | from allennlp.common.file_utils import cached_path
10 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader
11 | from allennlp.data.fields import LabelField, TextField, Field
12 | from allennlp.data.instance import Instance
13 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
14 | from allennlp.data.tokenizers import Token
15 | from allennlp.data.tokenizers.spacy_tokenizer import SpacyTokenizer
16 | from allennlp.common.checks import ConfigurationError
17 |
18 | from pathlib import Path
19 | from itertools import chain
20 | import os.path as osp
21 | import tarfile
22 | import numpy as np
23 | import math
24 |
25 | from src.predictors.predictor_utils import clean_text
26 | logger = logging.getLogger(__name__)
27 |
28 | TRAIN_VAL_SPLIT_RATIO = 0.9
29 |
30 | def get_label(p):
31 | assert "pos" in p or "neg" in p
32 | return "1" if "pos" in p else "0"
33 |
34 | @DatasetReader.register("imdb")
35 | class ImdbDatasetReader(DatasetReader):
36 |
37 | TAR_URL = 'https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
38 | TRAIN_DIR = 'aclImdb/train'
39 | TEST_DIR = 'aclImdb/test'
40 |
41 | def __init__(self,
42 | token_indexers: Dict[str, TokenIndexer] = None,
43 | tokenizer: Optional[Tokenizer] = None,
44 | **kwargs) -> None:
45 | super().__init__(**kwargs)
46 |
47 | self._tokenizer = tokenizer or SpacyTokenizer()
48 | self._token_indexers = token_indexers or \
49 | {"tokens": SingleIdTokenIndexer()}
50 |
51 | self.random_seed = 0 # numpy random seed
52 |
53 | def get_path(self, file_path):
54 | tar_path = cached_path(self.TAR_URL)
55 | tf = tarfile.open(tar_path, 'r')
56 | cache_dir = Path(osp.dirname(tar_path))
57 | if not (cache_dir / self.TRAIN_DIR).exists() and \
58 | not (cache_dir / self.TEST_DIR).exists():
59 | tf.extractall(cache_dir)
60 |
61 | if file_path == 'train':
62 | pos_dir = osp.join(self.TRAIN_DIR, 'pos')
63 | neg_dir = osp.join(self.TRAIN_DIR, 'neg')
64 | path = chain(
65 | Path(cache_dir.joinpath(pos_dir)).glob('*.txt'),
66 | Path(cache_dir.joinpath(neg_dir)).glob('*.txt'))
67 | elif file_path in ['train_split', 'dev_split']:
68 | pos_dir = osp.join(self.TRAIN_DIR, 'pos')
69 | neg_dir = osp.join(self.TRAIN_DIR, 'neg')
70 | path = chain(
71 | Path(cache_dir.joinpath(pos_dir)).glob('*.txt'),
72 | Path(cache_dir.joinpath(neg_dir)).glob('*.txt'))
73 | path_lst = list(path)
74 | np.random.shuffle(path_lst)
75 | num_train_strings = math.ceil(
76 | TRAIN_VAL_SPLIT_RATIO * len(path_lst))
77 | train_path, path_lst[:num_train_strings]
78 | val_path = path_lst[num_train_strings:]
79 | path = train_path if file_path == "train" else val_path
80 | elif file_path == 'test':
81 | pos_dir = osp.join(self.TEST_DIR, 'pos')
82 | neg_dir = osp.join(self.TEST_DIR, 'neg')
83 | path = chain(
84 | Path(cache_dir.joinpath(pos_dir)).glob('*.txt'),
85 | Path(cache_dir.joinpath(neg_dir)).glob('*.txt'))
86 | elif file_path == "unlabeled":
87 | unsup_dir = osp.join(self.TRAIN_DIR, 'unsup')
88 | path = chain(Path(cache_dir.joinpath(unsup_dir)).glob('*.txt'))
89 | else:
90 | raise ValueError(f"Invalid option for file_path.")
91 | return path
92 |
93 | def get_inputs(self, file_path, return_labels = False):
94 | np.random.seed(self.random_seed)
95 |
96 | path_lst = list(self.get_path(file_path))
97 | strings = [None] * len(path_lst)
98 | labels = [None] * len(path_lst)
99 | for i, p in enumerate(path_lst):
100 | labels[i] = get_label(str(p))
101 | strings[i] = clean_text(p.read_text(),
102 | special_chars=["
", "\t"])
103 | if return_labels:
104 | return strings, labels
105 | return strings
106 |
107 | @overrides
108 | def _read(self, file_path):
109 | np.random.seed(self.random_seed)
110 | tar_path = cached_path(self.TAR_URL)
111 | tf = tarfile.open(tar_path, 'r')
112 | cache_dir = Path(osp.dirname(tar_path))
113 | if not (cache_dir / self.TRAIN_DIR).exists() and \
114 | not (cache_dir / self.TEST_DIR).exists():
115 | tf.extractall(cache_dir)
116 | path = self.get_path(file_path)
117 | for p in path:
118 | label = get_label(str(p))
119 | yield self.text_to_instance(
120 | clean_text(p.read_text(), special_chars=["
", "\t"]),
121 | label)
122 |
123 | def text_to_instance(
124 | self, string: str, label:str = None) -> Optional[Instance]:
125 | tokens = self._tokenizer.tokenize(string)
126 | text_field = TextField(tokens, token_indexers=self._token_indexers)
127 | fields: Dict[str, Field] = {"tokens": text_field}
128 | if label is not None:
129 | fields["label"] = LabelField(label)
130 | return Instance(fields)
131 |
--------------------------------------------------------------------------------
/src/predictors/imdb/imdb_roberta.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_reader": {
3 | "type": "imdb",
4 | "token_indexers": {
5 | "tokens": {
6 | "type": "pretrained_transformer",
7 | "max_length": 512,
8 | "model_name": "roberta-large"
9 | }
10 | },
11 | "tokenizer": {
12 | "type": "pretrained_transformer",
13 | "max_length": 512,
14 | "model_name": "roberta-large"
15 | }
16 | },
17 | "model": {
18 | "type": "basic_classifier",
19 | "namespace": "tags",
20 | "seq2vec_encoder": {
21 | "type": "bert_pooler",
22 | "dropout": 0.1,
23 | "pretrained_model": "roberta-large"
24 | },
25 | "text_field_embedder": {
26 | "token_embedders": {
27 | "tokens": {
28 | "type": "pretrained_transformer",
29 | "max_length": 512,
30 | "model_name": "roberta-large"
31 | }
32 | }
33 | }
34 | },
35 | "train_data_path": "train",
36 | "test_data_path": "test",
37 | "trainer": {
38 | "num_epochs": 5,
39 | "learning_rate_scheduler": {
40 | "type": "slanted_triangular",
41 | "cut_frac": 0.06
42 | },
43 | "optimizer": {
44 | "type": "huggingface_adamw",
45 | "lr": 2e-05,
46 | "weight_decay": 0.1
47 | }
48 | },
49 | "evaluate_on_test": true,
50 | "data_loader": {
51 | "batch_sampler": {
52 | "type": "bucket",
53 | "batch_size": 8,
54 | "sorting_keys": [
55 | "tokens"
56 | ]
57 | }
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/src/predictors/newsgroups/newsgroups_dataset_reader.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 | import logging
3 |
4 | from allennlp.data import Tokenizer
5 | from overrides import overrides
6 | from nltk.tree import Tree
7 |
8 |
9 | from allennlp.common.file_utils import cached_path
10 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader
11 | from allennlp.data.fields import LabelField, TextField, Field
12 | from allennlp.data.instance import Instance
13 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
14 | from allennlp.data.tokenizers import Token
15 | from allennlp.data.tokenizers.spacy_tokenizer import SpacyTokenizer
16 | from allennlp.common.checks import ConfigurationError
17 |
18 | from pathlib import Path
19 | from itertools import chain
20 | import os.path as osp
21 | import tarfile
22 | from tqdm import tqdm as tqdm
23 | import numpy as np
24 | import math
25 | from sklearn.datasets import fetch_20newsgroups
26 |
27 | from src.predictors.predictor_utils import clean_text
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 | TRAIN_VAL_SPLIT_RATIO = 0.9
32 |
33 | @DatasetReader.register("newsgroups")
34 | class NewsgroupsDatasetReader(DatasetReader):
35 | def __init__(self,
36 | token_indexers: Dict[str, TokenIndexer] = None,
37 | tokenizer: Optional[Tokenizer] = None,
38 | **kwargs) -> None:
39 | super().__init__(**kwargs)
40 |
41 | self._tokenizer = tokenizer or SpacyTokenizer()
42 | self._token_indexers = token_indexers or \
43 | {"tokens": SingleIdTokenIndexer()}
44 |
45 | self.random_seed = 0
46 | np.random.seed(self.random_seed)
47 |
48 | def get_data_indices(self, subset):
49 | np.random.seed(self.random_seed)
50 | if subset in ['train', 'test']:
51 | newsgroups_data = fetch_20newsgroups(
52 | subset=subset, remove=("headers", "footers", "quotes"))
53 | data_indices = np.array(range(len(newsgroups_data.data)))
54 | elif subset in ['train_split', 'dev_split']:
55 | newsgroups_data = fetch_20newsgroups(
56 | subset='train', remove=("headers", "footers", "quotes"))
57 | data_indices = np.array(range(len(newsgroups_data.data)))
58 | np.random.shuffle(data_indices)
59 | num_train = math.ceil(TRAIN_VAL_SPLIT_RATIO * len(data_indices))
60 | train_indices = data_indices[:num_train]
61 | val_indices = data_indices[num_train:]
62 | data_indices = train_indices if subset == 'train' else val_indices
63 | else:
64 | raise ValueError("Invalid value for subset")
65 | return data_indices, newsgroups_data
66 |
67 | def get_inputs(self, subset, return_labels = False):
68 | np.random.seed(self.random_seed)
69 | data_indices, newsgroups_data = self.get_data_indices(subset)
70 | strings = [None] * len(data_indices)
71 | labels = [None] * len(data_indices)
72 | for i, idx in enumerate(data_indices):
73 | txt = newsgroups_data.data[idx]
74 | topic = newsgroups_data.target[idx]
75 | label = newsgroups_data['target_names'][topic].split(".")[0]
76 | txt = clean_text(txt, special_chars=["\n", "\t"])
77 | if len(txt) == 0 or len(label) == 0:
78 | strings[i] = None
79 | labels[i] = None
80 | else:
81 | strings[i] = txt
82 | labels[i] = label
83 |
84 | strings = [x for x in strings if x is not None]
85 | labels = [x for x in labels if x is not None]
86 | assert len(strings) == len(labels)
87 |
88 | if return_labels:
89 | return strings, labels
90 | return strings
91 |
92 | @overrides
93 | def _read(self, subset):
94 | np.random.seed(self.random_seed)
95 | data_indices = self.get_data_indices(subset)
96 | for idx in data_indices:
97 | txt = newsgroups_data.data[idx]
98 | topic = newsgroups_data.target[idx]
99 | label = newsgroups_data['target_names'][topic].split(".")[0]
100 | txt = clean_text(txt, special_chars=["\n", "\t"])
101 | if len(txt) == 0 or len(label) == 0:
102 | continue
103 | yield self.text_to_instance(txt, label)
104 |
105 | def text_to_instance(
106 | self, string: str, label:str = None) -> Optional[Instance]:
107 | tokens = self._tokenizer.tokenize(string)
108 | text_field = TextField(tokens, token_indexers=self._token_indexers)
109 | fields: Dict[str, Field] = {"tokens": text_field}
110 | if label is not None:
111 | fields["label"] = LabelField(label)
112 | return Instance(fields)
113 |
114 |
--------------------------------------------------------------------------------
/src/predictors/newsgroups/newsgroups_roberta.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_reader": {
3 | "type": "newsgroups",
4 | "token_indexers": {
5 | "tokens": {
6 | "type": "pretrained_transformer",
7 | "max_length": 512,
8 | "model_name": "roberta-large"
9 | }
10 | },
11 | "tokenizer": {
12 | "type": "pretrained_transformer",
13 | "max_length": 512,
14 | "model_name": "roberta-large"
15 | }
16 | },
17 | "model": {
18 | "type": "basic_classifier",
19 | "namespace": "tags",
20 | "seq2vec_encoder": {
21 | "type": "bert_pooler",
22 | "dropout": 0.1,
23 | "pretrained_model": "roberta-large"
24 | },
25 | "text_field_embedder": {
26 | "token_embedders": {
27 | "tokens": {
28 | "type": "pretrained_transformer",
29 | "max_length": 512,
30 | "model_name": "roberta-large"
31 | }
32 | }
33 | }
34 | },
35 | "train_data_path": "train",
36 | "test_data_path": "test",
37 | "trainer": {
38 | "num_epochs": 5,
39 | "optimizer": {
40 | "type": "huggingface_adamw",
41 | "lr": 2e-05,
42 | "weight_decay": 0.1
43 | },
44 | "learning_rate_scheduler": {
45 | "type": "slanted_triangular",
46 | "cut_frac": 0.06
47 | }
48 | },
49 | "evaluate_on_test": true,
50 | "data_loader": {
51 | "batch_sampler": {
52 | "type": "bucket",
53 | "batch_size": 8,
54 | "sorting_keys": [
55 | "tokens"
56 | ]
57 | }
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/src/predictors/predictor_utils.py:
--------------------------------------------------------------------------------
1 | def clean_text(text, special_chars=["\n", "\t"]):
2 | for char in special_chars:
3 | text = text.replace(char, " ")
4 | return text
5 |
6 |
--------------------------------------------------------------------------------
/src/predictors/race/race_dataset_reader.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from allennlp.data import DatasetReader
4 | from typing import List, Optional
5 | from allennlp.data import DatasetReader, Instance
6 |
7 | from overrides import overrides
8 |
9 | from allennlp_models.mc.dataset_readers.transformer_mc import TransformerMCReader
10 |
11 | from pathlib import Path
12 | from itertools import chain
13 | import os.path as osp
14 | import tarfile
15 | from tqdm import tqdm as tqdm
16 | import json
17 |
18 | from src.predictors.predictor_utils import clean_text
19 | import os
20 |
21 | logger = logging.getLogger(__name__)
22 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
23 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"), format=FORMAT)
24 | logger.setLevel(logging.INFO)
25 |
26 | @DatasetReader.register("race")
27 | class RaceDatasetReader(DatasetReader):
28 |
29 | def __init__(
30 | self,
31 | transformer_model_name: str = "roberta-large",
32 | qa_length_limit: int = 128,
33 | total_length_limit: int = 512,
34 | data_dir = "data/RACE",
35 | answer_mapping = {"A": 0, "B": 1, "C": 2, "D": 3},
36 | **kwargs
37 | ) -> None:
38 | super().__init__(**kwargs)
39 | from allennlp.data.tokenizers import PretrainedTransformerTokenizer
40 |
41 | self._tokenizer = PretrainedTransformerTokenizer(
42 | transformer_model_name, add_special_tokens=False,
43 | max_length=total_length_limit
44 | )
45 | from allennlp.data.token_indexers import PretrainedTransformerIndexer
46 |
47 | self._token_indexers = {"tokens": PretrainedTransformerIndexer(
48 | transformer_model_name)}
49 | self.qa_length_limit = qa_length_limit
50 | self.total_length_limit = total_length_limit
51 | self.data_dir = data_dir
52 |
53 | self.train_dir = os.path.join(self.data_dir, "train")
54 | self.dev_dir = os.path.join(self.data_dir, "dev")
55 | self.test_dir = os.path.join(self.data_dir, "test")
56 |
57 | self.answer_mapping = answer_mapping
58 |
59 | def get_path(self, file_path):
60 | if file_path == "train":
61 | high_dir = osp.join(self.train_dir, "high")
62 | middle_dir = osp.join(self.train_dir, "middle")
63 | elif file_path == "test":
64 | high_dir = osp.join(self.test_dir, "high")
65 | middle_dir = osp.join(self.test_dir, "middle")
66 | elif file_path == "dev":
67 | high_dir = osp.join(self.dev_dir, "high")
68 | middle_dir = osp.join(self.dev_dir, "middle")
69 | else:
70 | raise ValueError("Invalid value for file_path")
71 |
72 | path = chain(
73 | Path(high_dir).glob('*.txt'),
74 | Path(middle_dir).glob('*.txt'))
75 | return path
76 |
77 | @overrides
78 | def _read(self, file_path: str):
79 | path = self.get_path(file_path)
80 |
81 | for p in path:
82 | data = json.loads(p.read_text())
83 | iterator = enumerate(zip(data["answers"],
84 | data["options"],
85 | data["questions"]))
86 | for idx, (answer, options, question) in iterator:
87 | qid = str(data["id"][:-4] + "_" + str(idx))
88 | yield self.text_to_instance(qid, clean_text(data["article"]),
89 | question, options, self.answer_mapping[answer])[0]
90 |
91 | def get_inputs(self, file_path: str):
92 | logger.info(f"Getting RACE task inputs from file path: {file_path}")
93 | path = self.get_path(file_path)
94 | inputs = []
95 |
96 | for p in path:
97 | data = json.loads(p.read_text())
98 | iterator = enumerate(zip(data["answers"],
99 | data["options"],
100 | data["questions"]))
101 | for idx, (answer, options, question) in iterator:
102 | qid = str(data["id"][:-4] + "_" + str(idx))
103 | inp = {"id": qid,
104 | "article": clean_text(data["article"]),
105 | "question": question,
106 | "options": options,
107 | "answer_idx": self.answer_mapping[answer]}
108 | inputs.append(inp)
109 | return inputs
110 |
111 | @overrides
112 | def text_to_instance(
113 | self, # type: ignore
114 | qid: str,
115 | article: str,
116 | question: str,
117 | alternatives: List[str],
118 | label: Optional[int] = None,
119 | ) -> Instance:
120 | # tokenize
121 | article = self._tokenizer.tokenize(article)
122 | question = self._tokenizer.tokenize(question)
123 |
124 | # FORMAT: article, special tokens, question, special tokens, option
125 | sequences = []
126 | article_lengths = []
127 | max_article_lengths = []
128 |
129 | for alternative in alternatives:
130 | alternative = self._tokenizer.tokenize(alternative)
131 | qa_pair = self._tokenizer.add_special_tokens(
132 | question, alternative)[:self.qa_length_limit]
133 | length_for_article = self.total_length_limit - len(qa_pair) - \
134 | self._tokenizer.num_special_tokens_for_pair()
135 | sequence = self._tokenizer.add_special_tokens(
136 | article[:length_for_article], qa_pair)
137 | if len(sequence) > self.total_length_limit:
138 | print(len(sequence))
139 | assert len(sequence) <= self.total_length_limit
140 | sequences.append(sequence)
141 | article_lengths.append(len(article[:length_for_article]))
142 | max_article_lengths.append(length_for_article)
143 |
144 | # make fields
145 | from allennlp.data.fields import TextField
146 |
147 | sequences = [TextField(seq, self._token_indexers) for seq in sequences]
148 | from allennlp.data.fields import ListField
149 |
150 | sequences = ListField(sequences)
151 |
152 | from allennlp.data.fields import MetadataField
153 |
154 | fields = {
155 | "alternatives": sequences,
156 | "qid": MetadataField(qid),
157 | }
158 |
159 | if label is not None:
160 | if label < 0 or label >= len(sequences):
161 | raise ValueError("Alternative %d does not exist", label)
162 | from allennlp.data.fields import IndexField
163 |
164 | fields["correct_alternative"] = IndexField(label, sequences)
165 |
166 | return Instance(fields), article_lengths, max_article_lengths
167 |
--------------------------------------------------------------------------------
/src/predictors/race/race_roberta.json:
--------------------------------------------------------------------------------
1 | local transformer_model = "roberta-large";
2 |
3 | local epochs = 3;
4 |
5 | local gpu_batch_size = 4;
6 | local gradient_accumulation_steps = 16;
7 |
8 | {
9 | "dataset_reader": {
10 | "type": "race",
11 | "transformer_model_name": transformer_model,
12 | },
13 | "train_data_path": "train",
14 | "validation_data_path": "dev",
15 | "model": {
16 | "type": "transformer_mc",
17 | "transformer_model": transformer_model,
18 | },
19 | "data_loader": {
20 | "sampler": "random",
21 | "batch_size": gpu_batch_size
22 | },
23 | "trainer": {
24 | "optimizer": {
25 | "type": "huggingface_adamw",
26 | "weight_decay": 0.01,
27 | "parameter_groups": [[["bias", "LayerNorm\\.weight", "layer_norm\\.weight"], {"weight_decay": 0}]],
28 | "lr": 1e-5,
29 | "eps": 1e-8,
30 | "correct_bias": true
31 | },
32 | "learning_rate_scheduler": {
33 | "type": "linear_with_warmup",
34 | "warmup_steps": 100
35 | },
36 | // "grad_norm": 1.0,
37 | "num_epochs": epochs,
38 | "num_gradient_accumulation_steps": gradient_accumulation_steps,
39 | "patience": 3,
40 | "validation_metric": "+acc",
41 | "tensorboard_writer": {
42 | "summary_interval": 10,
43 | "should_log_learning_rate": true
44 | },
45 | },
46 | "random_seed": 42,
47 | "numpy_seed": 42,
48 | "pytorch_seed": 42,
49 | }
50 |
--------------------------------------------------------------------------------
/src/stage_one.py:
--------------------------------------------------------------------------------
1 | from transformers import T5Tokenizer, T5ForConditionalGeneration
2 | from transformers import T5Config, T5TokenizerFast
3 | from allennlp.data.tokenizers import PretrainedTransformerTokenizer
4 | from torch.utils.data import Dataset, DataLoader, RandomSampler
5 | import torch
6 | import more_itertools as mit
7 | import math
8 | import numpy as np
9 | import pandas as pd
10 | import os
11 | from tqdm import tqdm
12 | from types import SimpleNamespace
13 | import logging
14 | import json
15 | import sys
16 |
17 | # Local imports
18 | from src.masker import Masker, RandomMasker, GradientMasker
19 | from src.dataset import StageOneDataset, RaceStageOneDataset
20 | from src.utils import *
21 |
22 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
23 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"), format=FORMAT)
24 | logger.setLevel(logging.INFO)
25 |
26 | def train_epoch(epoch, editor_tokenizer, editor_model, device, loader, optim):
27 | """ Runs training for epoch """
28 |
29 | editor_model.train()
30 | total_loss = 0
31 | logger.info(f"Training epoch: {epoch}")
32 |
33 | for _, data in tqdm(enumerate(loader, 0), total = len(loader)):
34 | lm_labels = data['target_ids'].to(device, dtype = torch.long)
35 | lm_labels[lm_labels[:, :] == editor_tokenizer.pad_token_id] = -100
36 | ids = data['source_ids'].to(device, dtype = torch.long)
37 | outputs = editor_model(input_ids = ids, labels=lm_labels)
38 | loss = outputs[0]
39 | total_loss += loss.item()
40 |
41 | optim.zero_grad()
42 | loss.backward()
43 | optim.step()
44 | del lm_labels
45 | del ids
46 | torch.cuda.empty_cache()
47 |
48 | logger.info(f'Epoch: {epoch}, Avg Batch Loss: {total_loss/len(loader)}')
49 | return total_loss
50 |
51 | def validate_epoch(epoch, editor_tokenizer, editor_model, device, loader):
52 | """ Runs validation for epoch """
53 |
54 | editor_model.eval()
55 | total_loss = 0
56 | logger.info(f"Validating epoch: {epoch}")
57 |
58 | for _, data in tqdm(enumerate(loader, 0), total = len(loader)):
59 | lm_labels = data['target_ids'].to(device, dtype = torch.long)
60 | lm_labels[lm_labels[:, :] == editor_tokenizer.pad_token_id] = -100
61 | ids = data['source_ids'].to(device, dtype = torch.long)
62 |
63 | outputs = editor_model(input_ids = ids, labels=lm_labels)
64 | loss = outputs[0]
65 | total_loss += loss.item()
66 |
67 | del lm_labels
68 | del ids
69 | torch.cuda.empty_cache()
70 |
71 | logger.info(f'Epoch: {epoch}, Avg Batch Loss: {total_loss/len(loader)}')
72 | return total_loss
73 |
74 | def get_datasets(predictor, dr, masker, data_dir, train_inputs, val_inputs,
75 | train_labels, val_labels, editor_tokenizer, args):
76 | """ Writes data for Editor fine-tuning """
77 |
78 | train_data_path = os.path.join(data_dir, "train_data.csv")
79 | val_data_path = os.path.join(data_dir, "val_data.csv")
80 |
81 | # If data already exists for experiment, read data
82 | if os.path.exists(train_data_path) and os.path.exists(val_data_path):
83 | logger.info("Data for Editor fine-tuning already exist.")
84 | logger.info(f"Loading train data from: {train_data_path}")
85 | logger.info(f"Loading val data from: {val_data_path}")
86 |
87 | train_csv = pd.read_csv(train_data_path, sep="\t")
88 | val_csv = pd.read_csv(val_data_path, sep="\t")
89 |
90 | train_dataset = StageOneDataset(editor_tokenizer,
91 | max_length=args.model.model_max_length,
92 | masked_strings=train_csv['inputs'],
93 | targets=train_csv['targets'])
94 | val_dataset = StageOneDataset(editor_tokenizer,
95 | max_length=args.model.model_max_length,
96 | masked_strings=val_csv['inputs'],
97 | targets=val_csv['targets'])
98 |
99 | # Else, create data by calling create_inputs() function in dataset.py
100 | else:
101 | logger.info("Creating masked data for Editor fine-tuning...")
102 | logger.info(f"Target label (options are 'pred' or 'gold'): " + \
103 | f"{args.misc.target_label}")
104 | # For RACE, pass dr to create_inputs() to correctly truncate
105 | if args.meta.task == "race":
106 | train_dataset = RaceStageOneDataset(editor_tokenizer,
107 | max_length=args.model.model_max_length)
108 | train_dataset.create_inputs(dr, train_inputs, train_labels,
109 | predictor, masker, target_label=args.misc.target_label)
110 | val_dataset = RaceStageOneDataset(editor_tokenizer,
111 | max_length=args.model.model_max_length)
112 | val_dataset.create_inputs(dr, val_inputs, val_labels, predictor,
113 | masker, target_label=args.misc.target_label)
114 | else:
115 | train_dataset = StageOneDataset(editor_tokenizer,
116 | max_length=args.model.model_max_length)
117 | val_dataset = StageOneDataset(editor_tokenizer,
118 | max_length=args.model.model_max_length)
119 | train_dataset.create_inputs(train_inputs, train_labels, predictor,
120 | masker, target_label=args.misc.target_label)
121 | val_dataset.create_inputs(val_inputs, val_labels, predictor,
122 | masker, target_label=args.misc.target_label)
123 | logger.info("Done creating data.")
124 |
125 | # Write data
126 | logger.info(f"Writing train data to: {train_data_path}")
127 | train_masked_df = pd.DataFrame({
128 | 'inputs':train_dataset.masked_strings,
129 | 'targets':train_dataset.targets})
130 | train_masked_df.to_csv(train_data_path, sep="\t")
131 | logger.info(f"Writing val data to: {val_data_path}")
132 | val_masked_df = pd.DataFrame({
133 | 'inputs':val_dataset.masked_strings,
134 | 'targets':val_dataset.targets})
135 | val_masked_df.to_csv(val_data_path, sep="\t")
136 |
137 | return train_dataset, val_dataset
138 |
139 | def get_stage_one_masker(args, predictor):
140 | """ Helper function for loading appropriate masker, random or grad """
141 |
142 | logger.info(f"Creating masker of type: {args.mask.mask_type}")
143 | editor_tokenizer_wrapper = PretrainedTransformerTokenizer(
144 | "t5-base", max_length=args.model.model_max_length)
145 | if args.mask.mask_type == "random":
146 | logger.info("Loading Random masker...")
147 | masker = RandomMasker(None, editor_tokenizer_wrapper,
148 | args.model.model_max_length)
149 | elif args.mask.mask_type == "grad":
150 | logger.info("Loading Gradient Masker...")
151 | # In stage 1, if signed gradients, mask tokens pushing *towards* target
152 | sign_direction = 1 if "signed" in args.mask.grad_type else None
153 | masker = GradientMasker(None, editor_tokenizer_wrapper, predictor,
154 | args.model.model_max_length, grad_type=args.mask.grad_type,
155 | sign_direction=sign_direction)
156 | logger.info("Done.")
157 | return masker
158 |
159 | def get_task_data(args, dr):
160 | """ Helper function for loading original data of task.
161 | Calls get_inputs() function of dataset reader dr """
162 |
163 | if args.meta.task == 'race':
164 | strings = dr.get_inputs('train')
165 | labels = [int(s['answer_idx']) for s in strings]
166 | elif args.meta.task == "newsgroups" or args.meta.task == "imdb":
167 | strings, labels = dr.get_inputs('train', return_labels=True)
168 |
169 | string_indices = np.array(range(len(strings)))
170 | np.random.shuffle(string_indices)
171 | num_train = math.ceil(args.train.data_split_ratio * len(strings))
172 | train_string_indices, val_string_indices = \
173 | string_indices[:num_train], string_indices[num_train:]
174 | train_inputs = [strings[idx] for idx in train_string_indices]
175 | train_labels = [labels[idx] for idx in train_string_indices]
176 | val_inputs = [strings[idx] for idx in val_string_indices]
177 | val_labels = [labels[idx] for idx in val_string_indices]
178 |
179 | logger.info(f"Num train for Editor fine-tuning: {len(train_inputs)}")
180 | logger.info(f"Num val for Editor fine-tuning: {len(val_inputs)}")
181 |
182 | return train_inputs, val_inputs, train_labels, val_labels
183 |
184 | def run_train_editor(predictor, dr, args):
185 | """ Runs Editor training """
186 |
187 | # Set random seeds for reproducibility
188 | torch.manual_seed(args.train.seed)
189 | np.random.seed(args.train.seed)
190 | torch.backends.cudnn.deterministic = True
191 |
192 | editor_tokenizer, editor_model = load_base_t5(
193 | max_length=args.model.model_max_length)
194 | device = get_device()
195 | editor_model = editor_model.to(device)
196 |
197 | task_dir = os.path.join(args.meta.results_dir, args.meta.task)
198 | stage_one_dir = os.path.join(task_dir, f"editors/{args.meta.stage1_exp}")
199 | data_dir = os.path.join(stage_one_dir, 'editor_train_data')
200 | checkpoint_dir = os.path.join(stage_one_dir, 'checkpoints')
201 |
202 | logger.info(f"Task dir: {task_dir}")
203 | logger.info(f"Stage one dir: {stage_one_dir}")
204 | logger.info(f"Stage one training data dir: {data_dir}")
205 | logger.info(f"Checkpoints dir: {checkpoint_dir}")
206 |
207 | for dir in [task_dir, data_dir, stage_one_dir, checkpoint_dir]:
208 | if not os.path.exists(dir):
209 | os.makedirs(dir)
210 |
211 | # Save args
212 | args_path = os.path.join(stage_one_dir, "stage_one_args.json")
213 | write_args(args_path, args)
214 |
215 | masker = get_stage_one_masker(args, predictor)
216 |
217 | # Defining the parameters for creation of dataloaders
218 | train_params = {
219 | 'batch_size': args.train.train_batch_size,
220 | 'shuffle': True,
221 | 'num_workers': 0
222 | }
223 |
224 | val_params = {
225 | 'batch_size': args.train.val_batch_size,
226 | 'shuffle': False,
227 | 'num_workers': 0
228 | }
229 |
230 | optim = torch.optim.Adam(params=editor_model.parameters(), \
231 | lr=args.train.lr)
232 |
233 | # Load original task data
234 | train_inputs, val_inputs, train_labels, val_labels = \
235 | get_task_data(args, dr)
236 |
237 | # Get datasets for Editor training
238 | train_dataset, val_dataset = get_datasets(predictor, dr, masker,
239 | data_dir, train_inputs, val_inputs, train_labels, val_labels,
240 | editor_tokenizer, args)
241 | train_data_loader = DataLoader(train_dataset, **train_params)
242 | val_data_loader = DataLoader(val_dataset, **val_params)
243 |
244 | # Training loop
245 | logger.info('Initiating Editor Fine-Tuning.')
246 |
247 | best_path = os.path.join(checkpoint_dir, 'best.pth')
248 | best_val_loss = 1000000
249 | for epoch in range(args.train.num_epochs):
250 | path = os.path.join(checkpoint_dir, f"{epoch}.pth")
251 | if os.path.exists(path):
252 | logger.info(f"Found checkpoint for epoch. Loading from: {path}")
253 | editor_model.load_state_dict(torch.load(path))
254 | else:
255 | train_loss = train_epoch(epoch, editor_tokenizer, editor_model,
256 | device, train_data_loader, optim)
257 | logger.info("Saving Editor checkpoint to: " + path)
258 | torch.save(editor_model.state_dict(), path)
259 |
260 | val_loss = validate_epoch(epoch, editor_tokenizer, editor_model,
261 | device, val_data_loader)
262 | if val_loss < best_val_loss:
263 | best_val_loss = val_loss
264 | logger.info(f"Lowest loss. Saving weights to: {best_path}")
265 | torch.save(editor_model.state_dict(), best_path)
266 |
--------------------------------------------------------------------------------
/src/stage_two.py:
--------------------------------------------------------------------------------
1 | from transformers import T5Tokenizer, T5Model, T5Config
2 | from transformers import T5ForConditionalGeneration, T5TokenizerFast
3 | from allennlp.predictors import Predictor, TextClassifierPredictor
4 | from allennlp_models.classification import StanfordSentimentTreeBankDatasetReader
5 | from allennlp.data.tokenizers import PretrainedTransformerTokenizer
6 | from allennlp.data.tokenizers.spacy_tokenizer import SpacyTokenizer
7 |
8 | import torch
9 | import os
10 | import csv
11 | import heapq
12 | import sys
13 | import operator
14 | from tqdm import tqdm
15 | import re
16 | import nltk
17 | import warnings
18 | import argparse
19 | import pandas as pd
20 | import numpy as np
21 | import random
22 | import time
23 | import logging
24 | import json
25 |
26 | # Local imports
27 | from src.utils import *
28 | from src.edit_finder import EditFinder, EditEvaluator, EditList
29 | from src.editor import Editor, RaceEditor
30 | from src.predictors.imdb.imdb_dataset_reader import ImdbDatasetReader
31 | from src.predictors.newsgroups.newsgroups_dataset_reader import NewsgroupsDatasetReader
32 | from src.predictors.race.race_dataset_reader import RaceDatasetReader
33 |
34 | logger = logging.getLogger("my-logger")
35 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
36 | logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"), format=FORMAT)
37 | logger.setLevel(logging.INFO)
38 |
39 | def get_grad_sign_direction(grad_type, grad_pred):
40 | """ Helper function to get sign direction. When grad_type is signed,
41 | determines whether to get most negative or positive gradient values.
42 | This should depend on grad_pred, i.e. what label is being used to
43 | compute gradients for masking.
44 |
45 | During Stage Two, we want to mask tokens that push *away* from the contrast
46 | label or *towards* the original label.
47 |
48 | Sign direction plays no role if only gradient *magnitudes* are used
49 | (i.e. if grad_type is not signed, but involves taking the l1/l2 norm.)
50 | """
51 | assert grad_pred in ["contrast", "original"]
52 | assert grad_type in ["integrated_l1", "integrated_signed", "normal_l1",
53 | "normal_signed", "normal_l2", "integrated_l2"]
54 | if "signed" in grad_type and "contrast" in grad_pred:
55 | sign_direction = -1
56 | elif "signed" in grad_type and "original" in grad_pred:
57 | sign_direction = 1
58 | else:
59 | sign_direction = None
60 | return sign_direction
61 |
62 | def load_editor_weights(editor_model, editor_path):
63 | """ Loads Editor weights from editor_path """
64 |
65 | if os.path.isdir(editor_path):
66 | editor_path = os.path.join(editor_path, "best.pth")
67 | if not os.path.exists(editor_path):
68 | raise NotImplementedError(f"If directory given for editor_path, \
69 | it must contain a 'best.pth' file but found none in given \
70 | dir. Please give direct path to file containing weights.")
71 | logger.info(f"Loading Editor weights from: {editor_path}")
72 | editor_model.load_state_dict(torch.load(editor_path))
73 | return editor_model
74 |
75 | def load_models(args):
76 | """ Loads Predictor and Editor by task and other args """
77 |
78 | logger.info("Loading models...")
79 | predictor = load_predictor(args.meta.task)
80 | editor_tokenizer_wrapper = PretrainedTransformerTokenizer(
81 | 't5-base', max_length=args.model.model_max_length)
82 | editor_tokenizer, editor_model = load_base_t5(
83 | max_length=args.model.model_max_length)
84 | device = get_device()
85 | editor_model = load_editor_weights(editor_model, args.meta.editor_path)
86 | editor_model = editor_model.to(device)
87 |
88 | sign_direction = get_grad_sign_direction(
89 | args.mask.grad_type, args.misc.grad_pred)
90 |
91 | masker = GradientMasker(args.search.max_mask_frac,
92 | editor_tokenizer_wrapper, predictor,
93 | args.model.model_max_length,
94 | grad_type=args.mask.grad_type,
95 | sign_direction=sign_direction)
96 |
97 | if "race" in args.meta.task:
98 | editor = RaceEditor(editor_tokenizer_wrapper, editor_tokenizer,
99 | editor_model, masker,
100 | num_gens=args.generation.num_generations,
101 | num_beams=args.generation.generation_num_beams,
102 | grad_pred=args.misc.grad_pred,
103 | generate_type=args.generation.generate_type,
104 | length_penalty=args.generation.length_penalty,
105 | no_repeat_ngram_size=args.generation.no_repeat_ngram_size,
106 | top_p=args.generation.top_p,
107 | top_k=args.generation.top_k,
108 | verbose=False,
109 | editable_key="article")
110 | else:
111 | editor = Editor(editor_tokenizer_wrapper, editor_tokenizer,
112 | editor_model, masker,
113 | num_gens=args.generation.num_generations,
114 | num_beams=args.generation.generation_num_beams,
115 | grad_pred=args.misc.grad_pred,
116 | generate_type=args.generation.generate_type,
117 | no_repeat_ngram_size=args.generation.no_repeat_ngram_size,
118 | top_p=args.generation.top_p,
119 | top_k=args.generation.top_k,
120 | length_penalty=args.generation.length_penalty,
121 | verbose=False)
122 | logger.info("Done loading models.")
123 | return editor, predictor
124 |
125 | def run_edit_test(args):
126 | """ Runs Stage 2 on test inputs by task. """
127 |
128 | task_dir = os.path.join(args.meta.results_dir, args.meta.task)
129 | stage_two_dir = os.path.join(task_dir, f"edits/{args.meta.stage2_exp}")
130 |
131 | if not os.path.exists(stage_two_dir):
132 | os.makedirs(stage_two_dir)
133 |
134 | logger.info(f"Task dir: {task_dir}")
135 | logger.info(f"Stage two dir: {stage_two_dir}")
136 |
137 | # Save args
138 | args_path = os.path.join(stage_two_dir, "stage_two_args.json")
139 | write_args(args_path, args)
140 |
141 | out_file = os.path.join(stage_two_dir, "edits.csv")
142 | meta_log_file = os.path.join(stage_two_dir, "meta_log.txt")
143 |
144 | meta_f = open(meta_log_file, 'w', 1)
145 |
146 | # Load models and Edit objects
147 | editor, predictor = load_models(args)
148 | dr = get_dataset_reader(args.meta.task, predictor)
149 | edit_evaluator = EditEvaluator()
150 | edit_finder = EditFinder(predictor, editor,
151 | beam_width=args.search.beam_width,
152 | max_mask_frac=args.search.max_mask_frac,
153 | search_method=args.search.search_method,
154 | max_search_levels=args.search.max_search_levels)
155 |
156 | # Get inputs
157 | inputs = dr.get_inputs('test')
158 | if "race" not in args.meta.task:
159 | inputs = [x for x in inputs if len(x) > 0 and re.search('[a-zA-Z]', x)]
160 |
161 | np.random.seed(0)
162 | input_indices = np.array(range(len(inputs)))
163 | np.random.shuffle(input_indices)
164 |
165 | # Find edits and write to file
166 | with open(out_file, "w") as csv_file:
167 | fieldnames = ["data_idx", "sorted_idx", "orig_pred", "new_pred",
168 | "contrast_pred", "orig_contrast_prob_pred",
169 | "new_contrast_prob_pred", "orig_input", "edited_input",
170 | "orig_editable_seg", "edited_editable_seg",
171 | "minimality", "num_edit_rounds", "mask_frac",
172 | "duration", "error"]
173 | writer = csv.writer(csv_file, delimiter="\t")
174 | writer.writerow(fieldnames)
175 |
176 | for idx, i in tqdm(enumerate(input_indices), total=len(input_indices)):
177 | inp = inputs[i]
178 | logger.info(wrap_text(f"ORIGINAL INSTANCE ({i}): {inp}"))
179 |
180 | start_time = time.time()
181 | error = False
182 | try:
183 | edited_list = edit_finder.minimally_edit(inp,
184 | max_edit_rounds=args.search.max_edit_rounds,
185 | edit_evaluator=edit_evaluator)
186 |
187 | torch.cuda.empty_cache()
188 | sorted_list = edited_list.get_sorted_edits()
189 |
190 | except Exception as e:
191 | logger.info("ERROR: ", e)
192 | error = True
193 | sorted_list = []
194 |
195 | end_time = time.time()
196 |
197 | duration = end_time - start_time
198 | for s_idx, s in enumerate(sorted_list):
199 | writer.writerow([i, s_idx, edited_list.orig_label,
200 | s['edited_label'], edited_list.contrast_label,
201 | edited_list.orig_contrast_prob, s['edited_contrast_prob'],
202 | edited_list.orig_input, s['edited_input'],
203 | edited_list.orig_editable_seg,
204 | s['edited_editable_seg'], s['minimality'],
205 | s['num_edit_rounds'], s['mask_frac'], duration, error])
206 | csv_file.flush()
207 | if sorted_list == []:
208 | writer.writerow([i, 0, edited_list.orig_label,
209 | None, edited_list.contrast_label,
210 | edited_list.orig_contrast_prob, None,
211 | edited_list.orig_input, None,
212 | edited_list.orig_editable_seg,
213 | None, None, None, None, duration, error])
214 | csv_file.flush()
215 | meta_f.flush()
216 |
217 | csv_file.close()
218 | meta_f.close()
219 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import T5Tokenizer, T5Model, T5ForConditionalGeneration, \
2 | T5Config, T5TokenizerFast
3 | from allennlp.predictors import Predictor, TextClassifierPredictor
4 | from allennlp_models.classification \
5 | import StanfordSentimentTreeBankDatasetReader
6 | from allennlp.data.tokenizers import PretrainedTransformerTokenizer
7 | from allennlp.data.tokenizers import Token
8 | from allennlp.data.tokenizers.spacy_tokenizer import SpacyTokenizer
9 | from allennlp.data.batch import Batch
10 | from allennlp.models import Model
11 | from allennlp.models.archival import Archive, load_archive
12 | from allennlp.nn import util
13 | from torch import backends
14 | from nltk.tokenize.treebank import TreebankWordDetokenizer as Detok
15 | import more_itertools as mit
16 | import numpy as np
17 | import torch
18 | import textwrap
19 | import time
20 | import os
21 | import sys
22 | import logging
23 | import argparse
24 | import json
25 | import difflib
26 | from munch import Munch
27 |
28 | # local imports
29 | from src.predictors.imdb.imdb_dataset_reader import ImdbDatasetReader
30 | from src.predictors.newsgroups.newsgroups_dataset_reader \
31 | import NewsgroupsDatasetReader
32 | from src.predictors.race.race_dataset_reader import RaceDatasetReader
33 | from src.masker import Masker, RandomMasker, GradientMasker
34 |
35 | logger = logging.getLogger(__name__)
36 | logger.setLevel(logging.INFO)
37 |
38 | ####################################################################
39 | ######################## Arg Parsing Utils #########################
40 | ####################################################################
41 |
42 | def get_shared_parsers():
43 | """ Helper function to get parsers.
44 | Gets parsers that are shared across stage one and stage two. """
45 |
46 | meta_parser = argparse.ArgumentParser()
47 | meta_parser.add_argument("-task", required=True,
48 | help='Name of task. Currently, only RACE, IMDB, \
49 | and Newsgroups are supported.',
50 | choices=['race', 'imdb', 'newsgroups'])
51 | meta_parser.add_argument("-results_dir", default="results",
52 | help='Results dir. Where to store results.')
53 |
54 | mask_parser = argparse.ArgumentParser()
55 | mask_parser.add_argument("-mask_type", default="grad",
56 | choices=["grad", "random"])
57 | mask_parser.add_argument("-grad_type", default="normal_l1",
58 | choices=["integrated_l1", "integrated_signed", "normal_l1", \
59 | "normal_signed", "normal_l2", "integrated_l2"],
60 | help="Which gradient method to use for grad-based masking. \
61 | l1/signed/l2 determine how to aggregate over the emb dim.")
62 |
63 | model_parser = argparse.ArgumentParser()
64 | model_parser.add_argument("-model_max_length", default=700,
65 | help="Maximum number of tokens that Editor model can take")
66 | return {"meta": meta_parser, "mask": mask_parser, "model": model_parser}
67 |
68 | def get_stage_one_parsers():
69 | """ Helper function to get parsers for Stage 1. """
70 |
71 | train_parser = argparse.ArgumentParser()
72 | train_parser.add_argument("-train_batch_size", default=4, type=int)
73 | train_parser.add_argument("-val_batch_size", default=1, type=int)
74 | train_parser.add_argument("-num_epochs", default=10, type=int)
75 | train_parser.add_argument("-lr", default=5e-5, type=float)
76 | train_parser.add_argument("-seed", default=42, type=int)
77 | train_parser.add_argument("-data_split_ratio", default=0.75, type=float)
78 |
79 | misc_parser = argparse.ArgumentParser()
80 | misc_parser.add_argument("-target_label", default="gold",
81 | choices=["gold", "pred"],
82 | help="Which label to use as the target during Editor training")
83 | return {"train": train_parser, "misc": misc_parser}
84 |
85 | def get_stage_two_parsers():
86 | """ Helper function to get parsers for Stage 2. """
87 |
88 | generation_parser = argparse.ArgumentParser()
89 | generation_parser.add_argument("-generate_type", default="sample",
90 | choices=['beam', 'sample'])
91 | generation_parser.add_argument("-top_k", default=30)
92 | generation_parser.add_argument("-top_p", default=0.95)
93 | generation_parser.add_argument("-length_penalty", default=1.0)
94 | generation_parser.add_argument("-generation_num_beams", default=15)
95 | generation_parser.add_argument("-num_generations", default=15)
96 | generation_parser.add_argument("-no_repeat_ngram_size", default=2)
97 |
98 | search_parser = argparse.ArgumentParser()
99 | search_parser.add_argument("-max_mask_frac", default=0.55, type=float,
100 | help="Maximum mask fraction")
101 | search_parser.add_argument("-max_edit_rounds", default=3, type=int,
102 | help="Maximum number of edit rounds")
103 | search_parser.add_argument("-max_search_levels", default=4, type=int,
104 | help="Maximum number of search levels")
105 | search_parser.add_argument("-beam_width", default=3, type=int,
106 | help="Beam width for beam search over edits.")
107 | search_parser.add_argument("-search_method", default="binary",
108 | choices=["binary", "linear"],
109 | help="Which kind of search method to use: binary or linear.")
110 |
111 | misc_parser = argparse.ArgumentParser()
112 | misc_parser.add_argument("-grad_pred", default="original",
113 | choices=["original", "contrast"], help="Whether to take gradient \
114 | with respect to the contrast or original prediction")
115 |
116 | return {"generation": generation_parser,
117 | "search": search_parser,
118 | "misc": misc_parser}
119 |
120 | def get_parsers_by_stage(stage="stage1"):
121 | """ Gets parsers by stage. """
122 |
123 | if stage not in ["stage1", "stage2"]:
124 | raise ValueError(f"stage must be 'stage1' or 'stage2' but got {stage}")
125 | parsers = get_shared_parsers()
126 | if stage == "stage1":
127 | parsers.update(get_stage_one_parsers())
128 | parsers["meta"].add_argument("-stage1_exp", required=True,
129 | help='Stage 1 exp name. Used to create subdir in results dir \
130 | for trained Editor.')
131 | else:
132 | parsers.update(get_stage_two_parsers())
133 | parsers["meta"].add_argument("-editor_path", required=True,
134 | help="Path to trained Editor checkpoint. Can be a directory \
135 | containing 'best.pth' file OR a direct path to file \
136 | containing weights (if training ended prematurely).")
137 | parsers["meta"].add_argument("-stage2_exp", required=True,
138 | help='Stage 2 experiment name. Used to create subdir within \
139 | stage 1 directory for editing results.')
140 | return parsers
141 |
142 | def get_args(stage):
143 | """ Gets args by stage. """
144 |
145 | if stage not in ["stage1", "stage2"]:
146 | raise ValueError(f"stage must be one of ['stage1', 'stage2'] " + \
147 | "but got value {stage}")
148 | parsers = get_parsers_by_stage(stage)
149 | args = {}
150 | extra_args = sys.argv[1:]
151 | for arg_subset, parser in parsers.items():
152 | temp_args, extra_args = parser.parse_known_args(extra_args)
153 | args[arg_subset] = Munch(vars(temp_args))
154 | assert extra_args == [], f"Unrecognized arguments supplied: {extra_args}"
155 | return Munch(args)
156 |
157 | def write_args(args_path, args):
158 | """ Helper function to write args
159 | Args:
160 | args: list[Dict]
161 | args_path: str
162 | """
163 | logger.info("Writing args to: " + args_path)
164 | for name, sub_args in args.items():
165 | logger.info(f"{name} args: {sub_args}")
166 | f = open(args_path, "w")
167 | f.write(json.dumps(args, indent=4))
168 | f.close()
169 |
170 | ####################################################################
171 | ####################### Task Specific Utils ########################
172 | ####################################################################
173 |
174 | def get_dataset_reader(task, predictor):
175 | task_options = ["imdb", "race", "newsgroups"]
176 | if task not in task_options:
177 | raise NotImplementedError(f"Task {task} not implemented; \
178 | must be one of {task_options}")
179 | if task == "imdb":
180 | return ImdbDatasetReader(
181 | token_indexers=predictor._dataset_reader._token_indexers,
182 | tokenizer=predictor._dataset_reader._tokenizer)
183 | elif task == "race":
184 | return RaceDatasetReader()
185 | elif task == "newsgroups":
186 | return NewsgroupsDatasetReader(
187 | token_indexers=predictor._dataset_reader._token_indexers,
188 | tokenizer=predictor._dataset_reader._tokenizer)
189 |
190 | def format_classif_input(inp, label):
191 | return "label: " + label + ". input: " + inp
192 |
193 | def format_multiple_choice_input(context, question, options, answer_idx):
194 | formatted_str = f"question: {question} answer: choice {answer_idx}:" + \
195 | f"{options[answer_idx]} context: {context}"
196 | for option_idx, option in enumerate(options):
197 | formatted_str += " choice" + str(option_idx) + ": " + option
198 | return formatted_str
199 |
200 | def load_predictor(task, predictor_folder="trained_predictors/"):
201 | task_options = ["imdb", "race", "newsgroups"]
202 | if task not in task_options:
203 | raise NotImplementedError(f"Task {task} not implemented; \
204 | must be one of {task_options}")
205 | predictor_path = os.path.join(predictor_folder, task, "model/model.tar.gz")
206 | if not os.path.exists(predictor_path):
207 | raise ValueError(f"Cannot find predictor path {predictor_path}")
208 | logger.info(f"Loading Predictor from: {predictor_path}")
209 |
210 | dr_map = {
211 | "imdb": ImdbDatasetReader,
212 | "newsgroups": NewsgroupsDatasetReader,
213 | "race": RaceDatasetReader,
214 | }
215 |
216 | cuda_device = 0 if torch.cuda.is_available() else -1
217 | predictor = Predictor.from_path(predictor_path,
218 | dataset_reader_to_load=dr_map[task],
219 | cuda_device=cuda_device, frozen=True)
220 | logger.info("Done loading predictor.")
221 | return predictor
222 |
223 | ####################################################################
224 | ########################### Model Utils ############################
225 | ####################################################################
226 |
227 | def load_base_t5(max_length=700):
228 | t5_config = T5Config.from_pretrained("t5-base", n_positions=max_length)
229 | model = T5ForConditionalGeneration.from_pretrained("t5-base",
230 | config=t5_config)
231 | tokenizer = T5TokenizerFast.from_pretrained("t5-base", truncation=True)
232 | return tokenizer, model
233 |
234 | def get_device():
235 | return 'cuda' if torch.cuda.is_available() else 'cpu'
236 |
237 | def get_prob_pred(pred, label_idx):
238 | """ Given a prediction, gets predicted probability of label_idx. """
239 |
240 | for idx, prob in enumerate(pred['probs']):
241 | if idx == label_idx:
242 | return prob
243 |
244 | def get_ints_to_labels(predictor):
245 | vocab = predictor._model.vocab
246 | ints_to_labels = vocab.get_index_to_token_vocabulary('labels')
247 | return ints_to_labels
248 |
249 | def get_labels_to_ints(predictor):
250 | vocab = predictor._model.vocab
251 | labels_to_ints = vocab.get_token_to_index_vocabulary('labels')
252 | return labels_to_ints
253 |
254 | def get_predictor_tokenized(predictor, string):
255 | return predictor._dataset_reader._tokenizer.tokenize(string)
256 |
257 | def add_probs(pred):
258 | """ Computes predicted probs from logits. """
259 |
260 | if 'probs' not in pred:
261 | if isinstance(pred['logits'], torch.Tensor):
262 | pred['probs'] = torch.nn.functional.softmax(pred['logits'])
263 | else:
264 | pred['probs'] = np.exp(pred['logits'])/sum(np.exp(pred['logits']))
265 | return pred
266 |
267 | ####################################################################
268 | ########################### Other Utils ############################
269 | ####################################################################
270 |
271 | def wrap_text(text, num_indents=6, width=100):
272 | """ Util for pretty printing. """
273 |
274 | indent = "".join(['\t' for _ in range(num_indents)])
275 | return textwrap.fill(text, subsequent_indent = indent, width=width)
276 |
277 | def html_highlight_diffs(orig, edited, tokenizer_wrapper=SpacyTokenizer()):
278 | """ Given an orig and edited inputs, mark up differences in HTML. """
279 |
280 | orig = orig.replace("
" + \
301 | marked_original[start:end] + "" + marked_original[end:]
302 |
303 | marked_edited = edited.replace("
", "<-br />")
304 | for idx in reversed(edited_mark_indices):
305 | token = edited_tok[idx]
306 | start, end = token.idx, token.idx_end
307 | if start == None or end == None:
308 | logger.info(token, start, end)
309 | marked_edited = marked_edited[:start] + "" + \
310 | marked_edited[start:end] + "" + marked_edited[end:]
311 | return marked_original, marked_edited
312 |
313 | def get_marked_indices(orig_tokinal, tokenized_contrast, symbol):
314 | """ Helper function for html_highlight_diffs.
315 | Will only return indices of words deleted or replaced (not inserted). """
316 |
317 | index_offset = 0
318 | d = difflib.Differ()
319 | diff = d.compare(orig_tokinal, tokenized_contrast)
320 | list_diff = list(diff)
321 | tokens, modified_tokens, indices = [], [], []
322 | counter = 0
323 | additions, deletions = 0, 0
324 |
325 | for token_idx, token in enumerate(list_diff):
326 | marker = token[0]
327 | word = token[2:]
328 | if marker == symbol:
329 | tokens.append(word)
330 | indices.append(counter)
331 | counter += 1
332 | elif marker == " ":
333 | modified_tokens.append(word)
334 | counter += 1
335 |
336 | if marker == "+":
337 | additions += 1
338 | if marker == "-":
339 | deletions += 1
340 |
341 | return indices, additions, deletions
342 |
343 |
--------------------------------------------------------------------------------