├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── images └── motivation.png ├── infosol ├── alignment.py ├── decoding.py ├── env.py ├── evaluate.py ├── modeling_utils.py ├── models │ ├── __pycache__ │ │ └── word_edit_model.cpython-310.pyc │ └── word_edit_model.py ├── train.py └── utils │ ├── data │ └── cnn_dailymail.py │ ├── keywords.py │ └── pointer_utils.py ├── jobs └── interactive │ └── cnn-bart-s2s-len ├── requirements.txt ├── scripts ├── dowload_models.sh ├── make_data.py ├── make_eval_jobs.py └── run_eval.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | infosol_models.tar.zst 3 | data/ 4 | models/ 5 | jobs/main/ 6 | jobs/interactive/cnn-bart-editor 7 | jobs/interactive/cnn-bart-s2s 8 | jobs/interactive/cnn-bart_editor_large 9 | out/ 10 | infosol.egg-info/ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | #Makefile for training 2 | CUDA_DEVICE ?= 0 3 | 4 | # CNN BASE 5 | base_run_dir := #/path/to/your/run_dir/ 6 | base-cnn-bart_editor_large: 7 | python infosol/train.py --cuda_device $(CUDA_DEVICE) --run_dir $(base_run_dir) --run_name cnn-bart_editor_large base --n_epochs 6 --max_train_edits 1000000 --max_val_edits 2000 --model_name bart-large --track_gradient_norm True 8 | base-cnn-bart_editor_large-clip_grad_norm: 9 | python infosol/train.py --cuda_device $(CUDA_DEVICE) --run_dir $(base_run_dir) --run_name cnn-bart_editor_large base --n_epochs 6 --max_train_edits 1000000 --max_val_edits 2000 --model_name bart-large --track_gradient_norm True --clip_grad_norm True 10 | 11 | 12 | # CNN DAGGER 13 | run_dir := #/path/to/your/run_dir/ 14 | dagger_args_ := --use_timestamp False --run_dir $(run_dir) --cuda_device $(CUDA_DEVICE) dagger --n_epochs 600 --max_train_edits 1000000 --max_val_edits 2000 --n_warmup_epochs 300 --dagger_sampling_rate 1.0 --sample_batch_size 10000 --val_sample_batch_size 1000 15 | dagger_args := $(dagger_args_) --sampling_annealing_rate 0.9 16 | 17 | # main 18 | cnn-bart_editor: 19 | python infosol/train.py --run_name cnn-bart_editor $(dagger_args) 20 | cnn-bart_editor_large: 21 | python infosol/train.py --run_name cnn-bart_editor_large $(dagger_args) --model_name bart-large --track_gradient_norm True --clip_grad_norm True 22 | cnn-bart_s2s: 23 | python infosol/train.py --run_name cnn-bart_s2s $(dagger_args) --model_name barts2s 24 | # different oracles 25 | cnn-bart_editor-adj_edits: 26 | python infosol/train.py --run_name cnn-bart_editor-adj_edits $(dagger_args) --adjacent_ops True 27 | cnn-bart_editor-contig_edits: 28 | python infosol/train.py --run_name cnn-bart_editor-contig_edits $(dagger_args) --contiguous_edits True 29 | cnn-bart_editor-adj_edits-contig_edits: 30 | python infosol/train.py --run_name cnn-bart_editor-adj_edits-contig_edits $(dagger_args) --adjacent_ops True --contiguous_edits True 31 | 32 | # noise fractions 33 | cnn-bart_editor-noise_0.0: 34 | python infosol/train.py --run_name cnn-bart_editor-noise_0.0 $(dagger_args) --noise_frac 0 35 | cnn-bart_editor-noise_0.1: 36 | python infosol/train.py --run_name cnn-bart_editor-noise_0.1 $(dagger_args) --noise_frac 0.1 37 | cnn-bart_editor-noise_0.2: 38 | python infosol/train.py --run_name cnn-bart_editor-noise_0.2 $(dagger_args) --noise_frac 0.2 39 | # annealing rate 40 | cnn-bart_editor-anneal_0.85: 41 | python infosol/train.py --run_name cnn-bart_editor-anneal_0.85 $(dagger_args_) --sampling_annealing_rate 0.85 42 | cnn-bart_editor-anneal_0.80: 43 | python infosol/train.py --run_name cnn-bart_editor-anneal_0.80 $(dagger_args_) --sampling_annealing_rate 0.80 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | This project offers an implementation of the paper: 4 | 5 | **Interactive Text Generation**\ 6 | Felix Faltings, Michel Galley, Baolin Peng, Kianté Brantley, Weixin Cai, Yizhe Zhang, Jianfeng Gao, Bill Dolan\ 7 | [arXiv](https://arxiv.org/abs/2303.00908) 8 | 9 | 10 | 11 | # Installation 12 | 13 | Install dependencies using requirements.txt: 14 | `pip install -r requirements.txt` 15 | Then, install the package. From the top level directory (where `setup.py` is located), run: 16 | `pip install -e .` 17 | This will install this package as an editable module named `infosol`. 18 | 19 | Download model files (requires wget and zstd): 20 | `sh ./scripts/dowload_models.sh` 21 | 22 | # DATA 23 | 24 | You can regenerate the data used in the paper using the `make_data.py` script. You only need to specify the `data_dir` argument where the data will be saved (under `data/cnn_bart`). This script first downloads the raw data (CNN/DailyMail) from the Huggingface hub. The script can easily be adapted to generate other textual datasets from the hub. 25 | 26 | `python scripts/make_data.py --data_dir=data` 27 | 28 | # Test 29 | 30 | To replicate the main experiments of the paper, run: 31 | 32 | `python scripts/make_eval_jobs.py --model_dir=models --data_dir=data/cnn --job_dir=jobs --out_dir=out` 33 | 34 | The above command creates jobs files in 'jobs' directory, as well as the directory structure ('out') where test results will be stored. Then, you can pass any of the configuration files under 'jobs' as argument to `scripts/run_eval.py`. For example, run the following to replicate the BART-large 'interactive' experiments (Table 4 of the paper): 35 | 36 | `python scripts/run_eval.py --args_path jobs/interactive/cnn-bart_editor_large --cuda_device 0` 37 | 38 | Note: The S2S experiments of the paper yield generation that were inconsisent in length and hurt S2S performance. Thus, we tuned its length_penalty hyperparameter on a held out set, and the corresponding job files can be found in jobs/interactive/cnn-bart-s2s-len. 39 | 40 | # Train 41 | 42 | In order to train all the models presented in the paper, you may use the provided Makefile. Set the `run_dir` variable to the directory where you would like model weights to be saved to. You also need to set the `DATA_DIR` and `SAVE_DIR` paths in `train.py`. 43 | 44 | # Code walkthrough 45 | 46 | We suggest you go through the code in the order alignment -> model -> train/evaluate. Alignment defines some basic objects like alignments (used heavily by the oracle) and canvases (the objects that the model and oracle operate on). The others are self-explanatory and are commented. 47 | 48 | The main files: 49 | - alignment.py 50 | - env.py 51 | - evaluate.py 52 | - models/word_edit_model.py 53 | - run_eval.py 54 | - train.py 55 | - (decoding.py) 56 | -------------------------------------------------------------------------------- /images/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffaltings/InteractiveTextGeneration/febd658a91227dd88fbc5355111382b732a91647/images/motivation.png -------------------------------------------------------------------------------- /infosol/alignment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import itertools 3 | import torch 4 | import numpy as np 5 | 6 | from transformers import AutoTokenizer, BertModel 7 | 8 | def find_max_score(alignment_matrix, score_matrix, i, j, baseline_score=0): 9 | """ 10 | Maximum alignment score up to this position, given max alignment scores up to previous positions 11 | """ 12 | del_score = alignment_matrix[i, j-1] + baseline_score 13 | ins_score = alignment_matrix[i-1, j] + baseline_score 14 | match_score = alignment_matrix[i-1,j-1] + score_matrix[i-1, j-1] 15 | scores = (ins_score, del_score, match_score) # this order matters for priority b/w insert/del 16 | return np.max(scores), np.argmax(scores) 17 | 18 | def compute_alignment_scores(score_matrix, baseline_score=0): 19 | """ 20 | Compute maximum alignment score for each position using DP 21 | """ 22 | alignment_matrix = np.zeros((score_matrix.shape[0] + 1, score_matrix.shape[1] + 1)) 23 | for i in range(1, alignment_matrix.shape[0]): 24 | alignment_matrix[i,0] = i * (baseline_score*10)/10 #avoids a weird bug with floating point precision 25 | for j in range(1, alignment_matrix.shape[1]): 26 | alignment_matrix[0,j] = j * (baseline_score*10)/10 27 | move_matrix = np.zeros(alignment_matrix.shape) 28 | for j in range(1,alignment_matrix.shape[1]): 29 | for i in range(1,alignment_matrix.shape[0]): 30 | alignment_matrix[i,j], move_matrix[i,j] = find_max_score(alignment_matrix, score_matrix, i, j, baseline_score) 31 | return alignment_matrix, move_matrix 32 | 33 | def get_alignment(move_matrix, tokens_a, tokens_b, score_matrix=None, baseline_score=0.): 34 | """ 35 | Retrieve the actual alignment from the matrix of alignment scores, 36 | and matrix of "moves" thru the score matrix 37 | """ 38 | i,j = move_matrix.shape 39 | i -= 1 40 | j -= 1 41 | alignment = [] 42 | alignment_score = [] 43 | while i >= 1 and j >= 1: 44 | move = move_matrix[i,j] 45 | if move == 2: 46 | op = (tokens_a[i-1], tokens_b[j-1]) 47 | if score_matrix is not None: 48 | alignment_score.append(score_matrix[i-1,j-1]) 49 | # op = op + (float(score_matrix[i-1, j-1]),) 50 | alignment.append(op) 51 | i -= 1 52 | j -= 1 53 | elif move == 0: 54 | op = (tokens_a[i-1], '') 55 | if score_matrix is not None: 56 | # op = op + (baseline_score,) 57 | alignment_score.append(baseline_score) 58 | alignment.append(op) 59 | i -= 1 60 | elif move == 1: 61 | op = ('',tokens_b[j-1]) 62 | if score_matrix is not None: 63 | # op = op + (baseline_score,) 64 | alignment_score.append(baseline_score) 65 | alignment.append(op) 66 | j -= 1 67 | if i == 0: 68 | while j >= 1: 69 | op = ('', tokens_b[j-1]) 70 | if score_matrix is not None: 71 | # op = op + (baseline_score,) 72 | alignment_score.append(baseline_score) 73 | alignment.append(op) 74 | j -= 1 75 | elif j == 0: 76 | while i >= 1: 77 | op = (tokens_a[i-1], '') 78 | if score_matrix is not None: 79 | # op = op + (baseline_score,) 80 | alignment_score.append(baseline_score) 81 | alignment.append(op) 82 | i -= 1 83 | return alignment[::-1], alignment_score[::-1] 84 | 85 | def batch_cos_sim(tokens_a, tokens_b, model, tokenizer, device=torch.device('cpu'), use_model=True): 86 | """ 87 | Compute cosine similarities 88 | """ 89 | if len(tokens_a) != len(tokens_b): 90 | raise ValueError('batch sizes cannot be different') 91 | for tokens in tokens_a + tokens_b: 92 | if len(tokens) == 0: 93 | raise ValueError('tokens cannot be empty') 94 | 95 | def batch_encodings(batch): 96 | # convert to ids 97 | batch = [torch.tensor( 98 | tokenizer.convert_tokens_to_ids(tokens), 99 | dtype=torch.long, 100 | device=device) for tokens in batch] 101 | max_len = np.max([ids.shape[0] for ids in batch]) 102 | input_ids = torch.zeros((len(batch), max_len), dtype=torch.long, 103 | device=device) 104 | attention_mask = torch.zeros_like(input_ids, dtype=torch.int32) 105 | for i,b in enumerate(batch): 106 | input_ids[i, :len(b)] = b 107 | attention_mask[i, :len(b)] = 1 108 | with torch.no_grad(): 109 | out = model(input_ids=input_ids, attention_mask=attention_mask) 110 | 111 | for i in range(len(batch)): 112 | canv_len = attention_mask[i,:].sum() 113 | yield out.last_hidden_state[i, :canv_len],\ 114 | input_ids[i, :canv_len] 115 | 116 | batch_enc_a = batch_encodings(tokens_a) 117 | batch_enc_b = batch_encodings(tokens_b) 118 | 119 | for (enc_a, ids_a), (enc_b, ids_b) in zip(batch_enc_a, 120 | batch_enc_b): 121 | exact_matches = (ids_a.unsqueeze(-1).expand(-1, ids_b.shape[0])\ 122 | == ids_b.unsqueeze(0).expand(ids_a.shape[0], -1)).type(torch.float) 123 | 124 | if not use_model: 125 | yield exact_matches 126 | 127 | dot_products = torch.matmul(enc_a, enc_b.transpose(0,1)) 128 | try: 129 | norms = torch.matmul(torch.norm(enc_a, dim=-1).unsqueeze(1), 130 | torch.norm(enc_b, dim=-1).unsqueeze(0)) 131 | except IndexError as e: 132 | print(enc_a.shape, enc_b.shape, ids_a, ids_b) 133 | raise e 134 | yield torch.max(dot_products/norms, exact_matches) 135 | 136 | def batch_align(tokens_a, tokens_b, model, tokenizer, 137 | baseline_score=0., device=torch.device('cpu'), **kwargs): 138 | """ 139 | Align tokens 140 | :param tokens_a: list of lists of tokens 141 | :param tokens_b: list of lists of tokens 142 | :param model: encoder model 143 | :param tokenizer: tokenizer 144 | :param baseline_score: score assigned for non matches 145 | """ 146 | if len(tokens_a) != len(tokens_b): 147 | raise ValueError('batch sizes cannot differ') 148 | non_empty_idxs = [i for i,(tok_a,tok_b) in enumerate( 149 | zip(tokens_a, tokens_b)) if len(tok_a) > 0 and len(tok_b) > 0] 150 | score_matrices = batch_cos_sim([tokens_a[i] for i in non_empty_idxs], 151 | [tokens_b[i] for i in non_empty_idxs], model, tokenizer, device=device, **kwargs) 152 | 153 | for j in range(len(tokens_a)): 154 | tok_a = tokens_a[j] 155 | tok_b = tokens_b[j] 156 | if len(tok_a) == 0: 157 | alignment = [('', b) for b in tok_b] 158 | scores = [baseline_score] * len(alignment) 159 | yield Alignment(alignment, scores=scores, baseline_score=baseline_score) 160 | elif len(tok_b) == 0: 161 | alignment = [(a, '') for a in tok_a] 162 | scores = [baseline_score] * len(alignment) 163 | yield Alignment(alignment, scores=scores, baseline_score=baseline_score) 164 | else: 165 | score_matrix = next(score_matrices).cpu().numpy() 166 | alignment_matrix, move_matrix = compute_alignment_scores( 167 | score_matrix, baseline_score=baseline_score) 168 | alignment, scores = get_alignment(move_matrix, tok_a, tok_b, 169 | score_matrix=score_matrix, 170 | baseline_score=baseline_score) 171 | yield Alignment(alignment, scores=scores, baseline_score=baseline_score) 172 | 173 | def align(tokens_a, tokens_b, model, tokenizer, **kwargs): 174 | return next(batch_align([tokens_a], [tokens_b], model, tokenizer, **kwargs)) 175 | 176 | def batch_align_canvases(canvases, targets, model, tokenizer, **kwargs): 177 | """ 178 | Align canvases to target tokens, differs from aligning tokens because it handles the type ids of the canvases 179 | """ 180 | solid_tokens = [canvas.clean().tokens for canvas in canvases] 181 | for canvas,alignment in zip(canvases, 182 | batch_align(solid_tokens, targets, model, tokenizer, **kwargs)): 183 | recovered_alignment = [] 184 | recovered_alignment_scores = [] 185 | i,j = 0,0 186 | while i < len(canvas) or j < len(alignment): 187 | if i < len(canvas) and j < len(alignment) and \ 188 | canvas.tokens[i] == alignment.alignment[j][0] and\ 189 | not canvas.is_stricken(i) and not alignment.is_insertion(j): 190 | recovered_alignment.append(alignment.alignment[j]) 191 | recovered_alignment_scores.append(alignment.scores[j]) 192 | i += 1 193 | j += 1 194 | elif i < len(canvas) and canvas.is_stricken(i): 195 | recovered_alignment.append((canvas.tokens[i], '')) 196 | recovered_alignment_scores.append(1) 197 | i += 1 198 | elif j < len(alignment) and alignment.is_insertion(j): 199 | recovered_alignment.append(alignment.alignment[j]) 200 | recovered_alignment_scores.append(alignment.scores[j]) 201 | j += 1 202 | else: 203 | print(canvas.tokens, canvas.type_ids) 204 | print(alignment) 205 | print(i,j) 206 | print(recovered_alignment) 207 | print(canvas.tokens[i], canvas.type_ids[i], alignment.alignment[j]) 208 | raise RuntimeError('Error with alignment') 209 | alignment = Alignment(recovered_alignment, recovered_alignment_scores) 210 | alignment.set_type_ids(canvas.type_ids) 211 | yield alignment 212 | 213 | def align_canvas(canvas, target, model, tokenizer, **kwargs): 214 | return next(batch_align_canvases([canvas], [target], model, tokenizer, **kwargs)) 215 | 216 | class Canvas(): 217 | 218 | """ 219 | Canvas is a sequence of tokens with type ids 220 | """ 221 | 222 | html_types = [ 223 | "{}", # 0: plain text 224 | "{}", #1: agent inserted text 225 | "{}", #2 oracle inserted text 226 | "{}", #3: user deleted text 227 | "{}", #4 oracle deleted text 228 | ] 229 | 230 | @staticmethod 231 | def latex_type(tid, tok): 232 | """ 233 | Utility for rendering to latex 234 | """ 235 | if tid == 0: 236 | text = tok 237 | elif tid == 1: 238 | text = ''.join(('\\textcolor{seagreen}{',tok,'}')) 239 | elif tid == 2: 240 | text = ''.join(('\\textcolor{BrickRed}{',tok,'}')) 241 | elif tid == 3: 242 | text = ''.join(('\sout{', tok, '}')) 243 | elif tid == 4: 244 | text = ''.join(('\sout{\\textcolor{BrickRed}{', tok, '}}')) 245 | if tok[0] == ' ' and tid != 0: 246 | text = ' '+text 247 | return text 248 | 249 | def __init__(self, tokens, type_ids=None): 250 | if not type_ids is None and len(tokens) != len(type_ids): 251 | raise ValueError('length of tokens must match length of type ids') 252 | self.tokens = tokens 253 | if type_ids is None: 254 | type_ids = [0] * len(tokens) 255 | self.type_ids = type_ids 256 | 257 | def __len__(self): 258 | return len(self.tokens) 259 | 260 | def __eq__(self, other): 261 | return self.tokens == other.tokens and self.type_ids == other.type_ids 262 | 263 | def __lt__(self, other): 264 | return len(self.tokens) < len(other.tokens) 265 | 266 | def __gt__(self, other): 267 | return len(self.tokens) > len(other.tokens) 268 | 269 | def __le__(self, other): 270 | return len(self.tokens) <= len(other.tokens) 271 | 272 | def __ge__(self, other): 273 | return len(self.tokens) >= len(other.tokens) 274 | 275 | def __repr__(self): 276 | return str((self.tokens, self.type_ids)) 277 | 278 | def is_stricken(self, i): 279 | return self.type_ids[i] >= 3 280 | 281 | def operate(self, loc, operation, token, agent=0): 282 | """ 283 | Apply an edit to a canvas 284 | """ 285 | if operation == 0: # insertion 286 | self.tokens.insert(loc+1, token) 287 | self.type_ids.insert(loc+1, agent+1) 288 | elif operation == 1: # substitution 289 | self.tokens[loc] = token 290 | self.type_ids[loc] = agent+1 291 | elif operation == 2: # deletion 292 | if self.type_ids[loc] != 0: 293 | del self.tokens[loc] 294 | del self.type_ids[loc] 295 | else: 296 | self.type_ids[loc] = agent + 3 297 | else: 298 | raise ValueError('Invalid Operation: {}'.format(operation)) 299 | 300 | def render(self, tokenizer, clean=True): 301 | """ 302 | Return string from the canvas 303 | """ 304 | if clean: 305 | canvas = self.clean() 306 | if len(canvas) == 0: return '' 307 | else: 308 | return tokenizer.convert_tokens_to_string(canvas.tokens) 309 | 310 | def render_to_html(self, tokenizer): 311 | """ 312 | Format into html 313 | """ 314 | html_string = '' 315 | for tok, tid in zip(self.tokens, self.type_ids): 316 | html_string += Canvas.html_types[tid].format( 317 | tokenizer.convert_tokens_to_string(tok)) 318 | return html_string 319 | 320 | def render_to_latex(self, tokenizer): 321 | """ 322 | Format into LATEX 323 | """ 324 | latex_string = '' 325 | for tok, tid in zip(self.tokens, self.type_ids): 326 | latex_string += Canvas.latex_type( 327 | tid, tokenizer.convert_tokens_to_string(tok)) 328 | return latex_string 329 | 330 | def clean(self): 331 | """ 332 | Remove tokens that are stricken out and remove type ids (except for oracle type ids) 333 | """ 334 | tokens = [t for i,t in enumerate(self.tokens) if not self.is_stricken(i)] 335 | type_ids = [tid for i,tid in enumerate(self.type_ids) if not self.is_stricken(i)] 336 | # reset the type ids for model insertions 337 | type_ids = [0 if tid==1 else tid for tid in type_ids] 338 | return Canvas(tokens, type_ids) 339 | 340 | def copy(self): 341 | return Canvas(list(self.tokens), list(self.type_ids)) 342 | 343 | class Alignment(): 344 | 345 | """ 346 | Alignment between two sets of tokens 347 | """ 348 | 349 | def __init__(self, alignment, scores=None, type_ids=None, 350 | baseline_score=0.3, add_type_ids=True): 351 | self.alignment=alignment 352 | 353 | if not scores is None and len(scores) != len(alignment): 354 | raise ValueError('length of scores must match length of alignment') 355 | self.scores=scores 356 | self.baseline_score = baseline_score 357 | 358 | if not type_ids is None: 359 | self.set_type_ids(type_ids) 360 | elif add_type_ids: 361 | self.set_type_ids([0] * len(self.get_source_tokens())) 362 | else: 363 | self.type_ids = None 364 | 365 | def copy(self): 366 | return Alignment(list(self.alignment), list(self.scores)) 367 | 368 | def __len__(self): 369 | return len(self.alignment) 370 | 371 | def __str__(self): 372 | print_str = '' 373 | if len(self.alignment) == 0: return print_str 374 | src_pad = max(1, np.max([len(tup[0]) for tup in self.alignment])) 375 | tgt_pad = max(1, np.max([len(tup[1]) for tup in self.alignment])) 376 | for i in range(len(self.alignment)): 377 | line = '{src:{src_pad}} - {tgt:{tgt_pad}}'.format(src=self.alignment[i][0], 378 | tgt=self.alignment[i][1], src_pad=src_pad, tgt_pad=tgt_pad) 379 | if not self.scores is None: 380 | line += ' {:.2f}'.format(self.scores[i]) 381 | if not self.type_ids is None: 382 | line += ' {:2}'.format(self.type_ids[i]) 383 | print_str += line + '\n' 384 | return print_str 385 | 386 | def is_match(self, idx): 387 | return self.alignment[idx][0] != '' and self.alignment[idx][1] != '' 388 | 389 | def is_exact_match(self, idx): 390 | return self.alignment[idx][0] == self.alignment[idx][1] 391 | 392 | def is_insertion(self, idx): 393 | return self.alignment[idx][0] == '' and self.alignment[idx][1] != '' 394 | 395 | def is_addition(self, idx): 396 | return self.alignment[idx][1] != '' and self.alignment[idx][0] != self.alignment[idx][1] 397 | 398 | def is_deletion(self, idx): 399 | return self.alignment[idx][0] != '' and self.alignment[idx][1] == '' 400 | 401 | def get_non_const_ops(self): 402 | """ 403 | All positions in alignment that aren't exact matches 404 | """ 405 | if self.type_ids is None: 406 | return [i for i in range(len(self.alignment))\ 407 | if not self.is_exact_match(i)] 408 | else: 409 | return [i for i in range(len(self.alignment))\ 410 | if not self.is_exact_match(i) and self.type_ids[i] < 3] 411 | 412 | def get_adjacent_ops(self, idxs, adj_range = 1): 413 | """ 414 | Return idxs of adjacent ops in alignment. Adjacent idxs are idxs 415 | that are near a pair of matched tokens (not necessarily an exact match) 416 | """ 417 | def is_adjacent(idx, adj_range): 418 | min_lim = max(0, idx - adj_range) 419 | max_lim = min(len(self.alignment), idx + adj_range + 1) 420 | for i in range(min_lim, max_lim): 421 | if self.is_match(i): 422 | return True 423 | return False 424 | ret_idxs = [i for i in idxs if is_adjacent(i, adj_range)] 425 | if not ret_idxs: return idxs 426 | return ret_idxs 427 | 428 | def get_actions(self, action_idxs, include_deleted_words=False): 429 | """ 430 | Get actual action tuples for given idxs of an alignment 431 | """ 432 | actions = {} 433 | j = -1 434 | k = -1# <- keep track of this for insertions, -1 is sentinel token 435 | for i in range(len(self.alignment)): 436 | if not self.is_insertion(i): 437 | j += 1 438 | if self.type_ids is None: 439 | k = j 440 | else: 441 | if self.type_ids[i] < 3: 442 | k = j 443 | if i in action_idxs: 444 | if self.is_insertion(i): 445 | actions[i] = (k, 0, self.alignment[i][1]) 446 | elif self.is_deletion(i): 447 | if include_deleted_words: 448 | actions[i] = (j, 2, None, self.alignment[i][0]) 449 | else: 450 | actions[i] = (j, 2, None) 451 | else: # substitution 452 | actions[i] = (j, 1, self.alignment[i][1]) 453 | actions = [actions[i] for i in action_idxs] 454 | return actions 455 | 456 | def push_forward(self, idxs, agent = 0): 457 | """ 458 | Push alignment forward at given idxs (i.e. replace the source token with the target token at the location 459 | giving an exact match) 460 | """ 461 | if len(idxs) == 0: return 462 | if np.min(idxs) < 0 or np.max(idxs) >= len(self.alignment): 463 | raise ValueError('idxs out of range') 464 | delete_idxs = [] 465 | for i in idxs: 466 | if self.type_ids is None: 467 | if not self.is_deletion(i): 468 | self.alignment[i] = (self.alignment[i][1], self.alignment[i][1]) 469 | else: 470 | delete_idxs.append(i) 471 | else: 472 | if not self.is_deletion(i): 473 | self.alignment[i] = (self.alignment[i][1], self.alignment[i][1]) 474 | self.type_ids[i] = agent+1 475 | elif self.type_ids[i] == 0: 476 | self.type_ids[i] = agent+3 477 | elif self.type_ids[i] < 3: #otherwise token was already deleted 478 | delete_idxs.append(i) 479 | self.alignment = [tup for i,tup in enumerate(self.alignment) if i not in delete_idxs] 480 | if not self.scores is None: 481 | self.scores = [s for i,s in enumerate(self.scores) if i not in delete_idxs] 482 | if not self.type_ids is None: 483 | self.type_ids = [tid for i,tid in enumerate(self.type_ids) if i not in delete_idxs] 484 | 485 | def canvas_location_to_alignment_location(self, canvas_location): 486 | """ 487 | Map a location in the source canvas to the alignment 488 | """ 489 | j = 0 490 | for i in range(len(self.alignment)): 491 | if not self.is_insertion(i): 492 | canvas_location -= 1 493 | if canvas_location < -1: break 494 | j = i 495 | return j 496 | 497 | def operate(self, canvas_location, operation, token, agent = 0): 498 | ''' 499 | Operate on the source canvas while updating the alignment to the target tokens 500 | ''' 501 | if canvas_location >= len(self.get_source_tokens()): 502 | raise ValueError('canvas location out of bounds') 503 | if operation == 0: #insertion 504 | if token is None: raise ValueError('token cannot be None when inserting') 505 | # insert on the left of the next token to keep contiguous insertions and deletions grouped 506 | location = self.canvas_location_to_alignment_location(canvas_location+1) 507 | self.alignment.insert(location, (token, '')) 508 | if not self.scores is None: 509 | self.scores.insert(location, self.baseline_score) 510 | if not self.type_ids is None: 511 | self.type_ids.insert(location, agent+1) 512 | elif operation == 1: #substitution 513 | if token is None: raise ValueError('token cannot be None when substituting') 514 | if canvas_location == -1: raise ValueError('cannot substitute bos token') 515 | 516 | location = self.canvas_location_to_alignment_location(canvas_location) 517 | if not self.type_ids is None and self.type_ids[location] >= 3: 518 | raise ValueError('cannot edit deleted token') 519 | if not self.type_ids is None: self.type_ids[location] = 1 # trick to make deletion operate correctly 520 | self.operate(canvas_location, 2, token, agent) 521 | self.operate(canvas_location-1, 0, token, agent) 522 | elif operation == 2: #deletion 523 | location = self.canvas_location_to_alignment_location(canvas_location) 524 | if not self.type_ids is None and self.type_ids[location] >= 3: 525 | raise ValueError('cannot edit deleted token') 526 | if self.is_match(location): 527 | deleted_token = self.alignment[location][0] 528 | self.alignment[location] = ('', self.alignment[location][1]) 529 | if not self.scores is None: 530 | self.scores[location] = self.baseline_score 531 | if not self.type_ids is None and self.type_ids[location] == 0: #this needs to remain 532 | self.type_ids[location] = -1 533 | replacement_location = self.canvas_location_to_alignment_location( 534 | canvas_location) 535 | self.alignment.insert(replacement_location, (deleted_token, '')) 536 | self.type_ids.insert(replacement_location, agent+3) 537 | if not self.scores is None: 538 | self.scores.insert(replacement_location, self.baseline_score) 539 | elif not self.type_ids is None: self.type_ids[location] = -1 540 | else: 541 | if not self.type_ids is None and self.type_ids[location] == 0: 542 | self.type_ids[location] = agent+3 543 | else: 544 | del self.alignment[location] 545 | if not self.scores is None: 546 | del self.scores[location] 547 | if not self.type_ids is None: 548 | del self.type_ids[location] 549 | 550 | if len(self.get_type_ids()) != len(self.get_source_tokens()): 551 | print(canvas_location, operation, token) 552 | print(self.get_type_ids(), self.get_source_tokens()) 553 | print(self) 554 | raise RuntimeError('violated consitency') 555 | 556 | def set_type_ids(self, type_ids): 557 | ''' 558 | Add type ids to alignment 559 | :param type_ids: type ids to add. Assumes they correspond to the type ids of the source tokens 560 | ''' 561 | if len(type_ids) != len(self.get_source_tokens()): 562 | raise ValueError('length of type ids must match length of source tokens in alignment') 563 | self.type_ids = [] 564 | j = 0 565 | for i in range(len(self.alignment)): 566 | if not self.is_insertion(i): 567 | tag = type_ids[j] 568 | j += 1 569 | else: tag = -1 570 | self.type_ids.append(tag) 571 | 572 | def get_type_ids(self): 573 | if self.type_ids is None: return None 574 | else: return [tid for tid in self.type_ids if tid != -1] 575 | 576 | def get_source_tokens(self): 577 | return [self.alignment[i][0] for i in range(len(self.alignment))\ 578 | if not self.is_insertion(i)] 579 | 580 | def get_target_tokens(self): 581 | return [self.alignment[i][1] for i in range(len(self.alignment))\ 582 | if not self.is_deletion(i)] 583 | 584 | def get_source_canvas(self): 585 | source_tokens = self.get_source_tokens() 586 | type_ids = self.get_type_ids() 587 | return Canvas(source_tokens, type_ids) 588 | 589 | def get_target_canvas(self): 590 | target_tokens = self.get_target_tokens() 591 | return Canvas(target_tokens) 592 | 593 | def print_alignment(alignment): 594 | for (a,b) in alignment: 595 | print("{} - {}".format(a,b)) 596 | 597 | def get_frozen_idxs(alignment): 598 | """ 599 | Determine idxs with exact matches. Opposite of get_non_const_ops 600 | """ 601 | idxs = [] 602 | i = 0 603 | for tup in alignment: 604 | if tup[0] == tup[1]: 605 | idxs.append(i) 606 | if tup[0] != '': 607 | i += 1 608 | return idxs 609 | 610 | def score_token_idf(token, idf_dict, max_score): 611 | return idf_dict[token] if token in idf_dict else max_score 612 | 613 | def get_non_const_ops(alignment): 614 | return [i for i,op in enumerate(alignment) if op[0] != op[1] and op[2] < 3] 615 | 616 | def get_adjacent_ops(idxs, alignment, adj_range = 1): 617 | """ 618 | Return idxs of adjacent ops in alignment. Adjacent idxs are idxs 619 | that are near a pair of matched tokens (not necessarily an exact match) 620 | """ 621 | def is_adjacent(idx, adj_range): 622 | if idx < adj_range: # "adjacent to start token" 623 | return True 624 | min_lim = max(0, idx - adj_range) 625 | max_lim = min(len(alignment), idx + adj_range + 1) 626 | for i in range(min_lim, max_lim): 627 | if alignment[i][0] != '' and alignment[i][1] != '': # matched tokens? 628 | return True 629 | return False 630 | ret_idxs = [i for i in idxs if is_adjacent(i, adj_range)] 631 | return ret_idxs 632 | 633 | def operate_canvas(canvas, token_type_ids, loc, operation, token, agent=0): 634 | """ 635 | Apply an edit to a canvas 636 | """ 637 | canvas = list(canvas) 638 | token_type_ids = list(token_type_ids) 639 | if operation == 0: # insertion 640 | canvas = canvas[:loc+1] + [token] + canvas[loc+1:] 641 | token_type_ids = token_type_ids[:loc+1] + [agent+1] + token_type_ids[loc+1:] 642 | elif operation == 1: # substitution 643 | canvas = canvas[:loc] + [token] + canvas[loc+1:] 644 | token_type_ids = token_type_ids[:loc] + [agent+1] + token_type_ids[loc+1:] 645 | elif operation == 2: # deletion 646 | if token_type_ids[loc] != 0: 647 | canvas = canvas[:loc] + canvas[loc+1:] 648 | token_type_ids = token_type_ids[:loc] + token_type_ids[loc+1:] 649 | else: 650 | token_type_ids[loc] = agent + 3 651 | else: 652 | raise ValueError('Invalid Operation: {}'.format(operation)) 653 | return canvas, token_type_ids 654 | 655 | def operate_tagged_alignment(alignment, idxs, agent = 0): 656 | """ 657 | Push alignment forward at given idxs 658 | :param alignment: tagged alignment 659 | """ 660 | alignment = list(alignment) #so it doesn't do anything in place 661 | for i in idxs: 662 | if alignment[i][1] != '': 663 | alignment[i] = (alignment[i][1], alignment[i][1], agent + 1) 664 | elif alignment[i][-1] == 0: 665 | alignment[i] = (alignment[i][0], alignment[i][1], agent + 3) 666 | else: 667 | alignment[i] = (alignment[i][0], alignment[i][1], -2) 668 | alignment = [tup for tup in alignment if tup[-1] >= -1] 669 | return alignment 670 | 671 | def compress_alignment(alignment, alignment_scores): 672 | # TODO: this should assume that the alignment can contain other tags, such as scores 673 | new_alignment, new_alignment_scores = [], [] 674 | for tup, score in zip(alignment, alignment_scores): 675 | if tup[0] == '' and tup[1] == '': continue 676 | new_alignment.append(tup) 677 | new_alignment_scores.append(score) 678 | return new_alignment, new_alignment_scores 679 | 680 | def update_idxs(idxs, loc, op, tok): 681 | """ 682 | Update a set of idxs given an action that has been 683 | applied to a canvas 684 | """ 685 | new_idxs = [] 686 | for i in list(idxs): 687 | if i > loc: #unaffected otherwise 688 | if op == 0: 689 | i += 1 690 | elif op == 2: 691 | i -= 1 692 | new_idxs.append(i) 693 | return set(new_idxs) 694 | 695 | def get_actions_from_tagged_alignment(alignment, action_idxs, include_deleted_words=False): 696 | """ 697 | Get actual action tuples for given idxs of an alignment 698 | """ 699 | input_tokens = [] 700 | for i,tup in enumerate(alignment): 701 | input_tokens.append(tup[0]) 702 | 703 | actions = {} 704 | j = -1 705 | k = -1# <- keep track of this for insertions, -1 is sentinel token 706 | for i, (tok, op) in enumerate(zip(input_tokens, alignment)): 707 | if tok != '' and op[2] < 3: 708 | j += 1 709 | k = j 710 | elif tok != '' and op[2] >= 3: 711 | j += 1 712 | if i in action_idxs: 713 | if op[0] == '': #insertion 714 | actions[i] = (k, 0, op[1]) 715 | elif op[1] == '': #deletion 716 | if include_deleted_words: 717 | actions[i] = (j, 2, None, op[0]) 718 | else: 719 | actions[i] = (j, 2, None) 720 | else: # substitution 721 | actions[i] = (j, 1, op[1]) 722 | actions = [actions[i] for i in action_idxs] 723 | return actions 724 | 725 | def get_correct_canvas_positions(alignment): 726 | correct_positions = [] 727 | idx = 0 728 | for p in alignment: 729 | if p[0] == '': continue 730 | if p[0] == p[1]: correct_positions.append(idx) 731 | idx += 1 732 | return correct_positions 733 | 734 | def canvas_to_text(canvas, tokenizer, token_type_ids=None): 735 | if token_type_ids is None: 736 | token_type_ids = [0] * len(canvas) 737 | canvas, token_type_ids = clean_canvas(canvas, token_type_ids) 738 | if len(canvas) == 0: return '' 739 | else: 740 | return tokenizer.decode(tokenizer.convert_tokens_to_ids(canvas)) 741 | 742 | def tag_alignment(alignment, token_type_ids): 743 | tagged_alignment = [] 744 | j = 0 745 | for i,tup in enumerate(alignment): 746 | if tup[0] != '': 747 | tag = token_type_ids[j] 748 | j += 1 749 | else: tag = -1 750 | tagged_alignment.append([tup[0], tup[1], tag]) 751 | return tagged_alignment 752 | 753 | def get_tags_from_alignment(alignment): 754 | return [tup[2] for tup in alignment if tup[0] != ''] 755 | 756 | def get_source_canvas(alignment): 757 | source_tokens = [] 758 | for i,tup in enumerate(alignment): 759 | if tup[0] != '': source_tokens.append(tup[0]) 760 | return source_tokens 761 | 762 | def get_target_canvas(alignment): 763 | target_tokens = [] 764 | for i,tup in enumerate(alignment): 765 | if tup[1] != '': target_tokens.append(tup[1]) 766 | return target_tokens 767 | 768 | def clean_canvas(canvas, token_type_ids): 769 | canvas = [c for c,t in zip(canvas, token_type_ids) if t < 3] 770 | token_type_ids = [t for t in token_type_ids if t < 3] 771 | # reset the type ids for model insertions 772 | token_type_ids = [0 if t==1 else t for t in token_type_ids] 773 | return canvas, token_type_ids 774 | 775 | def get_token_type_ids(alignment, idxs): 776 | token_type_ids = [] 777 | for i,tup in enumerate(alignment): 778 | if tup[0] != '': 779 | token_type_ids.append(1 if i in idxs else 0) 780 | return token_type_ids 781 | 782 | class VocabSampler(): 783 | 784 | def __init__(self,vocab_weights, vocab_dict, batch_size=1000): 785 | self.weights = torch.tensor(vocab_weights) 786 | self.dict = vocab_dict 787 | self.batch_size=batch_size 788 | self.samples = torch.multinomial(self.weights, self.batch_size) 789 | self.sample_idx = 0 790 | 791 | def get(self): 792 | if self.sample_idx == self.batch_size: 793 | self.samples = torch.multinomial(self.weights, self.batch_size) 794 | self.sample_idx = 0 795 | self.sample_idx += 1 796 | return self.dict[self.samples[self.sample_idx-1].item()] 797 | 798 | #### 799 | #### VVVV Most of this now deperacted, moved to models/word_edit_model.py 800 | 801 | def noise_alignment_(tagged_alignment, vocab_sampler): 802 | positions = [i for i,tup in enumerate(tagged_alignment) if tup[-1] < 3] + [-1] 803 | position = np.random.choice(positions) 804 | if position >= 0 and tagged_alignment[position][-1] == 0 and \ 805 | tagged_alignment[position][0] == tagged_alignment[position][1]: 806 | operation = np.random.randint(0, 2) 807 | else: 808 | operation = 0 809 | 810 | if operation == 0: # insert 811 | # token = vocab_dict[torch.multinomial(torch.tensor(vocab_weights), 1).item()] 812 | token = vocab_sampler.get() 813 | tagged_alignment = tagged_alignment[:position+1] + [[token, '', 1]] + tagged_alignment[position+1:] 814 | elif operation == 1: # delete a token 815 | tup = tagged_alignment[position] 816 | tagged_alignment = tagged_alignment[:position] + [['', tup[1], -1]] + [[tup[0], '', 3]] + tagged_alignment[position+1:] 817 | 818 | return tagged_alignment 819 | 820 | 821 | def noise_alignment(tagged_alignment, vocab_weights, vocab_dict, max_noise=3, noise_frac=0.2): 822 | """ 823 | Noise an alignment for training. This randomly edits the source canvas of the alignment, and updates 824 | the alignment accordingly 825 | """ 826 | canv_len = len([tup[0] for tup in tagged_alignment if tup[-1] < 3]) 827 | n_false_edits = int(np.ceil(noise_frac * canv_len)) 828 | for i in range(n_false_edits): 829 | # while True: 830 | if random.random() > noise_frac: break 831 | tagged_alignment = noise_alignment_(tagged_alignment, vocab_weights, vocab_dict) 832 | return tagged_alignment 833 | 834 | def sample_actions(alignment, tokenizer, token_type_ids, vocab_weights, 835 | vocab_dict, noise_frac=0.0): 836 | """ 837 | Sample a set of actions for training. This randomly pushes the alignment forward a few steps 838 | and then returns the set of remaining actions to turn the source into the target 839 | """ 840 | if token_type_ids is None: 841 | token_type_ids = [0 for tup in alignment if tup[0] != ''] 842 | alignment = tag_alignment(alignment, token_type_ids) 843 | 844 | non_const_ops = get_non_const_ops(alignment) 845 | # print(non_const_ops) 846 | ops_length = len(non_const_ops) + 1 # +1 operation for stopping 847 | 848 | 849 | #sample length 850 | length = np.random.randint(0,ops_length) # allow sampling any state 851 | 852 | #sample operations 853 | rand_ops = np.random.permutation(non_const_ops) 854 | operations = rand_ops[:length] 855 | # return_ops = rand_ops[length:] 856 | 857 | alignment = operate_tagged_alignment(alignment, operations, agent=0) 858 | alignment = noise_alignment(alignment, vocab_weights, vocab_dict, noise_frac=noise_frac) 859 | return_ops = get_non_const_ops(alignment) 860 | 861 | actions = get_actions_from_tagged_alignment(alignment, return_ops) 862 | if len(actions) == 0: actions = [(1, None, None, None)] 863 | else: actions = [(0, a[0], a[1], tokenizer.convert_tokens_to_ids(a[2])) for a in actions] 864 | 865 | canvas = get_source_canvas(alignment) 866 | token_type_ids = get_tags_from_alignment(alignment) 867 | 868 | return canvas, token_type_ids, actions, ops_length 869 | 870 | def sample_trajectory(alignment, tokenizer, token_type_ids, vocab_sampler, 871 | noise_prob=0.2, max_traj_length=64): 872 | """ 873 | A more principled approach to noising the canvas. Randomly push the alignment forward 874 | while randomly making errors (noise). 875 | """ 876 | 877 | alignment = tag_alignment(alignment, token_type_ids) 878 | ops = get_non_const_ops(alignment) 879 | ops = np.random.permutation(ops) 880 | ops_len = len(ops) 881 | n_errors = 0 882 | n_ops = 0 883 | trajectory = [] 884 | trajectory.append((n_errors, n_ops)) 885 | for i in range(max_traj_length): 886 | if random.random() > noise_prob: 887 | if n_ops == ops_len and n_errors == 0: break 888 | if random.random() <= n_errors / (n_errors + ops_len - n_ops): 889 | n_errors -= 1 890 | else: n_ops += 1 891 | else: 892 | n_errors += 1 893 | trajectory.append((n_errors, n_ops)) 894 | traj_length = len(trajectory) 895 | traj_idx = np.random.randint(0,traj_length) # allow sampling any state 896 | n_errors, n_ops = trajectory[traj_idx] 897 | alignment = operate_tagged_alignment(alignment, ops[:n_ops]) 898 | # print(alignment) 899 | for i in range(n_errors): 900 | alignment = noise_alignment_(alignment, vocab_sampler) 901 | return_ops = get_non_const_ops(alignment) 902 | actions = get_actions_from_tagged_alignment(alignment, return_ops) 903 | if len(actions) == 0: actions = [(1, None, None, None)] 904 | else: actions = [(0, a[0], a[1], tokenizer.convert_tokens_to_ids(a[2])) for a in actions] 905 | 906 | canvas = get_source_canvas(alignment) 907 | token_type_ids = get_tags_from_alignment(alignment) 908 | 909 | return canvas, token_type_ids, actions, traj_length 910 | 911 | 912 | -------------------------------------------------------------------------------- /infosol/decoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | from queue import PriorityQueue 4 | 5 | class ParallelDecodingMixin(): 6 | """ 7 | Mixin for parallel decoding 8 | """ 9 | 10 | def forward_canvases(self, canvases, device=torch.device('cpu'), move_to_cpu=True): 11 | raise NotImplemented 12 | 13 | @staticmethod 14 | def canvas_len(canvas): 15 | raise NotImplemented 16 | 17 | def decode_loop(self, canvases, decode_func, state=None, max_batch_tokens=2048, device=torch.device('cpu'), queue_size=2000, return_idx=False): 18 | 19 | iter_idx = 0 20 | finished = False 21 | input_queue = PriorityQueue(maxsize=queue_size) 22 | input_iter = iter(canvases) 23 | start_state = state 24 | 25 | def enqueue(canvas, state, iter_idx): 26 | input_queue.put((canvas, iter_idx, state)) 27 | 28 | def add_input(idx): 29 | c = next(input_iter, None) 30 | if c is None: 31 | return True 32 | else: 33 | enqueue(deepcopy(c), start_state, idx) 34 | return False 35 | 36 | def get_batch(): 37 | batch = [] 38 | batch_len = 0 39 | batch_n_tokens = 0 40 | while True: 41 | if input_queue.empty(): break 42 | # get input 43 | inp = input_queue.get() 44 | canvas, idx, state = inp 45 | n_tokens = self.canvas_len(canvas) 46 | # check if can fit in batch 47 | batch_len_ = batch_len + 1 48 | batch_n_tokens_ = max(batch_n_tokens, n_tokens) 49 | # if doesn't fit, requeue and break 50 | if batch_len_ * batch_n_tokens_ > max_batch_tokens: 51 | enqueue(canvas, state, idx) # todo: swith this order of arguments, it's confusing 52 | break 53 | else: 54 | batch.append(inp) 55 | batch_len = batch_len_ 56 | batch_n_tokens = batch_n_tokens_ 57 | batch_size = len(batch) 58 | if batch_size == 0: 59 | raise RuntimeError('Unable to fit any inputs into batch. Try increasing max batch tokens') 60 | return batch 61 | 62 | # start by filling up the queue 63 | while not input_queue.full(): 64 | finished = add_input(iter_idx) 65 | iter_idx += 1 66 | if finished: break 67 | 68 | while not input_queue.empty(): 69 | batch = get_batch() 70 | batch_size = len(batch) 71 | 72 | canvases = [b[0] for b in batch] 73 | model_out = self.forward_canvases(canvases, device=device, move_to_cpu=True) 74 | for i in range(batch_size): 75 | m_out, (canvas, idx, state) = model_out[i], batch[i] 76 | canvas, state, stop = decode_func(m_out, canvas, state) 77 | 78 | if stop: 79 | if return_idx: 80 | yield canvas, idx 81 | else: 82 | yield canvas 83 | if not finished: 84 | finished = add_input(iter_idx) 85 | iter_idx += 1 86 | else: 87 | enqueue(canvas, state, idx) 88 | return 89 | -------------------------------------------------------------------------------- /infosol/env.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import torch.nn.functional as F 4 | import random 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | from transformers import BertModel, AutoTokenizer, T5ForConditionalGeneration, BertConfig 9 | from nltk.translate.bleu_score import sentence_bleu 10 | 11 | from infosol.alignment import batch_align, align_canvas, score_token_idf, Canvas, batch_align_canvases 12 | 13 | class WordEditOracle(torch.nn.Module): 14 | """ 15 | Oracle Model 16 | """ 17 | 18 | def __init__(self, align_model, align_tokenizer, idf_dict, adjacent_ops=False, sort_ops='sort', 19 | baseline_score=0.3, adj_range=1, min_alignment_score=0.7, 20 | avoid_delete=True, complete_words=False, token_no_space=None, 21 | contiguous_edits=False, 22 | # n_contiguous_edits=1 23 | ): 24 | """ 25 | :param idf_dict: dict of idf scores for ranking oracle actions 26 | :param adjacent_ops: whether to limit the model to making adjacent edits 27 | :param sort_ops: whether to sort the actions returned by the oracle, sorting is done according to idf scores. This is a string (!), can be 'sort', 'random', or 'l2r' (left2right) 28 | :param n_return_actions: number of actions the oracle returns. 0 means return all actions 29 | :param baseline_score: baseline score for computing alignment, 0.3 empirically gives "natural" alignments 30 | :param adj_range: adjacency range for limiting oracle to adjacent edits 31 | :param avoid_delete: whether to avoid deletion edits 32 | """ 33 | super().__init__() 34 | self.idf_dict = idf_dict 35 | # self.n_contiguous_edits = n_contiguous_edits 36 | self.contiguous_edits = contiguous_edits 37 | self.max_idf_score = np.max(list(self.idf_dict.values())) 38 | self.adjacent_ops = adjacent_ops 39 | self.adj_range = adj_range 40 | self.complete_words = complete_words 41 | self.token_no_space = token_no_space 42 | if sort_ops not in ('sort', 'random', 'l2r'): 43 | raise ValueError('unrecognized value for sort_ops: {}'.format(sort_ops)) 44 | self.sort_ops = sort_ops 45 | self.baseline_score = baseline_score 46 | self.avoid_delete = avoid_delete 47 | 48 | self.align_model = align_model 49 | self.align_tokenizer = align_tokenizer 50 | 51 | def edit(self, canvas, target, alignment=None, return_alignment=False, 52 | device=torch.device('cpu'), k=1): 53 | """ 54 | Make an edit 55 | """ 56 | if not alignment is None: 57 | if canvas != alignment.get_source_canvas(): 58 | raise ValueError('canvas does not match alignment') 59 | if target != alignment.get_target_tokens(): 60 | raise ValueError('target does not match alignment') 61 | else: 62 | alignment = align_canvas(canvas, target, self.align_model, 63 | self.align_tokenizer, baseline_score=self.baseline_score, 64 | device=device) 65 | self.edit_alignment(alignment, k=k) 66 | canvas = alignment.get_source_canvas() 67 | 68 | ret_tup = canvas 69 | if return_alignment: ret_tup = (canvas, alignment) 70 | return ret_tup 71 | 72 | def edit_alignment(self, alignment, k=1): 73 | """ 74 | Make an edit to an alignment 75 | """ 76 | op_idxs = self.get_action_idxs(alignment) 77 | op_idxs = self.sort_action_idxs(op_idxs, alignment) 78 | op_idxs = self.choose_action_idxs(op_idxs, alignment, k) 79 | op_idxs = list(set(op_idxs)) 80 | alignment.push_forward(op_idxs, agent=1) 81 | return len(op_idxs) > 0 82 | # 83 | # if len(op_idxs) > 0: 84 | # idxs_to_push = [op_idxs[0]] 85 | # for op in op_idxs[1:]: 86 | # if len(idxs_to_push) >= self.n_contiguous_edits: break 87 | # else: 88 | # cond1 = op == max(idxs_to_push) + 1 89 | # cond2 = op == min(idxs_to_push) - 1 90 | # if cond1 or cond2: 91 | # idxs_to_push.append(op) 92 | # if not self.complete_words: 93 | # alignment.push_forward(idxs_to_push, agent=1) 94 | # else: 95 | # idxs_to_push = list(set(idxs_to_push)) 96 | # idxs_to_push_ = [] 97 | # for op in idxs_to_push: 98 | # for case, rel_idx in zip( 99 | # (alignment.is_insertion, alignment.is_deletion), 100 | # (1,0)): 101 | # if case(op): 102 | # for j in range(op, 0, -1): 103 | # if case(j-1)\ 104 | # and self.token_no_space(alignment.alignment[j][rel_idx]): 105 | # # and alignment.alignment[j][rel_idx].startswith('#'): 106 | # idxs_to_push_.append(j-1) 107 | # else: break 108 | # for j in range(op+1, len(alignment)): 109 | # if case(j)\ 110 | # and self.token_no_space(alignment.alignment[j][rel_idx]): 111 | # # and alignment.alignment[j][rel_idx].startswith('#'): 112 | # idxs_to_push_.append(j) 113 | # else: break 114 | # # print(idxs_to_push) 115 | # idxs_to_push = list(set(idxs_to_push + idxs_to_push_)) 116 | # alignment.push_forward(list(idxs_to_push), agent=1) 117 | # return True 118 | # else: return False 119 | 120 | def get_action_idxs(self, alignment): 121 | """ 122 | Choose actions from alignment 123 | """ 124 | 125 | # retrieve idxs in the alignment that correspond to an edit 126 | op_idxs = alignment.get_non_const_ops() 127 | if len(op_idxs) == 0: return [] 128 | 129 | # if only allowing adjacent edits, limit idxs to adjacent idxs 130 | if self.adjacent_ops: 131 | op_idxs = alignment.get_adjacent_ops(op_idxs, adj_range=self.adj_range) 132 | return op_idxs 133 | 134 | def sort_action_idxs(self, op_idxs, alignment): 135 | """ 136 | Sort possible actions, based on alignment scores and idf scores 137 | """ 138 | if self.sort_ops == 'sort': 139 | # score using idf score and alignment score 140 | def score_op_idx(idx, alignment, max_idf_score): 141 | score = 0 142 | if alignment.is_deletion(idx): #deletions get treated specially 143 | if self.avoid_delete: 144 | score = 0 145 | else: # note that the alignment score for a deletion is the baseline score, usually 0.3 146 | # should have a more principled way of scoring deletions? 147 | score = alignment.scores[idx] * score_token_idf( 148 | alignment.alignment[idx][0], self.idf_dict, max_idf_score) 149 | else: # For insertions, the alignment score is the baseline score, usually 0.3 150 | score = (1-alignment.scores[idx]) * score_token_idf( 151 | alignment.alignment[idx][1], self.idf_dict, max_idf_score) 152 | return score 153 | idx_scores = [score_op_idx(i, alignment, self.max_idf_score) for i in op_idxs] 154 | op_idxs = list(np.asarray(op_idxs)[np.argsort(idx_scores)[::-1]]) 155 | elif self.sort_ops == 'random': 156 | op_idxs = [op_idxs[i] for i in np.random.choice( 157 | np.arange(len(op_idxs)), len(op_idxs), replace=False)] 158 | # if going left to right, we still want to prioritize a set of "misaligned" words 159 | elif self.sort_ops == 'l2r': 160 | priority_ops = [i for i in op_idxs if alignment_scores[i] < 0.8] 161 | if len(priority_ops) > 0: 162 | op_idxs = priority_ops 163 | return op_idxs 164 | 165 | def choose_action_idxs(self, op_idxs, alignment, k): 166 | 167 | """ 168 | Choose which actions to take based on some heuristics 169 | """ 170 | 171 | def complete_word(op): 172 | """ 173 | Completes actions so they operate on whole words, not tokens 174 | """ 175 | complete_idxs = [op] 176 | for case, rel_idx in zip( 177 | (alignment.is_addition, alignment.is_deletion), 178 | (1,0) 179 | ): 180 | try: 181 | if case(op): 182 | for j in range(op, 0, -1): #search backwards for word boundary 183 | if case(j-1) and self.token_no_space(alignment.alignment[j][rel_idx]): 184 | complete_idxs.append(j-1) 185 | else: break 186 | for j in range(op+1, len(alignment)): #search forwards for word boundary 187 | if case(j) and self.token_no_space(alignment.alignment[j][rel_idx]): 188 | complete_idxs.append(j) 189 | else: break 190 | except TypeError as e: 191 | print(alignment) 192 | raise e 193 | return complete_idxs 194 | 195 | def add_op(op, l): 196 | if self.complete_words: 197 | l.extend(complete_word(op)) 198 | else: 199 | l.append(op) 200 | return l 201 | 202 | idxs_to_push = [] 203 | for i,op in enumerate(op_idxs): 204 | if len(idxs_to_push) >= k: break 205 | if op in idxs_to_push: continue 206 | 207 | if not self.contiguous_edits: 208 | idxs_to_push = add_op(op, idxs_to_push) 209 | else: 210 | # start building set of contiguous edits 211 | contiguous_idxs = add_op(op, []) 212 | for opp in op_idxs[i+1:]: 213 | if opp in idxs_to_push: continue 214 | if opp in contiguous_idxs: continue 215 | if len(idxs_to_push) + len(contiguous_idxs) >= k: break 216 | 217 | cond1 = opp == max(contiguous_idxs) + 1 218 | cond2 = opp == min(contiguous_idxs) - 1 219 | if cond1 or cond2: 220 | contiguous_idxs = add_op(opp, contiguous_idxs) 221 | idxs_to_push.extend(contiguous_idxs) 222 | 223 | return idxs_to_push 224 | 225 | # if len(op_idxs) > 0: 226 | # idxs_to_push = [op_idxs[0]] 227 | # for op in op_idxs[1:]: 228 | # if len(idxs_to_push) >= self.n_contiguous_edits: break 229 | # else: 230 | # cond1 = op == max(idxs_to_push) + 1 231 | # cond2 = op == min(idxs_to_push) - 1 232 | # if cond1 or cond2: 233 | # idxs_to_push.append(op) 234 | # if not self.complete_words: 235 | # alignment.push_forward(idxs_to_push, agent=1) 236 | # else: 237 | # idxs_to_push = list(set(idxs_to_push)) 238 | # idxs_to_push_ = [] 239 | # for op in idxs_to_push: 240 | # for case, rel_idx in zip( 241 | # (alignment.is_insertion, alignment.is_deletion), 242 | # (1,0)): 243 | # if case(op): 244 | # for j in range(op, 0, -1): 245 | # if case(j-1)\ 246 | # and self.token_no_space(alignment.alignment[j][rel_idx]): 247 | # # and alignment.alignment[j][rel_idx].startswith('#'): 248 | # idxs_to_push_.append(j-1) 249 | # else: break 250 | # for j in range(op+1, len(alignment)): 251 | # if case(j)\ 252 | # and self.token_no_space(alignment.alignment[j][rel_idx]): 253 | # # and alignment.alignment[j][rel_idx].startswith('#'): 254 | # idxs_to_push_.append(j) 255 | # else: break 256 | # # print(idxs_to_push) 257 | # idxs_to_push = list(set(idxs_to_push + idxs_to_push_)) 258 | # alignment.push_forward(list(idxs_to_push), agent=1) 259 | # 260 | 261 | class EditingEnvironment(): 262 | 263 | def __init__(self, oracle, oracle_stop_p = 0.5, n_oracle_edits = -1, bleu_weights=(0.25,0.25,0.25,0.25)): 264 | """ 265 | :param oracle_stop_p: the oracle makes edits until it stops with probability oracle_stop_p 266 | :param bleu_weights: used for the reward, not important since we don't use rewards 267 | """ 268 | self.oracle = oracle 269 | self.oracle_stop_p = oracle_stop_p 270 | self.n_oracle_edits = n_oracle_edits 271 | self.bleu_weights=bleu_weights 272 | 273 | self.alignment = None 274 | 275 | def check_input_(self, canvas=None, target_tokens=None, alignment=None, device=torch.device('cpu')): 276 | if all((alignment is None, canvas is None, target_tokens is None)): 277 | raise ValueError('alignment, canvas and target cannot all be None') 278 | elif alignment is None and (canvas is None or target_tokens is None): 279 | raise ValueError('canvas and target cannot both be None') 280 | 281 | if not alignment is None: 282 | if canvas is None: 283 | canvas = alignment.get_source_canvas() 284 | if target_tokens is None: 285 | target_tokens = alignment.get_target_tokens() 286 | 287 | if canvas != alignment.get_source_canvas(): 288 | raise ValueError('canvas does not match alignment') 289 | if target_tokens != alignment.get_target_tokens(): 290 | raise ValueError('target does not match alignment') 291 | alignment = copy.deepcopy(alignment) 292 | else: 293 | alignment = align_canvas(canvas, target_tokens, self.oracle.align_model, 294 | self.oracle.align_tokenizer, baseline_score=self.oracle.baseline_score, 295 | device=device) 296 | return canvas, target_tokens, alignment 297 | 298 | def oracle_edit(self, canvas=None, target_tokens=None, alignment=None, device=torch.device('cpu'), return_alignment=False): 299 | canvas, target_tokens, alignment = self.check_input_(canvas, target_tokens, alignment, device=device) 300 | """ 301 | Let the oracle make an edit 302 | """ 303 | if self.n_oracle_edits >= 0: 304 | self.oracle.edit_alignment(alignment, k=self.n_oracle_edits) 305 | # for i in range(self.n_oracle_edits): 306 | # self.oracle.edit_alignment(alignment) 307 | else: 308 | k = np.random.geometric(self.oracle_stop_p) 309 | self.oracle.edit_alignment(alignment, k=k) 310 | # while True: 311 | # self.oracle.edit_alignment(alignment) 312 | # if np.random.random() <= self.oracle_stop_p: 313 | # break 314 | if return_alignment: 315 | return alignment 316 | else: 317 | return alignment.get_source_canvas() 318 | 319 | def reset(self, target_tokens=None, source_canvas=None, alignment=None, device=torch.device('cpu'), 320 | return_alignment=False): 321 | """ 322 | Reset the environment. Can either specify the target and source or the alignment (saves the cost of 323 | aligning the source to the target 324 | """ 325 | self.alignment = self.oracle_edit(source_canvas, target_tokens, alignment, return_alignment=True) 326 | if return_alignment: 327 | return copy.deepcopy(self.alignment) 328 | else: 329 | return copy.deepcopy(self.alignment.get_source_canvas()) 330 | 331 | def compute_reward(self, prev_canvas, agent_canvas, target_tokens): 332 | tokenizer = self.oracle.align_tokenizer 333 | 334 | prev_text = prev_canvas.render(tokenizer) 335 | target_text = Canvas(target_tokens).render(tokenizer) 336 | agent_text = agent_canvas.render(tokenizer) 337 | 338 | prev_bleu_score = sentence_bleu([target_text], 339 | prev_text, weights=self.bleu_weights) 340 | agent_bleu_score = sentence_bleu([target_text], 341 | agent_text, weights=self.bleu_weights) 342 | delta = agent_bleu_score - prev_bleu_score 343 | return delta, agent_bleu_score, prev_bleu_score 344 | 345 | def step(self, canvas, return_alignment=False, device=torch.device('cpu')): 346 | """ 347 | One step of the environment. Consists of oracle edit, computing rewards etc. 348 | """ 349 | prev_canvas, target_tokens = self.alignment.get_source_canvas(), self.alignment.get_target_tokens() 350 | reward, agent_score, _ = self.compute_reward(prev_canvas, canvas, target_tokens) 351 | 352 | self.alignment = self.oracle_edit(canvas=canvas, target_tokens=target_tokens, device=device, return_alignment=True) 353 | if return_alignment: 354 | return copy.deepcopy(self.alignment), reward 355 | else: 356 | return copy.deepcopy(self.alignment.get_source_canvas()), reward 357 | -------------------------------------------------------------------------------- /infosol/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import itertools 4 | import json 5 | import torch 6 | import tqdm 7 | 8 | from distutils.util import strtobool 9 | from nltk.translate.bleu_score import sentence_bleu 10 | 11 | from transformers import AutoTokenizer, BartModel, BertModel 12 | from datasets import load_from_disk 13 | 14 | from infosol.alignment import Alignment, Canvas, batch_align_canvases 15 | from infosol.models.word_edit_model import BertEditor, BartS2SEditor 16 | from infosol.env import WordEditOracle, EditingEnvironment 17 | 18 | class EvaluationInstance(): 19 | 20 | def __init__(self, datum): 21 | self.done = False 22 | self.rewards = [] 23 | self.history = [] 24 | self.actions = [] 25 | self.consistency = [] 26 | self.bleu_scores = [] 27 | self.canvas_history = [] 28 | self.canvas_actions = [] 29 | self.target_tokens = datum['target_tokens'] 30 | self.canvas = Canvas(datum['source_tokens'], datum['token_type_ids']) 31 | self.source_text = datum['source_text'] 32 | self.target_text = datum['target_text'] 33 | 34 | def to_dict(self): 35 | return { 36 | 'source': self.source_text, 37 | 'target': self.target_text, 38 | 'history': self.history, 39 | 'rewards': self.rewards, 40 | 'actions': self.actions, 41 | 'consistency': self.consistency, 42 | 'bleu_scores': self.bleu_scores, 43 | 'canvas history': self.canvas_history, 44 | 'canvas actions': self.canvas_actions 45 | } 46 | 47 | 48 | class Evaluate(): 49 | 50 | """ 51 | Evaluation job 52 | """ 53 | 54 | @classmethod 55 | def add_args(cls, parser): 56 | parser.add_argument('--model_path', type=str, 57 | default='') #/path/to/WEIGHTS.bin) 58 | parser.add_argument('--out_path', type=str, 59 | default='') #/path/to/save_file.pickle) 60 | parser.add_argument('--cuda_device', type=int, default=0) 61 | parser.add_argument('--data_path', type=str, default='') #/path/to/data/) 62 | parser.add_argument('--max_data', type=int, default=1000) 63 | parser.add_argument('--idf_path', type=str, default='') #/path/to/idf.pickle) 64 | parser.add_argument('--n_oracle_edits', type=int, default=3) 65 | parser.add_argument('--oracle_stop_p', type=float, default=0.5) 66 | parser.add_argument('--n_episodes', type=int, default=4) 67 | parser.add_argument('--BLEU_threshold', type=float, default=1.0) 68 | parser.add_argument('--adjacent_ops', type=lambda x:bool(strtobool(x)), default=True) 69 | parser.add_argument('--complete_words', type=lambda x:bool(strtobool(x)), default=True) 70 | parser.add_argument('--contiguous_edits', type=lambda x:bool(strtobool(x)), default=True) 71 | parser.add_argument('--baseline_alignment_score', type=float, default=0.3) 72 | parser.add_argument('--sort_ops', type=str, default='sort') 73 | parser.add_argument('--bleu_ngrams', type=int, default=1) 74 | parser.add_argument('--keyword_gen', type=lambda x:bool(strtobool(x)), default=False) 75 | 76 | def setup(self, args): 77 | self.args = args 78 | self.device = torch.device('cuda:{}'.format(args.cuda_device)) 79 | print('Loading Environment') 80 | self.tokenizer, self.align_model, self.oracle, self.env = self.setup_env(args) 81 | print('Loading Model') 82 | self.model = self.setup_model(args) 83 | 84 | def setup_model(self, args): 85 | raise NotImplementedError 86 | 87 | def setup_env(self, args): 88 | raise NotImplementedError 89 | 90 | def create_instance(self, datum): 91 | inst = EvaluationInstance(datum) 92 | alignment = Alignment(alignment=datum['alignment'], scores=datum['alignment_scores']) 93 | inst.canvas = self.env.reset(alignment=alignment) 94 | return inst 95 | 96 | def step_instance(self, inst, gen, alignment): 97 | """ 98 | Advance instance by one steps. This includes letting the oracle make changes, and tracking metrics etc. 99 | """ 100 | prev_canvas, target_tokens = inst.canvas, inst.target_tokens 101 | agent_canvas = gen 102 | # agent_canvas, target_tokens_ = alignment.get_source_canvas(), alignment.get_target_tokens() 103 | # assert target_tokens == target_tokens_, 'mismatching target tokens {target_tokens} vs. {target_tokens_}' 104 | 105 | reward, agent_score, prev_score = self.env.compute_reward(prev_canvas, agent_canvas, target_tokens) 106 | 107 | if not self.args.keyword_gen: 108 | inst.canvas = self.env.oracle_edit(alignment=alignment) 109 | else: 110 | oracle_canvas = self.env.oracle_edit(alignment=alignment) 111 | tokens_, type_ids_ = oracle_canvas.tokens, oracle_canvas.type_ids 112 | keywords = [t for t,tid in zip(tokens_, type_ids_) if tid == 2] 113 | kw_type_ids = [2] * len(keywords) 114 | inst.canvas = Canvas(keywords, type_ids=kw_type_ids) 115 | 116 | tokenizer = self.env.oracle.align_tokenizer 117 | inst.history.append(prev_canvas.render(tokenizer)) 118 | inst.canvas_history.append(prev_canvas.copy()) 119 | inst.bleu_scores.append((prev_score, agent_score)) 120 | inst.actions.append(agent_canvas.render(tokenizer)) 121 | inst.canvas_actions.append(agent_canvas.copy()) 122 | inst.rewards.append(reward) 123 | 124 | prev_text = prev_canvas.render(tokenizer) 125 | agent_text = agent_canvas.render(tokenizer) 126 | consistency = sentence_bleu([agent_text], prev_text, weights=(1,0,0,0)) 127 | inst.consistency.append(consistency) 128 | 129 | if len(inst.history) >= self.args.n_episodes: 130 | inst.done = True 131 | if max(inst.bleu_scores[-1]) >= self.args.BLEU_threshold: 132 | inst.done = True 133 | 134 | return inst 135 | 136 | def episode(self, instances): 137 | """ 138 | One editing episode. Model makes changes, align the model outputs to the targets, let the oracle make changes. 139 | Note the oracle makes changes when initializing the instances, that's why the order here is model then oracle 140 | instead of oracle then model as described in the paper 141 | """ 142 | canvases = [inst.canvas for inst in instances] 143 | generations = self.gen_model(canvases) 144 | clean_generations = [g.clean() for g in generations] 145 | align_model, align_tokenizer = self.env.oracle.align_model, self.env.oracle.align_tokenizer 146 | targets = [inst.target_tokens for inst in instances] 147 | alignments = batch_align_canvases(clean_generations, targets, align_model, align_tokenizer, device=self.device) 148 | # return [self.step_instance(inst, gen, a) for inst,gen,a in zip(instances, generations, alignments)] 149 | instances = [self.step_instance(inst, gen, a) for inst,gen,a in zip(instances, generations, alignments)] 150 | finished_instances = [inst for inst in instances if inst.done] 151 | instances = [inst for inst in instances if not inst.done] 152 | return instances, finished_instances 153 | 154 | def evaluate_instance(self, datum): 155 | rewards = [] 156 | history = [] 157 | actions = [] 158 | consistency = [] 159 | bleu_scores = [] 160 | canvas_history = [] 161 | canvas_actions = [] 162 | canvas = self.env.reset(alignment=Alignment( 163 | alignment=datum['alignment'], 164 | scores=datum['alignment_scores'])) 165 | for e in range(self.args.n_episodes): 166 | input_text = canvas.render(self.tokenizer) 167 | target_text = self.env.alignment.get_target_canvas().render(self.tokenizer) 168 | history.append(input_text) 169 | canvas_history.append(canvas.copy()) 170 | inp_score = sentence_bleu([target_text], input_text, weights=self.env.bleu_weights) 171 | 172 | canvas = self.gen_model(canvas) 173 | 174 | output_text = canvas.render(self.tokenizer) 175 | out_score = sentence_bleu([target_text], output_text, weights=self.env.bleu_weights) 176 | bleu_scores.append((inp_score, out_score)) 177 | actions.append(output_text) 178 | canvas_actions.append(canvas.copy()) 179 | 180 | canvas = canvas.clean() 181 | canvas, r = self.env.step(canvas, device=self.device) 182 | rewards.append(r) 183 | consistency.append(sentence_bleu([output_text], input_text, weights=(1,0,0,0))) 184 | return { 185 | 'source': datum['source_text'], 186 | 'target': datum['target_text'], 187 | # 'target': self.env.target_text, 188 | 'history': history, 189 | 'rewards': rewards, 190 | 'actions': actions, 191 | 'consistency': consistency, 192 | 'bleu_scores': bleu_scores, 193 | 'canvas history': canvas_history, 194 | 'canvas actions': canvas_actions 195 | } 196 | 197 | def gen_model(self, canvases): 198 | raise NotImplementedError 199 | 200 | # def gen_user(self, prev_canvases, canvases, targets, device=torch.device('cpu')): 201 | # batch_size = 64 202 | # user_canvases, rewards = [], [] 203 | # def batch(lst, n=1): 204 | # l = len(lst) 205 | # for ndx in range(0, l, n): 206 | # yield lst[ndx:min(ndx + n, l)] 207 | # for pc, c, t in tqdm.tqdm(zip( 208 | # batch(prev_canvases, batch_size), batch(canvases, batch_size), 209 | # batch(targets, batch_size))): 210 | # _, user_c, r = self.env.step_(pc, c, t, device) 211 | # user_canvases.extend(user_c) 212 | # rewards.extend(r) 213 | # return zip(user_canvases, rewards) 214 | 215 | def run(self): 216 | self.model = self.model.eval().to(self.device) 217 | self.oracle = self.oracle.to(self.device) 218 | 219 | data = load_from_disk(self.args.data_path)['test'] 220 | instances = [self.create_instance(datum) for datum in itertools.islice(data, 0, self.args.max_data)] 221 | finished_instances = [] 222 | while len(instances) > 0: 223 | instances, finst_ = self.episode(instances) 224 | finished_instances.extend(finst_) 225 | # for e in range(self.args.n_episodes): 226 | # instances = self.episode(instances) 227 | # results = [inst.to_dict() for inst in instances] 228 | results = [inst.to_dict() for inst in finished_instances] 229 | # for datum in tqdm.tqdm(itertools.islice(data,0,self.args.max_data), 230 | # total=self.args.max_data): 231 | # res = self.evaluate_instance(datum) 232 | # results.append(res) 233 | with open(self.args.out_path, 'wb') as f: 234 | pickle.dump(results, f) 235 | 236 | class EvaluateEditor(Evaluate): 237 | 238 | @classmethod 239 | def add_args(cls, parser): 240 | super().add_args(parser) 241 | parser.add_argument('--top_k', type=int, default=10) 242 | parser.add_argument('--stop_threshold', type=float, default=0.95) 243 | 244 | def gen_model(self, canvases): 245 | canvases = list(tqdm.tqdm(self.model.batch_depth_decode( 246 | canvases, 247 | device=self.device, 248 | top_k=self.args.top_k, 249 | stop_threshold=self.args.stop_threshold, 250 | return_idx=True 251 | ), total=len(canvases))) 252 | canvases = [(i,c) for c,i in canvases] 253 | canvases = [c for i,c in sorted(canvases)] 254 | return canvases 255 | 256 | class EvaluateS2S(Evaluate): 257 | 258 | @classmethod 259 | def add_args(cls, parser): 260 | super().add_args(parser) 261 | parser.add_argument('--do_sample', type=lambda x:bool(strtobool(x)), default=False) 262 | parser.add_argument('--top_p', type=float, default=0.95) 263 | parser.add_argument('--max_length', type=int, default=64) 264 | parser.add_argument('--num_beams', type=int, default=10) 265 | parser.add_argument('--length_penalty', type=float, default=1.0) 266 | 267 | def gen_model(self, canvases): 268 | return self.model.batch_generate(canvases, device=self.device, 269 | do_sample=self.args.do_sample, top_p=self.args.top_p, max_length=self.args.max_length, 270 | num_beams=self.args.num_beams, length_penalty=self.args.length_penalty) 271 | 272 | class EvaluateBart(Evaluate): 273 | 274 | def setup_env(self, args): 275 | tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base') 276 | align_model = BartModel.from_pretrained('facebook/bart-base').encoder 277 | with open(self.args.idf_path, 'rb') as f: 278 | idf_dict = pickle.load(f) 279 | oracle = WordEditOracle(align_model, tokenizer, idf_dict=idf_dict, 280 | sort_ops='sort', adjacent_ops=args.adjacent_ops, 281 | baseline_score=args.baseline_alignment_score, 282 | avoid_delete=False, complete_words=args.complete_words, 283 | token_no_space=lambda x: not x.startswith('Ġ'), 284 | contiguous_edits=args.contiguous_edits) 285 | bleu_weights = [1/args.bleu_ngrams if i < args.bleu_ngrams else 0 for i in range(4)] 286 | env = EditingEnvironment(oracle, n_oracle_edits=args.n_oracle_edits, 287 | oracle_stop_p=args.oracle_stop_p, 288 | bleu_weights = bleu_weights) 289 | 290 | return tokenizer, align_model, oracle, env 291 | 292 | class EvaluateBert(Evaluate): 293 | 294 | def setup_env(self, args): 295 | tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') 296 | align_model = BertModel.from_pretrained('bert-base-uncased') 297 | with open(self.args.idf_path, 'rb') as f: 298 | idf_dict = pickle.load(f) 299 | oracle = WordEditOracle(align_model, tokenizer, idf_dict=idf_dict, 300 | sort_ops='sort', adjacent_ops=args.adjacent_ops, 301 | baseline_score=args.baseline_alignment_score, 302 | avoid_delete=False, complete_words=args.complete_words, 303 | token_no_space=lambda x: x.startswith('#'), 304 | contiguous_edits=args.contiguous_edits) 305 | bleu_weights = [1/args.bleu_ngrams if i < args.bleu_ngrams else 0 for i in range(4)] 306 | env = EditingEnvironment(oracle, n_oracle_edits=args.n_oracle_edits, 307 | oracle_stop_p=args.oracle_stop_p, 308 | bleu_weights = bleu_weights) 309 | 310 | return tokenizer, align_model, oracle, env 311 | 312 | class EvaluateBartEditor(EvaluateEditor, EvaluateBart): 313 | 314 | def setup_model(self, args): 315 | model = BertEditor(model_type='bart', 316 | tokenizer=self.tokenizer, 317 | model_file='facebook/bart-base') 318 | model.load_state_dict(torch.load(self.args.model_path)) 319 | return model 320 | 321 | class EvaluateBertEditor(EvaluateEditor, EvaluateBert): 322 | 323 | def setup_model(self, args): 324 | model = BertEditor(model_type='bert', 325 | tokenizer=self.tokenizer) 326 | model.load_state_dict(torch.load(self.args.model_path)) 327 | return model 328 | 329 | 330 | class EvaluateBartS2S(EvaluateS2S, EvaluateBart): 331 | 332 | def setup_model(self, args): 333 | model = BartS2SEditor(self.align_model, self.tokenizer) 334 | model.load_state_dict(torch.load(self.args.model_path)) 335 | return model 336 | 337 | class EvaluateBartLarge(EvaluateEditor, EvaluateBart): 338 | 339 | def setup_model(self, args): 340 | model = BertEditor(model_type='bart-large', 341 | tokenizer=self.tokenizer, 342 | model_file='facebook/bart-large') 343 | model.load_state_dict(torch.load(self.args.model_path)) 344 | return model 345 | 346 | class EvaluateNoModel(EvaluateBartEditor): 347 | 348 | def gen_model(self, canvases): 349 | return canvases 350 | 351 | 352 | if __name__ == '__main__': 353 | 354 | parser = argparse.ArgumentParser() 355 | subparsers = parser.add_subparsers() 356 | 357 | parser_barteditor = subparsers.add_parser('BartEditor') 358 | parser_barteditor.set_defaults(func=EvaluateBartEditor) 359 | EvaluateBartEditor.add_args(parser_barteditor) 360 | 361 | parser_barts2s = subparsers.add_parser('BartS2S') 362 | parser_barts2s.set_defaults(func=EvaluateBartS2S) 363 | EvaluateBartS2S.add_args(parser_barts2s) 364 | 365 | parser_berteditor = subparsers.add_parser('BertEditor') 366 | parser_berteditor.set_defaults(func=EvaluateBertEditor) 367 | EvaluateBertEditor.add_args(parser_berteditor) 368 | 369 | parser_bartlargeeditor = subparsers.add_parser('BartLargeEditor') 370 | parser_bartlargeeditor.set_defaults(func=EvaluateBartLarge) 371 | EvaluateBartLarge.add_args(parser_bartlargeeditor) 372 | 373 | parser_baseline = subparsers.add_parser('Baseline') 374 | parser_baseline.set_defaults(func=EvaluateNoModel) 375 | EvaluateNoModel.add_args(parser_baseline) 376 | 377 | args = parser.parse_args() 378 | eval_instance = args.func() 379 | eval_instance.setup(args) 380 | eval_instance.run() 381 | -------------------------------------------------------------------------------- /infosol/modeling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def compute_entropy(probs): 4 | log_probs = torch.log(probs) 5 | entropies = -torch.matmul(probs, log_probs.transpose(0,1)).diagonal() 6 | return entropies 7 | 8 | def top_p_warp(scores, top_p=0.95, filter_value=-float("Inf")): 9 | sorted_logits, sorted_indices = torch.sort(scores, descending=True) 10 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 11 | 12 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) 13 | sorted_indices_to_remove = cumulative_probs > top_p 14 | # Shift the indices to the right to keep also the first token above the threshold 15 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 16 | sorted_indices_to_remove[..., 0] = 0 17 | 18 | # scatter sorted tensors to original indexing 19 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 20 | scores = scores.masked_fill(indices_to_remove, filter_value) 21 | return scores 22 | -------------------------------------------------------------------------------- /infosol/models/__pycache__/word_edit_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ffaltings/InteractiveTextGeneration/febd658a91227dd88fbc5355111382b732a91647/infosol/models/__pycache__/word_edit_model.cpython-310.pyc -------------------------------------------------------------------------------- /infosol/models/word_edit_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | import os 4 | import time 5 | import pdb 6 | import copy 7 | import torch.nn.functional as F 8 | import random 9 | import operator 10 | import numpy as np 11 | import torch.nn.functional as F 12 | import torch.multiprocessing as mp 13 | 14 | from queue import PriorityQueue 15 | from infosol.alignment import * 16 | from infosol.decoding import ParallelDecodingMixin 17 | from transformers import BertModel, AutoTokenizer, T5ForConditionalGeneration, BertConfig, AutoConfig, AutoModel 18 | from transformers.models.bart.modeling_bart import BartEncoder, BartModel, BartForConditionalGeneration, BartDecoder 19 | from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput 20 | 21 | MP_TIMEOUT = 2 #2 seconds timeout 22 | 23 | def log_factorial(n): 24 | if n <= 1: return 0. 25 | return np.log(np.arange(2,n+1)).sum() 26 | 27 | class BeamSearchNode(): 28 | 29 | """ 30 | Class for beam search 31 | """ 32 | 33 | def __init__(self, prevNode, canvas, stop, logp, length): 34 | self.prevNonde = prevNode 35 | self.canvas = canvas 36 | self.stop = stop 37 | self.logp = logp 38 | self.len = length 39 | 40 | def eval(self, len_normalize=False): 41 | if len_normalize: 42 | return self.logp + log_factorial(self.len-1) 43 | else: return self.logp 44 | 45 | def __eq__(self, other): 46 | return self.eval() == other.eval() 47 | 48 | def __lt__(self, other): 49 | return self.eval() < other.eval() 50 | 51 | def __gt__(self, other): 52 | return self.eval() > other.eval() 53 | 54 | def __le__(self, other): 55 | return self.eval() <= other.eval() 56 | 57 | def __ge__(self, other): 58 | return self.eval() >= other.eval() 59 | 60 | def noise_alignment(alignment, vocab_sampler): 61 | """ 62 | Make a random edit to the source canvas and update the alignment 63 | """ 64 | canvas = alignment.get_source_canvas() 65 | positions = [-1] + [i for i in range(len(canvas)) if canvas.type_ids[i] < 3] 66 | position = np.random.choice(positions) 67 | if position >= 0: 68 | operation = np.random.randint(0, 3) 69 | else: 70 | operation = 0 71 | 72 | if operation < 2: # insert and substitute 73 | token = vocab_sampler.get() 74 | alignment.operate(position, operation, token) 75 | elif operation == 2: # delete a token 76 | alignment.operate(position, operation, None) 77 | 78 | def sample_trajectory(alignment, tokenizer, vocab_sampler, noise_prob=0.2, 79 | max_traj_length=64): 80 | 81 | """ 82 | Sample a trajectory for training. Take the alignment and randomly push it forward 83 | a few steps (make the source more like the target). Inject noise by simulating a noisy 84 | process where you make errors when pushing the alignment forward. 85 | """ 86 | 87 | ops = alignment.get_non_const_ops() 88 | ops = np.random.permutation(ops) 89 | ops_len = len(ops) 90 | n_errors = 0 91 | n_ops = 0 92 | trajectory = [] 93 | trajectory.append((n_errors, n_ops)) 94 | for i in range(max_traj_length): 95 | if random.random() > noise_prob: 96 | if n_ops == ops_len and n_errors == 0: break 97 | if random.random() <= n_errors / (n_errors + ops_len - n_ops): # you might fix an error 98 | n_errors -= 1 99 | else: n_ops += 1 100 | else: 101 | n_errors += 1 102 | trajectory.append((n_errors, n_ops)) 103 | traj_length = len(trajectory) 104 | traj_idx = np.random.randint(0,traj_length) # sample state from trajectory 105 | n_errors, n_ops = trajectory[traj_idx] 106 | alignment.push_forward(ops[:n_ops]) 107 | for i in range(n_errors): 108 | noise_alignment(alignment, vocab_sampler) 109 | return_ops = alignment.get_non_const_ops() 110 | actions = alignment.get_actions(return_ops) 111 | if len(actions) == 0: actions = [(1, None, None, None)] 112 | else: actions = [(0, a[0], a[1], tokenizer.convert_tokens_to_ids(a[2])) for a in actions] 113 | 114 | canvas = alignment.get_source_canvas() 115 | 116 | return canvas, actions, traj_length 117 | 118 | class mlp_head(torch.nn.Module): 119 | """ 120 | MLP head to put on top of BERT 121 | """ 122 | 123 | def __init__(self, hidden_size, out_size, layer_norm_eps): 124 | super().__init__() 125 | self.head = torch.nn.Sequential( 126 | torch.nn.Linear(hidden_size, hidden_size), 127 | torch.nn.GELU(), 128 | torch.nn.LayerNorm(hidden_size, eps=layer_norm_eps), 129 | torch.nn.Linear(hidden_size, out_size, bias=False) 130 | ) 131 | 132 | def forward(self, inp): 133 | out = self.head(inp) 134 | return out 135 | 136 | def draw_multi(p_values): 137 | return np.where(np.random.multinomial(1, p_values) == 1)[0][0] 138 | 139 | class BertEditor(torch.nn.Module, ParallelDecodingMixin): #TODO: clean up, split out Bart model 140 | """ 141 | Main editor model 142 | """ 143 | 144 | def __init__(self, model_type = 'bert', model_file='bert-base-uncased', is_decoder=False, 145 | type_vocab_size=2, use_type_ids=True, training_noise=0.2, 146 | tokenizer=None, vocab_sampler=None): 147 | """ 148 | :param model_file: model file to initialize bert model, default set to pretrained model from huggingface 149 | :param use_memory: whether to use memory of oracle actions 150 | :param is_decoder: whether to use decoder architecture for BERT 151 | :param cat_embeds: whether to concatenate memory to inputs instead of feeding to cross attention layers (in decoder case) 152 | """ 153 | super().__init__() 154 | self.model_type = model_type 155 | self.use_type_ids = use_type_ids 156 | self.training_noise = training_noise 157 | 158 | self.tokenizer = tokenizer 159 | vocab_size = len(self.tokenizer) 160 | self.vocab_sampler = vocab_sampler 161 | self.delete_type_id = 3 162 | 163 | if self.model_type == 'bert': 164 | self.base_model = BertModel.from_pretrained(model_file) 165 | self.base_model.resize_token_embeddings(vocab_size) 166 | self.base_model.embeddings.token_type_embeddings = self.base_model._get_resized_embeddings( 167 | self.base_model.embeddings.token_type_embeddings, 5) # resize token type embeddings 168 | hidden_size = self.base_model.config.hidden_size 169 | layer_norm_eps = self.base_model.config.layer_norm_eps 170 | elif self.model_type in ['bart', 'bart-large']: 171 | config = AutoConfig.from_pretrained(model_file) 172 | config.n_type_ids = 5 173 | self.base_model = BartForConditionalGenerationTType.from_pretrained( 174 | model_file, config=config).model.encoder 175 | hidden_size = config.hidden_size 176 | layer_norm_eps = 1e-12 177 | 178 | self.emission_head = mlp_head(hidden_size, 1, layer_norm_eps) # decides when to "emit" a document/stop 179 | self.location_head = mlp_head(hidden_size, 1, layer_norm_eps) # decides where to edit 180 | self.operation_head = mlp_head(hidden_size, 3, layer_norm_eps) # decides what operation to do 181 | self.vocabulary_head = mlp_head(hidden_size, vocab_size * 2, layer_norm_eps) # decides what token to insert/substitute for 182 | 183 | def forward(self, input_ids, attention_mask, encoder_states=None, encoder_mask=None, token_type_ids=None, 184 | return_hidden_states=False): 185 | 186 | N, L = input_ids.shape 187 | 188 | if encoder_states is None: 189 | encoder_states = torch.zeros((N, 1, self.base_model.config.hidden_size)).to(input_ids.device) 190 | encoder_mask = torch.zeros((N, 1), dtype=torch.int32).to(input_ids.device) 191 | 192 | if not self.use_type_ids: 193 | token_type_ids = None 194 | 195 | if token_type_ids is None: 196 | token_type_ids = torch.zeros_like(attention_mask) 197 | token_type_ids = torch.cat([ 198 | torch.zeros_like(encoder_mask), 199 | token_type_ids], 200 | dim=1) 201 | attention_mask = torch.cat( 202 | [encoder_mask, attention_mask], dim=1) 203 | 204 | # TODO: split into separate model classes 205 | if self.model_type == 'bert': 206 | inputs_embeds = self.base_model.embeddings.word_embeddings(input_ids) #embed token ids 207 | inputs_embeds = torch.cat([encoder_states, inputs_embeds], dim=1) #concatenate with encoder states 208 | inputs_embeds = self.base_model.embeddings(inputs_embeds=inputs_embeds, 209 | token_type_ids=token_type_ids) 210 | out = self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask) 211 | elif self.model_type == 'bart' or self.model_type == 'bart-large': 212 | inputs_embeds = self.base_model.embed_tokens(input_ids) 213 | inputs_embeds = torch.cat([encoder_states, inputs_embeds], dim=1) #concatenate with encoder states 214 | out = self.base_model(inputs_embeds=inputs_embeds, token_type_ids=token_type_ids, 215 | attention_mask=attention_mask) 216 | prefix_len = encoder_states.shape[1] 217 | 218 | hidden_states = out[0][:, prefix_len:, :] 219 | emission_out = self.emission_head(hidden_states[:,0,:]).squeeze(-1) 220 | location_out = self.location_head(hidden_states).squeeze(-1) 221 | operation_out = self.operation_head(hidden_states) 222 | vocabulary_out = self.vocabulary_head(hidden_states) 223 | vocabulary_out = vocabulary_out.reshape(N, L, 2, -1) 224 | V = vocabulary_out.shape[-1] 225 | 226 | # if np.isnan(emission_out.detach().cpu()).any(): 227 | # pdb.set_trace() 228 | 229 | # Location mask 230 | # Mask out padding 231 | padding_mask = 1 - attention_mask[:, prefix_len:] 232 | # Mask out stricken tokens 233 | stricken_mask = (token_type_ids[:, prefix_len:] >= self.delete_type_id) # cannot insert, delete or sub stricken-out tokens <-- may want to allow substitutions? 234 | # Mask out EOS tokens 235 | eos_mask = torch.zeros_like(attention_mask[:, prefix_len:]) 236 | attention_lengths = attention_mask[:, prefix_len:].sum(dim=1, dtype=torch.int32) 237 | for i in range(attention_lengths.shape[0]): 238 | eos_mask[i, attention_lengths[i].item()-1] = 1 239 | location_mask = torch.clamp(padding_mask + stricken_mask + eos_mask, max=1) 240 | location_out += location_mask * -1e20 241 | 242 | # Operation mask 243 | # BOS mask 244 | bos_mask = torch.zeros_like(operation_out) 245 | bos_mask[:, 0, 1:] = 1 246 | operation_mask = torch.clamp(bos_mask, max=1) 247 | operation_out += operation_mask * -1e20 248 | 249 | # Vocabulary mask 250 | # Can't sub token for same one 251 | sub_mask = torch.zeros_like(vocabulary_out) 252 | sub_mask[:, :, 1, :] = F.one_hot(input_ids, num_classes=V) 253 | vocabulary_mask = torch.clamp(sub_mask, max=1) 254 | vocabulary_out += vocabulary_mask * -1e20 255 | 256 | return_out = (emission_out, location_out, operation_out, vocabulary_out) 257 | if return_hidden_states: 258 | return_out = return_out + (hidden_states,) 259 | return return_out 260 | 261 | def compute_instance_loss(self, emission_out, location_out, operation_out, vocabulary_out, 262 | emission_idx, location_idxs, operation_idxs, vocabulary_idxs, ops_len): 263 | """ 264 | Calculate cross entropy loss based on model outputs and gold targets 265 | """ 266 | n_actions = max(emission_out.shape[0] - 1, 1) 267 | 268 | emission_loss = F.binary_cross_entropy(torch.sigmoid(emission_out), emission_idx.float()) 269 | emission_mask = (1-emission_idx).type(torch.int32) 270 | 271 | location_loss = F.cross_entropy(location_out, location_idxs, reduction='none') 272 | location_loss = (location_loss * emission_mask).sum() 273 | 274 | # TODO: should technically mask here for bos, only viable operaiton is insertion 275 | operation_loss = F.cross_entropy( 276 | operation_out.gather(1, 277 | location_idxs.reshape(-1,1,1).expand(-1,-1,operation_out.size(-1))).squeeze(1), 278 | operation_idxs, reduction='none') 279 | # multiply by a mask so that don't count the loss in cases where the model should stop 280 | operation_loss = (operation_loss * emission_mask).sum() 281 | 282 | # TODO: should also technically mask here for sub identical token, since not allowed 283 | vocabulary_loss = F.cross_entropy( 284 | vocabulary_out.gather(1, 285 | location_idxs.reshape(-1,1,1,1).expand(-1,-1,2,vocabulary_out.size(-1))).squeeze(1).gather(1, 286 | torch.clamp(operation_idxs, max=1).reshape(-1,1,1).expand(-1,1,vocabulary_out.size(-1))).squeeze(1), 287 | vocabulary_idxs, reduction='none') # <- NOT averaged 288 | # only use vocabulary loss for substitution and insertion opertions (deletion mask) 289 | deletion_mask = 1 - (operation_idxs == 2).type(torch.long) # <- 2 = deletion operation 290 | vocabulary_loss = (vocabulary_loss * deletion_mask * emission_mask).sum() 291 | 292 | # scaling 293 | total_loss = emission_loss + location_loss + operation_loss + vocabulary_loss # this is already averaged by n-t 294 | total_loss = -log_factorial(ops_len) + ops_len * total_loss / n_actions 295 | 296 | return total_loss, emission_loss, location_loss, operation_loss, vocabulary_loss 297 | 298 | def compute_loss(self, input_ids, attention_mask, actions, prefix_lengths, ops_lengths, 299 | encoder_states=None, encoder_mask=None, token_type_ids=None): 300 | #TODO: get rid of prefix_lenghts (outdated) 301 | """ 302 | Computes the loss for a batch. Expands the batch when there are multiple actions to compute loss over all actions. 303 | Thus, for each entry in the batch, creates N entry for each action from the oracle that the model should predict 304 | """ 305 | emission_out, location_out, operation_out, vocabulary_out = self.forward(input_ids, attention_mask, 306 | encoder_states=encoder_states, encoder_mask=encoder_mask, token_type_ids=token_type_ids)[:4] 307 | 308 | # if np.isnan(emission_out.detach().cpu()).any(): 309 | # pdb.set_trace() 310 | 311 | batch_size = location_out.size(0) 312 | batch_loss, emission_loss, location_loss, operation_loss, vocabulary_loss = [ 313 | torch.tensor(0, dtype=torch.double, device=location_out.device)]*5 # create tensors to track loss 314 | # the number of actions that the oracle returns for each state could differ, so we have to iterate over the batch 315 | for e_out, l_out, o_out, v_out, (e_idxs, l_idxs, o_idxs, v_idxs), p_len, o_len in zip( 316 | emission_out, location_out, operation_out, vocabulary_out, actions, prefix_lengths, ops_lengths): 317 | n_ops = l_idxs.size(0) 318 | instance_loss, e_loss, l_loss, o_loss, v_loss = self.compute_instance_loss( 319 | # expand by number of operations that the oracle returned for these inputs 320 | e_out.unsqueeze(0).expand(n_ops), 321 | # prefix here is because we are still including a source canvas in the model inputs, which differs for each instance 322 | l_out[p_len:-1,...].unsqueeze(0).expand(n_ops, -1), 323 | o_out[p_len:-1,...].unsqueeze(0).expand(n_ops, -1, -1), 324 | v_out[p_len:-1,...].unsqueeze(0).expand(n_ops, -1, -1, -1), 325 | e_idxs, l_idxs, o_idxs, v_idxs, o_len 326 | ) 327 | 328 | batch_loss, emission_loss, location_loss, operation_loss, vocabulary_loss = [ 329 | loss + l / batch_size for loss,l in zip( 330 | [batch_loss, emission_loss, location_loss, operation_loss, vocabulary_loss], 331 | [instance_loss, e_loss, l_loss, o_loss, v_loss]) 332 | ] # aggregate loss 333 | 334 | metrics = { 335 | 'loss': batch_loss.item(), 336 | 'em_loss': emission_loss.item(), 337 | 'loc_loss': location_loss.item(), 338 | 'op_loss': operation_loss.item(), 339 | 'voc_loss': vocabulary_loss.item() 340 | } 341 | 342 | return batch_loss, metrics 343 | 344 | def prep_canvas(self, canvas, type_ids=None): 345 | """ 346 | Prepares canvass to feed into the model 347 | """ 348 | ids = self.tokenizer.convert_tokens_to_ids(canvas.tokens) # convert tokens to ids 349 | ids = [self.tokenizer.cls_token_id] + ids + [self.tokenizer.sep_token_id] 350 | type_ids = [0] + canvas.type_ids + [0] 351 | return ids, type_ids, 0 352 | 353 | def prep_canvases(self, canvases, device=torch.device('cpu')): 354 | """ 355 | Preps canvases into a batch to feed into the model 356 | """ 357 | canvases, type_ids, prefix_lengths = zip(*[self.prep_canvas(c) for c in canvases]) 358 | max_length = np.max([len(c) for c in canvases]) 359 | input_ids = torch.zeros((len(canvases), max_length), dtype=torch.long, device=device) 360 | attention_mask = torch.zeros(input_ids.shape, dtype=torch.int32, device=device) 361 | token_type_ids = torch.zeros_like(attention_mask, device=device) 362 | 363 | for i,(c, tids) in enumerate(zip(canvases, type_ids)): 364 | input_ids[i, :len(c)] = torch.tensor(c) 365 | token_type_ids[i, :len(c)] = torch.tensor(tids) 366 | attention_mask[i, :len(c)] = 1 367 | 368 | return input_ids, token_type_ids, attention_mask, prefix_lengths 369 | 370 | def prep_action(self, action): 371 | """ 372 | Prep action for model 373 | """ 374 | emission_idx = action[0] 375 | if emission_idx == 0: 376 | location_idx = action[1] + 1 # + 1 for sentinel token 377 | operation_idx = action[2] 378 | token_idx = action[3] if action[3] is not None else 0 379 | else: 380 | location_idx, operation_idx, token_idx = 0,0,0 # dummy values 381 | return emission_idx, location_idx, operation_idx, token_idx 382 | 383 | def prep_actions(self, actions, device=torch.device('cpu')): # preps sampled actions for a SINGLE INSTANCE 384 | """ 385 | Prep set of actions return by the oracle for the model. This is a single state, not for a batch 386 | """ 387 | return [torch.tensor(a, dtype=torch.long, device=device) for a in zip( 388 | *[self.prep_action(a) for a in actions])] 389 | 390 | def prep_batch(self, batch, device=torch.device('cpu'), **kwargs): 391 | """ 392 | Prepares whole batch to feed into the model 393 | """ 394 | if self.training_noise > 0 and self.vocab_sampler is None: 395 | raise ValueError('vocab sampler cannot be None when using training noise') 396 | canvases, actions, ops_lengths = zip(*[sample_trajectory( 397 | alignment, self.tokenizer, self.vocab_sampler, noise_prob = self.training_noise, **kwargs 398 | ) for alignment in batch]) 399 | 400 | input_ids, token_type_ids, attention_mask, prefix_lengths = self.prep_canvases( 401 | canvases, device=device) 402 | 403 | actions = [self.prep_actions(a, device) for a in actions] 404 | # no encoder input in this version 405 | encoder_states, encoder_mask = None, None 406 | 407 | return input_ids, attention_mask, actions, prefix_lengths, ops_lengths, encoder_states, encoder_mask, token_type_ids 408 | 409 | def move_batch(self, batch, device): 410 | input_ids, attention_mask, actions, prefix_lengths, ops_lengths, encoder_states, encoder_mask, token_type_ids = batch 411 | input_ids, attention_mask, token_type_ids = input_ids.to(device), attention_mask.to(device), token_type_ids.to(device) 412 | #encoder_states, encoder_mask = encoder_states.to(device), encoder_mask.to(device) 413 | actions = [(a[0].to(device), a[1].to(device), a[2].to(device), a[3].to(device)) for a in actions] 414 | return input_ids, attention_mask, actions, prefix_lengths, ops_lengths, encoder_states, encoder_mask, token_type_ids 415 | 416 | ### Decoding 417 | 418 | def convert_action(self, action): 419 | """ 420 | Get action from model representation of action 421 | """ 422 | stop, ref_idx, op_idx, tok_idx = action 423 | if not ref_idx is None: ref_idx -= 1 424 | if not tok_idx is None: 425 | tok_idx = self.tokenizer.convert_ids_to_tokens(int(tok_idx)) 426 | return stop, ref_idx, op_idx, tok_idx 427 | 428 | @staticmethod 429 | def sample_from_logits(actions, logps): 430 | action_ps = np.exp(np.asarray(logps)) 431 | action_ps = action_ps / np.sum(action_ps) 432 | action_idx = draw_multi(action_ps) 433 | return actions[action_idx], logps[action_idx] 434 | 435 | def forward_canvases(self, canvases, device=torch.device('cpu'), move_to_cpu=False): 436 | """ 437 | Forward pass taking canvases as inputs 438 | """ 439 | input_ids, token_type_ids, attention_mask, prefix_lengths = self.prep_canvases( 440 | canvases, device=device) 441 | with torch.no_grad(): 442 | emission_out, location_out, operation_out, vocabulary_out = self.forward( 443 | input_ids, attention_mask, token_type_ids = token_type_ids)[:4] 444 | if move_to_cpu: 445 | emission_out, location_out, operation_out, vocabulary_out =\ 446 | emission_out.cpu(), location_out.cpu(), operation_out.cpu(), vocabulary_out.cpu() 447 | B = emission_out.shape[0] 448 | out = [] 449 | for i in range(B): 450 | plen = prefix_lengths[i] 451 | e_out, l_out, op_out, v_out = emission_out[i], location_out[i], operation_out[i], vocabulary_out[i] 452 | # yield e_out, l_out[plen:], op_out[plen:], v_out[plen:] 453 | out.append((e_out, l_out[plen:], op_out[plen:], v_out[plen:])) 454 | return out 455 | 456 | def forward_canvas(self, canvas, device=torch.device('cpu')): 457 | # return list(self.forward_canvases(canvas, device=device)) 458 | return self.forward_canvases(canvas, device=device) 459 | 460 | def get_topk_actions(self, canvas, device=torch.device('cpu'), top_k=1, **kwargs): 461 | """ 462 | Get top actions from model predictions 463 | """ 464 | input_ids, token_type_ids, attention_mask, prefix_lengths = self.prep_canvases( 465 | [canvas], device=device) 466 | plen = prefix_lengths[0] 467 | with torch.no_grad(): 468 | emission_out, location_out, operation_out, vocabulary_out = self.forward( 469 | input_ids, attention_mask, token_type_ids = token_type_ids)[:4] 470 | actions, logps, stop_lp = get_topk_action_logps( 471 | emission_out, location_out[:, plen:], operation_out[:, plen:], 472 | vocabulary_out[:, plen:], top_k=top_k, **kwargs) 473 | return actions, logps, stop_lp 474 | 475 | @staticmethod 476 | def ancestral_sample(emission_out, location_out, operation_out, vocabulary_out): 477 | """ 478 | Ancestral sampling: sample next action based on previous predictions 479 | """ 480 | stop_p = torch.sigmoid(emission_out) 481 | if np.random.random() <= stop_p: 482 | return (1, None, None, None) 483 | 484 | def sample_from_out(out): 485 | prob = torch.softmax(out, dim=-1) 486 | return torch.multinomial(prob, 1).item() 487 | 488 | location = sample_from_out(location_out) 489 | op = sample_from_out(operation_out[location]) 490 | if op == 2: return (0, location, op, None) 491 | token = sample_from_out(vocabulary_out[location,op]) 492 | return (0, location, op, token) 493 | 494 | @staticmethod 495 | def get_topk_action_logps(stop_out, location_out, action_out, vocab_out, top_k=10, exclude_stop=False): # warning: this will modify the outputs from the model! 496 | """ 497 | Searches the model's output distribution for the top k actions. The actions the model can take 498 | are organized in a tree. The first level corresponds to deciding whether or not to edit, then where to 499 | edit, then what operation to take, and finally what token to use for insertions/substitutions. 500 | We search the tree mainting a list of the most likely actions, and since the probability of a node is always 501 | lower than that of its parent (since the probabilities multiply and are all leq 1), we can easily determine 502 | which nodes to explore by exploring them in sorted order. 503 | """ 504 | class ActionBuffer: 505 | 506 | def __init__(self, top_k): 507 | self.top_k = top_k 508 | self.actions = [None] * top_k 509 | self.logps = [-1e20] * top_k 510 | 511 | def insert(self, action, logp): 512 | # find insertion index 513 | for i in range(self.top_k-1, -2, -1): 514 | if i == -1: 515 | break 516 | if logp < self.logps[i]: 517 | break 518 | 519 | self.logps = self.logps[:i+1] + [logp] + self.logps[i+1:] 520 | self.actions = self.actions[:i+1] + [action] + self.actions[i+1:] 521 | 522 | self.logps, self.actions = self.logps[:self.top_k], self.actions[:self.top_k] 523 | 524 | def min_lp(self): 525 | return self.logps[-1] 526 | 527 | # def iter_torch_sort(sort_obj): 528 | # """ 529 | # Helper function for iterating over tensors. Much more efficient than directly iterating over tensor 530 | # since we often don't iterate over all the values 531 | # """ 532 | # values, indices = sort_obj.values, sort_obj.indices 533 | # # values = values.cpu() 534 | # for i in range(values.shape[0]): 535 | # yield values[i], indices[i] 536 | 537 | def iter_np_sort(sort_obj, descending=True): 538 | sorted_indices = np.argsort(sort_obj) 539 | if descending: sorted_indices = reversed(sorted_indices) 540 | for i in sorted_indices: #desce 541 | yield sort_obj[i], i 542 | 543 | # always include first index in frozen_idxs since cannot sub or del the CLS token 544 | # buffer for actions and their logps 545 | action_buffer = ActionBuffer(top_k) 546 | 547 | levels = (stop_out, location_out, action_out, vocab_out) 548 | level_logps = np.array([0,0,0,0], dtype=np.float32) 549 | # stopping probabilities 550 | stop_p = torch.sigmoid(stop_out) 551 | stop_lp = torch.log(stop_p) 552 | if not exclude_stop: 553 | action_buffer.insert((1, None, None, None), stop_lp.item()) 554 | level_logps[0] = np.log(1-stop_p) 555 | 556 | def search_level(level, idxs): 557 | if level > 3: return 558 | out = levels[level].unsqueeze(0)[idxs] 559 | logprobs = F.log_softmax(out, dim=-1) 560 | for lprob, i in iter_np_sort(logprobs.numpy()): 561 | # for lprob, i in iter_torch_sort(torch.sort(logprobs, descending=True)): 562 | level_logps[level] = lprob 563 | cumlp = level_logps[:level+1].sum() 564 | if cumlp <= action_buffer.min_lp(): 565 | return 566 | if level == 3: 567 | action_buffer.insert(idxs + (i,), cumlp) 568 | elif level == 2 and i==2: #deletion action is special because it doesn't specify a token 569 | action_buffer.insert(idxs + (i, None,), cumlp) 570 | else: 571 | search_level(level+1, idxs + (i,)) 572 | 573 | search_level(1, (0,)) 574 | 575 | return action_buffer.actions, action_buffer.logps, stop_lp.item() 576 | 577 | @staticmethod 578 | def sample(model_out, top_k=None): 579 | if top_k is None: 580 | action = BertEditor.ancestral_sample(*model_out) 581 | else: 582 | actions, logps, stop_lp = BertEditor.get_topk_action_logps(*model_out, top_k=top_k) 583 | action, _ = BertEditor.sample_from_logits(actions, logps) 584 | return action 585 | 586 | def edit(self, canvas, top_k=None, device=torch.device('cpu')): 587 | """ 588 | Make an edit to a canvas 589 | """ 590 | model_out = next(self.forward_canvas_([canvas], device=device, move_to_cpu=True)) 591 | action = BertEditor.sample(model_out, top_k=top_k) 592 | action = self.convert_action(action) 593 | if action[0]: return True 594 | canvas.operate(*action[1:], agent=0) 595 | return False 596 | 597 | # def decode_loop(self, canvases, decode_func, state=None, max_batch_tokens=2048, parallel=False, 598 | # n_processes=None, device=torch.device('cpu'), queue_size=2000, return_idx=False): 599 | # ''' parallelized decoding loop ''' 600 | # 601 | # def enqueue(canvas, state, iter_idx): 602 | # input_queue.put((canvas, iter_idx, state)) 603 | # 604 | # input_iter = iter(canvases) 605 | # start_state = state 606 | # def add_input(idx): 607 | # # add input to queue 608 | # c = next(input_iter, None) 609 | # if c is None: 610 | # return True 611 | # else: 612 | # enqueue(c.copy(), start_state, idx) 613 | # return False 614 | # 615 | # def inputs_generator(batch): 616 | # canvases = [b[0] for b in batch] 617 | # for i, model_out in enumerate(self.forward_canvas_(canvases, device=device, move_to_cpu=True)): 618 | # canvas, idx, state = batch[i] 619 | # yield model_out, canvas, state, idx 620 | # 621 | # iter_idx = 0 622 | # finished = False 623 | # input_queue = PriorityQueue(maxsize=queue_size) # priority queue groups similar sized inputs together 624 | # 625 | # # fill up queue 626 | # while not input_queue.full(): 627 | # finished = add_input(iter_idx) 628 | # iter_idx += 1 629 | # if finished: break 630 | # 631 | # if parallel: 632 | # if n_processes is None: n_processes = os.cpu_count() 633 | # else: n_processes = min(os.cpu_count(), n_processes) 634 | ## print(f'Using {n_processes} processes') 635 | # worker_pool = mp.Pool(processes=n_processes) 636 | # 637 | # start_time = time.time() 638 | # 639 | # try: 640 | # while not input_queue.empty(): 641 | # # prepare batch 642 | # batch = [] 643 | # batch_len = 0 644 | # batch_n_tokens = 0 645 | # while True: 646 | # if input_queue.empty(): break 647 | # # get input 648 | # inp = input_queue.get() 649 | # canvas, idx, state = inp 650 | # n_tokens = len(canvas) 651 | # # check if it can fit in batch 652 | # new_batch_len = batch_len + 1 653 | # new_batch_n_tokens = max(batch_n_tokens, n_tokens) # max because inputs will get padded 654 | # # if it doesn't fit, requeue and break 655 | # if new_batch_len * new_batch_n_tokens > max_batch_tokens: 656 | # enqueue(canvas, state, idx) # requeue 657 | # break 658 | # # otherwise add to batch 659 | # else: 660 | # batch.append(inp) 661 | # batch_len = new_batch_len 662 | # batch_n_tokens = new_batch_n_tokens 663 | # 664 | # batch_size = len(batch) 665 | # if batch_size == 0: 666 | # raise RuntimeError('Unable to fit any inputs into batch. Try increasing max batch tokens') 667 | # inputs_feed = list(inputs_generator(batch)) 668 | # if parallel: 669 | # chunk_size = batch_size // n_processes 670 | ## outputs_generator = worker_pool.imap_unordered(decode_func, inputs_feed, chunk_size) 671 | # outputs_generator = [worker_pool.apply_async(decode_func, args=(inp,)) for inp in inputs_feed] 672 | # else: 673 | # outputs_generator = map(decode_func, inputs_feed) 674 | # did_timeout = False 675 | # for out in outputs_generator: 676 | # if parallel: 677 | # try: 678 | # out = out.get(MP_TIMEOUT) 679 | # except mp.TimeoutError: 680 | # did_timeout = True 681 | # if return_idx: 682 | # yield None, None 683 | # else: 684 | # yield None 685 | # continue 686 | # canvas, action, stop, state, idx = out 687 | # if stop: 688 | # if return_idx: 689 | # yield canvas, idx 690 | # else: 691 | # yield canvas 692 | # if not finished: 693 | # finished = add_input(iter_idx) 694 | # iter_idx += 1 695 | # else: 696 | # action = self.convert_action(action) 697 | # canvas.operate(*action[1:], agent=0) 698 | # enqueue(canvas.copy(), state, idx) 699 | # if did_timeout and parallel: #reset the worker pool to avoid accumulating dead processes 700 | # print('restarting pool') 701 | # worker_pool.terminate() 702 | # worker_pool = mp.Pool(processes=n_processes) 703 | ## assert finished 704 | # finally: 705 | # if parallel: 706 | # worker_pool.terminate() 707 | # 708 | # delta = time.time() - start_time 709 | # delta = int(1000*delta) 710 | ## print(f'Took {delta} ms') 711 | ## return out_canvases 712 | # return 713 | 714 | @staticmethod 715 | def canvas_len(canvas): 716 | return len(canvas) 717 | 718 | def decode_(self, model_out, canvas, state): 719 | """ 720 | Regular decoding with sampling 721 | """ 722 | # if np.random.random() <= 0.01: 723 | # time.sleep(9999) 724 | # model_out, canvas, state, idx = inputs 725 | top_k, max_iter, cur_iter = state 726 | state = top_k, max_iter, cur_iter + 1 727 | action = BertEditor.sample(model_out, top_k=top_k) 728 | if action[0] or cur_iter > max_iter: 729 | # return canvas, action, True, state, idx 730 | return canvas, state, True 731 | else: 732 | action = self.convert_action(action) 733 | canvas.operate(*action[1:], agent=0) 734 | # return canvas, action, False, state, idx 735 | return canvas, state, False 736 | 737 | def batch_decode(self, canvases, top_k=10, max_iter=32, **kwargs): 738 | state = top_k, max_iter, 0 739 | for out in self.decode_loop(canvases, self.decode_, state=state, **kwargs): 740 | yield out 741 | 742 | def decode(self, canvas, **kwargs): 743 | return next(self.batch_decode([canvas], **kwargs)) 744 | 745 | def depth_decode_(self, model_out, canvas, state): 746 | """ 747 | Depth decoding: ignores stop actions, continus until reaching a stopping condition. Then 748 | returns whichever canvas had the highest stopping probability 749 | """ 750 | # model_out, canvas, state, idx = inputs 751 | top_k, stop_threshold, max_iter, cur_iter, top_canvas, top_lp = state 752 | 753 | actions, logps, stop_lp = BertEditor.get_topk_action_logps(*model_out, top_k=top_k, exclude_stop=True) 754 | if stop_lp > top_lp: 755 | top_lp = stop_lp 756 | top_canvas = canvas.copy() 757 | state = (top_k, stop_threshold, max_iter, cur_iter + 1, top_canvas, top_lp) 758 | if np.exp(stop_lp) >= stop_threshold or cur_iter > max_iter: 759 | return top_canvas, state, True 760 | # return top_canvas, None, True, state, idx 761 | else: 762 | action, _ = BertEditor.sample_from_logits(actions, logps) 763 | action = self.convert_action(action) 764 | canvas.operate(*action[1:], agent=0) 765 | return canvas, state, False 766 | # return canvas, action, False, state, idx 767 | 768 | def batch_depth_decode(self, canvases, top_k=10, stop_threshold=0.95, max_iter=64, **kwargs): 769 | state = top_k, stop_threshold, max_iter, 0, None, -1e20 770 | for out in self.decode_loop(canvases, self.depth_decode_, state=state, **kwargs): 771 | yield out 772 | 773 | def depth_decode(self, canvas, **kwargs): 774 | return next(self.batch_depth_decode([canvas], **kwargs)) 775 | # 776 | # def depth_decode(self, canvas, top_k=20, device=torch.device('cpu'), max_iter=64, stop_threshold=0.95, len_normalize=False, only_stop_lp=False): 777 | # canvases = PriorityQueue() 778 | # running_logp = 0 779 | # canvas = canvas.copy() 780 | # for i in range(max_iter): 781 | # actions, logps, stop_lp = self.get_topk_actions(canvas, device, top_k=top_k, exclude_stop=True) 782 | # 783 | # # add completed canvas 784 | # if only_stop_lp: 785 | # canvases.put((-stop_lp, canvas)) 786 | # else: 787 | # canvases.put((-(running_logp + stop_lp), canvas)) 788 | # 789 | # #stopping condition 790 | # if np.exp(stop_lp) >= stop_threshold: break 791 | # 792 | # # make new canvas 793 | # canvas = canvas.copy() 794 | # action, lp = self.sample_action_(actions, logps) 795 | # action = self.convert_action(action) 796 | # canvas.operate(*action[1:], agent=0) 797 | # running_logp += lp 798 | # if len_normalize: running_logp += np.log(i+1) 799 | # return canvases.get()[1] 800 | 801 | def beam_decode(self, canvas, top_k=1, beam_width=10, max_length=32, len_normalize=False, device=torch.device('cpu')): 802 | """ 803 | Beam decoding, not actually used 804 | """ 805 | 806 | if top_k > beam_width: 807 | raise ValueError('top_k cannot be greater than beam_width') 808 | 809 | decoded_batch = [] 810 | endnodes = PriorityQueue() 811 | 812 | node = BeamSearchNode(None, canvas, False, 0, 0) 813 | nodes = PriorityQueue() 814 | 815 | nodes.put((-node.eval(len_normalize=len_normalize), node)) 816 | qsize = 1 817 | 818 | for i in range(max_length): 819 | if nodes.empty(): break 820 | nextnodes = PriorityQueue() 821 | # go through nodes at current depth 822 | assert nodes.qsize() <= beam_width, nodes.qsize() 823 | while not nodes.empty(): 824 | score, n = nodes.get() 825 | canvas = n.canvas 826 | 827 | if n.stop: 828 | endnodes.put((score, n)) 829 | continue 830 | 831 | actions, logps, _ = self.get_topk_actions(canvas, device, top_k=beam_width) 832 | for a,lp in zip(actions, logps): 833 | edited_canvas = canvas.copy() 834 | a = self.convert_action(a) 835 | if not a[0]: edited_canvas.operate(*a[1:], agent=0) 836 | 837 | 838 | node = BeamSearchNode(n, edited_canvas, a[0], n.logp + lp, n.len + 1) 839 | score = -node.eval(len_normalize=len_normalize) 840 | nextnodes.put((score, node)) 841 | 842 | for _ in range(beam_width): 843 | if nextnodes.empty(): break 844 | score, nn = nextnodes.get() 845 | nodes.put((score, nn)) 846 | 847 | if endnodes.empty(): 848 | for _ in range(top_k): 849 | score,n = nodes.get() 850 | endnodes.put((score, n)) 851 | 852 | for _ in range(top_k): 853 | if endnodes.empty(): break 854 | score, n = endnodes.get() 855 | decoded_batch.append(n.canvas) 856 | 857 | return decoded_batch 858 | 859 | class BartEncoderTType(BartEncoder): 860 | 861 | """ 862 | Token type ids for BART (not supported natively) 863 | """ 864 | 865 | def __init__(self, config, embed_tokens): 866 | super().__init__(config, embed_tokens) 867 | 868 | self.token_type_embeddings = torch.nn.Embedding(config.n_type_ids, config.d_model) 869 | 870 | def forward(self, input_ids=None, inputs_embeds=None, token_type_ids=None, **kwargs): 871 | 872 | if input_ids is not None and inputs_embeds is not None: 873 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 874 | elif input_ids is not None: 875 | input_shape = input_ids.size() 876 | input_ids = input_ids.view(-1, input_shape[-1]) 877 | elif inputs_embeds is not None: 878 | input_shape = inputs_embeds.size()[:-1] 879 | else: 880 | raise ValueError("You have to specify either input_ids or inputs_embeds") 881 | 882 | if token_type_ids is None: 883 | token_type_ids = torch.zeros_like(input_ids) 884 | 885 | if inputs_embeds is None: 886 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 887 | 888 | token_type_embeds = self.token_type_embeddings(token_type_ids) 889 | inputs_embeds += token_type_embeds 890 | 891 | return super().forward(inputs_embeds=inputs_embeds, **kwargs) 892 | 893 | class BartModelTType(BartModel): 894 | 895 | def __init__(self, config): 896 | super(BartModel, self).__init__(config) 897 | 898 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size 899 | self.shared = torch.nn.Embedding(vocab_size, config.d_model, padding_idx) 900 | 901 | self.encoder = BartEncoderTType(config, self.shared) 902 | self.decoder = BartDecoder(config, self.shared) 903 | 904 | def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, 905 | output_attentions=None, output_hidden_states=None, return_dict=None, 906 | token_type_ids=None, encoder_outputs = None, **kwargs): 907 | 908 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 909 | output_hidden_states = ( 910 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 911 | ) 912 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 913 | 914 | if encoder_outputs is None: 915 | encoder_outputs = self.encoder( 916 | input_ids=input_ids, 917 | attention_mask=attention_mask, 918 | head_mask=head_mask, 919 | inputs_embeds=inputs_embeds, 920 | output_attentions=output_attentions, 921 | output_hidden_states=output_hidden_states, 922 | return_dict=return_dict, 923 | token_type_ids=token_type_ids 924 | ) 925 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 926 | encoder_outputs = BaseModelOutput( 927 | last_hidden_state=encoder_outputs[0], 928 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 929 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 930 | ) 931 | 932 | return super().forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, 933 | inputs_embeds=inputs_embeds, output_attentions=output_attentions, 934 | output_hidden_states=output_hidden_states, return_dict=return_dict, 935 | encoder_outputs=encoder_outputs, **kwargs) 936 | 937 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 938 | """ 939 | Shift input ids one token to the right. 940 | """ 941 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 942 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 943 | shifted_input_ids[:, 0] = decoder_start_token_id 944 | 945 | if pad_token_id is None: 946 | raise ValueError("self.model.config.pad_token_id has to be defined.") 947 | # replace possible -100 values in labels by `pad_token_id` 948 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 949 | 950 | return shifted_input_ids 951 | 952 | class BartForConditionalGenerationTType(BartForConditionalGeneration): 953 | 954 | """ 955 | Derived class to add in the token type ids 956 | """ 957 | 958 | def __init__(self, config): 959 | super().__init__(config) 960 | self.model = BartModelTType(config) 961 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 962 | self.lm_head = torch.nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 963 | 964 | self.init_weights() 965 | 966 | def forward( 967 | self, 968 | input_ids=None, 969 | attention_mask=None, 970 | decoder_input_ids=None, 971 | decoder_attention_mask=None, 972 | head_mask=None, 973 | decoder_head_mask=None, 974 | encoder_outputs=None, 975 | past_key_values=None, 976 | inputs_embeds=None, 977 | decoder_inputs_embeds=None, 978 | labels=None, 979 | use_cache=None, 980 | output_attentions=None, 981 | output_hidden_states=None, 982 | return_dict=None, 983 | token_type_ids=None, 984 | **kwargs # why not absorb all these arguments in kwargs?? 985 | ): 986 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 987 | 988 | if labels is not None: 989 | if decoder_input_ids is None and decoder_inputs_embeds is None: 990 | decoder_input_ids = shift_tokens_right( 991 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 992 | ) 993 | 994 | outputs = self.model( 995 | input_ids, 996 | attention_mask=attention_mask, 997 | decoder_input_ids=decoder_input_ids, 998 | encoder_outputs=encoder_outputs, 999 | decoder_attention_mask=decoder_attention_mask, 1000 | head_mask=head_mask, 1001 | decoder_head_mask=decoder_head_mask, 1002 | past_key_values=past_key_values, 1003 | inputs_embeds=inputs_embeds, 1004 | decoder_inputs_embeds=decoder_inputs_embeds, 1005 | use_cache=use_cache, 1006 | output_attentions=output_attentions, 1007 | output_hidden_states=output_hidden_states, 1008 | return_dict=return_dict, 1009 | token_type_ids=token_type_ids, 1010 | **kwargs 1011 | ) 1012 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias 1013 | 1014 | masked_lm_loss = None 1015 | if labels is not None: 1016 | loss_fct = torch.nn.CrossEntropyLoss() 1017 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 1018 | 1019 | if not return_dict: 1020 | output = (lm_logits,) + outputs[1:] 1021 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1022 | 1023 | return Seq2SeqLMOutput( 1024 | loss=masked_lm_loss, 1025 | logits=lm_logits, 1026 | past_key_values=outputs.past_key_values, 1027 | decoder_hidden_states=outputs.decoder_hidden_states, 1028 | decoder_attentions=outputs.decoder_attentions, 1029 | cross_attentions=outputs.cross_attentions, 1030 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1031 | encoder_hidden_states=outputs.encoder_hidden_states, 1032 | encoder_attentions=outputs.encoder_attentions, 1033 | ) 1034 | 1035 | class BartS2SEditor(torch.nn.Module, ParallelDecodingMixin): 1036 | """ 1037 | S2S model 1038 | """ 1039 | 1040 | def __init__(self, align_model, tokenizer, model_file='facebook/bart-base', from_pretrained=True, 1041 | alignment_baseline_score=0.2): 1042 | super().__init__() 1043 | 1044 | self.align_model = align_model 1045 | for p in self.align_model.parameters(): 1046 | p.requires_grad = False 1047 | self.alignment_baseline_score=alignment_baseline_score 1048 | self.tokenizer = tokenizer 1049 | vocab_size = len(self.tokenizer) 1050 | 1051 | config = AutoConfig.from_pretrained(model_file) 1052 | config.n_type_ids = 5 1053 | 1054 | if from_pretrained: 1055 | self.bart_model = BartForConditionalGenerationTType.from_pretrained( 1056 | model_file, config=config) 1057 | else: 1058 | self.bart_model = BartForConditionalGenerationTType.from_config(config) 1059 | self.bart_model.resize_token_embeddings(vocab_size) 1060 | 1061 | def forward(self, **kwargs): 1062 | return self.bart_model(**kwargs) 1063 | 1064 | def compute_loss(self, input_ids, token_type_ids, attention_mask, labels): 1065 | model_loss = self.bart_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, 1066 | token_type_ids=token_type_ids).loss 1067 | metrics = {'loss': model_loss.item()} 1068 | return model_loss, metrics 1069 | 1070 | def prep_canvas(self, canvas): 1071 | tokens, type_ids = canvas.tokens, canvas.type_ids 1072 | token_ids = [self.tokenizer.bos_token_id] + \ 1073 | self.tokenizer.convert_tokens_to_ids(tokens) +\ 1074 | [self.tokenizer.eos_token_id] 1075 | type_ids = [0] + type_ids + [0] 1076 | assert len(token_ids) == len(type_ids) 1077 | return token_ids, type_ids 1078 | 1079 | def prep_canvases(self, canvases, device=torch.device('cpu')): 1080 | token_ids, type_ids = zip(*[self.prep_canvas(c) for c in canvases]) 1081 | max_length = np.max([len(ids) for ids in token_ids]) 1082 | input_ids = torch.ones((len(canvases), max_length), dtype=torch.long, 1083 | device=device) * self.tokenizer.pad_token_id 1084 | attention_mask = torch.zeros(input_ids.shape, dtype=torch.int32, device=device) 1085 | token_type_ids = torch.zeros_like(attention_mask) 1086 | for i in range(len(canvases)): 1087 | input_ids[i, :len(token_ids[i])] = torch.tensor(token_ids[i]) 1088 | token_type_ids[i, :len(token_ids[i])] = torch.tensor(type_ids[i]) 1089 | attention_mask[i, :len(token_ids[i])] = 1 1090 | return input_ids, attention_mask, token_type_ids 1091 | 1092 | def prep_target(self, target): 1093 | target = [self.tokenizer.bos_token_id] +\ 1094 | self.tokenizer.convert_tokens_to_ids(target)+\ 1095 | [self.tokenizer.eos_token_id] 1096 | return target 1097 | 1098 | def prep_targets(self, targets, device=torch.device('cpu')): 1099 | token_ids = [self.prep_target(t) for t in targets] 1100 | max_length = np.max([len(ids) for ids in token_ids]) 1101 | input_ids = torch.ones((len(targets), max_length), dtype=torch.long, 1102 | device=device) * -100 1103 | for i in range(len(targets)): 1104 | input_ids[i, :len(token_ids[i])] = torch.tensor(token_ids[i]) 1105 | return input_ids 1106 | 1107 | def prep_gen(self, gen): 1108 | return self.tokenizer.convert_tokens_to_ids(gen) 1109 | 1110 | def prep_generations(self, generations, device=torch.device('cpu')): 1111 | token_ids = [self.prep_gen(g) for g in generations] 1112 | max_length = np.max([len(ids) for ids in token_idxs]) 1113 | input_ids = torch.ones((len(targets), max_length), dtype=torch.long, 1114 | device=device) * self.tokenizer.pad_token_id 1115 | input_lengths = [] 1116 | for i in range(len(generations)): 1117 | input_ids[i, :len(token_ids[i])] = torch.tensor(token_ids[i]) 1118 | input_lengths.append(len(token_ids[i])) 1119 | return input_ids, input_lengths 1120 | 1121 | def prep_batch(self, batch, device=torch.device('cpu'), **kwargs): 1122 | canvases = [a.get_source_canvas() for a in batch] 1123 | input_ids, attention_mask, type_ids = self.prep_canvases(canvases, device=device) 1124 | targets = [a.get_target_tokens() for a in batch] 1125 | labels = self.prep_targets(targets, device=device) 1126 | 1127 | return input_ids, type_ids, attention_mask, labels 1128 | 1129 | def move_batch(self, batch, device): 1130 | input_ids, type_ids, attention_mask, labels = batch 1131 | return input_ids.to(device), type_ids.to(device), attention_mask.to(device), labels.to(device) 1132 | 1133 | # ====== DECODING ============ 1134 | 1135 | # @staticmethod 1136 | # def canvas_len(canvas): 1137 | # canvas, gen = canvas 1138 | # return len(canvas) 1139 | # 1140 | # def forward_canvases(self, canvases, device=torch.device('cpu'), move_to_cpu=True): 1141 | # batch_size = len(canvases) 1142 | # canvases, gen = zip(*canvases) 1143 | # input_ids, attention_mask, type_ids = self.prep_canvases(canvases, device=device) 1144 | # gen_ids, gen_lengths = self.prep_generations(gen, device=device) 1145 | # with torch.no_grad(): 1146 | # model_out = self.bart_model(input_ids=input_ids, attention_mask=attention_mask, 1147 | # token_type_ids=type_ids, decoder_input_ids=gen_ids) 1148 | # out = [] 1149 | # for i in range(batch_size): 1150 | 1151 | 1152 | def batch_generate(self, canvases, device=torch.device('cpu'), **kwargs): 1153 | generations = [] 1154 | for c in tqdm.tqdm(canvases): 1155 | input_ids, attention_mask, type_ids = self.prep_canvases([c], device=device) 1156 | # print(input_ids.shape, attention_mask.shape) 1157 | gen = self.bart_model.generate(input_ids=input_ids, 1158 | attention_mask=attention_mask, token_type_ids=type_ids, early_stopping=False,**kwargs) 1159 | # print(gen.shape) 1160 | # for g in gen.cpu().numpy(): 1161 | # print(self.tokenizer.decode(g)) 1162 | tokens = self.tokenizer.convert_ids_to_tokens(gen[0].cpu().numpy()) 1163 | tokens = [t for t in tokens if not t in ('', '', '')] # get rid of extraneous tokens 1164 | # generations.append(self.tokenizer.convert_ids_to_tokens(gen[0][2:-1].cpu().numpy())) 1165 | generations.append(tokens) 1166 | generated_canvases = [] 1167 | # print(generations) 1168 | for alignment in batch_align_canvases(canvases, generations, self.align_model, self.tokenizer, 1169 | baseline_score=self.alignment_baseline_score, device=device): 1170 | alignment.push_forward(alignment.get_non_const_ops()) 1171 | generated_canvases.append(alignment.get_source_canvas()) 1172 | return generated_canvases 1173 | 1174 | def generate(self, canvas, **kwargs): 1175 | return self.batch_generate([canvas], **kwargs)[0] 1176 | 1177 | #def sample_action(stop_out, location_out, action_out, vocab_out, input_ids, sample=False): 1178 | # if sample: 1179 | # stop = int(np.random.random() <= stop_out.item()) 1180 | # else: 1181 | # stop = int(stop_out.item() > 0.5) 1182 | # if stop == 1: 1183 | # return stop, None, None, None 1184 | # 1185 | # if sample: 1186 | # refinement_idx = torch.multinomial(F.softmax(location_out, dim=-1), 1).item() 1187 | # else: 1188 | # refinement_idx = torch.argmax(F.softmax(location_out, dim=-1)).item() 1189 | # if refinement_idx > 0: 1190 | # if sample: 1191 | # action_idx = torch.multinomial(F.softmax(action_out[:, refinement_idx, :], dim=-1), 1).item() 1192 | # else: 1193 | # action_idx = torch.argmax(F.softmax(action_out[:, refinement_idx, :], dim=-1)).item() 1194 | # else: # can only insert when selecting the sentinel 1195 | # action_idx = 0 1196 | # 1197 | # if action_idx != 2: 1198 | # if sample: 1199 | # token = torch.multinomial(F.softmax(vocab_out[:, refinement_idx, action_idx, :], dim=-1), 1).squeeze().item() 1200 | # else: 1201 | # top_tokens = torch.argsort(F.softmax(vocab_out[:, refinement_idx, action_idx, :], dim=-1)).squeeze() 1202 | # token = top_tokens[-1].item() 1203 | # if token == input_ids[:, refinement_idx] and action_idx == 1: token = top_tokens[-2].item() 1204 | # else: token = None 1205 | # 1206 | # return stop, refinement_idx, action_idx, token 1207 | # 1208 | 1209 | 1210 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 1211 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 1212 | Args: 1213 | logits: logits distribution shape (vocabulary size) 1214 | top_k >0: keep only top k tokens with highest probability (top-k filtering). 1215 | top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 1216 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 1217 | """ 1218 | assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear 1219 | top_k = min(top_k, logits.size(-1)) # Safety check 1220 | if top_k > 0: 1221 | # Remove all tokens with a probability less than the last token of the top-k 1222 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 1223 | logits[indices_to_remove] = filter_value 1224 | 1225 | if top_p > 0.0: 1226 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 1227 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 1228 | 1229 | # Remove tokens with cumulative probability above the threshold 1230 | sorted_indices_to_remove = cumulative_probs > top_p 1231 | # Shift the indices to the right to keep also the first token above the threshold 1232 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 1233 | sorted_indices_to_remove[..., 0] = 0 1234 | 1235 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 1236 | logits[indices_to_remove] = filter_value 1237 | return logits 1238 | 1239 | 1240 | -------------------------------------------------------------------------------- /infosol/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import wandb 4 | import pickle 5 | import json 6 | import random 7 | import tqdm 8 | import torch 9 | import itertools 10 | import numpy as np 11 | import argparse 12 | import datetime 13 | 14 | from distutils.util import strtobool 15 | from torch.utils.data import DataLoader 16 | from torch.utils.tensorboard import SummaryWriter 17 | from datasets import DatasetDict, Dataset, load_from_disk, concatenate_datasets, load_dataset 18 | from transformers import BertModel, AutoTokenizer, BartModel 19 | from infosol.models.word_edit_model import BertEditor, sample_trajectory, BartS2SEditor 20 | from infosol.env import WordEditOracle, EditingEnvironment 21 | from infosol.alignment import * 22 | 23 | DATA_DIR = #/path/to/data_dir/ 24 | SAVE_DIR = #/path/to/default/save_dir/ 25 | IDF_PATH = os.path.join(DATA_DIR, 'misc', 'cnn_bart_idfs.pickle') 26 | TF_PATH = os.path.join(DATA_DIR, 'misc', 'cnn_bart_tfs.pickle') 27 | DATA_PATH = os.path.join(DATA_DIR, 'cnn', 'filtered_bart_64') 28 | 29 | def custom_cat_datasets(data1, data2): 30 | """ 31 | Utility to concatenate two huggingface datasets 32 | """ 33 | l1, l2 = len(data1), len(data2) 34 | data1 = data1.to_dict() 35 | data2 = data2.to_dict() 36 | keys = set(data1.keys()).union(set(data2.keys())) 37 | cat_data= {} 38 | for k in list(keys): 39 | d1 = data1.get(k) 40 | if d1 is None: 41 | d1 = [None] * l1 42 | d2 = data2.get(k) 43 | if d2 is None: 44 | d2 = [None] * l2 45 | cat_data[k] = d1 + d2 46 | return Dataset.from_dict(cat_data) 47 | 48 | class WBLogger(): 49 | 50 | def log(self, metrics): 51 | wandb.log(metrics) 52 | 53 | class Train(): 54 | 55 | """ 56 | Base training job. Doesn't use any generations from model 57 | """ 58 | 59 | def __init__( 60 | self, 61 | config=None, 62 | log_dir=None, 63 | save_dir=SAVE_DIR, 64 | rng=None, 65 | project_name=None, 66 | run_name=None, 67 | accumulation_steps=8, 68 | learning_rate=1e-4, 69 | n_epochs=1, 70 | report_every=50, 71 | val_every=1000, 72 | device=torch.device('cpu'), 73 | resume_from_epoch=0, 74 | keep_token_type_ids=True, 75 | track_gradient_norm=False, 76 | clip_grad_norm=False, 77 | max_grad_norm=2000, 78 | **kwargs): 79 | 80 | self.config = config 81 | self.project_name = project_name 82 | self.run_name = run_name 83 | self.rng = rng 84 | self.save_dir=save_dir 85 | self.model_save_path = os.path.join(self.save_dir, 'WEIGHTS.bin') 86 | self.log_dir=log_dir 87 | self.accumulation_steps = accumulation_steps 88 | self.n_epochs = n_epochs 89 | self.report_every = report_every 90 | self.val_every = val_every 91 | self.device = device 92 | self.resume_from_epoch = resume_from_epoch 93 | self.keep_token_type_ids = keep_token_type_ids 94 | self.track_gradient_norm = track_gradient_norm 95 | self.clip_grad_norm = clip_grad_norm 96 | self.max_grad_norm = max_grad_norm 97 | 98 | # self.logger = SummaryWriter(log_dir = self.log_dir) 99 | self.logger = WBLogger() 100 | 101 | self.kwargs = kwargs 102 | 103 | print("Loading environment") 104 | self.load_env(**kwargs) 105 | print("Loading data") 106 | self.load_data(**kwargs) 107 | print("Loading model") 108 | self.load_model(**kwargs) 109 | print("Loading optimizer") 110 | self.load_optimizer(**kwargs) 111 | 112 | def log(self, logdict, prefix, i, iter_name='iter'): 113 | for n in logdict: 114 | metrics = { 115 | '/'.join((prefix,n)): np.mean(logdict[n]) for n in logdict} 116 | metrics[iter_name] = i 117 | self.logger.log(metrics) 118 | # self.logger.add_scalar('/'.join((prefix, n)), 119 | # np.mean(logdict[n]), 120 | # i) 121 | 122 | def load_env(self, 123 | env_type='bart', 124 | idf_path=IDF_PATH, 125 | sort_ops='sort', 126 | adjacent_ops=False, 127 | avoid_delete=False, 128 | contiguous_edits=False, 129 | complete_words=True, 130 | baseline_score=0.3, 131 | oracle_stop_p=0.25, 132 | n_oracle_hints=-1, 133 | **kwargs): 134 | print(f'n_oracle_hints={n_oracle_hints}') 135 | 136 | with open(idf_path, 'rb') as f: 137 | idf_dict = pickle.load(f) 138 | 139 | if env_type == 'bert': 140 | self.align_model = BertModel.from_pretrained('bert-base-uncased') 141 | self.align_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') 142 | token_no_space = lambda x: x.startswith('#') 143 | elif env_type == 'bart': 144 | self.align_model = BartModel.from_pretrained('facebook/bart-base').encoder 145 | self.align_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base') 146 | token_no_space=lambda x: not x.startswith('Ġ') 147 | else: 148 | raise ValueError('Unknown env type: {}'.format(env_type)) 149 | 150 | self.oracle = WordEditOracle(self.align_model, self.align_tokenizer, idf_dict, 151 | sort_ops=sort_ops, adjacent_ops=adjacent_ops, 152 | avoid_delete=avoid_delete, baseline_score=baseline_score, 153 | contiguous_edits=contiguous_edits, complete_words=complete_words, 154 | token_no_space=token_no_space) 155 | self.env = EditingEnvironment(self.oracle, oracle_stop_p, n_oracle_edits=n_oracle_hints) 156 | 157 | def load_data(self, 158 | data_path=DATA_PATH, 159 | batch_size=16, 160 | max_train_edits=0, 161 | max_val_edits=1000, 162 | **kwargs): 163 | self.data = load_from_disk(data_path) 164 | self.batch_size = batch_size 165 | 166 | max_train_edits = len(self.data['train']) if max_train_edits == 0 else max_train_edits 167 | max_val_edits = len(self.data['val']) if max_val_edits == 0 else max_val_edits 168 | 169 | self.data['train'] = self.data['train'].shuffle(generator=self.rng).select(list(range(max_train_edits))) 170 | self.data['val'] = self.data['val'].shuffle(generator=self.rng).select(list(range(max_val_edits))) 171 | 172 | def load_model(self, 173 | tf_path=TF_PATH, 174 | model_name='bart', 175 | noise_frac=0.0, 176 | resume_from_ckpt=None, 177 | **kwargs): 178 | with open(tf_path, 'rb') as f: 179 | tf_dict = pickle.load(f) 180 | tf_map, tf_weights = {}, [] 181 | for i,k in enumerate(tf_dict): 182 | tf_map[i] = k 183 | tf_weights.append(tf_dict[k]) 184 | vocab_sampler = VocabSampler(tf_weights, tf_map) 185 | 186 | self.model_name = model_name 187 | if model_name == 'bert': 188 | self.model = BertEditor( 189 | tokenizer=self.align_tokenizer, 190 | vocab_sampler=vocab_sampler, 191 | training_noise=noise_frac) 192 | elif model_name == 'bart': 193 | self.model = BertEditor( 194 | tokenizer=self.align_tokenizer, 195 | vocab_sampler=vocab_sampler, 196 | training_noise=noise_frac, 197 | model_type='bart', 198 | model_file='facebook/bart-base') 199 | elif model_name == 'bart-large': 200 | self.model = BertEditor( 201 | tokenizer=self.align_tokenizer, 202 | vocab_sampler=vocab_sampler, 203 | training_noise=noise_frac, 204 | model_type='bart', 205 | model_file='facebook/bart-large') 206 | elif model_name == 'barts2s': 207 | self.model = BartS2SEditor(self.align_model, self.align_tokenizer) 208 | elif model_name == 'barts2s-large': 209 | self.model = BartS2SEditor(self.align_model, self.align_tokenizer, model_file='facebook/bart-large') 210 | else: 211 | raise NotImplementedError(f'Unknown model name: {model_name}') 212 | 213 | if not resume_from_ckpt is None: 214 | print(f'loading model weights from ckpt {resume_from_ckpt}') 215 | self.model.load_state_dict(torch.load(resume_from_ckpt)) 216 | 217 | self.model = self.model.to(self.device) 218 | self.model = self.model.train() 219 | 220 | def load_optimizer(self, learning_rate=1e-4, **kwargs): 221 | if self.model_name in ('bert', 'bart', 'bart-large'): 222 | self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=learning_rate) 223 | elif self.model_name == 'barts2s': 224 | self.optimizer = torch.optim.Adam(params=self.model.bart_model.parameters(), lr=learning_rate) 225 | else: 226 | raise NotImplementedError(f'Unknown model name: {model_name}') 227 | 228 | def pre_train(self): 229 | self.train_loader = DataLoader(self.data['train'], collate_fn = self.prep_batch, batch_size=self.batch_size, 230 | shuffle=True, drop_last=True) 231 | self.val_loader = DataLoader(self.data['val'], collate_fn = self.prep_batch, batch_size=self.batch_size, 232 | shuffle=False, drop_last=False) 233 | 234 | def train(self): 235 | 236 | def val_(min_val_loss): 237 | val_loss, val_metrics = self.validate(cur_iter) 238 | self.log(val_metrics, 'val', cur_iter+1) 239 | if val_loss <= min_val_loss: 240 | torch.save(self.model.state_dict(), self.model_save_path) 241 | min_val_loss = val_loss 242 | return min_val_loss 243 | 244 | self.pre_train() 245 | min_val_loss = 1e10 246 | print("Training") 247 | cur_iter = 0 248 | batch_num = 0 249 | with wandb.init(config=self.config, project=self.project_name, name=self.run_name, dir=self.log_dir) as wandb_run: 250 | wandb.watch(self.model, log='all') 251 | for e in range(self.resume_from_epoch, self.n_epochs): 252 | print("Starting epoch: {}".format(e)) 253 | self.optimizer.zero_grad() 254 | metrics = {} 255 | for i,batch in tqdm.tqdm(enumerate(self.train_loader), total=len(self.train_loader)): 256 | batch_num += 1 257 | metrics = self.train_step(batch, metrics) 258 | if (i+1) % self.accumulation_steps == 0: 259 | cur_iter += 1 260 | if self.clip_grad_norm: 261 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.max_grad_norm) 262 | if self.track_gradient_norm: 263 | gradient_norm = 0 264 | for p in self.model.parameters(): 265 | if not p.requires_grad: continue 266 | gradient_norm += torch.square(p.grad).sum().item() 267 | gradient_norm = np.sqrt(gradient_norm) 268 | self.log({'grad_norm': gradient_norm}, 'train', cur_iter) 269 | self.optimizer.step() 270 | self.optimizer.zero_grad() 271 | if (cur_iter+1) % args.report_every == 0: 272 | self.log(metrics, 'train', cur_iter+1) 273 | if (cur_iter+1) % args.val_every == 0: 274 | min_val_loss = val_(min_val_loss) 275 | # val_loss, val_metrics = self.validate(cur_iter) 276 | # self.log(val_metrics, 'val', cur_iter+1) 277 | # if val_loss <= min_val_loss: 278 | # min_val_loss = val_loss 279 | # torch.save(self.model.state_dict(), self.model_save_path) 280 | metrics = {} 281 | self.post_epoch(e) 282 | 283 | _ = val_(min_val_loss) 284 | 285 | def post_epoch(self, e): 286 | return 287 | 288 | def train_step(self, batch, metrics): 289 | loss, batch_metrics = self.model.compute_loss(*batch) #TODO: logging 290 | for n in batch_metrics: 291 | if n not in metrics: 292 | metrics[n] = [] 293 | metrics[n].append(batch_metrics[n]) 294 | loss /= self.accumulation_steps 295 | loss.backward() 296 | return metrics 297 | 298 | def validate(self, cur_iter): 299 | self.model.eval() 300 | running_loss = 0 301 | n_batches = 0 302 | metrics = {} 303 | for batch in self.val_loader: # TODO: <= make dataloader 304 | batch = self.model.move_batch(batch, device) 305 | with torch.no_grad(): 306 | loss, batch_metrics = self.model.compute_loss(*batch) 307 | for n in batch_metrics: 308 | if n not in metrics: 309 | metrics[n] = [] 310 | metrics[n].append(batch_metrics[n]) 311 | running_loss += loss.item() 312 | n_batches += 1 313 | # for n,l in zip(val_metrics, (loss, e_loss, l_loss, o_loss, v_loss)): 314 | # val_metrics[n].append(l.item()) 315 | # batch = self.model.move_batch(batch, torch.device('cpu')) 316 | 317 | print("============> Val loss, epoch: iteration: {}".format(cur_iter)) 318 | print(running_loss/n_batches) 319 | # for n in val_metrics: 320 | # print("{}: {}".format(n, np.mean(val_metrics[n]))) 321 | # tb_writer.add_scalar('/'.join(('Val', n)), np.mean(val_metrics[n]), cur_iter) 322 | self.model.train() 323 | return running_loss/n_batches, metrics 324 | 325 | 326 | def prep_batch(self, batch): 327 | alignments = [] 328 | for b in batch: 329 | if self.keep_token_type_ids: 330 | token_type_ids = b.get('token_type_ids') 331 | else: 332 | token_type_ids = None 333 | alignment = Alignment(b['alignment'], b['alignment_scores'], token_type_ids) 334 | alignment = self.env.oracle_edit(alignment=alignment, device=self.device, return_alignment=True) # TODO: move to sampling 335 | alignments.append(alignment) 336 | batch = self.model.prep_batch(alignments, device=self.device) 337 | return batch 338 | 339 | @classmethod 340 | def add_args(self, parser): 341 | parser.add_argument('--accumulation_steps', type=int, default=8) 342 | parser.add_argument('--n_epochs', type=int, default=20) 343 | parser.add_argument('--report_every', type=int, default=50) 344 | parser.add_argument('--val_every', type=int, default=1000) 345 | parser.add_argument('--resume_from_ckpt', type=str) 346 | parser.add_argument('--resume_from_epoch', type=int, default=0) 347 | parser.add_argument('--track_gradient_norm', type=lambda x:bool(strtobool(x)), default=False) 348 | parser.add_argument('--clip_grad_norm', type=lambda x:bool(strtobool(x)), default=False) 349 | parser.add_argument('--max_grad_norm', type=float, default=1000) 350 | 351 | data_group = parser.add_argument_group('data') 352 | data_group.add_argument('--batch_size', type=int, default=16) 353 | data_group.add_argument('--gen_batch_size', type=int, default=16) 354 | data_group.add_argument('--max_train_edits', type=int, default=500) 355 | data_group.add_argument('--max_val_edits', type=int, default=1000) 356 | data_group.add_argument('--data_path', type=str, default= 357 | '/data/scratch/faltings/data/infosol/cnn/filtered_bart_64') 358 | data_group.add_argument('--keep_token_type_ids', type=lambda x:bool(strtobool(x)), default=True) 359 | 360 | env_group = parser.add_argument_group('env') 361 | env_group.add_argument('--env_type', type=str, default='bart') 362 | env_group.add_argument('--oracle_stop_p', type=float, default=0.25) 363 | env_group.add_argument('--n_oracle_hints', type=int, default=-1) 364 | env_group.add_argument('--idf_path', type=str, default= 365 | '/data/scratch/faltings/data/infosol/misc/cnn_bart_idfs.pickle') 366 | env_group.add_argument('--n_return_actions', type=int, default=1) 367 | env_group.add_argument('--sort_ops', type=str, default='sort') 368 | env_group.add_argument('--avoid_delete', type=lambda x:bool(strtobool(x)), default=False) 369 | env_group.add_argument('--adjacent_ops', type=lambda x:bool(strtobool(x)), default=False) 370 | env_group.add_argument('--contiguous_edits', type=lambda x:bool(strtobool(x)), default=False) 371 | env_group.add_argument('--complete_words', type=lambda x:bool(strtobool(x)), default=True) 372 | 373 | model_group = parser.add_argument_group('model') 374 | model_group.add_argument('--model_name', type=str, default='bart') 375 | model_group.add_argument('--noise_frac', type=float, default=0.3) 376 | model_group.add_argument('--max_traj_length', type=int, default=64) 377 | model_group.add_argument('--tf_path', type=str, default= 378 | '/data/scratch/faltings/data/infosol/misc/cnn_bart_tfs.pickle') 379 | 380 | opt_group = parser.add_argument_group('optimizer') 381 | opt_group.add_argument('--learning_rate', type=float, default=1e-4) 382 | 383 | # parser.add_argument('--n_forward_iterations', type=int, default=1) TODO: ForwarTrain 384 | 385 | class GenerationInstance(): 386 | 387 | """ 388 | Used for training jobs that need generation 389 | """ 390 | 391 | def __init__(self, datum, env): 392 | self.done = False 393 | self.target = datum['target_tokens'] 394 | self.target_text = datum['target_text'] 395 | canvas = Canvas(datum['source_tokens'], datum['token_type_ids']) 396 | assert (not datum['alignment'] is None) and (not datum['alignment_scores'] is None) 397 | alignment = Alignment(datum['alignment'], datum['alignment_scores']) 398 | self.history = [(canvas.copy(), alignment.copy())] 399 | self.oracle_canvas = env.reset(alignment=alignment) 400 | 401 | def make_data(self, tokenizer): 402 | instances = [] 403 | for canvas, alignment in self.history: 404 | inst = { 405 | 'source_text': canvas.render(tokenizer), 406 | 'source_tokens': canvas.tokens, 407 | 'token_type_ids': canvas.type_ids, 408 | 'target_text': self.target_text, 409 | 'target_tokens': self.target, 410 | 'alignment': alignment.alignment, 411 | 'alignment_scores': alignment.scores 412 | } 413 | instances.append(inst) 414 | return instances 415 | 416 | class DaggerTrain(Train): 417 | 418 | """ 419 | Dagger training job. Split into epochs that run through small batches of data generated 420 | from the model or sampled from the dataset. 421 | """ 422 | 423 | def __init__( 424 | self, 425 | top_k=10, 426 | top_p=1.0, 427 | stop_threshold=0.9, 428 | max_iter=16, 429 | do_sample=True, 430 | max_length=64, 431 | parallel_decode=False, 432 | n_processes=10, 433 | n_warmup_epochs=0, 434 | sampling_annealing_rate=1.0, 435 | dagger_sampling_rate=1.0, 436 | max_trajectory_length=2, 437 | sample_batch_size=4096, 438 | val_sample_batch_size=1024, 439 | sample_val_every_n_epoch=25, 440 | sample_train_every_n_epoch=10, 441 | **kwargs): 442 | 443 | super().__init__(**kwargs) 444 | 445 | self.top_k = top_k 446 | self.top_p = top_p 447 | self.stop_threshold = stop_threshold 448 | self.max_iter = max_iter 449 | self.do_sample = do_sample 450 | self.max_length = max_length 451 | self.parallel_decode = parallel_decode 452 | self.n_processes = n_processes 453 | 454 | self.n_warmup_epochs = n_warmup_epochs 455 | self.sampling_annealing_rate = sampling_annealing_rate 456 | self.sample_expert_p = dagger_sampling_rate 457 | self.dagger_sampling_rate = dagger_sampling_rate # base rate 458 | self.max_trajectory_length = max_trajectory_length 459 | self.sample_batch_size = sample_batch_size 460 | self.sample_val_every_n_epoch = sample_val_every_n_epoch 461 | self.sample_train_every_n_epoch = sample_train_every_n_epoch 462 | self.val_sample_batch_size = val_sample_batch_size 463 | 464 | self.model_save_path = os.path.join(self.save_dir, 'WEIGHTS.bin') 465 | 466 | def gen_episode(self, instances): 467 | """ 468 | Generate one episode of a trajectory 469 | """ 470 | canvases = [inst.oracle_canvas for inst in instances] 471 | generations = self.gen_model(canvases) 472 | clean_generations = [g.clean() for g in generations] 473 | align_model, align_tokenizer = self.env.oracle.align_model, self.env.oracle.align_tokenizer 474 | targets = [inst.target for inst in instances] 475 | alignments = list(batch_align_canvases(clean_generations, targets, align_model, align_tokenizer, device=self.device)) 476 | for i,inst in enumerate(instances): 477 | inst.history.append((clean_generations[i].copy(), alignments[i].copy())) 478 | inst.oracle_canvas = self.env.oracle_edit(alignment=alignments[i]) 479 | return instances 480 | 481 | def sample_batch(self, data): 482 | 483 | """ 484 | Sample a new batch of data, i.e. generated trajectories from the current model 485 | """ 486 | 487 | instances = [GenerationInstance(d, self.env) for d in data] 488 | # for inst in instances: 489 | # if np.random.random() <= self.sample_expert_p: 490 | # inst.done=True 491 | 492 | self.env.oracle = self.env.oracle.to(self.device) 493 | self.model = self.model.eval() 494 | finished_instances = [] 495 | n_sampled_states = 0 496 | for i in range(self.max_trajectory_length): 497 | for inst in instances: 498 | if np.random.random() <= self.sample_expert_p or len(inst.history) >= self.max_trajectory_length: 499 | finished_instances.append(inst) 500 | inst.done = True 501 | n_sampled_states += len(inst.history) 502 | instances = [inst for inst in instances if not inst.done] 503 | # TODO: instead of fixing a maximum trajectory length, fix a max number of instances, then get a distribution over traj lengths 504 | # if n_sampled_states > self.sample_batch_size or len(instances) == 0: 505 | if len(instances) == 0: 506 | break 507 | 508 | instances = self.gen_episode(instances) 509 | finished_instances.extend(instances) 510 | self.env.oracle = self.env.oracle.cpu() 511 | self.model = self.model.train() 512 | 513 | sampled_states = [] 514 | for inst in finished_instances: 515 | sampled_states.extend(inst.make_data(self.align_tokenizer)) 516 | return sampled_states 517 | 518 | def pre_train(self): 519 | # self.sampling_rate_updates = 0 520 | self.post_epoch(-1) 521 | 522 | def post_epoch(self, e): 523 | # Sample trajectories here! 524 | def get_batch(data, size): 525 | idxs = np.random.choice(np.arange(len(data)), size, replace=False) 526 | for i in idxs: 527 | yield data[int(i)] 528 | 529 | if e < self.n_warmup_epochs or (e+1) % self.sample_train_every_n_epoch == 0: 530 | print("Sampling train batch") 531 | if e >= self.n_warmup_epochs: 532 | # self.sampling_rate_updates += 1 533 | self.sample_expert_p *= self.sampling_annealing_rate 534 | # np.exp(self.sampling_rate_updates * np.log(self.sampling_annealing_rate)) * self.dagger_sampling_rate 535 | print(self.sample_expert_p) 536 | train_batch = self.sample_batch(get_batch(self.data['train'], self.sample_batch_size)) 537 | self.train_loader = DataLoader(train_batch, collate_fn = self.prep_batch, batch_size=self.batch_size, 538 | shuffle=True, drop_last=True) 539 | 540 | if (e+1) % self.sample_val_every_n_epoch == 0: 541 | print("Sampling val batch") 542 | val_batch = self.sample_batch(get_batch(self.data['val'], self.val_sample_batch_size)) 543 | self.val_loader = DataLoader(val_batch, collate_fn = self.prep_batch, batch_size=self.batch_size, 544 | shuffle=False, drop_last=False) 545 | 546 | 547 | # def sample_trajectory(self, datum): 548 | # alignments = [] 549 | # alignment = Alignment(alignment=datum['alignment'], scores=datum['alignment_scores']) 550 | # alignment = self.env.reset(alignment=alignment, return_alignment=True) 551 | # i = 0 552 | # while True: 553 | # alignments.append(alignment) 554 | # if np.random.random() > self.sample_expert_p\ 555 | # and i < self.max_trajectory_length: 556 | # canvas = alignment.get_source_canvas() 557 | # canvas = self.gen_model(canvas) 558 | # alignment,_ = self.env.step(canvas=canvas, return_alignment=True, device=self.device) 559 | # i += 1 560 | # else: 561 | # break 562 | # return alignments 563 | # 564 | # def prep_batch(self, batch): 565 | # alignments = [] 566 | # for b in batch: 567 | # alignments.extend(self.sample_trajectory(b)) 568 | # if len(alignments) >= self.batch_size: break 569 | # batch = self.model.prep_batch(alignments, device=self.device) 570 | # return batch 571 | 572 | def gen_model(self, canvases): 573 | if self.model_name in ('bart', 'bert', 'bart-large'): 574 | canvases = list(tqdm.tqdm(self.model.batch_depth_decode( 575 | canvases, 576 | top_k=self.top_k, 577 | max_batch_tokens=2048, 578 | device=self.device, 579 | # parallel=self.parallel_decode, 580 | return_idx=True, 581 | queue_size=2000, 582 | max_iter=self.max_iter, 583 | # n_processes=self.n_processes 584 | ), total=len(canvases))) 585 | canvases = [(i,c) for c,i in canvases] 586 | canvases = [c for i,c in sorted(canvases)] 587 | return canvases 588 | elif self.model_name == 'barts2s': #TODO 589 | canvases = self.model.batch_generate(canvases, device=self.device, 590 | do_sample=self.do_sample, top_p=self.top_p, max_length=self.max_length) 591 | # print(canvases) 592 | return canvases 593 | 594 | @classmethod 595 | def add_args(cls, parser): 596 | super().add_args(parser) 597 | 598 | model_gen_group = parser.add_argument_group('model generation') 599 | model_gen_group.add_argument('--top_k', type=int, default=10) 600 | model_gen_group.add_argument('--top_p', type=float, default=0.95) 601 | model_gen_group.add_argument('--stop_threshold', type=float, default=0.9) 602 | model_gen_group.add_argument('--parallel_decode', type=lambda x:bool(strtobool(x)), default=False) 603 | model_gen_group.add_argument('--n_processes', type=int, default=10) 604 | model_gen_group.add_argument('--do_sample', type=lambda x:bool(strtobool(x)), default=True) 605 | model_gen_group.add_argument('--max_length', type=int, default=64) 606 | model_gen_group.add_argument('--max_iter', type=int, default=32) 607 | 608 | dagger_group = parser.add_argument_group('dagger') 609 | dagger_group.add_argument('--n_warmup_epochs', type=int, default=0) 610 | dagger_group.add_argument('--sampling_annealing_rate', type=float, default=1.0) 611 | dagger_group.add_argument('--dagger_sampling_rate', type=float, default=0.5) 612 | dagger_group.add_argument('--max_trajectory_length', type=int, default=2) 613 | dagger_group.add_argument('--sample_batch_size', type=int, default=100) 614 | dagger_group.add_argument('--sample_val_every_n_epoch', type=int, default=50) 615 | dagger_group.add_argument('--sample_train_every_n_epoch', type=int, default=15) 616 | dagger_group.add_argument('--val_sample_batch_size', type=int, default=100) 617 | 618 | class ForwardTrain(Train): 619 | 620 | """ 621 | Forward training job. Similar to dagger but generates less frequently. 622 | Not used anymore 623 | """ 624 | 625 | def __init__( 626 | self, 627 | top_k=10, 628 | top_p=1.0, 629 | stop_threshold=1.0, 630 | max_iter=16, 631 | do_sample=True, 632 | max_length=64, 633 | n_forward_iter=2, 634 | resume_forward_iter=0, 635 | sample_alg='depth', 636 | sample_gen_reverse_steps=64, 637 | force_regen=False, 638 | n_processes=None, 639 | parallel_decode=False, 640 | **kwargs): 641 | 642 | super().__init__(**kwargs) 643 | 644 | self.top_k = top_k 645 | self.sample_alg = sample_alg 646 | self.top_p = top_p 647 | self.stop_threshold = stop_threshold 648 | self.max_iter = max_iter 649 | self.do_sample = do_sample 650 | self.max_length = max_length 651 | self.sample_gen_reverse_steps = sample_gen_reverse_steps 652 | self.n_forward_iter = n_forward_iter 653 | self.resume_forward_iter = resume_forward_iter 654 | self.force_regen = force_regen 655 | self.n_processes = n_processes 656 | self.parallel_decode = parallel_decode 657 | 658 | self.meta_log_dir = self.log_dir 659 | self.meta_run_name = self.run_name 660 | 661 | def gen_model(self, input_generator): 662 | if self.model_name in ('bart', 'bart-large', 'bert'): 663 | if self.sample_alg == 'sample': 664 | for canvas, idx in self.model.batch_decode( 665 | input_generator, 666 | top_k = self.top_k, 667 | max_batch_tokens=2048, 668 | device=self.device, 669 | # parallel=self.parallel_decode, 670 | return_idx=True, 671 | queue_size=2000, 672 | max_iter=self.max_iter, 673 | # n_processes=self.n_processes 674 | ): 675 | yield canvas, idx 676 | elif self.sample_alg == 'depth': 677 | for canvas, idx in self.model.batch_depth_decode( 678 | input_generator, 679 | top_k = self.top_k, 680 | max_batch_tokens=2048, 681 | device=self.device, 682 | # parallel=self.parallel_decode, 683 | return_idx=True, 684 | queue_size=2000, 685 | max_iter=self.max_iter, 686 | stop_threshold=self.stop_threshold, 687 | # n_processes=self.n_processes 688 | ): 689 | yield canvas, idx 690 | elif self.model_name == 'barts2s': 691 | for i,canvas in enumerate(input_generator): 692 | canvas = self.model.generate( 693 | canvas, device=self.device, do_sample=self.do_sample, top_p=self.top_p, max_length=self.max_length 694 | ) 695 | yield canvas, i 696 | return 697 | 698 | def batch_generate(self, data, save_path, chunk_size=10000, **kwargs):# batch_size=1 for now because it is still not scaling up well (because of differences in canvas lengths, etc.) 699 | def generator_(data, data_buffer): 700 | for i,datum in enumerate(data): 701 | if self.keep_token_type_ids: 702 | token_type_ids = datum.get('token_type_ids') 703 | else: 704 | token_type_ids = None 705 | alignment = Alignment(datum['alignment'], datum['alignment_scores'], token_type_ids) 706 | alignment = self.env.oracle_edit(alignment=alignment, return_alignment=True) 707 | 708 | # push forward 709 | non_const_ops = alignment.get_non_const_ops() 710 | n_forward_steps = max(0, len(non_const_ops) - self.sample_gen_reverse_steps) 711 | forward_ops = np.random.choice(non_const_ops, n_forward_steps) 712 | alignment.push_forward(forward_ops) 713 | data_buffer[i] = alignment 714 | yield alignment.get_source_canvas() 715 | 716 | # chunk generator 717 | def chunks(iterator, n): 718 | for first in iterator: # take one item out (exits loop if `iterator` is empty) 719 | rest_of_chunk = itertools.islice(iterator, 0, n - 1) 720 | yield itertools.chain([first], rest_of_chunk) # concatenate the first item back 721 | 722 | self.model = self.model.eval() 723 | with open(save_path, 'wt') as f: 724 | for i,chunk in enumerate(chunks(iter(data), chunk_size)): 725 | print(f'On chunk {i}') 726 | timeouts = 0 727 | data_buffer = {} 728 | popped_idxs = set() 729 | input_generator = generator_(chunk, data_buffer) 730 | output_generator = self.gen_model(input_generator) 731 | for canvas, idx in tqdm.tqdm(output_generator, total=chunk_size): 732 | if canvas is None: 733 | timeouts += 1 734 | print('timeout') 735 | continue 736 | try: 737 | alignment = data_buffer.pop(idx) 738 | except KeyError as e: 739 | print(np.max(list(data_buffer.keys()))) 740 | print(idx in popped_idxs) 741 | raise e 742 | popped_idxs.add(idx) 743 | canvas = canvas.clean() 744 | json_str = json.dumps({ 745 | 'source_tokens': canvas.tokens, 746 | 'token_type_ids': canvas.type_ids, 747 | # 'idx': idx, 748 | 'source_text': canvas.render(self.align_tokenizer), 749 | 'target_tokens': alignment.get_target_tokens(), 750 | 'target_text': alignment.get_target_canvas().render(self.align_tokenizer), 751 | }) 752 | f.write(json_str + '\n') 753 | print(f'{timeouts} timeouts') 754 | 755 | def align_batch(self, batch): 756 | tokens_a = batch['source_tokens'] 757 | tokens_b = batch['target_tokens'] 758 | alignments = list( 759 | batch_align( 760 | tokens_a, tokens_b, self.align_model, self.align_tokenizer, 761 | baseline_score=0.3, device=self.device 762 | ) 763 | ) 764 | alignment_scores = [a.scores for a in alignments] 765 | alignments = [a.alignment for a in alignments] 766 | batch['alignment'] = alignments 767 | batch['alignment_scores'] = alignment_scores 768 | return batch 769 | 770 | def load_data(self, **kwargs): 771 | super().load_data(**kwargs) 772 | self.gen_data = self.data 773 | 774 | def pre_train(self): 775 | self.train_loader = DataLoader(self.data['train'], collate_fn = self.prep_batch, batch_size=self.batch_size, 776 | shuffle=True, drop_last=True) 777 | self.val_loader = DataLoader(self.data['val'], collate_fn = self.prep_batch, batch_size=self.batch_size, 778 | shuffle=False, drop_last=False) 779 | 780 | def train(self): 781 | if self.resume_forward_iter > 0: 782 | forward_iter = self.resume_forward_iter - 1 783 | iter_save_dir = os.path.join(self.save_dir, f'forward_iter_{forward_iter}') 784 | for i in range(forward_iter): 785 | gen_save_path = os.path.join(self.save_dir, f'forward_iter_{i}', 'gen_data') 786 | self.gen_data = load_from_disk(gen_save_path) 787 | self.add_data(self.gen_data) 788 | 789 | gen_save_path = os.path.join(iter_save_dir, 'gen_data') 790 | if not os.path.exists(gen_save_path) or self.force_regen: 791 | weights_path = os.path.join(iter_save_dir, 'WEIGHTS.bin') 792 | self.kwargs.update({'resume_from_ckpt': weights_path}) 793 | self.load_model(**self.kwargs) 794 | self.kwargs.update({'resume_from_ckpt': None}) 795 | self.generate(forward_iter, iter_save_dir) 796 | else: 797 | self.gen_data = load_from_disk(gen_save_path) 798 | self.add_data(self.gen_data) 799 | 800 | self.load_model(**self.kwargs) 801 | self.load_optimizer(**self.kwargs) 802 | 803 | for forward_iter in range(self.resume_forward_iter, self.n_forward_iter): 804 | 805 | iter_save_dir = os.path.join(self.save_dir, f'forward_iter_{forward_iter}') 806 | if not os.path.exists(iter_save_dir): 807 | os.makedirs(iter_save_dir) 808 | self.model_save_path = os.path.join(iter_save_dir, 'WEIGHTS.bin') 809 | 810 | iter_logdir = os.path.join(self.meta_log_dir, f'forward_iter_{forward_iter}') 811 | if not os.path.exists(iter_logdir): 812 | os.makedirs(iter_logdir) 813 | self.log_dir = iter_logdir 814 | 815 | self.run_name = '-'.join((self.meta_run_name, f'forward_iter_{forward_iter}')) 816 | # self.logger = SummaryWriter(log_dir = iter_logdir) 817 | 818 | super().train() 819 | 820 | if forward_iter == (self.n_forward_iter - 1): 821 | break 822 | 823 | # load best checkpoint 824 | self.kwargs.update({'resume_from_ckpt': self.model_save_path}) 825 | self.load_model(**self.kwargs) 826 | self.kwargs.update({'resume_from_ckpt': None}) 827 | 828 | self.generate(forward_iter, iter_save_dir) 829 | 830 | print('reloading model') 831 | 832 | self.load_model(**self.kwargs) 833 | self.load_optimizer(**self.kwargs) 834 | 835 | def generate(self, forward_iter, iter_save_dir): 836 | 837 | print('generating') 838 | train_save_path = os.path.join(iter_save_dir, 'train_data') 839 | val_save_path = os.path.join(iter_save_dir, 'val_data') 840 | self.batch_generate(self.gen_data['train'], train_save_path) 841 | self.batch_generate(self.gen_data['val'], val_save_path) 842 | 843 | def listdict2dictlist(ld): 844 | keys = ld[0].keys() 845 | return { 846 | k: [d[k] for d in ld] for k in keys 847 | } 848 | 849 | train_generations, val_generations = [], [] 850 | with open(train_save_path, 'rt') as f: 851 | for line in f: 852 | train_generations.append(json.loads(line)) 853 | with open(val_save_path, 'rt') as f: 854 | for line in f: 855 | val_generations.append(json.loads(line)) 856 | 857 | generated_data = DatasetDict({ 858 | 'train': Dataset.from_dict(listdict2dictlist(train_generations)), 859 | 'val': Dataset.from_dict(listdict2dictlist(val_generations)) 860 | }) 861 | 862 | # generated_data = DatasetDict({ 863 | # 'train': load_dataset('json', train_save_path), 864 | # 'val': load_dataset('json', val_save_path) 865 | # }) 866 | 867 | self.align_model = self.align_model.to(self.device) 868 | print('aligning') 869 | generated_data = generated_data.map(self.align_batch, batched=True, batch_size=64) 870 | self.align_model = self.align_model.cpu() 871 | self.gen_data = generated_data 872 | 873 | os.remove(train_save_path) 874 | os.remove(val_save_path) 875 | 876 | gen_save_path = os.path.join(iter_save_dir, 'gen_data') 877 | generated_data.save_to_disk(gen_save_path) 878 | 879 | self.add_data(generated_data) 880 | 881 | def add_data(self, dataset): 882 | 883 | self.data = DatasetDict({ 884 | 'train': custom_cat_datasets(dataset['train'], self.data['train']), 885 | 'val': custom_cat_datasets(dataset['val'], self.data['val']) 886 | }) 887 | 888 | @classmethod 889 | def add_args(cls, parser): 890 | super().add_args(parser) 891 | 892 | parser.add_argument('--sample_alg', choices=['depth', 'sample'], default='depth') 893 | parser.add_argument('--top_k', type=int, default=10) 894 | parser.add_argument('--top_p', type=float, default=0.95) 895 | parser.add_argument('--stop_threshold', type=float, default=0.9) 896 | parser.add_argument('--max_iter', type=int, default=32) 897 | parser.add_argument('--do_sample', type=lambda x:bool(strtobool(x)), default=False) 898 | parser.add_argument('--max_length', type=int, default=64) 899 | parser.add_argument('--n_forward_iter', type=int, default=2) 900 | parser.add_argument('--resume_forward_iter', type=int, default=0) 901 | parser.add_argument('--sample_gen_reverse_steps', type=int, default=64) 902 | parser.add_argument('--force_regen', type=lambda x:bool(strtobool(x)), default=False) 903 | parser.add_argument('--parallel_decode', type=lambda x:bool(strtobool(x)), default=False) 904 | parser.add_argument('--n_processes', type=int) 905 | 906 | #class NBTrain(Train): 907 | # 908 | # def __init__( 909 | # self, 910 | # top_k=10): 911 | # 912 | # self.top_k = top_k 913 | 914 | 915 | 916 | if __name__ == '__main__': 917 | 918 | parser = argparse.ArgumentParser() 919 | parser.add_argument('--cuda_device', type=int, default=0) 920 | parser.add_argument('--use_timestamp', type=lambda x:bool(strtobool(x)), default=False) 921 | parser.add_argument('--seed', type=int, default=42) 922 | parser.add_argument('--run_dir', type=str, 923 | default='/Mounts/rbg-storage1/users/faltings/cache/infosol/experiment_results/train/') 924 | parser.add_argument('--run_name', type=str, 925 | default='debug') 926 | parser.add_argument('--project_name', type=str, 927 | default='infosol') 928 | subparsers = parser.add_subparsers() 929 | 930 | base_parser = subparsers.add_parser('base') 931 | base_parser.set_defaults(func=Train) 932 | Train.add_args(base_parser) 933 | 934 | dagger_parser = subparsers.add_parser('dagger') 935 | dagger_parser.set_defaults(func=DaggerTrain) 936 | DaggerTrain.add_args(dagger_parser) 937 | 938 | forward_parser = subparsers.add_parser('forward') 939 | forward_parser.set_defaults(func=ForwardTrain) 940 | ForwardTrain.add_args(forward_parser) 941 | 942 | args = parser.parse_args() 943 | 944 | device=torch.device('cuda:{}'.format(args.cuda_device)) if torch.cuda.is_available() else torch.device('cpu') 945 | 946 | if args.use_timestamp: 947 | run_id = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 948 | run_dir = os.path.join(args.run_dir, args.run_name, run_id) 949 | else: run_dir = os.path.join(args.run_dir, args.run_name) 950 | if not os.path.exists(run_dir): 951 | os.makedirs(run_dir) 952 | 953 | config_path = os.path.join(run_dir, 'config.conf') 954 | with open(config_path, 'w') as f: 955 | for k,v in vars(args).items(): 956 | f.write('--' + k + '\n' + str(v) + '\n') 957 | 958 | log_dir = os.path.join(run_dir, 'logs') 959 | # if os.path.exists(log_dir): 960 | # shutil.rmtree(log_dir) 961 | # for logfile in os.listdir(log_dir): 962 | # os.remove(os.path.join(log_dir, logfile)) 963 | if not os.path.exists(log_dir): 964 | os.makedirs(log_dir) 965 | 966 | rng = np.random.default_rng(seed = args.seed) 967 | train = args.func( 968 | config = vars(args), 969 | save_dir = run_dir, 970 | log_dir = log_dir, 971 | device = device, 972 | rng = rng, 973 | **vars(args)) 974 | train.train() 975 | -------------------------------------------------------------------------------- /infosol/utils/data/cnn_dailymail.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from nltk.tokenize import sent_tokenize 4 | 5 | def prep_data_inst(tokenizer, 6 | data_inst, 7 | n_article_sentences = 3, 8 | prompt_length=10, 9 | use_highlights=True): 10 | highlight_prefix = 'Summary: ' 11 | article_prefix = 'Article: ' 12 | highlights = highlight_prefix + ' '.join(data_inst['highlights'].split('\n')) + '\n' 13 | highlight_ids = tokenizer(highlights, return_tensors='pt')['input_ids'].squeeze() 14 | article_text = ' '.join(sent_tokenize(data_inst['article'])[:n_article_sentences]) + '\n' 15 | article_ids = tokenizer(article_text, return_tensors='pt')['input_ids'].squeeze() 16 | article_prefix_length = 0 17 | if use_highlights: 18 | article_text = article_prefix + article_text 19 | article_length_no_prefix = article_ids.size(0) 20 | article_ids = tokenizer(article_text, return_tensors='pt')['input_ids'].squeeze() 21 | article_prefix_length = article_ids.size(0) - article_length_no_prefix 22 | if use_highlights: 23 | input_ids = torch.cat((highlight_ids, article_ids[:prompt_length])) 24 | highlight_length = highlight_ids.size(0) 25 | else: 26 | input_ids = article_ids[:prompt_length] 27 | highlight_length = 0 28 | target_ids = article_ids 29 | return input_ids, target_ids, highlight_length, article_prefix_length 30 | -------------------------------------------------------------------------------- /infosol/utils/keywords.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | def order_keywords(target_text, keywords, highlighter): 5 | highlighted_string = highlighter.highlight(target_text, keywords) 6 | kw_re = re.compile('[^<]*') 7 | ordered_keywords = kw_re.findall(highlighted_string) 8 | ordered_keywords = [re.sub('', '', re.sub('', '', kw)) for kw in ordered_keywords] 9 | for i in range(len(ordered_keywords)): 10 | kw_matched = False 11 | for kw_pair in keywords: 12 | if kw_pair[0] == ordered_keywords[i]: 13 | ordered_keywords[i] = kw_pair 14 | kw_matched = True 15 | break 16 | if not kw_matched: ordered_keywords[i] = (ordered_keywords[i], 0) 17 | return ordered_keywords 18 | 19 | def choose_top_kws(kws, n_kws): 20 | scores = [k[1] for k in kws] 21 | top_kws = [kws[i] for i in np.sort(np.argsort(scores)[-n_kws:])] 22 | return top_kws 23 | -------------------------------------------------------------------------------- /infosol/utils/pointer_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | 4 | sys.path.append('/home/v-fefal/POINTER') 5 | from pytorch_transformers import BertForMaskedLM 6 | from pytorch_transformers.tokenization_bert import BertTokenizer 7 | from inference import convert_example_to_features, greedy_search, PregeneratedDataset 8 | from util import MAX_TURN, PREVENT_FACTOR, PROMOTE_FACTOR, PREVENT_LIST, REDUCE_LIST, STOP_LIST, boolean_string 9 | 10 | class POINTERArgs: 11 | 12 | def __init__(self, bert_model, do_lower_case=False, noi_decay=1, reduce_decay=1, prevent=True, 13 | reduce_stop=True, lessrepeat=True, max_seq_length=256, no_ins_at_first=False, verbose=0): 14 | self.bert_model = bert_model 15 | self.do_lower_case = do_lower_case 16 | self.noi_decay = noi_decay 17 | self.reduce_decay = reduce_decay 18 | self.prevent = prevent 19 | self.reduce_stop = reduce_stop 20 | self.lessrepeat = lessrepeat 21 | self.max_seq_length = max_seq_length 22 | self.no_ins_at_first = no_ins_at_first 23 | self.verbose = verbose 24 | 25 | class POINTERWrapper: 26 | 27 | def __init__(self, args): 28 | self.args = args 29 | self.tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 30 | self.model = BertForMaskedLM.from_pretrained(args.bert_model) 31 | self.model.eval() 32 | 33 | prevent = [ self.tokenizer.vocab.get(x) for x in PREVENT_LIST] if args.prevent else None 34 | if args.reduce_stop: 35 | # import pdb; pdb.set_trace() 36 | reduce_l = REDUCE_LIST | STOP_LIST 37 | reduce = None 38 | if args.prevent: 39 | reduce = [ self.tokenizer.vocab.get(x) for x in reduce_l] 40 | reduce = [s for s in reduce if s] 41 | self.prevent = prevent 42 | self.reduce = reduce 43 | 44 | def prep_input(self, inp_tokens, device): 45 | # canvas = [c.strip().lstrip() for c in canvas] 46 | features = convert_example_to_features(inp_tokens, self.tokenizer, self.args.max_seq_length, 47 | no_ins_at_first = self.args.no_ins_at_first, id=0, tokenizing=True) 48 | out = (features.input_ids, features.input_mask, features.segment_ids, features.lm_label_ids, features.no_ins) 49 | out = tuple(torch.tensor(o.reshape(1,-1)).long().to(device) for o in out) 50 | return out 51 | 52 | def generate_(self, input_ids, segment_ids, input_mask, no_ins): 53 | sep_tok = self.tokenizer.vocab['[SEP]'] 54 | cls_tok = self.tokenizer.vocab['[CLS]'] 55 | pad_tok = self.tokenizer.vocab['[PAD]'] 56 | 57 | predict_ids = greedy_search(self.model, input_ids, segment_ids, input_mask, no_ins = no_ins, args=self.args, 58 | tokenizer=self.tokenizer, prevent=self.prevent, reduce=self.reduce) 59 | output = " ".join([str(self.tokenizer.ids_to_tokens.get(x, "noa").encode('ascii', 'ignore').decode('ascii')) for x in predict_ids[0].detach().cpu().numpy() if x!=sep_tok and x != pad_tok and x != cls_tok]) 60 | output = output.replace(" ##", "") 61 | return output 62 | 63 | def generate(self, inp_tokens, device=torch.device('cpu')): 64 | input_ids, input_mask, segment_ids, lm_label_ids, no_ins = self.prep_input(inp_tokens, device) 65 | return self.generate_(input_ids, segment_ids, input_mask, no_ins) 66 | -------------------------------------------------------------------------------- /jobs/interactive/cnn-bart-s2s-len: -------------------------------------------------------------------------------- 1 | BartS2S --model_path models/dagger/cnn-bart-s2s/WEIGHTS.bin --out_path out/main/interactive/cnn-bart-s2s/1x6_lp3.3 --data_path data/cnn_bart --max_data 2000 --idf_path data/misc/cnn_bart_idfs.pickle --n_episodes 1 --n_oracle_edits 6 --adjacent_ops True --complete_words True --contiguous_edits True --bleu_ngrams 1 --length_penalty 3.3 2 | BartS2S --model_path models/dagger/cnn-bart-s2s/WEIGHTS.bin --out_path out/main/interactive/cnn-bart-s2s/2x3_lp1.5 --data_path data/cnn_bart --max_data 2000 --idf_path data/misc/cnn_bart_idfs.pickle --n_episodes 2 --n_oracle_edits 3 --adjacent_ops True --complete_words True --contiguous_edits True --bleu_ngrams 1 --length_penalty 1.5 3 | BartS2S --model_path models/dagger/cnn-bart-s2s/WEIGHTS.bin --out_path out/main/interactive/cnn-bart-s2s/3x2_lp0.8 --data_path data/cnn_bart --max_data 2000 --idf_path data/misc/cnn_bart_idfs.pickle --n_episodes 3 --n_oracle_edits 2 --adjacent_ops True --complete_words True --contiguous_edits True --bleu_ngrams 1 --length_penalty 0.8 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | datasets==2.4.0 3 | matplotlib==3.5.1 4 | nltk==3.7 5 | setuptools==65.5.1 6 | tqdm==4.64.0 7 | transformers==4.18.0 8 | numpy==1.22.4 9 | wandb==0.13.2 10 | scikit-learn==1.1.1 11 | -------------------------------------------------------------------------------- /scripts/dowload_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget "https://onedrive.live.com/download?cid=7AA4E17800AB44C8&resid=7AA4E17800AB44C8%21510310&authkey=ABxsQgnGdJk7wWI&download=1" -O infosol_models.tar.zst 4 | zstd -dc infosol_models.tar.zst | tar -xvf - ./models 5 | -------------------------------------------------------------------------------- /scripts/make_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | import torch 5 | import itertools 6 | import tqdm 7 | import pickle 8 | import random 9 | import argparse 10 | import numpy as np 11 | import matplotlib.pyplot as pltq 12 | 13 | from sklearn.feature_extraction.text import CountVectorizer 14 | from datasets import Dataset, DatasetDict, load_from_disk, load_dataset, concatenate_datasets 15 | from transformers import AutoTokenizer, BertModel, BartModel, AutoModel 16 | from infosol.alignment import batch_align, get_non_const_ops, sample_actions 17 | from nltk.tokenize import sent_tokenize 18 | 19 | from transformers import AutoModel 20 | model = AutoModel.from_pretrained('facebook/bart-large') 21 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 22 | params = sum([np.prod(p.size()) for p in model_parameters]) 23 | 24 | def tokenize_text(data_inst): 25 | data_inst['target_tokens'] = tokenizer.tokenize(data_inst['target_text']) 26 | data_inst['source_tokens'] = tokenizer.tokenize(data_inst['source_text']) 27 | return data_inst 28 | 29 | def make_len_filter(max_len=32, min_len=0): 30 | def filter_(data_inst): 31 | return len(data_inst['target_tokens']) <= max_len and len(data_inst['target_tokens']) >= min_len 32 | return filter_ 33 | 34 | def align_batch(data_batch, model, tokenizer, baseline_score=0.3, device=torch.device('cpu')): 35 | tokens_a = data_batch['source_tokens'] 36 | tokens_b = data_batch['target_tokens'] 37 | alignments = list(batch_align(tokens_a, tokens_b, 38 | model, tokenizer, baseline_score=baseline_score, device=device)) 39 | alignment_scores = [a.scores for a in alignments] 40 | alignments = [a.alignment for a in alignments] 41 | data_batch['alignment'] = alignments 42 | data_batch['alignment_scores'] = alignment_scores 43 | return data_batch 44 | 45 | def is_edit(data_inst): 46 | return len(get_non_const_ops(data_inst['alignment'])) > 0 47 | 48 | def compute_tf_idf_scores(data, tokenizer): 49 | sentences = [datum['target_text'] for datum in data] 50 | cvectorizer = CountVectorizer(analyzer=tokenizer.tokenize) 51 | counts = cvectorizer.fit_transform(sentences) 52 | 53 | document_counts = counts > 0 54 | 55 | doc_counts_sum = document_counts.sum(axis=0) 56 | counts_sum = counts.sum(axis=0) 57 | 58 | idf_scores = {k: -np.log(doc_counts_sum[0,cvectorizer.vocabulary_[k]]) for k in tqdm.tqdm(cvectorizer.vocabulary_)} 59 | idf_scores = {k: np.log(document_counts.shape[0]) + idf_scores[k] for k in idf_scores} 60 | 61 | total = counts_sum.sum() 62 | tf_dict = {k: counts_sum[0, cvectorizer.vocabulary_[k]]/total for k in cvectorizer.vocabulary_} 63 | 64 | return idf_scores, tf_dict 65 | 66 | def add_type_ids(datum): 67 | datum['token_type_ids'] = [0] * len(datum['source_tokens']) 68 | return datum 69 | 70 | def remove_type_ids(datum): 71 | datum.pop('token_type_ids') 72 | return datum 73 | 74 | def tokenize_text(data_inst, tokenizer): 75 | data_inst['target_tokens'] = tokenizer.tokenize(data_inst['target_text']) 76 | data_inst['source_tokens'] = tokenizer.tokenize(data_inst['source_text']) 77 | return data_inst 78 | 79 | def random_subset(data, size, seed=42): 80 | return data.shuffle( 81 | generator=np.random.default_rng(seed) 82 | ).select( 83 | range(size) 84 | ) 85 | 86 | def random_subset_dict(dataset_dict, split_sizes, seed=42): 87 | return DatasetDict({ 88 | n: random_subset(dataset_dict[n], s, seed=seed) for n,s in zip(dataset_dict, split_sizes) 89 | }) 90 | 91 | def process_cnn_instances(data_inst, tokenizer): 92 | return_instances = [] 93 | for sent in data_inst['highlights'][0].split('\n'): 94 | return_inst = {} 95 | target_text = sent 96 | target_tokens = tokenizer.tokenize(target_text) 97 | alignment = [('', t) for t in target_tokens] 98 | alignment_scores = [0.3] * len(alignment) 99 | source_tokens = [] 100 | source_text = '' 101 | token_type_ids = [] 102 | return_inst = { 103 | 'id': data_inst['id'][0], 104 | 'target_text': target_text, 105 | 'target_tokens': target_tokens, 106 | 'source_text': source_text, 107 | 'source_tokens': source_tokens, 108 | 'alignment': alignment, 109 | 'alignment_scores': alignment_scores, 110 | 'token_type_ids': token_type_ids 111 | } 112 | return_instances.append(return_inst) 113 | return_instances = {k: [inst[k] for inst in return_instances] for k in return_instances[0]} 114 | return return_instances 115 | 116 | def make_cnn_data(tokenizer): 117 | data = load_dataset('cnn_dailymail', '3.0.0') 118 | data = data.map(lambda x: process_cnn_instances(x, tokenizer), batched=True, batch_size=1, remove_columns=['article', 'highlights', 'id']) 119 | data = data.filter(make_len_filter(max_len=64, min_len=10)) 120 | return data 121 | 122 | if __name__ == '__main__': 123 | 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument('--data_dir', type=str, default='data') # where to save data 126 | args = parser.parse_args() 127 | 128 | bart_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base') 129 | bart_model = BartModel.from_pretrained('facebook/bart-base').encoder 130 | device = torch.device('cuda') 131 | bart_model = bart_model.to(device) 132 | 133 | cnn_data = load_dataset('ccdv/cnn_dailymail', '3.0.0') 134 | 135 | bart_cnn_data = make_cnn_data(bart_tokenizer) 136 | 137 | save_path = os.path.join(args.data_dir, 'cnn_bart') 138 | bart_cnn_data.save_to_disk(save_path) 139 | 140 | cnn_bart_idfs, cnn_bart_tfs = compute_tf_idf_scores(bart_cnn_data['train'].select(range(500000)), bart_tokenizer) 141 | 142 | misc_dir = os.path.join(args.data_dir, 'misc') 143 | if not os.path.exists(misc_dir): 144 | os.makedirs(misc_dir) 145 | 146 | bart_idf_save_path = os.path.join(misc_dir, 'cnn_bart_idfs.pickle') 147 | with open(bart_idf_save_path, 'wb') as f: 148 | pickle.dump(cnn_bart_idfs, f) 149 | 150 | bart_tf_save_path = os.path.join(misc_dir, 'cnn_bart_tfs.pickle') 151 | with open(bart_tf_save_path, 'wb') as f: 152 | pickle.dump(cnn_bart_tfs, f) 153 | -------------------------------------------------------------------------------- /scripts/make_eval_jobs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import itertools 4 | import json 5 | import os 6 | import re 7 | 8 | from distutils.util import strtobool 9 | 10 | """ 11 | Expects the following directory structure: 12 | model_dir / 13 | exp_name / 14 | model_name / 15 | WEIGHTS.bin 16 | model_name 17 | exp_name / 18 | data_dir / 19 | cnn / 20 | filtered_bart_64.pickle # cnn dataset 21 | misc / 22 | cnn_bart_idfs.pickle # cnn idf scores 23 | 24 | Will write the argument files under job_dir, and the scripts will write results to out_dir/job_name. Pass the argument files to the run_eval.py script 25 | to run the experiments. 26 | """ 27 | 28 | command_template = '{} \ 29 | --model_path {} \ 30 | --out_path {} \ 31 | --data_path {} \ 32 | --max_data {} \ 33 | --idf_path {} \ 34 | --n_episodes {} \ 35 | --n_oracle_edits {} \ 36 | --adjacent_ops {} \ 37 | --complete_words {} \ 38 | --contiguous_edits {} \ 39 | --bleu_ngrams {}' 40 | 41 | def job_args_from_job_name(job_name): 42 | if 'bert' in job_name: 43 | if 'cnn' in job_name: 44 | data_path_ = os.path.join(data_dir, 'cnn', 'filtered_bert_64') 45 | idf_path_ = os.path.join(data_dir, 'misc', 'cnn_bert_idfs.pickle') 46 | elif 'yelp' in job_name: 47 | data_path_ = os.path.join(data_dir, 'yelp_pe', 'bert_gen_100') 48 | idf_path_ = os.path.join(data_dir, 'misc', 'yelp_bert_idfs.pickle') 49 | elif 'bart' in job_name: 50 | if 'cnn' in job_name: 51 | data_path_ = os.path.join(data_dir, 'cnn', 'filtered_bart_64') 52 | idf_path_ = os.path.join(data_dir, 'misc', 'cnn_bart_idfs.pickle') 53 | elif 'yelp' in job_name: 54 | data_path_ = os.path.join(data_dir, 'yelp_pe', 'bart_gen_100') 55 | idf_path_ = os.path.join(data_dir, 'misc', 'yelp_bart_idfs.pickle') 56 | 57 | if 'bert' in job_name: 58 | func_ = 'BertEditor' 59 | elif 'bart-s2s' in job_name: 60 | func_ = 'BartS2S' 61 | elif 'bart_editor_large' in job_name: 62 | func_ = 'BartLargeEditor' 63 | elif 'bart' in job_name: 64 | func_ = 'BartEditor' 65 | 66 | return data_path_, idf_path_, func_ 67 | 68 | def make_interactive_eval_job(name, func, model_path, idf_path): 69 | out_dir_ = os.path.join(int_out_dir, name) 70 | if not os.path.exists(out_dir_): 71 | os.makedirs(out_dir_) 72 | n_episodes_ = [1,2,3] 73 | n_oracle_edits_ = [6,3,2] 74 | arg_tuples = [] 75 | 76 | for n_ep, n_oe in zip(n_episodes_, n_oracle_edits_): 77 | out_path_ = os.path.join(out_dir_, f"{n_ep}x{n_oe}") 78 | arg_tuple = ( 79 | func, model_path, out_path_, 80 | data_path, max_data, idf_path, 81 | n_ep, n_oe, 82 | adjacent_ops, complete_words, contiguous_edits, 83 | bleu_ngrams) 84 | arg_tuples.append(arg_tuple) 85 | 86 | job_path = os.path.join(int_job_dir, name) 87 | with open(job_path, 'wt') as f: 88 | for a in arg_tuples: 89 | f.write(command_template.format(*a) + '\n') 90 | 91 | if __name__ == '__main__': 92 | 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument('--job_dir', type=str, 95 | default='jobs') # where to write job args 96 | parser.add_argument('--model_dir', type=str, 97 | default='models') # where to find model checkpoints 98 | parser.add_argument('--out_dir', type=str, 99 | default='out') # where evaluation jobs should write results 100 | parser.add_argument('--data_dir', type=str, 101 | default='data') # data dir 102 | args = parser.parse_args() 103 | 104 | # MAIN ============================ 105 | 106 | ## DEFAULT SETTINGS 107 | 108 | job_dir = os.path.join(args.job_dir, 'main') 109 | if not os.path.exists(job_dir): 110 | os.makedirs(job_dir) 111 | 112 | model_dir = args.model_dir 113 | exp_name = 'dagger' 114 | model_name = 'cnn-bart-editor' 115 | 116 | func = 'BartEditor' 117 | model_path = os.path.join(model_dir, exp_name, model_name, 'WEIGHTS.bin') 118 | out_dir = os.path.join(args.out_dir, 'main') 119 | out_path = os.path.join(out_dir, 'main.pickle') 120 | data_dir = args.data_dir 121 | data_path = os.path.join(data_dir, 'cnn_bart') 122 | max_data = 2000 123 | idf_path = os.path.join(data_dir, 'misc', 'cnn_bart_idfs.pickle') 124 | n_episodes = 4 125 | n_oracle_edits = 3 126 | adjacent_ops = True 127 | complete_words = True 128 | contiguous_edits = True 129 | bleu_ngrams = 1 130 | 131 | ## Different Oracles 132 | 133 | out_dir_ = os.path.join(out_dir, 'oracles') 134 | if not os.path.exists(out_dir_): 135 | os.makedirs(out_dir_) 136 | oracle_names = ['unrestricted', 'contiguous', 'adjacent'] 137 | oracle_settings = [(False, False), (False, True), (True, False)] 138 | arg_tuples = [] 139 | for (adjacent_ops_, contiguous_edits_), job_name in zip(oracle_settings, oracle_names): 140 | out_path_ = os.path.join(out_dir_, job_name) 141 | arg_tuple = ( 142 | func, model_path, out_path_, 143 | data_path, max_data, idf_path, 144 | n_episodes, n_oracle_edits, 145 | adjacent_ops_, complete_words, contiguous_edits_, 146 | bleu_ngrams) 147 | arg_tuples.append(arg_tuple) 148 | 149 | job_path = os.path.join(job_dir, 'oracles') 150 | with open(job_path, 'wt') as f: 151 | for a in arg_tuples: 152 | f.write(command_template.format(*a) + '\n') 153 | 154 | 155 | ## DAGGER 156 | 157 | model_names = os.listdir(os.path.join(model_dir, exp_name)) 158 | block_list = [] 159 | model_names = [n for n in model_names if n not in block_list] 160 | 161 | out_dir_ = os.path.join(out_dir, exp_name) 162 | if not os.path.exists(out_dir_): 163 | os.makedirs(out_dir_) 164 | arg_tuples = [] 165 | for job_name in model_names: 166 | model_path_ = os.path.join(model_dir, exp_name, job_name, 'WEIGHTS.bin') 167 | out_path_ = os.path.join(out_dir_, job_name) 168 | 169 | data_path_, idf_path_, func_ = job_args_from_job_name(job_name) #<= placeholder because name format incorrect for debug 170 | 171 | arg_tuple = ( 172 | func_, model_path_, out_path_, 173 | data_path_, max_data, idf_path_, 174 | n_episodes, n_oracle_edits, 175 | adjacent_ops, complete_words, contiguous_edits, 176 | bleu_ngrams) 177 | arg_tuples.append(arg_tuple) 178 | #baseline 179 | arg_tuples.append(('Baseline', model_path, os.path.join(out_dir_, 'baseline'), 180 | data_path, max_data, idf_path, n_episodes, n_oracle_edits, 181 | adjacent_ops, complete_words, contiguous_edits, bleu_ngrams)) 182 | 183 | job_strings = [command_template.format(*a) for a in arg_tuples] 184 | 185 | job_path = os.path.join(job_dir, exp_name) 186 | with open(job_path, 'wt') as f: 187 | for line in job_strings: 188 | f.write(line + '\n') 189 | 190 | # Interactive vs. One Shot =========================== 191 | 192 | int_job_dir = os.path.join(args.job_dir, 'interactive') 193 | if not os.path.exists(int_job_dir): 194 | os.makedirs(int_job_dir) 195 | int_out_dir = os.path.join(out_dir, 'interactive') 196 | int_model_dir = os.path.join(model_dir, 'dagger') # evaluation done with dagger models 197 | 198 | names = [ 199 | 'cnn-bart-editor', 'cnn-bart_editor_large', 200 | 'cnn-bart-s2s' 201 | ] 202 | funcs = [ 203 | 'BartEditor', 'BartLargeEditor', 204 | 'BartS2S' 205 | ] 206 | idf_names = [ 207 | 'cnn_bart_idfs', 'cnn_bart_idfs', 208 | 'cnn_bart_idfs' 209 | ] 210 | 211 | int_job_args = [] 212 | for n,f,idf_n in zip(names, funcs, idf_names): 213 | mp = os.path.join(int_model_dir, n, 'WEIGHTS.bin') 214 | ip = os.path.join(data_dir, 'misc', idf_n + '.pickle') 215 | int_job_args.append((n,f,mp,ip)) 216 | 217 | for a in int_job_args: 218 | make_interactive_eval_job(*a) 219 | -------------------------------------------------------------------------------- /scripts/run_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from infosol.evaluate import EvaluateBartEditor, EvaluateBertEditor, EvaluateBartS2S, EvaluateBartLarge, EvaluateNoModel 4 | 5 | if __name__ == '__main__': 6 | 7 | meta_parser = argparse.ArgumentParser() 8 | meta_parser.add_argument('--args_path', type=str) 9 | meta_parser.add_argument('--cuda_device', type=int) 10 | meta_args = meta_parser.parse_args() 11 | 12 | parser = argparse.ArgumentParser() 13 | subparsers = parser.add_subparsers() 14 | 15 | parser_barteditor = subparsers.add_parser('BartEditor') 16 | parser_barteditor.set_defaults(func=EvaluateBartEditor) 17 | EvaluateBartEditor.add_args(parser_barteditor) 18 | 19 | parser_berteditor = subparsers.add_parser('BertEditor') 20 | parser_berteditor.set_defaults(func=EvaluateBertEditor) 21 | EvaluateBertEditor.add_args(parser_berteditor) 22 | 23 | parser_barts2s = subparsers.add_parser('BartS2S') 24 | parser_barts2s.set_defaults(func=EvaluateBartS2S) 25 | EvaluateBartS2S.add_args(parser_barts2s) 26 | 27 | parser_bartlarge = subparsers.add_parser('BartLargeEditor') 28 | parser_bartlarge.set_defaults(func=EvaluateBartLarge) 29 | EvaluateBartLarge.add_args(parser_bartlarge) 30 | 31 | parser_baseline = subparsers.add_parser('Baseline') 32 | parser_baseline.set_defaults(func=EvaluateNoModel) 33 | EvaluateNoModel.add_args(parser_baseline) 34 | 35 | with open(meta_args.args_path, 'rt') as f: 36 | for i,line in enumerate(f): 37 | print(f'#### On job {i} ####') 38 | args_list = line.strip().split(' ') 39 | args_list.extend(['--cuda_device', str(meta_args.cuda_device)]) 40 | args = parser.parse_args(args_list) 41 | eval_instance = args.func() 42 | eval_instance.setup(args) 43 | eval_instance.run() 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='infosol', version='1.0', packages=find_packages()) 4 | --------------------------------------------------------------------------------